-
Notifications
You must be signed in to change notification settings - Fork 674
Fsdp2 stormscope [WIP] #1671
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Fsdp2 stormscope [WIP] #1671
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -21,11 +21,7 @@ | |||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| import numpy as np | ||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||
| from torch.distributed.fsdp import ( | ||||||||||||||||||||||||||
| FullyShardedDataParallel as FSDP, | ||||||||||||||||||||||||||
| ShardingStrategy, | ||||||||||||||||||||||||||
| BackwardPrefetch, | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
| from torch.distributed.fsdp import FSDPModule, fully_shard | ||||||||||||||||||||||||||
| from torch.distributed.tensor import DTensor, distribute_module, distribute_tensor | ||||||||||||||||||||||||||
| from torch.distributed.tensor.placement_types import Replicate, Shard | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
|
@@ -236,8 +232,17 @@ def distribute_tensor(self, x: torch.Tensor) -> ShardTensor: | |||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||
| return x | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def distribute_model(self, model: torch.nn.Module) -> FSDP: | ||||||||||||||||||||||||||
| """Shard model parameters across the domain mesh and wrap with FSDP. | ||||||||||||||||||||||||||
| def distribute_model(self, model: torch.nn.Module) -> FSDPModule: | ||||||||||||||||||||||||||
| """Shard model parameters with FSDP2 (``fully_shard``). | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| Parameters that are already DTensors from ``distribute_module`` (when | ||||||||||||||||||||||||||
| ``use_shard_tensor`` is True) are sharded on the domain mesh; FSDP2 | ||||||||||||||||||||||||||
| then additionally shards across the data-parallel mesh, producing | ||||||||||||||||||||||||||
| 2D-mesh DTensor parameters. | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| Identical parameter initialization across ranks is assumed (the | ||||||||||||||||||||||||||
| trainer sets ``torch.manual_seed`` before model construction); FSDP2 | ||||||||||||||||||||||||||
| does not perform a sync-from-rank-0 broadcast on its own. | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| Parameters | ||||||||||||||||||||||||||
| ---------- | ||||||||||||||||||||||||||
|
|
@@ -246,24 +251,29 @@ def distribute_model(self, model: torch.nn.Module) -> FSDP: | |||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| Returns | ||||||||||||||||||||||||||
| ------- | ||||||||||||||||||||||||||
| torch.distributed.fsdp.FullyShardedDataParallel | ||||||||||||||||||||||||||
| Distributed model wrapper. | ||||||||||||||||||||||||||
| torch.distributed.fsdp.FSDPModule | ||||||||||||||||||||||||||
| The input model, now an ``FSDPModule`` with sharded parameters. | ||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||
| if self.use_shard_tensor: | ||||||||||||||||||||||||||
| model = distribute_module( | ||||||||||||||||||||||||||
| model, | ||||||||||||||||||||||||||
| device_mesh=self.mesh["domain"], | ||||||||||||||||||||||||||
| partition_fn=partition_model_selective, | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
| return FSDP( | ||||||||||||||||||||||||||
| model, | ||||||||||||||||||||||||||
| device_mesh=self.mesh["ddp"], | ||||||||||||||||||||||||||
| use_orig_params=False, # Required for use with ShardTensor | ||||||||||||||||||||||||||
| sharding_strategy=ShardingStrategy.NO_SHARD, | ||||||||||||||||||||||||||
| sync_module_states=True, # Ensure initialized weights match across ranks | ||||||||||||||||||||||||||
| forward_prefetch=True, # Optimization for faster training | ||||||||||||||||||||||||||
| backward_prefetch=BackwardPrefetch.BACKWARD_PRE, # Backward prefetching for overlap | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
| # FSDP2 rejects non-contiguous parameters (PyTorch <= 2.10): | ||||||||||||||||||||||||||
| # NotImplementedError: FSDP does not support non-contiguous parameters | ||||||||||||||||||||||||||
| # Models created with ``.to(memory_format=torch.channels_last)`` have | ||||||||||||||||||||||||||
| # 4D params with channels_last strides. Force standard contiguity on | ||||||||||||||||||||||||||
| # the parameter storage — kernels still convert activations to | ||||||||||||||||||||||||||
| # channels_last when inputs arrive in that layout, so the perf win is | ||||||||||||||||||||||||||
| # retained. | ||||||||||||||||||||||||||
| with torch.no_grad(): | ||||||||||||||||||||||||||
| for p in model.parameters(): | ||||||||||||||||||||||||||
| if p.is_contiguous(): | ||||||||||||||||||||||||||
| continue | ||||||||||||||||||||||||||
| p.data = p.data.contiguous() | ||||||||||||||||||||||||||
|
Comment on lines
+270
to
+274
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||
| fully_shard(model, mesh=self.mesh["ddp"]) | ||||||||||||||||||||||||||
| return model | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def make_domain_parallel_scheduler(self, scheduler: object) -> object: | ||||||||||||||||||||||||||
| """Wrap a noise scheduler for domain-parallel diffusion. | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -47,7 +47,7 @@ | |
| set_model_state_dict, | ||
| set_optimizer_state_dict, | ||
| ) | ||
| from torch.distributed.fsdp import FullyShardedDataParallel as FSDP | ||
| from torch.distributed.fsdp import FSDPModule, FullyShardedDataParallel as FSDP | ||
| from torch.distributed.fsdp import ShardingStrategy | ||
| from torch.distributed.tensor import DTensor, distribute_tensor | ||
| from torch.optim.lr_scheduler import LRScheduler | ||
|
|
@@ -94,12 +94,35 @@ def _unwrap_ddp_compile( | |
|
|
||
|
|
||
| def _unwrap_fsdp(model: torch.nn.Module) -> torch.nn.Module: | ||
| """Unwrap one FSDP layer (if present) to reach the user module.""" | ||
| """Unwrap one FSDP layer (if present) to reach the user module. | ||
|
|
||
| For FSDP1 this returns ``model.module`` (the user module inside the | ||
| ``FullyShardedDataParallel`` wrapper). For FSDP2 the user module IS the | ||
| same object — ``fully_shard`` mutates its ``__class__`` in place — so | ||
| nothing to unwrap; callers that need the *original* class name should use | ||
| :func:`_unwrapped_class_name` rather than ``type(...).__name__``. | ||
| """ | ||
| if isinstance(model, FSDP): | ||
| return model.module | ||
| return model | ||
|
|
||
|
|
||
| def _unwrapped_class_name(model: torch.nn.Module) -> str: | ||
| """Return the user-facing class name, peeling FSDP1/FSDP2 wrappers. | ||
|
|
||
| FSDP2's ``fully_shard`` rebinds ``model.__class__`` to a dynamically | ||
| generated ``FSDP{ClassName}`` subclass with bases ``(FSDPModule, original_cls)``. | ||
| Stripping the prefix isn't reliable because user classes may legitimately | ||
| start with ``FSDP``; instead we read ``__bases__[1]`` directly. | ||
| """ | ||
| inner = _unwrap_fsdp(model) | ||
| if isinstance(inner, FSDPModule): | ||
| bases = type(inner).__bases__ | ||
| if len(bases) >= 2 and bases[0] is FSDPModule: | ||
| return bases[1].__name__ | ||
| return type(inner).__name__ | ||
|
Comment on lines
+119
to
+123
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
|
||
|
|
||
| def _cpu_offload_state_dict(state_dict: dict[str, Any]) -> dict[str, Any]: | ||
| """Move every tensor in *state_dict* to CPU (shallow copy).""" | ||
| out: dict[str, Any] = {} | ||
|
|
@@ -165,16 +188,38 @@ def _has_non_fsdp_dtensors( | |
| model: torch.nn.Module, | ||
| dtensor_plc: dict[str, tuple[Any, tuple[Any, ...]]], | ||
| ) -> bool: | ||
| """Return ``True`` when *dtensor_plc* contains placements not managed by FSDP. | ||
|
|
||
| FSDP with ``FULL_SHARD`` or ``SHARD_GRAD_OP`` wraps parameters as | ||
| DTensors on its own mesh. ``broadcast_from_rank0`` handles these | ||
| natively, so manual redistribution should be skipped. Only | ||
| user-created DTensors (e.g. ShardTensor on a separate domain mesh) | ||
| require explicit redistribution. | ||
| """Return ``True`` when *dtensor_plc* requires the manual broadcast path. | ||
|
|
||
| The function answers a single operational question: does loading this | ||
| model's state dict need to bypass DCP's ``broadcast_from_rank0`` in | ||
| favour of the explicit ``broadcast_object_list`` / | ||
| ``_redistribute_sd_for_dtensor`` path? The answer is ``True`` for: | ||
|
|
||
| * Plain (non-FSDP) modules holding DTensors (e.g. user-created ShardTensors | ||
| on a domain mesh) — DCP's broadcast cannot rebuild user placements. | ||
| * ``FSDP`` (FSDP1) with ``NO_SHARD`` — equivalent to plain DDP, no | ||
| FSDP-managed DTensors to be aware of. | ||
| * ``FSDPModule`` (FSDP2) whose parameters sit on a mesh with any | ||
| *degenerate* dimension (``size == 1``). In that case DCP's underlying | ||
| ``dist.broadcast`` raises | ||
| ``RuntimeError: found no DeviceMesh from dtensor args for c10d::broadcast_`` | ||
| because c10d cannot dispatch through a size-1 mesh axis. The same | ||
| configuration also breaks DCP's optimizer load with | ||
| ``KeyError: 'state.0.step'`` (the freshly-constructed live optimizer has | ||
| no materialised ``state``, so DCP's flatten-mapping is missing the | ||
| checkpoint's ``state.X.step`` keys). Both surface when ``fully_shard`` | ||
| runs on world_size == 1, or on a 2D mesh whose ``ddp`` axis collapses | ||
| to 1 (e.g. ``(ddp=1, domain=N)`` after batch_size==1 on N ranks). | ||
| The manual path sidesteps both — it broadcasts the dict as Python | ||
| objects and then sets state with ``options=full_options`` (no DCP | ||
| broadcast, no unflatten round-trip). | ||
| """ | ||
| if not dtensor_plc: | ||
| return False | ||
| if isinstance(model, FSDPModule): | ||
| return any( | ||
| any(s == 1 for s in mesh.shape) for mesh, _ in dtensor_plc.values() | ||
| ) | ||
| if not isinstance(model, FSDP): | ||
| return True | ||
| if model.sharding_strategy == ShardingStrategy.NO_SHARD: | ||
|
|
@@ -603,7 +648,7 @@ def _unique_model_names( | |
| model_dict: dict[str, list[torch.nn.Module]] = {} | ||
| for model0 in models: | ||
| model0 = _unwrap_ddp_compile(model0, loading=loading) | ||
| base_name = type(_unwrap_fsdp(model0)).__name__ | ||
| base_name = _unwrapped_class_name(model0) | ||
| if base_name in model_dict: | ||
| model_dict[base_name].append(model0) | ||
| else: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NOTE: this block exists solely for backward compatibility with PyTorch <= 2.10. Do we care about backward compatibility?