Skip to content

Commit b3e6b5d

Browse files
committed
Fix inline remote eval parameter handling
1 parent 8382fd5 commit b3e6b5d

3 files changed

Lines changed: 87 additions & 7 deletions

File tree

py/src/braintrust/devserver/server.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from ..framework import EvalAsync, EvalScorer, Evaluator, ExperimentSummary, SSEProgressEvent
2929
from ..generated_types import FunctionId
3030
from ..logger import BraintrustState, bt_iscoroutinefunction
31-
from ..parameters import serialize_remote_eval_parameters_container, validate_parameters
31+
from ..parameters import RemoteEvalParameters, serialize_remote_eval_parameters_container, validate_parameters
3232
from ..span_identifier_v4 import parse_parent
3333
from .auth import AuthorizationMiddleware
3434
from .cache import cached_login
@@ -228,6 +228,8 @@ def stream_fn(event: SSEProgressEvent):
228228
eval_kwargs = {
229229
k: v for (k, v) in evaluator.__dict__.items() if k not in ["eval_name", "project_name", "parameter_values"]
230230
}
231+
if validated_parameters is not None and not RemoteEvalParameters.is_parameters(evaluator.parameters):
232+
eval_kwargs["parameters"] = validated_parameters
231233

232234
try:
233235
eval_task = asyncio.create_task(

py/src/braintrust/devserver/test_server_integration.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
from braintrust.test_helpers import has_devserver_installed
99

1010

11+
HAS_PYDANTIC = __import__("importlib.util").util.find_spec("pydantic") is not None
12+
13+
1114
@pytest.fixture
1215
def client():
1316
"""Create test client using the real simple_eval.py example."""
@@ -205,3 +208,67 @@ def test_eval_error_handling(client, api_key, org_name):
205208
error = response.json()
206209
assert "error" in error
207210
assert "not found" in error["error"].lower()
211+
212+
213+
@pytest.mark.skipif(not HAS_PYDANTIC, reason="pydantic not installed")
214+
def test_eval_uses_inline_request_parameters(api_key, org_name, monkeypatch):
215+
from braintrust import Evaluator
216+
from braintrust.devserver import server as devserver_module
217+
from braintrust.devserver.server import create_app
218+
from braintrust.logger import BraintrustState
219+
from pydantic import BaseModel
220+
from starlette.testclient import TestClient
221+
222+
class RequiredInt(BaseModel):
223+
value: int
224+
225+
def task(input: str, hooks) -> dict[str, Any]:
226+
return {"input": input, "num_samples": hooks.parameters["num_samples_without_default"]}
227+
228+
evaluator = Evaluator(
229+
project_name="test-math-eval",
230+
eval_name="inline-parameter-eval",
231+
data=lambda: [{"input": "What is 2+2?", "expected": "4"}],
232+
task=task,
233+
scores=[],
234+
experiment_name=None,
235+
metadata=None,
236+
parameters={"num_samples_without_default": RequiredInt},
237+
)
238+
239+
async def fake_cached_login(**_kwargs):
240+
return BraintrustState()
241+
242+
class FakeSummary:
243+
def as_dict(self):
244+
return {"experiment_name": "inline-parameter-eval", "project_name": "test-math-eval", "scores": {}}
245+
246+
class FakeResult:
247+
summary = FakeSummary()
248+
249+
async def fake_eval_async(*, task, data, parameters, **_kwargs):
250+
assert parameters == {"num_samples_without_default": 1}
251+
datum = data[0]
252+
hooks = type("Hooks", (), {"parameters": parameters, "report_progress": lambda self, _progress: None})()
253+
await task(datum["input"], hooks)
254+
return FakeResult()
255+
256+
monkeypatch.setattr(devserver_module, "cached_login", fake_cached_login)
257+
monkeypatch.setattr(devserver_module, "EvalAsync", fake_eval_async)
258+
259+
response = TestClient(create_app([evaluator])).post(
260+
"/eval",
261+
headers={
262+
"x-bt-auth-token": api_key,
263+
"x-bt-org-name": org_name,
264+
"Content-Type": "application/json",
265+
},
266+
json={
267+
"name": "inline-parameter-eval",
268+
"stream": False,
269+
"parameters": {"num_samples_without_default": 1},
270+
"data": [{"input": "What is 2+2?", "expected": "4"}],
271+
},
272+
)
273+
274+
assert response.status_code == 200

py/src/braintrust/parameters.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -321,27 +321,38 @@ def serialize_eval_parameters(parameters: EvalParameters) -> dict[str, Any]:
321321

322322
for name, schema in parameters.items():
323323
if _is_prompt_parameter(schema):
324-
result[name] = {
324+
parameter_data = {
325325
"type": "prompt",
326-
"default": _prompt_data_to_dict(schema.get("default")),
327326
"description": schema.get("description"),
328327
}
328+
default = schema.get("default")
329+
if default is not None:
330+
parameter_data["default"] = _prompt_data_to_dict(default)
331+
result[name] = parameter_data
329332
elif _is_model_parameter(schema):
330-
result[name] = {
333+
parameter_data = {
331334
"type": "model",
332-
"default": schema.get("default"),
333335
"description": schema.get("description"),
334336
}
337+
default = schema.get("default")
338+
if default is not None:
339+
parameter_data["default"] = default
340+
result[name] = parameter_data
335341
elif schema is None:
336342
result[name] = {
337343
"type": "data",
338344
"schema": {},
339345
}
340346
else:
341-
result[name] = {
347+
schema_json = _serialize_pydantic_parameter_schema(schema)
348+
parameter_data = {
342349
"type": "data",
343-
"schema": _serialize_pydantic_parameter_schema(schema),
350+
"schema": schema_json,
351+
"description": schema_json.get("description"),
344352
}
353+
if "default" in schema_json:
354+
parameter_data["default"] = schema_json["default"]
355+
result[name] = parameter_data
345356

346357
return result
347358

0 commit comments

Comments
 (0)