Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions deepmd/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,6 +854,47 @@ def main_parser() -> argparse.ArgumentParser:
choices=["model-branch", "type-map", "descriptor", "fitting-net", "size"],
nargs="+",
)
# grad-probe: per-task descriptor gradient probe
parser_grad_probe = subparsers.add_parser(
"grad-probe",
parents=[parser_log, parser_mpi_log],
help="Compute per-task descriptor gradient vectors from a multitask checkpoint.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser_grad_probe.add_argument("INPUT", help="Training config JSON file.")
parser_grad_probe.add_argument(
"--ckpt", required=True, help="Checkpoint .pt file path."
)
parser_grad_probe.add_argument(
"-o",
"--output",
default="descriptor_grads.npz",
help="Output NPZ file path.",
)
parser_grad_probe.add_argument(
"-n",
"--nbatches",
type=int,
default=1,
help="Number of batches per task to average gradient over.",
)
parser_grad_probe.add_argument(
"-k",
"--accumulate-k",
type=int,
default=1,
dest="accumulate_k",
help="Group size for gradient accumulation before computing each dot product. "
"Total batches collected = nbatches; number of dot product samples = nbatches // k. "
"Larger K reduces norm variance and improves SNR of mean/std.",
)
parser_grad_probe.add_argument(
"--pcgrad",
action="store_true",
default=False,
help="Apply PCGrad projection to descriptor gradients before analysis. "
"Reports similarity of the projected vectors instead of raw gradients.",
)
return parser


Expand Down Expand Up @@ -919,6 +960,7 @@ def main(args: Optional[list[str]] = None) -> None:
"convert-from",
"train-nvnmd",
"change-bias",
"grad-probe",
):
deepmd_main = BACKENDS[args.backend]().entry_point_hook
elif args.command is None:
Expand Down
242 changes: 242 additions & 0 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,246 @@ def change_bias(
log.info(f"Saved model to {output_path}")


def grad_probe(FLAGS) -> None:
"""Compute per-task descriptor gradient vectors for multitask conflict analysis."""
import numpy as np

from deepmd.common import j_loader
from deepmd.pt.utils.multi_task import preprocess_shared_params
from deepmd.utils.argcheck import normalize
from deepmd.utils.compat import update_deepmd_input

config = j_loader(FLAGS.INPUT)
config = update_deepmd_input(config, warning=True, dump="input_v2_compat.json")
multi_task = "model_dict" in config.get("model", {})
shared_links = None
if multi_task:
config["model"], shared_links = preprocess_shared_params(config["model"])
config = normalize(config, multi_task=multi_task)

trainer = get_trainer(
config,
restart_model=FLAGS.ckpt,
shared_links=shared_links,
)
trainer.wrapper.eval()

module = (
trainer.wrapper.module
if hasattr(trainer.wrapper, "module")
else trainer.wrapper
)

grads_per_task: dict = {}
param_names: list | None = None
param_shapes: list | None = None

accumulate_k = getattr(FLAGS, "accumulate_k", 1)
total_batches = FLAGS.nbatches
use_pcgrad = getattr(FLAGS, "pcgrad", False)
if use_pcgrad:
log.info("grad_probe: PCGrad projection enabled — analysing projected descriptor gradients.")

# Initialize trackers
task_batch_norms = {k: [] for k in trainer.model_keys}
task_accum_grads = {k: None for k in trainer.model_keys}
pairwise_dots = {
(k1, k2): []
for i, k1 in enumerate(trainer.model_keys)
for j, k2 in enumerate(trainer.model_keys)
if i < j
}

# Accumulators within each group of K batches
group_grads = {k: None for k in trainer.model_keys}

cur_lr = config["learning_rate"]["start_lr"]

for b in range(total_batches):
# Collect raw descriptor grad vectors for all tasks this batch
batch_grad_vecs = {}
for task_key in trainer.model_keys:
trainer.optimizer.zero_grad(set_to_none=True)
input_dict, label_dict, _ = trainer.get_data(
is_train=True, task_key=task_key
)
_, loss, _ = trainer.wrapper(
**input_dict, cur_lr=cur_lr, label=label_dict, task_key=task_key
)
loss.backward()

model = module.model[task_key]
descriptor = model.get_descriptor()

current_grads = []
for name, param in descriptor.named_parameters():
if param.grad is not None:
current_grads.append(
param.grad.detach().cpu().float().numpy().ravel()
)

if current_grads:
batch_grad_vecs[task_key] = np.concatenate(current_grads)

if param_names is None:
param_names = [
n for n, p in descriptor.named_parameters() if p.requires_grad
]
param_shapes = [
list(p.shape)
for n, p in descriptor.named_parameters()
if p.requires_grad
]

# Apply PCGrad projection to descriptor grad vectors if requested
if use_pcgrad and len(batch_grad_vecs) == 2:
keys = list(batch_grad_vecs.keys())
k0, k1 = keys[0], keys[1]
g0, g1 = batch_grad_vecs[k0], batch_grad_vecs[k1]
dot = float(np.dot(g0, g1))
if dot < 0:
g0_orig = g0.copy()
batch_grad_vecs[k0] = g0 - dot / max(np.dot(g1, g1), 1e-12) * g1
batch_grad_vecs[k1] = g1 - dot / max(np.dot(g0_orig, g0_orig), 1e-12) * g0_orig

# Accumulate (projected) grad vectors
for task_key, grad_vec in batch_grad_vecs.items():
if task_accum_grads[task_key] is None:
task_accum_grads[task_key] = grad_vec.copy()
else:
task_accum_grads[task_key] += grad_vec

if group_grads[task_key] is None:
group_grads[task_key] = grad_vec.copy()
else:
group_grads[task_key] += grad_vec

# At the end of each group of K batches, record norms and dot products then reset
if (b + 1) % accumulate_k == 0:
# Normalize to per-batch mean so scale is independent of k
group_means = {
k: g / accumulate_k if g is not None else None
for k, g in group_grads.items()
}
for task_key in trainer.model_keys:
if group_means[task_key] is not None:
task_batch_norms[task_key].append(
np.linalg.norm(group_means[task_key])
)
for k1, k2 in pairwise_dots.keys():
g1 = group_means.get(k1)
g2 = group_means.get(k2)
if g1 is not None and g2 is not None:
pairwise_dots[(k1, k2)].append(float(np.dot(g1, g2)))
group_grads = {k: None for k in trainer.model_keys}

# Compute final statistics and log for each task
for task_key in trainer.model_keys:
accumulated_grads = task_accum_grads[task_key]
norms = task_batch_norms[task_key]
if accumulated_grads is not None:
avg_grad_vec = accumulated_grads / total_batches
norm_mean = np.mean(norms)
norm_std = np.std(norms)
else:
avg_grad_vec = np.array([], dtype=np.float32)
norm_mean = 0.0
norm_std = 0.0

G_sq_norm = float(np.dot(avg_grad_vec, avg_grad_vec))
norms_arr = np.array(norms)
mean_sq_norm = float(np.mean(norms_arr ** 2))
tr_sigma = mean_sq_norm - G_sq_norm
noise_scale = tr_sigma / G_sq_norm if G_sq_norm > 1e-30 else float("nan")

grads_per_task[task_key] = avg_grad_vec
grads_per_task[f"{task_key}_norm_mean"] = norm_mean
grads_per_task[f"{task_key}_norm_std"] = norm_std
grads_per_task[f"{task_key}_noise_scale"] = noise_scale

log.info(
"Task '%s': collected %d batches (groups=%d, k=%d). "
"Avg Grad Norm=%.4e, Mean(Group Norms)=%.4e, Std(Group Norms)=%.4e, "
"Noise Scale(group)=%.4e",
task_key,
total_batches,
total_batches // accumulate_k,
accumulate_k,
float(np.linalg.norm(avg_grad_vec)),
float(norm_mean),
float(norm_std),
noise_scale,
)

# Compute pairwise dot product and cosine similarity statistics
for (k1, k2), dots in pairwise_dots.items():
if dots:
dot_mean = float(np.mean(dots))
dot_std = float(np.std(dots))

g1_avg = grads_per_task[k1]
g2_avg = grads_per_task[k2]
dot_global = float(np.dot(g1_avg, g2_avg))

# Group-level cosine: per-group cos_i = dot_i / (||g1_i|| * ||g2_i||)
norms_1 = np.array(task_batch_norms[k1])
norms_2 = np.array(task_batch_norms[k2])
norm_products = norms_1 * norms_2
valid = norm_products > 1e-12
cos_per_group = np.where(
valid,
np.array(dots) / np.where(valid, norm_products, 1.0),
float("nan"),
)
valid_cos = cos_per_group[valid]
if len(valid_cos) > 0:
cos_group_mean = float(np.mean(valid_cos))
cos_group_std = float(np.std(valid_cos))
else:
cos_group_mean = float("nan")
cos_group_std = float("nan")

# Global-level cosine: cosine of accumulated average gradient vectors
norm_g1 = float(np.linalg.norm(g1_avg))
norm_g2 = float(np.linalg.norm(g2_avg))
denom_global = norm_g1 * norm_g2
cos_global = (
float(dot_global / denom_global)
if denom_global > 1e-12
else float("nan")
)

grads_per_task[f"dot_{k1}_{k2}_mean"] = dot_mean
grads_per_task[f"dot_{k1}_{k2}_std"] = dot_std
grads_per_task[f"dot_{k1}_{k2}_global"] = dot_global
grads_per_task[f"cos_group_mean_{k1}_{k2}"] = cos_group_mean
grads_per_task[f"cos_group_std_{k1}_{k2}"] = cos_group_std
grads_per_task[f"cos_global_{k1}_{k2}"] = cos_global

log.info(
"Grad similarity '%s' vs '%s': "
"dot_mean=%.4e, dot_std=%.4e, dot_global=%.4e, "
"cos_group_mean=%.4f, cos_group_std=%.4f, cos_global=%.4f "
"(n_groups=%d)",
k1,
k2,
dot_mean,
dot_std,
dot_global,
cos_group_mean,
cos_group_std,
cos_global,
len(dots),
)

save_dict = {f"grads_{k}": v for k, v in grads_per_task.items()}
save_dict["param_names"] = np.array(param_names, dtype=object)
save_dict["param_shapes"] = np.array(param_shapes, dtype=object)
np.savez(FLAGS.output, **save_dict)
log.info("Descriptor gradient vectors and similarity stats saved to: %s", FLAGS.output)



@record
def main(args: Optional[Union[list[str], argparse.Namespace]] = None) -> None:
if not isinstance(args, argparse.Namespace):
Expand Down Expand Up @@ -572,6 +812,8 @@ def main(args: Optional[Union[list[str], argparse.Namespace]] = None) -> None:
check_frequency=FLAGS.frequency,
training_script=FLAGS.training_script,
)
elif FLAGS.command == "grad-probe":
grad_probe(FLAGS)
else:
raise RuntimeError(f"Invalid command {FLAGS.command}!")

