Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 62 additions & 14 deletions src/core/services/edit_precision_response_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,15 @@ class EditPrecisionResponseMiddleware(IResponseMiddleware):
re.IGNORECASE | re.DOTALL,
),
]
_SESSION_KEY_FIELDS = (
"session_id",
"stream_id",
"id",
"request_id",
"conversation_id",
"thread_id",
"message_id",
)

def __init__(self, app_state: IApplicationState) -> None:
super().__init__(priority=10)
Expand Down Expand Up @@ -80,6 +89,40 @@ def __init__(self, app_state: IApplicationState) -> None:
# Use only default patterns if config loading fails
pass

def _resolve_session_key(
self,
session_id: str,
context: dict[str, Any] | None,
metadata: dict[str, Any] | None,
) -> str | None:
if session_id:
normalized = str(session_id).strip()
if normalized:
return normalized

candidates: list[str] = []
if isinstance(context, dict):
for field in self._SESSION_KEY_FIELDS:
value = context.get(field)
if value is None:
continue
candidate = str(value).strip()
if candidate:
candidates.append(candidate)

if isinstance(metadata, dict):
for field in self._SESSION_KEY_FIELDS:
value = metadata.get(field)
if value is None:
continue
candidate = str(value).strip()
if candidate:
candidates.append(candidate)

if candidates:
return candidates[0]
return None

async def process(
self,
response: Any,
Expand All @@ -97,6 +140,12 @@ async def process(
out = ProcessedResponse(content=text)

metadata = getattr(out, "metadata", {}) or {}
session_key = self._resolve_session_key(session_id, context, metadata)
if session_key is None:
self._logger.debug(
"Edit-precision: unable to determine session scope; skipping state updates"
)
return out

text_sources: list[str] = []
if text:
Expand Down Expand Up @@ -140,14 +189,13 @@ async def process(
except Exception:
pending_map = {}

key = session_id or ""
if key:
if active_disable_map.get(key):
if session_key:
if active_disable_map.get(session_key):
# We already flagged this response; still update stream tracking
self._update_stream_tracking(key, context, out)
self._update_stream_tracking(session_key, context, out)
self._logger.debug(
"Edit-precision: session %s already has hybrid reasoning disable flag",
key,
session_key,
)
return out

Expand All @@ -168,19 +216,19 @@ async def process(
)
except Exception:
stream_id = ""
last_stream_id = self._last_stream_ids.get(key)
last_stream_id = self._last_stream_ids.get(session_key)
if stream_id and last_stream_id == stream_id:
return out

pending_map[key] = int(pending_map.get(key, 0)) + 1
pending_map[session_key] = int(pending_map.get(session_key, 0)) + 1
if response_type == "stream" and stream_id:
self._last_stream_ids[key] = stream_id
self._last_stream_ids[session_key] = stream_id
elif response_type != "stream":
self._last_stream_ids.pop(key, None)
self._last_stream_ids.pop(session_key, None)
self._app_state.set_setting("edit_precision_pending", pending_map)

# Mark hybrid reasoning disable active until consumed by request processor
active_disable_map[key] = {"timestamp": time.time()}
active_disable_map[session_key] = {"timestamp": time.time()}
self._app_state.set_setting(
"edit_precision_hybrid_reasoning_active", active_disable_map
)
Expand All @@ -200,7 +248,7 @@ async def process(
hybrid_reasoning_disabled_map = {}

# Mark that hybrid reasoning should be disabled for next request
hybrid_reasoning_disabled_map[key] = True
hybrid_reasoning_disabled_map[session_key] = True
self._app_state.set_setting(
"edit_precision_hybrid_reasoning_disabled",
hybrid_reasoning_disabled_map,
Expand All @@ -213,14 +261,14 @@ async def process(
)
self._logger.info(
"Edit-precision trigger detected; session_id=%s pattern=%s count=%s response_type=%s",
key,
session_key,
matched_pattern,
pending_map.get(key, 0),
pending_map.get(session_key, 0),
response_type,
)
self._logger.info(
"Hybrid reasoning disabled for next request in session %s due to edit failure",
key,
session_key,
)
except Exception as e:
self._logger.debug(
Expand Down
125 changes: 104 additions & 21 deletions src/core/services/response_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import logging
from typing import Any
from uuid import uuid4

from src.core.common.exceptions import LoopDetectionError
from src.core.interfaces.loop_detector_interface import ILoopDetector
Expand Down Expand Up @@ -107,11 +108,79 @@ def __init__(self, loop_detector: ILoopDetector, priority: int = 0) -> None:
self._loop_detector = loop_detector
self._accumulated_content: dict[str, str] = {}
self._priority = priority
self._anonymous_session_aliases: dict[int, str] = {}

@property
def priority(self) -> int:
return self._priority

def _resolve_session_key(
self,
session_id: str,
context: dict[str, Any] | None,
response: Any,
stop_event: Any,
) -> tuple[str, bool]:
candidate_fields = (
"session_id",
"stream_id",
"id",
"request_id",
"conversation_id",
"thread_id",
"message_id",
)
if session_id:
normalized = str(session_id).strip()
if normalized:
return normalized, False
sources: list[dict[str, Any]] = []
if isinstance(context, dict):
sources.append(context)
metadata = getattr(response, "metadata", None)
if isinstance(metadata, dict):
sources.append(metadata)
for source in sources:
for field in candidate_fields:
try:
value = source.get(field) # type: ignore[call-arg]
except AttributeError:
continue
if value is None:
continue
candidate = str(value).strip()
if candidate:
return candidate, False
if stop_event is not None:
alias = self._anonymous_session_aliases.get(id(stop_event))
if alias is None:
alias = uuid4().hex
self._anonymous_session_aliases[id(stop_event)] = alias
return alias, False
return uuid4().hex, True

def _cleanup_session_state(
self,
resolved_session_id: str,
ephemeral_key: bool,
stop_event: Any,
) -> None:
if ephemeral_key:
self._accumulated_content.pop(resolved_session_id, None)
return
if stop_event is None:
return
alias_id = id(stop_event)
alias_value = self._anonymous_session_aliases.get(alias_id)
try:
is_done = bool(stop_event.is_set()) # type: ignore[attr-defined]
except AttributeError:
is_done = False
if is_done:
if alias_value is not None:
self._accumulated_content.pop(alias_value, None)
self._anonymous_session_aliases.pop(alias_id, None)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Session Cleanup Fails for Metadata-Resolved Sessions

The _cleanup_session_state method fails to clean up accumulated content when sessions are resolved via metadata fields (like stream_id) rather than via stop_event. The cleanup logic checks _anonymous_session_aliases to find which session to clean, but this dictionary is only populated when stop_event is used as the fallback identifier in _resolve_session_key. When a session is resolved earlier via metadata, the mapping is never created, causing accumulated content to persist indefinitely even after the stream completes, resulting in a memory leak.

Fix in Cursor Fix in Web

Comment on lines +162 to +182

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Leak loop detector state for completed anonymous streams

The new cleanup logic only removes _accumulated_content when an alias was registered for the stop_event, but when the session key comes from metadata (e.g. stream_id) no alias exists. After stop_event.is_set() the method returns without popping the resolved key, so each anonymous stream leaves its accumulated text in memory forever. On a busy server this dictionary will grow unbounded and can also contaminate later streams if an identifier is reused. Consider popping resolved_session_id whenever the stop event finishes, even when alias_value is None.

Useful? React with 👍 / 👎.


async def process(
self,
response: Any,
Expand All @@ -121,31 +190,45 @@ async def process(
stop_event: Any = None,
) -> Any:
"""Process a response, checking for loops."""
if not response.content:
return response
resolved_session_id, ephemeral_key = self._resolve_session_key(
session_id, context, response, stop_event
)

try:
if not response.content:
return response

self._accumulated_content.setdefault(session_id, "")
self._accumulated_content[session_id] += response.content
content = self._accumulated_content[session_id]

if len(content) > 100:
loop_result = await self._loop_detector.check_for_loops(content)
if loop_result.has_loop:
error_message = f"Loop detected: The response contains repetitive content. Detected {loop_result.repetitions} repetitions."
logger.warning(
f"Loop detected in session {session_id}: {loop_result.repetitions} repetitions"
)
raise LoopDetectionError(
message=error_message,
details={
"repetitions": loop_result.repetitions,
"pattern": loop_result.pattern,
},
)
previous = self._accumulated_content.get(resolved_session_id, "")
self._accumulated_content[resolved_session_id] = previous + response.content
content = self._accumulated_content[resolved_session_id]

if len(content) > 100:
loop_result = await self._loop_detector.check_for_loops(content)
if loop_result.has_loop:
error_message = (
"Loop detected: The response contains repetitive content. "
f"Detected {loop_result.repetitions} repetitions."
)
logger.warning(
f"Loop detected in session {resolved_session_id}: {loop_result.repetitions} repetitions"
)
raise LoopDetectionError(
message=error_message,
details={
"repetitions": loop_result.repetitions,
"pattern": loop_result.pattern,
"session_id": resolved_session_id,
},
)

return response
return response
finally:
self._cleanup_session_state(resolved_session_id, ephemeral_key, stop_event)

def reset_session(self, session_id: str) -> None:
"""Reset the accumulated content for a session."""
if session_id in self._accumulated_content:
del self._accumulated_content[session_id]
for alias_id, alias_value in list(self._anonymous_session_aliases.items()):
if alias_value == session_id:
self._anonymous_session_aliases.pop(alias_id, None)
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,46 @@ async def test_streaming_processor_applies_middleware_and_sets_pending(
assert "stream-abc" in active_flags


@pytest.mark.asyncio
async def test_streaming_without_session_id_uses_stream_identifier(
app_state: ApplicationStateService,
) -> None:
mw = EditPrecisionResponseMiddleware(app_state)
processor = MiddlewareApplicationProcessor([mw], app_state=app_state)

chunk = StreamingContent(
content="... diff_error ...",
metadata={"stream_id": "anon-stream"},
)

await processor.process(chunk)

pending = app_state.get_setting("edit_precision_pending", {})
assert pending.get("anon-stream", 0) >= 1
active_flags = app_state.get_setting("edit_precision_hybrid_reasoning_active", {})
assert "anon-stream" in active_flags


@pytest.mark.asyncio
async def test_streaming_without_identifiers_skips_state_changes(
app_state: ApplicationStateService,
) -> None:
mw = EditPrecisionResponseMiddleware(app_state)
processor = MiddlewareApplicationProcessor([mw], app_state=app_state)

chunk = StreamingContent(
content="... diff_error ...",
metadata={},
)

await processor.process(chunk)

pending = app_state.get_setting("edit_precision_pending", {})
assert pending == {}
active_flags = app_state.get_setting("edit_precision_hybrid_reasoning_active", {})
assert active_flags == {}


@pytest.mark.asyncio
async def test_streaming_duplicate_without_stream_id_only_flags_once(
app_state: ApplicationStateService,
Expand Down
Loading
Loading