Skip to content

Commit 1e975ce

Browse files
committed
move configs to a separate file
1 parent 3e077fe commit 1e975ce

4 files changed

Lines changed: 181 additions & 169 deletions

File tree

sentry_sdk/integrations/cohere/__init__.py

Lines changed: 2 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
import sys
22
from functools import wraps
33

4-
from sentry_sdk.consts import OP, SPANDATA
4+
from sentry_sdk.consts import OP
55
from sentry_sdk.ai.span_config import set_request_span_data, set_response_span_data
6+
from sentry_sdk.integrations.cohere.configs import COHERE_EMBED_CONFIG
67

78
from typing import TYPE_CHECKING
89

910
from sentry_sdk.tracing_utils import set_span_errored
1011

1112
if TYPE_CHECKING:
1213
from typing import Any, Callable
13-
from sentry_sdk.ai.span_config import OperationConfig
1414

1515
import sentry_sdk
1616
from sentry_sdk.integrations import DidNotEnable, Integration
@@ -22,35 +22,6 @@
2222
raise DidNotEnable("Cohere not installed")
2323

2424

25-
def _normalize_embedding_input(texts):
26-
# type: (Any) -> Any
27-
if isinstance(texts, list):
28-
return texts
29-
if isinstance(texts, tuple):
30-
return list(texts)
31-
return [texts]
32-
33-
34-
COHERE_EMBED_CONFIG: "OperationConfig" = {
35-
"static": {
36-
SPANDATA.GEN_AI_SYSTEM: "cohere",
37-
SPANDATA.GEN_AI_OPERATION_NAME: "embeddings",
38-
},
39-
"params": {"model": SPANDATA.GEN_AI_REQUEST_MODEL},
40-
"extract_messages": lambda kw: (
41-
_normalize_embedding_input(kw["texts"]) if "texts" in kw else None
42-
),
43-
"message_target": SPANDATA.GEN_AI_EMBEDDINGS_INPUT,
44-
"truncation_fn": None,
45-
"response": {
46-
"usage": {
47-
"input_tokens": [("meta", "billed_units", "input_tokens")],
48-
"total_tokens": [("meta", "billed_units", "input_tokens")],
49-
},
50-
},
51-
}
52-
53-
5425
class CohereIntegration(Integration):
5526
identifier = "cohere"
5627
origin = f"auto.ai.{identifier}"
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
from sentry_sdk.ai.utils import (
2+
get_first_from_sources,
3+
transform_message_content,
4+
)
5+
from sentry_sdk.consts import SPANDATA
6+
7+
from typing import TYPE_CHECKING
8+
9+
if TYPE_CHECKING:
10+
from typing import Any
11+
from sentry_sdk.ai.span_config import OperationConfig
12+
13+
14+
# ── Helpers ──────────────────────────────────────────────────────────────────
15+
16+
17+
def _normalize_embedding_input(texts):
18+
# type: (Any) -> Any
19+
if isinstance(texts, list):
20+
return texts
21+
if isinstance(texts, tuple):
22+
return list(texts)
23+
return [texts]
24+
25+
26+
def _extract_v1_messages(kwargs):
27+
# type: (dict[str, Any]) -> list[dict[str, str]]
28+
messages = []
29+
for x in kwargs.get("chat_history", []):
30+
messages.append(
31+
{
32+
"role": getattr(x, "role", ""),
33+
"content": transform_message_content(getattr(x, "message", "")),
34+
}
35+
)
36+
message = kwargs.get("message")
37+
if message:
38+
messages.append({"role": "user", "content": transform_message_content(message)})
39+
return messages
40+
41+
42+
def _extract_v1_response_text(response):
43+
# type: (Any) -> list[str] | None
44+
text = getattr(response, "text", None)
45+
return [text] if text is not None else None
46+
47+
48+
def _extract_v2_messages(messages):
49+
# type: (Any) -> list[dict[str, Any]]
50+
result = []
51+
for msg in messages:
52+
role = msg["role"] if isinstance(msg, dict) else getattr(msg, "role", "unknown")
53+
content = (
54+
msg["content"] if isinstance(msg, dict) else getattr(msg, "content", "")
55+
)
56+
result.append({"role": role, "content": transform_message_content(content)})
57+
return result
58+
59+
60+
def _extract_v2_response_text(response):
61+
# type: (Any) -> list[str] | None
62+
content = get_first_from_sources(response, [("message", "content")], True)
63+
if content:
64+
texts = [item.text for item in content if hasattr(item, "text")]
65+
if texts:
66+
return texts
67+
return None
68+
69+
70+
# ── Configs ──────────────────────────────────────────────────────────────────
71+
72+
73+
COHERE_EMBED_CONFIG: "OperationConfig" = {
74+
"static": {
75+
SPANDATA.GEN_AI_SYSTEM: "cohere",
76+
SPANDATA.GEN_AI_OPERATION_NAME: "embeddings",
77+
},
78+
"params": {"model": SPANDATA.GEN_AI_REQUEST_MODEL},
79+
"extract_messages": lambda kw: (
80+
_normalize_embedding_input(kw["texts"]) if "texts" in kw else None
81+
),
82+
"message_target": SPANDATA.GEN_AI_EMBEDDINGS_INPUT,
83+
"response": {
84+
"usage": {
85+
"input_tokens": [("meta", "billed_units", "input_tokens")],
86+
"total_tokens": [("meta", "billed_units", "input_tokens")],
87+
},
88+
},
89+
}
90+
91+
92+
COHERE_V1_CHAT_CONFIG: "OperationConfig" = {
93+
"static": {
94+
SPANDATA.GEN_AI_SYSTEM: "cohere",
95+
SPANDATA.GEN_AI_OPERATION_NAME: "chat",
96+
},
97+
"extract_messages": lambda kw: _extract_v1_messages(kw),
98+
"response": {
99+
"sources": {
100+
SPANDATA.GEN_AI_RESPONSE_MODEL: [("model",)],
101+
SPANDATA.GEN_AI_RESPONSE_ID: [("generation_id",)],
102+
SPANDATA.GEN_AI_RESPONSE_FINISH_REASONS: [("finish_reason",)],
103+
},
104+
"pii_sources": {
105+
SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS: [("tool_calls",)],
106+
},
107+
"extract_text": _extract_v1_response_text,
108+
"usage": {
109+
"input_tokens": [
110+
("meta", "billed_units", "input_tokens"),
111+
("meta", "tokens", "input_tokens"),
112+
],
113+
"output_tokens": [
114+
("meta", "billed_units", "output_tokens"),
115+
("meta", "tokens", "output_tokens"),
116+
],
117+
},
118+
},
119+
"stream_response_object": [("response",)],
120+
}
121+
122+
123+
STREAM_DELTA_TEXT_SOURCES = [("delta", "message", "content", "text")]
124+
125+
126+
COHERE_V2_CHAT_CONFIG: "OperationConfig" = {
127+
"static": {
128+
SPANDATA.GEN_AI_SYSTEM: "cohere",
129+
SPANDATA.GEN_AI_OPERATION_NAME: "chat",
130+
},
131+
"pii_params": {
132+
"tools": SPANDATA.GEN_AI_REQUEST_AVAILABLE_TOOLS,
133+
},
134+
"extract_messages": lambda kw: _extract_v2_messages(kw.get("messages", [])),
135+
"response": {
136+
"sources": {
137+
SPANDATA.GEN_AI_RESPONSE_MODEL: [("model",)],
138+
SPANDATA.GEN_AI_RESPONSE_ID: [("id",)],
139+
SPANDATA.GEN_AI_RESPONSE_FINISH_REASONS: [("finish_reason",)],
140+
},
141+
"pii_sources": {
142+
SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS: [("message", "tool_calls")],
143+
},
144+
"extract_text": _extract_v2_response_text,
145+
"usage": {
146+
"input_tokens": [
147+
("usage", "billed_units", "input_tokens"),
148+
("usage", "tokens", "input_tokens"),
149+
],
150+
"output_tokens": [
151+
("usage", "billed_units", "output_tokens"),
152+
("usage", "tokens", "output_tokens"),
153+
],
154+
},
155+
},
156+
"stream_response": {
157+
"sources": {
158+
SPANDATA.GEN_AI_RESPONSE_ID: [("id",)],
159+
SPANDATA.GEN_AI_RESPONSE_FINISH_REASONS: [("delta", "finish_reason")],
160+
},
161+
"usage": {
162+
"input_tokens": [
163+
("delta", "usage", "billed_units", "input_tokens"),
164+
("delta", "usage", "tokens", "input_tokens"),
165+
],
166+
"output_tokens": [
167+
("delta", "usage", "billed_units", "output_tokens"),
168+
("delta", "usage", "tokens", "output_tokens"),
169+
],
170+
},
171+
},
172+
}

