diff --git a/src/attune_rag/expander.py b/src/attune_rag/expander.py index fe61316..84140e7 100644 --- a/src/attune_rag/expander.py +++ b/src/attune_rag/expander.py @@ -2,10 +2,15 @@ from __future__ import annotations +import asyncio import json import logging logger = logging.getLogger(__name__) +# Touch reference so import remains after auto-format passes; the +# real consumer is ``QueryExpander.expand_async`` below. +_ASYNCIO_TO_THREAD = asyncio.to_thread +__all__ = ["QueryExpander"] _SYSTEM = """\ You expand developer queries for a documentation retrieval system. @@ -88,3 +93,18 @@ def expand(self, query: str) -> list[str]: if self._cache is not None: self._cache[query] = expansions return expansions + + async def expand_async(self, query: str) -> list[str]: + """Async variant of :meth:`expand` for use from async event loops. + + Wraps the synchronous Anthropic call in :func:`asyncio.to_thread` + so callers like FastAPI route handlers don't block the event + loop. The cache is shared with :meth:`expand` so a hit on either + path serves a hit on the other. + + Returns the same shape and same fail-soft empty-list semantics + as :meth:`expand`. + """ + if self._cache is not None and query in self._cache: + return self._cache[query] + return await _ASYNCIO_TO_THREAD(self.expand, query) diff --git a/tests/unit/test_expander_reranker.py b/tests/unit/test_expander_reranker.py index cce14cc..3ea0b7a 100644 --- a/tests/unit/test_expander_reranker.py +++ b/tests/unit/test_expander_reranker.py @@ -339,3 +339,63 @@ def test_expander_and_reranker_compose(self, corpus: FakeCorpus): exp_client.messages.create.assert_called_once() rer_client.messages.create.assert_called_once() assert not result.fallback_used + + def test_expand_async_returns_same_as_expand(self) -> None: + """Async variant must produce the same shape as the sync method + for the same input + mock response. + """ + import asyncio + + expander = QueryExpander(cache=False) + mock_client = MagicMock() + mock_client.messages.create.return_value = _fake_response('["a", "b"]') + expander._client = mock_client + result = asyncio.run(expander.expand_async("q")) + assert result == ["a", "b"] + + def test_expand_async_uses_cache_without_to_thread(self) -> None: + """When the cache already has the answer, ``expand_async`` returns + synchronously without dispatching to a thread. + """ + import asyncio + + expander = QueryExpander(cache=True) + mock_client = MagicMock() + mock_client.messages.create.return_value = _fake_response('["from-cache"]') + expander._client = mock_client + + # Prime the cache via the sync method + first = expander.expand("k") + assert first == ["from-cache"] + assert mock_client.messages.create.call_count == 1 + + # Async hit should pull from the same cache, no second API call + second = asyncio.run(expander.expand_async("k")) + assert second == ["from-cache"] + assert mock_client.messages.create.call_count == 1 + + def test_expand_async_does_not_block_event_loop(self) -> None: + """``expand_async`` must dispatch the blocking call via a thread + so the event loop stays responsive while it runs. + """ + import asyncio + import threading + + expander = QueryExpander(cache=False) + mock_client = MagicMock() + api_thread: dict[str, int] = {} + + def slow_create(**_kwargs): + api_thread["tid"] = threading.get_ident() + return _fake_response('["slow"]') + + mock_client.messages.create.side_effect = slow_create + expander._client = mock_client + + async def runner() -> None: + main_tid = threading.get_ident() + await expander.expand_async("q") + # The Anthropic call must not have run on the event-loop thread + assert api_thread["tid"] != main_tid + + asyncio.run(runner())