From 0ebf70b1d3606d33e8fd468ca80c3b8d82869144 Mon Sep 17 00:00:00 2001 From: ShaoChunLee Date: Tue, 19 May 2026 18:27:07 +0000 Subject: [PATCH 1/8] add all mxfp8 related kernels --- .../_triton_kernels/gemm/basic/gemm_mxfp8.py | 417 ++++++++++++++++++ .../_triton_kernels/quant/quant_mxfp8.py | 384 ++++++++++++++++ ...0-GEMM-MXFP8-PRESHUFFLE-N=1536-K=4096.json | 170 +++++++ ...0-GEMM-MXFP8-PRESHUFFLE-N=4096-K=1024.json | 170 +++++++ ...50-GEMM-MXFP8-PRESHUFFLE-N=4096-K=256.json | 170 +++++++ ...50-GEMM-MXFP8-PRESHUFFLE-N=512-K=4096.json | 170 +++++++ ...0-GEMM-MXFP8-PRESHUFFLE-N=8192-K=1024.json | 170 +++++++ .../gemm/gfx950-GEMM-MXFP8-PRESHUFFLE.json | 14 + aiter/ops/triton/gemm/basic/gemm_mxfp8.py | 256 +++++++++++ aiter/ops/triton/quant/quant_mxfp8.py | 289 ++++++++++++ 10 files changed, 2210 insertions(+) create mode 100644 aiter/ops/triton/_triton_kernels/gemm/basic/gemm_mxfp8.py create mode 100644 aiter/ops/triton/_triton_kernels/quant/quant_mxfp8.py create mode 100644 aiter/ops/triton/configs/gemm/gfx950-GEMM-MXFP8-PRESHUFFLE-N=1536-K=4096.json create mode 100644 aiter/ops/triton/configs/gemm/gfx950-GEMM-MXFP8-PRESHUFFLE-N=4096-K=1024.json create mode 100644 aiter/ops/triton/configs/gemm/gfx950-GEMM-MXFP8-PRESHUFFLE-N=4096-K=256.json create mode 100644 aiter/ops/triton/configs/gemm/gfx950-GEMM-MXFP8-PRESHUFFLE-N=512-K=4096.json create mode 100644 aiter/ops/triton/configs/gemm/gfx950-GEMM-MXFP8-PRESHUFFLE-N=8192-K=1024.json create mode 100644 aiter/ops/triton/configs/gemm/gfx950-GEMM-MXFP8-PRESHUFFLE.json create mode 100644 aiter/ops/triton/gemm/basic/gemm_mxfp8.py create mode 100644 aiter/ops/triton/quant/quant_mxfp8.py diff --git a/aiter/ops/triton/_triton_kernels/gemm/basic/gemm_mxfp8.py b/aiter/ops/triton/_triton_kernels/gemm/basic/gemm_mxfp8.py new file mode 100644 index 0000000000..beaf16cb23 --- /dev/null +++ b/aiter/ops/triton/_triton_kernels/gemm/basic/gemm_mxfp8.py @@ -0,0 +1,417 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. + +import triton +import triton.language as tl +from aiter.ops.triton.utils._triton.kernel_repr import make_kernel_repr +from aiter.ops.triton.utils._triton.pid_preprocessing import pid_grid, remap_xcd + + +_gemm_mxfp8_repr = make_kernel_repr( + "_gemm_mxfp8_kernel", + [ + "BLOCK_SIZE_M", + "BLOCK_SIZE_N", + "BLOCK_SIZE_K", + "GROUP_SIZE_M", + "num_warps", + "num_stages", + "waves_per_eu", + "matrix_instr_nonkdim", + "cache_modifier", + "NUM_KSPLIT", + ], +) + + +@triton.heuristics( + { + "EVEN_K": lambda args: (args["K"] % args["BLOCK_SIZE_K"] == 0), + } +) +@triton.jit(repr=_gemm_mxfp8_repr) +def _gemm_mxfp8_kernel( + a_ptr, + b_ptr, + c_ptr, + a_scales_ptr, + b_scales_ptr, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_asm, + stride_ask, + stride_bsn, + stride_bsk, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + NUM_KSPLIT: tl.constexpr, + EVEN_K: tl.constexpr, + num_warps: tl.constexpr, + num_stages: tl.constexpr, + waves_per_eu: tl.constexpr, + matrix_instr_nonkdim: tl.constexpr, + cache_modifier: tl.constexpr, +): + """ + Kernel for computing the matmul C = A x B. + A and B inputs are FP8 e4m3 (1 byte per element). + A_scales are e8m0 (uint8) with shape (M, K // 32). + B_scales are stored compact e8m0 (uint8) with shape (N // 128, K // 128), + representing 128x128 weight blocks. Broadcast inside kernel to (N, K // 32). + A has shape (M, K), B has shape (K, N) and C has shape (M, N). + Output dtype is determined by c_ptr (bf16 or fp16). + NUM_KSPLIT == 1 only (no split-K in this first version). + """ + + tl.assume(stride_am > 0) + tl.assume(stride_ak > 0) + tl.assume(stride_bk > 0) + tl.assume(stride_bn > 0) + tl.assume(stride_cm > 0) + tl.assume(stride_cn > 0) + tl.assume(stride_asm > 0) + tl.assume(stride_ask > 0) + tl.assume(stride_bsk > 0) + tl.assume(stride_bsn > 0) + + GRID_MN = tl.cdiv(M, BLOCK_SIZE_M) * tl.cdiv(N, BLOCK_SIZE_N) + + pid_unified = tl.program_id(axis=0) + pid_unified = remap_xcd(pid_unified, GRID_MN, NUM_XCDS=8) + + pid = pid_unified + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + + pid_m, pid_n = pid_grid(pid, num_pid_m, num_pid_n, GROUP_SIZE_M=GROUP_SIZE_M) + + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) + + # Scale group sizes + SCALE_GROUP_SIZE: tl.constexpr = 32 # A: per 32 elements along K + B_SCALE_K_GROUP: tl.constexpr = 128 # B: per 128 along K + B_SCALE_N_GROUP: tl.constexpr = 128 # B: per 128 along N + + # Number of K iterations + num_k_iter = tl.cdiv(K, BLOCK_SIZE_K) + + # Create pointers for first block of A and B input matrices + offs_k = tl.arange(0, BLOCK_SIZE_K) + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + a_ptrs = a_ptr + ( + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak + ) + b_ptrs = b_ptr + ( + offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn + ) + + # A-scale pointers: per-row (M) and per scale group (K // 32) + offs_ks_a = tl.arange(0, BLOCK_SIZE_K // SCALE_GROUP_SIZE) + a_scale_ptrs = ( + a_scales_ptr + offs_am[:, None] * stride_asm + offs_ks_a[None, :] * stride_ask + ) + + # B-scale pointers: compact (N // 128, K // 128) — broadcast inside the kernel + # Each scale covers a 128(N) x 128(K) block. + # We want a tile of shape (BLOCK_SIZE_N, BLOCK_SIZE_K // 32) per K-iter. + # Broadcast by floor-division: row index n -> n // 128; k-scale index ks -> ks // 4 + # (since ks counts 32-element groups, and 128 / 32 = 4). + offs_bsn = offs_bn // B_SCALE_N_GROUP # (BLOCK_SIZE_N,) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + offs_scale_k_a = tl.arange(0, BLOCK_SIZE_K // SCALE_GROUP_SIZE) + + for k in range(0, num_k_iter): + # K base for this iteration (in elements) + k_base = k * BLOCK_SIZE_K + + # ---- Load A scales (M, BLOCK_SIZE_K // 32) ---- + if EVEN_K: + a_scales = tl.load(a_scale_ptrs) + else: + a_scale_mask = offs_scale_k_a[None, :] < ( + K // SCALE_GROUP_SIZE - k * (BLOCK_SIZE_K // SCALE_GROUP_SIZE) + ) + a_scales = tl.load(a_scale_ptrs, mask=a_scale_mask, other=127) + + # ---- Load and broadcast B scales (BLOCK_SIZE_N, BLOCK_SIZE_K // 32) ---- + # B-scale K index in compact layout (K // 128 stride) + k_scale_idx_b = (k_base + tl.arange(0, BLOCK_SIZE_K) // SCALE_GROUP_SIZE) // ( + B_SCALE_K_GROUP // SCALE_GROUP_SIZE + ) + # k_scale_idx_b: shape (BLOCK_SIZE_K,), but we want one per 32-group. + # Recompute correctly: + offs_bsk = ( + k_base + offs_scale_k_a * SCALE_GROUP_SIZE + ) // B_SCALE_K_GROUP # (BLOCK_SIZE_K // 32,) + b_scale_ptrs = ( + b_scales_ptr + + offs_bsn[:, None] * stride_bsn + + offs_bsk[None, :] * stride_bsk + ) + if EVEN_K: + b_scales = tl.load(b_scale_ptrs, cache_modifier=cache_modifier) + else: + # OOB along K: load with the same mask as a-scales + b_scale_mask = offs_scale_k_a[None, :] < ( + K // SCALE_GROUP_SIZE - k * (BLOCK_SIZE_K // SCALE_GROUP_SIZE) + ) + b_scales = tl.load( + b_scale_ptrs, + mask=b_scale_mask, + other=127, + cache_modifier=cache_modifier, + ) + + # ---- Load A, B data ---- + if EVEN_K: + a = tl.load(a_ptrs) + b = tl.load(b_ptrs, cache_modifier=cache_modifier) + else: + a = tl.load( + a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0 + ) + b = tl.load( + b_ptrs, + mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, + other=0, + cache_modifier=cache_modifier, + ) + + accumulator = tl.dot_scaled( + a, a_scales, "e4m3", b, b_scales, "e4m3", accumulator + ) + + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + a_scale_ptrs += (BLOCK_SIZE_K // SCALE_GROUP_SIZE) * stride_ask + + c = accumulator.to(c_ptr.type.element_ty) + + # Write back the block of the output matrix C with masks. + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +_gemm_mxfp8_preshuffle_repr = make_kernel_repr( + "_gemm_mxfp8_preshuffle_kernel", + [ + "BLOCK_SIZE_M", + "BLOCK_SIZE_N", + "BLOCK_SIZE_K", + "GROUP_SIZE_M", + "num_warps", + "num_stages", + "waves_per_eu", + "matrix_instr_nonkdim", + "cache_modifier", + "NUM_KSPLIT", + ], +) + + +@triton.heuristics( + { + "EVEN_K": lambda args: (args["K"] % args["BLOCK_SIZE_K"] == 0), + } +) +@triton.jit(repr=_gemm_mxfp8_preshuffle_repr) +def _gemm_mxfp8_preshuffle_kernel( + a_ptr, + b_ptr, + c_ptr, + a_scales_ptr, + b_scales_ptr, + M, + N, + K, + stride_am, + stride_ak, + stride_bn, + stride_bk, + stride_cm, + stride_cn, + stride_asm, + stride_ask, + stride_bsn, + stride_bsk, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + NUM_KSPLIT: tl.constexpr, + EVEN_K: tl.constexpr, + num_warps: tl.constexpr, + num_stages: tl.constexpr, + waves_per_eu: tl.constexpr, + matrix_instr_nonkdim: tl.constexpr, + cache_modifier: tl.constexpr, +): + """ + Preshuffle variant of _gemm_mxfp8_kernel. Weight tensor has been shuffled + via aiter.ops.shuffle.shuffle_weight(layout=(16, 16)) so that 16-row N tiles + are interleaved with their 32-col K chunks in storage. The kernel loads the + shuffled tile in storage order (BLOCK_SIZE_N // 16, BLOCK_SIZE_K * 16) then + reshape+permute+trans inside the kernel to restore logical (K, N) layout + before tl.dot_scaled. Scales remain in the unshuffled compact 128x128 layout. + """ + + tl.assume(stride_am > 0) + tl.assume(stride_ak > 0) + tl.assume(stride_bk > 0) + tl.assume(stride_bn > 0) + tl.assume(stride_cm > 0) + tl.assume(stride_cn > 0) + tl.assume(stride_asm > 0) + tl.assume(stride_ask > 0) + tl.assume(stride_bsk > 0) + tl.assume(stride_bsn > 0) + + GRID_MN = tl.cdiv(M, BLOCK_SIZE_M) * tl.cdiv(N, BLOCK_SIZE_N) + + pid_unified = tl.program_id(axis=0) + pid_unified = remap_xcd(pid_unified, GRID_MN, NUM_XCDS=8) + + pid = pid_unified + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + + pid_m, pid_n = pid_grid(pid, num_pid_m, num_pid_n, GROUP_SIZE_M=GROUP_SIZE_M) + + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) + + SCALE_GROUP_SIZE: tl.constexpr = 32 # A: per-32 along K + B_SCALE_K_GROUP: tl.constexpr = 128 # B compact: per-128 along K + B_SCALE_N_GROUP: tl.constexpr = 128 # B compact: per-128 along N + + num_k_iter = tl.cdiv(K, BLOCK_SIZE_K) + + # A pointers (unchanged from non-preshuffle kernel). + offs_k = tl.arange(0, BLOCK_SIZE_K) + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + + # B pointers for preshuffled layout. The shuffled storage is viewed as + # (N // 16, K * 16) elements. pid_n indexes BLOCK_SIZE_N // 16 N-tiles per + # step; the K dimension is expanded by 16x in byte addresses. + offs_bn_shuffle = pid_n * (BLOCK_SIZE_N // 16) + tl.arange(0, BLOCK_SIZE_N // 16) + offs_k_shuffle = tl.arange(0, BLOCK_SIZE_K * 16) + b_ptrs = b_ptr + ( + offs_bn_shuffle[:, None] * stride_bn + offs_k_shuffle[None, :] * stride_bk + ) + + # A-scale pointers: per-row M, per 32-K group. + offs_ks_a = tl.arange(0, BLOCK_SIZE_K // SCALE_GROUP_SIZE) + a_scale_ptrs = ( + a_scales_ptr + offs_am[:, None] * stride_asm + offs_ks_a[None, :] * stride_ask + ) + + # B-scale pointers: compact (N // 128, K // 128). The N index needs the + # ORIGINAL (logical) row, not the shuffled row index, so use offs_bn_logical. + offs_bn_logical = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_bsn = offs_bn_logical // B_SCALE_N_GROUP + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + offs_scale_k_a = tl.arange(0, BLOCK_SIZE_K // SCALE_GROUP_SIZE) + + for k in range(0, num_k_iter): + k_base = k * BLOCK_SIZE_K + + # Load A scales. + if EVEN_K: + a_scales = tl.load(a_scale_ptrs) + else: + a_scale_mask = offs_scale_k_a[None, :] < ( + K // SCALE_GROUP_SIZE - k * (BLOCK_SIZE_K // SCALE_GROUP_SIZE) + ) + a_scales = tl.load(a_scale_ptrs, mask=a_scale_mask, other=127) + + # Load and broadcast B scales (same as non-preshuffle path). + offs_bsk = ( + k_base + offs_scale_k_a * SCALE_GROUP_SIZE + ) // B_SCALE_K_GROUP + b_scale_ptrs = ( + b_scales_ptr + + offs_bsn[:, None] * stride_bsn + + offs_bsk[None, :] * stride_bsk + ) + if EVEN_K: + b_scales = tl.load(b_scale_ptrs, cache_modifier=cache_modifier) + else: + b_scale_mask = offs_scale_k_a[None, :] < ( + K // SCALE_GROUP_SIZE - k * (BLOCK_SIZE_K // SCALE_GROUP_SIZE) + ) + b_scales = tl.load( + b_scale_ptrs, + mask=b_scale_mask, + other=127, + cache_modifier=cache_modifier, + ) + + # Load A and B (preshuffled). + if EVEN_K: + a = tl.load(a_ptrs) + b_shuf = tl.load(b_ptrs, cache_modifier=cache_modifier) + else: + a = tl.load( + a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0 + ) + b_shuf = tl.load( + b_ptrs, + mask=offs_k_shuffle[None, :] < (K - k * BLOCK_SIZE_K) * 16, + other=0, + cache_modifier=cache_modifier, + ) + + # Unshuffle B in-kernel. Inverse of the shuffle_weight permute: + # shuffle: (N // 16, 16, K // 32, 2, 16) --[perm 0,1,3,4,2,5]-> (N // 16, K // 32, 2, 16, 16) + # unshuffle: (N // 16, K // 32, 2, 16, 16) --[perm 0,1,4,2,3,5]-> (N // 16, 16, K // 32, 2, 16) + # then flatten to (N, K) and trans to (K, N). + b = ( + b_shuf.reshape( + 1, + BLOCK_SIZE_N // 16, + BLOCK_SIZE_K // 32, + 2, + 16, + 16, + ) + .permute(0, 1, 4, 2, 3, 5) + .reshape(BLOCK_SIZE_N, BLOCK_SIZE_K) + .trans(1, 0) + ) + + accumulator = tl.dot_scaled( + a, a_scales, "e4m3", b, b_scales, "e4m3", accumulator + ) + + # Advance pointers. + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * 16 * stride_bk + a_scale_ptrs += (BLOCK_SIZE_K // SCALE_GROUP_SIZE) * stride_ask + + c = accumulator.to(c_ptr.type.element_ty) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) diff --git a/aiter/ops/triton/_triton_kernels/quant/quant_mxfp8.py b/aiter/ops/triton/_triton_kernels/quant/quant_mxfp8.py new file mode 100644 index 0000000000..fe83509721 --- /dev/null +++ b/aiter/ops/triton/_triton_kernels/quant/quant_mxfp8.py @@ -0,0 +1,384 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. + +import triton +import triton.language as tl + + +# MXFP8 activation quant: per-1x32 e8m0 scale + FP8 e4m3 values. +# Follows aiter.ops.quant.per_1x32_f8_scale_f8_quant: +# MAX_POW2 = int(log2(448)) = 8 +# dtypeMax = 2 ** 8 = 256.0 +# scale_f32 = max_abs / dtypeMax +# scale_e8m0 = round_up_to_pow2(scale_f32) → e8m0 biased +# y = round(x_fp32 / e8m0_to_f32(scale_e8m0)) cast to fp8 e4m3 +# +# Per-block e8m0 derivation done with the same trick as the existing mxfp4 quant: +# - bitcast amax to int32 +# - add 0x200000 (round up to a power of 2 with respect to fp4-style rounding) +# - mask 0xFF800000 (keep only sign+exponent bits) +# - bitcast back to fp32 +# This delivers a pure power-of-2 amax. Then log2(amax).floor() - 8 gives the +# unbiased e8m0 exponent for MXFP8 (since dtypeMax = 2**8). + + +@triton.jit +def _mxfp8_quant_kernel( + x_ptr, + y_ptr, + s_ptr, + M, + N, + stride_xm, + stride_xn, + stride_ym, + stride_yn, + stride_sm, + stride_sn, + BLOCK_SIZE_N: tl.constexpr, # power-of-2 covering full N + QUANT_BLOCK_SIZE: tl.constexpr, # =32 + NUM_PRGMS: tl.constexpr, # row-loop range (usually =M) +): + """ + Per-1x32 MXFP8 quant. One program per row, holding the full row in + registers so a single launch handles all K-groups. Mirrors + _rmsnorm_mxfp8_quant_kernel shape and minimizes grid overhead. + """ + row_start = tl.program_id(0) + col_offsets = tl.arange(0, BLOCK_SIZE_N) + mask = col_offsets < N + n_groups: tl.constexpr = BLOCK_SIZE_N // QUANT_BLOCK_SIZE + + for row_idx in tl.range(row_start, M, NUM_PRGMS, num_stages=2): + x = tl.load( + x_ptr + row_idx * stride_xm + col_offsets * stride_xn, + mask=mask, + other=0.0, + ).to(tl.float32) + + # (BLOCK_SIZE_N,) -> (n_groups, QUANT_BLOCK_SIZE) + x_2d = tl.reshape(x, (n_groups, QUANT_BLOCK_SIZE)) + amax = tl.max(tl.abs(x_2d), axis=1, keep_dims=True) + + amax_i32 = amax.to(tl.int32, bitcast=True) + amax_i32 = (amax_i32 + 0x200000).to(tl.uint32, bitcast=True) & 0xFF800000 + amax_p2 = amax_i32.to(tl.float32, bitcast=True) + scale_unbiased = tl.log2(amax_p2).floor() - 8 + scale_unbiased = tl.clamp(scale_unbiased, min=-127, max=127) + scale_e8m0 = (scale_unbiased.to(tl.int32) + 127).to(tl.uint8) + quant_scale = tl.exp2(-scale_unbiased) + + qx_2d = x_2d * quant_scale + qx = tl.reshape(qx_2d, (BLOCK_SIZE_N,)) + y = qx.to(y_ptr.type.element_ty) + + tl.store( + y_ptr + row_idx * stride_ym + col_offsets * stride_yn, + y, + mask=mask, + ) + + group_offsets = tl.arange(0, n_groups) + group_mask = group_offsets < (N // QUANT_BLOCK_SIZE) + scale_flat = tl.reshape(scale_e8m0, (n_groups,)) + tl.store( + s_ptr + row_idx * stride_sm + group_offsets * stride_sn, + scale_flat, + mask=group_mask, + ) + + +# Transcoder: (FP8 fnuz, fp32 1x128 scale) -> (FP8 fn, e8m0 1x32 scale). +# Replaces the Python dequant+requant cascade (fp32 cast + multiply + bf16 cast +# + per_1x32_mxfp8 quant) used in linear.py's MXFP8 fallback path for MLA wq_b +# when q_norm emits the legacy fp8 fnuz + fp32 1x128 format. +# +# In: x_fp8_fnuz (M, N) — fp8 e4m3fnuz bits (interpreted with bias 8 -> value) +# x_scale_fp32 (M, N//128) — fp32 per-token-block scale +# Out: y_fp8_fn (M, N) — fp8 e4m3fn bits (NV format, bias 7) +# y_scale_e8m0 (M, N//32) — uint8 e8m0 (1x32 MX scale) + + +@triton.jit +def _fp8_legacy_to_mxfp8_kernel( + x_fnuz_ptr, + x_scale_fp32_ptr, + y_fn_ptr, + y_scale_e8m0_ptr, + M, + N, + stride_xm, + stride_xn, + stride_xsm, + stride_xsn, + stride_ym, + stride_yn, + stride_ysm, + stride_ysn, + BLOCK_SIZE_M: tl.constexpr, + QUANT_BLOCK_SIZE: tl.constexpr, # =32 (MXFP8 group) + LEGACY_BLOCK_SIZE: tl.constexpr, # =128 (input scale group) +): + """ + One program per (BLOCK_SIZE_M rows, QUANT_BLOCK_SIZE-element column window). + For each 1x32 block, dequantize fnuz fp8 values using the corresponding + 1x128 fp32 scale, derive the e8m0 (1x32) scale, then re-quantize to fp8 fn. + """ + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * QUANT_BLOCK_SIZE + tl.arange(0, QUANT_BLOCK_SIZE) + + x_offs = offs_m[:, None] * stride_xm + offs_n[None, :] * stride_xn + x_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + + # Load fp8 fnuz values; .to(fp32) decodes via fnuz bias 8 semantically. + x_fnuz = tl.load(x_fnuz_ptr + x_offs, mask=x_mask, other=0.0).to(tl.float32) + + # Which legacy 1x128 group does this 1x32 block fall into? + legacy_n = (pid_n * QUANT_BLOCK_SIZE) // LEGACY_BLOCK_SIZE + xs_offs = offs_m * stride_xsm + legacy_n * stride_xsn + xs_mask = offs_m < M + x_scale = tl.load(x_scale_fp32_ptr + xs_offs, mask=xs_mask, other=1.0) + + # Dequantize: bf16-equivalent reconstruction. + x_dq = x_fnuz * x_scale[:, None] + + # Derive new e8m0 (1x32) scale from x_dq amax. Same recipe as + # _mxfp8_quant_kernel above. + amax = tl.max(tl.abs(x_dq), axis=1, keep_dims=True) + amax_i32 = amax.to(tl.int32, bitcast=True) + amax_i32 = (amax_i32 + 0x200000).to(tl.uint32, bitcast=True) & 0xFF800000 + amax_p2 = amax_i32.to(tl.float32, bitcast=True) + scale_unbiased = tl.log2(amax_p2).floor() - 8 + scale_unbiased = tl.clamp(scale_unbiased, min=-127, max=127) + scale_e8m0 = (scale_unbiased.to(tl.int32) + 127).to(tl.uint8) + quant_scale = tl.exp2(-scale_unbiased) + + # Re-quantize to fp8 fn. + qx = x_dq * quant_scale + y = qx.to(y_fn_ptr.type.element_ty) + + y_offs = offs_m[:, None] * stride_ym + offs_n[None, :] * stride_yn + tl.store(y_fn_ptr + y_offs, y, mask=x_mask) + + s_offs = offs_m[:, None] * stride_ysm + pid_n * stride_ysn + s_mask = offs_m[:, None] < M + tl.store(y_scale_e8m0_ptr + s_offs, scale_e8m0, mask=s_mask) + + +# Fused RMSNorm + MXFP8 (1x32 e8m0) quant. Replaces the separate +# rmsnorm_quant(fp8 fnuz + fp32 1x128) + transcode-to-MXFP8 sequence used +# upstream of MXFP8-aware GEMMs (e.g. V4 q_norm -> wq_b). +# +# One program per row. Holds the full row in registers, so K is constrained +# by the BLOCK_SIZE_K constexpr (must be a power of two >= K). +# +# In: x (M, K) bf16 or fp16 +# g (K,) bf16 or fp16 weight +# Out: y (M, K) fp8 e4m3fn +# scale (M, K // 32) uint8 e8m0 + + +@triton.jit +def _rmsnorm_mxfp8_quant_kernel( + x_ptr, + g_ptr, + y_ptr, + s_ptr, + M, + K, + stride_xm, + stride_xk, + stride_ym, + stride_yk, + stride_sm, + stride_sn, + epsilon, + BLOCK_SIZE_K: tl.constexpr, # power-of-2 covering full K + QUANT_BLOCK_SIZE: tl.constexpr, # =32 + NUM_PRGMS: tl.constexpr, # for persistent-loop variant; usually =M +): + """One program processes one row: rmsnorm then MXFP8 quant in registers.""" + row_start = tl.program_id(0) + col_offsets = tl.arange(0, BLOCK_SIZE_K) + mask = col_offsets < K + + for row_idx in tl.range(row_start, M, NUM_PRGMS, num_stages=2): + # Load full row, cast to fp32 + x = tl.load( + x_ptr + row_idx * stride_xm + col_offsets * stride_xk, + mask=mask, + other=0.0, + ).to(tl.float32) + g = tl.load(g_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32) + + # RMS norm + ss = tl.sum(x * x, axis=-1) + norm_factor = tl.math.rsqrt((ss / K) + epsilon) + y_fp32 = x * norm_factor * g # (BLOCK_SIZE_K,) + + # Reshape into (K // QUANT_BLOCK_SIZE, QUANT_BLOCK_SIZE) groups for amax. + # BLOCK_SIZE_K is the power-of-2 padded size; we keep OOB lanes masked to 0 + # via the load above, so amax over them is 0 (won't affect the in-bounds max). + y_2d = tl.reshape(y_fp32, (BLOCK_SIZE_K // QUANT_BLOCK_SIZE, QUANT_BLOCK_SIZE)) + amax = tl.max(tl.abs(y_2d), axis=1, keep_dims=True) # (G, 1) + + # e8m0 scale derivation (same recipe as _mxfp8_quant_kernel). + amax_i32 = amax.to(tl.int32, bitcast=True) + amax_i32 = (amax_i32 + 0x200000).to(tl.uint32, bitcast=True) & 0xFF800000 + amax_p2 = amax_i32.to(tl.float32, bitcast=True) + scale_unbiased = tl.log2(amax_p2).floor() - 8 + scale_unbiased = tl.clamp(scale_unbiased, min=-127, max=127) + scale_e8m0 = (scale_unbiased.to(tl.int32) + 127).to(tl.uint8) # (G, 1) + quant_scale = tl.exp2(-scale_unbiased) # (G, 1) + + # Quantize: y_quant = y_fp32 * quant_scale (broadcast along inner 32). + qx_2d = y_2d * quant_scale + qx = tl.reshape(qx_2d, (BLOCK_SIZE_K,)) + y_fp8 = qx.to(y_ptr.type.element_ty) + + # Store y (mask OOB). + tl.store( + y_ptr + row_idx * stride_ym + col_offsets * stride_yk, + y_fp8, + mask=mask, + ) + + # Store scales: G entries for this row. + n_groups: tl.constexpr = BLOCK_SIZE_K // QUANT_BLOCK_SIZE + group_offsets = tl.arange(0, n_groups) + group_mask = group_offsets < (K // QUANT_BLOCK_SIZE) + scale_flat = tl.reshape(scale_e8m0, (n_groups,)) + tl.store( + s_ptr + row_idx * stride_sm + group_offsets * stride_sn, + scale_flat, + mask=group_mask, + ) + + +# Dual fused RMSNorm: Q-side (MXFP8 quant + e8m0 scale emit) + K-side (bf16 out). +# Replaces the CK `fused_qk_rmsnorm_group_quant` semantics in one Triton launch +# for the MXFP8 GEMM path (Task #77). The two halves are independent (different +# weight, different K dim) so they're packed into one program per row to amortize +# launch overhead: same kernel launch loads both rows, normalizes both, stores Q +# fp8 + scale, stores K bf16. Each row's Q and K are independently RMSNorm'd +# (separate weights, separate eps, separate K dim) -- this kernel does NOT fuse +# their normalization arithmetic, only their launch. +# +# In: q (M, KQ) bf16 or fp16 +# kv (M, KK) bf16 or fp16 +# gq (KQ,) bf16 or fp16 Q-RMSNorm weight +# gk (KK,) bf16 or fp16 K-RMSNorm weight +# Out: yq (M, KQ) fp8 e4m3fn +# sq (M, KQ // 32) uint8 e8m0 +# yk (M, KK) bf16 + + +@triton.jit +def _dual_rmsnorm_mxfp8_quant_kernel( + q_ptr, + k_ptr, + gq_ptr, + gk_ptr, + yq_ptr, + sq_ptr, + yk_ptr, + M, + KQ, + KK, + stride_qm, + stride_qn, + stride_km, + stride_kn, + stride_yqm, + stride_yqn, + stride_sqm, + stride_sqn, + stride_ykm, + stride_ykn, + eps_q, + eps_k, + BLOCK_SIZE_KQ: tl.constexpr, # power-of-2 covering full KQ + BLOCK_SIZE_KK: tl.constexpr, # power-of-2 covering full KK + QUANT_BLOCK_SIZE: tl.constexpr, # =32 (MXFP8 group size) + NUM_PRGMS: tl.constexpr, # row-loop bound (usually =M) +): + """One program per row: do Q-side RMSNorm+MXFP8 quant AND K-side RMSNorm + (bf16 out) in one launch. Mirrors the CK `fused_qk_rmsnorm_group_quant` + fusion topology but emits MXFP8 1x32 (e8m0) scales for Q directly.""" + row_start = tl.program_id(0) + + q_col_offsets = tl.arange(0, BLOCK_SIZE_KQ) + q_mask = q_col_offsets < KQ + k_col_offsets = tl.arange(0, BLOCK_SIZE_KK) + k_mask = k_col_offsets < KK + + n_q_groups: tl.constexpr = BLOCK_SIZE_KQ // QUANT_BLOCK_SIZE + + for row_idx in tl.range(row_start, M, NUM_PRGMS, num_stages=2): + # ===== Q side: RMSNorm + MXFP8 quant ===== + x_q = tl.load( + q_ptr + row_idx * stride_qm + q_col_offsets * stride_qn, + mask=q_mask, + other=0.0, + ).to(tl.float32) + g_q = tl.load(gq_ptr + q_col_offsets, mask=q_mask, other=0.0).to( + tl.float32 + ) + + ss_q = tl.sum(x_q * x_q, axis=-1) + norm_q = tl.math.rsqrt((ss_q / KQ) + eps_q) + y_q_fp32 = x_q * norm_q * g_q + + y_q_2d = tl.reshape(y_q_fp32, (n_q_groups, QUANT_BLOCK_SIZE)) + amax_q = tl.max(tl.abs(y_q_2d), axis=1, keep_dims=True) + + amax_qi32 = amax_q.to(tl.int32, bitcast=True) + amax_qi32 = (amax_qi32 + 0x200000).to(tl.uint32, bitcast=True) & 0xFF800000 + amax_qp2 = amax_qi32.to(tl.float32, bitcast=True) + scale_q_unbiased = tl.log2(amax_qp2).floor() - 8 + scale_q_unbiased = tl.clamp(scale_q_unbiased, min=-127, max=127) + scale_q_e8m0 = (scale_q_unbiased.to(tl.int32) + 127).to(tl.uint8) + quant_scale_q = tl.exp2(-scale_q_unbiased) + + qx_q_2d = y_q_2d * quant_scale_q + qx_q = tl.reshape(qx_q_2d, (BLOCK_SIZE_KQ,)) + y_q_fp8 = qx_q.to(yq_ptr.type.element_ty) + + tl.store( + yq_ptr + row_idx * stride_yqm + q_col_offsets * stride_yqn, + y_q_fp8, + mask=q_mask, + ) + + q_group_offsets = tl.arange(0, n_q_groups) + q_group_mask = q_group_offsets < (KQ // QUANT_BLOCK_SIZE) + scale_q_flat = tl.reshape(scale_q_e8m0, (n_q_groups,)) + tl.store( + sq_ptr + row_idx * stride_sqm + q_group_offsets * stride_sqn, + scale_q_flat, + mask=q_group_mask, + ) + + # ===== K side: RMSNorm only, bf16 out ===== + x_k = tl.load( + k_ptr + row_idx * stride_km + k_col_offsets * stride_kn, + mask=k_mask, + other=0.0, + ).to(tl.float32) + g_k = tl.load(gk_ptr + k_col_offsets, mask=k_mask, other=0.0).to( + tl.float32 + ) + + ss_k = tl.sum(x_k * x_k, axis=-1) + norm_k = tl.math.rsqrt((ss_k / KK) + eps_k) + y_k_fp32 = x_k * norm_k * g_k + y_k_out = y_k_fp32.to(yk_ptr.type.element_ty) + + tl.store( + yk_ptr + row_idx * stride_ykm + k_col_offsets * stride_ykn, + y_k_out, + mask=k_mask, + ) diff --git a/aiter/ops/triton/configs/gemm/gfx950-GEMM-MXFP8-PRESHUFFLE-N=1536-K=4096.json b/aiter/ops/triton/configs/gemm/gfx950-GEMM-MXFP8-PRESHUFFLE-N=1536-K=4096.json new file mode 100644 index 0000000000..6b2567e194 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/gfx950-GEMM-MXFP8-PRESHUFFLE-N=1536-K=4096.json @@ -0,0 +1,170 @@ +{ + "M_LEQ_4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_8192": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_GEQ_8192": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + } +} \ No newline at end of file diff --git a/aiter/ops/triton/configs/gemm/gfx950-GEMM-MXFP8-PRESHUFFLE-N=4096-K=1024.json b/aiter/ops/triton/configs/gemm/gfx950-GEMM-MXFP8-PRESHUFFLE-N=4096-K=1024.json new file mode 100644 index 0000000000..5ba817728b --- /dev/null +++ b/aiter/ops/triton/configs/gemm/gfx950-GEMM-MXFP8-PRESHUFFLE-N=4096-K=1024.json @@ -0,0 +1,170 @@ +{ + "M_LEQ_4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_8192": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_GEQ_8192": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "any": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + } +} \ No newline at end of file diff --git a/aiter/ops/triton/configs/gemm/gfx950-GEMM-MXFP8-PRESHUFFLE-N=4096-K=256.json b/aiter/ops/triton/configs/gemm/gfx950-GEMM-MXFP8-PRESHUFFLE-N=4096-K=256.json new file mode 100644 index 0000000000..e4b178dda7 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/gfx950-GEMM-MXFP8-PRESHUFFLE-N=4096-K=256.json @@ -0,0 +1,170 @@ +{ + "M_LEQ_4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_8192": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_GEQ_8192": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + } +} \ No newline at end of file diff --git a/aiter/ops/triton/configs/gemm/gfx950-GEMM-MXFP8-PRESHUFFLE-N=512-K=4096.json b/aiter/ops/triton/configs/gemm/gfx950-GEMM-MXFP8-PRESHUFFLE-N=512-K=4096.json new file mode 100644 index 0000000000..63ea464784 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/gfx950-GEMM-MXFP8-PRESHUFFLE-N=512-K=4096.json @@ -0,0 +1,170 @@ +{ + "M_LEQ_4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_8192": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_GEQ_8192": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + } +} \ No newline at end of file diff --git a/aiter/ops/triton/configs/gemm/gfx950-GEMM-MXFP8-PRESHUFFLE-N=8192-K=1024.json b/aiter/ops/triton/configs/gemm/gfx950-GEMM-MXFP8-PRESHUFFLE-N=8192-K=1024.json new file mode 100644 index 0000000000..cb2547fbbd --- /dev/null +++ b/aiter/ops/triton/configs/gemm/gfx950-GEMM-MXFP8-PRESHUFFLE-N=8192-K=1024.json @@ -0,0 +1,170 @@ +{ + "M_LEQ_4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_2048": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_8192": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_GEQ_8192": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + } +} \ No newline at end of file diff --git a/aiter/ops/triton/configs/gemm/gfx950-GEMM-MXFP8-PRESHUFFLE.json b/aiter/ops/triton/configs/gemm/gfx950-GEMM-MXFP8-PRESHUFFLE.json new file mode 100644 index 0000000000..0008e2ab97 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/gfx950-GEMM-MXFP8-PRESHUFFLE.json @@ -0,0 +1,14 @@ +{ + "any": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + } +} diff --git a/aiter/ops/triton/gemm/basic/gemm_mxfp8.py b/aiter/ops/triton/gemm/basic/gemm_mxfp8.py new file mode 100644 index 0000000000..56d4d2f44c --- /dev/null +++ b/aiter/ops/triton/gemm/basic/gemm_mxfp8.py @@ -0,0 +1,256 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. + +from typing import Optional + +import torch +import triton + +from aiter.ops.triton._triton_kernels.gemm.basic.gemm_mxfp8 import ( + _gemm_mxfp8_kernel, + _gemm_mxfp8_preshuffle_kernel, +) +from aiter.ops.triton.utils.gemm_config_utils import get_gemm_config + + +_DEFAULT_CONFIG = { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "NUM_KSPLIT": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", +} + + +def _get_default_config() -> dict: + return dict(_DEFAULT_CONFIG) + + +# ----------------------------------------------------------------------------- +# Tuned-config lookup for gemm_mxfp8_preshuffle. +# +# Configs live under aiter.ops.triton.configs.gemm using the standard aiter +# naming: gfx{arch}-GEMM-MXFP8-PRESHUFFLE-N={N}-K={K}.json, keyed by +# M_LEQ_x / M_GEQ_x / any per STANDARD_M_BOUNDS. A generic fallback +# gfx{arch}-GEMM-MXFP8-PRESHUFFLE.json covers untuned (N, K) shapes. +# ----------------------------------------------------------------------------- + +# Shapes that have OOM'd at runtime with the tuned config. Future calls +# bypass the tuned lookup for these and use _DEFAULT_CONFIG. The tuned +# benches don't run under HIP-graph capture, so they can pick configs that +# look fine in isolation but blow LDS once captured. This cache survives +# across requests in the same process. +_OOM_SHAPES: set = set() + + +def _mark_oom(M: int, N: int, K: int): + _OOM_SHAPES.add((M, N, K)) + + +def _get_config(M: int, N: int, K: int) -> dict: + """Load the best tuned config for (M, N, K) via the standard aiter JSON + config mechanism. Falls back to the generic fallback file or to + _DEFAULT_CONFIG if no JSON is available.""" + try: + config, _ = get_gemm_config("GEMM-MXFP8-PRESHUFFLE", M, N, K) + except (AssertionError, KeyError, FileNotFoundError): + config = _get_default_config() + # Always pin NUM_KSPLIT=1 (this version of the kernel doesn't split K). + config["NUM_KSPLIT"] = 1 + # The JSON uses cache_modifier=null which decodes to Python None; the + # kernel accepts None or "", but we keep the empty-string default from + # _DEFAULT_CONFIG to match the legacy code path. + if config.get("cache_modifier") is None: + config["cache_modifier"] = "" + return config + + +def gemm_mxfp8( + x: torch.Tensor, + w: torch.Tensor, + x_scales: torch.Tensor, + w_scales: torch.Tensor, + dtype: Optional[torch.dtype] = torch.bfloat16, + y: Optional[torch.Tensor] = None, + config: Optional[dict] = None, +) -> torch.Tensor: + """ + Computes matrix multiplication Y = X @ W^T with MXFP8 activations and FP8 + weights (1x32 e8m0 act scales, 128x128 e8m0 weight scales). + + Args: + x: FP8 e4m3 (or uint8 view) input matrix with shape (M, K). + w: FP8 e4m3 (or uint8 view) weight matrix with shape (N, K) — internally + transposed to (K, N) before the kernel call. + x_scales: e8m0 (uint8) per-group scale for x with shape (M, K // 32). + w_scales: e8m0 (uint8) per-block scale for w with shape (N // 128, K // 128). + dtype: Output dtype (BF16 or FP16). Default bf16. + y: Optional pre-allocated output tensor with shape (M, N). + config: Optional kernel-tuning dict. If None uses defaults. + + Returns: + torch.Tensor: Output with shape (M, N). + """ + M, K = x.shape + N, K_w = w.shape + assert K == K_w, f"K mismatch: x has K={K}, w has K={K_w}" + + # Transpose w to (K, N) for the kernel. + w_t = w.T + + # tl.dot_scaled with format "e4m3" expects uint8-typed operands; reinterpret + # the FP8 buffers as uint8 (bit-identical view). + if x.dtype != torch.uint8: + x = x.view(torch.uint8) + if w_t.dtype != torch.uint8: + w_t = w_t.view(torch.uint8) + + if config is None: + config = _get_default_config() + else: + # Merge with defaults so missing keys are filled. + merged = _get_default_config() + merged.update(config) + config = merged + + # First-version: NUM_KSPLIT must be 1 + config["NUM_KSPLIT"] = 1 + + if y is None: + y = torch.empty((M, N), dtype=dtype, device=x.device) + + grid = lambda META: ( # noqa: E731 + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), + ) + + _gemm_mxfp8_kernel[grid]( + x, + w_t, + y, + x_scales, + w_scales, + M, + N, + K, + x.stride(0), + x.stride(1), + w_t.stride(0), + w_t.stride(1), + y.stride(0), + y.stride(1), + x_scales.stride(0), + x_scales.stride(1), + w_scales.stride(0), + w_scales.stride(1), + **config, + ) + + return y + + +def gemm_mxfp8_preshuffle( + x: torch.Tensor, + w_shuffled: torch.Tensor, + x_scales: torch.Tensor, + w_scales: torch.Tensor, + dtype: Optional[torch.dtype] = torch.bfloat16, + y: Optional[torch.Tensor] = None, + config: Optional[dict] = None, +) -> torch.Tensor: + """ + Preshuffle variant of gemm_mxfp8. The weight tensor has already been + permuted via aiter.ops.shuffle.shuffle_weight(..., layout=(16, 16)). Scales + are left unshuffled in the compact 128x128 layout. + + Args: + x: FP8 e4m3 activations with shape (M, K). + w_shuffled: FP8 e4m3 weights, shuffled in place to (N, K) storage + (same total bytes; bytes rearranged for the kernel's read pattern). + x_scales: e8m0 (uint8) per-token scale with shape (M, K // 32). + w_scales: e8m0 (uint8) per-block weight scale with shape (N // 128, K // 128). + dtype: Output dtype. + y: Optional pre-allocated output (M, N). + config: Optional kernel-tuning dict. + + Returns: + torch.Tensor: Output with shape (M, N). + """ + M, K = x.shape + N, K_w = w_shuffled.shape + assert K == K_w, f"K mismatch: x={K}, w={K_w}" + assert N % 16 == 0, f"N must be divisible by 16 for preshuffle, got {N}" + + # The kernel expects to address the shuffled tensor as (N//16, K*16). + w_view = w_shuffled.view(N // 16, K * 16) + + if x.dtype != torch.uint8: + x = x.view(torch.uint8) + if w_view.dtype != torch.uint8: + w_view = w_view.view(torch.uint8) + + if config is None: + if (M, N, K) in _OOM_SHAPES: + # We've previously OOM'd here under HIP-graph capture; skip the + # tuned lookup and use the conservative default. + config = _get_default_config() + else: + config = _get_config(M, N, K) + else: + merged = _get_default_config() + merged.update(config) + config = merged + + config["NUM_KSPLIT"] = 1 + + if y is None: + y = torch.empty((M, N), dtype=dtype, device=x.device) + + grid = lambda META: ( # noqa: E731 + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), + ) + + def _launch(cfg): + _gemm_mxfp8_preshuffle_kernel[grid]( + x, + w_view, + y, + x_scales, + w_scales, + M, + N, + K, + x.stride(0), + x.stride(1), + w_view.stride(0), + w_view.stride(1), + y.stride(0), + y.stride(1), + x_scales.stride(0), + x_scales.stride(1), + w_scales.stride(0), + w_scales.stride(1), + **cfg, + ) + + try: + _launch(config) + except Exception as e: + # The tuned config may overflow gfx950 LDS (163840 bytes) under HIP + # graph capture even when the standalone bench succeeds. Fall back + # to default and cache the failure to avoid retrying on every call. + if ( + "OutOfResources" in type(e).__name__ + or "shared memory" in str(e) + or "Required" in str(e) + ): + _mark_oom(M, N, K) + _launch(_get_default_config() | {"NUM_KSPLIT": 1}) + else: + raise + + return y diff --git a/aiter/ops/triton/quant/quant_mxfp8.py b/aiter/ops/triton/quant/quant_mxfp8.py new file mode 100644 index 0000000000..ae82cafab8 --- /dev/null +++ b/aiter/ops/triton/quant/quant_mxfp8.py @@ -0,0 +1,289 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. + +from typing import Optional, Tuple +import torch +import triton + +from aiter.ops.triton._triton_kernels.quant.quant_mxfp8 import ( + _mxfp8_quant_kernel, + _fp8_legacy_to_mxfp8_kernel, + _rmsnorm_mxfp8_quant_kernel, + _dual_rmsnorm_mxfp8_quant_kernel, +) + + +_QUANT_BLOCK_SIZE = 32 + + +def per_1x32_mxfp8_quant_triton( + x: torch.Tensor, + scale: Optional[torch.Tensor] = None, + quant_dtype: torch.dtype = torch.float8_e4m3fn, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Per-1x32 MXFP8 quantization (e8m0 scale + FP8 e4m3 values). + + Args: + x: Input tensor (..., K). Typically bf16 or fp16. K % 32 == 0. + scale: Pre-allocated scale tensor (M, K // 32) uint8. Optional. + quant_dtype: FP8 dtype to cast quantized values to. On MI3xx + torch.float8_e4m3fnuz is the canonical FP8 e4m3 type. torch.float8_e4m3fn + is acceptable on hardware that supports it. + + Returns: + Tuple of: + y: FP8 tensor of shape x.shape. + s: e8m0 (uint8) scale tensor of shape (..., K // 32). + """ + assert x.dim() >= 2, f"x must be at least 2D, got {x.dim()}" + orig_shape = x.shape + K = orig_shape[-1] + assert ( + K % _QUANT_BLOCK_SIZE == 0 + ), f"last dim K={K} must be a multiple of {_QUANT_BLOCK_SIZE}" + + x2d = x.reshape(-1, K).contiguous() + M = x2d.shape[0] + Ns = K // _QUANT_BLOCK_SIZE # number of scales per row + + y = torch.empty((M, K), dtype=quant_dtype, device=x.device) + if scale is None: + scale = torch.empty((M, Ns), dtype=torch.uint8, device=x.device) + else: + assert scale.shape == (M, Ns), f"scale shape {scale.shape} != ({M},{Ns})" + assert scale.dtype == torch.uint8 + + BLOCK_SIZE_N = triton.next_power_of_2(K) + NUM_PRGMS = M + grid = (NUM_PRGMS,) + + _mxfp8_quant_kernel[grid]( + x2d, + y, + scale, + M, + K, + x2d.stride(0), + x2d.stride(1), + y.stride(0), + y.stride(1), + scale.stride(0), + scale.stride(1), + BLOCK_SIZE_N=BLOCK_SIZE_N, + QUANT_BLOCK_SIZE=_QUANT_BLOCK_SIZE, + NUM_PRGMS=NUM_PRGMS, + ) + + y = y.view(*orig_shape[:-1], K) + s = scale.view(*orig_shape[:-1], Ns) + return y, s + + +_LEGACY_BLOCK_SIZE = 128 + + +def fp8_legacy_to_mxfp8( + x_fnuz: torch.Tensor, + x_scale_fp32: torch.Tensor, + y_fn: Optional[torch.Tensor] = None, + y_scale: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Transcode (FP8 e4m3fnuz, fp32 1x128 scale) -> (FP8 e4m3fn, e8m0 1x32 scale) + in a single Triton launch. Replaces the Python dequant+requant cascade + used when MXFP8 path receives legacy-formatted (FP8 + fp32 1x128) inputs. + + Args: + x_fnuz: FP8 e4m3fnuz tensor of shape (M, N), N % 32 == 0. + x_scale_fp32: fp32 scale of shape (M, N // 128). + y_fn: optional preallocated output FP8 e4m3fn tensor. + y_scale: optional preallocated uint8 e8m0 scale tensor. + + Returns: + y_fn (M, N) fp8 e4m3fn, y_scale (M, N // 32) uint8 e8m0. + """ + assert x_fnuz.dim() == 2, f"x must be 2D, got {x_fnuz.dim()}" + M, N = x_fnuz.shape + assert N % _QUANT_BLOCK_SIZE == 0 + assert N % _LEGACY_BLOCK_SIZE == 0 + assert x_scale_fp32.shape == (M, N // _LEGACY_BLOCK_SIZE), ( + f"x_scale_fp32 shape {x_scale_fp32.shape} != ({M},{N // _LEGACY_BLOCK_SIZE})" + ) + + Ns = N // _QUANT_BLOCK_SIZE + if y_fn is None: + y_fn = torch.empty((M, N), dtype=torch.float8_e4m3fn, device=x_fnuz.device) + if y_scale is None: + y_scale = torch.empty((M, Ns), dtype=torch.uint8, device=x_fnuz.device) + + BLOCK_SIZE_M = 1 + grid = (triton.cdiv(M, BLOCK_SIZE_M), Ns) + + _fp8_legacy_to_mxfp8_kernel[grid]( + x_fnuz, + x_scale_fp32, + y_fn, + y_scale, + M, + N, + x_fnuz.stride(0), + x_fnuz.stride(1), + x_scale_fp32.stride(0), + x_scale_fp32.stride(1), + y_fn.stride(0), + y_fn.stride(1), + y_scale.stride(0), + y_scale.stride(1), + BLOCK_SIZE_M=BLOCK_SIZE_M, + QUANT_BLOCK_SIZE=_QUANT_BLOCK_SIZE, + LEGACY_BLOCK_SIZE=_LEGACY_BLOCK_SIZE, + ) + + return y_fn, y_scale + + +def rmsnorm_mxfp8_quant( + x: torch.Tensor, + weight: torch.Tensor, + eps: float, + y: Optional[torch.Tensor] = None, + scale: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Fused RMSNorm + MXFP8 (1x32 e8m0) quant in a single Triton launch. + + Args: + x: (M, K) bf16 or fp16. + weight: (K,) bf16 or fp16 RMSNorm weight. + eps: RMSNorm epsilon. + y: optional preallocated FP8 e4m3fn output (M, K). + scale: optional preallocated uint8 e8m0 output (M, K // 32). + + Returns: + y (M, K) fp8 e4m3fn, scale (M, K // 32) uint8. + """ + assert x.dim() == 2, f"x must be 2D, got {x.dim()}" + M, K = x.shape + assert weight.shape == (K,), f"weight shape {weight.shape} != ({K},)" + assert K % _QUANT_BLOCK_SIZE == 0 + Ns = K // _QUANT_BLOCK_SIZE + BLOCK_SIZE_K = triton.next_power_of_2(K) + + if y is None: + y = torch.empty((M, K), dtype=torch.float8_e4m3fn, device=x.device) + if scale is None: + scale = torch.empty((M, Ns), dtype=torch.uint8, device=x.device) + + NUM_PRGMS = M + grid = (NUM_PRGMS,) + + _rmsnorm_mxfp8_quant_kernel[grid]( + x, + weight, + y, + scale, + M, + K, + x.stride(0), + x.stride(1), + y.stride(0), + y.stride(1), + scale.stride(0), + scale.stride(1), + eps, + BLOCK_SIZE_K=BLOCK_SIZE_K, + QUANT_BLOCK_SIZE=_QUANT_BLOCK_SIZE, + NUM_PRGMS=NUM_PRGMS, + ) + return y, scale + + +def dual_rmsnorm_mxfp8_quant( + q: torch.Tensor, + k: torch.Tensor, + q_weight: torch.Tensor, + k_weight: torch.Tensor, + eps_q: float, + eps_k: Optional[float] = None, + yq: Optional[torch.Tensor] = None, + sq: Optional[torch.Tensor] = None, + yk: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Fused dual RMSNorm in a single Triton launch. + + - Q side: RMSNorm(q, q_weight, eps_q) -> MXFP8 (FP8 e4m3fn + uint8 e8m0 1x32). + - K side: RMSNorm(k, k_weight, eps_k) -> bf16. + + Replaces the CK `fused_qk_rmsnorm_group_quant` kernel for the MXFP8 GEMM + path on V4 (Task #77): one launch instead of two (rmsnorm_mxfp8_quant + + rmsnorm2d_fwd_), eliminating the ~6us/layer launch-overhead regression. + + Args: + q: (M, KQ) bf16 or fp16 — Q-side input (e.g. q_lora). + k: (M, KK) bf16 or fp16 — K-side input (e.g. kv_pre). + q_weight: (KQ,) bf16 or fp16 — Q RMSNorm weight. + k_weight: (KK,) bf16 or fp16 — K RMSNorm weight. + eps_q: Q RMSNorm epsilon. + eps_k: K RMSNorm epsilon; defaults to eps_q. + yq, sq, yk: optional pre-allocated outputs. + + Returns: + yq (M, KQ) fp8 e4m3fn, sq (M, KQ // 32) uint8 e8m0, yk (M, KK) bf16. + """ + assert q.dim() == 2, f"q must be 2D, got {q.dim()}" + assert k.dim() == 2, f"k must be 2D, got {k.dim()}" + M, KQ = q.shape + Mk, KK = k.shape + assert M == Mk, f"q rows {M} != k rows {Mk}" + assert q_weight.shape == (KQ,), f"q_weight shape {q_weight.shape} != ({KQ},)" + assert k_weight.shape == (KK,), f"k_weight shape {k_weight.shape} != ({KK},)" + assert KQ % _QUANT_BLOCK_SIZE == 0, ( + f"KQ={KQ} must be a multiple of {_QUANT_BLOCK_SIZE}" + ) + if eps_k is None: + eps_k = eps_q + + Ns = KQ // _QUANT_BLOCK_SIZE + BLOCK_SIZE_KQ = triton.next_power_of_2(KQ) + BLOCK_SIZE_KK = triton.next_power_of_2(KK) + + if yq is None: + yq = torch.empty((M, KQ), dtype=torch.float8_e4m3fn, device=q.device) + if sq is None: + sq = torch.empty((M, Ns), dtype=torch.uint8, device=q.device) + if yk is None: + yk = torch.empty((M, KK), dtype=k.dtype, device=k.device) + + NUM_PRGMS = M + grid = (NUM_PRGMS,) + + _dual_rmsnorm_mxfp8_quant_kernel[grid]( + q, + k, + q_weight, + k_weight, + yq, + sq, + yk, + M, + KQ, + KK, + q.stride(0), + q.stride(1), + k.stride(0), + k.stride(1), + yq.stride(0), + yq.stride(1), + sq.stride(0), + sq.stride(1), + yk.stride(0), + yk.stride(1), + eps_q, + eps_k, + BLOCK_SIZE_KQ=BLOCK_SIZE_KQ, + BLOCK_SIZE_KK=BLOCK_SIZE_KK, + QUANT_BLOCK_SIZE=_QUANT_BLOCK_SIZE, + NUM_PRGMS=NUM_PRGMS, + ) + return yq, sq, yk From cdc749127e890d4d778cea573bc97c88a221bbd5 Mon Sep 17 00:00:00 2001 From: ShaoChunLee Date: Tue, 19 May 2026 20:46:04 +0000 Subject: [PATCH 2/8] add ut --- .../fusions/fused_clamp_act_mul.py | 39 +- .../basic/{gemm_mxfp8.py => gemm_afp8wfp8.py} | 57 +-- ...MM-AFP8WFP8-PRESHUFFLE-N=1536-K=4096.json} | 0 ...MM-AFP8WFP8-PRESHUFFLE-N=4096-K=1024.json} | 0 ...EMM-AFP8WFP8-PRESHUFFLE-N=4096-K=256.json} | 0 ...EMM-AFP8WFP8-PRESHUFFLE-N=512-K=4096.json} | 0 ...MM-AFP8WFP8-PRESHUFFLE-N=8192-K=1024.json} | 0 ...n => gfx950-GEMM-AFP8WFP8-PRESHUFFLE.json} | 0 .../configs/gemm/gfx950-GEMM-AFP8WFP8.json | 14 + .../ops/triton/fusions/fused_clamp_act_mul.py | 35 +- .../basic/{gemm_mxfp8.py => gemm_afp8wfp8.py} | 233 ++++++----- .../gemm/basic/test_gemm_afp8wfp8.py | 150 ++++++++ .../triton_tests/quant/test_quant_mxfp8.py | 362 ++++++++++++++++++ 13 files changed, 748 insertions(+), 142 deletions(-) rename aiter/ops/triton/_triton_kernels/gemm/basic/{gemm_mxfp8.py => gemm_afp8wfp8.py} (90%) rename aiter/ops/triton/configs/gemm/{gfx950-GEMM-MXFP8-PRESHUFFLE-N=1536-K=4096.json => gfx950-GEMM-AFP8WFP8-PRESHUFFLE-N=1536-K=4096.json} (100%) rename aiter/ops/triton/configs/gemm/{gfx950-GEMM-MXFP8-PRESHUFFLE-N=4096-K=1024.json => gfx950-GEMM-AFP8WFP8-PRESHUFFLE-N=4096-K=1024.json} (100%) rename aiter/ops/triton/configs/gemm/{gfx950-GEMM-MXFP8-PRESHUFFLE-N=4096-K=256.json => gfx950-GEMM-AFP8WFP8-PRESHUFFLE-N=4096-K=256.json} (100%) rename aiter/ops/triton/configs/gemm/{gfx950-GEMM-MXFP8-PRESHUFFLE-N=512-K=4096.json => gfx950-GEMM-AFP8WFP8-PRESHUFFLE-N=512-K=4096.json} (100%) rename aiter/ops/triton/configs/gemm/{gfx950-GEMM-MXFP8-PRESHUFFLE-N=8192-K=1024.json => gfx950-GEMM-AFP8WFP8-PRESHUFFLE-N=8192-K=1024.json} (100%) rename aiter/ops/triton/configs/gemm/{gfx950-GEMM-MXFP8-PRESHUFFLE.json => gfx950-GEMM-AFP8WFP8-PRESHUFFLE.json} (100%) create mode 100644 aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8.json rename aiter/ops/triton/gemm/basic/{gemm_mxfp8.py => gemm_afp8wfp8.py} (53%) create mode 100644 op_tests/triton_tests/gemm/basic/test_gemm_afp8wfp8.py create mode 100644 op_tests/triton_tests/quant/test_quant_mxfp8.py diff --git a/aiter/ops/triton/_triton_kernels/fusions/fused_clamp_act_mul.py b/aiter/ops/triton/_triton_kernels/fusions/fused_clamp_act_mul.py index f5887fdcb5..589cc372f0 100644 --- a/aiter/ops/triton/_triton_kernels/fusions/fused_clamp_act_mul.py +++ b/aiter/ops/triton/_triton_kernels/fusions/fused_clamp_act_mul.py @@ -23,6 +23,7 @@ [ "BLOCK_SIZE_N", "QUANT_BLOCK_SIZE", + "SCALE_FMT", "HAVE_WEIGHTS", "WEIGHT_BROADCAST", "HAVE_SWIGLU_CLAMP", @@ -50,6 +51,7 @@ def _fused_clamp_silu_mul_kernel( swiglu_limit, BLOCK_SIZE_N: tl.constexpr, QUANT_BLOCK_SIZE: tl.constexpr, + SCALE_FMT: tl.constexpr, DTYPE_MAX: tl.constexpr, DTYPE_MIN: tl.constexpr, HAVE_WEIGHTS: tl.constexpr, @@ -95,11 +97,34 @@ def _fused_clamp_silu_mul_kernel( out = out * w if HAS_QUANT: - out_q, block_scales = _fp8_quant_op( - out, 1, BLOCK_SIZE_N, QUANT_BLOCK_SIZE, DTYPE_MAX, DTYPE_MIN - ) - out_q = tl.ravel(out_q) - block_scales = tl.ravel(block_scales) + if SCALE_FMT == "ue8m0": + # Per-1×QUANT_BLOCK_SIZE MXFP8 emit: fp8 e4m3 values + uint8 ue8m0 + # biased-exponent scales. Mirrors the ue8m0 path used by moe_gemm_a8w4. + NUM_QB: tl.constexpr = BLOCK_SIZE_N // QUANT_BLOCK_SIZE + out_3d = tl.reshape(out, [1, NUM_QB, QUANT_BLOCK_SIZE]) + abs_3d = tl.abs(out_3d) + max_val = tl.max(abs_3d, axis=2, keep_dims=True) + dequant_scale = max_val / DTYPE_MAX + # ROUND_UP via exponent: 2 ** ceil(log2(dequant_scale)) + dequant_scale_exp = ( + dequant_scale.to(tl.uint32, bitcast=True) + 0x007FFFFF + ) & 0x7F800000 + dequant_scale_rounded = dequant_scale_exp.to(tl.float32, bitcast=True) + quant_scale = tl.where( + dequant_scale_rounded == 0, 0.0, 1.0 / dequant_scale_rounded + ) + quant_tensor = out_3d * quant_scale + quant_2d = tl.reshape(quant_tensor, [1, BLOCK_SIZE_N]) + out_q = tl.ravel(quant_2d) + scale_exp = (dequant_scale_exp >> 23).to(tl.uint8) + scale_exp_2d = tl.reshape(scale_exp, [1, NUM_QB]) + block_scales = tl.ravel(scale_exp_2d) + else: + out_q, block_scales = _fp8_quant_op( + out, 1, BLOCK_SIZE_N, QUANT_BLOCK_SIZE, DTYPE_MAX, DTYPE_MIN + ) + out_q = tl.ravel(out_q) + block_scales = tl.ravel(block_scales) tl.store( out_ptr + m_pid * out_stride_m + n_offs * out_stride_n, @@ -108,8 +133,8 @@ def _fused_clamp_silu_mul_kernel( ) num_bs = tl.cdiv(n_half, QUANT_BLOCK_SIZE) - NUM_QB: tl.constexpr = BLOCK_SIZE_N // QUANT_BLOCK_SIZE - g_offs = tl.arange(0, NUM_QB) + NUM_QB_S: tl.constexpr = BLOCK_SIZE_N // QUANT_BLOCK_SIZE + g_offs = tl.arange(0, NUM_QB_S) tl.store( scale_ptr + m_pid * scale_stride_m + g_offs * scale_stride_n, block_scales.to(scale_ptr.dtype.element_ty), diff --git a/aiter/ops/triton/_triton_kernels/gemm/basic/gemm_mxfp8.py b/aiter/ops/triton/_triton_kernels/gemm/basic/gemm_afp8wfp8.py similarity index 90% rename from aiter/ops/triton/_triton_kernels/gemm/basic/gemm_mxfp8.py rename to aiter/ops/triton/_triton_kernels/gemm/basic/gemm_afp8wfp8.py index beaf16cb23..b657c27c9b 100644 --- a/aiter/ops/triton/_triton_kernels/gemm/basic/gemm_mxfp8.py +++ b/aiter/ops/triton/_triton_kernels/gemm/basic/gemm_afp8wfp8.py @@ -5,10 +5,10 @@ import triton.language as tl from aiter.ops.triton.utils._triton.kernel_repr import make_kernel_repr from aiter.ops.triton.utils._triton.pid_preprocessing import pid_grid, remap_xcd +from aiter.ops.triton.utils.gemm_config_utils import get_gemm_config - -_gemm_mxfp8_repr = make_kernel_repr( - "_gemm_mxfp8_kernel", +_gemm_afp8wfp8_repr = make_kernel_repr( + "_gemm_afp8wfp8_kernel", [ "BLOCK_SIZE_M", "BLOCK_SIZE_N", @@ -29,8 +29,8 @@ "EVEN_K": lambda args: (args["K"] % args["BLOCK_SIZE_K"] == 0), } ) -@triton.jit(repr=_gemm_mxfp8_repr) -def _gemm_mxfp8_kernel( +@triton.jit(repr=_gemm_afp8wfp8_repr) +def _gemm_afp8wfp8_kernel( a_ptr, b_ptr, c_ptr, @@ -43,6 +43,7 @@ def _gemm_mxfp8_kernel( stride_ak, stride_bk, stride_bn, + stride_ck, stride_cm, stride_cn, stride_asm, @@ -110,12 +111,8 @@ def _gemm_mxfp8_kernel( offs_k = tl.arange(0, BLOCK_SIZE_K) offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N - a_ptrs = a_ptr + ( - offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak - ) - b_ptrs = b_ptr + ( - offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn - ) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) # A-scale pointers: per-row (M) and per scale group (K // 32) offs_ks_a = tl.arange(0, BLOCK_SIZE_K // SCALE_GROUP_SIZE) @@ -180,9 +177,7 @@ def _gemm_mxfp8_kernel( a = tl.load(a_ptrs) b = tl.load(b_ptrs, cache_modifier=cache_modifier) else: - a = tl.load( - a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0 - ) + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0) b = tl.load( b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, @@ -209,8 +204,8 @@ def _gemm_mxfp8_kernel( tl.store(c_ptrs, c, mask=c_mask) -_gemm_mxfp8_preshuffle_repr = make_kernel_repr( - "_gemm_mxfp8_preshuffle_kernel", +_gemm_afp8wfp8_preshuffle_repr = make_kernel_repr( + "_gemm_afp8wfp8_preshuffle_kernel", [ "BLOCK_SIZE_M", "BLOCK_SIZE_N", @@ -231,8 +226,8 @@ def _gemm_mxfp8_kernel( "EVEN_K": lambda args: (args["K"] % args["BLOCK_SIZE_K"] == 0), } ) -@triton.jit(repr=_gemm_mxfp8_preshuffle_repr) -def _gemm_mxfp8_preshuffle_kernel( +@triton.jit(repr=_gemm_afp8wfp8_preshuffle_repr) +def _gemm_afp8wfp8_preshuffle_kernel( a_ptr, b_ptr, c_ptr, @@ -245,6 +240,7 @@ def _gemm_mxfp8_preshuffle_kernel( stride_ak, stride_bn, stride_bk, + stride_ck, stride_cm, stride_cn, stride_asm, @@ -265,7 +261,7 @@ def _gemm_mxfp8_preshuffle_kernel( cache_modifier: tl.constexpr, ): """ - Preshuffle variant of _gemm_mxfp8_kernel. Weight tensor has been shuffled + Preshuffle variant of _gemm_afp8wfp8_kernel. Weight tensor has been shuffled via aiter.ops.shuffle.shuffle_weight(layout=(16, 16)) so that 16-row N tiles are interleaved with their 32-col K chunks in storage. The kernel loads the shuffled tile in storage order (BLOCK_SIZE_N // 16, BLOCK_SIZE_K * 16) then @@ -345,9 +341,7 @@ def _gemm_mxfp8_preshuffle_kernel( a_scales = tl.load(a_scale_ptrs, mask=a_scale_mask, other=127) # Load and broadcast B scales (same as non-preshuffle path). - offs_bsk = ( - k_base + offs_scale_k_a * SCALE_GROUP_SIZE - ) // B_SCALE_K_GROUP + offs_bsk = (k_base + offs_scale_k_a * SCALE_GROUP_SIZE) // B_SCALE_K_GROUP b_scale_ptrs = ( b_scales_ptr + offs_bsn[:, None] * stride_bsn @@ -371,9 +365,7 @@ def _gemm_mxfp8_preshuffle_kernel( a = tl.load(a_ptrs) b_shuf = tl.load(b_ptrs, cache_modifier=cache_modifier) else: - a = tl.load( - a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0 - ) + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0) b_shuf = tl.load( b_ptrs, mask=offs_k_shuffle[None, :] < (K - k * BLOCK_SIZE_K) * 16, @@ -415,3 +407,18 @@ def _gemm_mxfp8_preshuffle_kernel( c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) tl.store(c_ptrs, c, mask=c_mask) + + +def _get_config( + M: int, + N: int, + K: int, + shuffle: bool = False, +): + """Load the best tuned config for (M, N, K) via the standard aiter JSON + config mechanism. Falls back to the generic fallback file or to + _DEFAULT_CONFIG if no JSON is available.""" + if shuffle: + return get_gemm_config("GEMM-AFP8WFP8-PRESHUFFLE", M, N, K) + else: + return get_gemm_config("GEMM-AFP8WFP8", M, N, K) diff --git a/aiter/ops/triton/configs/gemm/gfx950-GEMM-MXFP8-PRESHUFFLE-N=1536-K=4096.json b/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8-PRESHUFFLE-N=1536-K=4096.json similarity index 100% rename from aiter/ops/triton/configs/gemm/gfx950-GEMM-MXFP8-PRESHUFFLE-N=1536-K=4096.json rename to aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8-PRESHUFFLE-N=1536-K=4096.json diff --git a/aiter/ops/triton/configs/gemm/gfx950-GEMM-MXFP8-PRESHUFFLE-N=4096-K=1024.json b/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8-PRESHUFFLE-N=4096-K=1024.json similarity index 100% rename from aiter/ops/triton/configs/gemm/gfx950-GEMM-MXFP8-PRESHUFFLE-N=4096-K=1024.json rename to aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8-PRESHUFFLE-N=4096-K=1024.json diff --git a/aiter/ops/triton/configs/gemm/gfx950-GEMM-MXFP8-PRESHUFFLE-N=4096-K=256.json b/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8-PRESHUFFLE-N=4096-K=256.json similarity index 100% rename from aiter/ops/triton/configs/gemm/gfx950-GEMM-MXFP8-PRESHUFFLE-N=4096-K=256.json rename to aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8-PRESHUFFLE-N=4096-K=256.json diff --git a/aiter/ops/triton/configs/gemm/gfx950-GEMM-MXFP8-PRESHUFFLE-N=512-K=4096.json b/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8-PRESHUFFLE-N=512-K=4096.json similarity index 100% rename from aiter/ops/triton/configs/gemm/gfx950-GEMM-MXFP8-PRESHUFFLE-N=512-K=4096.json rename to aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8-PRESHUFFLE-N=512-K=4096.json diff --git a/aiter/ops/triton/configs/gemm/gfx950-GEMM-MXFP8-PRESHUFFLE-N=8192-K=1024.json b/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8-PRESHUFFLE-N=8192-K=1024.json similarity index 100% rename from aiter/ops/triton/configs/gemm/gfx950-GEMM-MXFP8-PRESHUFFLE-N=8192-K=1024.json rename to aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8-PRESHUFFLE-N=8192-K=1024.json diff --git a/aiter/ops/triton/configs/gemm/gfx950-GEMM-MXFP8-PRESHUFFLE.json b/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8-PRESHUFFLE.json similarity index 100% rename from aiter/ops/triton/configs/gemm/gfx950-GEMM-MXFP8-PRESHUFFLE.json rename to aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8-PRESHUFFLE.json diff --git a/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8.json b/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8.json new file mode 100644 index 0000000000..c7271acf94 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8.json @@ -0,0 +1,14 @@ +{ + "any": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + } +} diff --git a/aiter/ops/triton/fusions/fused_clamp_act_mul.py b/aiter/ops/triton/fusions/fused_clamp_act_mul.py index 0b1b8a1a63..a5baa533e7 100644 --- a/aiter/ops/triton/fusions/fused_clamp_act_mul.py +++ b/aiter/ops/triton/fusions/fused_clamp_act_mul.py @@ -25,6 +25,9 @@ def fused_clamp_act_mul( weights: Optional[torch.Tensor] = None, dtype_quant: torch.dtype | None = None, transpose_scale: bool = False, + quant_block_size: int = 128, + scale_dtype_fmt: Literal["fp32", "ue8m0"] = "fp32", + shuffle_scale: bool = False, ): """ Fused clamp (SwiGLU-style) + act(gate) * up + optional weights, with optional FP8 group quant. @@ -54,6 +57,25 @@ def fused_clamp_act_mul( HAS_QUANT = dtype_quant is not None + # Step 5 ue8m0 mode: per-1×32 group quant, uint8 scale. + assert scale_dtype_fmt in ("fp32", "ue8m0") + if scale_dtype_fmt == "ue8m0": + assert HAS_QUANT, "scale_dtype_fmt='ue8m0' requires dtype_quant" + assert ( + quant_block_size == 32 + ), f"ue8m0 requires quant_block_size=32 got {quant_block_size}" + assert dtype_quant in ( + torch.float8_e4m3fn, + torch.float8_e4m3fnuz, + ), f"ue8m0 requires fp8 e4m3, got {dtype_quant}" + if shuffle_scale and transpose_scale: + raise ValueError("shuffle_scale incompatible with transpose_scale") + _scale_storage_dtype = torch.uint8 + else: + if shuffle_scale: + raise ValueError("shuffle_scale only valid with scale_dtype_fmt='ue8m0'") + _scale_storage_dtype = torch.float32 + if HAS_QUANT: if out is None: out = torch.empty((M, n_half), dtype=dtype_quant, device=inp.device) @@ -65,15 +87,15 @@ def fused_clamp_act_mul( dtype_quant, out.dtype, ) - num_blocks = (n_half + 127) // 128 + num_blocks = (n_half + quant_block_size - 1) // quant_block_size if scale is None: if transpose_scale: scale = torch.empty( - (num_blocks, M), dtype=torch.float32, device=inp.device + (num_blocks, M), dtype=_scale_storage_dtype, device=inp.device ) else: scale = torch.empty( - (M, num_blocks), dtype=torch.float32, device=inp.device + (M, num_blocks), dtype=_scale_storage_dtype, device=inp.device ) else: if transpose_scale: @@ -153,7 +175,8 @@ def fused_clamp_act_mul( weights.stride(1) if HAVE_WEIGHTS else 0, swiglu_limit, BLOCK_SIZE_N=BLOCK_SIZE_N, - QUANT_BLOCK_SIZE=128, + QUANT_BLOCK_SIZE=quant_block_size, + SCALE_FMT=scale_dtype_fmt, DTYPE_MAX=DTYPE_MAX, DTYPE_MIN=-DTYPE_MAX, HAVE_WEIGHTS=HAVE_WEIGHTS, @@ -167,5 +190,9 @@ def fused_clamp_act_mul( if HAS_QUANT: if transpose_scale: scale = scale.view(M, num_bs_cols) + if shuffle_scale: + from aiter.utility import fp4_utils + + scale = fp4_utils.e8m0_shuffle(scale) return out, scale return out diff --git a/aiter/ops/triton/gemm/basic/gemm_mxfp8.py b/aiter/ops/triton/gemm/basic/gemm_afp8wfp8.py similarity index 53% rename from aiter/ops/triton/gemm/basic/gemm_mxfp8.py rename to aiter/ops/triton/gemm/basic/gemm_afp8wfp8.py index 56d4d2f44c..e4d4ae9b0f 100644 --- a/aiter/ops/triton/gemm/basic/gemm_mxfp8.py +++ b/aiter/ops/triton/gemm/basic/gemm_afp8wfp8.py @@ -6,33 +6,17 @@ import torch import triton -from aiter.ops.triton._triton_kernels.gemm.basic.gemm_mxfp8 import ( - _gemm_mxfp8_kernel, - _gemm_mxfp8_preshuffle_kernel, +from aiter.ops.triton._triton_kernels.gemm.basic.gemm_afp8wfp8 import ( + _gemm_afp8wfp8_kernel, + _gemm_afp8wfp8_preshuffle_kernel, + _get_config, +) +from aiter.ops.triton._triton_kernels.gemm.basic.gemm_a8w8_blockscale import ( + _gemm_a8w8_blockscale_reduce_kernel, ) -from aiter.ops.triton.utils.gemm_config_utils import get_gemm_config - - -_DEFAULT_CONFIG = { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 8, - "NUM_KSPLIT": 1, - "num_warps": 4, - "num_stages": 2, - "waves_per_eu": 0, - "matrix_instr_nonkdim": 16, - "cache_modifier": "", -} - - -def _get_default_config() -> dict: - return dict(_DEFAULT_CONFIG) - # ----------------------------------------------------------------------------- -# Tuned-config lookup for gemm_mxfp8_preshuffle. +# Tuned-config lookup for gemm_afp8wfp8_preshuffle. # # Configs live under aiter.ops.triton.configs.gemm using the standard aiter # naming: gfx{arch}-GEMM-MXFP8-PRESHUFFLE-N={N}-K={K}.json, keyed by @@ -52,25 +36,7 @@ def _mark_oom(M: int, N: int, K: int): _OOM_SHAPES.add((M, N, K)) -def _get_config(M: int, N: int, K: int) -> dict: - """Load the best tuned config for (M, N, K) via the standard aiter JSON - config mechanism. Falls back to the generic fallback file or to - _DEFAULT_CONFIG if no JSON is available.""" - try: - config, _ = get_gemm_config("GEMM-MXFP8-PRESHUFFLE", M, N, K) - except (AssertionError, KeyError, FileNotFoundError): - config = _get_default_config() - # Always pin NUM_KSPLIT=1 (this version of the kernel doesn't split K). - config["NUM_KSPLIT"] = 1 - # The JSON uses cache_modifier=null which decodes to Python None; the - # kernel accepts None or "", but we keep the empty-string default from - # _DEFAULT_CONFIG to match the legacy code path. - if config.get("cache_modifier") is None: - config["cache_modifier"] = "" - return config - - -def gemm_mxfp8( +def gemm_afp8wfp8( x: torch.Tensor, w: torch.Tensor, x_scales: torch.Tensor, @@ -78,6 +44,7 @@ def gemm_mxfp8( dtype: Optional[torch.dtype] = torch.bfloat16, y: Optional[torch.Tensor] = None, config: Optional[dict] = None, + skip_reduce: Optional[bool] = False, ) -> torch.Tensor: """ Computes matrix multiplication Y = X @ W^T with MXFP8 activations and FP8 @@ -111,27 +78,35 @@ def gemm_mxfp8( w_t = w_t.view(torch.uint8) if config is None: - config = _get_default_config() - else: - # Merge with defaults so missing keys are filled. - merged = _get_default_config() - merged.update(config) - config = merged - - # First-version: NUM_KSPLIT must be 1 - config["NUM_KSPLIT"] = 1 + config, _ = _get_config(M, N, K) - if y is None: + if y is None and (config["NUM_KSPLIT"] == 1 or not skip_reduce): y = torch.empty((M, N), dtype=dtype, device=x.device) + config["SPLITK_BLOCK_SIZE"] = triton.cdiv( + K, config["NUM_KSPLIT"] + ) # How big each split_k partition is + if config["NUM_KSPLIT"] > 1: + y_pp = torch.empty( + (config["NUM_KSPLIT"], M, N), + dtype=torch.float32, + device=x.device, + ) + else: + y_pp = None + grid = lambda META: ( # noqa: E731 - triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), + ( + META["NUM_KSPLIT"] + * triton.cdiv(M, META["BLOCK_SIZE_M"]) + * triton.cdiv(N, META["BLOCK_SIZE_N"]) + ), ) - _gemm_mxfp8_kernel[grid]( + _gemm_afp8wfp8_kernel[grid]( x, w_t, - y, + y if config["NUM_KSPLIT"] == 1 else y_pp, x_scales, w_scales, M, @@ -141,8 +116,9 @@ def gemm_mxfp8( x.stride(1), w_t.stride(0), w_t.stride(1), - y.stride(0), - y.stride(1), + 0 if config["NUM_KSPLIT"] == 1 else y_pp.stride(0), + y.stride(0) if config["NUM_KSPLIT"] == 1 else y_pp.stride(1), + y.stride(1) if config["NUM_KSPLIT"] == 1 else y_pp.stride(2), x_scales.stride(0), x_scales.stride(1), w_scales.stride(0), @@ -150,10 +126,38 @@ def gemm_mxfp8( **config, ) + if config["NUM_KSPLIT"] > 1: + if skip_reduce: + return y_pp + + REDUCE_BLOCK_SIZE_M = 32 + REDUCE_BLOCK_SIZE_N = 32 + ACTUAL_KSPLIT = triton.cdiv(K, config["SPLITK_BLOCK_SIZE"]) + + grid_reduce = ( + triton.cdiv(M, REDUCE_BLOCK_SIZE_M), + triton.cdiv(N, REDUCE_BLOCK_SIZE_N), + ) + _gemm_a8w8_blockscale_reduce_kernel[grid_reduce]( + y_pp, + y, + M, + N, + y_pp.stride(0), + y_pp.stride(1), + y_pp.stride(2), + y.stride(0), + y.stride(1), + REDUCE_BLOCK_SIZE_M, + REDUCE_BLOCK_SIZE_N, + ACTUAL_KSPLIT, + triton.next_power_of_2(config["NUM_KSPLIT"]), + ) + return y -def gemm_mxfp8_preshuffle( +def gemm_afp8wfp8_preshuffle( x: torch.Tensor, w_shuffled: torch.Tensor, x_scales: torch.Tensor, @@ -161,9 +165,10 @@ def gemm_mxfp8_preshuffle( dtype: Optional[torch.dtype] = torch.bfloat16, y: Optional[torch.Tensor] = None, config: Optional[dict] = None, + skip_reduce: Optional[bool] = False, ) -> torch.Tensor: """ - Preshuffle variant of gemm_mxfp8. The weight tensor has already been + Preshuffle variant of gemm_afp8wfp8. The weight tensor has already been permuted via aiter.ops.shuffle.shuffle_weight(..., layout=(16, 16)). Scales are left unshuffled in the compact 128x128 layout. @@ -194,63 +199,79 @@ def gemm_mxfp8_preshuffle( w_view = w_view.view(torch.uint8) if config is None: - if (M, N, K) in _OOM_SHAPES: - # We've previously OOM'd here under HIP-graph capture; skip the - # tuned lookup and use the conservative default. - config = _get_default_config() - else: - config = _get_config(M, N, K) - else: - merged = _get_default_config() - merged.update(config) - config = merged + config, _ = _get_config(M, N, K, shuffle=True) - config["NUM_KSPLIT"] = 1 - - if y is None: + if y is None and (config["NUM_KSPLIT"] == 1 or not skip_reduce): y = torch.empty((M, N), dtype=dtype, device=x.device) + config["SPLITK_BLOCK_SIZE"] = triton.cdiv( + K, config["NUM_KSPLIT"] + ) # How big each split_k partition is + if config["NUM_KSPLIT"] > 1: + y_pp = torch.empty( + (config["NUM_KSPLIT"], M, N), + dtype=torch.float32, + device=x.device, + ) + else: + y_pp = None + grid = lambda META: ( # noqa: E731 - triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), + ( + META["NUM_KSPLIT"] + * triton.cdiv(M, META["BLOCK_SIZE_M"]) + * triton.cdiv(N, META["BLOCK_SIZE_N"]) + ), + ) + _gemm_afp8wfp8_preshuffle_kernel[grid]( + x, + w_view, + y if config["NUM_KSPLIT"] == 1 else y_pp, + x_scales, + w_scales, + M, + N, + K, + x.stride(0), + x.stride(1), + w_view.stride(0), + w_view.stride(1), + 0 if config["NUM_KSPLIT"] == 1 else y_pp.stride(0), + y.stride(0) if config["NUM_KSPLIT"] == 1 else y_pp.stride(1), + y.stride(1) if config["NUM_KSPLIT"] == 1 else y_pp.stride(2), + x_scales.stride(0), + x_scales.stride(1), + w_scales.stride(0), + w_scales.stride(1), + **config, ) - def _launch(cfg): - _gemm_mxfp8_preshuffle_kernel[grid]( - x, - w_view, + if config["NUM_KSPLIT"] > 1: + if skip_reduce: + return y_pp + + REDUCE_BLOCK_SIZE_M = 32 + REDUCE_BLOCK_SIZE_N = 32 + ACTUAL_KSPLIT = triton.cdiv(K, config["SPLITK_BLOCK_SIZE"]) + + grid_reduce = ( + triton.cdiv(M, REDUCE_BLOCK_SIZE_M), + triton.cdiv(N, REDUCE_BLOCK_SIZE_N), + ) + _gemm_a8w8_blockscale_reduce_kernel[grid_reduce]( + y_pp, y, - x_scales, - w_scales, M, N, - K, - x.stride(0), - x.stride(1), - w_view.stride(0), - w_view.stride(1), + y_pp.stride(0), + y_pp.stride(1), + y_pp.stride(2), y.stride(0), y.stride(1), - x_scales.stride(0), - x_scales.stride(1), - w_scales.stride(0), - w_scales.stride(1), - **cfg, + REDUCE_BLOCK_SIZE_M, + REDUCE_BLOCK_SIZE_N, + ACTUAL_KSPLIT, + triton.next_power_of_2(config["NUM_KSPLIT"]), ) - try: - _launch(config) - except Exception as e: - # The tuned config may overflow gfx950 LDS (163840 bytes) under HIP - # graph capture even when the standalone bench succeeds. Fall back - # to default and cache the failure to avoid retrying on every call. - if ( - "OutOfResources" in type(e).__name__ - or "shared memory" in str(e) - or "Required" in str(e) - ): - _mark_oom(M, N, K) - _launch(_get_default_config() | {"NUM_KSPLIT": 1}) - else: - raise - return y diff --git a/op_tests/triton_tests/gemm/basic/test_gemm_afp8wfp8.py b/op_tests/triton_tests/gemm/basic/test_gemm_afp8wfp8.py new file mode 100644 index 0000000000..a1b5c1a20a --- /dev/null +++ b/op_tests/triton_tests/gemm/basic/test_gemm_afp8wfp8.py @@ -0,0 +1,150 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. + +import pytest +import torch + +from aiter.ops.triton.gemm.basic.gemm_afp8wfp8 import ( + gemm_afp8wfp8, + gemm_afp8wfp8_preshuffle, +) +from aiter.ops.shuffle import shuffle_weight +import aiter.ops.triton.utils._triton.arch_info as arch_info + +SCALE_GROUP_SIZE = 32 # A: 1x32 e8m0 scale group +W_SCALE_K_GROUP = 128 # B: 128 in K direction +W_SCALE_N_GROUP = 128 # B: 128 in N direction +FP8_MAX = 448.0 # e4m3 max + + +def e8m0_to_f32(x: torch.Tensor) -> torch.Tensor: + """Decode unsigned-biased e8m0 (uint8) to fp32. Bias 127, value = 2^(b-127).""" + return torch.exp2((x.to(torch.int32) - 127).to(torch.float32)) + + +def generate_inputs(M: int, N: int, K: int, seed: int = 0): + torch.manual_seed(seed) + # Small random fp32 → fp8 e4m3fn, kept inside e4m3 range so the cast is exact-ish. + x_f32 = torch.randn((M, K), dtype=torch.float32, device="cuda") + w_f32 = torch.randn((N, K), dtype=torch.float32, device="cuda") + x_f32 = torch.clamp(x_f32, -FP8_MAX, FP8_MAX) + w_f32 = torch.clamp(w_f32, -FP8_MAX, FP8_MAX) + x_fp8 = x_f32.to(torch.float8_e4m3fn) + w_fp8 = w_f32.to(torch.float8_e4m3fn) + + # e8m0 scales near 127 (== 1.0) so the dequant has unit-ish magnitude. + x_scales = torch.randint( + 125, 130, (M, K // SCALE_GROUP_SIZE), dtype=torch.uint8, device="cuda" + ) + w_scales = torch.randint( + 125, + 130, + (N // W_SCALE_N_GROUP, K // W_SCALE_K_GROUP), + dtype=torch.uint8, + device="cuda", + ) + return x_fp8, w_fp8, x_scales, w_scales + + +def run_torch_gemm_afp8wfp8( + x_fp8: torch.Tensor, + w_fp8: torch.Tensor, + x_scales: torch.Tensor, + w_scales: torch.Tensor, + out_dtype: torch.dtype, +) -> torch.Tensor: + """Reference: dequant both operands to fp32 and run torch.mm.""" + M, K = x_fp8.shape + N, _ = w_fp8.shape + + x_view = x_fp8 if x_fp8.dtype != torch.uint8 else x_fp8.view(torch.float8_e4m3fn) + w_view = w_fp8 if w_fp8.dtype != torch.uint8 else w_fp8.view(torch.float8_e4m3fn) + x_f32 = x_view.to(torch.float32) + w_f32 = w_view.to(torch.float32) + + x_s_f32 = e8m0_to_f32(x_scales).repeat_interleave(SCALE_GROUP_SIZE, dim=1) + assert x_s_f32.shape == (M, K) + + w_s_f32 = e8m0_to_f32(w_scales) + w_s_f32 = w_s_f32.repeat_interleave(W_SCALE_N_GROUP, dim=0).repeat_interleave( + W_SCALE_K_GROUP, dim=1 + ) + assert w_s_f32.shape == (N, K) + + x_dq = x_f32 * x_s_f32 + w_dq = w_f32 * w_s_f32 + return torch.mm(x_dq, w_dq.T).to(out_dtype) + + +def get_shapes(): + # (M, N, K), with N % 128 == 0 and K % 128 == 0 to fit the 128x128 W-scale layout. + return [ + (m, n, k) + for m in [1, 4, 8, 16, 32, 64, 128] + for n, k in [ + (1536, 4096), + (4096, 1024), + (512, 4096), + (8192, 1024), + (2048, 7168), + (7168, 2048), + (768, 7168), + (7168, 384), + (8192, 1536), + ] + ] + + +@pytest.mark.parametrize("M, N, K", get_shapes()) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) +def test_gemm_afp8wfp8(M: int, N: int, K: int, dtype: torch.dtype): + if not arch_info.is_fp8_avail(): + pytest.skip("MXFP8 GEMM requires FP8-capable arch") + torch.cuda.empty_cache() + + x_fp8, w_fp8, x_scales, w_scales = generate_inputs(M, N, K) + + torch_out = run_torch_gemm_afp8wfp8(x_fp8, w_fp8, x_scales, w_scales, dtype) + triton_out = gemm_afp8wfp8(x_fp8, w_fp8, x_scales, w_scales, dtype=dtype) + + torch.testing.assert_close(triton_out, torch_out, atol=0.5, rtol=5e-2) + + +@pytest.mark.parametrize("M, N, K", get_shapes()) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +def test_gemm_afp8wfp8_preshuffle(M: int, N: int, K: int, dtype: torch.dtype): + if not arch_info.is_fp8_avail(): + pytest.skip("MXFP8 GEMM requires FP8-capable arch") + if N % 16 != 0 or K % 32 != 0: + pytest.skip("Preshuffle requires N % 16 == 0 and K % 32 == 0") + torch.cuda.empty_cache() + + x_fp8, w_fp8, x_scales, w_scales = generate_inputs(M, N, K) + + # Shuffle the weight tensor in place (layout=(16,16)). The shuffler operates on + # raw bytes; view as uint8 to avoid any dtype quirks, then view back. + w_uint8 = w_fp8.view(torch.uint8) + w_shuffled = shuffle_weight(w_uint8, layout=(16, 16)) + + torch_out = run_torch_gemm_afp8wfp8(x_fp8, w_fp8, x_scales, w_scales, dtype) + triton_out = gemm_afp8wfp8_preshuffle( + x_fp8, w_shuffled, x_scales, w_scales, dtype=dtype + ) + + torch.testing.assert_close(triton_out, torch_out, atol=0.5, rtol=5e-2) + + +@pytest.mark.parametrize("M, N, K", [(128, 128, 128), (256, 256, 256)]) +def test_gemm_afp8wfp8_preallocated_output(M: int, N: int, K: int): + if not arch_info.is_fp8_avail(): + pytest.skip("MXFP8 GEMM requires FP8-capable arch") + torch.cuda.empty_cache() + + dtype = torch.bfloat16 + x_fp8, w_fp8, x_scales, w_scales = generate_inputs(M, N, K) + y = torch.empty((M, N), dtype=dtype, device="cuda") + out = gemm_afp8wfp8(x_fp8, w_fp8, x_scales, w_scales, dtype=dtype, y=y) + assert out.data_ptr() == y.data_ptr() + + torch_out = run_torch_gemm_afp8wfp8(x_fp8, w_fp8, x_scales, w_scales, dtype) + torch.testing.assert_close(out, torch_out, atol=0.5, rtol=5e-2) diff --git a/op_tests/triton_tests/quant/test_quant_mxfp8.py b/op_tests/triton_tests/quant/test_quant_mxfp8.py new file mode 100644 index 0000000000..9433eefa87 --- /dev/null +++ b/op_tests/triton_tests/quant/test_quant_mxfp8.py @@ -0,0 +1,362 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. + +import pytest +import torch + +from aiter.ops.triton.quant.quant_mxfp8 import ( + per_1x32_mxfp8_quant_triton, + fp8_legacy_to_mxfp8, + rmsnorm_mxfp8_quant, + dual_rmsnorm_mxfp8_quant, +) +import aiter.ops.triton.utils._triton.arch_info as arch_info + +QUANT_BLOCK_SIZE = 32 +LEGACY_BLOCK_SIZE = 128 +# 0xFF800000 in two's complement int32. Mask keeps sign + 8-bit exponent + top mantissa bit. +_E8M0_MASK_INT32 = -8388608 + + +def torch_mxfp8_quant_from_fp32(x_fp32: torch.Tensor): + """Bit-faithful port of `_mxfp8_quant_kernel` quant logic, taking fp32 input. + + Computes per-1x32 e8m0 scale (uint8) and FP8 e4m3fn values. + """ + assert x_fp32.dim() == 2, f"x_fp32 must be 2D, got {x_fp32.dim()}" + M, K = x_fp32.shape + assert K % QUANT_BLOCK_SIZE == 0 + Ng = K // QUANT_BLOCK_SIZE + x_2d = x_fp32.reshape(M, Ng, QUANT_BLOCK_SIZE).to(torch.float32) + amax = torch.amax(torch.abs(x_2d), dim=-1, keepdim=True) # (M, Ng, 1) + + # Same bit-level "round up to e8m0-representable pow-2" as the kernel. + amax_i32 = amax.contiguous().view(torch.int32) + amax_i32 = (amax_i32 + 0x200000) & _E8M0_MASK_INT32 + amax_p2 = amax_i32.view(torch.float32) + + scale_unbiased = torch.log2(amax_p2).floor() - 8 + scale_unbiased = torch.clamp(scale_unbiased, min=-127, max=127) + scale_e8m0 = (scale_unbiased.to(torch.int32) + 127).to(torch.uint8) + quant_scale = torch.exp2(-scale_unbiased) + + qx_2d = x_2d * quant_scale # broadcast over inner-32 + qx = qx_2d.reshape(M, K) + y_fp8 = qx.to(torch.float8_e4m3fn) + s = scale_e8m0.reshape(M, Ng) + return y_fp8, s + + +def e8m0_to_f32(x: torch.Tensor) -> torch.Tensor: + return torch.exp2((x.to(torch.int32) - 127).to(torch.float32)) + + +# ----------------------------------------------------------------------------- +# per_1x32_mxfp8_quant_triton +# ----------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "M, K", + [ + (1, 32), + (1, 64), + (1, 128), + (2, 32), + (8, 64), + (16, 128), + (32, 256), + (64, 512), + (128, 1024), + (137, 64), # non-power-of-2 M + (256, 32), + ], +) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) +def test_per_1x32_mxfp8_quant(M: int, K: int, dtype: torch.dtype): + if not arch_info.is_fp8_avail(): + pytest.skip("FP8 not supported on this arch") + torch.cuda.empty_cache() + torch.manual_seed(20) + + x = torch.randn((M, K), dtype=dtype, device="cuda") * 4.0 + + # Reference path: emulate the kernel in fp32 (matching its precision). + x_fp32 = x.to(torch.float32) + y_ref, s_ref = torch_mxfp8_quant_from_fp32(x_fp32) + + # Triton path. + y_kern, s_kern = per_1x32_mxfp8_quant_triton(x) + + # Scales must be bit-exact: the e8m0 derivation is integer-only after + # the fp32 cast, and amax is order-independent. + torch.testing.assert_close(s_kern, s_ref) + + # Quantized values: compare via the uint8 view (allow off-by-1 for any + # rounding-mode subtlety in the fp32→fp8 cast). + torch.testing.assert_close( + y_kern.view(torch.uint8).to(torch.int32), + y_ref.view(torch.uint8).to(torch.int32), + atol=1, + rtol=0, + ) + + +def test_per_1x32_mxfp8_quant_preallocated_scale(): + if not arch_info.is_fp8_avail(): + pytest.skip("FP8 not supported on this arch") + torch.cuda.empty_cache() + torch.manual_seed(20) + + M, K = 64, 256 + x = torch.randn((M, K), dtype=torch.bfloat16, device="cuda") + scale_pre = torch.empty( + (M, K // QUANT_BLOCK_SIZE), dtype=torch.uint8, device="cuda" + ) + y, s = per_1x32_mxfp8_quant_triton(x, scale=scale_pre) + assert s.data_ptr() == scale_pre.data_ptr() + + y_ref, s_ref = torch_mxfp8_quant_from_fp32(x.to(torch.float32)) + torch.testing.assert_close(s, s_ref) + + +def test_per_1x32_mxfp8_quant_multidim(): + """Wrapper folds higher dims into M; sanity-check 3D input.""" + if not arch_info.is_fp8_avail(): + pytest.skip("FP8 not supported on this arch") + torch.cuda.empty_cache() + torch.manual_seed(0) + + B, M, K = 4, 8, 128 + x = torch.randn((B, M, K), dtype=torch.bfloat16, device="cuda") + y, s = per_1x32_mxfp8_quant_triton(x) + assert y.shape == (B, M, K) + assert s.shape == (B, M, K // QUANT_BLOCK_SIZE) + + y_ref, s_ref = torch_mxfp8_quant_from_fp32(x.reshape(-1, K).to(torch.float32)) + torch.testing.assert_close(s.reshape(-1, K // QUANT_BLOCK_SIZE), s_ref) + + +# ----------------------------------------------------------------------------- +# fp8_legacy_to_mxfp8 +# ----------------------------------------------------------------------------- + + +def torch_fp8_legacy_to_mxfp8(x_fnuz: torch.Tensor, x_scale_fp32: torch.Tensor): + """Reference: dequantize fnuz fp8 with the 1x128 fp32 scale, then run + the standard mxfp8 1x32 quant on the result.""" + M, N = x_fnuz.shape + x_dq = x_fnuz.to(torch.float32) * x_scale_fp32.repeat_interleave( + LEGACY_BLOCK_SIZE, dim=1 + ) + return torch_mxfp8_quant_from_fp32(x_dq) + + +@pytest.mark.parametrize( + "M, N", + [ + (1, 128), + (8, 128), + (16, 256), + (32, 512), + (64, 1024), + (128, 256), + (37, 256), # non-pow-2 M + ], +) +def test_fp8_legacy_to_mxfp8(M: int, N: int): + if not arch_info.is_fp8_avail(): + pytest.skip("FP8 not supported on this arch") + torch.cuda.empty_cache() + torch.manual_seed(5) + + # Random values within e4m3fnuz range, then cast to fnuz fp8. + x_f32 = (torch.randn((M, N), dtype=torch.float32, device="cuda")).clamp(-200, 200) + x_fnuz = x_f32.to(torch.float8_e4m3fnuz) + # Random fp32 1x128 scales in a moderate range so the dequant stays within fp8. + x_scale_fp32 = ( + torch.rand((M, N // LEGACY_BLOCK_SIZE), dtype=torch.float32, device="cuda") + * 0.5 + + 0.25 + ) + + y_ref, s_ref = torch_fp8_legacy_to_mxfp8(x_fnuz, x_scale_fp32) + y_kern, s_kern = fp8_legacy_to_mxfp8(x_fnuz, x_scale_fp32) + + torch.testing.assert_close(s_kern, s_ref) + torch.testing.assert_close( + y_kern.view(torch.uint8).to(torch.int32), + y_ref.view(torch.uint8).to(torch.int32), + atol=1, + rtol=0, + ) + + +def test_fp8_legacy_to_mxfp8_preallocated(): + if not arch_info.is_fp8_avail(): + pytest.skip("FP8 not supported on this arch") + torch.cuda.empty_cache() + torch.manual_seed(5) + + M, N = 16, 256 + x_fnuz = (torch.randn((M, N), device="cuda") * 4).to(torch.float8_e4m3fnuz) + x_scale_fp32 = torch.rand((M, N // LEGACY_BLOCK_SIZE), device="cuda") * 0.5 + 0.25 + y_pre = torch.empty((M, N), dtype=torch.float8_e4m3fn, device="cuda") + s_pre = torch.empty((M, N // QUANT_BLOCK_SIZE), dtype=torch.uint8, device="cuda") + y, s = fp8_legacy_to_mxfp8(x_fnuz, x_scale_fp32, y_fn=y_pre, y_scale=s_pre) + assert y.data_ptr() == y_pre.data_ptr() + assert s.data_ptr() == s_pre.data_ptr() + + y_ref, s_ref = torch_fp8_legacy_to_mxfp8(x_fnuz, x_scale_fp32) + torch.testing.assert_close(s, s_ref) + + +# ----------------------------------------------------------------------------- +# rmsnorm_mxfp8_quant +# ----------------------------------------------------------------------------- + + +def torch_rmsnorm(x: torch.Tensor, weight: torch.Tensor, eps: float) -> torch.Tensor: + x_f32 = x.to(torch.float32) + g_f32 = weight.to(torch.float32) + rstd = torch.rsqrt(x_f32.pow(2).mean(-1, keepdim=True) + eps) + return x_f32 * rstd * g_f32 + + +def torch_rmsnorm_mxfp8_quant(x, weight, eps): + y_fp32 = torch_rmsnorm(x, weight, eps) + return torch_mxfp8_quant_from_fp32(y_fp32) + + +@pytest.mark.parametrize( + "M, K", + [ + (1, 32), + (1, 128), + (8, 128), + (16, 256), + (32, 512), + (64, 1024), + (128, 2048), + (97, 64), # non-pow-2 M, K=64 + (200, 192), # non-pow-2 K (still multiple of 32) + ], +) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) +def test_rmsnorm_mxfp8_quant(M: int, K: int, dtype: torch.dtype): + if not arch_info.is_fp8_avail(): + pytest.skip("FP8 not supported on this arch") + torch.cuda.empty_cache() + torch.manual_seed(11) + + x = torch.randn((M, K), dtype=dtype, device="cuda") + weight = torch.randn((K,), dtype=dtype, device="cuda") * 0.5 + 1.0 + eps = 1e-5 + + y_ref, s_ref = torch_rmsnorm_mxfp8_quant(x, weight, eps) + y_kern, s_kern = rmsnorm_mxfp8_quant(x, weight, eps) + + # Hardware rsqrt vs torch.rsqrt can disagree by a ULP; that may flip a single + # e8m0 bin near a power-of-2 boundary. Compare dequantized values instead. + s_ref_f32 = e8m0_to_f32(s_ref).repeat_interleave(QUANT_BLOCK_SIZE, dim=1) + s_kern_f32 = e8m0_to_f32(s_kern).repeat_interleave(QUANT_BLOCK_SIZE, dim=1) + y_ref_dq = y_ref.to(torch.float32) * s_ref_f32 + y_kern_dq = y_kern.to(torch.float32) * s_kern_f32 + + torch.testing.assert_close(y_kern_dq, y_ref_dq, atol=5e-2, rtol=5e-2) + + +def test_rmsnorm_mxfp8_quant_preallocated(): + if not arch_info.is_fp8_avail(): + pytest.skip("FP8 not supported on this arch") + torch.cuda.empty_cache() + torch.manual_seed(11) + + M, K = 32, 256 + x = torch.randn((M, K), dtype=torch.bfloat16, device="cuda") + weight = torch.randn((K,), dtype=torch.bfloat16, device="cuda") + y_pre = torch.empty((M, K), dtype=torch.float8_e4m3fn, device="cuda") + s_pre = torch.empty((M, K // QUANT_BLOCK_SIZE), dtype=torch.uint8, device="cuda") + y, s = rmsnorm_mxfp8_quant(x, weight, 1e-5, y=y_pre, scale=s_pre) + assert y.data_ptr() == y_pre.data_ptr() + assert s.data_ptr() == s_pre.data_ptr() + + +# ----------------------------------------------------------------------------- +# dual_rmsnorm_mxfp8_quant +# ----------------------------------------------------------------------------- + + +def torch_dual_rmsnorm_mxfp8_quant(q, k, q_weight, k_weight, eps_q, eps_k): + yq_fp32 = torch_rmsnorm(q, q_weight, eps_q) + yq, sq = torch_mxfp8_quant_from_fp32(yq_fp32) + yk_fp32 = torch_rmsnorm(k, k_weight, eps_k) + yk = yk_fp32.to(k.dtype) + return yq, sq, yk + + +@pytest.mark.parametrize( + "M, KQ, KK", + [ + (1, 32, 32), + (1, 128, 64), + (8, 256, 128), + (16, 512, 256), + (32, 1024, 512), + (64, 2048, 1024), + (47, 96, 80), # non-pow-2 sizes + ], +) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) +def test_dual_rmsnorm_mxfp8_quant(M: int, KQ: int, KK: int, dtype: torch.dtype): + if not arch_info.is_fp8_avail(): + pytest.skip("FP8 not supported on this arch") + torch.cuda.empty_cache() + torch.manual_seed(13) + + q = torch.randn((M, KQ), dtype=dtype, device="cuda") + k = torch.randn((M, KK), dtype=dtype, device="cuda") + q_weight = torch.randn((KQ,), dtype=dtype, device="cuda") * 0.5 + 1.0 + k_weight = torch.randn((KK,), dtype=dtype, device="cuda") * 0.5 + 1.0 + eps_q, eps_k = 1e-5, 2e-5 + + yq_ref, sq_ref, yk_ref = torch_dual_rmsnorm_mxfp8_quant( + q, k, q_weight, k_weight, eps_q, eps_k + ) + yq_kern, sq_kern, yk_kern = dual_rmsnorm_mxfp8_quant( + q, k, q_weight, k_weight, eps_q, eps_k + ) + + # Q side: compare dequantized values (rsqrt jitter -> tolerate e8m0 ULP flips). + sq_ref_f32 = e8m0_to_f32(sq_ref).repeat_interleave(QUANT_BLOCK_SIZE, dim=1) + sq_kern_f32 = e8m0_to_f32(sq_kern).repeat_interleave(QUANT_BLOCK_SIZE, dim=1) + yq_ref_dq = yq_ref.to(torch.float32) * sq_ref_f32 + yq_kern_dq = yq_kern.to(torch.float32) * sq_kern_f32 + torch.testing.assert_close(yq_kern_dq, yq_ref_dq, atol=5e-2, rtol=5e-2) + + # K side: bf16/fp16 RMSNorm output. + torch.testing.assert_close(yk_kern, yk_ref, atol=5e-3, rtol=5e-3) + + +def test_dual_rmsnorm_mxfp8_quant_default_eps_k(): + """eps_k defaults to eps_q when not provided.""" + if not arch_info.is_fp8_avail(): + pytest.skip("FP8 not supported on this arch") + torch.cuda.empty_cache() + torch.manual_seed(13) + + M, KQ, KK = 16, 128, 96 + dtype = torch.bfloat16 + q = torch.randn((M, KQ), dtype=dtype, device="cuda") + k = torch.randn((M, KK), dtype=dtype, device="cuda") + q_weight = torch.randn((KQ,), dtype=dtype, device="cuda") + k_weight = torch.randn((KK,), dtype=dtype, device="cuda") + eps = 1e-5 + + yq_a, sq_a, yk_a = dual_rmsnorm_mxfp8_quant(q, k, q_weight, k_weight, eps) + yq_b, sq_b, yk_b = dual_rmsnorm_mxfp8_quant( + q, k, q_weight, k_weight, eps, eps_k=eps + ) + torch.testing.assert_close(yq_a.view(torch.uint8), yq_b.view(torch.uint8)) + torch.testing.assert_close(sq_a, sq_b) + torch.testing.assert_close(yk_a, yk_b) From 9f2a6ed5f2d7ef08a20091b9b935b80dc2f7812d Mon Sep 17 00:00:00 2001 From: ShaoChunLee Date: Tue, 19 May 2026 21:32:00 +0000 Subject: [PATCH 3/8] add splitk --- .../gemm/basic/gemm_afp8wfp8.py | 452 ++++++++++-------- .../tunning/ut_afp8wfp8_gemm_preshuffle.py | 44 ++ .../gemm/basic/test_gemm_afp8wfp8.py | 54 +-- 3 files changed, 318 insertions(+), 232 deletions(-) create mode 100644 aiter/ops/triton/utils/_triton/tunning/ut_afp8wfp8_gemm_preshuffle.py diff --git a/aiter/ops/triton/_triton_kernels/gemm/basic/gemm_afp8wfp8.py b/aiter/ops/triton/_triton_kernels/gemm/basic/gemm_afp8wfp8.py index b657c27c9b..27b0c121ad 100644 --- a/aiter/ops/triton/_triton_kernels/gemm/basic/gemm_afp8wfp8.py +++ b/aiter/ops/triton/_triton_kernels/gemm/basic/gemm_afp8wfp8.py @@ -20,6 +20,7 @@ "matrix_instr_nonkdim", "cache_modifier", "NUM_KSPLIT", + "SPLITK_BLOCK_SIZE", ], ) @@ -56,6 +57,7 @@ def _gemm_afp8wfp8_kernel( BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, NUM_KSPLIT: tl.constexpr, + SPLITK_BLOCK_SIZE: tl.constexpr, EVEN_K: tl.constexpr, num_warps: tl.constexpr, num_stages: tl.constexpr, @@ -71,7 +73,9 @@ def _gemm_afp8wfp8_kernel( representing 128x128 weight blocks. Broadcast inside kernel to (N, K // 32). A has shape (M, K), B has shape (K, N) and C has shape (M, N). Output dtype is determined by c_ptr (bf16 or fp16). - NUM_KSPLIT == 1 only (no split-K in this first version). + When NUM_KSPLIT > 1, K is split into NUM_KSPLIT partitions of + SPLITK_BLOCK_SIZE elements and the partial result for partition pid_k is + written to c_ptr + pid_k * stride_ck; a downstream reduce kernel sums them. """ tl.assume(stride_am > 0) @@ -88,120 +92,136 @@ def _gemm_afp8wfp8_kernel( GRID_MN = tl.cdiv(M, BLOCK_SIZE_M) * tl.cdiv(N, BLOCK_SIZE_N) pid_unified = tl.program_id(axis=0) - pid_unified = remap_xcd(pid_unified, GRID_MN, NUM_XCDS=8) - - pid = pid_unified + pid_k = pid_unified % NUM_KSPLIT + pid = pid_unified // NUM_KSPLIT num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - pid_m, pid_n = pid_grid(pid, num_pid_m, num_pid_n, GROUP_SIZE_M=GROUP_SIZE_M) + if NUM_KSPLIT == 1: + pid = remap_xcd(pid, GRID_MN, NUM_XCDS=8) + pid_m, pid_n = pid_grid(pid, num_pid_m, num_pid_n, GROUP_SIZE_M=GROUP_SIZE_M) + else: + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n tl.assume(pid_m >= 0) tl.assume(pid_n >= 0) + tl.assume(pid_k >= 0) # Scale group sizes SCALE_GROUP_SIZE: tl.constexpr = 32 # A: per 32 elements along K B_SCALE_K_GROUP: tl.constexpr = 128 # B: per 128 along K B_SCALE_N_GROUP: tl.constexpr = 128 # B: per 128 along N - # Number of K iterations - num_k_iter = tl.cdiv(K, BLOCK_SIZE_K) - - # Create pointers for first block of A and B input matrices - offs_k = tl.arange(0, BLOCK_SIZE_K) - offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N - a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) - b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) - - # A-scale pointers: per-row (M) and per scale group (K // 32) - offs_ks_a = tl.arange(0, BLOCK_SIZE_K // SCALE_GROUP_SIZE) - a_scale_ptrs = ( - a_scales_ptr + offs_am[:, None] * stride_asm + offs_ks_a[None, :] * stride_ask - ) - - # B-scale pointers: compact (N // 128, K // 128) — broadcast inside the kernel - # Each scale covers a 128(N) x 128(K) block. - # We want a tile of shape (BLOCK_SIZE_N, BLOCK_SIZE_K // 32) per K-iter. - # Broadcast by floor-division: row index n -> n // 128; k-scale index ks -> ks // 4 - # (since ks counts 32-element groups, and 128 / 32 = 4). - offs_bsn = offs_bn // B_SCALE_N_GROUP # (BLOCK_SIZE_N,) - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - offs_scale_k_a = tl.arange(0, BLOCK_SIZE_K // SCALE_GROUP_SIZE) - - for k in range(0, num_k_iter): - # K base for this iteration (in elements) - k_base = k * BLOCK_SIZE_K - - # ---- Load A scales (M, BLOCK_SIZE_K // 32) ---- - if EVEN_K: - a_scales = tl.load(a_scale_ptrs) - else: - a_scale_mask = offs_scale_k_a[None, :] < ( - K // SCALE_GROUP_SIZE - k * (BLOCK_SIZE_K // SCALE_GROUP_SIZE) - ) - a_scales = tl.load(a_scale_ptrs, mask=a_scale_mask, other=127) - - # ---- Load and broadcast B scales (BLOCK_SIZE_N, BLOCK_SIZE_K // 32) ---- - # B-scale K index in compact layout (K // 128 stride) - k_scale_idx_b = (k_base + tl.arange(0, BLOCK_SIZE_K) // SCALE_GROUP_SIZE) // ( - B_SCALE_K_GROUP // SCALE_GROUP_SIZE + if (pid_k * SPLITK_BLOCK_SIZE) < K: + # K-block iteration range for this split (absolute block indices). + num_k_iter = tl.cdiv(SPLITK_BLOCK_SIZE, BLOCK_SIZE_K) + + # Create pointers for first block of A and B input matrices. The K + # offset is the absolute start of this split's K range. + offs_k = tl.arange(0, BLOCK_SIZE_K) + offs_k_split = pid_k * SPLITK_BLOCK_SIZE + offs_k + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + a_ptrs = a_ptr + ( + offs_am[:, None] * stride_am + offs_k_split[None, :] * stride_ak ) - # k_scale_idx_b: shape (BLOCK_SIZE_K,), but we want one per 32-group. - # Recompute correctly: - offs_bsk = ( - k_base + offs_scale_k_a * SCALE_GROUP_SIZE - ) // B_SCALE_K_GROUP # (BLOCK_SIZE_K // 32,) - b_scale_ptrs = ( - b_scales_ptr - + offs_bsn[:, None] * stride_bsn - + offs_bsk[None, :] * stride_bsk + b_ptrs = b_ptr + ( + offs_k_split[:, None] * stride_bk + offs_bn[None, :] * stride_bn ) - if EVEN_K: - b_scales = tl.load(b_scale_ptrs, cache_modifier=cache_modifier) - else: - # OOB along K: load with the same mask as a-scales - b_scale_mask = offs_scale_k_a[None, :] < ( - K // SCALE_GROUP_SIZE - k * (BLOCK_SIZE_K // SCALE_GROUP_SIZE) - ) - b_scales = tl.load( - b_scale_ptrs, - mask=b_scale_mask, - other=127, - cache_modifier=cache_modifier, - ) - - # ---- Load A, B data ---- - if EVEN_K: - a = tl.load(a_ptrs) - b = tl.load(b_ptrs, cache_modifier=cache_modifier) - else: - a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0) - b = tl.load( - b_ptrs, - mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, - other=0, - cache_modifier=cache_modifier, - ) - accumulator = tl.dot_scaled( - a, a_scales, "e4m3", b, b_scales, "e4m3", accumulator + # A-scale pointers: per-row (M) and per scale group (K // 32). Shift + # along the K-scale axis by the split's start in scale groups. + offs_ks_a = tl.arange(0, BLOCK_SIZE_K // SCALE_GROUP_SIZE) + offs_ks_a_split = pid_k * (SPLITK_BLOCK_SIZE // SCALE_GROUP_SIZE) + offs_ks_a + a_scale_ptrs = ( + a_scales_ptr + + offs_am[:, None] * stride_asm + + offs_ks_a_split[None, :] * stride_ask ) - # Advance the ptrs to the next K block. - a_ptrs += BLOCK_SIZE_K * stride_ak - b_ptrs += BLOCK_SIZE_K * stride_bk - a_scale_ptrs += (BLOCK_SIZE_K // SCALE_GROUP_SIZE) * stride_ask - - c = accumulator.to(c_ptr.type.element_ty) + # B-scale pointers: compact (N // 128, K // 128) — broadcast inside the kernel + # Each scale covers a 128(N) x 128(K) block. Computed per-iteration below + # using absolute K (so split-K naturally addresses the right b-scale block). + offs_bsn = offs_bn // B_SCALE_N_GROUP # (BLOCK_SIZE_N,) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + offs_scale_k_a = tl.arange(0, BLOCK_SIZE_K // SCALE_GROUP_SIZE) + + for k in range(pid_k * num_k_iter, (pid_k + 1) * num_k_iter): + # K base for this iteration (in elements, absolute). + k_base = k * BLOCK_SIZE_K + + # ---- Load A scales (M, BLOCK_SIZE_K // 32) ---- + if EVEN_K: + a_scales = tl.load(a_scale_ptrs) + else: + a_scale_mask = offs_scale_k_a[None, :] < ( + K // SCALE_GROUP_SIZE - k * (BLOCK_SIZE_K // SCALE_GROUP_SIZE) + ) + a_scales = tl.load(a_scale_ptrs, mask=a_scale_mask, other=127) + + # ---- Load and broadcast B scales (BLOCK_SIZE_N, BLOCK_SIZE_K // 32) ---- + offs_bsk = ( + k_base + offs_scale_k_a * SCALE_GROUP_SIZE + ) // B_SCALE_K_GROUP # (BLOCK_SIZE_K // 32,) + b_scale_ptrs = ( + b_scales_ptr + + offs_bsn[:, None] * stride_bsn + + offs_bsk[None, :] * stride_bsk + ) + if EVEN_K: + b_scales = tl.load(b_scale_ptrs, cache_modifier=cache_modifier) + else: + # OOB along K: load with the same mask as a-scales + b_scale_mask = offs_scale_k_a[None, :] < ( + K // SCALE_GROUP_SIZE - k * (BLOCK_SIZE_K // SCALE_GROUP_SIZE) + ) + b_scales = tl.load( + b_scale_ptrs, + mask=b_scale_mask, + other=127, + cache_modifier=cache_modifier, + ) + + # ---- Load A, B data ---- + if EVEN_K: + a = tl.load(a_ptrs) + b = tl.load(b_ptrs, cache_modifier=cache_modifier) + else: + a = tl.load( + a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0 + ) + b = tl.load( + b_ptrs, + mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, + other=0, + cache_modifier=cache_modifier, + ) + + accumulator = tl.dot_scaled( + a, a_scales, "e4m3", b, b_scales, "e4m3", accumulator + ) - # Write back the block of the output matrix C with masks. - offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) - offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64) - c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) - tl.store(c_ptrs, c, mask=c_mask) + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + a_scale_ptrs += (BLOCK_SIZE_K // SCALE_GROUP_SIZE) * stride_ask + + c = accumulator.to(c_ptr.type.element_ty) + + # Write back the block of the output matrix C with masks. For + # NUM_KSPLIT > 1, each pid_k writes to a separate slab of c_ptr. + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64) + c_ptrs = ( + c_ptr + + stride_cm * offs_cm[:, None] + + stride_cn * offs_cn[None, :] + + pid_k * stride_ck + ) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) _gemm_afp8wfp8_preshuffle_repr = make_kernel_repr( @@ -217,6 +237,7 @@ def _gemm_afp8wfp8_kernel( "matrix_instr_nonkdim", "cache_modifier", "NUM_KSPLIT", + "SPLITK_BLOCK_SIZE", ], ) @@ -253,6 +274,7 @@ def _gemm_afp8wfp8_preshuffle_kernel( BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, NUM_KSPLIT: tl.constexpr, + SPLITK_BLOCK_SIZE: tl.constexpr, EVEN_K: tl.constexpr, num_warps: tl.constexpr, num_stages: tl.constexpr, @@ -267,6 +289,8 @@ def _gemm_afp8wfp8_preshuffle_kernel( shuffled tile in storage order (BLOCK_SIZE_N // 16, BLOCK_SIZE_K * 16) then reshape+permute+trans inside the kernel to restore logical (K, N) layout before tl.dot_scaled. Scales remain in the unshuffled compact 128x128 layout. + When NUM_KSPLIT > 1, K is split into NUM_KSPLIT partitions of + SPLITK_BLOCK_SIZE elements; each pid_k writes to c_ptr + pid_k * stride_ck. """ tl.assume(stride_am > 0) @@ -283,130 +307,154 @@ def _gemm_afp8wfp8_preshuffle_kernel( GRID_MN = tl.cdiv(M, BLOCK_SIZE_M) * tl.cdiv(N, BLOCK_SIZE_N) pid_unified = tl.program_id(axis=0) - pid_unified = remap_xcd(pid_unified, GRID_MN, NUM_XCDS=8) - - pid = pid_unified + pid_k = pid_unified % NUM_KSPLIT + pid = pid_unified // NUM_KSPLIT num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - pid_m, pid_n = pid_grid(pid, num_pid_m, num_pid_n, GROUP_SIZE_M=GROUP_SIZE_M) + if NUM_KSPLIT == 1: + pid = remap_xcd(pid, GRID_MN, NUM_XCDS=8) + pid_m, pid_n = pid_grid(pid, num_pid_m, num_pid_n, GROUP_SIZE_M=GROUP_SIZE_M) + else: + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n tl.assume(pid_m >= 0) tl.assume(pid_n >= 0) + tl.assume(pid_k >= 0) SCALE_GROUP_SIZE: tl.constexpr = 32 # A: per-32 along K B_SCALE_K_GROUP: tl.constexpr = 128 # B compact: per-128 along K B_SCALE_N_GROUP: tl.constexpr = 128 # B compact: per-128 along N - num_k_iter = tl.cdiv(K, BLOCK_SIZE_K) - - # A pointers (unchanged from non-preshuffle kernel). - offs_k = tl.arange(0, BLOCK_SIZE_K) - offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) - - # B pointers for preshuffled layout. The shuffled storage is viewed as - # (N // 16, K * 16) elements. pid_n indexes BLOCK_SIZE_N // 16 N-tiles per - # step; the K dimension is expanded by 16x in byte addresses. - offs_bn_shuffle = pid_n * (BLOCK_SIZE_N // 16) + tl.arange(0, BLOCK_SIZE_N // 16) - offs_k_shuffle = tl.arange(0, BLOCK_SIZE_K * 16) - b_ptrs = b_ptr + ( - offs_bn_shuffle[:, None] * stride_bn + offs_k_shuffle[None, :] * stride_bk - ) - - # A-scale pointers: per-row M, per 32-K group. - offs_ks_a = tl.arange(0, BLOCK_SIZE_K // SCALE_GROUP_SIZE) - a_scale_ptrs = ( - a_scales_ptr + offs_am[:, None] * stride_asm + offs_ks_a[None, :] * stride_ask - ) - - # B-scale pointers: compact (N // 128, K // 128). The N index needs the - # ORIGINAL (logical) row, not the shuffled row index, so use offs_bn_logical. - offs_bn_logical = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N - offs_bsn = offs_bn_logical // B_SCALE_N_GROUP - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - offs_scale_k_a = tl.arange(0, BLOCK_SIZE_K // SCALE_GROUP_SIZE) - - for k in range(0, num_k_iter): - k_base = k * BLOCK_SIZE_K - - # Load A scales. - if EVEN_K: - a_scales = tl.load(a_scale_ptrs) - else: - a_scale_mask = offs_scale_k_a[None, :] < ( - K // SCALE_GROUP_SIZE - k * (BLOCK_SIZE_K // SCALE_GROUP_SIZE) - ) - a_scales = tl.load(a_scale_ptrs, mask=a_scale_mask, other=127) - - # Load and broadcast B scales (same as non-preshuffle path). - offs_bsk = (k_base + offs_scale_k_a * SCALE_GROUP_SIZE) // B_SCALE_K_GROUP - b_scale_ptrs = ( - b_scales_ptr - + offs_bsn[:, None] * stride_bsn - + offs_bsk[None, :] * stride_bsk - ) - if EVEN_K: - b_scales = tl.load(b_scale_ptrs, cache_modifier=cache_modifier) - else: - b_scale_mask = offs_scale_k_a[None, :] < ( - K // SCALE_GROUP_SIZE - k * (BLOCK_SIZE_K // SCALE_GROUP_SIZE) - ) - b_scales = tl.load( - b_scale_ptrs, - mask=b_scale_mask, - other=127, - cache_modifier=cache_modifier, - ) + if (pid_k * SPLITK_BLOCK_SIZE) < K: + num_k_iter = tl.cdiv(SPLITK_BLOCK_SIZE, BLOCK_SIZE_K) - # Load A and B (preshuffled). - if EVEN_K: - a = tl.load(a_ptrs) - b_shuf = tl.load(b_ptrs, cache_modifier=cache_modifier) - else: - a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0) - b_shuf = tl.load( - b_ptrs, - mask=offs_k_shuffle[None, :] < (K - k * BLOCK_SIZE_K) * 16, - other=0, - cache_modifier=cache_modifier, - ) + # A pointers (offset by this split's K start). + offs_k = tl.arange(0, BLOCK_SIZE_K) + offs_k_split = pid_k * SPLITK_BLOCK_SIZE + offs_k + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + a_ptrs = a_ptr + ( + offs_am[:, None] * stride_am + offs_k_split[None, :] * stride_ak + ) - # Unshuffle B in-kernel. Inverse of the shuffle_weight permute: - # shuffle: (N // 16, 16, K // 32, 2, 16) --[perm 0,1,3,4,2,5]-> (N // 16, K // 32, 2, 16, 16) - # unshuffle: (N // 16, K // 32, 2, 16, 16) --[perm 0,1,4,2,3,5]-> (N // 16, 16, K // 32, 2, 16) - # then flatten to (N, K) and trans to (K, N). - b = ( - b_shuf.reshape( - 1, - BLOCK_SIZE_N // 16, - BLOCK_SIZE_K // 32, - 2, - 16, - 16, - ) - .permute(0, 1, 4, 2, 3, 5) - .reshape(BLOCK_SIZE_N, BLOCK_SIZE_K) - .trans(1, 0) + # B pointers for preshuffled layout. The shuffled storage is viewed as + # (N // 16, K * 16) elements. pid_n indexes BLOCK_SIZE_N // 16 N-tiles per + # step; the K dimension is expanded by 16x in byte addresses. The split + # offsets the K-byte axis by pid_k * SPLITK_BLOCK_SIZE * 16. + offs_bn_shuffle = pid_n * (BLOCK_SIZE_N // 16) + tl.arange( + 0, BLOCK_SIZE_N // 16 + ) + offs_k_shuffle_arr = tl.arange(0, BLOCK_SIZE_K * 16) + offs_k_shuffle = pid_k * SPLITK_BLOCK_SIZE * 16 + offs_k_shuffle_arr + b_ptrs = b_ptr + ( + offs_bn_shuffle[:, None] * stride_bn + offs_k_shuffle[None, :] * stride_bk ) - accumulator = tl.dot_scaled( - a, a_scales, "e4m3", b, b_scales, "e4m3", accumulator + # A-scale pointers: per-row M, per 32-K group. Shift along the K-scale + # axis by the split's start in scale groups. + offs_ks_a = tl.arange(0, BLOCK_SIZE_K // SCALE_GROUP_SIZE) + offs_ks_a_split = pid_k * (SPLITK_BLOCK_SIZE // SCALE_GROUP_SIZE) + offs_ks_a + a_scale_ptrs = ( + a_scales_ptr + + offs_am[:, None] * stride_asm + + offs_ks_a_split[None, :] * stride_ask ) - # Advance pointers. - a_ptrs += BLOCK_SIZE_K * stride_ak - b_ptrs += BLOCK_SIZE_K * 16 * stride_bk - a_scale_ptrs += (BLOCK_SIZE_K // SCALE_GROUP_SIZE) * stride_ask + # B-scale pointers: compact (N // 128, K // 128). The N index needs the + # ORIGINAL (logical) row, not the shuffled row index. + offs_bn_logical = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_bsn = offs_bn_logical // B_SCALE_N_GROUP + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + offs_scale_k_a = tl.arange(0, BLOCK_SIZE_K // SCALE_GROUP_SIZE) + + for k in range(pid_k * num_k_iter, (pid_k + 1) * num_k_iter): + k_base = k * BLOCK_SIZE_K # absolute K base + + # Load A scales. + if EVEN_K: + a_scales = tl.load(a_scale_ptrs) + else: + a_scale_mask = offs_scale_k_a[None, :] < ( + K // SCALE_GROUP_SIZE - k * (BLOCK_SIZE_K // SCALE_GROUP_SIZE) + ) + a_scales = tl.load(a_scale_ptrs, mask=a_scale_mask, other=127) + + # Load and broadcast B scales (computed from absolute K). + offs_bsk = (k_base + offs_scale_k_a * SCALE_GROUP_SIZE) // B_SCALE_K_GROUP + b_scale_ptrs = ( + b_scales_ptr + + offs_bsn[:, None] * stride_bsn + + offs_bsk[None, :] * stride_bsk + ) + if EVEN_K: + b_scales = tl.load(b_scale_ptrs, cache_modifier=cache_modifier) + else: + b_scale_mask = offs_scale_k_a[None, :] < ( + K // SCALE_GROUP_SIZE - k * (BLOCK_SIZE_K // SCALE_GROUP_SIZE) + ) + b_scales = tl.load( + b_scale_ptrs, + mask=b_scale_mask, + other=127, + cache_modifier=cache_modifier, + ) + + # Load A and B (preshuffled). + if EVEN_K: + a = tl.load(a_ptrs) + b_shuf = tl.load(b_ptrs, cache_modifier=cache_modifier) + else: + a = tl.load( + a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0 + ) + b_shuf = tl.load( + b_ptrs, + mask=offs_k_shuffle_arr[None, :] < (K - k * BLOCK_SIZE_K) * 16, + other=0, + cache_modifier=cache_modifier, + ) + + # Unshuffle B in-kernel. Inverse of the shuffle_weight permute: + # shuffle: (N // 16, 16, K // 32, 2, 16) --[perm 0,1,3,4,2,5]-> (N // 16, K // 32, 2, 16, 16) + # unshuffle: (N // 16, K // 32, 2, 16, 16) --[perm 0,1,4,2,3,5]-> (N // 16, 16, K // 32, 2, 16) + # then flatten to (N, K) and trans to (K, N). + b = ( + b_shuf.reshape( + 1, + BLOCK_SIZE_N // 16, + BLOCK_SIZE_K // 32, + 2, + 16, + 16, + ) + .permute(0, 1, 4, 2, 3, 5) + .reshape(BLOCK_SIZE_N, BLOCK_SIZE_K) + .trans(1, 0) + ) - c = accumulator.to(c_ptr.type.element_ty) + accumulator = tl.dot_scaled( + a, a_scales, "e4m3", b, b_scales, "e4m3", accumulator + ) - offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) - offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64) - c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) - tl.store(c_ptrs, c, mask=c_mask) + # Advance pointers. + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * 16 * stride_bk + a_scale_ptrs += (BLOCK_SIZE_K // SCALE_GROUP_SIZE) * stride_ask + + c = accumulator.to(c_ptr.type.element_ty) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64) + c_ptrs = ( + c_ptr + + stride_cm * offs_cm[:, None] + + stride_cn * offs_cn[None, :] + + pid_k * stride_ck + ) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) def _get_config( diff --git a/aiter/ops/triton/utils/_triton/tunning/ut_afp8wfp8_gemm_preshuffle.py b/aiter/ops/triton/utils/_triton/tunning/ut_afp8wfp8_gemm_preshuffle.py new file mode 100644 index 0000000000..8f13ac4b3e --- /dev/null +++ b/aiter/ops/triton/utils/_triton/tunning/ut_afp8wfp8_gemm_preshuffle.py @@ -0,0 +1,44 @@ +import sys +from _utils import ( + run_profile, + get_input_shape_and_config_list, +) + +############################################################ +# +import torch +from aiter.ops.triton.gemm.basic.gemm_afp8wfp8 import gemm_afp8wfp8_preshuffle +from op_tests.triton_tests.gemm.basic.test_gemm_afp8wfp8 import ( + generate_inputs, +) +from aiter.ops.triton.utils.types import get_fp8_dtypes +from aiter.ops.triton.utils.gemm_config_utils import compute_splitk_params + +############################################################ + +input_shape, config_list = get_input_shape_and_config_list(sys.argv, shape_size=3) +M, N, K = input_shape + +############################################################ +# +_, e4m3_type = get_fp8_dtypes() +dtype = torch.bfloat16 +x_fp8, w_fp8, w_kernel, x_scales, w_scales = generate_inputs( + *input_shape, + shuffle=True, +) +############################################################ + +for config in config_list: + if config is not None: + compute_splitk_params(config, K) + + def fn(): + ############################################################ + # + gemm_afp8wfp8_preshuffle( + x_fp8, w_kernel, x_scales, w_scales, dtype=dtype, config=config + ) + ############################################################ + + run_profile(fn) diff --git a/op_tests/triton_tests/gemm/basic/test_gemm_afp8wfp8.py b/op_tests/triton_tests/gemm/basic/test_gemm_afp8wfp8.py index a1b5c1a20a..8b3bf5eae8 100644 --- a/op_tests/triton_tests/gemm/basic/test_gemm_afp8wfp8.py +++ b/op_tests/triton_tests/gemm/basic/test_gemm_afp8wfp8.py @@ -22,8 +22,14 @@ def e8m0_to_f32(x: torch.Tensor) -> torch.Tensor: return torch.exp2((x.to(torch.int32) - 127).to(torch.float32)) -def generate_inputs(M: int, N: int, K: int, seed: int = 0): - torch.manual_seed(seed) +def generate_inputs(M: int, N: int, K: int, shuffle: bool = False): + """Returns ``(x_fp8, w_fp8, w_kernel, x_scales, w_scales)``. + + ``w_fp8`` is always the unshuffled weight (for use by the fp32 reference). + ``w_kernel`` is the weight to pass to the kernel: identical to ``w_fp8`` + when ``shuffle=False``, or shuffled via ``shuffle_weight(layout=(16, 16))`` + when ``shuffle=True``. + """ # Small random fp32 → fp8 e4m3fn, kept inside e4m3 range so the cast is exact-ish. x_f32 = torch.randn((M, K), dtype=torch.float32, device="cuda") w_f32 = torch.randn((N, K), dtype=torch.float32, device="cuda") @@ -43,7 +49,14 @@ def generate_inputs(M: int, N: int, K: int, seed: int = 0): dtype=torch.uint8, device="cuda", ) - return x_fp8, w_fp8, x_scales, w_scales + + if shuffle: + # shuffle_weight operates on raw bytes; view as uint8 to avoid dtype quirks. + w_kernel = shuffle_weight(w_fp8.view(torch.uint8), layout=(16, 16)) + else: + w_kernel = w_fp8 + + return x_fp8, w_fp8, w_kernel, x_scales, w_scales def run_torch_gemm_afp8wfp8( @@ -98,53 +111,34 @@ def get_shapes(): @pytest.mark.parametrize("M, N, K", get_shapes()) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) def test_gemm_afp8wfp8(M: int, N: int, K: int, dtype: torch.dtype): + torch.manual_seed(0) if not arch_info.is_fp8_avail(): pytest.skip("MXFP8 GEMM requires FP8-capable arch") torch.cuda.empty_cache() - x_fp8, w_fp8, x_scales, w_scales = generate_inputs(M, N, K) + x_fp8, w_fp8, w_kernel, x_scales, w_scales = generate_inputs(M, N, K, shuffle=False) torch_out = run_torch_gemm_afp8wfp8(x_fp8, w_fp8, x_scales, w_scales, dtype) - triton_out = gemm_afp8wfp8(x_fp8, w_fp8, x_scales, w_scales, dtype=dtype) + triton_out = gemm_afp8wfp8(x_fp8, w_kernel, x_scales, w_scales, dtype=dtype) - torch.testing.assert_close(triton_out, torch_out, atol=0.5, rtol=5e-2) + torch.testing.assert_close(triton_out, torch_out, atol=0.03, rtol=1e-2) @pytest.mark.parametrize("M, N, K", get_shapes()) @pytest.mark.parametrize("dtype", [torch.bfloat16]) def test_gemm_afp8wfp8_preshuffle(M: int, N: int, K: int, dtype: torch.dtype): + torch.manual_seed(0) if not arch_info.is_fp8_avail(): pytest.skip("MXFP8 GEMM requires FP8-capable arch") if N % 16 != 0 or K % 32 != 0: pytest.skip("Preshuffle requires N % 16 == 0 and K % 32 == 0") torch.cuda.empty_cache() - x_fp8, w_fp8, x_scales, w_scales = generate_inputs(M, N, K) - - # Shuffle the weight tensor in place (layout=(16,16)). The shuffler operates on - # raw bytes; view as uint8 to avoid any dtype quirks, then view back. - w_uint8 = w_fp8.view(torch.uint8) - w_shuffled = shuffle_weight(w_uint8, layout=(16, 16)) + x_fp8, w_fp8, w_kernel, x_scales, w_scales = generate_inputs(M, N, K, shuffle=True) torch_out = run_torch_gemm_afp8wfp8(x_fp8, w_fp8, x_scales, w_scales, dtype) triton_out = gemm_afp8wfp8_preshuffle( - x_fp8, w_shuffled, x_scales, w_scales, dtype=dtype + x_fp8, w_kernel, x_scales, w_scales, dtype=dtype ) - torch.testing.assert_close(triton_out, torch_out, atol=0.5, rtol=5e-2) - - -@pytest.mark.parametrize("M, N, K", [(128, 128, 128), (256, 256, 256)]) -def test_gemm_afp8wfp8_preallocated_output(M: int, N: int, K: int): - if not arch_info.is_fp8_avail(): - pytest.skip("MXFP8 GEMM requires FP8-capable arch") - torch.cuda.empty_cache() - - dtype = torch.bfloat16 - x_fp8, w_fp8, x_scales, w_scales = generate_inputs(M, N, K) - y = torch.empty((M, N), dtype=dtype, device="cuda") - out = gemm_afp8wfp8(x_fp8, w_fp8, x_scales, w_scales, dtype=dtype, y=y) - assert out.data_ptr() == y.data_ptr() - - torch_out = run_torch_gemm_afp8wfp8(x_fp8, w_fp8, x_scales, w_scales, dtype) - torch.testing.assert_close(out, torch_out, atol=0.5, rtol=5e-2) + torch.testing.assert_close(triton_out, torch_out, atol=0.03, rtol=1e-2) From 68731f632a2384d76dfd6fbf71d685b988a44e84 Mon Sep 17 00:00:00 2001 From: ShaoChunLee Date: Tue, 19 May 2026 21:32:14 +0000 Subject: [PATCH 4/8] update --- .../gemm/gfx950-GEMM-AFP8WFP8-PRESHUFFLE-N=512-K=4096.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8-PRESHUFFLE-N=512-K=4096.json b/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8-PRESHUFFLE-N=512-K=4096.json index 63ea464784..91361fa902 100644 --- a/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8-PRESHUFFLE-N=512-K=4096.json +++ b/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8-PRESHUFFLE-N=512-K=4096.json @@ -9,7 +9,7 @@ "waves_per_eu": 0, "matrix_instr_nonkdim": 16, "cache_modifier": null, - "NUM_KSPLIT": 1 + "NUM_KSPLIT": 4 }, "M_LEQ_8": { "BLOCK_SIZE_M": 16, From 73d8719a21d723b3fb19b7478e543b0742e32d0f Mon Sep 17 00:00:00 2001 From: ShaoChunLee Date: Wed, 20 May 2026 15:00:49 +0000 Subject: [PATCH 5/8] add tunned config --- .../gemm/basic/gemm_afp8wfp8.py | 2 +- ...M-AFP8WFP8_PRESHUFFLED-N=1536-K=4096.json} | 0 ...MM-AFP8WFP8_PRESHUFFLED-N=2048-K=7168.json | 98 +++++++++++++++++++ ...M-AFP8WFP8_PRESHUFFLED-N=4096-K=1024.json} | 0 ...MM-AFP8WFP8_PRESHUFFLED-N=4096-K=256.json} | 0 ...MM-AFP8WFP8_PRESHUFFLED-N=512-K=4096.json} | 0 ...MM-AFP8WFP8_PRESHUFFLED-N=7168-K=2048.json | 98 +++++++++++++++++++ ...EMM-AFP8WFP8_PRESHUFFLED-N=7168-K=384.json | 98 +++++++++++++++++++ ...EMM-AFP8WFP8_PRESHUFFLED-N=768-K=7168.json | 98 +++++++++++++++++++ ...M-AFP8WFP8_PRESHUFFLED-N=8192-K=1024.json} | 0 ...MM-AFP8WFP8_PRESHUFFLED-N=8192-K=1536.json | 98 +++++++++++++++++++ ... => gfx950-GEMM-AFP8WFP8_PRESHUFFLED.json} | 0 12 files changed, 491 insertions(+), 1 deletion(-) rename aiter/ops/triton/configs/gemm/{gfx950-GEMM-AFP8WFP8-PRESHUFFLE-N=1536-K=4096.json => gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=1536-K=4096.json} (100%) create mode 100644 aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=2048-K=7168.json rename aiter/ops/triton/configs/gemm/{gfx950-GEMM-AFP8WFP8-PRESHUFFLE-N=4096-K=1024.json => gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=4096-K=1024.json} (100%) rename aiter/ops/triton/configs/gemm/{gfx950-GEMM-AFP8WFP8-PRESHUFFLE-N=4096-K=256.json => gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=4096-K=256.json} (100%) rename aiter/ops/triton/configs/gemm/{gfx950-GEMM-AFP8WFP8-PRESHUFFLE-N=512-K=4096.json => gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=512-K=4096.json} (100%) create mode 100644 aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=7168-K=2048.json create mode 100644 aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=7168-K=384.json create mode 100644 aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=768-K=7168.json rename aiter/ops/triton/configs/gemm/{gfx950-GEMM-AFP8WFP8-PRESHUFFLE-N=8192-K=1024.json => gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=8192-K=1024.json} (100%) create mode 100644 aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=8192-K=1536.json rename aiter/ops/triton/configs/gemm/{gfx950-GEMM-AFP8WFP8-PRESHUFFLE.json => gfx950-GEMM-AFP8WFP8_PRESHUFFLED.json} (100%) diff --git a/aiter/ops/triton/_triton_kernels/gemm/basic/gemm_afp8wfp8.py b/aiter/ops/triton/_triton_kernels/gemm/basic/gemm_afp8wfp8.py index 27b0c121ad..1b549157a8 100644 --- a/aiter/ops/triton/_triton_kernels/gemm/basic/gemm_afp8wfp8.py +++ b/aiter/ops/triton/_triton_kernels/gemm/basic/gemm_afp8wfp8.py @@ -467,6 +467,6 @@ def _get_config( config mechanism. Falls back to the generic fallback file or to _DEFAULT_CONFIG if no JSON is available.""" if shuffle: - return get_gemm_config("GEMM-AFP8WFP8-PRESHUFFLE", M, N, K) + return get_gemm_config("GEMM-AFP8WFP8_PRESHUFFLED", M, N, K) else: return get_gemm_config("GEMM-AFP8WFP8", M, N, K) diff --git a/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8-PRESHUFFLE-N=1536-K=4096.json b/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=1536-K=4096.json similarity index 100% rename from aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8-PRESHUFFLE-N=1536-K=4096.json rename to aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=1536-K=4096.json diff --git a/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=2048-K=7168.json b/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=2048-K=7168.json new file mode 100644 index 0000000000..a07c6c2d26 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=2048-K=7168.json @@ -0,0 +1,98 @@ +{ + "M_LEQ_4": { + "BLOCK_SIZE_M": 8, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 7 + }, + "M_LEQ_8": { + "BLOCK_SIZE_M": 4, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 6, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_16": { + "BLOCK_SIZE_M": 8, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "any": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + } +} diff --git a/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8-PRESHUFFLE-N=4096-K=1024.json b/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=4096-K=1024.json similarity index 100% rename from aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8-PRESHUFFLE-N=4096-K=1024.json rename to aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=4096-K=1024.json diff --git a/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8-PRESHUFFLE-N=4096-K=256.json b/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=4096-K=256.json similarity index 100% rename from aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8-PRESHUFFLE-N=4096-K=256.json rename to aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=4096-K=256.json diff --git a/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8-PRESHUFFLE-N=512-K=4096.json b/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=512-K=4096.json similarity index 100% rename from aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8-PRESHUFFLE-N=512-K=4096.json rename to aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=512-K=4096.json diff --git a/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=7168-K=2048.json b/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=7168-K=2048.json new file mode 100644 index 0000000000..6f74cb9e7f --- /dev/null +++ b/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=7168-K=2048.json @@ -0,0 +1,98 @@ +{ + "M_LEQ_4": { + "BLOCK_SIZE_M": 4, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "M_LEQ_8": { + "BLOCK_SIZE_M": 8, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "M_LEQ_16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "M_LEQ_32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "M_LEQ_64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + } +} diff --git a/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=7168-K=384.json b/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=7168-K=384.json new file mode 100644 index 0000000000..2a3b281530 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=7168-K=384.json @@ -0,0 +1,98 @@ +{ + "M_LEQ_4": { + "BLOCK_SIZE_M": 4, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 6, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "M_LEQ_8": { + "BLOCK_SIZE_M": 4, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "M_LEQ_16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "M_LEQ_32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "M_LEQ_64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + } +} diff --git a/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=768-K=7168.json b/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=768-K=7168.json new file mode 100644 index 0000000000..2d88b2f30c --- /dev/null +++ b/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=768-K=7168.json @@ -0,0 +1,98 @@ +{ + "M_LEQ_4": { + "BLOCK_SIZE_M": 8, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 7 + }, + "M_LEQ_8": { + "BLOCK_SIZE_M": 4, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_16": { + "BLOCK_SIZE_M": 8, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 7 + }, + "M_LEQ_32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 7 + }, + "M_LEQ_64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + } +} diff --git a/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8-PRESHUFFLE-N=8192-K=1024.json b/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=8192-K=1024.json similarity index 100% rename from aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8-PRESHUFFLE-N=8192-K=1024.json rename to aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=8192-K=1024.json diff --git a/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=8192-K=1536.json b/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=8192-K=1536.json new file mode 100644 index 0000000000..defffda52d --- /dev/null +++ b/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=8192-K=1536.json @@ -0,0 +1,98 @@ +{ + "M_LEQ_4": { + "BLOCK_SIZE_M": 4, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 8, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "M_LEQ_8": { + "BLOCK_SIZE_M": 8, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 6, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "M_LEQ_16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "M_LEQ_32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "M_LEQ_64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + } +} diff --git a/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8-PRESHUFFLE.json b/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED.json similarity index 100% rename from aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8-PRESHUFFLE.json rename to aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED.json From 13e7517357de5b425952b969ca64d3efd0fd71e8 Mon Sep 17 00:00:00 2001 From: ShaoChunLee Date: Wed, 20 May 2026 15:15:07 +0000 Subject: [PATCH 6/8] black format --- .../triton/_triton_kernels/quant/quant_mxfp8.py | 9 ++------- aiter/ops/triton/quant/quant_mxfp8.py | 14 +++++++------- 2 files changed, 9 insertions(+), 14 deletions(-) diff --git a/aiter/ops/triton/_triton_kernels/quant/quant_mxfp8.py b/aiter/ops/triton/_triton_kernels/quant/quant_mxfp8.py index fe83509721..14de6ca80f 100644 --- a/aiter/ops/triton/_triton_kernels/quant/quant_mxfp8.py +++ b/aiter/ops/triton/_triton_kernels/quant/quant_mxfp8.py @@ -4,7 +4,6 @@ import triton import triton.language as tl - # MXFP8 activation quant: per-1x32 e8m0 scale + FP8 e4m3 values. # Follows aiter.ops.quant.per_1x32_f8_scale_f8_quant: # MAX_POW2 = int(log2(448)) = 8 @@ -324,9 +323,7 @@ def _dual_rmsnorm_mxfp8_quant_kernel( mask=q_mask, other=0.0, ).to(tl.float32) - g_q = tl.load(gq_ptr + q_col_offsets, mask=q_mask, other=0.0).to( - tl.float32 - ) + g_q = tl.load(gq_ptr + q_col_offsets, mask=q_mask, other=0.0).to(tl.float32) ss_q = tl.sum(x_q * x_q, axis=-1) norm_q = tl.math.rsqrt((ss_q / KQ) + eps_q) @@ -368,9 +365,7 @@ def _dual_rmsnorm_mxfp8_quant_kernel( mask=k_mask, other=0.0, ).to(tl.float32) - g_k = tl.load(gk_ptr + k_col_offsets, mask=k_mask, other=0.0).to( - tl.float32 - ) + g_k = tl.load(gk_ptr + k_col_offsets, mask=k_mask, other=0.0).to(tl.float32) ss_k = tl.sum(x_k * x_k, axis=-1) norm_k = tl.math.rsqrt((ss_k / KK) + eps_k) diff --git a/aiter/ops/triton/quant/quant_mxfp8.py b/aiter/ops/triton/quant/quant_mxfp8.py index ae82cafab8..24a0136150 100644 --- a/aiter/ops/triton/quant/quant_mxfp8.py +++ b/aiter/ops/triton/quant/quant_mxfp8.py @@ -12,7 +12,6 @@ _dual_rmsnorm_mxfp8_quant_kernel, ) - _QUANT_BLOCK_SIZE = 32 @@ -107,9 +106,10 @@ def fp8_legacy_to_mxfp8( M, N = x_fnuz.shape assert N % _QUANT_BLOCK_SIZE == 0 assert N % _LEGACY_BLOCK_SIZE == 0 - assert x_scale_fp32.shape == (M, N // _LEGACY_BLOCK_SIZE), ( - f"x_scale_fp32 shape {x_scale_fp32.shape} != ({M},{N // _LEGACY_BLOCK_SIZE})" - ) + assert x_scale_fp32.shape == ( + M, + N // _LEGACY_BLOCK_SIZE, + ), f"x_scale_fp32 shape {x_scale_fp32.shape} != ({M},{N // _LEGACY_BLOCK_SIZE})" Ns = N // _QUANT_BLOCK_SIZE if y_fn is None: @@ -238,9 +238,9 @@ def dual_rmsnorm_mxfp8_quant( assert M == Mk, f"q rows {M} != k rows {Mk}" assert q_weight.shape == (KQ,), f"q_weight shape {q_weight.shape} != ({KQ},)" assert k_weight.shape == (KK,), f"k_weight shape {k_weight.shape} != ({KK},)" - assert KQ % _QUANT_BLOCK_SIZE == 0, ( - f"KQ={KQ} must be a multiple of {_QUANT_BLOCK_SIZE}" - ) + assert ( + KQ % _QUANT_BLOCK_SIZE == 0 + ), f"KQ={KQ} must be a multiple of {_QUANT_BLOCK_SIZE}" if eps_k is None: eps_k = eps_q From 8494fb5c743026d1e4dece408adc482d69b92f25 Mon Sep 17 00:00:00 2001 From: Shao-Chun Lee Date: Thu, 21 May 2026 10:22:53 -0500 Subject: [PATCH 7/8] [Triton] merge fused gate + downcast kernel (#3306) * add fused_gemm_a16w16_copy_x, remove near-moe torch.to and torch.zeros kernels * tunning * revert fill(0) change --- .../gemm/fused/fused_gemm_a16w16_copy_x.py | 231 ++++++++++++++++++ .../gemm/gfx950-GEMM-A16W16-N=384-K=7168.json | 98 ++++++++ .../gemm/fused/fused_gemm_a16w16_copy_x.py | 156 ++++++++++++ .../fused/test_fused_gemm_a16w16_copy_x.py | 119 +++++++++ 4 files changed, 604 insertions(+) create mode 100644 aiter/ops/triton/_triton_kernels/gemm/fused/fused_gemm_a16w16_copy_x.py create mode 100644 aiter/ops/triton/configs/gemm/gfx950-GEMM-A16W16-N=384-K=7168.json create mode 100644 aiter/ops/triton/gemm/fused/fused_gemm_a16w16_copy_x.py create mode 100644 op_tests/triton_tests/gemm/fused/test_fused_gemm_a16w16_copy_x.py diff --git a/aiter/ops/triton/_triton_kernels/gemm/fused/fused_gemm_a16w16_copy_x.py b/aiter/ops/triton/_triton_kernels/gemm/fused/fused_gemm_a16w16_copy_x.py new file mode 100644 index 0000000000..4b37f90b7c --- /dev/null +++ b/aiter/ops/triton/_triton_kernels/gemm/fused/fused_gemm_a16w16_copy_x.py @@ -0,0 +1,231 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. + +import triton.language as tl +from aiter.ops.triton.utils._triton.kernel_repr import make_kernel_repr +from aiter.ops.triton.utils._triton.pid_preprocessing import pid_grid, remap_xcd +from aiter.ops.triton.utils.gemm_config_utils import ( + compute_splitk_params, + get_gemm_config, +) + +import triton + +_fused_gemm_a16w16_copy_x_repr = make_kernel_repr( + "_fused_gemm_a16w16_copy_x_kernel", + [ + "BLOCK_SIZE_M", + "BLOCK_SIZE_N", + "BLOCK_SIZE_K", + "GROUP_SIZE_M", + "NUM_KSPLIT", + "SPLITK_BLOCK_SIZE", + "EVEN_K", + "EVEN_MN", + "cache_modifier", + "activation", + "use_activation", + "ADD_BIAS", + "SKIP_REDUCE", + ], +) + + +@triton.heuristics( + { + "EVEN_K": lambda args: (args["K"] % (args["SPLITK_BLOCK_SIZE"]) == 0) + and (args["SPLITK_BLOCK_SIZE"] % args["BLOCK_SIZE_K"] == 0), + "EVEN_MN": lambda args: (args["M"] % args["BLOCK_SIZE_M"] == 0) + and (args["N"] % args["BLOCK_SIZE_N"] == 0), + } +) +@triton.jit( + repr=_fused_gemm_a16w16_copy_x_repr, + do_not_specialize=["M", "N"], +) +def _fused_gemm_a16w16_copy_x_kernel( + a_ptr, + b_ptr, + bias_ptr, + c_ptr, + a_copy_ptr, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_ck, + stride_cm, + stride_cn, + stride_a_copy_m, + stride_a_copy_k, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + NUM_KSPLIT: tl.constexpr, + SPLITK_BLOCK_SIZE: tl.constexpr, + EVEN_K: tl.constexpr, + EVEN_MN: tl.constexpr, + cache_modifier: tl.constexpr, + activation: tl.constexpr, + use_activation: tl.constexpr, + ADD_BIAS: tl.constexpr, + SKIP_REDUCE: tl.constexpr, +): + """Kernel that computes C = A x B and also emits a downcasted copy of A. + + The grid is laid out as a single 1D program-id space split into two + contiguous regions: + + * `[0, NUM_KSPLIT * num_pid_m * num_pid_n)` runs the GEMM (identical to + the unfused a16w16 kernel). + * `[NUM_KSPLIT * num_pid_m * num_pid_n, GEMM_GRID + num_pid_m * num_pid_k_copy)` + runs the pure downcast-copy of A: each program handles one + `BLOCK_SIZE_M x BLOCK_SIZE_K` tile of A and writes it to `a_copy_ptr`. + + Decoupling the copy from the GEMM body keeps the GEMM hot loop unchanged + and lets the copy programs schedule independently. + """ + + tl.assume(stride_am > 0) + tl.assume(stride_ak > 0) + tl.assume(stride_bk > 0) + tl.assume(stride_bn > 0) + tl.assume(stride_ck > 0) + tl.assume(stride_cm > 0) + tl.assume(stride_cn > 0) + tl.assume(stride_a_copy_m > 0) + tl.assume(stride_a_copy_k > 0) + + pid_unified = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_k_copy = tl.cdiv(K, BLOCK_SIZE_K) + GEMM_GRID = num_pid_m * num_pid_n * NUM_KSPLIT + + if pid_unified < GEMM_GRID: + # ---- GEMM branch ---------------------------------------------------- + pid_unified = remap_xcd(pid_unified, GEMM_GRID, NUM_XCDS=8) + pid_k = pid_unified % NUM_KSPLIT + pid = pid_unified // NUM_KSPLIT + + if NUM_KSPLIT == 1: + pid_m, pid_n = pid_grid( + pid, num_pid_m, num_pid_n, GROUP_SIZE_M=GROUP_SIZE_M + ) + else: + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) + tl.assume(pid_k >= 0) + + split_k_start = pid_k * SPLITK_BLOCK_SIZE + if split_k_start < K: + offs_k = tl.arange(0, BLOCK_SIZE_K) + offs_k_split = split_k_start + offs_k + if EVEN_MN: + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + else: + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + + a_ptrs = a_ptr + ( + offs_am[:, None] * stride_am + offs_k_split[None, :] * stride_ak + ) + b_ptrs = b_ptr + ( + offs_k_split[:, None] * stride_bk + offs_bn[None, :] * stride_bn + ) + + acc_dtype = tl.float32 if c_ptr.type.element_ty != tl.int8 else tl.int32 + if ADD_BIAS: + if NUM_KSPLIT == 1 or (SKIP_REDUCE and pid_k == 0): + accumulator = tl.load(bias_ptr + offs_bn).to(dtype=acc_dtype) + accumulator = tl.broadcast_to( + accumulator[None, :], (BLOCK_SIZE_M, BLOCK_SIZE_N) + ) + else: + accumulator = tl.zeros( + (BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype + ) + else: + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) + + split_k_end = tl.minimum(split_k_start + SPLITK_BLOCK_SIZE, K) + k_span = split_k_end - split_k_start + num_k_iter = tl.cdiv(k_span, BLOCK_SIZE_K) + + for k in range(num_k_iter): + if EVEN_K: + a = tl.load(a_ptrs) + b = tl.load(b_ptrs, cache_modifier=cache_modifier) + else: + a = tl.load( + a_ptrs, + mask=offs_k[None, :] < k_span - k * BLOCK_SIZE_K, + other=0.0, + ) + b = tl.load( + b_ptrs, + mask=offs_k[:, None] < k_span - k * BLOCK_SIZE_K, + other=0.0, + cache_modifier=cache_modifier, + ) + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + if use_activation and NUM_KSPLIT == 1: + accumulator = activation(accumulator) + + c = accumulator.to(c_ptr.type.element_ty) + offs_cm = pid_m.to(tl.int64) * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = ( + c_ptr + + stride_cm * offs_cm[:, None] + + stride_cn * offs_cn[None, :] + + pid_k * stride_ck + ) + if EVEN_MN: + tl.store(c_ptrs, c) + else: + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + else: + # ---- Downcast-copy branch ------------------------------------------ + pid_copy = pid_unified - GEMM_GRID + pid_m = pid_copy // num_pid_k_copy + pid_k = pid_copy % num_pid_k_copy + + tl.assume(pid_m >= 0) + tl.assume(pid_k >= 0) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_k = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + + a_ptrs = a_ptr + (offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak) + a_copy_ptrs = a_copy_ptr + ( + offs_m[:, None] * stride_a_copy_m + offs_k[None, :] * stride_a_copy_k + ) + + mask = (offs_m[:, None] < M) & (offs_k[None, :] < K) + a = tl.load(a_ptrs, mask=mask, other=0.0) + tl.store(a_copy_ptrs, a.to(a_copy_ptr.type.element_ty), mask=mask) + + +def _get_config( + M: int, + N: int, + K: int, +): + # Use the same tuning portal as the unfused gemm_a16w16 — the extra + # downcast copy is assumed not to shift the optimal config. + config, is_tunned = get_gemm_config("GEMM-A16W16", M, N, K) + return compute_splitk_params(config, K), is_tunned diff --git a/aiter/ops/triton/configs/gemm/gfx950-GEMM-A16W16-N=384-K=7168.json b/aiter/ops/triton/configs/gemm/gfx950-GEMM-A16W16-N=384-K=7168.json new file mode 100644 index 0000000000..9b9cce2832 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/gfx950-GEMM-A16W16-N=384-K=7168.json @@ -0,0 +1,98 @@ +{ + "M_LEQ_4": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 6, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 7 + }, + "M_LEQ_8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 8, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 7 + }, + "M_LEQ_16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 1024, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 7 + }, + "M_LEQ_32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 7 + }, + "M_LEQ_64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 7 + }, + "M_LEQ_128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 7 + }, + "M_LEQ_256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 6, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 7 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + } +} diff --git a/aiter/ops/triton/gemm/fused/fused_gemm_a16w16_copy_x.py b/aiter/ops/triton/gemm/fused/fused_gemm_a16w16_copy_x.py new file mode 100644 index 0000000000..f932c480f2 --- /dev/null +++ b/aiter/ops/triton/gemm/fused/fused_gemm_a16w16_copy_x.py @@ -0,0 +1,156 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. + +from typing import Optional, Tuple +import torch +import triton +from aiter.ops.triton._triton_kernels.gemm.fused.fused_gemm_a16w16_copy_x import ( + _fused_gemm_a16w16_copy_x_kernel, + _get_config, +) +from aiter.ops.triton._triton_kernels.gemm.basic.gemm_a16w16 import ( + _gemm_a16w16_reduce_kernel, +) +from aiter.ops.triton._triton_kernels.activation import _get_activation_from_str +from aiter.ops.triton.utils.logger import AiterTritonLogger +from aiter import dtypes + +_LOGGER = AiterTritonLogger() + + +def fused_gemm_a16w16_copy_x( + x, + w, + bias: Optional[torch.Tensor] = None, + dtype: Optional[torch.dtype] = torch.bfloat16, + copy_dtype: Optional[torch.dtype] = None, + y: Optional[torch.Tensor] = None, + x_copy: Optional[torch.Tensor] = None, + config: Optional[dict] = None, + activation: Optional[str] = None, + skip_reduce: Optional[bool] = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Computes 16-bit matmul Y = X @ W^T and also emits a downcasted copy of X. + + This fuses the GEMM with the activation-quantization downcast that + immediately follows the router-gate GEMM in MoE flows (e.g. DSv4), + avoiding a separate kernel/DRAM pass to cast X from BF16 to FP8. + + The fused kernel uses a single 1D grid split into two regions: GEMM tiles + occupy `NUM_KSPLIT * cdiv(M, BLOCK_M) * cdiv(N, BLOCK_N)` programs and an + additional `cdiv(M, BLOCK_M) * cdiv(K, BLOCK_K)` programs handle the + downcast copy. + + Args: + x (torch.Tensor): Input matrix with shape (M, K). + w (torch.Tensor): Weight matrix with shape (N, K), internally transposed. + bias (Optional[torch.Tensor]): Bias vector with shape (N,). + dtype (Optional[torch.dtype]): Output Y datatype (BF16 or FP16). + copy_dtype (Optional[torch.dtype]): Output X-copy datatype + (defaults to aiter.dtypes.fp8). + y (Optional[torch.Tensor]): Pre-allocated output with shape (M, N). + x_copy (Optional[torch.Tensor]): Pre-allocated downcasted X copy with + shape (M, K). + config (Optional[dict]): Kernel tuning parameters. + activation (Optional[str]): Activation fused into Y. + skip_reduce (Optional[bool]): Skip split-K reduction for Y. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: (Y, x_copy). When skip_reduce=True + and NUM_KSPLIT > 1, Y has shape (NUM_KSPLIT, M, N). + """ + + _LOGGER.info(f"FUSED_GEMM_A16W16_COPY_X: x={tuple(x.shape)} w={tuple(w.shape)}") + # Shape checks + assert x.shape[1] == w.shape[1], "Incompatible matrix shapes." + + if copy_dtype is None: + copy_dtype = dtypes.fp8 + + M, K = x.shape + N, K = w.shape + w = w.T + + if config is None: + config, _ = _get_config(M, N, K) + + if y is None and (config["NUM_KSPLIT"] == 1 or not skip_reduce): + y = torch.empty((M, N), dtype=dtype, device=x.device) + + if x_copy is None: + x_copy = torch.empty((M, K), dtype=copy_dtype, device=x.device) + + if config["NUM_KSPLIT"] > 1: + y_pp = torch.empty( + (config["NUM_KSPLIT"], M, N), + dtype=torch.float32, + device=y.device if y is not None else x.device, + ) + else: + y_pp = None + + grid = lambda META: ( # noqa: E731 + META["NUM_KSPLIT"] + * triton.cdiv(M, META["BLOCK_SIZE_M"]) + * triton.cdiv(N, META["BLOCK_SIZE_N"]) + + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(K, META["BLOCK_SIZE_K"]), + ) + _fused_gemm_a16w16_copy_x_kernel[grid]( + x, + w, + bias, + y if config["NUM_KSPLIT"] == 1 else y_pp, + x_copy, + M, + N, + K, + x.stride(0), + x.stride(1), + w.stride(0), + w.stride(1), + 0 if config["NUM_KSPLIT"] == 1 else y_pp.stride(0), + y.stride(0) if config["NUM_KSPLIT"] == 1 else y_pp.stride(1), + y.stride(1) if config["NUM_KSPLIT"] == 1 else y_pp.stride(2), + x_copy.stride(0), + x_copy.stride(1), + activation=_get_activation_from_str(activation) if activation else "", + use_activation=activation is not None, + ADD_BIAS=(bias is not None), + SKIP_REDUCE=skip_reduce, + **config, + ) + + if config["NUM_KSPLIT"] > 1: + if skip_reduce: + return y_pp, x_copy + + REDUCE_BLOCK_SIZE_M = 32 + REDUCE_BLOCK_SIZE_N = 32 + ACTUAL_KSPLIT = triton.cdiv(K, config["SPLITK_BLOCK_SIZE"]) + + grid_reduce = ( + triton.cdiv(M, REDUCE_BLOCK_SIZE_M), + triton.cdiv(N, REDUCE_BLOCK_SIZE_N), + ) + _gemm_a16w16_reduce_kernel[grid_reduce]( + bias, + y_pp, + y, + M, + N, + y_pp.stride(0), + y_pp.stride(1), + y_pp.stride(2), + y.stride(0), + y.stride(1), + REDUCE_BLOCK_SIZE_M, + REDUCE_BLOCK_SIZE_N, + ACTUAL_KSPLIT, + triton.next_power_of_2(config["NUM_KSPLIT"]), + activation=_get_activation_from_str(activation) if activation else "", + use_activation=activation is not None, + ADD_BIAS=(bias is not None), + ) + + return y, x_copy diff --git a/op_tests/triton_tests/gemm/fused/test_fused_gemm_a16w16_copy_x.py b/op_tests/triton_tests/gemm/fused/test_fused_gemm_a16w16_copy_x.py new file mode 100644 index 0000000000..24a61caa2c --- /dev/null +++ b/op_tests/triton_tests/gemm/fused/test_fused_gemm_a16w16_copy_x.py @@ -0,0 +1,119 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. + +import torch +import torch.nn.functional as F +import pytest + +from aiter import dtypes +from aiter.ops.triton.gemm.fused.fused_gemm_a16w16_copy_x import ( + fused_gemm_a16w16_copy_x, +) +from op_tests.triton_tests.gemm.basic.test_gemm_a16w16 import ( + generate_gemm_a16w16_inputs, +) + + +def get_x_vals(): + x_vals = [(1, 1, 1)] + x_vals += [(3, 5, 2)] + x_vals += [(1024, 1024, 1024)] + x_vals += [(2048, 2048, 2048)] + # DSv4 router gate: num_tokens x 384 x 7168 + x_vals += [(2**i, 384, 7168) for i in range(5, 9)] + # DSR1 router GEMM + x_vals += [(2**i, 256, 7168) for i in range(5, 9)] + return x_vals + + +@pytest.mark.parametrize("M, N, K", get_x_vals()) +def test_fused_gemm_a16w16_copy_x(M: int, N: int, K: int): + torch.cuda.empty_cache() + x, w, _, _, _ = generate_gemm_a16w16_inputs( + M, N, K, dtype=torch.bfloat16, output=False + ) + + torch_y = F.linear(x, w, bias=None) + torch_x_copy = x.to(dtypes.fp8) + + triton_y, triton_x_copy = fused_gemm_a16w16_copy_x(x, w) + + torch.testing.assert_close(triton_y, torch_y, atol=1e-1, rtol=1e-2) + # x_copy comparison: both are produced by an identical BF16 -> FP8 cast, + # so we expect bitwise equality. Compare via BF16 to avoid FP8 dtype + # ambiguities across gfx targets. + torch.testing.assert_close( + triton_x_copy.to(torch.bfloat16), + torch_x_copy.to(torch.bfloat16), + atol=0.0, + rtol=0.0, + ) + + +def get_fewer_x_vals(): + x_vals = [(16, 1024, 1024)] + x_vals += [(128, 8192, 512)] + x_vals += [(256, 512, 8192)] + x_vals += [(1024, 1024, 1024)] + return x_vals + + +@pytest.mark.parametrize("activation", ["gelu", "gelu_tanh", "silu"]) +@pytest.mark.parametrize("M, N, K", get_fewer_x_vals()) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("output", [True, False]) +def test_fused_gemm_a16w16_copy_x_activation( + M: int, N: int, K: int, dtype, output, activation +): + x, w, _, _, y = generate_gemm_a16w16_inputs(M, N, K, dtype, output=output) + + torch_y = F.linear(x, w, bias=None) + if activation == "gelu": + torch_y = F.gelu(torch_y) + elif activation == "gelu_tanh": + torch_y = F.gelu(torch_y, approximate="tanh") + elif activation == "silu": + torch_y = F.silu(torch_y) + torch_x_copy = x.to(dtypes.fp8) + + triton_y, triton_x_copy = fused_gemm_a16w16_copy_x( + x, + w, + bias=None, + dtype=dtype, + y=y, + activation=activation, + ) + + torch.testing.assert_close(triton_y, torch_y, atol=1e-1, rtol=1e-2) + torch.testing.assert_close( + triton_x_copy.to(torch.bfloat16), + torch_x_copy.to(torch.bfloat16), + atol=0.0, + rtol=0.0, + ) + + +@pytest.mark.parametrize("M, N, K", get_fewer_x_vals()) +@pytest.mark.parametrize("skip_reduce", [True, False]) +def test_fused_gemm_a16w16_copy_x_skip_reduce(M: int, N: int, K: int, skip_reduce): + torch.cuda.empty_cache() + x, w, _, _, _ = generate_gemm_a16w16_inputs( + M, N, K, dtype=torch.bfloat16, output=False + ) + + torch_y = F.linear(x, w, bias=None) + torch_x_copy = x.to(dtypes.fp8) + + triton_y, triton_x_copy = fused_gemm_a16w16_copy_x(x, w, skip_reduce=skip_reduce) + + if triton_y.dim() == 3: + triton_y = triton_y.sum(axis=0).to(torch.bfloat16) + + torch.testing.assert_close(triton_y, torch_y, atol=1e-1, rtol=1e-2) + torch.testing.assert_close( + triton_x_copy.to(torch.bfloat16), + torch_x_copy.to(torch.bfloat16), + atol=0.0, + rtol=0.0, + ) From b86707c76a3598ba90bbc2285589fd6fd116f7e6 Mon Sep 17 00:00:00 2001 From: ShaoChunLee Date: Thu, 21 May 2026 15:47:05 +0000 Subject: [PATCH 8/8] fused_gemm_a16w16_quant_x --- ...copy_x.py => fused_gemm_a16w16_quant_x.py} | 76 +++++++--- ...copy_x.py => fused_gemm_a16w16_quant_x.py} | 89 +++++++---- .../fused/test_fused_gemm_a16w16_copy_x.py | 119 --------------- .../fused/test_fused_gemm_a16w16_quant_x.py | 143 ++++++++++++++++++ 4 files changed, 254 insertions(+), 173 deletions(-) rename aiter/ops/triton/_triton_kernels/gemm/fused/{fused_gemm_a16w16_copy_x.py => fused_gemm_a16w16_quant_x.py} (72%) rename aiter/ops/triton/gemm/fused/{fused_gemm_a16w16_copy_x.py => fused_gemm_a16w16_quant_x.py} (62%) delete mode 100644 op_tests/triton_tests/gemm/fused/test_fused_gemm_a16w16_copy_x.py create mode 100644 op_tests/triton_tests/gemm/fused/test_fused_gemm_a16w16_quant_x.py diff --git a/aiter/ops/triton/_triton_kernels/gemm/fused/fused_gemm_a16w16_copy_x.py b/aiter/ops/triton/_triton_kernels/gemm/fused/fused_gemm_a16w16_quant_x.py similarity index 72% rename from aiter/ops/triton/_triton_kernels/gemm/fused/fused_gemm_a16w16_copy_x.py rename to aiter/ops/triton/_triton_kernels/gemm/fused/fused_gemm_a16w16_quant_x.py index 4b37f90b7c..16de00ed92 100644 --- a/aiter/ops/triton/_triton_kernels/gemm/fused/fused_gemm_a16w16_copy_x.py +++ b/aiter/ops/triton/_triton_kernels/gemm/fused/fused_gemm_a16w16_quant_x.py @@ -11,8 +11,8 @@ import triton -_fused_gemm_a16w16_copy_x_repr = make_kernel_repr( - "_fused_gemm_a16w16_copy_x_kernel", +_fused_gemm_a16w16_quant_x_repr = make_kernel_repr( + "_fused_gemm_a16w16_quant_x_kernel", [ "BLOCK_SIZE_M", "BLOCK_SIZE_N", @@ -20,6 +20,7 @@ "GROUP_SIZE_M", "NUM_KSPLIT", "SPLITK_BLOCK_SIZE", + "QUANT_BLOCK_SIZE", "EVEN_K", "EVEN_MN", "cache_modifier", @@ -40,15 +41,16 @@ } ) @triton.jit( - repr=_fused_gemm_a16w16_copy_x_repr, + repr=_fused_gemm_a16w16_quant_x_repr, do_not_specialize=["M", "N"], ) -def _fused_gemm_a16w16_copy_x_kernel( +def _fused_gemm_a16w16_quant_x_kernel( a_ptr, b_ptr, bias_ptr, c_ptr, - a_copy_ptr, + a_quant_ptr, + a_scale_ptr, M, N, K, @@ -59,8 +61,10 @@ def _fused_gemm_a16w16_copy_x_kernel( stride_ck, stride_cm, stride_cn, - stride_a_copy_m, - stride_a_copy_k, + stride_a_quant_m, + stride_a_quant_k, + stride_a_scale_m, + stride_a_scale_n, # Meta-parameters BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, @@ -68,6 +72,7 @@ def _fused_gemm_a16w16_copy_x_kernel( GROUP_SIZE_M: tl.constexpr, NUM_KSPLIT: tl.constexpr, SPLITK_BLOCK_SIZE: tl.constexpr, + QUANT_BLOCK_SIZE: tl.constexpr, EVEN_K: tl.constexpr, EVEN_MN: tl.constexpr, cache_modifier: tl.constexpr, @@ -76,7 +81,7 @@ def _fused_gemm_a16w16_copy_x_kernel( ADD_BIAS: tl.constexpr, SKIP_REDUCE: tl.constexpr, ): - """Kernel that computes C = A x B and also emits a downcasted copy of A. + """Kernel that computes C = A x B and also emits an MXFP8-quantized A. The grid is laid out as a single 1D program-id space split into two contiguous regions: @@ -84,11 +89,12 @@ def _fused_gemm_a16w16_copy_x_kernel( * `[0, NUM_KSPLIT * num_pid_m * num_pid_n)` runs the GEMM (identical to the unfused a16w16 kernel). * `[NUM_KSPLIT * num_pid_m * num_pid_n, GEMM_GRID + num_pid_m * num_pid_k_copy)` - runs the pure downcast-copy of A: each program handles one - `BLOCK_SIZE_M x BLOCK_SIZE_K` tile of A and writes it to `a_copy_ptr`. + runs the per-1x32 MXFP8 quantization of A: each program handles one + `BLOCK_SIZE_M x BLOCK_SIZE_K` tile of A, derives per-1x32 e8m0 scales, + and writes both the FP8 values and the uint8 scales. - Decoupling the copy from the GEMM body keeps the GEMM hot loop unchanged - and lets the copy programs schedule independently. + BLOCK_SIZE_K must be a multiple of QUANT_BLOCK_SIZE (=32) so that each tile + contains a whole number of MXFP8 groups per row. """ tl.assume(stride_am > 0) @@ -98,8 +104,10 @@ def _fused_gemm_a16w16_copy_x_kernel( tl.assume(stride_ck > 0) tl.assume(stride_cm > 0) tl.assume(stride_cn > 0) - tl.assume(stride_a_copy_m > 0) - tl.assume(stride_a_copy_k > 0) + tl.assume(stride_a_quant_m > 0) + tl.assume(stride_a_quant_k > 0) + tl.assume(stride_a_scale_m > 0) + tl.assume(stride_a_scale_n > 0) pid_unified = tl.program_id(axis=0) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) @@ -199,7 +207,7 @@ def _fused_gemm_a16w16_copy_x_kernel( c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) tl.store(c_ptrs, c, mask=c_mask) else: - # ---- Downcast-copy branch ------------------------------------------ + # ---- MXFP8 quant branch -------------------------------------------- pid_copy = pid_unified - GEMM_GRID pid_m = pid_copy // num_pid_k_copy pid_k = pid_copy % num_pid_k_copy @@ -211,13 +219,39 @@ def _fused_gemm_a16w16_copy_x_kernel( offs_k = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) a_ptrs = a_ptr + (offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak) - a_copy_ptrs = a_copy_ptr + ( - offs_m[:, None] * stride_a_copy_m + offs_k[None, :] * stride_a_copy_k + mask = (offs_m[:, None] < M) & (offs_k[None, :] < K) + a = tl.load(a_ptrs, mask=mask, other=0.0).to(tl.float32) + + # Per-1x32 MXFP8 quant. Group along K within this tile. + n_groups: tl.constexpr = BLOCK_SIZE_K // QUANT_BLOCK_SIZE + a_2d = tl.reshape(a, (BLOCK_SIZE_M, n_groups, QUANT_BLOCK_SIZE)) + amax = tl.max(tl.abs(a_2d), axis=2, keep_dims=True) # (M, G, 1) + + amax_i32 = amax.to(tl.int32, bitcast=True) + amax_i32 = (amax_i32 + 0x200000).to(tl.uint32, bitcast=True) & 0xFF800000 + amax_p2 = amax_i32.to(tl.float32, bitcast=True) + scale_unbiased = tl.log2(amax_p2).floor() - 8 + scale_unbiased = tl.clamp(scale_unbiased, min=-127, max=127) + scale_e8m0 = (scale_unbiased.to(tl.int32) + 127).to(tl.uint8) # (M, G, 1) + quant_scale = tl.exp2(-scale_unbiased) # (M, G, 1) + + qa_2d = a_2d * quant_scale + qa = tl.reshape(qa_2d, (BLOCK_SIZE_M, BLOCK_SIZE_K)) + a_quant_ptrs = a_quant_ptr + ( + offs_m[:, None] * stride_a_quant_m + offs_k[None, :] * stride_a_quant_k ) + tl.store(a_quant_ptrs, qa.to(a_quant_ptr.type.element_ty), mask=mask) - mask = (offs_m[:, None] < M) & (offs_k[None, :] < K) - a = tl.load(a_ptrs, mask=mask, other=0.0) - tl.store(a_copy_ptrs, a.to(a_copy_ptr.type.element_ty), mask=mask) + # Store scales: shape (M, K // QUANT_BLOCK_SIZE). + offs_s_n = pid_k * n_groups + tl.arange(0, n_groups) + scale_2d = tl.reshape(scale_e8m0, (BLOCK_SIZE_M, n_groups)) + a_scale_ptrs = a_scale_ptr + ( + offs_m[:, None] * stride_a_scale_m + offs_s_n[None, :] * stride_a_scale_n + ) + scale_mask = (offs_m[:, None] < M) & ( + offs_s_n[None, :] < (K // QUANT_BLOCK_SIZE) + ) + tl.store(a_scale_ptrs, scale_2d, mask=scale_mask) def _get_config( @@ -226,6 +260,6 @@ def _get_config( K: int, ): # Use the same tuning portal as the unfused gemm_a16w16 — the extra - # downcast copy is assumed not to shift the optimal config. + # MXFP8 quant is assumed not to shift the optimal config. config, is_tunned = get_gemm_config("GEMM-A16W16", M, N, K) return compute_splitk_params(config, K), is_tunned diff --git a/aiter/ops/triton/gemm/fused/fused_gemm_a16w16_copy_x.py b/aiter/ops/triton/gemm/fused/fused_gemm_a16w16_quant_x.py similarity index 62% rename from aiter/ops/triton/gemm/fused/fused_gemm_a16w16_copy_x.py rename to aiter/ops/triton/gemm/fused/fused_gemm_a16w16_quant_x.py index f932c480f2..d580687ef5 100644 --- a/aiter/ops/triton/gemm/fused/fused_gemm_a16w16_copy_x.py +++ b/aiter/ops/triton/gemm/fused/fused_gemm_a16w16_quant_x.py @@ -4,69 +4,78 @@ from typing import Optional, Tuple import torch import triton -from aiter.ops.triton._triton_kernels.gemm.fused.fused_gemm_a16w16_copy_x import ( - _fused_gemm_a16w16_copy_x_kernel, +from aiter.ops.triton._triton_kernels.gemm.fused.fused_gemm_a16w16_quant_x import ( + _fused_gemm_a16w16_quant_x_kernel, _get_config, ) -from aiter.ops.triton._triton_kernels.gemm.basic.gemm_a16w16 import ( - _gemm_a16w16_reduce_kernel, +from aiter.ops.triton._triton_kernels.common.splitk_reduce import ( + _gemm_splitk_reduce_kernel, ) from aiter.ops.triton._triton_kernels.activation import _get_activation_from_str from aiter.ops.triton.utils.logger import AiterTritonLogger -from aiter import dtypes _LOGGER = AiterTritonLogger() +_QUANT_BLOCK_SIZE = 32 -def fused_gemm_a16w16_copy_x( + +def fused_gemm_a16w16_quant_x( x, w, bias: Optional[torch.Tensor] = None, dtype: Optional[torch.dtype] = torch.bfloat16, - copy_dtype: Optional[torch.dtype] = None, + quant_dtype: Optional[torch.dtype] = None, y: Optional[torch.Tensor] = None, - x_copy: Optional[torch.Tensor] = None, + x_quant: Optional[torch.Tensor] = None, + x_scales: Optional[torch.Tensor] = None, config: Optional[dict] = None, activation: Optional[str] = None, skip_reduce: Optional[bool] = False, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ - Computes 16-bit matmul Y = X @ W^T and also emits a downcasted copy of X. + Computes 16-bit matmul Y = X @ W^T and also emits an MXFP8-quantized X. - This fuses the GEMM with the activation-quantization downcast that + This fuses the GEMM with the activation per-1x32 MXFP8 quantization that immediately follows the router-gate GEMM in MoE flows (e.g. DSv4), - avoiding a separate kernel/DRAM pass to cast X from BF16 to FP8. + avoiding a separate kernel/DRAM pass to cast X from BF16 to FP8 + e8m0 + scales. The fused kernel uses a single 1D grid split into two regions: GEMM tiles occupy `NUM_KSPLIT * cdiv(M, BLOCK_M) * cdiv(N, BLOCK_N)` programs and an additional `cdiv(M, BLOCK_M) * cdiv(K, BLOCK_K)` programs handle the - downcast copy. + MXFP8 quant of X. Args: - x (torch.Tensor): Input matrix with shape (M, K). + x (torch.Tensor): Input matrix with shape (M, K). K must be a multiple + of 32. w (torch.Tensor): Weight matrix with shape (N, K), internally transposed. bias (Optional[torch.Tensor]): Bias vector with shape (N,). dtype (Optional[torch.dtype]): Output Y datatype (BF16 or FP16). - copy_dtype (Optional[torch.dtype]): Output X-copy datatype - (defaults to aiter.dtypes.fp8). + quant_dtype (Optional[torch.dtype]): FP8 dtype for the MXFP8 quantized + X values (defaults to torch.float8_e4m3fn). y (Optional[torch.Tensor]): Pre-allocated output with shape (M, N). - x_copy (Optional[torch.Tensor]): Pre-allocated downcasted X copy with + x_quant (Optional[torch.Tensor]): Pre-allocated MXFP8 quantized X with shape (M, K). + x_scales (Optional[torch.Tensor]): Pre-allocated uint8 e8m0 scales with + shape (M, K // 32). config (Optional[dict]): Kernel tuning parameters. activation (Optional[str]): Activation fused into Y. skip_reduce (Optional[bool]): Skip split-K reduction for Y. Returns: - Tuple[torch.Tensor, torch.Tensor]: (Y, x_copy). When skip_reduce=True - and NUM_KSPLIT > 1, Y has shape (NUM_KSPLIT, M, N). + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: (Y, x_quant, x_scales). + When skip_reduce=True and NUM_KSPLIT > 1, Y has shape (NUM_KSPLIT, M, N). """ - _LOGGER.info(f"FUSED_GEMM_A16W16_COPY_X: x={tuple(x.shape)} w={tuple(w.shape)}") + _LOGGER.info(f"FUSED_GEMM_A16W16_QUANT_X: x={tuple(x.shape)} w={tuple(w.shape)}") # Shape checks assert x.shape[1] == w.shape[1], "Incompatible matrix shapes." + assert ( + x.shape[1] % _QUANT_BLOCK_SIZE == 0 + ), f"K={x.shape[1]} must be a multiple of {_QUANT_BLOCK_SIZE} for MXFP8 quant" - if copy_dtype is None: - copy_dtype = dtypes.fp8 + if quant_dtype is None: + quant_dtype = torch.float8_e4m3fn M, K = x.shape N, K = w.shape @@ -75,11 +84,20 @@ def fused_gemm_a16w16_copy_x( if config is None: config, _ = _get_config(M, N, K) + assert ( + config["BLOCK_SIZE_K"] % _QUANT_BLOCK_SIZE == 0 + ), f"BLOCK_SIZE_K={config['BLOCK_SIZE_K']} must be a multiple of {_QUANT_BLOCK_SIZE}" + if y is None and (config["NUM_KSPLIT"] == 1 or not skip_reduce): y = torch.empty((M, N), dtype=dtype, device=x.device) - if x_copy is None: - x_copy = torch.empty((M, K), dtype=copy_dtype, device=x.device) + if x_quant is None: + x_quant = torch.empty((M, K), dtype=quant_dtype, device=x.device) + + if x_scales is None: + x_scales = torch.empty( + (M, K // _QUANT_BLOCK_SIZE), dtype=torch.uint8, device=x.device + ) if config["NUM_KSPLIT"] > 1: y_pp = torch.empty( @@ -96,12 +114,13 @@ def fused_gemm_a16w16_copy_x( * triton.cdiv(N, META["BLOCK_SIZE_N"]) + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(K, META["BLOCK_SIZE_K"]), ) - _fused_gemm_a16w16_copy_x_kernel[grid]( + _fused_gemm_a16w16_quant_x_kernel[grid]( x, w, bias, y if config["NUM_KSPLIT"] == 1 else y_pp, - x_copy, + x_quant, + x_scales, M, N, K, @@ -112,18 +131,21 @@ def fused_gemm_a16w16_copy_x( 0 if config["NUM_KSPLIT"] == 1 else y_pp.stride(0), y.stride(0) if config["NUM_KSPLIT"] == 1 else y_pp.stride(1), y.stride(1) if config["NUM_KSPLIT"] == 1 else y_pp.stride(2), - x_copy.stride(0), - x_copy.stride(1), + x_quant.stride(0), + x_quant.stride(1), + x_scales.stride(0), + x_scales.stride(1), activation=_get_activation_from_str(activation) if activation else "", use_activation=activation is not None, ADD_BIAS=(bias is not None), SKIP_REDUCE=skip_reduce, + QUANT_BLOCK_SIZE=_QUANT_BLOCK_SIZE, **config, ) if config["NUM_KSPLIT"] > 1: if skip_reduce: - return y_pp, x_copy + return y_pp, x_quant, x_scales REDUCE_BLOCK_SIZE_M = 32 REDUCE_BLOCK_SIZE_N = 32 @@ -133,10 +155,10 @@ def fused_gemm_a16w16_copy_x( triton.cdiv(M, REDUCE_BLOCK_SIZE_M), triton.cdiv(N, REDUCE_BLOCK_SIZE_N), ) - _gemm_a16w16_reduce_kernel[grid_reduce]( - bias, + _gemm_splitk_reduce_kernel[grid_reduce]( y_pp, y, + bias, M, N, y_pp.stride(0), @@ -148,9 +170,10 @@ def fused_gemm_a16w16_copy_x( REDUCE_BLOCK_SIZE_N, ACTUAL_KSPLIT, triton.next_power_of_2(config["NUM_KSPLIT"]), + ADD_BIAS=(bias is not None), activation=_get_activation_from_str(activation) if activation else "", use_activation=activation is not None, - ADD_BIAS=(bias is not None), + KERNEL_NAME="_fused_gemm_a16w16_quant_x_reduce_kernel", ) - return y, x_copy + return y, x_quant, x_scales diff --git a/op_tests/triton_tests/gemm/fused/test_fused_gemm_a16w16_copy_x.py b/op_tests/triton_tests/gemm/fused/test_fused_gemm_a16w16_copy_x.py deleted file mode 100644 index 24a61caa2c..0000000000 --- a/op_tests/triton_tests/gemm/fused/test_fused_gemm_a16w16_copy_x.py +++ /dev/null @@ -1,119 +0,0 @@ -# SPDX-License-Identifier: MIT -# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. - -import torch -import torch.nn.functional as F -import pytest - -from aiter import dtypes -from aiter.ops.triton.gemm.fused.fused_gemm_a16w16_copy_x import ( - fused_gemm_a16w16_copy_x, -) -from op_tests.triton_tests.gemm.basic.test_gemm_a16w16 import ( - generate_gemm_a16w16_inputs, -) - - -def get_x_vals(): - x_vals = [(1, 1, 1)] - x_vals += [(3, 5, 2)] - x_vals += [(1024, 1024, 1024)] - x_vals += [(2048, 2048, 2048)] - # DSv4 router gate: num_tokens x 384 x 7168 - x_vals += [(2**i, 384, 7168) for i in range(5, 9)] - # DSR1 router GEMM - x_vals += [(2**i, 256, 7168) for i in range(5, 9)] - return x_vals - - -@pytest.mark.parametrize("M, N, K", get_x_vals()) -def test_fused_gemm_a16w16_copy_x(M: int, N: int, K: int): - torch.cuda.empty_cache() - x, w, _, _, _ = generate_gemm_a16w16_inputs( - M, N, K, dtype=torch.bfloat16, output=False - ) - - torch_y = F.linear(x, w, bias=None) - torch_x_copy = x.to(dtypes.fp8) - - triton_y, triton_x_copy = fused_gemm_a16w16_copy_x(x, w) - - torch.testing.assert_close(triton_y, torch_y, atol=1e-1, rtol=1e-2) - # x_copy comparison: both are produced by an identical BF16 -> FP8 cast, - # so we expect bitwise equality. Compare via BF16 to avoid FP8 dtype - # ambiguities across gfx targets. - torch.testing.assert_close( - triton_x_copy.to(torch.bfloat16), - torch_x_copy.to(torch.bfloat16), - atol=0.0, - rtol=0.0, - ) - - -def get_fewer_x_vals(): - x_vals = [(16, 1024, 1024)] - x_vals += [(128, 8192, 512)] - x_vals += [(256, 512, 8192)] - x_vals += [(1024, 1024, 1024)] - return x_vals - - -@pytest.mark.parametrize("activation", ["gelu", "gelu_tanh", "silu"]) -@pytest.mark.parametrize("M, N, K", get_fewer_x_vals()) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("output", [True, False]) -def test_fused_gemm_a16w16_copy_x_activation( - M: int, N: int, K: int, dtype, output, activation -): - x, w, _, _, y = generate_gemm_a16w16_inputs(M, N, K, dtype, output=output) - - torch_y = F.linear(x, w, bias=None) - if activation == "gelu": - torch_y = F.gelu(torch_y) - elif activation == "gelu_tanh": - torch_y = F.gelu(torch_y, approximate="tanh") - elif activation == "silu": - torch_y = F.silu(torch_y) - torch_x_copy = x.to(dtypes.fp8) - - triton_y, triton_x_copy = fused_gemm_a16w16_copy_x( - x, - w, - bias=None, - dtype=dtype, - y=y, - activation=activation, - ) - - torch.testing.assert_close(triton_y, torch_y, atol=1e-1, rtol=1e-2) - torch.testing.assert_close( - triton_x_copy.to(torch.bfloat16), - torch_x_copy.to(torch.bfloat16), - atol=0.0, - rtol=0.0, - ) - - -@pytest.mark.parametrize("M, N, K", get_fewer_x_vals()) -@pytest.mark.parametrize("skip_reduce", [True, False]) -def test_fused_gemm_a16w16_copy_x_skip_reduce(M: int, N: int, K: int, skip_reduce): - torch.cuda.empty_cache() - x, w, _, _, _ = generate_gemm_a16w16_inputs( - M, N, K, dtype=torch.bfloat16, output=False - ) - - torch_y = F.linear(x, w, bias=None) - torch_x_copy = x.to(dtypes.fp8) - - triton_y, triton_x_copy = fused_gemm_a16w16_copy_x(x, w, skip_reduce=skip_reduce) - - if triton_y.dim() == 3: - triton_y = triton_y.sum(axis=0).to(torch.bfloat16) - - torch.testing.assert_close(triton_y, torch_y, atol=1e-1, rtol=1e-2) - torch.testing.assert_close( - triton_x_copy.to(torch.bfloat16), - torch_x_copy.to(torch.bfloat16), - atol=0.0, - rtol=0.0, - ) diff --git a/op_tests/triton_tests/gemm/fused/test_fused_gemm_a16w16_quant_x.py b/op_tests/triton_tests/gemm/fused/test_fused_gemm_a16w16_quant_x.py new file mode 100644 index 0000000000..2a7ccf5624 --- /dev/null +++ b/op_tests/triton_tests/gemm/fused/test_fused_gemm_a16w16_quant_x.py @@ -0,0 +1,143 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. + +import torch +import torch.nn.functional as F +import pytest + +from aiter.ops.triton.gemm.fused.fused_gemm_a16w16_quant_x import ( + fused_gemm_a16w16_quant_x, +) +from op_tests.triton_tests.gemm.basic.test_gemm_a16w16 import ( + generate_gemm_a16w16_inputs, +) + +_QUANT_BLOCK_SIZE = 32 +# 0xFF800000 in two's complement int32. Mask keeps sign + 8-bit exponent + top mantissa bit. +_E8M0_MASK_INT32 = -8388608 + + +def torch_mxfp8_quant_from_fp32(x_fp32: torch.Tensor): + """Bit-faithful port of `_mxfp8_quant_kernel` quant logic, taking fp32 input. + + Computes per-1x32 e8m0 scale (uint8) and FP8 e4m3fn values. + """ + assert x_fp32.dim() == 2, f"x_fp32 must be 2D, got {x_fp32.dim()}" + M, K = x_fp32.shape + assert K % _QUANT_BLOCK_SIZE == 0 + Ng = K // _QUANT_BLOCK_SIZE + x_2d = x_fp32.reshape(M, Ng, _QUANT_BLOCK_SIZE).to(torch.float32) + amax = torch.amax(torch.abs(x_2d), dim=-1, keepdim=True) # (M, Ng, 1) + + amax_i32 = amax.contiguous().view(torch.int32) + amax_i32 = (amax_i32 + 0x200000) & _E8M0_MASK_INT32 + amax_p2 = amax_i32.view(torch.float32) + + scale_unbiased = torch.log2(amax_p2).floor() - 8 + scale_unbiased = torch.clamp(scale_unbiased, min=-127, max=127) + scale_e8m0 = (scale_unbiased.to(torch.int32) + 127).to(torch.uint8) + quant_scale = torch.exp2(-scale_unbiased) + + qx_2d = x_2d * quant_scale + qx = qx_2d.reshape(M, K) + y_fp8 = qx.to(torch.float8_e4m3fn) + s = scale_e8m0.reshape(M, Ng) + return y_fp8, s + + +def get_x_vals(): + x_vals = [(1024, 1024, 1024)] + x_vals += [(2048, 2048, 2048)] + # DSv4 router gate: num_tokens x 384 x 7168 + x_vals += [(2**i, 384, 7168) for i in range(5, 9)] + # DSR1 router GEMM + x_vals += [(2**i, 256, 7168) for i in range(5, 9)] + return x_vals + + +def _assert_quant_close(triton_x_quant, triton_x_scales, x): + ref_x_quant, ref_x_scales = torch_mxfp8_quant_from_fp32(x.to(torch.float32)) + # e8m0 scales: bit-exact (integer-only after fp32 cast). + torch.testing.assert_close(triton_x_scales, ref_x_scales) + # Quantized values: compare via uint8 view (allow off-by-1 for any rounding + # subtlety in the fp32->fp8 cast). + torch.testing.assert_close( + triton_x_quant.view(torch.uint8).to(torch.int32), + ref_x_quant.view(torch.uint8).to(torch.int32), + atol=1, + rtol=0, + ) + + +@pytest.mark.parametrize("M, N, K", get_x_vals()) +def test_fused_gemm_a16w16_quant_x(M: int, N: int, K: int): + torch.cuda.empty_cache() + x, w, _, _, _ = generate_gemm_a16w16_inputs( + M, N, K, dtype=torch.bfloat16, output=False + ) + + torch_y = F.linear(x, w, bias=None) + + triton_y, triton_x_quant, triton_x_scales = fused_gemm_a16w16_quant_x(x, w) + + torch.testing.assert_close(triton_y, torch_y, atol=1e-1, rtol=1e-2) + _assert_quant_close(triton_x_quant, triton_x_scales, x) + + +def get_fewer_x_vals(): + x_vals = [(16, 1024, 1024)] + x_vals += [(128, 8192, 512)] + x_vals += [(256, 512, 8192)] + x_vals += [(1024, 1024, 1024)] + return x_vals + + +@pytest.mark.parametrize("activation", ["gelu", "gelu_tanh", "silu"]) +@pytest.mark.parametrize("M, N, K", get_fewer_x_vals()) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("output", [True, False]) +def test_fused_gemm_a16w16_quant_x_activation( + M: int, N: int, K: int, dtype, output, activation +): + x, w, _, _, y = generate_gemm_a16w16_inputs(M, N, K, dtype, output=output) + + torch_y = F.linear(x, w, bias=None) + if activation == "gelu": + torch_y = F.gelu(torch_y) + elif activation == "gelu_tanh": + torch_y = F.gelu(torch_y, approximate="tanh") + elif activation == "silu": + torch_y = F.silu(torch_y) + + triton_y, triton_x_quant, triton_x_scales = fused_gemm_a16w16_quant_x( + x, + w, + bias=None, + dtype=dtype, + y=y, + activation=activation, + ) + + torch.testing.assert_close(triton_y, torch_y, atol=1e-1, rtol=1e-2) + _assert_quant_close(triton_x_quant, triton_x_scales, x) + + +@pytest.mark.parametrize("M, N, K", get_fewer_x_vals()) +@pytest.mark.parametrize("skip_reduce", [True, False]) +def test_fused_gemm_a16w16_quant_x_skip_reduce(M: int, N: int, K: int, skip_reduce): + torch.cuda.empty_cache() + x, w, _, _, _ = generate_gemm_a16w16_inputs( + M, N, K, dtype=torch.bfloat16, output=False + ) + + torch_y = F.linear(x, w, bias=None) + + triton_y, triton_x_quant, triton_x_scales = fused_gemm_a16w16_quant_x( + x, w, skip_reduce=skip_reduce + ) + + if triton_y.dim() == 3: + triton_y = triton_y.sum(axis=0).to(torch.bfloat16) + + torch.testing.assert_close(triton_y, torch_y, atol=1e-3, rtol=1e-2) + _assert_quant_close(triton_x_quant, triton_x_scales, x)