Add native MX8×MX4 mixed-precision GEMM kernel (f8f4bf16)#313
Open
isratnisa wants to merge 1 commit into
Open
Conversation
Summary:
Add a native mixed-precision CUTLASS GEMM kernel for MX8 (MXFP8 e4m3) x MX4 (MXFP4 e2m1) using the Blackwell mxf8f6f4 block-scaled tensor core instruction.
This follows the exact same dispatch path as the CUTLASS 4.3.5 example 72c_blackwell_mixed_mxfp8_bf16_gemm.cu: separate ElementA = mx_float8_t<float_e4m3_t> and ElementB = mx_float4_t<float_e2m1_t> types with OpClassBlockScaledTensorOp and KernelScheduleAuto, which causes the CollectiveBuilder to auto-select the mxf8f6f4 variant of the block-scaled tensor core MMA instruction (tcgen05.mma.blockscaled). Both operands use E8M0 scale factors with block size 32. The kernel template uses AlignmentA=128/sizeof_bits<ElementA> (=16), AlignmentB=128 (FP4, from CUTLASS 72c example), TileShapeK=256, and TN layout.
Changes:
- New CUTLASS kernel directory mx8mx4bf16/ with 8 tile configurations
- Dispatch file with heuristic-based and autotune-based kernel selection
- TORCH_CHECK(K % 32 == 0) for block size validation
- Registered as torch.ops.mslk.mx8mx4bf16 PyTorch op (CUDA + Meta dispatch)
- Added NativeMX8MX4Quant class in quant_ops.py for benchmarking framework
- Scale format fix: to_mxfp8() returns plain row-major scales, applied _to_blocked() conversion to match CUTLASS block-scaled kernel interleaved layout
Note: native_mx8 (MX8xMX8 CUTLASS kernel) numbers shown below are from D99386004, a separate stacked diff. cuBLAS via torch._scaled_mm is 3-6x slower at small M (e.g., 1.4 vs 8.9 TFLOPS at M=1).
Benchmark on B200 (750W), N=K=8192, num_iters=20:
GEMM-Only TFLOPS (higher is better):
| BF16 Fallback | Native Block-Scaled Tensor Core |
| (torch.matmul)| (tcgen05.mma.blockscaled) |
M | bf16_mm | native_mx8* | native_mx8_mx4 | native_nvfp4 | native_mxfp4 |
| | (mx8mx8bf16) | (mx8mx4bf16) | (f4f4bf16) | (f4f4bf16) |
-----+---------------+--------------+----------------+--------------+------------------+
1 | 4.7 | 8.9 | 8.8 | 10.1 | 10.7 |
16 | 88.3 | 141.4 | 143.5 | 164.5 | 175.6 |
64 | 337.7 | 556.1 | 572.7 | 691.1 | 700.2 |
128 | 636.3 | 1010.7 | 1120.3 | 1367.8 | 1432.4 |
256 | 984.8 | 1806.7 | 2013.4 | 2796.2 | 2834.6 |
512 | 1109.0 | 1987.9 | 2380.5 | 3618.0 | 4012.9 |
1024 | 1211.0 | 2066.1 | 2457.5 | 4431.2 | 4635.4 |
* native_mx8 CUTLASS kernel is in a separate diff (D99386004). Numbers shown for comparison only.
Output SQNR vs BF16 (dB, higher is better):
M | native_mx8* | native_mx8_mx4 | native_nvfp4 | native_mxfp4 |
| (mx8mx8bf16) | (mx8mx4bf16) | (f4f4bf16) | (f4f4bf16) |
-----+--------------+----------------+--------------+--------------+
1 | 28.56 | 18.80 | 17.56 | 16.13 |
16 | 28.46 | 18.80 | 17.41 | 16.03 |
64 | 28.45 | 18.79 | 17.43 | 16.06 |
128 | 28.46 | 18.79 | 17.43 | 16.04 |
256 | 28.46 | 18.79 | 17.43 | 16.04 |
512 | 28.46 | 18.79 | 17.43 | 16.05 |
1024 | 28.46 | 18.79 | 17.44 | 16.04 |
* native_mx8 numbers from D99386004.
Activation Quantization Overhead (ms):
| Per-Format Quant Time | |
M | to_mxfp8 | triton_mx4 | triton_nvfp4 | mx8_mx4 total |
| | | | (to_mxfp8 + _to_blocked + triton_mx4) |
-----+----------+------------+--------------+----------------------------------------+
1 | 0.182 | 0.063 | 0.054 | 0.310 |
16 | 0.191 | 0.055 | 0.053 | 0.306 |
64 | 0.173 | 0.053 | 0.054 | 0.308 |
128 | 0.179 | 0.053 | 0.053 | 0.272 |
256 | 0.179 | 0.056 | 0.053 | 0.279 |
512 | 0.183 | 0.071 | 0.063 | 0.287 |
1024 | 0.189 | 0.053 | 0.056 | 0.278 |
Key takeaways:
- native_mx8_mx4 SQNR is 18.79 dB — 2.75 dB better than native_mxfp4 (16.04 dB), thanks to MX8 activation fidelity (31.49 dB input A SQNR vs 19.03 dB for MX4)
- mxf8f8 and mxf8f6f4 have the same peak throughput on B200 (~3,600 TFLOPS at 750W), so native_mx8 and native_mx8_mx4 show similar GEMM TFLOPS as expected
- The value of native_mx8_mx4 over native_mx8 is 2x weight compression (4-bit vs 8-bit) with no compute penalty and acceptable quality (18.79 dB SQNR)
- Correctness verified: 6/6 shapes pass, no NaN/Inf, consistent 18.79 dB SQNR
Reviewed By: cthi
Differential Revision: D97844552
|
@isratnisa has exported this pull request. If you are a Meta employee, you can view the originating Diff in D97844552. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary:
Add a native mixed-precision CUTLASS GEMM kernel for MX8 (MXFP8 e4m3) x MX4 (MXFP4 e2m1) using the Blackwell mxf8f6f4 block-scaled tensor core instruction.
This follows the exact same dispatch path as the CUTLASS 4.3.5 example 72c_blackwell_mixed_mxfp8_bf16_gemm.cu: separate ElementA = mx_float8_t<float_e4m3_t> and ElementB = mx_float4_t<float_e2m1_t> types with OpClassBlockScaledTensorOp and KernelScheduleAuto, which causes the CollectiveBuilder to auto-select the mxf8f6f4 variant of the block-scaled tensor core MMA instruction (tcgen05.mma.blockscaled). Both operands use E8M0 scale factors with block size 32. The kernel template uses AlignmentA=128/sizeof_bits (=16), AlignmentB=128 (FP4, from CUTLASS 72c example), TileShapeK=256, and TN layout.
Changes:
Note: native_mx8 (MX8xMX8 CUTLASS kernel) numbers shown below are from D99386004, a separate stacked diff. cuBLAS via torch._scaled_mm is 3-6x slower at small M (e.g., 1.4 vs 8.9 TFLOPS at M=1).
Benchmark on B200 (750W), N=K=8192, num_iters=20:
GEMM-Only TFLOPS (higher is better):
| BF16 Fallback | Native Block-Scaled Tensor Core |
| (torch.matmul)| (tcgen05.mma.blockscaled) |
M | bf16_mm | native_mx8* | native_mx8_mx4 | native_nvfp4 | native_mxfp4 |
| | (mx8mx8bf16) | (mx8mx4bf16) | (f4f4bf16) | (f4f4bf16) |
-----+---------------+--------------+----------------+--------------+------------------+
1 | 4.7 | 8.9 | 8.8 | 10.1 | 10.7 |
16 | 88.3 | 141.4 | 143.5 | 164.5 | 175.6 |
64 | 337.7 | 556.1 | 572.7 | 691.1 | 700.2 |
128 | 636.3 | 1010.7 | 1120.3 | 1367.8 | 1432.4 |
256 | 984.8 | 1806.7 | 2013.4 | 2796.2 | 2834.6 |
512 | 1109.0 | 1987.9 | 2380.5 | 3618.0 | 4012.9 |
1024 | 1211.0 | 2066.1 | 2457.5 | 4431.2 | 4635.4 |
Output SQNR vs BF16 (dB, higher is better):
M | native_mx8* | native_mx8_mx4 | native_nvfp4 | native_mxfp4 |
| (mx8mx8bf16) | (mx8mx4bf16) | (f4f4bf16) | (f4f4bf16) |
-----+--------------+----------------+--------------+--------------+
1 | 28.56 | 18.80 | 17.56 | 16.13 |
16 | 28.46 | 18.80 | 17.41 | 16.03 |
64 | 28.45 | 18.79 | 17.43 | 16.06 |
128 | 28.46 | 18.79 | 17.43 | 16.04 |
256 | 28.46 | 18.79 | 17.43 | 16.04 |
512 | 28.46 | 18.79 | 17.43 | 16.05 |
1024 | 28.46 | 18.79 | 17.44 | 16.04 |
Activation Quantization Overhead (ms):
| Per-Format Quant Time | |
M | to_mxfp8 | triton_mx4 | triton_nvfp4 | mx8_mx4 total |
| | | | (to_mxfp8 + _to_blocked + triton_mx4) |
-----+----------+------------+--------------+----------------------------------------+
1 | 0.182 | 0.063 | 0.054 | 0.310 |
16 | 0.191 | 0.055 | 0.053 | 0.306 |
64 | 0.173 | 0.053 | 0.054 | 0.308 |
128 | 0.179 | 0.053 | 0.053 | 0.272 |
256 | 0.179 | 0.056 | 0.053 | 0.279 |
512 | 0.183 | 0.071 | 0.063 | 0.287 |
1024 | 0.189 | 0.053 | 0.056 | 0.278 |
Key takeaways:
Reviewed By: cthi
Differential Revision: D97844552