Skip to content

Commit bb9fe7f

Browse files
committed
chore: formatting
1 parent b07e75c commit bb9fe7f

File tree

3 files changed

+33
-12
lines changed

3 files changed

+33
-12
lines changed

py/src/braintrust/oai.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,12 @@ def _postprocess_streaming_results(cls, all_results: list[dict[str, Any]]) -> di
350350

351351

352352
class ResponseWrapper:
353-
def __init__(self, create_fn: Callable[..., Any] | None, acreate_fn: Callable[..., Any] | None, name: str = "openai.responses.create"):
353+
def __init__(
354+
self,
355+
create_fn: Callable[..., Any] | None,
356+
acreate_fn: Callable[..., Any] | None,
357+
name: str = "openai.responses.create",
358+
):
354359
self.create_fn = create_fn
355360
self.acreate_fn = acreate_fn
356361
self.name = name
@@ -359,9 +364,7 @@ def create(self, *args: Any, **kwargs: Any) -> Any:
359364
params = self._parse_params(kwargs)
360365
stream = kwargs.get("stream", False)
361366

362-
span = start_span(
363-
**merge_dicts(dict(name=self.name, span_attributes={"type": SpanTypeAttribute.LLM}), params)
364-
)
367+
span = start_span(**merge_dicts(dict(name=self.name, span_attributes={"type": SpanTypeAttribute.LLM}), params))
365368
should_end = True
366369

367370
try:
@@ -373,6 +376,7 @@ def create(self, *args: Any, **kwargs: Any) -> Any:
373376
else:
374377
raw_response = create_response
375378
if stream:
379+
376380
def gen():
377381
try:
378382
first = True
@@ -410,9 +414,7 @@ async def acreate(self, *args: Any, **kwargs: Any) -> Any:
410414
params = self._parse_params(kwargs)
411415
stream = kwargs.get("stream", False)
412416

413-
span = start_span(
414-
**merge_dicts(dict(name=self.name, span_attributes={"type": SpanTypeAttribute.LLM}), params)
415-
)
417+
span = start_span(**merge_dicts(dict(name=self.name, span_attributes={"type": SpanTypeAttribute.LLM}), params))
416418
should_end = True
417419

418420
try:
@@ -424,6 +426,7 @@ async def acreate(self, *args: Any, **kwargs: Any) -> Any:
424426
else:
425427
raw_response = create_response
426428
if stream:
429+
427430
async def gen():
428431
try:
429432
first = True
@@ -506,7 +509,12 @@ def _postprocess_streaming_results(cls, all_results: list[Any]) -> dict[str, Any
506509

507510
for result in all_results:
508511
usage = getattr(result, "usage", None)
509-
if not usage and hasattr(result, "type") and result.type == "response.completed" and hasattr(result, "response"):
512+
if (
513+
not usage
514+
and hasattr(result, "type")
515+
and result.type == "response.completed"
516+
and hasattr(result, "response")
517+
):
510518
# Handle summaries from completed response if present
511519
if hasattr(result.response, "output") and result.response.output:
512520
for output_item in result.response.output:
@@ -795,7 +803,9 @@ def create(self, *args: Any, **kwargs: Any) -> Any:
795803
return ResponseWrapper(self.__responses.with_raw_response.create, None).create(*args, **kwargs)
796804

797805
def parse(self, *args: Any, **kwargs: Any) -> Any:
798-
return ResponseWrapper(self.__responses.with_raw_response.parse, None, "openai.responses.parse").create(*args, **kwargs)
806+
return ResponseWrapper(self.__responses.with_raw_response.parse, None, "openai.responses.parse").create(
807+
*args, **kwargs
808+
)
799809

800810

801811
class AsyncResponsesV1Wrapper(NamedWrapper):
@@ -808,7 +818,9 @@ async def create(self, *args: Any, **kwargs: Any) -> Any:
808818
return AsyncResponseWrapper(response)
809819

810820
async def parse(self, *args: Any, **kwargs: Any) -> Any:
811-
response = await ResponseWrapper(None, self.__responses.with_raw_response.parse, "openai.responses.parse").acreate(*args, **kwargs)
821+
response = await ResponseWrapper(
822+
None, self.__responses.with_raw_response.parse, "openai.responses.parse"
823+
).acreate(*args, **kwargs)
812824
return AsyncResponseWrapper(response)
813825

814826

@@ -938,7 +950,6 @@ def _parse_metrics_from_usage(usage: Any) -> dict[str, Any]:
938950
return metrics
939951

940952

941-
942953
def prettify_params(params: dict[str, Any]) -> dict[str, Any]:
943954
# Filter out NOT_GIVEN parameters
944955
# https://linear.app/braintrustdata/issue/BRA-2467

py/src/braintrust/test_pydantic_parameters.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
SCHEMA_TITLE = "SystemPrompt"
99
PARAM_NAME = "system_prompt"
1010

11+
1112
def test_extract_single_field_with_default_and_description():
1213
schema = {
1314
"value": {"type": "string", "default": DEFAULT_VALUE, "description": DEFAULT_DESCRIPTION},
@@ -63,6 +64,7 @@ def _make(default=DEFAULT_VALUE, description=DEFAULT_DESCRIPTION):
6364
}
6465
del model.get
6566
return model
67+
6668
return _make
6769

6870

@@ -78,6 +80,7 @@ def _make(default=DEFAULT_VALUE):
7880
}
7981
del model.get
8082
return model
83+
8184
return _make
8285

8386

@@ -95,6 +98,7 @@ def _make():
9598
}
9699
del model.get
97100
return model
101+
98102
return _make
99103

100104

@@ -113,6 +117,7 @@ def _make():
113117
}
114118
del model.get
115119
return model
120+
116121
return _make
117122

118123

@@ -138,7 +143,10 @@ def test_pydantic_v2_multi_field_model(v2_multi_field_model):
138143
assert schema[PARAM_NAME]["type"] == "data"
139144
assert schema[PARAM_NAME]["schema"]["title"] == "ModelConfig"
140145
assert schema[PARAM_NAME]["default"] == {"temperature": 0.7, "max_tokens": 1024}
141-
assert schema[PARAM_NAME]["description"] == {"temperature": "Sampling temperature", "max_tokens": "Maximum tokens to generate"}
146+
assert schema[PARAM_NAME]["description"] == {
147+
"temperature": "Sampling temperature",
148+
"max_tokens": "Maximum tokens to generate",
149+
}
142150

143151

144152
def test_pydantic_v1_multi_field_model(v1_multi_field_model):

py/src/braintrust/wrappers/test_openai.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,7 @@ def __init__(self, id="test_id", type="message"):
378378
# No spans should be generated from this unit test
379379
assert not memory_logger.pop()
380380

381+
381382
@pytest.mark.vcr
382383
def test_openai_embeddings(memory_logger):
383384
assert not memory_logger.pop()
@@ -1935,6 +1936,7 @@ def test_auto_instrument_openai(self):
19351936
"""Test auto_instrument patches OpenAI, creates spans, and uninstrument works."""
19361937
verify_autoinstrument_script("test_auto_openai.py")
19371938

1939+
19381940
class TestZAICompatibleOpenAI:
19391941
"""Tests for validating some ZAI compatibility with OpenAI wrapper."""
19401942

0 commit comments

Comments
 (0)