Conversation
Add DensityWeightedNormMAE loss function that upweights errors at high-density voxels using w(rho) = 1 + alpha * (rho/rho_max)^power. This addresses the systematic underprediction of peak charge densities observed in voxel error analysis (issue #101). Key changes: - DensityWeightedNormMAE class in model/loss/charge.py - Config-driven loss selection via `loss_fn` YAML key (defaults to normmae) - Always log val_normmae during validation for cross-experiment comparison - Log test_normmae in test_step (fixes bug where training loss was written instead of NormMAE to metrics CSV) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
… CSV metric - Extract _ChargeLoss base class with shared list-dispatch forward, eliminating duplicated boilerplate between NormMAE and DensityWeightedNormMAE - Fix strict=False to strict=True in loss forward (match caller convention) - Skip redundant NormMAE computation in validation/test when loss_fn is already NormMAE (avoids duplicate 128³ voxel L1 calculation per batch) - Short-circuit power=1.0 in DensityWeightedNormMAE to avoid needless torch.pow - Fix test_step CSV bug: was writing training loss as "nmae", now writes normmae Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
|
@claude review this PR |
|
Claude finished @forklady42's task in 2m 25s —— View job PR Review: Add density-weighted NormMAE loss function
Overall this is a solid, well-motivated PR. The implementation is correct, the bug fix to
|
| Severity | Issue |
|---|---|
| Should fix | Negative target values can produce negative weights in DensityWeightedNormMAE |
| Should fix | strict=False → strict=True behavior change has no test coverage |
| Nice to fix | axis= vs dim= inconsistency in NormMAE._forward |
| Nice to fix | normmae_fn always instantiated; normmae possibly-unbound |
| Informational | _build_loss_fn special-casing will need updating for future losses |
| Pre-existing | test_step doesn't handle list batch inputs |
The negative-weight concern and the missing strict=True test are the only things I'd push back on before merging. Everything else is either cosmetic or pre-existing.
- Clamp normalized target to [0, inf) in DensityWeightedNormMAE to prevent negative weights if target has negative voxels - Use dim= consistently instead of axis= in NormMAE._forward - Reuse loss_fn as normmae_fn when it's already NormMAE (avoids duplicate module in state_dict) - Initialize normmae to None to satisfy static analysis - Add test for strict=True on mismatched list lengths - Add test for list input with non-unity power Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
|
@claude review again |
|
Claude finished @forklady42's task in 1m 16s —— View job PR Review: Add density-weighted NormMAE loss function
All issues from the previous review have been addressed. This is ready to merge. What was fixed
One minor note
Everything else looks correct: the |
Summary
Adds a density-weighted variant of NormMAE that upweights errors at high-density voxels, addressing the systematic underprediction of peak charge densities identified in voxel error analysis (#101).
DensityWeightedNormMAE: weight functionw(ρ) = 1 + α(ρ/ρ_max)^power, configurable vialoss_alphaandloss_powerYAML keysloss_fn: density_weighted_normmaein YAML config (defaults tonormmaefor backward compatibility)val_normmaemetric: always logged during validation for apples-to-apples comparison across loss experimentstest_normmaefix: test_step now logs NormMAE separately from training loss, fixing a bug where the training loss value was written to metrics CSVExperiment results (from
loss/all-experiments)On the full MP dataset (113K, 18 epochs, 4× A100), density-weighted ties baseline on mean NMAE (2.19%) and achieves the lowest max NMAE (4.17% vs 4.66%), improving 2/3 of the worst-10% tail structures. See #101 for full analysis.
Test plan
DensityWeightedNormMAE(scalar output, alpha=0 matches NormMAE, higher alpha increases loss, list inputs, perfect prediction, zero target)test_srgan_layernorm_pbc.py)🤖 Generated with Claude Code