From 2ca243041acbb8cf474e59d1ee6c66ceebd536f4 Mon Sep 17 00:00:00 2001 From: Abhijeet Prasad Date: Tue, 7 Apr 2026 13:31:44 -0400 Subject: [PATCH] fix(claude_agent_sdk): trace top-level query calls Patch the exported claude_agent_sdk.query helper so one-shot queries emit the same task and LLM spans as ClaudeSDKClient. This preserves tracing when callers import query before setup and use the standalone API directly. Refactor prompt capture and streamed message handling into shared helpers so the client and one-shot query paths stay aligned, and add a regression test that fails before the patch and verifies import-before-setup alias propagation. --- .../claude_agent_sdk/integration.py | 4 +- .../integrations/claude_agent_sdk/patchers.py | 17 ++- .../claude_agent_sdk/test_claude_agent_sdk.py | 47 +++++++ .../integrations/claude_agent_sdk/tracing.py | 124 ++++++++++++------ 4 files changed, 149 insertions(+), 43 deletions(-) diff --git a/py/src/braintrust/integrations/claude_agent_sdk/integration.py b/py/src/braintrust/integrations/claude_agent_sdk/integration.py index 3f1a3b00..f6a99446 100644 --- a/py/src/braintrust/integrations/claude_agent_sdk/integration.py +++ b/py/src/braintrust/integrations/claude_agent_sdk/integration.py @@ -2,7 +2,7 @@ from braintrust.integrations.base import BaseIntegration -from .patchers import ClaudeSDKClientPatcher, SdkMcpToolPatcher +from .patchers import ClaudeSDKClientPatcher, ClaudeSDKQueryPatcher, SdkMcpToolPatcher class ClaudeAgentSDKIntegration(BaseIntegration): @@ -11,4 +11,4 @@ class ClaudeAgentSDKIntegration(BaseIntegration): name = "claude_agent_sdk" import_names = ("claude_agent_sdk",) min_version = "0.1.10" - patchers = (ClaudeSDKClientPatcher, SdkMcpToolPatcher) + patchers = (ClaudeSDKClientPatcher, ClaudeSDKQueryPatcher, SdkMcpToolPatcher) diff --git a/py/src/braintrust/integrations/claude_agent_sdk/patchers.py b/py/src/braintrust/integrations/claude_agent_sdk/patchers.py index 5f288d53..f5659584 100644 --- a/py/src/braintrust/integrations/claude_agent_sdk/patchers.py +++ b/py/src/braintrust/integrations/claude_agent_sdk/patchers.py @@ -1,8 +1,8 @@ -"""Claude Agent SDK patchers — class-replacement patchers for ClaudeSDKClient and SdkMcpTool.""" +"""Claude Agent SDK patchers — replacement patchers for ClaudeSDKClient, query, and SdkMcpTool.""" from braintrust.integrations.base import ClassReplacementPatcher -from .tracing import _create_client_wrapper_class, _create_tool_wrapper_class +from .tracing import _create_client_wrapper_class, _create_query_wrapper_function, _create_tool_wrapper_class class ClaudeSDKClientPatcher(ClassReplacementPatcher): @@ -18,6 +18,19 @@ class ClaudeSDKClientPatcher(ClassReplacementPatcher): replacement_factory = staticmethod(_create_client_wrapper_class) +class ClaudeSDKQueryPatcher(ClassReplacementPatcher): + """Replace exported ``claude_agent_sdk.query`` with a tracing wrapper. + + This integration needs exported-function replacement because the wrapper + drives the full span lifecycle across the one-shot async iterator and must + update modules that imported ``query`` before setup. + """ + + name = "claude_agent_sdk.query" + target_attr = "query" + replacement_factory = staticmethod(_create_query_wrapper_function) + + class SdkMcpToolPatcher(ClassReplacementPatcher): """Replace ``claude_agent_sdk.SdkMcpTool`` with a tracing wrapper class. diff --git a/py/src/braintrust/integrations/claude_agent_sdk/test_claude_agent_sdk.py b/py/src/braintrust/integrations/claude_agent_sdk/test_claude_agent_sdk.py index ff71a13a..6fe29edd 100644 --- a/py/src/braintrust/integrations/claude_agent_sdk/test_claude_agent_sdk.py +++ b/py/src/braintrust/integrations/claude_agent_sdk/test_claude_agent_sdk.py @@ -61,6 +61,7 @@ def memory_logger(): def _patched_claude_sdk(*, wrap_client: bool = False, wrap_tool_class: bool = False): original_client = claude_agent_sdk.ClaudeSDKClient original_tool_class = claude_agent_sdk.SdkMcpTool + original_query = claude_agent_sdk.query if wrap_client: claude_agent_sdk.ClaudeSDKClient = _create_client_wrapper_class(original_client) @@ -72,6 +73,7 @@ def _patched_claude_sdk(*, wrap_client: bool = False, wrap_tool_class: bool = Fa finally: claude_agent_sdk.ClaudeSDKClient = original_client claude_agent_sdk.SdkMcpTool = original_tool_class + claude_agent_sdk.query = original_query @pytest.mark.skipif(not CLAUDE_SDK_AVAILABLE, reason="Claude Agent SDK not installed") @@ -2116,6 +2118,51 @@ async def main() -> None: assert task_spans[0]["input"] == "Say hi" +@pytest.mark.skipif(not CLAUDE_SDK_AVAILABLE, reason="Claude Agent SDK not installed") +@pytest.mark.asyncio +async def test_setup_claude_agent_sdk_query_repro_import_before_setup(memory_logger, monkeypatch): + assert not memory_logger.pop() + + async def fake_query(*, prompt, **kwargs): + del kwargs + if isinstance(prompt, AsyncIterable): + async for _ in prompt: + pass + yield AssistantMessage(content=[TextBlock("hi")]) + yield ResultMessage() + + monkeypatch.setattr(claude_agent_sdk, "query", fake_query) + original_query = claude_agent_sdk.query + + consumer_module_name = "test_query_import_before_setup_module" + consumer_module = types.ModuleType(consumer_module_name) + consumer_module.query = original_query + monkeypatch.setitem(sys.modules, consumer_module_name, consumer_module) + + received_types = [] + + with _patched_claude_sdk(): + assert setup_claude_agent_sdk(project=PROJECT_NAME, api_key=logger.TEST_API_KEY) + assert getattr(consumer_module, "query") is not original_query + assert claude_agent_sdk.query is not original_query + + async for message in getattr(consumer_module, "query")(prompt="Say hi"): + received_types.append(type(message).__name__) + + assert "AssistantMessage" in received_types + assert received_types[-1] == "ResultMessage" + + spans = memory_logger.pop() + task_spans = [s for s in spans if s["span_attributes"]["type"] == SpanTypeAttribute.TASK] + llm_spans = [s for s in spans if s["span_attributes"]["type"] == SpanTypeAttribute.LLM] + + assert len(task_spans) == 1 + assert task_spans[0]["span_attributes"]["name"] == "Claude Agent" + assert task_spans[0]["input"] == "Say hi" + assert task_spans[0]["output"] is not None + assert len(llm_spans) == 1 + + @pytest.mark.skipif(not CLAUDE_SDK_AVAILABLE, reason="Claude Agent SDK not installed") @pytest.mark.asyncio async def test_concurrent_subagents_produce_parallel_llm_spans_with_correct_parenting(memory_logger): diff --git a/py/src/braintrust/integrations/claude_agent_sdk/tracing.py b/py/src/braintrust/integrations/claude_agent_sdk/tracing.py index d31c3551..4dea63d6 100644 --- a/py/src/braintrust/integrations/claude_agent_sdk/tracing.py +++ b/py/src/braintrust/integrations/claude_agent_sdk/tracing.py @@ -944,6 +944,80 @@ def _hook_parent_export(self, tool_use_id: str | None) -> str: return self._root_span.export() +def _prepare_prompt_for_tracing(prompt: Any) -> tuple[Any, str | None, list[dict[str, Any]] | None]: + if prompt is None: + return None, None, None + + if isinstance(prompt, str): + return prompt, prompt, None + + if isinstance(prompt, AsyncIterable): + captured: list[dict[str, Any]] = [] + + async def capturing_wrapper() -> AsyncGenerator[dict[str, Any], None]: + async for msg in prompt: + captured.append(msg) + yield msg + + return capturing_wrapper(), None, captured + + return prompt, str(prompt), None + + +async def _stream_messages_with_tracing( + generator: AsyncIterable[Any], + *, + request_tracker: RequestTracker, + finish_request_tracker: Any, +) -> AsyncGenerator[Any, None]: + try: + async for message in generator: + request_tracker.add_message(message) + yield message + except asyncio.CancelledError: + # The CancelledError may come from the subprocess transport + # (e.g., anyio internal cleanup when subagents complete) rather + # than a genuine external cancellation. We suppress it here so + # the response stream ends cleanly. If the caller genuinely + # cancelled the task, they still have pending cancellation + # requests that will fire at their next await point. + finish_request_tracker(log_output=True) + else: + finish_request_tracker(log_output=True) + finally: + finish_request_tracker() + + +def _create_query_wrapper_function(original_query: Any) -> Any: + """Create a tracing wrapper for the exported one-shot ``query()`` helper.""" + + async def wrapped_query(*args: Any, **kwargs: Any) -> AsyncGenerator[Any, None]: + query_start_time = time.time() + prompt = args[0] if args else kwargs.get("prompt") + prompt, traced_prompt, captured_messages = _prepare_prompt_for_tracing(prompt) + + if args: + args = (prompt,) + args[1:] + else: + kwargs = dict(kwargs) + kwargs["prompt"] = prompt + + request_tracker = RequestTracker( + prompt=traced_prompt, + query_start_time=query_start_time, + captured_messages=captured_messages, + ) + + async for message in _stream_messages_with_tracing( + original_query(*args, **kwargs), + request_tracker=request_tracker, + finish_request_tracker=request_tracker.finish, + ): + yield message + + return wrapped_query + + def _create_client_wrapper_class(original_client_class: Any) -> Any: """Creates a wrapper class for ClaudeSDKClient that wraps query and receive_response.""" @@ -1018,32 +1092,15 @@ async def query(self, *args: Any, **kwargs: Any) -> Any: """Wrap query to capture the prompt and start time for tracing.""" # Capture the time when query is called (when LLM call starts) self.__query_start_time = time.time() - self.__captured_messages = None # Capture the prompt for use in receive_response prompt = args[0] if args else kwargs.get("prompt") + prompt, self.__last_prompt, self.__captured_messages = _prepare_prompt_for_tracing(prompt) - if prompt is not None: - if isinstance(prompt, str): - self.__last_prompt = prompt - elif isinstance(prompt, AsyncIterable): - # AsyncIterable[dict] - wrap it to capture messages as they're yielded - captured: list[dict[str, Any]] = [] - self.__captured_messages = captured - self.__last_prompt = None # Will be set after messages are captured - - async def capturing_wrapper() -> AsyncGenerator[dict[str, Any], None]: - async for msg in prompt: - captured.append(msg) - yield msg - - # Replace the prompt with our capturing wrapper - if args: - args = (capturing_wrapper(),) + args[1:] - else: - kwargs["prompt"] = capturing_wrapper() - else: - self.__last_prompt = str(prompt) + if args: + args = (prompt,) + args[1:] + else: + kwargs["prompt"] = prompt self.__instrument_hook_callbacks() self.__start_request_tracker() @@ -1061,23 +1118,12 @@ async def receive_response(self) -> AsyncGenerator[Any, None]: generator = self.__client.receive_response() request_tracker = self.__request_tracker or self.__start_request_tracker() - try: - async for message in generator: - request_tracker.add_message(message) - yield message - except asyncio.CancelledError: - # The CancelledError may come from the subprocess transport - # (e.g., anyio internal cleanup when subagents complete) rather - # than a genuine external cancellation. We suppress it here so - # the response stream ends cleanly. If the caller genuinely - # cancelled the task, they still have pending cancellation - # requests that will fire at their next await point. - self.__finish_request_tracker(log_output=True) - else: - self.__finish_request_tracker(log_output=True) - finally: - if self.__request_tracker is not None: - self.__finish_request_tracker() + async for message in _stream_messages_with_tracing( + generator, + request_tracker=request_tracker, + finish_request_tracker=self.__finish_request_tracker, + ): + yield message async def __aenter__(self) -> "WrappedClaudeSDKClient": await self.__client.__aenter__()