diff --git a/src/core/services/response_middleware.py b/src/core/services/response_middleware.py index 6140a8d09..a6b67cf19 100644 --- a/src/core/services/response_middleware.py +++ b/src/core/services/response_middleware.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +from uuid import uuid4 from typing import Any from src.core.common.exceptions import LoopDetectionError @@ -106,6 +107,7 @@ class LoopDetectionMiddleware(IResponseMiddleware): def __init__(self, loop_detector: ILoopDetector, priority: int = 0) -> None: self._loop_detector = loop_detector self._accumulated_content: dict[str, str] = {} + self._session_aliases: dict[str, str] = {} self._priority = priority @property @@ -124,9 +126,11 @@ async def process( if not response.content: return response - self._accumulated_content.setdefault(session_id, "") - self._accumulated_content[session_id] += response.content - content = self._accumulated_content[session_id] + resolved_session_id = self._resolve_session_id(session_id, context, response) + + self._accumulated_content.setdefault(resolved_session_id, "") + self._accumulated_content[resolved_session_id] += response.content + content = self._accumulated_content[resolved_session_id] if len(content) > 100: loop_result = await self._loop_detector.check_for_loops(content) @@ -147,5 +151,42 @@ async def process( def reset_session(self, session_id: str) -> None: """Reset the accumulated content for a session.""" - if session_id in self._accumulated_content: - del self._accumulated_content[session_id] + target_session = self._session_aliases.pop(session_id, session_id) + if target_session in self._accumulated_content: + del self._accumulated_content[target_session] + + def _resolve_session_id( + self, + session_id: str, + context: dict[str, Any] | None, + response: Any, + ) -> str: + """Resolve a stable session identifier for loop detection state.""" + + if session_id: + resolved = str(session_id) + else: + resolved = None + + if resolved is None and context: + resolved = context.get("stream_id") or context.get( + "_loop_detection_session_id" + ) + + if resolved is None: + metadata = getattr(response, "metadata", None) + if isinstance(metadata, dict): + resolved = metadata.get("stream_id") or metadata.get("session_id") + + if resolved is None: + resolved = uuid4().hex + + resolved = str(resolved) + + if context is not None: + context.setdefault("_loop_detection_session_id", resolved) + + if session_id and session_id != resolved: + self._session_aliases[session_id] = resolved + + return resolved diff --git a/tests/unit/core/services/test_response_middleware.py b/tests/unit/core/services/test_response_middleware.py index 280636e9f..b961cbb8a 100644 --- a/tests/unit/core/services/test_response_middleware.py +++ b/tests/unit/core/services/test_response_middleware.py @@ -5,7 +5,7 @@ import pytest from src.core.common.exceptions import LoopDetectionError -from src.core.interfaces.loop_detector_interface import ILoopDetector +from src.core.interfaces.loop_detector_interface import ILoopDetector, LoopDetectionResult from src.core.interfaces.response_processor_interface import ProcessedResponse from src.core.services.response_middleware import ( ContentFilterMiddleware, @@ -127,6 +127,31 @@ def middleware(self, mock_loop_detector): """Create a LoopDetectionMiddleware instance.""" return LoopDetectionMiddleware(mock_loop_detector) + class _StubLoopDetector(ILoopDetector): + """Simple loop detector stub to capture processed content.""" + + def __init__(self) -> None: + self.calls: list[str] = [] + + def is_enabled(self) -> bool: + return True + + def process_chunk(self, chunk: str): + return None + + def reset(self) -> None: + self.calls.clear() + + def get_loop_history(self): + return [] + + def get_current_state(self): + return {} + + async def check_for_loops(self, content: str) -> LoopDetectionResult: + self.calls.append(content) + return LoopDetectionResult(has_loop=False) + @pytest.mark.asyncio async def test_process_no_loop_detected(self, middleware, mock_loop_detector): """Test middleware processes normally when no loop is detected.""" @@ -198,9 +223,23 @@ async def test_process_accumulates_content(self, middleware, mock_loop_detector) result2 = await middleware.process(response2, "session123", {}) assert result2 == response2 - # Check that accumulated content was passed to detector - args, _kwargs = mock_loop_detector.check_for_loops.call_args - assert "Part 1" in args[0] and "Part 2" in args[0] + @pytest.mark.asyncio + async def test_process_isolates_sessions_without_identifier(self): + """Ensure content buffers do not leak between sessions when IDs are missing.""" + + detector = self._StubLoopDetector() + middleware = LoopDetectionMiddleware(detector) + + content = "A" * 120 + response_one = ProcessedResponse(content=content) + response_two = ProcessedResponse(content=content) + + await middleware.process(response_one, "", {"stream_id": "stream-one"}) + await middleware.process(response_two, "", {"stream_id": "stream-two"}) + + assert len(detector.calls) == 2 + assert [len(call) for call in detector.calls] == [len(content), len(content)] + def test_reset_session(self, middleware): """Test resetting session accumulated content."""