Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions examples/weather/stormcast/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,6 @@
import pytest
import torch
from torch.distributed.checkpoint.state_dict import get_state_dict, StateDictOptions
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import (
StateDictType,
ShardedStateDictConfig,
ShardedOptimStateDictConfig,
)
from torch.distributed.tensor import DTensor

from physicsnemo.distributed import DistributedManager
Expand Down
46 changes: 28 additions & 18 deletions examples/weather/stormcast/utils/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,7 @@

import numpy as np
import torch
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
ShardingStrategy,
BackwardPrefetch,
)
from torch.distributed.fsdp import FSDPModule, fully_shard
from torch.distributed.tensor import DTensor, distribute_module, distribute_tensor
from torch.distributed.tensor.placement_types import Replicate, Shard

Expand Down Expand Up @@ -236,8 +232,17 @@ def distribute_tensor(self, x: torch.Tensor) -> ShardTensor:
else:
return x

def distribute_model(self, model: torch.nn.Module) -> FSDP:
"""Shard model parameters across the domain mesh and wrap with FSDP.
def distribute_model(self, model: torch.nn.Module) -> FSDPModule:
"""Shard model parameters with FSDP2 (``fully_shard``).

Parameters that are already DTensors from ``distribute_module`` (when
``use_shard_tensor`` is True) are sharded on the domain mesh; FSDP2
then additionally shards across the data-parallel mesh, producing
2D-mesh DTensor parameters.

Identical parameter initialization across ranks is assumed (the
trainer sets ``torch.manual_seed`` before model construction); FSDP2
does not perform a sync-from-rank-0 broadcast on its own.

Parameters
----------
Expand All @@ -246,24 +251,29 @@ def distribute_model(self, model: torch.nn.Module) -> FSDP:

Returns
-------
torch.distributed.fsdp.FullyShardedDataParallel
Distributed model wrapper.
torch.distributed.fsdp.FSDPModule
The input model, now an ``FSDPModule`` with sharded parameters.
"""
if self.use_shard_tensor:
model = distribute_module(
model,
device_mesh=self.mesh["domain"],
partition_fn=partition_model_selective,
)
return FSDP(
model,
device_mesh=self.mesh["ddp"],
use_orig_params=False, # Required for use with ShardTensor
sharding_strategy=ShardingStrategy.NO_SHARD,
sync_module_states=True, # Ensure initialized weights match across ranks
forward_prefetch=True, # Optimization for faster training
backward_prefetch=BackwardPrefetch.BACKWARD_PRE, # Backward prefetching for overlap
)
# FSDP2 rejects non-contiguous parameters (PyTorch <= 2.10):
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NOTE: this block exists solely for backward compatibility with PyTorch <= 2.10. Do we care about backward compatibility?

# NotImplementedError: FSDP does not support non-contiguous parameters
# Models created with ``.to(memory_format=torch.channels_last)`` have
# 4D params with channels_last strides. Force standard contiguity on
# the parameter storage — kernels still convert activations to
# channels_last when inputs arrive in that layout, so the perf win is
# retained.
with torch.no_grad():
for p in model.parameters():
if p.is_contiguous():
continue
p.data = p.data.contiguous()
Comment on lines +270 to +274
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 When use_shard_tensor=True, distribute_module has already been called and model.parameters() yields DTensor-backed nn.Parameters. Assigning p.data = p.data.contiguous() on a DTensor parameter is not documented PyTorch API; Tensor.set_() (which backs the .data setter) with a DTensor argument may silently strip the DTensor's mesh/placements metadata, breaking the subsequent fully_shard call. In practice distribute_tensor normalises contiguity internally so the guard p.is_contiguous() is usually True for DTensor params and the assignment is skipped — but making the skip explicit prevents a silent breakage if that behaviour changes.

Suggested change
with torch.no_grad():
for p in model.parameters():
if p.is_contiguous():
continue
p.data = p.data.contiguous()
with torch.no_grad():
for p in model.parameters():
if isinstance(p.data, DTensor):
continue # distribute_module already normalises DTensor local shards
if p.is_contiguous():
continue
p.data = p.data.contiguous()

fully_shard(model, mesh=self.mesh["ddp"])
return model

def make_domain_parallel_scheduler(self, scheduler: object) -> object:
"""Wrap a noise scheduler for domain-parallel diffusion.
Expand Down
9 changes: 5 additions & 4 deletions examples/weather/stormcast/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,9 @@ def __init__(self, cfg: DictConfig):
self._setup_data()

# All ranks use the same seed so parameter initialization is identical.
# FSDP sync_module_states and distribute_tensor also broadcast from
# rank 0, but explicit seeding avoids silent dependence on those.
# FSDP2 (fully_shard) does not broadcast initial weights from rank 0,
# so this deterministic seeding is what keeps the unsharded parameter
# values consistent across ranks.
torch.manual_seed(self.cfg.training.seed)

# Create model and move to device
Expand All @@ -147,10 +148,10 @@ def __init__(self, cfg: DictConfig):
# Sharding and FSDP wrapping
if self.use_shard_tensor:
self.logger.info(
"Distributing model with FSDP and sharding for domain parallelism"
"Distributing model with FSDP2 and sharding for domain parallelism"
)
else:
self.logger.info("Distributing model with FSDP")
self.logger.info("Distributing model with FSDP2")
self.net = self.parallel_helper.distribute_model(self.net)
if self.regression_net is not None:
self.regression_net = self.parallel_helper.distribute_model(
Expand Down
65 changes: 55 additions & 10 deletions physicsnemo/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
set_model_state_dict,
set_optimizer_state_dict,
)
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import FSDPModule, FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import ShardingStrategy
from torch.distributed.tensor import DTensor, distribute_tensor
from torch.optim.lr_scheduler import LRScheduler
Expand Down Expand Up @@ -94,12 +94,35 @@ def _unwrap_ddp_compile(


def _unwrap_fsdp(model: torch.nn.Module) -> torch.nn.Module:
"""Unwrap one FSDP layer (if present) to reach the user module."""
"""Unwrap one FSDP layer (if present) to reach the user module.

For FSDP1 this returns ``model.module`` (the user module inside the
``FullyShardedDataParallel`` wrapper). For FSDP2 the user module IS the
same object — ``fully_shard`` mutates its ``__class__`` in place — so
nothing to unwrap; callers that need the *original* class name should use
:func:`_unwrapped_class_name` rather than ``type(...).__name__``.
"""
if isinstance(model, FSDP):
return model.module
return model


def _unwrapped_class_name(model: torch.nn.Module) -> str:
"""Return the user-facing class name, peeling FSDP1/FSDP2 wrappers.

FSDP2's ``fully_shard`` rebinds ``model.__class__`` to a dynamically
generated ``FSDP{ClassName}`` subclass with bases ``(FSDPModule, original_cls)``.
Stripping the prefix isn't reliable because user classes may legitimately
start with ``FSDP``; instead we read ``__bases__[1]`` directly.
"""
inner = _unwrap_fsdp(model)
if isinstance(inner, FSDPModule):
bases = type(inner).__bases__
if len(bases) >= 2 and bases[0] is FSDPModule:
return bases[1].__name__
return type(inner).__name__
Comment on lines +119 to +123
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Fragile MRO assumption for FSDP2 class name

_unwrapped_class_name returns bases[1].__name__ only when bases[0] is FSDPModule. If a future PyTorch version changes the order of bases or introduces an intermediate mixin in the dynamically-generated class (e.g. (FSDPModule, SomeMixin, OriginalCls)), the condition bases[0] is FSDPModule still holds but bases[1].__name__ would return SomeMixin instead of the real user class, silently generating the wrong checkpoint filename. Using type(inner).__mro__ to find the first non-FSDPModule/torch.nn.Module base would be more resilient.



def _cpu_offload_state_dict(state_dict: dict[str, Any]) -> dict[str, Any]:
"""Move every tensor in *state_dict* to CPU (shallow copy)."""
out: dict[str, Any] = {}
Expand Down Expand Up @@ -165,16 +188,38 @@ def _has_non_fsdp_dtensors(
model: torch.nn.Module,
dtensor_plc: dict[str, tuple[Any, tuple[Any, ...]]],
) -> bool:
"""Return ``True`` when *dtensor_plc* contains placements not managed by FSDP.

FSDP with ``FULL_SHARD`` or ``SHARD_GRAD_OP`` wraps parameters as
DTensors on its own mesh. ``broadcast_from_rank0`` handles these
natively, so manual redistribution should be skipped. Only
user-created DTensors (e.g. ShardTensor on a separate domain mesh)
require explicit redistribution.
"""Return ``True`` when *dtensor_plc* requires the manual broadcast path.

The function answers a single operational question: does loading this
model's state dict need to bypass DCP's ``broadcast_from_rank0`` in
favour of the explicit ``broadcast_object_list`` /
``_redistribute_sd_for_dtensor`` path? The answer is ``True`` for:

* Plain (non-FSDP) modules holding DTensors (e.g. user-created ShardTensors
on a domain mesh) — DCP's broadcast cannot rebuild user placements.
* ``FSDP`` (FSDP1) with ``NO_SHARD`` — equivalent to plain DDP, no
FSDP-managed DTensors to be aware of.
* ``FSDPModule`` (FSDP2) whose parameters sit on a mesh with any
*degenerate* dimension (``size == 1``). In that case DCP's underlying
``dist.broadcast`` raises
``RuntimeError: found no DeviceMesh from dtensor args for c10d::broadcast_``
because c10d cannot dispatch through a size-1 mesh axis. The same
configuration also breaks DCP's optimizer load with
``KeyError: 'state.0.step'`` (the freshly-constructed live optimizer has
no materialised ``state``, so DCP's flatten-mapping is missing the
checkpoint's ``state.X.step`` keys). Both surface when ``fully_shard``
runs on world_size == 1, or on a 2D mesh whose ``ddp`` axis collapses
to 1 (e.g. ``(ddp=1, domain=N)`` after batch_size==1 on N ranks).
The manual path sidesteps both — it broadcasts the dict as Python
objects and then sets state with ``options=full_options`` (no DCP
broadcast, no unflatten round-trip).
"""
if not dtensor_plc:
return False
if isinstance(model, FSDPModule):
return any(
any(s == 1 for s in mesh.shape) for mesh, _ in dtensor_plc.values()
)
if not isinstance(model, FSDP):
return True
if model.sharding_strategy == ShardingStrategy.NO_SHARD:
Expand Down Expand Up @@ -603,7 +648,7 @@ def _unique_model_names(
model_dict: dict[str, list[torch.nn.Module]] = {}
for model0 in models:
model0 = _unwrap_ddp_compile(model0, loading=loading)
base_name = type(_unwrap_fsdp(model0)).__name__
base_name = _unwrapped_class_name(model0)
if base_name in model_dict:
model_dict[base_name].append(model0)
else:
Expand Down
Loading