sentry_sdk/integrations/cohere/v1.py

Lines changed: 2 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,15 @@
22
from functools import wraps
33

44
from sentry_sdk.ai.span_config import set_request_span_data, set_response_span_data
5-
from sentry_sdk.ai.utils import (
6-
get_first_from_sources,
7-
transform_message_content,
8-
)
5+
from sentry_sdk.ai.utils import get_first_from_sources
96
from sentry_sdk.consts import OP, SPANDATA
7+
from sentry_sdk.integrations.cohere.configs import COHERE_V1_CHAT_CONFIG
108

119
from typing import TYPE_CHECKING
1210

1311
if TYPE_CHECKING:
1412
from typing import Any, Callable, Iterator
1513
from cohere import StreamedChatResponse
16-
from sentry_sdk.ai.span_config import OperationConfig
1714

1815
import sentry_sdk
1916
from sentry_sdk.integrations.cohere import (
@@ -51,43 +48,6 @@ def setup_v1(wrap_embed_fn):
5148
Client.embed = wrap_embed_fn(Client.embed)
5249

5350

54-
def _extract_response_text(response):
55-
# type: (Any) -> list[str] | None
56-
text = getattr(response, "text", None)
57-
return [text] if text is not None else None
58-
59-
60-
COHERE_V1_CHAT_CONFIG: "OperationConfig" = {
61-
"static": {
62-
SPANDATA.GEN_AI_SYSTEM: "cohere",
63-
SPANDATA.GEN_AI_OPERATION_NAME: "chat",
64-
},
65-
"extract_messages": lambda kw: _extract_messages(kw),
66-
"response": {
67-
"sources": {
68-
SPANDATA.GEN_AI_RESPONSE_MODEL: [("model",)],
69-
SPANDATA.GEN_AI_RESPONSE_ID: [("generation_id",)],
70-
SPANDATA.GEN_AI_RESPONSE_FINISH_REASONS: [("finish_reason",)],
71-
},
72-
"pii_sources": {
73-
SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS: [("tool_calls",)],
74-
},
75-
"extract_text": _extract_response_text,
76-
"usage": {
77-
"input_tokens": [
78-
("meta", "billed_units", "input_tokens"),
79-
("meta", "tokens", "input_tokens"),
80-
],
81-
"output_tokens": [
82-
("meta", "billed_units", "output_tokens"),
83-
("meta", "tokens", "output_tokens"),
84-
],
85-
},
86-
},
87-
"stream_response_object": [("response",)],
88-
}
89-
90-
9151
def _wrap_chat(f, streaming):
9252
# type: (Callable[..., Any], bool) -> Callable[..., Any]
9353
if not _has_chat_types:
@@ -140,22 +100,6 @@ def new_chat(*args, **kwargs):
140100
return new_chat
141101

142102

143-
def _extract_messages(kwargs):
144-
# type: (dict[str, Any]) -> list[dict[str, str]]
145-
messages = []
146-
for x in kwargs.get("chat_history", []):
147-
messages.append(
148-
{
149-
"role": getattr(x, "role", ""),
150-
"content": transform_message_content(getattr(x, "message", "")),
151-
}
152-
)
153-
message = kwargs.get("message")
154-
if message:
155-
messages.append({"role": "user", "content": transform_message_content(message)})
156-
return messages
157-
158-
159103
def _iter_stream_events(old_iterator, span, include_pii):
160104
# type: (Any, Any, bool) -> Iterator[StreamedChatResponse]
161105
with capture_internal_exceptions():

0 commit comments

Comments
 (0)