Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from braintrust.integrations.base import BaseIntegration

from .patchers import ClaudeSDKClientPatcher, SdkMcpToolPatcher
from .patchers import ClaudeSDKClientPatcher, ClaudeSDKQueryPatcher, SdkMcpToolPatcher


class ClaudeAgentSDKIntegration(BaseIntegration):
Expand All @@ -11,4 +11,4 @@ class ClaudeAgentSDKIntegration(BaseIntegration):
name = "claude_agent_sdk"
import_names = ("claude_agent_sdk",)
min_version = "0.1.10"
patchers = (ClaudeSDKClientPatcher, SdkMcpToolPatcher)
patchers = (ClaudeSDKClientPatcher, ClaudeSDKQueryPatcher, SdkMcpToolPatcher)
17 changes: 15 additions & 2 deletions py/src/braintrust/integrations/claude_agent_sdk/patchers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""Claude Agent SDK patchers — class-replacement patchers for ClaudeSDKClient and SdkMcpTool."""
"""Claude Agent SDK patchers — replacement patchers for ClaudeSDKClient, query, and SdkMcpTool."""

from braintrust.integrations.base import ClassReplacementPatcher

from .tracing import _create_client_wrapper_class, _create_tool_wrapper_class
from .tracing import _create_client_wrapper_class, _create_query_wrapper_function, _create_tool_wrapper_class


class ClaudeSDKClientPatcher(ClassReplacementPatcher):
Expand All @@ -18,6 +18,19 @@ class ClaudeSDKClientPatcher(ClassReplacementPatcher):
replacement_factory = staticmethod(_create_client_wrapper_class)


class ClaudeSDKQueryPatcher(ClassReplacementPatcher):
"""Replace exported ``claude_agent_sdk.query`` with a tracing wrapper.

This integration needs exported-function replacement because the wrapper
drives the full span lifecycle across the one-shot async iterator and must
update modules that imported ``query`` before setup.
"""

name = "claude_agent_sdk.query"
target_attr = "query"
replacement_factory = staticmethod(_create_query_wrapper_function)


class SdkMcpToolPatcher(ClassReplacementPatcher):
"""Replace ``claude_agent_sdk.SdkMcpTool`` with a tracing wrapper class.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def memory_logger():
def _patched_claude_sdk(*, wrap_client: bool = False, wrap_tool_class: bool = False):
original_client = claude_agent_sdk.ClaudeSDKClient
original_tool_class = claude_agent_sdk.SdkMcpTool
original_query = claude_agent_sdk.query

