Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 166 additions & 9 deletions src/eva/models/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
``RunConfig(_env_file=".env", _cli_parse_args=True)``.
"""

import copy
import logging
from datetime import UTC, datetime
from pathlib import Path
from typing import Annotated, Any, ClassVar, Literal
Expand All @@ -34,9 +36,12 @@

from eva.models.provenance import RunProvenance

logger = logging.getLogger(__name__)

def current_date_and_time():
return f"{datetime.now(UTC):%Y-%m-%d_%H-%M-%S.%f}"

def _param_alias(params: dict[str, Any]) -> str:
"""Return the display alias from a params dict."""
return params.get("alias") or params["model"]


class PipelineConfig(BaseModel):
Expand Down Expand Up @@ -73,6 +78,15 @@ class PipelineConfig(BaseModel):
),
)

@property
def pipeline_parts(self) -> dict[str, str]:
"""Component names for this pipeline."""
return {
"stt": _param_alias(self.stt_params) or self.stt,
"llm": self.llm,
"tts": _param_alias(self.tts_params) or self.tts,
}

@model_validator(mode="before")
@classmethod
def _migrate_legacy_fields(cls, data: Any) -> Any:
Expand Down Expand Up @@ -107,6 +121,11 @@ class SpeechToSpeechConfig(BaseModel):
"Set via EVA_MODEL__TURN_STRATEGY=external."
),
)

@property
def pipeline_parts(self) -> dict[str, str]:
"""Component names for this pipeline."""
return {"s2s": _param_alias(self.s2s_params) or self.s2s}


class AudioLLMConfig(BaseModel):
Expand All @@ -129,6 +148,14 @@ class AudioLLMConfig(BaseModel):
tts: str = Field(description="TTS model", examples=["cartesia", "elevenlabs"])
tts_params: dict[str, Any] = Field({}, description="Additional TTS model parameters (JSON)")

@property
def pipeline_parts(self) -> dict[str, str]:
"""Component names for this pipeline."""
return {
"audio_llm": _param_alias(self.audio_llm_params) or self.audio_llm,
"tts": _param_alias(self.tts_params) or self.tts,
}


_PIPELINE_FIELDS = {
"llm",
Expand Down Expand Up @@ -280,6 +307,21 @@ class RunConfig(BaseSettings):
"EVA_METRICS_TO_RUN": "EVA_METRICS",
}

# Providers that manage their own model/key resolution (e.g. WebSocket-based)
_SKIP_PARAMS_VALIDATION: ClassVar[set[str]] = {"nvidia"}

# Maps *_params field names to their provider field for env override logic
_PARAMS_TO_PROVIDER: ClassVar[dict[str, str]] = {
"stt_params": "stt",
"tts_params": "tts",
"s2s_params": "s2s",
"audio_llm_params": "audio_llm",
}
# Keys always read from the live environment (not persisted across runs)
_ENV_OVERRIDE_KEYS: ClassVar[set[str]] = {"url", "urls"}
# Substrings that identify secret keys (redacted in logs and config.json)
_SECRET_KEY_PATTERNS: ClassVar[set[str]] = {"key", "credentials", "secret"}

class ModelDeployment(DeploymentTypedDict):
"""DeploymentTypedDict that preserves extra keys in litellm_params."""

Expand All @@ -294,7 +336,7 @@ class ModelDeployment(DeploymentTypedDict):

# Run identifier
run_id: str = Field(
default_factory=current_date_and_time,
"timestamp and model name(s)", # Overwritten by _set_default_run_id()
description="Run identifier, auto-generated if not provided",
)

Expand Down Expand Up @@ -464,8 +506,12 @@ def _check_companion_services(self) -> "RunConfig":
self._validate_service_params("S2S", self.model.s2s, ["api_key"], self.model.s2s_params)
return self

# Providers that manage their own model/key resolution (e.g. WebSocket-based)
_SKIP_PARAMS_VALIDATION: ClassVar[set[str]] = {"nvidia"}
@model_validator(mode="after")
def _set_default_run_id(self) -> "RunConfig":
if "run_id" not in self.model_fields_set:
suffix = "_".join(v for v in self.model.pipeline_parts.values() if v)
self.run_id = f"{datetime.now(UTC):%Y-%m-%d_%H-%M-%S.%f}_{suffix}"
return self

@classmethod
def _validate_service_params(
Expand Down Expand Up @@ -503,20 +549,131 @@ def _expand_metrics_all(cls, v: list[str] | None) -> list[str] | None:
return [m for m in get_global_registry().list_metrics() if m not in cls._VALIDATION_METRIC_NAMES]
return v

@classmethod
def _is_secret_key(cls, key: str) -> bool:
"""Return True if *key* matches any pattern in _SECRET_KEY_PATTERNS."""
return any(pattern in key for pattern in cls._SECRET_KEY_PATTERNS)

@classmethod
def _redact_dict(cls, params: dict) -> dict:
"""Return a copy of *params* with secret values replaced by ``***``."""
return {k: "***" if cls._is_secret_key(k) else v for k, v in params.items()}

@field_serializer("model_list")
@classmethod
def _redact_model_list(cls, deployments: list[ModelDeployment]) -> list[dict]:
"""Redact secret values in litellm_params when serializing."""
redacted = []
for deployment in deployments:
deployment = copy.deepcopy(deployment)
if "litellm_params" in deployment:
params = deployment["litellm_params"]
for key in params:
if "key" in key or "credentials" in key:
params[key] = "***"
deployment["litellm_params"] = cls._redact_dict(deployment["litellm_params"])
redacted.append(deployment)
return redacted

@field_serializer("model")
@classmethod
def _redact_model_params(cls, model: ModelConfigUnion) -> dict:
"""Redact secret values in STT/TTS/S2S/AudioLLM params when serializing."""
data = model.model_dump(mode="json")
for field_name, value in data.items():
if field_name.endswith("_params") and isinstance(value, dict):
data[field_name] = cls._redact_dict(value)
return data

def apply_env_overrides(self, live: "RunConfig") -> None:
"""Apply environment-dependent values from *live* config onto this (saved) config.

Restores redacted secrets (``***``) and overrides dynamic fields (``url``,
``urls``) in ``model.*_params`` and ``model_list[].litellm_params``.

Raises:
ValueError: If provider or alias differs for a service with redacted secrets.
"""
# ── model.*_params (STT / TTS / S2S / AudioLLM) ──
for params_field, provider_field in self._PARAMS_TO_PROVIDER.items():
saved = getattr(self.model, params_field, None)
source = getattr(live.model, params_field, None)
if not isinstance(saved, dict) or not isinstance(source, dict):
continue

has_redacted = any(v == "***" for v in saved.values())
has_env_overrides = any(k in saved or k in source for k in self._ENV_OVERRIDE_KEYS)
if not has_redacted and not has_env_overrides:
continue

if has_redacted:
saved_alias = saved.get("alias")
live_alias = source.get("alias")
if saved_alias and live_alias and saved_alias != live_alias:
raise ValueError(
f"Cannot restore secrets: saved {params_field}[alias]={saved_alias!r} "
f"but current environment has {params_field}[alias]={live_alias!r}"
)

saved_provider = getattr(self.model, provider_field, None)
live_provider = getattr(live.model, provider_field, None)
if saved_provider != live_provider:
logger.warning(
f"Provider mismatch for {params_field}: saved {saved_provider!r}, "
f"current environment has {live_provider!r}"
)

saved_model = saved.get("model")
live_model = source.get("model")
if saved_model and live_model and saved_model != live_model:
logger.warning(
f"Model mismatch for {params_field}: saved {saved_model!r}, "
f"current environment has {live_model!r}"
)

for key, value in saved.items():
if value == "***" and key in source:
saved[key] = source[key]

# Always use url/urls from the live environment
for key in self._ENV_OVERRIDE_KEYS:
if key in source:
saved_val = saved.get(key)
if saved_val and saved_val != source[key]:
logger.warning(
f"{params_field}[{key}] differs: saved {saved_val!r}, "
f"using {source[key]!r} from current environment"
)
saved[key] = source[key]

# ── model_list[].litellm_params (LLM deployments) ──
live_by_name = {d["model_name"]: d for d in live.model_list if "model_name" in d}
for deployment in self.model_list:
name = deployment.get("model_name")
if not name:
continue
saved_params = deployment.get("litellm_params", {})
has_redacted = any(v == "***" for v in saved_params.values())
if not has_redacted:
continue
if name not in live_by_name:
raise ValueError(
f"Cannot restore secrets: deployment {name!r} not found in "
f"current EVA_MODEL_LIST (available: {list(live_by_name)})"
)
live_params = live_by_name[name].get("litellm_params", {})
for key, value in saved_params.items():
if value == "***" and key in live_params:
saved_params[key] = live_params[key]

# ── Log resolved configuration ──
for params_field, provider_field in self._PARAMS_TO_PROVIDER.items():
params = getattr(self.model, params_field, None)
provider = getattr(self.model, provider_field, None)
if isinstance(params, dict) and params:
logger.info(f"Resolved {provider_field} ({provider}): {self._redact_dict(params)}")

for deployment in self.model_list:
name = deployment.get("model_name", "?")
params = deployment.get("litellm_params", {})
logger.info(f"Resolved deployment {name}: {self._redact_dict(params)}")

@classmethod
def from_yaml(cls, path: Path | str) -> "RunConfig":
"""Load configuration from YAML file."""
Expand Down
5 changes: 4 additions & 1 deletion src/eva/orchestrator/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,10 @@ async def run(self, records: list[EvaluationRecord]) -> RunResult:
}

config_path = self.output_dir / "config.json"
config_path.write_text(self.config.model_dump_json(indent=2))
config_data = self.config.model_dump(mode="json")
pipeline_parts = self.config.model.pipeline_parts
config_data["pipeline_parts"] = pipeline_parts
config_path.write_text(json.dumps(config_data, indent=2))

# Build output_id list for tracking (supports pass@k)
num_trials = self.config.num_trials
Expand Down
3 changes: 3 additions & 0 deletions src/eva/run_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ async def run_benchmark(config: RunConfig) -> int:
logger.error(str(e))
return 1

# Apply env-dependent values (secrets, urls) from live env onto saved config
runner.config.apply_env_overrides(config)

# Apply CLI overrides
runner.config.max_rerun_attempts = config.max_rerun_attempts
runner.config.force_rerun_metrics = config.force_rerun_metrics
Expand Down
Loading