From d26892859aa3473623a422c3de20053020d25909 Mon Sep 17 00:00:00 2001 From: amd-weisun Date: Mon, 18 May 2026 13:04:35 +0100 Subject: [PATCH 01/14] [Feat] Add FlyDSL MoE sorting kernel MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Drop-in replacement for OPUS/CK moe_sorting in aiter's fused_moe. Kernel paths: - T <= 16: decode — single-block LDS histogram + DPP prefix sum - 16 < T <= 2048: p0v2 + p23 — per-expert scatter + parallel prefix sum - T > 2048: 4-kernel fused — K1 clear + K2 scatter + K3 count + p23 Correctness: 32 CI tests + 14 large_shape (46 total), covering all 11 production MoE models (E=8..513, topk=2..9). --- kernels/moe_sorting_kernel.py | 2034 +++++++++++++++++++++++++++++ tests/kernels/test_moe_sorting.py | 837 ++++++++++++ 2 files changed, 2871 insertions(+) create mode 100644 kernels/moe_sorting_kernel.py create mode 100644 tests/kernels/test_moe_sorting.py diff --git a/kernels/moe_sorting_kernel.py b/kernels/moe_sorting_kernel.py new file mode 100644 index 00000000..5ce40ff3 --- /dev/null +++ b/kernels/moe_sorting_kernel.py @@ -0,0 +1,2034 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 FlyDSL Project Contributors + +"""MoE token sorting kernel (FlyDSL). + +Implements the MoE sorting operation used in DeepSeek R1 and similar MoE models. +Given router top-k selections (topk_ids, topk_weights), reorganizes tokens by expert +for efficient batched expert GEMM execution. + +Algorithm: counting sort in LDS (histogram → prefix-sum → scatter). + +Two paths: + - Decode (small T): single kernel, all phases in LDS. + - Prefill (large T): 4 kernels via HBM workspace (ClearWS → P0 scatter → P1 count → P23 prefix-sum+scatter). + +Packed token ID format: (topk_position << 24) | token_id + - Upper 8 bits: topk slot (0..topk-1) + - Lower 24 bits: token index (0..M-1) + - Padding sentinel: (topk << 24) | M +""" + +import functools +from contextlib import contextmanager + +import torch + +import flydsl.compiler as flyc +import flydsl.expr as fx +from flydsl._mlir import ir +from flydsl._mlir.dialects import memref as memref_ops +from flydsl._mlir.dialects import scf +from flydsl.compiler.kernel_function import CompilationContext +from flydsl.expr import buffer_ops, gpu, range_constexpr +from flydsl.expr import rocdl as fly_rocdl +from flydsl.expr.arith import ArithValue +from flydsl.expr.typing import T +from flydsl.expr.typing import Vector as Vec +from flydsl.runtime.device import get_rocm_arch as get_hip_arch +from flydsl.utils.smem_allocator import SmemAllocator, SmemPtr +from kernels.kernels_common import get_warp_size + +BLOCK_SIZE = 256 +UNIT_SIZE = 32 # GEMM tile-M, aka block_size in CK +WARP_SIZE = get_warp_size() + +# DPP constants for prefix sum (used by decode and prefill) +DPP_ROW_SHR_1 = 0x111 +DPP_ROW_SHR_2 = 0x112 +DPP_ROW_SHR_4 = 0x114 +DPP_ROW_SHR_8 = 0x118 +DPP_ROW_MASK = 0xF +DPP_BANK_MASK = 0xF + + +def _unwrap_val(v): + """Unwrap DSL value to raw MLIR ir.Value.""" + return v.ir_value() if hasattr(v, "ir_value") else v + + +def _dpp_intra_wave_prefix_sum(val, lane, WARP_SIZE): + """inclusive prefix sum within a single wave using DPP. + + Performs 4 DPP row_shr steps (1, 2, 4, 8) for intra-row scan, then + 2 ds_bpermute steps (16, 32) for cross-row accumulation within the wave. + Returns the inclusive prefix sum value for each lane. + + Call inside @flyc.kernel only — emits MLIR ops during tracing. + """ + val_raw = _unwrap_val(val) + zero_raw = _unwrap_val(fx.Int32(0)) + + for shift, dpp_op, threshold in [ + (1, DPP_ROW_SHR_1, 1), + (2, DPP_ROW_SHR_2, 2), + (4, DPP_ROW_SHR_4, 4), + (8, DPP_ROW_SHR_8, 8), + ]: + remote = fly_rocdl.update_dpp(T.i32, zero_raw, val_raw, dpp_op, DPP_ROW_MASK, DPP_BANK_MASK, True) + val = (lane >= fx.Int32(threshold)).select(val + fx.Int32(remote), val) + val_raw = _unwrap_val(val) + + src_lane_16 = (lane & fx.Int32(0x30)) - fx.Int32(1) + remote16 = fly_rocdl.ds_bpermute(T.i32, src_lane_16 * fx.Int32(4), val) + val = (lane >= fx.Int32(16)).select(val + fx.Int32(remote16), val) + + if WARP_SIZE > 32: + src_lane_32 = (lane & fx.Int32(0x30)) - fx.Int32(17) + remote32 = fly_rocdl.ds_bpermute(T.i32, src_lane_32 * fx.Int32(4), val) + val = (lane >= fx.Int32(32)).select(val + fx.Int32(remote32), val) + + return val + + +# --------------------------------------------------------------------------- +# AOT-compiled dispatch caches — keyed by constexpr values. +# After the first JIT call (which compiles the kernel), flyc.compile() +# returns a CompiledFunction whose __call__ skips inspect.Signature.bind, +# _make_cache_key, and dict lookup, reducing dispatch from ~70 us to ~5 us. +# --------------------------------------------------------------------------- +_decode_cf_cache = {} # (num_experts, topk, max_tokens, unit_size, has_mask, device) -> CompiledFunction +_prefill_cf_cache = {} # (num_experts, topk, unit_size, kernel_name, *constexpr_vals) -> CompiledFunction +_dummy_mask_cache = {} # device -> torch.Tensor(1, dtype=i32, value=1) + + +@contextmanager +def _if_then(if_op): + """Context manager for scf.IfOp then-region (from moe_gemm_2stage.py).""" + with ir.InsertionPoint(if_op.then_block): + try: + yield if_op.then_block + finally: + blk = if_op.then_block + if (not blk.operations) or not isinstance(blk.operations[-1], scf.YieldOp): + scf.YieldOp([]) + + +# --------------------------------------------------------------------------- +# FlyDSL GPU kernel — decode path (single kernel, SubTokenOneShot) +# --------------------------------------------------------------------------- +@functools.lru_cache(maxsize=256) +def compile_moe_sorting_decode( + *, + num_experts: int, + topk: int, + max_tokens: int = 128, + unit_size: int = UNIT_SIZE, + has_mask: bool = False, +): + """Compile the decode-path MoE sorting kernel. + + Parameters + ---------- + num_experts : int + Number of routed experts (e.g. 256 for DeepSeek R1). + topk : int + Experts per token (e.g. 8 for DeepSeek R1). + max_tokens : int + Upper bound on T for LDS sizing. Actual T is passed at runtime. + unit_size : int + GEMM tile-M for padding alignment (default 32). + """ + arch = get_hip_arch() + E = num_experts + # CDNA (warp64): 512 threads = 8 waves, affordable cross-wave reduction. + # RDNA (warp32): cap at 256 to avoid 16-wave overhead. + max_decode_block = 512 if WARP_SIZE == 64 else 256 + DECODE_BLOCK = 256 if E <= 256 else min(512, max_decode_block) + NUM_WAVES = DECODE_BLOCK // WARP_SIZE + smem_cols = E + 1 + + # LDS sizing: sub_tokens rows for the token×expert histogram + # Match CK's sizing: total LDS / occupancy / smem_cols, rounded to 8 + if arch in ("gfx942",) or str(arch).startswith("gfx94"): + lds_capacity_bytes = 65536 + elif str(arch).startswith("gfx95"): + lds_capacity_bytes = 163840 + else: + lds_capacity_bytes = 65536 # conservative default + + lds_capacity_ints = lds_capacity_bytes // 4 + target_occupancy = 2 + r = lds_capacity_ints // target_occupancy // smem_cols + sub_unroll = 8 + cumsum_bufs = 2 + if r < (cumsum_bufs + sub_unroll): + raise ValueError( + f"LDS too small for E={E}: need at least " f"{(cumsum_bufs + sub_unroll) * smem_cols * 4} bytes" + ) + r_for_sub = ((r - cumsum_bufs) // sub_unroll) * sub_unroll + r_token_min = ((max_tokens + sub_unroll - 1) // sub_unroll) * sub_unroll + r_for_sub = min(r_for_sub, r_token_min) + sub_tokens = r_for_sub + + # SmemAllocator for the 3 LDS regions + allocator = SmemAllocator(None, arch=arch) + + # Region 0: cumsum[E+1] (exclusive prefix sums per expert) + cumsum_offset = allocator._align(allocator.ptr, 16) + allocator.ptr = cumsum_offset + smem_cols * 4 + + # Region 1: cumdup[E+1] (duplicate of cumsum for scatter phase) + cumdup_offset = allocator._align(allocator.ptr, 16) + allocator.ptr = cumdup_offset + smem_cols * 4 + + # Region 2: tokens_mesh[sub_tokens, smem_cols] + mesh_offset = allocator._align(allocator.ptr, 16) + allocator.ptr = mesh_offset + sub_tokens * smem_cols * 4 + + # Region 3: cross-wave scratch for all-wave parallel prefix sum [NUM_WAVES] + scratch_offset = allocator._align(allocator.ptr, 16) + allocator.ptr = scratch_offset + NUM_WAVES * 4 + + # Helpers for raw memref LDS access (used inside scf.for instead of SmemPtr) + def _to_index(v): + """Convert i32 or index DSL value to raw MLIR index value.""" + raw = v.ir_value() if hasattr(v, "ir_value") else v + if isinstance(raw.type, ir.IndexType): + return raw + return ArithValue(v).index_cast(T.index) + + def _lds_load(raw_mr, idx): + """Load i32 from LDS raw memref. idx can be i32 or index.""" + return fx.Int32(memref_ops.load(raw_mr, [_to_index(idx)])) + + def _lds_store(raw_mr, val, idx): + """Store i32 to LDS raw memref. idx can be i32 or index.""" + v = val.ir_value() if hasattr(val, "ir_value") else val + memref_ops.store(v, raw_mr, [_to_index(idx)]) + + def _unwrap(v): + """Unwrap DSL value to raw MLIR ir.Value for scf.for init.""" + return v.ir_value() if hasattr(v, "ir_value") else v + + @flyc.kernel(known_block_size=[DECODE_BLOCK, 1, 1]) + def moe_sorting_decode_kernel( + topk_ids_tensor: fx.Tensor, + topk_weights_tensor: fx.Tensor, + sorted_token_ids: fx.Tensor, + sorted_weights_out: fx.Tensor, + sorted_expert_ids: fx.Tensor, + num_valid_ids: fx.Tensor, + moe_buf: fx.Tensor, + expert_mask_tensor: fx.Tensor, + i32_tokens: fx.Int32, + i32_moe_buf_elems: fx.Int32, + ): + bid = gpu.block_idx.x + tid = gpu.thread_idx.x + lane = tid % WARP_SIZE + wave = tid // WARP_SIZE + tokens = i32_tokens + c_zero_i32 = fx.Int32(0) + c_one_i32 = fx.Int32(1) + c_oob_idx = fx.Int32(0x7FFFFFFF) + c4_i32 = fx.Int32(4) + + # Buffer resources (needed by both paths, defined at top level) + moe_buf_rsrc = buffer_ops.create_buffer_resource(moe_buf, max_size=True) + topk_ids_rsrc = buffer_ops.create_buffer_resource(topk_ids_tensor, max_size=True) + weights_rsrc = buffer_ops.create_buffer_resource(topk_weights_tensor, max_size=True) + sorted_ids_rsrc = buffer_ops.create_buffer_resource(sorted_token_ids, max_size=True) + sorted_w_rsrc = buffer_ops.create_buffer_resource(sorted_weights_out, max_size=True) + sorted_e_rsrc = buffer_ops.create_buffer_resource(sorted_expert_ids, max_size=True) + nvalid_rsrc = buffer_ops.create_buffer_resource(num_valid_ids, max_size=True) + mask_rsrc = buffer_ops.create_buffer_resource(expert_mask_tensor, max_size=True) + + # LDS: get RAW memrefs ONCE — dominates all child scf.for/scf.if regions. + base_ptr = allocator.get_base() + cumsum_mr = SmemPtr(base_ptr, cumsum_offset, T.i32, shape=(smem_cols,)).get() + cumdup_mr = SmemPtr(base_ptr, cumdup_offset, T.i32, shape=(smem_cols,)).get() + mesh_mr = SmemPtr(base_ptr, mesh_offset, T.i32, shape=(sub_tokens * smem_cols,)).get() + + c_topk = fx.Int32(topk) + c_E = fx.Int32(E) + c_unit = fx.Int32(unit_size) + c_sub_tokens = fx.Int32(sub_tokens) + c_smem_cols = fx.Int32(smem_cols) + c_sentinel = fx.Int32((topk << 24)) + + # =================== MOE_BUF ZEROING (blocks > 0 only) =============== + is_zero_block = bid != c_zero_i32 + _if_zero = scf.IfOp(is_zero_block.ir_value()) + with _if_then(_if_zero): + zero_gid_v4 = (bid - c_one_i32) * fx.Int32(DECODE_BLOCK) + tid + num_zero_blocks = gpu.grid_dim.x - c_one_i32 + zero_stride_v4 = num_zero_blocks * fx.Int32(DECODE_BLOCK) + i32_moe_buf_v4 = i32_moe_buf_elems >> fx.Int32(2) + zero_niters = (i32_moe_buf_v4 + zero_stride_v4 - c_one_i32) // zero_stride_v4 + _zs = fx.Index(0) + _ze = ArithValue(zero_niters).index_cast(T.index) + _z1 = fx.Index(1) + c_zero_v4 = fx.Vector.filled(4, 0, fx.Int32) + c4_i32 = fx.Int32(4) + for _z in range(_zs, _ze, _z1): + z_idx_v4 = zero_gid_v4 + fx.Int32(_z) * zero_stride_v4 + z_valid = z_idx_v4 < i32_moe_buf_v4 + z_elem = z_valid.select(z_idx_v4 * c4_i32, c_oob_idx) + buffer_ops.buffer_store(c_zero_v4, moe_buf_rsrc, z_elem) + + # =================== SORTING (block 0 only) ========================== + is_sort_block = bid == c_zero_i32 + _if_sort = scf.IfOp(is_sort_block.ir_value()) + with _if_then(_if_sort): + # ========================= PHASE 1: Histogram ========================= + # Clear mesh region — unconditional store to safe index when out of bounds + for i_clear in range_constexpr(0, sub_tokens * smem_cols, DECODE_BLOCK): + idx = fx.Int32(i_clear) + tid + is_valid = idx < fx.Int32(sub_tokens * smem_cols) + safe_idx = is_valid.select(idx, c_zero_i32) + safe_idx_ix = ArithValue(safe_idx).index_cast(T.index) + # Always store; out-of-bounds threads harmlessly write to index 0 + _lds_store(mesh_mr, c_zero_i32, safe_idx_ix) + gpu.barrier() + + # Fill mesh: for each (token, topk_slot), write topk_slot+1 to mesh[token, expert_id] + total_assignments = tokens * c_topk + for i_assign in range_constexpr(0, max_tokens * topk, DECODE_BLOCK): + flat_idx = fx.Int32(i_assign) + tid + is_valid = flat_idx < total_assignments + safe_flat = is_valid.select(flat_idx, c_zero_i32) + + token_id = safe_flat // c_topk + topk_slot = safe_flat % c_topk + + global_idx = token_id * c_topk + topk_slot + eid = buffer_ops.buffer_load(topk_ids_rsrc, global_idx, vec_width=1, dtype=T.i32) + + # mesh[token_id, eid] = topk_slot + 1 (valid threads only). + # Invalid threads must NOT write to mesh[0] — that would race + # with a valid write to (token=0, expert=0). + mesh_addr = token_id * c_smem_cols + eid + last_mesh_idx = fx.Int32(sub_tokens * smem_cols - 1) + safe_mesh_addr = is_valid.select(mesh_addr, last_mesh_idx) + safe_mesh_ix = ArithValue(safe_mesh_addr).index_cast(T.index) + val = is_valid.select(topk_slot + c_one_i32, c_zero_i32) + _lds_store(mesh_mr, val, safe_mesh_ix) + gpu.barrier() + + # ===================== PHASE 2: Count + Prefix Sum ===================== + c_lane_group_sz = fx.Int32(8) + lane_group_id = tid // c_lane_group_sz + lane_group_os = tid % c_lane_group_sz + width8_i32 = fx.Int32(8) + + is_t0 = tid == c_zero_i32 + + # Initialize cumsum[0] = 0. All threads write 0 so there's no + # read-modify-write race across waves. + _lds_store(cumsum_mr, c_zero_i32, c_zero_i32) + gpu.barrier() + + for i_e in range_constexpr(0, E, DECODE_BLOCK // 8): + eid_local = fx.Int32(i_e) + lane_group_id + eid_valid = eid_local < c_E + + cnt = c_zero_i32 + for i_sub in range_constexpr(0, sub_tokens, 8): + sub_idx = fx.Int32(i_sub) + lane_group_os + sub_valid = sub_idx < c_sub_tokens + combined_valid = eid_valid & sub_valid + + safe_sub = combined_valid.select(sub_idx, c_zero_i32) + safe_eid = combined_valid.select(eid_local, c_zero_i32) + mesh_rd_addr = safe_sub * c_smem_cols + safe_eid + mesh_rd_ix = ArithValue(mesh_rd_addr).index_cast(T.index) + mesh_val = _lds_load(mesh_mr, mesh_rd_ix) + + has_token = combined_valid.select( + (mesh_val != c_zero_i32).select(c_one_i32, c_zero_i32), + c_zero_i32, + ) + + # Reduce within lane-group of 8 + reduced = has_token + for sh in range_constexpr(3): + off = fx.Int32(1 << sh) + peer = reduced.shuffle_xor(off, width8_i32) + reduced = reduced + peer + cnt = cnt + reduced + + # Only lane 0 of each valid lane-group writes the count to cumsum[eid+1]. + # Invalid threads: write_valid is false, cs_idx = 0, and we write 0 to + # cumsum[0] which is harmless (cumsum[0] is always 0). + write_valid = eid_valid & (lane_group_os == c_zero_i32) + cs_idx = write_valid.select(eid_local + c_one_i32, c_zero_i32) + cs_ix = ArithValue(cs_idx).index_cast(T.index) + cs_val = write_valid.select(cnt, c_zero_i32) + _lds_store(cumsum_mr, cs_val, cs_ix) + gpu.barrier() + + # Phase 2b: Prefix sum over expert counts. + # Step 1: Each thread converts its expert's raw count → padded block size. + for i_cvt in range_constexpr(0, E, DECODE_BLOCK): + cvt_eid = fx.Int32(i_cvt) + tid + cvt_valid = cvt_eid < c_E + # Safe index: valid → cumsum[eid+1], invalid → cumsum[0] (write 0, harmless) + safe_cvt_idx = cvt_valid.select(cvt_eid + c_one_i32, c_zero_i32) + cvt_ix = ArithValue(safe_cvt_idx).index_cast(T.index) + raw_cnt_cvt = _lds_load(cumsum_mr, cvt_ix) + blocks_cvt = (raw_cnt_cvt + c_unit - c_one_i32) // c_unit + padded_cvt = (raw_cnt_cvt == c_zero_i32).select(c_zero_i32, blocks_cvt * c_unit) + # Valid threads write padded value; invalid threads write 0 to cumsum[0] + _lds_store(cumsum_mr, cvt_valid.select(padded_cvt, c_zero_i32), cvt_ix) + gpu.barrier() + + if has_mask: + # EP: zero padded count for masked experts in a separate pass. + # Loading from mask buffer inside the padded-count loop above interfered + # with expert 0 (MLIR codegen issue). Separate pass avoids this. + for i_ep in range_constexpr(0, E, DECODE_BLOCK): + ep_eid = fx.Int32(i_ep) + tid + ep_valid = ep_eid < c_E + ep_safe_eid = ep_valid.select(ep_eid, c_zero_i32) + ep_m = buffer_ops.buffer_load(mask_rsrc, ep_safe_eid, vec_width=1, dtype=T.i32) + should_zero = ep_valid & (ep_m == c_zero_i32) + ep_cs_ix = ArithValue(ep_valid.select(ep_eid + c_one_i32, c_zero_i32)).index_cast(T.index) + _lds_store(cumsum_mr, should_zero.select(c_zero_i32, _lds_load(cumsum_mr, ep_cs_ix)), ep_cs_ix) + gpu.barrier() + + # Step 2: All-wave parallel prefix sum (cumsum → cumdup). + scratch_mr = SmemPtr(base_ptr, scratch_offset, T.i32, shape=(NUM_WAVES,)).get() + + # All threads read cumsum[tid+1] (in chunks for E > DECODE_BLOCK) + for _ps_chunk in range_constexpr(0, E, DECODE_BLOCK): + ps_eid = fx.Int32(_ps_chunk) + tid + ps_valid = ps_eid < c_E + ps_safe_ix = ArithValue(ps_valid.select(ps_eid + c_one_i32, c_zero_i32)).index_cast(T.index) + ps_val = ps_valid.select(_lds_load(cumsum_mr, ps_safe_ix), c_zero_i32) + _lds_store(cumdup_mr, ps_val, ps_safe_ix) + _lds_store(cumdup_mr, is_t0.select(c_zero_i32, _lds_load(cumdup_mr, c_zero_i32)), c_zero_i32) + gpu.barrier() + + # DPP prefix sum — all NUM_WAVES waves active + ps_tid_valid = tid < c_E + val = ps_tid_valid.select(_lds_load(cumdup_mr, ArithValue(tid + c_one_i32).index_cast(T.index)), c_zero_i32) + val = _dpp_intra_wave_prefix_sum(val, lane, WARP_SIZE) + + # Cross-wave accumulation via scratch LDS + is_last_lane_ps = lane == fx.Int32(WARP_SIZE - 1) + _if_ll_ps = scf.IfOp(is_last_lane_ps.ir_value()) + with _if_then(_if_ll_ps): + _lds_store_raw(scratch_mr, val, ArithValue(wave).index_cast(T.index)) + gpu.barrier() + cross_ps = c_zero_i32 + for _w_ps in range_constexpr(NUM_WAVES - 1): + w_ps_total = _lds_load_raw(scratch_mr, ArithValue(fx.Int32(_w_ps)).index_cast(T.index)) + cross_ps = (wave > fx.Int32(_w_ps)).select(cross_ps + w_ps_total, cross_ps) + + inclusive_ps = val + cross_ps + _lds_store( + cumdup_mr, + ps_tid_valid.select(inclusive_ps, c_zero_i32), + ArithValue(ps_tid_valid.select(tid + c_one_i32, c_zero_i32)).index_cast(T.index), + ) + gpu.barrier() + + # For E > DECODE_BLOCK: thread 0 serially extends + if E > DECODE_BLOCK: + _if_t0_ext = scf.IfOp(is_t0.ir_value()) + with _if_then(_if_t0_ext): + prev_ext = _lds_load(cumdup_mr, ArithValue(fx.Int32(DECODE_BLOCK)).index_cast(T.index)) + for _ext in range_constexpr(DECODE_BLOCK, E): + cur_ext = _lds_load(cumdup_mr, ArithValue(fx.Int32(_ext + 1)).index_cast(T.index)) + new_ext = prev_ext + cur_ext + _lds_store(cumdup_mr, new_ext, ArithValue(fx.Int32(_ext + 1)).index_cast(T.index)) + prev_ext = new_ext + gpu.barrier() + + # cumdup[0] = 0 + _lds_store(cumdup_mr, is_t0.select(c_zero_i32, _lds_load(cumdup_mr, c_zero_i32)), c_zero_i32) + gpu.barrier() + + # Write num_valid_ids from cumdup[E] + cs_E_ix_ps = ArithValue(c_E).index_cast(T.index) + total_padded = _lds_load(cumdup_mr, cs_E_ix_ps) + buffer_ops.buffer_store(total_padded, nvalid_rsrc, c_zero_i32) + buffer_ops.buffer_store(tokens, nvalid_rsrc, c_one_i32) + gpu.barrier() + + # Copy cumdup → cumsum (all threads, one expert per thread) + for i_cp in range_constexpr(0, E + 1, DECODE_BLOCK): + cp_idx = fx.Int32(i_cp) + tid + cp_valid = cp_idx <= c_E + safe_cp_idx = cp_valid.select(cp_idx, c_zero_i32) + cp_ix = ArithValue(safe_cp_idx).index_cast(T.index) + cp_val = _lds_load(cumdup_mr, cp_ix) + _lds_store(cumsum_mr, cp_val, cp_ix) + gpu.barrier() + + if has_mask: + # EP: Compute mask cumsum in cumdup for local expert index mapping. + # cumdup[eid] = exclusive prefix sum of mask[0..eid-1] = local expert index. + for i_ml in range_constexpr(0, E, DECODE_BLOCK): + ml_eid = fx.Int32(i_ml) + tid + ml_valid = ml_eid < c_E + safe_ml_eid = ml_valid.select(ml_eid, c_zero_i32) + ml_mask = buffer_ops.buffer_load(mask_rsrc, safe_ml_eid, vec_width=1, dtype=T.i32) + ml_val = ml_valid.select(ml_mask, c_zero_i32) + ml_ix = ArithValue(ml_valid.select(ml_eid + c_one_i32, c_zero_i32)).index_cast(T.index) + _lds_store(cumdup_mr, ml_val, ml_ix) + _lds_store(cumdup_mr, is_t0.select(c_zero_i32, _lds_load(cumdup_mr, c_zero_i32)), c_zero_i32) + gpu.barrier() + + # All-wave DPP prefix sum over mask values in cumdup + m_tid_valid = tid < c_E + mval = m_tid_valid.select( + _lds_load(cumdup_mr, ArithValue(tid + c_one_i32).index_cast(T.index)), c_zero_i32 + ) + mval = _dpp_intra_wave_prefix_sum(mval, lane, WARP_SIZE) + + is_last_lane_m = lane == fx.Int32(WARP_SIZE - 1) + _if_ll_m = scf.IfOp(is_last_lane_m.ir_value()) + with _if_then(_if_ll_m): + _lds_store_raw(scratch_mr, mval, ArithValue(wave).index_cast(T.index)) + gpu.barrier() + cross_m = c_zero_i32 + for _wm in range_constexpr(NUM_WAVES - 1): + wm_total = _lds_load_raw(scratch_mr, ArithValue(fx.Int32(_wm)).index_cast(T.index)) + cross_m = (wave > fx.Int32(_wm)).select(cross_m + wm_total, cross_m) + inclusive_m = mval + cross_m + _lds_store( + cumdup_mr, + m_tid_valid.select(inclusive_m, c_zero_i32), + ArithValue(m_tid_valid.select(tid + c_one_i32, c_zero_i32)).index_cast(T.index), + ) + gpu.barrier() + + if E > DECODE_BLOCK: + _if_t0_ext_m = scf.IfOp(is_t0.ir_value()) + with _if_then(_if_t0_ext_m): + prev_m = _lds_load(cumdup_mr, ArithValue(fx.Int32(DECODE_BLOCK)).index_cast(T.index)) + for _ext_m in range_constexpr(DECODE_BLOCK, E): + cur_m = _lds_load(cumdup_mr, ArithValue(fx.Int32(_ext_m + 1)).index_cast(T.index)) + new_m = prev_m + cur_m + _lds_store(cumdup_mr, new_m, ArithValue(fx.Int32(_ext_m + 1)).index_cast(T.index)) + prev_m = new_m + gpu.barrier() + + _lds_store(cumdup_mr, is_t0.select(c_zero_i32, _lds_load(cumdup_mr, c_zero_i32)), c_zero_i32) + gpu.barrier() + else: + # No mask: cumdup[eid] = eid (identity mapping) + for i_ml in range_constexpr(0, E, DECODE_BLOCK): + ml_eid = fx.Int32(i_ml) + tid + ml_valid = ml_eid < c_E + safe_ml_eid = ml_valid.select(ml_eid, c_zero_i32) + ml_ix = ArithValue(safe_ml_eid).index_cast(T.index) + _lds_store(cumdup_mr, ml_valid.select(safe_ml_eid, c_zero_i32), ml_ix) + gpu.barrier() + + # Write sorted_expert_ids — predicated stores to buffer (safe: buffer_store ignores OOB) + # EP: use cumdup[eid] as local expert index instead of global eid + for i_eid in range_constexpr(0, E, DECODE_BLOCK): + eid_wr = fx.Int32(i_eid) + tid + eid_wr_valid = eid_wr < c_E + safe_eid_wr = eid_wr_valid.select(eid_wr, c_zero_i32) + + cs_start_ix = ArithValue(safe_eid_wr).index_cast(T.index) + cs_end_ix = ArithValue(safe_eid_wr + c_one_i32).index_cast(T.index) + e_start = _lds_load(cumsum_mr, cs_start_ix) + e_end = eid_wr_valid.select(_lds_load(cumsum_mr, cs_end_ix), e_start) + local_eid = _lds_load(cumdup_mr, cs_start_ix) + + # Store cumdup: reuse cumdup for scatter phase position tracking. + # Write e_start to cumdup[eid] (overwriting mask cumsum, no longer needed). + _lds_store(cumdup_mr, e_start, cs_start_ix) + + blk_start = e_start // c_unit + blk_end = e_end // c_unit + n_blks_wr = eid_wr_valid.select(blk_end - blk_start, c_zero_i32) + for _jb in range(fx.Index(0), ArithValue(n_blks_wr).index_cast(T.index), fx.Index(1)): + blk_idx = blk_start + fx.Int32(_jb) + buffer_ops.buffer_store(local_eid, sorted_e_rsrc, blk_idx) + gpu.barrier() + + # Store cumdup[E] = cumsum[E]. + # All threads write cumE to cumdup[E] (all write the same value, no race). + cs_E_ix = ArithValue(c_E).index_cast(T.index) + cumE = _lds_load(cumsum_mr, cs_E_ix) + _lds_store(cumdup_mr, cumE, cs_E_ix) + gpu.barrier() + + # ====================== PRE-FILL: Sentinel fill (cooperative) =========== + # Fill ALL total_padded slots with sentinels BEFORE scatter. + # Scatter will overwrite real token positions. + sentinel_val_pre = c_sentinel | tokens + c_zero_pre = c_zero_i32 + cs_E_ix_pre = ArithValue(c_E).index_cast(T.index) + total_padded_pre = _lds_load(cumdup_mr, cs_E_ix_pre) + n_pre_iters = (total_padded_pre + fx.Int32(DECODE_BLOCK) - c_one_i32) // fx.Int32(DECODE_BLOCK) + for _pre in range(fx.Index(0), ArithValue(n_pre_iters).index_cast(T.index), fx.Index(1)): + pre_slot = fx.Int32(_pre) * fx.Int32(DECODE_BLOCK) + tid + pre_valid = pre_slot < total_padded_pre + safe_pre = pre_valid.select(pre_slot, c_oob_idx) + buffer_ops.buffer_store(sentinel_val_pre, sorted_ids_rsrc, safe_pre) + buffer_ops.buffer_store(c_zero_pre, sorted_w_rsrc, safe_pre) + gpu.barrier() + + # ====================== PHASE 3: Scatter ============================== + # CK uses wave_cumsum<8> (DPP prefix sum) + ds_bpermute (lane broadcast). + # FlyDSL has neither. Instead, each lane reads all 8 mesh values in the + # batch from LDS to compute its own prefix offset. No shuffle needed. + for i_e2 in range_constexpr(0, E, DECODE_BLOCK // 8): + eid_sc = fx.Int32(i_e2) + lane_group_id + eid_sc_valid = eid_sc < c_E + # Invalid lane groups map to cumsum[E] (the total count) instead of + # cumsum[0] to avoid racing with lane_group 0's position write-back. + safe_eid_sc = eid_sc_valid.select(eid_sc, c_E) + + sc_expert_enabled = eid_sc_valid + if has_mask: + # EP: check if this expert is masked (skip scatter for masked experts) + sc_mask_val = buffer_ops.buffer_load( + mask_rsrc, eid_sc_valid.select(eid_sc, c_zero_i32), vec_width=1, dtype=T.i32 + ) + sc_expert_enabled = eid_sc_valid & (sc_mask_val != c_zero_i32) + + cs_sc_ix = ArithValue(safe_eid_sc).index_cast(T.index) + position = _lds_load(cumsum_mr, cs_sc_ix) + + for i_sub2 in range_constexpr(0, sub_tokens, 8): + # This lane handles sub_token (i_sub2 + lane_group_os). + my_sub = fx.Int32(i_sub2) + lane_group_os + my_sub_valid = sc_expert_enabled & (my_sub < c_sub_tokens) + safe_my_sub = my_sub_valid.select(my_sub, c_zero_i32) + my_mesh_addr = safe_my_sub * c_smem_cols + safe_eid_sc + my_mesh_ix = ArithValue(my_mesh_addr).index_cast(T.index) + my_x = _lds_load(mesh_mr, my_mesh_ix) + my_has_token = my_sub_valid & (my_x != c_zero_i32) + local_cnt = my_has_token.select(c_one_i32, c_zero_i32) + + # 8-lane group prefix sum (NOT full-wave — uses lane_group_os, + # only shifts 1,2,4, no cross-row bpermute needed). + cnt_raw = _unwrap(local_cnt) + zero_raw = _unwrap(c_zero_i32) + + # row_shr:1 + remote = fly_rocdl.update_dpp( + T.i32, zero_raw, cnt_raw, DPP_ROW_SHR_1, DPP_ROW_MASK, DPP_BANK_MASK, True + ) + should_add = lane_group_os >= c_one_i32 + local_cnt = should_add.select(local_cnt + fx.Int32(remote), local_cnt) + + # row_shr:2 + cnt_raw = _unwrap(local_cnt) + remote = fly_rocdl.update_dpp( + T.i32, zero_raw, cnt_raw, DPP_ROW_SHR_2, DPP_ROW_MASK, DPP_BANK_MASK, True + ) + should_add = lane_group_os >= fx.Int32(2) + local_cnt = should_add.select(local_cnt + fx.Int32(remote), local_cnt) + + # row_shr:4 + cnt_raw = _unwrap(local_cnt) + remote = fly_rocdl.update_dpp( + T.i32, zero_raw, cnt_raw, DPP_ROW_SHR_4, DPP_ROW_MASK, DPP_BANK_MASK, True + ) + should_add = lane_group_os >= fx.Int32(4) + local_cnt = should_add.select(local_cnt + fx.Int32(remote), local_cnt) + + # Broadcast batch total from last lane of group via ds_bpermute + last_lane_of_group = tid | fx.Int32(7) # tid with lower 3 bits set + last_addr = last_lane_of_group * c4_i32 + batch_total = fly_rocdl.ds_bpermute(T.i32, last_addr, local_cnt) + batch_total = fx.Int32(batch_total) + + # Scatter this lane's token + slot = position + local_cnt - c_one_i32 + safe_x = my_has_token.select(my_x, c_one_i32) + topk_slot_sc = safe_x - c_one_i32 + packed_id = (topk_slot_sc << fx.Int32(24)) | my_sub + safe_slot = my_has_token.select(slot, c_oob_idx) + buffer_ops.buffer_store(packed_id, sorted_ids_rsrc, safe_slot) + + w_addr = my_has_token.select(my_sub * c_topk + topk_slot_sc, c_zero_i32) + w_val_i32 = buffer_ops.buffer_load(weights_rsrc, w_addr, vec_width=1, dtype=T.i32) + buffer_ops.buffer_store(w_val_i32, sorted_w_rsrc, safe_slot) + + # Advance position by batch total + position = position + batch_total + + # Write back updated position (for padding phase). + # Invalid lane groups write position (=0+0=0) to cumsum[0] which is harmless. + _lds_store(cumsum_mr, position, cs_sc_ix) + gpu.barrier() + + # Padding already filled by PRE-FILL phase above (before scatter). + + @flyc.jit + def launch_moe_sorting_decode( + topk_ids_tensor: fx.Tensor, + topk_weights_tensor: fx.Tensor, + sorted_token_ids: fx.Tensor, + sorted_weights_out: fx.Tensor, + sorted_expert_ids: fx.Tensor, + num_valid_ids_out: fx.Tensor, + moe_buf: fx.Tensor, + expert_mask_tensor: fx.Tensor, + i32_tokens: fx.Int32, + i32_moe_buf_elems: fx.Int32, + n_grid_blocks: fx.Int32, + stream: fx.Stream = fx.Stream(None), + ): + allocator.finalized = False + ctx = CompilationContext.get_current() + with ir.InsertionPoint(ctx.gpu_module_body): + allocator.finalize() + + launcher = moe_sorting_decode_kernel( + topk_ids_tensor, + topk_weights_tensor, + sorted_token_ids, + sorted_weights_out, + sorted_expert_ids, + num_valid_ids_out, + moe_buf, + expert_mask_tensor, + i32_tokens, + i32_moe_buf_elems, + ) + launcher.launch( + grid=(n_grid_blocks, 1, 1), + block=(DECODE_BLOCK, 1, 1), + stream=stream, + ) + + return launch_moe_sorting_decode + + +# --------------------------------------------------------------------------- +# LDS helpers for prefill kernels (module-level, used inside @flyc.kernel) +# --------------------------------------------------------------------------- +def _lds_load_raw(raw_mr, idx): + """Load i32 from LDS raw memref. idx can be i32 or index.""" + raw_idx = idx.ir_value() if hasattr(idx, "ir_value") else idx + if not isinstance(raw_idx.type, ir.IndexType): + raw_idx = ArithValue(idx).index_cast(T.index) + raw_idx = raw_idx.ir_value() if hasattr(raw_idx, "ir_value") else raw_idx + return fx.Int32(memref_ops.load(raw_mr, [raw_idx])) + + +def _lds_store_raw(raw_mr, val, idx): + """Store i32 to LDS raw memref. idx can be i32 or index.""" + v = val.ir_value() if hasattr(val, "ir_value") else val + raw_idx = idx.ir_value() if hasattr(idx, "ir_value") else idx + if not isinstance(raw_idx.type, ir.IndexType): + raw_idx = ArithValue(idx).index_cast(T.index) + raw_idx = raw_idx.ir_value() if hasattr(raw_idx, "ir_value") else raw_idx + memref_ops.store(v, raw_mr, [raw_idx]) + + +def _unwrap_raw(v): + """Unwrap DSL value to raw MLIR ir.Value.""" + return v.ir_value() if hasattr(v, "ir_value") else v + + +# --------------------------------------------------------------------------- +# FlyDSL GPU kernels — prefill path (4 kernels, large T via HBM workspace) +# --------------------------------------------------------------------------- +@functools.lru_cache(maxsize=256) +def compile_moe_sorting_prefill( + *, + num_experts: int, + topk: int, + unit_size: int = UNIT_SIZE, + has_mask: bool = False, +): + """Compile the prefill-path MoE sorting kernels. + + For token counts exceeding LDS capacity, uses HBM workspace: + K1: ClearWorkspace — zero the workspace buffer + K2: P0 scatter — scatter topk_ids into expert mesh in HBM + K3: P1 count — one block per expert, count non-zero mesh cells + K4: P23 prefix-sum + scatter — prefix-sum on counts, scatter tokens, + fill sorted_expert_ids, zero moe_buf + P0_v2: Fused clear+scatter+count — replaces K1+K2+K3 for small T (<=512) + + Workspace layout (i32 elements): + [0 .. ws_mesh_i32) : uint8 expert mesh (E rows x mesh_stride bytes, packed into i32) + [ws_mesh_i32 .. ws_mesh_i32 + E+1): expert_cumsum (E+1 i32 entries) + + Parameters + ---------- + num_experts : int + Number of routed experts (e.g. 256 for DeepSeek R1). + topk : int + Experts per token (e.g. 8). + unit_size : int + GEMM tile-M for padding alignment (default 32). + """ + arch = get_hip_arch() + E = num_experts + + # --- K1: ClearWorkspace kernel ------------------------------------------- + # CK uses grid=262144, block=1024 (1 store per thread, no loop). + # Match that: block=1024, grid=ceil(ws_total/1024). + K1_BLOCK = 1024 + + @flyc.kernel(known_block_size=[K1_BLOCK, 1, 1]) + def clear_workspace_kernel( + workspace: fx.Tensor, + i32_total_elems: fx.Int32, + ): + gid = gpu.block_idx.x * fx.Int32(K1_BLOCK) + gpu.thread_idx.x + ws_rsrc = buffer_ops.create_buffer_resource(workspace, max_size=True) + c_zero = fx.Int32(0) + + # Each thread stores exactly one element (no loop needed). + valid = gid < i32_total_elems + buffer_ops.buffer_store(c_zero, ws_rsrc, valid.select(gid, c_zero)) + + @flyc.jit + def launch_clear_ws( + workspace: fx.Tensor, + i32_total_elems: fx.Int32, + n_grid: fx.Int32, + stream: fx.Stream = fx.Stream(None), + ): + launcher = clear_workspace_kernel(workspace, i32_total_elems) + launcher.launch(grid=(n_grid, 1, 1), block=(K1_BLOCK, 1, 1), stream=stream) + + # --- K2: P0 scatter kernel ----------------------------------------------- + # uint8 mesh: stores topk_slot+1 (max 9) as a single byte directly. + # mesh_stride is in bytes; byte_offset = eid * mesh_stride + token_id. + # No two threads write the same byte (unique experts per token). + K2_BLOCK = 256 + + @flyc.kernel + def p0_scatter_kernel( + topk_ids: fx.Tensor, + workspace: fx.Tensor, + i32_tokens: fx.Int32, + i32_mesh_stride: fx.Int32, + i32_niters: fx.Int32, + ): + gid = gpu.block_idx.x * fx.Int32(K2_BLOCK) + gpu.thread_idx.x + stride = gpu.grid_dim.x * fx.Int32(K2_BLOCK) + topk_rsrc = buffer_ops.create_buffer_resource(topk_ids, max_size=True) + ws_rsrc = buffer_ops.create_buffer_resource(workspace, max_size=True) + c_zero = fx.Int32(0) + c_topk = fx.Int32(topk) + c_one = fx.Int32(1) + + total = i32_tokens * c_topk + + _s = fx.Index(0) + _e = ArithValue(i32_niters).index_cast(T.index) + _one = fx.Index(1) + for _i in range(_s, _e, _one): + flat = gid + fx.Int32(_i) * stride + valid = flat < total + safe_flat = valid.select(flat, c_zero) + token_id = safe_flat // c_topk + topk_slot = safe_flat % c_topk + eid = buffer_ops.buffer_load(topk_rsrc, safe_flat, vec_width=1, dtype=T.i32) + byte_offset = eid * i32_mesh_stride + token_id + val_i8 = ArithValue(topk_slot + c_one).trunci(T.i8) + _if_valid_k2 = scf.IfOp(valid.ir_value()) + with _if_then(_if_valid_k2): + buffer_ops.buffer_store(val_i8, ws_rsrc, byte_offset, offset_is_bytes=True) + + @flyc.jit + def launch_p0( + topk_ids: fx.Tensor, + workspace: fx.Tensor, + i32_tokens: fx.Int32, + i32_mesh_stride: fx.Int32, + i32_niters: fx.Int32, + n_grid: fx.Int32, + stream: fx.Stream = fx.Stream(None), + ): + launcher = p0_scatter_kernel(topk_ids, workspace, i32_tokens, i32_mesh_stride, i32_niters) + launcher.launch(grid=(n_grid, 1, 1), block=(K2_BLOCK, 1, 1), stream=stream) + + # --- K3: P1 count kernel ------------------------------------------------- + # 256 threads (4 waves), vec_width=4: each thread loads 4 i32 words (16 + # mesh cells) per iteration. 4 waves provide 4x memory-level parallelism + # vs the old 1-wave (64-thread) design, matching CK P1's block size. + # Cross-warp reduction via LDS (4 partial sums, one per warp). + K3_BLOCK = 256 + K3_NUM_WAVES = K3_BLOCK // WARP_SIZE + K3_VEC_WIDTH = 4 + K3_WORDS_PER_ITER = K3_BLOCK * K3_VEC_WIDTH + K3_WORDS_PER_ITER_LOG2 = (K3_WORDS_PER_ITER).bit_length() - 1 + + k3_allocator = SmemAllocator(None, arch=arch, global_sym_name="smem_storage_p1") + k3_reduce_offset = k3_allocator._align(k3_allocator.ptr, 16) + k3_allocator.ptr = k3_reduce_offset + K3_NUM_WAVES * 4 + + @flyc.kernel + def p1_count_kernel( + workspace: fx.Tensor, + expert_mask_tensor: fx.Tensor, + i32_mesh_stride: fx.Int32, + i32_mesh_size: fx.Int32, + ): + eid = gpu.block_idx.x + tid = gpu.thread_idx.x + lane = tid % WARP_SIZE + wave = tid // WARP_SIZE + + ws_rsrc = buffer_ops.create_buffer_resource(workspace, max_size=True) + c_zero = fx.Int32(0) + c_one = fx.Int32(1) + c_ff = fx.Int32(0xFF) + + base_ptr = k3_allocator.get_base() + reduce_mr = SmemPtr(base_ptr, k3_reduce_offset, T.i32, shape=(K3_NUM_WAVES,)).get() + + mesh_row_i32_base = (eid * i32_mesh_stride) >> fx.Int32(2) + i32_words_per_row = i32_mesh_stride >> fx.Int32(2) + n_iters = (i32_words_per_row + fx.Int32(K3_WORDS_PER_ITER - 1)) >> fx.Int32(K3_WORDS_PER_ITER_LOG2) + + if has_mask: + mask_rsrc = buffer_ops.create_buffer_resource(expert_mask_tensor, max_size=True) + p1_mask = buffer_ops.buffer_load(mask_rsrc, eid, vec_width=1, dtype=T.i32) + p1_is_local = p1_mask != c_zero + p1_should_zero = (~p1_is_local) & (tid == c_zero) + buffer_ops.buffer_store(c_zero, ws_rsrc, p1_should_zero.select(i32_mesh_size + eid, fx.Int32(0x7FFFFFFF))) + n_iters = p1_is_local.select(n_iters, c_zero) + + for _i, state in range(fx.Index(0), ArithValue(n_iters).index_cast(T.index), fx.Index(1), init=[c_zero]): + cnt_so_far = state[0] + + word_base = fx.Int32(_i) * fx.Int32(K3_WORDS_PER_ITER) + tid * fx.Int32(K3_VEC_WIDTH) + valid = word_base < i32_words_per_row + safe_addr = mesh_row_i32_base + valid.select(word_base, c_zero) + vec4 = buffer_ops.buffer_load(ws_rsrc, safe_addr, vec_width=4, dtype=T.i32) + + iter_cnt = c_zero + for _wi in range_constexpr(K3_VEC_WIDTH): + word = Vec(vec4)[_wi] + word_valid = valid & ((word_base + fx.Int32(_wi)) < i32_words_per_row) + b0 = word & c_ff + b1 = (word >> fx.Int32(8)) & c_ff + b2 = (word >> fx.Int32(16)) & c_ff + b3 = (word >> fx.Int32(24)) & c_ff + nz0 = word_valid.select((b0 != c_zero).select(c_one, c_zero), c_zero) + nz1 = word_valid.select((b1 != c_zero).select(c_one, c_zero), c_zero) + nz2 = word_valid.select((b2 != c_zero).select(c_one, c_zero), c_zero) + nz3 = word_valid.select((b3 != c_zero).select(c_one, c_zero), c_zero) + iter_cnt = iter_cnt + nz0 + nz1 + nz2 + nz3 + + new_cnt = cnt_so_far + iter_cnt + results = yield [new_cnt] + cnt = results + + # Intra-warp reduce via shuffle_xor + width_ws = fx.Int32(WARP_SIZE) + for sh in range_constexpr(int.bit_length(WARP_SIZE) - 1): + off = fx.Int32(1 << sh) + peer = cnt.shuffle_xor(off, width_ws) + cnt = cnt + peer + + # Cross-warp reduce via LDS: lane 0 of each warp writes partial sum + is_lane0 = lane == c_zero + _if_l0 = scf.IfOp(is_lane0.ir_value()) + with _if_then(_if_l0): + wave_ix = ArithValue(wave).index_cast(T.index) + _lds_store_raw(reduce_mr, cnt, wave_ix) + gpu.barrier() + + # Thread 0 sums all warp partials and writes to HBM + is_t0 = tid == c_zero + total = c_zero + for _w in range_constexpr(K3_NUM_WAVES): + total = total + _lds_load_raw(reduce_mr, ArithValue(fx.Int32(_w)).index_cast(T.index)) + + cs_offset = i32_mesh_size + eid + c_oob_idx = fx.Int32(0x7FFFFFFF) + safe_cs = is_t0.select(cs_offset, c_oob_idx) + buffer_ops.buffer_store(total, ws_rsrc, safe_cs) + + @flyc.jit + def launch_p1( + workspace: fx.Tensor, + expert_mask_tensor: fx.Tensor, + i32_mesh_stride: fx.Int32, + i32_mesh_size: fx.Int32, + stream: fx.Stream = fx.Stream(None), + ): + k3_allocator.finalized = False + ctx = CompilationContext.get_current() + with ir.InsertionPoint(ctx.gpu_module_body): + k3_allocator.finalize() + + launcher = p1_count_kernel(workspace, expert_mask_tensor, i32_mesh_stride, i32_mesh_size) + launcher.launch(grid=(E, 1, 1), block=(K3_BLOCK, 1, 1), stream=stream) + + # --- P0_v2: Fused clear+scatter+count kernel (for small T) --------------- + # Replaces K1+K2+K3 with a single kernel launch. + # Grid: E blocks (one per expert), Block: 512 threads (matching CK P0_v2). + # Phase 1: clear this expert's mesh row + # Phase 2: scan all T*topk assignments, filter by expert, byte stores + # Phase 3: popcount + warp reduce + cross-wave LDS reduce -> expert_cumsum + P0V2_BLOCK = 512 + P0V2_NUM_WAVES = P0V2_BLOCK // WARP_SIZE + + # Power-of-2 topk: use shift to avoid division + _p0v2_topk_is_po2 = (topk & (topk - 1)) == 0 and topk > 0 + _p0v2_topk_log2 = topk.bit_length() - 1 if _p0v2_topk_is_po2 else 0 + + # LDS for cross-wave reduction (same layout as K3) + p0v2_allocator = SmemAllocator(None, arch=arch, global_sym_name="smem_storage_p0v2") + p0v2_reduce_offset = p0v2_allocator._align(p0v2_allocator.ptr, 16) + p0v2_allocator.ptr = p0v2_reduce_offset + P0V2_NUM_WAVES * 4 + + @flyc.kernel(known_block_size=[P0V2_BLOCK, 1, 1]) + def p0v2_kernel( + topk_ids: fx.Tensor, + workspace: fx.Tensor, + expert_mask_tensor: fx.Tensor, + i32_tokens: fx.Int32, + i32_mesh_stride: fx.Int32, + i32_mesh_size: fx.Int32, + ): + eid = gpu.block_idx.x + tid = gpu.thread_idx.x + lane = tid % WARP_SIZE + wave = tid // WARP_SIZE + + ws_rsrc = buffer_ops.create_buffer_resource(workspace, max_size=True) + mask_rsrc = buffer_ops.create_buffer_resource(expert_mask_tensor, max_size=True) + topk_rsrc = buffer_ops.create_buffer_resource(topk_ids, max_size=True) + c_zero = fx.Int32(0) + c_oob = fx.Int32(0x7FFFFFFF) + c_one = fx.Int32(1) + c_ff = fx.Int32(0xFF) + c_topk = fx.Int32(topk) + c_block = fx.Int32(P0V2_BLOCK) + + base_ptr = p0v2_allocator.get_base() + reduce_mr = SmemPtr(base_ptr, p0v2_reduce_offset, T.i32, shape=(P0V2_NUM_WAVES,)).get() + + # Precompute mesh row base (in i32 words) and words per row + mesh_row_i32_base = (eid * i32_mesh_stride) >> fx.Int32(2) + i32_words_per_row = i32_mesh_stride >> fx.Int32(2) + + clear_niters = (i32_words_per_row + fx.Int32(P0V2_BLOCK - 1)) >> fx.Int32(9) + total_assignments = i32_tokens * c_topk + scatter_niters = (total_assignments + fx.Int32(P0V2_BLOCK - 1)) >> fx.Int32(9) + + # EP: load mask, write cumsum=0 for masked experts, set loop bounds to 0 + if has_mask: + m_val = buffer_ops.buffer_load(mask_rsrc, eid, vec_width=1, dtype=T.i32) + is_local_expert = m_val != c_zero + should_write_zero = (~is_local_expert) & (tid == c_zero) + buffer_ops.buffer_store(c_zero, ws_rsrc, should_write_zero.select(i32_mesh_size + eid, c_oob)) + clear_niters = is_local_expert.select(clear_niters, c_zero) + scatter_niters = is_local_expert.select(scatter_niters, c_zero) + + # ---- Phase 1: Clear this expert's mesh row ---- + for _ci in range(fx.Index(0), ArithValue(clear_niters).index_cast(T.index), fx.Index(1)): + word_idx = fx.Int32(_ci) * c_block + tid + valid = word_idx < i32_words_per_row + safe_idx = mesh_row_i32_base + valid.select(word_idx, c_zero) + buffer_ops.buffer_store(c_zero, ws_rsrc, valid.select(safe_idx, c_oob)) + + gpu.barrier() + + # ---- Phase 2: Scatter (scan all T*topk, filter by expert) ---- + for _si in range(fx.Index(0), ArithValue(scatter_niters).index_cast(T.index), fx.Index(1)): + flat = fx.Int32(_si) * c_block + tid + valid = flat < total_assignments + safe_flat = valid.select(flat, c_zero) + + token_id = safe_flat >> fx.Int32(_p0v2_topk_log2) if _p0v2_topk_is_po2 else safe_flat // c_topk + topk_slot = safe_flat & fx.Int32(topk - 1) if _p0v2_topk_is_po2 else safe_flat % c_topk + + expert_id = buffer_ops.buffer_load(topk_rsrc, safe_flat, vec_width=1, dtype=T.i32) + + is_mine = valid & (expert_id == eid) + byte_offset = eid * i32_mesh_stride + token_id + val_i8 = ArithValue(is_mine.select(topk_slot + c_one, c_zero)).trunci(T.i8) + # Byte-mode buffer_store with OOB offset crashes on AMD GPUs. + # Use conditional branch to skip the store for non-matching threads. + _if_mine = scf.IfOp(is_mine.ir_value()) + with _if_then(_if_mine): + buffer_ops.buffer_store(val_i8, ws_rsrc, byte_offset, offset_is_bytes=True) + + gpu.barrier() + + # ---- Phase 3: Count non-zero bytes + warp/cross-wave reduce ---- + count_niters = clear_niters # same loop structure, reuse (already EP-gated) + for _ki, state in range(fx.Index(0), ArithValue(count_niters).index_cast(T.index), fx.Index(1), init=[c_zero]): + cnt_so_far = state[0] + + word_base = fx.Int32(_ki) * c_block + tid + valid = word_base < i32_words_per_row + safe_addr = mesh_row_i32_base + valid.select(word_base, c_zero) + word = buffer_ops.buffer_load(ws_rsrc, safe_addr, vec_width=1, dtype=T.i32) + + b0 = word & c_ff + b1 = (word >> fx.Int32(8)) & c_ff + b2 = (word >> fx.Int32(16)) & c_ff + b3 = (word >> fx.Int32(24)) & c_ff + nz0 = valid.select((b0 != c_zero).select(c_one, c_zero), c_zero) + nz1 = valid.select((b1 != c_zero).select(c_one, c_zero), c_zero) + nz2 = valid.select((b2 != c_zero).select(c_one, c_zero), c_zero) + nz3 = valid.select((b3 != c_zero).select(c_one, c_zero), c_zero) + iter_cnt = nz0 + nz1 + nz2 + nz3 + + new_cnt = cnt_so_far + iter_cnt + results = yield [new_cnt] + cnt = results + + # Intra-warp reduce via shuffle_xor + width_ws = fx.Int32(WARP_SIZE) + for sh in range_constexpr(int.bit_length(WARP_SIZE) - 1): + off = fx.Int32(1 << sh) + peer = cnt.shuffle_xor(off, width_ws) + cnt = cnt + peer + + # Cross-warp reduce via LDS: lane 0 of each warp writes partial sum + is_lane0 = lane == c_zero + _if_l0 = scf.IfOp(is_lane0.ir_value()) + with _if_then(_if_l0): + wave_ix = ArithValue(wave).index_cast(T.index) + _lds_store_raw(reduce_mr, cnt, wave_ix) + gpu.barrier() + + # Thread 0 sums all warp partials and writes to HBM + is_t0 = tid == c_zero + total = c_zero + for _w in range_constexpr(P0V2_NUM_WAVES): + total = total + _lds_load_raw(reduce_mr, ArithValue(fx.Int32(_w)).index_cast(T.index)) + + cs_offset = i32_mesh_size + eid + c_oob_idx = fx.Int32(0x7FFFFFFF) + safe_cs = is_t0.select(cs_offset, c_oob_idx) + buffer_ops.buffer_store(total, ws_rsrc, safe_cs) + + @flyc.jit + def launch_p0v2( + topk_ids: fx.Tensor, + workspace: fx.Tensor, + expert_mask_tensor: fx.Tensor, + i32_tokens: fx.Int32, + i32_mesh_stride: fx.Int32, + i32_mesh_size: fx.Int32, + stream: fx.Stream = fx.Stream(None), + ): + p0v2_allocator.finalized = False + ctx = CompilationContext.get_current() + with ir.InsertionPoint(ctx.gpu_module_body): + p0v2_allocator.finalize() + + launcher = p0v2_kernel(topk_ids, workspace, expert_mask_tensor, i32_tokens, i32_mesh_stride, i32_mesh_size) + launcher.launch(grid=(E, 1, 1), block=(P0V2_BLOCK, 1, 1), stream=stream) + + # --- K4: P23 prefix-sum + scatter + moe_buf zeroing --------------------- + # Parallel design (matching CK P23): each block [0, E) independently + # computes the SAME prefix sum, then scatters ONLY for expert blockIdx.x. + # No inter-block barrier needed — redundant prefix sums are deterministic. + K4_BLOCK = 256 if E <= 256 else 512 + + # LDS: cumsum[E+1] for prefix sums + cross-wave scratch for DPP scan + K4_NUM_WAVES = K4_BLOCK // WARP_SIZE + k4_allocator = SmemAllocator(None, arch=arch) + k4_smem_cols = max(E + 1, K4_BLOCK + 1) + k4_cumsum_offset = k4_allocator._align(k4_allocator.ptr, 16) + k4_allocator.ptr = k4_cumsum_offset + k4_smem_cols * 4 + k4_scatter_offset = k4_allocator._align(k4_allocator.ptr, 16) + k4_allocator.ptr = k4_scatter_offset + K4_NUM_WAVES * 4 + + @flyc.kernel(known_block_size=[K4_BLOCK, 1, 1]) + def p23_kernel( + workspace: fx.Tensor, + topk_weights_tensor: fx.Tensor, + sorted_token_ids: fx.Tensor, + sorted_weights_out: fx.Tensor, + sorted_expert_ids: fx.Tensor, + num_valid_ids: fx.Tensor, + moe_buf: fx.Tensor, + expert_mask_tensor: fx.Tensor, + i32_tokens: fx.Int32, + i32_mesh_stride: fx.Int32, + i32_mesh_size: fx.Int32, + i32_moe_buf_elems: fx.Int32, + ): + bid = gpu.block_idx.x + tid = gpu.thread_idx.x + lane = tid % WARP_SIZE + wave = tid // WARP_SIZE + c_zero = fx.Int32(0) + c_one = fx.Int32(1) + c4 = fx.Int32(4) + c_E = fx.Int32(E) + c_unit = fx.Int32(unit_size) + c_topk = fx.Int32(topk) + c_sentinel = fx.Int32(topk << 24) + c_oob_idx = fx.Int32(0x7FFFFFFF) + c_ff = fx.Int32(0xFF) + + # Buffer resources + ws_rsrc = buffer_ops.create_buffer_resource(workspace, max_size=True) + weights_rsrc = buffer_ops.create_buffer_resource(topk_weights_tensor, max_size=True) + sorted_ids_rsrc = buffer_ops.create_buffer_resource(sorted_token_ids, max_size=True) + sorted_w_rsrc = buffer_ops.create_buffer_resource(sorted_weights_out, max_size=True) + mask_rsrc = buffer_ops.create_buffer_resource(expert_mask_tensor, max_size=True) + + # LDS: cumsum[E+1] for prefix sums + cross-wave scratch + base_ptr = k4_allocator.get_base() + cumsum_mr = SmemPtr(base_ptr, k4_cumsum_offset, T.i32, shape=(k4_smem_cols,)).get() + scatter_mr = SmemPtr(base_ptr, k4_scatter_offset, T.i32, shape=(K4_NUM_WAVES,)).get() + + is_sort_block = bid < c_E + is_zero_block = bid >= c_E + + # ================ MOE_BUF ZEROING (blocks >= E) ================== + _if_zero = scf.IfOp(is_zero_block.ir_value()) + with _if_then(_if_zero): + moe_buf_rsrc = buffer_ops.create_buffer_resource(moe_buf, max_size=True) + zero_base_bid = bid - c_E + zero_gid_v4 = zero_base_bid * fx.Int32(K4_BLOCK) + tid + num_zero_blocks = gpu.grid_dim.x - c_E + zero_stride_v4 = num_zero_blocks * fx.Int32(K4_BLOCK) + i32_moe_buf_v4 = i32_moe_buf_elems >> fx.Int32(2) + zero_niters = (i32_moe_buf_v4 + zero_stride_v4 - c_one) // zero_stride_v4 + c_zero_v4 = fx.Vector.filled(4, 0, fx.Int32) + for _z in range(fx.Index(0), ArithValue(zero_niters).index_cast(T.index), fx.Index(1)): + z_idx_v4 = zero_gid_v4 + fx.Int32(_z) * zero_stride_v4 + z_valid = z_idx_v4 < i32_moe_buf_v4 + z_elem = z_valid.select(z_idx_v4 * c4, c_oob_idx) + buffer_ops.buffer_store(c_zero_v4, moe_buf_rsrc, z_elem) + + # ================ PARALLEL PREFIX-SUM + MESH SCATTER (blocks 0..E-1) == + # Each block independently: prefix sum (redundant), scatter for its expert only. + _if_sort = scf.IfOp(is_sort_block.ir_value()) + with _if_then(_if_sort): + my_expert = bid + + # Step 1: Load expert counts from workspace -> pad to unit_size -> LDS cumsum + # Process E experts in chunks of K4_BLOCK (256). Most models have + # E <= 256, so the extra chunk is only needed for E > 256 + # (e.g. DeepSeek-R1 with 256 routed + 1 shared = 257). + is_t0_init = tid == c_zero + _if_init_cs = scf.IfOp(is_t0_init.ir_value()) + with _if_then(_if_init_cs): + _lds_store_raw(cumsum_mr, c_zero, ArithValue(c_zero).index_cast(T.index)) + for _chunk in range_constexpr(0, E, K4_BLOCK): + expert_idx = fx.Int32(_chunk) + tid + tid_valid_expert = expert_idx < c_E + ws_cs_addr = i32_mesh_size + tid_valid_expert.select(expert_idx, c_zero) + raw_cnt = buffer_ops.buffer_load(ws_rsrc, ws_cs_addr, vec_width=1, dtype=T.i32) + raw_cnt = tid_valid_expert.select(raw_cnt, c_zero) + blocks = (raw_cnt + c_unit - c_one) // c_unit + padded = (raw_cnt == c_zero).select(c_zero, blocks * c_unit) + p23_mask_val = c_one + if has_mask: + p23_mask_val = buffer_ops.buffer_load( + mask_rsrc, tid_valid_expert.select(expert_idx, c_zero), vec_width=1, dtype=T.i32 + ) + p23_mask_val = tid_valid_expert.select(p23_mask_val, c_zero) + padded = (p23_mask_val == c_zero).select(c_zero, padded) + _lds_store_raw(cumsum_mr, padded, ArithValue(expert_idx + c_one).index_cast(T.index)) + gpu.barrier() + + # Step 2: Prefix sum over cumsum LDS. When E <= K4_BLOCK (256), + # a single DPP pass covers all experts. When E > K4_BLOCK, we + # do the DPP pass for the first K4_BLOCK elements, then serially + # accumulate the remaining entries from thread 0. + val = _lds_load_raw(cumsum_mr, ArithValue(tid + c_one).index_cast(T.index)) + val = _dpp_intra_wave_prefix_sum(val, lane, WARP_SIZE) + + is_last_lane_ps = lane == fx.Int32(WARP_SIZE - 1) + _if_ll_ps = scf.IfOp(is_last_lane_ps.ir_value()) + with _if_then(_if_ll_ps): + _lds_store_raw(scatter_mr, val, ArithValue(wave).index_cast(T.index)) + gpu.barrier() + cross_offset = c_zero + for _w in range_constexpr(K4_NUM_WAVES - 1): + w_total = _lds_load_raw(scatter_mr, ArithValue(fx.Int32(_w)).index_cast(T.index)) + cross_offset = (wave > fx.Int32(_w)).select(cross_offset + w_total, cross_offset) + total_padded = c_zero + for _w in range_constexpr(K4_NUM_WAVES): + total_padded = total_padded + _lds_load_raw(scatter_mr, ArithValue(fx.Int32(_w)).index_cast(T.index)) + + inclusive_prefix = val + cross_offset + _lds_store_raw(cumsum_mr, inclusive_prefix, ArithValue(tid + c_one).index_cast(T.index)) + gpu.barrier() + + # For E > K4_BLOCK: thread 0 serially extends the prefix sum + # for experts K4_BLOCK..E-1 (at most a few iterations). + if E > K4_BLOCK: + _if_t0_extra = scf.IfOp(is_t0_init.ir_value()) + with _if_then(_if_t0_extra): + prev_sum = _lds_load_raw(cumsum_mr, ArithValue(fx.Int32(K4_BLOCK)).index_cast(T.index)) + for _extra in range_constexpr(K4_BLOCK, E): + cur = _lds_load_raw(cumsum_mr, ArithValue(fx.Int32(_extra + 1)).index_cast(T.index)) + new_sum = prev_sum + cur + _lds_store_raw(cumsum_mr, new_sum, ArithValue(fx.Int32(_extra + 1)).index_cast(T.index)) + prev_sum = new_sum + total_padded = prev_sum + gpu.barrier() + total_padded = _lds_load_raw(cumsum_mr, ArithValue(c_E).index_cast(T.index)) + + # Read my_start and my_end from cumsum LDS + my_start = _lds_load_raw(cumsum_mr, ArithValue(my_expert).index_cast(T.index)) + my_end = _lds_load_raw(cumsum_mr, ArithValue(my_expert + c_one).index_cast(T.index)) + + local_idx_p23 = tid + if has_mask: + # EP: Compute mask cumsum for local expert index (register-only DPP scan). + p23_mv = _dpp_intra_wave_prefix_sum(p23_mask_val, lane, WARP_SIZE) + + # Cross-wave via scratch_mr + is_last_lane_pm = lane == fx.Int32(WARP_SIZE - 1) + _if_ll_pm = scf.IfOp(is_last_lane_pm.ir_value()) + with _if_then(_if_ll_pm): + _lds_store_raw(scatter_mr, p23_mv, ArithValue(wave).index_cast(T.index)) + gpu.barrier() + cross_pm = c_zero + for _w in range_constexpr(K4_NUM_WAVES - 1): + wtp = _lds_load_raw(scatter_mr, ArithValue(fx.Int32(_w)).index_cast(T.index)) + cross_pm = (wave > fx.Int32(_w)).select(cross_pm + wtp, cross_pm) + p23_mask_inclusive = p23_mv + cross_pm + local_idx_p23 = p23_mask_inclusive - p23_mask_val + else: + local_idx_p23 = tid + + # Block 0, thread 0 writes num_valid_ids + is_b0 = bid == c_zero + is_t0 = tid == c_zero + is_b0_t0 = is_b0 & is_t0 + _if_nv = scf.IfOp(is_b0_t0.ir_value()) + with _if_then(_if_nv): + nvalid_rsrc = buffer_ops.create_buffer_resource(num_valid_ids, max_size=True) + buffer_ops.buffer_store(total_padded, nvalid_rsrc, c_zero) + buffer_ops.buffer_store(i32_tokens, nvalid_rsrc, c_one) + + # Step 3: Write sorted_expert_ids for THIS expert (using local_idx_p23 for EP) + # Store local_idx to LDS cumsum[tid], barrier, read cumsum[my_expert] + _lds_store_raw(cumsum_mr, local_idx_p23, ArithValue(tid).index_cast(T.index)) + # For E > K4_BLOCK: thread 0 writes local_idx for extra experts + if E > K4_BLOCK: + is_t0_extra3 = tid == c_zero + _if_t0_e3 = scf.IfOp(is_t0_extra3.ir_value()) + with _if_then(_if_t0_e3): + for _e3 in range_constexpr(K4_BLOCK, E): + _lds_store_raw(cumsum_mr, fx.Int32(_e3), ArithValue(fx.Int32(_e3)).index_cast(T.index)) + gpu.barrier() + my_local_idx = _lds_load_raw(cumsum_mr, ArithValue(my_expert).index_cast(T.index)) + + sorted_e_rsrc = buffer_ops.create_buffer_resource(sorted_expert_ids, max_size=True) + blk_start = my_start // c_unit + blk_end = my_end // c_unit + n_blks = blk_end - blk_start + n_eid_iters = (n_blks + fx.Int32(K4_BLOCK) - c_one) // fx.Int32(K4_BLOCK) + for _eii in range(fx.Index(0), ArithValue(n_eid_iters).index_cast(T.index), fx.Index(1)): + blk_idx = blk_start + fx.Int32(_eii) * fx.Int32(K4_BLOCK) + tid + buffer_ops.buffer_store(my_local_idx, sorted_e_rsrc, (blk_idx < blk_end).select(blk_idx, c_oob_idx)) + + # Step 4: Mesh-based scatter — read uint8 mesh from HBM, extract tokens, + # DPP prefix sum over counts, cross-wave LDS reduction, scatter stores. + p23_bid_enabled = c_one != c_zero + if has_mask: + # EP: skip scatter for masked experts (my_start == my_end, but mesh has data) + p23_bid_mask = buffer_ops.buffer_load(mask_rsrc, my_expert, vec_width=1, dtype=T.i32) + p23_bid_enabled = p23_bid_mask != c_zero + + i32_words_per_row = i32_mesh_stride >> fx.Int32(2) + n_mesh_iters_raw = (i32_words_per_row + fx.Int32(K4_BLOCK - 1)) // fx.Int32(K4_BLOCK) + has_work = my_start != my_end + n_mesh_iters = has_work.select(n_mesh_iters_raw, c_zero) + mesh_row_i32_base = (my_expert * i32_mesh_stride) >> fx.Int32(2) + + for _si, state in range( + fx.Index(0), ArithValue(n_mesh_iters).index_cast(T.index), fx.Index(1), init=[my_start] + ): + position = state[0] + + word_idx = fx.Int32(_si) * fx.Int32(K4_BLOCK) + tid + col_valid = p23_bid_enabled & (word_idx < i32_words_per_row) + safe_word_idx = col_valid.select(word_idx, c_zero) + word = buffer_ops.buffer_load(ws_rsrc, mesh_row_i32_base + safe_word_idx, vec_width=1, dtype=T.i32) + + # Extract 4 bytes from the i32 word + x0 = word & c_ff + x1 = (word >> fx.Int32(8)) & c_ff + x2 = (word >> fx.Int32(16)) & c_ff + x3 = (word >> fx.Int32(24)) & c_ff + base_col = word_idx * c4 + + h0 = col_valid & (x0 != c_zero) + h1 = col_valid & (x1 != c_zero) + h2 = col_valid & (x2 != c_zero) + h3 = col_valid & (x3 != c_zero) + + my_cnt = ( + h0.select(c_one, c_zero) + + h1.select(c_one, c_zero) + + h2.select(c_one, c_zero) + + h3.select(c_one, c_zero) + ) + + my_cnt = _dpp_intra_wave_prefix_sum(my_cnt, lane, WARP_SIZE) + + # my_cnt is now intra-wave inclusive prefix sum of per-thread token counts. + # Cross-wave reduction via LDS scratch. + is_last_lane_sc = lane == fx.Int32(WARP_SIZE - 1) + _if_ll_sc = scf.IfOp(is_last_lane_sc.ir_value()) + with _if_then(_if_ll_sc): + _lds_store_raw(scatter_mr, my_cnt, ArithValue(wave).index_cast(T.index)) + gpu.barrier() + + wave_offset = c_zero + for _w in range_constexpr(K4_NUM_WAVES - 1): + w_total = _lds_load_raw(scatter_mr, ArithValue(fx.Int32(_w)).index_cast(T.index)) + wave_offset = (wave > fx.Int32(_w)).select(wave_offset + w_total, wave_offset) + batch_total = c_zero + for _w in range_constexpr(K4_NUM_WAVES): + batch_total = batch_total + _lds_load_raw(scatter_mr, ArithValue(fx.Int32(_w)).index_cast(T.index)) + gpu.barrier() + + # Convert to exclusive prefix: my_exclusive = my_cnt - my_thread_count + my_thread_count = ( + h0.select(c_one, c_zero) + + h1.select(c_one, c_zero) + + h2.select(c_one, c_zero) + + h3.select(c_one, c_zero) + ) + my_exclusive = my_cnt - my_thread_count + wave_offset + + # Scatter: compute all addresses, batch-load weights, then batch-store. + scatter_base = position + my_exclusive + + # Compute packed IDs and output slots for all 4 tokens + token_id_0 = base_col + topk_slot_0 = h0.select(x0 - c_one, c_zero) + pid_0 = (topk_slot_0 << fx.Int32(24)) | token_id_0 + safe_slot_0 = h0.select(scatter_base, c_oob_idx) + + off1 = scatter_base + h0.select(c_one, c_zero) + token_id_1 = base_col + c_one + topk_slot_1 = h1.select(x1 - c_one, c_zero) + pid_1 = (topk_slot_1 << fx.Int32(24)) | token_id_1 + safe_slot_1 = h1.select(off1, c_oob_idx) + + off2 = off1 + h1.select(c_one, c_zero) + token_id_2 = base_col + fx.Int32(2) + topk_slot_2 = h2.select(x2 - c_one, c_zero) + pid_2 = (topk_slot_2 << fx.Int32(24)) | token_id_2 + safe_slot_2 = h2.select(off2, c_oob_idx) + + off3 = off2 + h2.select(c_one, c_zero) + token_id_3 = base_col + fx.Int32(3) + topk_slot_3 = h3.select(x3 - c_one, c_zero) + pid_3 = (topk_slot_3 << fx.Int32(24)) | token_id_3 + safe_slot_3 = h3.select(off3, c_oob_idx) + + # Batch-issue all 4 weight loads (increases load-use distance) + w_addr_0 = h0.select(token_id_0 * c_topk + topk_slot_0, c_zero) + w_addr_1 = h1.select(token_id_1 * c_topk + topk_slot_1, c_zero) + w_addr_2 = h2.select(token_id_2 * c_topk + topk_slot_2, c_zero) + w_addr_3 = h3.select(token_id_3 * c_topk + topk_slot_3, c_zero) + w_val_0 = buffer_ops.buffer_load(weights_rsrc, w_addr_0, vec_width=1, dtype=T.i32) + w_val_1 = buffer_ops.buffer_load(weights_rsrc, w_addr_1, vec_width=1, dtype=T.i32) + w_val_2 = buffer_ops.buffer_load(weights_rsrc, w_addr_2, vec_width=1, dtype=T.i32) + w_val_3 = buffer_ops.buffer_load(weights_rsrc, w_addr_3, vec_width=1, dtype=T.i32) + + # Batch-store: all packed IDs, then all weights + buffer_ops.buffer_store(pid_0, sorted_ids_rsrc, safe_slot_0) + buffer_ops.buffer_store(pid_1, sorted_ids_rsrc, safe_slot_1) + buffer_ops.buffer_store(pid_2, sorted_ids_rsrc, safe_slot_2) + buffer_ops.buffer_store(pid_3, sorted_ids_rsrc, safe_slot_3) + buffer_ops.buffer_store(w_val_0, sorted_w_rsrc, safe_slot_0) + buffer_ops.buffer_store(w_val_1, sorted_w_rsrc, safe_slot_1) + buffer_ops.buffer_store(w_val_2, sorted_w_rsrc, safe_slot_2) + buffer_ops.buffer_store(w_val_3, sorted_w_rsrc, safe_slot_3) + + pos_next = position + batch_total + results = yield [pos_next] + scatter_end_pos_t0 = results + + # Step 5: Fill padding with sentinel for THIS expert (parallel) + sentinel_val = c_sentinel | i32_tokens + pad_count = my_end - scatter_end_pos_t0 + pad_niters = (pad_count + fx.Int32(K4_BLOCK) - c_one) // fx.Int32(K4_BLOCK) + for _pi in range(fx.Index(0), ArithValue(pad_niters).index_cast(T.index), fx.Index(1)): + pad_slot = scatter_end_pos_t0 + fx.Int32(_pi) * fx.Int32(K4_BLOCK) + tid + pad_valid = pad_slot < my_end + buffer_ops.buffer_store(sentinel_val, sorted_ids_rsrc, pad_valid.select(pad_slot, c_oob_idx)) + buffer_ops.buffer_store(c_zero, sorted_w_rsrc, pad_valid.select(pad_slot, c_oob_idx)) + + @flyc.jit + def launch_p23( + workspace: fx.Tensor, + topk_weights_tensor: fx.Tensor, + sorted_token_ids: fx.Tensor, + sorted_weights_out: fx.Tensor, + sorted_expert_ids: fx.Tensor, + num_valid_ids_out: fx.Tensor, + moe_buf: fx.Tensor, + expert_mask_tensor: fx.Tensor, + i32_tokens: fx.Int32, + i32_mesh_stride: fx.Int32, + i32_mesh_size: fx.Int32, + i32_moe_buf_elems: fx.Int32, + n_grid: fx.Int32, + stream: fx.Stream = fx.Stream(None), + ): + k4_allocator.finalized = False + ctx = CompilationContext.get_current() + with ir.InsertionPoint(ctx.gpu_module_body): + k4_allocator.finalize() + + launcher = p23_kernel( + workspace, + topk_weights_tensor, + sorted_token_ids, + sorted_weights_out, + sorted_expert_ids, + num_valid_ids_out, + moe_buf, + expert_mask_tensor, + i32_tokens, + i32_mesh_stride, + i32_mesh_size, + i32_moe_buf_elems, + ) + launcher.launch(grid=(n_grid, 1, 1), block=(K4_BLOCK, 1, 1), stream=stream) + + @flyc.jit + def launch_p0v2_p23( + topk_ids: fx.Tensor, + workspace: fx.Tensor, + topk_weights_tensor: fx.Tensor, + sorted_token_ids: fx.Tensor, + sorted_weights_out: fx.Tensor, + sorted_expert_ids: fx.Tensor, + num_valid_ids_out: fx.Tensor, + moe_buf: fx.Tensor, + expert_mask_tensor: fx.Tensor, + i32_tokens: fx.Int32, + i32_mesh_stride: fx.Int32, + i32_mesh_size: fx.Int32, + i32_moe_buf_elems: fx.Int32, + n_grid_p23: fx.Int32, + stream: fx.Stream = fx.Stream(None), + ): + p0v2_allocator.finalized = False + k4_allocator.finalized = False + ctx = CompilationContext.get_current() + with ir.InsertionPoint(ctx.gpu_module_body): + p0v2_allocator.finalize() + k4_allocator.finalize() + + l1 = p0v2_kernel(topk_ids, workspace, expert_mask_tensor, i32_tokens, i32_mesh_stride, i32_mesh_size) + l1.launch(grid=(E, 1, 1), block=(P0V2_BLOCK, 1, 1), stream=stream) + + l2 = p23_kernel( + workspace, + topk_weights_tensor, + sorted_token_ids, + sorted_weights_out, + sorted_expert_ids, + num_valid_ids_out, + moe_buf, + expert_mask_tensor, + i32_tokens, + i32_mesh_stride, + i32_mesh_size, + i32_moe_buf_elems, + ) + l2.launch(grid=(n_grid_p23, 1, 1), block=(K4_BLOCK, 1, 1), stream=stream) + + @flyc.jit + def launch_4k_fused( + topk_ids: fx.Tensor, + workspace: fx.Tensor, + topk_weights_tensor: fx.Tensor, + sorted_token_ids: fx.Tensor, + sorted_weights_out: fx.Tensor, + sorted_expert_ids: fx.Tensor, + num_valid_ids_out: fx.Tensor, + moe_buf: fx.Tensor, + expert_mask_tensor: fx.Tensor, + i32_tokens: fx.Int32, + i32_mesh_stride: fx.Int32, + i32_mesh_size: fx.Int32, + i32_moe_buf_elems: fx.Int32, + i32_ws_total: fx.Int32, + i32_p0_niters: fx.Int32, + n_grid_k1: fx.Int32, + n_grid_k2: fx.Int32, + n_grid_p23: fx.Int32, + stream: fx.Stream = fx.Stream(None), + ): + k3_allocator.finalized = False + k4_allocator.finalized = False + ctx = CompilationContext.get_current() + with ir.InsertionPoint(ctx.gpu_module_body): + k3_allocator.finalize() + k4_allocator.finalize() + + l1 = clear_workspace_kernel(workspace, i32_ws_total) + l1.launch(grid=(n_grid_k1, 1, 1), block=(K1_BLOCK, 1, 1), stream=stream) + + l2 = p0_scatter_kernel(topk_ids, workspace, i32_tokens, i32_mesh_stride, i32_p0_niters) + l2.launch(grid=(n_grid_k2, 1, 1), block=(K2_BLOCK, 1, 1), stream=stream) + + l3 = p1_count_kernel(workspace, expert_mask_tensor, i32_mesh_stride, i32_mesh_size) + l3.launch(grid=(E, 1, 1), block=(K3_BLOCK, 1, 1), stream=stream) + + l4 = p23_kernel( + workspace, + topk_weights_tensor, + sorted_token_ids, + sorted_weights_out, + sorted_expert_ids, + num_valid_ids_out, + moe_buf, + expert_mask_tensor, + i32_tokens, + i32_mesh_stride, + i32_mesh_size, + i32_moe_buf_elems, + ) + l4.launch(grid=(n_grid_p23, 1, 1), block=(K4_BLOCK, 1, 1), stream=stream) + + return launch_clear_ws, launch_p0, launch_p1, launch_p23, launch_p0v2, launch_p0v2_p23, launch_4k_fused + + +# Host-side entry point +# --------------------------------------------------------------------------- +@functools.lru_cache(maxsize=64) +def _compute_sub_tokens(num_experts, arch=None): + """Compute the LDS-capacity threshold (sub_tokens) for decode vs prefill decision. + + Returns the max T that fits in LDS for the decode (single-kernel) path. + Same formula as compile_moe_sorting_decode. + """ + if arch is None: + arch = get_hip_arch() + E = num_experts + smem_cols = E + 1 + if arch in ("gfx942",) or str(arch).startswith("gfx94"): + lds_capacity_bytes = 65536 + elif str(arch).startswith("gfx95"): + lds_capacity_bytes = 163840 + else: + lds_capacity_bytes = 65536 + lds_capacity_ints = lds_capacity_bytes // 4 + target_occupancy = 2 + r = lds_capacity_ints // target_occupancy // smem_cols + sub_unroll = 8 + cumsum_bufs = 2 + if r < (cumsum_bufs + sub_unroll): + return 0 # LDS too small — always use prefill + r_for_sub = ((r - cumsum_bufs) // sub_unroll) * sub_unroll + return r_for_sub + + +def moe_sorting_get_workspace_size(M, num_experts, topk, unit_size=UNIT_SIZE): + """Return workspace size (in i32 elements) needed for the prefill path. + Returns 0 if the decode path will be used.""" + sub_tokens = _compute_sub_tokens(num_experts) + DECODE_MAX_T = 16 + if M <= min(sub_tokens, DECODE_MAX_T): + return 0 + mesh_stride = ((M + unit_size - 1) // unit_size) * unit_size + ws_mesh_bytes = num_experts * mesh_stride + ws_mesh_i32 = (ws_mesh_bytes + 3) // 4 + return ws_mesh_i32 + (num_experts + 1) + + +def moe_sorting_flydsl( + topk_ids, + topk_weights, + sorted_ids, + sorted_weights, + sorted_expert_ids, + num_valid_ids, + moe_buf, + num_experts, + unit_size=UNIT_SIZE, + expert_mask=None, + num_local_tokens=None, + workspace=None, +): + """MoE sorting using FlyDSL kernel (decode + prefill paths). + + API matches aiter.moe_sorting_fwd for drop-in replacement: + moe_sorting_flydsl(topk_ids, topk_weights, + sorted_ids, sorted_weights, sorted_expert_ids, + num_valid_ids, moe_buf, + num_experts, unit_size, expert_mask, + num_local_tokens, workspace) + + All output tensors (sorted_ids, sorted_weights, sorted_expert_ids, + num_valid_ids, moe_buf) must be pre-allocated by the caller. + + Returns + ------- + sorted_ids, sorted_weights, sorted_expert_ids, num_valid_ids, moe_buf + """ + topk = topk_ids.shape[1] + if num_local_tokens is not None: + M = num_local_tokens.item() if isinstance(num_local_tokens, torch.Tensor) else int(num_local_tokens) + else: + M = topk_ids.shape[0] + + sub_tokens = _compute_sub_tokens(num_experts) + + device = topk_ids.device + moe_buf_i32 = moe_buf.view(torch.int32) + moe_buf_elems = moe_buf_i32.numel() + + # EP: prepare mask tensor and flag. + has_mask = expert_mask is not None + if not has_mask: + mask_tensor = _dummy_mask_cache.get(device) + if mask_tensor is None: + mask_tensor = torch.ones(1, dtype=torch.int32, device=device) + _dummy_mask_cache[device] = mask_tensor + else: + mask_tensor = expert_mask + + DECODE_MAX_T = 16 + + if M <= min(sub_tokens, DECODE_MAX_T): + max_tokens = max(M, 8) + max_tokens = ((max_tokens + 7) // 8) * 8 + + target_occupancy = 2 + num_cu = torch.cuda.get_device_properties(topk_ids.device).multi_processor_count + n_zero_blocks = min((moe_buf_elems + BLOCK_SIZE - 1) // BLOCK_SIZE, num_cu * target_occupancy) + n_grid_blocks = 1 + n_zero_blocks + + launch_moe_sorting_decode_path( + topk_ids, + topk_weights, + sorted_ids, + sorted_weights, + sorted_expert_ids, + num_valid_ids, + moe_buf_i32, + mask_tensor, + M, + moe_buf_elems, + n_grid_blocks, + num_experts=num_experts, + topk=topk, + max_tokens=max_tokens, + unit_size=unit_size, + has_mask=has_mask, + ) + else: + # Prefill path: multiple kernels via HBM workspace + mesh_stride = ((M + unit_size - 1) // unit_size) * unit_size + ws_mesh_bytes = num_experts * mesh_stride + ws_mesh_i32 = (ws_mesh_bytes + 3) // 4 + ws_total = ws_mesh_i32 + (num_experts + 1) + if workspace is None: + workspace = torch.empty(ws_total, dtype=torch.int32, device=device) + + launch_moe_sorting_prefill_path( + topk_ids, + topk_weights, + sorted_ids, + sorted_weights, + sorted_expert_ids, + num_valid_ids, + moe_buf_i32, + workspace, + mask_tensor, + M, + moe_buf_elems, + mesh_stride, + ws_mesh_i32, + ws_total, + num_experts=num_experts, + topk=topk, + unit_size=unit_size, + has_mask=has_mask, + ) + + return sorted_ids, sorted_weights, sorted_expert_ids, num_valid_ids, moe_buf + + +def launch_moe_sorting_decode_path( + topk_ids, + topk_weights, + sorted_ids, + sorted_weights, + sorted_expert_ids, + num_valid_ids, + moe_buf_i32, + expert_mask, + i32_tokens, + i32_moe_buf_elems, + n_grid_blocks, + *, + num_experts, + topk, + max_tokens=128, + unit_size=UNIT_SIZE, + has_mask=False, +): + """Low-level launcher for decode path: single kernel. + + This is the hot-path entry point — no torch ops, just JIT dispatch. + Uses AOT-compiled dispatch after the first call to bypass the ~70 us + JIT overhead (inspect.Signature.bind + cache key + dict lookup). + """ + + cache_key = (num_experts, topk, max_tokens, unit_size, has_mask, topk_ids.device.index) + cf = _decode_cf_cache.get(cache_key) + if cf is not None: + stream = torch.cuda.current_stream() + cf( + topk_ids, + topk_weights, + sorted_ids, + sorted_weights, + sorted_expert_ids, + num_valid_ids, + moe_buf_i32, + expert_mask, + i32_tokens, + i32_moe_buf_elems, + n_grid_blocks, + fx.Stream(stream), + ) + return + + launch_fn = compile_moe_sorting_decode( + num_experts=num_experts, + topk=topk, + max_tokens=max_tokens, + unit_size=unit_size, + has_mask=has_mask, + ) + stream = torch.cuda.current_stream() + launch_fn( + topk_ids, + topk_weights, + sorted_ids, + sorted_weights, + sorted_expert_ids, + num_valid_ids, + moe_buf_i32, + expert_mask, + i32_tokens, + i32_moe_buf_elems, + n_grid_blocks, + stream=stream, + ) + + cf = flyc.compile( + launch_fn, + topk_ids, + topk_weights, + sorted_ids, + sorted_weights, + sorted_expert_ids, + num_valid_ids, + moe_buf_i32, + expert_mask, + i32_tokens, + i32_moe_buf_elems, + n_grid_blocks, + fx.Stream(stream), + ) + _decode_cf_cache[cache_key] = cf + + +def launch_moe_sorting_prefill_path( + topk_ids, + topk_weights, + sorted_ids, + sorted_weights, + sorted_expert_ids, + num_valid_ids, + moe_buf_i32, + workspace, + expert_mask, + i32_tokens, + i32_moe_buf_elems, + mesh_stride, + mesh_size, + ws_total, + *, + num_experts, + topk, + unit_size=UNIT_SIZE, + has_mask=False, +): + """Low-level launcher for prefill path via HBM workspace. + + For small T (<=512): fused P0_v2 (clear+scatter+count) + K4. + For large T: 4 separate kernels K1+K2+K3+K4. + + Uses AOT-compiled dispatch after the first call for each sub-kernel + to bypass JIT overhead. + """ + + launch_clear_ws, launch_p0, launch_p1, launch_p23, launch_p0v2, launch_p0v2_p23, launch_4k_fused = ( + compile_moe_sorting_prefill( + num_experts=num_experts, + topk=topk, + unit_size=unit_size, + has_mask=has_mask, + ) + ) + + stream = torch.cuda.current_stream() + stream_arg = fx.Stream(stream) + + num_cu = torch.cuda.get_device_properties(topk_ids.device).multi_processor_count + + base_key = (num_experts, topk, unit_size, has_mask, topk_ids.device.index) + use_p0v2 = i32_tokens <= 2048 + + target_occupancy = 2 + n_zero_blocks = min((i32_moe_buf_elems + BLOCK_SIZE - 1) // BLOCK_SIZE, num_cu * target_occupancy) + k4_grid = num_experts + n_zero_blocks + + if use_p0v2: + ck = base_key + ("p0v2_p23",) + cf = _prefill_cf_cache.get(ck) + if cf is not None: + cf( + topk_ids, + workspace, + topk_weights, + sorted_ids, + sorted_weights, + sorted_expert_ids, + num_valid_ids, + moe_buf_i32, + expert_mask, + i32_tokens, + mesh_stride, + mesh_size, + i32_moe_buf_elems, + k4_grid, + stream_arg, + ) + else: + launch_p0v2_p23( + topk_ids, + workspace, + topk_weights, + sorted_ids, + sorted_weights, + sorted_expert_ids, + num_valid_ids, + moe_buf_i32, + expert_mask, + i32_tokens, + mesh_stride, + mesh_size, + i32_moe_buf_elems, + k4_grid, + stream=stream, + ) + cf = flyc.compile( + launch_p0v2_p23, + topk_ids, + workspace, + topk_weights, + sorted_ids, + sorted_weights, + sorted_expert_ids, + num_valid_ids, + moe_buf_i32, + expert_mask, + i32_tokens, + mesh_stride, + mesh_size, + i32_moe_buf_elems, + k4_grid, + stream_arg, + ) + _prefill_cf_cache[ck] = cf + return + + # 4-kernel path (T > 2048): fused clear+p0+p1+p23 + k1_grid = (ws_total + 1023) // 1024 + k2_grid = min(num_cu * target_occupancy, (i32_tokens * topk + 255) // 256) + k2_total = i32_tokens * topk + k2_stride = k2_grid * 256 + k2_niters = (k2_total + k2_stride - 1) // k2_stride + + ck = base_key + ("4k_fused",) + cf = _prefill_cf_cache.get(ck) + if cf is not None: + cf( + topk_ids, + workspace, + topk_weights, + sorted_ids, + sorted_weights, + sorted_expert_ids, + num_valid_ids, + moe_buf_i32, + expert_mask, + i32_tokens, + mesh_stride, + mesh_size, + i32_moe_buf_elems, + ws_total, + k2_niters, + k1_grid, + k2_grid, + k4_grid, + stream_arg, + ) + else: + launch_4k_fused( + topk_ids, + workspace, + topk_weights, + sorted_ids, + sorted_weights, + sorted_expert_ids, + num_valid_ids, + moe_buf_i32, + expert_mask, + i32_tokens, + mesh_stride, + mesh_size, + i32_moe_buf_elems, + ws_total, + k2_niters, + k1_grid, + k2_grid, + k4_grid, + stream=stream, + ) + cf = flyc.compile( + launch_4k_fused, + topk_ids, + workspace, + topk_weights, + sorted_ids, + sorted_weights, + sorted_expert_ids, + num_valid_ids, + moe_buf_i32, + expert_mask, + i32_tokens, + mesh_stride, + mesh_size, + i32_moe_buf_elems, + ws_total, + k2_niters, + k1_grid, + k2_grid, + k4_grid, + stream_arg, + ) + _prefill_cf_cache[ck] = cf diff --git a/tests/kernels/test_moe_sorting.py b/tests/kernels/test_moe_sorting.py new file mode 100644 index 00000000..7943afa1 --- /dev/null +++ b/tests/kernels/test_moe_sorting.py @@ -0,0 +1,837 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 FlyDSL Project Contributors + +"""Tests for MoE token sorting kernel. + +Validates the FlyDSL GPU kernel against: + 1. Python reference implementation (moe_sorting_reference) + 2. aiter/CK kernel (if available) + +Usage: + FLYDSL_RUNTIME_ENABLE_CACHE=0 PYTHONPATH=./ pytest tests/kernels/test_moe_sorting.py -v + FLYDSL_RUNTIME_ENABLE_CACHE=0 PYTHONPATH=./ python tests/kernels/test_moe_sorting.py +""" + +import argparse +import sys + +import pytest + +pytestmark = [pytest.mark.l2_device, pytest.mark.rocm_lower] + +try: + import torch +except ImportError: + torch = None +if torch is None or not torch.cuda.is_available(): + pytest.skip("CUDA/ROCm not available.", allow_module_level=True) + +from flydsl.runtime.device import is_rdna_arch # noqa: E402 + +if is_rdna_arch(): + pytest.skip("MoE sorting kernel requires CDNA (MI300X/MI350X).", allow_module_level=True) + +from kernels.moe_sorting_kernel import ( # noqa: E402 + UNIT_SIZE, + moe_sorting_flydsl, +) + +WARMUP_ITERS = 3 + + +def _call_flydsl(topk_ids, topk_weights, E, model_dim=4096, topk=None, unit_size=UNIT_SIZE, expert_mask=None): + """Test helper: allocates outputs and calls moe_sorting_flydsl (CK-compatible API).""" + if topk is None: + topk = topk_ids.shape[1] + T = topk_ids.shape[0] + max_padded = T * topk + E * unit_size - topk + max_blocks = (max_padded + unit_size - 1) // unit_size + device = topk_ids.device + s_ids = torch.empty(max_padded, dtype=torch.int32, device=device) + s_w = torch.empty(max_padded, dtype=torch.float32, device=device) + s_eids = torch.empty(max_blocks, dtype=torch.int32, device=device) + nv = torch.empty(2, dtype=torch.int32, device=device) + buf = torch.empty((T, model_dim), dtype=torch.bfloat16, device=device) + return moe_sorting_flydsl(topk_ids, topk_weights, s_ids, s_w, s_eids, nv, buf, E, unit_size, expert_mask) + + +BENCH_ITERS = 20 +BENCH_WARMUP = 10 +BENCH_MEASURE = 50 + + +# --------------------------------------------------------------------------- +# CPU reference implementation +# --------------------------------------------------------------------------- +def moe_sorting_reference(topk_ids, topk_weights, num_experts, unit_size=UNIT_SIZE, expert_mask=None): + """Pure-Python reference matching the CK/aiter packed-ID format.""" + device = topk_ids.device + M, topk = topk_ids.shape + max_num_tokens_padded = topk_ids.numel() + num_experts * unit_size - topk + max_num_m_blocks = (max_num_tokens_padded + unit_size - 1) // unit_size + + sentinel = (topk << 24) | M + sorted_ids = torch.full((max_num_tokens_padded,), sentinel, dtype=torch.int32, device=device) + sorted_weights = torch.zeros((max_num_tokens_padded,), dtype=torch.float32, device=device) + sorted_expert_ids = torch.full((max_num_m_blocks,), -1, dtype=torch.int32, device=device) + num_valid_ids = torch.zeros(2, dtype=torch.int32, device=device) + + ids_cursor = 0 + expert_ids_cursor = 0 + skip_expert_num = 0 + for eid in range(num_experts): + if expert_mask is not None and expert_mask[eid].item() == 0: + skip_expert_num += 1 + continue + token_id, topk_pos = torch.where(topk_ids == eid) + count = token_id.numel() + if count == 0: + continue + num_blocks = (count + unit_size - 1) // unit_size + padded = num_blocks * unit_size + sorted_ids[ids_cursor : ids_cursor + count] = (topk_pos << 24) | token_id + sorted_weights[ids_cursor : ids_cursor + count] = topk_weights[token_id, topk_pos] + ids_cursor += padded + sorted_expert_ids[expert_ids_cursor : expert_ids_cursor + num_blocks] = eid - skip_expert_num + expert_ids_cursor += num_blocks + + num_valid_ids[0] = ids_cursor + num_valid_ids[1] = M + return sorted_ids, sorted_weights, sorted_expert_ids, num_valid_ids + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- +def generate_topk_ids(T, E, topk, device="cuda"): + """Generate random topk_ids and topk_weights for testing. + + Each token gets *unique* expert assignments (no duplicate expert IDs per + token), matching the real MoE router constraint. The mesh can only store + one topk_slot per (token, expert) pair, so duplicates would silently drop + assignments. + """ + assert topk <= E, f"topk={topk} must be <= E={E}" + topk_ids = torch.zeros(T, topk, dtype=torch.int32, device=device) + for t in range(T): + perm = torch.randperm(E, device=device)[:topk] + topk_ids[t] = perm.to(torch.int32) + topk_weights = torch.rand(T, topk, dtype=torch.float32, device=device) + return topk_ids, topk_weights + + +def check_sorted_ids(ref_ids, gpu_ids, num_padded, topk, M, label="sorted_ids"): + """Compare sorted_ids up to num_padded, ignoring padding sentinels.""" + sentinel = (topk << 24) | M + ref_slice = ref_ids[:num_padded] + gpu_slice = gpu_ids[:num_padded] + + mask = ref_slice != sentinel + n_valid = mask.sum().item() + + if n_valid == 0: + print(f" [{label}] no valid tokens (all padding) — OK") + return True + + ref_valid = ref_slice[mask] + gpu_valid = gpu_slice[mask] + + # Per-expert block comparison: within each expert's padded block, the set of + # packed IDs must match (order within block may differ between implementations) + if torch.equal(ref_valid, gpu_valid): + print(f" [{label}] exact match ({n_valid} valid entries)") + return True + + # Fallback: set-equality check per expert block + mismatch = (ref_valid != gpu_valid).sum().item() + print(f" [{label}] WARNING: {mismatch}/{n_valid} entries differ (checking set equality)") + + ref_set = set(ref_valid.cpu().tolist()) + gpu_set = set(gpu_valid.cpu().tolist()) + if ref_set == gpu_set: + print(f" [{label}] set-equal (order differs) — OK") + return True + + # Per-expert-block validation: for EP/atomic-scatter, within-expert ordering may differ. + # Check that each expert's block contains the same set of packed IDs. + missing_from_gpu = ref_set - gpu_set + extra_in_gpu = gpu_set - ref_set + if not missing_from_gpu and not extra_in_gpu: + print(f" [{label}] multiset-equal — OK") + return True + + print(f" [{label}] MISMATCH (missing={len(missing_from_gpu)}, extra={len(extra_in_gpu)})") + # Print first few diffs + diff_mask = ref_valid != gpu_valid + diff_indices = diff_mask.nonzero(as_tuple=True)[0][:10] + for idx in diff_indices: + r = ref_valid[idx].item() + g = gpu_valid[idx].item() + r_tok, r_topk = r & 0xFFFFFF, r >> 24 + g_tok, g_topk = g & 0xFFFFFF, g >> 24 + print(f" idx={idx.item()}: ref=({r_tok},{r_topk}) gpu=({g_tok},{g_topk})") + return False + + +def check_sorted_weights( + ref_w, gpu_w, ref_ids, topk, M, atol=1e-5, label="sorted_weights", gpu_ids=None, num_padded=None +): + """Compare sorted_weights, masking padding entries. + + When gpu_ids is provided and position-by-position comparison fails, + falls back to per-entry validation: checks that each GPU (packed_id, weight) + pair matches the reference by packed_id lookup (handles non-deterministic + order from atomic scatter). + """ + sentinel = (topk << 24) | M + # Limit to num_padded if provided (entries beyond are uninitialized) + check_range = num_padded if num_padded is not None else len(ref_ids) + ref_slice = ref_ids[:check_range] + mask = ref_slice != sentinel + n_valid = mask.sum().item() + if n_valid == 0: + return True + ref_valid = ref_w[:check_range][mask] + gpu_valid = gpu_w[:check_range][mask] + max_err = (ref_valid - gpu_valid).abs().max().item() + ok = max_err < atol + if ok: + print(f" [{label}] max_err={max_err:.2e} (OK)") + return True + # Position-by-position failed; try per-entry validation if gpu_ids provided + if gpu_ids is not None: + # Build lookup: packed_id -> expected weight from ref + ref_lut = {} + for i in range(check_range): + pid = ref_ids[i].item() + if pid != sentinel: + ref_lut[pid] = ref_w[i].item() + # Check each GPU entry within the padded range + gpu_slice = gpu_ids[:check_range] + max_pair_err = 0.0 + n_pair_checked = 0 + for i in range(check_range): + gpid = gpu_slice[i].item() + if gpid == sentinel: + continue + n_pair_checked += 1 + if gpid in ref_lut: + err = abs(gpu_w[i].item() - ref_lut[gpid]) + max_pair_err = max(max_pair_err, err) + else: + max_pair_err = float("inf") + break + if n_pair_checked == n_valid and max_pair_err < atol: + print(f" [{label}] max_pair_err={max_pair_err:.2e} (OK, order differs)") + return True + status = "FAIL" + print(f" [{label}] max_err={max_err:.2e} ({status})") + return False + + +def check_expert_ids(ref_eids, gpu_eids, label="sorted_expert_ids", num_valid_blocks=None): + """Compare sorted_expert_ids within valid range. + + When num_valid_blocks is provided, compares only that many blocks + (entries beyond are uninitialized garbage). Otherwise falls back to + masking by ref_eids != -1 (for Python reference comparisons). + """ + if num_valid_blocks is not None: + n_valid = num_valid_blocks + ref_valid = ref_eids[:n_valid] + gpu_valid = gpu_eids[:n_valid] + else: + mask = ref_eids != -1 + n_valid = mask.sum().item() + if n_valid == 0: + return True + ref_valid = ref_eids[mask] + gpu_valid = gpu_eids[mask] + ok = torch.equal(ref_valid, gpu_valid) + status = "OK" if ok else "FAIL" + print(f" [{label}] {n_valid} blocks ({status})") + if not ok: + diff = (ref_valid != gpu_valid).nonzero(as_tuple=True)[0][:10] + for idx in diff: + print(f" block {idx.item()}: ref={ref_valid[idx].item()} gpu={gpu_valid[idx].item()}") + return ok + + +# --------------------------------------------------------------------------- +# Single test case +# --------------------------------------------------------------------------- +def run_test(T, E, topk, unit_size=UNIT_SIZE, max_tokens=None): + """Run a single MoE sorting test case. + + Returns (passed: bool, gpu_time_us: float or None). + """ + # Let moe_sorting_flydsl auto-select decode/prefill path. + # max_tokens is only needed for explicit decode-path override. + from kernels.moe_sorting_kernel import _compute_sub_tokens + + sub_tokens = _compute_sub_tokens(E) + path = "decode" if T <= sub_tokens else "prefill" + + if max_tokens is None and path == "decode": + max_tokens = max(T, 8) + max_tokens = ((max_tokens + 7) // 8) * 8 + + print(f"\n{'='*60}") + print(f"Test: T={T}, E={E}, topk={topk}, unit_size={unit_size}, path={path}") + print(f"{'='*60}") + + torch.manual_seed(42 + T * 1000 + E * 10 + topk) + topk_ids, topk_weights = generate_topk_ids(T, E, topk) + + # --- Reference --- + ref_ids, ref_w, ref_eids, ref_nvalid = moe_sorting_reference(topk_ids, topk_weights, E, unit_size) + + # --- FlyDSL GPU kernel --- + try: + gpu_ids, gpu_w, gpu_eids, gpu_nvalid, gpu_moe_buf = _call_flydsl( + topk_ids, + topk_weights, + E, + model_dim=4096, + topk=topk, + unit_size=unit_size, + ) + except Exception as e: + print(f" [FAIL] Kernel launch failed: {e}") + import traceback + + traceback.print_exc() + return False, None + + torch.cuda.synchronize() + + # --- Validate --- + passed = True + + # 1. num_valid_ids + nv_ok = torch.equal(ref_nvalid, gpu_nvalid) + print(f" [num_valid_ids] ref={ref_nvalid.tolist()} gpu={gpu_nvalid.tolist()} ({'OK' if nv_ok else 'FAIL'})") + passed &= nv_ok + + num_padded = ref_nvalid[0].item() + + # 2. sorted_ids + passed &= check_sorted_ids(ref_ids, gpu_ids, num_padded, topk, T) + + # 3. sorted_weights + passed &= check_sorted_weights(ref_w, gpu_w, ref_ids, topk, T, gpu_ids=gpu_ids, num_padded=num_padded) + + # 4. sorted_expert_ids + passed &= check_expert_ids(ref_eids, gpu_eids) + + # 5. moe_buf should be zeroed + moe_buf_zero = (gpu_moe_buf.view(torch.int32) == 0).all().item() + print(f" [moe_buf_zeroed] {'OK' if moe_buf_zero else 'FAIL'}") + passed &= moe_buf_zero + + # --- Benchmark --- + gpu_time_us = None + if passed: + # Warmup + for _ in range(WARMUP_ITERS): + _call_flydsl(topk_ids, topk_weights, E, model_dim=4096, topk=topk, unit_size=unit_size) + torch.cuda.synchronize() + + # Timed runs + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + for _ in range(BENCH_ITERS): + _call_flydsl(topk_ids, topk_weights, E, model_dim=4096, topk=topk, unit_size=unit_size) + end.record() + torch.cuda.synchronize() + gpu_time_us = start.elapsed_time(end) * 1000.0 / BENCH_ITERS # ms → us + print(f" [perf] {gpu_time_us:.2f} us/call ({path})") + + status = "PASSED" if passed else "FAILED" + print(f" >>> {status}") + return passed, gpu_time_us + + +# --------------------------------------------------------------------------- +# Test with aiter reference (optional) +# --------------------------------------------------------------------------- +def run_test_vs_aiter(T, E, topk, unit_size=UNIT_SIZE, max_tokens=None): + """Compare FlyDSL kernel against aiter GPU kernel (if available).""" + try: + from aiter.fused_moe import moe_sorting as aiter_moe_sorting + except ImportError: + print(" [SKIP] aiter not available for cross-validation") + return None, None + + torch.manual_seed(42 + T * 1000 + E * 10 + topk) + topk_ids, topk_weights = generate_topk_ids(T, E, topk) + + print(f"\n [vs aiter] T={T}, E={E}, topk={topk}") + + # aiter reference + aiter_ids, aiter_w, aiter_eids, aiter_nvalid, _ = aiter_moe_sorting( + topk_ids, + topk_weights, + E, + model_dim=4096, + moebuf_dtype=torch.bfloat16, + block_size=unit_size, + ) + + # FlyDSL (auto-dispatches decode/prefill) + fly_ids, fly_w, fly_eids, fly_nvalid, _ = _call_flydsl( + topk_ids, + topk_weights, + E, + model_dim=4096, + topk=topk, + unit_size=unit_size, + ) + torch.cuda.synchronize() + + # Compare + nv_ok = torch.equal(aiter_nvalid, fly_nvalid) + num_padded = aiter_nvalid[0].item() + num_valid_blocks = num_padded // unit_size + ids_ok = check_sorted_ids(aiter_ids, fly_ids, num_padded, topk, T, "sorted_ids(vs_aiter)") + w_ok = check_sorted_weights( + aiter_w, fly_w, aiter_ids, topk, T, label="sorted_weights(vs_aiter)", gpu_ids=fly_ids, num_padded=num_padded + ) + e_ok = check_expert_ids(aiter_eids, fly_eids, "sorted_expert_ids(vs_aiter)", num_valid_blocks=num_valid_blocks) + + passed = nv_ok and ids_ok and w_ok and e_ok + return passed, None + + +# --------------------------------------------------------------------------- +# Pytest entry points +# --------------------------------------------------------------------------- +DECODE_CONFIGS = [ + # (T, E, topk) — decode path (small T) + (1, 256, 8), + (1, 32, 5), + (4, 256, 8), + (8, 256, 8), + (16, 256, 8), + (32, 256, 8), + (64, 256, 8), + # Edge cases + (1, 8, 2), + (7, 32, 5), # odd T, topk not power of 2 + (31, 64, 6), # prime T, topk not power of 2 + # Production E > 256 (DECODE_BLOCK=512) — core coverage + (1, 257, 9), # DeepSeek-R1 (256 routed + 1 shared) + (16, 257, 9), + (16, 513, 9), # Qwen3.5 (512 routed + 1 shared) +] + +DECODE_CONFIGS_FULL = DECODE_CONFIGS + [ + # Extended production coverage (large_shape — CI skips by default) + (8, 257, 9), + (1, 385, 7), # DeepSeek-V4 (384 routed + 1 shared) + (16, 385, 7), + (1, 513, 9), # Qwen3.5 + (1, 128, 4), # Qwen3-MoE + (16, 129, 7), # Qwen3-Next (128 + 1 shared) + (16, 161, 7), # GLM-4-MoE (160 + 1 shared) +] + + +PREFILL_CONFIGS = [ + # (T, E, topk) — prefill path (large T, HBM workspace) + (128, 256, 8), + (512, 256, 8), + (1024, 256, 8), + (2048, 256, 8), + # Production E > 256 — core coverage + (1024, 257, 9), # DeepSeek-R1 + (1024, 513, 9), # Qwen3.5 +] + +PREFILL_CONFIGS_FULL = PREFILL_CONFIGS + [ + # Extended (large_shape — CI skips by default) + (4096, 256, 8), + (8192, 256, 8), + (16384, 256, 8), + (16384, 257, 9), + (1024, 385, 7), # DeepSeek-V4 + (16384, 385, 7), + (16384, 513, 9), +] + + +@pytest.mark.parametrize("T,E,topk", DECODE_CONFIGS) +def test_moe_sorting_decode(T, E, topk): + passed, _ = run_test(T, E, topk) + assert passed, f"MoE sorting failed for T={T}, E={E}, topk={topk}" + + +@pytest.mark.large_shape +@pytest.mark.parametrize("T,E,topk", [c for c in DECODE_CONFIGS_FULL if c not in DECODE_CONFIGS]) +def test_moe_sorting_decode_full(T, E, topk): + passed, _ = run_test(T, E, topk) + assert passed, f"MoE sorting failed for T={T}, E={E}, topk={topk}" + + +@pytest.mark.parametrize("T,E,topk", PREFILL_CONFIGS) +def test_moe_sorting_prefill(T, E, topk): + passed, _ = run_test(T, E, topk) + assert passed, f"MoE sorting (prefill) failed for T={T}, E={E}, topk={topk}" + + +@pytest.mark.large_shape +@pytest.mark.parametrize("T,E,topk", [c for c in PREFILL_CONFIGS_FULL if c not in PREFILL_CONFIGS]) +def test_moe_sorting_prefill_full(T, E, topk): + passed, _ = run_test(T, E, topk) + assert passed, f"MoE sorting (prefill) failed for T={T}, E={E}, topk={topk}" + + +def run_test_ep(T, E, topk, mask_ratio=0.5, unit_size=UNIT_SIZE): + """Run MoE sorting test with expert_mask (EP mode).""" + from kernels.moe_sorting_kernel import _compute_sub_tokens + + sub_tokens = _compute_sub_tokens(E) + DECODE_MAX_T = 16 + if T <= min(sub_tokens, DECODE_MAX_T): + path = "decode" + else: + path = "prefill" + + print(f"\n{'='*60}") + print(f"EP Test: T={T}, E={E}, topk={topk}, mask_ratio={mask_ratio}, path={path}") + print(f"{'='*60}") + + torch.manual_seed(42 + T * 1000 + E * 10 + topk + int(mask_ratio * 100)) + topk_ids, topk_weights = generate_topk_ids(T, E, topk) + + if mask_ratio == 0.0: + expert_mask = torch.zeros(E, dtype=torch.int32, device="cuda") + elif mask_ratio == 1.0: + expert_mask = torch.ones(E, dtype=torch.int32, device="cuda") + else: + expert_mask = (torch.rand(E, device="cuda") < mask_ratio).to(torch.int32) + if expert_mask.sum() == 0: + expert_mask[0] = 1 + + n_enabled = expert_mask.sum().item() + print(f" expert_mask: {n_enabled}/{E} experts enabled") + + ref_ids, ref_w, ref_eids, ref_nvalid = moe_sorting_reference( + topk_ids, topk_weights, E, unit_size, expert_mask=expert_mask + ) + + try: + gpu_ids, gpu_w, gpu_eids, gpu_nvalid, gpu_moe_buf = _call_flydsl( + topk_ids, + topk_weights, + E, + model_dim=4096, + topk=topk, + unit_size=unit_size, + expert_mask=expert_mask, + ) + except Exception as e: + print(f" [FAIL] Kernel launch failed: {e}") + import traceback + + traceback.print_exc() + return False + + torch.cuda.synchronize() + + passed = True + nv_ok = torch.equal(ref_nvalid, gpu_nvalid) + print(f" [num_valid_ids] ref={ref_nvalid.tolist()} gpu={gpu_nvalid.tolist()} ({'OK' if nv_ok else 'FAIL'})") + passed &= nv_ok + + num_padded = ref_nvalid[0].item() + passed &= check_sorted_ids(ref_ids, gpu_ids, num_padded, topk, T) + passed &= check_sorted_weights(ref_w, gpu_w, ref_ids, topk, T, gpu_ids=gpu_ids, num_padded=num_padded) + passed &= check_expert_ids(ref_eids, gpu_eids) + + moe_buf_zero = (gpu_moe_buf.view(torch.int32) == 0).all().item() + print(f" [moe_buf_zeroed] {'OK' if moe_buf_zero else 'FAIL'}") + passed &= moe_buf_zero + + status = "PASSED" if passed else "FAILED" + print(f" >>> {status}") + return passed + + +EP_CONFIGS = [ + # (T, E, topk, mask_ratio) + (4, 256, 8, 0.5), # decode path + (8, 256, 8, 0.3), # decode path, sparse + (64, 256, 8, 0.5), # prefill path + (128, 256, 8, 0.7), # prefill path + (2048, 256, 8, 0.5), # prefill path + (4, 256, 8, 1.0), # all enabled (should match non-EP) + (64, 256, 8, 1.0), # all enabled, prefill + (4, 256, 8, 0.0), # all masked (empty output) + # Production E>256 with EP + (8, 257, 9, 0.5), # DeepSeek-R1 decode + EP + (1024, 257, 9, 0.5), # DeepSeek-R1 prefill + EP + (8, 513, 9, 0.5), # Qwen3.5 decode + EP +] + + +@pytest.mark.parametrize("T,E,topk,mask_ratio", EP_CONFIGS) +def test_moe_sorting_ep(T, E, topk, mask_ratio): + passed = run_test_ep(T, E, topk, mask_ratio) + assert passed, f"EP test failed: T={T}, E={E}, topk={topk}, mask_ratio={mask_ratio}" + + +@pytest.mark.parametrize( + "T,E,topk", + [ + (1, 256, 8), + (8, 256, 8), + ], +) +def test_moe_sorting_vs_aiter(T, E, topk): + result, _ = run_test_vs_aiter(T, E, topk) + if result is None: + pytest.skip("aiter not available") + assert result, f"FlyDSL vs aiter mismatch for T={T}, E={E}, topk={topk}" + + +# --------------------------------------------------------------------------- +# Benchmark utilities +# --------------------------------------------------------------------------- +def bench_eager_us(fn, warmup=BENCH_WARMUP, iters=BENCH_MEASURE, flush_l2=True): + """Per-iteration CUDA events timer with L2 flush and median latency.""" + flush_buf = None + if flush_l2: + props = torch.cuda.get_device_properties(torch.cuda.current_device()) + l2_bytes = getattr(props, "L2_cache_size", 4 * 1024 * 1024) + flush_buf = torch.empty(max(l2_bytes * 2, 8 * 1024 * 1024), dtype=torch.uint8, device="cuda") + + for _ in range(warmup): + if flush_buf is not None: + flush_buf.zero_() + fn() + torch.cuda.synchronize() + + starts = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] + ends = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] + for i in range(iters): + if flush_buf is not None: + flush_buf.zero_() + starts[i].record() + fn() + ends[i].record() + torch.cuda.synchronize() + + latencies = sorted(starts[i].elapsed_time(ends[i]) * 1e3 for i in range(iters)) + n = len(latencies) + if n >= 8: + q1, q3 = latencies[n // 4], latencies[3 * n // 4] + iqr = q3 - q1 + lo, hi = q1 - 1.5 * iqr, q3 + 1.5 * iqr + latencies = [x for x in latencies if lo <= x <= hi] or latencies + del flush_buf + return latencies[len(latencies) // 2] + + +def bench_graph_us(fn, warmup=BENCH_WARMUP, iters=BENCH_MEASURE): + """CUDA graph benchmark — amortizes kernel launch overhead.""" + for _ in range(warmup): + fn() + torch.cuda.synchronize() + + stream = torch.cuda.Stream() + stream.wait_stream(torch.cuda.current_stream()) + try: + with torch.cuda.stream(stream): + fn() + torch.cuda.current_stream().wait_stream(stream) + torch.cuda.synchronize() + + graph = torch.cuda.CUDAGraph() + with torch.cuda.stream(stream): + with torch.cuda.graph(graph, stream=stream): + fn() + torch.cuda.current_stream().wait_stream(stream) + for _ in range(warmup): + graph.replay() + torch.cuda.synchronize() + except RuntimeError: + return None # graph capture not supported + + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + for _ in range(iters): + graph.replay() + end.record() + torch.cuda.synchronize() + return start.elapsed_time(end) * 1e3 / iters + + +def run_bench_comparison(token_sweep=None): + """Benchmark FlyDSL vs CK (aiter) across T values in eager and graph modes.""" + try: + from aiter.fused_moe import moe_sorting as aiter_moe_sorting + except ImportError: + print(" aiter not available, skipping CK comparison") + aiter_moe_sorting = None + + E, topk, model_dim = 256, 8, 4096 + if token_sweep is None: + token_sweep = [1, 4, 8, 16, 32, 64, 128, 512, 2048, 4096, 8192, 16384] + + from kernels.moe_sorting_kernel import _compute_sub_tokens + + sub_tokens = _compute_sub_tokens(E) + + print(f"\n{'=' * 110}") + print(f" MoE Sorting Benchmark: FlyDSL vs CK (E={E}, topk={topk}, unit_size={UNIT_SIZE})") + print(f" Device: {torch.cuda.get_device_name(0)}") + props = torch.cuda.get_device_properties(0) + print(f" CUs: {props.multi_processor_count}, decode threshold: T<={sub_tokens}") + print(f" Modes: eager (with L2 flush, median of {BENCH_MEASURE}), graph ({BENCH_MEASURE} replays)") + print(f"{'=' * 110}") + print( + f"{'T':>6s} | {'Path':>7s} | {'FLY eager':>10s} | {'FLY graph':>10s} | " + f"{'CK eager':>10s} | {'CK graph':>10s} | {'Eager':>7s} | {'Graph':>7s}" + ) + print("-" * 110) + + for T in token_sweep: + torch.manual_seed(42) + topk_ids = torch.stack([torch.randperm(E, device="cuda")[:topk] for _ in range(T)]).to(torch.int32) + topk_weights = torch.rand(T, topk, dtype=torch.float32, device="cuda") + + path = "decode" if T <= sub_tokens else "prefill" + + # Pre-allocate outputs to avoid per-call torch.empty overhead + max_num_tokens_padded = T * topk + E * UNIT_SIZE - topk + max_num_m_blocks = (max_num_tokens_padded + UNIT_SIZE - 1) // UNIT_SIZE + fly_sorted_ids = torch.empty(max_num_tokens_padded, dtype=torch.int32, device="cuda") + fly_sorted_w = torch.empty(max_num_tokens_padded, dtype=torch.float32, device="cuda") + fly_sorted_eids = torch.empty(max_num_m_blocks, dtype=torch.int32, device="cuda") + fly_nvalid = torch.empty(2, dtype=torch.int32, device="cuda") + + fly_moe_buf_2d = torch.empty((T, model_dim), dtype=torch.bfloat16, device="cuda") + + def fly_fn(): + moe_sorting_flydsl( + topk_ids, + topk_weights, + fly_sorted_ids, + fly_sorted_w, + fly_sorted_eids, + fly_nvalid, + fly_moe_buf_2d, + E, + UNIT_SIZE, + ) + + fly_eager = bench_eager_us(fly_fn) + fly_graph = bench_graph_us(fly_fn) + + ck_eager, ck_graph = None, None + if aiter_moe_sorting is not None: + + def ck_fn(): + aiter_moe_sorting( + topk_ids, topk_weights, E, model_dim=model_dim, moebuf_dtype=torch.bfloat16, block_size=UNIT_SIZE + ) + + ck_eager = bench_eager_us(ck_fn) + ck_graph = bench_graph_us(ck_fn) + + def fmt(v): + return f"{v:8.1f}us" if v is not None else " N/A" + + def ratio(a, b): + if a is None or b is None or b == 0: + return " N/A" + r = a / b + return f" {r:.2f}x" + + print( + f"{T:>6d} | {path:>7s} | {fmt(fly_eager)} | {fmt(fly_graph)} | " + f"{fmt(ck_eager)} | {fmt(ck_graph)} | " + f"{ratio(fly_eager, ck_eager)} | {ratio(fly_graph, ck_graph)}" + ) + + print("=" * 110) + print(" Ratio < 1.0 = FlyDSL faster. Eager includes launch overhead. Graph amortizes it.") + print() + + +# --------------------------------------------------------------------------- +# Standalone runner +# --------------------------------------------------------------------------- +def main(): + parser = argparse.ArgumentParser(description="MoE sorting kernel test & benchmark") + parser.add_argument("-T", type=int, default=None, help="Token count") + parser.add_argument("-E", type=int, default=None, help="Number of experts") + parser.add_argument("-k", "--topk", type=int, default=None, help="Top-k") + parser.add_argument("--all", action="store_true", help="Run all configs") + parser.add_argument("--aiter", action="store_true", help="Compare with aiter") + parser.add_argument("--bench", action="store_true", help="Run benchmark sweep (eager + graph, FlyDSL vs CK)") + parser.add_argument( + "--bench-tokens", type=str, default=None, help="Comma-separated T values for bench (default: all)" + ) + args = parser.parse_args() + + if args.bench: + token_sweep = None + if args.bench_tokens: + token_sweep = [int(t) for t in args.bench_tokens.split(",")] + run_bench_comparison(token_sweep=token_sweep) + return + + if args.T is not None: + E = args.E or 256 + topk = args.topk or 8 + configs = [(args.T, E, topk)] + elif args.all: + configs = DECODE_CONFIGS + PREFILL_CONFIGS + else: + configs = [ + (1, 256, 8), + (8, 256, 8), + (32, 256, 8), + (128, 256, 8), + (512, 256, 8), + ] + + total = 0 + failures = 0 + results = [] + + for T, E, topk in configs: + passed, time_us = run_test(T, E, topk) + total += 1 + if not passed: + failures += 1 + results.append({"T": T, "E": E, "topk": topk, "passed": passed, "us": time_us}) + + if args.aiter: + aiter_ok, _ = run_test_vs_aiter(T, E, topk) + if aiter_ok is False: + failures += 1 + + print(f"\n{'='*60}") + print(f"Results: {total - failures}/{total} passed") + if failures: + print(f"FAILURES: {failures}") + else: + print("ALL TESTS PASSED") + print(f"{'='*60}") + + for r in results: + t_str = f"{r['us']:.1f}us" if r["us"] else "N/A" + status = "PASS" if r["passed"] else "FAIL" + print(f" T={r['T']:>6d} E={r['E']:>3d} topk={r['topk']} {status} {t_str}") + + sys.exit(1 if failures else 0) + + +if __name__ == "__main__": + main() From 1b524551be5204f87f81d70098864061af3847f3 Mon Sep 17 00:00:00 2001 From: amd-weisun Date: Mon, 18 May 2026 16:06:39 +0100 Subject: [PATCH 02/14] fix CI failure --- kernels/moe_sorting_kernel.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/kernels/moe_sorting_kernel.py b/kernels/moe_sorting_kernel.py index 5ce40ff3..8a98eeee 100644 --- a/kernels/moe_sorting_kernel.py +++ b/kernels/moe_sorting_kernel.py @@ -407,7 +407,7 @@ def moe_sorting_decode_kernel( ps_safe_ix = ArithValue(ps_valid.select(ps_eid + c_one_i32, c_zero_i32)).index_cast(T.index) ps_val = ps_valid.select(_lds_load(cumsum_mr, ps_safe_ix), c_zero_i32) _lds_store(cumdup_mr, ps_val, ps_safe_ix) - _lds_store(cumdup_mr, is_t0.select(c_zero_i32, _lds_load(cumdup_mr, c_zero_i32)), c_zero_i32) + _lds_store(cumdup_mr, c_zero_i32, c_zero_i32) gpu.barrier() # DPP prefix sum — all NUM_WAVES waves active @@ -447,7 +447,7 @@ def moe_sorting_decode_kernel( gpu.barrier() # cumdup[0] = 0 - _lds_store(cumdup_mr, is_t0.select(c_zero_i32, _lds_load(cumdup_mr, c_zero_i32)), c_zero_i32) + _lds_store(cumdup_mr, c_zero_i32, c_zero_i32) gpu.barrier() # Write num_valid_ids from cumdup[E] @@ -478,7 +478,7 @@ def moe_sorting_decode_kernel( ml_val = ml_valid.select(ml_mask, c_zero_i32) ml_ix = ArithValue(ml_valid.select(ml_eid + c_one_i32, c_zero_i32)).index_cast(T.index) _lds_store(cumdup_mr, ml_val, ml_ix) - _lds_store(cumdup_mr, is_t0.select(c_zero_i32, _lds_load(cumdup_mr, c_zero_i32)), c_zero_i32) + _lds_store(cumdup_mr, c_zero_i32, c_zero_i32) gpu.barrier() # All-wave DPP prefix sum over mask values in cumdup @@ -516,7 +516,7 @@ def moe_sorting_decode_kernel( prev_m = new_m gpu.barrier() - _lds_store(cumdup_mr, is_t0.select(c_zero_i32, _lds_load(cumdup_mr, c_zero_i32)), c_zero_i32) + _lds_store(cumdup_mr, c_zero_i32, c_zero_i32) gpu.barrier() else: # No mask: cumdup[eid] = eid (identity mapping) From 51a1ecb8e13bf5ba67621309cffce57786214b4c Mon Sep 17 00:00:00 2001 From: amd-weisun Date: Mon, 18 May 2026 17:14:48 +0100 Subject: [PATCH 03/14] address copilot comments --- kernels/moe_sorting_kernel.py | 48 ++++++++++++++++++++++++------- tests/kernels/test_moe_sorting.py | 4 ++- 2 files changed, 41 insertions(+), 11 deletions(-) diff --git a/kernels/moe_sorting_kernel.py b/kernels/moe_sorting_kernel.py index 8a98eeee..9d55b135 100644 --- a/kernels/moe_sorting_kernel.py +++ b/kernels/moe_sorting_kernel.py @@ -752,7 +752,7 @@ def compile_moe_sorting_prefill( K3: P1 count — one block per expert, count non-zero mesh cells K4: P23 prefix-sum + scatter — prefix-sum on counts, scatter tokens, fill sorted_expert_ids, zero moe_buf - P0_v2: Fused clear+scatter+count — replaces K1+K2+K3 for small T (<=512) + P0_v2: Fused clear+scatter+count — replaces K1+K2+K3 for T <= 2048 Workspace layout (i32 elements): [0 .. ws_mesh_i32) : uint8 expert mesh (E rows x mesh_stride bytes, packed into i32) @@ -966,7 +966,7 @@ def launch_p1( launcher = p1_count_kernel(workspace, expert_mask_tensor, i32_mesh_stride, i32_mesh_size) launcher.launch(grid=(E, 1, 1), block=(K3_BLOCK, 1, 1), stream=stream) - # --- P0_v2: Fused clear+scatter+count kernel (for small T) --------------- + # --- P0_v2: Fused clear+scatter+count kernel (for T <= 2048) -------------- # Replaces K1+K2+K3 with a single kernel launch. # Grid: E blocks (one per expert), Block: 512 threads (matching CK P0_v2). # Phase 1: clear this expert's mesh row @@ -1217,6 +1217,18 @@ def p23_kernel( _if_init_cs = scf.IfOp(is_t0_init.ir_value()) with _if_then(_if_init_cs): _lds_store_raw(cumsum_mr, c_zero, ArithValue(c_zero).index_cast(T.index)) + + # EP: load this thread's own mask value BEFORE the chunked loop. + # The chunked loop overwrites p23_mask_val in later chunks, so we + # need a stable copy for the mask prefix sum computed after the loop. + my_mask_val = c_one + if has_mask: + tid_has_expert = tid < c_E + my_mask_val = buffer_ops.buffer_load( + mask_rsrc, tid_has_expert.select(tid, c_zero), vec_width=1, dtype=T.i32 + ) + my_mask_val = tid_has_expert.select(my_mask_val, c_zero) + for _chunk in range_constexpr(0, E, K4_BLOCK): expert_idx = fx.Int32(_chunk) + tid tid_valid_expert = expert_idx < c_E @@ -1225,13 +1237,12 @@ def p23_kernel( raw_cnt = tid_valid_expert.select(raw_cnt, c_zero) blocks = (raw_cnt + c_unit - c_one) // c_unit padded = (raw_cnt == c_zero).select(c_zero, blocks * c_unit) - p23_mask_val = c_one if has_mask: - p23_mask_val = buffer_ops.buffer_load( + chunk_mask = buffer_ops.buffer_load( mask_rsrc, tid_valid_expert.select(expert_idx, c_zero), vec_width=1, dtype=T.i32 ) - p23_mask_val = tid_valid_expert.select(p23_mask_val, c_zero) - padded = (p23_mask_val == c_zero).select(c_zero, padded) + chunk_mask = tid_valid_expert.select(chunk_mask, c_zero) + padded = (chunk_mask == c_zero).select(c_zero, padded) _lds_store_raw(cumsum_mr, padded, ArithValue(expert_idx + c_one).index_cast(T.index)) gpu.barrier() @@ -1281,7 +1292,8 @@ def p23_kernel( local_idx_p23 = tid if has_mask: # EP: Compute mask cumsum for local expert index (register-only DPP scan). - p23_mv = _dpp_intra_wave_prefix_sum(p23_mask_val, lane, WARP_SIZE) + # Uses my_mask_val (loaded before the chunked loop) to avoid overwrite. + p23_mv = _dpp_intra_wave_prefix_sum(my_mask_val, lane, WARP_SIZE) # Cross-wave via scratch_mr is_last_lane_pm = lane == fx.Int32(WARP_SIZE - 1) @@ -1294,7 +1306,7 @@ def p23_kernel( wtp = _lds_load_raw(scatter_mr, ArithValue(fx.Int32(_w)).index_cast(T.index)) cross_pm = (wave > fx.Int32(_w)).select(cross_pm + wtp, cross_pm) p23_mask_inclusive = p23_mv + cross_pm - local_idx_p23 = p23_mask_inclusive - p23_mask_val + local_idx_p23 = p23_mask_inclusive - my_mask_val else: local_idx_p23 = tid @@ -1316,8 +1328,24 @@ def p23_kernel( is_t0_extra3 = tid == c_zero _if_t0_e3 = scf.IfOp(is_t0_extra3.ir_value()) with _if_then(_if_t0_e3): - for _e3 in range_constexpr(K4_BLOCK, E): - _lds_store_raw(cumsum_mr, fx.Int32(_e3), ArithValue(fx.Int32(_e3)).index_cast(T.index)) + if has_mask: + # EP: serially extend mask prefix sum for experts >= K4_BLOCK + prev_local = _lds_load_raw(cumsum_mr, ArithValue(fx.Int32(K4_BLOCK - 1)).index_cast(T.index)) + prev_mask = buffer_ops.buffer_load( + mask_rsrc, fx.Int32(K4_BLOCK - 1), vec_width=1, dtype=T.i32 + ) + prev_local = prev_local + prev_mask + for _e3 in range_constexpr(K4_BLOCK, E): + e3_mask = buffer_ops.buffer_load(mask_rsrc, fx.Int32(_e3), vec_width=1, dtype=T.i32) + _lds_store_raw( + cumsum_mr, prev_local, ArithValue(fx.Int32(_e3)).index_cast(T.index) + ) + prev_local = prev_local + e3_mask + else: + for _e3 in range_constexpr(K4_BLOCK, E): + _lds_store_raw( + cumsum_mr, fx.Int32(_e3), ArithValue(fx.Int32(_e3)).index_cast(T.index) + ) gpu.barrier() my_local_idx = _lds_load_raw(cumsum_mr, ArithValue(my_expert).index_cast(T.index)) diff --git a/tests/kernels/test_moe_sorting.py b/tests/kernels/test_moe_sorting.py index 7943afa1..52b3620d 100644 --- a/tests/kernels/test_moe_sorting.py +++ b/tests/kernels/test_moe_sorting.py @@ -271,7 +271,8 @@ def run_test(T, E, topk, unit_size=UNIT_SIZE, max_tokens=None): from kernels.moe_sorting_kernel import _compute_sub_tokens sub_tokens = _compute_sub_tokens(E) - path = "decode" if T <= sub_tokens else "prefill" + DECODE_MAX_T = 16 + path = "decode" if T <= min(sub_tokens, DECODE_MAX_T) else "prefill" if max_tokens is None and path == "decode": max_tokens = max(T, 8) @@ -574,6 +575,7 @@ def run_test_ep(T, E, topk, mask_ratio=0.5, unit_size=UNIT_SIZE): (8, 257, 9, 0.5), # DeepSeek-R1 decode + EP (1024, 257, 9, 0.5), # DeepSeek-R1 prefill + EP (8, 513, 9, 0.5), # Qwen3.5 decode + EP + (1024, 513, 9, 0.5), # Qwen3.5 prefill + EP (E > K4_BLOCK) ] From e5c6b1b7e9f0a51cdb379b892900292ef1685f5c Mon Sep 17 00:00:00 2001 From: amd-weisun Date: Mon, 18 May 2026 17:19:30 +0100 Subject: [PATCH 04/14] fix python format --- kernels/moe_sorting_kernel.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/kernels/moe_sorting_kernel.py b/kernels/moe_sorting_kernel.py index 9d55b135..12ffd7f3 100644 --- a/kernels/moe_sorting_kernel.py +++ b/kernels/moe_sorting_kernel.py @@ -1331,21 +1331,15 @@ def p23_kernel( if has_mask: # EP: serially extend mask prefix sum for experts >= K4_BLOCK prev_local = _lds_load_raw(cumsum_mr, ArithValue(fx.Int32(K4_BLOCK - 1)).index_cast(T.index)) - prev_mask = buffer_ops.buffer_load( - mask_rsrc, fx.Int32(K4_BLOCK - 1), vec_width=1, dtype=T.i32 - ) + prev_mask = buffer_ops.buffer_load(mask_rsrc, fx.Int32(K4_BLOCK - 1), vec_width=1, dtype=T.i32) prev_local = prev_local + prev_mask for _e3 in range_constexpr(K4_BLOCK, E): e3_mask = buffer_ops.buffer_load(mask_rsrc, fx.Int32(_e3), vec_width=1, dtype=T.i32) - _lds_store_raw( - cumsum_mr, prev_local, ArithValue(fx.Int32(_e3)).index_cast(T.index) - ) + _lds_store_raw(cumsum_mr, prev_local, ArithValue(fx.Int32(_e3)).index_cast(T.index)) prev_local = prev_local + e3_mask else: for _e3 in range_constexpr(K4_BLOCK, E): - _lds_store_raw( - cumsum_mr, fx.Int32(_e3), ArithValue(fx.Int32(_e3)).index_cast(T.index) - ) + _lds_store_raw(cumsum_mr, fx.Int32(_e3), ArithValue(fx.Int32(_e3)).index_cast(T.index)) gpu.barrier() my_local_idx = _lds_load_raw(cumsum_mr, ArithValue(my_expert).index_cast(T.index)) From 1c84b90edfa67886a60e48c3c3fd5ccdcacfcc0a Mon Sep 17 00:00:00 2001 From: amd-weisun Date: Mon, 18 May 2026 17:52:50 +0100 Subject: [PATCH 05/14] address copilot comments --- kernels/moe_sorting_kernel.py | 10 +++- tests/kernels/test_moe_sorting.py | 94 +++++++++++++++++++++++-------- 2 files changed, 77 insertions(+), 27 deletions(-) diff --git a/kernels/moe_sorting_kernel.py b/kernels/moe_sorting_kernel.py index 12ffd7f3..bcf9d5e0 100644 --- a/kernels/moe_sorting_kernel.py +++ b/kernels/moe_sorting_kernel.py @@ -1243,7 +1243,11 @@ def p23_kernel( ) chunk_mask = tid_valid_expert.select(chunk_mask, c_zero) padded = (chunk_mask == c_zero).select(c_zero, padded) - _lds_store_raw(cumsum_mr, padded, ArithValue(expert_idx + c_one).index_cast(T.index)) + raw_store_idx = expert_idx + c_one + oob = raw_store_idx >= fx.Int32(k4_smem_cols) + safe_store_idx = oob.select(c_zero, raw_store_idx) + safe_store_val = oob.select(c_zero, padded) + _lds_store_raw(cumsum_mr, safe_store_val, ArithValue(safe_store_idx).index_cast(T.index)) gpu.barrier() # Step 2: Prefix sum over cumsum LDS. When E <= K4_BLOCK (256), @@ -1890,8 +1894,8 @@ def launch_moe_sorting_prefill_path( ): """Low-level launcher for prefill path via HBM workspace. - For small T (<=512): fused P0_v2 (clear+scatter+count) + K4. - For large T: 4 separate kernels K1+K2+K3+K4. + For small T (<=2048): fused P0_v2 (clear+scatter+count) + K4. + For large T (>2048): 4 separate kernels K1+K2+K3+K4. Uses AOT-compiled dispatch after the first call for each sub-kernel to bypass JIT overhead. diff --git a/tests/kernels/test_moe_sorting.py b/tests/kernels/test_moe_sorting.py index 52b3620d..dce82f74 100644 --- a/tests/kernels/test_moe_sorting.py +++ b/tests/kernels/test_moe_sorting.py @@ -14,6 +14,7 @@ """ import argparse +import os import sys import pytest @@ -38,6 +39,7 @@ ) WARMUP_ITERS = 3 +RUN_BENCH = os.environ.get("MOE_SORTING_BENCH", "0") == "1" def _call_flydsl(topk_ids, topk_weights, E, model_dim=4096, topk=None, unit_size=UNIT_SIZE, expert_mask=None): @@ -77,11 +79,13 @@ def moe_sorting_reference(topk_ids, topk_weights, num_experts, unit_size=UNIT_SI sorted_expert_ids = torch.full((max_num_m_blocks,), -1, dtype=torch.int32, device=device) num_valid_ids = torch.zeros(2, dtype=torch.int32, device=device) + enabled = expert_mask.cpu().tolist() if expert_mask is not None else None + ids_cursor = 0 expert_ids_cursor = 0 skip_expert_num = 0 for eid in range(num_experts): - if expert_mask is not None and expert_mask[eid].item() == 0: + if enabled is not None and not enabled[eid]: skip_expert_num += 1 continue token_id, topk_pos = torch.where(topk_ids == eid) @@ -121,8 +125,15 @@ def generate_topk_ids(T, E, topk, device="cuda"): return topk_ids, topk_weights -def check_sorted_ids(ref_ids, gpu_ids, num_padded, topk, M, label="sorted_ids"): - """Compare sorted_ids up to num_padded, ignoring padding sentinels.""" +def check_sorted_ids( + ref_ids, gpu_ids, num_padded, topk, M, label="sorted_ids", topk_ids=None, gpu_eids=None, unit_size=UNIT_SIZE +): + """Compare sorted_ids up to num_padded, ignoring padding sentinels. + + When topk_ids and gpu_eids are provided, falls back to per-expert-block + validation: verifies each non-sentinel packed ID in a block maps to the + expert declared by sorted_expert_ids (catches cross-expert permutations). + """ sentinel = (topk << 24) | M ref_slice = ref_ids[:num_padded] gpu_slice = gpu_ids[:num_padded] @@ -137,32 +148,65 @@ def check_sorted_ids(ref_ids, gpu_ids, num_padded, topk, M, label="sorted_ids"): ref_valid = ref_slice[mask] gpu_valid = gpu_slice[mask] - # Per-expert block comparison: within each expert's padded block, the set of - # packed IDs must match (order within block may differ between implementations) if torch.equal(ref_valid, gpu_valid): print(f" [{label}] exact match ({n_valid} valid entries)") return True - # Fallback: set-equality check per expert block mismatch = (ref_valid != gpu_valid).sum().item() - print(f" [{label}] WARNING: {mismatch}/{n_valid} entries differ (checking set equality)") + print(f" [{label}] WARNING: {mismatch}/{n_valid} entries differ (checking per-expert blocks)") + + # Per-expert-block validation: verify each packed ID is in the correct expert block + if topk_ids is not None and gpu_eids is not None: + n_blocks = num_padded // unit_size + topk_ids_cpu = topk_ids.cpu() + gpu_slice_cpu = gpu_slice.cpu() + gpu_eids_cpu = gpu_eids.cpu() + ref_slice_cpu = ref_slice.cpu() + bad_blocks = [] + for blk in range(n_blocks): + start = blk * unit_size + end = start + unit_size + expert_id = gpu_eids_cpu[blk].item() + if expert_id < 0: + continue + blk_gpu = set() + blk_ref = set() + for i in range(start, end): + g = gpu_slice_cpu[i].item() + r = ref_slice_cpu[i].item() + if g != sentinel: + tok = g & 0xFFFFFF + topk_pos = g >> 24 + if tok < M and topk_pos < topk: + assigned_expert = topk_ids_cpu[tok, topk_pos].item() + if assigned_expert != expert_id: + bad_blocks.append((blk, expert_id, tok, topk_pos, assigned_expert)) + blk_gpu.add(g) + if r != sentinel: + blk_ref.add(r) + if blk_gpu != blk_ref and not bad_blocks: + bad_blocks.append((blk, expert_id, -1, -1, -1)) + if not bad_blocks: + print(f" [{label}] per-expert-block validated ({n_blocks} blocks) — OK") + return True + print(f" [{label}] FAIL: {len(bad_blocks)} block(s) have cross-expert errors") + for blk, eid, tok, tpos, actual in bad_blocks[:5]: + if tok >= 0: + print(f" block {blk}: expert_id={eid}, token {tok} topk_pos {tpos} -> expert {actual}") + else: + print(f" block {blk}: expert_id={eid}, set mismatch") + return False + # Fallback: global set equality (no topk_ids/gpu_eids provided) ref_set = set(ref_valid.cpu().tolist()) gpu_set = set(gpu_valid.cpu().tolist()) if ref_set == gpu_set: print(f" [{label}] set-equal (order differs) — OK") return True - # Per-expert-block validation: for EP/atomic-scatter, within-expert ordering may differ. - # Check that each expert's block contains the same set of packed IDs. - missing_from_gpu = ref_set - gpu_set - extra_in_gpu = gpu_set - ref_set - if not missing_from_gpu and not extra_in_gpu: - print(f" [{label}] multiset-equal — OK") - return True - - print(f" [{label}] MISMATCH (missing={len(missing_from_gpu)}, extra={len(extra_in_gpu)})") - # Print first few diffs + missing = ref_set - gpu_set + extra = gpu_set - ref_set + print(f" [{label}] MISMATCH (missing={len(missing)}, extra={len(extra)})") diff_mask = ref_valid != gpu_valid diff_indices = diff_mask.nonzero(as_tuple=True)[0][:10] for idx in diff_indices: @@ -317,8 +361,10 @@ def run_test(T, E, topk, unit_size=UNIT_SIZE, max_tokens=None): num_padded = ref_nvalid[0].item() - # 2. sorted_ids - passed &= check_sorted_ids(ref_ids, gpu_ids, num_padded, topk, T) + # 2. sorted_ids (per-expert-block validation) + passed &= check_sorted_ids( + ref_ids, gpu_ids, num_padded, topk, T, topk_ids=topk_ids, gpu_eids=gpu_eids, unit_size=unit_size + ) # 3. sorted_weights passed &= check_sorted_weights(ref_w, gpu_w, ref_ids, topk, T, gpu_ids=gpu_ids, num_padded=num_padded) @@ -331,15 +377,13 @@ def run_test(T, E, topk, unit_size=UNIT_SIZE, max_tokens=None): print(f" [moe_buf_zeroed] {'OK' if moe_buf_zero else 'FAIL'}") passed &= moe_buf_zero - # --- Benchmark --- + # --- Benchmark (opt-in via MOE_SORTING_BENCH=1) --- gpu_time_us = None - if passed: - # Warmup + if passed and RUN_BENCH: for _ in range(WARMUP_ITERS): _call_flydsl(topk_ids, topk_weights, E, model_dim=4096, topk=topk, unit_size=unit_size) torch.cuda.synchronize() - # Timed runs start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) start.record() @@ -548,7 +592,9 @@ def run_test_ep(T, E, topk, mask_ratio=0.5, unit_size=UNIT_SIZE): passed &= nv_ok num_padded = ref_nvalid[0].item() - passed &= check_sorted_ids(ref_ids, gpu_ids, num_padded, topk, T) + passed &= check_sorted_ids( + ref_ids, gpu_ids, num_padded, topk, T, topk_ids=topk_ids, gpu_eids=gpu_eids, unit_size=unit_size + ) passed &= check_sorted_weights(ref_w, gpu_w, ref_ids, topk, T, gpu_ids=gpu_ids, num_padded=num_padded) passed &= check_expert_ids(ref_eids, gpu_eids) From 281324cb8111eb0e28f2a5252e1a6966995b4854 Mon Sep 17 00:00:00 2001 From: amd-weisun Date: Tue, 19 May 2026 14:27:20 +0100 Subject: [PATCH 06/14] update hardcoded parameter --- kernels/moe_sorting_kernel.py | 4 ++-- tests/kernels/test_moe_sorting.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/kernels/moe_sorting_kernel.py b/kernels/moe_sorting_kernel.py index bcf9d5e0..64dcde60 100644 --- a/kernels/moe_sorting_kernel.py +++ b/kernels/moe_sorting_kernel.py @@ -1661,7 +1661,7 @@ def moe_sorting_get_workspace_size(M, num_experts, topk, unit_size=UNIT_SIZE): """Return workspace size (in i32 elements) needed for the prefill path. Returns 0 if the decode path will be used.""" sub_tokens = _compute_sub_tokens(num_experts) - DECODE_MAX_T = 16 + DECODE_MAX_T = min(sub_tokens, max(16, BLOCK_SIZE // max(topk, num_experts // 8))) if M <= min(sub_tokens, DECODE_MAX_T): return 0 mesh_stride = ((M + unit_size - 1) // unit_size) * unit_size @@ -1722,7 +1722,7 @@ def moe_sorting_flydsl( else: mask_tensor = expert_mask - DECODE_MAX_T = 16 + DECODE_MAX_T = min(sub_tokens, max(16, BLOCK_SIZE // max(topk, num_experts // 8))) if M <= min(sub_tokens, DECODE_MAX_T): max_tokens = max(M, 8) diff --git a/tests/kernels/test_moe_sorting.py b/tests/kernels/test_moe_sorting.py index dce82f74..e0928a58 100644 --- a/tests/kernels/test_moe_sorting.py +++ b/tests/kernels/test_moe_sorting.py @@ -312,10 +312,10 @@ def run_test(T, E, topk, unit_size=UNIT_SIZE, max_tokens=None): """ # Let moe_sorting_flydsl auto-select decode/prefill path. # max_tokens is only needed for explicit decode-path override. - from kernels.moe_sorting_kernel import _compute_sub_tokens + from kernels.moe_sorting_kernel import BLOCK_SIZE, _compute_sub_tokens sub_tokens = _compute_sub_tokens(E) - DECODE_MAX_T = 16 + DECODE_MAX_T = min(sub_tokens, max(16, BLOCK_SIZE // max(topk, E // 8))) path = "decode" if T <= min(sub_tokens, DECODE_MAX_T) else "prefill" if max_tokens is None and path == "decode": @@ -535,10 +535,10 @@ def test_moe_sorting_prefill_full(T, E, topk): def run_test_ep(T, E, topk, mask_ratio=0.5, unit_size=UNIT_SIZE): """Run MoE sorting test with expert_mask (EP mode).""" - from kernels.moe_sorting_kernel import _compute_sub_tokens + from kernels.moe_sorting_kernel import BLOCK_SIZE, _compute_sub_tokens sub_tokens = _compute_sub_tokens(E) - DECODE_MAX_T = 16 + DECODE_MAX_T = min(sub_tokens, max(16, BLOCK_SIZE // max(topk, E // 8))) if T <= min(sub_tokens, DECODE_MAX_T): path = "decode" else: From 105bcc5ff7e0cc0a94e9382b04525a69c06b5169 Mon Sep 17 00:00:00 2001 From: amd-weisun Date: Tue, 19 May 2026 14:37:27 +0100 Subject: [PATCH 07/14] add barrier for large E --- kernels/moe_sorting_kernel.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/kernels/moe_sorting_kernel.py b/kernels/moe_sorting_kernel.py index 64dcde60..f207b237 100644 --- a/kernels/moe_sorting_kernel.py +++ b/kernels/moe_sorting_kernel.py @@ -1327,8 +1327,10 @@ def p23_kernel( # Step 3: Write sorted_expert_ids for THIS expert (using local_idx_p23 for EP) # Store local_idx to LDS cumsum[tid], barrier, read cumsum[my_expert] _lds_store_raw(cumsum_mr, local_idx_p23, ArithValue(tid).index_cast(T.index)) - # For E > K4_BLOCK: thread 0 writes local_idx for extra experts + # For E > K4_BLOCK: thread 0 extends local_idx using cumsum[K4_BLOCK-1]. + # Barrier ensures all threads have written before thread 0 reads. if E > K4_BLOCK: + gpu.barrier() is_t0_extra3 = tid == c_zero _if_t0_e3 = scf.IfOp(is_t0_extra3.ir_value()) with _if_then(_if_t0_e3): From c2be90e6e96b3820e657bca36eed8fb777c8dec5 Mon Sep 17 00:00:00 2001 From: amd-weisun Date: Wed, 20 May 2026 16:26:25 +0100 Subject: [PATCH 08/14] adress some feedback --- kernels/moe_sorting_kernel.py | 679 ++++++++++------------------------ 1 file changed, 205 insertions(+), 474 deletions(-) diff --git a/kernels/moe_sorting_kernel.py b/kernels/moe_sorting_kernel.py index f207b237..70bef216 100644 --- a/kernels/moe_sorting_kernel.py +++ b/kernels/moe_sorting_kernel.py @@ -91,6 +91,90 @@ def _dpp_intra_wave_prefix_sum(val, lane, WARP_SIZE): return val +@flyc.jit +def _allwave_inclusive_prefix_sum(val, lane, wave, scratch_mr, NUM_WAVES, WARP_SIZE): + """DPP intra-wave prefix sum + cross-wave LDS accumulation. + + Returns (intra_wave_val, inclusive) where intra_wave_val is the per-wave + result (needed for total_padded computation) and inclusive is the full + cross-wave inclusive prefix sum. + """ + val = _dpp_intra_wave_prefix_sum(val, lane, WARP_SIZE) + if lane == fx.Int32(WARP_SIZE - 1): + _lds_store_raw(scratch_mr, val, wave) + gpu.barrier() + cross = fx.Int32(0) + for _w in range_constexpr(NUM_WAVES - 1): + wt = _lds_load_raw(scratch_mr, fx.Int32(_w)) + cross = (wave > fx.Int32(_w)).select(cross + wt, cross) + return val, val + cross + + +@flyc.jit +def _zero_moe_buf_grid_stride(moe_buf_rsrc, gid_v4, stride_v4, total_v4, oob_idx): + """Grid-stride loop zeroing moe_buf via vectorized buffer_store.""" + c_one = fx.Int32(1) + niters = (total_v4 + stride_v4 - c_one) // stride_v4 + c_zero_v4 = fx.Vector.filled(4, 0, fx.Int32) + c4 = fx.Int32(4) + for _z in range(fx.Index(0), ArithValue(niters).index_cast(T.index), fx.Index(1)): + idx = gid_v4 + fx.Int32(_z) * stride_v4 + valid = idx < total_v4 + buffer_ops.buffer_store(c_zero_v4, moe_buf_rsrc, valid.select(idx * c4, oob_idx)) + + +def _extend_prefix_sum_serial(mr, start_block, E, load_fn, store_fn): + """Thread-0 serial extension of prefix sum for experts >= start_block. + + Reads mr[start_block], then accumulates mr[start_block+1..E] in place. + Returns the final accumulated value (mr[E]). + """ + prev = load_fn(mr, fx.Int32(start_block)) + for _ext in range_constexpr(start_block, E): + cur = load_fn(mr, fx.Int32(_ext + 1)) + new_val = prev + cur + store_fn(mr, new_val, fx.Int32(_ext + 1)) + prev = new_val + return prev + + +@flyc.jit +def _write_expert_id_blocks(sorted_e_rsrc, local_eid, blk_start, n_blks): + """Write local_eid to sorted_expert_ids[blk_start .. blk_start+n_blks).""" + for _jb in range(fx.Index(0), ArithValue(n_blks).index_cast(T.index), fx.Index(1)): + blk_idx = blk_start + fx.Int32(_jb) + buffer_ops.buffer_store(local_eid, sorted_e_rsrc, blk_idx) + + +@flyc.jit +def _extend_local_idx_for_extra_experts(cumsum_mr, mask_rsrc, K4_BLOCK, E, has_mask): + """Thread-0: write local expert indices for experts >= K4_BLOCK to cumsum_mr.""" + if has_mask: + prev_local = _lds_load_raw(cumsum_mr, fx.Int32(K4_BLOCK - 1)) + prev_mask = buffer_ops.buffer_load(mask_rsrc, fx.Int32(K4_BLOCK - 1), vec_width=1, dtype=T.i32) + prev_local = prev_local + prev_mask + for _e3 in range_constexpr(K4_BLOCK, E): + e3_mask = buffer_ops.buffer_load(mask_rsrc, fx.Int32(_e3), vec_width=1, dtype=T.i32) + _lds_store_raw(cumsum_mr, prev_local, fx.Int32(_e3)) + prev_local = prev_local + e3_mask + else: + for _e3 in range_constexpr(K4_BLOCK, E): + _lds_store_raw(cumsum_mr, fx.Int32(_e3), fx.Int32(_e3)) + + +@flyc.jit +def _fill_sentinel_slots(sorted_ids_rsrc, sorted_w_rsrc, start, count, sentinel, block_size, tid, oob_idx): + """Cooperative sentinel fill: threads fill [start, start+count) with sentinels.""" + c_zero = fx.Int32(0) + end = start + count + niters = (count + fx.Int32(block_size) - fx.Int32(1)) // fx.Int32(block_size) + for _p in range(fx.Index(0), ArithValue(niters).index_cast(T.index), fx.Index(1)): + slot = start + fx.Int32(_p) * fx.Int32(block_size) + tid + safe = (slot < end).select(slot, oob_idx) + buffer_ops.buffer_store(sentinel, sorted_ids_rsrc, safe) + buffer_ops.buffer_store(c_zero, sorted_w_rsrc, safe) + + # --------------------------------------------------------------------------- # AOT-compiled dispatch caches — keyed by constexpr values. # After the first JIT call (which compiles the kernel), flyc.compile() @@ -258,29 +342,16 @@ def moe_sorting_decode_kernel( c_sentinel = fx.Int32((topk << 24)) # =================== MOE_BUF ZEROING (blocks > 0 only) =============== - is_zero_block = bid != c_zero_i32 - _if_zero = scf.IfOp(is_zero_block.ir_value()) - with _if_then(_if_zero): + if bid != c_zero_i32: zero_gid_v4 = (bid - c_one_i32) * fx.Int32(DECODE_BLOCK) + tid num_zero_blocks = gpu.grid_dim.x - c_one_i32 zero_stride_v4 = num_zero_blocks * fx.Int32(DECODE_BLOCK) - i32_moe_buf_v4 = i32_moe_buf_elems >> fx.Int32(2) - zero_niters = (i32_moe_buf_v4 + zero_stride_v4 - c_one_i32) // zero_stride_v4 - _zs = fx.Index(0) - _ze = ArithValue(zero_niters).index_cast(T.index) - _z1 = fx.Index(1) - c_zero_v4 = fx.Vector.filled(4, 0, fx.Int32) - c4_i32 = fx.Int32(4) - for _z in range(_zs, _ze, _z1): - z_idx_v4 = zero_gid_v4 + fx.Int32(_z) * zero_stride_v4 - z_valid = z_idx_v4 < i32_moe_buf_v4 - z_elem = z_valid.select(z_idx_v4 * c4_i32, c_oob_idx) - buffer_ops.buffer_store(c_zero_v4, moe_buf_rsrc, z_elem) + _zero_moe_buf_grid_stride( + moe_buf_rsrc, zero_gid_v4, zero_stride_v4, i32_moe_buf_elems >> fx.Int32(2), c_oob_idx + ) # =================== SORTING (block 0 only) ========================== - is_sort_block = bid == c_zero_i32 - _if_sort = scf.IfOp(is_sort_block.ir_value()) - with _if_then(_if_sort): + if bid == c_zero_i32: # ========================= PHASE 1: Histogram ========================= # Clear mesh region — unconditional store to safe index when out of bounds for i_clear in range_constexpr(0, sub_tokens * smem_cols, DECODE_BLOCK): @@ -412,21 +483,8 @@ def moe_sorting_decode_kernel( # DPP prefix sum — all NUM_WAVES waves active ps_tid_valid = tid < c_E - val = ps_tid_valid.select(_lds_load(cumdup_mr, ArithValue(tid + c_one_i32).index_cast(T.index)), c_zero_i32) - val = _dpp_intra_wave_prefix_sum(val, lane, WARP_SIZE) - - # Cross-wave accumulation via scratch LDS - is_last_lane_ps = lane == fx.Int32(WARP_SIZE - 1) - _if_ll_ps = scf.IfOp(is_last_lane_ps.ir_value()) - with _if_then(_if_ll_ps): - _lds_store_raw(scratch_mr, val, ArithValue(wave).index_cast(T.index)) - gpu.barrier() - cross_ps = c_zero_i32 - for _w_ps in range_constexpr(NUM_WAVES - 1): - w_ps_total = _lds_load_raw(scratch_mr, ArithValue(fx.Int32(_w_ps)).index_cast(T.index)) - cross_ps = (wave > fx.Int32(_w_ps)).select(cross_ps + w_ps_total, cross_ps) - - inclusive_ps = val + cross_ps + val = ps_tid_valid.select(_lds_load(cumdup_mr, tid + c_one_i32), c_zero_i32) + _, inclusive_ps = _allwave_inclusive_prefix_sum(val, lane, wave, scratch_mr, NUM_WAVES, WARP_SIZE) _lds_store( cumdup_mr, ps_tid_valid.select(inclusive_ps, c_zero_i32), @@ -436,14 +494,8 @@ def moe_sorting_decode_kernel( # For E > DECODE_BLOCK: thread 0 serially extends if E > DECODE_BLOCK: - _if_t0_ext = scf.IfOp(is_t0.ir_value()) - with _if_then(_if_t0_ext): - prev_ext = _lds_load(cumdup_mr, ArithValue(fx.Int32(DECODE_BLOCK)).index_cast(T.index)) - for _ext in range_constexpr(DECODE_BLOCK, E): - cur_ext = _lds_load(cumdup_mr, ArithValue(fx.Int32(_ext + 1)).index_cast(T.index)) - new_ext = prev_ext + cur_ext - _lds_store(cumdup_mr, new_ext, ArithValue(fx.Int32(_ext + 1)).index_cast(T.index)) - prev_ext = new_ext + if is_t0: + _extend_prefix_sum_serial(cumdup_mr, DECODE_BLOCK, E, _lds_load, _lds_store) gpu.barrier() # cumdup[0] = 0 @@ -483,21 +535,8 @@ def moe_sorting_decode_kernel( # All-wave DPP prefix sum over mask values in cumdup m_tid_valid = tid < c_E - mval = m_tid_valid.select( - _lds_load(cumdup_mr, ArithValue(tid + c_one_i32).index_cast(T.index)), c_zero_i32 - ) - mval = _dpp_intra_wave_prefix_sum(mval, lane, WARP_SIZE) - - is_last_lane_m = lane == fx.Int32(WARP_SIZE - 1) - _if_ll_m = scf.IfOp(is_last_lane_m.ir_value()) - with _if_then(_if_ll_m): - _lds_store_raw(scratch_mr, mval, ArithValue(wave).index_cast(T.index)) - gpu.barrier() - cross_m = c_zero_i32 - for _wm in range_constexpr(NUM_WAVES - 1): - wm_total = _lds_load_raw(scratch_mr, ArithValue(fx.Int32(_wm)).index_cast(T.index)) - cross_m = (wave > fx.Int32(_wm)).select(cross_m + wm_total, cross_m) - inclusive_m = mval + cross_m + mval = m_tid_valid.select(_lds_load(cumdup_mr, tid + c_one_i32), c_zero_i32) + _, inclusive_m = _allwave_inclusive_prefix_sum(mval, lane, wave, scratch_mr, NUM_WAVES, WARP_SIZE) _lds_store( cumdup_mr, m_tid_valid.select(inclusive_m, c_zero_i32), @@ -506,14 +545,8 @@ def moe_sorting_decode_kernel( gpu.barrier() if E > DECODE_BLOCK: - _if_t0_ext_m = scf.IfOp(is_t0.ir_value()) - with _if_then(_if_t0_ext_m): - prev_m = _lds_load(cumdup_mr, ArithValue(fx.Int32(DECODE_BLOCK)).index_cast(T.index)) - for _ext_m in range_constexpr(DECODE_BLOCK, E): - cur_m = _lds_load(cumdup_mr, ArithValue(fx.Int32(_ext_m + 1)).index_cast(T.index)) - new_m = prev_m + cur_m - _lds_store(cumdup_mr, new_m, ArithValue(fx.Int32(_ext_m + 1)).index_cast(T.index)) - prev_m = new_m + if is_t0: + _extend_prefix_sum_serial(cumdup_mr, DECODE_BLOCK, E, _lds_load, _lds_store) gpu.barrier() _lds_store(cumdup_mr, c_zero_i32, c_zero_i32) @@ -548,9 +581,7 @@ def moe_sorting_decode_kernel( blk_start = e_start // c_unit blk_end = e_end // c_unit n_blks_wr = eid_wr_valid.select(blk_end - blk_start, c_zero_i32) - for _jb in range(fx.Index(0), ArithValue(n_blks_wr).index_cast(T.index), fx.Index(1)): - blk_idx = blk_start + fx.Int32(_jb) - buffer_ops.buffer_store(local_eid, sorted_e_rsrc, blk_idx) + _write_expert_id_blocks(sorted_e_rsrc, local_eid, blk_start, n_blks_wr) gpu.barrier() # Store cumdup[E] = cumsum[E]. @@ -561,19 +592,17 @@ def moe_sorting_decode_kernel( gpu.barrier() # ====================== PRE-FILL: Sentinel fill (cooperative) =========== - # Fill ALL total_padded slots with sentinels BEFORE scatter. - # Scatter will overwrite real token positions. - sentinel_val_pre = c_sentinel | tokens - c_zero_pre = c_zero_i32 - cs_E_ix_pre = ArithValue(c_E).index_cast(T.index) - total_padded_pre = _lds_load(cumdup_mr, cs_E_ix_pre) - n_pre_iters = (total_padded_pre + fx.Int32(DECODE_BLOCK) - c_one_i32) // fx.Int32(DECODE_BLOCK) - for _pre in range(fx.Index(0), ArithValue(n_pre_iters).index_cast(T.index), fx.Index(1)): - pre_slot = fx.Int32(_pre) * fx.Int32(DECODE_BLOCK) + tid - pre_valid = pre_slot < total_padded_pre - safe_pre = pre_valid.select(pre_slot, c_oob_idx) - buffer_ops.buffer_store(sentinel_val_pre, sorted_ids_rsrc, safe_pre) - buffer_ops.buffer_store(c_zero_pre, sorted_w_rsrc, safe_pre) + total_padded_pre = _lds_load(cumdup_mr, ArithValue(c_E).index_cast(T.index)) + _fill_sentinel_slots( + sorted_ids_rsrc, + sorted_w_rsrc, + c_zero_i32, + total_padded_pre, + c_sentinel | tokens, + DECODE_BLOCK, + tid, + c_oob_idx, + ) gpu.barrier() # ====================== PHASE 3: Scatter ============================== @@ -834,8 +863,7 @@ def p0_scatter_kernel( eid = buffer_ops.buffer_load(topk_rsrc, safe_flat, vec_width=1, dtype=T.i32) byte_offset = eid * i32_mesh_stride + token_id val_i8 = ArithValue(topk_slot + c_one).trunci(T.i8) - _if_valid_k2 = scf.IfOp(valid.ir_value()) - with _if_then(_if_valid_k2): + if valid: buffer_ops.buffer_store(val_i8, ws_rsrc, byte_offset, offset_is_bytes=True) @flyc.jit @@ -933,8 +961,7 @@ def p1_count_kernel( # Cross-warp reduce via LDS: lane 0 of each warp writes partial sum is_lane0 = lane == c_zero - _if_l0 = scf.IfOp(is_lane0.ir_value()) - with _if_then(_if_l0): + if is_lane0: wave_ix = ArithValue(wave).index_cast(T.index) _lds_store_raw(reduce_mr, cnt, wave_ix) gpu.barrier() @@ -943,7 +970,7 @@ def p1_count_kernel( is_t0 = tid == c_zero total = c_zero for _w in range_constexpr(K3_NUM_WAVES): - total = total + _lds_load_raw(reduce_mr, ArithValue(fx.Int32(_w)).index_cast(T.index)) + total = total + _lds_load_raw(reduce_mr, fx.Int32(_w)) cs_offset = i32_mesh_size + eid c_oob_idx = fx.Int32(0x7FFFFFFF) @@ -1053,8 +1080,7 @@ def p0v2_kernel( val_i8 = ArithValue(is_mine.select(topk_slot + c_one, c_zero)).trunci(T.i8) # Byte-mode buffer_store with OOB offset crashes on AMD GPUs. # Use conditional branch to skip the store for non-matching threads. - _if_mine = scf.IfOp(is_mine.ir_value()) - with _if_then(_if_mine): + if is_mine: buffer_ops.buffer_store(val_i8, ws_rsrc, byte_offset, offset_is_bytes=True) gpu.barrier() @@ -1092,8 +1118,7 @@ def p0v2_kernel( # Cross-warp reduce via LDS: lane 0 of each warp writes partial sum is_lane0 = lane == c_zero - _if_l0 = scf.IfOp(is_lane0.ir_value()) - with _if_then(_if_l0): + if is_lane0: wave_ix = ArithValue(wave).index_cast(T.index) _lds_store_raw(reduce_mr, cnt, wave_ix) gpu.barrier() @@ -1102,7 +1127,7 @@ def p0v2_kernel( is_t0 = tid == c_zero total = c_zero for _w in range_constexpr(P0V2_NUM_WAVES): - total = total + _lds_load_raw(reduce_mr, ArithValue(fx.Int32(_w)).index_cast(T.index)) + total = total + _lds_load_raw(reduce_mr, fx.Int32(_w)) cs_offset = i32_mesh_size + eid c_oob_idx = fx.Int32(0x7FFFFFFF) @@ -1187,21 +1212,13 @@ def p23_kernel( is_zero_block = bid >= c_E # ================ MOE_BUF ZEROING (blocks >= E) ================== - _if_zero = scf.IfOp(is_zero_block.ir_value()) - with _if_then(_if_zero): + if is_zero_block: moe_buf_rsrc = buffer_ops.create_buffer_resource(moe_buf, max_size=True) - zero_base_bid = bid - c_E - zero_gid_v4 = zero_base_bid * fx.Int32(K4_BLOCK) + tid - num_zero_blocks = gpu.grid_dim.x - c_E - zero_stride_v4 = num_zero_blocks * fx.Int32(K4_BLOCK) - i32_moe_buf_v4 = i32_moe_buf_elems >> fx.Int32(2) - zero_niters = (i32_moe_buf_v4 + zero_stride_v4 - c_one) // zero_stride_v4 - c_zero_v4 = fx.Vector.filled(4, 0, fx.Int32) - for _z in range(fx.Index(0), ArithValue(zero_niters).index_cast(T.index), fx.Index(1)): - z_idx_v4 = zero_gid_v4 + fx.Int32(_z) * zero_stride_v4 - z_valid = z_idx_v4 < i32_moe_buf_v4 - z_elem = z_valid.select(z_idx_v4 * c4, c_oob_idx) - buffer_ops.buffer_store(c_zero_v4, moe_buf_rsrc, z_elem) + zero_gid_v4 = (bid - c_E) * fx.Int32(K4_BLOCK) + tid + zero_stride_v4 = (gpu.grid_dim.x - c_E) * fx.Int32(K4_BLOCK) + _zero_moe_buf_grid_stride( + moe_buf_rsrc, zero_gid_v4, zero_stride_v4, i32_moe_buf_elems >> fx.Int32(2), c_oob_idx + ) # ================ PARALLEL PREFIX-SUM + MESH SCATTER (blocks 0..E-1) == # Each block independently: prefix sum (redundant), scatter for its expert only. @@ -1213,10 +1230,8 @@ def p23_kernel( # Process E experts in chunks of K4_BLOCK (256). Most models have # E <= 256, so the extra chunk is only needed for E > 256 # (e.g. DeepSeek-R1 with 256 routed + 1 shared = 257). - is_t0_init = tid == c_zero - _if_init_cs = scf.IfOp(is_t0_init.ir_value()) - with _if_then(_if_init_cs): - _lds_store_raw(cumsum_mr, c_zero, ArithValue(c_zero).index_cast(T.index)) + if tid == c_zero: + _lds_store_raw(cumsum_mr, c_zero, c_zero) # EP: load this thread's own mask value BEFORE the chunked loop. # The chunked loop overwrites p23_mask_val in later chunks, so we @@ -1247,107 +1262,60 @@ def p23_kernel( oob = raw_store_idx >= fx.Int32(k4_smem_cols) safe_store_idx = oob.select(c_zero, raw_store_idx) safe_store_val = oob.select(c_zero, padded) - _lds_store_raw(cumsum_mr, safe_store_val, ArithValue(safe_store_idx).index_cast(T.index)) + _lds_store_raw(cumsum_mr, safe_store_val, safe_store_idx) gpu.barrier() # Step 2: Prefix sum over cumsum LDS. When E <= K4_BLOCK (256), # a single DPP pass covers all experts. When E > K4_BLOCK, we # do the DPP pass for the first K4_BLOCK elements, then serially # accumulate the remaining entries from thread 0. - val = _lds_load_raw(cumsum_mr, ArithValue(tid + c_one).index_cast(T.index)) - val = _dpp_intra_wave_prefix_sum(val, lane, WARP_SIZE) - - is_last_lane_ps = lane == fx.Int32(WARP_SIZE - 1) - _if_ll_ps = scf.IfOp(is_last_lane_ps.ir_value()) - with _if_then(_if_ll_ps): - _lds_store_raw(scatter_mr, val, ArithValue(wave).index_cast(T.index)) - gpu.barrier() - cross_offset = c_zero - for _w in range_constexpr(K4_NUM_WAVES - 1): - w_total = _lds_load_raw(scatter_mr, ArithValue(fx.Int32(_w)).index_cast(T.index)) - cross_offset = (wave > fx.Int32(_w)).select(cross_offset + w_total, cross_offset) + val = _lds_load_raw(cumsum_mr, tid + c_one) + val, inclusive_prefix = _allwave_inclusive_prefix_sum(val, lane, wave, scatter_mr, K4_NUM_WAVES, WARP_SIZE) total_padded = c_zero for _w in range_constexpr(K4_NUM_WAVES): - total_padded = total_padded + _lds_load_raw(scatter_mr, ArithValue(fx.Int32(_w)).index_cast(T.index)) - - inclusive_prefix = val + cross_offset - _lds_store_raw(cumsum_mr, inclusive_prefix, ArithValue(tid + c_one).index_cast(T.index)) + total_padded = total_padded + _lds_load_raw(scatter_mr, fx.Int32(_w)) + _lds_store_raw(cumsum_mr, inclusive_prefix, tid + c_one) gpu.barrier() # For E > K4_BLOCK: thread 0 serially extends the prefix sum - # for experts K4_BLOCK..E-1 (at most a few iterations). if E > K4_BLOCK: - _if_t0_extra = scf.IfOp(is_t0_init.ir_value()) - with _if_then(_if_t0_extra): - prev_sum = _lds_load_raw(cumsum_mr, ArithValue(fx.Int32(K4_BLOCK)).index_cast(T.index)) - for _extra in range_constexpr(K4_BLOCK, E): - cur = _lds_load_raw(cumsum_mr, ArithValue(fx.Int32(_extra + 1)).index_cast(T.index)) - new_sum = prev_sum + cur - _lds_store_raw(cumsum_mr, new_sum, ArithValue(fx.Int32(_extra + 1)).index_cast(T.index)) - prev_sum = new_sum - total_padded = prev_sum + if tid == c_zero: + total_padded = _extend_prefix_sum_serial(cumsum_mr, K4_BLOCK, E, _lds_load_raw, _lds_store_raw) gpu.barrier() - total_padded = _lds_load_raw(cumsum_mr, ArithValue(c_E).index_cast(T.index)) + total_padded = _lds_load_raw(cumsum_mr, c_E) # Read my_start and my_end from cumsum LDS - my_start = _lds_load_raw(cumsum_mr, ArithValue(my_expert).index_cast(T.index)) - my_end = _lds_load_raw(cumsum_mr, ArithValue(my_expert + c_one).index_cast(T.index)) + my_start = _lds_load_raw(cumsum_mr, my_expert) + my_end = _lds_load_raw(cumsum_mr, my_expert + c_one) local_idx_p23 = tid if has_mask: # EP: Compute mask cumsum for local expert index (register-only DPP scan). # Uses my_mask_val (loaded before the chunked loop) to avoid overwrite. - p23_mv = _dpp_intra_wave_prefix_sum(my_mask_val, lane, WARP_SIZE) - - # Cross-wave via scratch_mr - is_last_lane_pm = lane == fx.Int32(WARP_SIZE - 1) - _if_ll_pm = scf.IfOp(is_last_lane_pm.ir_value()) - with _if_then(_if_ll_pm): - _lds_store_raw(scatter_mr, p23_mv, ArithValue(wave).index_cast(T.index)) - gpu.barrier() - cross_pm = c_zero - for _w in range_constexpr(K4_NUM_WAVES - 1): - wtp = _lds_load_raw(scatter_mr, ArithValue(fx.Int32(_w)).index_cast(T.index)) - cross_pm = (wave > fx.Int32(_w)).select(cross_pm + wtp, cross_pm) - p23_mask_inclusive = p23_mv + cross_pm + _, p23_mask_inclusive = _allwave_inclusive_prefix_sum( + my_mask_val, lane, wave, scatter_mr, K4_NUM_WAVES, WARP_SIZE + ) local_idx_p23 = p23_mask_inclusive - my_mask_val else: local_idx_p23 = tid # Block 0, thread 0 writes num_valid_ids - is_b0 = bid == c_zero - is_t0 = tid == c_zero - is_b0_t0 = is_b0 & is_t0 - _if_nv = scf.IfOp(is_b0_t0.ir_value()) - with _if_then(_if_nv): + if (bid == c_zero) & (tid == c_zero): nvalid_rsrc = buffer_ops.create_buffer_resource(num_valid_ids, max_size=True) buffer_ops.buffer_store(total_padded, nvalid_rsrc, c_zero) buffer_ops.buffer_store(i32_tokens, nvalid_rsrc, c_one) # Step 3: Write sorted_expert_ids for THIS expert (using local_idx_p23 for EP) # Store local_idx to LDS cumsum[tid], barrier, read cumsum[my_expert] - _lds_store_raw(cumsum_mr, local_idx_p23, ArithValue(tid).index_cast(T.index)) + _lds_store_raw(cumsum_mr, local_idx_p23, tid) # For E > K4_BLOCK: thread 0 extends local_idx using cumsum[K4_BLOCK-1]. # Barrier ensures all threads have written before thread 0 reads. if E > K4_BLOCK: gpu.barrier() - is_t0_extra3 = tid == c_zero - _if_t0_e3 = scf.IfOp(is_t0_extra3.ir_value()) - with _if_then(_if_t0_e3): - if has_mask: - # EP: serially extend mask prefix sum for experts >= K4_BLOCK - prev_local = _lds_load_raw(cumsum_mr, ArithValue(fx.Int32(K4_BLOCK - 1)).index_cast(T.index)) - prev_mask = buffer_ops.buffer_load(mask_rsrc, fx.Int32(K4_BLOCK - 1), vec_width=1, dtype=T.i32) - prev_local = prev_local + prev_mask - for _e3 in range_constexpr(K4_BLOCK, E): - e3_mask = buffer_ops.buffer_load(mask_rsrc, fx.Int32(_e3), vec_width=1, dtype=T.i32) - _lds_store_raw(cumsum_mr, prev_local, ArithValue(fx.Int32(_e3)).index_cast(T.index)) - prev_local = prev_local + e3_mask - else: - for _e3 in range_constexpr(K4_BLOCK, E): - _lds_store_raw(cumsum_mr, fx.Int32(_e3), ArithValue(fx.Int32(_e3)).index_cast(T.index)) + if tid == c_zero: + _extend_local_idx_for_extra_experts(cumsum_mr, mask_rsrc, K4_BLOCK, E, has_mask) gpu.barrier() - my_local_idx = _lds_load_raw(cumsum_mr, ArithValue(my_expert).index_cast(T.index)) + my_local_idx = _lds_load_raw(cumsum_mr, my_expert) sorted_e_rsrc = buffer_ops.create_buffer_resource(sorted_expert_ids, max_size=True) blk_start = my_start // c_unit @@ -1401,23 +1369,13 @@ def p23_kernel( + h3.select(c_one, c_zero) ) - my_cnt = _dpp_intra_wave_prefix_sum(my_cnt, lane, WARP_SIZE) - - # my_cnt is now intra-wave inclusive prefix sum of per-thread token counts. - # Cross-wave reduction via LDS scratch. - is_last_lane_sc = lane == fx.Int32(WARP_SIZE - 1) - _if_ll_sc = scf.IfOp(is_last_lane_sc.ir_value()) - with _if_then(_if_ll_sc): - _lds_store_raw(scatter_mr, my_cnt, ArithValue(wave).index_cast(T.index)) - gpu.barrier() - - wave_offset = c_zero - for _w in range_constexpr(K4_NUM_WAVES - 1): - w_total = _lds_load_raw(scatter_mr, ArithValue(fx.Int32(_w)).index_cast(T.index)) - wave_offset = (wave > fx.Int32(_w)).select(wave_offset + w_total, wave_offset) + my_cnt, my_cnt_inclusive = _allwave_inclusive_prefix_sum( + my_cnt, lane, wave, scatter_mr, K4_NUM_WAVES, WARP_SIZE + ) + wave_offset = my_cnt_inclusive - my_cnt batch_total = c_zero for _w in range_constexpr(K4_NUM_WAVES): - batch_total = batch_total + _lds_load_raw(scatter_mr, ArithValue(fx.Int32(_w)).index_cast(T.index)) + batch_total = batch_total + _lds_load_raw(scatter_mr, fx.Int32(_w)) gpu.barrier() # Convert to exclusive prefix: my_exclusive = my_cnt - my_thread_count @@ -1481,14 +1439,16 @@ def p23_kernel( scatter_end_pos_t0 = results # Step 5: Fill padding with sentinel for THIS expert (parallel) - sentinel_val = c_sentinel | i32_tokens - pad_count = my_end - scatter_end_pos_t0 - pad_niters = (pad_count + fx.Int32(K4_BLOCK) - c_one) // fx.Int32(K4_BLOCK) - for _pi in range(fx.Index(0), ArithValue(pad_niters).index_cast(T.index), fx.Index(1)): - pad_slot = scatter_end_pos_t0 + fx.Int32(_pi) * fx.Int32(K4_BLOCK) + tid - pad_valid = pad_slot < my_end - buffer_ops.buffer_store(sentinel_val, sorted_ids_rsrc, pad_valid.select(pad_slot, c_oob_idx)) - buffer_ops.buffer_store(c_zero, sorted_w_rsrc, pad_valid.select(pad_slot, c_oob_idx)) + _fill_sentinel_slots( + sorted_ids_rsrc, + sorted_w_rsrc, + scatter_end_pos_t0, + my_end - scatter_end_pos_t0, + c_sentinel | i32_tokens, + K4_BLOCK, + tid, + c_oob_idx, + ) @flyc.jit def launch_p23( @@ -1672,6 +1632,18 @@ def moe_sorting_get_workspace_size(M, num_experts, topk, unit_size=UNIT_SIZE): return ws_mesh_i32 + (num_experts + 1) +def _launch_cached(cache, key, launch_fn, args, stream): + """AOT-compiled dispatch: first call JITs, subsequent calls use cached CompiledFunction.""" + cf = cache.get(key) + stream_arg = fx.Stream(stream) + if cf is not None: + cf(*args, stream_arg) + else: + launch_fn(*args, stream=stream) + cf = flyc.compile(launch_fn, *args, stream_arg) + cache[key] = cf + + def moe_sorting_flydsl( topk_ids, topk_weights, @@ -1735,7 +1707,10 @@ def moe_sorting_flydsl( n_zero_blocks = min((moe_buf_elems + BLOCK_SIZE - 1) // BLOCK_SIZE, num_cu * target_occupancy) n_grid_blocks = 1 + n_zero_blocks - launch_moe_sorting_decode_path( + launch_fn = compile_moe_sorting_decode( + num_experts=num_experts, topk=topk, max_tokens=max_tokens, unit_size=unit_size, has_mask=has_mask + ) + decode_args = ( topk_ids, topk_weights, sorted_ids, @@ -1747,12 +1722,9 @@ def moe_sorting_flydsl( M, moe_buf_elems, n_grid_blocks, - num_experts=num_experts, - topk=topk, - max_tokens=max_tokens, - unit_size=unit_size, - has_mask=has_mask, ) + cache_key = (num_experts, topk, max_tokens, unit_size, has_mask, device.index) + _launch_cached(_decode_cf_cache, cache_key, launch_fn, decode_args, torch.cuda.current_stream()) else: # Prefill path: multiple kernels via HBM workspace mesh_stride = ((M + unit_size - 1) // unit_size) * unit_size @@ -1762,173 +1734,18 @@ def moe_sorting_flydsl( if workspace is None: workspace = torch.empty(ws_total, dtype=torch.int32, device=device) - launch_moe_sorting_prefill_path( - topk_ids, - topk_weights, - sorted_ids, - sorted_weights, - sorted_expert_ids, - num_valid_ids, - moe_buf_i32, - workspace, - mask_tensor, - M, - moe_buf_elems, - mesh_stride, - ws_mesh_i32, - ws_total, - num_experts=num_experts, - topk=topk, - unit_size=unit_size, - has_mask=has_mask, + _, _, _, _, _, launch_p0v2_p23, launch_4k_fused = compile_moe_sorting_prefill( + num_experts=num_experts, topk=topk, unit_size=unit_size, has_mask=has_mask ) - - return sorted_ids, sorted_weights, sorted_expert_ids, num_valid_ids, moe_buf - - -def launch_moe_sorting_decode_path( - topk_ids, - topk_weights, - sorted_ids, - sorted_weights, - sorted_expert_ids, - num_valid_ids, - moe_buf_i32, - expert_mask, - i32_tokens, - i32_moe_buf_elems, - n_grid_blocks, - *, - num_experts, - topk, - max_tokens=128, - unit_size=UNIT_SIZE, - has_mask=False, -): - """Low-level launcher for decode path: single kernel. - - This is the hot-path entry point — no torch ops, just JIT dispatch. - Uses AOT-compiled dispatch after the first call to bypass the ~70 us - JIT overhead (inspect.Signature.bind + cache key + dict lookup). - """ - - cache_key = (num_experts, topk, max_tokens, unit_size, has_mask, topk_ids.device.index) - cf = _decode_cf_cache.get(cache_key) - if cf is not None: stream = torch.cuda.current_stream() - cf( - topk_ids, - topk_weights, - sorted_ids, - sorted_weights, - sorted_expert_ids, - num_valid_ids, - moe_buf_i32, - expert_mask, - i32_tokens, - i32_moe_buf_elems, - n_grid_blocks, - fx.Stream(stream), - ) - return - - launch_fn = compile_moe_sorting_decode( - num_experts=num_experts, - topk=topk, - max_tokens=max_tokens, - unit_size=unit_size, - has_mask=has_mask, - ) - stream = torch.cuda.current_stream() - launch_fn( - topk_ids, - topk_weights, - sorted_ids, - sorted_weights, - sorted_expert_ids, - num_valid_ids, - moe_buf_i32, - expert_mask, - i32_tokens, - i32_moe_buf_elems, - n_grid_blocks, - stream=stream, - ) - - cf = flyc.compile( - launch_fn, - topk_ids, - topk_weights, - sorted_ids, - sorted_weights, - sorted_expert_ids, - num_valid_ids, - moe_buf_i32, - expert_mask, - i32_tokens, - i32_moe_buf_elems, - n_grid_blocks, - fx.Stream(stream), - ) - _decode_cf_cache[cache_key] = cf - - -def launch_moe_sorting_prefill_path( - topk_ids, - topk_weights, - sorted_ids, - sorted_weights, - sorted_expert_ids, - num_valid_ids, - moe_buf_i32, - workspace, - expert_mask, - i32_tokens, - i32_moe_buf_elems, - mesh_stride, - mesh_size, - ws_total, - *, - num_experts, - topk, - unit_size=UNIT_SIZE, - has_mask=False, -): - """Low-level launcher for prefill path via HBM workspace. - - For small T (<=2048): fused P0_v2 (clear+scatter+count) + K4. - For large T (>2048): 4 separate kernels K1+K2+K3+K4. - - Uses AOT-compiled dispatch after the first call for each sub-kernel - to bypass JIT overhead. - """ - - launch_clear_ws, launch_p0, launch_p1, launch_p23, launch_p0v2, launch_p0v2_p23, launch_4k_fused = ( - compile_moe_sorting_prefill( - num_experts=num_experts, - topk=topk, - unit_size=unit_size, - has_mask=has_mask, - ) - ) - - stream = torch.cuda.current_stream() - stream_arg = fx.Stream(stream) - - num_cu = torch.cuda.get_device_properties(topk_ids.device).multi_processor_count - - base_key = (num_experts, topk, unit_size, has_mask, topk_ids.device.index) - use_p0v2 = i32_tokens <= 2048 + num_cu = torch.cuda.get_device_properties(device).multi_processor_count + target_occupancy = 2 + n_zero_blocks = min((moe_buf_elems + BLOCK_SIZE - 1) // BLOCK_SIZE, num_cu * target_occupancy) + k4_grid = num_experts + n_zero_blocks + base_key = (num_experts, topk, unit_size, has_mask, device.index) - target_occupancy = 2 - n_zero_blocks = min((i32_moe_buf_elems + BLOCK_SIZE - 1) // BLOCK_SIZE, num_cu * target_occupancy) - k4_grid = num_experts + n_zero_blocks - - if use_p0v2: - ck = base_key + ("p0v2_p23",) - cf = _prefill_cf_cache.get(ck) - if cf is not None: - cf( + if M <= 2048: + p0v2_args = ( topk_ids, workspace, topk_weights, @@ -1937,34 +1754,21 @@ def launch_moe_sorting_prefill_path( sorted_expert_ids, num_valid_ids, moe_buf_i32, - expert_mask, - i32_tokens, + mask_tensor, + M, mesh_stride, - mesh_size, - i32_moe_buf_elems, + ws_mesh_i32, + moe_buf_elems, k4_grid, - stream_arg, ) + _launch_cached(_prefill_cf_cache, base_key + ("p0v2_p23",), launch_p0v2_p23, p0v2_args, stream) else: - launch_p0v2_p23( - topk_ids, - workspace, - topk_weights, - sorted_ids, - sorted_weights, - sorted_expert_ids, - num_valid_ids, - moe_buf_i32, - expert_mask, - i32_tokens, - mesh_stride, - mesh_size, - i32_moe_buf_elems, - k4_grid, - stream=stream, - ) - cf = flyc.compile( - launch_p0v2_p23, + k1_grid = (ws_total + 1023) // 1024 + k2_grid = min(num_cu * target_occupancy, (M * topk + 255) // 256) + k2_total = M * topk + k2_stride = k2_grid * 256 + k2_niters = (k2_total + k2_stride - 1) // k2_stride + k4_args = ( topk_ids, workspace, topk_weights, @@ -1973,90 +1777,17 @@ def launch_moe_sorting_prefill_path( sorted_expert_ids, num_valid_ids, moe_buf_i32, - expert_mask, - i32_tokens, + mask_tensor, + M, mesh_stride, - mesh_size, - i32_moe_buf_elems, + ws_mesh_i32, + moe_buf_elems, + ws_total, + k2_niters, + k1_grid, + k2_grid, k4_grid, - stream_arg, ) - _prefill_cf_cache[ck] = cf - return - - # 4-kernel path (T > 2048): fused clear+p0+p1+p23 - k1_grid = (ws_total + 1023) // 1024 - k2_grid = min(num_cu * target_occupancy, (i32_tokens * topk + 255) // 256) - k2_total = i32_tokens * topk - k2_stride = k2_grid * 256 - k2_niters = (k2_total + k2_stride - 1) // k2_stride - - ck = base_key + ("4k_fused",) - cf = _prefill_cf_cache.get(ck) - if cf is not None: - cf( - topk_ids, - workspace, - topk_weights, - sorted_ids, - sorted_weights, - sorted_expert_ids, - num_valid_ids, - moe_buf_i32, - expert_mask, - i32_tokens, - mesh_stride, - mesh_size, - i32_moe_buf_elems, - ws_total, - k2_niters, - k1_grid, - k2_grid, - k4_grid, - stream_arg, - ) - else: - launch_4k_fused( - topk_ids, - workspace, - topk_weights, - sorted_ids, - sorted_weights, - sorted_expert_ids, - num_valid_ids, - moe_buf_i32, - expert_mask, - i32_tokens, - mesh_stride, - mesh_size, - i32_moe_buf_elems, - ws_total, - k2_niters, - k1_grid, - k2_grid, - k4_grid, - stream=stream, - ) - cf = flyc.compile( - launch_4k_fused, - topk_ids, - workspace, - topk_weights, - sorted_ids, - sorted_weights, - sorted_expert_ids, - num_valid_ids, - moe_buf_i32, - expert_mask, - i32_tokens, - mesh_stride, - mesh_size, - i32_moe_buf_elems, - ws_total, - k2_niters, - k1_grid, - k2_grid, - k4_grid, - stream_arg, - ) - _prefill_cf_cache[ck] = cf + _launch_cached(_prefill_cf_cache, base_key + ("4k_fused",), launch_4k_fused, k4_args, stream) + + return sorted_ids, sorted_weights, sorted_expert_ids, num_valid_ids, moe_buf From f0e4335e43b7963f71ecde157598b1453c266249 Mon Sep 17 00:00:00 2001 From: amd-weisun Date: Wed, 20 May 2026 18:32:21 +0100 Subject: [PATCH 09/14] fix scf if usage and remove duplicated lauching codes for prefill and decode --- kernels/moe_sorting_kernel.py | 247 +++++++++++++++++----------------- 1 file changed, 127 insertions(+), 120 deletions(-) diff --git a/kernels/moe_sorting_kernel.py b/kernels/moe_sorting_kernel.py index 70bef216..f533c722 100644 --- a/kernels/moe_sorting_kernel.py +++ b/kernels/moe_sorting_kernel.py @@ -162,6 +162,114 @@ def _extend_local_idx_for_extra_experts(cumsum_mr, mask_rsrc, K4_BLOCK, E, has_m _lds_store_raw(cumsum_mr, fx.Int32(_e3), fx.Int32(_e3)) +@flyc.jit +def _p23_scatter_mesh( + tid, + scatter_mr, + ws_rsrc, + weights_rsrc, + sorted_ids_rsrc, + sorted_w_rsrc, + mask_rsrc, + my_expert, + my_start, + my_end, + i32_mesh_stride, + c_topk, + K4_BLOCK, + has_mask, +): + """P23 Step 4: EP mask check, read uint8 mesh, DPP prefix sum, scatter tokens.""" + lane = tid % WARP_SIZE + wave = tid // WARP_SIZE + K4_NUM_WAVES = K4_BLOCK // WARP_SIZE + c_zero, c_one, c4 = fx.Int32(0), fx.Int32(1), fx.Int32(4) + c_ff, c_oob_idx = fx.Int32(0xFF), fx.Int32(0x7FFFFFFF) + p23_bid_enabled = c_one != c_zero + if has_mask: + p23_bid_mask = buffer_ops.buffer_load(mask_rsrc, my_expert, vec_width=1, dtype=T.i32) + p23_bid_enabled = p23_bid_mask != c_zero + i32_words_per_row = i32_mesh_stride >> fx.Int32(2) + n_mesh_iters = (my_start != my_end).select( + (i32_words_per_row + fx.Int32(K4_BLOCK - 1)) // fx.Int32(K4_BLOCK), c_zero + ) + mesh_row_i32_base = (my_expert * i32_mesh_stride) >> fx.Int32(2) + for _si, state in range(fx.Index(0), ArithValue(n_mesh_iters).index_cast(T.index), fx.Index(1), init=[my_start]): + position = state[0] + word_idx = fx.Int32(_si) * fx.Int32(K4_BLOCK) + tid + col_valid = p23_bid_enabled & (word_idx < i32_words_per_row) + safe_word_idx = col_valid.select(word_idx, c_zero) + word = buffer_ops.buffer_load(ws_rsrc, mesh_row_i32_base + safe_word_idx, vec_width=1, dtype=T.i32) + x0 = word & c_ff + x1 = (word >> fx.Int32(8)) & c_ff + x2 = (word >> fx.Int32(16)) & c_ff + x3 = (word >> fx.Int32(24)) & c_ff + base_col = word_idx * c4 + h0 = col_valid & (x0 != c_zero) + h1 = col_valid & (x1 != c_zero) + h2 = col_valid & (x2 != c_zero) + h3 = col_valid & (x3 != c_zero) + my_cnt = ( + h0.select(c_one, c_zero) + h1.select(c_one, c_zero) + h2.select(c_one, c_zero) + h3.select(c_one, c_zero) + ) + my_cnt, my_cnt_inclusive = _allwave_inclusive_prefix_sum( + my_cnt, lane, wave, scatter_mr, K4_NUM_WAVES, WARP_SIZE + ) + wave_offset = my_cnt_inclusive - my_cnt + batch_total = c_zero + for _w in range_constexpr(K4_NUM_WAVES): + batch_total = batch_total + _lds_load_raw(scatter_mr, fx.Int32(_w)) + gpu.barrier() + my_thread_count = ( + h0.select(c_one, c_zero) + h1.select(c_one, c_zero) + h2.select(c_one, c_zero) + h3.select(c_one, c_zero) + ) + my_exclusive = my_cnt - my_thread_count + wave_offset + scatter_base = position + my_exclusive + pid_0 = (h0.select(x0 - c_one, c_zero) << fx.Int32(24)) | base_col + pid_1 = (h1.select(x1 - c_one, c_zero) << fx.Int32(24)) | (base_col + c_one) + pid_2 = (h2.select(x2 - c_one, c_zero) << fx.Int32(24)) | (base_col + fx.Int32(2)) + pid_3 = (h3.select(x3 - c_one, c_zero) << fx.Int32(24)) | (base_col + fx.Int32(3)) + safe_slot_0 = h0.select(scatter_base, c_oob_idx) + off1 = scatter_base + h0.select(c_one, c_zero) + safe_slot_1 = h1.select(off1, c_oob_idx) + off2 = off1 + h1.select(c_one, c_zero) + safe_slot_2 = h2.select(off2, c_oob_idx) + off3 = off2 + h2.select(c_one, c_zero) + safe_slot_3 = h3.select(off3, c_oob_idx) + w_val_0 = buffer_ops.buffer_load( + weights_rsrc, h0.select(base_col * c_topk + h0.select(x0 - c_one, c_zero), c_zero), vec_width=1, dtype=T.i32 + ) + w_val_1 = buffer_ops.buffer_load( + weights_rsrc, + h1.select((base_col + c_one) * c_topk + h1.select(x1 - c_one, c_zero), c_zero), + vec_width=1, + dtype=T.i32, + ) + w_val_2 = buffer_ops.buffer_load( + weights_rsrc, + h2.select((base_col + fx.Int32(2)) * c_topk + h2.select(x2 - c_one, c_zero), c_zero), + vec_width=1, + dtype=T.i32, + ) + w_val_3 = buffer_ops.buffer_load( + weights_rsrc, + h3.select((base_col + fx.Int32(3)) * c_topk + h3.select(x3 - c_one, c_zero), c_zero), + vec_width=1, + dtype=T.i32, + ) + buffer_ops.buffer_store(pid_0, sorted_ids_rsrc, safe_slot_0) + buffer_ops.buffer_store(pid_1, sorted_ids_rsrc, safe_slot_1) + buffer_ops.buffer_store(pid_2, sorted_ids_rsrc, safe_slot_2) + buffer_ops.buffer_store(pid_3, sorted_ids_rsrc, safe_slot_3) + buffer_ops.buffer_store(w_val_0, sorted_w_rsrc, safe_slot_0) + buffer_ops.buffer_store(w_val_1, sorted_w_rsrc, safe_slot_1) + buffer_ops.buffer_store(w_val_2, sorted_w_rsrc, safe_slot_2) + buffer_ops.buffer_store(w_val_3, sorted_w_rsrc, safe_slot_3) + pos_next = position + batch_total + results = yield [pos_next] + return results + + @flyc.jit def _fill_sentinel_slots(sorted_ids_rsrc, sorted_w_rsrc, start, count, sentinel, block_size, tid, oob_idx): """Cooperative sentinel fill: threads fill [start, start+count) with sentinels.""" @@ -1188,13 +1296,11 @@ def p23_kernel( wave = tid // WARP_SIZE c_zero = fx.Int32(0) c_one = fx.Int32(1) - c4 = fx.Int32(4) c_E = fx.Int32(E) c_unit = fx.Int32(unit_size) c_topk = fx.Int32(topk) c_sentinel = fx.Int32(topk << 24) c_oob_idx = fx.Int32(0x7FFFFFFF) - c_ff = fx.Int32(0xFF) # Buffer resources ws_rsrc = buffer_ops.create_buffer_resource(workspace, max_size=True) @@ -1222,8 +1328,7 @@ def p23_kernel( # ================ PARALLEL PREFIX-SUM + MESH SCATTER (blocks 0..E-1) == # Each block independently: prefix sum (redundant), scatter for its expert only. - _if_sort = scf.IfOp(is_sort_block.ir_value()) - with _if_then(_if_sort): + if is_sort_block: my_expert = bid # Step 1: Load expert counts from workspace -> pad to unit_size -> LDS cumsum @@ -1320,123 +1425,25 @@ def p23_kernel( sorted_e_rsrc = buffer_ops.create_buffer_resource(sorted_expert_ids, max_size=True) blk_start = my_start // c_unit blk_end = my_end // c_unit - n_blks = blk_end - blk_start - n_eid_iters = (n_blks + fx.Int32(K4_BLOCK) - c_one) // fx.Int32(K4_BLOCK) - for _eii in range(fx.Index(0), ArithValue(n_eid_iters).index_cast(T.index), fx.Index(1)): - blk_idx = blk_start + fx.Int32(_eii) * fx.Int32(K4_BLOCK) + tid - buffer_ops.buffer_store(my_local_idx, sorted_e_rsrc, (blk_idx < blk_end).select(blk_idx, c_oob_idx)) - - # Step 4: Mesh-based scatter — read uint8 mesh from HBM, extract tokens, - # DPP prefix sum over counts, cross-wave LDS reduction, scatter stores. - p23_bid_enabled = c_one != c_zero - if has_mask: - # EP: skip scatter for masked experts (my_start == my_end, but mesh has data) - p23_bid_mask = buffer_ops.buffer_load(mask_rsrc, my_expert, vec_width=1, dtype=T.i32) - p23_bid_enabled = p23_bid_mask != c_zero - - i32_words_per_row = i32_mesh_stride >> fx.Int32(2) - n_mesh_iters_raw = (i32_words_per_row + fx.Int32(K4_BLOCK - 1)) // fx.Int32(K4_BLOCK) - has_work = my_start != my_end - n_mesh_iters = has_work.select(n_mesh_iters_raw, c_zero) - mesh_row_i32_base = (my_expert * i32_mesh_stride) >> fx.Int32(2) - - for _si, state in range( - fx.Index(0), ArithValue(n_mesh_iters).index_cast(T.index), fx.Index(1), init=[my_start] - ): - position = state[0] - - word_idx = fx.Int32(_si) * fx.Int32(K4_BLOCK) + tid - col_valid = p23_bid_enabled & (word_idx < i32_words_per_row) - safe_word_idx = col_valid.select(word_idx, c_zero) - word = buffer_ops.buffer_load(ws_rsrc, mesh_row_i32_base + safe_word_idx, vec_width=1, dtype=T.i32) - - # Extract 4 bytes from the i32 word - x0 = word & c_ff - x1 = (word >> fx.Int32(8)) & c_ff - x2 = (word >> fx.Int32(16)) & c_ff - x3 = (word >> fx.Int32(24)) & c_ff - base_col = word_idx * c4 - - h0 = col_valid & (x0 != c_zero) - h1 = col_valid & (x1 != c_zero) - h2 = col_valid & (x2 != c_zero) - h3 = col_valid & (x3 != c_zero) - - my_cnt = ( - h0.select(c_one, c_zero) - + h1.select(c_one, c_zero) - + h2.select(c_one, c_zero) - + h3.select(c_one, c_zero) - ) - - my_cnt, my_cnt_inclusive = _allwave_inclusive_prefix_sum( - my_cnt, lane, wave, scatter_mr, K4_NUM_WAVES, WARP_SIZE - ) - wave_offset = my_cnt_inclusive - my_cnt - batch_total = c_zero - for _w in range_constexpr(K4_NUM_WAVES): - batch_total = batch_total + _lds_load_raw(scatter_mr, fx.Int32(_w)) - gpu.barrier() + _write_expert_id_blocks(sorted_e_rsrc, my_local_idx, blk_start, blk_end - blk_start) - # Convert to exclusive prefix: my_exclusive = my_cnt - my_thread_count - my_thread_count = ( - h0.select(c_one, c_zero) - + h1.select(c_one, c_zero) - + h2.select(c_one, c_zero) - + h3.select(c_one, c_zero) - ) - my_exclusive = my_cnt - my_thread_count + wave_offset - - # Scatter: compute all addresses, batch-load weights, then batch-store. - scatter_base = position + my_exclusive - - # Compute packed IDs and output slots for all 4 tokens - token_id_0 = base_col - topk_slot_0 = h0.select(x0 - c_one, c_zero) - pid_0 = (topk_slot_0 << fx.Int32(24)) | token_id_0 - safe_slot_0 = h0.select(scatter_base, c_oob_idx) - - off1 = scatter_base + h0.select(c_one, c_zero) - token_id_1 = base_col + c_one - topk_slot_1 = h1.select(x1 - c_one, c_zero) - pid_1 = (topk_slot_1 << fx.Int32(24)) | token_id_1 - safe_slot_1 = h1.select(off1, c_oob_idx) - - off2 = off1 + h1.select(c_one, c_zero) - token_id_2 = base_col + fx.Int32(2) - topk_slot_2 = h2.select(x2 - c_one, c_zero) - pid_2 = (topk_slot_2 << fx.Int32(24)) | token_id_2 - safe_slot_2 = h2.select(off2, c_oob_idx) - - off3 = off2 + h2.select(c_one, c_zero) - token_id_3 = base_col + fx.Int32(3) - topk_slot_3 = h3.select(x3 - c_one, c_zero) - pid_3 = (topk_slot_3 << fx.Int32(24)) | token_id_3 - safe_slot_3 = h3.select(off3, c_oob_idx) - - # Batch-issue all 4 weight loads (increases load-use distance) - w_addr_0 = h0.select(token_id_0 * c_topk + topk_slot_0, c_zero) - w_addr_1 = h1.select(token_id_1 * c_topk + topk_slot_1, c_zero) - w_addr_2 = h2.select(token_id_2 * c_topk + topk_slot_2, c_zero) - w_addr_3 = h3.select(token_id_3 * c_topk + topk_slot_3, c_zero) - w_val_0 = buffer_ops.buffer_load(weights_rsrc, w_addr_0, vec_width=1, dtype=T.i32) - w_val_1 = buffer_ops.buffer_load(weights_rsrc, w_addr_1, vec_width=1, dtype=T.i32) - w_val_2 = buffer_ops.buffer_load(weights_rsrc, w_addr_2, vec_width=1, dtype=T.i32) - w_val_3 = buffer_ops.buffer_load(weights_rsrc, w_addr_3, vec_width=1, dtype=T.i32) - - # Batch-store: all packed IDs, then all weights - buffer_ops.buffer_store(pid_0, sorted_ids_rsrc, safe_slot_0) - buffer_ops.buffer_store(pid_1, sorted_ids_rsrc, safe_slot_1) - buffer_ops.buffer_store(pid_2, sorted_ids_rsrc, safe_slot_2) - buffer_ops.buffer_store(pid_3, sorted_ids_rsrc, safe_slot_3) - buffer_ops.buffer_store(w_val_0, sorted_w_rsrc, safe_slot_0) - buffer_ops.buffer_store(w_val_1, sorted_w_rsrc, safe_slot_1) - buffer_ops.buffer_store(w_val_2, sorted_w_rsrc, safe_slot_2) - buffer_ops.buffer_store(w_val_3, sorted_w_rsrc, safe_slot_3) - - pos_next = position + batch_total - results = yield [pos_next] - scatter_end_pos_t0 = results + # Step 4: Mesh-based scatter (EP mask + uint8 mesh read + DPP prefix sum + scatter) + scatter_end_pos_t0 = _p23_scatter_mesh( + tid, + scatter_mr, + ws_rsrc, + weights_rsrc, + sorted_ids_rsrc, + sorted_w_rsrc, + mask_rsrc, + my_expert, + my_start, + my_end, + i32_mesh_stride, + c_topk, + K4_BLOCK, + has_mask, + ) # Step 5: Fill padding with sentinel for THIS expert (parallel) _fill_sentinel_slots( From ce6315d9ae09b4c31efc43ccccbb51e5562bf820 Mon Sep 17 00:00:00 2001 From: amd-weisun Date: Thu, 21 May 2026 17:38:12 +0100 Subject: [PATCH 10/14] refactoring to improve code reuse --- kernels/moe_sorting_kernel.py | 444 ++++++++++++++++------------------ 1 file changed, 205 insertions(+), 239 deletions(-) diff --git a/kernels/moe_sorting_kernel.py b/kernels/moe_sorting_kernel.py index f533c722..ae78bde1 100644 --- a/kernels/moe_sorting_kernel.py +++ b/kernels/moe_sorting_kernel.py @@ -9,9 +9,12 @@ Algorithm: counting sort in LDS (histogram → prefix-sum → scatter). -Two paths: - - Decode (small T): single kernel, all phases in LDS. - - Prefill (large T): 4 kernels via HBM workspace (ClearWS → P0 scatter → P1 count → P23 prefix-sum+scatter). +Three paths (selected by T and DECODE_MAX_T = min(sub_tokens, BLOCK_SIZE // max(topk, E//8))): + - T <= DECODE_MAX_T: single kernel, all phases in LDS. + - DECODE_MAX_T < T <= 2048: 2 kernels (fused P0v2 + P23) via HBM workspace. + - Prefill/4 kernels (T > 2048): 4 kernels via HBM workspace (ClearWS → P0 scatter → P1 count → P23). + - DECODE_MAX_T = 16 for large models(Qwen3-MoE, MiMo-V2-Flash .. DeepSeekR1 V4) on current LDS sizes. + - DECODE_MAX_T is larger fro small model like 128 for Mixtral-8x7B, 32 for GPT-OSS and MiniMax-M2 Packed token ID format: (topk_position << 24) | token_id - Upper 8 bits: topk slot (0..topk-1) @@ -20,7 +23,6 @@ """ import functools -from contextlib import contextmanager import torch @@ -28,7 +30,6 @@ import flydsl.expr as fx from flydsl._mlir import ir from flydsl._mlir.dialects import memref as memref_ops -from flydsl._mlir.dialects import scf from flydsl.compiler.kernel_function import CompilationContext from flydsl.expr import buffer_ops, gpu, range_constexpr from flydsl.expr import rocdl as fly_rocdl @@ -146,130 +147,6 @@ def _write_expert_id_blocks(sorted_e_rsrc, local_eid, blk_start, n_blks): buffer_ops.buffer_store(local_eid, sorted_e_rsrc, blk_idx) -@flyc.jit -def _extend_local_idx_for_extra_experts(cumsum_mr, mask_rsrc, K4_BLOCK, E, has_mask): - """Thread-0: write local expert indices for experts >= K4_BLOCK to cumsum_mr.""" - if has_mask: - prev_local = _lds_load_raw(cumsum_mr, fx.Int32(K4_BLOCK - 1)) - prev_mask = buffer_ops.buffer_load(mask_rsrc, fx.Int32(K4_BLOCK - 1), vec_width=1, dtype=T.i32) - prev_local = prev_local + prev_mask - for _e3 in range_constexpr(K4_BLOCK, E): - e3_mask = buffer_ops.buffer_load(mask_rsrc, fx.Int32(_e3), vec_width=1, dtype=T.i32) - _lds_store_raw(cumsum_mr, prev_local, fx.Int32(_e3)) - prev_local = prev_local + e3_mask - else: - for _e3 in range_constexpr(K4_BLOCK, E): - _lds_store_raw(cumsum_mr, fx.Int32(_e3), fx.Int32(_e3)) - - -@flyc.jit -def _p23_scatter_mesh( - tid, - scatter_mr, - ws_rsrc, - weights_rsrc, - sorted_ids_rsrc, - sorted_w_rsrc, - mask_rsrc, - my_expert, - my_start, - my_end, - i32_mesh_stride, - c_topk, - K4_BLOCK, - has_mask, -): - """P23 Step 4: EP mask check, read uint8 mesh, DPP prefix sum, scatter tokens.""" - lane = tid % WARP_SIZE - wave = tid // WARP_SIZE - K4_NUM_WAVES = K4_BLOCK // WARP_SIZE - c_zero, c_one, c4 = fx.Int32(0), fx.Int32(1), fx.Int32(4) - c_ff, c_oob_idx = fx.Int32(0xFF), fx.Int32(0x7FFFFFFF) - p23_bid_enabled = c_one != c_zero - if has_mask: - p23_bid_mask = buffer_ops.buffer_load(mask_rsrc, my_expert, vec_width=1, dtype=T.i32) - p23_bid_enabled = p23_bid_mask != c_zero - i32_words_per_row = i32_mesh_stride >> fx.Int32(2) - n_mesh_iters = (my_start != my_end).select( - (i32_words_per_row + fx.Int32(K4_BLOCK - 1)) // fx.Int32(K4_BLOCK), c_zero - ) - mesh_row_i32_base = (my_expert * i32_mesh_stride) >> fx.Int32(2) - for _si, state in range(fx.Index(0), ArithValue(n_mesh_iters).index_cast(T.index), fx.Index(1), init=[my_start]): - position = state[0] - word_idx = fx.Int32(_si) * fx.Int32(K4_BLOCK) + tid - col_valid = p23_bid_enabled & (word_idx < i32_words_per_row) - safe_word_idx = col_valid.select(word_idx, c_zero) - word = buffer_ops.buffer_load(ws_rsrc, mesh_row_i32_base + safe_word_idx, vec_width=1, dtype=T.i32) - x0 = word & c_ff - x1 = (word >> fx.Int32(8)) & c_ff - x2 = (word >> fx.Int32(16)) & c_ff - x3 = (word >> fx.Int32(24)) & c_ff - base_col = word_idx * c4 - h0 = col_valid & (x0 != c_zero) - h1 = col_valid & (x1 != c_zero) - h2 = col_valid & (x2 != c_zero) - h3 = col_valid & (x3 != c_zero) - my_cnt = ( - h0.select(c_one, c_zero) + h1.select(c_one, c_zero) + h2.select(c_one, c_zero) + h3.select(c_one, c_zero) - ) - my_cnt, my_cnt_inclusive = _allwave_inclusive_prefix_sum( - my_cnt, lane, wave, scatter_mr, K4_NUM_WAVES, WARP_SIZE - ) - wave_offset = my_cnt_inclusive - my_cnt - batch_total = c_zero - for _w in range_constexpr(K4_NUM_WAVES): - batch_total = batch_total + _lds_load_raw(scatter_mr, fx.Int32(_w)) - gpu.barrier() - my_thread_count = ( - h0.select(c_one, c_zero) + h1.select(c_one, c_zero) + h2.select(c_one, c_zero) + h3.select(c_one, c_zero) - ) - my_exclusive = my_cnt - my_thread_count + wave_offset - scatter_base = position + my_exclusive - pid_0 = (h0.select(x0 - c_one, c_zero) << fx.Int32(24)) | base_col - pid_1 = (h1.select(x1 - c_one, c_zero) << fx.Int32(24)) | (base_col + c_one) - pid_2 = (h2.select(x2 - c_one, c_zero) << fx.Int32(24)) | (base_col + fx.Int32(2)) - pid_3 = (h3.select(x3 - c_one, c_zero) << fx.Int32(24)) | (base_col + fx.Int32(3)) - safe_slot_0 = h0.select(scatter_base, c_oob_idx) - off1 = scatter_base + h0.select(c_one, c_zero) - safe_slot_1 = h1.select(off1, c_oob_idx) - off2 = off1 + h1.select(c_one, c_zero) - safe_slot_2 = h2.select(off2, c_oob_idx) - off3 = off2 + h2.select(c_one, c_zero) - safe_slot_3 = h3.select(off3, c_oob_idx) - w_val_0 = buffer_ops.buffer_load( - weights_rsrc, h0.select(base_col * c_topk + h0.select(x0 - c_one, c_zero), c_zero), vec_width=1, dtype=T.i32 - ) - w_val_1 = buffer_ops.buffer_load( - weights_rsrc, - h1.select((base_col + c_one) * c_topk + h1.select(x1 - c_one, c_zero), c_zero), - vec_width=1, - dtype=T.i32, - ) - w_val_2 = buffer_ops.buffer_load( - weights_rsrc, - h2.select((base_col + fx.Int32(2)) * c_topk + h2.select(x2 - c_one, c_zero), c_zero), - vec_width=1, - dtype=T.i32, - ) - w_val_3 = buffer_ops.buffer_load( - weights_rsrc, - h3.select((base_col + fx.Int32(3)) * c_topk + h3.select(x3 - c_one, c_zero), c_zero), - vec_width=1, - dtype=T.i32, - ) - buffer_ops.buffer_store(pid_0, sorted_ids_rsrc, safe_slot_0) - buffer_ops.buffer_store(pid_1, sorted_ids_rsrc, safe_slot_1) - buffer_ops.buffer_store(pid_2, sorted_ids_rsrc, safe_slot_2) - buffer_ops.buffer_store(pid_3, sorted_ids_rsrc, safe_slot_3) - buffer_ops.buffer_store(w_val_0, sorted_w_rsrc, safe_slot_0) - buffer_ops.buffer_store(w_val_1, sorted_w_rsrc, safe_slot_1) - buffer_ops.buffer_store(w_val_2, sorted_w_rsrc, safe_slot_2) - buffer_ops.buffer_store(w_val_3, sorted_w_rsrc, safe_slot_3) - pos_next = position + batch_total - results = yield [pos_next] - return results - - @flyc.jit def _fill_sentinel_slots(sorted_ids_rsrc, sorted_w_rsrc, start, count, sentinel, block_size, tid, oob_idx): """Cooperative sentinel fill: threads fill [start, start+count) with sentinels.""" @@ -283,6 +160,28 @@ def _fill_sentinel_slots(sorted_ids_rsrc, sorted_w_rsrc, start, count, sentinel, buffer_ops.buffer_store(c_zero, sorted_w_rsrc, safe) +# --------------------------------------------------------------------------- +# LDS helpers for prefill kernels (module-level, used inside @flyc.kernel) +# --------------------------------------------------------------------------- +def _lds_load_raw(raw_mr, idx): + """Load i32 from LDS raw memref. idx can be i32 or index.""" + raw_idx = idx.ir_value() if hasattr(idx, "ir_value") else idx + if not isinstance(raw_idx.type, ir.IndexType): + raw_idx = ArithValue(idx).index_cast(T.index) + raw_idx = raw_idx.ir_value() if hasattr(raw_idx, "ir_value") else raw_idx + return fx.Int32(memref_ops.load(raw_mr, [raw_idx])) + + +def _lds_store_raw(raw_mr, val, idx): + """Store i32 to LDS raw memref. idx can be i32 or index.""" + v = val.ir_value() if hasattr(val, "ir_value") else val + raw_idx = idx.ir_value() if hasattr(idx, "ir_value") else idx + if not isinstance(raw_idx.type, ir.IndexType): + raw_idx = ArithValue(idx).index_cast(T.index) + raw_idx = raw_idx.ir_value() if hasattr(raw_idx, "ir_value") else raw_idx + memref_ops.store(v, raw_mr, [raw_idx]) + + # --------------------------------------------------------------------------- # AOT-compiled dispatch caches — keyed by constexpr values. # After the first JIT call (which compiles the kernel), flyc.compile() @@ -294,18 +193,6 @@ def _fill_sentinel_slots(sorted_ids_rsrc, sorted_w_rsrc, start, count, sentinel, _dummy_mask_cache = {} # device -> torch.Tensor(1, dtype=i32, value=1) -@contextmanager -def _if_then(if_op): - """Context manager for scf.IfOp then-region (from moe_gemm_2stage.py).""" - with ir.InsertionPoint(if_op.then_block): - try: - yield if_op.then_block - finally: - blk = if_op.then_block - if (not blk.operations) or not isinstance(blk.operations[-1], scf.YieldOp): - scf.YieldOp([]) - - # --------------------------------------------------------------------------- # FlyDSL GPU kernel — decode path (single kernel, SubTokenOneShot) # --------------------------------------------------------------------------- @@ -334,7 +221,6 @@ def compile_moe_sorting_decode( arch = get_hip_arch() E = num_experts # CDNA (warp64): 512 threads = 8 waves, affordable cross-wave reduction. - # RDNA (warp32): cap at 256 to avoid 16-wave overhead. max_decode_block = 512 if WARP_SIZE == 64 else 256 DECODE_BLOCK = 256 if E <= 256 else min(512, max_decode_block) NUM_WAVES = DECODE_BLOCK // WARP_SIZE @@ -355,9 +241,7 @@ def compile_moe_sorting_decode( sub_unroll = 8 cumsum_bufs = 2 if r < (cumsum_bufs + sub_unroll): - raise ValueError( - f"LDS too small for E={E}: need at least " f"{(cumsum_bufs + sub_unroll) * smem_cols * 4} bytes" - ) + raise ValueError(f"LDS too small for E={E}: need at least {(cumsum_bufs + sub_unroll) * smem_cols * 4} bytes") r_for_sub = ((r - cumsum_bufs) // sub_unroll) * sub_unroll r_token_min = ((max_tokens + sub_unroll - 1) // sub_unroll) * sub_unroll r_for_sub = min(r_for_sub, r_token_min) @@ -382,27 +266,6 @@ def compile_moe_sorting_decode( scratch_offset = allocator._align(allocator.ptr, 16) allocator.ptr = scratch_offset + NUM_WAVES * 4 - # Helpers for raw memref LDS access (used inside scf.for instead of SmemPtr) - def _to_index(v): - """Convert i32 or index DSL value to raw MLIR index value.""" - raw = v.ir_value() if hasattr(v, "ir_value") else v - if isinstance(raw.type, ir.IndexType): - return raw - return ArithValue(v).index_cast(T.index) - - def _lds_load(raw_mr, idx): - """Load i32 from LDS raw memref. idx can be i32 or index.""" - return fx.Int32(memref_ops.load(raw_mr, [_to_index(idx)])) - - def _lds_store(raw_mr, val, idx): - """Store i32 to LDS raw memref. idx can be i32 or index.""" - v = val.ir_value() if hasattr(val, "ir_value") else val - memref_ops.store(v, raw_mr, [_to_index(idx)]) - - def _unwrap(v): - """Unwrap DSL value to raw MLIR ir.Value for scf.for init.""" - return v.ir_value() if hasattr(v, "ir_value") else v - @flyc.kernel(known_block_size=[DECODE_BLOCK, 1, 1]) def moe_sorting_decode_kernel( topk_ids_tensor: fx.Tensor, @@ -468,7 +331,7 @@ def moe_sorting_decode_kernel( safe_idx = is_valid.select(idx, c_zero_i32) safe_idx_ix = ArithValue(safe_idx).index_cast(T.index) # Always store; out-of-bounds threads harmlessly write to index 0 - _lds_store(mesh_mr, c_zero_i32, safe_idx_ix) + _lds_store_raw(mesh_mr, c_zero_i32, safe_idx_ix) gpu.barrier() # Fill mesh: for each (token, topk_slot), write topk_slot+1 to mesh[token, expert_id] @@ -492,7 +355,7 @@ def moe_sorting_decode_kernel( safe_mesh_addr = is_valid.select(mesh_addr, last_mesh_idx) safe_mesh_ix = ArithValue(safe_mesh_addr).index_cast(T.index) val = is_valid.select(topk_slot + c_one_i32, c_zero_i32) - _lds_store(mesh_mr, val, safe_mesh_ix) + _lds_store_raw(mesh_mr, val, safe_mesh_ix) gpu.barrier() # ===================== PHASE 2: Count + Prefix Sum ===================== @@ -505,7 +368,7 @@ def moe_sorting_decode_kernel( # Initialize cumsum[0] = 0. All threads write 0 so there's no # read-modify-write race across waves. - _lds_store(cumsum_mr, c_zero_i32, c_zero_i32) + _lds_store_raw(cumsum_mr, c_zero_i32, c_zero_i32) gpu.barrier() for i_e in range_constexpr(0, E, DECODE_BLOCK // 8): @@ -522,7 +385,7 @@ def moe_sorting_decode_kernel( safe_eid = combined_valid.select(eid_local, c_zero_i32) mesh_rd_addr = safe_sub * c_smem_cols + safe_eid mesh_rd_ix = ArithValue(mesh_rd_addr).index_cast(T.index) - mesh_val = _lds_load(mesh_mr, mesh_rd_ix) + mesh_val = _lds_load_raw(mesh_mr, mesh_rd_ix) has_token = combined_valid.select( (mesh_val != c_zero_i32).select(c_one_i32, c_zero_i32), @@ -544,7 +407,7 @@ def moe_sorting_decode_kernel( cs_idx = write_valid.select(eid_local + c_one_i32, c_zero_i32) cs_ix = ArithValue(cs_idx).index_cast(T.index) cs_val = write_valid.select(cnt, c_zero_i32) - _lds_store(cumsum_mr, cs_val, cs_ix) + _lds_store_raw(cumsum_mr, cs_val, cs_ix) gpu.barrier() # Phase 2b: Prefix sum over expert counts. @@ -555,11 +418,11 @@ def moe_sorting_decode_kernel( # Safe index: valid → cumsum[eid+1], invalid → cumsum[0] (write 0, harmless) safe_cvt_idx = cvt_valid.select(cvt_eid + c_one_i32, c_zero_i32) cvt_ix = ArithValue(safe_cvt_idx).index_cast(T.index) - raw_cnt_cvt = _lds_load(cumsum_mr, cvt_ix) + raw_cnt_cvt = _lds_load_raw(cumsum_mr, cvt_ix) blocks_cvt = (raw_cnt_cvt + c_unit - c_one_i32) // c_unit padded_cvt = (raw_cnt_cvt == c_zero_i32).select(c_zero_i32, blocks_cvt * c_unit) # Valid threads write padded value; invalid threads write 0 to cumsum[0] - _lds_store(cumsum_mr, cvt_valid.select(padded_cvt, c_zero_i32), cvt_ix) + _lds_store_raw(cumsum_mr, cvt_valid.select(padded_cvt, c_zero_i32), cvt_ix) gpu.barrier() if has_mask: @@ -573,7 +436,9 @@ def moe_sorting_decode_kernel( ep_m = buffer_ops.buffer_load(mask_rsrc, ep_safe_eid, vec_width=1, dtype=T.i32) should_zero = ep_valid & (ep_m == c_zero_i32) ep_cs_ix = ArithValue(ep_valid.select(ep_eid + c_one_i32, c_zero_i32)).index_cast(T.index) - _lds_store(cumsum_mr, should_zero.select(c_zero_i32, _lds_load(cumsum_mr, ep_cs_ix)), ep_cs_ix) + _lds_store_raw( + cumsum_mr, should_zero.select(c_zero_i32, _lds_load_raw(cumsum_mr, ep_cs_ix)), ep_cs_ix + ) gpu.barrier() # Step 2: All-wave parallel prefix sum (cumsum → cumdup). @@ -584,16 +449,16 @@ def moe_sorting_decode_kernel( ps_eid = fx.Int32(_ps_chunk) + tid ps_valid = ps_eid < c_E ps_safe_ix = ArithValue(ps_valid.select(ps_eid + c_one_i32, c_zero_i32)).index_cast(T.index) - ps_val = ps_valid.select(_lds_load(cumsum_mr, ps_safe_ix), c_zero_i32) - _lds_store(cumdup_mr, ps_val, ps_safe_ix) - _lds_store(cumdup_mr, c_zero_i32, c_zero_i32) + ps_val = ps_valid.select(_lds_load_raw(cumsum_mr, ps_safe_ix), c_zero_i32) + _lds_store_raw(cumdup_mr, ps_val, ps_safe_ix) + _lds_store_raw(cumdup_mr, c_zero_i32, c_zero_i32) gpu.barrier() # DPP prefix sum — all NUM_WAVES waves active ps_tid_valid = tid < c_E - val = ps_tid_valid.select(_lds_load(cumdup_mr, tid + c_one_i32), c_zero_i32) + val = ps_tid_valid.select(_lds_load_raw(cumdup_mr, tid + c_one_i32), c_zero_i32) _, inclusive_ps = _allwave_inclusive_prefix_sum(val, lane, wave, scratch_mr, NUM_WAVES, WARP_SIZE) - _lds_store( + _lds_store_raw( cumdup_mr, ps_tid_valid.select(inclusive_ps, c_zero_i32), ArithValue(ps_tid_valid.select(tid + c_one_i32, c_zero_i32)).index_cast(T.index), @@ -603,16 +468,16 @@ def moe_sorting_decode_kernel( # For E > DECODE_BLOCK: thread 0 serially extends if E > DECODE_BLOCK: if is_t0: - _extend_prefix_sum_serial(cumdup_mr, DECODE_BLOCK, E, _lds_load, _lds_store) + _extend_prefix_sum_serial(cumdup_mr, DECODE_BLOCK, E, _lds_load_raw, _lds_store_raw) gpu.barrier() # cumdup[0] = 0 - _lds_store(cumdup_mr, c_zero_i32, c_zero_i32) + _lds_store_raw(cumdup_mr, c_zero_i32, c_zero_i32) gpu.barrier() # Write num_valid_ids from cumdup[E] cs_E_ix_ps = ArithValue(c_E).index_cast(T.index) - total_padded = _lds_load(cumdup_mr, cs_E_ix_ps) + total_padded = _lds_load_raw(cumdup_mr, cs_E_ix_ps) buffer_ops.buffer_store(total_padded, nvalid_rsrc, c_zero_i32) buffer_ops.buffer_store(tokens, nvalid_rsrc, c_one_i32) gpu.barrier() @@ -623,8 +488,8 @@ def moe_sorting_decode_kernel( cp_valid = cp_idx <= c_E safe_cp_idx = cp_valid.select(cp_idx, c_zero_i32) cp_ix = ArithValue(safe_cp_idx).index_cast(T.index) - cp_val = _lds_load(cumdup_mr, cp_ix) - _lds_store(cumsum_mr, cp_val, cp_ix) + cp_val = _lds_load_raw(cumdup_mr, cp_ix) + _lds_store_raw(cumsum_mr, cp_val, cp_ix) gpu.barrier() if has_mask: @@ -637,15 +502,15 @@ def moe_sorting_decode_kernel( ml_mask = buffer_ops.buffer_load(mask_rsrc, safe_ml_eid, vec_width=1, dtype=T.i32) ml_val = ml_valid.select(ml_mask, c_zero_i32) ml_ix = ArithValue(ml_valid.select(ml_eid + c_one_i32, c_zero_i32)).index_cast(T.index) - _lds_store(cumdup_mr, ml_val, ml_ix) - _lds_store(cumdup_mr, c_zero_i32, c_zero_i32) + _lds_store_raw(cumdup_mr, ml_val, ml_ix) + _lds_store_raw(cumdup_mr, c_zero_i32, c_zero_i32) gpu.barrier() # All-wave DPP prefix sum over mask values in cumdup m_tid_valid = tid < c_E - mval = m_tid_valid.select(_lds_load(cumdup_mr, tid + c_one_i32), c_zero_i32) + mval = m_tid_valid.select(_lds_load_raw(cumdup_mr, tid + c_one_i32), c_zero_i32) _, inclusive_m = _allwave_inclusive_prefix_sum(mval, lane, wave, scratch_mr, NUM_WAVES, WARP_SIZE) - _lds_store( + _lds_store_raw( cumdup_mr, m_tid_valid.select(inclusive_m, c_zero_i32), ArithValue(m_tid_valid.select(tid + c_one_i32, c_zero_i32)).index_cast(T.index), @@ -654,10 +519,10 @@ def moe_sorting_decode_kernel( if E > DECODE_BLOCK: if is_t0: - _extend_prefix_sum_serial(cumdup_mr, DECODE_BLOCK, E, _lds_load, _lds_store) + _extend_prefix_sum_serial(cumdup_mr, DECODE_BLOCK, E, _lds_load_raw, _lds_store_raw) gpu.barrier() - _lds_store(cumdup_mr, c_zero_i32, c_zero_i32) + _lds_store_raw(cumdup_mr, c_zero_i32, c_zero_i32) gpu.barrier() else: # No mask: cumdup[eid] = eid (identity mapping) @@ -666,7 +531,7 @@ def moe_sorting_decode_kernel( ml_valid = ml_eid < c_E safe_ml_eid = ml_valid.select(ml_eid, c_zero_i32) ml_ix = ArithValue(safe_ml_eid).index_cast(T.index) - _lds_store(cumdup_mr, ml_valid.select(safe_ml_eid, c_zero_i32), ml_ix) + _lds_store_raw(cumdup_mr, ml_valid.select(safe_ml_eid, c_zero_i32), ml_ix) gpu.barrier() # Write sorted_expert_ids — predicated stores to buffer (safe: buffer_store ignores OOB) @@ -678,13 +543,13 @@ def moe_sorting_decode_kernel( cs_start_ix = ArithValue(safe_eid_wr).index_cast(T.index) cs_end_ix = ArithValue(safe_eid_wr + c_one_i32).index_cast(T.index) - e_start = _lds_load(cumsum_mr, cs_start_ix) - e_end = eid_wr_valid.select(_lds_load(cumsum_mr, cs_end_ix), e_start) - local_eid = _lds_load(cumdup_mr, cs_start_ix) + e_start = _lds_load_raw(cumsum_mr, cs_start_ix) + e_end = eid_wr_valid.select(_lds_load_raw(cumsum_mr, cs_end_ix), e_start) + local_eid = _lds_load_raw(cumdup_mr, cs_start_ix) # Store cumdup: reuse cumdup for scatter phase position tracking. # Write e_start to cumdup[eid] (overwriting mask cumsum, no longer needed). - _lds_store(cumdup_mr, e_start, cs_start_ix) + _lds_store_raw(cumdup_mr, e_start, cs_start_ix) blk_start = e_start // c_unit blk_end = e_end // c_unit @@ -695,12 +560,12 @@ def moe_sorting_decode_kernel( # Store cumdup[E] = cumsum[E]. # All threads write cumE to cumdup[E] (all write the same value, no race). cs_E_ix = ArithValue(c_E).index_cast(T.index) - cumE = _lds_load(cumsum_mr, cs_E_ix) - _lds_store(cumdup_mr, cumE, cs_E_ix) + cumE = _lds_load_raw(cumsum_mr, cs_E_ix) + _lds_store_raw(cumdup_mr, cumE, cs_E_ix) gpu.barrier() # ====================== PRE-FILL: Sentinel fill (cooperative) =========== - total_padded_pre = _lds_load(cumdup_mr, ArithValue(c_E).index_cast(T.index)) + total_padded_pre = _lds_load_raw(cumdup_mr, ArithValue(c_E).index_cast(T.index)) _fill_sentinel_slots( sorted_ids_rsrc, sorted_w_rsrc, @@ -714,9 +579,6 @@ def moe_sorting_decode_kernel( gpu.barrier() # ====================== PHASE 3: Scatter ============================== - # CK uses wave_cumsum<8> (DPP prefix sum) + ds_bpermute (lane broadcast). - # FlyDSL has neither. Instead, each lane reads all 8 mesh values in the - # batch from LDS to compute its own prefix offset. No shuffle needed. for i_e2 in range_constexpr(0, E, DECODE_BLOCK // 8): eid_sc = fx.Int32(i_e2) + lane_group_id eid_sc_valid = eid_sc < c_E @@ -733,7 +595,7 @@ def moe_sorting_decode_kernel( sc_expert_enabled = eid_sc_valid & (sc_mask_val != c_zero_i32) cs_sc_ix = ArithValue(safe_eid_sc).index_cast(T.index) - position = _lds_load(cumsum_mr, cs_sc_ix) + position = _lds_load_raw(cumsum_mr, cs_sc_ix) for i_sub2 in range_constexpr(0, sub_tokens, 8): # This lane handles sub_token (i_sub2 + lane_group_os). @@ -742,14 +604,14 @@ def moe_sorting_decode_kernel( safe_my_sub = my_sub_valid.select(my_sub, c_zero_i32) my_mesh_addr = safe_my_sub * c_smem_cols + safe_eid_sc my_mesh_ix = ArithValue(my_mesh_addr).index_cast(T.index) - my_x = _lds_load(mesh_mr, my_mesh_ix) + my_x = _lds_load_raw(mesh_mr, my_mesh_ix) my_has_token = my_sub_valid & (my_x != c_zero_i32) local_cnt = my_has_token.select(c_one_i32, c_zero_i32) # 8-lane group prefix sum (NOT full-wave — uses lane_group_os, # only shifts 1,2,4, no cross-row bpermute needed). - cnt_raw = _unwrap(local_cnt) - zero_raw = _unwrap(c_zero_i32) + cnt_raw = _unwrap_val(local_cnt) + zero_raw = _unwrap_val(c_zero_i32) # row_shr:1 remote = fly_rocdl.update_dpp( @@ -759,7 +621,7 @@ def moe_sorting_decode_kernel( local_cnt = should_add.select(local_cnt + fx.Int32(remote), local_cnt) # row_shr:2 - cnt_raw = _unwrap(local_cnt) + cnt_raw = _unwrap_val(local_cnt) remote = fly_rocdl.update_dpp( T.i32, zero_raw, cnt_raw, DPP_ROW_SHR_2, DPP_ROW_MASK, DPP_BANK_MASK, True ) @@ -767,7 +629,7 @@ def moe_sorting_decode_kernel( local_cnt = should_add.select(local_cnt + fx.Int32(remote), local_cnt) # row_shr:4 - cnt_raw = _unwrap(local_cnt) + cnt_raw = _unwrap_val(local_cnt) remote = fly_rocdl.update_dpp( T.i32, zero_raw, cnt_raw, DPP_ROW_SHR_4, DPP_ROW_MASK, DPP_BANK_MASK, True ) @@ -797,7 +659,7 @@ def moe_sorting_decode_kernel( # Write back updated position (for padding phase). # Invalid lane groups write position (=0+0=0) to cumsum[0] which is harmless. - _lds_store(cumsum_mr, position, cs_sc_ix) + _lds_store_raw(cumsum_mr, position, cs_sc_ix) gpu.barrier() # Padding already filled by PRE-FILL phase above (before scatter). @@ -843,33 +705,6 @@ def launch_moe_sorting_decode( return launch_moe_sorting_decode -# --------------------------------------------------------------------------- -# LDS helpers for prefill kernels (module-level, used inside @flyc.kernel) -# --------------------------------------------------------------------------- -def _lds_load_raw(raw_mr, idx): - """Load i32 from LDS raw memref. idx can be i32 or index.""" - raw_idx = idx.ir_value() if hasattr(idx, "ir_value") else idx - if not isinstance(raw_idx.type, ir.IndexType): - raw_idx = ArithValue(idx).index_cast(T.index) - raw_idx = raw_idx.ir_value() if hasattr(raw_idx, "ir_value") else raw_idx - return fx.Int32(memref_ops.load(raw_mr, [raw_idx])) - - -def _lds_store_raw(raw_mr, val, idx): - """Store i32 to LDS raw memref. idx can be i32 or index.""" - v = val.ir_value() if hasattr(val, "ir_value") else val - raw_idx = idx.ir_value() if hasattr(idx, "ir_value") else idx - if not isinstance(raw_idx.type, ir.IndexType): - raw_idx = ArithValue(idx).index_cast(T.index) - raw_idx = raw_idx.ir_value() if hasattr(raw_idx, "ir_value") else raw_idx - memref_ops.store(v, raw_mr, [raw_idx]) - - -def _unwrap_raw(v): - """Unwrap DSL value to raw MLIR ir.Value.""" - return v.ir_value() if hasattr(v, "ir_value") else v - - # --------------------------------------------------------------------------- # FlyDSL GPU kernels — prefill path (4 kernels, large T via HBM workspace) # --------------------------------------------------------------------------- @@ -907,6 +742,139 @@ def compile_moe_sorting_prefill( arch = get_hip_arch() E = num_experts + @flyc.jit + def _extend_local_idx_for_extra_experts(cumsum_mr, mask_rsrc, K4_BLOCK, E, has_mask): + """Thread-0: write local expert indices for experts >= K4_BLOCK to cumsum_mr.""" + if has_mask: + prev_local = _lds_load_raw(cumsum_mr, fx.Int32(K4_BLOCK - 1)) + prev_mask = buffer_ops.buffer_load(mask_rsrc, fx.Int32(K4_BLOCK - 1), vec_width=1, dtype=T.i32) + prev_local = prev_local + prev_mask + for _e3 in range_constexpr(K4_BLOCK, E): + e3_mask = buffer_ops.buffer_load(mask_rsrc, fx.Int32(_e3), vec_width=1, dtype=T.i32) + _lds_store_raw(cumsum_mr, prev_local, fx.Int32(_e3)) + prev_local = prev_local + e3_mask + else: + for _e3 in range_constexpr(K4_BLOCK, E): + _lds_store_raw(cumsum_mr, fx.Int32(_e3), fx.Int32(_e3)) + + @flyc.jit + def _p23_scatter_mesh( + tid, + scatter_mr, + ws_rsrc, + weights_rsrc, + sorted_ids_rsrc, + sorted_w_rsrc, + mask_rsrc, + my_expert, + my_start, + my_end, + i32_mesh_stride, + c_topk, + K4_BLOCK, + has_mask, + ): + """P23 Step 4: EP mask check, read uint8 mesh, DPP prefix sum, scatter tokens.""" + lane = tid % WARP_SIZE + wave = tid // WARP_SIZE + K4_NUM_WAVES = K4_BLOCK // WARP_SIZE + c_zero, c_one, c4 = fx.Int32(0), fx.Int32(1), fx.Int32(4) + c_ff, c_oob_idx = fx.Int32(0xFF), fx.Int32(0x7FFFFFFF) + p23_bid_enabled = c_one != c_zero + if has_mask: + p23_bid_mask = buffer_ops.buffer_load(mask_rsrc, my_expert, vec_width=1, dtype=T.i32) + p23_bid_enabled = p23_bid_mask != c_zero + i32_words_per_row = i32_mesh_stride >> fx.Int32(2) + n_mesh_iters = (my_start != my_end).select( + (i32_words_per_row + fx.Int32(K4_BLOCK - 1)) // fx.Int32(K4_BLOCK), c_zero + ) + mesh_row_i32_base = (my_expert * i32_mesh_stride) >> fx.Int32(2) + for _si, state in range( + fx.Index(0), ArithValue(n_mesh_iters).index_cast(T.index), fx.Index(1), init=[my_start] + ): + position = state[0] + word_idx = fx.Int32(_si) * fx.Int32(K4_BLOCK) + tid + col_valid = p23_bid_enabled & (word_idx < i32_words_per_row) + safe_word_idx = col_valid.select(word_idx, c_zero) + word = buffer_ops.buffer_load(ws_rsrc, mesh_row_i32_base + safe_word_idx, vec_width=1, dtype=T.i32) + x0 = word & c_ff + x1 = (word >> fx.Int32(8)) & c_ff + x2 = (word >> fx.Int32(16)) & c_ff + x3 = (word >> fx.Int32(24)) & c_ff + base_col = word_idx * c4 + h0 = col_valid & (x0 != c_zero) + h1 = col_valid & (x1 != c_zero) + h2 = col_valid & (x2 != c_zero) + h3 = col_valid & (x3 != c_zero) + my_cnt = ( + h0.select(c_one, c_zero) + + h1.select(c_one, c_zero) + + h2.select(c_one, c_zero) + + h3.select(c_one, c_zero) + ) + my_cnt, my_cnt_inclusive = _allwave_inclusive_prefix_sum( + my_cnt, lane, wave, scatter_mr, K4_NUM_WAVES, WARP_SIZE + ) + wave_offset = my_cnt_inclusive - my_cnt + batch_total = c_zero + for _w in range_constexpr(K4_NUM_WAVES): + batch_total = batch_total + _lds_load_raw(scatter_mr, fx.Int32(_w)) + gpu.barrier() + my_thread_count = ( + h0.select(c_one, c_zero) + + h1.select(c_one, c_zero) + + h2.select(c_one, c_zero) + + h3.select(c_one, c_zero) + ) + my_exclusive = my_cnt - my_thread_count + wave_offset + scatter_base = position + my_exclusive + pid_0 = (h0.select(x0 - c_one, c_zero) << fx.Int32(24)) | base_col + pid_1 = (h1.select(x1 - c_one, c_zero) << fx.Int32(24)) | (base_col + c_one) + pid_2 = (h2.select(x2 - c_one, c_zero) << fx.Int32(24)) | (base_col + fx.Int32(2)) + pid_3 = (h3.select(x3 - c_one, c_zero) << fx.Int32(24)) | (base_col + fx.Int32(3)) + safe_slot_0 = h0.select(scatter_base, c_oob_idx) + off1 = scatter_base + h0.select(c_one, c_zero) + safe_slot_1 = h1.select(off1, c_oob_idx) + off2 = off1 + h1.select(c_one, c_zero) + safe_slot_2 = h2.select(off2, c_oob_idx) + off3 = off2 + h2.select(c_one, c_zero) + safe_slot_3 = h3.select(off3, c_oob_idx) + w_val_0 = buffer_ops.buffer_load( + weights_rsrc, + h0.select(base_col * c_topk + h0.select(x0 - c_one, c_zero), c_zero), + vec_width=1, + dtype=T.i32, + ) + w_val_1 = buffer_ops.buffer_load( + weights_rsrc, + h1.select((base_col + c_one) * c_topk + h1.select(x1 - c_one, c_zero), c_zero), + vec_width=1, + dtype=T.i32, + ) + w_val_2 = buffer_ops.buffer_load( + weights_rsrc, + h2.select((base_col + fx.Int32(2)) * c_topk + h2.select(x2 - c_one, c_zero), c_zero), + vec_width=1, + dtype=T.i32, + ) + w_val_3 = buffer_ops.buffer_load( + weights_rsrc, + h3.select((base_col + fx.Int32(3)) * c_topk + h3.select(x3 - c_one, c_zero), c_zero), + vec_width=1, + dtype=T.i32, + ) + buffer_ops.buffer_store(pid_0, sorted_ids_rsrc, safe_slot_0) + buffer_ops.buffer_store(pid_1, sorted_ids_rsrc, safe_slot_1) + buffer_ops.buffer_store(pid_2, sorted_ids_rsrc, safe_slot_2) + buffer_ops.buffer_store(pid_3, sorted_ids_rsrc, safe_slot_3) + buffer_ops.buffer_store(w_val_0, sorted_w_rsrc, safe_slot_0) + buffer_ops.buffer_store(w_val_1, sorted_w_rsrc, safe_slot_1) + buffer_ops.buffer_store(w_val_2, sorted_w_rsrc, safe_slot_2) + buffer_ops.buffer_store(w_val_3, sorted_w_rsrc, safe_slot_3) + pos_next = position + batch_total + results = yield [pos_next] + return results + # --- K1: ClearWorkspace kernel ------------------------------------------- # CK uses grid=262144, block=1024 (1 store per thread, no loop). # Match that: block=1024, grid=ceil(ws_total/1024). @@ -1393,16 +1361,14 @@ def p23_kernel( my_start = _lds_load_raw(cumsum_mr, my_expert) my_end = _lds_load_raw(cumsum_mr, my_expert + c_one) + # Hoist before if/else: AST rewriter extracts branches into + # separate functions, so variables must be defined in outer scope. local_idx_p23 = tid if has_mask: - # EP: Compute mask cumsum for local expert index (register-only DPP scan). - # Uses my_mask_val (loaded before the chunked loop) to avoid overwrite. _, p23_mask_inclusive = _allwave_inclusive_prefix_sum( my_mask_val, lane, wave, scatter_mr, K4_NUM_WAVES, WARP_SIZE ) local_idx_p23 = p23_mask_inclusive - my_mask_val - else: - local_idx_p23 = tid # Block 0, thread 0 writes num_valid_ids if (bid == c_zero) & (tid == c_zero): From 36fc4231aac72295960cb65fa42c0411da714bfd Mon Sep 17 00:00:00 2001 From: amd-weisun Date: Thu, 21 May 2026 17:39:04 +0100 Subject: [PATCH 11/14] fix format --- kernels/moe_sorting_kernel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kernels/moe_sorting_kernel.py b/kernels/moe_sorting_kernel.py index ae78bde1..38ed2bbb 100644 --- a/kernels/moe_sorting_kernel.py +++ b/kernels/moe_sorting_kernel.py @@ -14,7 +14,7 @@ - DECODE_MAX_T < T <= 2048: 2 kernels (fused P0v2 + P23) via HBM workspace. - Prefill/4 kernels (T > 2048): 4 kernels via HBM workspace (ClearWS → P0 scatter → P1 count → P23). - DECODE_MAX_T = 16 for large models(Qwen3-MoE, MiMo-V2-Flash .. DeepSeekR1 V4) on current LDS sizes. - - DECODE_MAX_T is larger fro small model like 128 for Mixtral-8x7B, 32 for GPT-OSS and MiniMax-M2 + - DECODE_MAX_T is larger fro small model like 128 for Mixtral-8x7B, 32 for GPT-OSS and MiniMax-M2 Packed token ID format: (topk_position << 24) | token_id - Upper 8 bits: topk slot (0..topk-1) From 2668d72336cc6d04f8d78cdb7b547c4a0e3a82af Mon Sep 17 00:00:00 2001 From: amd-weisun Date: Thu, 21 May 2026 21:27:53 +0100 Subject: [PATCH 12/14] improve code style --- kernels/moe_sorting_kernel.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/kernels/moe_sorting_kernel.py b/kernels/moe_sorting_kernel.py index 38ed2bbb..ce274b75 100644 --- a/kernels/moe_sorting_kernel.py +++ b/kernels/moe_sorting_kernel.py @@ -812,6 +812,7 @@ def _p23_scatter_mesh( + h2.select(c_one, c_zero) + h3.select(c_one, c_zero) ) + my_pre_scan = my_cnt my_cnt, my_cnt_inclusive = _allwave_inclusive_prefix_sum( my_cnt, lane, wave, scatter_mr, K4_NUM_WAVES, WARP_SIZE ) @@ -820,13 +821,7 @@ def _p23_scatter_mesh( for _w in range_constexpr(K4_NUM_WAVES): batch_total = batch_total + _lds_load_raw(scatter_mr, fx.Int32(_w)) gpu.barrier() - my_thread_count = ( - h0.select(c_one, c_zero) - + h1.select(c_one, c_zero) - + h2.select(c_one, c_zero) - + h3.select(c_one, c_zero) - ) - my_exclusive = my_cnt - my_thread_count + wave_offset + my_exclusive = my_cnt - my_pre_scan + wave_offset scatter_base = position + my_exclusive pid_0 = (h0.select(x0 - c_one, c_zero) << fx.Int32(24)) | base_col pid_1 = (h1.select(x1 - c_one, c_zero) << fx.Int32(24)) | (base_col + c_one) @@ -1737,7 +1732,7 @@ def moe_sorting_flydsl( _launch_cached(_prefill_cf_cache, base_key + ("p0v2_p23",), launch_p0v2_p23, p0v2_args, stream) else: k1_grid = (ws_total + 1023) // 1024 - k2_grid = min(num_cu * target_occupancy, (M * topk + 255) // 256) + k2_grid = num_cu * target_occupancy k2_total = M * topk k2_stride = k2_grid * 256 k2_niters = (k2_total + k2_stride - 1) // k2_stride From 1e661f9b51ceca1d928dabed1cb24b049d2af913 Mon Sep 17 00:00:00 2001 From: amd-weisun Date: Thu, 21 May 2026 21:46:53 +0100 Subject: [PATCH 13/14] make unified interface --- kernels/moe_sorting_kernel.py | 35 ++++++++++++++++++++++++----------- 1 file changed, 24 insertions(+), 11 deletions(-) diff --git a/kernels/moe_sorting_kernel.py b/kernels/moe_sorting_kernel.py index ce274b75..0d46a041 100644 --- a/kernels/moe_sorting_kernel.py +++ b/kernels/moe_sorting_kernel.py @@ -197,7 +197,7 @@ def _lds_store_raw(raw_mr, val, idx): # FlyDSL GPU kernel — decode path (single kernel, SubTokenOneShot) # --------------------------------------------------------------------------- @functools.lru_cache(maxsize=256) -def compile_moe_sorting_decode( +def _compile_moe_sorting_decode( *, num_experts: int, topk: int, @@ -709,7 +709,7 @@ def launch_moe_sorting_decode( # FlyDSL GPU kernels — prefill path (4 kernels, large T via HBM workspace) # --------------------------------------------------------------------------- @functools.lru_cache(maxsize=256) -def compile_moe_sorting_prefill( +def _compile_moe_sorting_prefill( *, num_experts: int, topk: int, @@ -1564,7 +1564,7 @@ def _compute_sub_tokens(num_experts, arch=None): """Compute the LDS-capacity threshold (sub_tokens) for decode vs prefill decision. Returns the max T that fits in LDS for the decode (single-kernel) path. - Same formula as compile_moe_sorting_decode. + Same formula as _compile_moe_sorting_decode. """ if arch is None: arch = get_hip_arch() @@ -1600,6 +1600,21 @@ def moe_sorting_get_workspace_size(M, num_experts, topk, unit_size=UNIT_SIZE): return ws_mesh_i32 + (num_experts + 1) +def compile_moe_sorting(*, num_experts, topk, max_tokens=128, unit_size=UNIT_SIZE, has_mask=False): + """Compile MoE sorting kernels for all paths (decode + prefill). + + Returns (launch_decode, launch_p0v2_p23, launch_4k_fused) covering all T ranges. + Decode compilation depends on max_tokens (LDS sizing); prefill is independent. + """ + launch_decode = _compile_moe_sorting_decode( + num_experts=num_experts, topk=topk, max_tokens=max_tokens, unit_size=unit_size, has_mask=has_mask + ) + _, _, _, _, _, launch_p0v2_p23, launch_4k_fused = _compile_moe_sorting_prefill( + num_experts=num_experts, topk=topk, unit_size=unit_size, has_mask=has_mask + ) + return launch_decode, launch_p0v2_p23, launch_4k_fused + + def _launch_cached(cache, key, launch_fn, args, stream): """AOT-compiled dispatch: first call JITs, subsequent calls use cached CompiledFunction.""" cf = cache.get(key) @@ -1666,16 +1681,17 @@ def moe_sorting_flydsl( DECODE_MAX_T = min(sub_tokens, max(16, BLOCK_SIZE // max(topk, num_experts // 8))) + target_occupancy = 2 + num_cu = torch.cuda.get_device_properties(device).multi_processor_count + if M <= min(sub_tokens, DECODE_MAX_T): max_tokens = max(M, 8) max_tokens = ((max_tokens + 7) // 8) * 8 - target_occupancy = 2 - num_cu = torch.cuda.get_device_properties(topk_ids.device).multi_processor_count n_zero_blocks = min((moe_buf_elems + BLOCK_SIZE - 1) // BLOCK_SIZE, num_cu * target_occupancy) n_grid_blocks = 1 + n_zero_blocks - launch_fn = compile_moe_sorting_decode( + launch_decode, _, _ = compile_moe_sorting( num_experts=num_experts, topk=topk, max_tokens=max_tokens, unit_size=unit_size, has_mask=has_mask ) decode_args = ( @@ -1692,9 +1708,8 @@ def moe_sorting_flydsl( n_grid_blocks, ) cache_key = (num_experts, topk, max_tokens, unit_size, has_mask, device.index) - _launch_cached(_decode_cf_cache, cache_key, launch_fn, decode_args, torch.cuda.current_stream()) + _launch_cached(_decode_cf_cache, cache_key, launch_decode, decode_args, torch.cuda.current_stream()) else: - # Prefill path: multiple kernels via HBM workspace mesh_stride = ((M + unit_size - 1) // unit_size) * unit_size ws_mesh_bytes = num_experts * mesh_stride ws_mesh_i32 = (ws_mesh_bytes + 3) // 4 @@ -1702,12 +1717,10 @@ def moe_sorting_flydsl( if workspace is None: workspace = torch.empty(ws_total, dtype=torch.int32, device=device) - _, _, _, _, _, launch_p0v2_p23, launch_4k_fused = compile_moe_sorting_prefill( + _, launch_p0v2_p23, launch_4k_fused = compile_moe_sorting( num_experts=num_experts, topk=topk, unit_size=unit_size, has_mask=has_mask ) stream = torch.cuda.current_stream() - num_cu = torch.cuda.get_device_properties(device).multi_processor_count - target_occupancy = 2 n_zero_blocks = min((moe_buf_elems + BLOCK_SIZE - 1) // BLOCK_SIZE, num_cu * target_occupancy) k4_grid = num_experts + n_zero_blocks base_key = (num_experts, topk, unit_size, has_mask, device.index) From fc13aa849977eaac5eb9bb79c78919a4695201c1 Mon Sep 17 00:00:00 2001 From: amd-weisun Date: Thu, 21 May 2026 22:39:25 +0100 Subject: [PATCH 14/14] renaming decode prefill to avoid confusion --- kernels/moe_sorting_kernel.py | 132 +++++++++++++++--------------- tests/kernels/test_moe_sorting.py | 80 +++++++++--------- 2 files changed, 105 insertions(+), 107 deletions(-) diff --git a/kernels/moe_sorting_kernel.py b/kernels/moe_sorting_kernel.py index 0d46a041..6916fef6 100644 --- a/kernels/moe_sorting_kernel.py +++ b/kernels/moe_sorting_kernel.py @@ -9,12 +9,10 @@ Algorithm: counting sort in LDS (histogram → prefix-sum → scatter). -Three paths (selected by T and DECODE_MAX_T = min(sub_tokens, BLOCK_SIZE // max(topk, E//8))): - - T <= DECODE_MAX_T: single kernel, all phases in LDS. - - DECODE_MAX_T < T <= 2048: 2 kernels (fused P0v2 + P23) via HBM workspace. - - Prefill/4 kernels (T > 2048): 4 kernels via HBM workspace (ClearWS → P0 scatter → P1 count → P23). - - DECODE_MAX_T = 16 for large models(Qwen3-MoE, MiMo-V2-Flash .. DeepSeekR1 V4) on current LDS sizes. - - DECODE_MAX_T is larger fro small model like 128 for Mixtral-8x7B, 32 for GPT-OSS and MiniMax-M2 +Three paths (selected by T vs ONESHOT_MAX_T = min(sub_tokens, max(16, BLOCK_SIZE // max(topk, E//8)))): + - Oneshot (T <= ONESHOT_MAX_T): single kernel, all phases in LDS. + - Multiphase/2k (ONESHOT_MAX_T < T <= 2048): 2 kernels (fused P0v2 + P23) via HBM workspace. + - Multiphase/4k (T > 2048): 4 kernels (ClearWS → P0 scatter → P1 count → P23) via HBM workspace. Packed token ID format: (topk_position << 24) | token_id - Upper 8 bits: topk slot (0..topk-1) @@ -44,7 +42,7 @@ UNIT_SIZE = 32 # GEMM tile-M, aka block_size in CK WARP_SIZE = get_warp_size() -# DPP constants for prefix sum (used by decode and prefill) +# DPP constants for prefix sum (used by oneshot and multiphase) DPP_ROW_SHR_1 = 0x111 DPP_ROW_SHR_2 = 0x112 DPP_ROW_SHR_4 = 0x114 @@ -161,7 +159,7 @@ def _fill_sentinel_slots(sorted_ids_rsrc, sorted_w_rsrc, start, count, sentinel, # --------------------------------------------------------------------------- -# LDS helpers for prefill kernels (module-level, used inside @flyc.kernel) +# LDS helpers for multiphase kernels (module-level, used inside @flyc.kernel) # --------------------------------------------------------------------------- def _lds_load_raw(raw_mr, idx): """Load i32 from LDS raw memref. idx can be i32 or index.""" @@ -188,16 +186,16 @@ def _lds_store_raw(raw_mr, val, idx): # returns a CompiledFunction whose __call__ skips inspect.Signature.bind, # _make_cache_key, and dict lookup, reducing dispatch from ~70 us to ~5 us. # --------------------------------------------------------------------------- -_decode_cf_cache = {} # (num_experts, topk, max_tokens, unit_size, has_mask, device) -> CompiledFunction -_prefill_cf_cache = {} # (num_experts, topk, unit_size, kernel_name, *constexpr_vals) -> CompiledFunction +_oneshot_cf_cache = {} # (num_experts, topk, max_tokens, unit_size, has_mask, device) -> CompiledFunction +_multiphase_cf_cache = {} # (num_experts, topk, unit_size, kernel_name, *constexpr_vals) -> CompiledFunction _dummy_mask_cache = {} # device -> torch.Tensor(1, dtype=i32, value=1) # --------------------------------------------------------------------------- -# FlyDSL GPU kernel — decode path (single kernel, SubTokenOneShot) +# FlyDSL GPU kernel — oneshot path (single kernel, all phases in LDS) # --------------------------------------------------------------------------- @functools.lru_cache(maxsize=256) -def _compile_moe_sorting_decode( +def _compile_moe_sorting_oneshot( *, num_experts: int, topk: int, @@ -205,7 +203,7 @@ def _compile_moe_sorting_decode( unit_size: int = UNIT_SIZE, has_mask: bool = False, ): - """Compile the decode-path MoE sorting kernel. + """Compile the oneshot MoE sorting kernel (single kernel, all phases in LDS). Parameters ---------- @@ -221,9 +219,9 @@ def _compile_moe_sorting_decode( arch = get_hip_arch() E = num_experts # CDNA (warp64): 512 threads = 8 waves, affordable cross-wave reduction. - max_decode_block = 512 if WARP_SIZE == 64 else 256 - DECODE_BLOCK = 256 if E <= 256 else min(512, max_decode_block) - NUM_WAVES = DECODE_BLOCK // WARP_SIZE + max_oneshot_block = 512 if WARP_SIZE == 64 else 256 + ONESHOT_BLOCK = 256 if E <= 256 else min(512, max_oneshot_block) + NUM_WAVES = ONESHOT_BLOCK // WARP_SIZE smem_cols = E + 1 # LDS sizing: sub_tokens rows for the token×expert histogram @@ -266,8 +264,8 @@ def _compile_moe_sorting_decode( scratch_offset = allocator._align(allocator.ptr, 16) allocator.ptr = scratch_offset + NUM_WAVES * 4 - @flyc.kernel(known_block_size=[DECODE_BLOCK, 1, 1]) - def moe_sorting_decode_kernel( + @flyc.kernel(known_block_size=[ONESHOT_BLOCK, 1, 1]) + def moe_sorting_oneshot_kernel( topk_ids_tensor: fx.Tensor, topk_weights_tensor: fx.Tensor, sorted_token_ids: fx.Tensor, @@ -314,9 +312,9 @@ def moe_sorting_decode_kernel( # =================== MOE_BUF ZEROING (blocks > 0 only) =============== if bid != c_zero_i32: - zero_gid_v4 = (bid - c_one_i32) * fx.Int32(DECODE_BLOCK) + tid + zero_gid_v4 = (bid - c_one_i32) * fx.Int32(ONESHOT_BLOCK) + tid num_zero_blocks = gpu.grid_dim.x - c_one_i32 - zero_stride_v4 = num_zero_blocks * fx.Int32(DECODE_BLOCK) + zero_stride_v4 = num_zero_blocks * fx.Int32(ONESHOT_BLOCK) _zero_moe_buf_grid_stride( moe_buf_rsrc, zero_gid_v4, zero_stride_v4, i32_moe_buf_elems >> fx.Int32(2), c_oob_idx ) @@ -325,7 +323,7 @@ def moe_sorting_decode_kernel( if bid == c_zero_i32: # ========================= PHASE 1: Histogram ========================= # Clear mesh region — unconditional store to safe index when out of bounds - for i_clear in range_constexpr(0, sub_tokens * smem_cols, DECODE_BLOCK): + for i_clear in range_constexpr(0, sub_tokens * smem_cols, ONESHOT_BLOCK): idx = fx.Int32(i_clear) + tid is_valid = idx < fx.Int32(sub_tokens * smem_cols) safe_idx = is_valid.select(idx, c_zero_i32) @@ -336,7 +334,7 @@ def moe_sorting_decode_kernel( # Fill mesh: for each (token, topk_slot), write topk_slot+1 to mesh[token, expert_id] total_assignments = tokens * c_topk - for i_assign in range_constexpr(0, max_tokens * topk, DECODE_BLOCK): + for i_assign in range_constexpr(0, max_tokens * topk, ONESHOT_BLOCK): flat_idx = fx.Int32(i_assign) + tid is_valid = flat_idx < total_assignments safe_flat = is_valid.select(flat_idx, c_zero_i32) @@ -371,7 +369,7 @@ def moe_sorting_decode_kernel( _lds_store_raw(cumsum_mr, c_zero_i32, c_zero_i32) gpu.barrier() - for i_e in range_constexpr(0, E, DECODE_BLOCK // 8): + for i_e in range_constexpr(0, E, ONESHOT_BLOCK // 8): eid_local = fx.Int32(i_e) + lane_group_id eid_valid = eid_local < c_E @@ -412,7 +410,7 @@ def moe_sorting_decode_kernel( # Phase 2b: Prefix sum over expert counts. # Step 1: Each thread converts its expert's raw count → padded block size. - for i_cvt in range_constexpr(0, E, DECODE_BLOCK): + for i_cvt in range_constexpr(0, E, ONESHOT_BLOCK): cvt_eid = fx.Int32(i_cvt) + tid cvt_valid = cvt_eid < c_E # Safe index: valid → cumsum[eid+1], invalid → cumsum[0] (write 0, harmless) @@ -429,7 +427,7 @@ def moe_sorting_decode_kernel( # EP: zero padded count for masked experts in a separate pass. # Loading from mask buffer inside the padded-count loop above interfered # with expert 0 (MLIR codegen issue). Separate pass avoids this. - for i_ep in range_constexpr(0, E, DECODE_BLOCK): + for i_ep in range_constexpr(0, E, ONESHOT_BLOCK): ep_eid = fx.Int32(i_ep) + tid ep_valid = ep_eid < c_E ep_safe_eid = ep_valid.select(ep_eid, c_zero_i32) @@ -444,8 +442,8 @@ def moe_sorting_decode_kernel( # Step 2: All-wave parallel prefix sum (cumsum → cumdup). scratch_mr = SmemPtr(base_ptr, scratch_offset, T.i32, shape=(NUM_WAVES,)).get() - # All threads read cumsum[tid+1] (in chunks for E > DECODE_BLOCK) - for _ps_chunk in range_constexpr(0, E, DECODE_BLOCK): + # All threads read cumsum[tid+1] (in chunks for E > ONESHOT_BLOCK) + for _ps_chunk in range_constexpr(0, E, ONESHOT_BLOCK): ps_eid = fx.Int32(_ps_chunk) + tid ps_valid = ps_eid < c_E ps_safe_ix = ArithValue(ps_valid.select(ps_eid + c_one_i32, c_zero_i32)).index_cast(T.index) @@ -465,10 +463,10 @@ def moe_sorting_decode_kernel( ) gpu.barrier() - # For E > DECODE_BLOCK: thread 0 serially extends - if E > DECODE_BLOCK: + # For E > ONESHOT_BLOCK: thread 0 serially extends + if E > ONESHOT_BLOCK: if is_t0: - _extend_prefix_sum_serial(cumdup_mr, DECODE_BLOCK, E, _lds_load_raw, _lds_store_raw) + _extend_prefix_sum_serial(cumdup_mr, ONESHOT_BLOCK, E, _lds_load_raw, _lds_store_raw) gpu.barrier() # cumdup[0] = 0 @@ -483,7 +481,7 @@ def moe_sorting_decode_kernel( gpu.barrier() # Copy cumdup → cumsum (all threads, one expert per thread) - for i_cp in range_constexpr(0, E + 1, DECODE_BLOCK): + for i_cp in range_constexpr(0, E + 1, ONESHOT_BLOCK): cp_idx = fx.Int32(i_cp) + tid cp_valid = cp_idx <= c_E safe_cp_idx = cp_valid.select(cp_idx, c_zero_i32) @@ -495,7 +493,7 @@ def moe_sorting_decode_kernel( if has_mask: # EP: Compute mask cumsum in cumdup for local expert index mapping. # cumdup[eid] = exclusive prefix sum of mask[0..eid-1] = local expert index. - for i_ml in range_constexpr(0, E, DECODE_BLOCK): + for i_ml in range_constexpr(0, E, ONESHOT_BLOCK): ml_eid = fx.Int32(i_ml) + tid ml_valid = ml_eid < c_E safe_ml_eid = ml_valid.select(ml_eid, c_zero_i32) @@ -517,16 +515,16 @@ def moe_sorting_decode_kernel( ) gpu.barrier() - if E > DECODE_BLOCK: + if E > ONESHOT_BLOCK: if is_t0: - _extend_prefix_sum_serial(cumdup_mr, DECODE_BLOCK, E, _lds_load_raw, _lds_store_raw) + _extend_prefix_sum_serial(cumdup_mr, ONESHOT_BLOCK, E, _lds_load_raw, _lds_store_raw) gpu.barrier() _lds_store_raw(cumdup_mr, c_zero_i32, c_zero_i32) gpu.barrier() else: # No mask: cumdup[eid] = eid (identity mapping) - for i_ml in range_constexpr(0, E, DECODE_BLOCK): + for i_ml in range_constexpr(0, E, ONESHOT_BLOCK): ml_eid = fx.Int32(i_ml) + tid ml_valid = ml_eid < c_E safe_ml_eid = ml_valid.select(ml_eid, c_zero_i32) @@ -536,7 +534,7 @@ def moe_sorting_decode_kernel( # Write sorted_expert_ids — predicated stores to buffer (safe: buffer_store ignores OOB) # EP: use cumdup[eid] as local expert index instead of global eid - for i_eid in range_constexpr(0, E, DECODE_BLOCK): + for i_eid in range_constexpr(0, E, ONESHOT_BLOCK): eid_wr = fx.Int32(i_eid) + tid eid_wr_valid = eid_wr < c_E safe_eid_wr = eid_wr_valid.select(eid_wr, c_zero_i32) @@ -572,14 +570,14 @@ def moe_sorting_decode_kernel( c_zero_i32, total_padded_pre, c_sentinel | tokens, - DECODE_BLOCK, + ONESHOT_BLOCK, tid, c_oob_idx, ) gpu.barrier() # ====================== PHASE 3: Scatter ============================== - for i_e2 in range_constexpr(0, E, DECODE_BLOCK // 8): + for i_e2 in range_constexpr(0, E, ONESHOT_BLOCK // 8): eid_sc = fx.Int32(i_e2) + lane_group_id eid_sc_valid = eid_sc < c_E # Invalid lane groups map to cumsum[E] (the total count) instead of @@ -665,7 +663,7 @@ def moe_sorting_decode_kernel( # Padding already filled by PRE-FILL phase above (before scatter). @flyc.jit - def launch_moe_sorting_decode( + def launch_moe_sorting_oneshot( topk_ids_tensor: fx.Tensor, topk_weights_tensor: fx.Tensor, sorted_token_ids: fx.Tensor, @@ -684,7 +682,7 @@ def launch_moe_sorting_decode( with ir.InsertionPoint(ctx.gpu_module_body): allocator.finalize() - launcher = moe_sorting_decode_kernel( + launcher = moe_sorting_oneshot_kernel( topk_ids_tensor, topk_weights_tensor, sorted_token_ids, @@ -698,25 +696,25 @@ def launch_moe_sorting_decode( ) launcher.launch( grid=(n_grid_blocks, 1, 1), - block=(DECODE_BLOCK, 1, 1), + block=(ONESHOT_BLOCK, 1, 1), stream=stream, ) - return launch_moe_sorting_decode + return launch_moe_sorting_oneshot # --------------------------------------------------------------------------- -# FlyDSL GPU kernels — prefill path (4 kernels, large T via HBM workspace) +# FlyDSL GPU kernels — multiphase path (2 or 4 kernels, large T via HBM workspace) # --------------------------------------------------------------------------- @functools.lru_cache(maxsize=256) -def _compile_moe_sorting_prefill( +def _compile_moe_sorting_multiphase( *, num_experts: int, topk: int, unit_size: int = UNIT_SIZE, has_mask: bool = False, ): - """Compile the prefill-path MoE sorting kernels. + """Compile the multiphase MoE sorting kernels (2 or 4 kernels via HBM workspace). For token counts exceeding LDS capacity, uses HBM workspace: K1: ClearWorkspace — zero the workspace buffer @@ -1561,10 +1559,10 @@ def launch_4k_fused( # --------------------------------------------------------------------------- @functools.lru_cache(maxsize=64) def _compute_sub_tokens(num_experts, arch=None): - """Compute the LDS-capacity threshold (sub_tokens) for decode vs prefill decision. + """Compute the LDS-capacity threshold (sub_tokens) for oneshot vs multiphase decision. - Returns the max T that fits in LDS for the decode (single-kernel) path. - Same formula as _compile_moe_sorting_decode. + Returns the max T that fits in LDS for the oneshot (single-kernel) path. + Same formula as _compile_moe_sorting_oneshot. """ if arch is None: arch = get_hip_arch() @@ -1582,17 +1580,17 @@ def _compute_sub_tokens(num_experts, arch=None): sub_unroll = 8 cumsum_bufs = 2 if r < (cumsum_bufs + sub_unroll): - return 0 # LDS too small — always use prefill + return 0 # LDS too small — always use multiphase r_for_sub = ((r - cumsum_bufs) // sub_unroll) * sub_unroll return r_for_sub def moe_sorting_get_workspace_size(M, num_experts, topk, unit_size=UNIT_SIZE): - """Return workspace size (in i32 elements) needed for the prefill path. - Returns 0 if the decode path will be used.""" + """Return workspace size (in i32 elements) needed for the multiphase path. + Returns 0 if the oneshot path will be used.""" sub_tokens = _compute_sub_tokens(num_experts) - DECODE_MAX_T = min(sub_tokens, max(16, BLOCK_SIZE // max(topk, num_experts // 8))) - if M <= min(sub_tokens, DECODE_MAX_T): + ONESHOT_MAX_T = min(sub_tokens, max(16, BLOCK_SIZE // max(topk, num_experts // 8))) + if M <= min(sub_tokens, ONESHOT_MAX_T): return 0 mesh_stride = ((M + unit_size - 1) // unit_size) * unit_size ws_mesh_bytes = num_experts * mesh_stride @@ -1601,18 +1599,18 @@ def moe_sorting_get_workspace_size(M, num_experts, topk, unit_size=UNIT_SIZE): def compile_moe_sorting(*, num_experts, topk, max_tokens=128, unit_size=UNIT_SIZE, has_mask=False): - """Compile MoE sorting kernels for all paths (decode + prefill). + """Compile MoE sorting kernels for all paths (oneshot + multiphase). - Returns (launch_decode, launch_p0v2_p23, launch_4k_fused) covering all T ranges. - Decode compilation depends on max_tokens (LDS sizing); prefill is independent. + Returns (launch_oneshot, launch_p0v2_p23, launch_4k_fused) covering all T ranges. + Oneshot compilation depends on max_tokens (LDS sizing); multiphase is independent. """ - launch_decode = _compile_moe_sorting_decode( + launch_oneshot = _compile_moe_sorting_oneshot( num_experts=num_experts, topk=topk, max_tokens=max_tokens, unit_size=unit_size, has_mask=has_mask ) - _, _, _, _, _, launch_p0v2_p23, launch_4k_fused = _compile_moe_sorting_prefill( + _, _, _, _, _, launch_p0v2_p23, launch_4k_fused = _compile_moe_sorting_multiphase( num_experts=num_experts, topk=topk, unit_size=unit_size, has_mask=has_mask ) - return launch_decode, launch_p0v2_p23, launch_4k_fused + return launch_oneshot, launch_p0v2_p23, launch_4k_fused def _launch_cached(cache, key, launch_fn, args, stream): @@ -1641,7 +1639,7 @@ def moe_sorting_flydsl( num_local_tokens=None, workspace=None, ): - """MoE sorting using FlyDSL kernel (decode + prefill paths). + """MoE sorting using FlyDSL kernel (oneshot + multiphase paths). API matches aiter.moe_sorting_fwd for drop-in replacement: moe_sorting_flydsl(topk_ids, topk_weights, @@ -1679,22 +1677,22 @@ def moe_sorting_flydsl( else: mask_tensor = expert_mask - DECODE_MAX_T = min(sub_tokens, max(16, BLOCK_SIZE // max(topk, num_experts // 8))) + ONESHOT_MAX_T = min(sub_tokens, max(16, BLOCK_SIZE // max(topk, num_experts // 8))) target_occupancy = 2 num_cu = torch.cuda.get_device_properties(device).multi_processor_count - if M <= min(sub_tokens, DECODE_MAX_T): + if M <= min(sub_tokens, ONESHOT_MAX_T): max_tokens = max(M, 8) max_tokens = ((max_tokens + 7) // 8) * 8 n_zero_blocks = min((moe_buf_elems + BLOCK_SIZE - 1) // BLOCK_SIZE, num_cu * target_occupancy) n_grid_blocks = 1 + n_zero_blocks - launch_decode, _, _ = compile_moe_sorting( + launch_oneshot, _, _ = compile_moe_sorting( num_experts=num_experts, topk=topk, max_tokens=max_tokens, unit_size=unit_size, has_mask=has_mask ) - decode_args = ( + oneshot_args = ( topk_ids, topk_weights, sorted_ids, @@ -1708,7 +1706,7 @@ def moe_sorting_flydsl( n_grid_blocks, ) cache_key = (num_experts, topk, max_tokens, unit_size, has_mask, device.index) - _launch_cached(_decode_cf_cache, cache_key, launch_decode, decode_args, torch.cuda.current_stream()) + _launch_cached(_oneshot_cf_cache, cache_key, launch_oneshot, oneshot_args, torch.cuda.current_stream()) else: mesh_stride = ((M + unit_size - 1) // unit_size) * unit_size ws_mesh_bytes = num_experts * mesh_stride @@ -1742,7 +1740,7 @@ def moe_sorting_flydsl( moe_buf_elems, k4_grid, ) - _launch_cached(_prefill_cf_cache, base_key + ("p0v2_p23",), launch_p0v2_p23, p0v2_args, stream) + _launch_cached(_multiphase_cf_cache, base_key + ("p0v2_p23",), launch_p0v2_p23, p0v2_args, stream) else: k1_grid = (ws_total + 1023) // 1024 k2_grid = num_cu * target_occupancy @@ -1769,6 +1767,6 @@ def moe_sorting_flydsl( k2_grid, k4_grid, ) - _launch_cached(_prefill_cf_cache, base_key + ("4k_fused",), launch_4k_fused, k4_args, stream) + _launch_cached(_multiphase_cf_cache, base_key + ("4k_fused",), launch_4k_fused, k4_args, stream) return sorted_ids, sorted_weights, sorted_expert_ids, num_valid_ids, moe_buf diff --git a/tests/kernels/test_moe_sorting.py b/tests/kernels/test_moe_sorting.py index e0928a58..39286db0 100644 --- a/tests/kernels/test_moe_sorting.py +++ b/tests/kernels/test_moe_sorting.py @@ -310,15 +310,15 @@ def run_test(T, E, topk, unit_size=UNIT_SIZE, max_tokens=None): Returns (passed: bool, gpu_time_us: float or None). """ - # Let moe_sorting_flydsl auto-select decode/prefill path. - # max_tokens is only needed for explicit decode-path override. + # Let moe_sorting_flydsl auto-select oneshot/multiphase path. + # max_tokens is only needed for explicit oneshot-path override. from kernels.moe_sorting_kernel import BLOCK_SIZE, _compute_sub_tokens sub_tokens = _compute_sub_tokens(E) - DECODE_MAX_T = min(sub_tokens, max(16, BLOCK_SIZE // max(topk, E // 8))) - path = "decode" if T <= min(sub_tokens, DECODE_MAX_T) else "prefill" + ONESHOT_MAX_T = min(sub_tokens, max(16, BLOCK_SIZE // max(topk, E // 8))) + path = "oneshot" if T <= min(sub_tokens, ONESHOT_MAX_T) else "multiphase" - if max_tokens is None and path == "decode": + if max_tokens is None and path == "oneshot": max_tokens = max(T, 8) max_tokens = ((max_tokens + 7) // 8) * 8 @@ -425,7 +425,7 @@ def run_test_vs_aiter(T, E, topk, unit_size=UNIT_SIZE, max_tokens=None): block_size=unit_size, ) - # FlyDSL (auto-dispatches decode/prefill) + # FlyDSL (auto-dispatches oneshot/multiphase) fly_ids, fly_w, fly_eids, fly_nvalid, _ = _call_flydsl( topk_ids, topk_weights, @@ -453,8 +453,8 @@ def run_test_vs_aiter(T, E, topk, unit_size=UNIT_SIZE, max_tokens=None): # --------------------------------------------------------------------------- # Pytest entry points # --------------------------------------------------------------------------- -DECODE_CONFIGS = [ - # (T, E, topk) — decode path (small T) +ONESHOT_CONFIGS = [ + # (T, E, topk) — oneshot path (small T) (1, 256, 8), (1, 32, 5), (4, 256, 8), @@ -466,13 +466,13 @@ def run_test_vs_aiter(T, E, topk, unit_size=UNIT_SIZE, max_tokens=None): (1, 8, 2), (7, 32, 5), # odd T, topk not power of 2 (31, 64, 6), # prime T, topk not power of 2 - # Production E > 256 (DECODE_BLOCK=512) — core coverage + # Production E > 256 (ONESHOT_BLOCK=512) — core coverage (1, 257, 9), # DeepSeek-R1 (256 routed + 1 shared) (16, 257, 9), (16, 513, 9), # Qwen3.5 (512 routed + 1 shared) ] -DECODE_CONFIGS_FULL = DECODE_CONFIGS + [ +ONESHOT_CONFIGS_FULL = ONESHOT_CONFIGS + [ # Extended production coverage (large_shape — CI skips by default) (8, 257, 9), (1, 385, 7), # DeepSeek-V4 (384 routed + 1 shared) @@ -484,8 +484,8 @@ def run_test_vs_aiter(T, E, topk, unit_size=UNIT_SIZE, max_tokens=None): ] -PREFILL_CONFIGS = [ - # (T, E, topk) — prefill path (large T, HBM workspace) +MULTIPHASE_CONFIGS = [ + # (T, E, topk) — multiphase path (large T, HBM workspace) (128, 256, 8), (512, 256, 8), (1024, 256, 8), @@ -495,7 +495,7 @@ def run_test_vs_aiter(T, E, topk, unit_size=UNIT_SIZE, max_tokens=None): (1024, 513, 9), # Qwen3.5 ] -PREFILL_CONFIGS_FULL = PREFILL_CONFIGS + [ +MULTIPHASE_CONFIGS_FULL = MULTIPHASE_CONFIGS + [ # Extended (large_shape — CI skips by default) (4096, 256, 8), (8192, 256, 8), @@ -507,30 +507,30 @@ def run_test_vs_aiter(T, E, topk, unit_size=UNIT_SIZE, max_tokens=None): ] -@pytest.mark.parametrize("T,E,topk", DECODE_CONFIGS) -def test_moe_sorting_decode(T, E, topk): +@pytest.mark.parametrize("T,E,topk", ONESHOT_CONFIGS) +def test_moe_sorting_oneshot(T, E, topk): passed, _ = run_test(T, E, topk) assert passed, f"MoE sorting failed for T={T}, E={E}, topk={topk}" @pytest.mark.large_shape -@pytest.mark.parametrize("T,E,topk", [c for c in DECODE_CONFIGS_FULL if c not in DECODE_CONFIGS]) -def test_moe_sorting_decode_full(T, E, topk): +@pytest.mark.parametrize("T,E,topk", [c for c in ONESHOT_CONFIGS_FULL if c not in ONESHOT_CONFIGS]) +def test_moe_sorting_oneshot_full(T, E, topk): passed, _ = run_test(T, E, topk) assert passed, f"MoE sorting failed for T={T}, E={E}, topk={topk}" -@pytest.mark.parametrize("T,E,topk", PREFILL_CONFIGS) -def test_moe_sorting_prefill(T, E, topk): +@pytest.mark.parametrize("T,E,topk", MULTIPHASE_CONFIGS) +def test_moe_sorting_multiphase(T, E, topk): passed, _ = run_test(T, E, topk) - assert passed, f"MoE sorting (prefill) failed for T={T}, E={E}, topk={topk}" + assert passed, f"MoE sorting (multiphase) failed for T={T}, E={E}, topk={topk}" @pytest.mark.large_shape -@pytest.mark.parametrize("T,E,topk", [c for c in PREFILL_CONFIGS_FULL if c not in PREFILL_CONFIGS]) -def test_moe_sorting_prefill_full(T, E, topk): +@pytest.mark.parametrize("T,E,topk", [c for c in MULTIPHASE_CONFIGS_FULL if c not in MULTIPHASE_CONFIGS]) +def test_moe_sorting_multiphase_full(T, E, topk): passed, _ = run_test(T, E, topk) - assert passed, f"MoE sorting (prefill) failed for T={T}, E={E}, topk={topk}" + assert passed, f"MoE sorting (multiphase) failed for T={T}, E={E}, topk={topk}" def run_test_ep(T, E, topk, mask_ratio=0.5, unit_size=UNIT_SIZE): @@ -538,11 +538,11 @@ def run_test_ep(T, E, topk, mask_ratio=0.5, unit_size=UNIT_SIZE): from kernels.moe_sorting_kernel import BLOCK_SIZE, _compute_sub_tokens sub_tokens = _compute_sub_tokens(E) - DECODE_MAX_T = min(sub_tokens, max(16, BLOCK_SIZE // max(topk, E // 8))) - if T <= min(sub_tokens, DECODE_MAX_T): - path = "decode" + ONESHOT_MAX_T = min(sub_tokens, max(16, BLOCK_SIZE // max(topk, E // 8))) + if T <= min(sub_tokens, ONESHOT_MAX_T): + path = "oneshot" else: - path = "prefill" + path = "multiphase" print(f"\n{'='*60}") print(f"EP Test: T={T}, E={E}, topk={topk}, mask_ratio={mask_ratio}, path={path}") @@ -609,19 +609,19 @@ def run_test_ep(T, E, topk, mask_ratio=0.5, unit_size=UNIT_SIZE): EP_CONFIGS = [ # (T, E, topk, mask_ratio) - (4, 256, 8, 0.5), # decode path - (8, 256, 8, 0.3), # decode path, sparse - (64, 256, 8, 0.5), # prefill path - (128, 256, 8, 0.7), # prefill path - (2048, 256, 8, 0.5), # prefill path + (4, 256, 8, 0.5), # oneshot path + (8, 256, 8, 0.3), # oneshot path, sparse + (64, 256, 8, 0.5), # multiphase path + (128, 256, 8, 0.7), # multiphase path + (2048, 256, 8, 0.5), # multiphase path (4, 256, 8, 1.0), # all enabled (should match non-EP) - (64, 256, 8, 1.0), # all enabled, prefill + (64, 256, 8, 1.0), # all enabled, multiphase (4, 256, 8, 0.0), # all masked (empty output) # Production E>256 with EP - (8, 257, 9, 0.5), # DeepSeek-R1 decode + EP - (1024, 257, 9, 0.5), # DeepSeek-R1 prefill + EP - (8, 513, 9, 0.5), # Qwen3.5 decode + EP - (1024, 513, 9, 0.5), # Qwen3.5 prefill + EP (E > K4_BLOCK) + (8, 257, 9, 0.5), # DeepSeek-R1 oneshot + EP + (1024, 257, 9, 0.5), # DeepSeek-R1 multiphase + EP + (8, 513, 9, 0.5), # Qwen3.5 oneshot + EP + (1024, 513, 9, 0.5), # Qwen3.5 multiphase + EP (E > K4_BLOCK) ] @@ -738,7 +738,7 @@ def run_bench_comparison(token_sweep=None): print(f" MoE Sorting Benchmark: FlyDSL vs CK (E={E}, topk={topk}, unit_size={UNIT_SIZE})") print(f" Device: {torch.cuda.get_device_name(0)}") props = torch.cuda.get_device_properties(0) - print(f" CUs: {props.multi_processor_count}, decode threshold: T<={sub_tokens}") + print(f" CUs: {props.multi_processor_count}, oneshot threshold: T<={sub_tokens}") print(f" Modes: eager (with L2 flush, median of {BENCH_MEASURE}), graph ({BENCH_MEASURE} replays)") print(f"{'=' * 110}") print( @@ -752,7 +752,7 @@ def run_bench_comparison(token_sweep=None): topk_ids = torch.stack([torch.randperm(E, device="cuda")[:topk] for _ in range(T)]).to(torch.int32) topk_weights = torch.rand(T, topk, dtype=torch.float32, device="cuda") - path = "decode" if T <= sub_tokens else "prefill" + path = "oneshot" if T <= sub_tokens else "multiphase" # Pre-allocate outputs to avoid per-call torch.empty overhead max_num_tokens_padded = T * topk + E * UNIT_SIZE - topk @@ -839,7 +839,7 @@ def main(): topk = args.topk or 8 configs = [(args.T, E, topk)] elif args.all: - configs = DECODE_CONFIGS + PREFILL_CONFIGS + configs = ONESHOT_CONFIGS + MULTIPHASE_CONFIGS else: configs = [ (1, 256, 8),