Skip to content

Add native MX8×MX4 mixed-precision GEMM kernel (f8f4bf16)#313

Open
isratnisa wants to merge 1 commit into
meta-pytorch:mainfrom
isratnisa:export-D97844552
Open

Add native MX8×MX4 mixed-precision GEMM kernel (f8f4bf16)#313
isratnisa wants to merge 1 commit into
meta-pytorch:mainfrom
isratnisa:export-D97844552

Conversation

@isratnisa
Copy link
Copy Markdown

Summary:
Add a native mixed-precision CUTLASS GEMM kernel for MX8 (MXFP8 e4m3) x MX4 (MXFP4 e2m1) using the Blackwell mxf8f6f4 block-scaled tensor core instruction.

This follows the exact same dispatch path as the CUTLASS 4.3.5 example 72c_blackwell_mixed_mxfp8_bf16_gemm.cu: separate ElementA = mx_float8_t<float_e4m3_t> and ElementB = mx_float4_t<float_e2m1_t> types with OpClassBlockScaledTensorOp and KernelScheduleAuto, which causes the CollectiveBuilder to auto-select the mxf8f6f4 variant of the block-scaled tensor core MMA instruction (tcgen05.mma.blockscaled). Both operands use E8M0 scale factors with block size 32. The kernel template uses AlignmentA=128/sizeof_bits (=16), AlignmentB=128 (FP4, from CUTLASS 72c example), TileShapeK=256, and TN layout.

Changes:

  • New CUTLASS kernel directory mx8mx4bf16/ with 8 tile configurations
  • Dispatch file with heuristic-based and autotune-based kernel selection
  • TORCH_CHECK(K % 32 == 0) for block size validation
  • Registered as torch.ops.mslk.mx8mx4bf16 PyTorch op (CUDA + Meta dispatch)
  • Added NativeMX8MX4Quant class in quant_ops.py for benchmarking framework
  • Scale format fix: to_mxfp8() returns plain row-major scales, applied _to_blocked() conversion to match CUTLASS block-scaled kernel interleaved layout

Note: native_mx8 (MX8xMX8 CUTLASS kernel) numbers shown below are from D99386004, a separate stacked diff. cuBLAS via torch._scaled_mm is 3-6x slower at small M (e.g., 1.4 vs 8.9 TFLOPS at M=1).

Benchmark on B200 (750W), N=K=8192, num_iters=20:

GEMM-Only TFLOPS (higher is better):
| BF16 Fallback | Native Block-Scaled Tensor Core |
| (torch.matmul)| (tcgen05.mma.blockscaled) |
M | bf16_mm | native_mx8* | native_mx8_mx4 | native_nvfp4 | native_mxfp4 |
| | (mx8mx8bf16) | (mx8mx4bf16) | (f4f4bf16) | (f4f4bf16) |
-----+---------------+--------------+----------------+--------------+------------------+
1 | 4.7 | 8.9 | 8.8 | 10.1 | 10.7 |
16 | 88.3 | 141.4 | 143.5 | 164.5 | 175.6 |
64 | 337.7 | 556.1 | 572.7 | 691.1 | 700.2 |
128 | 636.3 | 1010.7 | 1120.3 | 1367.8 | 1432.4 |
256 | 984.8 | 1806.7 | 2013.4 | 2796.2 | 2834.6 |
512 | 1109.0 | 1987.9 | 2380.5 | 3618.0 | 4012.9 |
1024 | 1211.0 | 2066.1 | 2457.5 | 4431.2 | 4635.4 |

  • native_mx8 CUTLASS kernel is in a separate diff (D99386004). Numbers shown for comparison only.

