Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,28 @@ jobs:
- name: Run ruff
run: python -m ruff check src/ tests/

- name: Run tests
run: python -m pytest tests/ -v --tb=short
- name: Guard — ANTHROPIC_API_KEY must NOT be set for default suite
run: |
if [ -n "${ANTHROPIC_API_KEY:-}" ]; then
echo "::error::ANTHROPIC_API_KEY is set in the default test environment."
echo "This can cause non-'live'-marked tests to leak real API calls."
echo "Either unset the key, or mark the test with @pytest.mark.live."
exit 1
fi
shell: bash

- name: Run tests (with coverage on ubuntu x py3.11)
run: |
if [ "${{ matrix.os }}" = "ubuntu-latest" ] && [ "${{ matrix.python-version }}" = "3.11" ]; then
python -m pytest tests/ -v --tb=short --cov --cov-report=term-missing --cov-report=xml
else
python -m pytest tests/ -v --tb=short
fi
shell: bash

- name: Upload coverage artifact
if: matrix.os == 'ubuntu-latest' && matrix.python-version == '3.11'
uses: actions/upload-artifact@v4
with:
name: coverage-attune-rag
path: coverage.xml
20 changes: 20 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,26 @@ where = ["src"]
[tool.pytest.ini_options]
testpaths = ["tests"]
asyncio_mode = "auto"
addopts = "-ra -m 'not live'"
markers = [
"live: opt-in tests that hit a real LLM API. Skipped by default; require ANTHROPIC_API_KEY and any test-specific env flags.",
]

[tool.coverage.run]
source = ["src/attune_rag"]
branch = true
omit = ["*/tests/*", "*/conftest.py"]

[tool.coverage.report]
show_missing = true
skip_covered = false
fail_under = 77
exclude_lines = [
"pragma: no cover",
"raise NotImplementedError",
"if TYPE_CHECKING:",
"if __name__ == .__main__.:",
]

[tool.ruff]
line-length = 100
Expand Down
68 changes: 68 additions & 0 deletions tests/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# attune-rag tests

## Running locally

```bash
# Install dev deps (includes pytest-cov)
pip install -e ".[dev]"

# Full suite
pytest

# With coverage (matches CI's ubuntu x py3.11 cell)
pytest --cov --cov-report=term-missing

# Just unit tests (skip the golden retrieval suite)
pytest tests/unit/

# Just golden retrieval (requires [attune-help] extra; long-running)
pytest tests/golden/
```

## LLM mocking standard, `live` marker, CI guard, cost policy

See **`testing-conventions.md`** in the attune workspace umbrella —
the canonical reference (mocking pattern, `live` marker semantics, CI
guard expectation, cost & quota policy). Applies to all four layers.

attune-rag itself is LLM-agnostic. The `live` marker is registered in
`pyproject.toml` so any future opt-in tests have a consistent home.

## Public API contract tests

`tests/unit/test_contracts.py` pins the public surface of `attune_rag`:

- Every name in `__all__` must be importable, the right kind (class /
callable / dict), and have a docstring.
- Function signatures for the most-consumed callables (`build_augmented_prompt`,
`RagPipeline.run`) preserve documented kwargs.
- `CitedSource` keeps the consumer-facing `template_path`/`score`/`excerpt`/`category`
fields (attune-gui maps `result.citation.hits` directly into its own
`RagHit` shape).

Adding a new public export requires updating `EXPECTED_ALL` in this
file — that's deliberate friction. attune-rag is the API contract
source for attune-gui, attune-help, and attune-author.

## Golden retrieval suite

`tests/golden/` runs each `queries.yaml` entry through a real
`RagPipeline` and asserts overlap with `expected_in_top_3`. Hard
queries are dynamically `@pytest.mark.xfail(strict=False)` so retriever
upgrades surface as `XPASS` rather than silent regressions.

Adding a query is a one-line YAML edit in `tests/golden/queries.yaml`.

## What's tested vs. not

Tracked in
`/Users/patrickroebuck/attune/specs/test-strategy/current-state.md`. After
pass 1, the highest-value remaining gaps in this layer are:

- `dashboard/show.py` (172 statements, 0% covered) — Rich CLI module
with no tests; biggest remaining gap by line count.
- `dashboard/refresh.py` (~66%) — refresh paths.
- `eval/bench_prompts.py` (~70%) — prompt builders.

