Skip to content

feat: from-scratch, external-checkpoint, and Gram-loss training variants#67

Open
anas-zafar wants to merge 2 commits intoMedARC-AI:mainfrom
anas-zafar:feat/training-starting-points
Open

feat: from-scratch, external-checkpoint, and Gram-loss training variants#67
anas-zafar wants to merge 2 commits intoMedARC-AI:mainfrom
anas-zafar:feat/training-starting-points

Conversation

@anas-zafar
Copy link
Copy Markdown

Summary

Implements three alternative training configurations to ablate against the existing DINOv2-pretrained fine-tuning baseline.

1. From-scratch training

New config vitg14_reg4_scratch.yaml / launch script run_scratch.sh.
Key changes vs. pretrained baseline: use_pretrained: False, base_lr: 1.5e-03 (~7.5x higher, calibrated from DINOv2 paper ViT-g numbers scaled to our effective batch), warmup_epochs: 20 (2x longer for early-training stability), layerwise_decay: 0.9 (standard LLRD; baseline uses 1.0 to preserve pretrained features).

2. External / DINOv3-checkpoint starting point

New config vitg14_reg4_dinov3init.yaml / launch script run_dinov3init.sh (set STAGE1_CHECKPOINT= env var).
New _load_from_external_checkpoint() in train.py auto-detects three checkpoint formats:

  • Teacher checkpoint {"teacher": {"backbone.*": ...}} (OpenMidnight teacher_checkpoint.pth or MedARC-AI/path-fm-dinov3 format)
  • Nested backbone {"model": {...}} or {"backbone": {...}}
  • Flat state-dict (Meta-style raw backbone release)

Loads with strict=False; architecture-incompatible keys (e.g. DINOv3 RoPE tensors) are skipped with a warning.

3. DINOv3-style Gram (patch-similarity) loss

New module dinov2/loss/gram_loss.py (ported from MedARC-AI/path-fm-dinov3).
New config vitg14_reg4_gram.yaml / launch script run_gram.sh (tunable GRAM_LOSS_WEIGHT env var).

The loss computes MSE between pairwise cosine-similarity matrices of student and teacher patch tokens - the defining new objective in DINOv3. Uses EMA-teacher mode: the existing EMA teacher patch tokens serve as Gram targets; no separate frozen model is needed. Enabled with gram.use_loss: true; fully configurable (loss_weight, img_level, normalized, remove_neg).

Files changed

File Change
dinov2/loss/gram_loss.py New - GramLoss class
dinov2/loss/__init__.py Export GramLoss
dinov2/configs/ssl_default_config.yaml Add train.pretrained_weights default + gram: config block
dinov2/train/ssl_meta_arch.py Integrate Gram loss (init + forward_backward)
dinov2/train/train.py Add _load_from_external_checkpoint(), wire into main()
dinov2/configs/train/vitg14_reg4_scratch.yaml New - from-scratch config
dinov2/configs/train/vitg14_reg4_dinov3init.yaml New - external-checkpoint config
dinov2/configs/train/vitg14_reg4_gram.yaml New - Gram loss config
run_scratch.sh New - from-scratch launch script
run_gram.sh New - Gram loss launch script
run_dinov3init.sh New - external-checkpoint launch script

Backward compatibility

All changes are fully additive and opt-in:

  • gram.use_loss defaults to false - existing runs are unaffected
  • train.pretrained_weights defaults to '' - external loading is only triggered when non-empty
  • use_pretrained: True (existing default) continues to use _load_pretrained_backbone unchanged

…ariants

Implements three alternative training configurations to compare against the
existing DINOv2-pretrained fine-tuning baseline:

1. From-scratch training (vitg14_reg4_scratch.yaml / run_scratch.sh)
   - use_pretrained: False -- random initialisation, no torch.hub download
   - base_lr: 1.5e-03 -- ~7.5x higher than fine-tuning LR
   - warmup_epochs: 20 -- longer warmup for stable early training
   - layerwise_decay: 0.9 -- standard LLRD (fine-tuning uses 1.0)

2. External / DINOv3-checkpoint init (vitg14_reg4_dinov3init.yaml / run_dinov3init.sh)
   New _load_from_external_checkpoint() in train.py auto-detects three
   checkpoint formats (teacher ckpt, nested backbone, flat state-dict) and
   loads backbone weights with strict=False. Architecture-incompatible keys
   (e.g. DINOv3 RoPE tensors) are skipped with a warning. Pass the path via
   train.pretrained_weights CLI override or STAGE1_CHECKPOINT env-var.

3. DINOv3-style Gram (patch-similarity) loss (vitg14_reg4_gram.yaml / run_gram.sh)
   New GramLoss module (dinov2/loss/gram_loss.py) ported from
   MedARC-AI/path-fm-dinov3. Computes MSE between pairwise cosine-similarity
   matrices of student and teacher patch tokens (the key DINOv3 objective).
   EMA-teacher mode: existing EMA teacher patch tokens serve as Gram targets,
   no separate frozen model needed. Enabled with gram.use_loss: true.

Default config (ssl_default_config.yaml) gains:
   - train.pretrained_weights: '' (used by external-checkpoint loader)
   - gram: block with all options defaulting to off
Loss values stored in loss_dict are used for metric logging only; keeping
the computational graph attached wastes memory. Matches the pattern already
used by all other losses in forward_backward().
@anas-zafar anas-zafar force-pushed the feat/training-starting-points branch 2 times, most recently from 615ee02 to d924304 Compare March 22, 2026 17:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant