diff --git a/src/whygraph/db/migrations/versions/b4de974b9f54_add_rationale_cache.py b/src/whygraph/db/migrations/versions/b4de974b9f54_add_rationale_cache.py new file mode 100644 index 0000000..41ced86 --- /dev/null +++ b/src/whygraph/db/migrations/versions/b4de974b9f54_add_rationale_cache.py @@ -0,0 +1,47 @@ +"""add rationale_cache + +Revision ID: b4de974b9f54 +Revises: 4ebdddf127cf +Create Date: 2026-05-23 19:59:51.000000 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'b4de974b9f54' +down_revision: Union[str, Sequence[str], None] = '4ebdddf127cf' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + op.create_table('rationale_cache', + sa.Column('path', sa.Text(), nullable=False), + sa.Column('line_start', sa.Integer(), nullable=False), + sa.Column('line_end', sa.Integer(), nullable=False), + sa.Column('provider', sa.Text(), nullable=False), + sa.Column('model', sa.Text(), nullable=False), + sa.Column('evidence_fingerprint', sa.Text(), nullable=False), + sa.Column('cached_at', sa.Text(), nullable=False), + sa.Column('purpose', sa.Text(), nullable=False), + sa.Column('why', sa.Text(), nullable=False), + sa.Column('constraints', sa.Text(), nullable=False), + sa.Column('tradeoffs', sa.Text(), nullable=False), + sa.Column('risks', sa.Text(), nullable=False), + sa.Column('input_tokens', sa.Integer(), nullable=True), + sa.Column('output_tokens', sa.Integer(), nullable=True), + sa.Column('actual_provider', sa.Text(), nullable=True), + sa.Column('actual_model', sa.Text(), nullable=True), + sa.Column('qualified_name', sa.Text(), nullable=True), + sa.PrimaryKeyConstraint('path', 'line_start', 'line_end', 'provider', 'model') + ) + + +def downgrade() -> None: + """Downgrade schema.""" + op.drop_table('rationale_cache') diff --git a/src/whygraph/db/models/__init__.py b/src/whygraph/db/models/__init__.py index dbffba5..a067aa7 100644 --- a/src/whygraph/db/models/__init__.py +++ b/src/whygraph/db/models/__init__.py @@ -20,18 +20,10 @@ * Float columns use :class:`sqlalchemy.REAL` for the same reason. * Several columns are typed ``str`` even though they hold JSON-encoded Python lists (e.g. ``Commit.parent_shas``, ``PullRequest.labels``, - ``PullRequest.commit_titles``). Callers encode/decode with ``json`` at - the boundary. Moving to a proper JSON column type is a follow-up that - needs a real Alembic migration. - -Current models cover the five tables whose auto-derived snake_case name -does *not* collide with the hand-rolled tables owned by -:mod:`whygraph.scan.db`: ``Author``, ``Commit``, ``Issue``, -``PullRequest``, ``PRIssueLink``. The remaining two scan-owned tables -(``rationale_cache``, ``scan_state``) intentionally have no SQLModel -yet — their natural snake_case names collide with scan/db.py, so they -will get models as part of the eventual scanner-side migration to ORM -rather than today. + ``PullRequest.commit_titles``, ``RationaleCache.constraints``). + Callers encode/decode with ``json`` at the boundary. Moving to a + proper JSON column type is a follow-up that needs a real Alembic + migration. """ from __future__ import annotations @@ -41,5 +33,6 @@ from whygraph.db.models.issue import Issue from whygraph.db.models.pr_issue_link import PRIssueLink from whygraph.db.models.pull_request import PullRequest +from whygraph.db.models.rationale_cache import RationaleCache -__all__ = ["Author", "Commit", "Issue", "PRIssueLink", "PullRequest"] +__all__ = ["Author", "Commit", "Issue", "PRIssueLink", "PullRequest", "RationaleCache"] diff --git a/src/whygraph/db/models/rationale_cache.py b/src/whygraph/db/models/rationale_cache.py new file mode 100644 index 0000000..9e0fadc --- /dev/null +++ b/src/whygraph/db/models/rationale_cache.py @@ -0,0 +1,62 @@ +"""SQLModel for the ``rationale_cache`` table. + +One row per cached LLM-generated rationale, keyed by target plus the +``(provider, model)`` identity of the LLM that produced it. Lookups +happen *after* evidence collection so that a change in the blamed-commit +set — a new commit landing on those lines — invalidates the cache via +the ``evidence_fingerprint`` column without needing TTLs. + +Notes +----- +The list-shaped rationale fields (``constraints``, ``tradeoffs``, +``risks``) are stored as JSON-encoded strings, matching the convention +already used by :attr:`Commit.parent_shas`, :attr:`PullRequest.labels`, +and :attr:`Issue.labels`. Callers encode/decode at the boundary +(:mod:`whygraph.mcp.rationale_cache`). + +``model`` is part of the composite PK; when +:attr:`whygraph.core.config.RationaleConfig.model` is ``None`` the cache +key uses the literal string ``"default"``. The LLM-reported model +identity lands in the separate ``actual_model`` column so rows keyed +under ``"default"`` retain provenance. +""" + +from __future__ import annotations + +from sqlalchemy import Text +from sqlmodel import Field + +from whygraph.db.base import WhygraphTable + + +class RationaleCache(WhygraphTable, table=True): + """Cached :class:`whygraph.analyze.Rationale` for a (target, LLM) pair. + + The composite PK ``(path, line_start, line_end, provider, model)`` + lets two LLMs cache their results for the same target side by side. + ``qualified_name`` is observational only — a path/line target may + have no symbol attached, and including it in the PK would split the + cache between symbol and line-range lookups of the same lines. + """ + + path: str = Field(primary_key=True, sa_type=Text) + line_start: int = Field(primary_key=True) + line_end: int = Field(primary_key=True) + provider: str = Field(primary_key=True, sa_type=Text) + model: str = Field(primary_key=True, sa_type=Text) + + evidence_fingerprint: str = Field(sa_type=Text) + cached_at: str = Field(sa_type=Text) + + purpose: str = Field(sa_type=Text) + why: str = Field(sa_type=Text) + constraints: str = Field(sa_type=Text) # JSON-encoded list[str] + tradeoffs: str = Field(sa_type=Text) # JSON-encoded list[str] + risks: str = Field(sa_type=Text) # JSON-encoded list[str] + + input_tokens: int | None = Field(default=None) + output_tokens: int | None = Field(default=None) + + actual_provider: str | None = Field(default=None, sa_type=Text) + actual_model: str | None = Field(default=None, sa_type=Text) + qualified_name: str | None = Field(default=None, sa_type=Text) diff --git a/src/whygraph/mcp/rationale.py b/src/whygraph/mcp/rationale.py index bdac193..849ab6b 100644 --- a/src/whygraph/mcp/rationale.py +++ b/src/whygraph/mcp/rationale.py @@ -15,9 +15,12 @@ from whygraph.services.codegraph import CodeGraph, CodeGraphError, SymbolContext from whygraph.services.llm import LlmError +from whygraph.analyze import CommitEvidence, Rationale + from .errors import WhyGraphError from .targets import Target, repo_root, resolve_target, target_dict from .evidence import collect_evidence +from .rationale_cache import lookup_cached, store_cached _TOOL_DESCRIPTION = ( "Generate a structured rationale card (purpose / why / constraints / " @@ -48,6 +51,31 @@ def _symbol_context(target: Target) -> SymbolContext | None: return None +def _format_response( + target: Target, + rationale: Rationale, + evidence: list[CommitEvidence], + cached_at: str, +) -> dict: + """Shape the MCP response payload around a (fresh or cached) rationale.""" + return { + "target": target_dict(target), + "purpose": rationale.purpose, + "why": rationale.why, + "constraints": list(rationale.constraints), + "tradeoffs": list(rationale.tradeoffs), + "risks": list(rationale.risks), + "model": rationale.model, + "provider": rationale.provider, + "cached_at": cached_at, + "evidence_count": { + "commits": len(evidence), + "prs": sum(len(item.pull_requests) for item in evidence), + "issues": sum(len(item.issues) for item in evidence), + }, + } + + def whygraph_rationale_brief( path: str | None = None, line_start: int | None = None, @@ -57,6 +85,10 @@ def whygraph_rationale_brief( """MCP tool — a rationale card for a chunk of code. See :data:`_TOOL_DESCRIPTION` for the agent-facing summary. + + A previously generated card is returned from the SQLite-backed cache + (see :mod:`whygraph.mcp.rationale_cache`) when the same target, + provider, model, and evidence fingerprint are all unchanged. """ target = resolve_target( path=path, @@ -71,27 +103,20 @@ def whygraph_rationale_brief( "scanned commit. Run `whygraph scan` to populate the database." ) + config = get_config().rationale + cached = lookup_cached(target, evidence, config.provider, config.model) + if cached is not None: + rationale, cached_at = cached + return _format_response(target, rationale, evidence, cached_at) + try: - generator = RationaleGenerator.from_config(get_config().rationale) + generator = RationaleGenerator.from_config(config) rationale = generator.generate(evidence, symbol_context=_symbol_context(target)) except (AnalyzeError, LlmError) as exc: raise WhyGraphError.wrap("rationale generation failed", exc) - return { - "target": target_dict(target), - "purpose": rationale.purpose, - "why": rationale.why, - "constraints": list(rationale.constraints), - "tradeoffs": list(rationale.tradeoffs), - "risks": list(rationale.risks), - "model": rationale.model, - "provider": rationale.provider, - "evidence_count": { - "commits": len(evidence), - "prs": sum(len(item.pull_requests) for item in evidence), - "issues": sum(len(item.issues) for item in evidence), - }, - } + cached_at = store_cached(target, evidence, rationale, config.provider, config.model) + return _format_response(target, rationale, evidence, cached_at) def register(mcp: FastMCP) -> None: diff --git a/src/whygraph/mcp/rationale_cache.py b/src/whygraph/mcp/rationale_cache.py new file mode 100644 index 0000000..76f9d86 --- /dev/null +++ b/src/whygraph/mcp/rationale_cache.py @@ -0,0 +1,169 @@ +"""SQLite-backed cache for ``whygraph_rationale_brief`` LLM output. + +Keys a cached card by ``(path, line_start, line_end, provider, model)`` +and invalidates by an *evidence fingerprint* — the sha256 of the sorted +commit SHAs returned by :func:`whygraph.mcp.evidence.collect_evidence`. +A new commit landing on the blamed lines changes the fingerprint and +forces a regeneration on the next call; the stale row is overwritten by +:func:`store_cached`. + +Notes +----- +:attr:`whygraph.core.config.RationaleConfig.model` can be ``None`` +(meaning *use whatever model the provider's adapter defaults to*). The +cache PK still needs a deterministic ``model`` token at lookup time — +*before* the LLM call returns and reports its actual model identity — so +``None`` is translated to the literal string ``"default"`` via +:func:`_model_key`. The LLM-reported identity is persisted separately in +:attr:`RationaleCache.actual_model` so rows keyed under ``"default"`` +keep their provenance. Pinning ``rationale.model`` in ``whygraph.toml`` +gives the cleanest per-model cache semantics. + +The fingerprint is computed only over commit SHAs — PR/issue updates +that don't change the underlying commit set do not invalidate the cache. +That matches :func:`collect_evidence`'s own derivation: PRs and issues +are looked up *from* the commit set. +""" + +from __future__ import annotations + +import hashlib +import json +from datetime import datetime, timezone +from typing import Sequence + +from whygraph.analyze import CommitEvidence, Rationale +from whygraph.db import get_session +from whygraph.db.models import RationaleCache + +from .targets import Target + + +_DEFAULT_MODEL_TAG = "default" + + +def _fingerprint(evidence: Sequence[CommitEvidence]) -> str: + """sha256 of the newline-joined, sorted commit SHAs from ``evidence``. + + Sorting decouples the fingerprint from :func:`collect_evidence`'s + return order, so a stable evidence set hashes to the same value + regardless of timing or future ordering tweaks. + """ + payload = "\n".join(sorted(item.commit.sha for item in evidence)) + return hashlib.sha256(payload.encode("utf-8")).hexdigest() + + +def _model_key(config_model: str | None) -> str: + """Translate a (possibly absent) configured model into a cache-key token.""" + return config_model if config_model else _DEFAULT_MODEL_TAG + + +def _now_iso() -> str: + """UTC ISO-8601 timestamp at second resolution. + + Matches the timestamp shape used by :attr:`Commit.committed_at`, + :attr:`Issue.created_at`, and the rest of the WhyGraph schema. + """ + return datetime.now(timezone.utc).isoformat(timespec="seconds") + + +def lookup_cached( + target: Target, + evidence: Sequence[CommitEvidence], + provider: str, + config_model: str | None, +) -> tuple[Rationale, str] | None: + """Return ``(rationale, cached_at)`` for ``target`` or ``None`` on miss. + + A hit requires the row's ``evidence_fingerprint`` to match the + fingerprint of ``evidence``; a row with a stale fingerprint is + treated as a miss and will be overwritten by the next + :func:`store_cached` call. + + Parameters + ---------- + target + Resolved target (path + line range + optional symbol name). + evidence + The evidence sequence as returned by + :func:`whygraph.mcp.evidence.collect_evidence` — its commit SHAs + drive the fingerprint check. + provider + LLM provider tag from :attr:`RationaleConfig.provider`. + config_model + Configured model name (or ``None`` for the provider default); + translated to the cache-key token via :func:`_model_key`. + + Returns + ------- + tuple of (Rationale, str), or None + Reconstructed :class:`~whygraph.analyze.Rationale` and the ISO + timestamp the row was originally cached at. ``None`` if no row + matches the key or its fingerprint is stale. + """ + fp = _fingerprint(evidence) + model_key = _model_key(config_model) + with get_session() as session: + row = session.get( + RationaleCache, + (target.path, target.line_start, target.line_end, provider, model_key), + ) + if row is None or row.evidence_fingerprint != fp: + return None + rationale = Rationale( + purpose=row.purpose, + why=row.why, + constraints=tuple(json.loads(row.constraints)), + tradeoffs=tuple(json.loads(row.tradeoffs)), + risks=tuple(json.loads(row.risks)), + model=row.actual_model or row.model, + provider=row.actual_provider or row.provider, + input_tokens=row.input_tokens, + output_tokens=row.output_tokens, + ) + cached_at = row.cached_at + return rationale, cached_at + + +def store_cached( + target: Target, + evidence: Sequence[CommitEvidence], + rationale: Rationale, + provider: str, + config_model: str | None, +) -> str: + """Upsert a freshly generated rationale; return its ``cached_at``. + + The PK ``(provider, model)`` columns mirror the *configured* LLM + identity passed in by the caller — they must match what + :func:`lookup_cached` will use, since the lookup happens *before* + the LLM is invoked. The LLM-reported model identity is persisted + separately in :attr:`RationaleCache.actual_model` so rows keyed + under the ``"default"`` model tag retain provenance. + """ + fp = _fingerprint(evidence) + model_key = _model_key(config_model) + cached_at = _now_iso() + with get_session() as session: + session.merge( + RationaleCache( + path=target.path, + line_start=target.line_start, + line_end=target.line_end, + provider=provider, + model=model_key, + evidence_fingerprint=fp, + cached_at=cached_at, + purpose=rationale.purpose, + why=rationale.why, + constraints=json.dumps(list(rationale.constraints)), + tradeoffs=json.dumps(list(rationale.tradeoffs)), + risks=json.dumps(list(rationale.risks)), + input_tokens=rationale.input_tokens, + output_tokens=rationale.output_tokens, + actual_provider=rationale.provider, + actual_model=rationale.model, + qualified_name=target.qualified_name, + ) + ) + return cached_at diff --git a/tests/test_db_plumbing.py b/tests/test_db_plumbing.py index f34470b..a7f91a8 100644 --- a/tests/test_db_plumbing.py +++ b/tests/test_db_plumbing.py @@ -19,7 +19,14 @@ from whygraph.db import engine as db_engine from whygraph.db.bootstrap import alembic_config -SQLMODEL_TABLES = {"author", "commit", "issue", "pr_issue_link", "pull_request"} +SQLMODEL_TABLES = { + "author", + "commit", + "issue", + "pr_issue_link", + "pull_request", + "rationale_cache", +} @pytest.fixture(autouse=True) diff --git a/tests/test_mcp_rationale_cache.py b/tests/test_mcp_rationale_cache.py new file mode 100644 index 0000000..b1d4bc0 --- /dev/null +++ b/tests/test_mcp_rationale_cache.py @@ -0,0 +1,200 @@ +"""Tests for the SQLite-backed rationale cache. + +Exercises :mod:`whygraph.mcp.rationale_cache` end-to-end via the +:func:`whygraph_rationale_brief` MCP tool — the LLM round-trip is +stubbed (``_CountingGenerator``) so we can assert exact call counts +across repeat invocations and an invalidation event. +""" + +from __future__ import annotations + +import json +import subprocess +from collections.abc import Sequence +from pathlib import Path + +import pytest + +from whygraph.analyze import CommitEvidence, Rationale +from whygraph.db import get_session +from whygraph.db.models import Commit, RationaleCache +from whygraph.mcp.rationale import whygraph_rationale_brief +from whygraph.mcp.rationale_cache import _fingerprint, lookup_cached +from whygraph.mcp.targets import Target +from whygraph.services.codegraph import SymbolContext +from whygraph.services.git import Repository + + +def _db_commit(sha: str, *, committed_at: str) -> Commit: + """A WhyGraph ``commit`` row with sensible defaults for tests.""" + return Commit( + sha=sha, + parent_shas="", + author_name="Test User", + author_email="tester@example.com", + authored_at=committed_at, + committed_at=committed_at, + subject="a change", + body="", + files_changed=1, + insertions=1, + deletions=0, + scanned_at="2026-05-01T00:00:00+00:00", + llm_description="Mechanical diff summary.", + ) + + +def _seed_two_commits(repo_root: Path) -> None: + """Seed the WhyGraph DB with the two commits of ``temp_git_repo``.""" + newest, oldest = list(Repository(repo_root).commits) + with get_session() as session: + session.add(_db_commit(oldest.sha, committed_at="2026-01-01T00:00:00+00:00")) + session.add(_db_commit(newest.sha, committed_at="2026-02-01T00:00:00+00:00")) + + +def _add_third_commit(repo_root: Path) -> None: + """Land one more commit on ``sample.py`` and seed it into the DB. + + Rewrites line 2 so a blame of lines 1-3 now spans *three* commits + instead of two — the evidence fingerprint changes and a previously + cached card must be regenerated. + """ + sample = repo_root / "sample.py" + sample.write_text("line one\nline two updated\nline three\n") + subprocess.run( + ["git", "add", "sample.py"], cwd=repo_root, check=True, capture_output=True + ) + subprocess.run( + ["git", "commit", "-m", "third commit"], + cwd=repo_root, + check=True, + capture_output=True, + ) + newest = list(Repository(repo_root).commits)[0] + with get_session() as session: + session.add(_db_commit(newest.sha, committed_at="2026-03-01T00:00:00+00:00")) + + +class _CountingGenerator: + """Stub for :class:`RationaleGenerator` — counts how often it's called.""" + + calls = 0 + + @classmethod + def reset(cls) -> None: + cls.calls = 0 + + @classmethod + def from_config(cls, config: object) -> "_CountingGenerator": + return cls() + + def generate( + self, + evidence: Sequence[CommitEvidence], + *, + symbol_context: SymbolContext | None = None, + ) -> Rationale: + type(self).calls += 1 + return Rationale( + purpose="Holds two sample lines.", + why="Built up across two commits.", + constraints=("keep it small",), + tradeoffs=(), + risks=("none worth noting",), + model="fake-1", + provider="fake", + ) + + +def test_second_call_returns_cached( + temp_git_repo: Path, + whygraph_db_initialized: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + _seed_two_commits(temp_git_repo) + monkeypatch.chdir(temp_git_repo) + _CountingGenerator.reset() + monkeypatch.setattr( + "whygraph.mcp.rationale.RationaleGenerator", _CountingGenerator + ) + + first = whygraph_rationale_brief(path="sample.py", line_start=1, line_end=3) + second = whygraph_rationale_brief(path="sample.py", line_start=1, line_end=3) + + assert _CountingGenerator.calls == 1 + assert first["cached_at"] == second["cached_at"] + assert first["purpose"] == second["purpose"] == "Holds two sample lines." + assert first["constraints"] == second["constraints"] == ["keep it small"] + assert first["model"] == second["model"] == "fake-1" + + +def test_new_commit_invalidates_cache( + temp_git_repo: Path, + whygraph_db_initialized: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + _seed_two_commits(temp_git_repo) + monkeypatch.chdir(temp_git_repo) + _CountingGenerator.reset() + monkeypatch.setattr( + "whygraph.mcp.rationale.RationaleGenerator", _CountingGenerator + ) + + first = whygraph_rationale_brief(path="sample.py", line_start=1, line_end=3) + _add_third_commit(temp_git_repo) + second = whygraph_rationale_brief(path="sample.py", line_start=1, line_end=3) + + assert _CountingGenerator.calls == 2 + assert first["evidence_count"]["commits"] == 2 + assert second["evidence_count"]["commits"] == 3 + # ``cached_at`` is the row's persisted timestamp; the regenerated + # entry overwrites the old one, so the second call's stamp is + # always >= the first. + assert second["cached_at"] >= first["cached_at"] + + +def test_lookup_cached_returns_none_on_stale_fingerprint( + temp_git_repo: Path, + whygraph_db_initialized: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Stale ``evidence_fingerprint`` is treated as a miss, not a hit.""" + from whygraph.mcp.evidence import collect_evidence + + _seed_two_commits(temp_git_repo) + monkeypatch.chdir(temp_git_repo) + target = Target(path="sample.py", line_start=1, line_end=3, qualified_name=None) + evidence = collect_evidence(target, limit=20) + assert evidence, "fixture must produce at least one evidence row" + + with get_session() as session: + session.add( + RationaleCache( + path="sample.py", + line_start=1, + line_end=3, + provider="fake", + model="default", + evidence_fingerprint="bogus-fingerprint", + cached_at="2026-01-01T00:00:00+00:00", + purpose="stale", + why="stale", + constraints=json.dumps([]), + tradeoffs=json.dumps([]), + risks=json.dumps([]), + actual_model="fake-1", + ) + ) + + assert lookup_cached(target, evidence, "fake", None) is None + + +def test_fingerprint_independent_of_evidence_order() -> None: + """``_fingerprint`` sorts SHAs — order in the input list must not matter.""" + e1 = [ + CommitEvidence(commit=_db_commit("aaa", committed_at="t1")), + CommitEvidence(commit=_db_commit("bbb", committed_at="t2")), + CommitEvidence(commit=_db_commit("ccc", committed_at="t3")), + ] + e2 = list(reversed(e1)) + assert _fingerprint(e1) == _fingerprint(e2)