Skip to content

Commit 3cbaad6

Browse files
committed
move text massaging to core file
1 parent c0a0e10 commit 3cbaad6

5 files changed

Lines changed: 145 additions & 113 deletions

File tree

sentry_sdk/ai/span_config.py

Lines changed: 20 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -33,27 +33,31 @@ def set_request_span_data(span, kwargs, integration, config, span_data=None):
3333
set_data_normalized(span, span_attr, value)
3434

3535
if should_send_default_pii() and integration.include_prompts:
36-
extract = config.get("extract_messages")
37-
if extract is not None:
38-
messages = extract(kwargs)
39-
if messages:
40-
messages = normalize_message_roles(messages)
41-
scope = sentry_sdk.get_current_scope()
42-
messages = truncate_and_annotate_messages(messages, span, scope)
43-
if messages is not None:
44-
target = config.get(
45-
"message_target", SPANDATA.GEN_AI_REQUEST_MESSAGES
46-
)
47-
set_data_normalized(span, target, messages, unpack=False)
48-
4936
for kwarg_key, span_attr in config.get("pii_params", {}).items():
5037
if kwarg_key in kwargs:
5138
value = kwargs[kwarg_key]
5239
set_data_normalized(span, span_attr, value)
5340

5441

42+
def set_request_messages(span, messages, target=None):
43+
# type: (Span, Any, Optional[str]) -> None
44+
"""Normalize, truncate, and set request messages on the span.
45+
46+
Caller is responsible for PII gating.
47+
"""
48+
if not messages:
49+
return
50+
messages = normalize_message_roles(messages)
51+
scope = sentry_sdk.get_current_scope()
52+
messages = truncate_and_annotate_messages(messages, span, scope)
53+
if messages is not None:
54+
set_data_normalized(
55+
span, target or SPANDATA.GEN_AI_REQUEST_MESSAGES, messages, unpack=False
56+
)
57+
58+
5559
def set_response_span_data(
56-
span, response, include_pii, response_config, collected_text=None
60+
span, response, include_pii, response_config, response_text=None
5761
):
5862
# type: (Span, Any, bool, Dict[str, Any], Optional[List[str]]) -> None
5963
"""Set response span data from a declarative config."""
@@ -65,16 +69,8 @@ def set_response_span_data(
6569
pii_sources = response_config.get("pii_sources")
6670
if pii_sources:
6771
set_span_data_from_sources(span, response, pii_sources, require_truthy=True)
68-
if collected_text:
69-
set_data_normalized(
70-
span, SPANDATA.GEN_AI_RESPONSE_TEXT, ["".join(collected_text)]
71-
)
72-
else:
73-
extract_text = response_config.get("extract_text")
74-
if extract_text:
75-
texts = extract_text(response)
76-
if texts:
77-
set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_TEXT, texts)
72+
if response_text:
73+
set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_TEXT, response_text)
7874

7975
usage_config = response_config.get("usage")
8076
if usage_config:

sentry_sdk/integrations/cohere/__init__.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
11
import sys
22
from functools import wraps
33

4-
from sentry_sdk.consts import OP
5-
from sentry_sdk.ai.span_config import set_request_span_data, set_response_span_data
4+
from sentry_sdk.consts import OP, SPANDATA
5+
from sentry_sdk.ai.span_config import (
6+
set_request_span_data,
7+
set_request_messages,
8+
set_response_span_data,
9+
)
610
from sentry_sdk.integrations.cohere.configs import COHERE_EMBED_CONFIG
711

812
from typing import TYPE_CHECKING
913

14+
from sentry_sdk.scope import should_send_default_pii
1015
from sentry_sdk.tracing_utils import set_span_errored
1116

1217
if TYPE_CHECKING:
@@ -64,24 +69,41 @@ def new_embed(*args, **kwargs):
6469

6570
model = kwargs.get("model", "")
6671

72+
include_pii = should_send_default_pii() and integration.include_prompts
73+
6774
with sentry_sdk.start_span(
6875
op=OP.GEN_AI_EMBEDDINGS,
6976
name=f"embeddings {model}".strip(),
7077
origin=CohereIntegration.origin,
7178
) as span:
7279
set_request_span_data(span, kwargs, integration, COHERE_EMBED_CONFIG)
80+
if include_pii and "texts" in kwargs:
81+
set_request_messages(
82+
span,
83+
_normalize_embedding_input(kwargs["texts"]),
84+
target=SPANDATA.GEN_AI_EMBEDDINGS_INPUT,
85+
)
7386

7487
try:
75-
res = f(*args, **kwargs)
88+
response = f(*args, **kwargs)
7689
except Exception as e:
7790
exc_info = sys.exc_info()
7891
with capture_internal_exceptions():
7992
_capture_exception(e)
8093
reraise(*exc_info)
8194

8295
set_response_span_data(
83-
span, res, False, COHERE_EMBED_CONFIG["response"]
96+
span, response, False, COHERE_EMBED_CONFIG["response"]
8497
)
85-
return res
98+
return response
8699

87100
return new_embed
101+
102+
103+
def _normalize_embedding_input(texts):
104+
# type: (Any) -> Any
105+
if isinstance(texts, list):
106+
return texts
107+
if isinstance(texts, tuple):
108+
return list(texts)
109+
return [texts]

sentry_sdk/integrations/cohere/configs.py

Lines changed: 1 addition & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,9 @@
1-
from sentry_sdk.ai.utils import (
2-
get_first_from_sources,
3-
transform_message_content,
4-
)
51
from sentry_sdk.consts import SPANDATA
62

73
from typing import TYPE_CHECKING
84

95
if TYPE_CHECKING:
10-
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple
6+
from typing import Any, Dict, Sequence, Tuple
117
from typing_extensions import TypedDict
128

139
# Source paths: list of attribute chains to try in order.
@@ -32,9 +28,6 @@ class ResponseConfig(TypedDict, total=False):
3228
sources: SourceMapping
3329
# Attributes extracted only when PII sending is enabled.
3430
pii_sources: SourceMapping
35-
# Custom extractor for response text (PII only).
36-
# Returns list of text strings, or None.
37-
extract_text: Callable[[Any], Optional[List[str]]]
3831
# Declarative token usage paths.
3932
usage: UsageConfig
4033

@@ -47,10 +40,6 @@ class OperationConfig(TypedDict, total=False):
4740
params: Dict[str, str]
4841
# Maps kwarg names to SPANDATA keys (only set when PII is enabled).
4942
pii_params: Dict[str, str]
50-
# Extracts messages from kwargs for the span.
51-
extract_messages: Callable[[Dict[str, Any]], Optional[List[Dict[str, Any]]]]
52-
# SPANDATA key for messages (default: GEN_AI_REQUEST_MESSAGES).
53-
message_target: str
5443
# Non-streaming response config.
5544
response: ResponseConfig
5645
# Streaming response config (different attribute paths).
@@ -60,62 +49,6 @@ class OperationConfig(TypedDict, total=False):
6049
stream_response_object: SourcePaths
6150

6251

63-
# ── Helpers ──────────────────────────────────────────────────────────────────
64-
65-
66-
def _normalize_embedding_input(texts):
67-
# type: (Any) -> Any
68-
if isinstance(texts, list):
69-
return texts
70-
if isinstance(texts, tuple):
71-
return list(texts)
72-
return [texts]
73-
74-
75-
def _extract_v1_messages(kwargs):
76-
# type: (dict[str, Any]) -> list[dict[str, str]]
77-
messages = []
78-
for x in kwargs.get("chat_history", []):
79-
messages.append(
80-
{
81-
"role": getattr(x, "role", ""),
82-
"content": transform_message_content(getattr(x, "message", "")),
83-
}
84-
)
85-
message = kwargs.get("message")
86-
if message:
87-
messages.append({"role": "user", "content": transform_message_content(message)})
88-
return messages
89-
90-
91-
def _extract_v1_response_text(response):
92-
# type: (Any) -> list[str] | None
93-
text = getattr(response, "text", None)
94-
return [text] if text is not None else None
95-
96-
97-
def _extract_v2_messages(messages):
98-
# type: (Any) -> list[dict[str, Any]]
99-
result = []
100-
for msg in messages:
101-
role = msg["role"] if isinstance(msg, dict) else getattr(msg, "role", "unknown")
102-
content = (
103-
msg["content"] if isinstance(msg, dict) else getattr(msg, "content", "")
104-
)
105-
result.append({"role": role, "content": transform_message_content(content)})
106-
return result
107-
108-
109-
def _extract_v2_response_text(response):
110-
# type: (Any) -> list[str] | None
111-
content = get_first_from_sources(response, [("message", "content")], True)
112-
if content:
113-
texts = [item.text for item in content if hasattr(item, "text")]
114-
if texts:
115-
return texts
116-
return None
117-
118-
11952
# ── Configs ──────────────────────────────────────────────────────────────────
12053

12154

@@ -125,10 +58,6 @@ def _extract_v2_response_text(response):
12558
SPANDATA.GEN_AI_OPERATION_NAME: "embeddings",
12659
},
12760
"params": {"model": SPANDATA.GEN_AI_REQUEST_MODEL},
128-
"extract_messages": lambda kw: (
129-
_normalize_embedding_input(kw["texts"]) if "texts" in kw else None
130-
),
131-
"message_target": SPANDATA.GEN_AI_EMBEDDINGS_INPUT,
13261
"response": {
13362
"usage": {
13463
"input_tokens": [("meta", "billed_units", "input_tokens")],
@@ -143,7 +72,6 @@ def _extract_v2_response_text(response):
14372
SPANDATA.GEN_AI_SYSTEM: "cohere",
14473
SPANDATA.GEN_AI_OPERATION_NAME: "chat",
14574
},
146-
"extract_messages": lambda kw: _extract_v1_messages(kw),
14775
"response": {
14876
"sources": {
14977
SPANDATA.GEN_AI_RESPONSE_MODEL: [("model",)],
@@ -153,7 +81,6 @@ def _extract_v2_response_text(response):
15381
"pii_sources": {
15482
SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS: [("tool_calls",)],
15583
},
156-
"extract_text": _extract_v1_response_text,
15784
"usage": {
15885
"input_tokens": [
15986
("meta", "billed_units", "input_tokens"),
@@ -180,7 +107,6 @@ def _extract_v2_response_text(response):
180107
"pii_params": {
181108
"tools": SPANDATA.GEN_AI_REQUEST_AVAILABLE_TOOLS,
182109
},
183-
"extract_messages": lambda kw: _extract_v2_messages(kw.get("messages", [])),
184110
"response": {
185111
"sources": {
186112
SPANDATA.GEN_AI_RESPONSE_MODEL: [("model",)],
@@ -190,7 +116,6 @@ def _extract_v2_response_text(response):
190116
"pii_sources": {
191117
SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS: [("message", "tool_calls")],
192118
},
193-
"extract_text": _extract_v2_response_text,
194119
"usage": {
195120
"input_tokens": [
196121
("usage", "billed_units", "input_tokens"),

sentry_sdk/integrations/cohere/v1.py

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
import sys
22
from functools import wraps
33

4-
from sentry_sdk.ai.span_config import set_request_span_data, set_response_span_data
5-
from sentry_sdk.ai.utils import get_first_from_sources
4+
from sentry_sdk.ai.span_config import (
5+
set_request_span_data,
6+
set_request_messages,
7+
set_response_span_data,
8+
)
9+
from sentry_sdk.ai.utils import get_first_from_sources, transform_message_content
610
from sentry_sdk.consts import OP, SPANDATA
711
from sentry_sdk.integrations.cohere.configs import COHERE_V1_CHAT_CONFIG
812

@@ -88,11 +92,20 @@ def new_chat(*args, **kwargs):
8892
set_request_span_data(
8993
span, kwargs, integration, COHERE_V1_CHAT_CONFIG, span_data
9094
)
95+
if include_pii:
96+
set_request_messages(span, _extract_v1_messages(kwargs))
9197

9298
if streaming:
9399
return _iter_stream_events(response, span, include_pii)
100+
response_text = (
101+
_extract_v1_response_text(response) if include_pii else None
102+
)
94103
set_response_span_data(
95-
span, response, include_pii, COHERE_V1_CHAT_CONFIG["response"]
104+
span,
105+
response,
106+
include_pii,
107+
COHERE_V1_CHAT_CONFIG["response"],
108+
response_text,
96109
)
97110
return response
98111

@@ -110,7 +123,36 @@ def _iter_stream_events(old_iterator, span, include_pii):
110123
x, COHERE_V1_CHAT_CONFIG["stream_response_object"]
111124
)
112125
if response is not None:
126+
response_text = (
127+
_extract_v1_response_text(response) if include_pii else None
128+
)
113129
set_response_span_data(
114-
span, response, include_pii, COHERE_V1_CHAT_CONFIG["response"]
130+
span,
131+
response,
132+
include_pii,
133+
COHERE_V1_CHAT_CONFIG["response"],
134+
response_text,
115135
)
116136
yield x
137+
138+
139+
def _extract_v1_messages(kwargs):
140+
# type: (Any) -> list[dict[str, str]]
141+
messages = []
142+
for x in kwargs.get("chat_history", []):
143+
messages.append(
144+
{
145+
"role": getattr(x, "role", ""),
146+
"content": transform_message_content(getattr(x, "message", "")),
147+
}
148+
)
149+
message = kwargs.get("message")
150+
if message:
151+
messages.append({"role": "user", "content": transform_message_content(message)})
152+
return messages
153+
154+
155+
def _extract_v1_response_text(response):
156+
# type: (Any) -> list[str] | None
157+
text = getattr(response, "text", None)
158+
return [text] if text is not None else None

0 commit comments

Comments
 (0)