Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion src/attune_rag/corpus/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,18 @@

@dataclass(frozen=True)
class RetrievalEntry:
"""A single corpus entry. Task 1.1 shape; task 1.2 wires loaders."""
"""A single corpus entry. Task 1.1 shape; task 1.2 wires loaders.

The ``_tokens_cache`` field is a per-instance mutable cache used by
:mod:`attune_rag.retrieval` to memoize tokenized representations
(path / summary / content-preview / aliases) so the keyword
retriever doesn't re-tokenize on every query. ``frozen=True``
prevents the field itself from being reassigned but doesn't stop
callers from mutating its contents — exactly what we want for a
write-once-on-first-access cache. Excluded from hash, equality,
and repr so two entries with the same content compare equal even
if one has cached tokens and the other hasn't.
"""

path: str
category: str
Expand All @@ -18,6 +29,12 @@ class RetrievalEntry:
related: tuple[str, ...] = ()
aliases: tuple[str, ...] = ()
metadata: dict[str, Any] = field(default_factory=dict)
_tokens_cache: dict[Any, Any] = field(
default_factory=dict,
compare=False,
hash=False,
repr=False,
)


class AliasInfo(TypedDict):
Expand Down
62 changes: 49 additions & 13 deletions src/attune_rag/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,25 +170,61 @@ class KeywordRetriever:
def _category_weight(self, entry: RetrievalEntry) -> float:
return self.CATEGORY_WEIGHTS.get(entry.category, self.DEFAULT_CATEGORY_WEIGHT)

def _entry_field_tokens(self, entry: RetrievalEntry) -> dict[str, set[str]]:
"""Memoized tokens for the entry's own fields (path/summary/content/aliases).

Stored on the entry itself via the ``_tokens_cache`` sidecar so
the keyword retriever stops re-tokenizing on every query — the
review's primary perf concern. Keyed by ``CONTENT_PREVIEW_CHARS``
so a retriever subclass with a different preview size sees
independent cache entries instead of stale ones.
"""
cache_key = ("field_tokens", self.CONTENT_PREVIEW_CHARS)
cached = entry._tokens_cache.get(cache_key)
if cached is not None:
return cached
tokens = {
"path": _tokenize(entry.path),
"summary": _tokenize(entry.summary or ""),
"content_preview": _tokenize(entry.content[: self.CONTENT_PREVIEW_CHARS]),
"aliases": _tokenize(" ".join(entry.aliases)),
}
entry._tokens_cache[cache_key] = tokens
return tokens

def _related_summary_tokens(self, entry: RetrievalEntry, corpus: CorpusProtocol) -> set[str]:
"""Memoized union of related-entry summary tokens.

Cached on the entry under a key that includes the corpus name so
the same entry surfaced across different corpora gets independent
caches. When a corpus rebuilds, fresh entries are created and
the cache is naturally empty.
"""
cache_key = ("related_tokens", corpus.name)
cached = entry._tokens_cache.get(cache_key)
if cached is not None:
return cached
related_tokens: set[str] = set()
for related_path in entry.related:
related_entry = corpus.get(related_path)
if related_entry is None or not related_entry.summary:
continue
related_tokens |= _tokenize(related_entry.summary)
entry._tokens_cache[cache_key] = related_tokens
return related_tokens

def _score_entry(
self,
query_tokens: set[str],
entry: RetrievalEntry,
corpus: CorpusProtocol,
) -> tuple[float, str]:
path_tokens = _tokenize(entry.path)
summary_tokens = _tokenize(entry.summary or "")
content_preview = entry.content[: self.CONTENT_PREVIEW_CHARS]
content_tokens = _tokenize(content_preview)

related_summary_tokens: set[str] = set()
for related_path in entry.related:
related_entry = corpus.get(related_path)
if related_entry is None or not related_entry.summary:
continue
related_summary_tokens |= _tokenize(related_entry.summary)

aliases_tokens = _tokenize(" ".join(entry.aliases))
field_tokens = self._entry_field_tokens(entry)
path_tokens = field_tokens["path"]
summary_tokens = field_tokens["summary"]
content_tokens = field_tokens["content_preview"]
aliases_tokens = field_tokens["aliases"]
related_summary_tokens = self._related_summary_tokens(entry, corpus)

path_hits_raw = len(query_tokens & path_tokens)
path_hits = min(path_hits_raw, self.PATH_HIT_CAP)
Expand Down
89 changes: 89 additions & 0 deletions tests/unit/test_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,3 +344,92 @@ def test_stemming_preserves_short_tokens() -> None:
assert _stem("is") == "is"
assert _stem("on") == "on"
assert _stem("bug") == "bug"


# ---------------------------------------------------------------------------
# Per-entry token cache (precomputation on first use, reused on subsequent)
# ---------------------------------------------------------------------------


def test_entry_field_tokens_cached_after_first_call() -> None:
"""``_entry_field_tokens`` must populate ``_tokens_cache`` and reuse it."""
from attune_rag.retrieval import KeywordRetriever

entry = _entry(
path="concepts/example.md",
category="concepts",
summary="explains the example flow",
content="example flow content",
)
retriever = KeywordRetriever()

first = retriever._entry_field_tokens(entry)
second = retriever._entry_field_tokens(entry)

# Same dict instance returned on both calls — no recomputation
assert first is second
# Cache is keyed by ("field_tokens", CONTENT_PREVIEW_CHARS)
assert ("field_tokens", retriever.CONTENT_PREVIEW_CHARS) in entry._tokens_cache


def test_entry_field_tokens_recomputed_when_preview_size_differs() -> None:
"""A retriever with a different ``CONTENT_PREVIEW_CHARS`` keys its
own cache slot, so the two don't collide.
"""
from attune_rag.retrieval import KeywordRetriever

class WidePreview(KeywordRetriever):
CONTENT_PREVIEW_CHARS = 200 # narrower than the 500 default

entry = _entry(
path="concepts/example.md",
category="concepts",
summary="example",
content="long content " * 100,
)

default = KeywordRetriever()._entry_field_tokens(entry)
narrow = WidePreview()._entry_field_tokens(entry)

assert default is not narrow
assert ("field_tokens", 500) in entry._tokens_cache
assert ("field_tokens", 200) in entry._tokens_cache


def test_score_entry_does_not_re_tokenize_on_repeat_calls() -> None:
"""The hot path: scoring the same entry against multiple queries
should tokenize the entry once. Patches ``_tokenize`` to count.
"""
from attune_rag import retrieval as rmod
from attune_rag.retrieval import KeywordRetriever

real_tokenize = rmod._tokenize
calls = 0

def counting(text):
nonlocal calls
calls += 1
return real_tokenize(text)

entry = _entry(
path="concepts/example.md",
category="concepts",
summary="explains the example flow",
content="example flow content",
)
corpus = FakeCorpus([entry])
retriever = KeywordRetriever()

# Prime the cache by calling once with one query
rmod._tokenize = counting
try:
retriever._score_entry({"example"}, entry, corpus)
first_pass = calls
# Subsequent scorings of the SAME entry against different queries
# must not re-tokenize the entry's fields.
retriever._score_entry({"flow"}, entry, corpus)
retriever._score_entry({"content"}, entry, corpus)
finally:
rmod._tokenize = real_tokenize

assert calls == first_pass, "field tokens must not be recomputed across queries"
Loading