Skip to content

NVFP4 primary weight support#2691

Open
WanZzzzzz wants to merge 8 commits intoNVIDIA:mainfrom
WanZzzzzz:fp4_native_weights
Open

NVFP4 primary weight support#2691
WanZzzzzz wants to merge 8 commits intoNVIDIA:mainfrom
WanZzzzzz:fp4_native_weights

Conversation

@WanZzzzzz
Copy link

Description

This PR adds NVFP4 partial cast support for distributed training with ZeRO/FSDP optimizers. It enables efficient casting of FP32 master weight shards to NVFP4 model weights with coordinated scaling across data parallel ranks, while minimizing CPU overhead in large-scale training.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

This PR introduces NVFP4 partial cast infrastructure and optimizations for distributed training:

NVFP4 Partial Cast Kernel (nvfp4_2d_partial_cast)

  • Implements nibble-accurate partial updates for NVFP4 tensors in distributed settings
  • Supports two-level NVFP4 scaling: global FP32 scale + per-block FP8 E4M3 scale

NVFP4 Transpose Kernel (nvfp4_transpose)

  • Custom transpose kernel for nibble-packed NVFP4 data with shared memory optimization
  • Uses vectorized uint2 loads/stores with 64×64 tiles for efficient memory access
  • Handles nibble repacking during transpose (unlike FP8 byte transpose)
  • Enables columnwise data generation for GEMM operations after rowwise AllGather

Fused Scale Kernel (nvfp4_fused_scale)

  • Fuses per-block scale computation, global amax copy, and FP8 scale expansion into a single kernel
  • Eliminates multiple kernel launches and avoids D2H transfers by accepting tensor pointers
  • Reduces kernel launch overhead in the critical path

Multi-Tensor Dispatch Pattern

  • C++-side loop dispatch for NVFP4 multi-tensor operations
  • Reduces Python–C++ transition overhead compared to per-tensor Python loops
  • Collects metadata in Python and executes batched operations in C++ wrappers

CPU Overhead Optimizations

  • Batched dtype conversion via torch.cat / torch.split
  • Replaced torch.zeros() with torch.empty() for immediately written buffers
  • Consolidated metadata collection and allocation phases
  • Optimized bucket partitioning for expert parallel buffers

Scale Computation Improvements

  • Fixed floating-point precision mismatch between Python and CUDA
  • Uses FP32 constants consistent with CUDA arithmetic
  • Ensures bitwise-identical results between partial and full quantization paths

New Public API

cast_master_weights_to_nvfp4()

  • Casts FP32 master weights to NVFP4 model weights
  • Handles global and per-block amax reduction across data parallel groups
  • Designed for low CPU overhead in distributed training loops

Testing

Test Description
test_nvfp4_transpose_kernel Verifies correctness for nibble-packed transpose
test_nvfp4_partial_cast_matches_full Multi-GPU: partial cast + all-gather equals full cast
test_single_gpu_partial_cast_vs_full Single-GPU: offset=0 partial cast matches reference quantizer
_test_cast_master_weights_to_nvfp4 500-iteration training loop with bitwise-identical loss

This feature also passed numeric validation in GPT-3 training on the corresponding Megatron-Core branch:

https://gitlab-master.nvidia.com/qiyuw/megatron-lm-all/-/tree/fp4_primary_opt?ref_type=heads

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: qiyuw <qiyuw@nvidia.com>
@WanZzzzzz WanZzzzzz mentioned this pull request Feb 19, 2026
13 tasks
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 19, 2026

Greptile Summary

This PR implements NVFP4 partial cast support for distributed training, enabling efficient FP32-to-NVFP4 weight quantization with coordinated scaling across data parallel ranks.

Key changes:

  • NVFP4 partial cast kernel (nvfp4_2d_partial_cast_kernel) - performs nibble-accurate partial updates for sharded tensors with two-level scaling (global FP32 + per-block FP8 E4M3)
  • NVFP4 transpose kernel (nvfp4_transpose_kernel) - custom transpose handling nibble repacking using vectorized uint2 loads/stores with 64×64 tiles
  • Fused scale kernel (nvfp4_fused_scale_kernel) - combines per-block scale computation, global amax copy, and FP8 scale expansion into single kernel
  • Multi-tensor dispatch - C++-side batching reduces Python-C++ transition overhead for multiple parameters
  • CPU optimizations - batched dtype conversion via torch.cat/torch.split, replaced torch.zeros() with torch.empty() for pre-written buffers
  • New public API: quantize_master_weights() (formerly cast_master_weights_to_fp8) now supports both FP8 and NVFP4
  • Comprehensive testing - includes 500-iteration training loop with bitwise-identical loss validation, multi-GPU partial cast tests, and transpose correctness tests

