From 0db8ff8411ce2a46f9564faae030d35796ac9992 Mon Sep 17 00:00:00 2001 From: Bohdan Date: Mon, 16 Mar 2026 16:30:09 -0700 Subject: [PATCH 1/2] Fix wandb run reuse across models --- src/art/model.py | 31 +++++----- tests/unit/test_wandb_multi_run.py | 92 ++++++++++++++++++++++++++++++ 2 files changed, 109 insertions(+), 14 deletions(-) create mode 100644 tests/unit/test_wandb_multi_run.py diff --git a/src/art/model.py b/src/art/model.py index 5c082fde..ab08a88d 100644 --- a/src/art/model.py +++ b/src/art/model.py @@ -470,6 +470,7 @@ def _get_wandb_run(self) -> Optional["Run"]: id=self.name, config=self._wandb_config or None, resume="allow", + reinit="create_new", settings=wandb.Settings( x_stats_open_metrics_endpoints={ "vllm": "http://localhost:8000/metrics", @@ -492,18 +493,18 @@ def _get_wandb_run(self) -> Optional["Run"]: # Define training_step as the x-axis for all metrics. # This allows out-of-order logging (e.g., async validation for previous steps). - wandb.define_metric("training_step") - wandb.define_metric("time/wall_clock_sec") - wandb.define_metric("reward/*", step_metric="training_step") - wandb.define_metric("loss/*", step_metric="training_step") - wandb.define_metric("throughput/*", step_metric="training_step") - wandb.define_metric("costs/*", step_metric="training_step") - wandb.define_metric("time/*", step_metric="training_step") - wandb.define_metric("data/*", step_metric="training_step") - wandb.define_metric("train/*", step_metric="training_step") - wandb.define_metric("val/*", step_metric="training_step") - wandb.define_metric("test/*", step_metric="training_step") - wandb.define_metric("discarded/*", step_metric="training_step") + run.define_metric("training_step") + run.define_metric("time/wall_clock_sec") + run.define_metric("reward/*", step_metric="training_step") + run.define_metric("loss/*", step_metric="training_step") + run.define_metric("throughput/*", step_metric="training_step") + run.define_metric("costs/*", step_metric="training_step") + run.define_metric("time/*", step_metric="training_step") + run.define_metric("data/*", step_metric="training_step") + run.define_metric("train/*", step_metric="training_step") + run.define_metric("val/*", step_metric="training_step") + run.define_metric("test/*", step_metric="training_step") + run.define_metric("discarded/*", step_metric="training_step") self._sync_wandb_config(run) return self._wandb_run @@ -562,14 +563,16 @@ def _log_metrics( run.log(prefixed) def _define_wandb_step_metrics(self, keys: Iterable[str]) -> None: - import wandb + run = self._wandb_run + if run is None or run._is_finished: + return for key in keys: if not key.startswith("costs/"): continue if key in self._wandb_defined_metrics: continue - wandb.define_metric(key, step_metric="training_step") + run.define_metric(key, step_metric="training_step") self._wandb_defined_metrics.add(key) def _route_metrics_and_collect_non_costs( diff --git a/tests/unit/test_wandb_multi_run.py b/tests/unit/test_wandb_multi_run.py new file mode 100644 index 00000000..33eb9eae --- /dev/null +++ b/tests/unit/test_wandb_multi_run.py @@ -0,0 +1,92 @@ +import os +import sys +from pathlib import Path +from unittest.mock import patch + +from art import Model + + +def test_wandb_creates_separate_runs_per_model(tmp_path: Path): + class FakeRun: + def __init__(self, name: str): + self.name = name + self.id = name + self._is_finished = False + self.defined_metrics: list[tuple[str, str | None]] = [] + + def define_metric(self, name: str, *, step_metric: str | None = None) -> None: + self.defined_metrics.append((name, step_metric)) + + class FakeWandb: + def __init__(self): + self.init_calls: list[dict] = [] + self.runs: list[FakeRun] = [] + + @staticmethod + def Settings(**kwargs): + return kwargs + + def init(self, **kwargs): + self.init_calls.append(kwargs) + run = FakeRun(kwargs["name"]) + self.runs.append(run) + return run + + def define_metric(self, *args, **kwargs) -> None: + raise AssertionError("Model should define metrics on the run object") + + fake_wandb = FakeWandb() + model_one = Model( + name="run-one", + project="test-project", + base_path=str(tmp_path), + ) + model_two = Model( + name="run-two", + project="test-project", + base_path=str(tmp_path), + ) + + with patch.dict(os.environ, {"WANDB_API_KEY": "test-key"}): + with patch.dict(sys.modules, {"wandb": fake_wandb}): + run_one = model_one._get_wandb_run() + run_two = model_two._get_wandb_run() + model_one._define_wandb_step_metrics(["costs/train/custom"]) + + assert run_one is not None + assert run_two is not None + assert run_one is not run_two + assert [call["name"] for call in fake_wandb.init_calls] == [ + "run-one", + "run-two", + ] + assert all(call["reinit"] == "create_new" for call in fake_wandb.init_calls) + assert run_one.defined_metrics == [ + ("training_step", None), + ("time/wall_clock_sec", None), + ("reward/*", "training_step"), + ("loss/*", "training_step"), + ("throughput/*", "training_step"), + ("costs/*", "training_step"), + ("time/*", "training_step"), + ("data/*", "training_step"), + ("train/*", "training_step"), + ("val/*", "training_step"), + ("test/*", "training_step"), + ("discarded/*", "training_step"), + ("costs/train/custom", "training_step"), + ] + assert run_two.defined_metrics == [ + ("training_step", None), + ("time/wall_clock_sec", None), + ("reward/*", "training_step"), + ("loss/*", "training_step"), + ("throughput/*", "training_step"), + ("costs/*", "training_step"), + ("time/*", "training_step"), + ("data/*", "training_step"), + ("train/*", "training_step"), + ("val/*", "training_step"), + ("test/*", "training_step"), + ("discarded/*", "training_step"), + ] From 7ee516e0293dc92627e61fafaa5c2437aae08660 Mon Sep 17 00:00:00 2001 From: Bohdan Date: Mon, 16 Mar 2026 16:34:34 -0700 Subject: [PATCH 2/2] Remove ART wandb regression test --- tests/unit/test_wandb_multi_run.py | 92 ------------------------------ 1 file changed, 92 deletions(-) delete mode 100644 tests/unit/test_wandb_multi_run.py diff --git a/tests/unit/test_wandb_multi_run.py b/tests/unit/test_wandb_multi_run.py deleted file mode 100644 index 33eb9eae..00000000 --- a/tests/unit/test_wandb_multi_run.py +++ /dev/null @@ -1,92 +0,0 @@ -import os -import sys -from pathlib import Path -from unittest.mock import patch - -from art import Model - - -def test_wandb_creates_separate_runs_per_model(tmp_path: Path): - class FakeRun: - def __init__(self, name: str): - self.name = name - self.id = name - self._is_finished = False - self.defined_metrics: list[tuple[str, str | None]] = [] - - def define_metric(self, name: str, *, step_metric: str | None = None) -> None: - self.defined_metrics.append((name, step_metric)) - - class FakeWandb: - def __init__(self): - self.init_calls: list[dict] = [] - self.runs: list[FakeRun] = [] - - @staticmethod - def Settings(**kwargs): - return kwargs - - def init(self, **kwargs): - self.init_calls.append(kwargs) - run = FakeRun(kwargs["name"]) - self.runs.append(run) - return run - - def define_metric(self, *args, **kwargs) -> None: - raise AssertionError("Model should define metrics on the run object") - - fake_wandb = FakeWandb() - model_one = Model( - name="run-one", - project="test-project", - base_path=str(tmp_path), - ) - model_two = Model( - name="run-two", - project="test-project", - base_path=str(tmp_path), - ) - - with patch.dict(os.environ, {"WANDB_API_KEY": "test-key"}): - with patch.dict(sys.modules, {"wandb": fake_wandb}): - run_one = model_one._get_wandb_run() - run_two = model_two._get_wandb_run() - model_one._define_wandb_step_metrics(["costs/train/custom"]) - - assert run_one is not None - assert run_two is not None - assert run_one is not run_two - assert [call["name"] for call in fake_wandb.init_calls] == [ - "run-one", - "run-two", - ] - assert all(call["reinit"] == "create_new" for call in fake_wandb.init_calls) - assert run_one.defined_metrics == [ - ("training_step", None), - ("time/wall_clock_sec", None), - ("reward/*", "training_step"), - ("loss/*", "training_step"), - ("throughput/*", "training_step"), - ("costs/*", "training_step"), - ("time/*", "training_step"), - ("data/*", "training_step"), - ("train/*", "training_step"), - ("val/*", "training_step"), - ("test/*", "training_step"), - ("discarded/*", "training_step"), - ("costs/train/custom", "training_step"), - ] - assert run_two.defined_metrics == [ - ("training_step", None), - ("time/wall_clock_sec", None), - ("reward/*", "training_step"), - ("loss/*", "training_step"), - ("throughput/*", "training_step"), - ("costs/*", "training_step"), - ("time/*", "training_step"), - ("data/*", "training_step"), - ("train/*", "training_step"), - ("val/*", "training_step"), - ("test/*", "training_step"), - ("discarded/*", "training_step"), - ]