Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions src/forecast/custom_models/constant_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""Minimal forecasters for tests and as integration examples.

These mirror the column layout expected by
:func:`src.forecast.forecast.generate_forecast` and
:func:`src.evaluation.evaluate.evaluate_forecast` (point column + ``{alias}-q-{pp}``).
"""

import numpy as np
import pandas as pd
from timecopilot.models.utils.forecaster import Forecaster

from src.forecast.constants import QUANTILES


class ConstantForecastModel(Forecaster):
"""Point and quantile forecasts equal a fixed value (default ``1.0``).

Subclasses :class:`timecopilot.models.utils.forecaster.Forecaster`. Use as a
sanity-check baseline and a minimal template for adding models: ``forecast``
returns ``unique_id``, ``ds``, ``{alias}``, and ``{alias}-q-10`` … ``{alias}-q-90``.
"""

def __init__(self, value: float = 1.0, alias: str = "constant_one") -> None:
self.value = float(value)
self.alias = alias

def forecast(
self,
df: pd.DataFrame,
h: int,
freq: str | None = None,
level: list[int | float] | None = None,
quantiles: list[float] | None = None,
) -> pd.DataFrame:
if level is not None:
raise NotImplementedError(
"ConstantForecastModel does not support ``level``"
)
inferred = self._maybe_infer_freq(df, freq)
if not pd.api.types.is_datetime64_any_dtype(df["ds"]):
df = df.copy()
df["ds"] = pd.to_datetime(df["ds"])

# Last observed timestamp is shared across series in this benchmark loader.
max_ds = df["ds"].max()
future = pd.date_range(
start=max_ds,
periods=h + 1,
freq=inferred,
)[1:]

qs = QUANTILES if quantiles is None else quantiles
uids = df["unique_id"].unique()
n_u, n_f = len(uids), len(future)
# Vectorized (series × horizon); no per-row Python loop.
out = pd.DataFrame(
{
"unique_id": np.repeat(uids, n_f),
"ds": np.tile(np.asarray(future, dtype="datetime64[ns]"), n_u),
self.alias: self.value,
}
)
# Same constant for each quantile column; names match evaluate.py.
q_block = {
f"{self.alias}-q-{int(round(q * 100))}": np.full(n_u * n_f, self.value)
for q in qs
}
return pd.concat([out, pd.DataFrame(q_block)], axis=1)
3 changes: 3 additions & 0 deletions src/forecast/forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
ZeroModel,
)

from src.forecast.constant_models import ConstantForecastModel
from src.forecast.constants import QUANTILES

# Model categories for infrastructure selection
Expand All @@ -27,6 +28,8 @@
"zero_model": lambda: ZeroModel(),
"historic_average": lambda: HistoricAverage(),
"seasonal_naive": lambda: SeasonalNaive(),
# Example
"constant_one": lambda: ConstantForecastModel(1.0, alias="constant_one"),
# Statistical models
"auto_arima": lambda: AutoARIMA(),
"auto_ets": lambda: AutoETS(),
Expand Down
Loading