diff --git a/py/src/braintrust/integrations/google_genai/cassettes/test_interactions_async_round_trip.yaml b/py/src/braintrust/integrations/google_genai/cassettes/test_interactions_async_round_trip.yaml new file mode 100644 index 00000000..46f16f94 --- /dev/null +++ b/py/src/braintrust/integrations/google_genai/cassettes/test_interactions_async_round_trip.yaml @@ -0,0 +1,155 @@ +interactions: +- request: + body: '{"input":"What is the capital of Italy?","model":"gemini-2.5-flash"}' + headers: + Accept: + - application/json + Accept-Encoding: + - gzip, deflate, zstd + Connection: + - keep-alive + Content-Length: + - '68' + Content-Type: + - application/json + Host: + - generativelanguage.googleapis.com + User-Agent: + - google-genai-sdk/1.66.0 gl-python/3.13.3 + x-goog-api-client: + - google-genai-sdk/1.66.0 gl-python/3.13.3 + method: POST + uri: https://generativelanguage.googleapis.com/v1beta/interactions + response: + body: + string: '{"id":"v1_ChdsaExWYWZYWE9mMkIxTWtQNk1EemlBbxIXbGhMVmFmWFhPZjJCMU1rUDZNRHppQW8","status":"completed","outputs":[{"signature":"CpMBAb4+9vv4hIyHoqn7P0G1RcfNsV+D0TVGVpbCfzmIK/crkqR6Qa0S42ec4l0Oq4Z0RbmycMFOP/vrRY+Sc4LXg05Rq9VWmajdjv6nOTQcjUcEiJtClTBse8396TbBWKKXRVDbmV1agysShvPyxEIv0Ics6mNi1npUq+w2TOFJ/KLgd19fIsh8eBVEV3pWdjvJd0ja","type":"thought"},{"text":"The + capital of Italy is **Rome**.","type":"text"}],"usage":{"total_tokens":41,"total_input_tokens":8,"input_tokens_by_modality":[{"modality":"text","tokens":8}],"total_cached_tokens":0,"total_output_tokens":8,"total_tool_use_tokens":0,"total_thought_tokens":25},"role":"model","created":"2026-04-07T14:20:08Z","updated":"2026-04-07T14:20:08Z","object":"interaction","model":"gemini-2.5-flash"}' + headers: + Alt-Svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + Content-Type: + - application/json + Date: + - Tue, 07 Apr 2026 14:20:08 GMT + Server: + - scaffolding on HTTPServer2 + Server-Timing: + - gfet4t7; dur=1348 + Transfer-Encoding: + - chunked + Vary: + - Origin + - X-Origin + - Referer + X-Content-Type-Options: + - nosniff + X-Frame-Options: + - SAMEORIGIN + X-XSS-Protection: + - '0' + content-length: + - '747' + status: + code: 200 + message: OK +- request: + body: '' + headers: + Accept: + - application/json + Accept-Encoding: + - gzip, deflate, zstd + Connection: + - keep-alive + Host: + - generativelanguage.googleapis.com + User-Agent: + - google-genai-sdk/1.66.0 gl-python/3.13.3 + x-goog-api-client: + - google-genai-sdk/1.66.0 gl-python/3.13.3 + method: GET + uri: https://generativelanguage.googleapis.com/v1beta/interactions/v1_ChdsaExWYWZYWE9mMkIxTWtQNk1EemlBbxIXbGhMVmFmWFhPZjJCMU1rUDZNRHppQW8?include_input=true + response: + body: + string: '{"id":"v1_ChdsaExWYWZYWE9mMkIxTWtQNk1EemlBbxIXbGhMVmFmWFhPZjJCMU1rUDZNRHppQW8","status":"completed","outputs":[{"signature":"CpMBAb4+9vv4hIyHoqn7P0G1RcfNsV+D0TVGVpbCfzmIK/crkqR6Qa0S42ec4l0Oq4Z0RbmycMFOP/vrRY+Sc4LXg05Rq9VWmajdjv6nOTQcjUcEiJtClTBse8396TbBWKKXRVDbmV1agysShvPyxEIv0Ics6mNi1npUq+w2TOFJ/KLgd19fIsh8eBVEV3pWdjvJd0ja","type":"thought"},{"text":"The + capital of Italy is **Rome**.","type":"text"}],"usage":{"total_tokens":41,"total_input_tokens":8,"input_tokens_by_modality":[{"modality":"text","tokens":8}],"total_cached_tokens":0,"total_output_tokens":8,"total_tool_use_tokens":0,"total_thought_tokens":25},"role":"model","created":"2026-04-07T14:20:08Z","updated":"2026-04-07T14:20:08Z","object":"interaction","model":"gemini-2.5-flash","input":[{"role":"user","content":[{"text":"What + is the capital of Italy?","type":"text"}]}]}' + headers: + Alt-Svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + Content-Type: + - application/json + Date: + - Tue, 07 Apr 2026 14:20:08 GMT + Server: + - scaffolding on HTTPServer2 + Server-Timing: + - gfet4t7; dur=130 + Transfer-Encoding: + - chunked + Vary: + - Origin + - X-Origin + - Referer + X-Content-Type-Options: + - nosniff + X-Frame-Options: + - SAMEORIGIN + X-XSS-Protection: + - '0' + content-length: + - '840' + status: + code: 200 + message: OK +- request: + body: '' + headers: + Accept: + - application/json + Accept-Encoding: + - gzip, deflate, zstd + Connection: + - keep-alive + Content-Type: + - application/json + Host: + - generativelanguage.googleapis.com + User-Agent: + - google-genai-sdk/1.66.0 gl-python/3.13.3 + x-goog-api-client: + - google-genai-sdk/1.66.0 gl-python/3.13.3 + method: DELETE + uri: https://generativelanguage.googleapis.com/v1beta/interactions/v1_ChdsaExWYWZYWE9mMkIxTWtQNk1EemlBbxIXbGhMVmFmWFhPZjJCMU1rUDZNRHppQW8 + response: + body: + string: '{}' + headers: + Alt-Svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + Content-Type: + - application/json + Date: + - Tue, 07 Apr 2026 14:20:09 GMT + Server: + - scaffolding on HTTPServer2 + Server-Timing: + - gfet4t7; dur=382 + Transfer-Encoding: + - chunked + Vary: + - Origin + - X-Origin + - Referer + X-Content-Type-Options: + - nosniff + X-Frame-Options: + - SAMEORIGIN + X-XSS-Protection: + - '0' + content-length: + - '2' + status: + code: 200 + message: OK +version: 1 diff --git a/py/src/braintrust/integrations/google_genai/cassettes/test_interactions_async_stream.yaml b/py/src/braintrust/integrations/google_genai/cassettes/test_interactions_async_stream.yaml new file mode 100644 index 00000000..7bda7e65 --- /dev/null +++ b/py/src/braintrust/integrations/google_genai/cassettes/test_interactions_async_stream.yaml @@ -0,0 +1,104 @@ +interactions: +- request: + body: '{"input":"Say hi shortly.","model":"gemini-2.5-flash","stream":true}' + headers: + Accept: + - application/json + Accept-Encoding: + - gzip, deflate, zstd + Connection: + - keep-alive + Content-Length: + - '68' + Content-Type: + - application/json + Host: + - generativelanguage.googleapis.com + User-Agent: + - google-genai-sdk/1.66.0 gl-python/3.13.3 + x-goog-api-client: + - google-genai-sdk/1.66.0 gl-python/3.13.3 + method: POST + uri: https://generativelanguage.googleapis.com/v1beta/interactions + response: + body: + string: 'event: interaction.start + + data: {"interaction":{"id":"v1_ChdtUkxWYWVqRUVZQ3YxTWtQb2IydS1RRRIXbVJMVmFlakVFWUN2MU1rUG9iMnUtUUU","status":"in_progress","object":"interaction","model":"gemini-2.5-flash"},"event_type":"interaction.start"} + + + event: interaction.status_update + + data: {"interaction_id":"v1_ChdtUkxWYWVqRUVZQ3YxTWtQb2IydS1RRRIXbVJMVmFlakVFWUN2MU1rUG9iMnUtUUU","status":"in_progress","event_type":"interaction.status_update"} + + + event: content.start + + data: {"index":0,"content":{"type":"thought"},"event_type":"content.start"} + + + event: content.delta + + data: {"index":0,"delta":{"signature":"ClUBvj72+9F73CQIW9AyrjKse6nIY6vxogBksIaH5zxIIsOMHbbkpLYbB3nSdvBf3485j4bntqHPZjSbypKwI84aO8RagejmAGNRu/Ss+Mebq15i5UIuCnUBvj72+0Ewr8MFjWarSxByDjIn8WzosxL4jSr61M/Cdf2bvYpcks9icyyuHi+zvrfpTOop2ODlpoa7Y/UexoemmbjdKpZeXy5wHb4R3eaiRAeucVPtpoMNFQi5TocaZXfK9hP5BP4p44e9zrJBVUYigBehaec=","type":"thought_signature"},"event_type":"content.delta"} + + + event: content.stop + + data: {"index":0,"event_type":"content.stop"} + + + event: content.start + + data: {"index":1,"content":{"type":"text"},"event_type":"content.start"} + + + event: content.delta + + data: {"index":1,"delta":{"text":"Hi!","type":"text"},"event_type":"content.delta"} + + + event: content.stop + + data: {"index":1,"event_type":"content.stop"} + + + event: interaction.complete + + data: {"interaction":{"id":"v1_ChdtUkxWYWVqRUVZQ3YxTWtQb2IydS1RRRIXbVJMVmFlakVFWUN2MU1rUG9iMnUtUUU","status":"completed","usage":{"total_tokens":42,"total_input_tokens":5,"input_tokens_by_modality":[{"modality":"text","tokens":5}],"total_cached_tokens":0,"total_output_tokens":2,"total_tool_use_tokens":0,"total_thought_tokens":35},"role":"model","created":"2026-04-07T14:20:10Z","updated":"2026-04-07T14:20:10Z","object":"interaction","model":"gemini-2.5-flash"},"event_type":"interaction.complete"} + + + event: done + + data: [DONE] + + + ' + headers: + Alt-Svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + Content-Type: + - text/event-stream + Date: + - Tue, 07 Apr 2026 14:20:10 GMT + Server: + - scaffolding on HTTPServer2 + Server-Timing: + - gfet4t7; dur=795 + Transfer-Encoding: + - chunked + Vary: + - Origin + - X-Origin + - Referer + X-Content-Type-Options: + - nosniff + X-Frame-Options: + - SAMEORIGIN + X-XSS-Protection: + - '0' + content-length: + - '1816' + status: + code: 200 + message: OK +version: 1 diff --git a/py/src/braintrust/integrations/google_genai/cassettes/test_interactions_create_and_get.yaml b/py/src/braintrust/integrations/google_genai/cassettes/test_interactions_create_and_get.yaml new file mode 100644 index 00000000..fb100cbb --- /dev/null +++ b/py/src/braintrust/integrations/google_genai/cassettes/test_interactions_create_and_get.yaml @@ -0,0 +1,105 @@ +interactions: +- request: + body: '{"input":"What is the capital of France?","model":"gemini-2.5-flash"}' + headers: + Accept: + - application/json + Accept-Encoding: + - gzip, deflate, zstd + Connection: + - keep-alive + Content-Length: + - '69' + Content-Type: + - application/json + Host: + - generativelanguage.googleapis.com + User-Agent: + - google-genai-sdk/1.66.0 gl-python/3.13.3 + x-goog-api-client: + - google-genai-sdk/1.66.0 gl-python/3.13.3 + method: POST + uri: https://generativelanguage.googleapis.com/v1beta/interactions + response: + body: + string: '{"id":"v1_ChdqaExWYWVIQUQ4YTIxTWtQZ2ZxUDRRWRIXamhMVmFlSEFEOGEyMU1rUGdmcVA0UVk","status":"completed","outputs":[{"signature":"Cp8BAb4+9vtLsikCpBEO0omEzOmvrAc1XfoRIv4pg6pRC203MxRPHMkBCpCP1EfSaBQ8Ypk3CQsck/Yi+7N/+yj9vbqfCh2+idJjFcWg7Sxo64B10gs8O67D06FhP/gDpnv/hLOImOAKwnw3JYsQtVnAINE2zYTMVeirk116zvayi5D5iOzQ+2/SuQCtuUhIXWt2Jyy3LHcEUDfeJgO7uHXw","type":"thought"},{"text":"The + capital of France is **Paris**.","type":"text"}],"usage":{"total_tokens":41,"total_input_tokens":8,"input_tokens_by_modality":[{"modality":"text","tokens":8}],"total_cached_tokens":0,"total_output_tokens":8,"total_tool_use_tokens":0,"total_thought_tokens":25},"role":"model","created":"2026-04-07T14:19:59Z","updated":"2026-04-07T14:19:59Z","object":"interaction","model":"gemini-2.5-flash"}' + headers: + Alt-Svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + Content-Type: + - application/json + Date: + - Tue, 07 Apr 2026 14:19:59 GMT + Server: + - scaffolding on HTTPServer2 + Server-Timing: + - gfet4t7; dur=1225 + Transfer-Encoding: + - chunked + Vary: + - Origin + - X-Origin + - Referer + X-Content-Type-Options: + - nosniff + X-Frame-Options: + - SAMEORIGIN + X-XSS-Protection: + - '0' + content-length: + - '765' + status: + code: 200 + message: OK +- request: + body: '' + headers: + Accept: + - application/json + Accept-Encoding: + - gzip, deflate, zstd + Connection: + - keep-alive + Host: + - generativelanguage.googleapis.com + User-Agent: + - google-genai-sdk/1.66.0 gl-python/3.13.3 + x-goog-api-client: + - google-genai-sdk/1.66.0 gl-python/3.13.3 + method: GET + uri: https://generativelanguage.googleapis.com/v1beta/interactions/v1_ChdqaExWYWVIQUQ4YTIxTWtQZ2ZxUDRRWRIXamhMVmFlSEFEOGEyMU1rUGdmcVA0UVk?include_input=true + response: + body: + string: '{"id":"v1_ChdqaExWYWVIQUQ4YTIxTWtQZ2ZxUDRRWRIXamhMVmFlSEFEOGEyMU1rUGdmcVA0UVk","status":"completed","outputs":[{"signature":"Cp8BAb4+9vtLsikCpBEO0omEzOmvrAc1XfoRIv4pg6pRC203MxRPHMkBCpCP1EfSaBQ8Ypk3CQsck/Yi+7N/+yj9vbqfCh2+idJjFcWg7Sxo64B10gs8O67D06FhP/gDpnv/hLOImOAKwnw3JYsQtVnAINE2zYTMVeirk116zvayi5D5iOzQ+2/SuQCtuUhIXWt2Jyy3LHcEUDfeJgO7uHXw","type":"thought"},{"text":"The + capital of France is **Paris**.","type":"text"}],"usage":{"total_tokens":41,"total_input_tokens":8,"input_tokens_by_modality":[{"modality":"text","tokens":8}],"total_cached_tokens":0,"total_output_tokens":8,"total_tool_use_tokens":0,"total_thought_tokens":25},"role":"model","created":"2026-04-07T14:19:59Z","updated":"2026-04-07T14:19:59Z","object":"interaction","model":"gemini-2.5-flash","input":[{"role":"user","content":[{"text":"What + is the capital of France?","type":"text"}]}]}' + headers: + Alt-Svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + Content-Type: + - application/json + Date: + - Tue, 07 Apr 2026 14:19:59 GMT + Server: + - scaffolding on HTTPServer2 + Server-Timing: + - gfet4t7; dur=102 + Transfer-Encoding: + - chunked + Vary: + - Origin + - X-Origin + - Referer + X-Content-Type-Options: + - nosniff + X-Frame-Options: + - SAMEORIGIN + X-XSS-Protection: + - '0' + content-length: + - '859' + status: + code: 200 + message: OK +version: 1 diff --git a/py/src/braintrust/integrations/google_genai/cassettes/test_interactions_create_stream.yaml b/py/src/braintrust/integrations/google_genai/cassettes/test_interactions_create_stream.yaml new file mode 100644 index 00000000..549a88fa --- /dev/null +++ b/py/src/braintrust/integrations/google_genai/cassettes/test_interactions_create_stream.yaml @@ -0,0 +1,109 @@ +interactions: +- request: + body: '{"input":"Say hi in five words or less.","model":"gemini-2.5-flash","stream":true}' + headers: + Accept: + - application/json + Accept-Encoding: + - gzip, deflate, zstd + Connection: + - keep-alive + Content-Length: + - '82' + Content-Type: + - application/json + Host: + - generativelanguage.googleapis.com + User-Agent: + - google-genai-sdk/1.66.0 gl-python/3.13.3 + x-goog-api-client: + - google-genai-sdk/1.66.0 gl-python/3.13.3 + method: POST + uri: https://generativelanguage.googleapis.com/v1beta/interactions + response: + body: + string: 'event: interaction.start + + data: {"interaction":{"id":"v1_ChdrQkxWYWJfc0NOV1o5TW9QbTZtRm9RSRIXa0JMVmFiX3NDTldaOU1vUG02bUZvUUk","status":"in_progress","object":"interaction","model":"gemini-2.5-flash"},"event_type":"interaction.start"} + + + event: interaction.status_update + + data: {"interaction_id":"v1_ChdrQkxWYWJfc0NOV1o5TW9QbTZtRm9RSRIXa0JMVmFiX3NDTldaOU1vUG02bUZvUUk","status":"in_progress","event_type":"interaction.status_update"} + + + event: content.start + + data: {"index":0,"content":{"type":"thought"},"event_type":"content.start"} + + + event: content.delta + + data: {"index":0,"delta":{"signature":"ClsBvj72+zoBtQP6EjTMrMyD5hM3sBNS/6cPoB/L7fFYkvzXlH/JhNd91N/Lf2BehMZvRijqF0Nk04nFHWjRcey23m0ly4uN3PaKv1jPos1jruzVSo2kjBna5/dvCrUBAb4+9vuoYgWiAXCse1w4Zc6moR0NuIZEblr4QuJSCF8OmKREmC6SLGT22wyV/Hpdhdy8AnoQGxF7d/gnR+RDBgBSH2krIW7HchLuqxZKmahXFuTjuFD8QDbCvOUJ9JGdGvM5tpW/akOdcPzMnZqm0C0/0NCPOG6VhbaloyGrQbI370a546lOHiY7ZdocwKb/pDEZRJIWqHOpvXviWhVIj/jM6gk+n9kRejz+ZfBGtPL/jeRhJQ==","type":"thought_signature"},"event_type":"content.delta"} + + + event: content.stop + + data: {"index":0,"event_type":"content.stop"} + + + event: content.start + + data: {"index":1,"content":{"type":"text"},"event_type":"content.start"} + + + event: content.delta + + data: {"index":1,"delta":{"text":"Hi there","type":"text"},"event_type":"content.delta"} + + + event: content.delta + + data: {"index":1,"delta":{"text":"!","type":"text"},"event_type":"content.delta"} + + + event: content.stop + + data: {"index":1,"event_type":"content.stop"} + + + event: interaction.complete + + data: {"interaction":{"id":"v1_ChdrQkxWYWJfc0NOV1o5TW9QbTZtRm9RSRIXa0JMVmFiX3NDTldaOU1vUG02bUZvUUk","status":"completed","usage":{"total_tokens":74,"total_input_tokens":9,"input_tokens_by_modality":[{"modality":"text","tokens":9}],"total_cached_tokens":0,"total_output_tokens":3,"total_tool_use_tokens":0,"total_thought_tokens":62},"role":"model","created":"2026-04-07T14:20:01Z","updated":"2026-04-07T14:20:01Z","object":"interaction","model":"gemini-2.5-flash"},"event_type":"interaction.complete"} + + + event: done + + data: [DONE] + + + ' + headers: + Alt-Svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + Content-Type: + - text/event-stream + Date: + - Tue, 07 Apr 2026 14:20:00 GMT + Server: + - scaffolding on HTTPServer2 + Server-Timing: + - gfet4t7; dur=887 + Transfer-Encoding: + - chunked + Vary: + - Origin + - X-Origin + - Referer + X-Content-Type-Options: + - nosniff + X-Frame-Options: + - SAMEORIGIN + X-XSS-Protection: + - '0' + content-length: + - '2021' + status: + code: 200 + message: OK +version: 1 diff --git a/py/src/braintrust/integrations/google_genai/cassettes/test_interactions_delete.yaml b/py/src/braintrust/integrations/google_genai/cassettes/test_interactions_delete.yaml new file mode 100644 index 00000000..19575b5a --- /dev/null +++ b/py/src/braintrust/integrations/google_genai/cassettes/test_interactions_delete.yaml @@ -0,0 +1,104 @@ +interactions: +- request: + body: '{"input":"Reply with exactly ok.","model":"gemini-2.5-flash"}' + headers: + Accept: + - application/json + Accept-Encoding: + - gzip, deflate, zstd + Connection: + - keep-alive + Content-Length: + - '61' + Content-Type: + - application/json + Host: + - generativelanguage.googleapis.com + User-Agent: + - google-genai-sdk/1.66.0 gl-python/3.13.3 + x-goog-api-client: + - google-genai-sdk/1.66.0 gl-python/3.13.3 + method: POST + uri: https://generativelanguage.googleapis.com/v1beta/interactions + response: + body: + string: '{"id":"v1_ChdsQkxWYWJtU09ZMmMxTWtQcFpQYXdRdxIXbEJMVmFibVNPWTJjMU1rUHBaUGF3UXc","status":"completed","outputs":[{"signature":"CmABvj72+7WzytAAFU/DR5hYJRFlxm4yJKjxhaIxo1oQQjH2hur9PHCtCrldHToJoyOVozXlgHA9xq7trYBtqGiCZAO/MfzKIz9M0OGv6W9JVTUlAV/wNZyqz9j3OJV1wOU=","type":"thought"},{"text":"ok","type":"text"}],"usage":{"total_tokens":23,"total_input_tokens":6,"input_tokens_by_modality":[{"modality":"text","tokens":6}],"total_cached_tokens":0,"total_output_tokens":1,"total_tool_use_tokens":0,"total_thought_tokens":16},"role":"model","created":"2026-04-07T14:20:06Z","updated":"2026-04-07T14:20:06Z","object":"interaction","model":"gemini-2.5-flash"}' + headers: + Alt-Svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + Content-Type: + - application/json + Date: + - Tue, 07 Apr 2026 14:20:06 GMT + Server: + - scaffolding on HTTPServer2 + Server-Timing: + - gfet4t7; dur=1269 + Transfer-Encoding: + - chunked + Vary: + - Origin + - X-Origin + - Referer + X-Content-Type-Options: + - nosniff + X-Frame-Options: + - SAMEORIGIN + X-XSS-Protection: + - '0' + content-length: + - '648' + status: + code: 200 + message: OK +- request: + body: '' + headers: + Accept: + - application/json + Accept-Encoding: + - gzip, deflate, zstd + Connection: + - keep-alive + Content-Type: + - application/json + Host: + - generativelanguage.googleapis.com + User-Agent: + - google-genai-sdk/1.66.0 gl-python/3.13.3 + x-goog-api-client: + - google-genai-sdk/1.66.0 gl-python/3.13.3 + method: DELETE + uri: https://generativelanguage.googleapis.com/v1beta/interactions/v1_ChdsQkxWYWJtU09ZMmMxTWtQcFpQYXdRdxIXbEJMVmFibVNPWTJjMU1rUHBaUGF3UXc + response: + body: + string: '{}' + headers: + Alt-Svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + Content-Type: + - application/json + Date: + - Tue, 07 Apr 2026 14:20:06 GMT + Server: + - scaffolding on HTTPServer2 + Server-Timing: + - gfet4t7; dur=390 + Transfer-Encoding: + - chunked + Vary: + - Origin + - X-Origin + - Referer + X-Content-Type-Options: + - nosniff + X-Frame-Options: + - SAMEORIGIN + X-XSS-Protection: + - '0' + content-length: + - '2' + status: + code: 200 + message: OK +version: 1 diff --git a/py/src/braintrust/integrations/google_genai/cassettes/test_interactions_tool_call_and_follow_up.yaml b/py/src/braintrust/integrations/google_genai/cassettes/test_interactions_tool_call_and_follow_up.yaml new file mode 100644 index 00000000..4b36853d --- /dev/null +++ b/py/src/braintrust/integrations/google_genai/cassettes/test_interactions_tool_call_and_follow_up.yaml @@ -0,0 +1,109 @@ +interactions: +- request: + body: '{"input":"What is the weather like in Paris? Use the tool.","model":"gemini-2.5-flash","tools":[{"type":"function","description":"Get + the current weather for a location.","name":"get_weather","parameters":{"type":"object","properties":{"location":{"type":"string"}},"required":["location"]}}]}' + headers: + Accept: + - application/json + Accept-Encoding: + - gzip, deflate, zstd + Connection: + - keep-alive + Content-Length: + - '293' + Content-Type: + - application/json + Host: + - generativelanguage.googleapis.com + User-Agent: + - google-genai-sdk/1.66.0 gl-python/3.13.3 + x-goog-api-client: + - google-genai-sdk/1.66.0 gl-python/3.13.3 + method: POST + uri: https://generativelanguage.googleapis.com/v1beta/interactions + response: + body: + string: '{"id":"v1_ChdrUkxWYWZIbEtxeWsxTWtQNTl1Ry1BbxIXa1JMVmFmSGxLcXlrMU1rUDU5dUctQW8","status":"requires_action","outputs":[{"signature":"Cv4BAb4+9vv0gIjK+dI8G/Gk9ZNL+77B3vRd+v3fJn9pPaT8958QL7Fh3Jzss/lUrm0BiaMlHjTqifxqb5h9Mzy8I5SJ1X7sTN8wWBGhpnFUTsZGVw3Cg8e0mdUJ6MnOxDCPNEmPsG8xe+N7forjuaU74YxGhtt8Ase3PWlvsQ9rXWQSQI3+EwiriF9zT+rF0iUXYsmlNC2jeVydCRJzetADEP1whtaY6Qm+8SHGwELS/T1ZyhOJtzE56mtdVQwvt1pr1fSINrVApTUAwmvJNJsDnndwNX06AXqKyx6l61qAlgSqoeo/1Zb7a2R6oJhmDeuvifsJyNbFJFOmyd9aClw=","type":"thought"},{"name":"get_weather","arguments":{"location":"Paris"},"type":"function_call","id":"y124y6bd"}],"usage":{"total_tokens":119,"total_input_tokens":53,"input_tokens_by_modality":[{"modality":"text","tokens":53}],"total_cached_tokens":0,"total_output_tokens":15,"total_tool_use_tokens":0,"total_thought_tokens":51},"role":"model","created":"2026-04-07T14:20:03Z","updated":"2026-04-07T14:20:03Z","object":"interaction","model":"gemini-2.5-flash"}' + headers: + Alt-Svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + Content-Type: + - application/json + Date: + - Tue, 07 Apr 2026 14:20:03 GMT + Server: + - scaffolding on HTTPServer2 + Server-Timing: + - gfet4t7; dur=1385 + Transfer-Encoding: + - chunked + Vary: + - Origin + - X-Origin + - Referer + X-Content-Type-Options: + - nosniff + X-Frame-Options: + - SAMEORIGIN + X-XSS-Protection: + - '0' + content-length: + - '937' + status: + code: 200 + message: OK +- request: + body: '{"input":{"call_id":"y124y6bd","result":{"forecast":"sunny"},"type":"function_result","name":"get_weather"},"model":"gemini-2.5-flash","previous_interaction_id":"v1_ChdrUkxWYWZIbEtxeWsxTWtQNTl1Ry1BbxIXa1JMVmFmSGxLcXlrMU1rUDU5dUctQW8","tools":[{"type":"function","description":"Get + the current weather for a location.","name":"get_weather","parameters":{"type":"object","properties":{"location":{"type":"string"}},"required":["location"]}}]}' + headers: + Accept: + - application/json + Accept-Encoding: + - gzip, deflate, zstd + Connection: + - keep-alive + Content-Length: + - '440' + Content-Type: + - application/json + Host: + - generativelanguage.googleapis.com + User-Agent: + - google-genai-sdk/1.66.0 gl-python/3.13.3 + x-goog-api-client: + - google-genai-sdk/1.66.0 gl-python/3.13.3 + method: POST + uri: https://generativelanguage.googleapis.com/v1beta/interactions + response: + body: + string: '{"id":"v1_ChdrUkxWYWZIbEtxeWsxTWtQNTl1Ry1BbxIXa3hMVmFhTzJFOW1XMU1rUHhLMmlzUXM","status":"completed","outputs":[{"signature":"CiRlMjQ4MzBhNy01Y2Q2LTQyZmUtOTk4Yi1lZTUzOWU3MmI5YzM=","type":"thought"},{"text":"The + weather in Paris is sunny.","type":"text"}],"usage":{"total_tokens":61,"total_input_tokens":54,"input_tokens_by_modality":[{"modality":"text","tokens":54}],"total_cached_tokens":0,"total_output_tokens":7,"total_tool_use_tokens":0,"total_thought_tokens":0},"role":"model","created":"2026-04-07T14:20:04Z","updated":"2026-04-07T14:20:04Z","object":"interaction","model":"gemini-2.5-flash"}' + headers: + Alt-Svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + Content-Type: + - application/json + Date: + - Tue, 07 Apr 2026 14:20:04 GMT + Server: + - scaffolding on HTTPServer2 + Server-Timing: + - gfet4t7; dur=1483 + Transfer-Encoding: + - chunked + Vary: + - Origin + - X-Origin + - Referer + X-Content-Type-Options: + - nosniff + X-Frame-Options: + - SAMEORIGIN + X-XSS-Protection: + - '0' + content-length: + - '597' + status: + code: 200 + message: OK +version: 1 diff --git a/py/src/braintrust/integrations/google_genai/cassettes/test_interactions_tool_span_stays_active_during_local_tool_work.yaml b/py/src/braintrust/integrations/google_genai/cassettes/test_interactions_tool_span_stays_active_during_local_tool_work.yaml new file mode 100644 index 00000000..98d357b4 --- /dev/null +++ b/py/src/braintrust/integrations/google_genai/cassettes/test_interactions_tool_span_stays_active_during_local_tool_work.yaml @@ -0,0 +1,109 @@ +interactions: +- request: + body: '{"input":"What is the weather like in Paris? Use the tool.","model":"gemini-2.5-flash","tools":[{"type":"function","description":"Get + the current weather for a location.","name":"get_weather","parameters":{"type":"object","properties":{"location":{"type":"string"}},"required":["location"]}}]}' + headers: + Accept: + - application/json + Accept-Encoding: + - gzip, deflate, zstd + Connection: + - keep-alive + Content-Length: + - '293' + Content-Type: + - application/json + Host: + - generativelanguage.googleapis.com + User-Agent: + - google-genai-sdk/1.66.0 gl-python/3.13.3 + x-goog-api-client: + - google-genai-sdk/1.66.0 gl-python/3.13.3 + method: POST + uri: https://generativelanguage.googleapis.com/v1beta/interactions + response: + body: + string: '{"id":"v1_ChczbVhXYWJ1SEhxbW0xTWtQLVA2dDZRYxIXM21YV2FidUhIcW1tMU1rUC1QNnQ2UWM","status":"requires_action","outputs":[{"signature":"Cu4BAb4+9vtHU+JvHsX1tR3UG/YI7ouu7mJBwpb/nalP7MVRXsxe3Ro+Q2pjSV1SdVyHirQfKgeCrZfYRnegxPbVvDIdWi+MAZoHdvfiNBNL5LsLK0pTA6bztJRmE7f2pAhaISzsl2CXXbdqDPMz8K5xOWoV51a/C+9OfzaI2BtqqXuUdU2QipSJYXCyEo5RWzLuuSHjNSLPGP+o2pFJSyE4FepwBLsgh5YZd84KM76nNGU6MkpO09EU/m07bMSX4e+0GoWalYRVk8/tjibpCvxm11Rrmc4gIMoE1xO7VTGhe0qfhleKd9XVjs8bjJxL7Q==","type":"thought"},{"name":"get_weather","arguments":{"location":"Paris"},"type":"function_call","id":"d526dpq4"}],"usage":{"total_tokens":115,"total_input_tokens":53,"input_tokens_by_modality":[{"modality":"text","tokens":53}],"total_cached_tokens":0,"total_output_tokens":15,"total_tool_use_tokens":0,"total_thought_tokens":47},"role":"model","created":"2026-04-08T14:27:43Z","updated":"2026-04-08T14:27:43Z","object":"interaction","model":"gemini-2.5-flash"}' + headers: + Alt-Svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + Content-Type: + - application/json + Date: + - Wed, 08 Apr 2026 14:27:43 GMT + Server: + - scaffolding on HTTPServer2 + Server-Timing: + - gfet4t7; dur=1398 + Transfer-Encoding: + - chunked + Vary: + - Origin + - X-Origin + - Referer + X-Content-Type-Options: + - nosniff + X-Frame-Options: + - SAMEORIGIN + X-XSS-Protection: + - '0' + content-length: + - '917' + status: + code: 200 + message: OK +- request: + body: '{"input":{"call_id":"d526dpq4","result":{"forecast":"sunny"},"type":"function_result","name":"get_weather"},"model":"gemini-2.5-flash","previous_interaction_id":"v1_ChczbVhXYWJ1SEhxbW0xTWtQLVA2dDZRYxIXM21YV2FidUhIcW1tMU1rUC1QNnQ2UWM","tools":[{"type":"function","description":"Get + the current weather for a location.","name":"get_weather","parameters":{"type":"object","properties":{"location":{"type":"string"}},"required":["location"]}}]}' + headers: + Accept: + - application/json + Accept-Encoding: + - gzip, deflate, zstd + Connection: + - keep-alive + Content-Length: + - '440' + Content-Type: + - application/json + Host: + - generativelanguage.googleapis.com + User-Agent: + - google-genai-sdk/1.66.0 gl-python/3.13.3 + x-goog-api-client: + - google-genai-sdk/1.66.0 gl-python/3.13.3 + method: POST + uri: https://generativelanguage.googleapis.com/v1beta/interactions + response: + body: + string: '{"id":"v1_ChczbVhXYWJ1SEhxbW0xTWtQLVA2dDZRYxIWNEdYV2FmbEM1TERVeVFfLTY3aWdDUQ","status":"completed","outputs":[{"signature":"CiRlMjQ4MzBhNy01Y2Q2LTQyZmUtOTk4Yi1lZTUzOWU3MmI5YzM=","type":"thought"},{"text":"The + weather in Paris is sunny.","type":"text"}],"usage":{"total_tokens":61,"total_input_tokens":54,"input_tokens_by_modality":[{"modality":"text","tokens":54}],"total_cached_tokens":0,"total_output_tokens":7,"total_tool_use_tokens":0,"total_thought_tokens":0},"role":"model","created":"2026-04-08T14:27:45Z","updated":"2026-04-08T14:27:45Z","object":"interaction","model":"gemini-2.5-flash"}' + headers: + Alt-Svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + Content-Type: + - application/json + Date: + - Wed, 08 Apr 2026 14:27:45 GMT + Server: + - scaffolding on HTTPServer2 + Server-Timing: + - gfet4t7; dur=1141 + Transfer-Encoding: + - chunked + Vary: + - Origin + - X-Origin + - Referer + X-Content-Type-Options: + - nosniff + X-Frame-Options: + - SAMEORIGIN + X-XSS-Protection: + - '0' + content-length: + - '596' + status: + code: 200 + message: OK +version: 1 diff --git a/py/src/braintrust/integrations/google_genai/integration.py b/py/src/braintrust/integrations/google_genai/integration.py index 3b9365e9..4efdf28f 100644 --- a/py/src/braintrust/integrations/google_genai/integration.py +++ b/py/src/braintrust/integrations/google_genai/integration.py @@ -5,10 +5,18 @@ from braintrust.integrations.base import BaseIntegration from .patchers import ( + AsyncInteractionsCancelPatcher, + AsyncInteractionsCreatePatcher, + AsyncInteractionsDeletePatcher, + AsyncInteractionsGetPatcher, AsyncModelsEmbedContentPatcher, AsyncModelsGenerateContentPatcher, AsyncModelsGenerateContentStreamPatcher, AsyncModelsGenerateImagesPatcher, + InteractionsCancelPatcher, + InteractionsCreatePatcher, + InteractionsDeletePatcher, + InteractionsGetPatcher, ModelsEmbedContentPatcher, ModelsGenerateContentPatcher, ModelsGenerateContentStreamPatcher, @@ -30,8 +38,16 @@ class GoogleGenAIIntegration(BaseIntegration): ModelsGenerateContentStreamPatcher, ModelsEmbedContentPatcher, ModelsGenerateImagesPatcher, + InteractionsCreatePatcher, + InteractionsGetPatcher, + InteractionsCancelPatcher, + InteractionsDeletePatcher, AsyncModelsGenerateContentPatcher, AsyncModelsGenerateContentStreamPatcher, AsyncModelsEmbedContentPatcher, AsyncModelsGenerateImagesPatcher, + AsyncInteractionsCreatePatcher, + AsyncInteractionsGetPatcher, + AsyncInteractionsCancelPatcher, + AsyncInteractionsDeletePatcher, ) diff --git a/py/src/braintrust/integrations/google_genai/patchers.py b/py/src/braintrust/integrations/google_genai/patchers.py index 7bc895f1..dc4433a0 100644 --- a/py/src/braintrust/integrations/google_genai/patchers.py +++ b/py/src/braintrust/integrations/google_genai/patchers.py @@ -7,10 +7,18 @@ _async_generate_content_stream_wrapper, _async_generate_content_wrapper, _async_generate_images_wrapper, + _async_interactions_cancel_wrapper, + _async_interactions_create_wrapper, + _async_interactions_delete_wrapper, + _async_interactions_get_wrapper, _embed_content_wrapper, _generate_content_stream_wrapper, _generate_content_wrapper, _generate_images_wrapper, + _interactions_cancel_wrapper, + _interactions_create_wrapper, + _interactions_delete_wrapper, + _interactions_get_wrapper, ) @@ -55,6 +63,47 @@ class ModelsGenerateImagesPatcher(FunctionWrapperPatcher): wrapper = _generate_images_wrapper +# --------------------------------------------------------------------------- +# Sync Interactions patchers +# --------------------------------------------------------------------------- + + +class InteractionsCreatePatcher(FunctionWrapperPatcher): + """Patch ``InteractionsResource.create`` for tracing.""" + + name = "google_genai.interactions.create" + target_module = "google.genai._interactions.resources.interactions" + target_path = "InteractionsResource.create" + wrapper = _interactions_create_wrapper + + +class InteractionsGetPatcher(FunctionWrapperPatcher): + """Patch ``InteractionsResource.get`` for tracing.""" + + name = "google_genai.interactions.get" + target_module = "google.genai._interactions.resources.interactions" + target_path = "InteractionsResource.get" + wrapper = _interactions_get_wrapper + + +class InteractionsCancelPatcher(FunctionWrapperPatcher): + """Patch ``InteractionsResource.cancel`` for tracing.""" + + name = "google_genai.interactions.cancel" + target_module = "google.genai._interactions.resources.interactions" + target_path = "InteractionsResource.cancel" + wrapper = _interactions_cancel_wrapper + + +class InteractionsDeletePatcher(FunctionWrapperPatcher): + """Patch ``InteractionsResource.delete`` for tracing.""" + + name = "google_genai.interactions.delete" + target_module = "google.genai._interactions.resources.interactions" + target_path = "InteractionsResource.delete" + wrapper = _interactions_delete_wrapper + + # --------------------------------------------------------------------------- # Async Models patchers # --------------------------------------------------------------------------- @@ -94,3 +143,44 @@ class AsyncModelsGenerateImagesPatcher(FunctionWrapperPatcher): target_module = "google.genai.models" target_path = "AsyncModels.generate_images" wrapper = _async_generate_images_wrapper + + +# --------------------------------------------------------------------------- +# Async Interactions patchers +# --------------------------------------------------------------------------- + + +class AsyncInteractionsCreatePatcher(FunctionWrapperPatcher): + """Patch ``AsyncInteractionsResource.create`` for tracing.""" + + name = "google_genai.async_interactions.create" + target_module = "google.genai._interactions.resources.interactions" + target_path = "AsyncInteractionsResource.create" + wrapper = _async_interactions_create_wrapper + + +class AsyncInteractionsGetPatcher(FunctionWrapperPatcher): + """Patch ``AsyncInteractionsResource.get`` for tracing.""" + + name = "google_genai.async_interactions.get" + target_module = "google.genai._interactions.resources.interactions" + target_path = "AsyncInteractionsResource.get" + wrapper = _async_interactions_get_wrapper + + +class AsyncInteractionsCancelPatcher(FunctionWrapperPatcher): + """Patch ``AsyncInteractionsResource.cancel`` for tracing.""" + + name = "google_genai.async_interactions.cancel" + target_module = "google.genai._interactions.resources.interactions" + target_path = "AsyncInteractionsResource.cancel" + wrapper = _async_interactions_cancel_wrapper + + +class AsyncInteractionsDeletePatcher(FunctionWrapperPatcher): + """Patch ``AsyncInteractionsResource.delete`` for tracing.""" + + name = "google_genai.async_interactions.delete" + target_module = "google.genai._interactions.resources.interactions" + target_path = "AsyncInteractionsResource.delete" + wrapper = _async_interactions_delete_wrapper diff --git a/py/src/braintrust/integrations/google_genai/test_google_genai.py b/py/src/braintrust/integrations/google_genai/test_google_genai.py index fc48d671..8fd93d6d 100644 --- a/py/src/braintrust/integrations/google_genai/test_google_genai.py +++ b/py/src/braintrust/integrations/google_genai/test_google_genai.py @@ -8,9 +8,18 @@ from braintrust import logger from braintrust.integrations.google_genai import setup_genai from braintrust.logger import Attachment +from braintrust.span_types import SpanTypeAttribute from braintrust.test_helpers import init_test_logger from braintrust.wrappers.test_utils import verify_autoinstrument_script from google.genai import types + + +try: + from google.genai import interactions +except ImportError: + interactions = None + +_needs_interactions = pytest.mark.skipif(interactions is None, reason="google-genai too old for interactions API") from google.genai.client import Client @@ -19,6 +28,7 @@ EMBEDDING_MODEL = "gemini-embedding-001" IMAGE_MODEL = "imagen-4.0-fast-generate-001" REASONING_MODEL = "gemini-2.5-flash" +INTERACTIONS_MODEL = "gemini-2.5-flash" FIXTURES_DIR = Path(__file__).resolve().parent.parent.parent.parent.parent.parent / "internal/golden/fixtures" TINY_PNG_BASE64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8DwHwAFBQIAX8jx0gAAAABJRU5ErkJggg==" @@ -1165,6 +1175,275 @@ async def test_google_search_grounding_async(memory_logger, mode): _assert_grounding_metadata(span["output"]) +def _find_spans_by_type(spans, span_type): + return [span for span in spans if span["span_attributes"]["type"] == span_type] + + +def _find_span_by_name(spans, name): + return next(span for span in spans if span["span_attributes"]["name"] == name) + + +def _interaction_function_tool(): + return interactions.Function( + type="function", + name="get_weather", + description="Get the current weather for a location.", + parameters={ + "type": "object", + "properties": {"location": {"type": "string"}}, + "required": ["location"], + }, + ) + + +@_needs_interactions +@pytest.mark.vcr +def test_interactions_create_and_get(memory_logger): + assert not memory_logger.pop() + + client = Client() + response = client.interactions.create( + model=INTERACTIONS_MODEL, + input="What is the capital of France?", + ) + fetched = client.interactions.get(response.id, include_input=True) + + assert response.status == "completed" + assert fetched.id == response.id + assert fetched.status == "completed" + + spans = memory_logger.pop() + create_span = _find_span_by_name(_find_spans_by_type(spans, SpanTypeAttribute.LLM), "interactions.create") + get_span = _find_span_by_name(_find_spans_by_type(spans, SpanTypeAttribute.TASK), "interactions.get") + + assert create_span["metadata"]["model"] == INTERACTIONS_MODEL + assert create_span["metadata"]["interaction_id"] == response.id + assert create_span["output"]["status"] == "completed" + assert "Paris" in create_span["output"]["text"] + assert create_span["metrics"]["prompt_tokens"] > 0 + assert create_span["metrics"]["completion_tokens"] > 0 + + assert get_span["input"]["id"] == response.id + assert get_span["metadata"]["interaction_id"] == response.id + assert get_span["output"]["status"] == "completed" + assert "Paris" in get_span["output"]["text"] + assert "France" in str(get_span["output"]["outputs"]) + + +@_needs_interactions +@pytest.mark.vcr +def test_interactions_create_stream(memory_logger): + assert not memory_logger.pop() + + client = Client() + events = list( + client.interactions.create( + model=INTERACTIONS_MODEL, + input="Say hi in five words or less.", + stream=True, + ) + ) + + assert events + + spans = memory_logger.pop() + create_span = _find_span_by_name(_find_spans_by_type(spans, SpanTypeAttribute.LLM), "interactions.create") + + assert create_span["metadata"]["model"] == INTERACTIONS_MODEL + assert create_span["output"]["status"] == "completed" + assert create_span["output"]["text"] + assert "hi" in create_span["output"]["text"].lower() + assert create_span["metrics"]["time_to_first_token"] >= 0 + assert "content.start" in create_span["metadata"]["stream_event_types"] + assert "interaction.complete" in create_span["metadata"]["stream_event_types"] + + +@_needs_interactions +@pytest.mark.vcr +def test_interactions_tool_call_and_follow_up(memory_logger): + assert not memory_logger.pop() + + client = Client() + tool = _interaction_function_tool() + + first_response = client.interactions.create( + model=INTERACTIONS_MODEL, + input="What is the weather like in Paris? Use the tool.", + tools=[tool], + ) + tool_call = next(output for output in first_response.outputs if output.type == "function_call") + + second_response = client.interactions.create( + model=INTERACTIONS_MODEL, + previous_interaction_id=first_response.id, + input=interactions.FunctionResultContent( + type="function_result", + call_id=tool_call.id, + name=tool_call.name, + result={"forecast": "sunny"}, + ), + tools=[tool], + ) + + assert first_response.status == "requires_action" + assert second_response.status == "completed" + assert "sunny" in second_response.outputs[-1].text.lower() + + spans = memory_logger.pop() + llm_spans = _find_spans_by_type(spans, SpanTypeAttribute.LLM) + tool_spans = _find_spans_by_type(spans, SpanTypeAttribute.TOOL) + + first_span = next(span for span in llm_spans if span["metadata"]["interaction_id"] == first_response.id) + second_span = next(span for span in llm_spans if span["metadata"]["interaction_id"] == second_response.id) + tool_span = _find_span_by_name(tool_spans, "get_weather") + + assert first_span["output"]["status"] == "requires_action" + assert second_span["metadata"]["previous_interaction_id"] == first_response.id + assert second_span["output"]["status"] == "completed" + assert "sunny" in second_span["output"]["text"].lower() + + assert tool_span["input"] == {"location": "Paris"} + assert tool_span["span_parents"] == [first_span["span_id"]] + + +@_needs_interactions +@pytest.mark.vcr +def test_interactions_tool_span_stays_active_during_local_tool_work(memory_logger): + assert not memory_logger.pop() + + client = Client() + tool = _interaction_function_tool() + + first_response = client.interactions.create( + model=INTERACTIONS_MODEL, + input="What is the weather like in Paris? Use the tool.", + tools=[tool], + ) + tool_call = next(output for output in first_response.outputs if output.type == "function_call") + + with logger.start_span(name="nested_tool_work", type=SpanTypeAttribute.TASK) as nested_tool_work: + nested_tool_work.log(output={"forecast": "sunny"}) + + second_response = client.interactions.create( + model=INTERACTIONS_MODEL, + previous_interaction_id=first_response.id, + input=interactions.FunctionResultContent( + type="function_result", + call_id=tool_call.id, + name=tool_call.name, + result={"forecast": "sunny"}, + ), + tools=[tool], + ) + + assert first_response.status == "requires_action" + assert second_response.status == "completed" + + spans = memory_logger.pop() + llm_spans = _find_spans_by_type(spans, SpanTypeAttribute.LLM) + tool_spans = _find_spans_by_type(spans, SpanTypeAttribute.TOOL) + + first_span = next(span for span in llm_spans if span["metadata"]["interaction_id"] == first_response.id) + second_span = next(span for span in llm_spans if span["metadata"]["interaction_id"] == second_response.id) + tool_span = _find_span_by_name(tool_spans, "get_weather") + nested_span = _find_span_by_name(spans, "nested_tool_work") + + assert tool_span["span_parents"] == [first_span["span_id"]] + assert nested_span["span_parents"] == [tool_span["span_id"]] + assert tool_span["metrics"]["start"] <= nested_span["metrics"]["start"] + assert tool_span["metrics"]["end"] >= nested_span["metrics"]["end"] + assert second_span.get("span_parents") in (None, []) + + +@_needs_interactions +@pytest.mark.vcr +def test_interactions_delete(memory_logger): + assert not memory_logger.pop() + + client = Client() + response = client.interactions.create( + model=INTERACTIONS_MODEL, + input="Reply with exactly ok.", + ) + assert response.id + + create_spans = memory_logger.pop() + assert create_spans + + delete_response = client.interactions.delete(response.id) + assert delete_response == {} + + spans = memory_logger.pop() + delete_span = _find_span_by_name(_find_spans_by_type(spans, SpanTypeAttribute.TASK), "interactions.delete") + + assert delete_span["input"]["id"] == response.id + assert delete_span["output"] == {} + assert delete_span["metrics"]["duration"] >= 0 + + +@_needs_interactions +@pytest.mark.vcr +@pytest.mark.asyncio +async def test_interactions_async_round_trip(memory_logger): + assert not memory_logger.pop() + + client = Client() + response = await client.aio.interactions.create( + model=INTERACTIONS_MODEL, + input="What is the capital of Italy?", + ) + fetched = await client.aio.interactions.get(response.id, include_input=True) + deleted = await client.aio.interactions.delete(response.id) + + assert response.status == "completed" + assert fetched.id == response.id + assert deleted == {} + + spans = memory_logger.pop() + create_span = _find_span_by_name(_find_spans_by_type(spans, SpanTypeAttribute.LLM), "interactions.create") + get_span = _find_span_by_name(_find_spans_by_type(spans, SpanTypeAttribute.TASK), "interactions.get") + delete_span = _find_span_by_name(_find_spans_by_type(spans, SpanTypeAttribute.TASK), "interactions.delete") + + assert create_span["metadata"]["model"] == INTERACTIONS_MODEL + assert create_span["output"]["status"] == "completed" + assert "Rome" in create_span["output"]["text"] + + assert get_span["input"]["id"] == response.id + assert "Italy" in str(get_span["output"]["outputs"]) + + assert delete_span["input"]["id"] == response.id + assert delete_span["output"] == {} + + +@_needs_interactions +@pytest.mark.vcr +@pytest.mark.asyncio +async def test_interactions_async_stream(memory_logger): + assert not memory_logger.pop() + + client = Client() + stream = await client.aio.interactions.create( + model=INTERACTIONS_MODEL, + input="Say hi shortly.", + stream=True, + ) + + events = [] + async for event in stream: + events.append(event) + + assert events + + spans = memory_logger.pop() + create_span = _find_span_by_name(_find_spans_by_type(spans, SpanTypeAttribute.LLM), "interactions.create") + + assert create_span["output"]["status"] == "completed" + assert create_span["output"]["text"] + assert "hi" in create_span["output"]["text"].lower() + assert create_span["metrics"]["time_to_first_token"] >= 0 + assert "content.delta" in create_span["metadata"]["stream_event_types"] + + class TestAutoInstrumentGoogleGenAI: """Tests for auto_instrument() with Google GenAI.""" diff --git a/py/src/braintrust/integrations/google_genai/tracing.py b/py/src/braintrust/integrations/google_genai/tracing.py index 0ee6fa56..f6d0fb57 100644 --- a/py/src/braintrust/integrations/google_genai/tracing.py +++ b/py/src/braintrust/integrations/google_genai/tracing.py @@ -1,8 +1,14 @@ """Google GenAI-specific span creation, metadata extraction, stream handling, and output normalization.""" +import base64 +import binascii +import contextvars +import dataclasses import logging import time from collections.abc import Awaitable, Callable, Iterable +from datetime import date, datetime +from enum import Enum from typing import TYPE_CHECKING, Any from braintrust.bt_json import bt_safe_deep_copy @@ -11,6 +17,7 @@ if TYPE_CHECKING: + from google.genai._interactions.types.interaction import Interaction from google.genai.types import ( EmbedContentResponse, GenerateContentResponse, @@ -19,6 +26,35 @@ logger = logging.getLogger(__name__) +_MEDIA_CONTENT_TYPES = {"image", "audio", "video", "document"} +_TOOL_CALL_TYPES = { + "function_call", + "code_execution_call", + "url_context_call", + "google_search_call", + "mcp_server_tool_call", + "file_search_call", +} +_TOOL_RESULT_TYPES = { + "function_result", + "code_execution_result", + "url_context_result", + "google_search_result", + "mcp_server_tool_result", + "file_search_result", +} + + +@dataclasses.dataclass +class _ActiveInteractionToolSpan: + span: Any + is_current: bool = False + + +_interaction_tool_spans: contextvars.ContextVar[dict[str, _ActiveInteractionToolSpan] | None] = contextvars.ContextVar( + "braintrust_google_genai_interaction_tool_spans", default=None +) + # --------------------------------------------------------------------------- # Serialization helpers @@ -119,6 +155,54 @@ def _serialize_tools(api_client: Any, input: Any | None) -> Any | None: return None +def _attachment_from_base64_data(data: str, mime_type: str, *, label: str) -> Attachment | None: + raw_data = data + if raw_data.startswith("data:"): + _, _, encoded = raw_data.partition(",") + raw_data = encoded + + try: + decoded = base64.b64decode(raw_data, validate=True) + except (ValueError, binascii.Error): + return None + + extension = mime_type.split("/")[1] if "/" in mime_type else "bin" + return Attachment(data=decoded, filename=f"{label}.{extension}", content_type=mime_type) + + +def _serialize_interaction_content_dict(value: dict[str, Any]) -> dict[str, Any]: + serialized = {key: _serialize_interaction_value(val) for key, val in value.items() if val is not None} + + content_type = serialized.get("type") + data = serialized.get("data") + mime_type = serialized.get("mime_type") + if content_type in _MEDIA_CONTENT_TYPES and isinstance(data, str) and isinstance(mime_type, str): + attachment = _attachment_from_base64_data(data, mime_type, label=content_type) + if attachment is not None: + serialized["data"] = attachment + + return serialized + + +def _serialize_interaction_value(value: Any) -> Any: + if value is None or isinstance(value, (str, int, float, bool, Attachment)): + return value + if isinstance(value, Enum): + return value.value + if isinstance(value, (date, datetime)): + return value.isoformat() + if isinstance(value, (list, tuple)): + return [_serialize_interaction_value(item) for item in value] + if isinstance(value, dict): + return _serialize_interaction_content_dict(value) + if hasattr(value, "model_dump"): + try: + return _serialize_interaction_value(value.model_dump(exclude_none=True)) + except TypeError: + return _serialize_interaction_value(value.model_dump()) + return value + + # --------------------------------------------------------------------------- # Argument extraction helpers # --------------------------------------------------------------------------- @@ -155,6 +239,69 @@ def _prepare_generate_images_traced_call( return _clean(input), clean_kwargs +def _prepare_interaction_create_traced_call( + api_client: Any, args: list[Any], kwargs: dict[str, Any] +) -> tuple[dict[str, Any], dict[str, Any]]: + del api_client, args + + input_data = _clean( + { + "model": kwargs.get("model"), + "agent": kwargs.get("agent"), + "input": _serialize_interaction_value(kwargs.get("input")), + "background": kwargs.get("background"), + "generation_config": _serialize_interaction_value(kwargs.get("generation_config")), + "previous_interaction_id": kwargs.get("previous_interaction_id"), + "response_format": _serialize_interaction_value(kwargs.get("response_format")), + "response_mime_type": kwargs.get("response_mime_type"), + "response_modalities": _serialize_interaction_value(kwargs.get("response_modalities")), + "store": kwargs.get("store"), + "stream": kwargs.get("stream"), + "system_instruction": kwargs.get("system_instruction"), + "tools": _serialize_interaction_value(kwargs.get("tools")), + "agent_config": _serialize_interaction_value(kwargs.get("agent_config")), + } + ) + metadata = _clean( + { + "api_version": kwargs.get("api_version"), + "model": kwargs.get("model"), + "agent": kwargs.get("agent"), + "previous_interaction_id": kwargs.get("previous_interaction_id"), + } + ) + return input_data, metadata + + +def _prepare_interaction_get_traced_call( + api_client: Any, args: list[Any], kwargs: dict[str, Any] +) -> tuple[dict[str, Any], dict[str, Any]]: + del api_client + + interaction_id = args[0] if args else kwargs.get("id") + input_data = _clean( + { + "id": interaction_id, + "include_input": kwargs.get("include_input"), + "last_event_id": kwargs.get("last_event_id"), + "stream": kwargs.get("stream"), + } + ) + metadata = _clean({"api_version": kwargs.get("api_version")}) + return input_data, metadata + + +def _prepare_interaction_id_traced_call( + api_client: Any, args: list[Any], kwargs: dict[str, Any] +) -> tuple[dict[str, Any], dict[str, Any]]: + del api_client + + interaction_id = args[0] if args else kwargs.get("id") + input_data = _clean({"id": interaction_id}) + metadata = _clean({"api_version": kwargs.get("api_version")}) + return input_data, metadata + + # --------------------------------------------------------------------------- # Metric extraction helpers # --------------------------------------------------------------------------- @@ -286,17 +433,94 @@ def _extract_generate_images_output(response: Any) -> dict[str, Any]: ) -def _extract_generate_images_metrics(start: float) -> dict[str, Any]: +def _extract_generic_timing_metrics(start: float) -> dict[str, Any]: end_time = time.time() return _clean( - dict( - start=start, - end=end_time, - duration=end_time - start, + { + "start": start, + "end": end_time, + "duration": end_time - start, + } + ) + + +def _extract_interaction_usage_metrics(usage: Any, metrics: dict[str, Any]) -> None: + if usage is None: + return + + if hasattr(usage, "total_input_tokens") and usage.total_input_tokens is not None: + metrics["prompt_tokens"] = usage.total_input_tokens + if hasattr(usage, "total_output_tokens") and usage.total_output_tokens is not None: + metrics["completion_tokens"] = usage.total_output_tokens + if hasattr(usage, "total_tokens") and usage.total_tokens is not None: + metrics["tokens"] = usage.total_tokens + if hasattr(usage, "total_cached_tokens") and usage.total_cached_tokens is not None: + metrics["prompt_cached_tokens"] = usage.total_cached_tokens + if hasattr(usage, "total_thought_tokens") and usage.total_thought_tokens is not None: + metrics["completion_reasoning_tokens"] = usage.total_thought_tokens + if hasattr(usage, "total_tool_use_tokens") and usage.total_tool_use_tokens is not None: + metrics["tool_use_tokens"] = usage.total_tool_use_tokens + + +def _extract_interaction_text(outputs: list[dict[str, Any]]) -> str | None: + text_parts = [] + for item in outputs: + if item.get("type") == "text" and isinstance(item.get("text"), str): + text_parts.append(item["text"]) + return "".join(text_parts) or None + + +def _serialize_interaction_outputs(response: "Interaction") -> list[dict[str, Any]]: + outputs = _serialize_interaction_value(getattr(response, "outputs", None)) + return outputs if isinstance(outputs, list) else ([] if outputs is None else [outputs]) + + +def _extract_interaction_output( + response: "Interaction", serialized_outputs: list[dict[str, Any]] | None = None +) -> dict[str, Any]: + outputs_list = serialized_outputs if serialized_outputs is not None else _serialize_interaction_outputs(response) + + return _clean( + { + "status": getattr(response, "status", None), + "outputs": outputs_list, + "text": _extract_interaction_text(outputs_list), + } + ) + + +def _extract_interaction_metadata(response: "Interaction") -> dict[str, Any]: + usage = getattr(response, "usage", None) + usage_serialized = _serialize_interaction_value(usage) + usage_by_modality = None + if isinstance(usage_serialized, dict): + usage_by_modality = _clean( + { + "input_tokens_by_modality": usage_serialized.get("input_tokens_by_modality"), + "output_tokens_by_modality": usage_serialized.get("output_tokens_by_modality"), + "cached_tokens_by_modality": usage_serialized.get("cached_tokens_by_modality"), + "tool_use_tokens_by_modality": usage_serialized.get("tool_use_tokens_by_modality"), + } ) + + return _clean( + { + "interaction_id": getattr(response, "id", None), + "previous_interaction_id": getattr(response, "previous_interaction_id", None), + "role": getattr(response, "role", None), + "response_mime_type": getattr(response, "response_mime_type", None), + "response_modalities": _serialize_interaction_value(getattr(response, "response_modalities", None)), + "usage_by_modality": usage_by_modality, + } ) +def _extract_interaction_metrics(response: "Interaction", start: float) -> dict[str, Any]: + metrics = _extract_generic_timing_metrics(start) + _extract_interaction_usage_metrics(getattr(response, "usage", None), metrics) + return metrics + + # --------------------------------------------------------------------------- # Result processing helpers # --------------------------------------------------------------------------- @@ -311,7 +535,65 @@ def _embed_process_result(result: "EmbedContentResponse", start: float) -> tuple def _generate_images_process_result(result: Any, start: float) -> tuple[Any, dict[str, Any]]: - return _extract_generate_images_output(result), _extract_generate_images_metrics(start) + return _extract_generate_images_output(result), _extract_generic_timing_metrics(start) + + +def _tool_span_name(call_item: dict[str, Any] | None, result_item: dict[str, Any] | None) -> str: + item = call_item or result_item or {} + if item.get("server_name") and item.get("name"): + return f"{item['server_name']}.{item['name']}" + if item.get("name"): + return str(item["name"]) + return str(item.get("type") or "interaction_tool") + + +def _tool_span_input(call_item: dict[str, Any] | None) -> Any: + if not call_item: + return None + if call_item.get("arguments") is not None: + return call_item["arguments"] + return ( + _clean( + { + key: value + for key, value in call_item.items() + if key not in {"id", "name", "type", "signature", "server_name"} + } + ) + or None + ) + + +def _tool_span_output(result_item: dict[str, Any] | None) -> Any: + if not result_item: + return None + if result_item.get("result") is not None: + return result_item["result"] + return ( + _clean( + { + key: value + for key, value in result_item.items() + if key not in {"call_id", "name", "type", "signature", "server_name", "is_error"} + } + ) + or None + ) + + +def _interaction_process_result( + result: "Interaction", start: float +) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]: + outputs_list = _serialize_interaction_outputs(result) + return ( + _extract_interaction_output(result, outputs_list), + _extract_interaction_metrics(result, start), + _extract_interaction_metadata(result), + ) + + +def _generic_process_result(result: Any, start: float) -> tuple[Any, dict[str, Any]]: + return _serialize_interaction_value(result), _extract_generic_timing_metrics(start) # --------------------------------------------------------------------------- @@ -417,11 +699,318 @@ def _aggregate_generate_content_chunks( return aggregated, clean_metrics +def _is_interaction_content_event(event: Any) -> bool: + return getattr(event, "event_type", None) in {"content.start", "content.delta"} + + +def _merge_interaction_content_delta(item: dict[str, Any], delta: dict[str, Any]) -> dict[str, Any]: + delta_type = delta.get("type") + if item.get("type") is None: + if delta_type == "thought_signature": + item["type"] = "thought" + elif delta_type == "thought_summary": + item["type"] = "thought" + elif delta_type is not None: + item["type"] = delta_type + + for key, value in delta.items(): + if key == "type" or value is None: + continue + if ( + key in item + and isinstance(item[key], str) + and isinstance(value, str) + and key in {"text", "data", "signature"} + ): + item[key] += value + else: + item[key] = value + + return item + + +def _reconstruct_interaction_outputs_from_events(events: list[Any]) -> list[dict[str, Any]]: + outputs_by_index: dict[int, dict[str, Any]] = {} + + for event in events: + event_type = getattr(event, "event_type", None) + index = getattr(event, "index", None) + if not isinstance(index, int): + continue + + if event_type == "content.start": + outputs_by_index[index] = _serialize_interaction_value(getattr(event, "content", None)) or {} + elif event_type == "content.delta": + item = outputs_by_index.setdefault(index, {}) + delta = _serialize_interaction_value(getattr(event, "delta", None)) or {} + if isinstance(delta, dict): + outputs_by_index[index] = _merge_interaction_content_delta(item, delta) + + return [outputs_by_index[index] for index in sorted(outputs_by_index)] + + +def _get_active_interaction_tool_spans() -> dict[str, _ActiveInteractionToolSpan]: + active_tool_spans = _interaction_tool_spans.get() + if active_tool_spans is None: + active_tool_spans = {} + _interaction_tool_spans.set(active_tool_spans) + return active_tool_spans + + +def _tool_span_metadata(call_item: dict[str, Any] | None, result_item: dict[str, Any] | None) -> dict[str, Any] | None: + return ( + _clean( + { + "tool_type": (call_item or result_item or {}).get("type"), + "call_id": (call_item or {}).get("id") or (result_item or {}).get("call_id"), + "server_name": (call_item or result_item or {}).get("server_name"), + "signature": (call_item or result_item or {}).get("signature"), + } + ) + or None + ) + + +def _log_posthoc_interaction_tool_span(call_item: dict[str, Any] | None, result_item: dict[str, Any] | None) -> None: + with start_span( + name=_tool_span_name(call_item, result_item), + type=SpanTypeAttribute.TOOL, + input=_tool_span_input(call_item), + metadata=_tool_span_metadata(call_item, result_item), + ) as tool_span: + if not result_item: + return + if result_item.get("is_error"): + tool_span.log(error=_tool_span_output(result_item)) + else: + tool_span.log(output=_tool_span_output(result_item)) + + +def _cleanup_interaction_tool_span_state(active_tool_spans: dict[str, _ActiveInteractionToolSpan]) -> None: + if active_tool_spans: + return + _interaction_tool_spans.set(None) + + +def _close_active_interaction_tool_span( + call_id: str, result_item: dict[str, Any] | None = None, *, end_time: float | None = None +) -> bool: + active_tool_spans = _get_active_interaction_tool_spans() + active_tool_span = active_tool_spans.pop(call_id, None) + if active_tool_span is None: + return False + + if active_tool_span.is_current: + active_tool_span.span.unset_current() + + if result_item is not None: + if result_item.get("is_error"): + active_tool_span.span.log(error=_tool_span_output(result_item)) + else: + active_tool_span.span.log(output=_tool_span_output(result_item)) + + active_tool_span.span.end(end_time=end_time) + _cleanup_interaction_tool_span_state(active_tool_spans) + return True + + +def _activate_interaction_tool_span( + call_item: dict[str, Any], *, parent_export: str, start_time: float | None = None, set_current: bool = False +) -> None: + # Keep the tool span open across local tool execution so any nested spans + # started by user code naturally inherit from it until the corresponding + # function_result is submitted on a follow-up interactions.create call. + call_id = call_item.get("id") + if not isinstance(call_id, str): + _log_posthoc_interaction_tool_span(call_item, None) + return + + active_tool_spans = _get_active_interaction_tool_spans() + if call_id in active_tool_spans: + return + + tool_span = start_span( + name=_tool_span_name(call_item, None), + type=SpanTypeAttribute.TOOL, + input=_tool_span_input(call_item), + metadata=_tool_span_metadata(call_item, None), + parent=parent_export, + start_time=start_time, + set_current=True, + ) + active_tool_spans[call_id] = _ActiveInteractionToolSpan(span=tool_span, is_current=False) + + if set_current: + tool_span.set_current() + active_tool_spans[call_id].is_current = True + + +def _serialize_interaction_items(value: Any) -> list[dict[str, Any]]: + serialized = _serialize_interaction_value(value) + if serialized is None: + return [] + items = serialized if isinstance(serialized, list) else [serialized] + return [item for item in items if isinstance(item, dict)] + + +def _close_interaction_tool_spans_from_input(input_value: Any) -> None: + # Tool spans should end when the client hands the tool result back to the + # interactions API, before the follow-up LLM/TASK span begins. + end_time = time.time() + for item in _serialize_interaction_items(input_value): + if item.get("type") not in _TOOL_RESULT_TYPES: + continue + call_id = item.get("call_id") + if isinstance(call_id, str): + _close_active_interaction_tool_span(call_id, item, end_time=end_time) + + +def _finalize_interaction_tool_spans( + output: Any, metrics: dict[str, Any], metadata: dict[str, Any] | None, parent_export: str +) -> None: + del metadata + + if not isinstance(output, dict): + return + + outputs = output.get("outputs") + if not isinstance(outputs, list): + return + + active_tool_spans = _get_active_interaction_tool_spans() + calls_by_id: dict[str, dict[str, Any]] = {} + pending_results: dict[str, list[dict[str, Any]]] = {} + pairs: list[tuple[dict[str, Any] | None, dict[str, Any] | None]] = [] + emitted_call_ids: set[str] = set() + + for item in outputs: + item_type = item.get("type") + if item_type in _TOOL_CALL_TYPES: + call_id = item.get("id") + if isinstance(call_id, str): + calls_by_id[call_id] = item + for pending in pending_results.pop(call_id, []): + pairs.append((item, pending)) + emitted_call_ids.add(call_id) + else: + pairs.append((item, None)) + elif item_type in _TOOL_RESULT_TYPES: + call_id = item.get("call_id") + if isinstance(call_id, str) and call_id in active_tool_spans: + _close_active_interaction_tool_span(call_id, item, end_time=time.time()) + emitted_call_ids.add(call_id) + elif isinstance(call_id, str) and call_id in calls_by_id: + pairs.append((calls_by_id[call_id], item)) + emitted_call_ids.add(call_id) + elif isinstance(call_id, str): + pending_results.setdefault(call_id, []).append(item) + else: + pairs.append((None, item)) + + for call_id, result_items in pending_results.items(): + call_item = calls_by_id.get(call_id) + for result_item in result_items: + pairs.append((call_item, result_item)) + if call_item is not None: + emitted_call_ids.add(call_id) + + unpaired_call_items: list[dict[str, Any]] = [] + for call_id, call_item in calls_by_id.items(): + if call_id not in emitted_call_ids: + unpaired_call_items.append(call_item) + + for call_item, result_item in pairs: + _log_posthoc_interaction_tool_span(call_item, result_item) + + activatable_call_items = [ + call_item + for call_item in unpaired_call_items + if isinstance(call_item.get("id"), str) and call_item.get("id") not in active_tool_spans + ] + claim_current = len(activatable_call_items) == 1 and not any( + active_tool_span.is_current for active_tool_span in active_tool_spans.values() + ) + + for call_item in unpaired_call_items: + call_id = call_item.get("id") + if not isinstance(call_id, str): + _log_posthoc_interaction_tool_span(call_item, None) + continue + if call_id in active_tool_spans: + continue + _activate_interaction_tool_span( + call_item, + parent_export=parent_export, + start_time=metrics.get("end"), + set_current=claim_current and call_item is activatable_call_items[0], + ) + + +def _aggregate_interaction_events( + events: list[Any], start: float, first_token_time: float | None = None +) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]: + metrics = _extract_generic_timing_metrics(start) + if first_token_time is not None: + metrics["time_to_first_token"] = first_token_time - start + + metadata = _clean({"stream_event_types": [et for event in events if (et := getattr(event, "event_type", None))]}) + reconstructed_outputs = _reconstruct_interaction_outputs_from_events(events) + + final_interaction = next( + ( + event.interaction + for event in reversed(events) + if hasattr(event, "interaction") and getattr(event, "interaction", None) is not None + ), + None, + ) + if final_interaction is None: + if reconstructed_outputs: + return ( + {"outputs": reconstructed_outputs, "text": _extract_interaction_text(reconstructed_outputs)}, + _clean(metrics), + metadata, + ) + error_event = next( + ( + event + for event in reversed(events) + if getattr(event, "event_type", None) == "error" and getattr(event, "error", None) is not None + ), + None, + ) + if error_event is not None: + metadata["stream_error"] = _serialize_interaction_value(error_event.error) + return {"events": _serialize_interaction_value(events)}, _clean(metrics), metadata + + final_outputs_list = _serialize_interaction_outputs(final_interaction) + + _extract_interaction_usage_metrics(getattr(final_interaction, "usage", None), metrics) + metadata.update(_extract_interaction_metadata(final_interaction)) + + output = _extract_interaction_output(final_interaction, final_outputs_list) + if reconstructed_outputs and not output.get("outputs"): + output["outputs"] = reconstructed_outputs + output["text"] = _extract_interaction_text(reconstructed_outputs) + + return output, _clean(metrics), _clean(metadata) + + # --------------------------------------------------------------------------- # Traced call orchestration # --------------------------------------------------------------------------- +def _normalize_logged_result(result: Any) -> tuple[Any, dict[str, Any], dict[str, Any] | None]: + if isinstance(result, tuple) and len(result) == 3: + output, metrics, metadata = result + return output, metrics, metadata + if isinstance(result, tuple) and len(result) == 2: + output, metrics = result + return output, metrics, None + raise ValueError("Expected process_result/aggregate to return a 2-tuple or 3-tuple") + + def _run_traced_call( api_client: Any, args: list[Any], @@ -429,19 +1018,34 @@ def _run_traced_call( *, name: str, invoke: Callable[[], Any], - process_result: Callable[[Any, float], tuple[Any, dict[str, Any]]], + process_result: Callable[[Any, float], tuple[Any, dict[str, Any]] | tuple[Any, dict[str, Any], dict[str, Any]]], prepare_call: Callable[ [Any, list[Any], dict[str, Any]], tuple[dict[str, Any], dict[str, Any]] ] = _prepare_traced_call, + span_type: SpanTypeAttribute = SpanTypeAttribute.LLM, + before_invoke: Callable[[], None] | None = None, + finalize_logged_output: Callable[[Any, dict[str, Any], dict[str, Any] | None, str], None] | None = None, ) -> Any: input, clean_kwargs = prepare_call(api_client, args, kwargs) + if before_invoke is not None: + before_invoke() + start = time.time() - with start_span(name=name, type=SpanTypeAttribute.LLM, input=input, metadata=clean_kwargs) as span: + parent_export = None + output = None + metrics = None + metadata = None + with start_span(name=name, type=span_type, input=input, metadata=clean_kwargs or None) as span: result = invoke() - output, metrics = process_result(result, start) - span.log(output=output, metrics=metrics) - return result + output, metrics, metadata = _normalize_logged_result(process_result(result, start)) + span.log(output=output, metrics=metrics, metadata=metadata) + parent_export = span.export() + + if finalize_logged_output is not None and parent_export is not None and metrics is not None: + finalize_logged_output(output, metrics, metadata, parent_export) + + return result async def _run_async_traced_call( @@ -451,19 +1055,34 @@ async def _run_async_traced_call( *, name: str, invoke: Callable[[], Awaitable[Any]], - process_result: Callable[[Any, float], tuple[Any, dict[str, Any]]], + process_result: Callable[[Any, float], tuple[Any, dict[str, Any]] | tuple[Any, dict[str, Any], dict[str, Any]]], prepare_call: Callable[ [Any, list[Any], dict[str, Any]], tuple[dict[str, Any], dict[str, Any]] ] = _prepare_traced_call, + span_type: SpanTypeAttribute = SpanTypeAttribute.LLM, + before_invoke: Callable[[], None] | None = None, + finalize_logged_output: Callable[[Any, dict[str, Any], dict[str, Any] | None, str], None] | None = None, ) -> Any: input, clean_kwargs = prepare_call(api_client, args, kwargs) + if before_invoke is not None: + before_invoke() + start = time.time() - with start_span(name=name, type=SpanTypeAttribute.LLM, input=input, metadata=clean_kwargs) as span: + parent_export = None + output = None + metrics = None + metadata = None + with start_span(name=name, type=span_type, input=input, metadata=clean_kwargs or None) as span: result = await invoke() - output, metrics = process_result(result, start) - span.log(output=output, metrics=metrics) - return result + output, metrics, metadata = _normalize_logged_result(process_result(result, start)) + span.log(output=output, metrics=metrics, metadata=metadata) + parent_export = span.export() + + if finalize_logged_output is not None and parent_export is not None and metrics is not None: + finalize_logged_output(output, metrics, metadata, parent_export) + + return result def _run_stream_traced_call( @@ -473,23 +1092,46 @@ def _run_stream_traced_call( *, name: str, invoke: Callable[[], Any], - aggregate: Callable[[list[Any], float, float | None], tuple[Any, dict[str, Any]]], + aggregate: Callable[ + [list[Any], float, float | None], tuple[Any, dict[str, Any]] | tuple[Any, dict[str, Any], dict[str, Any]] + ], + span_type: SpanTypeAttribute = SpanTypeAttribute.LLM, + first_token_predicate: Callable[[Any], bool] | None = None, + prepare_call: Callable[ + [Any, list[Any], dict[str, Any]], tuple[dict[str, Any], dict[str, Any]] + ] = _prepare_traced_call, + before_invoke: Callable[[], None] | None = None, + finalize_logged_output: Callable[[Any, dict[str, Any], dict[str, Any] | None, str], None] | None = None, ) -> Any: - input, clean_kwargs = _prepare_traced_call(api_client, args, kwargs) + input, clean_kwargs = prepare_call(api_client, args, kwargs) + + if before_invoke is not None: + before_invoke() start = time.time() first_token_time = None - with start_span(name=name, type=SpanTypeAttribute.LLM, input=input, metadata=clean_kwargs) as span: + output = None + metrics = None + metadata = None + parent_export = None + with start_span(name=name, type=span_type, input=input, metadata=clean_kwargs or None) as span: chunks = [] for chunk in invoke(): - if first_token_time is None: + if first_token_time is None and ( + first_token_predicate(chunk) if first_token_predicate is not None else True + ): first_token_time = time.time() chunks.append(chunk) yield chunk - output, metrics = aggregate(chunks, start, first_token_time) - span.log(output=output, metrics=metrics) - return output + output, metrics, metadata = _normalize_logged_result(aggregate(chunks, start, first_token_time)) + span.log(output=output, metrics=metrics, metadata=metadata) + parent_export = span.export() + + if finalize_logged_output is not None and parent_export is not None and metrics is not None: + finalize_logged_output(output, metrics, metadata, parent_export) + + return output def _run_async_stream_traced_call( @@ -499,23 +1141,45 @@ def _run_async_stream_traced_call( *, name: str, invoke: Callable[[], Awaitable[Any]], - aggregate: Callable[[list[Any], float, float | None], tuple[Any, dict[str, Any]]], + aggregate: Callable[ + [list[Any], float, float | None], tuple[Any, dict[str, Any]] | tuple[Any, dict[str, Any], dict[str, Any]] + ], + span_type: SpanTypeAttribute = SpanTypeAttribute.LLM, + first_token_predicate: Callable[[Any], bool] | None = None, + prepare_call: Callable[ + [Any, list[Any], dict[str, Any]], tuple[dict[str, Any], dict[str, Any]] + ] = _prepare_traced_call, + before_invoke: Callable[[], None] | None = None, + finalize_logged_output: Callable[[Any, dict[str, Any], dict[str, Any] | None, str], None] | None = None, ) -> Any: - input, clean_kwargs = _prepare_traced_call(api_client, args, kwargs) + input, clean_kwargs = prepare_call(api_client, args, kwargs) async def stream_generator(): + if before_invoke is not None: + before_invoke() + start = time.time() first_token_time = None - with start_span(name=name, type=SpanTypeAttribute.LLM, input=input, metadata=clean_kwargs) as span: + output = None + metrics = None + metadata = None + parent_export = None + with start_span(name=name, type=span_type, input=input, metadata=clean_kwargs or None) as span: chunks = [] async for chunk in await invoke(): - if first_token_time is None: + if first_token_time is None and ( + first_token_predicate(chunk) if first_token_predicate is not None else True + ): first_token_time = time.time() chunks.append(chunk) yield chunk - output, metrics = aggregate(chunks, start, first_token_time) - span.log(output=output, metrics=metrics) + output, metrics, metadata = _normalize_logged_result(aggregate(chunks, start, first_token_time)) + span.log(output=output, metrics=metrics, metadata=metadata) + parent_export = span.export() + + if finalize_logged_output is not None and parent_export is not None and metrics is not None: + finalize_logged_output(output, metrics, metadata, parent_export) return stream_generator() @@ -613,3 +1277,175 @@ async def _async_generate_images_wrapper(wrapped: Any, instance: Any, args: Any, process_result=_generate_images_process_result, prepare_call=_prepare_generate_images_traced_call, ) + + +def _interactions_create_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any) -> Any: + before_invoke = lambda: _close_interaction_tool_spans_from_input(kwargs.get("input")) + + if kwargs.get("stream"): + return _run_stream_traced_call( + getattr(instance, "_client", None), + args, + kwargs, + name="interactions.create", + invoke=lambda: wrapped(*args, **kwargs), + aggregate=_aggregate_interaction_events, + first_token_predicate=_is_interaction_content_event, + prepare_call=_prepare_interaction_create_traced_call, + span_type=SpanTypeAttribute.LLM, + before_invoke=before_invoke, + finalize_logged_output=_finalize_interaction_tool_spans, + ) + + return _run_traced_call( + getattr(instance, "_client", None), + args, + kwargs, + name="interactions.create", + invoke=lambda: wrapped(*args, **kwargs), + process_result=_interaction_process_result, + prepare_call=_prepare_interaction_create_traced_call, + span_type=SpanTypeAttribute.LLM, + before_invoke=before_invoke, + finalize_logged_output=_finalize_interaction_tool_spans, + ) + + +async def _async_interactions_create_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any) -> Any: + before_invoke = lambda: _close_interaction_tool_spans_from_input(kwargs.get("input")) + + if kwargs.get("stream"): + return _run_async_stream_traced_call( + getattr(instance, "_client", None), + args, + kwargs, + name="interactions.create", + invoke=lambda: wrapped(*args, **kwargs), + aggregate=_aggregate_interaction_events, + first_token_predicate=_is_interaction_content_event, + prepare_call=_prepare_interaction_create_traced_call, + span_type=SpanTypeAttribute.LLM, + before_invoke=before_invoke, + finalize_logged_output=_finalize_interaction_tool_spans, + ) + + return await _run_async_traced_call( + getattr(instance, "_client", None), + args, + kwargs, + name="interactions.create", + invoke=lambda: wrapped(*args, **kwargs), + process_result=_interaction_process_result, + prepare_call=_prepare_interaction_create_traced_call, + span_type=SpanTypeAttribute.LLM, + before_invoke=before_invoke, + finalize_logged_output=_finalize_interaction_tool_spans, + ) + + +def _interactions_get_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any) -> Any: + if kwargs.get("stream"): + return _run_stream_traced_call( + getattr(instance, "_client", None), + args, + kwargs, + name="interactions.get", + invoke=lambda: wrapped(*args, **kwargs), + aggregate=_aggregate_interaction_events, + first_token_predicate=_is_interaction_content_event, + prepare_call=_prepare_interaction_get_traced_call, + span_type=SpanTypeAttribute.TASK, + finalize_logged_output=_finalize_interaction_tool_spans, + ) + + return _run_traced_call( + getattr(instance, "_client", None), + args, + kwargs, + name="interactions.get", + invoke=lambda: wrapped(*args, **kwargs), + process_result=_interaction_process_result, + prepare_call=_prepare_interaction_get_traced_call, + span_type=SpanTypeAttribute.TASK, + finalize_logged_output=_finalize_interaction_tool_spans, + ) + + +async def _async_interactions_get_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any) -> Any: + if kwargs.get("stream"): + return _run_async_stream_traced_call( + getattr(instance, "_client", None), + args, + kwargs, + name="interactions.get", + invoke=lambda: wrapped(*args, **kwargs), + aggregate=_aggregate_interaction_events, + first_token_predicate=_is_interaction_content_event, + prepare_call=_prepare_interaction_get_traced_call, + span_type=SpanTypeAttribute.TASK, + finalize_logged_output=_finalize_interaction_tool_spans, + ) + + return await _run_async_traced_call( + getattr(instance, "_client", None), + args, + kwargs, + name="interactions.get", + invoke=lambda: wrapped(*args, **kwargs), + process_result=_interaction_process_result, + prepare_call=_prepare_interaction_get_traced_call, + span_type=SpanTypeAttribute.TASK, + finalize_logged_output=_finalize_interaction_tool_spans, + ) + + +def _interactions_cancel_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any) -> Any: + return _run_traced_call( + getattr(instance, "_client", None), + args, + kwargs, + name="interactions.cancel", + invoke=lambda: wrapped(*args, **kwargs), + process_result=_interaction_process_result, + prepare_call=_prepare_interaction_id_traced_call, + span_type=SpanTypeAttribute.TASK, + ) + + +async def _async_interactions_cancel_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any) -> Any: + return await _run_async_traced_call( + getattr(instance, "_client", None), + args, + kwargs, + name="interactions.cancel", + invoke=lambda: wrapped(*args, **kwargs), + process_result=_interaction_process_result, + prepare_call=_prepare_interaction_id_traced_call, + span_type=SpanTypeAttribute.TASK, + ) + + +def _interactions_delete_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any) -> Any: + return _run_traced_call( + getattr(instance, "_client", None), + args, + kwargs, + name="interactions.delete", + invoke=lambda: wrapped(*args, **kwargs), + process_result=_generic_process_result, + prepare_call=_prepare_interaction_id_traced_call, + span_type=SpanTypeAttribute.TASK, + ) + + +async def _async_interactions_delete_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any) -> Any: + return await _run_async_traced_call( + getattr(instance, "_client", None), + args, + kwargs, + name="interactions.delete", + invoke=lambda: wrapped(*args, **kwargs), + process_result=_generic_process_result, + prepare_call=_prepare_interaction_id_traced_call, + span_type=SpanTypeAttribute.TASK, + )