diff --git a/examples/weather/stormcast/test_training.py b/examples/weather/stormcast/test_training.py index 34e4af98ce..e286a1ee54 100644 --- a/examples/weather/stormcast/test_training.py +++ b/examples/weather/stormcast/test_training.py @@ -24,12 +24,6 @@ import pytest import torch from torch.distributed.checkpoint.state_dict import get_state_dict, StateDictOptions -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.distributed.fsdp import ( - StateDictType, - ShardedStateDictConfig, - ShardedOptimStateDictConfig, -) from torch.distributed.tensor import DTensor from physicsnemo.distributed import DistributedManager diff --git a/examples/weather/stormcast/utils/parallel.py b/examples/weather/stormcast/utils/parallel.py index aad080f1f8..b04ea1c318 100644 --- a/examples/weather/stormcast/utils/parallel.py +++ b/examples/weather/stormcast/utils/parallel.py @@ -21,11 +21,7 @@ import numpy as np import torch -from torch.distributed.fsdp import ( - FullyShardedDataParallel as FSDP, - ShardingStrategy, - BackwardPrefetch, -) +from torch.distributed.fsdp import FSDPModule, fully_shard from torch.distributed.tensor import DTensor, distribute_module, distribute_tensor from torch.distributed.tensor.placement_types import Replicate, Shard @@ -236,8 +232,17 @@ def distribute_tensor(self, x: torch.Tensor) -> ShardTensor: else: return x - def distribute_model(self, model: torch.nn.Module) -> FSDP: - """Shard model parameters across the domain mesh and wrap with FSDP. + def distribute_model(self, model: torch.nn.Module) -> FSDPModule: + """Shard model parameters with FSDP2 (``fully_shard``). + + Parameters that are already DTensors from ``distribute_module`` (when + ``use_shard_tensor`` is True) are sharded on the domain mesh; FSDP2 + then additionally shards across the data-parallel mesh, producing + 2D-mesh DTensor parameters. + + Identical parameter initialization across ranks is assumed (the + trainer sets ``torch.manual_seed`` before model construction); FSDP2 + does not perform a sync-from-rank-0 broadcast on its own. Parameters ---------- @@ -246,8 +251,8 @@ def distribute_model(self, model: torch.nn.Module) -> FSDP: Returns ------- - torch.distributed.fsdp.FullyShardedDataParallel - Distributed model wrapper. + torch.distributed.fsdp.FSDPModule + The input model, now an ``FSDPModule`` with sharded parameters. """ if self.use_shard_tensor: model = distribute_module( @@ -255,15 +260,20 @@ def distribute_model(self, model: torch.nn.Module) -> FSDP: device_mesh=self.mesh["domain"], partition_fn=partition_model_selective, ) - return FSDP( - model, - device_mesh=self.mesh["ddp"], - use_orig_params=False, # Required for use with ShardTensor - sharding_strategy=ShardingStrategy.NO_SHARD, - sync_module_states=True, # Ensure initialized weights match across ranks - forward_prefetch=True, # Optimization for faster training - backward_prefetch=BackwardPrefetch.BACKWARD_PRE, # Backward prefetching for overlap - ) + # FSDP2 rejects non-contiguous parameters (PyTorch <= 2.10): + # NotImplementedError: FSDP does not support non-contiguous parameters + # Models created with ``.to(memory_format=torch.channels_last)`` have + # 4D params with channels_last strides. Force standard contiguity on + # the parameter storage — kernels still convert activations to + # channels_last when inputs arrive in that layout, so the perf win is + # retained. + with torch.no_grad(): + for p in model.parameters(): + if p.is_contiguous(): + continue + p.data = p.data.contiguous() + fully_shard(model, mesh=self.mesh["ddp"]) + return model def make_domain_parallel_scheduler(self, scheduler: object) -> object: """Wrap a noise scheduler for domain-parallel diffusion. diff --git a/examples/weather/stormcast/utils/trainer.py b/examples/weather/stormcast/utils/trainer.py index a61b723b90..6b0e29bd01 100644 --- a/examples/weather/stormcast/utils/trainer.py +++ b/examples/weather/stormcast/utils/trainer.py @@ -130,8 +130,9 @@ def __init__(self, cfg: DictConfig): self._setup_data() # All ranks use the same seed so parameter initialization is identical. - # FSDP sync_module_states and distribute_tensor also broadcast from - # rank 0, but explicit seeding avoids silent dependence on those. + # FSDP2 (fully_shard) does not broadcast initial weights from rank 0, + # so this deterministic seeding is what keeps the unsharded parameter + # values consistent across ranks. torch.manual_seed(self.cfg.training.seed) # Create model and move to device @@ -147,10 +148,10 @@ def __init__(self, cfg: DictConfig): # Sharding and FSDP wrapping if self.use_shard_tensor: self.logger.info( - "Distributing model with FSDP and sharding for domain parallelism" + "Distributing model with FSDP2 and sharding for domain parallelism" ) else: - self.logger.info("Distributing model with FSDP") + self.logger.info("Distributing model with FSDP2") self.net = self.parallel_helper.distribute_model(self.net) if self.regression_net is not None: self.regression_net = self.parallel_helper.distribute_model( diff --git a/physicsnemo/utils/checkpoint.py b/physicsnemo/utils/checkpoint.py index df3390dca3..47e109b77c 100644 --- a/physicsnemo/utils/checkpoint.py +++ b/physicsnemo/utils/checkpoint.py @@ -47,7 +47,7 @@ set_model_state_dict, set_optimizer_state_dict, ) -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import FSDPModule, FullyShardedDataParallel as FSDP from torch.distributed.fsdp import ShardingStrategy from torch.distributed.tensor import DTensor, distribute_tensor from torch.optim.lr_scheduler import LRScheduler @@ -94,12 +94,35 @@ def _unwrap_ddp_compile( def _unwrap_fsdp(model: torch.nn.Module) -> torch.nn.Module: - """Unwrap one FSDP layer (if present) to reach the user module.""" + """Unwrap one FSDP layer (if present) to reach the user module. + + For FSDP1 this returns ``model.module`` (the user module inside the + ``FullyShardedDataParallel`` wrapper). For FSDP2 the user module IS the + same object — ``fully_shard`` mutates its ``__class__`` in place — so + nothing to unwrap; callers that need the *original* class name should use + :func:`_unwrapped_class_name` rather than ``type(...).__name__``. + """ if isinstance(model, FSDP): return model.module return model +def _unwrapped_class_name(model: torch.nn.Module) -> str: + """Return the user-facing class name, peeling FSDP1/FSDP2 wrappers. + + FSDP2's ``fully_shard`` rebinds ``model.__class__`` to a dynamically + generated ``FSDP{ClassName}`` subclass with bases ``(FSDPModule, original_cls)``. + Stripping the prefix isn't reliable because user classes may legitimately + start with ``FSDP``; instead we read ``__bases__[1]`` directly. + """ + inner = _unwrap_fsdp(model) + if isinstance(inner, FSDPModule): + bases = type(inner).__bases__ + if len(bases) >= 2 and bases[0] is FSDPModule: + return bases[1].__name__ + return type(inner).__name__ + + def _cpu_offload_state_dict(state_dict: dict[str, Any]) -> dict[str, Any]: """Move every tensor in *state_dict* to CPU (shallow copy).""" out: dict[str, Any] = {} @@ -165,16 +188,38 @@ def _has_non_fsdp_dtensors( model: torch.nn.Module, dtensor_plc: dict[str, tuple[Any, tuple[Any, ...]]], ) -> bool: - """Return ``True`` when *dtensor_plc* contains placements not managed by FSDP. - - FSDP with ``FULL_SHARD`` or ``SHARD_GRAD_OP`` wraps parameters as - DTensors on its own mesh. ``broadcast_from_rank0`` handles these - natively, so manual redistribution should be skipped. Only - user-created DTensors (e.g. ShardTensor on a separate domain mesh) - require explicit redistribution. + """Return ``True`` when *dtensor_plc* requires the manual broadcast path. + + The function answers a single operational question: does loading this + model's state dict need to bypass DCP's ``broadcast_from_rank0`` in + favour of the explicit ``broadcast_object_list`` / + ``_redistribute_sd_for_dtensor`` path? The answer is ``True`` for: + + * Plain (non-FSDP) modules holding DTensors (e.g. user-created ShardTensors + on a domain mesh) — DCP's broadcast cannot rebuild user placements. + * ``FSDP`` (FSDP1) with ``NO_SHARD`` — equivalent to plain DDP, no + FSDP-managed DTensors to be aware of. + * ``FSDPModule`` (FSDP2) whose parameters sit on a mesh with any + *degenerate* dimension (``size == 1``). In that case DCP's underlying + ``dist.broadcast`` raises + ``RuntimeError: found no DeviceMesh from dtensor args for c10d::broadcast_`` + because c10d cannot dispatch through a size-1 mesh axis. The same + configuration also breaks DCP's optimizer load with + ``KeyError: 'state.0.step'`` (the freshly-constructed live optimizer has + no materialised ``state``, so DCP's flatten-mapping is missing the + checkpoint's ``state.X.step`` keys). Both surface when ``fully_shard`` + runs on world_size == 1, or on a 2D mesh whose ``ddp`` axis collapses + to 1 (e.g. ``(ddp=1, domain=N)`` after batch_size==1 on N ranks). + The manual path sidesteps both — it broadcasts the dict as Python + objects and then sets state with ``options=full_options`` (no DCP + broadcast, no unflatten round-trip). """ if not dtensor_plc: return False + if isinstance(model, FSDPModule): + return any( + any(s == 1 for s in mesh.shape) for mesh, _ in dtensor_plc.values() + ) if not isinstance(model, FSDP): return True if model.sharding_strategy == ShardingStrategy.NO_SHARD: @@ -603,7 +648,7 @@ def _unique_model_names( model_dict: dict[str, list[torch.nn.Module]] = {} for model0 in models: model0 = _unwrap_ddp_compile(model0, loading=loading) - base_name = type(_unwrap_fsdp(model0)).__name__ + base_name = _unwrapped_class_name(model0) if base_name in model_dict: model_dict[base_name].append(model0) else: