feat: from-scratch, external-checkpoint, and Gram-loss training variants#67
Open
anas-zafar wants to merge 2 commits intoMedARC-AI:mainfrom
Open
feat: from-scratch, external-checkpoint, and Gram-loss training variants#67anas-zafar wants to merge 2 commits intoMedARC-AI:mainfrom
anas-zafar wants to merge 2 commits intoMedARC-AI:mainfrom
Conversation
…ariants Implements three alternative training configurations to compare against the existing DINOv2-pretrained fine-tuning baseline: 1. From-scratch training (vitg14_reg4_scratch.yaml / run_scratch.sh) - use_pretrained: False -- random initialisation, no torch.hub download - base_lr: 1.5e-03 -- ~7.5x higher than fine-tuning LR - warmup_epochs: 20 -- longer warmup for stable early training - layerwise_decay: 0.9 -- standard LLRD (fine-tuning uses 1.0) 2. External / DINOv3-checkpoint init (vitg14_reg4_dinov3init.yaml / run_dinov3init.sh) New _load_from_external_checkpoint() in train.py auto-detects three checkpoint formats (teacher ckpt, nested backbone, flat state-dict) and loads backbone weights with strict=False. Architecture-incompatible keys (e.g. DINOv3 RoPE tensors) are skipped with a warning. Pass the path via train.pretrained_weights CLI override or STAGE1_CHECKPOINT env-var. 3. DINOv3-style Gram (patch-similarity) loss (vitg14_reg4_gram.yaml / run_gram.sh) New GramLoss module (dinov2/loss/gram_loss.py) ported from MedARC-AI/path-fm-dinov3. Computes MSE between pairwise cosine-similarity matrices of student and teacher patch tokens (the key DINOv3 objective). EMA-teacher mode: existing EMA teacher patch tokens serve as Gram targets, no separate frozen model needed. Enabled with gram.use_loss: true. Default config (ssl_default_config.yaml) gains: - train.pretrained_weights: '' (used by external-checkpoint loader) - gram: block with all options defaulting to off
Loss values stored in loss_dict are used for metric logging only; keeping the computational graph attached wastes memory. Matches the pattern already used by all other losses in forward_backward().
615ee02 to
d924304
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Implements three alternative training configurations to ablate against the existing DINOv2-pretrained fine-tuning baseline.
1. From-scratch training
New config
vitg14_reg4_scratch.yaml/ launch scriptrun_scratch.sh.Key changes vs. pretrained baseline:
use_pretrained: False,base_lr: 1.5e-03(~7.5x higher, calibrated from DINOv2 paper ViT-g numbers scaled to our effective batch),warmup_epochs: 20(2x longer for early-training stability),layerwise_decay: 0.9(standard LLRD; baseline uses 1.0 to preserve pretrained features).2. External / DINOv3-checkpoint starting point
New config
vitg14_reg4_dinov3init.yaml/ launch scriptrun_dinov3init.sh(setSTAGE1_CHECKPOINT=env var).New
_load_from_external_checkpoint()intrain.pyauto-detects three checkpoint formats:{"teacher": {"backbone.*": ...}}(OpenMidnight teacher_checkpoint.pth or MedARC-AI/path-fm-dinov3 format){"model": {...}}or{"backbone": {...}}Loads with
strict=False; architecture-incompatible keys (e.g. DINOv3 RoPE tensors) are skipped with a warning.3. DINOv3-style Gram (patch-similarity) loss
New module
dinov2/loss/gram_loss.py(ported from MedARC-AI/path-fm-dinov3).New config
vitg14_reg4_gram.yaml/ launch scriptrun_gram.sh(tunableGRAM_LOSS_WEIGHTenv var).The loss computes MSE between pairwise cosine-similarity matrices of student and teacher patch tokens - the defining new objective in DINOv3. Uses EMA-teacher mode: the existing EMA teacher patch tokens serve as Gram targets; no separate frozen model is needed. Enabled with
gram.use_loss: true; fully configurable (loss_weight,img_level,normalized,remove_neg).Files changed
dinov2/loss/gram_loss.pydinov2/loss/__init__.pydinov2/configs/ssl_default_config.yamltrain.pretrained_weightsdefault +gram:config blockdinov2/train/ssl_meta_arch.pydinov2/train/train.py_load_from_external_checkpoint(), wire intomain()dinov2/configs/train/vitg14_reg4_scratch.yamldinov2/configs/train/vitg14_reg4_dinov3init.yamldinov2/configs/train/vitg14_reg4_gram.yamlrun_scratch.shrun_gram.shrun_dinov3init.shBackward compatibility
All changes are fully additive and opt-in:
gram.use_lossdefaults tofalse- existing runs are unaffectedtrain.pretrained_weightsdefaults to''- external loading is only triggered when non-emptyuse_pretrained: True(existing default) continues to use_load_pretrained_backboneunchanged