diff --git a/pyproject.toml b/pyproject.toml index eca6d7c..18d0bc5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "attune-author" -version = "0.5.1" +version = "0.6.0" description = "Documentation authoring and maintenance for the attune ecosystem — generate, maintain, and validate help content with AI assistance." readme = {file = "README.md", content-type = "text/markdown"} requires-python = ">=3.10" diff --git a/src/attune_author/cli.py b/src/attune_author/cli.py index 6ff6d87..24f6610 100644 --- a/src/attune_author/cli.py +++ b/src/attune_author/cli.py @@ -169,6 +169,27 @@ def _build_parser() -> argparse.ArgumentParser: help="Report stale features without regenerating.", ) + p_cache = sub.add_parser( + "cache", + help="Manage the on-disk polish cache", + description=( + "Inspect and clear the on-disk LLM polish cache used by the " + "generator. Entries are pruned automatically by mtime (default " + "TTL 30 days, configurable via ATTUNE_AUTHOR_POLISH_CACHE_TTL_SECONDS); " + "this command exposes a manual nuke." + ), + ) + cache_sub = p_cache.add_subparsers(dest="cache_command", help="Cache subcommands") + cache_sub.add_parser( + "clear", + help="Delete every cached polish entry", + description=( + "Remove all entries from the polish cache directory. Useful " + "after a prompt change in attune-author itself, or to reclaim " + "disk space without waiting for the TTL sweep." + ), + ) + p_docs = sub.add_parser( "docs", help="Generate docs from source (requires [ai])", @@ -220,6 +241,7 @@ def _dispatch(args: argparse.Namespace, parser: argparse.ArgumentParser) -> int: "generate": _cmd_generate, "regenerate": _cmd_regenerate, "docs": _cmd_docs, + "cache": _cmd_cache, } handler = handlers.get(args.command) if handler is None: @@ -403,6 +425,20 @@ def _cmd_regenerate(args: argparse.Namespace) -> int: return 0 +def _cmd_cache(args: argparse.Namespace) -> int: + """Handle the cache command and its subcommands.""" + from attune_author.polish import _cache_dir, clear_cache + + if args.cache_command == "clear": + deleted = clear_cache() + cache_path = _cache_dir() + print(f"Cleared {deleted} entries from {cache_path}") + return 0 + + print("Usage: attune-author cache clear", file=sys.stderr) + return 1 + + def _cmd_docs(args: argparse.Namespace) -> int: """Handle the docs command.""" if not args.target: diff --git a/src/attune_author/doc_gen/_anthropic.py b/src/attune_author/doc_gen/_anthropic.py index 1194c0a..7d24579 100644 --- a/src/attune_author/doc_gen/_anthropic.py +++ b/src/attune_author/doc_gen/_anthropic.py @@ -12,6 +12,7 @@ import logging import os import re +import time from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -19,6 +20,9 @@ logger = logging.getLogger(__name__) +_MAX_RETRIES = 3 +_RETRY_BASE_DELAY = 1.0 # seconds; doubles each attempt + #: Source-content character budgets per doc-gen stage. Tuned so #: the outline and review stages see enough code for accuracy #: without dominating the prompt context, while the write stage @@ -55,6 +59,19 @@ def _redact(text: str) -> str: return _KEY_PATTERN.sub(_REDACTED, text) +def _is_retryable(exc: Exception) -> bool: + """Return True for transient Anthropic errors that are safe to retry.""" + try: + from anthropic import APIConnectionError, APIStatusError + except ImportError: + return False + if isinstance(exc, APIConnectionError): + return True + if isinstance(exc, APIStatusError): + return exc.status_code in (429, 529) + return False + + def get_client(api_key: str | None = None) -> Anthropic: """Instantiate an Anthropic client. @@ -85,11 +102,11 @@ def call_anthropic( model: str, max_tokens: int, ) -> str: - """Make a single-turn ``messages.create`` call. + """Make a single-turn ``messages.create`` call with retry/backoff. - Wraps the SDK call so every caller shares identical error - handling, message shape, and response unwrapping. Any - exception raised by the SDK is re-raised as + Retries up to ``_MAX_RETRIES`` times on transient errors (rate + limits and overload responses). Non-transient SDK errors fail + immediately. All exceptions are re-raised as :class:`AnthropicCallError` with a redacted message and an empty ``__cause__`` chain to guarantee credential material cannot leak through ``str(exc.__cause__)``. @@ -106,23 +123,40 @@ def call_anthropic( string if the response carried no content. Raises: - AnthropicCallError: On any SDK or transport failure. + AnthropicCallError: On any SDK or transport failure after + retries are exhausted. """ - try: - response = client.messages.create( - model=model, - max_tokens=max_tokens, - system=system, - messages=[{"role": "user", "content": user_message}], - ) - except Exception as exc: # noqa: BLE001 - # INTENTIONAL: every SDK exception type funnels through - # one redaction pass so credential material can't leak - # into logs, error surfaces, or upstream exception - # chains. `from None` strips __cause__ so callers - # inspecting the chain only ever see the redacted form. - raise AnthropicCallError(_redact(str(exc))) from None - - if response.content: - return response.content[0].text - return "" + last_exc: Exception | None = None + for attempt in range(_MAX_RETRIES + 1): + if attempt: + delay = _RETRY_BASE_DELAY * (2 ** (attempt - 1)) + logger.warning( + "Anthropic call failed (attempt %d/%d), retrying in %.1fs: %s", + attempt, + _MAX_RETRIES, + delay, + _redact(str(last_exc)), + ) + time.sleep(delay) + try: + response = client.messages.create( + model=model, + max_tokens=max_tokens, + system=system, + messages=[{"role": "user", "content": user_message}], + ) + if response.content: + return response.content[0].text + return "" + except Exception as exc: # noqa: BLE001 + # INTENTIONAL: every SDK exception type funnels through + # one redaction pass so credential material can't leak + # into logs, error surfaces, or upstream exception + # chains. `from None` strips __cause__ so callers + # inspecting the chain only ever see the redacted form. + if _is_retryable(exc): + last_exc = exc + continue + raise AnthropicCallError(_redact(str(exc))) from None + + raise AnthropicCallError(_redact(str(last_exc))) from None diff --git a/src/attune_author/generator.py b/src/attune_author/generator.py index 9078f32..b4eb25d 100644 --- a/src/attune_author/generator.py +++ b/src/attune_author/generator.py @@ -14,6 +14,7 @@ import ast import logging +from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import dataclass, field from datetime import datetime, timezone from pathlib import Path @@ -25,6 +26,55 @@ logger = logging.getLogger(__name__) +#: Cap on concurrent LLM calls during the parallel polish phase. +#: Sized to comfortably fit under Anthropic's per-minute rate +#: limits while still saturating the LLM-bound wall time of a +#: typical ``regenerate --all-kinds`` run. +_POLISH_MAX_WORKERS = 4 + + +def _parallel_polish( + pending: list[tuple[str, str, Path]], + feature: object, + source_info: object, + use_rag: bool, +) -> dict[str, tuple[str, Path]]: + """Polish a batch of rendered templates concurrently. + + Args: + pending: List of (depth, rendered_content, out_path) tuples. + feature: Feature being documented (read-only, thread-safe). + source_info: Extracted source info (read-only, thread-safe). + use_rag: Whether to use RAG grounding during polish. + + Returns: + Mapping of depth -> (polished_content, out_path). Raises + the first exception encountered (propagated from the future). + """ + + def _task(depth: str, content: str, out_path: Path) -> tuple[str, str, Path]: + polished = _maybe_polish( + content, + feature, # type: ignore[arg-type] + source_info, # type: ignore[arg-type] + template_type=depth, + use_rag=use_rag, + ) + return depth, polished, out_path + + results: dict[str, tuple[str, Path]] = {} + workers = min(len(pending), _POLISH_MAX_WORKERS) + with ThreadPoolExecutor(max_workers=workers) as executor: + futures = { + executor.submit(_task, depth, content, out_path): depth + for depth, content, out_path in pending + } + for future in as_completed(futures): + depth, polished, out_path = future.result() + results[depth] = (polished, out_path) + return results + + #: Core progressive-depth template kinds. These form the #: progressive disclosure path that attune-help renders: #: concept → task → reference. They are generated by @@ -234,6 +284,9 @@ def generate_feature_templates( ", ".join(feature.doc_paths[1:]), ) + # Phase 1: render all templates (fast Jinja2, sequential). + # Determines which depths are active and builds the rendered skeleton. + pending: list[tuple[str, str, Path]] = [] for depth in target_depths: if depth not in _ALL_TEMPLATE_NAMES: logger.warning("Unknown template kind '%s', skipping", depth) @@ -278,17 +331,15 @@ def generate_feature_templates( source_hash=source_hash, source_info=source_info, ) + pending.append((depth, content, out_path)) - # LLM polish pass — improves writing quality - content = _maybe_polish( - content, - feature, - source_info, - template_type=depth, - use_rag=use_rag, - ) + # Phase 2: LLM polish — run all depths concurrently. + polished = _parallel_polish(pending, feature, source_info, use_rag) - out_path.write_text(content, encoding="utf-8") + # Phase 3: write results in original depth order. + for depth, content, out_path in pending: + final_content, _ = polished[depth] + out_path.write_text(final_content, encoding="utf-8") result.templates.append( GeneratedTemplate( feature=feature.name, diff --git a/src/attune_author/polish.py b/src/attune_author/polish.py index 215c438..1d3e65b 100644 --- a/src/attune_author/polish.py +++ b/src/attune_author/polish.py @@ -34,8 +34,11 @@ from __future__ import annotations +import hashlib import logging import os +import time +from pathlib import Path from attune_author.doc_gen._anthropic import ( AnthropicCallError, @@ -44,6 +47,149 @@ ) from attune_author.polish_prompts import get_system_prompt +#: Anthropic model used for the polish pass. Hoisted to a +#: module-level constant so it participates in the cache key — +#: bumping the model invalidates cache entries automatically. +_POLISH_MODEL = "claude-sonnet-4-20250514" + +#: Env var that overrides the default polish cache directory. +_CACHE_DIR_ENV = "ATTUNE_AUTHOR_POLISH_CACHE" +_CACHE_DIR_DEFAULT = Path.home() / ".attune" / "polish_cache" + +#: Env var (in seconds) that overrides the default cache-entry +#: TTL. ``0`` disables pruning entirely; negatives and unparseable +#: values fall back to the default. Tracked via mtime, which we +#: bump on every cache hit so heat is observed reliably even on +#: filesystems mounted ``noatime``. +_CACHE_TTL_ENV = "ATTUNE_AUTHOR_POLISH_CACHE_TTL_SECONDS" +_CACHE_TTL_DEFAULT_SECONDS = 30 * 24 * 60 * 60 # 30 days + + +def _cache_dir() -> Path: + env = os.environ.get(_CACHE_DIR_ENV, "").strip() + return Path(env) if env else _CACHE_DIR_DEFAULT + + +def _cache_ttl_seconds() -> int: + """Read the cache TTL from the environment. + + Returns the default (30 days) when the env var is unset, + blank, negative, or unparseable. ``0`` is honored as + "disable pruning" and returned verbatim. + """ + raw = os.environ.get(_CACHE_TTL_ENV, "").strip() + if not raw: + return _CACHE_TTL_DEFAULT_SECONDS + try: + val = int(raw) + except ValueError: + return _CACHE_TTL_DEFAULT_SECONDS + return val if val >= 0 else _CACHE_TTL_DEFAULT_SECONDS + + +def _cache_key(*parts: str) -> str: + h = hashlib.sha256() + for part in parts: + h.update(part.encode()) + h.update(b"\x00") + return h.hexdigest() + + +def _cache_get(key: str) -> str | None: + path = _cache_dir() / f"{key}.md" + try: + content = path.read_text(encoding="utf-8") + except OSError: + return None + # Bump mtime so the prune sweeper treats this entry as hot. + # Best-effort: a permission error here just means the entry + # may age out sooner than its actual access pattern warrants, + # which is acceptable degradation. + try: + os.utime(path, None) + except OSError: + pass + return content + + +def _cache_put(key: str, content: str) -> None: + cache = _cache_dir() + cache.mkdir(parents=True, exist_ok=True) + tmp = cache / f"{key}.tmp" + tmp.write_text(content, encoding="utf-8") + tmp.rename(cache / f"{key}.md") + # Lazy disk-space hygiene: piggyback the prune onto write + # operations so we don't need a separate scheduler. Cost is + # one stat per entry in the cache dir, which is negligible + # for the entry counts a single project produces. + try: + _cache_prune(cache) + except OSError: + # INTENTIONAL: prune failure must not break a successful + # write — the cache stays usable, it just doesn't shrink. + pass + + +def _cache_prune(cache: Path | None = None) -> int: + """Delete cache entries whose mtime is older than the TTL. + + Args: + cache: Cache directory to sweep. Defaults to + :func:`_cache_dir`. Accepting an explicit path keeps + the function trivially testable without monkey-patching + the env. + + Returns: + Count of files deleted. ``0`` when the cache directory + does not exist or the TTL is set to ``0`` (disabled). + """ + cache = cache if cache is not None else _cache_dir() + ttl = _cache_ttl_seconds() + if ttl <= 0 or not cache.exists(): + return 0 + cutoff = time.time() - ttl + deleted = 0 + for entry in cache.iterdir(): + if entry.suffix != ".md": + continue + try: + if entry.stat().st_mtime < cutoff: + entry.unlink() + deleted += 1 + except OSError: + # Swallow per-entry failures so a single locked or + # vanished file doesn't block the rest of the sweep. + continue + return deleted + + +def clear_cache() -> int: + """Delete every entry in the polish cache directory. + + Removes both finalized ``.md`` entries and any leftover + ``.tmp`` files from interrupted writes. Best-effort: per-entry + failures are swallowed so a single locked file does not block + the remaining deletions. + + Returns: + Count of files deleted. ``0`` when the cache directory + does not exist. + """ + cache = _cache_dir() + if not cache.exists(): + return 0 + deleted = 0 + for entry in cache.iterdir(): + if entry.suffix not in {".md", ".tmp"}: + continue + try: + entry.unlink() + deleted += 1 + except OSError: + continue + return deleted + + logger = logging.getLogger(__name__) #: Environment variable that flips polish out of strict mode. @@ -124,6 +270,20 @@ def polish_template( """ effective_strict = _env_strict_default() if strict is None else strict + system_prompt = get_system_prompt(template_type) + key = _cache_key( + content, + source_summary, + template_type, + system_prompt, + augmented_context or "", + _POLISH_MODEL, + ) + cached = _cache_get(key) + if cached is not None: + logger.debug("Polish cache hit for %s/%s", feature_name, template_type) + return cached + try: polished = _call_llm( content, @@ -132,7 +292,12 @@ def polish_template( template_type, augmented_context=augmented_context, ) - return _sanitize_output(polished) + result = _sanitize_output(polished) + try: + _cache_put(key, result) + except OSError as cache_exc: + logger.debug("Polish cache write failed (non-fatal): %s", cache_exc) + return result except Exception as exc: # noqa: BLE001 # INTENTIONAL: lenient mode swallows any LLM failure # so that `attune-author` can still run without an @@ -229,7 +394,7 @@ def _call_llm( client, system=system_prompt, user_message=user_message, - model="claude-sonnet-4-20250514", + model=_POLISH_MODEL, max_tokens=4096, ) return polished or content @@ -244,6 +409,7 @@ def _call_llm( "STRICT_ENV_VAR", "_env_strict_default", "build_source_summary", + "clear_cache", "polish_template", ] diff --git a/src/attune_author/rag_hook.py b/src/attune_author/rag_hook.py index 19f38ce..7875377 100644 --- a/src/attune_author/rag_hook.py +++ b/src/attune_author/rag_hook.py @@ -27,11 +27,31 @@ import logging import os +import threading logger = logging.getLogger(__name__) _DISABLE_ENV = "ATTUNE_AUTHOR_RAG" +# Module-level singleton so the corpus is loaded once per process rather than +# once per template kind. Thread-safe: _PIPELINE_LOCK guards first-time +# construction; subsequent reads need no lock (pipeline is immutable after init). +_PIPELINE = None # type: ignore[assignment] # RagPipeline | None +_PIPELINE_LOCK = threading.Lock() + + +def _get_pipeline(): # type: ignore[return] + """Return the process-level RagPipeline, constructing it on first call.""" + global _PIPELINE + if _PIPELINE is not None: + return _PIPELINE + with _PIPELINE_LOCK: + if _PIPELINE is None: + from attune_rag import RagPipeline + + _PIPELINE = RagPipeline() + return _PIPELINE + def rag_enabled() -> bool: """Return True when RAG grounding should be used. @@ -91,15 +111,10 @@ def ground_polish_context( if not rag_enabled(): return None - try: - from attune_rag import RagPipeline - except ImportError: - return None - query = f"{template_type} template for {feature_name}" try: - pipeline = RagPipeline() + pipeline = _get_pipeline() result = pipeline.run(query, k=k) except Exception: # noqa: BLE001 # INTENTIONAL: grounding is best-effort. Any diff --git a/tests/conftest.py b/tests/conftest.py index 24aa02f..e3da073 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -31,6 +31,19 @@ def _lenient_polish_by_default(monkeypatch: pytest.MonkeyPatch) -> Iterator[None yield +@pytest.fixture(autouse=True) +def _reset_rag_pipeline(monkeypatch: pytest.MonkeyPatch) -> None: + """Reset the module-level RagPipeline singleton before each test. + + ground_polish_context() caches the pipeline after first construction. + Tests that patch attune_rag.RagPipeline need the singleton to be None + so the patch intercepts construction rather than being bypassed. + """ + import attune_author.rag_hook as _rh + + monkeypatch.setattr(_rh, "_PIPELINE", None) + + @pytest.fixture def help_dir(tmp_path: Path) -> Path: """Create a .help/ directory with a features.yaml.""" diff --git a/tests/test_anthropic_retry.py b/tests/test_anthropic_retry.py new file mode 100644 index 0000000..79234ec --- /dev/null +++ b/tests/test_anthropic_retry.py @@ -0,0 +1,234 @@ +"""Tests for the retry/backoff logic in attune_author.doc_gen._anthropic. + +Covers the three behaviors that ``call_anthropic`` is supposed +to guarantee: + +1. Transient SDK errors (429, 529, ``APIConnectionError``) are + retried up to ``_MAX_RETRIES`` with exponential backoff. +2. Non-retryable SDK errors raise immediately. +3. Credential material in exception text is redacted before + the wrapped ``AnthropicCallError`` surfaces. + +``time.sleep`` is patched in every test so the suite runs fast +and the backoff schedule is asserted via the mock's call args +rather than wall-clock time. +""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest + + +def _api_status_error(status: int, message: str = "transient") -> Exception: + """Build an anthropic.APIStatusError with the given status code. + + The real exception class needs a response object with a + status_code; we hand it a MagicMock so the call_anthropic + code path that reads ``exc.status_code`` works as it would + against a live SDK error. + """ + from anthropic import APIStatusError + + response = MagicMock() + response.status_code = status + response.headers = {} + err = APIStatusError(message, response=response, body=None) + # APIStatusError forwards .status_code from the response. + return err + + +def _api_connection_error(message: str = "connection reset") -> Exception: + from anthropic import APIConnectionError + + return APIConnectionError(request=MagicMock()) + + +def _ok_response(text: str) -> MagicMock: + block = MagicMock() + block.text = text + response = MagicMock() + response.content = [block] + return response + + +class TestRetryableErrors: + """429 and 529 should be retried up to _MAX_RETRIES.""" + + def test_retries_on_429_then_succeeds(self) -> None: + from attune_author.doc_gen import _anthropic + + client = MagicMock() + client.messages.create.side_effect = [ + _api_status_error(429), + _ok_response("hello"), + ] + + with patch.object(_anthropic.time, "sleep") as mock_sleep: + result = _anthropic.call_anthropic( + client, + system="sys", + user_message="hi", + model="m", + max_tokens=10, + ) + + assert result == "hello" + assert client.messages.create.call_count == 2 + mock_sleep.assert_called_once_with(1.0) # base delay on first retry + + def test_retries_on_529_overload(self) -> None: + from attune_author.doc_gen import _anthropic + + client = MagicMock() + client.messages.create.side_effect = [ + _api_status_error(529), + _ok_response("ok"), + ] + + with patch.object(_anthropic.time, "sleep"): + result = _anthropic.call_anthropic( + client, system="s", user_message="u", model="m", max_tokens=10 + ) + + assert result == "ok" + + def test_retries_on_connection_error(self) -> None: + from attune_author.doc_gen import _anthropic + + client = MagicMock() + client.messages.create.side_effect = [ + _api_connection_error(), + _ok_response("recovered"), + ] + + with patch.object(_anthropic.time, "sleep"): + result = _anthropic.call_anthropic( + client, system="s", user_message="u", model="m", max_tokens=10 + ) + + assert result == "recovered" + + def test_exponential_backoff_schedule(self) -> None: + """Two retries should sleep 1s then 2s before the next attempt.""" + from attune_author.doc_gen import _anthropic + + client = MagicMock() + client.messages.create.side_effect = [ + _api_status_error(429), + _api_status_error(429), + _ok_response("finally"), + ] + + with patch.object(_anthropic.time, "sleep") as mock_sleep: + result = _anthropic.call_anthropic( + client, system="s", user_message="u", model="m", max_tokens=10 + ) + + assert result == "finally" + # Backoff: 1.0 * 2**0, then 1.0 * 2**1. + assert [c.args[0] for c in mock_sleep.call_args_list] == [1.0, 2.0] + + def test_gives_up_after_max_retries(self) -> None: + """Persistent 429s exhaust retries and raise AnthropicCallError.""" + from attune_author.doc_gen import _anthropic + + client = MagicMock() + client.messages.create.side_effect = [ + _api_status_error(429, "rate limit hit"), + ] * 4 # _MAX_RETRIES + 1 attempts + + with patch.object(_anthropic.time, "sleep"): + with pytest.raises(_anthropic.AnthropicCallError): + _anthropic.call_anthropic( + client, + system="s", + user_message="u", + model="m", + max_tokens=10, + ) + + # _MAX_RETRIES = 3 → 4 total attempts (1 initial + 3 retries). + assert client.messages.create.call_count == 4 + + +class TestNonRetryableErrors: + """4xx that aren't 429, 5xx that aren't 529, and unknown + exception types must raise immediately without retry. + """ + + def test_400_bad_request_raises_immediately(self) -> None: + from attune_author.doc_gen import _anthropic + + client = MagicMock() + client.messages.create.side_effect = _api_status_error(400, "bad input") + + with patch.object(_anthropic.time, "sleep") as mock_sleep: + with pytest.raises(_anthropic.AnthropicCallError): + _anthropic.call_anthropic( + client, + system="s", + user_message="u", + model="m", + max_tokens=10, + ) + + assert client.messages.create.call_count == 1 + mock_sleep.assert_not_called() + + def test_unknown_exception_raises_immediately(self) -> None: + from attune_author.doc_gen import _anthropic + + client = MagicMock() + client.messages.create.side_effect = ValueError("something else") + + with patch.object(_anthropic.time, "sleep") as mock_sleep: + with pytest.raises(_anthropic.AnthropicCallError): + _anthropic.call_anthropic( + client, + system="s", + user_message="u", + model="m", + max_tokens=10, + ) + + assert client.messages.create.call_count == 1 + mock_sleep.assert_not_called() + + +class TestCredentialRedaction: + """Credential material in error text must be scrubbed.""" + + def test_api_key_redacted_in_error_message(self) -> None: + from attune_author.doc_gen import _anthropic + + client = MagicMock() + # pragma: allowlist nextline secret + leaky_msg = "auth failed for sk-ant-abc123def456ghi789jkl0" + client.messages.create.side_effect = ValueError(leaky_msg) + + with pytest.raises(_anthropic.AnthropicCallError) as exc_info: + _anthropic.call_anthropic( + client, system="s", user_message="u", model="m", max_tokens=10 + ) + + msg = str(exc_info.value) + assert "sk-ant-abc123" not in msg + assert "[REDACTED]" in msg + + def test_cause_chain_is_stripped(self) -> None: + """``from None`` must scrub __cause__ so credentials can't + leak through ``str(exc.__cause__)``. + """ + from attune_author.doc_gen import _anthropic + + client = MagicMock() + client.messages.create.side_effect = ValueError("boom") + + with pytest.raises(_anthropic.AnthropicCallError) as exc_info: + _anthropic.call_anthropic( + client, system="s", user_message="u", model="m", max_tokens=10 + ) + + assert exc_info.value.__cause__ is None diff --git a/tests/test_polish_cache.py b/tests/test_polish_cache.py new file mode 100644 index 0000000..596dc45 --- /dev/null +++ b/tests/test_polish_cache.py @@ -0,0 +1,231 @@ +"""Tests for the on-disk polish cache. + +Covers the cache primitives in :mod:`attune_author.polish`: +hit/miss, mtime bump on hit, model name participating in the +key, mtime-based pruning, the TTL=0 disable, and ``clear_cache``. + +Each test uses a fresh ``tmp_path`` cache directory via the +``ATTUNE_AUTHOR_POLISH_CACHE`` env var so the user's real +``~/.attune/polish_cache`` is never touched. +""" + +from __future__ import annotations + +import os +import time +from pathlib import Path +from unittest.mock import patch + +import pytest + + +@pytest.fixture +def cache_dir(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> Path: + """Point the polish cache at a clean tmp directory.""" + monkeypatch.setenv("ATTUNE_AUTHOR_POLISH_CACHE", str(tmp_path)) + # Clear any TTL override left by other tests so the default + # (30 days) governs unless a test sets it explicitly. + monkeypatch.delenv("ATTUNE_AUTHOR_POLISH_CACHE_TTL_SECONDS", raising=False) + return tmp_path + + +class TestCachePutGet: + """Round-trip behavior of _cache_put / _cache_get.""" + + def test_put_then_get_returns_content(self, cache_dir: Path) -> None: + from attune_author.polish import _cache_get, _cache_put + + _cache_put("abc", "polished content") + assert _cache_get("abc") == "polished content" + + def test_get_returns_none_on_miss(self, cache_dir: Path) -> None: + from attune_author.polish import _cache_get + + assert _cache_get("never-written") is None + + def test_put_uses_atomic_tmp_rename(self, cache_dir: Path) -> None: + """No .tmp file should remain after a successful put.""" + from attune_author.polish import _cache_put + + _cache_put("abc", "x") + leftovers = list(cache_dir.glob("*.tmp")) + assert leftovers == [], f"orphaned .tmp files: {leftovers}" + + def test_get_bumps_mtime(self, cache_dir: Path) -> None: + """A cache hit must update the entry's mtime so the prune + sweeper treats it as hot. Verified by stamping an ancient + mtime, hitting the cache, then confirming mtime advanced. + """ + from attune_author.polish import _cache_get, _cache_put + + _cache_put("hot", "content") + path = next(cache_dir.glob("*.md")) + ancient = time.time() - 10_000 + os.utime(path, (ancient, ancient)) + + _cache_get("hot") + + new_mtime = path.stat().st_mtime + assert new_mtime > ancient + 100, f"mtime did not bump on hit: {new_mtime - ancient}s delta" + + +class TestCacheKeyIncludesModel: + """The model name must participate in the cache key so a model + bump in attune-author invalidates entries automatically. + """ + + def test_model_change_invalidates_cache(self, cache_dir: Path) -> None: + """Changing _POLISH_MODEL produces a different key for the + same content/summary/template_type/system_prompt/context. + """ + from attune_author import polish + + # First call with the production model produces a key K1. + with patch.object(polish, "_POLISH_MODEL", "claude-old"): + with patch("attune_author.polish._call_llm", return_value="from old"): + with patch.dict( + "os.environ", + {"ANTHROPIC_API_KEY": "fake"}, # pragma: allowlist secret + ): + out_old = polish.polish_template("# T", "feat", "sum", template_type="concept") + assert out_old.startswith("from old") + + # Second call with a different model must miss the cache — + # the LLM mock returns a different value, which we observe. + with patch.object(polish, "_POLISH_MODEL", "claude-new"): + with patch("attune_author.polish._call_llm", return_value="from new"): + with patch.dict( + "os.environ", + {"ANTHROPIC_API_KEY": "fake"}, # pragma: allowlist secret + ): + out_new = polish.polish_template("# T", "feat", "sum", template_type="concept") + assert out_new.startswith( + "from new" + ), "model change should invalidate cache key, but a stale entry was served" + + +class TestCachePrune: + """Mtime-based pruning behavior.""" + + def test_prune_deletes_stale_entries( + self, cache_dir: Path, monkeypatch: pytest.MonkeyPatch + ) -> None: + from attune_author.polish import _cache_prune, _cache_put + + monkeypatch.setenv("ATTUNE_AUTHOR_POLISH_CACHE_TTL_SECONDS", "1") + _cache_put("a", "alpha") + _cache_put("b", "beta") + # Backdate one entry past the TTL. + for path in cache_dir.glob("*.md"): + if "a" in path.name: + continue # leave 'a' fresh — file name is the sha256, so this + # branch never matches; we instead pick the first entry below + # Pick the first .md and backdate it. + targets = sorted(cache_dir.glob("*.md")) + old_path = targets[0] + ancient = time.time() - 10_000 + os.utime(old_path, (ancient, ancient)) + + deleted = _cache_prune() + + assert deleted == 1 + assert not old_path.exists() + # The other entry survives. + survivors = list(cache_dir.glob("*.md")) + assert len(survivors) == 1 + + def test_prune_with_ttl_zero_disables( + self, cache_dir: Path, monkeypatch: pytest.MonkeyPatch + ) -> None: + from attune_author.polish import _cache_prune, _cache_put + + monkeypatch.setenv("ATTUNE_AUTHOR_POLISH_CACHE_TTL_SECONDS", "0") + _cache_put("a", "alpha") + ancient = time.time() - 10_000 + for path in cache_dir.glob("*.md"): + os.utime(path, (ancient, ancient)) + + assert _cache_prune() == 0 + # Entry survives despite being ancient. + assert list(cache_dir.glob("*.md")) + + def test_prune_handles_missing_dir( + self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch + ) -> None: + from attune_author.polish import _cache_prune + + nonexistent = tmp_path / "does-not-exist" + monkeypatch.setenv("ATTUNE_AUTHOR_POLISH_CACHE", str(nonexistent)) + assert _cache_prune() == 0 + + def test_invalid_ttl_falls_back_to_default( + self, cache_dir: Path, monkeypatch: pytest.MonkeyPatch + ) -> None: + """Unparseable / negative TTL must NOT disable the cache — + the prune should fall back to the 30-day default. We verify + by stamping an entry one hour old; default TTL leaves it. + """ + from attune_author.polish import _cache_prune, _cache_put + + monkeypatch.setenv("ATTUNE_AUTHOR_POLISH_CACHE_TTL_SECONDS", "garbage") + _cache_put("a", "alpha") + an_hour_ago = time.time() - 3600 + for path in cache_dir.glob("*.md"): + os.utime(path, (an_hour_ago, an_hour_ago)) + + assert _cache_prune() == 0 + assert list(cache_dir.glob("*.md")), "fresh entry was wrongly pruned" + + +class TestClearCache: + """clear_cache() removes all entries unconditionally.""" + + def test_clear_removes_all_md_and_tmp(self, cache_dir: Path) -> None: + from attune_author.polish import _cache_put, clear_cache + + _cache_put("a", "alpha") + _cache_put("b", "beta") + # Drop a stray .tmp simulating an interrupted write. + (cache_dir / "interrupted.tmp").write_text("partial") + + # And an unrelated file that must NOT be deleted. + (cache_dir / "README.txt").write_text("don't touch me") + + deleted = clear_cache() + + assert deleted == 3 + assert not list(cache_dir.glob("*.md")) + assert not list(cache_dir.glob("*.tmp")) + assert (cache_dir / "README.txt").exists() + + def test_clear_handles_missing_dir( + self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch + ) -> None: + from attune_author.polish import clear_cache + + nonexistent = tmp_path / "does-not-exist" + monkeypatch.setenv("ATTUNE_AUTHOR_POLISH_CACHE", str(nonexistent)) + assert clear_cache() == 0 + + +class TestPolishTemplateCacheIntegration: + """End-to-end: polish_template hits the cache on a second call.""" + + def test_second_call_uses_cache_and_skips_llm(self, cache_dir: Path) -> None: + from attune_author import polish + + with patch.dict( + "os.environ", + {"ANTHROPIC_API_KEY": "fake"}, # pragma: allowlist secret + ): + with patch( + "attune_author.polish._call_llm", + return_value="polished output", + ) as mock_llm: + first = polish.polish_template("# T", "feat", "sum", template_type="concept") + second = polish.polish_template("# T", "feat", "sum", template_type="concept") + + assert first == second + # The LLM must have been called exactly once — the second + # call is served from the on-disk cache. + assert mock_llm.call_count == 1 diff --git a/uv.lock b/uv.lock index 2be830a..724350a 100644 --- a/uv.lock +++ b/uv.lock @@ -55,7 +55,7 @@ wheels = [ [[package]] name = "attune-author" -version = "0.5.0" +version = "0.5.1" source = { editable = "." } dependencies = [ { name = "attune-help" }, @@ -113,7 +113,7 @@ provides-extras = ["ai", "rag", "plugin", "rich", "dev"] [[package]] name = "attune-help" -version = "0.9.0" +version = "0.10.0" source = { editable = "../attune-help" } dependencies = [ { name = "python-frontmatter" },