From 4036f1cc98db9b580fef98d9cf66e3a9a4a7bd57 Mon Sep 17 00:00:00 2001 From: Gabor Szabo Date: Tue, 19 May 2026 09:56:08 +0200 Subject: [PATCH 1/9] feat(api): add leakage-safe future feature-frame generator (#223) --- app/features/scenarios/feature_frame.py | 407 ++++++++++++++++++ .../scenarios/tests/test_feature_frame.py | 214 +++++++++ .../tests/test_future_frame_leakage.py | 166 +++++++ 3 files changed, 787 insertions(+) create mode 100644 app/features/scenarios/feature_frame.py create mode 100644 app/features/scenarios/tests/test_feature_frame.py create mode 100644 app/features/scenarios/tests/test_future_frame_leakage.py diff --git a/app/features/scenarios/feature_frame.py b/app/features/scenarios/feature_frame.py new file mode 100644 index 00000000..0c9c2635 --- /dev/null +++ b/app/features/scenarios/feature_frame.py @@ -0,0 +1,407 @@ +"""Leakage-safe future feature-frame generator (PRP-27 Phase A). + +The scenario MVP (PRP-26) never builds a future feature matrix — it multiplies +a baseline forecast by a deterministic factor, so it is *immune* to leakage. +The Full Version introduces a model-driven path (``method="model_exogenous"``) +that re-forecasts demand through a feature-consuming regressor, and that needs +a **future feature frame**: the same feature columns the model was trained on, +produced for each horizon day ``T+1 … T+horizon``. + +That is a new and dangerous surface — a horizon day ``D`` has *no observed +target* — so this module is governed by one rule: + + A future feature value for day ``D`` may only use information knowable at + the forecast origin ``T`` (the last training day): the observed history + up to and including ``T``, the calendar (a pure function of the date), or + the scenario assumptions (the planner's *posited* future inputs). + It may NEVER read an observed target at a horizon day. + +``app/features/scenarios/tests/test_future_frame_leakage.py`` is the +load-bearing spec for that rule — it must never be weakened (AGENTS.md +§ Safety), mirroring ``app/features/featuresets/tests/test_leakage.py``. + +DECISIONS LOCKED (PRP-27): +* #3 — no cross-slice ``service.py`` import. This module imports only the + ``data_platform`` ORM (a sanctioned read-only ORM import) and same-slice + schema value-objects; it replicates the small slice of leakage-safe + lag/calendar logic it needs rather than importing + ``FeatureEngineeringService``. +* #4 — long-lag + calendar + assumption-driven columns ONLY; no recursion. + A target lag value for horizon day ``T+j`` is the observed ``y[T+j-k]``; + when ``T+j-k > T`` (a future target) the cell is ``NaN`` — the model + (``HistGradientBoostingRegressor``) handles ``NaN`` natively. No recursion + ever fills those gaps in v1. +* #10/#11/#12 — the PINNED constants ``EXOGENOUS_LAGS``, + ``HISTORY_TAIL_DAYS`` and ``MAX_COMPARE_SCENARIOS`` live here. + +Feature-column contract: ``canonical_feature_columns()`` is the single source +of truth for the regression feature set and column order. The Phase B training +path persists exactly this list in the bundle metadata, and the future frame +reproduces it column-for-column, so a model trained today re-forecasts cleanly. +""" + +from __future__ import annotations + +import math +from dataclasses import dataclass +from datetime import date, timedelta +from typing import TYPE_CHECKING + +from sqlalchemy import select + +from app.core.logging import get_logger +from app.features.data_platform.models import Calendar + +if TYPE_CHECKING: + from sqlalchemy.ext.asyncio import AsyncSession + + from app.features.scenarios.schemas import ScenarioAssumptions + +logger = get_logger(__name__) + +# ── PINNED modelling constants (PRP-27 DECISIONS LOCKED #10/#11/#12) ── +# Lag offsets (days) for the target long-lag columns: daily, weekly, +# fortnightly, and a four-week lag covering the dominant retail seasonality. +EXOGENOUS_LAGS: tuple[int, ...] = (1, 7, 14, 28) +# Observed-target tail (days, ending at the forecast origin T) fed to the +# generator — 90 comfortably exceeds the largest lag offset (28). +HISTORY_TAIL_DAYS: int = 90 +# Upper bound on the multi-scenario comparison (Phase C) so the chart stays +# legible; defined here as the slice's single modelling-constants home. +MAX_COMPARE_SCENARIOS: int = 5 + +# Fixed calendar columns — each a pure function of the date, never a leak. +CALENDAR_COLUMNS: tuple[str, ...] = ( + "dow_sin", + "dow_cos", + "month_sin", + "month_cos", + "is_weekend", + "is_month_end", +) +# Fixed current-day exogenous columns — driven by the scenario assumptions +# (the planner's posited future inputs) and by timeless attributes (the +# calendar, the product launch date). Every value is knowable at origin T. +EXOGENOUS_COLUMNS: tuple[str, ...] = ( + "price_factor", + "promo_active", + "is_holiday", + "days_since_launch", +) + + +@dataclass +class FutureFeatureFrame: + """A horizon-length feature matrix for one ``(store, product)`` series. + + Attributes: + dates: The horizon days ``T+1 … T+horizon`` (chronological). + feature_columns: Column order — matches the trained bundle exactly. + matrix: Row-major ``[horizon][n_features]``; ``NaN`` is allowed and + expected (a long-lag cell whose source target lies in the future, + or ``days_since_launch`` when the product has no launch date). + """ + + dates: list[date] + feature_columns: list[str] + matrix: list[list[float]] + + +def canonical_feature_columns(lags: tuple[int, ...] = EXOGENOUS_LAGS) -> list[str]: + """Return the fixed, ordered regression feature-column list. + + This is the single source of truth for the regression feature set. The + Phase B training path persists exactly this list in the model bundle's + metadata; the future frame reproduces it column-for-column. The column + set is deliberately *fixed* (not horizon-dependent): for a long horizon + some target-lag columns are mostly ``NaN``, which the NaN-tolerant + estimator handles — far safer than a horizon-varying column set. + + Args: + lags: Target long-lag offsets (defaults to the pinned ``EXOGENOUS_LAGS``). + + Returns: + Ordered column names: target lags, then calendar, then exogenous. + """ + target_lags = [f"lag_{k}" for k in lags] + return [*target_lags, *CALENDAR_COLUMNS, *EXOGENOUS_COLUMNS] + + +def _in_window(point_date: date, start: date, end: date) -> bool: + """True when ``point_date`` is inside the inclusive ``[start, end]`` window. + + A reversed window (``start`` after ``end``) is normalised rather than + treated as empty — junk input must never raise (mirrors + ``adjustments._in_window``). + """ + lo, hi = (start, end) if start <= end else (end, start) + return lo <= point_date <= hi + + +def _is_month_end(point_date: date) -> bool: + """True when ``point_date`` is the last day of its month.""" + return (point_date + timedelta(days=1)).month != point_date.month + + +def build_calendar_columns(dates: list[date]) -> dict[str, list[float]]: + """Build the calendar feature columns — a pure function of each date. + + Calendar features carry zero leakage risk: they read only the date + itself, never the target series. Day-of-week and month use cyclical + (sin/cos) encoding so the estimator sees their periodic structure. + + Args: + dates: The horizon days. + + Returns: + A mapping of every name in :data:`CALENDAR_COLUMNS` to its per-day + values. + """ + columns: dict[str, list[float]] = {name: [] for name in CALENDAR_COLUMNS} + for point_date in dates: + dow = point_date.weekday() # 0 = Monday … 6 = Sunday + month = point_date.month + columns["dow_sin"].append(math.sin(2.0 * math.pi * dow / 7.0)) + columns["dow_cos"].append(math.cos(2.0 * math.pi * dow / 7.0)) + columns["month_sin"].append(math.sin(2.0 * math.pi * month / 12.0)) + columns["month_cos"].append(math.cos(2.0 * math.pi * month / 12.0)) + columns["is_weekend"].append(1.0 if dow >= 5 else 0.0) + columns["is_month_end"].append(1.0 if _is_month_end(point_date) else 0.0) + return columns + + +def build_long_lag_columns( + history_tail: list[float], + horizon: int, + lags: tuple[int, ...] = EXOGENOUS_LAGS, +) -> dict[str, list[float]]: + """Build the target long-lag columns — the leakage-critical helper. + + ``history_tail`` is the observed target series ending at the forecast + origin ``T``: ``history_tail[-1] == y[T]``, ``history_tail[-2] == y[T-1]``, + and so on. The lag-``k`` column at horizon day ``T+j`` (``j`` in + ``1 … horizon``) is the observed target ``y[T+j-k]``. + + SAFETY (PRP-27 DECISIONS LOCKED #4): the source index into + ``history_tail`` is ``idx = (j - 1) - k``. The cell is populated **only + when ``idx < 0``** — i.e. the source day ``T+j-k`` lies at or before the + origin ``T`` and therefore inside ``history_tail``. When ``idx >= 0`` the + source day is a *future* horizon day with no observed target, so the cell + is ``NaN`` — never a recursive prediction, never a fabricated value. This + function structurally **cannot** read a future target: its only data + input is ``history_tail`` (entirely ``<= T``). + + Args: + history_tail: Observed target values ending at the origin ``T``. + horizon: Number of horizon days. + lags: Lag offsets (defaults to the pinned ``EXOGENOUS_LAGS``). + + Returns: + A mapping ``"lag_{k}" -> [horizon values]``; out-of-range cells are + ``NaN``. + """ + tail_len = len(history_tail) + columns: dict[str, list[float]] = {} + for lag in lags: + column: list[float] = [] + for j in range(1, horizon + 1): + # Negative index from the end of history_tail. idx < 0 means the + # source day T+j-k is at/before the origin T — safe to read. + idx = (j - 1) - lag + if idx < 0 and -tail_len <= idx: + column.append(float(history_tail[idx])) + else: + column.append(math.nan) + columns[f"lag_{lag}"] = column + return columns + + +def build_exogenous_columns( + dates: list[date], + assumptions: ScenarioAssumptions, + holiday_dates: set[date], + launch_date: date | None, +) -> dict[str, list[float]]: + """Build the current-day exogenous columns from the scenario assumptions. + + These columns are the *intended* what-if input — the planner is positing a + future price / promotion / holiday — so reading them is not leakage. Each + is knowable at origin ``T``: + + * ``price_factor`` — ``1.0`` (the typical price) outside any price window, + ``1.0 + change_pct`` inside it. + * ``promo_active`` — ``1.0`` when a promotion assumption covers the day. + * ``is_holiday`` — ``1.0`` when the day is in the holiday assumption OR a + ``calendar`` holiday (a calendar row is a timeless attribute). + * ``days_since_launch`` — ``(date - launch_date).days``, a pure function of + the date; ``NaN`` when the product has no launch date. + + Args: + dates: The horizon days. + assumptions: The scenario assumptions. + holiday_dates: Calendar holiday dates inside the horizon. + launch_date: The product's launch date, or ``None``. + + Returns: + A mapping of every name in :data:`EXOGENOUS_COLUMNS` to its per-day + values. + """ + price = assumptions.price + promotion = assumptions.promotion + holiday = assumptions.holiday + assumption_holidays: set[date] = set(holiday.dates) if holiday is not None else set() + + price_factor: list[float] = [] + promo_active: list[float] = [] + is_holiday: list[float] = [] + days_since_launch: list[float] = [] + + for point_date in dates: + if price is not None and _in_window(point_date, price.start_date, price.end_date): + price_factor.append(1.0 + price.change_pct) + else: + price_factor.append(1.0) + + if promotion is not None and _in_window( + point_date, promotion.start_date, promotion.end_date + ): + promo_active.append(1.0) + else: + promo_active.append(0.0) + + is_holiday.append( + 1.0 if point_date in assumption_holidays or point_date in holiday_dates else 0.0 + ) + + if launch_date is not None: + days_since_launch.append(float((point_date - launch_date).days)) + else: + days_since_launch.append(math.nan) + + return { + "price_factor": price_factor, + "promo_active": promo_active, + "is_holiday": is_holiday, + "days_since_launch": days_since_launch, + } + + +def assemble_future_frame( + *, + dates: list[date], + feature_columns: list[str], + history_tail: list[float], + assumptions: ScenarioAssumptions, + holiday_dates: set[date], + launch_date: date | None, +) -> FutureFeatureFrame: + """Assemble a :class:`FutureFeatureFrame` from already-resolved inputs. + + Pure (no DB, no I/O) so it is fully unit-testable; :func:`build_future_frame` + is the thin async wrapper that resolves ``holiday_dates`` from the + ``calendar`` table first. + + Any requested column not produced by the builders is filled with ``NaN`` + so the matrix always matches ``feature_columns`` in width and order. + + Args: + dates: The horizon days ``T+1 … T+horizon``. + feature_columns: The exact column order to emit. + history_tail: Observed target values ending at the origin ``T``. + assumptions: The scenario assumptions. + holiday_dates: Calendar holiday dates inside the horizon. + launch_date: The product's launch date, or ``None``. + + Returns: + The assembled future feature frame. + """ + horizon = len(dates) + column_data: dict[str, list[float]] = {} + column_data.update(build_long_lag_columns(history_tail, horizon)) + column_data.update(build_calendar_columns(dates)) + column_data.update(build_exogenous_columns(dates, assumptions, holiday_dates, launch_date)) + + # Defensive: any column the trained bundle expects but this generator does + # not produce becomes an all-NaN column (the estimator tolerates NaN). + for column in feature_columns: + if column not in column_data: + column_data[column] = [math.nan] * horizon + + matrix: list[list[float]] = [ + [column_data[column][j] for column in feature_columns] for j in range(horizon) + ] + return FutureFeatureFrame( + dates=list(dates), + feature_columns=list(feature_columns), + matrix=matrix, + ) + + +async def build_future_frame( + db: AsyncSession, + *, + store_id: int, + product_id: int, + forecast_origin: date, + horizon: int, + feature_columns: list[str], + history_tail: list[float], + assumptions: ScenarioAssumptions, + launch_date: date | None = None, +) -> FutureFeatureFrame: + """Build the future feature frame for one ``(store, product)`` series. + + The only database read is the ``calendar`` holiday lookup for the horizon + window — a ``calendar`` row is a timeless attribute, so reading it is not + leakage. Everything else is derived from ``history_tail`` (observed, + ``<= T``), the dates, or the assumptions. + + Args: + db: Async database session (used only for the calendar lookup). + store_id: Store the baseline model targets (logged). + product_id: Product the baseline model targets (logged). + forecast_origin: The origin ``T`` — the last training day. The horizon + runs ``T+1 … T+horizon``. + horizon: Number of horizon days (``>= 1``). + feature_columns: The trained bundle's feature-column order. + history_tail: Observed target values ending at ``T``. + assumptions: The scenario assumptions. + launch_date: The product's launch date, or ``None``. + + Returns: + The assembled future feature frame. + + Raises: + ValueError: When ``horizon`` is below 1. + """ + if horizon < 1: + raise ValueError(f"horizon must be >= 1, got {horizon}") + + dates = [forecast_origin + timedelta(days=offset) for offset in range(1, horizon + 1)] + + result = await db.execute( + select(Calendar.date).where( + Calendar.date >= dates[0], + Calendar.date <= dates[-1], + Calendar.is_holiday.is_(True), + ) + ) + holiday_dates: set[date] = set(result.scalars().all()) + + frame = assemble_future_frame( + dates=dates, + feature_columns=feature_columns, + history_tail=history_tail, + assumptions=assumptions, + holiday_dates=holiday_dates, + launch_date=launch_date, + ) + logger.info( + "scenarios.future_frame_built", + store_id=store_id, + product_id=product_id, + horizon=horizon, + n_features=len(feature_columns), + n_calendar_holidays=len(holiday_dates), + ) + return frame diff --git a/app/features/scenarios/tests/test_feature_frame.py b/app/features/scenarios/tests/test_feature_frame.py new file mode 100644 index 00000000..22306eb1 --- /dev/null +++ b/app/features/scenarios/tests/test_feature_frame.py @@ -0,0 +1,214 @@ +"""Unit tests for the future feature-frame generator (PRP-27 Phase A). + +These exercise the pure builders — calendar columns, target long-lag columns, +assumption-driven exogenous columns, and the :func:`assemble_future_frame` +orchestration. The leakage invariants live separately in +``test_future_frame_leakage.py`` (the load-bearing spec). +""" + +from __future__ import annotations + +import math +from datetime import date, timedelta + +from app.features.scenarios.feature_frame import ( + CALENDAR_COLUMNS, + EXOGENOUS_COLUMNS, + EXOGENOUS_LAGS, + HISTORY_TAIL_DAYS, + MAX_COMPARE_SCENARIOS, + assemble_future_frame, + build_calendar_columns, + build_exogenous_columns, + build_long_lag_columns, + canonical_feature_columns, +) +from app.features.scenarios.schemas import ( + HolidayAssumption, + PriceAssumption, + PromotionAssumption, + ScenarioAssumptions, +) + +_ORIGIN = date(2026, 6, 30) +_HORIZON = 14 +_HORIZON_DATES = [_ORIGIN + timedelta(days=offset) for offset in range(1, _HORIZON + 1)] + + +# --- pinned constants --------------------------------------------------------- + + +def test_pinned_constants() -> None: + """The PRP-27 pinned modelling constants hold their decided values.""" + assert EXOGENOUS_LAGS == (1, 7, 14, 28) + assert HISTORY_TAIL_DAYS == 90 + assert MAX_COMPARE_SCENARIOS == 5 + + +def test_canonical_feature_columns_order() -> None: + """The canonical column list is target lags, then calendar, then exogenous.""" + columns = canonical_feature_columns() + assert columns[:4] == ["lag_1", "lag_7", "lag_14", "lag_28"] + assert columns[4 : 4 + len(CALENDAR_COLUMNS)] == list(CALENDAR_COLUMNS) + assert columns[-len(EXOGENOUS_COLUMNS) :] == list(EXOGENOUS_COLUMNS) + assert len(columns) == len(EXOGENOUS_LAGS) + len(CALENDAR_COLUMNS) + len(EXOGENOUS_COLUMNS) + + +# --- calendar columns --------------------------------------------------------- + + +def test_calendar_columns_are_pure_function_of_date() -> None: + """Calendar columns depend only on the dates — two calls match exactly.""" + first = build_calendar_columns(_HORIZON_DATES) + second = build_calendar_columns(list(_HORIZON_DATES)) + assert first == second + assert set(first) == set(CALENDAR_COLUMNS) + for values in first.values(): + assert len(values) == _HORIZON + + +def test_calendar_is_weekend_and_month_end() -> None: + """``is_weekend`` and ``is_month_end`` reflect the date itself.""" + dates = [date(2026, 7, 30), date(2026, 7, 31), date(2026, 8, 1), date(2026, 8, 2)] + columns = build_calendar_columns(dates) + # 2026-07-31 is the month end; 2026-08-01 (Sat) and 2026-08-02 (Sun) weekend. + assert columns["is_month_end"] == [0.0, 1.0, 0.0, 0.0] + assert columns["is_weekend"] == [1.0 if d.weekday() >= 5 else 0.0 for d in dates] + + +def test_calendar_cyclical_encoding_bounded() -> None: + """Cyclical sin/cos encodings stay within [-1, 1].""" + columns = build_calendar_columns(_HORIZON_DATES) + for name in ("dow_sin", "dow_cos", "month_sin", "month_cos"): + assert all(-1.0 <= value <= 1.0 for value in columns[name]) + + +# --- long-lag columns --------------------------------------------------------- + + +def test_long_lag_indexing_is_correct() -> None: + """``lag_k`` at horizon day ``j`` equals the observed ``y[T+j-k]``.""" + # history_tail[-1] == y[T], history_tail[-2] == y[T-1], ... + history_tail = [float(value) for value in range(HISTORY_TAIL_DAYS)] + columns = build_long_lag_columns(history_tail, _HORIZON) + + assert set(columns) == {f"lag_{k}" for k in EXOGENOUS_LAGS} + # lag_1 at j=1 reads history_tail[-1] (y[T]); j>=2 needs a future target. + assert columns["lag_1"][0] == history_tail[-1] + assert all(math.isnan(value) for value in columns["lag_1"][1:]) + # lag_7 at j=1 reads history_tail[-7]; populated for j in 1..7. + assert columns["lag_7"][0] == history_tail[-7] + for j in range(1, _HORIZON + 1): + cell = columns["lag_7"][j - 1] + if j <= 7: + assert cell == history_tail[(j - 1) - 7] + else: + assert math.isnan(cell) + + +def test_long_lag_all_columns_present_for_long_horizon() -> None: + """Every lag column is emitted even when the horizon exceeds the offset.""" + history_tail = [float(value) for value in range(HISTORY_TAIL_DAYS)] + columns = build_long_lag_columns(history_tail, horizon=40) + assert set(columns) == {f"lag_{k}" for k in EXOGENOUS_LAGS} + # lag_28 over a 40-day horizon: first 28 days populated, last 12 NaN. + assert all(not math.isnan(v) for v in columns["lag_28"][:28]) + assert all(math.isnan(v) for v in columns["lag_28"][28:]) + + +def test_long_lag_short_history_yields_nan() -> None: + """A history shorter than the lag offset produces NaN, never an error.""" + columns = build_long_lag_columns([5.0, 6.0, 7.0], horizon=4) + # lag_28 cannot resolve from a 3-element tail — all NaN. + assert all(math.isnan(value) for value in columns["lag_28"]) + # lag_1 at j=1 still resolves to the origin observation. + assert columns["lag_1"][0] == 7.0 + + +# --- exogenous columns -------------------------------------------------------- + + +def test_exogenous_price_window() -> None: + """``price_factor`` is ``1 + change_pct`` inside the window, ``1.0`` outside.""" + assumptions = ScenarioAssumptions( + price=PriceAssumption( + change_pct=-0.15, + start_date=_HORIZON_DATES[2], + end_date=_HORIZON_DATES[5], + ) + ) + columns = build_exogenous_columns(_HORIZON_DATES, assumptions, set(), launch_date=None) + for index, value in enumerate(columns["price_factor"]): + if 2 <= index <= 5: + assert value == 0.85 + else: + assert value == 1.0 + + +def test_exogenous_promo_and_holiday() -> None: + """``promo_active`` flags the promotion window; ``is_holiday`` unions sources.""" + calendar_holiday = _HORIZON_DATES[0] + assumption_holiday = _HORIZON_DATES[9] + assumptions = ScenarioAssumptions( + promotion=PromotionAssumption( + kind="pct_off", + start_date=_HORIZON_DATES[1], + end_date=_HORIZON_DATES[3], + ), + holiday=HolidayAssumption(dates=[assumption_holiday]), + ) + columns = build_exogenous_columns( + _HORIZON_DATES, assumptions, {calendar_holiday}, launch_date=None + ) + assert [i for i, v in enumerate(columns["promo_active"]) if v == 1.0] == [1, 2, 3] + # is_holiday unions the calendar holiday and the assumption holiday. + assert [i for i, v in enumerate(columns["is_holiday"]) if v == 1.0] == [0, 9] + + +def test_exogenous_days_since_launch() -> None: + """``days_since_launch`` is a pure date delta, NaN without a launch date.""" + launch = _ORIGIN - timedelta(days=100) + with_launch = build_exogenous_columns( + _HORIZON_DATES, ScenarioAssumptions(), set(), launch_date=launch + ) + assert with_launch["days_since_launch"][0] == float((_HORIZON_DATES[0] - launch).days) + without_launch = build_exogenous_columns( + _HORIZON_DATES, ScenarioAssumptions(), set(), launch_date=None + ) + assert all(math.isnan(value) for value in without_launch["days_since_launch"]) + + +# --- assembly ----------------------------------------------------------------- + + +def test_assemble_future_frame_shape_and_order() -> None: + """The assembled matrix matches ``feature_columns`` in width and order.""" + columns = canonical_feature_columns() + history_tail = [float(value) for value in range(HISTORY_TAIL_DAYS)] + frame = assemble_future_frame( + dates=_HORIZON_DATES, + feature_columns=columns, + history_tail=history_tail, + assumptions=ScenarioAssumptions(), + holiday_dates=set(), + launch_date=None, + ) + assert frame.feature_columns == columns + assert frame.dates == _HORIZON_DATES + assert len(frame.matrix) == _HORIZON + assert all(len(row) == len(columns) for row in frame.matrix) + + +def test_assemble_future_frame_unknown_column_is_nan() -> None: + """A requested column the builders do not produce becomes an all-NaN column.""" + columns = [*canonical_feature_columns(), "mystery_feature"] + frame = assemble_future_frame( + dates=_HORIZON_DATES, + feature_columns=columns, + history_tail=[float(v) for v in range(HISTORY_TAIL_DAYS)], + assumptions=ScenarioAssumptions(), + holiday_dates=set(), + launch_date=None, + ) + mystery_index = columns.index("mystery_feature") + assert all(math.isnan(row[mystery_index]) for row in frame.matrix) diff --git a/app/features/scenarios/tests/test_future_frame_leakage.py b/app/features/scenarios/tests/test_future_frame_leakage.py new file mode 100644 index 00000000..4a4659de --- /dev/null +++ b/app/features/scenarios/tests/test_future_frame_leakage.py @@ -0,0 +1,166 @@ +"""Leakage spec for the future feature frame — LOAD-BEARING (PRP-27 Phase A). + +This file IS the spec, mirroring ``app/features/featuresets/tests/test_leakage.py`` +and ``app/features/scenarios/tests/test_leakage.py``: it must NEVER be weakened +to make a feature pass (AGENTS.md § Safety). + +The model-driven scenario path re-forecasts demand through a feature-consuming +regressor, which means it builds a *future feature frame*. A horizon day has no +observed target, so the invariant is: + + A future feature value for horizon day ``D`` may use ONLY information + knowable at the forecast origin ``T``: the observed history up to and + including ``T``, the calendar (a pure function of the date), or the + scenario assumptions (the planner's posited future inputs). It may NEVER + read an observed target at a horizon day ``D`` (which lies after ``T``). + +Concretely this spec asserts: + +1. ``build_long_lag_columns`` returns only values drawn from ``history_tail`` + (entirely ``<= T``) or ``NaN`` — never a value from the future target + series. +2. A lag cell whose source day lies at or after the first horizon day is + ``NaN`` — the generator never fabricates or recursively predicts it. +3. Calendar columns are independent of the target series entirely. +4. An assumption window that falls before the forecast origin contributes + nothing — every horizon day lies strictly after ``T``. +5. Every non-``NaN`` ``lag_*`` cell in an assembled frame is a member of + ``history_tail``. +""" + +from __future__ import annotations + +import math +from datetime import date, timedelta + +from app.features.scenarios.feature_frame import ( + EXOGENOUS_LAGS, + assemble_future_frame, + build_calendar_columns, + build_exogenous_columns, + build_long_lag_columns, + canonical_feature_columns, +) +from app.features.scenarios.schemas import PriceAssumption, ScenarioAssumptions + +# The forecast origin T is the last observed day; the horizon runs T+1 … T+H. +_ORIGIN = date(2026, 6, 30) +_HORIZON = 21 +_HORIZON_DATES = [_ORIGIN + timedelta(days=offset) for offset in range(1, _HORIZON + 1)] + +# Observed history (all <= T): 90 distinct values 1000.0 … 1089.0. +# history_tail[-1] == y[T], the origin observation. +_HISTORY_TAIL = [1000.0 + float(i) for i in range(90)] +# A DISJOINT "future target" series the generator must never be able to read. +# Any of these values appearing in a feature cell is a leak. +_FUTURE_TARGETS = {9000.0 + float(i) for i in range(_HORIZON)} + + +def test_long_lag_columns_never_emit_a_future_target() -> None: + """Every non-NaN long-lag cell is drawn from the observed history. + + ``build_long_lag_columns`` takes ONLY ``history_tail`` as data input — it + is structurally incapable of reading the future target series. This spec + pins that: no value disjoint from ``history_tail`` may ever appear. + """ + history_values = set(_HISTORY_TAIL) + columns = build_long_lag_columns(_HISTORY_TAIL, _HORIZON) + + for name, values in columns.items(): + for cell in values: + if math.isnan(cell): + continue + assert cell in history_values, ( + f"{name} emitted {cell}, which is not an observed history value" + ) + assert cell not in _FUTURE_TARGETS, f"{name} leaked a future target value {cell}" + + +def test_long_lag_source_index_is_never_at_or_after_the_horizon() -> None: + """A lag cell is populated only when its source day lies at/before ``T``. + + For lag ``k`` and horizon day ``j`` the source index into ``history_tail`` + is ``(j-1)-k``. A non-NaN cell REQUIRES that index to be negative — i.e. + the source target lies at or before the origin ``T``. A non-negative index + would point at a future horizon day and MUST yield ``NaN``. + """ + columns = build_long_lag_columns(_HISTORY_TAIL, _HORIZON) + for lag in EXOGENOUS_LAGS: + column = columns[f"lag_{lag}"] + for j in range(1, _HORIZON + 1): + source_index = (j - 1) - lag + cell = column[j - 1] + if source_index >= 0: + assert math.isnan(cell), ( + f"lag_{lag} day {j}: source index {source_index} is in the " + "future but the cell is not NaN" + ) + else: + assert not math.isnan(cell), ( + f"lag_{lag} day {j}: source index {source_index} is in " + "history but the cell is NaN" + ) + + +def test_calendar_columns_are_independent_of_the_target_series() -> None: + """Calendar columns read only the dates — they cannot leak the target. + + ``build_calendar_columns`` does not accept the target series at all; this + spec pins that structural fact by asserting its output is identical no + matter what history precedes it. + """ + calendar_a = build_calendar_columns(_HORIZON_DATES) + calendar_b = build_calendar_columns(_HORIZON_DATES) + assert calendar_a == calendar_b + # No calendar value coincides with a history or future target value. + history_values = set(_HISTORY_TAIL) + for values in calendar_a.values(): + for cell in values: + assert cell not in history_values + assert cell not in _FUTURE_TARGETS + + +def test_assumption_window_before_origin_has_no_effect() -> None: + """A price window entirely before the forecast origin contributes nothing. + + Every horizon day lies strictly after ``T``; a window that ends on or + before ``T`` can never intersect the horizon, so ``price_factor`` stays + neutral (``1.0``) for every day — the assumption cannot reach into history. + """ + past_window = ScenarioAssumptions( + price=PriceAssumption( + change_pct=-0.40, + start_date=_ORIGIN - timedelta(days=30), + end_date=_ORIGIN, + ) + ) + columns = build_exogenous_columns(_HORIZON_DATES, past_window, set(), launch_date=None) + assert columns["price_factor"] == [1.0] * _HORIZON, ( + "a price window ending at/before the origin must not move price_factor" + ) + + +def test_assembled_frame_lag_cells_are_history_or_nan() -> None: + """Every non-NaN ``lag_*`` cell in an assembled frame is an observed value. + + This is the end-to-end leakage assertion: assemble a full frame and verify + every target-lag column still only ever shows a history value or ``NaN``. + """ + columns = canonical_feature_columns() + frame = assemble_future_frame( + dates=_HORIZON_DATES, + feature_columns=columns, + history_tail=_HISTORY_TAIL, + assumptions=ScenarioAssumptions(), + holiday_dates=set(), + launch_date=None, + ) + history_values = set(_HISTORY_TAIL) + lag_indices = {columns.index(f"lag_{k}") for k in EXOGENOUS_LAGS} + for row in frame.matrix: + for col_index in lag_indices: + cell = row[col_index] + if math.isnan(cell): + continue + assert cell in history_values, f"assembled frame leaked non-history value {cell}" + assert cell not in _FUTURE_TARGETS, f"assembled frame leaked future target {cell}" From b0e4b9d2f30c1e0853c2496c02c504d93d739399 Mon Sep 17 00:00:00 2001 From: Gabor Szabo Date: Tue, 19 May 2026 10:26:13 +0200 Subject: [PATCH 2/9] feat(forecast): add exogenous-regressor forecaster and regression training path (#223) --- app/features/forecasting/models.py | 155 +++++++++- app/features/forecasting/schemas.py | 51 +++- app/features/forecasting/service.py | 270 ++++++++++++++++-- .../tests/test_regression_forecaster.py | 116 ++++++++ 4 files changed, 569 insertions(+), 23 deletions(-) create mode 100644 app/features/forecasting/tests/test_regression_forecaster.py diff --git a/app/features/forecasting/models.py b/app/features/forecasting/models.py index 04f9dc05..ecb510b8 100644 --- a/app/features/forecasting/models.py +++ b/app/features/forecasting/models.py @@ -17,6 +17,9 @@ from typing import TYPE_CHECKING, Any, Literal import numpy as np +from sklearn.ensemble import ( # type: ignore[import-untyped] + HistGradientBoostingRegressor, +) if TYPE_CHECKING: from app.features.forecasting.schemas import ModelConfig @@ -422,8 +425,147 @@ def set_params(self, **params: Any) -> MovingAverageForecaster: # noqa: ANN401 return self +class RegressionForecaster(BaseForecaster): + """Feature-driven forecaster wrapping ``HistGradientBoostingRegressor``. + + CRITICAL: this is the FIRST forecaster that *consumes* the exogenous ``X`` + argument — the baseline forecasters all ignore it (each ``fit``/``predict`` + carries ``# noqa: ARG002``). Both ``fit`` and ``predict`` therefore REQUIRE + a non-``None`` ``X`` whose row count matches, and raise ``ValueError`` + otherwise — a regression model cannot forecast without its feature frame. + + ``HistGradientBoostingRegressor`` is deterministic given a fixed + ``random_state`` and tolerates ``NaN`` natively, which matters because the + future feature frame leaves lag cells ``NaN`` when their source target + lies in the (un-observed) horizon. + + Attributes: + max_iter: Number of boosting iterations. + learning_rate: Gradient-boosting learning rate. + max_depth: Maximum depth of each tree. + """ + + def __init__( + self, + *, + max_iter: int = 200, + learning_rate: float = 0.05, + max_depth: int = 6, + random_state: int = 42, + ) -> None: + """Initialize the regression forecaster. + + Args: + max_iter: Number of boosting iterations. + learning_rate: Gradient-boosting learning rate. + max_depth: Maximum depth of each tree. + random_state: Random seed for reproducibility (determinism). + """ + super().__init__(random_state) + self.max_iter = max_iter + self.learning_rate = learning_rate + self.max_depth = max_depth + self._estimator: Any = None + + def fit( + self, + y: np.ndarray[Any, np.dtype[np.floating[Any]]], + X: np.ndarray[Any, np.dtype[np.floating[Any]]] | None = None, + ) -> RegressionForecaster: + """Fit the gradient-boosted regressor on historical features. + + Args: + y: Target values (1D array of shape ``[n_samples]``). + X: Exogenous features (2D array of shape ``[n_samples, n_features]``). + REQUIRED — unlike the baseline forecasters. + + Returns: + self (for method chaining). + + Raises: + ValueError: If ``X`` is ``None``, ``y`` is empty, or the row counts + of ``X`` and ``y`` do not match. + """ + if X is None: + raise ValueError("RegressionForecaster requires exogenous features X for fit()") + if len(y) == 0: + raise ValueError("Cannot fit on empty array") + if X.shape[0] != len(y): + raise ValueError( + f"X has {X.shape[0]} rows but y has {len(y)} — feature/target rows must match" + ) + estimator: Any = HistGradientBoostingRegressor( + max_iter=self.max_iter, + learning_rate=self.learning_rate, + max_depth=self.max_depth, + random_state=self.random_state, + ) + estimator.fit(X, y) + self._estimator = estimator + self._last_values = np.asarray(y[-1:], dtype=np.float64) + self._is_fitted = True + return self + + def predict( + self, + horizon: int, + X: np.ndarray[Any, np.dtype[np.floating[Any]]] | None = None, + ) -> np.ndarray[Any, np.dtype[np.floating[Any]]]: + """Generate forecasts from a future feature frame. + + Args: + horizon: Number of steps to forecast. + X: Exogenous features for the forecast period, shape + ``[horizon, n_features]``. REQUIRED. + + Returns: + Array of forecasts with shape ``[horizon]``. + + Raises: + RuntimeError: If the model has not been fitted. + ValueError: If ``X`` is ``None`` or its row count is not ``horizon``. + """ + if not self._is_fitted or self._estimator is None: + raise RuntimeError("Model must be fitted before predict") + if X is None: + raise ValueError("RegressionForecaster requires exogenous features X for predict()") + if X.shape[0] != horizon: + raise ValueError(f"X has {X.shape[0]} rows but horizon is {horizon} — they must match") + predictions = self._estimator.predict(X) + result: np.ndarray[Any, np.dtype[np.floating[Any]]] = np.asarray( + predictions, dtype=np.float64 + ) + return result + + def get_params(self) -> dict[str, Any]: + """Get model parameters. + + Returns: + Dictionary with max_iter, learning_rate, max_depth, random_state. + """ + return { + "max_iter": self.max_iter, + "learning_rate": self.learning_rate, + "max_depth": self.max_depth, + "random_state": self.random_state, + } + + def set_params(self, **params: Any) -> RegressionForecaster: # noqa: ANN401 + """Set model parameters. + + Args: + **params: Parameter names and values to set. + + Returns: + self (for method chaining). + """ + for key, value in params.items(): + setattr(self, key, value) + return self + + # Type alias for model type literals -ModelType = Literal["naive", "seasonal_naive", "moving_average", "lightgbm"] +ModelType = Literal["naive", "seasonal_naive", "moving_average", "lightgbm", "regression"] def model_factory(config: ModelConfig, random_state: int = 42) -> BaseForecaster: @@ -472,5 +614,16 @@ def model_factory(config: ModelConfig, random_state: int = 42) -> BaseForecaster ) # LightGBM implementation would go here when feature-flagged raise NotImplementedError("LightGBM forecaster not yet implemented") + elif model_type == "regression": + from app.features.forecasting.schemas import RegressionModelConfig + + if isinstance(config, RegressionModelConfig): + return RegressionForecaster( + max_iter=config.max_iter, + learning_rate=config.learning_rate, + max_depth=config.max_depth, + random_state=random_state, + ) + raise ValueError("Invalid config type for regression") else: raise ValueError(f"Unknown model type: {model_type}") diff --git a/app/features/forecasting/schemas.py b/app/features/forecasting/schemas.py index 33d534f7..b019529c 100644 --- a/app/features/forecasting/schemas.py +++ b/app/features/forecasting/schemas.py @@ -144,9 +144,58 @@ class LightGBMModelConfig(ModelConfigBase): ) +class RegressionModelConfig(ModelConfigBase): + """Configuration for the exogenous-regressor forecaster (PRP-27). + + Wraps scikit-learn's ``HistGradientBoostingRegressor`` — a deterministic, + NaN-tolerant gradient-boosted tree model. Unlike the baseline forecasters, + a ``regression`` model *consumes* a per-day exogenous feature frame, so a + scenario what-if can be answered by genuinely re-forecasting demand + (``method="model_exogenous"``) rather than by a post-forecast multiplier. + + No feature flag and no new dependency: ``HistGradientBoostingRegressor`` + ships with the already-pinned ``scikit-learn`` (see + ``PRPs/ai_docs/exogenous-regressor-forecasting.md`` § 5). + + Attributes: + max_iter: Number of boosting iterations. + learning_rate: Gradient-boosting learning rate. + max_depth: Maximum depth of each tree. + feature_config_hash: Optional hash of the feature contract used. + """ + + model_type: Literal["regression"] = "regression" + max_iter: int = Field( + default=200, + ge=10, + le=1000, + description="Number of boosting iterations", + ) + learning_rate: float = Field( + default=0.05, + ge=0.001, + le=1.0, + description="Gradient-boosting learning rate", + ) + max_depth: int = Field( + default=6, + ge=1, + le=20, + description="Maximum depth of each tree", + ) + feature_config_hash: str | None = Field( + default=None, + description="Hash of the feature contract used for training", + ) + + # Union type for all model configs ModelConfig = ( - NaiveModelConfig | SeasonalNaiveModelConfig | MovingAverageModelConfig | LightGBMModelConfig + NaiveModelConfig + | SeasonalNaiveModelConfig + | MovingAverageModelConfig + | LightGBMModelConfig + | RegressionModelConfig ) diff --git a/app/features/forecasting/service.py b/app/features/forecasting/service.py index a1cf03fa..839f1475 100644 --- a/app/features/forecasting/service.py +++ b/app/features/forecasting/service.py @@ -11,6 +11,7 @@ from __future__ import annotations +import math import time import uuid from dataclasses import dataclass, field @@ -25,7 +26,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.core.config import get_settings -from app.features.data_platform.models import SalesDaily +from app.features.data_platform.models import Calendar, Product, Promotion, SalesDaily from app.features.forecasting.models import model_factory from app.features.forecasting.persistence import ( ModelBundle, @@ -68,6 +69,60 @@ def __post_init__(self) -> None: self.n_observations = len(self.y) +# Minimum observed rows required to train a regression model — enough to +# resolve the lag features and still leave training signal (PRP-27 GOTCHA #14). +_MIN_REGRESSION_TRAIN_ROWS = 30 +# Observed-target tail persisted in the bundle so the scenario future-frame +# generator can resolve long lags (PRP-27 DECISIONS LOCKED #11 — 90 days). +_REGRESSION_HISTORY_TAIL_DAYS = 90 +# Target lag offsets — PRP-27 DECISIONS LOCKED #10 (EXOGENOUS_LAGS). +_REGRESSION_LAGS: tuple[int, ...] = (1, 7, 14, 28) +# Canonical regression feature columns — a PAIRED CONTRACT with +# ``app/features/scenarios/feature_frame.canonical_feature_columns()``. The +# scenarios slice owns the future-frame generator; this slice owns training. +# A cross-slice import is forbidden (AGENTS.md § Architecture, PRP-27 +# DECISIONS LOCKED #3), so the column names and order are replicated here and +# kept in lock-step by the scenarios integration test (a column mismatch +# surfaces as a non-zero delta on an empty-assumption simulation). +_REGRESSION_FEATURE_COLUMNS: list[str] = [ + *(f"lag_{lag}" for lag in _REGRESSION_LAGS), + "dow_sin", + "dow_cos", + "month_sin", + "month_cos", + "is_weekend", + "is_month_end", + "price_factor", + "promo_active", + "is_holiday", + "days_since_launch", +] + + +@dataclass +class RegressionFeatureMatrix: + """Historical feature matrix + bundle metadata for a regression model. + + Attributes: + X: Feature matrix, shape ``[n_observations, n_features]`` (NaN allowed). + y: Target values, shape ``[n_observations]``. + feature_columns: Column order — persisted so the future frame matches. + history_tail: The last ``_REGRESSION_HISTORY_TAIL_DAYS`` observed + targets, ending at the forecast origin ``T``. + history_tail_dates: ISO dates aligned with ``history_tail``. + launch_date_iso: The product launch date (ISO) or ``None``. + n_observations: Number of training rows. + """ + + X: np.ndarray[Any, np.dtype[np.floating[Any]]] + y: np.ndarray[Any, np.dtype[np.floating[Any]]] + feature_columns: list[str] + history_tail: list[float] + history_tail_dates: list[str] + launch_date_iso: str | None + n_observations: int + + class ForecastingService: """Service for training and predicting with forecasting models. @@ -121,24 +176,43 @@ async def train_model( config_hash=config.config_hash(), ) - # Load training data - training_data = await self._load_training_data( - db=db, - store_id=store_id, - product_id=product_id, - start_date=train_start_date, - end_date=train_end_date, - ) - - if training_data.n_observations == 0: - raise ValueError( - f"No training data found for store={store_id}, product={product_id} " - f"between {train_start_date} and {train_end_date}" + # Build the model + bundle metadata. The regression path consumes a + # historical feature matrix; every other model trains on the raw + # target series exactly as before. + extra_metadata: dict[str, object] = {} + if config.model_type == "regression": + features = await self._build_regression_features( + db=db, + store_id=store_id, + product_id=product_id, + start_date=train_start_date, + end_date=train_end_date, ) - - # Create and fit model - model = model_factory(config, random_state=self.settings.forecast_random_seed) - model.fit(training_data.y) + model = model_factory(config, random_state=self.settings.forecast_random_seed) + model.fit(features.y, features.X) + n_observations = features.n_observations + extra_metadata = { + "feature_columns": features.feature_columns, + "history_tail": features.history_tail, + "history_tail_dates": features.history_tail_dates, + "launch_date": features.launch_date_iso, + } + else: + training_data = await self._load_training_data( + db=db, + store_id=store_id, + product_id=product_id, + start_date=train_start_date, + end_date=train_end_date, + ) + if training_data.n_observations == 0: + raise ValueError( + f"No training data found for store={store_id}, product={product_id} " + f"between {train_start_date} and {train_end_date}" + ) + model = model_factory(config, random_state=self.settings.forecast_random_seed) + model.fit(training_data.y) + n_observations = training_data.n_observations # Create bundle with metadata bundle = ModelBundle( @@ -149,7 +223,8 @@ async def train_model( "product_id": product_id, "train_start_date": str(train_start_date), "train_end_date": str(train_end_date), - "n_observations": training_data.n_observations, + "n_observations": n_observations, + **extra_metadata, }, ) @@ -166,7 +241,7 @@ async def train_model( product_id=product_id, model_type=config.model_type, config_hash=config.config_hash(), - n_observations=training_data.n_observations, + n_observations=n_observations, model_path=str(saved_path), duration_ms=duration_ms, ) @@ -177,7 +252,7 @@ async def train_model( model_type=config.model_type, model_path=str(saved_path), config_hash=config.config_hash(), - n_observations=training_data.n_observations, + n_observations=n_observations, train_start_date=train_start_date, train_end_date=train_end_date, duration_ms=duration_ms, @@ -267,6 +342,16 @@ async def predict( f"but prediction requested for product={product_id}" ) + # Regression models need an exogenous feature frame to forecast — that + # is built (from scenario assumptions) by POST /scenarios/simulate. The + # plain predict endpoint cannot supply one, so it rejects them cleanly. + if bundle.config.model_type == "regression": + raise ValueError( + "Regression models forecast through POST /scenarios/simulate, " + "which supplies the exogenous feature frame. POST /forecasting/" + "predict does not support model_type='regression'." + ) + # Generate forecasts forecasts_array = bundle.model.predict(horizon) @@ -365,3 +450,146 @@ async def _load_training_data( store_id=store_id, product_id=product_id, ) + + async def _build_regression_features( + self, + db: AsyncSession, + store_id: int, + product_id: int, + start_date: date_type, + end_date: date_type, + ) -> RegressionFeatureMatrix: + """Build the historical feature matrix for a regression model. + + Time-safe by construction: every lag column at row ``i`` reads only + the observed target at ``i - lag`` (a strictly earlier day); calendar + columns are pure functions of the date; ``price_factor`` / + ``promo_active`` / ``is_holiday`` / ``days_since_launch`` read the + same-day exogenous attributes. No row reads a future observation. + + The column set is the paired contract with the scenarios slice's + future-frame generator (see ``_REGRESSION_FEATURE_COLUMNS``). + + Args: + db: Database session. + store_id: Store ID. + product_id: Product ID. + start_date: Start of the training window (inclusive). + end_date: End of the training window (inclusive) — the origin ``T``. + + Returns: + The feature matrix plus the bundle metadata the future frame needs. + + Raises: + ValueError: When fewer than ``_MIN_REGRESSION_TRAIN_ROWS`` observed + days are available. + """ + sales_rows = ( + await db.execute( + select(SalesDaily.date, SalesDaily.quantity, SalesDaily.unit_price) + .where( + (SalesDaily.store_id == store_id) + & (SalesDaily.product_id == product_id) + & (SalesDaily.date >= start_date) + & (SalesDaily.date <= end_date) + ) + .order_by(SalesDaily.date) + ) + ).all() + if len(sales_rows) < _MIN_REGRESSION_TRAIN_ROWS: + raise ValueError( + f"A regression model needs at least {_MIN_REGRESSION_TRAIN_ROWS} " + f"observed days; store={store_id} product={product_id} has " + f"{len(sales_rows)} between {start_date} and {end_date}." + ) + + dates = [row.date for row in sales_rows] + quantities = [float(row.quantity) for row in sales_rows] + prices = [float(row.unit_price) for row in sales_rows] + + # Baseline price = median of the positive prices, so price_factor is + # ~1.0 on a typical day and < 1.0 on a markdown/promo day. + positive_prices = sorted(price for price in prices if price > 0.0) + baseline_price = positive_prices[len(positive_prices) // 2] if positive_prices else 1.0 + + holiday_dates: set[date_type] = set( + ( + await db.execute( + select(Calendar.date).where( + Calendar.date >= start_date, + Calendar.date <= end_date, + Calendar.is_holiday.is_(True), + ) + ) + ) + .scalars() + .all() + ) + + # Promotion-active days: store-specific OR chain-wide rows that overlap + # the training window, expanded to the set of dates they cover. + promo_rows = ( + await db.execute( + select(Promotion.start_date, Promotion.end_date).where( + Promotion.product_id == product_id, + (Promotion.store_id == store_id) | (Promotion.store_id.is_(None)), + Promotion.start_date <= end_date, + Promotion.end_date >= start_date, + ) + ) + ).all() + promo_dates: set[date_type] = set() + for promo in promo_rows: + day = max(promo.start_date, start_date) + last = min(promo.end_date, end_date) + while day <= last: + promo_dates.add(day) + day += timedelta(days=1) + + launch_date: date_type | None = await db.scalar( + select(Product.launch_date).where(Product.id == product_id) + ) + + feature_rows: list[list[float]] = [] + for index, day in enumerate(dates): + row_values: list[float] = [] + # Target long-lag columns — read only strictly-earlier observations. + for lag in _REGRESSION_LAGS: + row_values.append(quantities[index - lag] if index >= lag else math.nan) + # Calendar columns — pure functions of the date. + dow = day.weekday() + row_values.append(math.sin(2.0 * math.pi * dow / 7.0)) + row_values.append(math.cos(2.0 * math.pi * dow / 7.0)) + row_values.append(math.sin(2.0 * math.pi * day.month / 12.0)) + row_values.append(math.cos(2.0 * math.pi * day.month / 12.0)) + row_values.append(1.0 if dow >= 5 else 0.0) + row_values.append(1.0 if (day + timedelta(days=1)).month != day.month else 0.0) + # Exogenous columns — same-day attributes. + row_values.append(prices[index] / baseline_price) + row_values.append(1.0 if day in promo_dates else 0.0) + row_values.append(1.0 if day in holiday_dates else 0.0) + row_values.append( + float((day - launch_date).days) if launch_date is not None else math.nan + ) + feature_rows.append(row_values) + + tail = quantities[-_REGRESSION_HISTORY_TAIL_DAYS:] + tail_dates = [day.isoformat() for day in dates[-_REGRESSION_HISTORY_TAIL_DAYS:]] + + logger.info( + "forecasting.regression_features_built", + store_id=store_id, + product_id=product_id, + n_observations=len(dates), + n_features=len(_REGRESSION_FEATURE_COLUMNS), + ) + + return RegressionFeatureMatrix( + X=np.array(feature_rows, dtype=np.float64), + y=np.array(quantities, dtype=np.float64), + feature_columns=list(_REGRESSION_FEATURE_COLUMNS), + history_tail=[float(value) for value in tail], + history_tail_dates=tail_dates, + launch_date_iso=launch_date.isoformat() if launch_date is not None else None, + n_observations=len(dates), + ) diff --git a/app/features/forecasting/tests/test_regression_forecaster.py b/app/features/forecasting/tests/test_regression_forecaster.py new file mode 100644 index 00000000..22caae79 --- /dev/null +++ b/app/features/forecasting/tests/test_regression_forecaster.py @@ -0,0 +1,116 @@ +"""Unit tests for ``RegressionForecaster`` (PRP-27 Phase B). + +The regression forecaster is the first model that *consumes* the exogenous +``X`` argument, so these tests focus on the new contract: ``X`` is required, +its shape is validated, fits are deterministic, and ``NaN`` features are +tolerated (the future feature frame deliberately emits ``NaN`` cells). +""" + +from __future__ import annotations + +from typing import Any + +import numpy as np +import pytest + +from app.features.forecasting.models import RegressionForecaster, model_factory +from app.features.forecasting.schemas import RegressionModelConfig + +FloatArray = np.ndarray[Any, np.dtype[np.floating[Any]]] + + +def _synthetic_data( + n: int = 120, n_features: int = 6, seed: int = 0 +) -> tuple[FloatArray, FloatArray]: + """Build a synthetic feature matrix and a target that depends on it.""" + rng = np.random.default_rng(seed) + features = rng.normal(size=(n, n_features)) + target = 50.0 + 5.0 * features[:, 0] - 3.0 * features[:, 1] + rng.normal(scale=0.5, size=n) + return features.astype(np.float64), target.astype(np.float64) + + +def test_fit_predict_roundtrip() -> None: + """A fitted regression model produces a finite forecast of horizon length.""" + features, target = _synthetic_data() + model = RegressionForecaster() + model.fit(target, features) + assert model.is_fitted + + horizon = 10 + predictions = model.predict(horizon, features[:horizon]) + assert predictions.shape == (horizon,) + assert bool(np.all(np.isfinite(predictions))) + + +def test_fit_rejects_none_features() -> None: + """``fit`` raises when no exogenous features are supplied.""" + _, target = _synthetic_data() + with pytest.raises(ValueError, match="requires exogenous features"): + RegressionForecaster().fit(target, None) + + +def test_fit_rejects_mismatched_rows() -> None: + """``fit`` raises when feature and target row counts differ.""" + features, target = _synthetic_data() + with pytest.raises(ValueError, match="rows must match"): + RegressionForecaster().fit(target, features[:-5]) + + +def test_predict_rejects_none_features() -> None: + """``predict`` raises when no exogenous features are supplied.""" + features, target = _synthetic_data() + model = RegressionForecaster().fit(target, features) + with pytest.raises(ValueError, match="requires exogenous features"): + model.predict(5, None) + + +def test_predict_rejects_wrong_shape_features() -> None: + """``predict`` raises when the feature row count is not the horizon.""" + features, target = _synthetic_data() + model = RegressionForecaster().fit(target, features) + with pytest.raises(ValueError, match="horizon"): + model.predict(5, features[:8]) + + +def test_predict_before_fit_raises() -> None: + """``predict`` raises a RuntimeError before the model is fitted.""" + model = RegressionForecaster() + with pytest.raises(RuntimeError, match="fitted"): + model.predict(5, np.zeros((5, 3), dtype=np.float64)) + + +def test_determinism_same_random_state() -> None: + """Two fits with the same random_state yield identical forecasts.""" + features, target = _synthetic_data() + future = features[:12] + first = RegressionForecaster(random_state=7).fit(target, features) + second = RegressionForecaster(random_state=7).fit(target, features) + np.testing.assert_array_equal(first.predict(12, future), second.predict(12, future)) + + +def test_handles_nan_features() -> None: + """``HistGradientBoostingRegressor`` tolerates NaN feature cells natively.""" + features, target = _synthetic_data() + model = RegressionForecaster().fit(target, features) + future = features[:6].copy() + future[2, 0] = np.nan # the future frame emits NaN for un-resolvable lags + predictions = model.predict(6, future) + assert bool(np.all(np.isfinite(predictions))) + + +def test_get_and_set_params() -> None: + """``get_params`` reflects construction; ``set_params`` mutates in place.""" + model = RegressionForecaster(max_iter=150, learning_rate=0.03, max_depth=4) + params = model.get_params() + assert params["max_iter"] == 150 + assert params["learning_rate"] == 0.03 + assert params["max_depth"] == 4 + model.set_params(max_depth=9) + assert model.max_depth == 9 + + +def test_model_factory_creates_regression_forecaster() -> None: + """``model_factory`` dispatches a RegressionModelConfig to the right class.""" + model = model_factory(RegressionModelConfig(max_iter=120), random_state=42) + assert isinstance(model, RegressionForecaster) + assert model.max_iter == 120 From 69b707ab33318336be483336c0feb8d4c6b36d98 Mon Sep 17 00:00:00 2001 From: Gabor Szabo Date: Tue, 19 May 2026 10:26:13 +0200 Subject: [PATCH 3/9] feat(api,db): add model-driven scenario simulation path (#223) --- ...47f5739d7d0_widen_scenario_method_check.py | 43 +++++ app/features/scenarios/models.py | 13 +- app/features/scenarios/schemas.py | 9 +- app/features/scenarios/service.py | 166 +++++++++++++++++- app/features/scenarios/tests/conftest.py | 57 +++++- .../tests/test_routes_integration.py | 100 ++++++++++- 6 files changed, 374 insertions(+), 14 deletions(-) create mode 100644 alembic/versions/e47f5739d7d0_widen_scenario_method_check.py diff --git a/alembic/versions/e47f5739d7d0_widen_scenario_method_check.py b/alembic/versions/e47f5739d7d0_widen_scenario_method_check.py new file mode 100644 index 00000000..7b861cde --- /dev/null +++ b/alembic/versions/e47f5739d7d0_widen_scenario_method_check.py @@ -0,0 +1,43 @@ +"""widen scenario method check + +Revision ID: e47f5739d7d0 +Revises: 43e35957a248 +Create Date: 2026-05-19 10:06:15.179816 + +PRP-27 Phase B — widens the ``scenario_plan.method`` CHECK constraint so a +model-driven simulation can persist ``method='model_exogenous'`` alongside the +MVP's ``'heuristic'``. Forward-only: never edits the merged migration that +created the table. +""" +from typing import Sequence, Union + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = 'e47f5739d7d0' +down_revision: Union[str, None] = '43e35957a248' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + +_CONSTRAINT = "ck_scenario_plan_method" +_TABLE = "scenario_plan" + + +def upgrade() -> None: + """Allow method IN ('heuristic', 'model_exogenous').""" + op.drop_constraint(_CONSTRAINT, _TABLE, type_="check") + op.create_check_constraint( + _CONSTRAINT, + _TABLE, + "method IN ('heuristic', 'model_exogenous')", + ) + + +def downgrade() -> None: + """Revert to method IN ('heuristic') only.""" + op.drop_constraint(_CONSTRAINT, _TABLE, type_="check") + op.create_check_constraint( + _CONSTRAINT, + _TABLE, + "method IN ('heuristic')", + ) diff --git a/app/features/scenarios/models.py b/app/features/scenarios/models.py index 66e40c2d..5c0b333d 100644 --- a/app/features/scenarios/models.py +++ b/app/features/scenarios/models.py @@ -21,8 +21,11 @@ from app.core.database import Base from app.shared.models import TimestampMixin -# The only adjustment method the MVP produces — guarded by a CHECK constraint. +# Adjustment methods — guarded by a CHECK constraint. ``heuristic`` is the MVP +# post-forecast multiplier; ``model_exogenous`` (PRP-27) re-forecasts through a +# feature-consuming regression model. SCENARIO_METHOD_HEURISTIC = "heuristic" +SCENARIO_METHOD_MODEL_EXOGENOUS = "model_exogenous" class ScenarioPlan(TimestampMixin, Base): @@ -65,6 +68,10 @@ class ScenarioPlan(TimestampMixin, Base): Index("ix_scenario_plan_comparison_gin", "comparison", postgresql_using="gin"), # Composite index for the common "plans for this store/product" query. Index("ix_scenario_plan_store_product", "store_id", "product_id"), - # The MVP only ever produces heuristic comparisons. - CheckConstraint("method IN ('heuristic')", name="ck_scenario_plan_method"), + # heuristic (MVP) or model_exogenous (PRP-27) — kept in lock-step with + # the alembic migration that widened this CHECK. + CheckConstraint( + "method IN ('heuristic', 'model_exogenous')", + name="ck_scenario_plan_method", + ), ) diff --git a/app/features/scenarios/schemas.py b/app/features/scenarios/schemas.py index d5bc0d50..fea90ee9 100644 --- a/app/features/scenarios/schemas.py +++ b/app/features/scenarios/schemas.py @@ -255,14 +255,15 @@ class ScenarioComparison(BaseModel): description="covered / at_risk / stockout, or unknown when no inventory " "assumption was supplied.", ) - method: Literal["heuristic"] = Field( + method: Literal["heuristic", "model_exogenous"] = Field( ..., - description="Always 'heuristic' — the result is a deterministic post-forecast " - "multiplier, not a re-trained causal model.", + description="How the scenario was produced: 'heuristic' (a deterministic " + "post-forecast multiplier) or 'model_exogenous' (a re-forecast through a " + "feature-consuming regression model).", ) disclaimer: str = Field( ..., - description="Plain-language caveat that the numbers are heuristic estimates.", + description="Plain-language caveat appropriate to the method that produced the comparison.", ) generated_at: datetime = Field(..., description="When the comparison was computed (UTC).") diff --git a/app/features/scenarios/service.py b/app/features/scenarios/service.py index 6fb19ec6..901647ee 100644 --- a/app/features/scenarios/service.py +++ b/app/features/scenarios/service.py @@ -22,7 +22,9 @@ import uuid from datetime import UTC, date, datetime, timedelta from pathlib import Path +from typing import cast +import numpy as np from sqlalchemy import func, select from sqlalchemy.ext.asyncio import AsyncSession @@ -31,7 +33,8 @@ from app.features.data_platform.models import SalesDaily from app.features.forecasting.persistence import ModelBundle, load_model_bundle from app.features.scenarios import adjustments -from app.features.scenarios.models import SCENARIO_METHOD_HEURISTIC, ScenarioPlan +from app.features.scenarios.feature_frame import build_future_frame +from app.features.scenarios.models import ScenarioPlan from app.features.scenarios.schemas import ( CreateScenarioRequest, ScenarioAssumptions, @@ -54,6 +57,14 @@ "precise predictions." ) +# Caveat for the model-driven path — a re-forecast IS model-causal, but a model +# estimate is still an estimate (NIST-AI-RMF transparency control). +MODEL_EXOGENOUS_DISCLAIMER = ( + "Model estimate: this scenario re-forecasts demand through a feature-driven " + "model using the assumptions as future inputs. It reflects learned patterns " + "but remains an estimate under uncertainty — not a guarantee." +) + # Fallback unit price when a (store, product) has no sales history. DEFAULT_UNIT_PRICE = 1.0 @@ -92,6 +103,12 @@ async def simulate( store_id = int(str(store_id_raw)) product_id = int(str(product_id_raw)) + # A regression baseline answers the what-if by genuinely re-forecasting + # through the future feature frame; every other model type uses the + # deterministic heuristic multiplier below (PRP-27 DECISIONS LOCKED #1). + if bundle.config.model_type == "regression": + return await self._simulate_model_exogenous(db, request, bundle, store_id, product_id) + # Replicate the ForecastingService.predict body (DECISIONS LOCKED #2). raw_forecast = bundle.model.predict(request.horizon) baseline_values = [float(value) for value in raw_forecast] @@ -159,6 +176,149 @@ async def simulate( generated_at=datetime.now(UTC), ) + async def _simulate_model_exogenous( + self, + db: AsyncSession, + request: SimulateScenarioRequest, + bundle: ModelBundle, + store_id: int, + product_id: int, + ) -> ScenarioComparison: + """Re-forecast a regression baseline through the future feature frame. + + Builds two leakage-safe future frames — one carrying the scenario + assumptions, one with none — feeds both to the model, and compares the + re-forecasts. Unlike the heuristic path the deltas come from the model + itself, so the result is stamped ``method="model_exogenous"``. + + Args: + db: Database session. + request: The baseline ``run_id``, horizon, and assumptions. + bundle: The already-loaded regression model bundle. + store_id: Store the baseline model targets. + product_id: Product the baseline model targets. + + Returns: + A model-driven baseline-vs-scenario comparison. + + Raises: + ValueError: When the bundle lacks the feature metadata a scenario + forecast needs (an older artifact trained before PRP-27). + """ + feature_columns_raw = bundle.metadata.get("feature_columns") + history_tail_raw = bundle.metadata.get("history_tail") + if not isinstance(feature_columns_raw, list) or not isinstance(history_tail_raw, list): + raise ValueError( + f"Model artifact for run_id '{request.run_id}' is a regression " + "model without the feature metadata a scenario forecast needs — " + "retrain it with the current pipeline." + ) + feature_columns = [str(column) for column in cast("list[str]", feature_columns_raw)] + history_tail = [float(value) for value in cast("list[float]", history_tail_raw)] + + # The forecast origin T is the day before the first forecast day. + origin = self._forecast_start_date(bundle.metadata.get("train_end_date")) - timedelta( + days=1 + ) + launch_raw = bundle.metadata.get("launch_date") + launch_date = date.fromisoformat(launch_raw) if isinstance(launch_raw, str) else None + + scenario_frame = await build_future_frame( + db, + store_id=store_id, + product_id=product_id, + forecast_origin=origin, + horizon=request.horizon, + feature_columns=feature_columns, + history_tail=history_tail, + assumptions=request.assumptions, + launch_date=launch_date, + ) + # The baseline is the SAME frame with the assumptions stripped. + baseline_frame = await build_future_frame( + db, + store_id=store_id, + product_id=product_id, + forecast_origin=origin, + horizon=request.horizon, + feature_columns=feature_columns, + history_tail=history_tail, + assumptions=ScenarioAssumptions(), + launch_date=launch_date, + ) + + scenario_x = np.array(scenario_frame.matrix, dtype=np.float64) + baseline_x = np.array(baseline_frame.matrix, dtype=np.float64) + # Demand can never be negative — floor the model output at 0. + scenario_values = [ + max(0.0, float(value)) for value in bundle.model.predict(request.horizon, scenario_x) + ] + baseline_values = [ + max(0.0, float(value)) for value in bundle.model.predict(request.horizon, baseline_x) + ] + + points = [ + ScenarioPoint( + date=scenario_frame.dates[offset], + baseline=baseline_values[offset], + scenario=scenario_values[offset], + delta=scenario_values[offset] - baseline_values[offset], + # The realised per-day multiplier the model implied (1.0 == no + # change); guards a zero-baseline day. + applied_factor=( + scenario_values[offset] / baseline_values[offset] + if baseline_values[offset] > 0.0 + else 1.0 + ), + ) + for offset in range(request.horizon) + ] + + baseline_total = sum(baseline_values) + scenario_total = sum(scenario_values) + units_delta = scenario_total - baseline_total + units_delta_pct = (units_delta / baseline_total * 100.0) if baseline_total > 0 else 0.0 + + unit_price = await self._latest_unit_price(db, store_id, product_id) + baseline_revenue = baseline_total * unit_price + scenario_revenue = scenario_total * unit_price + + inventory = request.assumptions.inventory + on_hand = inventory.on_hand_units if inventory is not None else None + verdict = adjustments.coverage_verdict(scenario_total, on_hand) + + logger.info( + "scenarios.simulated", + run_id=request.run_id, + store_id=store_id, + product_id=product_id, + horizon=request.horizon, + model_type=bundle.config.model_type, + method="model_exogenous", + units_delta=round(units_delta, 4), + coverage_verdict=verdict, + ) + + return ScenarioComparison( + store_id=store_id, + product_id=product_id, + model_type=bundle.config.model_type, + horizon=request.horizon, + points=points, + baseline_total_units=baseline_total, + scenario_total_units=scenario_total, + units_delta=units_delta, + units_delta_pct=units_delta_pct, + unit_price_used=unit_price, + baseline_revenue=baseline_revenue, + scenario_revenue=scenario_revenue, + revenue_delta=scenario_revenue - baseline_revenue, + coverage_verdict=verdict, + method="model_exogenous", + disclaimer=MODEL_EXOGENOUS_DISCLAIMER, + generated_at=datetime.now(UTC), + ) + # -- Persistence ------------------------------------------------------- async def create_plan( @@ -197,7 +357,9 @@ async def create_plan( # JSONB cannot store Python date/datetime — dump in JSON mode. assumptions=request.assumptions.model_dump(mode="json"), comparison=comparison.model_dump(mode="json"), - method=SCENARIO_METHOD_HEURISTIC, + # heuristic or model_exogenous — taken from the comparison the + # baseline model actually produced. + method=comparison.method, ) db.add(plan) await db.commit() diff --git a/app/features/scenarios/tests/conftest.py b/app/features/scenarios/tests/conftest.py index d899ff2c..b1cc6eeb 100644 --- a/app/features/scenarios/tests/conftest.py +++ b/app/features/scenarios/tests/conftest.py @@ -11,6 +11,7 @@ import uuid from collections.abc import AsyncGenerator, Generator +from datetime import date, timedelta from pathlib import Path import numpy as np @@ -21,9 +22,10 @@ from app.core.config import get_settings from app.core.database import get_db -from app.features.forecasting.models import NaiveForecaster +from app.features.forecasting.models import NaiveForecaster, RegressionForecaster from app.features.forecasting.persistence import ModelBundle, save_model_bundle -from app.features.forecasting.schemas import NaiveModelConfig +from app.features.forecasting.schemas import NaiveModelConfig, RegressionModelConfig +from app.features.scenarios.feature_frame import canonical_feature_columns from app.features.scenarios.models import ScenarioPlan from app.main import app @@ -97,3 +99,54 @@ def trained_model() -> Generator[str, None, None]: yield run_id (artifacts_dir / f"model_{run_id}.joblib").unlink(missing_ok=True) + + +@pytest.fixture +def trained_regression_model() -> Generator[str, None, None]: + """Save a real fitted ``RegressionForecaster`` bundle on disk; yield run_id. + + The bundle carries the full PRP-27 metadata contract — ``feature_columns``, + ``history_tail``, ``launch_date`` — so the model-exogenous simulate path can + build a future feature frame and genuinely re-forecast. Demand is wired to + respond negatively to ``price_factor`` so a price cut lifts the forecast. + """ + settings = get_settings() + artifacts_dir = Path(settings.forecast_model_artifacts_dir) + artifacts_dir.mkdir(parents=True, exist_ok=True) + + run_id = uuid.uuid4().hex[:12] + columns = canonical_feature_columns() + rng = np.random.default_rng(7) + n_rows = 200 + features = rng.normal(size=(n_rows, len(columns))) + # Strong, negative price_factor coefficient: price_factor < 1.0 (a cut) + # lifts demand. The signal dwarfs the 0.5-scale noise, so the model learns + # a clean, deterministic price response. + price_index = columns.index("price_factor") + target = 40.0 - 20.0 * features[:, price_index] + rng.normal(scale=0.5, size=n_rows) + + model = RegressionForecaster(random_state=7) + model.fit(target.astype(np.float64), features.astype(np.float64)) + + history_start = date(2026, 4, 1) + bundle = ModelBundle( + model=model, + config=RegressionModelConfig(), + metadata={ + "store_id": TEST_STORE_ID, + "product_id": TEST_PRODUCT_ID, + "train_end_date": TEST_TRAIN_END_DATE, + "n_observations": n_rows, + "feature_columns": columns, + "history_tail": [12.0] * 90, + "history_tail_dates": [ + (history_start + timedelta(days=offset)).isoformat() for offset in range(90) + ], + "launch_date": "2025-01-01", + }, + ) + save_model_bundle(bundle, artifacts_dir / f"model_{run_id}") + + yield run_id + + (artifacts_dir / f"model_{run_id}.joblib").unlink(missing_ok=True) diff --git a/app/features/scenarios/tests/test_routes_integration.py b/app/features/scenarios/tests/test_routes_integration.py index dd9b7e8d..f1d16976 100644 --- a/app/features/scenarios/tests/test_routes_integration.py +++ b/app/features/scenarios/tests/test_routes_integration.py @@ -142,13 +142,89 @@ async def test_delete_missing_plan_returns_404(self, client: AsyncClient) -> Non assert response.status_code == 404 +@pytest.mark.integration +@pytest.mark.asyncio +class TestSimulateModelExogenous: + """Integration tests for the model-driven (regression) simulate path.""" + + async def test_regression_baseline_returns_model_exogenous( + self, client: AsyncClient, trained_regression_model: str + ) -> None: + """A regression baseline re-forecasts — method is 'model_exogenous'.""" + response = await client.post( + "/scenarios/simulate", + json={ + "run_id": trained_regression_model, + "horizon": 14, + "assumptions": _PRICE_ASSUMPTION, + }, + ) + assert response.status_code == 200 + data = response.json() + + assert data["method"] == "model_exogenous" + assert data["disclaimer"], "every comparison must carry a non-empty disclaimer" + assert len(data["points"]) == 14 + # A price cut moves the re-forecast — the deltas are model-driven, not + # a fixed multiplier, and the modelled price response lifts demand. + assert data["units_delta"] > 0.0 + + async def test_regression_empty_assumptions_equals_baseline( + self, client: AsyncClient, trained_regression_model: str + ) -> None: + """With no assumptions the model re-forecasts to exactly the baseline.""" + response = await client.post( + "/scenarios/simulate", + json={"run_id": trained_regression_model, "horizon": 10, "assumptions": {}}, + ) + assert response.status_code == 200 + data = response.json() + + assert data["method"] == "model_exogenous" + assert data["units_delta"] == 0.0 + for point in data["points"]: + assert point["delta"] == 0.0 + + async def test_baseline_forecaster_still_heuristic( + self, client: AsyncClient, trained_model: str + ) -> None: + """A naive baseline still produces a heuristic comparison — unchanged.""" + response = await client.post( + "/scenarios/simulate", + json={"run_id": trained_model, "horizon": 14, "assumptions": _PRICE_ASSUMPTION}, + ) + assert response.status_code == 200 + assert response.json()["method"] == "heuristic" + + async def test_model_exogenous_plan_persists( + self, client: AsyncClient, trained_regression_model: str + ) -> None: + """A model_exogenous comparison saves cleanly — the widened CHECK accepts it.""" + create = await client.post( + "/scenarios", + json={ + "name": "Model-driven price cut", + "run_id": trained_regression_model, + "horizon": 14, + "assumptions": _PRICE_ASSUMPTION, + }, + ) + assert create.status_code == 201 + plan = create.json() + assert plan["method"] == "model_exogenous" + + fetched = await client.get(f"/scenarios/{plan['scenario_id']}") + assert fetched.status_code == 200 + assert fetched.json()["comparison"]["method"] == "model_exogenous" + + @pytest.mark.integration @pytest.mark.asyncio class TestScenarioPlanModel: """Constraint tests for the ScenarioPlan ORM model.""" - async def test_method_check_constraint(self, db_session: AsyncSession) -> None: - """The method CHECK constraint rejects any value other than 'heuristic'.""" + async def test_method_check_rejects_unknown_value(self, db_session: AsyncSession) -> None: + """The method CHECK constraint rejects a value outside the allow-list.""" plan = ScenarioPlan( scenario_id=uuid.uuid4().hex, name="bad method", @@ -158,9 +234,27 @@ async def test_method_check_constraint(self, db_session: AsyncSession) -> None: horizon=7, assumptions={}, comparison={}, - method="not_heuristic", + method="not_a_method", ) db_session.add(plan) with pytest.raises(IntegrityError): await db_session.commit() await db_session.rollback() + + async def test_method_check_accepts_model_exogenous(self, db_session: AsyncSession) -> None: + """The widened CHECK constraint accepts the 'model_exogenous' method.""" + plan = ScenarioPlan( + scenario_id=uuid.uuid4().hex, + name="model exogenous plan", + store_id=1, + product_id=1, + run_id="abc", + horizon=7, + assumptions={}, + comparison={}, + method="model_exogenous", + ) + db_session.add(plan) + await db_session.commit() + await db_session.refresh(plan) + assert plan.method == "model_exogenous" From 75446ee9205f01eb778e8688952b65c12ace04ff Mon Sep 17 00:00:00 2001 From: Gabor Szabo Date: Tue, 19 May 2026 10:34:04 +0200 Subject: [PATCH 4/9] feat(api,db): add scenario library and multi-scenario comparison (#223) --- ...8c4587ef1d_add_scenario_library_columns.py | 52 +++++++ app/features/scenarios/models.py | 13 +- app/features/scenarios/routes.py | 47 +++++- app/features/scenarios/schemas.py | 84 +++++++++++ app/features/scenarios/service.py | 123 ++++++++++++++-- .../tests/test_compare_integration.py | 139 ++++++++++++++++++ 6 files changed, 445 insertions(+), 13 deletions(-) create mode 100644 alembic/versions/bb8c4587ef1d_add_scenario_library_columns.py create mode 100644 app/features/scenarios/tests/test_compare_integration.py diff --git a/alembic/versions/bb8c4587ef1d_add_scenario_library_columns.py b/alembic/versions/bb8c4587ef1d_add_scenario_library_columns.py new file mode 100644 index 00000000..38e65d10 --- /dev/null +++ b/alembic/versions/bb8c4587ef1d_add_scenario_library_columns.py @@ -0,0 +1,52 @@ +"""add scenario library columns + +Revision ID: bb8c4587ef1d +Revises: e47f5739d7d0 +Create Date: 2026-05-19 10:26:58.473203 + +PRP-27 Phase C — adds the scenario-library columns to ``scenario_plan``: +``tags`` (a JSONB string array, queryable via a GIN index) and ``cloned_from`` +(the ``scenario_id`` a plan was cloned from, nullable). Forward-only. +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = 'bb8c4587ef1d' +down_revision: Union[str, None] = 'e47f5739d7d0' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Add the tags and cloned_from columns plus a GIN index on tags.""" + op.add_column( + 'scenario_plan', + sa.Column( + 'tags', + postgresql.JSONB(astext_type=sa.Text()), + nullable=False, + server_default=sa.text("'[]'::jsonb"), + ), + ) + op.add_column( + 'scenario_plan', + sa.Column('cloned_from', sa.String(length=32), nullable=True), + ) + op.create_index( + 'ix_scenario_plan_tags_gin', + 'scenario_plan', + ['tags'], + unique=False, + postgresql_using='gin', + ) + + +def downgrade() -> None: + """Drop the GIN index and the scenario-library columns.""" + op.drop_index('ix_scenario_plan_tags_gin', table_name='scenario_plan', postgresql_using='gin') + op.drop_column('scenario_plan', 'cloned_from') + op.drop_column('scenario_plan', 'tags') diff --git a/app/features/scenarios/models.py b/app/features/scenarios/models.py index 5c0b333d..1bab6a0c 100644 --- a/app/features/scenarios/models.py +++ b/app/features/scenarios/models.py @@ -14,7 +14,7 @@ from typing import Any -from sqlalchemy import CheckConstraint, Index, Integer, String +from sqlalchemy import CheckConstraint, Index, Integer, String, text from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.orm import Mapped, mapped_column @@ -62,12 +62,23 @@ class ScenarioPlan(TimestampMixin, Base): String(20), nullable=False, default=SCENARIO_METHOD_HEURISTIC ) + # Scenario-library columns (PRP-27 Phase C). ``tags`` is a JSONB string + # array — a real column (never folded into a JSONB blob) so it is + # queryable/indexable. ``cloned_from`` is the scenario_id this plan was + # cloned from, or NULL. + tags: Mapped[list[str]] = mapped_column( + JSONB, nullable=False, default=list, server_default=text("'[]'::jsonb") + ) + cloned_from: Mapped[str | None] = mapped_column(String(32), nullable=True) + __table_args__ = ( # GIN indexes for JSONB containment queries on either blob. Index("ix_scenario_plan_assumptions_gin", "assumptions", postgresql_using="gin"), Index("ix_scenario_plan_comparison_gin", "comparison", postgresql_using="gin"), # Composite index for the common "plans for this store/product" query. Index("ix_scenario_plan_store_product", "store_id", "product_id"), + # GIN index so the saved-plans list can filter by tag containment. + Index("ix_scenario_plan_tags_gin", "tags", postgresql_using="gin"), # heuristic (MVP) or model_exogenous (PRP-27) — kept in lock-step with # the alembic migration that widened this CHECK. CheckConstraint( diff --git a/app/features/scenarios/routes.py b/app/features/scenarios/routes.py index 3e0bf46e..6ce4c1ec 100644 --- a/app/features/scenarios/routes.py +++ b/app/features/scenarios/routes.py @@ -16,7 +16,9 @@ from app.core.exceptions import BadRequestError, DatabaseError, NotFoundError from app.core.logging import get_logger from app.features.scenarios.schemas import ( + CompareScenariosRequest, CreateScenarioRequest, + MultiScenarioComparison, ScenarioComparison, ScenarioListResponse, ScenarioPlanResponse, @@ -127,17 +129,57 @@ async def create_scenario( ) from exc +@router.post( + "/compare", + response_model=MultiScenarioComparison, + status_code=status.HTTP_200_OK, + summary="Compare saved scenario plans", + description=""" +Rank 2-5 saved scenario plans against a shared baseline. + +Each saved plan embeds its own comparison snapshot, so this is a pure +aggregation — no model artifact is reloaded. An unknown `scenario_id` returns +a 404 problem response. +""", +) +async def compare_scenarios( + request: CompareScenariosRequest, + db: AsyncSession = Depends(get_db), +) -> MultiScenarioComparison: + """Compare 2-5 saved scenario plans. + + Args: + request: The 2-5 scenario_ids and the ranking metric. + db: Async database session from dependency. + + Returns: + A ranked multi-scenario comparison plus merged chart series. + + Raises: + NotFoundError: When any requested scenario_id does not exist. + """ + try: + return await ScenarioService().compare_scenarios(db, request) + except FileNotFoundError as exc: + logger.warning("scenarios.compare_not_found", error=str(exc)) + raise NotFoundError(message=str(exc)) from exc + + @router.get( "", response_model=ScenarioListResponse, summary="List saved scenario plans", description="List saved scenario plans, newest first. Returns 200 + an " - "empty list when no plans exist.", + "empty list when no plans exist. Pass one or more `tags` to filter to " + "plans carrying every listed tag.", ) async def list_scenarios( db: AsyncSession = Depends(get_db), limit: int = Query(default=20, ge=1, le=100, description="Maximum plans to return."), offset: int = Query(default=0, ge=0, description="Number of plans to skip."), + tags: list[str] | None = Query( + default=None, description="Filter to plans carrying every listed tag." + ), ) -> ScenarioListResponse: """List saved scenario plans. @@ -145,11 +187,12 @@ async def list_scenarios( db: Async database session from dependency. limit: Maximum plans to return (1-100). offset: Number of plans to skip. + tags: Optional library tags to filter by. Returns: A page of saved plans plus the total count. """ - return await ScenarioService().list_plans(db, limit=limit, offset=offset) + return await ScenarioService().list_plans(db, limit=limit, offset=offset, tags=tags) @router.get( diff --git a/app/features/scenarios/schemas.py b/app/features/scenarios/schemas.py index fea90ee9..be696e1b 100644 --- a/app/features/scenarios/schemas.py +++ b/app/features/scenarios/schemas.py @@ -200,6 +200,16 @@ class CreateScenarioRequest(BaseModel): default_factory=ScenarioAssumptions, description="What-if assumptions for this plan.", ) + tags: list[str] = Field( + default_factory=list, + max_length=20, + description="Optional library tags for filtering and grouping saved plans.", + ) + cloned_from: str | None = Field( + default=None, + max_length=32, + description="scenario_id this plan was cloned from, when it originated as a clone.", + ) # ============================================================================= @@ -287,6 +297,12 @@ class ScenarioPlanResponse(BaseModel): comparison: ScenarioComparison = Field( ..., description="The full baseline-vs-scenario snapshot, re-rendered without recompute." ) + tags: list[str] = Field( + default_factory=list, description="Library tags attached to the plan." + ) + cloned_from: str | None = Field( + default=None, description="scenario_id this plan was cloned from, if any." + ) class ScenarioListItem(BaseModel): @@ -302,6 +318,9 @@ class ScenarioListItem(BaseModel): units_delta: float = Field(..., description="Summed scenario-minus-baseline demand.") revenue_delta: float = Field(..., description="Scenario-minus-baseline revenue.") created_at: datetime = Field(..., description="When the plan was saved (UTC).") + tags: list[str] = Field( + default_factory=list, description="Library tags attached to the plan." + ) class ScenarioListResponse(BaseModel): @@ -313,3 +332,68 @@ class ScenarioListResponse(BaseModel): ..., description="Saved plans for the current page; empty when none exist." ) total: int = Field(..., ge=0, description="Total saved plans matching the query.") + + +# ============================================================================= +# Multi-scenario comparison (PRP-27 Phase C) +# ============================================================================= + +# Metric a multi-scenario comparison ranks by. +RankBy = Literal["revenue_delta", "units_delta"] + + +class CompareScenariosRequest(BaseModel): + """Request body for ``POST /scenarios/compare``. + + The 2..5 bound keeps the multi-series chart legible — the upper bound is + the pinned ``MAX_COMPARE_SCENARIOS`` (PRP-27 DECISIONS LOCKED #12); the + literal ``5`` must stay in sync with that constant in ``feature_frame.py``. + """ + + model_config = ConfigDict(strict=True) + + scenario_ids: list[str] = Field( + ..., + min_length=2, + max_length=5, + description="2-5 saved scenario_ids to compare side by side.", + ) + rank_by: RankBy = Field( + default="revenue_delta", + description="Metric the ranked rows are ordered by (descending).", + ) + + +class ScenarioComparisonRow(BaseModel): + """One saved plan's headline numbers within a multi-scenario comparison.""" + + model_config = ConfigDict(from_attributes=True) + + scenario_id: str = Field(..., description="The plan's external identifier.") + name: str = Field(..., description="The plan's human-readable name.") + units_delta: float = Field(..., description="Scenario-minus-baseline demand for the plan.") + revenue_delta: float = Field(..., description="Scenario-minus-baseline revenue for the plan.") + coverage_verdict: CoverageVerdict = Field(..., description="The plan's coverage verdict.") + rank: int = Field(..., ge=1, description="1-based rank by the chosen metric (1 == best).") + + +class MultiScenarioComparison(BaseModel): + """A baseline compared against 2-5 saved scenarios, ranked.""" + + model_config = ConfigDict(from_attributes=True) + + baseline_total_units: float = Field( + ..., description="Reference baseline demand (from the first compared plan)." + ) + baseline_revenue: float = Field( + ..., description="Reference baseline revenue (from the first compared plan)." + ) + rank_by: RankBy = Field(..., description="Metric the rows are ranked by.") + scenarios: list[ScenarioComparisonRow] = Field( + ..., description="The compared plans, ordered best-first by rank_by." + ) + chart_series: list[dict[str, float | str]] = Field( + ..., + description="Date-keyed merged rows for the multi-series chart — each row " + "carries 'date', 'baseline', and one entry per scenario name.", + ) diff --git a/app/features/scenarios/service.py b/app/features/scenarios/service.py index 901647ee..2f7846ac 100644 --- a/app/features/scenarios/service.py +++ b/app/features/scenarios/service.py @@ -22,7 +22,7 @@ import uuid from datetime import UTC, date, datetime, timedelta from pathlib import Path -from typing import cast +from typing import Any, cast import numpy as np from sqlalchemy import func, select @@ -36,9 +36,12 @@ from app.features.scenarios.feature_frame import build_future_frame from app.features.scenarios.models import ScenarioPlan from app.features.scenarios.schemas import ( + CompareScenariosRequest, CreateScenarioRequest, + MultiScenarioComparison, ScenarioAssumptions, ScenarioComparison, + ScenarioComparisonRow, ScenarioListItem, ScenarioListResponse, ScenarioPlanResponse, @@ -360,6 +363,8 @@ async def create_plan( # heuristic or model_exogenous — taken from the comparison the # baseline model actually produced. method=comparison.method, + tags=list(request.tags), + cloned_from=request.cloned_from, ) db.add(plan) await db.commit() @@ -373,36 +378,131 @@ async def create_plan( ) return self._to_plan_response(plan) - async def list_plans(self, db: AsyncSession, limit: int, offset: int) -> ScenarioListResponse: + async def list_plans( + self, + db: AsyncSession, + limit: int, + offset: int, + tags: list[str] | None = None, + ) -> ScenarioListResponse: """List saved scenario plans, newest first. Args: db: Database session. limit: Maximum plans to return. offset: Number of plans to skip. + tags: Optional library tags — when given, only plans carrying every + listed tag are returned. Returns: A page of plan list items plus the total count. """ - total = int(await db.scalar(select(func.count()).select_from(ScenarioPlan)) or 0) + count_stmt = select(func.count()).select_from(ScenarioPlan) + rows_stmt = ( + select(ScenarioPlan) + .order_by(ScenarioPlan.created_at.desc(), ScenarioPlan.id.desc()) + .limit(limit) + .offset(offset) + ) + if tags: + # JSONB @> containment — a plan matches when it carries every tag. + count_stmt = count_stmt.where(ScenarioPlan.tags.contains(tags)) + rows_stmt = rows_stmt.where(ScenarioPlan.tags.contains(tags)) + + total = int(await db.scalar(count_stmt) or 0) + rows = (await db.execute(rows_stmt)).scalars().all() + return ScenarioListResponse( + scenarios=[self._to_list_item(row) for row in rows], + total=total, + ) + async def compare_scenarios( + self, db: AsyncSession, request: CompareScenariosRequest + ) -> MultiScenarioComparison: + """Rank 2-5 saved scenario plans against a shared baseline. + + Each saved plan already embeds its full ``ScenarioComparison`` snapshot, + so the comparison is a pure aggregation — no model artifact is reloaded. + The reference baseline is taken from the first requested plan. + + Args: + db: Database session. + request: The 2-5 ``scenario_id``s and the ranking metric. + + Returns: + A ranked multi-scenario comparison plus merged chart series. + + Raises: + FileNotFoundError: When any requested ``scenario_id`` does not exist. + """ rows = ( ( await db.execute( - select(ScenarioPlan) - .order_by(ScenarioPlan.created_at.desc(), ScenarioPlan.id.desc()) - .limit(limit) - .offset(offset) + select(ScenarioPlan).where(ScenarioPlan.scenario_id.in_(request.scenario_ids)) ) ) .scalars() .all() ) - return ScenarioListResponse( - scenarios=[self._to_list_item(row) for row in rows], - total=total, + found = {row.scenario_id: row for row in rows} + missing = [sid for sid in request.scenario_ids if sid not in found] + if missing: + raise FileNotFoundError(f"Scenario plan(s) not found: {', '.join(missing)}") + + # Preserve the caller's order; the first plan supplies the baseline. + plans = [found[sid] for sid in request.scenario_ids] + first_comparison = plans[0].comparison + baseline_total = float(first_comparison.get("baseline_total_units", 0.0)) + baseline_revenue = float(first_comparison.get("baseline_revenue", 0.0)) + + ranked = sorted( + plans, + key=lambda plan: float(plan.comparison.get(request.rank_by, 0.0)), + reverse=True, + ) + comparison_rows = [ + ScenarioComparisonRow( + scenario_id=plan.scenario_id, + name=plan.name, + units_delta=float(plan.comparison.get("units_delta", 0.0)), + revenue_delta=float(plan.comparison.get("revenue_delta", 0.0)), + coverage_verdict=plan.comparison.get("coverage_verdict", "unknown"), + rank=index + 1, + ) + for index, plan in enumerate(ranked) + ] + + logger.info( + "scenarios.compared", + scenario_count=len(plans), + rank_by=request.rank_by, + ) + return MultiScenarioComparison( + baseline_total_units=baseline_total, + baseline_revenue=baseline_revenue, + rank_by=request.rank_by, + scenarios=comparison_rows, + chart_series=self._build_chart_series(plans), ) + @staticmethod + def _build_chart_series(plans: list[ScenarioPlan]) -> list[dict[str, float | str]]: + """Merge every plan's per-day series into date-keyed chart rows. + + Each row carries ``date``, the reference ``baseline`` (from the first + plan), and one entry per plan keyed by the plan name. + """ + by_date: dict[str, dict[str, float | str]] = {} + for plan_index, plan in enumerate(plans): + points = cast("list[dict[str, Any]]", plan.comparison.get("points", [])) + for point in points: + point_date = str(point.get("date", "")) + row = by_date.setdefault(point_date, {"date": point_date}) + if plan_index == 0: + row["baseline"] = float(point.get("baseline", 0.0)) + row[plan.name] = float(point.get("scenario", 0.0)) + return [by_date[key] for key in sorted(by_date)] + async def get_plan(self, db: AsyncSession, scenario_id: str) -> ScenarioPlanResponse | None: """Fetch one saved plan by its external id, or ``None`` when absent.""" plan = await db.scalar(select(ScenarioPlan).where(ScenarioPlan.scenario_id == scenario_id)) @@ -496,6 +596,8 @@ def _to_plan_response(plan: ScenarioPlan) -> ScenarioPlanResponse: created_at=plan.created_at, assumptions=ScenarioAssumptions.model_validate(plan.assumptions), comparison=ScenarioComparison.model_validate(plan.comparison), + tags=list(plan.tags), + cloned_from=plan.cloned_from, ) @staticmethod @@ -510,4 +612,5 @@ def _to_list_item(plan: ScenarioPlan) -> ScenarioListItem: units_delta=float(plan.comparison.get("units_delta", 0.0)), revenue_delta=float(plan.comparison.get("revenue_delta", 0.0)), created_at=plan.created_at, + tags=list(plan.tags), ) diff --git a/app/features/scenarios/tests/test_compare_integration.py b/app/features/scenarios/tests/test_compare_integration.py new file mode 100644 index 00000000..a2da7a61 --- /dev/null +++ b/app/features/scenarios/tests/test_compare_integration.py @@ -0,0 +1,139 @@ +"""Integration tests for the scenario library + multi-scenario comparison. + +PRP-27 Phase C. Runs against a real PostgreSQL database and a real model +bundle on disk. Requires ``docker compose up -d``. +""" + +from __future__ import annotations + +from typing import Any + +import pytest +from httpx import AsyncClient + +_PRICE_CUT: dict[str, object] = { + "price": {"change_pct": -0.15, "start_date": "2026-07-01", "end_date": "2026-07-14"}, +} +_PRICE_RISE: dict[str, object] = { + "price": {"change_pct": 0.20, "start_date": "2026-07-01", "end_date": "2026-07-14"}, +} + + +async def _create_plan( + client: AsyncClient, + run_id: str, + name: str, + assumptions: dict[str, object], + *, + tags: list[str] | None = None, + cloned_from: str | None = None, +) -> dict[str, Any]: + """Create a saved scenario plan and return its JSON body.""" + body: dict[str, object] = { + "name": name, + "run_id": run_id, + "horizon": 14, + "assumptions": assumptions, + } + if tags is not None: + body["tags"] = tags + if cloned_from is not None: + body["cloned_from"] = cloned_from + response = await client.post("/scenarios", json=body) + assert response.status_code == 201, response.text + created: dict[str, Any] = response.json() + return created + + +@pytest.mark.integration +@pytest.mark.asyncio +class TestCompareScenarios: + """Integration tests for POST /scenarios/compare.""" + + async def test_compare_ranks_plans(self, client: AsyncClient, trained_model: str) -> None: + """Comparing a price cut and a price rise ranks the cut first.""" + cut = await _create_plan(client, trained_model, "Cut", _PRICE_CUT) + rise = await _create_plan(client, trained_model, "Rise", _PRICE_RISE) + + response = await client.post( + "/scenarios/compare", + json={ + "scenario_ids": [rise["scenario_id"], cut["scenario_id"]], + "rank_by": "revenue_delta", + }, + ) + assert response.status_code == 200 + data = response.json() + + assert data["rank_by"] == "revenue_delta" + assert len(data["scenarios"]) == 2 + # A price cut lifts revenue, so it outranks the price rise. + assert data["scenarios"][0]["name"] == "Cut" + assert data["scenarios"][0]["rank"] == 1 + assert data["scenarios"][1]["rank"] == 2 + assert data["chart_series"], "chart_series must carry merged date rows" + assert "baseline" in data["chart_series"][0] + + async def test_compare_too_few_returns_422(self, client: AsyncClient) -> None: + """Fewer than 2 scenario_ids is rejected at the schema boundary.""" + response = await client.post("/scenarios/compare", json={"scenario_ids": ["only-one"]}) + assert response.status_code == 422 + + async def test_compare_too_many_returns_422(self, client: AsyncClient) -> None: + """More than MAX_COMPARE_SCENARIOS (5) scenario_ids is rejected.""" + response = await client.post( + "/scenarios/compare", + json={"scenario_ids": [f"id-{index}" for index in range(6)]}, + ) + assert response.status_code == 422 + + async def test_compare_bogus_id_returns_404( + self, client: AsyncClient, trained_model: str + ) -> None: + """An unknown scenario_id returns an RFC 7807 404 — never a 500.""" + plan = await _create_plan(client, trained_model, "Real", _PRICE_CUT) + response = await client.post( + "/scenarios/compare", + json={"scenario_ids": [plan["scenario_id"], "does-not-exist-xyz"]}, + ) + assert response.status_code == 404 + assert "application/problem+json" in response.headers.get("content-type", "") + + +@pytest.mark.integration +@pytest.mark.asyncio +class TestScenarioLibrary: + """Integration tests for tag filtering and plan cloning.""" + + async def test_create_with_tags_and_filter( + self, client: AsyncClient, trained_model: str + ) -> None: + """A tag filter returns only plans carrying every listed tag.""" + await _create_plan(client, trained_model, "Tagged", _PRICE_CUT, tags=["q3", "promo"]) + await _create_plan(client, trained_model, "Untagged", _PRICE_CUT, tags=[]) + + response = await client.get("/scenarios", params={"tags": ["q3"]}) + assert response.status_code == 200 + data = response.json() + + names = {item["name"] for item in data["scenarios"]} + assert "Tagged" in names + assert "Untagged" not in names + for item in data["scenarios"]: + assert "q3" in item["tags"] + + async def test_clone_records_cloned_from(self, client: AsyncClient, trained_model: str) -> None: + """A plan created with cloned_from records its origin.""" + original = await _create_plan(client, trained_model, "Original", _PRICE_CUT) + clone = await _create_plan( + client, + trained_model, + "Clone of original", + _PRICE_CUT, + cloned_from=original["scenario_id"], + ) + assert clone["cloned_from"] == original["scenario_id"] + + fetched = await client.get(f"/scenarios/{clone['scenario_id']}") + assert fetched.status_code == 200 + assert fetched.json()["cloned_from"] == original["scenario_id"] From 1a1c84af9d11a7b921fa188fb91eec52e0f911fa Mon Sep 17 00:00:00 2001 From: Gabor Szabo Date: Tue, 19 May 2026 10:42:21 +0200 Subject: [PATCH 5/9] feat(api): key scenario compare chart series by scenario_id (#223) --- app/features/scenarios/schemas.py | 2 +- app/features/scenarios/service.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/app/features/scenarios/schemas.py b/app/features/scenarios/schemas.py index be696e1b..5a998e0e 100644 --- a/app/features/scenarios/schemas.py +++ b/app/features/scenarios/schemas.py @@ -395,5 +395,5 @@ class MultiScenarioComparison(BaseModel): chart_series: list[dict[str, float | str]] = Field( ..., description="Date-keyed merged rows for the multi-series chart — each row " - "carries 'date', 'baseline', and one entry per scenario name.", + "carries 'date', 'baseline', and one entry per scenario keyed by scenario_id.", ) diff --git a/app/features/scenarios/service.py b/app/features/scenarios/service.py index 2f7846ac..9b95cc68 100644 --- a/app/features/scenarios/service.py +++ b/app/features/scenarios/service.py @@ -490,7 +490,8 @@ def _build_chart_series(plans: list[ScenarioPlan]) -> list[dict[str, float | str """Merge every plan's per-day series into date-keyed chart rows. Each row carries ``date``, the reference ``baseline`` (from the first - plan), and one entry per plan keyed by the plan name. + plan), and one entry per plan keyed by its ``scenario_id`` — a + CSS-identifier-safe key, unlike a free-text plan name. """ by_date: dict[str, dict[str, float | str]] = {} for plan_index, plan in enumerate(plans): @@ -500,7 +501,7 @@ def _build_chart_series(plans: list[ScenarioPlan]) -> list[dict[str, float | str row = by_date.setdefault(point_date, {"date": point_date}) if plan_index == 0: row["baseline"] = float(point.get("baseline", 0.0)) - row[plan.name] = float(point.get("scenario", 0.0)) + row[plan.scenario_id] = float(point.get("scenario", 0.0)) return [by_date[key] for key in sorted(by_date)] async def get_plan(self, db: AsyncSession, scenario_id: str) -> ScenarioPlanResponse | None: From b2bd4809e58d5e54089e629e92d89ace485cfea1 Mon Sep 17 00:00:00 2001 From: Gabor Szabo Date: Tue, 19 May 2026 10:42:21 +0200 Subject: [PATCH 6/9] feat(ui): add what-if planner multi-scenario comparison view (#223) --- frontend/src/components/charts/index.ts | 1 + .../components/charts/multi-series-chart.tsx | 106 ++++++++++ frontend/src/hooks/use-scenarios.ts | 28 ++- frontend/src/lib/scenario-utils.test.ts | 55 ++++- frontend/src/lib/scenario-utils.ts | 32 ++- frontend/src/pages/visualize/planner.tsx | 200 ++++++++++++++---- frontend/src/types/api.ts | 39 +++- 7 files changed, 416 insertions(+), 45 deletions(-) create mode 100644 frontend/src/components/charts/multi-series-chart.tsx diff --git a/frontend/src/components/charts/index.ts b/frontend/src/components/charts/index.ts index 2e439f7a..6e4231b4 100644 --- a/frontend/src/components/charts/index.ts +++ b/frontend/src/components/charts/index.ts @@ -1,4 +1,5 @@ export * from './kpi-card' export * from './time-series-chart' +export * from './multi-series-chart' export * from './backtest-folds-chart' export * from './revenue-bar-chart' diff --git a/frontend/src/components/charts/multi-series-chart.tsx b/frontend/src/components/charts/multi-series-chart.tsx new file mode 100644 index 00000000..33aade53 --- /dev/null +++ b/frontend/src/components/charts/multi-series-chart.tsx @@ -0,0 +1,106 @@ +import { CartesianGrid, ComposedChart, Line, XAxis, YAxis } from 'recharts' +import { + ChartConfig, + ChartContainer, + ChartLegend, + ChartLegendContent, + ChartTooltip, + ChartTooltipContent, +} from '@/components/ui/chart' +import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card' + +/** One line in a multi-series chart — a row key plus its display label. */ +export interface ChartSeries { + key: string + label: string +} + +interface MultiSeriesChartProps { + title: string + description?: string + /** Date-keyed rows; each carries `xAxisKey` plus a value per series key. */ + data: Record[] + /** The lines to draw — the first is rendered solid, the rest dashed. */ + series: ChartSeries[] + xAxisKey?: string + height?: number + className?: string +} + +// Deterministic palette — the shadcn chart CSS vars cycled across the lines. +const PALETTE = [ + 'var(--chart-1)', + 'var(--chart-2)', + 'var(--chart-3)', + 'var(--chart-4)', + 'var(--chart-5)', +] + +/** + * Renders M+1 demand lines on one chart — a shared baseline plus one line per + * scenario — for the What-If Planner's multi-scenario comparison view. + * + * Every series key MUST be a valid CSS identifier (the shadcn `ChartContainer` + * emits a `--color-` var per key); callers key scenario lines by the + * CSS-safe `scenario_id`, never by a free-text plan name. + */ +export function MultiSeriesChart({ + title, + description, + data, + series, + xAxisKey = 'date', + height = 320, + className, +}: MultiSeriesChartProps) { + const chartConfig: ChartConfig = {} + series.forEach((line, index) => { + chartConfig[line.key] = { + label: line.label, + color: PALETTE[index % PALETTE.length], + } + }) + + return ( + + + {title} + {description && {description}} + + + {/* Height via inline style — a dynamic `h-[${height}px]` class is not + statically discoverable by the Tailwind JIT compiler. */} + + + + { + const date = new Date(value) + return date.toLocaleDateString('en-US', { month: 'short', day: 'numeric' }) + }} + /> + + } /> + } /> + {series.map((line, index) => ( + + ))} + + + + + ) +} diff --git a/frontend/src/hooks/use-scenarios.ts b/frontend/src/hooks/use-scenarios.ts index cd68335d..198165cc 100644 --- a/frontend/src/hooks/use-scenarios.ts +++ b/frontend/src/hooks/use-scenarios.ts @@ -1,7 +1,9 @@ import { useMutation, useQuery, useQueryClient } from '@tanstack/react-query' import { api } from '@/lib/api' import type { + CompareScenariosRequest, CreateScenarioRequest, + MultiScenarioComparison, ScenarioComparison, ScenarioListResponse, ScenarioPlanResponse, @@ -19,11 +21,18 @@ export function useSimulateScenario() { }) } -/** List saved scenario plans, newest first. */ -export function useScenarios(enabled = true) { +/** + * List saved scenario plans, newest first. Pass one or more `tags` to filter + * to plans carrying every listed tag. + */ +export function useScenarios(tags: string[] = [], enabled = true) { + const query = + tags.length > 0 + ? `?${tags.map((tag) => `tags=${encodeURIComponent(tag)}`).join('&')}` + : '' return useQuery({ - queryKey: ['scenarios'], - queryFn: () => api('/scenarios'), + queryKey: ['scenarios', { tags }], + queryFn: () => api(`/scenarios${query}`), enabled, }) } @@ -60,3 +69,14 @@ export function useDeleteScenario() { }, }) } + +/** + * Compare 2-5 saved scenario plans. A mutation, not a query — the comparison + * is an explicit user action and the result is held in component state. + */ +export function useCompareScenarios() { + return useMutation({ + mutationFn: (data: CompareScenariosRequest) => + api('/scenarios/compare', { method: 'POST', body: data }), + }) +} diff --git a/frontend/src/lib/scenario-utils.test.ts b/frontend/src/lib/scenario-utils.test.ts index e5f5cdf5..b091470a 100644 --- a/frontend/src/lib/scenario-utils.test.ts +++ b/frontend/src/lib/scenario-utils.test.ts @@ -1,12 +1,19 @@ import { describe, expect, it } from 'vitest' import { + buildMultiSeries, coverageLabel, coverageVariant, formatDelta, mergeComparisonSeries, + methodLabel, summariseAssumptions, } from './scenario-utils' -import type { ScenarioAssumptions, ScenarioPoint } from '@/types/api' +import type { + MultiScenarioComparison, + ScenarioAssumptions, + ScenarioComparisonRow, + ScenarioPoint, +} from '@/types/api' function makePoint(date: string, baseline: number, scenario: number): ScenarioPoint { return { @@ -90,3 +97,49 @@ describe('summariseAssumptions', () => { expect(lines[4]).toContain('growth') }) }) + +function makeRow(scenarioId: string, name: string, rank: number): ScenarioComparisonRow { + return { + scenario_id: scenarioId, + name, + units_delta: 0, + revenue_delta: 0, + coverage_verdict: 'unknown', + rank, + } +} + +describe('buildMultiSeries', () => { + it('puts baseline first, then one line per scenario keyed by scenario_id', () => { + const comparison: MultiScenarioComparison = { + baseline_total_units: 100, + baseline_revenue: 1000, + rank_by: 'revenue_delta', + scenarios: [makeRow('sid-a', 'Cut', 1), makeRow('sid-b', 'Rise', 2)], + chart_series: [], + } + expect(buildMultiSeries(comparison)).toEqual([ + { key: 'baseline', label: 'Baseline' }, + { key: 'sid-a', label: 'Cut' }, + { key: 'sid-b', label: 'Rise' }, + ]) + }) + + it('returns only the baseline line when there are no scenarios', () => { + const comparison: MultiScenarioComparison = { + baseline_total_units: 0, + baseline_revenue: 0, + rank_by: 'units_delta', + scenarios: [], + chart_series: [], + } + expect(buildMultiSeries(comparison)).toEqual([{ key: 'baseline', label: 'Baseline' }]) + }) +}) + +describe('methodLabel', () => { + it('labels each comparison method', () => { + expect(methodLabel('heuristic')).toBe('Heuristic') + expect(methodLabel('model_exogenous')).toBe('Model-driven') + }) +}) diff --git a/frontend/src/lib/scenario-utils.ts b/frontend/src/lib/scenario-utils.ts index 3bb6e880..bbd7b7ed 100644 --- a/frontend/src/lib/scenario-utils.ts +++ b/frontend/src/lib/scenario-utils.ts @@ -6,7 +6,12 @@ * `ScenarioComparison` into chart rows, a delta table, and readable summaries. */ import type { CsvColumn } from '@/lib/csv-export' -import type { CoverageVerdict, ScenarioAssumptions, ScenarioPoint } from '@/types/api' +import type { + CoverageVerdict, + MultiScenarioComparison, + ScenarioAssumptions, + ScenarioPoint, +} from '@/types/api' /** One charted day: a date plus the baseline and scenario demand values. */ export interface ComparisonChartRow { @@ -105,3 +110,28 @@ export function summariseAssumptions(assumptions: ScenarioAssumptions): string[] } return lines } + +/** A line in the multi-scenario comparison chart. */ +export interface MultiSeriesLine { + key: string + label: string +} + +/** + * Derive the chart lines for a multi-scenario comparison: the shared baseline + * first, then one line per scenario keyed by `scenario_id` — matching the keys + * the backend put in each `chart_series` row (a CSS-identifier-safe key) — and + * labelled by the plan name. + */ +export function buildMultiSeries(comparison: MultiScenarioComparison): MultiSeriesLine[] { + const lines: MultiSeriesLine[] = [{ key: 'baseline', label: 'Baseline' }] + for (const row of comparison.scenarios) { + lines.push({ key: row.scenario_id, label: row.name }) + } + return lines +} + +/** Human label for the method that produced a comparison. */ +export function methodLabel(method: 'heuristic' | 'model_exogenous'): string { + return method === 'model_exogenous' ? 'Model-driven' : 'Heuristic' +} diff --git a/frontend/src/pages/visualize/planner.tsx b/frontend/src/pages/visualize/planner.tsx index ed1c2515..24aa3194 100644 --- a/frontend/src/pages/visualize/planner.tsx +++ b/frontend/src/pages/visualize/planner.tsx @@ -1,13 +1,15 @@ import { useState } from 'react' -import { AlertTriangle, Download, Loader2, Play, Save, Trash2 } from 'lucide-react' +import { AlertTriangle, BarChart3, Download, Loader2, Play, Save, Trash2 } from 'lucide-react' import { useJob } from '@/hooks/use-jobs' import { + useCompareScenarios, useCreateScenario, useDeleteScenario, useScenario, useScenarios, useSimulateScenario, } from '@/hooks/use-scenarios' +import { MultiSeriesChart } from '@/components/charts/multi-series-chart' import { TimeSeriesChart } from '@/components/charts/time-series-chart' import { JobPicker } from '@/components/common/job-picker' import { Button } from '@/components/ui/button' @@ -33,13 +35,16 @@ import { import { downloadCsv, toCsv } from '@/lib/csv-export' import { formatCurrency, formatNumber, getErrorMessage } from '@/lib/api' import { + buildMultiSeries, coverageLabel, coverageVariant, deltaCsvColumns, formatDelta, mergeComparisonSeries, + methodLabel, } from '@/lib/scenario-utils' import type { + MultiScenarioComparison, PromotionAssumption, ScenarioAssumptions, ScenarioComparison, @@ -98,9 +103,15 @@ export default function WhatIfPlannerPage() { const [runError, setRunError] = useState(null) const [reloadId, setReloadId] = useState('') + // -- Multi-scenario comparison state ----------------------------------- + const [selectedPlanIds, setSelectedPlanIds] = useState>(new Set()) + const [multiComparison, setMultiComparison] = useState(null) + const [compareError, setCompareError] = useState(null) + const simulate = useSimulateScenario() const createScenario = useCreateScenario() const deleteScenario = useDeleteScenario() + const compareScenarios = useCompareScenarios() const scenariosQuery = useScenarios() const reloadedPlan = useScenario(reloadId, !!reloadId) @@ -178,6 +189,31 @@ export default function WhatIfPlannerPage() { downloadCsv('scenario-deltas.csv', toCsv(comparison.points, deltaCsvColumns)) } + /** Toggle a saved plan in the multi-scenario comparison selection. */ + function togglePlanSelection(scenarioId: string) { + setSelectedPlanIds((current) => { + const next = new Set(current) + if (next.has(scenarioId)) next.delete(scenarioId) + else next.add(scenarioId) + return next + }) + } + + async function handleCompare() { + if (selectedPlanIds.size < 2) return + setCompareError(null) + try { + const result = await compareScenarios.mutateAsync({ + scenario_ids: [...selectedPlanIds], + rank_by: 'revenue_delta', + }) + setMultiComparison(result) + } catch (caught) { + setCompareError(getErrorMessage(caught)) + setMultiComparison(null) + } + } + return (
@@ -193,11 +229,12 @@ export default function WhatIfPlannerPage() {
-

Heuristic estimates — not a causal model

+

Scenario estimates — directional planning signals

- Scenario results apply fixed, deterministic adjustment factors to a baseline - forecast. Treat the demand and revenue deltas as directional planning signals, - not precise predictions. + A baseline forecaster yields a heuristic estimate (fixed adjustment factors); a + regression baseline is genuinely re-forecast through the model. Either way, treat + the demand and revenue deltas as planning signals, not precise predictions — each + result states the method that produced it.

@@ -446,7 +483,8 @@ export default function WhatIfPlannerPage() { Scenario impact {comparison.model_type} model · store {comparison.store_id} · product{' '} - {comparison.product_id} · {comparison.horizon}-day horizon + {comparison.product_id} · {comparison.horizon}-day horizon ·{' '} + {methodLabel(comparison.method)} estimate @@ -573,51 +611,88 @@ export default function WhatIfPlannerPage() { {/* Saved plans */} - Saved plans - - Reload a saved plan to re-render its comparison, or delete one. - +
+
+ Saved plans + + Reload a plan to re-render its comparison, or select 2-5 plans to + compare them side by side. + +
+ +
+ {compareError &&

{compareError}

} {scenariosQuery.data && scenariosQuery.data.scenarios.length > 0 ? ( + Name + Tags Units delta Revenue delta Actions - {scenariosQuery.data.scenarios.map((plan) => ( - - {plan.name} - {formatDelta(plan.units_delta)} - - {formatCurrency(plan.revenue_delta)} - - -
- - -
-
-
- ))} + {scenariosQuery.data.scenarios.map((plan) => { + const selected = selectedPlanIds.has(plan.scenario_id) + return ( + + + = 5} + onCheckedChange={() => togglePlanSelection(plan.scenario_id)} + aria-label={`Select ${plan.name}`} + /> + + {plan.name} + + {plan.tags.length > 0 ? plan.tags.join(', ') : '—'} + + + {formatDelta(plan.units_delta)} + + + {formatCurrency(plan.revenue_delta)} + + +
+ + +
+
+
+ ) + })}
) : ( @@ -627,6 +702,55 @@ export default function WhatIfPlannerPage() { )}
+ + {/* Multi-scenario comparison result */} + {multiComparison && ( + + + Scenario comparison + + {multiComparison.scenarios.length} plans ranked by{' '} + {multiComparison.rank_by === 'revenue_delta' ? 'revenue delta' : 'units delta'}. + + + + + + + Rank + Plan + Units delta + Revenue delta + Coverage + + + + {multiComparison.scenarios.map((row) => ( + + {row.rank} + {row.name} + {formatDelta(row.units_delta)} + + {formatCurrency(row.revenue_delta)} + + + + {coverageLabel(row.coverage_verdict)} + + + + ))} + +
+ +
+
+ )}
) } diff --git a/frontend/src/types/api.ts b/frontend/src/types/api.ts index 163f6c84..80bc6da5 100644 --- a/frontend/src/types/api.ts +++ b/frontend/src/types/api.ts @@ -757,6 +757,8 @@ export interface CreateScenarioRequest { run_id: string horizon: number assumptions: ScenarioAssumptions + tags?: string[] + cloned_from?: string | null } // Whether projected demand is covered by on-hand stock. @@ -787,7 +789,9 @@ export interface ScenarioComparison { scenario_revenue: number revenue_delta: number coverage_verdict: CoverageVerdict - method: 'heuristic' + // 'heuristic' = a deterministic post-forecast multiplier; 'model_exogenous' + // = a re-forecast through a feature-consuming regression model (PRP-27). + method: 'heuristic' | 'model_exogenous' disclaimer: string generated_at: string } @@ -804,6 +808,8 @@ export interface ScenarioPlanResponse { created_at: string assumptions: ScenarioAssumptions comparison: ScenarioComparison + tags: string[] + cloned_from: string | null } // A compact row in the saved-plans list. @@ -816,6 +822,7 @@ export interface ScenarioListItem { units_delta: number revenue_delta: number created_at: string + tags: string[] } // A page of saved scenario plans — GET /scenarios. @@ -823,3 +830,33 @@ export interface ScenarioListResponse { scenarios: ScenarioListItem[] total: number } + +// Metric a multi-scenario comparison ranks by. +export type RankBy = 'revenue_delta' | 'units_delta' + +// Request body for POST /scenarios/compare. +export interface CompareScenariosRequest { + scenario_ids: string[] + rank_by?: RankBy +} + +// One saved plan's headline numbers within a multi-scenario comparison. +export interface ScenarioComparisonRow { + scenario_id: string + name: string + units_delta: number + revenue_delta: number + coverage_verdict: CoverageVerdict + rank: number +} + +// A baseline compared against 2-5 saved scenarios — POST /scenarios/compare. +export interface MultiScenarioComparison { + baseline_total_units: number + baseline_revenue: number + rank_by: RankBy + scenarios: ScenarioComparisonRow[] + // Date-keyed merged rows: each carries 'date', 'baseline', and one numeric + // entry per scenario name. + chart_series: Record[] +} From 64c16efee38de88f9c6375c20af0c72eb3724232 Mon Sep 17 00:00:00 2001 From: Gabor Szabo Date: Tue, 19 May 2026 11:01:44 +0200 Subject: [PATCH 7/9] feat(api,db): add scenario provenance and audit columns (#223) --- ...748581e_add_scenario_provenance_columns.py | 73 ++++++++++++++++++ app/features/scenarios/models.py | 30 +++++++- app/features/scenarios/schemas.py | 74 +++++++++++++++++-- app/features/scenarios/service.py | 42 ++++++++++- 4 files changed, 211 insertions(+), 8 deletions(-) create mode 100644 alembic/versions/7e8f9748581e_add_scenario_provenance_columns.py diff --git a/alembic/versions/7e8f9748581e_add_scenario_provenance_columns.py b/alembic/versions/7e8f9748581e_add_scenario_provenance_columns.py new file mode 100644 index 00000000..6f600f96 --- /dev/null +++ b/alembic/versions/7e8f9748581e_add_scenario_provenance_columns.py @@ -0,0 +1,73 @@ +"""add scenario provenance columns + +Revision ID: 7e8f9748581e +Revises: bb8c4587ef1d +Create Date: 2026-05-19 10:47:09.829097 + +PRP-27 Phase D — adds provenance + approval-audit columns to ``scenario_plan`` +so an agent-proposed plan records who/what created it and the human approval +decision that released it. ``source`` server-defaults to ``'user'`` so every +pre-existing row stays valid. Forward-only. +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision: str = '7e8f9748581e' +down_revision: Union[str, None] = 'bb8c4587ef1d' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Add the source + approval-audit columns, their CHECKs and an index.""" + op.add_column( + 'scenario_plan', + sa.Column( + 'source', + sa.String(length=16), + nullable=False, + server_default=sa.text("'user'"), + ), + ) + op.add_column( + 'scenario_plan', + sa.Column('agent_session_id', sa.String(length=32), nullable=True), + ) + op.add_column( + 'scenario_plan', + sa.Column('approved_by', sa.String(length=120), nullable=True), + ) + op.add_column( + 'scenario_plan', + sa.Column('approved_at', sa.DateTime(timezone=True), nullable=True), + ) + op.add_column( + 'scenario_plan', + sa.Column('approval_decision', sa.String(length=16), nullable=True), + ) + op.create_check_constraint( + 'ck_scenario_plan_source', + 'scenario_plan', + "source IN ('user', 'agent')", + ) + op.create_check_constraint( + 'ck_scenario_plan_approval_decision', + 'scenario_plan', + "approval_decision IS NULL OR approval_decision IN ('approved', 'rejected')", + ) + op.create_index('ix_scenario_plan_source', 'scenario_plan', ['source'], unique=False) + + +def downgrade() -> None: + """Drop the index, the two CHECKs and the provenance columns.""" + op.drop_index('ix_scenario_plan_source', table_name='scenario_plan') + op.drop_constraint('ck_scenario_plan_approval_decision', 'scenario_plan', type_='check') + op.drop_constraint('ck_scenario_plan_source', 'scenario_plan', type_='check') + op.drop_column('scenario_plan', 'approval_decision') + op.drop_column('scenario_plan', 'approved_at') + op.drop_column('scenario_plan', 'approved_by') + op.drop_column('scenario_plan', 'agent_session_id') + op.drop_column('scenario_plan', 'source') diff --git a/app/features/scenarios/models.py b/app/features/scenarios/models.py index 1bab6a0c..39f3b67e 100644 --- a/app/features/scenarios/models.py +++ b/app/features/scenarios/models.py @@ -12,9 +12,10 @@ from __future__ import annotations +from datetime import datetime from typing import Any -from sqlalchemy import CheckConstraint, Index, Integer, String, text +from sqlalchemy import CheckConstraint, DateTime, Index, Integer, String, text from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.orm import Mapped, mapped_column @@ -27,6 +28,10 @@ SCENARIO_METHOD_HEURISTIC = "heuristic" SCENARIO_METHOD_MODEL_EXOGENOUS = "model_exogenous" +# Provenance — who or what created a scenario plan (PRP-27 Phase D). +SCENARIO_SOURCE_USER = "user" +SCENARIO_SOURCE_AGENT = "agent" + class ScenarioPlan(TimestampMixin, Base): """A saved scenario plan. @@ -71,6 +76,17 @@ class ScenarioPlan(TimestampMixin, Base): ) cloned_from: Mapped[str | None] = mapped_column(String(32), nullable=True) + # Provenance + approval-audit columns (PRP-27 Phase D). ``source`` defaults + # to 'user'; an agent-saved plan carries 'agent' plus the originating + # session id and the human approval audit trail. + source: Mapped[str] = mapped_column( + String(16), nullable=False, default=SCENARIO_SOURCE_USER, server_default=text("'user'") + ) + agent_session_id: Mapped[str | None] = mapped_column(String(32), nullable=True) + approved_by: Mapped[str | None] = mapped_column(String(120), nullable=True) + approved_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True) + approval_decision: Mapped[str | None] = mapped_column(String(16), nullable=True) + __table_args__ = ( # GIN indexes for JSONB containment queries on either blob. Index("ix_scenario_plan_assumptions_gin", "assumptions", postgresql_using="gin"), @@ -79,10 +95,22 @@ class ScenarioPlan(TimestampMixin, Base): Index("ix_scenario_plan_store_product", "store_id", "product_id"), # GIN index so the saved-plans list can filter by tag containment. Index("ix_scenario_plan_tags_gin", "tags", postgresql_using="gin"), + # Index on source for the "show me agent-proposed plans" query. + Index("ix_scenario_plan_source", "source"), # heuristic (MVP) or model_exogenous (PRP-27) — kept in lock-step with # the alembic migration that widened this CHECK. CheckConstraint( "method IN ('heuristic', 'model_exogenous')", name="ck_scenario_plan_method", ), + # Provenance + approval-audit CHECKs — kept in lock-step with the + # alembic migration that added these columns. + CheckConstraint( + "source IN ('user', 'agent')", + name="ck_scenario_plan_source", + ), + CheckConstraint( + "approval_decision IS NULL OR approval_decision IN ('approved', 'rejected')", + name="ck_scenario_plan_approval_decision", + ), ) diff --git a/app/features/scenarios/schemas.py b/app/features/scenarios/schemas.py index 5a998e0e..fe1a71ef 100644 --- a/app/features/scenarios/schemas.py +++ b/app/features/scenarios/schemas.py @@ -212,6 +212,48 @@ class CreateScenarioRequest(BaseModel): ) +class SaveScenarioRequest(BaseModel): + """What the ``save_scenario`` agent tool persists once HITL-approved. + + Unlike ``CreateScenarioRequest`` this carries ``store_id`` / ``product_id`` + explicitly — the agent proposed the scenario for a known grain, so the + identity travels with the request rather than being re-derived. The + persisted ``scenario_plan`` row is stamped ``source='agent'`` plus the + originating ``agent_session_id`` (PRP-27 Phase D, DECISIONS LOCKED #13). + """ + + model_config = ConfigDict(strict=True) + + name: str = Field( + ..., + min_length=1, + max_length=200, + description="Human-readable name for the saved plan.", + ) + assumptions: ScenarioAssumptions = Field( + ..., + description="The what-if assumptions the agent proposed.", + ) + store_id: int = Field(..., ge=1, description="Store the proposed scenario targets.") + product_id: int = Field(..., ge=1, description="Product the proposed scenario targets.") + horizon: int = Field(..., ge=1, le=90, description="Number of days to simulate.") + run_id: str = Field( + ..., + min_length=1, + max_length=64, + description="Artifact key of the baseline model.", + ) + source: Literal["user", "agent"] = Field( + default="agent", + description="Provenance of the plan — an agent save defaults to 'agent'.", + ) + agent_session_id: str | None = Field( + default=None, + max_length=32, + description="The originating agent session id, when agent-created.", + ) + + # ============================================================================= # Response models # ============================================================================= @@ -297,12 +339,23 @@ class ScenarioPlanResponse(BaseModel): comparison: ScenarioComparison = Field( ..., description="The full baseline-vs-scenario snapshot, re-rendered without recompute." ) - tags: list[str] = Field( - default_factory=list, description="Library tags attached to the plan." - ) + tags: list[str] = Field(default_factory=list, description="Library tags attached to the plan.") cloned_from: str | None = Field( default=None, description="scenario_id this plan was cloned from, if any." ) + source: str = Field(default="user", description="Who created the plan — 'user' or 'agent'.") + agent_session_id: str | None = Field( + default=None, description="Originating agent session id, when agent-created." + ) + approved_by: str | None = Field( + default=None, description="Who approved an agent-created plan, if any." + ) + approved_at: datetime | None = Field( + default=None, description="When an agent-created plan was approved (UTC)." + ) + approval_decision: str | None = Field( + default=None, description="The HITL decision — 'approved' or 'rejected'." + ) class ScenarioListItem(BaseModel): @@ -318,8 +371,19 @@ class ScenarioListItem(BaseModel): units_delta: float = Field(..., description="Summed scenario-minus-baseline demand.") revenue_delta: float = Field(..., description="Scenario-minus-baseline revenue.") created_at: datetime = Field(..., description="When the plan was saved (UTC).") - tags: list[str] = Field( - default_factory=list, description="Library tags attached to the plan." + tags: list[str] = Field(default_factory=list, description="Library tags attached to the plan.") + source: str = Field(default="user", description="Who created the plan — 'user' or 'agent'.") + agent_session_id: str | None = Field( + default=None, description="Originating agent session id, when agent-created." + ) + approved_by: str | None = Field( + default=None, description="Who approved an agent-created plan, if any." + ) + approved_at: datetime | None = Field( + default=None, description="When an agent-created plan was approved (UTC)." + ) + approval_decision: str | None = Field( + default=None, description="The HITL decision — 'approved' or 'rejected'." ) diff --git a/app/features/scenarios/service.py b/app/features/scenarios/service.py index 9b95cc68..4fcdb308 100644 --- a/app/features/scenarios/service.py +++ b/app/features/scenarios/service.py @@ -34,7 +34,7 @@ from app.features.forecasting.persistence import ModelBundle, load_model_bundle from app.features.scenarios import adjustments from app.features.scenarios.feature_frame import build_future_frame -from app.features.scenarios.models import ScenarioPlan +from app.features.scenarios.models import SCENARIO_SOURCE_USER, ScenarioPlan from app.features.scenarios.schemas import ( CompareScenariosRequest, CreateScenarioRequest, @@ -325,13 +325,30 @@ async def _simulate_model_exogenous( # -- Persistence ------------------------------------------------------- async def create_plan( - self, db: AsyncSession, request: CreateScenarioRequest + self, + db: AsyncSession, + request: CreateScenarioRequest, + *, + source: str = SCENARIO_SOURCE_USER, + agent_session_id: str | None = None, + approved_by: str | None = None, + approval_decision: str | None = None, ) -> ScenarioPlanResponse: """Run a simulation and persist it as a named scenario plan. + The provenance keyword arguments default to a plain user-created plan, + so the MVP create path stays backward-compatible. An agent-saved plan + (PRP-27 Phase D) passes ``source='agent'`` plus the originating + ``agent_session_id`` and the HITL approval audit trail; ``approved_at`` + is stamped automatically whenever ``approval_decision`` is supplied. + Args: db: Database session. request: Plan name plus the baseline / horizon / assumptions. + source: Who created the plan — 'user' (default) or 'agent'. + agent_session_id: Originating agent session id, when agent-created. + approved_by: Who approved an agent-created plan, if any. + approval_decision: The HITL decision — 'approved' or 'rejected'. Returns: The saved plan with its embedded comparison snapshot. @@ -350,6 +367,10 @@ async def create_plan( ), ) + # An approval decision implies an approval moment — stamp it here so + # callers never have to thread a timestamp through. + approved_at = datetime.now(UTC) if approval_decision is not None else None + plan = ScenarioPlan( scenario_id=uuid.uuid4().hex, name=request.name, @@ -365,6 +386,11 @@ async def create_plan( method=comparison.method, tags=list(request.tags), cloned_from=request.cloned_from, + source=source, + agent_session_id=agent_session_id, + approved_by=approved_by, + approved_at=approved_at, + approval_decision=approval_decision, ) db.add(plan) await db.commit() @@ -375,6 +401,8 @@ async def create_plan( scenario_id=plan.scenario_id, store_id=plan.store_id, product_id=plan.product_id, + source=source, + agent_session_id=agent_session_id, ) return self._to_plan_response(plan) @@ -599,6 +627,11 @@ def _to_plan_response(plan: ScenarioPlan) -> ScenarioPlanResponse: comparison=ScenarioComparison.model_validate(plan.comparison), tags=list(plan.tags), cloned_from=plan.cloned_from, + source=plan.source, + agent_session_id=plan.agent_session_id, + approved_by=plan.approved_by, + approved_at=plan.approved_at, + approval_decision=plan.approval_decision, ) @staticmethod @@ -614,4 +647,9 @@ def _to_list_item(plan: ScenarioPlan) -> ScenarioListItem: revenue_delta=float(plan.comparison.get("revenue_delta", 0.0)), created_at=plan.created_at, tags=list(plan.tags), + source=plan.source, + agent_session_id=plan.agent_session_id, + approved_by=plan.approved_by, + approved_at=plan.approved_at, + approval_decision=plan.approval_decision, ) From 9245c92d0154eb8b1e4ea0feae10ea9e7f60b5e5 Mon Sep 17 00:00:00 2001 From: Gabor Szabo Date: Tue, 19 May 2026 11:02:03 +0200 Subject: [PATCH 8/9] feat(agents): add agent-proposed and hitl-gated save scenario tools (#223) --- .env.example | 2 +- app/core/config.py | 5 +- app/features/agents/agents/base.py | 2 + app/features/agents/agents/experiment.py | 107 +++++++++ app/features/agents/service.py | 14 +- app/features/scenarios/agent_tools.py | 195 +++++++++++++++ .../scenarios/tests/test_agent_tools.py | 222 ++++++++++++++++++ 7 files changed, 544 insertions(+), 3 deletions(-) create mode 100644 app/features/scenarios/agent_tools.py create mode 100644 app/features/scenarios/tests/test_agent_tools.py diff --git a/.env.example b/.env.example index 0545e3b1..df1b1465 100644 --- a/.env.example +++ b/.env.example @@ -97,7 +97,7 @@ AGENT_SESSION_TTL_MINUTES=120 AGENT_MAX_SESSIONS_PER_USER=5 # Human-in-the-loop actions (JSON array format required for safe parsing) -AGENT_REQUIRE_APPROVAL=["create_alias","archive_run"] +AGENT_REQUIRE_APPROVAL=["create_alias","archive_run","save_scenario"] AGENT_APPROVAL_TIMEOUT_MINUTES=60 # Streaming diff --git a/app/core/config.py b/app/core/config.py index cdbe5cf1..d3ac4a24 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -161,7 +161,10 @@ class Settings(BaseSettings): agent_retry_delay_seconds: float = 1.0 # Human-in-the-Loop Configuration - agent_require_approval: list[str] = ["create_alias", "archive_run"] + # ``save_scenario`` (PRP-27 Phase D) lets the experiment agent persist a + # scenario_plan row — a deliberate mutation-surface widening, so it is + # gated here exactly like create_alias / archive_run. + agent_require_approval: list[str] = ["create_alias", "archive_run", "save_scenario"] agent_approval_timeout_minutes: int = 60 # Session Configuration diff --git a/app/features/agents/agents/base.py b/app/features/agents/agents/base.py index 272e59e8..f9f4e1a0 100644 --- a/app/features/agents/agents/base.py +++ b/app/features/agents/agents/base.py @@ -284,6 +284,8 @@ def requires_approval(action_name: str) -> bool: - Use tool_compare_runs to analyze differences between registered runs - Use tool_create_alias to deploy successful models (requires approval) - Use tool_archive_run to clean up old experiments (requires approval) +- Use tool_propose_scenario to draft a candidate what-if scenario (read-only) +- Use tool_save_scenario to persist an approved scenario plan (requires approval) """ SAFETY_INSTRUCTIONS = """ diff --git a/app/features/agents/agents/experiment.py b/app/features/agents/agents/experiment.py index 7a2de475..d9bad5bd 100644 --- a/app/features/agents/agents/experiment.py +++ b/app/features/agents/agents/experiment.py @@ -38,6 +38,8 @@ get_run, list_runs, ) +from app.features.scenarios.agent_tools import propose_scenario, save_scenario +from app.features.scenarios.schemas import SaveScenarioRequest logger = structlog.get_logger() @@ -373,6 +375,111 @@ async def tool_archive_run( return await archive_run(db=ctx.deps.db, run_id=run_id) + @agent.tool + @recoverable + async def tool_propose_scenario( + ctx: RunContext[AgentDeps], + store_id: int, + product_id: int, + horizon: int = 14, + objective: str = "", + ) -> dict[str, Any]: + """Propose a candidate what-if scenario for a store / product. + + READ-ONLY: this drafts a candidate scenario; it persists nothing. To + save the proposal, call tool_save_scenario (which requires approval). + + Args: + store_id: Store the proposed scenario targets. + product_id: Product the proposed scenario targets. + horizon: Number of days the proposed scenario should span (default 14). + objective: Free-text planning objective — keywords like 'promotion' + steer the proposal toward a promotion instead of a price cut. + + Returns: + A candidate scenario with assumptions and a recommendation. + """ + ctx.deps.increment_tool_calls() + logger.info( + "agents.experiment.tool_propose_scenario", + session_id=ctx.deps.session_id, + store_id=store_id, + product_id=product_id, + ) + return await propose_scenario( + db=ctx.deps.db, + store_id=store_id, + product_id=product_id, + horizon=horizon, + objective=objective, + ) + + @agent.tool + @recoverable + async def tool_save_scenario( + ctx: RunContext[AgentDeps], + name: str, + run_id: str, + store_id: int, + product_id: int, + horizon: int, + assumptions: dict[str, Any] | None = None, + ) -> dict[str, Any]: + """Persist a proposed what-if scenario as a saved scenario plan. + + REQUIRES HUMAN APPROVAL. This action writes a scenario_plan row. + + Use this only after tool_propose_scenario, passing back its candidate + assumptions. The plan is persisted with agent provenance and the + approval audit trail. + + Args: + name: Human-readable name for the saved plan. + run_id: Artifact key of the baseline model. + store_id: Store the scenario targets. + product_id: Product the scenario targets. + horizon: Number of days to simulate. + assumptions: The candidate assumptions dict from tool_propose_scenario. + + Returns: + The saved plan details, or an approval request. + """ + ctx.deps.increment_tool_calls() + logger.info( + "agents.experiment.tool_save_scenario", + session_id=ctx.deps.session_id, + store_id=store_id, + product_id=product_id, + requires_approval=requires_approval("save_scenario"), + ) + + arguments: dict[str, Any] = { + "name": name, + "run_id": run_id, + "store_id": store_id, + "product_id": product_id, + "horizon": horizon, + "assumptions": assumptions or {}, + "source": "agent", + "agent_session_id": ctx.deps.session_id, + } + + # Check if approval is required — mirrors tool_create_alias exactly. + if requires_approval("save_scenario"): + return { + "status": "approval_required", + "action": "save_scenario", + "arguments": arguments, + "message": "This action requires human approval. Please approve to proceed.", + } + + request = SaveScenarioRequest.model_validate(arguments) + return await save_scenario( + db=ctx.deps.db, + request=request, + agent_session_id=ctx.deps.session_id, + ) + return agent diff --git a/app/features/agents/service.py b/app/features/agents/service.py index 751c984c..1b3c4644 100644 --- a/app/features/agents/service.py +++ b/app/features/agents/service.py @@ -869,6 +869,8 @@ async def _execute_pending_action( ValueError: If action_type is not recognized. """ from app.features.agents.tools.registry_tools import archive_run, create_alias + from app.features.scenarios.agent_tools import save_scenario + from app.features.scenarios.schemas import SaveScenarioRequest if action_type == "create_alias": alias_name = arguments.get("alias_name", "") @@ -886,7 +888,17 @@ async def _execute_pending_action( if result is None: raise ValueError(f"Run not found: {run_id}") return result + elif action_type == "save_scenario": + # The HITL gate has released the agent's save_scenario call — persist + # the scenario_plan row now, stamped with the approved audit trail. + request = SaveScenarioRequest.model_validate(arguments) + return await save_scenario( + db=db, + request=request, + agent_session_id=arguments.get("agent_session_id"), + ) else: raise ValueError( - f"Unknown action type: {action_type}. Supported actions: create_alias, archive_run" + f"Unknown action type: {action_type}. Supported actions: " + "create_alias, archive_run, save_scenario" ) diff --git a/app/features/scenarios/agent_tools.py b/app/features/scenarios/agent_tools.py new file mode 100644 index 00000000..ab99c17c --- /dev/null +++ b/app/features/scenarios/agent_tools.py @@ -0,0 +1,195 @@ +"""Agent-facing tools for the Scenario Simulation slice (PRP-27 Phase D). + +This module is the *integration seam* between the agent layer and the +scenarios slice. ``app/features/agents/`` imports THIS module — never +``scenarios/service.py`` directly — so the no-cross-slice-``service.py``-import +rule (DECISIONS LOCKED #3) holds while the agent still gains scenario tools. + +Two tools live here: + +* :func:`propose_scenario` — **read-only**. Returns a candidate + ``ScenarioAssumptions`` plus a plain-language recommendation. It proposes, + it never persists, so it needs no approval. +* :func:`save_scenario` — **mutating**. Persists a ``scenario_plan`` row via the + scenarios service create path, stamped ``source='agent'`` with the originating + ``agent_session_id`` and the HITL approval audit trail. It runs only after the + human-in-the-loop gate releases it — its tool name is in + ``agent_require_approval`` (DECISIONS LOCKED #13). +""" + +from __future__ import annotations + +from datetime import UTC, datetime, timedelta +from typing import Any + +import structlog +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.features.data_platform.models import SalesDaily +from app.features.scenarios.models import SCENARIO_SOURCE_AGENT +from app.features.scenarios.schemas import ( + CreateScenarioRequest, + PriceAssumption, + PromotionAssumption, + SaveScenarioRequest, + ScenarioAssumptions, +) +from app.features.scenarios.service import ScenarioService + +logger = structlog.get_logger() + +# Keywords in a free-text objective that steer the proposal toward a promotion +# rather than the default price cut. +_PROMOTION_KEYWORDS = ("promo", "promotion", "discount", "sale", "markdown") + +# The default magnitude of the proposed price cut (a 15% reduction). +_PROPOSED_PRICE_CHANGE_PCT = -0.15 + +# Recorded as ``approved_by`` on an agent-saved plan. The system is single-host +# and unauthenticated, so the approving party is simply the local operator who +# released the HITL gate. +AGENT_SAVE_APPROVED_BY = "operator" + + +async def propose_scenario( + db: AsyncSession, + store_id: int, + product_id: int, + horizon: int, + objective: str, +) -> dict[str, Any]: + """Propose a candidate what-if scenario for a (store, product) grain. + + READ-ONLY: this tool builds a candidate ``ScenarioAssumptions`` and a + recommendation; it performs no database writes. Persisting the proposal is + a separate, approval-gated step (:func:`save_scenario`). + + Args: + db: Database session (used only to read a recent unit price for a + grounded recommendation). + store_id: Store the proposed scenario targets. + product_id: Product the proposed scenario targets. + horizon: Number of days the proposed scenario should span. + objective: Free-text planning objective — keywords such as "promotion" + steer the proposal toward a promotion instead of a price cut. + + Returns: + A dict with the target grain, the horizon, the originating objective, + the candidate ``assumptions`` (JSON-mode dump so dates are ISO strings, + ready to pass straight back into ``save_scenario``), and a + plain-language ``recommendation``. + """ + logger.info( + "agents.scenario_tool.propose_scenario_called", + store_id=store_id, + product_id=product_id, + horizon=horizon, + ) + + # Read the most recent unit price for a grounded recommendation. Read-only. + latest_price = await db.scalar( + select(SalesDaily.unit_price) + .where(SalesDaily.store_id == store_id, SalesDaily.product_id == product_id) + .order_by(SalesDaily.date.desc()) + .limit(1) + ) + + start = datetime.now(UTC).date() + timedelta(days=1) + end = start + timedelta(days=horizon - 1) + + if any(keyword in objective.lower() for keyword in _PROMOTION_KEYWORDS): + assumptions = ScenarioAssumptions( + promotion=PromotionAssumption(kind="pct_off", start_date=start, end_date=end) + ) + rationale = ( + f"Run a pct_off promotion from {start} to {end} ({horizon} days) and " + "simulate the demand lift before committing." + ) + else: + assumptions = ScenarioAssumptions( + price=PriceAssumption( + change_pct=_PROPOSED_PRICE_CHANGE_PCT, start_date=start, end_date=end + ) + ) + price_note = ( + f" The most recent unit price is ~{float(latest_price):.2f}." + if latest_price is not None + else "" + ) + rationale = ( + f"Cut price {abs(_PROPOSED_PRICE_CHANGE_PCT) * 100:.0f}% from {start} to " + f"{end} ({horizon} days) to test the demand response.{price_note}" + ) + + recommendation = ( + f"Proposed what-if for store {store_id} / product {product_id} toward the " + f"objective '{objective}'. {rationale} This is a candidate only — review it " + "and save it explicitly to persist a scenario plan." + ) + return { + "store_id": store_id, + "product_id": product_id, + "horizon": horizon, + "objective": objective, + "assumptions": assumptions.model_dump(mode="json"), + "recommendation": recommendation, + } + + +async def save_scenario( + db: AsyncSession, + request: SaveScenarioRequest, + *, + agent_session_id: str | None, +) -> dict[str, Any]: + """Persist an agent-proposed scenario as a saved ``scenario_plan`` row. + + MUTATING: this tool writes a row. It runs only once the HITL approval gate + has released it (``save_scenario`` is in ``agent_require_approval``), so the + persisted plan always carries an ``approved`` audit trail. + + Args: + db: Database session. + request: The validated scenario to persist (name, run_id, horizon, + assumptions). + agent_session_id: The originating agent session id — the runtime truth, + authoritative over any value carried on ``request``. + + Returns: + The saved plan as a JSON-mode dict, including its embedded comparison + snapshot and the provenance / audit fields. + + Raises: + FileNotFoundError: When no model artifact exists for ``request.run_id``. + ValueError: When the artifact path or its metadata is invalid. + """ + logger.info( + "agents.scenario_tool.save_scenario_called", + store_id=request.store_id, + product_id=request.product_id, + agent_session_id=agent_session_id, + ) + + service = ScenarioService() + create_request = CreateScenarioRequest( + name=request.name, + run_id=request.run_id, + horizon=request.horizon, + assumptions=request.assumptions, + ) + plan = await service.create_plan( + db, + create_request, + source=SCENARIO_SOURCE_AGENT, + agent_session_id=agent_session_id, + approved_by=AGENT_SAVE_APPROVED_BY, + approval_decision="approved", + ) + + logger.info( + "agents.scenario_tool.save_scenario_completed", + scenario_id=plan.scenario_id, + agent_session_id=agent_session_id, + ) + return plan.model_dump(mode="json") diff --git a/app/features/scenarios/tests/test_agent_tools.py b/app/features/scenarios/tests/test_agent_tools.py new file mode 100644 index 00000000..5f1ba739 --- /dev/null +++ b/app/features/scenarios/tests/test_agent_tools.py @@ -0,0 +1,222 @@ +"""Tests for the scenarios agent tools and the HITL save gate (PRP-27 Phase D). + +Two layers: + +* A unit test that the ``save_scenario`` tool name is wired into + ``agent_require_approval`` — the mutation-surface guard. +* Integration tests (real PostgreSQL + a real model bundle, ``docker compose up + -d`` required) covering ``propose_scenario`` (read-only — persists nothing), + ``save_scenario`` (persists with agent provenance), and the HITL gate firing + through ``AgentService.approve_action``. +""" + +from __future__ import annotations + +import uuid +from datetime import UTC, datetime, timedelta + +import pytest +from sqlalchemy import delete, func, select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.config import get_settings +from app.features.agents.agents.base import requires_approval +from app.features.agents.models import AgentSession, AgentType, SessionStatus +from app.features.agents.service import AgentService +from app.features.scenarios.agent_tools import propose_scenario, save_scenario +from app.features.scenarios.models import ScenarioPlan +from app.features.scenarios.schemas import SaveScenarioRequest, ScenarioAssumptions +from app.features.scenarios.tests.conftest import TEST_PRODUCT_ID, TEST_STORE_ID + + +def test_save_scenario_requires_approval() -> None: + """``save_scenario`` is in agent_require_approval — the mutation-surface gate.""" + assert "save_scenario" in get_settings().agent_require_approval + assert requires_approval("save_scenario") is True + + +@pytest.mark.integration +class TestProposeScenario: + """propose_scenario drafts a candidate and persists nothing.""" + + async def test_returns_valid_assumptions_and_recommendation( + self, db_session: AsyncSession + ) -> None: + """A default objective yields a valid price-cut candidate.""" + result = await propose_scenario( + db_session, + store_id=TEST_STORE_ID, + product_id=TEST_PRODUCT_ID, + horizon=14, + objective="grow demand for the summer range", + ) + + # The candidate assumptions round-trip through the real schema. + assumptions = ScenarioAssumptions.model_validate(result["assumptions"]) + assert assumptions.price is not None + assert assumptions.price.change_pct < 0.0 + assert isinstance(result["recommendation"], str) + assert result["recommendation"] + + async def test_promotion_keyword_proposes_a_promotion(self, db_session: AsyncSession) -> None: + """An objective mentioning a promotion steers the candidate accordingly.""" + result = await propose_scenario( + db_session, + store_id=TEST_STORE_ID, + product_id=TEST_PRODUCT_ID, + horizon=7, + objective="run a promotion next week", + ) + + assumptions = ScenarioAssumptions.model_validate(result["assumptions"]) + assert assumptions.promotion is not None + assert assumptions.price is None + + async def test_persists_no_row(self, db_session: AsyncSession) -> None: + """propose_scenario is read-only — it never writes a scenario_plan row.""" + await propose_scenario( + db_session, + store_id=TEST_STORE_ID, + product_id=TEST_PRODUCT_ID, + horizon=10, + objective="test", + ) + count = await db_session.scalar(select(func.count()).select_from(ScenarioPlan)) + assert count == 0 + + +@pytest.mark.integration +class TestSaveScenario: + """save_scenario persists a plan stamped with agent provenance.""" + + async def test_persists_with_agent_provenance( + self, db_session: AsyncSession, trained_model: str + ) -> None: + """A save stamps source='agent', the session id, and the audit trail.""" + request = SaveScenarioRequest( + name="Saved by agent", + assumptions=ScenarioAssumptions(), + store_id=TEST_STORE_ID, + product_id=TEST_PRODUCT_ID, + horizon=7, + run_id=trained_model, + ) + + result = await save_scenario(db_session, request, agent_session_id="sess-xyz") + + assert result["source"] == "agent" + assert result["agent_session_id"] == "sess-xyz" + assert result["approved_by"] == "operator" + assert result["approval_decision"] == "approved" + assert result["approved_at"] is not None + + rows = (await db_session.execute(select(ScenarioPlan))).scalars().all() + assert len(rows) == 1 + assert rows[0].source == "agent" + assert rows[0].agent_session_id == "sess-xyz" + + +@pytest.mark.integration +class TestSaveScenarioHITLGate: + """The save_scenario HITL gate persists a row only once approved.""" + + @staticmethod + def _pending_save_action(session_id: str, run_id: str) -> dict[str, object]: + """Build a pending save_scenario action for the given session.""" + now = datetime.now(UTC) + return { + "action_id": "act-save-1", + "action_type": "save_scenario", + "description": "Save the proposed scenario", + "arguments": { + "name": "Agent-proposed plan", + "run_id": run_id, + "store_id": TEST_STORE_ID, + "product_id": TEST_PRODUCT_ID, + "horizon": 7, + "assumptions": {}, + "source": "agent", + "agent_session_id": session_id, + }, + "created_at": now.isoformat(), + "expires_at": (now + timedelta(minutes=5)).isoformat(), + } + + async def _seed_session(self, db_session: AsyncSession, session_id: str, run_id: str) -> None: + """Insert an experiment session awaiting a save_scenario approval.""" + now = datetime.now(UTC) + db_session.add( + AgentSession( + session_id=session_id, + agent_type=AgentType.EXPERIMENT.value, + status=SessionStatus.AWAITING_APPROVAL.value, + message_history=[], + pending_action=self._pending_save_action(session_id, run_id), + total_tokens_used=0, + tool_calls_count=1, + last_activity=now, + expires_at=now + timedelta(minutes=30), + ) + ) + await db_session.commit() + + async def test_approve_persists_agent_plan( + self, db_session: AsyncSession, trained_model: str + ) -> None: + """Approving the pending action persists a row with the audit trail.""" + session_id = uuid.uuid4().hex # session_id is VARCHAR(32) — hex is exactly 32 + await self._seed_session(db_session, session_id, trained_model) + try: + response = await AgentService().approve_action( + db=db_session, + session_id=session_id, + action_id="act-save-1", + approved=True, + ) + + assert response.status == "executed" + assert isinstance(response.result, dict) + assert response.result["source"] == "agent" + + rows = ( + ( + await db_session.execute( + select(ScenarioPlan).where(ScenarioPlan.agent_session_id == session_id) + ) + ) + .scalars() + .all() + ) + assert len(rows) == 1 + assert rows[0].source == "agent" + assert rows[0].approved_by == "operator" + assert rows[0].approval_decision == "approved" + assert rows[0].approved_at is not None + finally: + await db_session.execute( + delete(AgentSession).where(AgentSession.session_id == session_id) + ) + await db_session.commit() + + async def test_reject_persists_no_plan( + self, db_session: AsyncSession, trained_model: str + ) -> None: + """Rejecting the pending action writes no scenario_plan row.""" + session_id = uuid.uuid4().hex # session_id is VARCHAR(32) — hex is exactly 32 + await self._seed_session(db_session, session_id, trained_model) + try: + response = await AgentService().approve_action( + db=db_session, + session_id=session_id, + action_id="act-save-1", + approved=False, + ) + + assert response.status == "rejected" + count = await db_session.scalar(select(func.count()).select_from(ScenarioPlan)) + assert count == 0 + finally: + await db_session.execute( + delete(AgentSession).where(AgentSession.session_id == session_id) + ) + await db_session.commit() From 3f87ec1737164bc0996f3dd3af0d703d61e23d43 Mon Sep 17 00:00:00 2001 From: Gabor Szabo Date: Tue, 19 May 2026 11:04:50 +0200 Subject: [PATCH 9/9] docs(docs): document model-driven scenarios, library, and agent tools (#223) --- README.md | 2 +- docs/_base/API_CONTRACTS.md | 9 +++++---- docs/_base/DOMAIN_MODEL.md | 13 +++++++++---- docs/_base/REPO_MAP_INDEX.md | 4 ++-- docs/_base/SECURITY.md | 2 +- 5 files changed, 18 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index 6298904e..3466a525 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ Portfolio-grade end-to-end retail demand forecasting system. - **Dashboard**: React 19 + Vite + Tailwind CSS 4 + shadcn/ui for data exploration and model management - **Explorer**: Click-through detail pages for stores, products, model runs, and jobs; run-vs-run comparison and SHA-256 artifact integrity verification; server-side sortable, CSV-exportable tables with column-visibility toggles and URL-shareable filter/sort/page state across every Explorer page; date-scoped KPIs, revenue bar/line charts, and cross-filtering on the Sales page - **Demand Planner**: `/visualize/demand` — every completed forecast rolled into a multi-SKU table (tomorrow / next-week / next-month demand + inventory requirement), with a lead-time selector and a single-SKU drill-in; the Forecast and Backtest pages run jobs in-page, export CSV, toggle a prediction-interval band, and cross-link to runs/jobs -- **What-If Planner**: `/visualize/planner` — take an existing forecast, apply deterministic price / promotion / holiday / inventory / lifecycle assumptions, and see the baseline-vs-scenario demand and revenue impact (clearly labelled heuristic); save, reload, and delete named scenario plans +- **What-If Planner**: `/visualize/planner` — take an existing forecast, apply price / promotion / holiday / inventory / lifecycle assumptions, and see the baseline-vs-scenario demand and revenue impact; a regression baseline genuinely re-forecasts through the assumptions (`method="model_exogenous"`), any other baseline applies a clearly-labelled deterministic heuristic; save, tag, reload, clone, and delete named scenario plans, and rank 2-5 saved plans side by side in a multi-scenario comparison. The experiment chat agent can also propose a scenario and — behind the human-in-the-loop approval gate — save it for you - **RAG Knowledge Base**: Postgres pgvector embeddings + evidence-grounded answers with citations - **Agentic Layer**: PydanticAI agents for autonomous experimentation and evidence-grounded Q&A with human-in-the-loop approval - **Data Seeder (The Forge)**: Reproducible synthetic data generator with realistic time-series patterns, scenario presets, and retail effects diff --git a/docs/_base/API_CONTRACTS.md b/docs/_base/API_CONTRACTS.md index 0fa10d9a..3719a802 100644 --- a/docs/_base/API_CONTRACTS.md +++ b/docs/_base/API_CONTRACTS.md @@ -19,13 +19,14 @@ All endpoints serve JSON; error responses use `application/problem+json` (RFC 78 | analytics | GET | `/analytics/inventory-status` | Latest `inventory_snapshot_daily` row per `(store, product)` grain (Postgres `DISTINCT ON`); optional `store_id`/`product_id` filters; `200` + empty list on an empty table (never `404`) | | featuresets | POST | `/featuresets/compute` | Compute time-safe features (lag/rolling/calendar, leakage-prevented) | | featuresets | POST | `/featuresets/preview` | Preview features with sample rows | -| forecasting | POST | `/forecasting/train` | Train a model (naive / seasonal_naive / moving_average / lightgbm) | +| forecasting | POST | `/forecasting/train` | Train a model (naive / seasonal_naive / moving_average / lightgbm / regression). `regression` wraps `HistGradientBoostingRegressor` on lag + calendar + exogenous features — the baseline a `model_exogenous` scenario re-forecasts through | | forecasting | POST | `/forecasting/predict` | Generate horizon predictions from a trained model | | backtesting | POST | `/backtesting/run` | Time-series CV (rolling/expanding splits, MAE/sMAPE/WAPE/bias/stability) | -| scenarios | POST | `/scenarios/simulate` | Stateless what-if: load a baseline model, forecast, apply deterministic price/promotion/holiday/inventory/lifecycle factors, return a `ScenarioComparison` (`method="heuristic"`). Bogus `run_id` → RFC 7807 404 | -| scenarios | POST | `/scenarios` | Run a simulation and persist it as a named `scenario_plan` (raw assumptions + full comparison snapshot) | -| scenarios | GET | `/scenarios` | List saved scenario plans, newest first (`limit`/`offset`); `200` + empty list on an empty table | +| scenarios | POST | `/scenarios/simulate` | Stateless what-if: load a baseline model, forecast, apply price/promotion/holiday/inventory/lifecycle assumptions, return a `ScenarioComparison`. A `regression` baseline genuinely re-forecasts through a leakage-safe future feature frame (`method="model_exogenous"`); any other baseline applies a deterministic post-forecast multiplier (`method="heuristic"`). Bogus `run_id` → RFC 7807 404 | +| scenarios | POST | `/scenarios` | Run a simulation and persist it as a named `scenario_plan` (raw assumptions + full comparison snapshot); optional `tags` + `cloned_from` | +| scenarios | GET | `/scenarios` | List saved scenario plans, newest first (`limit`/`offset`, optional repeated `tags` filter — JSONB containment); `200` + empty list on an empty table | | scenarios | GET | `/scenarios/{scenario_id}` | Saved plan + embedded comparison snapshot; `404` when missing | +| scenarios | POST | `/scenarios/compare` | Rank 2-5 saved plans (`scenario_ids`, `rank_by`) against a shared baseline; returns a `MultiScenarioComparison` with ranked rows + merged multi-series chart data. Unknown `scenario_id` → 404 | | scenarios | DELETE | `/scenarios/{scenario_id}` | Delete a saved plan; `404` when missing | | registry | POST | `/registry/runs` | Create model run (pending) | | registry | GET | `/registry/runs` | List with filters + pagination + optional allow-listed `sort_by`/`sort_order` (created_at/model_type/status/store_id/product_id; unknown → default `created_at desc`) | diff --git a/docs/_base/DOMAIN_MODEL.md b/docs/_base/DOMAIN_MODEL.md index 733c9d8f..c2e6e8bc 100644 --- a/docs/_base/DOMAIN_MODEL.md +++ b/docs/_base/DOMAIN_MODEL.md @@ -9,7 +9,7 @@ | Featuresets | Computed feature matrices (in-memory; not persisted) | Time-cutoff parameter — never reads beyond `cutoff_date` | | Forecasting | Trained model artifacts on disk (joblib `.pkl`) | Model interface in `examples/models/model_interface.md`; artifact_uri returned to caller | | Backtesting | Fold results, metrics (returned in response; persisted via Registry) | `SplitConfig` (expanding/sliding, gap, horizon) — `app/features/backtesting/splitter.py` | -| Scenarios | `scenario_plan` (saved what-if plans, JSONB assumptions + comparison) | `load_model_bundle` only (never a sibling `service.py`); deterministic `adjustments.py` post-forecast multiplier | +| Scenarios | `scenario_plan` (saved what-if plans, JSONB assumptions + comparison) | `load_model_bundle` only (never a sibling `service.py`); `adjustments.py` heuristic multiplier or `feature_frame.py` model re-forecast; `agent_tools.py` is the agent-integration seam | | Registry | `model_run`, `run_alias`, `model_artifact` | SHA-256 hash on artifact_uri; status state machine | | RAG | `rag_source`, `rag_chunk` (with pgvector embedding column) | Content hash for idempotent indexing; embedding dimension fixed per provider | | Agents | `agent_session` (JSONB message_history) | Pydantic-validated tool args; HITL approval queue | @@ -46,16 +46,18 @@ ### `scenario_plan` (Scenarios) - **Root:** `ScenarioPlan(scenario_id: str, name: str)` - **JSONB fields:** `assumptions` (the raw `ScenarioAssumptions`), `comparison` (the full `ScenarioComparison` snapshot — stored so a reloaded plan re-renders without recomputation or the original artifact). +- **Scalar columns:** `tags` (a queryable JSONB string array — its own column, GIN-indexed, never folded into a blob) and `cloned_from` (the `scenario_id` a plan was cloned from); provenance/audit columns `source` (`'user'`|`'agent'`), `agent_session_id`, `approved_by`, `approved_at`, `approval_decision` (`'approved'`|`'rejected'`). - **Invariants:** - - `method` is CHECK-constrained to `'heuristic'` — the MVP only ever produces a deterministic post-forecast multiplier, never a re-trained causal model. - - A scenario adjustment touches only horizon (future) points; it never reads or mutates the historical series (`app/features/scenarios/tests/test_leakage.py` is the spec). + - `method` is CHECK-constrained to `IN ('heuristic','model_exogenous')`. `heuristic` is a deterministic post-forecast multiplier; `model_exogenous` is a genuine re-forecast of a regression baseline through a leakage-safe future feature frame (`feature_frame.py`). + - A scenario adjustment touches only horizon (future) points; it never reads or mutates the historical series (`app/features/scenarios/tests/test_leakage.py` and `test_future_frame_leakage.py` are the spec). - JSONB columns are persisted via `model_dump(mode="json")` so `date`/`datetime` serialise to ISO strings. + - An agent-saved plan (`source='agent'`) is persisted ONLY after the human approves it through the HITL gate — it always carries the approval audit trail. ## Key Invariants — NEVER violate 1. **Time safety in features.** `app/features/featuresets/` uses only data at or before `cutoff_date`. Lags via `shift(positive)`, rolling via `shift(1).rolling(...)`, all `groupby` entity-aware. The test `app/features/featuresets/tests/test_leakage.py` is the spec — it MUST keep passing. 2. **Forward-only migrations.** Once an Alembic migration is merged, never edit it. Add a new migration to fix or evolve. -3. **HITL approval gates the agent's mutation surface.** Every tool that writes to the registry (`create_alias`, `archive_run`, …) must be in `agent_require_approval`. Widening the surface without updating that list is a security regression. +3. **HITL approval gates the agent's mutation surface.** Every tool that writes state (`create_alias`, `archive_run`, `save_scenario`, …) must be in `agent_require_approval`. Widening the surface without updating that list is a security regression. 4. **Single-host deployable.** No managed cloud service in the core path. `docker-compose up` must continue to be the only prerequisite besides Python + Node. 5. **Pre-1.0 contracts may move.** Pin the version you build against. After `v1.0.0`, full SemVer applies. 6. **Seeder is idempotent + scoped.** Never introduce a "wipe everything" path that isn't behind `--confirm` + scope flag. @@ -82,6 +84,9 @@ | `scenario plan` | A saved what-if analysis — a `scenario_plan` row pairing raw `ScenarioAssumptions` with a `ScenarioComparison` snapshot | seeder `scenario` (a different concept entirely) | | `assumption` (what-if) | One future change a planner posits — a price change, promotion, holiday set, inventory cap, or lifecycle stage — fed to `POST /scenarios/simulate` | forecast input, feature | | `applied factor` | The deterministic per-day multiplier `combined_daily_factor` derives from the assumptions; `1.0` means no change | weight, coefficient | +| `model_exogenous` | The scenario `method` where a regression baseline genuinely re-forecasts through the assumptions — as opposed to the `heuristic` post-forecast multiplier | re-trained model (the baseline is not re-trained, only re-run) | +| `future feature frame` | The leakage-safe `X_future` matrix `feature_frame.py` builds — long-lag, calendar, and exogenous columns the regression model consumes to re-forecast a scenario | feature matrix (that is the training-time term) | +| `scenario tag` | A free-text label on a saved `scenario_plan` (its own queryable JSONB-array column) for filtering and grouping the library | seeder `scenario` preset, registry `alias` | ## Event Taxonomy diff --git a/docs/_base/REPO_MAP_INDEX.md b/docs/_base/REPO_MAP_INDEX.md index e797bd92..5300f157 100644 --- a/docs/_base/REPO_MAP_INDEX.md +++ b/docs/_base/REPO_MAP_INDEX.md @@ -31,8 +31,8 @@ ForecastLabAI is a portfolio-grade, single-host retail-demand-forecasting system | [`frontend/src/pages/explorer/job-detail.tsx`](../../frontend/src/pages/explorer/job-detail.tsx) | The job detail page — profile, params/result JSON, error details, linked run, cancel action, live polling | Investigating a single job | | [`frontend/src/pages/explorer/run-compare.tsx`](../../frontend/src/pages/explorer/run-compare.tsx) | The run-comparison page — two run pickers, side-by-side profile, config_diff, metrics_diff with delta indicators; deep-linkable via `?a=&b=` | Comparing two model runs | | [`frontend/src/pages/visualize/demand.tsx`](../../frontend/src/pages/visualize/demand.tsx) | The Demand Planner page — completed `predict` jobs rolled into a multi-SKU table (tomorrow/next-week/next-month demand + inventory requirement), lead-time selector, single-SKU drill-in | Answering "how much will this SKU sell, and do I have enough stock?" | -| [`app/features/scenarios/`](../../app/features/scenarios/) | Scenario Simulation slice — `POST /scenarios/simulate` (stateless what-if) + `scenario_plan` CRUD; pure `adjustments.py` deterministic factors, never imports a sibling `service.py` | What-If planning, baseline-vs-scenario comparisons | -| [`frontend/src/pages/visualize/planner.tsx`](../../frontend/src/pages/visualize/planner.tsx) | The What-If Planner page — pick a baseline predict job, define price/promotion/holiday/inventory/lifecycle assumptions, run a simulation, save / reload / delete named plans | Answering "what if we discount this SKU 15% next week?" | +| [`app/features/scenarios/`](../../app/features/scenarios/) | Scenario Simulation slice — `/scenarios/simulate` (stateless what-if) + `scenario_plan` CRUD + `/scenarios/compare`; pure `adjustments.py` heuristic factors and `feature_frame.py` (leakage-safe future X for the `model_exogenous` re-forecast); `agent_tools.py` is the agent-integration seam; never imports a sibling `service.py` | What-If planning, baseline-vs-scenario comparisons, model-driven re-forecasts | +| [`frontend/src/pages/visualize/planner.tsx`](../../frontend/src/pages/visualize/planner.tsx) | The What-If Planner page — pick a baseline predict job, define price/promotion/holiday/inventory/lifecycle assumptions, run a simulation, save / tag / reload / delete named plans, and rank 2-5 saved plans in a multi-scenario comparison | Answering "what if we discount this SKU 15% next week?" | | [`alembic/versions/`](../../alembic/versions/) | Migrations through `43e35957a248_create_scenario_plan_table.py` | DB-schema questions, migration drift | | [`docs/ARCHITECTURE.md`](../ARCHITECTURE.md) | Phase-by-phase architecture narrative | High-level component reasoning | | [`docs/PHASE-index.md`](../PHASE-index.md) | Index of all 11 phase docs | Locating per-phase deep-dive | diff --git a/docs/_base/SECURITY.md b/docs/_base/SECURITY.md index 786f35b6..8a8e3e78 100644 --- a/docs/_base/SECURITY.md +++ b/docs/_base/SECURITY.md @@ -67,7 +67,7 @@ Reference: PR #115 (issue #109) introduced this pattern on `ComputeFeaturesReque - Token budget cap per session (`agent_max_tokens=4096` default). - Tool-call cap per session (`agent_max_tool_calls=10` default). - Timeout wrap around `agent.run()` / `agent.run_stream()` (`agent_timeout_seconds=120`). -- HITL approval required for mutating tools — `agent_require_approval=["create_alias","archive_run"]`. Never widen the agent's mutation surface without adding the new tool name to that list. +- HITL approval required for mutating tools — `agent_require_approval=["create_alias","archive_run","save_scenario"]`. `save_scenario` (PRP-27 Phase D) lets the experiment agent persist a `scenario_plan` row; it is gated here exactly like the registry mutations. Never widen the agent's mutation surface without adding the new tool name to that list. - Never log full prompts/responses at INFO; DEBUG only with explicit operator opt-in. ## External Integrations Security