Conversation
Signed-off-by: qiyuw <qiyuw@nvidia.com>
Greptile SummaryThis PR implements NVFP4 partial cast support for distributed training, enabling efficient FP32-to-NVFP4 weight quantization with coordinated scaling across data parallel ranks. Key changes:
Implementation notes:
Issue found:
Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User as Training Loop
participant API as quantize_master_weights()
participant Utils as utils.py
participant Cpp as C++ Extensions
participant CUDA as CUDA Kernels
participant NCCL as Distributed
User->>API: Cast FP32 shards to NVFP4
API->>Utils: Batch dtype conversion
Note over Utils: torch.cat() -> .to(dtype) -> split()
API->>Utils: Collect metadata for multi-tensor
Note over Utils: h_list, w_list, offset_list
Utils->>Cpp: nvfp4_multi_tensor_compute_partial_amax()
loop For each tensor
Cpp->>CUDA: nvfp4_2d_compute_partial_amax_kernel
Note over CUDA: Compute per-block amax<br/>in shard range [start, end)
Cpp->>CUDA: nvte_compute_amax()
Note over CUDA: Compute global amax<br/>for full shard
end
Utils->>NCCL: all_reduce(packed_amaxes, MAX)
Utils->>NCCL: all_reduce(global_amaxes, MAX)
Utils->>Cpp: nvfp4_compute_global_scale()
Cpp->>CUDA: nvfp4_compute_global_scale_kernel
Note over CUDA: global_scale = 2688 / global_amax
Utils->>Cpp: nvfp4_multi_tensor_fused_scale()
loop For each tensor
Cpp->>CUDA: nvfp4_fused_scale_kernel
Note over CUDA: 1. Compute per-block scale<br/>2. Copy global amax<br/>3. Expand to FP8 E4M3
end
Utils->>Cpp: nvfp4_multi_tensor_2d_partial_cast()
loop For each tensor
Cpp->>CUDA: nvfp4_2d_partial_cast_kernel
Note over CUDA: Cast shard to FP4<br/>with nibble-accurate RMW
end
User->>API: post_all_gather_processing()
API->>Utils: _nvfp4_2d_multi_tensor_transpose()
Utils->>Cpp: nvfp4_2d_multi_tensor_transpose()
loop For each tensor
Cpp->>CUDA: nvfp4_transpose_kernel
Note over CUDA: Transpose with nibble repacking<br/>[M, K/2] -> [K, M/2]
Cpp->>CUDA: nvfp4_scale_transpose_kernel
Note over CUDA: Transpose tile scales<br/>[M_pad, K_tiles] -> [K_pad, M_tiles]
end
Last reviewed commit: 5f5f48f |
This comment was marked as outdated.
This comment was marked as outdated.
| start_offsets, | ||
| group, | ||
| fsdp_shard_model_weights=None, | ||
| manual_post_all_gather_processing=False, |
There was a problem hiding this comment.
We added this kwarg to the FP8 functions for backward compatibility, but there's no point keeping them for these brand-new NVFP4 APIs:
| manual_post_all_gather_processing=False, |
There was a problem hiding this comment.
fsdp_shard_model_weights=None is for future FSDP support. It's in the plan.
manual_post_all_gather_processing is also needed for the same reason as FP8 blockwise scaling:
https://github.com/WanZzzzzz/TransformerEngine/blob/38b92b1a168dcfaa6242fea50f03e5a1b873e3a0/transformer_engine/pytorch/tensor/utils.py#L535
There was a problem hiding this comment.
I see, that makes sense for now then. Let's change the default to True though since that's preferred.
I want to flag a potential future problem with manual_post_all_gather_processing=False: it assumes that the quantized tensor has some way to handle the post-processing automatically. For FP8 on Hopper:
cast_master_weights_to_fp8(..., manual_post_all_gather_processing=False)
torch.all_gather(...)
y = model(x) # Float8Tensor internally performs FP8 transposeThis is not something TE will guarantee for future data formats. Maybe the next recipe has some interleaved format:
cast_master_weights_to_futureformat(...)
torch.all_gather(...)
fix_futureformat_interleaving(...)
y = model(x) # FutureFormatTensor assumes data is interleavedIn this case, we should throw an error with the user passes manual_post_all_gather_processing=False and it should be Mcore's responsibility to perform the post-processing in a way that's friendly to overlapping.
transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py
Outdated
Show resolved
Hide resolved
transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py
Outdated
Show resolved
Hide resolved
| if isinstance(self.weights[0], QuantizedTensor): | ||
| weight_buffer_dtype = torch.uint8 | ||
| if self.weights_are_nvfp4: | ||
| weight_buffer_length = self.storage_total | ||
| buffer_rank_start = storage_rank_start | ||
| buffer_rank_end = storage_rank_end | ||
| else: | ||
| weight_buffer_length = self.offsets[-1] | ||
| buffer_rank_start = rank_start | ||
| buffer_rank_end = rank_end | ||
| else: | ||
| weight_buffer_dtype = weights[0].dtype | ||
| weight_buffer_length = self.offsets[-1] | ||
| buffer_rank_start = rank_start | ||
| buffer_rank_end = rank_end |
There was a problem hiding this comment.
Nit: It's a bit convoluted, isn't it? It would be much nicer to disentangle the quantization logic from the buffer allocation by computing storage offsets in all cases (even if it's trivial for non-NVFP4 cases) and then using that blindly here.
Signed-off-by: qiyuw <qiyuw@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: qiyuw <qiyuw@nvidia.com>
Signed-off-by: qiyuw <qiyuw@nvidia.com>
This comment was marked as resolved.
This comment was marked as resolved.
Signed-off-by: qiyuw <qiyuw@nvidia.com>
Signed-off-by: qiyuw <qiyuw@nvidia.com>
|
/te-ci L1 |
timmoon10
left a comment
There was a problem hiding this comment.
Overall LGTM, although there are some test failures related to missing licenses and linter warnings. I also still have some nits, although they are not blocking.
Signed-off-by: qiyuw <qiyuw@nvidia.com>
| # Compare amax | ||
| amax_match = torch.equal(test_tensor._amax_rowwise, ref_amax) | ||
|
|
||
| # Compare scale | ||
| scale_match = torch.equal(test_tensor._rowwise_scale_inv, ref_scale) | ||
|
|
||
| # Compare data | ||
| data_match = torch.equal(test_tensor._rowwise_data, ref_data) |
There was a problem hiding this comment.
Missing assertions - amax_match, scale_match, and data_match are computed but never checked
| # Compare amax | |
| amax_match = torch.equal(test_tensor._amax_rowwise, ref_amax) | |
| # Compare scale | |
| scale_match = torch.equal(test_tensor._rowwise_scale_inv, ref_scale) | |
| # Compare data | |
| data_match = torch.equal(test_tensor._rowwise_data, ref_data) | |
| # Compare amax | |
| amax_match = torch.equal(test_tensor._amax_rowwise, ref_amax) | |
| assert amax_match, "Amax mismatch between partial cast and reference" | |
| # Compare scale | |
| scale_match = torch.equal(test_tensor._rowwise_scale_inv, ref_scale) | |
| assert scale_match, "Scale mismatch between partial cast and reference" | |
| # Compare data | |
| data_match = torch.equal(test_tensor._rowwise_data, ref_data) | |
| assert data_match, "Data mismatch between partial cast and reference" |
Description
This PR adds NVFP4 partial cast support for distributed training with ZeRO/FSDP optimizers. It enables efficient casting of FP32 master weight shards to NVFP4 model weights with coordinated scaling across data parallel ranks, while minimizing CPU overhead in large-scale training.
Type of change
Changes
This PR introduces NVFP4 partial cast infrastructure and optimizations for distributed training:
NVFP4 Partial Cast Kernel (
nvfp4_2d_partial_cast)NVFP4 Transpose Kernel (
nvfp4_transpose)uint2loads/stores with 64×64 tiles for efficient memory accessFused Scale Kernel (
nvfp4_fused_scale)Multi-Tensor Dispatch Pattern
CPU Overhead Optimizations
torch.cat/torch.splittorch.zeros()withtorch.empty()for immediately written buffersScale Computation Improvements
New Public API
cast_master_weights_to_nvfp4()Testing
test_nvfp4_transpose_kerneltest_nvfp4_partial_cast_matches_fulltest_single_gpu_partial_cast_vs_full_test_cast_master_weights_to_nvfp4This feature also passed numeric validation in GPT-3 training on the corresponding Megatron-Core branch:
https://gitlab-master.nvidia.com/qiyuw/megatron-lm-all/-/tree/fp4_primary_opt?ref_type=heads
Checklist: