diff --git a/src/art/model.py b/src/art/model.py index 5c082fde..ab08a88d 100644 --- a/src/art/model.py +++ b/src/art/model.py @@ -470,6 +470,7 @@ def _get_wandb_run(self) -> Optional["Run"]: id=self.name, config=self._wandb_config or None, resume="allow", + reinit="create_new", settings=wandb.Settings( x_stats_open_metrics_endpoints={ "vllm": "http://localhost:8000/metrics", @@ -492,18 +493,18 @@ def _get_wandb_run(self) -> Optional["Run"]: # Define training_step as the x-axis for all metrics. # This allows out-of-order logging (e.g., async validation for previous steps). - wandb.define_metric("training_step") - wandb.define_metric("time/wall_clock_sec") - wandb.define_metric("reward/*", step_metric="training_step") - wandb.define_metric("loss/*", step_metric="training_step") - wandb.define_metric("throughput/*", step_metric="training_step") - wandb.define_metric("costs/*", step_metric="training_step") - wandb.define_metric("time/*", step_metric="training_step") - wandb.define_metric("data/*", step_metric="training_step") - wandb.define_metric("train/*", step_metric="training_step") - wandb.define_metric("val/*", step_metric="training_step") - wandb.define_metric("test/*", step_metric="training_step") - wandb.define_metric("discarded/*", step_metric="training_step") + run.define_metric("training_step") + run.define_metric("time/wall_clock_sec") + run.define_metric("reward/*", step_metric="training_step") + run.define_metric("loss/*", step_metric="training_step") + run.define_metric("throughput/*", step_metric="training_step") + run.define_metric("costs/*", step_metric="training_step") + run.define_metric("time/*", step_metric="training_step") + run.define_metric("data/*", step_metric="training_step") + run.define_metric("train/*", step_metric="training_step") + run.define_metric("val/*", step_metric="training_step") + run.define_metric("test/*", step_metric="training_step") + run.define_metric("discarded/*", step_metric="training_step") self._sync_wandb_config(run) return self._wandb_run @@ -562,14 +563,16 @@ def _log_metrics( run.log(prefixed) def _define_wandb_step_metrics(self, keys: Iterable[str]) -> None: - import wandb + run = self._wandb_run + if run is None or run._is_finished: + return for key in keys: if not key.startswith("costs/"): continue if key in self._wandb_defined_metrics: continue - wandb.define_metric(key, step_metric="training_step") + run.define_metric(key, step_metric="training_step") self._wandb_defined_metrics.add(key) def _route_metrics_and_collect_non_costs(