if wrap_client:
claude_agent_sdk.ClaudeSDKClient = _create_client_wrapper_class(original_client)
Expand All @@ -72,6 +73,7 @@ def _patched_claude_sdk(*, wrap_client: bool = False, wrap_tool_class: bool = Fa
finally:
claude_agent_sdk.ClaudeSDKClient = original_client
claude_agent_sdk.SdkMcpTool = original_tool_class
claude_agent_sdk.query = original_query


@pytest.mark.skipif(not CLAUDE_SDK_AVAILABLE, reason="Claude Agent SDK not installed")
Expand Down Expand Up @@ -2116,6 +2118,51 @@ async def main() -> None:
assert task_spans[0]["input"] == "Say hi"


@pytest.mark.skipif(not CLAUDE_SDK_AVAILABLE, reason="Claude Agent SDK not installed")
@pytest.mark.asyncio
async def test_setup_claude_agent_sdk_query_repro_import_before_setup(memory_logger, monkeypatch):
assert not memory_logger.pop()

async def fake_query(*, prompt, **kwargs):
del kwargs
if isinstance(prompt, AsyncIterable):
async for _ in prompt:
pass
yield AssistantMessage(content=[TextBlock("hi")])
yield ResultMessage()

monkeypatch.setattr(claude_agent_sdk, "query", fake_query)
original_query = claude_agent_sdk.query

consumer_module_name = "test_query_import_before_setup_module"
consumer_module = types.ModuleType(consumer_module_name)
consumer_module.query = original_query
monkeypatch.setitem(sys.modules, consumer_module_name, consumer_module)

received_types = []

with _patched_claude_sdk():
assert setup_claude_agent_sdk(project=PROJECT_NAME, api_key=logger.TEST_API_KEY)
assert getattr(consumer_module, "query") is not original_query
assert claude_agent_sdk.query is not original_query

async for message in getattr(consumer_module, "query")(prompt="Say hi"):
received_types.append(type(message).__name__)

assert "AssistantMessage" in received_types
assert received_types[-1] == "ResultMessage"

spans = memory_logger.pop()
task_spans = [s for s in spans if s["span_attributes"]["type"] == SpanTypeAttribute.TASK]
llm_spans = [s for s in spans if s["span_attributes"]["type"] == SpanTypeAttribute.LLM]

assert len(task_spans) == 1
assert task_spans[0]["span_attributes"]["name"] == "Claude Agent"
assert task_spans[0]["input"] == "Say hi"
assert task_spans[0]["output"] is not None
assert len(llm_spans) == 1


@pytest.mark.skipif(not CLAUDE_SDK_AVAILABLE, reason="Claude Agent SDK not installed")
@pytest.mark.asyncio
async def test_concurrent_subagents_produce_parallel_llm_spans_with_correct_parenting(memory_logger):
Expand Down
124 changes: 85 additions & 39 deletions py/src/braintrust/integrations/claude_agent_sdk/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -944,6 +944,80 @@ def _hook_parent_export(self, tool_use_id: str | None) -> str:
return self._root_span.export()


def _prepare_prompt_for_tracing(prompt: Any) -> tuple[Any, str | None, list[dict[str, Any]] | None]:
if prompt is None:
return None, None, None

if isinstance(prompt, str):
return prompt, prompt, None

if isinstance(prompt, AsyncIterable):
captured: list[dict[str, Any]] = []

async def capturing_wrapper() -> AsyncGenerator[dict[str, Any], None]:
async for msg in prompt:
captured.append(msg)
yield msg

return capturing_wrapper(), None, captured

return prompt, str(prompt), None


async def _stream_messages_with_tracing(
generator: AsyncIterable[Any],
*,
request_tracker: RequestTracker,
finish_request_tracker: Any,
) -> AsyncGenerator[Any, None]:
try:
async for message in generator:
request_tracker.add_message(message)
yield message
except asyncio.CancelledError:
# The CancelledError may come from the subprocess transport
# (e.g., anyio internal cleanup when subagents complete) rather
# than a genuine external cancellation. We suppress it here so
# the response stream ends cleanly. If the caller genuinely
# cancelled the task, they still have pending cancellation
# requests that will fire at their next await point.
finish_request_tracker(log_output=True)
else:
finish_request_tracker(log_output=True)
finally:
finish_request_tracker()


def _create_query_wrapper_function(original_query: Any) -> Any:
"""Create a tracing wrapper for the exported one-shot ``query()`` helper."""

async def wrapped_query(*args: Any, **kwargs: Any) -> AsyncGenerator[Any, None]:
query_start_time = time.time()
prompt = args[0] if args else kwargs.get("prompt")
prompt, traced_prompt, captured_messages = _prepare_prompt_for_tracing(prompt)

if args:
args = (prompt,) + args[1:]
else:
kwargs = dict(kwargs)
kwargs["prompt"] = prompt

request_tracker = RequestTracker(
prompt=traced_prompt,
query_start_time=query_start_time,
captured_messages=captured_messages,
)

async for message in _stream_messages_with_tracing(
original_query(*args, **kwargs),
request_tracker=request_tracker,
finish_request_tracker=request_tracker.finish,
):
yield message

return wrapped_query


def _create_client_wrapper_class(original_client_class: Any) -> Any:
"""Creates a wrapper class for ClaudeSDKClient that wraps query and receive_response."""

Expand Down Expand Up @@ -1018,32 +1092,15 @@ async def query(self, *args: Any, **kwargs: Any) -> Any:
"""Wrap query to capture the prompt and start time for tracing."""
# Capture the time when query is called (when LLM call starts)
self.__query_start_time = time.time()
self.__captured_messages = None

# Capture the prompt for use in receive_response
prompt = args[0] if args else kwargs.get("prompt")
prompt, self.__last_prompt, self.__captured_messages = _prepare_prompt_for_tracing(prompt)

if prompt is not None:
if isinstance(prompt, str):
self.__last_prompt = prompt
elif isinstance(prompt, AsyncIterable):
# AsyncIterable[dict] - wrap it to capture messages as they're yielded
captured: list[dict[str, Any]] = []
self.__captured_messages = captured
self.__last_prompt = None # Will be set after messages are captured

async def capturing_wrapper() -> AsyncGenerator[dict[str, Any], None]:
async for msg in prompt:
captured.append(msg)
yield msg

# Replace the prompt with our capturing wrapper
if args:
args = (capturing_wrapper(),) + args[1:]
else:
kwargs["prompt"] = capturing_wrapper()
else:
self.__last_prompt = str(prompt)
if args:
args = (prompt,) + args[1:]
else:
kwargs["prompt"] = prompt

self.__instrument_hook_callbacks()
self.__start_request_tracker()
Expand All @@ -1061,23 +1118,12 @@ async def receive_response(self) -> AsyncGenerator[Any, None]:
generator = self.__client.receive_response()
request_tracker = self.__request_tracker or self.__start_request_tracker()

try:
async for message in generator:
request_tracker.add_message(message)
yield message
except asyncio.CancelledError:
# The CancelledError may come from the subprocess transport
# (e.g., anyio internal cleanup when subagents complete) rather
# than a genuine external cancellation. We suppress it here so
# the response stream ends cleanly. If the caller genuinely
# cancelled the task, they still have pending cancellation
# requests that will fire at their next await point.
self.__finish_request_tracker(log_output=True)
else:
self.__finish_request_tracker(log_output=True)
finally:
if self.__request_tracker is not None:
self.__finish_request_tracker()
async for message in _stream_messages_with_tracing(
generator,
request_tracker=request_tracker,
finish_request_tracker=self.__finish_request_tracker,
):
yield message

async def __aenter__(self) -> "WrappedClaudeSDKClient":
await self.__client.__aenter__()
Expand Down
Loading