Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions evalbench/dataset/cortadoinput.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import json
import copy


class EvalCortadoRequest:
def __init__(self, raw_dict: dict, job_id: str = "", trace_id: str = ""):
"""
Initializes an EvalCortadoRequest from a parsed JSON dictionary.
"""
# Store the raw dictionary so process_scenario can read max_turns and the plan
self.scenario = raw_dict

# Extract top-level identification
self.id = str(raw_dict.get("id", "-1"))
self.job_id = job_id
self.trace_id = trace_id

# Evalbench core routing needs these to match the YAML
self.dialect = raw_dict.get("dialect", "")
self.dialects = [self.dialect]
self.database = raw_dict.get("database", "")
self.nl_prompt = raw_dict.get("starting_prompt", "")

# Ensure the stringified payload is ready for the gRPC Proxy
self.payload_str = json.dumps(raw_dict)
self.payload = self.payload_str

self.agent_results = []
self.scoring_results = []

@classmethod
def init_from_proto(cls, proto):
"""Unpacks the Protobuf from Google3 back into the object."""
payload_str = getattr(proto, "payload", "{}")
try:
raw_dict = json.loads(payload_str)
except json.JSONDecodeError:
raw_dict = {}

raw_dict["id"] = str(getattr(proto, "id", "-1"))

return cls(
raw_dict=raw_dict,
job_id=getattr(proto, "job_id", ""),
trace_id=getattr(proto, "trace_id", ""),
)

def to_proto(self):
"""Packs the object into the Protobuf to send to Google3."""
# Note: You must import eval_request_pb2 here to prevent circular dependencies
from evalproto import eval_request_pb2

return eval_request_pb2.EvalInputRequest(
id=int(self.id) if self.id.isdigit() else 0,
payload=self.payload_str,
# We map starting_prompt to nl_prompt for backwards compatibility
nl_prompt=self.nl_prompt
)

def copy(self):
return copy.deepcopy(self)
25 changes: 24 additions & 1 deletion evalbench/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from dataset.evalinput import EvalInputRequest
from dataset.evalinteractinput import EvalInteractInputRequest
from dataset.evalgeminicliinput import EvalGeminiCliRequest
from dataset.cortadoinput import EvalCortadoRequest
from itertools import chain
import os

Expand Down Expand Up @@ -95,6 +96,23 @@ def load_bird_interact_dataset(json_file_path, config):
return input_items


def load_cortado_json(json_file_path):
all_items: dict[str, list[EvalCortadoRequest]] = {
"cortado-format": [],
}
with open(json_file_path, "r") as json_file:
data = json.load(json_file)

scenarios = data.get("scenarios", [])
for scenario in scenarios:
eval_input = EvalCortadoRequest(
raw_dict=scenario
)
all_items["cortado-format"].extend([eval_input])

return all_items


def load_gemini_cli_json(json_file_path):
all_items: dict[str, list[EvalGeminiCliRequest]] = {
"gemini-cli-format": [],
Expand Down Expand Up @@ -139,6 +157,8 @@ def load_dataset_from_json(json_file_path, config):
all_items = load_bird_interact_dataset(json_file_path, config)
elif dataset_format in ("gemini-cli-format", "agent-format"):
all_items = load_gemini_cli_json(json_file_path)
elif dataset_format == "cortado-format":
all_items = load_cortado_json(json_file_path)
else:
all_items = load_json(json_file_path)

Expand All @@ -152,14 +172,17 @@ def load_dataset_from_json(json_file_path, config):
if "orchestrator" not in config:
config["orchestrator"] = "interact"
input_items = all_items
elif dataset_format == "cortado-format":
if "orchestrator" not in config:
config["orchestrator"] = "cortado"
elif dataset_format in ("gemini-cli-format", "agent-format"):
if "orchestrator" not in config:
config["orchestrator"] = "agent" if dataset_format == "agent-format" else "geminicli"
input_items = all_items
else:
raise ValueError("Dataset not in any of the recognised formats")

if dataset_format not in ["gemini-cli-format", "agent-format", "bird-interact-format"]:
if dataset_format not in ["gemini-cli-format", "bird-interact-format", "agent-format", "cortado-format"]:
totalEntries = sum(len(input_items.get(q, []))
for q in ["dql", "dml", "ddl"])
logging.info(f"Converted {totalEntries} entries to EvalInput.")
Expand Down
167 changes: 167 additions & 0 deletions evalbench/eval_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import yaml
import grpc
import pathlib
import queue
from dataset.dataset import load_json
from dataset import evalinput
from evaluator import get_orchestrator, get_streaming_orchestrator
Expand All @@ -33,11 +34,13 @@
load_session_configs,
get_dataset_from_request,
)
from generators.models.grpc_proxy import PROXY_QUEUES

