Skip to content

Commit fd34303

Browse files
feat: Port load_parameters to the Python SDK (#90)
* Add saved parameters support to Python SDK * Fix inline remote eval parameter handling * Address PR comments for python parameters impl * Prefer version over environment in Python loaders * Update json schema creation to dereference pydantic complex data types
1 parent 08210a4 commit fd34303

15 files changed

Lines changed: 1686 additions & 135 deletions

py/setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
"chevron",
1919
"tqdm",
2020
"exceptiongroup>=1.2.0",
21+
"jsonschema",
2122
"python-dotenv",
2223
"sseclient-py",
2324
"python-slugify",

py/src/braintrust/cli/eval.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
set_thread_pool_max_workers,
2323
)
2424
from ..logger import Dataset
25+
from ..parameters import RemoteEvalParameters
2526
from ..util import eprint
2627

2728

@@ -131,6 +132,12 @@ async def run_evaluator_task(evaluator, position, opts: EvaluatorOpts):
131132
if isinstance(evaluator.data, Dataset):
132133
dataset = evaluator.data
133134

135+
parameters = None
136+
if isinstance(evaluator.parameters, RemoteEvalParameters) and evaluator.parameters.id is not None:
137+
parameters = {"id": evaluator.parameters.id}
138+
if evaluator.parameters.version is not None:
139+
parameters["version"] = evaluator.parameters.version
140+
134141
# NOTE: This code is duplicated with _EvalCommon in py/src/braintrust/framework.py.
135142
# Make sure to update those arguments if you change this.
136143
experiment = init_experiment(
@@ -147,6 +154,7 @@ async def run_evaluator_task(evaluator, position, opts: EvaluatorOpts):
147154
git_metadata_settings=evaluator.git_metadata_settings,
148155
repo_info=evaluator.repo_info,
149156
dataset=dataset,
157+
parameters=parameters,
150158
)
151159

152160
try:

py/src/braintrust/cli/push.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,13 @@ def _collect_evaluator_defs(
322322
)
323323

324324

325+
def _collect_parameters_function_defs(
326+
project_ids: ProjectIdCache, functions: list[dict[str, Any]], if_exists: IfExists
327+
) -> None:
328+
for p in global_.parameters:
329+
functions.append(p.to_function_definition(if_exists, project_ids))
330+
331+
325332
def run(args):
326333
"""Runs the braintrust push subcommand."""
327334
login(
@@ -379,6 +386,8 @@ def run(args):
379386

380387
if len(global_.prompts) > 0:
381388
_collect_prompt_function_defs(project_ids, functions, args.if_exists)
389+
if len(global_.parameters) > 0:
390+
_collect_parameters_function_defs(project_ids, functions, args.if_exists)
382391

383392
if len(functions) > 0:
384393
api_conn().post_json("insert-functions", {"functions": functions})

py/src/braintrust/devserver/eval_hooks.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,16 @@
1010
from collections.abc import Callable
1111
from typing import Any
1212

13+
from ..parameters import ValidatedParameters
14+
1315

1416
class EvalHooks:
1517
"""Hooks provided to eval tasks for progress reporting."""
1618

1719
def __init__(
1820
self,
1921
report_progress: Callable[[dict[str, Any]], None] | None = None,
20-
parameters: dict[str, Any] | None = None,
22+
parameters: ValidatedParameters | None = None,
2123
):
2224
self._report_progress = report_progress
2325
self.parameters = parameters or {}

py/src/braintrust/devserver/server.py

Lines changed: 43 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,22 @@
2525
)
2626
)
2727

28-
from ..framework import EvalAsync, EvalScorer, Evaluator, ExperimentSummary, SSEProgressEvent
28+
from ..framework import (
29+
EvalAsync,
30+
EvalHooks,
31+
EvalScorer,
32+
Evaluator,
33+
ExperimentSummary,
34+
SSEProgressEvent,
35+
)
2936
from ..generated_types import FunctionId
3037
from ..logger import BraintrustState, bt_iscoroutinefunction
31-
from ..parameters import parameters_to_json_schema, validate_parameters
38+
from ..parameters import (
39+
RemoteEvalParameters,
40+
ValidatedParameters,
41+
serialize_remote_eval_parameters_container,
42+
validate_parameters,
43+
)
3244
from ..span_identifier_v4 import parse_parent
3345
from .auth import AuthorizationMiddleware
3446
from .cache import cached_login
@@ -41,6 +53,19 @@
4153
_all_evaluators: dict[str, Evaluator[Any, Any]] = {}
4254

