Skip to content

Commit c7a1b58

Browse files
committed
simplify
1 parent 068ce5d commit c7a1b58

6 files changed

Lines changed: 61 additions & 67 deletions

File tree

sentry_sdk/ai/span_config.py

Lines changed: 14 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -14,44 +14,39 @@
1414
from sentry_sdk.tracing import Span
1515

1616

17-
def set_input_span_data(span, kwargs, integration, config):
18-
# type: (Span, Dict[str, Any], Any, Dict[str, Any]) -> None
17+
def set_input_span_data(span, kwargs, integration, config, span_data=None):
18+
# type: (Span, Dict[str, Any], Any, Dict[str, Any], Dict[str, Any] | None) -> None
1919
"""
2020
Set input span data from a declarative config.
2121
2222
Config keys:
23-
system: str - gen_ai.system value
24-
operation: str - gen_ai.operation.name value
23+
static: dict - key/value pairs to set unconditionally
2524
params: dict - kwargs key -> span attr (always set if present)
2625
pii_params: dict - kwargs key -> span attr (only when PII allowed)
2726
extract_messages: callable(kwargs) -> list or None
2827
message_target: str - span attr for messages (default: GEN_AI_REQUEST_MESSAGES)
29-
truncation_fn: callable or None - truncation function (default: truncate_and_annotate_messages, None to skip)
30-
is_given: callable(value) -> bool - for NotGiven sentinels
31-
extra_static: dict - additional key/value pairs to set
28+
29+
span_data: additional key/value pairs for dynamic per-call values
3230
"""
33-
set_data_normalized(span, SPANDATA.GEN_AI_SYSTEM, config["system"])
34-
set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, config["operation"])
31+
for key, value in config.get("static", {}).items():
32+
set_data_normalized(span, key, value)
33+
if span_data:
34+
for key, value in span_data.items():
35+
set_data_normalized(span, key, value)
3536

36-
is_given = config.get("is_given")
3737
for kwarg_key, span_attr in config.get("params", {}).items():
3838
if kwarg_key in kwargs:
3939
value = kwargs[kwarg_key]
40-
if is_given is None or is_given(value):
41-
set_data_normalized(span, span_attr, value)
40+
set_data_normalized(span, span_attr, value)
4241

4342
if should_send_default_pii() and integration.include_prompts:
4443
extract = config.get("extract_messages")
4544
if extract is not None:
4645
messages = extract(kwargs)
4746
if messages:
4847
messages = normalize_message_roles(messages)
49-
truncation_fn = config.get(
50-
"truncation_fn", truncate_and_annotate_messages
51-
)
52-
if truncation_fn is not None:
53-
scope = sentry_sdk.get_current_scope()
54-
messages = truncation_fn(messages, span, scope)
48+
scope = sentry_sdk.get_current_scope()
49+
messages = truncate_and_annotate_messages(messages, span, scope)
5550
if messages is not None:
5651
target = config.get(
5752
"message_target", SPANDATA.GEN_AI_REQUEST_MESSAGES
@@ -61,8 +56,4 @@ def set_input_span_data(span, kwargs, integration, config):
6156
for kwarg_key, span_attr in config.get("pii_params", {}).items():
6257
if kwarg_key in kwargs:
6358
value = kwargs[kwarg_key]
64-
if is_given is None or is_given(value):
65-
set_data_normalized(span, span_attr, value)
66-
67-
for key, value in config.get("extra_static", {}).items():
68-
set_data_normalized(span, key, value)
59+
set_data_normalized(span, span_attr, value)

sentry_sdk/ai/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -503,7 +503,8 @@ def normalize_message_role(role: str) -> str:
503503
Normalize a message role to one of the 4 allowed gen_ai role values.
504504
Maps "ai" -> "assistant" and keeps other standard roles unchanged.
505505
"""
506-
return GEN_AI_MESSAGE_ROLE_MAPPING.get(role, role)
506+
role_lower = role.lower()
507+
return GEN_AI_MESSAGE_ROLE_MAPPING.get(role_lower, role_lower)
507508

508509

509510
def normalize_message_roles(messages: "list[dict[str, Any]]") -> "list[dict[str, Any]]":

sentry_sdk/integrations/cohere/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,10 @@ def _normalize_embedding_input(texts):
3232

3333

3434
COHERE_EMBED_CONFIG = {
35-
"system": "cohere",
36-
"operation": "embeddings",
35+
"static": {
36+
SPANDATA.GEN_AI_SYSTEM: "cohere",
37+
SPANDATA.GEN_AI_OPERATION_NAME: "embeddings",
38+
},
3739
"params": {"model": SPANDATA.GEN_AI_REQUEST_MODEL},
3840
"extract_messages": lambda kw: (
3941
_normalize_embedding_input(kw["texts"]) if "texts" in kw else None

sentry_sdk/integrations/cohere/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import TYPE_CHECKING
44

55
if TYPE_CHECKING:
6-
from typing import Any
6+
from typing import Any, Mapping, Sequence
77

88

99
def transitive_getattr(obj, *attrs):
@@ -17,7 +17,7 @@ def transitive_getattr(obj, *attrs):
1717

1818

1919
def get_first_from_sources(obj, source_paths, require_truthy=False):
20-
# type: (Any, list[tuple[str, ...]], bool) -> Any
20+
# type: (Any, Sequence[tuple[str, ...]], bool) -> Any
2121
for source_path in source_paths:
2222
value = transitive_getattr(obj, *source_path)
2323
if not value:
@@ -28,7 +28,7 @@ def get_first_from_sources(obj, source_paths, require_truthy=False):
2828

2929

3030
def set_span_data_from_sources(span, obj, target_sources, require_truthy):
31-
# type: (Any, Any, dict[str, list[tuple[str, ...]]], bool) -> None
31+
# type: (Any, Any, Mapping[str, Sequence[tuple[str, ...]]], bool) -> None
3232
for spandata_key, source_paths in target_sources.items():
3333
value = get_first_from_sources(obj, source_paths, require_truthy=require_truthy)
3434
if value is not None:

sentry_sdk/integrations/cohere/v1.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,14 @@
3838
except ImportError:
3939
_has_chat_types = False
4040

41+
COHERE_V1_CHAT_CONFIG = {
42+
"static": {
43+
SPANDATA.GEN_AI_SYSTEM: "cohere",
44+
SPANDATA.GEN_AI_OPERATION_NAME: "chat",
45+
},
46+
"extract_messages": lambda kw: _extract_messages(kw),
47+
}
48+
4149
CHAT_RESPONSE_SOURCES = {
4250
SPANDATA.GEN_AI_RESPONSE_ID: [("generation_id",)],
4351
SPANDATA.GEN_AI_RESPONSE_FINISH_REASONS: [("finish_reason",)],
@@ -105,38 +113,31 @@ def new_chat(*args, **kwargs):
105113
reraise(*exc_info)
106114

107115
with capture_internal_exceptions():
116+
span_data = {SPANDATA.GEN_AI_RESPONSE_STREAMING: streaming}
108117
if model:
109-
set_data_normalized(span, SPANDATA.GEN_AI_REQUEST_MODEL, model)
118+
span_data[SPANDATA.GEN_AI_REQUEST_MODEL] = model
110119
set_input_span_data(
111-
span,
112-
kwargs,
113-
integration,
114-
{
115-
"system": "cohere",
116-
"operation": "chat",
117-
"extract_messages": _extract_messages_v1,
118-
"extra_static": {SPANDATA.GEN_AI_RESPONSE_STREAMING: streaming},
119-
},
120+
span, kwargs, integration, COHERE_V1_CHAT_CONFIG, span_data
120121
)
121122

122123
if streaming:
123-
return _iter_v1_stream_events(res, span, include_pii)
124+
return _iter_stream_events(res, span, include_pii)
124125
if isinstance(res, NonStreamedChatResponse):
125-
_collect_v1_response_fields(span, res, include_pii=include_pii)
126+
_collect_response_fields(span, res, include_pii=include_pii)
126127
else:
127128
set_data_normalized(span, "unknown_response", True)
128129
return res
129130

130131
return new_chat
131132

132133

133-
def _extract_messages_v1(kwargs):
134+
def _extract_messages(kwargs):
134135
# type: (dict[str, Any]) -> list[dict[str, str]]
135136
messages = []
136137
for x in kwargs.get("chat_history", []):
137138
messages.append(
138139
{
139-
"role": getattr(x, "role", "").lower(),
140+
"role": getattr(x, "role", ""),
140141
"content": transform_message_content(getattr(x, "message", "")),
141142
}
142143
)
@@ -146,7 +147,7 @@ def _extract_messages_v1(kwargs):
146147
return messages
147148

148149

149-
def _iter_v1_stream_events(old_iterator, span, include_pii):
150+
def _iter_stream_events(old_iterator, span, include_pii):
150151
# type: (Any, Any, bool) -> Iterator[StreamedChatResponse]
151152
with capture_internal_exceptions():
152153
for x in old_iterator:
@@ -161,10 +162,10 @@ def _collect_v1_stream_end_fields(span, event, include_pii):
161162
# type: (Any, Any, bool) -> None
162163
response = get_first_from_sources(event, STREAM_RESPONSE_SOURCES)
163164
if response is not None:
164-
_collect_v1_response_fields(span, response, include_pii)
165+
_collect_response_fields(span, response, include_pii)
165166

166167

167-
def _collect_v1_response_fields(span, response, include_pii):
168+
def _collect_response_fields(span, response, include_pii):
168169
# type: (Any, Any, bool) -> None
169170
if include_pii:
170171
text = get_first_from_sources(response, CHAT_RESPONSE_TEXT_SOURCES)

sentry_sdk/integrations/cohere/v2.py

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,17 @@
4646
except ImportError:
4747
_has_v2 = False
4848

49+
COHERE_V2_CHAT_CONFIG = {
50+
"static": {
51+
SPANDATA.GEN_AI_SYSTEM: "cohere",
52+
SPANDATA.GEN_AI_OPERATION_NAME: "chat",
53+
},
54+
"pii_params": {
55+
"tools": SPANDATA.GEN_AI_REQUEST_AVAILABLE_TOOLS,
56+
},
57+
"extract_messages": lambda kw: _extract_messages_v2(kw.get("messages", [])),
58+
}
59+
4960
CHAT_RESPONSE_SOURCES = {
5061
SPANDATA.GEN_AI_RESPONSE_ID: [("id",)],
5162
SPANDATA.GEN_AI_RESPONSE_FINISH_REASONS: [("finish_reason",)],
@@ -55,7 +66,7 @@
5566
}
5667
CHAT_USAGE_SOURCES = [("usage",)]
5768
STREAM_DELTA_TEXT_SOURCES = [("delta", "message", "content", "text")]
58-
STREAM_CHAT_RESPONSE_SOURCES = {
69+
STREAM_CHAT_RESPONSE_SOURCES: "dict[str, list[tuple[str, ...]]]" = {
5970
SPANDATA.GEN_AI_RESPONSE_ID: [("id",)],
6071
SPANDATA.GEN_AI_RESPONSE_FINISH_REASONS: [("delta", "finish_reason")],
6172
}
@@ -104,29 +115,17 @@ def new_chat(*args, **kwargs):
104115
reraise(*exc_info)
105116

106117
with capture_internal_exceptions():
107-
extra = {SPANDATA.GEN_AI_RESPONSE_STREAMING: streaming}
108-
if model:
109-
set_data_normalized(span, SPANDATA.GEN_AI_REQUEST_MODEL, model)
110-
extra[SPANDATA.GEN_AI_RESPONSE_MODEL] = model
118+
span_data = {
119+
SPANDATA.GEN_AI_RESPONSE_STREAMING: streaming,
120+
SPANDATA.GEN_AI_REQUEST_MODEL: model if model else None,
121+
SPANDATA.GEN_AI_RESPONSE_MODEL: model if model else None,
122+
}
111123
set_input_span_data(
112-
span,
113-
kwargs,
114-
integration,
115-
{
116-
"system": "cohere",
117-
"operation": "chat",
118-
"pii_params": {
119-
"tools": SPANDATA.GEN_AI_REQUEST_AVAILABLE_TOOLS,
120-
},
121-
"extract_messages": lambda kw: _extract_messages_v2(
122-
kw.get("messages", [])
123-
),
124-
"extra_static": extra,
125-
},
124+
span, kwargs, integration, COHERE_V2_CHAT_CONFIG, span_data
126125
)
127126
if streaming:
128127
return _iter_v2_stream_events(res, span, include_pii)
129-
_collect_v2_response_fields(span, res, include_pii=include_pii)
128+
_collect_v2_response_fields(span, res, include_pii)
130129
return res
131130

132131
return new_chat
@@ -146,7 +145,7 @@ def _extract_messages_v2(messages):
146145

147146
def _iter_v2_stream_events(old_iterator, span, include_pii):
148147
# type: (Any, Span, bool) -> Iterator[V2ChatStreamResponse]
149-
collected_text = []
148+
collected_text = [] # type: list[str]
150149
with capture_internal_exceptions():
151150
for x in old_iterator:
152151
_append_stream_delta_text(collected_text, x)

0 commit comments

Comments
 (0)