diff --git a/py/src/braintrust/integrations/claude_agent_sdk/test_claude_agent_sdk.py b/py/src/braintrust/integrations/claude_agent_sdk/test_claude_agent_sdk.py index 8e1d53b8..f8a5cfbf 100644 --- a/py/src/braintrust/integrations/claude_agent_sdk/test_claude_agent_sdk.py +++ b/py/src/braintrust/integrations/claude_agent_sdk/test_claude_agent_sdk.py @@ -172,7 +172,14 @@ async def calculator_handler(args): assert task_span.get("metadata", {}).get("num_turns") is not None if hasattr(result_message, "session_id"): assert task_span.get("metadata", {}).get("session_id") is not None - + if hasattr(result_message, "stop_reason"): + assert task_span.get("metadata", {}).get("stop_reason") is not None + if hasattr(result_message, "total_cost_usd"): + assert task_span.get("metadata", {}).get("total_cost_usd") is not None + if hasattr(result_message, "duration_ms"): + assert task_span.get("metadata", {}).get("duration_ms") is not None + if hasattr(result_message, "duration_api_ms"): + assert task_span.get("metadata", {}).get("duration_api_ms") is not None llm_spans = [s for s in spans if s["span_attributes"]["type"] == SpanTypeAttribute.LLM] assert len(llm_spans) >= 1, f"Should have at least one LLM span, got {len(llm_spans)}" llm_span_ids = {span["span_id"] for span in llm_spans} @@ -1357,6 +1364,10 @@ def __init__( cache_creation_input_tokens: int = 0, num_turns: int = 1, session_id: str = "session-123", + stop_reason: str = "end_turn", + total_cost_usd: float = 1, + duration_ms: float = 1, + duration_api_ms: float = 1, ): self.usage = types.SimpleNamespace( input_tokens=input_tokens, @@ -1365,6 +1376,10 @@ def __init__( ) self.num_turns = num_turns self.session_id = session_id + self.stop_reason = stop_reason + self.total_cost_usd = total_cost_usd + self.duration_ms = duration_ms + self.duration_api_ms = duration_api_ms class FakeClaudeSDKClient: diff --git a/py/src/braintrust/integrations/claude_agent_sdk/tracing.py b/py/src/braintrust/integrations/claude_agent_sdk/tracing.py index 398d66f8..291869aa 100644 --- a/py/src/braintrust/integrations/claude_agent_sdk/tracing.py +++ b/py/src/braintrust/integrations/claude_agent_sdk/tracing.py @@ -730,6 +730,10 @@ def _handle_result(self, message: Any) -> None: for k, v in { "num_turns": getattr(message, "num_turns", None), "session_id": getattr(message, "session_id", None), + "stop_reason": getattr(message, "stop_reason", None), + "total_cost_usd": getattr(message, "total_cost_usd", None), + "duration_ms": getattr(message, "duration_ms", None), + "duration_api_ms": getattr(message, "duration_api_ms", None), }.items() if v is not None }