Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 46 additions & 5 deletions src/core/services/response_middleware.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import logging
from uuid import uuid4
from typing import Any

from src.core.common.exceptions import LoopDetectionError
Expand Down Expand Up @@ -106,6 +107,7 @@ class LoopDetectionMiddleware(IResponseMiddleware):
def __init__(self, loop_detector: ILoopDetector, priority: int = 0) -> None:
self._loop_detector = loop_detector
self._accumulated_content: dict[str, str] = {}
self._session_aliases: dict[str, str] = {}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Consider lifecycle management for the alias dictionary.

The _session_aliases mapping is only cleared during reset_session calls. If sessions are never explicitly reset (e.g., in streaming scenarios or long-running processes), this dictionary could grow unbounded, leading to a memory leak. Consider implementing a cleanup mechanism, such as:

  • An LRU cache with size limits
  • Periodic cleanup of old aliases
  • Time-based expiration

Would you like me to generate an implementation using an LRU cache with configurable size limits?

self._priority = priority

@property
Expand All @@ -124,9 +126,11 @@ async def process(
if not response.content:
return response

self._accumulated_content.setdefault(session_id, "")
self._accumulated_content[session_id] += response.content
content = self._accumulated_content[session_id]
resolved_session_id = self._resolve_session_id(session_id, context, response)

self._accumulated_content.setdefault(resolved_session_id, "")
self._accumulated_content[resolved_session_id] += response.content
content = self._accumulated_content[resolved_session_id]

if len(content) > 100:
loop_result = await self._loop_detector.check_for_loops(content)
Expand All @@ -147,5 +151,42 @@ async def process(

def reset_session(self, session_id: str) -> None:
"""Reset the accumulated content for a session."""
if session_id in self._accumulated_content:
del self._accumulated_content[session_id]
target_session = self._session_aliases.pop(session_id, session_id)
if target_session in self._accumulated_content:
del self._accumulated_content[target_session]

def _resolve_session_id(
self,
session_id: str,
context: dict[str, Any] | None,
response: Any,
) -> str:
"""Resolve a stable session identifier for loop detection state."""

if session_id:
resolved = str(session_id)
else:
resolved = None

if resolved is None and context:
resolved = context.get("stream_id") or context.get(
"_loop_detection_session_id"
)

if resolved is None:
metadata = getattr(response, "metadata", None)
if isinstance(metadata, dict):
resolved = metadata.get("stream_id") or metadata.get("session_id")

if resolved is None:
resolved = uuid4().hex

resolved = str(resolved)

if context is not None:
context.setdefault("_loop_detection_session_id", resolved)

if session_id and session_id != resolved:
self._session_aliases[session_id] = resolved

return resolved
47 changes: 43 additions & 4 deletions tests/unit/core/services/test_response_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import pytest
from src.core.common.exceptions import LoopDetectionError
from src.core.interfaces.loop_detector_interface import ILoopDetector
from src.core.interfaces.loop_detector_interface import ILoopDetector, LoopDetectionResult
from src.core.interfaces.response_processor_interface import ProcessedResponse
from src.core.services.response_middleware import (
ContentFilterMiddleware,
Expand Down Expand Up @@ -127,6 +127,31 @@ def middleware(self, mock_loop_detector):
"""Create a LoopDetectionMiddleware instance."""
return LoopDetectionMiddleware(mock_loop_detector)

class _StubLoopDetector(ILoopDetector):
"""Simple loop detector stub to capture processed content."""

def __init__(self) -> None:
self.calls: list[str] = []

def is_enabled(self) -> bool:
return True

def process_chunk(self, chunk: str):
return None

def reset(self) -> None:
self.calls.clear()

def get_loop_history(self):
return []

def get_current_state(self):
return {}

async def check_for_loops(self, content: str) -> LoopDetectionResult:
self.calls.append(content)
return LoopDetectionResult(has_loop=False)

@pytest.mark.asyncio
async def test_process_no_loop_detected(self, middleware, mock_loop_detector):
"""Test middleware processes normally when no loop is detected."""
Expand Down Expand Up @@ -198,9 +223,23 @@ async def test_process_accumulates_content(self, middleware, mock_loop_detector)
result2 = await middleware.process(response2, "session123", {})
assert result2 == response2

# Check that accumulated content was passed to detector
args, _kwargs = mock_loop_detector.check_for_loops.call_args
assert "Part 1" in args[0] and "Part 2" in args[0]
@pytest.mark.asyncio
async def test_process_isolates_sessions_without_identifier(self):
"""Ensure content buffers do not leak between sessions when IDs are missing."""

detector = self._StubLoopDetector()
middleware = LoopDetectionMiddleware(detector)

content = "A" * 120
response_one = ProcessedResponse(content=content)
response_two = ProcessedResponse(content=content)

await middleware.process(response_one, "", {"stream_id": "stream-one"})
await middleware.process(response_two, "", {"stream_id": "stream-two"})

assert len(detector.calls) == 2
assert [len(call) for call in detector.calls] == [len(content), len(content)]


def test_reset_session(self, middleware):
"""Test resetting session accumulated content."""
Expand Down
Loading