diff --git a/py/src/braintrust/integrations/claude_agent_sdk/integration.py b/py/src/braintrust/integrations/claude_agent_sdk/integration.py index 3f1a3b00..f6a99446 100644 --- a/py/src/braintrust/integrations/claude_agent_sdk/integration.py +++ b/py/src/braintrust/integrations/claude_agent_sdk/integration.py @@ -2,7 +2,7 @@ from braintrust.integrations.base import BaseIntegration -from .patchers import ClaudeSDKClientPatcher, SdkMcpToolPatcher +from .patchers import ClaudeSDKClientPatcher, ClaudeSDKQueryPatcher, SdkMcpToolPatcher class ClaudeAgentSDKIntegration(BaseIntegration): @@ -11,4 +11,4 @@ class ClaudeAgentSDKIntegration(BaseIntegration): name = "claude_agent_sdk" import_names = ("claude_agent_sdk",) min_version = "0.1.10" - patchers = (ClaudeSDKClientPatcher, SdkMcpToolPatcher) + patchers = (ClaudeSDKClientPatcher, ClaudeSDKQueryPatcher, SdkMcpToolPatcher) diff --git a/py/src/braintrust/integrations/claude_agent_sdk/patchers.py b/py/src/braintrust/integrations/claude_agent_sdk/patchers.py index 5f288d53..f5659584 100644 --- a/py/src/braintrust/integrations/claude_agent_sdk/patchers.py +++ b/py/src/braintrust/integrations/claude_agent_sdk/patchers.py @@ -1,8 +1,8 @@ -"""Claude Agent SDK patchers — class-replacement patchers for ClaudeSDKClient and SdkMcpTool.""" +"""Claude Agent SDK patchers — replacement patchers for ClaudeSDKClient, query, and SdkMcpTool.""" from braintrust.integrations.base import ClassReplacementPatcher -from .tracing import _create_client_wrapper_class, _create_tool_wrapper_class +from .tracing import _create_client_wrapper_class, _create_query_wrapper_function, _create_tool_wrapper_class class ClaudeSDKClientPatcher(ClassReplacementPatcher): @@ -18,6 +18,19 @@ class ClaudeSDKClientPatcher(ClassReplacementPatcher): replacement_factory = staticmethod(_create_client_wrapper_class) +class ClaudeSDKQueryPatcher(ClassReplacementPatcher): + """Replace exported ``claude_agent_sdk.query`` with a tracing wrapper. + + This integration needs exported-function replacement because the wrapper + drives the full span lifecycle across the one-shot async iterator and must + update modules that imported ``query`` before setup. + """ + + name = "claude_agent_sdk.query" + target_attr = "query" + replacement_factory = staticmethod(_create_query_wrapper_function) + + class SdkMcpToolPatcher(ClassReplacementPatcher): """Replace ``claude_agent_sdk.SdkMcpTool`` with a tracing wrapper class. diff --git a/py/src/braintrust/integrations/claude_agent_sdk/test_claude_agent_sdk.py b/py/src/braintrust/integrations/claude_agent_sdk/test_claude_agent_sdk.py index ff71a13a..6fe29edd 100644 --- a/py/src/braintrust/integrations/claude_agent_sdk/test_claude_agent_sdk.py +++ b/py/src/braintrust/integrations/claude_agent_sdk/test_claude_agent_sdk.py @@ -61,6 +61,7 @@ def memory_logger(): def _patched_claude_sdk(*, wrap_client: bool = False, wrap_tool_class: bool = False): original_client = claude_agent_sdk.ClaudeSDKClient original_tool_class = claude_agent_sdk.SdkMcpTool + original_query = claude_agent_sdk.query if wrap_client: claude_agent_sdk.ClaudeSDKClient = _create_client_wrapper_class(original_client) @@ -72,6 +73,7 @@ def _patched_claude_sdk(*, wrap_client: bool = False, wrap_tool_class: bool = Fa finally: claude_agent_sdk.ClaudeSDKClient = original_client claude_agent_sdk.SdkMcpTool = original_tool_class + claude_agent_sdk.query = original_query @pytest.mark.skipif(not CLAUDE_SDK_AVAILABLE, reason="Claude Agent SDK not installed") @@ -2116,6 +2118,51 @@ async def main() -> None: assert task_spans[0]["input"] == "Say hi" +@pytest.mark.skipif(not CLAUDE_SDK_AVAILABLE, reason="Claude Agent SDK not installed") +@pytest.mark.asyncio +async def test_setup_claude_agent_sdk_query_repro_import_before_setup(memory_logger, monkeypatch): + assert not memory_logger.pop() + + async def fake_query(*, prompt, **kwargs): + del kwargs + if isinstance(prompt, AsyncIterable): + async for _ in prompt: + pass + yield AssistantMessage(content=[TextBlock("hi")]) + yield ResultMessage() + + monkeypatch.setattr(claude_agent_sdk, "query", fake_query) + original_query = claude_agent_sdk.query + + consumer_module_name = "test_query_import_before_setup_module" + consumer_module = types.ModuleType(consumer_module_name) + consumer_module.query = original_query + monkeypatch.setitem(sys.modules, consumer_module_name, consumer_module) + + received_types = [] + + with _patched_claude_sdk(): + assert setup_claude_agent_sdk(project=PROJECT_NAME, api_key=logger.TEST_API_KEY) + assert getattr(consumer_module, "query") is not original_query + assert claude_agent_sdk.query is not original_query + + async for message in getattr(consumer_module, "query")(prompt="Say hi"): + received_types.append(type(message).__name__) + + assert "AssistantMessage" in received_types + assert received_types[-1] == "ResultMessage" + + spans = memory_logger.pop() + task_spans = [s for s in spans if s["span_attributes"]["type"] == SpanTypeAttribute.TASK] + llm_spans = [s for s in spans if s["span_attributes"]["type"] == SpanTypeAttribute.LLM] + + assert len(task_spans) == 1 + assert task_spans[0]["span_attributes"]["name"] == "Claude Agent" + assert task_spans[0]["input"] == "Say hi" + assert task_spans[0]["output"] is not None + assert len(llm_spans) == 1 + + @pytest.mark.skipif(not CLAUDE_SDK_AVAILABLE, reason="Claude Agent SDK not installed") @pytest.mark.asyncio async def test_concurrent_subagents_produce_parallel_llm_spans_with_correct_parenting(memory_logger): diff --git a/py/src/braintrust/integrations/claude_agent_sdk/tracing.py b/py/src/braintrust/integrations/claude_agent_sdk/tracing.py index d31c3551..4dea63d6 100644 --- a/py/src/braintrust/integrations/claude_agent_sdk/tracing.py +++ b/py/src/braintrust/integrations/claude_agent_sdk/tracing.py @@ -944,6 +944,80 @@ def _hook_parent_export(self, tool_use_id: str | None) -> str: return self._root_span.export() +def _prepare_prompt_for_tracing(prompt: Any) -> tuple[Any, str | None, list[dict[str, Any]] | None]: + if prompt is None: + return None, None, None + + if isinstance(prompt, str): + return prompt, prompt, None + + if isinstance(prompt, AsyncIterable): + captured: list[dict[str, Any]] = [] + + async def capturing_wrapper() -> AsyncGenerator[dict[str, Any], None]: + async for msg in prompt: + captured.append(msg) + yield msg + + return capturing_wrapper(), None, captured + + return prompt, str(prompt), None + + +async def _stream_messages_with_tracing( + generator: AsyncIterable[Any], + *, + request_tracker: RequestTracker, + finish_request_tracker: Any, +) -> AsyncGenerator[Any, None]: + try: + async for message in generator: + request_tracker.add_message(message) + yield message + except asyncio.CancelledError: + # The CancelledError may come from the subprocess transport + # (e.g., anyio internal cleanup when subagents complete) rather + # than a genuine external cancellation. We suppress it here so + # the response stream ends cleanly. If the caller genuinely + # cancelled the task, they still have pending cancellation + # requests that will fire at their next await point. + finish_request_tracker(log_output=True) + else: + finish_request_tracker(log_output=True) + finally: + finish_request_tracker() + + +def _create_query_wrapper_function(original_query: Any) -> Any: + """Create a tracing wrapper for the exported one-shot ``query()`` helper.""" + + async def wrapped_query(*args: Any, **kwargs: Any) -> AsyncGenerator[Any, None]: + query_start_time = time.time() + prompt = args[0] if args else kwargs.get("prompt") + prompt, traced_prompt, captured_messages = _prepare_prompt_for_tracing(prompt) + + if args: + args = (prompt,) + args[1:] + else: + kwargs = dict(kwargs) + kwargs["prompt"] = prompt + + request_tracker = RequestTracker( + prompt=traced_prompt, + query_start_time=query_start_time, + captured_messages=captured_messages, + ) + + async for message in _stream_messages_with_tracing( + original_query(*args, **kwargs), + request_tracker=request_tracker, + finish_request_tracker=request_tracker.finish, + ): + yield message + + return wrapped_query + + def _create_client_wrapper_class(original_client_class: Any) -> Any: """Creates a wrapper class for ClaudeSDKClient that wraps query and receive_response.""" @@ -1018,32 +1092,15 @@ async def query(self, *args: Any, **kwargs: Any) -> Any: """Wrap query to capture the prompt and start time for tracing.""" # Capture the time when query is called (when LLM call starts) self.__query_start_time = time.time() - self.__captured_messages = None # Capture the prompt for use in receive_response prompt = args[0] if args else kwargs.get("prompt") + prompt, self.__last_prompt, self.__captured_messages = _prepare_prompt_for_tracing(prompt) - if prompt is not None: - if isinstance(prompt, str): - self.__last_prompt = prompt - elif isinstance(prompt, AsyncIterable): - # AsyncIterable[dict] - wrap it to capture messages as they're yielded - captured: list[dict[str, Any]] = [] - self.__captured_messages = captured - self.__last_prompt = None # Will be set after messages are captured - - async def capturing_wrapper() -> AsyncGenerator[dict[str, Any], None]: - async for msg in prompt: - captured.append(msg) - yield msg - - # Replace the prompt with our capturing wrapper - if args: - args = (capturing_wrapper(),) + args[1:] - else: - kwargs["prompt"] = capturing_wrapper() - else: - self.__last_prompt = str(prompt) + if args: + args = (prompt,) + args[1:] + else: + kwargs["prompt"] = prompt self.__instrument_hook_callbacks() self.__start_request_tracker() @@ -1061,23 +1118,12 @@ async def receive_response(self) -> AsyncGenerator[Any, None]: generator = self.__client.receive_response() request_tracker = self.__request_tracker or self.__start_request_tracker() - try: - async for message in generator: - request_tracker.add_message(message) - yield message - except asyncio.CancelledError: - # The CancelledError may come from the subprocess transport - # (e.g., anyio internal cleanup when subagents complete) rather - # than a genuine external cancellation. We suppress it here so - # the response stream ends cleanly. If the caller genuinely - # cancelled the task, they still have pending cancellation - # requests that will fire at their next await point. - self.__finish_request_tracker(log_output=True) - else: - self.__finish_request_tracker(log_output=True) - finally: - if self.__request_tracker is not None: - self.__finish_request_tracker() + async for message in _stream_messages_with_tracing( + generator, + request_tracker=request_tracker, + finish_request_tracker=self.__finish_request_tracker, + ): + yield message async def __aenter__(self) -> "WrappedClaudeSDKClient": await self.__client.__aenter__()