import threading
from util.context import rpc_id_var
from util import get_SessionManager


SESSIONMANAGER = get_SessionManager()


Expand Down Expand Up @@ -85,6 +88,7 @@ async def Connect(
session = SESSIONMANAGER.get_session(session_id)
if session is not None:
session["streaming_eval"] = request.streaming_eval
session["bidirectional_stream"] = request.bidirectional_stream
return eval_response_pb2.EvalResponse(response="ack", session_id=session_id)

async def EvalConfig(
Expand Down Expand Up @@ -241,6 +245,169 @@ async def Eval(
logging.exception(f"gRPC Eval failed for config/dataset '{display_config}': {e}")
raise

async def Interact(
self,
request_iterator: AsyncIterator[eval_request_pb2.EvalInputRequest],
context: grpc.ServicerContext,
) -> AsyncGenerator[eval_request_pb2.EvalInputRequest, None]:
"""Bidirectional stream linking Google3 Agents to Evalbench Orchestrators."""

session_id = rpc_id_var.get()
session = SESSIONMANAGER.get_session(session_id)
config, db_configs, model_config, setup_config = load_session_configs(
session)

if config is None:
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
context.set_details("Session not configured")
return

is_bidirectional = session.get(
"bidirectional_stream", False) if session else False

if not is_bidirectional:
error_msg = (
"Interact must be used with bidirectional streaming"
)
logging.error(error_msg)
context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
context.set_details(error_msg)
return
logging.info("Starting a bidirectional Interact stream...")

config["session_id"] = session_id

# Create thread-safe queues
in_queue = {} # Google3 -> Evalbench (mapped by conversation_id)
out_queue = queue.Queue() # Evalbench -> Google3

config["grpc_in_queues"] = in_queue
config["grpc_out_queue"] = out_queue
logging.info(f"CONFIG: {config}")
generator = model_config.get("generator")

if generator != "grpc_proxy":
error_msg = (
"Interactive evaluation failed: must use 'grpc_proxy' generator"
)
logging.error(error_msg)
context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
context.set_details(error_msg)
return

# Load dataset and instantiate the Orchestrator
dataset_config_json = config["dataset_config"]
dataset_dict = load_dataset_from_json(dataset_config_json, config)

dataset = []
for _, item_list in dataset_dict.items():
dataset.extend(item_list)

num_evals = config.get("num_evals_to_run")
if num_evals and int(num_evals) > 0:
dataset = dataset[:int(num_evals)]

orchestrator = get_orchestrator(
config, db_configs, setup_config, report_progress=True)
loop = asyncio.get_event_loop()
ctx = contextvars.copy_context()

try:
PROXY_QUEUES[session_id] = (in_queue, out_queue)
eval_task = loop.run_in_executor(
None, ctx.run, orchestrator.evaluate, dataset
)

async def read_from_client():
"""Reads messages from the Google3 client stream."""
async for response in request_iterator:
conv_id = str(
getattr(response, "conversation_id", getattr(response, "id", "")))
logging.debug(
"Server-Inbound: Received from Google3 for conv_id %s", conv_id)

if conv_id in in_queue:
logging.info(
f"[TRACE] Server-Inbound: Matched {conv_id} to active thread. Unblocking...")
in_queue[conv_id].put(response)
else:
logging.error(
"Server-Inbound: Orphaned reply! conv_id: '%s' not in active queues. Active keys: %s",
conv_id, list(in_queue.keys())
)
read_task = asyncio.create_task(read_from_client())

# Yield Loop: Read from out_queue and yield to Google3
while True:
if eval_task.done():
logging.info("Evaluator task finished for session %s.", session_id)
try:
eval_task.result() # Propagate exceptions
except Exception as e:
logging.error("Orchestrator/Evaluator task failed: %s", e, exc_info=True)
break

try:
out_request: eval_request_pb2.EvalInputRequest = await asyncio.to_thread(out_queue.get, True, 1.0)

logging.debug("Server-Outbound: Yielding to Google3 for conv_id %s", out_request.conversation_id)
yield out_request

except queue.Empty:
continue # Loop back and check if eval_task is done
except Exception as e:
import traceback
logging.error("Server-Outbound: Yield Loop error: %s", e, exc_info=True)
continue

read_task.cancel()
try:
await read_task
except asyncio.CancelledError:
logging.debug("Read task cancelled as expected.")

# Process final scoring and reporting
job_id, run_time, results_tf, scores_tf = orchestrator.process()
reporters = get_reporters(config.get(
"reporting") or {}, job_id, run_time)

logging.info(
"Offloading interactive results processing to thread pool...")

summary = await loop.run_in_executor(
None,
ctx.run,
_process_results,
reporters,
job_id,
run_time,
results_tf,
scores_tf,
config,
model_config,
db_configs,
)
logging.info(
f"Finished Interactive Job ID {job_id}. Summary: {summary}")

# Send the final payload back to the client to close the stream cleanly.
final_request = eval_request_pb2.EvalInputRequest()
final_request.payload = json.dumps(
{"job_id": job_id, "summary": summary})

if dataset:
first_item = dataset[0]
conv_id = str(getattr(first_item, "id", ""))
if conv_id:
final_request.conversation_id = conv_id
logging.info(f"Yielding final summary payload: {final_request}")
yield final_request

finally:
# Clean up the global registry to prevent memory leaks.
PROXY_QUEUES.pop(session_id, None)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If a client connection drops forcefully or hangs, and the Interact RPC does not hit its finally block cleanup (or if the thread hangs indefinitely
blocked on inbound_response = thread_inbox.get(block=True, timeout=300.0) ), then that entry in PROXY_QUEUES will leak and remain in memory even
after the background reaper deletes the session from Disk/SessionManager!

we should clear it when the session gets deleted.

logging.info(f"Cleaned up proxy queues for session {session_id}")


def _process_results(
reporters, job_id, run_time, results_tf, scores_tf, multi_trial_scores_tf, config, model_config, db_configs
Expand Down
1 change: 1 addition & 0 deletions evalbench/evalproto/eval_connect.proto
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ option java_multiple_files = true;
message EvalConnectRequest {
string client_id = 1;
bool streaming_eval = 2;
bool bidirectional_stream = 3;
}
2 changes: 2 additions & 0 deletions evalbench/evalproto/eval_request.proto
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ message DialectBasedSQLStatements {

message EvalInputRequest {
int64 id = 1;
string conversation_id = 18;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be sorted by id

string nl_prompt = 2;
string query_type = 3;
string database = 4;
Expand All @@ -29,6 +30,7 @@ message EvalInputRequest {
string sql_generator_error = 12;
float sql_generator_time = 13;
string generated_sql = 14;
string generated_nl_response = 19;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

string job_id = 15;
string trace_id = 16;
string payload = 17;
Expand Down
5 changes: 5 additions & 0 deletions evalbench/evalproto/eval_service.proto
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ service EvalService {
// option deadline = 1800;
}

// MultiTurnEval. Bidirectional stream.
rpc Interact(stream EvalInputRequest) returns (stream EvalInputRequest) {
// option deadline = 1800;
}

// PrepareCodeEvalInputs for NL2Code Evaluation
rpc PrepareCodeEvalInputs(EvalCodeInputRequest) returns (stream EvalCodeInputRequest) {
// option deadline = 1800;
Expand Down
4 changes: 3 additions & 1 deletion evalbench/evaluator/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from evaluator.orchestrator import Orchestrator
from evaluator.oneshotorchestrator import OneShotOrchestrator
from evaluator.cortadoorchestrator import CortadoOrchestrator
from evaluator.interactorchestrator import InteractOrchestrator
from evaluator.dataagentorchestrator import DataAgentOrchestrator
from evaluator.agentorchestrator import AgentOrchestrator
Expand All @@ -9,7 +10,6 @@

def get_orchestrator(config, db_configs, setup_config, report_progress=False):
orchestrator_type = config.get("orchestrator", "oneshot")
logging.info(f"Orchestrator Type: {orchestrator_type}")
if orchestrator_type == "oneshot":
return OneShotOrchestrator(config, db_configs, setup_config, report_progress)
elif orchestrator_type == "interact":
Expand All @@ -18,6 +18,8 @@ def get_orchestrator(config, db_configs, setup_config, report_progress=False):
return DataAgentOrchestrator(config, db_configs, setup_config, report_progress)
elif orchestrator_type in ("geminicli", "agent"):
return AgentOrchestrator(config, db_configs, setup_config, report_progress)
elif orchestrator_type == "cortado":
return CortadoOrchestrator(config, db_configs, setup_config, report_progress)
else:
return Orchestrator(config, db_configs, setup_config, report_progress)

Expand Down
Loading
Loading