Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 120 additions & 11 deletions app/features/backtesting/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
- Training and predicting with models per fold
- Calculating metrics and aggregating results
- Running baseline comparisons
- Saving results to configured directory

CRITICAL: All operations respect time-safety constraints.
"""
Expand All @@ -16,6 +17,7 @@
import uuid
from dataclasses import dataclass, field
from datetime import date as date_type
from pathlib import Path
from typing import TYPE_CHECKING, Any

import numpy as np
Expand Down Expand Up @@ -88,6 +90,73 @@ def __init__(self) -> None:
self.settings = get_settings()
self.metrics_calculator = MetricsCalculator()

def _validate_config(self, config: BacktestConfig) -> None:
"""Validate backtest configuration against settings constraints.

Args:
config: Backtest configuration to validate.

Raises:
ValueError: If config violates settings constraints.
"""
split_config = config.split_config

# Validate n_splits against backtest_max_splits
if split_config.n_splits > self.settings.backtest_max_splits:
raise ValueError(
f"n_splits ({split_config.n_splits}) exceeds maximum allowed "
f"({self.settings.backtest_max_splits}). "
f"Adjust split_config.n_splits or increase BACKTEST_MAX_SPLITS setting."
)

# Validate gap against backtest_max_gap
if split_config.gap > self.settings.backtest_max_gap:
raise ValueError(
f"gap ({split_config.gap}) exceeds maximum allowed "
f"({self.settings.backtest_max_gap}). "
f"Adjust split_config.gap or increase BACKTEST_MAX_GAP setting."
)

# Validate min_train_size meets minimum threshold
if split_config.min_train_size < self.settings.backtest_default_min_train_size:
logger.warning(
"backtesting.min_train_size_below_default",
provided=split_config.min_train_size,
default=self.settings.backtest_default_min_train_size,
message="Using provided min_train_size below recommended default",
)

def save_results(
self,
response: BacktestResponse,
filename: str | None = None,
) -> Path:
"""Save backtest results to configured results directory.

Args:
response: BacktestResponse to save.
filename: Optional custom filename. Defaults to backtest_id.json.

Returns:
Path to saved results file.
"""
results_dir = Path(self.settings.backtest_results_dir)
results_dir.mkdir(parents=True, exist_ok=True)

if filename is None:
filename = f"{response.backtest_id}.json"

file_path = results_dir / filename
file_path.write_text(response.model_dump_json(indent=2))

logger.info(
"backtesting.results_saved",
backtest_id=response.backtest_id,
file_path=str(file_path),
)

return file_path

async def run_backtest(
self,
db: AsyncSession,
Expand All @@ -111,8 +180,12 @@ async def run_backtest(
BacktestResponse with all results.

Raises:
ValueError: If insufficient data for requested splits.
ValueError: If insufficient data for requested splits or config
violates settings constraints.
"""
# Validate config against settings constraints
self._validate_config(config)

start_time = time.perf_counter()
backtest_id = uuid.uuid4().hex[:16]

Expand Down Expand Up @@ -331,6 +404,10 @@ def _run_baseline_comparisons(

return results

# Metrics where the sign matters and we should compare absolute values
# for percentage improvement calculations
SIGNED_METRICS: frozenset[str] = frozenset({"bias"})

def _generate_comparison_summary(
self,
main_results: ModelBacktestResult,
Expand All @@ -345,11 +422,16 @@ def _generate_comparison_summary(
Returns:
Dictionary with comparison metrics.
Keys are metric names, values are dicts with:
- main: Main model value
- naive: Naive baseline value (if available)
- seasonal_naive: Seasonal naive value (if available)
- main: Main model value (original signed value)
- naive: Naive baseline value (original signed value, if available)
- seasonal_naive: Seasonal naive value (original signed value, if available)
- vs_naive_pct: Percentage improvement over naive
- vs_seasonal_pct: Percentage improvement over seasonal

Note:
For signed metrics (e.g., bias), percentage improvements are computed
using absolute values since a smaller absolute value is better
regardless of sign.
"""
summary: dict[str, dict[str, float]] = {}

Expand All @@ -362,21 +444,48 @@ def _generate_comparison_summary(
for metric_name, main_value in main_results.aggregated_metrics.items():
comparison: dict[str, float] = {"main": main_value}

# Determine if this is a signed metric
is_signed = metric_name in self.SIGNED_METRICS

# Add baseline values and compute improvements
if "naive" in baseline_by_type:
naive_value = baseline_by_type["naive"].get(metric_name, np.nan)
comparison["naive"] = naive_value
if not np.isnan(naive_value) and naive_value != 0:
# Negative improvement means main is worse
comparison["vs_naive_pct"] = ((naive_value - main_value) / naive_value) * 100

if not np.isnan(naive_value):
if is_signed:
# For signed metrics, compare absolute values
abs_main = abs(main_value)
abs_naive = abs(naive_value)
if abs_naive != 0:
# Improvement = (abs_baseline - abs_main) / abs_baseline * 100
comparison["vs_naive_pct"] = (
(abs_naive - abs_main) / abs_naive
) * 100
elif naive_value != 0:
# For unsigned metrics, use original formula
comparison["vs_naive_pct"] = (
(naive_value - main_value) / naive_value
) * 100

if "seasonal_naive" in baseline_by_type:
seasonal_value = baseline_by_type["seasonal_naive"].get(metric_name, np.nan)
comparison["seasonal_naive"] = seasonal_value
if not np.isnan(seasonal_value) and seasonal_value != 0:
comparison["vs_seasonal_pct"] = (
(seasonal_value - main_value) / seasonal_value
) * 100

if not np.isnan(seasonal_value):
if is_signed:
# For signed metrics, compare absolute values
abs_main = abs(main_value)
abs_seasonal = abs(seasonal_value)
if abs_seasonal != 0:
comparison["vs_seasonal_pct"] = (
(abs_seasonal - abs_main) / abs_seasonal
) * 100
elif seasonal_value != 0:
# For unsigned metrics, use original formula
comparison["vs_seasonal_pct"] = (
(seasonal_value - main_value) / seasonal_value
) * 100

summary[metric_name] = comparison

Expand Down
Loading
Loading