Pass 2 will revisit; stretch ceiling for this layer is **90%**
(rag is the contract source of truth, gets the highest gate).
270 changes: 270 additions & 0 deletions tests/unit/test_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,270 @@
"""Tests for attune_rag.benchmark — CLI exit codes, aggregation, helpers.

Targets the highest coverage gap identified in the test-strategy audit:
``benchmark.py`` was at 10% line coverage. These tests exercise the pure
helpers (`_percentile`, `_load_queries`) and the CLI happy + error paths
without spending API tokens (`--with-faithfulness` is gated behind a
real ANTHROPIC_API_KEY check we don't satisfy here).
"""

from __future__ import annotations

from pathlib import Path
from types import SimpleNamespace
from unittest.mock import patch

import pytest
import yaml

from attune_rag.benchmark import (
_default_queries_path,
_load_queries,
_percentile,
_run_benchmark,
main,
)

# ---------------------------------------------------------------------------
# Helpers / fixtures
# ---------------------------------------------------------------------------


def _write_queries(path: Path, queries: list[dict]) -> Path:
"""Write a queries.yaml-shaped file to ``path``."""
path.write_text(yaml.safe_dump({"queries": queries}), encoding="utf-8")
return path


def _hit(template_path: str) -> SimpleNamespace:
"""Minimal RagPipeline.run hit shape — only template_path used by benchmark."""
return SimpleNamespace(template_path=template_path)


def _result(*hits: SimpleNamespace) -> SimpleNamespace:
return SimpleNamespace(citation=SimpleNamespace(hits=list(hits)))


class _FakeRetriever:
pass


class _FakePipeline:
"""Stub RagPipeline that returns scripted results per query string."""

def __init__(self, scripted: dict[str, list[str]]) -> None:
self._scripted = scripted
self.retriever = _FakeRetriever()
self.corpus = SimpleNamespace(name="fake-corpus")

def run(self, query: str, k: int = 3) -> SimpleNamespace:
paths = self._scripted.get(query, [])
return _result(*[_hit(p) for p in paths])


# ---------------------------------------------------------------------------
# _default_queries_path
# ---------------------------------------------------------------------------


def test_default_queries_path_resolves_inside_repo() -> None:
path = _default_queries_path()
assert path.name == "queries.yaml"
# Path-component check (OS-independent — Windows uses backslashes).
assert path.parent.name == "golden"
assert path.parent.parent.name == "tests"


# ---------------------------------------------------------------------------
# _load_queries
# ---------------------------------------------------------------------------


def test_load_queries_returns_query_list(tmp_path: Path) -> None:
p = _write_queries(
tmp_path / "q.yaml",
[{"id": "q1", "query": "hello", "expected_in_top_3": ["a.md"]}],
)
out = _load_queries(p)
assert len(out) == 1
assert out[0]["id"] == "q1"


def test_load_queries_raises_when_no_queries_key(tmp_path: Path) -> None:
p = tmp_path / "empty.yaml"
p.write_text("queries: []\n", encoding="utf-8")
with pytest.raises(ValueError, match="No queries"):
_load_queries(p)


def test_load_queries_raises_when_top_level_missing_queries(tmp_path: Path) -> None:
p = tmp_path / "junk.yaml"
p.write_text("not_queries: []\n", encoding="utf-8")
with pytest.raises(ValueError):
_load_queries(p)


# ---------------------------------------------------------------------------
# _percentile
# ---------------------------------------------------------------------------


def test_percentile_empty_list_returns_zero() -> None:
assert _percentile([], 0.95) == 0.0


@pytest.mark.parametrize(
"values,pct,expected",
[
([1.0], 0.5, 1.0),
([1.0, 2.0, 3.0, 4.0, 5.0], 0.0, 1.0), # min
([1.0, 2.0, 3.0, 4.0, 5.0], 1.0, 5.0), # max
([10.0, 20.0, 30.0, 40.0, 50.0], 0.5, 30.0), # median
],
)
def test_percentile_typical_values(values: list[float], pct: float, expected: float) -> None:
assert _percentile(values, pct) == expected


def test_percentile_handles_unsorted_input() -> None:
assert _percentile([5.0, 2.0, 9.0, 1.0, 3.0], 0.0) == 1.0


# ---------------------------------------------------------------------------
# _run_benchmark — aggregation math
# ---------------------------------------------------------------------------


def test_run_benchmark_perfect_precision_and_recall() -> None:
queries = [
{"id": "q1", "query": "auth", "expected_in_top_3": ["concepts/auth.md"]},
{"id": "q2", "query": "memory", "expected_in_top_3": ["concepts/memory.md"]},
]
pipeline = _FakePipeline(
{
"auth": ["concepts/auth.md"],
"memory": ["concepts/memory.md"],
}
)
with patch("attune_rag.RagPipeline", return_value=pipeline):
report = _run_benchmark(queries, k=3)
assert report["total_queries"] == 2
assert report["precision_at_1"] == 1.0
assert report["recall_at_k"] == 1.0
assert report["k"] == 3


