diff --git a/.env.example b/.env.example index 0545e3b1..df1b1465 100644 --- a/.env.example +++ b/.env.example @@ -97,7 +97,7 @@ AGENT_SESSION_TTL_MINUTES=120 AGENT_MAX_SESSIONS_PER_USER=5 # Human-in-the-loop actions (JSON array format required for safe parsing) -AGENT_REQUIRE_APPROVAL=["create_alias","archive_run"] +AGENT_REQUIRE_APPROVAL=["create_alias","archive_run","save_scenario"] AGENT_APPROVAL_TIMEOUT_MINUTES=60 # Streaming diff --git a/.gitignore b/.gitignore index d9b12311..bdb8bf25 100644 --- a/.gitignore +++ b/.gitignore @@ -27,6 +27,7 @@ infra/terraform.tfvars .terraform .terraform* .langgraph_api +.mcp.json # Virtual environments .venv diff --git a/CLAUDE.md b/CLAUDE.md index 44557d80..cf8adbd8 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -21,8 +21,8 @@ Claude pulls these in on demand — load only what the task touches. - Pipeline contract (CI/CD): @docs/_base/PIPELINE_CONTRACT.md > Project rules are enforced via `.claude/rules/` (commit-format, branch-naming, -> security-patterns, product-vision, test-requirements, ui-design, versioning, -> output-formatting). Read those first — they are authoritative on detail. +> security-patterns, product-vision, test-requirements, ui-design, shadcn-ui, +> versioning, output-formatting). Read those first — they are authoritative on detail. ## Safety diff --git a/PRPs/PRP-23-rag-corpus-manager.md b/PRPs/PRP-23-rag-corpus-manager.md new file mode 100644 index 00000000..4eeca2e2 --- /dev/null +++ b/PRPs/PRP-23-rag-corpus-manager.md @@ -0,0 +1,869 @@ +name: "PRP-23 — RAG Corpus Manager: one-click bulk-index of bundled project docs" +description: | + Promote the MVP of `docs/optional-features/01-rag-corpus-manager.md` into code. + + A fresh ForecastLabAI install has an **empty RAG corpus** (`0 sources / 0 chunks`), + so the RAG Assistant agent can cite nothing and the Knowledge page is a permanent + empty state — despite the repo bundling ~115 markdown files under `docs/`, `PRPs/`, + and the root. + + This PRP adds **one new orchestration endpoint** — `POST /rag/index/project-docs` — + that discovers the bundled markdown and indexes each file through the existing + `RAGService.index_document` path (reusing its chunking, embedding, SHA-256 + content-hash idempotency, and upsert). The Admin → "RAG Sources" tab gets an + **"Index Project Docs"** button that calls it and toasts the summary. + + Everything else the feature doc lists is **already done** or **out of scope**: + source listing / deletion / provider-health / Knowledge empty-state all exist; + stale-detection, re-index, chunk-preview, and the Knowledge source-type filter are + explicitly deferred to a follow-up ("Full Version" — see Anti-Patterns / NOTES). + +> **PRP numbering:** `PRP-16` is reserved (Phase-2 LightGBM). `PRP-17`–`PRP-22` are +> used. This is `PRP-23`. Source plan: `.agents/plans/rag-index-project-docs.md`. + +## Purpose + +Close the "the RAG corpus starts empty and there is no operator-facing way to fill +it" gap. Today the only ways to populate the corpus are (a) `POST /rag/index` once +per file (~115 calls, requires pasting each path) or (b) the seeder's synthetic +3-document scenario, which indexes throwaway test prose rather than the real project +documentation. An operator or demo reviewer needs **one click** that turns +`0 sources` into a populated, citable corpus drawn from the repo's own docs. + +## Core Principles + +1. **Context is King** — every endpoint shape, schema field, service method, hook + name, and pattern below is linked to a real source file with verified line + numbers. +2. **Reuse, don't reinvent** — `index_project_docs` is a thin orchestrator over the + existing `RAGService.index_document`; it does NOT re-implement hashing, chunking, + embedding, or upsert. The route mirrors the existing `index_document` route's + exception handling; the hook mirrors `useIndexDocument`. +3. **Additive only** — NO Alembic migration (no schema change — `category` rides in + the existing `DocumentSource.metadata_` JSONB), NO new slice, NO `.env` var, NO + `app/main.py` change (the `rag` router is already wired). +4. **Strict gates honored** — `.py` files in the `rag` slice change, so the repo-wide + `ruff` / `mypy --strict` / `pyright --strict` / `pytest` CI jobs genuinely apply; + the new endpoint ships with unit + integration tests. +5. **UI through the running app** — the Admin button is verified in a real browser + via `webapp-testing` per `.claude/rules/ui-design.md`. A green `tsc` is NOT proof + the UI works. + +--- + +## Goal + +**Backend (additive, no migration, no `main.py` change):** + +- `POST /rag/index/project-docs` — discovers markdown under `docs/**`, `PRPs/**`, + and a fixed root-file allow-list (`README.md`, `AGENTS.md`, `CHANGELOG.md`), + indexes each through `RAGService.index_document`, and returns a per-file + + aggregate summary. Request body is three optional booleans (`include_docs`, + `include_prps`, `include_root`, all default `true`). Idempotent — re-runs return + every file `unchanged` via the existing SHA-256 short-circuit. A single + unreadable / non-UTF-8 file is reported `failed` without aborting the batch; + `EmbeddingError` / `SQLAlchemyError` are batch-fatal and surface as `502` / + `application/problem+json`. + +**Frontend:** + +- New TanStack mutation hook `useIndexProjectDocs` in `use-rag-sources.ts`. +- Three new TS types (`IndexProjectDocsRequest`, `ProjectDocResult`, + `IndexProjectDocsResponse`) in `types/api.ts`. +- An **"Index Project Docs"** button in the Admin → "RAG Sources" tab + (`RagSourcesPanel`) — spinner while running, a `toast` summary on completion, + and a `['rag-sources']` query invalidation so the list + counts refresh. + +## Why + +- **Portfolio identity.** `.claude/rules/product-vision.md` principle 1 — + "portfolio-grade, end-to-end … every phase ships working code". The RAG slice + exists end-to-end but is invisible: a reviewer opening a fresh system sees an + empty Knowledge page and an agent that can cite nothing. This makes the existing + RAG investment demonstrable. +- **Demo narrative.** `docs/optional-features/README.md` § "Promotion Criteria" — + a feature should "improve the demo narrative without breaking the local-first + setup". Bulk-indexing the repo's own docs is the most direct way to show the RAG + Assistant working off real evidence. +- **Operator workflow.** The feature doc's user value: "Demo reviewers can index + project docs without CLI setup"; "the Knowledge page becomes a real corpus + browser instead of mostly an empty-state page". + +## What + +A logged-in operator opens **Admin → RAG Sources**, sees `0 sources • 0 chunks`, +clicks **Index Project Docs**, watches a spinner for up to ~1–3 minutes (first run, +real embedding provider), then sees a toast — e.g. *"Indexed 112, updated 0, +unchanged 0, 3 failed — 1 480 chunks"* — and the source list populates. Opening +**Knowledge** now shows the corpus and semantic search returns cited chunks. +Clicking **Index Project Docs** again completes near-instantly with every file +`unchanged`. + +### Success Criteria + +- [ ] `POST /rag/index/project-docs` indexes `docs/**/*.md`, `PRPs/**/*.md`, and the + root allow-list; returns `IndexProjectDocsResponse` with per-file results + + aggregate counts. +- [ ] Idempotent — a second call with unchanged files returns every result + `unchanged` and creates no new chunks. +- [ ] `include_docs` / `include_prps` / `include_root` toggles select roots + independently; an empty `{}` body indexes all three. +- [ ] A single unreadable / non-UTF-8 file is reported `status="failed"` with an + `error` string and does not abort the batch. +- [ ] `EmbeddingError` → `502`, `SQLAlchemyError` → `DatabaseError` / + `application/problem+json` (no partial commit — the request rolls back). +- [ ] Admin → "RAG Sources" has a working "Index Project Docs" button: spinner, + toast summary (`toast.warning` when `failed > 0`, else `toast.success`), and + a live source-list refresh. +- [ ] All validation gates (ruff, mypy --strict, pyright --strict, pytest unit + + integration, frontend tsc/lint/test) pass; integration tests leave no + `test-` rows in `document_source`. +- [ ] `docs/_base/API_CONTRACTS.md` lists the new endpoint. +- [ ] No regression in existing RAG tests or `app/core/tests/test_strict_mode_policy.py`. + +## All Needed Context + +### Documentation & References + +```yaml +# ---- External docs ---- +- url: https://docs.python.org/3/library/pathlib.html#pathlib.Path.rglob + why: Path.rglob("*.md") for recursive discovery. CRITICAL — rglob on a + NON-EXISTENT directory yields nothing (no exception); relied on so an + absent docs/ or PRPs/ root simply contributes 0 files. +- url: https://fastapi.tiangolo.com/tutorial/body/ + why: A Pydantic model as a request body whose fields ALL have defaults + validates an empty `{}` payload — the frontend always posts `{}`. +- url: https://docs.pydantic.dev/latest/concepts/models/#extra-fields + why: ConfigDict(extra="forbid") on the request → an unknown body field 422s. + Mirrors the existing IndexRequest. + +# ---- Source feature spec ---- +- file: docs/optional-features/01-rag-corpus-manager.md + why: The spec. Implement ONLY the "MVP Scope" section. "Full Version" (stale + detection, re-index, chunk preview, Knowledge filters) is OUT OF SCOPE. +- file: .agents/plans/rag-index-project-docs.md + why: The source implementation plan this PRP refines (notably: the unit test + now targets a new pure _discover_project_doc_files helper, not a mocked + index_document — see "Resolved Decisions"). + +# ---- Backend: the rag slice (all changes land here) ---- +- file: app/features/rag/routes.py + why: lines 61-133 — `index_document` route: the EXACT exception-handling + shape to mirror (EmbeddingError→502, SQLAlchemyError→DatabaseError) and + the structured-logging style. Lines 12-19 — the schema import block to + extend. Lines 1-24 — router, `logger`, `RAGService` imports. +- file: app/features/rag/service.py + why: lines 130-251 — `index_document`, the method `index_project_docs` + orchestrates per file. lines 159-163 — the `if request.content:` branch + (see the empty-file GOTCHA). lines 173-191 — the SHA-256 idempotency + short-circuit. lines 61-81 — `__init__` + the `base_dir` test override. + lines 94-128 — `_read_content_from_path` (path-traversal pattern). + lines 29-38 — the schema import block to extend. +- file: app/features/rag/schemas.py + why: lines 17-43 `IndexRequest`, 46-65 `IndexResponse` — the schema style to + mirror: `ConfigDict(extra="forbid")` on the request, `Literal` status + field, `Args:` docstrings. NOTE: `IndexRequest` is NOT `strict=True`, so + the new request model needs no `Field(strict=False)` overrides and + `app/core/tests/test_strict_mode_policy.py` is unaffected. +- file: app/features/rag/models.py + why: lines 35-66 `DocumentSource` — confirms `source_type` is free-form + `String(50)` (we keep `"markdown"`), `metadata_` is JSONB (we store + `{"category": ...}`), and `uq_source_type_path` drives idempotency. +- file: app/features/rag/chunkers.py + why: `get_chunker("markdown")` → `MarkdownChunker`. Confirms `"markdown"` is a + valid `source_type` for every project doc. +- file: app/core/exceptions.py + why: `DatabaseError` — re-raised on `SQLAlchemyError`; already imported in + `routes.py:9`. + +# ---- Backend: tests ---- +- file: app/features/rag/tests/conftest.py + why: `db_session` + `client` integration fixtures, `mock_embedding_service` + unit fixture, and the cleanup at LINE 46 + (`DocumentSource.source_path.like("test-%")`) — this PRP widens it to + `"%test-%"` so nested fixture paths (`docs/test-*.md`) are cleaned up. +- file: app/features/rag/tests/test_routes.py + why: lines 22-37 `create_mock_embedding_service()` and the + `patch("app.features.rag.service.get_embedding_service", ...)` pattern; + `TestIndexEndpoint` (45-167) class layout to mirror. +- file: app/features/rag/tests/test_service.py + why: `TestRAGServiceUnit` — pure-unit class layout (`RAGService()` with no DB, + no mocks); the home for the new `_discover_project_doc_files` unit test. + +# ---- Frontend ---- +- file: frontend/src/hooks/use-rag-sources.ts + why: lines 29-41 `useIndexDocument` — the EXACT mutation-hook shape + (`useMutation` + `api(...)` + `invalidateQueries(['rag-sources'])`). +- file: frontend/src/pages/admin.tsx + why: lines 116-253 `RagSourcesPanel` — where the button goes; the `CardHeader` + actions area (148-205); the lucide import block (4-21 — `Library` must be + ADDED); `toast` already imported (line 68); the `handleGenerate` toast + pattern (470-488); the `Loader2` spinner-in-button pattern (line 199). +- file: frontend/src/types/api.ts + why: lines 258-313 — the `// === RAG ===` block to extend; `RagSource`, + `IndexDocumentResponse`, `RetrieveResponse` naming convention. +- file: frontend/src/lib/api.ts + why: lines 23-44 — `api(endpoint, {method, body})`; a truthy `{}` body is + JSON-stringified to `"{}"`. + +# ---- Rules ---- +- file: .claude/rules/security-patterns.md + why: § "File operations" — `pathlib.Path.resolve()`, allow-listed roots, no + `..`. Discovery globs only fixed roots under `base_dir` (no user input) → + inherently allow-listed; keep it that way. +- file: .claude/rules/test-requirements.md + why: new endpoint ⇒ route test with 2xx happy path + ≥1 error path. +- file: .claude/rules/commit-format.md + why: commit `type(scope): description (#issue)`; `rag,ui` comma-pair scope is + allowed; every commit references an open issue; NO AI co-author trailer. +``` + +### Current Codebase tree (relevant) + +``` +app/features/rag/ +├── __init__.py +├── chunkers.py # MarkdownChunker / OpenAPIChunker — UNCHANGED +├── embeddings.py # OpenAI / Ollama providers — UNCHANGED +├── models.py # DocumentSource / DocumentChunk — UNCHANGED (no migration) +├── routes.py # /rag/index, /retrieve, /sources — ADD one route +├── schemas.py # IndexRequest, …, DeleteResponse — ADD three models +├── service.py # RAGService — ADD _discover_project_doc_files + index_project_docs +└── tests/ + ├── conftest.py # MODIFY line 46 cleanup glob + ├── test_chunkers.py # UNCHANGED + ├── test_embeddings.py # UNCHANGED + ├── test_routes.py # ADD TestIndexProjectDocsEndpoint + ├── test_schemas.py # ADD new-schema cases + └── test_service.py # ADD _discover_project_doc_files unit test + +frontend/src/ +├── hooks/use-rag-sources.ts # ADD useIndexProjectDocs +├── pages/admin.tsx # ADD button in RagSourcesPanel +└── types/api.ts # ADD 3 interfaces + +docs/_base/API_CONTRACTS.md # ADD one table row +``` + +### Desired Codebase tree (files added / changed) + +No new files. Eleven existing files are modified: + +``` +MODIFY app/features/rag/schemas.py + IndexProjectDocsRequest / ProjectDocResult / IndexProjectDocsResponse +MODIFY app/features/rag/service.py + _discover_project_doc_files() + index_project_docs() + 2 module constants +MODIFY app/features/rag/routes.py + POST /rag/index/project-docs route +MODIFY app/features/rag/tests/conftest.py ~ cleanup glob "test-%" -> "%test-%" +MODIFY app/features/rag/tests/test_schemas.py + new-schema validation cases +MODIFY app/features/rag/tests/test_service.py + _discover_project_doc_files unit test +MODIFY app/features/rag/tests/test_routes.py + TestIndexProjectDocsEndpoint (integration) +MODIFY frontend/src/types/api.ts + 3 interfaces +MODIFY frontend/src/hooks/use-rag-sources.ts + useIndexProjectDocs hook +MODIFY frontend/src/pages/admin.tsx + "Index Project Docs" button (+ Library icon import) +MODIFY docs/_base/API_CONTRACTS.md + endpoint-table row +``` + +### Known Gotchas & Library Quirks + +```python +# CRITICAL: CRLF line endings. Every existing app/**/*.py file in this repo is +# CRLF-terminated (no .gitattributes — project memory). A FULL-FILE rewrite +# (the Write tool, or a text-mode dump) silently flips them to LF and produces +# a whole-file diff. Use the Edit tool (exact string replacement — it preserves +# the surrounding line endings) for every .py change. After EACH edit run +# `git diff --stat`: the changed-line count must be small. If you see a +# whole-file churn, the EOLs flipped — restore CRLF before continuing. New +# files: none here. Frontend .ts/.tsx files are LF — safe. + +# CRITICAL: NO app/main.py change. The `rag` router is already wired +# (main.py:27 import, main.py:142 include_router). The new route attaches to +# the existing `router = APIRouter(prefix="/rag", ...)` in routes.py. + +# CRITICAL: NO Alembic migration. DocumentSource / DocumentChunk are unchanged. +# The per-source `category` ("docs" | "prp" | "root") rides inside the EXISTING +# `DocumentSource.metadata_` JSONB column. `.claude/rules` require a migration +# only when the SCHEMA changes — adding one here would be wrong. + +# CRITICAL: index_document's content branch (service.py:160) is `if request.content:` +# — an EMPTY string is FALSY, so an empty .md file passed as content="" falls +# through to `_read_content_from_path(rel)`, which resolves the relative path +# against CWD. In production CWD == base_dir (uvicorn runs from the repo root), +# so the redundant re-read succeeds and the file indexes to 0 chunks. In a +# base_dir-OVERRIDE test (CWD != base_dir) it raises FileNotFoundError — which +# is a subclass of OSError and is therefore caught by the per-file +# `except (OSError, ValueError)` and reported `status="failed"`. NEVER fatal. +# Mitigation: make every test fixture file NON-EMPTY (`"# Test\n\nContent."`). +# Do NOT "fix" index_document — it is shared with POST /rag/index. + +# CRITICAL: pass BOTH source_path (the clean RELATIVE posix path — the DB id) AND +# content (the file text) to IndexRequest. source_path drives the +# `(source_type, source_path)` idempotency lookup + is stored; content (when +# truthy) is hashed/chunked. NEVER store an absolute path — it is +# machine-specific and breaks idempotency across machines/CI. + +# CRITICAL: route-test base_dir injection. The route does `RAGService()` with no +# args (→ base_dir = Path.cwd()). To point an integration test at a tmp_path, +# patch the class symbol in the routes module: +# patch("app.features.rag.routes.RAGService", +# functools.partial(RAGService, base_dir=str(tmp_path))) +# `partial(Cls, kw=v)()` constructs `Cls(kw=v)`. Patch +# `app.features.rag.service.get_embedding_service` SEPARATELY (the existing +# test pattern) so __init__ picks up the mock provider. + +# CRITICAL: integration-test cleanup. conftest.py:46 deletes +# `source_path LIKE 'test-%'`. Project-doc source paths are NESTED +# (`docs/test-proj-1.md`) and do NOT start with `test-`. Widen the glob to +# `"%test-%"` (Task 1). Real corpus paths (`docs/ARCHITECTURE.md`, `PRPs/PRP-1-…`) +# never contain `test-`, so the wider LIKE is safe; existing `test-`-prefixed +# fixtures still match. Name every new fixture file with a `test-` token. + +# GOTCHA: synchronous by design. Indexing runs in-request. ~115 bundled markdown +# files ⇒ the first run with a real embedding provider takes ~1-3 min (one +# batched embedding call per file). `fetch` has no default timeout and the +# TanStack mutation waits, so this is acceptable for an admin action; re-runs +# are fast (all `unchanged`). The jobs-layer upgrade is the deferred +# "Full Version" — OUT OF SCOPE here. + +# GOTCHA: status-Literal widening. IndexResponse.status is +# Literal["indexed","updated","unchanged"]; ProjectDocResult.status is +# Literal["indexed","updated","unchanged","failed"]. Assigning the narrower +# into the wider is fine for mypy/pyright (subtype). "failed" is only ever set +# in the per-file except branch. + +# GOTCHA: EmbeddingError is NOT an OSError/ValueError (it extends Exception), so +# it is NOT caught by the per-file `except (OSError, ValueError)` — it +# propagates out of the loop, the request rolls back, and the route maps it to +# 502. Same for SQLAlchemyError. This is intentional: a dead embedding provider +# makes the whole batch pointless. + +# GOTCHA: `RAGService()` is safe to construct in a pure unit test with no mocks — +# __init__ only builds the (lazy) embedding client + a tiktoken encoder, no +# network. test_service.py::TestRAGServiceUnit already relies on this. + +# GOTCHA: admin.tsx does NOT currently import `Library` from lucide-react. Add it +# to the existing import block (admin.tsx:4-21). `toast` IS already imported +# (admin.tsx:68). +``` + +### Resolved Decisions (carried from `.agents/plans/rag-index-project-docs.md`) + +- **Scope = MVP only.** Out of scope: `GET /rag/sources/{id}/chunks` + chunk + preview, `POST /rag/sources/{id}/reindex` + stale detection, the Knowledge-page + source-type filter, per-source embedding-metadata columns (would need a + migration). Keeping the PR small matches the maintainer preference in + `CLAUDE.local.md` ("prefer a smaller PR over a bundled one"). +- **Root file allow-list = `("README.md", "AGENTS.md", "CHANGELOG.md")`.** `CLAUDE.md` + is excluded — it is mostly an operating index and `@import`s `AGENTS.md` (whose + substance is already indexed). +- **`source_type` stays `"markdown"` for every project doc.** The `docs|prp|root` + distinction is stored as `metadata.category`, which powers the existing + `RetrieveRequest.filters.category` path (`service.py:585-589`) for free, with no + schema change. +- **Refinement vs the plan:** discovery is extracted into a pure, sync + `_discover_project_doc_files` helper so it can be unit-tested with no DB and no + mocks (the plan's "mock `index_document`" approach would pass a `MagicMock` + where `mypy --strict` expects an `AsyncSession`). The full `index_project_docs` + loop/aggregate path is covered by the route integration test. +- **Synchronous in-request** (not the jobs layer) — see the GOTCHA above. + +## Implementation Blueprint + +### Data models — backend schemas (`app/features/rag/schemas.py`) + +Append after `DeleteResponse`. Mirror `IndexRequest` / `IndexResponse` style. + +```python +# Pseudocode — do not copy verbatim; add full `Args:` docstrings per file style. + +class IndexProjectDocsRequest(BaseModel): + """Request to bulk-index bundled project documentation.""" + model_config = ConfigDict(extra="forbid") # NOT strict=True (mirror IndexRequest) + include_docs: bool = Field(default=True, description="Index docs/**/*.md") + include_prps: bool = Field(default=True, description="Index PRPs/**/*.md") + include_root: bool = Field(default=True, description="Index README/AGENTS/CHANGELOG") + +class ProjectDocResult(BaseModel): + """Per-file outcome of a project-docs index run.""" + source_path: str + status: Literal["indexed", "updated", "unchanged", "failed"] + chunks_created: int + error: str | None = None + +class IndexProjectDocsResponse(BaseModel): + """Aggregate result of POST /rag/index/project-docs.""" + results: list[ProjectDocResult] + total_files: int + indexed: int + updated: int + unchanged: int + failed: int + total_chunks: int + duration_ms: float +``` + +`Literal`, `BaseModel`, `ConfigDict`, `Field` are already imported (`schemas.py:11-14`). + +### Data models — frontend types (`frontend/src/types/api.ts`, in the `// === RAG ===` block, after `RetrieveResponse` ~line 313) + +```ts +export interface IndexProjectDocsRequest { + include_docs?: boolean + include_prps?: boolean + include_root?: boolean +} +export interface ProjectDocResult { + source_path: string + status: 'indexed' | 'updated' | 'unchanged' | 'failed' + chunks_created: number + error: string | null +} +export interface IndexProjectDocsResponse { + results: ProjectDocResult[] + total_files: number + indexed: number + updated: number + unchanged: number + failed: number + total_chunks: number + duration_ms: number +} +``` + +### Backend service (`app/features/rag/service.py`) + +Add two module-level constants (after the imports, before `class RAGService`) and +two methods on `RAGService`. + +```python +# Module-level — the allow-listed project-doc roots. +_PROJECT_ROOT_FILES: tuple[str, ...] = ("README.md", "AGENTS.md", "CHANGELOG.md") + +class RAGService: + ... + def _discover_project_doc_files( + self, request: IndexProjectDocsRequest + ) -> list[tuple[Path, str]]: + """Discover bundled markdown under allow-listed roots. Pure + sync. + + Returns a deterministically sorted list of (absolute_path, category) + where category is "docs" | "prp" | "root". + """ + found: list[tuple[Path, str]] = [] + if request.include_docs: + found += [(p, "docs") for p in (self._base_dir / "docs").rglob("*.md")] + if request.include_prps: + found += [(p, "prp") for p in (self._base_dir / "PRPs").rglob("*.md")] + if request.include_root: + for name in _PROJECT_ROOT_FILES: + candidate = self._base_dir / name + if candidate.is_file(): + found.append((candidate, "root")) + # GOTCHA: rglob order is filesystem-dependent — sort for stable results. + return sorted(found, key=lambda pair: str(pair[0])) + + async def index_project_docs( + self, db: AsyncSession, request: IndexProjectDocsRequest + ) -> IndexProjectDocsResponse: + """Bulk-index discovered project docs via index_document. Idempotent.""" + start = time.time() + logger.info("rag.index_project_docs_started", + include_docs=request.include_docs, + include_prps=request.include_prps, + include_root=request.include_root) + + results: list[ProjectDocResult] = [] + for abs_path, category in self._discover_project_doc_files(request): + # abs_path came from globbing UNDER self._base_dir → relative_to is safe. + rel = abs_path.relative_to(self._base_dir).as_posix() + try: + content = abs_path.read_text(encoding="utf-8") + index_response = await self.index_document( + db, + IndexRequest( + source_type="markdown", + source_path=rel, # clean relative DB id + content=content, + metadata={"category": category}, + ), + ) + results.append(ProjectDocResult( + source_path=rel, + status=index_response.status, # narrower Literal → wider: OK + chunks_created=index_response.chunks_created, + error=None, + )) + except (OSError, ValueError) as exc: + # FileNotFoundError ⊂ OSError ; UnicodeDecodeError ⊂ ValueError. + # EmbeddingError / SQLAlchemyError are NOT caught → batch-fatal. + logger.warning("rag.index_project_docs_file_failed", + source_path=rel, error=str(exc), + error_type=type(exc).__name__) + results.append(ProjectDocResult( + source_path=rel, status="failed", + chunks_created=0, error=str(exc))) + + duration_ms = (time.time() - start) * 1000 + summary = IndexProjectDocsResponse( + results=results, + total_files=len(results), + indexed=sum(r.status == "indexed" for r in results), + updated=sum(r.status == "updated" for r in results), + unchanged=sum(r.status == "unchanged" for r in results), + failed=sum(r.status == "failed" for r in results), + total_chunks=sum(r.chunks_created for r in results), + duration_ms=duration_ms, + ) + logger.info("rag.index_project_docs_completed", + total_files=summary.total_files, indexed=summary.indexed, + updated=summary.updated, unchanged=summary.unchanged, + failed=summary.failed, total_chunks=summary.total_chunks, + duration_ms=duration_ms) + return summary +``` + +IMPORTS to add to `service.py`: extend the existing +`from app.features.rag.schemas import (...)` block (lines 29-38) with +`IndexProjectDocsRequest`, `IndexProjectDocsResponse`, `ProjectDocResult`. `time`, +`Path`, `IndexRequest`, `AsyncSession`, `logger` are already imported. + +### Backend route (`app/features/rag/routes.py`) + +Add `IndexProjectDocsRequest, IndexProjectDocsResponse` to the schema import block +(lines 12-19), then append after the `index_document` route: + +```python +@router.post( + "/index/project-docs", + response_model=IndexProjectDocsResponse, + summary="Index bundled project documentation", + description="Discover and index docs/**, PRPs/**, and selected root markdown. " + "Idempotent via content hash; per-file + aggregate summary.", +) +async def index_project_docs( + request: IndexProjectDocsRequest, + db: AsyncSession = Depends(get_db), +) -> IndexProjectDocsResponse: + logger.info("rag.index_project_docs_request_received", + include_docs=request.include_docs, + include_prps=request.include_prps, + include_root=request.include_root) + service = RAGService() + try: + response = await service.index_project_docs(db=db, request=request) + logger.info("rag.index_project_docs_request_completed", + total_files=response.total_files, + total_chunks=response.total_chunks, failed=response.failed) + return response + except EmbeddingError as e: # mirror index_document route + logger.error("rag.index_project_docs_request_failed", error=str(e), + error_type=type(e).__name__, exc_info=True) + raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, + detail=f"Embedding generation failed: {e}") from e + except SQLAlchemyError as e: + logger.error("rag.index_project_docs_request_failed", error=str(e), + error_type=type(e).__name__, exc_info=True) + raise DatabaseError(message="Failed to index project docs", + details={"error": str(e)}) from e +``` + +NO explicit `status_code` → default `200` (this is a mixed, idempotent batch — not a +single-resource create). `/index/project-docs` and `/index` are distinct static +paths — no route-ordering conflict. Do NOT add a `FileNotFoundError` handler: the +service swallows per-file read errors as `status="failed"` and never raises it. + +### Frontend hook (`frontend/src/hooks/use-rag-sources.ts`) — mirror `useIndexDocument` + +```ts +// extend the type import with IndexProjectDocsRequest, IndexProjectDocsResponse +export function useIndexProjectDocs() { + const queryClient = useQueryClient() + return useMutation({ + mutationFn: (body: IndexProjectDocsRequest) => + api('/rag/index/project-docs', { method: 'POST', body }), + onSuccess: () => { + void queryClient.invalidateQueries({ queryKey: ['rag-sources'] }) + }, + }) +} +``` + +### Admin button (`frontend/src/pages/admin.tsx` → `RagSourcesPanel`) + +- Add `Library` to the lucide-react import (admin.tsx:4-21). +- In `RagSourcesPanel`, call `const indexProjectDocs = useIndexProjectDocs()`. +- Add a handler: + +```tsx +const handleIndexProjectDocs = async () => { + try { + const r = await indexProjectDocs.mutateAsync({}) // {} → all roots + const summary = + `Indexed ${r.indexed}, updated ${r.updated}, unchanged ${r.unchanged}, ` + + `${r.failed} failed — ${r.total_chunks} chunks` + if (r.failed > 0) toast.warning(summary) + else toast.success(summary) + } catch (err) { + toast.error(err instanceof Error ? err.message : 'Project-docs indexing failed') + } +} +``` + +- In the `CardHeader`, wrap the existing "Index Document" `` and a new + ` +``` + +Do NOT restructure the existing "Index Document" dialog — only wrap + add beside it. +No confirm dialog — indexing is additive and idempotent. + +### list of tasks (in execution order) + +```yaml +Task 0 — PRECONDITION: + - Find or open a GitHub issue (promote docs/optional-features/01-rag-corpus-manager.md). + - VERIFY: gh issue view --json state → "OPEN". + - git switch -c feat/rag-index-project-docs (off an up-to-date dev). + +Task 1 — MODIFY app/features/rag/tests/conftest.py: + - FIND: DocumentSource.source_path.like("test-%") # line 46 + - REPLACE: DocumentSource.source_path.like("%test-%") + - Use the Edit tool (preserve CRLF). git diff --stat → 1 line changed. + +Task 2 — MODIFY app/features/rag/schemas.py: + - APPEND IndexProjectDocsRequest, ProjectDocResult, IndexProjectDocsResponse + after DeleteResponse. MIRROR IndexRequest/IndexResponse style. + +Task 3 — MODIFY app/features/rag/service.py: + - ADD module constant _PROJECT_ROOT_FILES after the imports. + - EXTEND the rag.schemas import block with the 3 new names. + - ADD RAGService._discover_project_doc_files (pure/sync) and + RAGService.index_project_docs (async orchestrator). + +Task 4 — MODIFY app/features/rag/routes.py: + - EXTEND the rag.schemas import block with IndexProjectDocsRequest/Response. + - ADD the POST /rag/index/project-docs route after index_document. + +Task 5 — MODIFY app/features/rag/tests/test_schemas.py: + - ADD cases: empty IndexProjectDocsRequest() defaults all True; + model_validate({}) ok; unknown field → ValidationError (extra="forbid"); + ProjectDocResult rejects an out-of-Literal status; IndexProjectDocsResponse + round-trips a populated payload. + +Task 6 — MODIFY app/features/rag/tests/test_service.py: + - ADD a UNIT test for _discover_project_doc_files: build a tmp_path tree + (docs/test-a.md, docs/sub/test-b.md, PRPs/test-c.md, README.md, notes.txt), + RAGService(base_dir=str(tmp_path)), assert discovery counts, category tags, + .md-only filtering, root allow-list, and include_* toggles. No DB, no mocks. + +Task 7 — MODIFY app/features/rag/tests/test_routes.py: + - ADD @pytest.mark.integration TestIndexProjectDocsEndpoint (see pseudocode). + +Task 8 — MODIFY frontend/src/types/api.ts: + - ADD the 3 interfaces in the // === RAG === block. + +Task 9 — MODIFY frontend/src/hooks/use-rag-sources.ts: + - ADD useIndexProjectDocs (extend the type import). + +Task 10 — MODIFY frontend/src/pages/admin.tsx: + - ADD Library to the lucide import; ADD the button + handler in RagSourcesPanel. + +Task 11 — MODIFY docs/_base/API_CONTRACTS.md: + - ADD a rag row: + | rag | POST | /rag/index/project-docs | Bulk-index bundled docs/, PRPs/, and root markdown; per-file + aggregate summary; idempotent via content hash | + +Task 12 — Run the full Validation Loop (Levels 1-4); fix until green. +``` + +### Per-task pseudocode (highest-risk task) + +```python +# Task 7 — app/features/rag/tests/test_routes.py — the integration test. +# IMPORTS to add: `from functools import partial`, +# `from app.features.rag.service import RAGService`, +# `from app.features.rag.embeddings import EmbeddingError` (EmbeddingService already imported). + +@pytest.mark.integration +class TestIndexProjectDocsEndpoint: + @pytest.mark.asyncio + async def test_indexes_discovered_docs(self, client, tmp_path): + # fixture files — NON-EMPTY, names contain `test-` so conftest cleanup catches them + (tmp_path / "docs").mkdir() + (tmp_path / "PRPs").mkdir() + (tmp_path / "docs" / "test-proj-1.md").write_text("# A\n\nAlpha content.") + (tmp_path / "PRPs" / "test-proj-2.md").write_text("# B\n\nBeta content.") + mock = create_mock_embedding_service() + with patch("app.features.rag.routes.RAGService", + partial(RAGService, base_dir=str(tmp_path))), \ + patch("app.features.rag.service.get_embedding_service", return_value=mock): + r1 = await client.post("/rag/index/project-docs", json={}) + assert r1.status_code == 200 + d1 = r1.json() + assert d1["total_files"] == 2 and d1["indexed"] == 2 + assert d1["total_chunks"] >= 2 and d1["failed"] == 0 + # idempotent re-run + r2 = await client.post("/rag/index/project-docs", json={}) + assert r2.json()["unchanged"] == 2 + + @pytest.mark.asyncio + async def test_empty_roots_returns_zero(self, client, tmp_path): + mock = create_mock_embedding_service() + with patch("app.features.rag.routes.RAGService", + partial(RAGService, base_dir=str(tmp_path))), \ + patch("app.features.rag.service.get_embedding_service", return_value=mock): + r = await client.post("/rag/index/project-docs", json={}) + assert r.status_code == 200 and r.json()["total_files"] == 0 + + @pytest.mark.asyncio + async def test_unknown_field_rejected(self, client): + r = await client.post("/rag/index/project-docs", json={"bogus": True}) + assert r.status_code == 422 # extra="forbid" + + @pytest.mark.asyncio + async def test_embedding_failure_returns_502(self, client, tmp_path): + (tmp_path / "docs").mkdir() + (tmp_path / "docs" / "test-proj-3.md").write_text("# C\n\nGamma content.") + mock = create_mock_embedding_service() + mock.embed_texts = AsyncMock(side_effect=EmbeddingError("no key")) + with patch("app.features.rag.routes.RAGService", + partial(RAGService, base_dir=str(tmp_path))), \ + patch("app.features.rag.service.get_embedding_service", return_value=mock): + r = await client.post("/rag/index/project-docs", json={}) + assert r.status_code == 502 +``` + +### Integration Points + +```yaml +DATABASE: + - migration: NONE — no schema change (category rides in DocumentSource.metadata_ JSONB). +ROUTES: + - app/features/rag/routes.py — new route on the EXISTING `/rag` APIRouter. + - app/main.py — NO change (rag router already wired at main.py:142). +CONFIG: + - NONE — `_PROJECT_ROOT_FILES` is a code constant, not a Settings field. No .env.example change. +FRONTEND: + - frontend/src/hooks/use-rag-sources.ts — new hook beside useIndexDocument. + - frontend/src/pages/admin.tsx — button in RagSourcesPanel; invalidates ['rag-sources']. +DOCS: + - docs/_base/API_CONTRACTS.md — one new rag endpoint row. +``` + +## Validation Loop + +### Level 1: Syntax & Style + +```bash +# Run from the repo root. Fix every error before proceeding. +uv run ruff check . --fix +uv run ruff format . +git diff --stat # CRLF guard: confirm NO whole-file churn on the .py edits +``` + +### Level 2: Type Checks + Unit Tests + +```bash +uv run mypy app/ && uv run pyright app/ # both --strict — both gate merge +uv run pytest app/features/rag/ -v -m "not integration" +uv run pytest app/core/tests/test_strict_mode_policy.py -v # must still pass +``` + +Watch: the `index_response.status` (3-Literal) → `ProjectDocResult.status` +(4-Literal) assignment, and `sum(r.status == "..." for r in results)` returning +`int`, are the most likely strict-mode snags — both are fine, but verify. + +### Level 3: Integration Tests + +```bash +docker compose up -d +uv run alembic upgrade head +uv run pytest app/features/rag/tests/test_routes.py -v -m integration +# If they fail on a stale local Postgres: +# docker compose down -v && docker compose up -d && uv run alembic upgrade head +``` + +### Level 4: Frontend + Manual / Browser QA + +```bash +cd frontend && pnpm tsc --noEmit && pnpm lint && pnpm test --run +``` + +Manual dogfood (per `.claude/rules/ui-design.md` — use the `webapp-testing` skill): + +```bash +# Backend MUST run from the repo root so Path.cwd() == repo root. +uv run uvicorn app.main:app --reload --port 8123 +cd frontend && ./node_modules/.bin/vite --host 0.0.0.0 +``` + +1. `curl -s -X POST localhost:8123/rag/index/project-docs -H 'content-type: application/json' -d '{}' | head -c 400` + → `200`, a JSON summary with `total_files` ≈ 110+. +2. Open `/admin` → "RAG Sources" tab → on a fresh DB it shows `0 sources • 0 chunks`. +3. Click **Index Project Docs** → spinner → toast summary → the source list + + counts populate (the `['rag-sources']` invalidation). +4. Open `/knowledge` → the empty state is gone; "N sources • M chunks" reflects + the corpus; a semantic search ("How does backtesting prevent leakage?") returns + cited chunks. +5. Click **Index Project Docs** again → toast shows all `unchanged` (idempotency). + +## Final validation Checklist + +- [ ] `uv run ruff check .` and `uv run ruff format --check .` — clean. +- [ ] `uv run mypy app/` and `uv run pyright app/` — clean (`--strict`). +- [ ] `uv run pytest app/features/rag/ -v -m "not integration"` — green. +- [ ] `uv run pytest app/features/rag/tests/test_routes.py -v -m integration` — green; + no `document_source` row with `source_path` containing `test-` remains. +- [ ] `uv run pytest app/core/tests/test_strict_mode_policy.py -v` — still green. +- [ ] `cd frontend && pnpm tsc --noEmit && pnpm lint && pnpm test --run` — green. +- [ ] `git diff --stat` shows small, line-level diffs on the `.py` files — NO + whole-file CRLF→LF churn. +- [ ] Manual: Index Project Docs populates the corpus; a re-run is all `unchanged`; + Knowledge search returns cited chunks. +- [ ] `docs/_base/API_CONTRACTS.md` lists `POST /rag/index/project-docs`. +- [ ] Commit `feat(rag,ui): index bundled project docs into the RAG corpus (#)` + — references the open issue, NO AI co-author / "Generated with" trailer; PR + into `dev`. + +--- + +## Anti-Patterns to Avoid + +- ❌ Don't re-implement chunking / embedding / hashing — orchestrate + `index_document`. +- ❌ Don't add an Alembic migration — there is no schema change. +- ❌ Don't touch `app/main.py` — the `rag` router is already wired. +- ❌ Don't "fix" `index_document`'s `if request.content:` branch — it is shared + with `POST /rag/index`; the empty-file edge is already handled by the per-file + `OSError` catch. +- ❌ Don't store absolute paths as `source_path` — use the clean relative POSIX id. +- ❌ Don't rewrite existing `.py` files with the Write tool — CRLF will flip to LF. + Use Edit; verify with `git diff --stat`. +- ❌ Don't widen scope into the "Full Version" (chunk preview, re-index, stale + detection, Knowledge filters) — that is a separate, deferred PR. +- ❌ Don't catch `EmbeddingError` / `SQLAlchemyError` per file — they are + batch-fatal and must reach the route's `502` / `problem+json` handlers. +- ❌ Don't claim the UI works on a green `tsc` alone — dogfood it in a browser. + +## Confidence Score + +**8.5 / 10** for one-pass implementation success. + +The feature is almost entirely additive on a mature, well-tested slice; every +endpoint shape, schema field, and pattern is pinned to a verified source line. The +residual risks are all identified and mitigated in-PRP: (1) CRLF EOL churn on the +`.py` edits — mitigated by the explicit Edit-tool + `git diff --stat` gotcha; +(2) integration-test DB cleanup of nested fixture paths — mitigated by the Task-1 +`LIKE` widening + `test-`-token fixture names; (3) the `RAGService` `base_dir` +injection in the route test — mitigated by the documented `partial(...)` patch +point; (4) the empty-file / falsy-`content` interaction — mitigated by non-empty +fixtures + the `OSError` safety net. The half-point deduction is for the manual +browser-QA step, which depends on a live embedding provider being configured and +reachable in the implementer's environment. diff --git a/PRPs/PRP-24-forecastops-control-center.md b/PRPs/PRP-24-forecastops-control-center.md new file mode 100644 index 00000000..123b09e5 --- /dev/null +++ b/PRPs/PRP-24-forecastops-control-center.md @@ -0,0 +1,924 @@ +name: "PRP-24 — ForecastOps Control Center" +description: | + Context-rich PRP for a new read-only `ops` backend slice + an `/ops` frontend + Control Center page that aggregates operational state (jobs, runs, aliases, + data freshness) and ranks retraining candidates. One-pass implementation target. + +## Purpose + +Add an operator-facing dashboard that connects ForecastLabAI's isolated Explorer/Visualize +pages into one workflow. A new read-only vertical slice `app/features/ops/` exposes two +server-side aggregation endpoints; a new `frontend/src/pages/ops.tsx` page consumes them. + +--- + +## Goal + +Ship a fully working ForecastOps Control Center: + +- **Backend** — a new read-only slice `app/features/ops/` with two endpoints: + - `GET /ops/summary` — system health, job-status counts, run/alias health, data + freshness, and a "needs attention" list (failed jobs/runs + stale aliases). + - `GET /ops/retraining-candidates?limit=` — a `(store, product)` queue ranked by a + deterministic retraining-priority score. +- **Frontend** — a new `/ops` page wired into the top nav, consuming both endpoints, + reusing existing `KPICard` / `StatusBadge` / `Card` / `Table` / loading-error-empty + components, with attention items linking to existing Explorer detail pages. + +End state: `docker compose up` → seed → open `/ops` → operator sees, at a glance, what +needs attention. No new tables, no Alembic migration, no new external dependency. + +## Why + +- **User value** — operators can answer "which forecasts need attention?" without + cross-referencing four CRUD pages. Failed jobs and stale models become visible before + they affect decisions. Retraining candidates are ranked by recency + error. +- **Demo value** — reviewers see a mature ForecastOps story instead of isolated CRUD pages. +- **Integration** — this is the natural layer above the existing `jobs`, `registry`, + `backtesting`, and `analytics` slices and the Explorer pages. It reads their state; it + does not duplicate it. +- **Source docs** — `docs/optional-features/02-forecastops-control-center.md` (feature + brief) and `.agents/plans/forecastops-control-center.md` (the 17-task implementation + plan this PRP is derived from). + +## What + +### User-visible behavior + +A new **Control Center** nav item opens `/ops`, a dense single page with: + +1. **System Health** card — API up, database connected, embedding-provider reachability, + timestamp of the latest successful job. +2. **KPI row** — Active Jobs, Failed Jobs, Run Success Rate, Stale Aliases. +3. **Data Freshness** card — latest sales date, latest completed job, latest successful run. +4. **Needs Attention** table — recent failed jobs, failed runs, and stale aliases; each row + links to the matching Explorer detail page. +5. **Retraining Queue** table — `(store, product)` pairs ranked by priority score, showing + staleness, WAPE, and a human-readable reason. + +The page polls every 15 s, shows loading/error/empty states, and degrades gracefully when +fields are null (no sales yet, metrics missing, etc.). + +### Technical requirements + +- New vertical slice `app/features/ops/` with `__init__.py`, `schemas.py`, `service.py`, + `routes.py`, `tests/`. **No `models.py`, no migration** (read-only — mirrors `analytics`). +- Server-side SQL aggregation (`COUNT … GROUP BY`, `DISTINCT ON`) — never fetch lists and + count in Python. +- RFC 7807 errors, Pydantic v2 response models, SQLAlchemy 2.0 async, `mypy --strict` + + `pyright --strict` clean. +- Frontend: new page + hook module + pure util module (+ vitest tests) + route + nav item + + API response types. + +### Success Criteria + +- [ ] `GET /ops/summary` → 200 with `system`, `jobs`, `runs`, `aliases`, `freshness`, + `attention_items`, `generated_at`. +- [ ] `GET /ops/retraining-candidates` → 200, candidates sorted by `priority_score` desc, + honoring `limit`; 422 when `limit` is outside `[1, 100]`. +- [ ] `GET /ops/summary` → 200 (never 500) when the database has no jobs/runs/aliases. +- [ ] `/ops` page renders all five sections, appears in the top nav, and attention items + link to the correct Explorer detail routes. +- [ ] Backend reads sibling slices via **ORM models only** (no `service.py`/`schemas.py` + cross-slice imports); the vertical-slice tension is called out in the PR description. +- [ ] All validation gates pass: `ruff`, `mypy --strict`, `pyright --strict`, `pytest` + (unit + integration), frontend `tsc` + `lint` + `test`. +- [ ] No new external dependency, no new table, no Alembic migration. + +--- + +## All Needed Context + +### DECISIONS LOCKED (resolved during planning — do NOT re-litigate) + +1. **Backend approach** — a real `app/features/ops/` slice that **imports the ORM models** + of sibling slices (`Job`, `ModelRun`, `DeploymentAlias`, `SalesDaily`) for server-side + SQL aggregation. This is a deliberate, accepted tension with the *"a slice may NOT import + from another slice"* rule (`AGENTS.md` § Architecture). `data_platform` ORM is already a + sanctioned cross-slice import (`analytics` uses it); importing `jobs`/`registry` ORM is + the new tension. **Restrict imports to ORM models + read-only `select()` — NEVER import a + sibling `service.py` or `schemas.py`.** The chosen alternative over an `ASGITransport` + in-process-HTTP approach (the `demo` slice's pattern). **MUST be called out in the PR + description** per `.claude/rules/product-vision.md` § "When Ideas Don't Align". +2. **Scope** — feature-doc MVP **plus** the retraining-candidate queue: two endpoints total. + `/ops/model-health` and `/ops/job-health` from the feature brief are **folded into + `/ops/summary`** (model-health → the alias section; job-health → the jobs section). + **Deferred — do NOT build:** drift indicators, bulk-action queue, action drawer, + WebSocket live updates, exportable incident report. +3. **Provider health** — `config.service.get_provider_health()` is a *service function*, so + the `ops` backend does NOT import it. The frontend reuses the **existing** + `useProviderHealth()` hook from `frontend/src/hooks/use-config.ts` (it already calls + `GET /config/providers/health`). + +### Documentation & References + +```yaml +# MUST READ — backend slice pattern to mirror exactly +- file: app/features/analytics/routes.py + why: Canonical read-only aggregation router. Router decl: `router = APIRouter(prefix="/analytics", tags=["analytics"])`. Endpoint signatures, Query() validation, `db: AsyncSession = Depends(get_db)`, `response_model=`. Imports header lines 1-23. +- file: app/features/analytics/service.py + why: `AnalyticsService` class; SQLAlchemy 2.0 `select()` + `func.sum/count`; `.where()` before `.group_by()`; `DISTINCT ON` latest-per-grain in `compute_inventory_status`; `result.one()/.all()`; `logger.info("analytics.", ...)`. +- file: app/features/analytics/schemas.py + why: Pydantic v2 response models — `model_config = ConfigDict(from_attributes=True)`, `Field(..., description=...)`, str-Enum pattern. +- file: app/features/analytics/__init__.py + why: EXACT slice `__init__.py` shape to mirror (docstring + imports + `__all__`). +- file: app/features/analytics/tests/conftest.py + why: `db_session` + `client` fixtures; `app.dependency_overrides[get_db]`; `AsyncClient(transport=ASGITransport(app=app), base_url="http://test")`; `TEST-`-prefixed sample data; FK-safe cleanup. +- file: app/features/analytics/tests/test_routes_integration.py + why: `@pytest.mark.integration` + `@pytest.mark.asyncio`, `client.get(path, params=...)`, status + JSON assertions. +- file: app/features/analytics/tests/test_schemas.py + why: Unmarked unit tests for Pydantic construction/validation. + +# MUST READ — data-source ORM models (the `ops` service queries these) +- file: app/features/jobs/models.py + why: `Job` table `job`; `JobStatus`/`JobType` enums; columns incl. `job_id`, `status`, `completed_at`, `created_at`, `error_message`, `error_type`, `run_id`. +- file: app/features/registry/models.py + why: `ModelRun` table `model_run` (status, model_type, metrics JSONB, data_window_end, store_id, product_id, completed_at, created_at); `RunStatus` enum; `DeploymentAlias` table `deployment_alias` (alias_name, run_id FK→model_run.id, relationship `.run`). +- file: app/features/data_platform/models.py + why: `SalesDaily` (table `sales_daily`, column `date: Mapped[datetime.date]`) — for the latest-sales-date freshness query. +- file: app/core/database.py + why: `get_db` async dependency (auto-commit/rollback) and `Base`. Import: `from app.core.database import get_db`. +- file: app/core/exceptions.py + why: `BadRequestError` etc.; RFC 7807 handler already registered in `app/main.py`. +- file: app/core/health.py + why: DB-connectivity check pattern — `await db.execute(text("SELECT 1"))` in try/except. +- file: app/main.py + why: router import + `app.include_router(...)` wiring (analytics: import line 17, include line 133). + +# MUST READ — frontend patterns +- file: frontend/src/pages/visualize/demand.tsx + why: Closest dense data page — header, error→loading→empty early returns, hooks, useMemo, Card/Table, inline helper subcomponents, `@/` imports. +- file: frontend/src/hooks/use-runs.ts + why: TanStack Query hook module pattern to mirror for `use-ops.ts`. +- file: frontend/src/hooks/use-jobs.ts + why: `refetchInterval` polling pattern. +- file: frontend/src/hooks/use-config.ts + why: ALREADY EXPORTS `useProviderHealth()` — reuse it, do NOT duplicate. +- file: frontend/src/lib/api.ts + why: `api(endpoint, config)` generic client; `ApiError`; `formatNumber`/`formatPercent`. +- file: frontend/src/types/api.ts + why: response-type conventions; `ProviderHealth` already at line ~575; reuse `JobStatus`/`RunStatus` unions. +- file: frontend/src/lib/constants.ts + why: `ROUTES` object + `NAV_ITEMS` array — add `ROUTES.OPS` + a single-link nav entry. +- file: frontend/src/lib/status-utils.ts + why: ALREADY EXPORTS `getStatusVariant(status)` → StatusBadge variant — reuse for job/run status badges. +- file: frontend/src/App.tsx + why: lazy-load + `` + `` inside `}>`. +- file: frontend/src/components/charts/kpi-card.tsx + why: `KPICard` props — `title`, `value:string|number`, `description?`, `icon?:LucideIcon`, `trend?`, `isLoading?`. +- file: frontend/src/components/common/status-badge.tsx + why: `StatusBadge` variants — `default|success|warning|error|info|pending`. +- file: frontend/src/components/common/error-display.tsx + why: `ErrorDisplay({error,title?,onRetry?})` + `EmptyState({title,description?,action?,icon?})`. +- file: frontend/src/lib/knowledge-utils.ts + why: pattern for a PURE util module + colocated `*.test.ts` vitest file. + +# External docs +- url: https://docs.sqlalchemy.org/en/20/tutorial/data_select.html#aggregate-functions-with-group-by-having + why: `select(Job.status, func.count()).group_by(Job.status)` aggregation. +- url: https://docs.sqlalchemy.org/en/20/dialects/postgresql.html#postgresql-distinct-on + why: `.distinct(col, col)` DISTINCT ON — order_by MUST lead with the same columns. +- url: https://docs.sqlalchemy.org/en/20/orm/extensions/asyncio.html + why: `await db.execute(select(...))`, `.scalar()`, `.scalars()`, `.all()`. +- url: https://fastapi.tiangolo.com/tutorial/query-params-str-validations/ + why: `Query(default=20, ge=1, le=100)` bounded validation. +- url: https://tanstack.com/query/latest/docs/framework/react/reference/useQuery + why: `refetchInterval`, `enabled`, `queryKey` for polling hooks. +- url: https://docs.pydantic.dev/latest/concepts/config/ + why: `ConfigDict(from_attributes=True)` for ORM-row → Pydantic response models. + +- docfile: .agents/plans/forecastops-control-center.md + why: The full 17-task implementation plan with per-task IMPLEMENT/PATTERN/GOTCHA/VALIDATE. +- docfile: docs/optional-features/02-forecastops-control-center.md + why: Original feature brief — UX sections, risk model, validation plan. +``` + +### Current Codebase tree (relevant subset) + +```bash +app/ +├── main.py # router wiring — UPDATE +├── core/ +│ ├── database.py # get_db, Base +│ ├── exceptions.py # BadRequestError, RFC 7807 +│ └── health.py # SELECT 1 connectivity pattern +├── features/ +│ ├── analytics/ # ← MIRROR THIS SLICE (read-only aggregation) +│ │ ├── __init__.py routes.py schemas.py service.py +│ │ └── tests/ (conftest.py, test_routes_integration.py, test_schemas.py) +│ ├── jobs/models.py # Job, JobStatus, JobType +│ ├── registry/models.py # ModelRun, RunStatus, DeploymentAlias +│ └── data_platform/models.py # SalesDaily +frontend/src/ +├── App.tsx # route registration — UPDATE +├── hooks/ (use-runs.ts, use-jobs.ts, use-config.ts, index.ts) +├── lib/ (api.ts, constants.ts, status-utils.ts, knowledge-utils.ts) +├── types/api.ts # response types — UPDATE +├── pages/visualize/demand.tsx # ← MIRROR for the dense page layout +└── components/ (charts/kpi-card.tsx, common/status-badge.tsx, common/error-display.tsx) +``` + +### Desired Codebase tree (files to add / touch) + +```bash +app/features/ops/ # NEW SLICE — read-only, no models.py, no migration +├── __init__.py # NEW — slice exports (mirror analytics/__init__.py) +├── schemas.py # NEW — Pydantic v2 response models +├── service.py # NEW — OpsService + pure score/extract helpers +├── routes.py # NEW — APIRouter(prefix="/ops") + 2 endpoints +└── tests/ + ├── __init__.py # NEW — empty package marker + ├── conftest.py # NEW — db_session, client, sample-data fixtures + ├── test_schemas.py # NEW — unit (unmarked) + ├── test_service.py # NEW — unit for score_retraining_candidate/extract_wape + └── test_routes_integration.py # NEW — @pytest.mark.integration + +app/main.py # UPDATE — import + include_router(ops_router) + +frontend/src/ +├── hooks/use-ops.ts # NEW — useOpsSummary, useRetrainingCandidates +├── hooks/index.ts # UPDATE — export * from './use-ops' +├── lib/ops-utils.ts # NEW — pure helpers +├── lib/ops-utils.test.ts # NEW — vitest unit tests +├── pages/ops.tsx # NEW — the Control Center page +├── types/api.ts # UPDATE — Ops* response interfaces +├── lib/constants.ts # UPDATE — ROUTES.OPS + NAV_ITEMS entry +└── App.tsx # UPDATE — lazy import + +``` + +### Known Gotchas & Library Quirks + +```python +# CRITICAL: ORM status columns are String, NOT enum-typed. Compare against the +# `.value`, never the enum object — mirror registry/service.py & jobs/service.py: +# select(Job.status, func.count()).where(Job.status == JobStatus.COMPLETED.value) +# +# CRITICAL: ruff DTZ rules — do NOT use `date.today()` or naive `datetime.now()`. +# Use timezone-aware forms: +# now = datetime.now(UTC) # from `datetime import UTC` +# today = datetime.now(UTC).date() # for staleness math +# +# CRITICAL: PostgreSQL DISTINCT ON — `.distinct(a, b)` REQUIRES `order_by` to lead +# with the SAME columns. Order the "latest" tiebreaker by `created_at.desc()` +# (TimestampMixin, always non-null) — NOT `completed_at` (nullable; DESC puts +# NULLs first in Postgres and would pick a NULL-completed row): +# select(ModelRun).where(ModelRun.status == RunStatus.SUCCESS.value) +# .distinct(ModelRun.store_id, ModelRun.product_id) +# .order_by(ModelRun.store_id, ModelRun.product_id, ModelRun.created_at.desc()) +# +# CRITICAL: `func.count()` with no arg = COUNT(*) — valid, used by analytics. +# `select(Col, func.count()).group_by(Col)` returns only EXISTING statuses; +# zero-fill the missing enum members in Python. +# +# CRITICAL: AsyncSession FORBIDS implicit IO / lazy-loading (SQLAlchemy async +# docs). The alias query MUST select BOTH entities — +# `select(DeploymentAlias, ModelRun).join(ModelRun, DeploymentAlias.run_id == ModelRun.id)` +# — rows come back as (alias, run) tuples; use the joined `ModelRun` row +# DIRECTLY. NEVER touch the `DeploymentAlias.run` relationship attribute — it +# triggers a lazy load → `MissingGreenlet` error. Same rule for every +# relationship: eager-select it or `selectinload()` it; never access it lazily. +# +# CRITICAL: model_run.metrics JSONB is frequently None or lacks WAPE — backtest +# metrics persist to job.result, NOT model_run.metrics (only an explicit +# update_run writes run metrics). `extract_wape()` MUST tolerate None / unrelated +# dicts / non-numeric values. Scoring MUST NEVER raise on missing data. +# +# CRITICAL: DeploymentAlias.run_id is the INTEGER model_run.id (FK), NOT the +# 32-char run_id string. In fixtures set it from the persisted run's `.id`. +# Insert ModelRun before DeploymentAlias; clean up DeploymentAlias first. +# +# CRITICAL: Pydantic strict-mode linter (app/core/tests/test_strict_mode_policy.py) +# only inspects request models with ConfigDict(strict=True). Ops schemas are +# RESPONSE models with ConfigDict(from_attributes=True) — date/datetime fields +# need NO Field(strict=False). Do NOT add strict=True. +# +# CRITICAL: Cross-slice ORM import is the ACCEPTED design (decision #1). No CI +# import-linter enforces slice boundaries, so this will not fail the build — +# but it MUST be flagged in the PR description. +# +# GOTCHA: commit-format scope allow-list has NO `ops` scope. Use feat(api): for +# backend commits, feat(ui): for frontend, feat(api,ui): for the wiring commit. +# +# GOTCHA: the /ops page renders inside AppShell () — do NOT add nav, +# container, or Toaster. Never hardcode raw colors — use shadcn variants / +# semantic tokens (.claude/rules/shadcn-ui.md). +``` + +--- + +## External Research Findings + +Verified May 2026 against the docs the feature brief cited. Each finding ends with a +**verdict** — what it changes (or confirms) for this PRP. + +### 1. SQLAlchemy async ORM — `https://docs.sqlalchemy.org/en/21/orm/extensions/asyncio.html` + +The repo pins `sqlalchemy[asyncio]>=2.0.36`; the doc page is the 2.1 line — the async +contract below is identical across 2.0/2.1. + +- Execution idioms confirmed: `await session.execute(stmt)` → `.scalars()` / `.all()` / + `.one()`; single aggregate value via `await session.scalar(select(func.max(col)))`. +- **CRITICAL — implicit IO is forbidden.** *"the application needs to avoid any points at + which IO-on-attribute access may occur."* Accessing an un-loaded relationship under + `AsyncSession` raises `MissingGreenlet`. Two safe patterns: eager-select the related + entity in the same `select(...)`, or `selectinload()` it. +- A single `AsyncSession` is not safe across concurrent tasks — irrelevant here (one + request = one `get_db` session), but do not `asyncio.gather` queries on the same session. +- **Verdict — applied.** The alias query already selects both entities + (`select(DeploymentAlias, ModelRun).join(...)`). Added a CRITICAL gotcha: iterate the + `(alias, run)` tuples and use the joined `ModelRun` directly; **never touch + `DeploymentAlias.run`**. No other relationship is accessed, so no `selectinload()` is + needed. No code-shape change beyond the explicit warning. + +### 2. FastAPI query-param validation — `https://fastapi.tiangolo.com/tutorial/query-params-str-validations/` + +- Bounded numeric query params: `Query(ge=…, le=…)`; a violation returns **HTTP 422** + (`Unprocessable Entity`) automatically — confirms the `?limit=0` / `?limit=200` → 422 + test cases. +- Current docs favour the `Annotated[int, Query(ge=1, le=100)] = 20` form over the + legacy `limit: int = Query(default=20, ge=1, le=100)` form. +- **Verdict — mirror the repo, not the docs.** `analytics/routes.py` uses the non-`Annotated` + `Query(...)` form. Consistency with the existing slice wins (`.claude/rules` § "don't + create new patterns when existing ones work"). Keep the PRP's `Query(default=20, ge=1, + le=100, …)` signature; if `analytics/routes.py` is found to use `Annotated`, match that + instead. + +### 3. TanStack Query v5 — `https://tanstack.com/query/.../guides/important-defaults` + +- v4→v5 renames: `cacheTime` → `gcTime`; `isLoading` → `isPending` (`isLoading` still + exists = `isPending && isFetching`); `keepPreviousData` is now + `placeholderData: keepPreviousData`. `use-runs.ts` already uses the v5 form — mirror it. +- Polling: pair `refetchInterval` with `refetchOnWindowFocus: false` to avoid focus-storms. + The repo's `query-client.ts` already sets `refetchOnWindowFocus: false` and + `staleTime: 5min` globally. +- **Verdict — applied.** `useOpsSummary` keeps `refetchInterval: 15000` (operational state + worth polling); `useRetrainingCandidates` gets **no** `refetchInterval` (slow-moving — + refetch-on-mount suffices). Task 12 updated accordingly. + +### 4. MLflow Model Registry — `https://www.mlflow.org/docs/latest/ml/model-registry/` + +- MLflow defines an **alias** as *"a mutable, named reference to a particular version of a + registered model"*; the `champion`/production-alias promotion pattern decouples + deployment from a specific version. +- ForecastLabAI's `DeploymentAlias` is the same concept. MLflow frames alias governance as + managing **staleness** — an alias is "stale" when it still points at an old version + after a better one exists. +- **Verdict — confirms the design.** The PRP's `is_stale` detection (alias → non-`success` + run, or a newer `success` run exists for the same store/product) is the industry-standard + alias-staleness check. No change; cite MLflow as the conceptual basis in the PR. + +### 5. NIST AI RMF — `https://www.nist.gov/itl/ai-risk-management-framework` + +- Four core functions: **Govern, Map, Measure, Manage**. *Measure* = continuously track + trustworthiness/performance of deployed AI; *Manage* = act on what monitoring surfaces. + (Operational depth lives in the AI RMF 1.0 PDF + Playbook, not the overview page.) +- **Verdict — framing only.** The Control Center operationalises *Measure* (model-health + metrics, freshness) and *Manage* (the "needs attention" + retraining queue). Useful + one-line justification for the PR description; no implementation impact. + +### 6. Model-retraining triggers (MLOps best practice — web search, May 2026) + +- Established taxonomy: **time-based** (simple, predictable, but "may lead to unnecessary + retraining"), **performance-based** (retrain on metric degradation — needs monitoring), + **drift-based** (data/concept drift — needs drift detection). Sources recommend combining + signals over a pure time-based trigger. +- **Verdict — confirms the heuristic.** The PRP's score blends a **time-based** signal + (staleness) with a **performance-based** signal (WAPE) — exactly the recommended hybrid. + Drift-based is correctly deferred (needs infra the repo doesn't have). When WAPE is + unknown the score degrades to time-based only — an acceptable, documented fallback. The + 60/40 staleness/error weighting is a defensible, deterministic heuristic; keep it. +- Sources: [When to Retrain Your ML Models](https://tech.flowblog.io/blog/when-to-retrain-your-ml-models-for-success), + [Model Retraining 2026 (AIMultiple)](https://research.aimultiple.com/model-retraining/), + [CMU SEI — Automated Retraining](https://www.sei.cmu.edu/blog/improving-automated-retraining-of-machine-learning-models/). + +### 7. WAPE as the error signal (web search, May 2026) + +- WAPE/WMAPE is volume-weighted — a miss on a high-volume SKU counts more — and is not + destabilised by low-demand items; it is the recommended single accuracy metric for + demand forecasting. sMAPE is widely considered broken (unstable near zero, can go + negative). MAE is interpretable but scale-blind. +- **Verdict — confirms the choice.** Using WAPE as the score's error component is correct. + There is no universal "bad WAPE" threshold; the score's cap at WAPE 100 (total error = + total demand) is a reasonable normalisation ceiling. Keep it. +- Sources: [Forecast Accuracy Metrics 2026](https://prospeo.io/s/forecast-accuracy-metrics), + [MAPE vs WMAPE vs SMAPE](https://medium.com/@vinitkothari.24/time-series-evaluation-metrics-mape-vs-wmape-vs-smape-which-one-to-use-why-and-when-part1-32d3852b4779). + +### 8. Cited docs assessed as out-of-scope (do NOT pull these in) + +- **OpenTelemetry** (`opentelemetry.io`, `opentelemetry-python-contrib…fastapi`) — the repo + deliberately ships **no metrics/traces** (`docs/_base/SECURITY.md`: "Metrics — none … + Traces — none"). The `/ops` Control Center **is** the observability surface. Adding OTel + would be a new dependency and a scope violation — **do not add it.** +- **scikit-learn model persistence / `TimeSeriesSplit`** — the `ops` slice does no training + or cross-validation; those belong to `forecasting`/`backtesting`. Not relevant here. +- **Recharts** — MVP is cards + tables (no charts). Recharts is already available + (`frontend/src/components/ui/chart.tsx`) if a sparkline is wanted in a later iteration; + deferred, not part of this PRP. + +--- + +## Implementation Blueprint + +### Data models and structure — `app/features/ops/schemas.py` + +All response models. Every model: `model_config = ConfigDict(from_attributes=True)`. +Every field: `Field(..., description="...")`. Counts: `Field(..., ge=0, ...)`. + +```python +# Pydantic v2 response models (NOT request bodies — no strict=True) +class SystemHealth(BaseModel): + api_ok: bool + database_connected: bool + latest_successful_job_at: datetime | None + +class StatusCount(BaseModel): + status: str + count: int = Field(..., ge=0) + +class JobHealth(BaseModel): + counts: list[StatusCount] # one per JobStatus, zero-filled + completed_today: int = Field(..., ge=0) + failed_total: int = Field(..., ge=0) + active_total: int = Field(..., ge=0) # pending + running + +class RunHealth(BaseModel): + counts: list[StatusCount] # one per RunStatus, zero-filled + success_rate: float | None # success / (total - archived); None if denom 0 + failed_total: int = Field(..., ge=0) + +class AliasHealth(BaseModel): + alias_name: str + run_id: str + run_status: str + model_type: str + store_id: int + product_id: int + is_stale: bool + stale_reason: str | None + wape: float | None + +class DataFreshness(BaseModel): + latest_sales_date: date | None + latest_job_completed_at: datetime | None + latest_run_completed_at: datetime | None + +class AttentionItem(BaseModel): + item_type: Literal["failed_job", "failed_run", "stale_alias"] + entity_id: str # job_id for failed_job; run_id for failed_run AND stale_alias + label: str + detail: str + occurred_at: datetime | None + +class OpsSummaryResponse(BaseModel): + system: SystemHealth + jobs: JobHealth + runs: RunHealth + aliases: list[AliasHealth] + freshness: DataFreshness + attention_items: list[AttentionItem] + generated_at: datetime + +class RetrainingCandidate(BaseModel): + store_id: int + product_id: int + priority_score: float = Field(..., ge=0.0, le=1.0) + staleness_days: int = Field(..., ge=0) + wape: float | None + latest_run_id: str | None + latest_run_status: str | None + reason: str + +class RetrainingCandidatesResponse(BaseModel): + candidates: list[RetrainingCandidate] + total_evaluated: int = Field(..., ge=0) + generated_at: datetime +``` + +### Per-task pseudocode (critical details — not full code) + +```python +# ── app/features/ops/service.py — pure helpers (module scope, above OpsService) ── + +def extract_wape(metrics: dict[str, Any] | None) -> float | None: + # GOTCHA: match the param type to ModelRun.metrics' Mapped[...] annotation. + # Try "wape", "wape_mean", "WAPE"; return first numeric (int|float, not bool); else None. + if not metrics: + return None + for key in ("wape", "wape_mean", "WAPE"): + v = metrics.get(key) + if isinstance(v, (int, float)) and not isinstance(v, bool): + return float(v) + return None + +def score_retraining_candidate(staleness_days: int, wape: float | None) -> float: + """Retraining priority in [0.0, 1.0]; higher = more urgent. + 60% staleness (cap 90 days) + 40% error (cap WAPE 100).""" + staleness_norm = min(max(staleness_days, 0), 90) / 90.0 + error_norm = min(max(wape, 0.0), 100.0) / 100.0 if wape is not None else 0.0 + return round(0.6 * staleness_norm + 0.4 * error_norm, 4) + + +# ── OpsService — no custom __init__; just two async methods ── + +class OpsService: + async def get_summary(self, db: AsyncSession) -> OpsSummaryResponse: + now = datetime.now(UTC) + + # SYSTEM + try: + await db.execute(text("SELECT 1")) + db_ok = True + except Exception: # noqa: BLE001 — connectivity probe + db_ok = False + latest_job = await db.scalar( + select(func.max(Job.completed_at)).where(Job.status == JobStatus.COMPLETED.value) + ) + + # JOBS — server-side GROUP BY, zero-fill the enum + job_rows = (await db.execute( + select(Job.status, func.count()).group_by(Job.status))).all() + job_map = {s: c for s, c in job_rows} + job_counts = [StatusCount(status=s.value, count=job_map.get(s.value, 0)) + for s in JobStatus] + start_of_day = now.replace(hour=0, minute=0, second=0, microsecond=0) + completed_today = await db.scalar(select(func.count()).select_from(Job).where( + Job.status == JobStatus.COMPLETED.value, Job.completed_at >= start_of_day)) or 0 + # active_total = pending + running ; failed_total = failed (from job_map) + + # RUNS — same GROUP BY pattern over RunStatus; success_rate = success/(total-archived) + + # ALIASES — join, compute staleness + alias_rows = (await db.execute( + select(DeploymentAlias, ModelRun) + .join(ModelRun, DeploymentAlias.run_id == ModelRun.id))).all() + # For each (alias, run): is_stale via _is_alias_stale(run, db-derived newer-success), + # wape = extract_wape(run.metrics). + + # FRESHNESS + latest_sales_date = await db.scalar(select(func.max(SalesDaily.date))) + # latest_job_completed_at, latest_run_completed_at (status==SUCCESS) likewise. + + # ATTENTION ITEMS — 10 most-recent failed jobs + 10 failed runs + every stale alias + failed_jobs = (await db.execute( + select(Job).where(Job.status == JobStatus.FAILED.value) + .order_by(Job.created_at.desc()).limit(10))).scalars().all() + # failed_job → AttentionItem(entity_id=job.job_id, occurred_at=job.created_at, ...) + # failed_run → AttentionItem(entity_id=run.run_id, ...) + # stale_alias → AttentionItem(entity_id=, label="alias ''", ...) + + logger.info("ops.summary_computed", db_ok=db_ok, failed_jobs=..., stale_aliases=...) + return OpsSummaryResponse(...) + + async def get_retraining_candidates(self, db: AsyncSession, limit: int + ) -> RetrainingCandidatesResponse: + today = datetime.now(UTC).date() + # latest SUCCESS run per (store, product) — DISTINCT ON, order_by created_at.desc() + runs = (await db.execute( + select(ModelRun).where(ModelRun.status == RunStatus.SUCCESS.value) + .distinct(ModelRun.store_id, ModelRun.product_id) + .order_by(ModelRun.store_id, ModelRun.product_id, ModelRun.created_at.desc()) + )).scalars().all() + candidates = [] + for run in runs: + staleness = (today - run.data_window_end).days + wape = extract_wape(run.metrics) + score = score_retraining_candidate(staleness, wape) + reason = f"{staleness}d since last train window" + ( + f"; WAPE {wape:.1f}" if wape is not None else "; WAPE unknown") + candidates.append(RetrainingCandidate( + store_id=run.store_id, product_id=run.product_id, priority_score=score, + staleness_days=max(staleness, 0), wape=wape, latest_run_id=run.run_id, + latest_run_status=run.status, reason=reason)) + candidates.sort(key=lambda c: c.priority_score, reverse=True) + return RetrainingCandidatesResponse( + candidates=candidates[:limit], total_evaluated=len(candidates), + generated_at=datetime.now(UTC)) +``` + +```python +# ── app/features/ops/routes.py ── +router = APIRouter(prefix="/ops", tags=["ops"]) + +@router.get("/summary", response_model=OpsSummaryResponse, + summary="Operational summary for the Control Center") +async def get_ops_summary(db: AsyncSession = Depends(get_db)) -> OpsSummaryResponse: + return await OpsService().get_summary(db) + +@router.get("/retraining-candidates", response_model=RetrainingCandidatesResponse, + summary="Ranked retraining-candidate queue") +async def get_retraining_candidates( + limit: int = Query(default=20, ge=1, le=100, description="Max candidates to return"), + db: AsyncSession = Depends(get_db), +) -> RetrainingCandidatesResponse: + return await OpsService().get_retraining_candidates(db, limit) +``` + +```typescript +// ── frontend/src/lib/ops-utils.ts — PURE (no React, no fetch) ── +import { ROUTES } from '@/lib/constants' +import type { AttentionItem, RetrainingCandidate, SystemHealth } from '@/types/api' + +export function summaryHealthVariant(s: SystemHealth): 'success' | 'error' { + return s.api_ok && s.database_connected ? 'success' : 'error' +} +export function attentionItemLink(item: AttentionItem): string { + // failed_job → /explorer/jobs/:id ; failed_run + stale_alias → /explorer/runs/:id + if (item.item_type === 'failed_job') return `/explorer/jobs/${item.entity_id}` + return `/explorer/runs/${item.entity_id}` +} +export function attentionBadgeVariant(t: AttentionItem['item_type']): 'error' | 'warning' { + return t === 'stale_alias' ? 'warning' : 'error' +} +export function formatStaleness(days: number): string { + return days <= 0 ? 'today' : `${days}d` +} +export function sortRetrainingCandidates(rows: RetrainingCandidate[]): RetrainingCandidate[] { + return [...rows].sort((a, b) => b.priority_score - a.priority_score) +} +``` + +### list of tasks to be completed (in order) + +```yaml +Task 1 — CREATE app/features/ops/schemas.py: + - MIRROR pattern from: app/features/analytics/schemas.py + - DEFINE the 11 response models above; every model ConfigDict(from_attributes=True) + - IMPORTS: from datetime import date, datetime; from typing import Literal; + from pydantic import BaseModel, ConfigDict, Field + - GOTCHA: response models — NO ConfigDict(strict=True), NO Field(strict=False) + - VALIDATE: uv run python -c "from app.features.ops.schemas import OpsSummaryResponse, RetrainingCandidatesResponse; print('ok')" + +Task 2 — CREATE app/features/ops/service.py (pure helpers first): + - ADD module-scope `extract_wape` and `score_retraining_candidate` (pseudocode above) + - IMPORTS: from typing import Any + - GOTCHA: never raise on None/missing metrics + - VALIDATE: uv run python -c "from app.features.ops.service import score_retraining_candidate as s; assert s(90,100.0)==1.0 and s(0,None)==0.0; print('ok')" + +Task 3 — CREATE app/features/ops/tests/__init__.py + test_schemas.py + test_service.py: + - __init__.py empty + - test_schemas.py: construct each model; assert ge=0 rejects negatives (pytest.raises(ValidationError)) + - test_service.py: score boundaries (0,None)->0.0, (90,100.0)->1.0, mid, negative clamp, + wape>100 clamp; extract_wape each key / None / {} / non-numeric / bool + - MIRROR: app/features/analytics/tests/test_schemas.py — plain def test_*(), UNMARKED + - VALIDATE: uv run pytest -v -m "not integration" app/features/ops/tests/test_schemas.py app/features/ops/tests/test_service.py + +Task 4 — UPDATE app/features/ops/service.py — implement OpsService: + - ADD class OpsService (no custom __init__) with get_summary + get_retraining_candidates + - MIRROR: AnalyticsService.compute_inventory_status (DISTINCT ON), compute_kpis (func + scalar) + - IMPORTS: from datetime import UTC, datetime; + from sqlalchemy import func, select, text; + from sqlalchemy.ext.asyncio import AsyncSession; + from app.core.logging import get_logger; + from app.features.jobs.models import Job, JobStatus; + from app.features.registry.models import DeploymentAlias, ModelRun, RunStatus; + from app.features.data_platform.models import SalesDaily; + from app.features.ops.schemas import (... all ...) + - GOTCHA: compare status == Enum.X.value; use datetime.now(UTC); created_at.desc() for DISTINCT ON + - VALIDATE: uv run mypy app/features/ops/ && uv run pyright app/features/ops/ + +Task 5 — CREATE app/features/ops/routes.py: + - MIRROR: app/features/analytics/routes.py header + endpoint signatures + - router = APIRouter(prefix="/ops", tags=["ops"]); 2 endpoints (pseudocode above) + - VALIDATE: uv run python -c "from app.features.ops.routes import router; print(sorted(r.path for r in router.routes))" + +Task 6 — CREATE app/features/ops/__init__.py: + - MIRROR: app/features/analytics/__init__.py — docstring + imports + __all__ + - VALIDATE: uv run python -c "from app.features.ops import router, OpsService; print('ok')" + +Task 7 — UPDATE app/main.py: + - FIND the block of `from app.features..routes import router as _router` imports + - INJECT: from app.features.ops.routes import router as ops_router + - FIND in create_app(): the run of `app.include_router(...)` calls + - INJECT: app.include_router(ops_router) (e.g. after analytics_router) + - PRESERVE ruff import sorting (keep grouped with app.features.* imports) + - VALIDATE: uv run python -c "from app.main import app; p={r.path for r in app.routes}; assert '/ops/summary' in p and '/ops/retraining-candidates' in p; print('wired')" + +Task 8 — CREATE app/features/ops/tests/conftest.py: + - MIRROR: app/features/analytics/tests/conftest.py (db_session, client fixtures verbatim; + extend TEST- cleanup to Job/ModelRun/DeploymentAlias — DeploymentAlias before ModelRun) + - ADD fixtures: sample_jobs (statuses incl. failed+error_message), + sample_runs (statuses incl. success with metrics={"wape":31.0} + failed; varied + store_id/product_id/data_window_end), sample_alias (DeploymentAlias→success run), + sample_sales (a couple SalesDaily rows) + - GOTCHA: DeploymentAlias.run_id = persisted run.id (int); insert ModelRun first + - VALIDATE: uv run pytest -m integration app/features/ops/tests/ --collect-only + +Task 9 — CREATE app/features/ops/tests/test_routes_integration.py: + - @pytest.mark.integration + @pytest.mark.asyncio + - tests: /ops/summary 200 happy (seeded) ; /ops/summary 200 resilient (no fixtures → + counts >= 0, status keys all present) ; /ops/retraining-candidates 200 sorted desc, + len <= limit ; ?limit=0 → 422 ; ?limit=200 → 422 + - MIRROR: app/features/analytics/tests/test_routes_integration.py + - GOTCHA: idempotent — assert structural invariants, not exact global totals + - VALIDATE: docker compose up -d && uv run pytest -v -m integration app/features/ops/ + +Task 10 — UPDATE frontend/src/types/api.ts: + - ADD interfaces: SystemHealth, StatusCount, JobHealth, RunHealth, AliasHealth, + DataFreshness, AttentionItem, OpsSummaryResponse, RetrainingCandidate, + RetrainingCandidatesResponse (dates as string; nullable → `| null`) + - VALIDATE: cd frontend && pnpm tsc --noEmit + +Task 11 — CREATE frontend/src/lib/ops-utils.ts + ops-utils.test.ts: + - IMPLEMENT pure functions (pseudocode above) + - MIRROR: frontend/src/lib/knowledge-utils.ts + knowledge-utils.test.ts + - VALIDATE: cd frontend && pnpm test --run src/lib/ops-utils.test.ts + +Task 12 — CREATE frontend/src/hooks/use-ops.ts + UPDATE hooks/index.ts: + - useOpsSummary(enabled=true): queryKey ['ops','summary'], api('/ops/summary'), + refetchInterval: 15000 (operational state — poll. Global query-client already + sets refetchOnWindowFocus:false, so this won't double-fire on tab focus.) + - useRetrainingCandidates(limit=20, enabled=true): queryKey ['ops','retraining',limit], + api('/ops/retraining-candidates', {params:{limit}}) + — NO refetchInterval: the queue moves slowly (changes only on a new run); + refetch-on-mount + manual invalidation is sufficient. Avoids needless load. + - DO NOT add useProviderHealth — it already exists in use-config.ts; reuse that. + - index.ts: add `export * from './use-ops'` + - MIRROR: frontend/src/hooks/use-runs.ts, use-jobs.ts + - VALIDATE: cd frontend && pnpm tsc --noEmit + +Task 13 — UPDATE frontend/src/lib/constants.ts: + - ADD `OPS: '/ops'` to ROUTES (after SHOWCASE) + - ADD `{ label: 'Control Center', href: ROUTES.OPS }` to NAV_ITEMS (after Showcase) + - PRESERVE the `as const` literal types + - VALIDATE: cd frontend && pnpm tsc --noEmit + +Task 14 — UPDATE frontend/src/App.tsx: + - ADD `const OpsPage = lazy(() => import('@/pages/ops'))` + - ADD `}>} />` + inside the }> block + - VALIDATE: cd frontend && pnpm tsc --noEmit && pnpm lint + +Task 15 — CREATE frontend/src/pages/ops.tsx: + - export default function OpsPage() + - hooks: useOpsSummary(), useRetrainingCandidates(), useProviderHealth() [from use-config] + - early returns: ErrorDisplay(onRetry) → LoadingState → EmptyState (zero jobs AND runs) + - sections: System Health card, KPI row (KPICard ×4), Data Freshness card, + Needs Attention table (Link to attentionItemLink(item)), Retraining Queue table + - reuse getStatusVariant from @/lib/status-utils for job/run status badges + - MIRROR: frontend/src/pages/visualize/demand.tsx + - GOTCHA: renders inside AppShell — no nav/container; no raw colors + - VALIDATE: cd frontend && pnpm tsc --noEmit && pnpm lint + +Task 16 — FULL validation sweep (all gates — see Validation Loop) + +Task 17 — Browser dogfood per .claude/rules/ui-design.md (webapp-testing / agent-browser) +``` + +### Integration Points + +```yaml +DATABASE: + - migration: NONE — read-only slice, no schema change. + - tables read (existing): job, model_run, deployment_alias, sales_daily. + +ROUTES (backend): + - add to: app/main.py + - import: from app.features.ops.routes import router as ops_router + - wire: app.include_router(ops_router) + +ROUTES (frontend): + - add to: frontend/src/lib/constants.ts → ROUTES.OPS = '/ops' ; NAV_ITEMS entry + - add to: frontend/src/App.tsx → lazy import + + +HOOKS: + - new: frontend/src/hooks/use-ops.ts + - update: frontend/src/hooks/index.ts → export * from './use-ops' + - reuse: useProviderHealth from frontend/src/hooks/use-config.ts (do NOT duplicate) + +CONFIG: none — no new settings, no new env var. +``` + +--- + +## Validation Loop + +### Level 1: Syntax & Style + +```bash +uv run ruff check . --fix +uv run ruff format --check . +cd frontend && pnpm lint +# Expected: no errors. Common trap: date.today() → ruff DTZ — use datetime.now(UTC).date(). +``` + +### Level 2: Type Checks + +```bash +uv run mypy app/ && uv run pyright app/ # both --strict +cd frontend && pnpm tsc --noEmit +# Expected: no errors. +``` + +### Level 3: Unit Tests + +```bash +uv run pytest -v -m "not integration" app/features/ops/ +cd frontend && pnpm test --run src/lib/ops-utils.test.ts +``` + +Backend unit cases (`test_service.py`, pure — no DB, no mocks): +```python +def test_score_zero_when_fresh_and_no_error(): + assert score_retraining_candidate(0, None) == 0.0 + +def test_score_max_when_fully_stale_and_max_error(): + assert score_retraining_candidate(90, 100.0) == 1.0 + +def test_score_clamps_negative_staleness_and_high_wape(): + assert score_retraining_candidate(-5, 250.0) == 0.4 # staleness→0, error→1.0, *0.4 + +def test_extract_wape_prefers_wape_then_wape_mean(): + assert extract_wape({"wape": 12.0}) == 12.0 + assert extract_wape({"wape_mean": 8.5}) == 8.5 + assert extract_wape(None) is None + assert extract_wape({}) is None + assert extract_wape({"wape": "bad"}) is None + assert extract_wape({"wape": True}) is None # bool is not a metric +``` + +### Level 4: Integration Tests + +```bash +docker compose up -d +uv run alembic upgrade head +uv run pytest -v -m integration app/features/ops/ +``` + +`test_routes_integration.py` (`@pytest.mark.integration` + `@pytest.mark.asyncio`): +- `/ops/summary` → 200; `system.database_connected is True`; job & run `counts` cover + every status key; seeded failed job appears in `attention_items`; `freshness.latest_sales_date` set. +- `/ops/summary` with no fixtures → 200 (never 500); all counts `>= 0`; `attention_items` is a list. +- `/ops/retraining-candidates` → 200; `candidates` sorted by `priority_score` desc; `len <= limit`. +- `/ops/retraining-candidates?limit=0` → 422. +- `/ops/retraining-candidates?limit=200` → 422. + +### Level 5: Manual Validation + +```bash +uv run uvicorn app.main:app --reload --port 8123 & +curl -s http://localhost:8123/ops/summary | head -c 400 +curl -s "http://localhost:8123/ops/retraining-candidates?limit=5" | head -c 400 +curl -s -o /dev/null -w '%{http_code}\n' "http://localhost:8123/ops/retraining-candidates?limit=0" # 422 +# Frontend: seed first (make demo), then open http://localhost:5173/ops via +# the webapp-testing skill / agent-browser — verify all 5 sections, nav item, +# attention-item links route to Explorer detail pages, retraining table sorted, +# empty-state on a fresh DB. Type-check passing ≠ UI works. +``` + +--- + +## Final Validation Checklist + +- [ ] `uv run ruff check . && uv run ruff format --check .` — clean +- [ ] `uv run mypy app/ && uv run pyright app/` — clean (`--strict`) +- [ ] `uv run pytest -v -m "not integration"` — green +- [ ] `docker compose up -d && uv run pytest -v -m integration` — green +- [ ] `cd frontend && pnpm tsc --noEmit && pnpm lint && pnpm test --run` — green +- [ ] `GET /ops/summary` and `GET /ops/retraining-candidates` behave per Success Criteria +- [ ] `/ops` page renders all 5 sections in a real browser; nav item present; links work +- [ ] No new dependency, no new table, no migration +- [ ] Backend cross-slice imports are ORM-models-only; PR description flags the tension +- [ ] Commits use `feat(api)` / `feat(ui)` scopes (no `ops` scope exists) and reference an open issue + +--- + +## Anti-Patterns to Avoid + +- ❌ Don't import a sibling slice's `service.py` or `schemas.py` — ORM models only. +- ❌ Don't fetch full lists and count in Python — use `func.count()` + `GROUP BY`. +- ❌ Don't use `date.today()` / naive `datetime.now()` — ruff DTZ; use `datetime.now(UTC)`. +- ❌ Don't add `ConfigDict(strict=True)` to response models. +- ❌ Don't duplicate `useProviderHealth` — reuse the one in `use-config.ts`. +- ❌ Don't re-implement a status→badge mapper — reuse `getStatusVariant` from `status-utils.ts`. +- ❌ Don't create `app/features/ops/models.py` or an Alembic migration. +- ❌ Don't let scoring raise on `None`/missing `metrics` — degrade to staleness-only. +- ❌ Don't claim the UI works on a green type-check — dogfood it in a browser. +- ❌ Don't catch-all silently except for the deliberate DB-connectivity probe. + +--- + +## Workflow Notes + +- Open a GitHub issue first (`gh issue list` / `gh issue create`); branch + `feat/ops-control-center` off `dev` (`.claude/rules/branch-naming.md`); every commit + references the issue and uses `feat(api)` / `feat(ui)` / `feat(api,ai)` scopes; PR into `dev`. +- The cross-slice ORM import (`jobs`, `registry`) is a deliberate, accepted tension with the + vertical-slice rule — **state it explicitly in the PR description** per + `.claude/rules/product-vision.md`. + +## Confidence Score + +**9 / 10** for one-pass implementation success. + +Rationale: the `analytics` slice is a near-exact backend template; every data source, ORM +column, enum, and frontend pattern is verified against the live codebase; all three open +items the plan flagged (`SalesDaily`, `useProviderHealth`, `getStatusVariant`) are resolved +inline. The external-research pass (§ External Research Findings) validated the retraining +heuristic (hybrid time+performance trigger, WAPE error signal) and the alias-staleness +design against MLOps/MLflow guidance, and caught one latent bug — the `AsyncSession` +lazy-load trap on `DeploymentAlias.run` — now fixed with an explicit CRITICAL gotcha. +Residual risk: (1) integration-test fixture FK ordering for `DeploymentAlias`→`ModelRun`; +(2) `model_run.metrics` shape variability — both mitigated by defensive `extract_wape` and +structural (not exact-total) test assertions; (3) minor SQLAlchemy `DISTINCT ON` / typing +friction under `--strict`, mitigated by the explicit gotcha and `created_at`-ordering +guidance. diff --git a/PRPs/PRP-25-forecastops-control-center-full.md b/PRPs/PRP-25-forecastops-control-center-full.md new file mode 100644 index 00000000..62bd1bf3 --- /dev/null +++ b/PRPs/PRP-25-forecastops-control-center-full.md @@ -0,0 +1,827 @@ +name: "PRP-25 — ForecastOps Control Center (Full Version)" +description: | + Context-rich PRP that takes the ForecastOps Control Center from its PRP-24 MVP + to the "Full Version" of docs/optional-features/02-forecastops-control-center.md: + model-health + performance-drift indicators, an exportable incident report, and + an operator action layer (bulk retrain + promote-to-alias). Phased so each phase + is independently shippable and one-pass implementable. + +## Purpose + +PRP-24 shipped the Control Center MVP — a read-only `app/features/ops/` slice +(`GET /ops/summary`, `GET /ops/retraining-candidates`) and a `/ops` page. This PRP +delivers the remaining "Full Version" capabilities from the feature brief: + +- **Phase A — Model Health & Drift**: a new `GET /ops/model-health` endpoint that + classifies forecast-error *performance drift* per `(store, product)` from run + history, plus a Model Health section on `/ops`. +- **Phase B — Incident Report Export**: client-side CSV + Markdown export of the + operational snapshot. +- **Phase C — Action Layer**: operator bulk-retrain (multi-select the retraining + queue → fan out to `POST /jobs`) and promote-to-alias (`POST /registry/aliases`), + both behind a confirmation dialog. + +--- + +## DEPENDS ON — read before starting + +This PRP **builds directly on PRP-24** (`PRPs/PRP-24-forecastops-control-center.md`, +issue #217, PR #218). It modifies files PRP-24 created. **PR #218 is merged to +`dev`** (dev tip `aac7735 Merge pull request #218`), so the dependency gate is +satisfied — cut this PRP's branch from `dev`. Sanity-check before starting: if +`app/features/ops/service.py` does not already define `OpsService`, +`extract_wape`, and `score_retraining_candidate`, stop — the dependency is missing. + +--- + +## Goal + +Ship the Full Version of the Control Center as three independently-validatable +phases on top of the PRP-24 slice: + +- **Backend** — one new read-only endpoint (`GET /ops/model-health`). The `ops` + slice stays read-only: **no new table, no migration, no new mutating endpoint.** +- **Frontend** — a Model Health section, an incident-report export control, and an + action layer on the Retraining Queue, all on the existing `/ops` page. + +End state: `docker compose up` → seed → open `/ops` → operator sees drift signals, +can export an incident report, and can trigger retraining / promote a model — +without leaving the page. + +## Why + +- **User value** — the MVP shows *what* needs attention; the Full Version lets the + operator *understand why* (drift), *communicate it* (export), and *act on it* + (retrain / promote) in one place. This is the "operator workflow" the feature + brief calls for (`docs/optional-features/02-forecastops-control-center.md`). +- **Demo value** — a closed-loop ForecastOps story (observe → diagnose → act) + instead of a read-only dashboard. +- **Integration** — Phase A extends the `ops` slice's existing aggregation + pattern; Phases B & C are frontend layers over **already-shipped endpoints** + (`POST /jobs`, `POST /registry/aliases`) — no new backend mutation surface. + +## What + +### User-visible behavior + +On the existing `/ops` page, the Full Version adds: + +1. **Model Health section** (Phase A) — a table of `(store, product)` grains, each + showing its forecast-error (WAPE) history, a **drift badge** + (`improving` / `stable` / `degrading` / `unknown`), and the WAPE delta. Backed + by `GET /ops/model-health`. +2. **Export control** (Phase B) — an "Export report" button in the page header + offering **CSV** (attention items) and **Markdown** (full incident report) + downloads, generated entirely client-side. +3. **Action layer** (Phase C) — the Retraining Queue gains row checkboxes and a + **"Retrain selected (N)"** button; each Model Health / candidate row gains a + **"Promote to alias"** action. Both open a confirmation dialog, then call the + existing job / alias endpoints, reporting per-item success/failure via toasts. + +### Technical requirements + +- Phase A: extend `app/features/ops/{schemas,service,routes,__init__}.py` and the + slice's tests. Server-side SQL, `mypy --strict` + `pyright --strict` clean, + RFC 7807 errors, Pydantic v2 response models (`from_attributes=True`, no + `strict=True`). +- Phases B & C: **frontend only** — new pure util modules (+ vitest), new TanStack + Query hooks reusing existing ones, page wiring. No backend changes. +- No new external dependency, no new table, no Alembic migration. + +### Success Criteria + +- [ ] `GET /ops/model-health?limit=` → 200 with `entries` (each carrying + `drift_direction`, `latest_wape`, `wape_delta`, `wape_history`), sorted with + `degrading` grains first; `422` when `limit` is outside `[1, 100]`. +- [ ] `GET /ops/model-health` → 200 (never 500) on an empty database. +- [ ] `classify_drift` is a pure, unit-tested function that never raises on + missing / sparse WAPE history. +- [ ] `/ops` renders a Model Health section with drift badges and a delta. +- [ ] An "Export report" control downloads a valid CSV and a Markdown report + built entirely client-side from already-loaded data. +- [ ] The Retraining Queue supports multi-select; "Retrain selected" opens a + confirm dialog and creates one `train` job per selected grain via + `POST /jobs`, reporting per-item outcome. +- [ ] A "Promote to alias" action creates/updates an alias via + `POST /registry/aliases` behind a confirm dialog. +- [ ] The `ops` backend slice remains read-only — no new `models.py`, no + migration, no new mutating endpoint. +- [ ] All gates pass: `ruff`, `mypy --strict`, `pyright --strict`, `pytest` + (unit + integration), frontend `tsc` + `lint` + `test`. + +--- + +## All Needed Context + +### DECISIONS LOCKED (resolved during planning — do NOT re-litigate) + +1. **The `ops` backend stays read-only.** Phase C's "action layer" is a + *frontend* feature: the `/ops` page calls the **existing, already-sanctioned** + `POST /jobs` and `POST /registry/aliases` endpoints. No `POST /ops/*` mutating + endpoint is added; the slice keeps no `models.py` and ships no migration. This + keeps the PRP-24 "read-only slice" invariant intact. The mild tension — the + `/ops` *page* becomes an action launcher — is **accepted** (the Forecast page + already triggers train jobs the same way) and **MUST be noted in the PR + description** per `.claude/rules/product-vision.md`. + +2. **Drift = performance drift, not data drift.** Phase A classifies the **trend + of forecast error (WAPE) across a grain's successful-run history** — a + performance-based signal computable read-only from `model_run.metrics`. True + *data drift* (input-feature distribution shift, PSI/KL tests) is **OUT OF + SCOPE**: featuresets are computed in-memory and never persisted, so there are + no feature snapshots to compare. Do NOT add drift infrastructure. + +3. **WebSocket live job updates — DECLINED, do NOT build.** The feature brief's + Full Version lists "WebSocket updates for running jobs". This conflicts with + the `.claude/rules/product-vision.md` guardrail *"Not a real-time streaming + system … the agent WebSocket is for response streaming only."* PRP-24 already + polls `GET /ops/summary` every 15 s and `use-jobs.ts` polls every 5 s — live + job state is already covered. Keeping polling is the deliberate decision; a WS + would be a new streaming surface for no real gain. If a WS is ever wanted it is + a **separate PRP** mirroring the `demo/stream` pattern — not this one. + +4. **Retraining scoring is unchanged.** PRP-24's `score_retraining_candidate` + (60% staleness / 40% WAPE) is locked. Phase A adds drift as a *separate* + `/ops/model-health` signal; it does NOT fold drift into the retraining score. + +5. **No new dependencies.** Recharts, shadcn `Checkbox`, `AlertDialog`, and + `sonner` are already installed (`frontend/package.json`, + `frontend/src/components/ui/{chart,checkbox,alert-dialog,sonner}.tsx`). Reuse + them; do not `pnpm add` anything. + +### Documentation & References + +```yaml +# MUST READ — the PRP-24 slice this PRP extends (already on dev after #218) +- file: PRPs/PRP-24-forecastops-control-center.md + why: The MVP PRP. Its "Known Gotchas" section applies verbatim here — status + columns are String (compare `Enum.X.value`); `datetime.now(UTC)` (ruff DTZ); + DISTINCT ON order_by must lead with the distinct cols; AsyncSession forbids + lazy-loading; response models use ConfigDict(from_attributes=True), NEVER + strict=True; do NOT add a `# noqa: BLE001` (BLE is not in the ruff select). +- file: app/features/ops/service.py + why: EXTEND THIS. Already defines `extract_wape(metrics)` — REUSE it, do not + redefine. `OpsService.get_retraining_candidates` is the near-exact mirror + for the new `get_model_health` (DISTINCT ON vs. full-history difference + noted in Gotchas). Module-scope pure helpers (`score_retraining_candidate`) + are the pattern for the new `classify_drift`. +- file: app/features/ops/schemas.py + why: EXTEND THIS. `RetrainingCandidate` / `RetrainingCandidatesResponse` are the + exact shape to mirror for `ModelHealthEntry` / `ModelHealthResponse`. +- file: app/features/ops/routes.py + why: EXTEND THIS. `get_retraining_candidates` is the exact mirror for the new + `get_model_health` route — same `Query(default=20, ge=1, le=100)` bound. +- file: app/features/ops/__init__.py + why: EXTEND `__all__` with the new response model. +- file: app/features/ops/tests/conftest.py + why: EXTEND. `sample_runs` already creates two success runs for grain + (9001, 8001) — that grain already has a 2-point WAPE history (31.0 → 12.0, + i.e. `improving`). Add a third run if a `degrading` case is wanted. +- file: app/features/ops/tests/{test_service,test_schemas,test_routes_integration}.py + why: EXTEND. `test_service.py` is the pattern for pure-function tests of + `classify_drift`; `test_routes_integration.py` for the new endpoint. +- file: app/features/registry/models.py + why: `ModelRun.metrics` (JSONB, nullable), `status`, `store_id`, `product_id`, + `created_at`, `run_id`. `RunStatus.SUCCESS`. + +# MUST READ — endpoints Phase C reuses (no backend changes — frontend calls these) +- file: app/features/jobs/routes.py + why: `POST /jobs` (202) creates+executes a job. Body is `JobCreate` + {job_type, params}. The train-job `params` contract is documented in the + route docstring — model_type, store_id, product_id, start_date, end_date. +- file: app/features/jobs/schemas.py + why: `JobCreate` = {job_type: JobType, params: dict}. `JobResponse` shape. +- file: app/features/jobs/service.py + why: VERIFY the exact train-job `params` keys `JobService` consumes BEFORE + writing `buildRetrainJob` — this is Phase C's main risk. +- file: app/features/forecasting/schemas.py + why: `ModelConfig` (discriminated union on `model_type`) — the shape of + `model_run.model_config`; tells you what to flatten into the retrain params. +- file: app/features/registry/routes.py + why: `POST /registry/aliases` (201) body `AliasCreate` {alias_name, run_id, + description?}; aliases only point at SUCCESS runs (400 otherwise). +- file: app/features/registry/schemas.py + why: `AliasCreate`, `AliasResponse`, `RunResponse` field names. + +# MUST READ — frontend patterns +- file: frontend/src/pages/ops.tsx + why: MODIFY THIS in every phase. PRP-24's page — header, error/loading/empty + early returns, Card/Table sections, `@/` imports. +- file: frontend/src/hooks/use-ops.ts + why: EXTEND. `useOpsSummary` / `useRetrainingCandidates` are the exact pattern + for `useModelHealth`. +- file: frontend/src/hooks/use-jobs.ts + why: `useCreateJob()` — REUSE for Phase C bulk retrain. Do not write a new + job-creation hook. +- file: frontend/src/hooks/use-runs.ts + why: `useCreateAlias()` — REUSE for Phase C promote. `useRun(runId)` fetches a + run's detail (needed to clone model_config for a retrain). +- file: frontend/src/lib/csv-export.ts + why: `toCsv` / `downloadCsv` / `CsvColumn` — REUSE for Phase B CSV export. + Already CSV-injection-safe. +- file: frontend/src/lib/ops-utils.ts + ops-utils.test.ts + why: PRP-24's pure util module + colocated vitest test — the exact pattern for + the new `incident-report.ts` and `ops-actions.ts` modules. +- file: frontend/src/pages/visualize/demand.tsx + why: A dense data page that already does CSV export (`downloadCsv`/`toCsv`) and + row interaction — mirror its export-button placement and table patterns. +- file: frontend/src/components/ui/checkbox.tsx + why: shadcn Checkbox — Phase C row selection. Already installed. +- file: frontend/src/components/ui/alert-dialog.tsx + why: shadcn AlertDialog — Phase C confirm dialogs. Already installed. +- file: frontend/src/components/ui/sonner.tsx + why: `sonner` toast. VERIFY a `` is mounted (app-shell / main) before + calling `toast()`; if not, mount it once in the app shell. +- file: frontend/src/components/ui/chart.tsx + why: shadcn Recharts wrapper — optional WAPE sparkline in Model Health. +- file: frontend/src/components/charts/time-series-chart.tsx + why: existing Recharts usage pattern if a sparkline is added. +- file: frontend/src/types/api.ts + why: EXTEND. The `Ops*` interfaces PRP-24 added (`OpsSummaryResponse`, + `RetrainingCandidate`, …) are the mirror for `ModelHealth*`. + +# External docs +- url: https://docs.sqlalchemy.org/en/20/tutorial/data_select.html#order-by + why: ordering the full run history per grain (NOT DISTINCT ON — see Gotchas). +- url: https://tanstack.com/query/latest/docs/framework/react/guides/mutations + why: reusing `useMutation` (useCreateJob/useCreateAlias) for the action layer. +- url: https://ui.shadcn.com/docs/components/alert-dialog + why: AlertDialog composition for the confirm gates. +- url: https://recharts.org/en-US/api/LineChart + why: minimal WAPE sparkline (optional Phase A polish). +- url: https://www.mlflow.org/docs/latest/ml/model-registry/ + why: alias-promotion governance — the conceptual basis for Phase C promote. + +- docfile: docs/optional-features/02-forecastops-control-center.md + why: the feature brief — § "Full Version" (lines 76-83) is exactly this PRP's + scope, minus retraining scoring (done in PRP-24) and WebSocket (declined). +``` + +### Current Codebase tree (post-PRP-24, relevant subset) + +```bash +app/features/ops/ # read-only slice from PRP-24 +├── __init__.py schemas.py service.py routes.py +└── tests/ (__init__.py conftest.py test_schemas.py test_service.py + test_routes_integration.py) +frontend/src/ +├── pages/ops.tsx # PRP-24 Control Center page (5 sections) +├── hooks/use-ops.ts # useOpsSummary, useRetrainingCandidates +├── lib/ops-utils.ts (+ .test.ts) # pure helpers +├── types/api.ts # Ops* response interfaces +├── hooks/use-jobs.ts # useCreateJob (REUSE) +├── hooks/use-runs.ts # useCreateAlias, useRun (REUSE) +└── lib/csv-export.ts # toCsv, downloadCsv (REUSE) +``` + +### Desired Codebase tree (files to add / touch) + +```bash +# ── Phase A — Model Health & Drift ── +app/features/ops/schemas.py # MODIFY — add WapePoint, ModelHealthEntry, + # ModelHealthResponse +app/features/ops/service.py # MODIFY — add classify_drift() + get_model_health() +app/features/ops/routes.py # MODIFY — add GET /ops/model-health +app/features/ops/__init__.py # MODIFY — export ModelHealthResponse +app/features/ops/tests/test_schemas.py # MODIFY — new models +app/features/ops/tests/test_service.py # MODIFY — classify_drift +app/features/ops/tests/conftest.py # MODIFY — degrading-grain fixture +app/features/ops/tests/test_routes_integration.py # MODIFY — /ops/model-health +frontend/src/types/api.ts # MODIFY — ModelHealth* interfaces +frontend/src/hooks/use-ops.ts # MODIFY — useModelHealth +frontend/src/pages/ops.tsx # MODIFY — Model Health section + +# ── Phase B — Incident Report Export ── +frontend/src/lib/incident-report.ts # NEW — pure builders (CSV cols + markdown) +frontend/src/lib/incident-report.test.ts # NEW — vitest +frontend/src/pages/ops.tsx # MODIFY — export control in header + +# ── Phase C — Action Layer ── +frontend/src/lib/ops-actions.ts # NEW — pure buildRetrainJob() +frontend/src/lib/ops-actions.test.ts # NEW — vitest +frontend/src/pages/ops.tsx # MODIFY — selection, dialogs, actions + +# NOT created: any app/features/ops/models.py, any Alembic migration, +# any POST /ops/* endpoint, any WebSocket. +``` + +### Known Gotchas & Library Quirks + +```python +# CRITICAL: ALL PRP-24 gotchas still apply. Re-read PRP-24 § "Known Gotchas". +# Headlines: compare String status columns against `Enum.X.value`; use +# `datetime.now(UTC)` (ruff DTZ bans date.today()/naive now()); response models +# use ConfigDict(from_attributes=True) and NEVER strict=True; never add a +# `# noqa: BLE001` (BLE is not in the ruff select — it would trip RUF100). +# +# CRITICAL (Phase A): model-health needs the FULL run history per grain, NOT the +# latest-per-grain. Do NOT use DISTINCT ON here. Query every SUCCESS run, ordered +# by (store_id, product_id, created_at ASC), and group in Python: +# select(ModelRun).where(ModelRun.status == RunStatus.SUCCESS.value) +# .order_by(ModelRun.store_id, ModelRun.product_id, ModelRun.created_at) +# Then itertools.groupby over (store_id, product_id) — rows are already ordered. +# +# CRITICAL (Phase A): REUSE the existing `extract_wape` from ops/service.py — it +# already tolerates None / non-numeric / bool. Do NOT redefine it. WAPE history +# will contain None entries (runs whose metrics lack WAPE); classify_drift MUST +# tolerate a list with None gaps and never raise. +# +# CRITICAL (Phase C): POST /jobs executes SYNCHRONOUSLY (returns 202 but runs the +# job inline before responding). Bulk-retrain of N grains = N blocking calls on +# a single-process backend. Fire them SEQUENTIALLY (await each before the next), +# show per-item progress, and keep N modest. Do NOT Promise.all() them. +# +# CRITICAL (Phase C): a train job's `params` shape is consumed by JobService — +# VERIFY the exact keys in app/features/jobs/service.py + forecasting/schemas.py +# BEFORE writing buildRetrainJob(). `_execute_train` in jobs/service.py reads a +# FLAT params dict: model_type, store_id, product_id, start_date, end_date, plus +# model-specific keys `season_length` (seasonal_naive) / `window_size` +# (moving_average) — there is NO `period` key. Pick those keys explicitly from +# the source run's `model_config`; do NOT blind-spread `model_config` (it also +# carries `schema_version` + a duplicate `model_type`). +# +# CRITICAL (Phase C): aliases may only point at SUCCESS runs (registry returns 400 +# otherwise). Only offer "Promote to alias" on rows whose run status is success. +# +# GOTCHA (Phase C): retrain window — clone the source run's `model_type` + +# `model_config`; set start_date = source run `data_window_start`, end_date = +# `summary.freshness.latest_sales_date` (the freshest data). If latest_sales_date +# is null, fall back to the run's own data_window_end and surface a warning. +# +# GOTCHA: `sonner` `toast()` needs a mounted ``. It is ALREADY mounted — +# `frontend/src/components/layout/app-shell.tsx` renders `` from +# `@/components/ui/sonner`. Task C4 is verification-only; do NOT add a second one. +# +# GOTCHA: commit-format scope allow-list has NO `ops` scope. Use feat(api) for the +# backend phase, feat(ui) for the frontend phases. +# +# GOTCHA: the /ops page renders inside AppShell — no nav/container/Toaster added by +# the page; semantic shadcn tokens only, never raw colors. +``` + +--- + +## External Research Findings + +Verified May 2026. Each finding ends with a **verdict**. + +### 1. Performance-drift vs. data-drift triggers (web search, May 2026) + +MLOps practice splits retraining triggers into **performance-based** (monitor a +core error metric; retrain when it degrades past a threshold) and **drift-based** +(statistical tests — PSI, KL divergence — on input/target distributions). 2025 +reviews report models left unmonitored for 6+ months saw error rates rise ~35%, +and that proactive performance-trigger policies outperform reactive ones. + +- **Verdict — Phase A is a performance-drift indicator.** It tracks the WAPE trend + across a grain's run history and classifies `improving / stable / degrading`. + This is exactly the recommended performance-based signal and needs **no new + infrastructure** — it reads `model_run.metrics`. PSI/KL **data drift** is + correctly out of scope (no persisted feature snapshots; see Decision #2). +- Sources: [What Is Model Drift?](https://www.articsledge.com/post/model-drift), + [Advanced ML Model Monitoring](https://enhancedmlops.com/advanced-ml-model-monitoring-drift-detection-explainability-and-automated-retraining/), + [MLOps Model Monitoring](https://durapid.com/blog/mlops-model-monitoring-how-to-track-model-drift-and-performance-in-production/). + +### 2. Drift threshold (heuristic) + +There is no universal "drift threshold"; teams pick a relative tolerance band. +A ±10% relative change in the error metric is a defensible, deterministic default +for a portfolio system. + +- **Verdict — applied.** `classify_drift` uses a ±10% relative band: latest WAPE + vs. the mean of prior WAPEs. `degrading` if latest is >10% worse, `improving` if + >10% better, `stable` within the band, `unknown` if fewer than two numeric WAPEs. + +### 3. Alias-promotion governance (MLflow) + +MLflow models alias promotion (`champion` / production alias) as a deliberate, +human-gated step decoupling deployment from a specific version. + +- **Verdict — confirms Phase C.** "Promote to alias" reuses `POST /registry/aliases` + behind a confirmation `AlertDialog` — the human gate. No new backend gate is + needed; the registry already restricts aliases to SUCCESS runs. +- Source: [MLflow Model Registry](https://www.mlflow.org/docs/latest/ml/model-registry/). + +### 4. WebSocket job updates — assessed and DECLINED + +The feature brief lists WS job updates under "Full Version". `product-vision.md` +forbids new streaming surfaces ("Not a real-time streaming system"). PRP-24's +`/ops/summary` 15 s poll + `use-jobs.ts` 5 s poll already deliver live job state. + +- **Verdict — do NOT build (Decision #3).** Polling is the deliberate choice. + +--- + +## Implementation Blueprint + +### Phase A — data models (`app/features/ops/schemas.py`, additions) + +All response models — `ConfigDict(from_attributes=True)`, every field a +`Field(..., description=...)`, counts `ge=0`, **no `strict=True`**. + +```python +from typing import Literal +DriftDirection = Literal["improving", "stable", "degrading", "unknown"] + +class WapePoint(BaseModel): # one run's WAPE observation + run_id: str + created_at: datetime + wape: float | None # None when the run lacks WAPE + +class ModelHealthEntry(BaseModel): + store_id: int + product_id: int + run_count: int = Field(..., ge=0) + latest_run_id: str | None + latest_run_status: str | None + latest_wape: float | None + previous_wape: float | None # the prior numeric WAPE + wape_delta: float | None # latest - previous (numeric only) + drift_direction: DriftDirection + last_trained_at: datetime | None + staleness_days: int = Field(..., ge=0) + wape_history: list[WapePoint] # chronological, may hold gaps + +class ModelHealthResponse(BaseModel): + entries: list[ModelHealthEntry] # degrading-first sort + total_evaluated: int = Field(..., ge=0) + generated_at: datetime +``` + +### Phase A — pseudocode (`app/features/ops/service.py`, additions) + +```python +# ── module-scope pure helper (mirror of score_retraining_candidate) ── +_DRIFT_BAND = 0.10 # ±10% relative WAPE change + +def classify_drift(wape_history: list[float | None]) -> tuple[str, float | None]: + """Classify the WAPE trend. Pure; never raises. Returns (direction, delta). + direction ∈ improving|stable|degrading|unknown; delta = latest - previous.""" + numeric = [w for w in wape_history if w is not None] + if len(numeric) < 2: + return "unknown", None + latest = numeric[-1] + prior = numeric[:-1] + baseline = sum(prior) / len(prior) + delta = round(latest - prior[-1], 4) + if baseline <= 0: # avoid div-by-zero on a 0 WAPE + return ("degrading" if latest > 0 else "stable"), delta + rel = (latest - baseline) / baseline + if rel > _DRIFT_BAND: + return "degrading", delta + if rel < -_DRIFT_BAND: + return "improving", delta + return "stable", delta + +# ── OpsService.get_model_health ── +async def get_model_health(self, db, limit: int) -> ModelHealthResponse: + today = datetime.now(UTC).date() + # FULL history — NOT DISTINCT ON. Ordered so itertools.groupby works. + runs = (await db.execute( + select(ModelRun).where(ModelRun.status == RunStatus.SUCCESS.value) + .order_by(ModelRun.store_id, ModelRun.product_id, ModelRun.created_at) + )).scalars().all() + entries = [] + for (store_id, product_id), grain_runs in groupby(runs, key=lambda r: (r.store_id, r.product_id)): + grain_runs = list(grain_runs) # already chronological + history = [WapePoint(run_id=r.run_id, created_at=r.created_at, + wape=extract_wape(r.metrics)) for r in grain_runs] + direction, delta = classify_drift([p.wape for p in history]) + numeric = [p.wape for p in history if p.wape is not None] + latest_run = grain_runs[-1] + entries.append(ModelHealthEntry( + store_id=store_id, product_id=product_id, run_count=len(grain_runs), + latest_run_id=latest_run.run_id, latest_run_status=latest_run.status, + latest_wape=(numeric[-1] if numeric else None), + previous_wape=(numeric[-2] if len(numeric) > 1 else None), + wape_delta=delta, drift_direction=direction, + last_trained_at=latest_run.created_at, + staleness_days=max((today - latest_run.data_window_end).days, 0), + wape_history=history)) + # degrading first, then by |wape_delta| desc; unknown/stable last + _rank = {"degrading": 0, "improving": 1, "stable": 2, "unknown": 3} + entries.sort(key=lambda e: (_rank[e.drift_direction], -abs(e.wape_delta or 0.0))) + logger.info("ops.model_health_computed", grains=len(entries)) + return ModelHealthResponse(entries=entries[:limit], total_evaluated=len(entries), + generated_at=datetime.now(UTC)) +``` + +```python +# ── app/features/ops/routes.py (add; mirror get_retraining_candidates) ── +@router.get("/model-health", response_model=ModelHealthResponse, + summary="Per-(store,product) forecast-error health and drift") +async def get_model_health( + limit: int = Query(default=20, ge=1, le=100, description="Max grains to return"), + db: AsyncSession = Depends(get_db), +) -> ModelHealthResponse: + return await OpsService().get_model_health(db, limit) +``` + +### Phase B — pseudocode (`frontend/src/lib/incident-report.ts`, NEW, pure) + +```typescript +import type { CsvColumn } from '@/lib/csv-export' +import type { AttentionItem, OpsSummaryResponse, RetrainingCandidate } from '@/types/api' + +// CSV column set for the attention-items export (reuse toCsv/downloadCsv). +export const attentionCsvColumns: CsvColumn[] = [ + { key: 'item_type', header: 'Type' }, { key: 'entity_id', header: 'Entity' }, + { key: 'label', header: 'Item' }, { key: 'detail', header: 'Detail' }, + { key: 'occurred_at', header: 'When' }, +] + +// Build a human-readable Markdown incident report from already-loaded page data. +export function buildIncidentMarkdown( + summary: OpsSummaryResponse, candidates: RetrainingCandidate[], +): string { + // Sections: # ForecastOps Incident Report (generated_at) ; System Health + // (api/db + provider lines) ; KPIs (active/failed jobs, success rate, stale + // aliases) ; Data Freshness ; Needs Attention (a markdown table) ; Top + // Retraining Candidates (a markdown table). Pure string assembly — no fetch. + // Return the assembled string. +} +``` + +### Phase C — pseudocode (`frontend/src/lib/ops-actions.ts`, NEW, pure) + +```typescript +import type { JobCreate, ModelRun, RetrainingCandidate } from '@/types/api' + +// Build the POST /jobs body that retrains a grain from its latest run. +// VERIFY param keys against app/features/jobs/service.py before finalizing. +export function buildRetrainJob( + run: ModelRun, // GET /registry/runs/{latest_run_id} + latestSalesDate: string | null, // summary.freshness.latest_sales_date +): JobCreate { + return { + job_type: 'train', + params: { + model_type: run.model_type, + store_id: run.store_id, + product_id: run.product_id, + start_date: run.data_window_start, + end_date: latestSalesDate ?? run.data_window_end, // freshest data + // model-specific keys picked explicitly — NOT a blind ...model_config spread: + ...(run.model_config.season_length != null + ? { season_length: run.model_config.season_length } : {}), + ...(run.model_config.window_size != null + ? { window_size: run.model_config.window_size } : {}), + }, + } +} +``` + +Page wiring (`ops.tsx`): Retraining Queue rows get a `Checkbox`; selection is +`useState>`. "Retrain selected (N)" opens an `AlertDialog`; on confirm, +**sequentially** for each selected candidate: `useRun`-fetch its `latest_run_id`, +`buildRetrainJob(...)`, `useCreateJob().mutateAsync(...)`, `toast` the outcome. +"Promote to alias" (on success-status rows only) opens an `AlertDialog` with an +alias-name input, then `useCreateAlias().mutateAsync({alias_name, run_id})`. + +### Tasks (in order) + +```yaml +# ════════ PHASE A — Model Health & Drift (backend + page section) ════════ +Task A1 — MODIFY app/features/ops/schemas.py: + - ADD DriftDirection Literal, WapePoint, ModelHealthEntry, ModelHealthResponse + - VALIDATE: uv run python -c "from app.features.ops.schemas import ModelHealthResponse; print('ok')" + +Task A2 — MODIFY app/features/ops/service.py: + - ADD module-scope `classify_drift` (pseudocode above); import `groupby` from itertools + - ADD `OpsService.get_model_health` (full-history query — NOT DISTINCT ON; reuse extract_wape) + - VALIDATE: uv run mypy app/features/ops/ && uv run pyright app/features/ops/ + +Task A3 — MODIFY app/features/ops/routes.py + __init__.py: + - ADD GET /ops/model-health (mirror get_retraining_candidates); export ModelHealthResponse + - VALIDATE: uv run python -c "from app.main import app; assert '/ops/model-health' in {r.path for r in app.routes}; print('wired')" + +Task A4 — MODIFY ops tests (test_schemas.py, test_service.py, conftest.py): + - test_service.py: classify_drift cases — <2 numeric → unknown; degrading; + improving; stable within band; None-gap tolerance; zero-baseline guard + - test_schemas.py: construct ModelHealthEntry/Response; ge=0 rejects negatives + - conftest.py: extend sample_runs (or a new fixture) so one grain has a + degrading 3-point WAPE history + - VALIDATE: uv run pytest -v -m "not integration" app/features/ops/tests/test_service.py app/features/ops/tests/test_schemas.py + +Task A5 — MODIFY app/features/ops/tests/test_routes_integration.py: + - /ops/model-health 200 happy (seeded), entries carry drift_direction; + 200 resilient (empty); ?limit=0 → 422; ?limit=200 → 422; degrading-first sort + - VALIDATE: docker compose up -d && uv run pytest -v -m integration app/features/ops/ + +Task A6 — MODIFY frontend/src/types/api.ts: + - ADD DriftDirection, WapePoint, ModelHealthEntry, ModelHealthResponse (dates as string) + - VALIDATE: cd frontend && pnpm tsc --noEmit + +Task A7 — MODIFY frontend/src/hooks/use-ops.ts: + - ADD useModelHealth(limit=20, enabled=true) — queryKey ['ops','model-health',limit]; + no refetchInterval (slow-moving). MIRROR useRetrainingCandidates. + - VALIDATE: cd frontend && pnpm tsc --noEmit + +Task A8 — MODIFY frontend/src/pages/ops.tsx: + - ADD a "Model Health" Card+Table section: grain, drift StatusBadge + (degrading→error, improving→success, stable→info, unknown→default), + latest WAPE, wape_delta, run_count. Optional: a Recharts sparkline of + wape_history. GOTCHA: renders inside AppShell — no raw colors. + - VALIDATE: cd frontend && pnpm tsc --noEmit && pnpm lint + +# ════════ PHASE B — Incident Report Export (frontend only) ════════ +Task B1 — CREATE frontend/src/lib/incident-report.ts + incident-report.test.ts: + - attentionCsvColumns + buildIncidentMarkdown (pure; pseudocode above) + - MIRROR: csv-export.ts + ops-utils.test.ts + - VALIDATE: cd frontend && pnpm test --run src/lib/incident-report.test.ts + +Task B2 — MODIFY frontend/src/pages/ops.tsx: + - ADD an "Export report" control in the page header — a dropdown (or two + buttons): "CSV (attention items)" → downloadCsv(toCsv(...)); "Markdown + report" → download buildIncidentMarkdown(...) as ops-incident-report.md + - MIRROR: demand.tsx export button + - VALIDATE: cd frontend && pnpm tsc --noEmit && pnpm lint + +# ════════ PHASE C — Action Layer (frontend only) ════════ +Task C0 — RESEARCH (no code): read app/features/jobs/service.py + + app/features/forecasting/schemas.py — confirm the exact train-job `params` + keys. Adjust buildRetrainJob accordingly. This de-risks the whole phase. + +Task C1 — CREATE frontend/src/lib/ops-actions.ts + ops-actions.test.ts: + - buildRetrainJob(run, latestSalesDate) (pseudocode above) + - VALIDATE: cd frontend && pnpm test --run src/lib/ops-actions.test.ts + +Task C2 — MODIFY frontend/src/pages/ops.tsx — bulk retrain: + - Retraining Queue rows get a shadcn Checkbox; selection via useState + - "Retrain selected (N)" → AlertDialog confirm → SEQUENTIALLY per candidate: + fetch run (useRun / api), buildRetrainJob, useCreateJob().mutateAsync, toast + - GOTCHA: sequential awaits, not Promise.all (POST /jobs runs synchronously) + - VALIDATE: cd frontend && pnpm tsc --noEmit && pnpm lint + +Task C3 — MODIFY frontend/src/pages/ops.tsx — promote to alias: + - "Promote to alias" action on success-status rows → AlertDialog with an + alias-name input → useCreateAlias().mutateAsync({alias_name, run_id}) → toast + - GOTCHA: only success runs are promotable (registry returns 400 otherwise) + - VALIDATE: cd frontend && pnpm tsc --noEmit && pnpm lint + +Task C4 — VERIFY the in frontend/src/components/layout/app-shell.tsx + is mounted (it already is). Verification-only — do NOT add a second Toaster. + - VALIDATE: cd frontend && pnpm tsc --noEmit + +# ════════ FINAL ════════ +Task D1 — FULL validation sweep (all gates — see Validation Loop). +Task D2 — Browser dogfood per .claude/rules/ui-design.md (webapp-testing / + agent-browser): Model Health renders with drift badges; export downloads a + valid CSV + Markdown; bulk-retrain creates jobs (verify on /explorer/jobs); + promote creates an alias (verify on /ops summary aliases). +``` + +### Integration Points + +```yaml +DATABASE: + - migration: NONE — Phase A is a read-only query; Phases B/C touch no schema. + - tables read (existing): model_run (Phase A). + +ROUTES (backend): + - add to app/features/ops/routes.py: GET /ops/model-health + - already wired: app/main.py includes ops_router (PRP-24) — no main.py change. + +ROUTES (frontend): + - none — no new page, no new route; all work lands on the existing /ops page. + +HOOKS: + - new: useModelHealth in frontend/src/hooks/use-ops.ts + - reuse: useCreateJob (use-jobs.ts), useCreateAlias + useRun (use-runs.ts) + +CONFIG: none — no new settings, no new env var, no new dependency. +``` + +--- + +## Validation Loop + +### Level 1: Syntax & Style + +```bash +uv run ruff check . --fix && uv run ruff format --check . +cd frontend && pnpm lint +# Trap: date.today() → ruff DTZ; a stray `# noqa: BLE001` → RUF100. +``` + +### Level 2: Type Checks + +```bash +uv run mypy app/ && uv run pyright app/ # both --strict +cd frontend && pnpm tsc --noEmit +``` + +### Level 3: Unit Tests + +```bash +uv run pytest -v -m "not integration" app/features/ops/ +cd frontend && pnpm test --run src/lib/incident-report.test.ts src/lib/ops-actions.test.ts +``` + +Pure-function cases that MUST exist (`test_service.py`): +```python +def test_classify_drift_unknown_when_under_two_numeric(): + assert classify_drift([None, 10.0]) == ("unknown", None) + +def test_classify_drift_degrading(): + d, delta = classify_drift([10.0, 10.0, 20.0]) # latest 20 vs baseline 10 + assert d == "degrading" + +def test_classify_drift_improving(): + d, _ = classify_drift([20.0, 20.0, 10.0]) + assert d == "improving" + +def test_classify_drift_stable_within_band(): + d, _ = classify_drift([10.0, 10.5]) # +5% < 10% band + assert d == "stable" + +def test_classify_drift_tolerates_none_gaps(): + assert classify_drift([None, 10.0, None, 12.0])[0] in {"stable", "degrading"} +``` + +### Level 4: Integration Tests + +```bash +docker compose up -d && uv run alembic upgrade head +uv run pytest -v -m integration app/features/ops/ +``` + +`/ops/model-health` → 200; entries cover seeded grains; `drift_direction` present; +empty DB → 200 (never 500); `?limit=0` and `?limit=200` → 422; degrading-first sort. + +### Level 5: Manual Validation + +```bash +uv run uvicorn app.main:app --reload --port 8123 & +curl -s "http://localhost:8123/ops/model-health?limit=5" | head -c 400 +curl -s -o /dev/null -w '%{http_code}\n' "http://localhost:8123/ops/model-health?limit=0" # 422 +# Frontend: seed (make demo), open http://localhost:5173/ops via the +# webapp-testing skill / agent-browser — verify the Model Health section, +# CSV + Markdown export downloads, bulk-retrain → new jobs on /explorer/jobs, +# promote → alias on the summary. Type-check passing ≠ UI works. +``` + +--- + +## Final Validation Checklist + +- [ ] `uv run ruff check . && uv run ruff format --check .` — clean +- [ ] `uv run mypy app/ && uv run pyright app/` — clean (`--strict`) +- [ ] `uv run pytest -v -m "not integration"` — green +- [ ] `docker compose up -d && uv run pytest -v -m integration` — green +- [ ] `cd frontend && pnpm tsc --noEmit && pnpm lint && pnpm test --run` — green +- [ ] `GET /ops/model-health` behaves per Success Criteria (sort, 422, empty-DB) +- [ ] `/ops` shows Model Health with drift badges; export downloads CSV + Markdown; + bulk-retrain creates jobs; promote creates an alias — dogfooded in a browser +- [ ] `ops` backend slice still read-only — no `models.py`, no migration, no + `POST /ops/*` +- [ ] No new dependency +- [ ] PR description flags: the `/ops` page is now an action launcher (calls + existing `POST /jobs` / `POST /registry/aliases`); WebSocket job updates + were assessed and deliberately declined (Decision #3) +- [ ] Commits use `feat(api)` (Phase A) / `feat(ui)` (Phases B, C) and reference + an open issue + +--- + +## Anti-Patterns to Avoid + +- ❌ Don't add a `POST /ops/*` mutating endpoint — Phase C is frontend-only over + existing endpoints; the `ops` slice stays read-only. +- ❌ Don't build the WebSocket — it was assessed and declined (Decision #3). +- ❌ Don't add data-drift / PSI infrastructure — performance drift only (Decision #2). +- ❌ Don't use DISTINCT ON for model-health — it needs the full per-grain history. +- ❌ Don't redefine `extract_wape` — reuse the one in `ops/service.py`. +- ❌ Don't `Promise.all()` the bulk retrains — `POST /jobs` runs synchronously; + go sequential. +- ❌ Don't change `score_retraining_candidate` — PRP-24's scoring is locked. +- ❌ Don't `pnpm add` anything — Recharts/Checkbox/AlertDialog/sonner are installed. +- ❌ Don't add `ConfigDict(strict=True)` to the new response models. +- ❌ Don't claim the UI works on a green type-check — dogfood it in a browser. + +## Workflow Notes + +- Open a GitHub issue first; branch `feat/ops-control-center-full` off `dev` + (`.claude/rules/branch-naming.md`) — **only after PR #218 (PRP-24) is merged**. +- The phases are independently shippable. Prefer **one PR per phase** (smaller + reviews) or a single phased PR — either way: `feat(api)` for Phase A, + `feat(ui)` for Phases B and C. +- The PR description MUST state (a) the `/ops` page becomes an action launcher + via existing endpoints, and (b) WebSocket job updates were deliberately + declined — per `.claude/rules/product-vision.md` § "When Ideas Don't Align". + +## Confidence Score + +**8 / 10** for one-pass implementation success. + +Rationale: Phase A is a near-exact mirror of PRP-24's verified `get_retraining_candidates` +pattern (the one structural difference — full history vs. DISTINCT ON — is called +out as a CRITICAL gotcha). Phase B is pure frontend over the existing, CSV-safe +`csv-export.ts`. Both score ~9/10. Phase C is the residual risk: it depends on the +exact `train`-job `params` contract, mitigated by the mandatory Task C0 research +step (read `jobs/service.py` + `forecasting/schemas.py` first) and by reusing the +existing `useCreateJob` / `useCreateAlias` hooks rather than inventing mutation +code. The biggest scope risks of the feature brief's "Full Version" — a WebSocket +streaming surface and turning `ops` into a backend mutation slice — are removed by +Decisions #1 and #3, keeping every phase aligned with the single-host, non-streaming +product vision. diff --git a/PRPs/PRP-26-scenario-simulation-what-if-planning.md b/PRPs/PRP-26-scenario-simulation-what-if-planning.md new file mode 100644 index 00000000..c13e28dd --- /dev/null +++ b/PRPs/PRP-26-scenario-simulation-what-if-planning.md @@ -0,0 +1,1047 @@ +name: "PRP-26 — Scenario Simulation / What-If Planning (MVP)" +description: | + Context-rich PRP that promotes the **MVP scope** of + `docs/optional-features/03-scenario-simulation-what-if-planning.md` into code: a + new `app/features/scenarios/` vertical slice that turns ForecastLabAI from + "predict the future" into "plan possible futures". It runs a baseline forecast + from an existing trained model, applies **deterministic, transparent uplift / + drag factors** for future assumptions (price change, promotion, holiday, + inventory, lifecycle), and returns a baseline-vs-scenario comparison. Scenarios + can be saved as named JSON plans. A new `Visualize → What-If Planner` page drives + the slice. Phased so each phase is independently shippable and one-pass + implementable. + +## Purpose + +ForecastLabAI can train, predict, backtest, register and visualise demand — but +every forecast answers exactly one question: *"what happens if nothing changes?"* +There is **no surface that answers "what if we discount this SKU 15% next week?"** +This PRP delivers the **MVP** of the Scenario Simulation feature brief: + +- **Phase A — Stateless Simulation Engine (backend)**: a pure deterministic + adjustment engine (`adjustments.py`) and a stateless `POST /scenarios/simulate` + endpoint that resolves a baseline model, runs a baseline forecast, applies + per-day adjustment factors, and returns a `ScenarioComparison`. No table yet. +- **Phase B — Saved Scenario Plans (persistence)**: a new `scenario_plan` table + + Alembic migration, and `POST /scenarios` / `GET /scenarios` / + `GET /scenarios/{id}` / `DELETE /scenarios/{id}` CRUD over saved plans. +- **Phase C — What-If Planner Page (frontend)**: a `/visualize/planner` page — + baseline picker → assumption form → run → baseline-vs-scenario chart + delta + table → save / reload / delete named plans. + +The **"Full Version"** (a future-feature-frame generator, exogenous-regressor +model support, agent-generated scenarios, multi-scenario comparison) is explicitly +**out of scope** — see DECISIONS LOCKED #1. + +> Source plan: `.agents/plans/scenario-simulation-what-if-planning.md` (validated +> against the repo as of 2026-05-19). Feature brief: +> `docs/optional-features/03-scenario-simulation-what-if-planning.md`. + +--- + +## DEPENDS ON — read before starting + +This PRP has **no new dependency** on an unmerged PRP. It builds on already-merged +slices: `forecasting` (PRP-5), `registry` (PRP-7), `jobs` (PRP-8), +`data_platform` (PRP-2), and the frontend dashboard (PRP-11). Sanity-check before +starting: if `app/features/forecasting/service.py` does not define +`ForecastingService.predict` and `app/features/forecasting/persistence.py` does +not expose `load_model_bundle`, stop — a dependency moved and the artifact-resolution +plan needs revisiting. + +--- + +## Goal + +**Feature Goal**: Ship the MVP of Scenario Simulation as a new +`app/features/scenarios/` vertical slice — the first slice since `rag` to ship +both a persisted table **and** a non-trivial compute path — plus a `Visualize → +What-If Planner` page, delivered as three independently-shippable phases. + +**Deliverable**: +- **Backend** — a new `scenarios` slice (`models.py`, `schemas.py`, + `adjustments.py`, `service.py`, `routes.py`, `tests/`), one Alembic migration + creating `scenario_plan`, and five endpoints (`POST /scenarios/simulate`, + `POST /scenarios`, `GET /scenarios`, `GET /scenarios/{scenario_id}`, + `DELETE /scenarios/{scenario_id}`). +- **Frontend** — a `/visualize/planner` page, `use-scenarios.ts` hooks, a pure + `scenario-utils.ts` module (+ vitest), new `Scenario*` TS types, a route + nav + entry. + +**Success Definition**: `docker compose up` → seed → `make demo` (so completed +`predict` jobs + trained models exist) → open `/visualize/planner` → a planner +picks a baseline `predict` job, defines assumptions (e.g. −15% price + a `pct_off` +promotion), runs a simulation, sees a baseline-vs-scenario chart + delta table + +a visible **heuristic disclaimer**, exports the delta CSV, and saves / reloads / +deletes a named plan — with every gate (`ruff`, `mypy --strict`, +`pyright --strict`, `pytest` unit + integration, frontend `tsc`/`lint`/`test`) +green. + +## Why + +- **User value** — business users can quantify the demand + revenue impact of a + decision (a discount, a promotion, a holiday) *before* committing to it, and + save the analysis as a reusable plan. Inventory users get a coverage / stockout + verdict under demand spikes. +- **Demo value** — a demo reviewer currently sees a *forecasting* system; this + feature makes it a *planning* system — a recognisably high-value retail + workflow. +- **Integration** — the data platform already models the relevant drivers (price + history, promotions, inventory snapshots, calendar/holiday flags) and Phase-2 + feature engineering already supports promotion / lifecycle / exogenous features; + none of that was reachable as a planning workflow. This slice surfaces it. + +## What + +### User-visible behavior + +A new `Visualize → What-If Planner` page (`/visualize/planner`) lets a planner: + +1. **Pick a baseline** — choose a completed `predict` job (its `run_id` is the + baseline model artifact key) and a horizon (7 / 14 / 30 / 60 / 90 days). +2. **Define assumptions** — all optional: a price `change_pct` over a date window, + a promotion `kind` + window, holiday/event dates, an inventory `on_hand_units` + cap, a lifecycle `stage` override. +3. **Run the simulation** — `POST /scenarios/simulate` returns a baseline series, + a scenario series, per-day + aggregate deltas, a revenue delta, and a coverage + verdict. +4. **Review** — a baseline-vs-scenario two-series chart, KPI tiles + (units/revenue delta, coverage verdict), a per-day delta table with CSV export, + and a **prominent heuristic-disclaimer banner**. +5. **Save / reload / delete** — persist the scenario as a named plan (inputs + + the comparison snapshot), list saved plans, reload one, delete one. + +### Technical requirements + +- Phase A: a new `scenarios` slice — pure `adjustments.py`, Pydantic v2 request + + response models, `ScenarioService.simulate`, an `APIRouter`, RFC 7807 errors, + `mypy --strict` + `pyright --strict` clean. **Stateless** — no table. +- Phase B: a `scenario_plan` ORM model (JSONB inputs + JSONB comparison snapshot), + an Alembic migration (upgrade **and** downgrade), CRUD service methods + routes. +- Phase C: **frontend only** — a pure `scenario-utils.ts` (+ vitest), TanStack + Query hooks, the planner page, routing + nav wiring. No backend changes. +- No new external dependency, no managed-cloud SDK, no WebSocket. Reuses FastAPI, + SQLAlchemy 2.0 async, Pydantic v2, numpy, structlog, Recharts, TanStack Query — + all already present. + +### Success Criteria + +- [ ] `POST /scenarios/simulate` returns a `ScenarioComparison` with `points` + length == `horizon`, baseline + scenario series, per-day + aggregate deltas, + a revenue delta, a `coverage_verdict`, and `method == "heuristic"` plus a + non-empty `disclaimer`. +- [ ] An empty `ScenarioAssumptions` yields scenario == baseline, all deltas 0.0. +- [ ] A bogus `run_id` → RFC 7807 problem response (404/400), never a 500. +- [ ] `adjustments.py` helpers are pure, never raise on junk input, and are + unit-tested directly. +- [ ] `test_leakage.py` proves the scenario adjustment touches only horizon + (future) points and never mutates / reads the historical series — treated as + a load-bearing spec (never weakened to make a feature pass). +- [ ] `POST /scenarios` persists a plan; `GET /scenarios` lists; `GET + /scenarios/{id}` returns the embedded comparison snapshot; `DELETE` removes + it; `GET /scenarios` on an empty table → 200 + empty list (never 404). +- [ ] The Alembic migration creates `scenario_plan` and upgrades **and** + downgrades cleanly on a fresh DB. +- [ ] `Visualize → What-If Planner` lets a user pick a baseline `predict` job, + define assumptions, run a simulation, see a baseline-vs-scenario chart + + delta table + a visible heuristic disclaimer, export the delta CSV, and + save / reload / delete a named plan. +- [ ] All gates pass: `ruff`, `mypy --strict`, `pyright --strict`, `pytest` + (unit + integration), frontend `tsc` + `lint` + `test`. +- [ ] No new external dependency; no managed-cloud SDK; no WebSocket; the slice + respects the no-cross-slice-service-import rule (DECISIONS LOCKED #2). +- [ ] README + `docs/_base/{API_CONTRACTS,REPO_MAP_INDEX,DOMAIN_MODEL}.md` updated. + +--- + +## All Needed Context + +### DECISIONS LOCKED (resolved during planning — do NOT re-litigate) + +1. **MVP scope only — the heuristic adjustment is a post-forecast multiplier.** + The baseline models (`naive`, `seasonal_naive`, `moving_average`) forecast from + the historical target series only and **ignore the exogenous `X` argument** + (verified: `# noqa: ARG002` on every `fit`/`predict` in + `forecasting/models.py`). The MVP therefore applies assumptions as a + **deterministic post-forecast multiplier** on the baseline forecast — never a + leakage-prone re-training. Every result is explicitly labelled + `method = "heuristic"` with a fixed `disclaimer` string. The "Full Version" + (future-feature-frame generator, exogenous-regressor model support, + agent-generated scenarios, multi-scenario comparison) is **out of scope** — it + needs models that consume future feature frames, which the MVP does not add. + +2. **The scenario service does NOT import a sibling slice's `service.py`.** + `AGENTS.md` § Architecture: "a slice may NOT import from another slice; + cross-cutting code goes through `app/core/` or `app/shared/`." The sanctioned + narrow exception (used by `ops`) is importing a sibling's **ORM `models.py`** + read-only — NOT its `service.py`. Calling `ForecastingService` from `scenarios` + would be a genuine cross-slice *service* import and **violates the rule**. + RESOLUTION: the scenario service imports only the **stable, lower-level + building blocks** — `load_model_bundle` from `forecasting/persistence.py` — and + produces the baseline forecast by calling `bundle.model.predict(horizon)` + directly (the `BaseForecaster` interface), replicating the ~30-line + `ForecastPoint`-construction block from `ForecastingService.predict` + (`forecasting/service.py`, the predict body). Read-only ORM imports of sibling + `models.py` (`data_platform`, `registry`) are allowed. Alternative considered + + rejected: promoting the predict logic to `app/shared/` — larger blast radius, + deferred to the Full Version. **This decision MUST be cited in the PR + description** per `product-vision.md` § "When Ideas Don't Align". + +3. **`scenario_plan` stores the comparison SNAPSHOT, not just the inputs.** A + saved plan persists both the raw `ScenarioAssumptions` **and** the full + `ScenarioComparison` as JSONB, so a reloaded plan re-renders without + recomputation (and without needing the original model artifact to still exist). + Persist via `model_dump(mode="json")` so `date`/`datetime` serialise to strings + (JSONB rejects Python `date`). + +4. **There is no `scenarios` commit scope.** The `.claude/rules/commit-format.md` + allow-list has **no `scenarios` scope**. Use `feat(api)` for the backend slice, + `feat(api,db)` for the slice + migration, `feat(ui)` for the frontend, + `test(api)` for backend tests, `docs(docs)` for docs. Do **not** invent a + scope. + +5. **The current Alembic head is `378c112e4b32`, NOT `d6e0f2g3h456`.** The source + plan guessed `d6e0f2g3h456`; the verified head (via `uv run alembic heads`) is + `378c112e4b32_create_app_config_table.py`. Set the new migration's + `down_revision = "378c112e4b32"` — but **re-verify with `uv run alembic heads` + immediately before writing the migration** (a PRP merging first would move it). + +6. **No WebSocket.** Simulation is request/response — `POST /scenarios/simulate` + computes synchronously and returns. No streaming surface — consistent with + `product-vision.md` "not a real-time streaming system". + +### Documentation & References + +```yaml +# ── MUST READ — repo files (read BEFORE implementing) ── + +- file: PRPs/PRP-25-forecastops-control-center-full.md + why: The most recent, highest-quality PRP. Mirror its DECISIONS LOCKED section, + its Known Gotchas table, its phased task list, its Anti-Patterns, and its + Confidence Score rationale. + +- file: PRPs/PRP-22-visualize-demand-planner.md + why: The immediate sibling — it added a new Visualize page (demand.tsx), a new + hook, a new pure util module, new TS types, a route + nav entry. The EXACT + frontend shape this feature reuses. Its Resolved Decisions + Known Gotchas + apply. + +- file: PRPs/PRP-9-rag-knowledge-base.md + why: The precedent for a slice that ships a new table + Alembic migration + + service compute path. Read its migration + model + service layering. + +- file: app/features/forecasting/service.py + why: ForecastingService.predict() is the baseline-forecast engine — loads a + .joblib bundle, validates store/product, calls bundle.model.predict(horizon), + builds ForecastPoints. The scenario service REPLICATES the ~30-line predict + body (per DECISIONS LOCKED #2) rather than importing this class. + critical: Path-traversal validation in predict() is load-bearing — mirror it. + +- file: app/features/forecasting/persistence.py + why: load_model_bundle — the SANCTIONED lower-level building block the scenario + service imports (NOT ForecastingService). Read what the bundle carries + (model, metadata with store_id/product_id/train_end_date). + +- file: app/features/forecasting/schemas.py + why: ForecastPoint (date, forecast, lower_bound?, upper_bound?) + PredictResponse + — the baseline series shape. TrainRequest is the request-body strict-mode + pattern (ConfigDict(strict=True) + Field(strict=False) on date fields). + +- file: app/features/forecasting/models.py + why: Confirms the baseline forecasters IGNORE the exogenous X argument + (# noqa: ARG002) — the reason the MVP is a post-forecast multiplier. + +- file: app/features/jobs/service.py + why: _execute_predict shows how a run_id resolves to a model artifact — + {artifacts_dir}/model_{run_id}.joblib, then load_model_bundle to read + store_id/product_id from bundle.metadata. The scenario service resolves the + baseline artifact the SAME way. + critical: A predict/train job's run_id is the ARTIFACT KEY (model_{run_id}.joblib), + NOT a registry model_run.run_id — see Known Gotchas. + +- file: app/features/jobs/schemas.py + why: JobCreate / JobResponse — the scenario page picks a completed predict job + whose result carries run_id, store_id, product_id, forecasts. + +- file: app/features/jobs/models.py + why: Job — JSONB columns for params/result, CheckConstraint, Index, + TimestampMixin. The scenario_plan table mirrors this. + +- file: app/features/registry/models.py + why: ModelRun (JSONB model_config/metrics; data_window_end; store_id/product_id; + run_id 32-char string; RunStatus). Read-only context lookup only. + +- file: app/features/data_platform/models.py + why: SalesDaily (store_id, product_id, date, quantity, unit_price) — used to + estimate a baseline unit price for the revenue-delta calc. PriceHistory / + Promotion (kind in {pct_off,bogo,bundle,markdown}, discount_pct) document + the real driver semantics the heuristic factors approximate. + +- file: app/features/ops/service.py + why: The pure module-scope helper pattern (extract_wape, score_retraining_candidate) + — exactly how adjustments.py helpers should be written (pure, never raise, + unit-tested directly). OpsService class shape mirrors ScenarioService. + +- file: app/features/ops/schemas.py + why: Response-model conventions — ConfigDict(from_attributes=True), every field + a Field(..., description=), counts ge=0, NO strict=True on response models. + +- file: app/features/ops/routes.py + why: APIRouter(prefix=, tags=), Query(default=, ge=, le=, description=), + Depends(get_db), rich docstrings. + +- file: app/features/rag/models.py + why: DocumentSource — the model-with-JSONB + TimestampMixin + String(32) + external-id + GIN-index pattern for the new scenario_plan table. + +- file: alembic/versions/37e16ecef223_create_jobs_table.py + why: The EXACT migration shape for a new JSONB-bearing table — op.create_table, + postgresql.JSONB(astext_type=sa.Text()), GIN index, CheckConstraint, + server_default=sa.text('now()') on timestamps, a real downgrade(). + +- file: app/features/ops/tests/conftest.py + why: Real-Postgres integration fixtures — ASGITransport client, FK-safe scoped + cleanup, TEST-/test- marker on every seeded natural key. The scenario + conftest MUST add delete(ScenarioPlan) to its cleanup. + +- file: app/features/ops/tests/test_service.py + why: The pattern for unit-testing pure helpers — adjustments.py gets the same. + +- file: app/features/ops/tests/test_routes_integration.py + why: The @pytest.mark.integration route-test pattern (happy + empty-DB + 422). + +- file: app/features/forecasting/tests/test_service.py + why: How ForecastingService is tested with a real model bundle on disk — needed + for the /scenarios/simulate integration test (it needs a trained model). + +- file: app/features/featuresets/tests/test_leakage.py + why: The leakage-spec precedent — a test that IS the spec, never weakened to + make a feature pass. The scenario test_leakage.py follows this philosophy. + +- file: app/core/problem_details.py + why: RFC 7807 application/problem+json envelope. The route layer maps + FileNotFoundError/ValueError from the service to structured problems. + +- file: frontend/src/pages/visualize/demand.tsx + why: The closest page — header, loading/error/empty early returns, Card/Table + sections, Select controls, useMemo derivations, a drill-in Card, CSV + export, formatNumber, @/ imports, keyboard-operable rows. The planner page + mirrors its skeleton. + +- file: frontend/src/pages/visualize/forecast.tsx + why: The in-page job-launch pattern — useCreateJob().mutateAsync(...), JobPicker, + a horizon Select, runError state, getErrorMessage. The planner reuses this + to pick a baseline predict job. + +- file: frontend/src/hooks/use-jobs.ts + why: useCreateJob (a useMutation) is the pattern for useSimulateScenario / + useCreateScenario / useDeleteScenario; useJobs({jobType:'predict', + status:'completed'}) lists baseline-candidate jobs. + +- file: frontend/src/hooks/use-ops.ts + why: The query-hook pattern (useQuery, queryKey array, api(path, {params})). + Mirror for useScenario / useScenarios. + +- file: frontend/src/lib/demand-utils.ts + why: The pure-util module pattern — typed, no React, no I/O, @/types/api + imports, fully unit-tested. scenario-utils.ts follows this exactly. + +- file: frontend/src/lib/csv-export.ts + why: toCsv/downloadCsv/CsvColumn — reuse for the delta-table export. + CSV-injection-safe; do NOT re-implement. + +- file: frontend/src/components/charts/time-series-chart.tsx + why: The Recharts wrapper — data, actualKey/predictedKey, showActual/showPredicted, + optional lowerKey/upperKey/showInterval band. The baseline-vs-scenario chart + renders TWO series here (baseline as actualKey, scenario as predictedKey). + critical: Verify the exact prop names in the file before wiring. + +- file: frontend/src/components/common/job-picker.tsx + why: JobPicker — reused verbatim to pick a baseline predict job. Also reuse + ErrorDisplay/EmptyState, LoadingState, StatusBadge from components/common/. + +- file: frontend/src/lib/constants.ts + why: ROUTES.VISUALIZE (FORECAST/BACKTEST/DEMAND) + the Visualize NAV_ITEMS + submenu — add PLANNER here. + +- file: frontend/src/App.tsx + why: The lazy(() => import()) block + the ROUTES.VISUALIZE.* block — + add the planner identically (copy the DEMAND route). + +- file: frontend/src/types/api.ts + why: Job, JobCreate, ForecastPoint, Product already defined — add Scenario* + interfaces here, mirroring the Ops* / InventoryStatus* additions. + +# ── Rules — read before writing any code ── + +- file: .claude/rules/product-vision.md + why: Principle 5 (time-safety), principle 8 (single-host), the "not a streaming + system" guardrail. Answer all 6 Litmus-Test questions in the PR description. + +- file: .claude/rules/security-patterns.md + why: Pydantic v2 at every boundary, SQLAlchemy parameter binding, + pathlib.Path.resolve() for the model-artifact path, the strict-mode + request-body policy. + +- file: .claude/rules/test-requirements.md + why: New module -> test file; new endpoint -> route test (2xx + 1 error path); + new model -> constraint test; new migration -> upgrade/downgrade clean. + +- file: .claude/rules/commit-format.md + why: type(scope): description (#issue). GOTCHA: no `scenarios` scope — see + DECISIONS LOCKED #4. + +- file: .claude/rules/branch-naming.md + why: branch feat/scenario-what-if-planner off dev. + +- file: .claude/rules/ui-design.md +- file: .claude/rules/shadcn-ui.md + why: Build the page via frontend-design + shadcn-ui skills; dogfood in a real + browser via webapp-testing / agent-browser. A green tsc is NOT proof the + UI works. + +# ── External documentation ── + +- url: https://fastapi.tiangolo.com/tutorial/body/ + why: The new slice's request-body endpoints follow this. + +- url: https://fastapi.tiangolo.com/tutorial/bigger-applications/#apirouter + why: Confirms APIRouter(prefix=...) + include_router registration. + +- url: https://docs.pydantic.dev/latest/concepts/strict_mode/ + why: Request bodies use ConfigDict(strict=True) + per-field Field(strict=False) + on JSON-non-native types (date); response models do NOT. This is the repo's + docs/_base/SECURITY.md policy — get it right or every HTTP caller 422s on + date fields. + +- url: https://docs.pydantic.dev/latest/concepts/models/ + why: model_dump(mode="json") for JSONB persistence of date/datetime fields. + +- url: https://recharts.org/en-US/api/LineChart + why: The baseline-vs-scenario two-series chart; TimeSeriesChart already wraps + Recharts — pass two series, do not hand-roll a chart. + +- url: https://tanstack.com/query/latest/docs/framework/react/guides/mutations + why: useSimulateScenario/useCreateScenario/useDeleteScenario are mutations; + useScenarios/useScenario are queries. Mirror use-jobs.ts. + +- url: https://www.nist.gov/itl/ai-risk-management-framework + why: The "Risks" section of the brief — over-trust of heuristic numbers. Drives + the MANDATORY method: "heuristic" label + a disclaimer string on every + ScenarioComparison (a transparency / explainability control). +``` + +> Note: LightGBM / XGBoost / Prophet / scikit-learn `TimeSeriesSplit` (cited in +> the feature brief) are **context for the Full Version only**. The MVP does NOT +> add an exogenous-regressor model — do not pull these in for MVP implementation. + +### Current Codebase tree (relevant slices) + +```bash +app/ +├── main.py # router wiring — add scenarios_router +├── core/ +│ ├── config.py # get_settings() — forecast_model_artifacts_dir (str) +│ ├── database.py # get_db dependency +│ └── problem_details.py # RFC 7807 envelope +└── features/ + ├── data_platform/models.py # SalesDaily, PriceHistory, Promotion, Calendar + ├── forecasting/ + │ ├── persistence.py # load_model_bundle <- IMPORT THIS + │ ├── service.py # ForecastingService.predict <- replicate body + │ ├── models.py # baseline forecasters (ignore exogenous X) + │ └── schemas.py # ForecastPoint, PredictResponse, TrainRequest + ├── jobs/ # Job model, _execute_predict (artifact resolution) + ├── registry/models.py # ModelRun + └── ops/ # the pattern-mirror slice (pure helpers, schemas) +alembic/versions/ +└── 378c112e4b32_create_app_config_table.py # <- current head (VERIFY) +frontend/src/ +├── pages/visualize/{demand,forecast,backtest}.tsx +├── hooks/{use-jobs,use-ops}.ts +├── lib/{demand-utils,csv-export,constants}.ts +├── components/charts/time-series-chart.tsx +├── components/common/{job-picker,error-display,loading-state,status-badge}.tsx +├── types/api.ts +└── App.tsx +``` + +### Desired Codebase tree — files to ADD + +```bash +# ── Backend: the new `scenarios` vertical slice ── +app/features/scenarios/__init__.py # slice package + __all__ +app/features/scenarios/models.py # ScenarioPlan ORM model (JSONB) +app/features/scenarios/schemas.py # request + response Pydantic models +app/features/scenarios/adjustments.py # PURE deterministic factor math +app/features/scenarios/service.py # ScenarioService (simulate + CRUD) +app/features/scenarios/routes.py # APIRouter — 5 endpoints +app/features/scenarios/tests/__init__.py +app/features/scenarios/tests/conftest.py # real-Postgres fixtures + cleanup +app/features/scenarios/tests/test_adjustments.py # PURE-function unit tests +app/features/scenarios/tests/test_schemas.py # schema unit tests +app/features/scenarios/tests/test_leakage.py # leakage spec (load-bearing) +app/features/scenarios/tests/test_routes_integration.py # @pytest.mark.integration + +# ── Backend: migration ── +alembic/versions/_create_scenario_plan_table.py # new table + +# ── Frontend ── +frontend/src/hooks/use-scenarios.ts # query + mutation hooks +frontend/src/lib/scenario-utils.ts # PURE chart-merge + delta utils +frontend/src/lib/scenario-utils.test.ts # vitest +frontend/src/pages/visualize/planner.tsx # the /visualize/planner What-If page +``` + +### Files to MODIFY (all additive) + +```bash +app/main.py # +1 import, +1 include_router(scenarios_router) +frontend/src/types/api.ts # +Scenario* interfaces +frontend/src/lib/constants.ts # +ROUTES.VISUALIZE.PLANNER + nav entry +frontend/src/App.tsx # +1 lazy import + 1 +README.md # feature-list mention +docs/_base/API_CONTRACTS.md # +/scenarios/* rows +docs/_base/REPO_MAP_INDEX.md # +scenarios slice + planner.tsx rows +docs/_base/DOMAIN_MODEL.md # +scenario_plan aggregate + ubiquitous-language rows +``` + +### Known Gotchas of our codebase & Library Quirks + +| # | Gotcha | Mitigation | +|---|--------|-----------| +| 1 | A predict/train job's `run_id` is the **artifact key** (`model_{run_id}.joblib`), NOT a registry `model_run.run_id`. Passing the wrong one yields a missing artifact. | Resolve the artifact exactly as `jobs/service.py:_execute_predict` does. The page picks a *completed predict job* whose `result.run_id` is the artifact key. A bogus `run_id` must surface as a 404/400 problem, never a 500. | +| 2 | A slice may **NOT** import another slice's `service.py`. Importing `ForecastingService` from `scenarios` violates `AGENTS.md` § Architecture. | Import only `load_model_bundle` from `forecasting/persistence.py`; produce the baseline by calling `bundle.model.predict(horizon)` and replicating the ~30-line `ForecastPoint`-construction block. Cite in the PR (DECISIONS LOCKED #2). | +| 3 | FastAPI calls `TypeAdapter.validate_python` on request bodies. With `ConfigDict(strict=True)`, Pydantic refuses to coerce ISO-string dates → every HTTP caller 422s on `date` fields. | Request bodies: `ConfigDict(strict=True)` + `Field(strict=False, ...)` on **every** `date`/`datetime` field. Response models: `from_attributes=True`, **NO** `strict=True`. (`docs/_base/SECURITY.md`.) | +| 4 | JSONB rejects Python `date`/`datetime` objects. | Persist `assumptions` / `comparison` via `model_dump(mode="json")` so dates serialise to ISO strings. | +| 5 | SQLAlchemy reserves the attribute name `metadata` on declarative models (`rag` works around it with `metadata_`/`"metadata"`). | Name the `scenario_plan` JSONB columns `assumptions` and `comparison` — never `metadata`. | +| 6 | The current Alembic head is **`378c112e4b32`**, not the `d6e0f2g3h456` the source plan guessed. | Set `down_revision = "378c112e4b32"`, but re-verify with `uv run alembic heads` immediately before writing the migration. Migrations are forward-only after merge. | +| 7 | There is **no `scenarios` commit scope** in the allow-list. | Use `feat(api)` / `feat(api,db)` / `feat(ui)` / `test(api)` / `docs(docs)` (DECISIONS LOCKED #4). | +| 8 | The baseline forecasters ignore exogenous regressors — a "what-if" cannot be done by re-prediction. | The MVP applies a **post-forecast deterministic multiplier**; label every result `method = "heuristic"` + a `disclaimer`. | +| 9 | A green `pnpm tsc` is NOT proof the UI works. | Dogfood the running page in a real browser via `webapp-testing` / `agent-browser` (Task C7) — mandatory per `.claude/rules/ui-design.md`. | +| 10 | `units_delta_pct` divide-by-zero when baseline demand is 0. | Guard: return `0.0` when `baseline_total_units == 0`. | +| 11 | An assumption window can fall entirely **before** the forecast start. | The adjustment touches **only** horizon (future) days; out-of-window days contribute factor `1.0`. The leakage test asserts this. | +| 12 | The repo uses **CRLF line endings** on `.py` files (no `.gitattributes`); scripted text-mode writes can silently flip them to LF. | Edit `app/main.py` minimally; preserve existing line endings. | + +--- + +## Implementation Blueprint + +### Data models and structure + +**Backend — `adjustments.py` (PURE, no DB, no I/O, never raises):** + +```python +# Module constants — documented deterministic heuristic factors. +# Final values are a planning DECISION — lock them before coding (see NOTES). +PRICE_ELASTICITY = -1.2 # demand_factor = (1 + change_pct) ** PRICE_ELASTICITY +PROMOTION_UPLIFT_BY_KIND = {"pct_off": 1.25, "bogo": 1.40, "bundle": 1.15, "markdown": 1.30} +HOLIDAY_UPLIFT = 1.30 +LIFECYCLE_FACTOR = {"launch": 1.2, "growth": 1.1, "maturity": 1.0, "decline": 0.85} +FACTOR_BAND = (0.1, 5.0) # clamp band — no negative / explosive forecast + +def clamp(value, lo, hi) -> float: ... +def price_factor(price_change_pct: float) -> float: ... # constant-elasticity +def promotion_factor(kind: str, active: bool) -> float: ... # 1.0 if not active / unknown kind +def holiday_factor(is_holiday: bool) -> float: ... +def lifecycle_factor(stage: str | None) -> float: ... # 1.0 for None / unknown +def combined_daily_factor(*, day_index, horizon, assumptions) -> float: ... # multiply applicable, clamp +def apply_adjustment(baseline: list[float], factors: list[float]) -> list[float]: + # element-wise multiply; len asserted equal; every output max(0.0, ...) +``` + +**Backend — `schemas.py` (Pydantic v2):** + +```python +# Request models — ConfigDict(strict=True) + Field(strict=False) on every date: +class PriceAssumption: change_pct: float (ge=-0.9, le=5.0); start_date/end_date: date +class PromotionAssumption: kind: Literal["pct_off","bogo","bundle","markdown"]; start_date/end_date: date +class HolidayAssumption: dates: list[date] +class InventoryAssumption: on_hand_units: int (ge=0) # caps coverage, not demand +class LifecycleAssumption: stage: Literal["launch","growth","maturity","decline"] +class ScenarioAssumptions: price/promotion/holiday/inventory/lifecycle: ... | None = None +class SimulateScenarioRequest: run_id: str; horizon: int (ge=1, le=90); assumptions; name: str | None +class CreateScenarioRequest: name: str; run_id: str; horizon: int; assumptions: ScenarioAssumptions + +# Response models — ConfigDict(from_attributes=True), NO strict=True, Field(..., description=): +class ScenarioPoint: date; baseline; scenario; delta; applied_factor: float +class ScenarioComparison: store_id; product_id; model_type; horizon; points: list[ScenarioPoint]; + baseline_total_units; scenario_total_units; units_delta; units_delta_pct; + baseline_revenue; scenario_revenue; revenue_delta; unit_price_used; + coverage_verdict: Literal["covered","at_risk","stockout","unknown"]; + method: Literal["heuristic"]; disclaimer: str; generated_at: datetime +class ScenarioPlanResponse: scenario_id; name; store_id; product_id; run_id; horizon; method; + created_at; comparison: ScenarioComparison; assumptions: ScenarioAssumptions +class ScenarioListItem: scenario_id; name; store_id; product_id; units_delta; revenue_delta; created_at +class ScenarioListResponse: scenarios: list[ScenarioListItem]; total: int (ge=0) +``` + +**Backend — `models.py` (`ScenarioPlan(TimestampMixin, Base)`):** + +```python +id: int (PK); scenario_id: str String(32) unique index; name: str String(200) +store_id: int (index); product_id: int (index); run_id: str String(32) (index) +horizon: int +assumptions: dict[str, Any] -> JSONB # raw ScenarioAssumptions dump +comparison: dict[str, Any] -> JSONB # full ScenarioComparison snapshot +method: str String(20) -> CheckConstraint("method IN ('heuristic')") +__table_args__: GIN index on assumptions + comparison; composite (store_id, product_id) +``` + +### list of tasks to be completed (dependency-ordered) + +The work is **three independently-shippable phases**. Prefer one PR per phase (or +one phased PR). **Phase A must merge before Phase C can be dogfooded** (the page +needs the endpoint). + +```yaml +Task 0 — SETUP: tracking issue + branch: + - Open ONE GitHub issue "Scenario Simulation / What-If Planning (MVP)"; confirm OPEN. + - git fetch origin && git switch -c feat/scenario-what-if-planner origin/dev + - GOTCHA: no `scenarios` commit scope — use feat(api)/feat(api,db)/feat(ui)/docs(docs). + - VALIDATE: gh issue view --json state -> OPEN + +# ════════ PHASE A — Stateless Simulation Engine (backend) ════════ + +Task A1 — CREATE app/features/scenarios/__init__.py + tests/__init__.py: + - Docstring + empty __all__ (extend as schemas land); empty tests/__init__.py. + - PATTERN: app/features/ops/__init__.py + - VALIDATE: uv run python -c "import app.features.scenarios" + +Task A2 — CREATE app/features/scenarios/adjustments.py: + - The PURE deterministic adjustment engine (see Data models above). stdlib only, + `from __future__ import annotations`, no numpy. Every helper tolerates junk + input (negative pct, unknown kind, None stage) and returns a sane factor — + NEVER raises. + - PATTERN: app/features/ops/service.py pure module-scope helpers. + - VALIDATE: uv run python -c "from app.features.scenarios.adjustments import combined_daily_factor; print('ok')" + +Task A3 — CREATE app/features/scenarios/schemas.py (Phase A subset): + - The simulate request models + ScenarioComparison + ScenarioPoint (see above). + - PATTERN: forecasting/schemas.py:TrainRequest (request, strict); ops/schemas.py + (response, from_attributes). + - GOTCHA: strict=True ONLY on request bodies; Field(strict=False) on every date. + - VALIDATE: uv run python -c "from app.features.scenarios.schemas import SimulateScenarioRequest, ScenarioComparison; print('ok')" + +Task A4 — CREATE app/features/scenarios/service.py (Phase A subset): + - ScenarioService.simulate(db, request) -> ScenarioComparison: + 1. Resolve artifact: artifacts_dir = Path(settings.forecast_model_artifacts_dir); + model_path = (artifacts_dir / f"model_{run_id}.joblib").resolve() + (mirror jobs/service.py:_execute_predict — the setting is a str, wrap in Path). + Then mirror the LOAD-BEARING path-traversal guard from + forecasting/service.py:218-248 — reject a non-`.joblib` suffix and any path + that escapes artifacts_dir (`resolved_path.relative_to(artifacts_dir)`) with + ValueError. FileNotFoundError if the validated path is absent. + 2. load_model_bundle -> read store_id/product_id from bundle.metadata. + 3. Produce the baseline series by calling bundle.model.predict(horizon) and + replicating the ForecastPoint-construction block from + ForecastingService.predict (DECISIONS LOCKED #2 — do NOT import the + sibling service). + 4. Estimate unit_price_used: most-recent non-null SalesDaily.unit_price for + (store, product); fall back to a documented default + log a warning. + 5. Per horizon day: applied_factor = adjustments.combined_daily_factor(...); + scenario = max(0.0, baseline * applied_factor). + 6. Aggregate totals, units_delta, units_delta_pct (guard /0), revenue, deltas. + 7. coverage_verdict from the inventory assumption (covered / at_risk / stockout + / unknown). + 8. Return ScenarioComparison(method="heuristic", disclaimer=). + logger.info("scenarios.simulated", ...). + - PATTERN: ops/service.py (class shape, logging); jobs/service.py:_execute_predict + (artifact resolution). + - VALIDATE: uv run mypy app/features/scenarios/ && uv run pyright app/features/scenarios/ + +Task A5 — CREATE app/features/scenarios/routes.py (Phase A subset): + - router = APIRouter(prefix="/scenarios", tags=["scenarios"]); + POST /scenarios/simulate -> response_model=ScenarioComparison, rich docstring. + Map FileNotFoundError/ValueError -> RFC 7807 problem (read forecasting/routes.py + + app/core/problem_details.py first). + - PATTERN: ops/routes.py; forecasting/routes.py. + - VALIDATE: uv run ruff check app/features/scenarios/ && uv run mypy app/ + +Task A6 — UPDATE app/main.py: + - +1 import (alphabetical) + app.include_router(scenarios_router). + - GOTCHA: preserve line endings; edit minimally (Gotcha #12). + - VALIDATE: uv run python -c "from app.main import app; assert '/scenarios/simulate' in {r.path for r in app.routes}; print('wired')" + +Task A7 — CREATE tests/test_adjustments.py + test_schemas.py: + - test_adjustments.py: every pure helper — factor math, clamp bounds, + kind/stage fallthrough, junk-input tolerance, apply_adjustment element-wise + + non-negative. + - test_schemas.py: SimulateScenarioRequest from ISO-string dates via + model_validate({...}) (the FastAPI validate_python path); change_pct bounds. + - PATTERN: ops/tests/test_service.py + test_schemas.py. + - VALIDATE: uv run pytest -v -m "not integration" app/features/scenarios/tests/test_adjustments.py app/features/scenarios/tests/test_schemas.py + +Task A8 — CREATE tests/test_leakage.py (LOAD-BEARING): + - Assert the adjustment touches ONLY horizon points: apply_adjustment returns a + new list and leaves the input baseline unchanged; out-of-window days + contribute factor 1.0; len(points) == horizon; an assumption window before + the forecast start contributes no factor. + - PATTERN: app/features/featuresets/tests/test_leakage.py. + - GOTCHA: never weaken this test to make a feature pass (AGENTS.md § Safety). + - VALIDATE: uv run pytest -v -m "not integration" app/features/scenarios/tests/test_leakage.py + +# ════════ PHASE B — Saved Scenario Plans (persistence) ════════ + +Task B1 — CREATE app/features/scenarios/models.py: + - ScenarioPlan(TimestampMixin, Base) — see Data models above. + - PATTERN: jobs/models.py + rag/models.py:DocumentSource. + - GOTCHA: do NOT name a column `metadata` (Gotcha #5). + - VALIDATE: uv run python -c "from app.features.scenarios.models import ScenarioPlan; print(ScenarioPlan.__tablename__)" + +Task B2 — CREATE the Alembic migration: + - uv run alembic revision -m "create scenario plan table", then hand-write + upgrade() (op.create_table with all columns, postgresql.JSONB(astext_type= + sa.Text()), GIN indexes, the CheckConstraint, created_at/updated_at with + server_default=sa.text('now()')) and a real downgrade(). + - PATTERN: alembic/versions/37e16ecef223_create_jobs_table.py. + - GOTCHA: confirm head with `uv run alembic heads` and set down_revision to it + (currently 378c112e4b32 — VERIFY, do not assume; Gotcha #6). + - VALIDATE: docker compose up -d && uv run alembic upgrade head && uv run alembic downgrade -1 && uv run alembic upgrade head + +Task B3 — EXTEND app/features/scenarios/schemas.py: + - Add CreateScenarioRequest (request), ScenarioPlanResponse, ScenarioListItem, + ScenarioListResponse (responses — from_attributes). + - PATTERN: ops/schemas.py; jobs/schemas.py:JobListResponse. + - VALIDATE: uv run mypy app/features/scenarios/schemas.py + +Task B4 — EXTEND app/features/scenarios/service.py: + - create_plan (runs simulate, persists ScenarioPlan, scenario_id=uuid4().hex), + list_plans, get_plan, delete_plan. + - GOTCHA: persist comparison + assumptions via model_dump(mode="json") + (Gotcha #4). + - PATTERN: jobs/service.py (create/list/get); registry/service.py. + - VALIDATE: uv run mypy app/ && uv run pyright app/ + +Task B5 — EXTEND app/features/scenarios/routes.py: + - POST /scenarios (201); GET /scenarios (limit/offset Query params, bounded); + GET /scenarios/{scenario_id} (404 problem when missing); + DELETE /scenarios/{scenario_id} (204, 404 when missing). + - PATTERN: registry/routes.py (alias CRUD with 404 mapping). + - VALIDATE: uv run python -c "from app.main import app; paths={r.path for r in app.routes}; assert {'/scenarios','/scenarios/{scenario_id}'} <= paths; print('wired')" + +Task B6 — CREATE tests/conftest.py + test_routes_integration.py: + - conftest.py: real-Postgres fixtures (ASGITransport client; scoped cleanup that + INCLUDES delete(ScenarioPlan) + FK-safe deletes of seeded TEST-/test- rows; a + trained_model fixture that puts a real bundle on disk for simulate). + - test_routes_integration.py (@pytest.mark.integration): simulate happy path + (200, points length == horizon, method == "heuristic"); bogus run_id -> RFC + 7807 problem (not 500); full CRUD round-trip; GET /scenarios on empty table -> + 200 + []. Plus a constraint test for the method CheckConstraint. + - PATTERN: ops/tests/conftest.py + test_routes_integration.py; + forecasting/tests/conftest.py. + - GOTCHA: never mock the DB; integration tests need docker compose up + alembic + upgrade head. + - VALIDATE: docker compose up -d && uv run alembic upgrade head && uv run pytest -v -m integration app/features/scenarios/ + +# ════════ PHASE C — What-If Planner Page (frontend) ════════ + +Task C1 — UPDATE frontend/src/types/api.ts: + - Add Scenario* TS interfaces (dates as string) mirroring the backend schemas. + - PATTERN: the Ops* / InventoryStatus* blocks already in types/api.ts. + - VALIDATE: cd frontend && pnpm tsc --noEmit + +Task C2 — CREATE frontend/src/hooks/use-scenarios.ts: + - useSimulateScenario (mutation), useCreateScenario (mutation, invalidates list), + useScenarios (query), useScenario(scenarioId) (query), useDeleteScenario + (mutation). + - PATTERN: use-jobs.ts (useCreateJob mutation + useJobs query); use-ops.ts. + - VALIDATE: cd frontend && pnpm tsc --noEmit + +Task C3 — CREATE frontend/src/lib/scenario-utils.ts + scenario-utils.test.ts: + - PURE utils: mergeComparisonSeries (ScenarioPoint[] -> chart rows), + formatDelta (signed), deltaCsvColumns (CsvColumn[]), + summariseAssumptions (human-readable bullets). + - PATTERN: demand-utils.ts. + - VALIDATE: cd frontend && pnpm test --run src/lib/scenario-utils.test.ts + +Task C4 — UPDATE frontend/src/lib/constants.ts + frontend/src/App.tsx: + - ROUTES.VISUALIZE.PLANNER = '/visualize/planner'; a Visualize NAV_ITEMS entry; + a lazy import + in App.tsx (copy the DEMAND route). + - GOTCHA: pnpm tsc fails until Task C5 creates the page — re-run after C5. + - VALIDATE: (after C5) cd frontend && pnpm tsc --noEmit + +Task C5 — CREATE frontend/src/pages/visualize/planner.tsx: + - Build via frontend-design + shadcn-ui skills. Header Card with a prominent + heuristic-disclaimer banner; baseline picker (JobPicker jobType="predict" + + horizon Select); assumptions form (price slider, promotion kind + window, + holiday dates, inventory units, lifecycle stage — all optional); a "Run + simulation" Button; results (TimeSeriesChart two-series, KPI tiles, per-day + delta Table + Export CSV, "Save as plan"); a saved-plans Card (list, reload, + delete). Standard LoadingState / ErrorDisplay / EmptyState early returns. + - PATTERN: demand.tsx (skeleton, states, drill-in, CSV); forecast.tsx (in-page + job launch). + - GOTCHA: renders inside AppShell; shadcn semantic tokens only; a green tsc is + NOT proof the UI works (Gotcha #9). + - VALIDATE: cd frontend && pnpm tsc --noEmit && pnpm lint + +Task C6 — UPDATE docs: + - README.md (feature list); docs/_base/API_CONTRACTS.md (5 /scenarios/* rows); + docs/_base/REPO_MAP_INDEX.md (scenarios slice + planner.tsx); + docs/_base/DOMAIN_MODEL.md (scenario_plan aggregate + ubiquitous-language rows). + - VALIDATE: git diff --stat docs/ README.md + +Task C7 — Dogfood the running UI (MANDATORY per .claude/rules/ui-design.md): + - docker compose up -d && alembic upgrade head && seed_random --full-new, then + `make demo` (so completed predict jobs + trained models exist), start uvicorn + + vite, exercise via webapp-testing / agent-browser: pick a baseline job, + define a -15% price + a pct_off promotion, run, confirm the two-series chart + + non-zero deltas + the disclaimer banner, export the delta CSV, save the plan, + reload it, delete it. Capture screenshots. + - VALIDATE: screenshots captured; all 8 manual-check scenarios pass. + +Task C8 — Commit + PR: + - Commits (each (#issue), no AI co-author trailer): + feat(api): add scenario simulation engine and simulate endpoint (#N) + test(api): cover scenario adjustments, schemas, and leakage spec (#N) + feat(api,db): add scenario_plan table and CRUD endpoints (#N) + test(api): cover scenario plan persistence and CRUD (#N) + feat(ui): add scenario data layer — types, hooks, scenario-utils (#N) + feat(ui): add Visualize What-If Planner page (#N) + docs(docs): document scenario simulation slice and planner page (#N) + - GOTCHA: the PR description MUST flag (a) results are heuristic, deliberately + labelled, not model-causal — the Full-Version exogenous-model path is out of + scope; (b) the scenario service deliberately does NOT import sibling + ForecastingService (DECISIONS LOCKED #2). Answer the 6 product-vision Litmus + questions. + - VALIDATE: open PR into dev; CI green; merge. +``` + +### Per-task pseudocode (critical details only) + +```python +# Task A4 — ScenarioService.simulate (the heart of Phase A) +async def simulate(self, db: AsyncSession, request: SimulateScenarioRequest) -> ScenarioComparison: + settings = get_settings() + # GOTCHA #1/#2: resolve the ARTIFACT, mirror jobs/service.py:_execute_predict. + # forecast_model_artifacts_dir is a str — wrap in Path (NOT settings.artifacts_dir). + artifacts_dir = Path(settings.forecast_model_artifacts_dir).resolve() + model_path = (artifacts_dir / f"model_{request.run_id}.joblib").resolve() + # LOAD-BEARING path-traversal guard — mirror forecasting/service.py:218-248 + if model_path.suffix != ".joblib": + raise ValueError(f"Invalid model path for run_id={request.run_id}") + try: + model_path.relative_to(artifacts_dir) # rejects ../ escape + except ValueError: + raise ValueError(f"Invalid model path for run_id={request.run_id}") from None + if not model_path.exists(): + raise FileNotFoundError(f"No model artifact for run_id={request.run_id}") + bundle = load_model_bundle(model_path) # forecasting/persistence.py + # bundle.metadata is dict[str, object] — int(str(...)) keeps mypy --strict happy + store_id = int(str(bundle.metadata["store_id"])) + product_id = int(str(bundle.metadata["product_id"])) + # DECISIONS LOCKED #2: replicate the ForecastingService.predict body — do NOT import it + raw = bundle.model.predict(request.horizon) # BaseForecaster interface + # train_end_date is stored as an ISO STRING — parse it; fall back to today if absent + train_end_raw = bundle.metadata.get("train_end_date") + train_end_date = (date.fromisoformat(train_end_raw) + if isinstance(train_end_raw, str) + else datetime.now(UTC).date()) + start = train_end_date + timedelta(days=1) + baseline_pts = [ForecastPoint(date=start + timedelta(days=i), forecast=float(v)) + for i, v in enumerate(raw)] + # estimate unit price (revenue delta) + unit_price = await self._latest_unit_price(db, store_id, product_id) # default + warn if none + # apply per-day deterministic factors — adjustments.py is PURE + factors = [adjustments.combined_daily_factor(day_index=i, horizon=request.horizon, + assumptions=request.assumptions) for i in range(request.horizon)] + baseline = [p.forecast for p in baseline_pts] + scenario = adjustments.apply_adjustment(baseline, factors) # element-wise, max(0.0,...) + # aggregate — guard divide-by-zero (Gotcha #10) + ... + return ScenarioComparison(method="heuristic", disclaimer=HEURISTIC_DISCLAIMER, ...) +``` + +### Integration Points + +```yaml +DATABASE: + - migration: "create scenario_plan table (id, scenario_id, name, store_id, + product_id, run_id, horizon, assumptions JSONB, comparison JSONB, + method, created_at, updated_at)" + - index: "GIN on assumptions + comparison; composite (store_id, product_id); + unique on scenario_id" + - constraint: "CheckConstraint method IN ('heuristic')" + - down_revision: "378c112e4b32 (VERIFY with `uv run alembic heads`)" + +ROUTES: + - add to: app/main.py + - pattern: "from app.features.scenarios.routes import router as scenarios_router + ... app.include_router(scenarios_router)" + +FRONTEND ROUTING: + - add to: frontend/src/lib/constants.ts + - pattern: "ROUTES.VISUALIZE.PLANNER = '/visualize/planner' + a Visualize + NAV_ITEMS entry" + - add to: frontend/src/App.tsx + - pattern: "lazy(() => import('@/pages/visualize/planner')) + a " + +CONFIG: + - no new config — reuses settings.forecast_model_artifacts_dir (a str; + wrap with Path(...) before use, app/core/config.py) +``` + +--- + +## Validation Loop + +### Level 1: Syntax & Style + +```bash +uv run ruff check . --fix && uv run ruff format --check . +cd frontend && pnpm lint +# Traps: date.today() / naive datetime -> ruff DTZ (use datetime.now(UTC)); +# os.path -> ruff PTH (use pathlib.Path); a stray # noqa -> RUF100. +``` + +### Level 2: Type Checks + +```bash +uv run mypy app/ && uv run pyright app/ # both --strict, both gate merge +cd frontend && pnpm tsc --noEmit +``` + +### Level 3: Unit Tests + +```bash +uv run pytest -v -m "not integration" app/features/scenarios/ +cd frontend && pnpm test --run src/lib/scenario-utils.test.ts +``` + +### Level 4: Integration Tests + +```bash +docker compose up -d && uv run alembic upgrade head +uv run pytest -v -m integration app/features/scenarios/ +# Migration up/down check: +uv run alembic downgrade -1 && uv run alembic upgrade head +``` + +### Level 5: Manual Validation (dogfood — REQUIRED) + +```bash +docker compose up -d && uv run alembic upgrade head +uv run python scripts/seed_random.py --full-new --seed 42 --confirm +make demo # populates predict jobs + models +uv run uvicorn app.main:app --port 8123 & +until curl -fs http://127.0.0.1:8123/health; do sleep 2; done +# stateless simulate: +curl -s -X POST http://localhost:8123/scenarios/simulate \ + -H 'content-type: application/json' \ + -d '{"run_id":"","horizon":14, + "assumptions":{"price":{"change_pct":-0.15, + "start_date":"2026-06-01","end_date":"2026-06-14"}}}' | head -c 600 +curl -s -o /dev/null -w '%{http_code}\n' -X POST \ + http://localhost:8123/scenarios/simulate -H 'content-type: application/json' \ + -d '{"run_id":"does-not-exist","horizon":14,"assumptions":{}}' # expect 404, not 500 +# Frontend: cd frontend && ./node_modules/.bin/vite --host 0.0.0.0 +# -> open http://localhost:5173/visualize/planner via webapp-testing/agent-browser: +# pick a baseline job, set a price + promotion assumption, run, verify the +# two-series chart + non-zero deltas + the heuristic disclaimer, export the +# delta CSV, save the plan, reload it from the saved list, delete it. +``` + +### Level 6: Additional Validation (optional) + +```bash +# Confirm Recharts / TanStack Query usage against current docs via the contex7 MCP +# if the TimeSeriesChart two-series wiring or a mutation pattern is uncertain. +``` + +--- + +## Final Validation Checklist + +- [ ] `uv run ruff check . && uv run ruff format --check .` — clean +- [ ] `uv run mypy app/ && uv run pyright app/` — clean (`--strict`) +- [ ] `uv run pytest -v -m "not integration"` — green (incl. the leakage spec) +- [ ] `docker compose up -d && uv run pytest -v -m integration` — green +- [ ] `uv run alembic upgrade head && uv run alembic downgrade -1 && uv run alembic upgrade head` — clean +- [ ] `cd frontend && pnpm tsc --noEmit && pnpm lint && pnpm test --run` — green +- [ ] `POST /scenarios/simulate` behaves per Success Criteria (points length == + horizon, method == "heuristic", bogus run_id -> RFC 7807, not 500) +- [ ] CRUD round-trip works; `GET /scenarios` on an empty table -> 200 + [] +- [ ] `/visualize/planner` runs a simulation, shows a two-series chart + delta + table + heuristic disclaimer, exports CSV, saves/reloads/deletes a plan — + dogfooded in a browser (screenshots captured) +- [ ] No new external dependency; no managed-cloud SDK; no WebSocket +- [ ] `scenario_plan` table created via migration; columns named `assumptions` / + `comparison` (never `metadata`) +- [ ] README + `docs/_base/{API_CONTRACTS,REPO_MAP_INDEX,DOMAIN_MODEL}.md` updated +- [ ] Branch `feat/scenario-what-if-planner`; every commit references the tracking + issue; commit scopes are `api`/`api,db`/`ui`/`docs` (no `scenarios` scope); + no AI co-author trailer +- [ ] PR description flags: (a) heuristic, not model-causal — Full Version out of + scope; (b) the scenario service deliberately does NOT import sibling + `ForecastingService` (DECISIONS LOCKED #2); and answers the 6 Litmus-Test + questions + +--- + +## Anti-Patterns to Avoid + +- ❌ Don't import `ForecastingService` (or any sibling slice's `service.py`) into + `scenarios` — import only `load_model_bundle` from `forecasting/persistence.py` + and replicate the predict body (DECISIONS LOCKED #2). +- ❌ Don't add an exogenous-regressor model or a future-feature-frame generator — + that is the Full Version, out of scope (DECISIONS LOCKED #1). +- ❌ Don't re-train a model to produce the scenario — the MVP applies a + post-forecast deterministic multiplier. +- ❌ Don't drop the `method: "heuristic"` label or the `disclaimer` — they are the + NIST-AI-RMF transparency control against over-trust. +- ❌ Don't put `ConfigDict(strict=True)` on response models; don't omit + `Field(strict=False)` on `date` fields of request bodies (Gotcha #3). +- ❌ Don't name a `scenario_plan` column `metadata` — SQLAlchemy reserves it + (Gotcha #5). +- ❌ Don't persist Python `date`/`datetime` into JSONB — use + `model_dump(mode="json")` (Gotcha #4). +- ❌ Don't guess the migration `down_revision` — verify with `uv run alembic heads` + (Gotcha #6). +- ❌ Don't weaken `test_leakage.py` to make a feature pass — it is the spec. +- ❌ Don't `raise HTTPException(500, "raw string")` — use the RFC 7807 envelope. +- ❌ Don't add a WebSocket — simulation is request/response (DECISIONS LOCKED #6). +- ❌ Don't `pnpm add` anything — Recharts / TanStack Query / shadcn primitives are + installed. +- ❌ Don't hand-roll a chart — pass two series to the existing `TimeSeriesChart`. +- ❌ Don't claim the UI works on a green type-check — dogfood it in a browser. +- ❌ Don't invent a `scenarios` commit scope — use `api`/`api,db`/`ui`/`docs`. + +## NOTES — open questions / planning decisions to lock before coding + +- **Heuristic factor values are not yet final.** `PRICE_ELASTICITY = -1.2`, + `PROMOTION_UPLIFT_BY_KIND`, `HOLIDAY_UPLIFT = 1.30`, `LIFECYCLE_FACTOR`, and the + `FACTOR_BAND` clamp `[0.1, 5.0]` are *suggested starting values*. They are a + planning decision — confirm them (or adjust) before coding `adjustments.py`. + They are deliberately conservative and documented as constants so a reviewer can + see and tune them. The tests assert *direction and bounds* (a price cut → uplift + > 1, a clamp keeps the factor in band), not exact magnitudes — so reasonable + re-tuning does not break tests. +- **`coverage_verdict` band**: `at_risk` is suggested as "scenario total within + 10% of `on_hand_units`". Confirm the band before coding. +- **`unit_price_used` fallback**: when no `SalesDaily` row exists for the + `(store, product)`, the service falls back to a documented default (suggested + `1.0`) and logs a warning. Confirm the default. +- **Lifecycle stage source**: the MVP takes `lifecycle.stage` as a direct user + override on the assumption form — it does not derive the current stage from + `product.launch_date`. Deriving it is a Full-Version concern. + +## Confidence Score + +**9 / 10** for one-pass implementation success. + +Rationale: the source plan (`.agents/plans/scenario-simulation-what-if-planning.md`) +is unusually thorough and was validated against the repo as of 2026-05-19 — every +file path, class name, and pattern reference here was cross-checked. The two +highest-risk areas are both de-risked: (1) the cross-slice-import constraint is +resolved by a locked decision (import `load_model_bundle`, replicate the predict +body) rather than left for the implementer to discover; (2) the artifact-resolution +gotcha (`run_id` is the artifact key, not a registry id) is called out explicitly +with the exact reference (`jobs/service.py:_execute_predict`). `adjustments.py` is +pure, dependency-free, and trivially unit-testable. Phase C is a near-exact mirror +of PRP-22's `demand.tsx` shape over a deterministic backend. The residual 1-point +risk is the heuristic factor *values* (NOTES) — a planning decision that does not +affect the structure, and the tests assert direction/bounds rather than exact +magnitudes, so re-tuning is safe. The biggest scope risks of the feature brief — +an exogenous-regressor model and a WebSocket — are removed by DECISIONS LOCKED #1 +and #6, keeping every phase aligned with the single-host, non-streaming, time-safe +product vision. diff --git a/PRPs/PRP-27-scenario-simulation-full-version.md b/PRPs/PRP-27-scenario-simulation-full-version.md new file mode 100644 index 00000000..9e333ba9 --- /dev/null +++ b/PRPs/PRP-27-scenario-simulation-full-version.md @@ -0,0 +1,1465 @@ +name: "PRP-27 — Scenario Simulation / What-If Planning (Full Version)" +description: | + Context-rich PRP that promotes the **"Full Version"** section of + `docs/optional-features/03-scenario-simulation-what-if-planning.md` (lines + 82-89) into code. It is a strict **increment on the already-shipped MVP** + (`app/features/scenarios/`, PRP-26, issue #221, branch + `feat/scenarios-what-if-planning`). It adds: a leakage-safe future + feature-frame generator, an exogenous-regressor forecaster the scenario + engine can drive, a real `method="model_exogenous"` simulation path, a + multi-scenario comparison surface, and an agent that proposes scenarios + behind the HITL approval gate. Delivered as **four independently-shippable + phases** so each is one-pass implementable. + +## Purpose + +The MVP turned ForecastLabAI from "predict the future" into "plan possible +futures" — but every scenario number is a **deterministic post-forecast +multiplier** stamped `method="heuristic"`. The factors (`PRICE_ELASTICITY`, +`PROMOTION_UPLIFT_BY_KIND`, …) are hand-picked constants, not learned from data. +The feature brief's **"Full Version"** closes that gap: + +- **Phase A — Future Feature-Frame Generator (backend)**: a leakage-safe module + that builds a per-horizon-day feature matrix (`X_future`) for a `(store, + product)` series, with the scenario assumptions injected as exogenous columns. +- **Phase B — Exogenous-Regressor Forecaster + model-driven simulation + (backend)**: a `RegressionForecaster` (`BaseForecaster` subclass) that + *consumes* `X`, a new `method="model_exogenous"` value, and a + `ScenarioService.simulate` path that produces a model-causal comparison when + the baseline model supports exogenous features — falling back to the + heuristic path otherwise. +- **Phase C — Scenario Library + Multi-Scenario Comparison (backend + frontend)**: + scenario-library tagging/cloning over the existing `scenario_plan` table, a + `POST /scenarios/compare` endpoint, and a What-If Planner comparison view that + charts a baseline against N saved scenarios. +- **Phase D — Agent-Proposed Scenarios + Approval Flow (backend)**: two `scenarios` + agent tools — a read-only `propose_scenario` that returns a candidate + `ScenarioAssumptions` payload + an operational recommendation, and a + mutating `save_scenario` that persists a `scenario_plan` row **only after + the human approves it via the existing HITL gate** (`save_scenario` is added + to `agent_require_approval`). Persisted agent scenarios carry author/source + provenance and an audit trail linking back to the agent session and the + approval decision; a Phase D migration adds the provenance/audit columns to + `scenario_plan`. + +The MVP code is **not re-specified** — see § "What the MVP already delivered". + +> Source brief: `docs/optional-features/03-scenario-simulation-what-if-planning.md` +> § "Full Version". MVP PRP: `PRPs/PRP-26-scenario-simulation-what-if-planning.md`. +> Validated against the repo as of 2026-05-19. + +--- + +## SCOPE WARNING — this is a large PRP; it is phased deliberately + +The Full Version spans two genuinely large pieces of work — a leakage-safe +future-feature-frame generator and exogenous-regressor model support. Either one +alone is a normal PRP's worth of work. **This PRP is therefore explicitly +phased: ship one PR per phase, in order.** Phase A and B are backend-only and +gate Phase C/D. A team that can only land part of this should land **Phase A + +B** (the model-causal core) and defer C + D — the brief's other three bullets +(scenario library, multi-scenario comparison, agent suggestions) are valuable +but lower-risk and do not unblock anything. + +If the implementer judges Phase A + B alone is still too large for one PR, split +Phase B into B1 (the `RegressionForecaster` + training path) and B2 (the +scenario `simulate` integration) — they have a clean seam at the +`BaseForecaster` interface. + +--- + +## What the MVP already delivered (DO NOT re-build) + +PRP-26 shipped and merged the entire `app/features/scenarios/` slice. **Every +file below already exists** — this PRP modifies or extends them, never recreates +them. + +| File | What it already does | +|------|----------------------| +| `app/features/scenarios/__init__.py` | Slice package + `__all__`. | +| `app/features/scenarios/adjustments.py` | PURE deterministic factor engine — `price_factor`, `promotion_factor`, `holiday_factor`, `lifecycle_factor`, `combined_daily_factor`, `apply_adjustment`, `coverage_verdict`. Constants `PRICE_ELASTICITY`, `PROMOTION_UPLIFT_BY_KIND`, `HOLIDAY_UPLIFT`, `LIFECYCLE_FACTOR`, `FACTOR_BAND`. | +| `app/features/scenarios/schemas.py` | `PriceAssumption`, `PromotionAssumption`, `HolidayAssumption`, `InventoryAssumption`, `LifecycleAssumption`, `ScenarioAssumptions`, `SimulateScenarioRequest`, `CreateScenarioRequest`, `ScenarioPoint`, `ScenarioComparison`, `ScenarioPlanResponse`, `ScenarioListItem`, `ScenarioListResponse`. Request bodies use `ConfigDict(strict=True)` + `Field(strict=False)` on dates; responses use `from_attributes=True`. `ScenarioComparison.method` is `Literal["heuristic"]`. | +| `app/features/scenarios/models.py` | `ScenarioPlan(TimestampMixin, Base)` ORM — `scenario_id`, `name`, `store_id`, `product_id`, `run_id`, `horizon`, `assumptions` JSONB, `comparison` JSONB, `method` String(20). `CheckConstraint("method IN ('heuristic')")`, GIN indexes, composite `(store_id, product_id)` index. Constant `SCENARIO_METHOD_HEURISTIC = "heuristic"`. | +| `app/features/scenarios/service.py` | `ScenarioService` — `simulate` (heuristic post-forecast multiplier; loads a bundle via `load_model_bundle`, calls `bundle.model.predict(horizon)`, applies `adjustments.combined_daily_factor`), plus `create_plan` / `list_plans` / `get_plan` / `delete_plan`. Helpers `_load_baseline_bundle` (path-traversal guard), `_forecast_start_date`, `_latest_unit_price`, `_to_plan_response`, `_to_list_item`. Constant `HEURISTIC_DISCLAIMER`. | +| `app/features/scenarios/routes.py` | `router = APIRouter(prefix="/scenarios", tags=["scenarios"])`. Endpoints: `POST /scenarios/simulate`, `POST /scenarios`, `GET /scenarios`, `GET /scenarios/{scenario_id}`, `DELETE /scenarios/{scenario_id}`. Maps `FileNotFoundError`→`NotFoundError`, `ValueError`→`BadRequestError`, `SQLAlchemyError`→`DatabaseError`. | +| `app/features/scenarios/tests/` | `conftest.py`, `test_adjustments.py`, `test_schemas.py`, `test_leakage.py` (LOAD-BEARING), `test_routes_integration.py`. | +| `alembic/versions/43e35957a248_create_scenario_plan_table.py` | Creates `scenario_plan`. **This is the current Alembic head** — verified `uv run alembic heads` → `43e35957a248`. | +| `frontend/src/types/api.ts` | `Scenario*` interfaces. | +| `frontend/src/hooks/use-scenarios.ts` | `useSimulateScenario`, `useScenarios`, `useScenario`, `useCreateScenario`, `useDeleteScenario`. | +| `frontend/src/lib/scenario-utils.ts` (+ `.test.ts`) | `mergeComparisonSeries`, `formatDelta`, `coverageLabel`, `coverageVariant`, `deltaCsvColumns`, `summariseAssumptions`. | +| `frontend/src/pages/visualize/planner.tsx` | The `/visualize/planner` What-If Planner page. | + +> The Full Version **never weakens** `test_leakage.py`, never edits the merged +> migration `43e35957a248`, and never drops the `method="heuristic"` path — +> it adds a second, model-driven path alongside it. + +--- + +## DEPENDS ON — read before starting + +- **MVP merged** — `app/features/scenarios/` exists. If it does not, this PRP + cannot start; build PRP-26 first. +- **No unmerged-PRP dependency.** Builds on already-merged `forecasting` + (PRP-5), `featuresets` (PRP-4 + PRP-3.1*), `data_platform` (PRP-2), + `registry` (PRP-7), `jobs` (PRP-8), `agents` (PRP-10), dashboard (PRP-11). +- **Sanity-check before starting**: `app/features/forecasting/models.py` must + still define `BaseForecaster` with the `fit(y, X=None)` / `predict(horizon, + X=None)` interface and `model_factory`; `app/features/featuresets/service.py` + must still define `FeatureEngineeringService.compute_features` and + `FeatureDataLoader`. If either moved, the Phase A/B plan needs revisiting. + +--- + +## Goal + +**Feature Goal**: Make Scenario Simulation **model-causal** — a what-if can be +answered by re-forecasting through a model that consumes a leakage-safe future +feature frame built from the scenario assumptions — and **collaborative** — an +agent can propose a scenario and a human approves it. The MVP heuristic path +stays as the transparent fallback. + +**Deliverable**: +- **Phase A** — `app/features/scenarios/feature_frame.py` (new): a pure + + DB-reading future-feature-frame generator, plus a load-bearing leakage test. +- **Phase B** — `RegressionForecaster` added to `forecasting/models.py`, a + `RegressionModelConfig` schema, `model_factory` wiring, a training path that + builds historical features and fits the estimator; `ScenarioService.simulate` + extended with a `method="model_exogenous"` branch; one Alembic migration + widening the `scenario_plan.method` CHECK constraint. +- **Phase C** — `POST /scenarios/compare` + `MultiScenarioComparison` schema + + scenario-library fields (`tags`, `cloned_from`) on `scenario_plan` (second + migration); a multi-series chart variant and a comparison view on the planner + page; new hooks + types. +- **Phase D** — `app/features/scenarios/agent_tools.py` (new): a read-only + `propose_scenario` tool and a mutating `save_scenario` tool registered on the + experiment agent, with `save_scenario` gated by `agent_require_approval`; a + Phase D Alembic migration adding provenance/audit columns to `scenario_plan` + (`source`, `agent_session_id`, `approved_by`, `approved_at`, `approval_decision`); + the `ScenarioPlan` ORM model + schemas extended accordingly; `save_scenario` + added to the `agent_require_approval` config list in `app/core/config.py`. + +**Success Definition**: `docker compose up` → seed → train a `regression` +model → `POST /scenarios/simulate` with that model's `run_id` and a price-cut +assumption returns a `ScenarioComparison` with `method="model_exogenous"` whose +deltas come from re-forecasting (not a fixed multiplier); the future-frame +leakage test proves no observed target at/after the forecast origin is read; +`POST /scenarios/compare` ranks N saved scenarios; the planner comparison view +charts baseline + N scenarios; the experiment agent can `propose_scenario` +(read-only) and `save_scenario`, and a `save_scenario` call is blocked pending +`/agents/sessions/{id}/approve` — once approved, the persisted `scenario_plan` +row carries `source="agent"`, the originating `agent_session_id`, and the +approval audit trail (`approved_by`, `approved_at`, `approval_decision`); every gate +(`ruff`, `mypy --strict`, `pyright --strict`, `pytest` unit + integration, +frontend `tsc`/`lint`/`test`) is green. + +## Why + +- **User value** — heuristic deltas are directional only; a model-causal + scenario lets a planner trust the *magnitude* of "discount 15% next week". + A scenario library + multi-scenario comparison turns one-off analyses into a + reusable planning portfolio. Agent-proposed scenarios surface options a + planner would not have thought to try. +- **Demo value** — the brief calls this out: the Full Version is what makes the + feature a *planning system* rather than a labelled-heuristic demo. +- **Integration** — featuresets already produces every exogenous feature + (`price_lag_*`, `promo_*`, lifecycle, calendar) and is time-safe by + construction; this PRP reuses that machinery for the *future* frame and adds + the one model class that can consume it. + +## What + +### User-visible behavior + +1. **Train a regression model** — a new `model_type="regression"` is trainable + via the existing `POST /forecasting/train` (or a `train` job). It fits a + feature-driven estimator on historical features. +2. **Model-causal simulation** — on the What-If Planner, when the picked + baseline job's model is a `regression` model, `POST /scenarios/simulate` + returns `method="model_exogenous"`: the price/promotion/holiday assumptions + become real future feature-frame columns and the model re-forecasts. A + `naive`/`seasonal_naive`/`moving_average` baseline still returns + `method="heuristic"` exactly as today. The result always declares its method. +3. **Scenario library** — saved plans carry `tags` and can be cloned (a new plan + pre-filled from an existing one's assumptions). The saved-plans list filters + by tag. +4. **Multi-scenario comparison** — pick 2-5 saved plans → a comparison view + charts the baseline against every scenario series, ranks them by revenue + delta, and shows a verdict table. +5. **Agent-proposed scenarios** — in the chat agent, a user can ask "what + scenarios should I try for store 1 product 101?"; the agent calls + `propose_scenario` (read-only), which returns a candidate + `ScenarioAssumptions` + a plain-language recommendation. If the user then + asks the agent to save the proposal, the agent calls `save_scenario` — a + mutating tool that is in `agent_require_approval`, so it pauses at the + existing HITL gate and writes the `scenario_plan` row **only after** the + human approves via `/agents/sessions/{id}/approve`. The persisted row records + who/what created it (`source="agent"`, the `agent_session_id`) and the + approval decision (`approved_by`, `approved_at`, `approval_decision`). + +### Technical requirements + +- **Time-safety is the #1 invariant.** The future-feature-frame generator must + never read an observed target at or after the forecast origin. A new + load-bearing leakage test is the spec. +- New forecaster implements the existing `BaseForecaster` ABC — `fit(y, X)` / + `predict(horizon, X)` — and is deterministic (`random_state`). +- **No new external dependency by default** — use scikit-learn's + `HistGradientBoostingRegressor` (already a transitive dep via `scikit-learn`). + Adding LightGBM is a **stop-and-ask** gate (see § Vision Tensions). +- `method` stays forward-compatible: a migration widens the CHECK constraint to + `IN ('heuristic','model_exogenous')`. +- Pydantic v2 at every boundary; SQLAlchemy 2.0 async; RFC 7807 errors; + `mypy --strict` + `pyright --strict` clean. No WebSocket. No managed-cloud SDK. + +### Success Criteria + +- [ ] `app/features/scenarios/feature_frame.py` builds a horizon-length feature + matrix for a `(store, product)` series with assumption-driven exogenous + columns; it is unit-tested and a leakage test proves no at/after-origin + target read. +- [ ] A `regression` model is trainable and persists a `ModelBundle` whose + `model` is a `RegressionForecaster`; `bundle.metadata` carries the feature + column list and the historical tail needed to seed lags. +- [ ] `POST /scenarios/simulate` with a `regression` baseline returns + `method="model_exogenous"`; with a baseline forecaster returns + `method="heuristic"` (unchanged). Both pass through RFC 7807 on a bogus + `run_id` — never a 500. +- [ ] An empty `ScenarioAssumptions` on the model path yields scenario ≈ + baseline (the unmodified future frame re-forecast). +- [ ] The Alembic migration widening the `method` CHECK upgrades **and** + downgrades cleanly on a fresh DB; `down_revision = "43e35957a248"`. +- [ ] `POST /scenarios/compare` accepts 2-5 `scenario_id`s and returns a ranked + `MultiScenarioComparison`; saved plans carry `tags`; cloning works. +- [ ] The planner page renders a multi-series comparison chart + ranked verdict + table; dogfooded in a real browser. +- [ ] The experiment agent exposes `tool_propose_scenario` (read-only — never + writes a row) and `tool_save_scenario` (mutating). `save_scenario` is in + `agent_require_approval`, so a call pauses for HITL approval and persists + the `scenario_plan` row only after `/agents/sessions/{id}/approve`. +- [ ] An agent-persisted `scenario_plan` row carries provenance — `source`, + `agent_session_id` — and an audit trail — `approved_by`, `approved_at`, + `approval_decision`. The Phase D migration adding those columns upgrades + **and** downgrades cleanly on a fresh DB. +- [ ] `test_leakage.py` (MVP) still passes unweakened; the new future-frame + leakage test passes. +- [ ] All gates green; no new external dependency (unless the LightGBM + stop-and-ask is explicitly approved); no WebSocket; no cross-slice + `service.py` import (see DECISIONS LOCKED #3). +- [ ] README + `docs/_base/{API_CONTRACTS,REPO_MAP_INDEX,DOMAIN_MODEL}.md` + updated. + +--- + +## All Needed Context + +### DECISIONS LOCKED (resolved during planning — do NOT re-litigate) + +1. **The heuristic path STAYS — the model path is ADDITIVE.** The MVP's + `method="heuristic"` post-forecast multiplier is the documented fallback for + any baseline that cannot consume exogenous features (`naive`, + `seasonal_naive`, `moving_average` — every `fit`/`predict` carries + `# noqa: ARG002`). `ScenarioService.simulate` branches on the loaded + `bundle.config.model_type`: `regression` → model path, anything else → + the existing heuristic path. A scenario result always carries an accurate + `method`. Do NOT delete `adjustments.py` or the heuristic branch. + +2. **Use `HistGradientBoostingRegressor`, NOT LightGBM, by default.** LightGBM + is **not in `pyproject.toml`** (only `scikit-learn>=1.6.0` is) and + `model_factory` raises `NotImplementedError` for `lightgbm`. + `HistGradientBoostingRegressor` (`sklearn.ensemble`) is already importable, + deterministic with `random_state`, NaN-tolerant (critical — lag features are + `NaN` at series start), and needs no `pyproject.toml` change and no + stop-and-ask gate. The new `RegressionForecaster` wraps it. LightGBM remains + a future option behind `forecast_enable_lightgbm` — adding it is a separate, + explicitly-approved change (§ Vision Tensions). See + `PRPs/ai_docs/exogenous-regressor-forecasting.md` § 1, § 5. + +3. **No cross-slice `service.py` import — same rule as the MVP.** A slice may + NOT import another slice's `service.py` (`AGENTS.md` § Architecture). The + future-feature-frame generator (`scenarios/feature_frame.py`) imports the + **stable lower-level building blocks** only: + `FeatureEngineeringService` + `FeatureDataLoader` from + `featuresets/service.py` are *service-layer* classes — importing them is the + forbidden cross-slice service import. RESOLUTION: `feature_frame.py` reuses + the *featureset config schemas* (`featuresets/schemas.py` — + `FeatureSetConfig`, `LagConfig`, `CalendarConfig`, `ExogenousConfig` — these + are schema/value objects, allowed) and reads `data_platform` ORM models + directly (allowed read-only ORM import, the sanctioned exception). It + **replicates** the small slice of leakage-safe lag/calendar logic it needs — + exactly as the MVP `scenarios/service.py` replicates the + `ForecastingService.predict` body rather than importing `ForecastingService`. + The `RegressionForecaster` itself lives in `forecasting/models.py` (it IS a + forecasting concern — a `BaseForecaster` subclass), so the `scenarios` slice + only imports `load_model_bundle` + `model_factory` + the `BaseForecaster` + interface from `forecasting`, never `ForecastingService`. **Cite this in the + PR** per `product-vision.md`. + +4. **Long-lag + calendar + exogenous feature set — no recursion in v1.** The + future feature frame uses ONLY: lags `k ≥ horizon` (knowable at the forecast + origin), calendar features (pure function of the date), and assumption-driven + exogenous columns (`price_*`, `promo_*`, `is_holiday`, lifecycle). It + deliberately does NOT use lags shorter than the horizon, which would require + recursive (iterative) forecasting — `ŷ[T+j-k]` feeding `lag_k` at `T+j`. + Recursion is a documented Phase-2 extension. This keeps the leakage proof a + direct assertion and the PRP one-pass implementable. See + `PRPs/ai_docs/exogenous-regressor-forecasting.md` § 2. + +5. **New `method` value = `"model_exogenous"`.** The MVP CHECK constraint is + `method IN ('heuristic')`. The Phase B migration widens it to + `IN ('heuristic','model_exogenous')`. The `ScenarioComparison.method` field + becomes `Literal["heuristic", "model_exogenous"]`. The `disclaimer` string is + method-specific: the model path gets a *model-driven* disclaimer (still a + transparency control — a model estimate is not certainty) distinct from the + heuristic one. + +6. **There is no `scenarios` commit scope.** `.claude/rules/commit-format.md` + has no `scenarios` scope. Use `feat(forecast)` for `RegressionForecaster` + and the training path, `feat(api)` for the `scenarios`-slice backend, + `feat(api,db)` for slice + migration, `feat(agents)` for the agent tool, + `feat(ui)` for the frontend, `test(...)` matching the slice, `docs(docs)`. + +7. **Current Alembic head is `43e35957a248`** (`create_scenario_plan_table`) — + verified via `uv run alembic heads`. The Phase B migration sets + `down_revision = "43e35957a248"`; the Phase C migration chains off the + Phase B revision. **Re-verify with `uv run alembic heads` immediately before + writing each migration** — another PRP merging first would move the head. + +8. **No WebSocket, no managed-cloud SDK, no streaming.** Simulation and + comparison are request/response. Consistent with `product-vision.md`. + +9. **`scenario_plan` JSONB columns stay `assumptions` / `comparison`.** New + library fields are real columns (`tags` as `JSONB` array, `cloned_from` as + `String(32)` nullable) — never folded into the JSONB blobs, so they are + queryable/indexable. Never name a column `metadata` (SQLAlchemy reserves it). + +10. **Exogenous feature lag offsets = `(1, 7, 14, 28)` days — PINNED.** The + maintainer resolved this (formerly an Open Question). The exogenous-feature + lag offsets used by both the historical feature matrix (Phase B training, + Task B4) and the future feature frame (Phase A, `feature_frame.py`) are + `EXOGENOUS_LAGS = (1, 7, 14, 28)` — daily, weekly, fortnightly, and a + four-week lag covering the dominant retail seasonality. The future *target* + long-lag frame may use only the subset with `k >= horizon` (DECISIONS + LOCKED #4); the rest are exogenous-driven (`price_*`, `promo_*`, + `is_holiday`, lifecycle) and therefore knowable at the origin regardless of + `k`. The trained bundle's `feature_columns` must reflect this exact offset + set so the future frame matches column-for-column. + +11. **`history_tail` length = `90` days — PINNED.** The maintainer resolved + this (formerly an Open Question). The persisted historical tail — + `history_tail` in the bundle metadata, fed to the future-feature-frame + generator and used as regression context — is the last + `HISTORY_TAIL_DAYS = 90` observed target values ending at the forecast + origin `T`. 90 days exceeds the largest lag offset (28) with a comfortable + buffer, so every long-lag column resolves inside the tail. + +12. **Multi-scenario comparison cap = `5` — PINNED.** The maintainer resolved + this (formerly an Open Question). `POST /scenarios/compare` accepts 2–5 + `scenario_id`s; the upper bound `MAX_COMPARE_SCENARIOS = 5` keeps the + multi-series chart legible. Enforced at the schema boundary via + `Field(..., min_length=2, max_length=5)` on `CompareScenariosRequest. + scenario_ids` (a >5 list → 422), and the comparison-route integration test + asserts the 6-scenario rejection. + +13. **The agent CAN persist a scenario — behind the HITL approval gate.** The + maintainer resolved this (formerly Open Question 4). Phase D ships TWO agent + tools, not one: + - `propose_scenario` — READ-ONLY. Returns a candidate `ScenarioAssumptions` + + recommendation. No DB write, no approval needed. + - `save_scenario` — MUTATING. Persists a `scenario_plan` row, but ONLY after + a human approves via the existing HITL gate. It is gated EXACTLY like + `tool_create_alias` / `tool_archive_run`: the tool name `save_scenario` + is added to `agent_require_approval` in `app/core/config.py` (currently + `["create_alias", "archive_run"]` → `["create_alias", "archive_run", + "save_scenario"]`). + This widens the agent's mutation surface. `AGENTS.md` § Safety requires a + stop-and-ask before "widening an agent's mutation surface without adding + the tool name to `agent_require_approval`" — this PRP IS that approval, and + the change correctly adds the tool to the list, so the gate is satisfied. + The PR description MUST call the widening out explicitly. + An agent-persisted plan MUST capture, beyond the MVP's `assumptions` + + `comparison`: target scope (`store_id`/`product_id` already exist; a + `category` field is added if category-scoped plans are in play), horizon + (already exists), AUTHOR/SOURCE metadata (`source ∈ {"agent","user"}`, + `agent_session_id`), and an AUDIT TRAIL of the approval decision + (`approved_by`, `approved_at`, `approval_decision ∈ {"approved","rejected"}`). + The MVP `scenario_plan` table (per PRP-26) has only `assumptions` + + `comparison` JSONB plus the scalar columns — so Phase D ships an Alembic + migration adding the provenance/audit columns. Verified against + `app/features/scenarios/models.py` and + `alembic/versions/43e35957a248_create_scenario_plan_table.py`: the table + today has `id, scenario_id, name, store_id, product_id, run_id, horizon, + assumptions, comparison, method, created_at, updated_at` — no provenance + columns exist yet. The Phase D migration adds discrete columns: + `source` `String(16)` NOT NULL `server_default='user'` (CHECK + `source IN ('agent','user')`); `agent_session_id` `String(32)` nullable; + `approved_by` `String(120)` nullable; `approved_at` `DateTime(timezone=True)` + nullable; `approval_decision` `String(16)` nullable (CHECK + `approval_decision IN ('approved','rejected')`). Discrete columns (not a + single `provenance` JSONB blob) so they are queryable/indexable, consistent + with DECISIONS LOCKED #9. User-created plans default to `source='user'` + with the audit columns NULL, so the existing MVP create path is + backward-compatible. + +### Documentation & References + +```yaml +# ── MUST READ — the MVP (this PRP extends it) ── + +- file: PRPs/PRP-26-scenario-simulation-what-if-planning.md + why: The MVP PRP. Its DECISIONS LOCKED, Known Gotchas table, and Anti-Patterns + all still hold. This PRP is an increment — read the MVP first to know + what already exists. + +- file: app/features/scenarios/service.py + why: ScenarioService.simulate is the method this PRP branches. Read the + heuristic body, _load_baseline_bundle (path-traversal guard — REUSE it + verbatim), _forecast_start_date, _latest_unit_price. The model path is a + NEW branch alongside the existing one. + +- file: app/features/scenarios/schemas.py + why: ScenarioComparison.method is Literal["heuristic"] — Phase B widens it. + ScenarioAssumptions is the input both paths consume. Request bodies use + ConfigDict(strict=True) + Field(strict=False) on dates — keep that. + +- file: app/features/scenarios/models.py + why: ScenarioPlan ORM + the method CHECK constraint Phase B widens. Phase C + adds tags/cloned_from columns here. + +- file: app/features/scenarios/adjustments.py + why: The heuristic engine — NOT modified, but the model path mirrors its + "pure, never raises" discipline for feature_frame.py helpers. + +- file: app/features/scenarios/tests/test_leakage.py + why: The MVP leakage spec — LOAD-BEARING, never weakened. The new + future-frame leakage test follows this exact philosophy. + +# ── MUST READ — forecasting (where RegressionForecaster lands) ── + +- file: app/features/forecasting/models.py + why: BaseForecaster ABC — fit(y, X=None) / predict(horizon, X=None) / + get_params / set_params. RegressionForecaster subclasses it. model_factory + (line 429) is the dispatch — add a "regression" branch. ModelType alias + (line 426) gains "regression". CRITICAL: the existing baselines carry + # noqa: ARG002 because they ignore X — RegressionForecaster is the FIRST + to actually use X. + +- file: app/features/forecasting/schemas.py + why: ModelConfigBase (frozen, extra=forbid, config_hash). LightGBMModelConfig + (line 107) is the closest precedent for a new ML model config — mirror + it for RegressionModelConfig. ModelConfig union (line 148) gains the new + config. TrainRequest strict-mode pattern. + +- file: app/features/forecasting/persistence.py + why: ModelBundle (model + config + metadata dict) and load_model_bundle / + save_model_bundle. The regression bundle's metadata MUST additionally + carry the feature column list and the historical tail (last N target + values + dates) needed to build long-lag future-frame columns. + +- file: app/features/forecasting/service.py + why: ForecastingService.train_model + predict. The regression training path + loads features instead of raw y; predict() must pass X for a regression + model. Read _load_training_data (line 314) — the regression path needs a + feature-loading sibling. The path-traversal guard in predict() (lines + 218-249) is the pattern _load_baseline_bundle already mirrors. + +# ── MUST READ — featuresets (the time-safe machinery the future frame reuses) ── + +- file: app/features/featuresets/service.py + why: FeatureEngineeringService.compute_features — the time-safe HISTORICAL + feature builder. _compute_lag_features (shift(positive)), + _compute_rolling_features (shift(1).rolling), _compute_calendar_features, + _compute_exogenous_features. CRITICAL: cutoff filter happens BEFORE any + compute. feature_frame.py REPLICATES the leakage-safe lag+calendar logic + it needs (DECISIONS LOCKED #3) — it does NOT import this class. + FeatureDataLoader shows the SQL load patterns to mirror. + +- file: app/features/featuresets/schemas.py + why: FeatureSetConfig, LagConfig, CalendarConfig, ExogenousConfig — schema + value-objects feature_frame.py MAY import (they are not service code). + +- file: app/features/featuresets/tests/test_leakage.py + why: The original load-bearing leakage spec. The future-frame leakage test + mirrors its assertion style. + +- file: PRPs/PRP-3.1B-lifecycle-compute.md +- file: PRPs/PRP-3.1D-promotion-compute.md + why: The time-safety reasoning for lifecycle / promotion features — the + future frame's exogenous columns must respect the same boundaries. + +# ── MUST READ — agents (Phase D) ── + +- file: app/features/agents/agents/experiment.py + why: The agent that gains tool_propose_scenario. Mirror tool_create_alias + EXACTLY — it shows the @agent.tool + @recoverable decorators, the + requires_approval("create_alias") check, and the + {"status":"approval_required", ...} early return. + +- file: app/features/agents/agents/base.py + why: requires_approval(action_name) (line ~255) checks + settings.agent_require_approval. SYSTEM_PROMPT_HEADER / + TOOL_USAGE_INSTRUCTIONS — the new tool's name goes in + TOOL_USAGE_INSTRUCTIONS. + +- file: app/features/agents/tools/registry_tools.py + why: create_alias / archive_run — the "REQUIRES HUMAN APPROVAL" tool-function + shape. scenarios/agent_tools.py mirrors this. + +- file: app/features/agents/tools/forecasting_tools.py + why: train_model / predict tool functions — how a tool wraps a service call + and returns model_dump(). The propose_scenario tool mirrors the shape but + is READ-ONLY (it proposes, it does not persist). + +- file: app/core/config.py + why: agent_require_approval (line 164) = ["create_alias", "archive_run"]. + Phase D adds "save_scenario" to this list — verified the current value + is exactly the two-element list. This is a deliberate widening of the + agent's mutation surface; AGENTS.md § Safety requires a stop-and-ask for + that — this PRP IS that approval (see DECISIONS LOCKED #13). + forecast_model_artifacts_dir (line 100), forecast_enable_lightgbm (101). + +# ── MUST READ — frontend (Phase C) ── + +- file: frontend/src/pages/visualize/planner.tsx + why: The existing What-If Planner page — Phase C adds a multi-scenario + comparison view to it (or a sibling tab). Read its Card/Select/Table + skeleton, the saved-plans table, the TimeSeriesChart wiring. + +- file: frontend/src/components/charts/time-series-chart.tsx + why: The Recharts wrapper — currently 2-series (actualKey/predictedKey). The + multi-scenario chart needs an M+1-series variant; verify the exact prop + names before extending. Do NOT hand-roll a chart. + +- file: frontend/src/hooks/use-scenarios.ts + why: The existing hooks. Phase C adds useCompareScenarios (a mutation) and + extends useScenarios with a tag filter param. Mirror the existing shape. + +- file: frontend/src/lib/scenario-utils.ts + why: Pure utils. Phase C adds mergeMultiScenarioSeries and a ranking helper — + unit-tested in scenario-utils.test.ts. + +- file: frontend/src/types/api.ts + why: Scenario* interfaces — Phase C adds MultiScenarioComparison and extends + ScenarioPlanResponse with tags/cloned_from. Phase B changes method to a + 'heuristic' | 'model_exogenous' union. + +# ── External documentation (curated) ── + +- docfile: PRPs/ai_docs/exogenous-regressor-forecasting.md + why: THE primary reference for Phase A + B. Condenses the exogenous-regressor + model contract, the leakage rule for FUTURE feature frames (the + load-bearing part), the recursion-avoiding "long-lag" feature set, the + HistGradientBoostingRegressor-vs-LightGBM decision, and multi-scenario + comparison math. Read it before writing feature_frame.py or the + RegressionForecaster. + +- url: https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.HistGradientBoostingRegressor.html + why: The estimator RegressionForecaster wraps — fit/predict signatures, + random_state for determinism, native NaN handling. + +- url: https://lightgbm.readthedocs.io/en/stable/pythonapi/lightgbm.LGBMRegressor.html + why: ONLY if the LightGBM stop-and-ask is approved — the alternative estimator. + +- url: https://pandas.pydata.org/docs/user_guide/timeseries.html + why: Date ranges, shifting, rolling — for building the future feature frame. + +- url: https://recharts.org/en-US/api/LineChart + why: The multi-series scenario-comparison chart (Phase C). + +- url: https://tanstack.com/query/latest/docs/framework/react/guides/mutations + why: useCompareScenarios is a mutation; the tag-filtered list stays a query. + +- url: https://www.nist.gov/itl/ai-risk-management-framework + why: The transparency control — every ScenarioComparison declares its + `method` and carries a method-appropriate `disclaimer`. A model-driven + estimate still gets a caveat (not certainty). + +# ── Rules — read before writing any code ── + +- file: .claude/rules/product-vision.md + why: Principle 5 (time-safety — the leakage test is load-bearing), principle 8 + (single-host — the HistGradientBoostingRegressor choice keeps it so), + "not a generic ML platform" / "not a streaming system". Answer all 6 + Litmus-Test questions in the PR description. + +- file: .claude/rules/security-patterns.md + why: Pydantic v2 at every boundary; SQLAlchemy parameter binding; + pathlib.Path.resolve() for artifact paths; strict-mode request-body + policy; agent mutation tools MUST be in agent_require_approval. + +- file: .claude/rules/test-requirements.md + why: New module → test file; new endpoint → route test (2xx + 1 error); + new model → constraint test; new migration → upgrade/downgrade clean. + +- file: .claude/rules/commit-format.md + why: type(scope): description (#issue). No `scenarios` scope (DECISIONS + LOCKED #6). + +- file: .claude/rules/branch-naming.md + why: branch feat/scenario-simulation-full-version off dev. + +- file: .claude/rules/ui-design.md +- file: .claude/rules/shadcn-ui.md + why: Build the comparison view via frontend-design + shadcn-ui skills; + dogfood in a real browser. Green tsc ≠ working UI. +``` + +### Current Codebase tree (relevant slices — all already exist) + +```bash +app/features/ +├── scenarios/ # THE MVP SLICE — extended by this PRP +│ ├── __init__.py +│ ├── adjustments.py # heuristic engine — unchanged +│ ├── models.py # ScenarioPlan — Phase B/C extend +│ ├── schemas.py # Phase B/C extend +│ ├── service.py # ScenarioService.simulate — Phase B branches +│ ├── routes.py # Phase C adds /scenarios/compare +│ └── tests/ +├── forecasting/ +│ ├── models.py # BaseForecaster + model_factory — Phase B +│ ├── schemas.py # ModelConfig union — Phase B +│ ├── persistence.py # ModelBundle, load/save +│ └── service.py # train_model/predict — Phase B +├── featuresets/ +│ ├── service.py # FeatureEngineeringService — pattern source +│ └── schemas.py # FeatureSetConfig etc. — importable value-objects +├── agents/ +│ ├── agents/experiment.py # gains tool_propose_scenario — Phase D +│ ├── agents/base.py # requires_approval() +│ └── tools/ # tool-function shape +└── data_platform/models.py # SalesDaily, Calendar, Promotion, Product, PriceHistory +alembic/versions/ +└── 43e35957a248_create_scenario_plan_table.py # CURRENT HEAD (verify) +frontend/src/ +├── pages/visualize/planner.tsx +├── hooks/use-scenarios.ts +├── lib/scenario-utils.ts (+ .test.ts) +├── components/charts/time-series-chart.tsx +└── types/api.ts +``` + +### Desired Codebase tree — files to ADD + +```bash +# ── Phase A — future feature-frame generator ── +app/features/scenarios/feature_frame.py # leakage-safe X_future builder +app/features/scenarios/tests/test_feature_frame.py # unit tests +app/features/scenarios/tests/test_future_frame_leakage.py # LOAD-BEARING leakage spec + +# ── Phase B — exogenous-regressor model + migration ── +alembic/versions/_widen_scenario_method_check.py # method CHECK widen +app/features/forecasting/tests/test_regression_forecaster.py # new forecaster unit tests + +# ── Phase C — scenario library + multi-scenario comparison ── +alembic/versions/_add_scenario_library_columns.py # tags + cloned_from columns +app/features/scenarios/tests/test_compare_integration.py + +# ── Phase D — agent-proposed scenarios ── +alembic/versions/_add_scenario_provenance_columns.py # source + audit-trail columns +app/features/scenarios/agent_tools.py # propose_scenario + save_scenario tool functions +app/features/scenarios/tests/test_agent_tools.py + +# ── docs ── +PRPs/ai_docs/exogenous-regressor-forecasting.md # ALREADY CREATED by this PRP +``` + +### Files to MODIFY (all additive) + +```bash +# Phase A +app/features/scenarios/__init__.py # +export the frame builder if public + +# Phase B +app/features/forecasting/models.py # +RegressionForecaster, +model_factory branch, +ModelType +app/features/forecasting/schemas.py # +RegressionModelConfig, +ModelConfig union member +app/features/forecasting/service.py # +regression training/predict feature path +app/features/scenarios/schemas.py # method -> Literal[...,"model_exogenous"]; +model disclaimer +app/features/scenarios/models.py # CHECK constraint widened to match the migration +app/features/scenarios/service.py # +_simulate_model_exogenous branch in simulate() + +# Phase C +app/features/scenarios/models.py # +tags JSONB, +cloned_from String(32) +app/features/scenarios/schemas.py # +MultiScenarioComparison, +CompareScenariosRequest, +tags +app/features/scenarios/service.py # +compare_scenarios, +clone, +tag filter on list_plans +app/features/scenarios/routes.py # +POST /scenarios/compare, +tag query param, +clone +frontend/src/types/api.ts # +MultiScenarioComparison; method union; tags +frontend/src/hooks/use-scenarios.ts # +useCompareScenarios, tag filter +frontend/src/lib/scenario-utils.ts (+test)# +mergeMultiScenarioSeries, +rankScenarios +frontend/src/pages/visualize/planner.tsx # +comparison view +frontend/src/components/charts/time-series-chart.tsx # +multi-series variant (or new component) + +# Phase D +app/features/scenarios/models.py # +source, +agent_session_id, +approved_by, +approved_at, +approval_decision +app/features/scenarios/schemas.py # +provenance/audit fields on ScenarioPlanResponse; +SaveScenarioRequest +app/features/scenarios/service.py # +create_plan provenance args; +approval-decision write path +app/features/agents/agents/experiment.py # +tool_propose_scenario, +tool_save_scenario (HITL-gated) +app/features/agents/agents/base.py # +tool names in TOOL_USAGE_INSTRUCTIONS +app/core/config.py # +"save_scenario" in agent_require_approval + +# docs (all phases) +README.md +docs/_base/API_CONTRACTS.md +docs/_base/REPO_MAP_INDEX.md +docs/_base/DOMAIN_MODEL.md +``` + +### Known Gotchas of our codebase & Library Quirks + +| # | Gotcha | Mitigation | +|---|--------|-----------| +| 1 | The future feature frame is a NEW leakage surface the MVP did not have — building `X_future` wrong leaks future targets into the forecast. | DECISIONS LOCKED #4: long-lag (`k ≥ horizon`) + calendar + assumption-driven exogenous columns ONLY — every value knowable at the forecast origin. The new `test_future_frame_leakage.py` is the load-bearing spec. | +| 2 | `BaseForecaster.predict(horizon, X=None)` — the baseline forecasters carry `# noqa: ARG002` because they ignore `X`. `RegressionForecaster` is the FIRST that uses it. | `RegressionForecaster.predict` must reject a `None` X (it cannot forecast without features) with a clear `ValueError`, and assert `X.shape[0] == horizon`. | +| 3 | `HistGradientBoostingRegressor` is deterministic ONLY with a fixed `random_state`; without it the bundle hash drifts. | Pass `random_state=settings.forecast_random_seed`. `BaseForecaster.__init__` already stores `random_state` — use it. | +| 4 | Lag features have `NaN` at the start of a series. A model that cannot handle `NaN` would crash on fit. | `HistGradientBoostingRegressor` handles `NaN` natively — do NOT impute it away (imputation that uses the series mean would itself leak). | +| 5 | The regression bundle must carry MORE metadata than a baseline bundle — the feature column list (so predict reproduces column order) and the historical target tail (to seed long lags). | Extend the `metadata` dict in the training path: `feature_columns: list[str]`, `history_tail: list[float]`, `history_tail_dates: list[str]`. `metadata` is `dict[str, object]` — JSON-safe values only. | +| 6 | `ScenarioComparison.method` is `Literal["heuristic"]` and the `scenario_plan.method` column has `CHECK method IN ('heuristic')`. Persisting `"model_exogenous"` against the un-migrated DB fails the CHECK. | Phase B ships the migration widening the CHECK **and** updates the `Literal` + the ORM `CheckConstraint` in `models.py` in the SAME PR. Migrations are forward-only — never edit `43e35957a248`. | +| 7 | The current Alembic head is `43e35957a248`. THREE new migrations chain in order: Phase B off `43e35957a248`, Phase C off the Phase B revision, Phase D off the Phase C revision. | `uv run alembic heads` immediately before writing EACH migration; never guess `down_revision`. Each migration's `down_revision` is the *previous* PRP-27 migration's revision id. | +| 8 | Importing `FeatureEngineeringService` / `FeatureDataLoader` from `scenarios` is a cross-slice **service** import — forbidden. | DECISIONS LOCKED #3: `feature_frame.py` imports only featureset *schema* value-objects + `data_platform` ORM models, and replicates the leakage-safe lag/calendar logic. Cite in the PR. | +| 9 | `JSONB` rejects Python `date`/`datetime`; the `tags` column is a JSON array. | `tags` is `list[str]` — JSON-native, fine. Continue persisting `assumptions`/`comparison` via `model_dump(mode="json")`. | +| 10 | An agent tool that *persists* a scenario without approval widens the agent's mutation surface — a security regression. | Phase D ships two tools: `propose_scenario` is READ-ONLY (returns a candidate payload, no approval). `save_scenario` is MUTATING — its name MUST be in `agent_require_approval` and it MUST be gated exactly like `tool_create_alias` (DECISIONS LOCKED #13). The PR explicitly calls out the mutation-surface widening. | +| 11 | A `regression` model trained for store A / product B cannot simulate store C — same store/product check as `ForecastingService.predict`. | `simulate` already reads `store_id`/`product_id` from `bundle.metadata`; the model path reuses that — no cross-grain forecast. | +| 12 | The repo uses CRLF on `.py` files (no `.gitattributes`); scripted text-mode writes flip them to LF. | Edit `forecasting/models.py`, `app/main.py`, `config.py` minimally; preserve line endings. | +| 13 | `units_delta_pct` divide-by-zero when baseline demand is 0 — already guarded in the MVP. | The model path reuses the same guard; do not regress it. | +| 14 | A regression model with too few training rows (short history) over-fits or fails the long-lag construction (no row at `T+1-k` for `k ≥ horizon`). | The training path requires `n_observations >= horizon_max + min_train_rows`; raise a clear `ValueError` otherwise — surfaces as RFC 7807. | + +--- + +## Implementation Blueprint + +### Data models and structure + +**Phase A — `app/features/scenarios/feature_frame.py` (pure builder + thin DB read):** + +```python +# Builds X_future — a (horizon, n_features) matrix — for one (store, product). +# Time-safety: every column is knowable at the forecast origin T (DECISIONS #4). + +@dataclass +class FutureFeatureFrame: + dates: list[date] # T+1 .. T+horizon + feature_columns: list[str] # column order — MUST match the trained bundle + matrix: list[list[float]] # row-major; NaN allowed (HGBR handles it) + +def build_calendar_columns(dates: list[date], config: CalendarConfig) -> ...: ... + # PURE — dow/month/quarter/is_weekend from the date itself. Never leaks. + +def build_long_lag_columns(history_tail: list[float], dates: list[date], + lags: tuple[int, ...]) -> ...: ... + # lag_k at T+j == observed y[T+j-k]; REQUIRES k >= horizon so the index + # T+j-k <= T is always inside history_tail. ASSERT min(lags) >= horizon. + # PINNED: exogenous feature lag offsets are EXOGENOUS_LAGS = (1, 7, 14, 28) + # days (DECISION LOCKED #10). For the long-lag *target* frame the same + # offsets apply but only those with k >= horizon are usable in v1; offsets + # below the horizon are deferred to the recursive Phase-2 extension. + +def apply_assumption_columns(matrix, dates, assumptions: ScenarioAssumptions) -> ...: ... + # price_*, promo_*, is_holiday, lifecycle columns DRIVEN BY the assumptions — + # this is the intended what-if input (not leakage). Out-of-window day -> neutral. + +async def build_future_frame(db, *, store_id, product_id, forecast_origin: date, + horizon, feature_columns, history_tail, + assumptions) -> FutureFeatureFrame: ... + # Orchestrates the above; the only DB read is the optional Calendar lookup + # for baseline (non-assumption) holidays — reading a Calendar row is a + # timeless attribute, allowed. + # PINNED: `history_tail` carries the last HISTORY_TAIL_DAYS = 90 observed + # target values (ending at the forecast origin T). 90 days >= max lag + # offset (28) plus a comfortable buffer, so every long-lag column resolves + # inside the tail (DECISION LOCKED #11). + +# ── PINNED modelling constants (DECISIONS LOCKED #10/#11/#12) ── +EXOGENOUS_LAGS: tuple[int, ...] = (1, 7, 14, 28) # exogenous-feature lag offsets, days +HISTORY_TAIL_DAYS: int = 90 # history fed to the future-frame generator +MAX_COMPARE_SCENARIOS: int = 5 # cap on POST /scenarios/compare +``` + +**Phase B — `RegressionForecaster` in `forecasting/models.py`:** + +```python +class RegressionForecaster(BaseForecaster): + """Feature-driven forecaster wrapping HistGradientBoostingRegressor. + + The FIRST forecaster that consumes the exogenous X argument. + """ + def __init__(self, *, max_iter=200, learning_rate=0.05, max_depth=6, + random_state=42) -> None: ... + def fit(self, y, X) -> RegressionForecaster: # X REQUIRED — no noqa + # raise ValueError if X is None or X.shape[0] != len(y) + def predict(self, horizon, X) -> np.ndarray: # X REQUIRED + # raise ValueError if X is None or X.shape[0] != horizon + def get_params(self) -> dict[str, Any]: ... + def set_params(self, **params) -> RegressionForecaster: ... + +# forecasting/schemas.py: +class RegressionModelConfig(ModelConfigBase): + model_type: Literal["regression"] = "regression" + max_iter: int = Field(default=200, ge=10, le=1000) + learning_rate: float = Field(default=0.05, ge=0.001, le=1.0) + max_depth: int = Field(default=6, ge=1, le=20) + feature_config_hash: str | None = None +# ModelConfig union gains RegressionModelConfig. +# model_factory gains: model_type == "regression" -> RegressionForecaster(...) +``` + +**Phase B — `scenarios/schemas.py` changes:** + +```python +ScenarioMethod = Literal["heuristic", "model_exogenous"] +# ScenarioComparison.method: ScenarioMethod +MODEL_EXOGENOUS_DISCLAIMER = ( + "Model estimate: this scenario re-forecasts demand through a feature-driven " + "model using the assumptions as future inputs. It reflects learned patterns " + "but remains an estimate under uncertainty — not a guarantee." +) +``` + +**Phase C — `scenarios/models.py` + `schemas.py` additions:** + +```python +# models.py — ScenarioPlan gains: +tags: Mapped[list[str]] = mapped_column(JSONB, nullable=False, default=list) +cloned_from: Mapped[str | None] = mapped_column(String(32), nullable=True) +# + a GIN index on tags. + +# schemas.py: +class CompareScenariosRequest(BaseModel): # ConfigDict(strict=True) + # max_length is the PINNED cap MAX_COMPARE_SCENARIOS = 5 (DECISION LOCKED + # #12). Keep the literal 5 in the Field constraint (Pydantic constraints + # must be literal) and reference the constant in the docstring/validation + # test so the two never drift. + scenario_ids: list[str] = Field(..., min_length=2, max_length=5) + rank_by: Literal["revenue_delta","units_delta"] = "revenue_delta" +class ScenarioComparisonRow(BaseModel): # from_attributes + scenario_id, name, units_delta, revenue_delta, coverage_verdict, rank +class MultiScenarioComparison(BaseModel): # from_attributes + baseline_total_units, baseline_revenue + scenarios: list[ScenarioComparisonRow] + chart_series: list[dict[str, ...]] # merged date-keyed rows for Recharts +``` + +**Phase D — `scenarios/models.py` + `schemas.py` provenance/audit additions:** + +```python +# models.py — ScenarioPlan gains (DECISIONS LOCKED #13): +SCENARIO_SOURCE_USER = "user" +SCENARIO_SOURCE_AGENT = "agent" + +source: Mapped[str] = mapped_column( + String(16), nullable=False, server_default=SCENARIO_SOURCE_USER +) +agent_session_id: Mapped[str | None] = mapped_column(String(32), nullable=True) +approved_by: Mapped[str | None] = mapped_column(String(120), nullable=True) +approved_at: Mapped[datetime | None] = mapped_column( + DateTime(timezone=True), nullable=True +) +approval_decision: Mapped[str | None] = mapped_column(String(16), nullable=True) +# + CheckConstraint("source IN ('user','agent')", name="ck_scenario_plan_source") +# + CheckConstraint( +# "approval_decision IN ('approved','rejected')", +# name="ck_scenario_plan_approval_decision") +# + an index on source for "show me agent-proposed plans" queries. + +# schemas.py: +class SaveScenarioRequest(BaseModel): # ConfigDict(strict=True) + # what the save_scenario agent tool persists once HITL-approved. + name: str + assumptions: ScenarioAssumptions + store_id: int + product_id: int + horizon: int + run_id: str + source: Literal["user", "agent"] = "agent" + agent_session_id: str | None = None # the originating agent session +# ScenarioPlanResponse / ScenarioListItem gain: source, agent_session_id, +# approved_by, approved_at, approval_decision (all from_attributes). +``` + +### list of tasks (dependency-ordered) + +```yaml +Task 0 — SETUP: + - Open a GitHub issue "Scenario Simulation — Full Version (#)"; confirm OPEN. + - git fetch origin && git switch -c feat/scenario-simulation-full-version origin/dev + - GOTCHA: no `scenarios` commit scope (DECISIONS LOCKED #6). + - VALIDATE: gh issue view --json state -> OPEN + +# ════════ PHASE A — Future Feature-Frame Generator (backend) ════════ + +Task A1 — CREATE app/features/scenarios/feature_frame.py: + - The leakage-safe X_future builder (see Data models above). PURE helpers for + calendar + long-lag + assumption columns; one async build_future_frame that + does the optional Calendar read. + - PINNED constants (DECISIONS LOCKED #10/#11/#12): define module-level + EXOGENOUS_LAGS = (1, 7, 14, 28), HISTORY_TAIL_DAYS = 90, + MAX_COMPARE_SCENARIOS = 5. The exogenous columns built are price_*, promo_*, + is_holiday, and lifecycle, lagged at EXOGENOUS_LAGS. + - GOTCHA #1/#8: long-lag only (assert min(lags) >= horizon); import featureset + SCHEMA value-objects + data_platform ORM only — never FeatureEngineeringService. + - PATTERN: featuresets/service.py _compute_calendar_features / + _compute_lag_features (REPLICATE the leakage-safe logic, do not import). + - VALIDATE: uv run mypy app/features/scenarios/feature_frame.py && uv run pyright app/features/scenarios/ + +Task A2 — CREATE tests/test_feature_frame.py: + - Calendar columns are a pure function of the date; long-lag columns equal the + correct history_tail index; assumption columns apply only inside windows; + matrix shape == (horizon, len(feature_columns)). + - PATTERN: scenarios/tests/test_adjustments.py. + - VALIDATE: uv run pytest -v -m "not integration" app/features/scenarios/tests/test_feature_frame.py + +Task A3 — CREATE tests/test_future_frame_leakage.py (LOAD-BEARING): + - Assert: no feature value for any horizon day reads an observed target at or + after the forecast origin T; a long-lag column with k >= horizon only ever + indexes history_tail (index <= T); calendar columns ignore the target + entirely; an assumption window before T contributes nothing. + - PATTERN: scenarios/tests/test_leakage.py + featuresets/tests/test_leakage.py. + - GOTCHA: never weaken this test to make a feature pass (AGENTS.md § Safety). + - VALIDATE: uv run pytest -v -m "not integration" app/features/scenarios/tests/test_future_frame_leakage.py + +# ════════ PHASE B — Exogenous-Regressor Model + model-driven simulation ════════ + +Task B1 — MODIFY app/features/forecasting/schemas.py: + - Add RegressionModelConfig (mirror LightGBMModelConfig); add it to the + ModelConfig union. + - VALIDATE: uv run python -c "from app.features.forecasting.schemas import RegressionModelConfig; print('ok')" + +Task B2 — MODIFY app/features/forecasting/models.py: + - Add RegressionForecaster(BaseForecaster) wrapping HistGradientBoostingRegressor; + add "regression" to the ModelType alias; add the model_factory branch. + - GOTCHA #2/#3/#4: X is REQUIRED for fit/predict (raise ValueError on None / + shape mismatch); pass random_state; do NOT impute NaN away. + - PATTERN: the existing forecaster classes (interface), LightGBM branch in + model_factory (the feature-flag shape — but regression needs NO flag). + - VALIDATE: uv run mypy app/features/forecasting/ && uv run pyright app/features/forecasting/ + +Task B3 — CREATE app/features/forecasting/tests/test_regression_forecaster.py: + - fit/predict round-trip on synthetic features; predict rejects None X and a + wrong-shape X; determinism (same random_state -> same output); get/set_params. + - PATTERN: forecasting/tests/test_models.py. + - VALIDATE: uv run pytest -v -m "not integration" app/features/forecasting/tests/test_regression_forecaster.py + +Task B4 — MODIFY app/features/forecasting/service.py: + - Add a regression training path: when config.model_type == "regression", + build HISTORICAL features (replicate the leakage-safe lag/calendar logic, or + a minimal long-lag set matching feature_frame.py), fit RegressionForecaster, + and persist a ModelBundle whose metadata carries feature_columns + + history_tail + history_tail_dates (GOTCHA #5). Predict for a regression + model passes X. + - PINNED (DECISIONS LOCKED #10/#11): the historical feature matrix uses the + SAME exogenous lag offsets EXOGENOUS_LAGS = (1, 7, 14, 28) as the future + frame; metadata `history_tail` / `history_tail_dates` persist the last + HISTORY_TAIL_DAYS = 90 observed target values + dates. The persisted + `feature_columns` order MUST match feature_frame.py column-for-column. + - GOTCHA #14: require enough history; raise ValueError otherwise. + - VALIDATE: uv run mypy app/ && uv run pyright app/ + +Task B5 — CREATE alembic migration _widen_scenario_method_check.py: + - uv run alembic heads (expect 43e35957a248); revision -m "widen scenario + method check"; hand-write upgrade()/downgrade() that drop+recreate the + ck_scenario_plan_method CheckConstraint (upgrade -> IN + ('heuristic','model_exogenous'); downgrade -> IN ('heuristic')). + - GOTCHA #6/#7: down_revision = "43e35957a248" (VERIFY). Postgres needs + op.drop_constraint + op.create_check_constraint. + - VALIDATE: docker compose up -d && uv run alembic upgrade head && uv run alembic downgrade -1 && uv run alembic upgrade head + +Task B6 — MODIFY app/features/scenarios/models.py + schemas.py: + - models.py: widen the ORM CheckConstraint to match the migration; add the + SCENARIO_METHOD_MODEL_EXOGENOUS constant. + - schemas.py: ScenarioComparison.method -> Literal["heuristic","model_exogenous"]; + add MODEL_EXOGENOUS_DISCLAIMER. + - VALIDATE: uv run mypy app/features/scenarios/ + +Task B7 — MODIFY app/features/scenarios/service.py: + - In simulate(): after loading the bundle, branch on bundle.config.model_type. + "regression" -> _simulate_model_exogenous (build the future frame via + feature_frame.build_future_frame from the bundle's feature_columns + + history_tail + the assumptions; call bundle.model.predict(horizon, X); + derive the same ScenarioPoint/aggregate shape; method="model_exogenous", + disclaimer=MODEL_EXOGENOUS_DISCLAIMER). Anything else -> the existing + heuristic branch UNCHANGED. + - GOTCHA: a baseline run that lacks feature_columns metadata -> a clear + ValueError -> RFC 7807 (never a 500). + - PATTERN: the existing heuristic simulate body (point/aggregate construction). + - VALIDATE: uv run mypy app/ && uv run pyright app/ + +Task B8 — EXTEND tests/test_routes_integration.py + add a model-path test: + - A trained_regression_model fixture (real bundle on disk); simulate with a + regression run_id -> 200, method=="model_exogenous"; empty assumptions -> + scenario ≈ baseline; persisting that comparison -> the CHECK accepts it; a + baseline run_id still -> method=="heuristic". Migration constraint test. + - GOTCHA: never mock the DB; integration needs docker compose up + alembic. + - VALIDATE: docker compose up -d && uv run alembic upgrade head && uv run pytest -v -m integration app/features/scenarios/ + +# ════════ PHASE C — Scenario Library + Multi-Scenario Comparison ════════ + +Task C1 — CREATE alembic migration _add_scenario_library_columns.py: + - uv run alembic heads (expect ); add tags JSONB (server_default '[]', + not null) + cloned_from String(32) nullable + a GIN index on tags. + - GOTCHA #7: down_revision = "". + - VALIDATE: docker compose up -d && uv run alembic upgrade head && uv run alembic downgrade -1 && uv run alembic upgrade head + +Task C2 — MODIFY scenarios/models.py + schemas.py: + - models.py: tags + cloned_from columns + GIN index (match the migration). + - schemas.py: CompareScenariosRequest, ScenarioComparisonRow, + MultiScenarioComparison; add tags to ScenarioPlanResponse/ScenarioListItem; + add an optional cloned_from to CreateScenarioRequest. + - VALIDATE: uv run mypy app/features/scenarios/ + +Task C3 — MODIFY scenarios/service.py + routes.py: + - service.py: compare_scenarios (load N plans, rank by the metric, build + chart_series), tag handling on create_plan, a tag filter on list_plans, + clone (a create_plan variant pre-filled from an existing plan). + - routes.py: POST /scenarios/compare; a `tags` query param on GET /scenarios; + a POST /scenarios/{id}/clone (or a cloned_from field on POST /scenarios). + - PATTERN: the MVP routes (404 mapping, RFC 7807). + - VALIDATE: uv run python -c "from app.main import app; assert '/scenarios/compare' in {r.path for r in app.routes}; print('wired')" + +Task C4 — CREATE tests/test_compare_integration.py: + - compare 2-5 saved plans -> ranked rows; <2 or >5 -> 422 (the >5 case + exercises the PINNED MAX_COMPARE_SCENARIOS = 5 cap, DECISIONS LOCKED #12); + a bogus id -> 404; tag filter; clone round-trip. + - VALIDATE: docker compose up -d && uv run pytest -v -m integration app/features/scenarios/tests/test_compare_integration.py + +Task C5 — MODIFY frontend types/hooks/utils: + - types/api.ts: method union; tags; MultiScenarioComparison + rows. + - use-scenarios.ts: useCompareScenarios (mutation); tag filter on useScenarios. + - scenario-utils.ts (+test): mergeMultiScenarioSeries, rankScenarios. + - VALIDATE: cd frontend && pnpm tsc --noEmit && pnpm test --run src/lib/scenario-utils.test.ts + +Task C6 — MODIFY frontend planner page + chart: + - A multi-series chart variant (extend time-series-chart.tsx with a `series` + prop, or a new MultiSeriesChart) — pass M+1 series; a comparison panel on + planner.tsx (multi-select saved plans, run compare, ranked verdict table, + chart). Build via frontend-design + shadcn-ui skills. + - GOTCHA #12: green tsc ≠ working UI. + - VALIDATE: cd frontend && pnpm tsc --noEmit && pnpm lint + +# ════════ PHASE D — Agent-Proposed Scenarios + Approval Flow ════════ + +Task D1 — CREATE alembic migration _add_scenario_provenance_columns.py: + - uv run alembic heads (expect ); revision -m "add scenario provenance + columns". upgrade() adds five columns to scenario_plan: source String(16) + NOT NULL server_default 'user'; agent_session_id String(32) nullable; + approved_by String(120) nullable; approved_at DateTime(timezone=True) + nullable; approval_decision String(16) nullable. Add CHECK constraints + ck_scenario_plan_source ("source IN ('user','agent')") and + ck_scenario_plan_approval_decision ("approval_decision IN + ('approved','rejected')"), and an index on source. downgrade() drops the + index, the two constraints, and the five columns. + - GOTCHA #6/#7: down_revision = "" (VERIFY with alembic heads). The + server_default 'user' makes existing rows backward-compatible. + - VALIDATE: docker compose up -d && uv run alembic upgrade head && uv run alembic downgrade -1 && uv run alembic upgrade head + +Task D2 — MODIFY app/features/scenarios/models.py + schemas.py + service.py: + - models.py: add source / agent_session_id / approved_by / approved_at / + approval_decision columns + the two CHECK constraints + the source index + (match the migration). Add SCENARIO_SOURCE_USER / SCENARIO_SOURCE_AGENT. + - schemas.py: add SaveScenarioRequest; add source/agent_session_id/ + approved_by/approved_at/approval_decision to ScenarioPlanResponse + + ScenarioListItem. + - service.py: extend create_plan to accept the provenance fields (defaulting + source='user', audit columns None — keeps the MVP create path + backward-compatible); add the approval-decision write path that stamps + approved_by/approved_at/approval_decision when an agent save is approved. + - GOTCHA #9: discrete columns, never a `metadata`-named column. + - VALIDATE: uv run mypy app/features/scenarios/ + +Task D3 — CREATE app/features/scenarios/agent_tools.py (TWO tools): + - propose_scenario(db, store_id, product_id, horizon, objective) -> a candidate + ScenarioAssumptions + a plain-language recommendation. READ-ONLY — proposes, + never persists (GOTCHA #10). + - save_scenario(db, request: SaveScenarioRequest, *, agent_session_id) -> + persists a scenario_plan row via the scenarios service create path, stamping + source='agent', the agent_session_id, and the approval audit trail. This is + the MUTATING tool — it runs only after the HITL gate releases it. + - PATTERN: agents/tools/forecasting_tools.py (read-only tool shape); + agents/tools/registry_tools.py create_alias (the mutating, approval-gated + tool shape). + - GOTCHA #8: import scenarios *schemas* + the service create path through this + module — agent_tools.py is the seam; agents/ imports agent_tools.py, never + scenarios/service.py directly. + - VALIDATE: uv run mypy app/features/scenarios/agent_tools.py + +Task D4 — MODIFY app/core/config.py + agents/agents/base.py: + - config.py: add "save_scenario" to agent_require_approval. Current value is + ["create_alias", "archive_run"] (verified, app/core/config.py:164) -> + ["create_alias", "archive_run", "save_scenario"]. This is a deliberate + mutation-surface widening — DECISIONS LOCKED #13; flag it in the PR. + - base.py: add tool_propose_scenario + tool_save_scenario to + TOOL_USAGE_INSTRUCTIONS. + - GOTCHA #12: preserve line endings. + - VALIDATE: uv run python -c "from app.core.config import get_settings; print(get_settings().agent_require_approval)" + +Task D5 — MODIFY app/features/agents/agents/experiment.py: + - Register @agent.tool @recoverable tool_propose_scenario (read-only — calls + scenarios.agent_tools.propose_scenario; no approval gate). + - Register @agent.tool @recoverable tool_save_scenario, gated EXACTLY like + tool_create_alias: a requires_approval("save_scenario") check + a + {"status":"approval_required", ...} early return; the persist happens only + once the approval is granted. + - GOTCHA #8/#10: do NOT import scenarios/service.py into agents — the tool + functions in scenarios/agent_tools.py are the seam; agents imports that module. + - VALIDATE: uv run mypy app/ && uv run pyright app/ + +Task D6 — CREATE tests/test_agent_tools.py: + - propose_scenario returns a valid ScenarioAssumptions + a non-empty + recommendation; it performs NO DB writes. + - save_scenario, when approved, persists a scenario_plan row with + source='agent', the agent_session_id, and the audit columns populated. + - An integration test asserts the HITL gate fires: a save_scenario call on the + experiment agent returns {"status":"approval_required"} and writes no row + until /agents/sessions/{id}/approve is called. + - PATTERN: agents/tests/ tool tests; the create_alias HITL test. + - VALIDATE: docker compose up -d && uv run pytest -v -m integration app/features/scenarios/tests/test_agent_tools.py + +# ════════ CROSS-PHASE — docs + dogfood + PR ════════ + +Task E1 — Dogfood (MANDATORY per .claude/rules/ui-design.md): + - docker compose up -d && alembic upgrade head && seed; train a regression + model; exercise /visualize/planner via webapp-testing / agent-browser: + model-causal simulation (confirm method="model_exogenous" in the response), + save 2-3 plans, run a multi-scenario comparison, confirm the chart + ranked + table; in the chat agent ask for scenario suggestions and confirm the + proposal renders. Capture screenshots. + - VALIDATE: screenshots captured; all manual checks pass. + +Task E2 — UPDATE docs: + - README.md; docs/_base/API_CONTRACTS.md (+/scenarios/compare row, the new + method value, the regression model_type); REPO_MAP_INDEX.md + (feature_frame.py, agent_tools.py); DOMAIN_MODEL.md (the model_exogenous + method, tags/cloned_from + the source/agent_session_id/approved_by/ + approved_at/approval_decision provenance-audit columns on the scenario_plan + aggregate, the future-feature-frame concept, the agent save_scenario HITL + invariant, + ubiquitous-language rows). docs/_base/SECURITY.md — note + save_scenario added to agent_require_approval (the HITL-gated tool list). + - VALIDATE: git diff --stat docs/ README.md + +Task E3 — Commit + PR (one PR per phase preferred): + - Commits (each (#issue), no AI co-author trailer), e.g.: + feat(forecast): add exogenous-regressor forecaster (#N) + feat(api): add leakage-safe future feature-frame generator (#N) + feat(api,db): add model-driven scenario simulation path (#N) + feat(api,db): add scenario library and multi-scenario comparison (#N) + feat(ui): add multi-scenario comparison view (#N) + feat(api,db): add scenario provenance and audit columns (#N) + feat(agents): add agent-proposed and HITL-gated save scenario tools (#N) + test(...) / docs(docs): ... + - GOTCHA: the PR description MUST (a) cite the no-cross-slice-service-import + decision (DECISIONS LOCKED #3); (b) flag the new leakage surface and point + at test_future_frame_leakage.py as its spec; (c) state HistGradientBoosting + over LightGBM and why (no new dependency); (d) answer the 6 product-vision + Litmus questions; (e) note the model path is additive — the heuristic path + stays; (f) explicitly call out the agent mutation-surface widening — the + `save_scenario` tool added to `agent_require_approval` (DECISIONS LOCKED + #13) — per AGENTS.md § Safety "Stop and ask before widening an agent's + mutation surface". + - VALIDATE: open PR(s) into dev; CI green; merge. +``` + +### Per-task pseudocode (critical details only) + +```python +# Task A1 — the long-lag column builder (the leakage-critical helper) +def build_long_lag_columns(history_tail, dates, lags, horizon): + # history_tail[-1] is the observed target at the forecast origin T. + # history_tail holds HISTORY_TAIL_DAYS = 90 values (DECISIONS LOCKED #11). + # `lags` is the subset of EXOGENOUS_LAGS = (1, 7, 14, 28) with k >= horizon. + # For horizon day T+j (j in 1..horizon) and lag k, the value is y[T+j-k]. + # SAFETY: require k >= horizon so j-k <= 0, i.e. the index lands in history. + assert min(lags) >= horizon, "long-lag frame forbids k < horizon (DECISIONS #4)" + columns = {} + for k in lags: + col = [] + for j in range(1, horizon + 1): + # offset back from the END of history_tail: index = -k + (j-1) ... <= -1 + idx = -k + (j - 1) + col.append(history_tail[idx] if -len(history_tail) <= idx < 0 else float("nan")) + columns[f"lag_{k}"] = col + return columns + +# Task B7 — the model-exogenous simulate branch +async def _simulate_model_exogenous(self, db, request, bundle) -> ScenarioComparison: + meta = bundle.metadata + feature_columns = meta.get("feature_columns") + history_tail = meta.get("history_tail") + if not feature_columns or not history_tail: + raise ValueError( + f"run_id '{request.run_id}' is a regression model without the " + "feature metadata required for a scenario forecast." + ) + origin = self._forecast_start_date(meta.get("train_end_date")) - timedelta(days=1) + frame = await build_future_frame( + db, store_id=..., product_id=..., forecast_origin=origin, + horizon=request.horizon, feature_columns=feature_columns, + history_tail=history_tail, assumptions=request.assumptions) + X = np.array(frame.matrix, dtype=np.float64) + scenario_values = [float(v) for v in bundle.model.predict(request.horizon, X)] + # baseline = predict with the SAME frame but assumptions stripped (empty) + baseline_frame = await build_future_frame(..., assumptions=ScenarioAssumptions()) + baseline_values = [float(v) for v in bundle.model.predict( + request.horizon, np.array(baseline_frame.matrix, dtype=np.float64))] + # ... build ScenarioPoint list + aggregates exactly like the heuristic path, + # guarding units_delta_pct /0; method="model_exogenous", + # disclaimer=MODEL_EXOGENOUS_DISCLAIMER. +``` + +### Integration Points + +```yaml +DATABASE: + - migration B: "drop+recreate ck_scenario_plan_method -> + IN ('heuristic','model_exogenous')" + - migration C: "add scenario_plan.tags JSONB (default '[]') + + cloned_from String(32) nullable + GIN index on tags" + - migration D: "add scenario_plan.source String(16) NOT NULL default 'user' + + agent_session_id String(32) + approved_by String(120) + + approved_at DateTime(tz) + approval_decision String(16); + CHECK source IN ('user','agent'); CHECK approval_decision IN + ('approved','rejected'); index on source" + - down_revision: "B off 43e35957a248; C off ; D off — + VERIFY with alembic heads before each" + +ROUTES: + - POST /scenarios/compare added to the existing scenarios router; a `tags` + query param on GET /scenarios; a clone path. (Router already wired in + app/main.py — no main.py change.) + +FORECASTING: + - model_factory gains a "regression" branch (no feature flag — unlike lightgbm). + - ModelType / ModelConfig union gain the regression member. + +AGENTS: + - tool_propose_scenario (read-only) AND tool_save_scenario (mutating, + HITL-gated) on the experiment agent; TOOL_USAGE_INSTRUCTIONS updated for + both; agent_require_approval gains "save_scenario". + +CONFIG: + - app/core/config.py: agent_require_approval gains "save_scenario" + (["create_alias","archive_run"] -> [...,"save_scenario"]) — a deliberate + mutation-surface widening (DECISIONS LOCKED #13). + - no other new setting; reuses forecast_model_artifacts_dir, + forecast_random_seed. forecast_enable_lightgbm stays unused by this PRP. + +FRONTEND ROUTING: + - no new route — the comparison view extends the existing /visualize/planner + page (a panel or tab). +``` + +--- + +## Validation Loop + +### Level 1: Syntax & Style + +```bash +uv run ruff check . --fix && uv run ruff format --check . +cd frontend && pnpm lint +# Traps: date.today()/naive datetime -> ruff DTZ (use datetime.now(UTC)); +# os.path -> ruff PTH; a stray # noqa -> RUF100. +``` + +### Level 2: Type Checks + +```bash +uv run mypy app/ && uv run pyright app/ # both --strict, both gate merge +cd frontend && pnpm tsc --noEmit +``` + +### Level 3: Unit Tests + +```bash +uv run pytest -v -m "not integration" app/features/scenarios/ app/features/forecasting/ +cd frontend && pnpm test --run src/lib/scenario-utils.test.ts +# MUST include the un-weakened MVP leakage spec + the new future-frame leakage spec. +``` + +### Level 4: Integration Tests + Migrations + +```bash +docker compose up -d && uv run alembic upgrade head +uv run pytest -v -m integration app/features/scenarios/ +uv run alembic downgrade -3 && uv run alembic upgrade head # all 3 new migrations up/down +``` + +### Level 5: Manual Validation (dogfood — REQUIRED) + +```bash +docker compose up -d && uv run alembic upgrade head +uv run python scripts/seed_random.py --full-new --seed 42 --confirm +# train a regression model: +curl -s -X POST http://localhost:8123/forecasting/train -H 'content-type: application/json' \ + -d '{"store_id":1,"product_id":101,"train_start_date":"2025-01-01", + "train_end_date":"2026-04-30","config":{"model_type":"regression"}}' +# model-causal simulate (use the run_id from the train response): +curl -s -X POST http://localhost:8123/scenarios/simulate -H 'content-type: application/json' \ + -d '{"run_id":"","horizon":14, + "assumptions":{"price":{"change_pct":-0.15, + "start_date":"2026-05-02","end_date":"2026-05-15"}}}' | grep -o '"method":"[a-z_]*"' +# -> expect "method":"model_exogenous" +# bogus run_id -> 404, not 500: +curl -s -o /dev/null -w '%{http_code}\n' -X POST http://localhost:8123/scenarios/simulate \ + -H 'content-type: application/json' -d '{"run_id":"nope","horizon":14,"assumptions":{}}' +# multi-scenario compare (after saving >=2 plans): +curl -s -X POST http://localhost:8123/scenarios/compare -H 'content-type: application/json' \ + -d '{"scenario_ids":["",""],"rank_by":"revenue_delta"}' | head -c 400 +# Frontend: cd frontend && ./node_modules/.bin/vite --host 0.0.0.0 +# -> /visualize/planner via webapp-testing/agent-browser: run a model-causal +# sim, save 2-3 plans, run the comparison view, confirm the multi-series +# chart + ranked table; chat agent -> ask for scenario suggestions. +``` + +--- + +## Final Validation Checklist + +- [ ] `uv run ruff check . && uv run ruff format --check .` — clean +- [ ] `uv run mypy app/ && uv run pyright app/` — clean (`--strict`) +- [ ] `uv run pytest -v -m "not integration"` — green (MVP leakage spec + UNWEAKENED + new future-frame leakage spec passing) +- [ ] `docker compose up -d && uv run pytest -v -m integration` — green +- [ ] All three new migrations (B method-CHECK widen, C library columns, + D provenance/audit columns) upgrade **and** downgrade cleanly on a fresh DB +- [ ] `cd frontend && pnpm tsc --noEmit && pnpm lint && pnpm test --run` — green +- [ ] A `regression` model trains and persists feature_columns + history_tail +- [ ] `POST /scenarios/simulate` returns `method="model_exogenous"` for a + regression baseline, `method="heuristic"` for a baseline forecaster; a + bogus run_id → RFC 7807, not 500 +- [ ] `POST /scenarios/compare` ranks 2-5 saved plans; tag filter + clone work +- [ ] The planner comparison view renders a multi-series chart + ranked table — + dogfooded in a browser (screenshots captured) +- [ ] `tool_propose_scenario` works and persists nothing; `tool_save_scenario` + is in `agent_require_approval`, pauses for HITL approval, and only then + writes a `scenario_plan` row carrying `source="agent"`, `agent_session_id`, + and the approval audit trail (`approved_by`/`approved_at`/`approval_decision`) +- [ ] No new external dependency (LightGBM NOT added unless the stop-and-ask was + explicitly approved); no WebSocket; no cross-slice `service.py` import +- [ ] README + `docs/_base/{API_CONTRACTS,REPO_MAP_INDEX,DOMAIN_MODEL}.md` + updated +- [ ] Branch `feat/scenario-simulation-full-version`; every commit references + the issue; scopes are `forecast`/`api`/`api,db`/`agents`/`ui`/`docs`; no + AI co-author trailer +- [ ] PR description cites DECISIONS LOCKED #3, flags the new leakage surface + + its spec, states the HistGradientBoosting-over-LightGBM choice, and + answers the 6 Litmus questions + +--- + +## Anti-Patterns to Avoid + +- ❌ Don't re-build the MVP slice — `app/features/scenarios/` already exists. +- ❌ Don't delete `adjustments.py` or the heuristic `simulate` branch — the + model path is ADDITIVE (DECISIONS LOCKED #1). +- ❌ Don't build a future feature frame that uses lags `k < horizon` — that + needs recursion and a far harder leakage proof (DECISIONS LOCKED #4). +- ❌ Don't import `FeatureEngineeringService` / `FeatureDataLoader` / + `ForecastingService` into `scenarios` — cross-slice service import + (DECISIONS LOCKED #3). Import schema value-objects + ORM models; replicate. +- ❌ Don't weaken `scenarios/tests/test_leakage.py` or skip + `test_future_frame_leakage.py` — they are the leakage spec. +- ❌ Don't add LightGBM to `pyproject.toml` without an explicit stop-and-ask — + `HistGradientBoostingRegressor` is the no-new-dependency default. +- ❌ Don't impute the `NaN` out of lag features with a series mean — that + itself leaks; `HistGradientBoostingRegressor` handles `NaN` natively. +- ❌ Don't edit the merged migration `43e35957a248` — add new forward-only + migrations. +- ❌ Don't persist `"model_exogenous"` before the CHECK-widening migration + ships in the same PR. +- ❌ Don't let the agent persist a scenario without the HITL gate — `propose_ + scenario` is read-only; `save_scenario` is mutating and MUST be in + `agent_require_approval` and gated exactly like `tool_create_alias`. +- ❌ Don't fold the provenance/audit fields into the `assumptions`/`comparison` + JSONB blobs — they are discrete, queryable columns (DECISIONS LOCKED #13). +- ❌ Don't add a WebSocket — simulation and comparison are request/response. +- ❌ Don't hand-roll a chart — extend `TimeSeriesChart` / add a multi-series + variant. +- ❌ Don't claim the UI works on a green type-check — dogfood it in a browser. +- ❌ Don't invent a `scenarios` commit scope. + +## Vision Tensions — flag these in the PR + +1. **Exogenous ML vs "not a generic ML platform" (`product-vision.md`).** The + Full Version adds a feature-driven ML model. This is *aligned* — it stays + retail-demand-specific (one model class, retail features, no + classification/NLP/vision) and single-host (`HistGradientBoostingRegressor` + ships with `scikit-learn`, already a dependency; nothing managed-cloud). + Note this reasoning explicitly in the PR. +2. **LightGBM is a deliberate non-goal of this PRP.** `forecast_enable_lightgbm` + and `LightGBMModelConfig` exist but `model_factory` raises + `NotImplementedError` and LightGBM is not in `pyproject.toml`. Adding it is a + separate change — and adding any new dependency is a **stop-and-ask** gate + (`AGENTS.md` § Safety: "Bumping … major versions"; a new core dependency + warrants the same pause). If a reviewer wants LightGBM, that is its own + issue + PR. +3. **Over-trust of revenue claims (the brief's "Risks").** A model-driven number + reads as more authoritative than a heuristic one. Mitigation: every + `ScenarioComparison` still declares its `method` and carries a + method-appropriate `disclaimer` — `MODEL_EXOGENOUS_DISCLAIMER` states the + result is an estimate under uncertainty, not a guarantee (NIST AI RMF + transparency control). +4. **Agent mutation-surface widening (`AGENTS.md` § Safety).** Phase D's + `save_scenario` tool lets the agent write a `scenario_plan` row — a new + mutation. `AGENTS.md` § Safety lists "widening an agent's mutation surface + without adding the tool name to `agent_require_approval`" as a stop-and-ask + gate. This PRP is that approval: the maintainer resolved OQ4 to allow the + persist tool, and the design correctly (a) adds `save_scenario` to + `agent_require_approval` so every agent save pauses for explicit human + approval, and (b) records a full provenance/audit trail on the persisted + row. The PR description MUST call this widening out explicitly so a reviewer + sees the gate was satisfied deliberately, not by omission. + +## Open Questions — ALL RESOLVED + +The maintainer resolved every Open Question during planning. They are recorded +here as DECISION LOCKED entries (and cross-referenced into the DECISIONS LOCKED +section above) — there is nothing left to confirm before coding. + +- **DECISION LOCKED — Regression feature set composition.** Exogenous-feature + lag offsets are PINNED to `EXOGENOUS_LAGS = (1, 7, 14, 28)` days; the + exogenous columns are `price_*`, `promo_*`, `is_holiday`, and lifecycle. Both + the historical feature matrix (Task B4) and the future feature frame + (`feature_frame.py`) use this exact set, so the trained bundle's + `feature_columns` matches column-for-column. See DECISIONS LOCKED #10. +- **DECISION LOCKED — `history_tail` length.** PINNED to + `HISTORY_TAIL_DAYS = 90` days — the last 90 observed target values ending at + the forecast origin, comfortably exceeding the largest lag offset (28). See + DECISIONS LOCKED #11. +- **DECISION LOCKED — Multi-scenario comparison cap.** PINNED to + `MAX_COMPARE_SCENARIOS = 5`. `POST /scenarios/compare` accepts 2–5 + `scenario_id`s, enforced via `Field(..., min_length=2, max_length=5)`. See + DECISIONS LOCKED #12. +- **DECISION LOCKED (OQ4) — Agent persist tool.** The agent gets BOTH a + read-only `propose_scenario` tool AND a mutating `save_scenario` tool. The + `save_scenario` tool persists a `scenario_plan` row only after the human + approves via the existing HITL gate — it is added to `agent_require_approval` + and gated exactly like `tool_create_alias`. The persisted row carries + author/source provenance (`source`, `agent_session_id`) and an approval + audit trail (`approved_by`, `approved_at`, `approval_decision`); Phase D + ships an Alembic migration adding those columns. See DECISIONS LOCKED #13. + +## Confidence Score + +**7 / 10** for one-pass implementation success. + +Rationale: the PRP is grounded in a fully-read MVP and verified repo facts (the +Alembic head, the un-implemented `lightgbm` factory branch, the missing LightGBM +dependency, the `BaseForecaster` interface, the agent HITL pattern). The two +highest-risk pieces are de-risked by locked decisions: (a) the future-feature- +frame leakage surface is bounded to "long-lag + calendar + exogenous" so the +proof is a direct assertion and no recursion is needed; (b) exogenous-regressor +support uses `HistGradientBoostingRegressor` — already a dependency — so there +is no `pyproject.toml` change and no stop-and-ask gate on the critical path. +Phase C and D are near-mechanical (CRUD + a chart variant + two agent tools that +mirror `tool_create_alias`). The score is 7 rather than 9 because the Full +Version is genuinely large — four phases, three migrations, a new model class, a +new leakage spec, a frontend comparison view, and an HITL-gated agent persist +tool — and the regression *training* path (Task B4) requires building a +historical feature matrix that the future frame must exactly mirror; a +column-order or lag-offset mismatch between the two is the most likely one-pass +failure. That risk is now substantially reduced because the maintainer pinned +the three modelling defaults — the exogenous lag offsets `(1, 7, 14, 28)`, the +90-day `history_tail`, and the 5-scenario comparison cap — so `feature_frame.py` +and the Task B4 training path build against the same fixed constants rather than +an implementer's guess; the residual risk is mechanical column-order discipline, +mitigated by persisting `feature_columns` in the bundle metadata and asserting +it on both sides. The PRP explicitly recommends shipping Phase A + B first and +deferring C + D if scope pressure appears — splitting reduces per-PR risk +substantially. diff --git a/PRPs/PRP-28-forecast-explainability-driver-attribution.md b/PRPs/PRP-28-forecast-explainability-driver-attribution.md new file mode 100644 index 00000000..79aab65a --- /dev/null +++ b/PRPs/PRP-28-forecast-explainability-driver-attribution.md @@ -0,0 +1,1092 @@ +name: "PRP-28 — Forecast Explainability & Driver Attribution" +description: | + A new `explainability` vertical slice that produces structured, rule-based + explanations for forecast and registry-run outcomes. Decomposes each of the + three baseline forecasters (`naive`, `seasonal_naive`, `moving_average`) into + named, interpretable demand drivers, layers advisory retail "reason codes" + read time-safely from the data-platform fact tables, and surfaces an + agent-readable summary plus confidence band and caveats. MVP is rule-based + only — SHAP is deliberately excluded. + +--- + +## Goal + +Ship `app/features/explainability/` — a read-only vertical slice that, given a +baseline model + the series it was trained on, computes a `ForecastExplanation`: +an ordered list of `DriverContribution` objects (name, feature value, +contribution, direction), a list of `ReasonCode` objects (advisory retail +signals), a `confidence` band, `caveats`, and an `agent_summary` string. + +End state: + +- Three HTTP endpoints under a self-owned `/explain` namespace: explain a + completed `predict` job, explain a registry run, and an ad-hoc + `POST /explain/forecast`. +- A pure, deterministic rule-based explainer registry keyed by `model_type` + (`NaiveExplainer`, `SeasonalNaiveExplainer`, `MovingAverageExplainer`). +- A reason-code engine reading `inventory_snapshot_daily`, `promotion`, + `product.launch_date`, and `calendar` — never causal claims. +- A `forecast_explanation` persistence table + Alembic migration. +- Frontend explanation panels on the run-detail page and the forecast page. +- The slice imports **only** from `app/core/`, `app/shared/`, and reads + data-platform / registry / jobs ORM models as read-only data contracts — it + imports no other slice's `service.py`. + +## Why + +- **Business value** — ForecastLabAI trains models and reports point forecasts + plus aggregate error metrics (MAE/sMAPE/WAPE/bias), but the *reasoning* is + opaque. A demand planner sees a number on `/visualize/forecast` or a WAPE on + the run-detail page with no narrative connecting the backend's time-safe + signals to the forecast. +- **User impact** — planners cannot tell whether a high forecast reflects + genuine demand or stockout-suppressed history; they cannot compare two models + on *behaviour*, only on a single error scalar; the chat agents cannot cite + feature-level reasons when recommending a model. +- **Integration** — the slice composes existing surfaces (forecasting, registry, + jobs, data platform) without owning any of them; it re-loads the series + itself and re-fits a baseline explainer from the stored `model_config` JSONB. +- **Honest by construction** — a naive forecaster's "explanation" *is* its last + observation; there is no inference gap. Rule-based explainers are exact, not + approximate. Retail reason codes are explicitly labelled as correlation, not + causation (grounded in NIST AI RMF guidance). + +## What + +User-visible behaviour: + +- `POST /explain/forecast` — given `{store_id, product_id, model_type, as_of_date, + season_length?, window_size?}`, returns a `ForecastExplanation` for the h=1 + forecast that the named baseline model would produce on the series ending at + `as_of_date`. +- `GET /explain/runs/{run_id}` — explains a registry `model_run`; reconstructs + the baseline config from `model_run.model_config` JSONB and uses + `data_window_end` as the cutoff. +- `GET /explain/jobs/{job_id}` — explains a completed `predict` job; pulls + `store_id`/`product_id`/`model_type`/`horizon`/`forecasts` from `job.result`. +- Run-detail page gains a "Forecast Explanation" card; forecast page gains an + explanation panel below the chart. + +Technical requirements: + +- New vertical slice, RFC 7807 errors, Pydantic v2 + SQLAlchemy 2.0 async. +- `mypy --strict` + `pyright --strict` + `ruff` clean. +- Alembic migration chaining to the current head. +- Every new module/endpoint/model/migration ships a matching test; new + endpoints get a 2xx happy path + ≥1 error path. +- Time-safety: every series load and every reason-code DB query bounded + `<= as_of_date` / `<= data_window_end`. No future data, ever. + +### Success Criteria + +- [ ] New `app/features/explainability/` slice with + `models/schemas/service/routes/explainers/reason_codes/tests` — no import + from another `app/features//service.py`. +- [ ] 3 endpoints live: `GET /explain/runs/{run_id}`, `GET /explain/jobs/{job_id}`, + `POST /explain/forecast` — registered in `app/main.py`. +- [ ] Rule-based explainers for `naive`, `seasonal_naive`, `moving_average`; + each reproduces the corresponding forecaster's h=1 value **exactly**. +- [ ] `ForecastExplanation` returns `drivers`, `reason_codes`, `confidence`, + `caveats`, `agent_summary`. +- [ ] Retail reason codes (stockout, promotion, lifecycle, holiday, + insufficient-history) computed time-safely (`<= as_of_date`). +- [ ] `lightgbm` runs return a clean 400 (MVP scope guard); SHAP is NOT added to + `pyproject.toml`. +- [ ] New Alembic migration chains to head **`43e35957a248`**, applies and rolls + back cleanly. +- [ ] All errors RFC 7807; all request schemas Pydantic v2; the `date` field on + the `strict=True` request body has `Field(strict=False, ...)`. +- [ ] Frontend explanation panel on run-detail + forecast pages; + `pnpm tsc --noEmit` clean. +- [ ] `ruff` + `mypy --strict` + `pyright --strict` + unit + integration tests + pass. + +## All Needed Context + +### Documentation & References + +```yaml +# MUST READ - external docs +- url: https://fastapi.tiangolo.com/tutorial/bigger-applications/#apirouter + why: APIRouter wiring, path operations, Depends — new slice router mirrors existing slices. + +- url: https://docs.pydantic.dev/latest/concepts/models/ + section: model_config, ConfigDict, field_validator + why: Response/request schema construction. + +- url: https://docs.pydantic.dev/latest/concepts/strict_mode/ + section: strict mode, per-field overrides + critical: ExplainForecastRequest has an `as_of_date` (date) field on a + ConfigDict(strict=True) body — it MUST carry Field(strict=False). + Without it every HTTP caller 422s on the ISO-string JSON path. + +- url: https://docs.sqlalchemy.org/en/20/orm/declarative_tables.html + section: declarative mapping, JSONB columns, indexes + why: ForecastExplanation ORM model + GIN-index pattern. + +- url: https://alembic.sqlalchemy.org/en/latest/ops.html#alembic.operations.Operations.create_table + section: create_table, postgresql.JSONB, create_index + critical: down_revision MUST be "43e35957a248" (current head — verified via + `uv run alembic heads`). + +- url: https://www.nist.gov/itl/ai-risk-management-framework + why: grounding for the "drivers describe correlation/contribution, not + business causality" caveat baked into every explanation. + +- url: https://shap.readthedocs.io/en/stable/generated/shap.TreeExplainer.html + why: REFERENCE ONLY — do NOT implement. Read so the MVP `method` field and + DriverContribution shape stay forward-compatible with SHAP output. + +# MUST READ - codebase files (verified 2026-05-19 against HEAD on feat/scenarios-what-if-planning) +- file: app/features/forecasting/models.py + why: | + NaiveForecaster (class L138-220), SeasonalNaiveForecaster (L223-323), + MovingAverageForecaster (L326-422), model_factory (L429-476). The + explainers MIRROR this exact math (see "Forecaster math — verified" below). + `ModelType` alias at L426. + +- file: app/features/forecasting/schemas.py + why: | + ModelConfigBase (L22-49: ConfigDict(frozen=True, extra="forbid"), + schema_version, config_hash()), NaiveModelConfig (L52-62), + SeasonalNaiveModelConfig (L65-83: season_length default 7, ge=1, le=365), + MovingAverageModelConfig (L86-104: window_size default 7, ge=1, le=90), + LightGBMModelConfig (L107-144), ModelConfig union (L148-150). + TrainRequest (L158-198) is the canonical strict=True body + Field(strict=False) + date + field_validator pattern to copy. + +- file: app/features/forecasting/service.py + why: | + _load_training_data (L314-367) — the EXACT query to load a + (store_id, product_id, date-range) series. Mirror in + ExplainabilityService._load_series. Do NOT import ForecastingService. + +- file: app/features/registry/models.py + why: | + ModelRun ORM model (class L51, __tablename__ L80). JSONB columns: + model_config L88, feature_config L89, metrics L99, config_hash L90, + data_window_start/end L93-94, store_id/product_id L95-96. __table_args__ + (L125-143) — GIN index + CheckConstraint pattern to copy. TimestampMixin used. + +- file: app/features/registry/service.py + why: | + get_run (L247-265): `select(ModelRun).where(ModelRun.run_id == run_id)` → + `.scalar_one_or_none()`. _model_to_response (L658+) documents the alias + quirk — ORM attr `model_config` vs schema field `model_config_data`. + +- file: app/features/jobs/models.py + why: | + Job ORM model (class L68, __tablename__ L88). job_id L91, job_type L92, + status L93, params L96, result L99 (JSONB nullable), run_id L112. + JobType enum L29, JobStatus enum L43. __table_args__ L114-128. + +- file: app/features/jobs/service.py + why: | + _execute_predict (L498-574) — a completed predict job's `result` dict has the + keys {"store_id", "product_id", "model_type", "horizon", "duration_ms", + "forecasts": [{"date": iso, "forecast": float, "lower_bound", "upper_bound"}]}. + The explainer uses store_id/product_id/model_type/horizon/forecasts; duration_ms + is ignored. + +- file: app/features/scenarios/routes.py + why: | + The most recent slice (PRP-26) — canonical router style: APIRouter(prefix=..., + tags=[...]), Depends(get_db), rich summary/description, SQLAlchemyError → + DatabaseError, service-layer ValueError/FileNotFoundError → RFC 7807. + +- file: app/features/scenarios/models.py + why: | + Most recent ORM model — JSONB columns, CheckConstraint, TimestampMixin, + Base import. GOTCHA documented: SQLAlchemy reserves the attr name `metadata`. + +- file: alembic/versions/43e35957a248_create_scenario_plan_table.py + why: | + Migration TEMPLATE — `revision = "43e35957a248"`, `down_revision = "378c112e4b32"`. + The NEW migration's `down_revision` MUST be "43e35957a248" (this file is the + current head). Shows op.create_table + postgresql.JSONB + server_default now(). + +- file: app/features/registry/routes.py + why: | + get_run (L200-232): canonical GET /{id} → service → None-check → + `HTTPException(status.HTTP_404_NOT_FOUND, detail=...)`. + +- file: app/core/exceptions.py + why: | + NotFoundError (L63-83, status 404), ValidationError (L85), DatabaseError + (L107-127), BadRequestError (L151-171, status 400). All RFC 7807. Handlers + already registered in main.py. NEVER bare `raise HTTPException(500, "string")`. + +- file: app/core/database.py + why: | + get_db (L43-53) — auto-commits on success. Therefore the service should + `flush`/`refresh`, NOT `commit`. + +- file: app/main.py + why: | + Import block L15-31, include_router block L133-149. The new + explainability_router registers here (after forecasting/backtesting/registry). + +- file: app/features/data_platform/models.py + why: | + Verified column names: SalesDaily (date L199, store_id L200, product_id L201, + quantity L202, unit_price L203). InventorySnapshotDaily (date L364, + store_id L365, product_id L366, on_hand_qty L367, is_stockout L369). + Promotion (product_id L300, store_id L301, kind L305, start_date L315, + end_date L316). Product (launch_date L101). Calendar (date L146 PK, + is_holiday L151, holiday_name L152). + +- file: app/features/featuresets/tests/test_leakage.py + why: | + The LOAD-BEARING time-safety spec. Internalise the leakage rule before + writing the series load and reason-code queries — bound everything + `<= as_of_date`. + +- file: app/features/forecasting/tests/test_service.py + why: | + UNIT service-test pattern — class-grouped tests, numpy fixtures, + AsyncMock/MagicMock DB. Mirror this for explainability's unit test_service.py. + +- file: app/features/forecasting/tests/test_routes.py + why: | + UNIT route-test pattern — httpx AsyncClient(ASGITransport) with the test-DB + dependency override, 2xx happy path + error paths, RFC 7807 body assertions. + Mirror this for explainability's unit test_routes.py. + +- file: app/features/scenarios/tests/test_routes_integration.py + why: | + @pytest.mark.integration route-test pattern — httpx AsyncClient fixture, + happy path + 404 path, idempotent (no pre-seed assumptions). NOTE: the + scenarios slice ships ONLY an integration route test — there is no unit + test_routes.py / test_service.py there; mirror the forecasting slice for + the unit-level patterns. + +- file: frontend/src/hooks/use-runs.ts + why: TanStack Query hook pattern — useQuery/useMutation, queryKey, api(), enabled. + +- file: frontend/src/types/api.ts + why: | + ModelRun (L173), ForecastPoint (L102), Job (L234), ScenarioComparison + (L775) interfaces — snake_case field names mirroring the Pydantic schemas. + +- file: frontend/src/pages/explorer/run-detail.tsx + why: Card / CardHeader / CardContent composition; gets a new Explanation card. + +- file: frontend/src/pages/visualize/forecast.tsx + why: useJob, in-page predict-job results, EmptyState; gets an explanation panel. + +- doc: .claude/rules/ui-design.md + .claude/rules/shadcn-ui.md + section: frontend toolchain, shadcn skill + MCP + critical: Use the shadcn skill/MCP for any NEW shadcn component; reuse + already-installed card/badge/table first; verify in a real browser. +``` + +### Current Codebase tree (relevant subset) + +```bash +app/ + core/ # config, database, exceptions, logging, problem_details + shared/ + models.py # TimestampMixin + features/ + data_platform/models.py # SalesDaily, InventorySnapshotDaily, Promotion, Product, Calendar + forecasting/{models,schemas,service,routes,persistence}.py + tests/ + registry/{models,schemas,service,routes}.py + tests/ + jobs/{models,schemas,service,routes}.py + tests/ + backtesting/{...}/ # FoldResult shape (read-only reference for full version) + scenarios/{models,schemas,service,routes,adjustments}.py + tests/ # newest slice — PRP-26 +alembic/versions/ # head: 43e35957a248_create_scenario_plan_table.py +frontend/src/ + hooks/use-runs.ts + types/api.ts + pages/explorer/run-detail.tsx + pages/visualize/forecast.tsx + components/ # shadcn ui under components/ui/ +``` + +### Desired Codebase tree (files added / modified) + +```bash +app/features/explainability/ # NEW SLICE + __init__.py # empty package marker + schemas.py # Pydantic v2: DriverContribution, ReasonCode, + # ConfidenceLevel(enum), ForecastExplanation, + # ExplainForecastRequest + models.py # ForecastExplanation ORM (forecast_explanation table) + explainers.py # BaseExplainer ABC + 3 explainers + explainer_factory + reason_codes.py # pure reason-code functions + build_caveats + service.py # ExplainabilityService + routes.py # APIRouter, 3 endpoints + tests/ + __init__.py + conftest.py # series fixtures + sample ModelRun/Job rows + test_explainers.py # unit (no DB) — forecaster-parity assertions + test_reason_codes.py # unit (no DB) + test_schemas.py # schema validation + JSON-path test + test_service.py # service unit (mocked AsyncSession) + test_routes.py # route tests (2xx + 404 + 400) + test_models_integration.py # @pytest.mark.integration — CRUD + CheckConstraints + test_routes_integration.py # @pytest.mark.integration — end-to-end +alembic/versions/ + _create_forecast_explanation_table.py # down_revision = "43e35957a248" +frontend/src/ + hooks/use-explanations.ts # NEW — 3 TanStack Query hooks + components/explainability/explanation-panel.tsx # NEW — drivers + reason codes + confidence + caveats + components/explainability/explanation-panel.test.tsx # NEW — vitest render test + +# MODIFIED +app/main.py # import + include explainability_router +frontend/src/types/api.ts # + DriverContribution, ReasonCode, ConfidenceLevel, + # ForecastExplanation +frontend/src/pages/explorer/run-detail.tsx # + Forecast Explanation card +frontend/src/pages/visualize/forecast.tsx # + explanation panel for loaded predict job +docs/_base/API_CONTRACTS.md # + 3 endpoint rows (optional but recommended) +``` + +### Forecaster math — VERIFIED against `app/features/forecasting/models.py` + +The explainer for a model MUST produce the **same h=1 value** the forecaster +would, or it is wrong. `test_explainers.py` asserts this against the real +forecasters. + +```python +# NaiveForecaster (L156-199): fit stores last_value = float(y[-1]); +# predict(h) → np.full(h, last_value). h=1 forecast == y[-1]. +# fit raises ValueError("Cannot fit on empty array") when len(y) == 0. + +# SeasonalNaiveForecaster (L252-302): fit stores _last_values = y[-season_length:]; +# predict(h): forecasts[k] = _last_values[k % season_length]. +# => h=1 forecast (k=0) == _last_values[0] == y[-season_length]. +# fit raises ValueError(f"Need at least {season_length} observations") +# when len(y) < season_length. + +# MovingAverageForecaster (L356-401): fit stores _forecast_value = +# float(np.mean(y[-window_size:])); predict(h) → np.full(h, _forecast_value). +# => h=1 forecast == mean(y[-window_size:]). +# fit raises ValueError(f"Need at least {window_size} observations") +# when len(y) < window_size. +``` + +### Predict-job `result` shape — VERIFIED against `jobs/service.py:_execute_predict` (L559-574) + +The relevant keys the explainer reads are `store_id`, `product_id`, +`model_type`, `horizon`, and `forecasts`. `_execute_predict` also emits a +`duration_ms` key, which the explainer ignores. + +```python +{ + "store_id": int, "product_id": int, "model_type": str, "horizon": int, + "duration_ms": float, # emitted by _execute_predict; ignored here + "forecasts": [ + {"date": "YYYY-MM-DD", "forecast": float, + "lower_bound": float | None, "upper_bound": float | None}, + ... + ], +} +# NOTE: `run_id` in a predict job's *params* is the model-ARTIFACT key +# (model_{run_id}.joblib), NOT a registry run_id. +# as_of_date for a predict job = day before the FIRST forecast date. +``` + +### Known Gotchas of our codebase & Library Quirks + +```python +# CRITICAL: Current Alembic head is "43e35957a248" (create_scenario_plan_table). +# The new migration's down_revision MUST be exactly "43e35957a248" or the +# CI `migration-check` job fails. (The source plan said "378c112e4b32" — +# that is STALE; PRP-26 added 43e35957a248 on top of it.) + +# CRITICAL: Pydantic strict-mode policy (docs/_base/SECURITY.md, enforced by +# app/core/tests/test_strict_mode_policy.py — an AST linter that FAILS CI). +# On a ConfigDict(strict=True) request body, any field typed +# date/datetime/time/UUID/Decimal MUST carry Field(strict=False, ...). +# ExplainForecastRequest.as_of_date is a `date` → it MUST be +# `Field(..., strict=False, ...)`. Response schemas are NOT strict=True → exempt. + +# CRITICAL: Vertical-slice rule — app/features/explainability/ may import ONLY +# from app/core/, app/shared/, and ORM models. It MUST NOT import +# app.features.forecasting.service / registry.service / jobs.service / +# backtesting.service. To explain a run/job, query ModelRun/Job ORM rows +# directly. Importing registry.models.ModelRun and jobs.models.Job read-only +# is the LOCKED decision (see "Open Questions & Decisions" #1) — same pattern +# as importing data_platform.models. NEVER import a sibling slice's service. + +# CRITICAL: Time-safety is LOAD-BEARING. Every series load and reason-code DB +# query MUST be bounded `SalesDaily.date <= as_of_date` (and reason-code +# tables `<= as_of_date`). Mirror forecasting/service.py:_load_training_data. +# Treat any leakage path as a blocker, same discipline as test_leakage.py. + +# GOTCHA: SQLAlchemy reserves the declarative attribute name `metadata` — do +# NOT name any ORM column `metadata` (scenarios/models.py documents this). + +# GOTCHA: get_db (app/core/database.py:43-53) auto-commits on success. The +# service should `await db.flush()` + `await db.refresh(obj)` — do NOT call +# `await db.commit()` inside the service. + +# GOTCHA: ModelRun has an attr `model_config` (the JSONB column). Pydantic also +# reserves `model_config` for ConfigDict — the registry schema aliases it to +# `model_config_data`. The explainability slice reads the ORM attr +# `ModelRun.model_config` (a dict) directly; no Pydantic alias needed there. + +# GOTCHA: numpy floats are not JSON-serialisable as-is — cast every explainer +# output through `float(...)` before placing it in a Pydantic schema. + +# GOTCHA: FastAPI bodies — never raw `Body(Any)`. Use the Pydantic request model. +# New endpoints need a 2xx happy path + ≥1 error path (test-requirements.md). +``` + +## Implementation Blueprint + +### Data models and structure + +```python +# ---- app/features/explainability/schemas.py (Pydantic v2) ---- +class ConfidenceLevel(str, Enum): + HIGH = "high"; MEDIUM = "medium"; LOW = "low" + +class DriverContribution(BaseModel): # plain BaseModel (response sub-object) + name: str + feature_value: float + contribution: float # model-units amount this driver adds + direction: Literal["positive", "negative", "neutral"] + description: str + +class ReasonCode(BaseModel): # advisory only — correlation, never causation + code: Literal["stockout_constrained", "promotion_overlap", "holiday_effect", + "lifecycle_decay", "trend_shift", "insufficient_history"] + severity: Literal["info", "warn"] + detail: str + +class ForecastExplanation(BaseModel): # response — NOT strict + model_config = ConfigDict(from_attributes=True) + store_id: int + product_id: int + model_type: str + method: Literal["rule_based"] # "shap"/"component" reserved for full version + forecast_value: float + drivers: list[DriverContribution] + reason_codes: list[ReasonCode] + confidence: ConfidenceLevel + caveats: list[str] + agent_summary: str + as_of_date: date_type + generated_at: datetime + +class ExplainForecastRequest(BaseModel): # request — strict=True + model_config = ConfigDict(strict=True) + store_id: int = Field(..., ge=1) + product_id: int = Field(..., ge=1) + model_type: Literal["naive", "seasonal_naive", "moving_average"] + # date has no native JSON type -> strict=False per docs/_base/SECURITY.md + as_of_date: date_type = Field(..., strict=False, description="Series cutoff date") + season_length: int | None = Field(None, ge=1, le=365) + window_size: int | None = Field(None, ge=1, le=90) + +# ---- app/features/explainability/models.py (SQLAlchemy 2.0) ---- +class ForecastExplanation(TimestampMixin, Base): + __tablename__ = "forecast_explanation" + id: Mapped[int] = mapped_column(Integer, primary_key=True) + explanation_id: Mapped[str] = mapped_column(String(32), unique=True, index=True) + run_id: Mapped[str | None] = mapped_column(String(32), nullable=True, index=True) + job_id: Mapped[str | None] = mapped_column(String(32), nullable=True, index=True) + store_id: Mapped[int] = mapped_column(Integer, index=True) + product_id: Mapped[int] = mapped_column(Integer, index=True) + model_type: Mapped[str] = mapped_column(String(50)) + method: Mapped[str] = mapped_column(String(20), default="rule_based") + as_of_date: Mapped[datetime.date] = mapped_column(Date) + forecast_value: Mapped[float] = mapped_column(Float) + confidence: Mapped[str] = mapped_column(String(10)) + drivers: Mapped[list[dict[str, Any]]] = mapped_column(JSONB, nullable=False) + reason_codes: Mapped[list[dict[str, Any]]] = mapped_column(JSONB, nullable=False) + caveats: Mapped[list[str]] = mapped_column(JSONB, nullable=False) + agent_summary: Mapped[str] = mapped_column(String(2000)) + __table_args__ = ( + Index("ix_forecast_explanation_drivers_gin", "drivers", postgresql_using="gin"), + Index("ix_forecast_explanation_store_product", "store_id", "product_id"), + CheckConstraint("confidence IN ('high','medium','low')", + name="ck_forecast_explanation_confidence"), + CheckConstraint("method IN ('rule_based','shap','component')", + name="ck_forecast_explanation_method"), + ) +``` + +### List of tasks (execute in order — each is atomic and independently testable) + +```yaml +Task 1 — CREATE app/features/explainability/__init__.py: + - Empty file (slice package marker). + - MIRROR: any app/features//__init__.py. + - VALIDATE: test -f app/features/explainability/__init__.py && echo OK + +Task 2 — CREATE app/features/explainability/schemas.py: + - IMPLEMENT ConfidenceLevel, DriverContribution, ReasonCode, ForecastExplanation, + ExplainForecastRequest exactly as in "Data models and structure" above. + - Optional field_validator on ExplainForecastRequest: default season_length=7 / + window_size=7 when None (mirror SeasonalNaiveModelConfig / MovingAverageModelConfig + defaults) rather than hard-failing. + - MIRROR: app/features/forecasting/schemas.py L22-150 (config style), + L158-198 (strict=True body + Field(strict=False) date + field_validator). + - IMPORTS: from __future__ import annotations; from datetime import + date as date_type, datetime; from enum import Enum; from typing import Literal; + from pydantic import BaseModel, ConfigDict, Field, field_validator. + - GOTCHA: as_of_date MUST carry Field(strict=False) (test_strict_mode_policy.py). + - VALIDATE: uv run python -c "from app.features.explainability.schemas import + ForecastExplanation, ExplainForecastRequest, DriverContribution, ReasonCode; print('OK')" + +Task 3 — CREATE app/features/explainability/models.py: + - IMPLEMENT ForecastExplanation ORM model exactly as above. + - MIRROR: app/features/scenarios/models.py (newest — JSONB cols, CheckConstraint, + TimestampMixin, Base import) + app/features/registry/models.py L125-143 + (GIN index pattern). + - IMPORTS: from __future__ import annotations; import datetime; from typing + import Any; from sqlalchemy import CheckConstraint, Date, Float, Index, + Integer, String; from sqlalchemy.dialects.postgresql import JSONB; + from sqlalchemy.orm import Mapped, mapped_column; from app.core.database + import Base; from app.shared.models import TimestampMixin. + - VALIDATE: uv run python -c "from app.features.explainability.models import + ForecastExplanation; print(ForecastExplanation.__tablename__)" + +Task 3b — REGISTER the model with Alembic env (so migration drift check sees it): + - FIND: how registry/jobs/scenarios models are imported into the Alembic + `target_metadata` (alembic/env.py or a models aggregator). Commit + 9e7a9e1 registered scenario_plan for the drift check — mirror it. + - VALIDATE: uv run alembic check (should report no drift after Task 4). + +Task 4 — CREATE alembic/versions/_create_forecast_explanation_table.py: + - IMPLEMENT op.create_table("forecast_explanation", ...) with EVERY column + from the ORM model + created_at/updated_at (DateTime(timezone=True), + server_default=sa.text("now()")), then op.create_index for the GIN index + (postgresql_using="gin"), the composite (store_id, product_id) index, and + the unique explanation_id index. downgrade() = op.drop_table. + - revision = ""; down_revision = "43e35957a248". <-- CRITICAL + - MIRROR: alembic/versions/43e35957a248_create_scenario_plan_table.py + + a2f7b3c8d901_create_model_registry_tables.py (JSONB + GIN + CheckConstraint). + - Generate with: `uv run alembic revision -m "create forecast explanation table"` + then fill in the body, OR hand-pick a 12-hex id not already used. + - VALIDATE (needs docker compose up -d): uv run alembic upgrade head && + uv run alembic downgrade -1 && uv run alembic upgrade head + +Task 5 — CREATE app/features/explainability/explainers.py: + - IMPLEMENT BaseExplainer(ABC) with abstract + explain(self, y: np.ndarray) -> tuple[float, list[DriverContribution]] and + confidence(self, y: np.ndarray) -> ConfidenceLevel. + - NaiveExplainer: forecast = float(y[-1]); main driver "last_observation" + (feature_value=y[-1], contribution=y[-1], direction="positive"). Add an + informational "recent_trend" driver = mean(y[-7:]) - mean(y[-14:-7]) with + contribution=0.0 (direction from sign) labelled "context, not used by model". + confidence: LOW if len(y) < 14 else MEDIUM. + - SeasonalNaiveExplainer(season_length): forecast = float(y[-season_length]); + main driver "season_match" (feature_value/contribution = y[-season_length], + direction="positive"). confidence: LOW if len(y) < 2*season_length else MEDIUM. + - MovingAverageExplainer(window_size): forecast = float(mean(y[-window_size:])); + one aggregate "window_mean" driver (contribution = forecast) + informational + "window_dispersion" driver = float(std(y[-window_size:])) contribution=0.0. + confidence: HIGH if len(y) >= window_size and std/mean < 0.5; + MEDIUM if len(y) >= window_size; else LOW. + - explainer_factory(model_type, season_length, window_size) -> BaseExplainer; + raise ValueError for "lightgbm"/unknown (caught at route layer -> 400). + - Empty y -> raise ValueError (mirror NaiveForecaster.fit). Cast all outputs + to float(). Sum of MAIN driver contributions must ≈ forecast_value + (informational drivers contribution=0.0, excluded). + - MIRROR: app/features/forecasting/models.py — BaseForecaster ABC (L44-126), + each forecaster's math, model_factory (L429-476) dispatch pattern. + - IMPORTS: from __future__ import annotations; from abc import ABC, + abstractmethod; import numpy as np; from app.features.explainability.schemas + import ConfidenceLevel, DriverContribution. + - VALIDATE: uv run python -c "import numpy as np; + from app.features.explainability.explainers import explainer_factory; + e=explainer_factory('moving_average',None,7); print(e.explain(np.arange(30.0)))" + +Task 6 — CREATE app/features/explainability/reason_codes.py: + - IMPLEMENT pure functions (DB-free — the service does the queries and passes + already-windowed rows): + * stockout_reason(inventory_rows) -> ReasonCode | None — fires + "stockout_constrained" (warn) if any is_stockout day in the trailing + window; detail counts stockout days. + * promotion_reason(promotion_rows, as_of_date) -> ReasonCode | None — + "promotion_overlap" (info) if a promotion overlaps the trailing window + or is active at as_of_date. + * lifecycle_reason(launch_date, as_of_date) -> ReasonCode | None — + days_since_launch = (as_of_date - launch_date).days; < 30 -> "lifecycle_decay" + (info). + * holiday_reason(calendar_rows, forecast_date) -> ReasonCode | None — + "holiday_effect" (info) if the immediate forecast horizon hits a flagged + Calendar.is_holiday row. + * history_reason(n_obs, min_required) -> ReasonCode | None — + "insufficient_history" (warn) if n_obs < min_required. + * build_caveats(model_type, reason_codes) -> list[str] — ALWAYS includes the + NIST-grounded "drivers describe correlation/contribution, not business + causality" caveat; adds model-specific ones (naive: "ignores seasonality + & trend"; seasonal_naive: "assumes the prior cycle repeats"; + moving_average: "smooths over recent shifts"). + - GOTCHA: each function's docstring states it must only receive rows already + windowed `<= as_of_date` by the caller. holiday_reason must not peek past + the explained horizon's last date. + - IMPORTS: from __future__ import annotations; from datetime import + date as date_type; from app.features.explainability.schemas import ReasonCode. + - VALIDATE: uv run python -c "from app.features.explainability.reason_codes + import build_caveats; print(build_caveats('naive', []))" + +Task 7 — CREATE app/features/explainability/service.py: + - IMPLEMENT ExplainabilityService (see pseudocode below). + - MIRROR: app/features/forecasting/service.py (_load_training_data shape), + registry/service.py:get_run (select(ModelRun)... scalar_one_or_none), + scenarios/service.py (persist via flush/refresh, NOT commit). + - IMPORTS: from __future__ import annotations; import uuid; from datetime + import UTC, datetime, timedelta; from datetime import date as date_type; + import numpy as np; import structlog; from sqlalchemy import select; + from sqlalchemy.ext.asyncio import AsyncSession; from app.core.config import + get_settings; from app.core.exceptions import BadRequestError, NotFoundError; + from app.features.data_platform.models import (SalesDaily, + InventorySnapshotDaily, Promotion, Product, Calendar); + from app.features.registry.models import ModelRun; # read-only data contract + from app.features.jobs.models import Job; # read-only data contract + + slice-local explainers / reason_codes / schemas / models. + - NOTE: importing registry/jobs ORM models read-only is the locked decision + (see "Open Questions & Decisions" #1) — document it in the service module + docstring + PR description. DO NOT import any *.service. + - VALIDATE: uv run mypy app/features/explainability/service.py && + uv run pyright app/features/explainability/service.py + +Task 8 — CREATE app/features/explainability/routes.py: + - IMPLEMENT router = APIRouter(prefix="/explain", tags=["explainability"]): + * GET /explain/jobs/{job_id} -> explain_job (404 missing, 400 not a + completed predict job) + * GET /explain/runs/{run_id} -> explain_run (404 missing, 400 lightgbm) + * POST /explain/forecast -> explain_forecast (status_code=200) + - Each route: service = ExplainabilityService(); try/except + ValueError -> HTTPException(400, ...); SQLAlchemyError -> + DatabaseError(...); None from service -> HTTPException(404, ...). + - Add rich summary/description on every route (agents read these). + - MIRROR: app/features/scenarios/routes.py + registry/routes.py:get_run (L200-232). + - IMPORTS: from fastapi import APIRouter, Depends, HTTPException, status; + from sqlalchemy.exc import SQLAlchemyError; from sqlalchemy.ext.asyncio + import AsyncSession; from app.core.database import get_db; from + app.core.exceptions import DatabaseError; from app.core.logging import + get_logger; + slice schemas + service. + - VALIDATE: uv run python -c "from app.features.explainability.routes import + router; print([r.path for r in router.routes])" + +Task 9 — MODIFY app/main.py: + - FIND import block (L15-31). INJECT (near forecasting): + `from app.features.explainability.routes import router as explainability_router` + - FIND include_router block (L133-149). INJECT after `forecasting_router`: + `app.include_router(explainability_router)` + - PRESERVE all other wiring; no other change. + - VALIDATE: uv run python -c "from app.main import app; + print([r.path for r in app.routes if 'explain' in getattr(r,'path','')])" + +Task 10 — CREATE tests/__init__.py + tests/conftest.py: + - __init__.py empty. conftest.py: sample_series (np.ndarray ~60 values), + flat_series, short_series (<14), sample_run_row (ModelRun-shaped object or + dict with model_config JSONB), sample_predict_job (completed Job-shaped row). + Reuse the integration DB/client fixture pattern from + app/features/scenarios/tests/conftest.py for @pytest.mark.integration tests. + - VALIDATE: uv run pytest app/features/explainability/tests/ --collect-only -q + +Task 11 — CREATE tests/test_explainers.py: + - For each explainer: (a) forecast value EQUALS the real forecasting/models.py + forecaster's `.fit(y).predict(1)[0]` on the same y (import the real + forecasters — allowed in tests); (b) main driver contribution sum ≈ + forecast_value (pytest.approx); (c) direction signs correct; + (d) empty y raises ValueError; (e) confidence downgrades on short series. + - MIRROR: app/features/forecasting/tests/test_service.py (class-grouped, numpy). + - VALIDATE: uv run pytest app/features/explainability/tests/test_explainers.py + -v -m "not integration" + +Task 12 — CREATE tests/test_reason_codes.py: + - stockout_reason fires on stockout rows / None otherwise; promotion_reason on + overlap; lifecycle_reason on recent launch; history_reason on short series; + build_caveats always includes the correlation-vs-causation caveat. + - VALIDATE: uv run pytest app/features/explainability/tests/test_reason_codes.py + -v -m "not integration" + +Task 13 — CREATE tests/test_schemas.py: + - ExplainForecastRequest.model_validate({...}) accepts an ISO date STRING for + as_of_date (THE JSON PATH — required by docs/_base/SECURITY.md, catches the + strict regression). ForecastExplanation round-trips model_dump/model_validate. + ConfidenceLevel enum values. Invalid model_type rejected. + - VALIDATE: uv run pytest app/features/explainability/tests/test_schemas.py + -v -m "not integration" + +Task 14 — CREATE tests/test_service.py: + - ExplainabilityService with a mocked AsyncSession (AsyncMock) — explain_forecast + returns a well-formed ForecastExplanation; explain_run resolves config from a + fake ModelRun.model_config dict; explain_job rejects a non-completed / + non-predict job (BadRequestError); missing run/job -> None or NotFoundError. + - GOTCHA: mock the select(...) result chain — db.execute.return_value -> + object with .all() / .scalar_one_or_none(). + - MIRROR: app/features/forecasting/tests/test_service.py (UNIT service-test + pattern — class-grouped, AsyncMock/MagicMock DB). The scenarios slice ships + no unit test_service.py — do NOT mirror it for this. + - VALIDATE: uv run pytest app/features/explainability/tests/test_service.py + -v -m "not integration" + +Task 15 — CREATE tests/test_routes.py: + - httpx AsyncClient(transport=ASGITransport(app=app)) with the test-DB + dependency override. POST /explain/forecast 200; + GET /explain/runs/{missing} 404; GET /explain/jobs/{missing} 404; + GET /explain/runs/{lightgbm-run} 400. Assert RFC 7807 body shape + (type, title, status, detail) on error paths. + - MIRROR: app/features/forecasting/tests/test_routes.py (UNIT route-test + pattern — ASGITransport, dependency override, RFC 7807 assertions). The + scenarios slice ships only an integration route test — do NOT mirror it for + this unit task; see Task 16 for the integration route test. + - VALIDATE: uv run pytest app/features/explainability/tests/test_routes.py + -v -m "not integration" + +Task 16 — CREATE tests/test_models_integration.py + tests/test_routes_integration.py: + - @pytest.mark.integration — ForecastExplanation CRUD against real Postgres; + confidence/method CheckConstraint rejects bad values; end-to-end: seed a + tiny series + a baseline ModelRun -> GET /explain/runs/{run_id} returns a + real explanation; explanation row persisted and re-readable. + - GOTCHA: never mock the DB; tests idempotent (no pre-seed assumptions). + - VALIDATE (docker compose up -d): uv run alembic upgrade head && + uv run pytest app/features/explainability/tests/ -v -m integration + +Task 17 — MODIFY frontend/src/types/api.ts: + - ADD DriverContribution, ReasonCode, ConfidenceLevel ('high'|'medium'|'low'), + ForecastExplanation TS interfaces — snake_case field names mirroring the + Pydantic response schema EXACTLY. + - MIRROR: existing ModelRun (L173) / ForecastPoint (L102) interfaces. + - VALIDATE: cd frontend && pnpm tsc --noEmit + +Task 18 — CREATE frontend/src/hooks/use-explanations.ts: + - useRunExplanation(runId, enabled) -> api('/explain/runs/'+runId); + useJobExplanation(jobId, enabled) -> '/explain/jobs/'+jobId; + useExplainForecast() -> useMutation POST '/explain/forecast'. + - MIRROR: frontend/src/hooks/use-runs.ts. + - VALIDATE: cd frontend && pnpm tsc --noEmit + +Task 19 — CREATE frontend/src/components/explainability/explanation-panel.tsx: + - — a Card with a drivers + table (name, value, contribution, direction colour), a reason-codes list + (icon by severity), a confidence Badge, and a caveats footnote list. + - Reuse already-installed shadcn card/badge/table FIRST. For any NEW shadcn + component follow .claude/rules/shadcn-ui.md (shadcn skill + MCP). + - MIRROR: run-detail.tsx Card composition; forecast.tsx EmptyState/LoadingState. + - VALIDATE: cd frontend && pnpm tsc --noEmit && pnpm lint + +Task 20 — CREATE frontend/src/components/explainability/explanation-panel.test.tsx: + - vitest render test — panel renders drivers/reason codes/confidence/caveats + from a fixture; loading + error states. + - VALIDATE: cd frontend && pnpm test --run + +Task 21 — MODIFY frontend/src/pages/explorer/run-detail.tsx: + - After the Metrics card, add a "Forecast Explanation" Card rendering + fed by useRunExplanation(runId, true). Gracefully handle + lightgbm runs (400 -> "Explanations are available for baseline models only"). + - VALIDATE: cd frontend && pnpm tsc --noEmit && pnpm lint + +Task 22 — MODIFY frontend/src/pages/visualize/forecast.tsx: + - When a completed predict job is loaded (job.status==='completed' && + job.job_type==='predict'), render fed by + useJobExplanation(job.job_id, true) below the forecast chart. + - VALIDATE: cd frontend && pnpm tsc --noEmit && pnpm lint + +Task 23 — MODIFY docs/_base/API_CONTRACTS.md: + - ADD 3 rows under a new `explainability` slice grouping in the HTTP endpoint + table: GET /explain/runs/{run_id}, GET /explain/jobs/{job_id}, + POST /explain/forecast. + - VALIDATE: grep -c "/explain/" docs/_base/API_CONTRACTS.md +``` + +### Per-task pseudocode (CRITICAL details) + +```python +# ---- Task 5: explainers.py ---- +class BaseExplainer(ABC): + @abstractmethod + def explain(self, y: np.ndarray) -> tuple[float, list[DriverContribution]]: ... + @abstractmethod + def confidence(self, y: np.ndarray) -> ConfidenceLevel: ... + +class NaiveExplainer(BaseExplainer): + def explain(self, y): + if len(y) == 0: # GOTCHA: mirror NaiveForecaster.fit + raise ValueError("Cannot explain an empty series") + forecast = float(y[-1]) + drivers = [DriverContribution( + name="last_observation", feature_value=forecast, contribution=forecast, + direction="positive", + description="Naive forecast IS the last observed value.")] + if len(y) >= 14: # informational only — contribution=0.0 + trend = float(np.mean(y[-7:]) - np.mean(y[-14:-7])) + drivers.append(DriverContribution( + name="recent_trend", feature_value=trend, contribution=0.0, + direction="positive" if trend > 0 else "negative" if trend < 0 else "neutral", + description="Context only — the naive model does not use trend.")) + return forecast, drivers + +# ---- Task 7: service.py ---- +class ExplainabilityService: + """Read-only explainability slice service. + + Imports registry.models.ModelRun and jobs.models.Job as READ-ONLY data + contracts (locked decision — see PRP "Open Questions & Decisions" #1; same + pattern as importing data_platform.models). It imports NO other slice's + service.py. To explain a run/job it re-loads the series from sales_daily and + re-fits a rule-based explainer from the stored config. + """ + def __init__(self) -> None: + self.settings = get_settings() + + async def _load_series(self, db, store_id, product_id, end_date): + # MIRROR forecasting/service.py:_load_training_data — TIME-SAFE upper bound + stmt = (select(SalesDaily.date, SalesDaily.quantity) + .where((SalesDaily.store_id == store_id) + & (SalesDaily.product_id == product_id) + & (SalesDaily.date <= end_date)) # <-- LOAD-BEARING + .order_by(SalesDaily.date)) + rows = (await db.execute(stmt)).all() + y = np.array([float(r.quantity) for r in rows], dtype=np.float64) + return y, [r.date for r in rows] + + async def explain_forecast(self, db, request: ExplainForecastRequest) -> ForecastExplanation: + y, dates = await self._load_series(db, request.store_id, request.product_id, + request.as_of_date) + explainer = explainer_factory(request.model_type, request.season_length, + request.window_size) # ValueError -> route 400 + forecast_value, drivers = explainer.explain(y) # ValueError on empty y + confidence = explainer.confidence(y) + inv, promos, product, cal = await self._load_reason_code_inputs( + db, request.store_id, request.product_id, request.as_of_date) + reason_codes = self._assemble_reason_codes(inv, promos, product, cal, + request.as_of_date, len(y), + request.model_type) + caveats = build_caveats(request.model_type, reason_codes) + agent_summary = self._build_agent_summary(drivers, reason_codes, confidence, + forecast_value) + explanation = ForecastExplanation(... fields ...) + # PERSIST — flush/refresh, NOT commit (get_db auto-commits) + row = ExplanationORM(explanation_id=uuid.uuid4().hex, ...) + db.add(row); await db.flush(); await db.refresh(row) + return explanation + + async def explain_run(self, db, run_id: str) -> ForecastExplanation | None: + run = (await db.execute( + select(ModelRun).where(ModelRun.run_id == run_id))).scalar_one_or_none() + if run is None: + return None # route -> 404 + cfg = run.model_config # JSONB dict + model_type = cfg["model_type"] + if model_type == "lightgbm": + raise ValueError("Explanations available for baseline models only") # -> 400 + season_length = cfg.get("season_length") + window_size = cfg.get("window_size") + # as_of_date = run.data_window_end ; store/product from the run + ... dispatch as explain_forecast ... + + async def explain_job(self, db, job_id: str) -> ForecastExplanation | None: + job = (await db.execute( + select(Job).where(Job.job_id == job_id))).scalar_one_or_none() + if job is None: + return None # route -> 404 + if job.job_type != "predict" or job.status != "completed": + raise BadRequestError(message="explain_job requires a completed predict job", + details={"job_id": job_id, "status": job.status}) + result = job.result or {} + forecasts = result.get("forecasts", []) + # as_of_date = day BEFORE the first forecast date + first = date_type.fromisoformat(forecasts[0]["date"]) + as_of_date = first - timedelta(days=1) + ... dispatch ... +``` + +### Integration Points + +```yaml +DATABASE: + - migration: "create forecast_explanation table, down_revision = 43e35957a248" + - indexes: GIN on drivers, composite (store_id, product_id), unique explanation_id + - constraints: CHECK confidence IN (high,medium,low); CHECK method IN + (rule_based,shap,component) + - alembic env: register the model with target_metadata (mirror commit 9e7a9e1 + which registered scenario_plan for the drift check) + +CONFIG: + - none new — ExplainabilityService.__init__ calls get_settings() only for + parity; no new settings keys. + +ROUTES: + - add to: app/main.py + - import: from app.features.explainability.routes import router as explainability_router + - include: app.include_router(explainability_router) # after forecasting_router + +FRONTEND: + - types: frontend/src/types/api.ts (+4 interfaces) + - hook: frontend/src/hooks/use-explanations.ts + - component: frontend/src/components/explainability/explanation-panel.tsx + - pages: run-detail.tsx + forecast.tsx mount + +DOCS: + - docs/_base/API_CONTRACTS.md — 3 endpoint rows under an `explainability` group +``` + +## Validation Loop + +### Level 1: Syntax & Style + +```bash +uv run ruff check . +uv run ruff format --check . +# Expected: no errors. If errors, READ the error and fix (ruff check --fix for autofixable). +``` + +### Level 2: Type checks + Unit Tests + +```bash +uv run mypy app/ && uv run pyright app/ # both --strict — gate merge +uv run pytest -v -m "not integration" +# Iterate until green. Never weaken a test to pass; fix the code. +# Key assertions: each explainer's h=1 value == the real forecaster's +# .fit(y).predict(1)[0]; ExplainForecastRequest accepts an ISO-string as_of_date. +``` + +### Level 3: Integration Tests + +```bash +docker compose up -d +uv run alembic upgrade head +uv run alembic downgrade -1 && uv run alembic upgrade head # migration round-trips +uv run pytest -v -m integration +``` + +### Level 4: Manual Validation + +```bash +# Backend: uv run uvicorn app.main:app --reload --port 8123 +curl -s -X POST http://localhost:8123/explain/forecast \ + -H 'Content-Type: application/json' \ + -d '{"store_id":1,"product_id":1,"model_type":"moving_average","as_of_date":"2024-06-30","window_size":7}' \ + | python -m json.tool +# Expect: drivers[], reason_codes[], confidence, caveats[], agent_summary, forecast_value +curl -s http://localhost:8123/explain/runs/ | python -m json.tool +curl -s -o /dev/null -w '%{http_code}\n' http://localhost:8123/explain/runs/does-not-exist # 404 +``` + +### Level 5: Frontend + +```bash +cd frontend && pnpm tsc --noEmit && pnpm lint && pnpm test --run +# Then: ./node_modules/.bin/vite --host 0.0.0.0 +# Browser dogfood (webapp-testing / agent-browser per .claude/rules/ui-design.md): +# /explorer/runs/ -> Forecast Explanation card renders +# /visualize/forecast -> explanation panel renders below the chart +``` + +## Final Validation Checklist + +- [ ] `uv run ruff check . && uv run ruff format --check .` clean +- [ ] `uv run mypy app/ && uv run pyright app/` clean (both --strict) +- [ ] `uv run pytest -v -m "not integration"` green +- [ ] `docker compose up -d && uv run pytest -v -m integration` green +- [ ] Migration applies + rolls back cleanly; `uv run alembic check` reports no drift +- [ ] `cd frontend && pnpm tsc --noEmit && pnpm lint && pnpm test --run` green +- [ ] Manual curl + browser dogfood confirm the feature +- [ ] All Success Criteria met +- [ ] Commit `feat(api,ui): add forecast explainability & driver attribution slice (#)` + referencing an OPEN GitHub issue; branch `feat/forecast-explainability` off `dev` + +## Testing Strategy + +### Unit Tests (`-m "not integration"`, mocked externals, no DB) + +Cover: every explainer's forecast-value parity vs the real forecasters; driver +contribution sums; confidence downgrades; reason-code firing logic; caveat +content; schema JSON-path validation; service dispatch with `AsyncMock` DB; +route happy + error paths via `ASGITransport`. + +### Integration Tests (`@pytest.mark.integration`, real docker-compose Postgres) + +Cover: the Alembic migration applies + rolls back cleanly; `ForecastExplanation` +`CheckConstraint`s reject bad `confidence`/`method`; end-to-end +`GET /explain/runs/{run_id}` after seeding a real series + `ModelRun`; +explanation row persisted and re-readable. Never mock the DB. + +### Edge Cases + +- Empty series for a `(store, product)` → explainer raises `ValueError` → route 400. +- Series shorter than `season_length`/`window_size` → confidence `LOW` + + `insufficient_history` reason code (not a crash). +- A `lightgbm` run → 400 "baseline models only" (MVP scope guard). +- A `predict` job that is `pending`/`failed`/`cancelled` → 400. +- A series with stockout days → `stockout_constrained` reason code present, + caveat about understated demand. +- Flat (constant) series → moving-average confidence `HIGH`, naive + `recent_trend` driver `neutral`. +- ISO-date string body for `POST /explain/forecast` (the `strict`-mode JSON path). + +## Open Questions & Decisions + +1. **[DECISION LOCKED] Cross-slice ORM-model import is ALLOWED (read-only).** + `explain_run`/`explain_job` need to read `ModelRun`/`Job` rows. The + vertical-slice rule forbids importing another slice's `service.py`; importing + another slice's `models.py` was the one gray area. **Maintainer ruling:** the + `explainability` slice MAY import `app.features.registry.models.ModelRun` and + `app.features.jobs.models.Job` directly as **read-only data contracts** — it + must NEVER import those slices' `service.py`. This mirrors how slices already + import `app.features.data_platform.models` directly. The decision is final; + there is no fallback path. The implementer must document the choice in the + `service.py` module docstring and the PR description. + +2. **[ASSUMPTION] Endpoint namespace.** This PRP owns a self-contained `/explain` + prefix rather than mounting paths under `/forecasting` or `/registry` (the + original brief suggested `GET /forecasting/explanations/{job_id}` and + `GET /registry/runs/{run_id}/explanations`). A self-owned prefix keeps the + slice from mounting routes under another slice's prefix. Flag in the PR for + maintainer sign-off. + +3. **[ASSUMPTION] `seasonal_naive` horizon.** Explanations are computed for the + **h=1** forecast (the dominant, most-interpretable case). Multi-horizon + driver attribution is deferred to a future PRP. + +4. **[ASSUMPTION] `as_of_date` for a predict job.** Derived as the day before + the first forecast date in `job.result["forecasts"]`. If a future predict-job + result records the training cutoff explicitly, prefer that. + +5. **[ASSUMPTION] No GitHub issue exists yet.** The implementer must open/secure + an open issue before committing (`commit-format.md` requires `(#issue)`). + Branch `feat/forecast-explainability` off `dev`. + +6. **[OUT OF SCOPE] Backtest explanation.** Per-fold backtest explanation is + future-version work — noted, not built here. MVP covers forecast + run + job. + +7. **[OUT OF SCOPE] SHAP / tree-model explainers.** Deliberately excluded. SHAP + would add `shap` + a transitive tree (numba, llvmlite, cloudpickle) — a heavy + footprint for a single-host portfolio repo, and `lightgbm` itself is still + feature-flagged and `NotImplementedError` in `model_factory`. There is no tree + model to explain yet. The MVP rule-based explainers are **exact**, not + approximate. The `method` field (`"rule_based"` now; `"shap"`/`"component"` + reserved) keeps the schema forward-compatible. **Adding SHAP later needs its + own PRP + an ADR** — a new core-path dependency touches + `.claude/rules/product-vision.md` (single-host vision). + +## Anti-Patterns to Avoid + +- ❌ Importing another slice's `service.py` (vertical-slice rule violation). +- ❌ Setting `down_revision` to `378c112e4b32` — the head is `43e35957a248`. +- ❌ Omitting `Field(strict=False)` on `ExplainForecastRequest.as_of_date` + (`test_strict_mode_policy.py` fails CI). +- ❌ Reading any data past `as_of_date` / `data_window_end` (leakage — blocker). +- ❌ Calling `await db.commit()` inside the service (`get_db` auto-commits). +- ❌ Bare `raise HTTPException(500, "string")` — use the RFC 7807 exception classes. +- ❌ Adding `shap` / `lightgbm` to `pyproject.toml`. +- ❌ Naming any ORM column `metadata` (SQLAlchemy reserves it). +- ❌ Returning numpy floats unwrapped — cast through `float()`. +- ❌ Making a reason code a causal claim — they are advisory correlation only. +- ❌ Hand-rolling a shadcn install — use the `shadcn` skill + MCP. + +--- + +## Confidence Score: 9/10 for one-pass success + +The slice pattern, schema/route/migration patterns, and the exact forecaster +math are all well-established in the codebase and reproduced above with verified +line/symbol references. The newest slice (`scenarios`, PRP-26) is a near-exact +structural template. The cross-slice ORM-import question (Open Question #1) is +now a locked decision — no mid-implementation maintainer call is needed. +Residual risks: (a) the frontend panel needs real-browser verification per +`ui-design.md`, which can surface layout iteration; (b) the Alembic +model-registration step (Task 3b) must mirror commit 9e7a9e1 precisely so the +drift check stays green. Both are bounded and flagged. diff --git a/PRPs/ai_docs/exogenous-regressor-forecasting.md b/PRPs/ai_docs/exogenous-regressor-forecasting.md new file mode 100644 index 00000000..cd9353ce --- /dev/null +++ b/PRPs/ai_docs/exogenous-regressor-forecasting.md @@ -0,0 +1,164 @@ +# Exogenous-Regressor Forecasting & Leakage-Safe Future Feature Frames + +> Curated reference for **PRP-27 (Scenario Simulation — Full Version)**. ForecastLabAI's +> baseline forecasters (`naive`, `seasonal_naive`, `moving_average`) ignore the exogenous +> `X` argument (every `fit`/`predict` carries `# noqa: ARG002`). The Full Version needs a +> forecaster that *consumes* `X` so a scenario assumption can be expressed as a real +> regressor change instead of a post-forecast multiplier. This doc condenses the parts of +> the LightGBM / scikit-learn / pandas docs that matter for that, plus the leakage rule. + +--- + +## 1. The exogenous-regressor model contract (what to build) + +A "regression-on-features" forecaster predicts demand from a **feature row per future +day**, not from the historical target series. The flow: + +``` +TRAIN: y, X_hist ─fit─► estimator (X_hist built by featuresets, cutoff-safe) +PREDICT: X_future ─predict─► ŷ_future (X_future = the future feature frame) +``` + +- `X_hist` is a 2-D array `[n_samples, n_features]` — the columns featuresets already + produces (`lag_*`, `rolling_*`, calendar, `price_lag_*`, `promo_*`, lifecycle). +- `X_future` is the **same columns** for the horizon days. This is the *future feature + frame* — the central new artifact of PRP-27. +- The estimator is a gradient-boosted tree regressor (`LGBMRegressor`) — or, to avoid a + new dependency, scikit-learn's `HistGradientBoostingRegressor` (already in the + `scikit-learn` dep). **Prefer the scikit-learn option** — see §5. + +### scikit-learn `HistGradientBoostingRegressor` (no new dependency) + +```python +from sklearn.ensemble import HistGradientBoostingRegressor + +est = HistGradientBoostingRegressor( + max_iter=200, learning_rate=0.05, max_depth=6, random_state=42, +) +est.fit(X_hist, y) # X_hist: ndarray [n, k]; y: ndarray [n] +y_future = est.predict(X_future) # X_future: ndarray [horizon, k] +``` + +- Histogram-based, fast, handles `NaN` natively (important — lag features have `NaN` + at series start). Deterministic with a fixed `random_state`. +- Docs: https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.HistGradientBoostingRegressor.html + +### LightGBM `LGBMRegressor` (only if a new dependency is approved) + +```python +from lightgbm import LGBMRegressor +est = LGBMRegressor(n_estimators=200, learning_rate=0.05, max_depth=6, + random_state=42, n_jobs=1, verbose=-1) +est.fit(X_hist, y) +y_future = est.predict(X_future) +``` + +- API: https://lightgbm.readthedocs.io/en/stable/pythonapi/lightgbm.LGBMRegressor.html +- Set `n_jobs=1` + `random_state` for reproducibility; `verbose=-1` to silence. +- `LightGBMModelConfig` already exists in `forecasting/schemas.py` and + `forecast_enable_lightgbm` already exists in config — but **LightGBM is NOT in + `pyproject.toml`** and `model_factory` raises `NotImplementedError`. Adding it is a + `pyproject.toml` change + a stop-and-ask gate (see PRP-27 § Vision Tensions). + +--- + +## 2. The leakage rule for FUTURE feature frames (the load-bearing part) + +`app/features/featuresets/service.py` builds **historical** features and is time-safe by +construction: it filters to `cutoff_date` *before* any compute, lags via `shift(positive)`, +rolls via `shift(1).rolling(...)`, all `groupby` entity-aware. `test_leakage.py` is its spec. + +A **future** feature frame is different and dangerous: for horizon day `D` you must +produce the SAME feature columns, but `D` has **no observed target**. The rule: + +> **A future feature row for day `D` may only use information available at the forecast +> origin `T` (the last training day) — never an observed value at `D` or later.** + +Concretely, for a horizon `T+1 … T+H`: + +| Feature family | How to populate the future frame | Leakage trap to avoid | +|----------------|----------------------------------|------------------------| +| `lag_k` (k ≥ horizon) | Real observed `y[T+1-k]` — available at `T`. | — | +| `lag_k` (k < horizon) | **Recursive**: `lag_k` at `T+j` = the model's own prediction `ŷ[T+j-k]`. NEVER a real future `y`. | Using a real `y[T+j-k]` (does not exist) or 0. | +| `rolling_*` | Built from the same `shift(1)`-then-roll over the *extended* (history + predicted) series. | Rolling over un-shifted future values. | +| calendar (`dow`, `month`, `is_weekend`, …) | Pure function of the date `D` — always safe, compute directly. | — | +| `price_lag_*`, `promo_*` | Driven by the **scenario assumptions** — the planner is *positing* a future price/promo. This is the intended what-if input, not leakage. | Reading real future `price_history` rows. | +| `is_holiday` | From the scenario's holiday assumption OR the `calendar` table (a `calendar` row is a timeless attribute, like `launch_date`). | — | +| lifecycle (`days_since_launch`) | Pure function of `D - product.launch_date` — safe. | — | + +**Recursive (iterative) forecasting** is the standard technique for multi-step horizons +when lags shorter than the horizon exist: predict `T+1`, append `ŷ[T+1]` to the working +series, recompute lags, predict `T+2`, and so on. Pandas time-series guide: +https://pandas.pydata.org/docs/user_guide/timeseries.html + +**Simplification that sidesteps recursion entirely:** if the future feature frame uses +ONLY lags `k ≥ horizon`, calendar features, and assumption-driven exogenous columns, then +every feature value is knowable at `T` with no recursion. PRP-27 recommends this +"long-lag + exogenous + calendar" feature set for the MVP of the Full Version — it keeps +the leakage proof tractable (`test_leakage.py` can assert it directly) and is one-pass +implementable. Recursion is a documented Phase-2 extension. + +--- + +## 3. Why this is leakage-critical for a planner + +The MVP (PRP-26) is *immune* to leakage because it never builds a future feature frame — +it multiplies the baseline forecast by a deterministic factor. The Full Version +*introduces* the future feature frame, so it introduces the leakage surface the MVP did +not have. PRP-27 therefore ships a NEW load-bearing test +`app/features/scenarios/tests/test_leakage.py` extension (or a sibling +`test_future_frame_leakage.py`) that asserts the future-frame generator never reads an +observed target at or after the forecast origin. This mirrors +`app/features/featuresets/tests/test_leakage.py` — never weaken it to make a feature pass. + +--- + +## 4. Multi-scenario comparison (UX + math) + +Comparing N scenarios against one baseline is an aggregation over N `ScenarioComparison` +objects: + +- Each scenario contributes one `(units_delta, revenue_delta, coverage_verdict)` triple. +- The comparison view ranks scenarios by a chosen metric (revenue delta default) and + renders all series on one chart (baseline + one line per scenario). +- Recharts renders M+1 `` series from one merged row array keyed by date — + `frontend/src/components/charts/time-series-chart.tsx` currently wraps a 2-series case; + a multi-series variant passes a `series: {key,label,color}[]` prop. Recharts LineChart: + https://recharts.org/en-US/api/LineChart +- TanStack Query: the comparison page issues one query per saved scenario id (or one + batch endpoint). Mutations vs queries pattern: + https://tanstack.com/query/latest/docs/framework/react/guides/mutations + +--- + +## 5. Recommendation for PRP-27 (de-risking) + +1. **Prefer `HistGradientBoostingRegressor`** over LightGBM — it is already a transitive + dependency via `scikit-learn`, so no `pyproject.toml` change and no stop-and-ask gate. + It is deterministic, NaN-tolerant, and fast enough for single-series horizons. +2. **Use the long-lag + calendar + exogenous feature set** so the future frame needs no + recursion — the leakage proof stays simple and the PRP stays one-pass implementable. +3. **Keep `method` forward-compatible** — the MVP locked `method="heuristic"` behind a + CHECK constraint. The Full Version adds `method="model_exogenous"`; the migration must + widen the CHECK to `IN ('heuristic','model_exogenous')`. +4. **Never replace the heuristic path** — it stays as the fallback when a baseline model + does not support exogenous features. A scenario result always declares which `method` + produced it, and the heuristic disclaimer stays on heuristic results. + +--- + +## Source URLs (with the sections that matter) + +- scikit-learn `HistGradientBoostingRegressor` — fit/predict, NaN handling, `random_state`: + https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.HistGradientBoostingRegressor.html +- scikit-learn `TimeSeriesSplit` — for any backtest of the exogenous model: + https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.TimeSeriesSplit.html +- LightGBM `LGBMRegressor` Python API — only if the dependency is approved: + https://lightgbm.readthedocs.io/en/stable/pythonapi/lightgbm.LGBMRegressor.html +- pandas time-series user guide — date ranges, shifting, rolling for the future frame: + https://pandas.pydata.org/docs/user_guide/timeseries.html +- Recharts LineChart — multi-series scenario comparison chart: + https://recharts.org/en-US/api/LineChart +- NIST AI Risk Management Framework — transparency controls for model-driven revenue + claims (the `disclaimer` / `method` labelling requirement): + https://www.nist.gov/itl/ai-risk-management-framework diff --git a/README.md b/README.md index 27210c98..3466a525 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,7 @@ Portfolio-grade end-to-end retail demand forecasting system. - **Dashboard**: React 19 + Vite + Tailwind CSS 4 + shadcn/ui for data exploration and model management - **Explorer**: Click-through detail pages for stores, products, model runs, and jobs; run-vs-run comparison and SHA-256 artifact integrity verification; server-side sortable, CSV-exportable tables with column-visibility toggles and URL-shareable filter/sort/page state across every Explorer page; date-scoped KPIs, revenue bar/line charts, and cross-filtering on the Sales page - **Demand Planner**: `/visualize/demand` — every completed forecast rolled into a multi-SKU table (tomorrow / next-week / next-month demand + inventory requirement), with a lead-time selector and a single-SKU drill-in; the Forecast and Backtest pages run jobs in-page, export CSV, toggle a prediction-interval band, and cross-link to runs/jobs +- **What-If Planner**: `/visualize/planner` — take an existing forecast, apply price / promotion / holiday / inventory / lifecycle assumptions, and see the baseline-vs-scenario demand and revenue impact; a regression baseline genuinely re-forecasts through the assumptions (`method="model_exogenous"`), any other baseline applies a clearly-labelled deterministic heuristic; save, tag, reload, clone, and delete named scenario plans, and rank 2-5 saved plans side by side in a multi-scenario comparison. The experiment chat agent can also propose a scenario and — behind the human-in-the-loop approval gate — save it for you - **RAG Knowledge Base**: Postgres pgvector embeddings + evidence-grounded answers with citations - **Agentic Layer**: PydanticAI agents for autonomous experimentation and evidence-grounded Q&A with human-in-the-loop approval - **Data Seeder (The Forge)**: Reproducible synthetic data generator with realistic time-series patterns, scenario presets, and retail effects diff --git a/alembic/env.py b/alembic/env.py index 6abccc57..dd83996b 100644 --- a/alembic/env.py +++ b/alembic/env.py @@ -15,9 +15,11 @@ from app.features.agents import models as agents_models # noqa: F401 from app.features.config import models as config_models # noqa: F401 from app.features.data_platform import models as data_platform_models # noqa: F401 +from app.features.explainability import models as explainability_models # noqa: F401 from app.features.jobs import models as jobs_models # noqa: F401 from app.features.rag import models as rag_models # noqa: F401 from app.features.registry import models as registry_models # noqa: F401 +from app.features.scenarios import models as scenarios_models # noqa: F401 # Alembic Config object config = context.config diff --git a/alembic/versions/43e35957a248_create_scenario_plan_table.py b/alembic/versions/43e35957a248_create_scenario_plan_table.py new file mode 100644 index 00000000..8fdd4a8b --- /dev/null +++ b/alembic/versions/43e35957a248_create_scenario_plan_table.py @@ -0,0 +1,97 @@ +"""create scenario plan table + +Revision ID: 43e35957a248 +Revises: 378c112e4b32 +Create Date: 2026-05-19 07:34:30.545495 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = '43e35957a248' +down_revision: Union[str, None] = '378c112e4b32' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Apply migration — create the scenario_plan table.""" + op.create_table( + 'scenario_plan', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('scenario_id', sa.String(length=32), nullable=False), + sa.Column('name', sa.String(length=200), nullable=False), + sa.Column('store_id', sa.Integer(), nullable=False), + sa.Column('product_id', sa.Integer(), nullable=False), + sa.Column('run_id', sa.String(length=32), nullable=False), + sa.Column('horizon', sa.Integer(), nullable=False), + sa.Column('assumptions', postgresql.JSONB(astext_type=sa.Text()), nullable=False), + sa.Column('comparison', postgresql.JSONB(astext_type=sa.Text()), nullable=False), + sa.Column('method', sa.String(length=20), nullable=False), + sa.Column( + 'created_at', + sa.DateTime(timezone=True), + server_default=sa.text('now()'), + nullable=False, + ), + sa.Column( + 'updated_at', + sa.DateTime(timezone=True), + server_default=sa.text('now()'), + nullable=False, + ), + sa.CheckConstraint("method IN ('heuristic')", name='ck_scenario_plan_method'), + sa.PrimaryKeyConstraint('id'), + ) + op.create_index( + op.f('ix_scenario_plan_scenario_id'), 'scenario_plan', ['scenario_id'], unique=True + ) + op.create_index( + op.f('ix_scenario_plan_store_id'), 'scenario_plan', ['store_id'], unique=False + ) + op.create_index( + op.f('ix_scenario_plan_product_id'), 'scenario_plan', ['product_id'], unique=False + ) + op.create_index( + op.f('ix_scenario_plan_run_id'), 'scenario_plan', ['run_id'], unique=False + ) + op.create_index( + 'ix_scenario_plan_assumptions_gin', + 'scenario_plan', + ['assumptions'], + unique=False, + postgresql_using='gin', + ) + op.create_index( + 'ix_scenario_plan_comparison_gin', + 'scenario_plan', + ['comparison'], + unique=False, + postgresql_using='gin', + ) + op.create_index( + 'ix_scenario_plan_store_product', + 'scenario_plan', + ['store_id', 'product_id'], + unique=False, + ) + + +def downgrade() -> None: + """Revert migration — drop the scenario_plan table.""" + op.drop_index('ix_scenario_plan_store_product', table_name='scenario_plan') + op.drop_index( + 'ix_scenario_plan_comparison_gin', table_name='scenario_plan', postgresql_using='gin' + ) + op.drop_index( + 'ix_scenario_plan_assumptions_gin', table_name='scenario_plan', postgresql_using='gin' + ) + op.drop_index(op.f('ix_scenario_plan_run_id'), table_name='scenario_plan') + op.drop_index(op.f('ix_scenario_plan_product_id'), table_name='scenario_plan') + op.drop_index(op.f('ix_scenario_plan_store_id'), table_name='scenario_plan') + op.drop_index(op.f('ix_scenario_plan_scenario_id'), table_name='scenario_plan') + op.drop_table('scenario_plan') diff --git a/alembic/versions/7e8f9748581e_add_scenario_provenance_columns.py b/alembic/versions/7e8f9748581e_add_scenario_provenance_columns.py new file mode 100644 index 00000000..6f600f96 --- /dev/null +++ b/alembic/versions/7e8f9748581e_add_scenario_provenance_columns.py @@ -0,0 +1,73 @@ +"""add scenario provenance columns + +Revision ID: 7e8f9748581e +Revises: bb8c4587ef1d +Create Date: 2026-05-19 10:47:09.829097 + +PRP-27 Phase D — adds provenance + approval-audit columns to ``scenario_plan`` +so an agent-proposed plan records who/what created it and the human approval +decision that released it. ``source`` server-defaults to ``'user'`` so every +pre-existing row stays valid. Forward-only. +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision: str = '7e8f9748581e' +down_revision: Union[str, None] = 'bb8c4587ef1d' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Add the source + approval-audit columns, their CHECKs and an index.""" + op.add_column( + 'scenario_plan', + sa.Column( + 'source', + sa.String(length=16), + nullable=False, + server_default=sa.text("'user'"), + ), + ) + op.add_column( + 'scenario_plan', + sa.Column('agent_session_id', sa.String(length=32), nullable=True), + ) + op.add_column( + 'scenario_plan', + sa.Column('approved_by', sa.String(length=120), nullable=True), + ) + op.add_column( + 'scenario_plan', + sa.Column('approved_at', sa.DateTime(timezone=True), nullable=True), + ) + op.add_column( + 'scenario_plan', + sa.Column('approval_decision', sa.String(length=16), nullable=True), + ) + op.create_check_constraint( + 'ck_scenario_plan_source', + 'scenario_plan', + "source IN ('user', 'agent')", + ) + op.create_check_constraint( + 'ck_scenario_plan_approval_decision', + 'scenario_plan', + "approval_decision IS NULL OR approval_decision IN ('approved', 'rejected')", + ) + op.create_index('ix_scenario_plan_source', 'scenario_plan', ['source'], unique=False) + + +def downgrade() -> None: + """Drop the index, the two CHECKs and the provenance columns.""" + op.drop_index('ix_scenario_plan_source', table_name='scenario_plan') + op.drop_constraint('ck_scenario_plan_approval_decision', 'scenario_plan', type_='check') + op.drop_constraint('ck_scenario_plan_source', 'scenario_plan', type_='check') + op.drop_column('scenario_plan', 'approval_decision') + op.drop_column('scenario_plan', 'approved_at') + op.drop_column('scenario_plan', 'approved_by') + op.drop_column('scenario_plan', 'agent_session_id') + op.drop_column('scenario_plan', 'source') diff --git a/alembic/versions/bb8c4587ef1d_add_scenario_library_columns.py b/alembic/versions/bb8c4587ef1d_add_scenario_library_columns.py new file mode 100644 index 00000000..38e65d10 --- /dev/null +++ b/alembic/versions/bb8c4587ef1d_add_scenario_library_columns.py @@ -0,0 +1,52 @@ +"""add scenario library columns + +Revision ID: bb8c4587ef1d +Revises: e47f5739d7d0 +Create Date: 2026-05-19 10:26:58.473203 + +PRP-27 Phase C — adds the scenario-library columns to ``scenario_plan``: +``tags`` (a JSONB string array, queryable via a GIN index) and ``cloned_from`` +(the ``scenario_id`` a plan was cloned from, nullable). Forward-only. +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = 'bb8c4587ef1d' +down_revision: Union[str, None] = 'e47f5739d7d0' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Add the tags and cloned_from columns plus a GIN index on tags.""" + op.add_column( + 'scenario_plan', + sa.Column( + 'tags', + postgresql.JSONB(astext_type=sa.Text()), + nullable=False, + server_default=sa.text("'[]'::jsonb"), + ), + ) + op.add_column( + 'scenario_plan', + sa.Column('cloned_from', sa.String(length=32), nullable=True), + ) + op.create_index( + 'ix_scenario_plan_tags_gin', + 'scenario_plan', + ['tags'], + unique=False, + postgresql_using='gin', + ) + + +def downgrade() -> None: + """Drop the GIN index and the scenario-library columns.""" + op.drop_index('ix_scenario_plan_tags_gin', table_name='scenario_plan', postgresql_using='gin') + op.drop_column('scenario_plan', 'cloned_from') + op.drop_column('scenario_plan', 'tags') diff --git a/alembic/versions/e47f5739d7d0_widen_scenario_method_check.py b/alembic/versions/e47f5739d7d0_widen_scenario_method_check.py new file mode 100644 index 00000000..7b861cde --- /dev/null +++ b/alembic/versions/e47f5739d7d0_widen_scenario_method_check.py @@ -0,0 +1,43 @@ +"""widen scenario method check + +Revision ID: e47f5739d7d0 +Revises: 43e35957a248 +Create Date: 2026-05-19 10:06:15.179816 + +PRP-27 Phase B — widens the ``scenario_plan.method`` CHECK constraint so a +model-driven simulation can persist ``method='model_exogenous'`` alongside the +MVP's ``'heuristic'``. Forward-only: never edits the merged migration that +created the table. +""" +from typing import Sequence, Union + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = 'e47f5739d7d0' +down_revision: Union[str, None] = '43e35957a248' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + +_CONSTRAINT = "ck_scenario_plan_method" +_TABLE = "scenario_plan" + + +def upgrade() -> None: + """Allow method IN ('heuristic', 'model_exogenous').""" + op.drop_constraint(_CONSTRAINT, _TABLE, type_="check") + op.create_check_constraint( + _CONSTRAINT, + _TABLE, + "method IN ('heuristic', 'model_exogenous')", + ) + + +def downgrade() -> None: + """Revert to method IN ('heuristic') only.""" + op.drop_constraint(_CONSTRAINT, _TABLE, type_="check") + op.create_check_constraint( + _CONSTRAINT, + _TABLE, + "method IN ('heuristic')", + ) diff --git a/alembic/versions/f84258c4cb44_create_forecast_explanation_table.py b/alembic/versions/f84258c4cb44_create_forecast_explanation_table.py new file mode 100644 index 00000000..5d5cede9 --- /dev/null +++ b/alembic/versions/f84258c4cb44_create_forecast_explanation_table.py @@ -0,0 +1,133 @@ +"""create forecast explanation table + +Revision ID: f84258c4cb44 +Revises: 7e8f9748581e +Create Date: 2026-05-19 11:46:00.062839 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = 'f84258c4cb44' +down_revision: Union[str, None] = '7e8f9748581e' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Apply migration — create the forecast_explanation table.""" + op.create_table( + 'forecast_explanation', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('explanation_id', sa.String(length=32), nullable=False), + sa.Column('run_id', sa.String(length=32), nullable=True), + sa.Column('job_id', sa.String(length=32), nullable=True), + sa.Column('store_id', sa.Integer(), nullable=False), + sa.Column('product_id', sa.Integer(), nullable=False), + sa.Column('model_type', sa.String(length=50), nullable=False), + sa.Column('method', sa.String(length=20), nullable=False), + sa.Column('as_of_date', sa.Date(), nullable=False), + sa.Column('forecast_value', sa.Float(), nullable=False), + sa.Column('confidence', sa.String(length=10), nullable=False), + sa.Column('drivers', postgresql.JSONB(astext_type=sa.Text()), nullable=False), + sa.Column('reason_codes', postgresql.JSONB(astext_type=sa.Text()), nullable=False), + sa.Column('caveats', postgresql.JSONB(astext_type=sa.Text()), nullable=False), + sa.Column('agent_summary', sa.String(length=2000), nullable=False), + sa.Column( + 'created_at', + sa.DateTime(timezone=True), + server_default=sa.text('now()'), + nullable=False, + ), + sa.Column( + 'updated_at', + sa.DateTime(timezone=True), + server_default=sa.text('now()'), + nullable=False, + ), + sa.CheckConstraint( + "confidence IN ('high', 'medium', 'low')", + name='ck_forecast_explanation_confidence', + ), + sa.CheckConstraint( + "method IN ('rule_based', 'shap', 'component')", + name='ck_forecast_explanation_method', + ), + sa.PrimaryKeyConstraint('id'), + ) + op.create_index( + op.f('ix_forecast_explanation_explanation_id'), + 'forecast_explanation', + ['explanation_id'], + unique=True, + ) + op.create_index( + op.f('ix_forecast_explanation_run_id'), + 'forecast_explanation', + ['run_id'], + unique=False, + ) + op.create_index( + op.f('ix_forecast_explanation_job_id'), + 'forecast_explanation', + ['job_id'], + unique=False, + ) + op.create_index( + op.f('ix_forecast_explanation_store_id'), + 'forecast_explanation', + ['store_id'], + unique=False, + ) + op.create_index( + op.f('ix_forecast_explanation_product_id'), + 'forecast_explanation', + ['product_id'], + unique=False, + ) + op.create_index( + 'ix_forecast_explanation_drivers_gin', + 'forecast_explanation', + ['drivers'], + unique=False, + postgresql_using='gin', + ) + op.create_index( + 'ix_forecast_explanation_store_product', + 'forecast_explanation', + ['store_id', 'product_id'], + unique=False, + ) + + +def downgrade() -> None: + """Revert migration — drop the forecast_explanation table.""" + op.drop_index( + 'ix_forecast_explanation_store_product', table_name='forecast_explanation' + ) + op.drop_index( + 'ix_forecast_explanation_drivers_gin', + table_name='forecast_explanation', + postgresql_using='gin', + ) + op.drop_index( + op.f('ix_forecast_explanation_product_id'), table_name='forecast_explanation' + ) + op.drop_index( + op.f('ix_forecast_explanation_store_id'), table_name='forecast_explanation' + ) + op.drop_index( + op.f('ix_forecast_explanation_job_id'), table_name='forecast_explanation' + ) + op.drop_index( + op.f('ix_forecast_explanation_run_id'), table_name='forecast_explanation' + ) + op.drop_index( + op.f('ix_forecast_explanation_explanation_id'), + table_name='forecast_explanation', + ) + op.drop_table('forecast_explanation') diff --git a/app/core/config.py b/app/core/config.py index cdbe5cf1..d3ac4a24 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -161,7 +161,10 @@ class Settings(BaseSettings): agent_retry_delay_seconds: float = 1.0 # Human-in-the-Loop Configuration - agent_require_approval: list[str] = ["create_alias", "archive_run"] + # ``save_scenario`` (PRP-27 Phase D) lets the experiment agent persist a + # scenario_plan row — a deliberate mutation-surface widening, so it is + # gated here exactly like create_alias / archive_run. + agent_require_approval: list[str] = ["create_alias", "archive_run", "save_scenario"] agent_approval_timeout_minutes: int = 60 # Session Configuration diff --git a/app/features/agents/agents/base.py b/app/features/agents/agents/base.py index 272e59e8..f9f4e1a0 100644 --- a/app/features/agents/agents/base.py +++ b/app/features/agents/agents/base.py @@ -284,6 +284,8 @@ def requires_approval(action_name: str) -> bool: - Use tool_compare_runs to analyze differences between registered runs - Use tool_create_alias to deploy successful models (requires approval) - Use tool_archive_run to clean up old experiments (requires approval) +- Use tool_propose_scenario to draft a candidate what-if scenario (read-only) +- Use tool_save_scenario to persist an approved scenario plan (requires approval) """ SAFETY_INSTRUCTIONS = """ diff --git a/app/features/agents/agents/experiment.py b/app/features/agents/agents/experiment.py index 7a2de475..d9bad5bd 100644 --- a/app/features/agents/agents/experiment.py +++ b/app/features/agents/agents/experiment.py @@ -38,6 +38,8 @@ get_run, list_runs, ) +from app.features.scenarios.agent_tools import propose_scenario, save_scenario +from app.features.scenarios.schemas import SaveScenarioRequest logger = structlog.get_logger() @@ -373,6 +375,111 @@ async def tool_archive_run( return await archive_run(db=ctx.deps.db, run_id=run_id) + @agent.tool + @recoverable + async def tool_propose_scenario( + ctx: RunContext[AgentDeps], + store_id: int, + product_id: int, + horizon: int = 14, + objective: str = "", + ) -> dict[str, Any]: + """Propose a candidate what-if scenario for a store / product. + + READ-ONLY: this drafts a candidate scenario; it persists nothing. To + save the proposal, call tool_save_scenario (which requires approval). + + Args: + store_id: Store the proposed scenario targets. + product_id: Product the proposed scenario targets. + horizon: Number of days the proposed scenario should span (default 14). + objective: Free-text planning objective — keywords like 'promotion' + steer the proposal toward a promotion instead of a price cut. + + Returns: + A candidate scenario with assumptions and a recommendation. + """ + ctx.deps.increment_tool_calls() + logger.info( + "agents.experiment.tool_propose_scenario", + session_id=ctx.deps.session_id, + store_id=store_id, + product_id=product_id, + ) + return await propose_scenario( + db=ctx.deps.db, + store_id=store_id, + product_id=product_id, + horizon=horizon, + objective=objective, + ) + + @agent.tool + @recoverable + async def tool_save_scenario( + ctx: RunContext[AgentDeps], + name: str, + run_id: str, + store_id: int, + product_id: int, + horizon: int, + assumptions: dict[str, Any] | None = None, + ) -> dict[str, Any]: + """Persist a proposed what-if scenario as a saved scenario plan. + + REQUIRES HUMAN APPROVAL. This action writes a scenario_plan row. + + Use this only after tool_propose_scenario, passing back its candidate + assumptions. The plan is persisted with agent provenance and the + approval audit trail. + + Args: + name: Human-readable name for the saved plan. + run_id: Artifact key of the baseline model. + store_id: Store the scenario targets. + product_id: Product the scenario targets. + horizon: Number of days to simulate. + assumptions: The candidate assumptions dict from tool_propose_scenario. + + Returns: + The saved plan details, or an approval request. + """ + ctx.deps.increment_tool_calls() + logger.info( + "agents.experiment.tool_save_scenario", + session_id=ctx.deps.session_id, + store_id=store_id, + product_id=product_id, + requires_approval=requires_approval("save_scenario"), + ) + + arguments: dict[str, Any] = { + "name": name, + "run_id": run_id, + "store_id": store_id, + "product_id": product_id, + "horizon": horizon, + "assumptions": assumptions or {}, + "source": "agent", + "agent_session_id": ctx.deps.session_id, + } + + # Check if approval is required — mirrors tool_create_alias exactly. + if requires_approval("save_scenario"): + return { + "status": "approval_required", + "action": "save_scenario", + "arguments": arguments, + "message": "This action requires human approval. Please approve to proceed.", + } + + request = SaveScenarioRequest.model_validate(arguments) + return await save_scenario( + db=ctx.deps.db, + request=request, + agent_session_id=ctx.deps.session_id, + ) + return agent diff --git a/app/features/agents/service.py b/app/features/agents/service.py index 751c984c..1b3c4644 100644 --- a/app/features/agents/service.py +++ b/app/features/agents/service.py @@ -869,6 +869,8 @@ async def _execute_pending_action( ValueError: If action_type is not recognized. """ from app.features.agents.tools.registry_tools import archive_run, create_alias + from app.features.scenarios.agent_tools import save_scenario + from app.features.scenarios.schemas import SaveScenarioRequest if action_type == "create_alias": alias_name = arguments.get("alias_name", "") @@ -886,7 +888,17 @@ async def _execute_pending_action( if result is None: raise ValueError(f"Run not found: {run_id}") return result + elif action_type == "save_scenario": + # The HITL gate has released the agent's save_scenario call — persist + # the scenario_plan row now, stamped with the approved audit trail. + request = SaveScenarioRequest.model_validate(arguments) + return await save_scenario( + db=db, + request=request, + agent_session_id=arguments.get("agent_session_id"), + ) else: raise ValueError( - f"Unknown action type: {action_type}. Supported actions: create_alias, archive_run" + f"Unknown action type: {action_type}. Supported actions: " + "create_alias, archive_run, save_scenario" ) diff --git a/app/features/analytics/routes.py b/app/features/analytics/routes.py index 762d5204..161b655a 100644 --- a/app/features/analytics/routes.py +++ b/app/features/analytics/routes.py @@ -6,11 +6,12 @@ from datetime import date -from fastapi import APIRouter, Depends, HTTPException, Query +from fastapi import APIRouter, Depends, Query from sqlalchemy.ext.asyncio import AsyncSession from app.core.config import get_settings from app.core.database import get_db +from app.core.exceptions import BadRequestError from app.core.logging import get_logger from app.features.analytics.schemas import ( DrilldownDimension, @@ -40,23 +41,23 @@ def validate_date_range(start_date: date, end_date: date) -> None: end_date: End of analysis period. Raises: - HTTPException: If date range is invalid. + BadRequestError: If date range is invalid. Surfaces as an RFC 7807 + ``application/problem+json`` 400 via the registered handler — a + raw ``HTTPException`` would bypass the problem-details envelope. """ settings = get_settings() if end_date < start_date: - raise HTTPException( - status_code=400, - detail=f"end_date ({end_date}) must be >= start_date ({start_date})", + raise BadRequestError( + message=f"end_date ({end_date}) must be >= start_date ({start_date})", ) days_diff = (end_date - start_date).days max_days = settings.analytics_max_date_range_days if days_diff > max_days: - raise HTTPException( - status_code=400, - detail=f"Date range ({days_diff} days) exceeds maximum allowed ({max_days} days)", + raise BadRequestError( + message=f"Date range ({days_diff} days) exceeds maximum allowed ({max_days} days)", ) @@ -108,10 +109,12 @@ async def get_kpis( ), store_id: int | None = Query( None, + ge=1, description="Filter by store ID. Use GET /dimensions/stores to find valid IDs.", ), product_id: int | None = Query( None, + ge=1, description="Filter by product ID. Use GET /dimensions/products to find valid IDs.", ), category: str | None = Query( @@ -134,7 +137,7 @@ async def get_kpis( Aggregated KPI metrics. Raises: - HTTPException: If date range is invalid. + BadRequestError: If date range is invalid (RFC 7807 400). """ # Validate date range before processing validate_date_range(start_date, end_date) @@ -206,10 +209,12 @@ async def get_drilldowns( ), store_id: int | None = Query( None, + ge=1, description="Filter by store ID. Use GET /dimensions/stores to find valid IDs.", ), product_id: int | None = Query( None, + ge=1, description="Filter by product ID. Use GET /dimensions/products to find valid IDs.", ), max_items: int = Query( @@ -235,7 +240,7 @@ async def get_drilldowns( Drilldown analysis with ranked items. Raises: - HTTPException: If date range is invalid. + BadRequestError: If date range is invalid (RFC 7807 400). """ # Validate date range before processing validate_date_range(start_date, end_date) @@ -301,10 +306,12 @@ async def get_timeseries( ), store_id: int | None = Query( None, + ge=1, description="Filter by store ID. Use GET /dimensions/stores to find valid IDs.", ), product_id: int | None = Query( None, + ge=1, description="Filter by product ID. Use GET /dimensions/products to find valid IDs.", ), category: str | None = Query( @@ -328,7 +335,7 @@ async def get_timeseries( Time series response with points in ascending period order. Raises: - HTTPException: If date range is invalid. + BadRequestError: If date range is invalid (RFC 7807 400). """ # Validate date range before processing validate_date_range(start_date, end_date) @@ -380,10 +387,12 @@ async def get_timeseries( async def get_inventory_status( store_id: int | None = Query( None, + ge=1, description="Filter by store ID. Use GET /dimensions/stores to find valid IDs.", ), product_id: int | None = Query( None, + ge=1, description="Filter by product ID. Use GET /dimensions/products to find valid IDs.", ), db: AsyncSession = Depends(get_db), diff --git a/app/features/analytics/tests/conftest.py b/app/features/analytics/tests/conftest.py index b101fabc..a6aa09cd 100644 --- a/app/features/analytics/tests/conftest.py +++ b/app/features/analytics/tests/conftest.py @@ -7,7 +7,7 @@ import pytest from httpx import ASGITransport, AsyncClient -from sqlalchemy import delete +from sqlalchemy import delete, select from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from app.core.config import get_settings @@ -122,11 +122,25 @@ async def db_session() -> AsyncGenerator[AsyncSession, None]: try: yield session finally: - # Clean up test data (delete in FK-safe order). InventorySnapshotDaily - # FK-references store/product/calendar, so it must be cleared before - # the Store/Product/Calendar deletes below. - await session.execute(delete(InventorySnapshotDaily)) - await session.execute(delete(SalesDaily)) + # Clean up test data (delete in FK-safe order). Scope the fact-table + # deletes to TEST-prefixed stores/products so a shared dev or + # integration dataset is never wiped. InventorySnapshotDaily / + # SalesDaily FK-reference store/product/calendar, so they must be + # cleared before the Store/Product/Calendar deletes below. + test_store_ids = select(Store.id).where(Store.code.like("TEST-%")) + test_product_ids = select(Product.id).where(Product.sku.like("TEST-%")) + await session.execute( + delete(InventorySnapshotDaily).where( + InventorySnapshotDaily.store_id.in_(test_store_ids) + | InventorySnapshotDaily.product_id.in_(test_product_ids) + ) + ) + await session.execute( + delete(SalesDaily).where( + SalesDaily.store_id.in_(test_store_ids) + | SalesDaily.product_id.in_(test_product_ids) + ) + ) await session.execute(delete(Product).where(Product.sku.like("TEST-%"))) await session.execute(delete(Store).where(Store.code.like("TEST-%"))) await session.execute( @@ -154,7 +168,9 @@ async def override_get_db() -> AsyncGenerator[AsyncSession, None]: ) as ac: yield ac - app.dependency_overrides.clear() + # Remove only this fixture's override — clear() would also drop overrides + # installed by other fixtures sharing the app instance. + app.dependency_overrides.pop(get_db, None) @pytest.fixture diff --git a/app/features/config/tests/test_service.py b/app/features/config/tests/test_service.py index e97009c1..1bcdfd99 100644 --- a/app/features/config/tests/test_service.py +++ b/app/features/config/tests/test_service.py @@ -101,19 +101,34 @@ async def test_get_effective_config_masks_secrets(self): async def test_get_effective_config_maps_agent_limits(self): """The agent session-limit fields are sourced from the Settings singleton.""" settings = get_settings() - settings.agent_max_tool_calls = 7 - settings.agent_timeout_seconds = 99 - settings.agent_retry_attempts = 2 - settings.agent_session_ttl_minutes = 45 - settings.agent_require_approval = ["create_alias"] - - config = await service.get_effective_config(_mock_db()) - - assert config.agent_max_tool_calls == 7 - assert config.agent_timeout_seconds == 99 - assert config.agent_retry_attempts == 2 - assert config.agent_session_ttl_minutes == 45 - assert config.agent_require_approval == ["create_alias"] + # get_settings() returns a cached singleton — snapshot every field this + # test mutates and restore it in a finally block so the mutation never + # leaks into another test. + fields = ( + "agent_max_tool_calls", + "agent_timeout_seconds", + "agent_retry_attempts", + "agent_session_ttl_minutes", + "agent_require_approval", + ) + original = {field: getattr(settings, field) for field in fields} + try: + settings.agent_max_tool_calls = 7 + settings.agent_timeout_seconds = 99 + settings.agent_retry_attempts = 2 + settings.agent_session_ttl_minutes = 45 + settings.agent_require_approval = ["create_alias"] + + config = await service.get_effective_config(_mock_db()) + + assert config.agent_max_tool_calls == 7 + assert config.agent_timeout_seconds == 99 + assert config.agent_retry_attempts == 2 + assert config.agent_session_ttl_minutes == 45 + assert config.agent_require_approval == ["create_alias"] + finally: + for field, value in original.items(): + setattr(settings, field, value) # ============================================================================= diff --git a/app/features/dimensions/service.py b/app/features/dimensions/service.py index 54854aba..a77af056 100644 --- a/app/features/dimensions/service.py +++ b/app/features/dimensions/service.py @@ -108,9 +108,11 @@ async def list_stores( else: order_by = Store.code.asc() - # Apply pagination + # Apply pagination. Append the unique `code` as a tie-breaker so rows + # with equal sort values keep a stable order across pages (offset + # pagination over a non-unique sort key is otherwise non-deterministic). offset = (page - 1) * page_size - stmt = stmt.order_by(order_by).offset(offset).limit(page_size) + stmt = stmt.order_by(order_by, Store.code.asc()).offset(offset).limit(page_size) # Execute query result = await db.execute(stmt) @@ -233,9 +235,11 @@ async def list_products( else: order_by = Product.sku.asc() - # Apply pagination + # Apply pagination. Append the unique `sku` as a tie-breaker so rows + # with equal sort values keep a stable order across pages (offset + # pagination over a non-unique sort key is otherwise non-deterministic). offset = (page - 1) * page_size - stmt = stmt.order_by(order_by).offset(offset).limit(page_size) + stmt = stmt.order_by(order_by, Product.sku.asc()).offset(offset).limit(page_size) # Execute query result = await db.execute(stmt) diff --git a/app/features/dimensions/tests/conftest.py b/app/features/dimensions/tests/conftest.py index 9befdffe..fb9b12a5 100644 --- a/app/features/dimensions/tests/conftest.py +++ b/app/features/dimensions/tests/conftest.py @@ -7,7 +7,7 @@ import pytest from httpx import ASGITransport, AsyncClient -from sqlalchemy import delete +from sqlalchemy import delete, select from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from app.core.config import get_settings @@ -66,8 +66,17 @@ async def db_session() -> AsyncGenerator[AsyncSession, None]: try: yield session finally: - # Clean up test data (delete in FK-safe order). - await session.execute(delete(SalesDaily)) + # Clean up test data (delete in FK-safe order). Scope the SalesDaily + # delete to TEST-prefixed stores/products so a shared dev or + # integration dataset is never wiped. + test_store_ids = select(Store.id).where(Store.code.like("TEST-%")) + test_product_ids = select(Product.id).where(Product.sku.like("TEST-%")) + await session.execute( + delete(SalesDaily).where( + SalesDaily.store_id.in_(test_store_ids) + | SalesDaily.product_id.in_(test_product_ids) + ) + ) await session.execute(delete(Product).where(Product.sku.like("TEST-%"))) await session.execute(delete(Store).where(Store.code.like("TEST-%"))) await session.execute( @@ -95,7 +104,9 @@ async def override_get_db() -> AsyncGenerator[AsyncSession, None]: ) as ac: yield ac - app.dependency_overrides.clear() + # Remove only this fixture's override — clear() would also drop overrides + # installed by other fixtures sharing the app instance. + app.dependency_overrides.pop(get_db, None) @pytest.fixture diff --git a/app/features/explainability/__init__.py b/app/features/explainability/__init__.py new file mode 100644 index 00000000..ae3fa227 --- /dev/null +++ b/app/features/explainability/__init__.py @@ -0,0 +1,5 @@ +"""Forecast explainability & driver-attribution vertical slice (PRP-28). + +Rule-based, deterministic explanations for the three baseline forecasters. +SHAP is deliberately out of scope — see PRP-28. +""" diff --git a/app/features/explainability/explainers.py b/app/features/explainability/explainers.py new file mode 100644 index 00000000..8c4b9d3e --- /dev/null +++ b/app/features/explainability/explainers.py @@ -0,0 +1,272 @@ +"""Rule-based, deterministic explainers for the three baseline forecasters. + +Each explainer MIRRORS the exact h=1 math of the matching forecaster in +``app/features/forecasting/models.py`` — a rule-based explainer is *exact*, not +an approximation. ``test_explainers.py`` asserts each explainer's forecast value +equals the real forecaster's ``.fit(y).predict(1)[0]`` on the same series. + +A driver with ``contribution == 0.0`` is informational context only — the +baseline model does not consume it. The sum of all driver contributions equals +the forecast value. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any + +import numpy as np + +from app.features.explainability.schemas import ( + ConfidenceLevel, + Direction, + DriverContribution, +) + +# A 1-D float series, matching the forecasters' target-array type. +FloatArray = np.ndarray[Any, np.dtype[np.floating[Any]]] + +# Below this many observations a naive explanation is treated as low-confidence. +_NAIVE_MIN_COMFORTABLE = 14 + + +def _direction(value: float) -> Direction: + """Map a signed value to a driver direction literal.""" + if value > 0: + return "positive" + if value < 0: + return "negative" + return "neutral" + + +class BaseExplainer(ABC): + """Abstract base for a rule-based forecast explainer.""" + + @abstractmethod + def explain(self, y: FloatArray) -> tuple[float, list[DriverContribution]]: + """Decompose the h=1 forecast into named driver contributions. + + Args: + y: The historical target series (time-ordered, ``<= as_of_date``). + + Returns: + The h=1 forecast value and its ordered driver contributions. + + Raises: + ValueError: If the series is too short to produce the forecast. + """ + + @abstractmethod + def confidence(self, y: FloatArray) -> ConfidenceLevel: + """Return a qualitative confidence band for the explanation. + + Args: + y: The historical target series. + + Returns: + The confidence band. + """ + + +class NaiveExplainer(BaseExplainer): + """Explainer for the naive forecaster — the forecast IS the last value.""" + + def explain(self, y: FloatArray) -> tuple[float, list[DriverContribution]]: + """Decompose the naive h=1 forecast. + + Args: + y: The historical target series. + + Returns: + The h=1 forecast (``y[-1]``) and its driver contributions. + + Raises: + ValueError: If ``y`` is empty. + """ + if len(y) == 0: + raise ValueError("Cannot explain an empty series") + forecast = float(y[-1]) + drivers = [ + DriverContribution( + name="last_observation", + feature_value=forecast, + contribution=forecast, + direction="positive", + description="The naive forecast is exactly the last observed value.", + ) + ] + if len(y) >= _NAIVE_MIN_COMFORTABLE: + trend = float(np.mean(y[-7:]) - np.mean(y[-14:-7])) + drivers.append( + DriverContribution( + name="recent_trend", + feature_value=trend, + contribution=0.0, + direction=_direction(trend), + description=( + "Context only — week-over-week change in mean demand; " + "the naive model does not use trend." + ), + ) + ) + return forecast, drivers + + def confidence(self, y: FloatArray) -> ConfidenceLevel: + """Return ``LOW`` for a short series, otherwise ``MEDIUM``.""" + if len(y) < _NAIVE_MIN_COMFORTABLE: + return ConfidenceLevel.LOW + return ConfidenceLevel.MEDIUM + + +class SeasonalNaiveExplainer(BaseExplainer): + """Explainer for the seasonal-naive forecaster — last season's value.""" + + def __init__(self, season_length: int = 7) -> None: + """Initialise the explainer. + + Args: + season_length: Seasonal period in days (must be >= 1). + + Raises: + ValueError: If ``season_length`` < 1. + """ + if season_length < 1: + raise ValueError(f"season_length must be >= 1, got {season_length}") + self.season_length = season_length + + def explain(self, y: FloatArray) -> tuple[float, list[DriverContribution]]: + """Decompose the seasonal-naive h=1 forecast. + + Args: + y: The historical target series. + + Returns: + The h=1 forecast (``y[-season_length]``) and its driver contributions. + + Raises: + ValueError: If ``y`` has fewer observations than ``season_length``. + """ + if len(y) < self.season_length: + raise ValueError(f"Need at least {self.season_length} observations") + forecast = float(y[-self.season_length]) + drivers = [ + DriverContribution( + name="season_match", + feature_value=forecast, + contribution=forecast, + direction="positive", + description=( + f"The forecast repeats the value observed {self.season_length} " + "days ago (one seasonal cycle back)." + ), + ) + ] + return forecast, drivers + + def confidence(self, y: FloatArray) -> ConfidenceLevel: + """Return ``LOW`` for under two seasonal cycles, otherwise ``MEDIUM``.""" + if len(y) < 2 * self.season_length: + return ConfidenceLevel.LOW + return ConfidenceLevel.MEDIUM + + +class MovingAverageExplainer(BaseExplainer): + """Explainer for the moving-average forecaster — mean of the last window.""" + + def __init__(self, window_size: int = 7) -> None: + """Initialise the explainer. + + Args: + window_size: Averaging window in days (must be >= 1). + + Raises: + ValueError: If ``window_size`` < 1. + """ + if window_size < 1: + raise ValueError(f"window_size must be >= 1, got {window_size}") + self.window_size = window_size + + def explain(self, y: FloatArray) -> tuple[float, list[DriverContribution]]: + """Decompose the moving-average h=1 forecast. + + Args: + y: The historical target series. + + Returns: + The h=1 forecast (``mean(y[-window_size:])``) and driver contributions. + + Raises: + ValueError: If ``y`` has fewer observations than ``window_size``. + """ + if len(y) < self.window_size: + raise ValueError(f"Need at least {self.window_size} observations") + window = y[-self.window_size :] + forecast = float(np.mean(window)) + dispersion = float(np.std(window)) + drivers = [ + DriverContribution( + name="window_mean", + feature_value=forecast, + contribution=forecast, + direction="positive", + description=( + f"The forecast is the mean of the last {self.window_size} observed values." + ), + ), + DriverContribution( + name="window_dispersion", + feature_value=dispersion, + contribution=0.0, + direction="neutral", + description=( + "Context only — standard deviation within the averaging " + "window; higher values mean a noisier, less reliable mean." + ), + ), + ] + return forecast, drivers + + def confidence(self, y: FloatArray) -> ConfidenceLevel: + """Return ``HIGH`` for a stable full window, ``MEDIUM``/``LOW`` otherwise.""" + if len(y) < self.window_size: + return ConfidenceLevel.LOW + window = y[-self.window_size :] + mean = float(np.mean(window)) + std = float(np.std(window)) + cv = std / mean if mean > 0 else 0.0 + if cv < 0.5: + return ConfidenceLevel.HIGH + return ConfidenceLevel.MEDIUM + + +def explainer_factory( + model_type: str, + season_length: int | None = None, + window_size: int | None = None, +) -> BaseExplainer: + """Build the rule-based explainer for a baseline model type. + + Args: + model_type: One of ``naive``, ``seasonal_naive``, ``moving_average``. + season_length: Seasonal period for ``seasonal_naive`` (defaults to 7). + window_size: Averaging window for ``moving_average`` (defaults to 7). + + Returns: + The matching explainer instance. + + Raises: + ValueError: For ``lightgbm``/``regression`` (MVP scope guard) or an + unknown model type. + """ + if model_type == "naive": + return NaiveExplainer() + if model_type == "seasonal_naive": + return SeasonalNaiveExplainer(season_length=season_length or 7) + if model_type == "moving_average": + return MovingAverageExplainer(window_size=window_size or 7) + if model_type in ("lightgbm", "regression"): + raise ValueError( + f"Explanations are available for baseline models only; " + f"'{model_type}' is not supported (rule-based MVP)." + ) + raise ValueError(f"Unknown model type: {model_type}") diff --git a/app/features/explainability/models.py b/app/features/explainability/models.py new file mode 100644 index 00000000..e12538c0 --- /dev/null +++ b/app/features/explainability/models.py @@ -0,0 +1,81 @@ +"""ORM model for the explainability slice. + +A ``forecast_explanation`` row persists one rule-based explanation: the driver +breakdown, advisory reason codes, and caveats as JSONB, plus scalar columns for +the forecast context. Persisting it means a re-requested explanation is a cheap +read and gives the slice an audit trail. + +GOTCHA: SQLAlchemy reserves the declarative attribute name ``metadata`` — no +column here uses it. +""" + +from __future__ import annotations + +import datetime +from typing import Any + +from sqlalchemy import CheckConstraint, Date, Float, Index, Integer, String +from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.orm import Mapped, mapped_column + +from app.core.database import Base +from app.shared.models import TimestampMixin + + +class ForecastExplanation(TimestampMixin, Base): + """A persisted rule-based forecast explanation. + + Attributes: + id: Surrogate primary key. + explanation_id: Unique external identifier (UUID hex, 32 chars). + run_id: Originating registry run, when explained via ``/explain/runs``. + job_id: Originating predict job, when explained via ``/explain/jobs``. + store_id: Store the forecast targets. + product_id: Product the forecast targets. + model_type: Baseline model type explained. + method: Explanation method — always ``rule_based`` for the MVP. + as_of_date: Series cutoff date. + forecast_value: The h=1 forecast value. + confidence: Qualitative confidence band (``high|medium|low``). + drivers: Driver contributions as JSONB. + reason_codes: Advisory reason codes as JSONB. + caveats: Plain-language caveats as a JSONB string array. + agent_summary: One-paragraph natural-language summary. + """ + + __tablename__ = "forecast_explanation" + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + explanation_id: Mapped[str] = mapped_column(String(32), unique=True, index=True) + run_id: Mapped[str | None] = mapped_column(String(32), nullable=True, index=True) + job_id: Mapped[str | None] = mapped_column(String(32), nullable=True, index=True) + store_id: Mapped[int] = mapped_column(Integer, index=True) + product_id: Mapped[int] = mapped_column(Integer, index=True) + model_type: Mapped[str] = mapped_column(String(50)) + method: Mapped[str] = mapped_column(String(20), default="rule_based") + as_of_date: Mapped[datetime.date] = mapped_column(Date) + forecast_value: Mapped[float] = mapped_column(Float) + confidence: Mapped[str] = mapped_column(String(10)) + + # JSONB blobs — never named ``metadata`` (SQLAlchemy reserves it). + drivers: Mapped[list[dict[str, Any]]] = mapped_column(JSONB, nullable=False) + reason_codes: Mapped[list[dict[str, Any]]] = mapped_column(JSONB, nullable=False) + caveats: Mapped[list[str]] = mapped_column(JSONB, nullable=False) + + agent_summary: Mapped[str] = mapped_column(String(2000)) + + __table_args__ = ( + # GIN index for JSONB containment queries on the driver breakdown. + Index("ix_forecast_explanation_drivers_gin", "drivers", postgresql_using="gin"), + # Composite index for the common "explanations for this store/product" query. + Index("ix_forecast_explanation_store_product", "store_id", "product_id"), + # Kept in lock-step with the alembic migration that created this table. + CheckConstraint( + "confidence IN ('high', 'medium', 'low')", + name="ck_forecast_explanation_confidence", + ), + CheckConstraint( + "method IN ('rule_based', 'shap', 'component')", + name="ck_forecast_explanation_method", + ), + ) diff --git a/app/features/explainability/reason_codes.py b/app/features/explainability/reason_codes.py new file mode 100644 index 00000000..d047abbe --- /dev/null +++ b/app/features/explainability/reason_codes.py @@ -0,0 +1,186 @@ +"""Advisory retail reason-code engine for the explainability slice. + +These are PURE functions — they perform no database access and take only +primitive inputs. The service layer runs the time-safe queries and extracts the +primitives; every input below therefore reflects only data the caller already +bounded ``<= as_of_date`` (or, for ``holiday_reason``, the explained horizon +date). + +CRITICAL: a reason code is an advisory *correlation* signal, never a causal +claim. ``build_caveats`` always emits the NIST-grounded correlation-vs-causation +disclaimer. +""" + +from __future__ import annotations + +from collections.abc import Sequence +from datetime import date as date_type + +from app.features.explainability.schemas import ReasonCode + + +def stockout_reason(stockout_flags: Sequence[bool]) -> ReasonCode | None: + """Flag stockout-suppressed history in the trailing window. + + Args: + stockout_flags: One ``is_stockout`` flag per day in the trailing window + (already bounded ``<= as_of_date`` by the caller). + + Returns: + A ``stockout_constrained`` warning when any day was a stockout, + otherwise ``None``. + """ + stockout_days = sum(1 for flag in stockout_flags if flag) + if stockout_days == 0: + return None + return ReasonCode( + code="stockout_constrained", + severity="warn", + detail=( + f"{stockout_days} stockout day(s) in the trailing " + f"{len(stockout_flags)}-day window — observed demand may understate " + "true demand because units could not be sold while out of stock." + ), + ) + + +def promotion_reason( + promotion_windows: Sequence[tuple[date_type, date_type]], as_of_date: date_type +) -> ReasonCode | None: + """Flag promotions overlapping the trailing window. + + Args: + promotion_windows: ``(start_date, end_date)`` tuples for promotions + overlapping the trailing window (already bounded ``<= as_of_date``). + as_of_date: The series cutoff date. + + Returns: + A ``promotion_overlap`` info code when any promotion overlaps, + otherwise ``None``. + """ + if not promotion_windows: + return None + active_now = sum(1 for start, end in promotion_windows if start <= as_of_date <= end) + active_clause = f"; {active_now} still active on {as_of_date.isoformat()}" if active_now else "" + return ReasonCode( + code="promotion_overlap", + severity="info", + detail=( + f"{len(promotion_windows)} promotion(s) overlap the trailing window" + f"{active_clause} — promotional demand may not represent the baseline." + ), + ) + + +def lifecycle_reason(launch_date: date_type | None, as_of_date: date_type) -> ReasonCode | None: + """Flag a product still in its early lifecycle. + + Args: + launch_date: The product's launch date, or ``None`` if unknown. + as_of_date: The series cutoff date. + + Returns: + A ``lifecycle_decay`` info code when the product launched fewer than + 30 days before ``as_of_date``, otherwise ``None``. + """ + if launch_date is None: + return None + days_since_launch = (as_of_date - launch_date).days + if 0 <= days_since_launch < 30: + return ReasonCode( + code="lifecycle_decay", + severity="info", + detail=( + f"Product launched {days_since_launch} day(s) ago — early-" + "lifecycle demand is volatile and may not represent a stable " + "baseline." + ), + ) + return None + + +def holiday_reason( + is_holiday: bool, holiday_name: str | None, forecast_date: date_type +) -> ReasonCode | None: + """Flag a holiday landing on the explained forecast horizon. + + Args: + is_holiday: Whether ``forecast_date`` is flagged as a holiday. + holiday_name: The holiday's name, when known. + forecast_date: The date of the explained h=1 forecast. + + Returns: + A ``holiday_effect`` info code when ``forecast_date`` is a holiday, + otherwise ``None``. + """ + if not is_holiday: + return None + name = holiday_name or "a holiday" + return ReasonCode( + code="holiday_effect", + severity="info", + detail=( + f"The forecast date {forecast_date.isoformat()} is {name} — " + "holiday demand typically deviates from a normal day." + ), + ) + + +def history_reason(n_obs: int, min_required: int) -> ReasonCode | None: + """Flag a series too short for a comfortable explanation. + + Args: + n_obs: Number of observations in the series. + min_required: Minimum comfortable observation count for the model. + + Returns: + An ``insufficient_history`` warning when ``n_obs < min_required``, + otherwise ``None``. + """ + if n_obs < min_required: + return ReasonCode( + code="insufficient_history", + severity="warn", + detail=( + f"Only {n_obs} observation(s) available; {min_required} or more " + "is recommended for a confident explanation." + ), + ) + return None + + +# The NIST-grounded disclaimer baked into every explanation (see +# https://www.nist.gov/itl/ai-risk-management-framework). +CORRELATION_CAVEAT = ( + "Drivers describe correlation and contribution, not business causality — " + "they explain the model's arithmetic, not why demand moved." +) + +_MODEL_CAVEATS: dict[str, str] = { + "naive": "The naive model ignores seasonality and trend entirely.", + "seasonal_naive": "The seasonal-naive model assumes the prior cycle repeats exactly.", + "moving_average": "The moving-average model smooths over recent shifts in demand.", +} + + +def build_caveats(model_type: str, reason_codes: Sequence[ReasonCode]) -> list[str]: + """Assemble the caveat list for an explanation. + + Args: + model_type: The baseline model type explained. + reason_codes: The reason codes already computed for the explanation. + + Returns: + Plain-language caveats, always starting with the correlation-vs- + causation disclaimer. + """ + caveats = [CORRELATION_CAVEAT] + model_caveat = _MODEL_CAVEATS.get(model_type) + if model_caveat is not None: + caveats.append(model_caveat) + codes = {rc.code for rc in reason_codes} + if "stockout_constrained" in codes: + caveats.append("Stockout days in the history mean the forecast may understate true demand.") + if "insufficient_history" in codes: + caveats.append("The short history makes this explanation less reliable than usual.") + return caveats diff --git a/app/features/explainability/routes.py b/app/features/explainability/routes.py new file mode 100644 index 00000000..ae93e2bc --- /dev/null +++ b/app/features/explainability/routes.py @@ -0,0 +1,169 @@ +"""API routes for the explainability slice. + +Three endpoints under a self-owned ``/explain`` namespace produce rule-based +forecast explanations. Service-layer ``ValueError`` (unsupported model type, +too-short series) maps to an RFC 7807 400; a missing run/job maps to a 404; +``SQLAlchemyError`` maps to a 500 — never a bare ``HTTPException``. +""" + +from fastapi import APIRouter, Depends, status +from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.database import get_db +from app.core.exceptions import BadRequestError, DatabaseError, NotFoundError +from app.core.logging import get_logger +from app.features.explainability.schemas import ( + ExplainForecastRequest, + ForecastExplanation, +) +from app.features.explainability.service import ExplainabilityService + +logger = get_logger(__name__) + +router = APIRouter(prefix="/explain", tags=["explainability"]) + + +@router.post( + "/forecast", + response_model=ForecastExplanation, + status_code=status.HTTP_200_OK, + summary="Explain an ad-hoc baseline forecast", + description=""" +Compute a rule-based explanation for the h=1 forecast a named baseline model +would produce on the series ending at `as_of_date`. + +**Inputs:** `store_id`, `product_id`, `model_type` (`naive` / `seasonal_naive` / +`moving_average`), `as_of_date`, and optional `season_length` / `window_size`. + +**Output:** a `ForecastExplanation` — ordered driver contributions, advisory +retail reason codes (correlation, never causation), a confidence band, caveats, +and an agent-readable summary. The series and every reason-code query are +time-safe (`<= as_of_date`). + +An unsupported `model_type` (`lightgbm` / `regression`) or a series too short to +forecast returns an RFC 7807 400 — never a 500. +""", +) +async def explain_forecast( + request: ExplainForecastRequest, + db: AsyncSession = Depends(get_db), +) -> ForecastExplanation: + """Explain an ad-hoc baseline forecast. + + Args: + request: Store/product/model/cutoff parameters. + db: Async database session from dependency. + + Returns: + The rule-based forecast explanation. + + Raises: + BadRequestError: For an unsupported model type or a too-short series. + DatabaseError: When persistence fails. + """ + try: + return await ExplainabilityService().explain_forecast(db, request) + except ValueError as exc: + logger.warning("explainability.forecast_invalid", error=str(exc)) + raise BadRequestError(message=str(exc)) from exc + except SQLAlchemyError as exc: + logger.error("explainability.forecast_db_error", error=str(exc), exc_info=True) + raise DatabaseError( + message="Failed to generate forecast explanation", + details={"error": str(exc)}, + ) from exc + + +@router.get( + "/runs/{run_id}", + response_model=ForecastExplanation, + summary="Explain a registry model run", + description=""" +Explain a registry `model_run`. The baseline config is reconstructed from +`model_run.model_config`, and `data_window_end` is used as the series cutoff. + +A missing `run_id` returns a 404; a non-baseline run (`lightgbm` / `regression`) +returns a 400 — explanations are available for baseline models only. +""", +) +async def explain_run( + run_id: str, + db: AsyncSession = Depends(get_db), +) -> ForecastExplanation: + """Explain a registry model run. + + Args: + run_id: External run identifier. + db: Async database session from dependency. + + Returns: + The rule-based forecast explanation. + + Raises: + NotFoundError: When no run matches ``run_id``. + BadRequestError: For a non-baseline run or a too-short series. + DatabaseError: When persistence fails. + """ + try: + explanation = await ExplainabilityService().explain_run(db, run_id) + except ValueError as exc: + logger.warning("explainability.run_invalid", run_id=run_id, error=str(exc)) + raise BadRequestError(message=str(exc)) from exc + except SQLAlchemyError as exc: + logger.error("explainability.run_db_error", error=str(exc), exc_info=True) + raise DatabaseError( + message="Failed to generate run explanation", + details={"error": str(exc)}, + ) from exc + if explanation is None: + raise NotFoundError(message=f"Model run not found: {run_id}") + return explanation + + +@router.get( + "/jobs/{job_id}", + response_model=ForecastExplanation, + summary="Explain a completed predict job", + description=""" +Explain a completed `predict` job. `store_id`, `product_id`, and `model_type` +are read from `job.result`; the series cutoff is the day before the first +forecast date. + +A missing `job_id` returns a 404; a job that is not a completed predict job +returns a 400. +""", +) +async def explain_job( + job_id: str, + db: AsyncSession = Depends(get_db), +) -> ForecastExplanation: + """Explain a completed predict job. + + Args: + job_id: External job identifier. + db: Async database session from dependency. + + Returns: + The rule-based forecast explanation. + + Raises: + NotFoundError: When no job matches ``job_id``. + BadRequestError: When the job is not a completed predict job, or for a + too-short series. + DatabaseError: When persistence fails. + """ + try: + explanation = await ExplainabilityService().explain_job(db, job_id) + except ValueError as exc: + logger.warning("explainability.job_invalid", job_id=job_id, error=str(exc)) + raise BadRequestError(message=str(exc)) from exc + except SQLAlchemyError as exc: + logger.error("explainability.job_db_error", error=str(exc), exc_info=True) + raise DatabaseError( + message="Failed to generate job explanation", + details={"error": str(exc)}, + ) from exc + if explanation is None: + raise NotFoundError(message=f"Job not found: {job_id}") + return explanation diff --git a/app/features/explainability/schemas.py b/app/features/explainability/schemas.py new file mode 100644 index 00000000..9c767a09 --- /dev/null +++ b/app/features/explainability/schemas.py @@ -0,0 +1,147 @@ +"""Pydantic v2 schemas for the explainability slice. + +The response schemas (``DriverContribution``, ``ReasonCode``, +``ForecastExplanation``) are plain ``BaseModel`` — NOT ``strict=True`` — so they +serialise cleanly. The single request body (``ExplainForecastRequest``) IS +``strict=True``; its ``as_of_date`` field therefore carries ``Field(strict=False)`` +because ``date`` has no native JSON representation (see ``docs/_base/SECURITY.md`` +-> "Pydantic v2 strict mode on FastAPI request bodies"; enforced by +``app/core/tests/test_strict_mode_policy.py``). +""" + +from __future__ import annotations + +from datetime import UTC, datetime +from datetime import date as date_type +from enum import Enum +from typing import Literal + +from pydantic import BaseModel, ConfigDict, Field + +# Direction of a driver's influence on the forecast. +Direction = Literal["positive", "negative", "neutral"] + +# Advisory retail reason-code identifiers — correlation signals, never causal claims. +ReasonCodeId = Literal[ + "stockout_constrained", + "promotion_overlap", + "holiday_effect", + "lifecycle_decay", + "trend_shift", + "insufficient_history", +] + +# Baseline model types this slice can explain. ``lightgbm``/``regression`` are +# rejected with a clean 400 (MVP scope guard). +ExplainableModelType = Literal["naive", "seasonal_naive", "moving_average"] + + +class ConfidenceLevel(str, Enum): + """Qualitative confidence band for an explanation.""" + + HIGH = "high" + MEDIUM = "medium" + LOW = "low" + + +class DriverContribution(BaseModel): + """One named, interpretable demand driver behind a forecast. + + Attributes: + name: Stable machine identifier for the driver. + feature_value: The observed value of the underlying feature. + contribution: Amount (in model units) this driver adds to the forecast. + Informational/context drivers carry ``0.0``. + direction: Sign of the driver's influence. + description: Human-readable explanation of the driver. + """ + + name: str + feature_value: float + contribution: float + direction: Direction + description: str + + +class ReasonCode(BaseModel): + """An advisory retail signal correlated with the forecast. + + CRITICAL: reason codes describe correlation, never business causality. + + Attributes: + code: Machine-readable reason-code identifier. + severity: ``info`` for context, ``warn`` for a quality caveat. + detail: Human-readable detail for the signal. + """ + + code: ReasonCodeId + severity: Literal["info", "warn"] + detail: str + + +class ForecastExplanation(BaseModel): + """Structured, rule-based explanation of a baseline h=1 forecast. + + Attributes: + store_id: Store the forecast targets. + product_id: Product the forecast targets. + model_type: Baseline model type explained. + method: Always ``rule_based`` for the MVP (``shap``/``component`` reserved). + forecast_value: The h=1 forecast the baseline model produces. + drivers: Ordered, named driver contributions. + reason_codes: Advisory retail reason codes (correlation only). + confidence: Qualitative confidence band. + caveats: Plain-language caveats, always including the correlation-vs- + causation disclaimer. + agent_summary: One-paragraph natural-language summary for chat agents. + as_of_date: Series cutoff — no data past this date informs the explanation. + generated_at: UTC timestamp the explanation was produced. + """ + + model_config = ConfigDict(from_attributes=True) + + store_id: int + product_id: int + model_type: str + method: Literal["rule_based"] = "rule_based" + forecast_value: float + drivers: list[DriverContribution] + reason_codes: list[ReasonCode] + confidence: ConfidenceLevel + caveats: list[str] + agent_summary: str + as_of_date: date_type + generated_at: datetime = Field(default_factory=lambda: datetime.now(UTC)) + + +class ExplainForecastRequest(BaseModel): + """Request body for ``POST /explain/forecast``. + + Attributes: + store_id: Store ID to explain. + product_id: Product ID to explain. + model_type: Baseline model type to reproduce and explain. + as_of_date: Series cutoff date — the explainer reads only ``<= as_of_date``. + season_length: Seasonal period (``seasonal_naive`` only; defaults to 7). + window_size: Averaging window (``moving_average`` only; defaults to 7). + """ + + model_config = ConfigDict(strict=True) + + store_id: int = Field(..., ge=1, description="Store ID") + product_id: int = Field(..., ge=1, description="Product ID") + model_type: ExplainableModelType = Field(..., description="Baseline model type") + # ``date`` has no native JSON representation — ``strict=False`` lets FastAPI's + # ``validate_python`` accept an ISO-string body. Repo-wide policy; see module + # docstring. + as_of_date: date_type = Field( + ..., + strict=False, + description="Series cutoff date (the explainer reads only <= this date)", + ) + season_length: int | None = Field( + None, ge=1, le=365, description="Seasonal period for seasonal_naive (default 7)" + ) + window_size: int | None = Field( + None, ge=1, le=90, description="Averaging window for moving_average (default 7)" + ) diff --git a/app/features/explainability/service.py b/app/features/explainability/service.py new file mode 100644 index 00000000..2f10fe6e --- /dev/null +++ b/app/features/explainability/service.py @@ -0,0 +1,421 @@ +"""Service layer for the explainability slice. + +``ExplainabilityService`` is READ-ONLY with respect to every other slice. It +imports ``app.features.registry.models.ModelRun`` and +``app.features.jobs.models.Job`` directly, but ONLY as read-only data contracts +— a locked maintainer decision (PRP-28 "Open Questions & Decisions" #1), the +same pattern by which slices already import ``app.features.data_platform.models``. +It NEVER imports another slice's ``service.py``. + +To explain a run or job, the service re-loads the target series from +``sales_daily`` and re-fits a rule-based explainer from the stored config — it +does not reload the model artifact. Every series load and reason-code query is +bounded ``<= as_of_date`` (time-safety is load-bearing). +""" + +from __future__ import annotations + +import uuid +from datetime import date as date_type +from datetime import timedelta +from typing import Any + +import numpy as np +import structlog +from sqlalchemy import or_, select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.config import get_settings +from app.core.exceptions import BadRequestError +from app.features.data_platform.models import ( + Calendar, + InventorySnapshotDaily, + Product, + Promotion, + SalesDaily, +) +from app.features.explainability.explainers import FloatArray, explainer_factory +from app.features.explainability.models import ( + ForecastExplanation as ForecastExplanationORM, +) +from app.features.explainability.reason_codes import ( + build_caveats, + history_reason, + holiday_reason, + lifecycle_reason, + promotion_reason, + stockout_reason, +) +from app.features.explainability.schemas import ( + ConfidenceLevel, + DriverContribution, + ExplainForecastRequest, + ForecastExplanation, + ReasonCode, +) +from app.features.jobs.models import Job # read-only data contract — see module docstring +from app.features.registry.models import ( # read-only data contract — see module docstring + ModelRun, +) + +logger = structlog.get_logger(__name__) + +# Trailing window (days) used for stockout / promotion reason-code lookups. +_REASON_WINDOW_DAYS = 30 + + +class ExplainabilityService: + """Produces rule-based forecast explanations for the baseline models.""" + + def __init__(self) -> None: + """Initialise the service.""" + self.settings = get_settings() + + # ------------------------------------------------------------------ # + # Public entry points + # ------------------------------------------------------------------ # + + async def explain_forecast( + self, db: AsyncSession, request: ExplainForecastRequest + ) -> ForecastExplanation: + """Explain the h=1 forecast a baseline model produces ad hoc. + + Args: + db: Async database session. + request: Store/product/model/cutoff parameters. + + Returns: + The persisted forecast explanation. + + Raises: + ValueError: For an unsupported model type or a too-short series. + """ + return await self._explain( + db, + store_id=request.store_id, + product_id=request.product_id, + model_type=request.model_type, + as_of_date=request.as_of_date, + season_length=request.season_length, + window_size=request.window_size, + ) + + async def explain_run(self, db: AsyncSession, run_id: str) -> ForecastExplanation | None: + """Explain a registry ``model_run``. + + Args: + db: Async database session. + run_id: External run identifier. + + Returns: + The explanation, or ``None`` when the run does not exist. + + Raises: + ValueError: For a non-baseline run or a too-short series. + """ + run = ( + await db.execute(select(ModelRun).where(ModelRun.run_id == run_id)) + ).scalar_one_or_none() + if run is None: + return None + config: dict[str, Any] = run.model_config or {} + return await self._explain( + db, + store_id=run.store_id, + product_id=run.product_id, + model_type=run.model_type, + as_of_date=run.data_window_end, + season_length=config.get("season_length"), + window_size=config.get("window_size"), + run_id=run_id, + ) + + async def explain_job(self, db: AsyncSession, job_id: str) -> ForecastExplanation | None: + """Explain a completed ``predict`` job. + + Args: + db: Async database session. + job_id: External job identifier. + + Returns: + The explanation, or ``None`` when the job does not exist. + + Raises: + BadRequestError: When the job is not a completed predict job, or + its result carries no forecasts. + ValueError: For an unsupported model type or a too-short series. + """ + job = (await db.execute(select(Job).where(Job.job_id == job_id))).scalar_one_or_none() + if job is None: + return None + if job.job_type != "predict" or job.status != "completed": + raise BadRequestError( + message="explain_job requires a completed predict job", + details={"job_id": job_id, "job_type": job.job_type, "status": job.status}, + ) + result: dict[str, Any] = job.result or {} + forecasts: list[Any] = result.get("forecasts") or [] + if not forecasts: + raise BadRequestError( + message="predict job has no forecasts to explain", + details={"job_id": job_id}, + ) + store_id = result.get("store_id") + product_id = result.get("product_id") + model_type = result.get("model_type") + if store_id is None or product_id is None or model_type is None: + raise BadRequestError( + message="predict job result is missing store/product/model_type", + details={"job_id": job_id}, + ) + # as_of_date = the day before the first forecast date (PRP-28 assumption #4). + first_forecast_date = date_type.fromisoformat(forecasts[0]["date"]) + as_of_date = first_forecast_date - timedelta(days=1) + return await self._explain( + db, + store_id=int(store_id), + product_id=int(product_id), + model_type=str(model_type), + as_of_date=as_of_date, + # A predict job's result does not record season_length/window_size; + # the explainer falls back to the forecaster defaults (7). + season_length=None, + window_size=None, + job_id=job_id, + ) + + # ------------------------------------------------------------------ # + # Core + # ------------------------------------------------------------------ # + + async def _explain( + self, + db: AsyncSession, + *, + store_id: int, + product_id: int, + model_type: str, + as_of_date: date_type, + season_length: int | None, + window_size: int | None, + run_id: str | None = None, + job_id: str | None = None, + ) -> ForecastExplanation: + """Build, persist, and return one rule-based explanation.""" + explainer = explainer_factory(model_type, season_length, window_size) + y, _dates = await self._load_series(db, store_id, product_id, as_of_date) + forecast_value, drivers = explainer.explain(y) + confidence = explainer.confidence(y) + forecast_date = as_of_date + timedelta(days=1) + + reason_codes = await self._assemble_reason_codes( + db, + store_id=store_id, + product_id=product_id, + model_type=model_type, + as_of_date=as_of_date, + forecast_date=forecast_date, + season_length=season_length, + window_size=window_size, + n_obs=len(y), + ) + caveats = build_caveats(model_type, reason_codes) + agent_summary = self._build_agent_summary( + store_id=store_id, + product_id=product_id, + model_type=model_type, + forecast_value=forecast_value, + forecast_date=forecast_date, + drivers=drivers, + reason_codes=reason_codes, + confidence=confidence, + ) + explanation = ForecastExplanation( + store_id=store_id, + product_id=product_id, + model_type=model_type, + forecast_value=forecast_value, + drivers=drivers, + reason_codes=reason_codes, + confidence=confidence, + caveats=caveats, + agent_summary=agent_summary, + as_of_date=as_of_date, + ) + await self._persist(db, explanation, run_id=run_id, job_id=job_id) + logger.info( + "explainability.explanation_generated", + store_id=store_id, + product_id=product_id, + model_type=model_type, + confidence=confidence.value, + n_reason_codes=len(reason_codes), + ) + return explanation + + async def _load_series( + self, + db: AsyncSession, + store_id: int, + product_id: int, + end_date: date_type, + ) -> tuple[FloatArray, list[date_type]]: + """Load the time-ordered sales series, bounded ``<= end_date``. + + TIME-SAFETY: the ``date <= end_date`` bound is load-bearing — no data + past the cutoff may inform an explanation. + """ + stmt = ( + select(SalesDaily) + .where( + SalesDaily.store_id == store_id, + SalesDaily.product_id == product_id, + SalesDaily.date <= end_date, + ) + .order_by(SalesDaily.date) + ) + rows = (await db.execute(stmt)).scalars().all() + y: FloatArray = np.array([float(r.quantity) for r in rows], dtype=np.float64) + return y, [r.date for r in rows] + + async def _assemble_reason_codes( + self, + db: AsyncSession, + *, + store_id: int, + product_id: int, + model_type: str, + as_of_date: date_type, + forecast_date: date_type, + season_length: int | None, + window_size: int | None, + n_obs: int, + ) -> list[ReasonCode]: + """Run the time-safe reason-code queries and assemble the code list.""" + window_start = as_of_date - timedelta(days=_REASON_WINDOW_DAYS) + + inventory_rows = ( + ( + await db.execute( + select(InventorySnapshotDaily).where( + InventorySnapshotDaily.store_id == store_id, + InventorySnapshotDaily.product_id == product_id, + InventorySnapshotDaily.date <= as_of_date, + InventorySnapshotDaily.date >= window_start, + ) + ) + ) + .scalars() + .all() + ) + + promotion_rows = ( + ( + await db.execute( + select(Promotion).where( + Promotion.product_id == product_id, + or_(Promotion.store_id == store_id, Promotion.store_id.is_(None)), + Promotion.start_date <= as_of_date, + Promotion.end_date >= window_start, + ) + ) + ) + .scalars() + .all() + ) + + product = ( + await db.execute(select(Product).where(Product.id == product_id)) + ).scalar_one_or_none() + + calendar_row = ( + await db.execute(select(Calendar).where(Calendar.date == forecast_date)) + ).scalar_one_or_none() + + # Extract primitives — the reason-code engine is DB- and ORM-free. + stockout_flags = [row.is_stockout for row in inventory_rows] + promotion_windows = [(row.start_date, row.end_date) for row in promotion_rows] + launch_date = product.launch_date if product is not None else None + is_holiday = calendar_row.is_holiday if calendar_row is not None else False + holiday_name = calendar_row.holiday_name if calendar_row is not None else None + min_required = self._min_required_history(model_type, season_length, window_size) + + candidates = [ + stockout_reason(stockout_flags), + promotion_reason(promotion_windows, as_of_date), + lifecycle_reason(launch_date, as_of_date), + holiday_reason(is_holiday, holiday_name, forecast_date), + history_reason(n_obs, min_required), + ] + return [code for code in candidates if code is not None] + + @staticmethod + def _min_required_history( + model_type: str, season_length: int | None, window_size: int | None + ) -> int: + """Comfortable minimum observation count for a confident explanation.""" + if model_type == "seasonal_naive": + return 2 * (season_length or 7) + if model_type == "moving_average": + return 2 * (window_size or 7) + return 14 + + @staticmethod + def _build_agent_summary( + *, + store_id: int, + product_id: int, + model_type: str, + forecast_value: float, + forecast_date: date_type, + drivers: list[DriverContribution], + reason_codes: list[ReasonCode], + confidence: ConfidenceLevel, + ) -> str: + """Compose a one-paragraph natural-language summary for chat agents.""" + main_driver = drivers[0] + sentences = [ + f"For store {store_id} / product {product_id}, the {model_type} model " + f"forecasts {forecast_value:.1f} units for {forecast_date.isoformat()}.", + f"The forecast is driven by '{main_driver.name}' " + f"(value {main_driver.feature_value:.1f}).", + f"Explanation confidence is {confidence.value}.", + ] + if reason_codes: + codes = ", ".join(rc.code for rc in reason_codes) + sentences.append(f"Advisory retail signals present: {codes}.") + else: + sentences.append("No advisory retail signals were detected.") + return " ".join(sentences) + + async def _persist( + self, + db: AsyncSession, + explanation: ForecastExplanation, + *, + run_id: str | None, + job_id: str | None, + ) -> None: + """Persist the explanation as a ``forecast_explanation`` row. + + Uses ``flush``/``refresh`` — ``get_db`` auto-commits on success. + """ + row = ForecastExplanationORM( + explanation_id=uuid.uuid4().hex, + run_id=run_id, + job_id=job_id, + store_id=explanation.store_id, + product_id=explanation.product_id, + model_type=explanation.model_type, + method=explanation.method, + as_of_date=explanation.as_of_date, + forecast_value=explanation.forecast_value, + confidence=explanation.confidence.value, + drivers=[d.model_dump() for d in explanation.drivers], + reason_codes=[rc.model_dump() for rc in explanation.reason_codes], + caveats=list(explanation.caveats), + agent_summary=explanation.agent_summary, + ) + db.add(row) + await db.flush() + await db.refresh(row) diff --git a/app/features/explainability/tests/__init__.py b/app/features/explainability/tests/__init__.py new file mode 100644 index 00000000..4b942f89 --- /dev/null +++ b/app/features/explainability/tests/__init__.py @@ -0,0 +1 @@ +"""Tests for the explainability slice.""" diff --git a/app/features/explainability/tests/conftest.py b/app/features/explainability/tests/conftest.py new file mode 100644 index 00000000..64480968 --- /dev/null +++ b/app/features/explainability/tests/conftest.py @@ -0,0 +1,241 @@ +"""Test fixtures for the explainability slice. + +Unit fixtures supply numpy series and a ``make_mock_db`` factory that builds an +``AsyncMock`` session whose ``execute`` calls are scripted in order. Integration +fixtures (``@pytest.mark.integration``) seed a real ``docker compose`` Postgres +and clean up after themselves; ``forecast_explanation`` is a slice-private table +so its teardown wipes it whole. +""" + +from __future__ import annotations + +import datetime +import uuid +from collections.abc import AsyncGenerator +from datetime import date, timedelta +from decimal import Decimal +from types import SimpleNamespace +from typing import Any +from unittest.mock import AsyncMock, MagicMock + +import numpy as np +import pytest +from httpx import ASGITransport, AsyncClient +from sqlalchemy import delete +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine + +from app.core.config import get_settings +from app.core.database import get_db +from app.features.data_platform.models import Calendar, Product, SalesDaily, Store +from app.features.explainability.models import ForecastExplanation +from app.features.registry.models import ModelRun, RunStatus +from app.main import app + +# Test date range — kept narrow so the calendar teardown is precise. +TEST_START = date(2024, 1, 1) +TEST_DAYS = 90 +TEST_END = TEST_START + timedelta(days=TEST_DAYS - 1) + + +# ============================================================================= +# Unit fixtures — numpy series + scripted-mock DB factory +# ============================================================================= + + +@pytest.fixture +def sample_series() -> np.ndarray: + """A 60-observation float series with mild variation.""" + return np.array([float(10 + (i % 7)) for i in range(60)], dtype=np.float64) + + +@pytest.fixture +def flat_series() -> np.ndarray: + """A 30-observation constant series.""" + return np.full(30, 25.0, dtype=np.float64) + + +@pytest.fixture +def short_series() -> np.ndarray: + """A 3-observation series (shorter than every comfortable threshold).""" + return np.array([5.0, 7.0, 6.0], dtype=np.float64) + + +def mock_result(*, scalars: list[Any] | None = None, one: Any | None = None) -> MagicMock: + """Build a mock SQLAlchemy ``Result`` for one ``execute`` call.""" + result = MagicMock() + result.scalars.return_value.all.return_value = scalars or [] + result.scalar_one_or_none.return_value = one + return result + + +def make_mock_db(results: list[MagicMock]) -> AsyncMock: + """Build an ``AsyncMock`` session whose ``execute`` returns ``results`` in order. + + Args: + results: Mock ``Result`` objects (see ``mock_result``), one per expected + ``execute`` call, in call order. + + Returns: + A mock session ready to pass to ``ExplainabilityService``. + """ + db = AsyncMock() + db.execute = AsyncMock(side_effect=results) + db.flush = AsyncMock() + db.refresh = AsyncMock() + db.add = MagicMock() + return db + + +def sales_rows(values: list[float], start: date = TEST_START) -> list[SimpleNamespace]: + """Build sales-row stand-ins (``.quantity`` / ``.date``) for the mock DB.""" + return [ + SimpleNamespace(quantity=int(v), date=start + timedelta(days=i)) + for i, v in enumerate(values) + ] + + +def forecast_result_db(values: list[float]) -> AsyncMock: + """Mock DB for one ``explain_forecast`` call (series + 4 reason-code queries).""" + return make_mock_db( + [ + mock_result(scalars=sales_rows(values)), # _load_series + mock_result(scalars=[]), # inventory + mock_result(scalars=[]), # promotion + mock_result(one=None), # product + mock_result(one=None), # calendar + ] + ) + + +# ============================================================================= +# Integration fixtures — real Postgres +# ============================================================================= + + +@pytest.fixture +async def db_session() -> AsyncGenerator[AsyncSession, None]: + """Yield an async session; wipe explainability + test data on teardown.""" + settings = get_settings() + engine = create_async_engine(settings.database_url, echo=False) + session_maker = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) + + async with session_maker() as session: + try: + yield session + finally: + await session.execute(delete(ForecastExplanation)) + await session.execute(delete(SalesDaily)) + await session.execute(delete(ModelRun).where(ModelRun.run_id.like("texpl%"))) + await session.execute(delete(Product).where(Product.sku.like("TEXPL-%"))) + await session.execute(delete(Store).where(Store.code.like("TEXPL-%"))) + await session.execute( + delete(Calendar).where((Calendar.date >= TEST_START) & (Calendar.date <= TEST_END)) + ) + await session.commit() + + await engine.dispose() + + +@pytest.fixture +async def client(db_session: AsyncSession) -> AsyncGenerator[AsyncClient, None]: + """Test client with the database dependency overridden.""" + + async def override_get_db() -> AsyncGenerator[AsyncSession, None]: + yield db_session + + app.dependency_overrides[get_db] = override_get_db + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as ac: + yield ac + app.dependency_overrides.pop(get_db, None) + + +@pytest.fixture +async def seeded_series(db_session: AsyncSession) -> dict[str, int]: + """Seed a store, product, calendar, and a sales series; return ids. + + The series is a clean weekly pattern so the seasonal-naive h=1 forecast is + deterministic. + """ + suffix = uuid.uuid4().hex[:8] + store = Store(code=f"TEXPL-{suffix}", name="Explain Store", region="R", store_type="x") + product = Product( + sku=f"TEXPL-{suffix}", + name="Explain Product", + category="C", + base_price=Decimal("9.99"), + launch_date=TEST_START, + ) + db_session.add_all([store, product]) + await db_session.commit() + await db_session.refresh(store) + await db_session.refresh(product) + + weekly = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0] + for i in range(TEST_DAYS): + d = TEST_START + timedelta(days=i) + await db_session.merge( + Calendar( + date=d, + day_of_week=d.weekday(), + month=d.month, + quarter=(d.month - 1) // 3 + 1, + year=d.year, + is_holiday=False, + ) + ) + await db_session.commit() + + for i in range(TEST_DAYS): + qty = weekly[i % 7] + db_session.add( + SalesDaily( + date=TEST_START + timedelta(days=i), + store_id=store.id, + product_id=product.id, + quantity=int(qty), + unit_price=Decimal("9.99"), + total_amount=Decimal("9.99") * int(qty), + ) + ) + await db_session.commit() + + return {"store_id": store.id, "product_id": product.id} + + +@pytest.fixture +async def seeded_run(db_session: AsyncSession, seeded_series: dict[str, int]) -> str: + """Seed a successful baseline ModelRun over the seeded series; return run_id.""" + run_id = f"texpl{uuid.uuid4().hex[:11]}" + run = ModelRun( + run_id=run_id, + status=RunStatus.SUCCESS.value, + model_type="naive", + model_config={"model_type": "naive", "schema_version": "1.0"}, + config_hash="deadbeefdeadbeef", + data_window_start=TEST_START, + data_window_end=TEST_END, + store_id=seeded_series["store_id"], + product_id=seeded_series["product_id"], + ) + db_session.add(run) + await db_session.commit() + return run_id + + +@pytest.fixture +def explanation_row_kwargs() -> dict[str, Any]: + """Valid keyword args for constructing a ForecastExplanation ORM row.""" + return { + "explanation_id": uuid.uuid4().hex, + "store_id": 1, + "product_id": 2, + "model_type": "naive", + "method": "rule_based", + "as_of_date": datetime.date(2024, 3, 1), + "forecast_value": 42.0, + "confidence": "medium", + "drivers": [{"name": "last_observation", "contribution": 42.0}], + "reason_codes": [], + "caveats": ["correlation not causation"], + "agent_summary": "A test explanation.", + } diff --git a/app/features/explainability/tests/test_explainers.py b/app/features/explainability/tests/test_explainers.py new file mode 100644 index 00000000..05045145 --- /dev/null +++ b/app/features/explainability/tests/test_explainers.py @@ -0,0 +1,161 @@ +"""Unit tests for the rule-based explainers. + +The load-bearing assertion: each explainer's h=1 forecast value EQUALS the real +forecaster's ``.fit(y).predict(1)[0]`` on the same series. A rule-based +explainer is exact — if it diverges from the forecaster, it is wrong. +""" + +from __future__ import annotations + +import numpy as np +import pytest + +from app.features.explainability.explainers import ( + MovingAverageExplainer, + NaiveExplainer, + SeasonalNaiveExplainer, + explainer_factory, +) +from app.features.explainability.schemas import ConfidenceLevel +from app.features.forecasting.models import ( + MovingAverageForecaster, + NaiveForecaster, + SeasonalNaiveForecaster, +) + + +class TestNaiveExplainer: + """Tests for NaiveExplainer.""" + + def test_forecast_matches_real_forecaster(self, sample_series: np.ndarray) -> None: + """h=1 value equals NaiveForecaster's prediction on the same series.""" + forecast, _ = NaiveExplainer().explain(sample_series) + expected = float(NaiveForecaster().fit(sample_series).predict(1)[0]) + assert forecast == pytest.approx(expected) + + def test_main_driver_contribution_sums_to_forecast(self, sample_series: np.ndarray) -> None: + """Driver contributions sum to the forecast value.""" + forecast, drivers = NaiveExplainer().explain(sample_series) + assert sum(d.contribution for d in drivers) == pytest.approx(forecast) + + def test_recent_trend_driver_present_for_long_series(self, sample_series: np.ndarray) -> None: + """A long series gets an informational recent_trend driver.""" + _, drivers = NaiveExplainer().explain(sample_series) + names = {d.name for d in drivers} + assert "last_observation" in names + assert "recent_trend" in names + trend = next(d for d in drivers if d.name == "recent_trend") + assert trend.contribution == 0.0 + + def test_no_trend_driver_for_short_series(self, short_series: np.ndarray) -> None: + """A short series gets only the last_observation driver.""" + _, drivers = NaiveExplainer().explain(short_series) + assert [d.name for d in drivers] == ["last_observation"] + + def test_empty_series_raises(self) -> None: + """An empty series raises ValueError (mirrors NaiveForecaster.fit).""" + with pytest.raises(ValueError, match="empty"): + NaiveExplainer().explain(np.array([], dtype=np.float64)) + + def test_confidence_downgrades_on_short_series( + self, sample_series: np.ndarray, short_series: np.ndarray + ) -> None: + """Confidence is LOW for a short series, MEDIUM otherwise.""" + assert NaiveExplainer().confidence(short_series) == ConfidenceLevel.LOW + assert NaiveExplainer().confidence(sample_series) == ConfidenceLevel.MEDIUM + + +class TestSeasonalNaiveExplainer: + """Tests for SeasonalNaiveExplainer.""" + + def test_forecast_matches_real_forecaster(self, sample_series: np.ndarray) -> None: + """h=1 value equals SeasonalNaiveForecaster's prediction.""" + forecast, _ = SeasonalNaiveExplainer(season_length=7).explain(sample_series) + expected = float(SeasonalNaiveForecaster(season_length=7).fit(sample_series).predict(1)[0]) + assert forecast == pytest.approx(expected) + + def test_main_driver_contribution_sums_to_forecast(self, sample_series: np.ndarray) -> None: + """Driver contributions sum to the forecast value.""" + forecast, drivers = SeasonalNaiveExplainer(season_length=7).explain(sample_series) + assert sum(d.contribution for d in drivers) == pytest.approx(forecast) + assert drivers[0].direction == "positive" + + def test_too_short_series_raises(self, short_series: np.ndarray) -> None: + """A series shorter than season_length raises ValueError.""" + with pytest.raises(ValueError, match="at least"): + SeasonalNaiveExplainer(season_length=7).explain(short_series) + + def test_invalid_season_length_raises(self) -> None: + """season_length < 1 raises ValueError.""" + with pytest.raises(ValueError, match="season_length"): + SeasonalNaiveExplainer(season_length=0) + + def test_confidence_downgrades_under_two_cycles(self) -> None: + """Confidence is LOW under two seasonal cycles, MEDIUM otherwise.""" + short = np.arange(10.0, dtype=np.float64) # 10 < 2*7 + long = np.arange(40.0, dtype=np.float64) # 40 >= 2*7 + assert SeasonalNaiveExplainer(7).confidence(short) == ConfidenceLevel.LOW + assert SeasonalNaiveExplainer(7).confidence(long) == ConfidenceLevel.MEDIUM + + +class TestMovingAverageExplainer: + """Tests for MovingAverageExplainer.""" + + def test_forecast_matches_real_forecaster(self, sample_series: np.ndarray) -> None: + """h=1 value equals MovingAverageForecaster's prediction.""" + forecast, _ = MovingAverageExplainer(window_size=7).explain(sample_series) + expected = float(MovingAverageForecaster(window_size=7).fit(sample_series).predict(1)[0]) + assert forecast == pytest.approx(expected) + + def test_main_driver_contribution_sums_to_forecast(self, sample_series: np.ndarray) -> None: + """Driver contributions sum to the forecast (dispersion contributes 0).""" + forecast, drivers = MovingAverageExplainer(window_size=7).explain(sample_series) + assert sum(d.contribution for d in drivers) == pytest.approx(forecast) + dispersion = next(d for d in drivers if d.name == "window_dispersion") + assert dispersion.contribution == 0.0 + assert dispersion.direction == "neutral" + + def test_too_short_series_raises(self, short_series: np.ndarray) -> None: + """A series shorter than window_size raises ValueError.""" + with pytest.raises(ValueError, match="at least"): + MovingAverageExplainer(window_size=7).explain(short_series) + + def test_confidence_high_for_stable_window(self, flat_series: np.ndarray) -> None: + """A flat (zero-dispersion) window yields HIGH confidence.""" + assert MovingAverageExplainer(7).confidence(flat_series) == ConfidenceLevel.HIGH + + def test_confidence_medium_for_noisy_window(self) -> None: + """A high-variance window yields MEDIUM confidence.""" + noisy = np.array([1.0, 100.0, 2.0, 99.0, 3.0, 98.0, 4.0], dtype=np.float64) + assert MovingAverageExplainer(7).confidence(noisy) == ConfidenceLevel.MEDIUM + + +class TestExplainerFactory: + """Tests for explainer_factory.""" + + def test_builds_each_baseline(self) -> None: + """The factory builds the matching explainer per baseline model type.""" + assert isinstance(explainer_factory("naive"), NaiveExplainer) + assert isinstance( + explainer_factory("seasonal_naive", season_length=14), SeasonalNaiveExplainer + ) + assert isinstance( + explainer_factory("moving_average", window_size=21), MovingAverageExplainer + ) + + def test_seasonal_defaults_to_seven(self) -> None: + """A None season_length defaults to 7.""" + explainer = explainer_factory("seasonal_naive") + assert isinstance(explainer, SeasonalNaiveExplainer) + assert explainer.season_length == 7 + + @pytest.mark.parametrize("model_type", ["lightgbm", "regression"]) + def test_rejects_non_baseline_models(self, model_type: str) -> None: + """lightgbm/regression are rejected (MVP scope guard).""" + with pytest.raises(ValueError, match="baseline models only"): + explainer_factory(model_type) + + def test_rejects_unknown_model(self) -> None: + """An unknown model type raises ValueError.""" + with pytest.raises(ValueError, match="Unknown model type"): + explainer_factory("transformer") diff --git a/app/features/explainability/tests/test_models_integration.py b/app/features/explainability/tests/test_models_integration.py new file mode 100644 index 00000000..2fc210ff --- /dev/null +++ b/app/features/explainability/tests/test_models_integration.py @@ -0,0 +1,62 @@ +"""Integration tests for the ForecastExplanation ORM model. + +Run against the real docker-compose Postgres (``docker compose up -d``). +""" + +from __future__ import annotations + +from typing import Any + +import pytest +from sqlalchemy import select +from sqlalchemy.exc import IntegrityError +from sqlalchemy.ext.asyncio import AsyncSession + +from app.features.explainability.models import ForecastExplanation + + +@pytest.mark.integration +@pytest.mark.asyncio +class TestForecastExplanationModel: + """CRUD and constraint tests for the forecast_explanation table.""" + + async def test_insert_and_read_back( + self, db_session: AsyncSession, explanation_row_kwargs: dict[str, Any] + ) -> None: + """A forecast_explanation row persists and reads back intact.""" + row = ForecastExplanation(**explanation_row_kwargs) + db_session.add(row) + await db_session.commit() + + fetched = ( + await db_session.execute( + select(ForecastExplanation).where( + ForecastExplanation.explanation_id == explanation_row_kwargs["explanation_id"] + ) + ) + ).scalar_one() + assert fetched.forecast_value == 42.0 + assert fetched.model_type == "naive" + assert fetched.confidence == "medium" + assert fetched.drivers[0]["name"] == "last_observation" + assert fetched.created_at is not None + + async def test_confidence_check_constraint_rejects_bad_value( + self, db_session: AsyncSession, explanation_row_kwargs: dict[str, Any] + ) -> None: + """An out-of-allow-list confidence value is rejected by the CHECK.""" + bad = ForecastExplanation(**{**explanation_row_kwargs, "confidence": "bogus"}) + db_session.add(bad) + with pytest.raises(IntegrityError): + await db_session.flush() + await db_session.rollback() + + async def test_method_check_constraint_rejects_bad_value( + self, db_session: AsyncSession, explanation_row_kwargs: dict[str, Any] + ) -> None: + """An out-of-allow-list method value is rejected by the CHECK.""" + bad = ForecastExplanation(**{**explanation_row_kwargs, "method": "telepathy"}) + db_session.add(bad) + with pytest.raises(IntegrityError): + await db_session.flush() + await db_session.rollback() diff --git a/app/features/explainability/tests/test_reason_codes.py b/app/features/explainability/tests/test_reason_codes.py new file mode 100644 index 00000000..066ebba8 --- /dev/null +++ b/app/features/explainability/tests/test_reason_codes.py @@ -0,0 +1,128 @@ +"""Unit tests for the advisory reason-code engine.""" + +from __future__ import annotations + +from datetime import date + +from app.features.explainability.reason_codes import ( + CORRELATION_CAVEAT, + build_caveats, + history_reason, + holiday_reason, + lifecycle_reason, + promotion_reason, + stockout_reason, +) +from app.features.explainability.schemas import ReasonCode + +AS_OF = date(2024, 3, 1) + + +class TestStockoutReason: + """Tests for stockout_reason.""" + + def test_fires_on_stockout_days(self) -> None: + """A stockout day produces a warn-level reason code.""" + code = stockout_reason([False, True, False, True]) + assert code is not None + assert code.code == "stockout_constrained" + assert code.severity == "warn" + assert "2 stockout" in code.detail + + def test_none_when_no_stockout(self) -> None: + """No stockout days yields None.""" + assert stockout_reason([False, False, False]) is None + + def test_none_for_empty_window(self) -> None: + """An empty window yields None.""" + assert stockout_reason([]) is None + + +class TestPromotionReason: + """Tests for promotion_reason.""" + + def test_fires_on_overlap(self) -> None: + """An overlapping promotion produces an info reason code.""" + code = promotion_reason([(date(2024, 2, 20), date(2024, 2, 25))], AS_OF) + assert code is not None + assert code.code == "promotion_overlap" + assert code.severity == "info" + + def test_detects_promotion_active_at_cutoff(self) -> None: + """A promotion active on as_of_date is called out in the detail.""" + code = promotion_reason([(date(2024, 2, 25), date(2024, 3, 10))], AS_OF) + assert code is not None + assert "still active" in code.detail + + def test_none_when_no_promotions(self) -> None: + """No promotions yields None.""" + assert promotion_reason([], AS_OF) is None + + +class TestLifecycleReason: + """Tests for lifecycle_reason.""" + + def test_fires_for_recent_launch(self) -> None: + """A launch under 30 days ago produces an info reason code.""" + code = lifecycle_reason(date(2024, 2, 15), AS_OF) + assert code is not None + assert code.code == "lifecycle_decay" + + def test_none_for_old_launch(self) -> None: + """A launch over 30 days ago yields None.""" + assert lifecycle_reason(date(2023, 1, 1), AS_OF) is None + + def test_none_when_launch_unknown(self) -> None: + """An unknown launch date yields None.""" + assert lifecycle_reason(None, AS_OF) is None + + +class TestHolidayReason: + """Tests for holiday_reason.""" + + def test_fires_on_holiday(self) -> None: + """A holiday forecast date produces an info reason code.""" + code = holiday_reason(True, "New Year", date(2024, 3, 2)) + assert code is not None + assert code.code == "holiday_effect" + assert "New Year" in code.detail + + def test_none_on_normal_day(self) -> None: + """A non-holiday forecast date yields None.""" + assert holiday_reason(False, None, date(2024, 3, 2)) is None + + +class TestHistoryReason: + """Tests for history_reason.""" + + def test_fires_for_short_series(self) -> None: + """Fewer observations than required produces a warn reason code.""" + code = history_reason(5, 14) + assert code is not None + assert code.code == "insufficient_history" + assert code.severity == "warn" + + def test_none_for_sufficient_history(self) -> None: + """Enough observations yields None.""" + assert history_reason(30, 14) is None + + +class TestBuildCaveats: + """Tests for build_caveats.""" + + def test_always_includes_correlation_caveat(self) -> None: + """Every caveat list starts with the correlation-vs-causation disclaimer.""" + caveats = build_caveats("naive", []) + assert caveats[0] == CORRELATION_CAVEAT + + def test_includes_model_specific_caveat(self) -> None: + """A model-specific caveat is appended for each baseline.""" + assert any("seasonality" in c for c in build_caveats("naive", [])) + assert any("prior cycle" in c for c in build_caveats("seasonal_naive", [])) + assert any("smooths" in c for c in build_caveats("moving_average", [])) + + def test_adds_stockout_caveat(self) -> None: + """A stockout reason code adds an understated-demand caveat.""" + stockout = ReasonCode(code="stockout_constrained", severity="warn", detail="x") + caveats = build_caveats("naive", [stockout]) + assert any("understate" in c for c in caveats) diff --git a/app/features/explainability/tests/test_routes.py b/app/features/explainability/tests/test_routes.py new file mode 100644 index 00000000..cf73ad84 --- /dev/null +++ b/app/features/explainability/tests/test_routes.py @@ -0,0 +1,153 @@ +"""Unit route tests for the explainability endpoints. + +Each test overrides ``get_db`` with a scripted-mock session, so the routes are +exercised over the HTTP boundary without a real database. Error paths assert the +RFC 7807 problem-detail shape. +""" + +from __future__ import annotations + +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager +from datetime import date +from types import SimpleNamespace +from typing import Any +from unittest.mock import AsyncMock + +import pytest +from httpx import ASGITransport, AsyncClient + +from app.core.database import get_db +from app.features.explainability.tests.conftest import ( + forecast_result_db, + make_mock_db, + mock_result, +) +from app.main import app + + +@asynccontextmanager +async def _client(db: AsyncMock) -> AsyncGenerator[AsyncClient, None]: + """Yield a test client whose get_db dependency yields ``db``.""" + + async def override_get_db() -> AsyncGenerator[AsyncMock, None]: + yield db + + app.dependency_overrides[get_db] = override_get_db + try: + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as ac: + yield ac + finally: + app.dependency_overrides.pop(get_db, None) + + +def _assert_problem_detail(body: dict[str, Any], expected_status: int) -> None: + """Assert an RFC 7807 problem-detail body shape.""" + for key in ("type", "title", "status", "detail"): + assert key in body, f"missing RFC 7807 field: {key}" + assert body["status"] == expected_status + + +@pytest.mark.asyncio +async def test_explain_forecast_returns_200() -> None: + """POST /explain/forecast returns 200 with a well-formed explanation.""" + db = forecast_result_db([10.0, 12.0, 11.0, 9.0, 14.0]) + async with _client(db) as ac: + response = await ac.post( + "/explain/forecast", + json={ + "store_id": 1, + "product_id": 2, + "model_type": "naive", + "as_of_date": "2024-03-01", + }, + ) + assert response.status_code == 200 + body = response.json() + assert body["forecast_value"] == 14.0 + assert body["method"] == "rule_based" + assert body["drivers"][0]["name"] == "last_observation" + + +@pytest.mark.asyncio +async def test_explain_forecast_rejects_iso_string_path() -> None: + """An ISO-string as_of_date is accepted (strict-mode JSON path).""" + db = forecast_result_db([10.0, 12.0, 11.0]) + async with _client(db) as ac: + response = await ac.post( + "/explain/forecast", + json={ + "store_id": 1, + "product_id": 2, + "model_type": "naive", + "as_of_date": "2024-03-01", + }, + ) + assert response.status_code == 200 + + +@pytest.mark.asyncio +async def test_explain_forecast_empty_series_returns_400() -> None: + """An empty series yields an RFC 7807 400.""" + db = forecast_result_db([]) + async with _client(db) as ac: + response = await ac.post( + "/explain/forecast", + json={ + "store_id": 1, + "product_id": 2, + "model_type": "naive", + "as_of_date": "2024-03-01", + }, + ) + assert response.status_code == 400 + _assert_problem_detail(response.json(), 400) + + +@pytest.mark.asyncio +async def test_explain_run_missing_returns_404() -> None: + """GET /explain/runs/{missing} yields an RFC 7807 404.""" + db = make_mock_db([mock_result(one=None)]) + async with _client(db) as ac: + response = await ac.get("/explain/runs/does-not-exist") + assert response.status_code == 404 + _assert_problem_detail(response.json(), 404) + + +@pytest.mark.asyncio +async def test_explain_run_lightgbm_returns_400() -> None: + """GET /explain/runs/{lightgbm-run} yields an RFC 7807 400.""" + run = SimpleNamespace( + run_id="run-lgbm", + model_type="lightgbm", + model_config={"model_type": "lightgbm"}, + store_id=1, + product_id=2, + data_window_end=date(2024, 3, 1), + ) + db = make_mock_db([mock_result(one=run)]) + async with _client(db) as ac: + response = await ac.get("/explain/runs/run-lgbm") + assert response.status_code == 400 + _assert_problem_detail(response.json(), 400) + + +@pytest.mark.asyncio +async def test_explain_job_missing_returns_404() -> None: + """GET /explain/jobs/{missing} yields an RFC 7807 404.""" + db = make_mock_db([mock_result(one=None)]) + async with _client(db) as ac: + response = await ac.get("/explain/jobs/does-not-exist") + assert response.status_code == 404 + _assert_problem_detail(response.json(), 404) + + +@pytest.mark.asyncio +async def test_explain_job_non_predict_returns_400() -> None: + """GET /explain/jobs/{train-job} yields an RFC 7807 400.""" + job = SimpleNamespace(job_id="job-train", job_type="train", status="completed", result={}) + db = make_mock_db([mock_result(one=job)]) + async with _client(db) as ac: + response = await ac.get("/explain/jobs/job-train") + assert response.status_code == 400 + _assert_problem_detail(response.json(), 400) diff --git a/app/features/explainability/tests/test_routes_integration.py b/app/features/explainability/tests/test_routes_integration.py new file mode 100644 index 00000000..6d79bc59 --- /dev/null +++ b/app/features/explainability/tests/test_routes_integration.py @@ -0,0 +1,85 @@ +"""End-to-end integration tests for the explainability endpoints. + +Run against the real docker-compose Postgres (``docker compose up -d``). The +``client`` fixture shares the test session, so a persisted explanation is +readable back through the same session after the request. +""" + +from __future__ import annotations + +import pytest +from httpx import AsyncClient +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.features.explainability.models import ForecastExplanation +from app.features.explainability.tests.conftest import TEST_END + + +@pytest.mark.integration +@pytest.mark.asyncio +class TestExplainEndpointsIntegration: + """End-to-end tests over a real database.""" + + async def test_explain_run_returns_explanation( + self, client: AsyncClient, seeded_run: str + ) -> None: + """GET /explain/runs/{run_id} explains a real baseline run.""" + response = await client.get(f"/explain/runs/{seeded_run}") + assert response.status_code == 200 + body = response.json() + assert body["model_type"] == "naive" + assert body["method"] == "rule_based" + assert body["drivers"] + assert body["confidence"] in ("high", "medium", "low") + assert body["caveats"] + assert body["agent_summary"] + # The naive forecast is the last observed value — a positive quantity. + assert body["forecast_value"] > 0 + + async def test_explain_run_persists_row( + self, client: AsyncClient, db_session: AsyncSession, seeded_run: str + ) -> None: + """The explanation is persisted as a forecast_explanation row.""" + await client.get(f"/explain/runs/{seeded_run}") + row = ( + await db_session.execute( + select(ForecastExplanation).where(ForecastExplanation.run_id == seeded_run) + ) + ).scalar_one() + assert row.model_type == "naive" + assert row.run_id == seeded_run + + async def test_explain_forecast_end_to_end( + self, client: AsyncClient, seeded_series: dict[str, int] + ) -> None: + """POST /explain/forecast explains an ad-hoc forecast over a real series.""" + response = await client.post( + "/explain/forecast", + json={ + "store_id": seeded_series["store_id"], + "product_id": seeded_series["product_id"], + "model_type": "seasonal_naive", + "as_of_date": TEST_END.isoformat(), + "season_length": 7, + }, + ) + assert response.status_code == 200 + body = response.json() + assert body["model_type"] == "seasonal_naive" + assert body["drivers"][0]["name"] == "season_match" + assert body["forecast_value"] > 0 + + async def test_explain_run_missing_returns_404(self, client: AsyncClient) -> None: + """GET /explain/runs/{missing} returns an RFC 7807 404.""" + response = await client.get("/explain/runs/no-such-run-id") + assert response.status_code == 404 + body = response.json() + assert body["status"] == 404 + assert "title" in body + + async def test_explain_job_missing_returns_404(self, client: AsyncClient) -> None: + """GET /explain/jobs/{missing} returns an RFC 7807 404.""" + response = await client.get("/explain/jobs/no-such-job-id") + assert response.status_code == 404 + assert response.json()["status"] == 404 diff --git a/app/features/explainability/tests/test_schemas.py b/app/features/explainability/tests/test_schemas.py new file mode 100644 index 00000000..156cdaf0 --- /dev/null +++ b/app/features/explainability/tests/test_schemas.py @@ -0,0 +1,134 @@ +"""Unit tests for the explainability Pydantic schemas. + +The JSON-path test (``test_request_accepts_iso_string_date``) is required by +``docs/_base/SECURITY.md`` — it exercises the ``validate_python`` path FastAPI +uses, catching the strict-mode date regression at unit-test time. +""" + +from __future__ import annotations + +from datetime import date + +import pytest +from pydantic import ValidationError + +from app.features.explainability.schemas import ( + ConfidenceLevel, + DriverContribution, + ExplainForecastRequest, + ForecastExplanation, + ReasonCode, +) + + +class TestExplainForecastRequest: + """Tests for the strict request body.""" + + def test_request_accepts_iso_string_date(self) -> None: + """as_of_date accepts an ISO-string (the FastAPI JSON path).""" + request = ExplainForecastRequest.model_validate( + { + "store_id": 1, + "product_id": 2, + "model_type": "naive", + "as_of_date": "2024-03-01", + } + ) + assert request.as_of_date == date(2024, 3, 1) + + def test_request_accepts_native_date(self) -> None: + """as_of_date also accepts a native date object.""" + request = ExplainForecastRequest( + store_id=1, product_id=2, model_type="naive", as_of_date=date(2024, 3, 1) + ) + assert request.as_of_date == date(2024, 3, 1) + + def test_invalid_model_type_rejected(self) -> None: + """A non-baseline model_type fails validation.""" + with pytest.raises(ValidationError): + ExplainForecastRequest.model_validate( + { + "store_id": 1, + "product_id": 2, + "model_type": "lightgbm", + "as_of_date": "2024-03-01", + } + ) + + def test_non_positive_store_id_rejected(self) -> None: + """store_id must be >= 1.""" + with pytest.raises(ValidationError): + ExplainForecastRequest.model_validate( + { + "store_id": 0, + "product_id": 2, + "model_type": "naive", + "as_of_date": "2024-03-01", + } + ) + + def test_optional_params_default_to_none(self) -> None: + """season_length and window_size default to None.""" + request = ExplainForecastRequest( + store_id=1, product_id=2, model_type="naive", as_of_date=date(2024, 3, 1) + ) + assert request.season_length is None + assert request.window_size is None + + +class TestForecastExplanation: + """Tests for the response schema.""" + + def test_round_trips_through_model_dump(self) -> None: + """A ForecastExplanation survives model_dump -> model_validate.""" + explanation = ForecastExplanation( + store_id=1, + product_id=2, + model_type="naive", + forecast_value=42.0, + drivers=[ + DriverContribution( + name="last_observation", + feature_value=42.0, + contribution=42.0, + direction="positive", + description="x", + ) + ], + reason_codes=[ReasonCode(code="holiday_effect", severity="info", detail="x")], + confidence=ConfidenceLevel.MEDIUM, + caveats=["correlation not causation"], + agent_summary="A summary.", + as_of_date=date(2024, 3, 1), + ) + restored = ForecastExplanation.model_validate(explanation.model_dump()) + assert restored.forecast_value == 42.0 + assert restored.method == "rule_based" + assert restored.confidence == ConfidenceLevel.MEDIUM + assert restored.drivers[0].name == "last_observation" + + def test_method_defaults_to_rule_based(self) -> None: + """method defaults to rule_based.""" + explanation = ForecastExplanation( + store_id=1, + product_id=2, + model_type="naive", + forecast_value=1.0, + drivers=[], + reason_codes=[], + confidence=ConfidenceLevel.LOW, + caveats=[], + agent_summary="x", + as_of_date=date(2024, 3, 1), + ) + assert explanation.method == "rule_based" + + +class TestConfidenceLevel: + """Tests for the ConfidenceLevel enum.""" + + def test_values(self) -> None: + """The enum carries the three expected string values.""" + assert ConfidenceLevel.HIGH.value == "high" + assert ConfidenceLevel.MEDIUM.value == "medium" + assert ConfidenceLevel.LOW.value == "low" diff --git a/app/features/explainability/tests/test_service.py b/app/features/explainability/tests/test_service.py new file mode 100644 index 00000000..f7be684d --- /dev/null +++ b/app/features/explainability/tests/test_service.py @@ -0,0 +1,179 @@ +"""Unit tests for ExplainabilityService with a scripted-mock AsyncSession. + +The mock session returns pre-built ``Result`` objects in ``execute`` call order +(see ``conftest.make_mock_db``) so the service logic is exercised without a DB. +""" + +from __future__ import annotations + +from datetime import date +from types import SimpleNamespace +from typing import Literal + +import pytest + +from app.core.exceptions import BadRequestError +from app.features.explainability.schemas import ( + ConfidenceLevel, + ExplainForecastRequest, + ForecastExplanation, +) +from app.features.explainability.service import ExplainabilityService +from app.features.explainability.tests.conftest import ( + forecast_result_db, + make_mock_db, + mock_result, + sales_rows, +) + + +def _request( + model_type: Literal["naive", "seasonal_naive", "moving_average"] = "naive", +) -> ExplainForecastRequest: + """Build an ExplainForecastRequest for the given model type.""" + return ExplainForecastRequest( + store_id=1, product_id=2, model_type=model_type, as_of_date=date(2024, 3, 1) + ) + + +class TestExplainForecast: + """Tests for ExplainabilityService.explain_forecast.""" + + async def test_returns_well_formed_explanation(self) -> None: + """A naive forecast explanation reproduces the last observed value.""" + db = forecast_result_db([10.0, 12.0, 11.0, 9.0, 14.0]) + explanation = await ExplainabilityService().explain_forecast(db, _request()) + + assert isinstance(explanation, ForecastExplanation) + assert explanation.forecast_value == 14.0 # last observation + assert explanation.method == "rule_based" + assert explanation.drivers[0].name == "last_observation" + assert explanation.agent_summary + # The correlation-vs-causation caveat is always present. + assert any("causality" in c for c in explanation.caveats) + + async def test_persists_the_explanation(self) -> None: + """The service adds, flushes, and refreshes a forecast_explanation row.""" + db = forecast_result_db([10.0, 12.0, 11.0]) + await ExplainabilityService().explain_forecast(db, _request()) + + db.add.assert_called_once() + db.flush.assert_awaited_once() + db.refresh.assert_awaited_once() + + async def test_short_series_flags_insufficient_history(self) -> None: + """A short series yields LOW confidence and an insufficient_history code.""" + db = forecast_result_db([10.0, 12.0, 11.0]) + explanation = await ExplainabilityService().explain_forecast(db, _request()) + + assert explanation.confidence == ConfidenceLevel.LOW + codes = {rc.code for rc in explanation.reason_codes} + assert "insufficient_history" in codes + + async def test_empty_series_raises_value_error(self) -> None: + """An empty series raises ValueError (route maps it to 400).""" + db = forecast_result_db([]) + with pytest.raises(ValueError, match="empty"): + await ExplainabilityService().explain_forecast(db, _request()) + + +class TestExplainRun: + """Tests for ExplainabilityService.explain_run.""" + + async def test_missing_run_returns_none(self) -> None: + """A missing run_id returns None (route maps it to 404).""" + db = make_mock_db([mock_result(one=None)]) + result = await ExplainabilityService().explain_run(db, "does-not-exist") + assert result is None + + async def test_explains_a_baseline_run(self) -> None: + """A baseline run resolves its config and produces an explanation.""" + run = SimpleNamespace( + run_id="run-abc", + model_type="naive", + model_config={"model_type": "naive"}, + store_id=1, + product_id=2, + data_window_end=date(2024, 3, 1), + ) + db = make_mock_db( + [ + mock_result(one=run), + mock_result(scalars=sales_rows([10.0, 20.0, 15.0])), + mock_result(scalars=[]), + mock_result(scalars=[]), + mock_result(one=None), + mock_result(one=None), + ] + ) + explanation = await ExplainabilityService().explain_run(db, "run-abc") + assert explanation is not None + assert explanation.forecast_value == 15.0 + + async def test_lightgbm_run_raises_value_error(self) -> None: + """A lightgbm run raises ValueError before any series load.""" + run = SimpleNamespace( + run_id="run-lgbm", + model_type="lightgbm", + model_config={"model_type": "lightgbm"}, + store_id=1, + product_id=2, + data_window_end=date(2024, 3, 1), + ) + db = make_mock_db([mock_result(one=run)]) + with pytest.raises(ValueError, match="baseline models only"): + await ExplainabilityService().explain_run(db, "run-lgbm") + + +class TestExplainJob: + """Tests for ExplainabilityService.explain_job.""" + + async def test_missing_job_returns_none(self) -> None: + """A missing job_id returns None (route maps it to 404).""" + db = make_mock_db([mock_result(one=None)]) + result = await ExplainabilityService().explain_job(db, "does-not-exist") + assert result is None + + async def test_non_completed_job_raises_bad_request(self) -> None: + """A pending predict job raises BadRequestError.""" + job = SimpleNamespace(job_id="job-1", job_type="predict", status="pending", result=None) + db = make_mock_db([mock_result(one=job)]) + with pytest.raises(BadRequestError, match="completed predict job"): + await ExplainabilityService().explain_job(db, "job-1") + + async def test_non_predict_job_raises_bad_request(self) -> None: + """A completed train job raises BadRequestError.""" + job = SimpleNamespace(job_id="job-2", job_type="train", status="completed", result={}) + db = make_mock_db([mock_result(one=job)]) + with pytest.raises(BadRequestError, match="completed predict job"): + await ExplainabilityService().explain_job(db, "job-2") + + async def test_explains_a_completed_predict_job(self) -> None: + """A completed predict job produces an explanation at the right cutoff.""" + job = SimpleNamespace( + job_id="job-3", + job_type="predict", + status="completed", + result={ + "store_id": 1, + "product_id": 2, + "model_type": "naive", + "horizon": 7, + "forecasts": [{"date": "2024-03-02", "forecast": 25.0}], + }, + ) + db = make_mock_db( + [ + mock_result(one=job), + mock_result(scalars=sales_rows([10.0, 20.0, 25.0])), + mock_result(scalars=[]), + mock_result(scalars=[]), + mock_result(one=None), + mock_result(one=None), + ] + ) + explanation = await ExplainabilityService().explain_job(db, "job-3") + assert explanation is not None + # as_of_date = day before the first forecast date. + assert explanation.as_of_date == date(2024, 3, 1) + assert explanation.forecast_value == 25.0 diff --git a/app/features/forecasting/models.py b/app/features/forecasting/models.py index 04f9dc05..ecb510b8 100644 --- a/app/features/forecasting/models.py +++ b/app/features/forecasting/models.py @@ -17,6 +17,9 @@ from typing import TYPE_CHECKING, Any, Literal import numpy as np +from sklearn.ensemble import ( # type: ignore[import-untyped] + HistGradientBoostingRegressor, +) if TYPE_CHECKING: from app.features.forecasting.schemas import ModelConfig @@ -422,8 +425,147 @@ def set_params(self, **params: Any) -> MovingAverageForecaster: # noqa: ANN401 return self +class RegressionForecaster(BaseForecaster): + """Feature-driven forecaster wrapping ``HistGradientBoostingRegressor``. + + CRITICAL: this is the FIRST forecaster that *consumes* the exogenous ``X`` + argument — the baseline forecasters all ignore it (each ``fit``/``predict`` + carries ``# noqa: ARG002``). Both ``fit`` and ``predict`` therefore REQUIRE + a non-``None`` ``X`` whose row count matches, and raise ``ValueError`` + otherwise — a regression model cannot forecast without its feature frame. + + ``HistGradientBoostingRegressor`` is deterministic given a fixed + ``random_state`` and tolerates ``NaN`` natively, which matters because the + future feature frame leaves lag cells ``NaN`` when their source target + lies in the (un-observed) horizon. + + Attributes: + max_iter: Number of boosting iterations. + learning_rate: Gradient-boosting learning rate. + max_depth: Maximum depth of each tree. + """ + + def __init__( + self, + *, + max_iter: int = 200, + learning_rate: float = 0.05, + max_depth: int = 6, + random_state: int = 42, + ) -> None: + """Initialize the regression forecaster. + + Args: + max_iter: Number of boosting iterations. + learning_rate: Gradient-boosting learning rate. + max_depth: Maximum depth of each tree. + random_state: Random seed for reproducibility (determinism). + """ + super().__init__(random_state) + self.max_iter = max_iter + self.learning_rate = learning_rate + self.max_depth = max_depth + self._estimator: Any = None + + def fit( + self, + y: np.ndarray[Any, np.dtype[np.floating[Any]]], + X: np.ndarray[Any, np.dtype[np.floating[Any]]] | None = None, + ) -> RegressionForecaster: + """Fit the gradient-boosted regressor on historical features. + + Args: + y: Target values (1D array of shape ``[n_samples]``). + X: Exogenous features (2D array of shape ``[n_samples, n_features]``). + REQUIRED — unlike the baseline forecasters. + + Returns: + self (for method chaining). + + Raises: + ValueError: If ``X`` is ``None``, ``y`` is empty, or the row counts + of ``X`` and ``y`` do not match. + """ + if X is None: + raise ValueError("RegressionForecaster requires exogenous features X for fit()") + if len(y) == 0: + raise ValueError("Cannot fit on empty array") + if X.shape[0] != len(y): + raise ValueError( + f"X has {X.shape[0]} rows but y has {len(y)} — feature/target rows must match" + ) + estimator: Any = HistGradientBoostingRegressor( + max_iter=self.max_iter, + learning_rate=self.learning_rate, + max_depth=self.max_depth, + random_state=self.random_state, + ) + estimator.fit(X, y) + self._estimator = estimator + self._last_values = np.asarray(y[-1:], dtype=np.float64) + self._is_fitted = True + return self + + def predict( + self, + horizon: int, + X: np.ndarray[Any, np.dtype[np.floating[Any]]] | None = None, + ) -> np.ndarray[Any, np.dtype[np.floating[Any]]]: + """Generate forecasts from a future feature frame. + + Args: + horizon: Number of steps to forecast. + X: Exogenous features for the forecast period, shape + ``[horizon, n_features]``. REQUIRED. + + Returns: + Array of forecasts with shape ``[horizon]``. + + Raises: + RuntimeError: If the model has not been fitted. + ValueError: If ``X`` is ``None`` or its row count is not ``horizon``. + """ + if not self._is_fitted or self._estimator is None: + raise RuntimeError("Model must be fitted before predict") + if X is None: + raise ValueError("RegressionForecaster requires exogenous features X for predict()") + if X.shape[0] != horizon: + raise ValueError(f"X has {X.shape[0]} rows but horizon is {horizon} — they must match") + predictions = self._estimator.predict(X) + result: np.ndarray[Any, np.dtype[np.floating[Any]]] = np.asarray( + predictions, dtype=np.float64 + ) + return result + + def get_params(self) -> dict[str, Any]: + """Get model parameters. + + Returns: + Dictionary with max_iter, learning_rate, max_depth, random_state. + """ + return { + "max_iter": self.max_iter, + "learning_rate": self.learning_rate, + "max_depth": self.max_depth, + "random_state": self.random_state, + } + + def set_params(self, **params: Any) -> RegressionForecaster: # noqa: ANN401 + """Set model parameters. + + Args: + **params: Parameter names and values to set. + + Returns: + self (for method chaining). + """ + for key, value in params.items(): + setattr(self, key, value) + return self + + # Type alias for model type literals -ModelType = Literal["naive", "seasonal_naive", "moving_average", "lightgbm"] +ModelType = Literal["naive", "seasonal_naive", "moving_average", "lightgbm", "regression"] def model_factory(config: ModelConfig, random_state: int = 42) -> BaseForecaster: @@ -472,5 +614,16 @@ def model_factory(config: ModelConfig, random_state: int = 42) -> BaseForecaster ) # LightGBM implementation would go here when feature-flagged raise NotImplementedError("LightGBM forecaster not yet implemented") + elif model_type == "regression": + from app.features.forecasting.schemas import RegressionModelConfig + + if isinstance(config, RegressionModelConfig): + return RegressionForecaster( + max_iter=config.max_iter, + learning_rate=config.learning_rate, + max_depth=config.max_depth, + random_state=random_state, + ) + raise ValueError("Invalid config type for regression") else: raise ValueError(f"Unknown model type: {model_type}") diff --git a/app/features/forecasting/schemas.py b/app/features/forecasting/schemas.py index 33d534f7..b019529c 100644 --- a/app/features/forecasting/schemas.py +++ b/app/features/forecasting/schemas.py @@ -144,9 +144,58 @@ class LightGBMModelConfig(ModelConfigBase): ) +class RegressionModelConfig(ModelConfigBase): + """Configuration for the exogenous-regressor forecaster (PRP-27). + + Wraps scikit-learn's ``HistGradientBoostingRegressor`` — a deterministic, + NaN-tolerant gradient-boosted tree model. Unlike the baseline forecasters, + a ``regression`` model *consumes* a per-day exogenous feature frame, so a + scenario what-if can be answered by genuinely re-forecasting demand + (``method="model_exogenous"``) rather than by a post-forecast multiplier. + + No feature flag and no new dependency: ``HistGradientBoostingRegressor`` + ships with the already-pinned ``scikit-learn`` (see + ``PRPs/ai_docs/exogenous-regressor-forecasting.md`` § 5). + + Attributes: + max_iter: Number of boosting iterations. + learning_rate: Gradient-boosting learning rate. + max_depth: Maximum depth of each tree. + feature_config_hash: Optional hash of the feature contract used. + """ + + model_type: Literal["regression"] = "regression" + max_iter: int = Field( + default=200, + ge=10, + le=1000, + description="Number of boosting iterations", + ) + learning_rate: float = Field( + default=0.05, + ge=0.001, + le=1.0, + description="Gradient-boosting learning rate", + ) + max_depth: int = Field( + default=6, + ge=1, + le=20, + description="Maximum depth of each tree", + ) + feature_config_hash: str | None = Field( + default=None, + description="Hash of the feature contract used for training", + ) + + # Union type for all model configs ModelConfig = ( - NaiveModelConfig | SeasonalNaiveModelConfig | MovingAverageModelConfig | LightGBMModelConfig + NaiveModelConfig + | SeasonalNaiveModelConfig + | MovingAverageModelConfig + | LightGBMModelConfig + | RegressionModelConfig ) diff --git a/app/features/forecasting/service.py b/app/features/forecasting/service.py index a1cf03fa..839f1475 100644 --- a/app/features/forecasting/service.py +++ b/app/features/forecasting/service.py @@ -11,6 +11,7 @@ from __future__ import annotations +import math import time import uuid from dataclasses import dataclass, field @@ -25,7 +26,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.core.config import get_settings -from app.features.data_platform.models import SalesDaily +from app.features.data_platform.models import Calendar, Product, Promotion, SalesDaily from app.features.forecasting.models import model_factory from app.features.forecasting.persistence import ( ModelBundle, @@ -68,6 +69,60 @@ def __post_init__(self) -> None: self.n_observations = len(self.y) +# Minimum observed rows required to train a regression model — enough to +# resolve the lag features and still leave training signal (PRP-27 GOTCHA #14). +_MIN_REGRESSION_TRAIN_ROWS = 30 +# Observed-target tail persisted in the bundle so the scenario future-frame +# generator can resolve long lags (PRP-27 DECISIONS LOCKED #11 — 90 days). +_REGRESSION_HISTORY_TAIL_DAYS = 90 +# Target lag offsets — PRP-27 DECISIONS LOCKED #10 (EXOGENOUS_LAGS). +_REGRESSION_LAGS: tuple[int, ...] = (1, 7, 14, 28) +# Canonical regression feature columns — a PAIRED CONTRACT with +# ``app/features/scenarios/feature_frame.canonical_feature_columns()``. The +# scenarios slice owns the future-frame generator; this slice owns training. +# A cross-slice import is forbidden (AGENTS.md § Architecture, PRP-27 +# DECISIONS LOCKED #3), so the column names and order are replicated here and +# kept in lock-step by the scenarios integration test (a column mismatch +# surfaces as a non-zero delta on an empty-assumption simulation). +_REGRESSION_FEATURE_COLUMNS: list[str] = [ + *(f"lag_{lag}" for lag in _REGRESSION_LAGS), + "dow_sin", + "dow_cos", + "month_sin", + "month_cos", + "is_weekend", + "is_month_end", + "price_factor", + "promo_active", + "is_holiday", + "days_since_launch", +] + + +@dataclass +class RegressionFeatureMatrix: + """Historical feature matrix + bundle metadata for a regression model. + + Attributes: + X: Feature matrix, shape ``[n_observations, n_features]`` (NaN allowed). + y: Target values, shape ``[n_observations]``. + feature_columns: Column order — persisted so the future frame matches. + history_tail: The last ``_REGRESSION_HISTORY_TAIL_DAYS`` observed + targets, ending at the forecast origin ``T``. + history_tail_dates: ISO dates aligned with ``history_tail``. + launch_date_iso: The product launch date (ISO) or ``None``. + n_observations: Number of training rows. + """ + + X: np.ndarray[Any, np.dtype[np.floating[Any]]] + y: np.ndarray[Any, np.dtype[np.floating[Any]]] + feature_columns: list[str] + history_tail: list[float] + history_tail_dates: list[str] + launch_date_iso: str | None + n_observations: int + + class ForecastingService: """Service for training and predicting with forecasting models. @@ -121,24 +176,43 @@ async def train_model( config_hash=config.config_hash(), ) - # Load training data - training_data = await self._load_training_data( - db=db, - store_id=store_id, - product_id=product_id, - start_date=train_start_date, - end_date=train_end_date, - ) - - if training_data.n_observations == 0: - raise ValueError( - f"No training data found for store={store_id}, product={product_id} " - f"between {train_start_date} and {train_end_date}" + # Build the model + bundle metadata. The regression path consumes a + # historical feature matrix; every other model trains on the raw + # target series exactly as before. + extra_metadata: dict[str, object] = {} + if config.model_type == "regression": + features = await self._build_regression_features( + db=db, + store_id=store_id, + product_id=product_id, + start_date=train_start_date, + end_date=train_end_date, ) - - # Create and fit model - model = model_factory(config, random_state=self.settings.forecast_random_seed) - model.fit(training_data.y) + model = model_factory(config, random_state=self.settings.forecast_random_seed) + model.fit(features.y, features.X) + n_observations = features.n_observations + extra_metadata = { + "feature_columns": features.feature_columns, + "history_tail": features.history_tail, + "history_tail_dates": features.history_tail_dates, + "launch_date": features.launch_date_iso, + } + else: + training_data = await self._load_training_data( + db=db, + store_id=store_id, + product_id=product_id, + start_date=train_start_date, + end_date=train_end_date, + ) + if training_data.n_observations == 0: + raise ValueError( + f"No training data found for store={store_id}, product={product_id} " + f"between {train_start_date} and {train_end_date}" + ) + model = model_factory(config, random_state=self.settings.forecast_random_seed) + model.fit(training_data.y) + n_observations = training_data.n_observations # Create bundle with metadata bundle = ModelBundle( @@ -149,7 +223,8 @@ async def train_model( "product_id": product_id, "train_start_date": str(train_start_date), "train_end_date": str(train_end_date), - "n_observations": training_data.n_observations, + "n_observations": n_observations, + **extra_metadata, }, ) @@ -166,7 +241,7 @@ async def train_model( product_id=product_id, model_type=config.model_type, config_hash=config.config_hash(), - n_observations=training_data.n_observations, + n_observations=n_observations, model_path=str(saved_path), duration_ms=duration_ms, ) @@ -177,7 +252,7 @@ async def train_model( model_type=config.model_type, model_path=str(saved_path), config_hash=config.config_hash(), - n_observations=training_data.n_observations, + n_observations=n_observations, train_start_date=train_start_date, train_end_date=train_end_date, duration_ms=duration_ms, @@ -267,6 +342,16 @@ async def predict( f"but prediction requested for product={product_id}" ) + # Regression models need an exogenous feature frame to forecast — that + # is built (from scenario assumptions) by POST /scenarios/simulate. The + # plain predict endpoint cannot supply one, so it rejects them cleanly. + if bundle.config.model_type == "regression": + raise ValueError( + "Regression models forecast through POST /scenarios/simulate, " + "which supplies the exogenous feature frame. POST /forecasting/" + "predict does not support model_type='regression'." + ) + # Generate forecasts forecasts_array = bundle.model.predict(horizon) @@ -365,3 +450,146 @@ async def _load_training_data( store_id=store_id, product_id=product_id, ) + + async def _build_regression_features( + self, + db: AsyncSession, + store_id: int, + product_id: int, + start_date: date_type, + end_date: date_type, + ) -> RegressionFeatureMatrix: + """Build the historical feature matrix for a regression model. + + Time-safe by construction: every lag column at row ``i`` reads only + the observed target at ``i - lag`` (a strictly earlier day); calendar + columns are pure functions of the date; ``price_factor`` / + ``promo_active`` / ``is_holiday`` / ``days_since_launch`` read the + same-day exogenous attributes. No row reads a future observation. + + The column set is the paired contract with the scenarios slice's + future-frame generator (see ``_REGRESSION_FEATURE_COLUMNS``). + + Args: + db: Database session. + store_id: Store ID. + product_id: Product ID. + start_date: Start of the training window (inclusive). + end_date: End of the training window (inclusive) — the origin ``T``. + + Returns: + The feature matrix plus the bundle metadata the future frame needs. + + Raises: + ValueError: When fewer than ``_MIN_REGRESSION_TRAIN_ROWS`` observed + days are available. + """ + sales_rows = ( + await db.execute( + select(SalesDaily.date, SalesDaily.quantity, SalesDaily.unit_price) + .where( + (SalesDaily.store_id == store_id) + & (SalesDaily.product_id == product_id) + & (SalesDaily.date >= start_date) + & (SalesDaily.date <= end_date) + ) + .order_by(SalesDaily.date) + ) + ).all() + if len(sales_rows) < _MIN_REGRESSION_TRAIN_ROWS: + raise ValueError( + f"A regression model needs at least {_MIN_REGRESSION_TRAIN_ROWS} " + f"observed days; store={store_id} product={product_id} has " + f"{len(sales_rows)} between {start_date} and {end_date}." + ) + + dates = [row.date for row in sales_rows] + quantities = [float(row.quantity) for row in sales_rows] + prices = [float(row.unit_price) for row in sales_rows] + + # Baseline price = median of the positive prices, so price_factor is + # ~1.0 on a typical day and < 1.0 on a markdown/promo day. + positive_prices = sorted(price for price in prices if price > 0.0) + baseline_price = positive_prices[len(positive_prices) // 2] if positive_prices else 1.0 + + holiday_dates: set[date_type] = set( + ( + await db.execute( + select(Calendar.date).where( + Calendar.date >= start_date, + Calendar.date <= end_date, + Calendar.is_holiday.is_(True), + ) + ) + ) + .scalars() + .all() + ) + + # Promotion-active days: store-specific OR chain-wide rows that overlap + # the training window, expanded to the set of dates they cover. + promo_rows = ( + await db.execute( + select(Promotion.start_date, Promotion.end_date).where( + Promotion.product_id == product_id, + (Promotion.store_id == store_id) | (Promotion.store_id.is_(None)), + Promotion.start_date <= end_date, + Promotion.end_date >= start_date, + ) + ) + ).all() + promo_dates: set[date_type] = set() + for promo in promo_rows: + day = max(promo.start_date, start_date) + last = min(promo.end_date, end_date) + while day <= last: + promo_dates.add(day) + day += timedelta(days=1) + + launch_date: date_type | None = await db.scalar( + select(Product.launch_date).where(Product.id == product_id) + ) + + feature_rows: list[list[float]] = [] + for index, day in enumerate(dates): + row_values: list[float] = [] + # Target long-lag columns — read only strictly-earlier observations. + for lag in _REGRESSION_LAGS: + row_values.append(quantities[index - lag] if index >= lag else math.nan) + # Calendar columns — pure functions of the date. + dow = day.weekday() + row_values.append(math.sin(2.0 * math.pi * dow / 7.0)) + row_values.append(math.cos(2.0 * math.pi * dow / 7.0)) + row_values.append(math.sin(2.0 * math.pi * day.month / 12.0)) + row_values.append(math.cos(2.0 * math.pi * day.month / 12.0)) + row_values.append(1.0 if dow >= 5 else 0.0) + row_values.append(1.0 if (day + timedelta(days=1)).month != day.month else 0.0) + # Exogenous columns — same-day attributes. + row_values.append(prices[index] / baseline_price) + row_values.append(1.0 if day in promo_dates else 0.0) + row_values.append(1.0 if day in holiday_dates else 0.0) + row_values.append( + float((day - launch_date).days) if launch_date is not None else math.nan + ) + feature_rows.append(row_values) + + tail = quantities[-_REGRESSION_HISTORY_TAIL_DAYS:] + tail_dates = [day.isoformat() for day in dates[-_REGRESSION_HISTORY_TAIL_DAYS:]] + + logger.info( + "forecasting.regression_features_built", + store_id=store_id, + product_id=product_id, + n_observations=len(dates), + n_features=len(_REGRESSION_FEATURE_COLUMNS), + ) + + return RegressionFeatureMatrix( + X=np.array(feature_rows, dtype=np.float64), + y=np.array(quantities, dtype=np.float64), + feature_columns=list(_REGRESSION_FEATURE_COLUMNS), + history_tail=[float(value) for value in tail], + history_tail_dates=tail_dates, + launch_date_iso=launch_date.isoformat() if launch_date is not None else None, + n_observations=len(dates), + ) diff --git a/app/features/forecasting/tests/test_regression_forecaster.py b/app/features/forecasting/tests/test_regression_forecaster.py new file mode 100644 index 00000000..22caae79 --- /dev/null +++ b/app/features/forecasting/tests/test_regression_forecaster.py @@ -0,0 +1,116 @@ +"""Unit tests for ``RegressionForecaster`` (PRP-27 Phase B). + +The regression forecaster is the first model that *consumes* the exogenous +``X`` argument, so these tests focus on the new contract: ``X`` is required, +its shape is validated, fits are deterministic, and ``NaN`` features are +tolerated (the future feature frame deliberately emits ``NaN`` cells). +""" + +from __future__ import annotations + +from typing import Any + +import numpy as np +import pytest + +from app.features.forecasting.models import RegressionForecaster, model_factory +from app.features.forecasting.schemas import RegressionModelConfig + +FloatArray = np.ndarray[Any, np.dtype[np.floating[Any]]] + + +def _synthetic_data( + n: int = 120, n_features: int = 6, seed: int = 0 +) -> tuple[FloatArray, FloatArray]: + """Build a synthetic feature matrix and a target that depends on it.""" + rng = np.random.default_rng(seed) + features = rng.normal(size=(n, n_features)) + target = 50.0 + 5.0 * features[:, 0] - 3.0 * features[:, 1] + rng.normal(scale=0.5, size=n) + return features.astype(np.float64), target.astype(np.float64) + + +def test_fit_predict_roundtrip() -> None: + """A fitted regression model produces a finite forecast of horizon length.""" + features, target = _synthetic_data() + model = RegressionForecaster() + model.fit(target, features) + assert model.is_fitted + + horizon = 10 + predictions = model.predict(horizon, features[:horizon]) + assert predictions.shape == (horizon,) + assert bool(np.all(np.isfinite(predictions))) + + +def test_fit_rejects_none_features() -> None: + """``fit`` raises when no exogenous features are supplied.""" + _, target = _synthetic_data() + with pytest.raises(ValueError, match="requires exogenous features"): + RegressionForecaster().fit(target, None) + + +def test_fit_rejects_mismatched_rows() -> None: + """``fit`` raises when feature and target row counts differ.""" + features, target = _synthetic_data() + with pytest.raises(ValueError, match="rows must match"): + RegressionForecaster().fit(target, features[:-5]) + + +def test_predict_rejects_none_features() -> None: + """``predict`` raises when no exogenous features are supplied.""" + features, target = _synthetic_data() + model = RegressionForecaster().fit(target, features) + with pytest.raises(ValueError, match="requires exogenous features"): + model.predict(5, None) + + +def test_predict_rejects_wrong_shape_features() -> None: + """``predict`` raises when the feature row count is not the horizon.""" + features, target = _synthetic_data() + model = RegressionForecaster().fit(target, features) + with pytest.raises(ValueError, match="horizon"): + model.predict(5, features[:8]) + + +def test_predict_before_fit_raises() -> None: + """``predict`` raises a RuntimeError before the model is fitted.""" + model = RegressionForecaster() + with pytest.raises(RuntimeError, match="fitted"): + model.predict(5, np.zeros((5, 3), dtype=np.float64)) + + +def test_determinism_same_random_state() -> None: + """Two fits with the same random_state yield identical forecasts.""" + features, target = _synthetic_data() + future = features[:12] + first = RegressionForecaster(random_state=7).fit(target, features) + second = RegressionForecaster(random_state=7).fit(target, features) + np.testing.assert_array_equal(first.predict(12, future), second.predict(12, future)) + + +def test_handles_nan_features() -> None: + """``HistGradientBoostingRegressor`` tolerates NaN feature cells natively.""" + features, target = _synthetic_data() + model = RegressionForecaster().fit(target, features) + future = features[:6].copy() + future[2, 0] = np.nan # the future frame emits NaN for un-resolvable lags + predictions = model.predict(6, future) + assert bool(np.all(np.isfinite(predictions))) + + +def test_get_and_set_params() -> None: + """``get_params`` reflects construction; ``set_params`` mutates in place.""" + model = RegressionForecaster(max_iter=150, learning_rate=0.03, max_depth=4) + params = model.get_params() + assert params["max_iter"] == 150 + assert params["learning_rate"] == 0.03 + assert params["max_depth"] == 4 + model.set_params(max_depth=9) + assert model.max_depth == 9 + + +def test_model_factory_creates_regression_forecaster() -> None: + """``model_factory`` dispatches a RegressionModelConfig to the right class.""" + model = model_factory(RegressionModelConfig(max_iter=120), random_state=42) + assert isinstance(model, RegressionForecaster) + assert model.max_iter == 120 diff --git a/app/features/jobs/service.py b/app/features/jobs/service.py index e559fb99..528bb20e 100644 --- a/app/features/jobs/service.py +++ b/app/features/jobs/service.py @@ -261,9 +261,16 @@ async def list_jobs( else: order_by = Job.created_at.desc() - # Apply pagination + # Apply pagination. Append created_at then the unique `job_id` as + # tie-breakers so rows with equal sort values keep a stable order + # across pages (offset pagination over a non-unique sort key is + # otherwise non-deterministic). offset = (page - 1) * page_size - stmt = stmt.order_by(order_by).offset(offset).limit(page_size) + stmt = ( + stmt.order_by(order_by, Job.created_at.desc(), Job.job_id.asc()) + .offset(offset) + .limit(page_size) + ) # Execute query result = await db.execute(stmt) diff --git a/app/features/jobs/tests/conftest.py b/app/features/jobs/tests/conftest.py index 41589643..2777253c 100644 --- a/app/features/jobs/tests/conftest.py +++ b/app/features/jobs/tests/conftest.py @@ -71,7 +71,9 @@ async def override_get_db() -> AsyncGenerator[AsyncSession, None]: ) as ac: yield ac - app.dependency_overrides.clear() + # Remove only this fixture's override — clear() would also drop overrides + # installed by other fixtures sharing the app instance. + app.dependency_overrides.pop(get_db, None) @pytest.fixture diff --git a/app/features/ops/__init__.py b/app/features/ops/__init__.py new file mode 100644 index 00000000..2c534a50 --- /dev/null +++ b/app/features/ops/__init__.py @@ -0,0 +1,23 @@ +"""ForecastOps Control Center slice. + +A read-only vertical slice that aggregates operational state across the +``jobs``, ``registry``, and ``data_platform`` slices: system health, job / run / +alias health, data freshness, a needs-attention list, and a ranked +retraining-candidate queue. Has no models and no migration — it only reads. +""" + +from app.features.ops.routes import router +from app.features.ops.schemas import ( + ModelHealthResponse, + OpsSummaryResponse, + RetrainingCandidatesResponse, +) +from app.features.ops.service import OpsService + +__all__ = [ + "ModelHealthResponse", + "OpsService", + "OpsSummaryResponse", + "RetrainingCandidatesResponse", + "router", +] diff --git a/app/features/ops/routes.py b/app/features/ops/routes.py new file mode 100644 index 00000000..fd7381b6 --- /dev/null +++ b/app/features/ops/routes.py @@ -0,0 +1,128 @@ +"""API routes for the ForecastOps Control Center. + +Three read-only aggregation endpoints backing the ``/ops`` Control Center page: +operational summary, the ranked retraining-candidate queue, and per-grain +forecast-error health with a drift verdict. +""" + +from fastapi import APIRouter, Depends, Query +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.database import get_db +from app.features.ops.schemas import ( + ModelHealthResponse, + OpsSummaryResponse, + RetrainingCandidatesResponse, +) +from app.features.ops.service import OpsService + +router = APIRouter(prefix="/ops", tags=["ops"]) + + +@router.get( + "/summary", + response_model=OpsSummaryResponse, + summary="Operational summary for the Control Center", + description=""" +Aggregate the system's operational state into one response. + +**Sections**: +- `system`: API liveness, database connectivity, latest completed job. +- `jobs`: per-status job histogram plus active / failed / completed-today counts. +- `runs`: per-status model-run histogram plus success rate and failed count. +- `aliases`: every deployment alias with a staleness verdict. +- `freshness`: latest sales date, latest completed job, latest successful run. +- `attention_items`: recent failed jobs, failed runs, and stale aliases. + +Returns HTTP 200 even on an empty database — every section degrades to +zeros / nulls / empty lists rather than erroring. +""", +) +async def get_ops_summary( + db: AsyncSession = Depends(get_db), +) -> OpsSummaryResponse: + """Return the aggregated operational summary. + + Args: + db: Database session. + + Returns: + The full operational summary. + """ + return await OpsService().get_summary(db) + + +@router.get( + "/retraining-candidates", + response_model=RetrainingCandidatesResponse, + summary="Ranked retraining-candidate queue", + description=""" +Rank `(store, product)` grains by a deterministic retraining-priority score. + +Each grain is evaluated from its latest successful model run. The score blends +a time-based signal (staleness since the training-data window ended) with a +performance-based signal (WAPE), so the highest-scoring rows are the most +overdue and/or least accurate. + +Candidates are sorted by `priority_score` descending and capped at `limit`. +""", +) +async def get_retraining_candidates( + limit: int = Query( + default=20, + ge=1, + le=100, + description="Maximum number of candidates to return (1-100, default 20).", + ), + db: AsyncSession = Depends(get_db), +) -> RetrainingCandidatesResponse: + """Return the ranked retraining-candidate queue. + + Args: + limit: Maximum number of candidates to return. + db: Database session. + + Returns: + Candidates sorted by priority score (highest first). + """ + return await OpsService().get_retraining_candidates(db, limit) + + +@router.get( + "/model-health", + response_model=ModelHealthResponse, + summary="Per-(store, product) forecast-error health and drift", + description=""" +Classify forecast-error **performance drift** for every `(store, product)` grain. + +For each grain the endpoint reads the **full** successful-run history, extracts +each run's WAPE, and compares the latest WAPE against the mean of the prior +WAPEs within a ±10% relative band — yielding a drift verdict +(`improving` / `stable` / `degrading` / `unknown`). + +Entries are sorted **degrading-first**, then by the magnitude of the WAPE +change, and capped at `limit`. Returns HTTP 200 even on an empty database. + +This is a performance-drift signal, not data drift — it needs no feature +snapshots and adds no new table or migration. +""", +) +async def get_model_health( + limit: int = Query( + default=20, + ge=1, + le=100, + description="Maximum number of grains to return (1-100, default 20).", + ), + db: AsyncSession = Depends(get_db), +) -> ModelHealthResponse: + """Return per-grain forecast-error health and drift. + + Args: + limit: Maximum number of grains to return. + db: Database session. + + Returns: + Grains sorted degrading-first, then by absolute WAPE change. + """ + return await OpsService().get_model_health(db, limit) diff --git a/app/features/ops/schemas.py b/app/features/ops/schemas.py new file mode 100644 index 00000000..02a8405c --- /dev/null +++ b/app/features/ops/schemas.py @@ -0,0 +1,336 @@ +"""Pydantic schemas for the ForecastOps Control Center. + +All models here are HTTP **response** models — they are built from aggregated +operational state, never parsed from a request body. They therefore use +``ConfigDict(from_attributes=True)`` and deliberately do NOT set +``strict=True`` (the strict-mode request-body policy does not apply). +""" + +from datetime import date, datetime +from typing import Literal + +from pydantic import BaseModel, ConfigDict, Field + +# ============================================================================= +# System & freshness +# ============================================================================= + + +class SystemHealth(BaseModel): + """Liveness snapshot for the Control Center header.""" + + model_config = ConfigDict(from_attributes=True) + + api_ok: bool = Field( + ..., + description="True whenever this response was produced (the API served the request).", + ) + database_connected: bool = Field( + ..., + description="True when a 'SELECT 1' probe against PostgreSQL succeeded.", + ) + latest_successful_job_at: datetime | None = Field( + None, + description="Completion timestamp of the most recent completed job. " + "Null when no job has completed yet.", + ) + + +class DataFreshness(BaseModel): + """How current the underlying data and model state are.""" + + model_config = ConfigDict(from_attributes=True) + + latest_sales_date: date | None = Field( + None, + description="Most recent date present in sales_daily. Null when no sales exist.", + ) + latest_job_completed_at: datetime | None = Field( + None, + description="Completion timestamp of the most recently finished job (any outcome).", + ) + latest_run_completed_at: datetime | None = Field( + None, + description="Completion timestamp of the most recent successful model run.", + ) + + +# ============================================================================= +# Job & run health +# ============================================================================= + + +class StatusCount(BaseModel): + """One row of a status histogram (e.g. how many jobs are 'failed').""" + + model_config = ConfigDict(from_attributes=True) + + status: str = Field(..., description="The lifecycle status value.") + count: int = Field(..., ge=0, description="Number of entities in that status.") + + +class JobHealth(BaseModel): + """Aggregated job-execution health.""" + + model_config = ConfigDict(from_attributes=True) + + counts: list[StatusCount] = Field( + ..., + description="One entry per JobStatus, zero-filled for statuses with no rows.", + ) + completed_today: int = Field( + ..., + ge=0, + description="Jobs that completed since 00:00 UTC today.", + ) + failed_total: int = Field(..., ge=0, description="Total jobs in the 'failed' status.") + active_total: int = Field( + ..., + ge=0, + description="Jobs currently pending or running.", + ) + + +class RunHealth(BaseModel): + """Aggregated model-run health.""" + + model_config = ConfigDict(from_attributes=True) + + counts: list[StatusCount] = Field( + ..., + description="One entry per RunStatus, zero-filled for statuses with no rows.", + ) + success_rate: float | None = Field( + None, + description="Successful runs divided by non-archived runs. " + "Null when there are no non-archived runs.", + ) + failed_total: int = Field(..., ge=0, description="Total runs in the 'failed' status.") + + +# ============================================================================= +# Alias health +# ============================================================================= + + +class AliasHealth(BaseModel): + """Deployment-alias health, including a staleness verdict.""" + + model_config = ConfigDict(from_attributes=True) + + alias_name: str = Field(..., description="The deployment alias name.") + run_id: str = Field(..., description="External run_id of the aliased model run.") + run_status: str = Field(..., description="Lifecycle status of the aliased run.") + model_type: str = Field(..., description="Model type of the aliased run.") + store_id: int = Field(..., description="Store the aliased run targets.") + product_id: int = Field(..., description="Product the aliased run targets.") + is_stale: bool = Field( + ..., + description="True when the alias points at a non-successful run, or a newer " + "successful run exists for the same store/product.", + ) + stale_reason: str | None = Field( + None, + description="Human-readable explanation when is_stale is true; null otherwise.", + ) + wape: float | None = Field( + None, + description="WAPE of the aliased run, when present in its metrics; null otherwise.", + ) + + +# ============================================================================= +# Attention items +# ============================================================================= + + +class AttentionItem(BaseModel): + """One entry in the 'needs attention' list.""" + + model_config = ConfigDict(from_attributes=True) + + item_type: Literal["failed_job", "failed_run", "stale_alias"] = Field( + ..., + description="What kind of problem this row represents.", + ) + entity_id: str = Field( + ..., + description="job_id for failed_job; run_id for failed_run and stale_alias. " + "Used to deep-link to the matching Explorer detail page.", + ) + label: str = Field(..., description="Short title for the row.") + detail: str = Field(..., description="Longer explanation (error message or stale reason).") + occurred_at: datetime | None = Field( + None, + description="When the entity was created. Null when unknown.", + ) + + +# ============================================================================= +# Summary response +# ============================================================================= + + +class OpsSummaryResponse(BaseModel): + """Aggregated operational summary for the Control Center page.""" + + model_config = ConfigDict(from_attributes=True) + + system: SystemHealth = Field(..., description="Liveness and connectivity.") + jobs: JobHealth = Field(..., description="Job-execution health.") + runs: RunHealth = Field(..., description="Model-run health.") + aliases: list[AliasHealth] = Field( + ..., + description="Every deployment alias with its staleness verdict.", + ) + freshness: DataFreshness = Field(..., description="How current data and models are.") + attention_items: list[AttentionItem] = Field( + ..., + description="Recent failed jobs, failed runs, and stale aliases.", + ) + generated_at: datetime = Field(..., description="When this summary was computed (UTC).") + + +# ============================================================================= +# Retraining-candidate queue +# ============================================================================= + + +class RetrainingCandidate(BaseModel): + """One (store, product) pair ranked for retraining.""" + + model_config = ConfigDict(from_attributes=True) + + store_id: int = Field(..., description="Store the candidate covers.") + product_id: int = Field(..., description="Product the candidate covers.") + priority_score: float = Field( + ..., + ge=0.0, + le=1.0, + description="Retraining-priority score in [0, 1]; higher means more urgent.", + ) + staleness_days: int = Field( + ..., + ge=0, + description="Days since the latest successful run's training-data window ended.", + ) + wape: float | None = Field( + None, + description="WAPE of the latest successful run, when known; null otherwise.", + ) + latest_run_id: str | None = Field( + None, + description="External run_id of the latest successful run for this grain.", + ) + latest_run_status: str | None = Field( + None, + description="Status of that run (always 'success' for evaluated candidates).", + ) + reason: str = Field(..., description="Human-readable rationale for the score.") + + +class RetrainingCandidatesResponse(BaseModel): + """Ranked retraining-candidate queue.""" + + model_config = ConfigDict(from_attributes=True) + + candidates: list[RetrainingCandidate] = Field( + ..., + description="Candidates sorted by priority_score (descending), capped at the limit.", + ) + total_evaluated: int = Field( + ..., + ge=0, + description="Total (store, product) grains evaluated before applying the limit.", + ) + generated_at: datetime = Field(..., description="When this queue was computed (UTC).") + + +# ============================================================================= +# Model health & drift +# ============================================================================= + + +# Forecast-error trend verdict for a (store, product) grain. +DriftDirection = Literal["improving", "stable", "degrading", "unknown"] + + +class WapePoint(BaseModel): + """One run's WAPE observation in a grain's chronological history.""" + + model_config = ConfigDict(from_attributes=True) + + run_id: str = Field(..., description="External run_id of the observed run.") + created_at: datetime = Field(..., description="When the run was created (UTC).") + wape: float | None = Field( + None, + description="WAPE of the run, when present in its metrics; null otherwise.", + ) + + +class ModelHealthEntry(BaseModel): + """Forecast-error health and drift verdict for one (store, product) grain.""" + + model_config = ConfigDict(from_attributes=True) + + store_id: int = Field(..., description="Store the grain covers.") + product_id: int = Field(..., description="Product the grain covers.") + run_count: int = Field( + ..., + ge=0, + description="Number of successful runs evaluated for this grain.", + ) + latest_run_id: str | None = Field( + None, + description="External run_id of the most recent successful run.", + ) + latest_run_status: str | None = Field( + None, + description="Status of that run (always 'success' for evaluated grains).", + ) + latest_wape: float | None = Field( + None, + description="Most recent numeric WAPE in the grain's history; null when none.", + ) + previous_wape: float | None = Field( + None, + description="The prior numeric WAPE, used as the drift baseline; null when none.", + ) + wape_delta: float | None = Field( + None, + description="latest_wape minus previous_wape; null when fewer than two numeric WAPEs.", + ) + drift_direction: DriftDirection = Field( + ..., + description="Forecast-error trend: improving / stable / degrading / unknown.", + ) + last_trained_at: datetime | None = Field( + None, + description="created_at of the latest successful run; null when none.", + ) + staleness_days: int = Field( + ..., + ge=0, + description="Days since the latest run's training-data window ended.", + ) + wape_history: list[WapePoint] = Field( + ..., + description="Chronological WAPE observations; may carry null gaps.", + ) + + +class ModelHealthResponse(BaseModel): + """Per-grain forecast-error health, degrading grains first.""" + + model_config = ConfigDict(from_attributes=True) + + entries: list[ModelHealthEntry] = Field( + ..., + description="Grains sorted degrading-first, then by |wape_delta| descending.", + ) + total_evaluated: int = Field( + ..., + ge=0, + description="Total grains evaluated before applying the limit.", + ) + generated_at: datetime = Field(..., description="When this report was computed (UTC).") diff --git a/app/features/ops/service.py b/app/features/ops/service.py new file mode 100644 index 00000000..e88e2054 --- /dev/null +++ b/app/features/ops/service.py @@ -0,0 +1,543 @@ +"""Service layer for the ForecastOps Control Center. + +Read-only aggregation across sibling slices. This module imports the ORM +**models** of the ``jobs``, ``registry``, and ``data_platform`` slices and runs +read-only ``select()`` queries against them. It deliberately does NOT import any +sibling ``service.py`` or ``schemas.py`` — the cross-slice coupling is confined +to the verified, read-only ORM surface (see PRP-24, decision #1). +""" + +from datetime import UTC, datetime +from itertools import groupby +from typing import Any + +from sqlalchemy import func, select, text +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.logging import get_logger +from app.features.data_platform.models import SalesDaily +from app.features.jobs.models import Job, JobStatus +from app.features.ops.schemas import ( + AliasHealth, + AttentionItem, + DataFreshness, + DriftDirection, + JobHealth, + ModelHealthEntry, + ModelHealthResponse, + OpsSummaryResponse, + RetrainingCandidate, + RetrainingCandidatesResponse, + RunHealth, + StatusCount, + SystemHealth, + WapePoint, +) +from app.features.registry.models import DeploymentAlias, ModelRun, RunStatus + +logger = get_logger(__name__) + +# Staleness (days) at which the time-based component of the score saturates. +_STALENESS_CAP_DAYS = 90 +# WAPE value at which the error-based component of the score saturates. +_WAPE_CAP = 100.0 +# How many recent failed jobs / runs to surface in the attention list. +_ATTENTION_LIMIT = 10 +# Relative WAPE-change band: forecast-error drift is only flagged outside ±10%. +_DRIFT_BAND = 0.10 + + +# ============================================================================= +# Pure helpers (no DB, no I/O — unit-tested directly) +# ============================================================================= + + +def extract_wape(metrics: dict[str, Any] | None) -> float | None: + """Pull a WAPE value out of a model run's ``metrics`` JSONB blob. + + Tolerant by design: ``model_run.metrics`` is frequently None or carries an + unrelated metric set (backtest WAPE persists to ``job.result``, not run + metrics), so this returns None rather than raising whenever a numeric WAPE + cannot be found. Booleans are rejected — ``bool`` is an ``int`` subclass but + is never a valid metric value. + + Args: + metrics: The ``ModelRun.metrics`` JSONB dict, or None. + + Returns: + The WAPE as a float, or None when absent / non-numeric. + """ + if not metrics: + return None + for key in ("wape", "wape_mean", "WAPE"): + value = metrics.get(key) + if isinstance(value, bool): + continue + if isinstance(value, (int, float)): + return float(value) + return None + + +def score_retraining_candidate(staleness_days: int, wape: float | None) -> float: + """Compute a deterministic retraining-priority score in ``[0.0, 1.0]``. + + Blends a time-based signal (staleness, capped at 90 days, 60% weight) with a + performance-based signal (WAPE, capped at 100, 40% weight) — the hybrid + trigger recommended by MLOps retraining guidance. When WAPE is unknown the + score degrades gracefully to staleness-only. Never raises. + + Args: + staleness_days: Days since the run's training-data window ended. + wape: The run's WAPE, or None when unknown. + + Returns: + Priority score rounded to 4 decimals; higher means more urgent. + """ + staleness_norm = min(max(staleness_days, 0), _STALENESS_CAP_DAYS) / _STALENESS_CAP_DAYS + error_norm = min(max(wape, 0.0), _WAPE_CAP) / _WAPE_CAP if wape is not None else 0.0 + return round(0.6 * staleness_norm + 0.4 * error_norm, 4) + + +def classify_drift( + wape_history: list[float | None], +) -> tuple[DriftDirection, float | None]: + """Classify a grain's forecast-error (WAPE) trend. + + Pure and total: never raises, tolerates None gaps and sparse history. + Compares the latest numeric WAPE against the mean of all prior numeric + WAPEs, applying a ±10% relative band — the heuristic drift tolerance from + MLOps monitoring guidance (a universal threshold does not exist). + + Args: + wape_history: Chronological WAPE values; None marks a run with no WAPE. + + Returns: + A ``(direction, delta)`` tuple. ``direction`` is improving / stable / + degrading / unknown; ``delta`` is the latest numeric WAPE minus the + previous numeric WAPE, or None when fewer than two numeric values exist. + """ + numeric = [wape for wape in wape_history if wape is not None] + if len(numeric) < 2: + return "unknown", None + latest = numeric[-1] + prior = numeric[:-1] + baseline = sum(prior) / len(prior) + delta = round(latest - prior[-1], 4) + if baseline <= 0: + # Avoid div-by-zero on a zero baseline: any positive error is degrading. + return ("degrading" if latest > 0 else "stable"), delta + relative = (latest - baseline) / baseline + if relative > _DRIFT_BAND: + return "degrading", delta + if relative < -_DRIFT_BAND: + return "improving", delta + return "stable", delta + + +def _alias_staleness( + run: ModelRun, + latest_success_by_grain: dict[tuple[int, int], ModelRun], +) -> tuple[bool, str | None]: + """Decide whether an aliased run is stale, and why. + + An alias is stale when its run is no longer a successful run, or when a + newer successful run exists for the same ``(store, product)`` grain — the + industry-standard alias-staleness check (cf. MLflow alias governance). + + Args: + run: The model run the alias points at. + latest_success_by_grain: Latest successful run keyed by (store, product). + + Returns: + A ``(is_stale, reason)`` tuple; ``reason`` is None when not stale. + """ + if run.status != RunStatus.SUCCESS.value: + return True, f"aliased run status is '{run.status}', not 'success'" + latest = latest_success_by_grain.get((run.store_id, run.product_id)) + if latest is not None and latest.id != run.id and latest.created_at > run.created_at: + return True, "a newer successful run exists for this store/product" + return False, None + + +# ============================================================================= +# Service +# ============================================================================= + + +class OpsService: + """Read-only operational aggregation for the Control Center.""" + + async def get_summary(self, db: AsyncSession) -> OpsSummaryResponse: + """Aggregate system, job, run, alias, and freshness state. + + Args: + db: Database session. + + Returns: + The full operational summary. Never raises on an empty database — + every section degrades to zeros / nulls / empty lists. + """ + now = datetime.now(UTC) + + # ---- System health ------------------------------------------------ + try: + await db.execute(text("SELECT 1")) + database_connected = True + except Exception: + # Deliberate connectivity probe: any failure means "not connected". + database_connected = False + + latest_successful_job_at = await db.scalar( + select(func.max(Job.completed_at)).where(Job.status == JobStatus.COMPLETED.value) + ) + + # ---- Job health --------------------------------------------------- + job_count_rows = ( + await db.execute(select(Job.status, func.count()).group_by(Job.status)) + ).all() + job_count_map: dict[str, int] = {str(row[0]): int(row[1]) for row in job_count_rows} + job_counts = [ + StatusCount(status=status.value, count=job_count_map.get(status.value, 0)) + for status in JobStatus + ] + start_of_day = now.replace(hour=0, minute=0, second=0, microsecond=0) + completed_today = int( + await db.scalar( + select(func.count()) + .select_from(Job) + .where( + Job.status == JobStatus.COMPLETED.value, + Job.completed_at >= start_of_day, + ) + ) + or 0 + ) + jobs = JobHealth( + counts=job_counts, + completed_today=completed_today, + failed_total=job_count_map.get(JobStatus.FAILED.value, 0), + active_total=( + job_count_map.get(JobStatus.PENDING.value, 0) + + job_count_map.get(JobStatus.RUNNING.value, 0) + ), + ) + + # ---- Run health --------------------------------------------------- + run_count_rows = ( + await db.execute(select(ModelRun.status, func.count()).group_by(ModelRun.status)) + ).all() + run_count_map: dict[str, int] = {str(row[0]): int(row[1]) for row in run_count_rows} + run_counts = [ + StatusCount(status=status.value, count=run_count_map.get(status.value, 0)) + for status in RunStatus + ] + eligible = sum(run_count_map.values()) - run_count_map.get(RunStatus.ARCHIVED.value, 0) + success_rate = ( + run_count_map.get(RunStatus.SUCCESS.value, 0) / eligible if eligible > 0 else None + ) + runs = RunHealth( + counts=run_counts, + success_rate=success_rate, + failed_total=run_count_map.get(RunStatus.FAILED.value, 0), + ) + + # ---- Alias health ------------------------------------------------- + # Latest successful run per (store, product) — the staleness baseline. + latest_success_runs = ( + ( + await db.execute( + select(ModelRun) + .where(ModelRun.status == RunStatus.SUCCESS.value) + .distinct(ModelRun.store_id, ModelRun.product_id) + .order_by( + ModelRun.store_id, + ModelRun.product_id, + ModelRun.created_at.desc(), + ) + ) + ) + .scalars() + .all() + ) + latest_success_by_grain: dict[tuple[int, int], ModelRun] = { + (run.store_id, run.product_id): run for run in latest_success_runs + } + + # Two-query alias load. NEVER touch DeploymentAlias.run — accessing that + # relationship under AsyncSession triggers a lazy load (MissingGreenlet). + # Resolve the integer FK into a typed map of single-entity rows instead. + alias_rows = (await db.execute(select(DeploymentAlias))).scalars().all() + alias_run_ids = {alias.run_id for alias in alias_rows} + runs_by_id: dict[int, ModelRun] = {} + if alias_run_ids: + runs_by_id = { + run.id: run + for run in ( + (await db.execute(select(ModelRun).where(ModelRun.id.in_(alias_run_ids)))) + .scalars() + .all() + ) + } + + aliases: list[AliasHealth] = [] + stale_alias_items: list[AttentionItem] = [] + for alias in alias_rows: + run = runs_by_id.get(alias.run_id) + if run is None: # orphan FK — defensive; the FK constraint forbids it + continue + is_stale, stale_reason = _alias_staleness(run, latest_success_by_grain) + aliases.append( + AliasHealth( + alias_name=alias.alias_name, + run_id=run.run_id, + run_status=run.status, + model_type=run.model_type, + store_id=run.store_id, + product_id=run.product_id, + is_stale=is_stale, + stale_reason=stale_reason, + wape=extract_wape(run.metrics), + ) + ) + if is_stale: + stale_alias_items.append( + AttentionItem( + item_type="stale_alias", + entity_id=run.run_id, + label=f"alias '{alias.alias_name}' is stale", + detail=stale_reason or "alias is stale", + occurred_at=run.created_at, + ) + ) + + # ---- Data freshness ----------------------------------------------- + freshness = DataFreshness( + latest_sales_date=await db.scalar(select(func.max(SalesDaily.date))), + latest_job_completed_at=await db.scalar(select(func.max(Job.completed_at))), + latest_run_completed_at=await db.scalar( + select(func.max(ModelRun.completed_at)).where( + ModelRun.status == RunStatus.SUCCESS.value + ) + ), + ) + + # ---- Attention items ---------------------------------------------- + failed_jobs = ( + ( + await db.execute( + select(Job) + .where(Job.status == JobStatus.FAILED.value) + .order_by(Job.created_at.desc()) + .limit(_ATTENTION_LIMIT) + ) + ) + .scalars() + .all() + ) + failed_runs = ( + ( + await db.execute( + select(ModelRun) + .where(ModelRun.status == RunStatus.FAILED.value) + .order_by(ModelRun.created_at.desc()) + .limit(_ATTENTION_LIMIT) + ) + ) + .scalars() + .all() + ) + + attention_items: list[AttentionItem] = [ + AttentionItem( + item_type="failed_job", + entity_id=job.job_id, + label=f"{job.job_type} job failed", + detail=job.error_message or job.error_type or "Job failed", + occurred_at=job.created_at, + ) + for job in failed_jobs + ] + attention_items.extend( + AttentionItem( + item_type="failed_run", + entity_id=run.run_id, + label=f"{run.model_type} run failed", + detail=run.error_message or "Run failed", + occurred_at=run.created_at, + ) + for run in failed_runs + ) + attention_items.extend(stale_alias_items) + + logger.info( + "ops.summary_computed", + database_connected=database_connected, + failed_jobs=len(failed_jobs), + failed_runs=len(failed_runs), + stale_aliases=len(stale_alias_items), + ) + + return OpsSummaryResponse( + system=SystemHealth( + api_ok=True, + database_connected=database_connected, + latest_successful_job_at=latest_successful_job_at, + ), + jobs=jobs, + runs=runs, + aliases=aliases, + freshness=freshness, + attention_items=attention_items, + generated_at=now, + ) + + async def get_retraining_candidates( + self, db: AsyncSession, limit: int + ) -> RetrainingCandidatesResponse: + """Rank ``(store, product)`` grains by retraining priority. + + One candidate per grain — derived from its latest successful run. + + Args: + db: Database session. + limit: Maximum candidates to return (bounded 1..100 by the route). + + Returns: + Candidates sorted by ``priority_score`` descending, capped at limit. + """ + today = datetime.now(UTC).date() + + # Latest successful run per (store, product) — DISTINCT ON requires the + # ORDER BY to lead with the DISTINCT ON columns; created_at (non-null + # TimestampMixin column) is the "latest" tiebreaker. + latest_success_runs = ( + ( + await db.execute( + select(ModelRun) + .where(ModelRun.status == RunStatus.SUCCESS.value) + .distinct(ModelRun.store_id, ModelRun.product_id) + .order_by( + ModelRun.store_id, + ModelRun.product_id, + ModelRun.created_at.desc(), + ) + ) + ) + .scalars() + .all() + ) + + candidates: list[RetrainingCandidate] = [] + for run in latest_success_runs: + raw_staleness = (today - run.data_window_end).days + staleness_days = max(raw_staleness, 0) + wape = extract_wape(run.metrics) + score = score_retraining_candidate(raw_staleness, wape) + wape_part = f"WAPE {wape:.1f}" if wape is not None else "WAPE unknown" + candidates.append( + RetrainingCandidate( + store_id=run.store_id, + product_id=run.product_id, + priority_score=score, + staleness_days=staleness_days, + wape=wape, + latest_run_id=run.run_id, + latest_run_status=run.status, + reason=f"{staleness_days}d since last training window; {wape_part}", + ) + ) + + candidates.sort(key=lambda candidate: candidate.priority_score, reverse=True) + + logger.info( + "ops.retraining_candidates_computed", + total_evaluated=len(candidates), + returned=min(limit, len(candidates)), + ) + + return RetrainingCandidatesResponse( + candidates=candidates[:limit], + total_evaluated=len(candidates), + generated_at=datetime.now(UTC), + ) + + async def get_model_health(self, db: AsyncSession, limit: int) -> ModelHealthResponse: + """Classify per-grain forecast-error drift from full run history. + + Unlike the retraining queue, this needs the *full* WAPE history per + grain (not just the latest run), so it queries every successful run + ordered by grain then creation time and groups in Python with + ``itertools.groupby`` — NOT ``DISTINCT ON``. + + Args: + db: Database session. + limit: Maximum grains to return (bounded 1..100 by the route). + + Returns: + Grains sorted degrading-first, then by ``|wape_delta|`` descending, + capped at ``limit``. Never raises on an empty database. + """ + today = datetime.now(UTC).date() + + # FULL history — NOT DISTINCT ON. Ordered by (store, product, created_at) + # so itertools.groupby batches each grain in chronological order. + success_runs = ( + ( + await db.execute( + select(ModelRun) + .where(ModelRun.status == RunStatus.SUCCESS.value) + .order_by( + ModelRun.store_id, + ModelRun.product_id, + ModelRun.created_at, + ) + ) + ) + .scalars() + .all() + ) + + entries: list[ModelHealthEntry] = [] + for (store_id, product_id), grain_iter in groupby( + success_runs, key=lambda run: (run.store_id, run.product_id) + ): + grain_runs = list(grain_iter) # already chronological + history = [ + WapePoint( + run_id=run.run_id, + created_at=run.created_at, + wape=extract_wape(run.metrics), + ) + for run in grain_runs + ] + direction, delta = classify_drift([point.wape for point in history]) + numeric = [point.wape for point in history if point.wape is not None] + latest_run = grain_runs[-1] + entries.append( + ModelHealthEntry( + store_id=store_id, + product_id=product_id, + run_count=len(grain_runs), + latest_run_id=latest_run.run_id, + latest_run_status=latest_run.status, + latest_wape=(numeric[-1] if numeric else None), + previous_wape=(numeric[-2] if len(numeric) > 1 else None), + wape_delta=delta, + drift_direction=direction, + last_trained_at=latest_run.created_at, + staleness_days=max((today - latest_run.data_window_end).days, 0), + wape_history=history, + ) + ) + + # Degrading grains first; within a tier, the largest WAPE move leads. + rank: dict[str, int] = {"degrading": 0, "improving": 1, "stable": 2, "unknown": 3} + entries.sort(key=lambda entry: (rank[entry.drift_direction], -abs(entry.wape_delta or 0.0))) + + logger.info("ops.model_health_computed", grains=len(entries)) + + return ModelHealthResponse( + entries=entries[:limit], + total_evaluated=len(entries), + generated_at=datetime.now(UTC), + ) diff --git a/app/features/ops/tests/__init__.py b/app/features/ops/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/app/features/ops/tests/conftest.py b/app/features/ops/tests/conftest.py new file mode 100644 index 00000000..f3663f92 --- /dev/null +++ b/app/features/ops/tests/conftest.py @@ -0,0 +1,297 @@ +"""Test fixtures for the ops slice. + +Mirrors ``app/features/analytics/tests/conftest.py``: a real PostgreSQL session +(integration tests need ``docker-compose up -d``) with FK-safe, scoped cleanup. + +All seeded rows carry a ``test-`` / ``TEST-`` marker so the teardown never +touches a shared dev or CI dataset. +""" + +import uuid +from collections.abc import AsyncGenerator +from datetime import date + +import pytest +from httpx import ASGITransport, AsyncClient +from sqlalchemy import delete, select +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine + +from app.core.config import get_settings +from app.core.database import get_db +from app.features.data_platform.models import Calendar, Product, SalesDaily, Store +from app.features.jobs.models import Job, JobStatus, JobType +from app.features.registry.models import DeploymentAlias, ModelRun, RunStatus +from app.main import app + +# Calendar dates the ops sales fixture occupies — deleted on teardown. +_SALES_DATES = [date(2026, 3, 1), date(2026, 3, 2), date(2026, 3, 3)] + + +def _short_id() -> str: + """Return a short unique hex token for test natural keys.""" + return uuid.uuid4().hex[:12] + + +# ============================================================================= +# Database + client fixtures +# ============================================================================= + + +@pytest.fixture +async def db_session() -> AsyncGenerator[AsyncSession, None]: + """Yield an async session, then clean up every ``test-``/``TEST-`` row. + + Cleanup runs in FK-safe order: DeploymentAlias before ModelRun (alias FKs + the run), and Sales before its Store/Product parents. Jobs are independent. + Calendar rows are intentionally left in place — the sales fixture's dates + fall inside the seeder's window, so a seeded dataset may already reference + them; deleting them would hit a foreign-key violation. + """ + settings = get_settings() + engine = create_async_engine(settings.database_url, echo=False) + async_session_maker = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) + + async with async_session_maker() as session: + try: + yield session + finally: + test_store_ids = select(Store.id).where(Store.code.like("TEST-%")) + test_product_ids = select(Product.id).where(Product.sku.like("TEST-%")) + # Aliases first — they FK-reference model_run. + await session.execute( + delete(DeploymentAlias).where(DeploymentAlias.alias_name.like("test-%")) + ) + await session.execute(delete(ModelRun).where(ModelRun.run_id.like("test-%"))) + await session.execute(delete(Job).where(Job.job_id.like("test-%"))) + await session.execute( + delete(SalesDaily).where( + SalesDaily.store_id.in_(test_store_ids) + | SalesDaily.product_id.in_(test_product_ids) + ) + ) + await session.execute(delete(Product).where(Product.sku.like("TEST-%"))) + await session.execute(delete(Store).where(Store.code.like("TEST-%"))) + await session.commit() + + await engine.dispose() + + +@pytest.fixture +async def client(db_session: AsyncSession) -> AsyncGenerator[AsyncClient, None]: + """Create a test client with the database dependency overridden.""" + + async def override_get_db() -> AsyncGenerator[AsyncSession, None]: + yield db_session + + app.dependency_overrides[get_db] = override_get_db + + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as ac: + yield ac + + app.dependency_overrides.pop(get_db, None) + + +# ============================================================================= +# Sample-data fixtures +# ============================================================================= + + +@pytest.fixture +async def sample_jobs(db_session: AsyncSession) -> list[Job]: + """Create one job per lifecycle status; the failed job carries an error.""" + jobs = [ + Job( + job_id=f"test-{_short_id()}", + job_type=JobType.TRAIN.value, + status=JobStatus.PENDING.value, + params={"_test": True}, + ), + Job( + job_id=f"test-{_short_id()}", + job_type=JobType.PREDICT.value, + status=JobStatus.RUNNING.value, + params={"_test": True}, + ), + Job( + job_id=f"test-{_short_id()}", + job_type=JobType.BACKTEST.value, + status=JobStatus.COMPLETED.value, + params={"_test": True}, + result={"ok": True}, + ), + Job( + job_id=f"test-{_short_id()}", + job_type=JobType.TRAIN.value, + status=JobStatus.FAILED.value, + params={"_test": True}, + error_message="seeded failure", + error_type="ValueError", + ), + ] + for job in jobs: + db_session.add(job) + await db_session.commit() + for job in jobs: + await db_session.refresh(job) + return jobs + + +@pytest.fixture +async def sample_runs(db_session: AsyncSession) -> dict[str, ModelRun]: + """Create model runs across grains and statuses. + + ``success_old`` and ``success_new`` share grain (9001, 8001) so an alias to + ``success_old`` is provably stale (a newer successful run exists). + """ + + def _run( + status: str, + store_id: int, + product_id: int, + window_end: date, + metrics: dict[str, float] | None, + error_message: str | None = None, + ) -> ModelRun: + return ModelRun( + run_id=f"test-{_short_id()}", + status=status, + model_type="naive", + model_config={"_test": True}, + config_hash=_short_id()[:16], + data_window_start=date(2025, 1, 1), + data_window_end=window_end, + store_id=store_id, + product_id=product_id, + metrics=metrics, + error_message=error_message, + ) + + runs = { + "success_old": _run(RunStatus.SUCCESS.value, 9001, 8001, date(2026, 1, 1), {"wape": 31.0}), + "failed": _run( + RunStatus.FAILED.value, + 9002, + 8002, + date(2026, 2, 1), + None, + error_message="seeded run failure", + ), + "success_other": _run(RunStatus.SUCCESS.value, 9003, 8003, date(2026, 2, 15), None), + } + for run in runs.values(): + db_session.add(run) + await db_session.commit() + for run in runs.values(): + await db_session.refresh(run) + + # success_new is committed after success_old so its created_at is strictly + # later — making success_old the stale one for grain (9001, 8001). + success_new = _run(RunStatus.SUCCESS.value, 9001, 8001, date(2026, 4, 1), {"wape": 12.0}) + db_session.add(success_new) + await db_session.commit() + await db_session.refresh(success_new) + runs["success_new"] = success_new + return runs + + +@pytest.fixture +async def sample_alias( + db_session: AsyncSession, sample_runs: dict[str, ModelRun] +) -> DeploymentAlias: + """Alias pointing at the OLDER successful run — provably stale.""" + alias = DeploymentAlias( + alias_name=f"test-{_short_id()}", + run_id=sample_runs["success_old"].id, + description="ops slice test alias", + ) + db_session.add(alias) + await db_session.commit() + await db_session.refresh(alias) + return alias + + +@pytest.fixture +async def sample_sales(db_session: AsyncSession) -> list[SalesDaily]: + """Create a TEST- store/product, calendar rows, and a few sales days.""" + store = Store( + code=f"TEST-{_short_id()}", + name="Ops Test Store", + region="Test Region", + city="Test City", + store_type="supermarket", + ) + product = Product( + sku=f"TEST-{_short_id()}", + name="Ops Test Product", + category="Test Category", + brand="Test Brand", + base_price=10, + base_cost=5, + ) + db_session.add_all([store, product]) + await db_session.commit() + await db_session.refresh(store) + await db_session.refresh(product) + + for day in _SALES_DATES: + await db_session.merge( + Calendar( + date=day, + day_of_week=day.weekday(), + month=day.month, + quarter=(day.month - 1) // 3 + 1, + year=day.year, + is_holiday=False, + ) + ) + await db_session.commit() + + sales = [ + SalesDaily( + date=day, + store_id=store.id, + product_id=product.id, + quantity=5, + unit_price=10, + total_amount=50, + ) + for day in _SALES_DATES + ] + for row in sales: + db_session.add(row) + await db_session.commit() + for row in sales: + await db_session.refresh(row) + return sales + + +@pytest.fixture +async def sample_health_runs(db_session: AsyncSession) -> list[ModelRun]: + """Three successful runs for one grain forming a degrading WAPE history. + + Grain (9101, 8101): WAPE 10.0 -> 11.0 -> 25.0. Each run is committed in its + own transaction so its server-side ``created_at`` is strictly later than + the prior one — making the chronological history deterministic for the + model-health endpoint's ``itertools.groupby`` ordering. + """ + wapes = [10.0, 11.0, 25.0] + window_ends = [date(2026, 1, 1), date(2026, 2, 1), date(2026, 3, 1)] + runs: list[ModelRun] = [] + for wape, window_end in zip(wapes, window_ends, strict=True): + run = ModelRun( + run_id=f"test-{_short_id()}", + status=RunStatus.SUCCESS.value, + model_type="naive", + model_config={"_test": True}, + config_hash=_short_id()[:16], + data_window_start=date(2025, 1, 1), + data_window_end=window_end, + store_id=9101, + product_id=8101, + metrics={"wape": wape}, + ) + db_session.add(run) + await db_session.commit() + await db_session.refresh(run) + runs.append(run) + return runs diff --git a/app/features/ops/tests/test_routes_integration.py b/app/features/ops/tests/test_routes_integration.py new file mode 100644 index 00000000..e6b96ec1 --- /dev/null +++ b/app/features/ops/tests/test_routes_integration.py @@ -0,0 +1,196 @@ +"""Integration tests for the ops Control Center routes. + +Runs against a real PostgreSQL database — the full path from HTTP request +through SQL aggregation to response. Requires ``docker-compose up -d``. + +Assertions are structural (status-key coverage, sort order, bounds) rather than +exact global totals, so the tests stay idempotent against a shared dataset. +""" + +import pytest +from httpx import AsyncClient + +from app.features.data_platform.models import SalesDaily +from app.features.jobs.models import Job, JobStatus +from app.features.registry.models import DeploymentAlias, ModelRun, RunStatus + +_JOB_STATUSES = {s.value for s in JobStatus} +_RUN_STATUSES = {s.value for s in RunStatus} +_DRIFT_RANK = {"degrading": 0, "improving": 1, "stable": 2, "unknown": 3} + + +@pytest.mark.integration +@pytest.mark.asyncio +class TestOpsSummary: + """Integration tests for GET /ops/summary.""" + + async def test_summary_happy_path( + self, + client: AsyncClient, + sample_jobs: list[Job], + sample_runs: dict[str, ModelRun], + sample_alias: DeploymentAlias, + sample_sales: list[SalesDaily], + ) -> None: + """A seeded database yields a fully populated summary.""" + response = await client.get("/ops/summary") + + assert response.status_code == 200 + data = response.json() + + assert data["system"]["api_ok"] is True + assert data["system"]["database_connected"] is True + + # Job and run histograms cover every status key (zero-filled). + assert {c["status"] for c in data["jobs"]["counts"]} == _JOB_STATUSES + assert {c["status"] for c in data["runs"]["counts"]} == _RUN_STATUSES + + # The seeded failed job surfaces in the attention list. It is the most + # recently created failed job, so the limit-10 window always includes it. + failed_job_id = next(j.job_id for j in sample_jobs if j.status == JobStatus.FAILED.value) + failed_job_ids = { + item["entity_id"] + for item in data["attention_items"] + if item["item_type"] == "failed_job" + } + assert failed_job_id in failed_job_ids + + # Freshness reflects the seeded sales. + assert data["freshness"]["latest_sales_date"] is not None + assert data["freshness"]["latest_sales_date"] >= "2026-03-03" + + # The alias seeded against the older successful run is reported stale. + stale_alias = next(a for a in data["aliases"] if a["alias_name"] == sample_alias.alias_name) + assert stale_alias["is_stale"] is True + assert stale_alias["stale_reason"] is not None + assert stale_alias["wape"] == 31.0 + + async def test_summary_resilient_structural(self, client: AsyncClient) -> None: + """Without any seeded fixtures the summary still returns 200, never 500.""" + response = await client.get("/ops/summary") + + assert response.status_code == 200 + data = response.json() + + # Every histogram bucket is non-negative and every status key present. + for section in ("jobs", "runs"): + for count in data[section]["counts"]: + assert count["count"] >= 0 + assert {c["status"] for c in data["jobs"]["counts"]} == _JOB_STATUSES + assert {c["status"] for c in data["runs"]["counts"]} == _RUN_STATUSES + + assert data["jobs"]["completed_today"] >= 0 + assert data["jobs"]["active_total"] >= 0 + assert data["jobs"]["failed_total"] >= 0 + assert data["runs"]["failed_total"] >= 0 + assert isinstance(data["attention_items"], list) + assert isinstance(data["aliases"], list) + + +@pytest.mark.integration +@pytest.mark.asyncio +class TestRetrainingCandidates: + """Integration tests for GET /ops/retraining-candidates.""" + + async def test_candidates_sorted_and_limited( + self, + client: AsyncClient, + sample_runs: dict[str, ModelRun], + ) -> None: + """Candidates are sorted by priority_score desc and capped at limit.""" + response = await client.get("/ops/retraining-candidates", params={"limit": 5}) + + assert response.status_code == 200 + data = response.json() + + candidates = data["candidates"] + assert len(candidates) <= 5 + assert data["total_evaluated"] >= len(candidates) + + scores = [c["priority_score"] for c in candidates] + assert scores == sorted(scores, reverse=True), "candidates must be sorted desc" + + for candidate in candidates: + assert 0.0 <= candidate["priority_score"] <= 1.0 + assert candidate["staleness_days"] >= 0 + assert candidate["latest_run_status"] == RunStatus.SUCCESS.value + + async def test_candidates_default_limit(self, client: AsyncClient) -> None: + """The endpoint works with no explicit limit (default 20).""" + response = await client.get("/ops/retraining-candidates") + + assert response.status_code == 200 + assert len(response.json()["candidates"]) <= 20 + + async def test_candidates_limit_zero_rejected(self, client: AsyncClient) -> None: + """limit=0 is below the ge=1 bound and returns 422.""" + response = await client.get("/ops/retraining-candidates", params={"limit": 0}) + assert response.status_code == 422 + + async def test_candidates_limit_too_high_rejected(self, client: AsyncClient) -> None: + """limit=200 is above the le=100 bound and returns 422.""" + response = await client.get("/ops/retraining-candidates", params={"limit": 200}) + assert response.status_code == 422 + + +@pytest.mark.integration +@pytest.mark.asyncio +class TestModelHealth: + """Integration tests for GET /ops/model-health.""" + + async def test_model_health_happy_path( + self, + client: AsyncClient, + sample_health_runs: list[ModelRun], + ) -> None: + """The seeded 3-run degrading grain surfaces with a drift verdict.""" + response = await client.get("/ops/model-health", params={"limit": 100}) + + assert response.status_code == 200 + data = response.json() + + entry = next( + e for e in data["entries"] if e["store_id"] == 9101 and e["product_id"] == 8101 + ) + assert entry["drift_direction"] == "degrading" + assert entry["run_count"] == 3 + assert entry["latest_wape"] == 25.0 + assert entry["previous_wape"] == 11.0 + assert entry["wape_delta"] == 14.0 + assert len(entry["wape_history"]) == 3 + assert data["total_evaluated"] >= 1 + + async def test_model_health_degrading_first_sort( + self, + client: AsyncClient, + sample_health_runs: list[ModelRun], + ) -> None: + """Entries are ordered degrading-first (drift rank non-decreasing).""" + response = await client.get("/ops/model-health", params={"limit": 100}) + + assert response.status_code == 200 + ranks = [_DRIFT_RANK[e["drift_direction"]] for e in response.json()["entries"]] + assert ranks == sorted(ranks), "entries must be sorted degrading-first" + + async def test_model_health_resilient_structural(self, client: AsyncClient) -> None: + """Without seeded fixtures the endpoint still returns 200, never 500.""" + response = await client.get("/ops/model-health") + + assert response.status_code == 200 + data = response.json() + assert isinstance(data["entries"], list) + assert data["total_evaluated"] >= 0 + for entry in data["entries"]: + assert entry["drift_direction"] in _DRIFT_RANK + assert entry["run_count"] >= 0 + assert entry["staleness_days"] >= 0 + + async def test_model_health_limit_zero_rejected(self, client: AsyncClient) -> None: + """limit=0 is below the ge=1 bound and returns 422.""" + response = await client.get("/ops/model-health", params={"limit": 0}) + assert response.status_code == 422 + + async def test_model_health_limit_too_high_rejected(self, client: AsyncClient) -> None: + """limit=200 is above the le=100 bound and returns 422.""" + response = await client.get("/ops/model-health", params={"limit": 200}) + assert response.status_code == 422 diff --git a/app/features/ops/tests/test_schemas.py b/app/features/ops/tests/test_schemas.py new file mode 100644 index 00000000..4ee7314d --- /dev/null +++ b/app/features/ops/tests/test_schemas.py @@ -0,0 +1,272 @@ +"""Unit tests for the ops slice's Pydantic response schemas. + +These run without a database (-m "not integration"). +""" + +from datetime import UTC, date, datetime + +import pytest +from pydantic import ValidationError + +from app.features.ops.schemas import ( + AliasHealth, + AttentionItem, + DataFreshness, + JobHealth, + ModelHealthEntry, + ModelHealthResponse, + OpsSummaryResponse, + RetrainingCandidate, + RetrainingCandidatesResponse, + RunHealth, + StatusCount, + SystemHealth, + WapePoint, +) + +_NOW = datetime(2026, 5, 19, 12, 0, tzinfo=UTC) + + +def test_system_health_construct() -> None: + """SystemHealth carries liveness flags and an optional job timestamp.""" + system = SystemHealth(api_ok=True, database_connected=True, latest_successful_job_at=_NOW) + assert system.api_ok is True + assert system.latest_successful_job_at == _NOW + + +def test_system_health_allows_null_job_timestamp() -> None: + """latest_successful_job_at defaults to None when no job has completed.""" + system = SystemHealth(api_ok=True, database_connected=False) + assert system.latest_successful_job_at is None + + +def test_status_count_rejects_negative_count() -> None: + """A negative count violates the ge=0 constraint.""" + with pytest.raises(ValidationError): + StatusCount(status="failed", count=-1) + + +def test_job_health_construct_and_reject_negative() -> None: + """JobHealth aggregates counts; negative totals are rejected.""" + health = JobHealth( + counts=[StatusCount(status="completed", count=3)], + completed_today=2, + failed_total=1, + active_total=0, + ) + assert health.completed_today == 2 + with pytest.raises(ValidationError): + JobHealth(counts=[], completed_today=-1, failed_total=0, active_total=0) + + +def test_run_health_allows_null_success_rate() -> None: + """success_rate is None when there are no eligible runs.""" + health = RunHealth(counts=[], success_rate=None, failed_total=0) + assert health.success_rate is None + + +def test_alias_health_construct() -> None: + """AliasHealth carries the staleness verdict and an optional WAPE.""" + alias = AliasHealth( + alias_name="production", + run_id="abc123", + run_status="success", + model_type="naive", + store_id=1, + product_id=2, + is_stale=True, + stale_reason="a newer successful run exists for this store/product", + wape=18.4, + ) + assert alias.is_stale is True + assert alias.wape == 18.4 + + +def test_data_freshness_defaults_to_null() -> None: + """Every freshness field is optional and defaults to None.""" + freshness = DataFreshness() + assert freshness.latest_sales_date is None + assert freshness.latest_job_completed_at is None + assert freshness.latest_run_completed_at is None + + +def test_attention_item_rejects_unknown_type() -> None: + """item_type is constrained to the three known literals.""" + with pytest.raises(ValidationError): + AttentionItem( + item_type="something_else", # type: ignore[arg-type] + entity_id="x", + label="x", + detail="x", + ) + + +def test_attention_item_construct() -> None: + """A valid AttentionItem accepts the known literals.""" + item = AttentionItem( + item_type="failed_job", + entity_id="job-1", + label="train job failed", + detail="boom", + occurred_at=_NOW, + ) + assert item.item_type == "failed_job" + + +def test_ops_summary_response_construct() -> None: + """OpsSummaryResponse nests every section.""" + summary = OpsSummaryResponse( + system=SystemHealth(api_ok=True, database_connected=True), + jobs=JobHealth(counts=[], completed_today=0, failed_total=0, active_total=0), + runs=RunHealth(counts=[], success_rate=None, failed_total=0), + aliases=[], + freshness=DataFreshness(), + attention_items=[], + generated_at=_NOW, + ) + assert summary.generated_at == _NOW + assert summary.aliases == [] + + +def test_retraining_candidate_rejects_out_of_range_score() -> None: + """priority_score is bounded to [0.0, 1.0].""" + with pytest.raises(ValidationError): + RetrainingCandidate( + store_id=1, + product_id=2, + priority_score=1.5, + staleness_days=10, + wape=None, + latest_run_id="r1", + latest_run_status="success", + reason="x", + ) + + +def test_retraining_candidate_rejects_negative_staleness() -> None: + """staleness_days violates ge=0 when negative.""" + with pytest.raises(ValidationError): + RetrainingCandidate( + store_id=1, + product_id=2, + priority_score=0.5, + staleness_days=-1, + wape=None, + latest_run_id="r1", + latest_run_status="success", + reason="x", + ) + + +def test_retraining_candidates_response_construct() -> None: + """RetrainingCandidatesResponse wraps candidates with a total and timestamp.""" + candidate = RetrainingCandidate( + store_id=1, + product_id=2, + priority_score=0.75, + staleness_days=30, + wape=12.0, + latest_run_id="r1", + latest_run_status="success", + reason="30d since last training window; WAPE 12.0", + ) + response = RetrainingCandidatesResponse( + candidates=[candidate], + total_evaluated=1, + generated_at=_NOW, + ) + assert response.total_evaluated == 1 + assert response.candidates[0].priority_score == 0.75 + + +def test_data_freshness_accepts_date() -> None: + """latest_sales_date accepts a date value.""" + freshness = DataFreshness(latest_sales_date=date(2026, 5, 1)) + assert freshness.latest_sales_date == date(2026, 5, 1) + + +# ============================================================================= +# Model health & drift +# ============================================================================= + + +def test_wape_point_construct() -> None: + """WapePoint carries a run_id, timestamp, and an optional WAPE.""" + point = WapePoint(run_id="r1", created_at=_NOW, wape=14.0) + assert point.wape == 14.0 + null_point = WapePoint(run_id="r2", created_at=_NOW) + assert null_point.wape is None + + +def test_model_health_entry_construct() -> None: + """A valid ModelHealthEntry accepts a known drift direction and history.""" + entry = ModelHealthEntry( + store_id=1, + product_id=2, + run_count=3, + latest_run_id="r3", + latest_run_status="success", + latest_wape=25.0, + previous_wape=11.0, + wape_delta=14.0, + drift_direction="degrading", + last_trained_at=_NOW, + staleness_days=30, + wape_history=[WapePoint(run_id="r3", created_at=_NOW, wape=25.0)], + ) + assert entry.drift_direction == "degrading" + assert entry.wape_delta == 14.0 + + +def test_model_health_entry_rejects_negative_run_count() -> None: + """run_count violates the ge=0 constraint when negative.""" + with pytest.raises(ValidationError): + ModelHealthEntry( + store_id=1, + product_id=2, + run_count=-1, + drift_direction="unknown", + staleness_days=0, + wape_history=[], + ) + + +def test_model_health_entry_rejects_negative_staleness() -> None: + """staleness_days violates the ge=0 constraint when negative.""" + with pytest.raises(ValidationError): + ModelHealthEntry( + store_id=1, + product_id=2, + run_count=0, + drift_direction="unknown", + staleness_days=-1, + wape_history=[], + ) + + +def test_model_health_entry_rejects_unknown_drift_direction() -> None: + """drift_direction is constrained to the four known literals.""" + with pytest.raises(ValidationError): + ModelHealthEntry( + store_id=1, + product_id=2, + run_count=0, + drift_direction="exploding", # type: ignore[arg-type] + staleness_days=0, + wape_history=[], + ) + + +def test_model_health_response_construct() -> None: + """ModelHealthResponse wraps entries with a total and timestamp.""" + entry = ModelHealthEntry( + store_id=1, + product_id=2, + run_count=2, + drift_direction="stable", + staleness_days=5, + wape_history=[], + ) + response = ModelHealthResponse(entries=[entry], total_evaluated=1, generated_at=_NOW) + assert response.total_evaluated == 1 + assert response.entries[0].drift_direction == "stable" diff --git a/app/features/ops/tests/test_service.py b/app/features/ops/tests/test_service.py new file mode 100644 index 00000000..3c228660 --- /dev/null +++ b/app/features/ops/tests/test_service.py @@ -0,0 +1,135 @@ +"""Unit tests for the ops slice's pure scoring helpers. + +These run without a database (-m "not integration"): the helpers are pure +functions with no I/O. +""" + +from app.features.ops.service import classify_drift, extract_wape, score_retraining_candidate + +# ============================================================================= +# score_retraining_candidate +# ============================================================================= + + +def test_score_zero_when_fresh_and_no_error() -> None: + """A brand-new run with no WAPE scores 0.0.""" + assert score_retraining_candidate(0, None) == 0.0 + + +def test_score_max_when_fully_stale_and_max_error() -> None: + """90+ days stale with WAPE 100 saturates both terms to 1.0.""" + assert score_retraining_candidate(90, 100.0) == 1.0 + + +def test_score_clamps_negative_staleness_and_high_wape() -> None: + """Negative staleness clamps to 0; WAPE above the cap clamps to 1.0.""" + # staleness -> 0.0, error -> 1.0; score = 0.6*0 + 0.4*1.0 = 0.4 + assert score_retraining_candidate(-5, 250.0) == 0.4 + + +def test_score_midpoint() -> None: + """Half-stale with half-max WAPE lands at the weighted midpoint.""" + # staleness 45/90 -> 0.5, error 50/100 -> 0.5; score = 0.6*0.5 + 0.4*0.5 = 0.5 + assert score_retraining_candidate(45, 50.0) == 0.5 + + +def test_score_staleness_only_when_wape_unknown() -> None: + """With WAPE unknown the score degrades to the staleness term alone.""" + # staleness 90 -> 1.0, error -> 0.0; score = 0.6 + assert score_retraining_candidate(90, None) == 0.6 + + +def test_score_is_bounded() -> None: + """The score never escapes [0.0, 1.0] for extreme inputs.""" + assert score_retraining_candidate(10_000, 10_000.0) == 1.0 + assert score_retraining_candidate(-10_000, -10_000.0) == 0.0 + + +# ============================================================================= +# extract_wape +# ============================================================================= + + +def test_extract_wape_prefers_wape_then_wape_mean() -> None: + """The 'wape' key wins; 'wape_mean' and 'WAPE' are fallbacks.""" + assert extract_wape({"wape": 12.0}) == 12.0 + assert extract_wape({"wape_mean": 8.5}) == 8.5 + assert extract_wape({"WAPE": 4.0}) == 4.0 + assert extract_wape({"wape": 1.0, "wape_mean": 99.0}) == 1.0 + + +def test_extract_wape_returns_none_for_missing_or_empty() -> None: + """None and an empty / unrelated dict yield None — never an exception.""" + assert extract_wape(None) is None + assert extract_wape({}) is None + assert extract_wape({"mae": 3.2}) is None + + +def test_extract_wape_rejects_non_numeric_and_bool() -> None: + """A non-numeric value yields None; bool is rejected (it is not a metric).""" + assert extract_wape({"wape": "bad"}) is None + assert extract_wape({"wape": None}) is None + assert extract_wape({"wape": True}) is None + assert extract_wape({"wape": False}) is None + + +def test_extract_wape_coerces_int_to_float() -> None: + """An integer WAPE is returned as a float.""" + result = extract_wape({"wape": 25}) + assert result == 25.0 + assert isinstance(result, float) + + +# ============================================================================= +# classify_drift +# ============================================================================= + + +def test_classify_drift_unknown_when_empty() -> None: + """An empty history has no trend — direction is 'unknown', delta None.""" + assert classify_drift([]) == ("unknown", None) + + +def test_classify_drift_unknown_when_under_two_numeric() -> None: + """Fewer than two numeric WAPEs yields 'unknown' (None gaps don't count).""" + assert classify_drift([None, 10.0]) == ("unknown", None) + assert classify_drift([10.0]) == ("unknown", None) + + +def test_classify_drift_degrading() -> None: + """A latest WAPE far above the prior mean is 'degrading'; delta is positive.""" + direction, delta = classify_drift([10.0, 10.0, 20.0]) + assert direction == "degrading" + assert delta == 10.0 + + +def test_classify_drift_improving() -> None: + """A latest WAPE far below the prior mean is 'improving'; delta is negative.""" + direction, delta = classify_drift([20.0, 20.0, 10.0]) + assert direction == "improving" + assert delta == -10.0 + + +def test_classify_drift_stable_within_band() -> None: + """A change inside the ±10% relative band is 'stable'.""" + direction, delta = classify_drift([10.0, 10.5]) # +5% < 10% band + assert direction == "stable" + assert delta == 0.5 + + +def test_classify_drift_tolerates_none_gaps() -> None: + """None gaps are skipped; classification uses only numeric observations.""" + direction, delta = classify_drift([None, 10.0, None, 12.0]) # +20% over baseline 10 + assert direction == "degrading" + assert delta == 2.0 + + +def test_classify_drift_zero_baseline_guard() -> None: + """A zero baseline never divides by zero: positive error degrades, zero is stable.""" + assert classify_drift([0.0, 5.0])[0] == "degrading" + assert classify_drift([0.0, 0.0])[0] == "stable" + + +def test_classify_drift_never_raises_on_sparse_history() -> None: + """Sparse / all-None history degrades gracefully to 'unknown'.""" + assert classify_drift([None, None, None]) == ("unknown", None) diff --git a/app/features/rag/routes.py b/app/features/rag/routes.py index 4daf0106..e4474fb2 100644 --- a/app/features/rag/routes.py +++ b/app/features/rag/routes.py @@ -11,6 +11,8 @@ from app.features.rag.embeddings import EmbeddingError from app.features.rag.schemas import ( DeleteResponse, + IndexProjectDocsRequest, + IndexProjectDocsResponse, IndexRequest, IndexResponse, RetrieveRequest, @@ -133,6 +135,91 @@ async def index_document( ) from e +@router.post( + "/index/project-docs", + response_model=IndexProjectDocsResponse, + summary="Index bundled project documentation", + description=""" +Discover and bulk-index the repository's own bundled markdown. + +**Discovery roots (all toggleable, all default on):** +- `include_docs`: every `docs/**/*.md` +- `include_prps`: every `PRPs/**/*.md` +- `include_root`: `README.md`, `AGENTS.md`, `CHANGELOG.md` + +Each file is indexed through the same path as `POST /rag/index`, so chunking, +embedding, the SHA-256 content-hash idempotency short-circuit, and upsert are +all reused. Re-runs return every unchanged file as `status: "unchanged"`. + +**Returns:** per-file results plus aggregate counts (indexed / updated / +unchanged / failed / total_chunks). A single unreadable file is reported +`status: "failed"` without aborting the batch; an embedding-provider or +database failure is batch-fatal and surfaces as `502` / problem+json. +""", +) +async def index_project_docs( + request: IndexProjectDocsRequest, + db: AsyncSession = Depends(get_db), +) -> IndexProjectDocsResponse: + """Bulk-index bundled project documentation into the knowledge base. + + Args: + request: Toggles selecting which doc roots to index. + db: Async database session from dependency. + + Returns: + Per-file results plus aggregate indexing statistics. + + Raises: + HTTPException: If embedding generation fails (502). + DatabaseError: If a database operation fails. + """ + logger.info( + "rag.index_project_docs_request_received", + include_docs=request.include_docs, + include_prps=request.include_prps, + include_root=request.include_root, + ) + + service = RAGService() + + try: + response = await service.index_project_docs(db=db, request=request) + + logger.info( + "rag.index_project_docs_request_completed", + total_files=response.total_files, + total_chunks=response.total_chunks, + failed=response.failed, + ) + + return response + + except EmbeddingError as e: + logger.error( + "rag.index_project_docs_request_failed", + error=str(e), + error_type=type(e).__name__, + exc_info=True, + ) + raise HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, + detail=f"Embedding generation failed: {e}", + ) from e + + except SQLAlchemyError as e: + logger.error( + "rag.index_project_docs_request_failed", + error=str(e), + error_type=type(e).__name__, + exc_info=True, + ) + raise DatabaseError( + message="Failed to index project docs", + details={"error": str(e)}, + ) from e + + # ============================================================================= # Retrieve Endpoint # ============================================================================= diff --git a/app/features/rag/schemas.py b/app/features/rag/schemas.py index 41a31d1b..44dbac2d 100644 --- a/app/features/rag/schemas.py +++ b/app/features/rag/schemas.py @@ -179,3 +179,63 @@ class DeleteResponse(BaseModel): source_id: str chunks_deleted: int status: Literal["deleted"] + + +class IndexProjectDocsRequest(BaseModel): + """Request to bulk-index bundled project documentation. + + All fields default to True so an empty ``{}`` body indexes every root. + + Args: + include_docs: Index markdown discovered under docs/**. + include_prps: Index markdown discovered under PRPs/**. + include_root: Index the root allow-list (README/AGENTS/CHANGELOG). + """ + + model_config = ConfigDict(extra="forbid") + + include_docs: bool = Field(default=True, description="Index docs/**/*.md") + include_prps: bool = Field(default=True, description="Index PRPs/**/*.md") + include_root: bool = Field( + default=True, description="Index README.md / AGENTS.md / CHANGELOG.md" + ) + + +class ProjectDocResult(BaseModel): + """Per-file outcome of a project-docs index run. + + Args: + source_path: Relative POSIX path of the file (the source identifier). + status: Outcome — indexed, updated, unchanged, or failed. + chunks_created: Number of chunks created (0 when unchanged or failed). + error: Error message when status is "failed", otherwise None. + """ + + source_path: str + status: Literal["indexed", "updated", "unchanged", "failed"] + chunks_created: int + error: str | None = None + + +class IndexProjectDocsResponse(BaseModel): + """Aggregate result of POST /rag/index/project-docs. + + Args: + results: Per-file outcomes. + total_files: Total files discovered and processed. + indexed: Count of newly indexed files. + updated: Count of re-indexed (changed) files. + unchanged: Count of files skipped by the content-hash short-circuit. + failed: Count of files that could not be read. + total_chunks: Total chunks created across all files. + duration_ms: Wall-clock time taken for the batch. + """ + + results: list[ProjectDocResult] + total_files: int + indexed: int + updated: int + unchanged: int + failed: int + total_chunks: int + duration_ms: float diff --git a/app/features/rag/service.py b/app/features/rag/service.py index d77f6f7e..8229bb33 100644 --- a/app/features/rag/service.py +++ b/app/features/rag/service.py @@ -29,8 +29,11 @@ from app.features.rag.schemas import ( ChunkResult, DeleteResponse, + IndexProjectDocsRequest, + IndexProjectDocsResponse, IndexRequest, IndexResponse, + ProjectDocResult, RetrieveRequest, RetrieveResponse, SourceListResponse, @@ -39,6 +42,10 @@ logger = structlog.get_logger() +# Allow-listed root markdown files indexed by index_project_docs. CLAUDE.md is +# deliberately excluded — it is an operating index that @imports AGENTS.md. +_PROJECT_ROOT_FILES: tuple[str, ...] = ("README.md", "AGENTS.md", "CHANGELOG.md") + class SourceNotFoundError(ValueError): """Source not found in the knowledge base.""" @@ -250,6 +257,135 @@ async def index_document( status=status, ) + def _discover_project_doc_files( + self, request: IndexProjectDocsRequest + ) -> list[tuple[Path, str]]: + """Discover bundled markdown under the allow-listed project-doc roots. + + Pure and synchronous — no DB, no network. ``rglob`` on a non-existent + directory yields nothing (no exception), so an absent docs/ or PRPs/ + root simply contributes 0 files. + + Args: + request: Toggles selecting which roots to discover. + + Returns: + A deterministically sorted list of (absolute_path, category) pairs + where category is "docs", "prp", or "root". + """ + found: list[tuple[Path, str]] = [] + + if request.include_docs: + found += [(p, "docs") for p in (self._base_dir / "docs").rglob("*.md")] + + if request.include_prps: + found += [(p, "prp") for p in (self._base_dir / "PRPs").rglob("*.md")] + + if request.include_root: + for name in _PROJECT_ROOT_FILES: + candidate = self._base_dir / name + if candidate.is_file(): + found.append((candidate, "root")) + + # rglob order is filesystem-dependent — sort for stable, reproducible runs. + return sorted(found, key=lambda pair: str(pair[0])) + + async def index_project_docs( + self, + db: AsyncSession, + request: IndexProjectDocsRequest, + ) -> IndexProjectDocsResponse: + """Bulk-index discovered project docs via index_document. Idempotent. + + Each file is indexed through index_document, reusing its chunking, + embedding, SHA-256 content-hash idempotency, and upsert. A single + unreadable / non-UTF-8 file is reported status="failed" and does NOT + abort the batch. EmbeddingError / SQLAlchemyError are NOT caught here — + they are batch-fatal and propagate to the route's error handlers. + + Args: + db: Database session. + request: Toggles selecting which roots to index. + + Returns: + Per-file results plus aggregate counts. + """ + start_time = time.time() + + logger.info( + "rag.index_project_docs_started", + include_docs=request.include_docs, + include_prps=request.include_prps, + include_root=request.include_root, + ) + + results: list[ProjectDocResult] = [] + + for abs_path, category in self._discover_project_doc_files(request): + # abs_path was globbed under self._base_dir, so relative_to is safe. + rel = abs_path.relative_to(self._base_dir).as_posix() + try: + content = abs_path.read_text(encoding="utf-8") + index_response = await self.index_document( + db, + IndexRequest( + source_type="markdown", + source_path=rel, + content=content, + metadata={"category": category}, + ), + ) + results.append( + ProjectDocResult( + source_path=rel, + status=index_response.status, + chunks_created=index_response.chunks_created, + error=None, + ) + ) + except (OSError, ValueError) as exc: + # FileNotFoundError ⊂ OSError; UnicodeDecodeError ⊂ ValueError. + logger.warning( + "rag.index_project_docs_file_failed", + source_path=rel, + error=str(exc), + error_type=type(exc).__name__, + ) + results.append( + ProjectDocResult( + source_path=rel, + status="failed", + chunks_created=0, + error=str(exc), + ) + ) + + duration_ms = (time.time() - start_time) * 1000 + + summary = IndexProjectDocsResponse( + results=results, + total_files=len(results), + indexed=sum(r.status == "indexed" for r in results), + updated=sum(r.status == "updated" for r in results), + unchanged=sum(r.status == "unchanged" for r in results), + failed=sum(r.status == "failed" for r in results), + total_chunks=sum(r.chunks_created for r in results), + duration_ms=duration_ms, + ) + + logger.info( + "rag.index_project_docs_completed", + total_files=summary.total_files, + indexed=summary.indexed, + updated=summary.updated, + unchanged=summary.unchanged, + failed=summary.failed, + total_chunks=summary.total_chunks, + duration_ms=duration_ms, + ) + + return summary + async def retrieve( self, db: AsyncSession, diff --git a/app/features/rag/tests/conftest.py b/app/features/rag/tests/conftest.py index 3bf7f318..30d5fe6a 100644 --- a/app/features/rag/tests/conftest.py +++ b/app/features/rag/tests/conftest.py @@ -41,9 +41,10 @@ async def db_session() -> AsyncGenerator[AsyncSession, None]: try: yield session finally: - # Clean up test data (delete sources with test- prefix) + # Clean up test data (delete sources whose path contains a test- token, + # including nested project-doc fixture paths like docs/test-*.md) test_source_ids = delete(DocumentSource).where( - DocumentSource.source_path.like("test-%") + DocumentSource.source_path.like("%test-%") ) await session.execute(test_source_ids) await session.commit() diff --git a/app/features/rag/tests/test_routes.py b/app/features/rag/tests/test_routes.py index ce09a05a..f898a9f1 100644 --- a/app/features/rag/tests/test_routes.py +++ b/app/features/rag/tests/test_routes.py @@ -7,12 +7,14 @@ Note: These tests mock the OpenAI embedding service to avoid API calls. """ +from functools import partial from unittest.mock import AsyncMock, MagicMock, patch import pytest from httpx import AsyncClient -from app.features.rag.embeddings import EmbeddingService +from app.features.rag.embeddings import EmbeddingError, EmbeddingService +from app.features.rag.service import RAGService # ============================================================================= # Mock Embedding Service for Integration Tests @@ -431,3 +433,135 @@ async def test_index_openapi_creates_endpoint_chunks(self, client: AsyncClient): data = response.json() # Should have at least: info chunk + 2 endpoint chunks assert data["chunks_created"] >= 3 + + +# ============================================================================= +# Index Project Docs Endpoint Tests +# ============================================================================= + + +@pytest.mark.integration +class TestIndexProjectDocsEndpoint: + """Integration tests for POST /rag/index/project-docs endpoint.""" + + @pytest.mark.asyncio + async def test_indexes_discovered_docs(self, client: AsyncClient, tmp_path): + """Test that discovered docs are indexed and re-runs are idempotent.""" + (tmp_path / "docs").mkdir() + (tmp_path / "PRPs").mkdir() + # Non-empty content; `test-` token so conftest cleanup catches the rows. + (tmp_path / "docs" / "test-proj-1.md").write_text( + "# Alpha\n\nAlpha content.", encoding="utf-8" + ) + (tmp_path / "PRPs" / "test-proj-2.md").write_text( + "# Beta\n\nBeta content.", encoding="utf-8" + ) + mock_service = create_mock_embedding_service() + + with ( + patch( + "app.features.rag.routes.RAGService", + partial(RAGService, base_dir=str(tmp_path)), + ), + patch( + "app.features.rag.service.get_embedding_service", + return_value=mock_service, + ), + ): + response1 = await client.post("/rag/index/project-docs", json={}) + assert response1.status_code == 200 + data1 = response1.json() + assert data1["total_files"] == 2 + assert data1["indexed"] == 2 + assert data1["failed"] == 0 + assert data1["total_chunks"] >= 2 + + # Idempotent re-run — every file unchanged, no new chunks. + response2 = await client.post("/rag/index/project-docs", json={}) + assert response2.status_code == 200 + assert response2.json()["unchanged"] == 2 + + @pytest.mark.asyncio + async def test_empty_roots_returns_zero(self, client: AsyncClient, tmp_path): + """Test that an empty doc tree returns zero files without error.""" + mock_service = create_mock_embedding_service() + + with ( + patch( + "app.features.rag.routes.RAGService", + partial(RAGService, base_dir=str(tmp_path)), + ), + patch( + "app.features.rag.service.get_embedding_service", + return_value=mock_service, + ), + ): + response = await client.post("/rag/index/project-docs", json={}) + + assert response.status_code == 200 + assert response.json()["total_files"] == 0 + + @pytest.mark.asyncio + async def test_toggles_select_roots(self, client: AsyncClient, tmp_path): + """Test that include_* toggles restrict discovery.""" + (tmp_path / "docs").mkdir() + (tmp_path / "PRPs").mkdir() + (tmp_path / "docs" / "test-toggle-1.md").write_text( + "# Docs\n\nDocs content.", encoding="utf-8" + ) + (tmp_path / "PRPs" / "test-toggle-2.md").write_text( + "# Prp\n\nPrp content.", encoding="utf-8" + ) + mock_service = create_mock_embedding_service() + + with ( + patch( + "app.features.rag.routes.RAGService", + partial(RAGService, base_dir=str(tmp_path)), + ), + patch( + "app.features.rag.service.get_embedding_service", + return_value=mock_service, + ), + ): + response = await client.post( + "/rag/index/project-docs", + json={"include_prps": False, "include_root": False}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["total_files"] == 1 + assert data["results"][0]["source_path"] == "docs/test-toggle-1.md" + + @pytest.mark.asyncio + async def test_unknown_field_rejected(self, client: AsyncClient): + """Test that an unknown body field is rejected (extra='forbid').""" + response = await client.post("/rag/index/project-docs", json={"bogus": True}) + assert response.status_code == 422 + + @pytest.mark.asyncio + async def test_embedding_failure_returns_502(self, client: AsyncClient, tmp_path): + """Test that an embedding-provider failure is batch-fatal (502).""" + (tmp_path / "docs").mkdir() + (tmp_path / "docs" / "test-proj-3.md").write_text( + "# Gamma\n\nGamma content.", encoding="utf-8" + ) + # Build a mock whose embed_texts raises — a MagicMock var (not the + # EmbeddingService-typed factory return) so mypy permits the assignment. + mock_service = MagicMock(spec=EmbeddingService) + mock_service.embed_texts = AsyncMock(side_effect=EmbeddingError("no key")) + + with ( + patch( + "app.features.rag.routes.RAGService", + partial(RAGService, base_dir=str(tmp_path)), + ), + patch( + "app.features.rag.service.get_embedding_service", + return_value=mock_service, + ), + ): + response = await client.post("/rag/index/project-docs", json={}) + + assert response.status_code == 502 diff --git a/app/features/rag/tests/test_schemas.py b/app/features/rag/tests/test_schemas.py index 3a1881e7..479a95a2 100644 --- a/app/features/rag/tests/test_schemas.py +++ b/app/features/rag/tests/test_schemas.py @@ -6,8 +6,11 @@ from app.features.rag.schemas import ( ChunkResult, DeleteResponse, + IndexProjectDocsRequest, + IndexProjectDocsResponse, IndexRequest, IndexResponse, + ProjectDocResult, RetrieveRequest, RetrieveResponse, SourceListResponse, @@ -344,3 +347,115 @@ def test_valid_delete_response(self): ) assert response.status == "deleted" assert response.chunks_deleted == 10 + + +class TestIndexProjectDocsRequest: + """Tests for IndexProjectDocsRequest schema.""" + + def test_defaults_all_true(self): + """Test that an empty request defaults every root to True.""" + request = IndexProjectDocsRequest() + assert request.include_docs is True + assert request.include_prps is True + assert request.include_root is True + + def test_model_validate_empty_dict(self): + """Test that an empty {} body validates (the frontend always posts {}).""" + request = IndexProjectDocsRequest.model_validate({}) + assert request.include_docs is True + assert request.include_prps is True + assert request.include_root is True + + def test_toggles_select_roots_independently(self): + """Test that each include_* toggle is honored independently.""" + request = IndexProjectDocsRequest(include_docs=True, include_prps=False, include_root=False) + assert request.include_docs is True + assert request.include_prps is False + assert request.include_root is False + + def test_extra_fields_rejected(self): + """Test that an unknown body field is rejected (extra='forbid').""" + with pytest.raises(ValidationError) as exc_info: + IndexProjectDocsRequest(bogus=True) # type: ignore[call-arg] + assert "bogus" in str(exc_info.value) + + +class TestProjectDocResult: + """Tests for ProjectDocResult schema.""" + + def test_valid_result(self): + """Test a valid per-file result.""" + result = ProjectDocResult( + source_path="docs/ARCHITECTURE.md", + status="indexed", + chunks_created=7, + ) + assert result.status == "indexed" + assert result.chunks_created == 7 + assert result.error is None + + def test_failed_result_carries_error(self): + """Test a failed result carries an error string.""" + result = ProjectDocResult( + source_path="docs/bad.md", + status="failed", + chunks_created=0, + error="not valid UTF-8", + ) + assert result.status == "failed" + assert result.error == "not valid UTF-8" + + def test_invalid_status_rejected(self): + """Test that an out-of-Literal status is rejected.""" + with pytest.raises(ValidationError) as exc_info: + ProjectDocResult( + source_path="docs/x.md", + status="bogus", # type: ignore[arg-type] + chunks_created=0, + ) + assert "status" in str(exc_info.value) + + +class TestIndexProjectDocsResponse: + """Tests for IndexProjectDocsResponse schema.""" + + def test_valid_response_round_trips(self): + """Test a populated aggregate response round-trips through validation.""" + response = IndexProjectDocsResponse( + results=[ + ProjectDocResult(source_path="docs/a.md", status="indexed", chunks_created=3), + ProjectDocResult( + source_path="PRPs/b.md", + status="failed", + chunks_created=0, + error="boom", + ), + ], + total_files=2, + indexed=1, + updated=0, + unchanged=0, + failed=1, + total_chunks=3, + duration_ms=42.5, + ) + assert response.total_files == 2 + assert response.indexed == 1 + assert response.failed == 1 + assert response.total_chunks == 3 + assert len(response.results) == 2 + + def test_empty_response(self): + """Test an aggregate response with no discovered files.""" + response = IndexProjectDocsResponse( + results=[], + total_files=0, + indexed=0, + updated=0, + unchanged=0, + failed=0, + total_chunks=0, + duration_ms=1.0, + ) + assert response.total_files == 0 + assert len(response.results) == 0 diff --git a/app/features/rag/tests/test_service.py b/app/features/rag/tests/test_service.py index 52a7afc2..836bc84b 100644 --- a/app/features/rag/tests/test_service.py +++ b/app/features/rag/tests/test_service.py @@ -1,11 +1,16 @@ """Unit tests for RAG service.""" import hashlib +from pathlib import Path from unittest.mock import AsyncMock, MagicMock, patch import pytest -from app.features.rag.schemas import IndexRequest, RetrieveRequest +from app.features.rag.schemas import ( + IndexProjectDocsRequest, + IndexRequest, + RetrieveRequest, +) from app.features.rag.service import RAGService, SourceNotFoundError @@ -70,6 +75,88 @@ def test_read_content_from_path_traversal_blocked(self, tmp_path): service._read_content_from_path("/etc/passwd") +class TestRAGServiceDiscoverProjectDocFiles: + """Unit tests for RAGService._discover_project_doc_files (pure, no DB).""" + + @staticmethod + def _build_tree(tmp_path: Path) -> None: + """Create a fixture doc tree under tmp_path.""" + (tmp_path / "docs" / "sub").mkdir(parents=True) + (tmp_path / "PRPs").mkdir() + (tmp_path / "docs" / "test-a.md").write_text("# A", encoding="utf-8") + (tmp_path / "docs" / "sub" / "test-b.md").write_text("# B", encoding="utf-8") + (tmp_path / "docs" / "notes.txt").write_text("not markdown", encoding="utf-8") + (tmp_path / "PRPs" / "test-c.md").write_text("# C", encoding="utf-8") + (tmp_path / "README.md").write_text("# Readme", encoding="utf-8") + + def test_discovers_all_roots(self, tmp_path): + """Test discovery across docs/, PRPs/, and the root allow-list.""" + self._build_tree(tmp_path) + service = RAGService(base_dir=str(tmp_path)) + + found = service._discover_project_doc_files(IndexProjectDocsRequest()) + + rel = {p.relative_to(tmp_path).as_posix(): cat for p, cat in found} + assert rel == { + "docs/test-a.md": "docs", + "docs/sub/test-b.md": "docs", + "PRPs/test-c.md": "prp", + "README.md": "root", + } + + def test_filters_non_markdown(self, tmp_path): + """Test that non-.md files (notes.txt) are excluded.""" + self._build_tree(tmp_path) + service = RAGService(base_dir=str(tmp_path)) + + found = service._discover_project_doc_files(IndexProjectDocsRequest()) + + assert all(p.suffix == ".md" for p, _ in found) + + def test_result_is_sorted(self, tmp_path): + """Test that discovery returns a deterministically sorted list.""" + self._build_tree(tmp_path) + service = RAGService(base_dir=str(tmp_path)) + + found = service._discover_project_doc_files(IndexProjectDocsRequest()) + + paths = [str(p) for p, _ in found] + assert paths == sorted(paths) + + def test_toggles_select_roots(self, tmp_path): + """Test that include_* toggles select roots independently.""" + self._build_tree(tmp_path) + service = RAGService(base_dir=str(tmp_path)) + + docs_only = service._discover_project_doc_files( + IndexProjectDocsRequest(include_prps=False, include_root=False) + ) + assert {cat for _, cat in docs_only} == {"docs"} + + no_root = service._discover_project_doc_files(IndexProjectDocsRequest(include_root=False)) + assert "root" not in {cat for _, cat in no_root} + + def test_missing_root_directory_yields_nothing(self, tmp_path): + """Test that an absent docs/ or PRPs/ root contributes 0 files.""" + # tmp_path is empty — no docs/, no PRPs/, no root markdown. + service = RAGService(base_dir=str(tmp_path)) + + found = service._discover_project_doc_files(IndexProjectDocsRequest()) + + assert found == [] + + def test_root_allow_list_only(self, tmp_path): + """Test that only allow-listed root files are discovered.""" + (tmp_path / "README.md").write_text("# Readme", encoding="utf-8") + (tmp_path / "NOTES.md").write_text("# Notes", encoding="utf-8") + service = RAGService(base_dir=str(tmp_path)) + + found = service._discover_project_doc_files(IndexProjectDocsRequest()) + + names = {p.name for p, _ in found} + assert names == {"README.md"} + + class TestRAGServiceIndexDocument: """Tests for index_document method.""" diff --git a/app/features/registry/service.py b/app/features/registry/service.py index ceb0b4bf..076910c1 100644 --- a/app/features/registry/service.py +++ b/app/features/registry/service.py @@ -322,9 +322,11 @@ async def list_runs( else: order_by = ModelRun.created_at.desc() - # Apply pagination + # Apply pagination. Append the unique `run_id` as a tie-breaker so rows + # with equal sort values keep a stable order across pages (offset + # pagination over a non-unique sort key is otherwise non-deterministic). offset = (page - 1) * page_size - stmt = stmt.order_by(order_by).offset(offset).limit(page_size) + stmt = stmt.order_by(order_by, ModelRun.run_id.asc()).offset(offset).limit(page_size) result = await db.execute(stmt) runs = result.scalars().all() diff --git a/app/features/scenarios/__init__.py b/app/features/scenarios/__init__.py new file mode 100644 index 00000000..20f9d538 --- /dev/null +++ b/app/features/scenarios/__init__.py @@ -0,0 +1,31 @@ +"""Scenario Simulation / What-If Planning slice. + +A vertical slice that turns a baseline forecast into a *plan*: it loads an +already-trained baseline model, runs its forecast, applies deterministic, +transparent uplift / drag factors for future assumptions (price change, +promotion, holiday, inventory, lifecycle), and returns a baseline-vs-scenario +comparison. Comparisons can be persisted as named ``scenario_plan`` rows. + +DECISIONS LOCKED (PRP-26): the baseline forecasters ignore exogenous +regressors, so a "what-if" is applied as a deterministic post-forecast +multiplier — never a leakage-prone re-training. Every result is explicitly +labelled ``method = "heuristic"`` with a fixed disclaimer. +""" + +from app.features.scenarios.models import ScenarioPlan +from app.features.scenarios.routes import router +from app.features.scenarios.schemas import ( + ScenarioComparison, + ScenarioListResponse, + ScenarioPlanResponse, +) +from app.features.scenarios.service import ScenarioService + +__all__ = [ + "ScenarioComparison", + "ScenarioListResponse", + "ScenarioPlan", + "ScenarioPlanResponse", + "ScenarioService", + "router", +] diff --git a/app/features/scenarios/adjustments.py b/app/features/scenarios/adjustments.py new file mode 100644 index 00000000..a4152e3f --- /dev/null +++ b/app/features/scenarios/adjustments.py @@ -0,0 +1,169 @@ +"""Pure deterministic adjustment engine for scenario simulation. + +Every function here is a pure factor computation — no DB, no I/O, no mutation +of its inputs, and it NEVER raises on junk input (a negative price change, an +unknown promotion kind, a ``None`` lifecycle stage all return a sane factor). + +DECISIONS LOCKED (PRP-26 #1): the baseline forecasters ignore exogenous +regressors, so a what-if cannot be answered by re-prediction. The MVP applies +these factors as a post-forecast multiplier on a baseline forecast. Each factor +is a documented, tunable constant so a reviewer can see and adjust the +heuristic; the tests assert direction and bounds, not exact magnitudes. +""" + +from __future__ import annotations + +from datetime import date +from typing import TYPE_CHECKING, Literal + +if TYPE_CHECKING: + from app.features.scenarios.schemas import ScenarioAssumptions + +# Constant-elasticity price response: factor = (1 + change_pct) ** PRICE_ELASTICITY. +# A negative elasticity means a price cut (change_pct < 0) lifts demand. +PRICE_ELASTICITY: float = -1.2 + +# Multiplicative demand uplift per promotion kind (1.0 == no effect). +PROMOTION_UPLIFT_BY_KIND: dict[str, float] = { + "pct_off": 1.25, + "bogo": 1.40, + "bundle": 1.15, + "markdown": 1.30, +} + +# Demand uplift applied on a holiday / event day. +HOLIDAY_UPLIFT: float = 1.30 + +# Demand multiplier per forced product lifecycle stage. +LIFECYCLE_FACTOR: dict[str, float] = { + "launch": 1.2, + "growth": 1.1, + "maturity": 1.0, + "decline": 0.85, +} + +# Clamp band — keeps a combined factor away from a zero / explosive forecast. +FACTOR_BAND: tuple[float, float] = (0.1, 5.0) + +# Relative band around on-hand stock within which coverage is "at_risk". +COVERAGE_AT_RISK_BAND: float = 0.10 + +CoverageVerdict = Literal["covered", "at_risk", "stockout", "unknown"] + + +def clamp(value: float, lo: float, hi: float) -> float: + """Clamp ``value`` into the inclusive ``[lo, hi]`` range.""" + return max(lo, min(hi, value)) + + +def price_factor(price_change_pct: float) -> float: + """Return the demand multiplier for a relative price change. + + Constant-elasticity response: ``(1 + change) ** PRICE_ELASTICITY``. A price + cut (negative change) yields a factor > 1; a price rise yields < 1. + Tolerates junk — a change of -100% or worse (a non-positive price) clamps to + the upper band rather than raising or returning a complex / NaN value. + """ + base = 1.0 + price_change_pct + if base <= 0.0: + return FACTOR_BAND[1] + return clamp(base**PRICE_ELASTICITY, *FACTOR_BAND) + + +def promotion_factor(kind: str, active: bool) -> float: + """Return the demand multiplier for a promotion of ``kind``. + + Returns ``1.0`` when the promotion is not active or the kind is unknown. + """ + if not active: + return 1.0 + return PROMOTION_UPLIFT_BY_KIND.get(kind, 1.0) + + +def holiday_factor(is_holiday: bool) -> float: + """Return the demand multiplier for a holiday / event day.""" + return HOLIDAY_UPLIFT if is_holiday else 1.0 + + +def lifecycle_factor(stage: str | None) -> float: + """Return the demand multiplier for a product lifecycle stage. + + Returns ``1.0`` for ``None`` or an unknown stage. + """ + if stage is None: + return 1.0 + return LIFECYCLE_FACTOR.get(stage, 1.0) + + +def _in_window(point_date: date, start: date, end: date) -> bool: + """True when ``point_date`` is inside the inclusive ``[start, end]`` window. + + A reversed window (``start`` after ``end``) is normalised rather than + treated as empty — junk input must not raise. + """ + lo, hi = (start, end) if start <= end else (end, start) + return lo <= point_date <= hi + + +def combined_daily_factor(point_date: date, assumptions: ScenarioAssumptions) -> float: + """Multiply every applicable per-day factor for ``point_date``, then clamp. + + Time-safety: every window test is keyed on ``point_date`` — always a horizon + (future) date — so an assumption window that falls entirely before the + forecast start contributes factor ``1.0`` and can never reach back into the + historical series. An empty ``ScenarioAssumptions`` yields exactly ``1.0``. + """ + factor = 1.0 + + price = assumptions.price + if price is not None and _in_window(point_date, price.start_date, price.end_date): + factor *= price_factor(price.change_pct) + + promotion = assumptions.promotion + if promotion is not None and _in_window(point_date, promotion.start_date, promotion.end_date): + factor *= promotion_factor(promotion.kind, active=True) + + holiday = assumptions.holiday + if holiday is not None and point_date in holiday.dates: + factor *= holiday_factor(True) + + lifecycle = assumptions.lifecycle + if lifecycle is not None: + factor *= lifecycle_factor(lifecycle.stage) + + return clamp(factor, *FACTOR_BAND) + + +def apply_adjustment(baseline: list[float], factors: list[float]) -> list[float]: + """Element-wise multiply ``baseline`` by ``factors``, flooring each at 0.0. + + Returns a NEW list — the input ``baseline`` is never mutated (the leakage + spec depends on this). Raises ``ValueError`` on a length mismatch: that is a + caller-contract violation, not junk data. + """ + if len(baseline) != len(factors): + raise ValueError( + f"baseline and factors must be equal length: {len(baseline)} != {len(factors)}" + ) + return [max(0.0, value * factor) for value, factor in zip(baseline, factors, strict=True)] + + +def coverage_verdict(scenario_total_units: float, on_hand_units: int | None) -> CoverageVerdict: + """Classify whether projected demand is covered by on-hand stock. + + Returns ``unknown`` when no inventory assumption was supplied. Otherwise: + ``covered`` when demand sits comfortably below stock, ``at_risk`` when it is + within ``COVERAGE_AT_RISK_BAND`` of stock, ``stockout`` when it exceeds that + band. Never raises. + """ + if on_hand_units is None: + return "unknown" + if on_hand_units <= 0: + return "stockout" if scenario_total_units > 0.0 else "at_risk" + upper = on_hand_units * (1.0 + COVERAGE_AT_RISK_BAND) + lower = on_hand_units * (1.0 - COVERAGE_AT_RISK_BAND) + if scenario_total_units > upper: + return "stockout" + if scenario_total_units >= lower: + return "at_risk" + return "covered" diff --git a/app/features/scenarios/agent_tools.py b/app/features/scenarios/agent_tools.py new file mode 100644 index 00000000..ab99c17c --- /dev/null +++ b/app/features/scenarios/agent_tools.py @@ -0,0 +1,195 @@ +"""Agent-facing tools for the Scenario Simulation slice (PRP-27 Phase D). + +This module is the *integration seam* between the agent layer and the +scenarios slice. ``app/features/agents/`` imports THIS module — never +``scenarios/service.py`` directly — so the no-cross-slice-``service.py``-import +rule (DECISIONS LOCKED #3) holds while the agent still gains scenario tools. + +Two tools live here: + +* :func:`propose_scenario` — **read-only**. Returns a candidate + ``ScenarioAssumptions`` plus a plain-language recommendation. It proposes, + it never persists, so it needs no approval. +* :func:`save_scenario` — **mutating**. Persists a ``scenario_plan`` row via the + scenarios service create path, stamped ``source='agent'`` with the originating + ``agent_session_id`` and the HITL approval audit trail. It runs only after the + human-in-the-loop gate releases it — its tool name is in + ``agent_require_approval`` (DECISIONS LOCKED #13). +""" + +from __future__ import annotations + +from datetime import UTC, datetime, timedelta +from typing import Any + +import structlog +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.features.data_platform.models import SalesDaily +from app.features.scenarios.models import SCENARIO_SOURCE_AGENT +from app.features.scenarios.schemas import ( + CreateScenarioRequest, + PriceAssumption, + PromotionAssumption, + SaveScenarioRequest, + ScenarioAssumptions, +) +from app.features.scenarios.service import ScenarioService + +logger = structlog.get_logger() + +# Keywords in a free-text objective that steer the proposal toward a promotion +# rather than the default price cut. +_PROMOTION_KEYWORDS = ("promo", "promotion", "discount", "sale", "markdown") + +# The default magnitude of the proposed price cut (a 15% reduction). +_PROPOSED_PRICE_CHANGE_PCT = -0.15 + +# Recorded as ``approved_by`` on an agent-saved plan. The system is single-host +# and unauthenticated, so the approving party is simply the local operator who +# released the HITL gate. +AGENT_SAVE_APPROVED_BY = "operator" + + +async def propose_scenario( + db: AsyncSession, + store_id: int, + product_id: int, + horizon: int, + objective: str, +) -> dict[str, Any]: + """Propose a candidate what-if scenario for a (store, product) grain. + + READ-ONLY: this tool builds a candidate ``ScenarioAssumptions`` and a + recommendation; it performs no database writes. Persisting the proposal is + a separate, approval-gated step (:func:`save_scenario`). + + Args: + db: Database session (used only to read a recent unit price for a + grounded recommendation). + store_id: Store the proposed scenario targets. + product_id: Product the proposed scenario targets. + horizon: Number of days the proposed scenario should span. + objective: Free-text planning objective — keywords such as "promotion" + steer the proposal toward a promotion instead of a price cut. + + Returns: + A dict with the target grain, the horizon, the originating objective, + the candidate ``assumptions`` (JSON-mode dump so dates are ISO strings, + ready to pass straight back into ``save_scenario``), and a + plain-language ``recommendation``. + """ + logger.info( + "agents.scenario_tool.propose_scenario_called", + store_id=store_id, + product_id=product_id, + horizon=horizon, + ) + + # Read the most recent unit price for a grounded recommendation. Read-only. + latest_price = await db.scalar( + select(SalesDaily.unit_price) + .where(SalesDaily.store_id == store_id, SalesDaily.product_id == product_id) + .order_by(SalesDaily.date.desc()) + .limit(1) + ) + + start = datetime.now(UTC).date() + timedelta(days=1) + end = start + timedelta(days=horizon - 1) + + if any(keyword in objective.lower() for keyword in _PROMOTION_KEYWORDS): + assumptions = ScenarioAssumptions( + promotion=PromotionAssumption(kind="pct_off", start_date=start, end_date=end) + ) + rationale = ( + f"Run a pct_off promotion from {start} to {end} ({horizon} days) and " + "simulate the demand lift before committing." + ) + else: + assumptions = ScenarioAssumptions( + price=PriceAssumption( + change_pct=_PROPOSED_PRICE_CHANGE_PCT, start_date=start, end_date=end + ) + ) + price_note = ( + f" The most recent unit price is ~{float(latest_price):.2f}." + if latest_price is not None + else "" + ) + rationale = ( + f"Cut price {abs(_PROPOSED_PRICE_CHANGE_PCT) * 100:.0f}% from {start} to " + f"{end} ({horizon} days) to test the demand response.{price_note}" + ) + + recommendation = ( + f"Proposed what-if for store {store_id} / product {product_id} toward the " + f"objective '{objective}'. {rationale} This is a candidate only — review it " + "and save it explicitly to persist a scenario plan." + ) + return { + "store_id": store_id, + "product_id": product_id, + "horizon": horizon, + "objective": objective, + "assumptions": assumptions.model_dump(mode="json"), + "recommendation": recommendation, + } + + +async def save_scenario( + db: AsyncSession, + request: SaveScenarioRequest, + *, + agent_session_id: str | None, +) -> dict[str, Any]: + """Persist an agent-proposed scenario as a saved ``scenario_plan`` row. + + MUTATING: this tool writes a row. It runs only once the HITL approval gate + has released it (``save_scenario`` is in ``agent_require_approval``), so the + persisted plan always carries an ``approved`` audit trail. + + Args: + db: Database session. + request: The validated scenario to persist (name, run_id, horizon, + assumptions). + agent_session_id: The originating agent session id — the runtime truth, + authoritative over any value carried on ``request``. + + Returns: + The saved plan as a JSON-mode dict, including its embedded comparison + snapshot and the provenance / audit fields. + + Raises: + FileNotFoundError: When no model artifact exists for ``request.run_id``. + ValueError: When the artifact path or its metadata is invalid. + """ + logger.info( + "agents.scenario_tool.save_scenario_called", + store_id=request.store_id, + product_id=request.product_id, + agent_session_id=agent_session_id, + ) + + service = ScenarioService() + create_request = CreateScenarioRequest( + name=request.name, + run_id=request.run_id, + horizon=request.horizon, + assumptions=request.assumptions, + ) + plan = await service.create_plan( + db, + create_request, + source=SCENARIO_SOURCE_AGENT, + agent_session_id=agent_session_id, + approved_by=AGENT_SAVE_APPROVED_BY, + approval_decision="approved", + ) + + logger.info( + "agents.scenario_tool.save_scenario_completed", + scenario_id=plan.scenario_id, + agent_session_id=agent_session_id, + ) + return plan.model_dump(mode="json") diff --git a/app/features/scenarios/feature_frame.py b/app/features/scenarios/feature_frame.py new file mode 100644 index 00000000..0c9c2635 --- /dev/null +++ b/app/features/scenarios/feature_frame.py @@ -0,0 +1,407 @@ +"""Leakage-safe future feature-frame generator (PRP-27 Phase A). + +The scenario MVP (PRP-26) never builds a future feature matrix — it multiplies +a baseline forecast by a deterministic factor, so it is *immune* to leakage. +The Full Version introduces a model-driven path (``method="model_exogenous"``) +that re-forecasts demand through a feature-consuming regressor, and that needs +a **future feature frame**: the same feature columns the model was trained on, +produced for each horizon day ``T+1 … T+horizon``. + +That is a new and dangerous surface — a horizon day ``D`` has *no observed +target* — so this module is governed by one rule: + + A future feature value for day ``D`` may only use information knowable at + the forecast origin ``T`` (the last training day): the observed history + up to and including ``T``, the calendar (a pure function of the date), or + the scenario assumptions (the planner's *posited* future inputs). + It may NEVER read an observed target at a horizon day. + +``app/features/scenarios/tests/test_future_frame_leakage.py`` is the +load-bearing spec for that rule — it must never be weakened (AGENTS.md +§ Safety), mirroring ``app/features/featuresets/tests/test_leakage.py``. + +DECISIONS LOCKED (PRP-27): +* #3 — no cross-slice ``service.py`` import. This module imports only the + ``data_platform`` ORM (a sanctioned read-only ORM import) and same-slice + schema value-objects; it replicates the small slice of leakage-safe + lag/calendar logic it needs rather than importing + ``FeatureEngineeringService``. +* #4 — long-lag + calendar + assumption-driven columns ONLY; no recursion. + A target lag value for horizon day ``T+j`` is the observed ``y[T+j-k]``; + when ``T+j-k > T`` (a future target) the cell is ``NaN`` — the model + (``HistGradientBoostingRegressor``) handles ``NaN`` natively. No recursion + ever fills those gaps in v1. +* #10/#11/#12 — the PINNED constants ``EXOGENOUS_LAGS``, + ``HISTORY_TAIL_DAYS`` and ``MAX_COMPARE_SCENARIOS`` live here. + +Feature-column contract: ``canonical_feature_columns()`` is the single source +of truth for the regression feature set and column order. The Phase B training +path persists exactly this list in the bundle metadata, and the future frame +reproduces it column-for-column, so a model trained today re-forecasts cleanly. +""" + +from __future__ import annotations + +import math +from dataclasses import dataclass +from datetime import date, timedelta +from typing import TYPE_CHECKING + +from sqlalchemy import select + +from app.core.logging import get_logger +from app.features.data_platform.models import Calendar + +if TYPE_CHECKING: + from sqlalchemy.ext.asyncio import AsyncSession + + from app.features.scenarios.schemas import ScenarioAssumptions + +logger = get_logger(__name__) + +# ── PINNED modelling constants (PRP-27 DECISIONS LOCKED #10/#11/#12) ── +# Lag offsets (days) for the target long-lag columns: daily, weekly, +# fortnightly, and a four-week lag covering the dominant retail seasonality. +EXOGENOUS_LAGS: tuple[int, ...] = (1, 7, 14, 28) +# Observed-target tail (days, ending at the forecast origin T) fed to the +# generator — 90 comfortably exceeds the largest lag offset (28). +HISTORY_TAIL_DAYS: int = 90 +# Upper bound on the multi-scenario comparison (Phase C) so the chart stays +# legible; defined here as the slice's single modelling-constants home. +MAX_COMPARE_SCENARIOS: int = 5 + +# Fixed calendar columns — each a pure function of the date, never a leak. +CALENDAR_COLUMNS: tuple[str, ...] = ( + "dow_sin", + "dow_cos", + "month_sin", + "month_cos", + "is_weekend", + "is_month_end", +) +# Fixed current-day exogenous columns — driven by the scenario assumptions +# (the planner's posited future inputs) and by timeless attributes (the +# calendar, the product launch date). Every value is knowable at origin T. +EXOGENOUS_COLUMNS: tuple[str, ...] = ( + "price_factor", + "promo_active", + "is_holiday", + "days_since_launch", +) + + +@dataclass +class FutureFeatureFrame: + """A horizon-length feature matrix for one ``(store, product)`` series. + + Attributes: + dates: The horizon days ``T+1 … T+horizon`` (chronological). + feature_columns: Column order — matches the trained bundle exactly. + matrix: Row-major ``[horizon][n_features]``; ``NaN`` is allowed and + expected (a long-lag cell whose source target lies in the future, + or ``days_since_launch`` when the product has no launch date). + """ + + dates: list[date] + feature_columns: list[str] + matrix: list[list[float]] + + +def canonical_feature_columns(lags: tuple[int, ...] = EXOGENOUS_LAGS) -> list[str]: + """Return the fixed, ordered regression feature-column list. + + This is the single source of truth for the regression feature set. The + Phase B training path persists exactly this list in the model bundle's + metadata; the future frame reproduces it column-for-column. The column + set is deliberately *fixed* (not horizon-dependent): for a long horizon + some target-lag columns are mostly ``NaN``, which the NaN-tolerant + estimator handles — far safer than a horizon-varying column set. + + Args: + lags: Target long-lag offsets (defaults to the pinned ``EXOGENOUS_LAGS``). + + Returns: + Ordered column names: target lags, then calendar, then exogenous. + """ + target_lags = [f"lag_{k}" for k in lags] + return [*target_lags, *CALENDAR_COLUMNS, *EXOGENOUS_COLUMNS] + + +def _in_window(point_date: date, start: date, end: date) -> bool: + """True when ``point_date`` is inside the inclusive ``[start, end]`` window. + + A reversed window (``start`` after ``end``) is normalised rather than + treated as empty — junk input must never raise (mirrors + ``adjustments._in_window``). + """ + lo, hi = (start, end) if start <= end else (end, start) + return lo <= point_date <= hi + + +def _is_month_end(point_date: date) -> bool: + """True when ``point_date`` is the last day of its month.""" + return (point_date + timedelta(days=1)).month != point_date.month + + +def build_calendar_columns(dates: list[date]) -> dict[str, list[float]]: + """Build the calendar feature columns — a pure function of each date. + + Calendar features carry zero leakage risk: they read only the date + itself, never the target series. Day-of-week and month use cyclical + (sin/cos) encoding so the estimator sees their periodic structure. + + Args: + dates: The horizon days. + + Returns: + A mapping of every name in :data:`CALENDAR_COLUMNS` to its per-day + values. + """ + columns: dict[str, list[float]] = {name: [] for name in CALENDAR_COLUMNS} + for point_date in dates: + dow = point_date.weekday() # 0 = Monday … 6 = Sunday + month = point_date.month + columns["dow_sin"].append(math.sin(2.0 * math.pi * dow / 7.0)) + columns["dow_cos"].append(math.cos(2.0 * math.pi * dow / 7.0)) + columns["month_sin"].append(math.sin(2.0 * math.pi * month / 12.0)) + columns["month_cos"].append(math.cos(2.0 * math.pi * month / 12.0)) + columns["is_weekend"].append(1.0 if dow >= 5 else 0.0) + columns["is_month_end"].append(1.0 if _is_month_end(point_date) else 0.0) + return columns + + +def build_long_lag_columns( + history_tail: list[float], + horizon: int, + lags: tuple[int, ...] = EXOGENOUS_LAGS, +) -> dict[str, list[float]]: + """Build the target long-lag columns — the leakage-critical helper. + + ``history_tail`` is the observed target series ending at the forecast + origin ``T``: ``history_tail[-1] == y[T]``, ``history_tail[-2] == y[T-1]``, + and so on. The lag-``k`` column at horizon day ``T+j`` (``j`` in + ``1 … horizon``) is the observed target ``y[T+j-k]``. + + SAFETY (PRP-27 DECISIONS LOCKED #4): the source index into + ``history_tail`` is ``idx = (j - 1) - k``. The cell is populated **only + when ``idx < 0``** — i.e. the source day ``T+j-k`` lies at or before the + origin ``T`` and therefore inside ``history_tail``. When ``idx >= 0`` the + source day is a *future* horizon day with no observed target, so the cell + is ``NaN`` — never a recursive prediction, never a fabricated value. This + function structurally **cannot** read a future target: its only data + input is ``history_tail`` (entirely ``<= T``). + + Args: + history_tail: Observed target values ending at the origin ``T``. + horizon: Number of horizon days. + lags: Lag offsets (defaults to the pinned ``EXOGENOUS_LAGS``). + + Returns: + A mapping ``"lag_{k}" -> [horizon values]``; out-of-range cells are + ``NaN``. + """ + tail_len = len(history_tail) + columns: dict[str, list[float]] = {} + for lag in lags: + column: list[float] = [] + for j in range(1, horizon + 1): + # Negative index from the end of history_tail. idx < 0 means the + # source day T+j-k is at/before the origin T — safe to read. + idx = (j - 1) - lag + if idx < 0 and -tail_len <= idx: + column.append(float(history_tail[idx])) + else: + column.append(math.nan) + columns[f"lag_{lag}"] = column + return columns + + +def build_exogenous_columns( + dates: list[date], + assumptions: ScenarioAssumptions, + holiday_dates: set[date], + launch_date: date | None, +) -> dict[str, list[float]]: + """Build the current-day exogenous columns from the scenario assumptions. + + These columns are the *intended* what-if input — the planner is positing a + future price / promotion / holiday — so reading them is not leakage. Each + is knowable at origin ``T``: + + * ``price_factor`` — ``1.0`` (the typical price) outside any price window, + ``1.0 + change_pct`` inside it. + * ``promo_active`` — ``1.0`` when a promotion assumption covers the day. + * ``is_holiday`` — ``1.0`` when the day is in the holiday assumption OR a + ``calendar`` holiday (a calendar row is a timeless attribute). + * ``days_since_launch`` — ``(date - launch_date).days``, a pure function of + the date; ``NaN`` when the product has no launch date. + + Args: + dates: The horizon days. + assumptions: The scenario assumptions. + holiday_dates: Calendar holiday dates inside the horizon. + launch_date: The product's launch date, or ``None``. + + Returns: + A mapping of every name in :data:`EXOGENOUS_COLUMNS` to its per-day + values. + """ + price = assumptions.price + promotion = assumptions.promotion + holiday = assumptions.holiday + assumption_holidays: set[date] = set(holiday.dates) if holiday is not None else set() + + price_factor: list[float] = [] + promo_active: list[float] = [] + is_holiday: list[float] = [] + days_since_launch: list[float] = [] + + for point_date in dates: + if price is not None and _in_window(point_date, price.start_date, price.end_date): + price_factor.append(1.0 + price.change_pct) + else: + price_factor.append(1.0) + + if promotion is not None and _in_window( + point_date, promotion.start_date, promotion.end_date + ): + promo_active.append(1.0) + else: + promo_active.append(0.0) + + is_holiday.append( + 1.0 if point_date in assumption_holidays or point_date in holiday_dates else 0.0 + ) + + if launch_date is not None: + days_since_launch.append(float((point_date - launch_date).days)) + else: + days_since_launch.append(math.nan) + + return { + "price_factor": price_factor, + "promo_active": promo_active, + "is_holiday": is_holiday, + "days_since_launch": days_since_launch, + } + + +def assemble_future_frame( + *, + dates: list[date], + feature_columns: list[str], + history_tail: list[float], + assumptions: ScenarioAssumptions, + holiday_dates: set[date], + launch_date: date | None, +) -> FutureFeatureFrame: + """Assemble a :class:`FutureFeatureFrame` from already-resolved inputs. + + Pure (no DB, no I/O) so it is fully unit-testable; :func:`build_future_frame` + is the thin async wrapper that resolves ``holiday_dates`` from the + ``calendar`` table first. + + Any requested column not produced by the builders is filled with ``NaN`` + so the matrix always matches ``feature_columns`` in width and order. + + Args: + dates: The horizon days ``T+1 … T+horizon``. + feature_columns: The exact column order to emit. + history_tail: Observed target values ending at the origin ``T``. + assumptions: The scenario assumptions. + holiday_dates: Calendar holiday dates inside the horizon. + launch_date: The product's launch date, or ``None``. + + Returns: + The assembled future feature frame. + """ + horizon = len(dates) + column_data: dict[str, list[float]] = {} + column_data.update(build_long_lag_columns(history_tail, horizon)) + column_data.update(build_calendar_columns(dates)) + column_data.update(build_exogenous_columns(dates, assumptions, holiday_dates, launch_date)) + + # Defensive: any column the trained bundle expects but this generator does + # not produce becomes an all-NaN column (the estimator tolerates NaN). + for column in feature_columns: + if column not in column_data: + column_data[column] = [math.nan] * horizon + + matrix: list[list[float]] = [ + [column_data[column][j] for column in feature_columns] for j in range(horizon) + ] + return FutureFeatureFrame( + dates=list(dates), + feature_columns=list(feature_columns), + matrix=matrix, + ) + + +async def build_future_frame( + db: AsyncSession, + *, + store_id: int, + product_id: int, + forecast_origin: date, + horizon: int, + feature_columns: list[str], + history_tail: list[float], + assumptions: ScenarioAssumptions, + launch_date: date | None = None, +) -> FutureFeatureFrame: + """Build the future feature frame for one ``(store, product)`` series. + + The only database read is the ``calendar`` holiday lookup for the horizon + window — a ``calendar`` row is a timeless attribute, so reading it is not + leakage. Everything else is derived from ``history_tail`` (observed, + ``<= T``), the dates, or the assumptions. + + Args: + db: Async database session (used only for the calendar lookup). + store_id: Store the baseline model targets (logged). + product_id: Product the baseline model targets (logged). + forecast_origin: The origin ``T`` — the last training day. The horizon + runs ``T+1 … T+horizon``. + horizon: Number of horizon days (``>= 1``). + feature_columns: The trained bundle's feature-column order. + history_tail: Observed target values ending at ``T``. + assumptions: The scenario assumptions. + launch_date: The product's launch date, or ``None``. + + Returns: + The assembled future feature frame. + + Raises: + ValueError: When ``horizon`` is below 1. + """ + if horizon < 1: + raise ValueError(f"horizon must be >= 1, got {horizon}") + + dates = [forecast_origin + timedelta(days=offset) for offset in range(1, horizon + 1)] + + result = await db.execute( + select(Calendar.date).where( + Calendar.date >= dates[0], + Calendar.date <= dates[-1], + Calendar.is_holiday.is_(True), + ) + ) + holiday_dates: set[date] = set(result.scalars().all()) + + frame = assemble_future_frame( + dates=dates, + feature_columns=feature_columns, + history_tail=history_tail, + assumptions=assumptions, + holiday_dates=holiday_dates, + launch_date=launch_date, + ) + logger.info( + "scenarios.future_frame_built", + store_id=store_id, + product_id=product_id, + horizon=horizon, + n_features=len(feature_columns), + n_calendar_holidays=len(holiday_dates), + ) + return frame diff --git a/app/features/scenarios/models.py b/app/features/scenarios/models.py new file mode 100644 index 00000000..39f3b67e --- /dev/null +++ b/app/features/scenarios/models.py @@ -0,0 +1,116 @@ +"""Scenario plan ORM model. + +A ``scenario_plan`` row persists a saved what-if analysis: the raw +``ScenarioAssumptions`` *and* the full ``ScenarioComparison`` snapshot, both as +JSONB. Storing the snapshot (PRP-26 decision #3) means a reloaded plan +re-renders without recomputation — and without the original model artifact +still having to exist on disk. + +GOTCHA: SQLAlchemy reserves the declarative attribute name ``metadata``; the +JSONB columns are therefore named ``assumptions`` and ``comparison``. +""" + +from __future__ import annotations + +from datetime import datetime +from typing import Any + +from sqlalchemy import CheckConstraint, DateTime, Index, Integer, String, text +from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.orm import Mapped, mapped_column + +from app.core.database import Base +from app.shared.models import TimestampMixin + +# Adjustment methods — guarded by a CHECK constraint. ``heuristic`` is the MVP +# post-forecast multiplier; ``model_exogenous`` (PRP-27) re-forecasts through a +# feature-consuming regression model. +SCENARIO_METHOD_HEURISTIC = "heuristic" +SCENARIO_METHOD_MODEL_EXOGENOUS = "model_exogenous" + +# Provenance — who or what created a scenario plan (PRP-27 Phase D). +SCENARIO_SOURCE_USER = "user" +SCENARIO_SOURCE_AGENT = "agent" + + +class ScenarioPlan(TimestampMixin, Base): + """A saved scenario plan. + + Attributes: + id: Surrogate primary key. + scenario_id: Unique external identifier (UUID hex, 32 chars). + name: Human-readable plan name. + store_id: Store the baseline model targets. + product_id: Product the baseline model targets. + run_id: Artifact key of the baseline model (model_{run_id}.joblib). + horizon: Number of days simulated. + assumptions: Raw ScenarioAssumptions as JSONB. + comparison: Full ScenarioComparison snapshot as JSONB. + method: Adjustment method — always 'heuristic' (CHECK-constrained). + """ + + __tablename__ = "scenario_plan" + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + scenario_id: Mapped[str] = mapped_column(String(32), unique=True, index=True) + name: Mapped[str] = mapped_column(String(200), nullable=False) + store_id: Mapped[int] = mapped_column(Integer, index=True, nullable=False) + product_id: Mapped[int] = mapped_column(Integer, index=True, nullable=False) + run_id: Mapped[str] = mapped_column(String(32), index=True, nullable=False) + horizon: Mapped[int] = mapped_column(Integer, nullable=False) + + # JSONB blobs — never named ``metadata`` (SQLAlchemy reserves it). + assumptions: Mapped[dict[str, Any]] = mapped_column(JSONB, nullable=False) + comparison: Mapped[dict[str, Any]] = mapped_column(JSONB, nullable=False) + + method: Mapped[str] = mapped_column( + String(20), nullable=False, default=SCENARIO_METHOD_HEURISTIC + ) + + # Scenario-library columns (PRP-27 Phase C). ``tags`` is a JSONB string + # array — a real column (never folded into a JSONB blob) so it is + # queryable/indexable. ``cloned_from`` is the scenario_id this plan was + # cloned from, or NULL. + tags: Mapped[list[str]] = mapped_column( + JSONB, nullable=False, default=list, server_default=text("'[]'::jsonb") + ) + cloned_from: Mapped[str | None] = mapped_column(String(32), nullable=True) + + # Provenance + approval-audit columns (PRP-27 Phase D). ``source`` defaults + # to 'user'; an agent-saved plan carries 'agent' plus the originating + # session id and the human approval audit trail. + source: Mapped[str] = mapped_column( + String(16), nullable=False, default=SCENARIO_SOURCE_USER, server_default=text("'user'") + ) + agent_session_id: Mapped[str | None] = mapped_column(String(32), nullable=True) + approved_by: Mapped[str | None] = mapped_column(String(120), nullable=True) + approved_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True) + approval_decision: Mapped[str | None] = mapped_column(String(16), nullable=True) + + __table_args__ = ( + # GIN indexes for JSONB containment queries on either blob. + Index("ix_scenario_plan_assumptions_gin", "assumptions", postgresql_using="gin"), + Index("ix_scenario_plan_comparison_gin", "comparison", postgresql_using="gin"), + # Composite index for the common "plans for this store/product" query. + Index("ix_scenario_plan_store_product", "store_id", "product_id"), + # GIN index so the saved-plans list can filter by tag containment. + Index("ix_scenario_plan_tags_gin", "tags", postgresql_using="gin"), + # Index on source for the "show me agent-proposed plans" query. + Index("ix_scenario_plan_source", "source"), + # heuristic (MVP) or model_exogenous (PRP-27) — kept in lock-step with + # the alembic migration that widened this CHECK. + CheckConstraint( + "method IN ('heuristic', 'model_exogenous')", + name="ck_scenario_plan_method", + ), + # Provenance + approval-audit CHECKs — kept in lock-step with the + # alembic migration that added these columns. + CheckConstraint( + "source IN ('user', 'agent')", + name="ck_scenario_plan_source", + ), + CheckConstraint( + "approval_decision IS NULL OR approval_decision IN ('approved', 'rejected')", + name="ck_scenario_plan_approval_decision", + ), + ) diff --git a/app/features/scenarios/routes.py b/app/features/scenarios/routes.py new file mode 100644 index 00000000..6ce4c1ec --- /dev/null +++ b/app/features/scenarios/routes.py @@ -0,0 +1,247 @@ +"""API routes for the Scenario Simulation slice. + +Five endpoints back the ``Visualize → What-If Planner`` page: a stateless +``POST /scenarios/simulate`` plus CRUD over saved ``scenario_plan`` rows. + +Service-layer ``FileNotFoundError`` / ``ValueError`` map to RFC 7807 problem +responses via the ``app.core.exceptions`` ``ForecastLabError`` hierarchy +(``application/problem+json``) — a bogus ``run_id`` never surfaces as a 500. +""" + +from fastapi import APIRouter, Depends, Query, status +from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.database import get_db +from app.core.exceptions import BadRequestError, DatabaseError, NotFoundError +from app.core.logging import get_logger +from app.features.scenarios.schemas import ( + CompareScenariosRequest, + CreateScenarioRequest, + MultiScenarioComparison, + ScenarioComparison, + ScenarioListResponse, + ScenarioPlanResponse, + SimulateScenarioRequest, +) +from app.features.scenarios.service import ScenarioService + +logger = get_logger(__name__) + +router = APIRouter(prefix="/scenarios", tags=["scenarios"]) + + +@router.post( + "/simulate", + response_model=ScenarioComparison, + status_code=status.HTTP_200_OK, + summary="Run a stateless what-if simulation", + description=""" +Run a baseline forecast for an existing trained model and apply deterministic +what-if adjustment factors. + +**Inputs:** +- `run_id`: artifact key of a baseline model (the `run_id` on a completed + predict/train job — `model_{run_id}.joblib`). +- `horizon`: number of days to simulate (1-90). +- `assumptions`: optional price / promotion / holiday / inventory / lifecycle + assumptions. Omit them all for a no-change baseline. + +**Output:** a `ScenarioComparison` — per-day baseline vs. scenario demand, +aggregate unit and revenue deltas, a coverage verdict, and a `method` +(`heuristic`) plus a `disclaimer`. The result is a deterministic post-forecast +multiplier, NOT a re-trained causal model. + +A bogus `run_id` returns a 404 problem response; an invalid artifact path +returns 400 — never a 500. +""", +) +async def simulate_scenario( + request: SimulateScenarioRequest, + db: AsyncSession = Depends(get_db), +) -> ScenarioComparison: + """Run a stateless scenario simulation. + + Args: + request: Baseline run_id, horizon, and what-if assumptions. + db: Async database session from dependency. + + Returns: + A baseline-vs-scenario comparison. + + Raises: + NotFoundError: When the model artifact is missing. + BadRequestError: When the request is otherwise invalid. + """ + try: + return await ScenarioService().simulate(db, request) + except FileNotFoundError as exc: + logger.warning("scenarios.simulate_not_found", run_id=request.run_id, error=str(exc)) + raise NotFoundError(message=str(exc)) from exc + except ValueError as exc: + logger.warning("scenarios.simulate_invalid", run_id=request.run_id, error=str(exc)) + raise BadRequestError(message=str(exc)) from exc + + +@router.post( + "", + response_model=ScenarioPlanResponse, + status_code=status.HTTP_201_CREATED, + summary="Save a scenario plan", + description=""" +Run a simulation and persist it as a named plan. + +The saved plan stores both the raw assumptions and the full comparison +snapshot, so a reloaded plan re-renders without recomputation. +""", +) +async def create_scenario( + request: CreateScenarioRequest, + db: AsyncSession = Depends(get_db), +) -> ScenarioPlanResponse: + """Persist a scenario plan. + + Args: + request: Plan name plus baseline run_id, horizon, and assumptions. + db: Async database session from dependency. + + Returns: + The saved plan with its embedded comparison snapshot. + + Raises: + NotFoundError: When the model artifact is missing. + BadRequestError: When the request is otherwise invalid. + DatabaseError: When the persistence operation fails. + """ + try: + return await ScenarioService().create_plan(db, request) + except FileNotFoundError as exc: + logger.warning("scenarios.create_not_found", run_id=request.run_id, error=str(exc)) + raise NotFoundError(message=str(exc)) from exc + except ValueError as exc: + logger.warning("scenarios.create_invalid", run_id=request.run_id, error=str(exc)) + raise BadRequestError(message=str(exc)) from exc + except SQLAlchemyError as exc: + logger.error("scenarios.create_db_error", error=str(exc), exc_info=True) + raise DatabaseError( + message="Failed to save scenario plan", + details={"error": str(exc)}, + ) from exc + + +@router.post( + "/compare", + response_model=MultiScenarioComparison, + status_code=status.HTTP_200_OK, + summary="Compare saved scenario plans", + description=""" +Rank 2-5 saved scenario plans against a shared baseline. + +Each saved plan embeds its own comparison snapshot, so this is a pure +aggregation — no model artifact is reloaded. An unknown `scenario_id` returns +a 404 problem response. +""", +) +async def compare_scenarios( + request: CompareScenariosRequest, + db: AsyncSession = Depends(get_db), +) -> MultiScenarioComparison: + """Compare 2-5 saved scenario plans. + + Args: + request: The 2-5 scenario_ids and the ranking metric. + db: Async database session from dependency. + + Returns: + A ranked multi-scenario comparison plus merged chart series. + + Raises: + NotFoundError: When any requested scenario_id does not exist. + """ + try: + return await ScenarioService().compare_scenarios(db, request) + except FileNotFoundError as exc: + logger.warning("scenarios.compare_not_found", error=str(exc)) + raise NotFoundError(message=str(exc)) from exc + + +@router.get( + "", + response_model=ScenarioListResponse, + summary="List saved scenario plans", + description="List saved scenario plans, newest first. Returns 200 + an " + "empty list when no plans exist. Pass one or more `tags` to filter to " + "plans carrying every listed tag.", +) +async def list_scenarios( + db: AsyncSession = Depends(get_db), + limit: int = Query(default=20, ge=1, le=100, description="Maximum plans to return."), + offset: int = Query(default=0, ge=0, description="Number of plans to skip."), + tags: list[str] | None = Query( + default=None, description="Filter to plans carrying every listed tag." + ), +) -> ScenarioListResponse: + """List saved scenario plans. + + Args: + db: Async database session from dependency. + limit: Maximum plans to return (1-100). + offset: Number of plans to skip. + tags: Optional library tags to filter by. + + Returns: + A page of saved plans plus the total count. + """ + return await ScenarioService().list_plans(db, limit=limit, offset=offset, tags=tags) + + +@router.get( + "/{scenario_id}", + response_model=ScenarioPlanResponse, + summary="Get a saved scenario plan", + description="Fetch one saved plan, including its embedded comparison snapshot.", +) +async def get_scenario( + scenario_id: str, + db: AsyncSession = Depends(get_db), +) -> ScenarioPlanResponse: + """Get a saved scenario plan by id. + + Args: + scenario_id: External identifier of the plan. + db: Async database session from dependency. + + Returns: + The saved plan with its embedded comparison snapshot. + + Raises: + NotFoundError: When no plan matches ``scenario_id``. + """ + plan = await ScenarioService().get_plan(db, scenario_id) + if plan is None: + raise NotFoundError(message=f"Scenario plan not found: {scenario_id}") + return plan + + +@router.delete( + "/{scenario_id}", + status_code=status.HTTP_204_NO_CONTENT, + summary="Delete a saved scenario plan", + description="Delete a saved scenario plan by id.", +) +async def delete_scenario( + scenario_id: str, + db: AsyncSession = Depends(get_db), +) -> None: + """Delete a saved scenario plan. + + Args: + scenario_id: External identifier of the plan. + db: Async database session from dependency. + + Raises: + NotFoundError: When no plan matches ``scenario_id``. + """ + deleted = await ScenarioService().delete_plan(db, scenario_id) + if not deleted: + raise NotFoundError(message=f"Scenario plan not found: {scenario_id}") diff --git a/app/features/scenarios/schemas.py b/app/features/scenarios/schemas.py new file mode 100644 index 00000000..fe1a71ef --- /dev/null +++ b/app/features/scenarios/schemas.py @@ -0,0 +1,463 @@ +"""Pydantic schemas for the Scenario Simulation slice. + +Two families of model live here: + +* **Request bodies** — ``SimulateScenarioRequest``, ``CreateScenarioRequest`` and + the ``*Assumption`` inputs. They carry ``ConfigDict(strict=True)`` to catch + silent coercion bugs on JSON-native types, and every ``date`` field carries a + ``Field(strict=False, ...)`` override so FastAPI's ``validate_python`` path + still accepts ISO-string dates (see ``docs/_base/SECURITY.md`` — "Pydantic v2 + strict mode on FastAPI request bodies"). +* **Responses** — ``ScenarioComparison``, ``ScenarioPlanResponse`` and the list + models. They use ``ConfigDict(from_attributes=True)`` and deliberately do NOT + set ``strict=True``. +""" + +from __future__ import annotations + +from datetime import date as date_type +from datetime import datetime +from typing import Annotated, Literal + +from pydantic import BaseModel, ConfigDict, Field + +# Promotion mechanics mirror data_platform.models.Promotion.kind. +PromotionKind = Literal["pct_off", "bogo", "bundle", "markdown"] +# Lifecycle stages a planner can force on the assumption form. +LifecycleStage = Literal["launch", "growth", "maturity", "decline"] +# Whether projected demand is covered by on-hand stock. +CoverageVerdict = Literal["covered", "at_risk", "stockout", "unknown"] + + +# ============================================================================= +# Assumption inputs (request fragments) +# ============================================================================= + + +class PriceAssumption(BaseModel): + """A relative price change applied over a future date window.""" + + model_config = ConfigDict(strict=True) + + change_pct: float = Field( + ..., + ge=-0.9, + le=5.0, + description="Relative price change as a fraction (-0.15 == 15% cheaper, " + "0.10 == 10% dearer).", + ) + start_date: date_type = Field( + ..., + strict=False, + description="First day the price change is in effect (inclusive).", + ) + end_date: date_type = Field( + ..., + strict=False, + description="Last day the price change is in effect (inclusive).", + ) + + +class PromotionAssumption(BaseModel): + """A promotion of a given kind running over a future date window.""" + + model_config = ConfigDict(strict=True) + + kind: PromotionKind = Field( + ..., + description="Promotion mechanic: pct_off, bogo, bundle, or markdown.", + ) + start_date: date_type = Field( + ..., + strict=False, + description="First day the promotion runs (inclusive).", + ) + end_date: date_type = Field( + ..., + strict=False, + description="Last day the promotion runs (inclusive).", + ) + + +class HolidayAssumption(BaseModel): + """Explicit holiday / event days that lift demand.""" + + model_config = ConfigDict(strict=True) + + # ``strict=False`` on the outer Field satisfies the strict-mode policy + # linter; the per-element ``Annotated[..., Field(strict=False)]`` is what + # actually lets each ISO-string date coerce — field-level strict does NOT + # propagate into list members. + dates: list[Annotated[date_type, Field(strict=False)]] = Field( + ..., + strict=False, + min_length=1, + description="Calendar dates treated as holiday / event days.", + ) + + +class InventoryAssumption(BaseModel): + """On-hand stock used only to derive a coverage verdict — never demand.""" + + model_config = ConfigDict(strict=True) + + on_hand_units: int = Field( + ..., + ge=0, + description="Units of stock on hand for the horizon. Caps coverage, not demand.", + ) + + +class LifecycleAssumption(BaseModel): + """A forced product lifecycle stage for the whole horizon.""" + + model_config = ConfigDict(strict=True) + + stage: LifecycleStage = Field( + ..., + description="Lifecycle stage override: launch, growth, maturity, or decline.", + ) + + +class ScenarioAssumptions(BaseModel): + """The full set of optional what-if assumptions. + + Every field is optional — an empty ``ScenarioAssumptions`` is the "nothing + changes" case and yields a scenario identical to the baseline. + """ + + model_config = ConfigDict(strict=True) + + price: PriceAssumption | None = Field(default=None, description="Price-change assumption.") + promotion: PromotionAssumption | None = Field(default=None, description="Promotion assumption.") + holiday: HolidayAssumption | None = Field(default=None, description="Holiday / event days.") + inventory: InventoryAssumption | None = Field( + default=None, description="On-hand stock for the coverage verdict." + ) + lifecycle: LifecycleAssumption | None = Field( + default=None, description="Lifecycle-stage override." + ) + + +# ============================================================================= +# Request bodies +# ============================================================================= + + +class SimulateScenarioRequest(BaseModel): + """Request body for ``POST /scenarios/simulate`` (stateless).""" + + model_config = ConfigDict(strict=True) + + run_id: str = Field( + ..., + min_length=1, + max_length=64, + description="Artifact key of a baseline model — the run_id stored on a " + "completed predict/train job (model_{run_id}.joblib).", + ) + horizon: int = Field( + ..., + ge=1, + le=90, + description="Number of days to simulate.", + ) + assumptions: ScenarioAssumptions = Field( + default_factory=ScenarioAssumptions, + description="Optional what-if assumptions. Omit for a no-change baseline.", + ) + name: str | None = Field( + default=None, + max_length=200, + description="Optional label echoed back; suggested name when saving a plan.", + ) + + +class CreateScenarioRequest(BaseModel): + """Request body for ``POST /scenarios`` — runs a simulation and persists it.""" + + model_config = ConfigDict(strict=True) + + name: str = Field( + ..., + min_length=1, + max_length=200, + description="Human-readable name for the saved plan.", + ) + run_id: str = Field( + ..., + min_length=1, + max_length=64, + description="Artifact key of the baseline model.", + ) + horizon: int = Field( + ..., + ge=1, + le=90, + description="Number of days to simulate.", + ) + assumptions: ScenarioAssumptions = Field( + default_factory=ScenarioAssumptions, + description="What-if assumptions for this plan.", + ) + tags: list[str] = Field( + default_factory=list, + max_length=20, + description="Optional library tags for filtering and grouping saved plans.", + ) + cloned_from: str | None = Field( + default=None, + max_length=32, + description="scenario_id this plan was cloned from, when it originated as a clone.", + ) + + +class SaveScenarioRequest(BaseModel): + """What the ``save_scenario`` agent tool persists once HITL-approved. + + Unlike ``CreateScenarioRequest`` this carries ``store_id`` / ``product_id`` + explicitly — the agent proposed the scenario for a known grain, so the + identity travels with the request rather than being re-derived. The + persisted ``scenario_plan`` row is stamped ``source='agent'`` plus the + originating ``agent_session_id`` (PRP-27 Phase D, DECISIONS LOCKED #13). + """ + + model_config = ConfigDict(strict=True) + + name: str = Field( + ..., + min_length=1, + max_length=200, + description="Human-readable name for the saved plan.", + ) + assumptions: ScenarioAssumptions = Field( + ..., + description="The what-if assumptions the agent proposed.", + ) + store_id: int = Field(..., ge=1, description="Store the proposed scenario targets.") + product_id: int = Field(..., ge=1, description="Product the proposed scenario targets.") + horizon: int = Field(..., ge=1, le=90, description="Number of days to simulate.") + run_id: str = Field( + ..., + min_length=1, + max_length=64, + description="Artifact key of the baseline model.", + ) + source: Literal["user", "agent"] = Field( + default="agent", + description="Provenance of the plan — an agent save defaults to 'agent'.", + ) + agent_session_id: str | None = Field( + default=None, + max_length=32, + description="The originating agent session id, when agent-created.", + ) + + +# ============================================================================= +# Response models +# ============================================================================= + + +class ScenarioPoint(BaseModel): + """One horizon day: baseline vs. scenario demand and the factor applied.""" + + model_config = ConfigDict(from_attributes=True) + + date: date_type = Field(..., description="Forecast date.") + baseline: float = Field(..., description="Baseline forecast demand for the day.") + scenario: float = Field(..., description="Scenario-adjusted demand for the day.") + delta: float = Field(..., description="scenario minus baseline for the day.") + applied_factor: float = Field( + ..., + description="Combined deterministic multiplier applied on the day (1.0 == no change).", + ) + + +class ScenarioComparison(BaseModel): + """A full baseline-vs-scenario comparison for one (store, product) series.""" + + model_config = ConfigDict(from_attributes=True) + + store_id: int = Field(..., description="Store the baseline model targets.") + product_id: int = Field(..., description="Product the baseline model targets.") + model_type: str = Field(..., description="Model type of the baseline artifact.") + horizon: int = Field(..., ge=1, description="Number of days simulated.") + points: list[ScenarioPoint] = Field( + ..., + description="Per-day baseline / scenario series; length equals horizon.", + ) + baseline_total_units: float = Field(..., description="Summed baseline demand.") + scenario_total_units: float = Field(..., description="Summed scenario demand.") + units_delta: float = Field(..., description="scenario_total_units minus baseline_total_units.") + units_delta_pct: float = Field( + ..., + description="units_delta as a percentage of baseline; 0.0 when baseline is 0.", + ) + unit_price_used: float = Field( + ..., + description="Unit price used for the revenue estimate (most recent sale, " + "or a documented fallback).", + ) + baseline_revenue: float = Field(..., description="baseline_total_units * unit_price_used.") + scenario_revenue: float = Field(..., description="scenario_total_units * unit_price_used.") + revenue_delta: float = Field(..., description="scenario_revenue minus baseline_revenue.") + coverage_verdict: CoverageVerdict = Field( + ..., + description="covered / at_risk / stockout, or unknown when no inventory " + "assumption was supplied.", + ) + method: Literal["heuristic", "model_exogenous"] = Field( + ..., + description="How the scenario was produced: 'heuristic' (a deterministic " + "post-forecast multiplier) or 'model_exogenous' (a re-forecast through a " + "feature-consuming regression model).", + ) + disclaimer: str = Field( + ..., + description="Plain-language caveat appropriate to the method that produced the comparison.", + ) + generated_at: datetime = Field(..., description="When the comparison was computed (UTC).") + + +class ScenarioPlanResponse(BaseModel): + """A persisted scenario plan, including the embedded comparison snapshot.""" + + model_config = ConfigDict(from_attributes=True) + + scenario_id: str = Field(..., description="Unique external identifier of the plan.") + name: str = Field(..., description="Human-readable plan name.") + store_id: int = Field(..., description="Store the plan targets.") + product_id: int = Field(..., description="Product the plan targets.") + run_id: str = Field(..., description="Artifact key of the baseline model.") + horizon: int = Field(..., ge=1, description="Number of days simulated.") + method: str = Field(..., description="Adjustment method — always 'heuristic'.") + created_at: datetime = Field(..., description="When the plan was saved (UTC).") + assumptions: ScenarioAssumptions = Field( + ..., description="The raw what-if assumptions the plan was built from." + ) + comparison: ScenarioComparison = Field( + ..., description="The full baseline-vs-scenario snapshot, re-rendered without recompute." + ) + tags: list[str] = Field(default_factory=list, description="Library tags attached to the plan.") + cloned_from: str | None = Field( + default=None, description="scenario_id this plan was cloned from, if any." + ) + source: str = Field(default="user", description="Who created the plan — 'user' or 'agent'.") + agent_session_id: str | None = Field( + default=None, description="Originating agent session id, when agent-created." + ) + approved_by: str | None = Field( + default=None, description="Who approved an agent-created plan, if any." + ) + approved_at: datetime | None = Field( + default=None, description="When an agent-created plan was approved (UTC)." + ) + approval_decision: str | None = Field( + default=None, description="The HITL decision — 'approved' or 'rejected'." + ) + + +class ScenarioListItem(BaseModel): + """A compact row in the saved-plans list.""" + + model_config = ConfigDict(from_attributes=True) + + scenario_id: str = Field(..., description="Unique external identifier of the plan.") + name: str = Field(..., description="Human-readable plan name.") + store_id: int = Field(..., description="Store the plan targets.") + product_id: int = Field(..., description="Product the plan targets.") + horizon: int = Field(..., ge=1, description="Number of days simulated.") + units_delta: float = Field(..., description="Summed scenario-minus-baseline demand.") + revenue_delta: float = Field(..., description="Scenario-minus-baseline revenue.") + created_at: datetime = Field(..., description="When the plan was saved (UTC).") + tags: list[str] = Field(default_factory=list, description="Library tags attached to the plan.") + source: str = Field(default="user", description="Who created the plan — 'user' or 'agent'.") + agent_session_id: str | None = Field( + default=None, description="Originating agent session id, when agent-created." + ) + approved_by: str | None = Field( + default=None, description="Who approved an agent-created plan, if any." + ) + approved_at: datetime | None = Field( + default=None, description="When an agent-created plan was approved (UTC)." + ) + approval_decision: str | None = Field( + default=None, description="The HITL decision — 'approved' or 'rejected'." + ) + + +class ScenarioListResponse(BaseModel): + """A page of saved scenario plans, newest first.""" + + model_config = ConfigDict(from_attributes=True) + + scenarios: list[ScenarioListItem] = Field( + ..., description="Saved plans for the current page; empty when none exist." + ) + total: int = Field(..., ge=0, description="Total saved plans matching the query.") + + +# ============================================================================= +# Multi-scenario comparison (PRP-27 Phase C) +# ============================================================================= + +# Metric a multi-scenario comparison ranks by. +RankBy = Literal["revenue_delta", "units_delta"] + + +class CompareScenariosRequest(BaseModel): + """Request body for ``POST /scenarios/compare``. + + The 2..5 bound keeps the multi-series chart legible — the upper bound is + the pinned ``MAX_COMPARE_SCENARIOS`` (PRP-27 DECISIONS LOCKED #12); the + literal ``5`` must stay in sync with that constant in ``feature_frame.py``. + """ + + model_config = ConfigDict(strict=True) + + scenario_ids: list[str] = Field( + ..., + min_length=2, + max_length=5, + description="2-5 saved scenario_ids to compare side by side.", + ) + rank_by: RankBy = Field( + default="revenue_delta", + description="Metric the ranked rows are ordered by (descending).", + ) + + +class ScenarioComparisonRow(BaseModel): + """One saved plan's headline numbers within a multi-scenario comparison.""" + + model_config = ConfigDict(from_attributes=True) + + scenario_id: str = Field(..., description="The plan's external identifier.") + name: str = Field(..., description="The plan's human-readable name.") + units_delta: float = Field(..., description="Scenario-minus-baseline demand for the plan.") + revenue_delta: float = Field(..., description="Scenario-minus-baseline revenue for the plan.") + coverage_verdict: CoverageVerdict = Field(..., description="The plan's coverage verdict.") + rank: int = Field(..., ge=1, description="1-based rank by the chosen metric (1 == best).") + + +class MultiScenarioComparison(BaseModel): + """A baseline compared against 2-5 saved scenarios, ranked.""" + + model_config = ConfigDict(from_attributes=True) + + baseline_total_units: float = Field( + ..., description="Reference baseline demand (from the first compared plan)." + ) + baseline_revenue: float = Field( + ..., description="Reference baseline revenue (from the first compared plan)." + ) + rank_by: RankBy = Field(..., description="Metric the rows are ranked by.") + scenarios: list[ScenarioComparisonRow] = Field( + ..., description="The compared plans, ordered best-first by rank_by." + ) + chart_series: list[dict[str, float | str]] = Field( + ..., + description="Date-keyed merged rows for the multi-series chart — each row " + "carries 'date', 'baseline', and one entry per scenario keyed by scenario_id.", + ) diff --git a/app/features/scenarios/service.py b/app/features/scenarios/service.py new file mode 100644 index 00000000..4fcdb308 --- /dev/null +++ b/app/features/scenarios/service.py @@ -0,0 +1,655 @@ +"""Service layer for the Scenario Simulation slice. + +``ScenarioService`` does two things: + +* **simulate** — resolve a baseline model artifact, run its forecast, apply the + pure deterministic factors from ``adjustments.py``, and return a + ``ScenarioComparison``. Stateless. +* **CRUD** — persist a comparison as a named ``scenario_plan`` row, then list / + fetch / delete saved plans. + +DECISIONS LOCKED (PRP-26 #2): this service must NOT import a sibling slice's +``service.py``. It imports only the stable lower-level building block +``load_model_bundle`` from ``forecasting/persistence.py`` and produces the +baseline forecast by calling ``bundle.model.predict(horizon)`` directly — +replicating the ``ForecastPoint``-construction block of +``ForecastingService.predict`` rather than calling that class. Read-only ORM +imports of sibling ``models.py`` (``data_platform``) are allowed. +""" + +from __future__ import annotations + +import uuid +from datetime import UTC, date, datetime, timedelta +from pathlib import Path +from typing import Any, cast + +import numpy as np +from sqlalchemy import func, select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.config import get_settings +from app.core.logging import get_logger +from app.features.data_platform.models import SalesDaily +from app.features.forecasting.persistence import ModelBundle, load_model_bundle +from app.features.scenarios import adjustments +from app.features.scenarios.feature_frame import build_future_frame +from app.features.scenarios.models import SCENARIO_SOURCE_USER, ScenarioPlan +from app.features.scenarios.schemas import ( + CompareScenariosRequest, + CreateScenarioRequest, + MultiScenarioComparison, + ScenarioAssumptions, + ScenarioComparison, + ScenarioComparisonRow, + ScenarioListItem, + ScenarioListResponse, + ScenarioPlanResponse, + ScenarioPoint, + SimulateScenarioRequest, +) + +logger = get_logger(__name__) + +# Plain-language caveat stamped on every comparison — the NIST-AI-RMF +# transparency control against over-trusting a heuristic number. +HEURISTIC_DISCLAIMER = ( + "Heuristic estimate: this scenario applies fixed, deterministic adjustment " + "factors to a baseline forecast — it is not a re-trained, causal model. " + "Treat the demand and revenue deltas as directional planning signals, not " + "precise predictions." +) + +# Caveat for the model-driven path — a re-forecast IS model-causal, but a model +# estimate is still an estimate (NIST-AI-RMF transparency control). +MODEL_EXOGENOUS_DISCLAIMER = ( + "Model estimate: this scenario re-forecasts demand through a feature-driven " + "model using the assumptions as future inputs. It reflects learned patterns " + "but remains an estimate under uncertainty — not a guarantee." +) + +# Fallback unit price when a (store, product) has no sales history. +DEFAULT_UNIT_PRICE = 1.0 + + +class ScenarioService: + """Stateless simulation plus saved-plan CRUD for scenario planning.""" + + # -- Simulation -------------------------------------------------------- + + async def simulate( + self, db: AsyncSession, request: SimulateScenarioRequest + ) -> ScenarioComparison: + """Run a baseline forecast and apply the what-if assumptions. + + Args: + db: Database session (used only to estimate a unit price). + request: The baseline ``run_id``, horizon, and assumptions. + + Returns: + A full baseline-vs-scenario comparison. + + Raises: + FileNotFoundError: When no model artifact exists for ``run_id``. + ValueError: When the artifact path is invalid or its metadata is + missing the store / product identity. + """ + bundle = self._load_baseline_bundle(request.run_id) + + store_id_raw = bundle.metadata.get("store_id") + product_id_raw = bundle.metadata.get("product_id") + if store_id_raw is None or product_id_raw is None: + raise ValueError( + f"Model artifact for run_id '{request.run_id}' is missing " + "store_id / product_id metadata." + ) + store_id = int(str(store_id_raw)) + product_id = int(str(product_id_raw)) + + # A regression baseline answers the what-if by genuinely re-forecasting + # through the future feature frame; every other model type uses the + # deterministic heuristic multiplier below (PRP-27 DECISIONS LOCKED #1). + if bundle.config.model_type == "regression": + return await self._simulate_model_exogenous(db, request, bundle, store_id, product_id) + + # Replicate the ForecastingService.predict body (DECISIONS LOCKED #2). + raw_forecast = bundle.model.predict(request.horizon) + baseline_values = [float(value) for value in raw_forecast] + start_date = self._forecast_start_date(bundle.metadata.get("train_end_date")) + + # Per-day deterministic factors — adjustments.py is pure. + factors: list[float] = [] + for offset in range(request.horizon): + point_date = start_date + timedelta(days=offset) + factors.append(adjustments.combined_daily_factor(point_date, request.assumptions)) + scenario_values = adjustments.apply_adjustment(baseline_values, factors) + + points = [ + ScenarioPoint( + date=start_date + timedelta(days=offset), + baseline=baseline_values[offset], + scenario=scenario_values[offset], + delta=scenario_values[offset] - baseline_values[offset], + applied_factor=factors[offset], + ) + for offset in range(request.horizon) + ] + + baseline_total = sum(baseline_values) + scenario_total = sum(scenario_values) + units_delta = scenario_total - baseline_total + units_delta_pct = (units_delta / baseline_total * 100.0) if baseline_total > 0 else 0.0 + + unit_price = await self._latest_unit_price(db, store_id, product_id) + baseline_revenue = baseline_total * unit_price + scenario_revenue = scenario_total * unit_price + + inventory = request.assumptions.inventory + on_hand = inventory.on_hand_units if inventory is not None else None + verdict = adjustments.coverage_verdict(scenario_total, on_hand) + + logger.info( + "scenarios.simulated", + run_id=request.run_id, + store_id=store_id, + product_id=product_id, + horizon=request.horizon, + model_type=bundle.config.model_type, + units_delta=round(units_delta, 4), + coverage_verdict=verdict, + ) + + return ScenarioComparison( + store_id=store_id, + product_id=product_id, + model_type=bundle.config.model_type, + horizon=request.horizon, + points=points, + baseline_total_units=baseline_total, + scenario_total_units=scenario_total, + units_delta=units_delta, + units_delta_pct=units_delta_pct, + unit_price_used=unit_price, + baseline_revenue=baseline_revenue, + scenario_revenue=scenario_revenue, + revenue_delta=scenario_revenue - baseline_revenue, + coverage_verdict=verdict, + method="heuristic", + disclaimer=HEURISTIC_DISCLAIMER, + generated_at=datetime.now(UTC), + ) + + async def _simulate_model_exogenous( + self, + db: AsyncSession, + request: SimulateScenarioRequest, + bundle: ModelBundle, + store_id: int, + product_id: int, + ) -> ScenarioComparison: + """Re-forecast a regression baseline through the future feature frame. + + Builds two leakage-safe future frames — one carrying the scenario + assumptions, one with none — feeds both to the model, and compares the + re-forecasts. Unlike the heuristic path the deltas come from the model + itself, so the result is stamped ``method="model_exogenous"``. + + Args: + db: Database session. + request: The baseline ``run_id``, horizon, and assumptions. + bundle: The already-loaded regression model bundle. + store_id: Store the baseline model targets. + product_id: Product the baseline model targets. + + Returns: + A model-driven baseline-vs-scenario comparison. + + Raises: + ValueError: When the bundle lacks the feature metadata a scenario + forecast needs (an older artifact trained before PRP-27). + """ + feature_columns_raw = bundle.metadata.get("feature_columns") + history_tail_raw = bundle.metadata.get("history_tail") + if not isinstance(feature_columns_raw, list) or not isinstance(history_tail_raw, list): + raise ValueError( + f"Model artifact for run_id '{request.run_id}' is a regression " + "model without the feature metadata a scenario forecast needs — " + "retrain it with the current pipeline." + ) + feature_columns = [str(column) for column in cast("list[str]", feature_columns_raw)] + history_tail = [float(value) for value in cast("list[float]", history_tail_raw)] + + # The forecast origin T is the day before the first forecast day. + origin = self._forecast_start_date(bundle.metadata.get("train_end_date")) - timedelta( + days=1 + ) + launch_raw = bundle.metadata.get("launch_date") + launch_date = date.fromisoformat(launch_raw) if isinstance(launch_raw, str) else None + + scenario_frame = await build_future_frame( + db, + store_id=store_id, + product_id=product_id, + forecast_origin=origin, + horizon=request.horizon, + feature_columns=feature_columns, + history_tail=history_tail, + assumptions=request.assumptions, + launch_date=launch_date, + ) + # The baseline is the SAME frame with the assumptions stripped. + baseline_frame = await build_future_frame( + db, + store_id=store_id, + product_id=product_id, + forecast_origin=origin, + horizon=request.horizon, + feature_columns=feature_columns, + history_tail=history_tail, + assumptions=ScenarioAssumptions(), + launch_date=launch_date, + ) + + scenario_x = np.array(scenario_frame.matrix, dtype=np.float64) + baseline_x = np.array(baseline_frame.matrix, dtype=np.float64) + # Demand can never be negative — floor the model output at 0. + scenario_values = [ + max(0.0, float(value)) for value in bundle.model.predict(request.horizon, scenario_x) + ] + baseline_values = [ + max(0.0, float(value)) for value in bundle.model.predict(request.horizon, baseline_x) + ] + + points = [ + ScenarioPoint( + date=scenario_frame.dates[offset], + baseline=baseline_values[offset], + scenario=scenario_values[offset], + delta=scenario_values[offset] - baseline_values[offset], + # The realised per-day multiplier the model implied (1.0 == no + # change); guards a zero-baseline day. + applied_factor=( + scenario_values[offset] / baseline_values[offset] + if baseline_values[offset] > 0.0 + else 1.0 + ), + ) + for offset in range(request.horizon) + ] + + baseline_total = sum(baseline_values) + scenario_total = sum(scenario_values) + units_delta = scenario_total - baseline_total + units_delta_pct = (units_delta / baseline_total * 100.0) if baseline_total > 0 else 0.0 + + unit_price = await self._latest_unit_price(db, store_id, product_id) + baseline_revenue = baseline_total * unit_price + scenario_revenue = scenario_total * unit_price + + inventory = request.assumptions.inventory + on_hand = inventory.on_hand_units if inventory is not None else None + verdict = adjustments.coverage_verdict(scenario_total, on_hand) + + logger.info( + "scenarios.simulated", + run_id=request.run_id, + store_id=store_id, + product_id=product_id, + horizon=request.horizon, + model_type=bundle.config.model_type, + method="model_exogenous", + units_delta=round(units_delta, 4), + coverage_verdict=verdict, + ) + + return ScenarioComparison( + store_id=store_id, + product_id=product_id, + model_type=bundle.config.model_type, + horizon=request.horizon, + points=points, + baseline_total_units=baseline_total, + scenario_total_units=scenario_total, + units_delta=units_delta, + units_delta_pct=units_delta_pct, + unit_price_used=unit_price, + baseline_revenue=baseline_revenue, + scenario_revenue=scenario_revenue, + revenue_delta=scenario_revenue - baseline_revenue, + coverage_verdict=verdict, + method="model_exogenous", + disclaimer=MODEL_EXOGENOUS_DISCLAIMER, + generated_at=datetime.now(UTC), + ) + + # -- Persistence ------------------------------------------------------- + + async def create_plan( + self, + db: AsyncSession, + request: CreateScenarioRequest, + *, + source: str = SCENARIO_SOURCE_USER, + agent_session_id: str | None = None, + approved_by: str | None = None, + approval_decision: str | None = None, + ) -> ScenarioPlanResponse: + """Run a simulation and persist it as a named scenario plan. + + The provenance keyword arguments default to a plain user-created plan, + so the MVP create path stays backward-compatible. An agent-saved plan + (PRP-27 Phase D) passes ``source='agent'`` plus the originating + ``agent_session_id`` and the HITL approval audit trail; ``approved_at`` + is stamped automatically whenever ``approval_decision`` is supplied. + + Args: + db: Database session. + request: Plan name plus the baseline / horizon / assumptions. + source: Who created the plan — 'user' (default) or 'agent'. + agent_session_id: Originating agent session id, when agent-created. + approved_by: Who approved an agent-created plan, if any. + approval_decision: The HITL decision — 'approved' or 'rejected'. + + Returns: + The saved plan with its embedded comparison snapshot. + + Raises: + FileNotFoundError: When no model artifact exists for ``run_id``. + ValueError: When the artifact path or its metadata is invalid. + """ + comparison = await self.simulate( + db, + SimulateScenarioRequest( + run_id=request.run_id, + horizon=request.horizon, + assumptions=request.assumptions, + name=request.name, + ), + ) + + # An approval decision implies an approval moment — stamp it here so + # callers never have to thread a timestamp through. + approved_at = datetime.now(UTC) if approval_decision is not None else None + + plan = ScenarioPlan( + scenario_id=uuid.uuid4().hex, + name=request.name, + store_id=comparison.store_id, + product_id=comparison.product_id, + run_id=request.run_id, + horizon=request.horizon, + # JSONB cannot store Python date/datetime — dump in JSON mode. + assumptions=request.assumptions.model_dump(mode="json"), + comparison=comparison.model_dump(mode="json"), + # heuristic or model_exogenous — taken from the comparison the + # baseline model actually produced. + method=comparison.method, + tags=list(request.tags), + cloned_from=request.cloned_from, + source=source, + agent_session_id=agent_session_id, + approved_by=approved_by, + approved_at=approved_at, + approval_decision=approval_decision, + ) + db.add(plan) + await db.commit() + await db.refresh(plan) + + logger.info( + "scenarios.plan_created", + scenario_id=plan.scenario_id, + store_id=plan.store_id, + product_id=plan.product_id, + source=source, + agent_session_id=agent_session_id, + ) + return self._to_plan_response(plan) + + async def list_plans( + self, + db: AsyncSession, + limit: int, + offset: int, + tags: list[str] | None = None, + ) -> ScenarioListResponse: + """List saved scenario plans, newest first. + + Args: + db: Database session. + limit: Maximum plans to return. + offset: Number of plans to skip. + tags: Optional library tags — when given, only plans carrying every + listed tag are returned. + + Returns: + A page of plan list items plus the total count. + """ + count_stmt = select(func.count()).select_from(ScenarioPlan) + rows_stmt = ( + select(ScenarioPlan) + .order_by(ScenarioPlan.created_at.desc(), ScenarioPlan.id.desc()) + .limit(limit) + .offset(offset) + ) + if tags: + # JSONB @> containment — a plan matches when it carries every tag. + count_stmt = count_stmt.where(ScenarioPlan.tags.contains(tags)) + rows_stmt = rows_stmt.where(ScenarioPlan.tags.contains(tags)) + + total = int(await db.scalar(count_stmt) or 0) + rows = (await db.execute(rows_stmt)).scalars().all() + return ScenarioListResponse( + scenarios=[self._to_list_item(row) for row in rows], + total=total, + ) + + async def compare_scenarios( + self, db: AsyncSession, request: CompareScenariosRequest + ) -> MultiScenarioComparison: + """Rank 2-5 saved scenario plans against a shared baseline. + + Each saved plan already embeds its full ``ScenarioComparison`` snapshot, + so the comparison is a pure aggregation — no model artifact is reloaded. + The reference baseline is taken from the first requested plan. + + Args: + db: Database session. + request: The 2-5 ``scenario_id``s and the ranking metric. + + Returns: + A ranked multi-scenario comparison plus merged chart series. + + Raises: + FileNotFoundError: When any requested ``scenario_id`` does not exist. + """ + rows = ( + ( + await db.execute( + select(ScenarioPlan).where(ScenarioPlan.scenario_id.in_(request.scenario_ids)) + ) + ) + .scalars() + .all() + ) + found = {row.scenario_id: row for row in rows} + missing = [sid for sid in request.scenario_ids if sid not in found] + if missing: + raise FileNotFoundError(f"Scenario plan(s) not found: {', '.join(missing)}") + + # Preserve the caller's order; the first plan supplies the baseline. + plans = [found[sid] for sid in request.scenario_ids] + first_comparison = plans[0].comparison + baseline_total = float(first_comparison.get("baseline_total_units", 0.0)) + baseline_revenue = float(first_comparison.get("baseline_revenue", 0.0)) + + ranked = sorted( + plans, + key=lambda plan: float(plan.comparison.get(request.rank_by, 0.0)), + reverse=True, + ) + comparison_rows = [ + ScenarioComparisonRow( + scenario_id=plan.scenario_id, + name=plan.name, + units_delta=float(plan.comparison.get("units_delta", 0.0)), + revenue_delta=float(plan.comparison.get("revenue_delta", 0.0)), + coverage_verdict=plan.comparison.get("coverage_verdict", "unknown"), + rank=index + 1, + ) + for index, plan in enumerate(ranked) + ] + + logger.info( + "scenarios.compared", + scenario_count=len(plans), + rank_by=request.rank_by, + ) + return MultiScenarioComparison( + baseline_total_units=baseline_total, + baseline_revenue=baseline_revenue, + rank_by=request.rank_by, + scenarios=comparison_rows, + chart_series=self._build_chart_series(plans), + ) + + @staticmethod + def _build_chart_series(plans: list[ScenarioPlan]) -> list[dict[str, float | str]]: + """Merge every plan's per-day series into date-keyed chart rows. + + Each row carries ``date``, the reference ``baseline`` (from the first + plan), and one entry per plan keyed by its ``scenario_id`` — a + CSS-identifier-safe key, unlike a free-text plan name. + """ + by_date: dict[str, dict[str, float | str]] = {} + for plan_index, plan in enumerate(plans): + points = cast("list[dict[str, Any]]", plan.comparison.get("points", [])) + for point in points: + point_date = str(point.get("date", "")) + row = by_date.setdefault(point_date, {"date": point_date}) + if plan_index == 0: + row["baseline"] = float(point.get("baseline", 0.0)) + row[plan.scenario_id] = float(point.get("scenario", 0.0)) + return [by_date[key] for key in sorted(by_date)] + + async def get_plan(self, db: AsyncSession, scenario_id: str) -> ScenarioPlanResponse | None: + """Fetch one saved plan by its external id, or ``None`` when absent.""" + plan = await db.scalar(select(ScenarioPlan).where(ScenarioPlan.scenario_id == scenario_id)) + if plan is None: + return None + return self._to_plan_response(plan) + + async def delete_plan(self, db: AsyncSession, scenario_id: str) -> bool: + """Delete a saved plan; return ``True`` when a row was removed.""" + plan = await db.scalar(select(ScenarioPlan).where(ScenarioPlan.scenario_id == scenario_id)) + if plan is None: + return False + await db.delete(plan) + await db.commit() + logger.info("scenarios.plan_deleted", scenario_id=scenario_id) + return True + + # -- Internal helpers -------------------------------------------------- + + def _load_baseline_bundle(self, run_id: str) -> ModelBundle: + """Resolve and load the baseline model artifact for ``run_id``. + + Mirrors the load-bearing path-traversal guard in + ``ForecastingService.predict``: reject a non-``.joblib`` suffix and any + path that escapes the configured artifacts directory. + """ + settings = get_settings() + artifacts_dir = Path(settings.forecast_model_artifacts_dir).resolve() + model_path = (artifacts_dir / f"model_{run_id}.joblib").resolve() + + if model_path.suffix != ".joblib": + raise ValueError(f"Invalid model path for run_id '{run_id}'.") + try: + model_path.relative_to(artifacts_dir) + except ValueError: + raise ValueError(f"Invalid model path for run_id '{run_id}'.") from None + if not model_path.exists(): + raise FileNotFoundError(f"No model artifact found for run_id '{run_id}'.") + + return load_model_bundle(model_path) + + @staticmethod + def _forecast_start_date(train_end_raw: object) -> date: + """Return the first forecast day — train_end_date + 1, or today + 1. + + ``train_end_date`` is persisted as an ISO string in the bundle metadata; + when it is absent the forecast simply starts tomorrow. + """ + if isinstance(train_end_raw, str): + return date.fromisoformat(train_end_raw) + timedelta(days=1) + return datetime.now(UTC).date() + timedelta(days=1) + + async def _latest_unit_price(self, db: AsyncSession, store_id: int, product_id: int) -> float: + """Estimate a unit price from the most recent sale of this grain. + + Falls back to ``DEFAULT_UNIT_PRICE`` (and logs a warning) when the + grain has no sales history. + """ + price = await db.scalar( + select(SalesDaily.unit_price) + .where(SalesDaily.store_id == store_id, SalesDaily.product_id == product_id) + .order_by(SalesDaily.date.desc()) + .limit(1) + ) + if price is None: + logger.warning( + "scenarios.unit_price_fallback", + store_id=store_id, + product_id=product_id, + fallback=DEFAULT_UNIT_PRICE, + ) + return DEFAULT_UNIT_PRICE + return float(price) + + @staticmethod + def _to_plan_response(plan: ScenarioPlan) -> ScenarioPlanResponse: + """Build a full plan response from a persisted row. + + The JSONB blobs round-trip cleanly: ``ScenarioComparison`` is not strict, + and every ``date`` field of ``ScenarioAssumptions`` carries + ``Field(strict=False)``, so the stored ISO strings re-validate. + """ + return ScenarioPlanResponse( + scenario_id=plan.scenario_id, + name=plan.name, + store_id=plan.store_id, + product_id=plan.product_id, + run_id=plan.run_id, + horizon=plan.horizon, + method=plan.method, + created_at=plan.created_at, + assumptions=ScenarioAssumptions.model_validate(plan.assumptions), + comparison=ScenarioComparison.model_validate(plan.comparison), + tags=list(plan.tags), + cloned_from=plan.cloned_from, + source=plan.source, + agent_session_id=plan.agent_session_id, + approved_by=plan.approved_by, + approved_at=plan.approved_at, + approval_decision=plan.approval_decision, + ) + + @staticmethod + def _to_list_item(plan: ScenarioPlan) -> ScenarioListItem: + """Build a compact list row, reading the deltas from the snapshot.""" + return ScenarioListItem( + scenario_id=plan.scenario_id, + name=plan.name, + store_id=plan.store_id, + product_id=plan.product_id, + horizon=plan.horizon, + units_delta=float(plan.comparison.get("units_delta", 0.0)), + revenue_delta=float(plan.comparison.get("revenue_delta", 0.0)), + created_at=plan.created_at, + tags=list(plan.tags), + source=plan.source, + agent_session_id=plan.agent_session_id, + approved_by=plan.approved_by, + approved_at=plan.approved_at, + approval_decision=plan.approval_decision, + ) diff --git a/app/features/scenarios/tests/__init__.py b/app/features/scenarios/tests/__init__.py new file mode 100644 index 00000000..b9535cd8 --- /dev/null +++ b/app/features/scenarios/tests/__init__.py @@ -0,0 +1 @@ +"""Tests for the scenarios vertical slice.""" diff --git a/app/features/scenarios/tests/conftest.py b/app/features/scenarios/tests/conftest.py new file mode 100644 index 00000000..b1cc6eeb --- /dev/null +++ b/app/features/scenarios/tests/conftest.py @@ -0,0 +1,152 @@ +"""Test fixtures for the scenarios slice. + +Integration tests run against a real PostgreSQL database (``docker compose up +-d`` required). The ``trained_model`` fixture writes a real model bundle into +the configured artifacts directory so ``POST /scenarios/simulate`` can resolve +it, exactly as a completed predict job would. + +``scenario_plan`` is a slice-private table — no seeder or demo writes it — so +the teardown safely wipes it whole rather than relying on a row marker. +""" + +import uuid +from collections.abc import AsyncGenerator, Generator +from datetime import date, timedelta +from pathlib import Path + +import numpy as np +import pytest +from httpx import ASGITransport, AsyncClient +from sqlalchemy import delete +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine + +from app.core.config import get_settings +from app.core.database import get_db +from app.features.forecasting.models import NaiveForecaster, RegressionForecaster +from app.features.forecasting.persistence import ModelBundle, save_model_bundle +from app.features.forecasting.schemas import NaiveModelConfig, RegressionModelConfig +from app.features.scenarios.feature_frame import canonical_feature_columns +from app.features.scenarios.models import ScenarioPlan +from app.main import app + +# Store / product the test bundle is trained for. High IDs that no seeder uses, +# so the revenue calc deterministically hits the unit-price fallback. +TEST_STORE_ID = 990001 +TEST_PRODUCT_ID = 990002 +# train_end_date baked into the bundle metadata — the forecast starts the next day. +TEST_TRAIN_END_DATE = "2026-06-30" + + +@pytest.fixture +async def db_session() -> AsyncGenerator[AsyncSession, None]: + """Yield an async session, then wipe every scenario_plan row on teardown.""" + settings = get_settings() + engine = create_async_engine(settings.database_url, echo=False) + async_session_maker = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) + + async with async_session_maker() as session: + try: + yield session + finally: + await session.execute(delete(ScenarioPlan)) + await session.commit() + + await engine.dispose() + + +@pytest.fixture +async def client(db_session: AsyncSession) -> AsyncGenerator[AsyncClient, None]: + """Create a test client with the database dependency overridden.""" + + async def override_get_db() -> AsyncGenerator[AsyncSession, None]: + yield db_session + + app.dependency_overrides[get_db] = override_get_db + + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as ac: + yield ac + + app.dependency_overrides.pop(get_db, None) + + +@pytest.fixture +def trained_model() -> Generator[str, None, None]: + """Save a real fitted naive-model bundle on disk; yield its run_id. + + The bundle lands in ``settings.forecast_model_artifacts_dir`` as + ``model_{run_id}.joblib`` — the exact artifact key ``ScenarioService`` + resolves. The file is removed on teardown. + """ + settings = get_settings() + artifacts_dir = Path(settings.forecast_model_artifacts_dir) + artifacts_dir.mkdir(parents=True, exist_ok=True) + + run_id = uuid.uuid4().hex[:12] + model = NaiveForecaster() + model.fit(np.array([10.0, 12.0, 11.0, 13.0, 9.0, 14.0, 10.0], dtype=np.float64)) + bundle = ModelBundle( + model=model, + config=NaiveModelConfig(), + metadata={ + "store_id": TEST_STORE_ID, + "product_id": TEST_PRODUCT_ID, + "train_end_date": TEST_TRAIN_END_DATE, + "n_observations": 7, + }, + ) + save_model_bundle(bundle, artifacts_dir / f"model_{run_id}") + + yield run_id + + (artifacts_dir / f"model_{run_id}.joblib").unlink(missing_ok=True) + + +@pytest.fixture +def trained_regression_model() -> Generator[str, None, None]: + """Save a real fitted ``RegressionForecaster`` bundle on disk; yield run_id. + + The bundle carries the full PRP-27 metadata contract — ``feature_columns``, + ``history_tail``, ``launch_date`` — so the model-exogenous simulate path can + build a future feature frame and genuinely re-forecast. Demand is wired to + respond negatively to ``price_factor`` so a price cut lifts the forecast. + """ + settings = get_settings() + artifacts_dir = Path(settings.forecast_model_artifacts_dir) + artifacts_dir.mkdir(parents=True, exist_ok=True) + + run_id = uuid.uuid4().hex[:12] + columns = canonical_feature_columns() + rng = np.random.default_rng(7) + n_rows = 200 + features = rng.normal(size=(n_rows, len(columns))) + # Strong, negative price_factor coefficient: price_factor < 1.0 (a cut) + # lifts demand. The signal dwarfs the 0.5-scale noise, so the model learns + # a clean, deterministic price response. + price_index = columns.index("price_factor") + target = 40.0 - 20.0 * features[:, price_index] + rng.normal(scale=0.5, size=n_rows) + + model = RegressionForecaster(random_state=7) + model.fit(target.astype(np.float64), features.astype(np.float64)) + + history_start = date(2026, 4, 1) + bundle = ModelBundle( + model=model, + config=RegressionModelConfig(), + metadata={ + "store_id": TEST_STORE_ID, + "product_id": TEST_PRODUCT_ID, + "train_end_date": TEST_TRAIN_END_DATE, + "n_observations": n_rows, + "feature_columns": columns, + "history_tail": [12.0] * 90, + "history_tail_dates": [ + (history_start + timedelta(days=offset)).isoformat() for offset in range(90) + ], + "launch_date": "2025-01-01", + }, + ) + save_model_bundle(bundle, artifacts_dir / f"model_{run_id}") + + yield run_id + + (artifacts_dir / f"model_{run_id}.joblib").unlink(missing_ok=True) diff --git a/app/features/scenarios/tests/test_adjustments.py b/app/features/scenarios/tests/test_adjustments.py new file mode 100644 index 00000000..1cde5fe6 --- /dev/null +++ b/app/features/scenarios/tests/test_adjustments.py @@ -0,0 +1,206 @@ +"""Unit tests for the pure scenario adjustment engine. + +These run without a database (-m "not integration"): every function in +``adjustments.py`` is pure. The tests assert *direction and bounds* — a price +cut lifts demand, the clamp keeps a factor in band — not exact magnitudes, so +re-tuning the heuristic constants does not break them. +""" + +from datetime import date + +import pytest + +from app.features.scenarios import adjustments +from app.features.scenarios.schemas import ( + HolidayAssumption, + InventoryAssumption, + LifecycleAssumption, + PriceAssumption, + PromotionAssumption, + ScenarioAssumptions, +) + +# ============================================================================= +# clamp +# ============================================================================= + + +def test_clamp_inside_range_returns_value() -> None: + """A value already in range is returned unchanged.""" + assert adjustments.clamp(0.5, 0.1, 5.0) == 0.5 + + +def test_clamp_below_and_above() -> None: + """Values outside the range snap to the nearest bound.""" + assert adjustments.clamp(-3.0, 0.1, 5.0) == 0.1 + assert adjustments.clamp(99.0, 0.1, 5.0) == 5.0 + + +# ============================================================================= +# price_factor +# ============================================================================= + + +def test_price_cut_lifts_demand() -> None: + """A price cut (negative change) yields a factor above 1.""" + assert adjustments.price_factor(-0.15) > 1.0 + + +def test_price_rise_drags_demand() -> None: + """A price rise (positive change) yields a factor below 1.""" + assert adjustments.price_factor(0.20) < 1.0 + + +def test_price_factor_no_change_is_neutral() -> None: + """A zero price change is exactly neutral.""" + assert adjustments.price_factor(0.0) == 1.0 + + +def test_price_factor_tolerates_non_positive_price() -> None: + """A change of -100% or worse clamps to the upper band, never raises.""" + assert adjustments.price_factor(-1.0) == adjustments.FACTOR_BAND[1] + assert adjustments.price_factor(-5.0) == adjustments.FACTOR_BAND[1] + + +def test_price_factor_stays_in_band() -> None: + """The factor never escapes the clamp band for extreme inputs.""" + lo, hi = adjustments.FACTOR_BAND + for change in (-0.95, -0.5, 0.0, 1.0, 5.0): + assert lo <= adjustments.price_factor(change) <= hi + + +# ============================================================================= +# promotion_factor / holiday_factor / lifecycle_factor +# ============================================================================= + + +def test_promotion_factor_known_kinds_lift_demand() -> None: + """Every known promotion kind lifts demand when active.""" + for kind in ("pct_off", "bogo", "bundle", "markdown"): + assert adjustments.promotion_factor(kind, active=True) > 1.0 + + +def test_promotion_factor_inactive_or_unknown_is_neutral() -> None: + """An inactive promotion or an unknown kind is neutral.""" + assert adjustments.promotion_factor("pct_off", active=False) == 1.0 + assert adjustments.promotion_factor("mystery", active=True) == 1.0 + + +def test_holiday_factor() -> None: + """A holiday lifts demand; a non-holiday is neutral.""" + assert adjustments.holiday_factor(True) > 1.0 + assert adjustments.holiday_factor(False) == 1.0 + + +def test_lifecycle_factor_known_stages() -> None: + """Known lifecycle stages map to their documented multipliers.""" + assert adjustments.lifecycle_factor("launch") > 1.0 + assert adjustments.lifecycle_factor("maturity") == 1.0 + assert adjustments.lifecycle_factor("decline") < 1.0 + + +def test_lifecycle_factor_none_or_unknown_is_neutral() -> None: + """``None`` and an unknown stage are neutral, never an exception.""" + assert adjustments.lifecycle_factor(None) == 1.0 + assert adjustments.lifecycle_factor("zombie") == 1.0 + + +# ============================================================================= +# combined_daily_factor +# ============================================================================= + + +def test_combined_factor_empty_assumptions_is_neutral() -> None: + """An empty ScenarioAssumptions yields exactly 1.0 for any day.""" + assert adjustments.combined_daily_factor(date(2026, 6, 1), ScenarioAssumptions()) == 1.0 + + +def test_combined_factor_applies_price_inside_window() -> None: + """A price assumption applies only inside its date window.""" + assumptions = ScenarioAssumptions( + price=PriceAssumption( + change_pct=-0.20, start_date=date(2026, 6, 5), end_date=date(2026, 6, 10) + ) + ) + inside = adjustments.combined_daily_factor(date(2026, 6, 7), assumptions) + outside = adjustments.combined_daily_factor(date(2026, 6, 20), assumptions) + assert inside > 1.0 + assert outside == 1.0 + + +def test_combined_factor_stacks_promotion_and_holiday() -> None: + """Overlapping promotion and holiday assumptions compound multiplicatively.""" + day = date(2026, 6, 7) + assumptions = ScenarioAssumptions( + promotion=PromotionAssumption( + kind="bogo", start_date=date(2026, 6, 1), end_date=date(2026, 6, 30) + ), + holiday=HolidayAssumption(dates=[day]), + ) + assert adjustments.combined_daily_factor(day, assumptions) > 1.0 + + +def test_combined_factor_is_clamped() -> None: + """Even a stack of strong uplifts stays within the clamp band.""" + lo, hi = adjustments.FACTOR_BAND + assumptions = ScenarioAssumptions( + price=PriceAssumption( + change_pct=-0.9, start_date=date(2026, 6, 1), end_date=date(2026, 6, 30) + ), + promotion=PromotionAssumption( + kind="bogo", start_date=date(2026, 6, 1), end_date=date(2026, 6, 30) + ), + holiday=HolidayAssumption(dates=[date(2026, 6, 7)]), + lifecycle=LifecycleAssumption(stage="launch"), + ) + assert lo <= adjustments.combined_daily_factor(date(2026, 6, 7), assumptions) <= hi + + +# ============================================================================= +# apply_adjustment +# ============================================================================= + + +def test_apply_adjustment_element_wise() -> None: + """Each baseline value is multiplied by its matching factor.""" + assert adjustments.apply_adjustment([10.0, 20.0], [1.5, 0.5]) == [15.0, 10.0] + + +def test_apply_adjustment_floors_at_zero() -> None: + """A negative product is floored at 0.0 — demand is never negative.""" + assert adjustments.apply_adjustment([10.0], [-1.0]) == [0.0] + + +def test_apply_adjustment_length_mismatch_raises() -> None: + """A length mismatch is a caller-contract violation and raises ValueError.""" + with pytest.raises(ValueError, match="equal length"): + adjustments.apply_adjustment([1.0, 2.0], [1.0]) + + +# ============================================================================= +# coverage_verdict +# ============================================================================= + + +def test_coverage_verdict_unknown_without_inventory() -> None: + """No inventory assumption yields an 'unknown' verdict.""" + assert adjustments.coverage_verdict(100.0, None) == "unknown" + + +def test_coverage_verdict_covered_at_risk_stockout() -> None: + """Demand vs. on-hand stock maps to the three coverage bands.""" + assert adjustments.coverage_verdict(50.0, 100) == "covered" + assert adjustments.coverage_verdict(100.0, 100) == "at_risk" + assert adjustments.coverage_verdict(500.0, 100) == "stockout" + + +def test_coverage_verdict_zero_stock() -> None: + """Zero stock is a stockout when any demand exists.""" + assert adjustments.coverage_verdict(10.0, 0) == "stockout" + assert adjustments.coverage_verdict(0.0, 0) == "at_risk" + + +def test_coverage_verdict_uses_inventory_assumption_field() -> None: + """The verdict reads on_hand_units straight off the assumption model.""" + inventory = InventoryAssumption(on_hand_units=100) + assert adjustments.coverage_verdict(50.0, inventory.on_hand_units) == "covered" diff --git a/app/features/scenarios/tests/test_agent_tools.py b/app/features/scenarios/tests/test_agent_tools.py new file mode 100644 index 00000000..5f1ba739 --- /dev/null +++ b/app/features/scenarios/tests/test_agent_tools.py @@ -0,0 +1,222 @@ +"""Tests for the scenarios agent tools and the HITL save gate (PRP-27 Phase D). + +Two layers: + +* A unit test that the ``save_scenario`` tool name is wired into + ``agent_require_approval`` — the mutation-surface guard. +* Integration tests (real PostgreSQL + a real model bundle, ``docker compose up + -d`` required) covering ``propose_scenario`` (read-only — persists nothing), + ``save_scenario`` (persists with agent provenance), and the HITL gate firing + through ``AgentService.approve_action``. +""" + +from __future__ import annotations + +import uuid +from datetime import UTC, datetime, timedelta + +import pytest +from sqlalchemy import delete, func, select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.config import get_settings +from app.features.agents.agents.base import requires_approval +from app.features.agents.models import AgentSession, AgentType, SessionStatus +from app.features.agents.service import AgentService +from app.features.scenarios.agent_tools import propose_scenario, save_scenario +from app.features.scenarios.models import ScenarioPlan +from app.features.scenarios.schemas import SaveScenarioRequest, ScenarioAssumptions +from app.features.scenarios.tests.conftest import TEST_PRODUCT_ID, TEST_STORE_ID + + +def test_save_scenario_requires_approval() -> None: + """``save_scenario`` is in agent_require_approval — the mutation-surface gate.""" + assert "save_scenario" in get_settings().agent_require_approval + assert requires_approval("save_scenario") is True + + +@pytest.mark.integration +class TestProposeScenario: + """propose_scenario drafts a candidate and persists nothing.""" + + async def test_returns_valid_assumptions_and_recommendation( + self, db_session: AsyncSession + ) -> None: + """A default objective yields a valid price-cut candidate.""" + result = await propose_scenario( + db_session, + store_id=TEST_STORE_ID, + product_id=TEST_PRODUCT_ID, + horizon=14, + objective="grow demand for the summer range", + ) + + # The candidate assumptions round-trip through the real schema. + assumptions = ScenarioAssumptions.model_validate(result["assumptions"]) + assert assumptions.price is not None + assert assumptions.price.change_pct < 0.0 + assert isinstance(result["recommendation"], str) + assert result["recommendation"] + + async def test_promotion_keyword_proposes_a_promotion(self, db_session: AsyncSession) -> None: + """An objective mentioning a promotion steers the candidate accordingly.""" + result = await propose_scenario( + db_session, + store_id=TEST_STORE_ID, + product_id=TEST_PRODUCT_ID, + horizon=7, + objective="run a promotion next week", + ) + + assumptions = ScenarioAssumptions.model_validate(result["assumptions"]) + assert assumptions.promotion is not None + assert assumptions.price is None + + async def test_persists_no_row(self, db_session: AsyncSession) -> None: + """propose_scenario is read-only — it never writes a scenario_plan row.""" + await propose_scenario( + db_session, + store_id=TEST_STORE_ID, + product_id=TEST_PRODUCT_ID, + horizon=10, + objective="test", + ) + count = await db_session.scalar(select(func.count()).select_from(ScenarioPlan)) + assert count == 0 + + +@pytest.mark.integration +class TestSaveScenario: + """save_scenario persists a plan stamped with agent provenance.""" + + async def test_persists_with_agent_provenance( + self, db_session: AsyncSession, trained_model: str + ) -> None: + """A save stamps source='agent', the session id, and the audit trail.""" + request = SaveScenarioRequest( + name="Saved by agent", + assumptions=ScenarioAssumptions(), + store_id=TEST_STORE_ID, + product_id=TEST_PRODUCT_ID, + horizon=7, + run_id=trained_model, + ) + + result = await save_scenario(db_session, request, agent_session_id="sess-xyz") + + assert result["source"] == "agent" + assert result["agent_session_id"] == "sess-xyz" + assert result["approved_by"] == "operator" + assert result["approval_decision"] == "approved" + assert result["approved_at"] is not None + + rows = (await db_session.execute(select(ScenarioPlan))).scalars().all() + assert len(rows) == 1 + assert rows[0].source == "agent" + assert rows[0].agent_session_id == "sess-xyz" + + +@pytest.mark.integration +class TestSaveScenarioHITLGate: + """The save_scenario HITL gate persists a row only once approved.""" + + @staticmethod + def _pending_save_action(session_id: str, run_id: str) -> dict[str, object]: + """Build a pending save_scenario action for the given session.""" + now = datetime.now(UTC) + return { + "action_id": "act-save-1", + "action_type": "save_scenario", + "description": "Save the proposed scenario", + "arguments": { + "name": "Agent-proposed plan", + "run_id": run_id, + "store_id": TEST_STORE_ID, + "product_id": TEST_PRODUCT_ID, + "horizon": 7, + "assumptions": {}, + "source": "agent", + "agent_session_id": session_id, + }, + "created_at": now.isoformat(), + "expires_at": (now + timedelta(minutes=5)).isoformat(), + } + + async def _seed_session(self, db_session: AsyncSession, session_id: str, run_id: str) -> None: + """Insert an experiment session awaiting a save_scenario approval.""" + now = datetime.now(UTC) + db_session.add( + AgentSession( + session_id=session_id, + agent_type=AgentType.EXPERIMENT.value, + status=SessionStatus.AWAITING_APPROVAL.value, + message_history=[], + pending_action=self._pending_save_action(session_id, run_id), + total_tokens_used=0, + tool_calls_count=1, + last_activity=now, + expires_at=now + timedelta(minutes=30), + ) + ) + await db_session.commit() + + async def test_approve_persists_agent_plan( + self, db_session: AsyncSession, trained_model: str + ) -> None: + """Approving the pending action persists a row with the audit trail.""" + session_id = uuid.uuid4().hex # session_id is VARCHAR(32) — hex is exactly 32 + await self._seed_session(db_session, session_id, trained_model) + try: + response = await AgentService().approve_action( + db=db_session, + session_id=session_id, + action_id="act-save-1", + approved=True, + ) + + assert response.status == "executed" + assert isinstance(response.result, dict) + assert response.result["source"] == "agent" + + rows = ( + ( + await db_session.execute( + select(ScenarioPlan).where(ScenarioPlan.agent_session_id == session_id) + ) + ) + .scalars() + .all() + ) + assert len(rows) == 1 + assert rows[0].source == "agent" + assert rows[0].approved_by == "operator" + assert rows[0].approval_decision == "approved" + assert rows[0].approved_at is not None + finally: + await db_session.execute( + delete(AgentSession).where(AgentSession.session_id == session_id) + ) + await db_session.commit() + + async def test_reject_persists_no_plan( + self, db_session: AsyncSession, trained_model: str + ) -> None: + """Rejecting the pending action writes no scenario_plan row.""" + session_id = uuid.uuid4().hex # session_id is VARCHAR(32) — hex is exactly 32 + await self._seed_session(db_session, session_id, trained_model) + try: + response = await AgentService().approve_action( + db=db_session, + session_id=session_id, + action_id="act-save-1", + approved=False, + ) + + assert response.status == "rejected" + count = await db_session.scalar(select(func.count()).select_from(ScenarioPlan)) + assert count == 0 + finally: + await db_session.execute( + delete(AgentSession).where(AgentSession.session_id == session_id) + ) + await db_session.commit() diff --git a/app/features/scenarios/tests/test_compare_integration.py b/app/features/scenarios/tests/test_compare_integration.py new file mode 100644 index 00000000..a2da7a61 --- /dev/null +++ b/app/features/scenarios/tests/test_compare_integration.py @@ -0,0 +1,139 @@ +"""Integration tests for the scenario library + multi-scenario comparison. + +PRP-27 Phase C. Runs against a real PostgreSQL database and a real model +bundle on disk. Requires ``docker compose up -d``. +""" + +from __future__ import annotations + +from typing import Any + +import pytest +from httpx import AsyncClient + +_PRICE_CUT: dict[str, object] = { + "price": {"change_pct": -0.15, "start_date": "2026-07-01", "end_date": "2026-07-14"}, +} +_PRICE_RISE: dict[str, object] = { + "price": {"change_pct": 0.20, "start_date": "2026-07-01", "end_date": "2026-07-14"}, +} + + +async def _create_plan( + client: AsyncClient, + run_id: str, + name: str, + assumptions: dict[str, object], + *, + tags: list[str] | None = None, + cloned_from: str | None = None, +) -> dict[str, Any]: + """Create a saved scenario plan and return its JSON body.""" + body: dict[str, object] = { + "name": name, + "run_id": run_id, + "horizon": 14, + "assumptions": assumptions, + } + if tags is not None: + body["tags"] = tags + if cloned_from is not None: + body["cloned_from"] = cloned_from + response = await client.post("/scenarios", json=body) + assert response.status_code == 201, response.text + created: dict[str, Any] = response.json() + return created + + +@pytest.mark.integration +@pytest.mark.asyncio +class TestCompareScenarios: + """Integration tests for POST /scenarios/compare.""" + + async def test_compare_ranks_plans(self, client: AsyncClient, trained_model: str) -> None: + """Comparing a price cut and a price rise ranks the cut first.""" + cut = await _create_plan(client, trained_model, "Cut", _PRICE_CUT) + rise = await _create_plan(client, trained_model, "Rise", _PRICE_RISE) + + response = await client.post( + "/scenarios/compare", + json={ + "scenario_ids": [rise["scenario_id"], cut["scenario_id"]], + "rank_by": "revenue_delta", + }, + ) + assert response.status_code == 200 + data = response.json() + + assert data["rank_by"] == "revenue_delta" + assert len(data["scenarios"]) == 2 + # A price cut lifts revenue, so it outranks the price rise. + assert data["scenarios"][0]["name"] == "Cut" + assert data["scenarios"][0]["rank"] == 1 + assert data["scenarios"][1]["rank"] == 2 + assert data["chart_series"], "chart_series must carry merged date rows" + assert "baseline" in data["chart_series"][0] + + async def test_compare_too_few_returns_422(self, client: AsyncClient) -> None: + """Fewer than 2 scenario_ids is rejected at the schema boundary.""" + response = await client.post("/scenarios/compare", json={"scenario_ids": ["only-one"]}) + assert response.status_code == 422 + + async def test_compare_too_many_returns_422(self, client: AsyncClient) -> None: + """More than MAX_COMPARE_SCENARIOS (5) scenario_ids is rejected.""" + response = await client.post( + "/scenarios/compare", + json={"scenario_ids": [f"id-{index}" for index in range(6)]}, + ) + assert response.status_code == 422 + + async def test_compare_bogus_id_returns_404( + self, client: AsyncClient, trained_model: str + ) -> None: + """An unknown scenario_id returns an RFC 7807 404 — never a 500.""" + plan = await _create_plan(client, trained_model, "Real", _PRICE_CUT) + response = await client.post( + "/scenarios/compare", + json={"scenario_ids": [plan["scenario_id"], "does-not-exist-xyz"]}, + ) + assert response.status_code == 404 + assert "application/problem+json" in response.headers.get("content-type", "") + + +@pytest.mark.integration +@pytest.mark.asyncio +class TestScenarioLibrary: + """Integration tests for tag filtering and plan cloning.""" + + async def test_create_with_tags_and_filter( + self, client: AsyncClient, trained_model: str + ) -> None: + """A tag filter returns only plans carrying every listed tag.""" + await _create_plan(client, trained_model, "Tagged", _PRICE_CUT, tags=["q3", "promo"]) + await _create_plan(client, trained_model, "Untagged", _PRICE_CUT, tags=[]) + + response = await client.get("/scenarios", params={"tags": ["q3"]}) + assert response.status_code == 200 + data = response.json() + + names = {item["name"] for item in data["scenarios"]} + assert "Tagged" in names + assert "Untagged" not in names + for item in data["scenarios"]: + assert "q3" in item["tags"] + + async def test_clone_records_cloned_from(self, client: AsyncClient, trained_model: str) -> None: + """A plan created with cloned_from records its origin.""" + original = await _create_plan(client, trained_model, "Original", _PRICE_CUT) + clone = await _create_plan( + client, + trained_model, + "Clone of original", + _PRICE_CUT, + cloned_from=original["scenario_id"], + ) + assert clone["cloned_from"] == original["scenario_id"] + + fetched = await client.get(f"/scenarios/{clone['scenario_id']}") + assert fetched.status_code == 200 + assert fetched.json()["cloned_from"] == original["scenario_id"] diff --git a/app/features/scenarios/tests/test_feature_frame.py b/app/features/scenarios/tests/test_feature_frame.py new file mode 100644 index 00000000..22306eb1 --- /dev/null +++ b/app/features/scenarios/tests/test_feature_frame.py @@ -0,0 +1,214 @@ +"""Unit tests for the future feature-frame generator (PRP-27 Phase A). + +These exercise the pure builders — calendar columns, target long-lag columns, +assumption-driven exogenous columns, and the :func:`assemble_future_frame` +orchestration. The leakage invariants live separately in +``test_future_frame_leakage.py`` (the load-bearing spec). +""" + +from __future__ import annotations + +import math +from datetime import date, timedelta + +from app.features.scenarios.feature_frame import ( + CALENDAR_COLUMNS, + EXOGENOUS_COLUMNS, + EXOGENOUS_LAGS, + HISTORY_TAIL_DAYS, + MAX_COMPARE_SCENARIOS, + assemble_future_frame, + build_calendar_columns, + build_exogenous_columns, + build_long_lag_columns, + canonical_feature_columns, +) +from app.features.scenarios.schemas import ( + HolidayAssumption, + PriceAssumption, + PromotionAssumption, + ScenarioAssumptions, +) + +_ORIGIN = date(2026, 6, 30) +_HORIZON = 14 +_HORIZON_DATES = [_ORIGIN + timedelta(days=offset) for offset in range(1, _HORIZON + 1)] + + +# --- pinned constants --------------------------------------------------------- + + +def test_pinned_constants() -> None: + """The PRP-27 pinned modelling constants hold their decided values.""" + assert EXOGENOUS_LAGS == (1, 7, 14, 28) + assert HISTORY_TAIL_DAYS == 90 + assert MAX_COMPARE_SCENARIOS == 5 + + +def test_canonical_feature_columns_order() -> None: + """The canonical column list is target lags, then calendar, then exogenous.""" + columns = canonical_feature_columns() + assert columns[:4] == ["lag_1", "lag_7", "lag_14", "lag_28"] + assert columns[4 : 4 + len(CALENDAR_COLUMNS)] == list(CALENDAR_COLUMNS) + assert columns[-len(EXOGENOUS_COLUMNS) :] == list(EXOGENOUS_COLUMNS) + assert len(columns) == len(EXOGENOUS_LAGS) + len(CALENDAR_COLUMNS) + len(EXOGENOUS_COLUMNS) + + +# --- calendar columns --------------------------------------------------------- + + +def test_calendar_columns_are_pure_function_of_date() -> None: + """Calendar columns depend only on the dates — two calls match exactly.""" + first = build_calendar_columns(_HORIZON_DATES) + second = build_calendar_columns(list(_HORIZON_DATES)) + assert first == second + assert set(first) == set(CALENDAR_COLUMNS) + for values in first.values(): + assert len(values) == _HORIZON + + +def test_calendar_is_weekend_and_month_end() -> None: + """``is_weekend`` and ``is_month_end`` reflect the date itself.""" + dates = [date(2026, 7, 30), date(2026, 7, 31), date(2026, 8, 1), date(2026, 8, 2)] + columns = build_calendar_columns(dates) + # 2026-07-31 is the month end; 2026-08-01 (Sat) and 2026-08-02 (Sun) weekend. + assert columns["is_month_end"] == [0.0, 1.0, 0.0, 0.0] + assert columns["is_weekend"] == [1.0 if d.weekday() >= 5 else 0.0 for d in dates] + + +def test_calendar_cyclical_encoding_bounded() -> None: + """Cyclical sin/cos encodings stay within [-1, 1].""" + columns = build_calendar_columns(_HORIZON_DATES) + for name in ("dow_sin", "dow_cos", "month_sin", "month_cos"): + assert all(-1.0 <= value <= 1.0 for value in columns[name]) + + +# --- long-lag columns --------------------------------------------------------- + + +def test_long_lag_indexing_is_correct() -> None: + """``lag_k`` at horizon day ``j`` equals the observed ``y[T+j-k]``.""" + # history_tail[-1] == y[T], history_tail[-2] == y[T-1], ... + history_tail = [float(value) for value in range(HISTORY_TAIL_DAYS)] + columns = build_long_lag_columns(history_tail, _HORIZON) + + assert set(columns) == {f"lag_{k}" for k in EXOGENOUS_LAGS} + # lag_1 at j=1 reads history_tail[-1] (y[T]); j>=2 needs a future target. + assert columns["lag_1"][0] == history_tail[-1] + assert all(math.isnan(value) for value in columns["lag_1"][1:]) + # lag_7 at j=1 reads history_tail[-7]; populated for j in 1..7. + assert columns["lag_7"][0] == history_tail[-7] + for j in range(1, _HORIZON + 1): + cell = columns["lag_7"][j - 1] + if j <= 7: + assert cell == history_tail[(j - 1) - 7] + else: + assert math.isnan(cell) + + +def test_long_lag_all_columns_present_for_long_horizon() -> None: + """Every lag column is emitted even when the horizon exceeds the offset.""" + history_tail = [float(value) for value in range(HISTORY_TAIL_DAYS)] + columns = build_long_lag_columns(history_tail, horizon=40) + assert set(columns) == {f"lag_{k}" for k in EXOGENOUS_LAGS} + # lag_28 over a 40-day horizon: first 28 days populated, last 12 NaN. + assert all(not math.isnan(v) for v in columns["lag_28"][:28]) + assert all(math.isnan(v) for v in columns["lag_28"][28:]) + + +def test_long_lag_short_history_yields_nan() -> None: + """A history shorter than the lag offset produces NaN, never an error.""" + columns = build_long_lag_columns([5.0, 6.0, 7.0], horizon=4) + # lag_28 cannot resolve from a 3-element tail — all NaN. + assert all(math.isnan(value) for value in columns["lag_28"]) + # lag_1 at j=1 still resolves to the origin observation. + assert columns["lag_1"][0] == 7.0 + + +# --- exogenous columns -------------------------------------------------------- + + +def test_exogenous_price_window() -> None: + """``price_factor`` is ``1 + change_pct`` inside the window, ``1.0`` outside.""" + assumptions = ScenarioAssumptions( + price=PriceAssumption( + change_pct=-0.15, + start_date=_HORIZON_DATES[2], + end_date=_HORIZON_DATES[5], + ) + ) + columns = build_exogenous_columns(_HORIZON_DATES, assumptions, set(), launch_date=None) + for index, value in enumerate(columns["price_factor"]): + if 2 <= index <= 5: + assert value == 0.85 + else: + assert value == 1.0 + + +def test_exogenous_promo_and_holiday() -> None: + """``promo_active`` flags the promotion window; ``is_holiday`` unions sources.""" + calendar_holiday = _HORIZON_DATES[0] + assumption_holiday = _HORIZON_DATES[9] + assumptions = ScenarioAssumptions( + promotion=PromotionAssumption( + kind="pct_off", + start_date=_HORIZON_DATES[1], + end_date=_HORIZON_DATES[3], + ), + holiday=HolidayAssumption(dates=[assumption_holiday]), + ) + columns = build_exogenous_columns( + _HORIZON_DATES, assumptions, {calendar_holiday}, launch_date=None + ) + assert [i for i, v in enumerate(columns["promo_active"]) if v == 1.0] == [1, 2, 3] + # is_holiday unions the calendar holiday and the assumption holiday. + assert [i for i, v in enumerate(columns["is_holiday"]) if v == 1.0] == [0, 9] + + +def test_exogenous_days_since_launch() -> None: + """``days_since_launch`` is a pure date delta, NaN without a launch date.""" + launch = _ORIGIN - timedelta(days=100) + with_launch = build_exogenous_columns( + _HORIZON_DATES, ScenarioAssumptions(), set(), launch_date=launch + ) + assert with_launch["days_since_launch"][0] == float((_HORIZON_DATES[0] - launch).days) + without_launch = build_exogenous_columns( + _HORIZON_DATES, ScenarioAssumptions(), set(), launch_date=None + ) + assert all(math.isnan(value) for value in without_launch["days_since_launch"]) + + +# --- assembly ----------------------------------------------------------------- + + +def test_assemble_future_frame_shape_and_order() -> None: + """The assembled matrix matches ``feature_columns`` in width and order.""" + columns = canonical_feature_columns() + history_tail = [float(value) for value in range(HISTORY_TAIL_DAYS)] + frame = assemble_future_frame( + dates=_HORIZON_DATES, + feature_columns=columns, + history_tail=history_tail, + assumptions=ScenarioAssumptions(), + holiday_dates=set(), + launch_date=None, + ) + assert frame.feature_columns == columns + assert frame.dates == _HORIZON_DATES + assert len(frame.matrix) == _HORIZON + assert all(len(row) == len(columns) for row in frame.matrix) + + +def test_assemble_future_frame_unknown_column_is_nan() -> None: + """A requested column the builders do not produce becomes an all-NaN column.""" + columns = [*canonical_feature_columns(), "mystery_feature"] + frame = assemble_future_frame( + dates=_HORIZON_DATES, + feature_columns=columns, + history_tail=[float(v) for v in range(HISTORY_TAIL_DAYS)], + assumptions=ScenarioAssumptions(), + holiday_dates=set(), + launch_date=None, + ) + mystery_index = columns.index("mystery_feature") + assert all(math.isnan(row[mystery_index]) for row in frame.matrix) diff --git a/app/features/scenarios/tests/test_future_frame_leakage.py b/app/features/scenarios/tests/test_future_frame_leakage.py new file mode 100644 index 00000000..4a4659de --- /dev/null +++ b/app/features/scenarios/tests/test_future_frame_leakage.py @@ -0,0 +1,166 @@ +"""Leakage spec for the future feature frame — LOAD-BEARING (PRP-27 Phase A). + +This file IS the spec, mirroring ``app/features/featuresets/tests/test_leakage.py`` +and ``app/features/scenarios/tests/test_leakage.py``: it must NEVER be weakened +to make a feature pass (AGENTS.md § Safety). + +The model-driven scenario path re-forecasts demand through a feature-consuming +regressor, which means it builds a *future feature frame*. A horizon day has no +observed target, so the invariant is: + + A future feature value for horizon day ``D`` may use ONLY information + knowable at the forecast origin ``T``: the observed history up to and + including ``T``, the calendar (a pure function of the date), or the + scenario assumptions (the planner's posited future inputs). It may NEVER + read an observed target at a horizon day ``D`` (which lies after ``T``). + +Concretely this spec asserts: + +1. ``build_long_lag_columns`` returns only values drawn from ``history_tail`` + (entirely ``<= T``) or ``NaN`` — never a value from the future target + series. +2. A lag cell whose source day lies at or after the first horizon day is + ``NaN`` — the generator never fabricates or recursively predicts it. +3. Calendar columns are independent of the target series entirely. +4. An assumption window that falls before the forecast origin contributes + nothing — every horizon day lies strictly after ``T``. +5. Every non-``NaN`` ``lag_*`` cell in an assembled frame is a member of + ``history_tail``. +""" + +from __future__ import annotations + +import math +from datetime import date, timedelta + +from app.features.scenarios.feature_frame import ( + EXOGENOUS_LAGS, + assemble_future_frame, + build_calendar_columns, + build_exogenous_columns, + build_long_lag_columns, + canonical_feature_columns, +) +from app.features.scenarios.schemas import PriceAssumption, ScenarioAssumptions + +# The forecast origin T is the last observed day; the horizon runs T+1 … T+H. +_ORIGIN = date(2026, 6, 30) +_HORIZON = 21 +_HORIZON_DATES = [_ORIGIN + timedelta(days=offset) for offset in range(1, _HORIZON + 1)] + +# Observed history (all <= T): 90 distinct values 1000.0 … 1089.0. +# history_tail[-1] == y[T], the origin observation. +_HISTORY_TAIL = [1000.0 + float(i) for i in range(90)] +# A DISJOINT "future target" series the generator must never be able to read. +# Any of these values appearing in a feature cell is a leak. +_FUTURE_TARGETS = {9000.0 + float(i) for i in range(_HORIZON)} + + +def test_long_lag_columns_never_emit_a_future_target() -> None: + """Every non-NaN long-lag cell is drawn from the observed history. + + ``build_long_lag_columns`` takes ONLY ``history_tail`` as data input — it + is structurally incapable of reading the future target series. This spec + pins that: no value disjoint from ``history_tail`` may ever appear. + """ + history_values = set(_HISTORY_TAIL) + columns = build_long_lag_columns(_HISTORY_TAIL, _HORIZON) + + for name, values in columns.items(): + for cell in values: + if math.isnan(cell): + continue + assert cell in history_values, ( + f"{name} emitted {cell}, which is not an observed history value" + ) + assert cell not in _FUTURE_TARGETS, f"{name} leaked a future target value {cell}" + + +def test_long_lag_source_index_is_never_at_or_after_the_horizon() -> None: + """A lag cell is populated only when its source day lies at/before ``T``. + + For lag ``k`` and horizon day ``j`` the source index into ``history_tail`` + is ``(j-1)-k``. A non-NaN cell REQUIRES that index to be negative — i.e. + the source target lies at or before the origin ``T``. A non-negative index + would point at a future horizon day and MUST yield ``NaN``. + """ + columns = build_long_lag_columns(_HISTORY_TAIL, _HORIZON) + for lag in EXOGENOUS_LAGS: + column = columns[f"lag_{lag}"] + for j in range(1, _HORIZON + 1): + source_index = (j - 1) - lag + cell = column[j - 1] + if source_index >= 0: + assert math.isnan(cell), ( + f"lag_{lag} day {j}: source index {source_index} is in the " + "future but the cell is not NaN" + ) + else: + assert not math.isnan(cell), ( + f"lag_{lag} day {j}: source index {source_index} is in " + "history but the cell is NaN" + ) + + +def test_calendar_columns_are_independent_of_the_target_series() -> None: + """Calendar columns read only the dates — they cannot leak the target. + + ``build_calendar_columns`` does not accept the target series at all; this + spec pins that structural fact by asserting its output is identical no + matter what history precedes it. + """ + calendar_a = build_calendar_columns(_HORIZON_DATES) + calendar_b = build_calendar_columns(_HORIZON_DATES) + assert calendar_a == calendar_b + # No calendar value coincides with a history or future target value. + history_values = set(_HISTORY_TAIL) + for values in calendar_a.values(): + for cell in values: + assert cell not in history_values + assert cell not in _FUTURE_TARGETS + + +def test_assumption_window_before_origin_has_no_effect() -> None: + """A price window entirely before the forecast origin contributes nothing. + + Every horizon day lies strictly after ``T``; a window that ends on or + before ``T`` can never intersect the horizon, so ``price_factor`` stays + neutral (``1.0``) for every day — the assumption cannot reach into history. + """ + past_window = ScenarioAssumptions( + price=PriceAssumption( + change_pct=-0.40, + start_date=_ORIGIN - timedelta(days=30), + end_date=_ORIGIN, + ) + ) + columns = build_exogenous_columns(_HORIZON_DATES, past_window, set(), launch_date=None) + assert columns["price_factor"] == [1.0] * _HORIZON, ( + "a price window ending at/before the origin must not move price_factor" + ) + + +def test_assembled_frame_lag_cells_are_history_or_nan() -> None: + """Every non-NaN ``lag_*`` cell in an assembled frame is an observed value. + + This is the end-to-end leakage assertion: assemble a full frame and verify + every target-lag column still only ever shows a history value or ``NaN``. + """ + columns = canonical_feature_columns() + frame = assemble_future_frame( + dates=_HORIZON_DATES, + feature_columns=columns, + history_tail=_HISTORY_TAIL, + assumptions=ScenarioAssumptions(), + holiday_dates=set(), + launch_date=None, + ) + history_values = set(_HISTORY_TAIL) + lag_indices = {columns.index(f"lag_{k}") for k in EXOGENOUS_LAGS} + for row in frame.matrix: + for col_index in lag_indices: + cell = row[col_index] + if math.isnan(cell): + continue + assert cell in history_values, f"assembled frame leaked non-history value {cell}" + assert cell not in _FUTURE_TARGETS, f"assembled frame leaked future target {cell}" diff --git a/app/features/scenarios/tests/test_leakage.py b/app/features/scenarios/tests/test_leakage.py new file mode 100644 index 00000000..256947b7 --- /dev/null +++ b/app/features/scenarios/tests/test_leakage.py @@ -0,0 +1,94 @@ +"""Leakage spec for scenario simulation — LOAD-BEARING. + +This file IS the spec, mirroring the precedent of +``app/features/featuresets/tests/test_leakage.py``: it must NEVER be weakened +to make a feature pass (AGENTS.md § Safety). + +The invariant: a scenario adjustment touches ONLY horizon (future) points. It +applies a deterministic post-forecast multiplier to the baseline forecast and +can never reach back into, read, or mutate the historical target series. + +Concretely this spec asserts: + +1. ``apply_adjustment`` returns a NEW list and never mutates its ``baseline`` + input, and the adjusted series has exactly ``horizon`` points. +2. An assumption window that falls entirely BEFORE the forecast start + contributes factor ``1.0`` to every horizon day — it cannot affect history, + and it cannot affect the future either. +3. A day outside any assumption window contributes factor ``1.0``. +4. An empty ``ScenarioAssumptions`` leaves the baseline exactly unchanged. +""" + +from datetime import date, timedelta + +from app.features.scenarios import adjustments +from app.features.scenarios.schemas import PriceAssumption, ScenarioAssumptions + +# A deterministic forecast horizon used throughout this spec. +_FORECAST_START = date(2026, 7, 1) +_HORIZON = 14 +_HORIZON_DATES = [_FORECAST_START + timedelta(days=offset) for offset in range(_HORIZON)] + + +def test_apply_adjustment_does_not_mutate_baseline() -> None: + """``apply_adjustment`` returns a new list; the input baseline is untouched.""" + baseline = [10.0] * _HORIZON + baseline_snapshot = list(baseline) + factors = [1.5] * _HORIZON + + adjusted = adjustments.apply_adjustment(baseline, factors) + + assert adjusted is not baseline, "apply_adjustment must return a NEW list" + assert baseline == baseline_snapshot, "the input baseline must never be mutated" + assert len(adjusted) == _HORIZON, "the adjusted series must keep the horizon length" + + +def test_assumption_window_before_forecast_start_has_no_effect() -> None: + """A price window entirely before the forecast start contributes no factor. + + The window 2026-06-01 .. 2026-06-15 ends before the forecast starts on + 2026-07-01. Every horizon day must therefore receive factor 1.0 — the + adjustment can never reach a date outside the future horizon. + """ + past_window = ScenarioAssumptions( + price=PriceAssumption( + change_pct=-0.30, + start_date=date(2026, 6, 1), + end_date=date(2026, 6, 15), + ) + ) + factors = [ + adjustments.combined_daily_factor(point_date, past_window) for point_date in _HORIZON_DATES + ] + assert factors == [1.0] * _HORIZON, "a pre-forecast window must not affect the horizon" + + +def test_out_of_window_days_contribute_unit_factor() -> None: + """Only days inside the assumption window are adjusted; the rest stay 1.0. + + The window covers exactly the first three horizon days; days 4..14 must be + untouched (factor 1.0). + """ + windowed = ScenarioAssumptions( + price=PriceAssumption( + change_pct=-0.25, + start_date=_HORIZON_DATES[0], + end_date=_HORIZON_DATES[2], + ) + ) + factors = [ + adjustments.combined_daily_factor(point_date, windowed) for point_date in _HORIZON_DATES + ] + assert all(factor > 1.0 for factor in factors[:3]), "in-window days must be adjusted" + assert factors[3:] == [1.0] * (_HORIZON - 3), "out-of-window days must stay neutral" + + +def test_empty_assumptions_leave_baseline_unchanged() -> None: + """With no assumptions the scenario series equals the baseline exactly.""" + baseline = [float(value) for value in range(1, _HORIZON + 1)] + factors = [ + adjustments.combined_daily_factor(point_date, ScenarioAssumptions()) + for point_date in _HORIZON_DATES + ] + scenario = adjustments.apply_adjustment(baseline, factors) + assert scenario == baseline, "an empty scenario must not move the baseline" diff --git a/app/features/scenarios/tests/test_routes_integration.py b/app/features/scenarios/tests/test_routes_integration.py new file mode 100644 index 00000000..f1d16976 --- /dev/null +++ b/app/features/scenarios/tests/test_routes_integration.py @@ -0,0 +1,260 @@ +"""Integration tests for the scenarios routes. + +Runs against a real PostgreSQL database and a real model bundle on disk — the +full path from HTTP request through artifact resolution, forecast, adjustment, +and persistence. Requires ``docker compose up -d``. +""" + +import uuid + +import pytest +from httpx import AsyncClient +from sqlalchemy.exc import IntegrityError +from sqlalchemy.ext.asyncio import AsyncSession + +from app.features.scenarios.models import ScenarioPlan + +# A price window covering the test bundle's 14-day horizon (train_end 2026-06-30). +_PRICE_ASSUMPTION = { + "price": {"change_pct": -0.15, "start_date": "2026-07-01", "end_date": "2026-07-14"}, +} + + +@pytest.mark.integration +@pytest.mark.asyncio +class TestSimulate: + """Integration tests for POST /scenarios/simulate.""" + + async def test_simulate_happy_path(self, client: AsyncClient, trained_model: str) -> None: + """A price-cut simulation returns a full, well-formed comparison.""" + response = await client.post( + "/scenarios/simulate", + json={"run_id": trained_model, "horizon": 14, "assumptions": _PRICE_ASSUMPTION}, + ) + + assert response.status_code == 200 + data = response.json() + + assert len(data["points"]) == 14 + assert data["horizon"] == 14 + assert data["method"] == "heuristic" + assert data["disclaimer"], "every comparison must carry a non-empty disclaimer" + # A price cut lifts demand — the scenario total must exceed the baseline. + assert data["units_delta"] > 0.0 + assert data["scenario_total_units"] > data["baseline_total_units"] + for point in data["points"]: + assert point["applied_factor"] > 1.0 + + async def test_simulate_empty_assumptions_equals_baseline( + self, client: AsyncClient, trained_model: str + ) -> None: + """An empty ScenarioAssumptions yields scenario == baseline, all deltas 0.""" + response = await client.post( + "/scenarios/simulate", + json={"run_id": trained_model, "horizon": 10, "assumptions": {}}, + ) + + assert response.status_code == 200 + data = response.json() + + assert data["units_delta"] == 0.0 + assert data["revenue_delta"] == 0.0 + assert data["coverage_verdict"] == "unknown" + for point in data["points"]: + assert point["delta"] == 0.0 + assert point["applied_factor"] == 1.0 + + async def test_simulate_bogus_run_id_returns_404(self, client: AsyncClient) -> None: + """A run_id with no artifact returns an RFC 7807 404 — never a 500.""" + response = await client.post( + "/scenarios/simulate", + json={"run_id": "does-not-exist-999", "horizon": 14, "assumptions": {}}, + ) + + assert response.status_code == 404 + assert response.status_code != 500 + assert "application/problem+json" in response.headers.get("content-type", "") + + async def test_simulate_invalid_horizon_rejected(self, client: AsyncClient) -> None: + """horizon below the ge=1 bound returns 422.""" + response = await client.post( + "/scenarios/simulate", + json={"run_id": "anything", "horizon": 0, "assumptions": {}}, + ) + assert response.status_code == 422 + + +@pytest.mark.integration +@pytest.mark.asyncio +class TestScenarioPlanCrud: + """Integration tests for the scenario_plan CRUD endpoints.""" + + async def test_crud_round_trip(self, client: AsyncClient, trained_model: str) -> None: + """A plan can be created, listed, fetched, and deleted.""" + create = await client.post( + "/scenarios", + json={ + "name": "Summer price cut", + "run_id": trained_model, + "horizon": 14, + "assumptions": _PRICE_ASSUMPTION, + }, + ) + assert create.status_code == 201 + plan = create.json() + scenario_id = plan["scenario_id"] + assert plan["name"] == "Summer price cut" + assert plan["method"] == "heuristic" + assert len(plan["comparison"]["points"]) == 14 + + listed = await client.get("/scenarios") + assert listed.status_code == 200 + list_data = listed.json() + assert list_data["total"] >= 1 + assert scenario_id in {item["scenario_id"] for item in list_data["scenarios"]} + + fetched = await client.get(f"/scenarios/{scenario_id}") + assert fetched.status_code == 200 + assert fetched.json()["comparison"]["units_delta"] > 0.0 + + deleted = await client.delete(f"/scenarios/{scenario_id}") + assert deleted.status_code == 204 + + missing = await client.get(f"/scenarios/{scenario_id}") + assert missing.status_code == 404 + + async def test_list_scenarios_empty_is_200(self, client: AsyncClient) -> None: + """GET /scenarios returns 200 + an empty list, never 404.""" + response = await client.get("/scenarios") + assert response.status_code == 200 + data = response.json() + assert isinstance(data["scenarios"], list) + assert data["total"] >= 0 + + async def test_get_missing_plan_returns_404(self, client: AsyncClient) -> None: + """Fetching an unknown scenario_id returns 404.""" + response = await client.get(f"/scenarios/{uuid.uuid4().hex}") + assert response.status_code == 404 + + async def test_delete_missing_plan_returns_404(self, client: AsyncClient) -> None: + """Deleting an unknown scenario_id returns 404.""" + response = await client.delete(f"/scenarios/{uuid.uuid4().hex}") + assert response.status_code == 404 + + +@pytest.mark.integration +@pytest.mark.asyncio +class TestSimulateModelExogenous: + """Integration tests for the model-driven (regression) simulate path.""" + + async def test_regression_baseline_returns_model_exogenous( + self, client: AsyncClient, trained_regression_model: str + ) -> None: + """A regression baseline re-forecasts — method is 'model_exogenous'.""" + response = await client.post( + "/scenarios/simulate", + json={ + "run_id": trained_regression_model, + "horizon": 14, + "assumptions": _PRICE_ASSUMPTION, + }, + ) + assert response.status_code == 200 + data = response.json() + + assert data["method"] == "model_exogenous" + assert data["disclaimer"], "every comparison must carry a non-empty disclaimer" + assert len(data["points"]) == 14 + # A price cut moves the re-forecast — the deltas are model-driven, not + # a fixed multiplier, and the modelled price response lifts demand. + assert data["units_delta"] > 0.0 + + async def test_regression_empty_assumptions_equals_baseline( + self, client: AsyncClient, trained_regression_model: str + ) -> None: + """With no assumptions the model re-forecasts to exactly the baseline.""" + response = await client.post( + "/scenarios/simulate", + json={"run_id": trained_regression_model, "horizon": 10, "assumptions": {}}, + ) + assert response.status_code == 200 + data = response.json() + + assert data["method"] == "model_exogenous" + assert data["units_delta"] == 0.0 + for point in data["points"]: + assert point["delta"] == 0.0 + + async def test_baseline_forecaster_still_heuristic( + self, client: AsyncClient, trained_model: str + ) -> None: + """A naive baseline still produces a heuristic comparison — unchanged.""" + response = await client.post( + "/scenarios/simulate", + json={"run_id": trained_model, "horizon": 14, "assumptions": _PRICE_ASSUMPTION}, + ) + assert response.status_code == 200 + assert response.json()["method"] == "heuristic" + + async def test_model_exogenous_plan_persists( + self, client: AsyncClient, trained_regression_model: str + ) -> None: + """A model_exogenous comparison saves cleanly — the widened CHECK accepts it.""" + create = await client.post( + "/scenarios", + json={ + "name": "Model-driven price cut", + "run_id": trained_regression_model, + "horizon": 14, + "assumptions": _PRICE_ASSUMPTION, + }, + ) + assert create.status_code == 201 + plan = create.json() + assert plan["method"] == "model_exogenous" + + fetched = await client.get(f"/scenarios/{plan['scenario_id']}") + assert fetched.status_code == 200 + assert fetched.json()["comparison"]["method"] == "model_exogenous" + + +@pytest.mark.integration +@pytest.mark.asyncio +class TestScenarioPlanModel: + """Constraint tests for the ScenarioPlan ORM model.""" + + async def test_method_check_rejects_unknown_value(self, db_session: AsyncSession) -> None: + """The method CHECK constraint rejects a value outside the allow-list.""" + plan = ScenarioPlan( + scenario_id=uuid.uuid4().hex, + name="bad method", + store_id=1, + product_id=1, + run_id="abc", + horizon=7, + assumptions={}, + comparison={}, + method="not_a_method", + ) + db_session.add(plan) + with pytest.raises(IntegrityError): + await db_session.commit() + await db_session.rollback() + + async def test_method_check_accepts_model_exogenous(self, db_session: AsyncSession) -> None: + """The widened CHECK constraint accepts the 'model_exogenous' method.""" + plan = ScenarioPlan( + scenario_id=uuid.uuid4().hex, + name="model exogenous plan", + store_id=1, + product_id=1, + run_id="abc", + horizon=7, + assumptions={}, + comparison={}, + method="model_exogenous", + ) + db_session.add(plan) + await db_session.commit() + await db_session.refresh(plan) + assert plan.method == "model_exogenous" diff --git a/app/features/scenarios/tests/test_schemas.py b/app/features/scenarios/tests/test_schemas.py new file mode 100644 index 00000000..47339b0e --- /dev/null +++ b/app/features/scenarios/tests/test_schemas.py @@ -0,0 +1,79 @@ +"""Unit tests for the scenario request / response schemas. + +The critical case exercises the FastAPI ``validate_python`` path — calling +``model_validate`` on a dict with ISO-string dates — to prove the +``Field(strict=False)`` overrides on every ``date`` field hold. Without them +every HTTP caller would 422 (see ``docs/_base/SECURITY.md``). +""" + +from datetime import date + +import pytest +from pydantic import ValidationError + +from app.features.scenarios.schemas import ( + CreateScenarioRequest, + PriceAssumption, + ScenarioAssumptions, + SimulateScenarioRequest, +) + + +def test_simulate_request_accepts_iso_string_dates() -> None: + """A JSON-shaped dict with ISO-string dates validates (validate_python path).""" + request = SimulateScenarioRequest.model_validate( + { + "run_id": "abc123def456", + "horizon": 14, + "assumptions": { + "price": { + "change_pct": -0.15, + "start_date": "2026-06-01", + "end_date": "2026-06-14", + }, + "holiday": {"dates": ["2026-06-07", "2026-06-08"]}, + }, + } + ) + assert request.assumptions.price is not None + assert request.assumptions.price.start_date == date(2026, 6, 1) + assert request.assumptions.holiday is not None + assert request.assumptions.holiday.dates == [date(2026, 6, 7), date(2026, 6, 8)] + + +def test_simulate_request_defaults_to_empty_assumptions() -> None: + """Omitting ``assumptions`` yields an empty (no-change) ScenarioAssumptions.""" + request = SimulateScenarioRequest.model_validate({"run_id": "abc", "horizon": 7}) + assert isinstance(request.assumptions, ScenarioAssumptions) + assert request.assumptions.price is None + assert request.assumptions.promotion is None + + +def test_price_assumption_change_pct_bounds() -> None: + """change_pct outside [-0.9, 5.0] is rejected.""" + with pytest.raises(ValidationError): + PriceAssumption.model_validate( + {"change_pct": -1.5, "start_date": "2026-06-01", "end_date": "2026-06-14"} + ) + with pytest.raises(ValidationError): + PriceAssumption.model_validate( + {"change_pct": 9.0, "start_date": "2026-06-01", "end_date": "2026-06-14"} + ) + + +def test_simulate_request_horizon_bounds() -> None: + """horizon must be within 1..90.""" + with pytest.raises(ValidationError): + SimulateScenarioRequest.model_validate({"run_id": "abc", "horizon": 0}) + with pytest.raises(ValidationError): + SimulateScenarioRequest.model_validate({"run_id": "abc", "horizon": 200}) + + +def test_create_request_requires_name() -> None: + """CreateScenarioRequest requires a non-empty name.""" + with pytest.raises(ValidationError): + CreateScenarioRequest.model_validate({"name": "", "run_id": "abc", "horizon": 14}) + request = CreateScenarioRequest.model_validate( + {"name": "Summer discount", "run_id": "abc", "horizon": 14} + ) + assert request.name == "Summer discount" diff --git a/app/main.py b/app/main.py index f473127d..3cb36c8e 100644 --- a/app/main.py +++ b/app/main.py @@ -20,12 +20,15 @@ from app.features.config.service import apply_overrides_on_startup from app.features.demo.routes import router as demo_router from app.features.dimensions.routes import router as dimensions_router +from app.features.explainability.routes import router as explainability_router from app.features.featuresets.routes import router as featuresets_router from app.features.forecasting.routes import router as forecasting_router from app.features.ingest.routes import router as ingest_router from app.features.jobs.routes import router as jobs_router +from app.features.ops.routes import router as ops_router from app.features.rag.routes import router as rag_router from app.features.registry.routes import router as registry_router +from app.features.scenarios.routes import router as scenarios_router from app.features.seeder.routes import router as seeder_router logger = get_logger(__name__) @@ -131,13 +134,16 @@ def create_app() -> FastAPI: app.include_router(health_router) app.include_router(dimensions_router) app.include_router(analytics_router) + app.include_router(ops_router) app.include_router(jobs_router) app.include_router(ingest_router) app.include_router(featuresets_router) app.include_router(forecasting_router) + app.include_router(explainability_router) app.include_router(backtesting_router) app.include_router(registry_router) app.include_router(rag_router) + app.include_router(scenarios_router) app.include_router(agents_router) app.include_router(agents_ws_router) app.include_router(seeder_router) diff --git a/docs/_base/API_CONTRACTS.md b/docs/_base/API_CONTRACTS.md index 00e93fb4..d57debcc 100644 --- a/docs/_base/API_CONTRACTS.md +++ b/docs/_base/API_CONTRACTS.md @@ -19,9 +19,18 @@ All endpoints serve JSON; error responses use `application/problem+json` (RFC 78 | analytics | GET | `/analytics/inventory-status` | Latest `inventory_snapshot_daily` row per `(store, product)` grain (Postgres `DISTINCT ON`); optional `store_id`/`product_id` filters; `200` + empty list on an empty table (never `404`) | | featuresets | POST | `/featuresets/compute` | Compute time-safe features (lag/rolling/calendar, leakage-prevented) | | featuresets | POST | `/featuresets/preview` | Preview features with sample rows | -| forecasting | POST | `/forecasting/train` | Train a model (naive / seasonal_naive / moving_average / lightgbm) | +| forecasting | POST | `/forecasting/train` | Train a model (naive / seasonal_naive / moving_average / lightgbm / regression). `regression` wraps `HistGradientBoostingRegressor` on lag + calendar + exogenous features — the baseline a `model_exogenous` scenario re-forecasts through | | forecasting | POST | `/forecasting/predict` | Generate horizon predictions from a trained model | | backtesting | POST | `/backtesting/run` | Time-series CV (rolling/expanding splits, MAE/sMAPE/WAPE/bias/stability) | +| explainability | POST | `/explain/forecast` | Rule-based explanation of the h=1 forecast a named baseline model (`naive`/`seasonal_naive`/`moving_average`) produces on the series ending at `as_of_date`; returns a `ForecastExplanation` — driver contributions, advisory retail reason codes (correlation, not causation), confidence band, caveats, agent summary. Time-safe (`<= as_of_date`); a non-baseline `model_type` or a too-short series → RFC 7807 400 | +| explainability | GET | `/explain/runs/{run_id}` | Explain a registry `model_run` — config reconstructed from `model_run.model_config`, cutoff `data_window_end`. Missing run → 404; a non-baseline (`lightgbm`/`regression`) run → 400 | +| explainability | GET | `/explain/jobs/{job_id}` | Explain a completed `predict` job — store/product/model read from `job.result`, cutoff = day before the first forecast date. Missing job → 404; a job that is not a completed predict job → 400 | +| scenarios | POST | `/scenarios/simulate` | Stateless what-if: load a baseline model, forecast, apply price/promotion/holiday/inventory/lifecycle assumptions, return a `ScenarioComparison`. A `regression` baseline genuinely re-forecasts through a leakage-safe future feature frame (`method="model_exogenous"`); any other baseline applies a deterministic post-forecast multiplier (`method="heuristic"`). Bogus `run_id` → RFC 7807 404 | +| scenarios | POST | `/scenarios` | Run a simulation and persist it as a named `scenario_plan` (raw assumptions + full comparison snapshot); optional `tags` + `cloned_from` | +| scenarios | GET | `/scenarios` | List saved scenario plans, newest first (`limit`/`offset`, optional repeated `tags` filter — JSONB containment); `200` + empty list on an empty table | +| scenarios | GET | `/scenarios/{scenario_id}` | Saved plan + embedded comparison snapshot; `404` when missing | +| scenarios | POST | `/scenarios/compare` | Rank 2-5 saved plans (`scenario_ids`, `rank_by`) against a shared baseline; returns a `MultiScenarioComparison` with ranked rows + merged multi-series chart data. Unknown `scenario_id` → 404 | +| scenarios | DELETE | `/scenarios/{scenario_id}` | Delete a saved plan; `404` when missing | | registry | POST | `/registry/runs` | Create model run (pending) | | registry | GET | `/registry/runs` | List with filters + pagination + optional allow-listed `sort_by`/`sort_order` (created_at/model_type/status/store_id/product_id; unknown → default `created_at desc`) | | registry | GET | `/registry/runs/{run_id}` | Run details + JSONB metrics + runtime_info | @@ -37,6 +46,7 @@ All endpoints serve JSON; error responses use `application/problem+json` (RFC 78 | jobs | GET | `/jobs/{job_id}` | Status + result JSON | | jobs | DELETE | `/jobs/{job_id}` | Cancel pending | | rag | POST | `/rag/index` | Index a markdown/openapi document; idempotent via content hash | +| rag | POST | `/rag/index/project-docs` | Bulk-index bundled `docs/`, `PRPs/`, and root markdown; per-file + aggregate summary; idempotent via content hash; `502` if the embedding provider fails | | rag | POST | `/rag/retrieve` | Semantic search (HNSW), top-k with similarity threshold | | rag | GET | `/rag/sources` | List indexed sources | | rag | DELETE | `/rag/sources/{source_id}` | Delete source + cascaded chunks | diff --git a/docs/_base/DOMAIN_MODEL.md b/docs/_base/DOMAIN_MODEL.md index 25007301..c2e6e8bc 100644 --- a/docs/_base/DOMAIN_MODEL.md +++ b/docs/_base/DOMAIN_MODEL.md @@ -9,6 +9,7 @@ | Featuresets | Computed feature matrices (in-memory; not persisted) | Time-cutoff parameter — never reads beyond `cutoff_date` | | Forecasting | Trained model artifacts on disk (joblib `.pkl`) | Model interface in `examples/models/model_interface.md`; artifact_uri returned to caller | | Backtesting | Fold results, metrics (returned in response; persisted via Registry) | `SplitConfig` (expanding/sliding, gap, horizon) — `app/features/backtesting/splitter.py` | +| Scenarios | `scenario_plan` (saved what-if plans, JSONB assumptions + comparison) | `load_model_bundle` only (never a sibling `service.py`); `adjustments.py` heuristic multiplier or `feature_frame.py` model re-forecast; `agent_tools.py` is the agent-integration seam | | Registry | `model_run`, `run_alias`, `model_artifact` | SHA-256 hash on artifact_uri; status state machine | | RAG | `rag_source`, `rag_chunk` (with pgvector embedding column) | Content hash for idempotent indexing; embedding dimension fixed per provider | | Agents | `agent_session` (JSONB message_history) | Pydantic-validated tool args; HITL approval queue | @@ -42,11 +43,21 @@ - `store_id`, `product_id`, `date` must reference existing dimension rows. - Idempotent upsert via `ON CONFLICT (store_id, product_id, date) DO UPDATE` (`app/features/ingest/service.py`). +### `scenario_plan` (Scenarios) +- **Root:** `ScenarioPlan(scenario_id: str, name: str)` +- **JSONB fields:** `assumptions` (the raw `ScenarioAssumptions`), `comparison` (the full `ScenarioComparison` snapshot — stored so a reloaded plan re-renders without recomputation or the original artifact). +- **Scalar columns:** `tags` (a queryable JSONB string array — its own column, GIN-indexed, never folded into a blob) and `cloned_from` (the `scenario_id` a plan was cloned from); provenance/audit columns `source` (`'user'`|`'agent'`), `agent_session_id`, `approved_by`, `approved_at`, `approval_decision` (`'approved'`|`'rejected'`). +- **Invariants:** + - `method` is CHECK-constrained to `IN ('heuristic','model_exogenous')`. `heuristic` is a deterministic post-forecast multiplier; `model_exogenous` is a genuine re-forecast of a regression baseline through a leakage-safe future feature frame (`feature_frame.py`). + - A scenario adjustment touches only horizon (future) points; it never reads or mutates the historical series (`app/features/scenarios/tests/test_leakage.py` and `test_future_frame_leakage.py` are the spec). + - JSONB columns are persisted via `model_dump(mode="json")` so `date`/`datetime` serialise to ISO strings. + - An agent-saved plan (`source='agent'`) is persisted ONLY after the human approves it through the HITL gate — it always carries the approval audit trail. + ## Key Invariants — NEVER violate 1. **Time safety in features.** `app/features/featuresets/` uses only data at or before `cutoff_date`. Lags via `shift(positive)`, rolling via `shift(1).rolling(...)`, all `groupby` entity-aware. The test `app/features/featuresets/tests/test_leakage.py` is the spec — it MUST keep passing. 2. **Forward-only migrations.** Once an Alembic migration is merged, never edit it. Add a new migration to fix or evolve. -3. **HITL approval gates the agent's mutation surface.** Every tool that writes to the registry (`create_alias`, `archive_run`, …) must be in `agent_require_approval`. Widening the surface without updating that list is a security regression. +3. **HITL approval gates the agent's mutation surface.** Every tool that writes state (`create_alias`, `archive_run`, `save_scenario`, …) must be in `agent_require_approval`. Widening the surface without updating that list is a security regression. 4. **Single-host deployable.** No managed cloud service in the core path. `docker-compose up` must continue to be the only prerequisite besides Python + Node. 5. **Pre-1.0 contracts may move.** Pin the version you build against. After `v1.0.0`, full SemVer applies. 6. **Seeder is idempotent + scoped.** Never introduce a "wipe everything" path that isn't behind `--confirm` + scope flag. @@ -70,6 +81,12 @@ | `replenishment event` | One row in `replenishment_event` representing inbound stock at `(store, product, date)`; feature cadence is derived from event spacing | inbound order, restock (those would be different grains) | | `promotion (kind)` | One row in `promotion` with `kind ∈ {pct_off, bogo, bundle, markdown}`; features are one-hot per kind via `PromotionConfig.kinds_to_track` | discount, sale (kind is the discriminator, not "promotion" in the colloquial sense) | | `scenario` (seeder) | A YAML or in-code preset (`retail_standard`, `holiday_rush`, …) that wires `DimensionConfig` + `FactsConfig` | template, profile | +| `scenario plan` | A saved what-if analysis — a `scenario_plan` row pairing raw `ScenarioAssumptions` with a `ScenarioComparison` snapshot | seeder `scenario` (a different concept entirely) | +| `assumption` (what-if) | One future change a planner posits — a price change, promotion, holiday set, inventory cap, or lifecycle stage — fed to `POST /scenarios/simulate` | forecast input, feature | +| `applied factor` | The deterministic per-day multiplier `combined_daily_factor` derives from the assumptions; `1.0` means no change | weight, coefficient | +| `model_exogenous` | The scenario `method` where a regression baseline genuinely re-forecasts through the assumptions — as opposed to the `heuristic` post-forecast multiplier | re-trained model (the baseline is not re-trained, only re-run) | +| `future feature frame` | The leakage-safe `X_future` matrix `feature_frame.py` builds — long-lag, calendar, and exogenous columns the regression model consumes to re-forecast a scenario | feature matrix (that is the training-time term) | +| `scenario tag` | A free-text label on a saved `scenario_plan` (its own queryable JSONB-array column) for filtering and grouping the library | seeder `scenario` preset, registry `alias` | ## Event Taxonomy @@ -92,6 +109,8 @@ rag_source ──owns──► rag_chunk (with pgvector embedding) agent_session ──owns──► message_history (JSONB) ──may-contain──► tool_call (pending approval) job ──may-reference──► model_run (for train/backtest jobs) + +scenario_plan ──built-from──► model artifact (a baseline run_id) ──embeds──► comparison snapshot (JSONB) ``` ## Glossary (cross-cutting) diff --git a/docs/_base/REPO_MAP_INDEX.md b/docs/_base/REPO_MAP_INDEX.md index 8158a010..5300f157 100644 --- a/docs/_base/REPO_MAP_INDEX.md +++ b/docs/_base/REPO_MAP_INDEX.md @@ -31,7 +31,9 @@ ForecastLabAI is a portfolio-grade, single-host retail-demand-forecasting system | [`frontend/src/pages/explorer/job-detail.tsx`](../../frontend/src/pages/explorer/job-detail.tsx) | The job detail page — profile, params/result JSON, error details, linked run, cancel action, live polling | Investigating a single job | | [`frontend/src/pages/explorer/run-compare.tsx`](../../frontend/src/pages/explorer/run-compare.tsx) | The run-comparison page — two run pickers, side-by-side profile, config_diff, metrics_diff with delta indicators; deep-linkable via `?a=&b=` | Comparing two model runs | | [`frontend/src/pages/visualize/demand.tsx`](../../frontend/src/pages/visualize/demand.tsx) | The Demand Planner page — completed `predict` jobs rolled into a multi-SKU table (tomorrow/next-week/next-month demand + inventory requirement), lead-time selector, single-SKU drill-in | Answering "how much will this SKU sell, and do I have enough stock?" | -| [`alembic/versions/`](../../alembic/versions/) | Six migrations through `d6e0f2g3h456_create_agent_session_table.py` | DB-schema questions, migration drift | +| [`app/features/scenarios/`](../../app/features/scenarios/) | Scenario Simulation slice — `/scenarios/simulate` (stateless what-if) + `scenario_plan` CRUD + `/scenarios/compare`; pure `adjustments.py` heuristic factors and `feature_frame.py` (leakage-safe future X for the `model_exogenous` re-forecast); `agent_tools.py` is the agent-integration seam; never imports a sibling `service.py` | What-If planning, baseline-vs-scenario comparisons, model-driven re-forecasts | +| [`frontend/src/pages/visualize/planner.tsx`](../../frontend/src/pages/visualize/planner.tsx) | The What-If Planner page — pick a baseline predict job, define price/promotion/holiday/inventory/lifecycle assumptions, run a simulation, save / tag / reload / delete named plans, and rank 2-5 saved plans in a multi-scenario comparison | Answering "what if we discount this SKU 15% next week?" | +| [`alembic/versions/`](../../alembic/versions/) | Migrations through `43e35957a248_create_scenario_plan_table.py` | DB-schema questions, migration drift | | [`docs/ARCHITECTURE.md`](../ARCHITECTURE.md) | Phase-by-phase architecture narrative | High-level component reasoning | | [`docs/PHASE-index.md`](../PHASE-index.md) | Index of all 11 phase docs | Locating per-phase deep-dive | | [`docs/PHASE/*.md`](../PHASE/) | Per-phase implementation reference | Slice-specific deep dives | diff --git a/docs/_base/RULES.md b/docs/_base/RULES.md index 73baa422..ad9937b8 100644 --- a/docs/_base/RULES.md +++ b/docs/_base/RULES.md @@ -1,6 +1,6 @@ # ForecastLabAI Rules > Generated by: w7_generating-claudemd skill -> Source of truth: `.claude/rules/` directory (commit-format, branch-naming, security-patterns, product-vision, test-requirements, ui-design, versioning, output-formatting). This file consolidates the constraint matrix; the rule files are authoritative on detail. +> Source of truth: `.claude/rules/` directory (commit-format, branch-naming, security-patterns, product-vision, test-requirements, ui-design, shadcn-ui, versioning, output-formatting). This file consolidates the constraint matrix; the rule files are authoritative on detail. > Last reviewed: 2026-05-11 ## Change Authority Matrix diff --git a/docs/_base/SECURITY.md b/docs/_base/SECURITY.md index 786f35b6..8a8e3e78 100644 --- a/docs/_base/SECURITY.md +++ b/docs/_base/SECURITY.md @@ -67,7 +67,7 @@ Reference: PR #115 (issue #109) introduced this pattern on `ComputeFeaturesReque - Token budget cap per session (`agent_max_tokens=4096` default). - Tool-call cap per session (`agent_max_tool_calls=10` default). - Timeout wrap around `agent.run()` / `agent.run_stream()` (`agent_timeout_seconds=120`). -- HITL approval required for mutating tools — `agent_require_approval=["create_alias","archive_run"]`. Never widen the agent's mutation surface without adding the new tool name to that list. +- HITL approval required for mutating tools — `agent_require_approval=["create_alias","archive_run","save_scenario"]`. `save_scenario` (PRP-27 Phase D) lets the experiment agent persist a `scenario_plan` row; it is gated here exactly like the registry mutations. Never widen the agent's mutation surface without adding the new tool name to that list. - Never log full prompts/responses at INFO; DEBUG only with explicit operator opt-in. ## External Integrations Security diff --git a/docs/user-guide/agents-and-rag-guide.md b/docs/user-guide/agents-and-rag-guide.md new file mode 100644 index 00000000..052d8b33 --- /dev/null +++ b/docs/user-guide/agents-and-rag-guide.md @@ -0,0 +1,121 @@ +# Agents and RAG Guide + +ForecastLab includes a conversational AI layer — chat agents — backed by a +**RAG knowledge base** (retrieval-augmented generation). This guide explains how both +work and how to use them safely. + +## The RAG Knowledge Base + +RAG lets the system answer questions using a body of indexed documents rather than +only the language model's general training. ForecastLab uses it to ground answers in +**project documentation**. + +### How indexing works + +When you index a document: + +1. The document is split into overlapping **chunks** (markdown is split by heading, + OpenAPI specs by endpoint). +2. Each chunk is converted into an **embedding** — a numeric vector capturing its meaning. +3. Chunks and embeddings are stored in PostgreSQL using the `pgvector` extension. + +Indexing is **idempotent**: each document is identified by its path and a content +hash, so re-indexing unchanged content does nothing, and changed content replaces the +old chunks cleanly. + +### How retrieval works + +A search query is embedded the same way, then compared against every stored chunk by +**cosine similarity**. The closest chunks above a similarity threshold are returned, +each with a relevance score and a citation back to its source document. Retrieval +returns evidence — passages — not a generated answer; the agent decides what to do +with them. + +### Using it + +- **Knowledge page** (`/knowledge`) — browse the indexed corpus and run live semantic + searches. +- **Admin → RAG Sources** — index a new document, list sources, or delete one. +- **API** — `POST /rag/index`, `POST /rag/retrieve`, `GET /rag/sources`, + `DELETE /rag/sources/{id}`. + +### Embedding providers + +Embeddings come from either **OpenAI** or a local **Ollama** server. The active +provider, model, and vector dimension are shown and changed under **Admin → AI models** +(`GET` / `PATCH /config/ai`). Local Ollama keeps document content off external services. + +## The Chat Agents + +The agents are conversational assistants built with PydanticAI. Two agent types exist: + +- **`rag_assistant`** — answers questions using the RAG knowledge base. +- **`experiment`** — can run forecasting experiments (training, backtesting, registry + actions) on your behalf. + +### Talking to an agent + +Use the **Chat** page (`/chat`) or the API: + +1. `POST /agents/sessions` — open a session, choosing the agent type. +2. `POST /agents/sessions/{id}/chat` — send a message and get the full response, or + connect to `WS /agents/stream` for token-by-token streaming. +3. `DELETE /agents/sessions/{id}` — close the session. + +A session keeps its message history, so the agent remembers earlier turns in the +conversation. + +### Tools + +Agents can call **tools** — typed functions that fetch data or perform actions +(retrieve documentation, list model runs, start a backtest, and so on). When an agent +uses a tool, the chat UI shows the call and its result, so you can see exactly how an +answer was produced. + +## The Human-in-the-Loop Approval Gate + +Most tools are read-only and run immediately. Tools that **change state** — for +example creating a registry alias or archiving a run — are different: they **pause and +wait for your approval**. + +When an agent wants to run one of these tools: + +1. The session enters an `awaiting_approval` state and an `approval_required` event is + emitted. +2. Nothing happens until you respond. +3. You approve or reject via `POST /agents/sessions/{id}/approve` (the Chat page + surfaces this as a prompt). +4. On approval the tool runs; on rejection it is skipped. + +This gate means an agent can never silently mutate the model registry — a person is +always in the loop for consequential actions. The set of approval-gated tools is a +deliberate, fixed list. + +### Other safety limits + +Each session is bounded so an agent cannot run away: + +- a **token budget** per session, +- a **maximum number of tool calls** per session, +- a **timeout** wrapping each agent run. + +The **Agent Guide** page (`/guide`) shows these limits live, along with the available +tools and example prompts. + +## Putting It Together + +A typical RAG-assisted exchange: you ask a question on the Chat page → the +`rag_assistant` agent calls its retrieval tool → the tool runs a semantic search over +the indexed corpus → the agent reads the returned passages → it answers, grounded in +real documentation, and you can see the citations. For experiments, the `experiment` +agent can additionally trigger training or backtesting — pausing for your approval +before anything that writes to the registry. + +## Tips + +- The agents need an LLM API key (`OPENAI_API_KEY` or `ANTHROPIC_API_KEY`) in `.env`. + Without one, the chat features are unavailable but the rest of the system still works. +- For useful RAG answers, index relevant documentation first — an empty corpus means + the assistant has nothing to cite. +- Watch the tool-call display in the chat: it is the simplest way to understand how + the agent reached its answer. diff --git a/docs/user-guide/dashboard-guide.md b/docs/user-guide/dashboard-guide.md new file mode 100644 index 00000000..83aec0bd --- /dev/null +++ b/docs/user-guide/dashboard-guide.md @@ -0,0 +1,93 @@ +# Dashboard Guide + +The ForecastLab dashboard is a React web app at **http://localhost:5173**. This guide +walks through every page. The top navigation bar groups pages as: Dashboard, +Showcase, **Explorer** (menu), **Visualize** (menu), Knowledge, Chat, Agent Guide, +and Admin. A light/dark theme toggle sits on the right. + +## Dashboard (`/`) + +The landing page. It shows headline **KPI cards** — total revenue, units sold, +transactions, average unit price, average basket — plus a revenue-over-time chart. +Use it for a quick health check of the seeded dataset. If the database is empty, +the cards read zero; seed data first (see Admin, or run `make demo`). + +## Showcase (`/showcase`) + +Runs the **end-to-end demo pipeline live in your browser**. Click to start, and the +page streams one status card per step: seed → features → train three models → +backtest → register the winner → alias → agent check. Each card flips to a +pass / fail / skip state, and a summary banner reports the winning model and its +accuracy. This is the best page for a guided demo of the whole system. + +Tip: tick **Re-seed first** if the database is empty or stale. Only one pipeline can +run at a time. + +## Explorer + +The Explorer menu contains read-only pages for browsing the underlying data and +model history. Tables support pagination, filtering, search, and sorting; clicking a +row opens a detail page. + +- **Sales** (`/explorer/sales`) — browse daily sales records. +- **Stores** (`/explorer/stores`) — list of retail stores. Click a store to open its + **detail page**: an entity profile, date-scoped KPIs, a revenue-over-time chart, + and a top-products drilldown. +- **Products** (`/explorer/products`) — list of products (SKUs). Click a product for + its **detail page**: profile, KPIs, revenue and lifecycle-demand curves, and a + top-stores drilldown. +- **Model Runs** (`/explorer/runs`) — every trained model tracked in the registry. + A run **detail page** shows its configuration, metrics, and runtime info as JSON, + cross-links to the store/product, an artifact-integrity check, and a compare link. + Two runs can be compared side by side (config diff + metrics diff with deltas). +- **Jobs** (`/explorer/jobs`) — submitted train/predict/backtest jobs. A job + **detail page** shows parameters, result JSON, error details, the linked run, a + cancel action, and live status polling. + +## Visualize + +The Visualize menu holds the analytical, chart-heavy pages. + +- **Demand Planner** (`/visualize/demand`) — rolls completed `predict` jobs into a + multi-SKU table showing tomorrow / next-week / next-month demand and the + inventory required to cover it. Includes a lead-time selector and a single-SKU + drill-in. Answers "how much will this SKU sell, and do I have enough stock?" +- **Forecast** (`/visualize/forecast`) — visualizes a model's horizon predictions. +- **Backtest Results** (`/visualize/backtest`) — charts backtest folds and the + accuracy metrics (MAE, sMAPE, WAPE, bias, stability) for a model run. + +## Knowledge (`/knowledge`) + +Surfaces the **RAG knowledge base**: the indexed document corpus, a live semantic +search box, and current system state. Type a question to retrieve the most relevant +documentation passages with similarity scores. If the corpus is empty, the page +shows an empty state until documents are indexed (see the Admin page). + +## Chat (`/chat`) + +The **AI agent chat**. Ask questions in natural language; the assistant streams its +answer token by token and shows any tools it calls. Some actions pause for your +approval before they run. See the Agents and RAG Guide for details. + +## Agent Guide (`/guide`) + +An in-app reference for the chat agents: the tools they can use, the human-in-the-loop +approval gate, live session limits, and example prompts to try. + +## Admin (`/admin`) + +Operational controls, organized into tabs: + +- **Data seeding** — generate synthetic retail data from named scenarios, append more, + verify integrity, or clear the dataset. +- **RAG Sources** — list indexed knowledge documents, index a new document, and + delete sources. +- **Aliases** — manage model registry aliases (e.g. promote a run to `production`). +- **AI models** — view and change the agent LLM and RAG embedding configuration + live, with per-provider health indicators. + +## Notes + +- Pages fetch data from the backend API; if everything shows "Loading…", confirm the + backend is running and `VITE_API_BASE_URL` points at it. +- Explorer detail pages are reached by clicking table rows — they are not in the nav. diff --git a/docs/user-guide/feature-reference.md b/docs/user-guide/feature-reference.md new file mode 100644 index 00000000..91e1ba51 --- /dev/null +++ b/docs/user-guide/feature-reference.md @@ -0,0 +1,142 @@ +# Feature Reference + +This is a capability-by-capability reference for ForecastLab's backend. Every feature +is a REST API served at **http://localhost:8123**; the interactive Swagger UI at +**/docs** is the authoritative, always-current contract. All errors use the RFC 7807 +`application/problem+json` format. + +## Health + +- `GET /health` — liveness probe; returns `{"status": "ok"}`. + +## Data Platform and Ingest + +The data platform owns seven retail tables: `store`, `product`, `calendar`, +`sales_daily`, `price_history`, `promotion`, and `inventory_snapshot_daily`. + +- `POST /ingest/sales-daily` — batch-load daily sales. Resolves natural keys + (store code, SKU) to IDs and upserts idempotently, so re-sending the same batch is + safe. + +## Dimensions + +Reference data — the "who" and "what" behind the sales facts. + +- `GET /dimensions/stores` — list stores (pagination, region / store-type filters, + case-insensitive search, optional sorting). +- `GET /dimensions/stores/{store_id}` — one store by ID. +- `GET /dimensions/products` — list products (category / brand filters, SKU / name + search, optional sorting). +- `GET /dimensions/products/{product_id}` — one product by ID. + +## Analytics + +Read-only aggregates computed over the sales data. + +- `GET /analytics/kpis` — headline KPIs: revenue, units, transactions, average unit + price, average basket. +- `GET /analytics/drilldowns` — group sales by store, product, category, region, or date. +- `GET /analytics/timeseries` — period-bucketed sales series (day / week / month / + quarter) for revenue-over-time charts. +- `GET /analytics/inventory-status` — latest inventory snapshot per store-product pair. + +## Feature Engineering + +Turns raw sales into model-ready features while strictly preventing **data leakage** — +features never use information from the future. + +- `POST /featuresets/compute` — compute time-safe features (lags, rolling-window + statistics, calendar effects) up to a cutoff date. +- `POST /featuresets/preview` — preview computed features with sample rows. + +## Forecasting + +Trains demand-forecasting models and generates predictions. + +- `POST /forecasting/train` — train a model. Supported model types: `naive`, + `seasonal_naive`, `moving_average` (baselines) and `lightgbm` (machine learning). +- `POST /forecasting/predict` — generate horizon predictions from a trained model. + +The three baselines exist as honest comparison points — a machine-learning model is +only worth using if it beats them. + +## Backtesting + +Measures how accurate a model would have been, using time-series cross-validation. + +- `POST /backtesting/run` — run rolling or expanding train/test splits and report + accuracy metrics: **MAE**, **sMAPE**, **WAPE**, **bias**, and **stability**. + +## Model Registry + +Tracks every trained model so runs are reproducible and comparable. + +- `POST /registry/runs` — create a model run record (starts `pending`). +- `GET /registry/runs` — list runs with filters, pagination, and sorting. +- `GET /registry/runs/{run_id}` — run details, including metrics and runtime info. +- `PATCH /registry/runs/{run_id}` — update a run's status, metrics, or artifact location. +- `GET /registry/runs/{run_id}/verify` — verify the model artifact's SHA-256 integrity. +- `GET /registry/compare/{run_id_a}/{run_id_b}` — diff two runs. +- `POST /registry/aliases` — create or move an alias (e.g. `production`); aliases may + point only to a successful run. +- `GET /registry/aliases`, `GET /registry/aliases/{name}`, `DELETE /registry/aliases/{name}` + — manage aliases. + +A run moves through `pending → running → success` (or `failed`), and an alias is a +human-friendly pointer (like `production` or `champion`) to a chosen successful run. + +## Jobs + +Long-running work — training, prediction, backtesting — submitted as jobs. + +- `POST /jobs` — submit a `train`, `predict`, or `backtest` job; returns a `job_id`. +- `GET /jobs` — list jobs with filters and sorting. +- `GET /jobs/{job_id}` — job status and result JSON. +- `DELETE /jobs/{job_id}` — cancel a pending job. + +## RAG Knowledge Base + +Semantic search over indexed documents. See the Agents and RAG Guide for the full +picture. + +- `POST /rag/index` — index a markdown or OpenAPI document; idempotent via content hash. +- `POST /rag/retrieve` — semantic search; returns the top-k most relevant passages. +- `GET /rag/sources` — list indexed sources. +- `DELETE /rag/sources/{source_id}` — delete a source and its chunks. + +## Agents + +The conversational AI layer. See the Agents and RAG Guide. + +- `POST /agents/sessions` — open a chat session (`experiment` or `rag_assistant`). +- `GET /agents/sessions/{id}` — session status and message history. +- `POST /agents/sessions/{id}/chat` — send a message; returns the full response. +- `POST /agents/sessions/{id}/approve` — approve or reject a pending tool call. +- `DELETE /agents/sessions/{id}` — close a session. +- `WS /agents/stream` — token-by-token streaming with tool-call events. + +## Seeder ("The Forge") + +Generates realistic synthetic retail data so you have something to forecast. + +- `GET /seeder/status` — current dataset state. +- `GET /seeder/scenarios` — available named scenarios. +- `GET /seeder/channels` — available sales channels. +- `POST /seeder/generate` — generate a dataset from a scenario. +- `POST /seeder/append` — append more data to an existing dataset. +- `DELETE /seeder/data` — clear the generated data. +- `GET /seeder/exogenous` — exogenous signal data. +- `POST /seeder/verify` — verify dataset integrity. + +## Demo Pipeline + +- `POST /demo/run` — run the full end-to-end pipeline in one call. +- `WS /demo/stream` — stream per-step events for the live Showcase page. + +## Configuration + +- `GET /config/ai` — effective AI-model configuration (agent LLM + RAG embeddings); + API keys are always masked. +- `PATCH /config/ai` — change AI-model settings live, with no restart. +- `GET /config/providers/health` — per-provider connectivity status. +- `GET /config/ollama/models` — models available on the configured Ollama host. diff --git a/docs/user-guide/getting-started.md b/docs/user-guide/getting-started.md new file mode 100644 index 00000000..f3109d4b --- /dev/null +++ b/docs/user-guide/getting-started.md @@ -0,0 +1,100 @@ +# Getting Started with ForecastLab + +This guide takes you from a fresh clone to a running ForecastLab system with data, +trained models, and a working dashboard — in about ten minutes. + +## What ForecastLab Is + +ForecastLab is a **retail demand-forecasting system** you run on a single machine. It +covers the whole forecasting lifecycle end to end: + +1. **Data platform** — stores, products, calendar, daily sales, prices, promotions, inventory. +2. **Ingest** — load sales data through a batch API. +3. **Feature engineering** — build time-safe features (lags, rolling windows, calendar effects). +4. **Forecasting** — train baseline and machine-learning models. +5. **Backtesting** — measure accuracy with time-series cross-validation. +6. **Model registry** — track every trained model, compare runs, promote a champion. +7. **RAG knowledge base** — semantic search over project documentation. +8. **AI agents** — a chat assistant that can run experiments and answer questions. +9. **Dashboard** — a React web app that surfaces all of the above. + +It is built for learning, demos, and portfolio use. It is **not** a multi-tenant SaaS, +not a real-time streaming system, and needs no cloud account — everything runs locally. + +## Prerequisites + +- **Docker** (for the PostgreSQL database) +- **Python 3.12** with [`uv`](https://docs.astral.sh/uv/) (the Python package manager) +- **Node.js** with `pnpm` (enabled through `corepack`) + +## Install and Run + +Run these from the repository root. + +```bash +# 1. Configure environment — add your OpenAI / Anthropic API keys to .env +cp .env.example .env + +# 2. Start PostgreSQL + pgvector (listens on host port 5433) +docker compose up -d + +# 3. Install backend dependencies +uv sync --extra dev + +# 4. Apply database migrations +uv run alembic upgrade head + +# 5. Start the backend API (http://localhost:8123) +uv run uvicorn app.main:app --reload --port 8123 +``` + +In a second terminal, start the web dashboard: + +```bash +cd frontend +corepack enable pnpm +pnpm install +pnpm dev # dashboard at http://localhost:5173 +``` + +Open **http://localhost:5173** in your browser. The interactive API documentation +(Swagger UI) is available at **http://localhost:8123/docs**. + +## Load Data and See It Work + +A fresh database is empty. The fastest way to see the whole system in action is the +**end-to-end demo**, which seeds data, computes features, trains three models, +backtests them, registers the winner, and exercises the agent: + +```bash +make demo +``` + +You can also watch the same pipeline run live in the browser on the **Showcase** page +(see the Dashboard Guide). To generate data without the full pipeline, use the +**Admin** page or the seeder API directly. + +## Key Ports and URLs + +| Service | URL | +|----------------|------------------------------| +| Dashboard | http://localhost:5173 | +| Backend API | http://localhost:8123 | +| API docs | http://localhost:8123/docs | +| PostgreSQL | localhost:5433 | + +## If Something Goes Wrong + +- **Dashboard shows "Loading…" everywhere** — the frontend cannot reach the backend. + Check that the API is running (`curl http://localhost:8123/health`) and that + `frontend/.env` has `VITE_API_BASE_URL=http://localhost:8123`. +- **Database connection refused** — make sure `docker compose up -d` succeeded and + migrations are applied (`uv run alembic upgrade head`). +- **API keys** — the AI agent and RAG features need `OPENAI_API_KEY` and/or + `ANTHROPIC_API_KEY` set in `.env`. Forecasting and the dashboard work without them. + +## Next Steps + +- **Dashboard Guide** — a tour of every page in the web app. +- **Feature Reference** — what each part of the system does and its API endpoints. +- **Agents and RAG Guide** — how the chat assistant and knowledge base work. diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index 82ad780b..6619ead2 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -10,6 +10,7 @@ import { ROUTES } from '@/lib/constants' // Lazy-loaded page components const DashboardPage = lazy(() => import('@/pages/dashboard')) const ShowcasePage = lazy(() => import('@/pages/showcase')) +const OpsPage = lazy(() => import('@/pages/ops')) const SalesExplorerPage = lazy(() => import('@/pages/explorer/sales')) const StoresExplorerPage = lazy(() => import('@/pages/explorer/stores')) const StoreDetailPage = lazy(() => import('@/pages/explorer/store-detail')) @@ -23,6 +24,7 @@ const JobDetailPage = lazy(() => import('@/pages/explorer/job-detail')) const ForecastPage = lazy(() => import('@/pages/visualize/forecast')) const BacktestPage = lazy(() => import('@/pages/visualize/backtest')) const DemandPlannerPage = lazy(() => import('@/pages/visualize/demand')) +const WhatIfPlannerPage = lazy(() => import('@/pages/visualize/planner')) const ChatPage = lazy(() => import('@/pages/chat')) const KnowledgePage = lazy(() => import('@/pages/knowledge')) const GuidePage = lazy(() => import('@/pages/guide')) @@ -55,6 +57,14 @@ function App() { } /> + }> + + + } + /> } /> + }> + + + } + /> {description}} - + {/* Height is passed via inline style — a `h-[${height}px]` class is a + dynamic string Tailwind cannot statically discover, so the JIT + compiler drops it at build time. */} + } /> - + } /> {formattedData.map((_, index) => ( diff --git a/frontend/src/components/charts/index.ts b/frontend/src/components/charts/index.ts index 2e439f7a..6e4231b4 100644 --- a/frontend/src/components/charts/index.ts +++ b/frontend/src/components/charts/index.ts @@ -1,4 +1,5 @@ export * from './kpi-card' export * from './time-series-chart' +export * from './multi-series-chart' export * from './backtest-folds-chart' export * from './revenue-bar-chart' diff --git a/frontend/src/components/charts/kpi-card.tsx b/frontend/src/components/charts/kpi-card.tsx index 7f9ff6c3..d734fa2c 100644 --- a/frontend/src/components/charts/kpi-card.tsx +++ b/frontend/src/components/charts/kpi-card.tsx @@ -47,16 +47,16 @@ export function KPICard({ {Icon && } -
{value}
+
{value}
{trend && ( 0 - ? 'text-green-600 dark:text-green-400' + ? 'text-success' : trend.value < 0 - ? 'text-red-600 dark:text-red-400' + ? 'text-destructive' : '' )} > diff --git a/frontend/src/components/charts/multi-series-chart.tsx b/frontend/src/components/charts/multi-series-chart.tsx new file mode 100644 index 00000000..33aade53 --- /dev/null +++ b/frontend/src/components/charts/multi-series-chart.tsx @@ -0,0 +1,106 @@ +import { CartesianGrid, ComposedChart, Line, XAxis, YAxis } from 'recharts' +import { + ChartConfig, + ChartContainer, + ChartLegend, + ChartLegendContent, + ChartTooltip, + ChartTooltipContent, +} from '@/components/ui/chart' +import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card' + +/** One line in a multi-series chart — a row key plus its display label. */ +export interface ChartSeries { + key: string + label: string +} + +interface MultiSeriesChartProps { + title: string + description?: string + /** Date-keyed rows; each carries `xAxisKey` plus a value per series key. */ + data: Record[] + /** The lines to draw — the first is rendered solid, the rest dashed. */ + series: ChartSeries[] + xAxisKey?: string + height?: number + className?: string +} + +// Deterministic palette — the shadcn chart CSS vars cycled across the lines. +const PALETTE = [ + 'var(--chart-1)', + 'var(--chart-2)', + 'var(--chart-3)', + 'var(--chart-4)', + 'var(--chart-5)', +] + +/** + * Renders M+1 demand lines on one chart — a shared baseline plus one line per + * scenario — for the What-If Planner's multi-scenario comparison view. + * + * Every series key MUST be a valid CSS identifier (the shadcn `ChartContainer` + * emits a `--color-` var per key); callers key scenario lines by the + * CSS-safe `scenario_id`, never by a free-text plan name. + */ +export function MultiSeriesChart({ + title, + description, + data, + series, + xAxisKey = 'date', + height = 320, + className, +}: MultiSeriesChartProps) { + const chartConfig: ChartConfig = {} + series.forEach((line, index) => { + chartConfig[line.key] = { + label: line.label, + color: PALETTE[index % PALETTE.length], + } + }) + + return ( + + + {title} + {description && {description}} + + + {/* Height via inline style — a dynamic `h-[${height}px]` class is not + statically discoverable by the Tailwind JIT compiler. */} + + + + { + const date = new Date(value) + return date.toLocaleDateString('en-US', { month: 'short', day: 'numeric' }) + }} + /> + + } /> + } /> + {series.map((line, index) => ( + + ))} + + + + + ) +} diff --git a/frontend/src/components/charts/revenue-bar-chart.tsx b/frontend/src/components/charts/revenue-bar-chart.tsx index 58ff87b9..200c19f2 100644 --- a/frontend/src/components/charts/revenue-bar-chart.tsx +++ b/frontend/src/components/charts/revenue-bar-chart.tsx @@ -41,7 +41,10 @@ export function RevenueBarChart({ {description && {description}} - + {/* Height is passed via inline style — a `h-[${height}px]` class is a + dynamic string Tailwind cannot statically discover, so the JIT + compiler drops it at build time. */} + diff --git a/frontend/src/components/charts/time-series-chart.tsx b/frontend/src/components/charts/time-series-chart.tsx index b32bf641..767da19a 100644 --- a/frontend/src/components/charts/time-series-chart.tsx +++ b/frontend/src/components/charts/time-series-chart.tsx @@ -1,7 +1,9 @@ -import { Area, CartesianGrid, ComposedChart, Legend, Line, XAxis, YAxis } from 'recharts' +import { Area, CartesianGrid, ComposedChart, Line, XAxis, YAxis } from 'recharts' import { ChartConfig, ChartContainer, + ChartLegend, + ChartLegendContent, ChartTooltip, ChartTooltipContent, } from '@/components/ui/chart' @@ -70,7 +72,10 @@ export function TimeSeriesChart({ {description && {description}} - + {/* Height is passed via inline style — a `h-[${height}px]` class is a + dynamic string Tailwind cannot statically discover, so the JIT + compiler drops it at build time. */} + } /> - + } /> {/* Prediction-interval band — drawn first so the forecast line sits on top. A function dataKey returns the [lower, upper] tuple recharts renders as a range area. */} diff --git a/frontend/src/components/chat/chat-input.tsx b/frontend/src/components/chat/chat-input.tsx index b4ad2876..74005673 100644 --- a/frontend/src/components/chat/chat-input.tsx +++ b/frontend/src/components/chat/chat-input.tsx @@ -87,7 +87,7 @@ export function ApprovalPrompt({ isLoading = false, }: ApprovalPromptProps) { return ( -
+

Approval Required

The agent wants to perform: {action} diff --git a/frontend/src/components/chat/tool-call-display.tsx b/frontend/src/components/chat/tool-call-display.tsx index 05e35b90..22a53713 100644 --- a/frontend/src/components/chat/tool-call-display.tsx +++ b/frontend/src/components/chat/tool-call-display.tsx @@ -18,9 +18,9 @@ export function ToolCallDisplay({ toolCall, className }: ToolCallDisplayProps) { const statusIcon = { pending: , - running: , - completed: , - failed: , + running: , + completed: , + failed: , } return ( @@ -77,9 +77,9 @@ export function ToolCallProgress({ toolName, status }: ToolCallProgressProps) { {status === 'running' || status === 'starting' ? ( ) : status === 'completed' ? ( - + ) : ( - + )} diff --git a/frontend/src/components/common/json-block.tsx b/frontend/src/components/common/json-block.tsx index 656f894c..245961eb 100644 --- a/frontend/src/components/common/json-block.tsx +++ b/frontend/src/components/common/json-block.tsx @@ -1,3 +1,4 @@ +import type { ReactNode } from 'react' import { cn } from '@/lib/utils' interface JsonBlockProps { @@ -5,10 +6,52 @@ interface JsonBlockProps { className?: string } +// Matches a JSON string (group 1) with an optional key colon (group 2), a +// literal (group 3), or a number (group 4) in pretty-printed JSON. +const JSON_TOKEN = + /("(?:\\.|[^"\\])*")(\s*:)?|\b(true|false|null)\b|(-?\d+(?:\.\d+)?(?:[eE][+-]?\d+)?)/g + +/** + * Tokenises pretty-printed JSON into syntax-highlighted spans. Token colours use + * the semantic status tokens — each verified to clear WCAG AA on the muted block + * background in both light and dark themes. + */ +function highlightJson(json: string): ReactNode[] { + const out: ReactNode[] = [] + let lastIndex = 0 + let key = 0 + let match: RegExpExecArray | null + JSON_TOKEN.lastIndex = 0 + while ((match = JSON_TOKEN.exec(json)) !== null) { + if (match.index > lastIndex) { + out.push(json.slice(lastIndex, match.index)) + } + const [token, str, colon, literal, num] = match + let className = '' + if (str !== undefined) { + className = colon ? 'text-info' : 'text-success' + } else if (literal !== undefined) { + className = 'text-destructive' + } else if (num !== undefined) { + className = 'text-warning' + } + out.push( + + {token} + , + ) + lastIndex = match.index + token.length + } + if (lastIndex < json.length) { + out.push(json.slice(lastIndex)) + } + return out +} + /** * Read-only formatted-JSON viewer. Renders a muted em-dash for null/undefined, - * otherwise a scrollable, pretty-printed

 block. Intentionally has no
- * syntax-highlighter dependency — it surfaces run/job JSONB payloads as-is.
+ * otherwise a scrollable, syntax-highlighted 
 block surfacing run/job
+ * JSONB payloads.
  */
 export function JsonBlock({ value, className }: JsonBlockProps) {
   if (value === null || value === undefined) {
@@ -22,7 +65,7 @@ export function JsonBlock({ value, className }: JsonBlockProps) {
         className,
       )}
     >
-      {JSON.stringify(value, null, 2)}
+      {highlightJson(JSON.stringify(value, null, 2))}
     
) } diff --git a/frontend/src/components/common/status-badge.tsx b/frontend/src/components/common/status-badge.tsx index 5340ce7d..ee4de6fc 100644 --- a/frontend/src/components/common/status-badge.tsx +++ b/frontend/src/components/common/status-badge.tsx @@ -7,11 +7,11 @@ const statusBadgeVariants = cva( variants: { variant: { default: 'bg-secondary text-secondary-foreground', - success: 'bg-green-100 text-green-800 dark:bg-green-900/30 dark:text-green-400', - warning: 'bg-yellow-100 text-yellow-800 dark:bg-yellow-900/30 dark:text-yellow-400', - error: 'bg-red-100 text-red-800 dark:bg-red-900/30 dark:text-red-400', - info: 'bg-blue-100 text-blue-800 dark:bg-blue-900/30 dark:text-blue-400', - pending: 'bg-gray-100 text-gray-800 dark:bg-gray-700/30 dark:text-gray-400', + success: 'bg-success text-success-foreground', + warning: 'bg-warning text-warning-foreground', + error: 'bg-destructive text-destructive-foreground', + info: 'bg-info text-info-foreground', + pending: 'bg-muted text-muted-foreground', }, }, defaultVariants: { diff --git a/frontend/src/components/data-table/data-table.tsx b/frontend/src/components/data-table/data-table.tsx index 2224a87c..858ff375 100644 --- a/frontend/src/components/data-table/data-table.tsx +++ b/frontend/src/components/data-table/data-table.tsx @@ -77,7 +77,9 @@ export function DataTable({
)} -
+ {/* Light: elevated white panel. dark: overrides keep the existing + borderless-on-canvas look unchanged. */} +
{table.getHeaderGroups().map((headerGroup) => ( @@ -97,11 +99,13 @@ export function DataTable({ {isLoading ? ( - // Loading skeleton + // Loading skeleton — one cell per *visible* column so the + // skeleton width matches the rendered table after a + // column-visibility toggle. Array.from({ length: pagination.pageSize }).map((_, i) => ( - {columns.map((_, j) => ( - + {table.getVisibleLeafColumns().map((column) => ( + ))} @@ -113,6 +117,20 @@ export function DataTable({ key={row.id} data-state={row.getIsSelected() && 'selected'} onClick={onRowClick ? () => onRowClick(row.original) : undefined} + // A clickable row must be keyboard-operable: expose it as a + // button to assistive tech and fire on Enter / Space. + onKeyDown={ + onRowClick + ? (event) => { + if (event.key === 'Enter' || event.key === ' ') { + event.preventDefault() + onRowClick(row.original) + } + } + : undefined + } + tabIndex={onRowClick ? 0 : undefined} + role={onRowClick ? 'button' : undefined} className={cn(onRowClick && 'cursor-pointer')} > {row.getVisibleCells().map((cell) => ( @@ -128,7 +146,7 @@ export function DataTable({ ) : ( {emptyMessage} diff --git a/frontend/src/components/demo/demo-step-card.tsx b/frontend/src/components/demo/demo-step-card.tsx index 6d230066..e24b8e9d 100644 --- a/frontend/src/components/demo/demo-step-card.tsx +++ b/frontend/src/components/demo/demo-step-card.tsx @@ -15,11 +15,11 @@ const STATUS_GLYPH: Record = { // Left-border accent colour per status. const STATUS_ACCENT: Record = { idle: 'border-l-border', - running: 'border-l-blue-500', - pass: 'border-l-green-500', - fail: 'border-l-red-500', + running: 'border-l-info', + pass: 'border-l-success', + fail: 'border-l-destructive', skip: 'border-l-muted-foreground/40', - warn: 'border-l-yellow-500', + warn: 'border-l-warning', } function formatDuration(ms: number): string { @@ -50,7 +50,7 @@ function BacktestBreakdown({ data }: { data: Record }) { className={cn( 'flex items-center justify-between rounded-md px-2 py-1 text-xs', row.model === winner - ? 'bg-green-100 font-semibold dark:bg-green-900/30' + ? 'bg-success/10 font-semibold' : 'bg-muted' )} > diff --git a/frontend/src/components/explainability/explanation-panel.test.tsx b/frontend/src/components/explainability/explanation-panel.test.tsx new file mode 100644 index 00000000..11e4f801 --- /dev/null +++ b/frontend/src/components/explainability/explanation-panel.test.tsx @@ -0,0 +1,66 @@ +import { describe, expect, it } from 'vitest' +import { render, screen } from '@testing-library/react' +import { ApiError } from '@/lib/api' +import { ExplanationPanel } from './explanation-panel' +import type { ForecastExplanation } from '@/types/api' + +const sampleExplanation: ForecastExplanation = { + store_id: 1, + product_id: 2, + model_type: 'naive', + method: 'rule_based', + forecast_value: 42, + drivers: [ + { + name: 'last_observation', + feature_value: 42, + contribution: 42, + direction: 'positive', + description: 'The naive forecast is the last observed value.', + }, + ], + reason_codes: [ + { code: 'stockout_constrained', severity: 'warn', detail: '2 stockout days.' }, + ], + confidence: 'medium', + caveats: ['Drivers describe correlation, not causation.'], + agent_summary: 'The naive model forecasts 42 units.', + as_of_date: '2024-03-01', + generated_at: '2024-03-01T00:00:00Z', +} + +describe('ExplanationPanel', () => { + it('renders drivers, reason codes, confidence, and caveats', () => { + render() + + expect(screen.getByText('last observation')).toBeTruthy() + expect(screen.getByText('Positive')).toBeTruthy() + expect(screen.getByText('medium')).toBeTruthy() + expect(screen.getByText(/stockout constrained/)).toBeTruthy() + expect(screen.getByText(/correlation, not causation/)).toBeTruthy() + expect(screen.getByText(sampleExplanation.agent_summary)).toBeTruthy() + }) + + it('renders a loading state', () => { + render() + expect(screen.getByText(/Generating explanation/)).toBeTruthy() + }) + + it('renders a destructive error for an unexpected failure', () => { + render() + expect(screen.getByText('boom')).toBeTruthy() + }) + + it('renders a neutral message for a 400 (non-baseline model)', () => { + const apiError = new ApiError('Explanations are available for baseline models only', 400) + render() + expect(screen.getByText(/baseline models only/)).toBeTruthy() + }) + + it('shows a no-signals message when there are no reason codes', () => { + render( + , + ) + expect(screen.getByText(/No advisory retail signals/)).toBeTruthy() + }) +}) diff --git a/frontend/src/components/explainability/explanation-panel.tsx b/frontend/src/components/explainability/explanation-panel.tsx new file mode 100644 index 00000000..ecd9b0a7 --- /dev/null +++ b/frontend/src/components/explainability/explanation-panel.tsx @@ -0,0 +1,218 @@ +import type { ReactNode } from 'react' +import { AlertTriangle, Info, Lightbulb, Minus, TrendingDown, TrendingUp } from 'lucide-react' +import { ApiError, formatNumber, getErrorMessage } from '@/lib/api' +import { Badge } from '@/components/ui/badge' +import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card' +import { LoadingState } from '@/components/common/loading-state' +import { + Table, + TableBody, + TableCell, + TableHead, + TableHeader, + TableRow, +} from '@/components/ui/table' +import type { + ConfidenceLevel, + DriverContribution, + ForecastExplanation, + ReasonCode, +} from '@/types/api' + +interface ExplanationPanelProps { + explanation?: ForecastExplanation + isLoading?: boolean + error?: unknown +} + +const CONFIDENCE_VARIANT: Record = { + high: 'default', + medium: 'secondary', + low: 'outline', +} + +function DirectionLabel({ direction }: { direction: DriverContribution['direction'] }) { + if (direction === 'positive') { + return ( + + + Positive + + ) + } + if (direction === 'negative') { + return ( + + + Negative + + ) + } + return ( + + + Neutral + + ) +} + +function ReasonCodeRow({ reason }: { reason: ReasonCode }) { + const isWarn = reason.severity === 'warn' + return ( +
  • + {isWarn ? ( + + ) : ( + + )} + + {reason.code.replace(/_/g, ' ')} + {' — '} + {reason.detail} + +
  • + ) +} + +/** Card shell shared by every panel state so the layout never jumps. */ +function PanelShell({ children }: { children: ReactNode }) { + return ( + + + + + Forecast Explanation + + + Rule-based driver attribution for the h=1 baseline forecast. + + + {children} + + ) +} + +export function ExplanationPanel({ explanation, isLoading, error }: ExplanationPanelProps) { + if (isLoading) { + return ( + + + + ) + } + + if (error) { + // A 400 here means the run/job is not a baseline model — an expected, + // non-error outcome, so it is shown in a neutral (not destructive) tone. + const isExpected = error instanceof ApiError && error.status === 400 + return ( + +
    + {isExpected ? ( + + ) : ( + + )} + {getErrorMessage(error)} +
    +
    + ) + } + + if (!explanation) { + return ( + +

    No explanation available.

    +
    + ) + } + + return ( + +
    +
    +
    +

    h=1 forecast

    +

    {formatNumber(explanation.forecast_value, 1)}

    +
    +
    +

    Confidence

    + + {explanation.confidence} + +
    +
    +

    Model

    +

    {explanation.model_type}

    +
    +
    + +

    {explanation.agent_summary}

    + +
    +

    Drivers

    +
    + + + Driver + Value + Contribution + Direction + + + + {explanation.drivers.map((driver) => ( + + + {driver.name.replace(/_/g, ' ')} + + {driver.description} + + + + {formatNumber(driver.feature_value, 1)} + + + {formatNumber(driver.contribution, 1)} + + + + + + ))} + +
    +
    + +
    +

    Retail signals

    + {explanation.reason_codes.length > 0 ? ( +
      + {explanation.reason_codes.map((reason) => ( + + ))} +
    + ) : ( +

    + No advisory retail signals were detected. +

    + )} +
    + +
    + {explanation.caveats.map((caveat) => ( +

    + {caveat} +

    + ))} +
    +
    + + ) +} diff --git a/frontend/src/components/layout/top-nav.tsx b/frontend/src/components/layout/top-nav.tsx index 6e2b5758..041263e5 100644 --- a/frontend/src/components/layout/top-nav.tsx +++ b/frontend/src/components/layout/top-nav.tsx @@ -59,7 +59,8 @@ export function TopNav() { to={subItem.href} className={cn( 'block select-none rounded-md p-2 text-sm leading-none no-underline outline-none transition-colors hover:bg-accent hover:text-accent-foreground focus:bg-accent focus:text-accent-foreground', - isActive(subItem.href) && 'bg-accent/50' + isActive(subItem.href) && + 'bg-accent font-medium text-accent-foreground' )} > {subItem.label} @@ -75,7 +76,8 @@ export function TopNav() { to={item.href} className={cn( navigationMenuTriggerStyle(), - isActive(item.href) && 'bg-accent/50' + isActive(item.href) && + 'bg-primary text-primary-foreground hover:bg-primary hover:text-primary-foreground focus:bg-primary focus:text-primary-foreground' )} > {item.label} @@ -121,7 +123,8 @@ export function TopNav() { onClick={() => setMobileMenuOpen(false)} className={cn( 'block rounded-md px-2 py-1.5 text-sm hover:bg-accent', - isActive(subItem.href) && 'bg-accent/50 font-medium' + isActive(subItem.href) && + 'bg-accent font-medium text-accent-foreground' )} > {subItem.label} @@ -135,7 +138,8 @@ export function TopNav() { onClick={() => setMobileMenuOpen(false)} className={cn( 'block rounded-md px-2 py-1.5 text-sm font-medium hover:bg-accent', - isActive(item.href) && 'bg-accent/50' + isActive(item.href) && + 'bg-primary text-primary-foreground hover:bg-primary hover:text-primary-foreground' )} > {item.label} diff --git a/frontend/src/components/ui/card.tsx b/frontend/src/components/ui/card.tsx index 681ad980..f5c4d945 100644 --- a/frontend/src/components/ui/card.tsx +++ b/frontend/src/components/ui/card.tsx @@ -7,7 +7,7 @@ function Card({ className, ...props }: React.ComponentProps<"div">) {
    ) { return ( ) diff --git a/frontend/src/hooks/index.ts b/frontend/src/hooks/index.ts index 43fdc939..1c47074d 100644 --- a/frontend/src/hooks/index.ts +++ b/frontend/src/hooks/index.ts @@ -3,9 +3,12 @@ export * from './use-products' export * from './use-kpis' export * from './use-drilldowns' export * from './use-timeseries' +export * from './use-inventory' export * from './use-lifecycle-curve' export * from './use-runs' export * from './use-jobs' +export * from './use-ops' +export * from './use-scenarios' export * from './use-rag-sources' export * from './use-websocket' export * from './use-seeder' diff --git a/frontend/src/hooks/use-explanations.ts b/frontend/src/hooks/use-explanations.ts new file mode 100644 index 00000000..d6ac799f --- /dev/null +++ b/frontend/src/hooks/use-explanations.ts @@ -0,0 +1,48 @@ +import { useMutation, useQuery } from '@tanstack/react-query' +import { api } from '@/lib/api' +import type { ForecastExplanation } from '@/types/api' + +/** + * Explain a registry model run. Disabled until `runId` is set. `retry: false` + * because a 404 (no run) or 400 (non-baseline run) is a final answer, not a + * transient failure. + */ +export function useRunExplanation(runId: string, enabled = true) { + return useQuery({ + queryKey: ['explanations', 'run', runId], + queryFn: () => api(`/explain/runs/${runId}`), + enabled: enabled && !!runId, + retry: false, + }) +} + +/** + * Explain a completed predict job. Disabled until `jobId` is set; `retry: false` + * for the same reason as {@link useRunExplanation}. + */ +export function useJobExplanation(jobId: string, enabled = true) { + return useQuery({ + queryKey: ['explanations', 'job', jobId], + queryFn: () => api(`/explain/jobs/${jobId}`), + enabled: enabled && !!jobId, + retry: false, + }) +} + +/** Request body for POST /explain/forecast. */ +export interface ExplainForecastBody { + store_id: number + product_id: number + model_type: 'naive' | 'seasonal_naive' | 'moving_average' + as_of_date: string // ISO date + season_length?: number + window_size?: number +} + +/** Run an ad-hoc forecast explanation. */ +export function useExplainForecast() { + return useMutation({ + mutationFn: (body: ExplainForecastBody) => + api('/explain/forecast', { method: 'POST', body }), + }) +} diff --git a/frontend/src/hooks/use-ops.ts b/frontend/src/hooks/use-ops.ts new file mode 100644 index 00000000..33141d64 --- /dev/null +++ b/frontend/src/hooks/use-ops.ts @@ -0,0 +1,47 @@ +import { useQuery } from '@tanstack/react-query' +import { api } from '@/lib/api' +import type { + ModelHealthResponse, + OpsSummaryResponse, + RetrainingCandidatesResponse, +} from '@/types/api' + +/** + * Operational summary for the Control Center. Polled every 15s — job/run state + * changes quickly. The global query client already disables refetch-on-focus, + * so this will not double-fire when the tab regains focus. + */ +export function useOpsSummary(enabled = true) { + return useQuery({ + queryKey: ['ops', 'summary'], + queryFn: () => api('/ops/summary'), + refetchInterval: 15000, + enabled, + }) +} + +/** + * Ranked retraining-candidate queue. Deliberately NOT polled — the queue only + * changes when a new run lands, so refetch-on-mount is sufficient. + */ +export function useRetrainingCandidates(limit = 20, enabled = true) { + return useQuery({ + queryKey: ['ops', 'retraining', limit], + queryFn: () => + api('/ops/retraining-candidates', { params: { limit } }), + enabled, + }) +} + +/** + * Per-(store, product) forecast-error health and drift. Deliberately NOT + * polled — drift is slow-moving and only changes when a new run lands, so + * refetch-on-mount is sufficient (mirrors useRetrainingCandidates). + */ +export function useModelHealth(limit = 20, enabled = true) { + return useQuery({ + queryKey: ['ops', 'model-health', limit], + queryFn: () => api('/ops/model-health', { params: { limit } }), + enabled, + }) +} diff --git a/frontend/src/hooks/use-rag-sources.ts b/frontend/src/hooks/use-rag-sources.ts index ee80cc93..55902285 100644 --- a/frontend/src/hooks/use-rag-sources.ts +++ b/frontend/src/hooks/use-rag-sources.ts @@ -4,6 +4,8 @@ import type { SourceListResponse, IndexDocumentRequest, IndexDocumentResponse, + IndexProjectDocsRequest, + IndexProjectDocsResponse, RetrieveRequest, RetrieveResponse, } from '@/types/api' @@ -40,6 +42,23 @@ export function useIndexDocument() { }) } +// Mutation: bulk-index the repo's bundled project docs (POST /rag/index/project-docs). +// Synchronous server-side — the first run can take ~1-3 min with a real embedding +// provider. Invalidates ['rag-sources'] so the list + counts refresh on completion. +export function useIndexProjectDocs() { + const queryClient = useQueryClient() + return useMutation({ + mutationFn: (body: IndexProjectDocsRequest) => + api('/rag/index/project-docs', { + method: 'POST', + body, + }), + onSuccess: () => { + void queryClient.invalidateQueries({ queryKey: ['rag-sources'] }) + }, + }) +} + // Mutation: semantic search over the knowledge base (POST /rag/retrieve). // Search results are ephemeral — no cache invalidation. A 502 (no embedding // provider configured) surfaces as an ApiError the caller degrades gracefully. diff --git a/frontend/src/hooks/use-scenarios.ts b/frontend/src/hooks/use-scenarios.ts new file mode 100644 index 00000000..198165cc --- /dev/null +++ b/frontend/src/hooks/use-scenarios.ts @@ -0,0 +1,82 @@ +import { useMutation, useQuery, useQueryClient } from '@tanstack/react-query' +import { api } from '@/lib/api' +import type { + CompareScenariosRequest, + CreateScenarioRequest, + MultiScenarioComparison, + ScenarioComparison, + ScenarioListResponse, + ScenarioPlanResponse, + SimulateScenarioRequest, +} from '@/types/api' + +/** + * Run a stateless what-if simulation. A mutation, not a query — each run is an + * explicit user action and the result is held in component state. + */ +export function useSimulateScenario() { + return useMutation({ + mutationFn: (data: SimulateScenarioRequest) => + api('/scenarios/simulate', { method: 'POST', body: data }), + }) +} + +/** + * List saved scenario plans, newest first. Pass one or more `tags` to filter + * to plans carrying every listed tag. + */ +export function useScenarios(tags: string[] = [], enabled = true) { + const query = + tags.length > 0 + ? `?${tags.map((tag) => `tags=${encodeURIComponent(tag)}`).join('&')}` + : '' + return useQuery({ + queryKey: ['scenarios', { tags }], + queryFn: () => api(`/scenarios${query}`), + enabled, + }) +} + +/** Fetch one saved plan, including its embedded comparison snapshot. */ +export function useScenario(scenarioId: string, enabled = true) { + return useQuery({ + queryKey: ['scenarios', scenarioId], + queryFn: () => api(`/scenarios/${scenarioId}`), + enabled: enabled && !!scenarioId, + }) +} + +/** Persist a scenario plan; invalidates the saved-plans list on success. */ +export function useCreateScenario() { + const queryClient = useQueryClient() + return useMutation({ + mutationFn: (data: CreateScenarioRequest) => + api('/scenarios', { method: 'POST', body: data }), + onSuccess: () => { + void queryClient.invalidateQueries({ queryKey: ['scenarios'] }) + }, + }) +} + +/** Delete a saved scenario plan; invalidates the saved-plans list on success. */ +export function useDeleteScenario() { + const queryClient = useQueryClient() + return useMutation({ + mutationFn: (scenarioId: string) => + api(`/scenarios/${scenarioId}`, { method: 'DELETE' }), + onSuccess: () => { + void queryClient.invalidateQueries({ queryKey: ['scenarios'] }) + }, + }) +} + +/** + * Compare 2-5 saved scenario plans. A mutation, not a query — the comparison + * is an explicit user action and the result is held in component state. + */ +export function useCompareScenarios() { + return useMutation({ + mutationFn: (data: CompareScenariosRequest) => + api('/scenarios/compare', { method: 'POST', body: data }), + }) +} diff --git a/frontend/src/index.css b/frontend/src/index.css index 3b8289ce..c2eeafad 100644 --- a/frontend/src/index.css +++ b/frontend/src/index.css @@ -3,25 +3,6 @@ @custom-variant dark (&:is(.dark *)); -/* shadcn/ui chart color variables */ -@layer base { - :root { - --chart-1: 221.2 83.2% 53.3%; - --chart-2: 142.1 76.2% 36.3%; - --chart-3: 47.9 95.8% 53.1%; - --chart-4: 24.6 95% 53.1%; - --chart-5: 280.1 93.6% 53.1%; - } - - .dark { - --chart-1: 217.2 91.2% 59.8%; - --chart-2: 142.1 70.6% 45.3%; - --chart-3: 47.9 95.8% 53.1%; - --chart-4: 24.6 95% 53.1%; - --chart-5: 280.1 93.6% 53.1%; - } -} - @theme inline { --radius-sm: calc(var(--radius) - 4px); --radius-md: calc(var(--radius) - 2px); @@ -30,6 +11,10 @@ --radius-2xl: calc(var(--radius) + 8px); --radius-3xl: calc(var(--radius) + 12px); --radius-4xl: calc(var(--radius) + 16px); + /* Soft, cool-tinted elevation shadow — gives white cards real lift. */ + --shadow-card: + 0 1px 2px -1px oklch(0.45 0.05 245 / 0.1), + 0 4px 12px -3px oklch(0.45 0.05 245 / 0.12); --color-background: var(--background); --color-foreground: var(--foreground); --color-card: var(--card); @@ -45,6 +30,13 @@ --color-accent: var(--accent); --color-accent-foreground: var(--accent-foreground); --color-destructive: var(--destructive); + --color-destructive-foreground: var(--destructive-foreground); + --color-success: var(--success); + --color-success-foreground: var(--success-foreground); + --color-warning: var(--warning); + --color-warning-foreground: var(--warning-foreground); + --color-info: var(--info); + --color-info-foreground: var(--info-foreground); --color-border: var(--border); --color-input: var(--input); --color-ring: var(--ring); @@ -63,73 +55,97 @@ --color-sidebar-ring: var(--sidebar-ring); } +/* + * ForecastLabAI theme — ocean-blue brand (OKLCH hue 245). + * All foreground/background and chart/card pairs verified to clear + * WCAG AA (body text >= 4.5:1, large text & UI >= 3:1) in both themes. + */ :root { --radius: 0.625rem; - --background: oklch(1 0 0); - --foreground: oklch(0.145 0 0); + /* Cool blue-gray canvas so the pure-white cards visibly lift off it. */ + --background: oklch(0.96 0.013 245); + --foreground: oklch(0.21 0.02 245); --card: oklch(1 0 0); - --card-foreground: oklch(0.145 0 0); + --card-foreground: oklch(0.21 0.02 245); --popover: oklch(1 0 0); - --popover-foreground: oklch(0.145 0 0); - --primary: oklch(0.205 0 0); + --popover-foreground: oklch(0.21 0.02 245); + --primary: oklch(0.5 0.18 245); --primary-foreground: oklch(0.985 0 0); - --secondary: oklch(0.97 0 0); - --secondary-foreground: oklch(0.205 0 0); - --muted: oklch(0.97 0 0); - --muted-foreground: oklch(0.556 0 0); - --accent: oklch(0.97 0 0); - --accent-foreground: oklch(0.205 0 0); + --secondary: oklch(0.948 0.014 245); + --secondary-foreground: oklch(0.28 0.03 245); + --muted: oklch(0.948 0.014 245); + --muted-foreground: oklch(0.47 0.03 245); + --accent: oklch(0.92 0.05 245); + --accent-foreground: oklch(0.32 0.09 245); --destructive: oklch(0.577 0.245 27.325); - --border: oklch(0.922 0 0); - --input: oklch(0.922 0 0); - --ring: oklch(0.708 0 0); - --chart-1: oklch(0.646 0.222 41.116); - --chart-2: oklch(0.6 0.118 184.704); - --chart-3: oklch(0.398 0.07 227.392); - --chart-4: oklch(0.828 0.189 84.429); - --chart-5: oklch(0.769 0.188 70.08); - --sidebar: oklch(0.985 0 0); - --sidebar-foreground: oklch(0.145 0 0); - --sidebar-primary: oklch(0.205 0 0); + --destructive-foreground: oklch(0.985 0 0); + --success: oklch(0.55 0.13 150); + --success-foreground: oklch(0.985 0 0); + --warning: oklch(0.55 0.13 70); + --warning-foreground: oklch(0.99 0.01 75); + --info: oklch(0.54 0.15 245); + --info-foreground: oklch(0.985 0 0); + --border: oklch(0.87 0.02 245); + --input: oklch(0.87 0.02 245); + --ring: oklch(0.5 0.18 245); + /* Light-only muted table-header fill (dark overrides this to transparent). */ + --table-header: var(--muted); + --chart-1: oklch(0.5 0.18 245); + --chart-2: oklch(0.55 0.11 195); + --chart-3: oklch(0.63 0.16 65); + --chart-4: oklch(0.55 0.2 12); + --chart-5: oklch(0.5 0.19 305); + --sidebar: oklch(0.985 0.004 245); + --sidebar-foreground: oklch(0.21 0.02 245); + --sidebar-primary: oklch(0.5 0.18 245); --sidebar-primary-foreground: oklch(0.985 0 0); - --sidebar-accent: oklch(0.97 0 0); - --sidebar-accent-foreground: oklch(0.205 0 0); - --sidebar-border: oklch(0.922 0 0); - --sidebar-ring: oklch(0.708 0 0); + --sidebar-accent: oklch(0.92 0.05 245); + --sidebar-accent-foreground: oklch(0.32 0.09 245); + --sidebar-border: oklch(0.87 0.02 245); + --sidebar-ring: oklch(0.5 0.18 245); } .dark { - --background: oklch(0.145 0 0); - --foreground: oklch(0.985 0 0); - --card: oklch(0.205 0 0); - --card-foreground: oklch(0.985 0 0); - --popover: oklch(0.205 0 0); - --popover-foreground: oklch(0.985 0 0); - --primary: oklch(0.922 0 0); - --primary-foreground: oklch(0.205 0 0); - --secondary: oklch(0.269 0 0); - --secondary-foreground: oklch(0.985 0 0); - --muted: oklch(0.269 0 0); - --muted-foreground: oklch(0.708 0 0); - --accent: oklch(0.269 0 0); - --accent-foreground: oklch(0.985 0 0); + --background: oklch(0.18 0.012 245); + --foreground: oklch(0.97 0.004 245); + --card: oklch(0.22 0.014 245); + --card-foreground: oklch(0.97 0.004 245); + --popover: oklch(0.22 0.014 245); + --popover-foreground: oklch(0.97 0.004 245); + --primary: oklch(0.52 0.18 245); + --primary-foreground: oklch(0.985 0 0); + --secondary: oklch(0.27 0.015 245); + --secondary-foreground: oklch(0.97 0.004 245); + --muted: oklch(0.27 0.015 245); + --muted-foreground: oklch(0.72 0.015 245); + --accent: oklch(0.31 0.03 245); + --accent-foreground: oklch(0.97 0.004 245); --destructive: oklch(0.704 0.191 22.216); - --border: oklch(1 0 0 / 10%); - --input: oklch(1 0 0 / 15%); - --ring: oklch(0.556 0 0); - --chart-1: oklch(0.488 0.243 264.376); - --chart-2: oklch(0.696 0.17 162.48); - --chart-3: oklch(0.769 0.188 70.08); - --chart-4: oklch(0.627 0.265 303.9); - --chart-5: oklch(0.645 0.246 16.439); - --sidebar: oklch(0.205 0 0); - --sidebar-foreground: oklch(0.985 0 0); - --sidebar-primary: oklch(0.488 0.243 264.376); + --destructive-foreground: oklch(0.18 0.02 25); + --success: oklch(0.7 0.14 155); + --success-foreground: oklch(0.18 0.02 155); + --warning: oklch(0.8 0.14 80); + --warning-foreground: oklch(0.2 0.03 80); + --info: oklch(0.65 0.15 245); + --info-foreground: oklch(0.18 0.02 245); + --border: oklch(0.32 0.012 245); + --input: oklch(0.34 0.014 245); + --ring: oklch(0.55 0.16 245); + /* Preserves the existing dark table header (no fill) — do not change. */ + --table-header: transparent; + --chart-1: oklch(0.65 0.16 245); + --chart-2: oklch(0.7 0.11 195); + --chart-3: oklch(0.78 0.15 70); + --chart-4: oklch(0.68 0.18 12); + --chart-5: oklch(0.66 0.16 305); + --sidebar: oklch(0.22 0.014 245); + --sidebar-foreground: oklch(0.97 0.004 245); + --sidebar-primary: oklch(0.65 0.16 245); --sidebar-primary-foreground: oklch(0.985 0 0); - --sidebar-accent: oklch(0.269 0 0); - --sidebar-accent-foreground: oklch(0.985 0 0); - --sidebar-border: oklch(1 0 0 / 10%); - --sidebar-ring: oklch(0.556 0 0); + --sidebar-accent: oklch(0.31 0.03 245); + --sidebar-accent-foreground: oklch(0.97 0.004 245); + --sidebar-border: oklch(0.32 0.012 245); + --sidebar-ring: oklch(0.55 0.16 245); } @layer base { diff --git a/frontend/src/lib/api.test.ts b/frontend/src/lib/api.test.ts new file mode 100644 index 00000000..f09dee8a --- /dev/null +++ b/frontend/src/lib/api.test.ts @@ -0,0 +1,52 @@ +import { afterEach, describe, expect, it, vi } from 'vitest' +import { ApiError, api } from './api' + +/** Build a fake `fetch` that returns one canned `Response`. */ +function stubFetch(body: string, init: ResponseInit) { + const fetchMock = vi.fn().mockResolvedValue(new Response(body, init)) + vi.stubGlobal('fetch', fetchMock) + return fetchMock +} + +afterEach(() => { + vi.unstubAllGlobals() +}) + +describe('api()', () => { + it('parses an RFC 7807 application/problem+json error body into ApiError.detail', async () => { + // Regression: api() previously only treated `application/json` as JSON, so + // `application/problem+json` error bodies went unparsed and the raw JSON + // string leaked into the UI via getErrorMessage(). + const problem = { + type: '/errors/bad-request', + title: 'Bad Request', + status: 400, + detail: 'Need at least 7 observations', + code: 'BAD_REQUEST', + } + stubFetch(JSON.stringify(problem), { + status: 400, + headers: { 'content-type': 'application/problem+json' }, + }) + + const err = await api('/explain/forecast', { method: 'POST', body: {} }).catch( + (e: unknown) => e, + ) + + expect(err).toBeInstanceOf(ApiError) + expect((err as ApiError).status).toBe(400) + expect((err as ApiError).message).toBe('Need at least 7 observations') + expect((err as ApiError).detail?.detail).toBe('Need at least 7 observations') + }) + + it('parses a plain application/json success body', async () => { + stubFetch(JSON.stringify({ status: 'ok' }), { + status: 200, + headers: { 'content-type': 'application/json' }, + }) + + const data = await api<{ status: string }>('/health') + + expect(data).toEqual({ status: 'ok' }) + }) +}) diff --git a/frontend/src/lib/api.ts b/frontend/src/lib/api.ts index c3eeec81..bc6ddaed 100644 --- a/frontend/src/lib/api.ts +++ b/frontend/src/lib/api.ts @@ -48,7 +48,11 @@ export async function api(endpoint: string, config: RequestConfig = {}): Prom const contentType = response.headers.get('content-type') || '' const rawBody = await response.text() - const isJson = contentType.includes('application/json') + // RFC 7807 error responses use `application/problem+json`, so match the + // `+json` structured-syntax suffix as well as plain `application/json` -- + // otherwise error bodies go unparsed and `ApiError.detail` is always empty. + const isJson = + contentType.includes('application/json') || contentType.includes('+json') let data: unknown = undefined if (rawBody && isJson) { diff --git a/frontend/src/lib/constants.ts b/frontend/src/lib/constants.ts index d802c664..64f60654 100644 --- a/frontend/src/lib/constants.ts +++ b/frontend/src/lib/constants.ts @@ -2,6 +2,7 @@ export const ROUTES = { DASHBOARD: '/', SHOWCASE: '/showcase', + OPS: '/ops', EXPLORER: { SALES: '/explorer/sales', STORES: '/explorer/stores', @@ -22,6 +23,7 @@ export const ROUTES = { FORECAST: '/visualize/forecast', BACKTEST: '/visualize/backtest', DEMAND: '/visualize/demand', + PLANNER: '/visualize/planner', }, KNOWLEDGE: '/knowledge', CHAT: '/chat', @@ -33,6 +35,7 @@ export const ROUTES = { export const NAV_ITEMS = [ { label: 'Dashboard', href: ROUTES.DASHBOARD }, { label: 'Showcase', href: ROUTES.SHOWCASE }, + { label: 'Control Center', href: ROUTES.OPS }, { label: 'Explorer', items: [ @@ -47,6 +50,7 @@ export const NAV_ITEMS = [ label: 'Visualize', items: [ { label: 'Demand Planner', href: ROUTES.VISUALIZE.DEMAND }, + { label: 'What-If Planner', href: ROUTES.VISUALIZE.PLANNER }, { label: 'Forecast', href: ROUTES.VISUALIZE.FORECAST }, { label: 'Backtest Results', href: ROUTES.VISUALIZE.BACKTEST }, ], diff --git a/frontend/src/lib/csv-export.test.ts b/frontend/src/lib/csv-export.test.ts index 4112527c..7678f3c8 100644 --- a/frontend/src/lib/csv-export.test.ts +++ b/frontend/src/lib/csv-export.test.ts @@ -53,4 +53,17 @@ describe('toCsv', () => { ] expect(toCsv([{ a: null, b: undefined }], nullableColumns)).toBe('A,B\r\n,') }) + + it('neutralizes CSV formula injection by prefixing a quote', () => { + interface AttackRow { + val: string + } + const attackColumns: CsvColumn[] = [{ key: 'val', header: 'Val' }] + expect(toCsv([{ val: '=SUM(A1:A2)' }], attackColumns)).toBe("Val\r\n'=SUM(A1:A2)") + expect(toCsv([{ val: '+1+1' }], attackColumns)).toBe("Val\r\n'+1+1") + expect(toCsv([{ val: '-2+3' }], attackColumns)).toBe("Val\r\n'-2+3") + expect(toCsv([{ val: '@cmd' }], attackColumns)).toBe("Val\r\n'@cmd") + // A neutralized value that also needs RFC 4180 quoting is still wrapped. + expect(toCsv([{ val: '=1,2' }], attackColumns)).toBe('Val\r\n"\'=1,2"') + }) }) diff --git a/frontend/src/lib/csv-export.ts b/frontend/src/lib/csv-export.ts index 692cbdc7..3ceb8737 100644 --- a/frontend/src/lib/csv-export.ts +++ b/frontend/src/lib/csv-export.ts @@ -6,7 +6,13 @@ export interface CsvColumn { /** Quote a single CSV field per RFC 4180 (wrap + double internal quotes). */ function quoteField(value: unknown): string { - const str = value === null || value === undefined ? '' : String(value) + let str = value === null || value === undefined ? '' : String(value) + // CSV formula injection: a spreadsheet executes a cell whose value begins + // with =, +, -, @, or a control char (tab / CR). Prefix a single quote so + // the value is rendered as literal text instead of an evaluated formula. + if (/^[=+\-@\t\r]/.test(str)) { + str = `'${str}` + } if (/[",\r\n]/.test(str)) { return `"${str.replace(/"/g, '""')}"` } diff --git a/frontend/src/lib/demand-utils.test.ts b/frontend/src/lib/demand-utils.test.ts index e4294481..6f877553 100644 --- a/frontend/src/lib/demand-utils.test.ts +++ b/frontend/src/lib/demand-utils.test.ts @@ -122,6 +122,11 @@ describe('inventoryRequirement', () => { it('treats a null on-order as zero', () => { expect(inventoryRequirement(100, 30, null)).toBe(70) }) + + it('rounds a fractional shortfall up to avoid under-ordering', () => { + // shortfall = 10.4 - 5 - 0 = 5.4 → ceil → 6 (Math.round would yield 5) + expect(inventoryRequirement(10.4, 5, 0)).toBe(6) + }) }) describe('extractForecasts', () => { @@ -186,6 +191,20 @@ describe('joinDemandRows', () => { expect(row.inventoryRequirement).toBeNull() }) + it('keeps the latest inventory snapshot per grain by date', () => { + const job = makePredictJob('j1', { + store_id: 1, + product_id: 10, + forecasts: flatForecasts(7, 1), + }) + const older: InventoryStatusItem = { ...makeInventory(1, 10, 99, 0), date: '2026-01-01' } + const newer: InventoryStatusItem = { ...makeInventory(1, 10, 12, 3), date: '2026-02-01' } + // `newer` listed first — a naive Map-from-entries would keep `older`. + const [row] = joinDemandRows([job], products, [newer, older], 7) + expect(row.onHand).toBe(12) + expect(row.onOrder).toBe(3) + }) + it('falls back to a #id SKU when the product is unknown', () => { const job = makePredictJob('j1', { store_id: 1, diff --git a/frontend/src/lib/demand-utils.ts b/frontend/src/lib/demand-utils.ts index 8be1b9fc..fc475269 100644 --- a/frontend/src/lib/demand-utils.ts +++ b/frontend/src/lib/demand-utils.ts @@ -43,7 +43,9 @@ export function inventoryRequirement( onOrder: number | null, ): number | null { if (onHand === null) return null - return Math.max(0, Math.round(leadTimeDemand - onHand - (onOrder ?? 0))) + // Round up: a fractional shortfall still needs a whole extra unit ordered — + // Math.round would under-order when the fraction is below 0.5. + return Math.max(0, Math.ceil(leadTimeDemand - onHand - (onOrder ?? 0))) } /** @@ -86,9 +88,17 @@ export function joinDemandRows( leadTimeDays: number, ): DemandRow[] { const productById = new Map(products.map((product) => [product.id, product])) - const inventoryByGrain = new Map( - inventory.map((item) => [`${item.store_id}:${item.product_id}`, item]), - ) + // Keep the latest snapshot per grain. The API returns one row per grain, but + // a naive Map-from-entries would silently keep whichever row is last in the + // array — pick by `date` so order in the response can never matter. + const inventoryByGrain = new Map() + for (const item of inventory) { + const key = `${item.store_id}:${item.product_id}` + const existing = inventoryByGrain.get(key) + if (!existing || item.date > existing.date) { + inventoryByGrain.set(key, item) + } + } const rows: DemandRow[] = [] for (const job of predictJobs) { diff --git a/frontend/src/lib/incident-report.test.ts b/frontend/src/lib/incident-report.test.ts new file mode 100644 index 00000000..7b80a66a --- /dev/null +++ b/frontend/src/lib/incident-report.test.ts @@ -0,0 +1,138 @@ +import { describe, it, expect } from 'vitest' +import { attentionCsvColumns, buildIncidentMarkdown } from './incident-report' +import type { + AttentionItem, + ModelHealthEntry, + OpsSummaryResponse, + RetrainingCandidate, +} from '@/types/api' + +/** Build an OpsSummaryResponse with sensible defaults for fields not under test. */ +function makeSummary(overrides: Partial = {}): OpsSummaryResponse { + return { + system: { + api_ok: true, + database_connected: true, + latest_successful_job_at: '2026-05-19T10:00:00Z', + }, + jobs: { counts: [], completed_today: 2, failed_total: 1, active_total: 3 }, + runs: { counts: [], success_rate: 0.8, failed_total: 1 }, + aliases: [], + freshness: { + latest_sales_date: '2026-05-18', + latest_job_completed_at: '2026-05-19T09:00:00Z', + latest_run_completed_at: '2026-05-19T08:00:00Z', + }, + attention_items: [], + generated_at: '2026-05-19T12:00:00Z', + ...overrides, + } +} + +/** Build an AttentionItem with sensible defaults. */ +function makeAttentionItem( + partial: Partial & Pick, +): AttentionItem { + return { + item_type: partial.item_type, + entity_id: partial.entity_id ?? 'e1', + label: partial.label ?? 'label', + detail: partial.detail ?? 'detail', + occurred_at: partial.occurred_at ?? null, + } +} + +/** Build a ModelHealthEntry with sensible defaults. */ +function makeHealthEntry(partial: Partial = {}): ModelHealthEntry { + return { + store_id: partial.store_id ?? 1, + product_id: partial.product_id ?? 2, + run_count: partial.run_count ?? 3, + latest_run_id: partial.latest_run_id ?? 'r1', + latest_run_status: partial.latest_run_status ?? 'success', + latest_wape: partial.latest_wape ?? 25, + previous_wape: partial.previous_wape ?? 11, + wape_delta: partial.wape_delta ?? 14, + drift_direction: partial.drift_direction ?? 'degrading', + last_trained_at: partial.last_trained_at ?? null, + staleness_days: partial.staleness_days ?? 10, + wape_history: partial.wape_history ?? [], + } +} + +/** Build a RetrainingCandidate with sensible defaults. */ +function makeCandidate(partial: Partial = {}): RetrainingCandidate { + return { + store_id: partial.store_id ?? 1, + product_id: partial.product_id ?? 2, + priority_score: partial.priority_score ?? 0.75, + staleness_days: partial.staleness_days ?? 30, + wape: partial.wape ?? 12, + latest_run_id: partial.latest_run_id ?? 'r1', + latest_run_status: partial.latest_run_status ?? 'success', + reason: partial.reason ?? 'reason', + } +} + +describe('attentionCsvColumns', () => { + it('exposes the five attention-item columns in order', () => { + expect(attentionCsvColumns.map((column) => column.key)).toEqual([ + 'item_type', + 'entity_id', + 'label', + 'detail', + 'occurred_at', + ]) + }) +}) + +describe('buildIncidentMarkdown', () => { + it('renders the report title and generated timestamp', () => { + const md = buildIncidentMarkdown(makeSummary(), [], []) + expect(md).toContain('# ForecastOps Incident Report') + expect(md).toContain('_Generated 2026-05-19T12:00:00Z_') + }) + + it('renders KPI lines from the summary', () => { + const md = buildIncidentMarkdown(makeSummary(), [], []) + expect(md).toContain('- Active jobs: 3') + expect(md).toContain('- Run success rate: 80.0%') + }) + + it('shows the empty-state line when nothing needs attention', () => { + const md = buildIncidentMarkdown(makeSummary(), [], []) + expect(md).toContain('## Needs Attention (0)') + expect(md).toContain('_Nothing needs attention._') + }) + + it('renders an attention table row and escapes pipe characters', () => { + const summary = makeSummary({ + attention_items: [ + makeAttentionItem({ item_type: 'failed_job', label: 'train', detail: 'a | b' }), + ], + }) + const md = buildIncidentMarkdown(summary, [], []) + expect(md).toContain('## Needs Attention (1)') + expect(md).toContain('a \\| b') + }) + + it('renders the model-health drift section with a signed delta', () => { + const md = buildIncidentMarkdown(makeSummary(), [], [ + makeHealthEntry({ drift_direction: 'degrading', wape_delta: 14 }), + ]) + expect(md).toContain('## Model Health — Drift (1)') + expect(md).toContain('degrading') + expect(md).toContain('+14.0') + }) + + it('renders the retraining-candidates section', () => { + const md = buildIncidentMarkdown(makeSummary(), [makeCandidate({ priority_score: 0.91 })], []) + expect(md).toContain('## Top Retraining Candidates (1)') + expect(md).toContain('0.91') + }) + + it('handles a null success rate', () => { + const summary = makeSummary({ runs: { counts: [], success_rate: null, failed_total: 0 } }) + expect(buildIncidentMarkdown(summary, [], [])).toContain('- Run success rate: —') + }) +}) diff --git a/frontend/src/lib/incident-report.ts b/frontend/src/lib/incident-report.ts new file mode 100644 index 00000000..d04e6783 --- /dev/null +++ b/frontend/src/lib/incident-report.ts @@ -0,0 +1,140 @@ +// Builders for the ForecastOps incident report — a client-side CSV + Markdown +// export of the operational snapshot already loaded on the /ops page. The +// builders take no I/O and are unit-tested (incident-report.test.ts); +// downloadMarkdown is the one DOM-touching helper (mirrors csv-export.ts). +import type { CsvColumn } from '@/lib/csv-export' +import { formatWapeDelta } from '@/lib/ops-utils' +import type { + AttentionItem, + ModelHealthEntry, + OpsSummaryResponse, + RetrainingCandidate, +} from '@/types/api' + +/** CSV column set for the attention-items export (feed to toCsv / downloadCsv). */ +export const attentionCsvColumns: CsvColumn[] = [ + { key: 'item_type', header: 'Type' }, + { key: 'entity_id', header: 'Entity' }, + { key: 'label', header: 'Item' }, + { key: 'detail', header: 'Detail' }, + { key: 'occurred_at', header: 'When' }, +] + +/** Render a value for a Markdown table cell: '—' for empty, pipes/newlines neutralised. */ +function mdCell(value: string | number | null | undefined): string { + if (value === null || value === undefined || value === '') return '—' + return String(value) + .replace(/\|/g, '\\|') + .replace(/[\r\n]+/g, ' ') +} + +/** Percentage display for the run success rate; '—' when null. */ +function pct(rate: number | null): string { + return rate === null ? '—' : `${(rate * 100).toFixed(1)}%` +} + +/** One-decimal WAPE display; '—' when null. */ +function wape(value: number | null): string { + return value === null ? '—' : value.toFixed(1) +} + +/** + * Build a human-readable Markdown incident report from already-loaded /ops + * page data. Pure: no fetch, no DOM, deterministic for a given input — the + * timestamps are emitted verbatim so the output is stable for unit tests. + */ +export function buildIncidentMarkdown( + summary: OpsSummaryResponse, + candidates: RetrainingCandidate[], + modelHealth: ModelHealthEntry[], +): string { + const staleAliases = summary.aliases.filter((alias) => alias.is_stale).length + const lines: string[] = [ + '# ForecastOps Incident Report', + '', + `_Generated ${summary.generated_at}_`, + '', + '## System Health', + '', + `- API: ${summary.system.api_ok ? 'ok' : 'down'}`, + `- Database: ${summary.system.database_connected ? 'connected' : 'down'}`, + `- Latest successful job: ${summary.system.latest_successful_job_at ?? '—'}`, + '', + '## KPIs', + '', + `- Active jobs: ${summary.jobs.active_total}`, + `- Failed jobs: ${summary.jobs.failed_total}`, + `- Completed today: ${summary.jobs.completed_today}`, + `- Run success rate: ${pct(summary.runs.success_rate)}`, + `- Failed runs: ${summary.runs.failed_total}`, + `- Stale aliases: ${staleAliases} of ${summary.aliases.length}`, + '', + '## Data Freshness', + '', + `- Latest sales date: ${summary.freshness.latest_sales_date ?? '—'}`, + `- Latest completed job: ${summary.freshness.latest_job_completed_at ?? '—'}`, + `- Latest successful run: ${summary.freshness.latest_run_completed_at ?? '—'}`, + '', + `## Needs Attention (${summary.attention_items.length})`, + '', + ] + + if (summary.attention_items.length === 0) { + lines.push('_Nothing needs attention._', '') + } else { + lines.push('| Type | Item | Detail | When |', '| --- | --- | --- | --- |') + for (const item of summary.attention_items) { + lines.push( + `| ${mdCell(item.item_type)} | ${mdCell(item.label)} | ${mdCell(item.detail)} | ${mdCell(item.occurred_at)} |`, + ) + } + lines.push('') + } + + lines.push(`## Model Health — Drift (${modelHealth.length})`, '') + if (modelHealth.length === 0) { + lines.push('_No model health to evaluate._', '') + } else { + lines.push( + '| Store | Product | Drift | Latest WAPE | Δ WAPE | Runs |', + '| --- | --- | --- | --- | --- | --- |', + ) + for (const entry of modelHealth) { + lines.push( + `| ${mdCell(entry.store_id)} | ${mdCell(entry.product_id)} | ${mdCell(entry.drift_direction)} | ${wape(entry.latest_wape)} | ${formatWapeDelta(entry.wape_delta)} | ${mdCell(entry.run_count)} |`, + ) + } + lines.push('') + } + + lines.push(`## Top Retraining Candidates (${candidates.length})`, '') + if (candidates.length === 0) { + lines.push('_No retraining candidates._', '') + } else { + lines.push( + '| Store | Product | Priority | Staleness (days) | WAPE | Reason |', + '| --- | --- | --- | --- | --- | --- |', + ) + for (const candidate of candidates) { + lines.push( + `| ${mdCell(candidate.store_id)} | ${mdCell(candidate.product_id)} | ${candidate.priority_score.toFixed(2)} | ${mdCell(candidate.staleness_days)} | ${wape(candidate.wape)} | ${mdCell(candidate.reason)} |`, + ) + } + lines.push('') + } + + return lines.join('\n') +} + +/** Trigger a browser download of `content` as a Markdown file. */ +export function downloadMarkdown(filename: string, content: string): void { + const blob = new Blob([content], { type: 'text/markdown;charset=utf-8;' }) + const url = URL.createObjectURL(blob) + const link = document.createElement('a') + link.href = url + link.download = filename + document.body.appendChild(link) + link.click() + document.body.removeChild(link) + URL.revokeObjectURL(url) +} diff --git a/frontend/src/lib/ops-actions.test.ts b/frontend/src/lib/ops-actions.test.ts new file mode 100644 index 00000000..a5bc2dcb --- /dev/null +++ b/frontend/src/lib/ops-actions.test.ts @@ -0,0 +1,79 @@ +import { describe, it, expect } from 'vitest' +import { buildRetrainJob } from './ops-actions' +import type { ModelRun } from '@/types/api' + +/** Build a ModelRun with sensible defaults for fields not under test. */ +function makeRun(overrides: Partial = {}): ModelRun { + return { + run_id: 'r1', + status: 'success', + model_type: 'naive', + model_config: { model_type: 'naive', schema_version: '1.0' }, + feature_config: null, + config_hash: 'abc', + data_window_start: '2025-01-01', + data_window_end: '2026-01-01', + store_id: 9, + product_id: 8, + metrics: null, + artifact_uri: null, + artifact_hash: null, + artifact_size_bytes: null, + runtime_info: null, + agent_context: null, + git_sha: null, + error_message: null, + started_at: null, + completed_at: null, + created_at: '2026-01-01T00:00:00Z', + updated_at: '2026-01-01T00:00:00Z', + ...overrides, + } +} + +describe('buildRetrainJob', () => { + it('builds a train job with the flat param contract', () => { + const job = buildRetrainJob(makeRun(), '2026-05-18') + expect(job.job_type).toBe('train') + expect(job.params).toEqual({ + model_type: 'naive', + store_id: 9, + product_id: 8, + start_date: '2025-01-01', + end_date: '2026-05-18', + }) + }) + + it('falls back to the run window end when no latest sales date is known', () => { + expect(buildRetrainJob(makeRun(), null).params.end_date).toBe('2026-01-01') + }) + + it('lifts season_length for a seasonal_naive run', () => { + const run = makeRun({ + model_type: 'seasonal_naive', + model_config: { model_type: 'seasonal_naive', schema_version: '1.0', season_length: 7 }, + }) + expect(buildRetrainJob(run, null).params.season_length).toBe(7) + }) + + it('lifts window_size for a moving_average run', () => { + const run = makeRun({ + model_type: 'moving_average', + model_config: { model_type: 'moving_average', schema_version: '1.0', window_size: 14 }, + }) + expect(buildRetrainJob(run, null).params.window_size).toBe(14) + }) + + it('omits model-specific keys when model_config carries none', () => { + const job = buildRetrainJob(makeRun(), null) + expect(job.params).not.toHaveProperty('season_length') + expect(job.params).not.toHaveProperty('window_size') + }) + + it('ignores non-numeric model_config values', () => { + const run = makeRun({ + model_config: { model_type: 'seasonal_naive', season_length: 'weekly' }, + }) + expect(buildRetrainJob(run, null).params).not.toHaveProperty('season_length') + }) +}) diff --git a/frontend/src/lib/ops-actions.ts b/frontend/src/lib/ops-actions.ts new file mode 100644 index 00000000..54723a99 --- /dev/null +++ b/frontend/src/lib/ops-actions.ts @@ -0,0 +1,43 @@ +// Pure builder for the ForecastOps action layer — turns a source model run +// into the POST /jobs body that retrains its (store, product) grain. Kept +// React-free and unit-tested (ops-actions.test.ts). +import type { JobCreate, ModelRun } from '@/types/api' + +/** Read a numeric field from a run's model_config JSONB, or null when absent. */ +function numericConfig(config: Record, key: string): number | null { + const value = config[key] + return typeof value === 'number' ? value : null +} + +/** + * Build the `POST /jobs` body that retrains a grain from its source run. + * + * The train job consumes a FLAT params dict — verified against + * `app/features/jobs/service.py::_execute_train`: `model_type`, `store_id`, + * `product_id`, `start_date`, `end_date`, plus the model-specific + * `season_length` (seasonal_naive) / `window_size` (moving_average) lifted + * from the source run's `model_config`. There is no `period` key. The + * end date is advanced to the freshest available sales date so the retrain + * sees every observation since the original training window. + * + * @param run - The source model run (GET /registry/runs/{latest_run_id}). + * @param latestSalesDate - summary.freshness.latest_sales_date, or null. + */ +export function buildRetrainJob(run: ModelRun, latestSalesDate: string | null): JobCreate { + const params: Record = { + model_type: run.model_type, + store_id: run.store_id, + product_id: run.product_id, + start_date: run.data_window_start, + end_date: latestSalesDate ?? run.data_window_end, + } + const seasonLength = numericConfig(run.model_config, 'season_length') + if (seasonLength !== null) { + params.season_length = seasonLength + } + const windowSize = numericConfig(run.model_config, 'window_size') + if (windowSize !== null) { + params.window_size = windowSize + } + return { job_type: 'train', params } +} diff --git a/frontend/src/lib/ops-utils.test.ts b/frontend/src/lib/ops-utils.test.ts new file mode 100644 index 00000000..a3a49c60 --- /dev/null +++ b/frontend/src/lib/ops-utils.test.ts @@ -0,0 +1,144 @@ +import { describe, it, expect } from 'vitest' +import { + attentionBadgeVariant, + attentionItemLink, + driftBadgeVariant, + formatStaleness, + formatWapeDelta, + sortRetrainingCandidates, + summaryHealthVariant, +} from './ops-utils' +import type { AttentionItem, RetrainingCandidate, SystemHealth } from '@/types/api' + +/** Build an AttentionItem with sensible defaults for fields not under test. */ +function makeItem(partial: Partial & Pick): AttentionItem { + return { + item_type: partial.item_type, + entity_id: partial.entity_id ?? 'entity-1', + label: partial.label ?? 'label', + detail: partial.detail ?? 'detail', + occurred_at: partial.occurred_at ?? null, + } +} + +/** Build a RetrainingCandidate with sensible defaults. */ +function makeCandidate( + partial: Partial & Pick, +): RetrainingCandidate { + return { + store_id: partial.store_id ?? 1, + product_id: partial.product_id ?? 1, + priority_score: partial.priority_score, + staleness_days: partial.staleness_days ?? 0, + wape: partial.wape ?? null, + latest_run_id: partial.latest_run_id ?? 'run-1', + latest_run_status: partial.latest_run_status ?? 'success', + reason: partial.reason ?? 'reason', + } +} + +describe('summaryHealthVariant', () => { + it('is success when API and database are both up', () => { + const system: SystemHealth = { + api_ok: true, + database_connected: true, + latest_successful_job_at: null, + } + expect(summaryHealthVariant(system)).toBe('success') + }) + + it('is error when the database is down', () => { + const system: SystemHealth = { + api_ok: true, + database_connected: false, + latest_successful_job_at: null, + } + expect(summaryHealthVariant(system)).toBe('error') + }) +}) + +describe('attentionItemLink', () => { + it('links a failed job to the job detail page', () => { + expect(attentionItemLink(makeItem({ item_type: 'failed_job', entity_id: 'job-9' }))).toBe( + '/explorer/jobs/job-9', + ) + }) + + it('links a failed run to the run detail page', () => { + expect(attentionItemLink(makeItem({ item_type: 'failed_run', entity_id: 'run-9' }))).toBe( + '/explorer/runs/run-9', + ) + }) + + it('links a stale alias to the run detail page', () => { + expect(attentionItemLink(makeItem({ item_type: 'stale_alias', entity_id: 'run-3' }))).toBe( + '/explorer/runs/run-3', + ) + }) +}) + +describe('attentionBadgeVariant', () => { + it('warns for a stale alias', () => { + expect(attentionBadgeVariant('stale_alias')).toBe('warning') + }) + + it('errors for failed jobs and runs', () => { + expect(attentionBadgeVariant('failed_job')).toBe('error') + expect(attentionBadgeVariant('failed_run')).toBe('error') + }) +}) + +describe('formatStaleness', () => { + it('renders a positive day count', () => { + expect(formatStaleness(12)).toBe('12d') + }) + + it('renders "today" at zero or negative days', () => { + expect(formatStaleness(0)).toBe('today') + expect(formatStaleness(-3)).toBe('today') + }) +}) + +describe('sortRetrainingCandidates', () => { + it('sorts by priority score descending', () => { + const sorted = sortRetrainingCandidates([ + makeCandidate({ priority_score: 0.2 }), + makeCandidate({ priority_score: 0.9 }), + makeCandidate({ priority_score: 0.5 }), + ]) + expect(sorted.map((c) => c.priority_score)).toEqual([0.9, 0.5, 0.2]) + }) + + it('does not mutate the input array', () => { + const input = [makeCandidate({ priority_score: 0.1 }), makeCandidate({ priority_score: 0.8 })] + sortRetrainingCandidates(input) + expect(input.map((c) => c.priority_score)).toEqual([0.1, 0.8]) + }) + + it('returns an empty array unchanged', () => { + expect(sortRetrainingCandidates([])).toEqual([]) + }) +}) + +describe('driftBadgeVariant', () => { + it('maps each drift direction to its badge variant', () => { + expect(driftBadgeVariant('degrading')).toBe('error') + expect(driftBadgeVariant('improving')).toBe('success') + expect(driftBadgeVariant('stable')).toBe('info') + expect(driftBadgeVariant('unknown')).toBe('default') + }) +}) + +describe('formatWapeDelta', () => { + it('prefixes a positive delta with +', () => { + expect(formatWapeDelta(14)).toBe('+14.0') + }) + + it('keeps a negative delta sign', () => { + expect(formatWapeDelta(-9.3)).toBe('-9.3') + }) + + it('renders an em dash for a null delta', () => { + expect(formatWapeDelta(null)).toBe('—') + }) +}) diff --git a/frontend/src/lib/ops-utils.ts b/frontend/src/lib/ops-utils.ts new file mode 100644 index 00000000..20f48eaa --- /dev/null +++ b/frontend/src/lib/ops-utils.ts @@ -0,0 +1,80 @@ +// Pure, React-free helpers for the ForecastOps Control Center page. Kept +// separate from the page component so they are cheap to unit-test (see +// ops-utils.test.ts) — mirrors the knowledge-utils.ts / status-utils.ts precedent. +import { ROUTES } from '@/lib/constants' +import type { AttentionItem, DriftDirection, RetrainingCandidate, SystemHealth } from '@/types/api' + +/** + * System-health badge variant: 'success' only when the API and database are + * both up, 'error' otherwise. + */ +export function summaryHealthVariant(system: SystemHealth): 'success' | 'error' { + return system.api_ok && system.database_connected ? 'success' : 'error' +} + +/** + * Deep-link an attention item to its Explorer detail page. A failed job links + * to the job detail page; a failed run and a stale alias both carry a run_id + * and link to the run detail page. + */ +export function attentionItemLink(item: AttentionItem): string { + if (item.item_type === 'failed_job') { + return `${ROUTES.EXPLORER.JOBS}/${item.entity_id}` + } + return `${ROUTES.EXPLORER.RUNS}/${item.entity_id}` +} + +/** + * Badge variant for an attention row — a stale alias is a 'warning', a failed + * job or run is an 'error'. + */ +export function attentionBadgeVariant( + itemType: AttentionItem['item_type'], +): 'error' | 'warning' { + return itemType === 'stale_alias' ? 'warning' : 'error' +} + +/** + * Human-readable staleness: "today" at zero or negative days, "{n}d" otherwise. + */ +export function formatStaleness(days: number): string { + return days <= 0 ? 'today' : `${days}d` +} + +/** + * Return a copy of the candidates sorted by priority score, most urgent first. + * The backend already sorts, but sorting again keeps the page correct if the + * order ever changes upstream. + */ +export function sortRetrainingCandidates(rows: RetrainingCandidate[]): RetrainingCandidate[] { + return [...rows].sort((a, b) => b.priority_score - a.priority_score) +} + +/** + * Badge variant for a drift verdict — 'degrading' is an error, 'improving' a + * success, 'stable' an info, and 'unknown' a neutral default. + */ +export function driftBadgeVariant( + direction: DriftDirection, +): 'success' | 'error' | 'info' | 'default' { + switch (direction) { + case 'degrading': + return 'error' + case 'improving': + return 'success' + case 'stable': + return 'info' + default: + return 'default' + } +} + +/** + * Signed, one-decimal WAPE delta for display ("+14.0" / "-9.3"); '—' when the + * grain has fewer than two numeric WAPEs (delta is null). + */ +export function formatWapeDelta(delta: number | null): string { + if (delta === null) return '—' + const sign = delta > 0 ? '+' : '' + return `${sign}${delta.toFixed(1)}` +} diff --git a/frontend/src/lib/scenario-utils.test.ts b/frontend/src/lib/scenario-utils.test.ts new file mode 100644 index 00000000..b091470a --- /dev/null +++ b/frontend/src/lib/scenario-utils.test.ts @@ -0,0 +1,145 @@ +import { describe, expect, it } from 'vitest' +import { + buildMultiSeries, + coverageLabel, + coverageVariant, + formatDelta, + mergeComparisonSeries, + methodLabel, + summariseAssumptions, +} from './scenario-utils' +import type { + MultiScenarioComparison, + ScenarioAssumptions, + ScenarioComparisonRow, + ScenarioPoint, +} from '@/types/api' + +function makePoint(date: string, baseline: number, scenario: number): ScenarioPoint { + return { + date, + baseline, + scenario, + delta: scenario - baseline, + applied_factor: baseline === 0 ? 1 : scenario / baseline, + } +} + +describe('mergeComparisonSeries', () => { + it('flattens points into date / baseline / scenario rows', () => { + const rows = mergeComparisonSeries([makePoint('2026-07-01', 10, 12)]) + expect(rows).toEqual([{ date: '2026-07-01', baseline: 10, scenario: 12 }]) + }) + + it('returns an empty array for no points', () => { + expect(mergeComparisonSeries([])).toEqual([]) + }) +}) + +describe('formatDelta', () => { + it('prefixes a plus sign for positive values', () => { + expect(formatDelta(12.34)).toBe('+12.3') + }) + + it('keeps the minus sign for negative values', () => { + expect(formatDelta(-4.5)).toBe('-4.5') + }) + + it('formats zero without a sign', () => { + expect(formatDelta(0)).toBe('0.0') + }) + + it('honours the decimals argument', () => { + expect(formatDelta(3, 0)).toBe('+3') + }) +}) + +describe('coverageLabel / coverageVariant', () => { + it('maps every verdict to a label and a badge variant', () => { + expect(coverageLabel('covered')).toBe('Covered') + expect(coverageLabel('at_risk')).toBe('At risk') + expect(coverageLabel('stockout')).toBe('Stockout') + expect(coverageLabel('unknown')).toBe('Unknown') + expect(coverageVariant('covered')).toBe('success') + expect(coverageVariant('at_risk')).toBe('warning') + expect(coverageVariant('stockout')).toBe('error') + expect(coverageVariant('unknown')).toBe('default') + }) +}) + +describe('summariseAssumptions', () => { + it('returns a baseline-only line for empty assumptions', () => { + expect(summariseAssumptions({})).toEqual(['No assumptions — baseline only']) + }) + + it('summarises a price cut with sign-aware wording', () => { + const assumptions: ScenarioAssumptions = { + price: { change_pct: -0.15, start_date: '2026-07-01', end_date: '2026-07-14' }, + } + const [line] = summariseAssumptions(assumptions) + expect(line).toContain('Price cut of 15%') + }) + + it('lists every supplied assumption', () => { + const assumptions: ScenarioAssumptions = { + price: { change_pct: 0.1, start_date: '2026-07-01', end_date: '2026-07-07' }, + promotion: { kind: 'bogo', start_date: '2026-07-02', end_date: '2026-07-05' }, + holiday: { dates: ['2026-07-04'] }, + inventory: { on_hand_units: 500 }, + lifecycle: { stage: 'growth' }, + } + const lines = summariseAssumptions(assumptions) + expect(lines).toHaveLength(5) + expect(lines[0]).toContain('Price increase of 10%') + expect(lines[1]).toContain('bogo promotion') + expect(lines[2]).toContain('1 holiday/event day') + expect(lines[3]).toContain('500 units') + expect(lines[4]).toContain('growth') + }) +}) + +function makeRow(scenarioId: string, name: string, rank: number): ScenarioComparisonRow { + return { + scenario_id: scenarioId, + name, + units_delta: 0, + revenue_delta: 0, + coverage_verdict: 'unknown', + rank, + } +} + +describe('buildMultiSeries', () => { + it('puts baseline first, then one line per scenario keyed by scenario_id', () => { + const comparison: MultiScenarioComparison = { + baseline_total_units: 100, + baseline_revenue: 1000, + rank_by: 'revenue_delta', + scenarios: [makeRow('sid-a', 'Cut', 1), makeRow('sid-b', 'Rise', 2)], + chart_series: [], + } + expect(buildMultiSeries(comparison)).toEqual([ + { key: 'baseline', label: 'Baseline' }, + { key: 'sid-a', label: 'Cut' }, + { key: 'sid-b', label: 'Rise' }, + ]) + }) + + it('returns only the baseline line when there are no scenarios', () => { + const comparison: MultiScenarioComparison = { + baseline_total_units: 0, + baseline_revenue: 0, + rank_by: 'units_delta', + scenarios: [], + chart_series: [], + } + expect(buildMultiSeries(comparison)).toEqual([{ key: 'baseline', label: 'Baseline' }]) + }) +}) + +describe('methodLabel', () => { + it('labels each comparison method', () => { + expect(methodLabel('heuristic')).toBe('Heuristic') + expect(methodLabel('model_exogenous')).toBe('Model-driven') + }) +}) diff --git a/frontend/src/lib/scenario-utils.ts b/frontend/src/lib/scenario-utils.ts new file mode 100644 index 00000000..bbd7b7ed --- /dev/null +++ b/frontend/src/lib/scenario-utils.ts @@ -0,0 +1,137 @@ +/** + * Pure helpers for the What-If Planner page. + * + * No React, no I/O — every function here is unit-tested in + * scenario-utils.test.ts. The planner composes these to turn a + * `ScenarioComparison` into chart rows, a delta table, and readable summaries. + */ +import type { CsvColumn } from '@/lib/csv-export' +import type { + CoverageVerdict, + MultiScenarioComparison, + ScenarioAssumptions, + ScenarioPoint, +} from '@/types/api' + +/** One charted day: a date plus the baseline and scenario demand values. */ +export interface ComparisonChartRow { + date: string + baseline: number + scenario: number + // Index signature so the row is assignable to TimeSeriesChart's data prop. + [key: string]: string | number | null | undefined +} + +/** Flatten comparison points into the two-series rows TimeSeriesChart renders. */ +export function mergeComparisonSeries(points: ScenarioPoint[]): ComparisonChartRow[] { + return points.map((point) => ({ + date: point.date, + baseline: point.baseline, + scenario: point.scenario, + })) +} + +/** Format a number with an explicit sign (+1.5 / -2.0 / 0.0). */ +export function formatDelta(value: number, decimals = 1): string { + const sign = value > 0 ? '+' : '' + return `${sign}${value.toFixed(decimals)}` +} + +/** Human label for a coverage verdict. */ +export function coverageLabel(verdict: CoverageVerdict): string { + switch (verdict) { + case 'covered': + return 'Covered' + case 'at_risk': + return 'At risk' + case 'stockout': + return 'Stockout' + default: + return 'Unknown' + } +} + +/** StatusBadge variant for a coverage verdict. */ +export function coverageVariant( + verdict: CoverageVerdict, +): 'success' | 'warning' | 'error' | 'default' { + switch (verdict) { + case 'covered': + return 'success' + case 'at_risk': + return 'warning' + case 'stockout': + return 'error' + default: + return 'default' + } +} + +/** CSV columns for the per-day delta-table export. */ +export const deltaCsvColumns: CsvColumn[] = [ + { key: 'date', header: 'Date' }, + { key: 'baseline', header: 'Baseline' }, + { key: 'scenario', header: 'Scenario' }, + { key: 'delta', header: 'Delta' }, + { key: 'applied_factor', header: 'Factor' }, +] + +/** Render the active what-if assumptions as human-readable bullet lines. */ +export function summariseAssumptions(assumptions: ScenarioAssumptions): string[] { + const lines: string[] = [] + + if (assumptions.price) { + const pct = Math.round(assumptions.price.change_pct * 100) + const verb = pct < 0 ? 'cut' : 'increase' + lines.push( + `Price ${verb} of ${Math.abs(pct)}% from ${assumptions.price.start_date} ` + + `to ${assumptions.price.end_date}`, + ) + } + if (assumptions.promotion) { + lines.push( + `${assumptions.promotion.kind} promotion from ${assumptions.promotion.start_date} ` + + `to ${assumptions.promotion.end_date}`, + ) + } + if (assumptions.holiday && assumptions.holiday.dates.length > 0) { + const count = assumptions.holiday.dates.length + lines.push(`${count} holiday/event day${count === 1 ? '' : 's'}`) + } + if (assumptions.inventory) { + lines.push(`On-hand stock of ${assumptions.inventory.on_hand_units} units`) + } + if (assumptions.lifecycle) { + lines.push(`Lifecycle stage forced to "${assumptions.lifecycle.stage}"`) + } + + if (lines.length === 0) { + lines.push('No assumptions — baseline only') + } + return lines +} + +/** A line in the multi-scenario comparison chart. */ +export interface MultiSeriesLine { + key: string + label: string +} + +/** + * Derive the chart lines for a multi-scenario comparison: the shared baseline + * first, then one line per scenario keyed by `scenario_id` — matching the keys + * the backend put in each `chart_series` row (a CSS-identifier-safe key) — and + * labelled by the plan name. + */ +export function buildMultiSeries(comparison: MultiScenarioComparison): MultiSeriesLine[] { + const lines: MultiSeriesLine[] = [{ key: 'baseline', label: 'Baseline' }] + for (const row of comparison.scenarios) { + lines.push({ key: row.scenario_id, label: row.name }) + } + return lines +} + +/** Human label for the method that produced a comparison. */ +export function methodLabel(method: 'heuristic' | 'model_exogenous'): string { + return method === 'model_exogenous' ? 'Model-driven' : 'Heuristic' +} diff --git a/frontend/src/lib/url-params.test.ts b/frontend/src/lib/url-params.test.ts new file mode 100644 index 00000000..fc81d889 --- /dev/null +++ b/frontend/src/lib/url-params.test.ts @@ -0,0 +1,48 @@ +import { describe, it, expect } from 'vitest' +import { parseEnumParam, parseIdParam, parsePageParam } from './url-params' + +describe('parsePageParam', () => { + it('returns the integer for a valid positive page', () => { + expect(parsePageParam('3')).toBe(3) + }) + + it('floors a fractional page above 1', () => { + expect(parsePageParam('2.9')).toBe(2) + }) + + it('falls back to 1 for null, non-numeric, zero, and negative input', () => { + expect(parsePageParam(null)).toBe(1) + expect(parsePageParam('')).toBe(1) + expect(parsePageParam('abc')).toBe(1) + expect(parsePageParam('0')).toBe(1) + expect(parsePageParam('-4')).toBe(1) + }) +}) + +describe('parseIdParam', () => { + it('returns a positive integer ID', () => { + expect(parseIdParam('42')).toBe(42) + }) + + it('returns undefined for null, empty, non-numeric, fractional, or non-positive input', () => { + expect(parseIdParam(null)).toBeUndefined() + expect(parseIdParam('')).toBeUndefined() + expect(parseIdParam('abc')).toBeUndefined() + expect(parseIdParam('1.5')).toBeUndefined() + expect(parseIdParam('0')).toBeUndefined() + expect(parseIdParam('-3')).toBeUndefined() + }) +}) + +describe('parseEnumParam', () => { + const allowed = ['asc', 'desc'] as const + + it('returns the value when it is a member of the allow-list', () => { + expect(parseEnumParam('desc', allowed)).toBe('desc') + }) + + it('returns undefined for null or an unknown value', () => { + expect(parseEnumParam(null, allowed)).toBeUndefined() + expect(parseEnumParam('sideways', allowed)).toBeUndefined() + }) +}) diff --git a/frontend/src/lib/url-params.ts b/frontend/src/lib/url-params.ts new file mode 100644 index 00000000..a9270ed0 --- /dev/null +++ b/frontend/src/lib/url-params.ts @@ -0,0 +1,48 @@ +/** + * Safe readers for URL query-string state. + * + * The explorer pages treat the query string as the single source of truth for + * filter / sort / page state, so a hand-edited, stale, or truncated URL can + * carry a NaN page, a negative page, or an unknown enum value. These helpers + * validate at the read boundary so a junk param degrades to a sane default + * instead of reaching a hook (and the API) unverified. + */ + +/** + * Parse a `page` query param into a positive integer (>= 1). + * + * `null`, non-numeric, zero, negative, and fractional inputs all fall back to + * `1`; a fractional input above 1 is floored (`"2.9"` -> `2`). + */ +export function parsePageParam(value: string | null): number { + const parsed = Number(value) + if (!Number.isFinite(parsed) || parsed < 1) return 1 + return Math.floor(parsed) +} + +/** + * Parse a positive-integer ID query param (`store_id`, `product_id`, ...). + * + * Returns `undefined` for `null`, empty, non-numeric, fractional, or + * non-positive input — never `NaN`. + */ +export function parseIdParam(value: string | null): number | undefined { + if (value === null || value === '') return undefined + const parsed = Number(value) + return Number.isInteger(parsed) && parsed >= 1 ? parsed : undefined +} + +/** + * Return `value` only when it is a member of `allowed`; otherwise `undefined`. + * + * Use for enum-typed query params (status, model_type, sort_by, dimension, ...) + * so an unknown value is dropped rather than blind-cast into a typed slot. + */ +export function parseEnumParam( + value: string | null, + allowed: readonly T[], +): T | undefined { + return value !== null && (allowed as readonly string[]).includes(value) + ? (value as T) + : undefined +} diff --git a/frontend/src/pages/admin.tsx b/frontend/src/pages/admin.tsx index b933238c..a5de6024 100644 --- a/frontend/src/pages/admin.tsx +++ b/frontend/src/pages/admin.tsx @@ -18,8 +18,14 @@ import { History, Percent, Bot, + Library, } from 'lucide-react' -import { useRagSources, useDeleteRagSource, useIndexDocument } from '@/hooks/use-rag-sources' +import { + useRagSources, + useDeleteRagSource, + useIndexDocument, + useIndexProjectDocs, +} from '@/hooks/use-rag-sources' import { useAliases, useDeleteAlias, useCreateAlias } from '@/hooks/use-runs' import { useSeederStatus, @@ -117,6 +123,7 @@ function RagSourcesPanel() { const { data, isLoading, error, refetch } = useRagSources() const deleteSource = useDeleteRagSource() const indexDocument = useIndexDocument() + const indexProjectDocs = useIndexProjectDocs() const [newSource, setNewSource] = useState({ type: 'markdown', path: '' }) const [isDialogOpen, setIsDialogOpen] = useState(false) @@ -131,6 +138,19 @@ function RagSourcesPanel() { setIsDialogOpen(false) } + const handleIndexProjectDocs = async () => { + try { + const r = await indexProjectDocs.mutateAsync({}) + const summary = + `Indexed ${r.indexed}, updated ${r.updated}, unchanged ${r.unchanged}, ` + + `${r.failed} failed — ${r.total_chunks} chunks` + if (r.failed > 0) toast.warning(summary) + else toast.success(summary) + } catch (err) { + toast.error(err instanceof Error ? err.message : 'Project-docs indexing failed') + } + } + const handleDelete = async (sourceId: string) => { await deleteSource.mutateAsync(sourceId) } @@ -152,6 +172,20 @@ function RagSourcesPanel() { {data?.total_sources ?? 0} sources • {data?.total_chunks ?? 0} chunks
    +
    + +
    {data?.sources.length ? ( diff --git a/frontend/src/pages/chat.tsx b/frontend/src/pages/chat.tsx index 1a987401..cc22a9d5 100644 --- a/frontend/src/pages/chat.tsx +++ b/frontend/src/pages/chat.tsx @@ -240,9 +240,9 @@ export default function ChatPage() {

    {agentType === 'rag_assistant' ? 'RAG Assistant' : 'Experiment Agent'} •{' '} {wsStatus === 'connected' ? ( - Connected + Connected ) : ( - {wsStatus} + {wsStatus} )}

    diff --git a/frontend/src/pages/explorer/job-detail.tsx b/frontend/src/pages/explorer/job-detail.tsx index e112b18a..77b15fed 100644 --- a/frontend/src/pages/explorer/job-detail.tsx +++ b/frontend/src/pages/explorer/job-detail.tsx @@ -21,6 +21,7 @@ import { AlertDialogTitle, AlertDialogTrigger, } from '@/components/ui/alert-dialog' +import { toast } from 'sonner' import { ROUTES } from '@/lib/constants' function fmtDate(value: string | null | undefined): string { @@ -74,10 +75,16 @@ export default function JobDetailPage() { const job = jobQuery.data async function handleCancel() { - await cancelJob.mutateAsync(jobId) - // useCancelJob invalidates ['jobs']; refresh this detail query explicitly - // so the page reflects the cancelled status immediately. - void queryClient.invalidateQueries({ queryKey: ['jobs', jobId] }) + // mutateAsync rejects on failure — catch it so a cancel error surfaces as + // a toast instead of an unhandled promise rejection. + try { + await cancelJob.mutateAsync(jobId) + // useCancelJob invalidates ['jobs']; refresh this detail query explicitly + // so the page reflects the cancelled status immediately. + void queryClient.invalidateQueries({ queryKey: ['jobs', jobId] }) + } catch (err) { + toast.error(err instanceof Error ? err.message : 'Failed to cancel job') + } } return ( diff --git a/frontend/src/pages/explorer/jobs.tsx b/frontend/src/pages/explorer/jobs.tsx index 380a732c..ee864dc5 100644 --- a/frontend/src/pages/explorer/jobs.tsx +++ b/frontend/src/pages/explorer/jobs.tsx @@ -21,7 +21,9 @@ import { AlertDialogTitle, AlertDialogTrigger, } from '@/components/ui/alert-dialog' +import { toast } from 'sonner' import { toCsv, downloadCsv, type CsvColumn } from '@/lib/csv-export' +import { parsePageParam } from '@/lib/url-params' import type { Job, JobStatus, JobType } from '@/types/api' import { DEFAULT_PAGE_SIZE } from '@/lib/constants' @@ -42,7 +44,9 @@ export default function JobsMonitorPage() { // so a pasted URL reproduces the exact view. const jobType = searchParams.get('job_type') ?? undefined const status = searchParams.get('status') ?? undefined - const page = Number(searchParams.get('page')) || 1 + // Clamp `page` to a positive integer — a hand-edited NaN/negative value + // would otherwise reach the API as-is. + const page = parsePageParam(searchParams.get('page')) const sortBy = searchParams.get('sort_by') ?? undefined const sortOrder: 'asc' | 'desc' = searchParams.get('sort_order') === 'desc' ? 'desc' : 'asc' @@ -64,7 +68,13 @@ export default function JobsMonitorPage() { const cancelJob = useCancelJob() const handleCancelJob = async (jobId: string) => { - await cancelJob.mutateAsync(jobId) + // mutateAsync rejects on failure — catch it so a cancel error surfaces as + // a toast instead of an unhandled promise rejection. + try { + await cancelJob.mutateAsync(jobId) + } catch (err) { + toast.error(err instanceof Error ? err.message : 'Failed to cancel job') + } } const columns: ColumnDef[] = [ diff --git a/frontend/src/pages/explorer/products.tsx b/frontend/src/pages/explorer/products.tsx index 893ee81e..ecc604c9 100644 --- a/frontend/src/pages/explorer/products.tsx +++ b/frontend/src/pages/explorer/products.tsx @@ -9,6 +9,7 @@ import { ErrorDisplay } from '@/components/common/error-display' import { Button } from '@/components/ui/button' import { formatCurrency } from '@/lib/api' import { toCsv, downloadCsv, type CsvColumn } from '@/lib/csv-export' +import { parsePageParam } from '@/lib/url-params' import type { Product } from '@/types/api' import { DEFAULT_PAGE_SIZE } from '@/lib/constants' @@ -62,7 +63,9 @@ export default function ProductsExplorerPage() { // URL query string is the single source of truth for filter/sort/page state. const search = searchParams.get('search') ?? '' const category = searchParams.get('category') ?? undefined - const page = Number(searchParams.get('page')) || 1 + // Clamp `page` to a positive integer — a hand-edited NaN/negative value + // would otherwise reach the API as-is. + const page = parsePageParam(searchParams.get('page')) const sortBy = searchParams.get('sort_by') ?? undefined const sortOrder: 'asc' | 'desc' = searchParams.get('sort_order') === 'desc' ? 'desc' : 'asc' diff --git a/frontend/src/pages/explorer/run-detail.tsx b/frontend/src/pages/explorer/run-detail.tsx index 1a41ca5e..f769e07f 100644 --- a/frontend/src/pages/explorer/run-detail.tsx +++ b/frontend/src/pages/explorer/run-detail.tsx @@ -10,6 +10,8 @@ import { ShieldCheck, } from 'lucide-react' import { useRun, useVerifyArtifact } from '@/hooks/use-runs' +import { useRunExplanation } from '@/hooks/use-explanations' +import { ExplanationPanel } from '@/components/explainability/explanation-panel' import { JsonBlock } from '@/components/common/json-block' import { ErrorDisplay } from '@/components/common/error-display' import { LoadingState } from '@/components/common/loading-state' @@ -41,6 +43,9 @@ export default function RunDetailPage() { const [verifyOn, setVerifyOn] = useState(false) const verifyQuery = useVerifyArtifact(runId ?? '', verifyOn) + // The explanation panel self-handles a 400 for non-baseline (lightgbm) runs. + const explanationQuery = useRunExplanation(runId ?? '', !!runId) + if (!runId) { return (
    @@ -158,6 +163,12 @@ export default function RunDetailPage() { + +
    @@ -243,8 +254,8 @@ export default function RunDetailPage() { !verifyQuery.isFetching && verifyQuery.data && (verifyQuery.data.verified ? ( -
    - +
    + Artifact verified — the stored checksum matches. {verifyQuery.data.computed_hash && ( diff --git a/frontend/src/pages/explorer/runs.tsx b/frontend/src/pages/explorer/runs.tsx index c6fb43fc..5fc1bf63 100644 --- a/frontend/src/pages/explorer/runs.tsx +++ b/frontend/src/pages/explorer/runs.tsx @@ -11,9 +11,16 @@ import { getStatusVariant } from '@/lib/status-utils' import { ErrorDisplay } from '@/components/common/error-display' import { Button } from '@/components/ui/button' import { toCsv, downloadCsv, type CsvColumn } from '@/lib/csv-export' +import { parseEnumParam, parsePageParam } from '@/lib/url-params' import type { ModelRun, RunStatus } from '@/types/api' import { DEFAULT_PAGE_SIZE, ROUTES } from '@/lib/constants' +// Allow-lists for the URL-driven filter/sort params. A hand-edited URL value +// outside these is dropped (treated as "no filter") rather than blind-cast. +const RUN_STATUSES: readonly RunStatus[] = ['pending', 'running', 'success', 'failed', 'archived'] +const MODEL_TYPES = ['naive', 'seasonal_naive', 'moving_average', 'lightgbm'] as const +const RUN_SORT_KEYS = ['created_at', 'model_type', 'status', 'store_id', 'product_id'] as const + const columns: ColumnDef[] = [ { accessorKey: 'run_id', @@ -89,11 +96,12 @@ export default function RunsExplorerPage() { const [searchParams, setSearchParams] = useSearchParams() // URL query string is the single source of truth for filter/sort/page state, - // so a pasted URL reproduces the exact view. - const modelType = searchParams.get('model_type') ?? undefined - const status = searchParams.get('status') ?? undefined - const page = Number(searchParams.get('page')) || 1 - const sortBy = searchParams.get('sort_by') ?? undefined + // so a pasted URL reproduces the exact view. Each param is validated so a + // junk value degrades gracefully instead of reaching the API. + const modelType = parseEnumParam(searchParams.get('model_type'), MODEL_TYPES) + const status = parseEnumParam(searchParams.get('status'), RUN_STATUSES) + const page = parsePageParam(searchParams.get('page')) + const sortBy = parseEnumParam(searchParams.get('sort_by'), RUN_SORT_KEYS) const sortOrder: 'asc' | 'desc' = searchParams.get('sort_order') === 'desc' ? 'desc' : 'asc' const pagination: PaginationState = { @@ -106,7 +114,7 @@ export default function RunsExplorerPage() { page, pageSize: pagination.pageSize, modelType, - status: status as RunStatus | undefined, + status, sortBy, sortOrder: sortBy ? sortOrder : undefined, }) diff --git a/frontend/src/pages/explorer/sales.tsx b/frontend/src/pages/explorer/sales.tsx index ba32036b..eff5ee6d 100644 --- a/frontend/src/pages/explorer/sales.tsx +++ b/frontend/src/pages/explorer/sales.tsx @@ -14,18 +14,28 @@ import { Badge } from '@/components/ui/badge' import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card' import { Tabs, TabsContent, TabsList, TabsTrigger } from '@/components/ui/tabs' import { dateRangeToStrings, stringsToDateRange } from '@/lib/date-utils' +import { parseEnumParam, parseIdParam } from '@/lib/url-params' import { formatCurrency, formatNumber } from '@/lib/api' import type { DrilldownDimension } from '@/types/api' +/** The drilldown dimensions a shareable URL is allowed to select. */ +const DRILLDOWN_DIMENSIONS: readonly DrilldownDimension[] = [ + 'store', + 'product', + 'category', + 'region', + 'date', +] + export default function SalesExplorerPage() { const [searchParams, setSearchParams] = useSearchParams() // dimension + cross-filter state live in the URL so the view is shareable. - const dimension = (searchParams.get('dimension') as DrilldownDimension | null) ?? 'store' - const storeIdParam = searchParams.get('store_id') - const productIdParam = searchParams.get('product_id') - const storeId = storeIdParam ? Number(storeIdParam) : undefined - const productId = productIdParam ? Number(productIdParam) : undefined + // A hand-edited URL can carry an unknown dimension or a NaN id — validate + // both before they reach the drilldown/timeseries hooks. + const dimension = parseEnumParam(searchParams.get('dimension'), DRILLDOWN_DIMENSIONS) ?? 'store' + const storeId = parseIdParam(searchParams.get('store_id')) + const productId = parseIdParam(searchParams.get('product_id')) const startParam = searchParams.get('start_date') const endParam = searchParams.get('end_date') diff --git a/frontend/src/pages/explorer/stores.tsx b/frontend/src/pages/explorer/stores.tsx index 99914107..b5df3c39 100644 --- a/frontend/src/pages/explorer/stores.tsx +++ b/frontend/src/pages/explorer/stores.tsx @@ -8,9 +8,17 @@ import { DataTableColumnHeader } from '@/components/data-table/data-table-column import { ErrorDisplay } from '@/components/common/error-display' import { Button } from '@/components/ui/button' import { toCsv, downloadCsv, type CsvColumn } from '@/lib/csv-export' +import { parseEnumParam, parsePageParam } from '@/lib/url-params' import type { Store } from '@/types/api' import { DEFAULT_PAGE_SIZE } from '@/lib/constants' +// Allow-lists for the URL-driven filter/sort params — these mirror the filter +// dropdown options and the backend sort allow-list. A value outside them is +// dropped rather than blind-cast. +const REGIONS = ['North', 'South', 'East', 'West'] as const +const STORE_TYPES = ['Supermarket', 'Convenience', 'Hypermarket'] as const +const STORE_SORT_KEYS = ['code', 'name', 'region', 'city', 'store_type'] as const + const columns: ColumnDef[] = [ { accessorKey: 'id', @@ -59,12 +67,13 @@ export default function StoresExplorerPage() { const [searchParams, setSearchParams] = useSearchParams() // URL query string is the single source of truth for filter/sort/page state, - // so a pasted URL reproduces the exact view. + // so a pasted URL reproduces the exact view. Each param is validated so a + // junk value degrades gracefully instead of reaching the API. const search = searchParams.get('search') ?? '' - const region = searchParams.get('region') ?? undefined - const storeType = searchParams.get('store_type') ?? undefined - const page = Number(searchParams.get('page')) || 1 - const sortBy = searchParams.get('sort_by') ?? undefined + const region = parseEnumParam(searchParams.get('region'), REGIONS) + const storeType = parseEnumParam(searchParams.get('store_type'), STORE_TYPES) + const page = parsePageParam(searchParams.get('page')) + const sortBy = parseEnumParam(searchParams.get('sort_by'), STORE_SORT_KEYS) const sortOrder: 'asc' | 'desc' = searchParams.get('sort_order') === 'desc' ? 'desc' : 'asc' const pagination: PaginationState = { diff --git a/frontend/src/pages/guide.tsx b/frontend/src/pages/guide.tsx index 787a24e8..fe18ca6f 100644 --- a/frontend/src/pages/guide.tsx +++ b/frontend/src/pages/guide.tsx @@ -196,8 +196,12 @@ export default function GuidePage() { {tool} )) - ) : ( + ) : configLoading ? ( + ) : ( + + Unavailable — the configuration endpoint could not be reached. + )}
    diff --git a/frontend/src/pages/ops.tsx b/frontend/src/pages/ops.tsx new file mode 100644 index 00000000..04a5ba1a --- /dev/null +++ b/frontend/src/pages/ops.tsx @@ -0,0 +1,652 @@ +import { useState } from 'react' +import { useNavigate, Link } from 'react-router-dom' +import { Activity, AlertTriangle, CheckCircle2, Clock, Download, RefreshCw } from 'lucide-react' +import { toast } from 'sonner' +import { useModelHealth, useOpsSummary, useRetrainingCandidates } from '@/hooks/use-ops' +import { useProviderHealth } from '@/hooks/use-config' +import { useCreateJob } from '@/hooks/use-jobs' +import { useCreateAlias } from '@/hooks/use-runs' +import { + attentionBadgeVariant, + attentionItemLink, + driftBadgeVariant, + formatStaleness, + formatWapeDelta, + sortRetrainingCandidates, + summaryHealthVariant, +} from '@/lib/ops-utils' +import { getStatusVariant } from '@/lib/status-utils' +import { KPICard } from '@/components/charts/kpi-card' +import { EmptyState, ErrorDisplay } from '@/components/common/error-display' +import { LoadingState } from '@/components/common/loading-state' +import { StatusBadge } from '@/components/common/status-badge' +import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card' +import { + Table, + TableBody, + TableCell, + TableHead, + TableHeader, + TableRow, +} from '@/components/ui/table' +import { Button } from '@/components/ui/button' +import { + DropdownMenu, + DropdownMenuContent, + DropdownMenuItem, + DropdownMenuTrigger, +} from '@/components/ui/dropdown-menu' +import { + AlertDialog, + AlertDialogAction, + AlertDialogCancel, + AlertDialogContent, + AlertDialogDescription, + AlertDialogFooter, + AlertDialogHeader, + AlertDialogTitle, +} from '@/components/ui/alert-dialog' +import { Checkbox } from '@/components/ui/checkbox' +import { Input } from '@/components/ui/input' +import { downloadCsv, toCsv } from '@/lib/csv-export' +import { attentionCsvColumns, buildIncidentMarkdown, downloadMarkdown } from '@/lib/incident-report' +import { buildRetrainJob } from '@/lib/ops-actions' +import { api, formatPercent, getErrorMessage } from '@/lib/api' +import { ROUTES } from '@/lib/constants' +import type { ModelRun } from '@/types/api' + +/** The run + grain a "Promote to alias" dialog is currently targeting. */ +interface PromoteTarget { + runId: string + storeId: number + productId: number +} + +/** Format an ISO timestamp / date string for display; '—' when null. */ +function formatWhen(value: string | null): string { + if (!value) return '—' + const parsed = new Date(value) + return Number.isNaN(parsed.getTime()) ? value : parsed.toLocaleString() +} + +/** A labelled health row inside the System Health card. */ +function HealthRow({ label, ok, detail }: { label: string; ok: boolean; detail?: string }) { + return ( +
    + {label} + + {detail && {detail}} + {ok ? 'ok' : 'down'} + +
    + ) +} + +/** A labelled value pair for the Data Freshness card. */ +function FreshnessRow({ label, value }: { label: string; value: string }) { + return ( +
    + {label} + {value} +
    + ) +} + +export default function OpsPage() { + const navigate = useNavigate() + const summaryQuery = useOpsSummary() + const candidatesQuery = useRetrainingCandidates() + const modelHealthQuery = useModelHealth() + const providerQuery = useProviderHealth() + const createJob = useCreateJob() + const createAlias = useCreateAlias() + const [selected, setSelected] = useState>(new Set()) + const [retrainConfirmOpen, setRetrainConfirmOpen] = useState(false) + const [actionBusy, setActionBusy] = useState(false) + const [promoteTarget, setPromoteTarget] = useState(null) + const [aliasName, setAliasName] = useState('') + + if (summaryQuery.error) { + return ( +
    +

    Control Center

    + void summaryQuery.refetch()} /> +
    + ) + } + + if (summaryQuery.isLoading || !summaryQuery.data) { + return ( +
    +

    Control Center

    + +
    + ) + } + + const summary = summaryQuery.data + const providers = providerQuery.data ?? [] + const totalJobs = summary.jobs.counts.reduce((sum, c) => sum + c.count, 0) + const totalRuns = summary.runs.counts.reduce((sum, c) => sum + c.count, 0) + const staleAliases = summary.aliases.filter((a) => a.is_stale).length + const candidates = sortRetrainingCandidates(candidatesQuery.data?.candidates ?? []) + const modelHealthEntries = modelHealthQuery.data?.entries ?? [] + + /** Download the needs-attention list as a CSV, built client-side. */ + function handleExportCsv() { + downloadCsv('ops-attention-items.csv', toCsv(summary.attention_items, attentionCsvColumns)) + } + + /** Download the full operational snapshot as a Markdown incident report. */ + function handleExportMarkdown() { + downloadMarkdown( + 'ops-incident-report.md', + buildIncidentMarkdown(summary, candidates, modelHealthEntries), + ) + } + + /** Stable selection key for a (store, product) grain. */ + const grainKey = (storeId: number, productId: number) => `${storeId}-${productId}` + + /** Toggle one grain in the bulk-retrain selection set. */ + function toggleSelected(key: string) { + setSelected((prev) => { + const next = new Set(prev) + if (next.has(key)) { + next.delete(key) + } else { + next.add(key) + } + return next + }) + } + + // Selected candidates that carry a source run — the bulk-retrain work list. + const selectedCandidates = candidates.filter( + (candidate) => + selected.has(grainKey(candidate.store_id, candidate.product_id)) && + candidate.latest_run_id !== null, + ) + + /** + * Bulk-retrain every selected grain. POST /jobs runs synchronously + * server-side, so jobs are fired SEQUENTIALLY (never Promise.all) with a + * per-item toast; the action layer reuses the existing /jobs endpoint. + */ + async function runBulkRetrain() { + setRetrainConfirmOpen(false) + setActionBusy(true) + let succeeded = 0 + let failed = 0 + for (const candidate of selectedCandidates) { + const runId = candidate.latest_run_id + if (runId === null) continue + const where = `store ${candidate.store_id} / product ${candidate.product_id}` + try { + const run = await api(`/registry/runs/${runId}`) + await createJob.mutateAsync(buildRetrainJob(run, summary.freshness.latest_sales_date)) + succeeded += 1 + toast.success(`Retrain queued — ${where}`) + } catch (error) { + failed += 1 + toast.error(`Retrain failed — ${where}: ${getErrorMessage(error)}`) + } + } + setSelected(new Set()) + setActionBusy(false) + toast.message(`Bulk retrain complete — ${succeeded} queued, ${failed} failed`) + } + + /** Open the promote-to-alias dialog for a grain's latest successful run. */ + function openPromote(runId: string | null, storeId: number, productId: number) { + if (runId === null) return + setAliasName('') + setPromoteTarget({ runId, storeId, productId }) + } + + /** Promote the targeted run to a deployment alias via POST /registry/aliases. */ + async function runPromote() { + if (promoteTarget === null) return + const target = promoteTarget + const name = aliasName.trim() + setActionBusy(true) + try { + await createAlias.mutateAsync({ alias_name: name, run_id: target.runId }) + toast.success(`Promoted run to alias '${name}'`) + } catch (error) { + toast.error(`Promote failed: ${getErrorMessage(error)}`) + } + setActionBusy(false) + setPromoteTarget(null) + } + + return ( +
    +
    +
    +

    Control Center

    +

    + One operational view across jobs, model runs, deployment aliases, and data + freshness — surfacing what needs attention before it affects decisions. +

    +
    + + + + + + CSV — attention items + + Markdown — full report + + + +
    + + {totalJobs === 0 && totalRuns === 0 ? ( + } + action={{ label: 'Go to Showcase', onClick: () => navigate(ROUTES.SHOWCASE) }} + /> + ) : ( + <> +
    + {/* Section 1 — System Health */} + + +
    + System Health + + {summaryHealthVariant(summary.system) === 'success' ? 'healthy' : 'degraded'} + +
    + API, database, and embedding-provider reachability. +
    + +
    + + + {providers.map((p) => ( + + ))} +
    +
    + Latest successful job: + + {formatWhen(summary.system.latest_successful_job_at)} + +
    +
    +
    + Jobs + {summary.jobs.counts.map((c) => ( + + {c.status} {c.count} + + ))} +
    +
    + Runs + {summary.runs.counts.map((c) => ( + + {c.status} {c.count} + + ))} +
    +
    +
    +
    + + {/* Section 3 — Data Freshness */} + + + Data Freshness + How current the data and model state are. + + +
    + + + +
    +
    +
    +
    + + {/* Section 2 — KPI row */} +
    + + + + +
    + + {/* Section 4 — Needs Attention */} + + + Needs Attention + + Recent failed jobs, failed runs, and stale deployment aliases. Each row links + to its Explorer detail page. + + + + {summary.attention_items.length === 0 ? ( +

    + Nothing needs attention — no failed jobs, failed runs, or stale aliases. +

    + ) : ( + + + + Type + Item + Detail + When + + + + {summary.attention_items.map((item) => ( + + + + {item.item_type.replace('_', ' ')} + + + + + {item.label} + + + + {item.detail} + + + {formatWhen(item.occurred_at)} + + + ))} + +
    + )} +
    +
    + + {/* Section 5 — Model Health */} + + + Model Health + + Forecast-error (WAPE) drift per store / product, classified from each grain's + successful-run history. Degrading grains are listed first. + + + + {modelHealthQuery.isLoading ? ( + + ) : modelHealthEntries.length === 0 ? ( +

    + No model health to evaluate — no successful model runs yet. +

    + ) : ( + + + + Store + Product + Drift + Latest WAPE + Δ WAPE + Runs + Action + + + + {modelHealthEntries.map((entry) => ( + + {entry.store_id} + {entry.product_id} + + + {entry.drift_direction} + + + + {entry.latest_wape === null ? '—' : entry.latest_wape.toFixed(1)} + + + {formatWapeDelta(entry.wape_delta)} + + + {entry.run_count} + + + + + + ))} + +
    + )} +
    +
    + + {/* Section 6 — Retraining Queue */} + + +
    +
    + Retraining Queue + + Store / product pairs ranked by a retraining-priority score that blends + staleness with forecast error (WAPE). Select rows to retrain in bulk. + +
    + +
    +
    + + {candidatesQuery.isLoading ? ( + + ) : candidates.length === 0 ? ( +

    + No retraining candidates — no successful model runs to evaluate yet. +

    + ) : ( + + + + + Select + + Store + Product + Priority + Staleness + WAPE + Reason + Action + + + + {candidates.map((candidate) => { + const key = grainKey(candidate.store_id, candidate.product_id) + return ( + + + toggleSelected(key)} + aria-label={`Select store ${candidate.store_id} product ${candidate.product_id}`} + /> + + + {candidate.store_id} + + + {candidate.product_id} + + + {candidate.priority_score.toFixed(2)} + + + {formatStaleness(candidate.staleness_days)} + + + {candidate.wape === null ? '—' : candidate.wape.toFixed(1)} + + + {candidate.reason} + + + + + + ) + })} + +
    + )} +
    +
    + + )} + + {/* Confirm gate — bulk retrain of the selected grains. */} + + + + + Retrain {selectedCandidates.length} grain + {selectedCandidates.length === 1 ? '' : 's'}? + + + This creates one training job per selected store / product via the existing + POST /jobs endpoint. Jobs run sequentially and each may take a moment; the + outcome of every job is reported individually. + + + + Cancel + void runBulkRetrain()}>Retrain + + + + + {/* Confirm gate — promote a run to a deployment alias. */} + { + if (!open) setPromoteTarget(null) + }} + > + + + Promote to alias + + {promoteTarget + ? `Point a deployment alias at the latest successful run for store ${promoteTarget.storeId} / product ${promoteTarget.productId}. An existing alias of the same name is repointed.` + : ''} + + +
    + + setAliasName(event.target.value)} + placeholder="e.g. production" + autoComplete="off" + /> +
    + + Cancel + void runPromote()} + > + Promote + + +
    +
    +
    + ) +} diff --git a/frontend/src/pages/showcase.tsx b/frontend/src/pages/showcase.tsx index 2b1d8687..34035812 100644 --- a/frontend/src/pages/showcase.tsx +++ b/frontend/src/pages/showcase.tsx @@ -110,7 +110,7 @@ export default function ShowcasePage() { @@ -118,7 +118,7 @@ export default function ShowcasePage() { {summary.overallStatus === 'pass' diff --git a/frontend/src/pages/visualize/backtest.tsx b/frontend/src/pages/visualize/backtest.tsx index a2ce517f..3e36d763 100644 --- a/frontend/src/pages/visualize/backtest.tsx +++ b/frontend/src/pages/visualize/backtest.tsx @@ -85,7 +85,15 @@ export default function BacktestPage() { // Extract backtest result from job const backtestResult = job?.result as BacktestResult | undefined - const formReady = !!storeId && !!productId && !!dateRange?.from && !!dateRange?.to + // The number inputs can be cleared to 0; require a valid split count and + // test size so an invalid backtest config can never be submitted. + const formReady = + !!storeId && + !!productId && + !!dateRange?.from && + !!dateRange?.to && + nSplits >= 2 && + testSize >= 1 async function handleRunBacktest() { if (!storeId || !productId || !dateRange?.from || !dateRange?.to) return @@ -210,7 +218,8 @@ export default function BacktestPage() { {!formReady && ( - Pick a store, product and date window to enable. + Pick a store, product and date window, with at least 2 splits and a 1-day test + size, to enable. )}
    diff --git a/frontend/src/pages/visualize/demand.tsx b/frontend/src/pages/visualize/demand.tsx index 46f549d5..30bf7204 100644 --- a/frontend/src/pages/visualize/demand.tsx +++ b/frontend/src/pages/visualize/demand.tsx @@ -83,6 +83,17 @@ function SortHead({ onSort(columnKey)} + // Keyboard-operable sort header: focusable, fires on Enter / Space, and + // exposes the current sort direction to assistive tech. + onKeyDown={(event) => { + if (event.key === 'Enter' || event.key === ' ') { + event.preventDefault() + onSort(columnKey) + } + }} + tabIndex={0} + role="button" + aria-sort={active ? (sortDir === 'asc' ? 'ascending' : 'descending') : 'none'} > {label} @@ -304,6 +315,15 @@ export default function DemandPlannerPage() { setSelectedJobId(row.jobId)} + // Keyboard-operable row: focusable and fires on Enter / Space. + onKeyDown={(event) => { + if (event.key === 'Enter' || event.key === ' ') { + event.preventDefault() + setSelectedJobId(row.jobId) + } + }} + tabIndex={0} + role="button" className={cn( 'cursor-pointer', row.jobId === selectedJobId && 'bg-muted', diff --git a/frontend/src/pages/visualize/forecast.tsx b/frontend/src/pages/visualize/forecast.tsx index bda6f244..bed6cfe9 100644 --- a/frontend/src/pages/visualize/forecast.tsx +++ b/frontend/src/pages/visualize/forecast.tsx @@ -2,6 +2,8 @@ import { useState } from 'react' import { Link } from 'react-router-dom' import { BarChart3, Download, ExternalLink, Loader2, Play } from 'lucide-react' import { useJob, useCreateJob } from '@/hooks/use-jobs' +import { useJobExplanation } from '@/hooks/use-explanations' +import { ExplanationPanel } from '@/components/explainability/explanation-panel' import { TimeSeriesChart } from '@/components/charts/time-series-chart' import { EmptyState } from '@/components/common/error-display' import { JobPicker } from '@/components/common/job-picker' @@ -48,11 +50,20 @@ export default function ForecastPage() { // A completed `predict` job stores result.forecasts (date + forecast, plus // optional lower/upper bounds for models that emit a prediction interval). - const forecastData = job?.result?.forecasts as ForecastPoint[] | undefined - const hasBounds = !!forecastData?.some( + // `job.result` is untyped JSONB — guard with Array.isArray before treating + // `forecasts` as an array so a malformed result can never throw on `.some()`. + const rawForecasts = job?.result?.forecasts + const forecastData: ForecastPoint[] = Array.isArray(rawForecasts) + ? (rawForecasts as ForecastPoint[]) + : [] + const hasBounds = forecastData.some( (point) => point.lower_bound != null && point.upper_bound != null, ) + // Explain the loaded job only when it is a completed predict job. + const isPredictDone = job?.status === 'completed' && job?.job_type === 'predict' + const explanationQuery = useJobExplanation(job?.job_id ?? '', !!job && isPredictDone) + async function handleRunForecast() { if (!trainRunId) return setRunError(null) @@ -68,7 +79,7 @@ export default function ForecastPage() { } function handleExport() { - if (!forecastData || !job) return + if (forecastData.length === 0 || !job) return downloadCsv(`forecast-${job.job_id}.csv`, toCsv(forecastData, csvColumns)) } @@ -189,7 +200,7 @@ export default function ForecastPage() {
    {/* Forecast Chart */} - {forecastData && forecastData.length > 0 ? ( + {forecastData.length > 0 ? (