diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 1c68258..e59767b 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -36,5 +36,28 @@ jobs: - name: Run ruff run: python -m ruff check src/ tests/ - - name: Run tests - run: python -m pytest tests/ -v --tb=short + - name: Guard — ANTHROPIC_API_KEY must NOT be set for default suite + run: | + if [ -n "${ANTHROPIC_API_KEY:-}" ]; then + echo "::error::ANTHROPIC_API_KEY is set in the default test environment." + echo "This can cause non-'live'-marked tests to leak real API calls." + echo "Either unset the key, or mark the test with @pytest.mark.live." + exit 1 + fi + shell: bash + + - name: Run tests (with coverage on ubuntu x py3.11) + run: | + if [ "${{ matrix.os }}" = "ubuntu-latest" ] && [ "${{ matrix.python-version }}" = "3.11" ]; then + python -m pytest tests/ -v --tb=short --cov --cov-report=term-missing --cov-report=xml + else + python -m pytest tests/ -v --tb=short + fi + shell: bash + + - name: Upload coverage artifact + if: matrix.os == 'ubuntu-latest' && matrix.python-version == '3.11' + uses: actions/upload-artifact@v4 + with: + name: coverage-attune-rag + path: coverage.xml diff --git a/pyproject.toml b/pyproject.toml index 4eabf28..6f99d51 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -84,6 +84,26 @@ where = ["src"] [tool.pytest.ini_options] testpaths = ["tests"] asyncio_mode = "auto" +addopts = "-ra -m 'not live'" +markers = [ + "live: opt-in tests that hit a real LLM API. Skipped by default; require ANTHROPIC_API_KEY and any test-specific env flags.", +] + +[tool.coverage.run] +source = ["src/attune_rag"] +branch = true +omit = ["*/tests/*", "*/conftest.py"] + +[tool.coverage.report] +show_missing = true +skip_covered = false +fail_under = 77 +exclude_lines = [ + "pragma: no cover", + "raise NotImplementedError", + "if TYPE_CHECKING:", + "if __name__ == .__main__.:", +] [tool.ruff] line-length = 100 diff --git a/tests/README.md b/tests/README.md new file mode 100644 index 0000000..b179014 --- /dev/null +++ b/tests/README.md @@ -0,0 +1,68 @@ +# attune-rag tests + +## Running locally + +```bash +# Install dev deps (includes pytest-cov) +pip install -e ".[dev]" + +# Full suite +pytest + +# With coverage (matches CI's ubuntu x py3.11 cell) +pytest --cov --cov-report=term-missing + +# Just unit tests (skip the golden retrieval suite) +pytest tests/unit/ + +# Just golden retrieval (requires [attune-help] extra; long-running) +pytest tests/golden/ +``` + +## LLM mocking standard, `live` marker, CI guard, cost policy + +See **`testing-conventions.md`** in the attune workspace umbrella — +the canonical reference (mocking pattern, `live` marker semantics, CI +guard expectation, cost & quota policy). Applies to all four layers. + +attune-rag itself is LLM-agnostic. The `live` marker is registered in +`pyproject.toml` so any future opt-in tests have a consistent home. + +## Public API contract tests + +`tests/unit/test_contracts.py` pins the public surface of `attune_rag`: + +- Every name in `__all__` must be importable, the right kind (class / + callable / dict), and have a docstring. +- Function signatures for the most-consumed callables (`build_augmented_prompt`, + `RagPipeline.run`) preserve documented kwargs. +- `CitedSource` keeps the consumer-facing `template_path`/`score`/`excerpt`/`category` + fields (attune-gui maps `result.citation.hits` directly into its own + `RagHit` shape). + +Adding a new public export requires updating `EXPECTED_ALL` in this +file — that's deliberate friction. attune-rag is the API contract +source for attune-gui, attune-help, and attune-author. + +## Golden retrieval suite + +`tests/golden/` runs each `queries.yaml` entry through a real +`RagPipeline` and asserts overlap with `expected_in_top_3`. Hard +queries are dynamically `@pytest.mark.xfail(strict=False)` so retriever +upgrades surface as `XPASS` rather than silent regressions. + +Adding a query is a one-line YAML edit in `tests/golden/queries.yaml`. + +## What's tested vs. not + +Tracked in +`/Users/patrickroebuck/attune/specs/test-strategy/current-state.md`. After +pass 1, the highest-value remaining gaps in this layer are: + +- `dashboard/show.py` (172 statements, 0% covered) — Rich CLI module + with no tests; biggest remaining gap by line count. +- `dashboard/refresh.py` (~66%) — refresh paths. +- `eval/bench_prompts.py` (~70%) — prompt builders. + +Pass 2 will revisit; stretch ceiling for this layer is **90%** +(rag is the contract source of truth, gets the highest gate). diff --git a/tests/unit/test_benchmark.py b/tests/unit/test_benchmark.py new file mode 100644 index 0000000..3beee90 --- /dev/null +++ b/tests/unit/test_benchmark.py @@ -0,0 +1,270 @@ +"""Tests for attune_rag.benchmark — CLI exit codes, aggregation, helpers. + +Targets the highest coverage gap identified in the test-strategy audit: +``benchmark.py`` was at 10% line coverage. These tests exercise the pure +helpers (`_percentile`, `_load_queries`) and the CLI happy + error paths +without spending API tokens (`--with-faithfulness` is gated behind a +real ANTHROPIC_API_KEY check we don't satisfy here). +""" + +from __future__ import annotations + +from pathlib import Path +from types import SimpleNamespace +from unittest.mock import patch + +import pytest +import yaml + +from attune_rag.benchmark import ( + _default_queries_path, + _load_queries, + _percentile, + _run_benchmark, + main, +) + +# --------------------------------------------------------------------------- +# Helpers / fixtures +# --------------------------------------------------------------------------- + + +def _write_queries(path: Path, queries: list[dict]) -> Path: + """Write a queries.yaml-shaped file to ``path``.""" + path.write_text(yaml.safe_dump({"queries": queries}), encoding="utf-8") + return path + + +def _hit(template_path: str) -> SimpleNamespace: + """Minimal RagPipeline.run hit shape — only template_path used by benchmark.""" + return SimpleNamespace(template_path=template_path) + + +def _result(*hits: SimpleNamespace) -> SimpleNamespace: + return SimpleNamespace(citation=SimpleNamespace(hits=list(hits))) + + +class _FakeRetriever: + pass + + +class _FakePipeline: + """Stub RagPipeline that returns scripted results per query string.""" + + def __init__(self, scripted: dict[str, list[str]]) -> None: + self._scripted = scripted + self.retriever = _FakeRetriever() + self.corpus = SimpleNamespace(name="fake-corpus") + + def run(self, query: str, k: int = 3) -> SimpleNamespace: + paths = self._scripted.get(query, []) + return _result(*[_hit(p) for p in paths]) + + +# --------------------------------------------------------------------------- +# _default_queries_path +# --------------------------------------------------------------------------- + + +def test_default_queries_path_resolves_inside_repo() -> None: + path = _default_queries_path() + assert path.name == "queries.yaml" + # Path-component check (OS-independent — Windows uses backslashes). + assert path.parent.name == "golden" + assert path.parent.parent.name == "tests" + + +# --------------------------------------------------------------------------- +# _load_queries +# --------------------------------------------------------------------------- + + +def test_load_queries_returns_query_list(tmp_path: Path) -> None: + p = _write_queries( + tmp_path / "q.yaml", + [{"id": "q1", "query": "hello", "expected_in_top_3": ["a.md"]}], + ) + out = _load_queries(p) + assert len(out) == 1 + assert out[0]["id"] == "q1" + + +def test_load_queries_raises_when_no_queries_key(tmp_path: Path) -> None: + p = tmp_path / "empty.yaml" + p.write_text("queries: []\n", encoding="utf-8") + with pytest.raises(ValueError, match="No queries"): + _load_queries(p) + + +def test_load_queries_raises_when_top_level_missing_queries(tmp_path: Path) -> None: + p = tmp_path / "junk.yaml" + p.write_text("not_queries: []\n", encoding="utf-8") + with pytest.raises(ValueError): + _load_queries(p) + + +# --------------------------------------------------------------------------- +# _percentile +# --------------------------------------------------------------------------- + + +def test_percentile_empty_list_returns_zero() -> None: + assert _percentile([], 0.95) == 0.0 + + +@pytest.mark.parametrize( + "values,pct,expected", + [ + ([1.0], 0.5, 1.0), + ([1.0, 2.0, 3.0, 4.0, 5.0], 0.0, 1.0), # min + ([1.0, 2.0, 3.0, 4.0, 5.0], 1.0, 5.0), # max + ([10.0, 20.0, 30.0, 40.0, 50.0], 0.5, 30.0), # median + ], +) +def test_percentile_typical_values(values: list[float], pct: float, expected: float) -> None: + assert _percentile(values, pct) == expected + + +def test_percentile_handles_unsorted_input() -> None: + assert _percentile([5.0, 2.0, 9.0, 1.0, 3.0], 0.0) == 1.0 + + +# --------------------------------------------------------------------------- +# _run_benchmark — aggregation math +# --------------------------------------------------------------------------- + + +def test_run_benchmark_perfect_precision_and_recall() -> None: + queries = [ + {"id": "q1", "query": "auth", "expected_in_top_3": ["concepts/auth.md"]}, + {"id": "q2", "query": "memory", "expected_in_top_3": ["concepts/memory.md"]}, + ] + pipeline = _FakePipeline( + { + "auth": ["concepts/auth.md"], + "memory": ["concepts/memory.md"], + } + ) + with patch("attune_rag.RagPipeline", return_value=pipeline): + report = _run_benchmark(queries, k=3) + assert report["total_queries"] == 2 + assert report["precision_at_1"] == 1.0 + assert report["recall_at_k"] == 1.0 + assert report["k"] == 3 + + +def test_run_benchmark_zero_precision_when_top1_misses() -> None: + queries = [ + {"id": "q1", "query": "auth", "expected_in_top_3": ["concepts/auth.md"]}, + ] + pipeline = _FakePipeline( + {"auth": ["concepts/wrong.md", "concepts/auth.md"]}, + ) + with patch("attune_rag.RagPipeline", return_value=pipeline): + report = _run_benchmark(queries, k=3) + assert report["precision_at_1"] == 0.0 + # But recall@3 still counts since auth.md is in the top-k set + assert report["recall_at_k"] == 1.0 + + +def test_run_benchmark_zero_recall_when_no_match() -> None: + queries = [ + {"id": "q1", "query": "auth", "expected_in_top_3": ["concepts/auth.md"]}, + ] + pipeline = _FakePipeline({"auth": ["concepts/elsewhere.md"]}) + with patch("attune_rag.RagPipeline", return_value=pipeline): + report = _run_benchmark(queries, k=3) + assert report["precision_at_1"] == 0.0 + assert report["recall_at_k"] == 0.0 + + +def test_run_benchmark_records_per_query_detail() -> None: + queries = [ + { + "id": "q1", + "query": "auth", + "expected_in_top_3": ["a.md"], + "difficulty": "easy", + }, + { + "id": "q2", + "query": "memory", + "expected_in_top_3": ["m.md"], + "difficulty": "hard", + }, + ] + pipeline = _FakePipeline({"auth": ["a.md"], "memory": ["wrong.md"]}) + with patch("attune_rag.RagPipeline", return_value=pipeline): + report = _run_benchmark(queries, k=3) + by_id = {q["id"]: q for q in report["per_query"]} + assert by_id["q1"]["top1_match"] is True + assert by_id["q2"]["top1_match"] is False + assert by_id["q1"]["difficulty"] == "easy" + assert by_id["q2"]["difficulty"] == "hard" + + +def test_run_benchmark_empty_queries_yields_zero_metrics() -> None: + """Defensive: total=0 must not divide-by-zero.""" + pipeline = _FakePipeline({}) + with patch("attune_rag.RagPipeline", return_value=pipeline): + report = _run_benchmark([], k=3) + assert report["total_queries"] == 0 + assert report["precision_at_1"] == 0.0 + assert report["recall_at_k"] == 0.0 + assert report["mean_latency_ms"] == 0.0 + + +# --------------------------------------------------------------------------- +# main() — exit codes +# --------------------------------------------------------------------------- + + +def test_main_exits_2_when_queries_file_missing( + tmp_path: Path, capsys: pytest.CaptureFixture[str] +) -> None: + rc = main(["--queries", str(tmp_path / "nope.yaml")]) + assert rc == 2 + assert "Queries file not found" in capsys.readouterr().err + + +def test_main_exits_0_on_perfect_precision(tmp_path: Path) -> None: + p = _write_queries( + tmp_path / "q.yaml", + [{"id": "q1", "query": "auth", "expected_in_top_3": ["a.md"]}], + ) + pipeline = _FakePipeline({"auth": ["a.md"]}) + with patch("attune_rag.RagPipeline", return_value=pipeline): + rc = main(["--queries", str(p), "--min-precision", "0.5"]) + assert rc == 0 + + +def test_main_exits_1_when_precision_below_gate( + tmp_path: Path, capsys: pytest.CaptureFixture[str] +) -> None: + p = _write_queries( + tmp_path / "q.yaml", + [{"id": "q1", "query": "auth", "expected_in_top_3": ["a.md"]}], + ) + pipeline = _FakePipeline({"auth": ["wrong.md"]}) # 0% precision + with patch("attune_rag.RagPipeline", return_value=pipeline): + rc = main(["--queries", str(p), "--min-precision", "0.5"]) + assert rc == 1 + assert "FAIL" in capsys.readouterr().err + + +def test_main_with_faithfulness_requires_api_key( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str] +) -> None: + """--with-faithfulness without ANTHROPIC_API_KEY exits 2.""" + p = _write_queries( + tmp_path / "q.yaml", + [{"id": "q1", "query": "auth", "expected_in_top_3": ["a.md"]}], + ) + monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False) + pipeline = _FakePipeline({"auth": ["a.md"]}) + with patch("attune_rag.RagPipeline", return_value=pipeline): + rc = main( + ["--queries", str(p), "--min-precision", "0.0", "--with-faithfulness"], + ) + assert rc == 2 + assert "ANTHROPIC_API_KEY" in capsys.readouterr().err diff --git a/tests/unit/test_contracts.py b/tests/unit/test_contracts.py new file mode 100644 index 0000000..5cc700c --- /dev/null +++ b/tests/unit/test_contracts.py @@ -0,0 +1,194 @@ +"""Public-API contract tests for attune_rag. + +attune-rag is the contract source of truth for the attune product +family — attune-gui, attune-help, and attune-author all consume objects +from this package. This file pins the public surface so silent breaking +changes (renames, signature drift, missing docstrings) fail at PR time +instead of surfacing as runtime errors weeks later in consumer layers. + +Each named export in ``attune_rag.__all__`` must: + +- Be importable directly from ``attune_rag`` +- Be the right kind (class vs callable) +- Carry a docstring +- Preserve its declared module identity (catches accidental re-binds) +""" + +from __future__ import annotations + +import inspect + +import pytest + +import attune_rag + +# --------------------------------------------------------------------------- +# __all__ surface +# --------------------------------------------------------------------------- + + +EXPECTED_ALL = { + "RagPipeline", + "RagResult", + "CitationRecord", + "CitedSource", + "ClaimCitation", + "format_citations_markdown", + "format_claim_citations_markdown", + "CorpusProtocol", + "RetrievalEntry", + "DirectoryCorpus", + "KeywordRetriever", + "RetrievalHit", + "RetrieverProtocol", + "build_augmented_prompt", + "PROMPT_VARIANTS", + "QueryExpander", + "LLMReranker", +} + + +def test_all_lists_the_documented_public_surface() -> None: + """``__all__`` is the contract; new exports require an explicit add.""" + assert set(attune_rag.__all__) == EXPECTED_ALL + + +@pytest.mark.parametrize("name", sorted(EXPECTED_ALL)) +def test_export_is_importable_from_top_level(name: str) -> None: + """Each declared export must be reachable as ``attune_rag.``.""" + assert hasattr(attune_rag, name), f"missing public export: {name}" + assert getattr(attune_rag, name) is not None + + +@pytest.mark.parametrize("name", sorted(EXPECTED_ALL)) +def test_export_has_docstring(name: str) -> None: + """Every public symbol needs a docstring — consumer doc generators rely on it.""" + obj = getattr(attune_rag, name) + if isinstance(obj, dict): + # PROMPT_VARIANTS is a module-level dict; no docstring expected. + return + doc = inspect.getdoc(obj) + assert doc, f"{name} is missing a docstring" + assert len(doc.strip()) >= 10, f"{name} docstring is too short: {doc!r}" + + +# --------------------------------------------------------------------------- +# Type / shape contracts +# --------------------------------------------------------------------------- + + +CLASSES = { + "RagPipeline", + "RagResult", + "CitationRecord", + "CitedSource", + "ClaimCitation", + "DirectoryCorpus", + "KeywordRetriever", + "RetrievalHit", + "RetrievalEntry", + "QueryExpander", + "LLMReranker", +} + +PROTOCOLS = {"CorpusProtocol", "RetrieverProtocol"} + +CALLABLES = { + "format_citations_markdown", + "format_claim_citations_markdown", + "build_augmented_prompt", +} + +CONTAINERS = {"PROMPT_VARIANTS"} + + +def test_export_classification_is_complete() -> None: + """Every export falls into exactly one category.""" + classified = CLASSES | PROTOCOLS | CALLABLES | CONTAINERS + assert classified == EXPECTED_ALL, ( + f"unclassified exports: {EXPECTED_ALL - classified}; " + f"unknown classifiers: {classified - EXPECTED_ALL}" + ) + + +@pytest.mark.parametrize("name", sorted(CLASSES | PROTOCOLS)) +def test_class_exports_are_classes(name: str) -> None: + obj = getattr(attune_rag, name) + assert inspect.isclass(obj), f"{name} should be a class" + + +@pytest.mark.parametrize("name", sorted(CALLABLES)) +def test_callable_exports_are_functions(name: str) -> None: + obj = getattr(attune_rag, name) + assert callable(obj), f"{name} should be callable" + # Must have a real signature (not a builtin or C extension stub). + inspect.signature(obj) + + +def test_prompt_variants_is_a_dict_of_strings() -> None: + """PROMPT_VARIANTS schema: dict[str, str] (variant name → template body).""" + pv = attune_rag.PROMPT_VARIANTS + assert isinstance(pv, dict) + assert pv, "PROMPT_VARIANTS must be non-empty" + for k, v in pv.items(): + assert isinstance(k, str) + assert isinstance(v, str) + assert v.strip(), f"PROMPT_VARIANTS[{k!r}] is empty" + + +def test_version_is_pep440_string() -> None: + """``attune_rag.__version__`` is the canonical version string consumers may pin.""" + v = attune_rag.__version__ + assert isinstance(v, str) + parts = v.split(".") + assert len(parts) >= 2 + # Major / minor / patch are digit-only. + int(parts[0]) + int(parts[1]) + + +# --------------------------------------------------------------------------- +# Signature pins for the most-consumed callables +# --------------------------------------------------------------------------- + + +def test_build_augmented_prompt_signature() -> None: + sig = inspect.signature(attune_rag.build_augmented_prompt) + # Must accept query + context-bearing args; downstream consumers + # (attune-gui's RagPipeline.run flow) depend on the keyword names. + params = list(sig.parameters) + assert "query" in params, f"build_augmented_prompt lost ``query`` param: {params}" + + +def test_rag_pipeline_run_signature() -> None: + sig = inspect.signature(attune_rag.RagPipeline.run) + params = list(sig.parameters) + # ``self`` + ``query`` must be present; ``k`` keeps default semantics. + assert "query" in params + assert "k" in params + + +def test_retrieval_hit_has_internal_shape() -> None: + """RetrievalHit is the internal retriever record — keeps its own shape.""" + hit_cls = attune_rag.RetrievalHit + assert hasattr( + hit_cls, "__dataclass_fields__" + ), "RetrievalHit must remain a dataclass for predictable structural use" + fields = {f.name for f in hit_cls.__dataclass_fields__.values()} + # Pinning current internal shape; if these change, downstream + # callers (notably the reranker + benchmark) need adjustments. + for required in ("entry", "score", "match_reason"): + assert required in fields, f"RetrievalHit lost ``{required}`` field" + + +def test_cited_source_exposes_consumer_fields() -> None: + """``CitedSource`` is what attune-gui's ``RagPipeline.run`` flow ultimately + surfaces — the flat ``template_path / score / excerpt / category`` shape + consumers depend on must live on this class.""" + cls = attune_rag.CitedSource + if hasattr(cls, "__dataclass_fields__"): + fields = {f.name for f in cls.__dataclass_fields__.values()} + else: + fields = {a for a in dir(cls) if not a.startswith("_")} + for required in ("template_path", "score", "excerpt", "category"): + assert required in fields, f"CitedSource lost ``{required}`` attribute — consumer break" diff --git a/tests/unit/test_providers_base.py b/tests/unit/test_providers_base.py new file mode 100644 index 0000000..1f78f2f --- /dev/null +++ b/tests/unit/test_providers_base.py @@ -0,0 +1,146 @@ +"""Tests for attune_rag.providers.base — protocol shape + dataclass invariants. + +Pins the LLMProvider protocol and the two dataclasses concrete providers +(claude, gemini) consume. These are part of attune-rag's contract surface +since downstream callers can implement their own provider by conforming +to the protocol. +""" + +from __future__ import annotations + +import asyncio +import inspect +from dataclasses import FrozenInstanceError + +import pytest + +from attune_rag.providers.base import ( + CitationDocument, + CitedResponse, + LLMProvider, +) + +# --------------------------------------------------------------------------- +# CitationDocument +# --------------------------------------------------------------------------- + + +def test_citation_document_is_frozen() -> None: + doc = CitationDocument(title="t.md", text="content") + with pytest.raises(FrozenInstanceError): + doc.title = "other.md" # type: ignore[misc] + + +def test_citation_document_equality_by_field() -> None: + a = CitationDocument(title="t.md", text="x") + b = CitationDocument(title="t.md", text="x") + assert a == b + assert hash(a) == hash(b) + + +def test_citation_document_unequal_when_different_text() -> None: + a = CitationDocument(title="t.md", text="one") + b = CitationDocument(title="t.md", text="two") + assert a != b + + +# --------------------------------------------------------------------------- +# CitedResponse +# --------------------------------------------------------------------------- + + +def test_cited_response_is_frozen() -> None: + r = CitedResponse(text="ans", claim_citations=()) + with pytest.raises(FrozenInstanceError): + r.text = "other" # type: ignore[misc] + + +def test_cited_response_claim_citations_is_tuple() -> None: + """``claim_citations`` is a tuple to keep the dataclass hashable.""" + r = CitedResponse(text="ans", claim_citations=()) + assert isinstance(r.claim_citations, tuple) + + +# --------------------------------------------------------------------------- +# LLMProvider protocol +# --------------------------------------------------------------------------- + + +def test_llm_provider_is_runtime_checkable() -> None: + """Pipeline uses ``isinstance(p, LLMProvider)`` to dispatch — runtime check + must work.""" + + class _Conforming: + name = "fake" + supports_native_citations = False + + async def generate( + self, + prompt: str, + model: str | None = None, + max_tokens: int = 2048, + cached_prefix: str | None = None, + ) -> str: + return "ok" + + async def generate_with_citations(self, *args, **kwargs): # noqa: ANN + raise NotImplementedError + + assert isinstance(_Conforming(), LLMProvider) + + +def test_llm_provider_rejects_missing_attrs() -> None: + """Class missing ``name`` or required methods must NOT pass isinstance.""" + + class _Incomplete: + # Missing name + supports_native_citations + methods. + pass + + assert not isinstance(_Incomplete(), LLMProvider) + + +def test_default_generate_with_citations_raises_not_implemented() -> None: + """Concrete providers that don't override the citations method must + surface a clear NotImplementedError when callers reach for it.""" + + class _NoCitations: + name = "nocite" + supports_native_citations = False + + async def generate( + self, + prompt: str, + model: str | None = None, + max_tokens: int = 2048, + cached_prefix: str | None = None, + ) -> str: + return "ok" + + # Inherit default behavior — call via the protocol method directly. + generate_with_citations = LLMProvider.generate_with_citations + + inst = _NoCitations() + with pytest.raises(NotImplementedError, match="does not support native citations"): + asyncio.run(inst.generate_with_citations(documents=[], query="q")) + + +# --------------------------------------------------------------------------- +# Signature pins — providers depend on these defaults +# --------------------------------------------------------------------------- + + +def test_generate_signature_keeps_documented_kwargs() -> None: + sig = inspect.signature(LLMProvider.generate) + params = sig.parameters + # Pipeline + tests pass these by keyword. + for required in ("prompt", "model", "max_tokens", "cached_prefix"): + assert required in params, f"generate() lost ``{required}`` kwarg" + assert params["max_tokens"].default == 2048 + + +def test_generate_with_citations_signature_keeps_documented_kwargs() -> None: + sig = inspect.signature(LLMProvider.generate_with_citations) + params = sig.parameters + for required in ("documents", "query", "system", "model", "max_tokens"): + assert required in params, f"generate_with_citations() lost ``{required}`` kwarg" + assert params["max_tokens"].default == 2048