def test_run_benchmark_zero_precision_when_top1_misses() -> None:
queries = [
{"id": "q1", "query": "auth", "expected_in_top_3": ["concepts/auth.md"]},
]
pipeline = _FakePipeline(
{"auth": ["concepts/wrong.md", "concepts/auth.md"]},
)
with patch("attune_rag.RagPipeline", return_value=pipeline):
report = _run_benchmark(queries, k=3)
assert report["precision_at_1"] == 0.0
# But recall@3 still counts since auth.md is in the top-k set
assert report["recall_at_k"] == 1.0


def test_run_benchmark_zero_recall_when_no_match() -> None:
queries = [
{"id": "q1", "query": "auth", "expected_in_top_3": ["concepts/auth.md"]},
]
pipeline = _FakePipeline({"auth": ["concepts/elsewhere.md"]})
with patch("attune_rag.RagPipeline", return_value=pipeline):
report = _run_benchmark(queries, k=3)
assert report["precision_at_1"] == 0.0
assert report["recall_at_k"] == 0.0


def test_run_benchmark_records_per_query_detail() -> None:
queries = [
{
"id": "q1",
"query": "auth",
"expected_in_top_3": ["a.md"],
"difficulty": "easy",
},
{
"id": "q2",
"query": "memory",
"expected_in_top_3": ["m.md"],
"difficulty": "hard",
},
]
pipeline = _FakePipeline({"auth": ["a.md"], "memory": ["wrong.md"]})
with patch("attune_rag.RagPipeline", return_value=pipeline):
report = _run_benchmark(queries, k=3)
by_id = {q["id"]: q for q in report["per_query"]}
assert by_id["q1"]["top1_match"] is True
assert by_id["q2"]["top1_match"] is False
assert by_id["q1"]["difficulty"] == "easy"
assert by_id["q2"]["difficulty"] == "hard"


def test_run_benchmark_empty_queries_yields_zero_metrics() -> None:
"""Defensive: total=0 must not divide-by-zero."""
pipeline = _FakePipeline({})
with patch("attune_rag.RagPipeline", return_value=pipeline):
report = _run_benchmark([], k=3)
assert report["total_queries"] == 0
assert report["precision_at_1"] == 0.0
assert report["recall_at_k"] == 0.0
assert report["mean_latency_ms"] == 0.0


# ---------------------------------------------------------------------------
# main() — exit codes
# ---------------------------------------------------------------------------


def test_main_exits_2_when_queries_file_missing(
tmp_path: Path, capsys: pytest.CaptureFixture[str]
) -> None:
rc = main(["--queries", str(tmp_path / "nope.yaml")])
assert rc == 2
assert "Queries file not found" in capsys.readouterr().err


def test_main_exits_0_on_perfect_precision(tmp_path: Path) -> None:
p = _write_queries(
tmp_path / "q.yaml",
[{"id": "q1", "query": "auth", "expected_in_top_3": ["a.md"]}],
)
pipeline = _FakePipeline({"auth": ["a.md"]})
with patch("attune_rag.RagPipeline", return_value=pipeline):
rc = main(["--queries", str(p), "--min-precision", "0.5"])
assert rc == 0


def test_main_exits_1_when_precision_below_gate(
tmp_path: Path, capsys: pytest.CaptureFixture[str]
) -> None:
p = _write_queries(
tmp_path / "q.yaml",
[{"id": "q1", "query": "auth", "expected_in_top_3": ["a.md"]}],
)
pipeline = _FakePipeline({"auth": ["wrong.md"]}) # 0% precision
with patch("attune_rag.RagPipeline", return_value=pipeline):
rc = main(["--queries", str(p), "--min-precision", "0.5"])
assert rc == 1
assert "FAIL" in capsys.readouterr().err


def test_main_with_faithfulness_requires_api_key(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str]
) -> None:
"""--with-faithfulness without ANTHROPIC_API_KEY exits 2."""
p = _write_queries(
tmp_path / "q.yaml",
[{"id": "q1", "query": "auth", "expected_in_top_3": ["a.md"]}],
)
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
pipeline = _FakePipeline({"auth": ["a.md"]})
with patch("attune_rag.RagPipeline", return_value=pipeline):
rc = main(
["--queries", str(p), "--min-precision", "0.0", "--with-faithfulness"],
)
assert rc == 2
assert "ANTHROPIC_API_KEY" in capsys.readouterr().err
Loading
Loading