From 051a076133460110ceb3ac17b9c98564d9cb4aa2 Mon Sep 17 00:00:00 2001 From: Abhijeet Prasad Date: Tue, 7 Apr 2026 10:35:58 -0400 Subject: [PATCH 1/3] feat(google_genai): trace interactions api methods Instrument the Google GenAI Interactions API for sync and async create, get, cancel, delete, and streaming calls. The new tracing normalizes interaction inputs and outputs, records usage and timing, tracks previous_interaction_id, and emits nested tool spans for interaction tool calls. Replace the earlier fake-based interaction coverage with VCR-backed integration tests and recorded cassettes so the span shape is validated against real SDK behavior. Closes #198 --- .../test_interactions_async_round_trip.yaml | 155 ++++ .../test_interactions_async_stream.yaml | 104 +++ .../test_interactions_create_and_get.yaml | 105 +++ .../test_interactions_create_stream.yaml | 109 +++ .../cassettes/test_interactions_delete.yaml | 104 +++ ..._interactions_tool_call_and_follow_up.yaml | 109 +++ .../integrations/google_genai/integration.py | 16 + .../integrations/google_genai/patchers.py | 90 +++ .../google_genai/test_google_genai.py | 218 +++++- .../integrations/google_genai/tracing.py | 675 +++++++++++++++++- 10 files changed, 1658 insertions(+), 27 deletions(-) create mode 100644 py/src/braintrust/integrations/google_genai/cassettes/test_interactions_async_round_trip.yaml create mode 100644 py/src/braintrust/integrations/google_genai/cassettes/test_interactions_async_stream.yaml create mode 100644 py/src/braintrust/integrations/google_genai/cassettes/test_interactions_create_and_get.yaml create mode 100644 py/src/braintrust/integrations/google_genai/cassettes/test_interactions_create_stream.yaml create mode 100644 py/src/braintrust/integrations/google_genai/cassettes/test_interactions_delete.yaml create mode 100644 py/src/braintrust/integrations/google_genai/cassettes/test_interactions_tool_call_and_follow_up.yaml diff --git a/py/src/braintrust/integrations/google_genai/cassettes/test_interactions_async_round_trip.yaml b/py/src/braintrust/integrations/google_genai/cassettes/test_interactions_async_round_trip.yaml new file mode 100644 index 00000000..46f16f94 --- /dev/null +++ b/py/src/braintrust/integrations/google_genai/cassettes/test_interactions_async_round_trip.yaml @@ -0,0 +1,155 @@ +interactions: +- request: + body: '{"input":"What is the capital of Italy?","model":"gemini-2.5-flash"}' + headers: + Accept: + - application/json + Accept-Encoding: + - gzip, deflate, zstd + Connection: + - keep-alive + Content-Length: + - '68' + Content-Type: + - application/json + Host: + - generativelanguage.googleapis.com + User-Agent: + - google-genai-sdk/1.66.0 gl-python/3.13.3 + x-goog-api-client: + - google-genai-sdk/1.66.0 gl-python/3.13.3 + method: POST + uri: https://generativelanguage.googleapis.com/v1beta/interactions + response: + body: + string: '{"id":"v1_ChdsaExWYWZYWE9mMkIxTWtQNk1EemlBbxIXbGhMVmFmWFhPZjJCMU1rUDZNRHppQW8","status":"completed","outputs":[{"signature":"CpMBAb4+9vv4hIyHoqn7P0G1RcfNsV+D0TVGVpbCfzmIK/crkqR6Qa0S42ec4l0Oq4Z0RbmycMFOP/vrRY+Sc4LXg05Rq9VWmajdjv6nOTQcjUcEiJtClTBse8396TbBWKKXRVDbmV1agysShvPyxEIv0Ics6mNi1npUq+w2TOFJ/KLgd19fIsh8eBVEV3pWdjvJd0ja","type":"thought"},{"text":"The + capital of Italy is **Rome**.","type":"text"}],"usage":{"total_tokens":41,"total_input_tokens":8,"input_tokens_by_modality":[{"modality":"text","tokens":8}],"total_cached_tokens":0,"total_output_tokens":8,"total_tool_use_tokens":0,"total_thought_tokens":25},"role":"model","created":"2026-04-07T14:20:08Z","updated":"2026-04-07T14:20:08Z","object":"interaction","model":"gemini-2.5-flash"}' + headers: + Alt-Svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + Content-Type: + - application/json + Date: + - Tue, 07 Apr 2026 14:20:08 GMT + Server: + - scaffolding on HTTPServer2 + Server-Timing: + - gfet4t7; dur=1348 + Transfer-Encoding: + - chunked + Vary: + - Origin + - X-Origin + - Referer + X-Content-Type-Options: + - nosniff + X-Frame-Options: + - SAMEORIGIN + X-XSS-Protection: + - '0' + content-length: + - '747' + status: + code: 200 + message: OK +- request: + body: '' + headers: + Accept: + - application/json + Accept-Encoding: + - gzip, deflate, zstd + Connection: + - keep-alive + Host: + - generativelanguage.googleapis.com + User-Agent: + - google-genai-sdk/1.66.0 gl-python/3.13.3 + x-goog-api-client: + - google-genai-sdk/1.66.0 gl-python/3.13.3 + method: GET + uri: https://generativelanguage.googleapis.com/v1beta/interactions/v1_ChdsaExWYWZYWE9mMkIxTWtQNk1EemlBbxIXbGhMVmFmWFhPZjJCMU1rUDZNRHppQW8?include_input=true + response: + body: + string: '{"id":"v1_ChdsaExWYWZYWE9mMkIxTWtQNk1EemlBbxIXbGhMVmFmWFhPZjJCMU1rUDZNRHppQW8","status":"completed","outputs":[{"signature":"CpMBAb4+9vv4hIyHoqn7P0G1RcfNsV+D0TVGVpbCfzmIK/crkqR6Qa0S42ec4l0Oq4Z0RbmycMFOP/vrRY+Sc4LXg05Rq9VWmajdjv6nOTQcjUcEiJtClTBse8396TbBWKKXRVDbmV1agysShvPyxEIv0Ics6mNi1npUq+w2TOFJ/KLgd19fIsh8eBVEV3pWdjvJd0ja","type":"thought"},{"text":"The + capital of Italy is **Rome**.","type":"text"}],"usage":{"total_tokens":41,"total_input_tokens":8,"input_tokens_by_modality":[{"modality":"text","tokens":8}],"total_cached_tokens":0,"total_output_tokens":8,"total_tool_use_tokens":0,"total_thought_tokens":25},"role":"model","created":"2026-04-07T14:20:08Z","updated":"2026-04-07T14:20:08Z","object":"interaction","model":"gemini-2.5-flash","input":[{"role":"user","content":[{"text":"What + is the capital of Italy?","type":"text"}]}]}' + headers: + Alt-Svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + Content-Type: + - application/json + Date: + - Tue, 07 Apr 2026 14:20:08 GMT + Server: + - scaffolding on HTTPServer2 + Server-Timing: + - gfet4t7; dur=130 + Transfer-Encoding: + - chunked + Vary: + - Origin + - X-Origin + - Referer + X-Content-Type-Options: + - nosniff + X-Frame-Options: + - SAMEORIGIN + X-XSS-Protection: + - '0' + content-length: + - '840' + status: + code: 200 + message: OK +- request: + body: '' + headers: + Accept: + - application/json + Accept-Encoding: + - gzip, deflate, zstd + Connection: + - keep-alive + Content-Type: + - application/json + Host: + - generativelanguage.googleapis.com + User-Agent: + - google-genai-sdk/1.66.0 gl-python/3.13.3 + x-goog-api-client: + - google-genai-sdk/1.66.0 gl-python/3.13.3 + method: DELETE + uri: https://generativelanguage.googleapis.com/v1beta/interactions/v1_ChdsaExWYWZYWE9mMkIxTWtQNk1EemlBbxIXbGhMVmFmWFhPZjJCMU1rUDZNRHppQW8 + response: + body: + string: '{}' + headers: + Alt-Svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + Content-Type: + - application/json + Date: + - Tue, 07 Apr 2026 14:20:09 GMT + Server: + - scaffolding on HTTPServer2 + Server-Timing: + - gfet4t7; dur=382 + Transfer-Encoding: + - chunked + Vary: + - Origin + - X-Origin + - Referer + X-Content-Type-Options: + - nosniff + X-Frame-Options: + - SAMEORIGIN + X-XSS-Protection: + - '0' + content-length: + - '2' + status: + code: 200 + message: OK +version: 1 diff --git a/py/src/braintrust/integrations/google_genai/cassettes/test_interactions_async_stream.yaml b/py/src/braintrust/integrations/google_genai/cassettes/test_interactions_async_stream.yaml new file mode 100644 index 00000000..7bda7e65 --- /dev/null +++ b/py/src/braintrust/integrations/google_genai/cassettes/test_interactions_async_stream.yaml @@ -0,0 +1,104 @@ +interactions: +- request: + body: '{"input":"Say hi shortly.","model":"gemini-2.5-flash","stream":true}' + headers: + Accept: + - application/json + Accept-Encoding: + - gzip, deflate, zstd + Connection: + - keep-alive + Content-Length: + - '68' + Content-Type: + - application/json + Host: + - generativelanguage.googleapis.com + User-Agent: + - google-genai-sdk/1.66.0 gl-python/3.13.3 + x-goog-api-client: + - google-genai-sdk/1.66.0 gl-python/3.13.3 + method: POST + uri: https://generativelanguage.googleapis.com/v1beta/interactions + response: + body: + string: 'event: interaction.start + + data: {"interaction":{"id":"v1_ChdtUkxWYWVqRUVZQ3YxTWtQb2IydS1RRRIXbVJMVmFlakVFWUN2MU1rUG9iMnUtUUU","status":"in_progress","object":"interaction","model":"gemini-2.5-flash"},"event_type":"interaction.start"} + + + event: interaction.status_update + + data: {"interaction_id":"v1_ChdtUkxWYWVqRUVZQ3YxTWtQb2IydS1RRRIXbVJMVmFlakVFWUN2MU1rUG9iMnUtUUU","status":"in_progress","event_type":"interaction.status_update"} + + + event: content.start + + data: {"index":0,"content":{"type":"thought"},"event_type":"content.start"} + + + event: content.delta + + data: {"index":0,"delta":{"signature":"ClUBvj72+9F73CQIW9AyrjKse6nIY6vxogBksIaH5zxIIsOMHbbkpLYbB3nSdvBf3485j4bntqHPZjSbypKwI84aO8RagejmAGNRu/Ss+Mebq15i5UIuCnUBvj72+0Ewr8MFjWarSxByDjIn8WzosxL4jSr61M/Cdf2bvYpcks9icyyuHi+zvrfpTOop2ODlpoa7Y/UexoemmbjdKpZeXy5wHb4R3eaiRAeucVPtpoMNFQi5TocaZXfK9hP5BP4p44e9zrJBVUYigBehaec=","type":"thought_signature"},"event_type":"content.delta"} + + + event: content.stop + + data: {"index":0,"event_type":"content.stop"} + + + event: content.start + + data: {"index":1,"content":{"type":"text"},"event_type":"content.start"} + + + event: content.delta + + data: {"index":1,"delta":{"text":"Hi!","type":"text"},"event_type":"content.delta"} + + + event: content.stop + + data: {"index":1,"event_type":"content.stop"} + + + event: interaction.complete + + data: {"interaction":{"id":"v1_ChdtUkxWYWVqRUVZQ3YxTWtQb2IydS1RRRIXbVJMVmFlakVFWUN2MU1rUG9iMnUtUUU","status":"completed","usage":{"total_tokens":42,"total_input_tokens":5,"input_tokens_by_modality":[{"modality":"text","tokens":5}],"total_cached_tokens":0,"total_output_tokens":2,"total_tool_use_tokens":0,"total_thought_tokens":35},"role":"model","created":"2026-04-07T14:20:10Z","updated":"2026-04-07T14:20:10Z","object":"interaction","model":"gemini-2.5-flash"},"event_type":"interaction.complete"} + + + event: done + + data: [DONE] + + + ' + headers: + Alt-Svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + Content-Type: + - text/event-stream + Date: + - Tue, 07 Apr 2026 14:20:10 GMT + Server: + - scaffolding on HTTPServer2 + Server-Timing: + - gfet4t7; dur=795 + Transfer-Encoding: + - chunked + Vary: + - Origin + - X-Origin + - Referer + X-Content-Type-Options: + - nosniff + X-Frame-Options: + - SAMEORIGIN + X-XSS-Protection: + - '0' + content-length: + - '1816' + status: + code: 200 + message: OK +version: 1 diff --git a/py/src/braintrust/integrations/google_genai/cassettes/test_interactions_create_and_get.yaml b/py/src/braintrust/integrations/google_genai/cassettes/test_interactions_create_and_get.yaml new file mode 100644 index 00000000..fb100cbb --- /dev/null +++ b/py/src/braintrust/integrations/google_genai/cassettes/test_interactions_create_and_get.yaml @@ -0,0 +1,105 @@ +interactions: +- request: + body: '{"input":"What is the capital of France?","model":"gemini-2.5-flash"}' + headers: + Accept: + - application/json + Accept-Encoding: + - gzip, deflate, zstd + Connection: + - keep-alive + Content-Length: + - '69' + Content-Type: + - application/json + Host: + - generativelanguage.googleapis.com + User-Agent: + - google-genai-sdk/1.66.0 gl-python/3.13.3 + x-goog-api-client: + - google-genai-sdk/1.66.0 gl-python/3.13.3 + method: POST + uri: https://generativelanguage.googleapis.com/v1beta/interactions + response: + body: + string: '{"id":"v1_ChdqaExWYWVIQUQ4YTIxTWtQZ2ZxUDRRWRIXamhMVmFlSEFEOGEyMU1rUGdmcVA0UVk","status":"completed","outputs":[{"signature":"Cp8BAb4+9vtLsikCpBEO0omEzOmvrAc1XfoRIv4pg6pRC203MxRPHMkBCpCP1EfSaBQ8Ypk3CQsck/Yi+7N/+yj9vbqfCh2+idJjFcWg7Sxo64B10gs8O67D06FhP/gDpnv/hLOImOAKwnw3JYsQtVnAINE2zYTMVeirk116zvayi5D5iOzQ+2/SuQCtuUhIXWt2Jyy3LHcEUDfeJgO7uHXw","type":"thought"},{"text":"The + capital of France is **Paris**.","type":"text"}],"usage":{"total_tokens":41,"total_input_tokens":8,"input_tokens_by_modality":[{"modality":"text","tokens":8}],"total_cached_tokens":0,"total_output_tokens":8,"total_tool_use_tokens":0,"total_thought_tokens":25},"role":"model","created":"2026-04-07T14:19:59Z","updated":"2026-04-07T14:19:59Z","object":"interaction","model":"gemini-2.5-flash"}' + headers: + Alt-Svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + Content-Type: + - application/json + Date: + - Tue, 07 Apr 2026 14:19:59 GMT + Server: + - scaffolding on HTTPServer2 + Server-Timing: + - gfet4t7; dur=1225 + Transfer-Encoding: + - chunked + Vary: + - Origin + - X-Origin + - Referer + X-Content-Type-Options: + - nosniff + X-Frame-Options: + - SAMEORIGIN + X-XSS-Protection: + - '0' + content-length: + - '765' + status: + code: 200 + message: OK +- request: + body: '' + headers: + Accept: + - application/json + Accept-Encoding: + - gzip, deflate, zstd + Connection: + - keep-alive + Host: + - generativelanguage.googleapis.com + User-Agent: + - google-genai-sdk/1.66.0 gl-python/3.13.3 + x-goog-api-client: + - google-genai-sdk/1.66.0 gl-python/3.13.3 + method: GET + uri: https://generativelanguage.googleapis.com/v1beta/interactions/v1_ChdqaExWYWVIQUQ4YTIxTWtQZ2ZxUDRRWRIXamhMVmFlSEFEOGEyMU1rUGdmcVA0UVk?include_input=true + response: + body: + string: '{"id":"v1_ChdqaExWYWVIQUQ4YTIxTWtQZ2ZxUDRRWRIXamhMVmFlSEFEOGEyMU1rUGdmcVA0UVk","status":"completed","outputs":[{"signature":"Cp8BAb4+9vtLsikCpBEO0omEzOmvrAc1XfoRIv4pg6pRC203MxRPHMkBCpCP1EfSaBQ8Ypk3CQsck/Yi+7N/+yj9vbqfCh2+idJjFcWg7Sxo64B10gs8O67D06FhP/gDpnv/hLOImOAKwnw3JYsQtVnAINE2zYTMVeirk116zvayi5D5iOzQ+2/SuQCtuUhIXWt2Jyy3LHcEUDfeJgO7uHXw","type":"thought"},{"text":"The + capital of France is **Paris**.","type":"text"}],"usage":{"total_tokens":41,"total_input_tokens":8,"input_tokens_by_modality":[{"modality":"text","tokens":8}],"total_cached_tokens":0,"total_output_tokens":8,"total_tool_use_tokens":0,"total_thought_tokens":25},"role":"model","created":"2026-04-07T14:19:59Z","updated":"2026-04-07T14:19:59Z","object":"interaction","model":"gemini-2.5-flash","input":[{"role":"user","content":[{"text":"What + is the capital of France?","type":"text"}]}]}' + headers: + Alt-Svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + Content-Type: + - application/json + Date: + - Tue, 07 Apr 2026 14:19:59 GMT + Server: + - scaffolding on HTTPServer2 + Server-Timing: + - gfet4t7; dur=102 + Transfer-Encoding: + - chunked + Vary: + - Origin + - X-Origin + - Referer + X-Content-Type-Options: + - nosniff + X-Frame-Options: + - SAMEORIGIN + X-XSS-Protection: + - '0' + content-length: + - '859' + status: + code: 200 + message: OK +version: 1 diff --git a/py/src/braintrust/integrations/google_genai/cassettes/test_interactions_create_stream.yaml b/py/src/braintrust/integrations/google_genai/cassettes/test_interactions_create_stream.yaml new file mode 100644 index 00000000..549a88fa --- /dev/null +++ b/py/src/braintrust/integrations/google_genai/cassettes/test_interactions_create_stream.yaml @@ -0,0 +1,109 @@ +interactions: +- request: + body: '{"input":"Say hi in five words or less.","model":"gemini-2.5-flash","stream":true}' + headers: + Accept: + - application/json + Accept-Encoding: + - gzip, deflate, zstd + Connection: + - keep-alive + Content-Length: + - '82' + Content-Type: + - application/json + Host: + - generativelanguage.googleapis.com + User-Agent: + - google-genai-sdk/1.66.0 gl-python/3.13.3 + x-goog-api-client: + - google-genai-sdk/1.66.0 gl-python/3.13.3 + method: POST + uri: https://generativelanguage.googleapis.com/v1beta/interactions + response: + body: + string: 'event: interaction.start + + data: {"interaction":{"id":"v1_ChdrQkxWYWJfc0NOV1o5TW9QbTZtRm9RSRIXa0JMVmFiX3NDTldaOU1vUG02bUZvUUk","status":"in_progress","object":"interaction","model":"gemini-2.5-flash"},"event_type":"interaction.start"} + + + event: interaction.status_update + + data: {"interaction_id":"v1_ChdrQkxWYWJfc0NOV1o5TW9QbTZtRm9RSRIXa0JMVmFiX3NDTldaOU1vUG02bUZvUUk","status":"in_progress","event_type":"interaction.status_update"} + + + event: content.start + + data: {"index":0,"content":{"type":"thought"},"event_type":"content.start"} + + + event: content.delta + + data: {"index":0,"delta":{"signature":"ClsBvj72+zoBtQP6EjTMrMyD5hM3sBNS/6cPoB/L7fFYkvzXlH/JhNd91N/Lf2BehMZvRijqF0Nk04nFHWjRcey23m0ly4uN3PaKv1jPos1jruzVSo2kjBna5/dvCrUBAb4+9vuoYgWiAXCse1w4Zc6moR0NuIZEblr4QuJSCF8OmKREmC6SLGT22wyV/Hpdhdy8AnoQGxF7d/gnR+RDBgBSH2krIW7HchLuqxZKmahXFuTjuFD8QDbCvOUJ9JGdGvM5tpW/akOdcPzMnZqm0C0/0NCPOG6VhbaloyGrQbI370a546lOHiY7ZdocwKb/pDEZRJIWqHOpvXviWhVIj/jM6gk+n9kRejz+ZfBGtPL/jeRhJQ==","type":"thought_signature"},"event_type":"content.delta"} + + + event: content.stop + + data: {"index":0,"event_type":"content.stop"} + + + event: content.start + + data: {"index":1,"content":{"type":"text"},"event_type":"content.start"} + + + event: content.delta + + data: {"index":1,"delta":{"text":"Hi there","type":"text"},"event_type":"content.delta"} + + + event: content.delta + + data: {"index":1,"delta":{"text":"!","type":"text"},"event_type":"content.delta"} + + + event: content.stop + + data: {"index":1,"event_type":"content.stop"} + + + event: interaction.complete + + data: {"interaction":{"id":"v1_ChdrQkxWYWJfc0NOV1o5TW9QbTZtRm9RSRIXa0JMVmFiX3NDTldaOU1vUG02bUZvUUk","status":"completed","usage":{"total_tokens":74,"total_input_tokens":9,"input_tokens_by_modality":[{"modality":"text","tokens":9}],"total_cached_tokens":0,"total_output_tokens":3,"total_tool_use_tokens":0,"total_thought_tokens":62},"role":"model","created":"2026-04-07T14:20:01Z","updated":"2026-04-07T14:20:01Z","object":"interaction","model":"gemini-2.5-flash"},"event_type":"interaction.complete"} + + + event: done + + data: [DONE] + + + ' + headers: + Alt-Svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + Content-Type: + - text/event-stream + Date: + - Tue, 07 Apr 2026 14:20:00 GMT + Server: + - scaffolding on HTTPServer2 + Server-Timing: + - gfet4t7; dur=887 + Transfer-Encoding: + - chunked + Vary: + - Origin + - X-Origin + - Referer + X-Content-Type-Options: + - nosniff + X-Frame-Options: + - SAMEORIGIN + X-XSS-Protection: + - '0' + content-length: + - '2021' + status: + code: 200 + message: OK +version: 1 diff --git a/py/src/braintrust/integrations/google_genai/cassettes/test_interactions_delete.yaml b/py/src/braintrust/integrations/google_genai/cassettes/test_interactions_delete.yaml new file mode 100644 index 00000000..19575b5a --- /dev/null +++ b/py/src/braintrust/integrations/google_genai/cassettes/test_interactions_delete.yaml @@ -0,0 +1,104 @@ +interactions: +- request: + body: '{"input":"Reply with exactly ok.","model":"gemini-2.5-flash"}' + headers: + Accept: + - application/json + Accept-Encoding: + - gzip, deflate, zstd + Connection: + - keep-alive + Content-Length: + - '61' + Content-Type: + - application/json + Host: + - generativelanguage.googleapis.com + User-Agent: + - google-genai-sdk/1.66.0 gl-python/3.13.3 + x-goog-api-client: + - google-genai-sdk/1.66.0 gl-python/3.13.3 + method: POST + uri: https://generativelanguage.googleapis.com/v1beta/interactions + response: + body: + string: '{"id":"v1_ChdsQkxWYWJtU09ZMmMxTWtQcFpQYXdRdxIXbEJMVmFibVNPWTJjMU1rUHBaUGF3UXc","status":"completed","outputs":[{"signature":"CmABvj72+7WzytAAFU/DR5hYJRFlxm4yJKjxhaIxo1oQQjH2hur9PHCtCrldHToJoyOVozXlgHA9xq7trYBtqGiCZAO/MfzKIz9M0OGv6W9JVTUlAV/wNZyqz9j3OJV1wOU=","type":"thought"},{"text":"ok","type":"text"}],"usage":{"total_tokens":23,"total_input_tokens":6,"input_tokens_by_modality":[{"modality":"text","tokens":6}],"total_cached_tokens":0,"total_output_tokens":1,"total_tool_use_tokens":0,"total_thought_tokens":16},"role":"model","created":"2026-04-07T14:20:06Z","updated":"2026-04-07T14:20:06Z","object":"interaction","model":"gemini-2.5-flash"}' + headers: + Alt-Svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + Content-Type: + - application/json + Date: + - Tue, 07 Apr 2026 14:20:06 GMT + Server: + - scaffolding on HTTPServer2 + Server-Timing: + - gfet4t7; dur=1269 + Transfer-Encoding: + - chunked + Vary: + - Origin + - X-Origin + - Referer + X-Content-Type-Options: + - nosniff + X-Frame-Options: + - SAMEORIGIN + X-XSS-Protection: + - '0' + content-length: + - '648' + status: + code: 200 + message: OK +- request: + body: '' + headers: + Accept: + - application/json + Accept-Encoding: + - gzip, deflate, zstd + Connection: + - keep-alive + Content-Type: + - application/json + Host: + - generativelanguage.googleapis.com + User-Agent: + - google-genai-sdk/1.66.0 gl-python/3.13.3 + x-goog-api-client: + - google-genai-sdk/1.66.0 gl-python/3.13.3 + method: DELETE + uri: https://generativelanguage.googleapis.com/v1beta/interactions/v1_ChdsQkxWYWJtU09ZMmMxTWtQcFpQYXdRdxIXbEJMVmFibVNPWTJjMU1rUHBaUGF3UXc + response: + body: + string: '{}' + headers: + Alt-Svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + Content-Type: + - application/json + Date: + - Tue, 07 Apr 2026 14:20:06 GMT + Server: + - scaffolding on HTTPServer2 + Server-Timing: + - gfet4t7; dur=390 + Transfer-Encoding: + - chunked + Vary: + - Origin + - X-Origin + - Referer + X-Content-Type-Options: + - nosniff + X-Frame-Options: + - SAMEORIGIN + X-XSS-Protection: + - '0' + content-length: + - '2' + status: + code: 200 + message: OK +version: 1 diff --git a/py/src/braintrust/integrations/google_genai/cassettes/test_interactions_tool_call_and_follow_up.yaml b/py/src/braintrust/integrations/google_genai/cassettes/test_interactions_tool_call_and_follow_up.yaml new file mode 100644 index 00000000..4b36853d --- /dev/null +++ b/py/src/braintrust/integrations/google_genai/cassettes/test_interactions_tool_call_and_follow_up.yaml @@ -0,0 +1,109 @@ +interactions: +- request: + body: '{"input":"What is the weather like in Paris? Use the tool.","model":"gemini-2.5-flash","tools":[{"type":"function","description":"Get + the current weather for a location.","name":"get_weather","parameters":{"type":"object","properties":{"location":{"type":"string"}},"required":["location"]}}]}' + headers: + Accept: + - application/json + Accept-Encoding: + - gzip, deflate, zstd + Connection: + - keep-alive + Content-Length: + - '293' + Content-Type: + - application/json + Host: + - generativelanguage.googleapis.com + User-Agent: + - google-genai-sdk/1.66.0 gl-python/3.13.3 + x-goog-api-client: + - google-genai-sdk/1.66.0 gl-python/3.13.3 + method: POST + uri: https://generativelanguage.googleapis.com/v1beta/interactions + response: + body: + string: '{"id":"v1_ChdrUkxWYWZIbEtxeWsxTWtQNTl1Ry1BbxIXa1JMVmFmSGxLcXlrMU1rUDU5dUctQW8","status":"requires_action","outputs":[{"signature":"Cv4BAb4+9vv0gIjK+dI8G/Gk9ZNL+77B3vRd+v3fJn9pPaT8958QL7Fh3Jzss/lUrm0BiaMlHjTqifxqb5h9Mzy8I5SJ1X7sTN8wWBGhpnFUTsZGVw3Cg8e0mdUJ6MnOxDCPNEmPsG8xe+N7forjuaU74YxGhtt8Ase3PWlvsQ9rXWQSQI3+EwiriF9zT+rF0iUXYsmlNC2jeVydCRJzetADEP1whtaY6Qm+8SHGwELS/T1ZyhOJtzE56mtdVQwvt1pr1fSINrVApTUAwmvJNJsDnndwNX06AXqKyx6l61qAlgSqoeo/1Zb7a2R6oJhmDeuvifsJyNbFJFOmyd9aClw=","type":"thought"},{"name":"get_weather","arguments":{"location":"Paris"},"type":"function_call","id":"y124y6bd"}],"usage":{"total_tokens":119,"total_input_tokens":53,"input_tokens_by_modality":[{"modality":"text","tokens":53}],"total_cached_tokens":0,"total_output_tokens":15,"total_tool_use_tokens":0,"total_thought_tokens":51},"role":"model","created":"2026-04-07T14:20:03Z","updated":"2026-04-07T14:20:03Z","object":"interaction","model":"gemini-2.5-flash"}' + headers: + Alt-Svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + Content-Type: + - application/json + Date: + - Tue, 07 Apr 2026 14:20:03 GMT + Server: + - scaffolding on HTTPServer2 + Server-Timing: + - gfet4t7; dur=1385 + Transfer-Encoding: + - chunked + Vary: + - Origin + - X-Origin + - Referer + X-Content-Type-Options: + - nosniff + X-Frame-Options: + - SAMEORIGIN + X-XSS-Protection: + - '0' + content-length: + - '937' + status: + code: 200 + message: OK +- request: + body: '{"input":{"call_id":"y124y6bd","result":{"forecast":"sunny"},"type":"function_result","name":"get_weather"},"model":"gemini-2.5-flash","previous_interaction_id":"v1_ChdrUkxWYWZIbEtxeWsxTWtQNTl1Ry1BbxIXa1JMVmFmSGxLcXlrMU1rUDU5dUctQW8","tools":[{"type":"function","description":"Get + the current weather for a location.","name":"get_weather","parameters":{"type":"object","properties":{"location":{"type":"string"}},"required":["location"]}}]}' + headers: + Accept: + - application/json + Accept-Encoding: + - gzip, deflate, zstd + Connection: + - keep-alive + Content-Length: + - '440' + Content-Type: + - application/json + Host: + - generativelanguage.googleapis.com + User-Agent: + - google-genai-sdk/1.66.0 gl-python/3.13.3 + x-goog-api-client: + - google-genai-sdk/1.66.0 gl-python/3.13.3 + method: POST + uri: https://generativelanguage.googleapis.com/v1beta/interactions + response: + body: + string: '{"id":"v1_ChdrUkxWYWZIbEtxeWsxTWtQNTl1Ry1BbxIXa3hMVmFhTzJFOW1XMU1rUHhLMmlzUXM","status":"completed","outputs":[{"signature":"CiRlMjQ4MzBhNy01Y2Q2LTQyZmUtOTk4Yi1lZTUzOWU3MmI5YzM=","type":"thought"},{"text":"The + weather in Paris is sunny.","type":"text"}],"usage":{"total_tokens":61,"total_input_tokens":54,"input_tokens_by_modality":[{"modality":"text","tokens":54}],"total_cached_tokens":0,"total_output_tokens":7,"total_tool_use_tokens":0,"total_thought_tokens":0},"role":"model","created":"2026-04-07T14:20:04Z","updated":"2026-04-07T14:20:04Z","object":"interaction","model":"gemini-2.5-flash"}' + headers: + Alt-Svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + Content-Type: + - application/json + Date: + - Tue, 07 Apr 2026 14:20:04 GMT + Server: + - scaffolding on HTTPServer2 + Server-Timing: + - gfet4t7; dur=1483 + Transfer-Encoding: + - chunked + Vary: + - Origin + - X-Origin + - Referer + X-Content-Type-Options: + - nosniff + X-Frame-Options: + - SAMEORIGIN + X-XSS-Protection: + - '0' + content-length: + - '597' + status: + code: 200 + message: OK +version: 1 diff --git a/py/src/braintrust/integrations/google_genai/integration.py b/py/src/braintrust/integrations/google_genai/integration.py index 48faf088..b48099f5 100644 --- a/py/src/braintrust/integrations/google_genai/integration.py +++ b/py/src/braintrust/integrations/google_genai/integration.py @@ -5,10 +5,18 @@ from braintrust.integrations.base import BaseIntegration from .patchers import ( + AsyncInteractionsCancelPatcher, + AsyncInteractionsCreatePatcher, + AsyncInteractionsDeletePatcher, + AsyncInteractionsGetPatcher, AsyncModelsEmbedContentPatcher, AsyncModelsGenerateContentPatcher, AsyncModelsGenerateContentStreamPatcher, AsyncModelsGenerateImagesPatcher, + InteractionsCancelPatcher, + InteractionsCreatePatcher, + InteractionsDeletePatcher, + InteractionsGetPatcher, ModelsEmbedContentPatcher, ModelsGenerateContentPatcher, ModelsGenerateContentStreamPatcher, @@ -29,8 +37,16 @@ class GoogleGenAIIntegration(BaseIntegration): ModelsGenerateContentStreamPatcher, ModelsEmbedContentPatcher, ModelsGenerateImagesPatcher, + InteractionsCreatePatcher, + InteractionsGetPatcher, + InteractionsCancelPatcher, + InteractionsDeletePatcher, AsyncModelsGenerateContentPatcher, AsyncModelsGenerateContentStreamPatcher, AsyncModelsEmbedContentPatcher, AsyncModelsGenerateImagesPatcher, + AsyncInteractionsCreatePatcher, + AsyncInteractionsGetPatcher, + AsyncInteractionsCancelPatcher, + AsyncInteractionsDeletePatcher, ) diff --git a/py/src/braintrust/integrations/google_genai/patchers.py b/py/src/braintrust/integrations/google_genai/patchers.py index 7bc895f1..dc4433a0 100644 --- a/py/src/braintrust/integrations/google_genai/patchers.py +++ b/py/src/braintrust/integrations/google_genai/patchers.py @@ -7,10 +7,18 @@ _async_generate_content_stream_wrapper, _async_generate_content_wrapper, _async_generate_images_wrapper, + _async_interactions_cancel_wrapper, + _async_interactions_create_wrapper, + _async_interactions_delete_wrapper, + _async_interactions_get_wrapper, _embed_content_wrapper, _generate_content_stream_wrapper, _generate_content_wrapper, _generate_images_wrapper, + _interactions_cancel_wrapper, + _interactions_create_wrapper, + _interactions_delete_wrapper, + _interactions_get_wrapper, ) @@ -55,6 +63,47 @@ class ModelsGenerateImagesPatcher(FunctionWrapperPatcher): wrapper = _generate_images_wrapper +# --------------------------------------------------------------------------- +# Sync Interactions patchers +# --------------------------------------------------------------------------- + + +class InteractionsCreatePatcher(FunctionWrapperPatcher): + """Patch ``InteractionsResource.create`` for tracing.""" + + name = "google_genai.interactions.create" + target_module = "google.genai._interactions.resources.interactions" + target_path = "InteractionsResource.create" + wrapper = _interactions_create_wrapper + + +class InteractionsGetPatcher(FunctionWrapperPatcher): + """Patch ``InteractionsResource.get`` for tracing.""" + + name = "google_genai.interactions.get" + target_module = "google.genai._interactions.resources.interactions" + target_path = "InteractionsResource.get" + wrapper = _interactions_get_wrapper + + +class InteractionsCancelPatcher(FunctionWrapperPatcher): + """Patch ``InteractionsResource.cancel`` for tracing.""" + + name = "google_genai.interactions.cancel" + target_module = "google.genai._interactions.resources.interactions" + target_path = "InteractionsResource.cancel" + wrapper = _interactions_cancel_wrapper + + +class InteractionsDeletePatcher(FunctionWrapperPatcher): + """Patch ``InteractionsResource.delete`` for tracing.""" + + name = "google_genai.interactions.delete" + target_module = "google.genai._interactions.resources.interactions" + target_path = "InteractionsResource.delete" + wrapper = _interactions_delete_wrapper + + # --------------------------------------------------------------------------- # Async Models patchers # --------------------------------------------------------------------------- @@ -94,3 +143,44 @@ class AsyncModelsGenerateImagesPatcher(FunctionWrapperPatcher): target_module = "google.genai.models" target_path = "AsyncModels.generate_images" wrapper = _async_generate_images_wrapper + + +# --------------------------------------------------------------------------- +# Async Interactions patchers +# --------------------------------------------------------------------------- + + +class AsyncInteractionsCreatePatcher(FunctionWrapperPatcher): + """Patch ``AsyncInteractionsResource.create`` for tracing.""" + + name = "google_genai.async_interactions.create" + target_module = "google.genai._interactions.resources.interactions" + target_path = "AsyncInteractionsResource.create" + wrapper = _async_interactions_create_wrapper + + +class AsyncInteractionsGetPatcher(FunctionWrapperPatcher): + """Patch ``AsyncInteractionsResource.get`` for tracing.""" + + name = "google_genai.async_interactions.get" + target_module = "google.genai._interactions.resources.interactions" + target_path = "AsyncInteractionsResource.get" + wrapper = _async_interactions_get_wrapper + + +class AsyncInteractionsCancelPatcher(FunctionWrapperPatcher): + """Patch ``AsyncInteractionsResource.cancel`` for tracing.""" + + name = "google_genai.async_interactions.cancel" + target_module = "google.genai._interactions.resources.interactions" + target_path = "AsyncInteractionsResource.cancel" + wrapper = _async_interactions_cancel_wrapper + + +class AsyncInteractionsDeletePatcher(FunctionWrapperPatcher): + """Patch ``AsyncInteractionsResource.delete`` for tracing.""" + + name = "google_genai.async_interactions.delete" + target_module = "google.genai._interactions.resources.interactions" + target_path = "AsyncInteractionsResource.delete" + wrapper = _async_interactions_delete_wrapper diff --git a/py/src/braintrust/integrations/google_genai/test_google_genai.py b/py/src/braintrust/integrations/google_genai/test_google_genai.py index fc48d671..1a8c8311 100644 --- a/py/src/braintrust/integrations/google_genai/test_google_genai.py +++ b/py/src/braintrust/integrations/google_genai/test_google_genai.py @@ -8,9 +8,10 @@ from braintrust import logger from braintrust.integrations.google_genai import setup_genai from braintrust.logger import Attachment +from braintrust.span_types import SpanTypeAttribute from braintrust.test_helpers import init_test_logger from braintrust.wrappers.test_utils import verify_autoinstrument_script -from google.genai import types +from google.genai import interactions, types from google.genai.client import Client @@ -19,6 +20,7 @@ EMBEDDING_MODEL = "gemini-embedding-001" IMAGE_MODEL = "imagen-4.0-fast-generate-001" REASONING_MODEL = "gemini-2.5-flash" +INTERACTIONS_MODEL = "gemini-2.5-flash" FIXTURES_DIR = Path(__file__).resolve().parent.parent.parent.parent.parent.parent / "internal/golden/fixtures" TINY_PNG_BASE64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8DwHwAFBQIAX8jx0gAAAABJRU5ErkJggg==" @@ -1165,6 +1167,220 @@ async def test_google_search_grounding_async(memory_logger, mode): _assert_grounding_metadata(span["output"]) +def _find_spans_by_type(spans, span_type): + return [span for span in spans if span["span_attributes"]["type"] == span_type] + + +def _find_span_by_name(spans, name): + return next(span for span in spans if span["span_attributes"]["name"] == name) + + +def _interaction_function_tool(): + return interactions.Function( + type="function", + name="get_weather", + description="Get the current weather for a location.", + parameters={ + "type": "object", + "properties": {"location": {"type": "string"}}, + "required": ["location"], + }, + ) + + +@pytest.mark.vcr +def test_interactions_create_and_get(memory_logger): + assert not memory_logger.pop() + + client = Client() + response = client.interactions.create( + model=INTERACTIONS_MODEL, + input="What is the capital of France?", + ) + fetched = client.interactions.get(response.id, include_input=True) + + assert response.status == "completed" + assert fetched.id == response.id + assert fetched.status == "completed" + + spans = memory_logger.pop() + create_span = _find_span_by_name(_find_spans_by_type(spans, SpanTypeAttribute.LLM), "interactions.create") + get_span = _find_span_by_name(_find_spans_by_type(spans, SpanTypeAttribute.TASK), "interactions.get") + + assert create_span["metadata"]["model"] == INTERACTIONS_MODEL + assert create_span["metadata"]["interaction_id"] == response.id + assert create_span["output"]["status"] == "completed" + assert "Paris" in create_span["output"]["text"] + assert create_span["metrics"]["prompt_tokens"] > 0 + assert create_span["metrics"]["completion_tokens"] > 0 + + assert get_span["input"]["id"] == response.id + assert get_span["metadata"]["interaction_id"] == response.id + assert get_span["output"]["status"] == "completed" + assert "Paris" in get_span["output"]["text"] + assert "France" in str(get_span["output"]["outputs"]) + + +@pytest.mark.vcr +def test_interactions_create_stream(memory_logger): + assert not memory_logger.pop() + + client = Client() + events = list( + client.interactions.create( + model=INTERACTIONS_MODEL, + input="Say hi in five words or less.", + stream=True, + ) + ) + + assert events + + spans = memory_logger.pop() + create_span = _find_span_by_name(_find_spans_by_type(spans, SpanTypeAttribute.LLM), "interactions.create") + + assert create_span["metadata"]["model"] == INTERACTIONS_MODEL + assert create_span["output"]["status"] == "completed" + assert create_span["output"]["text"] + assert "hi" in create_span["output"]["text"].lower() + assert create_span["metrics"]["time_to_first_token"] >= 0 + assert "content.start" in create_span["metadata"]["stream_event_types"] + assert "interaction.complete" in create_span["metadata"]["stream_event_types"] + + +@pytest.mark.vcr +def test_interactions_tool_call_and_follow_up(memory_logger): + assert not memory_logger.pop() + + client = Client() + tool = _interaction_function_tool() + + first_response = client.interactions.create( + model=INTERACTIONS_MODEL, + input="What is the weather like in Paris? Use the tool.", + tools=[tool], + ) + tool_call = next(output for output in first_response.outputs if output.type == "function_call") + + second_response = client.interactions.create( + model=INTERACTIONS_MODEL, + previous_interaction_id=first_response.id, + input=interactions.FunctionResultContent( + type="function_result", + call_id=tool_call.id, + name=tool_call.name, + result={"forecast": "sunny"}, + ), + tools=[tool], + ) + + assert first_response.status == "requires_action" + assert second_response.status == "completed" + assert "sunny" in second_response.outputs[-1].text.lower() + + spans = memory_logger.pop() + llm_spans = _find_spans_by_type(spans, SpanTypeAttribute.LLM) + tool_spans = _find_spans_by_type(spans, SpanTypeAttribute.TOOL) + + first_span = next(span for span in llm_spans if span["metadata"]["interaction_id"] == first_response.id) + second_span = next(span for span in llm_spans if span["metadata"]["interaction_id"] == second_response.id) + tool_span = _find_span_by_name(tool_spans, "get_weather") + + assert first_span["output"]["status"] == "requires_action" + assert second_span["metadata"]["previous_interaction_id"] == first_response.id + assert second_span["output"]["status"] == "completed" + assert "sunny" in second_span["output"]["text"].lower() + + assert tool_span["input"] == {"location": "Paris"} + assert tool_span["span_parents"] == [first_span["span_id"]] + + +@pytest.mark.vcr +def test_interactions_delete(memory_logger): + assert not memory_logger.pop() + + client = Client() + response = client.interactions.create( + model=INTERACTIONS_MODEL, + input="Reply with exactly ok.", + ) + assert response.id + + create_spans = memory_logger.pop() + assert create_spans + + delete_response = client.interactions.delete(response.id) + assert delete_response == {} + + spans = memory_logger.pop() + delete_span = _find_span_by_name(_find_spans_by_type(spans, SpanTypeAttribute.TASK), "interactions.delete") + + assert delete_span["input"]["id"] == response.id + assert delete_span["output"] == {} + assert delete_span["metrics"]["duration"] >= 0 + + +@pytest.mark.vcr +@pytest.mark.asyncio +async def test_interactions_async_round_trip(memory_logger): + assert not memory_logger.pop() + + client = Client() + response = await client.aio.interactions.create( + model=INTERACTIONS_MODEL, + input="What is the capital of Italy?", + ) + fetched = await client.aio.interactions.get(response.id, include_input=True) + deleted = await client.aio.interactions.delete(response.id) + + assert response.status == "completed" + assert fetched.id == response.id + assert deleted == {} + + spans = memory_logger.pop() + create_span = _find_span_by_name(_find_spans_by_type(spans, SpanTypeAttribute.LLM), "interactions.create") + get_span = _find_span_by_name(_find_spans_by_type(spans, SpanTypeAttribute.TASK), "interactions.get") + delete_span = _find_span_by_name(_find_spans_by_type(spans, SpanTypeAttribute.TASK), "interactions.delete") + + assert create_span["metadata"]["model"] == INTERACTIONS_MODEL + assert create_span["output"]["status"] == "completed" + assert "Rome" in create_span["output"]["text"] + + assert get_span["input"]["id"] == response.id + assert "Italy" in str(get_span["output"]["outputs"]) + + assert delete_span["input"]["id"] == response.id + assert delete_span["output"] == {} + + +@pytest.mark.vcr +@pytest.mark.asyncio +async def test_interactions_async_stream(memory_logger): + assert not memory_logger.pop() + + client = Client() + stream = await client.aio.interactions.create( + model=INTERACTIONS_MODEL, + input="Say hi shortly.", + stream=True, + ) + + events = [] + async for event in stream: + events.append(event) + + assert events + + spans = memory_logger.pop() + create_span = _find_span_by_name(_find_spans_by_type(spans, SpanTypeAttribute.LLM), "interactions.create") + + assert create_span["output"]["status"] == "completed" + assert create_span["output"]["text"] + assert "hi" in create_span["output"]["text"].lower() + assert create_span["metrics"]["time_to_first_token"] >= 0 + assert "content.delta" in create_span["metadata"]["stream_event_types"] + + class TestAutoInstrumentGoogleGenAI: """Tests for auto_instrument() with Google GenAI.""" diff --git a/py/src/braintrust/integrations/google_genai/tracing.py b/py/src/braintrust/integrations/google_genai/tracing.py index 0ee6fa56..b899b897 100644 --- a/py/src/braintrust/integrations/google_genai/tracing.py +++ b/py/src/braintrust/integrations/google_genai/tracing.py @@ -1,8 +1,12 @@ """Google GenAI-specific span creation, metadata extraction, stream handling, and output normalization.""" +import base64 +import binascii import logging import time from collections.abc import Awaitable, Callable, Iterable +from datetime import date, datetime +from enum import Enum from typing import TYPE_CHECKING, Any from braintrust.bt_json import bt_safe_deep_copy @@ -11,6 +15,7 @@ if TYPE_CHECKING: + from google.genai._interactions.types.interaction import Interaction from google.genai.types import ( EmbedContentResponse, GenerateContentResponse, @@ -19,6 +24,24 @@ logger = logging.getLogger(__name__) +_MEDIA_CONTENT_TYPES = {"image", "audio", "video", "document"} +_TOOL_CALL_TYPES = { + "function_call", + "code_execution_call", + "url_context_call", + "google_search_call", + "mcp_server_tool_call", + "file_search_call", +} +_TOOL_RESULT_TYPES = { + "function_result", + "code_execution_result", + "url_context_result", + "google_search_result", + "mcp_server_tool_result", + "file_search_result", +} + # --------------------------------------------------------------------------- # Serialization helpers @@ -119,6 +142,54 @@ def _serialize_tools(api_client: Any, input: Any | None) -> Any | None: return None +def _attachment_from_base64_data(data: str, mime_type: str, *, label: str) -> Attachment | None: + raw_data = data + if raw_data.startswith("data:"): + _, _, encoded = raw_data.partition(",") + raw_data = encoded + + try: + decoded = base64.b64decode(raw_data, validate=True) + except (ValueError, binascii.Error): + return None + + extension = mime_type.split("/")[1] if "/" in mime_type else "bin" + return Attachment(data=decoded, filename=f"{label}.{extension}", content_type=mime_type) + + +def _serialize_interaction_content_dict(value: dict[str, Any]) -> dict[str, Any]: + serialized = {key: _serialize_interaction_value(val) for key, val in value.items() if val is not None} + + content_type = serialized.get("type") + data = serialized.get("data") + mime_type = serialized.get("mime_type") + if content_type in _MEDIA_CONTENT_TYPES and isinstance(data, str) and isinstance(mime_type, str): + attachment = _attachment_from_base64_data(data, mime_type, label=content_type) + if attachment is not None: + serialized["data"] = attachment + + return serialized + + +def _serialize_interaction_value(value: Any) -> Any: + if value is None or isinstance(value, (str, int, float, bool, Attachment)): + return value + if isinstance(value, Enum): + return value.value + if isinstance(value, (date, datetime)): + return value.isoformat() + if isinstance(value, (list, tuple)): + return [_serialize_interaction_value(item) for item in value] + if isinstance(value, dict): + return _serialize_interaction_content_dict(value) + if hasattr(value, "model_dump"): + try: + return _serialize_interaction_value(value.model_dump(exclude_none=True)) + except TypeError: + return _serialize_interaction_value(value.model_dump()) + return value + + # --------------------------------------------------------------------------- # Argument extraction helpers # --------------------------------------------------------------------------- @@ -155,6 +226,69 @@ def _prepare_generate_images_traced_call( return _clean(input), clean_kwargs +def _prepare_interaction_create_traced_call( + api_client: Any, args: list[Any], kwargs: dict[str, Any] +) -> tuple[dict[str, Any], dict[str, Any]]: + del api_client, args + + input_data = _clean( + { + "model": kwargs.get("model"), + "agent": kwargs.get("agent"), + "input": _serialize_interaction_value(kwargs.get("input")), + "background": kwargs.get("background"), + "generation_config": _serialize_interaction_value(kwargs.get("generation_config")), + "previous_interaction_id": kwargs.get("previous_interaction_id"), + "response_format": _serialize_interaction_value(kwargs.get("response_format")), + "response_mime_type": kwargs.get("response_mime_type"), + "response_modalities": _serialize_interaction_value(kwargs.get("response_modalities")), + "store": kwargs.get("store"), + "stream": kwargs.get("stream"), + "system_instruction": kwargs.get("system_instruction"), + "tools": _serialize_interaction_value(kwargs.get("tools")), + "agent_config": _serialize_interaction_value(kwargs.get("agent_config")), + } + ) + metadata = _clean( + { + "api_version": kwargs.get("api_version"), + "model": kwargs.get("model"), + "agent": kwargs.get("agent"), + "previous_interaction_id": kwargs.get("previous_interaction_id"), + } + ) + return input_data, metadata + + +def _prepare_interaction_get_traced_call( + api_client: Any, args: list[Any], kwargs: dict[str, Any] +) -> tuple[dict[str, Any], dict[str, Any]]: + del api_client + + interaction_id = args[0] if args else kwargs.get("id") + input_data = _clean( + { + "id": interaction_id, + "include_input": kwargs.get("include_input"), + "last_event_id": kwargs.get("last_event_id"), + "stream": kwargs.get("stream"), + } + ) + metadata = _clean({"api_version": kwargs.get("api_version")}) + return input_data, metadata + + +def _prepare_interaction_id_traced_call( + api_client: Any, args: list[Any], kwargs: dict[str, Any] +) -> tuple[dict[str, Any], dict[str, Any]]: + del api_client + + interaction_id = args[0] if args else kwargs.get("id") + input_data = _clean({"id": interaction_id}) + metadata = _clean({"api_version": kwargs.get("api_version")}) + return input_data, metadata + + # --------------------------------------------------------------------------- # Metric extraction helpers # --------------------------------------------------------------------------- @@ -286,17 +420,94 @@ def _extract_generate_images_output(response: Any) -> dict[str, Any]: ) -def _extract_generate_images_metrics(start: float) -> dict[str, Any]: +def _extract_generic_timing_metrics(start: float) -> dict[str, Any]: end_time = time.time() return _clean( - dict( - start=start, - end=end_time, - duration=end_time - start, + { + "start": start, + "end": end_time, + "duration": end_time - start, + } + ) + + +def _extract_interaction_usage_metrics(usage: Any, metrics: dict[str, Any]) -> None: + if usage is None: + return + + if hasattr(usage, "total_input_tokens") and usage.total_input_tokens is not None: + metrics["prompt_tokens"] = usage.total_input_tokens + if hasattr(usage, "total_output_tokens") and usage.total_output_tokens is not None: + metrics["completion_tokens"] = usage.total_output_tokens + if hasattr(usage, "total_tokens") and usage.total_tokens is not None: + metrics["tokens"] = usage.total_tokens + if hasattr(usage, "total_cached_tokens") and usage.total_cached_tokens is not None: + metrics["prompt_cached_tokens"] = usage.total_cached_tokens + if hasattr(usage, "total_thought_tokens") and usage.total_thought_tokens is not None: + metrics["completion_reasoning_tokens"] = usage.total_thought_tokens + if hasattr(usage, "total_tool_use_tokens") and usage.total_tool_use_tokens is not None: + metrics["tool_use_tokens"] = usage.total_tool_use_tokens + + +def _extract_interaction_text(outputs: list[dict[str, Any]]) -> str | None: + text_parts = [] + for item in outputs: + if item.get("type") == "text" and isinstance(item.get("text"), str): + text_parts.append(item["text"]) + return "".join(text_parts) or None + + +def _serialize_interaction_outputs(response: "Interaction") -> list[dict[str, Any]]: + outputs = _serialize_interaction_value(getattr(response, "outputs", None)) + return outputs if isinstance(outputs, list) else ([] if outputs is None else [outputs]) + + +def _extract_interaction_output( + response: "Interaction", serialized_outputs: list[dict[str, Any]] | None = None +) -> dict[str, Any]: + outputs_list = serialized_outputs if serialized_outputs is not None else _serialize_interaction_outputs(response) + + return _clean( + { + "status": getattr(response, "status", None), + "outputs": outputs_list, + "text": _extract_interaction_text(outputs_list), + } + ) + + +def _extract_interaction_metadata(response: "Interaction") -> dict[str, Any]: + usage = getattr(response, "usage", None) + usage_serialized = _serialize_interaction_value(usage) + usage_by_modality = None + if isinstance(usage_serialized, dict): + usage_by_modality = _clean( + { + "input_tokens_by_modality": usage_serialized.get("input_tokens_by_modality"), + "output_tokens_by_modality": usage_serialized.get("output_tokens_by_modality"), + "cached_tokens_by_modality": usage_serialized.get("cached_tokens_by_modality"), + "tool_use_tokens_by_modality": usage_serialized.get("tool_use_tokens_by_modality"), + } ) + + return _clean( + { + "interaction_id": getattr(response, "id", None), + "previous_interaction_id": getattr(response, "previous_interaction_id", None), + "role": getattr(response, "role", None), + "response_mime_type": getattr(response, "response_mime_type", None), + "response_modalities": _serialize_interaction_value(getattr(response, "response_modalities", None)), + "usage_by_modality": usage_by_modality, + } ) +def _extract_interaction_metrics(response: "Interaction", start: float) -> dict[str, Any]: + metrics = _extract_generic_timing_metrics(start) + _extract_interaction_usage_metrics(getattr(response, "usage", None), metrics) + return metrics + + # --------------------------------------------------------------------------- # Result processing helpers # --------------------------------------------------------------------------- @@ -311,7 +522,66 @@ def _embed_process_result(result: "EmbedContentResponse", start: float) -> tuple def _generate_images_process_result(result: Any, start: float) -> tuple[Any, dict[str, Any]]: - return _extract_generate_images_output(result), _extract_generate_images_metrics(start) + return _extract_generate_images_output(result), _extract_generic_timing_metrics(start) + + +def _tool_span_name(call_item: dict[str, Any] | None, result_item: dict[str, Any] | None) -> str: + item = call_item or result_item or {} + if item.get("server_name") and item.get("name"): + return f"{item['server_name']}.{item['name']}" + if item.get("name"): + return str(item["name"]) + return str(item.get("type") or "interaction_tool") + + +def _tool_span_input(call_item: dict[str, Any] | None) -> Any: + if not call_item: + return None + if call_item.get("arguments") is not None: + return call_item["arguments"] + return ( + _clean( + { + key: value + for key, value in call_item.items() + if key not in {"id", "name", "type", "signature", "server_name"} + } + ) + or None + ) + + +def _tool_span_output(result_item: dict[str, Any] | None) -> Any: + if not result_item: + return None + if result_item.get("result") is not None: + return result_item["result"] + return ( + _clean( + { + key: value + for key, value in result_item.items() + if key not in {"call_id", "name", "type", "signature", "server_name", "is_error"} + } + ) + or None + ) + + +def _interaction_process_result( + result: "Interaction", start: float +) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]: + outputs_list = _serialize_interaction_outputs(result) + _log_interaction_tool_spans_from_outputs(outputs_list) + return ( + _extract_interaction_output(result, outputs_list), + _extract_interaction_metrics(result, start), + _extract_interaction_metadata(result), + ) + + +def _generic_process_result(result: Any, start: float) -> tuple[Any, dict[str, Any]]: + return _serialize_interaction_value(result), _extract_generic_timing_metrics(start) # --------------------------------------------------------------------------- @@ -417,11 +687,188 @@ def _aggregate_generate_content_chunks( return aggregated, clean_metrics +def _is_interaction_content_event(event: Any) -> bool: + return getattr(event, "event_type", None) in {"content.start", "content.delta"} + + +def _merge_interaction_content_delta(item: dict[str, Any], delta: dict[str, Any]) -> dict[str, Any]: + delta_type = delta.get("type") + if item.get("type") is None: + if delta_type == "thought_signature": + item["type"] = "thought" + elif delta_type == "thought_summary": + item["type"] = "thought" + elif delta_type is not None: + item["type"] = delta_type + + for key, value in delta.items(): + if key == "type" or value is None: + continue + if ( + key in item + and isinstance(item[key], str) + and isinstance(value, str) + and key in {"text", "data", "signature"} + ): + item[key] += value + else: + item[key] = value + + return item + + +def _reconstruct_interaction_outputs_from_events(events: list[Any]) -> list[dict[str, Any]]: + outputs_by_index: dict[int, dict[str, Any]] = {} + + for event in events: + event_type = getattr(event, "event_type", None) + index = getattr(event, "index", None) + if not isinstance(index, int): + continue + + if event_type == "content.start": + outputs_by_index[index] = _serialize_interaction_value(getattr(event, "content", None)) or {} + elif event_type == "content.delta": + item = outputs_by_index.setdefault(index, {}) + delta = _serialize_interaction_value(getattr(event, "delta", None)) or {} + if isinstance(delta, dict): + outputs_by_index[index] = _merge_interaction_content_delta(item, delta) + + return [outputs_by_index[index] for index in sorted(outputs_by_index)] + + +def _log_interaction_tool_spans_from_outputs(outputs: list[dict[str, Any]]) -> None: + calls_by_id: dict[str, dict[str, Any]] = {} + pending_results: dict[str, list[dict[str, Any]]] = {} + pairs: list[tuple[dict[str, Any] | None, dict[str, Any] | None]] = [] + emitted_call_ids: set[str] = set() + + for item in outputs: + item_type = item.get("type") + if item_type in _TOOL_CALL_TYPES: + call_id = item.get("id") + if isinstance(call_id, str): + calls_by_id[call_id] = item + for pending in pending_results.pop(call_id, []): + pairs.append((item, pending)) + emitted_call_ids.add(call_id) + else: + pairs.append((item, None)) + elif item_type in _TOOL_RESULT_TYPES: + call_id = item.get("call_id") + if isinstance(call_id, str) and call_id in calls_by_id: + pairs.append((calls_by_id[call_id], item)) + emitted_call_ids.add(call_id) + elif isinstance(call_id, str): + pending_results.setdefault(call_id, []).append(item) + else: + pairs.append((None, item)) + + for call_id, result_items in pending_results.items(): + call_item = calls_by_id.get(call_id) + for result_item in result_items: + pairs.append((call_item, result_item)) + if call_item is not None: + emitted_call_ids.add(call_id) + + for call_id, call_item in calls_by_id.items(): + if call_id not in emitted_call_ids: + pairs.append((call_item, None)) + + for call_item, result_item in pairs: + metadata = _clean( + { + "tool_type": (call_item or result_item or {}).get("type"), + "call_id": (call_item or {}).get("id") or (result_item or {}).get("call_id"), + "server_name": (call_item or result_item or {}).get("server_name"), + "signature": (call_item or result_item or {}).get("signature"), + } + ) + + with start_span( + name=_tool_span_name(call_item, result_item), + type=SpanTypeAttribute.TOOL, + input=_tool_span_input(call_item), + metadata=metadata or None, + ) as tool_span: + if not result_item: + continue + if result_item.get("is_error"): + tool_span.log(error=_tool_span_output(result_item)) + else: + tool_span.log(output=_tool_span_output(result_item)) + + +def _aggregate_interaction_events( + events: list[Any], start: float, first_token_time: float | None = None +) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]: + metrics = _extract_generic_timing_metrics(start) + if first_token_time is not None: + metrics["time_to_first_token"] = first_token_time - start + + metadata = _clean({"stream_event_types": [et for event in events if (et := getattr(event, "event_type", None))]}) + reconstructed_outputs = _reconstruct_interaction_outputs_from_events(events) + + final_interaction = next( + ( + event.interaction + for event in reversed(events) + if hasattr(event, "interaction") and getattr(event, "interaction", None) is not None + ), + None, + ) + if final_interaction is None: + if reconstructed_outputs: + _log_interaction_tool_spans_from_outputs(reconstructed_outputs) + return ( + {"outputs": reconstructed_outputs, "text": _extract_interaction_text(reconstructed_outputs)}, + _clean(metrics), + metadata, + ) + error_event = next( + ( + event + for event in reversed(events) + if getattr(event, "event_type", None) == "error" and getattr(event, "error", None) is not None + ), + None, + ) + if error_event is not None: + metadata["stream_error"] = _serialize_interaction_value(error_event.error) + return {"events": _serialize_interaction_value(events)}, _clean(metrics), metadata + + final_outputs_list = _serialize_interaction_outputs(final_interaction) + if final_outputs_list: + _log_interaction_tool_spans_from_outputs(final_outputs_list) + elif reconstructed_outputs: + _log_interaction_tool_spans_from_outputs(reconstructed_outputs) + + _extract_interaction_usage_metrics(getattr(final_interaction, "usage", None), metrics) + metadata.update(_extract_interaction_metadata(final_interaction)) + + output = _extract_interaction_output(final_interaction, final_outputs_list) + if reconstructed_outputs and not output.get("outputs"): + output["outputs"] = reconstructed_outputs + output["text"] = _extract_interaction_text(reconstructed_outputs) + + return output, _clean(metrics), _clean(metadata) + + # --------------------------------------------------------------------------- # Traced call orchestration # --------------------------------------------------------------------------- +def _normalize_logged_result(result: Any) -> tuple[Any, dict[str, Any], dict[str, Any] | None]: + if isinstance(result, tuple) and len(result) == 3: + output, metrics, metadata = result + return output, metrics, metadata + if isinstance(result, tuple) and len(result) == 2: + output, metrics = result + return output, metrics, None + raise ValueError("Expected process_result/aggregate to return a 2-tuple or 3-tuple") + + def _run_traced_call( api_client: Any, args: list[Any], @@ -429,18 +876,19 @@ def _run_traced_call( *, name: str, invoke: Callable[[], Any], - process_result: Callable[[Any, float], tuple[Any, dict[str, Any]]], + process_result: Callable[[Any, float], tuple[Any, dict[str, Any]] | tuple[Any, dict[str, Any], dict[str, Any]]], prepare_call: Callable[ [Any, list[Any], dict[str, Any]], tuple[dict[str, Any], dict[str, Any]] ] = _prepare_traced_call, + span_type: SpanTypeAttribute = SpanTypeAttribute.LLM, ) -> Any: input, clean_kwargs = prepare_call(api_client, args, kwargs) start = time.time() - with start_span(name=name, type=SpanTypeAttribute.LLM, input=input, metadata=clean_kwargs) as span: + with start_span(name=name, type=span_type, input=input, metadata=clean_kwargs or None) as span: result = invoke() - output, metrics = process_result(result, start) - span.log(output=output, metrics=metrics) + output, metrics, metadata = _normalize_logged_result(process_result(result, start)) + span.log(output=output, metrics=metrics, metadata=metadata) return result @@ -451,18 +899,19 @@ async def _run_async_traced_call( *, name: str, invoke: Callable[[], Awaitable[Any]], - process_result: Callable[[Any, float], tuple[Any, dict[str, Any]]], + process_result: Callable[[Any, float], tuple[Any, dict[str, Any]] | tuple[Any, dict[str, Any], dict[str, Any]]], prepare_call: Callable[ [Any, list[Any], dict[str, Any]], tuple[dict[str, Any], dict[str, Any]] ] = _prepare_traced_call, + span_type: SpanTypeAttribute = SpanTypeAttribute.LLM, ) -> Any: input, clean_kwargs = prepare_call(api_client, args, kwargs) start = time.time() - with start_span(name=name, type=SpanTypeAttribute.LLM, input=input, metadata=clean_kwargs) as span: + with start_span(name=name, type=span_type, input=input, metadata=clean_kwargs or None) as span: result = await invoke() - output, metrics = process_result(result, start) - span.log(output=output, metrics=metrics) + output, metrics, metadata = _normalize_logged_result(process_result(result, start)) + span.log(output=output, metrics=metrics, metadata=metadata) return result @@ -473,22 +922,31 @@ def _run_stream_traced_call( *, name: str, invoke: Callable[[], Any], - aggregate: Callable[[list[Any], float, float | None], tuple[Any, dict[str, Any]]], + aggregate: Callable[ + [list[Any], float, float | None], tuple[Any, dict[str, Any]] | tuple[Any, dict[str, Any], dict[str, Any]] + ], + span_type: SpanTypeAttribute = SpanTypeAttribute.LLM, + first_token_predicate: Callable[[Any], bool] | None = None, + prepare_call: Callable[ + [Any, list[Any], dict[str, Any]], tuple[dict[str, Any], dict[str, Any]] + ] = _prepare_traced_call, ) -> Any: - input, clean_kwargs = _prepare_traced_call(api_client, args, kwargs) + input, clean_kwargs = prepare_call(api_client, args, kwargs) start = time.time() first_token_time = None - with start_span(name=name, type=SpanTypeAttribute.LLM, input=input, metadata=clean_kwargs) as span: + with start_span(name=name, type=span_type, input=input, metadata=clean_kwargs or None) as span: chunks = [] for chunk in invoke(): - if first_token_time is None: + if first_token_time is None and ( + first_token_predicate(chunk) if first_token_predicate is not None else True + ): first_token_time = time.time() chunks.append(chunk) yield chunk - output, metrics = aggregate(chunks, start, first_token_time) - span.log(output=output, metrics=metrics) + output, metrics, metadata = _normalize_logged_result(aggregate(chunks, start, first_token_time)) + span.log(output=output, metrics=metrics, metadata=metadata) return output @@ -499,23 +957,32 @@ def _run_async_stream_traced_call( *, name: str, invoke: Callable[[], Awaitable[Any]], - aggregate: Callable[[list[Any], float, float | None], tuple[Any, dict[str, Any]]], + aggregate: Callable[ + [list[Any], float, float | None], tuple[Any, dict[str, Any]] | tuple[Any, dict[str, Any], dict[str, Any]] + ], + span_type: SpanTypeAttribute = SpanTypeAttribute.LLM, + first_token_predicate: Callable[[Any], bool] | None = None, + prepare_call: Callable[ + [Any, list[Any], dict[str, Any]], tuple[dict[str, Any], dict[str, Any]] + ] = _prepare_traced_call, ) -> Any: - input, clean_kwargs = _prepare_traced_call(api_client, args, kwargs) + input, clean_kwargs = prepare_call(api_client, args, kwargs) async def stream_generator(): start = time.time() first_token_time = None - with start_span(name=name, type=SpanTypeAttribute.LLM, input=input, metadata=clean_kwargs) as span: + with start_span(name=name, type=span_type, input=input, metadata=clean_kwargs or None) as span: chunks = [] async for chunk in await invoke(): - if first_token_time is None: + if first_token_time is None and ( + first_token_predicate(chunk) if first_token_predicate is not None else True + ): first_token_time = time.time() chunks.append(chunk) yield chunk - output, metrics = aggregate(chunks, start, first_token_time) - span.log(output=output, metrics=metrics) + output, metrics, metadata = _normalize_logged_result(aggregate(chunks, start, first_token_time)) + span.log(output=output, metrics=metrics, metadata=metadata) return stream_generator() @@ -613,3 +1080,159 @@ async def _async_generate_images_wrapper(wrapped: Any, instance: Any, args: Any, process_result=_generate_images_process_result, prepare_call=_prepare_generate_images_traced_call, ) + + +def _interactions_create_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any) -> Any: + if kwargs.get("stream"): + return _run_stream_traced_call( + getattr(instance, "_client", None), + args, + kwargs, + name="interactions.create", + invoke=lambda: wrapped(*args, **kwargs), + aggregate=_aggregate_interaction_events, + first_token_predicate=_is_interaction_content_event, + prepare_call=_prepare_interaction_create_traced_call, + span_type=SpanTypeAttribute.LLM, + ) + + return _run_traced_call( + getattr(instance, "_client", None), + args, + kwargs, + name="interactions.create", + invoke=lambda: wrapped(*args, **kwargs), + process_result=_interaction_process_result, + prepare_call=_prepare_interaction_create_traced_call, + span_type=SpanTypeAttribute.LLM, + ) + + +async def _async_interactions_create_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any) -> Any: + if kwargs.get("stream"): + return _run_async_stream_traced_call( + getattr(instance, "_client", None), + args, + kwargs, + name="interactions.create", + invoke=lambda: wrapped(*args, **kwargs), + aggregate=_aggregate_interaction_events, + first_token_predicate=_is_interaction_content_event, + prepare_call=_prepare_interaction_create_traced_call, + span_type=SpanTypeAttribute.LLM, + ) + + return await _run_async_traced_call( + getattr(instance, "_client", None), + args, + kwargs, + name="interactions.create", + invoke=lambda: wrapped(*args, **kwargs), + process_result=_interaction_process_result, + prepare_call=_prepare_interaction_create_traced_call, + span_type=SpanTypeAttribute.LLM, + ) + + +def _interactions_get_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any) -> Any: + if kwargs.get("stream"): + return _run_stream_traced_call( + getattr(instance, "_client", None), + args, + kwargs, + name="interactions.get", + invoke=lambda: wrapped(*args, **kwargs), + aggregate=_aggregate_interaction_events, + first_token_predicate=_is_interaction_content_event, + prepare_call=_prepare_interaction_get_traced_call, + span_type=SpanTypeAttribute.TASK, + ) + + return _run_traced_call( + getattr(instance, "_client", None), + args, + kwargs, + name="interactions.get", + invoke=lambda: wrapped(*args, **kwargs), + process_result=_interaction_process_result, + prepare_call=_prepare_interaction_get_traced_call, + span_type=SpanTypeAttribute.TASK, + ) + + +async def _async_interactions_get_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any) -> Any: + if kwargs.get("stream"): + return _run_async_stream_traced_call( + getattr(instance, "_client", None), + args, + kwargs, + name="interactions.get", + invoke=lambda: wrapped(*args, **kwargs), + aggregate=_aggregate_interaction_events, + first_token_predicate=_is_interaction_content_event, + prepare_call=_prepare_interaction_get_traced_call, + span_type=SpanTypeAttribute.TASK, + ) + + return await _run_async_traced_call( + getattr(instance, "_client", None), + args, + kwargs, + name="interactions.get", + invoke=lambda: wrapped(*args, **kwargs), + process_result=_interaction_process_result, + prepare_call=_prepare_interaction_get_traced_call, + span_type=SpanTypeAttribute.TASK, + ) + + +def _interactions_cancel_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any) -> Any: + return _run_traced_call( + getattr(instance, "_client", None), + args, + kwargs, + name="interactions.cancel", + invoke=lambda: wrapped(*args, **kwargs), + process_result=_interaction_process_result, + prepare_call=_prepare_interaction_id_traced_call, + span_type=SpanTypeAttribute.TASK, + ) + + +async def _async_interactions_cancel_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any) -> Any: + return await _run_async_traced_call( + getattr(instance, "_client", None), + args, + kwargs, + name="interactions.cancel", + invoke=lambda: wrapped(*args, **kwargs), + process_result=_interaction_process_result, + prepare_call=_prepare_interaction_id_traced_call, + span_type=SpanTypeAttribute.TASK, + ) + + +def _interactions_delete_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any) -> Any: + return _run_traced_call( + getattr(instance, "_client", None), + args, + kwargs, + name="interactions.delete", + invoke=lambda: wrapped(*args, **kwargs), + process_result=_generic_process_result, + prepare_call=_prepare_interaction_id_traced_call, + span_type=SpanTypeAttribute.TASK, + ) + + +async def _async_interactions_delete_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any) -> Any: + return await _run_async_traced_call( + getattr(instance, "_client", None), + args, + kwargs, + name="interactions.delete", + invoke=lambda: wrapped(*args, **kwargs), + process_result=_generic_process_result, + prepare_call=_prepare_interaction_id_traced_call, + span_type=SpanTypeAttribute.TASK, + ) From d5520e4d14fc2adff32d9225e298b4ad8544804f Mon Sep 17 00:00:00 2001 From: Abhijeet Prasad Date: Wed, 8 Apr 2026 10:36:18 -0400 Subject: [PATCH 2/3] fix(google-genai): preserve interaction tool spans during local tool work --- ...n_stays_active_during_local_tool_work.yaml | 109 +++++++ .../google_genai/test_google_genai.py | 48 +++ .../integrations/google_genai/tracing.py | 277 ++++++++++++++++-- 3 files changed, 402 insertions(+), 32 deletions(-) create mode 100644 py/src/braintrust/integrations/google_genai/cassettes/test_interactions_tool_span_stays_active_during_local_tool_work.yaml diff --git a/py/src/braintrust/integrations/google_genai/cassettes/test_interactions_tool_span_stays_active_during_local_tool_work.yaml b/py/src/braintrust/integrations/google_genai/cassettes/test_interactions_tool_span_stays_active_during_local_tool_work.yaml new file mode 100644 index 00000000..98d357b4 --- /dev/null +++ b/py/src/braintrust/integrations/google_genai/cassettes/test_interactions_tool_span_stays_active_during_local_tool_work.yaml @@ -0,0 +1,109 @@ +interactions: +- request: + body: '{"input":"What is the weather like in Paris? Use the tool.","model":"gemini-2.5-flash","tools":[{"type":"function","description":"Get + the current weather for a location.","name":"get_weather","parameters":{"type":"object","properties":{"location":{"type":"string"}},"required":["location"]}}]}' + headers: + Accept: + - application/json + Accept-Encoding: + - gzip, deflate, zstd + Connection: + - keep-alive + Content-Length: + - '293' + Content-Type: + - application/json + Host: + - generativelanguage.googleapis.com + User-Agent: + - google-genai-sdk/1.66.0 gl-python/3.13.3 + x-goog-api-client: + - google-genai-sdk/1.66.0 gl-python/3.13.3 + method: POST + uri: https://generativelanguage.googleapis.com/v1beta/interactions + response: + body: + string: '{"id":"v1_ChczbVhXYWJ1SEhxbW0xTWtQLVA2dDZRYxIXM21YV2FidUhIcW1tMU1rUC1QNnQ2UWM","status":"requires_action","outputs":[{"signature":"Cu4BAb4+9vtHU+JvHsX1tR3UG/YI7ouu7mJBwpb/nalP7MVRXsxe3Ro+Q2pjSV1SdVyHirQfKgeCrZfYRnegxPbVvDIdWi+MAZoHdvfiNBNL5LsLK0pTA6bztJRmE7f2pAhaISzsl2CXXbdqDPMz8K5xOWoV51a/C+9OfzaI2BtqqXuUdU2QipSJYXCyEo5RWzLuuSHjNSLPGP+o2pFJSyE4FepwBLsgh5YZd84KM76nNGU6MkpO09EU/m07bMSX4e+0GoWalYRVk8/tjibpCvxm11Rrmc4gIMoE1xO7VTGhe0qfhleKd9XVjs8bjJxL7Q==","type":"thought"},{"name":"get_weather","arguments":{"location":"Paris"},"type":"function_call","id":"d526dpq4"}],"usage":{"total_tokens":115,"total_input_tokens":53,"input_tokens_by_modality":[{"modality":"text","tokens":53}],"total_cached_tokens":0,"total_output_tokens":15,"total_tool_use_tokens":0,"total_thought_tokens":47},"role":"model","created":"2026-04-08T14:27:43Z","updated":"2026-04-08T14:27:43Z","object":"interaction","model":"gemini-2.5-flash"}' + headers: + Alt-Svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + Content-Type: + - application/json + Date: + - Wed, 08 Apr 2026 14:27:43 GMT + Server: + - scaffolding on HTTPServer2 + Server-Timing: + - gfet4t7; dur=1398 + Transfer-Encoding: + - chunked + Vary: + - Origin + - X-Origin + - Referer + X-Content-Type-Options: + - nosniff + X-Frame-Options: + - SAMEORIGIN + X-XSS-Protection: + - '0' + content-length: + - '917' + status: + code: 200 + message: OK +- request: + body: '{"input":{"call_id":"d526dpq4","result":{"forecast":"sunny"},"type":"function_result","name":"get_weather"},"model":"gemini-2.5-flash","previous_interaction_id":"v1_ChczbVhXYWJ1SEhxbW0xTWtQLVA2dDZRYxIXM21YV2FidUhIcW1tMU1rUC1QNnQ2UWM","tools":[{"type":"function","description":"Get + the current weather for a location.","name":"get_weather","parameters":{"type":"object","properties":{"location":{"type":"string"}},"required":["location"]}}]}' + headers: + Accept: + - application/json + Accept-Encoding: + - gzip, deflate, zstd + Connection: + - keep-alive + Content-Length: + - '440' + Content-Type: + - application/json + Host: + - generativelanguage.googleapis.com + User-Agent: + - google-genai-sdk/1.66.0 gl-python/3.13.3 + x-goog-api-client: + - google-genai-sdk/1.66.0 gl-python/3.13.3 + method: POST + uri: https://generativelanguage.googleapis.com/v1beta/interactions + response: + body: + string: '{"id":"v1_ChczbVhXYWJ1SEhxbW0xTWtQLVA2dDZRYxIWNEdYV2FmbEM1TERVeVFfLTY3aWdDUQ","status":"completed","outputs":[{"signature":"CiRlMjQ4MzBhNy01Y2Q2LTQyZmUtOTk4Yi1lZTUzOWU3MmI5YzM=","type":"thought"},{"text":"The + weather in Paris is sunny.","type":"text"}],"usage":{"total_tokens":61,"total_input_tokens":54,"input_tokens_by_modality":[{"modality":"text","tokens":54}],"total_cached_tokens":0,"total_output_tokens":7,"total_tool_use_tokens":0,"total_thought_tokens":0},"role":"model","created":"2026-04-08T14:27:45Z","updated":"2026-04-08T14:27:45Z","object":"interaction","model":"gemini-2.5-flash"}' + headers: + Alt-Svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + Content-Type: + - application/json + Date: + - Wed, 08 Apr 2026 14:27:45 GMT + Server: + - scaffolding on HTTPServer2 + Server-Timing: + - gfet4t7; dur=1141 + Transfer-Encoding: + - chunked + Vary: + - Origin + - X-Origin + - Referer + X-Content-Type-Options: + - nosniff + X-Frame-Options: + - SAMEORIGIN + X-XSS-Protection: + - '0' + content-length: + - '596' + status: + code: 200 + message: OK +version: 1 diff --git a/py/src/braintrust/integrations/google_genai/test_google_genai.py b/py/src/braintrust/integrations/google_genai/test_google_genai.py index 1a8c8311..d2a80676 100644 --- a/py/src/braintrust/integrations/google_genai/test_google_genai.py +++ b/py/src/braintrust/integrations/google_genai/test_google_genai.py @@ -1295,6 +1295,54 @@ def test_interactions_tool_call_and_follow_up(memory_logger): assert tool_span["span_parents"] == [first_span["span_id"]] +@pytest.mark.vcr +def test_interactions_tool_span_stays_active_during_local_tool_work(memory_logger): + assert not memory_logger.pop() + + client = Client() + tool = _interaction_function_tool() + + first_response = client.interactions.create( + model=INTERACTIONS_MODEL, + input="What is the weather like in Paris? Use the tool.", + tools=[tool], + ) + tool_call = next(output for output in first_response.outputs if output.type == "function_call") + + with logger.start_span(name="nested_tool_work", type=SpanTypeAttribute.TASK) as nested_tool_work: + nested_tool_work.log(output={"forecast": "sunny"}) + + second_response = client.interactions.create( + model=INTERACTIONS_MODEL, + previous_interaction_id=first_response.id, + input=interactions.FunctionResultContent( + type="function_result", + call_id=tool_call.id, + name=tool_call.name, + result={"forecast": "sunny"}, + ), + tools=[tool], + ) + + assert first_response.status == "requires_action" + assert second_response.status == "completed" + + spans = memory_logger.pop() + llm_spans = _find_spans_by_type(spans, SpanTypeAttribute.LLM) + tool_spans = _find_spans_by_type(spans, SpanTypeAttribute.TOOL) + + first_span = next(span for span in llm_spans if span["metadata"]["interaction_id"] == first_response.id) + second_span = next(span for span in llm_spans if span["metadata"]["interaction_id"] == second_response.id) + tool_span = _find_span_by_name(tool_spans, "get_weather") + nested_span = _find_span_by_name(spans, "nested_tool_work") + + assert tool_span["span_parents"] == [first_span["span_id"]] + assert nested_span["span_parents"] == [tool_span["span_id"]] + assert tool_span["metrics"]["start"] <= nested_span["metrics"]["start"] + assert tool_span["metrics"]["end"] >= nested_span["metrics"]["end"] + assert second_span.get("span_parents") in (None, []) + + @pytest.mark.vcr def test_interactions_delete(memory_logger): assert not memory_logger.pop() diff --git a/py/src/braintrust/integrations/google_genai/tracing.py b/py/src/braintrust/integrations/google_genai/tracing.py index b899b897..f6d0fb57 100644 --- a/py/src/braintrust/integrations/google_genai/tracing.py +++ b/py/src/braintrust/integrations/google_genai/tracing.py @@ -2,6 +2,8 @@ import base64 import binascii +import contextvars +import dataclasses import logging import time from collections.abc import Awaitable, Callable, Iterable @@ -43,6 +45,17 @@ } +@dataclasses.dataclass +class _ActiveInteractionToolSpan: + span: Any + is_current: bool = False + + +_interaction_tool_spans: contextvars.ContextVar[dict[str, _ActiveInteractionToolSpan] | None] = contextvars.ContextVar( + "braintrust_google_genai_interaction_tool_spans", default=None +) + + # --------------------------------------------------------------------------- # Serialization helpers # --------------------------------------------------------------------------- @@ -572,7 +585,6 @@ def _interaction_process_result( result: "Interaction", start: float ) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]: outputs_list = _serialize_interaction_outputs(result) - _log_interaction_tool_spans_from_outputs(outputs_list) return ( _extract_interaction_output(result, outputs_list), _extract_interaction_metrics(result, start), @@ -737,7 +749,135 @@ def _reconstruct_interaction_outputs_from_events(events: list[Any]) -> list[dict return [outputs_by_index[index] for index in sorted(outputs_by_index)] -def _log_interaction_tool_spans_from_outputs(outputs: list[dict[str, Any]]) -> None: +def _get_active_interaction_tool_spans() -> dict[str, _ActiveInteractionToolSpan]: + active_tool_spans = _interaction_tool_spans.get() + if active_tool_spans is None: + active_tool_spans = {} + _interaction_tool_spans.set(active_tool_spans) + return active_tool_spans + + +def _tool_span_metadata(call_item: dict[str, Any] | None, result_item: dict[str, Any] | None) -> dict[str, Any] | None: + return ( + _clean( + { + "tool_type": (call_item or result_item or {}).get("type"), + "call_id": (call_item or {}).get("id") or (result_item or {}).get("call_id"), + "server_name": (call_item or result_item or {}).get("server_name"), + "signature": (call_item or result_item or {}).get("signature"), + } + ) + or None + ) + + +def _log_posthoc_interaction_tool_span(call_item: dict[str, Any] | None, result_item: dict[str, Any] | None) -> None: + with start_span( + name=_tool_span_name(call_item, result_item), + type=SpanTypeAttribute.TOOL, + input=_tool_span_input(call_item), + metadata=_tool_span_metadata(call_item, result_item), + ) as tool_span: + if not result_item: + return + if result_item.get("is_error"): + tool_span.log(error=_tool_span_output(result_item)) + else: + tool_span.log(output=_tool_span_output(result_item)) + + +def _cleanup_interaction_tool_span_state(active_tool_spans: dict[str, _ActiveInteractionToolSpan]) -> None: + if active_tool_spans: + return + _interaction_tool_spans.set(None) + + +def _close_active_interaction_tool_span( + call_id: str, result_item: dict[str, Any] | None = None, *, end_time: float | None = None +) -> bool: + active_tool_spans = _get_active_interaction_tool_spans() + active_tool_span = active_tool_spans.pop(call_id, None) + if active_tool_span is None: + return False + + if active_tool_span.is_current: + active_tool_span.span.unset_current() + + if result_item is not None: + if result_item.get("is_error"): + active_tool_span.span.log(error=_tool_span_output(result_item)) + else: + active_tool_span.span.log(output=_tool_span_output(result_item)) + + active_tool_span.span.end(end_time=end_time) + _cleanup_interaction_tool_span_state(active_tool_spans) + return True + + +def _activate_interaction_tool_span( + call_item: dict[str, Any], *, parent_export: str, start_time: float | None = None, set_current: bool = False +) -> None: + # Keep the tool span open across local tool execution so any nested spans + # started by user code naturally inherit from it until the corresponding + # function_result is submitted on a follow-up interactions.create call. + call_id = call_item.get("id") + if not isinstance(call_id, str): + _log_posthoc_interaction_tool_span(call_item, None) + return + + active_tool_spans = _get_active_interaction_tool_spans() + if call_id in active_tool_spans: + return + + tool_span = start_span( + name=_tool_span_name(call_item, None), + type=SpanTypeAttribute.TOOL, + input=_tool_span_input(call_item), + metadata=_tool_span_metadata(call_item, None), + parent=parent_export, + start_time=start_time, + set_current=True, + ) + active_tool_spans[call_id] = _ActiveInteractionToolSpan(span=tool_span, is_current=False) + + if set_current: + tool_span.set_current() + active_tool_spans[call_id].is_current = True + + +def _serialize_interaction_items(value: Any) -> list[dict[str, Any]]: + serialized = _serialize_interaction_value(value) + if serialized is None: + return [] + items = serialized if isinstance(serialized, list) else [serialized] + return [item for item in items if isinstance(item, dict)] + + +def _close_interaction_tool_spans_from_input(input_value: Any) -> None: + # Tool spans should end when the client hands the tool result back to the + # interactions API, before the follow-up LLM/TASK span begins. + end_time = time.time() + for item in _serialize_interaction_items(input_value): + if item.get("type") not in _TOOL_RESULT_TYPES: + continue + call_id = item.get("call_id") + if isinstance(call_id, str): + _close_active_interaction_tool_span(call_id, item, end_time=end_time) + + +def _finalize_interaction_tool_spans( + output: Any, metrics: dict[str, Any], metadata: dict[str, Any] | None, parent_export: str +) -> None: + del metadata + + if not isinstance(output, dict): + return + + outputs = output.get("outputs") + if not isinstance(outputs, list): + return + + active_tool_spans = _get_active_interaction_tool_spans() calls_by_id: dict[str, dict[str, Any]] = {} pending_results: dict[str, list[dict[str, Any]]] = {} pairs: list[tuple[dict[str, Any] | None, dict[str, Any] | None]] = [] @@ -756,7 +896,10 @@ def _log_interaction_tool_spans_from_outputs(outputs: list[dict[str, Any]]) -> N pairs.append((item, None)) elif item_type in _TOOL_RESULT_TYPES: call_id = item.get("call_id") - if isinstance(call_id, str) and call_id in calls_by_id: + if isinstance(call_id, str) and call_id in active_tool_spans: + _close_active_interaction_tool_span(call_id, item, end_time=time.time()) + emitted_call_ids.add(call_id) + elif isinstance(call_id, str) and call_id in calls_by_id: pairs.append((calls_by_id[call_id], item)) emitted_call_ids.add(call_id) elif isinstance(call_id, str): @@ -771,32 +914,36 @@ def _log_interaction_tool_spans_from_outputs(outputs: list[dict[str, Any]]) -> N if call_item is not None: emitted_call_ids.add(call_id) + unpaired_call_items: list[dict[str, Any]] = [] for call_id, call_item in calls_by_id.items(): if call_id not in emitted_call_ids: - pairs.append((call_item, None)) + unpaired_call_items.append(call_item) for call_item, result_item in pairs: - metadata = _clean( - { - "tool_type": (call_item or result_item or {}).get("type"), - "call_id": (call_item or {}).get("id") or (result_item or {}).get("call_id"), - "server_name": (call_item or result_item or {}).get("server_name"), - "signature": (call_item or result_item or {}).get("signature"), - } - ) + _log_posthoc_interaction_tool_span(call_item, result_item) + + activatable_call_items = [ + call_item + for call_item in unpaired_call_items + if isinstance(call_item.get("id"), str) and call_item.get("id") not in active_tool_spans + ] + claim_current = len(activatable_call_items) == 1 and not any( + active_tool_span.is_current for active_tool_span in active_tool_spans.values() + ) - with start_span( - name=_tool_span_name(call_item, result_item), - type=SpanTypeAttribute.TOOL, - input=_tool_span_input(call_item), - metadata=metadata or None, - ) as tool_span: - if not result_item: - continue - if result_item.get("is_error"): - tool_span.log(error=_tool_span_output(result_item)) - else: - tool_span.log(output=_tool_span_output(result_item)) + for call_item in unpaired_call_items: + call_id = call_item.get("id") + if not isinstance(call_id, str): + _log_posthoc_interaction_tool_span(call_item, None) + continue + if call_id in active_tool_spans: + continue + _activate_interaction_tool_span( + call_item, + parent_export=parent_export, + start_time=metrics.get("end"), + set_current=claim_current and call_item is activatable_call_items[0], + ) def _aggregate_interaction_events( @@ -819,7 +966,6 @@ def _aggregate_interaction_events( ) if final_interaction is None: if reconstructed_outputs: - _log_interaction_tool_spans_from_outputs(reconstructed_outputs) return ( {"outputs": reconstructed_outputs, "text": _extract_interaction_text(reconstructed_outputs)}, _clean(metrics), @@ -838,10 +984,6 @@ def _aggregate_interaction_events( return {"events": _serialize_interaction_value(events)}, _clean(metrics), metadata final_outputs_list = _serialize_interaction_outputs(final_interaction) - if final_outputs_list: - _log_interaction_tool_spans_from_outputs(final_outputs_list) - elif reconstructed_outputs: - _log_interaction_tool_spans_from_outputs(reconstructed_outputs) _extract_interaction_usage_metrics(getattr(final_interaction, "usage", None), metrics) metadata.update(_extract_interaction_metadata(final_interaction)) @@ -881,15 +1023,29 @@ def _run_traced_call( [Any, list[Any], dict[str, Any]], tuple[dict[str, Any], dict[str, Any]] ] = _prepare_traced_call, span_type: SpanTypeAttribute = SpanTypeAttribute.LLM, + before_invoke: Callable[[], None] | None = None, + finalize_logged_output: Callable[[Any, dict[str, Any], dict[str, Any] | None, str], None] | None = None, ) -> Any: input, clean_kwargs = prepare_call(api_client, args, kwargs) + if before_invoke is not None: + before_invoke() + start = time.time() + parent_export = None + output = None + metrics = None + metadata = None with start_span(name=name, type=span_type, input=input, metadata=clean_kwargs or None) as span: result = invoke() output, metrics, metadata = _normalize_logged_result(process_result(result, start)) span.log(output=output, metrics=metrics, metadata=metadata) - return result + parent_export = span.export() + + if finalize_logged_output is not None and parent_export is not None and metrics is not None: + finalize_logged_output(output, metrics, metadata, parent_export) + + return result async def _run_async_traced_call( @@ -904,15 +1060,29 @@ async def _run_async_traced_call( [Any, list[Any], dict[str, Any]], tuple[dict[str, Any], dict[str, Any]] ] = _prepare_traced_call, span_type: SpanTypeAttribute = SpanTypeAttribute.LLM, + before_invoke: Callable[[], None] | None = None, + finalize_logged_output: Callable[[Any, dict[str, Any], dict[str, Any] | None, str], None] | None = None, ) -> Any: input, clean_kwargs = prepare_call(api_client, args, kwargs) + if before_invoke is not None: + before_invoke() + start = time.time() + parent_export = None + output = None + metrics = None + metadata = None with start_span(name=name, type=span_type, input=input, metadata=clean_kwargs or None) as span: result = await invoke() output, metrics, metadata = _normalize_logged_result(process_result(result, start)) span.log(output=output, metrics=metrics, metadata=metadata) - return result + parent_export = span.export() + + if finalize_logged_output is not None and parent_export is not None and metrics is not None: + finalize_logged_output(output, metrics, metadata, parent_export) + + return result def _run_stream_traced_call( @@ -930,11 +1100,20 @@ def _run_stream_traced_call( prepare_call: Callable[ [Any, list[Any], dict[str, Any]], tuple[dict[str, Any], dict[str, Any]] ] = _prepare_traced_call, + before_invoke: Callable[[], None] | None = None, + finalize_logged_output: Callable[[Any, dict[str, Any], dict[str, Any] | None, str], None] | None = None, ) -> Any: input, clean_kwargs = prepare_call(api_client, args, kwargs) + if before_invoke is not None: + before_invoke() + start = time.time() first_token_time = None + output = None + metrics = None + metadata = None + parent_export = None with start_span(name=name, type=span_type, input=input, metadata=clean_kwargs or None) as span: chunks = [] for chunk in invoke(): @@ -947,7 +1126,12 @@ def _run_stream_traced_call( output, metrics, metadata = _normalize_logged_result(aggregate(chunks, start, first_token_time)) span.log(output=output, metrics=metrics, metadata=metadata) - return output + parent_export = span.export() + + if finalize_logged_output is not None and parent_export is not None and metrics is not None: + finalize_logged_output(output, metrics, metadata, parent_export) + + return output def _run_async_stream_traced_call( @@ -965,12 +1149,21 @@ def _run_async_stream_traced_call( prepare_call: Callable[ [Any, list[Any], dict[str, Any]], tuple[dict[str, Any], dict[str, Any]] ] = _prepare_traced_call, + before_invoke: Callable[[], None] | None = None, + finalize_logged_output: Callable[[Any, dict[str, Any], dict[str, Any] | None, str], None] | None = None, ) -> Any: input, clean_kwargs = prepare_call(api_client, args, kwargs) async def stream_generator(): + if before_invoke is not None: + before_invoke() + start = time.time() first_token_time = None + output = None + metrics = None + metadata = None + parent_export = None with start_span(name=name, type=span_type, input=input, metadata=clean_kwargs or None) as span: chunks = [] async for chunk in await invoke(): @@ -983,6 +1176,10 @@ async def stream_generator(): output, metrics, metadata = _normalize_logged_result(aggregate(chunks, start, first_token_time)) span.log(output=output, metrics=metrics, metadata=metadata) + parent_export = span.export() + + if finalize_logged_output is not None and parent_export is not None and metrics is not None: + finalize_logged_output(output, metrics, metadata, parent_export) return stream_generator() @@ -1083,6 +1280,8 @@ async def _async_generate_images_wrapper(wrapped: Any, instance: Any, args: Any, def _interactions_create_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any) -> Any: + before_invoke = lambda: _close_interaction_tool_spans_from_input(kwargs.get("input")) + if kwargs.get("stream"): return _run_stream_traced_call( getattr(instance, "_client", None), @@ -1094,6 +1293,8 @@ def _interactions_create_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: first_token_predicate=_is_interaction_content_event, prepare_call=_prepare_interaction_create_traced_call, span_type=SpanTypeAttribute.LLM, + before_invoke=before_invoke, + finalize_logged_output=_finalize_interaction_tool_spans, ) return _run_traced_call( @@ -1105,10 +1306,14 @@ def _interactions_create_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: process_result=_interaction_process_result, prepare_call=_prepare_interaction_create_traced_call, span_type=SpanTypeAttribute.LLM, + before_invoke=before_invoke, + finalize_logged_output=_finalize_interaction_tool_spans, ) async def _async_interactions_create_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any) -> Any: + before_invoke = lambda: _close_interaction_tool_spans_from_input(kwargs.get("input")) + if kwargs.get("stream"): return _run_async_stream_traced_call( getattr(instance, "_client", None), @@ -1120,6 +1325,8 @@ async def _async_interactions_create_wrapper(wrapped: Any, instance: Any, args: first_token_predicate=_is_interaction_content_event, prepare_call=_prepare_interaction_create_traced_call, span_type=SpanTypeAttribute.LLM, + before_invoke=before_invoke, + finalize_logged_output=_finalize_interaction_tool_spans, ) return await _run_async_traced_call( @@ -1131,6 +1338,8 @@ async def _async_interactions_create_wrapper(wrapped: Any, instance: Any, args: process_result=_interaction_process_result, prepare_call=_prepare_interaction_create_traced_call, span_type=SpanTypeAttribute.LLM, + before_invoke=before_invoke, + finalize_logged_output=_finalize_interaction_tool_spans, ) @@ -1146,6 +1355,7 @@ def _interactions_get_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: An first_token_predicate=_is_interaction_content_event, prepare_call=_prepare_interaction_get_traced_call, span_type=SpanTypeAttribute.TASK, + finalize_logged_output=_finalize_interaction_tool_spans, ) return _run_traced_call( @@ -1157,6 +1367,7 @@ def _interactions_get_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: An process_result=_interaction_process_result, prepare_call=_prepare_interaction_get_traced_call, span_type=SpanTypeAttribute.TASK, + finalize_logged_output=_finalize_interaction_tool_spans, ) @@ -1172,6 +1383,7 @@ async def _async_interactions_get_wrapper(wrapped: Any, instance: Any, args: Any first_token_predicate=_is_interaction_content_event, prepare_call=_prepare_interaction_get_traced_call, span_type=SpanTypeAttribute.TASK, + finalize_logged_output=_finalize_interaction_tool_spans, ) return await _run_async_traced_call( @@ -1183,6 +1395,7 @@ async def _async_interactions_get_wrapper(wrapped: Any, instance: Any, args: Any process_result=_interaction_process_result, prepare_call=_prepare_interaction_get_traced_call, span_type=SpanTypeAttribute.TASK, + finalize_logged_output=_finalize_interaction_tool_spans, ) From d024e0012cec4f1c4753f26b1474fcafe578dbcf Mon Sep 17 00:00:00 2001 From: Abhijeet Prasad Date: Wed, 8 Apr 2026 12:28:40 -0400 Subject: [PATCH 3/3] fix(google_genai): skip interactions tests on older google-genai versions The interactions API was introduced in google-genai 1.55.0. The test file imported google.genai.interactions unconditionally at module level, which caused collection to fail on google-genai 1.30.0 (the minimum pinned version in noxfile.py). Make the import conditional and add a skipif marker so the seven interaction tests are gracefully skipped when running against older SDK versions. --- .../google_genai/test_google_genai.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/py/src/braintrust/integrations/google_genai/test_google_genai.py b/py/src/braintrust/integrations/google_genai/test_google_genai.py index d2a80676..8fd93d6d 100644 --- a/py/src/braintrust/integrations/google_genai/test_google_genai.py +++ b/py/src/braintrust/integrations/google_genai/test_google_genai.py @@ -11,7 +11,15 @@ from braintrust.span_types import SpanTypeAttribute from braintrust.test_helpers import init_test_logger from braintrust.wrappers.test_utils import verify_autoinstrument_script -from google.genai import interactions, types +from google.genai import types + + +try: + from google.genai import interactions +except ImportError: + interactions = None + +_needs_interactions = pytest.mark.skipif(interactions is None, reason="google-genai too old for interactions API") from google.genai.client import Client @@ -1188,6 +1196,7 @@ def _interaction_function_tool(): ) +@_needs_interactions @pytest.mark.vcr def test_interactions_create_and_get(memory_logger): assert not memory_logger.pop() @@ -1221,6 +1230,7 @@ def test_interactions_create_and_get(memory_logger): assert "France" in str(get_span["output"]["outputs"]) +@_needs_interactions @pytest.mark.vcr def test_interactions_create_stream(memory_logger): assert not memory_logger.pop() @@ -1248,6 +1258,7 @@ def test_interactions_create_stream(memory_logger): assert "interaction.complete" in create_span["metadata"]["stream_event_types"] +@_needs_interactions @pytest.mark.vcr def test_interactions_tool_call_and_follow_up(memory_logger): assert not memory_logger.pop() @@ -1295,6 +1306,7 @@ def test_interactions_tool_call_and_follow_up(memory_logger): assert tool_span["span_parents"] == [first_span["span_id"]] +@_needs_interactions @pytest.mark.vcr def test_interactions_tool_span_stays_active_during_local_tool_work(memory_logger): assert not memory_logger.pop() @@ -1343,6 +1355,7 @@ def test_interactions_tool_span_stays_active_during_local_tool_work(memory_logge assert second_span.get("span_parents") in (None, []) +@_needs_interactions @pytest.mark.vcr def test_interactions_delete(memory_logger): assert not memory_logger.pop() @@ -1368,6 +1381,7 @@ def test_interactions_delete(memory_logger): assert delete_span["metrics"]["duration"] >= 0 +@_needs_interactions @pytest.mark.vcr @pytest.mark.asyncio async def test_interactions_async_round_trip(memory_logger): @@ -1401,6 +1415,7 @@ async def test_interactions_async_round_trip(memory_logger): assert delete_span["output"] == {} +@_needs_interactions @pytest.mark.vcr @pytest.mark.asyncio async def test_interactions_async_stream(memory_logger):