Output SQNR vs BF16 (dB, higher is better):
M | native_mx8* | native_mx8_mx4 | native_nvfp4 | native_mxfp4 |
| (mx8mx8bf16) | (mx8mx4bf16) | (f4f4bf16) | (f4f4bf16) |
-----+--------------+----------------+--------------+--------------+
1 | 28.56 | 18.80 | 17.56 | 16.13 |
16 | 28.46 | 18.80 | 17.41 | 16.03 |
64 | 28.45 | 18.79 | 17.43 | 16.06 |
128 | 28.46 | 18.79 | 17.43 | 16.04 |
256 | 28.46 | 18.79 | 17.43 | 16.04 |
512 | 28.46 | 18.79 | 17.43 | 16.05 |
1024 | 28.46 | 18.79 | 17.44 | 16.04 |

  • native_mx8 numbers from D99386004.

Activation Quantization Overhead (ms):
| Per-Format Quant Time | |
M | to_mxfp8 | triton_mx4 | triton_nvfp4 | mx8_mx4 total |
| | | | (to_mxfp8 + _to_blocked + triton_mx4) |
-----+----------+------------+--------------+----------------------------------------+
1 | 0.182 | 0.063 | 0.054 | 0.310 |
16 | 0.191 | 0.055 | 0.053 | 0.306 |
64 | 0.173 | 0.053 | 0.054 | 0.308 |
128 | 0.179 | 0.053 | 0.053 | 0.272 |
256 | 0.179 | 0.056 | 0.053 | 0.279 |
512 | 0.183 | 0.071 | 0.063 | 0.287 |
1024 | 0.189 | 0.053 | 0.056 | 0.278 |

Key takeaways:

  • native_mx8_mx4 SQNR is 18.79 dB — 2.75 dB better than native_mxfp4 (16.04 dB), thanks to MX8 activation fidelity (31.49 dB input A SQNR vs 19.03 dB for MX4)
  • mxf8f8 and mxf8f6f4 have the same peak throughput on B200 (~3,600 TFLOPS at 750W), so native_mx8 and native_mx8_mx4 show similar GEMM TFLOPS as expected
  • The value of native_mx8_mx4 over native_mx8 is 2x weight compression (4-bit vs 8-bit) with no compute penalty and acceptable quality (18.79 dB SQNR)
  • Correctness verified: 6/6 shapes pass, no NaN/Inf, consistent 18.79 dB SQNR

Reviewed By: cthi

Differential Revision: D97844552

Summary:
Add a native mixed-precision CUTLASS GEMM kernel for MX8 (MXFP8 e4m3) x MX4 (MXFP4 e2m1) using the Blackwell mxf8f6f4 block-scaled tensor core instruction.

This follows the exact same dispatch path as the CUTLASS 4.3.5 example 72c_blackwell_mixed_mxfp8_bf16_gemm.cu: separate ElementA = mx_float8_t<float_e4m3_t> and ElementB = mx_float4_t<float_e2m1_t> types with OpClassBlockScaledTensorOp and KernelScheduleAuto, which causes the CollectiveBuilder to auto-select the mxf8f6f4 variant of the block-scaled tensor core MMA instruction (tcgen05.mma.blockscaled). Both operands use E8M0 scale factors with block size 32. The kernel template uses AlignmentA=128/sizeof_bits<ElementA> (=16), AlignmentB=128 (FP4, from CUTLASS 72c example), TileShapeK=256, and TN layout.

Changes:
- New CUTLASS kernel directory mx8mx4bf16/ with 8 tile configurations
- Dispatch file with heuristic-based and autotune-based kernel selection
- TORCH_CHECK(K % 32 == 0) for block size validation
- Registered as torch.ops.mslk.mx8mx4bf16 PyTorch op (CUDA + Meta dispatch)
- Added NativeMX8MX4Quant class in quant_ops.py for benchmarking framework
- Scale format fix: to_mxfp8() returns plain row-major scales, applied _to_blocked() conversion to match CUTLASS block-scaled kernel interleaved layout

Note: native_mx8 (MX8xMX8 CUTLASS kernel) numbers shown below are from D99386004, a separate stacked diff. cuBLAS via torch._scaled_mm is 3-6x slower at small M (e.g., 1.4 vs 8.9 TFLOPS at M=1).

Benchmark on B200 (750W), N=K=8192, num_iters=20:

