Skip to content

Commit a0e7d54

Browse files
authored
fix(claude_agent_sdk): trace top-level query calls (#216)
Patch the exported claude_agent_sdk.query helper so one-shot queries emit the same task and LLM spans as ClaudeSDKClient. This preserves tracing when callers import query before setup and use the standalone API directly. Refactor prompt capture and streamed message handling into shared helpers so the client and one-shot query paths stay aligned, and add a regression test that fails before the patch and verifies import-before-setup alias propagation.
1 parent b8eaf0e commit a0e7d54

4 files changed

Lines changed: 149 additions & 43 deletions

File tree

py/src/braintrust/integrations/claude_agent_sdk/integration.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from braintrust.integrations.base import BaseIntegration
44

5-
from .patchers import ClaudeSDKClientPatcher, SdkMcpToolPatcher
5+
from .patchers import ClaudeSDKClientPatcher, ClaudeSDKQueryPatcher, SdkMcpToolPatcher
66

77

88
class ClaudeAgentSDKIntegration(BaseIntegration):
@@ -11,4 +11,4 @@ class ClaudeAgentSDKIntegration(BaseIntegration):
1111
name = "claude_agent_sdk"
1212
import_names = ("claude_agent_sdk",)
1313
min_version = "0.1.10"
14-
patchers = (ClaudeSDKClientPatcher, SdkMcpToolPatcher)
14+
patchers = (ClaudeSDKClientPatcher, ClaudeSDKQueryPatcher, SdkMcpToolPatcher)

py/src/braintrust/integrations/claude_agent_sdk/patchers.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
"""Claude Agent SDK patchers — class-replacement patchers for ClaudeSDKClient and SdkMcpTool."""
1+
"""Claude Agent SDK patchers — replacement patchers for ClaudeSDKClient, query, and SdkMcpTool."""
22

33
from braintrust.integrations.base import ClassReplacementPatcher
44

5-
from .tracing import _create_client_wrapper_class, _create_tool_wrapper_class
5+
from .tracing import _create_client_wrapper_class, _create_query_wrapper_function, _create_tool_wrapper_class
66

77

88
class ClaudeSDKClientPatcher(ClassReplacementPatcher):
@@ -18,6 +18,19 @@ class ClaudeSDKClientPatcher(ClassReplacementPatcher):
1818
replacement_factory = staticmethod(_create_client_wrapper_class)
1919

2020

21+
class ClaudeSDKQueryPatcher(ClassReplacementPatcher):
22+
"""Replace exported ``claude_agent_sdk.query`` with a tracing wrapper.
23+
24+
This integration needs exported-function replacement because the wrapper
25+
drives the full span lifecycle across the one-shot async iterator and must
26+
update modules that imported ``query`` before setup.
27+
"""
28+
29+
name = "claude_agent_sdk.query"
30+
target_attr = "query"
31+
replacement_factory = staticmethod(_create_query_wrapper_function)
32+
33+
2134
class SdkMcpToolPatcher(ClassReplacementPatcher):
2235
"""Replace ``claude_agent_sdk.SdkMcpTool`` with a tracing wrapper class.
2336

py/src/braintrust/integrations/claude_agent_sdk/test_claude_agent_sdk.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def memory_logger():
6161
def _patched_claude_sdk(*, wrap_client: bool = False, wrap_tool_class: bool = False):
6262
original_client = claude_agent_sdk.ClaudeSDKClient
6363
original_tool_class = claude_agent_sdk.SdkMcpTool
64+
original_query = claude_agent_sdk.query
6465

6566
if wrap_client:
6667
claude_agent_sdk.ClaudeSDKClient = _create_client_wrapper_class(original_client)
@@ -72,6 +73,7 @@ def _patched_claude_sdk(*, wrap_client: bool = False, wrap_tool_class: bool = Fa
7273
finally:
7374
claude_agent_sdk.ClaudeSDKClient = original_client
7475
claude_agent_sdk.SdkMcpTool = original_tool_class
76+
claude_agent_sdk.query = original_query
7577

7678

7779
@pytest.mark.skipif(not CLAUDE_SDK_AVAILABLE, reason="Claude Agent SDK not installed")
@@ -2116,6 +2118,51 @@ async def main() -> None:
21162118
assert task_spans[0]["input"] == "Say hi"
21172119

21182120

2121+
@pytest.mark.skipif(not CLAUDE_SDK_AVAILABLE, reason="Claude Agent SDK not installed")
2122+
@pytest.mark.asyncio
2123+
async def test_setup_claude_agent_sdk_query_repro_import_before_setup(memory_logger, monkeypatch):
2124+
assert not memory_logger.pop()
2125+
2126+
async def fake_query(*, prompt, **kwargs):
2127+
del kwargs
2128+
if isinstance(prompt, AsyncIterable):
2129+
async for _ in prompt:
2130+
pass
2131+
yield AssistantMessage(content=[TextBlock("hi")])
2132+
yield ResultMessage()
2133+
2134+
monkeypatch.setattr(claude_agent_sdk, "query", fake_query)
2135+
original_query = claude_agent_sdk.query
2136+
2137+
consumer_module_name = "test_query_import_before_setup_module"
2138+
consumer_module = types.ModuleType(consumer_module_name)
2139+
consumer_module.query = original_query
2140+
monkeypatch.setitem(sys.modules, consumer_module_name, consumer_module)
2141+
2142+
received_types = []
2143+
2144+
with _patched_claude_sdk():
2145+
assert setup_claude_agent_sdk(project=PROJECT_NAME, api_key=logger.TEST_API_KEY)
2146+
assert getattr(consumer_module, "query") is not original_query
2147+
assert claude_agent_sdk.query is not original_query
2148+
2149+
async for message in getattr(consumer_module, "query")(prompt="Say hi"):
2150+
received_types.append(type(message).__name__)
2151+
2152+
assert "AssistantMessage" in received_types
2153+
assert received_types[-1] == "ResultMessage"
2154+
2155+
spans = memory_logger.pop()
2156+
task_spans = [s for s in spans if s["span_attributes"]["type"] == SpanTypeAttribute.TASK]
2157+
llm_spans = [s for s in spans if s["span_attributes"]["type"] == SpanTypeAttribute.LLM]
2158+
2159+
assert len(task_spans) == 1
2160+
assert task_spans[0]["span_attributes"]["name"] == "Claude Agent"
2161+
assert task_spans[0]["input"] == "Say hi"
2162+
assert task_spans[0]["output"] is not None
2163+
assert len(llm_spans) == 1
2164+
2165+
21192166
@pytest.mark.skipif(not CLAUDE_SDK_AVAILABLE, reason="Claude Agent SDK not installed")
21202167
@pytest.mark.asyncio
21212168
async def test_concurrent_subagents_produce_parallel_llm_spans_with_correct_parenting(memory_logger):

py/src/braintrust/integrations/claude_agent_sdk/tracing.py

Lines changed: 85 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -944,6 +944,80 @@ def _hook_parent_export(self, tool_use_id: str | None) -> str:
944944
return self._root_span.export()
945945

946946

947+
def _prepare_prompt_for_tracing(prompt: Any) -> tuple[Any, str | None, list[dict[str, Any]] | None]:
948+
if prompt is None:
949+
return None, None, None
950+
951+
if isinstance(prompt, str):
952+
return prompt, prompt, None
953+
954+
if isinstance(prompt, AsyncIterable):
955+
captured: list[dict[str, Any]] = []
956+
957+
async def capturing_wrapper() -> AsyncGenerator[dict[str, Any], None]:
958+
async for msg in prompt:
959+
captured.append(msg)
960+
yield msg
961+
962+
return capturing_wrapper(), None, captured
963+
964+
return prompt, str(prompt), None
965+
966+
967+
async def _stream_messages_with_tracing(
968+
generator: AsyncIterable[Any],
969+
*,
970+
request_tracker: RequestTracker,
971+
finish_request_tracker: Any,
972+
) -> AsyncGenerator[Any, None]:
973+
try:
974+
async for message in generator:
975+
request_tracker.add_message(message)
976+
yield message
977+
except asyncio.CancelledError:
978+
# The CancelledError may come from the subprocess transport
979+
# (e.g., anyio internal cleanup when subagents complete) rather
980+
# than a genuine external cancellation. We suppress it here so
981+
# the response stream ends cleanly. If the caller genuinely
982+
# cancelled the task, they still have pending cancellation
983+
# requests that will fire at their next await point.
984+
finish_request_tracker(log_output=True)
985+
else:
986+
finish_request_tracker(log_output=True)
987+
finally:
988+
finish_request_tracker()
989+
990+
991+
def _create_query_wrapper_function(original_query: Any) -> Any:
992+
"""Create a tracing wrapper for the exported one-shot ``query()`` helper."""
993+
994+
async def wrapped_query(*args: Any, **kwargs: Any) -> AsyncGenerator[Any, None]:
995+
query_start_time = time.time()
996+
prompt = args[0] if args else kwargs.get("prompt")
997+
prompt, traced_prompt, captured_messages = _prepare_prompt_for_tracing(prompt)
998+
999+
if args:
1000+
args = (prompt,) + args[1:]
1001+
else:
1002+
kwargs = dict(kwargs)
1003+
kwargs["prompt"] = prompt
1004+
1005+
request_tracker = RequestTracker(
1006+
prompt=traced_prompt,
1007+
query_start_time=query_start_time,
1008+
captured_messages=captured_messages,
1009+
)
1010+
1011+
async for message in _stream_messages_with_tracing(
1012+
original_query(*args, **kwargs),
1013+
request_tracker=request_tracker,
1014+
finish_request_tracker=request_tracker.finish,
1015+
):
1016+
yield message
1017+
1018+
return wrapped_query
1019+
1020+
9471021
def _create_client_wrapper_class(original_client_class: Any) -> Any:
9481022
"""Creates a wrapper class for ClaudeSDKClient that wraps query and receive_response."""
9491023

@@ -1018,32 +1092,15 @@ async def query(self, *args: Any, **kwargs: Any) -> Any:
10181092
"""Wrap query to capture the prompt and start time for tracing."""
10191093
# Capture the time when query is called (when LLM call starts)
10201094
self.__query_start_time = time.time()
1021-
self.__captured_messages = None
10221095

10231096
# Capture the prompt for use in receive_response
10241097
prompt = args[0] if args else kwargs.get("prompt")
1098+
prompt, self.__last_prompt, self.__captured_messages = _prepare_prompt_for_tracing(prompt)
10251099

1026-
if prompt is not None:
1027-
if isinstance(prompt, str):
1028-
self.__last_prompt = prompt
1029-
elif isinstance(prompt, AsyncIterable):
1030-
# AsyncIterable[dict] - wrap it to capture messages as they're yielded
1031-
captured: list[dict[str, Any]] = []
1032-
self.__captured_messages = captured
1033-
self.__last_prompt = None # Will be set after messages are captured
1034-
1035-
async def capturing_wrapper() -> AsyncGenerator[dict[str, Any], None]:
1036-
async for msg in prompt:
1037-
captured.append(msg)
1038-
yield msg
1039-
1040-
# Replace the prompt with our capturing wrapper
1041-
if args:
1042-
args = (capturing_wrapper(),) + args[1:]
1043-
else:
1044-
kwargs["prompt"] = capturing_wrapper()
1045-
else:
1046-
self.__last_prompt = str(prompt)
1100+
if args:
1101+
args = (prompt,) + args[1:]
1102+
else:
1103+
kwargs["prompt"] = prompt
10471104

10481105
self.__instrument_hook_callbacks()
10491106
self.__start_request_tracker()
@@ -1061,23 +1118,12 @@ async def receive_response(self) -> AsyncGenerator[Any, None]:
10611118
generator = self.__client.receive_response()
10621119
request_tracker = self.__request_tracker or self.__start_request_tracker()
10631120

1064-
try:
1065-
async for message in generator:
1066-
request_tracker.add_message(message)
1067-
yield message
1068-
except asyncio.CancelledError:
1069-
# The CancelledError may come from the subprocess transport
1070-
# (e.g., anyio internal cleanup when subagents complete) rather
1071-
# than a genuine external cancellation. We suppress it here so
1072-
# the response stream ends cleanly. If the caller genuinely
1073-
# cancelled the task, they still have pending cancellation
1074-
# requests that will fire at their next await point.
1075-
self.__finish_request_tracker(log_output=True)
1076-
else:
1077-
self.__finish_request_tracker(log_output=True)
1078-
finally:
1079-
if self.__request_tracker is not None:
1080-
self.__finish_request_tracker()
1121+
async for message in _stream_messages_with_tracing(
1122+
generator,
1123+
request_tracker=request_tracker,
1124+
finish_request_tracker=self.__finish_request_tracker,
1125+
):
1126+
yield message
10811127

10821128
async def __aenter__(self) -> "WrappedClaudeSDKClient":
10831129
await self.__client.__aenter__()

0 commit comments

Comments
 (0)