Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions src/art/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ class Model(
_openai_client: AsyncOpenAI | None = None
_wandb_run: Optional["Run"] = None # Private, for lazy wandb initialization
_wandb_defined_metrics: set[str]
_wandb_config: dict[str, Any]
_run_start_time: float
_run_start_monotonic: float
_last_local_train_log_monotonic: float
Expand Down Expand Up @@ -150,6 +151,7 @@ def __init__(
**kwargs,
)
object.__setattr__(self, "_wandb_defined_metrics", set())
object.__setattr__(self, "_wandb_config", {})
object.__setattr__(self, "_run_start_time", time.time())
object.__setattr__(self, "_run_start_monotonic", time.monotonic())
object.__setattr__(
Expand Down Expand Up @@ -371,6 +373,34 @@ def _deep_merge_dicts(
merged[key] = value
return merged

@staticmethod
def _merge_wandb_config(
existing: dict[str, Any],
updates: dict[str, Any],
*,
path: str = "",
) -> dict[str, Any]:
merged = dict(existing)
for key, value in updates.items():
key_path = f"{path}.{key}" if path else key
if key not in merged:
merged[key] = value
continue
existing_value = merged[key]
if isinstance(existing_value, dict) and isinstance(value, dict):
merged[key] = Model._merge_wandb_config(
existing_value,
value,
path=key_path,
)
continue
if existing_value != value:
raise ValueError(
"W&B config is immutable once set. "
f"Conflicting value for '{key_path}'."
)
return merged

def read_state(self) -> StateType | None:
"""Read persistent state from the model directory.

Expand All @@ -390,6 +420,43 @@ def read_state(self) -> StateType | None:
with open(state_path, "r") as f:
return json.load(f)

def update_wandb_config(
self,
config: dict[str, Any],
) -> None:
"""Merge configuration into the W&B run config for this model.

This can be called before the W&B run exists, in which case the config is
passed to `wandb.init(...)` when ART first creates the run. If the run is
already active, ART updates the run config immediately.

Args:
config: JSON-serializable configuration to store on the W&B run.
"""
if not isinstance(config, dict):
raise TypeError("config must be a dict[str, Any]")

merged = self._merge_wandb_config(self._wandb_config, config)
object.__setattr__(self, "_wandb_config", merged)

if self._wandb_run is not None and not self._wandb_run._is_finished:
self._sync_wandb_config(self._wandb_run)

def _sync_wandb_config(
self,
run: "Run",
) -> None:
if not self._wandb_config:
return

run_config = getattr(run, "config", None)
if run_config is None or not hasattr(run_config, "update"):
return

run_config.update(
self._wandb_config,
)

def _get_wandb_run(self) -> Optional["Run"]:
"""Get or create the wandb run for this model."""
import wandb
Expand All @@ -401,6 +468,7 @@ def _get_wandb_run(self) -> Optional["Run"]:
project=self.project,
name=self.name,
id=self.name,
config=self._wandb_config or None,
resume="allow",
settings=wandb.Settings(
x_stats_open_metrics_endpoints={
Expand Down Expand Up @@ -436,6 +504,7 @@ def _get_wandb_run(self) -> Optional["Run"]:
wandb.define_metric("val/*", step_metric="training_step")
wandb.define_metric("test/*", step_metric="training_step")
wandb.define_metric("discarded/*", step_metric="training_step")
self._sync_wandb_config(run)
return self._wandb_run

def _log_metrics(
Expand Down
90 changes: 90 additions & 0 deletions tests/unit/test_metric_routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import types
from unittest.mock import MagicMock, patch

import pytest

from art import Model


Expand Down Expand Up @@ -80,6 +82,7 @@ def test_log_metrics_defines_nested_cost_keys_with_training_step(
) -> None:
fake_run = MagicMock()
fake_run._is_finished = False
fake_run.config = MagicMock()

fake_wandb = types.SimpleNamespace()
fake_wandb.init = MagicMock(return_value=fake_run)
Expand Down Expand Up @@ -121,3 +124,90 @@ def test_log_metrics_defines_nested_cost_keys_with_training_step(
assert logged_metrics["training_step"] == 1
assert "time/wall_clock_sec" in logged_metrics
assert fake_run.log.call_args.kwargs == {}

def test_update_wandb_config_seeds_wandb_init(self, tmp_path: Path) -> None:
fake_run = MagicMock()
fake_run._is_finished = False
fake_run.config = MagicMock()

fake_wandb = types.SimpleNamespace()
fake_wandb.init = MagicMock(return_value=fake_run)
fake_wandb.define_metric = MagicMock()
fake_wandb.Settings = lambda **kwargs: kwargs

payload = {
"experiment": {"learning_rate": 1e-5, "batch_size": 4},
"dataset": {"split": "train"},
}

with patch.dict(os.environ, {"WANDB_API_KEY": "test-key"}, clear=False):
with patch.dict("sys.modules", {"wandb": fake_wandb}):
model = Model(
name="test-model",
project="test-project",
base_path=str(tmp_path),
)
model.update_wandb_config(payload)
run = model._get_wandb_run()

assert run is fake_run
init_kwargs = fake_wandb.init.call_args.kwargs
assert init_kwargs["config"] == payload
assert "allow_val_change" not in init_kwargs
fake_run.config.update.assert_called_once_with(payload)

def test_update_wandb_config_updates_active_run(self, tmp_path: Path) -> None:
fake_run = MagicMock()
fake_run._is_finished = False
fake_run.config = MagicMock()

fake_wandb = types.SimpleNamespace()
fake_wandb.init = MagicMock(return_value=fake_run)
fake_wandb.define_metric = MagicMock()
fake_wandb.Settings = lambda **kwargs: kwargs

with patch.dict(os.environ, {"WANDB_API_KEY": "test-key"}, clear=False):
with patch.dict("sys.modules", {"wandb": fake_wandb}):
model = Model(
name="test-model",
project="test-project",
base_path=str(tmp_path),
)
model.update_wandb_config({"experiment": {"learning_rate": 1e-5}})
_ = model._get_wandb_run()
fake_run.config.update.reset_mock()

model.update_wandb_config(
{"experiment": {"learning_rate": 1e-5, "batch_size": 8}},
)
model.update_wandb_config(
{"dataset": {"split": "train"}},
)

assert fake_run.config.update.call_count == 2
assert fake_run.config.update.call_args_list[0].args == (
{"experiment": {"learning_rate": 1e-5, "batch_size": 8}},
)
assert fake_run.config.update.call_args_list[1].args == (
{
"experiment": {"learning_rate": 1e-5, "batch_size": 8},
"dataset": {"split": "train"},
},
)

def test_update_wandb_config_rejects_conflicting_values(
self, tmp_path: Path
) -> None:
model = Model(
name="test-model",
project="test-project",
base_path=str(tmp_path),
)

model.update_wandb_config({"experiment": {"learning_rate": 1e-5}})

with pytest.raises(
ValueError,
match="Conflicting value for 'experiment.learning_rate'",
):
model.update_wandb_config({"experiment": {"learning_rate": 2e-5}})
Loading