From 9e61a1a3dab6086a71b88c86c447bb6e32dad082 Mon Sep 17 00:00:00 2001 From: GeneAI Date: Fri, 8 May 2026 15:46:29 -0400 Subject: [PATCH] feat(regenerate): batch path via Anthropic Message Batches API Adds opt-in `attune-author regenerate --batch` that submits all polish requests for stale features as a single Anthropic batch (~50% cost vs synchronous per-call pricing). Submit-and-detach UX with a `--resume` partner that polls + splices results when the user is ready. CLI surface: --batch submit; print batch ID + resume hint; exit --resume poll + splice + write templates --status one-shot status (supports --json) --cancel cancel batch + remove state file --force with --batch: overwrite an existing pending state Argparse mutex enforces "exactly one batch-mode flag at a time". Bare `regenerate` (no batch flags) prints a one-liner hint when a state file exists but never auto-resumes. State file lives at .help/.batch-state.json. Schema-versioned; 29-day retention check matches Anthropic's batch retention window. On poll-loop timeout the state file is intentionally KEPT so the next --resume picks up cleanly. SIGINT during polling cancels the batch and re-raises. Refactor: the synchronous path's render+write phases are extracted to generator.prepare_polish_phase / generator.apply_polish_results, and polish.build_polish_prompt is shared between paths. Both paths produce byte-identical on-disk templates for the same inputs (parity test in tests/test_maintenance_batch.py guards the invariant). Per-feature failure isolation: any errored depth fails the whole feature (no templates written for it); other features in the same batch succeed normally. Tests: +57 new unit tests, +1 opt-in live integration test (gated on ANTHROPIC_API_KEY + RUN_LIVE_BATCH=1, ~$0.02/run). Full suite 670 pass / 32 skipped (live + skill-export integration). Ruff clean on touched files. Spec: specs/author-batch-maintain/{requirements,design,tasks}.md Co-Authored-By: Claude Opus 4.7 --- CHANGELOG.md | 53 ++ docs/regenerate-batch.md | 158 ++++ pyproject.toml | 5 +- src/attune_author/__init__.py | 2 +- src/attune_author/cli.py | 228 +++++- src/attune_author/doc_gen/_anthropic_batch.py | 380 ++++++++++ src/attune_author/generator.py | 147 +++- src/attune_author/maintenance.py | 24 + src/attune_author/maintenance_batch.py | 704 ++++++++++++++++++ src/attune_author/polish.py | 69 +- tests/golden/batch_response.json | 65 ++ tests/integration/__init__.py | 0 tests/integration/test_batch_live.py | 81 ++ tests/test_anthropic_batch.py | 434 +++++++++++ tests/test_cli_batch.py | 396 ++++++++++ tests/test_maintenance_batch.py | 390 ++++++++++ 16 files changed, 3089 insertions(+), 47 deletions(-) create mode 100644 docs/regenerate-batch.md create mode 100644 src/attune_author/doc_gen/_anthropic_batch.py create mode 100644 src/attune_author/maintenance_batch.py create mode 100644 tests/golden/batch_response.json create mode 100644 tests/integration/__init__.py create mode 100644 tests/integration/test_batch_live.py create mode 100644 tests/test_anthropic_batch.py create mode 100644 tests/test_cli_batch.py create mode 100644 tests/test_maintenance_batch.py diff --git a/CHANGELOG.md b/CHANGELOG.md index bcc36d2..2e54dfc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,59 @@ and this project adheres to Work in progress for the next release. Add entries here as changes land, not at tag time. +## [0.11.0] - 2026-05-08 + +### Added + +- **`attune-author regenerate --batch`** — opt-in batch path + that submits all polish requests for stale features as one + Anthropic Message Batches API call (~50% cost vs per-call + pricing). Detaches; results are spliced back into the help + corpus by a follow-up `--resume` invocation. + - Companion flags: `--resume` (poll + splice + write), + `--status` (one-shot query, supports `--json`), + `--cancel` (call SDK cancel + remove state file), + `--force` (with `--batch`: overwrite an existing pending + state file). + - Argparse mutex enforces "exactly one batch-mode flag at a + time"; `--force` and `--json` are modifiers. + - Bare `regenerate` (no batch flags) prints a one-liner hint + when a state file exists; never auto-resumes. +- **`attune_author.maintenance_batch`** — new module containing + `submit_maintenance_batch`, `resume_maintenance_batch`, + `status_maintenance_batch`, `cancel_maintenance_batch` plus + the `BatchState` schema and `.help/.batch-state.json` + read/write/delete helpers. Includes 29-day stale-state + detection (Anthropic's batch retention window). +- **`attune_author.doc_gen._anthropic_batch`** — SDK wrapper + module exposing `submit_batch` / `poll_batch` plus typed + `BatchPolishRequest` / `BatchPolishResult` records and an + adaptive `timeout_secs` helper. SIGINT during polling cancels + the batch and re-raises. +- **`attune_author.polish.build_polish_prompt`** — extracted + prompt-building helper. Both the synchronous path and the + batch path call it so the wire-level prompts are byte-identical. +- **`attune_author.generator.prepare_polish_phase`** / + **`apply_polish_results`** — extracted Phase 1 (render) and + Phase 3 (write) helpers. The synchronous + `generate_feature_templates` now calls them; the batch path + uses them too. Single source of truth for rendering. +- **Docs**: `docs/regenerate-batch.md` covers when to use which + path, fallback behavior, env-var overrides, cancel semantics, + and stale-state recovery. +- Live integration test at `tests/integration/test_batch_live.py`, + gated on both `ANTHROPIC_API_KEY` and `RUN_LIVE_BATCH=1`. Tagged + `@pytest.mark.live`. Default-skipped; ~$0.02 per run. + +### Changed + +- **`generate_feature_templates`** internal layout: refactored + to call the new `prepare_polish_phase` and `apply_polish_results` + helpers. **No behavior change** for existing callers — same + inputs produce byte-identical on-disk templates. Verified by + the existing 152 polish tests + the new parity test in + `tests/test_maintenance_batch.py`. + ## [0.10.0] - 2026-05-08 ### Added diff --git a/docs/regenerate-batch.md b/docs/regenerate-batch.md new file mode 100644 index 0000000..e93e9c5 --- /dev/null +++ b/docs/regenerate-batch.md @@ -0,0 +1,158 @@ +# `attune-author regenerate --batch` + +Cost-saving alternative to the synchronous regen path. Submits all +polish requests for stale features as a single Anthropic Message +Batches API call (~50% cost savings) and detaches; you splice the +results back into your help corpus with a follow-up `--resume` +invocation. + +## TL;DR + +```sh +# kick off the batch (returns immediately) +attune-author regenerate --batch +# go to lunch ... +# when ready, splice results +attune-author regenerate --resume +``` + +## When to use which path + +| Path | Cost | Wall-clock latency | Ergonomics | +|----------------------------|------------|-------------------|-------------------------------------| +| Default (synchronous) | full price | seconds-to-minutes | terminal blocks; one command | +| `--batch` (then `--resume`)| ~50% price | minutes-to-hours | terminal returns immediately twice | + +Default is right for: small ad-hoc regen (1–5 stale features), CI +that needs a clean exit code in one step, or when you want +sub-minute end-to-end latency. + +`--batch` is right for: bulk regen (10+ stale features), nightly +cron jobs where two scheduled invocations is fine, or any time +the cost savings outweighs a few minutes of wait. + +## How it works + +The `--batch` invocation: + +1. Resolves stale features from the staleness report. +2. Renders all templates for those features (Phase 1 of the + normal generator pipeline). +3. Builds polish prompts for every (feature, depth) pair using + the **same** `build_polish_prompt` helper the synchronous + path uses, so the prompts go out byte-identical. +4. Posts a single Anthropic batch (`messages.batches.create`). +5. Writes a small state file at `.help/.batch-state.json` + containing the batch ID and a per-request manifest. +6. Prints a copy-paste-ready `--resume` hint and exits. + +The `--resume` invocation: + +1. Loads the state file. +2. Polls Anthropic until the batch terminates (or the adaptive + timeout fires). +3. Splices polished text back into per-feature + `GenerationResult` records. +4. Writes the templates exactly the way the synchronous path + does (Phase 3 of the normal generator pipeline). +5. Deletes the state file (on terminal completion) or keeps it + (on poll-loop timeout, so the next `--resume` picks up). + +## Resume ergonomics + +A few specific affordances make `--resume` painless: + +- **No batch ID required.** `--resume` reads the state file + automatically. +- **Bare `regenerate` prints a hint.** If you accidentally run + the synchronous path while a batch is pending, the CLI tells + you about the pending batch (without auto-resuming). +- **Resume-after-timeout works.** If polling hit our 30-min cap + but Anthropic is still working, the state file is **kept** and + you can run `--resume` again later. No work lost. +- **`--status` is one-shot.** Want to know if your batch is done + without committing to a wait? `attune-author regenerate --status`. + Add `--json` for cron parsing. +- **`--cancel` is one-shot.** Changed your mind, or batch is + wedged? `attune-author regenerate --cancel`. + +## Per-request failure isolation + +If one (feature, depth) request errors inside the batch (rate +limit, content policy, model error), the **whole feature** is +marked as failed in `result.failed`. Other features in the same +batch are still written to disk normally. This mirrors the +synchronous path's per-feature failure isolation. + +## Stale-state recovery + +Anthropic retains batch results for 29 days. If you somehow leave +a state file around longer than that, `--resume` surfaces a clean +error pointing at the cleanup path: + +``` +error: batch submitted 2026-04-01T12:00:00+00:00 is older than + Anthropic's 29-day retention window; delete + .help/.batch-state.json and rerun --batch +``` + +## Multi-batch guard + +If you run `--batch` while a state file already exists, you get a +clear refusal: + +``` +error: pending batch already submitted; state file at + .help/.batch-state.json. Run --resume, --cancel, or pass + --force to overwrite. +``` + +`--force` overwrites the state file and starts a new batch. (The +old Anthropic batch continues to run on its side; you can let it +finish or cancel via the dashboard.) + +## Environment variables + +| Variable | Default | Purpose | +|--------------------------------|---------|---------------------------------------------| +| `ATTUNE_BATCH_POLL_SECS` | 30 | Poll interval for `--resume`. | +| `ATTUNE_BATCH_TIMEOUT_SECS` | adaptive | Override the poll-loop ceiling. | +| `ANTHROPIC_API_KEY` | — | Required for any LLM call. | + +The adaptive timeout default scales with batch size: +`min(30min, max(5min, 60·N/20))` — 5min minimum (handles +cold-start variance) plus 1min per 20 requests, capped at +30min. + +## Ctrl+C semantics + +During `--resume` polling, SIGINT calls Anthropic's +`messages.batches.cancel` and re-raises so the CLI can clean up. +Anthropic doesn't refund work already in progress, but it stops +new work and lets you walk away without lingering charges. + +## Cost model + +The Batches API charges ~50% of the per-token prices of the +synchronous Messages API. There's no per-batch overhead; cost is +purely a function of total tokens. Batch jobs typically complete +within minutes for our N (~10–50 polish prompts), well inside +Anthropic's 24h SLA. + +## Live integration test + +`tests/integration/test_batch_live.py` covers end-to-end +submit→poll→splice with two real polish requests. Skipped by +default; opt-in via: + +```sh +ANTHROPIC_API_KEY=sk-ant-... RUN_LIVE_BATCH=1 \ + pytest tests/integration/test_batch_live.py -m live +``` + +Cost: ~$0.02 per run. + +## See also + +- Spec: `specs/author-batch-maintain/{requirements,design,tasks}.md` +- Anthropic docs: diff --git a/pyproject.toml b/pyproject.toml index e4e17ac..7d58974 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "attune-author" -version = "0.10.0" +version = "0.11.0" description = "Documentation authoring and maintenance for the attune ecosystem — generate, maintain, and validate help content with AI assistance." readme = {file = "README.md", content-type = "text/markdown"} requires-python = ">=3.10" @@ -102,6 +102,9 @@ select = ["E", "F", "W", "I", "UP", "BLE"] [tool.pytest.ini_options] testpaths = ["tests"] +markers = [ + "live: opt-in tests that hit the live Anthropic API. Skipped by default; require ANTHROPIC_API_KEY and any test-specific env flags.", +] [tool.coverage.run] source = ["attune_author"] diff --git a/src/attune_author/__init__.py b/src/attune_author/__init__.py index c3fc3ad..2bac38f 100644 --- a/src/attune_author/__init__.py +++ b/src/attune_author/__init__.py @@ -5,7 +5,7 @@ attune-help (reader) and attune-ai (full dev workflows). """ -__version__ = "0.10.0" +__version__ = "0.11.0" from attune_author.manifest import Feature, Manifest, load_manifest from attune_author.staleness import StalenessReport, check_staleness, compute_source_hash diff --git a/src/attune_author/cli.py b/src/attune_author/cli.py index d945c52..2ae7af6 100644 --- a/src/attune_author/cli.py +++ b/src/attune_author/cli.py @@ -174,6 +174,47 @@ def _build_parser() -> argparse.ArgumentParser: action="store_true", help="Report stale features without regenerating.", ) + # --- batch-mode flags (mutually exclusive) ----------------------------- + # Sub-group enforces "exactly one of these (or none) at a time" so the + # user can't ask the CLI to both submit and resume in one invocation. + batch_group = p_regen.add_mutually_exclusive_group() + batch_group.add_argument( + "--batch", + action="store_true", + help=( + "Submit polish requests as one Anthropic batch (~50%% cost). " + "Detaches; run --resume later to splice results." + ), + ) + batch_group.add_argument( + "--resume", + action="store_true", + help="Resume a previously submitted batch and write its templates.", + ) + batch_group.add_argument( + "--status", + action="store_true", + help="One-shot status of the pending batch. No polling.", + ) + batch_group.add_argument( + "--cancel", + action="store_true", + help="Cancel the pending batch and remove its state file.", + ) + p_regen.add_argument( + "--force", + action="store_true", + help=( + "With --batch: overwrite an existing pending state file " + "(cancels the previous batch from the local view; the " + "still-running Anthropic batch is left to finish on its own)." + ), + ) + p_regen.add_argument( + "--json", + action="store_true", + help="With --status: emit JSON instead of human-readable output.", + ) p_cache = sub.add_parser( "cache", @@ -464,11 +505,38 @@ def _cmd_generate(args: argparse.Namespace) -> int: def _cmd_regenerate(args: argparse.Namespace) -> int: """Handle the regenerate command.""" - from attune_author.maintenance import run_maintenance - root = validate_file_path(args.project_root) help_dir = validate_file_path(args.help_dir) + if getattr(args, "batch", False): + return _cmd_regenerate_batch_submit(args, help_dir, root) + if getattr(args, "resume", False): + return _cmd_regenerate_batch_resume(args, help_dir, root) + if getattr(args, "status", False): + return _cmd_regenerate_batch_status(args, help_dir) + if getattr(args, "cancel", False): + return _cmd_regenerate_batch_cancel(args, help_dir) + + return _cmd_regenerate_synchronous(args, help_dir, root) + + +def _cmd_regenerate_synchronous( + args: argparse.Namespace, + help_dir: Path, + root: Path, +) -> int: + """Default synchronous path. Also surfaces the pending-batch hint.""" + from attune_author.maintenance import run_maintenance + from attune_author.maintenance_batch import has_pending_batch + + if has_pending_batch(help_dir): + print( + "note: a pending batch was previously submitted from this corpus.\n" + " run `attune-author regenerate --resume` to splice its results,\n" + " or `attune-author regenerate --status` to see progress.", + file=sys.stderr, + ) + try: result = run_maintenance( help_dir=help_dir, @@ -491,6 +559,159 @@ def _cmd_regenerate(args: argparse.Namespace) -> int: return 0 +def _cmd_regenerate_batch_submit( + args: argparse.Namespace, + help_dir: Path, + root: Path, +) -> int: + from attune_author.maintenance_batch import ( + BatchAlreadyPending, + submit_maintenance_batch, + ) + + try: + state = submit_maintenance_batch( + help_dir=help_dir, + project_root=root, + force=args.force, + ) + except BatchAlreadyPending as exc: + print(f"error: {exc}", file=sys.stderr) + return 2 + except FileNotFoundError: + _print_missing_manifest_hint(help_dir) + return 1 + except ValueError as exc: + # No stale features / nothing to do — surface gracefully. + print(f"note: {exc}") + return 0 + + print( + f"Submitted batch {state.batch_id} ({len(state.requests)} requests, model {state.model}).\n" + f"Estimated completion: ~{_format_eta(state)} (Anthropic 24h SLA)." + ) + print() + print("Run attune-author regenerate --resume to splice results when ready.") + return 0 + + +def _cmd_regenerate_batch_resume( + args: argparse.Namespace, + help_dir: Path, + root: Path, +) -> int: + from attune_author.maintenance_batch import ( + BatchStateExpired, + BatchStateNotFound, + resume_maintenance_batch, + ) + + _ = args + try: + outcome = resume_maintenance_batch( + help_dir=help_dir, + project_root=root, + ) + except BatchStateNotFound: + print("no pending batch found. Nothing to resume.", file=sys.stderr) + return 1 + except BatchStateExpired as exc: + print(f"error: {exc}", file=sys.stderr) + return 2 + + if outcome.final_status == "timed_out": + print( + f"Batch {outcome.batch_id} still running. State file kept; run " + "`attune-author regenerate --resume` again later." + ) + else: + print( + f"Batch {outcome.batch_id} {outcome.final_status}. " + f"Regenerated {outcome.regenerated_count} feature(s)." + ) + if outcome.failed: + print(f"Failed: {', '.join(outcome.failed)}") + return 0 + + +def _cmd_regenerate_batch_status( + args: argparse.Namespace, + help_dir: Path, +) -> int: + from attune_author.maintenance_batch import ( + BatchStateExpired, + BatchStateNotFound, + status_maintenance_batch, + ) + + try: + info = status_maintenance_batch(help_dir=help_dir) + except BatchStateNotFound: + print("no pending batch.", file=sys.stderr) + return 1 + except BatchStateExpired as exc: + print(f"error: {exc}", file=sys.stderr) + return 2 + + if args.json: + import json as _json + + print(_json.dumps(info, indent=2, sort_keys=True)) + return 0 + + print("Pending batch:") + print(f" id: {info['batch_id']}") + print(f" submitted: {info['submitted_at']}") + print(f" expected completion: {info['expected_completion_at']}") + print(f" model: {info['model']}") + print(f" request count: {info['request_count']}") + print(f" state (Anthropic): {info.get('processing_status')}") + rc = info.get("request_counts") or {} + if rc: + print( + f" succeeded={rc.get('succeeded', 0)} " + f"errored={rc.get('errored', 0)} " + f"expired={rc.get('expired', 0)} " + f"canceled={rc.get('canceled', 0)} " + f"processing={rc.get('processing', 0)}" + ) + return 0 + + +def _cmd_regenerate_batch_cancel( + args: argparse.Namespace, + help_dir: Path, +) -> int: + from attune_author.maintenance_batch import ( + BatchStateExpired, + BatchStateNotFound, + cancel_maintenance_batch, + ) + + _ = args + try: + batch_id = cancel_maintenance_batch(help_dir=help_dir) + except BatchStateNotFound: + print("no pending batch to cancel.", file=sys.stderr) + return 1 + except BatchStateExpired as exc: + print(f"error: {exc}", file=sys.stderr) + return 2 + print(f"Canceled batch {batch_id}.") + return 0 + + +def _format_eta(state: object) -> str: + """Human-readable ETA for a batch state.""" + submitted = getattr(state, "submitted_at") + eta = getattr(state, "expected_completion_at") + delta = eta - submitted + minutes = int(delta.total_seconds() // 60) + if minutes < 1: + return "<1 min" + return f"{minutes} min" + + def _cmd_cache(args: argparse.Namespace) -> int: """Handle the cache command and its subcommands.""" from attune_author.polish import _cache_dir, clear_cache @@ -747,8 +968,7 @@ def _print_generate_usage(help_dir: Path) -> None: manifest = load_manifest(help_dir) except Exception: # noqa: BLE001 print( - f"\nNo manifest found at {help_dir / 'features.yaml'}. " - "Run `attune-author init` first.", + f"\nNo manifest found at {help_dir / 'features.yaml'}. Run `attune-author init` first.", file=sys.stderr, ) return diff --git a/src/attune_author/doc_gen/_anthropic_batch.py b/src/attune_author/doc_gen/_anthropic_batch.py new file mode 100644 index 0000000..2b9aee5 --- /dev/null +++ b/src/attune_author/doc_gen/_anthropic_batch.py @@ -0,0 +1,380 @@ +"""Anthropic Message Batches API wrapper. + +Mirrors :mod:`attune_author.doc_gen._anthropic`'s contract for +*N* prompts: build a single batch request, submit it, and poll +until the batch reaches a terminal state. Used by the +``attune-author maintain --batch`` flow as the cost-saving +alternative to per-call ``messages.create``. + +The split across :func:`submit_batch` and :func:`poll_batch` is +deliberate — submit is a fast one-shot that the CLI uses on the +``--batch`` invocation, while polling happens later on +``--resume`` (potentially in a fresh process). They are +connected by the batch ID stored in the on-disk +``.help/.batch-state.json``. + +Verified against ``anthropic`` SDK 0.89+: + +- Resource path: ``client.messages.batches`` with + ``create / cancel / retrieve / results``. +- :class:`anthropic.types.messages.MessageBatch` exposes + ``processing_status``, ``request_counts``, ``expires_at``, + ``cancel_initiated_at``, ``results_url``. +- :class:`MessageBatchIndividualResponse` carries + ``custom_id`` + ``result`` (discriminated union of + Succeeded / Errored / Expired / Canceled). +""" + +from __future__ import annotations + +import time +from collections.abc import Callable +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from anthropic import Anthropic + + +# Status strings the SDK can return from ``processing_status``. +# We only branch on a small subset; the rest map to ``in_progress``. +_TERMINAL_STATUSES = {"ended", "canceled"} + + +@dataclass(frozen=True) +class BatchPolishRequest: + """One polish request inside a batch. + + ``custom_id`` is the Anthropic-side per-request key. Callers + set it to a stable string (we use ``feat__{feature}__{depth}``) + so the resume splicer can map results back to (feature, depth) + without consulting any other state. + """ + + custom_id: str + feature: str + depth: str + system: str + user_message: str + model: str + max_tokens: int + cache_system: bool = False + + +@dataclass(frozen=True) +class BatchPolishResult: + """One polish response from a batch. + + Exactly one of ``text`` / ``error`` is non-None. ``error`` + is a short string that mirrors today's per-call failure + messages (e.g. ``"anthropic.RateLimitError: ..."``). + Special sentinel values: ``"timed_out"`` (poll-loop ceiling + exceeded), ``"expired"`` (batch hit Anthropic's 24h window), + ``"canceled"`` (batch was canceled). + """ + + custom_id: str + feature: str + depth: str + text: str | None + error: str | None + + +def _build_request_param(req: BatchPolishRequest) -> dict[str, Any]: + """Render one ``BatchPolishRequest`` as the SDK's per-request envelope. + + Anthropic batches wrap the same Messages payload we use + today; the only difference is the per-entry ``custom_id`` + plus the ``params`` envelope. ``cache_control`` on the + system prompt works the same as the synchronous path. + """ + if req.cache_system: + system_payload: object = [ + { + "type": "text", + "text": req.system, + "cache_control": {"type": "ephemeral"}, + } + ] + else: + system_payload = req.system + + return { + "custom_id": req.custom_id, + "params": { + "model": req.model, + "max_tokens": req.max_tokens, + "system": system_payload, + "messages": [{"role": "user", "content": req.user_message}], + }, + } + + +def submit_batch( + requests: list[BatchPolishRequest], + *, + client: Anthropic | None = None, +) -> str: + """Post one Anthropic batch and return its ``batch_id``. + + Does not poll. Mirrors the synchronous-path pattern: caller + supplies an explicit client (for tests) or we build one from + ``ANTHROPIC_API_KEY`` via the shared helper. + """ + if not requests: + raise ValueError("submit_batch requires at least one request") + + if client is None: + from ._anthropic import get_client + + client = get_client() + + payload = [_build_request_param(r) for r in requests] + batch = client.messages.batches.create(requests=payload) + return batch.id + + +def adaptive_timeout_secs(num_requests: int) -> float: + """Default poll-loop ceiling, scaled to batch size. + + Empirical Anthropic batch latency is ~30s baseline plus + linear growth. We give 5min minimum (handles cold-start + variance) plus 1min per 20 requests, capped at 30min so + a single ``--resume`` invocation can't burn an entire + afternoon. Spec design.md "Adaptive polling and timeout (D2)". + """ + return min(30 * 60, max(5 * 60, 60 * (num_requests / 20))) + + +def poll_batch( + batch_id: str, + expected_requests: list[BatchPolishRequest], + *, + client: Anthropic | None = None, + poll_interval_secs: float = 30.0, + timeout_secs: float | None = None, + on_progress: Callable[[Any], None] | None = None, + _now: Callable[[], float] = time.monotonic, + _sleep: Callable[[float], None] = time.sleep, +) -> tuple[str, list[BatchPolishResult]]: + """Poll until the batch reaches terminal state. + + Returns ``(final_status, results)``: + + - ``final_status`` ∈ ``{"ended", "canceled", "expired", "timed_out"}``. + - ``results`` has one entry per ``expected_requests`` item, in + the same order. On ``timed_out`` (poll-loop ceiling exceeded + before Anthropic finished), missing entries become + ``BatchPolishResult(error="timed_out")``. + + SIGINT raises :class:`KeyboardInterrupt` after attempting + to cancel the batch — surviving completed work is reflected + in the partial results returned by a follow-up + ``poll_batch`` call. + + ``_now`` and ``_sleep`` are injection points for tests; do + not pass them in production. + """ + if client is None: + from ._anthropic import get_client + + client = get_client() + + if timeout_secs is None: + timeout_secs = adaptive_timeout_secs(len(expected_requests)) + + started = _now() + deadline = started + timeout_secs + + try: + while True: + batch = client.messages.batches.retrieve(batch_id) + if on_progress is not None: + on_progress(batch) + + status = getattr(batch, "processing_status", "in_progress") or "in_progress" + if status in _TERMINAL_STATUSES: + results = _collect_results(client, batch_id, expected_requests) + final_status = _final_status_from_terminal(batch, status) + return final_status, results + + if _now() >= deadline: + # Best-effort: pull any results Anthropic has produced so far. + # Some batches expose partial results before they're terminal, + # so we attempt a fetch but tolerate failure. + partial = _collect_partial_results(client, batch_id, expected_requests) + return "timed_out", partial + + _sleep(poll_interval_secs) + except KeyboardInterrupt: + # Best-effort cancel so the user isn't billed for new work + # after they hit Ctrl+C. SIGINT propagates so the CLI can + # clean up the state file and print a clear message. + try: + client.messages.batches.cancel(batch_id) + except Exception: # noqa: BLE001 — cancel is best-effort + pass + raise + + +def _final_status_from_terminal(batch: Any, status: str) -> str: + """Translate the SDK's terminal ``processing_status`` to our taxonomy. + + ``ended`` is the success path; we further check + ``request_counts`` and the presence of ``cancel_initiated_at`` + or expired-result rows to surface ``canceled`` vs + ``expired`` as appropriate. + """ + if status == "canceled": + return "canceled" + if getattr(batch, "cancel_initiated_at", None): + return "canceled" + counts = getattr(batch, "request_counts", None) + if counts is not None and getattr(counts, "expired", 0) > 0: + return "expired" + return "ended" + + +def _collect_results( + client: Anthropic, + batch_id: str, + expected: list[BatchPolishRequest], +) -> list[BatchPolishResult]: + """Fetch and parse all per-request results. + + The SDK's ``results()`` returns an iterator of + :class:`MessageBatchIndividualResponse`. We map each by + ``custom_id`` then walk ``expected`` to produce a + same-order list — keeps splicing on the orchestration side + O(N) and order-deterministic for tests. + """ + by_id: dict[str, Any] = {} + for entry in client.messages.batches.results(batch_id): + cid = getattr(entry, "custom_id", None) + if cid: + by_id[cid] = entry + + out: list[BatchPolishResult] = [] + for req in expected: + entry = by_id.get(req.custom_id) + if entry is None: + out.append(_missing_result(req, "missing_in_results")) + continue + out.append(_parse_individual(entry, req)) + return out + + +def _collect_partial_results( + client: Anthropic, + batch_id: str, + expected: list[BatchPolishRequest], +) -> list[BatchPolishResult]: + """Try to fetch results for a non-terminal batch (timeout path). + + Most batches won't expose results until terminal, but the + SDK doesn't refuse the call. Any request with no result + becomes ``error="timed_out"`` so the caller's failure + surface is filled. + """ + try: + results = _collect_results(client, batch_id, expected) + except Exception: # noqa: BLE001 — partial-fetch is best-effort + results = [_missing_result(r, "timed_out") for r in expected] + return results + + timed_out: list[BatchPolishResult] = [] + for req, res in zip(expected, results, strict=True): + if res.error == "missing_in_results": + timed_out.append(_missing_result(req, "timed_out")) + else: + timed_out.append(res) + return timed_out + + +def _missing_result(req: BatchPolishRequest, error: str) -> BatchPolishResult: + return BatchPolishResult( + custom_id=req.custom_id, + feature=req.feature, + depth=req.depth, + text=None, + error=error, + ) + + +def _parse_individual(entry: Any, req: BatchPolishRequest) -> BatchPolishResult: + """Parse one ``MessageBatchIndividualResponse`` into our type. + + The ``result`` field is a discriminated union; we branch on + its ``type`` and pull either the response text (succeeded) + or a short error string (everything else). + """ + result = getattr(entry, "result", None) + rtype = getattr(result, "type", None) + + if rtype == "succeeded": + message = getattr(result, "message", None) + text = _first_text_block(message) + return BatchPolishResult( + custom_id=req.custom_id, + feature=req.feature, + depth=req.depth, + text=text, + error=None, + ) + + if rtype == "errored": + err = getattr(result, "error", None) + # SDK shape: result.error.error.{type, message} or similar. + # Walk defensively so unfamiliar shapes don't crash. + message = ( + _resolve_attr(err, "message") + or _resolve_attr(err, "error", "message") + or "unknown error" + ) + return BatchPolishResult( + custom_id=req.custom_id, + feature=req.feature, + depth=req.depth, + text=None, + error=str(message), + ) + + if rtype in {"expired", "canceled"}: + return BatchPolishResult( + custom_id=req.custom_id, + feature=req.feature, + depth=req.depth, + text=None, + error=rtype, + ) + + # Unknown or missing result type — defensive default. + return BatchPolishResult( + custom_id=req.custom_id, + feature=req.feature, + depth=req.depth, + text=None, + error=f"unknown_result_type:{rtype}", + ) + + +def _first_text_block(message: Any) -> str: + """Return the first text block from a message-shaped object.""" + if message is None: + return "" + content = getattr(message, "content", None) or [] + for block in content: + text = getattr(block, "text", None) + if text: + return text + return "" + + +def _resolve_attr(obj: Any, *path: str) -> Any: + """Walk ``path`` of attributes; return None on any miss.""" + cur = obj + for name in path: + if cur is None: + return None + cur = getattr(cur, name, None) + return cur diff --git a/src/attune_author/generator.py b/src/attune_author/generator.py index fc03f76..4ff7836 100644 --- a/src/attune_author/generator.py +++ b/src/attune_author/generator.py @@ -179,6 +179,52 @@ class GeneratedTemplate: source_hash: str +@dataclass(frozen=True) +class _PendingPolish: + """One depth pre-polish state. + + Returned by :func:`prepare_polish_phase` so the batch path + can submit polish requests for many features at once and + splice results back via :func:`apply_polish_results`. + """ + + depth: str + rendered_content: str + out_path: Path + + +@dataclass(frozen=True) +class PolishPreparation: + """Per-feature pre-polish state for the batch path. + + Encapsulates everything :func:`generate_feature_templates` + computes BEFORE the polish pass so a sibling caller (the + batch maintenance flow) can collect prompts across many + features, submit one batch, and later splice polished + results back without re-doing the expensive Phase 1 work. + + The synchronous path does not use this — it stays inside + ``generate_feature_templates`` as today. + """ + + feature: object # Feature; kept loosely-typed to avoid import cycle + source_hash: str + matched_files: list[str] + source_info: object # _SourceInfo; ditto + pending: tuple[_PendingPolish, ...] + use_rag: bool + + def pending_legacy_tuples(self) -> tuple[tuple[str, str, Path], ...]: + """Adapt ``pending`` to the legacy 3-tuple shape for ``_parallel_polish``. + + ``_parallel_polish`` predates this dataclass and expects + ``(depth, content, out_path)``. We keep the legacy shape + local to this conversion so the rest of the code can use + the more readable :class:`_PendingPolish`. + """ + return tuple((p.depth, p.rendered_content, p.out_path) for p in self.pending) + + @dataclass class GenerationResult: """Result of generating templates for a feature. @@ -245,6 +291,47 @@ def generate_feature_templates( Returns: GenerationResult with paths and metadata. """ + prep = prepare_polish_phase( + feature=feature, + help_dir=help_dir, + project_root=project_root, + depths=depths, + overwrite=overwrite, + use_rag=use_rag, + ) + + # Phase 2: LLM polish — run all depths concurrently. + polished = _parallel_polish( + list(prep.pending_legacy_tuples()), + prep.feature, + prep.source_info, + prep.use_rag, + ) + polished_text: dict[str, str] = {depth: text for depth, (text, _path) in polished.items()} + + return apply_polish_results(prep, polished_text) + + +def prepare_polish_phase( + feature: Feature, + help_dir: str | Path, + project_root: str | Path, + depths: list[str] | None = None, + overwrite: bool = False, + use_rag: bool = True, +) -> PolishPreparation: + """Run the pre-polish (Phase 1) work for one feature. + + Computes the source hash, extracts source info, builds the + Jinja environment, renders all active templates, and returns + a :class:`PolishPreparation` carrying everything the polish + phase (and later the write phase) need. + + Pure of LLM calls. Both the synchronous path + (:func:`generate_feature_templates`) and the batched path + (:mod:`attune_author.maintenance_batch`) call this so they + render identically — the parity guard against drift. + """ help_path = Path(help_dir) root = Path(project_root) target_depths = depths or list(_CORE_DEPTH_NAMES) @@ -252,22 +339,12 @@ def generate_feature_templates( if not is_safe_feature_name(feature.name): raise ValueError(f"Invalid feature name: {feature.name!r}") - # Compute source hash source_hash, matched_files = compute_source_hash(feature, root) - - # Extract info from source files source_info = _extract_source_info(matched_files, root) - result = GenerationResult( - feature=feature.name, - source_hash=source_hash, - matched_files=matched_files, - ) - template_dir = help_path / "templates" / feature.name template_dir.mkdir(parents=True, exist_ok=True) - # Build Jinja2 environment with project-first resolution env = _build_jinja_env(help_path) if ( @@ -284,9 +361,7 @@ def generate_feature_templates( ", ".join(feature.doc_paths[1:]), ) - # Phase 1: render all templates (fast Jinja2, sequential). - # Determines which depths are active and builds the rendered skeleton. - pending: list[tuple[str, str, Path]] = [] + pending: list[_PendingPolish] = [] for depth in target_depths: if depth not in _ALL_TEMPLATE_NAMES: logger.warning("Unknown template kind '%s', skipping", depth) @@ -331,24 +406,48 @@ def generate_feature_templates( source_hash=source_hash, source_info=source_info, ) - pending.append((depth, content, out_path)) + pending.append(_PendingPolish(depth=depth, rendered_content=content, out_path=out_path)) - # Phase 2: LLM polish — run all depths concurrently. - polished = _parallel_polish(pending, feature, source_info, use_rag) + return PolishPreparation( + feature=feature, + source_hash=source_hash, + matched_files=matched_files, + source_info=source_info, + pending=tuple(pending), + use_rag=use_rag, + ) - # Phase 3: write results in original depth order. - for depth, content, out_path in pending: - final_content, _ = polished[depth] - out_path.write_text(final_content, encoding="utf-8") + +def apply_polish_results( + prep: PolishPreparation, + polished_by_depth: dict[str, str], +) -> GenerationResult: + """Write polished content for each pending template (Phase 3). + + ``polished_by_depth`` maps depth → polished markdown. Any + depth missing from the map falls back to the original + rendered content — used by the batch path when a per-request + failure means we couldn't get a polished version, mirroring + the synchronous path's "use raw template on lenient-mode + failure" behavior. + """ + feature = prep.feature + result = GenerationResult( + feature=feature.name, + source_hash=prep.source_hash, + matched_files=list(prep.matched_files), + ) + for entry in prep.pending: + final_content = polished_by_depth.get(entry.depth, entry.rendered_content) + entry.out_path.write_text(final_content, encoding="utf-8") result.templates.append( GeneratedTemplate( feature=feature.name, - depth=depth, - path=out_path, - source_hash=source_hash, + depth=entry.depth, + path=entry.out_path, + source_hash=prep.source_hash, ) ) - return result diff --git a/src/attune_author/maintenance.py b/src/attune_author/maintenance.py index 11f453e..10c0345 100644 --- a/src/attune_author/maintenance.py +++ b/src/attune_author/maintenance.py @@ -46,6 +46,30 @@ def regenerated_count(self) -> int: return len(self.regenerated) +def _resolve_stale_features( + help_dir: str | Path, + project_root: str | Path, + features: list[str] | None = None, +) -> list: + """Return a list of stale ``Feature`` objects. + + Used by both the synchronous path and the batch path so they + pick up the exact same set of features to regenerate. Returns + an empty list when nothing is stale. + """ + manifest = load_manifest(help_dir) + report = check_staleness(manifest, help_dir, project_root, features) + out = [] + for entry in report.help_entries: + if not entry.is_stale: + continue + feat = manifest.features.get(entry.feature) + if feat is None: + continue + out.append(feat) + return out + + def run_maintenance( help_dir: str | Path, project_root: str | Path, diff --git a/src/attune_author/maintenance_batch.py b/src/attune_author/maintenance_batch.py new file mode 100644 index 0000000..f722fbe --- /dev/null +++ b/src/attune_author/maintenance_batch.py @@ -0,0 +1,704 @@ +"""Batched help maintenance via Anthropic's Message Batches API. + +Sibling module to :mod:`attune_author.maintenance`. The +synchronous default path lives there; this module owns the +opt-in batch path triggered by ``attune-author maintain --batch``. + +Two phases that run in two separate CLI invocations: + +- **Submit** — :func:`submit_maintenance_batch` resolves stale + features, builds a list of polish prompts, posts a single + batch to Anthropic, and persists the resulting batch ID + a + per-request manifest to ``.help/.batch-state.json``. + +- **Resume** — :func:`resume_maintenance_batch` loads that + state file, polls Anthropic until the batch terminates, + splices results back into per-feature ``GenerationResult`` + records, writes templates, and deletes the state file. + +Connecting the phases is the on-disk state file. We do NOT +persist *partial* polling progress; if a resume polling loop is +killed mid-flight, the next ``--resume`` re-polls from where +Anthropic says the batch is. Anthropic's view is the source of +truth. + +See ``specs/author-batch-maintain/design.md`` for the full +state-file layout and resume-ergonomics design choices. +""" + +from __future__ import annotations + +import json +import logging +from dataclasses import dataclass +from datetime import datetime, timedelta, timezone +from pathlib import Path +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from attune_author.generator import PolishPreparation + +from attune_author.doc_gen._anthropic_batch import ( + BatchPolishRequest, + BatchPolishResult, + adaptive_timeout_secs, + poll_batch, + submit_batch, +) + +logger = logging.getLogger(__name__) + + +#: Filename under ``/`` where pending-batch metadata is stored. +STATE_FILENAME = ".batch-state.json" + +#: Schema version for the state file. Bump if the on-disk shape changes +#: in an incompatible way; older files surface :class:`BatchStateExpired` +#: with a clean cleanup message. +STATE_SCHEMA_VERSION = 1 + +#: Anthropic retains batch results for 29 days. After that, ``retrieve`` +#: returns 404 and ``results`` returns nothing. We surface a clear +#: cleanup hint before crashing on the network call. +BATCH_RETENTION_DAYS = 29 + + +class BatchStateError(RuntimeError): + """Base class for state-file errors.""" + + +class BatchAlreadyPending(BatchStateError): + """Raised when ``--batch`` is run while a state file already exists. + + The CLI surfaces a helpful message: pass ``--force`` to + overwrite, or run ``--resume`` / ``--cancel`` on the + existing batch first. + """ + + +class BatchStateExpired(BatchStateError): + """Raised when the state file is older than Anthropic's retention. + + The CLI surfaces a clean cleanup hint instead of a stack + trace. Users should delete ``.help/.batch-state.json`` and + resubmit. + """ + + +class BatchStateNotFound(BatchStateError): + """Raised when ``--resume`` / ``--status`` / ``--cancel`` runs + but no state file exists. The CLI surfaces "no pending batch".""" + + +@dataclass(frozen=True) +class BatchStateRequest: + """Per-request entry inside the state file's ``requests`` list.""" + + custom_id: str + feature: str + depth: str + + def to_dict(self) -> dict[str, str]: + return { + "custom_id": self.custom_id, + "feature": self.feature, + "depth": self.depth, + } + + @classmethod + def from_dict(cls, raw: dict[str, str]) -> BatchStateRequest: + return cls( + custom_id=raw["custom_id"], + feature=raw["feature"], + depth=raw["depth"], + ) + + +@dataclass(frozen=True) +class BatchState: + """Persisted metadata for a submitted-not-yet-resumed batch. + + The state file is written at submit time and deleted at + resume completion (success OR permanent failure). On a + poll-loop timeout the file is intentionally **kept**, so + the user's next ``--resume`` re-picks-up cleanly. + """ + + schema_version: int + batch_id: str + submitted_at: datetime + expected_completion_at: datetime + model: str + requests: tuple[BatchStateRequest, ...] + + def to_dict(self) -> dict[str, Any]: + return { + "schema_version": self.schema_version, + "batch_id": self.batch_id, + "submitted_at": self.submitted_at.isoformat(), + "expected_completion_at": self.expected_completion_at.isoformat(), + "model": self.model, + "requests": [r.to_dict() for r in self.requests], + } + + @classmethod + def from_dict(cls, raw: dict[str, Any]) -> BatchState: + # Tolerate "Z" suffix (datetime.fromisoformat handles it on + # Python 3.11+; we normalize for 3.10 compatibility). + sub = _parse_iso(raw["submitted_at"]) + eta = _parse_iso(raw["expected_completion_at"]) + return cls( + schema_version=int(raw["schema_version"]), + batch_id=str(raw["batch_id"]), + submitted_at=sub, + expected_completion_at=eta, + model=str(raw["model"]), + requests=tuple(BatchStateRequest.from_dict(r) for r in raw.get("requests", [])), + ) + + def to_polish_requests( + self, + prompt_lookup: dict[tuple[str, str], BatchPolishRequest], + ) -> list[BatchPolishRequest]: + """Reconstruct ``BatchPolishRequest`` objects for resume polling. + + The state file persists only the (custom_id, feature, depth) + triples — the actual system/user prompts are reproducible + from the manifest at resume time. ``prompt_lookup`` maps + ``(feature, depth)`` to the freshly-rebuilt + ``BatchPolishRequest`` so the order is preserved. + """ + out: list[BatchPolishRequest] = [] + for r in self.requests: + req = prompt_lookup.get((r.feature, r.depth)) + if req is None: + # Manifest changed between submit and resume — surface + # via missing_in_results downstream rather than crashing. + req = BatchPolishRequest( + custom_id=r.custom_id, + feature=r.feature, + depth=r.depth, + system="", + user_message="", + model=self.model, + max_tokens=0, + ) + out.append(req) + return out + + +def state_path(help_dir: str | Path) -> Path: + """Resolve the on-disk state file path for ``help_dir``.""" + return Path(help_dir) / STATE_FILENAME + + +def write_batch_state(help_dir: str | Path, state: BatchState) -> Path: + """Write the state file atomically. + + Uses a tempfile + rename so a half-written file never lands + on disk if the process is killed mid-write. + """ + target = state_path(help_dir) + target.parent.mkdir(parents=True, exist_ok=True) + tmp = target.with_suffix(target.suffix + ".tmp") + tmp.write_text(json.dumps(state.to_dict(), indent=2) + "\n", encoding="utf-8") + tmp.replace(target) + return target + + +def read_batch_state( + help_dir: str | Path, + *, + now: datetime | None = None, +) -> BatchState: + """Load the state file, validating schema + retention. + + Raises: + BatchStateNotFound: No state file exists at ``help_dir``. + BatchStateExpired: State file is older than Anthropic's + 29-day retention window, or its schema version is + unrecognized. + """ + target = state_path(help_dir) + if not target.exists(): + raise BatchStateNotFound(f"no pending batch at {target}") + + try: + raw = json.loads(target.read_text(encoding="utf-8")) + except json.JSONDecodeError as exc: + raise BatchStateExpired( + f"state file at {target} is corrupt; delete it and rerun --batch: {exc}" + ) from None + + schema_version = int(raw.get("schema_version", 0)) + if schema_version != STATE_SCHEMA_VERSION: + raise BatchStateExpired( + f"state file at {target} has unsupported schema_version " + f"{schema_version}; delete it and rerun --batch" + ) + + state = BatchState.from_dict(raw) + + horizon = (now or _now_utc()) - timedelta(days=BATCH_RETENTION_DAYS) + if state.submitted_at < horizon: + raise BatchStateExpired( + f"batch submitted {state.submitted_at.isoformat()} is older than " + f"Anthropic's {BATCH_RETENTION_DAYS}-day retention window; " + f"delete {target} and rerun --batch" + ) + + return state + + +def delete_batch_state(help_dir: str | Path) -> None: + """Remove the state file, no-op if it doesn't exist.""" + target = state_path(help_dir) + if target.exists(): + target.unlink() + + +def has_pending_batch(help_dir: str | Path) -> bool: + """Lightweight check for a state file's existence. + + Used by the bare ``maintain`` command to surface a one-liner + hint without parsing the whole file. Does NOT raise on stale + state — that surfaces only when a real load is attempted. + """ + return state_path(help_dir).exists() + + +# --- helpers --------------------------------------------------------------- + + +def _parse_iso(value: str) -> datetime: + """Parse an ISO-8601 timestamp, normalizing the trailing Z.""" + if value.endswith("Z"): + value = value[:-1] + "+00:00" + parsed = datetime.fromisoformat(value) + if parsed.tzinfo is None: + parsed = parsed.replace(tzinfo=timezone.utc) + return parsed + + +def _now_utc() -> datetime: + return datetime.now(timezone.utc) + + +# --- prompt construction --------------------------------------------------- + + +def custom_id_for(feature: str, depth: str) -> str: + """Stable Anthropic ``custom_id`` for one (feature, depth) request.""" + return f"feat__{feature}__{depth}" + + +def _build_polish_request( + feature_name: str, + depth: str, + rendered_content: str, + source_summary: str, + augmented_context: str | None, +) -> BatchPolishRequest: + """Build a :class:`BatchPolishRequest` for one polish call. + + Uses the same ``build_polish_prompt`` helper the synchronous + path calls, so the wire-level prompts are byte-identical. + """ + from attune_author.polish import ( + _POLISH_MODEL, + POLISH_CACHE_SYSTEM, + POLISH_MAX_TOKENS, + build_polish_prompt, + ) + + system, user = build_polish_prompt( + rendered_content, + feature_name, + source_summary, + depth, + augmented_context=augmented_context, + ) + return BatchPolishRequest( + custom_id=custom_id_for(feature_name, depth), + feature=feature_name, + depth=depth, + system=system, + user_message=user, + model=_POLISH_MODEL, + max_tokens=POLISH_MAX_TOKENS, + cache_system=POLISH_CACHE_SYSTEM, + ) + + +def _collect_polish_prompts( + stale_features: list, + help_dir: str | Path, + project_root: str | Path, + *, + use_rag: bool = True, +) -> tuple[list[BatchPolishRequest], dict[str, PolishPreparation]]: + """Render templates and build polish requests for stale features. + + Returns ``(requests, prep_by_feature)``: + + - ``requests`` is the flat list of all polish requests across + every stale feature, in feature×depth order. + - ``prep_by_feature`` maps feature name to the + :class:`PolishPreparation` used at submit time, so the + resume splicer can re-derive the same out_paths and + depth ordering. + + Pure of LLM calls. The RAG grounding pass DOES happen here + (mirroring the synchronous path's ``_maybe_polish``) — we + bake the augmented_context into the user message so the + batch result is comparable to a synchronous polish call. + """ + from attune_author.generator import prepare_polish_phase + from attune_author.polish import build_source_summary + + requests: list[BatchPolishRequest] = [] + prep_by_feature: dict[str, PolishPreparation] = {} + + for feature in stale_features: + prep = prepare_polish_phase( + feature=feature, + help_dir=help_dir, + project_root=project_root, + use_rag=use_rag, + ) + prep_by_feature[feature.name] = prep + if not prep.pending: + continue + + # Mirror _maybe_polish: build source summary + (optional) RAG context + # once per feature, reuse across depths. + source_info = prep.source_info + summary = build_source_summary( + public_classes=source_info.public_classes, + public_functions=source_info.public_functions, + module_docstrings=source_info.module_docstrings, + file_count=source_info.file_count, + function_signatures=source_info.function_signatures or None, + class_signatures=source_info.class_signatures or None, + module_constants=source_info.module_constants or None, + ) + + for entry in prep.pending: + augmented: str | None = None + if use_rag: + from attune_author.rag_hook import ground_polish_context + + augmented = ground_polish_context(feature.name, entry.depth) + requests.append( + _build_polish_request( + feature_name=feature.name, + depth=entry.depth, + rendered_content=entry.rendered_content, + source_summary=summary, + augmented_context=augmented, + ) + ) + + return requests, prep_by_feature + + +# --- submit ---------------------------------------------------------------- + + +def submit_maintenance_batch( + help_dir: str | Path, + project_root: str | Path, + features: list[str] | None = None, + *, + force: bool = False, + client: Any | None = None, +) -> BatchState: + """SUBMIT phase: prepare prompts, post one batch, write state. + + Resolves stale features (or the explicit ``features`` list), + renders all templates, builds polish requests, posts one + Anthropic batch, and writes the state file. Returns the + persisted :class:`BatchState` so the CLI can print a + confirmation message. + + Raises: + BatchAlreadyPending: A state file already exists. + Pass ``force=True`` to overwrite (or run + ``--cancel`` first to release the existing batch). + ValueError: No stale features (nothing to do). + """ + if not force and has_pending_batch(help_dir): + raise BatchAlreadyPending( + f"pending batch already submitted; state file at {state_path(help_dir)}. " + "Run --resume, --cancel, or pass --force to overwrite." + ) + + from attune_author.maintenance import _resolve_stale_features + + stale = _resolve_stale_features(help_dir, project_root, features) + if not stale: + raise ValueError("no stale features to regenerate") + + requests, _prep = _collect_polish_prompts(stale, help_dir=help_dir, project_root=project_root) + if not requests: + raise ValueError("no polish requests to submit (all features skipped)") + + batch_id = submit_batch(requests, client=client) + submitted_at = _now_utc() + eta = submitted_at + _eta_for(len(requests)) + state = BatchState( + schema_version=STATE_SCHEMA_VERSION, + batch_id=batch_id, + submitted_at=submitted_at, + expected_completion_at=eta, + model=requests[0].model, + requests=tuple( + BatchStateRequest(custom_id=r.custom_id, feature=r.feature, depth=r.depth) + for r in requests + ), + ) + write_batch_state(help_dir, state) + logger.info( + "submitted batch %s (%d requests, eta %s)", + batch_id, + len(requests), + eta.isoformat(), + ) + return state + + +def _eta_for(num_requests: int) -> timedelta: + """Estimate completion time for ``num_requests`` based on adaptive timeout.""" + return timedelta(seconds=adaptive_timeout_secs(num_requests)) + + +# --- resume ---------------------------------------------------------------- + + +def resume_maintenance_batch( + help_dir: str | Path, + project_root: str | Path, + *, + poll_interval_secs: float = 30.0, + timeout_secs: float | None = None, + client: Any | None = None, + on_progress: Any = None, +) -> MaintenanceResultBatch: + """RESUME phase: poll, splice, write. + + Loads the state file, polls the batch until terminal (or + timeout), splices polished text per feature, and writes the + final templates. On terminal completion (success or permanent + failure for some features) the state file is **deleted**. On + poll-loop timeout the state file is **kept** so the user's + next ``--resume`` picks up cleanly. + + Per-feature failure isolation: if any depth of a feature's + requests errored in the batch, the entire feature is marked + failed (no templates written). Other features in the same + batch are written normally. + """ + state = read_batch_state(help_dir) + # Re-resolve current stale features so we know what's currently expected. + # We trust the state file's request list as the authoritative "what we + # asked for"; current-staleness is only for the report. + from attune_author.manifest import load_manifest as _load_manifest + from attune_author.staleness import check_staleness as _check_staleness + + manifest = _load_manifest(help_dir) + stale_features_in_state = [ + manifest.features[f] for f in {r.feature for r in state.requests} if f in manifest.features + ] + _, prep_by_feature = _collect_polish_prompts( + stale_features_in_state, + help_dir=help_dir, + project_root=project_root, + ) + + # Re-build the request list so poll_batch can map results in the + # same order as the state file (the on-the-wire entries have the + # actual prompts; we don't need them at resume time, but we need + # the request list to hold custom_ids and order). + expected = [ + BatchPolishRequest( + custom_id=r.custom_id, + feature=r.feature, + depth=r.depth, + system="", + user_message="", + model=state.model, + max_tokens=0, + ) + for r in state.requests + ] + + final_status, results = poll_batch( + state.batch_id, + expected, + client=client, + poll_interval_secs=poll_interval_secs, + timeout_secs=timeout_secs, + on_progress=on_progress, + ) + + # Group results by feature so we can apply per-feature isolation. + per_feature: dict[str, list[BatchPolishResult]] = {} + for r in results: + per_feature.setdefault(r.feature, []).append(r) + + report_features = stale_features_in_state + report = _check_staleness(manifest, help_dir, project_root, [f.name for f in report_features]) + rb = MaintenanceResultBatch( + staleness=report, + final_status=final_status, + batch_id=state.batch_id, + ) + + for feature_name, results_for_feat in per_feature.items(): + prep = prep_by_feature.get(feature_name) + if prep is None: + # Manifest changed between submit and resume. We can't write + # anything sensibly. Mark the feature failed and continue. + rb.failed.append(feature_name) + logger.warning( + "Feature '%s' was in batch but is no longer in manifest; skipping", + feature_name, + ) + continue + + if any(r.error for r in results_for_feat): + rb.failed.append(feature_name) + for r in results_for_feat: + if r.error: + logger.warning( + "Polish failed for %s/%s: %s", + r.feature, + r.depth, + r.error, + ) + continue + + polished = {r.depth: r.text or "" for r in results_for_feat} + from attune_author.generator import apply_polish_results + + gen_result = apply_polish_results(prep, polished) + rb.regenerated.append(gen_result) + + if final_status == "timed_out": + # Keep state file. User can run --resume again later. + logger.info( + "Batch %s still running (%d/%d done); state file kept for re-resume", + state.batch_id, + sum(1 for r in results if r.text is not None), + len(results), + ) + else: + delete_batch_state(help_dir) + + return rb + + +# --- status & cancel ------------------------------------------------------- + + +def status_maintenance_batch( + help_dir: str | Path, + *, + client: Any | None = None, +) -> dict[str, Any]: + """STATUS phase: one-shot query, no polling. + + Returns a JSON-friendly dict with both the persisted local + state and Anthropic's current view of the batch. Does not + raise on terminal states — the caller (CLI) renders both the + pending and the terminal cases. + """ + state = read_batch_state(help_dir) + if client is None: + from attune_author.doc_gen._anthropic import get_client + + client = get_client() + batch = client.messages.batches.retrieve(state.batch_id) + counts = getattr(batch, "request_counts", None) + return { + "batch_id": state.batch_id, + "submitted_at": state.submitted_at.isoformat(), + "expected_completion_at": state.expected_completion_at.isoformat(), + "model": state.model, + "request_count": len(state.requests), + "processing_status": getattr(batch, "processing_status", None), + "request_counts": ( + { + "succeeded": getattr(counts, "succeeded", 0), + "errored": getattr(counts, "errored", 0), + "expired": getattr(counts, "expired", 0), + "canceled": getattr(counts, "canceled", 0), + "processing": getattr(counts, "processing", 0), + } + if counts is not None + else None + ), + "ended_at": getattr(batch, "ended_at", None), + "expires_at": getattr(batch, "expires_at", None), + "cancel_initiated_at": getattr(batch, "cancel_initiated_at", None), + } + + +def cancel_maintenance_batch( + help_dir: str | Path, + *, + client: Any | None = None, +) -> str: + """CANCEL phase: cancel the batch and remove the state file. + + Returns the canceled batch_id for the CLI to confirm. + """ + state = read_batch_state(help_dir) + if client is None: + from attune_author.doc_gen._anthropic import get_client + + client = get_client() + try: + client.messages.batches.cancel(state.batch_id) + except Exception as exc: # noqa: BLE001 — cancel is best-effort + logger.warning( + "cancel call failed for batch %s (state file still removed): %s", + state.batch_id, + exc, + ) + delete_batch_state(help_dir) + return state.batch_id + + +# --- result type ----------------------------------------------------------- + + +@dataclass +class MaintenanceResultBatch: + """Result of a batch maintenance run. + + Mirrors :class:`MaintenanceResult` so downstream consumers + (status reports, hooks) can treat it interchangeably. Adds + ``final_status`` and ``batch_id`` so callers can render + timeouts and partial-completion states correctly. + """ + + staleness: Any # StalenessReport — typed loosely to avoid import cycle + final_status: str = "ended" + batch_id: str = "" + regenerated: list = None # type: ignore[assignment] + failed: list = None # type: ignore[assignment] + + def __post_init__(self) -> None: + if self.regenerated is None: + self.regenerated = [] + if self.failed is None: + self.failed = [] + + @property + def stale_count(self) -> int: + return self.staleness.stale_count + + @property + def regenerated_count(self) -> int: + return len(self.regenerated) diff --git a/src/attune_author/polish.py b/src/attune_author/polish.py index 877dc0c..547d2bc 100644 --- a/src/attune_author/polish.py +++ b/src/attune_author/polish.py @@ -333,7 +333,7 @@ def polish_template( # type to catch. if effective_strict: raise PolishError( - f"Polish pass failed for {feature_name!r} " f"(type={template_type!r}): {exc}" + f"Polish pass failed for {feature_name!r} (type={template_type!r}): {exc}" ) from exc # Lenient mode is the explicit opt-out from strict, so a # failure here is the user-visible degradation we promised @@ -374,6 +374,49 @@ def _sanitize_output(content: str) -> str: return body +def build_polish_prompt( + content: str, + feature_name: str, + source_summary: str, + template_type: str, + augmented_context: str | None = None, +) -> tuple[str, str]: + """Build the (system_prompt, user_message) pair for a polish call. + + Extracted so the batch-API code path + (:mod:`attune_author.doc_gen._anthropic_batch`) can build + byte-identical prompts to the synchronous path. Pure: no + side effects, no I/O. + + Returns: + ``(system_prompt, user_message)`` ready to pass to + either :func:`call_anthropic` (synchronous) or + :class:`BatchPolishRequest` construction (batch). + """ + system_prompt = get_system_prompt(template_type) + grounding = augmented_context or "" + user_message = ( + f"Polish this auto-generated {template_type} template " + f"for the '{feature_name}' feature.\n\n" + f"{grounding}" + f"## Source info (for accuracy checking)\n\n" + f"{source_summary}\n\n" + f"## Template to polish\n\n" + f"{content}" + ) + return system_prompt, user_message + + +#: Polish-call defaults. Exposed as constants so the batch path +#: can build identical ``BatchPolishRequest`` envelopes without +#: hard-coding values that the synchronous path might drift away +#: from. ``cache_system=True`` because the polish system prompt is +#: ~6000 tokens — well over the 1024-token sonnet caching threshold; +#: cache hits cut input cost ~90% on repeated polish passes. +POLISH_MAX_TOKENS = 4096 +POLISH_CACHE_SYSTEM = True + + def _call_llm( content: str, feature_name: str, @@ -403,17 +446,12 @@ def _call_llm( the SDK call fails. """ client = get_client() - system_prompt = get_system_prompt(template_type) - - grounding = augmented_context or "" - user_message = ( - f"Polish this auto-generated {template_type} template " - f"for the '{feature_name}' feature.\n\n" - f"{grounding}" - f"## Source info (for accuracy checking)\n\n" - f"{source_summary}\n\n" - f"## Template to polish\n\n" - f"{content}" + system_prompt, user_message = build_polish_prompt( + content, + feature_name, + source_summary, + template_type, + augmented_context=augmented_context, ) polished = call_anthropic( @@ -421,11 +459,8 @@ def _call_llm( system=system_prompt, user_message=user_message, model=_POLISH_MODEL, - max_tokens=4096, - # Polish system prompt is ~6000 tokens — well over the 1024-token - # threshold for sonnet caching. Cache hits cut input cost ~90% on - # repeated polish passes (template regen, retries, etc.). - cache_system=True, + max_tokens=POLISH_MAX_TOKENS, + cache_system=POLISH_CACHE_SYSTEM, ) return polished or content diff --git a/tests/golden/batch_response.json b/tests/golden/batch_response.json new file mode 100644 index 0000000..31b4a31 --- /dev/null +++ b/tests/golden/batch_response.json @@ -0,0 +1,65 @@ +{ + "_comment": "Recorded shape mirroring an Anthropic Messages Batches retrieve() response after the batch has terminated. SDK 0.89.0. Used by tests/test_anthropic_batch.py to lock the parser against the real shape; when the SDK shape shifts, this fixture's parser-vs-types parity test fails first.", + "id": "msgbatch_01abc123", + "type": "message_batch", + "processing_status": "ended", + "request_counts": { + "canceled": 0, + "errored": 1, + "expired": 0, + "processing": 0, + "succeeded": 2 + }, + "ended_at": "2026-05-08T18:46:00Z", + "expires_at": "2026-05-09T18:35:00Z", + "created_at": "2026-05-08T18:35:00Z", + "cancel_initiated_at": null, + "results_url": "https://api.anthropic.com/v1/messages/batches/msgbatch_01abc123/results", + "_results": [ + { + "custom_id": "feat__security-audit__concept", + "result": { + "type": "succeeded", + "message": { + "id": "msg_aaa", + "type": "message", + "role": "assistant", + "content": [ + {"type": "text", "text": "Polished concept template body for security-audit."} + ], + "model": "claude-sonnet-4-6", + "stop_reason": "end_turn" + } + } + }, + { + "custom_id": "feat__security-audit__task", + "result": { + "type": "succeeded", + "message": { + "id": "msg_bbb", + "type": "message", + "role": "assistant", + "content": [ + {"type": "text", "text": "Polished task template body for security-audit."} + ], + "model": "claude-sonnet-4-6", + "stop_reason": "end_turn" + } + } + }, + { + "custom_id": "feat__smart-test__concept", + "result": { + "type": "errored", + "error": { + "type": "error", + "error": { + "type": "rate_limit_error", + "message": "Rate limit exceeded for batch processing." + } + } + } + } + ] +} diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/integration/test_batch_live.py b/tests/integration/test_batch_live.py new file mode 100644 index 0000000..dd4eafd --- /dev/null +++ b/tests/integration/test_batch_live.py @@ -0,0 +1,81 @@ +"""Live integration test for the Anthropic Message Batches API. + +Default-skipped. To run: + + ANTHROPIC_API_KEY=sk-ant-... RUN_LIVE_BATCH=1 \\ + pytest tests/integration/test_batch_live.py -m live + +Costs ~$0.02 per run (two batched polish calls). Asserts: + +- ``submit_batch`` returns a real batch id. +- ``poll_batch`` reaches ``ended`` (or marks as ``timed_out`` if + Anthropic is slow today). +- Each result has either non-empty text or a structured error. + +Tags ``@pytest.mark.live`` so projects that allow some live tests +in CI can opt in selectively. +""" + +from __future__ import annotations + +import os + +import pytest + +from attune_author.doc_gen._anthropic_batch import ( + BatchPolishRequest, + poll_batch, + submit_batch, +) + +pytestmark = [ + pytest.mark.live, + pytest.mark.skipif( + not os.environ.get("ANTHROPIC_API_KEY") or os.environ.get("RUN_LIVE_BATCH") != "1", + reason="Requires ANTHROPIC_API_KEY and RUN_LIVE_BATCH=1.", + ), +] + + +def _request(custom_id: str, query: str) -> BatchPolishRequest: + """Tiny prompt so the live call is cheap. + + System prompt is short on purpose — keeps token cost minimal + and exercises the wire format without polishing real content. + """ + return BatchPolishRequest( + custom_id=custom_id, + feature="test", + depth="concept", + system="Reply with a one-word sentence answering the user.", + user_message=query, + model="claude-haiku-4-5-20251001", + max_tokens=64, + cache_system=False, + ) + + +def test_submit_then_poll_two_request_batch() -> None: + requests = [ + _request("req1", "What color is the sky? Answer in one word."), + _request("req2", "What color is grass? Answer in one word."), + ] + + batch_id = submit_batch(requests) + assert batch_id.startswith("msgbatch_") + + final_status, results = poll_batch( + batch_id, + requests, + # Live batch latency is hard to predict; give it 5 minutes. + timeout_secs=300.0, + poll_interval_secs=10.0, + ) + + assert final_status in {"ended", "timed_out"} + assert len(results) == 2 + if final_status == "ended": + # Each request should have either text or a structured error, + # never both, never neither. + for r in results: + assert (r.text is None) ^ (r.error is None) diff --git a/tests/test_anthropic_batch.py b/tests/test_anthropic_batch.py new file mode 100644 index 0000000..96ede8a --- /dev/null +++ b/tests/test_anthropic_batch.py @@ -0,0 +1,434 @@ +"""Unit tests for the Anthropic Message Batches API wrapper. + +Three surfaces covered: + +1. **Payload shape** — :func:`_build_request_param` produces the + exact wire format the SDK expects, including the + ``cache_control`` wrap when ``cache_system=True``. +2. **Submit** — :func:`submit_batch` posts the rendered payload + and returns the batch id from the SDK response. +3. **Poll loop** — :func:`poll_batch` covers the full state + matrix: in_progress→ended, ended-with-errors, canceled, + timed_out, ctrl-c-during-poll. All driven by a stub SDK; no + live calls. + +The recorded fixture lives at ``tests/golden/batch_response.json`` +and mirrors the SDK's ``MessageBatch`` + ``MessageBatchIndividual +Response`` shape. +""" + +from __future__ import annotations + +import json +from pathlib import Path +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from attune_author.doc_gen._anthropic_batch import ( + BatchPolishRequest, + BatchPolishResult, + _build_request_param, + adaptive_timeout_secs, + poll_batch, + submit_batch, +) + +_FIXTURE = Path(__file__).resolve().parent / "golden" / "batch_response.json" + + +def _req( + custom_id: str, + feature: str = "f", + depth: str = "concept", + *, + cache_system: bool = False, +) -> BatchPolishRequest: + return BatchPolishRequest( + custom_id=custom_id, + feature=feature, + depth=depth, + system="sys-prompt", + user_message="user-msg", + model="claude-sonnet-4-6", + max_tokens=2048, + cache_system=cache_system, + ) + + +def _ns(**kwargs: object) -> SimpleNamespace: + return SimpleNamespace(**kwargs) + + +def _fixture_results() -> list[SimpleNamespace]: + """Translate the fixture's ``_results`` array into SimpleNamespaces. + + Mirrors the runtime shape: each entry has ``custom_id`` and + ``result``; ``result`` carries ``type`` plus a discriminated + payload. The same recursive walk would work for any future + SDK shape change. + """ + raw = json.loads(_FIXTURE.read_text()) + out: list[SimpleNamespace] = [] + for entry in raw["_results"]: + result_dict = entry["result"] + rtype = result_dict["type"] + if rtype == "succeeded": + msg_dict = result_dict["message"] + content = [_ns(**block) for block in msg_dict.get("content", [])] + message = _ns(content=content, **{k: v for k, v in msg_dict.items() if k != "content"}) + result = _ns(type=rtype, message=message) + elif rtype == "errored": + err_outer = result_dict["error"] + inner = err_outer.get("error", {}) or {} + error = _ns( + type=err_outer.get("type"), + error=_ns(**inner), + message=inner.get("message"), + ) + result = _ns(type=rtype, error=error) + else: + result = _ns(type=rtype) + out.append(_ns(custom_id=entry["custom_id"], result=result)) + return out + + +def _fixture_batch(processing_status: str = "ended") -> SimpleNamespace: + raw = json.loads(_FIXTURE.read_text()) + return _ns( + id=raw["id"], + processing_status=processing_status, + request_counts=_ns(**raw["request_counts"]), + ended_at=raw["ended_at"], + expires_at=raw["expires_at"], + created_at=raw["created_at"], + cancel_initiated_at=raw["cancel_initiated_at"], + results_url=raw["results_url"], + ) + + +def _stub_client( + *, + batch_states: list[SimpleNamespace] | None = None, + results: list[SimpleNamespace] | None = None, + create_id: str = "msgbatch_01abc123", +) -> MagicMock: + """Build a stub SDK client. + + ``batch_states`` is the sequence ``retrieve`` returns on + successive calls (so a poll loop can transition through + in_progress → ended). ``results`` is the list ``results()`` + yields on terminal state. + """ + client = MagicMock() + client.messages.batches.create.return_value = _ns(id=create_id) + if batch_states: + client.messages.batches.retrieve.side_effect = batch_states + if results is not None: + client.messages.batches.results.return_value = iter(results) + return client + + +# --- payload shape ---------------------------------------------------------- + + +def test_build_request_param_default_no_cache() -> None: + p = _build_request_param(_req("feat__a__concept")) + assert p["custom_id"] == "feat__a__concept" + assert p["params"]["model"] == "claude-sonnet-4-6" + assert p["params"]["max_tokens"] == 2048 + assert p["params"]["system"] == "sys-prompt" + assert p["params"]["messages"] == [{"role": "user", "content": "user-msg"}] + + +def test_build_request_param_with_cache_wraps_system_block() -> None: + p = _build_request_param(_req("feat__a__concept", cache_system=True)) + assert p["params"]["system"] == [ + { + "type": "text", + "text": "sys-prompt", + "cache_control": {"type": "ephemeral"}, + } + ] + + +# --- submit ----------------------------------------------------------------- + + +def test_submit_batch_rejects_empty_requests() -> None: + with pytest.raises(ValueError, match="at least one request"): + submit_batch([], client=MagicMock()) + + +def test_submit_batch_renders_payload_and_returns_id() -> None: + client = _stub_client(create_id="msgbatch_xyz") + batch_id = submit_batch( + [ + _req("feat__a__concept", feature="a"), + _req("feat__b__task", feature="b", depth="task"), + ], + client=client, + ) + assert batch_id == "msgbatch_xyz" + sent = client.messages.batches.create.call_args.kwargs["requests"] + assert len(sent) == 2 + assert sent[0]["custom_id"] == "feat__a__concept" + assert sent[1]["custom_id"] == "feat__b__task" + + +# --- poll loop -------------------------------------------------------------- + + +def test_poll_terminal_ended_with_results_splices_in_input_order() -> None: + """The order of ``BatchPolishResult`` matches the order of the + ``expected`` argument, regardless of what order Anthropic + returns results in. + """ + expected = [ + _req("feat__security-audit__concept", feature="security-audit"), + _req("feat__security-audit__task", feature="security-audit", depth="task"), + _req("feat__smart-test__concept", feature="smart-test"), + ] + client = _stub_client( + batch_states=[_fixture_batch("ended")], + results=list(reversed(_fixture_results())), # out-of-order on the wire + ) + status, results = poll_batch( + "msgbatch_01abc123", + expected, + client=client, + _now=lambda: 0.0, + _sleep=lambda _: None, + ) + assert status == "ended" + assert [r.custom_id for r in results] == [r.custom_id for r in expected] + assert results[0].text == "Polished concept template body for security-audit." + assert results[1].text == "Polished task template body for security-audit." + assert results[2].text is None + assert "Rate limit" in (results[2].error or "") + + +def test_poll_in_progress_then_ended_advances_through_states() -> None: + expected = [_req("feat__a__concept", feature="a")] + in_progress = _fixture_batch("in_progress") + in_progress.request_counts = _ns(canceled=0, errored=0, expired=0, processing=1, succeeded=0) + ended = _fixture_batch("ended") + ended.request_counts = _ns(canceled=0, errored=0, expired=0, processing=0, succeeded=1) + client = _stub_client( + batch_states=[in_progress, ended], + results=[_fixture_results()[0]], + ) + sleeps: list[float] = [] + status, results = poll_batch( + "msgbatch_01abc123", + expected, + client=client, + poll_interval_secs=12.5, + _now=lambda: 0.0, + _sleep=sleeps.append, + ) + assert status == "ended" + assert sleeps == [12.5] # exactly one sleep between two poll cycles + assert client.messages.batches.retrieve.call_count == 2 + + +def test_poll_canceled_status_returns_canceled() -> None: + expected = [_req("feat__a__concept", feature="a")] + canceled = _fixture_batch("canceled") + client = _stub_client(batch_states=[canceled], results=[]) + status, _ = poll_batch( + "msgbatch_01abc123", + expected, + client=client, + _now=lambda: 0.0, + _sleep=lambda _: None, + ) + assert status == "canceled" + + +def test_poll_cancel_initiated_treated_as_canceled() -> None: + """Even with processing_status='ended', a non-null + cancel_initiated_at means the user canceled mid-flight.""" + expected = [_req("feat__a__concept", feature="a")] + batch = _fixture_batch("ended") + batch.cancel_initiated_at = "2026-05-08T18:40:00Z" + client = _stub_client(batch_states=[batch], results=_fixture_results()[:1]) + status, _ = poll_batch( + "msgbatch_01abc123", + expected, + client=client, + _now=lambda: 0.0, + _sleep=lambda _: None, + ) + assert status == "canceled" + + +def test_poll_expired_count_surfaces_expired_status() -> None: + expected = [_req("feat__a__concept", feature="a")] + batch = _fixture_batch("ended") + batch.request_counts = _ns(canceled=0, errored=0, expired=1, processing=0, succeeded=0) + client = _stub_client(batch_states=[batch], results=[]) + status, _ = poll_batch( + "msgbatch_01abc123", + expected, + client=client, + _now=lambda: 0.0, + _sleep=lambda _: None, + ) + assert status == "expired" + + +def test_poll_timeout_returns_timed_out_with_unfinished_marked() -> None: + """When the poll-loop ceiling fires before the batch is + terminal, every unfinished entry surfaces as + ``error="timed_out"`` so the orchestration layer can mark + them in ``MaintenanceResult.failed``. + """ + expected = [ + _req("feat__a__concept", feature="a"), + _req("feat__b__concept", feature="b"), + ] + in_progress = _fixture_batch("in_progress") + client = _stub_client(batch_states=[in_progress, in_progress, in_progress], results=[]) + ticks = iter([0.0, 100.0, 200.0]) + status, results = poll_batch( + "msgbatch_01abc123", + expected, + client=client, + timeout_secs=50.0, + poll_interval_secs=10.0, + _now=lambda: next(ticks), + _sleep=lambda _: None, + ) + assert status == "timed_out" + assert all(r.text is None and r.error == "timed_out" for r in results) + assert [r.custom_id for r in results] == [r.custom_id for r in expected] + + +def test_poll_ctrl_c_cancels_batch_and_re_raises() -> None: + """SIGINT mid-poll calls cancel and re-raises so the CLI can + clean up its state file. Cancel failures must not mask the + KeyboardInterrupt.""" + expected = [_req("feat__a__concept", feature="a")] + client = MagicMock() + client.messages.batches.retrieve.side_effect = KeyboardInterrupt + with pytest.raises(KeyboardInterrupt): + poll_batch( + "msgbatch_01abc123", + expected, + client=client, + _now=lambda: 0.0, + _sleep=lambda _: None, + ) + client.messages.batches.cancel.assert_called_once_with("msgbatch_01abc123") + + +def test_poll_ctrl_c_cancel_failure_does_not_mask_interrupt() -> None: + expected = [_req("feat__a__concept", feature="a")] + client = MagicMock() + client.messages.batches.retrieve.side_effect = KeyboardInterrupt + client.messages.batches.cancel.side_effect = RuntimeError("network down") + with pytest.raises(KeyboardInterrupt): + poll_batch( + "msgbatch_01abc123", + expected, + client=client, + _now=lambda: 0.0, + _sleep=lambda _: None, + ) + + +def test_poll_missing_result_for_request_marks_missing() -> None: + """If results() returns an entry for an unknown custom_id (or + omits one), the affected request becomes + ``error="missing_in_results"`` rather than crashing.""" + expected = [ + _req("feat__security-audit__concept", feature="security-audit"), + _req("feat__missing__concept", feature="missing"), + ] + client = _stub_client( + batch_states=[_fixture_batch("ended")], + results=[_fixture_results()[0]], # only the security-audit entry + ) + _, results = poll_batch( + "msgbatch_01abc123", + expected, + client=client, + _now=lambda: 0.0, + _sleep=lambda _: None, + ) + assert results[0].text is not None + assert results[1].text is None + assert results[1].error == "missing_in_results" + + +def test_poll_progress_callback_invoked_each_retrieve() -> None: + expected = [_req("feat__a__concept", feature="a")] + client = _stub_client( + batch_states=[_fixture_batch("ended")], + results=[_fixture_results()[0]], + ) + seen: list[object] = [] + poll_batch( + "msgbatch_01abc123", + expected, + client=client, + on_progress=seen.append, + _now=lambda: 0.0, + _sleep=lambda _: None, + ) + assert len(seen) == 1 + + +def test_poll_unknown_result_type_does_not_crash() -> None: + expected = [_req("feat__sample__concept", feature="sample")] + client = _stub_client( + batch_states=[_fixture_batch("ended")], + results=[ + _ns( + custom_id="feat__sample__concept", + result=_ns(type="some_future_status"), + ) + ], + ) + status, results = poll_batch( + "msgbatch_01abc123", + expected, + client=client, + _now=lambda: 0.0, + _sleep=lambda _: None, + ) + assert status == "ended" + assert results[0].error == "unknown_result_type:some_future_status" + + +# --- adaptive_timeout_secs -------------------------------------------------- + + +def test_adaptive_timeout_floor_for_small_batches() -> None: + assert adaptive_timeout_secs(0) == 5 * 60 + assert adaptive_timeout_secs(1) == 5 * 60 + assert adaptive_timeout_secs(50) == 5 * 60 + + +def test_adaptive_timeout_grows_linearly_above_floor() -> None: + assert adaptive_timeout_secs(200) == pytest.approx(600.0) + + +def test_adaptive_timeout_capped_at_30_minutes() -> None: + assert adaptive_timeout_secs(10_000) == 30 * 60 + + +# --- BatchPolishResult contract -------------------------------------------- + + +def test_batch_polish_result_text_xor_error_invariant() -> None: + """Construction can produce either successful (text set, error + None) or failed (text None, error set) results — but the + parser layer never produces both, so this contract gives + downstream consumers a clean discriminator.""" + ok = BatchPolishResult(custom_id="x", feature="f", depth="d", text="hello", error=None) + bad = BatchPolishResult(custom_id="x", feature="f", depth="d", text=None, error="boom") + assert (ok.text is None) ^ (ok.error is None) + assert (bad.text is None) ^ (bad.error is None) diff --git a/tests/test_cli_batch.py b/tests/test_cli_batch.py new file mode 100644 index 0000000..cb21e31 --- /dev/null +++ b/tests/test_cli_batch.py @@ -0,0 +1,396 @@ +"""CLI tests for the batch flags on `attune-author regenerate`. + +Covers: + +- argparse mutex enforcement (--batch / --resume / --status / --cancel + are mutually exclusive; --force and --json are modifiers) +- submit prints batch ID + copy-paste-ready resume hint +- bare `regenerate` prints a one-liner when a state file exists, + but doesn't auto-resume +- --resume and --status surface clean errors when no state file +- --status JSON mode produces parseable output +- --cancel removes the state file + +Each test mocks the maintenance_batch entry points; nothing +reaches Anthropic. +""" + +from __future__ import annotations + +import json +from datetime import datetime, timezone +from pathlib import Path +from unittest.mock import patch + +import pytest + +from attune_author.cli import _build_parser, main +from attune_author.maintenance_batch import ( + BatchState, + BatchStateRequest, + state_path, + write_batch_state, +) + + +def _state(batch_id: str = "msgbatch_test") -> BatchState: + return BatchState( + schema_version=1, + batch_id=batch_id, + submitted_at=datetime(2026, 5, 8, 18, 35, tzinfo=timezone.utc), + expected_completion_at=datetime(2026, 5, 8, 18, 41, tzinfo=timezone.utc), + model="claude-sonnet-4-20250514", + requests=( + BatchStateRequest("feat__auth__concept", "auth", "concept"), + BatchStateRequest("feat__auth__task", "auth", "task"), + ), + ) + + +# --- mutex / parser -------------------------------------------------------- + + +class TestMutex: + """Each batch-mode flag excludes the others; --force / --json modify.""" + + def test_batch_and_resume_rejected(self) -> None: + with pytest.raises(SystemExit): + _build_parser().parse_args(["regenerate", "--batch", "--resume"]) + + def test_status_and_cancel_rejected(self) -> None: + with pytest.raises(SystemExit): + _build_parser().parse_args(["regenerate", "--status", "--cancel"]) + + def test_batch_with_force_allowed(self) -> None: + args = _build_parser().parse_args(["regenerate", "--batch", "--force"]) + assert args.batch and args.force + + def test_status_with_json_allowed(self) -> None: + args = _build_parser().parse_args(["regenerate", "--status", "--json"]) + assert args.status and args.json + + +# --- submit ---------------------------------------------------------------- + + +class TestSubmit: + def test_prints_batch_id_and_resume_hint( + self, help_dir: Path, project_root: Path, capsys, monkeypatch + ) -> None: + monkeypatch.chdir(project_root) + with patch( + "attune_author.maintenance_batch.submit_maintenance_batch", + return_value=_state("msgbatch_xyz"), + ): + rc = main( + [ + "regenerate", + "--batch", + "--help-dir", + str(help_dir), + "--project-root", + str(project_root), + ] + ) + assert rc == 0 + out = capsys.readouterr().out + assert "msgbatch_xyz" in out + assert "regenerate --resume" in out + # Estimated completion should be rendered somewhere. + assert "completion" in out.lower() or "min" in out.lower() + + def test_already_pending_returns_error( + self, help_dir: Path, project_root: Path, capsys, monkeypatch + ) -> None: + monkeypatch.chdir(project_root) + write_batch_state(help_dir, _state()) + # No --force: should fail without ever calling submit_maintenance_batch. + rc = main( + [ + "regenerate", + "--batch", + "--help-dir", + str(help_dir), + "--project-root", + str(project_root), + ] + ) + err = capsys.readouterr().err + assert rc == 2 + assert "already" in err.lower() or "pending" in err.lower() + + +# --- resume ---------------------------------------------------------------- + + +class TestResume: + def test_no_state_returns_error( + self, help_dir: Path, project_root: Path, capsys, monkeypatch + ) -> None: + monkeypatch.chdir(project_root) + rc = main( + [ + "regenerate", + "--resume", + "--help-dir", + str(help_dir), + "--project-root", + str(project_root), + ] + ) + assert rc == 1 + assert "no pending batch" in capsys.readouterr().err.lower() + + def test_completed_batch_prints_summary( + self, help_dir: Path, project_root: Path, capsys, monkeypatch + ) -> None: + monkeypatch.chdir(project_root) + write_batch_state(help_dir, _state("msgbatch_done")) + + from attune_author.maintenance_batch import MaintenanceResultBatch + from attune_author.staleness import StalenessReport + + outcome = MaintenanceResultBatch( + staleness=StalenessReport(help_entries=[], doc_entries=[]), + final_status="ended", + batch_id="msgbatch_done", + ) + outcome.regenerated = ["fake-result"] # mimic 1 regenerated + with patch( + "attune_author.maintenance_batch.resume_maintenance_batch", + return_value=outcome, + ): + rc = main( + [ + "regenerate", + "--resume", + "--help-dir", + str(help_dir), + "--project-root", + str(project_root), + ] + ) + assert rc == 0 + out = capsys.readouterr().out + assert "msgbatch_done" in out + assert "ended" in out + + def test_timed_out_keeps_state_and_prints_hint( + self, help_dir: Path, project_root: Path, capsys, monkeypatch + ) -> None: + monkeypatch.chdir(project_root) + write_batch_state(help_dir, _state("msgbatch_inflight")) + + from attune_author.maintenance_batch import MaintenanceResultBatch + from attune_author.staleness import StalenessReport + + outcome = MaintenanceResultBatch( + staleness=StalenessReport(help_entries=[], doc_entries=[]), + final_status="timed_out", + batch_id="msgbatch_inflight", + ) + with patch( + "attune_author.maintenance_batch.resume_maintenance_batch", + return_value=outcome, + ): + rc = main( + [ + "regenerate", + "--resume", + "--help-dir", + str(help_dir), + "--project-root", + str(project_root), + ] + ) + assert rc == 0 + out = capsys.readouterr().out + assert "still running" in out.lower() + assert "regenerate --resume" in out + + +# --- status ---------------------------------------------------------------- + + +class TestStatus: + _info = { + "batch_id": "msgbatch_status", + "submitted_at": "2026-05-08T18:35:00+00:00", + "expected_completion_at": "2026-05-08T18:41:00+00:00", + "model": "claude-sonnet-4-20250514", + "request_count": 3, + "processing_status": "in_progress", + "request_counts": { + "succeeded": 1, + "errored": 0, + "expired": 0, + "canceled": 0, + "processing": 2, + }, + "ended_at": None, + "expires_at": "2026-05-09T18:35:00+00:00", + "cancel_initiated_at": None, + } + + def test_human_output(self, help_dir: Path, project_root: Path, capsys, monkeypatch) -> None: + monkeypatch.chdir(project_root) + with patch( + "attune_author.maintenance_batch.status_maintenance_batch", + return_value=self._info, + ): + rc = main( + [ + "regenerate", + "--status", + "--help-dir", + str(help_dir), + "--project-root", + str(project_root), + ] + ) + assert rc == 0 + out = capsys.readouterr().out + assert "msgbatch_status" in out + assert "in_progress" in out + assert "succeeded=1" in out + + def test_json_output_is_parseable( + self, help_dir: Path, project_root: Path, capsys, monkeypatch + ) -> None: + monkeypatch.chdir(project_root) + with patch( + "attune_author.maintenance_batch.status_maintenance_batch", + return_value=self._info, + ): + rc = main( + [ + "regenerate", + "--status", + "--json", + "--help-dir", + str(help_dir), + "--project-root", + str(project_root), + ] + ) + assert rc == 0 + parsed = json.loads(capsys.readouterr().out) + assert parsed["batch_id"] == "msgbatch_status" + assert parsed["request_counts"]["processing"] == 2 + + def test_no_pending_returns_error( + self, help_dir: Path, project_root: Path, capsys, monkeypatch + ) -> None: + monkeypatch.chdir(project_root) + rc = main( + [ + "regenerate", + "--status", + "--help-dir", + str(help_dir), + "--project-root", + str(project_root), + ] + ) + assert rc == 1 + assert "no pending batch" in capsys.readouterr().err.lower() + + +# --- cancel ---------------------------------------------------------------- + + +class TestCancel: + def test_cancel_removes_state( + self, help_dir: Path, project_root: Path, capsys, monkeypatch + ) -> None: + monkeypatch.chdir(project_root) + with patch( + "attune_author.maintenance_batch.cancel_maintenance_batch", + return_value="msgbatch_killed", + ): + rc = main( + [ + "regenerate", + "--cancel", + "--help-dir", + str(help_dir), + "--project-root", + str(project_root), + ] + ) + assert rc == 0 + assert "msgbatch_killed" in capsys.readouterr().out + + def test_cancel_no_state_returns_error( + self, help_dir: Path, project_root: Path, capsys, monkeypatch + ) -> None: + monkeypatch.chdir(project_root) + rc = main( + [ + "regenerate", + "--cancel", + "--help-dir", + str(help_dir), + "--project-root", + str(project_root), + ] + ) + assert rc == 1 + assert "no pending" in capsys.readouterr().err.lower() + + +# --- bare command pending-batch hint -------------------------------------- + + +class TestBareCommandHint: + def test_pending_batch_hint_printed_to_stderr( + self, help_dir: Path, project_root: Path, capsys, monkeypatch + ) -> None: + """Plain `regenerate` (no batch flags) prints a one-liner + when a state file exists, but does not auto-resume.""" + monkeypatch.chdir(project_root) + write_batch_state(help_dir, _state()) + # patch run_maintenance to a no-op so we just exercise the hint path + with ( + patch("attune_author.maintenance.run_maintenance") as mock_run, + patch( + "attune_author.maintenance_batch.has_pending_batch", + return_value=True, + ), + ): + mock_run.return_value.regenerated_count = 0 + mock_run.return_value.failed = [] + rc = main( + [ + "regenerate", + "--help-dir", + str(help_dir), + "--project-root", + str(project_root), + ] + ) + assert rc == 0 + err = capsys.readouterr().err + assert "pending batch" in err.lower() + assert "--resume" in err + # state file is NOT auto-deleted + assert state_path(help_dir).exists() + + def test_no_hint_when_no_pending_batch( + self, help_dir: Path, project_root: Path, capsys, monkeypatch + ) -> None: + monkeypatch.chdir(project_root) + with patch("attune_author.maintenance.run_maintenance") as mock_run: + mock_run.return_value.regenerated_count = 0 + mock_run.return_value.failed = [] + main( + [ + "regenerate", + "--help-dir", + str(help_dir), + "--project-root", + str(project_root), + ] + ) + err = capsys.readouterr().err + assert "pending batch" not in err.lower() diff --git a/tests/test_maintenance_batch.py b/tests/test_maintenance_batch.py new file mode 100644 index 0000000..743e042 --- /dev/null +++ b/tests/test_maintenance_batch.py @@ -0,0 +1,390 @@ +"""Tests for the batched maintenance flow. + +Covers: + +- State file lifecycle (write/read/delete, schema, retention) +- ``submit_maintenance_batch`` happy path + multiple-batch guard +- ``resume_maintenance_batch`` happy path, partial-failure + isolation per feature, and timeout-keeps-state semantics +- ``status_maintenance_batch`` shape +- ``cancel_maintenance_batch`` cleanup +- **Parity**: same stale features → same on-disk template + files, regardless of synchronous vs batch path + +The Anthropic SDK is stubbed at the +``attune_author.doc_gen._anthropic_batch.{submit_batch, +poll_batch}`` boundary so no live calls leave the test process. +""" + +from __future__ import annotations + +import json +from datetime import datetime, timezone +from pathlib import Path +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from attune_author.doc_gen._anthropic_batch import ( + BatchPolishRequest, + BatchPolishResult, +) +from attune_author.maintenance_batch import ( + BatchAlreadyPending, + BatchState, + BatchStateExpired, + BatchStateNotFound, + BatchStateRequest, + MaintenanceResultBatch, + cancel_maintenance_batch, + custom_id_for, + delete_batch_state, + has_pending_batch, + read_batch_state, + resume_maintenance_batch, + state_path, + status_maintenance_batch, + submit_maintenance_batch, + write_batch_state, +) + +# --- helpers --------------------------------------------------------------- + + +def _state(submitted_at: datetime | None = None) -> BatchState: + return BatchState( + schema_version=1, + batch_id="msgbatch_test", + submitted_at=submitted_at or datetime(2026, 5, 8, 18, 35, tzinfo=timezone.utc), + expected_completion_at=datetime(2026, 5, 8, 18, 41, tzinfo=timezone.utc), + model="claude-sonnet-4-20250514", + requests=( + BatchStateRequest("feat__auth__concept", "auth", "concept"), + BatchStateRequest("feat__auth__task", "auth", "task"), + BatchStateRequest("feat__auth__reference", "auth", "reference"), + ), + ) + + +def _stub_succeeded_results(features_and_depths: list[tuple[str, str]]) -> list[BatchPolishResult]: + """Build a successful batch result for each (feature, depth).""" + return [ + BatchPolishResult( + custom_id=custom_id_for(f, d), + feature=f, + depth=d, + text=f"# polished {f}/{d}\n\npolished body for {f} at depth {d}.\n", + error=None, + ) + for f, d in features_and_depths + ] + + +# --- state file lifecycle -------------------------------------------------- + + +class TestStateFile: + def test_round_trip_preserves_all_fields(self, tmp_path: Path) -> None: + s = _state() + write_batch_state(tmp_path, s) + loaded = read_batch_state(tmp_path, now=datetime(2026, 5, 9, tzinfo=timezone.utc)) + assert loaded == s + + def test_state_path_uses_dotfile(self, tmp_path: Path) -> None: + assert state_path(tmp_path).name == ".batch-state.json" + + def test_has_pending_batch_false_when_no_state(self, tmp_path: Path) -> None: + assert has_pending_batch(tmp_path) is False + + def test_has_pending_batch_true_after_write(self, tmp_path: Path) -> None: + write_batch_state(tmp_path, _state()) + assert has_pending_batch(tmp_path) is True + + def test_delete_state_no_op_when_missing(self, tmp_path: Path) -> None: + delete_batch_state(tmp_path) # should not raise + + def test_read_raises_not_found_when_missing(self, tmp_path: Path) -> None: + with pytest.raises(BatchStateNotFound): + read_batch_state(tmp_path) + + def test_read_raises_expired_for_old_state(self, tmp_path: Path) -> None: + old = _state(submitted_at=datetime(2026, 1, 1, tzinfo=timezone.utc)) + write_batch_state(tmp_path, old) + with pytest.raises(BatchStateExpired, match="29-day retention"): + read_batch_state(tmp_path, now=datetime(2026, 5, 8, tzinfo=timezone.utc)) + + def test_read_raises_expired_for_unknown_schema(self, tmp_path: Path) -> None: + target = state_path(tmp_path) + target.write_text( + json.dumps( + { + "schema_version": 999, + "batch_id": "x", + "submitted_at": "2026-05-08T18:35:00+00:00", + "expected_completion_at": "2026-05-08T18:41:00+00:00", + "model": "m", + "requests": [], + } + ) + ) + with pytest.raises(BatchStateExpired, match="schema_version"): + read_batch_state(tmp_path) + + def test_read_raises_expired_for_corrupt_json(self, tmp_path: Path) -> None: + state_path(tmp_path).write_text("not json {") + with pytest.raises(BatchStateExpired, match="corrupt"): + read_batch_state(tmp_path) + + def test_write_is_atomic_via_tempfile(self, tmp_path: Path) -> None: + write_batch_state(tmp_path, _state()) + # No leftover .tmp file after a successful write. + assert not list(tmp_path.glob("*.tmp")) + assert state_path(tmp_path).exists() + + +# --- submit ---------------------------------------------------------------- + + +class TestSubmit: + def test_raises_when_state_already_exists(self, help_dir: Path, project_root: Path) -> None: + write_batch_state(help_dir, _state()) + with pytest.raises(BatchAlreadyPending): + submit_maintenance_batch(help_dir, project_root) + + def test_force_overwrites_existing_state(self, help_dir: Path, project_root: Path) -> None: + write_batch_state(help_dir, _state()) + with patch("attune_author.maintenance_batch.submit_batch", return_value="msgbatch_new"): + new_state = submit_maintenance_batch( + help_dir, project_root, force=True, client=MagicMock() + ) + assert new_state.batch_id == "msgbatch_new" + + def test_submit_writes_state_file_with_request_manifest( + self, help_dir: Path, project_root: Path + ) -> None: + with patch("attune_author.maintenance_batch.submit_batch", return_value="msgbatch_aaa"): + submitted = submit_maintenance_batch(help_dir, project_root, client=MagicMock()) + assert submitted.batch_id == "msgbatch_aaa" + assert has_pending_batch(help_dir) + # Every request in state has a (feature, depth) we can map back. + assert all(r.custom_id.startswith("feat__") for r in submitted.requests) + # Persisted file matches the in-memory state. Read with ``now`` set + # to the submit timestamp so the 29-day retention check doesn't fire + # on the fresh state file we just wrote. + loaded = read_batch_state(help_dir, now=submitted.submitted_at) + assert loaded.batch_id == submitted.batch_id + + def test_submit_passes_polish_requests_to_submit_batch( + self, help_dir: Path, project_root: Path + ) -> None: + captured: list[BatchPolishRequest] = [] + + def _capture(reqs, client=None): # noqa: ARG001 + captured.extend(reqs) + return "msgbatch_captured" + + with patch("attune_author.maintenance_batch.submit_batch", side_effect=_capture): + submit_maintenance_batch(help_dir, project_root, client=MagicMock()) + assert len(captured) > 0 + # Each captured request has a non-empty system + user_message + # (build_polish_prompt produced real content) and a custom_id + # matching feat__{feature}__{depth}. + for r in captured: + assert r.system.strip() + assert r.user_message.strip() + assert r.custom_id == custom_id_for(r.feature, r.depth) + + def test_submit_raises_when_no_stale_features(self, help_dir: Path, project_root: Path) -> None: + # Run normal maintain first so everything is current. + from attune_author.maintenance import run_maintenance + + run_maintenance(help_dir=help_dir, project_root=project_root) + with pytest.raises(ValueError, match="no stale features"): + submit_maintenance_batch(help_dir, project_root, client=MagicMock()) + + +# --- resume ---------------------------------------------------------------- + + +class TestResume: + def _seed_state_for(self, help_dir: Path, project_root: Path) -> BatchState: + with patch( + "attune_author.maintenance_batch.submit_batch", + return_value="msgbatch_resume_test", + ): + return submit_maintenance_batch(help_dir, project_root, client=MagicMock()) + + def test_resume_writes_polished_templates_for_successful_features( + self, help_dir: Path, project_root: Path + ) -> None: + state = self._seed_state_for(help_dir, project_root) + results = _stub_succeeded_results([(r.feature, r.depth) for r in state.requests]) + with patch( + "attune_author.maintenance_batch.poll_batch", + return_value=("ended", results), + ): + outcome = resume_maintenance_batch(help_dir, project_root, client=MagicMock()) + assert isinstance(outcome, MaintenanceResultBatch) + assert outcome.final_status == "ended" + assert outcome.regenerated_count >= 1 + assert outcome.failed == [] + # State file was deleted after successful resume. + assert not has_pending_batch(help_dir) + # On-disk file actually contains the polished body. + regenerated_paths = [t.path for r in outcome.regenerated for t in r.templates] + for path in regenerated_paths: + assert "polished body" in path.read_text(encoding="utf-8") + + def test_resume_isolates_per_feature_failure(self, help_dir: Path, project_root: Path) -> None: + state = self._seed_state_for(help_dir, project_root) + results: list[BatchPolishResult] = [] + for r in state.requests: + if r.feature == "auth" and r.depth == "concept": + # One failed depth fails the whole feature. + results.append( + BatchPolishResult( + custom_id=r.custom_id, + feature=r.feature, + depth=r.depth, + text=None, + error="rate_limit_error", + ) + ) + else: + results.append( + BatchPolishResult( + custom_id=r.custom_id, + feature=r.feature, + depth=r.depth, + text=f"# polished {r.feature}/{r.depth}\n", + error=None, + ) + ) + with patch( + "attune_author.maintenance_batch.poll_batch", + return_value=("ended", results), + ): + outcome = resume_maintenance_batch(help_dir, project_root, client=MagicMock()) + assert "auth" in outcome.failed + # State file deleted on terminal completion (even with failures). + assert not has_pending_batch(help_dir) + + def test_resume_timeout_keeps_state_file(self, help_dir: Path, project_root: Path) -> None: + state = self._seed_state_for(help_dir, project_root) + timed = [ + BatchPolishResult( + custom_id=r.custom_id, + feature=r.feature, + depth=r.depth, + text=None, + error="timed_out", + ) + for r in state.requests + ] + with patch( + "attune_author.maintenance_batch.poll_batch", + return_value=("timed_out", timed), + ): + outcome = resume_maintenance_batch(help_dir, project_root, client=MagicMock()) + assert outcome.final_status == "timed_out" + # State file retained for next --resume invocation. + assert has_pending_batch(help_dir) + + +# --- status & cancel ------------------------------------------------------- + + +class TestStatus: + def test_status_returns_local_state_plus_anthropic_view( + self, help_dir: Path, project_root: Path + ) -> None: + write_batch_state(help_dir, _state()) + client = MagicMock() + client.messages.batches.retrieve.return_value = SimpleNamespace( + processing_status="in_progress", + request_counts=SimpleNamespace( + succeeded=1, errored=0, expired=0, canceled=0, processing=2 + ), + ended_at=None, + expires_at="2026-05-09T18:35:00Z", + cancel_initiated_at=None, + ) + out = status_maintenance_batch(help_dir, client=client) + assert out["batch_id"] == "msgbatch_test" + assert out["processing_status"] == "in_progress" + assert out["request_counts"]["processing"] == 2 + assert out["request_count"] == 3 + # JSON-serializable for cron consumers + json.dumps(out) + + +class TestCancel: + def test_cancel_removes_state_file_and_calls_sdk( + self, help_dir: Path, project_root: Path + ) -> None: + write_batch_state(help_dir, _state()) + client = MagicMock() + cancel_maintenance_batch(help_dir, client=client) + client.messages.batches.cancel.assert_called_once_with("msgbatch_test") + assert not has_pending_batch(help_dir) + + def test_cancel_removes_state_even_when_sdk_fails( + self, help_dir: Path, project_root: Path + ) -> None: + write_batch_state(help_dir, _state()) + client = MagicMock() + client.messages.batches.cancel.side_effect = RuntimeError("network down") + cancel_maintenance_batch(help_dir, client=client) + assert not has_pending_batch(help_dir) + + +# --- parity ---------------------------------------------------------------- + + +class TestParity: + def test_synchronous_and_batch_paths_write_same_file_set( + self, help_dir: Path, project_root: Path + ) -> None: + """Same stale features → same set of template files written. + + Guards against drift between the two paths' rendering. The + actual *content* differs because the polish pass is mocked + in the batch path; we assert the set of files and their + ``feature/depth`` shape match. + """ + from attune_author.maintenance import run_maintenance + + # --- synchronous run --- + sync_dir = help_dir.parent / ".help_sync" + sync_dir.mkdir() + (sync_dir / "features.yaml").write_text( + (help_dir / "features.yaml").read_text(), + encoding="utf-8", + ) + sync_outcome = run_maintenance(help_dir=sync_dir, project_root=project_root) + sync_paths = { + t.path.relative_to(sync_dir).as_posix() + for r in sync_outcome.regenerated + for t in r.templates + } + + # --- batch run --- + with patch( + "attune_author.maintenance_batch.submit_batch", + return_value="msgbatch_parity", + ): + state = submit_maintenance_batch(help_dir, project_root, client=MagicMock()) + + results = _stub_succeeded_results([(r.feature, r.depth) for r in state.requests]) + with patch( + "attune_author.maintenance_batch.poll_batch", + return_value=("ended", results), + ): + batch_outcome = resume_maintenance_batch(help_dir, project_root, client=MagicMock()) + batch_paths = { + t.path.relative_to(help_dir).as_posix() + for r in batch_outcome.regenerated + for t in r.templates + } + assert sync_paths == batch_paths