4355

56+
class _ParameterOverrideHooks:
57+
def __init__(self, hooks: EvalHooks[Any], parameters: ValidatedParameters):
58+
self._hooks = hooks
59+
self._parameters = parameters
60+
61+
@property
62+
def parameters(self) -> ValidatedParameters:
63+
return self._parameters
64+
65+
def __getattr__(self, name: str):
66+
return getattr(self._hooks, name)
67+
68+
4469
class CheckAuthorizedMiddleware(BaseHTTPMiddleware):
4570
def __init__(self, app, allowed_org_name: str | None = None):
4671
super().__init__(app)
@@ -95,7 +120,9 @@ async def list_evaluators(request: Request) -> JSONResponse:
95120
evaluator_list = {}
96121
for name, evaluator in _all_evaluators.items():
97122
evaluator_list[name] = {
98-
"parameters": parameters_to_json_schema(evaluator.parameters) if evaluator.parameters else {},
123+
"parameters": (
124+
serialize_remote_eval_parameters_container(evaluator.parameters) if evaluator.parameters else None
125+
),
99126
"scores": [{"name": getattr(score, "name", f"score_{i}")} for i, score in enumerate(evaluator.scores)],
100127
}
101128

@@ -154,12 +181,13 @@ async def run_eval(request: Request) -> JSONResponse | StreamingResponse:
154181
# Set up SSE headers for streaming
155182
sse_queue = SSEQueue()
156183

157-
async def task(input, hooks):
184+
async def task(input: Any, hooks: EvalHooks[Any]):
185+
task_hooks = hooks if validated_parameters is None else _ParameterOverrideHooks(hooks, validated_parameters)
158186
if bt_iscoroutinefunction(evaluator.task):
159-
result = await evaluator.task(input, hooks)
187+
result = await evaluator.task(input, task_hooks)
160188
else:
161-
result = evaluator.task(input, hooks)
162-
hooks.report_progress(
189+
result = evaluator.task(input, task_hooks)
190+
task_hooks.report_progress(
163191
{
164192
"format": "code",
165193
"output_type": "completion",
@@ -186,9 +214,10 @@ def stream_fn(event: SSEProgressEvent):
186214
if parent:
187215
parent = parse_parent(parent)
188216

189-
# Override evaluator parameters with validated ones if provided
190-
eval_kwargs = {k: v for (k, v) in evaluator.__dict__.items() if k not in ["eval_name", "project_name"]}
191-
if validated_parameters is not None:
217+
eval_kwargs = {
218+
k: v for (k, v) in evaluator.__dict__.items() if k not in ["eval_name", "project_name", "parameter_values"]
219+
}
220+
if validated_parameters is not None and not isinstance(evaluator.parameters, RemoteEvalParameters):
192221
eval_kwargs["parameters"] = validated_parameters
193222

194223
try:
@@ -289,7 +318,10 @@ def create_app(evaluators: list[Evaluator[Any, Any]], org_name: str | None = Non
289318

290319

291320
def run_dev_server(
292-
evaluators: list[Evaluator[Any, Any]], host: str = "localhost", port: int = 8300, org_name: str | None = None
321+
evaluators: list[Evaluator[Any, Any]],
322+
host: str = "localhost",
323+
port: int = 8300,
324+
org_name: str | None = None,
293325
):
294326
"""Start the dev server.
295327

py/src/braintrust/devserver/test_server_integration.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
from braintrust.test_helpers import has_devserver_installed
99

1010

11+
HAS_PYDANTIC = __import__("importlib.util").util.find_spec("pydantic") is not None
12+
13+
1114
@pytest.fixture
1215
def client():
1316
"""Create test client using the real simple_eval.py example."""
@@ -205,3 +208,67 @@ def test_eval_error_handling(client, api_key, org_name):
205208
error = response.json()
206209
assert "error" in error
207210
assert "not found" in error["error"].lower()
211+
212+
213+
@pytest.mark.skipif(not HAS_PYDANTIC, reason="pydantic not installed")
214+
def test_eval_uses_inline_request_parameters(api_key, org_name, monkeypatch):
215+
from braintrust import Evaluator
216+
from braintrust.devserver import server as devserver_module
217+
from braintrust.devserver.server import create_app
218+
from braintrust.logger import BraintrustState
219+
from pydantic import BaseModel
220+
from starlette.testclient import TestClient
221+
222+
class RequiredInt(BaseModel):
223+
value: int
224+
225+
def task(input: str, hooks) -> dict[str, Any]:
226+
return {"input": input, "num_samples": hooks.parameters["num_samples_without_default"]}
227+
228+
evaluator = Evaluator(
229+
project_name="test-math-eval",
230+
eval_name="inline-parameter-eval",
231+
data=lambda: [{"input": "What is 2+2?", "expected": "4"}],
232+
task=task,
233+
scores=[],
234+
experiment_name=None,
235+
metadata=None,
236+
parameters={"num_samples_without_default": RequiredInt},
237+
)
238+
239+
async def fake_cached_login(**_kwargs):
240+
return BraintrustState()
241+
242+
class FakeSummary:
243+
def as_dict(self):
244+
return {"experiment_name": "inline-parameter-eval", "project_name": "test-math-eval", "scores": {}}
245+
246+
class FakeResult:
247+
summary = FakeSummary()
248+
249+
async def fake_eval_async(*, task, data, parameters, **_kwargs):
250+
assert parameters == {"num_samples_without_default": 1}
251+
datum = data[0]
252+
hooks = type("Hooks", (), {"parameters": parameters, "report_progress": lambda self, _progress: None})()
253+
await task(datum["input"], hooks)
254+
return FakeResult()
255+
256+
monkeypatch.setattr(devserver_module, "cached_login", fake_cached_login)
257+
monkeypatch.setattr(devserver_module, "EvalAsync", fake_eval_async)
258+
259+
response = TestClient(create_app([evaluator])).post(
260+
"/eval",
261+
headers={
262+
"x-bt-auth-token": api_key,
263+
"x-bt-org-name": org_name,
264+
"Content-Type": "application/json",
265+
},
266+
json={
267+
"name": "inline-parameter-eval",
268+
"stream": False,
269+
"parameters": {"num_samples_without_default": 1},
270+
"data": [{"input": "What is 2+2?", "expected": "4"}],
271+
},
272+
)
273+
274+
assert response.status_code == 200

py/src/braintrust/framework.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,13 @@
4242
stringify_exception,
4343
)
4444
from .logger import init as _init_experiment
45-
from .parameters import EvalParameters
45+
from .parameters import (
46+
EvalParameters,
47+
RemoteEvalParameters,
48+
ValidatedParameters,
49+
is_eval_parameter_schema,
50+
validate_parameters,
51+
)
4652
from .resource_manager import ResourceManager
4753
from .score import Score, is_score, is_scorer
4854
from .serializable_data_class import SerializableDataClass
@@ -215,7 +221,7 @@ def meta(self, **info: Any) -> None:
215221

216222
@property
217223
@abc.abstractmethod
218-
def parameters(self) -> dict[str, Any] | None:
224+
def parameters(self) -> ValidatedParameters | None:
219225
"""
220226
The parameters for the current evaluation. These are the validated parameter values
221227
that were passed to the evaluator.
@@ -439,12 +445,14 @@ class Evaluator(Generic[Input, Output]):
439445
Whether to summarize the scores of the experiment after it has run.
440446
"""
441447

442-
parameters: EvalParameters | None = None
448+
parameters: EvalParameters | RemoteEvalParameters | None = None
443449
"""
444450
A set of parameters that will be passed to the evaluator.
445451
Can be used to define prompts or other configurable values.
446452
"""
447453

454+
parameter_values: dict[str, Any] | None = None
455+
448456

449457
@dataclasses.dataclass
450458
class EvalResultWithSummary(SerializableDataClass, Generic[Input, Output]):
@@ -675,7 +683,7 @@ def _EvalCommon(
675683
summarize_scores: bool,
676684
no_send_logs: bool,
677685
error_score_handler: ErrorScoreHandler | None = None,
678-
parameters: EvalParameters | None = None,
686+
parameters: EvalParameters | RemoteEvalParameters | None = None,
679687
on_start: Callable[[ExperimentSummary], None] | None = None,
680688
stream: Callable[[SSEProgressEvent], None] | None = None,
681689
parent: str | None = None,
@@ -741,6 +749,12 @@ async def make_empty_summary():
741749
if isinstance(evaluator.data, Dataset):
742750
dataset = evaluator.data
743751

752+
experiment_parameters = None
753+
if isinstance(evaluator.parameters, RemoteEvalParameters) and evaluator.parameters.id is not None:
754+
experiment_parameters = {"id": evaluator.parameters.id}
755+
if evaluator.parameters.version is not None:
756+
experiment_parameters["version"] = evaluator.parameters.version
757+
744758
# NOTE: This code is duplicated with run_evaluator_task in py/src/braintrust/cli/eval.py.
745759
# Make sure to update those arguments if you change this.
746760
experiment = None
@@ -759,6 +773,7 @@ async def make_empty_summary():
759773
git_metadata_settings=evaluator.git_metadata_settings,
760774
repo_info=evaluator.repo_info,
761775
dataset=dataset,
776+
parameters=experiment_parameters,
762777
state=state,
763778
)
764779

@@ -804,7 +819,7 @@ async def EvalAsync(
804819
description: str | None = None,
805820
summarize_scores: bool = True,
806821
no_send_logs: bool = False,
807-
parameters: EvalParameters | None = None,
822+
parameters: EvalParameters | RemoteEvalParameters | None = None,
808823
on_start: Callable[[ExperimentSummary], None] | None = None,
809824
stream: Callable[[SSEProgressEvent], None] | None = None,
810825
parent: str | None = None,
@@ -931,7 +946,7 @@ def Eval(
931946
description: str | None = None,
932947
summarize_scores: bool = True,
933948
no_send_logs: bool = False,
934-
parameters: EvalParameters | None = None,
949+
parameters: EvalParameters | RemoteEvalParameters | None = None,
935950
on_start: Callable[[ExperimentSummary], None] | None = None,
936951
stream: Callable[[SSEProgressEvent], None] | None = None,
937952
parent: str | None = None,
@@ -1153,7 +1168,7 @@ def __init__(
11531168
trial_index: int = 0,
11541169
tags: Sequence[str] | None = None,
11551170
report_progress: Callable[[TaskProgressEvent], None] = None,
1156-
parameters: dict[str, Any] | None = None,
1171+
parameters: ValidatedParameters | None = None,
11571172
):
11581173
if metadata is not None:
11591174
self.update({"metadata": metadata})
@@ -1211,7 +1226,7 @@ def report_progress(self, event: TaskProgressEvent):
12111226
return self._report_progress(event)
12121227

12131228
@property
1214-
def parameters(self) -> dict[str, Any] | None:
1229+
def parameters(self) -> ValidatedParameters | None:
12151230
return self._parameters
12161231

12171232

@@ -1392,6 +1407,15 @@ def get_other_fields(s):
13921407
scorer_names = [_scorer_name(scorer, i) for i, scorer in enumerate(scorers)]
13931408
unhandled_scores = scorer_names
13941409

1410+
if evaluator.parameter_values is not None:
1411+
resolved_evaluator_parameters = evaluator.parameter_values
1412+
elif isinstance(evaluator.parameters, RemoteEvalParameters):
1413+
resolved_evaluator_parameters = validate_parameters({}, evaluator.parameters)
1414+
elif is_eval_parameter_schema(evaluator.parameters):
1415+
resolved_evaluator_parameters = validate_parameters({}, evaluator.parameters)
1416+
else:
1417+
resolved_evaluator_parameters = evaluator.parameters
1418+
13951419
async def run_evaluator_task(datum, trial_index=0):
13961420
if isinstance(datum, dict):
13971421
datum = EvalCase.from_dict(datum)
@@ -1451,7 +1475,7 @@ def report_progress(event: TaskProgressEvent):
14511475
trial_index=trial_index,
14521476
tags=tags,
14531477
report_progress=report_progress,
1454-
parameters=evaluator.parameters,
1478+
parameters=resolved_evaluator_parameters,
14551479
)
14561480

14571481
# Check if the task takes a hooks argument

0 commit comments

Comments
 (0)