diff --git a/evalbench/dataset/cortadoinput.py b/evalbench/dataset/cortadoinput.py new file mode 100644 index 00000000..0eec1e3d --- /dev/null +++ b/evalbench/dataset/cortadoinput.py @@ -0,0 +1,61 @@ +import json +import copy + + +class EvalCortadoRequest: + def __init__(self, raw_dict: dict, job_id: str = "", trace_id: str = ""): + """ + Initializes an EvalCortadoRequest from a parsed JSON dictionary. + """ + # Store the raw dictionary so process_scenario can read max_turns and the plan + self.scenario = raw_dict + + # Extract top-level identification + self.id = str(raw_dict.get("id", "-1")) + self.job_id = job_id + self.trace_id = trace_id + + # Evalbench core routing needs these to match the YAML + self.dialect = raw_dict.get("dialect", "") + self.dialects = [self.dialect] + self.database = raw_dict.get("database", "") + self.nl_prompt = raw_dict.get("starting_prompt", "") + + # Ensure the stringified payload is ready for the gRPC Proxy + self.payload_str = json.dumps(raw_dict) + self.payload = self.payload_str + + self.agent_results = [] + self.scoring_results = [] + + @classmethod + def init_from_proto(cls, proto): + """Unpacks the Protobuf from Google3 back into the object.""" + payload_str = getattr(proto, "payload", "{}") + try: + raw_dict = json.loads(payload_str) + except json.JSONDecodeError: + raw_dict = {} + + raw_dict["id"] = str(getattr(proto, "id", "-1")) + + return cls( + raw_dict=raw_dict, + job_id=getattr(proto, "job_id", ""), + trace_id=getattr(proto, "trace_id", ""), + ) + + def to_proto(self): + """Packs the object into the Protobuf to send to Google3.""" + # Note: You must import eval_request_pb2 here to prevent circular dependencies + from evalproto import eval_request_pb2 + + return eval_request_pb2.EvalInputRequest( + id=int(self.id) if self.id.isdigit() else 0, + payload=self.payload_str, + # We map starting_prompt to nl_prompt for backwards compatibility + nl_prompt=self.nl_prompt + ) + + def copy(self): + return copy.deepcopy(self) diff --git a/evalbench/dataset/dataset.py b/evalbench/dataset/dataset.py index b1cf7bf5..bbc3ec42 100644 --- a/evalbench/dataset/dataset.py +++ b/evalbench/dataset/dataset.py @@ -7,6 +7,7 @@ from dataset.evalinput import EvalInputRequest from dataset.evalinteractinput import EvalInteractInputRequest from dataset.evalgeminicliinput import EvalGeminiCliRequest +from dataset.cortadoinput import EvalCortadoRequest from itertools import chain import os @@ -95,6 +96,23 @@ def load_bird_interact_dataset(json_file_path, config): return input_items +def load_cortado_json(json_file_path): + all_items: dict[str, list[EvalCortadoRequest]] = { + "cortado-format": [], + } + with open(json_file_path, "r") as json_file: + data = json.load(json_file) + + scenarios = data.get("scenarios", []) + for scenario in scenarios: + eval_input = EvalCortadoRequest( + raw_dict=scenario + ) + all_items["cortado-format"].extend([eval_input]) + + return all_items + + def load_gemini_cli_json(json_file_path): all_items: dict[str, list[EvalGeminiCliRequest]] = { "gemini-cli-format": [], @@ -139,6 +157,8 @@ def load_dataset_from_json(json_file_path, config): all_items = load_bird_interact_dataset(json_file_path, config) elif dataset_format in ("gemini-cli-format", "agent-format"): all_items = load_gemini_cli_json(json_file_path) + elif dataset_format == "cortado-format": + all_items = load_cortado_json(json_file_path) else: all_items = load_json(json_file_path) @@ -152,6 +172,9 @@ def load_dataset_from_json(json_file_path, config): if "orchestrator" not in config: config["orchestrator"] = "interact" input_items = all_items + elif dataset_format == "cortado-format": + if "orchestrator" not in config: + config["orchestrator"] = "cortado" elif dataset_format in ("gemini-cli-format", "agent-format"): if "orchestrator" not in config: config["orchestrator"] = "agent" if dataset_format == "agent-format" else "geminicli" @@ -159,7 +182,7 @@ def load_dataset_from_json(json_file_path, config): else: raise ValueError("Dataset not in any of the recognised formats") - if dataset_format not in ["gemini-cli-format", "agent-format", "bird-interact-format"]: + if dataset_format not in ["gemini-cli-format", "bird-interact-format", "agent-format", "cortado-format"]: totalEntries = sum(len(input_items.get(q, [])) for q in ["dql", "dml", "ddl"]) logging.info(f"Converted {totalEntries} entries to EvalInput.") diff --git a/evalbench/eval_service.py b/evalbench/eval_service.py index 0382aa69..5b6c7408 100644 --- a/evalbench/eval_service.py +++ b/evalbench/eval_service.py @@ -12,6 +12,7 @@ import yaml import grpc import pathlib +import queue from dataset.dataset import load_json from dataset import evalinput from evaluator import get_orchestrator, get_streaming_orchestrator @@ -33,11 +34,13 @@ load_session_configs, get_dataset_from_request, ) +from generators.models.grpc_proxy import PROXY_QUEUES import threading from util.context import rpc_id_var from util import get_SessionManager + SESSIONMANAGER = get_SessionManager() @@ -85,6 +88,7 @@ async def Connect( session = SESSIONMANAGER.get_session(session_id) if session is not None: session["streaming_eval"] = request.streaming_eval + session["bidirectional_stream"] = request.bidirectional_stream return eval_response_pb2.EvalResponse(response="ack", session_id=session_id) async def EvalConfig( @@ -133,7 +137,8 @@ async def Eval( try: session_id = rpc_id_var.get() session = SESSIONMANAGER.get_session(session_id) - config, db_configs, model_config, setup_config = load_session_configs(session) + config, db_configs, model_config, setup_config = load_session_configs( + session) if config is None: context.set_code(grpc.StatusCode.FAILED_PRECONDITION) context.set_details("Session not configured") @@ -145,12 +150,15 @@ async def Eval( set_up_script = config.get("set_up_script") if set_up_script: if os.path.exists(set_up_script): - logging.info(f"Eval: Executing set_up_script '{set_up_script}'") + logging.info( + f"Eval: Executing set_up_script '{set_up_script}'") run_script(set_up_script, session_dir, "setup") else: - logging.error(f"Eval: Cannot run set_up_script, file not found at '{set_up_script}'") + logging.error( + f"Eval: Cannot run set_up_script, file not found at '{set_up_script}'") - streaming_eval = session.get("streaming_eval", False) if session else False + streaming_eval = session.get( + "streaming_eval", False) if session else False loop = asyncio.get_event_loop() if streaming_eval: @@ -177,7 +185,8 @@ async def Eval( evaluator = get_orchestrator( config, db_configs, setup_config, report_progress=True ) - logging.info("Batch eval mode: evaluating all items together...") + logging.info( + "Batch eval mode: evaluating all items together...") ctx = contextvars.copy_context() await loop.run_in_executor( None, ctx.run, evaluator.evaluate, dataset @@ -219,10 +228,12 @@ async def Eval( tear_down_script = config.get("tear_down_script") if tear_down_script: if os.path.exists(tear_down_script): - logging.info(f"Eval: Executing tear_down_script '{tear_down_script}'") + logging.info( + f"Eval: Executing tear_down_script '{tear_down_script}'") run_script(tear_down_script, session_dir, "teardown") else: - logging.error(f"Eval: Cannot run tear_down_script, file not found at '{tear_down_script}'") + logging.error( + f"Eval: Cannot run tear_down_script, file not found at '{tear_down_script}'") return eval_response_pb2.EvalResponse(response=response, session_id=session_id) @@ -230,17 +241,203 @@ async def Eval( display_config = "Unknown" # Attempt retrieval of configuration details if successfully loaded try: - loaded_config = SESSIONMANAGER.get_session(rpc_id_var.get()).get("config", {}) + loaded_config = SESSIONMANAGER.get_session( + rpc_id_var.get()).get("config", {}) cand = loaded_config.get("dataset_config", "Unknown") g3_idx = cand.find("google3/") display_config = cand[g3_idx:] if g3_idx != -1 else cand except Exception as e_ctx: # Best effort retrieval of metadata for tracing. Do not mask original fault. - logging.debug(f"Unable to determine active dataset path for log context: {e_ctx}") + logging.debug( + f"Unable to determine active dataset path for log context: {e_ctx}") - logging.exception(f"gRPC Eval failed for config/dataset '{display_config}': {e}") + logging.exception( + f"gRPC Eval failed for config/dataset '{display_config}': {e}") raise + async def Interact( + self, + request_iterator: AsyncIterator[eval_request_pb2.EvalInputRequest], + context: grpc.ServicerContext, + ) -> AsyncGenerator[eval_request_pb2.EvalInputRequest, None]: + """Bidirectional stream linking Google3 Agents to Evalbench Orchestrators.""" + + session_id = rpc_id_var.get() + session = SESSIONMANAGER.get_session(session_id) + config, db_configs, model_config, setup_config = load_session_configs( + session) + + if config is None: + context.set_code(grpc.StatusCode.FAILED_PRECONDITION) + context.set_details("Session not configured") + return + + is_bidirectional = session.get( + "bidirectional_stream", False) if session else False + + if not is_bidirectional: + error_msg = ( + "Interact must be used with bidirectional streaming" + ) + logging.error(error_msg) + context.set_code(grpc.StatusCode.INVALID_ARGUMENT) + context.set_details(error_msg) + return + logging.info("Starting a bidirectional Interact stream...") + + config["session_id"] = session_id + + # Create thread-safe queues + in_queue = {} # Google3 -> Evalbench (mapped by conversation_id) + out_queue = queue.Queue() # Evalbench -> Google3 + + config["grpc_in_queues"] = in_queue + config["grpc_out_queue"] = out_queue + logging.info(f"CONFIG: {config}") + generator = model_config.get("generator") + + if generator != "grpc_proxy": + error_msg = ( + "Interactive evaluation failed: must use 'grpc_proxy' generator" + ) + logging.error(error_msg) + context.set_code(grpc.StatusCode.INVALID_ARGUMENT) + context.set_details(error_msg) + return + + # Load dataset and instantiate the Orchestrator + dataset_config_json = config["dataset_config"] + dataset_dict = load_dataset_from_json(dataset_config_json, config) + + dataset = [] + for _, item_list in dataset_dict.items(): + dataset.extend(item_list) + + num_evals = config.get("num_evals_to_run") + if num_evals and int(num_evals) > 0: + dataset = dataset[:int(num_evals)] + + orchestrator = get_orchestrator( + config, db_configs, setup_config, report_progress=True) + loop = asyncio.get_event_loop() + ctx = contextvars.copy_context() + + try: + PROXY_QUEUES[session_id] = (in_queue, out_queue) + + def _cleanup_on_drop(ctx): + if session_id in PROXY_QUEUES: + PROXY_QUEUES.pop(session_id, None) + logging.info( + f"Cleaned up proxy queues for session {session_id} via disconnect callback") + + context.add_done_callback(_cleanup_on_drop) + + eval_task = loop.run_in_executor( + None, ctx.run, orchestrator.evaluate, dataset + ) + + async def read_from_client(): + """Reads messages from the Google3 client stream.""" + async for response in request_iterator: + conv_id = str( + getattr(response, "conversation_id", getattr(response, "id", ""))) + logging.debug( + "Server-Inbound: Received from Google3 for conv_id %s", conv_id) + + if conv_id in in_queue: + logging.info( + f"[TRACE] Server-Inbound: Matched {conv_id} to active thread. Unblocking...") + in_queue[conv_id].put(response) + else: + logging.error( + "Server-Inbound: Orphaned reply! conv_id: '%s' not in active queues. Active keys: %s", + conv_id, list(in_queue.keys()) + ) + read_task = asyncio.create_task(read_from_client()) + + # Yield Loop: Read from out_queue and yield to Google3 + while True: + if eval_task.done(): + logging.info( + "Evaluator task finished for session %s.", session_id) + try: + eval_task.result() # Propagate exceptions + except Exception as e: + logging.error( + "Orchestrator/Evaluator task failed: %s", e, exc_info=True) + break + + if SESSIONMANAGER.get_session(session_id) is None: + logging.warning( + f"Session {session_id} deleted. Terminating stream.") + context.set_code(grpc.StatusCode.NOT_FOUND) + context.set_details("Session deleted") + return + + try: + out_request: eval_request_pb2.EvalInputRequest = await asyncio.to_thread(out_queue.get, True, 1.0) + + logging.debug( + "Server-Outbound: Yielding to Google3 for conv_id %s", out_request.conversation_id) + yield out_request + + except queue.Empty: + continue # Loop back and check if eval_task is done + except Exception as e: + import traceback + logging.error( + "Server-Outbound: Yield Loop error: %s", e, exc_info=True) + continue + + read_task.cancel() + try: + await read_task + except asyncio.CancelledError: + logging.debug("Read task cancelled as expected.") + + # Process final scoring and reporting + job_id, run_time, results_tf, scores_tf = orchestrator.process() + reporters = get_reporters(config.get( + "reporting") or {}, job_id, run_time) + + logging.info( + "Offloading interactive results processing to thread pool...") + + summary = await loop.run_in_executor( + None, + ctx.run, + _process_results, + reporters, + job_id, + run_time, + results_tf, + scores_tf, + config, + model_config, + db_configs, + ) + logging.info( + f"Finished Interactive Job ID {job_id}. Summary: {summary}") + + # Send the final payload back to the client to close the stream cleanly. + final_request = eval_request_pb2.EvalInputRequest() + final_request.payload = json.dumps( + {"job_id": job_id, "summary": summary}) + + if dataset: + first_item = dataset[0] + conv_id = str(getattr(first_item, "id", "")) + if conv_id: + final_request.conversation_id = conv_id + logging.info(f"Yielding final summary payload: {final_request}") + yield final_request + + finally: + # Clean up the global registry to prevent memory leaks. + PROXY_QUEUES.pop(session_id, None) + logging.info(f"Cleaned up proxy queues for session {session_id}") + def _process_results( reporters, job_id, run_time, results_tf, scores_tf, multi_trial_scores_tf, config, model_config, db_configs diff --git a/evalbench/evalproto/eval_connect.proto b/evalbench/evalproto/eval_connect.proto index 390c757b..6319359a 100644 --- a/evalbench/evalproto/eval_connect.proto +++ b/evalbench/evalproto/eval_connect.proto @@ -7,4 +7,5 @@ option java_multiple_files = true; message EvalConnectRequest { string client_id = 1; bool streaming_eval = 2; + bool bidirectional_stream = 3; } diff --git a/evalbench/evalproto/eval_request.proto b/evalbench/evalproto/eval_request.proto index 3aec80bf..872c983e 100644 --- a/evalbench/evalproto/eval_request.proto +++ b/evalbench/evalproto/eval_request.proto @@ -32,6 +32,8 @@ message EvalInputRequest { string job_id = 15; string trace_id = 16; string payload = 17; + string conversation_id = 18; + string generated_nl_response = 19; } message UserAction { diff --git a/evalbench/evalproto/eval_service.proto b/evalbench/evalproto/eval_service.proto index 60066053..490ce6d0 100644 --- a/evalbench/evalproto/eval_service.proto +++ b/evalbench/evalproto/eval_service.proto @@ -35,6 +35,11 @@ service EvalService { // option deadline = 1800; } + // MultiTurnEval. Bidirectional stream. + rpc Interact(stream EvalInputRequest) returns (stream EvalInputRequest) { + // option deadline = 1800; + } + // PrepareCodeEvalInputs for NL2Code Evaluation rpc PrepareCodeEvalInputs(EvalCodeInputRequest) returns (stream EvalCodeInputRequest) { // option deadline = 1800; diff --git a/evalbench/evaluator/__init__.py b/evalbench/evaluator/__init__.py index 21ba34e8..0898262b 100644 --- a/evalbench/evaluator/__init__.py +++ b/evalbench/evaluator/__init__.py @@ -1,5 +1,6 @@ from evaluator.orchestrator import Orchestrator from evaluator.oneshotorchestrator import OneShotOrchestrator +from evaluator.cortadoorchestrator import CortadoOrchestrator from evaluator.interactorchestrator import InteractOrchestrator from evaluator.dataagentorchestrator import DataAgentOrchestrator from evaluator.agentorchestrator import AgentOrchestrator @@ -9,7 +10,6 @@ def get_orchestrator(config, db_configs, setup_config, report_progress=False): orchestrator_type = config.get("orchestrator", "oneshot") - logging.info(f"Orchestrator Type: {orchestrator_type}") if orchestrator_type == "oneshot": return OneShotOrchestrator(config, db_configs, setup_config, report_progress) elif orchestrator_type == "interact": @@ -18,6 +18,8 @@ def get_orchestrator(config, db_configs, setup_config, report_progress=False): return DataAgentOrchestrator(config, db_configs, setup_config, report_progress) elif orchestrator_type in ("geminicli", "agent"): return AgentOrchestrator(config, db_configs, setup_config, report_progress) + elif orchestrator_type == "cortado": + return CortadoOrchestrator(config, db_configs, setup_config, report_progress) else: return Orchestrator(config, db_configs, setup_config, report_progress) diff --git a/evalbench/evaluator/cortadoevaluator.py b/evalbench/evaluator/cortadoevaluator.py new file mode 100644 index 00000000..64ceb803 --- /dev/null +++ b/evalbench/evaluator/cortadoevaluator.py @@ -0,0 +1,184 @@ +# cortadoevaluator.py + +from typing import Any, List, Dict +import datetime +import concurrent.futures +import logging +import json + +from dataset.cortadoinput import EvalCortadoRequest +from generators.models.grpc_proxy import GrpcProxyModel +from util.config import load_yaml_config +from mp import mprunner +from work.agentgenwork import AgentGenWork +from evaluator.simulateduser import SimulatedUser +from work.agentscorework import AgentScoreWork + + +class CortadoEvaluator: + def __init__(self, config): + self.config = config + + # Load model config + model_config = config + if "model_config" in config and isinstance(config["model_config"], str): + loaded_config = load_yaml_config(config["model_config"]) + model_config = loaded_config.copy() + model_config.update(config) + + generator_type = model_config.get("generator") + if generator_type == "grpc_proxy": + self.generator = GrpcProxyModel(model_config) + else: + raise ValueError( + f"CortadoEvaluator requires 'grpc_proxy' generator, got {generator_type}") + + runner_config = self.config.get("runners", {}) + self.agent_runners = runner_config.get("agent_runners", 10) + self.agentrunner = mprunner.MPRunner(self.agent_runners) + + def evaluate(self, dataset: List[EvalCortadoRequest], job_id: str, run_time: datetime.datetime): + eval_outputs: List[Any] = [] + scoring_results: List[Any] = [] + logging.info("Running Cortado gRPC evaluation") + + self.agentrunner.futures.clear() + + metadata = { + "dialects": self.config.get("dialects", []), + "database": self.config.get("database", "unknown"), + "scorers": self.config.get("scorers", {}), + } + + # Spin up threads for concurrent conversation processing + for item in dataset: + simulated_user = SimulatedUser(self.config) + work = AgentGenWork( + processor=self.process_scenario, + eval_result=item, + job_id=job_id, + metadata=metadata, + simulated_user=simulated_user + ) + self.agentrunner.execute_work(work) + + for future in concurrent.futures.as_completed(self.agentrunner.futures): + try: + # This now contains the returned object from process_scenario + modified_item = future.result() + if hasattr(modified_item, "agent_results"): + eval_outputs.extend(modified_item.agent_results) + if hasattr(modified_item, "scoring_results"): + scoring_results.extend(modified_item.scoring_results) + except Exception as e: + logging.error( + f"Error getting result from future: {e}", exc_info=True) + + return eval_outputs, scoring_results + + def process_scenario( + self, scenario: Dict[str, Any], eval_result: Any, job_id: str, + metadata: Dict[str, Any], simulated_user: Any = None + ) -> Any: + """Communication between Cortado and the Simulated User.""" + + current_prompt = scenario.get("starting_prompt", "") + max_turns = scenario.get("max_turns", 1) + conversation_plan = scenario.get("conversation_plan", []) + conversation_history = [] + last_agent_text = "" + last_sql_reply = "" + + # Parity tracking lists + accumulated_tools = [] + accumulated_skills = [] + + for turn in range(max_turns): + logging.info( + f"Turn {turn + 1}/{max_turns} - Prompt: {current_prompt}") + + # Inject the current prompt into the object + eval_result.nl_prompt = current_prompt + + # Hand it to the gRPC Proxy (blocks until client replies) + agent_text = "" + try: + self.generator.generate(eval_result) + + nl_reply = getattr(eval_result, "generated_nl_response", "") + sql_reply = getattr(eval_result, "generated_sql", "") + last_sql_reply = sql_reply + agent_text = nl_reply + + except Exception as e: + logging.error(f'gRPC generation failed: {e}', exc_info=True) + agent_text = f"Error: {e}" + last_sql_reply = "" + + last_agent_text = agent_text + logging.info( + f"Turn {turn + 1}/{max_turns} - Agent Reply to Simulated User: {agent_text}") + + # Log history + conversation_history.append({ + "user": current_prompt, + "agent": agent_text + }) + + # Invoke Simulated User to check plan and generate next turn + if turn < max_turns - 1 and simulated_user: + next_response = simulated_user.get_next_response( + conversation_plan, conversation_history, agent_text + ) + if "TERMINATE" in next_response: + logging.info( + "Simulated user met the goal and terminated the conversation.") + break + current_prompt = next_response + else: + break + + # Finalize and Score + self._finalize_scenario( + scenario, last_agent_text, conversation_history, + accumulated_tools, accumulated_skills, + eval_result, job_id, metadata, + last_sql_reply + ) + return eval_result + + def _finalize_scenario( + self, scenario: Dict[str, Any], last_response: str, + conversation_history: List[Dict[str, str]], + accumulated_tools: List[str], accumulated_skills: List[str], + eval_result: Any, job_id: str, metadata: Dict[str, Any], + last_sql: str + ): + """Packages the conversation and sends it to the scoring engine.""" + + eval_output_data = { + "eval_id": scenario["id"], + "stdout": last_response, # This is the text seen by the simulated user + "stderr": "", + "returncode": 0 if not last_response.startswith("Error") else 1, + "prompt_generator_error": None, + "generated_error": None, + "sql_generator_error": None, + "golden_error": None, + "generated_sql": last_sql, + "prompt": scenario["starting_prompt"], + "conversation_history": json.dumps(conversation_history, indent=2), + "scenario": scenario, + "accumulated_tools": accumulated_tools, # Passes empty list for now + "accumulated_skills": accumulated_skills, # Passes empty list for now + "job_id": job_id, + "metadata": metadata + } + + score_work = AgentScoreWork( + config=self.config, + eval_output=eval_output_data, + scoring_results=eval_result.scoring_results + ) + score_work.run() + eval_result.agent_results.append(eval_output_data) diff --git a/evalbench/evaluator/cortadoorchestrator.py b/evalbench/evaluator/cortadoorchestrator.py new file mode 100644 index 00000000..968cd933 --- /dev/null +++ b/evalbench/evaluator/cortadoorchestrator.py @@ -0,0 +1,37 @@ +from evaluator.orchestrator import Orchestrator +import uuid +import datetime +import tempfile +import json +from dataset.cortadoinput import EvalCortadoRequest +from evaluator.cortadoevaluator import CortadoEvaluator + + +class CortadoOrchestrator(Orchestrator): + def __init__(self, config, db_configs, setup_config, report_progress=False): + self.config = config + self.db_configs = db_configs + self.setup_config = setup_config + self.job_id = f"{uuid.uuid4()}" + self.run_time = datetime.datetime.now() + self.total_eval_outputs = [] + self.total_scoring_results = [] + + def evaluate(self, dataset: list[EvalCortadoRequest]): + evaluator = CortadoEvaluator(self.config) + eval_outputs, scoring_results = evaluator.evaluate( + dataset, self.job_id, self.run_time + ) + self.total_eval_outputs.extend(eval_outputs) + self.total_scoring_results.extend(scoring_results) + + def process(self): + with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".json") as f: + json.dump(self.total_eval_outputs, f, + sort_keys=True, indent=4, default=str) + results_tf = f.name + with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".json") as f: + json.dump(self.total_scoring_results, f, + sort_keys=True, indent=4, default=str) + scores_tf = f.name + return self.job_id, self.run_time, results_tf, scores_tf diff --git a/evalbench/generators/models/__init__.py b/evalbench/generators/models/__init__.py index af99514b..61913cde 100644 --- a/evalbench/generators/models/__init__.py +++ b/evalbench/generators/models/__init__.py @@ -3,6 +3,7 @@ from generators.models.generator import QueryGenerator from .gemini import GeminiGenerator from .passthrough import NOOPGenerator +from .grpc_proxy import GrpcProxyModel from .claude import ClaudeGenerator from .querydata import QueryData from .query_data_api import QueryDataAPIGenerator @@ -33,6 +34,8 @@ def get_generator(global_models, model_config_path: str, db: DB = None): model = QueryData(config) if config["generator"] == "query_data_api": model = QueryDataAPIGenerator(config) + if config["generator"] == "grpc_proxy": + model = GrpcProxyModel(config) if config["generator"] == "gemini_cli": model = GeminiCliGenerator(config) if config["generator"] == "claude_code": diff --git a/evalbench/generators/models/grpc_proxy.py b/evalbench/generators/models/grpc_proxy.py new file mode 100644 index 00000000..78e20b55 --- /dev/null +++ b/evalbench/generators/models/grpc_proxy.py @@ -0,0 +1,123 @@ +""" +This module defines the GrpcProxyModel, which acts as a client-side interface +within the Evalbench system to an external generator running in Google3. + +The GrpcProxyModel does not perform any generation itself. Instead, it marshals +requests from the Evalbench evaluator (e.g., CortadoEvaluator), sends them +over a bidirectional gRPC stream to the Google3 orchestrator, and waits for +the results to be sent back. +""" + +import threading +import queue +import logging +import traceback +from generators.models.generator import QueryGenerator +from evalproto import eval_request_pb2 + +PROXY_QUEUES = {} + + +class GrpcProxyModel(QueryGenerator): + def __init__(self, config): + super().__init__(config) + self.name = "grpc_proxy" + + def generate_internal(self, prompt: str) -> str: + # This method seems unused in the bidi flow, can be left as pass + pass + + def generate(self, eval_output) -> dict: + """ + Proxies the request to Google3 and waits for a response. + eval_output is typically an instance of EvalCortadoRequest. + """ + conv_id = None # Initialize conv_id + in_queues_dict = None # Initialize in_queues_dict + thread_id = threading.get_ident() + try: + if not PROXY_QUEUES: + raise RuntimeError("PROXY_QUEUES is empty!") + + # Assuming single session per proxy for now + session_id = list(PROXY_QUEUES.keys())[0] + in_queues_dict, out_queue = PROXY_QUEUES[session_id] + + def get_val(obj, *keys, default=None): + for k in keys: + if hasattr(obj, "get") and callable(obj.get): + val = obj.get(k) + if val is not None: + return val + if hasattr(obj, k): + val = getattr(obj, k) + if val is not None: + return val + return default + + prompt_text = get_val(eval_output, "nl_prompt", default="") + database = get_val(eval_output, "database", + "db_id", "database_name", default="") + dialects = get_val(eval_output, "dialects", "dialect", default=[]) + if isinstance(dialects, str): + dialects = [dialects] + query_type = get_val(eval_output, "query_type", default="") + scenario_payload = get_val( + eval_output, "payload_str", "payload", default="{}") + item_id_str = str( + get_val(eval_output, "id", "eval_id", default=thread_id)) + conv_id = item_id_str + + thread_inbox = queue.Queue() + if conv_id in in_queues_dict: + logging.warning( + f"[TRACE] Proxy[Thread-{thread_id}]: WARNING: conv_id {conv_id} already exists in in_queue. Overwriting.") + in_queues_dict[conv_id] = thread_inbox + + outbound_req = eval_request_pb2.EvalInputRequest( + conversation_id=conv_id, + nl_prompt=prompt_text, + database=database, + dialects=dialects, + query_type=query_type, + payload=scenario_payload + ) + + logging.debug( + f"[DEBUG] Routing prompt to client. conv_id: {conv_id}") + out_queue.put(outbound_req) + + logging.debug( + f"[TRACE] Blocked and waiting for client reply on {conv_id}...") + inbound_response: eval_request_pb2.EvalInputRequest = thread_inbox.get( + block=True, timeout=300.0) + logging.debug( + f"[TRACE] Received Reply from client for conv_id {conv_id}!") + + # Extract fields from the received proto + nl_response = getattr( + inbound_response, "generated_nl_response", "") + sql_response = getattr(inbound_response, "generated_sql", "") + + # Update the eval_output object with the results from the client. + if hasattr(eval_output, "__setitem__"): + eval_output["generated_sql"] = sql_response + eval_output["generated_nl_response"] = nl_response + else: + setattr(eval_output, "generated_sql", sql_response) + setattr(eval_output, "generated_nl_response", nl_response) + + return eval_output + + except queue.Empty: + logging.error(f"[ERROR] Client TIMEOUT on {conv_id}") + raise TimeoutError( + f"Client disconnected or timed out on conv_id {conv_id}.") + except Exception as e: + logging.error( + f"[ERROR] crashed hard on conv_id {conv_id}: {e}\n{traceback.format_exc()}") + raise e + finally: + if in_queues_dict is not None and conv_id is not None: + in_queues_dict.pop(conv_id, None) + logging.debug(f"[DEBUG] Cleaned up inbox for {conv_id}") diff --git a/evalbench/work/agentgenwork.py b/evalbench/work/agentgenwork.py index a17743eb..5b5e552c 100644 --- a/evalbench/work/agentgenwork.py +++ b/evalbench/work/agentgenwork.py @@ -41,9 +41,21 @@ def run(self, work_config: Any = None) -> Any: try: eval_set = json.loads(eval_result.payload) - for scenario in eval_set.get("scenarios", []): + + if "scenarios" in eval_set: + # Loop through the scenarios array + for scenario in eval_set["scenarios"]: + self.processor( + scenario, + self.eval_result, + self.job_id, + self.metadata, + self.simulated_user + ) + else: + # When payload is the scenario self.processor( - scenario, + eval_set, self.eval_result, self.job_id, self.metadata, diff --git a/viewer/version.txt b/viewer/version.txt new file mode 100644 index 00000000..41c5857d --- /dev/null +++ b/viewer/version.txt @@ -0,0 +1 @@ +01802cf