Add fused_adam, quantized_model_init, and fsdp2 example#2698
Add fused_adam, quantized_model_init, and fsdp2 example#2698pstjohn wants to merge 2 commits intoNVIDIA:mainfrom
Conversation
22604c4 to
4d89e04
Compare
Greptile SummaryThis PR enables Key changes:
Confidence Score: 5/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant Model as TransformerLayer<br/>(FP8 params via quantized_model_init)
participant FSDP2 as fully_shard<br/>(DTensor wrapping)
participant Optimizer as FusedAdam<br/>(FP32 master weights)
participant Kernel as Multi-tensor kernels
User->>Model: Build model in quantized_model_init context
Model-->>User: Float8Tensor/QuantizedTensor params
User->>FSDP2: Apply fully_shard(model)
FSDP2-->>User: DTensor-wrapped params
Note over FSDP2: DTensor._local_tensor = Float8Tensor
User->>Optimizer: FusedAdam(params, master_weights=True)
Optimizer->>Optimizer: Extract local_tensor from DTensor
Optimizer->>Optimizer: Dequantize QuantizedTensor
Optimizer-->>User: FP32 master_param states initialized
User->>Model: Forward + backward pass
Model-->>Optimizer: Gradients (DTensor-wrapped)
Optimizer->>Optimizer: Extract p_grad._local_tensor
Optimizer->>Optimizer: Extract p._local_tensor for FP8 params
Optimizer->>Kernel: Call multi_tensor_adam with plain CUDA tensors
Kernel-->>Optimizer: Updated FP32 master weights
Optimizer->>Optimizer: Quantize and update FP8 params
Last reviewed commit: 96c123e |
| # to get a plain float32 copy for the master weight. | ||
| local_param = param._local_tensor if isinstance(param, DTensor) else param | ||
| if isinstance(local_param, QuantizedTensor): | ||
| master = local_param.dequantize().clone().detach().float() |
There was a problem hiding this comment.
Should we use dequantize(dtype=torch.float32), to fuse the cast into the de-quantization's output buffer? (Likely not a big deal since I don't think this will change anything numerically, and you only call this function during init and whenever you save and load DCP checkpoints.)
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
0103b53 to
3c3dbd2
Compare
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
|
@pstjohn Hi, thanks for the great work! Does this PR plan to also handle the BF16 path? I noticed the BF16 branch still operates on the original p/p_grad without unwrapping when they're DTensors. In my experiments with FSDP2 + BF16, I'm seeing non-trivial overhead during the optimizer step from repeated DTensor dispatch. Curious if that's intentional or a planned follow-up. |
Summary
FusedAdamto work with PyTorch-native FSDP2 (fully_shard) when parameters areDTensor-wrappedFloat8Tensor/QuantizedTensorfuse_wgrad_accumulationguard to avoid crashing with vanilla FSDP2 (previously assumed Megatron-Core FSDP exclusively)quantized_model_initon single-GPU (main.py) and multi-GPU FSDP2 (fully_shard.py)Note:
fuse_wgrad_accumulationremains incompatible with vanilla FSDP2fuse_wgrad_accumulationstill cannot be used with vanilla FSDP2. The feature writes weight gradients directly intomain_gradand returnsNoneto autograd, bypassing FSDP2's reduce-scatter. Each rank ends up with an unreduced gradient. Megatron-Core FSDP solves this by wiringget_main_grad()into its own reduce-scatter infrastructure. Vanilla FSDP2 does not yet expose an equivalent hook.Fixes #2682