Implementation notes:

  • Moved nvfp4.cu to arch-specific sources (likely due to PTX instruction usage)
  • Scale computation uses FP32 constants matching CUDA arithmetic for bitwise-identical results
  • Kernels extensively document tile/warp mapping and partial-shard design
  • Multi-tensor operations reduce kernel launch overhead in critical path

Issue found:

  • test_single_gpu_partial_cast_vs_full computes comparison results but never asserts them (test always passes)

Confidence Score: 4/5

  • Safe to merge with one test fix - production code is well-tested and thoroughly documented
  • Score reflects high code quality, comprehensive testing (including 500-iteration training validation), and detailed documentation. Deducted one point for the missing test assertions in test_single_gpu_partial_cast_vs_full, which should be fixed before merge to ensure the test actually validates correctness.
  • tests/pytorch/distributed/test_cast_master_weights_to_fp8.py - fix missing assertions at line 1237-1244

Important Files Changed

Filename Overview
transformer_engine/common/recipe/nvfp4.cu Implements NVFP4 partial cast and transpose kernels with optimized shared memory usage and vectorized I/O. Well-structured with detailed comments.
transformer_engine/pytorch/csrc/extensions/transpose.cpp Adds NVFP4 transpose and scale operations with comprehensive shape validation and multi-tensor dispatch.
transformer_engine/pytorch/tensor/utils.py Implements NVFP4 quantization with batched operations to reduce CPU overhead. Introduces batched dtype conversion and multi-tensor kernel dispatch.
tests/pytorch/distributed/test_cast_master_weights_to_fp8.py Adds comprehensive NVFP4 tests including 500-iteration training loop and multi-GPU partial cast validation. Contains one test with missing assertions.
transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py Adds _create_columnwise() method for NVFP4 using specialized transpose kernels that handle nibble repacking.

Sequence Diagram

sequenceDiagram
    participant User as Training Loop
    participant API as quantize_master_weights()
    participant Utils as utils.py
    participant Cpp as C++ Extensions
    participant CUDA as CUDA Kernels
    participant NCCL as Distributed

    User->>API: Cast FP32 shards to NVFP4
    API->>Utils: Batch dtype conversion
    Note over Utils: torch.cat() -> .to(dtype) -> split()
    
    API->>Utils: Collect metadata for multi-tensor
    Note over Utils: h_list, w_list, offset_list
    
    Utils->>Cpp: nvfp4_multi_tensor_compute_partial_amax()
    loop For each tensor
        Cpp->>CUDA: nvfp4_2d_compute_partial_amax_kernel
        Note over CUDA: Compute per-block amax<br/>in shard range [start, end)
        Cpp->>CUDA: nvte_compute_amax()
        Note over CUDA: Compute global amax<br/>for full shard
    end
    
    Utils->>NCCL: all_reduce(packed_amaxes, MAX)
    Utils->>NCCL: all_reduce(global_amaxes, MAX)
    
    Utils->>Cpp: nvfp4_compute_global_scale()
    Cpp->>CUDA: nvfp4_compute_global_scale_kernel
    Note over CUDA: global_scale = 2688 / global_amax
    
    Utils->>Cpp: nvfp4_multi_tensor_fused_scale()
    loop For each tensor
        Cpp->>CUDA: nvfp4_fused_scale_kernel
        Note over CUDA: 1. Compute per-block scale<br/>2. Copy global amax<br/>3. Expand to FP8 E4M3
    end
    
    Utils->>Cpp: nvfp4_multi_tensor_2d_partial_cast()
    loop For each tensor
        Cpp->>CUDA: nvfp4_2d_partial_cast_kernel
        Note over CUDA: Cast shard to FP4<br/>with nibble-accurate RMW
    end
    
    User->>API: post_all_gather_processing()
    API->>Utils: _nvfp4_2d_multi_tensor_transpose()
    Utils->>Cpp: nvfp4_2d_multi_tensor_transpose()
    loop For each tensor
        Cpp->>CUDA: nvfp4_transpose_kernel
        Note over CUDA: Transpose with nibble repacking<br/>[M, K/2] -> [K, M/2]
        Cpp->>CUDA: nvfp4_scale_transpose_kernel
        Note over CUDA: Transpose tile scales<br/>[M_pad, K_tiles] -> [K_pad, M_tiles]
    end
Loading

Last reviewed commit: 5f5f48f

greptile-apps[bot]

This comment was marked as resolved.

@timmoon10

This comment was marked as outdated.

start_offsets,
group,
fsdp_shard_model_weights=None,
manual_post_all_gather_processing=False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We added this kwarg to the FP8 functions for backward compatibility, but there's no point keeping them for these brand-new NVFP4 APIs:

Suggested change
manual_post_all_gather_processing=False,

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fsdp_shard_model_weights=None is for future FSDP support. It's in the plan.
manual_post_all_gather_processing is also needed for the same reason as FP8 blockwise scaling:
https://github.com/WanZzzzzz/TransformerEngine/blob/38b92b1a168dcfaa6242fea50f03e5a1b873e3a0/transformer_engine/pytorch/tensor/utils.py#L535

Copy link
Collaborator

@timmoon10 timmoon10 Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, that makes sense for now then. Let's change the default to True though since that's preferred.

I want to flag a potential future problem with manual_post_all_gather_processing=False: it assumes that the quantized tensor has some way to handle the post-processing automatically. For FP8 on Hopper:

cast_master_weights_to_fp8(..., manual_post_all_gather_processing=False)
torch.all_gather(...)

y = model(x)  # Float8Tensor internally performs FP8 transpose

This is not something TE will guarantee for future data formats. Maybe the next recipe has some interleaved format:

cast_master_weights_to_futureformat(...)
torch.all_gather(...)
fix_futureformat_interleaving(...)

y = model(x)  # FutureFormatTensor assumes data is interleaved

In this case, we should throw an error with the user passes manual_post_all_gather_processing=False and it should be Mcore's responsibility to perform the post-processing in a way that's friendly to overlapping.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, note it down.

Comment on lines 245 to +259
if isinstance(self.weights[0], QuantizedTensor):
weight_buffer_dtype = torch.uint8
if self.weights_are_nvfp4:
weight_buffer_length = self.storage_total
buffer_rank_start = storage_rank_start
buffer_rank_end = storage_rank_end
else:
weight_buffer_length = self.offsets[-1]
buffer_rank_start = rank_start
buffer_rank_end = rank_end
else:
weight_buffer_dtype = weights[0].dtype
weight_buffer_length = self.offsets[-1]
buffer_rank_start = rank_start
buffer_rank_end = rank_end
Copy link
Collaborator

@timmoon10 timmoon10 Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: It's a bit convoluted, isn't it? It would be much nicer to disentangle the quantization logic from the buffer allocation by computing storage offsets in all cases (even if it's trivial for non-NVFP4 cases) and then using that blindly here.

qiyuw and others added 2 commits February 20, 2026 05:52
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

11 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Signed-off-by: qiyuw <qiyuw@nvidia.com>
greptile-apps[bot]

This comment was marked as outdated.

Signed-off-by: qiyuw <qiyuw@nvidia.com>
greptile-apps[bot]

This comment was marked as resolved.

@greptile-apps

This comment was marked as resolved.

Signed-off-by: qiyuw <qiyuw@nvidia.com>
greptile-apps[bot]

This comment was marked as outdated.

Signed-off-by: qiyuw <qiyuw@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

11 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@timmoon10
Copy link
Collaborator

/te-ci L1

timmoon10
timmoon10 previously approved these changes Feb 21, 2026
Copy link
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM, although there are some test failures related to missing licenses and linter warnings. I also still have some nits, although they are not blocking.

@timmoon10 timmoon10 self-requested a review February 21, 2026 00:09
@timmoon10 timmoon10 dismissed their stale review February 21, 2026 00:09

Test failures

Signed-off-by: qiyuw <qiyuw@nvidia.com>
Comment on lines +1237 to +1244
# Compare amax
amax_match = torch.equal(test_tensor._amax_rowwise, ref_amax)

# Compare scale
scale_match = torch.equal(test_tensor._rowwise_scale_inv, ref_scale)

# Compare data
data_match = torch.equal(test_tensor._rowwise_data, ref_data)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing assertions - amax_match, scale_match, and data_match are computed but never checked

Suggested change
# Compare amax
amax_match = torch.equal(test_tensor._amax_rowwise, ref_amax)
# Compare scale
scale_match = torch.equal(test_tensor._rowwise_scale_inv, ref_scale)
# Compare data
data_match = torch.equal(test_tensor._rowwise_data, ref_data)
# Compare amax
amax_match = torch.equal(test_tensor._amax_rowwise, ref_amax)
assert amax_match, "Amax mismatch between partial cast and reference"
# Compare scale
scale_match = torch.equal(test_tensor._rowwise_scale_inv, ref_scale)
assert scale_match, "Scale mismatch between partial cast and reference"
# Compare data
data_match = torch.equal(test_tensor._rowwise_data, ref_data)
assert data_match, "Data mismatch between partial cast and reference"

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants