Skip to content

Fsdp2 stormscope [WIP]#1671

Open
negin513 wants to merge 4 commits into
NVIDIA:mainfrom
negin513:fsdp2-stormcast
Open

Fsdp2 stormscope [WIP]#1671
negin513 wants to merge 4 commits into
NVIDIA:mainfrom
negin513:fsdp2-stormcast

Conversation

@negin513
Copy link
Copy Markdown
Member

PhysicsNeMo Pull Request

Migrates StormScope off FSDP1 onto FSDP2 (fully_shard / FSDPModule)....

Description

FSDP1's flat-param machinery doesn't compose with ShardTensor / DTensor.

This is the immediate motivator: the refactored ShardTensor in #1556 breaks FSDP1's backward pass in StormScope, so until StormScope/StormCast move to FSDP2, domain parallelism implementations were not working with FSDP (or using DDP entirely). The current implementation is DDP only.

Checklist

Dependencies

Review Process

All PRs are reviewed by the PhysicsNeMo team before merging.

Depending on which files are changed, GitHub may automatically assign a maintainer for review.

We are also testing AI-based code review tools (e.g., Greptile), which may add automated comments with a confidence score.
This score reflects the AI’s assessment of merge readiness and is not a qualitative judgment of your work, nor is
it an indication that the PR will be accepted / rejected.

AI-generated feedback should be reviewed critically for usefulness.
You are not required to respond to every AI comment, but they are intended to help both authors and reviewers.
Please react to Greptile comments with 👍 or 👎 to provide feedback on their accuracy.

@negin513 negin513 requested a review from CharlelieLrt as a code owner May 26, 2026 17:09
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 26, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

forward_prefetch=True, # Optimization for faster training
backward_prefetch=BackwardPrefetch.BACKWARD_PRE, # Backward prefetching for overlap
)
# FSDP2 rejects non-contiguous parameters (PyTorch <= 2.10):
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NOTE: this block exists solely for backward compatibility with PyTorch <= 2.10. Do we care about backward compatibility?

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 26, 2026

Greptile Summary

This PR migrates StormScope/StormCast from FSDP1 (FullyShardedDataParallel with NO_SHARD) to FSDP2 (fully_shard / FSDPModule), motivated by FSDP1's flat-param machinery being incompatible with the refactored ShardTensor/DTensor in #1556.

  • parallel.py: replaces the FSDP(NO_SHARD) wrapper with fully_shard, adds a pre-shard loop to force standard contiguity on channels-last parameters (FSDP2 rejects non-contiguous params ≤ PyTorch 2.10), and updates the docstring/return type to FSDPModule.
  • checkpoint.py: adds _unwrapped_class_name to recover the original class name from FSDP2's dynamically-generated subclass, extends _has_non_fsdp_dtensors with a degenerate-mesh guard for FSDP2 (size-1 mesh axes break DCP's broadcast), and threads both helpers through _unique_model_names.

Important Files Changed

Filename Overview
physicsnemo/utils/checkpoint.py Adds FSDPModule (FSDP2) support: new _unwrapped_class_name helper, extended _has_non_fsdp_dtensors logic for degenerate meshes, and updated _unique_model_names to use the new helper; _is_distributed_model does not add an explicit FSDPModule check.
examples/weather/stormcast/utils/parallel.py Replaces FSDP1 (NO_SHARD) with FSDP2 fully_shard; adds a pre-shard contiguity normalization loop that also iterates over DTensor parameters when use_shard_tensor=True.
examples/weather/stormcast/utils/trainer.py Minor comment and log-message updates to reflect FSDP2 terminology; no logic changes.
examples/weather/stormcast/test_training.py Removes now-unused FSDP1 imports (StateDictType, ShardedStateDictConfig, ShardedOptimStateDictConfig); no test logic changes.

Comments Outside Diff (1)

  1. physicsnemo/utils/checkpoint.py, line 70-74 (link)

    P2 _is_distributed_model relies solely on the DTensor parameter check to detect FSDP2 models. In practice fully_shard always converts parameters to DTensors, so this works, but a model with no learnable parameters would return False even after FSDP2 wrapping, silently routing it through the non-distributed checkpoint path. Adding an explicit FSDPModule branch makes the intent clear and is defensive against edge cases.

Reviews (1): Last reviewed commit: "improving checkpoint" | Re-trigger Greptile

Comment on lines +119 to +123
if isinstance(inner, FSDPModule):
bases = type(inner).__bases__
if len(bases) >= 2 and bases[0] is FSDPModule:
return bases[1].__name__
return type(inner).__name__
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Fragile MRO assumption for FSDP2 class name

_unwrapped_class_name returns bases[1].__name__ only when bases[0] is FSDPModule. If a future PyTorch version changes the order of bases or introduces an intermediate mixin in the dynamically-generated class (e.g. (FSDPModule, SomeMixin, OriginalCls)), the condition bases[0] is FSDPModule still holds but bases[1].__name__ would return SomeMixin instead of the real user class, silently generating the wrong checkpoint filename. Using type(inner).__mro__ to find the first non-FSDPModule/torch.nn.Module base would be more resilient.

Comment on lines +270 to +274
with torch.no_grad():
for p in model.parameters():
if p.is_contiguous():
continue
p.data = p.data.contiguous()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 When use_shard_tensor=True, distribute_module has already been called and model.parameters() yields DTensor-backed nn.Parameters. Assigning p.data = p.data.contiguous() on a DTensor parameter is not documented PyTorch API; Tensor.set_() (which backs the .data setter) with a DTensor argument may silently strip the DTensor's mesh/placements metadata, breaking the subsequent fully_shard call. In practice distribute_tensor normalises contiguity internally so the guard p.is_contiguous() is usually True for DTensor params and the assignment is skipped — but making the skip explicit prevents a silent breakage if that behaviour changes.

Suggested change
with torch.no_grad():
for p in model.parameters():
if p.is_contiguous():
continue
p.data = p.data.contiguous()
with torch.no_grad():
for p in model.parameters():
if isinstance(p.data, DTensor):
continue # distribute_module already normalises DTensor local shards
if p.is_contiguous():
continue
p.data = p.data.contiguous()

@negin513 negin513 changed the title Fsdp2 stormscope Fsdp2 stormscope [WIP] May 26, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant