diff --git a/aiter/ops/triton/_triton_kernels/fusions/fused_clamp_act_mul.py b/aiter/ops/triton/_triton_kernels/fusions/fused_clamp_act_mul.py index f5887fdcb5..b7d0c62c1a 100644 --- a/aiter/ops/triton/_triton_kernels/fusions/fused_clamp_act_mul.py +++ b/aiter/ops/triton/_triton_kernels/fusions/fused_clamp_act_mul.py @@ -23,6 +23,7 @@ [ "BLOCK_SIZE_N", "QUANT_BLOCK_SIZE", + "SCALE_FMT", "HAVE_WEIGHTS", "WEIGHT_BROADCAST", "HAVE_SWIGLU_CLAMP", @@ -50,6 +51,7 @@ def _fused_clamp_silu_mul_kernel( swiglu_limit, BLOCK_SIZE_N: tl.constexpr, QUANT_BLOCK_SIZE: tl.constexpr, + SCALE_FMT: tl.constexpr, DTYPE_MAX: tl.constexpr, DTYPE_MIN: tl.constexpr, HAVE_WEIGHTS: tl.constexpr, @@ -57,6 +59,8 @@ def _fused_clamp_silu_mul_kernel( HAVE_SWIGLU_CLAMP: tl.constexpr, HAS_QUANT: tl.constexpr, ACTIVATION: tl.constexpr, + SHUFFLE: tl.constexpr, + SCALE_N_PAD: tl.constexpr, ): m_pid = tl.program_id(0) n_offs = tl.arange(0, BLOCK_SIZE_N) @@ -95,11 +99,34 @@ def _fused_clamp_silu_mul_kernel( out = out * w if HAS_QUANT: - out_q, block_scales = _fp8_quant_op( - out, 1, BLOCK_SIZE_N, QUANT_BLOCK_SIZE, DTYPE_MAX, DTYPE_MIN - ) - out_q = tl.ravel(out_q) - block_scales = tl.ravel(block_scales) + if SCALE_FMT == "ue8m0": + # Per-1×QUANT_BLOCK_SIZE MXFP8 emit: fp8 e4m3 values + uint8 ue8m0 + # biased-exponent scales. Mirrors the ue8m0 path used by moe_gemm_a8w4. + NUM_QB: tl.constexpr = BLOCK_SIZE_N // QUANT_BLOCK_SIZE + out_3d = tl.reshape(out, [1, NUM_QB, QUANT_BLOCK_SIZE]) + abs_3d = tl.abs(out_3d) + max_val = tl.max(abs_3d, axis=2, keep_dims=True) + dequant_scale = max_val / DTYPE_MAX + # ROUND_UP via exponent: 2 ** ceil(log2(dequant_scale)) + dequant_scale_exp = ( + dequant_scale.to(tl.uint32, bitcast=True) + 0x007FFFFF + ) & 0x7F800000 + dequant_scale_rounded = dequant_scale_exp.to(tl.float32, bitcast=True) + quant_scale = tl.where( + dequant_scale_rounded == 0, 0.0, 1.0 / dequant_scale_rounded + ) + quant_tensor = out_3d * quant_scale + quant_2d = tl.reshape(quant_tensor, [1, BLOCK_SIZE_N]) + out_q = tl.ravel(quant_2d) + scale_exp = (dequant_scale_exp >> 23).to(tl.uint8) + scale_exp_2d = tl.reshape(scale_exp, [1, NUM_QB]) + block_scales = tl.ravel(scale_exp_2d) + else: + out_q, block_scales = _fp8_quant_op( + out, 1, BLOCK_SIZE_N, QUANT_BLOCK_SIZE, DTYPE_MAX, DTYPE_MIN + ) + out_q = tl.ravel(out_q) + block_scales = tl.ravel(block_scales) tl.store( out_ptr + m_pid * out_stride_m + n_offs * out_stride_n, @@ -108,13 +135,36 @@ def _fused_clamp_silu_mul_kernel( ) num_bs = tl.cdiv(n_half, QUANT_BLOCK_SIZE) - NUM_QB: tl.constexpr = BLOCK_SIZE_N // QUANT_BLOCK_SIZE - g_offs = tl.arange(0, NUM_QB) - tl.store( - scale_ptr + m_pid * scale_stride_m + g_offs * scale_stride_n, - block_scales.to(scale_ptr.dtype.element_ty), - mask=g_offs < num_bs, - ) + NUM_QB_S: tl.constexpr = BLOCK_SIZE_N // QUANT_BLOCK_SIZE + g_offs = tl.arange(0, NUM_QB_S) + if SHUFFLE: + bs_offs_0 = m_pid // 32 + bs_offs_1 = m_pid % 32 + bs_offs_2 = bs_offs_1 % 16 + bs_offs_1 = bs_offs_1 // 16 + bs_offs_3 = g_offs // 8 + bs_offs_4 = g_offs % 8 + bs_offs_5 = bs_offs_4 % 4 + bs_offs_4 = bs_offs_4 // 4 + bs_offs = ( + bs_offs_1 + + bs_offs_4 * 2 + + bs_offs_2 * 2 * 2 + + bs_offs_5 * 2 * 2 * 16 + + bs_offs_3 * 2 * 2 * 16 * 4 + + bs_offs_0 * 2 * 16 * SCALE_N_PAD + ) + tl.store( + scale_ptr + bs_offs, + block_scales.to(scale_ptr.dtype.element_ty), + mask=g_offs < num_bs, + ) + else: + tl.store( + scale_ptr + m_pid * scale_stride_m + g_offs * scale_stride_n, + block_scales.to(scale_ptr.dtype.element_ty), + mask=g_offs < num_bs, + ) else: tl.store( out_ptr + m_pid * out_stride_m + n_offs * out_stride_n, diff --git a/aiter/ops/triton/_triton_kernels/gemm/basic/gemm_afp8wfp8.py b/aiter/ops/triton/_triton_kernels/gemm/basic/gemm_afp8wfp8.py new file mode 100644 index 0000000000..f9e28bbcea --- /dev/null +++ b/aiter/ops/triton/_triton_kernels/gemm/basic/gemm_afp8wfp8.py @@ -0,0 +1,469 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. + +import triton +import triton.language as tl +from aiter.ops.triton.utils._triton.kernel_repr import make_kernel_repr +from aiter.ops.triton.utils._triton.pid_preprocessing import pid_grid, remap_xcd +from aiter.ops.triton.utils.gemm_config_utils import get_gemm_config + +_gemm_afp8wfp8_repr = make_kernel_repr( + "_gemm_afp8wfp8_kernel", + [ + "BLOCK_SIZE_M", + "BLOCK_SIZE_N", + "BLOCK_SIZE_K", + "GROUP_SIZE_M", + "num_warps", + "num_stages", + "waves_per_eu", + "matrix_instr_nonkdim", + "cache_modifier", + "NUM_KSPLIT", + "SPLITK_BLOCK_SIZE", + ], +) + + +@triton.heuristics( + { + "EVEN_K": lambda args: (args["K"] % args["BLOCK_SIZE_K"] == 0), + } +) +@triton.jit(repr=_gemm_afp8wfp8_repr) +def _gemm_afp8wfp8_kernel( + a_ptr, + b_ptr, + c_ptr, + a_scales_ptr, + b_scales_ptr, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_ck, + stride_cm, + stride_cn, + stride_asm, + stride_ask, + stride_bsn, + stride_bsk, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + NUM_KSPLIT: tl.constexpr, + SPLITK_BLOCK_SIZE: tl.constexpr, + EVEN_K: tl.constexpr, + num_warps: tl.constexpr, + num_stages: tl.constexpr, + waves_per_eu: tl.constexpr, + matrix_instr_nonkdim: tl.constexpr, + cache_modifier: tl.constexpr, +): + """ + Kernel for computing the matmul C = A x B. + A and B inputs are FP8 e4m3 (1 byte per element). + A_scales are e8m0 (uint8) with shape (M, K // 32). + B_scales are stored compact e8m0 (uint8) with shape (N // 128, K // 128), + representing 128x128 weight blocks. Broadcast inside kernel to (N, K // 32). + A has shape (M, K), B has shape (K, N) and C has shape (M, N). + Output dtype is determined by c_ptr (bf16 or fp16). + When NUM_KSPLIT > 1, K is split into NUM_KSPLIT partitions of + SPLITK_BLOCK_SIZE elements and the partial result for partition pid_k is + written to c_ptr + pid_k * stride_ck; a downstream reduce kernel sums them. + """ + + tl.assume(stride_am > 0) + tl.assume(stride_ak > 0) + tl.assume(stride_bk > 0) + tl.assume(stride_bn > 0) + tl.assume(stride_cm > 0) + tl.assume(stride_cn > 0) + tl.assume(stride_asm > 0) + tl.assume(stride_ask > 0) + tl.assume(stride_bsk > 0) + tl.assume(stride_bsn > 0) + + GRID_MN = tl.cdiv(M, BLOCK_SIZE_M) * tl.cdiv(N, BLOCK_SIZE_N) + + pid_unified = tl.program_id(axis=0) + pid_k = pid_unified % NUM_KSPLIT + pid = pid_unified // NUM_KSPLIT + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + + if NUM_KSPLIT == 1: + pid = remap_xcd(pid, GRID_MN, NUM_XCDS=8) + pid_m, pid_n = pid_grid(pid, num_pid_m, num_pid_n, GROUP_SIZE_M=GROUP_SIZE_M) + else: + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) + tl.assume(pid_k >= 0) + + # Scale group sizes + SCALE_GROUP_SIZE: tl.constexpr = 32 # A: per 32 elements along K + B_SCALE_K_GROUP: tl.constexpr = 128 # B: per 128 along K + B_SCALE_N_GROUP: tl.constexpr = 128 # B: per 128 along N + + if (pid_k * SPLITK_BLOCK_SIZE) < K: + # K-block iteration range for this split (absolute block indices). + num_k_iter = tl.cdiv(SPLITK_BLOCK_SIZE, BLOCK_SIZE_K) + + # Create pointers for first block of A and B input matrices. The K + # offset is the absolute start of this split's K range. + offs_k = tl.arange(0, BLOCK_SIZE_K) + offs_k_split = pid_k * SPLITK_BLOCK_SIZE + offs_k + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + a_ptrs = a_ptr + ( + offs_am[:, None] * stride_am + offs_k_split[None, :] * stride_ak + ) + b_ptrs = b_ptr + ( + offs_k_split[:, None] * stride_bk + offs_bn[None, :] * stride_bn + ) + + # A-scale pointers: per-row (M) and per scale group (K // 32). Shift + # along the K-scale axis by the split's start in scale groups. + offs_ks_a = tl.arange(0, BLOCK_SIZE_K // SCALE_GROUP_SIZE) + offs_ks_a_split = pid_k * (SPLITK_BLOCK_SIZE // SCALE_GROUP_SIZE) + offs_ks_a + a_scale_ptrs = ( + a_scales_ptr + + offs_am[:, None] * stride_asm + + offs_ks_a_split[None, :] * stride_ask + ) + + # B-scale pointers: compact (N // 128, K // 128) — broadcast inside the kernel + # Each scale covers a 128(N) x 128(K) block. Computed per-iteration below + # using absolute K (so split-K naturally addresses the right b-scale block). + offs_bsn = offs_bn // B_SCALE_N_GROUP # (BLOCK_SIZE_N,) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + offs_scale_k_a = tl.arange(0, BLOCK_SIZE_K // SCALE_GROUP_SIZE) + + for k in range(pid_k * num_k_iter, (pid_k + 1) * num_k_iter): + # K base for this iteration (in elements, absolute). + k_base = k * BLOCK_SIZE_K + + # ---- Load A scales (M, BLOCK_SIZE_K // 32) ---- + if EVEN_K: + a_scales = tl.load(a_scale_ptrs) + else: + a_scale_mask = offs_scale_k_a[None, :] < ( + K // SCALE_GROUP_SIZE - k * (BLOCK_SIZE_K // SCALE_GROUP_SIZE) + ) + a_scales = tl.load(a_scale_ptrs, mask=a_scale_mask, other=127) + + # ---- Load and broadcast B scales (BLOCK_SIZE_N, BLOCK_SIZE_K // 32) ---- + offs_bsk = ( + k_base + offs_scale_k_a * SCALE_GROUP_SIZE + ) // B_SCALE_K_GROUP # (BLOCK_SIZE_K // 32,) + b_scale_ptrs = ( + b_scales_ptr + + offs_bsn[:, None] * stride_bsn + + offs_bsk[None, :] * stride_bsk + ) + if EVEN_K: + b_scales = tl.load(b_scale_ptrs, cache_modifier=cache_modifier) + else: + # OOB along K: load with the same mask as a-scales + b_scale_mask = offs_scale_k_a[None, :] < ( + K // SCALE_GROUP_SIZE - k * (BLOCK_SIZE_K // SCALE_GROUP_SIZE) + ) + b_scales = tl.load( + b_scale_ptrs, + mask=b_scale_mask, + other=127, + cache_modifier=cache_modifier, + ) + + # ---- Load A, B data ---- + if EVEN_K: + a = tl.load(a_ptrs) + b = tl.load(b_ptrs, cache_modifier=cache_modifier) + else: + a = tl.load( + a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0 + ) + b = tl.load( + b_ptrs, + mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, + other=0, + cache_modifier=cache_modifier, + ) + + accumulator = tl.dot_scaled( + a, a_scales, "e4m3", b, b_scales, "e4m3", accumulator + ) + + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + a_scale_ptrs += (BLOCK_SIZE_K // SCALE_GROUP_SIZE) * stride_ask + + c = accumulator.to(c_ptr.type.element_ty) + + # Write back the block of the output matrix C with masks. For + # NUM_KSPLIT > 1, each pid_k writes to a separate slab of c_ptr. + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64) + c_ptrs = ( + c_ptr + + stride_cm * offs_cm[:, None] + + stride_cn * offs_cn[None, :] + + pid_k * stride_ck + ) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +_gemm_afp8wfp8_preshuffle_repr = make_kernel_repr( + "_gemm_afp8wfp8_preshuffle_kernel", + [ + "BLOCK_SIZE_M", + "BLOCK_SIZE_N", + "BLOCK_SIZE_K", + "GROUP_SIZE_M", + "num_warps", + "num_stages", + "waves_per_eu", + "matrix_instr_nonkdim", + "cache_modifier", + "NUM_KSPLIT", + "SPLITK_BLOCK_SIZE", + ], +) + + +@triton.heuristics( + { + "EVEN_K": lambda args: (args["K"] % args["BLOCK_SIZE_K"] == 0), + } +) +@triton.jit(repr=_gemm_afp8wfp8_preshuffle_repr) +def _gemm_afp8wfp8_preshuffle_kernel( + a_ptr, + b_ptr, + c_ptr, + a_scales_ptr, + b_scales_ptr, + M, + N, + K, + stride_am, + stride_ak, + stride_bn, + stride_bk, + stride_ck, + stride_cm, + stride_cn, + stride_asm, + stride_ask, + stride_bsn, + stride_bsk, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + NUM_KSPLIT: tl.constexpr, + SPLITK_BLOCK_SIZE: tl.constexpr, + EVEN_K: tl.constexpr, + num_warps: tl.constexpr, + num_stages: tl.constexpr, + waves_per_eu: tl.constexpr, + matrix_instr_nonkdim: tl.constexpr, + cache_modifier: tl.constexpr, +): + """ + Preshuffle variant of _gemm_afp8wfp8_kernel. Weight tensor has been shuffled + via aiter.ops.shuffle.shuffle_weight(layout=(16, 16)) so that 16-row N tiles + are interleaved with their 32-col K chunks in storage. The kernel loads the + shuffled tile in storage order (BLOCK_SIZE_N // 16, BLOCK_SIZE_K * 16) then + reshape+permute+trans inside the kernel to restore logical (K, N) layout + before tl.dot_scaled. Scales remain in the unshuffled compact 128x128 layout. + When NUM_KSPLIT > 1, K is split into NUM_KSPLIT partitions of + SPLITK_BLOCK_SIZE elements; each pid_k writes to c_ptr + pid_k * stride_ck. + """ + + tl.assume(stride_am > 0) + tl.assume(stride_ak > 0) + tl.assume(stride_bk > 0) + tl.assume(stride_bn > 0) + tl.assume(stride_cm > 0) + tl.assume(stride_cn > 0) + tl.assume(stride_asm > 0) + tl.assume(stride_ask > 0) + tl.assume(stride_bsk > 0) + tl.assume(stride_bsn > 0) + + GRID_MN = tl.cdiv(M, BLOCK_SIZE_M) * tl.cdiv(N, BLOCK_SIZE_N) + + pid_unified = tl.program_id(axis=0) + pid_k = pid_unified % NUM_KSPLIT + pid = pid_unified // NUM_KSPLIT + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + + if NUM_KSPLIT == 1: + pid = remap_xcd(pid, GRID_MN, NUM_XCDS=8) + pid_m, pid_n = pid_grid(pid, num_pid_m, num_pid_n, GROUP_SIZE_M=GROUP_SIZE_M) + else: + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) + tl.assume(pid_k >= 0) + + SCALE_GROUP_SIZE: tl.constexpr = 32 # A: per-32 along K + B_SCALE_K_GROUP: tl.constexpr = 128 # B compact: per-128 along K + B_SCALE_N_GROUP: tl.constexpr = 128 # B compact: per-128 along N + + if (pid_k * SPLITK_BLOCK_SIZE) < K: + num_k_iter = tl.cdiv(SPLITK_BLOCK_SIZE, BLOCK_SIZE_K) + + # A pointers (offset by this split's K start). + offs_k = tl.arange(0, BLOCK_SIZE_K) + offs_k_split = pid_k * SPLITK_BLOCK_SIZE + offs_k + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + a_ptrs = a_ptr + ( + offs_am[:, None] * stride_am + offs_k_split[None, :] * stride_ak + ) + + # B pointers for preshuffled layout. The shuffled storage is viewed as + # (N // 16, K * 16) elements. pid_n indexes BLOCK_SIZE_N // 16 N-tiles per + # step; the K dimension is expanded by 16x in byte addresses. The split + # offsets the K-byte axis by pid_k * SPLITK_BLOCK_SIZE * 16. + offs_bn_shuffle = pid_n * (BLOCK_SIZE_N // 16) + tl.arange( + 0, BLOCK_SIZE_N // 16 + ) + offs_k_shuffle_arr = tl.arange(0, BLOCK_SIZE_K * 16) + offs_k_shuffle = pid_k * SPLITK_BLOCK_SIZE * 16 + offs_k_shuffle_arr + b_ptrs = b_ptr + ( + offs_bn_shuffle[:, None] * stride_bn + offs_k_shuffle[None, :] * stride_bk + ) + + # A-scale pointers: per-row M, per 32-K group. Shift along the K-scale + # axis by the split's start in scale groups. + offs_ks_a = tl.arange(0, BLOCK_SIZE_K // SCALE_GROUP_SIZE) + offs_ks_a_split = pid_k * (SPLITK_BLOCK_SIZE // SCALE_GROUP_SIZE) + offs_ks_a + a_scale_ptrs = ( + a_scales_ptr + + offs_am[:, None] * stride_asm + + offs_ks_a_split[None, :] * stride_ask + ) + + # B-scale pointers: compact (N // 128, K // 128). The N index needs the + # ORIGINAL (logical) row, not the shuffled row index. + offs_bn_logical = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_bsn = offs_bn_logical // B_SCALE_N_GROUP + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + offs_scale_k_a = tl.arange(0, BLOCK_SIZE_K // SCALE_GROUP_SIZE) + + for k in range(pid_k * num_k_iter, (pid_k + 1) * num_k_iter): + k_base = k * BLOCK_SIZE_K # absolute K base + + # Load A scales. + if EVEN_K: + a_scales = tl.load(a_scale_ptrs) + else: + a_scale_mask = offs_scale_k_a[None, :] < ( + K // SCALE_GROUP_SIZE - k * (BLOCK_SIZE_K // SCALE_GROUP_SIZE) + ) + a_scales = tl.load(a_scale_ptrs, mask=a_scale_mask, other=127) + + # Load and broadcast B scales (computed from absolute K). + offs_bsk = (k_base + offs_scale_k_a * SCALE_GROUP_SIZE) // B_SCALE_K_GROUP + b_scale_ptrs = ( + b_scales_ptr + + offs_bsn[:, None] * stride_bsn + + offs_bsk[None, :] * stride_bsk + ) + if EVEN_K: + b_scales = tl.load(b_scale_ptrs, cache_modifier=cache_modifier) + else: + b_scale_mask = offs_scale_k_a[None, :] < ( + K // SCALE_GROUP_SIZE - k * (BLOCK_SIZE_K // SCALE_GROUP_SIZE) + ) + b_scales = tl.load( + b_scale_ptrs, + mask=b_scale_mask, + other=127, + cache_modifier=cache_modifier, + ) + + # Load A and B (preshuffled). + if EVEN_K: + a = tl.load(a_ptrs) + b_shuf = tl.load(b_ptrs, cache_modifier=cache_modifier) + else: + a = tl.load( + a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0 + ) + b_shuf = tl.load( + b_ptrs, + mask=offs_k_shuffle_arr[None, :] < (K - k * BLOCK_SIZE_K) * 16, + other=0, + cache_modifier=cache_modifier, + ) + + # Unshuffle B in-kernel. Inverse of the shuffle_weight permute: + # shuffle: (N // 16, 16, K // 32, 2, 16) --[perm 0,1,3,4,2,5]-> (N // 16, K // 32, 2, 16, 16) + # unshuffle: (N // 16, K // 32, 2, 16, 16) --[perm 0,1,4,2,3,5]-> (N // 16, 16, K // 32, 2, 16) + # then flatten to (N, K) and trans to (K, N). + b = ( + b_shuf.reshape( + 1, + BLOCK_SIZE_N // 16, + BLOCK_SIZE_K // 32, + 2, + 16, + 16, + ) + .permute(0, 1, 4, 2, 3, 5) + .reshape(BLOCK_SIZE_N, BLOCK_SIZE_K) + .trans(1, 0) + ) + + accumulator = tl.dot_scaled( + a, a_scales, "e4m3", b, b_scales, "e4m3", accumulator + ) + + # Advance pointers. + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * 16 * stride_bk + a_scale_ptrs += (BLOCK_SIZE_K // SCALE_GROUP_SIZE) * stride_ask + + c = accumulator.to(c_ptr.type.element_ty) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64) + c_ptrs = ( + c_ptr + + stride_cm * offs_cm[:, None] + + stride_cn * offs_cn[None, :] + + pid_k * stride_ck + ) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +def _get_config( + M: int, + N: int, + K: int, + shuffle: bool = False, +): + if shuffle: + return get_gemm_config("GEMM-AFP8WFP8_PRESHUFFLED", M, N, K) + else: + return get_gemm_config("GEMM-AFP8WFP8", M, N, K) diff --git a/aiter/ops/triton/_triton_kernels/gemm/fused/fused_gemm_a16w16_quant_x.py b/aiter/ops/triton/_triton_kernels/gemm/fused/fused_gemm_a16w16_quant_x.py new file mode 100644 index 0000000000..13459fa76c --- /dev/null +++ b/aiter/ops/triton/_triton_kernels/gemm/fused/fused_gemm_a16w16_quant_x.py @@ -0,0 +1,258 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. + +import triton.language as tl +from aiter.ops.triton._triton_kernels.quant.quant import _mxfp8_quant_op +from aiter.ops.triton.utils._triton.kernel_repr import make_kernel_repr +from aiter.ops.triton.utils._triton.pid_preprocessing import pid_grid, remap_xcd +from aiter.ops.triton.utils.gemm_config_utils import ( + compute_splitk_params, + get_gemm_config, +) + +import triton + +_fused_gemm_a16w16_quant_x_repr = make_kernel_repr( + "_fused_gemm_a16w16_quant_x_kernel", + [ + "BLOCK_SIZE_M", + "BLOCK_SIZE_N", + "BLOCK_SIZE_K", + "GROUP_SIZE_M", + "NUM_KSPLIT", + "SPLITK_BLOCK_SIZE", + "QUANT_BLOCK_SIZE", + "EVEN_K", + "EVEN_MN", + "cache_modifier", + "activation", + "use_activation", + "ADD_BIAS", + "SKIP_REDUCE", + ], +) + + +@triton.heuristics( + { + "EVEN_K": lambda args: (args["K"] % (args["SPLITK_BLOCK_SIZE"]) == 0) + and (args["SPLITK_BLOCK_SIZE"] % args["BLOCK_SIZE_K"] == 0), + "EVEN_MN": lambda args: (args["M"] % args["BLOCK_SIZE_M"] == 0) + and (args["N"] % args["BLOCK_SIZE_N"] == 0), + } +) +@triton.jit( + repr=_fused_gemm_a16w16_quant_x_repr, + do_not_specialize=["M", "N"], +) +def _fused_gemm_a16w16_quant_x_kernel( + a_ptr, + b_ptr, + bias_ptr, + c_ptr, + a_quant_ptr, + a_scale_ptr, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_ck, + stride_cm, + stride_cn, + stride_a_quant_m, + stride_a_quant_k, + stride_a_scale_m, + stride_a_scale_n, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + NUM_KSPLIT: tl.constexpr, + SPLITK_BLOCK_SIZE: tl.constexpr, + QUANT_BLOCK_SIZE: tl.constexpr, + EVEN_K: tl.constexpr, + EVEN_MN: tl.constexpr, + cache_modifier: tl.constexpr, + activation: tl.constexpr, + use_activation: tl.constexpr, + ADD_BIAS: tl.constexpr, + SKIP_REDUCE: tl.constexpr, +): + """Kernel that computes C = A x B and also emits an MXFP8-quantized A. + + The grid is laid out as a single 1D program-id space split into two + contiguous regions: + + * `[0, NUM_KSPLIT * num_pid_m * num_pid_n)` runs the GEMM (identical to + the unfused a16w16 kernel). + * `[NUM_KSPLIT * num_pid_m * num_pid_n, GEMM_GRID + num_pid_m * num_pid_k_copy)` + runs the per-1x32 MXFP8 quantization of A: each program handles one + `BLOCK_SIZE_M x BLOCK_SIZE_K` tile of A, derives per-1x32 e8m0 scales, + and writes both the FP8 values and the uint8 scales. + + BLOCK_SIZE_K must be a multiple of QUANT_BLOCK_SIZE (=32) so that each tile + contains a whole number of MXFP8 groups per row. + """ + + tl.assume(stride_am > 0) + tl.assume(stride_ak > 0) + tl.assume(stride_bk > 0) + tl.assume(stride_bn > 0) + tl.assume(stride_ck > 0) + tl.assume(stride_cm > 0) + tl.assume(stride_cn > 0) + tl.assume(stride_a_quant_m > 0) + tl.assume(stride_a_quant_k > 0) + tl.assume(stride_a_scale_m > 0) + tl.assume(stride_a_scale_n > 0) + + pid_unified = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_k_copy = tl.cdiv(K, BLOCK_SIZE_K) + GEMM_GRID = num_pid_m * num_pid_n * NUM_KSPLIT + + if pid_unified < GEMM_GRID: + # ---- GEMM branch ---------------------------------------------------- + pid_unified = remap_xcd(pid_unified, GEMM_GRID, NUM_XCDS=8) + pid_k = pid_unified % NUM_KSPLIT + pid = pid_unified // NUM_KSPLIT + + if NUM_KSPLIT == 1: + pid_m, pid_n = pid_grid( + pid, num_pid_m, num_pid_n, GROUP_SIZE_M=GROUP_SIZE_M + ) + else: + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) + tl.assume(pid_k >= 0) + + split_k_start = pid_k * SPLITK_BLOCK_SIZE + if split_k_start < K: + offs_k = tl.arange(0, BLOCK_SIZE_K) + offs_k_split = split_k_start + offs_k + if EVEN_MN: + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + else: + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + + a_ptrs = a_ptr + ( + offs_am[:, None] * stride_am + offs_k_split[None, :] * stride_ak + ) + b_ptrs = b_ptr + ( + offs_k_split[:, None] * stride_bk + offs_bn[None, :] * stride_bn + ) + + acc_dtype = tl.float32 if c_ptr.type.element_ty != tl.int8 else tl.int32 + if ADD_BIAS: + if NUM_KSPLIT == 1 or (SKIP_REDUCE and pid_k == 0): + accumulator = tl.load(bias_ptr + offs_bn).to(dtype=acc_dtype) + accumulator = tl.broadcast_to( + accumulator[None, :], (BLOCK_SIZE_M, BLOCK_SIZE_N) + ) + else: + accumulator = tl.zeros( + (BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype + ) + else: + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) + + split_k_end = tl.minimum(split_k_start + SPLITK_BLOCK_SIZE, K) + k_span = split_k_end - split_k_start + num_k_iter = tl.cdiv(k_span, BLOCK_SIZE_K) + + for k in range(num_k_iter): + if EVEN_K: + a = tl.load(a_ptrs) + b = tl.load(b_ptrs, cache_modifier=cache_modifier) + else: + a = tl.load( + a_ptrs, + mask=offs_k[None, :] < k_span - k * BLOCK_SIZE_K, + other=0.0, + ) + b = tl.load( + b_ptrs, + mask=offs_k[:, None] < k_span - k * BLOCK_SIZE_K, + other=0.0, + cache_modifier=cache_modifier, + ) + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + if use_activation and NUM_KSPLIT == 1: + accumulator = activation(accumulator) + + c = accumulator.to(c_ptr.type.element_ty) + offs_cm = pid_m.to(tl.int64) * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = ( + c_ptr + + stride_cm * offs_cm[:, None] + + stride_cn * offs_cn[None, :] + + pid_k * stride_ck + ) + if EVEN_MN: + tl.store(c_ptrs, c) + else: + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + else: + # ---- MXFP8 quant branch -------------------------------------------- + pid_copy = pid_unified - GEMM_GRID + pid_m = pid_copy // num_pid_k_copy + pid_k = pid_copy % num_pid_k_copy + + tl.assume(pid_m >= 0) + tl.assume(pid_k >= 0) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_k = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + + a_ptrs = a_ptr + (offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak) + mask = (offs_m[:, None] < M) & (offs_k[None, :] < K) + a = tl.load(a_ptrs, mask=mask, other=0.0).to(tl.float32) + + # Per-1x32 MXFP8 quant. Group along K within this tile. + n_groups: tl.constexpr = BLOCK_SIZE_K // QUANT_BLOCK_SIZE + a_2d = tl.reshape(a, (BLOCK_SIZE_M, n_groups, QUANT_BLOCK_SIZE)) + scale_e8m0, quant_scale = _mxfp8_quant_op(a_2d, QUANT_AXIS=2) + + qa_2d = a_2d * quant_scale + qa = tl.reshape(qa_2d, (BLOCK_SIZE_M, BLOCK_SIZE_K)) + a_quant_ptrs = a_quant_ptr + ( + offs_m[:, None] * stride_a_quant_m + offs_k[None, :] * stride_a_quant_k + ) + tl.store(a_quant_ptrs, qa.to(a_quant_ptr.type.element_ty), mask=mask) + + # Store scales: shape (M, K // QUANT_BLOCK_SIZE). + offs_s_n = pid_k * n_groups + tl.arange(0, n_groups) + scale_2d = tl.reshape(scale_e8m0, (BLOCK_SIZE_M, n_groups)) + a_scale_ptrs = a_scale_ptr + ( + offs_m[:, None] * stride_a_scale_m + offs_s_n[None, :] * stride_a_scale_n + ) + scale_mask = (offs_m[:, None] < M) & ( + offs_s_n[None, :] < (K // QUANT_BLOCK_SIZE) + ) + tl.store(a_scale_ptrs, scale_2d, mask=scale_mask) + + +def _get_config( + M: int, + N: int, + K: int, +): + # Use the same tuning portal as the unfused gemm_a16w16 — the extra + # MXFP8 quant is assumed not to shift the optimal config. + config, is_tunned = get_gemm_config("GEMM-A16W16", M, N, K) + return compute_splitk_params(config, K), is_tunned diff --git a/aiter/ops/triton/_triton_kernels/moe/moe_op_gemm_a8w4.py b/aiter/ops/triton/_triton_kernels/moe/moe_op_gemm_a8w4.py index b71090c143..a368dba2ce 100644 --- a/aiter/ops/triton/_triton_kernels/moe/moe_op_gemm_a8w4.py +++ b/aiter/ops/triton/_triton_kernels/moe/moe_op_gemm_a8w4.py @@ -162,6 +162,15 @@ def _moe_gemm_a8w4( SPLIT_K: tl.constexpr, W_CACHE_MODIFIER: tl.constexpr, UPCAST_INDICES: tl.constexpr = False, + # Idea 1: fold per-1×32 MXFP8 (ue8m0 scale) group quant into write-back. + # When HAS_MX_OUT=True, Y is fp8 e4m3 and YMxScale receives uint8 ue8m0 + # exponents at [m, n // 32]. Eliminates the standalone `downcast_to_mxfp` + # launch between GEMM1 and GEMM2. Requires SPLIT_K==1 and OUT_BLOCK_N % 32 == 0. + YMxScale=None, + stride_y_mx_m=0, + stride_y_mx_n=0, + HAS_MX_OUT: tl.constexpr = False, + MX_OUT_DTYPE_MAX: tl.constexpr = 448.0, ): tl.assume(stride_y_k >= 0) tl.assume(stride_y_m >= 0) @@ -421,4 +430,45 @@ def _moe_gemm_a8w4( + offs_y_n.to(index_type)[None, :] * stride_y_n ) mask = mask_m[:, None] & mask_n[None, :] - tl.store(YPtrs, out, mask=mask) + if HAS_MX_OUT and SPLIT_K == 1: + # Per-1×32 MXFP8 group quant emit. Mirrors the ue8m0 path in + # `_fused_clamp_silu_mul_kernel`. OUT_BLOCK_N is the post-swiglu output + # block size (BLOCK_N if no swiglu, BLOCK_N//2 with apply_swiglu). + tl.static_assert( + OUT_BLOCK_N % 32 == 0, + "HAS_MX_OUT requires OUT_BLOCK_N % 32 == 0", + ) + NUM_QB: tl.constexpr = OUT_BLOCK_N // 32 + out_safe = tl.where(mask, out, 0.0) + out_3d = tl.reshape(out_safe, [BLOCK_M, NUM_QB, 32]) + abs_3d = tl.abs(out_3d) + max_val = tl.max(abs_3d, axis=2, keep_dims=True) + dequant_scale = max_val / MX_OUT_DTYPE_MAX + # ROUND_UP via exponent: 2 ** ceil(log2(dequant_scale)). + dequant_scale_exponent = ( + dequant_scale.to(tl.uint32, bitcast=True) + 0x007FFFFF + ) & 0x7F800000 + dequant_scale_rounded = dequant_scale_exponent.to(tl.float32, bitcast=True) + quant_scale = tl.where( + dequant_scale_rounded == 0, 0.0, 1.0 / dequant_scale_rounded + ) + quant_tensor = out_3d * quant_scale + quant_2d = tl.reshape(quant_tensor, [BLOCK_M, OUT_BLOCK_N]) + tl.store(YPtrs, quant_2d.to(Y.dtype.element_ty), mask=mask) + # Extract biased exponent (top 8 bits >> 23) as uint8 ue8m0 scale. + scale_exp_3d = (dequant_scale_exponent >> 23).to(tl.uint8) + scale_exp_2d = tl.reshape(scale_exp_3d, [BLOCK_M, NUM_QB]) + offs_s_n = NUM_QB * pid_n + tl.arange(0, NUM_QB) + mask_s_n = offs_s_n < tl.cdiv(yN, 32) + YMxScalePtrs = ( + YMxScale + + (start_m + offs_y_m).to(index_type)[:, None] * stride_y_mx_m + + offs_s_n.to(index_type)[None, :] * stride_y_mx_n + ) + tl.store( + YMxScalePtrs, + scale_exp_2d, + mask=mask_m[:, None] & mask_s_n[None, :], + ) + else: + tl.store(YPtrs, out, mask=mask) diff --git a/aiter/ops/triton/_triton_kernels/moe/moe_routing/expt_data.py b/aiter/ops/triton/_triton_kernels/moe/moe_routing/expt_data.py index ff0b32ac70..70fa3b6f88 100644 --- a/aiter/ops/triton/_triton_kernels/moe/moe_routing/expt_data.py +++ b/aiter/ops/triton/_triton_kernels/moe/moe_routing/expt_data.py @@ -90,3 +90,37 @@ def _expt_data_compute_stage2_fused(expt_id, Hist, TileStart, TileInfo): return TileInfo += tl.load(TileStart + expt_id) tl.store(TileInfo, expt_id) + + +@triton.jit +def _expt_data_only_kernel( + Hist, + n_expts_tot, + TokenStart, + TileStart, + MDTileInfo, + max_num_tiles, + n_gates, + tile_dim_log2: tl.constexpr, + BLOCK: tl.constexpr, + EQUAL_BLOCK: tl.constexpr, +): + """Standalone stage1+stage2 launch — builds ExptData from a precomputed + histogram with no memset. Grid: (n_expts_tot,).""" + pid = tl.program_id(0) + + _expt_data_compute_stage1( + pid, + Hist, + n_expts_tot, + TokenStart, + TileStart, + MDTileInfo, + max_num_tiles, + n_gates, + tile_dim_log2, + BLOCK, + EQUAL_BLOCK, + ) + + _expt_data_compute_stage2(pid, Hist, TileStart, MDTileInfo, tile_dim_log2) diff --git a/aiter/ops/triton/_triton_kernels/moe/moe_routing/topk.py b/aiter/ops/triton/_triton_kernels/moe/moe_routing/topk.py index b75562b091..539380f240 100644 --- a/aiter/ops/triton/_triton_kernels/moe/moe_routing/topk.py +++ b/aiter/ops/triton/_triton_kernels/moe/moe_routing/topk.py @@ -26,6 +26,25 @@ def key_to_fpval(x): return x ^ tl.where((x & tm) == 0, fm, tm) +@triton.jit +def _apply_score_mode(x, SCORE_MODE: tl.constexpr): + """Pre-transform raw logits before topk selection. + + SCORE_MODE values: + - "softmax": no pre-transform (caller may apply softmax to selected + values via APPLY_SOFTMAX flag). + - "sqrtsoftplus": x → sqrt(softplus(x)) using numerically stable + softplus(x) = max(x, 0) + log(1 + exp(-|x|)). Matches + torch.nn.functional.softplus for the DeepSeek-V4 sqrtsoftplus router. + """ + if SCORE_MODE == "sqrtsoftplus": + x_f = x.to(tl.float32) + softplus_x = tl.maximum(x_f, 0.0) + tl.log(1.0 + tl.exp(-tl.abs(x_f))) + return tl.sqrt(softplus_x).to(x.dtype) + # "softmax" (and default): identity + return x + + @triton.jit def streaming_topk( X, @@ -37,45 +56,78 @@ def streaming_topk( N_EXPTS_ACT: tl.constexpr, N_EXPTS_ACT_PAD: tl.constexpr, BLOCK_N: tl.constexpr, + Bias=None, + SCORE_MODE: tl.constexpr = "softmax", + HAS_BIAS: tl.constexpr = False, ): x_nbits: tl.constexpr = X.dtype.element_ty.primitive_bitwidth x_utype: tl.constexpr = tl.dtype(f"uint{x_nbits}") if x_nbits < 16: - # this ensures that we leave at least 16 bits for expert index - # even if the input dtype is smaller than 16 bits: y_nbits: tl.constexpr = 32 else: y_nbits: tl.constexpr = x_nbits * 2 x_ultype: tl.constexpr = tl.dtype(f"uint{y_nbits}") x_dtype: tl.constexpr = X.dtype.element_ty - # subtract 1 from loop iterations because we peel the first (masked) iteration: loop_iterations: tl.constexpr = N_EXPTS_PAD // BLOCK_N - 1 offs_x_n = loop_iterations * BLOCK_N + tl.arange(0, BLOCK_N) mask_n = offs_x_n[None, :] < n_expts_tot - # first iteration: + # First iteration (peeled, may have masked lanes). For SCORE_MODE="softmax" + # (legacy), keep the exact original sequence (load with -inf placeholder, + # no transform). For SCORE_MODE="sqrtsoftplus", load with 0 placeholder, + # apply transform + optional bias, then explicitly mask invalid lanes to + # -inf — because the transform of -inf is NOT -inf (sqrt(softplus(-inf)) + # = 0) and would incorrectly win topk against small valid scores. X_ptrs = X + offs_m[:, None] * stride_xm + offs_x_n[None, :] - x = tl.load(X_ptrs, mask=(mask_m & mask_n), other=float("-inf")) + if SCORE_MODE == "softmax": + x = tl.load(X_ptrs, mask=(mask_m & mask_n), other=float("-inf")) + else: + x = tl.load(X_ptrs, mask=(mask_m & mask_n), other=0.0) + x = _apply_score_mode(x, SCORE_MODE) + if HAS_BIAS: + bias_col_mask = offs_x_n < n_expts_tot + b = tl.load(Bias + offs_x_n, mask=bias_col_mask, other=0.0) + x = x + b[None, :].to(x_dtype) + x = tl.where(mask_m & mask_n, x, float("-inf")) x = fpval_to_key(x.to(x_utype, bitcast=True)) x = (x.to(x_ultype) << 16) | offs_x_n[None, :] acc = tl.topk(x, N_EXPTS_ACT_PAD, dim=1) - # subsequent iterations: + # subsequent iterations: full blocks within n_expts_tot, no col mask for _i in (tl.static_range if loop_iterations <= 4 else range)(loop_iterations): acc = tl.bitonic_merge(acc) # ensure sorted ascending for the merge X_ptrs -= BLOCK_N offs_x_n -= BLOCK_N - x = tl.load(X_ptrs, mask=mask_m, other=float("-inf")) + if SCORE_MODE == "softmax": + x = tl.load(X_ptrs, mask=mask_m, other=float("-inf")) + else: + x = tl.load(X_ptrs, mask=mask_m, other=0.0) + x = _apply_score_mode(x, SCORE_MODE) + if HAS_BIAS: + b = tl.load(Bias + offs_x_n) + x = x + b[None, :].to(x_dtype) + x = tl.where(mask_m, x, float("-inf")) x = fpval_to_key(x.to(x_utype, bitcast=True)) x = (x.to(x_ultype) << 16) | offs_x_n[None, :] acc = tl.maximum(acc, tl.topk(x, N_EXPTS_ACT_PAD, dim=1)) + # Pre-existing bug fix: after the streaming merge loop, acc is not + # guaranteed to be sorted by value (tl.maximum of an ASC and the new + # tl.topk output is bitonic, not sorted). The mask `arange < K` only + # works if acc is already sorted descending by value with the top-K at + # positions 0..K-1. For K_PAD > K (e.g. K=6 with K_PAD=8), this drops + # arbitrary entries, including real top-K entries. Fix: sort by value + # ascending first, then mask the smallest (K_PAD - K) positions. + if N_EXPTS_ACT != N_EXPTS_ACT_PAD: + acc = tl.sort(acc, dim=1) # rotate expert index into upper 16 bits: # 0000vvvvvvvviiii --> iiii0000vvvvvvvv acc = (acc << (y_nbits - 16)) | (acc >> 16) if N_EXPTS_ACT != N_EXPTS_ACT_PAD: - mask_expts_act = tl.arange(0, N_EXPTS_ACT_PAD)[None, :] < N_EXPTS_ACT + mask_expts_act = tl.arange(0, N_EXPTS_ACT_PAD)[None, :] >= ( + N_EXPTS_ACT_PAD - N_EXPTS_ACT + ) acc = tl.where(mask_expts_act, acc, N_EXPTS_PAD << (y_nbits - 16)) # sort in ascending order of expert (descending order of key) acc = tl.sort(acc, dim=1) @@ -115,7 +167,20 @@ def _topk( N_EXPTS_ACT: tl.constexpr, N_EXPTS_ACT_PAD: tl.constexpr, BLOCK_N: tl.constexpr, + Bias=None, + SCORE_MODE: tl.constexpr = "softmax", + HAS_BIAS: tl.constexpr = False, + APPLY_RENORM: tl.constexpr = False, + ROUTED_SCALING: tl.constexpr = 1.0, ): + # Backward-compat sanity. APPLY_SOFTMAX = post-selection softmax over the + # K selected logits (legacy behavior). It only makes sense when no + # pre-transform is applied; for SCORE_MODE="sqrtsoftplus" the caller is + # expected to use APPLY_RENORM + ROUTED_SCALING instead. + tl.static_assert( + (not APPLY_SOFTMAX) or (SCORE_MODE == "softmax"), + "APPLY_SOFTMAX is only valid when SCORE_MODE='softmax'", + ) pid = tl.program_id(0) if isinstance(n_rows, tl.tensor) and n_rows.dtype.is_ptr(): @@ -130,14 +195,12 @@ def _topk( tl.store(SP + offs, tl.zeros([BLOCK_SP], tl.int32), mask=offs < sp_size) if pid * BLOCK_M >= n_rows: - # early exit: return tl.static_assert(BLOCK_N % 32 == 0) tl.static_assert(N_EXPTS_PAD % BLOCK_N == 0) x_dtype: tl.constexpr = X.dtype.element_ty - # load logits offs_m = pid * BLOCK_M + tl.arange(0, BLOCK_M) offs_y_n = tl.arange(0, N_EXPTS_ACT_PAD) mask_m = offs_m[:, None] < n_rows @@ -151,13 +214,41 @@ def _topk( N_EXPTS_ACT, N_EXPTS_ACT_PAD, BLOCK_N, + Bias=Bias, + SCORE_MODE=SCORE_MODE, + HAS_BIAS=HAS_BIAS, + ) + + # Real-entry mask: padded (sentinel) entries are flagged by + # y_indices == N_EXPTS_PAD inside streaming_topk and have y_values = -inf. + # Post-selection ops (bias-subtract, renorm, scaling) must operate on + # real entries only — otherwise -inf poisons the renorm sum, and the + # sentinel y_indices (== N_EXPTS_PAD) would OOB the Bias array. + real_mask = ( + y_indices != N_EXPTS_PAD if N_EXPTS_ACT != N_EXPTS_ACT_PAD else (y_indices >= 0) ) + # For SCORE_MODE="sqrtsoftplus" with HAS_BIAS, the y_values returned by + # streaming_topk are biased scores (sqrt(softplus(x)) + bias) — used for + # selection. We want the unbiased sqrt(softplus(x)) values as the gathered + # weights (the "noaux_tc" pattern from V4). Subtract bias[y_indices]. + if SCORE_MODE == "sqrtsoftplus" and HAS_BIAS: + safe_idx = tl.where(real_mask, y_indices, 0).to(tl.int32) + b_at_idx = tl.load(Bias + safe_idx) + y_unbiased = y_values.to(tl.float32) - b_at_idx + y_values = tl.where(real_mask, y_unbiased, 0.0).to(x_dtype) + # normalize selected values if APPLY_SOFTMAX: y_values = tl.softmax(y_values.to(tl.float32), dim=1, keep_dims=True).to( x_dtype ) + elif APPLY_RENORM: + y_f = tl.where(real_mask, y_values.to(tl.float32), 0.0) + s = tl.sum(y_f, axis=1, keep_dims=True) + y_values = (y_f / (s + 1e-20) * ROUTED_SCALING).to(x_dtype) + elif ROUTED_SCALING != 1.0: + y_values = (y_values.to(tl.float32) * ROUTED_SCALING).to(x_dtype) # write back Yv_ptrs = Yv + offs_m[:, None] * stride_ym + offs_y_n[None, :] @@ -182,3 +273,156 @@ def _topk( r = tl.reduce_or(y2, axis=1) BitsPtrs = Bits + offs_m[:, None] * stride_rm + offs_r_n[None, :] * stride_rn tl.store(BitsPtrs, r, mask=mask_m) + + +@triton.jit +def _hash_routing( + InputIds, # int32 [n_rows] — token-id per row + Tid2Eid, # int32 [vocab_size, K] — per-token-id top-K expert table + stride_t2e_v, # row stride of Tid2Eid (= K) + X, # [n_rows, n_expts_tot] router logits (bf16/fp32) + stride_xm, + Yv, # output expt_scal [n_rows, N_EXPTS_ACT_PAD] + Yi, # output expt_indx [n_rows, N_EXPTS_ACT_PAD] (int16) + stride_ym, + Bits, # bitmatrix data + stride_rm, + stride_rn, + n_rows, + n_expts_tot, + S, # bitmatrix scratchpad (must memset to 0) + BLOCK_S: tl.constexpr, + s_blocks, + SP, # bitmatrix partials (must memset to 0) + BLOCK_SP: tl.constexpr, + sp_blocks, + sp_size, + BLOCK_M: tl.constexpr, + N_EXPTS_PAD: tl.constexpr, # padded n_expts_tot (power of 2 ≥ n_expts_tot) + N_EXPTS_ACT: tl.constexpr, + N_EXPTS_ACT_PAD: tl.constexpr, # next pow2 of K + BLOCK_N: tl.constexpr, + SCORE_MODE: tl.constexpr = "sqrtsoftplus", + APPLY_RENORM: tl.constexpr = True, + ROUTED_SCALING: tl.constexpr = 1.0, +): + """Fused hash routing for DeepSeek-V4 hash layers. + + Replaces _hash_topk (Python: tid2eid lookup + softplus + sqrt + gather + + renorm + scale) AND fused_routing_from_topk (3-kernel counting sort) AND + bitmatrix construction with ONE Triton kernel. Output contract matches + `_topk` so downstream `sort_tokens_fused` consumes it unchanged. + + Pipeline per row: + 1. expt_indx = Tid2Eid[input_id, :K] # tid2eid lookup + 2. raw_scores = sqrt(softplus(X[row, :])) # apply score transform + 3. expt_scal = raw_scores[expt_indx] # gather K weights + 4. (optional) renorm: expt_scal /= expt_scal.sum() ; clamp_min(1e-20) + 5. expt_scal *= routed_scaling_factor + 6. Pack expt_indx into bitmatrix. + + No topk selection — expt_indx is fully determined by Tid2Eid lookup. + """ + pid = tl.program_id(0) + + # Memset bitmatrix scratchpads (mirror _topk pattern) + if pid < s_blocks: + tl.store( + S + BLOCK_S * pid + tl.arange(0, BLOCK_S), tl.zeros([BLOCK_S], tl.int32) + ) + elif pid < s_blocks + sp_blocks: + offs = BLOCK_SP * (pid - s_blocks) + tl.arange(0, BLOCK_SP) + tl.store(SP + offs, tl.zeros([BLOCK_SP], tl.int32), mask=offs < sp_size) + + if pid * BLOCK_M >= n_rows: + return + + tl.static_assert(BLOCK_N % 32 == 0) + tl.static_assert(N_EXPTS_PAD % BLOCK_N == 0) + x_dtype: tl.constexpr = X.dtype.element_ty + + offs_m = pid * BLOCK_M + tl.arange(0, BLOCK_M) + mask_m = offs_m < n_rows + + # 1. Load input_ids for this BLOCK_M, then tid2eid[input_ids[i], :K] + input_ids = tl.load(InputIds + offs_m, mask=mask_m, other=0).to(tl.int32) + offs_k = tl.arange(0, N_EXPTS_ACT_PAD) + mask_k = ( + offs_k < N_EXPTS_ACT + if N_EXPTS_ACT != N_EXPTS_ACT_PAD + else tl.full([N_EXPTS_ACT_PAD], 1, tl.int1) + ) + # Gather Tid2Eid[input_ids[m], k] for m in BLOCK_M, k in K + t2e_offs = input_ids[:, None] * stride_t2e_v + offs_k[None, :] + expt_indx = tl.load( + Tid2Eid + t2e_offs, + mask=mask_m[:, None] & mask_k[None, :], + other=0, + ).to(tl.int32) + + # 2-3. Apply score transform to full row + gather at expt_indx. + # Streaming load of X row in BLOCK_N chunks; accumulate scores per expert + # at expt_indx (each row has K << n_expts_tot, so we test which chunk + # holds each expert index). + y_scores = tl.zeros([BLOCK_M, N_EXPTS_ACT_PAD], dtype=tl.float32) + loop_iterations: tl.constexpr = N_EXPTS_PAD // BLOCK_N + for i in range(loop_iterations): + offs_x_n = i * BLOCK_N + tl.arange(0, BLOCK_N) + mask_n = offs_x_n < n_expts_tot + X_ptrs = X + offs_m[:, None] * stride_xm + offs_x_n[None, :] + x = tl.load(X_ptrs, mask=mask_m[:, None] & mask_n[None, :], other=0.0).to( + tl.float32 + ) + # sqrt(softplus(x)) — numerically stable + if SCORE_MODE == "sqrtsoftplus": + softplus_x = tl.maximum(x, 0.0) + tl.log(1.0 + tl.exp(-tl.abs(x))) + scores = tl.sqrt(softplus_x) + else: + scores = x + # For each (m, k): if expt_indx[m, k] is in [i*BLOCK_N, (i+1)*BLOCK_N), pick scores[m, expt_indx[m, k] - i*BLOCK_N] + # Implement via expand: scores[m, n] vs expt_indx[m, k] mapping. + # gate_mask[m, k, n] = (expt_indx[m, k] == offs_x_n[n]) + match = expt_indx[:, :, None] == offs_x_n[None, None, :] + # picked[m, k] = sum_n match[m, k, n] * scores[m, n] + scores_3d = scores[:, None, :].broadcast_to(BLOCK_M, N_EXPTS_ACT_PAD, BLOCK_N) + picked = tl.where(match, scores_3d, 0.0) + y_scores += tl.sum(picked, axis=2) + + # Real-entry mask (handles N_EXPTS_ACT_PAD > N_EXPTS_ACT padding lanes) + real_mask = mask_k[None, :] + y_f = tl.where(mask_m[:, None] & real_mask, y_scores, 0.0) + + # 4-5. Renorm + scale + if APPLY_RENORM: + s = tl.sum(y_f, axis=1, keep_dims=True) + y_f = y_f / (s + 1e-20) * ROUTED_SCALING + elif ROUTED_SCALING != 1.0: + y_f = y_f * ROUTED_SCALING + + y_values = y_f.to(x_dtype) + + # Write outputs + Yv_ptrs = Yv + offs_m[:, None] * stride_ym + offs_k[None, :] + Yi_ptrs = Yi + offs_m[:, None] * stride_ym + offs_k[None, :] + write_mask = mask_m[:, None] & real_mask + tl.store(Yv_ptrs, y_values, mask=write_mask) + tl.store(Yi_ptrs, expt_indx, mask=write_mask) + + # Pack into bitmatrix (mirror _topk pattern; sentinels (indx=0 in padded + # lanes) safely OR with a real bit since y_values=0 in those lanes won't + # affect downstream — but to be safe, mask out padded lanes from packing). + safe_indx = tl.where(real_mask, expt_indx, 0).to(tl.int32) + y_div = safe_indx // 32 + y_rem = safe_indx % 32 + bm_iters: tl.constexpr = N_EXPTS_PAD // BLOCK_N + for i in range(bm_iters): + offs_r_n = tl.arange(0, BLOCK_N // 32) + i * (BLOCK_N // 32) + # Only contribute bits from real lanes + y2 = tl.where( + (y_div[:, :, None] == offs_r_n[None, None, :]) & real_mask[:, :, None], + (1 << y_rem)[:, :, None], + 0, + ) + r = tl.reduce_or(y2, axis=1) + BitsPtrs = Bits + offs_m[:, None] * stride_rm + offs_r_n[None, :] * stride_rn + tl.store(BitsPtrs, r, mask=mask_m[:, None]) diff --git a/aiter/ops/triton/_triton_kernels/moe/reduce.py b/aiter/ops/triton/_triton_kernels/moe/reduce.py index 22d983a767..e318176181 100644 --- a/aiter/ops/triton/_triton_kernels/moe/reduce.py +++ b/aiter/ops/triton/_triton_kernels/moe/reduce.py @@ -27,6 +27,12 @@ def _reduce_grouped( EVEN_N: tl.constexpr, ADD_RESIDUAL: tl.constexpr, USE_TDM: tl.constexpr, + # Step 9: external residual fold-in. When HAS_EXT_RESIDUAL=True, + # Residual[token, :] is added to `acc` before the writeback. + Residual, + stride_extres_m: tl.uint64, + stride_extres_n, + HAS_EXT_RESIDUAL: tl.constexpr, ): pid = tl.program_id(0) pid_t = pid // num_blocks @@ -76,6 +82,18 @@ def _reduce_grouped( # Compute per-32-col MXFP scales for this tile if requested Nrem = N // ACTIVATION_REDUCTION_N + # Step 9: optional external residual fold-in: load residual at this + # tile and add to acc before writeback. Same per-token-row layout as Out. + if HAS_EXT_RESIDUAL: + res_offs_n = pid_n * BLOCK_N_OUT + tl.arange(0, BLOCK_N_OUT) + res_ptr = Residual + pid_t * stride_extres_m + res_offs_n * stride_extres_n + if EVEN_N: + res = tl.load(res_ptr).to(tl.float32) + acc = acc + res + else: + res_mask = res_offs_n < Nrem + res = tl.load(res_ptr, mask=res_mask, other=0.0).to(tl.float32) + acc = acc + res # write-back for this tile out_ptr = OutPtrs + pid_t * stride_om if EVEN_N: diff --git a/aiter/ops/triton/_triton_kernels/quant/fused_fp8_quant.py b/aiter/ops/triton/_triton_kernels/quant/fused_fp8_quant.py index 5acc36b81f..4f02dbf4e5 100644 --- a/aiter/ops/triton/_triton_kernels/quant/fused_fp8_quant.py +++ b/aiter/ops/triton/_triton_kernels/quant/fused_fp8_quant.py @@ -287,10 +287,16 @@ def _fused_flatten_fp8_group_quant_kernel( n1 = tl.program_id(1) NUM_QUANT_BLOCKS: tl.constexpr = BLOCK_SIZE_N2 // QUANT_BLOCK_SIZE + # In the flattened (M, N1 * N2) output, each n1 segment is exactly N2 wide + # (not BLOCK_SIZE_N2), so stride between n1 segments must use N2 — otherwise + # non-power-of-2 N2 (e.g. 7168) over-strides the output (and at the last n1 + # walks past the row boundary, causing OOB writes). + n2_groups = tl.cdiv(N2, QUANT_BLOCK_SIZE) n2_offs = tl.arange(0, BLOCK_SIZE_N2) + x_mask = n2_offs < N2 x_offs = m * x_stride_m + n1 * x_stride_n1 + n2_offs * x_stride_n2 - x = tl.load(x_ptr + x_offs, mask=n2_offs < N2) + x = tl.load(x_ptr + x_offs, mask=x_mask, other=0.0) out, out_block_scales = _fp8_quant_op( x, 1, BLOCK_SIZE_N2, QUANT_BLOCK_SIZE, DTYPE_MAX, DTYPE_MIN @@ -299,17 +305,17 @@ def _fused_flatten_fp8_group_quant_kernel( out_block_scales = tl.ravel(out_block_scales) tl.store( - out_ptr + m * out_stride_m + (n1 * BLOCK_SIZE_N2 + n2_offs) * out_stride_n, + out_ptr + m * out_stride_m + (n1 * N2 + n2_offs) * out_stride_n, out.to(out_ptr.dtype.element_ty), - mask=n2_offs < N2, + mask=x_mask, ) block_scale_offs = tl.arange(0, NUM_QUANT_BLOCKS) tl.store( out_scales_ptr + m * out_scales_stride_m - + (n1 * NUM_QUANT_BLOCKS + block_scale_offs) * out_scales_stride_n, + + (n1 * n2_groups + block_scale_offs) * out_scales_stride_n, out_block_scales.to(out_scales_ptr.dtype.element_ty), - mask=block_scale_offs < tl.cdiv(N2, QUANT_BLOCK_SIZE), + mask=block_scale_offs < n2_groups, ) diff --git a/aiter/ops/triton/_triton_kernels/quant/fused_mxfp8_quant.py b/aiter/ops/triton/_triton_kernels/quant/fused_mxfp8_quant.py new file mode 100644 index 0000000000..5b63a01f0e --- /dev/null +++ b/aiter/ops/triton/_triton_kernels/quant/fused_mxfp8_quant.py @@ -0,0 +1,260 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. + +import triton +import triton.language as tl + +from .quant import _mxfp8_quant_op + +# Fused RMSNorm + MXFP8 (1x32 e8m0) quant. Replaces the separate +# rmsnorm_quant(fp8 fnuz + fp32 1x128) + transcode-to-MXFP8 sequence used +# upstream of MXFP8-aware GEMMs (e.g. V4 q_norm -> wq_b). +# +# One program per row. Holds the full row in registers, so K is constrained +# by the BLOCK_SIZE_K constexpr (must be a power of two >= K). +# +# In: x (M, K) bf16 or fp16 +# g (K,) bf16 or fp16 weight +# Out: y (M, K) fp8 e4m3fn +# scale (M, K // 32) uint8 e8m0 + + +@triton.jit +def _fused_rms_mxfp8_kernel( + x_ptr, + g_ptr, + y_ptr, + s_ptr, + M, + K, + stride_xm, + stride_xk, + stride_ym, + stride_yk, + stride_sm, + stride_sn, + epsilon, + BLOCK_SIZE_K: tl.constexpr, # power-of-2 covering full K + QUANT_BLOCK_SIZE: tl.constexpr, # =32 + NUM_PRGMS: tl.constexpr, # for persistent-loop variant; usually =M +): + """One program processes one row: rmsnorm then MXFP8 quant in registers.""" + row_start = tl.program_id(0) + col_offsets = tl.arange(0, BLOCK_SIZE_K) + mask = col_offsets < K + + for row_idx in tl.range(row_start, M, NUM_PRGMS, num_stages=2): + # Load full row, cast to fp32 + x = tl.load( + x_ptr + row_idx * stride_xm + col_offsets * stride_xk, + mask=mask, + other=0.0, + ).to(tl.float32) + g = tl.load(g_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32) + + # RMS norm + ss = tl.sum(x * x, axis=-1) + norm_factor = tl.math.rsqrt((ss / K) + epsilon) + y_fp32 = x * norm_factor * g # (BLOCK_SIZE_K,) + + # Reshape into (K // QUANT_BLOCK_SIZE, QUANT_BLOCK_SIZE) groups for amax. + # BLOCK_SIZE_K is the power-of-2 padded size; we keep OOB lanes masked to 0 + # via the load above, so amax over them is 0 (won't affect the in-bounds max). + y_2d = tl.reshape(y_fp32, (BLOCK_SIZE_K // QUANT_BLOCK_SIZE, QUANT_BLOCK_SIZE)) + scale_e8m0, quant_scale = _mxfp8_quant_op(y_2d, QUANT_AXIS=1) + + # Quantize: y_quant = y_fp32 * quant_scale (broadcast along inner 32). + qx_2d = y_2d * quant_scale + qx = tl.reshape(qx_2d, (BLOCK_SIZE_K,)) + y_fp8 = qx.to(y_ptr.type.element_ty) + + # Store y (mask OOB). + tl.store( + y_ptr + row_idx * stride_ym + col_offsets * stride_yk, + y_fp8, + mask=mask, + ) + + # Store scales: G entries for this row. + n_groups: tl.constexpr = BLOCK_SIZE_K // QUANT_BLOCK_SIZE + group_offsets = tl.arange(0, n_groups) + group_mask = group_offsets < (K // QUANT_BLOCK_SIZE) + scale_flat = tl.reshape(scale_e8m0, (n_groups,)) + tl.store( + s_ptr + row_idx * stride_sm + group_offsets * stride_sn, + scale_flat, + mask=group_mask, + ) + + +# Dual fused RMSNorm: Q-side (MXFP8 quant + e8m0 scale emit) + K-side (bf16 out). +# Replaces the CK `fused_qk_rmsnorm_group_quant` semantics in one Triton launch +# for the MXFP8 GEMM path (Task #77). The two halves are independent (different +# weight, different K dim) so they're packed into one program per row to amortize +# launch overhead: same kernel launch loads both rows, normalizes both, stores Q +# fp8 + scale, stores K bf16. Each row's Q and K are independently RMSNorm'd +# (separate weights, separate eps, separate K dim) -- this kernel does NOT fuse +# their normalization arithmetic, only their launch. +# +# In: q (M, KQ) bf16 or fp16 +# kv (M, KK) bf16 or fp16 +# gq (KQ,) bf16 or fp16 Q-RMSNorm weight +# gk (KK,) bf16 or fp16 K-RMSNorm weight +# Out: yq (M, KQ) fp8 e4m3fn +# sq (M, KQ // 32) uint8 e8m0 +# yk (M, KK) bf16 + + +@triton.jit +def _fused_dual_rmsnorm_mxfp8_quant_kernel( + q_ptr, + k_ptr, + gq_ptr, + gk_ptr, + yq_ptr, + sq_ptr, + yk_ptr, + M, + KQ, + KK, + stride_qm, + stride_qn, + stride_km, + stride_kn, + stride_yqm, + stride_yqn, + stride_sqm, + stride_sqn, + stride_ykm, + stride_ykn, + eps_q, + eps_k, + BLOCK_SIZE_KQ: tl.constexpr, # power-of-2 covering full KQ + BLOCK_SIZE_KK: tl.constexpr, # power-of-2 covering full KK + QUANT_BLOCK_SIZE: tl.constexpr, # =32 (MXFP8 group size) + NUM_PRGMS: tl.constexpr, # row-loop bound (usually =M) +): + """One program per row: do Q-side RMSNorm+MXFP8 quant AND K-side RMSNorm + (bf16 out) in one launch. Mirrors the CK `fused_qk_rmsnorm_group_quant` + fusion topology but emits MXFP8 1x32 (e8m0) scales for Q directly.""" + row_start = tl.program_id(0) + + q_col_offsets = tl.arange(0, BLOCK_SIZE_KQ) + q_mask = q_col_offsets < KQ + k_col_offsets = tl.arange(0, BLOCK_SIZE_KK) + k_mask = k_col_offsets < KK + + n_q_groups: tl.constexpr = BLOCK_SIZE_KQ // QUANT_BLOCK_SIZE + + for row_idx in tl.range(row_start, M, NUM_PRGMS, num_stages=2): + # ===== Q side: RMSNorm + MXFP8 quant ===== + x_q = tl.load( + q_ptr + row_idx * stride_qm + q_col_offsets * stride_qn, + mask=q_mask, + other=0.0, + ).to(tl.float32) + g_q = tl.load(gq_ptr + q_col_offsets, mask=q_mask, other=0.0).to(tl.float32) + + ss_q = tl.sum(x_q * x_q, axis=-1) + norm_q = tl.math.rsqrt((ss_q / KQ) + eps_q) + y_q_fp32 = x_q * norm_q * g_q + + y_q_2d = tl.reshape(y_q_fp32, (n_q_groups, QUANT_BLOCK_SIZE)) + scale_q_e8m0, quant_scale_q = _mxfp8_quant_op(y_q_2d, QUANT_AXIS=1) + + qx_q_2d = y_q_2d * quant_scale_q + qx_q = tl.reshape(qx_q_2d, (BLOCK_SIZE_KQ,)) + y_q_fp8 = qx_q.to(yq_ptr.type.element_ty) + + tl.store( + yq_ptr + row_idx * stride_yqm + q_col_offsets * stride_yqn, + y_q_fp8, + mask=q_mask, + ) + + q_group_offsets = tl.arange(0, n_q_groups) + q_group_mask = q_group_offsets < (KQ // QUANT_BLOCK_SIZE) + scale_q_flat = tl.reshape(scale_q_e8m0, (n_q_groups,)) + tl.store( + sq_ptr + row_idx * stride_sqm + q_group_offsets * stride_sqn, + scale_q_flat, + mask=q_group_mask, + ) + + # ===== K side: RMSNorm only, bf16 out ===== + x_k = tl.load( + k_ptr + row_idx * stride_km + k_col_offsets * stride_kn, + mask=k_mask, + other=0.0, + ).to(tl.float32) + g_k = tl.load(gk_ptr + k_col_offsets, mask=k_mask, other=0.0).to(tl.float32) + + ss_k = tl.sum(x_k * x_k, axis=-1) + norm_k = tl.math.rsqrt((ss_k / KK) + eps_k) + y_k_fp32 = x_k * norm_k * g_k + y_k_out = y_k_fp32.to(yk_ptr.type.element_ty) + + tl.store( + yk_ptr + row_idx * stride_ykm + k_col_offsets * stride_ykn, + y_k_out, + mask=k_mask, + ) + + +# Flatten-then-MXFP8 quant. Takes (M, N1, N2) input, flattens the trailing two +# dims into N = N1 * N2, and emits per-1x32 MXFP8 (FP8 e4m3fn values + uint8 +# e8m0 scales) along the flattened axis. One program per (m, n1); each program +# handles a row of N2 elements that contributes BLOCK_SIZE_N2 // 32 groups to +# the M-th row of the (M, N) flattened output. + + +@triton.jit +def _fused_flatten_mxfp8_quant_kernel( + x_ptr, + out_ptr, + out_scales_ptr, + x_stride_m, + x_stride_n1, + x_stride_n2, + out_stride_m, + out_stride_n, + out_scales_stride_m, + out_scales_stride_n, + N2, + BLOCK_SIZE_N2: tl.constexpr, + QUANT_BLOCK_SIZE: tl.constexpr, +): + m = tl.program_id(0) + n1 = tl.program_id(1) + + NUM_QUANT_BLOCKS: tl.constexpr = BLOCK_SIZE_N2 // QUANT_BLOCK_SIZE + # In the flattened (M, N1 * N2) output, each n1 segment is exactly N2 wide + # (not BLOCK_SIZE_N2), so stride between n1 segments must use N2 — otherwise + # non-power-of-2 N2 (e.g. 7168) would gap-write the output. + n2_groups = N2 // QUANT_BLOCK_SIZE + + n2_offs = tl.arange(0, BLOCK_SIZE_N2) + x_mask = n2_offs < N2 + x_offs = m * x_stride_m + n1 * x_stride_n1 + n2_offs * x_stride_n2 + x = tl.load(x_ptr + x_offs, mask=x_mask, other=0.0).to(tl.float32) + + x_2d = tl.reshape(x, (NUM_QUANT_BLOCKS, QUANT_BLOCK_SIZE)) + scale_e8m0, quant_scale = _mxfp8_quant_op(x_2d, QUANT_AXIS=1) + + qx_2d = x_2d * quant_scale + qx = tl.reshape(qx_2d, (BLOCK_SIZE_N2,)) + tl.store( + out_ptr + m * out_stride_m + (n1 * N2 + n2_offs) * out_stride_n, + qx.to(out_ptr.type.element_ty), + mask=x_mask, + ) + + block_scale_offs = tl.arange(0, NUM_QUANT_BLOCKS) + scale_flat = tl.reshape(scale_e8m0, (NUM_QUANT_BLOCKS,)) + tl.store( + out_scales_ptr + + m * out_scales_stride_m + + (n1 * n2_groups + block_scale_offs) * out_scales_stride_n, + scale_flat, + mask=block_scale_offs < n2_groups, + ) diff --git a/aiter/ops/triton/_triton_kernels/quant/quant.py b/aiter/ops/triton/_triton_kernels/quant/quant.py index 3b88c8b2b2..5e24277375 100644 --- a/aiter/ops/triton/_triton_kernels/quant/quant.py +++ b/aiter/ops/triton/_triton_kernels/quant/quant.py @@ -270,3 +270,160 @@ def _dynamic_mxfp4_quant_kernel( bs_e8m0, mask=bs_mask, ) + + +# MXFP8 (1x32 e8m0) quant: derives a per-block uint8 e8m0 scale + FP8 e4m3 +# values. The bit-trick (bitcast amax to int32, add 0x200000, mask 0xFF800000, +# bitcast back to fp32) rounds amax up to a power of 2; log2(amax).floor() - 8 +# is the unbiased e8m0 exponent (dtypeMax = 2**8). + + +@triton.jit +def _mxfp8_quant_op(x_grouped, QUANT_AXIS: tl.constexpr): + """Shared MXFP8 (1x32 e8m0) scale derivation. + + Given a fp32 tile where the QUANT_AXIS dim is sized QUANT_BLOCK_SIZE (=32), + returns (scale_e8m0, quant_scale): the per-group uint8 e8m0 scale and the + matching fp32 multiplicative scale. Both outputs keep QUANT_AXIS with size 1 + so they broadcast against the input for in-place quantization. + """ + amax = tl.max(tl.abs(x_grouped), axis=QUANT_AXIS, keep_dims=True) + amax_i32 = amax.to(tl.int32, bitcast=True) + amax_i32 = (amax_i32 + 0x200000).to(tl.uint32, bitcast=True) & 0xFF800000 + amax_p2 = amax_i32.to(tl.float32, bitcast=True) + scale_unbiased = tl.log2(amax_p2).floor() - 8 + scale_unbiased = tl.clamp(scale_unbiased, min=-127, max=127) + scale_e8m0 = (scale_unbiased.to(tl.int32) + 127).to(tl.uint8) + quant_scale = tl.exp2(-scale_unbiased) + return scale_e8m0, quant_scale + + +@triton.jit +def _dynamic_mxfp8_quant_kernel( + x_ptr, + y_ptr, + s_ptr, + M, + N, + stride_xm, + stride_xn, + stride_ym, + stride_yn, + stride_sm, + stride_sn, + BLOCK_SIZE_N: tl.constexpr, # power-of-2 covering full N + QUANT_BLOCK_SIZE: tl.constexpr, # =32 + NUM_PRGMS: tl.constexpr, # row-loop range (usually =M) +): + """ + Per-1x32 MXFP8 quant. One program per row, holding the full row in + registers so a single launch handles all K-groups. Mirrors + _fused_rms_mxfp8_kernel shape (in fused_mxfp8_quant.py) and minimizes + grid overhead. + """ + row_start = tl.program_id(0) + col_offsets = tl.arange(0, BLOCK_SIZE_N) + mask = col_offsets < N + n_groups: tl.constexpr = BLOCK_SIZE_N // QUANT_BLOCK_SIZE + + for row_idx in tl.range(row_start, M, NUM_PRGMS, num_stages=2): + x = tl.load( + x_ptr + row_idx * stride_xm + col_offsets * stride_xn, + mask=mask, + other=0.0, + ).to(tl.float32) + + # (BLOCK_SIZE_N,) -> (n_groups, QUANT_BLOCK_SIZE) + x_2d = tl.reshape(x, (n_groups, QUANT_BLOCK_SIZE)) + scale_e8m0, quant_scale = _mxfp8_quant_op(x_2d, QUANT_AXIS=1) + + qx_2d = x_2d * quant_scale + qx = tl.reshape(qx_2d, (BLOCK_SIZE_N,)) + y = qx.to(y_ptr.type.element_ty) + + tl.store( + y_ptr + row_idx * stride_ym + col_offsets * stride_yn, + y, + mask=mask, + ) + + group_offsets = tl.arange(0, n_groups) + group_mask = group_offsets < (N // QUANT_BLOCK_SIZE) + scale_flat = tl.reshape(scale_e8m0, (n_groups,)) + tl.store( + s_ptr + row_idx * stride_sm + group_offsets * stride_sn, + scale_flat, + mask=group_mask, + ) + + +# Transcoder: (FP8 fnuz, fp32 1x128 scale) -> (FP8 fn, e8m0 1x32 scale). +# Replaces the Python dequant+requant cascade (fp32 cast + multiply + bf16 cast +# + per_1x32_mxfp8 quant) used in linear.py's MXFP8 fallback path for MLA wq_b +# when q_norm emits the legacy fp8 fnuz + fp32 1x128 format. +# +# In: x_fp8_fnuz (M, N) — fp8 e4m3fnuz bits (interpreted with bias 8 -> value) +# x_scale_fp32 (M, N//128) — fp32 per-token-block scale +# Out: y_fp8_fn (M, N) — fp8 e4m3fn bits (NV format, bias 7) +# y_scale_e8m0 (M, N//32) — uint8 e8m0 (1x32 MX scale) + + +@triton.jit +def _fp8_legacy_to_mxfp8_kernel( + x_fnuz_ptr, + x_scale_fp32_ptr, + y_fn_ptr, + y_scale_e8m0_ptr, + M, + N, + stride_xm, + stride_xn, + stride_xsm, + stride_xsn, + stride_ym, + stride_yn, + stride_ysm, + stride_ysn, + BLOCK_SIZE_M: tl.constexpr, + QUANT_BLOCK_SIZE: tl.constexpr, # =32 (MXFP8 group) + LEGACY_BLOCK_SIZE: tl.constexpr, # =128 (input scale group) +): + """ + One program per (BLOCK_SIZE_M rows, QUANT_BLOCK_SIZE-element column window). + For each 1x32 block, dequantize fnuz fp8 values using the corresponding + 1x128 fp32 scale, derive the e8m0 (1x32) scale, then re-quantize to fp8 fn. + """ + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * QUANT_BLOCK_SIZE + tl.arange(0, QUANT_BLOCK_SIZE) + + x_offs = offs_m[:, None] * stride_xm + offs_n[None, :] * stride_xn + x_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + + # Load fp8 fnuz values; .to(fp32) decodes via fnuz bias 8 semantically. + x_fnuz = tl.load(x_fnuz_ptr + x_offs, mask=x_mask, other=0.0).to(tl.float32) + + # Which legacy 1x128 group does this 1x32 block fall into? + legacy_n = (pid_n * QUANT_BLOCK_SIZE) // LEGACY_BLOCK_SIZE + xs_offs = offs_m * stride_xsm + legacy_n * stride_xsn + xs_mask = offs_m < M + x_scale = tl.load(x_scale_fp32_ptr + xs_offs, mask=xs_mask, other=1.0) + + # Dequantize: bf16-equivalent reconstruction. + x_dq = x_fnuz * x_scale[:, None] + + # Derive new e8m0 (1x32) scale from x_dq amax. + scale_e8m0, quant_scale = _mxfp8_quant_op(x_dq, QUANT_AXIS=1) + + # Re-quantize to fp8 fn. + qx = x_dq * quant_scale + y = qx.to(y_fn_ptr.type.element_ty) + + y_offs = offs_m[:, None] * stride_ym + offs_n[None, :] * stride_yn + tl.store(y_fn_ptr + y_offs, y, mask=x_mask) + + s_offs = offs_m[:, None] * stride_ysm + pid_n * stride_ysn + s_mask = offs_m[:, None] < M + tl.store(y_scale_e8m0_ptr + s_offs, scale_e8m0, mask=s_mask) diff --git a/aiter/ops/triton/configs/gemm/gfx942-GEMM-AFP8WFP8.json b/aiter/ops/triton/configs/gemm/gfx942-GEMM-AFP8WFP8.json new file mode 100644 index 0000000000..c7271acf94 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/gfx942-GEMM-AFP8WFP8.json @@ -0,0 +1,14 @@ +{ + "any": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + } +} diff --git a/aiter/ops/triton/configs/gemm/gfx942-GEMM-AFP8WFP8_PRESHUFFLED.json b/aiter/ops/triton/configs/gemm/gfx942-GEMM-AFP8WFP8_PRESHUFFLED.json new file mode 100644 index 0000000000..0008e2ab97 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/gfx942-GEMM-AFP8WFP8_PRESHUFFLED.json @@ -0,0 +1,14 @@ +{ + "any": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + } +} diff --git a/aiter/ops/triton/configs/gemm/gfx950-GEMM-A16W16-N=384-K=7168.json b/aiter/ops/triton/configs/gemm/gfx950-GEMM-A16W16-N=384-K=7168.json new file mode 100644 index 0000000000..9b9cce2832 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/gfx950-GEMM-A16W16-N=384-K=7168.json @@ -0,0 +1,98 @@ +{ + "M_LEQ_4": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 6, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 7 + }, + "M_LEQ_8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 8, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 7 + }, + "M_LEQ_16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 1024, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 7 + }, + "M_LEQ_32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 7 + }, + "M_LEQ_64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 7 + }, + "M_LEQ_128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 7 + }, + "M_LEQ_256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 6, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 7 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + } +} diff --git a/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8.json b/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8.json new file mode 100644 index 0000000000..c7271acf94 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8.json @@ -0,0 +1,14 @@ +{ + "any": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + } +} diff --git a/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=1536-K=4096.json b/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=1536-K=4096.json new file mode 100644 index 0000000000..6b2567e194 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=1536-K=4096.json @@ -0,0 +1,170 @@ +{ + "M_LEQ_4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_8192": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_GEQ_8192": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + } +} \ No newline at end of file diff --git a/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=2048-K=7168.json b/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=2048-K=7168.json new file mode 100644 index 0000000000..a07c6c2d26 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=2048-K=7168.json @@ -0,0 +1,98 @@ +{ + "M_LEQ_4": { + "BLOCK_SIZE_M": 8, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 7 + }, + "M_LEQ_8": { + "BLOCK_SIZE_M": 4, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 6, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_16": { + "BLOCK_SIZE_M": 8, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "any": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + } +} diff --git a/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=4096-K=1024.json b/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=4096-K=1024.json new file mode 100644 index 0000000000..5ba817728b --- /dev/null +++ b/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=4096-K=1024.json @@ -0,0 +1,170 @@ +{ + "M_LEQ_4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_8192": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_GEQ_8192": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "any": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + } +} \ No newline at end of file diff --git a/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=4096-K=256.json b/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=4096-K=256.json new file mode 100644 index 0000000000..e4b178dda7 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=4096-K=256.json @@ -0,0 +1,170 @@ +{ + "M_LEQ_4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_8192": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_GEQ_8192": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + } +} \ No newline at end of file diff --git a/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=512-K=4096.json b/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=512-K=4096.json new file mode 100644 index 0000000000..91361fa902 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=512-K=4096.json @@ -0,0 +1,170 @@ +{ + "M_LEQ_4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 4 + }, + "M_LEQ_8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_8192": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_GEQ_8192": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + } +} \ No newline at end of file diff --git a/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=7168-K=2048.json b/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=7168-K=2048.json new file mode 100644 index 0000000000..6f74cb9e7f --- /dev/null +++ b/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=7168-K=2048.json @@ -0,0 +1,98 @@ +{ + "M_LEQ_4": { + "BLOCK_SIZE_M": 4, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "M_LEQ_8": { + "BLOCK_SIZE_M": 8, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "M_LEQ_16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "M_LEQ_32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "M_LEQ_64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + } +} diff --git a/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=7168-K=384.json b/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=7168-K=384.json new file mode 100644 index 0000000000..2a3b281530 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=7168-K=384.json @@ -0,0 +1,98 @@ +{ + "M_LEQ_4": { + "BLOCK_SIZE_M": 4, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 6, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "M_LEQ_8": { + "BLOCK_SIZE_M": 4, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "M_LEQ_16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "M_LEQ_32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "M_LEQ_64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + } +} diff --git a/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=768-K=7168.json b/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=768-K=7168.json new file mode 100644 index 0000000000..2d88b2f30c --- /dev/null +++ b/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=768-K=7168.json @@ -0,0 +1,98 @@ +{ + "M_LEQ_4": { + "BLOCK_SIZE_M": 8, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 7 + }, + "M_LEQ_8": { + "BLOCK_SIZE_M": 4, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_16": { + "BLOCK_SIZE_M": 8, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 7 + }, + "M_LEQ_32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 7 + }, + "M_LEQ_64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + } +} diff --git a/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=8192-K=1024.json b/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=8192-K=1024.json new file mode 100644 index 0000000000..cb2547fbbd --- /dev/null +++ b/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=8192-K=1024.json @@ -0,0 +1,170 @@ +{ + "M_LEQ_4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_2048": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_8192": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_GEQ_8192": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + } +} \ No newline at end of file diff --git a/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=8192-K=1536.json b/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=8192-K=1536.json new file mode 100644 index 0000000000..defffda52d --- /dev/null +++ b/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=8192-K=1536.json @@ -0,0 +1,98 @@ +{ + "M_LEQ_4": { + "BLOCK_SIZE_M": 4, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 8, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "M_LEQ_8": { + "BLOCK_SIZE_M": 8, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 6, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "M_LEQ_16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "M_LEQ_32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "M_LEQ_64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + } +} diff --git a/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED.json b/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED.json new file mode 100644 index 0000000000..0008e2ab97 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED.json @@ -0,0 +1,14 @@ +{ + "any": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + } +} diff --git a/aiter/ops/triton/fusions/fused_clamp_act_mul.py b/aiter/ops/triton/fusions/fused_clamp_act_mul.py index 0b1b8a1a63..af3afe0ad5 100644 --- a/aiter/ops/triton/fusions/fused_clamp_act_mul.py +++ b/aiter/ops/triton/fusions/fused_clamp_act_mul.py @@ -25,6 +25,9 @@ def fused_clamp_act_mul( weights: Optional[torch.Tensor] = None, dtype_quant: torch.dtype | None = None, transpose_scale: bool = False, + quant_block_size: int = 128, + scale_dtype_fmt: Literal["fp32", "ue8m0"] = "fp32", + shuffle_scale: bool = False, ): """ Fused clamp (SwiGLU-style) + act(gate) * up + optional weights, with optional FP8 group quant. @@ -54,6 +57,26 @@ def fused_clamp_act_mul( HAS_QUANT = dtype_quant is not None + assert scale_dtype_fmt in ("fp32", "ue8m0") + if scale_dtype_fmt == "ue8m0": + assert HAS_QUANT, "scale_dtype_fmt='ue8m0' requires dtype_quant" + assert ( + quant_block_size == 32 + ), f"ue8m0 requires quant_block_size=32 got {quant_block_size}" + assert dtype_quant in ( + torch.float8_e4m3fn, + torch.float8_e4m3fnuz, + ), f"ue8m0 requires fp8 e4m3, got {dtype_quant}" + assert not ( + shuffle_scale and transpose_scale + ), "shuffle_scale incompatible with transpose_scale" + _scale_storage_dtype = torch.uint8 + else: + assert ( + not shuffle_scale + ), "shuffle_scale only valid with scale_dtype_fmt='ue8m0'" + _scale_storage_dtype = torch.float32 + if HAS_QUANT: if out is None: out = torch.empty((M, n_half), dtype=dtype_quant, device=inp.device) @@ -65,15 +88,29 @@ def fused_clamp_act_mul( dtype_quant, out.dtype, ) - num_blocks = (n_half + 127) // 128 - if scale is None: + num_blocks = (n_half + quant_block_size - 1) // quant_block_size + if shuffle_scale: + # Scales are preshuffled inside the kernel (see e8m0_shuffle / + # aiter.ops.shuffle.shuffle_scale): rows padded to a multiple of 256 + # and block-cols to a multiple of 8, written in the tiled layout. + scale_m_pad = (M + 255) // 256 * 256 + scale_n_pad = (num_blocks + 7) // 8 * 8 + if scale is None: + scale = torch.empty( + (scale_m_pad, scale_n_pad), + dtype=_scale_storage_dtype, + device=inp.device, + ) + else: + assert scale.shape == (scale_m_pad, scale_n_pad) + elif scale is None: if transpose_scale: scale = torch.empty( - (num_blocks, M), dtype=torch.float32, device=inp.device + (num_blocks, M), dtype=_scale_storage_dtype, device=inp.device ) else: scale = torch.empty( - (M, num_blocks), dtype=torch.float32, device=inp.device + (M, num_blocks), dtype=_scale_storage_dtype, device=inp.device ) else: if transpose_scale: @@ -121,8 +158,16 @@ def fused_clamp_act_mul( HAVE_SWIGLU_CLAMP = swiglu_limit > 0 + scale_n_pad = 0 if HAS_QUANT: - if transpose_scale: + if shuffle_scale: + # Kernel writes directly into the (scale_m_pad, scale_n_pad) buffer + # using the shuffled offset, so the plain row/col strides are unused. + scale_row_stride = scale.stride(0) + scale_col_stride = scale.stride(1) + num_bs_cols = scale.shape[1] + scale_n_pad = scale.shape[1] + elif transpose_scale: scale_row_stride = scale.stride(1) scale_col_stride = scale.stride(0) num_bs_cols = scale.shape[0] @@ -153,7 +198,8 @@ def fused_clamp_act_mul( weights.stride(1) if HAVE_WEIGHTS else 0, swiglu_limit, BLOCK_SIZE_N=BLOCK_SIZE_N, - QUANT_BLOCK_SIZE=128, + QUANT_BLOCK_SIZE=quant_block_size, + SCALE_FMT=scale_dtype_fmt, DTYPE_MAX=DTYPE_MAX, DTYPE_MIN=-DTYPE_MAX, HAVE_WEIGHTS=HAVE_WEIGHTS, @@ -161,6 +207,8 @@ def fused_clamp_act_mul( HAVE_SWIGLU_CLAMP=HAVE_SWIGLU_CLAMP, HAS_QUANT=HAS_QUANT, ACTIVATION=activation, + SHUFFLE=shuffle_scale, + SCALE_N_PAD=scale_n_pad, num_warps=num_warps, ) diff --git a/aiter/ops/triton/gemm/basic/gemm_afp8wfp8.py b/aiter/ops/triton/gemm/basic/gemm_afp8wfp8.py new file mode 100644 index 0000000000..2fc13b0c24 --- /dev/null +++ b/aiter/ops/triton/gemm/basic/gemm_afp8wfp8.py @@ -0,0 +1,267 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. + +from typing import Optional + +import torch +import triton + +from aiter.ops.triton._triton_kernels.gemm.basic.gemm_afp8wfp8 import ( + _gemm_afp8wfp8_kernel, + _gemm_afp8wfp8_preshuffle_kernel, + _get_config, +) +from aiter.ops.triton._triton_kernels.common.splitk_reduce import ( + _gemm_splitk_reduce_kernel, +) + + +def gemm_afp8wfp8( + x: torch.Tensor, + w: torch.Tensor, + x_scales: torch.Tensor, + w_scales: torch.Tensor, + dtype: Optional[torch.dtype] = torch.bfloat16, + y: Optional[torch.Tensor] = None, + config: Optional[dict] = None, + skip_reduce: Optional[bool] = False, +) -> torch.Tensor: + """ + Computes matrix multiplication Y = X @ W^T with MXFP8 activations and FP8 + weights (1x32 e8m0 act scales, 128x128 e8m0 weight scales). + + Args: + x: FP8 e4m3 (or uint8 view) input matrix with shape (M, K). + w: FP8 e4m3 (or uint8 view) weight matrix with shape (N, K) — internally + transposed to (K, N) before the kernel call. + x_scales: e8m0 (uint8) per-group scale for x with shape (M, K // 32). + w_scales: e8m0 (uint8) per-block scale for w with shape (N // 128, K // 128). + dtype: Output dtype (BF16 or FP16). Default bf16. + y: Optional pre-allocated output tensor with shape (M, N). + config: Optional kernel-tuning dict. If None uses defaults. + + Returns: + torch.Tensor: Output with shape (M, N). + """ + M, K = x.shape + N, K_w = w.shape + assert K == K_w, f"K mismatch: x has K={K}, w has K={K_w}" + + # Transpose w to (K, N) for the kernel. + w_t = w.T + + # tl.dot_scaled with format "e4m3" expects uint8-typed operands; reinterpret + # the FP8 buffers as uint8 (bit-identical view). + if x.dtype != torch.uint8: + x = x.view(torch.uint8) + if w_t.dtype != torch.uint8: + w_t = w_t.view(torch.uint8) + + if config is None: + config, _ = _get_config(M, N, K) + + if y is None and (config["NUM_KSPLIT"] == 1 or not skip_reduce): + y = torch.empty((M, N), dtype=dtype, device=x.device) + + config["SPLITK_BLOCK_SIZE"] = triton.cdiv( + K, config["NUM_KSPLIT"] + ) # How big each split_k partition is + if config["NUM_KSPLIT"] > 1: + y_pp = torch.empty( + (config["NUM_KSPLIT"], M, N), + dtype=torch.float32, + device=x.device, + ) + else: + y_pp = None + + grid = lambda META: ( # noqa: E731 + ( + META["NUM_KSPLIT"] + * triton.cdiv(M, META["BLOCK_SIZE_M"]) + * triton.cdiv(N, META["BLOCK_SIZE_N"]) + ), + ) + + _gemm_afp8wfp8_kernel[grid]( + x, + w_t, + y if config["NUM_KSPLIT"] == 1 else y_pp, + x_scales, + w_scales, + M, + N, + K, + x.stride(0), + x.stride(1), + w_t.stride(0), + w_t.stride(1), + 0 if config["NUM_KSPLIT"] == 1 else y_pp.stride(0), + y.stride(0) if config["NUM_KSPLIT"] == 1 else y_pp.stride(1), + y.stride(1) if config["NUM_KSPLIT"] == 1 else y_pp.stride(2), + x_scales.stride(0), + x_scales.stride(1), + w_scales.stride(0), + w_scales.stride(1), + **config, + ) + + if config["NUM_KSPLIT"] > 1: + if skip_reduce: + return y_pp + + REDUCE_BLOCK_SIZE_M = 32 + REDUCE_BLOCK_SIZE_N = 32 + ACTUAL_KSPLIT = triton.cdiv(K, config["SPLITK_BLOCK_SIZE"]) + + grid_reduce = ( + triton.cdiv(M, REDUCE_BLOCK_SIZE_M), + triton.cdiv(N, REDUCE_BLOCK_SIZE_N), + ) + _gemm_splitk_reduce_kernel[grid_reduce]( + y_pp, + y, + None, + M, + N, + y_pp.stride(0), + y_pp.stride(1), + y_pp.stride(2), + y.stride(0), + y.stride(1), + BLOCK_SIZE_M=REDUCE_BLOCK_SIZE_M, + BLOCK_SIZE_N=REDUCE_BLOCK_SIZE_N, + ACTUAL_KSPLIT=ACTUAL_KSPLIT, + MAX_KSPLIT=triton.next_power_of_2(config["NUM_KSPLIT"]), + ADD_BIAS=False, + activation=None, + use_activation=False, + KERNEL_NAME="_gemm_afp8wfp8_reduce_kernel", + ) + + return y + + +def gemm_afp8wfp8_preshuffle( + x: torch.Tensor, + w_shuffled: torch.Tensor, + x_scales: torch.Tensor, + w_scales: torch.Tensor, + dtype: Optional[torch.dtype] = torch.bfloat16, + y: Optional[torch.Tensor] = None, + config: Optional[dict] = None, + skip_reduce: Optional[bool] = False, +) -> torch.Tensor: + """ + Preshuffle variant of gemm_afp8wfp8. The weight tensor has already been + permuted via aiter.ops.shuffle.shuffle_weight(..., layout=(16, 16)). Scales + are left unshuffled in the compact 128x128 layout. + + Args: + x: FP8 e4m3 activations with shape (M, K). + w_shuffled: FP8 e4m3 weights, shuffled in place to (N, K) storage + (same total bytes; bytes rearranged for the kernel's read pattern). + x_scales: e8m0 (uint8) per-token scale with shape (M, K // 32). + w_scales: e8m0 (uint8) per-block weight scale with shape (N // 128, K // 128). + dtype: Output dtype. + y: Optional pre-allocated output (M, N). + config: Optional kernel-tuning dict. + + Returns: + torch.Tensor: Output with shape (M, N). + """ + M, K = x.shape + N, K_w = w_shuffled.shape + assert K == K_w, f"K mismatch: x={K}, w={K_w}" + assert N % 16 == 0, f"N must be divisible by 16 for preshuffle, got {N}" + + # The kernel expects to address the shuffled tensor as (N//16, K*16). + w_view = w_shuffled.view(N // 16, K * 16) + + if x.dtype != torch.uint8: + x = x.view(torch.uint8) + if w_view.dtype != torch.uint8: + w_view = w_view.view(torch.uint8) + + if config is None: + config, _ = _get_config(M, N, K, shuffle=True) + + if y is None and (config["NUM_KSPLIT"] == 1 or not skip_reduce): + y = torch.empty((M, N), dtype=dtype, device=x.device) + + config["SPLITK_BLOCK_SIZE"] = triton.cdiv( + K, config["NUM_KSPLIT"] + ) # How big each split_k partition is + if config["NUM_KSPLIT"] > 1: + y_pp = torch.empty( + (config["NUM_KSPLIT"], M, N), + dtype=torch.float32, + device=x.device, + ) + else: + y_pp = None + + grid = lambda META: ( # noqa: E731 + ( + META["NUM_KSPLIT"] + * triton.cdiv(M, META["BLOCK_SIZE_M"]) + * triton.cdiv(N, META["BLOCK_SIZE_N"]) + ), + ) + _gemm_afp8wfp8_preshuffle_kernel[grid]( + x, + w_view, + y if config["NUM_KSPLIT"] == 1 else y_pp, + x_scales, + w_scales, + M, + N, + K, + x.stride(0), + x.stride(1), + w_view.stride(0), + w_view.stride(1), + 0 if config["NUM_KSPLIT"] == 1 else y_pp.stride(0), + y.stride(0) if config["NUM_KSPLIT"] == 1 else y_pp.stride(1), + y.stride(1) if config["NUM_KSPLIT"] == 1 else y_pp.stride(2), + x_scales.stride(0), + x_scales.stride(1), + w_scales.stride(0), + w_scales.stride(1), + **config, + ) + + if config["NUM_KSPLIT"] > 1: + if skip_reduce: + return y_pp + + REDUCE_BLOCK_SIZE_M = 32 + REDUCE_BLOCK_SIZE_N = 32 + ACTUAL_KSPLIT = triton.cdiv(K, config["SPLITK_BLOCK_SIZE"]) + + grid_reduce = ( + triton.cdiv(M, REDUCE_BLOCK_SIZE_M), + triton.cdiv(N, REDUCE_BLOCK_SIZE_N), + ) + _gemm_splitk_reduce_kernel[grid_reduce]( + y_pp, + y, + None, + M, + N, + y_pp.stride(0), + y_pp.stride(1), + y_pp.stride(2), + y.stride(0), + y.stride(1), + BLOCK_SIZE_M=REDUCE_BLOCK_SIZE_M, + BLOCK_SIZE_N=REDUCE_BLOCK_SIZE_N, + ACTUAL_KSPLIT=ACTUAL_KSPLIT, + MAX_KSPLIT=triton.next_power_of_2(config["NUM_KSPLIT"]), + ADD_BIAS=False, + activation=None, + use_activation=False, + KERNEL_NAME="_gemm_afp8wfp8_preshuffle_reduce_kernel", + ) + + return y diff --git a/aiter/ops/triton/gemm/fused/fused_gemm_a16w16_quant_x.py b/aiter/ops/triton/gemm/fused/fused_gemm_a16w16_quant_x.py new file mode 100644 index 0000000000..d580687ef5 --- /dev/null +++ b/aiter/ops/triton/gemm/fused/fused_gemm_a16w16_quant_x.py @@ -0,0 +1,179 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. + +from typing import Optional, Tuple +import torch +import triton +from aiter.ops.triton._triton_kernels.gemm.fused.fused_gemm_a16w16_quant_x import ( + _fused_gemm_a16w16_quant_x_kernel, + _get_config, +) +from aiter.ops.triton._triton_kernels.common.splitk_reduce import ( + _gemm_splitk_reduce_kernel, +) +from aiter.ops.triton._triton_kernels.activation import _get_activation_from_str +from aiter.ops.triton.utils.logger import AiterTritonLogger + +_LOGGER = AiterTritonLogger() + +_QUANT_BLOCK_SIZE = 32 + + +def fused_gemm_a16w16_quant_x( + x, + w, + bias: Optional[torch.Tensor] = None, + dtype: Optional[torch.dtype] = torch.bfloat16, + quant_dtype: Optional[torch.dtype] = None, + y: Optional[torch.Tensor] = None, + x_quant: Optional[torch.Tensor] = None, + x_scales: Optional[torch.Tensor] = None, + config: Optional[dict] = None, + activation: Optional[str] = None, + skip_reduce: Optional[bool] = False, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Computes 16-bit matmul Y = X @ W^T and also emits an MXFP8-quantized X. + + This fuses the GEMM with the activation per-1x32 MXFP8 quantization that + immediately follows the router-gate GEMM in MoE flows (e.g. DSv4), + avoiding a separate kernel/DRAM pass to cast X from BF16 to FP8 + e8m0 + scales. + + The fused kernel uses a single 1D grid split into two regions: GEMM tiles + occupy `NUM_KSPLIT * cdiv(M, BLOCK_M) * cdiv(N, BLOCK_N)` programs and an + additional `cdiv(M, BLOCK_M) * cdiv(K, BLOCK_K)` programs handle the + MXFP8 quant of X. + + Args: + x (torch.Tensor): Input matrix with shape (M, K). K must be a multiple + of 32. + w (torch.Tensor): Weight matrix with shape (N, K), internally transposed. + bias (Optional[torch.Tensor]): Bias vector with shape (N,). + dtype (Optional[torch.dtype]): Output Y datatype (BF16 or FP16). + quant_dtype (Optional[torch.dtype]): FP8 dtype for the MXFP8 quantized + X values (defaults to torch.float8_e4m3fn). + y (Optional[torch.Tensor]): Pre-allocated output with shape (M, N). + x_quant (Optional[torch.Tensor]): Pre-allocated MXFP8 quantized X with + shape (M, K). + x_scales (Optional[torch.Tensor]): Pre-allocated uint8 e8m0 scales with + shape (M, K // 32). + config (Optional[dict]): Kernel tuning parameters. + activation (Optional[str]): Activation fused into Y. + skip_reduce (Optional[bool]): Skip split-K reduction for Y. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: (Y, x_quant, x_scales). + When skip_reduce=True and NUM_KSPLIT > 1, Y has shape (NUM_KSPLIT, M, N). + """ + + _LOGGER.info(f"FUSED_GEMM_A16W16_QUANT_X: x={tuple(x.shape)} w={tuple(w.shape)}") + # Shape checks + assert x.shape[1] == w.shape[1], "Incompatible matrix shapes." + assert ( + x.shape[1] % _QUANT_BLOCK_SIZE == 0 + ), f"K={x.shape[1]} must be a multiple of {_QUANT_BLOCK_SIZE} for MXFP8 quant" + + if quant_dtype is None: + quant_dtype = torch.float8_e4m3fn + + M, K = x.shape + N, K = w.shape + w = w.T + + if config is None: + config, _ = _get_config(M, N, K) + + assert ( + config["BLOCK_SIZE_K"] % _QUANT_BLOCK_SIZE == 0 + ), f"BLOCK_SIZE_K={config['BLOCK_SIZE_K']} must be a multiple of {_QUANT_BLOCK_SIZE}" + + if y is None and (config["NUM_KSPLIT"] == 1 or not skip_reduce): + y = torch.empty((M, N), dtype=dtype, device=x.device) + + if x_quant is None: + x_quant = torch.empty((M, K), dtype=quant_dtype, device=x.device) + + if x_scales is None: + x_scales = torch.empty( + (M, K // _QUANT_BLOCK_SIZE), dtype=torch.uint8, device=x.device + ) + + if config["NUM_KSPLIT"] > 1: + y_pp = torch.empty( + (config["NUM_KSPLIT"], M, N), + dtype=torch.float32, + device=y.device if y is not None else x.device, + ) + else: + y_pp = None + + grid = lambda META: ( # noqa: E731 + META["NUM_KSPLIT"] + * triton.cdiv(M, META["BLOCK_SIZE_M"]) + * triton.cdiv(N, META["BLOCK_SIZE_N"]) + + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(K, META["BLOCK_SIZE_K"]), + ) + _fused_gemm_a16w16_quant_x_kernel[grid]( + x, + w, + bias, + y if config["NUM_KSPLIT"] == 1 else y_pp, + x_quant, + x_scales, + M, + N, + K, + x.stride(0), + x.stride(1), + w.stride(0), + w.stride(1), + 0 if config["NUM_KSPLIT"] == 1 else y_pp.stride(0), + y.stride(0) if config["NUM_KSPLIT"] == 1 else y_pp.stride(1), + y.stride(1) if config["NUM_KSPLIT"] == 1 else y_pp.stride(2), + x_quant.stride(0), + x_quant.stride(1), + x_scales.stride(0), + x_scales.stride(1), + activation=_get_activation_from_str(activation) if activation else "", + use_activation=activation is not None, + ADD_BIAS=(bias is not None), + SKIP_REDUCE=skip_reduce, + QUANT_BLOCK_SIZE=_QUANT_BLOCK_SIZE, + **config, + ) + + if config["NUM_KSPLIT"] > 1: + if skip_reduce: + return y_pp, x_quant, x_scales + + REDUCE_BLOCK_SIZE_M = 32 + REDUCE_BLOCK_SIZE_N = 32 + ACTUAL_KSPLIT = triton.cdiv(K, config["SPLITK_BLOCK_SIZE"]) + + grid_reduce = ( + triton.cdiv(M, REDUCE_BLOCK_SIZE_M), + triton.cdiv(N, REDUCE_BLOCK_SIZE_N), + ) + _gemm_splitk_reduce_kernel[grid_reduce]( + y_pp, + y, + bias, + M, + N, + y_pp.stride(0), + y_pp.stride(1), + y_pp.stride(2), + y.stride(0), + y.stride(1), + REDUCE_BLOCK_SIZE_M, + REDUCE_BLOCK_SIZE_N, + ACTUAL_KSPLIT, + triton.next_power_of_2(config["NUM_KSPLIT"]), + ADD_BIAS=(bias is not None), + activation=_get_activation_from_str(activation) if activation else "", + use_activation=activation is not None, + KERNEL_NAME="_fused_gemm_a16w16_quant_x_reduce_kernel", + ) + + return y, x_quant, x_scales diff --git a/aiter/ops/triton/moe/moe_op_gemm_a8w4.py b/aiter/ops/triton/moe/moe_op_gemm_a8w4.py index 1407ba85c5..bad8eaa8f3 100644 --- a/aiter/ops/triton/moe/moe_op_gemm_a8w4.py +++ b/aiter/ops/triton/moe/moe_op_gemm_a8w4.py @@ -2,6 +2,7 @@ # original code https://github.com/triton-lang/triton/blob/main/python/triton_kernels/triton_kernels/matmul_ogs.py import itertools +import os import torch import triton from aiter.ops.triton.moe.moe_routing.routing import RoutingData @@ -29,8 +30,8 @@ def should_upcast_indices(*args): def allocate_output( - M, - N, + x, + w, out_dtype, reduction_n_matmul, reduction_n_reduction, @@ -41,6 +42,14 @@ def allocate_output( split_k, device, ): + # ---- output ------ + N = w.shape[-1] + # by default - M is number of rows in the activations + M = x.shape[-2] + # if the activations are gathered, then M is number of gather indices + if gather_indx is not None: + M = gather_indx.shape[0] + # final output if routing_data.n_expts_act == 1 or scatter_indx is None: y_rows = M else: @@ -57,6 +66,14 @@ def allocate_output( return matmul_output, final_output +def recommend_block_m(m: int) -> int: + """ + Recommend block_m for moe_gemm_a8w4 based on M. + Prefill (M >= 256) → 64. Decode (M < 256) → 32. + """ + return 64 if m >= 256 else 16 + + def get_kernel_config_triton(m, n, k, routing_data): block_m = routing_data.block_m group_m = 4 @@ -75,12 +92,18 @@ def get_kernel_config_triton(m, n, k, routing_data): grid_m = routing_data.n_blocks(m, block_m) grid_n = triton.cdiv(n, block_n) grid = grid_m * grid_n * split_k - while block_n >= 64 and grid < get_num_sms(): + # Floor at 64 (was 32): out_mx_quant=True with apply_swiglu requires + # OUT_BLOCK_N = BLOCK_N // 2 >= 32. Loop boundary changed to keep + # block_n >= 64 for both MX and non-MX paths. + while block_n >= 128 and grid < get_num_sms(): block_n = block_n // 2 grid_m = routing_data.n_blocks(m, block_m) grid_n = triton.cdiv(n, block_n) grid = grid_m * grid_n * split_k + if k >= 512: + block_k = 512 + elif block_m == 32: if n <= 1024: block_n = 128 @@ -92,6 +115,14 @@ def get_kernel_config_triton(m, n, k, routing_data): block_n = 512 num_warps = 4 + elif block_m == 64: + # V4-Flash prefill-tuned (rocprof brute force v2): for block_m=64, + # (bn=128, nw=4, ns=1) gives 2-4x speedup over the previous bn=512/nw=8 + # default on all four V4-Flash prefill shapes. + block_n = 128 + num_warps = 4 + num_stages = 1 + else: block_n = 512 # routing caps block_m at 128; nw=4 wins ~2x at block_m=128 on gpt-oss @@ -112,6 +143,28 @@ def get_kernel_config_triton(m, n, k, routing_data): "matrix_instr_nonkdim": 16, "kpack": 1, } + # Env-driven overrides split by regime: block_m>=64 → PREFILL_*, else DECODE_*. + # Generic AITER_A8W4_* still works as a fallback when the regime-specific var is unset. + _regime = "PREFILL" if block_m >= 64 else "DECODE" + _knobs = ( + "block_n", + "block_k", + "num_warps", + "num_stages", + "group_m", + "waves_per_eu", + "matrix_instr_nonkdim", + "split_k", + ) + for key in _knobs: + env_specific = f"AITER_A8W4_{_regime}_{key.upper()}" + env_generic = f"AITER_A8W4_{key.upper()}" + v = os.environ.get(env_specific, os.environ.get(env_generic)) + if v is not None: + try: + ret[key] = int(v) + except ValueError: + pass return ret @@ -210,6 +263,13 @@ def moe_gemm_a8w4( add_residual=True, unpadded_N=None, unpadded_K=None, + # Idea 1: emit (fp8 e4m3, ue8m0 per-1×32 scale) directly from the GEMM + # write-back. When out_mx_quant=True, returns (y_fp8, y_scale_ue8m0). + # Requires SPLIT_K==1 and no scatter_indx (GEMM1-style). + out_mx_quant: bool = False, + # External residual to fold into reduce_grouped writeback (saves the + # standalone routed+shared elementwise add). + residual=None, ): """ Y[:, :] = 0. @@ -259,11 +319,21 @@ def moe_gemm_a8w4( reduction_n_matmul = 1 apply_swiglu_reduction = False reduction_n_reduction = 1 - # allocate output memory + # allocate output memory. With out_mx_quant=True, the kernel writes fp8 e4m3 + # into y; otherwise the requested out_dtype (bf16). + if out_mx_quant: + assert config["split_k"] == 1, "out_mx_quant requires split_k == 1" + assert scatter_indx is None, ( + "out_mx_quant currently only supported for GEMM1-style (no scatter); " + "scatter+combine would need fp8-aware reduce_grouped" + ) + out_dtype_actual = torch.float8_e4m3fn + else: + out_dtype_actual = out_dtype y, y_final = allocate_output( - M, - N, - out_dtype, + x, + w, + out_dtype_actual, reduction_n_matmul, reduction_n_reduction, routing_data, @@ -273,6 +343,18 @@ def moe_gemm_a8w4( config["split_k"], x.device, ) + # Companion ue8m0 scale buffer for the MXFP8 emit path. + if out_mx_quant: + n_out = w.shape[-1] // reduction_n_matmul # post-swiglu width + assert n_out % 32 == 0, "out_mx_quant requires N_out % 32 == 0" + m_out = y.shape[-2] + y_scale = torch.empty((m_out, n_out // 32), dtype=torch.uint8, device=x.device) + stride_y_mx_m = y_scale.stride(0) + stride_y_mx_n = y_scale.stride(1) + else: + y_scale = None + stride_y_mx_m = 0 + stride_y_mx_n = 0 stride_bias = None if bias is None else bias.stride(0) # moe metadata expt_data = routing_data.expt_data @@ -398,14 +480,23 @@ def moe_gemm_a8w4( waves_per_eu=config["waves_per_eu"], matrix_instr_nonkdim=config["matrix_instr_nonkdim"], kpack=config["kpack"], + YMxScale=y_scale, + stride_y_mx_m=stride_y_mx_m, + stride_y_mx_n=stride_y_mx_n, + HAS_MX_OUT=out_mx_quant, ) + # MXFP8 emit path: scatter_indx is None and split_k==1, so we bypass + # reduce_grouped and return (fp8 values, ue8m0 scales) directly. + if out_mx_quant: + return y.squeeze(0), y_scale # Build grouped reduction inputs in a uniform way group_indx = ( None if scatter_indx is None else scatter_indx.view(-1, routing_data.n_expts_act) ) + # Step 9: external residual fold-in is now wired into reduce_grouped. y_final = reduce_grouped( y, group_indx, @@ -416,6 +507,7 @@ def moe_gemm_a8w4( reduction_n_reduction, out_dtype=out_dtype, add_residual=add_residual, + residual=residual, ) return y_final diff --git a/aiter/ops/triton/moe/moe_routing/routing.py b/aiter/ops/triton/moe/moe_routing/routing.py index 6f5ed8084f..2ea56e8a04 100644 --- a/aiter/ops/triton/moe/moe_routing/routing.py +++ b/aiter/ops/triton/moe/moe_routing/routing.py @@ -5,6 +5,12 @@ _combined_routing, _combined_routing_fused, ) +from aiter.ops.triton.fusions.fused_routing_from_topk import ( + fused_routing_from_topk, +) +from aiter.ops.triton._triton_kernels.moe.moe_routing.expt_data import ( + _expt_data_only_kernel, +) from aiter.ops.triton.utils._triton.arch_info import is_tdm_avail @@ -300,6 +306,195 @@ def routing(logits, n_expts_act, sm_first=False): ) +def routing_a8w4( + logits: torch.Tensor, + n_expts_act: int, + block_m: int, + *, + score_mode: str = "sqrtsoftplus", + bias: torch.Tensor | None = None, + renorm: bool = True, + routed_scaling_factor: float = 1.0, +): + """All-Triton routing for the a8w4 path: fused V4 routing math + sort. + + One-shot pipeline: + 1. aiter `_topk` (extended): pre-transform (sqrtsoftplus) + bias + topk + + bitmatrix + renorm + scale — single Triton kernel. + 2. aiter `sort_tokens` (or `sort_tokens_fused` for tiny M): sort tokens by + expert and produce ExptData specialized for the given ``block_m``. + + Returns (RoutingData, gather_indx, scatter_indx) where gather_indx and + scatter_indx are raw int32 tensors (no GatherIndx/ScatterIndx wrappers) — + consumed directly by ``moe_gemm_a8w4``. + + No multi-block_m dict, no triton_kernels wrapper, no Python bridge step. + """ + from .topk import topk + + n_tokens, n_expts_tot = logits.shape + + # Step 1: extended topk does sqrtsoftplus + bias + topk + bitmatrix + renorm + scale. + expt_scal, expt_indx, bitmatrix = topk( + logits, + n_expts_act, + apply_softmax=False, + score_mode=score_mode, + bias=bias, + renorm=renorm, + routed_scaling_factor=routed_scaling_factor, + HIST_BLOCK_M=32, + ) + + # Step 2: sort tokens by expert and build ExptData for the chosen block_m. + if n_tokens <= 16: + HIST_BLOCK_M = triton.next_power_of_2(max(n_tokens, 1)) + sort_fn = sort_tokens_fused + else: + HIST_BLOCK_M = 32 + sort_fn = sort_tokens + ( + hist, + topk_indx, + gate_indx, + gate_scal, + token_offs_raw, + token_offs_pad, + block_pid_map, + ) = sort_fn(expt_scal, expt_indx, n_expts_tot, bitmatrix, block_m, HIST_BLOCK_M) + expt_data = ExptData(hist, token_offs_raw, token_offs_pad, block_pid_map) + routing_data = RoutingData( + block_m=block_m, + gate_scal=gate_scal, + expt_hist=hist, + n_expts_tot=n_expts_tot, + n_expts_act=n_expts_act, + expt_data=expt_data, + ) + return routing_data, topk_indx, gate_indx + + +def routing_a8w4_from_hash( + router_logits: torch.Tensor, + tid2eid: torch.Tensor, + input_ids: torch.Tensor, + n_expts_act: int, + block_m: int, + *, + score_mode: str = "sqrtsoftplus", + renorm: bool = True, + routed_scaling_factor: float = 1.0, +): + """All-Triton routing for the a8w4 path on DeepSeek-V4 hash layers. + + Single fused kernel ``hash_routing`` does tid2eid lookup + score transform + + gather + renorm + scale + bitmatrix in one launch, then + ``sort_tokens_fused`` (same as :func:`routing_a8w4`) produces ExptData. + + Replaces the Python ``_hash_topk`` + multi-kernel ``fused_routing_from_topk`` + counting-sort + ``compute_expt_data`` (with memset) chain entirely. + """ + from .topk import hash_routing + + n_tokens, n_expts_tot = router_logits.shape + + expt_scal, expt_indx, bitmatrix = hash_routing( + router_logits, + tid2eid, + input_ids, + n_expts_act=n_expts_act, + HIST_BLOCK_M=32, + score_mode=score_mode, + renorm=renorm, + routed_scaling_factor=routed_scaling_factor, + ) + + if n_tokens <= 16: + HIST_BLOCK_M = triton.next_power_of_2(max(n_tokens, 1)) + sort_fn = sort_tokens_fused + else: + HIST_BLOCK_M = 32 + sort_fn = sort_tokens + ( + hist, + topk_indx, + gate_indx, + gate_scal, + token_offs_raw, + token_offs_pad, + block_pid_map, + ) = sort_fn(expt_scal, expt_indx, n_expts_tot, bitmatrix, block_m, HIST_BLOCK_M) + expt_data = ExptData(hist, token_offs_raw, token_offs_pad, block_pid_map) + routing_data = RoutingData( + block_m=block_m, + gate_scal=gate_scal, + expt_hist=hist, + n_expts_tot=n_expts_tot, + n_expts_act=n_expts_act, + expt_data=expt_data, + ) + return routing_data, topk_indx, gate_indx + + +def routing_a8w4_from_topk( + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + n_expts_tot: int, + block_m: int, +): + """Routing for the a8w4 path when topk has been pre-computed externally + (e.g. DeepSeek-V4 hash layers with tid2eid lookup). + + Mirrors ``routing_a8w4`` but skips the score+topk math step. Pipeline: + 1. aiter ``fused_routing_from_topk``: 3-kernel counting-sort over the + supplied ``(topk_weights, topk_ids)``. Allocates only via + ``torch.empty`` — no histogram memset. + 2. aiter ``_expt_data_only_kernel``: standalone stage1+stage2 launch + that materialises ExptData (token_offs_raw, token_offs_pad, + block_pid_map) from the histogram for the chosen ``block_m``. + + Returns ``(RoutingData, gather_indx, scatter_indx)`` where ``gather_indx`` + and ``scatter_indx`` are raw int32 tensors — same contract as + ``routing_a8w4`` — so ``_a8w4_fused_experts`` consumes them unchanged. + """ + + n_tokens, n_expts_act = topk_weights.shape + n_gates = n_tokens * n_expts_act + + hist, topk_indx, gate_indx, gate_scal = fused_routing_from_topk( + topk_weights, topk_ids, n_expts_tot + ) + + token_offs_raw, token_offs_pad, block_pid_map, blocks1a, BLOCK_A, block_m_log2 = ( + _compute_expt_data_internal(n_expts_tot, n_gates, block_m, topk_weights.device) + ) + + _expt_data_only_kernel[(blocks1a,)]( + hist, + n_expts_tot, + token_offs_raw, + token_offs_pad, + block_pid_map, + block_pid_map.shape[0], + n_gates, + block_m_log2, + BLOCK_A, + (hist.shape[0] == BLOCK_A), + num_warps=1, + ) + + expt_data = ExptData(hist, token_offs_raw, token_offs_pad, block_pid_map) + routing_data = RoutingData( + block_m=block_m, + gate_scal=gate_scal, + expt_hist=hist, + n_expts_tot=n_expts_tot, + n_expts_act=n_expts_act, + expt_data=expt_data, + ) + return routing_data, topk_indx, gate_indx + + # -------------------------- # torch reference # -------------------------- diff --git a/aiter/ops/triton/moe/moe_routing/topk.py b/aiter/ops/triton/moe/moe_routing/topk.py index c9c43e1ba1..4306184f7e 100644 --- a/aiter/ops/triton/moe/moe_routing/topk.py +++ b/aiter/ops/triton/moe/moe_routing/topk.py @@ -1,50 +1,89 @@ import triton import torch -from aiter.ops.triton._triton_kernels.moe.moe_routing.topk import _topk +from aiter.ops.triton._triton_kernels.moe.moe_routing.topk import _topk, _hash_routing from aiter.ops.triton.moe.moe_routing.bitmatrix import Bitmatrix -def topk(x, k, apply_softmax=True, dim=1, return_bitmatrix=True, HIST_BLOCK_M=32): - x_shape = [x.shape[0], x.shape[1]] +def topk( + x, + k, + apply_softmax=True, + dim=1, + return_bitmatrix=True, + HIST_BLOCK_M=32, + score_mode: str = "softmax", + bias=None, + renorm: bool = False, + routed_scaling_factor: float = 1.0, +): + """Top-k expert selection with bitmatrix. - def cdiv(a, b): - return (a + b - 1) // b + score_mode: + - "softmax" (default): no pre-transform; APPLY_SOFTMAX may renormalize. + - "sqrtsoftplus": pre-transform `scores = sqrt(softplus(logits))` before + adding the optional `bias` and running topk. Selected weights are the + UNBIASED sqrt(softplus(logits)). DeepSeek-V4 noaux_tc router. - BLOCK_M = 32 - BLOCK_N = 128 # triton.next_power_of_2(x_shape[1]) # 128 + bias (fp32, [n_expts_tot]): added to scores for selection only, not for + returned weights. Only meaningful with score_mode='sqrtsoftplus'. + + renorm: renormalize weights to sum=1 per row before multiplying by + routed_scaling_factor. + """ + assert len(x.shape) == 2 + n_rows, n_cols = x.shape + + # BLOCK_M=1 for small n_rows keeps the grid wide enough to overlap with + BLOCK_M = 1 if n_rows <= 256 else 32 + BLOCK_N = 128 BLOCK_S = 128 BLOCK_SP = 128 - assert len(x.shape) == 2 - assert x_shape[-1] < 32768 + assert n_cols < 32768 assert dim == 1 assert return_bitmatrix - n_rows, n_cols = x_shape + assert score_mode in ( + "softmax", + "sqrtsoftplus", + ), f"score_mode must be 'softmax' or 'sqrtsoftplus', got {score_mode!r}" + if score_mode != "softmax": + assert not apply_softmax, "apply_softmax only valid with score_mode='softmax'" + has_bias = bias is not None + if has_bias: + assert bias.dim() == 1 + assert bias.shape[0] == x.shape[-1] + assert bias.dtype == torch.float32 + assert ( + score_mode == "sqrtsoftplus" + ), "bias currently only supported with score_mode='sqrtsoftplus'" dev = x.device # scratchpad tensors # NOTE: these are not returned y_vals = torch.empty((n_rows, k), dtype=x.dtype, device=dev) y_indx = torch.empty((n_rows, k), dtype=torch.int16, device=dev) - k_pow2 = triton.next_power_of_2(k) + # Triton's tl.topk fails to compile for k=1 (log_k=0 reduces the hypercube + # to a 0-D tensor; the final reshape hits dtype.numel). Pad to ≥ 2 — the + # kernel already masks N_EXPTS_ACT < N_EXPTS_ACT_PAD on store. + k_pow2 = max(2, triton.next_power_of_2(k)) # create bitmatrix in transposed memory layout: - n_cols_pad = cdiv(n_cols, BLOCK_N) * BLOCK_N + n_cols_pad = triton.cdiv(n_cols, BLOCK_N) * BLOCK_N n_cols_words = n_cols_pad // 32 bitmatrix = torch.empty( - (n_cols_words, cdiv(n_rows, 32) * 32), dtype=torch.uint32, device=dev + (n_cols_words, triton.cdiv(n_rows, 32) * 32), dtype=torch.uint32, device=dev ) bitmatrix = torch.transpose(bitmatrix, 0, 1)[:n_rows] - s_blocks = cdiv(n_cols, BLOCK_S) + s_blocks = triton.cdiv(n_cols, BLOCK_S) s_cols = s_blocks * BLOCK_S scratchpad = torch.empty((s_cols,), dtype=torch.int32, device=dev) TILE_SIZE = 8 BLOCK_MM = HIST_BLOCK_M * TILE_SIZE - pids_x = cdiv(n_rows, BLOCK_MM) + pids_x = triton.cdiv(n_rows, BLOCK_MM) scratchpad_partials = torch.empty( (n_cols_pad, pids_x * TILE_SIZE), device=dev, dtype=torch.int32 ) scratchpad_partials = torch.transpose(scratchpad_partials, 0, 1) sp_size = torch.numel(scratchpad_partials) - sp_blocks = cdiv(sp_size, BLOCK_SP) - pids = max(cdiv(n_rows, BLOCK_M), s_blocks + sp_blocks) + sp_blocks = triton.cdiv(sp_size, BLOCK_SP) + pids = max(triton.cdiv(n_rows, BLOCK_M), s_blocks + sp_blocks) _topk[(pids,)]( x, x.stride(0), # inputs @@ -70,6 +109,11 @@ def cdiv(a, b): N_EXPTS_ACT=k, # constants N_EXPTS_ACT_PAD=k_pow2, num_warps=8, + Bias=bias, + SCORE_MODE=score_mode, + HAS_BIAS=has_bias, + APPLY_RENORM=renorm, + ROUTED_SCALING=routed_scaling_factor, ) bitmatrix_shape = [n_rows, n_cols_words * 32] bitmatrix = Bitmatrix( @@ -79,3 +123,112 @@ def cdiv(a, b): scratchpad_partials=scratchpad_partials, ) return y_vals, y_indx, bitmatrix + + +def hash_routing( + router_logits: torch.Tensor, # [n_rows, n_expts_tot] bf16/fp32 + tid2eid: torch.Tensor, # [vocab_size, K] int32 per-token-id expert table + input_ids: torch.Tensor, # [n_rows] int32 token ids (post DP gather, clamped) + n_expts_act: int, + HIST_BLOCK_M: int = 32, + score_mode: str = "sqrtsoftplus", + renorm: bool = True, + routed_scaling_factor: float = 1.0, +): + """Fused hash routing: tid2eid lookup + score transform + gather + renorm + + scale + bitmatrix construction. Output contract matches :func:`topk` so + downstream :func:`sort_tokens_fused` consumes it unchanged. + + Replaces the Python ``_hash_topk`` + ``fused_routing_from_topk`` + counting-sort + bitmatrix-build chain with one Triton kernel launch. + """ + + BLOCK_M = 32 + BLOCK_N = 128 + BLOCK_S = 128 + BLOCK_SP = 128 + assert router_logits.dim() == 2 + assert input_ids.dim() == 1 + assert tid2eid.dim() == 2 + assert input_ids.shape[0] == router_logits.shape[0] + assert ( + tid2eid.shape[1] == n_expts_act + ), f"tid2eid second dim {tid2eid.shape[1]} must equal n_expts_act {n_expts_act}" + assert tid2eid.dtype == torch.int32 + assert input_ids.dtype in (torch.int32, torch.int64) + assert score_mode in ("sqrtsoftplus",) + + n_rows, n_cols = router_logits.shape + dev = router_logits.device + k = n_expts_act + + y_vals = torch.empty((n_rows, k), dtype=router_logits.dtype, device=dev) + y_indx = torch.empty((n_rows, k), dtype=torch.int16, device=dev) + # See note in topk(): pad to ≥ 2 to dodge tl.topk(k=1) compile bug. + k_pow2 = max(2, triton.next_power_of_2(k)) + + n_cols_pad = triton.cdiv(n_cols, BLOCK_N) * BLOCK_N + n_cols_words = n_cols_pad // 32 + bitmatrix = torch.empty( + (n_cols_words, triton.cdiv(n_rows, 32) * 32), dtype=torch.uint32, device=dev + ) + bitmatrix = torch.transpose(bitmatrix, 0, 1)[:n_rows] + s_blocks = triton.cdiv(n_cols, BLOCK_S) + s_cols = s_blocks * BLOCK_S + scratchpad = torch.empty((s_cols,), dtype=torch.int32, device=dev) + TILE_SIZE = 8 + BLOCK_MM = HIST_BLOCK_M * TILE_SIZE + pids_x = triton.cdiv(n_rows, BLOCK_MM) + scratchpad_partials = torch.empty( + (n_cols_pad, pids_x * TILE_SIZE), device=dev, dtype=torch.int32 + ) + scratchpad_partials = torch.transpose(scratchpad_partials, 0, 1) + sp_size = torch.numel(scratchpad_partials) + sp_blocks = triton.cdiv(sp_size, BLOCK_SP) + pids = max(triton.cdiv(n_rows, BLOCK_M), s_blocks + sp_blocks) + + # int32 cast for input_ids if int64 + input_ids_i32 = ( + input_ids.to(torch.int32) if input_ids.dtype != torch.int32 else input_ids + ) + + _hash_routing[(pids,)]( + input_ids_i32, + tid2eid, + tid2eid.stride(0), + router_logits, + router_logits.stride(0), + y_vals, + y_indx, + y_vals.stride(0), + bitmatrix, + bitmatrix.stride(0), + bitmatrix.stride(1), + n_rows, + n_cols, + scratchpad, + BLOCK_S, + s_blocks, + scratchpad_partials, + BLOCK_SP, + sp_blocks, + sp_size, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + N_EXPTS_PAD=n_cols_pad, + N_EXPTS_ACT=k, + N_EXPTS_ACT_PAD=k_pow2, + SCORE_MODE=score_mode, + APPLY_RENORM=renorm, + ROUTED_SCALING=routed_scaling_factor, + num_warps=8, + ) + + bitmatrix_shape = [n_rows, n_cols_words * 32] + bitmatrix = Bitmatrix( + bitmatrix, + shape=bitmatrix_shape, + scratchpad=scratchpad, + scratchpad_partials=scratchpad_partials, + ) + return y_vals, y_indx, bitmatrix diff --git a/aiter/ops/triton/moe/reduce.py b/aiter/ops/triton/moe/reduce.py index 4f8d07043e..656beba2cb 100644 --- a/aiter/ops/triton/moe/reduce.py +++ b/aiter/ops/triton/moe/reduce.py @@ -1,3 +1,4 @@ +from typing import Optional import torch import triton from aiter.ops.triton._triton_kernels.moe.reduce import _reduce_grouped @@ -14,6 +15,7 @@ def reduce_grouped( reduction_n=1, out_dtype=None, add_residual: bool = True, + residual: Optional[torch.Tensor] = None, ): """ Grouped row reduction used during moe scatter and also compatible with split-k reduce. @@ -38,6 +40,10 @@ def reduce_grouped( """ if indx is None and x.shape[0] == 1: + assert residual is None, ( + "reduce_grouped early-return path can't apply external residual; " + "either rebuild routing with K>=1 or skip residual fold for this call" + ) return x.squeeze(0) if indx is not None: num_groups = indx.shape[0] @@ -49,6 +55,18 @@ def reduce_grouped( BLOCK_N = 512 num_blocks = triton.cdiv(x.shape[-1], BLOCK_N) + # Step 9: prep external residual buffer + strides for the kernel. + if residual is not None: + assert ( + residual.shape == out.shape + ), f"residual.shape {tuple(residual.shape)} must match out.shape {tuple(out.shape)}" + res_stride_m = residual.stride(0) + res_stride_n = residual.stride(1) + has_ext_residual = True + else: + res_stride_m = 0 + res_stride_n = 0 + has_ext_residual = False _reduce_grouped[(num_blocks * num_groups,)]( x, x.stride(0), @@ -71,6 +89,10 @@ def reduce_grouped( K=K, # ADD_RESIDUAL=add_residual, USE_TDM=is_tdm_avail(), + Residual=residual, + stride_extres_m=res_stride_m, + stride_extres_n=res_stride_n, + HAS_EXT_RESIDUAL=has_ext_residual, num_warps=2, # ) return out diff --git a/aiter/ops/triton/quant/__init__.py b/aiter/ops/triton/quant/__init__.py index 5a1eec8cac..8f2568a7cb 100644 --- a/aiter/ops/triton/quant/__init__.py +++ b/aiter/ops/triton/quant/__init__.py @@ -4,6 +4,9 @@ dynamic_per_token_quant_fp8_i8, dynamic_mxfp4_quant, _mxfp4_quant_op, + dynamic_mxfp8_quant, + fp8_legacy_to_mxfp8, + _mxfp8_quant_op, ) from .fused_fp8_quant import ( @@ -25,6 +28,12 @@ fused_dynamic_mxfp4_quant_moe_sort, ) +from .fused_mxfp8_quant import ( + fused_rms_mxfp8_quant, + fused_dual_rmsnorm_mxfp8_quant, + fused_flatten_mxfp8_quant, +) + __all__ = [ # quant.py exports "static_per_tensor_quant_fp8_i8", @@ -32,6 +41,9 @@ "dynamic_per_token_quant_fp8_i8", "dynamic_mxfp4_quant", "_mxfp4_quant_op", + "dynamic_mxfp8_quant", + "fp8_legacy_to_mxfp8", + "_mxfp8_quant_op", # fused_fp8_quant.py exports "calc_rows_per_block", "get_fp8_min_max_bounds", @@ -47,4 +59,8 @@ "fused_reduce_act_mul_and_mxfp4_quant", "fused_reduce_rms_mxfp4_quant", "fused_dynamic_mxfp4_quant_moe_sort", + # fused_mxfp8_quant.py exports + "fused_rms_mxfp8_quant", + "fused_dual_rmsnorm_mxfp8_quant", + "fused_flatten_mxfp8_quant", ] diff --git a/aiter/ops/triton/quant/fused_mxfp8_quant.py b/aiter/ops/triton/quant/fused_mxfp8_quant.py new file mode 100644 index 0000000000..7c0860e2c3 --- /dev/null +++ b/aiter/ops/triton/quant/fused_mxfp8_quant.py @@ -0,0 +1,218 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. + +from typing import Optional, Tuple +import torch +import triton + +from aiter.ops.triton._triton_kernels.quant.fused_mxfp8_quant import ( + _fused_rms_mxfp8_kernel, + _fused_dual_rmsnorm_mxfp8_quant_kernel, + _fused_flatten_mxfp8_quant_kernel, +) + +__all__ = [ + "fused_rms_mxfp8_quant", + "fused_dual_rmsnorm_mxfp8_quant", + "fused_flatten_mxfp8_quant", +] + +_QUANT_BLOCK_SIZE = 32 + + +def fused_rms_mxfp8_quant( + x: torch.Tensor, + weight: torch.Tensor, + eps: float, + y: Optional[torch.Tensor] = None, + scale: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Fused RMSNorm + MXFP8 (1x32 e8m0) quant in a single Triton launch. + + Args: + x: (M, K) bf16 or fp16. + weight: (K,) bf16 or fp16 RMSNorm weight. + eps: RMSNorm epsilon. + y: optional preallocated FP8 e4m3fn output (M, K). + scale: optional preallocated uint8 e8m0 output (M, K // 32). + + Returns: + y (M, K) fp8 e4m3fn, scale (M, K // 32) uint8. + """ + assert x.dim() == 2, f"x must be 2D, got {x.dim()}" + M, K = x.shape + assert weight.shape == (K,), f"weight shape {weight.shape} != ({K},)" + assert K % _QUANT_BLOCK_SIZE == 0 + Ns = K // _QUANT_BLOCK_SIZE + BLOCK_SIZE_K = triton.next_power_of_2(K) + + if y is None: + y = torch.empty((M, K), dtype=torch.float8_e4m3fn, device=x.device) + if scale is None: + scale = torch.empty((M, Ns), dtype=torch.uint8, device=x.device) + + NUM_PRGMS = M + grid = (NUM_PRGMS,) + + _fused_rms_mxfp8_kernel[grid]( + x, + weight, + y, + scale, + M, + K, + x.stride(0), + x.stride(1), + y.stride(0), + y.stride(1), + scale.stride(0), + scale.stride(1), + eps, + BLOCK_SIZE_K=BLOCK_SIZE_K, + QUANT_BLOCK_SIZE=_QUANT_BLOCK_SIZE, + NUM_PRGMS=NUM_PRGMS, + ) + return y, scale + + +def fused_dual_rmsnorm_mxfp8_quant( + q: torch.Tensor, + k: torch.Tensor, + q_weight: torch.Tensor, + k_weight: torch.Tensor, + eps_q: float, + eps_k: Optional[float] = None, + yq: Optional[torch.Tensor] = None, + sq: Optional[torch.Tensor] = None, + yk: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Fused dual RMSNorm in a single Triton launch. + + - Q side: RMSNorm(q, q_weight, eps_q) -> MXFP8 (FP8 e4m3fn + uint8 e8m0 1x32). + - K side: RMSNorm(k, k_weight, eps_k) -> bf16. + + Replaces the CK `fused_qk_rmsnorm_group_quant` kernel for the MXFP8 GEMM + path on V4 (Task #77): one launch instead of two (fused_rms_mxfp8_quant + + rmsnorm2d_fwd_), eliminating the ~6us/layer launch-overhead regression. + + Args: + q: (M, KQ) bf16 or fp16 — Q-side input (e.g. q_lora). + k: (M, KK) bf16 or fp16 — K-side input (e.g. kv_pre). + q_weight: (KQ,) bf16 or fp16 — Q RMSNorm weight. + k_weight: (KK,) bf16 or fp16 — K RMSNorm weight. + eps_q: Q RMSNorm epsilon. + eps_k: K RMSNorm epsilon; defaults to eps_q. + yq, sq, yk: optional pre-allocated outputs. + + Returns: + yq (M, KQ) fp8 e4m3fn, sq (M, KQ // 32) uint8 e8m0, yk (M, KK) bf16. + """ + assert q.dim() == 2, f"q must be 2D, got {q.dim()}" + assert k.dim() == 2, f"k must be 2D, got {k.dim()}" + M, KQ = q.shape + Mk, KK = k.shape + assert M == Mk, f"q rows {M} != k rows {Mk}" + assert q_weight.shape == (KQ,), f"q_weight shape {q_weight.shape} != ({KQ},)" + assert k_weight.shape == (KK,), f"k_weight shape {k_weight.shape} != ({KK},)" + assert ( + KQ % _QUANT_BLOCK_SIZE == 0 + ), f"KQ={KQ} must be a multiple of {_QUANT_BLOCK_SIZE}" + if eps_k is None: + eps_k = eps_q + + Ns = KQ // _QUANT_BLOCK_SIZE + BLOCK_SIZE_KQ = triton.next_power_of_2(KQ) + BLOCK_SIZE_KK = triton.next_power_of_2(KK) + + if yq is None: + yq = torch.empty((M, KQ), dtype=torch.float8_e4m3fn, device=q.device) + if sq is None: + sq = torch.empty((M, Ns), dtype=torch.uint8, device=q.device) + if yk is None: + yk = torch.empty((M, KK), dtype=k.dtype, device=k.device) + + NUM_PRGMS = M + grid = (NUM_PRGMS,) + + _fused_dual_rmsnorm_mxfp8_quant_kernel[grid]( + q, + k, + q_weight, + k_weight, + yq, + sq, + yk, + M, + KQ, + KK, + q.stride(0), + q.stride(1), + k.stride(0), + k.stride(1), + yq.stride(0), + yq.stride(1), + sq.stride(0), + sq.stride(1), + yk.stride(0), + yk.stride(1), + eps_q, + eps_k, + BLOCK_SIZE_KQ=BLOCK_SIZE_KQ, + BLOCK_SIZE_KK=BLOCK_SIZE_KK, + QUANT_BLOCK_SIZE=_QUANT_BLOCK_SIZE, + NUM_PRGMS=NUM_PRGMS, + ) + return yq, sq, yk + + +def fused_flatten_mxfp8_quant( + x: torch.Tensor, + quant_dtype: torch.dtype = torch.float8_e4m3fn, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Flatten the last two dimensions of x and apply per-1x32 MXFP8 quant along + the flattened axis (FP8 e4m3 values + uint8 e8m0 scales). + + Equivalent in shape to `fused_flatten_fp8_group_quant` but emits MXFP8 + 1x32 (e8m0) scales using the same recipe as `dynamic_mxfp8_quant`. + + Args: + x: Input tensor of shape (M, N1, N2). N2 must be a multiple of 32. + quant_dtype: FP8 dtype to cast quantized values to (defaults to + torch.float8_e4m3fn). + + Returns: + Tuple of: + out: FP8 tensor of shape (M, N1 * N2). + out_scales: e8m0 (uint8) scale tensor of shape + (M, (N1 * N2) // 32). + """ + assert x.dim() == 3, f"x must be 3D, got {x.dim()}" + M, N1, N2 = x.shape + assert ( + N2 % _QUANT_BLOCK_SIZE == 0 + ), f"N2={N2} must be a multiple of {_QUANT_BLOCK_SIZE}" + + BLOCK_SIZE_N2 = max(triton.next_power_of_2(N2), _QUANT_BLOCK_SIZE) + N = N1 * N2 + + out = torch.empty((M, N), dtype=quant_dtype, device=x.device) + out_scales = torch.empty( + (M, N // _QUANT_BLOCK_SIZE), dtype=torch.uint8, device=x.device + ) + + grid = (M, N1) + _fused_flatten_mxfp8_quant_kernel[grid]( + x, + out, + out_scales, + *x.stride(), + *out.stride(), + *out_scales.stride(), + N2, + BLOCK_SIZE_N2=BLOCK_SIZE_N2, + QUANT_BLOCK_SIZE=_QUANT_BLOCK_SIZE, + ) + + return out, out_scales diff --git a/aiter/ops/triton/quant/quant.py b/aiter/ops/triton/quant/quant.py index 0883d78df0..0258fa5e43 100644 --- a/aiter/ops/triton/quant/quant.py +++ b/aiter/ops/triton/quant/quant.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: MIT # Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. +from typing import Optional, Tuple + import triton import torch from aiter.ops.triton._triton_kernels.quant.quant import ( @@ -9,6 +11,9 @@ _dynamic_per_token_quant_fp8_i8_kernel, _dynamic_mxfp4_quant_kernel, _mxfp4_quant_op, + _dynamic_mxfp8_quant_kernel, + _mxfp8_quant_op, + _fp8_legacy_to_mxfp8_kernel, ) from aiter.ops.triton.utils.logger import AiterTritonLogger @@ -18,8 +23,14 @@ "dynamic_per_token_quant_fp8_i8", "dynamic_mxfp4_quant", "_mxfp4_quant_op", + "dynamic_mxfp8_quant", + "fp8_legacy_to_mxfp8", + "_mxfp8_quant_op", ] +_MXFP8_QUANT_BLOCK_SIZE = 32 +_MXFP8_LEGACY_BLOCK_SIZE = 128 + _LOGGER = AiterTritonLogger() @@ -214,3 +225,128 @@ def dynamic_mxfp4_quant( ) return (x_fp4, blockscale_e8m0) + + +def dynamic_mxfp8_quant( + x: torch.Tensor, + scale: Optional[torch.Tensor] = None, + quant_dtype: torch.dtype = torch.float8_e4m3fn, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Per-1x32 MXFP8 quantization (e8m0 scale + FP8 e4m3 values). + + Args: + x: Input tensor (..., K). Typically bf16 or fp16. K % 32 == 0. + scale: Pre-allocated scale tensor (M, K // 32) uint8. Optional. + quant_dtype: FP8 dtype to cast quantized values to. On MI3xx + torch.float8_e4m3fnuz is the canonical FP8 e4m3 type. torch.float8_e4m3fn + is acceptable on hardware that supports it. + + Returns: + Tuple of: + y: FP8 tensor of shape x.shape. + s: e8m0 (uint8) scale tensor of shape (..., K // 32). + """ + assert x.dim() >= 2, f"x must be at least 2D, got {x.dim()}" + orig_shape = x.shape + K = orig_shape[-1] + assert ( + K % _MXFP8_QUANT_BLOCK_SIZE == 0 + ), f"last dim K={K} must be a multiple of {_MXFP8_QUANT_BLOCK_SIZE}" + + x2d = x.reshape(-1, K).contiguous() + M = x2d.shape[0] + Ns = K // _MXFP8_QUANT_BLOCK_SIZE # number of scales per row + + y = torch.empty((M, K), dtype=quant_dtype, device=x.device) + if scale is None: + scale = torch.empty((M, Ns), dtype=torch.uint8, device=x.device) + else: + assert scale.shape == (M, Ns), f"scale shape {scale.shape} != ({M},{Ns})" + assert scale.dtype == torch.uint8 + + BLOCK_SIZE_N = triton.next_power_of_2(K) + NUM_PRGMS = M + grid = (NUM_PRGMS,) + + _dynamic_mxfp8_quant_kernel[grid]( + x2d, + y, + scale, + M, + K, + x2d.stride(0), + x2d.stride(1), + y.stride(0), + y.stride(1), + scale.stride(0), + scale.stride(1), + BLOCK_SIZE_N=BLOCK_SIZE_N, + QUANT_BLOCK_SIZE=_MXFP8_QUANT_BLOCK_SIZE, + NUM_PRGMS=NUM_PRGMS, + ) + + y = y.view(*orig_shape[:-1], K) + s = scale.view(*orig_shape[:-1], Ns) + return y, s + + +def fp8_legacy_to_mxfp8( + x_fnuz: torch.Tensor, + x_scale_fp32: torch.Tensor, + y_fn: Optional[torch.Tensor] = None, + y_scale: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Transcode (FP8 e4m3fnuz, fp32 1x128 scale) -> (FP8 e4m3fn, e8m0 1x32 scale) + in a single Triton launch. Replaces the Python dequant+requant cascade + used when MXFP8 path receives legacy-formatted (FP8 + fp32 1x128) inputs. + + Args: + x_fnuz: FP8 e4m3fnuz tensor of shape (M, N), N % 32 == 0. + x_scale_fp32: fp32 scale of shape (M, N // 128). + y_fn: optional preallocated output FP8 e4m3fn tensor. + y_scale: optional preallocated uint8 e8m0 scale tensor. + + Returns: + y_fn (M, N) fp8 e4m3fn, y_scale (M, N // 32) uint8 e8m0. + """ + assert x_fnuz.dim() == 2, f"x must be 2D, got {x_fnuz.dim()}" + M, N = x_fnuz.shape + assert N % _MXFP8_QUANT_BLOCK_SIZE == 0 + assert N % _MXFP8_LEGACY_BLOCK_SIZE == 0 + assert x_scale_fp32.shape == ( + M, + N // _MXFP8_LEGACY_BLOCK_SIZE, + ), f"x_scale_fp32 shape {x_scale_fp32.shape} != ({M},{N // _MXFP8_LEGACY_BLOCK_SIZE})" + + Ns = N // _MXFP8_QUANT_BLOCK_SIZE + if y_fn is None: + y_fn = torch.empty((M, N), dtype=torch.float8_e4m3fn, device=x_fnuz.device) + if y_scale is None: + y_scale = torch.empty((M, Ns), dtype=torch.uint8, device=x_fnuz.device) + + BLOCK_SIZE_M = 1 + grid = (triton.cdiv(M, BLOCK_SIZE_M), Ns) + + _fp8_legacy_to_mxfp8_kernel[grid]( + x_fnuz, + x_scale_fp32, + y_fn, + y_scale, + M, + N, + x_fnuz.stride(0), + x_fnuz.stride(1), + x_scale_fp32.stride(0), + x_scale_fp32.stride(1), + y_fn.stride(0), + y_fn.stride(1), + y_scale.stride(0), + y_scale.stride(1), + BLOCK_SIZE_M=BLOCK_SIZE_M, + QUANT_BLOCK_SIZE=_MXFP8_QUANT_BLOCK_SIZE, + LEGACY_BLOCK_SIZE=_MXFP8_LEGACY_BLOCK_SIZE, + ) + + return y_fn, y_scale diff --git a/aiter/ops/triton/utils/_triton/tunning/ut_afp8wfp8_gemm_preshuffle.py b/aiter/ops/triton/utils/_triton/tunning/ut_afp8wfp8_gemm_preshuffle.py new file mode 100644 index 0000000000..8f13ac4b3e --- /dev/null +++ b/aiter/ops/triton/utils/_triton/tunning/ut_afp8wfp8_gemm_preshuffle.py @@ -0,0 +1,44 @@ +import sys +from _utils import ( + run_profile, + get_input_shape_and_config_list, +) + +############################################################ +# +import torch +from aiter.ops.triton.gemm.basic.gemm_afp8wfp8 import gemm_afp8wfp8_preshuffle +from op_tests.triton_tests.gemm.basic.test_gemm_afp8wfp8 import ( + generate_inputs, +) +from aiter.ops.triton.utils.types import get_fp8_dtypes +from aiter.ops.triton.utils.gemm_config_utils import compute_splitk_params + +############################################################ + +input_shape, config_list = get_input_shape_and_config_list(sys.argv, shape_size=3) +M, N, K = input_shape + +############################################################ +# +_, e4m3_type = get_fp8_dtypes() +dtype = torch.bfloat16 +x_fp8, w_fp8, w_kernel, x_scales, w_scales = generate_inputs( + *input_shape, + shuffle=True, +) +############################################################ + +for config in config_list: + if config is not None: + compute_splitk_params(config, K) + + def fn(): + ############################################################ + # + gemm_afp8wfp8_preshuffle( + x_fp8, w_kernel, x_scales, w_scales, dtype=dtype, config=config + ) + ############################################################ + + run_profile(fn) diff --git a/op_tests/triton_tests/fusions/test_fused_clamp_act_mul.py b/op_tests/triton_tests/fusions/test_fused_clamp_act_mul.py index 7d0e0c4ff2..056b3ceae7 100644 --- a/op_tests/triton_tests/fusions/test_fused_clamp_act_mul.py +++ b/op_tests/triton_tests/fusions/test_fused_clamp_act_mul.py @@ -6,10 +6,12 @@ from aiter.ops.triton.fusions.fused_clamp_act_mul import ( fused_clamp_act_mul, ) +from aiter.utility import fp4_utils from op_tests.triton_tests.quant.test_fused_fp8_quant import ( per_token_fp8_group_quant, upcast, ) +from op_tests.triton_tests.gemm.basic.test_gemm_afp4wfp4 import un_shuffle_scales def _torch_reference(inp, swiglu_limit, weights, dtype_quant): @@ -101,3 +103,85 @@ def test_fused_clamp_act_mul( assert out.dtype == inp.dtype torch.testing.assert_close(out, ref, atol=1e-2, rtol=1e-2) + + +def _torch_reference_ue8m0(inp, swiglu_limit, weights, dtype_quant, quant_block_size): + """Bit-exact torch model of the kernel's ue8m0 path: the exp2-based SiLU + (matching ``_silu_exp2``) in fp32 followed by per-group MXFP8 quant with + round-up e8m0 scales. Returns ``(out_q, unshuffled_scale)``.""" + gate, up = inp.chunk(2, dim=-1) + gate = gate.float() + up = up.float() + if swiglu_limit > 0: + up = torch.clamp(up, min=-swiglu_limit, max=swiglu_limit) + gate = torch.clamp(gate, max=swiglu_limit) + y = (gate / (1.0 + torch.exp2(-(gate * 1.44269504089)))) * up + if weights is not None: + y = weights * y + + M, N = y.shape + QB = quant_block_size + dtype_max = torch.finfo(dtype_quant).max + num_blocks = (N + QB - 1) // QB + y = y.view(M, num_blocks, QB) + max_val = y.abs().amax(dim=2, keepdim=True) + dequant_scale = max_val / dtype_max + # Round dequant_scale up to a power of two via the fp32 exponent field. + exp = (dequant_scale.view(torch.int32) + 0x007FFFFF) & 0x7F800000 + rounded = exp.view(torch.float32) + quant_scale = torch.where(rounded == 0, torch.zeros_like(rounded), 1.0 / rounded) + out_q = (y * quant_scale).view(M, N).to(dtype_quant) + scale = (exp >> 23).to(torch.uint8).view(M, num_blocks) + return out_q, scale + + +@pytest.mark.parametrize("M", [1, 2, 7, 32, 100, 257]) +@pytest.mark.parametrize("D", [2048, 3072]) +@pytest.mark.parametrize("swiglu_limit", [0.0, 7.0]) +@pytest.mark.parametrize("with_weights", [False, True]) +@pytest.mark.parametrize("shuffle_scale", [False, True]) +def test_fused_clamp_act_mul_ue8m0(M, D, swiglu_limit, with_weights, shuffle_scale): + """ue8m0 group quant. The fp8 output and e8m0 scales must match the torch + reference; when ``shuffle_scale`` is set the kernel must lay the scales out + exactly like ``fp4_utils.e8m0_shuffle`` applied to the unshuffled scales.""" + torch.manual_seed(42) + N = D // 2 + quant_block_size = 32 + dtype_quant = torch.float8_e4m3fn + w = ( + torch.randn(M, 1, device="cuda", dtype=torch.float32) * 0.5 + if with_weights + else None + ) + inp = torch.randn(M, D, device="cuda", dtype=torch.bfloat16) + + out_q, scale = fused_clamp_act_mul( + inp, + swiglu_limit=swiglu_limit, + weights=w, + activation="silu", + dtype_quant=dtype_quant, + quant_block_size=quant_block_size, + scale_dtype_fmt="ue8m0", + shuffle_scale=shuffle_scale, + ) + + ref_out, ref_scale = _torch_reference_ue8m0( + inp, swiglu_limit, w, dtype_quant, quant_block_size + ) + assert torch.equal(out_q.view(torch.uint8), ref_out.view(torch.uint8)) + + num_blocks = (N + quant_block_size - 1) // quant_block_size + if shuffle_scale: + # Kernel preshuffles in place; the reference shuffles with e8m0_shuffle. + # Both leave padding undefined, so undo the shuffle and compare the valid + # region (which also confirms the kernel layout matches e8m0_shuffle). + expected = fp4_utils.e8m0_shuffle(ref_scale) + assert scale.shape == expected.shape + sm = scale.shape[0] + got = un_shuffle_scales(scale.view(sm // 32, -1))[:M, :num_blocks] + exp = un_shuffle_scales(expected.view(sm // 32, -1))[:M, :num_blocks] + assert torch.equal(got, exp) + assert torch.equal(got, ref_scale) + else: + assert torch.equal(scale[:M, :num_blocks], ref_scale) diff --git a/op_tests/triton_tests/gemm/basic/test_gemm_afp8wfp8.py b/op_tests/triton_tests/gemm/basic/test_gemm_afp8wfp8.py new file mode 100644 index 0000000000..8b3bf5eae8 --- /dev/null +++ b/op_tests/triton_tests/gemm/basic/test_gemm_afp8wfp8.py @@ -0,0 +1,144 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. + +import pytest +import torch + +from aiter.ops.triton.gemm.basic.gemm_afp8wfp8 import ( + gemm_afp8wfp8, + gemm_afp8wfp8_preshuffle, +) +from aiter.ops.shuffle import shuffle_weight +import aiter.ops.triton.utils._triton.arch_info as arch_info + +SCALE_GROUP_SIZE = 32 # A: 1x32 e8m0 scale group +W_SCALE_K_GROUP = 128 # B: 128 in K direction +W_SCALE_N_GROUP = 128 # B: 128 in N direction +FP8_MAX = 448.0 # e4m3 max + + +def e8m0_to_f32(x: torch.Tensor) -> torch.Tensor: + """Decode unsigned-biased e8m0 (uint8) to fp32. Bias 127, value = 2^(b-127).""" + return torch.exp2((x.to(torch.int32) - 127).to(torch.float32)) + + +def generate_inputs(M: int, N: int, K: int, shuffle: bool = False): + """Returns ``(x_fp8, w_fp8, w_kernel, x_scales, w_scales)``. + + ``w_fp8`` is always the unshuffled weight (for use by the fp32 reference). + ``w_kernel`` is the weight to pass to the kernel: identical to ``w_fp8`` + when ``shuffle=False``, or shuffled via ``shuffle_weight(layout=(16, 16))`` + when ``shuffle=True``. + """ + # Small random fp32 → fp8 e4m3fn, kept inside e4m3 range so the cast is exact-ish. + x_f32 = torch.randn((M, K), dtype=torch.float32, device="cuda") + w_f32 = torch.randn((N, K), dtype=torch.float32, device="cuda") + x_f32 = torch.clamp(x_f32, -FP8_MAX, FP8_MAX) + w_f32 = torch.clamp(w_f32, -FP8_MAX, FP8_MAX) + x_fp8 = x_f32.to(torch.float8_e4m3fn) + w_fp8 = w_f32.to(torch.float8_e4m3fn) + + # e8m0 scales near 127 (== 1.0) so the dequant has unit-ish magnitude. + x_scales = torch.randint( + 125, 130, (M, K // SCALE_GROUP_SIZE), dtype=torch.uint8, device="cuda" + ) + w_scales = torch.randint( + 125, + 130, + (N // W_SCALE_N_GROUP, K // W_SCALE_K_GROUP), + dtype=torch.uint8, + device="cuda", + ) + + if shuffle: + # shuffle_weight operates on raw bytes; view as uint8 to avoid dtype quirks. + w_kernel = shuffle_weight(w_fp8.view(torch.uint8), layout=(16, 16)) + else: + w_kernel = w_fp8 + + return x_fp8, w_fp8, w_kernel, x_scales, w_scales + + +def run_torch_gemm_afp8wfp8( + x_fp8: torch.Tensor, + w_fp8: torch.Tensor, + x_scales: torch.Tensor, + w_scales: torch.Tensor, + out_dtype: torch.dtype, +) -> torch.Tensor: + """Reference: dequant both operands to fp32 and run torch.mm.""" + M, K = x_fp8.shape + N, _ = w_fp8.shape + + x_view = x_fp8 if x_fp8.dtype != torch.uint8 else x_fp8.view(torch.float8_e4m3fn) + w_view = w_fp8 if w_fp8.dtype != torch.uint8 else w_fp8.view(torch.float8_e4m3fn) + x_f32 = x_view.to(torch.float32) + w_f32 = w_view.to(torch.float32) + + x_s_f32 = e8m0_to_f32(x_scales).repeat_interleave(SCALE_GROUP_SIZE, dim=1) + assert x_s_f32.shape == (M, K) + + w_s_f32 = e8m0_to_f32(w_scales) + w_s_f32 = w_s_f32.repeat_interleave(W_SCALE_N_GROUP, dim=0).repeat_interleave( + W_SCALE_K_GROUP, dim=1 + ) + assert w_s_f32.shape == (N, K) + + x_dq = x_f32 * x_s_f32 + w_dq = w_f32 * w_s_f32 + return torch.mm(x_dq, w_dq.T).to(out_dtype) + + +def get_shapes(): + # (M, N, K), with N % 128 == 0 and K % 128 == 0 to fit the 128x128 W-scale layout. + return [ + (m, n, k) + for m in [1, 4, 8, 16, 32, 64, 128] + for n, k in [ + (1536, 4096), + (4096, 1024), + (512, 4096), + (8192, 1024), + (2048, 7168), + (7168, 2048), + (768, 7168), + (7168, 384), + (8192, 1536), + ] + ] + + +@pytest.mark.parametrize("M, N, K", get_shapes()) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) +def test_gemm_afp8wfp8(M: int, N: int, K: int, dtype: torch.dtype): + torch.manual_seed(0) + if not arch_info.is_fp8_avail(): + pytest.skip("MXFP8 GEMM requires FP8-capable arch") + torch.cuda.empty_cache() + + x_fp8, w_fp8, w_kernel, x_scales, w_scales = generate_inputs(M, N, K, shuffle=False) + + torch_out = run_torch_gemm_afp8wfp8(x_fp8, w_fp8, x_scales, w_scales, dtype) + triton_out = gemm_afp8wfp8(x_fp8, w_kernel, x_scales, w_scales, dtype=dtype) + + torch.testing.assert_close(triton_out, torch_out, atol=0.03, rtol=1e-2) + + +@pytest.mark.parametrize("M, N, K", get_shapes()) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +def test_gemm_afp8wfp8_preshuffle(M: int, N: int, K: int, dtype: torch.dtype): + torch.manual_seed(0) + if not arch_info.is_fp8_avail(): + pytest.skip("MXFP8 GEMM requires FP8-capable arch") + if N % 16 != 0 or K % 32 != 0: + pytest.skip("Preshuffle requires N % 16 == 0 and K % 32 == 0") + torch.cuda.empty_cache() + + x_fp8, w_fp8, w_kernel, x_scales, w_scales = generate_inputs(M, N, K, shuffle=True) + + torch_out = run_torch_gemm_afp8wfp8(x_fp8, w_fp8, x_scales, w_scales, dtype) + triton_out = gemm_afp8wfp8_preshuffle( + x_fp8, w_kernel, x_scales, w_scales, dtype=dtype + ) + + torch.testing.assert_close(triton_out, torch_out, atol=0.03, rtol=1e-2) diff --git a/op_tests/triton_tests/gemm/fused/test_fused_gemm_a16w16_quant_x.py b/op_tests/triton_tests/gemm/fused/test_fused_gemm_a16w16_quant_x.py new file mode 100644 index 0000000000..4d74d04b8d --- /dev/null +++ b/op_tests/triton_tests/gemm/fused/test_fused_gemm_a16w16_quant_x.py @@ -0,0 +1,143 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. + +import torch +import torch.nn.functional as F +import pytest + +from aiter.ops.triton.gemm.fused.fused_gemm_a16w16_quant_x import ( + fused_gemm_a16w16_quant_x, +) +from op_tests.triton_tests.gemm.basic.test_gemm_a16w16 import ( + generate_gemm_a16w16_inputs, +) + +_QUANT_BLOCK_SIZE = 32 +# 0xFF800000 in two's complement int32. Mask keeps sign + 8-bit exponent + top mantissa bit. +_E8M0_MASK_INT32 = -8388608 + + +def torch_mxfp8_quant_from_fp32(x_fp32: torch.Tensor): + """Bit-faithful port of `_dynamic_mxfp8_quant_kernel` quant logic, taking fp32 input. + + Computes per-1x32 e8m0 scale (uint8) and FP8 e4m3fn values. + """ + assert x_fp32.dim() == 2, f"x_fp32 must be 2D, got {x_fp32.dim()}" + M, K = x_fp32.shape + assert K % _QUANT_BLOCK_SIZE == 0 + Ng = K // _QUANT_BLOCK_SIZE + x_2d = x_fp32.reshape(M, Ng, _QUANT_BLOCK_SIZE).to(torch.float32) + amax = torch.amax(torch.abs(x_2d), dim=-1, keepdim=True) # (M, Ng, 1) + + amax_i32 = amax.contiguous().view(torch.int32) + amax_i32 = (amax_i32 + 0x200000) & _E8M0_MASK_INT32 + amax_p2 = amax_i32.view(torch.float32) + + scale_unbiased = torch.log2(amax_p2).floor() - 8 + scale_unbiased = torch.clamp(scale_unbiased, min=-127, max=127) + scale_e8m0 = (scale_unbiased.to(torch.int32) + 127).to(torch.uint8) + quant_scale = torch.exp2(-scale_unbiased) + + qx_2d = x_2d * quant_scale + qx = qx_2d.reshape(M, K) + y_fp8 = qx.to(torch.float8_e4m3fn) + s = scale_e8m0.reshape(M, Ng) + return y_fp8, s + + +def get_x_vals(): + x_vals = [(1024, 1024, 1024)] + x_vals += [(2048, 2048, 2048)] + # DSv4 router gate: num_tokens x 384 x 7168 + x_vals += [(2**i, 384, 7168) for i in range(5, 9)] + # DSR1 router GEMM + x_vals += [(2**i, 256, 7168) for i in range(5, 9)] + return x_vals + + +def _assert_quant_close(triton_x_quant, triton_x_scales, x): + ref_x_quant, ref_x_scales = torch_mxfp8_quant_from_fp32(x.to(torch.float32)) + # e8m0 scales: bit-exact (integer-only after fp32 cast). + torch.testing.assert_close(triton_x_scales, ref_x_scales) + # Quantized values: compare via uint8 view (allow off-by-1 for any rounding + # subtlety in the fp32->fp8 cast). + torch.testing.assert_close( + triton_x_quant.view(torch.uint8).to(torch.int32), + ref_x_quant.view(torch.uint8).to(torch.int32), + atol=1, + rtol=0, + ) + + +@pytest.mark.parametrize("M, N, K", get_x_vals()) +def test_fused_gemm_a16w16_quant_x(M: int, N: int, K: int): + torch.cuda.empty_cache() + x, w, _, _, _ = generate_gemm_a16w16_inputs( + M, N, K, dtype=torch.bfloat16, output=False + ) + + torch_y = F.linear(x, w, bias=None) + + triton_y, triton_x_quant, triton_x_scales = fused_gemm_a16w16_quant_x(x, w) + + torch.testing.assert_close(triton_y, torch_y, atol=1e-1, rtol=1e-2) + _assert_quant_close(triton_x_quant, triton_x_scales, x) + + +def get_fewer_x_vals(): + x_vals = [(16, 1024, 1024)] + x_vals += [(128, 8192, 512)] + x_vals += [(256, 512, 8192)] + x_vals += [(1024, 1024, 1024)] + return x_vals + + +@pytest.mark.parametrize("activation", ["gelu", "gelu_tanh", "silu"]) +@pytest.mark.parametrize("M, N, K", get_fewer_x_vals()) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("output", [True, False]) +def test_fused_gemm_a16w16_quant_x_activation( + M: int, N: int, K: int, dtype, output, activation +): + x, w, _, _, y = generate_gemm_a16w16_inputs(M, N, K, dtype, output=output) + + torch_y = F.linear(x, w, bias=None) + if activation == "gelu": + torch_y = F.gelu(torch_y) + elif activation == "gelu_tanh": + torch_y = F.gelu(torch_y, approximate="tanh") + elif activation == "silu": + torch_y = F.silu(torch_y) + + triton_y, triton_x_quant, triton_x_scales = fused_gemm_a16w16_quant_x( + x, + w, + bias=None, + dtype=dtype, + y=y, + activation=activation, + ) + + torch.testing.assert_close(triton_y, torch_y, atol=1e-1, rtol=1e-2) + _assert_quant_close(triton_x_quant, triton_x_scales, x) + + +@pytest.mark.parametrize("M, N, K", get_fewer_x_vals()) +@pytest.mark.parametrize("skip_reduce", [True, False]) +def test_fused_gemm_a16w16_quant_x_skip_reduce(M: int, N: int, K: int, skip_reduce): + torch.cuda.empty_cache() + x, w, _, _, _ = generate_gemm_a16w16_inputs( + M, N, K, dtype=torch.bfloat16, output=False + ) + + torch_y = F.linear(x, w, bias=None) + + triton_y, triton_x_quant, triton_x_scales = fused_gemm_a16w16_quant_x( + x, w, skip_reduce=skip_reduce + ) + + if triton_y.dim() == 3: + triton_y = triton_y.sum(axis=0).to(torch.bfloat16) + + torch.testing.assert_close(triton_y, torch_y, atol=1e-3, rtol=1e-2) + _assert_quant_close(triton_x_quant, triton_x_scales, x) diff --git a/op_tests/triton_tests/moe/test_moe_routing.py b/op_tests/triton_tests/moe/test_moe_routing.py index 02afc8824d..6208fed867 100644 --- a/op_tests/triton_tests/moe/test_moe_routing.py +++ b/op_tests/triton_tests/moe/test_moe_routing.py @@ -1,12 +1,22 @@ import pytest import torch -from aiter.ops.triton.moe.moe_routing.routing import routing, routing_torch +import torch.nn.functional as F +from aiter.ops.triton.moe.moe_routing.routing import ( + routing, + routing_a8w4, + routing_a8w4_from_hash, + routing_a8w4_from_topk, + routing_torch, + compute_expt_data_torch, +) from aiter.ops.triton.utils._triton.arch_info import get_arch def assert_equal(ref, tri): if isinstance(ref, torch.Tensor): - assert torch.all(ref == tri) + # CI may be failing using this: + # assert torch.all(ref == tri) + assert ((ref.cpu().numpy() - tri.cpu().numpy()) ** 2).sum() == 0 else: assert ref == tri @@ -96,7 +106,7 @@ def init_data(n_tokens, n_expts_tot, dtype=torch.float16, device="cuda"): [(128, 4), (128, 6), (128, 32), (1500, 8), (256, 8), (8, 2)], ) @pytest.mark.parametrize("sm_first", [True, False]) -def test_op(n_tokens, n_expts_tot, n_expts_act, sm_first): +def test_routing(n_tokens, n_expts_tot, n_expts_act, sm_first): if get_arch() not in ["gfx950", "gfx1250"]: pytest.skip("MOE stack not fully implemented on non-CDNA4 arch yet.") @@ -139,6 +149,430 @@ def _assert_indx_equal(ref, tri): _assert_indx_equal(ref_scatter, tri_scatter) +# -------------------------- +# Reference implementations for routing_a8w4* paths +# -------------------------- + + +def _score_transform_torch(logits, score_mode): + if score_mode == "sqrtsoftplus": + return torch.sqrt(F.softplus(logits.to(torch.float32))).to(logits.dtype) + # "softmax" mode in the kernel means "no pre-transform" (identity) + return logits + + +def _sort_and_build_torch(expt_scal, expt_indx, n_expts_tot, block_m): + """Mirror of the post-topk sort_tokens + ExptData build, in pytorch. + + expt_scal, expt_indx: shape (n_tokens, n_expts_act) — per-row order is + preserved (we do NOT sort experts per row here; that's the caller's + responsibility if needed). + Returns (hist, topk_indx, gate_indx, gate_scal, expt_data) matching the + triton sort_tokens contract. + """ + n_tokens, n_expts_act = expt_scal.shape + n_gates = n_tokens * n_expts_act + scal_flat = expt_scal.reshape(-1) + indx_flat = expt_indx.reshape(-1).to(torch.int32) + topk_indx = torch.argsort(indx_flat, stable=True).to(torch.int32) + gate_indx = torch.argsort(topk_indx, stable=True).to(torch.int32) + gate_scal = scal_flat[topk_indx.long()] + hist = torch.histc( + indx_flat.float(), bins=n_expts_tot, min=0, max=n_expts_tot - 1 + ).int() + expt_data = compute_expt_data_torch(hist, n_expts_tot, n_gates, block_m) + return hist, topk_indx, gate_indx, gate_scal, expt_data + + +def routing_a8w4_torch( + logits, + n_expts_act, + block_m, + *, + score_mode="sqrtsoftplus", + bias=None, + renorm=True, + routed_scaling_factor=1.0, +): + n_tokens, n_expts_tot = logits.shape + + # 1. Score transform; bias added only for selection. + transformed_f32 = _score_transform_torch(logits, score_mode).to(torch.float32) + if bias is not None: + biased = transformed_f32 + bias.to(torch.float32) + else: + biased = transformed_f32 + + # 2. Top-k selection (by biased score), then sort experts ascending per row + # — this matches streaming_topk's final per-row sort. + _, topk_ids = torch.topk(biased, n_expts_act, dim=1) + topk_ids, _ = torch.sort(topk_ids, dim=1) + + # 3. Gather the UNBIASED transformed value at the selected positions. + expt_scal = torch.gather(transformed_f32, 1, topk_ids) + + # 4. Renorm + scale (or just scale). + if renorm: + s = expt_scal.sum(dim=1, keepdim=True) + expt_scal = expt_scal / (s + 1e-20) * routed_scaling_factor + elif routed_scaling_factor != 1.0: + expt_scal = expt_scal * routed_scaling_factor + + expt_scal = expt_scal.to(logits.dtype) + topk_ids = topk_ids.to(torch.int16) + return _sort_and_build_torch(expt_scal, topk_ids, n_expts_tot, block_m) + + +def routing_a8w4_from_hash_torch( + router_logits, + tid2eid, + input_ids, + n_expts_act, + block_m, + *, + score_mode="sqrtsoftplus", + renorm=True, + routed_scaling_factor=1.0, +): + n_tokens, n_expts_tot = router_logits.shape + iid = input_ids.to(torch.int64) + # Expert ids come straight from the table — no per-row sort. + expt_indx = tid2eid[iid, :n_expts_act].to(torch.int32) + + # Score transform on the full row, then gather the K weights. + transformed_f32 = _score_transform_torch(router_logits, score_mode).to( + torch.float32 + ) + expt_scal = torch.gather(transformed_f32, 1, expt_indx.to(torch.int64)) + + if renorm: + s = expt_scal.sum(dim=1, keepdim=True) + expt_scal = expt_scal / (s + 1e-20) * routed_scaling_factor + elif routed_scaling_factor != 1.0: + expt_scal = expt_scal * routed_scaling_factor + + expt_scal = expt_scal.to(router_logits.dtype) + expt_indx = expt_indx.to(torch.int16) + return _sort_and_build_torch(expt_scal, expt_indx, n_expts_tot, block_m) + + +def routing_a8w4_from_topk_torch(topk_weights, topk_ids, n_expts_tot, block_m): + return _sort_and_build_torch( + topk_weights, + topk_ids.to(torch.int16), + n_expts_tot, + block_m, + ) + + +def _check_routing_data(ref_pack, tri_routing_data, tri_gather, tri_scatter): + """Strict equality check: works when the triton sort and stable argsort + agree on intra-bucket order (the sort_tokens / sort_tokens_fused path).""" + ref_hist, ref_topk_indx, ref_gate_indx, ref_gate_scal, ref_expt_data = ref_pack + assert_close(ref_gate_scal, tri_routing_data.gate_scal, 2e-2, 4e-3) + assert_equal(ref_hist, tri_routing_data.expt_hist) + assert_equal(ref_expt_data.hist, tri_routing_data.expt_data.hist) + assert_equal( + ref_expt_data.token_offs_raw, tri_routing_data.expt_data.token_offs_raw + ) + assert_equal( + ref_expt_data.token_offs_pad, tri_routing_data.expt_data.token_offs_pad + ) + assert_equal(ref_expt_data.block_pid_map, tri_routing_data.expt_data.block_pid_map) + assert_equal(ref_topk_indx, tri_gather) + assert_equal(ref_gate_indx, tri_scatter) + + +def _check_routing_data_bucket( + ref_pack, + tri_routing_data, + tri_gather, + tri_scatter, + topk_weights, + topk_ids, +): + """Bucket-multiset check for the fused_routing_from_topk sort path, which + uses a different stable tie-breaking than torch.argsort. Validates the + histogram + ExptData strictly, then compares per-expert (token, weight) + multisets and the inverse-permutation invariant. + """ + ref_hist, _, _, _, ref_expt_data = ref_pack + assert_equal(ref_hist, tri_routing_data.expt_hist) + assert_equal(ref_expt_data.hist, tri_routing_data.expt_data.hist) + assert_equal( + ref_expt_data.token_offs_raw, tri_routing_data.expt_data.token_offs_raw + ) + assert_equal( + ref_expt_data.token_offs_pad, tri_routing_data.expt_data.token_offs_pad + ) + assert_equal(ref_expt_data.block_pid_map, tri_routing_data.expt_data.block_pid_map) + + n_tokens, n_expts_act = topk_ids.shape + n_gates = n_tokens * n_expts_act + n_expts_tot = ref_hist.numel() + + # Inverse permutation invariant: gate_indx[topk_indx[j]] == j. + iota = torch.arange(n_gates, dtype=torch.int32, device=tri_gather.device) + assert torch.equal(tri_scatter[tri_gather.long()], iota), "scatter[gather[j]] != j" + + # Per-expert (token, weight) multisets. + flat_ids = topk_ids.reshape(-1).cpu().tolist() + flat_w = topk_weights.reshape(-1).float().cpu().tolist() + src = tri_gather.cpu().tolist() + scal = tri_routing_data.gate_scal.float().cpu().tolist() + cum = torch.cumsum(ref_hist, dim=0).cpu().tolist() + + ground = {e: [] for e in range(n_expts_tot)} + for i, e in enumerate(flat_ids): + token = i // n_expts_act + ground[e].append((token, flat_w[i])) + for e in ground: + ground[e].sort() + + got = {e: [] for e in range(n_expts_tot)} + e = 0 + for j in range(n_gates): + while e < n_expts_tot and j >= cum[e]: + e += 1 + token = src[j] // n_expts_act + # Bucket invariant: at expert-sorted position j inside expert e's + # slice, the source (token, slot) must reference expert e. + assert flat_ids[src[j]] == e, ( + f"bucket-invariant violated at pos {j}: source flat={src[j]} " + f"has expert {flat_ids[src[j]]}, expected {e}" + ) + got[e].append((token, scal[j])) + for e in got: + got[e].sort() + + for e in range(n_expts_tot): + rb, tb = ground[e], got[e] + assert len(rb) == len(tb), f"expert {e}: ref={len(rb)} test={len(tb)}" + for (tt_r, w_r), (tt_t, w_t) in zip(rb, tb): + assert tt_r == tt_t, f"expert {e}: token ref={tt_r} test={tt_t}" + assert ( + abs(w_r - w_t) <= 1e-6 + ), f"expert {e} token {tt_r}: weight ref={w_r} test={w_t}" + + +# -------------------------- +# routing_a8w4 +# -------------------------- + + +@pytest.mark.parametrize( + "n_tokens, n_expts_tot, n_expts_act", + [ + (8, 128, 4), # tiny: hits sort_tokens_fused path (n_tokens <= 16) + (16, 128, 4), # boundary + (64, 128, 4), + (1024, 128, 4), + (1024, 256, 8), + ], +) +@pytest.mark.parametrize( + "score_mode, has_bias, renorm, routed_scaling_factor", + [ + ("sqrtsoftplus", True, True, 2.5), # full V4 noaux_tc path + ("sqrtsoftplus", True, False, 1.0), # bias, no renorm + ("sqrtsoftplus", False, True, 1.0), # no bias + ("softmax", False, False, 1.0), # identity transform, no renorm + ], +) +@pytest.mark.parametrize("block_m", [16, 32]) +def test_routing_a8w4( + n_tokens, + n_expts_tot, + n_expts_act, + score_mode, + has_bias, + renorm, + routed_scaling_factor, + block_m, +): + if get_arch() not in ["gfx950", "gfx1250"]: + pytest.skip("MOE stack not fully implemented on non-CDNA4 arch yet.") + + device = "cuda" + torch.manual_seed(2) + logits = init_data(n_tokens, n_expts_tot, device=device, dtype=torch.float32) + bias = ( + torch.randn(n_expts_tot, dtype=torch.float32, device=device) * 0.05 + if has_bias + else None + ) + + ref_pack = routing_a8w4_torch( + logits.clone(), + n_expts_act, + block_m, + score_mode=score_mode, + bias=bias, + renorm=renorm, + routed_scaling_factor=routed_scaling_factor, + ) + tri_routing_data, tri_gather, tri_scatter = routing_a8w4( + logits, + n_expts_act, + block_m, + score_mode=score_mode, + bias=bias, + renorm=renorm, + routed_scaling_factor=routed_scaling_factor, + ) + + _check_routing_data(ref_pack, tri_routing_data, tri_gather, tri_scatter) + assert tri_routing_data.n_expts_tot == n_expts_tot + assert tri_routing_data.n_expts_act == n_expts_act + assert tri_routing_data.block_m == block_m + + +# -------------------------- +# routing_a8w4_from_hash +# -------------------------- + + +@pytest.mark.parametrize( + "n_tokens, n_expts_tot, n_expts_act", + [ + (8, 128, 4), + (64, 128, 4), + (1024, 256, 8), + ], +) +@pytest.mark.parametrize( + "renorm, routed_scaling_factor", + [ + (True, 2.5), # production V4 hash config + (True, 1.0), + (False, 1.0), + ], +) +@pytest.mark.parametrize("block_m", [16, 32]) +def test_routing_a8w4_from_hash( + n_tokens, + n_expts_tot, + n_expts_act, + renorm, + routed_scaling_factor, + block_m, +): + if get_arch() not in ["gfx950", "gfx1250"]: + pytest.skip("MOE stack not fully implemented on non-CDNA4 arch yet.") + + device = "cuda" + torch.manual_seed(2) + vocab_size = 512 + router_logits = torch.randn( + n_tokens, n_expts_tot, dtype=torch.float32, device=device + ) + # Distinct experts per vocab entry (production V4 hash table contract). + # Avoids within-row duplicates that would make intra-bucket ordering + # implementation-defined between the triton sort and torch.argsort. + tid2eid = torch.stack( + [ + torch.randperm(n_expts_tot, device=device)[:n_expts_act] + for _ in range(vocab_size) + ], + dim=0, + ).to(torch.int32) + input_ids = torch.randint( + 0, vocab_size, (n_tokens,), dtype=torch.int32, device=device + ) + + ref_pack = routing_a8w4_from_hash_torch( + router_logits.clone(), + tid2eid, + input_ids, + n_expts_act, + block_m, + score_mode="sqrtsoftplus", + renorm=renorm, + routed_scaling_factor=routed_scaling_factor, + ) + tri_routing_data, tri_gather, tri_scatter = routing_a8w4_from_hash( + router_logits, + tid2eid, + input_ids, + n_expts_act, + block_m, + score_mode="sqrtsoftplus", + renorm=renorm, + routed_scaling_factor=routed_scaling_factor, + ) + + _check_routing_data(ref_pack, tri_routing_data, tri_gather, tri_scatter) + assert tri_routing_data.n_expts_tot == n_expts_tot + assert tri_routing_data.n_expts_act == n_expts_act + assert tri_routing_data.block_m == block_m + + +# -------------------------- +# routing_a8w4_from_topk +# -------------------------- + + +# fused_routing_from_topk requires n_tokens * n_expts_act <= 4096. +@pytest.mark.parametrize( + "n_tokens, n_expts_tot, n_expts_act", + [ + (8, 128, 4), + (64, 128, 4), + (256, 128, 4), + (256, 256, 8), + (512, 128, 8), + ], +) +@pytest.mark.parametrize("block_m", [16, 32]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16]) +def test_routing_a8w4_from_topk( + n_tokens, + n_expts_tot, + n_expts_act, + block_m, + dtype, +): + if get_arch() not in ["gfx950", "gfx1250"]: + pytest.skip("MOE stack not fully implemented on non-CDNA4 arch yet.") + + device = "cuda" + torch.manual_seed(2) + topk_weights = torch.randn(n_tokens, n_expts_act, dtype=dtype, device=device).abs() + # Per-row unique expert ids (the natural V4 case). + topk_ids = torch.stack( + [ + torch.randperm(n_expts_tot, device=device)[:n_expts_act] + for _ in range(n_tokens) + ], + dim=0, + ).to(torch.int32) + + ref_pack = routing_a8w4_from_topk_torch( + topk_weights.clone(), + topk_ids.clone(), + n_expts_tot, + block_m, + ) + tri_routing_data, tri_gather, tri_scatter = routing_a8w4_from_topk( + topk_weights, + topk_ids, + n_expts_tot, + block_m, + ) + + _check_routing_data_bucket( + ref_pack, + tri_routing_data, + tri_gather, + tri_scatter, + topk_weights, + topk_ids, + ) + assert tri_routing_data.n_expts_tot == n_expts_tot + assert tri_routing_data.n_expts_act == n_expts_act + assert tri_routing_data.block_m == block_m + + def bench_routing(): import triton.profiler as proton diff --git a/op_tests/triton_tests/quant/test_fused_fp8_quant.py b/op_tests/triton_tests/quant/test_fused_fp8_quant.py index 96d4f31fd0..c483b63b76 100644 --- a/op_tests/triton_tests/quant/test_fused_fp8_quant.py +++ b/op_tests/triton_tests/quant/test_fused_fp8_quant.py @@ -340,7 +340,7 @@ def run_torch_flatten_fp8_group_quant(x, dtype_quant, group_size): @pytest.mark.parametrize("M", [1, 32, 256]) -@pytest.mark.parametrize("N1, N2", [(16, 128)]) +@pytest.mark.parametrize("N1, N2", [(16, 128), (16, 7168)]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) def test_fused_flatten_fp8_group_quant(M: int, N1: int, N2: int, dtype): group_size = 128 diff --git a/op_tests/triton_tests/quant/test_quant_mxfp8.py b/op_tests/triton_tests/quant/test_quant_mxfp8.py new file mode 100644 index 0000000000..5bdeea8fc7 --- /dev/null +++ b/op_tests/triton_tests/quant/test_quant_mxfp8.py @@ -0,0 +1,439 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. + +import pytest +import torch + +from aiter.ops.triton.quant.quant import ( + dynamic_mxfp8_quant, + fp8_legacy_to_mxfp8, +) +from aiter.ops.triton.quant.fused_mxfp8_quant import ( + fused_rms_mxfp8_quant, + fused_dual_rmsnorm_mxfp8_quant, + fused_flatten_mxfp8_quant, +) +import aiter.ops.triton.utils._triton.arch_info as arch_info + +QUANT_BLOCK_SIZE = 32 +LEGACY_BLOCK_SIZE = 128 +# 0xFF800000 in two's complement int32. Mask keeps sign + 8-bit exponent + top mantissa bit. +_E8M0_MASK_INT32 = -8388608 + + +def torch_mxfp8_quant_from_fp32(x_fp32: torch.Tensor): + """Bit-faithful port of `_dynamic_mxfp8_quant_kernel` quant logic, taking fp32 input. + + Computes per-1x32 e8m0 scale (uint8) and FP8 e4m3fn values. + """ + assert x_fp32.dim() == 2, f"x_fp32 must be 2D, got {x_fp32.dim()}" + M, K = x_fp32.shape + assert K % QUANT_BLOCK_SIZE == 0 + Ng = K // QUANT_BLOCK_SIZE + x_2d = x_fp32.reshape(M, Ng, QUANT_BLOCK_SIZE).to(torch.float32) + amax = torch.amax(torch.abs(x_2d), dim=-1, keepdim=True) # (M, Ng, 1) + + # Same bit-level "round up to e8m0-representable pow-2" as the kernel. + amax_i32 = amax.contiguous().view(torch.int32) + amax_i32 = (amax_i32 + 0x200000) & _E8M0_MASK_INT32 + amax_p2 = amax_i32.view(torch.float32) + + scale_unbiased = torch.log2(amax_p2).floor() - 8 + scale_unbiased = torch.clamp(scale_unbiased, min=-127, max=127) + scale_e8m0 = (scale_unbiased.to(torch.int32) + 127).to(torch.uint8) + quant_scale = torch.exp2(-scale_unbiased) + + qx_2d = x_2d * quant_scale # broadcast over inner-32 + qx = qx_2d.reshape(M, K) + y_fp8 = qx.to(torch.float8_e4m3fn) + s = scale_e8m0.reshape(M, Ng) + return y_fp8, s + + +def e8m0_to_f32(x: torch.Tensor) -> torch.Tensor: + return torch.exp2((x.to(torch.int32) - 127).to(torch.float32)) + + +# ----------------------------------------------------------------------------- +# dynamic_mxfp8_quant +# ----------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "M, K", + [ + (1, 32), + (1, 64), + (1, 128), + (2, 32), + (8, 64), + (16, 128), + (32, 256), + (64, 512), + (128, 1024), + (137, 64), # non-power-of-2 M + (256, 32), + ], +) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) +def test_per_1x32_mxfp8_quant(M: int, K: int, dtype: torch.dtype): + if not arch_info.is_fp8_avail(): + pytest.skip("FP8 not supported on this arch") + torch.cuda.empty_cache() + torch.manual_seed(20) + + x = torch.randn((M, K), dtype=dtype, device="cuda") * 4.0 + + # Reference path: emulate the kernel in fp32 (matching its precision). + x_fp32 = x.to(torch.float32) + y_ref, s_ref = torch_mxfp8_quant_from_fp32(x_fp32) + + # Triton path. + y_kern, s_kern = dynamic_mxfp8_quant(x) + + # Scales must be bit-exact: the e8m0 derivation is integer-only after + # the fp32 cast, and amax is order-independent. + torch.testing.assert_close(s_kern, s_ref) + + # Quantized values: compare via the uint8 view (allow off-by-1 for any + # rounding-mode subtlety in the fp32→fp8 cast). + torch.testing.assert_close( + y_kern.view(torch.uint8).to(torch.int32), + y_ref.view(torch.uint8).to(torch.int32), + atol=1, + rtol=0, + ) + + +def test_per_1x32_mxfp8_quant_preallocated_scale(): + if not arch_info.is_fp8_avail(): + pytest.skip("FP8 not supported on this arch") + torch.cuda.empty_cache() + torch.manual_seed(20) + + M, K = 64, 256 + x = torch.randn((M, K), dtype=torch.bfloat16, device="cuda") + scale_pre = torch.empty( + (M, K // QUANT_BLOCK_SIZE), dtype=torch.uint8, device="cuda" + ) + y, s = dynamic_mxfp8_quant(x, scale=scale_pre) + assert s.data_ptr() == scale_pre.data_ptr() + + y_ref, s_ref = torch_mxfp8_quant_from_fp32(x.to(torch.float32)) + torch.testing.assert_close(s, s_ref) + + +def test_per_1x32_mxfp8_quant_multidim(): + """Wrapper folds higher dims into M; sanity-check 3D input.""" + if not arch_info.is_fp8_avail(): + pytest.skip("FP8 not supported on this arch") + torch.cuda.empty_cache() + torch.manual_seed(0) + + B, M, K = 4, 8, 128 + x = torch.randn((B, M, K), dtype=torch.bfloat16, device="cuda") + y, s = dynamic_mxfp8_quant(x) + assert y.shape == (B, M, K) + assert s.shape == (B, M, K // QUANT_BLOCK_SIZE) + + y_ref, s_ref = torch_mxfp8_quant_from_fp32(x.reshape(-1, K).to(torch.float32)) + torch.testing.assert_close(s.reshape(-1, K // QUANT_BLOCK_SIZE), s_ref) + + +# ----------------------------------------------------------------------------- +# fp8_legacy_to_mxfp8 +# ----------------------------------------------------------------------------- + + +def torch_fp8_legacy_to_mxfp8(x_fnuz: torch.Tensor, x_scale_fp32: torch.Tensor): + """Reference: dequantize fnuz fp8 with the 1x128 fp32 scale, then run + the standard mxfp8 1x32 quant on the result.""" + M, N = x_fnuz.shape + x_dq = x_fnuz.to(torch.float32) * x_scale_fp32.repeat_interleave( + LEGACY_BLOCK_SIZE, dim=1 + ) + return torch_mxfp8_quant_from_fp32(x_dq) + + +@pytest.mark.parametrize( + "M, N", + [ + (1, 128), + (8, 128), + (16, 256), + (32, 512), + (64, 1024), + (128, 256), + (37, 256), # non-pow-2 M + ], +) +def test_fp8_legacy_to_mxfp8(M: int, N: int): + if not arch_info.is_fp8_avail(): + pytest.skip("FP8 not supported on this arch") + torch.cuda.empty_cache() + torch.manual_seed(5) + + # Random values within e4m3fnuz range, then cast to fnuz fp8. + x_f32 = (torch.randn((M, N), dtype=torch.float32, device="cuda")).clamp(-200, 200) + x_fnuz = x_f32.to(torch.float8_e4m3fnuz) + # Random fp32 1x128 scales in a moderate range so the dequant stays within fp8. + x_scale_fp32 = ( + torch.rand((M, N // LEGACY_BLOCK_SIZE), dtype=torch.float32, device="cuda") + * 0.5 + + 0.25 + ) + + y_ref, s_ref = torch_fp8_legacy_to_mxfp8(x_fnuz, x_scale_fp32) + y_kern, s_kern = fp8_legacy_to_mxfp8(x_fnuz, x_scale_fp32) + + torch.testing.assert_close(s_kern, s_ref) + torch.testing.assert_close( + y_kern.view(torch.uint8).to(torch.int32), + y_ref.view(torch.uint8).to(torch.int32), + atol=1, + rtol=0, + ) + + +def test_fp8_legacy_to_mxfp8_preallocated(): + if not arch_info.is_fp8_avail(): + pytest.skip("FP8 not supported on this arch") + torch.cuda.empty_cache() + torch.manual_seed(5) + + M, N = 16, 256 + x_fnuz = (torch.randn((M, N), device="cuda") * 4).to(torch.float8_e4m3fnuz) + x_scale_fp32 = torch.rand((M, N // LEGACY_BLOCK_SIZE), device="cuda") * 0.5 + 0.25 + y_pre = torch.empty((M, N), dtype=torch.float8_e4m3fn, device="cuda") + s_pre = torch.empty((M, N // QUANT_BLOCK_SIZE), dtype=torch.uint8, device="cuda") + y, s = fp8_legacy_to_mxfp8(x_fnuz, x_scale_fp32, y_fn=y_pre, y_scale=s_pre) + assert y.data_ptr() == y_pre.data_ptr() + assert s.data_ptr() == s_pre.data_ptr() + + y_ref, s_ref = torch_fp8_legacy_to_mxfp8(x_fnuz, x_scale_fp32) + torch.testing.assert_close(s, s_ref) + + +# ----------------------------------------------------------------------------- +# fused_rms_mxfp8_quant +# ----------------------------------------------------------------------------- + + +def torch_rmsnorm(x: torch.Tensor, weight: torch.Tensor, eps: float) -> torch.Tensor: + x_f32 = x.to(torch.float32) + g_f32 = weight.to(torch.float32) + rstd = torch.rsqrt(x_f32.pow(2).mean(-1, keepdim=True) + eps) + return x_f32 * rstd * g_f32 + + +def torch_rmsnorm_mxfp8_quant(x, weight, eps): + y_fp32 = torch_rmsnorm(x, weight, eps) + return torch_mxfp8_quant_from_fp32(y_fp32) + + +@pytest.mark.parametrize( + "M, K", + [ + (1, 32), + (1, 128), + (8, 128), + (16, 256), + (32, 512), + (64, 1024), + (128, 2048), + (97, 64), # non-pow-2 M, K=64 + (200, 192), # non-pow-2 K (still multiple of 32) + ], +) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) +def test_rmsnorm_mxfp8_quant(M: int, K: int, dtype: torch.dtype): + if not arch_info.is_fp8_avail(): + pytest.skip("FP8 not supported on this arch") + torch.cuda.empty_cache() + torch.manual_seed(11) + + x = torch.randn((M, K), dtype=dtype, device="cuda") + weight = torch.randn((K,), dtype=dtype, device="cuda") * 0.5 + 1.0 + eps = 1e-5 + + y_ref, s_ref = torch_rmsnorm_mxfp8_quant(x, weight, eps) + y_kern, s_kern = fused_rms_mxfp8_quant(x, weight, eps) + + # Hardware rsqrt vs torch.rsqrt can disagree by a ULP; that may flip a single + # e8m0 bin near a power-of-2 boundary. Compare dequantized values instead. + s_ref_f32 = e8m0_to_f32(s_ref).repeat_interleave(QUANT_BLOCK_SIZE, dim=1) + s_kern_f32 = e8m0_to_f32(s_kern).repeat_interleave(QUANT_BLOCK_SIZE, dim=1) + y_ref_dq = y_ref.to(torch.float32) * s_ref_f32 + y_kern_dq = y_kern.to(torch.float32) * s_kern_f32 + + torch.testing.assert_close(y_kern_dq, y_ref_dq, atol=5e-2, rtol=5e-2) + + +def test_rmsnorm_mxfp8_quant_preallocated(): + if not arch_info.is_fp8_avail(): + pytest.skip("FP8 not supported on this arch") + torch.cuda.empty_cache() + torch.manual_seed(11) + + M, K = 32, 256 + x = torch.randn((M, K), dtype=torch.bfloat16, device="cuda") + weight = torch.randn((K,), dtype=torch.bfloat16, device="cuda") + y_pre = torch.empty((M, K), dtype=torch.float8_e4m3fn, device="cuda") + s_pre = torch.empty((M, K // QUANT_BLOCK_SIZE), dtype=torch.uint8, device="cuda") + y, s = fused_rms_mxfp8_quant(x, weight, 1e-5, y=y_pre, scale=s_pre) + assert y.data_ptr() == y_pre.data_ptr() + assert s.data_ptr() == s_pre.data_ptr() + + +# ----------------------------------------------------------------------------- +# fused_dual_rmsnorm_mxfp8_quant +# ----------------------------------------------------------------------------- + + +def torch_dual_rmsnorm_mxfp8_quant(q, k, q_weight, k_weight, eps_q, eps_k): + yq_fp32 = torch_rmsnorm(q, q_weight, eps_q) + yq, sq = torch_mxfp8_quant_from_fp32(yq_fp32) + yk_fp32 = torch_rmsnorm(k, k_weight, eps_k) + yk = yk_fp32.to(k.dtype) + return yq, sq, yk + + +@pytest.mark.parametrize( + "M, KQ, KK", + [ + (1, 32, 32), + (1, 128, 64), + (8, 256, 128), + (16, 512, 256), + (32, 1024, 512), + (64, 2048, 1024), + (47, 96, 80), # non-pow-2 sizes + ], +) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) +def test_dual_rmsnorm_mxfp8_quant(M: int, KQ: int, KK: int, dtype: torch.dtype): + if not arch_info.is_fp8_avail(): + pytest.skip("FP8 not supported on this arch") + torch.cuda.empty_cache() + torch.manual_seed(13) + + q = torch.randn((M, KQ), dtype=dtype, device="cuda") + k = torch.randn((M, KK), dtype=dtype, device="cuda") + q_weight = torch.randn((KQ,), dtype=dtype, device="cuda") * 0.5 + 1.0 + k_weight = torch.randn((KK,), dtype=dtype, device="cuda") * 0.5 + 1.0 + eps_q, eps_k = 1e-5, 2e-5 + + yq_ref, sq_ref, yk_ref = torch_dual_rmsnorm_mxfp8_quant( + q, k, q_weight, k_weight, eps_q, eps_k + ) + yq_kern, sq_kern, yk_kern = fused_dual_rmsnorm_mxfp8_quant( + q, k, q_weight, k_weight, eps_q, eps_k + ) + + # Q side: compare dequantized values (rsqrt jitter -> tolerate e8m0 ULP flips). + sq_ref_f32 = e8m0_to_f32(sq_ref).repeat_interleave(QUANT_BLOCK_SIZE, dim=1) + sq_kern_f32 = e8m0_to_f32(sq_kern).repeat_interleave(QUANT_BLOCK_SIZE, dim=1) + yq_ref_dq = yq_ref.to(torch.float32) * sq_ref_f32 + yq_kern_dq = yq_kern.to(torch.float32) * sq_kern_f32 + torch.testing.assert_close(yq_kern_dq, yq_ref_dq, atol=5e-2, rtol=5e-2) + + # K side: bf16/fp16 RMSNorm output. + torch.testing.assert_close(yk_kern, yk_ref, atol=5e-3, rtol=5e-3) + + +def test_dual_rmsnorm_mxfp8_quant_default_eps_k(): + """eps_k defaults to eps_q when not provided.""" + if not arch_info.is_fp8_avail(): + pytest.skip("FP8 not supported on this arch") + torch.cuda.empty_cache() + torch.manual_seed(13) + + M, KQ, KK = 16, 128, 96 + dtype = torch.bfloat16 + q = torch.randn((M, KQ), dtype=dtype, device="cuda") + k = torch.randn((M, KK), dtype=dtype, device="cuda") + q_weight = torch.randn((KQ,), dtype=dtype, device="cuda") + k_weight = torch.randn((KK,), dtype=dtype, device="cuda") + eps = 1e-5 + + yq_a, sq_a, yk_a = fused_dual_rmsnorm_mxfp8_quant(q, k, q_weight, k_weight, eps) + yq_b, sq_b, yk_b = fused_dual_rmsnorm_mxfp8_quant( + q, k, q_weight, k_weight, eps, eps_k=eps + ) + torch.testing.assert_close(yq_a.view(torch.uint8), yq_b.view(torch.uint8)) + torch.testing.assert_close(sq_a, sq_b) + torch.testing.assert_close(yk_a, yk_b) + + +# ----------------------------------------------------------------------------- +# fused_flatten_mxfp8_quant +# ----------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "M, N1, N2", + [ + (1, 1, 32), + (1, 4, 64), + (8, 2, 128), + (16, 3, 256), + (32, 4, 512), + (64, 1, 1024), + (37, 5, 64), # non-pow-2 M + (128, 8, 32), + (64, 8, 7168), + ], +) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) +def test_fused_flatten_mxfp8_quant(M: int, N1: int, N2: int, dtype: torch.dtype): + if not arch_info.is_fp8_avail(): + pytest.skip("FP8 not supported on this arch") + torch.cuda.empty_cache() + torch.manual_seed(17) + + # x = torch.randn((M, N1, N2), dtype=dtype, device="cuda") * 4.0 + x = torch.randn((N1, M, N2), dtype=dtype, device="cuda").transpose(0, 1) * 4.0 + + # Reference: flatten (M, N1, N2) -> (M, N1 * N2), then MXFP8 quant in fp32. + x_flat_fp32 = x.reshape(M, N1 * N2).to(torch.float32) + y_ref, s_ref = torch_mxfp8_quant_from_fp32(x_flat_fp32) + + y_kern, s_kern = fused_flatten_mxfp8_quant(x) + + assert y_kern.shape == (M, N1 * N2) + assert s_kern.shape == (M, (N1 * N2) // QUANT_BLOCK_SIZE) + + # Scales must be bit-exact (integer-only after fp32 cast). + torch.testing.assert_close(s_kern, s_ref) + + # Quantized values: compare via uint8 view, allow off-by-1 for fp32->fp8 + # rounding-mode subtlety. + torch.testing.assert_close( + y_kern.view(torch.uint8).to(torch.int32), + y_ref.view(torch.uint8).to(torch.int32), + atol=1, + rtol=0, + ) + + +def test_fused_flatten_mxfp8_quant_matches_per_1x32_after_flatten(): + """Sanity: the flatten+quant path should match dynamic_mxfp8_quant + applied to the pre-flattened (M, N1 * N2) tensor.""" + if not arch_info.is_fp8_avail(): + pytest.skip("FP8 not supported on this arch") + torch.cuda.empty_cache() + torch.manual_seed(19) + + M, N1, N2 = 16, 4, 128 + x = torch.randn((M, N1, N2), dtype=torch.bfloat16, device="cuda") + + y_flat, s_flat = fused_flatten_mxfp8_quant(x) + y_ref, s_ref = dynamic_mxfp8_quant(x.reshape(M, N1 * N2).contiguous()) + + torch.testing.assert_close(s_flat, s_ref) + torch.testing.assert_close( + y_flat.view(torch.uint8).to(torch.int32), + y_ref.view(torch.uint8).to(torch.int32), + atol=1, + rtol=0, + )