diff --git a/src/core/services/edit_precision_response_middleware.py b/src/core/services/edit_precision_response_middleware.py index 48568c765..6af3508ad 100644 --- a/src/core/services/edit_precision_response_middleware.py +++ b/src/core/services/edit_precision_response_middleware.py @@ -47,6 +47,15 @@ class EditPrecisionResponseMiddleware(IResponseMiddleware): re.IGNORECASE | re.DOTALL, ), ] + _SESSION_KEY_FIELDS = ( + "session_id", + "stream_id", + "id", + "request_id", + "conversation_id", + "thread_id", + "message_id", + ) def __init__(self, app_state: IApplicationState) -> None: super().__init__(priority=10) @@ -80,6 +89,40 @@ def __init__(self, app_state: IApplicationState) -> None: # Use only default patterns if config loading fails pass + def _resolve_session_key( + self, + session_id: str, + context: dict[str, Any] | None, + metadata: dict[str, Any] | None, + ) -> str | None: + if session_id: + normalized = str(session_id).strip() + if normalized: + return normalized + + candidates: list[str] = [] + if isinstance(context, dict): + for field in self._SESSION_KEY_FIELDS: + value = context.get(field) + if value is None: + continue + candidate = str(value).strip() + if candidate: + candidates.append(candidate) + + if isinstance(metadata, dict): + for field in self._SESSION_KEY_FIELDS: + value = metadata.get(field) + if value is None: + continue + candidate = str(value).strip() + if candidate: + candidates.append(candidate) + + if candidates: + return candidates[0] + return None + async def process( self, response: Any, @@ -97,6 +140,12 @@ async def process( out = ProcessedResponse(content=text) metadata = getattr(out, "metadata", {}) or {} + session_key = self._resolve_session_key(session_id, context, metadata) + if session_key is None: + self._logger.debug( + "Edit-precision: unable to determine session scope; skipping state updates" + ) + return out text_sources: list[str] = [] if text: @@ -140,14 +189,13 @@ async def process( except Exception: pending_map = {} - key = session_id or "" - if key: - if active_disable_map.get(key): + if session_key: + if active_disable_map.get(session_key): # We already flagged this response; still update stream tracking - self._update_stream_tracking(key, context, out) + self._update_stream_tracking(session_key, context, out) self._logger.debug( "Edit-precision: session %s already has hybrid reasoning disable flag", - key, + session_key, ) return out @@ -168,19 +216,19 @@ async def process( ) except Exception: stream_id = "" - last_stream_id = self._last_stream_ids.get(key) + last_stream_id = self._last_stream_ids.get(session_key) if stream_id and last_stream_id == stream_id: return out - pending_map[key] = int(pending_map.get(key, 0)) + 1 + pending_map[session_key] = int(pending_map.get(session_key, 0)) + 1 if response_type == "stream" and stream_id: - self._last_stream_ids[key] = stream_id + self._last_stream_ids[session_key] = stream_id elif response_type != "stream": - self._last_stream_ids.pop(key, None) + self._last_stream_ids.pop(session_key, None) self._app_state.set_setting("edit_precision_pending", pending_map) # Mark hybrid reasoning disable active until consumed by request processor - active_disable_map[key] = {"timestamp": time.time()} + active_disable_map[session_key] = {"timestamp": time.time()} self._app_state.set_setting( "edit_precision_hybrid_reasoning_active", active_disable_map ) @@ -200,7 +248,7 @@ async def process( hybrid_reasoning_disabled_map = {} # Mark that hybrid reasoning should be disabled for next request - hybrid_reasoning_disabled_map[key] = True + hybrid_reasoning_disabled_map[session_key] = True self._app_state.set_setting( "edit_precision_hybrid_reasoning_disabled", hybrid_reasoning_disabled_map, @@ -213,14 +261,14 @@ async def process( ) self._logger.info( "Edit-precision trigger detected; session_id=%s pattern=%s count=%s response_type=%s", - key, + session_key, matched_pattern, - pending_map.get(key, 0), + pending_map.get(session_key, 0), response_type, ) self._logger.info( "Hybrid reasoning disabled for next request in session %s due to edit failure", - key, + session_key, ) except Exception as e: self._logger.debug( diff --git a/src/core/services/response_middleware.py b/src/core/services/response_middleware.py index 6140a8d09..72c756116 100644 --- a/src/core/services/response_middleware.py +++ b/src/core/services/response_middleware.py @@ -2,6 +2,7 @@ import logging from typing import Any +from uuid import uuid4 from src.core.common.exceptions import LoopDetectionError from src.core.interfaces.loop_detector_interface import ILoopDetector @@ -107,11 +108,79 @@ def __init__(self, loop_detector: ILoopDetector, priority: int = 0) -> None: self._loop_detector = loop_detector self._accumulated_content: dict[str, str] = {} self._priority = priority + self._anonymous_session_aliases: dict[int, str] = {} @property def priority(self) -> int: return self._priority + def _resolve_session_key( + self, + session_id: str, + context: dict[str, Any] | None, + response: Any, + stop_event: Any, + ) -> tuple[str, bool]: + candidate_fields = ( + "session_id", + "stream_id", + "id", + "request_id", + "conversation_id", + "thread_id", + "message_id", + ) + if session_id: + normalized = str(session_id).strip() + if normalized: + return normalized, False + sources: list[dict[str, Any]] = [] + if isinstance(context, dict): + sources.append(context) + metadata = getattr(response, "metadata", None) + if isinstance(metadata, dict): + sources.append(metadata) + for source in sources: + for field in candidate_fields: + try: + value = source.get(field) # type: ignore[call-arg] + except AttributeError: + continue + if value is None: + continue + candidate = str(value).strip() + if candidate: + return candidate, False + if stop_event is not None: + alias = self._anonymous_session_aliases.get(id(stop_event)) + if alias is None: + alias = uuid4().hex + self._anonymous_session_aliases[id(stop_event)] = alias + return alias, False + return uuid4().hex, True + + def _cleanup_session_state( + self, + resolved_session_id: str, + ephemeral_key: bool, + stop_event: Any, + ) -> None: + if ephemeral_key: + self._accumulated_content.pop(resolved_session_id, None) + return + if stop_event is None: + return + alias_id = id(stop_event) + alias_value = self._anonymous_session_aliases.get(alias_id) + try: + is_done = bool(stop_event.is_set()) # type: ignore[attr-defined] + except AttributeError: + is_done = False + if is_done: + if alias_value is not None: + self._accumulated_content.pop(alias_value, None) + self._anonymous_session_aliases.pop(alias_id, None) + async def process( self, response: Any, @@ -121,31 +190,45 @@ async def process( stop_event: Any = None, ) -> Any: """Process a response, checking for loops.""" - if not response.content: - return response + resolved_session_id, ephemeral_key = self._resolve_session_key( + session_id, context, response, stop_event + ) + + try: + if not response.content: + return response - self._accumulated_content.setdefault(session_id, "") - self._accumulated_content[session_id] += response.content - content = self._accumulated_content[session_id] - - if len(content) > 100: - loop_result = await self._loop_detector.check_for_loops(content) - if loop_result.has_loop: - error_message = f"Loop detected: The response contains repetitive content. Detected {loop_result.repetitions} repetitions." - logger.warning( - f"Loop detected in session {session_id}: {loop_result.repetitions} repetitions" - ) - raise LoopDetectionError( - message=error_message, - details={ - "repetitions": loop_result.repetitions, - "pattern": loop_result.pattern, - }, - ) + previous = self._accumulated_content.get(resolved_session_id, "") + self._accumulated_content[resolved_session_id] = previous + response.content + content = self._accumulated_content[resolved_session_id] + + if len(content) > 100: + loop_result = await self._loop_detector.check_for_loops(content) + if loop_result.has_loop: + error_message = ( + "Loop detected: The response contains repetitive content. " + f"Detected {loop_result.repetitions} repetitions." + ) + logger.warning( + f"Loop detected in session {resolved_session_id}: {loop_result.repetitions} repetitions" + ) + raise LoopDetectionError( + message=error_message, + details={ + "repetitions": loop_result.repetitions, + "pattern": loop_result.pattern, + "session_id": resolved_session_id, + }, + ) - return response + return response + finally: + self._cleanup_session_state(resolved_session_id, ephemeral_key, stop_event) def reset_session(self, session_id: str) -> None: """Reset the accumulated content for a session.""" if session_id in self._accumulated_content: del self._accumulated_content[session_id] + for alias_id, alias_value in list(self._anonymous_session_aliases.items()): + if alias_value == session_id: + self._anonymous_session_aliases.pop(alias_id, None) diff --git a/tests/unit/core/services/test_edit_precision_response_middleware.py b/tests/unit/core/services/test_edit_precision_response_middleware.py index e8f75b669..9782519f7 100644 --- a/tests/unit/core/services/test_edit_precision_response_middleware.py +++ b/tests/unit/core/services/test_edit_precision_response_middleware.py @@ -64,6 +64,46 @@ async def test_streaming_processor_applies_middleware_and_sets_pending( assert "stream-abc" in active_flags +@pytest.mark.asyncio +async def test_streaming_without_session_id_uses_stream_identifier( + app_state: ApplicationStateService, +) -> None: + mw = EditPrecisionResponseMiddleware(app_state) + processor = MiddlewareApplicationProcessor([mw], app_state=app_state) + + chunk = StreamingContent( + content="... diff_error ...", + metadata={"stream_id": "anon-stream"}, + ) + + await processor.process(chunk) + + pending = app_state.get_setting("edit_precision_pending", {}) + assert pending.get("anon-stream", 0) >= 1 + active_flags = app_state.get_setting("edit_precision_hybrid_reasoning_active", {}) + assert "anon-stream" in active_flags + + +@pytest.mark.asyncio +async def test_streaming_without_identifiers_skips_state_changes( + app_state: ApplicationStateService, +) -> None: + mw = EditPrecisionResponseMiddleware(app_state) + processor = MiddlewareApplicationProcessor([mw], app_state=app_state) + + chunk = StreamingContent( + content="... diff_error ...", + metadata={}, + ) + + await processor.process(chunk) + + pending = app_state.get_setting("edit_precision_pending", {}) + assert pending == {} + active_flags = app_state.get_setting("edit_precision_hybrid_reasoning_active", {}) + assert active_flags == {} + + @pytest.mark.asyncio async def test_streaming_duplicate_without_stream_id_only_flags_once( app_state: ApplicationStateService, diff --git a/tests/unit/core/services/test_response_middleware.py b/tests/unit/core/services/test_response_middleware.py index 280636e9f..dd85dd87c 100644 --- a/tests/unit/core/services/test_response_middleware.py +++ b/tests/unit/core/services/test_response_middleware.py @@ -202,6 +202,64 @@ async def test_process_accumulates_content(self, middleware, mock_loop_detector) args, _kwargs = mock_loop_detector.check_for_loops.call_args assert "Part 1" in args[0] and "Part 2" in args[0] + @pytest.mark.asyncio + async def test_streams_without_session_id_are_isolated( + self, middleware, mock_loop_detector + ): + """Ensure anonymous streaming sessions do not leak state between each other.""" + mock_result = MagicMock() + mock_result.has_loop = False + mock_loop_detector.check_for_loops.return_value = mock_result + + class DummyStopEvent: + def __init__(self) -> None: + self._set = False + + def is_set(self) -> bool: + return self._set + + stop_a = DummyStopEvent() + stop_b = DummyStopEvent() + + chunk_a1 = ProcessedResponse( + content="A" * 60, metadata={"stream_id": "stream-a"} + ) + chunk_b = ProcessedResponse( + content="B" * 60, metadata={"stream_id": "stream-b"} + ) + chunk_a2 = ProcessedResponse( + content="A" * 60, metadata={"stream_id": "stream-a"} + ) + + await middleware.process( + chunk_a1, + "", + {"response_type": "stream"}, + is_streaming=True, + stop_event=stop_a, + ) + await middleware.process( + chunk_b, + "", + {"response_type": "stream"}, + is_streaming=True, + stop_event=stop_b, + ) + await middleware.process( + chunk_a2, + "", + {"response_type": "stream"}, + is_streaming=True, + stop_event=stop_a, + ) + + assert "stream-a" in middleware._accumulated_content + assert "stream-b" in middleware._accumulated_content + assert middleware._accumulated_content["stream-b"] == "B" * 60 + + args, _ = mock_loop_detector.check_for_loops.call_args + assert "B" not in args[0] + def test_reset_session(self, middleware): """Test resetting session accumulated content.""" # Manually add content to test reset