diff --git a/src/art/model.py b/src/art/model.py index 2960fa34..5c082fde 100644 --- a/src/art/model.py +++ b/src/art/model.py @@ -114,6 +114,7 @@ class Model( _openai_client: AsyncOpenAI | None = None _wandb_run: Optional["Run"] = None # Private, for lazy wandb initialization _wandb_defined_metrics: set[str] + _wandb_config: dict[str, Any] _run_start_time: float _run_start_monotonic: float _last_local_train_log_monotonic: float @@ -150,6 +151,7 @@ def __init__( **kwargs, ) object.__setattr__(self, "_wandb_defined_metrics", set()) + object.__setattr__(self, "_wandb_config", {}) object.__setattr__(self, "_run_start_time", time.time()) object.__setattr__(self, "_run_start_monotonic", time.monotonic()) object.__setattr__( @@ -371,6 +373,34 @@ def _deep_merge_dicts( merged[key] = value return merged + @staticmethod + def _merge_wandb_config( + existing: dict[str, Any], + updates: dict[str, Any], + *, + path: str = "", + ) -> dict[str, Any]: + merged = dict(existing) + for key, value in updates.items(): + key_path = f"{path}.{key}" if path else key + if key not in merged: + merged[key] = value + continue + existing_value = merged[key] + if isinstance(existing_value, dict) and isinstance(value, dict): + merged[key] = Model._merge_wandb_config( + existing_value, + value, + path=key_path, + ) + continue + if existing_value != value: + raise ValueError( + "W&B config is immutable once set. " + f"Conflicting value for '{key_path}'." + ) + return merged + def read_state(self) -> StateType | None: """Read persistent state from the model directory. @@ -390,6 +420,43 @@ def read_state(self) -> StateType | None: with open(state_path, "r") as f: return json.load(f) + def update_wandb_config( + self, + config: dict[str, Any], + ) -> None: + """Merge configuration into the W&B run config for this model. + + This can be called before the W&B run exists, in which case the config is + passed to `wandb.init(...)` when ART first creates the run. If the run is + already active, ART updates the run config immediately. + + Args: + config: JSON-serializable configuration to store on the W&B run. + """ + if not isinstance(config, dict): + raise TypeError("config must be a dict[str, Any]") + + merged = self._merge_wandb_config(self._wandb_config, config) + object.__setattr__(self, "_wandb_config", merged) + + if self._wandb_run is not None and not self._wandb_run._is_finished: + self._sync_wandb_config(self._wandb_run) + + def _sync_wandb_config( + self, + run: "Run", + ) -> None: + if not self._wandb_config: + return + + run_config = getattr(run, "config", None) + if run_config is None or not hasattr(run_config, "update"): + return + + run_config.update( + self._wandb_config, + ) + def _get_wandb_run(self) -> Optional["Run"]: """Get or create the wandb run for this model.""" import wandb @@ -401,6 +468,7 @@ def _get_wandb_run(self) -> Optional["Run"]: project=self.project, name=self.name, id=self.name, + config=self._wandb_config or None, resume="allow", settings=wandb.Settings( x_stats_open_metrics_endpoints={ @@ -436,6 +504,7 @@ def _get_wandb_run(self) -> Optional["Run"]: wandb.define_metric("val/*", step_metric="training_step") wandb.define_metric("test/*", step_metric="training_step") wandb.define_metric("discarded/*", step_metric="training_step") + self._sync_wandb_config(run) return self._wandb_run def _log_metrics( diff --git a/tests/unit/test_metric_routing.py b/tests/unit/test_metric_routing.py index ad7a94a9..7cb57d91 100644 --- a/tests/unit/test_metric_routing.py +++ b/tests/unit/test_metric_routing.py @@ -4,6 +4,8 @@ import types from unittest.mock import MagicMock, patch +import pytest + from art import Model @@ -80,6 +82,7 @@ def test_log_metrics_defines_nested_cost_keys_with_training_step( ) -> None: fake_run = MagicMock() fake_run._is_finished = False + fake_run.config = MagicMock() fake_wandb = types.SimpleNamespace() fake_wandb.init = MagicMock(return_value=fake_run) @@ -121,3 +124,90 @@ def test_log_metrics_defines_nested_cost_keys_with_training_step( assert logged_metrics["training_step"] == 1 assert "time/wall_clock_sec" in logged_metrics assert fake_run.log.call_args.kwargs == {} + + def test_update_wandb_config_seeds_wandb_init(self, tmp_path: Path) -> None: + fake_run = MagicMock() + fake_run._is_finished = False + fake_run.config = MagicMock() + + fake_wandb = types.SimpleNamespace() + fake_wandb.init = MagicMock(return_value=fake_run) + fake_wandb.define_metric = MagicMock() + fake_wandb.Settings = lambda **kwargs: kwargs + + payload = { + "experiment": {"learning_rate": 1e-5, "batch_size": 4}, + "dataset": {"split": "train"}, + } + + with patch.dict(os.environ, {"WANDB_API_KEY": "test-key"}, clear=False): + with patch.dict("sys.modules", {"wandb": fake_wandb}): + model = Model( + name="test-model", + project="test-project", + base_path=str(tmp_path), + ) + model.update_wandb_config(payload) + run = model._get_wandb_run() + + assert run is fake_run + init_kwargs = fake_wandb.init.call_args.kwargs + assert init_kwargs["config"] == payload + assert "allow_val_change" not in init_kwargs + fake_run.config.update.assert_called_once_with(payload) + + def test_update_wandb_config_updates_active_run(self, tmp_path: Path) -> None: + fake_run = MagicMock() + fake_run._is_finished = False + fake_run.config = MagicMock() + + fake_wandb = types.SimpleNamespace() + fake_wandb.init = MagicMock(return_value=fake_run) + fake_wandb.define_metric = MagicMock() + fake_wandb.Settings = lambda **kwargs: kwargs + + with patch.dict(os.environ, {"WANDB_API_KEY": "test-key"}, clear=False): + with patch.dict("sys.modules", {"wandb": fake_wandb}): + model = Model( + name="test-model", + project="test-project", + base_path=str(tmp_path), + ) + model.update_wandb_config({"experiment": {"learning_rate": 1e-5}}) + _ = model._get_wandb_run() + fake_run.config.update.reset_mock() + + model.update_wandb_config( + {"experiment": {"learning_rate": 1e-5, "batch_size": 8}}, + ) + model.update_wandb_config( + {"dataset": {"split": "train"}}, + ) + + assert fake_run.config.update.call_count == 2 + assert fake_run.config.update.call_args_list[0].args == ( + {"experiment": {"learning_rate": 1e-5, "batch_size": 8}}, + ) + assert fake_run.config.update.call_args_list[1].args == ( + { + "experiment": {"learning_rate": 1e-5, "batch_size": 8}, + "dataset": {"split": "train"}, + }, + ) + + def test_update_wandb_config_rejects_conflicting_values( + self, tmp_path: Path + ) -> None: + model = Model( + name="test-model", + project="test-project", + base_path=str(tmp_path), + ) + + model.update_wandb_config({"experiment": {"learning_rate": 1e-5}}) + + with pytest.raises( + ValueError, + match="Conflicting value for 'experiment.learning_rate'", + ): + model.update_wandb_config({"experiment": {"learning_rate": 2e-5}})