Skip to content

Commit f80ecfd

Browse files
committed
Add saved parameters support to Python SDK
1 parent fc2d5fe commit f80ecfd

13 files changed

Lines changed: 1051 additions & 110 deletions

File tree

py/setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
"chevron",
1919
"tqdm",
2020
"exceptiongroup>=1.2.0",
21+
"jsonschema",
2122
"python-dotenv",
2223
"sseclient-py",
2324
"python-slugify",

py/src/braintrust/cli/eval.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
set_thread_pool_max_workers,
2323
)
2424
from ..logger import Dataset
25+
from ..parameters import RemoteEvalParameters
2526
from ..util import eprint
2627

2728

@@ -131,6 +132,12 @@ async def run_evaluator_task(evaluator, position, opts: EvaluatorOpts):
131132
if isinstance(evaluator.data, Dataset):
132133
dataset = evaluator.data
133134

135+
parameters = None
136+
if RemoteEvalParameters.is_parameters(evaluator.parameters) and evaluator.parameters.id is not None:
137+
parameters = {"id": evaluator.parameters.id}
138+
if evaluator.parameters.version is not None:
139+
parameters["version"] = evaluator.parameters.version
140+
134141
# NOTE: This code is duplicated with _EvalCommon in py/src/braintrust/framework.py.
135142
# Make sure to update those arguments if you change this.
136143
experiment = init_experiment(
@@ -147,6 +154,7 @@ async def run_evaluator_task(evaluator, position, opts: EvaluatorOpts):
147154
git_metadata_settings=evaluator.git_metadata_settings,
148155
repo_info=evaluator.repo_info,
149156
dataset=dataset,
157+
parameters=parameters,
150158
)
151159

152160
try:

py/src/braintrust/cli/push.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,6 @@ def _collect_prompt_function_defs(
271271
for p in global_.prompts:
272272
functions.append(p.to_function_definition(if_exists, project_ids))
273273

274-
275274
def _collect_evaluator_defs(
276275
project_ids: ProjectIdCache,
277276
functions: list[dict[str, Any]],
@@ -322,6 +321,13 @@ def _collect_evaluator_defs(
322321
)
323322

324323

324+
def _collect_parameters_function_defs(
325+
project_ids: ProjectIdCache, functions: list[dict[str, Any]], if_exists: IfExists
326+
) -> None:
327+
for p in global_.parameters:
328+
functions.append(p.to_function_definition(if_exists, project_ids))
329+
330+
325331
def run(args):
326332
"""Runs the braintrust push subcommand."""
327333
login(
@@ -379,6 +385,8 @@ def run(args):
379385

380386
if len(global_.prompts) > 0:
381387
_collect_prompt_function_defs(project_ids, functions, args.if_exists)
388+
if len(global_.parameters) > 0:
389+
_collect_parameters_function_defs(project_ids, functions, args.if_exists)
382390

383391
if len(functions) > 0:
384392
api_conn().post_json("insert-functions", {"functions": functions})

py/src/braintrust/devserver/server.py

Lines changed: 47 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from ..framework import EvalAsync, EvalScorer, Evaluator, ExperimentSummary, SSEProgressEvent
2929
from ..generated_types import FunctionId
3030
from ..logger import BraintrustState, bt_iscoroutinefunction
31-
from ..parameters import parameters_to_json_schema, validate_parameters
31+
from ..parameters import serialize_remote_eval_parameters_container, validate_parameters
3232
from ..span_identifier_v4 import parse_parent
3333
from .auth import AuthorizationMiddleware
3434
from .cache import cached_login
@@ -41,6 +41,42 @@
4141
_all_evaluators: dict[str, Evaluator[Any, Any]] = {}
4242

4343

44+
class _ParameterOverrideHooks:
45+
def __init__(self, hooks: Any, parameters: dict[str, Any]):
46+
self._hooks = hooks
47+
self._parameters = parameters
48+
49+
@property
50+
def metadata(self):
51+
return self._hooks.metadata
52+
53+
@property
54+
def expected(self):
55+
return self._hooks.expected
56+
57+
@property
58+
def span(self):
59+
return self._hooks.span
60+
61+
@property
62+
def trial_index(self):
63+
return self._hooks.trial_index
64+
65+
@property
66+
def tags(self):
67+
return self._hooks.tags
68+
69+
@property
70+
def parameters(self):
71+
return self._parameters
72+
73+
def report_progress(self, progress):
74+
return self._hooks.report_progress(progress)
75+
76+
def meta(self, **info: Any):
77+
return self._hooks.meta(**info)
78+
79+
4480
class CheckAuthorizedMiddleware(BaseHTTPMiddleware):
4581
def __init__(self, app, allowed_org_name: str | None = None):
4682
super().__init__(app)
@@ -95,7 +131,9 @@ async def list_evaluators(request: Request) -> JSONResponse:
95131
evaluator_list = {}
96132
for name, evaluator in _all_evaluators.items():
97133
evaluator_list[name] = {
98-
"parameters": parameters_to_json_schema(evaluator.parameters) if evaluator.parameters else {},
134+
"parameters": (
135+
serialize_remote_eval_parameters_container(evaluator.parameters) if evaluator.parameters else None
136+
),
99137
"scores": [{"name": getattr(score, "name", f"score_{i}")} for i, score in enumerate(evaluator.scores)],
100138
}
101139

@@ -155,11 +193,12 @@ async def run_eval(request: Request) -> JSONResponse | StreamingResponse:
155193
sse_queue = SSEQueue()
156194

157195
async def task(input, hooks):
196+
task_hooks = hooks if validated_parameters is None else _ParameterOverrideHooks(hooks, validated_parameters)
158197
if bt_iscoroutinefunction(evaluator.task):
159-
result = await evaluator.task(input, hooks)
198+
result = await evaluator.task(input, task_hooks)
160199
else:
161-
result = evaluator.task(input, hooks)
162-
hooks.report_progress(
200+
result = evaluator.task(input, task_hooks)
201+
task_hooks.report_progress(
163202
{
164203
"format": "code",
165204
"output_type": "completion",
@@ -186,10 +225,9 @@ def stream_fn(event: SSEProgressEvent):
186225
if parent:
187226
parent = parse_parent(parent)
188227

189-
# Override evaluator parameters with validated ones if provided
190-
eval_kwargs = {k: v for (k, v) in evaluator.__dict__.items() if k not in ["eval_name", "project_name"]}
191-
if validated_parameters is not None:
192-
eval_kwargs["parameters"] = validated_parameters
228+
eval_kwargs = {
229+
k: v for (k, v) in evaluator.__dict__.items() if k not in ["eval_name", "project_name", "parameter_values"]
230+
}
193231

194232
try:
195233
eval_task = asyncio.create_task(

py/src/braintrust/framework.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
stringify_exception,
4343
)
4444
from .logger import init as _init_experiment
45-
from .parameters import EvalParameters
45+
from .parameters import EvalParameters, RemoteEvalParameters, is_eval_parameter_schema, validate_parameters
4646
from .resource_manager import ResourceManager
4747
from .score import Score, is_score, is_scorer
4848
from .serializable_data_class import SerializableDataClass
@@ -439,12 +439,14 @@ class Evaluator(Generic[Input, Output]):
439439
Whether to summarize the scores of the experiment after it has run.
440440
"""
441441

442-
parameters: EvalParameters | None = None
442+
parameters: EvalParameters | RemoteEvalParameters | None = None
443443
"""
444444
A set of parameters that will be passed to the evaluator.
445445
Can be used to define prompts or other configurable values.
446446
"""
447447

448+
parameter_values: dict[str, Any] | None = None
449+
448450

449451
@dataclasses.dataclass
450452
class EvalResultWithSummary(SerializableDataClass, Generic[Input, Output]):
@@ -675,7 +677,7 @@ def _EvalCommon(
675677
summarize_scores: bool,
676678
no_send_logs: bool,
677679
error_score_handler: ErrorScoreHandler | None = None,
678-
parameters: EvalParameters | None = None,
680+
parameters: EvalParameters | RemoteEvalParameters | None = None,
679681
on_start: Callable[[ExperimentSummary], None] | None = None,
680682
stream: Callable[[SSEProgressEvent], None] | None = None,
681683
parent: str | None = None,
@@ -741,6 +743,12 @@ async def make_empty_summary():
741743
if isinstance(evaluator.data, Dataset):
742744
dataset = evaluator.data
743745

746+
experiment_parameters = None
747+
if RemoteEvalParameters.is_parameters(evaluator.parameters) and evaluator.parameters.id is not None:
748+
experiment_parameters = {"id": evaluator.parameters.id}
749+
if evaluator.parameters.version is not None:
750+
experiment_parameters["version"] = evaluator.parameters.version
751+
744752
# NOTE: This code is duplicated with run_evaluator_task in py/src/braintrust/cli/eval.py.
745753
# Make sure to update those arguments if you change this.
746754
experiment = None
@@ -759,6 +767,7 @@ async def make_empty_summary():
759767
git_metadata_settings=evaluator.git_metadata_settings,
760768
repo_info=evaluator.repo_info,
761769
dataset=dataset,
770+
parameters=experiment_parameters,
762771
state=state,
763772
)
764773

@@ -804,7 +813,7 @@ async def EvalAsync(
804813
description: str | None = None,
805814
summarize_scores: bool = True,
806815
no_send_logs: bool = False,
807-
parameters: EvalParameters | None = None,
816+
parameters: EvalParameters | RemoteEvalParameters | None = None,
808817
on_start: Callable[[ExperimentSummary], None] | None = None,
809818
stream: Callable[[SSEProgressEvent], None] | None = None,
810819
parent: str | None = None,
@@ -931,7 +940,7 @@ def Eval(
931940
description: str | None = None,
932941
summarize_scores: bool = True,
933942
no_send_logs: bool = False,
934-
parameters: EvalParameters | None = None,
943+
parameters: EvalParameters | RemoteEvalParameters | None = None,
935944
on_start: Callable[[ExperimentSummary], None] | None = None,
936945
stream: Callable[[SSEProgressEvent], None] | None = None,
937946
parent: str | None = None,
@@ -1392,6 +1401,15 @@ def get_other_fields(s):
13921401
scorer_names = [_scorer_name(scorer, i) for i, scorer in enumerate(scorers)]
13931402
unhandled_scores = scorer_names
13941403

1404+
if evaluator.parameter_values is not None:
1405+
resolved_evaluator_parameters = evaluator.parameter_values
1406+
elif RemoteEvalParameters.is_parameters(evaluator.parameters):
1407+
resolved_evaluator_parameters = validate_parameters({}, evaluator.parameters)
1408+
elif is_eval_parameter_schema(evaluator.parameters):
1409+
resolved_evaluator_parameters = validate_parameters({}, evaluator.parameters)
1410+
else:
1411+
resolved_evaluator_parameters = evaluator.parameters
1412+
13951413
async def run_evaluator_task(datum, trial_index=0):
13961414
if isinstance(datum, dict):
13971415
datum = EvalCase.from_dict(datum)
@@ -1451,7 +1469,7 @@ def report_progress(event: TaskProgressEvent):
14511469
trial_index=trial_index,
14521470
tags=tags,
14531471
report_progress=report_progress,
1454-
parameters=evaluator.parameters,
1472+
parameters=resolved_evaluator_parameters,
14551473
)
14561474

14571475
# Check if the task takes a hooks argument

py/src/braintrust/framework2.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
SavedFunctionId,
1717
ToolFunctionDefinition,
1818
)
19+
from .parameters import EvalParameters, get_default_data_from_parameters_schema, parameters_to_json_schema
1920
from .util import eprint
2021

2122

@@ -40,6 +41,7 @@ class _GlobalState:
4041
def __init__(self):
4142
self.functions: list[CodeFunction] = []
4243
self.prompts: list[CodePrompt] = []
44+
self.parameters: list["CodeParameters"] = []
4345

4446

4547
global_ = _GlobalState()
@@ -116,6 +118,36 @@ def to_function_definition(self, if_exists: IfExists | None, project_ids: Projec
116118
return j
117119

118120

121+
@dataclasses.dataclass
122+
class CodeParameters:
123+
project: "Project"
124+
name: str
125+
slug: str
126+
description: str | None
127+
schema: EvalParameters
128+
if_exists: IfExists | None
129+
metadata: dict[str, Any] | None = None
130+
131+
def to_function_definition(self, if_exists: IfExists | None, project_ids: ProjectIdCache) -> dict[str, Any]:
132+
schema = parameters_to_json_schema(self.schema)
133+
j: dict[str, Any] = {
134+
"project_id": project_ids.get(self.project),
135+
"name": self.name,
136+
"slug": self.slug,
137+
"description": self.description or "",
138+
"function_type": "parameters",
139+
"function_data": {
140+
"type": "parameters",
141+
"data": get_default_data_from_parameters_schema(schema),
142+
"__schema": schema,
143+
},
144+
"if_exists": self.if_exists if self.if_exists is not None else if_exists,
145+
}
146+
if self.metadata is not None:
147+
j["metadata"] = self.metadata
148+
return j
149+
150+
119151
class ToolBuilder:
120152
"""Builder to create a tool in Braintrust."""
121153

@@ -305,6 +337,38 @@ def create(
305337
return p
306338

307339

340+
class ParametersBuilder:
341+
"""Builder to create saved parameters in Braintrust."""
342+
343+
def __init__(self, project: "Project"):
344+
self.project = project
345+
346+
def create(
347+
self,
348+
*,
349+
name: str,
350+
schema: EvalParameters,
351+
slug: str | None = None,
352+
description: str | None = None,
353+
if_exists: IfExists | None = None,
354+
metadata: dict[str, Any] | None = None,
355+
) -> EvalParameters:
356+
if slug is None or len(slug) == 0:
357+
slug = slugify.slugify(name)
358+
359+
parameters = CodeParameters(
360+
project=self.project,
361+
name=name,
362+
slug=slug,
363+
description=description,
364+
schema=schema,
365+
if_exists=if_exists,
366+
metadata=metadata,
367+
)
368+
self.project.add_parameters(parameters)
369+
return schema
370+
371+
308372
class ScorerBuilder:
309373
"""Builder to create a scorer in Braintrust."""
310374

@@ -486,10 +550,12 @@ def __init__(self, name: str):
486550
self.name = name
487551
self.tools = ToolBuilder(self)
488552
self.prompts = PromptBuilder(self)
553+
self.parameters = ParametersBuilder(self)
489554
self.scorers = ScorerBuilder(self)
490555

491556
self._publishable_code_functions: list[CodeFunction] = []
492557
self._publishable_prompts: list[CodePrompt] = []
558+
self._publishable_parameters: list[CodeParameters] = []
493559

494560
def add_code_function(self, fn: CodeFunction):
495561
self._publishable_code_functions.append(fn)
@@ -501,6 +567,11 @@ def add_prompt(self, prompt: CodePrompt):
501567
if _is_lazy_load():
502568
global_.prompts.append(prompt)
503569

570+
def add_parameters(self, parameters: CodeParameters):
571+
self._publishable_parameters.append(parameters)
572+
if _is_lazy_load():
573+
global_.parameters.append(parameters)
574+
504575
def publish(self):
505576
if _is_lazy_load():
506577
eprint(f"{bcolors.WARNING}publish() is a no-op when running `braintrust push`.{bcolors.ENDC}")
@@ -518,6 +589,8 @@ def publish(self):
518589
for prompt in self._publishable_prompts:
519590
prompt_definition = prompt.to_function_definition(None, project_id_cache)
520591
definitions.append(prompt_definition)
592+
for parameters in self._publishable_parameters:
593+
definitions.append(parameters.to_function_definition(None, project_id_cache))
521594
return api_conn().post_json("insert-functions", {"functions": definitions})
522595

523596

0 commit comments

Comments
 (0)