Skip to content

Commit 7048b5d

Browse files
committed
fix(integrations): Auto-wrap root gen_ai spans for openai, cohere, langgraph, huggingface_hub
1 parent 3684f01 commit 7048b5d

6 files changed

Lines changed: 37 additions & 22 deletions

File tree

sentry_sdk/integrations/cohere.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from sentry_sdk import consts
66
from sentry_sdk.ai.monitoring import record_token_usage
7-
from sentry_sdk.ai.utils import set_data_normalized
7+
from sentry_sdk.ai.utils import get_start_span_function, set_data_normalized
88
from sentry_sdk.consts import SPANDATA
99
from sentry_sdk.tracing_utils import set_span_errored
1010

@@ -142,7 +142,7 @@ def new_chat(*args: "Any", **kwargs: "Any") -> "Any":
142142

143143
message = kwargs.get("message")
144144

145-
span = sentry_sdk.start_span(
145+
span = get_start_span_function()(
146146
op=consts.OP.COHERE_CHAT_COMPLETIONS_CREATE,
147147
name="cohere.client.Chat",
148148
origin=CohereIntegration.origin,
@@ -225,7 +225,7 @@ def new_embed(*args: "Any", **kwargs: "Any") -> "Any":
225225
if integration is None:
226226
return f(*args, **kwargs)
227227

228-
with sentry_sdk.start_span(
228+
with get_start_span_function()(
229229
op=consts.OP.COHERE_EMBEDDINGS_CREATE,
230230
name="Cohere Embedding Creation",
231231
origin=CohereIntegration.origin,

sentry_sdk/integrations/huggingface_hub.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@
55

66
import sentry_sdk
77
from sentry_sdk.ai.monitoring import record_token_usage
8-
from sentry_sdk.ai.utils import _set_span_data_attribute, set_data_normalized
8+
from sentry_sdk.ai.utils import (
9+
_set_span_data_attribute,
10+
get_start_span_function,
11+
set_data_normalized,
12+
)
913
from sentry_sdk.consts import OP, SPANDATA
1014
from sentry_sdk.integrations import DidNotEnable, Integration
1115
from sentry_sdk.scope import should_send_default_pii
@@ -97,7 +101,7 @@ def new_huggingface_task(*args: "Any", **kwargs: "Any") -> "Any":
97101
},
98102
)
99103
else:
100-
span = sentry_sdk.start_span(
104+
span = get_start_span_function()(
101105
op=op,
102106
name=f"{operation_name} {model}",
103107
origin=HuggingfaceHubIntegration.origin,

sentry_sdk/integrations/langgraph.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import sentry_sdk
55
from sentry_sdk.ai.utils import (
6+
get_start_span_function,
67
normalize_message_roles,
78
set_data_normalized,
89
truncate_and_annotate_messages,
@@ -159,7 +160,7 @@ def new_invoke(self: "Any", *args: "Any", **kwargs: "Any") -> "Any":
159160
f"invoke_agent {graph_name}".strip() if graph_name else "invoke_agent"
160161
)
161162

162-
with sentry_sdk.start_span(
163+
with get_start_span_function()(
163164
op=OP.GEN_AI_INVOKE_AGENT,
164165
name=span_name,
165166
origin=LanggraphIntegration.origin,
@@ -219,7 +220,7 @@ async def new_ainvoke(self: "Any", *args: "Any", **kwargs: "Any") -> "Any":
219220
f"invoke_agent {graph_name}".strip() if graph_name else "invoke_agent"
220221
)
221222

222-
with sentry_sdk.start_span(
223+
with get_start_span_function()(
223224
op=OP.GEN_AI_INVOKE_AGENT,
224225
name=span_name,
225226
origin=LanggraphIntegration.origin,

sentry_sdk/integrations/openai.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
)
2626
from sentry_sdk.ai.monitoring import record_token_usage
2727
from sentry_sdk.ai.utils import (
28+
get_start_span_function,
2829
normalize_message_roles,
2930
set_data_normalized,
3031
truncate_and_annotate_embedding_inputs,
@@ -713,7 +714,7 @@ def _new_sync_chat_completion(f: "Any", *args: "Any", **kwargs: "Any") -> "Any":
713714

714715
model = kwargs.get("model")
715716

716-
span = sentry_sdk.start_span(
717+
span = get_start_span_function()(
717718
op=consts.OP.GEN_AI_CHAT,
718719
name=f"chat {model}",
719720
origin=OpenAIIntegration.origin,
@@ -781,7 +782,7 @@ async def _new_async_chat_completion(f: "Any", *args: "Any", **kwargs: "Any") ->
781782

782783
model = kwargs.get("model")
783784

784-
span = sentry_sdk.start_span(
785+
span = get_start_span_function()(
785786
op=consts.OP.GEN_AI_CHAT,
786787
name=f"chat {model}",
787788
origin=OpenAIIntegration.origin,
@@ -1177,7 +1178,7 @@ def _new_sync_embeddings_create(f: "Any", *args: "Any", **kwargs: "Any") -> "Any
11771178

11781179
model = kwargs.get("model")
11791180

1180-
with sentry_sdk.start_span(
1181+
with get_start_span_function()(
11811182
op=consts.OP.GEN_AI_EMBEDDINGS,
11821183
name=f"embeddings {model}",
11831184
origin=OpenAIIntegration.origin,
@@ -1209,7 +1210,7 @@ async def _new_async_embeddings_create(
12091210

12101211
model = kwargs.get("model")
12111212

1212-
with sentry_sdk.start_span(
1213+
with get_start_span_function()(
12131214
op=consts.OP.GEN_AI_EMBEDDINGS,
12141215
name=f"embeddings {model}",
12151216
origin=OpenAIIntegration.origin,
@@ -1263,7 +1264,7 @@ def _new_sync_responses_create(f: "Any", *args: "Any", **kwargs: "Any") -> "Any"
12631264

12641265
model = kwargs.get("model")
12651266

1266-
span = sentry_sdk.start_span(
1267+
span = get_start_span_function()(
12671268
op=consts.OP.GEN_AI_RESPONSES,
12681269
name=f"responses {model}",
12691270
origin=OpenAIIntegration.origin,
@@ -1321,7 +1322,7 @@ async def _new_async_responses_create(f: "Any", *args: "Any", **kwargs: "Any") -
13211322

13221323
model = kwargs.get("model")
13231324

1324-
span = sentry_sdk.start_span(
1325+
span = get_start_span_function()(
13251326
op=consts.OP.GEN_AI_RESPONSES,
13261327
name=f"responses {model}",
13271328
origin=OpenAIIntegration.origin,

tests/integrations/cohere/test_cohere.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,8 +162,9 @@ def test_bad_chat(sentry_init, capture_events):
162162
with pytest.raises(httpx.HTTPError):
163163
client.chat(model="some-model", message="hello")
164164

165-
(event,) = events
165+
(event, transaction) = events
166166
assert event["level"] == "error"
167+
assert transaction["contexts"]["trace"]["status"] == "internal_error"
167168

168169

169170
def test_span_status_error(sentry_init, capture_events):

tests/integrations/openai/test_openai.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2164,7 +2164,7 @@ def test_bad_chat_completion(
21642164
)
21652165

21662166
if stream_gen_ai_spans:
2167-
items = capture_items("event")
2167+
items = capture_items("event", "transaction")
21682168

21692169
client = OpenAI(api_key="z")
21702170
client.chat.completions._post = mock.Mock(
@@ -2177,6 +2177,7 @@ def test_bad_chat_completion(
21772177
)
21782178

21792179
(event,) = (item.payload for item in items if item.type == "event")
2180+
(transaction,) = (item.payload for item in items if item.type == "transaction")
21802181
else:
21812182
events = capture_events()
21822183

@@ -2190,9 +2191,10 @@ def test_bad_chat_completion(
21902191
messages=[{"role": "system", "content": "hello"}],
21912192
)
21922193

2193-
(event,) = events
2194+
(event, transaction) = events
21942195

21952196
assert event["level"] == "error"
2197+
assert transaction["contexts"]["trace"]["status"] == "internal_error"
21962198

21972199

21982200
@pytest.mark.parametrize("stream_gen_ai_spans", [True, False])
@@ -2266,14 +2268,15 @@ async def test_bad_chat_completion_async(
22662268
side_effect=OpenAIError("API rate limit reached")
22672269
)
22682270
if stream_gen_ai_spans:
2269-
items = capture_items("event")
2271+
items = capture_items("event", "transaction")
22702272

22712273
with pytest.raises(OpenAIError):
22722274
await client.chat.completions.create(
22732275
model="some-model", messages=[{"role": "system", "content": "hello"}]
22742276
)
22752277

22762278
(event,) = (item.payload for item in items if item.type == "event")
2279+
(transaction,) = (item.payload for item in items if item.type == "transaction")
22772280
else:
22782281
events = capture_events()
22792282

@@ -2282,9 +2285,10 @@ async def test_bad_chat_completion_async(
22822285
model="some-model", messages=[{"role": "system", "content": "hello"}]
22832286
)
22842287

2285-
(event,) = events
2288+
(event, transaction) = events
22862289

22872290
assert event["level"] == "error"
2291+
assert transaction["contexts"]["trace"]["status"] == "internal_error"
22882292

22892293

22902294
@pytest.mark.parametrize("stream_gen_ai_spans", [True, False])
@@ -2834,21 +2838,23 @@ def test_embeddings_create_raises_error(
28342838
)
28352839

28362840
if stream_gen_ai_spans:
2837-
items = capture_items("event")
2841+
items = capture_items("event", "transaction")
28382842

28392843
with pytest.raises(OpenAIError):
28402844
client.embeddings.create(input="hello", model="text-embedding-3-large")
28412845

28422846
(event,) = (item.payload for item in items if item.type == "event")
2847+
(transaction,) = (item.payload for item in items if item.type == "transaction")
28432848
else:
28442849
events = capture_events()
28452850

28462851
with pytest.raises(OpenAIError):
28472852
client.embeddings.create(input="hello", model="text-embedding-3-large")
28482853

2849-
(event,) = events
2854+
(event, transaction) = events
28502855

28512856
assert event["level"] == "error"
2857+
assert transaction["contexts"]["trace"]["status"] == "internal_error"
28522858

28532859

28542860
@pytest.mark.parametrize("stream_gen_ai_spans", [True, False])
@@ -2879,14 +2885,15 @@ async def test_embeddings_create_raises_error_async(
28792885
)
28802886

28812887
if stream_gen_ai_spans:
2882-
items = capture_items("event")
2888+
items = capture_items("event", "transaction")
28832889

28842890
with pytest.raises(OpenAIError):
28852891
await client.embeddings.create(
28862892
input="hello", model="text-embedding-3-large"
28872893
)
28882894

28892895
(event,) = (item.payload for item in items if item.type == "event")
2896+
(transaction,) = (item.payload for item in items if item.type == "transaction")
28902897
else:
28912898
events = capture_events()
28922899

@@ -2895,9 +2902,10 @@ async def test_embeddings_create_raises_error_async(
28952902
input="hello", model="text-embedding-3-large"
28962903
)
28972904

2898-
(event,) = events
2905+
(event, transaction) = events
28992906

29002907
assert event["level"] == "error"
2908+
assert transaction["contexts"]["trace"]["status"] == "internal_error"
29012909

29022910

29032911
@pytest.mark.parametrize("stream_gen_ai_spans", [True, False])

0 commit comments

Comments
 (0)