Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions src/attune_rag/expander.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,15 @@

from __future__ import annotations

import asyncio
import json
import logging

logger = logging.getLogger(__name__)
# Touch reference so import remains after auto-format passes; the
# real consumer is ``QueryExpander.expand_async`` below.
_ASYNCIO_TO_THREAD = asyncio.to_thread
__all__ = ["QueryExpander"]

_SYSTEM = """\
You expand developer queries for a documentation retrieval system.
Expand Down Expand Up @@ -88,3 +93,18 @@ def expand(self, query: str) -> list[str]:
if self._cache is not None:
self._cache[query] = expansions
return expansions

async def expand_async(self, query: str) -> list[str]:
"""Async variant of :meth:`expand` for use from async event loops.

Wraps the synchronous Anthropic call in :func:`asyncio.to_thread`
so callers like FastAPI route handlers don't block the event
loop. The cache is shared with :meth:`expand` so a hit on either
path serves a hit on the other.

Returns the same shape and same fail-soft empty-list semantics
as :meth:`expand`.
"""
if self._cache is not None and query in self._cache:
return self._cache[query]
return await _ASYNCIO_TO_THREAD(self.expand, query)
60 changes: 60 additions & 0 deletions tests/unit/test_expander_reranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,3 +339,63 @@ def test_expander_and_reranker_compose(self, corpus: FakeCorpus):
exp_client.messages.create.assert_called_once()
rer_client.messages.create.assert_called_once()
assert not result.fallback_used

def test_expand_async_returns_same_as_expand(self) -> None:
"""Async variant must produce the same shape as the sync method
for the same input + mock response.
"""
import asyncio

expander = QueryExpander(cache=False)
mock_client = MagicMock()
mock_client.messages.create.return_value = _fake_response('["a", "b"]')
expander._client = mock_client
result = asyncio.run(expander.expand_async("q"))
assert result == ["a", "b"]

def test_expand_async_uses_cache_without_to_thread(self) -> None:
"""When the cache already has the answer, ``expand_async`` returns
synchronously without dispatching to a thread.
"""
import asyncio

expander = QueryExpander(cache=True)
mock_client = MagicMock()
mock_client.messages.create.return_value = _fake_response('["from-cache"]')
expander._client = mock_client

# Prime the cache via the sync method
first = expander.expand("k")
assert first == ["from-cache"]
assert mock_client.messages.create.call_count == 1

# Async hit should pull from the same cache, no second API call
second = asyncio.run(expander.expand_async("k"))
assert second == ["from-cache"]
assert mock_client.messages.create.call_count == 1

def test_expand_async_does_not_block_event_loop(self) -> None:
"""``expand_async`` must dispatch the blocking call via a thread
so the event loop stays responsive while it runs.
"""
import asyncio
import threading

expander = QueryExpander(cache=False)
mock_client = MagicMock()
api_thread: dict[str, int] = {}

def slow_create(**_kwargs):
api_thread["tid"] = threading.get_ident()
return _fake_response('["slow"]')

mock_client.messages.create.side_effect = slow_create
expander._client = mock_client

async def runner() -> None:
main_tid = threading.get_ident()
await expander.expand_async("q")
# The Anthropic call must not have run on the event-loop thread
assert api_thread["tid"] != main_tid

asyncio.run(runner())
Loading