Skip to content

Add density-weighted NormMAE loss function#108

Draft
forklady42 wants to merge 4 commits intomainfrom
loss/density-weighted
Draft

Add density-weighted NormMAE loss function#108
forklady42 wants to merge 4 commits intomainfrom
loss/density-weighted

Conversation

@forklady42
Copy link
Copy Markdown
Collaborator

Summary

Adds a density-weighted variant of NormMAE that upweights errors at high-density voxels, addressing the systematic underprediction of peak charge densities identified in voxel error analysis (#101).

  • DensityWeightedNormMAE: weight function w(ρ) = 1 + α(ρ/ρ_max)^power, configurable via loss_alpha and loss_power YAML keys
  • Config-driven loss selection: set loss_fn: density_weighted_normmae in YAML config (defaults to normmae for backward compatibility)
  • val_normmae metric: always logged during validation for apples-to-apples comparison across loss experiments
  • test_normmae fix: test_step now logs NormMAE separately from training loss, fixing a bug where the training loss value was written to metrics CSV

Experiment results (from loss/all-experiments)

On the full MP dataset (113K, 18 epochs, 4× A100), density-weighted ties baseline on mean NMAE (2.19%) and achieves the lowest max NMAE (4.17% vs 4.66%), improving 2/3 of the worst-10% tail structures. See #101 for full analysis.

Test plan

  • 7 unit tests for DensityWeightedNormMAE (scalar output, alpha=0 matches NormMAE, higher alpha increases loss, list inputs, perfect prediction, zero target)
  • Full test suite passes (59/60 — 1 pre-existing failure in test_srgan_layernorm_pbc.py)

🤖 Generated with Claude Code

forklady42 and others added 3 commits March 30, 2026 16:14
Add DensityWeightedNormMAE loss function that upweights errors at high-density
voxels using w(rho) = 1 + alpha * (rho/rho_max)^power. This addresses the
systematic underprediction of peak charge densities observed in voxel error
analysis (issue #101).

Key changes:
- DensityWeightedNormMAE class in model/loss/charge.py
- Config-driven loss selection via `loss_fn` YAML key (defaults to normmae)
- Always log val_normmae during validation for cross-experiment comparison
- Log test_normmae in test_step (fixes bug where training loss was written
  instead of NormMAE to metrics CSV)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
… CSV metric

- Extract _ChargeLoss base class with shared list-dispatch forward, eliminating
  duplicated boilerplate between NormMAE and DensityWeightedNormMAE
- Fix strict=False to strict=True in loss forward (match caller convention)
- Skip redundant NormMAE computation in validation/test when loss_fn is already
  NormMAE (avoids duplicate 128³ voxel L1 calculation per batch)
- Short-circuit power=1.0 in DensityWeightedNormMAE to avoid needless torch.pow
- Fix test_step CSV bug: was writing training loss as "nmae", now writes normmae

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@forklady42
Copy link
Copy Markdown
Collaborator Author

@claude review this PR

@claude
Copy link
Copy Markdown

claude bot commented Mar 31, 2026

Claude finished @forklady42's task in 2m 25s —— View job


PR Review: Add density-weighted NormMAE loss function

  • Gather context and read changed files
  • Review src/electrai/model/loss/charge.py
  • Review src/electrai/lightning.py
  • Review tests/electrai/model/loss/test_charge.py
  • Summarize findings

Overall this is a solid, well-motivated PR. The implementation is correct, the bug fix to test_step CSV metrics is real and worth having, and the 7 unit tests cover the important cases. A few issues worth addressing before merge, ranging from a potential numerical bug to design concerns.


src/electrai/model/loss/charge.py

Potential numerical issue: negative target values not guarded

DensityWeightedNormMAE._forward computes normalized = target / rho_max using target values directly. If any voxel has a negative value (plausible with noisy model outputs or pre-processed data), weights = 1 + alpha * normalized can go negative for alpha > 1, turning the loss into a reward for errors at those voxels. The .clamp(min=1e-8) only protects rho_max, not the per-voxel values.

# charge.py:49-53
rho_max = target.amax(dim=(-3, -2, -1), keepdim=True).clamp(min=1e-8)
normalized = target / rho_max          # ← negative target → negative weight
if self.power != 1.0:
    normalized = normalized**self.power
weights = 1.0 + self.alpha * normalized

If charge densities are always physically non-negative by the time they reach this loss, this is fine — but it's worth either asserting that or clamping: normalized = (target / rho_max).clamp(min=0). Fix this →

axis= vs dim= inconsistency

NormMAE._forward uses torch.sum(target, axis=...) while DensityWeightedNormMAE._forward uses torch.sum(target, dim=...). Both work in PyTorch, but the inconsistency is distracting in the same file. Prefer dim= uniformly. Fix this →


src/electrai/lightning.py

_build_loss_fn is hard-coded to DensityWeightedNormMAE

The special-case if loss_cls is DensityWeightedNormMAE means adding any other parameterized loss requires touching this method again. A cleaner approach would be to pass all relevant config keys as **kwargs or to store loss_alpha/loss_power in a sub-dict. For a two-loss codebase this is fine as-is, but worth noting if more variants are planned. Fix this →

normmae is potentially unbound per static analysis

In _loss_calculation, normmae is assigned inside if isinstance(x, list): ... if compute_normmae: normmae = ... and else: ... if compute_normmae: normmae = .... These are guarded by the same compute_normmae flag as the final return, so it's safe at runtime. But linters/mypy will flag normmae as possibly unbound. Initializing normmae: torch.Tensor | None = None at the top of the method would silence this. Fix this →

normmae_fn is always instantiated even when unused

self.normmae_fn = NormMAE() is always created in __init__, even when loss_fn is already NormMAE. This wastes a tiny bit of memory/state and doubles the parameter count shown by save_hyperparameters. A simple guard: self.normmae_fn = self.loss_fn if isinstance(self.loss_fn, NormMAE) else NormMAE().

test_step doesn't handle list inputs

test_step calls self(x) directly without checking isinstance(x, list), while _loss_calculation handles both paths. This was pre-existing, but the new normmae computation at line 147 (self.normmae_fn(preds, y)) would silently pass list tensors to NormMAE.forward which does handle lists — so it won't crash. Still, the inconsistency between test_step and validation_step is worth a note for future maintainers.


tests/electrai/model/loss/test_charge.py

strict=Falsestrict=True change is untested

The base class _ChargeLoss.forward changed zip(..., strict=False) to zip(..., strict=True) — a behavior change that would now raise ValueError on mismatched-length list inputs instead of silently truncating. There's no test asserting this new behavior. Fix this →

No test for non-unity power behavior in list inputs

The list-input tests only use default power=1.0. The test_perfect_prediction_gives_zero test uses power=2.0 but only for a single tensor — worth adding a list test with power != 1 to confirm the code path through _ChargeLoss.forward_forwardnormalized**self.power works end-to-end with list dispatch.


Summary

Severity Issue
Should fix Negative target values can produce negative weights in DensityWeightedNormMAE
Should fix strict=Falsestrict=True behavior change has no test coverage
Nice to fix axis= vs dim= inconsistency in NormMAE._forward
Nice to fix normmae_fn always instantiated; normmae possibly-unbound
Informational _build_loss_fn special-casing will need updating for future losses
Pre-existing test_step doesn't handle list batch inputs

The negative-weight concern and the missing strict=True test are the only things I'd push back on before merging. Everything else is either cosmetic or pre-existing.

- Clamp normalized target to [0, inf) in DensityWeightedNormMAE to prevent
  negative weights if target has negative voxels
- Use dim= consistently instead of axis= in NormMAE._forward
- Reuse loss_fn as normmae_fn when it's already NormMAE (avoids duplicate
  module in state_dict)
- Initialize normmae to None to satisfy static analysis
- Add test for strict=True on mismatched list lengths
- Add test for list input with non-unity power

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@forklady42
Copy link
Copy Markdown
Collaborator Author

@claude review again

@claude
Copy link
Copy Markdown

claude bot commented Apr 7, 2026

Claude finished @forklady42's task in 1m 16s —— View job


PR Review: Add density-weighted NormMAE loss function

  • Gather context
  • Read changed files
  • Review findings
  • Post final review

All issues from the previous review have been addressed. This is ready to merge.

What was fixed

Previous issue Resolution
Negative target values → negative weights (target / rho_max).clamp(min=0) at charge.py:50
strict=True behavior change untested test_mismatched_list_lengths_raises added at test_charge.py:80
axis= vs dim= inconsistency NormMAE._forward now uses dim= at charge.py:29
normmae_fn always instantiated Conditional assignment at lightning.py:27-29
normmae possibly-unbound `normmae: torch.Tensor
No list test with power != 1 test_list_with_non_unity_power added at test_charge.py:86

One minor note

test_mismatched_list_lengths_raises uses match="zip" (test_charge.py:83). The strict=True zip error message is stable in Python 3.10+ ("zip() has arguments with different lengths"), so this is fine — just worth knowing the match is coupling to CPython's error text if you ever swap implementations.

Everything else looks correct: the _loss_is_normmae property correctly returns False for DensityWeightedNormMAE(alpha=0) (separate class, no inheritance from NormMAE), _loss_calculation list dispatch is clean, and test_step now logs test_normmae separately from the loss value via self.normmae_fn.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant