From 171f04013a74a6d16367048c35be9fecd7f06156 Mon Sep 17 00:00:00 2001 From: anyangml Date: Tue, 10 Mar 2026 09:39:51 +0000 Subject: [PATCH 01/19] feat:support grad prob --- deepmd/main.py | 24 +++++++++++ deepmd/pt/entrypoints/main.py | 80 +++++++++++++++++++++++++++++++++++ deepmd/pt/loss/ener.py | 12 ++---- deepmd/pt/utils/dataloader.py | 1 + 4 files changed, 108 insertions(+), 9 deletions(-) diff --git a/deepmd/main.py b/deepmd/main.py index 492f3b085e..68b4e161ef 100644 --- a/deepmd/main.py +++ b/deepmd/main.py @@ -854,6 +854,30 @@ def main_parser() -> argparse.ArgumentParser: choices=["model-branch", "type-map", "descriptor", "fitting-net", "size"], nargs="+", ) + # grad-probe: per-task descriptor gradient probe + parser_grad_probe = subparsers.add_parser( + "grad-probe", + parents=[parser_log, parser_mpi_log], + help="Compute per-task descriptor gradient vectors from a multitask checkpoint.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser_grad_probe.add_argument("INPUT", help="Training config JSON file.") + parser_grad_probe.add_argument( + "--ckpt", required=True, help="Checkpoint .pt file path." + ) + parser_grad_probe.add_argument( + "-o", + "--output", + default="descriptor_grads.npz", + help="Output NPZ file path.", + ) + parser_grad_probe.add_argument( + "-n", + "--nbatches", + type=int, + default=1, + help="Number of batches per task to average gradient over.", + ) return parser diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index 0e248583ec..887f441352 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -513,6 +513,84 @@ def change_bias( log.info(f"Saved model to {output_path}") +def grad_probe(FLAGS) -> None: + """Compute per-task descriptor gradient vectors for multitask conflict analysis.""" + import numpy as np + + config = j_loader(FLAGS.INPUT) + config = update_deepmd_input(config, warning=True, dump="input_v2_compat.json") + multi_task = "model_dict" in config.get("model", {}) + shared_links = None + if multi_task: + config["model"], shared_links = preprocess_shared_params(config["model"]) + config = normalize(config, multi_task=multi_task) + + trainer = get_trainer( + config, + restart_model=FLAGS.ckpt, + shared_links=shared_links, + ) + trainer.wrapper.eval() + + module = ( + trainer.wrapper.module + if hasattr(trainer.wrapper, "module") + else trainer.wrapper + ) + + grads_per_task: dict = {} + param_names: list | None = None + param_shapes: list | None = None + + for task_key in trainer.model_keys: + trainer.optimizer.zero_grad(set_to_none=True) + + for _ in range(FLAGS.nbatches): + input_dict, label_dict, _ = trainer.get_data( + is_train=True, task_key=task_key + ) + cur_lr = trainer.lr_schedule.value(0) + _, loss, _ = trainer.wrapper( + **input_dict, cur_lr=cur_lr, label=label_dict, task_key=task_key + ) + loss.backward() + + model = module.model[task_key] + descriptor = model.get_descriptor() + grads, names, shapes = [], [], [] + for name, param in descriptor.named_parameters(): + if param.grad is not None: + g = param.grad.detach().cpu().float().numpy().ravel() + else: + g = np.zeros(param.numel(), dtype=np.float32) + grads.append(g) + names.append(name) + shapes.append(list(param.shape)) + + grad_vec = np.concatenate(grads) if grads else np.array([], dtype=np.float32) + if FLAGS.nbatches > 1: + grad_vec /= FLAGS.nbatches + + grads_per_task[task_key] = grad_vec + if param_names is None: + param_names = names + param_shapes = shapes + + trainer.optimizer.zero_grad(set_to_none=True) + log.info( + "Task '%s': descriptor gradient collected, norm=%.4e", + task_key, + float(np.linalg.norm(grad_vec)), + ) + + save_dict = {f"grads_{k}": v for k, v in grads_per_task.items()} + save_dict["param_names"] = np.array(param_names, dtype=object) + save_dict["param_shapes"] = np.array(param_shapes, dtype=object) + np.savez(FLAGS.output, **save_dict) + log.info("Descriptor gradient vectors saved to: %s", FLAGS.output) + + + @record def main(args: Optional[Union[list[str], argparse.Namespace]] = None) -> None: if not isinstance(args, argparse.Namespace): @@ -572,6 +650,8 @@ def main(args: Optional[Union[list[str], argparse.Namespace]] = None) -> None: check_frequency=FLAGS.frequency, training_script=FLAGS.training_script, ) + elif FLAGS.command == "grad-probe": + grad_probe(FLAGS) else: raise RuntimeError(f"Invalid command {FLAGS.command}!") diff --git a/deepmd/pt/loss/ener.py b/deepmd/pt/loss/ener.py index 00a352424e..9df0164a1a 100644 --- a/deepmd/pt/loss/ener.py +++ b/deepmd/pt/loss/ener.py @@ -151,13 +151,7 @@ def __init__( assert self.use_huber or self.use_l1_all, ( "f_use_norm can only be True when use_huber or use_l1_all is True." ) - if self.use_huber and ( - self.has_pf or self.has_gf or self.relative_f is not None - ): - raise RuntimeError( - "Huber loss is not implemented for force with atom_pref, generalized force and relative force. " - ) - + def forward(self, input_dict, model, label, natoms, learning_rate, mae=False): """Return loss on energy and force. @@ -346,8 +340,8 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False): ) else: l_huber_loss = custom_huber_loss( - (atom_pref * force_pred).reshape(-1), - (atom_pref * force_label).reshape(-1), + atom_pref_reshape * force_pred.reshape(-1), + atom_pref_reshape * force_label.reshape(-1), delta=self.huber_delta, ) loss += pref_pf * l_huber_loss diff --git a/deepmd/pt/utils/dataloader.py b/deepmd/pt/utils/dataloader.py index bc771b41d4..716dbc98c5 100644 --- a/deepmd/pt/utils/dataloader.py +++ b/deepmd/pt/utils/dataloader.py @@ -214,6 +214,7 @@ def print_summary( name: str, prob: list[float], ) -> None: + return rank = dist.get_rank() if dist.is_initialized() else 0 if rank == 0: print_summary( From b7209c9b19cd38326fa0cece6db2a052aa6935f8 Mon Sep 17 00:00:00 2001 From: anyangml Date: Tue, 10 Mar 2026 09:56:39 +0000 Subject: [PATCH 02/19] fix: register arg --- deepmd/main.py | 1 + 1 file changed, 1 insertion(+) diff --git a/deepmd/main.py b/deepmd/main.py index 68b4e161ef..df8700e136 100644 --- a/deepmd/main.py +++ b/deepmd/main.py @@ -943,6 +943,7 @@ def main(args: Optional[list[str]] = None) -> None: "convert-from", "train-nvnmd", "change-bias", + "grad-probe", ): deepmd_main = BACKENDS[args.backend]().entry_point_hook elif args.command is None: From 5b5207057de8478ced5fd58584c3f701902a7be6 Mon Sep 17 00:00:00 2001 From: anyangml Date: Wed, 11 Mar 2026 03:09:23 +0000 Subject: [PATCH 03/19] fix: use start lr --- deepmd/pt/entrypoints/main.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index 887f441352..3d6e9f0304 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -516,6 +516,10 @@ def change_bias( def grad_probe(FLAGS) -> None: """Compute per-task descriptor gradient vectors for multitask conflict analysis.""" import numpy as np + from deepmd.common import j_loader + from deepmd.utils.compat import update_deepmd_input + from deepmd.pt.utils.multi_task import preprocess_shared_params + from deepmd.utils.argcheck import normalize config = j_loader(FLAGS.INPUT) config = update_deepmd_input(config, warning=True, dump="input_v2_compat.json") @@ -549,7 +553,7 @@ def grad_probe(FLAGS) -> None: input_dict, label_dict, _ = trainer.get_data( is_train=True, task_key=task_key ) - cur_lr = trainer.lr_schedule.value(0) + cur_lr = config["learning_rate"]["start_lr"] _, loss, _ = trainer.wrapper( **input_dict, cur_lr=cur_lr, label=label_dict, task_key=task_key ) From 00669b7c60d068771b120a0d7b0202f17c2ba75b Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Sat, 4 Apr 2026 10:14:05 +0800 Subject: [PATCH 04/19] feat: add fparam and case_embd arguments to descriptor forward methods for multitask training support --- deepmd/pt/model/descriptor/dpa1.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/deepmd/pt/model/descriptor/dpa1.py b/deepmd/pt/model/descriptor/dpa1.py index 97b1e29da3..82944442cc 100644 --- a/deepmd/pt/model/descriptor/dpa1.py +++ b/deepmd/pt/model/descriptor/dpa1.py @@ -656,6 +656,8 @@ def forward( nlist: torch.Tensor, mapping: Optional[torch.Tensor] = None, comm_dict: Optional[dict[str, torch.Tensor]] = None, + fparam: Optional[torch.Tensor] = None, + case_embd: Optional[torch.Tensor] = None, ): """Compute the descriptor. @@ -671,6 +673,10 @@ def forward( The index mapping, not required by this descriptor. comm_dict The data needed for communication for parallel inference. + fparam + The frame-level parameters. shape: nf x nfparam + case_embd + The case (dataset) embedding for multitask training with shared fitting. shape: nf x dim_case_embd Returns ------- From ab2bf084d532048796b2b89bd9c18c1b26c639a3 Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Sat, 4 Apr 2026 14:58:22 +0800 Subject: [PATCH 05/19] feat: add fparam and case_embd support to descriptor compute methods and update training state dictionary loading logic --- deepmd/pt/train/training.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 7e0761915f..f374fc52c5 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -559,12 +559,13 @@ def collect_single_finetune_params( for kk in rm_list: state_dict.pop(kk) state_dict["_extra_state"]["model_params"] = old_model_params - out_shape_list = [ - "model.Default.atomic_model.out_bias", - "model.Default.atomic_model.out_std", - ] + out_shape_list = [] + for model_key in self.model_keys: + out_shape_list.append(f"model.{model_key}.atomic_model.out_bias") + out_shape_list.append(f"model.{model_key}.atomic_model.out_std") for kk in out_shape_list: - state_dict[kk] = state_dict[kk][:1, :, :1] + if kk in state_dict: + state_dict[kk] = state_dict[kk][:1, :, :1] self.wrapper.load_state_dict(state_dict) # change bias for fine-tuning From ff60390b5da1e4f2d2a91b95fa1d5fd6ffbbdc05 Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Tue, 14 Apr 2026 13:11:52 +0800 Subject: [PATCH 06/19] support similarity std --- deepmd/pt/entrypoints/main.py | 137 +++++++++++++++++++++++++++------- deepmd/pt/loss/ener.py | 4 +- 2 files changed, 111 insertions(+), 30 deletions(-) diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index 3d6e9f0304..4ce55b55be 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -516,10 +516,11 @@ def change_bias( def grad_probe(FLAGS) -> None: """Compute per-task descriptor gradient vectors for multitask conflict analysis.""" import numpy as np + from deepmd.common import j_loader - from deepmd.utils.compat import update_deepmd_input from deepmd.pt.utils.multi_task import preprocess_shared_params from deepmd.utils.argcheck import normalize + from deepmd.utils.compat import update_deepmd_input config = j_loader(FLAGS.INPUT) config = update_deepmd_input(config, warning=True, dump="input_v2_compat.json") @@ -546,52 +547,132 @@ def grad_probe(FLAGS) -> None: param_names: list | None = None param_shapes: list | None = None - for task_key in trainer.model_keys: - trainer.optimizer.zero_grad(set_to_none=True) - - for _ in range(FLAGS.nbatches): + # Initialize trackers + task_batch_norms = {k: [] for k in trainer.model_keys} + task_accum_grads = {k: None for k in trainer.model_keys} + pairwise_sims = { + (k1, k2): [] + for i, k1 in enumerate(trainer.model_keys) + for j, k2 in enumerate(trainer.model_keys) + if i < j + } + + cur_lr = config["learning_rate"]["start_lr"] + + for b in range(FLAGS.nbatches): + batch_grads = {} + for task_key in trainer.model_keys: + trainer.optimizer.zero_grad(set_to_none=True) input_dict, label_dict, _ = trainer.get_data( is_train=True, task_key=task_key ) - cur_lr = config["learning_rate"]["start_lr"] _, loss, _ = trainer.wrapper( **input_dict, cur_lr=cur_lr, label=label_dict, task_key=task_key ) loss.backward() - model = module.model[task_key] - descriptor = model.get_descriptor() - grads, names, shapes = [], [], [] - for name, param in descriptor.named_parameters(): - if param.grad is not None: - g = param.grad.detach().cpu().float().numpy().ravel() - else: - g = np.zeros(param.numel(), dtype=np.float32) - grads.append(g) - names.append(name) - shapes.append(list(param.shape)) + model = module.model[task_key] + descriptor = model.get_descriptor() + + # Extract current batch gradient + current_grads = [] + for name, param in descriptor.named_parameters(): + if param.grad is not None: + current_grads.append( + param.grad.detach().cpu().float().numpy().ravel() + ) - grad_vec = np.concatenate(grads) if grads else np.array([], dtype=np.float32) - if FLAGS.nbatches > 1: - grad_vec /= FLAGS.nbatches + if current_grads: + grad_vec = np.concatenate(current_grads) + task_batch_norms[task_key].append(np.linalg.norm(grad_vec)) + if task_accum_grads[task_key] is None: + task_accum_grads[task_key] = grad_vec + else: + task_accum_grads[task_key] += grad_vec + batch_grads[task_key] = grad_vec + + if param_names is None: + # get names and shapes from descriptor parameters + param_names = [ + n for n, p in descriptor.named_parameters() if p.requires_grad + ] + param_shapes = [ + list(p.shape) + for n, p in descriptor.named_parameters() + if p.requires_grad + ] + + # Compute pairwise similarity for this batch + for k1, k2 in pairwise_sims.keys(): + if k1 in batch_grads and k2 in batch_grads: + g1 = batch_grads[k1] + g2 = batch_grads[k2] + norm1 = np.linalg.norm(g1) + norm2 = np.linalg.norm(g2) + if norm1 > 1e-10 and norm2 > 1e-10: + sim = np.dot(g1, g2) / (norm1 * norm2) + pairwise_sims[(k1, k2)].append(sim) + + # Compute final statistics and log for each task + for task_key in trainer.model_keys: + accumulated_grads = task_accum_grads[task_key] + norms = task_batch_norms[task_key] + if accumulated_grads is not None: + avg_grad_vec = accumulated_grads / FLAGS.nbatches + norm_mean = np.mean(norms) + norm_std = np.std(norms) + else: + avg_grad_vec = np.array([], dtype=np.float32) + norm_mean = 0.0 + norm_std = 0.0 - grads_per_task[task_key] = grad_vec - if param_names is None: - param_names = names - param_shapes = shapes + grads_per_task[task_key] = avg_grad_vec + grads_per_task[f"{task_key}_norm_mean"] = norm_mean + grads_per_task[f"{task_key}_norm_std"] = norm_std - trainer.optimizer.zero_grad(set_to_none=True) log.info( - "Task '%s': descriptor gradient collected, norm=%.4e", + "Task '%s': collected %d batches. Avg Grad Norm=%.4e, Mean(Batch Norms)=%.4e, Std(Batch Norms)=%.4e", task_key, - float(np.linalg.norm(grad_vec)), + len(norms), + float(np.linalg.norm(avg_grad_vec)), + float(norm_mean), + float(norm_std), ) + # Compute pairwise similarity statistics + for (k1, k2), sims in pairwise_sims.items(): + if sims: + sim_mean = np.mean(sims) + sim_std = np.std(sims) + + # Compute similarity of accumulated gradients (reproducible from avg_grad_vec) + g1_avg = grads_per_task[k1] + g2_avg = grads_per_task[k2] + norm1_avg = np.linalg.norm(g1_avg) + norm2_avg = np.linalg.norm(g2_avg) + if norm1_avg > 1e-10 and norm2_avg > 1e-10: + sim_accum = np.dot(g1_avg, g2_avg) / (norm1_avg * norm2_avg) + else: + sim_accum = 0.0 + + grads_per_task[f"sim_{k1}_{k2}_mean"] = sim_mean + grads_per_task[f"sim_{k1}_{k2}_std"] = sim_std + grads_per_task[f"sim_{k1}_{k2}_accum"] = sim_accum + + log.info( + "Similarity '%s' vs '%s': Mean=%.4f, Std=%.4f, Accum=%.4f", + k1, + k2, + float(sim_mean), + float(sim_std), + float(sim_accum), + ) + save_dict = {f"grads_{k}": v for k, v in grads_per_task.items()} save_dict["param_names"] = np.array(param_names, dtype=object) save_dict["param_shapes"] = np.array(param_shapes, dtype=object) np.savez(FLAGS.output, **save_dict) - log.info("Descriptor gradient vectors saved to: %s", FLAGS.output) + log.info("Descriptor gradient vectors and similarity stats saved to: %s", FLAGS.output) diff --git a/deepmd/pt/loss/ener.py b/deepmd/pt/loss/ener.py index 9df0164a1a..e3df7dfbe2 100644 --- a/deepmd/pt/loss/ener.py +++ b/deepmd/pt/loss/ener.py @@ -214,7 +214,7 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False): l2_ener_loss.detach(), find_energy ) if not self.use_huber: - loss += atom_norm * (pref_e * l2_ener_loss) + loss += atom_norm**2 * (pref_e * l2_ener_loss) else: l_huber_loss = custom_huber_loss( atom_norm * model_pred["energy"], @@ -399,7 +399,7 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False): l2_virial_loss.detach(), find_virial ) if not self.use_huber: - loss += atom_norm * (pref_v * l2_virial_loss) + loss += atom_norm**2 * (pref_v * l2_virial_loss) else: l_huber_loss = custom_huber_loss( atom_norm * model_pred["virial"].reshape(-1), From 77e44190ad876ef20263ccb1227b7b0e498c21ae Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Mon, 20 Apr 2026 15:40:46 +0800 Subject: [PATCH 07/19] update similarity calculation using dot product --- deepmd/pt/entrypoints/main.py | 41 ++++++++++++++--------------------- 1 file changed, 16 insertions(+), 25 deletions(-) diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index 4ce55b55be..b985378de4 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -602,16 +602,12 @@ def grad_probe(FLAGS) -> None: if p.requires_grad ] - # Compute pairwise similarity for this batch + # Compute pairwise dot products for this batch for k1, k2 in pairwise_sims.keys(): if k1 in batch_grads and k2 in batch_grads: g1 = batch_grads[k1] g2 = batch_grads[k2] - norm1 = np.linalg.norm(g1) - norm2 = np.linalg.norm(g2) - if norm1 > 1e-10 and norm2 > 1e-10: - sim = np.dot(g1, g2) / (norm1 * norm2) - pairwise_sims[(k1, k2)].append(sim) + pairwise_sims[(k1, k2)].append(float(np.dot(g1, g2))) # Compute final statistics and log for each task for task_key in trainer.model_keys: @@ -639,33 +635,28 @@ def grad_probe(FLAGS) -> None: float(norm_std), ) - # Compute pairwise similarity statistics - for (k1, k2), sims in pairwise_sims.items(): - if sims: - sim_mean = np.mean(sims) - sim_std = np.std(sims) + # Compute pairwise dot product statistics + for (k1, k2), dots in pairwise_sims.items(): + if dots: + dot_mean = np.mean(dots) + dot_std = np.std(dots) - # Compute similarity of accumulated gradients (reproducible from avg_grad_vec) + # Dot product of accumulated average gradient vectors g1_avg = grads_per_task[k1] g2_avg = grads_per_task[k2] - norm1_avg = np.linalg.norm(g1_avg) - norm2_avg = np.linalg.norm(g2_avg) - if norm1_avg > 1e-10 and norm2_avg > 1e-10: - sim_accum = np.dot(g1_avg, g2_avg) / (norm1_avg * norm2_avg) - else: - sim_accum = 0.0 + dot_accum = float(np.dot(g1_avg, g2_avg)) - grads_per_task[f"sim_{k1}_{k2}_mean"] = sim_mean - grads_per_task[f"sim_{k1}_{k2}_std"] = sim_std - grads_per_task[f"sim_{k1}_{k2}_accum"] = sim_accum + grads_per_task[f"dot_{k1}_{k2}_mean"] = dot_mean + grads_per_task[f"dot_{k1}_{k2}_std"] = dot_std + grads_per_task[f"dot_{k1}_{k2}_accum"] = dot_accum log.info( - "Similarity '%s' vs '%s': Mean=%.4f, Std=%.4f, Accum=%.4f", + "Grad dot product '%s' vs '%s': Mean=%.4e, Std=%.4e, Accum=%.4e", k1, k2, - float(sim_mean), - float(sim_std), - float(sim_accum), + float(dot_mean), + float(dot_std), + float(dot_accum), ) save_dict = {f"grads_{k}": v for k, v in grads_per_task.items()} From cf84059ffdac702d654a94c5f61ee2e118ba47e6 Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Mon, 20 Apr 2026 16:07:42 +0800 Subject: [PATCH 08/19] fix sampling method in grad-prob --- deepmd/main.py | 10 ++++++ deepmd/pt/entrypoints/main.py | 57 +++++++++++++++++++++++------------ 2 files changed, 48 insertions(+), 19 deletions(-) diff --git a/deepmd/main.py b/deepmd/main.py index df8700e136..a2ca4e25d2 100644 --- a/deepmd/main.py +++ b/deepmd/main.py @@ -878,6 +878,16 @@ def main_parser() -> argparse.ArgumentParser: default=1, help="Number of batches per task to average gradient over.", ) + parser_grad_probe.add_argument( + "-k", + "--accumulate-k", + type=int, + default=1, + dest="accumulate_k", + help="Accumulate K batches before computing each dot product. " + "Total batches collected = nbatches * accumulate_k. " + "Larger K reduces norm variance and improves SNR of mean/std.", + ) return parser diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index b985378de4..0834a238c9 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -547,20 +547,25 @@ def grad_probe(FLAGS) -> None: param_names: list | None = None param_shapes: list | None = None + accumulate_k = getattr(FLAGS, "accumulate_k", 1) + total_batches = FLAGS.nbatches * accumulate_k + # Initialize trackers task_batch_norms = {k: [] for k in trainer.model_keys} task_accum_grads = {k: None for k in trainer.model_keys} - pairwise_sims = { + pairwise_dots = { (k1, k2): [] for i, k1 in enumerate(trainer.model_keys) for j, k2 in enumerate(trainer.model_keys) if i < j } + # Accumulators within each group of K batches + group_grads = {k: None for k in trainer.model_keys} + cur_lr = config["learning_rate"]["start_lr"] - for b in range(FLAGS.nbatches): - batch_grads = {} + for b in range(total_batches): for task_key in trainer.model_keys: trainer.optimizer.zero_grad(set_to_none=True) input_dict, label_dict, _ = trainer.get_data( @@ -574,7 +579,6 @@ def grad_probe(FLAGS) -> None: model = module.model[task_key] descriptor = model.get_descriptor() - # Extract current batch gradient current_grads = [] for name, param in descriptor.named_parameters(): if param.grad is not None: @@ -584,15 +588,17 @@ def grad_probe(FLAGS) -> None: if current_grads: grad_vec = np.concatenate(current_grads) - task_batch_norms[task_key].append(np.linalg.norm(grad_vec)) if task_accum_grads[task_key] is None: - task_accum_grads[task_key] = grad_vec + task_accum_grads[task_key] = grad_vec.copy() else: task_accum_grads[task_key] += grad_vec - batch_grads[task_key] = grad_vec + + if group_grads[task_key] is None: + group_grads[task_key] = grad_vec.copy() + else: + group_grads[task_key] += grad_vec if param_names is None: - # get names and shapes from descriptor parameters param_names = [ n for n, p in descriptor.named_parameters() if p.requires_grad ] @@ -602,19 +608,26 @@ def grad_probe(FLAGS) -> None: if p.requires_grad ] - # Compute pairwise dot products for this batch - for k1, k2 in pairwise_sims.keys(): - if k1 in batch_grads and k2 in batch_grads: - g1 = batch_grads[k1] - g2 = batch_grads[k2] - pairwise_sims[(k1, k2)].append(float(np.dot(g1, g2))) + # At the end of each group of K batches, record norms and dot products then reset + if (b + 1) % accumulate_k == 0: + for task_key in trainer.model_keys: + if group_grads[task_key] is not None: + task_batch_norms[task_key].append( + np.linalg.norm(group_grads[task_key]) + ) + for k1, k2 in pairwise_dots.keys(): + g1 = group_grads.get(k1) + g2 = group_grads.get(k2) + if g1 is not None and g2 is not None: + pairwise_dots[(k1, k2)].append(float(np.dot(g1, g2))) + group_grads = {k: None for k in trainer.model_keys} # Compute final statistics and log for each task for task_key in trainer.model_keys: accumulated_grads = task_accum_grads[task_key] norms = task_batch_norms[task_key] if accumulated_grads is not None: - avg_grad_vec = accumulated_grads / FLAGS.nbatches + avg_grad_vec = accumulated_grads / total_batches norm_mean = np.mean(norms) norm_std = np.std(norms) else: @@ -627,16 +640,19 @@ def grad_probe(FLAGS) -> None: grads_per_task[f"{task_key}_norm_std"] = norm_std log.info( - "Task '%s': collected %d batches. Avg Grad Norm=%.4e, Mean(Batch Norms)=%.4e, Std(Batch Norms)=%.4e", + "Task '%s': collected %d batches (groups=%d, k=%d). " + "Avg Grad Norm=%.4e, Mean(Batch Norms)=%.4e, Std(Batch Norms)=%.4e", task_key, - len(norms), + total_batches, + FLAGS.nbatches, + accumulate_k, float(np.linalg.norm(avg_grad_vec)), float(norm_mean), float(norm_std), ) # Compute pairwise dot product statistics - for (k1, k2), dots in pairwise_sims.items(): + for (k1, k2), dots in pairwise_dots.items(): if dots: dot_mean = np.mean(dots) dot_std = np.std(dots) @@ -651,12 +667,15 @@ def grad_probe(FLAGS) -> None: grads_per_task[f"dot_{k1}_{k2}_accum"] = dot_accum log.info( - "Grad dot product '%s' vs '%s': Mean=%.4e, Std=%.4e, Accum=%.4e", + "Grad dot product '%s' vs '%s': Mean=%.4e, Std=%.4e, Accum=%.4e " + "(SNR=%.2f, n_groups=%d)", k1, k2, float(dot_mean), float(dot_std), float(dot_accum), + abs(dot_mean) / dot_std if dot_std > 0 else float("inf"), + len(dots), ) save_dict = {f"grads_{k}": v for k, v in grads_per_task.items()} From c9d6b27ea867bd673cbd8c8bc21c142bd2577733 Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Mon, 20 Apr 2026 16:52:18 +0800 Subject: [PATCH 09/19] fix dot_mean normalization --- deepmd/main.py | 4 ++-- deepmd/pt/entrypoints/main.py | 19 ++++++++++++------- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/deepmd/main.py b/deepmd/main.py index a2ca4e25d2..0cea509295 100644 --- a/deepmd/main.py +++ b/deepmd/main.py @@ -884,8 +884,8 @@ def main_parser() -> argparse.ArgumentParser: type=int, default=1, dest="accumulate_k", - help="Accumulate K batches before computing each dot product. " - "Total batches collected = nbatches * accumulate_k. " + help="Group size for gradient accumulation before computing each dot product. " + "Total batches collected = nbatches; number of dot product samples = nbatches // k. " "Larger K reduces norm variance and improves SNR of mean/std.", ) return parser diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index 0834a238c9..398ed940f2 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -548,7 +548,7 @@ def grad_probe(FLAGS) -> None: param_shapes: list | None = None accumulate_k = getattr(FLAGS, "accumulate_k", 1) - total_batches = FLAGS.nbatches * accumulate_k + total_batches = FLAGS.nbatches # Initialize trackers task_batch_norms = {k: [] for k in trainer.model_keys} @@ -610,14 +610,19 @@ def grad_probe(FLAGS) -> None: # At the end of each group of K batches, record norms and dot products then reset if (b + 1) % accumulate_k == 0: + # Normalize to per-batch mean so scale is independent of k + group_means = { + k: g / accumulate_k if g is not None else None + for k, g in group_grads.items() + } for task_key in trainer.model_keys: - if group_grads[task_key] is not None: + if group_means[task_key] is not None: task_batch_norms[task_key].append( - np.linalg.norm(group_grads[task_key]) + np.linalg.norm(group_means[task_key]) ) for k1, k2 in pairwise_dots.keys(): - g1 = group_grads.get(k1) - g2 = group_grads.get(k2) + g1 = group_means.get(k1) + g2 = group_means.get(k2) if g1 is not None and g2 is not None: pairwise_dots[(k1, k2)].append(float(np.dot(g1, g2))) group_grads = {k: None for k in trainer.model_keys} @@ -641,10 +646,10 @@ def grad_probe(FLAGS) -> None: log.info( "Task '%s': collected %d batches (groups=%d, k=%d). " - "Avg Grad Norm=%.4e, Mean(Batch Norms)=%.4e, Std(Batch Norms)=%.4e", + "Avg Grad Norm=%.4e, Mean(Group Norms)=%.4e, Std(Group Norms)=%.4e", task_key, total_batches, - FLAGS.nbatches, + total_batches // accumulate_k, accumulate_k, float(np.linalg.norm(avg_grad_vec)), float(norm_mean), From b29fe5f20969818d0d7ed354c5fb25d6fa6aef2e Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Thu, 23 Apr 2026 14:02:37 +0800 Subject: [PATCH 10/19] add metrics --- deepmd/pt/entrypoints/main.py | 75 ++++++++++++++++++++++++++++++++++- 1 file changed, 73 insertions(+), 2 deletions(-) diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index 398ed940f2..c61ecbc614 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -656,20 +656,72 @@ def grad_probe(FLAGS) -> None: float(norm_std), ) - # Compute pairwise dot product statistics + # Compute pairwise dot product and Pearson correlation statistics for (k1, k2), dots in pairwise_dots.items(): if dots: dot_mean = np.mean(dots) dot_std = np.std(dots) - # Dot product of accumulated average gradient vectors g1_avg = grads_per_task[k1] g2_avg = grads_per_task[k2] dot_accum = float(np.dot(g1_avg, g2_avg)) + # Pearson correlation across groups: + # cov_sum = Σ_j Cov(g1_j, g2_j) = E[dot(g1,g2)] - dot(ḡ1, ḡ2) + # var_sum = Σ_j Var(g_j) = E[||g||²] - ||ḡ||² + # pearson = cov_sum / sqrt(var1_sum * var2_sum) + cov_sum = float(dot_mean - dot_accum) + norms_1 = np.array(task_batch_norms[k1]) + norms_2 = np.array(task_batch_norms[k2]) + var1_sum = float(np.mean(norms_1**2) - np.dot(g1_avg, g1_avg)) + var2_sum = float(np.mean(norms_2**2) - np.dot(g2_avg, g2_avg)) + denom = np.sqrt(max(var1_sum, 0.0) * max(var2_sum, 0.0)) + pearson = float(cov_sum / denom) if denom > 1e-12 else float("nan") + + # Group-level weighted cosine similarity: + # per-group cos_i = dot_i / (||g1_i|| * ||g2_i||) + # weight w_i = ||g1_i|| * ||g2_i|| so high-norm groups dominate + norm_products = norms_1 * norms_2 # (n_groups,) weight per group + valid = norm_products > 1e-12 + cos_per_group = np.where( + valid, + np.array(dots) / np.where(valid, norm_products, 1.0), + 0.0, + ) + norm_product_sum = float(np.sum(norm_products)) + if norm_product_sum > 1e-12: + cos_group_mean = float( + np.sum(norm_products * cos_per_group) / norm_product_sum + ) + cos_group_std = float( + np.sqrt( + np.sum(norm_products * (cos_per_group - cos_group_mean) ** 2) + / norm_product_sum + ) + ) + else: + cos_group_mean = float("nan") + cos_group_std = float("nan") + + # Global-level cosine similarity: + # cosine of accumulated average gradient vectors; + # large-norm batches naturally dominate the direction of the average + norm_g1 = float(np.linalg.norm(g1_avg)) + norm_g2 = float(np.linalg.norm(g2_avg)) + denom_global = norm_g1 * norm_g2 + cos_global = ( + float(dot_accum / denom_global) + if denom_global > 1e-12 + else float("nan") + ) + grads_per_task[f"dot_{k1}_{k2}_mean"] = dot_mean grads_per_task[f"dot_{k1}_{k2}_std"] = dot_std grads_per_task[f"dot_{k1}_{k2}_accum"] = dot_accum + grads_per_task[f"pearson_{k1}_{k2}"] = pearson + grads_per_task[f"cos_group_mean_{k1}_{k2}"] = cos_group_mean + grads_per_task[f"cos_group_std_{k1}_{k2}"] = cos_group_std + grads_per_task[f"cos_global_{k1}_{k2}"] = cos_global log.info( "Grad dot product '%s' vs '%s': Mean=%.4e, Std=%.4e, Accum=%.4e " @@ -682,6 +734,25 @@ def grad_probe(FLAGS) -> None: abs(dot_mean) / dot_std if dot_std > 0 else float("inf"), len(dots), ) + log.info( + "Pearson correlation '%s' vs '%s': %.4f " + "(cov_sum=%.4e, var1=%.4e, var2=%.4e)", + k1, + k2, + pearson, + cov_sum, + var1_sum, + var2_sum, + ) + log.info( + "Cosine similarity '%s' vs '%s': " + "group_mean=%.4f, group_std=%.4f, global=%.4f", + k1, + k2, + cos_group_mean, + cos_group_std, + cos_global, + ) save_dict = {f"grads_{k}": v for k, v in grads_per_task.items()} save_dict["param_names"] = np.array(param_names, dtype=object) From 0f56a5f01f57b4bea12548307c626ea89f8a10b9 Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Mon, 27 Apr 2026 15:52:39 +0800 Subject: [PATCH 11/19] fix: revert covariance --- deepmd/pt/entrypoints/main.py | 81 ++++++++++------------------------- 1 file changed, 22 insertions(+), 59 deletions(-) diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index c61ecbc614..d7bbbdc072 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -656,102 +656,65 @@ def grad_probe(FLAGS) -> None: float(norm_std), ) - # Compute pairwise dot product and Pearson correlation statistics + # Compute pairwise dot product and cosine similarity statistics for (k1, k2), dots in pairwise_dots.items(): if dots: - dot_mean = np.mean(dots) - dot_std = np.std(dots) + dot_mean = float(np.mean(dots)) + dot_std = float(np.std(dots)) g1_avg = grads_per_task[k1] g2_avg = grads_per_task[k2] - dot_accum = float(np.dot(g1_avg, g2_avg)) + dot_global = float(np.dot(g1_avg, g2_avg)) - # Pearson correlation across groups: - # cov_sum = Σ_j Cov(g1_j, g2_j) = E[dot(g1,g2)] - dot(ḡ1, ḡ2) - # var_sum = Σ_j Var(g_j) = E[||g||²] - ||ḡ||² - # pearson = cov_sum / sqrt(var1_sum * var2_sum) - cov_sum = float(dot_mean - dot_accum) + # Group-level cosine: per-group cos_i = dot_i / (||g1_i|| * ||g2_i||) norms_1 = np.array(task_batch_norms[k1]) norms_2 = np.array(task_batch_norms[k2]) - var1_sum = float(np.mean(norms_1**2) - np.dot(g1_avg, g1_avg)) - var2_sum = float(np.mean(norms_2**2) - np.dot(g2_avg, g2_avg)) - denom = np.sqrt(max(var1_sum, 0.0) * max(var2_sum, 0.0)) - pearson = float(cov_sum / denom) if denom > 1e-12 else float("nan") - - # Group-level weighted cosine similarity: - # per-group cos_i = dot_i / (||g1_i|| * ||g2_i||) - # weight w_i = ||g1_i|| * ||g2_i|| so high-norm groups dominate - norm_products = norms_1 * norms_2 # (n_groups,) weight per group + norm_products = norms_1 * norms_2 valid = norm_products > 1e-12 cos_per_group = np.where( valid, np.array(dots) / np.where(valid, norm_products, 1.0), - 0.0, + float("nan"), ) - norm_product_sum = float(np.sum(norm_products)) - if norm_product_sum > 1e-12: - cos_group_mean = float( - np.sum(norm_products * cos_per_group) / norm_product_sum - ) - cos_group_std = float( - np.sqrt( - np.sum(norm_products * (cos_per_group - cos_group_mean) ** 2) - / norm_product_sum - ) - ) + valid_cos = cos_per_group[valid] + if len(valid_cos) > 0: + cos_group_mean = float(np.mean(valid_cos)) + cos_group_std = float(np.std(valid_cos)) else: cos_group_mean = float("nan") cos_group_std = float("nan") - # Global-level cosine similarity: - # cosine of accumulated average gradient vectors; - # large-norm batches naturally dominate the direction of the average + # Global-level cosine: cosine of accumulated average gradient vectors norm_g1 = float(np.linalg.norm(g1_avg)) norm_g2 = float(np.linalg.norm(g2_avg)) denom_global = norm_g1 * norm_g2 cos_global = ( - float(dot_accum / denom_global) + float(dot_global / denom_global) if denom_global > 1e-12 else float("nan") ) grads_per_task[f"dot_{k1}_{k2}_mean"] = dot_mean grads_per_task[f"dot_{k1}_{k2}_std"] = dot_std - grads_per_task[f"dot_{k1}_{k2}_accum"] = dot_accum - grads_per_task[f"pearson_{k1}_{k2}"] = pearson + grads_per_task[f"dot_{k1}_{k2}_global"] = dot_global grads_per_task[f"cos_group_mean_{k1}_{k2}"] = cos_group_mean grads_per_task[f"cos_group_std_{k1}_{k2}"] = cos_group_std grads_per_task[f"cos_global_{k1}_{k2}"] = cos_global log.info( - "Grad dot product '%s' vs '%s': Mean=%.4e, Std=%.4e, Accum=%.4e " - "(SNR=%.2f, n_groups=%d)", - k1, - k2, - float(dot_mean), - float(dot_std), - float(dot_accum), - abs(dot_mean) / dot_std if dot_std > 0 else float("inf"), - len(dots), - ) - log.info( - "Pearson correlation '%s' vs '%s': %.4f " - "(cov_sum=%.4e, var1=%.4e, var2=%.4e)", - k1, - k2, - pearson, - cov_sum, - var1_sum, - var2_sum, - ) - log.info( - "Cosine similarity '%s' vs '%s': " - "group_mean=%.4f, group_std=%.4f, global=%.4f", + "Grad similarity '%s' vs '%s': " + "dot_mean=%.4e, dot_std=%.4e, dot_global=%.4e, " + "cos_group_mean=%.4f, cos_group_std=%.4f, cos_global=%.4f " + "(n_groups=%d)", k1, k2, + dot_mean, + dot_std, + dot_global, cos_group_mean, cos_group_std, cos_global, + len(dots), ) save_dict = {f"grads_{k}": v for k, v in grads_per_task.items()} From 95d33ca9f9f8e10c5ea07bca37c0941d798b350d Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Wed, 6 May 2026 12:55:55 +0800 Subject: [PATCH 12/19] try: pc-grad --- deepmd/pt/train/training.py | 89 +++++++++++++++++++++++++++++++++++-- deepmd/utils/argcheck.py | 1 + 2 files changed, 86 insertions(+), 4 deletions(-) diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index f374fc52c5..483ccef9af 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -626,6 +626,10 @@ def single_model_finetune( assert sum_prob > 0.0, "Sum of model prob must be larger than 0!" self.model_prob = self.model_prob / sum_prob + self.use_pcgrad = self.multi_task and training_params.get("use_pcgrad", False) + if self.use_pcgrad: + log.info("PCGrad enabled: descriptor gradients will be projected each step.") + # Multi-task share params if shared_links is not None: _data_stat_protect = np.array( @@ -764,10 +768,87 @@ def step(_step_id, task_key="Default") -> None: pref_lr = _lr.start_lr else: pref_lr = cur_lr - model_pred, loss, more_loss = self.wrapper( - **input_dict, cur_lr=pref_lr, label=label_dict, task_key=task_key - ) - loss.backward() + if self.use_pcgrad: + module = ( + self.wrapper.module + if hasattr(self.wrapper, "module") + else self.wrapper + ) + descriptor = module.model[self.model_keys[0]].get_descriptor() + desc_params = list(descriptor.parameters()) + desc_param_ids = {id(p) for p in desc_params} + all_params = list(self.wrapper.parameters()) + + task_grads = {} + more_loss_per_task = {} + for tk in self.model_keys: + self.optimizer.zero_grad(set_to_none=True) + in_d, lbl_d, _ = self.get_data(is_train=True, task_key=tk) + _, loss, more_loss = self.wrapper( + **in_d, cur_lr=pref_lr, label=lbl_d, task_key=tk + ) + loss.backward() + task_grads[tk] = { + id(p): p.grad.clone() if p.grad is not None else None + for p in all_params + } + more_loss_per_task[tk] = more_loss + + k0, k1 = self.model_keys[0], self.model_keys[1] + + # PCGrad projection: only on descriptor params + desc_pids = [ + pid for pid in desc_param_ids + if task_grads[k0].get(pid) is not None + and task_grads[k1].get(pid) is not None + ] + if desc_pids: + g0 = torch.cat([task_grads[k0][pid].flatten() for pid in desc_pids]) + g1 = torch.cat([task_grads[k1][pid].flatten() for pid in desc_pids]) + dot = (g0 * g1).sum() + projected = dot < 0 + if projected: + g0_proj = g0 - dot / (g1 * g1).sum().clamp(min=1e-12) * g1 + g1_proj = g1 - dot / (g0 * g0).sum().clamp(min=1e-12) * g0 + offset = 0 + for pid in desc_pids: + numel = task_grads[k0][pid].numel() + shape = task_grads[k0][pid].shape + task_grads[k0][pid] = g0_proj[offset:offset + numel].reshape(shape) + task_grads[k1][pid] = g1_proj[offset:offset + numel].reshape(shape) + offset += numel + if self.rank == 0 and (_step_id + 1) % self.disp_freq == 0: + cos_sim = dot / ( + g0.norm() * g1.norm() + ).clamp(min=1e-12) + log.info( + "PCGrad step %d: desc dot=%.4e, cos_sim=%.4f, projected=%s", + _step_id + 1, + dot.item(), + cos_sim.item(), + projected, + ) + + # Set final grads: + # - descriptor: projected sum (from above) + # - all other params: plain sum of two tasks' grads + self.optimizer.zero_grad(set_to_none=True) + for p in all_params: + pid = id(p) + g0p, g1p = task_grads[k0][pid], task_grads[k1][pid] + if g0p is not None and g1p is not None: + p.grad = g0p + g1p + elif g0p is not None: + p.grad = g0p + elif g1p is not None: + p.grad = g1p + + more_loss = more_loss_per_task[task_key] + else: + _, loss, more_loss = self.wrapper( + **input_dict, cur_lr=pref_lr, label=label_dict, task_key=task_key + ) + loss.backward() if self.gradient_max_norm > 0.0: torch.nn.utils.clip_grad_norm_( self.wrapper.parameters(), diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index bbb0c01a4c..3ea1e023d8 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -3386,6 +3386,7 @@ def training_args( if not multi_task else [ Argument("model_prob", dict, optional=True, default={}, doc=doc_model_prob), + Argument("use_pcgrad", bool, optional=True, default=False, doc="Apply PCGrad gradient surgery on the shared descriptor parameters in multi-task training."), Argument("data_dict", dict, data_args, repeat=True, doc=doc_data_dict), ] ) From b9464e899cfb3472e27fa7cfa058542467e65b4e Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Thu, 7 May 2026 09:58:33 +0800 Subject: [PATCH 13/19] try: prob-pcgrad --- deepmd/main.py | 7 +++++++ deepmd/pt/entrypoints/main.py | 39 ++++++++++++++++++++++++++--------- 2 files changed, 36 insertions(+), 10 deletions(-) diff --git a/deepmd/main.py b/deepmd/main.py index 0cea509295..65d417b6b4 100644 --- a/deepmd/main.py +++ b/deepmd/main.py @@ -888,6 +888,13 @@ def main_parser() -> argparse.ArgumentParser: "Total batches collected = nbatches; number of dot product samples = nbatches // k. " "Larger K reduces norm variance and improves SNR of mean/std.", ) + parser_grad_probe.add_argument( + "--pcgrad", + action="store_true", + default=False, + help="Apply PCGrad projection to descriptor gradients before analysis. " + "Reports similarity of the projected vectors instead of raw gradients.", + ) return parser diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index d7bbbdc072..a2489aabd7 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -549,6 +549,9 @@ def grad_probe(FLAGS) -> None: accumulate_k = getattr(FLAGS, "accumulate_k", 1) total_batches = FLAGS.nbatches + use_pcgrad = getattr(FLAGS, "pcgrad", False) + if use_pcgrad: + log.info("grad_probe: PCGrad projection enabled — analysing projected descriptor gradients.") # Initialize trackers task_batch_norms = {k: [] for k in trainer.model_keys} @@ -566,6 +569,8 @@ def grad_probe(FLAGS) -> None: cur_lr = config["learning_rate"]["start_lr"] for b in range(total_batches): + # Collect raw descriptor grad vectors for all tasks this batch + batch_grad_vecs = {} for task_key in trainer.model_keys: trainer.optimizer.zero_grad(set_to_none=True) input_dict, label_dict, _ = trainer.get_data( @@ -587,16 +592,7 @@ def grad_probe(FLAGS) -> None: ) if current_grads: - grad_vec = np.concatenate(current_grads) - if task_accum_grads[task_key] is None: - task_accum_grads[task_key] = grad_vec.copy() - else: - task_accum_grads[task_key] += grad_vec - - if group_grads[task_key] is None: - group_grads[task_key] = grad_vec.copy() - else: - group_grads[task_key] += grad_vec + batch_grad_vecs[task_key] = np.concatenate(current_grads) if param_names is None: param_names = [ @@ -608,6 +604,29 @@ def grad_probe(FLAGS) -> None: if p.requires_grad ] + # Apply PCGrad projection to descriptor grad vectors if requested + if use_pcgrad and len(batch_grad_vecs) == 2: + keys = list(batch_grad_vecs.keys()) + k0, k1 = keys[0], keys[1] + g0, g1 = batch_grad_vecs[k0], batch_grad_vecs[k1] + dot = float(np.dot(g0, g1)) + if dot < 0: + g0_orig = g0.copy() + batch_grad_vecs[k0] = g0 - dot / max(np.dot(g1, g1), 1e-12) * g1 + batch_grad_vecs[k1] = g1 - dot / max(np.dot(g0_orig, g0_orig), 1e-12) * g0_orig + + # Accumulate (projected) grad vectors + for task_key, grad_vec in batch_grad_vecs.items(): + if task_accum_grads[task_key] is None: + task_accum_grads[task_key] = grad_vec.copy() + else: + task_accum_grads[task_key] += grad_vec + + if group_grads[task_key] is None: + group_grads[task_key] = grad_vec.copy() + else: + group_grads[task_key] += grad_vec + # At the end of each group of K batches, record norms and dot products then reset if (b + 1) % accumulate_k == 0: # Normalize to per-batch mean so scale is independent of k From 8db863d6a899a6302a965469dfec4c3345fb18ed Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Thu, 7 May 2026 12:32:26 +0800 Subject: [PATCH 14/19] feat: add dual batch baseline --- deepmd/pt/train/training.py | 86 ++++++++++++++++++------------------- deepmd/utils/argcheck.py | 1 + 2 files changed, 42 insertions(+), 45 deletions(-) diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 483ccef9af..98860570bd 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -627,8 +627,13 @@ def single_model_finetune( self.model_prob = self.model_prob / sum_prob self.use_pcgrad = self.multi_task and training_params.get("use_pcgrad", False) + self.use_dual_batch = self.multi_task and ( + self.use_pcgrad or training_params.get("use_dual_batch", False) + ) if self.use_pcgrad: log.info("PCGrad enabled: descriptor gradients will be projected each step.") + elif self.use_dual_batch: + log.info("Dual-batch enabled: all tasks sampled per step, gradients summed without projection.") # Multi-task share params if shared_links is not None: @@ -768,17 +773,8 @@ def step(_step_id, task_key="Default") -> None: pref_lr = _lr.start_lr else: pref_lr = cur_lr - if self.use_pcgrad: - module = ( - self.wrapper.module - if hasattr(self.wrapper, "module") - else self.wrapper - ) - descriptor = module.model[self.model_keys[0]].get_descriptor() - desc_params = list(descriptor.parameters()) - desc_param_ids = {id(p) for p in desc_params} + if self.use_dual_batch: all_params = list(self.wrapper.parameters()) - task_grads = {} more_loss_per_task = {} for tk in self.model_keys: @@ -796,42 +792,42 @@ def step(_step_id, task_key="Default") -> None: k0, k1 = self.model_keys[0], self.model_keys[1] - # PCGrad projection: only on descriptor params - desc_pids = [ - pid for pid in desc_param_ids - if task_grads[k0].get(pid) is not None - and task_grads[k1].get(pid) is not None - ] - if desc_pids: - g0 = torch.cat([task_grads[k0][pid].flatten() for pid in desc_pids]) - g1 = torch.cat([task_grads[k1][pid].flatten() for pid in desc_pids]) - dot = (g0 * g1).sum() - projected = dot < 0 - if projected: - g0_proj = g0 - dot / (g1 * g1).sum().clamp(min=1e-12) * g1 - g1_proj = g1 - dot / (g0 * g0).sum().clamp(min=1e-12) * g0 - offset = 0 - for pid in desc_pids: - numel = task_grads[k0][pid].numel() - shape = task_grads[k0][pid].shape - task_grads[k0][pid] = g0_proj[offset:offset + numel].reshape(shape) - task_grads[k1][pid] = g1_proj[offset:offset + numel].reshape(shape) - offset += numel - if self.rank == 0 and (_step_id + 1) % self.disp_freq == 0: - cos_sim = dot / ( - g0.norm() * g1.norm() - ).clamp(min=1e-12) - log.info( - "PCGrad step %d: desc dot=%.4e, cos_sim=%.4f, projected=%s", - _step_id + 1, - dot.item(), - cos_sim.item(), - projected, - ) + if self.use_pcgrad: + module = ( + self.wrapper.module + if hasattr(self.wrapper, "module") + else self.wrapper + ) + descriptor = module.model[k0].get_descriptor() + desc_param_ids = {id(p) for p in descriptor.parameters()} + desc_pids = [ + pid for pid in desc_param_ids + if task_grads[k0].get(pid) is not None + and task_grads[k1].get(pid) is not None + ] + if desc_pids: + g0 = torch.cat([task_grads[k0][pid].flatten() for pid in desc_pids]) + g1 = torch.cat([task_grads[k1][pid].flatten() for pid in desc_pids]) + dot = (g0 * g1).sum() + projected = dot < 0 + if projected: + g0_proj = g0 - dot / (g1 * g1).sum().clamp(min=1e-12) * g1 + g1_proj = g1 - dot / (g0 * g0).sum().clamp(min=1e-12) * g0 + offset = 0 + for pid in desc_pids: + numel = task_grads[k0][pid].numel() + shape = task_grads[k0][pid].shape + task_grads[k0][pid] = g0_proj[offset:offset + numel].reshape(shape) + task_grads[k1][pid] = g1_proj[offset:offset + numel].reshape(shape) + offset += numel + if self.rank == 0 and (_step_id + 1) % self.disp_freq == 0: + cos_sim = dot / (g0.norm() * g1.norm()).clamp(min=1e-12) + log.info( + "PCGrad step %d: desc dot=%.4e, cos_sim=%.4f, projected=%s", + _step_id + 1, dot.item(), cos_sim.item(), projected, + ) - # Set final grads: - # - descriptor: projected sum (from above) - # - all other params: plain sum of two tasks' grads + # Set final grads: sum of (projected) per-task grads self.optimizer.zero_grad(set_to_none=True) for p in all_params: pid = id(p) diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 3ea1e023d8..5a2d19114d 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -3387,6 +3387,7 @@ def training_args( else [ Argument("model_prob", dict, optional=True, default={}, doc=doc_model_prob), Argument("use_pcgrad", bool, optional=True, default=False, doc="Apply PCGrad gradient surgery on the shared descriptor parameters in multi-task training."), + Argument("use_dual_batch", bool, optional=True, default=False, doc="Sample all tasks every step and sum gradients without projection. Use as control group to isolate PCGrad effect from dual-batch effect."), Argument("data_dict", dict, data_args, repeat=True, doc=doc_data_dict), ] ) From 4308a74815e1eeb319a177656f801b0beb2ee54a Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Tue, 12 May 2026 13:03:36 +0800 Subject: [PATCH 15/19] feat: alter sampling --- deepmd/pt/train/training.py | 19 ++++++++++++++----- deepmd/utils/argcheck.py | 1 + 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 98860570bd..16728c6c81 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -630,10 +630,16 @@ def single_model_finetune( self.use_dual_batch = self.multi_task and ( self.use_pcgrad or training_params.get("use_dual_batch", False) ) + self.use_alternating = self.multi_task and ( + not self.use_dual_batch + and training_params.get("alternating_tasks", False) + ) if self.use_pcgrad: log.info("PCGrad enabled: descriptor gradients will be projected each step.") elif self.use_dual_batch: log.info("Dual-batch enabled: all tasks sampled per step, gradients summed without projection.") + elif self.use_alternating: + log.info("Alternating-tasks enabled: tasks cycled deterministically A→B→A→B each step.") # Multi-task share params if shared_links is not None: @@ -745,11 +751,14 @@ def run(self) -> None: def step(_step_id, task_key="Default") -> None: if self.multi_task: - model_index = dp_random.choice( - np.arange(self.num_model, dtype=np.int_), - p=self.model_prob, - ) - task_key = self.model_keys[model_index] + if self.use_alternating: + task_key = self.model_keys[_step_id % self.num_model] + else: + model_index = dp_random.choice( + np.arange(self.num_model, dtype=np.int_), + p=self.model_prob, + ) + task_key = self.model_keys[model_index] # PyTorch Profiler if self.enable_profiler or self.profiling: prof.step() diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 5a2d19114d..6b91f9caf2 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -3388,6 +3388,7 @@ def training_args( Argument("model_prob", dict, optional=True, default={}, doc=doc_model_prob), Argument("use_pcgrad", bool, optional=True, default=False, doc="Apply PCGrad gradient surgery on the shared descriptor parameters in multi-task training."), Argument("use_dual_batch", bool, optional=True, default=False, doc="Sample all tasks every step and sum gradients without projection. Use as control group to isolate PCGrad effect from dual-batch effect."), + Argument("alternating_tasks", bool, optional=True, default=False, doc="Cycle through tasks deterministically (A→B→A→B) each step instead of random sampling. Ablation control to isolate balanced-sampling effect from combined-gradient effect."), Argument("data_dict", dict, data_args, repeat=True, doc=doc_data_dict), ] ) From 9e0f0bf22bf50c83266347bbd1e758c5684f4caf Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Tue, 12 May 2026 14:22:30 +0800 Subject: [PATCH 16/19] fix grad reduce --- deepmd/pt/train/training.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 16728c6c81..3ffcac6e52 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -838,11 +838,12 @@ def step(_step_id, task_key="Default") -> None: # Set final grads: sum of (projected) per-task grads self.optimizer.zero_grad(set_to_none=True) + num_tasks = len(self.model_keys) for p in all_params: pid = id(p) g0p, g1p = task_grads[k0][pid], task_grads[k1][pid] if g0p is not None and g1p is not None: - p.grad = g0p + g1p + p.grad = (g0p + g1p) / num_tasks elif g0p is not None: p.grad = g0p elif g1p is not None: From 043586df2727438aab0f145ea7d40ed33c948caa Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Fri, 15 May 2026 14:27:06 +0800 Subject: [PATCH 17/19] feat: ema norm scaling --- deepmd/pt/train/training.py | 103 +++++++++++++++++++++++++++++++----- deepmd/utils/argcheck.py | 3 ++ 2 files changed, 94 insertions(+), 12 deletions(-) diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 3ffcac6e52..a7e6748238 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -634,10 +634,32 @@ def single_model_finetune( not self.use_dual_batch and training_params.get("alternating_tasks", False) ) + self.use_grad_norm_reweight = self.use_dual_batch and training_params.get( + "grad_norm_reweight", False + ) + self.use_loss_ratio_reweight = self.use_dual_batch and training_params.get( + "loss_ratio_reweight", False + ) + self.reweight_ema_decay = training_params.get("reweight_ema_decay", 0.99) + if self.use_grad_norm_reweight or self.use_loss_ratio_reweight: + self.grad_norm_ema = {k: None for k in self.model_keys} + self.loss_val_ema = {k: None for k in self.model_keys} if self.use_pcgrad: log.info("PCGrad enabled: descriptor gradients will be projected each step.") elif self.use_dual_batch: - log.info("Dual-batch enabled: all tasks sampled per step, gradients summed without projection.") + reweight_modes = [] + if self.use_grad_norm_reweight: + reweight_modes.append("grad-norm-EMA") + if self.use_loss_ratio_reweight: + reweight_modes.append("loss-ratio") + if reweight_modes: + log.info( + "Dual-batch enabled with reweighting: %s (ema_decay=%.3f).", + "+".join(reweight_modes), + self.reweight_ema_decay, + ) + else: + log.info("Dual-batch enabled: all tasks sampled per step, gradients averaged.") elif self.use_alternating: log.info("Alternating-tasks enabled: tasks cycled deterministically A→B→A→B each step.") @@ -769,13 +791,14 @@ def step(_step_id, task_key="Default") -> None: cur_lr = _lr.value(_step_id) pref_lr = cur_lr self.optimizer.zero_grad(set_to_none=True) - input_dict, label_dict, log_dict = self.get_data( - is_train=True, task_key=task_key - ) - if SAMPLER_RECORD: - print_str = f"Step {_step_id}: sample system{log_dict['sid']} frame{log_dict['fid']}\n" - fout1.write(print_str) - fout1.flush() + if not self.use_dual_batch: + input_dict, label_dict, log_dict = self.get_data( + is_train=True, task_key=task_key + ) + if SAMPLER_RECORD: + print_str = f"Step {_step_id}: sample system{log_dict['sid']} frame{log_dict['fid']}\n" + fout1.write(print_str) + fout1.flush() if self.opt_type in ["Adam", "AdamW"]: cur_lr = self.scheduler.get_last_lr()[0] if _step_id < self.warmup_steps: @@ -786,9 +809,14 @@ def step(_step_id, task_key="Default") -> None: all_params = list(self.wrapper.parameters()) task_grads = {} more_loss_per_task = {} + task_loss_val = {} for tk in self.model_keys: self.optimizer.zero_grad(set_to_none=True) - in_d, lbl_d, _ = self.get_data(is_train=True, task_key=tk) + in_d, lbl_d, log_d = self.get_data(is_train=True, task_key=tk) + if SAMPLER_RECORD and tk == self.model_keys[0]: + print_str = f"Step {_step_id}: sample system{log_d['sid']} frame{log_d['fid']}\n" + fout1.write(print_str) + fout1.flush() _, loss, more_loss = self.wrapper( **in_d, cur_lr=pref_lr, label=lbl_d, task_key=tk ) @@ -798,6 +826,7 @@ def step(_step_id, task_key="Default") -> None: for p in all_params } more_loss_per_task[tk] = more_loss + task_loss_val[tk] = loss.item() k0, k1 = self.model_keys[0], self.model_keys[1] @@ -836,14 +865,64 @@ def step(_step_id, task_key="Default") -> None: _step_id + 1, dot.item(), cos_sim.item(), projected, ) - # Set final grads: sum of (projected) per-task grads + # Compute per-task weights for gradient combination + task_weights = {k: 1.0 for k in self.model_keys} + d = self.reweight_ema_decay + + if self.use_grad_norm_reweight: + for k in self.model_keys: + grads = [ + task_grads[k][id(p)] + for p in all_params + if task_grads[k][id(p)] is not None + ] + if grads: + cur_norm = torch.stack( + [g.norm() for g in grads] + ).norm().item() + if self.grad_norm_ema[k] is None: + self.grad_norm_ema[k] = cur_norm + else: + self.grad_norm_ema[k] = ( + d * self.grad_norm_ema[k] + + (1 - d) * cur_norm + ) + task_weights[k] /= self.grad_norm_ema[k] + 1e-8 + + if self.use_loss_ratio_reweight: + for k in self.model_keys: + cur_loss_val = task_loss_val[k] + if self.loss_val_ema[k] is None: + self.loss_val_ema[k] = cur_loss_val + else: + self.loss_val_ema[k] = ( + d * self.loss_val_ema[k] + + (1 - d) * cur_loss_val + ) + task_weights[k] *= self.loss_val_ema[k] + + # Normalize weights to sum to 1 (keeps gradient scale on par with baseline) + w_sum = sum(task_weights.values()) + for k in task_weights: + task_weights[k] /= w_sum + + if self.rank == 0 and ( + self.use_grad_norm_reweight or self.use_loss_ratio_reweight + ) and (_step_id + 1) % self.disp_freq == 0: + weight_str = ", ".join( + f"{k}={task_weights[k]:.4f}" for k in self.model_keys + ) + log.info( + "Reweight step %d: [%s]", _step_id + 1, weight_str + ) + + # Set final grads: weighted average for shared params, full grad for exclusive params self.optimizer.zero_grad(set_to_none=True) - num_tasks = len(self.model_keys) for p in all_params: pid = id(p) g0p, g1p = task_grads[k0][pid], task_grads[k1][pid] if g0p is not None and g1p is not None: - p.grad = (g0p + g1p) / num_tasks + p.grad = task_weights[k0] * g0p + task_weights[k1] * g1p elif g0p is not None: p.grad = g0p elif g1p is not None: diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 6b91f9caf2..bf7ea5944f 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -3389,6 +3389,9 @@ def training_args( Argument("use_pcgrad", bool, optional=True, default=False, doc="Apply PCGrad gradient surgery on the shared descriptor parameters in multi-task training."), Argument("use_dual_batch", bool, optional=True, default=False, doc="Sample all tasks every step and sum gradients without projection. Use as control group to isolate PCGrad effect from dual-batch effect."), Argument("alternating_tasks", bool, optional=True, default=False, doc="Cycle through tasks deterministically (A→B→A→B) each step instead of random sampling. Ablation control to isolate balanced-sampling effect from combined-gradient effect."), + Argument("grad_norm_reweight", bool, optional=True, default=False, doc="(dual-batch only) Reweight per-task gradients inversely proportional to their EMA gradient norm before combining, equalizing each task's directional contribution to shared parameters."), + Argument("loss_ratio_reweight", bool, optional=True, default=False, doc="(dual-batch only) Reweight per-task gradients proportional to their EMA loss value, giving more weight to the higher-loss task to prevent it from being sacrificed."), + Argument("reweight_ema_decay", float, optional=True, default=0.99, doc="EMA decay factor for grad_norm_reweight and loss_ratio_reweight tracking. Higher values give smoother but slower-adapting estimates."), Argument("data_dict", dict, data_args, repeat=True, doc=doc_data_dict), ] ) From 412d4b73176af4f7041ec55eb1b8d6cb92089c5d Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Fri, 15 May 2026 16:00:05 +0800 Subject: [PATCH 18/19] fix: norm only use descript --- deepmd/pt/train/training.py | 171 +++++++++++++++++++++++------------- 1 file changed, 110 insertions(+), 61 deletions(-) diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index a7e6748238..a3660473c2 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -642,8 +642,12 @@ def single_model_finetune( ) self.reweight_ema_decay = training_params.get("reweight_ema_decay", 0.99) if self.use_grad_norm_reweight or self.use_loss_ratio_reweight: - self.grad_norm_ema = {k: None for k in self.model_keys} - self.loss_val_ema = {k: None for k in self.model_keys} + # Per-component EMA norms: descriptor and shared fitting_net tracked separately + self.grad_norm_ema_desc = {k: None for k in self.model_keys} + self.grad_norm_ema_fit = {k: None for k in self.model_keys} + # Two-speed EMA for loss: relative rate = fast/slow, scale-invariant + self.loss_val_ema_fast = {k: None for k in self.model_keys} + self.loss_val_ema_slow = {k: None for k in self.model_keys} if self.use_pcgrad: log.info("PCGrad enabled: descriptor gradients will be projected each step.") elif self.use_dual_batch: @@ -865,68 +869,113 @@ def step(_step_id, task_key="Default") -> None: _step_id + 1, dot.item(), cos_sim.item(), projected, ) - # Compute per-task weights for gradient combination - task_weights = {k: 1.0 for k in self.model_keys} - d = self.reweight_ema_decay - - if self.use_grad_norm_reweight: - for k in self.model_keys: - grads = [ - task_grads[k][id(p)] - for p in all_params - if task_grads[k][id(p)] is not None - ] - if grads: - cur_norm = torch.stack( - [g.norm() for g in grads] - ).norm().item() - if self.grad_norm_ema[k] is None: - self.grad_norm_ema[k] = cur_norm - else: - self.grad_norm_ema[k] = ( - d * self.grad_norm_ema[k] - + (1 - d) * cur_norm - ) - task_weights[k] /= self.grad_norm_ema[k] + 1e-8 + # Gradient combination: per-component reweighting when active, + # otherwise simple equal average for shared params + self.optimizer.zero_grad(set_to_none=True) + num_tasks = len(self.model_keys) - if self.use_loss_ratio_reweight: - for k in self.model_keys: - cur_loss_val = task_loss_val[k] - if self.loss_val_ema[k] is None: - self.loss_val_ema[k] = cur_loss_val - else: - self.loss_val_ema[k] = ( - d * self.loss_val_ema[k] - + (1 - d) * cur_loss_val - ) - task_weights[k] *= self.loss_val_ema[k] - - # Normalize weights to sum to 1 (keeps gradient scale on par with baseline) - w_sum = sum(task_weights.values()) - for k in task_weights: - task_weights[k] /= w_sum - - if self.rank == 0 and ( - self.use_grad_norm_reweight or self.use_loss_ratio_reweight - ) and (_step_id + 1) % self.disp_freq == 0: - weight_str = ", ".join( - f"{k}={task_weights[k]:.4f}" for k in self.model_keys - ) - log.info( - "Reweight step %d: [%s]", _step_id + 1, weight_str + if self.use_grad_norm_reweight or self.use_loss_ratio_reweight: + d = self.reweight_ema_decay + d_slow = 1.0 - (1.0 - d) / 10.0 + + _module = ( + self.wrapper.module + if hasattr(self.wrapper, "module") + else self.wrapper ) + _desc_param_ids = { + id(p) for p in _module.model[k0].get_descriptor().parameters() + } + _shared_pids = { + id(p) for p in all_params + if task_grads[k0].get(id(p)) is not None + and task_grads[k1].get(id(p)) is not None + } + _fit_shared_pids = _shared_pids - _desc_param_ids + + if self.use_grad_norm_reweight: + for k in self.model_keys: + desc_g = [ + task_grads[k][pid] + for pid in (_shared_pids & _desc_param_ids) + if task_grads[k].get(pid) is not None + ] + if desc_g: + cur = torch.stack([g.norm() for g in desc_g]).norm().item() + if self.grad_norm_ema_desc[k] is None: + self.grad_norm_ema_desc[k] = cur + else: + self.grad_norm_ema_desc[k] = d * self.grad_norm_ema_desc[k] + (1 - d) * cur + fit_g = [ + task_grads[k][pid] + for pid in _fit_shared_pids + if task_grads[k].get(pid) is not None + ] + if fit_g: + cur = torch.stack([g.norm() for g in fit_g]).norm().item() + if self.grad_norm_ema_fit[k] is None: + self.grad_norm_ema_fit[k] = cur + else: + self.grad_norm_ema_fit[k] = d * self.grad_norm_ema_fit[k] + (1 - d) * cur + + if self.use_loss_ratio_reweight: + for k in self.model_keys: + v = task_loss_val[k] + if self.loss_val_ema_fast[k] is None: + self.loss_val_ema_fast[k] = v + self.loss_val_ema_slow[k] = v + else: + self.loss_val_ema_fast[k] = d * self.loss_val_ema_fast[k] + (1 - d) * v + self.loss_val_ema_slow[k] = d_slow * self.loss_val_ema_slow[k] + (1 - d_slow) * v + + # Build per-component normalized weights + def _w(norm_ema): + return {k: 1.0 / (norm_ema[k] + 1e-8) if norm_ema[k] is not None else 1.0 + for k in self.model_keys} + + dw = _w(self.grad_norm_ema_desc) if self.use_grad_norm_reweight else {k: 1.0 for k in self.model_keys} + fw = _w(self.grad_norm_ema_fit) if self.use_grad_norm_reweight else {k: 1.0 for k in self.model_keys} + + if self.use_loss_ratio_reweight: + for k in self.model_keys: + rel = (self.loss_val_ema_fast[k] / (self.loss_val_ema_slow[k] + 1e-8) + if self.loss_val_ema_fast[k] is not None else 1.0) + dw[k] *= rel + fw[k] *= rel + + dw_sum = sum(dw.values()) + fw_sum = sum(fw.values()) + desc_w = {k: dw[k] / dw_sum for k in self.model_keys} + fit_w = {k: fw[k] / fw_sum for k in self.model_keys} + + if self.rank == 0 and (_step_id + 1) % self.disp_freq == 0: + log.info( + "Reweight step %d: desc=[%s] fit=[%s]", + _step_id + 1, + ", ".join(f"{k}={desc_w[k]:.4f}" for k in self.model_keys), + ", ".join(f"{k}={fit_w[k]:.4f}" for k in self.model_keys), + ) - # Set final grads: weighted average for shared params, full grad for exclusive params - self.optimizer.zero_grad(set_to_none=True) - for p in all_params: - pid = id(p) - g0p, g1p = task_grads[k0][pid], task_grads[k1][pid] - if g0p is not None and g1p is not None: - p.grad = task_weights[k0] * g0p + task_weights[k1] * g1p - elif g0p is not None: - p.grad = g0p - elif g1p is not None: - p.grad = g1p + for p in all_params: + pid = id(p) + g0p, g1p = task_grads[k0][pid], task_grads[k1][pid] + if g0p is not None and g1p is not None: + w0, w1 = (desc_w[k0], desc_w[k1]) if pid in _desc_param_ids else (fit_w[k0], fit_w[k1]) + p.grad = w0 * g0p + w1 * g1p + elif g0p is not None: + p.grad = g0p + elif g1p is not None: + p.grad = g1p + else: + for p in all_params: + pid = id(p) + g0p, g1p = task_grads[k0][pid], task_grads[k1][pid] + if g0p is not None and g1p is not None: + p.grad = (g0p + g1p) / num_tasks + elif g0p is not None: + p.grad = g0p + elif g1p is not None: + p.grad = g1p more_loss = more_loss_per_task[task_key] else: From 543e6dd0d0eff5905766e3f1a2ffdaa0b91a0bca Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Mon, 18 May 2026 18:33:35 +0800 Subject: [PATCH 19/19] feat: add noise scale --- deepmd/pt/entrypoints/main.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index a2489aabd7..8cf19c8697 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -659,13 +659,21 @@ def grad_probe(FLAGS) -> None: norm_mean = 0.0 norm_std = 0.0 + G_sq_norm = float(np.dot(avg_grad_vec, avg_grad_vec)) + norms_arr = np.array(norms) + mean_sq_norm = float(np.mean(norms_arr ** 2)) + tr_sigma = mean_sq_norm - G_sq_norm + noise_scale = tr_sigma / G_sq_norm if G_sq_norm > 1e-30 else float("nan") + grads_per_task[task_key] = avg_grad_vec grads_per_task[f"{task_key}_norm_mean"] = norm_mean grads_per_task[f"{task_key}_norm_std"] = norm_std + grads_per_task[f"{task_key}_noise_scale"] = noise_scale log.info( "Task '%s': collected %d batches (groups=%d, k=%d). " - "Avg Grad Norm=%.4e, Mean(Group Norms)=%.4e, Std(Group Norms)=%.4e", + "Avg Grad Norm=%.4e, Mean(Group Norms)=%.4e, Std(Group Norms)=%.4e, " + "Noise Scale(group)=%.4e", task_key, total_batches, total_batches // accumulate_k, @@ -673,6 +681,7 @@ def grad_probe(FLAGS) -> None: float(np.linalg.norm(avg_grad_vec)), float(norm_mean), float(norm_std), + noise_scale, ) # Compute pairwise dot product and cosine similarity statistics