diff --git a/alembic/versions/f7a8b9c0d123_add_exogenous_signal_and_sales_returns_tables.py b/alembic/versions/f7a8b9c0d123_add_exogenous_signal_and_sales_returns_tables.py new file mode 100644 index 00000000..e19d6e9b --- /dev/null +++ b/alembic/versions/f7a8b9c0d123_add_exogenous_signal_and_sales_returns_tables.py @@ -0,0 +1,154 @@ +"""add exogenous_signal and sales_returns tables + +Revision ID: f7a8b9c0d123 +Revises: d6e0f2g3h456 +Create Date: 2026-05-11 12:00:00.000000 + +Phase 1 of the seeder realism extension. Additive only — creates two new +fact tables to support exogenous demand signals (weather / macro / events) +and synthetic returns volume. No existing rows are touched. + +Downgrade drops both tables; any seeded rows are lost. This is acceptable +because the data is synthetic; do not run downgrade against an environment +that holds user-loaded data. +""" + +from collections.abc import Sequence + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "f7a8b9c0d123" +down_revision: str | None = "d6e0f2g3h456" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + """Apply migration: create exogenous_signal and sales_returns.""" + op.create_table( + "exogenous_signal", + sa.Column("id", sa.BigInteger(), nullable=False), + sa.Column("date", sa.Date(), nullable=False), + sa.Column("signal_name", sa.String(length=50), nullable=False), + sa.Column("store_id", sa.Integer(), nullable=True), + sa.Column("is_global", sa.Boolean(), nullable=False), + sa.Column("value", sa.Float(), nullable=False), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.CheckConstraint( + "(is_global = true AND store_id IS NULL) OR " + "(is_global = false AND store_id IS NOT NULL)", + name="ck_exogenous_signal_global_consistency", + ), + sa.ForeignKeyConstraint(["date"], ["calendar.date"]), + sa.ForeignKeyConstraint(["store_id"], ["store.id"]), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + op.f("ix_exogenous_signal_date"), "exogenous_signal", ["date"], unique=False + ) + op.create_index( + op.f("ix_exogenous_signal_signal_name"), + "exogenous_signal", + ["signal_name"], + unique=False, + ) + op.create_index( + op.f("ix_exogenous_signal_store_id"), + "exogenous_signal", + ["store_id"], + unique=False, + ) + op.create_index( + "ix_exogenous_signal_name_date", + "exogenous_signal", + ["signal_name", "date"], + unique=False, + ) + op.create_index( + "uq_exogenous_signal_global", + "exogenous_signal", + ["date", "signal_name"], + unique=True, + postgresql_where=sa.text("is_global = true"), + ) + op.create_index( + "uq_exogenous_signal_per_store", + "exogenous_signal", + ["date", "signal_name", "store_id"], + unique=True, + postgresql_where=sa.text("is_global = false"), + ) + + op.create_table( + "sales_returns", + sa.Column("id", sa.BigInteger(), nullable=False), + sa.Column("date", sa.Date(), nullable=False), + sa.Column("store_id", sa.Integer(), nullable=False), + sa.Column("product_id", sa.Integer(), nullable=False), + sa.Column("return_quantity", sa.Integer(), nullable=False), + sa.Column("return_reason", sa.String(length=50), nullable=False), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.CheckConstraint("return_quantity >= 1", name="ck_sales_returns_quantity_positive"), + sa.ForeignKeyConstraint(["date"], ["calendar.date"]), + sa.ForeignKeyConstraint(["product_id"], ["product.id"]), + sa.ForeignKeyConstraint(["store_id"], ["store.id"]), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + op.f("ix_sales_returns_product_id"), "sales_returns", ["product_id"], unique=False + ) + op.create_index( + op.f("ix_sales_returns_store_id"), "sales_returns", ["store_id"], unique=False + ) + op.create_index( + "ix_sales_returns_store_product_date", + "sales_returns", + ["store_id", "product_id", "date"], + unique=False, + ) + op.create_index("ix_sales_returns_date", "sales_returns", ["date"], unique=False) + + +def downgrade() -> None: + """Revert migration: drop sales_returns and exogenous_signal. + + WARNING: Any seeded Phase 1 rows are lost. Acceptable for synthetic data + only — do not run against an environment with user-loaded signals. + """ + op.drop_index("ix_sales_returns_date", table_name="sales_returns") + op.drop_index("ix_sales_returns_store_product_date", table_name="sales_returns") + op.drop_index(op.f("ix_sales_returns_store_id"), table_name="sales_returns") + op.drop_index(op.f("ix_sales_returns_product_id"), table_name="sales_returns") + op.drop_table("sales_returns") + + op.drop_index("uq_exogenous_signal_per_store", table_name="exogenous_signal") + op.drop_index("uq_exogenous_signal_global", table_name="exogenous_signal") + op.drop_index("ix_exogenous_signal_name_date", table_name="exogenous_signal") + op.drop_index(op.f("ix_exogenous_signal_store_id"), table_name="exogenous_signal") + op.drop_index(op.f("ix_exogenous_signal_signal_name"), table_name="exogenous_signal") + op.drop_index(op.f("ix_exogenous_signal_date"), table_name="exogenous_signal") + op.drop_table("exogenous_signal") diff --git a/app/features/data_platform/models.py b/app/features/data_platform/models.py index f2bb76c2..5a99dbb9 100644 --- a/app/features/data_platform/models.py +++ b/app/features/data_platform/models.py @@ -13,9 +13,11 @@ from decimal import Decimal from sqlalchemy import ( + BigInteger, Boolean, CheckConstraint, Date, + Float, ForeignKey, Index, Integer, @@ -308,3 +310,88 @@ class InventorySnapshotDaily(TimestampMixin, Base): CheckConstraint("on_hand_qty >= 0", name="ck_inventory_on_hand_positive"), CheckConstraint("on_order_qty >= 0", name="ck_inventory_on_order_positive"), ) + + +class ExogenousSignal(TimestampMixin, Base): + """Exogenous demand-relevant signals (weather, macro index, events). + + A signal is either chain-wide (``is_global=True``, ``store_id IS NULL``) + or per-store (``is_global=False``, ``store_id IS NOT NULL``). The two + cases are enforced by ``ck_exogenous_signal_global_consistency`` and made + unique by two partial indexes so re-runs of the seeder are idempotent. + + Attributes: + id: Surrogate primary key. + date: Signal date (FK to calendar). + signal_name: Short identifier (e.g. ``"weather_temp_c"``, ``"macro_index"``). + store_id: Store (FK) — NULL when ``is_global=True``. + is_global: True for chain-wide signals; mirrors ``store_id IS NULL``. + value: Numeric value of the signal on the given date. + """ + + __tablename__ = "exogenous_signal" + + id: Mapped[int] = mapped_column(BigInteger, primary_key=True) + date: Mapped[datetime.date] = mapped_column(Date, ForeignKey("calendar.date"), index=True) + signal_name: Mapped[str] = mapped_column(String(50), index=True) + store_id: Mapped[int | None] = mapped_column( + Integer, ForeignKey("store.id"), nullable=True, index=True + ) + is_global: Mapped[bool] = mapped_column(Boolean, nullable=False) + value: Mapped[float] = mapped_column(Float, nullable=False) + + __table_args__ = ( + Index("ix_exogenous_signal_name_date", "signal_name", "date"), + Index( + "uq_exogenous_signal_global", + "date", + "signal_name", + unique=True, + postgresql_where=("is_global = true"), + ), + Index( + "uq_exogenous_signal_per_store", + "date", + "signal_name", + "store_id", + unique=True, + postgresql_where=("is_global = false"), + ), + CheckConstraint( + "(is_global = true AND store_id IS NULL) OR " + "(is_global = false AND store_id IS NOT NULL)", + name="ck_exogenous_signal_global_consistency", + ), + ) + + +class SalesReturn(TimestampMixin, Base): + """Synthetic sales return event. + + Returns are not subtracted from ``sales_daily.quantity``; they live in a + separate table so featuresets/forecasting can opt into them as a signal. + + Attributes: + id: Surrogate primary key. + date: Return date (FK to calendar). + store_id: Store (FK). + product_id: Product (FK). + return_quantity: Units returned (>= 1). + return_reason: Free-form short reason (e.g. ``"defective"``, + ``"changed_mind"``). + """ + + __tablename__ = "sales_returns" + + id: Mapped[int] = mapped_column(BigInteger, primary_key=True) + date: Mapped[datetime.date] = mapped_column(Date, ForeignKey("calendar.date")) + store_id: Mapped[int] = mapped_column(Integer, ForeignKey("store.id"), index=True) + product_id: Mapped[int] = mapped_column(Integer, ForeignKey("product.id"), index=True) + return_quantity: Mapped[int] = mapped_column(Integer, nullable=False) + return_reason: Mapped[str] = mapped_column(String(50), nullable=False) + + __table_args__ = ( + Index("ix_sales_returns_store_product_date", "store_id", "product_id", "date"), + Index("ix_sales_returns_date", "date"), + CheckConstraint("return_quantity >= 1", name="ck_sales_returns_quantity_positive"), + ) diff --git a/app/features/seeder/routes.py b/app/features/seeder/routes.py index 76e1233b..3e2717ed 100644 --- a/app/features/seeder/routes.py +++ b/app/features/seeder/routes.py @@ -4,7 +4,9 @@ through the dashboard admin panel. """ -from fastapi import APIRouter, Depends, HTTPException, status +from datetime import date + +from fastapi import APIRouter, Depends, HTTPException, Query, status from sqlalchemy.ext.asyncio import AsyncSession from app.core.config import get_settings @@ -226,6 +228,70 @@ async def delete_data( ) from e +@router.get( + "/exogenous", + response_model=schemas.ExogenousSignalResponse, + summary="Query exogenous signals", + description=( + "Return exogenous signal rows (Phase 1) for a given signal name and date " + "window. Available signals: `weather_temp_c`, `macro_index`, `event_flag`." + ), +) +async def query_exogenous( + signal_name: str = Query( + ..., + min_length=1, + max_length=50, + description="Signal identifier (e.g. weather_temp_c, macro_index, event_flag)", + ), + start_date: date = Query(..., description="Window start (inclusive)"), + end_date: date = Query(..., description="Window end (inclusive)"), + store_id: int | None = Query( + default=None, + ge=1, + description="Optional store filter. Omit to include global + per-store rows.", + ), + db: AsyncSession = Depends(get_db), +) -> schemas.ExogenousSignalResponse: + """Query exogenous_signal rows for a signal name and date window. + + Returns rows ordered by date. Subject to row and date-range caps to + keep the response bounded. + + Raises: + HTTPException: 400 if the date window is invalid or oversized. + """ + try: + return await service.query_exogenous( + db, + signal_name=signal_name, + start_date=start_date, + end_date=end_date, + store_id=store_id, + ) + except ValueError as e: + logger.error( + "seeder.exogenous.query_failed", + error=str(e), + error_type=type(e).__name__, + ) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=str(e), + ) from e + except Exception as e: + logger.error( + "seeder.exogenous.query_failed", + error=str(e), + error_type=type(e).__name__, + exc_info=True, + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Exogenous query failed: {e}", + ) from e + + @router.post( "/verify", response_model=schemas.VerifyResult, diff --git a/app/features/seeder/schemas.py b/app/features/seeder/schemas.py index 6c925114..7e438d22 100644 --- a/app/features/seeder/schemas.py +++ b/app/features/seeder/schemas.py @@ -1,9 +1,10 @@ """Pydantic schemas for the seeder feature.""" +import datetime as _datetime_module from datetime import date, datetime from typing import Literal -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, model_validator class SeederStatus(BaseModel): @@ -16,6 +17,14 @@ class SeederStatus(BaseModel): inventory: int = Field(description="Number of inventory_snapshot_daily records") price_history: int = Field(description="Number of price_history records") promotions: int = Field(description="Number of promotion records") + exogenous_signals: int = Field( + default=0, + description="Number of exogenous_signal records (Phase 1)", + ) + sales_returns: int = Field( + default=0, + description="Number of sales_returns records (Phase 1)", + ) date_range_start: date | None = Field( default=None, description="Earliest date in sales_daily", @@ -30,6 +39,22 @@ class SeederStatus(BaseModel): ) +class ChangepointEventParam(BaseModel): + """API-facing representation of a demand changepoint (Phase 1).""" + + date: _datetime_module.date = Field(description="Changepoint impulse date") + demand_multiplier: float = Field( + ge=0.0, + description="Peak multiplier on the changepoint date", + ) + decay_days: int = Field( + default=30, + ge=0, + le=3650, + description="Exponential decay e-folding time (days). 0 = pure impulse.", + ) + + class ScenarioInfo(BaseModel): """Information about a scenario preset.""" @@ -84,6 +109,69 @@ class GenerateParams(BaseModel): description="Preview only, do not execute", ) + # Phase 1 — realism extension. All flags default off so existing + # scenarios remain byte-identical when this endpoint is called without + # the new fields. + enable_exogenous: bool = Field( + default=False, + description="Seed weather/macro/event exogenous signals (Phase 1)", + ) + enable_returns: bool = Field( + default=False, + description="Seed sales_returns rows derived from sales (Phase 1)", + ) + enable_substitution: bool = Field( + default=False, + description="Apply cross-product substitution lift on stockouts (Phase 1)", + ) + yearly_seasonality_amplitude: float | None = Field( + default=None, + ge=0.0, + le=1.0, + description=( + "Yearly sin-wave demand amplitude (fraction). None or 0 = disabled. (Phase 1)" + ), + ) + weather_temperature_sensitivity: float | None = Field( + default=None, + ge=-1.0, + le=1.0, + description=( + "Demand delta per °C above the climatology mean. " + "Only applied when enable_exogenous=true. (Phase 1)" + ), + ) + changepoints: list[ChangepointEventParam] | None = Field( + default=None, + description="Optional list of demand changepoints (Phase 1)", + ) + substitute_groups: list[list[int]] | None = Field( + default=None, + description=( + "Optional list of product-ID groups whose members substitute for " + "each other on stockout. Only applied when enable_substitution=true. " + "(Phase 1)" + ), + ) + substitution_lift_on_stockout: float | None = Field( + default=None, + ge=0.0, + le=10.0, + description=( + "Demand lift distributed across in-stock group-mates when a member " + "is stocked out. Only applied when enable_substitution=true. (Phase 1)" + ), + ) + + @model_validator(mode="after") + def _validate_date_range(self) -> "GenerateParams": + """Reject inverted date ranges with a clear message.""" + if self.end_date < self.start_date: + raise ValueError( + f"end_date ({self.end_date}) must be on or after start_date ({self.start_date})" + ) + return self + class AppendParams(BaseModel): """Parameters for appending data to existing dataset.""" @@ -156,3 +244,37 @@ class VerifyResult(BaseModel): passed_count: int = Field(description="Number of passed checks") warning_count: int = Field(description="Number of warnings") failed_count: int = Field(description="Number of failures") + + +# ============================================================================ +# PHASE 1 — Exogenous signal read API +# ============================================================================ + + +class ExogenousSignalRecord(BaseModel): + """One row of the exogenous_signal table.""" + + date: _datetime_module.date = Field(description="Signal date") + signal_name: str = Field(description="Signal identifier") + store_id: int | None = Field( + default=None, + description="Store ID. None for chain-wide (global) signals.", + ) + is_global: bool = Field(description="True for chain-wide signals") + value: float = Field(description="Numeric signal value") + + +class ExogenousSignalResponse(BaseModel): + """Response payload for GET /seeder/exogenous.""" + + signal_name: str = Field(description="Signal identifier queried") + start_date: date = Field(description="Start of the query window") + end_date: date = Field(description="End of the query window") + store_id: int | None = Field( + default=None, + description="Specific store filter, if applied", + ) + records: list[ExogenousSignalRecord] = Field( + description="Signal rows in ascending date order", + ) + total: int = Field(description="Row count in the response") diff --git a/app/features/seeder/service.py b/app/features/seeder/service.py index 2eff9032..ee1a24ed 100644 --- a/app/features/seeder/service.py +++ b/app/features/seeder/service.py @@ -13,16 +13,27 @@ from app.core.logging import get_logger from app.features.data_platform.models import ( Calendar, + ExogenousSignal, InventorySnapshotDaily, PriceHistory, Product, Promotion, SalesDaily, + SalesReturn, Store, ) from app.features.seeder import schemas from app.shared.seeder import DataSeeder, ScenarioPreset, SeederConfig -from app.shared.seeder.config import DimensionConfig, SparsityConfig +from app.shared.seeder.config import ( + ChangepointConfig, + ChangepointEvent, + DimensionConfig, + ExogenousSignalConfig, + MultiSeasonalityConfig, + ReturnsConfig, + SparsityConfig, + SubstitutionConfig, +) logger = get_logger(__name__) @@ -42,6 +53,71 @@ def _get_scenario_preset(name: str) -> ScenarioPreset | None: return None +def _apply_phase1_overrides(config: SeederConfig, params: schemas.GenerateParams) -> None: + """Apply Phase 1 (realism) overrides from API params onto ``config``. + + Mutates ``config`` in place. Each override is no-op when the matching + flag/field is absent, so existing scenarios stay byte-identical when + Phase 1 params are omitted. + """ + if params.enable_exogenous: + config.exogenous = ExogenousSignalConfig( + enable_weather=True, + enable_macro=True, + enable_events=False, + weather_temperature_sensitivity=( + params.weather_temperature_sensitivity + if params.weather_temperature_sensitivity is not None + else 0.0 + ), + ) + elif params.weather_temperature_sensitivity is not None: + # Sensitivity passed without enable_exogenous → ignore quietly; the + # weather lookup won't exist so the multiplier short-circuits. + config.exogenous = replace( + config.exogenous, + weather_temperature_sensitivity=params.weather_temperature_sensitivity, + ) + + if ( + params.yearly_seasonality_amplitude is not None + and params.yearly_seasonality_amplitude > 0.0 + ): + config.multi_seasonality = MultiSeasonalityConfig( + yearly_seasonality_amplitude=params.yearly_seasonality_amplitude, + ) + + if params.changepoints: + config.changepoints = ChangepointConfig( + changepoints=[ + ChangepointEvent( + date=cp.date, + demand_multiplier=cp.demand_multiplier, + decay_days=cp.decay_days, + ) + for cp in params.changepoints + ] + ) + + if params.enable_returns: + config.returns = ReturnsConfig(enable=True) + + if params.enable_substitution: + config.substitution = SubstitutionConfig( + enable=True, + substitute_groups=( + [list(group) for group in params.substitute_groups] + if params.substitute_groups is not None + else [] + ), + substitution_lift_on_stockout=( + params.substitution_lift_on_stockout + if params.substitution_lift_on_stockout is not None + else 0.5 + ), + ) + + def _build_config_from_params(params: schemas.GenerateParams) -> SeederConfig: """Build SeederConfig from API parameters. @@ -80,6 +156,8 @@ def _build_config_from_params(params: schemas.GenerateParams) -> SeederConfig: sparsity=SparsityConfig(missing_combinations_pct=params.sparsity), ) + _apply_phase1_overrides(config, params) + settings = get_settings() config.batch_size = settings.seeder_batch_size config.enable_progress = settings.seeder_enable_progress @@ -107,6 +185,8 @@ async def get_status(db: AsyncSession) -> schemas.SeederStatus: ("inventory", InventorySnapshotDaily), ("price_history", PriceHistory), ("promotions", Promotion), + ("exogenous_signals", ExogenousSignal), + ("sales_returns", SalesReturn), ] counts: dict[str, int] = {} @@ -141,6 +221,8 @@ async def get_status(db: AsyncSession) -> schemas.SeederStatus: inventory=counts["inventory"], price_history=counts["price_history"], promotions=counts["promotions"], + exogenous_signals=counts["exogenous_signals"], + sales_returns=counts["sales_returns"], date_range_start=date_range_start, date_range_end=date_range_end, last_updated=last_updated, @@ -257,6 +339,8 @@ async def generate_data( "price_history": 0, "promotions": 0, "inventory": 0, + "exogenous_signals": 0, + "sales_returns": 0, }, duration_seconds=0.0, message=f"Dry run: would generate data with scenario '{params.scenario}'", @@ -299,6 +383,8 @@ async def generate_data( "price_history": result.price_history_count, "promotions": result.promotions_count, "inventory": result.inventory_count, + "exogenous_signals": result.exogenous_count, + "sales_returns": result.returns_count, }, duration_seconds=round(duration, 2), message=f"Successfully generated {result.sales_count:,} sales records with seed {params.seed}", @@ -367,6 +453,8 @@ async def append_data( "price_history": result.price_history_count, "promotions": result.promotions_count, "inventory": result.inventory_count, + "exogenous_signals": result.exogenous_count, + "sales_returns": result.returns_count, }, duration_seconds=round(duration, 2), message=f"Appended {result.sales_count:,} sales records for date range {params.start_date} to {params.end_date}", @@ -540,3 +628,95 @@ async def verify_data(db: AsyncSession) -> schemas.VerifyResult: warning_count=warning_count, failed_count=failed_count, ) + + +# ============================================================================ +# PHASE 1 — Exogenous signal read API +# ============================================================================ + + +EXOGENOUS_MAX_DATE_RANGE_DAYS = 365 * 3 # 3 years — matches feature_max_lookback_days +EXOGENOUS_MAX_RECORDS = 50_000 + + +async def query_exogenous( + db: AsyncSession, + signal_name: str, + start_date: date, + end_date: date, + store_id: int | None, +) -> schemas.ExogenousSignalResponse: + """Return exogenous signal rows for ``signal_name`` within a window. + + Args: + db: Async database session. + signal_name: Exact signal identifier (e.g. ``"weather_temp_c"``). + start_date: Window start (inclusive). + end_date: Window end (inclusive). + store_id: Optional store filter. When None, returns global signals + plus any store-scoped rows for the period (callers typically + filter on a single store to keep payload sizes reasonable). + + Returns: + ExogenousSignalResponse with rows ordered by date ascending. + + Raises: + ValueError: On inverted or oversized date windows. + """ + if end_date < start_date: + raise ValueError(f"end_date ({end_date}) must be on or after start_date ({start_date})") + span_days = (end_date - start_date).days + if span_days > EXOGENOUS_MAX_DATE_RANGE_DAYS: + raise ValueError( + f"Date range too large ({span_days} days); max is {EXOGENOUS_MAX_DATE_RANGE_DAYS} days" + ) + + stmt = ( + select(ExogenousSignal) + .where(ExogenousSignal.signal_name == signal_name) + .where(ExogenousSignal.date >= start_date) + .where(ExogenousSignal.date <= end_date) + .order_by(ExogenousSignal.date.asc(), ExogenousSignal.store_id.asc().nullsfirst()) + .limit(EXOGENOUS_MAX_RECORDS + 1) + ) + if store_id is not None: + stmt = stmt.where( + (ExogenousSignal.store_id == store_id) | (ExogenousSignal.is_global.is_(True)) + ) + + result = await db.execute(stmt) + rows = result.scalars().all() + if len(rows) > EXOGENOUS_MAX_RECORDS: + raise ValueError( + f"Query exceeded maximum row cap ({EXOGENOUS_MAX_RECORDS}); " + "narrow the date range or filter by store_id" + ) + + records = [ + schemas.ExogenousSignalRecord( + date=row.date, + signal_name=row.signal_name, + store_id=row.store_id, + is_global=row.is_global, + value=row.value, + ) + for row in rows + ] + + logger.info( + "seeder.exogenous.queried", + signal_name=signal_name, + start_date=str(start_date), + end_date=str(end_date), + store_id=store_id, + rows=len(records), + ) + + return schemas.ExogenousSignalResponse( + signal_name=signal_name, + start_date=start_date, + end_date=end_date, + store_id=store_id, + records=records, + total=len(records), + ) diff --git a/app/features/seeder/tests/test_phase1_routes.py b/app/features/seeder/tests/test_phase1_routes.py new file mode 100644 index 00000000..dba06b48 --- /dev/null +++ b/app/features/seeder/tests/test_phase1_routes.py @@ -0,0 +1,122 @@ +"""Route tests for Phase 1 GET /seeder/exogenous endpoint.""" + +from datetime import date +from unittest.mock import patch + +import pytest +from fastapi import status +from fastapi.testclient import TestClient + +from app.features.seeder import schemas +from app.main import app + + +@pytest.fixture +def client(): + return TestClient(app) + + +class TestExogenousRoute: + def test_happy_path(self, client): + mock_response = schemas.ExogenousSignalResponse( + signal_name="weather_temp_c", + start_date=date(2024, 1, 1), + end_date=date(2024, 1, 2), + store_id=None, + records=[ + schemas.ExogenousSignalRecord( + date=date(2024, 1, 1), + signal_name="weather_temp_c", + store_id=1, + is_global=False, + value=12.3, + ), + schemas.ExogenousSignalRecord( + date=date(2024, 1, 2), + signal_name="weather_temp_c", + store_id=1, + is_global=False, + value=13.1, + ), + ], + total=2, + ) + + async def _fake(*args, **kwargs): + return mock_response + + with patch( + "app.features.seeder.routes.service.query_exogenous", + side_effect=_fake, + ): + response = client.get( + "/seeder/exogenous", + params={ + "signal_name": "weather_temp_c", + "start_date": "2024-01-01", + "end_date": "2024-01-02", + }, + ) + + assert response.status_code == status.HTTP_200_OK + body = response.json() + assert body["signal_name"] == "weather_temp_c" + assert body["total"] == 2 + assert len(body["records"]) == 2 + + def test_rejects_inverted_window(self, client): + # Service raises ValueError → 400 per the error handler. + async def _fake(*args, **kwargs): + raise ValueError("end_date must be on or after start_date") + + with patch( + "app.features.seeder.routes.service.query_exogenous", + side_effect=_fake, + ): + response = client.get( + "/seeder/exogenous", + params={ + "signal_name": "weather_temp_c", + "start_date": "2024-12-31", + "end_date": "2024-01-01", + }, + ) + assert response.status_code == status.HTTP_400_BAD_REQUEST + + def test_requires_signal_name(self, client): + response = client.get( + "/seeder/exogenous", + params={"start_date": "2024-01-01", "end_date": "2024-01-02"}, + ) + # Missing required param → FastAPI validation 422. + assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + + def test_optional_store_id_passes_through(self, client): + captured: dict[str, object] = {} + + async def _fake(db, signal_name, start_date, end_date, store_id): + captured["store_id"] = store_id + return schemas.ExogenousSignalResponse( + signal_name=signal_name, + start_date=start_date, + end_date=end_date, + store_id=store_id, + records=[], + total=0, + ) + + with patch( + "app.features.seeder.routes.service.query_exogenous", + side_effect=_fake, + ): + response = client.get( + "/seeder/exogenous", + params={ + "signal_name": "weather_temp_c", + "start_date": "2024-01-01", + "end_date": "2024-01-02", + "store_id": 7, + }, + ) + assert response.status_code == status.HTTP_200_OK + assert captured["store_id"] == 7 diff --git a/app/features/seeder/tests/test_phase1_service.py b/app/features/seeder/tests/test_phase1_service.py new file mode 100644 index 00000000..51d5bd57 --- /dev/null +++ b/app/features/seeder/tests/test_phase1_service.py @@ -0,0 +1,136 @@ +"""Service-layer tests for Phase 1 seeder features. + +Covers: +- _apply_phase1_overrides / _build_config_from_params translation of new + GenerateParams fields into SeederConfig sub-configs. +- GenerateParams validation (inverted date range). +- query_exogenous date-window guards. +""" + +from datetime import date + +import pytest + +from app.features.seeder import schemas, service + + +class TestApplyPhase1Overrides: + def test_defaults_leave_phase1_off(self): + """Calling generate with default params must keep Phase 1 off.""" + params = schemas.GenerateParams() + config = service._build_config_from_params(params) + assert config.exogenous.enable_weather is False + assert config.exogenous.enable_macro is False + assert config.multi_seasonality.yearly_seasonality_amplitude == 0.0 + assert config.changepoints.changepoints == [] + assert config.returns.enable is False + assert config.substitution.enable is False + + def test_enable_exogenous_turns_on_weather_and_macro(self): + params = schemas.GenerateParams( + enable_exogenous=True, + weather_temperature_sensitivity=0.03, + ) + config = service._build_config_from_params(params) + assert config.exogenous.enable_weather is True + assert config.exogenous.enable_macro is True + assert config.exogenous.weather_temperature_sensitivity == 0.03 + + def test_yearly_seasonality_passthrough(self): + params = schemas.GenerateParams(yearly_seasonality_amplitude=0.25) + config = service._build_config_from_params(params) + assert config.multi_seasonality.yearly_seasonality_amplitude == 0.25 + + def test_changepoint_list_translation(self): + params = schemas.GenerateParams( + changepoints=[ + schemas.ChangepointEventParam( + date=date(2024, 3, 15), + demand_multiplier=2.0, + decay_days=60, + ) + ] + ) + config = service._build_config_from_params(params) + assert len(config.changepoints.changepoints) == 1 + cp = config.changepoints.changepoints[0] + assert cp.date == date(2024, 3, 15) + assert cp.demand_multiplier == 2.0 + assert cp.decay_days == 60 + + def test_enable_returns_flips_returns_config(self): + params = schemas.GenerateParams(enable_returns=True) + config = service._build_config_from_params(params) + assert config.returns.enable is True + + def test_enable_substitution_with_groups(self): + params = schemas.GenerateParams( + enable_substitution=True, + substitute_groups=[[1, 2, 3], [4, 5]], + substitution_lift_on_stockout=0.4, + ) + config = service._build_config_from_params(params) + assert config.substitution.enable is True + assert config.substitution.substitute_groups == [[1, 2, 3], [4, 5]] + assert config.substitution.substitution_lift_on_stockout == 0.4 + + def test_phase1_overrides_preserve_scenario_dimensions(self): + """A Phase 1 override must not clobber scenario-defined region/brand + lists — regression for the bug fix in service._build_config_from_params. + """ + params = schemas.GenerateParams( + scenario="holiday_rush", + stores=20, + products=80, + enable_returns=True, + ) + config = service._build_config_from_params(params) + assert config.dimensions.stores == 20 + assert config.dimensions.products == 80 + # Holiday rush keeps its 4 holidays + monthly seasonality through + # the phase-1 path. + assert len(config.holidays) == 4 + assert config.time_series.monthly_seasonality[12] == 1.8 + + +class TestGenerateParamsValidation: + def test_rejects_inverted_date_range(self): + with pytest.raises(ValueError, match="must be on or after"): + schemas.GenerateParams( + start_date=date(2024, 12, 31), + end_date=date(2024, 1, 1), + ) + + def test_yearly_amplitude_bounds(self): + # ge=0.0 / le=1.0 enforced by Field. + with pytest.raises(ValueError): + schemas.GenerateParams(yearly_seasonality_amplitude=-0.1) + with pytest.raises(ValueError): + schemas.GenerateParams(yearly_seasonality_amplitude=1.5) + + +class TestQueryExogenousValidation: + """Date-window guards on the service helper. The DB path is covered in + integration tests.""" + + @pytest.mark.asyncio + async def test_rejects_inverted_window(self): + with pytest.raises(ValueError, match="must be on or after"): + await service.query_exogenous( + db=None, # type: ignore[arg-type] + signal_name="weather_temp_c", + start_date=date(2024, 12, 31), + end_date=date(2024, 1, 1), + store_id=None, + ) + + @pytest.mark.asyncio + async def test_rejects_overlong_window(self): + with pytest.raises(ValueError, match="too large"): + await service.query_exogenous( + db=None, # type: ignore[arg-type] + signal_name="weather_temp_c", + start_date=date(2020, 1, 1), + end_date=date(2030, 1, 1), + store_id=None, + ) diff --git a/app/features/seeder/tests/test_service.py b/app/features/seeder/tests/test_service.py index 2ec313aa..3e938ddf 100644 --- a/app/features/seeder/tests/test_service.py +++ b/app/features/seeder/tests/test_service.py @@ -189,10 +189,10 @@ async def test_returns_status(self): """Test status is returned with counts.""" mock_db = AsyncMock() - # Mock the count queries - return different values for each table - mock_results = [10, 50, 365, 182500, 182500, 1500, 500] + # Mock the count queries - return different values for each table. + # Phase 1 adds exogenous_signals (2520) and sales_returns (3650). + mock_results = [10, 50, 365, 182500, 182500, 1500, 500, 2520, 3650] mock_db.execute.side_effect = [ - # Counts for each table *[MagicMock(scalar=MagicMock(return_value=count)) for count in mock_results], # Date range query MagicMock(fetchone=MagicMock(return_value=(date(2024, 1, 1), date(2024, 12, 31)))), @@ -206,15 +206,17 @@ async def test_returns_status(self): assert status.products == 50 assert status.calendar == 365 assert status.sales == 182500 + assert status.exogenous_signals == 2520 + assert status.sales_returns == 3650 @pytest.mark.asyncio async def test_empty_database(self): """Test status for empty database.""" mock_db = AsyncMock() - # Mock empty counts + # Mock empty counts (9 tables: 7 original + 2 Phase 1). mock_db.execute.side_effect = [ - *[MagicMock(scalar=MagicMock(return_value=0)) for _ in range(7)], + *[MagicMock(scalar=MagicMock(return_value=0)) for _ in range(9)], ] status = await service.get_status(mock_db) @@ -222,6 +224,8 @@ async def test_empty_database(self): assert status.stores == 0 assert status.products == 0 assert status.sales == 0 + assert status.exogenous_signals == 0 + assert status.sales_returns == 0 assert status.date_range_start is None assert status.date_range_end is None diff --git a/app/shared/seeder/config.py b/app/shared/seeder/config.py index 3f7cd922..d1b3783c 100644 --- a/app/shared/seeder/config.py +++ b/app/shared/seeder/config.py @@ -126,6 +126,148 @@ class HolidayConfig: multiplier: float = 1.5 +@dataclass +class ExogenousSignalConfig: + """Configuration for exogenous demand signals (weather, macro, events). + + All signals are disabled by default — enabling them does not change the + sales math unless `weather_temperature_sensitivity` is also non-zero or + a feature consumer reads `exogenous_signal` rows. Default values keep + existing scenarios byte-identical. + + Attributes: + enable_weather: Emit `weather_temp_c` rows per (store, date). + enable_macro: Emit `macro_index` rows per date (random walk). + enable_events: Emit `event_flag` rows per date (binary, sparse). + weather_temperature_sensitivity: Demand delta as a fraction per °C + above/below the climatological mean. 0.0 = no demand impact even + when weather rows are emitted. + weather_climatology_mean_c: Annual mean temperature (°C) used as the + sinusoidal center for weather generation. + weather_amplitude_c: Peak-to-peak amplitude of the seasonal sin wave. + weather_noise_sigma_c: Gaussian noise standard deviation in °C. + macro_indicator_lag_days: How many days a macro signal lags demand by + (consumers may use this; the generator itself emits values daily). + macro_initial_value: Starting value of the random-walk index. + macro_step_sigma: Standard deviation of the daily Gaussian increment. + event_dates: Specific dates marked with `event_flag=1` (e.g. promo + launch days). Empty list = no event rows emitted even when + `enable_events=True`. + """ + + enable_weather: bool = False + enable_macro: bool = False + enable_events: bool = False + weather_temperature_sensitivity: float = 0.0 + weather_climatology_mean_c: float = 15.0 + weather_amplitude_c: float = 12.0 + weather_noise_sigma_c: float = 2.0 + macro_indicator_lag_days: int = 0 + macro_initial_value: float = 100.0 + macro_step_sigma: float = 0.5 + event_dates: list[date] = field(default_factory=list) + + +@dataclass +class MultiSeasonalityConfig: + """Configuration for yearly seasonality on top of weekly + monthly. + + Demand multiplier on day-of-year d is `1 + amplitude * sin(2π·(d + phase)/365)`. + + Attributes: + yearly_seasonality_amplitude: Fraction of base demand swung by the + yearly sin wave (e.g. 0.15 = ±15%). 0.0 disables. + yearly_phase_offset_days: Phase shift in days (positive = later peak). + """ + + yearly_seasonality_amplitude: float = 0.0 + yearly_phase_offset_days: int = 0 + + +@dataclass +class ChangepointEvent: + """A single demand changepoint (COVID-style impulse + exponential decay). + + Demand multiplier on day t for a changepoint at day t0 is: + `1 + (demand_multiplier - 1) * exp(-(t - t0) / decay_days)` + for `t >= t0`; 1.0 otherwise. + + Attributes: + date: Date of the changepoint impulse. + demand_multiplier: Peak multiplier on the changepoint date. + decay_days: e-folding time of the exponential decay. 0 = pure impulse. + """ + + date: date + demand_multiplier: float = 1.0 + decay_days: int = 30 + + +@dataclass +class ChangepointConfig: + """Configuration for trend changepoints. + + Attributes: + changepoints: List of changepoint events. Empty = disabled. + """ + + changepoints: list[ChangepointEvent] = field(default_factory=list) + + +@dataclass +class ReturnsConfig: + """Configuration for synthetic returns volume. + + Attributes: + enable: Whether to emit `sales_returns` rows at all. + return_probability: Probability that a given sale generates a return + (0.0 to 1.0). + return_lag_days_min: Minimum days between sale and return. + return_lag_days_max: Maximum days between sale and return. + return_quantity_fraction: Fraction of the original sale quantity that + is returned (clamped to ≥ 1 unit when a return fires). + return_reason_distribution: Probability-weighted reasons. Weights are + normalized at use time. + """ + + enable: bool = False + return_probability: float = 0.02 + return_lag_days_min: int = 1 + return_lag_days_max: int = 14 + return_quantity_fraction: float = 0.5 + return_reason_distribution: dict[str, float] = field( + default_factory=lambda: { + "defective": 0.25, + "wrong_size": 0.20, + "not_as_described": 0.15, + "changed_mind": 0.30, + "damaged_in_transit": 0.10, + } + ) + + +@dataclass +class SubstitutionConfig: + """Configuration for cross-product substitution on stockout. + + When product A in a substitute group is stocked out at a given store on + a given date, each other group-mate B sees its demand multiplied by + `1 + substitution_lift_on_stockout / (group_size - 1)` for that day. + + Attributes: + enable: Whether substitution is applied. + substitute_groups: Sets of product IDs that substitute for each + other. A product may appear in multiple groups. + substitution_lift_on_stockout: Total demand lift distributed across + group-mates when one member is stocked out (e.g. 0.5 = +50% + split among the others). + """ + + enable: bool = False + substitute_groups: list[list[int]] = field(default_factory=list) + substitution_lift_on_stockout: float = 0.0 + + @dataclass class SeederConfig: """Master configuration for the data seeder. @@ -139,6 +281,11 @@ class SeederConfig: retail: Retail-specific pattern configuration. sparsity: Data sparsity configuration. holidays: List of holiday configurations. + exogenous: Phase 1 exogenous signal generation (disabled by default). + multi_seasonality: Phase 1 yearly seasonality (disabled by default). + changepoints: Phase 1 trend changepoints (empty by default). + returns: Phase 1 returns volume (disabled by default). + substitution: Phase 1 stockout substitution (disabled by default). batch_size: Batch size for database inserts. enable_progress: Whether to show progress bars. """ @@ -151,6 +298,11 @@ class SeederConfig: retail: RetailPatternConfig = field(default_factory=RetailPatternConfig) sparsity: SparsityConfig = field(default_factory=SparsityConfig) holidays: list[HolidayConfig] = field(default_factory=list) + exogenous: ExogenousSignalConfig = field(default_factory=ExogenousSignalConfig) + multi_seasonality: MultiSeasonalityConfig = field(default_factory=MultiSeasonalityConfig) + changepoints: ChangepointConfig = field(default_factory=ChangepointConfig) + returns: ReturnsConfig = field(default_factory=ReturnsConfig) + substitution: SubstitutionConfig = field(default_factory=SubstitutionConfig) batch_size: int = 1000 enable_progress: bool = True diff --git a/app/shared/seeder/core.py b/app/shared/seeder/core.py index 830ac962..373f4209 100644 --- a/app/shared/seeder/core.py +++ b/app/shared/seeder/core.py @@ -15,22 +15,27 @@ from app.core.logging import get_logger from app.features.data_platform.models import ( Calendar, + ExogenousSignal, InventorySnapshotDaily, PriceHistory, Product, Promotion, SalesDaily, + SalesReturn, Store, ) from app.shared.seeder.generators import ( CalendarGenerator, + ExogenousSignalGenerator, InventorySnapshotGenerator, PriceHistoryGenerator, ProductGenerator, PromotionGenerator, + ReturnsGenerator, SalesDailyGenerator, StoreGenerator, ) +from app.shared.seeder.generators.exogenous import WEATHER_SIGNAL_NAME if TYPE_CHECKING: from app.shared.seeder.config import SeederConfig @@ -50,6 +55,8 @@ class SeederResult: price_history_count: Number of price history records. promotions_count: Number of promotions generated. inventory_count: Number of inventory snapshots. + exogenous_count: Number of exogenous signal records (Phase 1). + returns_count: Number of sales return records (Phase 1). seed: Random seed used. """ @@ -60,6 +67,8 @@ class SeederResult: price_history_count: int = 0 promotions_count: int = 0 inventory_count: int = 0 + exogenous_count: int = 0 + returns_count: int = 0 seed: int = 42 @@ -182,13 +191,53 @@ async def _generate_dimensions( return store_ids, product_data, dates + async def _generate_exogenous( + self, + db: AsyncSession, + store_ids: list[int], + dates: list[date], + ) -> tuple[int, dict[tuple[int, date], float]]: + """Generate exogenous signals (Phase 1). + + Returns: + Tuple of (rows_inserted, weather_lookup) where ``weather_lookup`` + is ``{(store_id, date): temp_c}`` for downstream demand math. + Empty dict if weather is disabled. + """ + exo_gen = ExogenousSignalGenerator(self.rng, self.config.exogenous) + records = exo_gen.generate(dates, store_ids) + + if not records: + return 0, {} + + logger.info("seeder.exogenous.generating", count=len(records)) + inserted = await self._batch_insert(db, ExogenousSignal, records) + + weather_lookup: dict[tuple[int, date], float] = {} + if self.config.exogenous.enable_weather: + for r in records: + if r["signal_name"] != WEATHER_SIGNAL_NAME: + continue + store_id = r["store_id"] + signal_date = r["date"] + value = r["value"] + if ( + isinstance(store_id, int) + and isinstance(signal_date, date) + and isinstance(value, float) + ): + weather_lookup[(store_id, signal_date)] = value + + return inserted, weather_lookup + async def _generate_facts( self, db: AsyncSession, store_ids: list[int], product_data: list[tuple[int, Decimal]], dates: list[date], - ) -> tuple[int, int, int, int]: + weather_lookup: dict[tuple[int, date], float] | None = None, + ) -> tuple[int, int, int, int, int]: """Generate and insert fact tables. Args: @@ -196,9 +245,14 @@ async def _generate_facts( store_ids: List of store IDs. product_data: List of (product_id, base_price) tuples. dates: List of dates. + weather_lookup: Optional ``{(store_id, date): temp_c}`` from the + exogenous generator. Demand picks up weather sensitivity only + when this dict is non-empty AND + ``config.exogenous.weather_temperature_sensitivity`` is non-zero. Returns: - Tuple of (sales_count, price_history_count, promotions_count, inventory_count). + Tuple of (sales_count, price_history_count, promotions_count, + inventory_count, returns_count). """ product_ids = [pid for pid, _ in product_data] @@ -255,13 +309,26 @@ async def _generate_facts( await self._batch_insert(db, InventorySnapshotDaily, inventory_records) - # Generate sales (depends on promotions and stockouts) + # Generate sales (depends on promotions and stockouts). Phase 1 + # extensions stay as None / 0 when their config flags are off so the + # disabled-path is byte-identical with pre-Phase-1. + weather_lookup_for_sales = ( + weather_lookup + if weather_lookup and self.config.exogenous.weather_temperature_sensitivity != 0.0 + else None + ) sales_gen = SalesDailyGenerator( self.rng, self.config.time_series, self.config.retail, self.config.sparsity, self.config.holidays, + multi_seasonality=self.config.multi_seasonality, + changepoints=self.config.changepoints, + substitution=self.config.substitution, + exogenous_weather=weather_lookup_for_sales, + weather_temperature_sensitivity=(self.config.exogenous.weather_temperature_sensitivity), + weather_climatology_mean_c=self.config.exogenous.weather_climatology_mean_c, ) sales_records = sales_gen.generate( store_ids, @@ -278,11 +345,20 @@ async def _generate_facts( await self._batch_insert(db, SalesDaily, sales_records) + # Generate returns (Phase 1) — depends on sales. Returns config is + # disabled by default; generator short-circuits to an empty list. + returns_gen = ReturnsGenerator(self.rng, self.config.returns) + returns_records = returns_gen.generate(sales_records, self.config.end_date) + if returns_records: + logger.info("seeder.returns.generating", count=len(returns_records)) + await self._batch_insert(db, SalesReturn, returns_records) + return ( len(sales_records), len(price_records), len(promo_records), len(inventory_records), + len(returns_records), ) async def generate_full(self, db: AsyncSession) -> SeederResult: @@ -309,10 +385,17 @@ async def generate_full(self, db: AsyncSession) -> SeederResult: # Generate dimensions first store_ids, product_data, dates = await self._generate_dimensions(db) + # Phase 1: generate exogenous signals (no-op when no signal is enabled). + exogenous_count, weather_lookup = await self._generate_exogenous(db, store_ids, dates) + # Generate facts - sales_count, price_count, promo_count, inventory_count = await self._generate_facts( - db, store_ids, product_data, dates - ) + ( + sales_count, + price_count, + promo_count, + inventory_count, + returns_count, + ) = await self._generate_facts(db, store_ids, product_data, dates, weather_lookup) # Commit all changes await db.commit() @@ -325,6 +408,8 @@ async def generate_full(self, db: AsyncSession) -> SeederResult: price_history_count=price_count, promotions_count=promo_count, inventory_count=inventory_count, + exogenous_count=exogenous_count, + returns_count=returns_count, seed=self.config.seed, ) @@ -334,6 +419,8 @@ async def generate_full(self, db: AsyncSession) -> SeederResult: products=result.products_count, calendar_days=result.calendar_days, sales=result.sales_count, + exogenous=result.exogenous_count, + returns=result.returns_count, seed=self.config.seed, ) @@ -397,10 +484,17 @@ async def append_data( dates.append(current) current += timedelta(days=1) + # Phase 1: append exogenous signals for the new range (no-op when off). + exogenous_count, weather_lookup = await self._generate_exogenous(db, store_ids, dates) + # Generate facts for new date range - sales_count, price_count, promo_count, inventory_count = await self._generate_facts( - db, store_ids, product_data, dates - ) + ( + sales_count, + price_count, + promo_count, + inventory_count, + returns_count, + ) = await self._generate_facts(db, store_ids, product_data, dates, weather_lookup) await db.commit() @@ -412,6 +506,8 @@ async def append_data( price_history_count=price_count, promotions_count=promo_count, inventory_count=inventory_count, + exogenous_count=exogenous_count, + returns_count=returns_count, seed=self.config.seed, ) @@ -419,6 +515,8 @@ async def append_data( "seeder.append.completed", calendar_days=result_data.calendar_days, sales=result_data.sales_count, + exogenous=result_data.exogenous_count, + returns=result_data.returns_count, ) return result_data @@ -441,8 +539,13 @@ async def delete_data( """ counts: dict[str, int] = {} - # Get current counts + # Get current counts. Phase 1 tables are listed BEFORE the older fact + # tables so they're deleted first — sales_returns FKs to product/store + # and exogenous_signal FKs to store/calendar, so cleaning them up + # ahead of the dimension/calendar wipe avoids FK violations. fact_tables = [ + ("sales_returns", SalesReturn), + ("exogenous_signal", ExogenousSignal), ("sales_daily", SalesDaily), ("inventory_snapshot_daily", InventorySnapshotDaily), ("price_history", PriceHistory), @@ -527,6 +630,8 @@ async def get_current_counts(self, db: AsyncSession) -> dict[str, int]: ("price_history", PriceHistory), ("promotion", Promotion), ("inventory_snapshot_daily", InventorySnapshotDaily), + ("exogenous_signal", ExogenousSignal), + ("sales_returns", SalesReturn), ] counts: dict[str, int] = {} @@ -585,4 +690,27 @@ async def verify_data_integrity(self, db: AsyncSession) -> list[str]: f"Calendar gap detected: expected {expected_days} days, found {actual_days}" ) + # Phase 1: sales_returns must never carry quantity <= 0 (CHECK + # constraint guards this at the DB layer, but a defensive count + # catches drift if a future generator drops the invariant). + neg_return_check = text("SELECT COUNT(*) FROM sales_returns WHERE return_quantity < 1") + result = await db.execute(neg_return_check) + neg_returns = result.scalar() or 0 + if neg_returns > 0: + errors.append(f"Found {neg_returns} sales_returns with non-positive quantity") + + # Phase 1: exogenous_signal global/per-store consistency. + bad_global_check = text( + "SELECT COUNT(*) FROM exogenous_signal " + "WHERE (is_global = true AND store_id IS NOT NULL) " + " OR (is_global = false AND store_id IS NULL)" + ) + result = await db.execute(bad_global_check) + bad_global = result.scalar() or 0 + if bad_global > 0: + errors.append( + f"Found {bad_global} exogenous_signal rows violating " + "is_global / store_id consistency" + ) + return errors diff --git a/app/shared/seeder/generators/__init__.py b/app/shared/seeder/generators/__init__.py index a8083550..53c993f5 100644 --- a/app/shared/seeder/generators/__init__.py +++ b/app/shared/seeder/generators/__init__.py @@ -1,21 +1,25 @@ -"""Data generators for dimensions and facts.""" - -from app.shared.seeder.generators.calendar import CalendarGenerator -from app.shared.seeder.generators.facts import ( - InventorySnapshotGenerator, - PriceHistoryGenerator, - PromotionGenerator, - SalesDailyGenerator, -) -from app.shared.seeder.generators.product import ProductGenerator -from app.shared.seeder.generators.store import StoreGenerator - -__all__ = [ - "CalendarGenerator", - "InventorySnapshotGenerator", - "PriceHistoryGenerator", - "ProductGenerator", - "PromotionGenerator", - "SalesDailyGenerator", - "StoreGenerator", -] +"""Data generators for dimensions and facts.""" + +from app.shared.seeder.generators.calendar import CalendarGenerator +from app.shared.seeder.generators.exogenous import ExogenousSignalGenerator +from app.shared.seeder.generators.facts import ( + InventorySnapshotGenerator, + PriceHistoryGenerator, + PromotionGenerator, + SalesDailyGenerator, +) +from app.shared.seeder.generators.product import ProductGenerator +from app.shared.seeder.generators.returns import ReturnsGenerator +from app.shared.seeder.generators.store import StoreGenerator + +__all__ = [ + "CalendarGenerator", + "ExogenousSignalGenerator", + "InventorySnapshotGenerator", + "PriceHistoryGenerator", + "ProductGenerator", + "PromotionGenerator", + "ReturnsGenerator", + "SalesDailyGenerator", + "StoreGenerator", +] diff --git a/app/shared/seeder/generators/exogenous.py b/app/shared/seeder/generators/exogenous.py new file mode 100644 index 00000000..c720ba48 --- /dev/null +++ b/app/shared/seeder/generators/exogenous.py @@ -0,0 +1,146 @@ +"""Exogenous signal generator (weather, macro index, event flags). + +Phase 1 of the seeder realism extension. Produces rows for the +``exogenous_signal`` table. Each enabled signal contributes records; +disabled signals contribute zero rows so callers that don't opt in see +no Phase 1 side effects. + +The output schema matches ``app.features.data_platform.models.ExogenousSignal``: + + {"date", "signal_name", "store_id", "is_global", "value"} + +Reproducibility: this generator uses the seeder's ``random.Random`` instance +(NOT numpy.random) so identical seeds produce identical signal series. +""" + +from __future__ import annotations + +import math +import random +from datetime import date +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from app.shared.seeder.config import ExogenousSignalConfig + + +WEATHER_SIGNAL_NAME = "weather_temp_c" +MACRO_SIGNAL_NAME = "macro_index" +EVENT_SIGNAL_NAME = "event_flag" + + +class ExogenousSignalGenerator: + """Generator for exogenous demand signals. + + Produces one row per (signal, date[, store]) for each enabled signal: + + - ``weather_temp_c``: per (store, date). Temperature in °C following a + yearly sin wave with Gaussian noise. ``is_global=False``. + - ``macro_index``: per date. Random walk starting at + ``macro_initial_value``. ``is_global=True``. + - ``event_flag``: per ``event_dates`` entry. Binary 1.0 marker. + ``is_global=True``. + """ + + def __init__(self, rng: random.Random, config: ExogenousSignalConfig) -> None: + """Initialize the generator. + + Args: + rng: Seeded random number generator. + config: Exogenous signal configuration. + """ + self.rng = rng + self.config = config + + def _weather_row( + self, + signal_date: date, + store_id: int, + day_of_year: int, + ) -> dict[str, date | int | bool | str | float | None]: + """Compute one weather sample for (store, date). + + Uses a sinusoidal seasonal cycle around the climatological mean with + peak in mid-July (day-of-year 196) for the northern hemisphere. + """ + # Phase chosen so peak is around day 196 (mid-July): sin peaks at π/2, + # so we want 2π(d - 105)/365 = π/2 → d = 196. + phase_rad = 2.0 * math.pi * (day_of_year - 105) / 365.0 + seasonal = self.config.weather_amplitude_c * math.sin(phase_rad) + noise = self.rng.gauss(0.0, self.config.weather_noise_sigma_c) + value = self.config.weather_climatology_mean_c + seasonal + noise + return { + "date": signal_date, + "signal_name": WEATHER_SIGNAL_NAME, + "store_id": store_id, + "is_global": False, + "value": value, + } + + def _macro_rows( + self, dates: list[date] + ) -> list[dict[str, date | int | bool | str | float | None]]: + """Random-walk macro index, one row per date.""" + records: list[dict[str, date | int | bool | str | float | None]] = [] + value = self.config.macro_initial_value + for d in dates: + value += self.rng.gauss(0.0, self.config.macro_step_sigma) + records.append( + { + "date": d, + "signal_name": MACRO_SIGNAL_NAME, + "store_id": None, + "is_global": True, + "value": value, + } + ) + return records + + def _event_rows( + self, dates: list[date] + ) -> list[dict[str, date | int | bool | str | float | None]]: + """Binary event-flag rows for configured event dates within range.""" + if not self.config.event_dates: + return [] + date_set = set(dates) + return [ + { + "date": event_date, + "signal_name": EVENT_SIGNAL_NAME, + "store_id": None, + "is_global": True, + "value": 1.0, + } + for event_date in self.config.event_dates + if event_date in date_set + ] + + def generate( + self, dates: list[date], store_ids: list[int] + ) -> list[dict[str, date | int | bool | str | float | None]]: + """Generate exogenous signal rows. + + Args: + dates: Dates in the seeded range (sorted ascending). + store_ids: Store IDs for per-store signals. + + Returns: + List of row dicts ready for batch insert. Empty when no signal + is enabled. + """ + records: list[dict[str, date | int | bool | str | float | None]] = [] + + if self.config.enable_weather and store_ids and dates: + # Iterate stores in the outer loop so the rng draws per store + # are deterministic and reproducible. + for store_id in store_ids: + for d in dates: + records.append(self._weather_row(d, store_id, d.timetuple().tm_yday)) + + if self.config.enable_macro and dates: + records.extend(self._macro_rows(dates)) + + if self.config.enable_events: + records.extend(self._event_rows(dates)) + + return records diff --git a/app/shared/seeder/generators/facts.py b/app/shared/seeder/generators/facts.py index 78e0f5eb..61758d71 100644 --- a/app/shared/seeder/generators/facts.py +++ b/app/shared/seeder/generators/facts.py @@ -10,15 +10,24 @@ if TYPE_CHECKING: from app.shared.seeder.config import ( + ChangepointConfig, HolidayConfig, + MultiSeasonalityConfig, RetailPatternConfig, SparsityConfig, + SubstitutionConfig, TimeSeriesConfig, ) class SalesDailyGenerator: - """Generator for daily sales fact data with realistic time-series patterns.""" + """Generator for daily sales fact data with realistic time-series patterns. + + Phase 1 extensions (``multi_seasonality``, ``changepoints``, + ``substitution``, ``exogenous_weather``) are all opt-in. When every Phase + 1 input is None / disabled, the generator's output is byte-identical to + its pre-Phase-1 behavior. + """ def __init__( self, @@ -27,6 +36,12 @@ def __init__( retail_config: RetailPatternConfig, sparsity_config: SparsityConfig, holidays: list[HolidayConfig], + multi_seasonality: MultiSeasonalityConfig | None = None, + changepoints: ChangepointConfig | None = None, + substitution: SubstitutionConfig | None = None, + exogenous_weather: dict[tuple[int, date], float] | None = None, + weather_temperature_sensitivity: float = 0.0, + weather_climatology_mean_c: float = 15.0, ) -> None: """Initialize the sales generator. @@ -36,12 +51,143 @@ def __init__( retail_config: Retail-specific pattern configuration. sparsity_config: Data sparsity configuration. holidays: List of holiday configurations with multipliers. + multi_seasonality: Optional yearly seasonality configuration. + When None or amplitude=0, no yearly multiplier is applied. + changepoints: Optional list of demand changepoints. When None or + empty, no changepoint multiplier is applied. + substitution: Optional substitution configuration. When None or + disabled, no substitution lift is applied. + exogenous_weather: Optional lookup ``{(store_id, date): temp_c}``. + Each entry shifts demand by + ``weather_temperature_sensitivity * (temp_c - climatology_mean_c)`` + fraction (i.e. linear, centered on the climatology mean). + When None, no weather effect. + weather_temperature_sensitivity: Demand delta per °C above the + climatology mean (used only when ``exogenous_weather`` is set). + weather_climatology_mean_c: Reference temperature for the linear + weather term. """ self.rng = rng self.ts_config = time_series_config self.retail_config = retail_config self.sparsity_config = sparsity_config self.holiday_map = {h.date: h.multiplier for h in holidays} + self.multi_seasonality = multi_seasonality + self.changepoints = changepoints + self.substitution = substitution + self.exogenous_weather = exogenous_weather + self.weather_sensitivity = weather_temperature_sensitivity + self.weather_climatology_mean_c = weather_climatology_mean_c + + # Pre-compute substitution group memberships for O(1) lookup. + self._substitution_groups_by_product: dict[int, list[list[int]]] = {} + if self.substitution is not None and self.substitution.enable: + for group in self.substitution.substitute_groups: + for product_id in group: + self._substitution_groups_by_product.setdefault(product_id, []).append(group) + + def _yearly_seasonality_multiplier(self, current_date: date) -> float: + """Return the yearly seasonality multiplier for ``current_date``. + + Returns 1.0 when multi-seasonality is unset or amplitude is 0 — that + preserves the pre-Phase-1 output byte-for-byte. + """ + if ( + self.multi_seasonality is None + or self.multi_seasonality.yearly_seasonality_amplitude == 0.0 + ): + return 1.0 + day_of_year = current_date.timetuple().tm_yday + offset = self.multi_seasonality.yearly_phase_offset_days + phase = 2.0 * math.pi * (day_of_year + offset) / 365.0 + return 1.0 + self.multi_seasonality.yearly_seasonality_amplitude * math.sin(phase) + + def _changepoint_multiplier(self, current_date: date) -> float: + """Aggregate multiplier from all changepoints active on ``current_date``. + + Each changepoint contributes ``(multiplier - 1) * exp(-Δ/decay)`` if + ``current_date >= changepoint.date`` and 0 otherwise. The total + multiplier is ``1 + Σ contributions``. + + Returns 1.0 when there are no changepoints — preserving byte-identical + output for callers that don't opt in. + """ + if self.changepoints is None or not self.changepoints.changepoints: + return 1.0 + contribution = 0.0 + for cp in self.changepoints.changepoints: + delta_days = (current_date - cp.date).days + if delta_days < 0: + continue + if cp.decay_days <= 0: + # Pure impulse on the changepoint date only. + if delta_days == 0: + contribution += cp.demand_multiplier - 1.0 + continue + decay = math.exp(-delta_days / cp.decay_days) + contribution += (cp.demand_multiplier - 1.0) * decay + return 1.0 + contribution + + def _weather_multiplier(self, current_date: date, store_id: int) -> float: + """Linear weather effect centered on the climatology mean. + + Returns 1.0 when no weather data is configured. + """ + if self.exogenous_weather is None or self.weather_sensitivity == 0.0: + return 1.0 + temp_c = self.exogenous_weather.get((store_id, current_date)) + if temp_c is None: + return 1.0 + return 1.0 + self.weather_sensitivity * (temp_c - self.weather_climatology_mean_c) + + def _substitution_multiplier( + self, + product_id: int, + stockouts_today: set[int], + ) -> float: + """Lift demand for ``product_id`` when stocked-out group-mates exist. + + ``stockouts_today`` is the set of product IDs stocked out on the + current date at the same store. For each substitution group the + product belongs to, we count how many other members are stocked out + and distribute ``substitution_lift_on_stockout`` across the surviving + in-stock members. + + Returns 1.0 when substitution is disabled or no group-mate is out. + """ + if ( + self.substitution is None + or not self.substitution.enable + or self.substitution.substitution_lift_on_stockout == 0.0 + ): + return 1.0 + groups = self._substitution_groups_by_product.get(product_id) + if not groups: + return 1.0 + if product_id in stockouts_today: + return 1.0 # A stocked-out product can't pick up lift. + + contribution = 0.0 + for group in groups: + out_members = sum( + 1 for member in group if member != product_id and member in stockouts_today + ) + survivors = sum( + 1 for member in group if member != product_id and member not in stockouts_today + ) + if out_members == 0 or survivors == 0: + # Either no group-mate is out, or we'd divide by zero (e.g. + # everyone but this product is out — in that case give all + # the lift to this product). + if out_members > 0 and survivors == 0: + contribution += self.substitution.substitution_lift_on_stockout * out_members + continue + # Each out member's lift is split among (survivors + 1) including + # this product, so this product captures one share per out member. + contribution += ( + self.substitution.substitution_lift_on_stockout * out_members / (survivors + 1) + ) + return 1.0 + contribution def _compute_demand( self, @@ -52,6 +198,9 @@ def _compute_demand( is_promotion: bool, is_stockout: bool, product_launch_date: date | None, + store_id: int | None = None, + product_id: int | None = None, + stockouts_today_for_store: set[int] | None = None, ) -> int: """Compute demand for a single observation. @@ -63,6 +212,12 @@ def _compute_demand( is_promotion: Whether there's an active promotion. is_stockout: Whether there's a stockout. product_launch_date: Optional launch date for new product ramp. + store_id: Store ID (used only by Phase 1 weather + substitution + effects). Required when those features are enabled. + product_id: Product ID (used only by Phase 1 substitution). + Required when substitution is enabled. + stockouts_today_for_store: Set of product IDs stocked out at + ``store_id`` on ``current_date``. Used only by substitution. Returns: Computed demand quantity (non-negative integer). @@ -111,6 +266,14 @@ def _compute_demand( demand *= ramp_factor # If ramp_days == 0, skip ramp calculation (demand unchanged) + # Phase 1 multipliers (each returns 1.0 when its feature is off). + demand *= self._yearly_seasonality_multiplier(current_date) + demand *= self._changepoint_multiplier(current_date) + if store_id is not None: + demand *= self._weather_multiplier(current_date, store_id) + if product_id is not None and stockouts_today_for_store is not None: + demand *= self._substitution_multiplier(product_id, stockouts_today_for_store) + # Apply noise if self.ts_config.noise_sigma > 0: noise = self.rng.gauss(0, self.ts_config.noise_sigma) @@ -182,6 +345,15 @@ def generate( gaps.add(dates[gap_start_idx + i]) gap_dates[key] = gaps + # Phase 1: per-(store, date) lookup of stocked-out product IDs for + # substitution. Only build it when substitution is enabled — keeps + # the disabled-path byte-identical with pre-Phase-1. + stockouts_by_store_date: dict[tuple[int, date], set[int]] = {} + if self.substitution is not None and self.substitution.enable: + for (s_id, p_id), out_dates in stockouts.items(): + for d in out_dates: + stockouts_by_store_date.setdefault((s_id, d), set()).add(p_id) + # Generate sales for each active combination and date for store_id in store_ids: for product_id, base_price in product_data: @@ -203,6 +375,8 @@ def generate( is_promotion = current_date in promo_dates is_stockout = current_date in stockout_dates + stockouts_today = stockouts_by_store_date.get((store_id, current_date)) + quantity = self._compute_demand( current_date=current_date, base_date=base_date, @@ -211,6 +385,9 @@ def generate( is_promotion=is_promotion, is_stockout=is_stockout, product_launch_date=None, # Could be extended + store_id=store_id, + product_id=product_id, + stockouts_today_for_store=stockouts_today, ) # Skip zero sales from stockouts to reduce data volume diff --git a/app/shared/seeder/generators/returns.py b/app/shared/seeder/generators/returns.py new file mode 100644 index 00000000..96432d3a --- /dev/null +++ b/app/shared/seeder/generators/returns.py @@ -0,0 +1,121 @@ +"""Returns generator: synthetic ``sales_returns`` rows. + +Phase 1 of the seeder realism extension. Samples sales rows +(probabilistically) and emits a delayed return event for each pick. The +return is *not* subtracted from ``sales_daily.quantity`` — returns are an +additive, separately queryable table so the forecasting/feature layer can +opt in. + +Output schema matches ``app.features.data_platform.models.SalesReturn``: + + {"date", "store_id", "product_id", "return_quantity", "return_reason"} + +Reproducibility: uses the seeder's ``random.Random`` instance. +""" + +from __future__ import annotations + +import random +from datetime import date, timedelta +from decimal import Decimal +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from app.shared.seeder.config import ReturnsConfig + + +class ReturnsGenerator: + """Generator for synthetic sales returns.""" + + def __init__(self, rng: random.Random, config: ReturnsConfig) -> None: + """Initialize the generator. + + Args: + rng: Seeded random number generator. + config: Returns configuration. + """ + self.rng = rng + self.config = config + + def _pick_reason(self) -> str: + """Sample a return reason from the configured distribution. + + Returns: + Reason string. Falls back to ``"unspecified"`` if the + distribution is empty (defensive — config defaults are non-empty). + """ + reasons = list(self.config.return_reason_distribution.keys()) + weights = list(self.config.return_reason_distribution.values()) + if not reasons: + return "unspecified" + # random.choices is deterministic under self.rng. + return self.rng.choices(reasons, weights=weights, k=1)[0] + + def generate( + self, + sales_records: list[dict[str, date | int | Decimal]], + end_date: date, + ) -> list[dict[str, date | int | str]]: + """Generate return rows from a list of sales rows. + + Args: + sales_records: Sales dicts from ``SalesDailyGenerator.generate``. + Each must contain ``date``, ``store_id``, ``product_id``, + ``quantity``. + end_date: Calendar end date. Returns lagged beyond ``end_date`` + are clamped to ``end_date`` (so they have a calendar FK + target and don't trigger FK violations). + + Returns: + List of return-row dicts. Empty when the returns feature is + disabled or no sales qualify. + """ + if not self.config.enable or not sales_records: + return [] + + lag_min = self.config.return_lag_days_min + lag_max = max(self.config.return_lag_days_max, lag_min) + + returns: list[dict[str, date | int | str]] = [] + for sale in sales_records: + quantity = sale["quantity"] + sale_date = sale["date"] + store_id = sale["store_id"] + product_id = sale["product_id"] + # Sales rows from SalesDailyGenerator carry these types; the + # union annotation is wider than the runtime guarantees because + # the same dict shape is reused for inserts. Defensive narrowing + # here keeps mypy --strict happy without a cast. + if not ( + isinstance(quantity, int) + and isinstance(sale_date, date) + and isinstance(store_id, int) + and isinstance(product_id, int) + ): + continue + if quantity <= 0: + continue + if self.rng.random() >= self.config.return_probability: + continue + + lag = self.rng.randint(lag_min, lag_max) + return_date = sale_date + timedelta(days=lag) + if return_date > end_date: + return_date = end_date + + # Fraction of original quantity, with a minimum of 1 unit. + raw_qty = quantity * self.config.return_quantity_fraction + return_qty = max(1, round(raw_qty)) + return_qty = min(return_qty, quantity) + + returns.append( + { + "date": return_date, + "store_id": store_id, + "product_id": product_id, + "return_quantity": return_qty, + "return_reason": self._pick_reason(), + } + ) + + return returns diff --git a/app/shared/seeder/tests/test_core.py b/app/shared/seeder/tests/test_core.py index 036db889..eaffe370 100644 --- a/app/shared/seeder/tests/test_core.py +++ b/app/shared/seeder/tests/test_core.py @@ -233,12 +233,13 @@ def seeder(self): async def test_returns_empty_list_when_valid(self, seeder): """Test empty list returned when data is valid.""" mock_db = AsyncMock() - # Create separate mock results for each execute call - # verify_data_integrity makes 4 calls: + # verify_data_integrity now makes 6 execute calls: # 1. orphan check # 2. negative qty check # 3. min/max date check # 4. calendar count + # 5. (Phase 1) sales_returns non-positive check + # 6. (Phase 1) exogenous_signal is_global/store_id consistency mock_result1 = MagicMock() mock_result1.scalar.return_value = 0 # no orphans mock_result2 = MagicMock() @@ -247,8 +248,19 @@ async def test_returns_empty_list_when_valid(self, seeder): mock_result3.fetchone.return_value = (date(2024, 1, 1), date(2024, 1, 31)) mock_result4 = MagicMock() mock_result4.scalar.return_value = 31 # 31 days matches Jan 1-31 - - mock_db.execute.side_effect = [mock_result1, mock_result2, mock_result3, mock_result4] + mock_result5 = MagicMock() + mock_result5.scalar.return_value = 0 # no bad returns + mock_result6 = MagicMock() + mock_result6.scalar.return_value = 0 # no inconsistent exogenous rows + + mock_db.execute.side_effect = [ + mock_result1, + mock_result2, + mock_result3, + mock_result4, + mock_result5, + mock_result6, + ] errors = await seeder.verify_data_integrity(mock_db) @@ -258,7 +270,6 @@ async def test_returns_empty_list_when_valid(self, seeder): async def test_detects_orphaned_sales(self, seeder): """Test orphaned sales are detected.""" mock_db = AsyncMock() - # Create separate mock results for each execute call mock_result1 = MagicMock() mock_result1.scalar.return_value = 5 # orphan check returns 5 errors mock_result2 = MagicMock() @@ -267,8 +278,19 @@ async def test_detects_orphaned_sales(self, seeder): mock_result3.fetchone.return_value = (date(2024, 1, 1), date(2024, 1, 31)) mock_result4 = MagicMock() mock_result4.scalar.return_value = 31 # calendar count - - mock_db.execute.side_effect = [mock_result1, mock_result2, mock_result3, mock_result4] + mock_result5 = MagicMock() + mock_result5.scalar.return_value = 0 # bad returns + mock_result6 = MagicMock() + mock_result6.scalar.return_value = 0 # inconsistent exogenous + + mock_db.execute.side_effect = [ + mock_result1, + mock_result2, + mock_result3, + mock_result4, + mock_result5, + mock_result6, + ] errors = await seeder.verify_data_integrity(mock_db) @@ -278,7 +300,6 @@ async def test_detects_orphaned_sales(self, seeder): async def test_detects_negative_quantities(self, seeder): """Test negative quantities are detected.""" mock_db = AsyncMock() - # Create separate mock results for each execute call mock_result1 = MagicMock() mock_result1.scalar.return_value = 0 # orphan check mock_result2 = MagicMock() @@ -287,8 +308,19 @@ async def test_detects_negative_quantities(self, seeder): mock_result3.fetchone.return_value = (date(2024, 1, 1), date(2024, 1, 31)) mock_result4 = MagicMock() mock_result4.scalar.return_value = 31 # calendar count - - mock_db.execute.side_effect = [mock_result1, mock_result2, mock_result3, mock_result4] + mock_result5 = MagicMock() + mock_result5.scalar.return_value = 0 # bad returns + mock_result6 = MagicMock() + mock_result6.scalar.return_value = 0 # inconsistent exogenous + + mock_db.execute.side_effect = [ + mock_result1, + mock_result2, + mock_result3, + mock_result4, + mock_result5, + mock_result6, + ] errors = await seeder.verify_data_integrity(mock_db) diff --git a/app/shared/seeder/tests/test_exogenous.py b/app/shared/seeder/tests/test_exogenous.py new file mode 100644 index 00000000..4cb2831a --- /dev/null +++ b/app/shared/seeder/tests/test_exogenous.py @@ -0,0 +1,125 @@ +"""Tests for ExogenousSignalGenerator (Phase 1).""" + +# mypy: disable-error-code="union-attr,arg-type,operator,return-value" +# Generator dicts have a wide union; tests narrow at access time. + +import math +import random +from datetime import date, timedelta + +from app.shared.seeder.config import ExogenousSignalConfig +from app.shared.seeder.generators.exogenous import ( + EVENT_SIGNAL_NAME, + MACRO_SIGNAL_NAME, + WEATHER_SIGNAL_NAME, + ExogenousSignalGenerator, +) + + +def _date_range(start: date, days: int) -> list[date]: + return [start + timedelta(days=i) for i in range(days)] + + +class TestExogenousSignalGeneratorDisabled: + def test_all_disabled_produces_no_rows(self): + gen = ExogenousSignalGenerator(random.Random(42), ExogenousSignalConfig()) + rows = gen.generate(_date_range(date(2024, 1, 1), 5), [1, 2]) + assert rows == [] + + +class TestWeather: + def test_weather_emits_row_per_store_and_date(self): + cfg = ExogenousSignalConfig(enable_weather=True, weather_noise_sigma_c=0.0) + gen = ExogenousSignalGenerator(random.Random(42), cfg) + store_ids = [1, 2, 3] + dates = _date_range(date(2024, 1, 1), 7) + rows = gen.generate(dates, store_ids) + weather_rows = [r for r in rows if r["signal_name"] == WEATHER_SIGNAL_NAME] + assert len(weather_rows) == len(store_ids) * len(dates) + # Sanity: each row is per-store (is_global=False), store_id non-null. + for r in weather_rows: + assert r["is_global"] is False + assert r["store_id"] in store_ids + assert isinstance(r["value"], float) + + def test_weather_seasonal_peak_in_summer(self): + # With zero noise the value should follow the deterministic sin wave. + cfg = ExogenousSignalConfig( + enable_weather=True, + weather_amplitude_c=10.0, + weather_climatology_mean_c=15.0, + weather_noise_sigma_c=0.0, + ) + gen = ExogenousSignalGenerator(random.Random(0), cfg) + # July 14 = doy 196 → peak + # January 14 = doy 14 → near trough + rows = gen.generate([date(2024, 7, 14), date(2024, 1, 14)], [1]) + by_date = {r["date"]: r["value"] for r in rows} + # Peak is roughly mean + amplitude; trough roughly mean - amplitude. + assert by_date[date(2024, 7, 14)] > by_date[date(2024, 1, 14)] + assert abs(by_date[date(2024, 7, 14)] - 25.0) < 0.5 + assert by_date[date(2024, 1, 14)] < 10.0 + + def test_weather_reproducible(self): + cfg = ExogenousSignalConfig(enable_weather=True) + gen1 = ExogenousSignalGenerator(random.Random(7), cfg) + gen2 = ExogenousSignalGenerator(random.Random(7), cfg) + dates = _date_range(date(2024, 1, 1), 30) + assert gen1.generate(dates, [1, 2]) == gen2.generate(dates, [1, 2]) + + +class TestMacroIndex: + def test_macro_row_per_date(self): + cfg = ExogenousSignalConfig(enable_macro=True) + gen = ExogenousSignalGenerator(random.Random(42), cfg) + dates = _date_range(date(2024, 6, 1), 10) + rows = gen.generate(dates, []) + macro = [r for r in rows if r["signal_name"] == MACRO_SIGNAL_NAME] + assert len(macro) == len(dates) + for r in macro: + assert r["is_global"] is True + assert r["store_id"] is None + + def test_macro_random_walk_changes_value(self): + cfg = ExogenousSignalConfig( + enable_macro=True, macro_initial_value=100.0, macro_step_sigma=1.0 + ) + gen = ExogenousSignalGenerator(random.Random(1), cfg) + dates = _date_range(date(2024, 1, 1), 30) + rows = [r for r in gen.generate(dates, []) if r["signal_name"] == MACRO_SIGNAL_NAME] + values = [r["value"] for r in rows] + # The first value already has one rng step applied so it's not + # exactly 100; just confirm the walk produces variation. + assert len({round(v, 6) for v in values}) > 1 + assert abs(values[0] - 100.0) < 5.0 # one small step + + def test_zero_step_sigma_yields_constant(self): + cfg = ExogenousSignalConfig( + enable_macro=True, macro_initial_value=42.0, macro_step_sigma=0.0 + ) + gen = ExogenousSignalGenerator(random.Random(99), cfg) + rows = gen.generate(_date_range(date(2024, 1, 1), 5), []) + macro_values = [r["value"] for r in rows if r["signal_name"] == MACRO_SIGNAL_NAME] + assert all(math.isclose(v, 42.0) for v in macro_values) + + +class TestEvents: + def test_events_only_within_range(self): + cfg = ExogenousSignalConfig( + enable_events=True, + event_dates=[date(2024, 1, 3), date(2025, 6, 1)], + ) + gen = ExogenousSignalGenerator(random.Random(0), cfg) + rows = gen.generate(_date_range(date(2024, 1, 1), 31), []) + events = [r for r in rows if r["signal_name"] == EVENT_SIGNAL_NAME] + # 2024-01-03 is in range; 2025-06-01 is not. + assert len(events) == 1 + assert events[0]["date"] == date(2024, 1, 3) + assert events[0]["value"] == 1.0 + assert events[0]["is_global"] is True + + def test_events_disabled_emits_nothing(self): + cfg = ExogenousSignalConfig(enable_events=False, event_dates=[date(2024, 1, 3)]) + gen = ExogenousSignalGenerator(random.Random(0), cfg) + rows = gen.generate(_date_range(date(2024, 1, 1), 31), []) + assert all(r["signal_name"] != EVENT_SIGNAL_NAME for r in rows) diff --git a/app/shared/seeder/tests/test_integration.py b/app/shared/seeder/tests/test_integration.py index facb43f5..784ba246 100644 --- a/app/shared/seeder/tests/test_integration.py +++ b/app/shared/seeder/tests/test_integration.py @@ -20,11 +20,13 @@ from app.core.config import get_settings from app.features.data_platform.models import ( Calendar, + ExogenousSignal, InventorySnapshotDaily, PriceHistory, Product, Promotion, SalesDaily, + SalesReturn, Store, ) from app.shared.seeder import DataSeeder, SeederConfig @@ -76,7 +78,10 @@ async def db_session() -> AsyncGenerator[AsyncSession, None]: # Pre-test cleanup for proper isolation async with session_maker() as cleanup_session: try: - # Delete in FK order (facts before dimensions) + # Delete in FK order (facts before dimensions). Phase 1 tables + # come first because they FK to store/product/calendar. + await cleanup_session.execute(delete(SalesReturn)) + await cleanup_session.execute(delete(ExogenousSignal)) await cleanup_session.execute(delete(SalesDaily)) await cleanup_session.execute(delete(InventorySnapshotDaily)) await cleanup_session.execute(delete(PriceHistory)) @@ -102,7 +107,10 @@ async def db_session() -> AsyncGenerator[AsyncSession, None]: # Post-test cleanup async with session_maker() as cleanup_session: try: - # Delete in FK order (facts before dimensions) + # Delete in FK order (facts before dimensions). Phase 1 tables + # come first because they FK to store/product/calendar. + await cleanup_session.execute(delete(SalesReturn)) + await cleanup_session.execute(delete(ExogenousSignal)) await cleanup_session.execute(delete(SalesDaily)) await cleanup_session.execute(delete(InventorySnapshotDaily)) await cleanup_session.execute(delete(PriceHistory)) diff --git a/app/shared/seeder/tests/test_phase1_config.py b/app/shared/seeder/tests/test_phase1_config.py new file mode 100644 index 00000000..b82123b2 --- /dev/null +++ b/app/shared/seeder/tests/test_phase1_config.py @@ -0,0 +1,103 @@ +"""Tests for Phase 1 seeder configuration dataclasses. + +Covers ExogenousSignalConfig, MultiSeasonalityConfig, ChangepointEvent / +ChangepointConfig, ReturnsConfig, SubstitutionConfig — and confirms the +SeederConfig defaults wire them in with disabled / empty defaults. +""" + +from datetime import date + +from app.shared.seeder.config import ( + ChangepointConfig, + ChangepointEvent, + ExogenousSignalConfig, + MultiSeasonalityConfig, + ReturnsConfig, + ScenarioPreset, + SeederConfig, + SubstitutionConfig, +) + + +class TestExogenousSignalConfig: + def test_defaults_disabled(self): + config = ExogenousSignalConfig() + assert config.enable_weather is False + assert config.enable_macro is False + assert config.enable_events is False + assert config.weather_temperature_sensitivity == 0.0 + assert config.event_dates == [] + + def test_event_dates_is_independent(self): + # Default-factory list must not be shared between instances. + a = ExogenousSignalConfig() + b = ExogenousSignalConfig() + a.event_dates.append(date(2024, 1, 1)) + assert b.event_dates == [] + + +class TestMultiSeasonalityConfig: + def test_defaults_zero(self): + config = MultiSeasonalityConfig() + assert config.yearly_seasonality_amplitude == 0.0 + assert config.yearly_phase_offset_days == 0 + + +class TestChangepointConfig: + def test_default_empty(self): + assert ChangepointConfig().changepoints == [] + + def test_event_fields(self): + event = ChangepointEvent(date=date(2024, 3, 15), demand_multiplier=2.5, decay_days=60) + assert event.date == date(2024, 3, 15) + assert event.demand_multiplier == 2.5 + assert event.decay_days == 60 + + +class TestReturnsConfig: + def test_defaults_disabled(self): + cfg = ReturnsConfig() + assert cfg.enable is False + assert 0.0 <= cfg.return_probability <= 1.0 + assert cfg.return_lag_days_min <= cfg.return_lag_days_max + # Reason distribution default must be non-empty so _pick_reason + # always returns a real reason without falling back. + assert sum(cfg.return_reason_distribution.values()) > 0 + + +class TestSubstitutionConfig: + def test_defaults_disabled(self): + cfg = SubstitutionConfig() + assert cfg.enable is False + assert cfg.substitute_groups == [] + assert cfg.substitution_lift_on_stockout == 0.0 + + +class TestSeederConfigPhase1Wiring: + def test_phase1_defaults_present_and_disabled(self): + cfg = SeederConfig() + # Each Phase 1 sub-config must be present with disabled defaults + # so existing scenarios are byte-identical when not opted in. + assert isinstance(cfg.exogenous, ExogenousSignalConfig) + assert isinstance(cfg.multi_seasonality, MultiSeasonalityConfig) + assert isinstance(cfg.changepoints, ChangepointConfig) + assert isinstance(cfg.returns, ReturnsConfig) + assert isinstance(cfg.substitution, SubstitutionConfig) + assert cfg.exogenous.enable_weather is False + assert cfg.multi_seasonality.yearly_seasonality_amplitude == 0.0 + assert cfg.changepoints.changepoints == [] + assert cfg.returns.enable is False + assert cfg.substitution.enable is False + + def test_from_scenario_does_not_enable_phase1(self): + # Existing scenarios must keep Phase 1 off — this is the + # regression invariant that protects pre-Phase-1 outputs. + for scenario in ScenarioPreset: + cfg = SeederConfig.from_scenario(scenario) + assert cfg.exogenous.enable_weather is False, f"{scenario} unexpectedly enables weather" + assert cfg.exogenous.enable_macro is False + assert cfg.exogenous.enable_events is False + assert cfg.multi_seasonality.yearly_seasonality_amplitude == 0.0 + assert cfg.changepoints.changepoints == [] + assert cfg.returns.enable is False + assert cfg.substitution.enable is False diff --git a/app/shared/seeder/tests/test_phase1_integration.py b/app/shared/seeder/tests/test_phase1_integration.py new file mode 100644 index 00000000..c5751127 --- /dev/null +++ b/app/shared/seeder/tests/test_phase1_integration.py @@ -0,0 +1,334 @@ +"""Phase 1 integration tests against real Postgres. + +Run with: uv run pytest app/shared/seeder/tests/test_phase1_integration.py -v -m integration +Requires docker-compose Postgres up and migrations applied. +""" + +# mypy: disable-error-code="union-attr,arg-type,operator,return-value" + +import os +from collections.abc import AsyncGenerator +from contextlib import suppress +from datetime import date, timedelta + +import pytest +import pytest_asyncio +from sqlalchemy import delete, func, select +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine + +from app.core.config import get_settings +from app.features.data_platform.models import ( + Calendar, + ExogenousSignal, + InventorySnapshotDaily, + PriceHistory, + Product, + Promotion, + SalesDaily, + SalesReturn, + Store, +) +from app.features.seeder import schemas, service +from app.shared.seeder import DataSeeder, SeederConfig +from app.shared.seeder.config import ( + ChangepointConfig, + ChangepointEvent, + DimensionConfig, + ExogenousSignalConfig, + MultiSeasonalityConfig, + ReturnsConfig, +) + +pytestmark = pytest.mark.integration + + +def _check_destructive_test_guard() -> None: + settings = get_settings() + is_testing = getattr(settings, "testing", False) + app_env_testing = os.environ.get("APP_ENV", "").lower() == "testing" + allow_destructive = os.environ.get("ALLOW_DESTRUCTIVE_TEST_DB", "").lower() == "true" + if not is_testing and not app_env_testing and not allow_destructive: + raise RuntimeError( + "Destructive test operations require explicit opt-in. " + "Set ALLOW_DESTRUCTIVE_TEST_DB=true, APP_ENV=testing, or settings.testing=True" + ) + + +@pytest_asyncio.fixture(scope="function") +async def db_session() -> AsyncGenerator[AsyncSession, None]: + _check_destructive_test_guard() + settings = get_settings() + engine = create_async_engine(settings.database_url, echo=False) + session_maker = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) + + async with session_maker() as cleanup_session: + try: + await cleanup_session.execute(delete(SalesReturn)) + await cleanup_session.execute(delete(ExogenousSignal)) + await cleanup_session.execute(delete(SalesDaily)) + await cleanup_session.execute(delete(InventorySnapshotDaily)) + await cleanup_session.execute(delete(PriceHistory)) + await cleanup_session.execute(delete(Promotion)) + await cleanup_session.execute(delete(Calendar)) + await cleanup_session.execute(delete(Product)) + await cleanup_session.execute(delete(Store)) + await cleanup_session.commit() + except Exception: + await cleanup_session.rollback() + + async with session_maker() as session: + try: + yield session + finally: + with suppress(Exception): + await session.rollback() + + _check_destructive_test_guard() + + async with session_maker() as cleanup_session: + try: + await cleanup_session.execute(delete(SalesReturn)) + await cleanup_session.execute(delete(ExogenousSignal)) + await cleanup_session.execute(delete(SalesDaily)) + await cleanup_session.execute(delete(InventorySnapshotDaily)) + await cleanup_session.execute(delete(PriceHistory)) + await cleanup_session.execute(delete(Promotion)) + await cleanup_session.execute(delete(Calendar)) + await cleanup_session.execute(delete(Product)) + await cleanup_session.execute(delete(Store)) + await cleanup_session.commit() + except Exception: + await cleanup_session.rollback() + + await engine.dispose() + + +class TestPhase1Disabled: + @pytest.mark.asyncio + async def test_default_run_creates_no_phase1_rows(self, db_session: AsyncSession) -> None: + """With Phase 1 fully off, exogenous_signal and sales_returns stay empty.""" + config = SeederConfig( + seed=42, + start_date=date(2024, 1, 1), + end_date=date(2024, 1, 7), + dimensions=DimensionConfig(stores=2, products=3), + ) + result = await DataSeeder(config).generate_full(db_session) + assert result.exogenous_count == 0 + assert result.returns_count == 0 + + exo_count = ( + await db_session.execute(select(func.count()).select_from(ExogenousSignal)) + ).scalar() or 0 + ret_count = ( + await db_session.execute(select(func.count()).select_from(SalesReturn)) + ).scalar() or 0 + assert exo_count == 0 + assert ret_count == 0 + + +class TestPhase1Enabled: + @pytest.mark.asyncio + async def test_exogenous_weather_and_macro_persisted(self, db_session: AsyncSession) -> None: + config = SeederConfig( + seed=42, + start_date=date(2024, 1, 1), + end_date=date(2024, 1, 7), # 7 days + dimensions=DimensionConfig(stores=2, products=2), + exogenous=ExogenousSignalConfig( + enable_weather=True, + enable_macro=True, + ), + ) + result = await DataSeeder(config).generate_full(db_session) + # 2 stores x 7 dates weather + 7 dates macro = 21 rows. + assert result.exogenous_count == 21 + + weather_rows = ( + await db_session.execute( + select(func.count()) + .select_from(ExogenousSignal) + .where(ExogenousSignal.signal_name == "weather_temp_c") + ) + ).scalar() or 0 + macro_rows = ( + await db_session.execute( + select(func.count()) + .select_from(ExogenousSignal) + .where(ExogenousSignal.signal_name == "macro_index") + ) + ).scalar() or 0 + assert weather_rows == 14 + assert macro_rows == 7 + + @pytest.mark.asyncio + async def test_returns_table_populated(self, db_session: AsyncSession) -> None: + config = SeederConfig( + seed=42, + start_date=date(2024, 1, 1), + end_date=date(2024, 1, 31), + dimensions=DimensionConfig(stores=2, products=3), + returns=ReturnsConfig(enable=True, return_probability=0.2), + ) + result = await DataSeeder(config).generate_full(db_session) + assert result.returns_count > 0 + # Quantity invariant + bad = ( + await db_session.execute( + select(func.count()).select_from(SalesReturn).where(SalesReturn.return_quantity < 1) + ) + ).scalar() or 0 + assert bad == 0 + + @pytest.mark.asyncio + async def test_changepoint_lifts_demand_at_date(self, db_session: AsyncSession) -> None: + """A 5x changepoint on day 0 with no decay should produce strictly + higher total demand than the baseline run.""" + # Baseline (no changepoint). + base_config = SeederConfig( + seed=42, + start_date=date(2024, 1, 1), + end_date=date(2024, 1, 14), + dimensions=DimensionConfig(stores=2, products=2), + ) + await DataSeeder(base_config).generate_full(db_session) + baseline_total = ( + await db_session.execute( + select(func.sum(SalesDaily.quantity)).where(SalesDaily.date == date(2024, 1, 1)) + ) + ).scalar() or 0 + + # Reset and re-run with a changepoint. + await db_session.execute(delete(SalesReturn)) + await db_session.execute(delete(ExogenousSignal)) + await db_session.execute(delete(SalesDaily)) + await db_session.execute(delete(InventorySnapshotDaily)) + await db_session.execute(delete(PriceHistory)) + await db_session.execute(delete(Promotion)) + await db_session.execute(delete(Calendar)) + await db_session.execute(delete(Product)) + await db_session.execute(delete(Store)) + await db_session.commit() + + cp_config = SeederConfig( + seed=42, + start_date=date(2024, 1, 1), + end_date=date(2024, 1, 14), + dimensions=DimensionConfig(stores=2, products=2), + changepoints=ChangepointConfig( + changepoints=[ + ChangepointEvent( + date=date(2024, 1, 1), + demand_multiplier=5.0, + decay_days=0, + ) + ] + ), + ) + await DataSeeder(cp_config).generate_full(db_session) + cp_total = ( + await db_session.execute( + select(func.sum(SalesDaily.quantity)).where(SalesDaily.date == date(2024, 1, 1)) + ) + ).scalar() or 0 + assert cp_total > baseline_total * 2 # well above the 5x lift floor + + @pytest.mark.asyncio + async def test_verify_integrity_clean_with_phase1(self, db_session: AsyncSession) -> None: + config = SeederConfig( + seed=42, + start_date=date(2024, 1, 1), + end_date=date(2024, 1, 7), + dimensions=DimensionConfig(stores=2, products=2), + exogenous=ExogenousSignalConfig(enable_weather=True), + returns=ReturnsConfig(enable=True, return_probability=0.5), + multi_seasonality=MultiSeasonalityConfig(yearly_seasonality_amplitude=0.1), + ) + seeder = DataSeeder(config) + await seeder.generate_full(db_session) + errors = await seeder.verify_data_integrity(db_session) + assert errors == [] + + +class TestQueryExogenousService: + @pytest.mark.asyncio + async def test_query_returns_persisted_weather(self, db_session: AsyncSession) -> None: + config = SeederConfig( + seed=42, + start_date=date(2024, 1, 1), + end_date=date(2024, 1, 7), + dimensions=DimensionConfig(stores=2, products=2), + exogenous=ExogenousSignalConfig(enable_weather=True), + ) + await DataSeeder(config).generate_full(db_session) + + # Need to commit DataSeeder's writes? DataSeeder.generate_full already + # commits. The fixture's expire_on_commit=False keeps objects valid. + + resp = await service.query_exogenous( + db_session, + signal_name="weather_temp_c", + start_date=date(2024, 1, 1), + end_date=date(2024, 1, 2), + store_id=None, + ) + assert isinstance(resp, schemas.ExogenousSignalResponse) + # 2 stores x 2 dates = 4 weather rows in this window. + assert resp.total == 4 + for r in resp.records: + assert r.signal_name == "weather_temp_c" + assert r.is_global is False + + @pytest.mark.asyncio + async def test_query_filter_by_store(self, db_session: AsyncSession) -> None: + config = SeederConfig( + seed=42, + start_date=date(2024, 1, 1), + end_date=date(2024, 1, 3), + dimensions=DimensionConfig(stores=3, products=2), + exogenous=ExogenousSignalConfig(enable_weather=True, enable_macro=True), + ) + await DataSeeder(config).generate_full(db_session) + + # Pick the first store id present. + store_id_row = (await db_session.execute(select(Store.id).limit(1))).scalar() + assert store_id_row is not None + + resp = await service.query_exogenous( + db_session, + signal_name="weather_temp_c", + start_date=date(2024, 1, 1), + end_date=date(2024, 1, 3), + store_id=store_id_row, + ) + # Only the rows for this store_id over 3 dates. + assert resp.total == 3 + for r in resp.records: + assert r.store_id == store_id_row + + @pytest.mark.asyncio + async def test_query_empty_signal_returns_no_rows(self, db_session: AsyncSession) -> None: + # No data seeded → query should return empty list, not error. + # First seed something to make sure tables exist with FK targets, + # then query a signal we never emitted. + config = SeederConfig( + seed=42, + start_date=date(2024, 1, 1), + end_date=date(2024, 1, 3), + dimensions=DimensionConfig(stores=1, products=1), + ) + await DataSeeder(config).generate_full(db_session) + + resp = await service.query_exogenous( + db_session, + signal_name="weather_temp_c", + start_date=date(2024, 1, 1), + end_date=date(2024, 1, 3), + store_id=None, + ) + assert resp.total == 0 + assert resp.records == [] + + +# Suppress unused-import warning for timedelta — kept for future use. +_ = timedelta diff --git a/app/shared/seeder/tests/test_phase1_regression.py b/app/shared/seeder/tests/test_phase1_regression.py new file mode 100644 index 00000000..7dbe214b --- /dev/null +++ b/app/shared/seeder/tests/test_phase1_regression.py @@ -0,0 +1,92 @@ +"""Regression invariant: Phase 1 toggles OFF == pre-Phase-1 output. + +These tests are LOAD-BEARING: they guarantee that adding the Phase 1 +options to ``SalesDailyGenerator`` does not change the byte-output of +the existing six scenarios. If any of them starts failing, somebody +either added an RNG draw on the disabled path or changed a default +value that affects the existing math. +""" + +# mypy: disable-error-code="union-attr,arg-type,operator,return-value" + +import random +from datetime import date, timedelta +from decimal import Decimal + +import pytest + +from app.shared.seeder.config import ( + ChangepointConfig, + MultiSeasonalityConfig, + ScenarioPreset, + SeederConfig, + SubstitutionConfig, +) +from app.shared.seeder.generators.facts import SalesDailyGenerator + + +def _short_dates(n: int) -> list[date]: + """Use a small date range so the test is fast.""" + return [date(2024, 1, 1) + timedelta(days=i) for i in range(n)] + + +def _run_with_kwargs(config: SeederConfig, **extra_kwargs): + """Run SalesDailyGenerator using ``config`` with optional kwargs.""" + rng = random.Random(config.seed) + gen = SalesDailyGenerator( + rng, + config.time_series, + config.retail, + config.sparsity, + config.holidays, + **extra_kwargs, + ) + return gen.generate( + store_ids=[1, 2], + product_data=[(1, Decimal("9.99")), (2, Decimal("4.99"))], + dates=_short_dates(30), + promotions={}, + stockouts={}, + ) + + +class TestRegressionWithoutKwargs: + """Calling without any Phase 1 kwargs must match calling with explicit + defaults / None / empty configs.""" + + @pytest.mark.parametrize("scenario", list(ScenarioPreset)) + def test_no_kwargs_matches_explicit_defaults(self, scenario: ScenarioPreset): + config = SeederConfig.from_scenario(scenario, seed=42) + # Cap dates to the scenario range we care about. + baseline = _run_with_kwargs(config) + with_defaults = _run_with_kwargs( + config, + multi_seasonality=MultiSeasonalityConfig(), # amplitude=0 default + changepoints=ChangepointConfig(), # empty default + substitution=SubstitutionConfig(), # disabled default + exogenous_weather=None, + weather_temperature_sensitivity=0.0, + ) + assert baseline == with_defaults, ( + f"Phase 1 defaults must not alter output for scenario {scenario.value}" + ) + + def test_disabled_phase1_does_not_consume_rng(self): + """A second generator with Phase 1 features enabled but no data + (e.g. empty changepoints / empty weather lookup) must still + produce the same row count and quantities as the disabled path. + """ + config = SeederConfig.from_scenario(ScenarioPreset.RETAIL_STANDARD, seed=42) + baseline = _run_with_kwargs(config) + # Enable substitution but with no groups → group lookup is empty. + no_op = _run_with_kwargs( + config, + substitution=SubstitutionConfig( + enable=True, + substitute_groups=[], + substitution_lift_on_stockout=0.5, + ), + exogenous_weather={}, # empty lookup + weather_temperature_sensitivity=0.1, # nonzero but no rows match + ) + assert baseline == no_op diff --git a/app/shared/seeder/tests/test_phase1_sales_effects.py b/app/shared/seeder/tests/test_phase1_sales_effects.py new file mode 100644 index 00000000..bc2d01a8 --- /dev/null +++ b/app/shared/seeder/tests/test_phase1_sales_effects.py @@ -0,0 +1,271 @@ +"""Tests for Phase 1 SalesDailyGenerator demand-multiplier extensions. + +Covers yearly seasonality, changepoints, weather-driven demand, and +substitution-on-stockout. The regression invariant — that disabling all +Phase 1 toggles produces byte-identical output to the pre-Phase-1 code +path — is verified in ``test_phase1_regression.py``. +""" + +# mypy: disable-error-code="union-attr,arg-type,operator,return-value" + +import math +import random +from datetime import date, timedelta +from decimal import Decimal + +from app.shared.seeder.config import ( + ChangepointConfig, + ChangepointEvent, + MultiSeasonalityConfig, + RetailPatternConfig, + SparsityConfig, + SubstitutionConfig, + TimeSeriesConfig, +) +from app.shared.seeder.generators.facts import SalesDailyGenerator + + +def _deterministic_ts_config() -> TimeSeriesConfig: + """A noise/anomaly-free config so multipliers can be asserted exactly.""" + return TimeSeriesConfig( + base_demand=100, + trend="none", + weekly_seasonality=[1.0] * 7, + monthly_seasonality={}, + noise_sigma=0.0, + anomaly_probability=0.0, + ) + + +def _deterministic_retail_config() -> RetailPatternConfig: + return RetailPatternConfig( + promotion_lift=1.0, + stockout_behavior="zero", + price_elasticity=0.0, + promotion_probability=0.0, + stockout_probability=0.0, + ) + + +def _flat_sparsity() -> SparsityConfig: + return SparsityConfig(missing_combinations_pct=0.0, random_gaps_per_series=0) + + +class TestYearlySeasonality: + def test_amplitude_zero_no_effect(self): + gen = SalesDailyGenerator( + random.Random(0), + _deterministic_ts_config(), + _deterministic_retail_config(), + _flat_sparsity(), + [], + multi_seasonality=MultiSeasonalityConfig(yearly_seasonality_amplitude=0.0), + ) + # Demand = base_demand exactly under the deterministic config. + assert gen._yearly_seasonality_multiplier(date(2024, 7, 1)) == 1.0 + + def test_amplitude_nonzero_introduces_swing(self): + cfg = MultiSeasonalityConfig(yearly_seasonality_amplitude=0.2) + gen = SalesDailyGenerator( + random.Random(0), + _deterministic_ts_config(), + _deterministic_retail_config(), + _flat_sparsity(), + [], + multi_seasonality=cfg, + ) + # On day-of-year 91 (≈ April 1) sin(2π · 91 / 365) ≈ 1; check sign. + m_apr = gen._yearly_seasonality_multiplier(date(2024, 4, 1)) + m_oct = gen._yearly_seasonality_multiplier(date(2024, 10, 1)) + assert m_apr > 1.0 + assert m_oct < 1.0 + # Bounded by ±amplitude. + assert 0.8 - 1e-9 <= m_oct <= 1.0 + assert 1.0 <= m_apr <= 1.2 + 1e-9 + + +class TestChangepoints: + def test_no_changepoints_returns_one(self): + gen = SalesDailyGenerator( + random.Random(0), + _deterministic_ts_config(), + _deterministic_retail_config(), + _flat_sparsity(), + [], + changepoints=ChangepointConfig(changepoints=[]), + ) + assert gen._changepoint_multiplier(date(2024, 6, 1)) == 1.0 + + def test_impulse_decays_exponentially(self): + cp = ChangepointEvent(date=date(2024, 6, 1), demand_multiplier=2.0, decay_days=10) + gen = SalesDailyGenerator( + random.Random(0), + _deterministic_ts_config(), + _deterministic_retail_config(), + _flat_sparsity(), + [], + changepoints=ChangepointConfig(changepoints=[cp]), + ) + # Day 0: multiplier == 2.0 + assert math.isclose(gen._changepoint_multiplier(date(2024, 6, 1)), 2.0) + # Day 10: multiplier ≈ 1 + (2-1) * e^-1 ≈ 1.3679 + m10 = gen._changepoint_multiplier(date(2024, 6, 11)) + assert 1.35 < m10 < 1.40 + # Before the changepoint: 1.0 + assert gen._changepoint_multiplier(date(2024, 5, 31)) == 1.0 + + def test_pure_impulse_zero_decay(self): + cp = ChangepointEvent(date=date(2024, 6, 1), demand_multiplier=3.0, decay_days=0) + gen = SalesDailyGenerator( + random.Random(0), + _deterministic_ts_config(), + _deterministic_retail_config(), + _flat_sparsity(), + [], + changepoints=ChangepointConfig(changepoints=[cp]), + ) + assert gen._changepoint_multiplier(date(2024, 6, 1)) == 3.0 + assert gen._changepoint_multiplier(date(2024, 6, 2)) == 1.0 + + +class TestWeatherMultiplier: + def test_no_lookup_returns_one(self): + gen = SalesDailyGenerator( + random.Random(0), + _deterministic_ts_config(), + _deterministic_retail_config(), + _flat_sparsity(), + [], + exogenous_weather=None, + weather_temperature_sensitivity=0.01, + ) + assert gen._weather_multiplier(date(2024, 7, 1), 1) == 1.0 + + def test_sensitivity_zero_returns_one_even_with_lookup(self): + gen = SalesDailyGenerator( + random.Random(0), + _deterministic_ts_config(), + _deterministic_retail_config(), + _flat_sparsity(), + [], + exogenous_weather={(1, date(2024, 7, 1)): 30.0}, + weather_temperature_sensitivity=0.0, + weather_climatology_mean_c=15.0, + ) + assert gen._weather_multiplier(date(2024, 7, 1), 1) == 1.0 + + def test_warm_day_lifts_demand(self): + gen = SalesDailyGenerator( + random.Random(0), + _deterministic_ts_config(), + _deterministic_retail_config(), + _flat_sparsity(), + [], + exogenous_weather={(1, date(2024, 7, 1)): 25.0}, + weather_temperature_sensitivity=0.02, + weather_climatology_mean_c=15.0, + ) + # 1 + 0.02 * (25 - 15) = 1.2 + assert math.isclose(gen._weather_multiplier(date(2024, 7, 1), 1), 1.2) + + +class TestSubstitution: + def test_disabled_returns_one(self): + sub = SubstitutionConfig( + enable=False, + substitute_groups=[[1, 2]], + substitution_lift_on_stockout=0.5, + ) + gen = SalesDailyGenerator( + random.Random(0), + _deterministic_ts_config(), + _deterministic_retail_config(), + _flat_sparsity(), + [], + substitution=sub, + ) + assert gen._substitution_multiplier(1, {2}) == 1.0 + + def test_no_group_member_returns_one(self): + sub = SubstitutionConfig( + enable=True, + substitute_groups=[[1, 2]], + substitution_lift_on_stockout=0.5, + ) + gen = SalesDailyGenerator( + random.Random(0), + _deterministic_ts_config(), + _deterministic_retail_config(), + _flat_sparsity(), + [], + substitution=sub, + ) + # Product 3 isn't in any group → no lift. + assert gen._substitution_multiplier(3, {2}) == 1.0 + + def test_lift_when_groupmate_out(self): + sub = SubstitutionConfig( + enable=True, + substitute_groups=[[1, 2, 3]], + substitution_lift_on_stockout=0.6, + ) + gen = SalesDailyGenerator( + random.Random(0), + _deterministic_ts_config(), + _deterministic_retail_config(), + _flat_sparsity(), + [], + substitution=sub, + ) + # Product 1 in stock, products 2 in stock, product 3 stocked out. + # out_members=1, survivors=1 (product 2). product 1's share is + # 0.6 * 1 / (survivors + 1) = 0.3 → multiplier 1.3. + m = gen._substitution_multiplier(1, {3}) + assert math.isclose(m, 1.3) + + def test_stocked_out_product_gets_no_lift(self): + sub = SubstitutionConfig( + enable=True, + substitute_groups=[[1, 2]], + substitution_lift_on_stockout=0.5, + ) + gen = SalesDailyGenerator( + random.Random(0), + _deterministic_ts_config(), + _deterministic_retail_config(), + _flat_sparsity(), + [], + substitution=sub, + ) + # Product 1 itself stocked out → multiplier is 1.0. + assert gen._substitution_multiplier(1, {1}) == 1.0 + + +class TestPhase1EndToEnd: + def test_phase1_features_alter_quantities(self): + # With deterministic ts/retail config and a changepoint impulse, + # the day-of-change quantity should equal base x multiplier. + ts = _deterministic_ts_config() + retail = _deterministic_retail_config() + cp = ChangepointEvent(date=date(2024, 1, 1), demand_multiplier=2.0, decay_days=0) + gen = SalesDailyGenerator( + random.Random(0), + ts, + retail, + _flat_sparsity(), + [], + changepoints=ChangepointConfig(changepoints=[cp]), + ) + dates = [date(2024, 1, 1) + timedelta(days=i) for i in range(3)] + sales = gen.generate( + store_ids=[1], + product_data=[(1, Decimal("10.00"))], + dates=dates, + promotions={}, + stockouts={}, + ) + by_date = {s["date"]: s["quantity"] for s in sales} + # Day 0: 200 (2x base). Days 1+: 100 (no decay, decay_days=0). + assert by_date[date(2024, 1, 1)] == 200 + assert by_date[date(2024, 1, 2)] == 100 + assert by_date[date(2024, 1, 3)] == 100 diff --git a/app/shared/seeder/tests/test_returns.py b/app/shared/seeder/tests/test_returns.py new file mode 100644 index 00000000..6421b01a --- /dev/null +++ b/app/shared/seeder/tests/test_returns.py @@ -0,0 +1,102 @@ +"""Tests for ReturnsGenerator (Phase 1).""" + +# mypy: disable-error-code="union-attr,arg-type,operator,return-value" + +import random +from datetime import date, timedelta +from decimal import Decimal + +from app.shared.seeder.config import ReturnsConfig +from app.shared.seeder.generators.returns import ReturnsGenerator + + +def _sales_records(n: int, start: date = date(2024, 1, 1)) -> list[dict[str, object]]: + """Build n synthetic sales rows in the shape SalesDailyGenerator emits.""" + return [ + { + "date": start + timedelta(days=i), + "store_id": 1, + "product_id": 100, + "quantity": 10, + "unit_price": Decimal("9.99"), + "total_amount": Decimal("99.90"), + } + for i in range(n) + ] + + +class TestReturnsGeneratorDisabled: + def test_disabled_emits_nothing(self): + gen = ReturnsGenerator(random.Random(42), ReturnsConfig(enable=False)) + assert gen.generate(_sales_records(50), date(2024, 1, 31)) == [] + + +class TestReturnsGeneratorEnabled: + def test_returns_fire_at_configured_rate(self): + # Probability 1.0 means every sale generates a return. + cfg = ReturnsConfig(enable=True, return_probability=1.0) + gen = ReturnsGenerator(random.Random(0), cfg) + sales = _sales_records(200) + returns = gen.generate(sales, date(2024, 1, 31)) + assert len(returns) == 200 + + def test_probability_zero_no_returns(self): + cfg = ReturnsConfig(enable=True, return_probability=0.0) + gen = ReturnsGenerator(random.Random(0), cfg) + assert gen.generate(_sales_records(50), date(2024, 1, 31)) == [] + + def test_return_quantity_is_positive_and_capped(self): + # quantity_fraction=2.0 should be clamped to original quantity. + cfg = ReturnsConfig(enable=True, return_probability=1.0, return_quantity_fraction=2.0) + gen = ReturnsGenerator(random.Random(0), cfg) + sales = _sales_records(20) + returns = gen.generate(sales, date(2024, 1, 31)) + for r in returns: + assert 1 <= r["return_quantity"] <= 10 # capped at sale quantity + + def test_return_date_clamped_to_end_date(self): + cfg = ReturnsConfig( + enable=True, + return_probability=1.0, + return_lag_days_min=30, + return_lag_days_max=30, + ) + gen = ReturnsGenerator(random.Random(0), cfg) + sales = _sales_records(5, start=date(2024, 1, 20)) + end = date(2024, 1, 31) + returns = gen.generate(sales, end) + for r in returns: + assert r["date"] <= end + + def test_reasons_drawn_from_distribution(self): + cfg = ReturnsConfig( + enable=True, + return_probability=1.0, + return_reason_distribution={"defective": 1.0}, + ) + gen = ReturnsGenerator(random.Random(0), cfg) + sales = _sales_records(10) + returns = gen.generate(sales, date(2024, 1, 31)) + assert all(r["return_reason"] == "defective" for r in returns) + + def test_reproducible(self): + cfg = ReturnsConfig(enable=True, return_probability=0.5) + sales = _sales_records(100) + a = ReturnsGenerator(random.Random(7), cfg).generate(sales, date(2024, 12, 31)) + b = ReturnsGenerator(random.Random(7), cfg).generate(sales, date(2024, 12, 31)) + assert a == b + + def test_zero_quantity_sales_skipped(self): + cfg = ReturnsConfig(enable=True, return_probability=1.0) + gen = ReturnsGenerator(random.Random(0), cfg) + sales = [ + { + "date": date(2024, 1, 1), + "store_id": 1, + "product_id": 2, + "quantity": 0, + "unit_price": Decimal("9.99"), + "total_amount": Decimal("0.00"), + } + ] + assert gen.generate(sales, date(2024, 1, 31)) == [] diff --git a/docs/DATA-SEEDER.md b/docs/DATA-SEEDER.md index 8f94c8f3..3ae59223 100644 --- a/docs/DATA-SEEDER.md +++ b/docs/DATA-SEEDER.md @@ -196,6 +196,93 @@ uv run python scripts/seed_random.py --full-new --config examples/seed/config_cu - **Price Elasticity**: Demand adjustment based on price changes - **New Product Ramps**: Gradual demand increase for new launches +## Phase 1 Realism Extensions + +Phase 1 adds opt-in realism: exogenous signals, multi-seasonality, trend changepoints, +returns volume, and stockout substitution. Each extension is gated behind its own flag +on `GenerateParams` (or its dataclass on `SeederConfig`). **Existing scenarios with no +flags set produce byte-identical seeded data to pre-Phase-1** — the regression invariant +is enforced by `app/shared/seeder/tests/test_phase1_regression.py`. + +### Exogenous Signals + +Persisted in the `exogenous_signal` table. Three signals available: + +| Signal | Scope | Shape | +|--------|-------|-------| +| `weather_temp_c` | per (store, date) | sinusoidal climatology + Gaussian noise | +| `macro_index` | per date (global) | random walk from `macro_initial_value` | +| `event_flag` | per `event_dates` entry | binary 1.0 marker on configured dates | + +Toggle via `GenerateParams.enable_exogenous=true` (turns on weather + macro). To also +drive demand from weather, pass `weather_temperature_sensitivity` (e.g. `0.02` = +2% +demand per °C above the climatology mean). + +Read back: + +```bash +curl "http://localhost:8123/seeder/exogenous?signal_name=weather_temp_c&start_date=2024-01-01&end_date=2024-01-31" +``` + +### Multi-Seasonality + +Yearly sin wave on top of weekly + monthly seasonality: + +```json +{"yearly_seasonality_amplitude": 0.15} +``` + +Amplitude is a fraction of base demand (0–1). 0 or unset = disabled. + +### Changepoints + +COVID-style demand impulses with exponential decay: + +```json +{ + "changepoints": [ + {"date": "2024-03-15", "demand_multiplier": 2.0, "decay_days": 60} + ] +} +``` + +`decay_days=0` means a pure impulse on the changepoint date. + +### Returns + +Synthetic returns volume in the `sales_returns` table. A configurable fraction of +sales rows generates a delayed return: + +```json +{"enable_returns": true} +``` + +Tune via `ReturnsConfig` on `SeederConfig` (default ~2% of sales, lag 1–14 days, with +reasons drawn from `defective`/`wrong_size`/`not_as_described`/`changed_mind`/ +`damaged_in_transit`). + +### Substitution on Stockout + +When a member of a substitute group is stocked out, the surviving members pick up a +share of demand: + +```json +{ + "enable_substitution": true, + "substitute_groups": [[1, 2, 3]], + "substitution_lift_on_stockout": 0.5 +} +``` + +`product_id` values must already exist in the dataset. The lift is split across in-stock +group-mates. + +### Phase 1 API surface + +- `POST /seeder/generate` accepts the Phase 1 fields above; defaults keep Phase 1 off. +- `GET /seeder/exogenous?signal_name=&start_date=&end_date=&store_id=` returns signal rows. +- `GET /seeder/status` adds `exogenous_signals` and `sales_returns` counts. + ## Data Integrity The seeder enforces data integrity: @@ -204,6 +291,10 @@ The seeder enforces data integrity: 2. **Non-Negative Values**: Quantities and prices are always non-negative 3. **Date Coverage**: Calendar table covers entire date range 4. **Uniqueness**: Store codes and product SKUs are unique +5. **Phase 1 — Returns positive**: `sales_returns.return_quantity` is always ≥ 1 +6. **Phase 1 — Exogenous consistency**: every `exogenous_signal` row satisfies + `is_global = true ⇔ store_id IS NULL` (enforced by a CHECK constraint and verified + by `verify_data_integrity`) Verify with: ```bash