Skip to content

Commit f87ba31

Browse files
committed
feat: Add extraction of default and descriptions for fields in Pydantic models
1 parent fdd0202 commit f87ba31

2 files changed

Lines changed: 175 additions & 2 deletions

File tree

py/src/braintrust/parameters.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,26 @@ def _pydantic_to_json_schema(model: Any) -> dict[str, Any]:
3434
raise ValueError(f"Cannot convert {model} to JSON schema - not a pydantic model")
3535

3636

37+
def _extract_pydantic_fields(schema: dict[str, dict[str, Any]]) -> tuple[Any, Any] | tuple[dict[str, Any], dict[str, Any]]:
38+
"""Extract pydantic fields default and description metadata"""
39+
flatten_defaults = {}
40+
flatten_description = {}
41+
schema_items = schema.items()
42+
if len(schema_items) == 1:
43+
for _, field_metadata in schema_items:
44+
return (
45+
field_metadata.get("default"),
46+
field_metadata.get("description")
47+
)
48+
49+
for field_name, field_metadata in schema_items:
50+
flatten_defaults[field_name] = field_metadata.get("default")
51+
flatten_description[field_name] = field_metadata.get("description")
52+
return (
53+
flatten_defaults,
54+
flatten_description
55+
)
56+
3757
def validate_parameters(
3858
parameters: dict[str, Any],
3959
parameter_schema: EvalParameters,
@@ -143,10 +163,13 @@ def parameters_to_json_schema(parameters: EvalParameters) -> dict[str, Any]:
143163
else:
144164
# Pydantic model
145165
try:
166+
pydantic_schema = _pydantic_to_json_schema(schema)
167+
model_defaults, model_descriptions = _extract_pydantic_fields(pydantic_schema.get("properties", {}))
146168
result[name] = {
147169
"type": "data",
148-
"schema": _pydantic_to_json_schema(schema),
149-
# TODO: Extract default and description from pydantic model
170+
"schema": pydantic_schema,
171+
"default": model_defaults,
172+
"description": model_descriptions
150173
}
151174
except ValueError:
152175
# Not a pydantic model, skip
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
from unittest.mock import MagicMock
2+
3+
import pytest
4+
from braintrust.parameters import _extract_pydantic_fields, parameters_to_json_schema
5+
6+
DEFAULT_VALUE = "You are a helpful assistant."
7+
DEFAULT_DESCRIPTION = "System prompt for the model"
8+
SCHEMA_TITLE = "SystemPrompt"
9+
PARAM_NAME = "system_prompt"
10+
11+
def test_extract_single_field_with_default_and_description():
12+
schema = {
13+
"value": {"type": "string", "default": DEFAULT_VALUE, "description": DEFAULT_DESCRIPTION},
14+
}
15+
defaults, descriptions = _extract_pydantic_fields(schema)
16+
assert defaults == DEFAULT_VALUE
17+
assert descriptions == DEFAULT_DESCRIPTION
18+
19+
20+
def test_extract_single_field_missing_default_and_description():
21+
schema = {
22+
"value": {"type": "string"},
23+
}
24+
defaults, descriptions = _extract_pydantic_fields(schema)
25+
assert defaults is None
26+
assert descriptions is None
27+
28+
29+
def test_extract_multi_field():
30+
schema = {
31+
"temperature": {"type": "number", "default": 0.7, "description": "Sampling temperature"},
32+
"max_tokens": {"type": "integer", "default": 1024, "description": "Maximum tokens to generate"},
33+
}
34+
defaults, descriptions = _extract_pydantic_fields(schema)
35+
assert defaults == {"temperature": 0.7, "max_tokens": 1024}
36+
assert descriptions == {"temperature": "Sampling temperature", "max_tokens": "Maximum tokens to generate"}
37+
38+
39+
def test_extract_multi_field_partial_metadata():
40+
schema = {
41+
"temperature": {"type": "number", "default": 0.7},
42+
"max_tokens": {"type": "integer", "description": "Maximum tokens to generate"},
43+
}
44+
defaults, descriptions = _extract_pydantic_fields(schema)
45+
assert defaults == {"temperature": 0.7, "max_tokens": None}
46+
assert descriptions == {"temperature": None, "max_tokens": "Maximum tokens to generate"}
47+
48+
49+
def test_extract_empty_schema():
50+
defaults, descriptions = _extract_pydantic_fields({})
51+
assert defaults == {}
52+
assert descriptions == {}
53+
54+
55+
@pytest.fixture
56+
def v2_model():
57+
def _make(default=DEFAULT_VALUE, description=DEFAULT_DESCRIPTION):
58+
model = MagicMock()
59+
model.model_json_schema.return_value = {
60+
"title": SCHEMA_TITLE,
61+
"type": "object",
62+
"properties": {"value": {"type": "string", "default": default, "description": description}},
63+
}
64+
del model.get
65+
return model
66+
return _make
67+
68+
69+
@pytest.fixture
70+
def v1_model():
71+
def _make(default=DEFAULT_VALUE):
72+
model = MagicMock()
73+
del model.model_json_schema
74+
model.schema.return_value = {
75+
"title": SCHEMA_TITLE,
76+
"type": "object",
77+
"properties": {"value": {"type": "string", "default": default}},
78+
}
79+
del model.get
80+
return model
81+
return _make
82+
83+
84+
@pytest.fixture
85+
def v2_multi_field_model():
86+
def _make():
87+
model = MagicMock()
88+
model.model_json_schema.return_value = {
89+
"title": "ModelConfig",
90+
"type": "object",
91+
"properties": {
92+
"temperature": {"type": "number", "default": 0.7, "description": "Sampling temperature"},
93+
"max_tokens": {"type": "integer", "default": 1024, "description": "Maximum tokens to generate"},
94+
},
95+
}
96+
del model.get
97+
return model
98+
return _make
99+
100+
101+
@pytest.fixture
102+
def v1_multi_field_model():
103+
def _make():
104+
model = MagicMock()
105+
del model.model_json_schema
106+
model.schema.return_value = {
107+
"title": "ModelConfig",
108+
"type": "object",
109+
"properties": {
110+
"temperature": {"type": "number", "default": 0.7},
111+
"max_tokens": {"type": "integer", "default": 1024},
112+
},
113+
}
114+
del model.get
115+
return model
116+
return _make
117+
118+
119+
def test_pydantic_v2_model(v2_model):
120+
schema = parameters_to_json_schema({PARAM_NAME: v2_model()})
121+
122+
assert schema[PARAM_NAME]["type"] == "data"
123+
assert schema[PARAM_NAME]["default"] == DEFAULT_VALUE
124+
assert schema[PARAM_NAME]["description"] == DEFAULT_DESCRIPTION
125+
126+
127+
def test_pydantic_v1_model(v1_model):
128+
schema = parameters_to_json_schema({PARAM_NAME: v1_model()})
129+
130+
assert schema[PARAM_NAME]["type"] == "data"
131+
assert schema[PARAM_NAME]["default"] == DEFAULT_VALUE
132+
assert schema[PARAM_NAME]["description"] is None
133+
134+
135+
def test_pydantic_v2_multi_field_model(v2_multi_field_model):
136+
schema = parameters_to_json_schema({PARAM_NAME: v2_multi_field_model()})
137+
138+
assert schema[PARAM_NAME]["type"] == "data"
139+
assert schema[PARAM_NAME]["schema"]["title"] == "ModelConfig"
140+
assert schema[PARAM_NAME]["default"] == {"temperature": 0.7, "max_tokens": 1024}
141+
assert schema[PARAM_NAME]["description"] == {"temperature": "Sampling temperature", "max_tokens": "Maximum tokens to generate"}
142+
143+
144+
def test_pydantic_v1_multi_field_model(v1_multi_field_model):
145+
schema = parameters_to_json_schema({PARAM_NAME: v1_multi_field_model()})
146+
147+
assert schema[PARAM_NAME]["type"] == "data"
148+
assert schema[PARAM_NAME]["schema"]["title"] == "ModelConfig"
149+
assert schema[PARAM_NAME]["default"] == {"temperature": 0.7, "max_tokens": 1024}
150+
assert schema[PARAM_NAME]["description"] == {"temperature": None, "max_tokens": None}

0 commit comments

Comments
 (0)