GEMM-Only TFLOPS (higher is better):
       | BF16 Fallback |          Native Block-Scaled Tensor Core                          |
       | (torch.matmul)|          (tcgen05.mma.blockscaled)                                |
     M | bf16_mm       | native_mx8*  | native_mx8_mx4 | native_nvfp4 | native_mxfp4     |
       |               | (mx8mx8bf16) | (mx8mx4bf16)   | (f4f4bf16)   | (f4f4bf16)       |
  -----+---------------+--------------+----------------+--------------+------------------+
     1 |           4.7 |          8.9 |            8.8 |         10.1 |             10.7 |
    16 |          88.3 |        141.4 |          143.5 |        164.5 |            175.6 |
    64 |         337.7 |        556.1 |          572.7 |        691.1 |            700.2 |
   128 |         636.3 |       1010.7 |         1120.3 |       1367.8 |           1432.4 |
   256 |         984.8 |       1806.7 |         2013.4 |       2796.2 |           2834.6 |
   512 |        1109.0 |       1987.9 |         2380.5 |       3618.0 |           4012.9 |
  1024 |        1211.0 |       2066.1 |         2457.5 |       4431.2 |           4635.4 |

  * native_mx8 CUTLASS kernel is in a separate diff (D99386004). Numbers shown for comparison only.

Output SQNR vs BF16 (dB, higher is better):
     M | native_mx8*  | native_mx8_mx4 | native_nvfp4 | native_mxfp4 |
       | (mx8mx8bf16) | (mx8mx4bf16)   | (f4f4bf16)   | (f4f4bf16)   |
  -----+--------------+----------------+--------------+--------------+
     1 |        28.56 |          18.80 |        17.56 |        16.13 |
    16 |        28.46 |          18.80 |        17.41 |        16.03 |
    64 |        28.45 |          18.79 |        17.43 |        16.06 |
   128 |        28.46 |          18.79 |        17.43 |        16.04 |
   256 |        28.46 |          18.79 |        17.43 |        16.04 |
   512 |        28.46 |          18.79 |        17.43 |        16.05 |
  1024 |        28.46 |          18.79 |        17.44 |        16.04 |

  * native_mx8 numbers from D99386004.

Activation Quantization Overhead (ms):
       |       Per-Format Quant Time          |                                        |
     M | to_mxfp8 | triton_mx4 | triton_nvfp4 | mx8_mx4 total                          |
       |          |            |              | (to_mxfp8 + _to_blocked + triton_mx4)  |
  -----+----------+------------+--------------+----------------------------------------+
     1 |    0.182 |      0.063 |        0.054 |                                  0.310 |
    16 |    0.191 |      0.055 |        0.053 |                                  0.306 |
    64 |    0.173 |      0.053 |        0.054 |                                  0.308 |
   128 |    0.179 |      0.053 |        0.053 |                                  0.272 |
   256 |    0.179 |      0.056 |        0.053 |                                  0.279 |
   512 |    0.183 |      0.071 |        0.063 |                                  0.287 |
  1024 |    0.189 |      0.053 |        0.056 |                                  0.278 |

Key takeaways:
- native_mx8_mx4 SQNR is 18.79 dB — 2.75 dB better than native_mxfp4 (16.04 dB), thanks to MX8 activation fidelity (31.49 dB input A SQNR vs 19.03 dB for MX4)
- mxf8f8 and mxf8f6f4 have the same peak throughput on B200 (~3,600 TFLOPS at 750W), so native_mx8 and native_mx8_mx4 show similar GEMM TFLOPS as expected
- The value of native_mx8_mx4 over native_mx8 is 2x weight compression (4-bit vs 8-bit) with no compute penalty and acceptable quality (18.79 dB SQNR)
- Correctness verified: 6/6 shapes pass, no NaN/Inf, consistent 18.79 dB SQNR

Reviewed By: cthi

Differential Revision: D97844552
@meta-cla meta-cla Bot added the cla signed label Apr 6, 2026
@meta-codesync
Copy link
Copy Markdown

meta-codesync Bot commented Apr 6, 2026

@isratnisa has exported this pull request. If you are a Meta employee, you can view the originating Diff in D97844552.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant