Skip to content

Commit e5a97bc

Browse files
committed
feat: Add extraction of default and descriptions for fields in Pydantic models
1 parent fdd0202 commit e5a97bc

File tree

2 files changed

+200
-2
lines changed

2 files changed

+200
-2
lines changed

py/src/braintrust/parameters.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,20 @@ def _pydantic_to_json_schema(model: Any) -> dict[str, Any]:
3434
raise ValueError(f"Cannot convert {model} to JSON schema - not a pydantic model")
3535

3636

37+
def _extract_pydantic_fields(schema: dict[str, dict[str, str]]) -> dict[str, list[dict[str, str]]]:
38+
"""Extract pydantic fields default and description metadata"""
39+
flatten_defaults = []
40+
flatten_description = []
41+
for field_name, field_metadata in schema.items():
42+
flatten_defaults.append({
43+
field_name: field_metadata.get("default")
44+
})
45+
flatten_description.append({field_name: field_metadata.get("description")})
46+
return {
47+
"default": flatten_defaults,
48+
"description": flatten_description
49+
}
50+
3751
def validate_parameters(
3852
parameters: dict[str, Any],
3953
parameter_schema: EvalParameters,
@@ -143,10 +157,13 @@ def parameters_to_json_schema(parameters: EvalParameters) -> dict[str, Any]:
143157
else:
144158
# Pydantic model
145159
try:
160+
pydantic_schema = _pydantic_to_json_schema(schema)
161+
parameter_values = _extract_pydantic_fields(pydantic_schema.get("properties", {}))
146162
result[name] = {
147163
"type": "data",
148-
"schema": _pydantic_to_json_schema(schema),
149-
# TODO: Extract default and description from pydantic model
164+
"schema": pydantic_schema,
165+
"default": parameter_values.get("default"),
166+
"description": parameter_values.get("description")
150167
}
151168
except ValueError:
152169
# Not a pydantic model, skip
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
from unittest.mock import MagicMock
2+
3+
import pytest
4+
from braintrust.parameters import _extract_pydantic_fields, parameters_to_json_schema
5+
6+
DEFAULT_VALUE = "You are a helpful assistant."
7+
DEFAULT_DESCRIPTION = "System prompt for the model"
8+
SCHEMA_TITLE = "SystemPrompt"
9+
PARAM_NAME = "system_prompt"
10+
11+
def test_extract_single_field_with_default_and_description():
12+
schema = {
13+
"value": {"type": "string", "default": DEFAULT_VALUE, "description": DEFAULT_DESCRIPTION},
14+
}
15+
result = _extract_pydantic_fields(schema)
16+
assert result == {
17+
"default": [{"value": DEFAULT_VALUE}],
18+
"description": [{"value": DEFAULT_DESCRIPTION}],
19+
}
20+
21+
22+
def test_extract_single_field_missing_default_and_description():
23+
schema = {
24+
"value": {"type": "string"},
25+
}
26+
result = _extract_pydantic_fields(schema)
27+
assert result == {
28+
"default": [{"value": None}],
29+
"description": [{"value": None}],
30+
}
31+
32+
33+
def test_extract_multi_field():
34+
schema = {
35+
"temperature": {"type": "number", "default": 0.7, "description": "Sampling temperature"},
36+
"max_tokens": {"type": "integer", "default": 1024, "description": "Maximum tokens to generate"},
37+
}
38+
result = _extract_pydantic_fields(schema)
39+
assert result == {
40+
"default": [
41+
{"temperature": 0.7},
42+
{"max_tokens": 1024},
43+
],
44+
"description": [
45+
{"temperature": "Sampling temperature"},
46+
{"max_tokens": "Maximum tokens to generate"},
47+
],
48+
}
49+
50+
51+
def test_extract_multi_field_partial_metadata():
52+
schema = {
53+
"temperature": {"type": "number", "default": 0.7},
54+
"max_tokens": {"type": "integer", "description": "Maximum tokens to generate"},
55+
}
56+
result = _extract_pydantic_fields(schema)
57+
assert result == {
58+
"default": [
59+
{"temperature": 0.7},
60+
{"max_tokens": None},
61+
],
62+
"description": [
63+
{"temperature": None},
64+
{"max_tokens": "Maximum tokens to generate"},
65+
],
66+
}
67+
68+
69+
def test_extract_empty_schema():
70+
result = _extract_pydantic_fields({})
71+
assert result == {"default": [], "description": []}
72+
73+
74+
@pytest.fixture
75+
def v2_model():
76+
def _make(default=DEFAULT_VALUE, description=DEFAULT_DESCRIPTION):
77+
model = MagicMock()
78+
model.model_json_schema.return_value = {
79+
"title": SCHEMA_TITLE,
80+
"type": "object",
81+
"properties": {"value": {"type": "string", "default": default, "description": description}},
82+
}
83+
del model.get
84+
return model
85+
return _make
86+
87+
88+
@pytest.fixture
89+
def v1_model():
90+
def _make(default=DEFAULT_VALUE):
91+
model = MagicMock()
92+
del model.model_json_schema
93+
model.schema.return_value = {
94+
"title": SCHEMA_TITLE,
95+
"type": "object",
96+
"properties": {"value": {"type": "string", "default": default}},
97+
}
98+
del model.get
99+
return model
100+
return _make
101+
102+
103+
@pytest.fixture
104+
def v2_multi_field_model():
105+
def _make():
106+
model = MagicMock()
107+
model.model_json_schema.return_value = {
108+
"title": "ModelConfig",
109+
"type": "object",
110+
"properties": {
111+
"temperature": {"type": "number", "default": 0.7, "description": "Sampling temperature"},
112+
"max_tokens": {"type": "integer", "default": 1024, "description": "Maximum tokens to generate"},
113+
},
114+
}
115+
del model.get
116+
return model
117+
return _make
118+
119+
120+
@pytest.fixture
121+
def v1_multi_field_model():
122+
def _make():
123+
model = MagicMock()
124+
del model.model_json_schema
125+
model.schema.return_value = {
126+
"title": "ModelConfig",
127+
"type": "object",
128+
"properties": {
129+
"temperature": {"type": "number", "default": 0.7},
130+
"max_tokens": {"type": "integer", "default": 1024},
131+
},
132+
}
133+
del model.get
134+
return model
135+
return _make
136+
137+
138+
def test_pydantic_v2_model(v2_model):
139+
schema = parameters_to_json_schema({PARAM_NAME: v2_model()})
140+
141+
assert schema[PARAM_NAME]["type"] == "data"
142+
assert schema[PARAM_NAME]["default"] == [{"value": DEFAULT_VALUE}]
143+
assert schema[PARAM_NAME]["description"] == [{"value": DEFAULT_DESCRIPTION}]
144+
145+
146+
def test_pydantic_v1_model(v1_model):
147+
schema = parameters_to_json_schema({PARAM_NAME: v1_model()})
148+
149+
assert schema[PARAM_NAME]["type"] == "data"
150+
assert schema[PARAM_NAME]["default"] == [{"value": DEFAULT_VALUE}]
151+
assert schema[PARAM_NAME]["description"] == [{"value": None}]
152+
153+
154+
def test_pydantic_v2_multi_field_model(v2_multi_field_model):
155+
schema = parameters_to_json_schema({PARAM_NAME: v2_multi_field_model()})
156+
157+
assert schema[PARAM_NAME]["type"] == "data"
158+
assert schema[PARAM_NAME]["schema"]["title"] == "ModelConfig"
159+
assert schema[PARAM_NAME]["default"] == [
160+
{"temperature": 0.7},
161+
{"max_tokens": 1024},
162+
]
163+
assert schema[PARAM_NAME]["description"] == [
164+
{"temperature": "Sampling temperature"},
165+
{"max_tokens": "Maximum tokens to generate"},
166+
]
167+
168+
169+
def test_pydantic_v1_multi_field_model(v1_multi_field_model):
170+
schema = parameters_to_json_schema({PARAM_NAME: v1_multi_field_model()})
171+
172+
assert schema[PARAM_NAME]["type"] == "data"
173+
assert schema[PARAM_NAME]["schema"]["title"] == "ModelConfig"
174+
assert schema[PARAM_NAME]["default"] == [
175+
{"temperature": 0.7},
176+
{"max_tokens": 1024},
177+
]
178+
assert schema[PARAM_NAME]["description"] == [
179+
{"temperature": None},
180+
{"max_tokens": None},
181+
]

0 commit comments

Comments
 (0)