Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 98 additions & 26 deletions src/core/services/streaming/stream_utils.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,98 @@
from __future__ import annotations

"""Utility helpers for streaming response processors."""

from uuid import uuid4

from src.core.ports.streaming import StreamingContent


def get_stream_id(content: StreamingContent) -> str:
"""Return a stable identifier for the current stream.

Processors rely on this value to keep per-stream buffers isolated. The
identifier is sourced from the chunk metadata when available. If the
upstream pipeline has not yet assigned one, a new UUID is generated and
stored back into the metadata so that subsequent processors can reuse it.
"""

metadata = content.metadata
stream_id = (
metadata.get("stream_id") or metadata.get("session_id") or metadata.get("id")
)
if not stream_id:
stream_id = uuid4().hex
metadata["stream_id"] = stream_id
return str(stream_id)
from __future__ import annotations

"""Utility helpers for streaming response processors."""

from typing import Any
from uuid import uuid4
import threading

from src.core.ports.streaming import StreamingContent

_UNIQUE_METADATA_KEYS = (
"stream_id",
"request_id",
"response_id",
"id",
"chunk_id",
"event_id",
)

_StreamKey = tuple[str | None, str | None, str | None]

_fallback_lock = threading.Lock()
_active_stream_ids: dict[_StreamKey, str] = {}
_reverse_stream_keys: dict[str, _StreamKey] = {}


def _normalize_component(value: Any) -> str | None:
"""Normalize arbitrary metadata values to comparable strings."""

if value is None:
return None
try:
text = str(value)
except Exception:
return None
return text or None


def _build_fallback_key(metadata: dict[str, Any]) -> _StreamKey:
"""Construct a key used when explicit stream identifiers are missing."""

request_component = _normalize_component(
metadata.get("request_id") or metadata.get("response_id")
)
id_component = _normalize_component(
metadata.get("id")
or metadata.get("chunk_id")
or metadata.get("event_id")
)
session_component = _normalize_component(metadata.get("session_id"))
return (request_component, id_component, session_component)


def get_stream_id(content: StreamingContent) -> str:
"""Return a stable identifier for the current stream.

Processors rely on this value to keep per-stream buffers isolated. The
identifier is sourced from the chunk metadata when available. If the
upstream pipeline has not yet assigned one, a new UUID is generated and
stored back into the metadata so that subsequent processors can reuse it.

When multiple streaming responses share the same session identifier (for
example, parallel requests from the same client), we prefer more specific
metadata such as request IDs so that each stream remains isolated.
"""

metadata = content.metadata
raw_stream_id = metadata.get("stream_id")
stream_id: str | None = _normalize_component(raw_stream_id)

if stream_id is None:
for key in _UNIQUE_METADATA_KEYS[1:]:
candidate = _normalize_component(metadata.get(key))
if candidate:
stream_id = candidate
break

if stream_id is None:
fallback_key = _build_fallback_key(metadata)
if fallback_key != (None, None, None):
with _fallback_lock:
stream_id = _active_stream_ids.get(fallback_key)
if stream_id is None:
stream_id = uuid4().hex
_active_stream_ids[fallback_key] = stream_id
_reverse_stream_keys[stream_id] = fallback_key
else:
stream_id = uuid4().hex

metadata["stream_id"] = stream_id

if content.is_done or content.is_cancellation:
with _fallback_lock:
fallback_key = _reverse_stream_keys.pop(stream_id, None)
if fallback_key is not None:
_active_stream_ids.pop(fallback_key, None)
Comment on lines +92 to +96

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Clean stale fallback stream mapping when real id arrives

When a stream starts without stream_id, this function stores a generated UUID in _active_stream_ids and _reverse_stream_keys. If later chunks include an explicit stream_id or request_id, the loop above returns that explicit value, but the cleanup here uses the new identifier to pop the reverse map. The original UUID entry is never removed, so future streams that only carry the same session metadata will reuse the stale UUID and share buffers, and the dictionaries leak entries indefinitely. Consider removing or updating the fallback mapping whenever the resolved stream_id differs from the recorded fallback value.

Useful? React with 👍 / 👎.


return stream_id
57 changes: 57 additions & 0 deletions tests/unit/core/services/streaming/test_stream_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from __future__ import annotations

from src.core.ports.streaming import StreamingContent
from src.core.services.streaming.stream_utils import get_stream_id


def _build_chunk(
*,
session_id: str | None = None,
request_id: str | None = None,
stream_id: str | None = None,
is_done: bool = False,
) -> StreamingContent:
metadata: dict[str, str] = {}
if session_id is not None:
metadata["session_id"] = session_id
if request_id is not None:
metadata["request_id"] = request_id
if stream_id is not None:
metadata["stream_id"] = stream_id
return StreamingContent(content="", is_done=is_done, metadata=metadata)


def test_get_stream_id_prefers_request_over_session() -> None:
"""Distinct request identifiers must yield isolated stream identifiers."""

first_chunk = _build_chunk(session_id="session-1", request_id="req-a")
second_chunk = _build_chunk(session_id="session-1", request_id="req-b")

first_stream_id = get_stream_id(first_chunk)
second_stream_id = get_stream_id(second_chunk)

assert first_stream_id != second_stream_id

# Subsequent chunks for the same request must reuse the original identifier.
repeat_chunk = _build_chunk(session_id="session-1", request_id="req-a")
assert get_stream_id(repeat_chunk) == first_stream_id


def test_get_stream_id_releases_mapping_on_completion() -> None:
"""Completing a stream should allow a fresh identifier for fallback lookups."""

chunk = _build_chunk(session_id="session-42")
original_stream_id = get_stream_id(chunk)

completion = _build_chunk(
session_id="session-42",
stream_id=original_stream_id,
is_done=True,
)
# Calling get_stream_id on the completion chunk should clean up state.
assert get_stream_id(completion) == original_stream_id

new_chunk = _build_chunk(session_id="session-42")
refreshed_stream_id = get_stream_id(new_chunk)

assert refreshed_stream_id != original_stream_id
Loading