diff --git a/aiter/ops/triton/_triton_kernels/fusions/fused_clamp_act_mul.py b/aiter/ops/triton/_triton_kernels/fusions/fused_clamp_act_mul.py index f5887fdcb5..589cc372f0 100644 --- a/aiter/ops/triton/_triton_kernels/fusions/fused_clamp_act_mul.py +++ b/aiter/ops/triton/_triton_kernels/fusions/fused_clamp_act_mul.py @@ -23,6 +23,7 @@ [ "BLOCK_SIZE_N", "QUANT_BLOCK_SIZE", + "SCALE_FMT", "HAVE_WEIGHTS", "WEIGHT_BROADCAST", "HAVE_SWIGLU_CLAMP", @@ -50,6 +51,7 @@ def _fused_clamp_silu_mul_kernel( swiglu_limit, BLOCK_SIZE_N: tl.constexpr, QUANT_BLOCK_SIZE: tl.constexpr, + SCALE_FMT: tl.constexpr, DTYPE_MAX: tl.constexpr, DTYPE_MIN: tl.constexpr, HAVE_WEIGHTS: tl.constexpr, @@ -95,11 +97,34 @@ def _fused_clamp_silu_mul_kernel( out = out * w if HAS_QUANT: - out_q, block_scales = _fp8_quant_op( - out, 1, BLOCK_SIZE_N, QUANT_BLOCK_SIZE, DTYPE_MAX, DTYPE_MIN - ) - out_q = tl.ravel(out_q) - block_scales = tl.ravel(block_scales) + if SCALE_FMT == "ue8m0": + # Per-1×QUANT_BLOCK_SIZE MXFP8 emit: fp8 e4m3 values + uint8 ue8m0 + # biased-exponent scales. Mirrors the ue8m0 path used by moe_gemm_a8w4. + NUM_QB: tl.constexpr = BLOCK_SIZE_N // QUANT_BLOCK_SIZE + out_3d = tl.reshape(out, [1, NUM_QB, QUANT_BLOCK_SIZE]) + abs_3d = tl.abs(out_3d) + max_val = tl.max(abs_3d, axis=2, keep_dims=True) + dequant_scale = max_val / DTYPE_MAX + # ROUND_UP via exponent: 2 ** ceil(log2(dequant_scale)) + dequant_scale_exp = ( + dequant_scale.to(tl.uint32, bitcast=True) + 0x007FFFFF + ) & 0x7F800000 + dequant_scale_rounded = dequant_scale_exp.to(tl.float32, bitcast=True) + quant_scale = tl.where( + dequant_scale_rounded == 0, 0.0, 1.0 / dequant_scale_rounded + ) + quant_tensor = out_3d * quant_scale + quant_2d = tl.reshape(quant_tensor, [1, BLOCK_SIZE_N]) + out_q = tl.ravel(quant_2d) + scale_exp = (dequant_scale_exp >> 23).to(tl.uint8) + scale_exp_2d = tl.reshape(scale_exp, [1, NUM_QB]) + block_scales = tl.ravel(scale_exp_2d) + else: + out_q, block_scales = _fp8_quant_op( + out, 1, BLOCK_SIZE_N, QUANT_BLOCK_SIZE, DTYPE_MAX, DTYPE_MIN + ) + out_q = tl.ravel(out_q) + block_scales = tl.ravel(block_scales) tl.store( out_ptr + m_pid * out_stride_m + n_offs * out_stride_n, @@ -108,8 +133,8 @@ def _fused_clamp_silu_mul_kernel( ) num_bs = tl.cdiv(n_half, QUANT_BLOCK_SIZE) - NUM_QB: tl.constexpr = BLOCK_SIZE_N // QUANT_BLOCK_SIZE - g_offs = tl.arange(0, NUM_QB) + NUM_QB_S: tl.constexpr = BLOCK_SIZE_N // QUANT_BLOCK_SIZE + g_offs = tl.arange(0, NUM_QB_S) tl.store( scale_ptr + m_pid * scale_stride_m + g_offs * scale_stride_n, block_scales.to(scale_ptr.dtype.element_ty), diff --git a/aiter/ops/triton/_triton_kernels/gemm/basic/gemm_afp8wfp8.py b/aiter/ops/triton/_triton_kernels/gemm/basic/gemm_afp8wfp8.py new file mode 100644 index 0000000000..1b549157a8 --- /dev/null +++ b/aiter/ops/triton/_triton_kernels/gemm/basic/gemm_afp8wfp8.py @@ -0,0 +1,472 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. + +import triton +import triton.language as tl +from aiter.ops.triton.utils._triton.kernel_repr import make_kernel_repr +from aiter.ops.triton.utils._triton.pid_preprocessing import pid_grid, remap_xcd +from aiter.ops.triton.utils.gemm_config_utils import get_gemm_config + +_gemm_afp8wfp8_repr = make_kernel_repr( + "_gemm_afp8wfp8_kernel", + [ + "BLOCK_SIZE_M", + "BLOCK_SIZE_N", + "BLOCK_SIZE_K", + "GROUP_SIZE_M", + "num_warps", + "num_stages", + "waves_per_eu", + "matrix_instr_nonkdim", + "cache_modifier", + "NUM_KSPLIT", + "SPLITK_BLOCK_SIZE", + ], +) + + +@triton.heuristics( + { + "EVEN_K": lambda args: (args["K"] % args["BLOCK_SIZE_K"] == 0), + } +) +@triton.jit(repr=_gemm_afp8wfp8_repr) +def _gemm_afp8wfp8_kernel( + a_ptr, + b_ptr, + c_ptr, + a_scales_ptr, + b_scales_ptr, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_ck, + stride_cm, + stride_cn, + stride_asm, + stride_ask, + stride_bsn, + stride_bsk, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + NUM_KSPLIT: tl.constexpr, + SPLITK_BLOCK_SIZE: tl.constexpr, + EVEN_K: tl.constexpr, + num_warps: tl.constexpr, + num_stages: tl.constexpr, + waves_per_eu: tl.constexpr, + matrix_instr_nonkdim: tl.constexpr, + cache_modifier: tl.constexpr, +): + """ + Kernel for computing the matmul C = A x B. + A and B inputs are FP8 e4m3 (1 byte per element). + A_scales are e8m0 (uint8) with shape (M, K // 32). + B_scales are stored compact e8m0 (uint8) with shape (N // 128, K // 128), + representing 128x128 weight blocks. Broadcast inside kernel to (N, K // 32). + A has shape (M, K), B has shape (K, N) and C has shape (M, N). + Output dtype is determined by c_ptr (bf16 or fp16). + When NUM_KSPLIT > 1, K is split into NUM_KSPLIT partitions of + SPLITK_BLOCK_SIZE elements and the partial result for partition pid_k is + written to c_ptr + pid_k * stride_ck; a downstream reduce kernel sums them. + """ + + tl.assume(stride_am > 0) + tl.assume(stride_ak > 0) + tl.assume(stride_bk > 0) + tl.assume(stride_bn > 0) + tl.assume(stride_cm > 0) + tl.assume(stride_cn > 0) + tl.assume(stride_asm > 0) + tl.assume(stride_ask > 0) + tl.assume(stride_bsk > 0) + tl.assume(stride_bsn > 0) + + GRID_MN = tl.cdiv(M, BLOCK_SIZE_M) * tl.cdiv(N, BLOCK_SIZE_N) + + pid_unified = tl.program_id(axis=0) + pid_k = pid_unified % NUM_KSPLIT + pid = pid_unified // NUM_KSPLIT + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + + if NUM_KSPLIT == 1: + pid = remap_xcd(pid, GRID_MN, NUM_XCDS=8) + pid_m, pid_n = pid_grid(pid, num_pid_m, num_pid_n, GROUP_SIZE_M=GROUP_SIZE_M) + else: + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) + tl.assume(pid_k >= 0) + + # Scale group sizes + SCALE_GROUP_SIZE: tl.constexpr = 32 # A: per 32 elements along K + B_SCALE_K_GROUP: tl.constexpr = 128 # B: per 128 along K + B_SCALE_N_GROUP: tl.constexpr = 128 # B: per 128 along N + + if (pid_k * SPLITK_BLOCK_SIZE) < K: + # K-block iteration range for this split (absolute block indices). + num_k_iter = tl.cdiv(SPLITK_BLOCK_SIZE, BLOCK_SIZE_K) + + # Create pointers for first block of A and B input matrices. The K + # offset is the absolute start of this split's K range. + offs_k = tl.arange(0, BLOCK_SIZE_K) + offs_k_split = pid_k * SPLITK_BLOCK_SIZE + offs_k + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + a_ptrs = a_ptr + ( + offs_am[:, None] * stride_am + offs_k_split[None, :] * stride_ak + ) + b_ptrs = b_ptr + ( + offs_k_split[:, None] * stride_bk + offs_bn[None, :] * stride_bn + ) + + # A-scale pointers: per-row (M) and per scale group (K // 32). Shift + # along the K-scale axis by the split's start in scale groups. + offs_ks_a = tl.arange(0, BLOCK_SIZE_K // SCALE_GROUP_SIZE) + offs_ks_a_split = pid_k * (SPLITK_BLOCK_SIZE // SCALE_GROUP_SIZE) + offs_ks_a + a_scale_ptrs = ( + a_scales_ptr + + offs_am[:, None] * stride_asm + + offs_ks_a_split[None, :] * stride_ask + ) + + # B-scale pointers: compact (N // 128, K // 128) — broadcast inside the kernel + # Each scale covers a 128(N) x 128(K) block. Computed per-iteration below + # using absolute K (so split-K naturally addresses the right b-scale block). + offs_bsn = offs_bn // B_SCALE_N_GROUP # (BLOCK_SIZE_N,) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + offs_scale_k_a = tl.arange(0, BLOCK_SIZE_K // SCALE_GROUP_SIZE) + + for k in range(pid_k * num_k_iter, (pid_k + 1) * num_k_iter): + # K base for this iteration (in elements, absolute). + k_base = k * BLOCK_SIZE_K + + # ---- Load A scales (M, BLOCK_SIZE_K // 32) ---- + if EVEN_K: + a_scales = tl.load(a_scale_ptrs) + else: + a_scale_mask = offs_scale_k_a[None, :] < ( + K // SCALE_GROUP_SIZE - k * (BLOCK_SIZE_K // SCALE_GROUP_SIZE) + ) + a_scales = tl.load(a_scale_ptrs, mask=a_scale_mask, other=127) + + # ---- Load and broadcast B scales (BLOCK_SIZE_N, BLOCK_SIZE_K // 32) ---- + offs_bsk = ( + k_base + offs_scale_k_a * SCALE_GROUP_SIZE + ) // B_SCALE_K_GROUP # (BLOCK_SIZE_K // 32,) + b_scale_ptrs = ( + b_scales_ptr + + offs_bsn[:, None] * stride_bsn + + offs_bsk[None, :] * stride_bsk + ) + if EVEN_K: + b_scales = tl.load(b_scale_ptrs, cache_modifier=cache_modifier) + else: + # OOB along K: load with the same mask as a-scales + b_scale_mask = offs_scale_k_a[None, :] < ( + K // SCALE_GROUP_SIZE - k * (BLOCK_SIZE_K // SCALE_GROUP_SIZE) + ) + b_scales = tl.load( + b_scale_ptrs, + mask=b_scale_mask, + other=127, + cache_modifier=cache_modifier, + ) + + # ---- Load A, B data ---- + if EVEN_K: + a = tl.load(a_ptrs) + b = tl.load(b_ptrs, cache_modifier=cache_modifier) + else: + a = tl.load( + a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0 + ) + b = tl.load( + b_ptrs, + mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, + other=0, + cache_modifier=cache_modifier, + ) + + accumulator = tl.dot_scaled( + a, a_scales, "e4m3", b, b_scales, "e4m3", accumulator + ) + + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + a_scale_ptrs += (BLOCK_SIZE_K // SCALE_GROUP_SIZE) * stride_ask + + c = accumulator.to(c_ptr.type.element_ty) + + # Write back the block of the output matrix C with masks. For + # NUM_KSPLIT > 1, each pid_k writes to a separate slab of c_ptr. + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64) + c_ptrs = ( + c_ptr + + stride_cm * offs_cm[:, None] + + stride_cn * offs_cn[None, :] + + pid_k * stride_ck + ) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +_gemm_afp8wfp8_preshuffle_repr = make_kernel_repr( + "_gemm_afp8wfp8_preshuffle_kernel", + [ + "BLOCK_SIZE_M", + "BLOCK_SIZE_N", + "BLOCK_SIZE_K", + "GROUP_SIZE_M", + "num_warps", + "num_stages", + "waves_per_eu", + "matrix_instr_nonkdim", + "cache_modifier", + "NUM_KSPLIT", + "SPLITK_BLOCK_SIZE", + ], +) + + +@triton.heuristics( + { + "EVEN_K": lambda args: (args["K"] % args["BLOCK_SIZE_K"] == 0), + } +) +@triton.jit(repr=_gemm_afp8wfp8_preshuffle_repr) +def _gemm_afp8wfp8_preshuffle_kernel( + a_ptr, + b_ptr, + c_ptr, + a_scales_ptr, + b_scales_ptr, + M, + N, + K, + stride_am, + stride_ak, + stride_bn, + stride_bk, + stride_ck, + stride_cm, + stride_cn, + stride_asm, + stride_ask, + stride_bsn, + stride_bsk, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + NUM_KSPLIT: tl.constexpr, + SPLITK_BLOCK_SIZE: tl.constexpr, + EVEN_K: tl.constexpr, + num_warps: tl.constexpr, + num_stages: tl.constexpr, + waves_per_eu: tl.constexpr, + matrix_instr_nonkdim: tl.constexpr, + cache_modifier: tl.constexpr, +): + """ + Preshuffle variant of _gemm_afp8wfp8_kernel. Weight tensor has been shuffled + via aiter.ops.shuffle.shuffle_weight(layout=(16, 16)) so that 16-row N tiles + are interleaved with their 32-col K chunks in storage. The kernel loads the + shuffled tile in storage order (BLOCK_SIZE_N // 16, BLOCK_SIZE_K * 16) then + reshape+permute+trans inside the kernel to restore logical (K, N) layout + before tl.dot_scaled. Scales remain in the unshuffled compact 128x128 layout. + When NUM_KSPLIT > 1, K is split into NUM_KSPLIT partitions of + SPLITK_BLOCK_SIZE elements; each pid_k writes to c_ptr + pid_k * stride_ck. + """ + + tl.assume(stride_am > 0) + tl.assume(stride_ak > 0) + tl.assume(stride_bk > 0) + tl.assume(stride_bn > 0) + tl.assume(stride_cm > 0) + tl.assume(stride_cn > 0) + tl.assume(stride_asm > 0) + tl.assume(stride_ask > 0) + tl.assume(stride_bsk > 0) + tl.assume(stride_bsn > 0) + + GRID_MN = tl.cdiv(M, BLOCK_SIZE_M) * tl.cdiv(N, BLOCK_SIZE_N) + + pid_unified = tl.program_id(axis=0) + pid_k = pid_unified % NUM_KSPLIT + pid = pid_unified // NUM_KSPLIT + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + + if NUM_KSPLIT == 1: + pid = remap_xcd(pid, GRID_MN, NUM_XCDS=8) + pid_m, pid_n = pid_grid(pid, num_pid_m, num_pid_n, GROUP_SIZE_M=GROUP_SIZE_M) + else: + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) + tl.assume(pid_k >= 0) + + SCALE_GROUP_SIZE: tl.constexpr = 32 # A: per-32 along K + B_SCALE_K_GROUP: tl.constexpr = 128 # B compact: per-128 along K + B_SCALE_N_GROUP: tl.constexpr = 128 # B compact: per-128 along N + + if (pid_k * SPLITK_BLOCK_SIZE) < K: + num_k_iter = tl.cdiv(SPLITK_BLOCK_SIZE, BLOCK_SIZE_K) + + # A pointers (offset by this split's K start). + offs_k = tl.arange(0, BLOCK_SIZE_K) + offs_k_split = pid_k * SPLITK_BLOCK_SIZE + offs_k + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + a_ptrs = a_ptr + ( + offs_am[:, None] * stride_am + offs_k_split[None, :] * stride_ak + ) + + # B pointers for preshuffled layout. The shuffled storage is viewed as + # (N // 16, K * 16) elements. pid_n indexes BLOCK_SIZE_N // 16 N-tiles per + # step; the K dimension is expanded by 16x in byte addresses. The split + # offsets the K-byte axis by pid_k * SPLITK_BLOCK_SIZE * 16. + offs_bn_shuffle = pid_n * (BLOCK_SIZE_N // 16) + tl.arange( + 0, BLOCK_SIZE_N // 16 + ) + offs_k_shuffle_arr = tl.arange(0, BLOCK_SIZE_K * 16) + offs_k_shuffle = pid_k * SPLITK_BLOCK_SIZE * 16 + offs_k_shuffle_arr + b_ptrs = b_ptr + ( + offs_bn_shuffle[:, None] * stride_bn + offs_k_shuffle[None, :] * stride_bk + ) + + # A-scale pointers: per-row M, per 32-K group. Shift along the K-scale + # axis by the split's start in scale groups. + offs_ks_a = tl.arange(0, BLOCK_SIZE_K // SCALE_GROUP_SIZE) + offs_ks_a_split = pid_k * (SPLITK_BLOCK_SIZE // SCALE_GROUP_SIZE) + offs_ks_a + a_scale_ptrs = ( + a_scales_ptr + + offs_am[:, None] * stride_asm + + offs_ks_a_split[None, :] * stride_ask + ) + + # B-scale pointers: compact (N // 128, K // 128). The N index needs the + # ORIGINAL (logical) row, not the shuffled row index. + offs_bn_logical = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_bsn = offs_bn_logical // B_SCALE_N_GROUP + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + offs_scale_k_a = tl.arange(0, BLOCK_SIZE_K // SCALE_GROUP_SIZE) + + for k in range(pid_k * num_k_iter, (pid_k + 1) * num_k_iter): + k_base = k * BLOCK_SIZE_K # absolute K base + + # Load A scales. + if EVEN_K: + a_scales = tl.load(a_scale_ptrs) + else: + a_scale_mask = offs_scale_k_a[None, :] < ( + K // SCALE_GROUP_SIZE - k * (BLOCK_SIZE_K // SCALE_GROUP_SIZE) + ) + a_scales = tl.load(a_scale_ptrs, mask=a_scale_mask, other=127) + + # Load and broadcast B scales (computed from absolute K). + offs_bsk = (k_base + offs_scale_k_a * SCALE_GROUP_SIZE) // B_SCALE_K_GROUP + b_scale_ptrs = ( + b_scales_ptr + + offs_bsn[:, None] * stride_bsn + + offs_bsk[None, :] * stride_bsk + ) + if EVEN_K: + b_scales = tl.load(b_scale_ptrs, cache_modifier=cache_modifier) + else: + b_scale_mask = offs_scale_k_a[None, :] < ( + K // SCALE_GROUP_SIZE - k * (BLOCK_SIZE_K // SCALE_GROUP_SIZE) + ) + b_scales = tl.load( + b_scale_ptrs, + mask=b_scale_mask, + other=127, + cache_modifier=cache_modifier, + ) + + # Load A and B (preshuffled). + if EVEN_K: + a = tl.load(a_ptrs) + b_shuf = tl.load(b_ptrs, cache_modifier=cache_modifier) + else: + a = tl.load( + a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0 + ) + b_shuf = tl.load( + b_ptrs, + mask=offs_k_shuffle_arr[None, :] < (K - k * BLOCK_SIZE_K) * 16, + other=0, + cache_modifier=cache_modifier, + ) + + # Unshuffle B in-kernel. Inverse of the shuffle_weight permute: + # shuffle: (N // 16, 16, K // 32, 2, 16) --[perm 0,1,3,4,2,5]-> (N // 16, K // 32, 2, 16, 16) + # unshuffle: (N // 16, K // 32, 2, 16, 16) --[perm 0,1,4,2,3,5]-> (N // 16, 16, K // 32, 2, 16) + # then flatten to (N, K) and trans to (K, N). + b = ( + b_shuf.reshape( + 1, + BLOCK_SIZE_N // 16, + BLOCK_SIZE_K // 32, + 2, + 16, + 16, + ) + .permute(0, 1, 4, 2, 3, 5) + .reshape(BLOCK_SIZE_N, BLOCK_SIZE_K) + .trans(1, 0) + ) + + accumulator = tl.dot_scaled( + a, a_scales, "e4m3", b, b_scales, "e4m3", accumulator + ) + + # Advance pointers. + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * 16 * stride_bk + a_scale_ptrs += (BLOCK_SIZE_K // SCALE_GROUP_SIZE) * stride_ask + + c = accumulator.to(c_ptr.type.element_ty) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64) + c_ptrs = ( + c_ptr + + stride_cm * offs_cm[:, None] + + stride_cn * offs_cn[None, :] + + pid_k * stride_ck + ) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +def _get_config( + M: int, + N: int, + K: int, + shuffle: bool = False, +): + """Load the best tuned config for (M, N, K) via the standard aiter JSON + config mechanism. Falls back to the generic fallback file or to + _DEFAULT_CONFIG if no JSON is available.""" + if shuffle: + return get_gemm_config("GEMM-AFP8WFP8_PRESHUFFLED", M, N, K) + else: + return get_gemm_config("GEMM-AFP8WFP8", M, N, K) diff --git a/aiter/ops/triton/_triton_kernels/gemm/fused/fused_gemm_a16w16_quant_x.py b/aiter/ops/triton/_triton_kernels/gemm/fused/fused_gemm_a16w16_quant_x.py new file mode 100644 index 0000000000..16de00ed92 --- /dev/null +++ b/aiter/ops/triton/_triton_kernels/gemm/fused/fused_gemm_a16w16_quant_x.py @@ -0,0 +1,265 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. + +import triton.language as tl +from aiter.ops.triton.utils._triton.kernel_repr import make_kernel_repr +from aiter.ops.triton.utils._triton.pid_preprocessing import pid_grid, remap_xcd +from aiter.ops.triton.utils.gemm_config_utils import ( + compute_splitk_params, + get_gemm_config, +) + +import triton + +_fused_gemm_a16w16_quant_x_repr = make_kernel_repr( + "_fused_gemm_a16w16_quant_x_kernel", + [ + "BLOCK_SIZE_M", + "BLOCK_SIZE_N", + "BLOCK_SIZE_K", + "GROUP_SIZE_M", + "NUM_KSPLIT", + "SPLITK_BLOCK_SIZE", + "QUANT_BLOCK_SIZE", + "EVEN_K", + "EVEN_MN", + "cache_modifier", + "activation", + "use_activation", + "ADD_BIAS", + "SKIP_REDUCE", + ], +) + + +@triton.heuristics( + { + "EVEN_K": lambda args: (args["K"] % (args["SPLITK_BLOCK_SIZE"]) == 0) + and (args["SPLITK_BLOCK_SIZE"] % args["BLOCK_SIZE_K"] == 0), + "EVEN_MN": lambda args: (args["M"] % args["BLOCK_SIZE_M"] == 0) + and (args["N"] % args["BLOCK_SIZE_N"] == 0), + } +) +@triton.jit( + repr=_fused_gemm_a16w16_quant_x_repr, + do_not_specialize=["M", "N"], +) +def _fused_gemm_a16w16_quant_x_kernel( + a_ptr, + b_ptr, + bias_ptr, + c_ptr, + a_quant_ptr, + a_scale_ptr, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_ck, + stride_cm, + stride_cn, + stride_a_quant_m, + stride_a_quant_k, + stride_a_scale_m, + stride_a_scale_n, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + NUM_KSPLIT: tl.constexpr, + SPLITK_BLOCK_SIZE: tl.constexpr, + QUANT_BLOCK_SIZE: tl.constexpr, + EVEN_K: tl.constexpr, + EVEN_MN: tl.constexpr, + cache_modifier: tl.constexpr, + activation: tl.constexpr, + use_activation: tl.constexpr, + ADD_BIAS: tl.constexpr, + SKIP_REDUCE: tl.constexpr, +): + """Kernel that computes C = A x B and also emits an MXFP8-quantized A. + + The grid is laid out as a single 1D program-id space split into two + contiguous regions: + + * `[0, NUM_KSPLIT * num_pid_m * num_pid_n)` runs the GEMM (identical to + the unfused a16w16 kernel). + * `[NUM_KSPLIT * num_pid_m * num_pid_n, GEMM_GRID + num_pid_m * num_pid_k_copy)` + runs the per-1x32 MXFP8 quantization of A: each program handles one + `BLOCK_SIZE_M x BLOCK_SIZE_K` tile of A, derives per-1x32 e8m0 scales, + and writes both the FP8 values and the uint8 scales. + + BLOCK_SIZE_K must be a multiple of QUANT_BLOCK_SIZE (=32) so that each tile + contains a whole number of MXFP8 groups per row. + """ + + tl.assume(stride_am > 0) + tl.assume(stride_ak > 0) + tl.assume(stride_bk > 0) + tl.assume(stride_bn > 0) + tl.assume(stride_ck > 0) + tl.assume(stride_cm > 0) + tl.assume(stride_cn > 0) + tl.assume(stride_a_quant_m > 0) + tl.assume(stride_a_quant_k > 0) + tl.assume(stride_a_scale_m > 0) + tl.assume(stride_a_scale_n > 0) + + pid_unified = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_k_copy = tl.cdiv(K, BLOCK_SIZE_K) + GEMM_GRID = num_pid_m * num_pid_n * NUM_KSPLIT + + if pid_unified < GEMM_GRID: + # ---- GEMM branch ---------------------------------------------------- + pid_unified = remap_xcd(pid_unified, GEMM_GRID, NUM_XCDS=8) + pid_k = pid_unified % NUM_KSPLIT + pid = pid_unified // NUM_KSPLIT + + if NUM_KSPLIT == 1: + pid_m, pid_n = pid_grid( + pid, num_pid_m, num_pid_n, GROUP_SIZE_M=GROUP_SIZE_M + ) + else: + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) + tl.assume(pid_k >= 0) + + split_k_start = pid_k * SPLITK_BLOCK_SIZE + if split_k_start < K: + offs_k = tl.arange(0, BLOCK_SIZE_K) + offs_k_split = split_k_start + offs_k + if EVEN_MN: + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + else: + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + + a_ptrs = a_ptr + ( + offs_am[:, None] * stride_am + offs_k_split[None, :] * stride_ak + ) + b_ptrs = b_ptr + ( + offs_k_split[:, None] * stride_bk + offs_bn[None, :] * stride_bn + ) + + acc_dtype = tl.float32 if c_ptr.type.element_ty != tl.int8 else tl.int32 + if ADD_BIAS: + if NUM_KSPLIT == 1 or (SKIP_REDUCE and pid_k == 0): + accumulator = tl.load(bias_ptr + offs_bn).to(dtype=acc_dtype) + accumulator = tl.broadcast_to( + accumulator[None, :], (BLOCK_SIZE_M, BLOCK_SIZE_N) + ) + else: + accumulator = tl.zeros( + (BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype + ) + else: + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) + + split_k_end = tl.minimum(split_k_start + SPLITK_BLOCK_SIZE, K) + k_span = split_k_end - split_k_start + num_k_iter = tl.cdiv(k_span, BLOCK_SIZE_K) + + for k in range(num_k_iter): + if EVEN_K: + a = tl.load(a_ptrs) + b = tl.load(b_ptrs, cache_modifier=cache_modifier) + else: + a = tl.load( + a_ptrs, + mask=offs_k[None, :] < k_span - k * BLOCK_SIZE_K, + other=0.0, + ) + b = tl.load( + b_ptrs, + mask=offs_k[:, None] < k_span - k * BLOCK_SIZE_K, + other=0.0, + cache_modifier=cache_modifier, + ) + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + if use_activation and NUM_KSPLIT == 1: + accumulator = activation(accumulator) + + c = accumulator.to(c_ptr.type.element_ty) + offs_cm = pid_m.to(tl.int64) * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = ( + c_ptr + + stride_cm * offs_cm[:, None] + + stride_cn * offs_cn[None, :] + + pid_k * stride_ck + ) + if EVEN_MN: + tl.store(c_ptrs, c) + else: + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + else: + # ---- MXFP8 quant branch -------------------------------------------- + pid_copy = pid_unified - GEMM_GRID + pid_m = pid_copy // num_pid_k_copy + pid_k = pid_copy % num_pid_k_copy + + tl.assume(pid_m >= 0) + tl.assume(pid_k >= 0) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_k = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + + a_ptrs = a_ptr + (offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak) + mask = (offs_m[:, None] < M) & (offs_k[None, :] < K) + a = tl.load(a_ptrs, mask=mask, other=0.0).to(tl.float32) + + # Per-1x32 MXFP8 quant. Group along K within this tile. + n_groups: tl.constexpr = BLOCK_SIZE_K // QUANT_BLOCK_SIZE + a_2d = tl.reshape(a, (BLOCK_SIZE_M, n_groups, QUANT_BLOCK_SIZE)) + amax = tl.max(tl.abs(a_2d), axis=2, keep_dims=True) # (M, G, 1) + + amax_i32 = amax.to(tl.int32, bitcast=True) + amax_i32 = (amax_i32 + 0x200000).to(tl.uint32, bitcast=True) & 0xFF800000 + amax_p2 = amax_i32.to(tl.float32, bitcast=True) + scale_unbiased = tl.log2(amax_p2).floor() - 8 + scale_unbiased = tl.clamp(scale_unbiased, min=-127, max=127) + scale_e8m0 = (scale_unbiased.to(tl.int32) + 127).to(tl.uint8) # (M, G, 1) + quant_scale = tl.exp2(-scale_unbiased) # (M, G, 1) + + qa_2d = a_2d * quant_scale + qa = tl.reshape(qa_2d, (BLOCK_SIZE_M, BLOCK_SIZE_K)) + a_quant_ptrs = a_quant_ptr + ( + offs_m[:, None] * stride_a_quant_m + offs_k[None, :] * stride_a_quant_k + ) + tl.store(a_quant_ptrs, qa.to(a_quant_ptr.type.element_ty), mask=mask) + + # Store scales: shape (M, K // QUANT_BLOCK_SIZE). + offs_s_n = pid_k * n_groups + tl.arange(0, n_groups) + scale_2d = tl.reshape(scale_e8m0, (BLOCK_SIZE_M, n_groups)) + a_scale_ptrs = a_scale_ptr + ( + offs_m[:, None] * stride_a_scale_m + offs_s_n[None, :] * stride_a_scale_n + ) + scale_mask = (offs_m[:, None] < M) & ( + offs_s_n[None, :] < (K // QUANT_BLOCK_SIZE) + ) + tl.store(a_scale_ptrs, scale_2d, mask=scale_mask) + + +def _get_config( + M: int, + N: int, + K: int, +): + # Use the same tuning portal as the unfused gemm_a16w16 — the extra + # MXFP8 quant is assumed not to shift the optimal config. + config, is_tunned = get_gemm_config("GEMM-A16W16", M, N, K) + return compute_splitk_params(config, K), is_tunned diff --git a/aiter/ops/triton/_triton_kernels/quant/quant_mxfp8.py b/aiter/ops/triton/_triton_kernels/quant/quant_mxfp8.py new file mode 100644 index 0000000000..14de6ca80f --- /dev/null +++ b/aiter/ops/triton/_triton_kernels/quant/quant_mxfp8.py @@ -0,0 +1,379 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. + +import triton +import triton.language as tl + +# MXFP8 activation quant: per-1x32 e8m0 scale + FP8 e4m3 values. +# Follows aiter.ops.quant.per_1x32_f8_scale_f8_quant: +# MAX_POW2 = int(log2(448)) = 8 +# dtypeMax = 2 ** 8 = 256.0 +# scale_f32 = max_abs / dtypeMax +# scale_e8m0 = round_up_to_pow2(scale_f32) → e8m0 biased +# y = round(x_fp32 / e8m0_to_f32(scale_e8m0)) cast to fp8 e4m3 +# +# Per-block e8m0 derivation done with the same trick as the existing mxfp4 quant: +# - bitcast amax to int32 +# - add 0x200000 (round up to a power of 2 with respect to fp4-style rounding) +# - mask 0xFF800000 (keep only sign+exponent bits) +# - bitcast back to fp32 +# This delivers a pure power-of-2 amax. Then log2(amax).floor() - 8 gives the +# unbiased e8m0 exponent for MXFP8 (since dtypeMax = 2**8). + + +@triton.jit +def _mxfp8_quant_kernel( + x_ptr, + y_ptr, + s_ptr, + M, + N, + stride_xm, + stride_xn, + stride_ym, + stride_yn, + stride_sm, + stride_sn, + BLOCK_SIZE_N: tl.constexpr, # power-of-2 covering full N + QUANT_BLOCK_SIZE: tl.constexpr, # =32 + NUM_PRGMS: tl.constexpr, # row-loop range (usually =M) +): + """ + Per-1x32 MXFP8 quant. One program per row, holding the full row in + registers so a single launch handles all K-groups. Mirrors + _rmsnorm_mxfp8_quant_kernel shape and minimizes grid overhead. + """ + row_start = tl.program_id(0) + col_offsets = tl.arange(0, BLOCK_SIZE_N) + mask = col_offsets < N + n_groups: tl.constexpr = BLOCK_SIZE_N // QUANT_BLOCK_SIZE + + for row_idx in tl.range(row_start, M, NUM_PRGMS, num_stages=2): + x = tl.load( + x_ptr + row_idx * stride_xm + col_offsets * stride_xn, + mask=mask, + other=0.0, + ).to(tl.float32) + + # (BLOCK_SIZE_N,) -> (n_groups, QUANT_BLOCK_SIZE) + x_2d = tl.reshape(x, (n_groups, QUANT_BLOCK_SIZE)) + amax = tl.max(tl.abs(x_2d), axis=1, keep_dims=True) + + amax_i32 = amax.to(tl.int32, bitcast=True) + amax_i32 = (amax_i32 + 0x200000).to(tl.uint32, bitcast=True) & 0xFF800000 + amax_p2 = amax_i32.to(tl.float32, bitcast=True) + scale_unbiased = tl.log2(amax_p2).floor() - 8 + scale_unbiased = tl.clamp(scale_unbiased, min=-127, max=127) + scale_e8m0 = (scale_unbiased.to(tl.int32) + 127).to(tl.uint8) + quant_scale = tl.exp2(-scale_unbiased) + + qx_2d = x_2d * quant_scale + qx = tl.reshape(qx_2d, (BLOCK_SIZE_N,)) + y = qx.to(y_ptr.type.element_ty) + + tl.store( + y_ptr + row_idx * stride_ym + col_offsets * stride_yn, + y, + mask=mask, + ) + + group_offsets = tl.arange(0, n_groups) + group_mask = group_offsets < (N // QUANT_BLOCK_SIZE) + scale_flat = tl.reshape(scale_e8m0, (n_groups,)) + tl.store( + s_ptr + row_idx * stride_sm + group_offsets * stride_sn, + scale_flat, + mask=group_mask, + ) + + +# Transcoder: (FP8 fnuz, fp32 1x128 scale) -> (FP8 fn, e8m0 1x32 scale). +# Replaces the Python dequant+requant cascade (fp32 cast + multiply + bf16 cast +# + per_1x32_mxfp8 quant) used in linear.py's MXFP8 fallback path for MLA wq_b +# when q_norm emits the legacy fp8 fnuz + fp32 1x128 format. +# +# In: x_fp8_fnuz (M, N) — fp8 e4m3fnuz bits (interpreted with bias 8 -> value) +# x_scale_fp32 (M, N//128) — fp32 per-token-block scale +# Out: y_fp8_fn (M, N) — fp8 e4m3fn bits (NV format, bias 7) +# y_scale_e8m0 (M, N//32) — uint8 e8m0 (1x32 MX scale) + + +@triton.jit +def _fp8_legacy_to_mxfp8_kernel( + x_fnuz_ptr, + x_scale_fp32_ptr, + y_fn_ptr, + y_scale_e8m0_ptr, + M, + N, + stride_xm, + stride_xn, + stride_xsm, + stride_xsn, + stride_ym, + stride_yn, + stride_ysm, + stride_ysn, + BLOCK_SIZE_M: tl.constexpr, + QUANT_BLOCK_SIZE: tl.constexpr, # =32 (MXFP8 group) + LEGACY_BLOCK_SIZE: tl.constexpr, # =128 (input scale group) +): + """ + One program per (BLOCK_SIZE_M rows, QUANT_BLOCK_SIZE-element column window). + For each 1x32 block, dequantize fnuz fp8 values using the corresponding + 1x128 fp32 scale, derive the e8m0 (1x32) scale, then re-quantize to fp8 fn. + """ + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * QUANT_BLOCK_SIZE + tl.arange(0, QUANT_BLOCK_SIZE) + + x_offs = offs_m[:, None] * stride_xm + offs_n[None, :] * stride_xn + x_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + + # Load fp8 fnuz values; .to(fp32) decodes via fnuz bias 8 semantically. + x_fnuz = tl.load(x_fnuz_ptr + x_offs, mask=x_mask, other=0.0).to(tl.float32) + + # Which legacy 1x128 group does this 1x32 block fall into? + legacy_n = (pid_n * QUANT_BLOCK_SIZE) // LEGACY_BLOCK_SIZE + xs_offs = offs_m * stride_xsm + legacy_n * stride_xsn + xs_mask = offs_m < M + x_scale = tl.load(x_scale_fp32_ptr + xs_offs, mask=xs_mask, other=1.0) + + # Dequantize: bf16-equivalent reconstruction. + x_dq = x_fnuz * x_scale[:, None] + + # Derive new e8m0 (1x32) scale from x_dq amax. Same recipe as + # _mxfp8_quant_kernel above. + amax = tl.max(tl.abs(x_dq), axis=1, keep_dims=True) + amax_i32 = amax.to(tl.int32, bitcast=True) + amax_i32 = (amax_i32 + 0x200000).to(tl.uint32, bitcast=True) & 0xFF800000 + amax_p2 = amax_i32.to(tl.float32, bitcast=True) + scale_unbiased = tl.log2(amax_p2).floor() - 8 + scale_unbiased = tl.clamp(scale_unbiased, min=-127, max=127) + scale_e8m0 = (scale_unbiased.to(tl.int32) + 127).to(tl.uint8) + quant_scale = tl.exp2(-scale_unbiased) + + # Re-quantize to fp8 fn. + qx = x_dq * quant_scale + y = qx.to(y_fn_ptr.type.element_ty) + + y_offs = offs_m[:, None] * stride_ym + offs_n[None, :] * stride_yn + tl.store(y_fn_ptr + y_offs, y, mask=x_mask) + + s_offs = offs_m[:, None] * stride_ysm + pid_n * stride_ysn + s_mask = offs_m[:, None] < M + tl.store(y_scale_e8m0_ptr + s_offs, scale_e8m0, mask=s_mask) + + +# Fused RMSNorm + MXFP8 (1x32 e8m0) quant. Replaces the separate +# rmsnorm_quant(fp8 fnuz + fp32 1x128) + transcode-to-MXFP8 sequence used +# upstream of MXFP8-aware GEMMs (e.g. V4 q_norm -> wq_b). +# +# One program per row. Holds the full row in registers, so K is constrained +# by the BLOCK_SIZE_K constexpr (must be a power of two >= K). +# +# In: x (M, K) bf16 or fp16 +# g (K,) bf16 or fp16 weight +# Out: y (M, K) fp8 e4m3fn +# scale (M, K // 32) uint8 e8m0 + + +@triton.jit +def _rmsnorm_mxfp8_quant_kernel( + x_ptr, + g_ptr, + y_ptr, + s_ptr, + M, + K, + stride_xm, + stride_xk, + stride_ym, + stride_yk, + stride_sm, + stride_sn, + epsilon, + BLOCK_SIZE_K: tl.constexpr, # power-of-2 covering full K + QUANT_BLOCK_SIZE: tl.constexpr, # =32 + NUM_PRGMS: tl.constexpr, # for persistent-loop variant; usually =M +): + """One program processes one row: rmsnorm then MXFP8 quant in registers.""" + row_start = tl.program_id(0) + col_offsets = tl.arange(0, BLOCK_SIZE_K) + mask = col_offsets < K + + for row_idx in tl.range(row_start, M, NUM_PRGMS, num_stages=2): + # Load full row, cast to fp32 + x = tl.load( + x_ptr + row_idx * stride_xm + col_offsets * stride_xk, + mask=mask, + other=0.0, + ).to(tl.float32) + g = tl.load(g_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32) + + # RMS norm + ss = tl.sum(x * x, axis=-1) + norm_factor = tl.math.rsqrt((ss / K) + epsilon) + y_fp32 = x * norm_factor * g # (BLOCK_SIZE_K,) + + # Reshape into (K // QUANT_BLOCK_SIZE, QUANT_BLOCK_SIZE) groups for amax. + # BLOCK_SIZE_K is the power-of-2 padded size; we keep OOB lanes masked to 0 + # via the load above, so amax over them is 0 (won't affect the in-bounds max). + y_2d = tl.reshape(y_fp32, (BLOCK_SIZE_K // QUANT_BLOCK_SIZE, QUANT_BLOCK_SIZE)) + amax = tl.max(tl.abs(y_2d), axis=1, keep_dims=True) # (G, 1) + + # e8m0 scale derivation (same recipe as _mxfp8_quant_kernel). + amax_i32 = amax.to(tl.int32, bitcast=True) + amax_i32 = (amax_i32 + 0x200000).to(tl.uint32, bitcast=True) & 0xFF800000 + amax_p2 = amax_i32.to(tl.float32, bitcast=True) + scale_unbiased = tl.log2(amax_p2).floor() - 8 + scale_unbiased = tl.clamp(scale_unbiased, min=-127, max=127) + scale_e8m0 = (scale_unbiased.to(tl.int32) + 127).to(tl.uint8) # (G, 1) + quant_scale = tl.exp2(-scale_unbiased) # (G, 1) + + # Quantize: y_quant = y_fp32 * quant_scale (broadcast along inner 32). + qx_2d = y_2d * quant_scale + qx = tl.reshape(qx_2d, (BLOCK_SIZE_K,)) + y_fp8 = qx.to(y_ptr.type.element_ty) + + # Store y (mask OOB). + tl.store( + y_ptr + row_idx * stride_ym + col_offsets * stride_yk, + y_fp8, + mask=mask, + ) + + # Store scales: G entries for this row. + n_groups: tl.constexpr = BLOCK_SIZE_K // QUANT_BLOCK_SIZE + group_offsets = tl.arange(0, n_groups) + group_mask = group_offsets < (K // QUANT_BLOCK_SIZE) + scale_flat = tl.reshape(scale_e8m0, (n_groups,)) + tl.store( + s_ptr + row_idx * stride_sm + group_offsets * stride_sn, + scale_flat, + mask=group_mask, + ) + + +# Dual fused RMSNorm: Q-side (MXFP8 quant + e8m0 scale emit) + K-side (bf16 out). +# Replaces the CK `fused_qk_rmsnorm_group_quant` semantics in one Triton launch +# for the MXFP8 GEMM path (Task #77). The two halves are independent (different +# weight, different K dim) so they're packed into one program per row to amortize +# launch overhead: same kernel launch loads both rows, normalizes both, stores Q +# fp8 + scale, stores K bf16. Each row's Q and K are independently RMSNorm'd +# (separate weights, separate eps, separate K dim) -- this kernel does NOT fuse +# their normalization arithmetic, only their launch. +# +# In: q (M, KQ) bf16 or fp16 +# kv (M, KK) bf16 or fp16 +# gq (KQ,) bf16 or fp16 Q-RMSNorm weight +# gk (KK,) bf16 or fp16 K-RMSNorm weight +# Out: yq (M, KQ) fp8 e4m3fn +# sq (M, KQ // 32) uint8 e8m0 +# yk (M, KK) bf16 + + +@triton.jit +def _dual_rmsnorm_mxfp8_quant_kernel( + q_ptr, + k_ptr, + gq_ptr, + gk_ptr, + yq_ptr, + sq_ptr, + yk_ptr, + M, + KQ, + KK, + stride_qm, + stride_qn, + stride_km, + stride_kn, + stride_yqm, + stride_yqn, + stride_sqm, + stride_sqn, + stride_ykm, + stride_ykn, + eps_q, + eps_k, + BLOCK_SIZE_KQ: tl.constexpr, # power-of-2 covering full KQ + BLOCK_SIZE_KK: tl.constexpr, # power-of-2 covering full KK + QUANT_BLOCK_SIZE: tl.constexpr, # =32 (MXFP8 group size) + NUM_PRGMS: tl.constexpr, # row-loop bound (usually =M) +): + """One program per row: do Q-side RMSNorm+MXFP8 quant AND K-side RMSNorm + (bf16 out) in one launch. Mirrors the CK `fused_qk_rmsnorm_group_quant` + fusion topology but emits MXFP8 1x32 (e8m0) scales for Q directly.""" + row_start = tl.program_id(0) + + q_col_offsets = tl.arange(0, BLOCK_SIZE_KQ) + q_mask = q_col_offsets < KQ + k_col_offsets = tl.arange(0, BLOCK_SIZE_KK) + k_mask = k_col_offsets < KK + + n_q_groups: tl.constexpr = BLOCK_SIZE_KQ // QUANT_BLOCK_SIZE + + for row_idx in tl.range(row_start, M, NUM_PRGMS, num_stages=2): + # ===== Q side: RMSNorm + MXFP8 quant ===== + x_q = tl.load( + q_ptr + row_idx * stride_qm + q_col_offsets * stride_qn, + mask=q_mask, + other=0.0, + ).to(tl.float32) + g_q = tl.load(gq_ptr + q_col_offsets, mask=q_mask, other=0.0).to(tl.float32) + + ss_q = tl.sum(x_q * x_q, axis=-1) + norm_q = tl.math.rsqrt((ss_q / KQ) + eps_q) + y_q_fp32 = x_q * norm_q * g_q + + y_q_2d = tl.reshape(y_q_fp32, (n_q_groups, QUANT_BLOCK_SIZE)) + amax_q = tl.max(tl.abs(y_q_2d), axis=1, keep_dims=True) + + amax_qi32 = amax_q.to(tl.int32, bitcast=True) + amax_qi32 = (amax_qi32 + 0x200000).to(tl.uint32, bitcast=True) & 0xFF800000 + amax_qp2 = amax_qi32.to(tl.float32, bitcast=True) + scale_q_unbiased = tl.log2(amax_qp2).floor() - 8 + scale_q_unbiased = tl.clamp(scale_q_unbiased, min=-127, max=127) + scale_q_e8m0 = (scale_q_unbiased.to(tl.int32) + 127).to(tl.uint8) + quant_scale_q = tl.exp2(-scale_q_unbiased) + + qx_q_2d = y_q_2d * quant_scale_q + qx_q = tl.reshape(qx_q_2d, (BLOCK_SIZE_KQ,)) + y_q_fp8 = qx_q.to(yq_ptr.type.element_ty) + + tl.store( + yq_ptr + row_idx * stride_yqm + q_col_offsets * stride_yqn, + y_q_fp8, + mask=q_mask, + ) + + q_group_offsets = tl.arange(0, n_q_groups) + q_group_mask = q_group_offsets < (KQ // QUANT_BLOCK_SIZE) + scale_q_flat = tl.reshape(scale_q_e8m0, (n_q_groups,)) + tl.store( + sq_ptr + row_idx * stride_sqm + q_group_offsets * stride_sqn, + scale_q_flat, + mask=q_group_mask, + ) + + # ===== K side: RMSNorm only, bf16 out ===== + x_k = tl.load( + k_ptr + row_idx * stride_km + k_col_offsets * stride_kn, + mask=k_mask, + other=0.0, + ).to(tl.float32) + g_k = tl.load(gk_ptr + k_col_offsets, mask=k_mask, other=0.0).to(tl.float32) + + ss_k = tl.sum(x_k * x_k, axis=-1) + norm_k = tl.math.rsqrt((ss_k / KK) + eps_k) + y_k_fp32 = x_k * norm_k * g_k + y_k_out = y_k_fp32.to(yk_ptr.type.element_ty) + + tl.store( + yk_ptr + row_idx * stride_ykm + k_col_offsets * stride_ykn, + y_k_out, + mask=k_mask, + ) diff --git a/aiter/ops/triton/configs/gemm/gfx950-GEMM-A16W16-N=384-K=7168.json b/aiter/ops/triton/configs/gemm/gfx950-GEMM-A16W16-N=384-K=7168.json new file mode 100644 index 0000000000..9b9cce2832 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/gfx950-GEMM-A16W16-N=384-K=7168.json @@ -0,0 +1,98 @@ +{ + "M_LEQ_4": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 6, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 7 + }, + "M_LEQ_8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 8, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 7 + }, + "M_LEQ_16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 1024, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 7 + }, + "M_LEQ_32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 7 + }, + "M_LEQ_64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 7 + }, + "M_LEQ_128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 7 + }, + "M_LEQ_256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 6, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 7 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + } +} diff --git a/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8.json b/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8.json new file mode 100644 index 0000000000..c7271acf94 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8.json @@ -0,0 +1,14 @@ +{ + "any": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + } +} diff --git a/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=1536-K=4096.json b/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=1536-K=4096.json new file mode 100644 index 0000000000..6b2567e194 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=1536-K=4096.json @@ -0,0 +1,170 @@ +{ + "M_LEQ_4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_8192": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_GEQ_8192": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + } +} \ No newline at end of file diff --git a/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=2048-K=7168.json b/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=2048-K=7168.json new file mode 100644 index 0000000000..a07c6c2d26 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=2048-K=7168.json @@ -0,0 +1,98 @@ +{ + "M_LEQ_4": { + "BLOCK_SIZE_M": 8, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 7 + }, + "M_LEQ_8": { + "BLOCK_SIZE_M": 4, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 6, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_16": { + "BLOCK_SIZE_M": 8, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "any": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + } +} diff --git a/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=4096-K=1024.json b/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=4096-K=1024.json new file mode 100644 index 0000000000..5ba817728b --- /dev/null +++ b/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=4096-K=1024.json @@ -0,0 +1,170 @@ +{ + "M_LEQ_4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_8192": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_GEQ_8192": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "any": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + } +} \ No newline at end of file diff --git a/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=4096-K=256.json b/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=4096-K=256.json new file mode 100644 index 0000000000..e4b178dda7 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=4096-K=256.json @@ -0,0 +1,170 @@ +{ + "M_LEQ_4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_8192": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_GEQ_8192": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + } +} \ No newline at end of file diff --git a/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=512-K=4096.json b/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=512-K=4096.json new file mode 100644 index 0000000000..91361fa902 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=512-K=4096.json @@ -0,0 +1,170 @@ +{ + "M_LEQ_4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 4 + }, + "M_LEQ_8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_8192": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_GEQ_8192": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + } +} \ No newline at end of file diff --git a/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=7168-K=2048.json b/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=7168-K=2048.json new file mode 100644 index 0000000000..6f74cb9e7f --- /dev/null +++ b/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=7168-K=2048.json @@ -0,0 +1,98 @@ +{ + "M_LEQ_4": { + "BLOCK_SIZE_M": 4, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "M_LEQ_8": { + "BLOCK_SIZE_M": 8, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "M_LEQ_16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "M_LEQ_32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "M_LEQ_64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + } +} diff --git a/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=7168-K=384.json b/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=7168-K=384.json new file mode 100644 index 0000000000..2a3b281530 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=7168-K=384.json @@ -0,0 +1,98 @@ +{ + "M_LEQ_4": { + "BLOCK_SIZE_M": 4, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 6, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "M_LEQ_8": { + "BLOCK_SIZE_M": 4, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "M_LEQ_16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "M_LEQ_32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "M_LEQ_64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + } +} diff --git a/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=768-K=7168.json b/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=768-K=7168.json new file mode 100644 index 0000000000..2d88b2f30c --- /dev/null +++ b/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=768-K=7168.json @@ -0,0 +1,98 @@ +{ + "M_LEQ_4": { + "BLOCK_SIZE_M": 8, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 7 + }, + "M_LEQ_8": { + "BLOCK_SIZE_M": 4, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_16": { + "BLOCK_SIZE_M": 8, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 7 + }, + "M_LEQ_32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 7 + }, + "M_LEQ_64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + } +} diff --git a/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=8192-K=1024.json b/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=8192-K=1024.json new file mode 100644 index 0000000000..cb2547fbbd --- /dev/null +++ b/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=8192-K=1024.json @@ -0,0 +1,170 @@ +{ + "M_LEQ_4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_2048": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_8192": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_GEQ_8192": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + } +} \ No newline at end of file diff --git a/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=8192-K=1536.json b/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=8192-K=1536.json new file mode 100644 index 0000000000..defffda52d --- /dev/null +++ b/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED-N=8192-K=1536.json @@ -0,0 +1,98 @@ +{ + "M_LEQ_4": { + "BLOCK_SIZE_M": 4, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 8, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "M_LEQ_8": { + "BLOCK_SIZE_M": 8, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 6, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "M_LEQ_16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "M_LEQ_32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "M_LEQ_64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + } +} diff --git a/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED.json b/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED.json new file mode 100644 index 0000000000..0008e2ab97 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP8WFP8_PRESHUFFLED.json @@ -0,0 +1,14 @@ +{ + "any": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + } +} diff --git a/aiter/ops/triton/fusions/fused_clamp_act_mul.py b/aiter/ops/triton/fusions/fused_clamp_act_mul.py index 0b1b8a1a63..a5baa533e7 100644 --- a/aiter/ops/triton/fusions/fused_clamp_act_mul.py +++ b/aiter/ops/triton/fusions/fused_clamp_act_mul.py @@ -25,6 +25,9 @@ def fused_clamp_act_mul( weights: Optional[torch.Tensor] = None, dtype_quant: torch.dtype | None = None, transpose_scale: bool = False, + quant_block_size: int = 128, + scale_dtype_fmt: Literal["fp32", "ue8m0"] = "fp32", + shuffle_scale: bool = False, ): """ Fused clamp (SwiGLU-style) + act(gate) * up + optional weights, with optional FP8 group quant. @@ -54,6 +57,25 @@ def fused_clamp_act_mul( HAS_QUANT = dtype_quant is not None + # Step 5 ue8m0 mode: per-1×32 group quant, uint8 scale. + assert scale_dtype_fmt in ("fp32", "ue8m0") + if scale_dtype_fmt == "ue8m0": + assert HAS_QUANT, "scale_dtype_fmt='ue8m0' requires dtype_quant" + assert ( + quant_block_size == 32 + ), f"ue8m0 requires quant_block_size=32 got {quant_block_size}" + assert dtype_quant in ( + torch.float8_e4m3fn, + torch.float8_e4m3fnuz, + ), f"ue8m0 requires fp8 e4m3, got {dtype_quant}" + if shuffle_scale and transpose_scale: + raise ValueError("shuffle_scale incompatible with transpose_scale") + _scale_storage_dtype = torch.uint8 + else: + if shuffle_scale: + raise ValueError("shuffle_scale only valid with scale_dtype_fmt='ue8m0'") + _scale_storage_dtype = torch.float32 + if HAS_QUANT: if out is None: out = torch.empty((M, n_half), dtype=dtype_quant, device=inp.device) @@ -65,15 +87,15 @@ def fused_clamp_act_mul( dtype_quant, out.dtype, ) - num_blocks = (n_half + 127) // 128 + num_blocks = (n_half + quant_block_size - 1) // quant_block_size if scale is None: if transpose_scale: scale = torch.empty( - (num_blocks, M), dtype=torch.float32, device=inp.device + (num_blocks, M), dtype=_scale_storage_dtype, device=inp.device ) else: scale = torch.empty( - (M, num_blocks), dtype=torch.float32, device=inp.device + (M, num_blocks), dtype=_scale_storage_dtype, device=inp.device ) else: if transpose_scale: @@ -153,7 +175,8 @@ def fused_clamp_act_mul( weights.stride(1) if HAVE_WEIGHTS else 0, swiglu_limit, BLOCK_SIZE_N=BLOCK_SIZE_N, - QUANT_BLOCK_SIZE=128, + QUANT_BLOCK_SIZE=quant_block_size, + SCALE_FMT=scale_dtype_fmt, DTYPE_MAX=DTYPE_MAX, DTYPE_MIN=-DTYPE_MAX, HAVE_WEIGHTS=HAVE_WEIGHTS, @@ -167,5 +190,9 @@ def fused_clamp_act_mul( if HAS_QUANT: if transpose_scale: scale = scale.view(M, num_bs_cols) + if shuffle_scale: + from aiter.utility import fp4_utils + + scale = fp4_utils.e8m0_shuffle(scale) return out, scale return out diff --git a/aiter/ops/triton/gemm/basic/gemm_afp8wfp8.py b/aiter/ops/triton/gemm/basic/gemm_afp8wfp8.py new file mode 100644 index 0000000000..9a124dda71 --- /dev/null +++ b/aiter/ops/triton/gemm/basic/gemm_afp8wfp8.py @@ -0,0 +1,287 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. + +from typing import Optional + +import torch +import triton + +from aiter.ops.triton._triton_kernels.gemm.basic.gemm_afp8wfp8 import ( + _gemm_afp8wfp8_kernel, + _gemm_afp8wfp8_preshuffle_kernel, + _get_config, +) +from aiter.ops.triton._triton_kernels.common.splitk_reduce import ( + _gemm_splitk_reduce_kernel, +) + +# ----------------------------------------------------------------------------- +# Tuned-config lookup for gemm_afp8wfp8_preshuffle. +# +# Configs live under aiter.ops.triton.configs.gemm using the standard aiter +# naming: gfx{arch}-GEMM-MXFP8-PRESHUFFLE-N={N}-K={K}.json, keyed by +# M_LEQ_x / M_GEQ_x / any per STANDARD_M_BOUNDS. A generic fallback +# gfx{arch}-GEMM-MXFP8-PRESHUFFLE.json covers untuned (N, K) shapes. +# ----------------------------------------------------------------------------- + +# Shapes that have OOM'd at runtime with the tuned config. Future calls +# bypass the tuned lookup for these and use _DEFAULT_CONFIG. The tuned +# benches don't run under HIP-graph capture, so they can pick configs that +# look fine in isolation but blow LDS once captured. This cache survives +# across requests in the same process. +_OOM_SHAPES: set = set() + + +def _mark_oom(M: int, N: int, K: int): + _OOM_SHAPES.add((M, N, K)) + + +def gemm_afp8wfp8( + x: torch.Tensor, + w: torch.Tensor, + x_scales: torch.Tensor, + w_scales: torch.Tensor, + dtype: Optional[torch.dtype] = torch.bfloat16, + y: Optional[torch.Tensor] = None, + config: Optional[dict] = None, + skip_reduce: Optional[bool] = False, +) -> torch.Tensor: + """ + Computes matrix multiplication Y = X @ W^T with MXFP8 activations and FP8 + weights (1x32 e8m0 act scales, 128x128 e8m0 weight scales). + + Args: + x: FP8 e4m3 (or uint8 view) input matrix with shape (M, K). + w: FP8 e4m3 (or uint8 view) weight matrix with shape (N, K) — internally + transposed to (K, N) before the kernel call. + x_scales: e8m0 (uint8) per-group scale for x with shape (M, K // 32). + w_scales: e8m0 (uint8) per-block scale for w with shape (N // 128, K // 128). + dtype: Output dtype (BF16 or FP16). Default bf16. + y: Optional pre-allocated output tensor with shape (M, N). + config: Optional kernel-tuning dict. If None uses defaults. + + Returns: + torch.Tensor: Output with shape (M, N). + """ + M, K = x.shape + N, K_w = w.shape + assert K == K_w, f"K mismatch: x has K={K}, w has K={K_w}" + + # Transpose w to (K, N) for the kernel. + w_t = w.T + + # tl.dot_scaled with format "e4m3" expects uint8-typed operands; reinterpret + # the FP8 buffers as uint8 (bit-identical view). + if x.dtype != torch.uint8: + x = x.view(torch.uint8) + if w_t.dtype != torch.uint8: + w_t = w_t.view(torch.uint8) + + if config is None: + config, _ = _get_config(M, N, K) + + if y is None and (config["NUM_KSPLIT"] == 1 or not skip_reduce): + y = torch.empty((M, N), dtype=dtype, device=x.device) + + config["SPLITK_BLOCK_SIZE"] = triton.cdiv( + K, config["NUM_KSPLIT"] + ) # How big each split_k partition is + if config["NUM_KSPLIT"] > 1: + y_pp = torch.empty( + (config["NUM_KSPLIT"], M, N), + dtype=torch.float32, + device=x.device, + ) + else: + y_pp = None + + grid = lambda META: ( # noqa: E731 + ( + META["NUM_KSPLIT"] + * triton.cdiv(M, META["BLOCK_SIZE_M"]) + * triton.cdiv(N, META["BLOCK_SIZE_N"]) + ), + ) + + _gemm_afp8wfp8_kernel[grid]( + x, + w_t, + y if config["NUM_KSPLIT"] == 1 else y_pp, + x_scales, + w_scales, + M, + N, + K, + x.stride(0), + x.stride(1), + w_t.stride(0), + w_t.stride(1), + 0 if config["NUM_KSPLIT"] == 1 else y_pp.stride(0), + y.stride(0) if config["NUM_KSPLIT"] == 1 else y_pp.stride(1), + y.stride(1) if config["NUM_KSPLIT"] == 1 else y_pp.stride(2), + x_scales.stride(0), + x_scales.stride(1), + w_scales.stride(0), + w_scales.stride(1), + **config, + ) + + if config["NUM_KSPLIT"] > 1: + if skip_reduce: + return y_pp + + REDUCE_BLOCK_SIZE_M = 32 + REDUCE_BLOCK_SIZE_N = 32 + ACTUAL_KSPLIT = triton.cdiv(K, config["SPLITK_BLOCK_SIZE"]) + + grid_reduce = ( + triton.cdiv(M, REDUCE_BLOCK_SIZE_M), + triton.cdiv(N, REDUCE_BLOCK_SIZE_N), + ) + _gemm_splitk_reduce_kernel[grid_reduce]( + y_pp, + y, + None, + M, + N, + y_pp.stride(0), + y_pp.stride(1), + y_pp.stride(2), + y.stride(0), + y.stride(1), + BLOCK_SIZE_M=REDUCE_BLOCK_SIZE_M, + BLOCK_SIZE_N=REDUCE_BLOCK_SIZE_N, + ACTUAL_KSPLIT=ACTUAL_KSPLIT, + MAX_KSPLIT=triton.next_power_of_2(config["NUM_KSPLIT"]), + ADD_BIAS=False, + activation=None, + use_activation=False, + KERNEL_NAME="_gemm_afp8wfp8_reduce_kernel", + ) + + return y + + +def gemm_afp8wfp8_preshuffle( + x: torch.Tensor, + w_shuffled: torch.Tensor, + x_scales: torch.Tensor, + w_scales: torch.Tensor, + dtype: Optional[torch.dtype] = torch.bfloat16, + y: Optional[torch.Tensor] = None, + config: Optional[dict] = None, + skip_reduce: Optional[bool] = False, +) -> torch.Tensor: + """ + Preshuffle variant of gemm_afp8wfp8. The weight tensor has already been + permuted via aiter.ops.shuffle.shuffle_weight(..., layout=(16, 16)). Scales + are left unshuffled in the compact 128x128 layout. + + Args: + x: FP8 e4m3 activations with shape (M, K). + w_shuffled: FP8 e4m3 weights, shuffled in place to (N, K) storage + (same total bytes; bytes rearranged for the kernel's read pattern). + x_scales: e8m0 (uint8) per-token scale with shape (M, K // 32). + w_scales: e8m0 (uint8) per-block weight scale with shape (N // 128, K // 128). + dtype: Output dtype. + y: Optional pre-allocated output (M, N). + config: Optional kernel-tuning dict. + + Returns: + torch.Tensor: Output with shape (M, N). + """ + M, K = x.shape + N, K_w = w_shuffled.shape + assert K == K_w, f"K mismatch: x={K}, w={K_w}" + assert N % 16 == 0, f"N must be divisible by 16 for preshuffle, got {N}" + + # The kernel expects to address the shuffled tensor as (N//16, K*16). + w_view = w_shuffled.view(N // 16, K * 16) + + if x.dtype != torch.uint8: + x = x.view(torch.uint8) + if w_view.dtype != torch.uint8: + w_view = w_view.view(torch.uint8) + + if config is None: + config, _ = _get_config(M, N, K, shuffle=True) + + if y is None and (config["NUM_KSPLIT"] == 1 or not skip_reduce): + y = torch.empty((M, N), dtype=dtype, device=x.device) + + config["SPLITK_BLOCK_SIZE"] = triton.cdiv( + K, config["NUM_KSPLIT"] + ) # How big each split_k partition is + if config["NUM_KSPLIT"] > 1: + y_pp = torch.empty( + (config["NUM_KSPLIT"], M, N), + dtype=torch.float32, + device=x.device, + ) + else: + y_pp = None + + grid = lambda META: ( # noqa: E731 + ( + META["NUM_KSPLIT"] + * triton.cdiv(M, META["BLOCK_SIZE_M"]) + * triton.cdiv(N, META["BLOCK_SIZE_N"]) + ), + ) + _gemm_afp8wfp8_preshuffle_kernel[grid]( + x, + w_view, + y if config["NUM_KSPLIT"] == 1 else y_pp, + x_scales, + w_scales, + M, + N, + K, + x.stride(0), + x.stride(1), + w_view.stride(0), + w_view.stride(1), + 0 if config["NUM_KSPLIT"] == 1 else y_pp.stride(0), + y.stride(0) if config["NUM_KSPLIT"] == 1 else y_pp.stride(1), + y.stride(1) if config["NUM_KSPLIT"] == 1 else y_pp.stride(2), + x_scales.stride(0), + x_scales.stride(1), + w_scales.stride(0), + w_scales.stride(1), + **config, + ) + + if config["NUM_KSPLIT"] > 1: + if skip_reduce: + return y_pp + + REDUCE_BLOCK_SIZE_M = 32 + REDUCE_BLOCK_SIZE_N = 32 + ACTUAL_KSPLIT = triton.cdiv(K, config["SPLITK_BLOCK_SIZE"]) + + grid_reduce = ( + triton.cdiv(M, REDUCE_BLOCK_SIZE_M), + triton.cdiv(N, REDUCE_BLOCK_SIZE_N), + ) + _gemm_splitk_reduce_kernel[grid_reduce]( + y_pp, + y, + None, + M, + N, + y_pp.stride(0), + y_pp.stride(1), + y_pp.stride(2), + y.stride(0), + y.stride(1), + BLOCK_SIZE_M=REDUCE_BLOCK_SIZE_M, + BLOCK_SIZE_N=REDUCE_BLOCK_SIZE_N, + ACTUAL_KSPLIT=ACTUAL_KSPLIT, + MAX_KSPLIT=triton.next_power_of_2(config["NUM_KSPLIT"]), + ADD_BIAS=False, + activation=None, + use_activation=False, + KERNEL_NAME="_gemm_afp8wfp8_preshuffle_reduce_kernel", + ) + + return y diff --git a/aiter/ops/triton/gemm/fused/fused_gemm_a16w16_quant_x.py b/aiter/ops/triton/gemm/fused/fused_gemm_a16w16_quant_x.py new file mode 100644 index 0000000000..d580687ef5 --- /dev/null +++ b/aiter/ops/triton/gemm/fused/fused_gemm_a16w16_quant_x.py @@ -0,0 +1,179 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. + +from typing import Optional, Tuple +import torch +import triton +from aiter.ops.triton._triton_kernels.gemm.fused.fused_gemm_a16w16_quant_x import ( + _fused_gemm_a16w16_quant_x_kernel, + _get_config, +) +from aiter.ops.triton._triton_kernels.common.splitk_reduce import ( + _gemm_splitk_reduce_kernel, +) +from aiter.ops.triton._triton_kernels.activation import _get_activation_from_str +from aiter.ops.triton.utils.logger import AiterTritonLogger + +_LOGGER = AiterTritonLogger() + +_QUANT_BLOCK_SIZE = 32 + + +def fused_gemm_a16w16_quant_x( + x, + w, + bias: Optional[torch.Tensor] = None, + dtype: Optional[torch.dtype] = torch.bfloat16, + quant_dtype: Optional[torch.dtype] = None, + y: Optional[torch.Tensor] = None, + x_quant: Optional[torch.Tensor] = None, + x_scales: Optional[torch.Tensor] = None, + config: Optional[dict] = None, + activation: Optional[str] = None, + skip_reduce: Optional[bool] = False, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Computes 16-bit matmul Y = X @ W^T and also emits an MXFP8-quantized X. + + This fuses the GEMM with the activation per-1x32 MXFP8 quantization that + immediately follows the router-gate GEMM in MoE flows (e.g. DSv4), + avoiding a separate kernel/DRAM pass to cast X from BF16 to FP8 + e8m0 + scales. + + The fused kernel uses a single 1D grid split into two regions: GEMM tiles + occupy `NUM_KSPLIT * cdiv(M, BLOCK_M) * cdiv(N, BLOCK_N)` programs and an + additional `cdiv(M, BLOCK_M) * cdiv(K, BLOCK_K)` programs handle the + MXFP8 quant of X. + + Args: + x (torch.Tensor): Input matrix with shape (M, K). K must be a multiple + of 32. + w (torch.Tensor): Weight matrix with shape (N, K), internally transposed. + bias (Optional[torch.Tensor]): Bias vector with shape (N,). + dtype (Optional[torch.dtype]): Output Y datatype (BF16 or FP16). + quant_dtype (Optional[torch.dtype]): FP8 dtype for the MXFP8 quantized + X values (defaults to torch.float8_e4m3fn). + y (Optional[torch.Tensor]): Pre-allocated output with shape (M, N). + x_quant (Optional[torch.Tensor]): Pre-allocated MXFP8 quantized X with + shape (M, K). + x_scales (Optional[torch.Tensor]): Pre-allocated uint8 e8m0 scales with + shape (M, K // 32). + config (Optional[dict]): Kernel tuning parameters. + activation (Optional[str]): Activation fused into Y. + skip_reduce (Optional[bool]): Skip split-K reduction for Y. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: (Y, x_quant, x_scales). + When skip_reduce=True and NUM_KSPLIT > 1, Y has shape (NUM_KSPLIT, M, N). + """ + + _LOGGER.info(f"FUSED_GEMM_A16W16_QUANT_X: x={tuple(x.shape)} w={tuple(w.shape)}") + # Shape checks + assert x.shape[1] == w.shape[1], "Incompatible matrix shapes." + assert ( + x.shape[1] % _QUANT_BLOCK_SIZE == 0 + ), f"K={x.shape[1]} must be a multiple of {_QUANT_BLOCK_SIZE} for MXFP8 quant" + + if quant_dtype is None: + quant_dtype = torch.float8_e4m3fn + + M, K = x.shape + N, K = w.shape + w = w.T + + if config is None: + config, _ = _get_config(M, N, K) + + assert ( + config["BLOCK_SIZE_K"] % _QUANT_BLOCK_SIZE == 0 + ), f"BLOCK_SIZE_K={config['BLOCK_SIZE_K']} must be a multiple of {_QUANT_BLOCK_SIZE}" + + if y is None and (config["NUM_KSPLIT"] == 1 or not skip_reduce): + y = torch.empty((M, N), dtype=dtype, device=x.device) + + if x_quant is None: + x_quant = torch.empty((M, K), dtype=quant_dtype, device=x.device) + + if x_scales is None: + x_scales = torch.empty( + (M, K // _QUANT_BLOCK_SIZE), dtype=torch.uint8, device=x.device + ) + + if config["NUM_KSPLIT"] > 1: + y_pp = torch.empty( + (config["NUM_KSPLIT"], M, N), + dtype=torch.float32, + device=y.device if y is not None else x.device, + ) + else: + y_pp = None + + grid = lambda META: ( # noqa: E731 + META["NUM_KSPLIT"] + * triton.cdiv(M, META["BLOCK_SIZE_M"]) + * triton.cdiv(N, META["BLOCK_SIZE_N"]) + + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(K, META["BLOCK_SIZE_K"]), + ) + _fused_gemm_a16w16_quant_x_kernel[grid]( + x, + w, + bias, + y if config["NUM_KSPLIT"] == 1 else y_pp, + x_quant, + x_scales, + M, + N, + K, + x.stride(0), + x.stride(1), + w.stride(0), + w.stride(1), + 0 if config["NUM_KSPLIT"] == 1 else y_pp.stride(0), + y.stride(0) if config["NUM_KSPLIT"] == 1 else y_pp.stride(1), + y.stride(1) if config["NUM_KSPLIT"] == 1 else y_pp.stride(2), + x_quant.stride(0), + x_quant.stride(1), + x_scales.stride(0), + x_scales.stride(1), + activation=_get_activation_from_str(activation) if activation else "", + use_activation=activation is not None, + ADD_BIAS=(bias is not None), + SKIP_REDUCE=skip_reduce, + QUANT_BLOCK_SIZE=_QUANT_BLOCK_SIZE, + **config, + ) + + if config["NUM_KSPLIT"] > 1: + if skip_reduce: + return y_pp, x_quant, x_scales + + REDUCE_BLOCK_SIZE_M = 32 + REDUCE_BLOCK_SIZE_N = 32 + ACTUAL_KSPLIT = triton.cdiv(K, config["SPLITK_BLOCK_SIZE"]) + + grid_reduce = ( + triton.cdiv(M, REDUCE_BLOCK_SIZE_M), + triton.cdiv(N, REDUCE_BLOCK_SIZE_N), + ) + _gemm_splitk_reduce_kernel[grid_reduce]( + y_pp, + y, + bias, + M, + N, + y_pp.stride(0), + y_pp.stride(1), + y_pp.stride(2), + y.stride(0), + y.stride(1), + REDUCE_BLOCK_SIZE_M, + REDUCE_BLOCK_SIZE_N, + ACTUAL_KSPLIT, + triton.next_power_of_2(config["NUM_KSPLIT"]), + ADD_BIAS=(bias is not None), + activation=_get_activation_from_str(activation) if activation else "", + use_activation=activation is not None, + KERNEL_NAME="_fused_gemm_a16w16_quant_x_reduce_kernel", + ) + + return y, x_quant, x_scales diff --git a/aiter/ops/triton/quant/quant_mxfp8.py b/aiter/ops/triton/quant/quant_mxfp8.py new file mode 100644 index 0000000000..24a0136150 --- /dev/null +++ b/aiter/ops/triton/quant/quant_mxfp8.py @@ -0,0 +1,289 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. + +from typing import Optional, Tuple +import torch +import triton + +from aiter.ops.triton._triton_kernels.quant.quant_mxfp8 import ( + _mxfp8_quant_kernel, + _fp8_legacy_to_mxfp8_kernel, + _rmsnorm_mxfp8_quant_kernel, + _dual_rmsnorm_mxfp8_quant_kernel, +) + +_QUANT_BLOCK_SIZE = 32 + + +def per_1x32_mxfp8_quant_triton( + x: torch.Tensor, + scale: Optional[torch.Tensor] = None, + quant_dtype: torch.dtype = torch.float8_e4m3fn, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Per-1x32 MXFP8 quantization (e8m0 scale + FP8 e4m3 values). + + Args: + x: Input tensor (..., K). Typically bf16 or fp16. K % 32 == 0. + scale: Pre-allocated scale tensor (M, K // 32) uint8. Optional. + quant_dtype: FP8 dtype to cast quantized values to. On MI3xx + torch.float8_e4m3fnuz is the canonical FP8 e4m3 type. torch.float8_e4m3fn + is acceptable on hardware that supports it. + + Returns: + Tuple of: + y: FP8 tensor of shape x.shape. + s: e8m0 (uint8) scale tensor of shape (..., K // 32). + """ + assert x.dim() >= 2, f"x must be at least 2D, got {x.dim()}" + orig_shape = x.shape + K = orig_shape[-1] + assert ( + K % _QUANT_BLOCK_SIZE == 0 + ), f"last dim K={K} must be a multiple of {_QUANT_BLOCK_SIZE}" + + x2d = x.reshape(-1, K).contiguous() + M = x2d.shape[0] + Ns = K // _QUANT_BLOCK_SIZE # number of scales per row + + y = torch.empty((M, K), dtype=quant_dtype, device=x.device) + if scale is None: + scale = torch.empty((M, Ns), dtype=torch.uint8, device=x.device) + else: + assert scale.shape == (M, Ns), f"scale shape {scale.shape} != ({M},{Ns})" + assert scale.dtype == torch.uint8 + + BLOCK_SIZE_N = triton.next_power_of_2(K) + NUM_PRGMS = M + grid = (NUM_PRGMS,) + + _mxfp8_quant_kernel[grid]( + x2d, + y, + scale, + M, + K, + x2d.stride(0), + x2d.stride(1), + y.stride(0), + y.stride(1), + scale.stride(0), + scale.stride(1), + BLOCK_SIZE_N=BLOCK_SIZE_N, + QUANT_BLOCK_SIZE=_QUANT_BLOCK_SIZE, + NUM_PRGMS=NUM_PRGMS, + ) + + y = y.view(*orig_shape[:-1], K) + s = scale.view(*orig_shape[:-1], Ns) + return y, s + + +_LEGACY_BLOCK_SIZE = 128 + + +def fp8_legacy_to_mxfp8( + x_fnuz: torch.Tensor, + x_scale_fp32: torch.Tensor, + y_fn: Optional[torch.Tensor] = None, + y_scale: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Transcode (FP8 e4m3fnuz, fp32 1x128 scale) -> (FP8 e4m3fn, e8m0 1x32 scale) + in a single Triton launch. Replaces the Python dequant+requant cascade + used when MXFP8 path receives legacy-formatted (FP8 + fp32 1x128) inputs. + + Args: + x_fnuz: FP8 e4m3fnuz tensor of shape (M, N), N % 32 == 0. + x_scale_fp32: fp32 scale of shape (M, N // 128). + y_fn: optional preallocated output FP8 e4m3fn tensor. + y_scale: optional preallocated uint8 e8m0 scale tensor. + + Returns: + y_fn (M, N) fp8 e4m3fn, y_scale (M, N // 32) uint8 e8m0. + """ + assert x_fnuz.dim() == 2, f"x must be 2D, got {x_fnuz.dim()}" + M, N = x_fnuz.shape + assert N % _QUANT_BLOCK_SIZE == 0 + assert N % _LEGACY_BLOCK_SIZE == 0 + assert x_scale_fp32.shape == ( + M, + N // _LEGACY_BLOCK_SIZE, + ), f"x_scale_fp32 shape {x_scale_fp32.shape} != ({M},{N // _LEGACY_BLOCK_SIZE})" + + Ns = N // _QUANT_BLOCK_SIZE + if y_fn is None: + y_fn = torch.empty((M, N), dtype=torch.float8_e4m3fn, device=x_fnuz.device) + if y_scale is None: + y_scale = torch.empty((M, Ns), dtype=torch.uint8, device=x_fnuz.device) + + BLOCK_SIZE_M = 1 + grid = (triton.cdiv(M, BLOCK_SIZE_M), Ns) + + _fp8_legacy_to_mxfp8_kernel[grid]( + x_fnuz, + x_scale_fp32, + y_fn, + y_scale, + M, + N, + x_fnuz.stride(0), + x_fnuz.stride(1), + x_scale_fp32.stride(0), + x_scale_fp32.stride(1), + y_fn.stride(0), + y_fn.stride(1), + y_scale.stride(0), + y_scale.stride(1), + BLOCK_SIZE_M=BLOCK_SIZE_M, + QUANT_BLOCK_SIZE=_QUANT_BLOCK_SIZE, + LEGACY_BLOCK_SIZE=_LEGACY_BLOCK_SIZE, + ) + + return y_fn, y_scale + + +def rmsnorm_mxfp8_quant( + x: torch.Tensor, + weight: torch.Tensor, + eps: float, + y: Optional[torch.Tensor] = None, + scale: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Fused RMSNorm + MXFP8 (1x32 e8m0) quant in a single Triton launch. + + Args: + x: (M, K) bf16 or fp16. + weight: (K,) bf16 or fp16 RMSNorm weight. + eps: RMSNorm epsilon. + y: optional preallocated FP8 e4m3fn output (M, K). + scale: optional preallocated uint8 e8m0 output (M, K // 32). + + Returns: + y (M, K) fp8 e4m3fn, scale (M, K // 32) uint8. + """ + assert x.dim() == 2, f"x must be 2D, got {x.dim()}" + M, K = x.shape + assert weight.shape == (K,), f"weight shape {weight.shape} != ({K},)" + assert K % _QUANT_BLOCK_SIZE == 0 + Ns = K // _QUANT_BLOCK_SIZE + BLOCK_SIZE_K = triton.next_power_of_2(K) + + if y is None: + y = torch.empty((M, K), dtype=torch.float8_e4m3fn, device=x.device) + if scale is None: + scale = torch.empty((M, Ns), dtype=torch.uint8, device=x.device) + + NUM_PRGMS = M + grid = (NUM_PRGMS,) + + _rmsnorm_mxfp8_quant_kernel[grid]( + x, + weight, + y, + scale, + M, + K, + x.stride(0), + x.stride(1), + y.stride(0), + y.stride(1), + scale.stride(0), + scale.stride(1), + eps, + BLOCK_SIZE_K=BLOCK_SIZE_K, + QUANT_BLOCK_SIZE=_QUANT_BLOCK_SIZE, + NUM_PRGMS=NUM_PRGMS, + ) + return y, scale + + +def dual_rmsnorm_mxfp8_quant( + q: torch.Tensor, + k: torch.Tensor, + q_weight: torch.Tensor, + k_weight: torch.Tensor, + eps_q: float, + eps_k: Optional[float] = None, + yq: Optional[torch.Tensor] = None, + sq: Optional[torch.Tensor] = None, + yk: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Fused dual RMSNorm in a single Triton launch. + + - Q side: RMSNorm(q, q_weight, eps_q) -> MXFP8 (FP8 e4m3fn + uint8 e8m0 1x32). + - K side: RMSNorm(k, k_weight, eps_k) -> bf16. + + Replaces the CK `fused_qk_rmsnorm_group_quant` kernel for the MXFP8 GEMM + path on V4 (Task #77): one launch instead of two (rmsnorm_mxfp8_quant + + rmsnorm2d_fwd_), eliminating the ~6us/layer launch-overhead regression. + + Args: + q: (M, KQ) bf16 or fp16 — Q-side input (e.g. q_lora). + k: (M, KK) bf16 or fp16 — K-side input (e.g. kv_pre). + q_weight: (KQ,) bf16 or fp16 — Q RMSNorm weight. + k_weight: (KK,) bf16 or fp16 — K RMSNorm weight. + eps_q: Q RMSNorm epsilon. + eps_k: K RMSNorm epsilon; defaults to eps_q. + yq, sq, yk: optional pre-allocated outputs. + + Returns: + yq (M, KQ) fp8 e4m3fn, sq (M, KQ // 32) uint8 e8m0, yk (M, KK) bf16. + """ + assert q.dim() == 2, f"q must be 2D, got {q.dim()}" + assert k.dim() == 2, f"k must be 2D, got {k.dim()}" + M, KQ = q.shape + Mk, KK = k.shape + assert M == Mk, f"q rows {M} != k rows {Mk}" + assert q_weight.shape == (KQ,), f"q_weight shape {q_weight.shape} != ({KQ},)" + assert k_weight.shape == (KK,), f"k_weight shape {k_weight.shape} != ({KK},)" + assert ( + KQ % _QUANT_BLOCK_SIZE == 0 + ), f"KQ={KQ} must be a multiple of {_QUANT_BLOCK_SIZE}" + if eps_k is None: + eps_k = eps_q + + Ns = KQ // _QUANT_BLOCK_SIZE + BLOCK_SIZE_KQ = triton.next_power_of_2(KQ) + BLOCK_SIZE_KK = triton.next_power_of_2(KK) + + if yq is None: + yq = torch.empty((M, KQ), dtype=torch.float8_e4m3fn, device=q.device) + if sq is None: + sq = torch.empty((M, Ns), dtype=torch.uint8, device=q.device) + if yk is None: + yk = torch.empty((M, KK), dtype=k.dtype, device=k.device) + + NUM_PRGMS = M + grid = (NUM_PRGMS,) + + _dual_rmsnorm_mxfp8_quant_kernel[grid]( + q, + k, + q_weight, + k_weight, + yq, + sq, + yk, + M, + KQ, + KK, + q.stride(0), + q.stride(1), + k.stride(0), + k.stride(1), + yq.stride(0), + yq.stride(1), + sq.stride(0), + sq.stride(1), + yk.stride(0), + yk.stride(1), + eps_q, + eps_k, + BLOCK_SIZE_KQ=BLOCK_SIZE_KQ, + BLOCK_SIZE_KK=BLOCK_SIZE_KK, + QUANT_BLOCK_SIZE=_QUANT_BLOCK_SIZE, + NUM_PRGMS=NUM_PRGMS, + ) + return yq, sq, yk diff --git a/aiter/ops/triton/utils/_triton/tunning/ut_afp8wfp8_gemm_preshuffle.py b/aiter/ops/triton/utils/_triton/tunning/ut_afp8wfp8_gemm_preshuffle.py new file mode 100644 index 0000000000..8f13ac4b3e --- /dev/null +++ b/aiter/ops/triton/utils/_triton/tunning/ut_afp8wfp8_gemm_preshuffle.py @@ -0,0 +1,44 @@ +import sys +from _utils import ( + run_profile, + get_input_shape_and_config_list, +) + +############################################################ +# +import torch +from aiter.ops.triton.gemm.basic.gemm_afp8wfp8 import gemm_afp8wfp8_preshuffle +from op_tests.triton_tests.gemm.basic.test_gemm_afp8wfp8 import ( + generate_inputs, +) +from aiter.ops.triton.utils.types import get_fp8_dtypes +from aiter.ops.triton.utils.gemm_config_utils import compute_splitk_params + +############################################################ + +input_shape, config_list = get_input_shape_and_config_list(sys.argv, shape_size=3) +M, N, K = input_shape + +############################################################ +# +_, e4m3_type = get_fp8_dtypes() +dtype = torch.bfloat16 +x_fp8, w_fp8, w_kernel, x_scales, w_scales = generate_inputs( + *input_shape, + shuffle=True, +) +############################################################ + +for config in config_list: + if config is not None: + compute_splitk_params(config, K) + + def fn(): + ############################################################ + # + gemm_afp8wfp8_preshuffle( + x_fp8, w_kernel, x_scales, w_scales, dtype=dtype, config=config + ) + ############################################################ + + run_profile(fn) diff --git a/op_tests/triton_tests/gemm/basic/test_gemm_afp8wfp8.py b/op_tests/triton_tests/gemm/basic/test_gemm_afp8wfp8.py new file mode 100644 index 0000000000..8b3bf5eae8 --- /dev/null +++ b/op_tests/triton_tests/gemm/basic/test_gemm_afp8wfp8.py @@ -0,0 +1,144 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. + +import pytest +import torch + +from aiter.ops.triton.gemm.basic.gemm_afp8wfp8 import ( + gemm_afp8wfp8, + gemm_afp8wfp8_preshuffle, +) +from aiter.ops.shuffle import shuffle_weight +import aiter.ops.triton.utils._triton.arch_info as arch_info + +SCALE_GROUP_SIZE = 32 # A: 1x32 e8m0 scale group +W_SCALE_K_GROUP = 128 # B: 128 in K direction +W_SCALE_N_GROUP = 128 # B: 128 in N direction +FP8_MAX = 448.0 # e4m3 max + + +def e8m0_to_f32(x: torch.Tensor) -> torch.Tensor: + """Decode unsigned-biased e8m0 (uint8) to fp32. Bias 127, value = 2^(b-127).""" + return torch.exp2((x.to(torch.int32) - 127).to(torch.float32)) + + +def generate_inputs(M: int, N: int, K: int, shuffle: bool = False): + """Returns ``(x_fp8, w_fp8, w_kernel, x_scales, w_scales)``. + + ``w_fp8`` is always the unshuffled weight (for use by the fp32 reference). + ``w_kernel`` is the weight to pass to the kernel: identical to ``w_fp8`` + when ``shuffle=False``, or shuffled via ``shuffle_weight(layout=(16, 16))`` + when ``shuffle=True``. + """ + # Small random fp32 → fp8 e4m3fn, kept inside e4m3 range so the cast is exact-ish. + x_f32 = torch.randn((M, K), dtype=torch.float32, device="cuda") + w_f32 = torch.randn((N, K), dtype=torch.float32, device="cuda") + x_f32 = torch.clamp(x_f32, -FP8_MAX, FP8_MAX) + w_f32 = torch.clamp(w_f32, -FP8_MAX, FP8_MAX) + x_fp8 = x_f32.to(torch.float8_e4m3fn) + w_fp8 = w_f32.to(torch.float8_e4m3fn) + + # e8m0 scales near 127 (== 1.0) so the dequant has unit-ish magnitude. + x_scales = torch.randint( + 125, 130, (M, K // SCALE_GROUP_SIZE), dtype=torch.uint8, device="cuda" + ) + w_scales = torch.randint( + 125, + 130, + (N // W_SCALE_N_GROUP, K // W_SCALE_K_GROUP), + dtype=torch.uint8, + device="cuda", + ) + + if shuffle: + # shuffle_weight operates on raw bytes; view as uint8 to avoid dtype quirks. + w_kernel = shuffle_weight(w_fp8.view(torch.uint8), layout=(16, 16)) + else: + w_kernel = w_fp8 + + return x_fp8, w_fp8, w_kernel, x_scales, w_scales + + +def run_torch_gemm_afp8wfp8( + x_fp8: torch.Tensor, + w_fp8: torch.Tensor, + x_scales: torch.Tensor, + w_scales: torch.Tensor, + out_dtype: torch.dtype, +) -> torch.Tensor: + """Reference: dequant both operands to fp32 and run torch.mm.""" + M, K = x_fp8.shape + N, _ = w_fp8.shape + + x_view = x_fp8 if x_fp8.dtype != torch.uint8 else x_fp8.view(torch.float8_e4m3fn) + w_view = w_fp8 if w_fp8.dtype != torch.uint8 else w_fp8.view(torch.float8_e4m3fn) + x_f32 = x_view.to(torch.float32) + w_f32 = w_view.to(torch.float32) + + x_s_f32 = e8m0_to_f32(x_scales).repeat_interleave(SCALE_GROUP_SIZE, dim=1) + assert x_s_f32.shape == (M, K) + + w_s_f32 = e8m0_to_f32(w_scales) + w_s_f32 = w_s_f32.repeat_interleave(W_SCALE_N_GROUP, dim=0).repeat_interleave( + W_SCALE_K_GROUP, dim=1 + ) + assert w_s_f32.shape == (N, K) + + x_dq = x_f32 * x_s_f32 + w_dq = w_f32 * w_s_f32 + return torch.mm(x_dq, w_dq.T).to(out_dtype) + + +def get_shapes(): + # (M, N, K), with N % 128 == 0 and K % 128 == 0 to fit the 128x128 W-scale layout. + return [ + (m, n, k) + for m in [1, 4, 8, 16, 32, 64, 128] + for n, k in [ + (1536, 4096), + (4096, 1024), + (512, 4096), + (8192, 1024), + (2048, 7168), + (7168, 2048), + (768, 7168), + (7168, 384), + (8192, 1536), + ] + ] + + +@pytest.mark.parametrize("M, N, K", get_shapes()) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) +def test_gemm_afp8wfp8(M: int, N: int, K: int, dtype: torch.dtype): + torch.manual_seed(0) + if not arch_info.is_fp8_avail(): + pytest.skip("MXFP8 GEMM requires FP8-capable arch") + torch.cuda.empty_cache() + + x_fp8, w_fp8, w_kernel, x_scales, w_scales = generate_inputs(M, N, K, shuffle=False) + + torch_out = run_torch_gemm_afp8wfp8(x_fp8, w_fp8, x_scales, w_scales, dtype) + triton_out = gemm_afp8wfp8(x_fp8, w_kernel, x_scales, w_scales, dtype=dtype) + + torch.testing.assert_close(triton_out, torch_out, atol=0.03, rtol=1e-2) + + +@pytest.mark.parametrize("M, N, K", get_shapes()) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +def test_gemm_afp8wfp8_preshuffle(M: int, N: int, K: int, dtype: torch.dtype): + torch.manual_seed(0) + if not arch_info.is_fp8_avail(): + pytest.skip("MXFP8 GEMM requires FP8-capable arch") + if N % 16 != 0 or K % 32 != 0: + pytest.skip("Preshuffle requires N % 16 == 0 and K % 32 == 0") + torch.cuda.empty_cache() + + x_fp8, w_fp8, w_kernel, x_scales, w_scales = generate_inputs(M, N, K, shuffle=True) + + torch_out = run_torch_gemm_afp8wfp8(x_fp8, w_fp8, x_scales, w_scales, dtype) + triton_out = gemm_afp8wfp8_preshuffle( + x_fp8, w_kernel, x_scales, w_scales, dtype=dtype + ) + + torch.testing.assert_close(triton_out, torch_out, atol=0.03, rtol=1e-2) diff --git a/op_tests/triton_tests/gemm/fused/test_fused_gemm_a16w16_quant_x.py b/op_tests/triton_tests/gemm/fused/test_fused_gemm_a16w16_quant_x.py new file mode 100644 index 0000000000..2a7ccf5624 --- /dev/null +++ b/op_tests/triton_tests/gemm/fused/test_fused_gemm_a16w16_quant_x.py @@ -0,0 +1,143 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. + +import torch +import torch.nn.functional as F +import pytest + +from aiter.ops.triton.gemm.fused.fused_gemm_a16w16_quant_x import ( + fused_gemm_a16w16_quant_x, +) +from op_tests.triton_tests.gemm.basic.test_gemm_a16w16 import ( + generate_gemm_a16w16_inputs, +) + +_QUANT_BLOCK_SIZE = 32 +# 0xFF800000 in two's complement int32. Mask keeps sign + 8-bit exponent + top mantissa bit. +_E8M0_MASK_INT32 = -8388608 + + +def torch_mxfp8_quant_from_fp32(x_fp32: torch.Tensor): + """Bit-faithful port of `_mxfp8_quant_kernel` quant logic, taking fp32 input. + + Computes per-1x32 e8m0 scale (uint8) and FP8 e4m3fn values. + """ + assert x_fp32.dim() == 2, f"x_fp32 must be 2D, got {x_fp32.dim()}" + M, K = x_fp32.shape + assert K % _QUANT_BLOCK_SIZE == 0 + Ng = K // _QUANT_BLOCK_SIZE + x_2d = x_fp32.reshape(M, Ng, _QUANT_BLOCK_SIZE).to(torch.float32) + amax = torch.amax(torch.abs(x_2d), dim=-1, keepdim=True) # (M, Ng, 1) + + amax_i32 = amax.contiguous().view(torch.int32) + amax_i32 = (amax_i32 + 0x200000) & _E8M0_MASK_INT32 + amax_p2 = amax_i32.view(torch.float32) + + scale_unbiased = torch.log2(amax_p2).floor() - 8 + scale_unbiased = torch.clamp(scale_unbiased, min=-127, max=127) + scale_e8m0 = (scale_unbiased.to(torch.int32) + 127).to(torch.uint8) + quant_scale = torch.exp2(-scale_unbiased) + + qx_2d = x_2d * quant_scale + qx = qx_2d.reshape(M, K) + y_fp8 = qx.to(torch.float8_e4m3fn) + s = scale_e8m0.reshape(M, Ng) + return y_fp8, s + + +def get_x_vals(): + x_vals = [(1024, 1024, 1024)] + x_vals += [(2048, 2048, 2048)] + # DSv4 router gate: num_tokens x 384 x 7168 + x_vals += [(2**i, 384, 7168) for i in range(5, 9)] + # DSR1 router GEMM + x_vals += [(2**i, 256, 7168) for i in range(5, 9)] + return x_vals + + +def _assert_quant_close(triton_x_quant, triton_x_scales, x): + ref_x_quant, ref_x_scales = torch_mxfp8_quant_from_fp32(x.to(torch.float32)) + # e8m0 scales: bit-exact (integer-only after fp32 cast). + torch.testing.assert_close(triton_x_scales, ref_x_scales) + # Quantized values: compare via uint8 view (allow off-by-1 for any rounding + # subtlety in the fp32->fp8 cast). + torch.testing.assert_close( + triton_x_quant.view(torch.uint8).to(torch.int32), + ref_x_quant.view(torch.uint8).to(torch.int32), + atol=1, + rtol=0, + ) + + +@pytest.mark.parametrize("M, N, K", get_x_vals()) +def test_fused_gemm_a16w16_quant_x(M: int, N: int, K: int): + torch.cuda.empty_cache() + x, w, _, _, _ = generate_gemm_a16w16_inputs( + M, N, K, dtype=torch.bfloat16, output=False + ) + + torch_y = F.linear(x, w, bias=None) + + triton_y, triton_x_quant, triton_x_scales = fused_gemm_a16w16_quant_x(x, w) + + torch.testing.assert_close(triton_y, torch_y, atol=1e-1, rtol=1e-2) + _assert_quant_close(triton_x_quant, triton_x_scales, x) + + +def get_fewer_x_vals(): + x_vals = [(16, 1024, 1024)] + x_vals += [(128, 8192, 512)] + x_vals += [(256, 512, 8192)] + x_vals += [(1024, 1024, 1024)] + return x_vals + + +@pytest.mark.parametrize("activation", ["gelu", "gelu_tanh", "silu"]) +@pytest.mark.parametrize("M, N, K", get_fewer_x_vals()) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("output", [True, False]) +def test_fused_gemm_a16w16_quant_x_activation( + M: int, N: int, K: int, dtype, output, activation +): + x, w, _, _, y = generate_gemm_a16w16_inputs(M, N, K, dtype, output=output) + + torch_y = F.linear(x, w, bias=None) + if activation == "gelu": + torch_y = F.gelu(torch_y) + elif activation == "gelu_tanh": + torch_y = F.gelu(torch_y, approximate="tanh") + elif activation == "silu": + torch_y = F.silu(torch_y) + + triton_y, triton_x_quant, triton_x_scales = fused_gemm_a16w16_quant_x( + x, + w, + bias=None, + dtype=dtype, + y=y, + activation=activation, + ) + + torch.testing.assert_close(triton_y, torch_y, atol=1e-1, rtol=1e-2) + _assert_quant_close(triton_x_quant, triton_x_scales, x) + + +@pytest.mark.parametrize("M, N, K", get_fewer_x_vals()) +@pytest.mark.parametrize("skip_reduce", [True, False]) +def test_fused_gemm_a16w16_quant_x_skip_reduce(M: int, N: int, K: int, skip_reduce): + torch.cuda.empty_cache() + x, w, _, _, _ = generate_gemm_a16w16_inputs( + M, N, K, dtype=torch.bfloat16, output=False + ) + + torch_y = F.linear(x, w, bias=None) + + triton_y, triton_x_quant, triton_x_scales = fused_gemm_a16w16_quant_x( + x, w, skip_reduce=skip_reduce + ) + + if triton_y.dim() == 3: + triton_y = triton_y.sum(axis=0).to(torch.bfloat16) + + torch.testing.assert_close(triton_y, torch_y, atol=1e-3, rtol=1e-2) + _assert_quant_close(triton_x_quant, triton_x_scales, x) diff --git a/op_tests/triton_tests/quant/test_quant_mxfp8.py b/op_tests/triton_tests/quant/test_quant_mxfp8.py new file mode 100644 index 0000000000..9433eefa87 --- /dev/null +++ b/op_tests/triton_tests/quant/test_quant_mxfp8.py @@ -0,0 +1,362 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. + +import pytest +import torch + +from aiter.ops.triton.quant.quant_mxfp8 import ( + per_1x32_mxfp8_quant_triton, + fp8_legacy_to_mxfp8, + rmsnorm_mxfp8_quant, + dual_rmsnorm_mxfp8_quant, +) +import aiter.ops.triton.utils._triton.arch_info as arch_info + +QUANT_BLOCK_SIZE = 32 +LEGACY_BLOCK_SIZE = 128 +# 0xFF800000 in two's complement int32. Mask keeps sign + 8-bit exponent + top mantissa bit. +_E8M0_MASK_INT32 = -8388608 + + +def torch_mxfp8_quant_from_fp32(x_fp32: torch.Tensor): + """Bit-faithful port of `_mxfp8_quant_kernel` quant logic, taking fp32 input. + + Computes per-1x32 e8m0 scale (uint8) and FP8 e4m3fn values. + """ + assert x_fp32.dim() == 2, f"x_fp32 must be 2D, got {x_fp32.dim()}" + M, K = x_fp32.shape + assert K % QUANT_BLOCK_SIZE == 0 + Ng = K // QUANT_BLOCK_SIZE + x_2d = x_fp32.reshape(M, Ng, QUANT_BLOCK_SIZE).to(torch.float32) + amax = torch.amax(torch.abs(x_2d), dim=-1, keepdim=True) # (M, Ng, 1) + + # Same bit-level "round up to e8m0-representable pow-2" as the kernel. + amax_i32 = amax.contiguous().view(torch.int32) + amax_i32 = (amax_i32 + 0x200000) & _E8M0_MASK_INT32 + amax_p2 = amax_i32.view(torch.float32) + + scale_unbiased = torch.log2(amax_p2).floor() - 8 + scale_unbiased = torch.clamp(scale_unbiased, min=-127, max=127) + scale_e8m0 = (scale_unbiased.to(torch.int32) + 127).to(torch.uint8) + quant_scale = torch.exp2(-scale_unbiased) + + qx_2d = x_2d * quant_scale # broadcast over inner-32 + qx = qx_2d.reshape(M, K) + y_fp8 = qx.to(torch.float8_e4m3fn) + s = scale_e8m0.reshape(M, Ng) + return y_fp8, s + + +def e8m0_to_f32(x: torch.Tensor) -> torch.Tensor: + return torch.exp2((x.to(torch.int32) - 127).to(torch.float32)) + + +# ----------------------------------------------------------------------------- +# per_1x32_mxfp8_quant_triton +# ----------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "M, K", + [ + (1, 32), + (1, 64), + (1, 128), + (2, 32), + (8, 64), + (16, 128), + (32, 256), + (64, 512), + (128, 1024), + (137, 64), # non-power-of-2 M + (256, 32), + ], +) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) +def test_per_1x32_mxfp8_quant(M: int, K: int, dtype: torch.dtype): + if not arch_info.is_fp8_avail(): + pytest.skip("FP8 not supported on this arch") + torch.cuda.empty_cache() + torch.manual_seed(20) + + x = torch.randn((M, K), dtype=dtype, device="cuda") * 4.0 + + # Reference path: emulate the kernel in fp32 (matching its precision). + x_fp32 = x.to(torch.float32) + y_ref, s_ref = torch_mxfp8_quant_from_fp32(x_fp32) + + # Triton path. + y_kern, s_kern = per_1x32_mxfp8_quant_triton(x) + + # Scales must be bit-exact: the e8m0 derivation is integer-only after + # the fp32 cast, and amax is order-independent. + torch.testing.assert_close(s_kern, s_ref) + + # Quantized values: compare via the uint8 view (allow off-by-1 for any + # rounding-mode subtlety in the fp32→fp8 cast). + torch.testing.assert_close( + y_kern.view(torch.uint8).to(torch.int32), + y_ref.view(torch.uint8).to(torch.int32), + atol=1, + rtol=0, + ) + + +def test_per_1x32_mxfp8_quant_preallocated_scale(): + if not arch_info.is_fp8_avail(): + pytest.skip("FP8 not supported on this arch") + torch.cuda.empty_cache() + torch.manual_seed(20) + + M, K = 64, 256 + x = torch.randn((M, K), dtype=torch.bfloat16, device="cuda") + scale_pre = torch.empty( + (M, K // QUANT_BLOCK_SIZE), dtype=torch.uint8, device="cuda" + ) + y, s = per_1x32_mxfp8_quant_triton(x, scale=scale_pre) + assert s.data_ptr() == scale_pre.data_ptr() + + y_ref, s_ref = torch_mxfp8_quant_from_fp32(x.to(torch.float32)) + torch.testing.assert_close(s, s_ref) + + +def test_per_1x32_mxfp8_quant_multidim(): + """Wrapper folds higher dims into M; sanity-check 3D input.""" + if not arch_info.is_fp8_avail(): + pytest.skip("FP8 not supported on this arch") + torch.cuda.empty_cache() + torch.manual_seed(0) + + B, M, K = 4, 8, 128 + x = torch.randn((B, M, K), dtype=torch.bfloat16, device="cuda") + y, s = per_1x32_mxfp8_quant_triton(x) + assert y.shape == (B, M, K) + assert s.shape == (B, M, K // QUANT_BLOCK_SIZE) + + y_ref, s_ref = torch_mxfp8_quant_from_fp32(x.reshape(-1, K).to(torch.float32)) + torch.testing.assert_close(s.reshape(-1, K // QUANT_BLOCK_SIZE), s_ref) + + +# ----------------------------------------------------------------------------- +# fp8_legacy_to_mxfp8 +# ----------------------------------------------------------------------------- + + +def torch_fp8_legacy_to_mxfp8(x_fnuz: torch.Tensor, x_scale_fp32: torch.Tensor): + """Reference: dequantize fnuz fp8 with the 1x128 fp32 scale, then run + the standard mxfp8 1x32 quant on the result.""" + M, N = x_fnuz.shape + x_dq = x_fnuz.to(torch.float32) * x_scale_fp32.repeat_interleave( + LEGACY_BLOCK_SIZE, dim=1 + ) + return torch_mxfp8_quant_from_fp32(x_dq) + + +@pytest.mark.parametrize( + "M, N", + [ + (1, 128), + (8, 128), + (16, 256), + (32, 512), + (64, 1024), + (128, 256), + (37, 256), # non-pow-2 M + ], +) +def test_fp8_legacy_to_mxfp8(M: int, N: int): + if not arch_info.is_fp8_avail(): + pytest.skip("FP8 not supported on this arch") + torch.cuda.empty_cache() + torch.manual_seed(5) + + # Random values within e4m3fnuz range, then cast to fnuz fp8. + x_f32 = (torch.randn((M, N), dtype=torch.float32, device="cuda")).clamp(-200, 200) + x_fnuz = x_f32.to(torch.float8_e4m3fnuz) + # Random fp32 1x128 scales in a moderate range so the dequant stays within fp8. + x_scale_fp32 = ( + torch.rand((M, N // LEGACY_BLOCK_SIZE), dtype=torch.float32, device="cuda") + * 0.5 + + 0.25 + ) + + y_ref, s_ref = torch_fp8_legacy_to_mxfp8(x_fnuz, x_scale_fp32) + y_kern, s_kern = fp8_legacy_to_mxfp8(x_fnuz, x_scale_fp32) + + torch.testing.assert_close(s_kern, s_ref) + torch.testing.assert_close( + y_kern.view(torch.uint8).to(torch.int32), + y_ref.view(torch.uint8).to(torch.int32), + atol=1, + rtol=0, + ) + + +def test_fp8_legacy_to_mxfp8_preallocated(): + if not arch_info.is_fp8_avail(): + pytest.skip("FP8 not supported on this arch") + torch.cuda.empty_cache() + torch.manual_seed(5) + + M, N = 16, 256 + x_fnuz = (torch.randn((M, N), device="cuda") * 4).to(torch.float8_e4m3fnuz) + x_scale_fp32 = torch.rand((M, N // LEGACY_BLOCK_SIZE), device="cuda") * 0.5 + 0.25 + y_pre = torch.empty((M, N), dtype=torch.float8_e4m3fn, device="cuda") + s_pre = torch.empty((M, N // QUANT_BLOCK_SIZE), dtype=torch.uint8, device="cuda") + y, s = fp8_legacy_to_mxfp8(x_fnuz, x_scale_fp32, y_fn=y_pre, y_scale=s_pre) + assert y.data_ptr() == y_pre.data_ptr() + assert s.data_ptr() == s_pre.data_ptr() + + y_ref, s_ref = torch_fp8_legacy_to_mxfp8(x_fnuz, x_scale_fp32) + torch.testing.assert_close(s, s_ref) + + +# ----------------------------------------------------------------------------- +# rmsnorm_mxfp8_quant +# ----------------------------------------------------------------------------- + + +def torch_rmsnorm(x: torch.Tensor, weight: torch.Tensor, eps: float) -> torch.Tensor: + x_f32 = x.to(torch.float32) + g_f32 = weight.to(torch.float32) + rstd = torch.rsqrt(x_f32.pow(2).mean(-1, keepdim=True) + eps) + return x_f32 * rstd * g_f32 + + +def torch_rmsnorm_mxfp8_quant(x, weight, eps): + y_fp32 = torch_rmsnorm(x, weight, eps) + return torch_mxfp8_quant_from_fp32(y_fp32) + + +@pytest.mark.parametrize( + "M, K", + [ + (1, 32), + (1, 128), + (8, 128), + (16, 256), + (32, 512), + (64, 1024), + (128, 2048), + (97, 64), # non-pow-2 M, K=64 + (200, 192), # non-pow-2 K (still multiple of 32) + ], +) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) +def test_rmsnorm_mxfp8_quant(M: int, K: int, dtype: torch.dtype): + if not arch_info.is_fp8_avail(): + pytest.skip("FP8 not supported on this arch") + torch.cuda.empty_cache() + torch.manual_seed(11) + + x = torch.randn((M, K), dtype=dtype, device="cuda") + weight = torch.randn((K,), dtype=dtype, device="cuda") * 0.5 + 1.0 + eps = 1e-5 + + y_ref, s_ref = torch_rmsnorm_mxfp8_quant(x, weight, eps) + y_kern, s_kern = rmsnorm_mxfp8_quant(x, weight, eps) + + # Hardware rsqrt vs torch.rsqrt can disagree by a ULP; that may flip a single + # e8m0 bin near a power-of-2 boundary. Compare dequantized values instead. + s_ref_f32 = e8m0_to_f32(s_ref).repeat_interleave(QUANT_BLOCK_SIZE, dim=1) + s_kern_f32 = e8m0_to_f32(s_kern).repeat_interleave(QUANT_BLOCK_SIZE, dim=1) + y_ref_dq = y_ref.to(torch.float32) * s_ref_f32 + y_kern_dq = y_kern.to(torch.float32) * s_kern_f32 + + torch.testing.assert_close(y_kern_dq, y_ref_dq, atol=5e-2, rtol=5e-2) + + +def test_rmsnorm_mxfp8_quant_preallocated(): + if not arch_info.is_fp8_avail(): + pytest.skip("FP8 not supported on this arch") + torch.cuda.empty_cache() + torch.manual_seed(11) + + M, K = 32, 256 + x = torch.randn((M, K), dtype=torch.bfloat16, device="cuda") + weight = torch.randn((K,), dtype=torch.bfloat16, device="cuda") + y_pre = torch.empty((M, K), dtype=torch.float8_e4m3fn, device="cuda") + s_pre = torch.empty((M, K // QUANT_BLOCK_SIZE), dtype=torch.uint8, device="cuda") + y, s = rmsnorm_mxfp8_quant(x, weight, 1e-5, y=y_pre, scale=s_pre) + assert y.data_ptr() == y_pre.data_ptr() + assert s.data_ptr() == s_pre.data_ptr() + + +# ----------------------------------------------------------------------------- +# dual_rmsnorm_mxfp8_quant +# ----------------------------------------------------------------------------- + + +def torch_dual_rmsnorm_mxfp8_quant(q, k, q_weight, k_weight, eps_q, eps_k): + yq_fp32 = torch_rmsnorm(q, q_weight, eps_q) + yq, sq = torch_mxfp8_quant_from_fp32(yq_fp32) + yk_fp32 = torch_rmsnorm(k, k_weight, eps_k) + yk = yk_fp32.to(k.dtype) + return yq, sq, yk + + +@pytest.mark.parametrize( + "M, KQ, KK", + [ + (1, 32, 32), + (1, 128, 64), + (8, 256, 128), + (16, 512, 256), + (32, 1024, 512), + (64, 2048, 1024), + (47, 96, 80), # non-pow-2 sizes + ], +) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) +def test_dual_rmsnorm_mxfp8_quant(M: int, KQ: int, KK: int, dtype: torch.dtype): + if not arch_info.is_fp8_avail(): + pytest.skip("FP8 not supported on this arch") + torch.cuda.empty_cache() + torch.manual_seed(13) + + q = torch.randn((M, KQ), dtype=dtype, device="cuda") + k = torch.randn((M, KK), dtype=dtype, device="cuda") + q_weight = torch.randn((KQ,), dtype=dtype, device="cuda") * 0.5 + 1.0 + k_weight = torch.randn((KK,), dtype=dtype, device="cuda") * 0.5 + 1.0 + eps_q, eps_k = 1e-5, 2e-5 + + yq_ref, sq_ref, yk_ref = torch_dual_rmsnorm_mxfp8_quant( + q, k, q_weight, k_weight, eps_q, eps_k + ) + yq_kern, sq_kern, yk_kern = dual_rmsnorm_mxfp8_quant( + q, k, q_weight, k_weight, eps_q, eps_k + ) + + # Q side: compare dequantized values (rsqrt jitter -> tolerate e8m0 ULP flips). + sq_ref_f32 = e8m0_to_f32(sq_ref).repeat_interleave(QUANT_BLOCK_SIZE, dim=1) + sq_kern_f32 = e8m0_to_f32(sq_kern).repeat_interleave(QUANT_BLOCK_SIZE, dim=1) + yq_ref_dq = yq_ref.to(torch.float32) * sq_ref_f32 + yq_kern_dq = yq_kern.to(torch.float32) * sq_kern_f32 + torch.testing.assert_close(yq_kern_dq, yq_ref_dq, atol=5e-2, rtol=5e-2) + + # K side: bf16/fp16 RMSNorm output. + torch.testing.assert_close(yk_kern, yk_ref, atol=5e-3, rtol=5e-3) + + +def test_dual_rmsnorm_mxfp8_quant_default_eps_k(): + """eps_k defaults to eps_q when not provided.""" + if not arch_info.is_fp8_avail(): + pytest.skip("FP8 not supported on this arch") + torch.cuda.empty_cache() + torch.manual_seed(13) + + M, KQ, KK = 16, 128, 96 + dtype = torch.bfloat16 + q = torch.randn((M, KQ), dtype=dtype, device="cuda") + k = torch.randn((M, KK), dtype=dtype, device="cuda") + q_weight = torch.randn((KQ,), dtype=dtype, device="cuda") + k_weight = torch.randn((KK,), dtype=dtype, device="cuda") + eps = 1e-5 + + yq_a, sq_a, yk_a = dual_rmsnorm_mxfp8_quant(q, k, q_weight, k_weight, eps) + yq_b, sq_b, yk_b = dual_rmsnorm_mxfp8_quant( + q, k, q_weight, k_weight, eps, eps_k=eps + ) + torch.testing.assert_close(yq_a.view(torch.uint8), yq_b.view(torch.uint8)) + torch.testing.assert_close(sq_a, sq_b) + torch.testing.assert_close(yk_a, yk_b)