From e82d9e0d781230ffc763374943d3df2fe42d1c5c Mon Sep 17 00:00:00 2001 From: Gabor Szabo Date: Tue, 19 May 2026 12:29:33 +0200 Subject: [PATCH 1/2] feat(api,ui): add forecast explainability & driver attribution slice (#224) --- alembic/env.py | 1 + ...4cb44_create_forecast_explanation_table.py | 133 ++++++ app/features/explainability/__init__.py | 5 + app/features/explainability/explainers.py | 272 +++++++++++ app/features/explainability/models.py | 81 ++++ app/features/explainability/reason_codes.py | 186 ++++++++ app/features/explainability/routes.py | 169 +++++++ app/features/explainability/schemas.py | 147 ++++++ app/features/explainability/service.py | 421 ++++++++++++++++++ app/features/explainability/tests/__init__.py | 1 + app/features/explainability/tests/conftest.py | 241 ++++++++++ .../explainability/tests/test_explainers.py | 161 +++++++ .../tests/test_models_integration.py | 62 +++ .../explainability/tests/test_reason_codes.py | 128 ++++++ .../explainability/tests/test_routes.py | 153 +++++++ .../tests/test_routes_integration.py | 85 ++++ .../explainability/tests/test_schemas.py | 134 ++++++ .../explainability/tests/test_service.py | 179 ++++++++ app/main.py | 2 + docs/_base/API_CONTRACTS.md | 3 + .../explainability/explanation-panel.test.tsx | 66 +++ .../explainability/explanation-panel.tsx | 218 +++++++++ frontend/src/hooks/use-explanations.ts | 48 ++ frontend/src/pages/explorer/run-detail.tsx | 11 + frontend/src/pages/visualize/forecast.tsx | 15 + frontend/src/types/api.ts | 41 ++ 26 files changed, 2963 insertions(+) create mode 100644 alembic/versions/f84258c4cb44_create_forecast_explanation_table.py create mode 100644 app/features/explainability/__init__.py create mode 100644 app/features/explainability/explainers.py create mode 100644 app/features/explainability/models.py create mode 100644 app/features/explainability/reason_codes.py create mode 100644 app/features/explainability/routes.py create mode 100644 app/features/explainability/schemas.py create mode 100644 app/features/explainability/service.py create mode 100644 app/features/explainability/tests/__init__.py create mode 100644 app/features/explainability/tests/conftest.py create mode 100644 app/features/explainability/tests/test_explainers.py create mode 100644 app/features/explainability/tests/test_models_integration.py create mode 100644 app/features/explainability/tests/test_reason_codes.py create mode 100644 app/features/explainability/tests/test_routes.py create mode 100644 app/features/explainability/tests/test_routes_integration.py create mode 100644 app/features/explainability/tests/test_schemas.py create mode 100644 app/features/explainability/tests/test_service.py create mode 100644 frontend/src/components/explainability/explanation-panel.test.tsx create mode 100644 frontend/src/components/explainability/explanation-panel.tsx create mode 100644 frontend/src/hooks/use-explanations.ts diff --git a/alembic/env.py b/alembic/env.py index 4c4db209..dd83996b 100644 --- a/alembic/env.py +++ b/alembic/env.py @@ -15,6 +15,7 @@ from app.features.agents import models as agents_models # noqa: F401 from app.features.config import models as config_models # noqa: F401 from app.features.data_platform import models as data_platform_models # noqa: F401 +from app.features.explainability import models as explainability_models # noqa: F401 from app.features.jobs import models as jobs_models # noqa: F401 from app.features.rag import models as rag_models # noqa: F401 from app.features.registry import models as registry_models # noqa: F401 diff --git a/alembic/versions/f84258c4cb44_create_forecast_explanation_table.py b/alembic/versions/f84258c4cb44_create_forecast_explanation_table.py new file mode 100644 index 00000000..5d5cede9 --- /dev/null +++ b/alembic/versions/f84258c4cb44_create_forecast_explanation_table.py @@ -0,0 +1,133 @@ +"""create forecast explanation table + +Revision ID: f84258c4cb44 +Revises: 7e8f9748581e +Create Date: 2026-05-19 11:46:00.062839 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = 'f84258c4cb44' +down_revision: Union[str, None] = '7e8f9748581e' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Apply migration — create the forecast_explanation table.""" + op.create_table( + 'forecast_explanation', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('explanation_id', sa.String(length=32), nullable=False), + sa.Column('run_id', sa.String(length=32), nullable=True), + sa.Column('job_id', sa.String(length=32), nullable=True), + sa.Column('store_id', sa.Integer(), nullable=False), + sa.Column('product_id', sa.Integer(), nullable=False), + sa.Column('model_type', sa.String(length=50), nullable=False), + sa.Column('method', sa.String(length=20), nullable=False), + sa.Column('as_of_date', sa.Date(), nullable=False), + sa.Column('forecast_value', sa.Float(), nullable=False), + sa.Column('confidence', sa.String(length=10), nullable=False), + sa.Column('drivers', postgresql.JSONB(astext_type=sa.Text()), nullable=False), + sa.Column('reason_codes', postgresql.JSONB(astext_type=sa.Text()), nullable=False), + sa.Column('caveats', postgresql.JSONB(astext_type=sa.Text()), nullable=False), + sa.Column('agent_summary', sa.String(length=2000), nullable=False), + sa.Column( + 'created_at', + sa.DateTime(timezone=True), + server_default=sa.text('now()'), + nullable=False, + ), + sa.Column( + 'updated_at', + sa.DateTime(timezone=True), + server_default=sa.text('now()'), + nullable=False, + ), + sa.CheckConstraint( + "confidence IN ('high', 'medium', 'low')", + name='ck_forecast_explanation_confidence', + ), + sa.CheckConstraint( + "method IN ('rule_based', 'shap', 'component')", + name='ck_forecast_explanation_method', + ), + sa.PrimaryKeyConstraint('id'), + ) + op.create_index( + op.f('ix_forecast_explanation_explanation_id'), + 'forecast_explanation', + ['explanation_id'], + unique=True, + ) + op.create_index( + op.f('ix_forecast_explanation_run_id'), + 'forecast_explanation', + ['run_id'], + unique=False, + ) + op.create_index( + op.f('ix_forecast_explanation_job_id'), + 'forecast_explanation', + ['job_id'], + unique=False, + ) + op.create_index( + op.f('ix_forecast_explanation_store_id'), + 'forecast_explanation', + ['store_id'], + unique=False, + ) + op.create_index( + op.f('ix_forecast_explanation_product_id'), + 'forecast_explanation', + ['product_id'], + unique=False, + ) + op.create_index( + 'ix_forecast_explanation_drivers_gin', + 'forecast_explanation', + ['drivers'], + unique=False, + postgresql_using='gin', + ) + op.create_index( + 'ix_forecast_explanation_store_product', + 'forecast_explanation', + ['store_id', 'product_id'], + unique=False, + ) + + +def downgrade() -> None: + """Revert migration — drop the forecast_explanation table.""" + op.drop_index( + 'ix_forecast_explanation_store_product', table_name='forecast_explanation' + ) + op.drop_index( + 'ix_forecast_explanation_drivers_gin', + table_name='forecast_explanation', + postgresql_using='gin', + ) + op.drop_index( + op.f('ix_forecast_explanation_product_id'), table_name='forecast_explanation' + ) + op.drop_index( + op.f('ix_forecast_explanation_store_id'), table_name='forecast_explanation' + ) + op.drop_index( + op.f('ix_forecast_explanation_job_id'), table_name='forecast_explanation' + ) + op.drop_index( + op.f('ix_forecast_explanation_run_id'), table_name='forecast_explanation' + ) + op.drop_index( + op.f('ix_forecast_explanation_explanation_id'), + table_name='forecast_explanation', + ) + op.drop_table('forecast_explanation') diff --git a/app/features/explainability/__init__.py b/app/features/explainability/__init__.py new file mode 100644 index 00000000..ae3fa227 --- /dev/null +++ b/app/features/explainability/__init__.py @@ -0,0 +1,5 @@ +"""Forecast explainability & driver-attribution vertical slice (PRP-28). + +Rule-based, deterministic explanations for the three baseline forecasters. +SHAP is deliberately out of scope — see PRP-28. +""" diff --git a/app/features/explainability/explainers.py b/app/features/explainability/explainers.py new file mode 100644 index 00000000..8c4b9d3e --- /dev/null +++ b/app/features/explainability/explainers.py @@ -0,0 +1,272 @@ +"""Rule-based, deterministic explainers for the three baseline forecasters. + +Each explainer MIRRORS the exact h=1 math of the matching forecaster in +``app/features/forecasting/models.py`` — a rule-based explainer is *exact*, not +an approximation. ``test_explainers.py`` asserts each explainer's forecast value +equals the real forecaster's ``.fit(y).predict(1)[0]`` on the same series. + +A driver with ``contribution == 0.0`` is informational context only — the +baseline model does not consume it. The sum of all driver contributions equals +the forecast value. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any + +import numpy as np + +from app.features.explainability.schemas import ( + ConfidenceLevel, + Direction, + DriverContribution, +) + +# A 1-D float series, matching the forecasters' target-array type. +FloatArray = np.ndarray[Any, np.dtype[np.floating[Any]]] + +# Below this many observations a naive explanation is treated as low-confidence. +_NAIVE_MIN_COMFORTABLE = 14 + + +def _direction(value: float) -> Direction: + """Map a signed value to a driver direction literal.""" + if value > 0: + return "positive" + if value < 0: + return "negative" + return "neutral" + + +class BaseExplainer(ABC): + """Abstract base for a rule-based forecast explainer.""" + + @abstractmethod + def explain(self, y: FloatArray) -> tuple[float, list[DriverContribution]]: + """Decompose the h=1 forecast into named driver contributions. + + Args: + y: The historical target series (time-ordered, ``<= as_of_date``). + + Returns: + The h=1 forecast value and its ordered driver contributions. + + Raises: + ValueError: If the series is too short to produce the forecast. + """ + + @abstractmethod + def confidence(self, y: FloatArray) -> ConfidenceLevel: + """Return a qualitative confidence band for the explanation. + + Args: + y: The historical target series. + + Returns: + The confidence band. + """ + + +class NaiveExplainer(BaseExplainer): + """Explainer for the naive forecaster — the forecast IS the last value.""" + + def explain(self, y: FloatArray) -> tuple[float, list[DriverContribution]]: + """Decompose the naive h=1 forecast. + + Args: + y: The historical target series. + + Returns: + The h=1 forecast (``y[-1]``) and its driver contributions. + + Raises: + ValueError: If ``y`` is empty. + """ + if len(y) == 0: + raise ValueError("Cannot explain an empty series") + forecast = float(y[-1]) + drivers = [ + DriverContribution( + name="last_observation", + feature_value=forecast, + contribution=forecast, + direction="positive", + description="The naive forecast is exactly the last observed value.", + ) + ] + if len(y) >= _NAIVE_MIN_COMFORTABLE: + trend = float(np.mean(y[-7:]) - np.mean(y[-14:-7])) + drivers.append( + DriverContribution( + name="recent_trend", + feature_value=trend, + contribution=0.0, + direction=_direction(trend), + description=( + "Context only — week-over-week change in mean demand; " + "the naive model does not use trend." + ), + ) + ) + return forecast, drivers + + def confidence(self, y: FloatArray) -> ConfidenceLevel: + """Return ``LOW`` for a short series, otherwise ``MEDIUM``.""" + if len(y) < _NAIVE_MIN_COMFORTABLE: + return ConfidenceLevel.LOW + return ConfidenceLevel.MEDIUM + + +class SeasonalNaiveExplainer(BaseExplainer): + """Explainer for the seasonal-naive forecaster — last season's value.""" + + def __init__(self, season_length: int = 7) -> None: + """Initialise the explainer. + + Args: + season_length: Seasonal period in days (must be >= 1). + + Raises: + ValueError: If ``season_length`` < 1. + """ + if season_length < 1: + raise ValueError(f"season_length must be >= 1, got {season_length}") + self.season_length = season_length + + def explain(self, y: FloatArray) -> tuple[float, list[DriverContribution]]: + """Decompose the seasonal-naive h=1 forecast. + + Args: + y: The historical target series. + + Returns: + The h=1 forecast (``y[-season_length]``) and its driver contributions. + + Raises: + ValueError: If ``y`` has fewer observations than ``season_length``. + """ + if len(y) < self.season_length: + raise ValueError(f"Need at least {self.season_length} observations") + forecast = float(y[-self.season_length]) + drivers = [ + DriverContribution( + name="season_match", + feature_value=forecast, + contribution=forecast, + direction="positive", + description=( + f"The forecast repeats the value observed {self.season_length} " + "days ago (one seasonal cycle back)." + ), + ) + ] + return forecast, drivers + + def confidence(self, y: FloatArray) -> ConfidenceLevel: + """Return ``LOW`` for under two seasonal cycles, otherwise ``MEDIUM``.""" + if len(y) < 2 * self.season_length: + return ConfidenceLevel.LOW + return ConfidenceLevel.MEDIUM + + +class MovingAverageExplainer(BaseExplainer): + """Explainer for the moving-average forecaster — mean of the last window.""" + + def __init__(self, window_size: int = 7) -> None: + """Initialise the explainer. + + Args: + window_size: Averaging window in days (must be >= 1). + + Raises: + ValueError: If ``window_size`` < 1. + """ + if window_size < 1: + raise ValueError(f"window_size must be >= 1, got {window_size}") + self.window_size = window_size + + def explain(self, y: FloatArray) -> tuple[float, list[DriverContribution]]: + """Decompose the moving-average h=1 forecast. + + Args: + y: The historical target series. + + Returns: + The h=1 forecast (``mean(y[-window_size:])``) and driver contributions. + + Raises: + ValueError: If ``y`` has fewer observations than ``window_size``. + """ + if len(y) < self.window_size: + raise ValueError(f"Need at least {self.window_size} observations") + window = y[-self.window_size :] + forecast = float(np.mean(window)) + dispersion = float(np.std(window)) + drivers = [ + DriverContribution( + name="window_mean", + feature_value=forecast, + contribution=forecast, + direction="positive", + description=( + f"The forecast is the mean of the last {self.window_size} observed values." + ), + ), + DriverContribution( + name="window_dispersion", + feature_value=dispersion, + contribution=0.0, + direction="neutral", + description=( + "Context only — standard deviation within the averaging " + "window; higher values mean a noisier, less reliable mean." + ), + ), + ] + return forecast, drivers + + def confidence(self, y: FloatArray) -> ConfidenceLevel: + """Return ``HIGH`` for a stable full window, ``MEDIUM``/``LOW`` otherwise.""" + if len(y) < self.window_size: + return ConfidenceLevel.LOW + window = y[-self.window_size :] + mean = float(np.mean(window)) + std = float(np.std(window)) + cv = std / mean if mean > 0 else 0.0 + if cv < 0.5: + return ConfidenceLevel.HIGH + return ConfidenceLevel.MEDIUM + + +def explainer_factory( + model_type: str, + season_length: int | None = None, + window_size: int | None = None, +) -> BaseExplainer: + """Build the rule-based explainer for a baseline model type. + + Args: + model_type: One of ``naive``, ``seasonal_naive``, ``moving_average``. + season_length: Seasonal period for ``seasonal_naive`` (defaults to 7). + window_size: Averaging window for ``moving_average`` (defaults to 7). + + Returns: + The matching explainer instance. + + Raises: + ValueError: For ``lightgbm``/``regression`` (MVP scope guard) or an + unknown model type. + """ + if model_type == "naive": + return NaiveExplainer() + if model_type == "seasonal_naive": + return SeasonalNaiveExplainer(season_length=season_length or 7) + if model_type == "moving_average": + return MovingAverageExplainer(window_size=window_size or 7) + if model_type in ("lightgbm", "regression"): + raise ValueError( + f"Explanations are available for baseline models only; " + f"'{model_type}' is not supported (rule-based MVP)." + ) + raise ValueError(f"Unknown model type: {model_type}") diff --git a/app/features/explainability/models.py b/app/features/explainability/models.py new file mode 100644 index 00000000..e12538c0 --- /dev/null +++ b/app/features/explainability/models.py @@ -0,0 +1,81 @@ +"""ORM model for the explainability slice. + +A ``forecast_explanation`` row persists one rule-based explanation: the driver +breakdown, advisory reason codes, and caveats as JSONB, plus scalar columns for +the forecast context. Persisting it means a re-requested explanation is a cheap +read and gives the slice an audit trail. + +GOTCHA: SQLAlchemy reserves the declarative attribute name ``metadata`` — no +column here uses it. +""" + +from __future__ import annotations + +import datetime +from typing import Any + +from sqlalchemy import CheckConstraint, Date, Float, Index, Integer, String +from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.orm import Mapped, mapped_column + +from app.core.database import Base +from app.shared.models import TimestampMixin + + +class ForecastExplanation(TimestampMixin, Base): + """A persisted rule-based forecast explanation. + + Attributes: + id: Surrogate primary key. + explanation_id: Unique external identifier (UUID hex, 32 chars). + run_id: Originating registry run, when explained via ``/explain/runs``. + job_id: Originating predict job, when explained via ``/explain/jobs``. + store_id: Store the forecast targets. + product_id: Product the forecast targets. + model_type: Baseline model type explained. + method: Explanation method — always ``rule_based`` for the MVP. + as_of_date: Series cutoff date. + forecast_value: The h=1 forecast value. + confidence: Qualitative confidence band (``high|medium|low``). + drivers: Driver contributions as JSONB. + reason_codes: Advisory reason codes as JSONB. + caveats: Plain-language caveats as a JSONB string array. + agent_summary: One-paragraph natural-language summary. + """ + + __tablename__ = "forecast_explanation" + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + explanation_id: Mapped[str] = mapped_column(String(32), unique=True, index=True) + run_id: Mapped[str | None] = mapped_column(String(32), nullable=True, index=True) + job_id: Mapped[str | None] = mapped_column(String(32), nullable=True, index=True) + store_id: Mapped[int] = mapped_column(Integer, index=True) + product_id: Mapped[int] = mapped_column(Integer, index=True) + model_type: Mapped[str] = mapped_column(String(50)) + method: Mapped[str] = mapped_column(String(20), default="rule_based") + as_of_date: Mapped[datetime.date] = mapped_column(Date) + forecast_value: Mapped[float] = mapped_column(Float) + confidence: Mapped[str] = mapped_column(String(10)) + + # JSONB blobs — never named ``metadata`` (SQLAlchemy reserves it). + drivers: Mapped[list[dict[str, Any]]] = mapped_column(JSONB, nullable=False) + reason_codes: Mapped[list[dict[str, Any]]] = mapped_column(JSONB, nullable=False) + caveats: Mapped[list[str]] = mapped_column(JSONB, nullable=False) + + agent_summary: Mapped[str] = mapped_column(String(2000)) + + __table_args__ = ( + # GIN index for JSONB containment queries on the driver breakdown. + Index("ix_forecast_explanation_drivers_gin", "drivers", postgresql_using="gin"), + # Composite index for the common "explanations for this store/product" query. + Index("ix_forecast_explanation_store_product", "store_id", "product_id"), + # Kept in lock-step with the alembic migration that created this table. + CheckConstraint( + "confidence IN ('high', 'medium', 'low')", + name="ck_forecast_explanation_confidence", + ), + CheckConstraint( + "method IN ('rule_based', 'shap', 'component')", + name="ck_forecast_explanation_method", + ), + ) diff --git a/app/features/explainability/reason_codes.py b/app/features/explainability/reason_codes.py new file mode 100644 index 00000000..d047abbe --- /dev/null +++ b/app/features/explainability/reason_codes.py @@ -0,0 +1,186 @@ +"""Advisory retail reason-code engine for the explainability slice. + +These are PURE functions — they perform no database access and take only +primitive inputs. The service layer runs the time-safe queries and extracts the +primitives; every input below therefore reflects only data the caller already +bounded ``<= as_of_date`` (or, for ``holiday_reason``, the explained horizon +date). + +CRITICAL: a reason code is an advisory *correlation* signal, never a causal +claim. ``build_caveats`` always emits the NIST-grounded correlation-vs-causation +disclaimer. +""" + +from __future__ import annotations + +from collections.abc import Sequence +from datetime import date as date_type + +from app.features.explainability.schemas import ReasonCode + + +def stockout_reason(stockout_flags: Sequence[bool]) -> ReasonCode | None: + """Flag stockout-suppressed history in the trailing window. + + Args: + stockout_flags: One ``is_stockout`` flag per day in the trailing window + (already bounded ``<= as_of_date`` by the caller). + + Returns: + A ``stockout_constrained`` warning when any day was a stockout, + otherwise ``None``. + """ + stockout_days = sum(1 for flag in stockout_flags if flag) + if stockout_days == 0: + return None + return ReasonCode( + code="stockout_constrained", + severity="warn", + detail=( + f"{stockout_days} stockout day(s) in the trailing " + f"{len(stockout_flags)}-day window — observed demand may understate " + "true demand because units could not be sold while out of stock." + ), + ) + + +def promotion_reason( + promotion_windows: Sequence[tuple[date_type, date_type]], as_of_date: date_type +) -> ReasonCode | None: + """Flag promotions overlapping the trailing window. + + Args: + promotion_windows: ``(start_date, end_date)`` tuples for promotions + overlapping the trailing window (already bounded ``<= as_of_date``). + as_of_date: The series cutoff date. + + Returns: + A ``promotion_overlap`` info code when any promotion overlaps, + otherwise ``None``. + """ + if not promotion_windows: + return None + active_now = sum(1 for start, end in promotion_windows if start <= as_of_date <= end) + active_clause = f"; {active_now} still active on {as_of_date.isoformat()}" if active_now else "" + return ReasonCode( + code="promotion_overlap", + severity="info", + detail=( + f"{len(promotion_windows)} promotion(s) overlap the trailing window" + f"{active_clause} — promotional demand may not represent the baseline." + ), + ) + + +def lifecycle_reason(launch_date: date_type | None, as_of_date: date_type) -> ReasonCode | None: + """Flag a product still in its early lifecycle. + + Args: + launch_date: The product's launch date, or ``None`` if unknown. + as_of_date: The series cutoff date. + + Returns: + A ``lifecycle_decay`` info code when the product launched fewer than + 30 days before ``as_of_date``, otherwise ``None``. + """ + if launch_date is None: + return None + days_since_launch = (as_of_date - launch_date).days + if 0 <= days_since_launch < 30: + return ReasonCode( + code="lifecycle_decay", + severity="info", + detail=( + f"Product launched {days_since_launch} day(s) ago — early-" + "lifecycle demand is volatile and may not represent a stable " + "baseline." + ), + ) + return None + + +def holiday_reason( + is_holiday: bool, holiday_name: str | None, forecast_date: date_type +) -> ReasonCode | None: + """Flag a holiday landing on the explained forecast horizon. + + Args: + is_holiday: Whether ``forecast_date`` is flagged as a holiday. + holiday_name: The holiday's name, when known. + forecast_date: The date of the explained h=1 forecast. + + Returns: + A ``holiday_effect`` info code when ``forecast_date`` is a holiday, + otherwise ``None``. + """ + if not is_holiday: + return None + name = holiday_name or "a holiday" + return ReasonCode( + code="holiday_effect", + severity="info", + detail=( + f"The forecast date {forecast_date.isoformat()} is {name} — " + "holiday demand typically deviates from a normal day." + ), + ) + + +def history_reason(n_obs: int, min_required: int) -> ReasonCode | None: + """Flag a series too short for a comfortable explanation. + + Args: + n_obs: Number of observations in the series. + min_required: Minimum comfortable observation count for the model. + + Returns: + An ``insufficient_history`` warning when ``n_obs < min_required``, + otherwise ``None``. + """ + if n_obs < min_required: + return ReasonCode( + code="insufficient_history", + severity="warn", + detail=( + f"Only {n_obs} observation(s) available; {min_required} or more " + "is recommended for a confident explanation." + ), + ) + return None + + +# The NIST-grounded disclaimer baked into every explanation (see +# https://www.nist.gov/itl/ai-risk-management-framework). +CORRELATION_CAVEAT = ( + "Drivers describe correlation and contribution, not business causality — " + "they explain the model's arithmetic, not why demand moved." +) + +_MODEL_CAVEATS: dict[str, str] = { + "naive": "The naive model ignores seasonality and trend entirely.", + "seasonal_naive": "The seasonal-naive model assumes the prior cycle repeats exactly.", + "moving_average": "The moving-average model smooths over recent shifts in demand.", +} + + +def build_caveats(model_type: str, reason_codes: Sequence[ReasonCode]) -> list[str]: + """Assemble the caveat list for an explanation. + + Args: + model_type: The baseline model type explained. + reason_codes: The reason codes already computed for the explanation. + + Returns: + Plain-language caveats, always starting with the correlation-vs- + causation disclaimer. + """ + caveats = [CORRELATION_CAVEAT] + model_caveat = _MODEL_CAVEATS.get(model_type) + if model_caveat is not None: + caveats.append(model_caveat) + codes = {rc.code for rc in reason_codes} + if "stockout_constrained" in codes: + caveats.append("Stockout days in the history mean the forecast may understate true demand.") + if "insufficient_history" in codes: + caveats.append("The short history makes this explanation less reliable than usual.") + return caveats diff --git a/app/features/explainability/routes.py b/app/features/explainability/routes.py new file mode 100644 index 00000000..ae93e2bc --- /dev/null +++ b/app/features/explainability/routes.py @@ -0,0 +1,169 @@ +"""API routes for the explainability slice. + +Three endpoints under a self-owned ``/explain`` namespace produce rule-based +forecast explanations. Service-layer ``ValueError`` (unsupported model type, +too-short series) maps to an RFC 7807 400; a missing run/job maps to a 404; +``SQLAlchemyError`` maps to a 500 — never a bare ``HTTPException``. +""" + +from fastapi import APIRouter, Depends, status +from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.database import get_db +from app.core.exceptions import BadRequestError, DatabaseError, NotFoundError +from app.core.logging import get_logger +from app.features.explainability.schemas import ( + ExplainForecastRequest, + ForecastExplanation, +) +from app.features.explainability.service import ExplainabilityService + +logger = get_logger(__name__) + +router = APIRouter(prefix="/explain", tags=["explainability"]) + + +@router.post( + "/forecast", + response_model=ForecastExplanation, + status_code=status.HTTP_200_OK, + summary="Explain an ad-hoc baseline forecast", + description=""" +Compute a rule-based explanation for the h=1 forecast a named baseline model +would produce on the series ending at `as_of_date`. + +**Inputs:** `store_id`, `product_id`, `model_type` (`naive` / `seasonal_naive` / +`moving_average`), `as_of_date`, and optional `season_length` / `window_size`. + +**Output:** a `ForecastExplanation` — ordered driver contributions, advisory +retail reason codes (correlation, never causation), a confidence band, caveats, +and an agent-readable summary. The series and every reason-code query are +time-safe (`<= as_of_date`). + +An unsupported `model_type` (`lightgbm` / `regression`) or a series too short to +forecast returns an RFC 7807 400 — never a 500. +""", +) +async def explain_forecast( + request: ExplainForecastRequest, + db: AsyncSession = Depends(get_db), +) -> ForecastExplanation: + """Explain an ad-hoc baseline forecast. + + Args: + request: Store/product/model/cutoff parameters. + db: Async database session from dependency. + + Returns: + The rule-based forecast explanation. + + Raises: + BadRequestError: For an unsupported model type or a too-short series. + DatabaseError: When persistence fails. + """ + try: + return await ExplainabilityService().explain_forecast(db, request) + except ValueError as exc: + logger.warning("explainability.forecast_invalid", error=str(exc)) + raise BadRequestError(message=str(exc)) from exc + except SQLAlchemyError as exc: + logger.error("explainability.forecast_db_error", error=str(exc), exc_info=True) + raise DatabaseError( + message="Failed to generate forecast explanation", + details={"error": str(exc)}, + ) from exc + + +@router.get( + "/runs/{run_id}", + response_model=ForecastExplanation, + summary="Explain a registry model run", + description=""" +Explain a registry `model_run`. The baseline config is reconstructed from +`model_run.model_config`, and `data_window_end` is used as the series cutoff. + +A missing `run_id` returns a 404; a non-baseline run (`lightgbm` / `regression`) +returns a 400 — explanations are available for baseline models only. +""", +) +async def explain_run( + run_id: str, + db: AsyncSession = Depends(get_db), +) -> ForecastExplanation: + """Explain a registry model run. + + Args: + run_id: External run identifier. + db: Async database session from dependency. + + Returns: + The rule-based forecast explanation. + + Raises: + NotFoundError: When no run matches ``run_id``. + BadRequestError: For a non-baseline run or a too-short series. + DatabaseError: When persistence fails. + """ + try: + explanation = await ExplainabilityService().explain_run(db, run_id) + except ValueError as exc: + logger.warning("explainability.run_invalid", run_id=run_id, error=str(exc)) + raise BadRequestError(message=str(exc)) from exc + except SQLAlchemyError as exc: + logger.error("explainability.run_db_error", error=str(exc), exc_info=True) + raise DatabaseError( + message="Failed to generate run explanation", + details={"error": str(exc)}, + ) from exc + if explanation is None: + raise NotFoundError(message=f"Model run not found: {run_id}") + return explanation + + +@router.get( + "/jobs/{job_id}", + response_model=ForecastExplanation, + summary="Explain a completed predict job", + description=""" +Explain a completed `predict` job. `store_id`, `product_id`, and `model_type` +are read from `job.result`; the series cutoff is the day before the first +forecast date. + +A missing `job_id` returns a 404; a job that is not a completed predict job +returns a 400. +""", +) +async def explain_job( + job_id: str, + db: AsyncSession = Depends(get_db), +) -> ForecastExplanation: + """Explain a completed predict job. + + Args: + job_id: External job identifier. + db: Async database session from dependency. + + Returns: + The rule-based forecast explanation. + + Raises: + NotFoundError: When no job matches ``job_id``. + BadRequestError: When the job is not a completed predict job, or for a + too-short series. + DatabaseError: When persistence fails. + """ + try: + explanation = await ExplainabilityService().explain_job(db, job_id) + except ValueError as exc: + logger.warning("explainability.job_invalid", job_id=job_id, error=str(exc)) + raise BadRequestError(message=str(exc)) from exc + except SQLAlchemyError as exc: + logger.error("explainability.job_db_error", error=str(exc), exc_info=True) + raise DatabaseError( + message="Failed to generate job explanation", + details={"error": str(exc)}, + ) from exc + if explanation is None: + raise NotFoundError(message=f"Job not found: {job_id}") + return explanation diff --git a/app/features/explainability/schemas.py b/app/features/explainability/schemas.py new file mode 100644 index 00000000..9c767a09 --- /dev/null +++ b/app/features/explainability/schemas.py @@ -0,0 +1,147 @@ +"""Pydantic v2 schemas for the explainability slice. + +The response schemas (``DriverContribution``, ``ReasonCode``, +``ForecastExplanation``) are plain ``BaseModel`` — NOT ``strict=True`` — so they +serialise cleanly. The single request body (``ExplainForecastRequest``) IS +``strict=True``; its ``as_of_date`` field therefore carries ``Field(strict=False)`` +because ``date`` has no native JSON representation (see ``docs/_base/SECURITY.md`` +-> "Pydantic v2 strict mode on FastAPI request bodies"; enforced by +``app/core/tests/test_strict_mode_policy.py``). +""" + +from __future__ import annotations + +from datetime import UTC, datetime +from datetime import date as date_type +from enum import Enum +from typing import Literal + +from pydantic import BaseModel, ConfigDict, Field + +# Direction of a driver's influence on the forecast. +Direction = Literal["positive", "negative", "neutral"] + +# Advisory retail reason-code identifiers — correlation signals, never causal claims. +ReasonCodeId = Literal[ + "stockout_constrained", + "promotion_overlap", + "holiday_effect", + "lifecycle_decay", + "trend_shift", + "insufficient_history", +] + +# Baseline model types this slice can explain. ``lightgbm``/``regression`` are +# rejected with a clean 400 (MVP scope guard). +ExplainableModelType = Literal["naive", "seasonal_naive", "moving_average"] + + +class ConfidenceLevel(str, Enum): + """Qualitative confidence band for an explanation.""" + + HIGH = "high" + MEDIUM = "medium" + LOW = "low" + + +class DriverContribution(BaseModel): + """One named, interpretable demand driver behind a forecast. + + Attributes: + name: Stable machine identifier for the driver. + feature_value: The observed value of the underlying feature. + contribution: Amount (in model units) this driver adds to the forecast. + Informational/context drivers carry ``0.0``. + direction: Sign of the driver's influence. + description: Human-readable explanation of the driver. + """ + + name: str + feature_value: float + contribution: float + direction: Direction + description: str + + +class ReasonCode(BaseModel): + """An advisory retail signal correlated with the forecast. + + CRITICAL: reason codes describe correlation, never business causality. + + Attributes: + code: Machine-readable reason-code identifier. + severity: ``info`` for context, ``warn`` for a quality caveat. + detail: Human-readable detail for the signal. + """ + + code: ReasonCodeId + severity: Literal["info", "warn"] + detail: str + + +class ForecastExplanation(BaseModel): + """Structured, rule-based explanation of a baseline h=1 forecast. + + Attributes: + store_id: Store the forecast targets. + product_id: Product the forecast targets. + model_type: Baseline model type explained. + method: Always ``rule_based`` for the MVP (``shap``/``component`` reserved). + forecast_value: The h=1 forecast the baseline model produces. + drivers: Ordered, named driver contributions. + reason_codes: Advisory retail reason codes (correlation only). + confidence: Qualitative confidence band. + caveats: Plain-language caveats, always including the correlation-vs- + causation disclaimer. + agent_summary: One-paragraph natural-language summary for chat agents. + as_of_date: Series cutoff — no data past this date informs the explanation. + generated_at: UTC timestamp the explanation was produced. + """ + + model_config = ConfigDict(from_attributes=True) + + store_id: int + product_id: int + model_type: str + method: Literal["rule_based"] = "rule_based" + forecast_value: float + drivers: list[DriverContribution] + reason_codes: list[ReasonCode] + confidence: ConfidenceLevel + caveats: list[str] + agent_summary: str + as_of_date: date_type + generated_at: datetime = Field(default_factory=lambda: datetime.now(UTC)) + + +class ExplainForecastRequest(BaseModel): + """Request body for ``POST /explain/forecast``. + + Attributes: + store_id: Store ID to explain. + product_id: Product ID to explain. + model_type: Baseline model type to reproduce and explain. + as_of_date: Series cutoff date — the explainer reads only ``<= as_of_date``. + season_length: Seasonal period (``seasonal_naive`` only; defaults to 7). + window_size: Averaging window (``moving_average`` only; defaults to 7). + """ + + model_config = ConfigDict(strict=True) + + store_id: int = Field(..., ge=1, description="Store ID") + product_id: int = Field(..., ge=1, description="Product ID") + model_type: ExplainableModelType = Field(..., description="Baseline model type") + # ``date`` has no native JSON representation — ``strict=False`` lets FastAPI's + # ``validate_python`` accept an ISO-string body. Repo-wide policy; see module + # docstring. + as_of_date: date_type = Field( + ..., + strict=False, + description="Series cutoff date (the explainer reads only <= this date)", + ) + season_length: int | None = Field( + None, ge=1, le=365, description="Seasonal period for seasonal_naive (default 7)" + ) + window_size: int | None = Field( + None, ge=1, le=90, description="Averaging window for moving_average (default 7)" + ) diff --git a/app/features/explainability/service.py b/app/features/explainability/service.py new file mode 100644 index 00000000..2f10fe6e --- /dev/null +++ b/app/features/explainability/service.py @@ -0,0 +1,421 @@ +"""Service layer for the explainability slice. + +``ExplainabilityService`` is READ-ONLY with respect to every other slice. It +imports ``app.features.registry.models.ModelRun`` and +``app.features.jobs.models.Job`` directly, but ONLY as read-only data contracts +— a locked maintainer decision (PRP-28 "Open Questions & Decisions" #1), the +same pattern by which slices already import ``app.features.data_platform.models``. +It NEVER imports another slice's ``service.py``. + +To explain a run or job, the service re-loads the target series from +``sales_daily`` and re-fits a rule-based explainer from the stored config — it +does not reload the model artifact. Every series load and reason-code query is +bounded ``<= as_of_date`` (time-safety is load-bearing). +""" + +from __future__ import annotations + +import uuid +from datetime import date as date_type +from datetime import timedelta +from typing import Any + +import numpy as np +import structlog +from sqlalchemy import or_, select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.config import get_settings +from app.core.exceptions import BadRequestError +from app.features.data_platform.models import ( + Calendar, + InventorySnapshotDaily, + Product, + Promotion, + SalesDaily, +) +from app.features.explainability.explainers import FloatArray, explainer_factory +from app.features.explainability.models import ( + ForecastExplanation as ForecastExplanationORM, +) +from app.features.explainability.reason_codes import ( + build_caveats, + history_reason, + holiday_reason, + lifecycle_reason, + promotion_reason, + stockout_reason, +) +from app.features.explainability.schemas import ( + ConfidenceLevel, + DriverContribution, + ExplainForecastRequest, + ForecastExplanation, + ReasonCode, +) +from app.features.jobs.models import Job # read-only data contract — see module docstring +from app.features.registry.models import ( # read-only data contract — see module docstring + ModelRun, +) + +logger = structlog.get_logger(__name__) + +# Trailing window (days) used for stockout / promotion reason-code lookups. +_REASON_WINDOW_DAYS = 30 + + +class ExplainabilityService: + """Produces rule-based forecast explanations for the baseline models.""" + + def __init__(self) -> None: + """Initialise the service.""" + self.settings = get_settings() + + # ------------------------------------------------------------------ # + # Public entry points + # ------------------------------------------------------------------ # + + async def explain_forecast( + self, db: AsyncSession, request: ExplainForecastRequest + ) -> ForecastExplanation: + """Explain the h=1 forecast a baseline model produces ad hoc. + + Args: + db: Async database session. + request: Store/product/model/cutoff parameters. + + Returns: + The persisted forecast explanation. + + Raises: + ValueError: For an unsupported model type or a too-short series. + """ + return await self._explain( + db, + store_id=request.store_id, + product_id=request.product_id, + model_type=request.model_type, + as_of_date=request.as_of_date, + season_length=request.season_length, + window_size=request.window_size, + ) + + async def explain_run(self, db: AsyncSession, run_id: str) -> ForecastExplanation | None: + """Explain a registry ``model_run``. + + Args: + db: Async database session. + run_id: External run identifier. + + Returns: + The explanation, or ``None`` when the run does not exist. + + Raises: + ValueError: For a non-baseline run or a too-short series. + """ + run = ( + await db.execute(select(ModelRun).where(ModelRun.run_id == run_id)) + ).scalar_one_or_none() + if run is None: + return None + config: dict[str, Any] = run.model_config or {} + return await self._explain( + db, + store_id=run.store_id, + product_id=run.product_id, + model_type=run.model_type, + as_of_date=run.data_window_end, + season_length=config.get("season_length"), + window_size=config.get("window_size"), + run_id=run_id, + ) + + async def explain_job(self, db: AsyncSession, job_id: str) -> ForecastExplanation | None: + """Explain a completed ``predict`` job. + + Args: + db: Async database session. + job_id: External job identifier. + + Returns: + The explanation, or ``None`` when the job does not exist. + + Raises: + BadRequestError: When the job is not a completed predict job, or + its result carries no forecasts. + ValueError: For an unsupported model type or a too-short series. + """ + job = (await db.execute(select(Job).where(Job.job_id == job_id))).scalar_one_or_none() + if job is None: + return None + if job.job_type != "predict" or job.status != "completed": + raise BadRequestError( + message="explain_job requires a completed predict job", + details={"job_id": job_id, "job_type": job.job_type, "status": job.status}, + ) + result: dict[str, Any] = job.result or {} + forecasts: list[Any] = result.get("forecasts") or [] + if not forecasts: + raise BadRequestError( + message="predict job has no forecasts to explain", + details={"job_id": job_id}, + ) + store_id = result.get("store_id") + product_id = result.get("product_id") + model_type = result.get("model_type") + if store_id is None or product_id is None or model_type is None: + raise BadRequestError( + message="predict job result is missing store/product/model_type", + details={"job_id": job_id}, + ) + # as_of_date = the day before the first forecast date (PRP-28 assumption #4). + first_forecast_date = date_type.fromisoformat(forecasts[0]["date"]) + as_of_date = first_forecast_date - timedelta(days=1) + return await self._explain( + db, + store_id=int(store_id), + product_id=int(product_id), + model_type=str(model_type), + as_of_date=as_of_date, + # A predict job's result does not record season_length/window_size; + # the explainer falls back to the forecaster defaults (7). + season_length=None, + window_size=None, + job_id=job_id, + ) + + # ------------------------------------------------------------------ # + # Core + # ------------------------------------------------------------------ # + + async def _explain( + self, + db: AsyncSession, + *, + store_id: int, + product_id: int, + model_type: str, + as_of_date: date_type, + season_length: int | None, + window_size: int | None, + run_id: str | None = None, + job_id: str | None = None, + ) -> ForecastExplanation: + """Build, persist, and return one rule-based explanation.""" + explainer = explainer_factory(model_type, season_length, window_size) + y, _dates = await self._load_series(db, store_id, product_id, as_of_date) + forecast_value, drivers = explainer.explain(y) + confidence = explainer.confidence(y) + forecast_date = as_of_date + timedelta(days=1) + + reason_codes = await self._assemble_reason_codes( + db, + store_id=store_id, + product_id=product_id, + model_type=model_type, + as_of_date=as_of_date, + forecast_date=forecast_date, + season_length=season_length, + window_size=window_size, + n_obs=len(y), + ) + caveats = build_caveats(model_type, reason_codes) + agent_summary = self._build_agent_summary( + store_id=store_id, + product_id=product_id, + model_type=model_type, + forecast_value=forecast_value, + forecast_date=forecast_date, + drivers=drivers, + reason_codes=reason_codes, + confidence=confidence, + ) + explanation = ForecastExplanation( + store_id=store_id, + product_id=product_id, + model_type=model_type, + forecast_value=forecast_value, + drivers=drivers, + reason_codes=reason_codes, + confidence=confidence, + caveats=caveats, + agent_summary=agent_summary, + as_of_date=as_of_date, + ) + await self._persist(db, explanation, run_id=run_id, job_id=job_id) + logger.info( + "explainability.explanation_generated", + store_id=store_id, + product_id=product_id, + model_type=model_type, + confidence=confidence.value, + n_reason_codes=len(reason_codes), + ) + return explanation + + async def _load_series( + self, + db: AsyncSession, + store_id: int, + product_id: int, + end_date: date_type, + ) -> tuple[FloatArray, list[date_type]]: + """Load the time-ordered sales series, bounded ``<= end_date``. + + TIME-SAFETY: the ``date <= end_date`` bound is load-bearing — no data + past the cutoff may inform an explanation. + """ + stmt = ( + select(SalesDaily) + .where( + SalesDaily.store_id == store_id, + SalesDaily.product_id == product_id, + SalesDaily.date <= end_date, + ) + .order_by(SalesDaily.date) + ) + rows = (await db.execute(stmt)).scalars().all() + y: FloatArray = np.array([float(r.quantity) for r in rows], dtype=np.float64) + return y, [r.date for r in rows] + + async def _assemble_reason_codes( + self, + db: AsyncSession, + *, + store_id: int, + product_id: int, + model_type: str, + as_of_date: date_type, + forecast_date: date_type, + season_length: int | None, + window_size: int | None, + n_obs: int, + ) -> list[ReasonCode]: + """Run the time-safe reason-code queries and assemble the code list.""" + window_start = as_of_date - timedelta(days=_REASON_WINDOW_DAYS) + + inventory_rows = ( + ( + await db.execute( + select(InventorySnapshotDaily).where( + InventorySnapshotDaily.store_id == store_id, + InventorySnapshotDaily.product_id == product_id, + InventorySnapshotDaily.date <= as_of_date, + InventorySnapshotDaily.date >= window_start, + ) + ) + ) + .scalars() + .all() + ) + + promotion_rows = ( + ( + await db.execute( + select(Promotion).where( + Promotion.product_id == product_id, + or_(Promotion.store_id == store_id, Promotion.store_id.is_(None)), + Promotion.start_date <= as_of_date, + Promotion.end_date >= window_start, + ) + ) + ) + .scalars() + .all() + ) + + product = ( + await db.execute(select(Product).where(Product.id == product_id)) + ).scalar_one_or_none() + + calendar_row = ( + await db.execute(select(Calendar).where(Calendar.date == forecast_date)) + ).scalar_one_or_none() + + # Extract primitives — the reason-code engine is DB- and ORM-free. + stockout_flags = [row.is_stockout for row in inventory_rows] + promotion_windows = [(row.start_date, row.end_date) for row in promotion_rows] + launch_date = product.launch_date if product is not None else None + is_holiday = calendar_row.is_holiday if calendar_row is not None else False + holiday_name = calendar_row.holiday_name if calendar_row is not None else None + min_required = self._min_required_history(model_type, season_length, window_size) + + candidates = [ + stockout_reason(stockout_flags), + promotion_reason(promotion_windows, as_of_date), + lifecycle_reason(launch_date, as_of_date), + holiday_reason(is_holiday, holiday_name, forecast_date), + history_reason(n_obs, min_required), + ] + return [code for code in candidates if code is not None] + + @staticmethod + def _min_required_history( + model_type: str, season_length: int | None, window_size: int | None + ) -> int: + """Comfortable minimum observation count for a confident explanation.""" + if model_type == "seasonal_naive": + return 2 * (season_length or 7) + if model_type == "moving_average": + return 2 * (window_size or 7) + return 14 + + @staticmethod + def _build_agent_summary( + *, + store_id: int, + product_id: int, + model_type: str, + forecast_value: float, + forecast_date: date_type, + drivers: list[DriverContribution], + reason_codes: list[ReasonCode], + confidence: ConfidenceLevel, + ) -> str: + """Compose a one-paragraph natural-language summary for chat agents.""" + main_driver = drivers[0] + sentences = [ + f"For store {store_id} / product {product_id}, the {model_type} model " + f"forecasts {forecast_value:.1f} units for {forecast_date.isoformat()}.", + f"The forecast is driven by '{main_driver.name}' " + f"(value {main_driver.feature_value:.1f}).", + f"Explanation confidence is {confidence.value}.", + ] + if reason_codes: + codes = ", ".join(rc.code for rc in reason_codes) + sentences.append(f"Advisory retail signals present: {codes}.") + else: + sentences.append("No advisory retail signals were detected.") + return " ".join(sentences) + + async def _persist( + self, + db: AsyncSession, + explanation: ForecastExplanation, + *, + run_id: str | None, + job_id: str | None, + ) -> None: + """Persist the explanation as a ``forecast_explanation`` row. + + Uses ``flush``/``refresh`` — ``get_db`` auto-commits on success. + """ + row = ForecastExplanationORM( + explanation_id=uuid.uuid4().hex, + run_id=run_id, + job_id=job_id, + store_id=explanation.store_id, + product_id=explanation.product_id, + model_type=explanation.model_type, + method=explanation.method, + as_of_date=explanation.as_of_date, + forecast_value=explanation.forecast_value, + confidence=explanation.confidence.value, + drivers=[d.model_dump() for d in explanation.drivers], + reason_codes=[rc.model_dump() for rc in explanation.reason_codes], + caveats=list(explanation.caveats), + agent_summary=explanation.agent_summary, + ) + db.add(row) + await db.flush() + await db.refresh(row) diff --git a/app/features/explainability/tests/__init__.py b/app/features/explainability/tests/__init__.py new file mode 100644 index 00000000..4b942f89 --- /dev/null +++ b/app/features/explainability/tests/__init__.py @@ -0,0 +1 @@ +"""Tests for the explainability slice.""" diff --git a/app/features/explainability/tests/conftest.py b/app/features/explainability/tests/conftest.py new file mode 100644 index 00000000..64480968 --- /dev/null +++ b/app/features/explainability/tests/conftest.py @@ -0,0 +1,241 @@ +"""Test fixtures for the explainability slice. + +Unit fixtures supply numpy series and a ``make_mock_db`` factory that builds an +``AsyncMock`` session whose ``execute`` calls are scripted in order. Integration +fixtures (``@pytest.mark.integration``) seed a real ``docker compose`` Postgres +and clean up after themselves; ``forecast_explanation`` is a slice-private table +so its teardown wipes it whole. +""" + +from __future__ import annotations + +import datetime +import uuid +from collections.abc import AsyncGenerator +from datetime import date, timedelta +from decimal import Decimal +from types import SimpleNamespace +from typing import Any +from unittest.mock import AsyncMock, MagicMock + +import numpy as np +import pytest +from httpx import ASGITransport, AsyncClient +from sqlalchemy import delete +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine + +from app.core.config import get_settings +from app.core.database import get_db +from app.features.data_platform.models import Calendar, Product, SalesDaily, Store +from app.features.explainability.models import ForecastExplanation +from app.features.registry.models import ModelRun, RunStatus +from app.main import app + +# Test date range — kept narrow so the calendar teardown is precise. +TEST_START = date(2024, 1, 1) +TEST_DAYS = 90 +TEST_END = TEST_START + timedelta(days=TEST_DAYS - 1) + + +# ============================================================================= +# Unit fixtures — numpy series + scripted-mock DB factory +# ============================================================================= + + +@pytest.fixture +def sample_series() -> np.ndarray: + """A 60-observation float series with mild variation.""" + return np.array([float(10 + (i % 7)) for i in range(60)], dtype=np.float64) + + +@pytest.fixture +def flat_series() -> np.ndarray: + """A 30-observation constant series.""" + return np.full(30, 25.0, dtype=np.float64) + + +@pytest.fixture +def short_series() -> np.ndarray: + """A 3-observation series (shorter than every comfortable threshold).""" + return np.array([5.0, 7.0, 6.0], dtype=np.float64) + + +def mock_result(*, scalars: list[Any] | None = None, one: Any | None = None) -> MagicMock: + """Build a mock SQLAlchemy ``Result`` for one ``execute`` call.""" + result = MagicMock() + result.scalars.return_value.all.return_value = scalars or [] + result.scalar_one_or_none.return_value = one + return result + + +def make_mock_db(results: list[MagicMock]) -> AsyncMock: + """Build an ``AsyncMock`` session whose ``execute`` returns ``results`` in order. + + Args: + results: Mock ``Result`` objects (see ``mock_result``), one per expected + ``execute`` call, in call order. + + Returns: + A mock session ready to pass to ``ExplainabilityService``. + """ + db = AsyncMock() + db.execute = AsyncMock(side_effect=results) + db.flush = AsyncMock() + db.refresh = AsyncMock() + db.add = MagicMock() + return db + + +def sales_rows(values: list[float], start: date = TEST_START) -> list[SimpleNamespace]: + """Build sales-row stand-ins (``.quantity`` / ``.date``) for the mock DB.""" + return [ + SimpleNamespace(quantity=int(v), date=start + timedelta(days=i)) + for i, v in enumerate(values) + ] + + +def forecast_result_db(values: list[float]) -> AsyncMock: + """Mock DB for one ``explain_forecast`` call (series + 4 reason-code queries).""" + return make_mock_db( + [ + mock_result(scalars=sales_rows(values)), # _load_series + mock_result(scalars=[]), # inventory + mock_result(scalars=[]), # promotion + mock_result(one=None), # product + mock_result(one=None), # calendar + ] + ) + + +# ============================================================================= +# Integration fixtures — real Postgres +# ============================================================================= + + +@pytest.fixture +async def db_session() -> AsyncGenerator[AsyncSession, None]: + """Yield an async session; wipe explainability + test data on teardown.""" + settings = get_settings() + engine = create_async_engine(settings.database_url, echo=False) + session_maker = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) + + async with session_maker() as session: + try: + yield session + finally: + await session.execute(delete(ForecastExplanation)) + await session.execute(delete(SalesDaily)) + await session.execute(delete(ModelRun).where(ModelRun.run_id.like("texpl%"))) + await session.execute(delete(Product).where(Product.sku.like("TEXPL-%"))) + await session.execute(delete(Store).where(Store.code.like("TEXPL-%"))) + await session.execute( + delete(Calendar).where((Calendar.date >= TEST_START) & (Calendar.date <= TEST_END)) + ) + await session.commit() + + await engine.dispose() + + +@pytest.fixture +async def client(db_session: AsyncSession) -> AsyncGenerator[AsyncClient, None]: + """Test client with the database dependency overridden.""" + + async def override_get_db() -> AsyncGenerator[AsyncSession, None]: + yield db_session + + app.dependency_overrides[get_db] = override_get_db + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as ac: + yield ac + app.dependency_overrides.pop(get_db, None) + + +@pytest.fixture +async def seeded_series(db_session: AsyncSession) -> dict[str, int]: + """Seed a store, product, calendar, and a sales series; return ids. + + The series is a clean weekly pattern so the seasonal-naive h=1 forecast is + deterministic. + """ + suffix = uuid.uuid4().hex[:8] + store = Store(code=f"TEXPL-{suffix}", name="Explain Store", region="R", store_type="x") + product = Product( + sku=f"TEXPL-{suffix}", + name="Explain Product", + category="C", + base_price=Decimal("9.99"), + launch_date=TEST_START, + ) + db_session.add_all([store, product]) + await db_session.commit() + await db_session.refresh(store) + await db_session.refresh(product) + + weekly = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0] + for i in range(TEST_DAYS): + d = TEST_START + timedelta(days=i) + await db_session.merge( + Calendar( + date=d, + day_of_week=d.weekday(), + month=d.month, + quarter=(d.month - 1) // 3 + 1, + year=d.year, + is_holiday=False, + ) + ) + await db_session.commit() + + for i in range(TEST_DAYS): + qty = weekly[i % 7] + db_session.add( + SalesDaily( + date=TEST_START + timedelta(days=i), + store_id=store.id, + product_id=product.id, + quantity=int(qty), + unit_price=Decimal("9.99"), + total_amount=Decimal("9.99") * int(qty), + ) + ) + await db_session.commit() + + return {"store_id": store.id, "product_id": product.id} + + +@pytest.fixture +async def seeded_run(db_session: AsyncSession, seeded_series: dict[str, int]) -> str: + """Seed a successful baseline ModelRun over the seeded series; return run_id.""" + run_id = f"texpl{uuid.uuid4().hex[:11]}" + run = ModelRun( + run_id=run_id, + status=RunStatus.SUCCESS.value, + model_type="naive", + model_config={"model_type": "naive", "schema_version": "1.0"}, + config_hash="deadbeefdeadbeef", + data_window_start=TEST_START, + data_window_end=TEST_END, + store_id=seeded_series["store_id"], + product_id=seeded_series["product_id"], + ) + db_session.add(run) + await db_session.commit() + return run_id + + +@pytest.fixture +def explanation_row_kwargs() -> dict[str, Any]: + """Valid keyword args for constructing a ForecastExplanation ORM row.""" + return { + "explanation_id": uuid.uuid4().hex, + "store_id": 1, + "product_id": 2, + "model_type": "naive", + "method": "rule_based", + "as_of_date": datetime.date(2024, 3, 1), + "forecast_value": 42.0, + "confidence": "medium", + "drivers": [{"name": "last_observation", "contribution": 42.0}], + "reason_codes": [], + "caveats": ["correlation not causation"], + "agent_summary": "A test explanation.", + } diff --git a/app/features/explainability/tests/test_explainers.py b/app/features/explainability/tests/test_explainers.py new file mode 100644 index 00000000..05045145 --- /dev/null +++ b/app/features/explainability/tests/test_explainers.py @@ -0,0 +1,161 @@ +"""Unit tests for the rule-based explainers. + +The load-bearing assertion: each explainer's h=1 forecast value EQUALS the real +forecaster's ``.fit(y).predict(1)[0]`` on the same series. A rule-based +explainer is exact — if it diverges from the forecaster, it is wrong. +""" + +from __future__ import annotations + +import numpy as np +import pytest + +from app.features.explainability.explainers import ( + MovingAverageExplainer, + NaiveExplainer, + SeasonalNaiveExplainer, + explainer_factory, +) +from app.features.explainability.schemas import ConfidenceLevel +from app.features.forecasting.models import ( + MovingAverageForecaster, + NaiveForecaster, + SeasonalNaiveForecaster, +) + + +class TestNaiveExplainer: + """Tests for NaiveExplainer.""" + + def test_forecast_matches_real_forecaster(self, sample_series: np.ndarray) -> None: + """h=1 value equals NaiveForecaster's prediction on the same series.""" + forecast, _ = NaiveExplainer().explain(sample_series) + expected = float(NaiveForecaster().fit(sample_series).predict(1)[0]) + assert forecast == pytest.approx(expected) + + def test_main_driver_contribution_sums_to_forecast(self, sample_series: np.ndarray) -> None: + """Driver contributions sum to the forecast value.""" + forecast, drivers = NaiveExplainer().explain(sample_series) + assert sum(d.contribution for d in drivers) == pytest.approx(forecast) + + def test_recent_trend_driver_present_for_long_series(self, sample_series: np.ndarray) -> None: + """A long series gets an informational recent_trend driver.""" + _, drivers = NaiveExplainer().explain(sample_series) + names = {d.name for d in drivers} + assert "last_observation" in names + assert "recent_trend" in names + trend = next(d for d in drivers if d.name == "recent_trend") + assert trend.contribution == 0.0 + + def test_no_trend_driver_for_short_series(self, short_series: np.ndarray) -> None: + """A short series gets only the last_observation driver.""" + _, drivers = NaiveExplainer().explain(short_series) + assert [d.name for d in drivers] == ["last_observation"] + + def test_empty_series_raises(self) -> None: + """An empty series raises ValueError (mirrors NaiveForecaster.fit).""" + with pytest.raises(ValueError, match="empty"): + NaiveExplainer().explain(np.array([], dtype=np.float64)) + + def test_confidence_downgrades_on_short_series( + self, sample_series: np.ndarray, short_series: np.ndarray + ) -> None: + """Confidence is LOW for a short series, MEDIUM otherwise.""" + assert NaiveExplainer().confidence(short_series) == ConfidenceLevel.LOW + assert NaiveExplainer().confidence(sample_series) == ConfidenceLevel.MEDIUM + + +class TestSeasonalNaiveExplainer: + """Tests for SeasonalNaiveExplainer.""" + + def test_forecast_matches_real_forecaster(self, sample_series: np.ndarray) -> None: + """h=1 value equals SeasonalNaiveForecaster's prediction.""" + forecast, _ = SeasonalNaiveExplainer(season_length=7).explain(sample_series) + expected = float(SeasonalNaiveForecaster(season_length=7).fit(sample_series).predict(1)[0]) + assert forecast == pytest.approx(expected) + + def test_main_driver_contribution_sums_to_forecast(self, sample_series: np.ndarray) -> None: + """Driver contributions sum to the forecast value.""" + forecast, drivers = SeasonalNaiveExplainer(season_length=7).explain(sample_series) + assert sum(d.contribution for d in drivers) == pytest.approx(forecast) + assert drivers[0].direction == "positive" + + def test_too_short_series_raises(self, short_series: np.ndarray) -> None: + """A series shorter than season_length raises ValueError.""" + with pytest.raises(ValueError, match="at least"): + SeasonalNaiveExplainer(season_length=7).explain(short_series) + + def test_invalid_season_length_raises(self) -> None: + """season_length < 1 raises ValueError.""" + with pytest.raises(ValueError, match="season_length"): + SeasonalNaiveExplainer(season_length=0) + + def test_confidence_downgrades_under_two_cycles(self) -> None: + """Confidence is LOW under two seasonal cycles, MEDIUM otherwise.""" + short = np.arange(10.0, dtype=np.float64) # 10 < 2*7 + long = np.arange(40.0, dtype=np.float64) # 40 >= 2*7 + assert SeasonalNaiveExplainer(7).confidence(short) == ConfidenceLevel.LOW + assert SeasonalNaiveExplainer(7).confidence(long) == ConfidenceLevel.MEDIUM + + +class TestMovingAverageExplainer: + """Tests for MovingAverageExplainer.""" + + def test_forecast_matches_real_forecaster(self, sample_series: np.ndarray) -> None: + """h=1 value equals MovingAverageForecaster's prediction.""" + forecast, _ = MovingAverageExplainer(window_size=7).explain(sample_series) + expected = float(MovingAverageForecaster(window_size=7).fit(sample_series).predict(1)[0]) + assert forecast == pytest.approx(expected) + + def test_main_driver_contribution_sums_to_forecast(self, sample_series: np.ndarray) -> None: + """Driver contributions sum to the forecast (dispersion contributes 0).""" + forecast, drivers = MovingAverageExplainer(window_size=7).explain(sample_series) + assert sum(d.contribution for d in drivers) == pytest.approx(forecast) + dispersion = next(d for d in drivers if d.name == "window_dispersion") + assert dispersion.contribution == 0.0 + assert dispersion.direction == "neutral" + + def test_too_short_series_raises(self, short_series: np.ndarray) -> None: + """A series shorter than window_size raises ValueError.""" + with pytest.raises(ValueError, match="at least"): + MovingAverageExplainer(window_size=7).explain(short_series) + + def test_confidence_high_for_stable_window(self, flat_series: np.ndarray) -> None: + """A flat (zero-dispersion) window yields HIGH confidence.""" + assert MovingAverageExplainer(7).confidence(flat_series) == ConfidenceLevel.HIGH + + def test_confidence_medium_for_noisy_window(self) -> None: + """A high-variance window yields MEDIUM confidence.""" + noisy = np.array([1.0, 100.0, 2.0, 99.0, 3.0, 98.0, 4.0], dtype=np.float64) + assert MovingAverageExplainer(7).confidence(noisy) == ConfidenceLevel.MEDIUM + + +class TestExplainerFactory: + """Tests for explainer_factory.""" + + def test_builds_each_baseline(self) -> None: + """The factory builds the matching explainer per baseline model type.""" + assert isinstance(explainer_factory("naive"), NaiveExplainer) + assert isinstance( + explainer_factory("seasonal_naive", season_length=14), SeasonalNaiveExplainer + ) + assert isinstance( + explainer_factory("moving_average", window_size=21), MovingAverageExplainer + ) + + def test_seasonal_defaults_to_seven(self) -> None: + """A None season_length defaults to 7.""" + explainer = explainer_factory("seasonal_naive") + assert isinstance(explainer, SeasonalNaiveExplainer) + assert explainer.season_length == 7 + + @pytest.mark.parametrize("model_type", ["lightgbm", "regression"]) + def test_rejects_non_baseline_models(self, model_type: str) -> None: + """lightgbm/regression are rejected (MVP scope guard).""" + with pytest.raises(ValueError, match="baseline models only"): + explainer_factory(model_type) + + def test_rejects_unknown_model(self) -> None: + """An unknown model type raises ValueError.""" + with pytest.raises(ValueError, match="Unknown model type"): + explainer_factory("transformer") diff --git a/app/features/explainability/tests/test_models_integration.py b/app/features/explainability/tests/test_models_integration.py new file mode 100644 index 00000000..2fc210ff --- /dev/null +++ b/app/features/explainability/tests/test_models_integration.py @@ -0,0 +1,62 @@ +"""Integration tests for the ForecastExplanation ORM model. + +Run against the real docker-compose Postgres (``docker compose up -d``). +""" + +from __future__ import annotations + +from typing import Any + +import pytest +from sqlalchemy import select +from sqlalchemy.exc import IntegrityError +from sqlalchemy.ext.asyncio import AsyncSession + +from app.features.explainability.models import ForecastExplanation + + +@pytest.mark.integration +@pytest.mark.asyncio +class TestForecastExplanationModel: + """CRUD and constraint tests for the forecast_explanation table.""" + + async def test_insert_and_read_back( + self, db_session: AsyncSession, explanation_row_kwargs: dict[str, Any] + ) -> None: + """A forecast_explanation row persists and reads back intact.""" + row = ForecastExplanation(**explanation_row_kwargs) + db_session.add(row) + await db_session.commit() + + fetched = ( + await db_session.execute( + select(ForecastExplanation).where( + ForecastExplanation.explanation_id == explanation_row_kwargs["explanation_id"] + ) + ) + ).scalar_one() + assert fetched.forecast_value == 42.0 + assert fetched.model_type == "naive" + assert fetched.confidence == "medium" + assert fetched.drivers[0]["name"] == "last_observation" + assert fetched.created_at is not None + + async def test_confidence_check_constraint_rejects_bad_value( + self, db_session: AsyncSession, explanation_row_kwargs: dict[str, Any] + ) -> None: + """An out-of-allow-list confidence value is rejected by the CHECK.""" + bad = ForecastExplanation(**{**explanation_row_kwargs, "confidence": "bogus"}) + db_session.add(bad) + with pytest.raises(IntegrityError): + await db_session.flush() + await db_session.rollback() + + async def test_method_check_constraint_rejects_bad_value( + self, db_session: AsyncSession, explanation_row_kwargs: dict[str, Any] + ) -> None: + """An out-of-allow-list method value is rejected by the CHECK.""" + bad = ForecastExplanation(**{**explanation_row_kwargs, "method": "telepathy"}) + db_session.add(bad) + with pytest.raises(IntegrityError): + await db_session.flush() + await db_session.rollback() diff --git a/app/features/explainability/tests/test_reason_codes.py b/app/features/explainability/tests/test_reason_codes.py new file mode 100644 index 00000000..066ebba8 --- /dev/null +++ b/app/features/explainability/tests/test_reason_codes.py @@ -0,0 +1,128 @@ +"""Unit tests for the advisory reason-code engine.""" + +from __future__ import annotations + +from datetime import date + +from app.features.explainability.reason_codes import ( + CORRELATION_CAVEAT, + build_caveats, + history_reason, + holiday_reason, + lifecycle_reason, + promotion_reason, + stockout_reason, +) +from app.features.explainability.schemas import ReasonCode + +AS_OF = date(2024, 3, 1) + + +class TestStockoutReason: + """Tests for stockout_reason.""" + + def test_fires_on_stockout_days(self) -> None: + """A stockout day produces a warn-level reason code.""" + code = stockout_reason([False, True, False, True]) + assert code is not None + assert code.code == "stockout_constrained" + assert code.severity == "warn" + assert "2 stockout" in code.detail + + def test_none_when_no_stockout(self) -> None: + """No stockout days yields None.""" + assert stockout_reason([False, False, False]) is None + + def test_none_for_empty_window(self) -> None: + """An empty window yields None.""" + assert stockout_reason([]) is None + + +class TestPromotionReason: + """Tests for promotion_reason.""" + + def test_fires_on_overlap(self) -> None: + """An overlapping promotion produces an info reason code.""" + code = promotion_reason([(date(2024, 2, 20), date(2024, 2, 25))], AS_OF) + assert code is not None + assert code.code == "promotion_overlap" + assert code.severity == "info" + + def test_detects_promotion_active_at_cutoff(self) -> None: + """A promotion active on as_of_date is called out in the detail.""" + code = promotion_reason([(date(2024, 2, 25), date(2024, 3, 10))], AS_OF) + assert code is not None + assert "still active" in code.detail + + def test_none_when_no_promotions(self) -> None: + """No promotions yields None.""" + assert promotion_reason([], AS_OF) is None + + +class TestLifecycleReason: + """Tests for lifecycle_reason.""" + + def test_fires_for_recent_launch(self) -> None: + """A launch under 30 days ago produces an info reason code.""" + code = lifecycle_reason(date(2024, 2, 15), AS_OF) + assert code is not None + assert code.code == "lifecycle_decay" + + def test_none_for_old_launch(self) -> None: + """A launch over 30 days ago yields None.""" + assert lifecycle_reason(date(2023, 1, 1), AS_OF) is None + + def test_none_when_launch_unknown(self) -> None: + """An unknown launch date yields None.""" + assert lifecycle_reason(None, AS_OF) is None + + +class TestHolidayReason: + """Tests for holiday_reason.""" + + def test_fires_on_holiday(self) -> None: + """A holiday forecast date produces an info reason code.""" + code = holiday_reason(True, "New Year", date(2024, 3, 2)) + assert code is not None + assert code.code == "holiday_effect" + assert "New Year" in code.detail + + def test_none_on_normal_day(self) -> None: + """A non-holiday forecast date yields None.""" + assert holiday_reason(False, None, date(2024, 3, 2)) is None + + +class TestHistoryReason: + """Tests for history_reason.""" + + def test_fires_for_short_series(self) -> None: + """Fewer observations than required produces a warn reason code.""" + code = history_reason(5, 14) + assert code is not None + assert code.code == "insufficient_history" + assert code.severity == "warn" + + def test_none_for_sufficient_history(self) -> None: + """Enough observations yields None.""" + assert history_reason(30, 14) is None + + +class TestBuildCaveats: + """Tests for build_caveats.""" + + def test_always_includes_correlation_caveat(self) -> None: + """Every caveat list starts with the correlation-vs-causation disclaimer.""" + caveats = build_caveats("naive", []) + assert caveats[0] == CORRELATION_CAVEAT + + def test_includes_model_specific_caveat(self) -> None: + """A model-specific caveat is appended for each baseline.""" + assert any("seasonality" in c for c in build_caveats("naive", [])) + assert any("prior cycle" in c for c in build_caveats("seasonal_naive", [])) + assert any("smooths" in c for c in build_caveats("moving_average", [])) + + def test_adds_stockout_caveat(self) -> None: + """A stockout reason code adds an understated-demand caveat.""" + stockout = ReasonCode(code="stockout_constrained", severity="warn", detail="x") + caveats = build_caveats("naive", [stockout]) + assert any("understate" in c for c in caveats) diff --git a/app/features/explainability/tests/test_routes.py b/app/features/explainability/tests/test_routes.py new file mode 100644 index 00000000..cf73ad84 --- /dev/null +++ b/app/features/explainability/tests/test_routes.py @@ -0,0 +1,153 @@ +"""Unit route tests for the explainability endpoints. + +Each test overrides ``get_db`` with a scripted-mock session, so the routes are +exercised over the HTTP boundary without a real database. Error paths assert the +RFC 7807 problem-detail shape. +""" + +from __future__ import annotations + +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager +from datetime import date +from types import SimpleNamespace +from typing import Any +from unittest.mock import AsyncMock + +import pytest +from httpx import ASGITransport, AsyncClient + +from app.core.database import get_db +from app.features.explainability.tests.conftest import ( + forecast_result_db, + make_mock_db, + mock_result, +) +from app.main import app + + +@asynccontextmanager +async def _client(db: AsyncMock) -> AsyncGenerator[AsyncClient, None]: + """Yield a test client whose get_db dependency yields ``db``.""" + + async def override_get_db() -> AsyncGenerator[AsyncMock, None]: + yield db + + app.dependency_overrides[get_db] = override_get_db + try: + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as ac: + yield ac + finally: + app.dependency_overrides.pop(get_db, None) + + +def _assert_problem_detail(body: dict[str, Any], expected_status: int) -> None: + """Assert an RFC 7807 problem-detail body shape.""" + for key in ("type", "title", "status", "detail"): + assert key in body, f"missing RFC 7807 field: {key}" + assert body["status"] == expected_status + + +@pytest.mark.asyncio +async def test_explain_forecast_returns_200() -> None: + """POST /explain/forecast returns 200 with a well-formed explanation.""" + db = forecast_result_db([10.0, 12.0, 11.0, 9.0, 14.0]) + async with _client(db) as ac: + response = await ac.post( + "/explain/forecast", + json={ + "store_id": 1, + "product_id": 2, + "model_type": "naive", + "as_of_date": "2024-03-01", + }, + ) + assert response.status_code == 200 + body = response.json() + assert body["forecast_value"] == 14.0 + assert body["method"] == "rule_based" + assert body["drivers"][0]["name"] == "last_observation" + + +@pytest.mark.asyncio +async def test_explain_forecast_rejects_iso_string_path() -> None: + """An ISO-string as_of_date is accepted (strict-mode JSON path).""" + db = forecast_result_db([10.0, 12.0, 11.0]) + async with _client(db) as ac: + response = await ac.post( + "/explain/forecast", + json={ + "store_id": 1, + "product_id": 2, + "model_type": "naive", + "as_of_date": "2024-03-01", + }, + ) + assert response.status_code == 200 + + +@pytest.mark.asyncio +async def test_explain_forecast_empty_series_returns_400() -> None: + """An empty series yields an RFC 7807 400.""" + db = forecast_result_db([]) + async with _client(db) as ac: + response = await ac.post( + "/explain/forecast", + json={ + "store_id": 1, + "product_id": 2, + "model_type": "naive", + "as_of_date": "2024-03-01", + }, + ) + assert response.status_code == 400 + _assert_problem_detail(response.json(), 400) + + +@pytest.mark.asyncio +async def test_explain_run_missing_returns_404() -> None: + """GET /explain/runs/{missing} yields an RFC 7807 404.""" + db = make_mock_db([mock_result(one=None)]) + async with _client(db) as ac: + response = await ac.get("/explain/runs/does-not-exist") + assert response.status_code == 404 + _assert_problem_detail(response.json(), 404) + + +@pytest.mark.asyncio +async def test_explain_run_lightgbm_returns_400() -> None: + """GET /explain/runs/{lightgbm-run} yields an RFC 7807 400.""" + run = SimpleNamespace( + run_id="run-lgbm", + model_type="lightgbm", + model_config={"model_type": "lightgbm"}, + store_id=1, + product_id=2, + data_window_end=date(2024, 3, 1), + ) + db = make_mock_db([mock_result(one=run)]) + async with _client(db) as ac: + response = await ac.get("/explain/runs/run-lgbm") + assert response.status_code == 400 + _assert_problem_detail(response.json(), 400) + + +@pytest.mark.asyncio +async def test_explain_job_missing_returns_404() -> None: + """GET /explain/jobs/{missing} yields an RFC 7807 404.""" + db = make_mock_db([mock_result(one=None)]) + async with _client(db) as ac: + response = await ac.get("/explain/jobs/does-not-exist") + assert response.status_code == 404 + _assert_problem_detail(response.json(), 404) + + +@pytest.mark.asyncio +async def test_explain_job_non_predict_returns_400() -> None: + """GET /explain/jobs/{train-job} yields an RFC 7807 400.""" + job = SimpleNamespace(job_id="job-train", job_type="train", status="completed", result={}) + db = make_mock_db([mock_result(one=job)]) + async with _client(db) as ac: + response = await ac.get("/explain/jobs/job-train") + assert response.status_code == 400 + _assert_problem_detail(response.json(), 400) diff --git a/app/features/explainability/tests/test_routes_integration.py b/app/features/explainability/tests/test_routes_integration.py new file mode 100644 index 00000000..6d79bc59 --- /dev/null +++ b/app/features/explainability/tests/test_routes_integration.py @@ -0,0 +1,85 @@ +"""End-to-end integration tests for the explainability endpoints. + +Run against the real docker-compose Postgres (``docker compose up -d``). The +``client`` fixture shares the test session, so a persisted explanation is +readable back through the same session after the request. +""" + +from __future__ import annotations + +import pytest +from httpx import AsyncClient +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.features.explainability.models import ForecastExplanation +from app.features.explainability.tests.conftest import TEST_END + + +@pytest.mark.integration +@pytest.mark.asyncio +class TestExplainEndpointsIntegration: + """End-to-end tests over a real database.""" + + async def test_explain_run_returns_explanation( + self, client: AsyncClient, seeded_run: str + ) -> None: + """GET /explain/runs/{run_id} explains a real baseline run.""" + response = await client.get(f"/explain/runs/{seeded_run}") + assert response.status_code == 200 + body = response.json() + assert body["model_type"] == "naive" + assert body["method"] == "rule_based" + assert body["drivers"] + assert body["confidence"] in ("high", "medium", "low") + assert body["caveats"] + assert body["agent_summary"] + # The naive forecast is the last observed value — a positive quantity. + assert body["forecast_value"] > 0 + + async def test_explain_run_persists_row( + self, client: AsyncClient, db_session: AsyncSession, seeded_run: str + ) -> None: + """The explanation is persisted as a forecast_explanation row.""" + await client.get(f"/explain/runs/{seeded_run}") + row = ( + await db_session.execute( + select(ForecastExplanation).where(ForecastExplanation.run_id == seeded_run) + ) + ).scalar_one() + assert row.model_type == "naive" + assert row.run_id == seeded_run + + async def test_explain_forecast_end_to_end( + self, client: AsyncClient, seeded_series: dict[str, int] + ) -> None: + """POST /explain/forecast explains an ad-hoc forecast over a real series.""" + response = await client.post( + "/explain/forecast", + json={ + "store_id": seeded_series["store_id"], + "product_id": seeded_series["product_id"], + "model_type": "seasonal_naive", + "as_of_date": TEST_END.isoformat(), + "season_length": 7, + }, + ) + assert response.status_code == 200 + body = response.json() + assert body["model_type"] == "seasonal_naive" + assert body["drivers"][0]["name"] == "season_match" + assert body["forecast_value"] > 0 + + async def test_explain_run_missing_returns_404(self, client: AsyncClient) -> None: + """GET /explain/runs/{missing} returns an RFC 7807 404.""" + response = await client.get("/explain/runs/no-such-run-id") + assert response.status_code == 404 + body = response.json() + assert body["status"] == 404 + assert "title" in body + + async def test_explain_job_missing_returns_404(self, client: AsyncClient) -> None: + """GET /explain/jobs/{missing} returns an RFC 7807 404.""" + response = await client.get("/explain/jobs/no-such-job-id") + assert response.status_code == 404 + assert response.json()["status"] == 404 diff --git a/app/features/explainability/tests/test_schemas.py b/app/features/explainability/tests/test_schemas.py new file mode 100644 index 00000000..156cdaf0 --- /dev/null +++ b/app/features/explainability/tests/test_schemas.py @@ -0,0 +1,134 @@ +"""Unit tests for the explainability Pydantic schemas. + +The JSON-path test (``test_request_accepts_iso_string_date``) is required by +``docs/_base/SECURITY.md`` — it exercises the ``validate_python`` path FastAPI +uses, catching the strict-mode date regression at unit-test time. +""" + +from __future__ import annotations + +from datetime import date + +import pytest +from pydantic import ValidationError + +from app.features.explainability.schemas import ( + ConfidenceLevel, + DriverContribution, + ExplainForecastRequest, + ForecastExplanation, + ReasonCode, +) + + +class TestExplainForecastRequest: + """Tests for the strict request body.""" + + def test_request_accepts_iso_string_date(self) -> None: + """as_of_date accepts an ISO-string (the FastAPI JSON path).""" + request = ExplainForecastRequest.model_validate( + { + "store_id": 1, + "product_id": 2, + "model_type": "naive", + "as_of_date": "2024-03-01", + } + ) + assert request.as_of_date == date(2024, 3, 1) + + def test_request_accepts_native_date(self) -> None: + """as_of_date also accepts a native date object.""" + request = ExplainForecastRequest( + store_id=1, product_id=2, model_type="naive", as_of_date=date(2024, 3, 1) + ) + assert request.as_of_date == date(2024, 3, 1) + + def test_invalid_model_type_rejected(self) -> None: + """A non-baseline model_type fails validation.""" + with pytest.raises(ValidationError): + ExplainForecastRequest.model_validate( + { + "store_id": 1, + "product_id": 2, + "model_type": "lightgbm", + "as_of_date": "2024-03-01", + } + ) + + def test_non_positive_store_id_rejected(self) -> None: + """store_id must be >= 1.""" + with pytest.raises(ValidationError): + ExplainForecastRequest.model_validate( + { + "store_id": 0, + "product_id": 2, + "model_type": "naive", + "as_of_date": "2024-03-01", + } + ) + + def test_optional_params_default_to_none(self) -> None: + """season_length and window_size default to None.""" + request = ExplainForecastRequest( + store_id=1, product_id=2, model_type="naive", as_of_date=date(2024, 3, 1) + ) + assert request.season_length is None + assert request.window_size is None + + +class TestForecastExplanation: + """Tests for the response schema.""" + + def test_round_trips_through_model_dump(self) -> None: + """A ForecastExplanation survives model_dump -> model_validate.""" + explanation = ForecastExplanation( + store_id=1, + product_id=2, + model_type="naive", + forecast_value=42.0, + drivers=[ + DriverContribution( + name="last_observation", + feature_value=42.0, + contribution=42.0, + direction="positive", + description="x", + ) + ], + reason_codes=[ReasonCode(code="holiday_effect", severity="info", detail="x")], + confidence=ConfidenceLevel.MEDIUM, + caveats=["correlation not causation"], + agent_summary="A summary.", + as_of_date=date(2024, 3, 1), + ) + restored = ForecastExplanation.model_validate(explanation.model_dump()) + assert restored.forecast_value == 42.0 + assert restored.method == "rule_based" + assert restored.confidence == ConfidenceLevel.MEDIUM + assert restored.drivers[0].name == "last_observation" + + def test_method_defaults_to_rule_based(self) -> None: + """method defaults to rule_based.""" + explanation = ForecastExplanation( + store_id=1, + product_id=2, + model_type="naive", + forecast_value=1.0, + drivers=[], + reason_codes=[], + confidence=ConfidenceLevel.LOW, + caveats=[], + agent_summary="x", + as_of_date=date(2024, 3, 1), + ) + assert explanation.method == "rule_based" + + +class TestConfidenceLevel: + """Tests for the ConfidenceLevel enum.""" + + def test_values(self) -> None: + """The enum carries the three expected string values.""" + assert ConfidenceLevel.HIGH.value == "high" + assert ConfidenceLevel.MEDIUM.value == "medium" + assert ConfidenceLevel.LOW.value == "low" diff --git a/app/features/explainability/tests/test_service.py b/app/features/explainability/tests/test_service.py new file mode 100644 index 00000000..f7be684d --- /dev/null +++ b/app/features/explainability/tests/test_service.py @@ -0,0 +1,179 @@ +"""Unit tests for ExplainabilityService with a scripted-mock AsyncSession. + +The mock session returns pre-built ``Result`` objects in ``execute`` call order +(see ``conftest.make_mock_db``) so the service logic is exercised without a DB. +""" + +from __future__ import annotations + +from datetime import date +from types import SimpleNamespace +from typing import Literal + +import pytest + +from app.core.exceptions import BadRequestError +from app.features.explainability.schemas import ( + ConfidenceLevel, + ExplainForecastRequest, + ForecastExplanation, +) +from app.features.explainability.service import ExplainabilityService +from app.features.explainability.tests.conftest import ( + forecast_result_db, + make_mock_db, + mock_result, + sales_rows, +) + + +def _request( + model_type: Literal["naive", "seasonal_naive", "moving_average"] = "naive", +) -> ExplainForecastRequest: + """Build an ExplainForecastRequest for the given model type.""" + return ExplainForecastRequest( + store_id=1, product_id=2, model_type=model_type, as_of_date=date(2024, 3, 1) + ) + + +class TestExplainForecast: + """Tests for ExplainabilityService.explain_forecast.""" + + async def test_returns_well_formed_explanation(self) -> None: + """A naive forecast explanation reproduces the last observed value.""" + db = forecast_result_db([10.0, 12.0, 11.0, 9.0, 14.0]) + explanation = await ExplainabilityService().explain_forecast(db, _request()) + + assert isinstance(explanation, ForecastExplanation) + assert explanation.forecast_value == 14.0 # last observation + assert explanation.method == "rule_based" + assert explanation.drivers[0].name == "last_observation" + assert explanation.agent_summary + # The correlation-vs-causation caveat is always present. + assert any("causality" in c for c in explanation.caveats) + + async def test_persists_the_explanation(self) -> None: + """The service adds, flushes, and refreshes a forecast_explanation row.""" + db = forecast_result_db([10.0, 12.0, 11.0]) + await ExplainabilityService().explain_forecast(db, _request()) + + db.add.assert_called_once() + db.flush.assert_awaited_once() + db.refresh.assert_awaited_once() + + async def test_short_series_flags_insufficient_history(self) -> None: + """A short series yields LOW confidence and an insufficient_history code.""" + db = forecast_result_db([10.0, 12.0, 11.0]) + explanation = await ExplainabilityService().explain_forecast(db, _request()) + + assert explanation.confidence == ConfidenceLevel.LOW + codes = {rc.code for rc in explanation.reason_codes} + assert "insufficient_history" in codes + + async def test_empty_series_raises_value_error(self) -> None: + """An empty series raises ValueError (route maps it to 400).""" + db = forecast_result_db([]) + with pytest.raises(ValueError, match="empty"): + await ExplainabilityService().explain_forecast(db, _request()) + + +class TestExplainRun: + """Tests for ExplainabilityService.explain_run.""" + + async def test_missing_run_returns_none(self) -> None: + """A missing run_id returns None (route maps it to 404).""" + db = make_mock_db([mock_result(one=None)]) + result = await ExplainabilityService().explain_run(db, "does-not-exist") + assert result is None + + async def test_explains_a_baseline_run(self) -> None: + """A baseline run resolves its config and produces an explanation.""" + run = SimpleNamespace( + run_id="run-abc", + model_type="naive", + model_config={"model_type": "naive"}, + store_id=1, + product_id=2, + data_window_end=date(2024, 3, 1), + ) + db = make_mock_db( + [ + mock_result(one=run), + mock_result(scalars=sales_rows([10.0, 20.0, 15.0])), + mock_result(scalars=[]), + mock_result(scalars=[]), + mock_result(one=None), + mock_result(one=None), + ] + ) + explanation = await ExplainabilityService().explain_run(db, "run-abc") + assert explanation is not None + assert explanation.forecast_value == 15.0 + + async def test_lightgbm_run_raises_value_error(self) -> None: + """A lightgbm run raises ValueError before any series load.""" + run = SimpleNamespace( + run_id="run-lgbm", + model_type="lightgbm", + model_config={"model_type": "lightgbm"}, + store_id=1, + product_id=2, + data_window_end=date(2024, 3, 1), + ) + db = make_mock_db([mock_result(one=run)]) + with pytest.raises(ValueError, match="baseline models only"): + await ExplainabilityService().explain_run(db, "run-lgbm") + + +class TestExplainJob: + """Tests for ExplainabilityService.explain_job.""" + + async def test_missing_job_returns_none(self) -> None: + """A missing job_id returns None (route maps it to 404).""" + db = make_mock_db([mock_result(one=None)]) + result = await ExplainabilityService().explain_job(db, "does-not-exist") + assert result is None + + async def test_non_completed_job_raises_bad_request(self) -> None: + """A pending predict job raises BadRequestError.""" + job = SimpleNamespace(job_id="job-1", job_type="predict", status="pending", result=None) + db = make_mock_db([mock_result(one=job)]) + with pytest.raises(BadRequestError, match="completed predict job"): + await ExplainabilityService().explain_job(db, "job-1") + + async def test_non_predict_job_raises_bad_request(self) -> None: + """A completed train job raises BadRequestError.""" + job = SimpleNamespace(job_id="job-2", job_type="train", status="completed", result={}) + db = make_mock_db([mock_result(one=job)]) + with pytest.raises(BadRequestError, match="completed predict job"): + await ExplainabilityService().explain_job(db, "job-2") + + async def test_explains_a_completed_predict_job(self) -> None: + """A completed predict job produces an explanation at the right cutoff.""" + job = SimpleNamespace( + job_id="job-3", + job_type="predict", + status="completed", + result={ + "store_id": 1, + "product_id": 2, + "model_type": "naive", + "horizon": 7, + "forecasts": [{"date": "2024-03-02", "forecast": 25.0}], + }, + ) + db = make_mock_db( + [ + mock_result(one=job), + mock_result(scalars=sales_rows([10.0, 20.0, 25.0])), + mock_result(scalars=[]), + mock_result(scalars=[]), + mock_result(one=None), + mock_result(one=None), + ] + ) + explanation = await ExplainabilityService().explain_job(db, "job-3") + assert explanation is not None + # as_of_date = day before the first forecast date. + assert explanation.as_of_date == date(2024, 3, 1) + assert explanation.forecast_value == 25.0 diff --git a/app/main.py b/app/main.py index bcc09d35..3cb36c8e 100644 --- a/app/main.py +++ b/app/main.py @@ -20,6 +20,7 @@ from app.features.config.service import apply_overrides_on_startup from app.features.demo.routes import router as demo_router from app.features.dimensions.routes import router as dimensions_router +from app.features.explainability.routes import router as explainability_router from app.features.featuresets.routes import router as featuresets_router from app.features.forecasting.routes import router as forecasting_router from app.features.ingest.routes import router as ingest_router @@ -138,6 +139,7 @@ def create_app() -> FastAPI: app.include_router(ingest_router) app.include_router(featuresets_router) app.include_router(forecasting_router) + app.include_router(explainability_router) app.include_router(backtesting_router) app.include_router(registry_router) app.include_router(rag_router) diff --git a/docs/_base/API_CONTRACTS.md b/docs/_base/API_CONTRACTS.md index 3719a802..d57debcc 100644 --- a/docs/_base/API_CONTRACTS.md +++ b/docs/_base/API_CONTRACTS.md @@ -22,6 +22,9 @@ All endpoints serve JSON; error responses use `application/problem+json` (RFC 78 | forecasting | POST | `/forecasting/train` | Train a model (naive / seasonal_naive / moving_average / lightgbm / regression). `regression` wraps `HistGradientBoostingRegressor` on lag + calendar + exogenous features — the baseline a `model_exogenous` scenario re-forecasts through | | forecasting | POST | `/forecasting/predict` | Generate horizon predictions from a trained model | | backtesting | POST | `/backtesting/run` | Time-series CV (rolling/expanding splits, MAE/sMAPE/WAPE/bias/stability) | +| explainability | POST | `/explain/forecast` | Rule-based explanation of the h=1 forecast a named baseline model (`naive`/`seasonal_naive`/`moving_average`) produces on the series ending at `as_of_date`; returns a `ForecastExplanation` — driver contributions, advisory retail reason codes (correlation, not causation), confidence band, caveats, agent summary. Time-safe (`<= as_of_date`); a non-baseline `model_type` or a too-short series → RFC 7807 400 | +| explainability | GET | `/explain/runs/{run_id}` | Explain a registry `model_run` — config reconstructed from `model_run.model_config`, cutoff `data_window_end`. Missing run → 404; a non-baseline (`lightgbm`/`regression`) run → 400 | +| explainability | GET | `/explain/jobs/{job_id}` | Explain a completed `predict` job — store/product/model read from `job.result`, cutoff = day before the first forecast date. Missing job → 404; a job that is not a completed predict job → 400 | | scenarios | POST | `/scenarios/simulate` | Stateless what-if: load a baseline model, forecast, apply price/promotion/holiday/inventory/lifecycle assumptions, return a `ScenarioComparison`. A `regression` baseline genuinely re-forecasts through a leakage-safe future feature frame (`method="model_exogenous"`); any other baseline applies a deterministic post-forecast multiplier (`method="heuristic"`). Bogus `run_id` → RFC 7807 404 | | scenarios | POST | `/scenarios` | Run a simulation and persist it as a named `scenario_plan` (raw assumptions + full comparison snapshot); optional `tags` + `cloned_from` | | scenarios | GET | `/scenarios` | List saved scenario plans, newest first (`limit`/`offset`, optional repeated `tags` filter — JSONB containment); `200` + empty list on an empty table | diff --git a/frontend/src/components/explainability/explanation-panel.test.tsx b/frontend/src/components/explainability/explanation-panel.test.tsx new file mode 100644 index 00000000..11e4f801 --- /dev/null +++ b/frontend/src/components/explainability/explanation-panel.test.tsx @@ -0,0 +1,66 @@ +import { describe, expect, it } from 'vitest' +import { render, screen } from '@testing-library/react' +import { ApiError } from '@/lib/api' +import { ExplanationPanel } from './explanation-panel' +import type { ForecastExplanation } from '@/types/api' + +const sampleExplanation: ForecastExplanation = { + store_id: 1, + product_id: 2, + model_type: 'naive', + method: 'rule_based', + forecast_value: 42, + drivers: [ + { + name: 'last_observation', + feature_value: 42, + contribution: 42, + direction: 'positive', + description: 'The naive forecast is the last observed value.', + }, + ], + reason_codes: [ + { code: 'stockout_constrained', severity: 'warn', detail: '2 stockout days.' }, + ], + confidence: 'medium', + caveats: ['Drivers describe correlation, not causation.'], + agent_summary: 'The naive model forecasts 42 units.', + as_of_date: '2024-03-01', + generated_at: '2024-03-01T00:00:00Z', +} + +describe('ExplanationPanel', () => { + it('renders drivers, reason codes, confidence, and caveats', () => { + render() + + expect(screen.getByText('last observation')).toBeTruthy() + expect(screen.getByText('Positive')).toBeTruthy() + expect(screen.getByText('medium')).toBeTruthy() + expect(screen.getByText(/stockout constrained/)).toBeTruthy() + expect(screen.getByText(/correlation, not causation/)).toBeTruthy() + expect(screen.getByText(sampleExplanation.agent_summary)).toBeTruthy() + }) + + it('renders a loading state', () => { + render() + expect(screen.getByText(/Generating explanation/)).toBeTruthy() + }) + + it('renders a destructive error for an unexpected failure', () => { + render() + expect(screen.getByText('boom')).toBeTruthy() + }) + + it('renders a neutral message for a 400 (non-baseline model)', () => { + const apiError = new ApiError('Explanations are available for baseline models only', 400) + render() + expect(screen.getByText(/baseline models only/)).toBeTruthy() + }) + + it('shows a no-signals message when there are no reason codes', () => { + render( + , + ) + expect(screen.getByText(/No advisory retail signals/)).toBeTruthy() + }) +}) diff --git a/frontend/src/components/explainability/explanation-panel.tsx b/frontend/src/components/explainability/explanation-panel.tsx new file mode 100644 index 00000000..ecd9b0a7 --- /dev/null +++ b/frontend/src/components/explainability/explanation-panel.tsx @@ -0,0 +1,218 @@ +import type { ReactNode } from 'react' +import { AlertTriangle, Info, Lightbulb, Minus, TrendingDown, TrendingUp } from 'lucide-react' +import { ApiError, formatNumber, getErrorMessage } from '@/lib/api' +import { Badge } from '@/components/ui/badge' +import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card' +import { LoadingState } from '@/components/common/loading-state' +import { + Table, + TableBody, + TableCell, + TableHead, + TableHeader, + TableRow, +} from '@/components/ui/table' +import type { + ConfidenceLevel, + DriverContribution, + ForecastExplanation, + ReasonCode, +} from '@/types/api' + +interface ExplanationPanelProps { + explanation?: ForecastExplanation + isLoading?: boolean + error?: unknown +} + +const CONFIDENCE_VARIANT: Record = { + high: 'default', + medium: 'secondary', + low: 'outline', +} + +function DirectionLabel({ direction }: { direction: DriverContribution['direction'] }) { + if (direction === 'positive') { + return ( + + + Positive + + ) + } + if (direction === 'negative') { + return ( + + + Negative + + ) + } + return ( + + + Neutral + + ) +} + +function ReasonCodeRow({ reason }: { reason: ReasonCode }) { + const isWarn = reason.severity === 'warn' + return ( +
  • + {isWarn ? ( + + ) : ( + + )} + + {reason.code.replace(/_/g, ' ')} + {' — '} + {reason.detail} + +
  • + ) +} + +/** Card shell shared by every panel state so the layout never jumps. */ +function PanelShell({ children }: { children: ReactNode }) { + return ( + + + + + Forecast Explanation + + + Rule-based driver attribution for the h=1 baseline forecast. + + + {children} + + ) +} + +export function ExplanationPanel({ explanation, isLoading, error }: ExplanationPanelProps) { + if (isLoading) { + return ( + + + + ) + } + + if (error) { + // A 400 here means the run/job is not a baseline model — an expected, + // non-error outcome, so it is shown in a neutral (not destructive) tone. + const isExpected = error instanceof ApiError && error.status === 400 + return ( + +
    + {isExpected ? ( + + ) : ( + + )} + {getErrorMessage(error)} +
    +
    + ) + } + + if (!explanation) { + return ( + +

    No explanation available.

    +
    + ) + } + + return ( + +
    +
    +
    +

    h=1 forecast

    +

    {formatNumber(explanation.forecast_value, 1)}

    +
    +
    +

    Confidence

    + + {explanation.confidence} + +
    +
    +

    Model

    +

    {explanation.model_type}

    +
    +
    + +

    {explanation.agent_summary}

    + +
    +

    Drivers

    + + + + Driver + Value + Contribution + Direction + + + + {explanation.drivers.map((driver) => ( + + + {driver.name.replace(/_/g, ' ')} + + {driver.description} + + + + {formatNumber(driver.feature_value, 1)} + + + {formatNumber(driver.contribution, 1)} + + + + + + ))} + +
    +
    + +
    +

    Retail signals

    + {explanation.reason_codes.length > 0 ? ( +
      + {explanation.reason_codes.map((reason) => ( + + ))} +
    + ) : ( +

    + No advisory retail signals were detected. +

    + )} +
    + +
    + {explanation.caveats.map((caveat) => ( +

    + {caveat} +

    + ))} +
    +
    +
    + ) +} diff --git a/frontend/src/hooks/use-explanations.ts b/frontend/src/hooks/use-explanations.ts new file mode 100644 index 00000000..d6ac799f --- /dev/null +++ b/frontend/src/hooks/use-explanations.ts @@ -0,0 +1,48 @@ +import { useMutation, useQuery } from '@tanstack/react-query' +import { api } from '@/lib/api' +import type { ForecastExplanation } from '@/types/api' + +/** + * Explain a registry model run. Disabled until `runId` is set. `retry: false` + * because a 404 (no run) or 400 (non-baseline run) is a final answer, not a + * transient failure. + */ +export function useRunExplanation(runId: string, enabled = true) { + return useQuery({ + queryKey: ['explanations', 'run', runId], + queryFn: () => api(`/explain/runs/${runId}`), + enabled: enabled && !!runId, + retry: false, + }) +} + +/** + * Explain a completed predict job. Disabled until `jobId` is set; `retry: false` + * for the same reason as {@link useRunExplanation}. + */ +export function useJobExplanation(jobId: string, enabled = true) { + return useQuery({ + queryKey: ['explanations', 'job', jobId], + queryFn: () => api(`/explain/jobs/${jobId}`), + enabled: enabled && !!jobId, + retry: false, + }) +} + +/** Request body for POST /explain/forecast. */ +export interface ExplainForecastBody { + store_id: number + product_id: number + model_type: 'naive' | 'seasonal_naive' | 'moving_average' + as_of_date: string // ISO date + season_length?: number + window_size?: number +} + +/** Run an ad-hoc forecast explanation. */ +export function useExplainForecast() { + return useMutation({ + mutationFn: (body: ExplainForecastBody) => + api('/explain/forecast', { method: 'POST', body }), + }) +} diff --git a/frontend/src/pages/explorer/run-detail.tsx b/frontend/src/pages/explorer/run-detail.tsx index 15038b4e..f769e07f 100644 --- a/frontend/src/pages/explorer/run-detail.tsx +++ b/frontend/src/pages/explorer/run-detail.tsx @@ -10,6 +10,8 @@ import { ShieldCheck, } from 'lucide-react' import { useRun, useVerifyArtifact } from '@/hooks/use-runs' +import { useRunExplanation } from '@/hooks/use-explanations' +import { ExplanationPanel } from '@/components/explainability/explanation-panel' import { JsonBlock } from '@/components/common/json-block' import { ErrorDisplay } from '@/components/common/error-display' import { LoadingState } from '@/components/common/loading-state' @@ -41,6 +43,9 @@ export default function RunDetailPage() { const [verifyOn, setVerifyOn] = useState(false) const verifyQuery = useVerifyArtifact(runId ?? '', verifyOn) + // The explanation panel self-handles a 400 for non-baseline (lightgbm) runs. + const explanationQuery = useRunExplanation(runId ?? '', !!runId) + if (!runId) { return (
    @@ -158,6 +163,12 @@ export default function RunDetailPage() { + +
    diff --git a/frontend/src/pages/visualize/forecast.tsx b/frontend/src/pages/visualize/forecast.tsx index 611864ec..bed6cfe9 100644 --- a/frontend/src/pages/visualize/forecast.tsx +++ b/frontend/src/pages/visualize/forecast.tsx @@ -2,6 +2,8 @@ import { useState } from 'react' import { Link } from 'react-router-dom' import { BarChart3, Download, ExternalLink, Loader2, Play } from 'lucide-react' import { useJob, useCreateJob } from '@/hooks/use-jobs' +import { useJobExplanation } from '@/hooks/use-explanations' +import { ExplanationPanel } from '@/components/explainability/explanation-panel' import { TimeSeriesChart } from '@/components/charts/time-series-chart' import { EmptyState } from '@/components/common/error-display' import { JobPicker } from '@/components/common/job-picker' @@ -58,6 +60,10 @@ export default function ForecastPage() { (point) => point.lower_bound != null && point.upper_bound != null, ) + // Explain the loaded job only when it is a completed predict job. + const isPredictDone = job?.status === 'completed' && job?.job_type === 'predict' + const explanationQuery = useJobExplanation(job?.job_id ?? '', !!job && isPredictDone) + async function handleRunForecast() { if (!trainRunId) return setRunError(null) @@ -246,6 +252,15 @@ export default function ForecastPage() { )} + + {/* Forecast explanation — only for a completed predict job */} + {isPredictDone && ( + + )} )} diff --git a/frontend/src/types/api.ts b/frontend/src/types/api.ts index 80bc6da5..b29b23f1 100644 --- a/frontend/src/types/api.ts +++ b/frontend/src/types/api.ts @@ -860,3 +860,44 @@ export interface MultiScenarioComparison { // entry per scenario name. chart_series: Record[] } + +// ============================================================================= +// Explainability — PRP-28 forecast explanation & driver attribution +// ============================================================================= + +// Qualitative confidence band for a forecast explanation. +export type ConfidenceLevel = 'high' | 'medium' | 'low' + +// One named, interpretable demand driver behind a forecast. A driver with +// contribution === 0 is informational context the model does not consume. +export interface DriverContribution { + name: string + feature_value: number + contribution: number + direction: 'positive' | 'negative' | 'neutral' + description: string +} + +// An advisory retail signal correlated with the forecast — never a causal claim. +export interface ReasonCode { + code: string + severity: 'info' | 'warn' + detail: string +} + +// A structured, rule-based explanation of a baseline h=1 forecast — +// GET /explain/runs/{run_id}, GET /explain/jobs/{job_id}, POST /explain/forecast. +export interface ForecastExplanation { + store_id: number + product_id: number + model_type: string + method: 'rule_based' + forecast_value: number + drivers: DriverContribution[] + reason_codes: ReasonCode[] + confidence: ConfidenceLevel + caveats: string[] + agent_summary: string + as_of_date: string // ISO date + generated_at: string // ISO datetime +} From 4f23f687143c6491d859ef1a09791bc6fa2a426c Mon Sep 17 00:00:00 2001 From: Gabor Szabo Date: Tue, 19 May 2026 12:29:39 +0200 Subject: [PATCH 2/2] fix(ui): parse application/problem+json error bodies (#230) --- frontend/src/lib/api.test.ts | 52 ++++++++++++++++++++++++++++++++++++ frontend/src/lib/api.ts | 6 ++++- 2 files changed, 57 insertions(+), 1 deletion(-) create mode 100644 frontend/src/lib/api.test.ts diff --git a/frontend/src/lib/api.test.ts b/frontend/src/lib/api.test.ts new file mode 100644 index 00000000..f09dee8a --- /dev/null +++ b/frontend/src/lib/api.test.ts @@ -0,0 +1,52 @@ +import { afterEach, describe, expect, it, vi } from 'vitest' +import { ApiError, api } from './api' + +/** Build a fake `fetch` that returns one canned `Response`. */ +function stubFetch(body: string, init: ResponseInit) { + const fetchMock = vi.fn().mockResolvedValue(new Response(body, init)) + vi.stubGlobal('fetch', fetchMock) + return fetchMock +} + +afterEach(() => { + vi.unstubAllGlobals() +}) + +describe('api()', () => { + it('parses an RFC 7807 application/problem+json error body into ApiError.detail', async () => { + // Regression: api() previously only treated `application/json` as JSON, so + // `application/problem+json` error bodies went unparsed and the raw JSON + // string leaked into the UI via getErrorMessage(). + const problem = { + type: '/errors/bad-request', + title: 'Bad Request', + status: 400, + detail: 'Need at least 7 observations', + code: 'BAD_REQUEST', + } + stubFetch(JSON.stringify(problem), { + status: 400, + headers: { 'content-type': 'application/problem+json' }, + }) + + const err = await api('/explain/forecast', { method: 'POST', body: {} }).catch( + (e: unknown) => e, + ) + + expect(err).toBeInstanceOf(ApiError) + expect((err as ApiError).status).toBe(400) + expect((err as ApiError).message).toBe('Need at least 7 observations') + expect((err as ApiError).detail?.detail).toBe('Need at least 7 observations') + }) + + it('parses a plain application/json success body', async () => { + stubFetch(JSON.stringify({ status: 'ok' }), { + status: 200, + headers: { 'content-type': 'application/json' }, + }) + + const data = await api<{ status: string }>('/health') + + expect(data).toEqual({ status: 'ok' }) + }) +}) diff --git a/frontend/src/lib/api.ts b/frontend/src/lib/api.ts index c3eeec81..bc6ddaed 100644 --- a/frontend/src/lib/api.ts +++ b/frontend/src/lib/api.ts @@ -48,7 +48,11 @@ export async function api(endpoint: string, config: RequestConfig = {}): Prom const contentType = response.headers.get('content-type') || '' const rawBody = await response.text() - const isJson = contentType.includes('application/json') + // RFC 7807 error responses use `application/problem+json`, so match the + // `+json` structured-syntax suffix as well as plain `application/json` -- + // otherwise error bodies go unparsed and `ApiError.detail` is always empty. + const isJson = + contentType.includes('application/json') || contentType.includes('+json') let data: unknown = undefined if (rawBody && isJson) {