Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 35 additions & 17 deletions sentry_sdk/integrations/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
from sentry_sdk.consts import SPANDATA
from sentry_sdk.integrations import DidNotEnable, Integration
from sentry_sdk.scope import should_send_default_pii
from sentry_sdk.tracing_utils import (
has_span_streaming_enabled,
should_truncate_gen_ai_input,
)
from sentry_sdk.utils import event_from_exception

if TYPE_CHECKING:
Expand Down Expand Up @@ -68,7 +72,8 @@ def _convert_message_parts(messages: "List[Dict[str, Any]]") -> "List[Dict[str,

def _input_callback(kwargs: "Dict[str, Any]") -> None:
"""Handle the start of a request."""
integration = sentry_sdk.get_client().get_integration(LiteLLMIntegration)
client = sentry_sdk.get_client()
integration = client.get_integration(LiteLLMIntegration)

if integration is None:
return
Expand All @@ -88,16 +93,29 @@ def _input_callback(kwargs: "Dict[str, Any]") -> None:
operation = "chat"

# Start a new span/transaction
span = get_start_span_function()(
op=(
consts.OP.GEN_AI_CHAT
if operation == "chat"
else consts.OP.GEN_AI_EMBEDDINGS
),
name=f"{operation} {model}",
origin=LiteLLMIntegration.origin,
)
span.__enter__()
if has_span_streaming_enabled(client.options):
span = sentry_sdk.traces.start_span(
name=f"{operation} {model}",
attributes={
"sentry.op": (
consts.OP.GEN_AI_CHAT
if operation == "chat"
else consts.OP.GEN_AI_EMBEDDINGS
),
"sentry.origin": LiteLLMIntegration.origin,
},
)
else:
span = get_start_span_function()(
op=(
consts.OP.GEN_AI_CHAT
if operation == "chat"
else consts.OP.GEN_AI_EMBEDDINGS
),
name=f"{operation} {model}",
origin=LiteLLMIntegration.origin,
)
span.__enter__()

# Store span for later
_get_metadata_dict(kwargs)["_sentry_span"] = span
Expand All @@ -121,9 +139,9 @@ def _input_callback(kwargs: "Dict[str, Any]") -> None:
)
client = sentry_sdk.get_client()
messages_data = (
input_list
if client.options.get("stream_gen_ai_spans", False)
else truncate_and_annotate_embedding_inputs(input_list, span, scope)
truncate_and_annotate_embedding_inputs(input_list, span, scope)
if should_truncate_gen_ai_input(client.options)
else input_list
)
if messages_data is not None:
set_data_normalized(
Expand All @@ -140,9 +158,9 @@ def _input_callback(kwargs: "Dict[str, Any]") -> None:
scope = sentry_sdk.get_current_scope()
messages = _convert_message_parts(messages)
messages_data = (
messages
if client.options.get("stream_gen_ai_spans", False)
else truncate_and_annotate_messages(messages, span, scope)
truncate_and_annotate_messages(messages, span, scope)
if should_truncate_gen_ai_input(client.options)
else messages
)
if messages_data is not None:
set_data_normalized(
Expand Down
9 changes: 9 additions & 0 deletions sentry_sdk/tracing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,15 @@ def has_span_streaming_enabled(options: "Optional[dict[str, Any]]") -> bool:
return (options.get("_experiments") or {}).get("trace_lifecycle") == "stream"


def should_truncate_gen_ai_input(options: "Optional[dict[str, Any]]") -> bool:
if options is None:
return True

return not options.get(
"stream_gen_ai_spans", False
) and not has_span_streaming_enabled(options)


@contextlib.contextmanager
def record_sql_queries(
cursor: "Any",
Expand Down
Loading
Loading