Expand Down
16 changes: 5 additions & 11 deletions deepmd/pt/loss/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,13 +151,7 @@ def __init__(
assert self.use_huber or self.use_l1_all, (
"f_use_norm can only be True when use_huber or use_l1_all is True."
)
if self.use_huber and (
self.has_pf or self.has_gf or self.relative_f is not None
):
raise RuntimeError(
"Huber loss is not implemented for force with atom_pref, generalized force and relative force. "
)


def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
"""Return loss on energy and force.

Expand Down Expand Up @@ -220,7 +214,7 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
l2_ener_loss.detach(), find_energy
)
if not self.use_huber:
loss += atom_norm * (pref_e * l2_ener_loss)
loss += atom_norm**2 * (pref_e * l2_ener_loss)
else:
l_huber_loss = custom_huber_loss(
atom_norm * model_pred["energy"],
Expand Down Expand Up @@ -346,8 +340,8 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
)
else:
l_huber_loss = custom_huber_loss(
(atom_pref * force_pred).reshape(-1),
(atom_pref * force_label).reshape(-1),
atom_pref_reshape * force_pred.reshape(-1),
atom_pref_reshape * force_label.reshape(-1),
delta=self.huber_delta,
)
loss += pref_pf * l_huber_loss
Expand Down Expand Up @@ -405,7 +399,7 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
l2_virial_loss.detach(), find_virial
)
if not self.use_huber:
loss += atom_norm * (pref_v * l2_virial_loss)
loss += atom_norm**2 * (pref_v * l2_virial_loss)
else:
l_huber_loss = custom_huber_loss(
atom_norm * model_pred["virial"].reshape(-1),
Expand Down
6 changes: 6 additions & 0 deletions deepmd/pt/model/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,6 +656,8 @@ def forward(
nlist: torch.Tensor,
mapping: Optional[torch.Tensor] = None,
comm_dict: Optional[dict[str, torch.Tensor]] = None,
fparam: Optional[torch.Tensor] = None,
case_embd: Optional[torch.Tensor] = None,
):
"""Compute the descriptor.

Expand All @@ -671,6 +673,10 @@ def forward(
The index mapping, not required by this descriptor.
comm_dict
The data needed for communication for parallel inference.
fparam
The frame-level parameters. shape: nf x nfparam
case_embd
The case (dataset) embedding for multitask training with shared fitting. shape: nf x dim_case_embd

Returns
-------
Expand Down
Loading
Loading