From 9048f3e987207925d3ff8849f9bc9ddb9ffedbc0 Mon Sep 17 00:00:00 2001 From: amd-weisun Date: Tue, 19 May 2026 12:16:55 +0100 Subject: [PATCH 1/5] [FlyDSL]add FlyDSL MoE sorting kernel --- aiter/fused_moe.py | 37 +- .../ops/flydsl/kernels/moe_sorting_kernel.py | 2066 +++++++++++++++++ aiter/ops/flydsl/moe_sorting.py | 64 + 3 files changed, 2166 insertions(+), 1 deletion(-) create mode 100644 aiter/ops/flydsl/kernels/moe_sorting_kernel.py create mode 100644 aiter/ops/flydsl/moe_sorting.py diff --git a/aiter/fused_moe.py b/aiter/fused_moe.py index c755e2bf04..f23ddc5135 100644 --- a/aiter/fused_moe.py +++ b/aiter/fused_moe.py @@ -22,8 +22,9 @@ BLOCK_SIZE_M = 32 -# Default to Opus unless CK sorting is explicitly requested. +# Default to Opus unless CK or FlyDSL sorting is explicitly requested. _USE_CK_MOE_SORTING = os.environ.get("AITER_USE_CK_MOE_SORTING", "0") == "1" +_USE_FLYDSL_MOE_SORTING = os.environ.get("AITER_USE_FLYDSL_MOE_SORTING", "0") == "1" _ACT_TYPE_DISABLED_KEY = "__ignore__" _SWIGLU_MXFP4_BF16_BOUND = int(os.environ.get("GPTOSS_SWIGLU_MXFP4_BF16_BOUND", "256")) @@ -105,6 +106,35 @@ def _moe_sorting_impl( return ret +def _flydsl_moe_sorting( + topk_ids, topk_weights, num_experts, model_dim, + moebuf_dtype, block_size, expert_mask, num_local_tokens, +): + """FlyDSL sorting dispatch — called outside torch_compile_guard.""" + from aiter.ops.flydsl.moe_sorting import flydsl_moe_sorting_fwd + + device = topk_ids.device + M, topk = topk_ids.shape + max_num_tokens_padded = int(topk_ids.numel() + num_experts * block_size - topk) + max_num_m_blocks = int((max_num_tokens_padded + block_size - 1) // block_size) + sorted_ids = torch.empty(max_num_tokens_padded, dtype=dtypes.i32, device=device) + sorted_weights = torch.empty( + max_num_tokens_padded, dtype=dtypes.fp32, device=device + ) + sorted_expert_ids = torch.empty(max_num_m_blocks, dtype=dtypes.i32, device=device) + num_valid_ids = torch.empty(2, dtype=dtypes.i32, device=device) + moe_buf = torch.empty((M, model_dim), dtype=moebuf_dtype, device=device) + + flydsl_moe_sorting_fwd( + topk_ids, topk_weights, + sorted_ids, sorted_weights, sorted_expert_ids, + num_valid_ids, moe_buf, + num_experts, int(block_size), + expert_mask, num_local_tokens, + ) + return sorted_ids, sorted_weights, sorted_expert_ids, num_valid_ids, moe_buf + + def moe_sorting( topk_ids, topk_weights, @@ -117,6 +147,11 @@ def moe_sorting( dispatch_policy=0, return_local_topk_ids=False, ): + if _USE_FLYDSL_MOE_SORTING: + return _flydsl_moe_sorting( + topk_ids, topk_weights, num_experts, model_dim, + moebuf_dtype, block_size, expert_mask, num_local_tokens, + ) try: return _moe_sorting_impl( topk_ids, diff --git a/aiter/ops/flydsl/kernels/moe_sorting_kernel.py b/aiter/ops/flydsl/kernels/moe_sorting_kernel.py new file mode 100644 index 0000000000..dc403c631a --- /dev/null +++ b/aiter/ops/flydsl/kernels/moe_sorting_kernel.py @@ -0,0 +1,2066 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 FlyDSL Project Contributors + +"""MoE token sorting kernel (FlyDSL). + +Implements the MoE sorting operation used in DeepSeek R1 and similar MoE models. +Given router top-k selections (topk_ids, topk_weights), reorganizes tokens by expert +for efficient batched expert GEMM execution. + +Algorithm: counting sort in LDS (histogram → prefix-sum → scatter). + +Two paths: + - Decode (small T): single kernel, all phases in LDS. + - Prefill (large T): 4 kernels via HBM workspace (ClearWS → P0 scatter → P1 count → P23 prefix-sum+scatter). + +Packed token ID format: (topk_position << 24) | token_id + - Upper 8 bits: topk slot (0..topk-1) + - Lower 24 bits: token index (0..M-1) + - Padding sentinel: (topk << 24) | M +""" + +import functools +from contextlib import contextmanager + +import torch + +import flydsl.compiler as flyc +import flydsl.expr as fx +from flydsl._mlir import ir +from flydsl._mlir.dialects import memref as memref_ops +from flydsl._mlir.dialects import scf +from flydsl.compiler.kernel_function import CompilationContext +from flydsl.expr import buffer_ops, gpu, range_constexpr +from flydsl.expr import rocdl as fly_rocdl +from flydsl.expr.arith import ArithValue +from flydsl.expr.typing import T +from flydsl.expr.typing import Vector as Vec +from flydsl.runtime.device import get_rocm_arch as get_hip_arch +from flydsl.utils.smem_allocator import SmemAllocator, SmemPtr +from .kernels_common import get_warp_size + +BLOCK_SIZE = 256 +UNIT_SIZE = 32 # GEMM tile-M, aka block_size in CK +WARP_SIZE = get_warp_size() + +# DPP constants for prefix sum (used by decode and prefill) +DPP_ROW_SHR_1 = 0x111 +DPP_ROW_SHR_2 = 0x112 +DPP_ROW_SHR_4 = 0x114 +DPP_ROW_SHR_8 = 0x118 +DPP_ROW_MASK = 0xF +DPP_BANK_MASK = 0xF + + +def _unwrap_val(v): + """Unwrap DSL value to raw MLIR ir.Value.""" + return v.ir_value() if hasattr(v, "ir_value") else v + + +def _dpp_intra_wave_prefix_sum(val, lane, WARP_SIZE): + """inclusive prefix sum within a single wave using DPP. + + Performs 4 DPP row_shr steps (1, 2, 4, 8) for intra-row scan, then + 2 ds_bpermute steps (16, 32) for cross-row accumulation within the wave. + Returns the inclusive prefix sum value for each lane. + + Call inside @flyc.kernel only — emits MLIR ops during tracing. + """ + val_raw = _unwrap_val(val) + zero_raw = _unwrap_val(fx.Int32(0)) + + for shift, dpp_op, threshold in [ + (1, DPP_ROW_SHR_1, 1), + (2, DPP_ROW_SHR_2, 2), + (4, DPP_ROW_SHR_4, 4), + (8, DPP_ROW_SHR_8, 8), + ]: + remote = fly_rocdl.update_dpp(T.i32, zero_raw, val_raw, dpp_op, DPP_ROW_MASK, DPP_BANK_MASK, True) + val = (lane >= fx.Int32(threshold)).select(val + fx.Int32(remote), val) + val_raw = _unwrap_val(val) + + src_lane_16 = (lane & fx.Int32(0x30)) - fx.Int32(1) + remote16 = fly_rocdl.ds_bpermute(T.i32, src_lane_16 * fx.Int32(4), val) + val = (lane >= fx.Int32(16)).select(val + fx.Int32(remote16), val) + + if WARP_SIZE > 32: + src_lane_32 = (lane & fx.Int32(0x30)) - fx.Int32(17) + remote32 = fly_rocdl.ds_bpermute(T.i32, src_lane_32 * fx.Int32(4), val) + val = (lane >= fx.Int32(32)).select(val + fx.Int32(remote32), val) + + return val + + +# --------------------------------------------------------------------------- +# AOT-compiled dispatch caches — keyed by constexpr values. +# After the first JIT call (which compiles the kernel), flyc.compile() +# returns a CompiledFunction whose __call__ skips inspect.Signature.bind, +# _make_cache_key, and dict lookup, reducing dispatch from ~70 us to ~5 us. +# --------------------------------------------------------------------------- +_decode_cf_cache = {} # (num_experts, topk, max_tokens, unit_size, has_mask, device) -> CompiledFunction +_prefill_cf_cache = {} # (num_experts, topk, unit_size, kernel_name, *constexpr_vals) -> CompiledFunction +_dummy_mask_cache = {} # device -> torch.Tensor(1, dtype=i32, value=1) + + +@contextmanager +def _if_then(if_op): + """Context manager for scf.IfOp then-region (from moe_gemm_2stage.py).""" + with ir.InsertionPoint(if_op.then_block): + try: + yield if_op.then_block + finally: + blk = if_op.then_block + if (not blk.operations) or not isinstance(blk.operations[-1], scf.YieldOp): + scf.YieldOp([]) + + +# --------------------------------------------------------------------------- +# FlyDSL GPU kernel — decode path (single kernel, SubTokenOneShot) +# --------------------------------------------------------------------------- +@functools.lru_cache(maxsize=256) +def compile_moe_sorting_decode( + *, + num_experts: int, + topk: int, + max_tokens: int = 128, + unit_size: int = UNIT_SIZE, + has_mask: bool = False, +): + """Compile the decode-path MoE sorting kernel. + + Parameters + ---------- + num_experts : int + Number of routed experts (e.g. 256 for DeepSeek R1). + topk : int + Experts per token (e.g. 8 for DeepSeek R1). + max_tokens : int + Upper bound on T for LDS sizing. Actual T is passed at runtime. + unit_size : int + GEMM tile-M for padding alignment (default 32). + """ + arch = get_hip_arch() + E = num_experts + # CDNA (warp64): 512 threads = 8 waves, affordable cross-wave reduction. + # RDNA (warp32): cap at 256 to avoid 16-wave overhead. + max_decode_block = 512 if WARP_SIZE == 64 else 256 + DECODE_BLOCK = 256 if E <= 256 else min(512, max_decode_block) + NUM_WAVES = DECODE_BLOCK // WARP_SIZE + smem_cols = E + 1 + + # LDS sizing: sub_tokens rows for the token×expert histogram + # Match CK's sizing: total LDS / occupancy / smem_cols, rounded to 8 + if arch in ("gfx942",) or str(arch).startswith("gfx94"): + lds_capacity_bytes = 65536 + elif str(arch).startswith("gfx95"): + lds_capacity_bytes = 163840 + else: + lds_capacity_bytes = 65536 # conservative default + + lds_capacity_ints = lds_capacity_bytes // 4 + target_occupancy = 2 + r = lds_capacity_ints // target_occupancy // smem_cols + sub_unroll = 8 + cumsum_bufs = 2 + if r < (cumsum_bufs + sub_unroll): + raise ValueError( + f"LDS too small for E={E}: need at least " f"{(cumsum_bufs + sub_unroll) * smem_cols * 4} bytes" + ) + r_for_sub = ((r - cumsum_bufs) // sub_unroll) * sub_unroll + r_token_min = ((max_tokens + sub_unroll - 1) // sub_unroll) * sub_unroll + r_for_sub = min(r_for_sub, r_token_min) + sub_tokens = r_for_sub + + # SmemAllocator for the 3 LDS regions + allocator = SmemAllocator(None, arch=arch) + + # Region 0: cumsum[E+1] (exclusive prefix sums per expert) + cumsum_offset = allocator._align(allocator.ptr, 16) + allocator.ptr = cumsum_offset + smem_cols * 4 + + # Region 1: cumdup[E+1] (duplicate of cumsum for scatter phase) + cumdup_offset = allocator._align(allocator.ptr, 16) + allocator.ptr = cumdup_offset + smem_cols * 4 + + # Region 2: tokens_mesh[sub_tokens, smem_cols] + mesh_offset = allocator._align(allocator.ptr, 16) + allocator.ptr = mesh_offset + sub_tokens * smem_cols * 4 + + # Region 3: cross-wave scratch for all-wave parallel prefix sum [NUM_WAVES] + scratch_offset = allocator._align(allocator.ptr, 16) + allocator.ptr = scratch_offset + NUM_WAVES * 4 + + # Helpers for raw memref LDS access (used inside scf.for instead of SmemPtr) + def _to_index(v): + """Convert i32 or index DSL value to raw MLIR index value.""" + raw = v.ir_value() if hasattr(v, "ir_value") else v + if isinstance(raw.type, ir.IndexType): + return raw + return ArithValue(v).index_cast(T.index) + + def _lds_load(raw_mr, idx): + """Load i32 from LDS raw memref. idx can be i32 or index.""" + return fx.Int32(memref_ops.load(raw_mr, [_to_index(idx)])) + + def _lds_store(raw_mr, val, idx): + """Store i32 to LDS raw memref. idx can be i32 or index.""" + v = val.ir_value() if hasattr(val, "ir_value") else val + memref_ops.store(v, raw_mr, [_to_index(idx)]) + + def _unwrap(v): + """Unwrap DSL value to raw MLIR ir.Value for scf.for init.""" + return v.ir_value() if hasattr(v, "ir_value") else v + + @flyc.kernel(known_block_size=[DECODE_BLOCK, 1, 1]) + def moe_sorting_decode_kernel( + topk_ids_tensor: fx.Tensor, + topk_weights_tensor: fx.Tensor, + sorted_token_ids: fx.Tensor, + sorted_weights_out: fx.Tensor, + sorted_expert_ids: fx.Tensor, + num_valid_ids: fx.Tensor, + moe_buf: fx.Tensor, + expert_mask_tensor: fx.Tensor, + i32_tokens: fx.Int32, + i32_moe_buf_elems: fx.Int32, + ): + bid = gpu.block_idx.x + tid = gpu.thread_idx.x + lane = tid % WARP_SIZE + wave = tid // WARP_SIZE + tokens = i32_tokens + c_zero_i32 = fx.Int32(0) + c_one_i32 = fx.Int32(1) + c_oob_idx = fx.Int32(0x7FFFFFFF) + c4_i32 = fx.Int32(4) + + # Buffer resources (needed by both paths, defined at top level) + moe_buf_rsrc = buffer_ops.create_buffer_resource(moe_buf, max_size=True) + topk_ids_rsrc = buffer_ops.create_buffer_resource(topk_ids_tensor, max_size=True) + weights_rsrc = buffer_ops.create_buffer_resource(topk_weights_tensor, max_size=True) + sorted_ids_rsrc = buffer_ops.create_buffer_resource(sorted_token_ids, max_size=True) + sorted_w_rsrc = buffer_ops.create_buffer_resource(sorted_weights_out, max_size=True) + sorted_e_rsrc = buffer_ops.create_buffer_resource(sorted_expert_ids, max_size=True) + nvalid_rsrc = buffer_ops.create_buffer_resource(num_valid_ids, max_size=True) + mask_rsrc = buffer_ops.create_buffer_resource(expert_mask_tensor, max_size=True) + + # LDS: get RAW memrefs ONCE — dominates all child scf.for/scf.if regions. + base_ptr = allocator.get_base() + cumsum_mr = SmemPtr(base_ptr, cumsum_offset, T.i32, shape=(smem_cols,)).get() + cumdup_mr = SmemPtr(base_ptr, cumdup_offset, T.i32, shape=(smem_cols,)).get() + mesh_mr = SmemPtr(base_ptr, mesh_offset, T.i32, shape=(sub_tokens * smem_cols,)).get() + + c_topk = fx.Int32(topk) + c_E = fx.Int32(E) + c_unit = fx.Int32(unit_size) + c_sub_tokens = fx.Int32(sub_tokens) + c_smem_cols = fx.Int32(smem_cols) + c_sentinel = fx.Int32((topk << 24)) + + # =================== MOE_BUF ZEROING (blocks > 0 only) =============== + is_zero_block = bid != c_zero_i32 + _if_zero = scf.IfOp(is_zero_block.ir_value()) + with _if_then(_if_zero): + zero_gid_v4 = (bid - c_one_i32) * fx.Int32(DECODE_BLOCK) + tid + num_zero_blocks = gpu.grid_dim.x - c_one_i32 + zero_stride_v4 = num_zero_blocks * fx.Int32(DECODE_BLOCK) + i32_moe_buf_v4 = i32_moe_buf_elems >> fx.Int32(2) + zero_niters = (i32_moe_buf_v4 + zero_stride_v4 - c_one_i32) // zero_stride_v4 + _zs = fx.Index(0) + _ze = ArithValue(zero_niters).index_cast(T.index) + _z1 = fx.Index(1) + c_zero_v4 = fx.Vector.filled(4, 0, fx.Int32) + c4_i32 = fx.Int32(4) + for _z in range(_zs, _ze, _z1): + z_idx_v4 = zero_gid_v4 + fx.Int32(_z) * zero_stride_v4 + z_valid = z_idx_v4 < i32_moe_buf_v4 + z_elem = z_valid.select(z_idx_v4 * c4_i32, c_oob_idx) + buffer_ops.buffer_store(c_zero_v4, moe_buf_rsrc, z_elem) + + # =================== SORTING (block 0 only) ========================== + is_sort_block = bid == c_zero_i32 + _if_sort = scf.IfOp(is_sort_block.ir_value()) + with _if_then(_if_sort): + # ========================= PHASE 1: Histogram ========================= + # Clear mesh region — unconditional store to safe index when out of bounds + for i_clear in range_constexpr(0, sub_tokens * smem_cols, DECODE_BLOCK): + idx = fx.Int32(i_clear) + tid + is_valid = idx < fx.Int32(sub_tokens * smem_cols) + safe_idx = is_valid.select(idx, c_zero_i32) + safe_idx_ix = ArithValue(safe_idx).index_cast(T.index) + # Always store; out-of-bounds threads harmlessly write to index 0 + _lds_store(mesh_mr, c_zero_i32, safe_idx_ix) + gpu.barrier() + + # Fill mesh: for each (token, topk_slot), write topk_slot+1 to mesh[token, expert_id] + total_assignments = tokens * c_topk + for i_assign in range_constexpr(0, max_tokens * topk, DECODE_BLOCK): + flat_idx = fx.Int32(i_assign) + tid + is_valid = flat_idx < total_assignments + safe_flat = is_valid.select(flat_idx, c_zero_i32) + + token_id = safe_flat // c_topk + topk_slot = safe_flat % c_topk + + global_idx = token_id * c_topk + topk_slot + eid = buffer_ops.buffer_load(topk_ids_rsrc, global_idx, vec_width=1, dtype=T.i32) + + # mesh[token_id, eid] = topk_slot + 1 (valid threads only). + # Invalid threads must NOT write to mesh[0] — that would race + # with a valid write to (token=0, expert=0). + mesh_addr = token_id * c_smem_cols + eid + last_mesh_idx = fx.Int32(sub_tokens * smem_cols - 1) + safe_mesh_addr = is_valid.select(mesh_addr, last_mesh_idx) + safe_mesh_ix = ArithValue(safe_mesh_addr).index_cast(T.index) + val = is_valid.select(topk_slot + c_one_i32, c_zero_i32) + _lds_store(mesh_mr, val, safe_mesh_ix) + gpu.barrier() + + # ===================== PHASE 2: Count + Prefix Sum ===================== + c_lane_group_sz = fx.Int32(8) + lane_group_id = tid // c_lane_group_sz + lane_group_os = tid % c_lane_group_sz + width8_i32 = fx.Int32(8) + + is_t0 = tid == c_zero_i32 + + # Initialize cumsum[0] = 0. All threads write 0 so there's no + # read-modify-write race across waves. + _lds_store(cumsum_mr, c_zero_i32, c_zero_i32) + gpu.barrier() + + for i_e in range_constexpr(0, E, DECODE_BLOCK // 8): + eid_local = fx.Int32(i_e) + lane_group_id + eid_valid = eid_local < c_E + + cnt = c_zero_i32 + for i_sub in range_constexpr(0, sub_tokens, 8): + sub_idx = fx.Int32(i_sub) + lane_group_os + sub_valid = sub_idx < c_sub_tokens + combined_valid = eid_valid & sub_valid + + safe_sub = combined_valid.select(sub_idx, c_zero_i32) + safe_eid = combined_valid.select(eid_local, c_zero_i32) + mesh_rd_addr = safe_sub * c_smem_cols + safe_eid + mesh_rd_ix = ArithValue(mesh_rd_addr).index_cast(T.index) + mesh_val = _lds_load(mesh_mr, mesh_rd_ix) + + has_token = combined_valid.select( + (mesh_val != c_zero_i32).select(c_one_i32, c_zero_i32), + c_zero_i32, + ) + + # Reduce within lane-group of 8 + reduced = has_token + for sh in range_constexpr(3): + off = fx.Int32(1 << sh) + peer = reduced.shuffle_xor(off, width8_i32) + reduced = reduced + peer + cnt = cnt + reduced + + # Only lane 0 of each valid lane-group writes the count to cumsum[eid+1]. + # Invalid threads: write_valid is false, cs_idx = 0, and we write 0 to + # cumsum[0] which is harmless (cumsum[0] is always 0). + write_valid = eid_valid & (lane_group_os == c_zero_i32) + cs_idx = write_valid.select(eid_local + c_one_i32, c_zero_i32) + cs_ix = ArithValue(cs_idx).index_cast(T.index) + cs_val = write_valid.select(cnt, c_zero_i32) + _lds_store(cumsum_mr, cs_val, cs_ix) + gpu.barrier() + + # Phase 2b: Prefix sum over expert counts. + # Step 1: Each thread converts its expert's raw count → padded block size. + for i_cvt in range_constexpr(0, E, DECODE_BLOCK): + cvt_eid = fx.Int32(i_cvt) + tid + cvt_valid = cvt_eid < c_E + # Safe index: valid → cumsum[eid+1], invalid → cumsum[0] (write 0, harmless) + safe_cvt_idx = cvt_valid.select(cvt_eid + c_one_i32, c_zero_i32) + cvt_ix = ArithValue(safe_cvt_idx).index_cast(T.index) + raw_cnt_cvt = _lds_load(cumsum_mr, cvt_ix) + blocks_cvt = (raw_cnt_cvt + c_unit - c_one_i32) // c_unit + padded_cvt = (raw_cnt_cvt == c_zero_i32).select(c_zero_i32, blocks_cvt * c_unit) + # Valid threads write padded value; invalid threads write 0 to cumsum[0] + _lds_store(cumsum_mr, cvt_valid.select(padded_cvt, c_zero_i32), cvt_ix) + gpu.barrier() + + if has_mask: + # EP: zero padded count for masked experts in a separate pass. + # Loading from mask buffer inside the padded-count loop above interfered + # with expert 0 (MLIR codegen issue). Separate pass avoids this. + for i_ep in range_constexpr(0, E, DECODE_BLOCK): + ep_eid = fx.Int32(i_ep) + tid + ep_valid = ep_eid < c_E + ep_safe_eid = ep_valid.select(ep_eid, c_zero_i32) + ep_m = buffer_ops.buffer_load(mask_rsrc, ep_safe_eid, vec_width=1, dtype=T.i32) + should_zero = ep_valid & (ep_m == c_zero_i32) + ep_cs_ix = ArithValue(ep_valid.select(ep_eid + c_one_i32, c_zero_i32)).index_cast(T.index) + _lds_store(cumsum_mr, should_zero.select(c_zero_i32, _lds_load(cumsum_mr, ep_cs_ix)), ep_cs_ix) + gpu.barrier() + + # Step 2: All-wave parallel prefix sum (cumsum → cumdup). + scratch_mr = SmemPtr(base_ptr, scratch_offset, T.i32, shape=(NUM_WAVES,)).get() + + # All threads read cumsum[tid+1] (in chunks for E > DECODE_BLOCK) + for _ps_chunk in range_constexpr(0, E, DECODE_BLOCK): + ps_eid = fx.Int32(_ps_chunk) + tid + ps_valid = ps_eid < c_E + ps_safe_ix = ArithValue(ps_valid.select(ps_eid + c_one_i32, c_zero_i32)).index_cast(T.index) + ps_val = ps_valid.select(_lds_load(cumsum_mr, ps_safe_ix), c_zero_i32) + _lds_store(cumdup_mr, ps_val, ps_safe_ix) + _lds_store(cumdup_mr, c_zero_i32, c_zero_i32) + gpu.barrier() + + # DPP prefix sum — all NUM_WAVES waves active + ps_tid_valid = tid < c_E + val = ps_tid_valid.select(_lds_load(cumdup_mr, ArithValue(tid + c_one_i32).index_cast(T.index)), c_zero_i32) + val = _dpp_intra_wave_prefix_sum(val, lane, WARP_SIZE) + + # Cross-wave accumulation via scratch LDS + is_last_lane_ps = lane == fx.Int32(WARP_SIZE - 1) + _if_ll_ps = scf.IfOp(is_last_lane_ps.ir_value()) + with _if_then(_if_ll_ps): + _lds_store_raw(scratch_mr, val, ArithValue(wave).index_cast(T.index)) + gpu.barrier() + cross_ps = c_zero_i32 + for _w_ps in range_constexpr(NUM_WAVES - 1): + w_ps_total = _lds_load_raw(scratch_mr, ArithValue(fx.Int32(_w_ps)).index_cast(T.index)) + cross_ps = (wave > fx.Int32(_w_ps)).select(cross_ps + w_ps_total, cross_ps) + + inclusive_ps = val + cross_ps + _lds_store( + cumdup_mr, + ps_tid_valid.select(inclusive_ps, c_zero_i32), + ArithValue(ps_tid_valid.select(tid + c_one_i32, c_zero_i32)).index_cast(T.index), + ) + gpu.barrier() + + # For E > DECODE_BLOCK: thread 0 serially extends + if E > DECODE_BLOCK: + _if_t0_ext = scf.IfOp(is_t0.ir_value()) + with _if_then(_if_t0_ext): + prev_ext = _lds_load(cumdup_mr, ArithValue(fx.Int32(DECODE_BLOCK)).index_cast(T.index)) + for _ext in range_constexpr(DECODE_BLOCK, E): + cur_ext = _lds_load(cumdup_mr, ArithValue(fx.Int32(_ext + 1)).index_cast(T.index)) + new_ext = prev_ext + cur_ext + _lds_store(cumdup_mr, new_ext, ArithValue(fx.Int32(_ext + 1)).index_cast(T.index)) + prev_ext = new_ext + gpu.barrier() + + # cumdup[0] = 0 + _lds_store(cumdup_mr, c_zero_i32, c_zero_i32) + gpu.barrier() + + # Write num_valid_ids from cumdup[E] + cs_E_ix_ps = ArithValue(c_E).index_cast(T.index) + total_padded = _lds_load(cumdup_mr, cs_E_ix_ps) + buffer_ops.buffer_store(total_padded, nvalid_rsrc, c_zero_i32) + buffer_ops.buffer_store(tokens, nvalid_rsrc, c_one_i32) + gpu.barrier() + + # Copy cumdup → cumsum (all threads, one expert per thread) + for i_cp in range_constexpr(0, E + 1, DECODE_BLOCK): + cp_idx = fx.Int32(i_cp) + tid + cp_valid = cp_idx <= c_E + safe_cp_idx = cp_valid.select(cp_idx, c_zero_i32) + cp_ix = ArithValue(safe_cp_idx).index_cast(T.index) + cp_val = _lds_load(cumdup_mr, cp_ix) + _lds_store(cumsum_mr, cp_val, cp_ix) + gpu.barrier() + + if has_mask: + # EP: Compute mask cumsum in cumdup for local expert index mapping. + # cumdup[eid] = exclusive prefix sum of mask[0..eid-1] = local expert index. + for i_ml in range_constexpr(0, E, DECODE_BLOCK): + ml_eid = fx.Int32(i_ml) + tid + ml_valid = ml_eid < c_E + safe_ml_eid = ml_valid.select(ml_eid, c_zero_i32) + ml_mask = buffer_ops.buffer_load(mask_rsrc, safe_ml_eid, vec_width=1, dtype=T.i32) + ml_val = ml_valid.select(ml_mask, c_zero_i32) + ml_ix = ArithValue(ml_valid.select(ml_eid + c_one_i32, c_zero_i32)).index_cast(T.index) + _lds_store(cumdup_mr, ml_val, ml_ix) + _lds_store(cumdup_mr, c_zero_i32, c_zero_i32) + gpu.barrier() + + # All-wave DPP prefix sum over mask values in cumdup + m_tid_valid = tid < c_E + mval = m_tid_valid.select( + _lds_load(cumdup_mr, ArithValue(tid + c_one_i32).index_cast(T.index)), c_zero_i32 + ) + mval = _dpp_intra_wave_prefix_sum(mval, lane, WARP_SIZE) + + is_last_lane_m = lane == fx.Int32(WARP_SIZE - 1) + _if_ll_m = scf.IfOp(is_last_lane_m.ir_value()) + with _if_then(_if_ll_m): + _lds_store_raw(scratch_mr, mval, ArithValue(wave).index_cast(T.index)) + gpu.barrier() + cross_m = c_zero_i32 + for _wm in range_constexpr(NUM_WAVES - 1): + wm_total = _lds_load_raw(scratch_mr, ArithValue(fx.Int32(_wm)).index_cast(T.index)) + cross_m = (wave > fx.Int32(_wm)).select(cross_m + wm_total, cross_m) + inclusive_m = mval + cross_m + _lds_store( + cumdup_mr, + m_tid_valid.select(inclusive_m, c_zero_i32), + ArithValue(m_tid_valid.select(tid + c_one_i32, c_zero_i32)).index_cast(T.index), + ) + gpu.barrier() + + if E > DECODE_BLOCK: + _if_t0_ext_m = scf.IfOp(is_t0.ir_value()) + with _if_then(_if_t0_ext_m): + prev_m = _lds_load(cumdup_mr, ArithValue(fx.Int32(DECODE_BLOCK)).index_cast(T.index)) + for _ext_m in range_constexpr(DECODE_BLOCK, E): + cur_m = _lds_load(cumdup_mr, ArithValue(fx.Int32(_ext_m + 1)).index_cast(T.index)) + new_m = prev_m + cur_m + _lds_store(cumdup_mr, new_m, ArithValue(fx.Int32(_ext_m + 1)).index_cast(T.index)) + prev_m = new_m + gpu.barrier() + + _lds_store(cumdup_mr, c_zero_i32, c_zero_i32) + gpu.barrier() + else: + # No mask: cumdup[eid] = eid (identity mapping) + for i_ml in range_constexpr(0, E, DECODE_BLOCK): + ml_eid = fx.Int32(i_ml) + tid + ml_valid = ml_eid < c_E + safe_ml_eid = ml_valid.select(ml_eid, c_zero_i32) + ml_ix = ArithValue(safe_ml_eid).index_cast(T.index) + _lds_store(cumdup_mr, ml_valid.select(safe_ml_eid, c_zero_i32), ml_ix) + gpu.barrier() + + # Write sorted_expert_ids — predicated stores to buffer (safe: buffer_store ignores OOB) + # EP: use cumdup[eid] as local expert index instead of global eid + for i_eid in range_constexpr(0, E, DECODE_BLOCK): + eid_wr = fx.Int32(i_eid) + tid + eid_wr_valid = eid_wr < c_E + safe_eid_wr = eid_wr_valid.select(eid_wr, c_zero_i32) + + cs_start_ix = ArithValue(safe_eid_wr).index_cast(T.index) + cs_end_ix = ArithValue(safe_eid_wr + c_one_i32).index_cast(T.index) + e_start = _lds_load(cumsum_mr, cs_start_ix) + e_end = eid_wr_valid.select(_lds_load(cumsum_mr, cs_end_ix), e_start) + local_eid = _lds_load(cumdup_mr, cs_start_ix) + + # Store cumdup: reuse cumdup for scatter phase position tracking. + # Write e_start to cumdup[eid] (overwriting mask cumsum, no longer needed). + _lds_store(cumdup_mr, e_start, cs_start_ix) + + blk_start = e_start // c_unit + blk_end = e_end // c_unit + n_blks_wr = eid_wr_valid.select(blk_end - blk_start, c_zero_i32) + for _jb in range(fx.Index(0), ArithValue(n_blks_wr).index_cast(T.index), fx.Index(1)): + blk_idx = blk_start + fx.Int32(_jb) + buffer_ops.buffer_store(local_eid, sorted_e_rsrc, blk_idx) + gpu.barrier() + + # Store cumdup[E] = cumsum[E]. + # All threads write cumE to cumdup[E] (all write the same value, no race). + cs_E_ix = ArithValue(c_E).index_cast(T.index) + cumE = _lds_load(cumsum_mr, cs_E_ix) + _lds_store(cumdup_mr, cumE, cs_E_ix) + gpu.barrier() + + # ====================== PRE-FILL: Sentinel fill (cooperative) =========== + # Fill ALL total_padded slots with sentinels BEFORE scatter. + # Scatter will overwrite real token positions. + sentinel_val_pre = c_sentinel | tokens + c_zero_pre = c_zero_i32 + cs_E_ix_pre = ArithValue(c_E).index_cast(T.index) + total_padded_pre = _lds_load(cumdup_mr, cs_E_ix_pre) + n_pre_iters = (total_padded_pre + fx.Int32(DECODE_BLOCK) - c_one_i32) // fx.Int32(DECODE_BLOCK) + for _pre in range(fx.Index(0), ArithValue(n_pre_iters).index_cast(T.index), fx.Index(1)): + pre_slot = fx.Int32(_pre) * fx.Int32(DECODE_BLOCK) + tid + pre_valid = pre_slot < total_padded_pre + safe_pre = pre_valid.select(pre_slot, c_oob_idx) + buffer_ops.buffer_store(sentinel_val_pre, sorted_ids_rsrc, safe_pre) + buffer_ops.buffer_store(c_zero_pre, sorted_w_rsrc, safe_pre) + gpu.barrier() + + # ====================== PHASE 3: Scatter ============================== + # CK uses wave_cumsum<8> (DPP prefix sum) + ds_bpermute (lane broadcast). + # FlyDSL has neither. Instead, each lane reads all 8 mesh values in the + # batch from LDS to compute its own prefix offset. No shuffle needed. + for i_e2 in range_constexpr(0, E, DECODE_BLOCK // 8): + eid_sc = fx.Int32(i_e2) + lane_group_id + eid_sc_valid = eid_sc < c_E + # Invalid lane groups map to cumsum[E] (the total count) instead of + # cumsum[0] to avoid racing with lane_group 0's position write-back. + safe_eid_sc = eid_sc_valid.select(eid_sc, c_E) + + sc_expert_enabled = eid_sc_valid + if has_mask: + # EP: check if this expert is masked (skip scatter for masked experts) + sc_mask_val = buffer_ops.buffer_load( + mask_rsrc, eid_sc_valid.select(eid_sc, c_zero_i32), vec_width=1, dtype=T.i32 + ) + sc_expert_enabled = eid_sc_valid & (sc_mask_val != c_zero_i32) + + cs_sc_ix = ArithValue(safe_eid_sc).index_cast(T.index) + position = _lds_load(cumsum_mr, cs_sc_ix) + + for i_sub2 in range_constexpr(0, sub_tokens, 8): + # This lane handles sub_token (i_sub2 + lane_group_os). + my_sub = fx.Int32(i_sub2) + lane_group_os + my_sub_valid = sc_expert_enabled & (my_sub < c_sub_tokens) + safe_my_sub = my_sub_valid.select(my_sub, c_zero_i32) + my_mesh_addr = safe_my_sub * c_smem_cols + safe_eid_sc + my_mesh_ix = ArithValue(my_mesh_addr).index_cast(T.index) + my_x = _lds_load(mesh_mr, my_mesh_ix) + my_has_token = my_sub_valid & (my_x != c_zero_i32) + local_cnt = my_has_token.select(c_one_i32, c_zero_i32) + + # 8-lane group prefix sum (NOT full-wave — uses lane_group_os, + # only shifts 1,2,4, no cross-row bpermute needed). + cnt_raw = _unwrap(local_cnt) + zero_raw = _unwrap(c_zero_i32) + + # row_shr:1 + remote = fly_rocdl.update_dpp( + T.i32, zero_raw, cnt_raw, DPP_ROW_SHR_1, DPP_ROW_MASK, DPP_BANK_MASK, True + ) + should_add = lane_group_os >= c_one_i32 + local_cnt = should_add.select(local_cnt + fx.Int32(remote), local_cnt) + + # row_shr:2 + cnt_raw = _unwrap(local_cnt) + remote = fly_rocdl.update_dpp( + T.i32, zero_raw, cnt_raw, DPP_ROW_SHR_2, DPP_ROW_MASK, DPP_BANK_MASK, True + ) + should_add = lane_group_os >= fx.Int32(2) + local_cnt = should_add.select(local_cnt + fx.Int32(remote), local_cnt) + + # row_shr:4 + cnt_raw = _unwrap(local_cnt) + remote = fly_rocdl.update_dpp( + T.i32, zero_raw, cnt_raw, DPP_ROW_SHR_4, DPP_ROW_MASK, DPP_BANK_MASK, True + ) + should_add = lane_group_os >= fx.Int32(4) + local_cnt = should_add.select(local_cnt + fx.Int32(remote), local_cnt) + + # Broadcast batch total from last lane of group via ds_bpermute + last_lane_of_group = tid | fx.Int32(7) # tid with lower 3 bits set + last_addr = last_lane_of_group * c4_i32 + batch_total = fly_rocdl.ds_bpermute(T.i32, last_addr, local_cnt) + batch_total = fx.Int32(batch_total) + + # Scatter this lane's token + slot = position + local_cnt - c_one_i32 + safe_x = my_has_token.select(my_x, c_one_i32) + topk_slot_sc = safe_x - c_one_i32 + packed_id = (topk_slot_sc << fx.Int32(24)) | my_sub + safe_slot = my_has_token.select(slot, c_oob_idx) + buffer_ops.buffer_store(packed_id, sorted_ids_rsrc, safe_slot) + + w_addr = my_has_token.select(my_sub * c_topk + topk_slot_sc, c_zero_i32) + w_val_i32 = buffer_ops.buffer_load(weights_rsrc, w_addr, vec_width=1, dtype=T.i32) + buffer_ops.buffer_store(w_val_i32, sorted_w_rsrc, safe_slot) + + # Advance position by batch total + position = position + batch_total + + # Write back updated position (for padding phase). + # Invalid lane groups write position (=0+0=0) to cumsum[0] which is harmless. + _lds_store(cumsum_mr, position, cs_sc_ix) + gpu.barrier() + + # Padding already filled by PRE-FILL phase above (before scatter). + + @flyc.jit + def launch_moe_sorting_decode( + topk_ids_tensor: fx.Tensor, + topk_weights_tensor: fx.Tensor, + sorted_token_ids: fx.Tensor, + sorted_weights_out: fx.Tensor, + sorted_expert_ids: fx.Tensor, + num_valid_ids_out: fx.Tensor, + moe_buf: fx.Tensor, + expert_mask_tensor: fx.Tensor, + i32_tokens: fx.Int32, + i32_moe_buf_elems: fx.Int32, + n_grid_blocks: fx.Int32, + stream: fx.Stream = fx.Stream(None), + ): + allocator.finalized = False + ctx = CompilationContext.get_current() + with ir.InsertionPoint(ctx.gpu_module_body): + allocator.finalize() + + launcher = moe_sorting_decode_kernel( + topk_ids_tensor, + topk_weights_tensor, + sorted_token_ids, + sorted_weights_out, + sorted_expert_ids, + num_valid_ids_out, + moe_buf, + expert_mask_tensor, + i32_tokens, + i32_moe_buf_elems, + ) + launcher.launch( + grid=(n_grid_blocks, 1, 1), + block=(DECODE_BLOCK, 1, 1), + stream=stream, + ) + + return launch_moe_sorting_decode + + +# --------------------------------------------------------------------------- +# LDS helpers for prefill kernels (module-level, used inside @flyc.kernel) +# --------------------------------------------------------------------------- +def _lds_load_raw(raw_mr, idx): + """Load i32 from LDS raw memref. idx can be i32 or index.""" + raw_idx = idx.ir_value() if hasattr(idx, "ir_value") else idx + if not isinstance(raw_idx.type, ir.IndexType): + raw_idx = ArithValue(idx).index_cast(T.index) + raw_idx = raw_idx.ir_value() if hasattr(raw_idx, "ir_value") else raw_idx + return fx.Int32(memref_ops.load(raw_mr, [raw_idx])) + + +def _lds_store_raw(raw_mr, val, idx): + """Store i32 to LDS raw memref. idx can be i32 or index.""" + v = val.ir_value() if hasattr(val, "ir_value") else val + raw_idx = idx.ir_value() if hasattr(idx, "ir_value") else idx + if not isinstance(raw_idx.type, ir.IndexType): + raw_idx = ArithValue(idx).index_cast(T.index) + raw_idx = raw_idx.ir_value() if hasattr(raw_idx, "ir_value") else raw_idx + memref_ops.store(v, raw_mr, [raw_idx]) + + +def _unwrap_raw(v): + """Unwrap DSL value to raw MLIR ir.Value.""" + return v.ir_value() if hasattr(v, "ir_value") else v + + +# --------------------------------------------------------------------------- +# FlyDSL GPU kernels — prefill path (4 kernels, large T via HBM workspace) +# --------------------------------------------------------------------------- +@functools.lru_cache(maxsize=256) +def compile_moe_sorting_prefill( + *, + num_experts: int, + topk: int, + unit_size: int = UNIT_SIZE, + has_mask: bool = False, +): + """Compile the prefill-path MoE sorting kernels. + + For token counts exceeding LDS capacity, uses HBM workspace: + K1: ClearWorkspace — zero the workspace buffer + K2: P0 scatter — scatter topk_ids into expert mesh in HBM + K3: P1 count — one block per expert, count non-zero mesh cells + K4: P23 prefix-sum + scatter — prefix-sum on counts, scatter tokens, + fill sorted_expert_ids, zero moe_buf + P0_v2: Fused clear+scatter+count — replaces K1+K2+K3 for small T (<=512) + + Workspace layout (i32 elements): + [0 .. ws_mesh_i32) : uint8 expert mesh (E rows x mesh_stride bytes, packed into i32) + [ws_mesh_i32 .. ws_mesh_i32 + E+1): expert_cumsum (E+1 i32 entries) + + Parameters + ---------- + num_experts : int + Number of routed experts (e.g. 256 for DeepSeek R1). + topk : int + Experts per token (e.g. 8). + unit_size : int + GEMM tile-M for padding alignment (default 32). + """ + arch = get_hip_arch() + E = num_experts + + # --- K1: ClearWorkspace kernel ------------------------------------------- + # CK uses grid=262144, block=1024 (1 store per thread, no loop). + # Match that: block=1024, grid=ceil(ws_total/1024). + K1_BLOCK = 1024 + + @flyc.kernel(known_block_size=[K1_BLOCK, 1, 1]) + def clear_workspace_kernel( + workspace: fx.Tensor, + i32_total_elems: fx.Int32, + ): + gid = gpu.block_idx.x * fx.Int32(K1_BLOCK) + gpu.thread_idx.x + ws_rsrc = buffer_ops.create_buffer_resource(workspace, max_size=True) + c_zero = fx.Int32(0) + + # Each thread stores exactly one element (no loop needed). + valid = gid < i32_total_elems + buffer_ops.buffer_store(c_zero, ws_rsrc, valid.select(gid, c_zero)) + + @flyc.jit + def launch_clear_ws( + workspace: fx.Tensor, + i32_total_elems: fx.Int32, + n_grid: fx.Int32, + stream: fx.Stream = fx.Stream(None), + ): + launcher = clear_workspace_kernel(workspace, i32_total_elems) + launcher.launch(grid=(n_grid, 1, 1), block=(K1_BLOCK, 1, 1), stream=stream) + + # --- K2: P0 scatter kernel ----------------------------------------------- + # uint8 mesh: stores topk_slot+1 (max 9) as a single byte directly. + # mesh_stride is in bytes; byte_offset = eid * mesh_stride + token_id. + # No two threads write the same byte (unique experts per token). + K2_BLOCK = 256 + + @flyc.kernel + def p0_scatter_kernel( + topk_ids: fx.Tensor, + workspace: fx.Tensor, + i32_tokens: fx.Int32, + i32_mesh_stride: fx.Int32, + i32_niters: fx.Int32, + ): + gid = gpu.block_idx.x * fx.Int32(K2_BLOCK) + gpu.thread_idx.x + stride = gpu.grid_dim.x * fx.Int32(K2_BLOCK) + topk_rsrc = buffer_ops.create_buffer_resource(topk_ids, max_size=True) + ws_rsrc = buffer_ops.create_buffer_resource(workspace, max_size=True) + c_zero = fx.Int32(0) + c_topk = fx.Int32(topk) + c_one = fx.Int32(1) + + total = i32_tokens * c_topk + + _s = fx.Index(0) + _e = ArithValue(i32_niters).index_cast(T.index) + _one = fx.Index(1) + for _i in range(_s, _e, _one): + flat = gid + fx.Int32(_i) * stride + valid = flat < total + safe_flat = valid.select(flat, c_zero) + token_id = safe_flat // c_topk + topk_slot = safe_flat % c_topk + eid = buffer_ops.buffer_load(topk_rsrc, safe_flat, vec_width=1, dtype=T.i32) + byte_offset = eid * i32_mesh_stride + token_id + val_i8 = ArithValue(topk_slot + c_one).trunci(T.i8) + _if_valid_k2 = scf.IfOp(valid.ir_value()) + with _if_then(_if_valid_k2): + buffer_ops.buffer_store(val_i8, ws_rsrc, byte_offset, offset_is_bytes=True) + + @flyc.jit + def launch_p0( + topk_ids: fx.Tensor, + workspace: fx.Tensor, + i32_tokens: fx.Int32, + i32_mesh_stride: fx.Int32, + i32_niters: fx.Int32, + n_grid: fx.Int32, + stream: fx.Stream = fx.Stream(None), + ): + launcher = p0_scatter_kernel(topk_ids, workspace, i32_tokens, i32_mesh_stride, i32_niters) + launcher.launch(grid=(n_grid, 1, 1), block=(K2_BLOCK, 1, 1), stream=stream) + + # --- K3: P1 count kernel ------------------------------------------------- + # 256 threads (4 waves), vec_width=4: each thread loads 4 i32 words (16 + # mesh cells) per iteration. 4 waves provide 4x memory-level parallelism + # vs the old 1-wave (64-thread) design, matching CK P1's block size. + # Cross-warp reduction via LDS (4 partial sums, one per warp). + K3_BLOCK = 256 + K3_NUM_WAVES = K3_BLOCK // WARP_SIZE + K3_VEC_WIDTH = 4 + K3_WORDS_PER_ITER = K3_BLOCK * K3_VEC_WIDTH + K3_WORDS_PER_ITER_LOG2 = (K3_WORDS_PER_ITER).bit_length() - 1 + + k3_allocator = SmemAllocator(None, arch=arch, global_sym_name="smem_storage_p1") + k3_reduce_offset = k3_allocator._align(k3_allocator.ptr, 16) + k3_allocator.ptr = k3_reduce_offset + K3_NUM_WAVES * 4 + + @flyc.kernel + def p1_count_kernel( + workspace: fx.Tensor, + expert_mask_tensor: fx.Tensor, + i32_mesh_stride: fx.Int32, + i32_mesh_size: fx.Int32, + ): + eid = gpu.block_idx.x + tid = gpu.thread_idx.x + lane = tid % WARP_SIZE + wave = tid // WARP_SIZE + + ws_rsrc = buffer_ops.create_buffer_resource(workspace, max_size=True) + c_zero = fx.Int32(0) + c_one = fx.Int32(1) + c_ff = fx.Int32(0xFF) + + base_ptr = k3_allocator.get_base() + reduce_mr = SmemPtr(base_ptr, k3_reduce_offset, T.i32, shape=(K3_NUM_WAVES,)).get() + + mesh_row_i32_base = (eid * i32_mesh_stride) >> fx.Int32(2) + i32_words_per_row = i32_mesh_stride >> fx.Int32(2) + n_iters = (i32_words_per_row + fx.Int32(K3_WORDS_PER_ITER - 1)) >> fx.Int32(K3_WORDS_PER_ITER_LOG2) + + if has_mask: + mask_rsrc = buffer_ops.create_buffer_resource(expert_mask_tensor, max_size=True) + p1_mask = buffer_ops.buffer_load(mask_rsrc, eid, vec_width=1, dtype=T.i32) + p1_is_local = p1_mask != c_zero + p1_should_zero = (~p1_is_local) & (tid == c_zero) + buffer_ops.buffer_store(c_zero, ws_rsrc, p1_should_zero.select(i32_mesh_size + eid, fx.Int32(0x7FFFFFFF))) + n_iters = p1_is_local.select(n_iters, c_zero) + + for _i, state in range(fx.Index(0), ArithValue(n_iters).index_cast(T.index), fx.Index(1), init=[c_zero]): + cnt_so_far = state[0] + + word_base = fx.Int32(_i) * fx.Int32(K3_WORDS_PER_ITER) + tid * fx.Int32(K3_VEC_WIDTH) + valid = word_base < i32_words_per_row + safe_addr = mesh_row_i32_base + valid.select(word_base, c_zero) + vec4 = buffer_ops.buffer_load(ws_rsrc, safe_addr, vec_width=4, dtype=T.i32) + + iter_cnt = c_zero + for _wi in range_constexpr(K3_VEC_WIDTH): + word = Vec(vec4)[_wi] + word_valid = valid & ((word_base + fx.Int32(_wi)) < i32_words_per_row) + b0 = word & c_ff + b1 = (word >> fx.Int32(8)) & c_ff + b2 = (word >> fx.Int32(16)) & c_ff + b3 = (word >> fx.Int32(24)) & c_ff + nz0 = word_valid.select((b0 != c_zero).select(c_one, c_zero), c_zero) + nz1 = word_valid.select((b1 != c_zero).select(c_one, c_zero), c_zero) + nz2 = word_valid.select((b2 != c_zero).select(c_one, c_zero), c_zero) + nz3 = word_valid.select((b3 != c_zero).select(c_one, c_zero), c_zero) + iter_cnt = iter_cnt + nz0 + nz1 + nz2 + nz3 + + new_cnt = cnt_so_far + iter_cnt + results = yield [new_cnt] + cnt = results + + # Intra-warp reduce via shuffle_xor + width_ws = fx.Int32(WARP_SIZE) + for sh in range_constexpr(int.bit_length(WARP_SIZE) - 1): + off = fx.Int32(1 << sh) + peer = cnt.shuffle_xor(off, width_ws) + cnt = cnt + peer + + # Cross-warp reduce via LDS: lane 0 of each warp writes partial sum + is_lane0 = lane == c_zero + _if_l0 = scf.IfOp(is_lane0.ir_value()) + with _if_then(_if_l0): + wave_ix = ArithValue(wave).index_cast(T.index) + _lds_store_raw(reduce_mr, cnt, wave_ix) + gpu.barrier() + + # Thread 0 sums all warp partials and writes to HBM + is_t0 = tid == c_zero + total = c_zero + for _w in range_constexpr(K3_NUM_WAVES): + total = total + _lds_load_raw(reduce_mr, ArithValue(fx.Int32(_w)).index_cast(T.index)) + + cs_offset = i32_mesh_size + eid + c_oob_idx = fx.Int32(0x7FFFFFFF) + safe_cs = is_t0.select(cs_offset, c_oob_idx) + buffer_ops.buffer_store(total, ws_rsrc, safe_cs) + + @flyc.jit + def launch_p1( + workspace: fx.Tensor, + expert_mask_tensor: fx.Tensor, + i32_mesh_stride: fx.Int32, + i32_mesh_size: fx.Int32, + stream: fx.Stream = fx.Stream(None), + ): + k3_allocator.finalized = False + ctx = CompilationContext.get_current() + with ir.InsertionPoint(ctx.gpu_module_body): + k3_allocator.finalize() + + launcher = p1_count_kernel(workspace, expert_mask_tensor, i32_mesh_stride, i32_mesh_size) + launcher.launch(grid=(E, 1, 1), block=(K3_BLOCK, 1, 1), stream=stream) + + # --- P0_v2: Fused clear+scatter+count kernel (for small T) --------------- + # Replaces K1+K2+K3 with a single kernel launch. + # Grid: E blocks (one per expert), Block: 512 threads (matching CK P0_v2). + # Phase 1: clear this expert's mesh row + # Phase 2: scan all T*topk assignments, filter by expert, byte stores + # Phase 3: popcount + warp reduce + cross-wave LDS reduce -> expert_cumsum + P0V2_BLOCK = 512 + P0V2_NUM_WAVES = P0V2_BLOCK // WARP_SIZE + + # Power-of-2 topk: use shift to avoid division + _p0v2_topk_is_po2 = (topk & (topk - 1)) == 0 and topk > 0 + _p0v2_topk_log2 = topk.bit_length() - 1 if _p0v2_topk_is_po2 else 0 + + # LDS for cross-wave reduction (same layout as K3) + p0v2_allocator = SmemAllocator(None, arch=arch, global_sym_name="smem_storage_p0v2") + p0v2_reduce_offset = p0v2_allocator._align(p0v2_allocator.ptr, 16) + p0v2_allocator.ptr = p0v2_reduce_offset + P0V2_NUM_WAVES * 4 + + @flyc.kernel(known_block_size=[P0V2_BLOCK, 1, 1]) + def p0v2_kernel( + topk_ids: fx.Tensor, + workspace: fx.Tensor, + expert_mask_tensor: fx.Tensor, + i32_tokens: fx.Int32, + i32_mesh_stride: fx.Int32, + i32_mesh_size: fx.Int32, + ): + eid = gpu.block_idx.x + tid = gpu.thread_idx.x + lane = tid % WARP_SIZE + wave = tid // WARP_SIZE + + ws_rsrc = buffer_ops.create_buffer_resource(workspace, max_size=True) + mask_rsrc = buffer_ops.create_buffer_resource(expert_mask_tensor, max_size=True) + topk_rsrc = buffer_ops.create_buffer_resource(topk_ids, max_size=True) + c_zero = fx.Int32(0) + c_oob = fx.Int32(0x7FFFFFFF) + c_one = fx.Int32(1) + c_ff = fx.Int32(0xFF) + c_topk = fx.Int32(topk) + c_block = fx.Int32(P0V2_BLOCK) + + base_ptr = p0v2_allocator.get_base() + reduce_mr = SmemPtr(base_ptr, p0v2_reduce_offset, T.i32, shape=(P0V2_NUM_WAVES,)).get() + + # Precompute mesh row base (in i32 words) and words per row + mesh_row_i32_base = (eid * i32_mesh_stride) >> fx.Int32(2) + i32_words_per_row = i32_mesh_stride >> fx.Int32(2) + + clear_niters = (i32_words_per_row + fx.Int32(P0V2_BLOCK - 1)) >> fx.Int32(9) + total_assignments = i32_tokens * c_topk + scatter_niters = (total_assignments + fx.Int32(P0V2_BLOCK - 1)) >> fx.Int32(9) + + # EP: load mask, write cumsum=0 for masked experts, set loop bounds to 0 + if has_mask: + m_val = buffer_ops.buffer_load(mask_rsrc, eid, vec_width=1, dtype=T.i32) + is_local_expert = m_val != c_zero + should_write_zero = (~is_local_expert) & (tid == c_zero) + buffer_ops.buffer_store(c_zero, ws_rsrc, should_write_zero.select(i32_mesh_size + eid, c_oob)) + clear_niters = is_local_expert.select(clear_niters, c_zero) + scatter_niters = is_local_expert.select(scatter_niters, c_zero) + + # ---- Phase 1: Clear this expert's mesh row ---- + for _ci in range(fx.Index(0), ArithValue(clear_niters).index_cast(T.index), fx.Index(1)): + word_idx = fx.Int32(_ci) * c_block + tid + valid = word_idx < i32_words_per_row + safe_idx = mesh_row_i32_base + valid.select(word_idx, c_zero) + buffer_ops.buffer_store(c_zero, ws_rsrc, valid.select(safe_idx, c_oob)) + + gpu.barrier() + + # ---- Phase 2: Scatter (scan all T*topk, filter by expert) ---- + for _si in range(fx.Index(0), ArithValue(scatter_niters).index_cast(T.index), fx.Index(1)): + flat = fx.Int32(_si) * c_block + tid + valid = flat < total_assignments + safe_flat = valid.select(flat, c_zero) + + token_id = safe_flat >> fx.Int32(_p0v2_topk_log2) if _p0v2_topk_is_po2 else safe_flat // c_topk + topk_slot = safe_flat & fx.Int32(topk - 1) if _p0v2_topk_is_po2 else safe_flat % c_topk + + expert_id = buffer_ops.buffer_load(topk_rsrc, safe_flat, vec_width=1, dtype=T.i32) + + is_mine = valid & (expert_id == eid) + byte_offset = eid * i32_mesh_stride + token_id + val_i8 = ArithValue(is_mine.select(topk_slot + c_one, c_zero)).trunci(T.i8) + # Byte-mode buffer_store with OOB offset crashes on AMD GPUs. + # Use conditional branch to skip the store for non-matching threads. + _if_mine = scf.IfOp(is_mine.ir_value()) + with _if_then(_if_mine): + buffer_ops.buffer_store(val_i8, ws_rsrc, byte_offset, offset_is_bytes=True) + + gpu.barrier() + + # ---- Phase 3: Count non-zero bytes + warp/cross-wave reduce ---- + count_niters = clear_niters # same loop structure, reuse (already EP-gated) + for _ki, state in range(fx.Index(0), ArithValue(count_niters).index_cast(T.index), fx.Index(1), init=[c_zero]): + cnt_so_far = state[0] + + word_base = fx.Int32(_ki) * c_block + tid + valid = word_base < i32_words_per_row + safe_addr = mesh_row_i32_base + valid.select(word_base, c_zero) + word = buffer_ops.buffer_load(ws_rsrc, safe_addr, vec_width=1, dtype=T.i32) + + b0 = word & c_ff + b1 = (word >> fx.Int32(8)) & c_ff + b2 = (word >> fx.Int32(16)) & c_ff + b3 = (word >> fx.Int32(24)) & c_ff + nz0 = valid.select((b0 != c_zero).select(c_one, c_zero), c_zero) + nz1 = valid.select((b1 != c_zero).select(c_one, c_zero), c_zero) + nz2 = valid.select((b2 != c_zero).select(c_one, c_zero), c_zero) + nz3 = valid.select((b3 != c_zero).select(c_one, c_zero), c_zero) + iter_cnt = nz0 + nz1 + nz2 + nz3 + + new_cnt = cnt_so_far + iter_cnt + results = yield [new_cnt] + cnt = results + + # Intra-warp reduce via shuffle_xor + width_ws = fx.Int32(WARP_SIZE) + for sh in range_constexpr(int.bit_length(WARP_SIZE) - 1): + off = fx.Int32(1 << sh) + peer = cnt.shuffle_xor(off, width_ws) + cnt = cnt + peer + + # Cross-warp reduce via LDS: lane 0 of each warp writes partial sum + is_lane0 = lane == c_zero + _if_l0 = scf.IfOp(is_lane0.ir_value()) + with _if_then(_if_l0): + wave_ix = ArithValue(wave).index_cast(T.index) + _lds_store_raw(reduce_mr, cnt, wave_ix) + gpu.barrier() + + # Thread 0 sums all warp partials and writes to HBM + is_t0 = tid == c_zero + total = c_zero + for _w in range_constexpr(P0V2_NUM_WAVES): + total = total + _lds_load_raw(reduce_mr, ArithValue(fx.Int32(_w)).index_cast(T.index)) + + cs_offset = i32_mesh_size + eid + c_oob_idx = fx.Int32(0x7FFFFFFF) + safe_cs = is_t0.select(cs_offset, c_oob_idx) + buffer_ops.buffer_store(total, ws_rsrc, safe_cs) + + @flyc.jit + def launch_p0v2( + topk_ids: fx.Tensor, + workspace: fx.Tensor, + expert_mask_tensor: fx.Tensor, + i32_tokens: fx.Int32, + i32_mesh_stride: fx.Int32, + i32_mesh_size: fx.Int32, + stream: fx.Stream = fx.Stream(None), + ): + p0v2_allocator.finalized = False + ctx = CompilationContext.get_current() + with ir.InsertionPoint(ctx.gpu_module_body): + p0v2_allocator.finalize() + + launcher = p0v2_kernel(topk_ids, workspace, expert_mask_tensor, i32_tokens, i32_mesh_stride, i32_mesh_size) + launcher.launch(grid=(E, 1, 1), block=(P0V2_BLOCK, 1, 1), stream=stream) + + # --- K4: P23 prefix-sum + scatter + moe_buf zeroing --------------------- + # Parallel design (matching CK P23): each block [0, E) independently + # computes the SAME prefix sum, then scatters ONLY for expert blockIdx.x. + # No inter-block barrier needed — redundant prefix sums are deterministic. + K4_BLOCK = 256 if E <= 256 else 512 + + # LDS: cumsum[E+1] for prefix sums + cross-wave scratch for DPP scan + K4_NUM_WAVES = K4_BLOCK // WARP_SIZE + k4_allocator = SmemAllocator(None, arch=arch) + k4_smem_cols = max(E + 1, K4_BLOCK + 1) + k4_cumsum_offset = k4_allocator._align(k4_allocator.ptr, 16) + k4_allocator.ptr = k4_cumsum_offset + k4_smem_cols * 4 + k4_scatter_offset = k4_allocator._align(k4_allocator.ptr, 16) + k4_allocator.ptr = k4_scatter_offset + K4_NUM_WAVES * 4 + + @flyc.kernel(known_block_size=[K4_BLOCK, 1, 1]) + def p23_kernel( + workspace: fx.Tensor, + topk_weights_tensor: fx.Tensor, + sorted_token_ids: fx.Tensor, + sorted_weights_out: fx.Tensor, + sorted_expert_ids: fx.Tensor, + num_valid_ids: fx.Tensor, + moe_buf: fx.Tensor, + expert_mask_tensor: fx.Tensor, + i32_tokens: fx.Int32, + i32_mesh_stride: fx.Int32, + i32_mesh_size: fx.Int32, + i32_moe_buf_elems: fx.Int32, + ): + bid = gpu.block_idx.x + tid = gpu.thread_idx.x + lane = tid % WARP_SIZE + wave = tid // WARP_SIZE + c_zero = fx.Int32(0) + c_one = fx.Int32(1) + c4 = fx.Int32(4) + c_E = fx.Int32(E) + c_unit = fx.Int32(unit_size) + c_topk = fx.Int32(topk) + c_sentinel = fx.Int32(topk << 24) + c_oob_idx = fx.Int32(0x7FFFFFFF) + c_ff = fx.Int32(0xFF) + + # Buffer resources + ws_rsrc = buffer_ops.create_buffer_resource(workspace, max_size=True) + weights_rsrc = buffer_ops.create_buffer_resource(topk_weights_tensor, max_size=True) + sorted_ids_rsrc = buffer_ops.create_buffer_resource(sorted_token_ids, max_size=True) + sorted_w_rsrc = buffer_ops.create_buffer_resource(sorted_weights_out, max_size=True) + mask_rsrc = buffer_ops.create_buffer_resource(expert_mask_tensor, max_size=True) + + # LDS: cumsum[E+1] for prefix sums + cross-wave scratch + base_ptr = k4_allocator.get_base() + cumsum_mr = SmemPtr(base_ptr, k4_cumsum_offset, T.i32, shape=(k4_smem_cols,)).get() + scatter_mr = SmemPtr(base_ptr, k4_scatter_offset, T.i32, shape=(K4_NUM_WAVES,)).get() + + is_sort_block = bid < c_E + is_zero_block = bid >= c_E + + # ================ MOE_BUF ZEROING (blocks >= E) ================== + _if_zero = scf.IfOp(is_zero_block.ir_value()) + with _if_then(_if_zero): + moe_buf_rsrc = buffer_ops.create_buffer_resource(moe_buf, max_size=True) + zero_base_bid = bid - c_E + zero_gid_v4 = zero_base_bid * fx.Int32(K4_BLOCK) + tid + num_zero_blocks = gpu.grid_dim.x - c_E + zero_stride_v4 = num_zero_blocks * fx.Int32(K4_BLOCK) + i32_moe_buf_v4 = i32_moe_buf_elems >> fx.Int32(2) + zero_niters = (i32_moe_buf_v4 + zero_stride_v4 - c_one) // zero_stride_v4 + c_zero_v4 = fx.Vector.filled(4, 0, fx.Int32) + for _z in range(fx.Index(0), ArithValue(zero_niters).index_cast(T.index), fx.Index(1)): + z_idx_v4 = zero_gid_v4 + fx.Int32(_z) * zero_stride_v4 + z_valid = z_idx_v4 < i32_moe_buf_v4 + z_elem = z_valid.select(z_idx_v4 * c4, c_oob_idx) + buffer_ops.buffer_store(c_zero_v4, moe_buf_rsrc, z_elem) + + # ================ PARALLEL PREFIX-SUM + MESH SCATTER (blocks 0..E-1) == + # Each block independently: prefix sum (redundant), scatter for its expert only. + _if_sort = scf.IfOp(is_sort_block.ir_value()) + with _if_then(_if_sort): + my_expert = bid + + # Step 1: Load expert counts from workspace -> pad to unit_size -> LDS cumsum + # Process E experts in chunks of K4_BLOCK (256). Most models have + # E <= 256, so the extra chunk is only needed for E > 256 + # (e.g. DeepSeek-R1 with 256 routed + 1 shared = 257). + is_t0_init = tid == c_zero + _if_init_cs = scf.IfOp(is_t0_init.ir_value()) + with _if_then(_if_init_cs): + _lds_store_raw(cumsum_mr, c_zero, ArithValue(c_zero).index_cast(T.index)) + + # EP: load this thread's own mask value BEFORE the chunked loop. + # The chunked loop overwrites p23_mask_val in later chunks, so we + # need a stable copy for the mask prefix sum computed after the loop. + my_mask_val = c_one + if has_mask: + tid_has_expert = tid < c_E + my_mask_val = buffer_ops.buffer_load( + mask_rsrc, tid_has_expert.select(tid, c_zero), vec_width=1, dtype=T.i32 + ) + my_mask_val = tid_has_expert.select(my_mask_val, c_zero) + + for _chunk in range_constexpr(0, E, K4_BLOCK): + expert_idx = fx.Int32(_chunk) + tid + tid_valid_expert = expert_idx < c_E + ws_cs_addr = i32_mesh_size + tid_valid_expert.select(expert_idx, c_zero) + raw_cnt = buffer_ops.buffer_load(ws_rsrc, ws_cs_addr, vec_width=1, dtype=T.i32) + raw_cnt = tid_valid_expert.select(raw_cnt, c_zero) + blocks = (raw_cnt + c_unit - c_one) // c_unit + padded = (raw_cnt == c_zero).select(c_zero, blocks * c_unit) + if has_mask: + chunk_mask = buffer_ops.buffer_load( + mask_rsrc, tid_valid_expert.select(expert_idx, c_zero), vec_width=1, dtype=T.i32 + ) + chunk_mask = tid_valid_expert.select(chunk_mask, c_zero) + padded = (chunk_mask == c_zero).select(c_zero, padded) + raw_store_idx = expert_idx + c_one + oob = raw_store_idx >= fx.Int32(k4_smem_cols) + safe_store_idx = oob.select(c_zero, raw_store_idx) + safe_store_val = oob.select(c_zero, padded) + _lds_store_raw(cumsum_mr, safe_store_val, ArithValue(safe_store_idx).index_cast(T.index)) + gpu.barrier() + + # Step 2: Prefix sum over cumsum LDS. When E <= K4_BLOCK (256), + # a single DPP pass covers all experts. When E > K4_BLOCK, we + # do the DPP pass for the first K4_BLOCK elements, then serially + # accumulate the remaining entries from thread 0. + val = _lds_load_raw(cumsum_mr, ArithValue(tid + c_one).index_cast(T.index)) + val = _dpp_intra_wave_prefix_sum(val, lane, WARP_SIZE) + + is_last_lane_ps = lane == fx.Int32(WARP_SIZE - 1) + _if_ll_ps = scf.IfOp(is_last_lane_ps.ir_value()) + with _if_then(_if_ll_ps): + _lds_store_raw(scatter_mr, val, ArithValue(wave).index_cast(T.index)) + gpu.barrier() + cross_offset = c_zero + for _w in range_constexpr(K4_NUM_WAVES - 1): + w_total = _lds_load_raw(scatter_mr, ArithValue(fx.Int32(_w)).index_cast(T.index)) + cross_offset = (wave > fx.Int32(_w)).select(cross_offset + w_total, cross_offset) + total_padded = c_zero + for _w in range_constexpr(K4_NUM_WAVES): + total_padded = total_padded + _lds_load_raw(scatter_mr, ArithValue(fx.Int32(_w)).index_cast(T.index)) + + inclusive_prefix = val + cross_offset + _lds_store_raw(cumsum_mr, inclusive_prefix, ArithValue(tid + c_one).index_cast(T.index)) + gpu.barrier() + + # For E > K4_BLOCK: thread 0 serially extends the prefix sum + # for experts K4_BLOCK..E-1 (at most a few iterations). + if E > K4_BLOCK: + _if_t0_extra = scf.IfOp(is_t0_init.ir_value()) + with _if_then(_if_t0_extra): + prev_sum = _lds_load_raw(cumsum_mr, ArithValue(fx.Int32(K4_BLOCK)).index_cast(T.index)) + for _extra in range_constexpr(K4_BLOCK, E): + cur = _lds_load_raw(cumsum_mr, ArithValue(fx.Int32(_extra + 1)).index_cast(T.index)) + new_sum = prev_sum + cur + _lds_store_raw(cumsum_mr, new_sum, ArithValue(fx.Int32(_extra + 1)).index_cast(T.index)) + prev_sum = new_sum + total_padded = prev_sum + gpu.barrier() + total_padded = _lds_load_raw(cumsum_mr, ArithValue(c_E).index_cast(T.index)) + + # Read my_start and my_end from cumsum LDS + my_start = _lds_load_raw(cumsum_mr, ArithValue(my_expert).index_cast(T.index)) + my_end = _lds_load_raw(cumsum_mr, ArithValue(my_expert + c_one).index_cast(T.index)) + + local_idx_p23 = tid + if has_mask: + # EP: Compute mask cumsum for local expert index (register-only DPP scan). + # Uses my_mask_val (loaded before the chunked loop) to avoid overwrite. + p23_mv = _dpp_intra_wave_prefix_sum(my_mask_val, lane, WARP_SIZE) + + # Cross-wave via scratch_mr + is_last_lane_pm = lane == fx.Int32(WARP_SIZE - 1) + _if_ll_pm = scf.IfOp(is_last_lane_pm.ir_value()) + with _if_then(_if_ll_pm): + _lds_store_raw(scatter_mr, p23_mv, ArithValue(wave).index_cast(T.index)) + gpu.barrier() + cross_pm = c_zero + for _w in range_constexpr(K4_NUM_WAVES - 1): + wtp = _lds_load_raw(scatter_mr, ArithValue(fx.Int32(_w)).index_cast(T.index)) + cross_pm = (wave > fx.Int32(_w)).select(cross_pm + wtp, cross_pm) + p23_mask_inclusive = p23_mv + cross_pm + local_idx_p23 = p23_mask_inclusive - my_mask_val + else: + local_idx_p23 = tid + + # Block 0, thread 0 writes num_valid_ids + is_b0 = bid == c_zero + is_t0 = tid == c_zero + is_b0_t0 = is_b0 & is_t0 + _if_nv = scf.IfOp(is_b0_t0.ir_value()) + with _if_then(_if_nv): + nvalid_rsrc = buffer_ops.create_buffer_resource(num_valid_ids, max_size=True) + buffer_ops.buffer_store(total_padded, nvalid_rsrc, c_zero) + buffer_ops.buffer_store(i32_tokens, nvalid_rsrc, c_one) + + # Step 3: Write sorted_expert_ids for THIS expert (using local_idx_p23 for EP) + # Store local_idx to LDS cumsum[tid], barrier, read cumsum[my_expert] + _lds_store_raw(cumsum_mr, local_idx_p23, ArithValue(tid).index_cast(T.index)) + # For E > K4_BLOCK: thread 0 writes local_idx for extra experts + if E > K4_BLOCK: + is_t0_extra3 = tid == c_zero + _if_t0_e3 = scf.IfOp(is_t0_extra3.ir_value()) + with _if_then(_if_t0_e3): + if has_mask: + # EP: serially extend mask prefix sum for experts >= K4_BLOCK + prev_local = _lds_load_raw(cumsum_mr, ArithValue(fx.Int32(K4_BLOCK - 1)).index_cast(T.index)) + prev_mask = buffer_ops.buffer_load( + mask_rsrc, fx.Int32(K4_BLOCK - 1), vec_width=1, dtype=T.i32 + ) + prev_local = prev_local + prev_mask + for _e3 in range_constexpr(K4_BLOCK, E): + e3_mask = buffer_ops.buffer_load(mask_rsrc, fx.Int32(_e3), vec_width=1, dtype=T.i32) + _lds_store_raw( + cumsum_mr, prev_local, ArithValue(fx.Int32(_e3)).index_cast(T.index) + ) + prev_local = prev_local + e3_mask + else: + for _e3 in range_constexpr(K4_BLOCK, E): + _lds_store_raw( + cumsum_mr, fx.Int32(_e3), ArithValue(fx.Int32(_e3)).index_cast(T.index) + ) + gpu.barrier() + my_local_idx = _lds_load_raw(cumsum_mr, ArithValue(my_expert).index_cast(T.index)) + + sorted_e_rsrc = buffer_ops.create_buffer_resource(sorted_expert_ids, max_size=True) + blk_start = my_start // c_unit + blk_end = my_end // c_unit + n_blks = blk_end - blk_start + n_eid_iters = (n_blks + fx.Int32(K4_BLOCK) - c_one) // fx.Int32(K4_BLOCK) + for _eii in range(fx.Index(0), ArithValue(n_eid_iters).index_cast(T.index), fx.Index(1)): + blk_idx = blk_start + fx.Int32(_eii) * fx.Int32(K4_BLOCK) + tid + buffer_ops.buffer_store(my_local_idx, sorted_e_rsrc, (blk_idx < blk_end).select(blk_idx, c_oob_idx)) + + # Step 4: Mesh-based scatter — read uint8 mesh from HBM, extract tokens, + # DPP prefix sum over counts, cross-wave LDS reduction, scatter stores. + p23_bid_enabled = c_one != c_zero + if has_mask: + # EP: skip scatter for masked experts (my_start == my_end, but mesh has data) + p23_bid_mask = buffer_ops.buffer_load(mask_rsrc, my_expert, vec_width=1, dtype=T.i32) + p23_bid_enabled = p23_bid_mask != c_zero + + i32_words_per_row = i32_mesh_stride >> fx.Int32(2) + n_mesh_iters_raw = (i32_words_per_row + fx.Int32(K4_BLOCK - 1)) // fx.Int32(K4_BLOCK) + has_work = my_start != my_end + n_mesh_iters = has_work.select(n_mesh_iters_raw, c_zero) + mesh_row_i32_base = (my_expert * i32_mesh_stride) >> fx.Int32(2) + + for _si, state in range( + fx.Index(0), ArithValue(n_mesh_iters).index_cast(T.index), fx.Index(1), init=[my_start] + ): + position = state[0] + + word_idx = fx.Int32(_si) * fx.Int32(K4_BLOCK) + tid + col_valid = p23_bid_enabled & (word_idx < i32_words_per_row) + safe_word_idx = col_valid.select(word_idx, c_zero) + word = buffer_ops.buffer_load(ws_rsrc, mesh_row_i32_base + safe_word_idx, vec_width=1, dtype=T.i32) + + # Extract 4 bytes from the i32 word + x0 = word & c_ff + x1 = (word >> fx.Int32(8)) & c_ff + x2 = (word >> fx.Int32(16)) & c_ff + x3 = (word >> fx.Int32(24)) & c_ff + base_col = word_idx * c4 + + h0 = col_valid & (x0 != c_zero) + h1 = col_valid & (x1 != c_zero) + h2 = col_valid & (x2 != c_zero) + h3 = col_valid & (x3 != c_zero) + + my_cnt = ( + h0.select(c_one, c_zero) + + h1.select(c_one, c_zero) + + h2.select(c_one, c_zero) + + h3.select(c_one, c_zero) + ) + + my_cnt = _dpp_intra_wave_prefix_sum(my_cnt, lane, WARP_SIZE) + + # my_cnt is now intra-wave inclusive prefix sum of per-thread token counts. + # Cross-wave reduction via LDS scratch. + is_last_lane_sc = lane == fx.Int32(WARP_SIZE - 1) + _if_ll_sc = scf.IfOp(is_last_lane_sc.ir_value()) + with _if_then(_if_ll_sc): + _lds_store_raw(scatter_mr, my_cnt, ArithValue(wave).index_cast(T.index)) + gpu.barrier() + + wave_offset = c_zero + for _w in range_constexpr(K4_NUM_WAVES - 1): + w_total = _lds_load_raw(scatter_mr, ArithValue(fx.Int32(_w)).index_cast(T.index)) + wave_offset = (wave > fx.Int32(_w)).select(wave_offset + w_total, wave_offset) + batch_total = c_zero + for _w in range_constexpr(K4_NUM_WAVES): + batch_total = batch_total + _lds_load_raw(scatter_mr, ArithValue(fx.Int32(_w)).index_cast(T.index)) + gpu.barrier() + + # Convert to exclusive prefix: my_exclusive = my_cnt - my_thread_count + my_thread_count = ( + h0.select(c_one, c_zero) + + h1.select(c_one, c_zero) + + h2.select(c_one, c_zero) + + h3.select(c_one, c_zero) + ) + my_exclusive = my_cnt - my_thread_count + wave_offset + + # Scatter: compute all addresses, batch-load weights, then batch-store. + scatter_base = position + my_exclusive + + # Compute packed IDs and output slots for all 4 tokens + token_id_0 = base_col + topk_slot_0 = h0.select(x0 - c_one, c_zero) + pid_0 = (topk_slot_0 << fx.Int32(24)) | token_id_0 + safe_slot_0 = h0.select(scatter_base, c_oob_idx) + + off1 = scatter_base + h0.select(c_one, c_zero) + token_id_1 = base_col + c_one + topk_slot_1 = h1.select(x1 - c_one, c_zero) + pid_1 = (topk_slot_1 << fx.Int32(24)) | token_id_1 + safe_slot_1 = h1.select(off1, c_oob_idx) + + off2 = off1 + h1.select(c_one, c_zero) + token_id_2 = base_col + fx.Int32(2) + topk_slot_2 = h2.select(x2 - c_one, c_zero) + pid_2 = (topk_slot_2 << fx.Int32(24)) | token_id_2 + safe_slot_2 = h2.select(off2, c_oob_idx) + + off3 = off2 + h2.select(c_one, c_zero) + token_id_3 = base_col + fx.Int32(3) + topk_slot_3 = h3.select(x3 - c_one, c_zero) + pid_3 = (topk_slot_3 << fx.Int32(24)) | token_id_3 + safe_slot_3 = h3.select(off3, c_oob_idx) + + # Batch-issue all 4 weight loads (increases load-use distance) + w_addr_0 = h0.select(token_id_0 * c_topk + topk_slot_0, c_zero) + w_addr_1 = h1.select(token_id_1 * c_topk + topk_slot_1, c_zero) + w_addr_2 = h2.select(token_id_2 * c_topk + topk_slot_2, c_zero) + w_addr_3 = h3.select(token_id_3 * c_topk + topk_slot_3, c_zero) + w_val_0 = buffer_ops.buffer_load(weights_rsrc, w_addr_0, vec_width=1, dtype=T.i32) + w_val_1 = buffer_ops.buffer_load(weights_rsrc, w_addr_1, vec_width=1, dtype=T.i32) + w_val_2 = buffer_ops.buffer_load(weights_rsrc, w_addr_2, vec_width=1, dtype=T.i32) + w_val_3 = buffer_ops.buffer_load(weights_rsrc, w_addr_3, vec_width=1, dtype=T.i32) + + # Batch-store: all packed IDs, then all weights + buffer_ops.buffer_store(pid_0, sorted_ids_rsrc, safe_slot_0) + buffer_ops.buffer_store(pid_1, sorted_ids_rsrc, safe_slot_1) + buffer_ops.buffer_store(pid_2, sorted_ids_rsrc, safe_slot_2) + buffer_ops.buffer_store(pid_3, sorted_ids_rsrc, safe_slot_3) + buffer_ops.buffer_store(w_val_0, sorted_w_rsrc, safe_slot_0) + buffer_ops.buffer_store(w_val_1, sorted_w_rsrc, safe_slot_1) + buffer_ops.buffer_store(w_val_2, sorted_w_rsrc, safe_slot_2) + buffer_ops.buffer_store(w_val_3, sorted_w_rsrc, safe_slot_3) + + pos_next = position + batch_total + results = yield [pos_next] + scatter_end_pos_t0 = results + + # Step 5: Fill padding with sentinel for THIS expert (parallel) + sentinel_val = c_sentinel | i32_tokens + pad_count = my_end - scatter_end_pos_t0 + pad_niters = (pad_count + fx.Int32(K4_BLOCK) - c_one) // fx.Int32(K4_BLOCK) + for _pi in range(fx.Index(0), ArithValue(pad_niters).index_cast(T.index), fx.Index(1)): + pad_slot = scatter_end_pos_t0 + fx.Int32(_pi) * fx.Int32(K4_BLOCK) + tid + pad_valid = pad_slot < my_end + buffer_ops.buffer_store(sentinel_val, sorted_ids_rsrc, pad_valid.select(pad_slot, c_oob_idx)) + buffer_ops.buffer_store(c_zero, sorted_w_rsrc, pad_valid.select(pad_slot, c_oob_idx)) + + @flyc.jit + def launch_p23( + workspace: fx.Tensor, + topk_weights_tensor: fx.Tensor, + sorted_token_ids: fx.Tensor, + sorted_weights_out: fx.Tensor, + sorted_expert_ids: fx.Tensor, + num_valid_ids_out: fx.Tensor, + moe_buf: fx.Tensor, + expert_mask_tensor: fx.Tensor, + i32_tokens: fx.Int32, + i32_mesh_stride: fx.Int32, + i32_mesh_size: fx.Int32, + i32_moe_buf_elems: fx.Int32, + n_grid: fx.Int32, + stream: fx.Stream = fx.Stream(None), + ): + k4_allocator.finalized = False + ctx = CompilationContext.get_current() + with ir.InsertionPoint(ctx.gpu_module_body): + k4_allocator.finalize() + + launcher = p23_kernel( + workspace, + topk_weights_tensor, + sorted_token_ids, + sorted_weights_out, + sorted_expert_ids, + num_valid_ids_out, + moe_buf, + expert_mask_tensor, + i32_tokens, + i32_mesh_stride, + i32_mesh_size, + i32_moe_buf_elems, + ) + launcher.launch(grid=(n_grid, 1, 1), block=(K4_BLOCK, 1, 1), stream=stream) + + @flyc.jit + def launch_p0v2_p23( + topk_ids: fx.Tensor, + workspace: fx.Tensor, + topk_weights_tensor: fx.Tensor, + sorted_token_ids: fx.Tensor, + sorted_weights_out: fx.Tensor, + sorted_expert_ids: fx.Tensor, + num_valid_ids_out: fx.Tensor, + moe_buf: fx.Tensor, + expert_mask_tensor: fx.Tensor, + i32_tokens: fx.Int32, + i32_mesh_stride: fx.Int32, + i32_mesh_size: fx.Int32, + i32_moe_buf_elems: fx.Int32, + n_grid_p23: fx.Int32, + stream: fx.Stream = fx.Stream(None), + ): + p0v2_allocator.finalized = False + k4_allocator.finalized = False + ctx = CompilationContext.get_current() + with ir.InsertionPoint(ctx.gpu_module_body): + p0v2_allocator.finalize() + k4_allocator.finalize() + + l1 = p0v2_kernel(topk_ids, workspace, expert_mask_tensor, i32_tokens, i32_mesh_stride, i32_mesh_size) + l1.launch(grid=(E, 1, 1), block=(P0V2_BLOCK, 1, 1), stream=stream) + + l2 = p23_kernel( + workspace, + topk_weights_tensor, + sorted_token_ids, + sorted_weights_out, + sorted_expert_ids, + num_valid_ids_out, + moe_buf, + expert_mask_tensor, + i32_tokens, + i32_mesh_stride, + i32_mesh_size, + i32_moe_buf_elems, + ) + l2.launch(grid=(n_grid_p23, 1, 1), block=(K4_BLOCK, 1, 1), stream=stream) + + @flyc.jit + def launch_4k_fused( + topk_ids: fx.Tensor, + workspace: fx.Tensor, + topk_weights_tensor: fx.Tensor, + sorted_token_ids: fx.Tensor, + sorted_weights_out: fx.Tensor, + sorted_expert_ids: fx.Tensor, + num_valid_ids_out: fx.Tensor, + moe_buf: fx.Tensor, + expert_mask_tensor: fx.Tensor, + i32_tokens: fx.Int32, + i32_mesh_stride: fx.Int32, + i32_mesh_size: fx.Int32, + i32_moe_buf_elems: fx.Int32, + i32_ws_total: fx.Int32, + i32_p0_niters: fx.Int32, + n_grid_k1: fx.Int32, + n_grid_k2: fx.Int32, + n_grid_p23: fx.Int32, + stream: fx.Stream = fx.Stream(None), + ): + k3_allocator.finalized = False + k4_allocator.finalized = False + ctx = CompilationContext.get_current() + with ir.InsertionPoint(ctx.gpu_module_body): + k3_allocator.finalize() + k4_allocator.finalize() + + l1 = clear_workspace_kernel(workspace, i32_ws_total) + l1.launch(grid=(n_grid_k1, 1, 1), block=(K1_BLOCK, 1, 1), stream=stream) + + l2 = p0_scatter_kernel(topk_ids, workspace, i32_tokens, i32_mesh_stride, i32_p0_niters) + l2.launch(grid=(n_grid_k2, 1, 1), block=(K2_BLOCK, 1, 1), stream=stream) + + l3 = p1_count_kernel(workspace, expert_mask_tensor, i32_mesh_stride, i32_mesh_size) + l3.launch(grid=(E, 1, 1), block=(K3_BLOCK, 1, 1), stream=stream) + + l4 = p23_kernel( + workspace, + topk_weights_tensor, + sorted_token_ids, + sorted_weights_out, + sorted_expert_ids, + num_valid_ids_out, + moe_buf, + expert_mask_tensor, + i32_tokens, + i32_mesh_stride, + i32_mesh_size, + i32_moe_buf_elems, + ) + l4.launch(grid=(n_grid_p23, 1, 1), block=(K4_BLOCK, 1, 1), stream=stream) + + return launch_clear_ws, launch_p0, launch_p1, launch_p23, launch_p0v2, launch_p0v2_p23, launch_4k_fused + + +# Host-side entry point +# --------------------------------------------------------------------------- +@functools.lru_cache(maxsize=64) +def _compute_sub_tokens(num_experts, arch=None): + """Compute the LDS-capacity threshold (sub_tokens) for decode vs prefill decision. + + Returns the max T that fits in LDS for the decode (single-kernel) path. + Same formula as compile_moe_sorting_decode. + """ + if arch is None: + arch = get_hip_arch() + E = num_experts + smem_cols = E + 1 + if arch in ("gfx942",) or str(arch).startswith("gfx94"): + lds_capacity_bytes = 65536 + elif str(arch).startswith("gfx95"): + lds_capacity_bytes = 163840 + else: + lds_capacity_bytes = 65536 + lds_capacity_ints = lds_capacity_bytes // 4 + target_occupancy = 2 + r = lds_capacity_ints // target_occupancy // smem_cols + sub_unroll = 8 + cumsum_bufs = 2 + if r < (cumsum_bufs + sub_unroll): + return 0 # LDS too small — always use prefill + r_for_sub = ((r - cumsum_bufs) // sub_unroll) * sub_unroll + return r_for_sub + + +def moe_sorting_get_workspace_size(M, num_experts, topk, unit_size=UNIT_SIZE): + """Return workspace size (in i32 elements) needed for the prefill path. + Returns 0 if the decode path will be used.""" + sub_tokens = _compute_sub_tokens(num_experts) + DECODE_MAX_T = 16 + if M <= min(sub_tokens, DECODE_MAX_T): + return 0 + mesh_stride = ((M + unit_size - 1) // unit_size) * unit_size + ws_mesh_bytes = num_experts * mesh_stride + ws_mesh_i32 = (ws_mesh_bytes + 3) // 4 + return ws_mesh_i32 + (num_experts + 1) + + +def moe_sorting_flydsl( + topk_ids, + topk_weights, + sorted_ids, + sorted_weights, + sorted_expert_ids, + num_valid_ids, + moe_buf, + num_experts, + unit_size=UNIT_SIZE, + expert_mask=None, + num_local_tokens=None, + workspace=None, +): + """MoE sorting using FlyDSL kernel (decode + prefill paths). + + API matches aiter.moe_sorting_fwd for drop-in replacement: + moe_sorting_flydsl(topk_ids, topk_weights, + sorted_ids, sorted_weights, sorted_expert_ids, + num_valid_ids, moe_buf, + num_experts, unit_size, expert_mask, + num_local_tokens, workspace) + + All output tensors (sorted_ids, sorted_weights, sorted_expert_ids, + num_valid_ids, moe_buf) must be pre-allocated by the caller. + + Returns + ------- + sorted_ids, sorted_weights, sorted_expert_ids, num_valid_ids, moe_buf + """ + topk = topk_ids.shape[1] + if num_local_tokens is not None: + M = num_local_tokens.item() if isinstance(num_local_tokens, torch.Tensor) else int(num_local_tokens) + else: + M = topk_ids.shape[0] + + sub_tokens = _compute_sub_tokens(num_experts) + + device = topk_ids.device + moe_buf_i32 = moe_buf.view(torch.int32) + moe_buf_elems = moe_buf_i32.numel() + + # EP: prepare mask tensor and flag. + has_mask = expert_mask is not None + if not has_mask: + mask_tensor = _dummy_mask_cache.get(device) + if mask_tensor is None: + mask_tensor = torch.ones(1, dtype=torch.int32, device=device) + _dummy_mask_cache[device] = mask_tensor + else: + mask_tensor = expert_mask + + DECODE_MAX_T = 16 + + if M <= min(sub_tokens, DECODE_MAX_T): + max_tokens = max(M, 8) + max_tokens = ((max_tokens + 7) // 8) * 8 + + target_occupancy = 2 + num_cu = torch.cuda.get_device_properties(topk_ids.device).multi_processor_count + n_zero_blocks = min((moe_buf_elems + BLOCK_SIZE - 1) // BLOCK_SIZE, num_cu * target_occupancy) + n_grid_blocks = 1 + n_zero_blocks + + launch_moe_sorting_decode_path( + topk_ids, + topk_weights, + sorted_ids, + sorted_weights, + sorted_expert_ids, + num_valid_ids, + moe_buf_i32, + mask_tensor, + M, + moe_buf_elems, + n_grid_blocks, + num_experts=num_experts, + topk=topk, + max_tokens=max_tokens, + unit_size=unit_size, + has_mask=has_mask, + ) + else: + # Prefill path: multiple kernels via HBM workspace + mesh_stride = ((M + unit_size - 1) // unit_size) * unit_size + ws_mesh_bytes = num_experts * mesh_stride + ws_mesh_i32 = (ws_mesh_bytes + 3) // 4 + ws_total = ws_mesh_i32 + (num_experts + 1) + if workspace is None: + workspace = torch.empty(ws_total, dtype=torch.int32, device=device) + + launch_moe_sorting_prefill_path( + topk_ids, + topk_weights, + sorted_ids, + sorted_weights, + sorted_expert_ids, + num_valid_ids, + moe_buf_i32, + workspace, + mask_tensor, + M, + moe_buf_elems, + mesh_stride, + ws_mesh_i32, + ws_total, + num_experts=num_experts, + topk=topk, + unit_size=unit_size, + has_mask=has_mask, + ) + + return sorted_ids, sorted_weights, sorted_expert_ids, num_valid_ids, moe_buf + + +def launch_moe_sorting_decode_path( + topk_ids, + topk_weights, + sorted_ids, + sorted_weights, + sorted_expert_ids, + num_valid_ids, + moe_buf_i32, + expert_mask, + i32_tokens, + i32_moe_buf_elems, + n_grid_blocks, + *, + num_experts, + topk, + max_tokens=128, + unit_size=UNIT_SIZE, + has_mask=False, +): + """Low-level launcher for decode path: single kernel. + + This is the hot-path entry point — no torch ops, just JIT dispatch. + Uses AOT-compiled dispatch after the first call to bypass the ~70 us + JIT overhead (inspect.Signature.bind + cache key + dict lookup). + """ + + cache_key = (num_experts, topk, max_tokens, unit_size, has_mask, topk_ids.device.index) + cf = _decode_cf_cache.get(cache_key) + if cf is not None: + stream = torch.cuda.current_stream() + cf( + topk_ids, + topk_weights, + sorted_ids, + sorted_weights, + sorted_expert_ids, + num_valid_ids, + moe_buf_i32, + expert_mask, + i32_tokens, + i32_moe_buf_elems, + n_grid_blocks, + fx.Stream(stream), + ) + return + + launch_fn = compile_moe_sorting_decode( + num_experts=num_experts, + topk=topk, + max_tokens=max_tokens, + unit_size=unit_size, + has_mask=has_mask, + ) + stream = torch.cuda.current_stream() + launch_fn( + topk_ids, + topk_weights, + sorted_ids, + sorted_weights, + sorted_expert_ids, + num_valid_ids, + moe_buf_i32, + expert_mask, + i32_tokens, + i32_moe_buf_elems, + n_grid_blocks, + stream=stream, + ) + + cf = flyc.compile( + launch_fn, + topk_ids, + topk_weights, + sorted_ids, + sorted_weights, + sorted_expert_ids, + num_valid_ids, + moe_buf_i32, + expert_mask, + i32_tokens, + i32_moe_buf_elems, + n_grid_blocks, + fx.Stream(stream), + ) + _decode_cf_cache[cache_key] = cf + + +def launch_moe_sorting_prefill_path( + topk_ids, + topk_weights, + sorted_ids, + sorted_weights, + sorted_expert_ids, + num_valid_ids, + moe_buf_i32, + workspace, + expert_mask, + i32_tokens, + i32_moe_buf_elems, + mesh_stride, + mesh_size, + ws_total, + *, + num_experts, + topk, + unit_size=UNIT_SIZE, + has_mask=False, +): + """Low-level launcher for prefill path via HBM workspace. + + For small T (<=512): fused P0_v2 (clear+scatter+count) + K4. + For large T: 4 separate kernels K1+K2+K3+K4. + + Uses AOT-compiled dispatch after the first call for each sub-kernel + to bypass JIT overhead. + """ + + launch_clear_ws, launch_p0, launch_p1, launch_p23, launch_p0v2, launch_p0v2_p23, launch_4k_fused = ( + compile_moe_sorting_prefill( + num_experts=num_experts, + topk=topk, + unit_size=unit_size, + has_mask=has_mask, + ) + ) + + stream = torch.cuda.current_stream() + stream_arg = fx.Stream(stream) + + num_cu = torch.cuda.get_device_properties(topk_ids.device).multi_processor_count + + base_key = (num_experts, topk, unit_size, has_mask, topk_ids.device.index) + use_p0v2 = i32_tokens <= 2048 + + target_occupancy = 2 + n_zero_blocks = min((i32_moe_buf_elems + BLOCK_SIZE - 1) // BLOCK_SIZE, num_cu * target_occupancy) + k4_grid = num_experts + n_zero_blocks + + if use_p0v2: + ck = base_key + ("p0v2_p23",) + cf = _prefill_cf_cache.get(ck) + if cf is not None: + cf( + topk_ids, + workspace, + topk_weights, + sorted_ids, + sorted_weights, + sorted_expert_ids, + num_valid_ids, + moe_buf_i32, + expert_mask, + i32_tokens, + mesh_stride, + mesh_size, + i32_moe_buf_elems, + k4_grid, + stream_arg, + ) + else: + launch_p0v2_p23( + topk_ids, + workspace, + topk_weights, + sorted_ids, + sorted_weights, + sorted_expert_ids, + num_valid_ids, + moe_buf_i32, + expert_mask, + i32_tokens, + mesh_stride, + mesh_size, + i32_moe_buf_elems, + k4_grid, + stream=stream, + ) + cf = flyc.compile( + launch_p0v2_p23, + topk_ids, + workspace, + topk_weights, + sorted_ids, + sorted_weights, + sorted_expert_ids, + num_valid_ids, + moe_buf_i32, + expert_mask, + i32_tokens, + mesh_stride, + mesh_size, + i32_moe_buf_elems, + k4_grid, + stream_arg, + ) + _prefill_cf_cache[ck] = cf + return + + # 4-kernel path (T > 2048): fused clear+p0+p1+p23 + k1_grid = (ws_total + 1023) // 1024 + k2_grid = min(num_cu * target_occupancy, (i32_tokens * topk + 255) // 256) + k2_total = i32_tokens * topk + k2_stride = k2_grid * 256 + k2_niters = (k2_total + k2_stride - 1) // k2_stride + + ck = base_key + ("4k_fused",) + cf = _prefill_cf_cache.get(ck) + if cf is not None: + cf( + topk_ids, + workspace, + topk_weights, + sorted_ids, + sorted_weights, + sorted_expert_ids, + num_valid_ids, + moe_buf_i32, + expert_mask, + i32_tokens, + mesh_stride, + mesh_size, + i32_moe_buf_elems, + ws_total, + k2_niters, + k1_grid, + k2_grid, + k4_grid, + stream_arg, + ) + else: + launch_4k_fused( + topk_ids, + workspace, + topk_weights, + sorted_ids, + sorted_weights, + sorted_expert_ids, + num_valid_ids, + moe_buf_i32, + expert_mask, + i32_tokens, + mesh_stride, + mesh_size, + i32_moe_buf_elems, + ws_total, + k2_niters, + k1_grid, + k2_grid, + k4_grid, + stream=stream, + ) + cf = flyc.compile( + launch_4k_fused, + topk_ids, + workspace, + topk_weights, + sorted_ids, + sorted_weights, + sorted_expert_ids, + num_valid_ids, + moe_buf_i32, + expert_mask, + i32_tokens, + mesh_stride, + mesh_size, + i32_moe_buf_elems, + ws_total, + k2_niters, + k1_grid, + k2_grid, + k4_grid, + stream_arg, + ) + _prefill_cf_cache[ck] = cf diff --git a/aiter/ops/flydsl/moe_sorting.py b/aiter/ops/flydsl/moe_sorting.py new file mode 100644 index 0000000000..bb3894e283 --- /dev/null +++ b/aiter/ops/flydsl/moe_sorting.py @@ -0,0 +1,64 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. + +"""FlyDSL MoE sorting kernel — drop-in replacement for CK/Opus moe_sorting_fwd. + +Provides `flydsl_moe_sorting_fwd()` with the same signature as +`aiter.moe_sorting_fwd()` so it can be used as a direct dispatch target +in `_moe_sorting_impl()`. + +Workspace is pre-allocated here (not inside the kernel) so that CUDA graph +capture sees deterministic allocations. +""" + +import torch + +_workspace_cache = {} + + +def flydsl_moe_sorting_fwd( + topk_ids, + topk_weights, + sorted_ids, + sorted_weights, + sorted_expert_ids, + num_valid_ids, + moe_buf, + num_experts, + unit_size, + expert_mask=None, + num_local_tokens=None, +): + from .kernels.moe_sorting_kernel import ( + moe_sorting_flydsl, + moe_sorting_get_workspace_size, + ) + + M = topk_ids.shape[0] + topk = topk_ids.shape[1] + device = topk_ids.device + + # Pre-allocate workspace (cached per config for CUDA graph compatibility) + ws_size = moe_sorting_get_workspace_size(M, num_experts, topk, unit_size) + workspace = None + if ws_size > 0: + cache_key = (ws_size, device) + workspace = _workspace_cache.get(cache_key) + if workspace is None or workspace.numel() < ws_size: + workspace = torch.empty(ws_size, dtype=torch.int32, device=device) + _workspace_cache[cache_key] = workspace + + moe_sorting_flydsl( + topk_ids, + topk_weights, + sorted_ids, + sorted_weights, + sorted_expert_ids, + num_valid_ids, + moe_buf, + num_experts, + unit_size, + expert_mask, + num_local_tokens, + workspace, + ) From 13e93cf0e71dff11f63418150388b4f5dfe72695 Mon Sep 17 00:00:00 2001 From: amd-weisun Date: Tue, 19 May 2026 12:35:16 +0100 Subject: [PATCH 2/5] fix python style --- aiter/fused_moe.py | 36 +- .../ops/flydsl/kernels/moe_sorting_kernel.py | 594 ++++++++++++++---- 2 files changed, 486 insertions(+), 144 deletions(-) diff --git a/aiter/fused_moe.py b/aiter/fused_moe.py index f23ddc5135..f7f908882b 100644 --- a/aiter/fused_moe.py +++ b/aiter/fused_moe.py @@ -107,8 +107,14 @@ def _moe_sorting_impl( def _flydsl_moe_sorting( - topk_ids, topk_weights, num_experts, model_dim, - moebuf_dtype, block_size, expert_mask, num_local_tokens, + topk_ids, + topk_weights, + num_experts, + model_dim, + moebuf_dtype, + block_size, + expert_mask, + num_local_tokens, ): """FlyDSL sorting dispatch — called outside torch_compile_guard.""" from aiter.ops.flydsl.moe_sorting import flydsl_moe_sorting_fwd @@ -126,11 +132,17 @@ def _flydsl_moe_sorting( moe_buf = torch.empty((M, model_dim), dtype=moebuf_dtype, device=device) flydsl_moe_sorting_fwd( - topk_ids, topk_weights, - sorted_ids, sorted_weights, sorted_expert_ids, - num_valid_ids, moe_buf, - num_experts, int(block_size), - expert_mask, num_local_tokens, + topk_ids, + topk_weights, + sorted_ids, + sorted_weights, + sorted_expert_ids, + num_valid_ids, + moe_buf, + num_experts, + int(block_size), + expert_mask, + num_local_tokens, ) return sorted_ids, sorted_weights, sorted_expert_ids, num_valid_ids, moe_buf @@ -149,8 +161,14 @@ def moe_sorting( ): if _USE_FLYDSL_MOE_SORTING: return _flydsl_moe_sorting( - topk_ids, topk_weights, num_experts, model_dim, - moebuf_dtype, block_size, expert_mask, num_local_tokens, + topk_ids, + topk_weights, + num_experts, + model_dim, + moebuf_dtype, + block_size, + expert_mask, + num_local_tokens, ) try: return _moe_sorting_impl( diff --git a/aiter/ops/flydsl/kernels/moe_sorting_kernel.py b/aiter/ops/flydsl/kernels/moe_sorting_kernel.py index dc403c631a..9bc077d6d3 100644 --- a/aiter/ops/flydsl/kernels/moe_sorting_kernel.py +++ b/aiter/ops/flydsl/kernels/moe_sorting_kernel.py @@ -75,7 +75,9 @@ def _dpp_intra_wave_prefix_sum(val, lane, WARP_SIZE): (4, DPP_ROW_SHR_4, 4), (8, DPP_ROW_SHR_8, 8), ]: - remote = fly_rocdl.update_dpp(T.i32, zero_raw, val_raw, dpp_op, DPP_ROW_MASK, DPP_BANK_MASK, True) + remote = fly_rocdl.update_dpp( + T.i32, zero_raw, val_raw, dpp_op, DPP_ROW_MASK, DPP_BANK_MASK, True + ) val = (lane >= fx.Int32(threshold)).select(val + fx.Int32(remote), val) val_raw = _unwrap_val(val) @@ -97,8 +99,12 @@ def _dpp_intra_wave_prefix_sum(val, lane, WARP_SIZE): # returns a CompiledFunction whose __call__ skips inspect.Signature.bind, # _make_cache_key, and dict lookup, reducing dispatch from ~70 us to ~5 us. # --------------------------------------------------------------------------- -_decode_cf_cache = {} # (num_experts, topk, max_tokens, unit_size, has_mask, device) -> CompiledFunction -_prefill_cf_cache = {} # (num_experts, topk, unit_size, kernel_name, *constexpr_vals) -> CompiledFunction +_decode_cf_cache = ( + {} +) # (num_experts, topk, max_tokens, unit_size, has_mask, device) -> CompiledFunction +_prefill_cf_cache = ( + {} +) # (num_experts, topk, unit_size, kernel_name, *constexpr_vals) -> CompiledFunction _dummy_mask_cache = {} # device -> torch.Tensor(1, dtype=i32, value=1) @@ -164,7 +170,8 @@ def compile_moe_sorting_decode( cumsum_bufs = 2 if r < (cumsum_bufs + sub_unroll): raise ValueError( - f"LDS too small for E={E}: need at least " f"{(cumsum_bufs + sub_unroll) * smem_cols * 4} bytes" + f"LDS too small for E={E}: need at least " + f"{(cumsum_bufs + sub_unroll) * smem_cols * 4} bytes" ) r_for_sub = ((r - cumsum_bufs) // sub_unroll) * sub_unroll r_token_min = ((max_tokens + sub_unroll - 1) // sub_unroll) * sub_unroll @@ -236,11 +243,21 @@ def moe_sorting_decode_kernel( # Buffer resources (needed by both paths, defined at top level) moe_buf_rsrc = buffer_ops.create_buffer_resource(moe_buf, max_size=True) - topk_ids_rsrc = buffer_ops.create_buffer_resource(topk_ids_tensor, max_size=True) - weights_rsrc = buffer_ops.create_buffer_resource(topk_weights_tensor, max_size=True) - sorted_ids_rsrc = buffer_ops.create_buffer_resource(sorted_token_ids, max_size=True) - sorted_w_rsrc = buffer_ops.create_buffer_resource(sorted_weights_out, max_size=True) - sorted_e_rsrc = buffer_ops.create_buffer_resource(sorted_expert_ids, max_size=True) + topk_ids_rsrc = buffer_ops.create_buffer_resource( + topk_ids_tensor, max_size=True + ) + weights_rsrc = buffer_ops.create_buffer_resource( + topk_weights_tensor, max_size=True + ) + sorted_ids_rsrc = buffer_ops.create_buffer_resource( + sorted_token_ids, max_size=True + ) + sorted_w_rsrc = buffer_ops.create_buffer_resource( + sorted_weights_out, max_size=True + ) + sorted_e_rsrc = buffer_ops.create_buffer_resource( + sorted_expert_ids, max_size=True + ) nvalid_rsrc = buffer_ops.create_buffer_resource(num_valid_ids, max_size=True) mask_rsrc = buffer_ops.create_buffer_resource(expert_mask_tensor, max_size=True) @@ -248,7 +265,9 @@ def moe_sorting_decode_kernel( base_ptr = allocator.get_base() cumsum_mr = SmemPtr(base_ptr, cumsum_offset, T.i32, shape=(smem_cols,)).get() cumdup_mr = SmemPtr(base_ptr, cumdup_offset, T.i32, shape=(smem_cols,)).get() - mesh_mr = SmemPtr(base_ptr, mesh_offset, T.i32, shape=(sub_tokens * smem_cols,)).get() + mesh_mr = SmemPtr( + base_ptr, mesh_offset, T.i32, shape=(sub_tokens * smem_cols,) + ).get() c_topk = fx.Int32(topk) c_E = fx.Int32(E) @@ -265,7 +284,9 @@ def moe_sorting_decode_kernel( num_zero_blocks = gpu.grid_dim.x - c_one_i32 zero_stride_v4 = num_zero_blocks * fx.Int32(DECODE_BLOCK) i32_moe_buf_v4 = i32_moe_buf_elems >> fx.Int32(2) - zero_niters = (i32_moe_buf_v4 + zero_stride_v4 - c_one_i32) // zero_stride_v4 + zero_niters = ( + i32_moe_buf_v4 + zero_stride_v4 - c_one_i32 + ) // zero_stride_v4 _zs = fx.Index(0) _ze = ArithValue(zero_niters).index_cast(T.index) _z1 = fx.Index(1) @@ -303,7 +324,9 @@ def moe_sorting_decode_kernel( topk_slot = safe_flat % c_topk global_idx = token_id * c_topk + topk_slot - eid = buffer_ops.buffer_load(topk_ids_rsrc, global_idx, vec_width=1, dtype=T.i32) + eid = buffer_ops.buffer_load( + topk_ids_rsrc, global_idx, vec_width=1, dtype=T.i32 + ) # mesh[token_id, eid] = topk_slot + 1 (valid threads only). # Invalid threads must NOT write to mesh[0] — that would race @@ -378,7 +401,9 @@ def moe_sorting_decode_kernel( cvt_ix = ArithValue(safe_cvt_idx).index_cast(T.index) raw_cnt_cvt = _lds_load(cumsum_mr, cvt_ix) blocks_cvt = (raw_cnt_cvt + c_unit - c_one_i32) // c_unit - padded_cvt = (raw_cnt_cvt == c_zero_i32).select(c_zero_i32, blocks_cvt * c_unit) + padded_cvt = (raw_cnt_cvt == c_zero_i32).select( + c_zero_i32, blocks_cvt * c_unit + ) # Valid threads write padded value; invalid threads write 0 to cumsum[0] _lds_store(cumsum_mr, cvt_valid.select(padded_cvt, c_zero_i32), cvt_ix) gpu.barrier() @@ -391,20 +416,32 @@ def moe_sorting_decode_kernel( ep_eid = fx.Int32(i_ep) + tid ep_valid = ep_eid < c_E ep_safe_eid = ep_valid.select(ep_eid, c_zero_i32) - ep_m = buffer_ops.buffer_load(mask_rsrc, ep_safe_eid, vec_width=1, dtype=T.i32) + ep_m = buffer_ops.buffer_load( + mask_rsrc, ep_safe_eid, vec_width=1, dtype=T.i32 + ) should_zero = ep_valid & (ep_m == c_zero_i32) - ep_cs_ix = ArithValue(ep_valid.select(ep_eid + c_one_i32, c_zero_i32)).index_cast(T.index) - _lds_store(cumsum_mr, should_zero.select(c_zero_i32, _lds_load(cumsum_mr, ep_cs_ix)), ep_cs_ix) + ep_cs_ix = ArithValue( + ep_valid.select(ep_eid + c_one_i32, c_zero_i32) + ).index_cast(T.index) + _lds_store( + cumsum_mr, + should_zero.select(c_zero_i32, _lds_load(cumsum_mr, ep_cs_ix)), + ep_cs_ix, + ) gpu.barrier() # Step 2: All-wave parallel prefix sum (cumsum → cumdup). - scratch_mr = SmemPtr(base_ptr, scratch_offset, T.i32, shape=(NUM_WAVES,)).get() + scratch_mr = SmemPtr( + base_ptr, scratch_offset, T.i32, shape=(NUM_WAVES,) + ).get() # All threads read cumsum[tid+1] (in chunks for E > DECODE_BLOCK) for _ps_chunk in range_constexpr(0, E, DECODE_BLOCK): ps_eid = fx.Int32(_ps_chunk) + tid ps_valid = ps_eid < c_E - ps_safe_ix = ArithValue(ps_valid.select(ps_eid + c_one_i32, c_zero_i32)).index_cast(T.index) + ps_safe_ix = ArithValue( + ps_valid.select(ps_eid + c_one_i32, c_zero_i32) + ).index_cast(T.index) ps_val = ps_valid.select(_lds_load(cumsum_mr, ps_safe_ix), c_zero_i32) _lds_store(cumdup_mr, ps_val, ps_safe_ix) _lds_store(cumdup_mr, c_zero_i32, c_zero_i32) @@ -412,7 +449,10 @@ def moe_sorting_decode_kernel( # DPP prefix sum — all NUM_WAVES waves active ps_tid_valid = tid < c_E - val = ps_tid_valid.select(_lds_load(cumdup_mr, ArithValue(tid + c_one_i32).index_cast(T.index)), c_zero_i32) + val = ps_tid_valid.select( + _lds_load(cumdup_mr, ArithValue(tid + c_one_i32).index_cast(T.index)), + c_zero_i32, + ) val = _dpp_intra_wave_prefix_sum(val, lane, WARP_SIZE) # Cross-wave accumulation via scratch LDS @@ -423,14 +463,20 @@ def moe_sorting_decode_kernel( gpu.barrier() cross_ps = c_zero_i32 for _w_ps in range_constexpr(NUM_WAVES - 1): - w_ps_total = _lds_load_raw(scratch_mr, ArithValue(fx.Int32(_w_ps)).index_cast(T.index)) - cross_ps = (wave > fx.Int32(_w_ps)).select(cross_ps + w_ps_total, cross_ps) + w_ps_total = _lds_load_raw( + scratch_mr, ArithValue(fx.Int32(_w_ps)).index_cast(T.index) + ) + cross_ps = (wave > fx.Int32(_w_ps)).select( + cross_ps + w_ps_total, cross_ps + ) inclusive_ps = val + cross_ps _lds_store( cumdup_mr, ps_tid_valid.select(inclusive_ps, c_zero_i32), - ArithValue(ps_tid_valid.select(tid + c_one_i32, c_zero_i32)).index_cast(T.index), + ArithValue(ps_tid_valid.select(tid + c_one_i32, c_zero_i32)).index_cast( + T.index + ), ) gpu.barrier() @@ -438,11 +484,21 @@ def moe_sorting_decode_kernel( if E > DECODE_BLOCK: _if_t0_ext = scf.IfOp(is_t0.ir_value()) with _if_then(_if_t0_ext): - prev_ext = _lds_load(cumdup_mr, ArithValue(fx.Int32(DECODE_BLOCK)).index_cast(T.index)) + prev_ext = _lds_load( + cumdup_mr, + ArithValue(fx.Int32(DECODE_BLOCK)).index_cast(T.index), + ) for _ext in range_constexpr(DECODE_BLOCK, E): - cur_ext = _lds_load(cumdup_mr, ArithValue(fx.Int32(_ext + 1)).index_cast(T.index)) + cur_ext = _lds_load( + cumdup_mr, + ArithValue(fx.Int32(_ext + 1)).index_cast(T.index), + ) new_ext = prev_ext + cur_ext - _lds_store(cumdup_mr, new_ext, ArithValue(fx.Int32(_ext + 1)).index_cast(T.index)) + _lds_store( + cumdup_mr, + new_ext, + ArithValue(fx.Int32(_ext + 1)).index_cast(T.index), + ) prev_ext = new_ext gpu.barrier() @@ -474,9 +530,13 @@ def moe_sorting_decode_kernel( ml_eid = fx.Int32(i_ml) + tid ml_valid = ml_eid < c_E safe_ml_eid = ml_valid.select(ml_eid, c_zero_i32) - ml_mask = buffer_ops.buffer_load(mask_rsrc, safe_ml_eid, vec_width=1, dtype=T.i32) + ml_mask = buffer_ops.buffer_load( + mask_rsrc, safe_ml_eid, vec_width=1, dtype=T.i32 + ) ml_val = ml_valid.select(ml_mask, c_zero_i32) - ml_ix = ArithValue(ml_valid.select(ml_eid + c_one_i32, c_zero_i32)).index_cast(T.index) + ml_ix = ArithValue( + ml_valid.select(ml_eid + c_one_i32, c_zero_i32) + ).index_cast(T.index) _lds_store(cumdup_mr, ml_val, ml_ix) _lds_store(cumdup_mr, c_zero_i32, c_zero_i32) gpu.barrier() @@ -484,35 +544,54 @@ def moe_sorting_decode_kernel( # All-wave DPP prefix sum over mask values in cumdup m_tid_valid = tid < c_E mval = m_tid_valid.select( - _lds_load(cumdup_mr, ArithValue(tid + c_one_i32).index_cast(T.index)), c_zero_i32 + _lds_load( + cumdup_mr, ArithValue(tid + c_one_i32).index_cast(T.index) + ), + c_zero_i32, ) mval = _dpp_intra_wave_prefix_sum(mval, lane, WARP_SIZE) is_last_lane_m = lane == fx.Int32(WARP_SIZE - 1) _if_ll_m = scf.IfOp(is_last_lane_m.ir_value()) with _if_then(_if_ll_m): - _lds_store_raw(scratch_mr, mval, ArithValue(wave).index_cast(T.index)) + _lds_store_raw( + scratch_mr, mval, ArithValue(wave).index_cast(T.index) + ) gpu.barrier() cross_m = c_zero_i32 for _wm in range_constexpr(NUM_WAVES - 1): - wm_total = _lds_load_raw(scratch_mr, ArithValue(fx.Int32(_wm)).index_cast(T.index)) + wm_total = _lds_load_raw( + scratch_mr, ArithValue(fx.Int32(_wm)).index_cast(T.index) + ) cross_m = (wave > fx.Int32(_wm)).select(cross_m + wm_total, cross_m) inclusive_m = mval + cross_m _lds_store( cumdup_mr, m_tid_valid.select(inclusive_m, c_zero_i32), - ArithValue(m_tid_valid.select(tid + c_one_i32, c_zero_i32)).index_cast(T.index), + ArithValue( + m_tid_valid.select(tid + c_one_i32, c_zero_i32) + ).index_cast(T.index), ) gpu.barrier() if E > DECODE_BLOCK: _if_t0_ext_m = scf.IfOp(is_t0.ir_value()) with _if_then(_if_t0_ext_m): - prev_m = _lds_load(cumdup_mr, ArithValue(fx.Int32(DECODE_BLOCK)).index_cast(T.index)) + prev_m = _lds_load( + cumdup_mr, + ArithValue(fx.Int32(DECODE_BLOCK)).index_cast(T.index), + ) for _ext_m in range_constexpr(DECODE_BLOCK, E): - cur_m = _lds_load(cumdup_mr, ArithValue(fx.Int32(_ext_m + 1)).index_cast(T.index)) + cur_m = _lds_load( + cumdup_mr, + ArithValue(fx.Int32(_ext_m + 1)).index_cast(T.index), + ) new_m = prev_m + cur_m - _lds_store(cumdup_mr, new_m, ArithValue(fx.Int32(_ext_m + 1)).index_cast(T.index)) + _lds_store( + cumdup_mr, + new_m, + ArithValue(fx.Int32(_ext_m + 1)).index_cast(T.index), + ) prev_m = new_m gpu.barrier() @@ -525,7 +604,9 @@ def moe_sorting_decode_kernel( ml_valid = ml_eid < c_E safe_ml_eid = ml_valid.select(ml_eid, c_zero_i32) ml_ix = ArithValue(safe_ml_eid).index_cast(T.index) - _lds_store(cumdup_mr, ml_valid.select(safe_ml_eid, c_zero_i32), ml_ix) + _lds_store( + cumdup_mr, ml_valid.select(safe_ml_eid, c_zero_i32), ml_ix + ) gpu.barrier() # Write sorted_expert_ids — predicated stores to buffer (safe: buffer_store ignores OOB) @@ -548,7 +629,9 @@ def moe_sorting_decode_kernel( blk_start = e_start // c_unit blk_end = e_end // c_unit n_blks_wr = eid_wr_valid.select(blk_end - blk_start, c_zero_i32) - for _jb in range(fx.Index(0), ArithValue(n_blks_wr).index_cast(T.index), fx.Index(1)): + for _jb in range( + fx.Index(0), ArithValue(n_blks_wr).index_cast(T.index), fx.Index(1) + ): blk_idx = blk_start + fx.Int32(_jb) buffer_ops.buffer_store(local_eid, sorted_e_rsrc, blk_idx) gpu.barrier() @@ -567,8 +650,12 @@ def moe_sorting_decode_kernel( c_zero_pre = c_zero_i32 cs_E_ix_pre = ArithValue(c_E).index_cast(T.index) total_padded_pre = _lds_load(cumdup_mr, cs_E_ix_pre) - n_pre_iters = (total_padded_pre + fx.Int32(DECODE_BLOCK) - c_one_i32) // fx.Int32(DECODE_BLOCK) - for _pre in range(fx.Index(0), ArithValue(n_pre_iters).index_cast(T.index), fx.Index(1)): + n_pre_iters = ( + total_padded_pre + fx.Int32(DECODE_BLOCK) - c_one_i32 + ) // fx.Int32(DECODE_BLOCK) + for _pre in range( + fx.Index(0), ArithValue(n_pre_iters).index_cast(T.index), fx.Index(1) + ): pre_slot = fx.Int32(_pre) * fx.Int32(DECODE_BLOCK) + tid pre_valid = pre_slot < total_padded_pre safe_pre = pre_valid.select(pre_slot, c_oob_idx) @@ -591,7 +678,10 @@ def moe_sorting_decode_kernel( if has_mask: # EP: check if this expert is masked (skip scatter for masked experts) sc_mask_val = buffer_ops.buffer_load( - mask_rsrc, eid_sc_valid.select(eid_sc, c_zero_i32), vec_width=1, dtype=T.i32 + mask_rsrc, + eid_sc_valid.select(eid_sc, c_zero_i32), + vec_width=1, + dtype=T.i32, ) sc_expert_enabled = eid_sc_valid & (sc_mask_val != c_zero_i32) @@ -616,26 +706,50 @@ def moe_sorting_decode_kernel( # row_shr:1 remote = fly_rocdl.update_dpp( - T.i32, zero_raw, cnt_raw, DPP_ROW_SHR_1, DPP_ROW_MASK, DPP_BANK_MASK, True + T.i32, + zero_raw, + cnt_raw, + DPP_ROW_SHR_1, + DPP_ROW_MASK, + DPP_BANK_MASK, + True, ) should_add = lane_group_os >= c_one_i32 - local_cnt = should_add.select(local_cnt + fx.Int32(remote), local_cnt) + local_cnt = should_add.select( + local_cnt + fx.Int32(remote), local_cnt + ) # row_shr:2 cnt_raw = _unwrap(local_cnt) remote = fly_rocdl.update_dpp( - T.i32, zero_raw, cnt_raw, DPP_ROW_SHR_2, DPP_ROW_MASK, DPP_BANK_MASK, True + T.i32, + zero_raw, + cnt_raw, + DPP_ROW_SHR_2, + DPP_ROW_MASK, + DPP_BANK_MASK, + True, ) should_add = lane_group_os >= fx.Int32(2) - local_cnt = should_add.select(local_cnt + fx.Int32(remote), local_cnt) + local_cnt = should_add.select( + local_cnt + fx.Int32(remote), local_cnt + ) # row_shr:4 cnt_raw = _unwrap(local_cnt) remote = fly_rocdl.update_dpp( - T.i32, zero_raw, cnt_raw, DPP_ROW_SHR_4, DPP_ROW_MASK, DPP_BANK_MASK, True + T.i32, + zero_raw, + cnt_raw, + DPP_ROW_SHR_4, + DPP_ROW_MASK, + DPP_BANK_MASK, + True, ) should_add = lane_group_os >= fx.Int32(4) - local_cnt = should_add.select(local_cnt + fx.Int32(remote), local_cnt) + local_cnt = should_add.select( + local_cnt + fx.Int32(remote), local_cnt + ) # Broadcast batch total from last lane of group via ds_bpermute last_lane_of_group = tid | fx.Int32(7) # tid with lower 3 bits set @@ -651,8 +765,12 @@ def moe_sorting_decode_kernel( safe_slot = my_has_token.select(slot, c_oob_idx) buffer_ops.buffer_store(packed_id, sorted_ids_rsrc, safe_slot) - w_addr = my_has_token.select(my_sub * c_topk + topk_slot_sc, c_zero_i32) - w_val_i32 = buffer_ops.buffer_load(weights_rsrc, w_addr, vec_width=1, dtype=T.i32) + w_addr = my_has_token.select( + my_sub * c_topk + topk_slot_sc, c_zero_i32 + ) + w_val_i32 = buffer_ops.buffer_load( + weights_rsrc, w_addr, vec_width=1, dtype=T.i32 + ) buffer_ops.buffer_store(w_val_i32, sorted_w_rsrc, safe_slot) # Advance position by batch total @@ -752,7 +870,7 @@ def compile_moe_sorting_prefill( K3: P1 count — one block per expert, count non-zero mesh cells K4: P23 prefix-sum + scatter — prefix-sum on counts, scatter tokens, fill sorted_expert_ids, zero moe_buf - P0_v2: Fused clear+scatter+count — replaces K1+K2+K3 for small T (<=512) + P0_v2: Fused clear+scatter+count — replaces K1+K2+K3 for T <= 2048 Workspace layout (i32 elements): [0 .. ws_mesh_i32) : uint8 expert mesh (E rows x mesh_stride bytes, packed into i32) @@ -836,7 +954,9 @@ def p0_scatter_kernel( val_i8 = ArithValue(topk_slot + c_one).trunci(T.i8) _if_valid_k2 = scf.IfOp(valid.ir_value()) with _if_then(_if_valid_k2): - buffer_ops.buffer_store(val_i8, ws_rsrc, byte_offset, offset_is_bytes=True) + buffer_ops.buffer_store( + val_i8, ws_rsrc, byte_offset, offset_is_bytes=True + ) @flyc.jit def launch_p0( @@ -848,7 +968,9 @@ def launch_p0( n_grid: fx.Int32, stream: fx.Stream = fx.Stream(None), ): - launcher = p0_scatter_kernel(topk_ids, workspace, i32_tokens, i32_mesh_stride, i32_niters) + launcher = p0_scatter_kernel( + topk_ids, workspace, i32_tokens, i32_mesh_stride, i32_niters + ) launcher.launch(grid=(n_grid, 1, 1), block=(K2_BLOCK, 1, 1), stream=stream) # --- K3: P1 count kernel ------------------------------------------------- @@ -884,24 +1006,41 @@ def p1_count_kernel( c_ff = fx.Int32(0xFF) base_ptr = k3_allocator.get_base() - reduce_mr = SmemPtr(base_ptr, k3_reduce_offset, T.i32, shape=(K3_NUM_WAVES,)).get() + reduce_mr = SmemPtr( + base_ptr, k3_reduce_offset, T.i32, shape=(K3_NUM_WAVES,) + ).get() mesh_row_i32_base = (eid * i32_mesh_stride) >> fx.Int32(2) i32_words_per_row = i32_mesh_stride >> fx.Int32(2) - n_iters = (i32_words_per_row + fx.Int32(K3_WORDS_PER_ITER - 1)) >> fx.Int32(K3_WORDS_PER_ITER_LOG2) + n_iters = (i32_words_per_row + fx.Int32(K3_WORDS_PER_ITER - 1)) >> fx.Int32( + K3_WORDS_PER_ITER_LOG2 + ) if has_mask: - mask_rsrc = buffer_ops.create_buffer_resource(expert_mask_tensor, max_size=True) + mask_rsrc = buffer_ops.create_buffer_resource( + expert_mask_tensor, max_size=True + ) p1_mask = buffer_ops.buffer_load(mask_rsrc, eid, vec_width=1, dtype=T.i32) p1_is_local = p1_mask != c_zero p1_should_zero = (~p1_is_local) & (tid == c_zero) - buffer_ops.buffer_store(c_zero, ws_rsrc, p1_should_zero.select(i32_mesh_size + eid, fx.Int32(0x7FFFFFFF))) + buffer_ops.buffer_store( + c_zero, + ws_rsrc, + p1_should_zero.select(i32_mesh_size + eid, fx.Int32(0x7FFFFFFF)), + ) n_iters = p1_is_local.select(n_iters, c_zero) - for _i, state in range(fx.Index(0), ArithValue(n_iters).index_cast(T.index), fx.Index(1), init=[c_zero]): + for _i, state in range( + fx.Index(0), + ArithValue(n_iters).index_cast(T.index), + fx.Index(1), + init=[c_zero], + ): cnt_so_far = state[0] - word_base = fx.Int32(_i) * fx.Int32(K3_WORDS_PER_ITER) + tid * fx.Int32(K3_VEC_WIDTH) + word_base = fx.Int32(_i) * fx.Int32(K3_WORDS_PER_ITER) + tid * fx.Int32( + K3_VEC_WIDTH + ) valid = word_base < i32_words_per_row safe_addr = mesh_row_i32_base + valid.select(word_base, c_zero) vec4 = buffer_ops.buffer_load(ws_rsrc, safe_addr, vec_width=4, dtype=T.i32) @@ -943,7 +1082,9 @@ def p1_count_kernel( is_t0 = tid == c_zero total = c_zero for _w in range_constexpr(K3_NUM_WAVES): - total = total + _lds_load_raw(reduce_mr, ArithValue(fx.Int32(_w)).index_cast(T.index)) + total = total + _lds_load_raw( + reduce_mr, ArithValue(fx.Int32(_w)).index_cast(T.index) + ) cs_offset = i32_mesh_size + eid c_oob_idx = fx.Int32(0x7FFFFFFF) @@ -963,10 +1104,12 @@ def launch_p1( with ir.InsertionPoint(ctx.gpu_module_body): k3_allocator.finalize() - launcher = p1_count_kernel(workspace, expert_mask_tensor, i32_mesh_stride, i32_mesh_size) + launcher = p1_count_kernel( + workspace, expert_mask_tensor, i32_mesh_stride, i32_mesh_size + ) launcher.launch(grid=(E, 1, 1), block=(K3_BLOCK, 1, 1), stream=stream) - # --- P0_v2: Fused clear+scatter+count kernel (for small T) --------------- + # --- P0_v2: Fused clear+scatter+count kernel (for T <= 2048) -------------- # Replaces K1+K2+K3 with a single kernel launch. # Grid: E blocks (one per expert), Block: 512 threads (matching CK P0_v2). # Phase 1: clear this expert's mesh row @@ -1009,7 +1152,9 @@ def p0v2_kernel( c_block = fx.Int32(P0V2_BLOCK) base_ptr = p0v2_allocator.get_base() - reduce_mr = SmemPtr(base_ptr, p0v2_reduce_offset, T.i32, shape=(P0V2_NUM_WAVES,)).get() + reduce_mr = SmemPtr( + base_ptr, p0v2_reduce_offset, T.i32, shape=(P0V2_NUM_WAVES,) + ).get() # Precompute mesh row base (in i32 words) and words per row mesh_row_i32_base = (eid * i32_mesh_stride) >> fx.Int32(2) @@ -1024,12 +1169,16 @@ def p0v2_kernel( m_val = buffer_ops.buffer_load(mask_rsrc, eid, vec_width=1, dtype=T.i32) is_local_expert = m_val != c_zero should_write_zero = (~is_local_expert) & (tid == c_zero) - buffer_ops.buffer_store(c_zero, ws_rsrc, should_write_zero.select(i32_mesh_size + eid, c_oob)) + buffer_ops.buffer_store( + c_zero, ws_rsrc, should_write_zero.select(i32_mesh_size + eid, c_oob) + ) clear_niters = is_local_expert.select(clear_niters, c_zero) scatter_niters = is_local_expert.select(scatter_niters, c_zero) # ---- Phase 1: Clear this expert's mesh row ---- - for _ci in range(fx.Index(0), ArithValue(clear_niters).index_cast(T.index), fx.Index(1)): + for _ci in range( + fx.Index(0), ArithValue(clear_niters).index_cast(T.index), fx.Index(1) + ): word_idx = fx.Int32(_ci) * c_block + tid valid = word_idx < i32_words_per_row safe_idx = mesh_row_i32_base + valid.select(word_idx, c_zero) @@ -1038,15 +1187,27 @@ def p0v2_kernel( gpu.barrier() # ---- Phase 2: Scatter (scan all T*topk, filter by expert) ---- - for _si in range(fx.Index(0), ArithValue(scatter_niters).index_cast(T.index), fx.Index(1)): + for _si in range( + fx.Index(0), ArithValue(scatter_niters).index_cast(T.index), fx.Index(1) + ): flat = fx.Int32(_si) * c_block + tid valid = flat < total_assignments safe_flat = valid.select(flat, c_zero) - token_id = safe_flat >> fx.Int32(_p0v2_topk_log2) if _p0v2_topk_is_po2 else safe_flat // c_topk - topk_slot = safe_flat & fx.Int32(topk - 1) if _p0v2_topk_is_po2 else safe_flat % c_topk + token_id = ( + safe_flat >> fx.Int32(_p0v2_topk_log2) + if _p0v2_topk_is_po2 + else safe_flat // c_topk + ) + topk_slot = ( + safe_flat & fx.Int32(topk - 1) + if _p0v2_topk_is_po2 + else safe_flat % c_topk + ) - expert_id = buffer_ops.buffer_load(topk_rsrc, safe_flat, vec_width=1, dtype=T.i32) + expert_id = buffer_ops.buffer_load( + topk_rsrc, safe_flat, vec_width=1, dtype=T.i32 + ) is_mine = valid & (expert_id == eid) byte_offset = eid * i32_mesh_stride + token_id @@ -1055,13 +1216,20 @@ def p0v2_kernel( # Use conditional branch to skip the store for non-matching threads. _if_mine = scf.IfOp(is_mine.ir_value()) with _if_then(_if_mine): - buffer_ops.buffer_store(val_i8, ws_rsrc, byte_offset, offset_is_bytes=True) + buffer_ops.buffer_store( + val_i8, ws_rsrc, byte_offset, offset_is_bytes=True + ) gpu.barrier() # ---- Phase 3: Count non-zero bytes + warp/cross-wave reduce ---- count_niters = clear_niters # same loop structure, reuse (already EP-gated) - for _ki, state in range(fx.Index(0), ArithValue(count_niters).index_cast(T.index), fx.Index(1), init=[c_zero]): + for _ki, state in range( + fx.Index(0), + ArithValue(count_niters).index_cast(T.index), + fx.Index(1), + init=[c_zero], + ): cnt_so_far = state[0] word_base = fx.Int32(_ki) * c_block + tid @@ -1102,7 +1270,9 @@ def p0v2_kernel( is_t0 = tid == c_zero total = c_zero for _w in range_constexpr(P0V2_NUM_WAVES): - total = total + _lds_load_raw(reduce_mr, ArithValue(fx.Int32(_w)).index_cast(T.index)) + total = total + _lds_load_raw( + reduce_mr, ArithValue(fx.Int32(_w)).index_cast(T.index) + ) cs_offset = i32_mesh_size + eid c_oob_idx = fx.Int32(0x7FFFFFFF) @@ -1124,7 +1294,14 @@ def launch_p0v2( with ir.InsertionPoint(ctx.gpu_module_body): p0v2_allocator.finalize() - launcher = p0v2_kernel(topk_ids, workspace, expert_mask_tensor, i32_tokens, i32_mesh_stride, i32_mesh_size) + launcher = p0v2_kernel( + topk_ids, + workspace, + expert_mask_tensor, + i32_tokens, + i32_mesh_stride, + i32_mesh_size, + ) launcher.launch(grid=(E, 1, 1), block=(P0V2_BLOCK, 1, 1), stream=stream) # --- K4: P23 prefix-sum + scatter + moe_buf zeroing --------------------- @@ -1173,15 +1350,25 @@ def p23_kernel( # Buffer resources ws_rsrc = buffer_ops.create_buffer_resource(workspace, max_size=True) - weights_rsrc = buffer_ops.create_buffer_resource(topk_weights_tensor, max_size=True) - sorted_ids_rsrc = buffer_ops.create_buffer_resource(sorted_token_ids, max_size=True) - sorted_w_rsrc = buffer_ops.create_buffer_resource(sorted_weights_out, max_size=True) + weights_rsrc = buffer_ops.create_buffer_resource( + topk_weights_tensor, max_size=True + ) + sorted_ids_rsrc = buffer_ops.create_buffer_resource( + sorted_token_ids, max_size=True + ) + sorted_w_rsrc = buffer_ops.create_buffer_resource( + sorted_weights_out, max_size=True + ) mask_rsrc = buffer_ops.create_buffer_resource(expert_mask_tensor, max_size=True) # LDS: cumsum[E+1] for prefix sums + cross-wave scratch base_ptr = k4_allocator.get_base() - cumsum_mr = SmemPtr(base_ptr, k4_cumsum_offset, T.i32, shape=(k4_smem_cols,)).get() - scatter_mr = SmemPtr(base_ptr, k4_scatter_offset, T.i32, shape=(K4_NUM_WAVES,)).get() + cumsum_mr = SmemPtr( + base_ptr, k4_cumsum_offset, T.i32, shape=(k4_smem_cols,) + ).get() + scatter_mr = SmemPtr( + base_ptr, k4_scatter_offset, T.i32, shape=(K4_NUM_WAVES,) + ).get() is_sort_block = bid < c_E is_zero_block = bid >= c_E @@ -1197,7 +1384,9 @@ def p23_kernel( i32_moe_buf_v4 = i32_moe_buf_elems >> fx.Int32(2) zero_niters = (i32_moe_buf_v4 + zero_stride_v4 - c_one) // zero_stride_v4 c_zero_v4 = fx.Vector.filled(4, 0, fx.Int32) - for _z in range(fx.Index(0), ArithValue(zero_niters).index_cast(T.index), fx.Index(1)): + for _z in range( + fx.Index(0), ArithValue(zero_niters).index_cast(T.index), fx.Index(1) + ): z_idx_v4 = zero_gid_v4 + fx.Int32(_z) * zero_stride_v4 z_valid = z_idx_v4 < i32_moe_buf_v4 z_elem = z_valid.select(z_idx_v4 * c4, c_oob_idx) @@ -1216,7 +1405,9 @@ def p23_kernel( is_t0_init = tid == c_zero _if_init_cs = scf.IfOp(is_t0_init.ir_value()) with _if_then(_if_init_cs): - _lds_store_raw(cumsum_mr, c_zero, ArithValue(c_zero).index_cast(T.index)) + _lds_store_raw( + cumsum_mr, c_zero, ArithValue(c_zero).index_cast(T.index) + ) # EP: load this thread's own mask value BEFORE the chunked loop. # The chunked loop overwrites p23_mask_val in later chunks, so we @@ -1225,7 +1416,10 @@ def p23_kernel( if has_mask: tid_has_expert = tid < c_E my_mask_val = buffer_ops.buffer_load( - mask_rsrc, tid_has_expert.select(tid, c_zero), vec_width=1, dtype=T.i32 + mask_rsrc, + tid_has_expert.select(tid, c_zero), + vec_width=1, + dtype=T.i32, ) my_mask_val = tid_has_expert.select(my_mask_val, c_zero) @@ -1233,13 +1427,18 @@ def p23_kernel( expert_idx = fx.Int32(_chunk) + tid tid_valid_expert = expert_idx < c_E ws_cs_addr = i32_mesh_size + tid_valid_expert.select(expert_idx, c_zero) - raw_cnt = buffer_ops.buffer_load(ws_rsrc, ws_cs_addr, vec_width=1, dtype=T.i32) + raw_cnt = buffer_ops.buffer_load( + ws_rsrc, ws_cs_addr, vec_width=1, dtype=T.i32 + ) raw_cnt = tid_valid_expert.select(raw_cnt, c_zero) blocks = (raw_cnt + c_unit - c_one) // c_unit padded = (raw_cnt == c_zero).select(c_zero, blocks * c_unit) if has_mask: chunk_mask = buffer_ops.buffer_load( - mask_rsrc, tid_valid_expert.select(expert_idx, c_zero), vec_width=1, dtype=T.i32 + mask_rsrc, + tid_valid_expert.select(expert_idx, c_zero), + vec_width=1, + dtype=T.i32, ) chunk_mask = tid_valid_expert.select(chunk_mask, c_zero) padded = (chunk_mask == c_zero).select(c_zero, padded) @@ -1247,7 +1446,11 @@ def p23_kernel( oob = raw_store_idx >= fx.Int32(k4_smem_cols) safe_store_idx = oob.select(c_zero, raw_store_idx) safe_store_val = oob.select(c_zero, padded) - _lds_store_raw(cumsum_mr, safe_store_val, ArithValue(safe_store_idx).index_cast(T.index)) + _lds_store_raw( + cumsum_mr, + safe_store_val, + ArithValue(safe_store_idx).index_cast(T.index), + ) gpu.barrier() # Step 2: Prefix sum over cumsum LDS. When E <= K4_BLOCK (256), @@ -1264,14 +1467,22 @@ def p23_kernel( gpu.barrier() cross_offset = c_zero for _w in range_constexpr(K4_NUM_WAVES - 1): - w_total = _lds_load_raw(scatter_mr, ArithValue(fx.Int32(_w)).index_cast(T.index)) - cross_offset = (wave > fx.Int32(_w)).select(cross_offset + w_total, cross_offset) + w_total = _lds_load_raw( + scatter_mr, ArithValue(fx.Int32(_w)).index_cast(T.index) + ) + cross_offset = (wave > fx.Int32(_w)).select( + cross_offset + w_total, cross_offset + ) total_padded = c_zero for _w in range_constexpr(K4_NUM_WAVES): - total_padded = total_padded + _lds_load_raw(scatter_mr, ArithValue(fx.Int32(_w)).index_cast(T.index)) + total_padded = total_padded + _lds_load_raw( + scatter_mr, ArithValue(fx.Int32(_w)).index_cast(T.index) + ) inclusive_prefix = val + cross_offset - _lds_store_raw(cumsum_mr, inclusive_prefix, ArithValue(tid + c_one).index_cast(T.index)) + _lds_store_raw( + cumsum_mr, inclusive_prefix, ArithValue(tid + c_one).index_cast(T.index) + ) gpu.barrier() # For E > K4_BLOCK: thread 0 serially extends the prefix sum @@ -1279,19 +1490,34 @@ def p23_kernel( if E > K4_BLOCK: _if_t0_extra = scf.IfOp(is_t0_init.ir_value()) with _if_then(_if_t0_extra): - prev_sum = _lds_load_raw(cumsum_mr, ArithValue(fx.Int32(K4_BLOCK)).index_cast(T.index)) + prev_sum = _lds_load_raw( + cumsum_mr, ArithValue(fx.Int32(K4_BLOCK)).index_cast(T.index) + ) for _extra in range_constexpr(K4_BLOCK, E): - cur = _lds_load_raw(cumsum_mr, ArithValue(fx.Int32(_extra + 1)).index_cast(T.index)) + cur = _lds_load_raw( + cumsum_mr, + ArithValue(fx.Int32(_extra + 1)).index_cast(T.index), + ) new_sum = prev_sum + cur - _lds_store_raw(cumsum_mr, new_sum, ArithValue(fx.Int32(_extra + 1)).index_cast(T.index)) + _lds_store_raw( + cumsum_mr, + new_sum, + ArithValue(fx.Int32(_extra + 1)).index_cast(T.index), + ) prev_sum = new_sum total_padded = prev_sum gpu.barrier() - total_padded = _lds_load_raw(cumsum_mr, ArithValue(c_E).index_cast(T.index)) + total_padded = _lds_load_raw( + cumsum_mr, ArithValue(c_E).index_cast(T.index) + ) # Read my_start and my_end from cumsum LDS - my_start = _lds_load_raw(cumsum_mr, ArithValue(my_expert).index_cast(T.index)) - my_end = _lds_load_raw(cumsum_mr, ArithValue(my_expert + c_one).index_cast(T.index)) + my_start = _lds_load_raw( + cumsum_mr, ArithValue(my_expert).index_cast(T.index) + ) + my_end = _lds_load_raw( + cumsum_mr, ArithValue(my_expert + c_one).index_cast(T.index) + ) local_idx_p23 = tid if has_mask: @@ -1303,11 +1529,15 @@ def p23_kernel( is_last_lane_pm = lane == fx.Int32(WARP_SIZE - 1) _if_ll_pm = scf.IfOp(is_last_lane_pm.ir_value()) with _if_then(_if_ll_pm): - _lds_store_raw(scatter_mr, p23_mv, ArithValue(wave).index_cast(T.index)) + _lds_store_raw( + scatter_mr, p23_mv, ArithValue(wave).index_cast(T.index) + ) gpu.barrier() cross_pm = c_zero for _w in range_constexpr(K4_NUM_WAVES - 1): - wtp = _lds_load_raw(scatter_mr, ArithValue(fx.Int32(_w)).index_cast(T.index)) + wtp = _lds_load_raw( + scatter_mr, ArithValue(fx.Int32(_w)).index_cast(T.index) + ) cross_pm = (wave > fx.Int32(_w)).select(cross_pm + wtp, cross_pm) p23_mask_inclusive = p23_mv + cross_pm local_idx_p23 = p23_mask_inclusive - my_mask_val @@ -1320,13 +1550,17 @@ def p23_kernel( is_b0_t0 = is_b0 & is_t0 _if_nv = scf.IfOp(is_b0_t0.ir_value()) with _if_then(_if_nv): - nvalid_rsrc = buffer_ops.create_buffer_resource(num_valid_ids, max_size=True) + nvalid_rsrc = buffer_ops.create_buffer_resource( + num_valid_ids, max_size=True + ) buffer_ops.buffer_store(total_padded, nvalid_rsrc, c_zero) buffer_ops.buffer_store(i32_tokens, nvalid_rsrc, c_one) # Step 3: Write sorted_expert_ids for THIS expert (using local_idx_p23 for EP) # Store local_idx to LDS cumsum[tid], barrier, read cumsum[my_expert] - _lds_store_raw(cumsum_mr, local_idx_p23, ArithValue(tid).index_cast(T.index)) + _lds_store_raw( + cumsum_mr, local_idx_p23, ArithValue(tid).index_cast(T.index) + ) # For E > K4_BLOCK: thread 0 writes local_idx for extra experts if E > K4_BLOCK: is_t0_extra3 = tid == c_zero @@ -1334,57 +1568,85 @@ def p23_kernel( with _if_then(_if_t0_e3): if has_mask: # EP: serially extend mask prefix sum for experts >= K4_BLOCK - prev_local = _lds_load_raw(cumsum_mr, ArithValue(fx.Int32(K4_BLOCK - 1)).index_cast(T.index)) + prev_local = _lds_load_raw( + cumsum_mr, + ArithValue(fx.Int32(K4_BLOCK - 1)).index_cast(T.index), + ) prev_mask = buffer_ops.buffer_load( mask_rsrc, fx.Int32(K4_BLOCK - 1), vec_width=1, dtype=T.i32 ) prev_local = prev_local + prev_mask for _e3 in range_constexpr(K4_BLOCK, E): - e3_mask = buffer_ops.buffer_load(mask_rsrc, fx.Int32(_e3), vec_width=1, dtype=T.i32) + e3_mask = buffer_ops.buffer_load( + mask_rsrc, fx.Int32(_e3), vec_width=1, dtype=T.i32 + ) _lds_store_raw( - cumsum_mr, prev_local, ArithValue(fx.Int32(_e3)).index_cast(T.index) + cumsum_mr, + prev_local, + ArithValue(fx.Int32(_e3)).index_cast(T.index), ) prev_local = prev_local + e3_mask else: for _e3 in range_constexpr(K4_BLOCK, E): _lds_store_raw( - cumsum_mr, fx.Int32(_e3), ArithValue(fx.Int32(_e3)).index_cast(T.index) + cumsum_mr, + fx.Int32(_e3), + ArithValue(fx.Int32(_e3)).index_cast(T.index), ) gpu.barrier() - my_local_idx = _lds_load_raw(cumsum_mr, ArithValue(my_expert).index_cast(T.index)) + my_local_idx = _lds_load_raw( + cumsum_mr, ArithValue(my_expert).index_cast(T.index) + ) - sorted_e_rsrc = buffer_ops.create_buffer_resource(sorted_expert_ids, max_size=True) + sorted_e_rsrc = buffer_ops.create_buffer_resource( + sorted_expert_ids, max_size=True + ) blk_start = my_start // c_unit blk_end = my_end // c_unit n_blks = blk_end - blk_start n_eid_iters = (n_blks + fx.Int32(K4_BLOCK) - c_one) // fx.Int32(K4_BLOCK) - for _eii in range(fx.Index(0), ArithValue(n_eid_iters).index_cast(T.index), fx.Index(1)): + for _eii in range( + fx.Index(0), ArithValue(n_eid_iters).index_cast(T.index), fx.Index(1) + ): blk_idx = blk_start + fx.Int32(_eii) * fx.Int32(K4_BLOCK) + tid - buffer_ops.buffer_store(my_local_idx, sorted_e_rsrc, (blk_idx < blk_end).select(blk_idx, c_oob_idx)) + buffer_ops.buffer_store( + my_local_idx, + sorted_e_rsrc, + (blk_idx < blk_end).select(blk_idx, c_oob_idx), + ) # Step 4: Mesh-based scatter — read uint8 mesh from HBM, extract tokens, # DPP prefix sum over counts, cross-wave LDS reduction, scatter stores. p23_bid_enabled = c_one != c_zero if has_mask: # EP: skip scatter for masked experts (my_start == my_end, but mesh has data) - p23_bid_mask = buffer_ops.buffer_load(mask_rsrc, my_expert, vec_width=1, dtype=T.i32) + p23_bid_mask = buffer_ops.buffer_load( + mask_rsrc, my_expert, vec_width=1, dtype=T.i32 + ) p23_bid_enabled = p23_bid_mask != c_zero i32_words_per_row = i32_mesh_stride >> fx.Int32(2) - n_mesh_iters_raw = (i32_words_per_row + fx.Int32(K4_BLOCK - 1)) // fx.Int32(K4_BLOCK) + n_mesh_iters_raw = (i32_words_per_row + fx.Int32(K4_BLOCK - 1)) // fx.Int32( + K4_BLOCK + ) has_work = my_start != my_end n_mesh_iters = has_work.select(n_mesh_iters_raw, c_zero) mesh_row_i32_base = (my_expert * i32_mesh_stride) >> fx.Int32(2) for _si, state in range( - fx.Index(0), ArithValue(n_mesh_iters).index_cast(T.index), fx.Index(1), init=[my_start] + fx.Index(0), + ArithValue(n_mesh_iters).index_cast(T.index), + fx.Index(1), + init=[my_start], ): position = state[0] word_idx = fx.Int32(_si) * fx.Int32(K4_BLOCK) + tid col_valid = p23_bid_enabled & (word_idx < i32_words_per_row) safe_word_idx = col_valid.select(word_idx, c_zero) - word = buffer_ops.buffer_load(ws_rsrc, mesh_row_i32_base + safe_word_idx, vec_width=1, dtype=T.i32) + word = buffer_ops.buffer_load( + ws_rsrc, mesh_row_i32_base + safe_word_idx, vec_width=1, dtype=T.i32 + ) # Extract 4 bytes from the i32 word x0 = word & c_ff @@ -1412,16 +1674,24 @@ def p23_kernel( is_last_lane_sc = lane == fx.Int32(WARP_SIZE - 1) _if_ll_sc = scf.IfOp(is_last_lane_sc.ir_value()) with _if_then(_if_ll_sc): - _lds_store_raw(scatter_mr, my_cnt, ArithValue(wave).index_cast(T.index)) + _lds_store_raw( + scatter_mr, my_cnt, ArithValue(wave).index_cast(T.index) + ) gpu.barrier() wave_offset = c_zero for _w in range_constexpr(K4_NUM_WAVES - 1): - w_total = _lds_load_raw(scatter_mr, ArithValue(fx.Int32(_w)).index_cast(T.index)) - wave_offset = (wave > fx.Int32(_w)).select(wave_offset + w_total, wave_offset) + w_total = _lds_load_raw( + scatter_mr, ArithValue(fx.Int32(_w)).index_cast(T.index) + ) + wave_offset = (wave > fx.Int32(_w)).select( + wave_offset + w_total, wave_offset + ) batch_total = c_zero for _w in range_constexpr(K4_NUM_WAVES): - batch_total = batch_total + _lds_load_raw(scatter_mr, ArithValue(fx.Int32(_w)).index_cast(T.index)) + batch_total = batch_total + _lds_load_raw( + scatter_mr, ArithValue(fx.Int32(_w)).index_cast(T.index) + ) gpu.barrier() # Convert to exclusive prefix: my_exclusive = my_cnt - my_thread_count @@ -1465,10 +1735,18 @@ def p23_kernel( w_addr_1 = h1.select(token_id_1 * c_topk + topk_slot_1, c_zero) w_addr_2 = h2.select(token_id_2 * c_topk + topk_slot_2, c_zero) w_addr_3 = h3.select(token_id_3 * c_topk + topk_slot_3, c_zero) - w_val_0 = buffer_ops.buffer_load(weights_rsrc, w_addr_0, vec_width=1, dtype=T.i32) - w_val_1 = buffer_ops.buffer_load(weights_rsrc, w_addr_1, vec_width=1, dtype=T.i32) - w_val_2 = buffer_ops.buffer_load(weights_rsrc, w_addr_2, vec_width=1, dtype=T.i32) - w_val_3 = buffer_ops.buffer_load(weights_rsrc, w_addr_3, vec_width=1, dtype=T.i32) + w_val_0 = buffer_ops.buffer_load( + weights_rsrc, w_addr_0, vec_width=1, dtype=T.i32 + ) + w_val_1 = buffer_ops.buffer_load( + weights_rsrc, w_addr_1, vec_width=1, dtype=T.i32 + ) + w_val_2 = buffer_ops.buffer_load( + weights_rsrc, w_addr_2, vec_width=1, dtype=T.i32 + ) + w_val_3 = buffer_ops.buffer_load( + weights_rsrc, w_addr_3, vec_width=1, dtype=T.i32 + ) # Batch-store: all packed IDs, then all weights buffer_ops.buffer_store(pid_0, sorted_ids_rsrc, safe_slot_0) @@ -1488,11 +1766,17 @@ def p23_kernel( sentinel_val = c_sentinel | i32_tokens pad_count = my_end - scatter_end_pos_t0 pad_niters = (pad_count + fx.Int32(K4_BLOCK) - c_one) // fx.Int32(K4_BLOCK) - for _pi in range(fx.Index(0), ArithValue(pad_niters).index_cast(T.index), fx.Index(1)): + for _pi in range( + fx.Index(0), ArithValue(pad_niters).index_cast(T.index), fx.Index(1) + ): pad_slot = scatter_end_pos_t0 + fx.Int32(_pi) * fx.Int32(K4_BLOCK) + tid pad_valid = pad_slot < my_end - buffer_ops.buffer_store(sentinel_val, sorted_ids_rsrc, pad_valid.select(pad_slot, c_oob_idx)) - buffer_ops.buffer_store(c_zero, sorted_w_rsrc, pad_valid.select(pad_slot, c_oob_idx)) + buffer_ops.buffer_store( + sentinel_val, sorted_ids_rsrc, pad_valid.select(pad_slot, c_oob_idx) + ) + buffer_ops.buffer_store( + c_zero, sorted_w_rsrc, pad_valid.select(pad_slot, c_oob_idx) + ) @flyc.jit def launch_p23( @@ -1557,7 +1841,14 @@ def launch_p0v2_p23( p0v2_allocator.finalize() k4_allocator.finalize() - l1 = p0v2_kernel(topk_ids, workspace, expert_mask_tensor, i32_tokens, i32_mesh_stride, i32_mesh_size) + l1 = p0v2_kernel( + topk_ids, + workspace, + expert_mask_tensor, + i32_tokens, + i32_mesh_stride, + i32_mesh_size, + ) l1.launch(grid=(E, 1, 1), block=(P0V2_BLOCK, 1, 1), stream=stream) l2 = p23_kernel( @@ -1608,10 +1899,14 @@ def launch_4k_fused( l1 = clear_workspace_kernel(workspace, i32_ws_total) l1.launch(grid=(n_grid_k1, 1, 1), block=(K1_BLOCK, 1, 1), stream=stream) - l2 = p0_scatter_kernel(topk_ids, workspace, i32_tokens, i32_mesh_stride, i32_p0_niters) + l2 = p0_scatter_kernel( + topk_ids, workspace, i32_tokens, i32_mesh_stride, i32_p0_niters + ) l2.launch(grid=(n_grid_k2, 1, 1), block=(K2_BLOCK, 1, 1), stream=stream) - l3 = p1_count_kernel(workspace, expert_mask_tensor, i32_mesh_stride, i32_mesh_size) + l3 = p1_count_kernel( + workspace, expert_mask_tensor, i32_mesh_stride, i32_mesh_size + ) l3.launch(grid=(E, 1, 1), block=(K3_BLOCK, 1, 1), stream=stream) l4 = p23_kernel( @@ -1630,7 +1925,15 @@ def launch_4k_fused( ) l4.launch(grid=(n_grid_p23, 1, 1), block=(K4_BLOCK, 1, 1), stream=stream) - return launch_clear_ws, launch_p0, launch_p1, launch_p23, launch_p0v2, launch_p0v2_p23, launch_4k_fused + return ( + launch_clear_ws, + launch_p0, + launch_p1, + launch_p23, + launch_p0v2, + launch_p0v2_p23, + launch_4k_fused, + ) # Host-side entry point @@ -1708,7 +2011,11 @@ def moe_sorting_flydsl( """ topk = topk_ids.shape[1] if num_local_tokens is not None: - M = num_local_tokens.item() if isinstance(num_local_tokens, torch.Tensor) else int(num_local_tokens) + M = ( + num_local_tokens.item() + if isinstance(num_local_tokens, torch.Tensor) + else int(num_local_tokens) + ) else: M = topk_ids.shape[0] @@ -1736,7 +2043,9 @@ def moe_sorting_flydsl( target_occupancy = 2 num_cu = torch.cuda.get_device_properties(topk_ids.device).multi_processor_count - n_zero_blocks = min((moe_buf_elems + BLOCK_SIZE - 1) // BLOCK_SIZE, num_cu * target_occupancy) + n_zero_blocks = min( + (moe_buf_elems + BLOCK_SIZE - 1) // BLOCK_SIZE, num_cu * target_occupancy + ) n_grid_blocks = 1 + n_zero_blocks launch_moe_sorting_decode_path( @@ -1816,7 +2125,14 @@ def launch_moe_sorting_decode_path( JIT overhead (inspect.Signature.bind + cache key + dict lookup). """ - cache_key = (num_experts, topk, max_tokens, unit_size, has_mask, topk_ids.device.index) + cache_key = ( + num_experts, + topk, + max_tokens, + unit_size, + has_mask, + topk_ids.device.index, + ) cf = _decode_cf_cache.get(cache_key) if cf is not None: stream = torch.cuda.current_stream() @@ -1900,20 +2216,26 @@ def launch_moe_sorting_prefill_path( ): """Low-level launcher for prefill path via HBM workspace. - For small T (<=512): fused P0_v2 (clear+scatter+count) + K4. - For large T: 4 separate kernels K1+K2+K3+K4. + For small T (<=2048): fused P0_v2 (clear+scatter+count) + K4. + For large T (>2048): 4 separate kernels K1+K2+K3+K4. Uses AOT-compiled dispatch after the first call for each sub-kernel to bypass JIT overhead. """ - launch_clear_ws, launch_p0, launch_p1, launch_p23, launch_p0v2, launch_p0v2_p23, launch_4k_fused = ( - compile_moe_sorting_prefill( - num_experts=num_experts, - topk=topk, - unit_size=unit_size, - has_mask=has_mask, - ) + ( + launch_clear_ws, + launch_p0, + launch_p1, + launch_p23, + launch_p0v2, + launch_p0v2_p23, + launch_4k_fused, + ) = compile_moe_sorting_prefill( + num_experts=num_experts, + topk=topk, + unit_size=unit_size, + has_mask=has_mask, ) stream = torch.cuda.current_stream() @@ -1925,7 +2247,9 @@ def launch_moe_sorting_prefill_path( use_p0v2 = i32_tokens <= 2048 target_occupancy = 2 - n_zero_blocks = min((i32_moe_buf_elems + BLOCK_SIZE - 1) // BLOCK_SIZE, num_cu * target_occupancy) + n_zero_blocks = min( + (i32_moe_buf_elems + BLOCK_SIZE - 1) // BLOCK_SIZE, num_cu * target_occupancy + ) k4_grid = num_experts + n_zero_blocks if use_p0v2: From 0c602efbf7618fabfa22620ac891359428c1726e Mon Sep 17 00:00:00 2001 From: amd-weisun Date: Tue, 19 May 2026 14:27:57 +0100 Subject: [PATCH 3/5] update hardcoded parameter --- aiter/ops/flydsl/kernels/moe_sorting_kernel.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/aiter/ops/flydsl/kernels/moe_sorting_kernel.py b/aiter/ops/flydsl/kernels/moe_sorting_kernel.py index 9bc077d6d3..f568fdd24c 100644 --- a/aiter/ops/flydsl/kernels/moe_sorting_kernel.py +++ b/aiter/ops/flydsl/kernels/moe_sorting_kernel.py @@ -1970,7 +1970,7 @@ def moe_sorting_get_workspace_size(M, num_experts, topk, unit_size=UNIT_SIZE): """Return workspace size (in i32 elements) needed for the prefill path. Returns 0 if the decode path will be used.""" sub_tokens = _compute_sub_tokens(num_experts) - DECODE_MAX_T = 16 + DECODE_MAX_T = min(sub_tokens, max(16, BLOCK_SIZE // max(topk, num_experts // 8))) if M <= min(sub_tokens, DECODE_MAX_T): return 0 mesh_stride = ((M + unit_size - 1) // unit_size) * unit_size @@ -2035,7 +2035,7 @@ def moe_sorting_flydsl( else: mask_tensor = expert_mask - DECODE_MAX_T = 16 + DECODE_MAX_T = min(sub_tokens, max(16, BLOCK_SIZE // max(topk, num_experts // 8))) if M <= min(sub_tokens, DECODE_MAX_T): max_tokens = max(M, 8) From 1b7f30e0e606de0f2ec1cdda7b3bba220cf87a29 Mon Sep 17 00:00:00 2001 From: amd-weisun Date: Tue, 19 May 2026 14:38:03 +0100 Subject: [PATCH 4/5] add barrier for large E --- aiter/ops/flydsl/kernels/moe_sorting_kernel.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/aiter/ops/flydsl/kernels/moe_sorting_kernel.py b/aiter/ops/flydsl/kernels/moe_sorting_kernel.py index f568fdd24c..131899e641 100644 --- a/aiter/ops/flydsl/kernels/moe_sorting_kernel.py +++ b/aiter/ops/flydsl/kernels/moe_sorting_kernel.py @@ -1561,8 +1561,10 @@ def p23_kernel( _lds_store_raw( cumsum_mr, local_idx_p23, ArithValue(tid).index_cast(T.index) ) - # For E > K4_BLOCK: thread 0 writes local_idx for extra experts + # For E > K4_BLOCK: thread 0 extends local_idx using cumsum[K4_BLOCK-1]. + # Barrier ensures all threads have written before thread 0 reads. if E > K4_BLOCK: + gpu.barrier() is_t0_extra3 = tid == c_zero _if_t0_e3 = scf.IfOp(is_t0_extra3.ir_value()) with _if_then(_if_t0_e3): From 7f2b38f55258fbd8dbb45bb0b82e6f11c45755c1 Mon Sep 17 00:00:00 2001 From: amd-weisun Date: Tue, 26 May 2026 11:03:37 +0100 Subject: [PATCH 5/5] update the changes in FlyDSL --- .../ops/flydsl/kernels/moe_sorting_kernel.py | 1434 +++++++---------- 1 file changed, 544 insertions(+), 890 deletions(-) diff --git a/aiter/ops/flydsl/kernels/moe_sorting_kernel.py b/aiter/ops/flydsl/kernels/moe_sorting_kernel.py index 131899e641..622d251c44 100644 --- a/aiter/ops/flydsl/kernels/moe_sorting_kernel.py +++ b/aiter/ops/flydsl/kernels/moe_sorting_kernel.py @@ -9,9 +9,10 @@ Algorithm: counting sort in LDS (histogram → prefix-sum → scatter). -Two paths: - - Decode (small T): single kernel, all phases in LDS. - - Prefill (large T): 4 kernels via HBM workspace (ClearWS → P0 scatter → P1 count → P23 prefix-sum+scatter). +Three paths (selected by T vs ONESHOT_MAX_T = min(sub_tokens, max(16, BLOCK_SIZE // max(topk, E//8)))): + - Oneshot (T <= ONESHOT_MAX_T): single kernel, all phases in LDS. + - Multiphase/2k (ONESHOT_MAX_T < T <= 2048): 2 kernels (fused P0v2 + P23) via HBM workspace. + - Multiphase/4k (T > 2048): 4 kernels (ClearWS → P0 scatter → P1 count → P23) via HBM workspace. Packed token ID format: (topk_position << 24) | token_id - Upper 8 bits: topk slot (0..topk-1) @@ -20,7 +21,6 @@ """ import functools -from contextlib import contextmanager import torch @@ -28,7 +28,6 @@ import flydsl.expr as fx from flydsl._mlir import ir from flydsl._mlir.dialects import memref as memref_ops -from flydsl._mlir.dialects import scf from flydsl.compiler.kernel_function import CompilationContext from flydsl.expr import buffer_ops, gpu, range_constexpr from flydsl.expr import rocdl as fly_rocdl @@ -43,7 +42,7 @@ UNIT_SIZE = 32 # GEMM tile-M, aka block_size in CK WARP_SIZE = get_warp_size() -# DPP constants for prefix sum (used by decode and prefill) +# DPP constants for prefix sum (used by oneshot and multiphase) DPP_ROW_SHR_1 = 0x111 DPP_ROW_SHR_2 = 0x112 DPP_ROW_SHR_4 = 0x114 @@ -93,38 +92,120 @@ def _dpp_intra_wave_prefix_sum(val, lane, WARP_SIZE): return val +@flyc.jit +def _allwave_inclusive_prefix_sum(val, lane, wave, scratch_mr, NUM_WAVES, WARP_SIZE): + """DPP intra-wave prefix sum + cross-wave LDS accumulation. + + Returns (intra_wave_val, inclusive) where intra_wave_val is the per-wave + result (needed for total_padded computation) and inclusive is the full + cross-wave inclusive prefix sum. + """ + val = _dpp_intra_wave_prefix_sum(val, lane, WARP_SIZE) + if lane == fx.Int32(WARP_SIZE - 1): + _lds_store_raw(scratch_mr, val, wave) + gpu.barrier() + cross = fx.Int32(0) + for _w in range_constexpr(NUM_WAVES - 1): + wt = _lds_load_raw(scratch_mr, fx.Int32(_w)) + cross = (wave > fx.Int32(_w)).select(cross + wt, cross) + return val, val + cross + + +@flyc.jit +def _zero_moe_buf_grid_stride(moe_buf_rsrc, gid_v4, stride_v4, total_v4, oob_idx): + """Grid-stride loop zeroing moe_buf via vectorized buffer_store.""" + c_one = fx.Int32(1) + niters = (total_v4 + stride_v4 - c_one) // stride_v4 + c_zero_v4 = fx.Vector.filled(4, 0, fx.Int32) + c4 = fx.Int32(4) + for _z in range(fx.Index(0), ArithValue(niters).index_cast(T.index), fx.Index(1)): + idx = gid_v4 + fx.Int32(_z) * stride_v4 + valid = idx < total_v4 + buffer_ops.buffer_store( + c_zero_v4, moe_buf_rsrc, valid.select(idx * c4, oob_idx) + ) + + +def _extend_prefix_sum_serial(mr, start_block, E, load_fn, store_fn): + """Thread-0 serial extension of prefix sum for experts >= start_block. + + Reads mr[start_block], then accumulates mr[start_block+1..E] in place. + Returns the final accumulated value (mr[E]). + """ + prev = load_fn(mr, fx.Int32(start_block)) + for _ext in range_constexpr(start_block, E): + cur = load_fn(mr, fx.Int32(_ext + 1)) + new_val = prev + cur + store_fn(mr, new_val, fx.Int32(_ext + 1)) + prev = new_val + return prev + + +@flyc.jit +def _write_expert_id_blocks(sorted_e_rsrc, local_eid, blk_start, n_blks): + """Write local_eid to sorted_expert_ids[blk_start .. blk_start+n_blks).""" + for _jb in range(fx.Index(0), ArithValue(n_blks).index_cast(T.index), fx.Index(1)): + blk_idx = blk_start + fx.Int32(_jb) + buffer_ops.buffer_store(local_eid, sorted_e_rsrc, blk_idx) + + +@flyc.jit +def _fill_sentinel_slots( + sorted_ids_rsrc, sorted_w_rsrc, start, count, sentinel, block_size, tid, oob_idx +): + """Cooperative sentinel fill: threads fill [start, start+count) with sentinels.""" + c_zero = fx.Int32(0) + end = start + count + niters = (count + fx.Int32(block_size) - fx.Int32(1)) // fx.Int32(block_size) + for _p in range(fx.Index(0), ArithValue(niters).index_cast(T.index), fx.Index(1)): + slot = start + fx.Int32(_p) * fx.Int32(block_size) + tid + safe = (slot < end).select(slot, oob_idx) + buffer_ops.buffer_store(sentinel, sorted_ids_rsrc, safe) + buffer_ops.buffer_store(c_zero, sorted_w_rsrc, safe) + + +# --------------------------------------------------------------------------- +# LDS helpers for multiphase kernels (module-level, used inside @flyc.kernel) +# --------------------------------------------------------------------------- +def _lds_load_raw(raw_mr, idx): + """Load i32 from LDS raw memref. idx can be i32 or index.""" + raw_idx = idx.ir_value() if hasattr(idx, "ir_value") else idx + if not isinstance(raw_idx.type, ir.IndexType): + raw_idx = ArithValue(idx).index_cast(T.index) + raw_idx = raw_idx.ir_value() if hasattr(raw_idx, "ir_value") else raw_idx + return fx.Int32(memref_ops.load(raw_mr, [raw_idx])) + + +def _lds_store_raw(raw_mr, val, idx): + """Store i32 to LDS raw memref. idx can be i32 or index.""" + v = val.ir_value() if hasattr(val, "ir_value") else val + raw_idx = idx.ir_value() if hasattr(idx, "ir_value") else idx + if not isinstance(raw_idx.type, ir.IndexType): + raw_idx = ArithValue(idx).index_cast(T.index) + raw_idx = raw_idx.ir_value() if hasattr(raw_idx, "ir_value") else raw_idx + memref_ops.store(v, raw_mr, [raw_idx]) + + # --------------------------------------------------------------------------- # AOT-compiled dispatch caches — keyed by constexpr values. # After the first JIT call (which compiles the kernel), flyc.compile() # returns a CompiledFunction whose __call__ skips inspect.Signature.bind, # _make_cache_key, and dict lookup, reducing dispatch from ~70 us to ~5 us. # --------------------------------------------------------------------------- -_decode_cf_cache = ( +_oneshot_cf_cache = ( {} ) # (num_experts, topk, max_tokens, unit_size, has_mask, device) -> CompiledFunction -_prefill_cf_cache = ( +_multiphase_cf_cache = ( {} ) # (num_experts, topk, unit_size, kernel_name, *constexpr_vals) -> CompiledFunction _dummy_mask_cache = {} # device -> torch.Tensor(1, dtype=i32, value=1) -@contextmanager -def _if_then(if_op): - """Context manager for scf.IfOp then-region (from moe_gemm_2stage.py).""" - with ir.InsertionPoint(if_op.then_block): - try: - yield if_op.then_block - finally: - blk = if_op.then_block - if (not blk.operations) or not isinstance(blk.operations[-1], scf.YieldOp): - scf.YieldOp([]) - - # --------------------------------------------------------------------------- -# FlyDSL GPU kernel — decode path (single kernel, SubTokenOneShot) +# FlyDSL GPU kernel — oneshot path (single kernel, all phases in LDS) # --------------------------------------------------------------------------- @functools.lru_cache(maxsize=256) -def compile_moe_sorting_decode( +def _compile_moe_sorting_oneshot( *, num_experts: int, topk: int, @@ -132,7 +213,7 @@ def compile_moe_sorting_decode( unit_size: int = UNIT_SIZE, has_mask: bool = False, ): - """Compile the decode-path MoE sorting kernel. + """Compile the oneshot MoE sorting kernel (single kernel, all phases in LDS). Parameters ---------- @@ -148,10 +229,9 @@ def compile_moe_sorting_decode( arch = get_hip_arch() E = num_experts # CDNA (warp64): 512 threads = 8 waves, affordable cross-wave reduction. - # RDNA (warp32): cap at 256 to avoid 16-wave overhead. - max_decode_block = 512 if WARP_SIZE == 64 else 256 - DECODE_BLOCK = 256 if E <= 256 else min(512, max_decode_block) - NUM_WAVES = DECODE_BLOCK // WARP_SIZE + max_oneshot_block = 512 if WARP_SIZE == 64 else 256 + ONESHOT_BLOCK = 256 if E <= 256 else min(512, max_oneshot_block) + NUM_WAVES = ONESHOT_BLOCK // WARP_SIZE smem_cols = E + 1 # LDS sizing: sub_tokens rows for the token×expert histogram @@ -170,8 +250,7 @@ def compile_moe_sorting_decode( cumsum_bufs = 2 if r < (cumsum_bufs + sub_unroll): raise ValueError( - f"LDS too small for E={E}: need at least " - f"{(cumsum_bufs + sub_unroll) * smem_cols * 4} bytes" + f"LDS too small for E={E}: need at least {(cumsum_bufs + sub_unroll) * smem_cols * 4} bytes" ) r_for_sub = ((r - cumsum_bufs) // sub_unroll) * sub_unroll r_token_min = ((max_tokens + sub_unroll - 1) // sub_unroll) * sub_unroll @@ -197,29 +276,8 @@ def compile_moe_sorting_decode( scratch_offset = allocator._align(allocator.ptr, 16) allocator.ptr = scratch_offset + NUM_WAVES * 4 - # Helpers for raw memref LDS access (used inside scf.for instead of SmemPtr) - def _to_index(v): - """Convert i32 or index DSL value to raw MLIR index value.""" - raw = v.ir_value() if hasattr(v, "ir_value") else v - if isinstance(raw.type, ir.IndexType): - return raw - return ArithValue(v).index_cast(T.index) - - def _lds_load(raw_mr, idx): - """Load i32 from LDS raw memref. idx can be i32 or index.""" - return fx.Int32(memref_ops.load(raw_mr, [_to_index(idx)])) - - def _lds_store(raw_mr, val, idx): - """Store i32 to LDS raw memref. idx can be i32 or index.""" - v = val.ir_value() if hasattr(val, "ir_value") else val - memref_ops.store(v, raw_mr, [_to_index(idx)]) - - def _unwrap(v): - """Unwrap DSL value to raw MLIR ir.Value for scf.for init.""" - return v.ir_value() if hasattr(v, "ir_value") else v - - @flyc.kernel(known_block_size=[DECODE_BLOCK, 1, 1]) - def moe_sorting_decode_kernel( + @flyc.kernel(known_block_size=[ONESHOT_BLOCK, 1, 1]) + def moe_sorting_oneshot_kernel( topk_ids_tensor: fx.Tensor, topk_weights_tensor: fx.Tensor, sorted_token_ids: fx.Tensor, @@ -277,45 +335,34 @@ def moe_sorting_decode_kernel( c_sentinel = fx.Int32((topk << 24)) # =================== MOE_BUF ZEROING (blocks > 0 only) =============== - is_zero_block = bid != c_zero_i32 - _if_zero = scf.IfOp(is_zero_block.ir_value()) - with _if_then(_if_zero): - zero_gid_v4 = (bid - c_one_i32) * fx.Int32(DECODE_BLOCK) + tid + if bid != c_zero_i32: + zero_gid_v4 = (bid - c_one_i32) * fx.Int32(ONESHOT_BLOCK) + tid num_zero_blocks = gpu.grid_dim.x - c_one_i32 - zero_stride_v4 = num_zero_blocks * fx.Int32(DECODE_BLOCK) - i32_moe_buf_v4 = i32_moe_buf_elems >> fx.Int32(2) - zero_niters = ( - i32_moe_buf_v4 + zero_stride_v4 - c_one_i32 - ) // zero_stride_v4 - _zs = fx.Index(0) - _ze = ArithValue(zero_niters).index_cast(T.index) - _z1 = fx.Index(1) - c_zero_v4 = fx.Vector.filled(4, 0, fx.Int32) - c4_i32 = fx.Int32(4) - for _z in range(_zs, _ze, _z1): - z_idx_v4 = zero_gid_v4 + fx.Int32(_z) * zero_stride_v4 - z_valid = z_idx_v4 < i32_moe_buf_v4 - z_elem = z_valid.select(z_idx_v4 * c4_i32, c_oob_idx) - buffer_ops.buffer_store(c_zero_v4, moe_buf_rsrc, z_elem) + zero_stride_v4 = num_zero_blocks * fx.Int32(ONESHOT_BLOCK) + _zero_moe_buf_grid_stride( + moe_buf_rsrc, + zero_gid_v4, + zero_stride_v4, + i32_moe_buf_elems >> fx.Int32(2), + c_oob_idx, + ) # =================== SORTING (block 0 only) ========================== - is_sort_block = bid == c_zero_i32 - _if_sort = scf.IfOp(is_sort_block.ir_value()) - with _if_then(_if_sort): + if bid == c_zero_i32: # ========================= PHASE 1: Histogram ========================= # Clear mesh region — unconditional store to safe index when out of bounds - for i_clear in range_constexpr(0, sub_tokens * smem_cols, DECODE_BLOCK): + for i_clear in range_constexpr(0, sub_tokens * smem_cols, ONESHOT_BLOCK): idx = fx.Int32(i_clear) + tid is_valid = idx < fx.Int32(sub_tokens * smem_cols) safe_idx = is_valid.select(idx, c_zero_i32) safe_idx_ix = ArithValue(safe_idx).index_cast(T.index) # Always store; out-of-bounds threads harmlessly write to index 0 - _lds_store(mesh_mr, c_zero_i32, safe_idx_ix) + _lds_store_raw(mesh_mr, c_zero_i32, safe_idx_ix) gpu.barrier() # Fill mesh: for each (token, topk_slot), write topk_slot+1 to mesh[token, expert_id] total_assignments = tokens * c_topk - for i_assign in range_constexpr(0, max_tokens * topk, DECODE_BLOCK): + for i_assign in range_constexpr(0, max_tokens * topk, ONESHOT_BLOCK): flat_idx = fx.Int32(i_assign) + tid is_valid = flat_idx < total_assignments safe_flat = is_valid.select(flat_idx, c_zero_i32) @@ -336,7 +383,7 @@ def moe_sorting_decode_kernel( safe_mesh_addr = is_valid.select(mesh_addr, last_mesh_idx) safe_mesh_ix = ArithValue(safe_mesh_addr).index_cast(T.index) val = is_valid.select(topk_slot + c_one_i32, c_zero_i32) - _lds_store(mesh_mr, val, safe_mesh_ix) + _lds_store_raw(mesh_mr, val, safe_mesh_ix) gpu.barrier() # ===================== PHASE 2: Count + Prefix Sum ===================== @@ -349,10 +396,10 @@ def moe_sorting_decode_kernel( # Initialize cumsum[0] = 0. All threads write 0 so there's no # read-modify-write race across waves. - _lds_store(cumsum_mr, c_zero_i32, c_zero_i32) + _lds_store_raw(cumsum_mr, c_zero_i32, c_zero_i32) gpu.barrier() - for i_e in range_constexpr(0, E, DECODE_BLOCK // 8): + for i_e in range_constexpr(0, E, ONESHOT_BLOCK // 8): eid_local = fx.Int32(i_e) + lane_group_id eid_valid = eid_local < c_E @@ -366,7 +413,7 @@ def moe_sorting_decode_kernel( safe_eid = combined_valid.select(eid_local, c_zero_i32) mesh_rd_addr = safe_sub * c_smem_cols + safe_eid mesh_rd_ix = ArithValue(mesh_rd_addr).index_cast(T.index) - mesh_val = _lds_load(mesh_mr, mesh_rd_ix) + mesh_val = _lds_load_raw(mesh_mr, mesh_rd_ix) has_token = combined_valid.select( (mesh_val != c_zero_i32).select(c_one_i32, c_zero_i32), @@ -388,31 +435,33 @@ def moe_sorting_decode_kernel( cs_idx = write_valid.select(eid_local + c_one_i32, c_zero_i32) cs_ix = ArithValue(cs_idx).index_cast(T.index) cs_val = write_valid.select(cnt, c_zero_i32) - _lds_store(cumsum_mr, cs_val, cs_ix) + _lds_store_raw(cumsum_mr, cs_val, cs_ix) gpu.barrier() # Phase 2b: Prefix sum over expert counts. # Step 1: Each thread converts its expert's raw count → padded block size. - for i_cvt in range_constexpr(0, E, DECODE_BLOCK): + for i_cvt in range_constexpr(0, E, ONESHOT_BLOCK): cvt_eid = fx.Int32(i_cvt) + tid cvt_valid = cvt_eid < c_E # Safe index: valid → cumsum[eid+1], invalid → cumsum[0] (write 0, harmless) safe_cvt_idx = cvt_valid.select(cvt_eid + c_one_i32, c_zero_i32) cvt_ix = ArithValue(safe_cvt_idx).index_cast(T.index) - raw_cnt_cvt = _lds_load(cumsum_mr, cvt_ix) + raw_cnt_cvt = _lds_load_raw(cumsum_mr, cvt_ix) blocks_cvt = (raw_cnt_cvt + c_unit - c_one_i32) // c_unit padded_cvt = (raw_cnt_cvt == c_zero_i32).select( c_zero_i32, blocks_cvt * c_unit ) # Valid threads write padded value; invalid threads write 0 to cumsum[0] - _lds_store(cumsum_mr, cvt_valid.select(padded_cvt, c_zero_i32), cvt_ix) + _lds_store_raw( + cumsum_mr, cvt_valid.select(padded_cvt, c_zero_i32), cvt_ix + ) gpu.barrier() if has_mask: # EP: zero padded count for masked experts in a separate pass. # Loading from mask buffer inside the padded-count loop above interfered # with expert 0 (MLIR codegen issue). Separate pass avoids this. - for i_ep in range_constexpr(0, E, DECODE_BLOCK): + for i_ep in range_constexpr(0, E, ONESHOT_BLOCK): ep_eid = fx.Int32(i_ep) + tid ep_valid = ep_eid < c_E ep_safe_eid = ep_valid.select(ep_eid, c_zero_i32) @@ -423,9 +472,11 @@ def moe_sorting_decode_kernel( ep_cs_ix = ArithValue( ep_valid.select(ep_eid + c_one_i32, c_zero_i32) ).index_cast(T.index) - _lds_store( + _lds_store_raw( cumsum_mr, - should_zero.select(c_zero_i32, _lds_load(cumsum_mr, ep_cs_ix)), + should_zero.select( + c_zero_i32, _lds_load_raw(cumsum_mr, ep_cs_ix) + ), ep_cs_ix, ) gpu.barrier() @@ -435,43 +486,29 @@ def moe_sorting_decode_kernel( base_ptr, scratch_offset, T.i32, shape=(NUM_WAVES,) ).get() - # All threads read cumsum[tid+1] (in chunks for E > DECODE_BLOCK) - for _ps_chunk in range_constexpr(0, E, DECODE_BLOCK): + # All threads read cumsum[tid+1] (in chunks for E > ONESHOT_BLOCK) + for _ps_chunk in range_constexpr(0, E, ONESHOT_BLOCK): ps_eid = fx.Int32(_ps_chunk) + tid ps_valid = ps_eid < c_E ps_safe_ix = ArithValue( ps_valid.select(ps_eid + c_one_i32, c_zero_i32) ).index_cast(T.index) - ps_val = ps_valid.select(_lds_load(cumsum_mr, ps_safe_ix), c_zero_i32) - _lds_store(cumdup_mr, ps_val, ps_safe_ix) - _lds_store(cumdup_mr, c_zero_i32, c_zero_i32) + ps_val = ps_valid.select( + _lds_load_raw(cumsum_mr, ps_safe_ix), c_zero_i32 + ) + _lds_store_raw(cumdup_mr, ps_val, ps_safe_ix) + _lds_store_raw(cumdup_mr, c_zero_i32, c_zero_i32) gpu.barrier() # DPP prefix sum — all NUM_WAVES waves active ps_tid_valid = tid < c_E val = ps_tid_valid.select( - _lds_load(cumdup_mr, ArithValue(tid + c_one_i32).index_cast(T.index)), - c_zero_i32, + _lds_load_raw(cumdup_mr, tid + c_one_i32), c_zero_i32 ) - val = _dpp_intra_wave_prefix_sum(val, lane, WARP_SIZE) - - # Cross-wave accumulation via scratch LDS - is_last_lane_ps = lane == fx.Int32(WARP_SIZE - 1) - _if_ll_ps = scf.IfOp(is_last_lane_ps.ir_value()) - with _if_then(_if_ll_ps): - _lds_store_raw(scratch_mr, val, ArithValue(wave).index_cast(T.index)) - gpu.barrier() - cross_ps = c_zero_i32 - for _w_ps in range_constexpr(NUM_WAVES - 1): - w_ps_total = _lds_load_raw( - scratch_mr, ArithValue(fx.Int32(_w_ps)).index_cast(T.index) - ) - cross_ps = (wave > fx.Int32(_w_ps)).select( - cross_ps + w_ps_total, cross_ps - ) - - inclusive_ps = val + cross_ps - _lds_store( + _, inclusive_ps = _allwave_inclusive_prefix_sum( + val, lane, wave, scratch_mr, NUM_WAVES, WARP_SIZE + ) + _lds_store_raw( cumdup_mr, ps_tid_valid.select(inclusive_ps, c_zero_i32), ArithValue(ps_tid_valid.select(tid + c_one_i32, c_zero_i32)).index_cast( @@ -480,53 +517,39 @@ def moe_sorting_decode_kernel( ) gpu.barrier() - # For E > DECODE_BLOCK: thread 0 serially extends - if E > DECODE_BLOCK: - _if_t0_ext = scf.IfOp(is_t0.ir_value()) - with _if_then(_if_t0_ext): - prev_ext = _lds_load( - cumdup_mr, - ArithValue(fx.Int32(DECODE_BLOCK)).index_cast(T.index), + # For E > ONESHOT_BLOCK: thread 0 serially extends + if E > ONESHOT_BLOCK: + if is_t0: + _extend_prefix_sum_serial( + cumdup_mr, ONESHOT_BLOCK, E, _lds_load_raw, _lds_store_raw ) - for _ext in range_constexpr(DECODE_BLOCK, E): - cur_ext = _lds_load( - cumdup_mr, - ArithValue(fx.Int32(_ext + 1)).index_cast(T.index), - ) - new_ext = prev_ext + cur_ext - _lds_store( - cumdup_mr, - new_ext, - ArithValue(fx.Int32(_ext + 1)).index_cast(T.index), - ) - prev_ext = new_ext gpu.barrier() # cumdup[0] = 0 - _lds_store(cumdup_mr, c_zero_i32, c_zero_i32) + _lds_store_raw(cumdup_mr, c_zero_i32, c_zero_i32) gpu.barrier() # Write num_valid_ids from cumdup[E] cs_E_ix_ps = ArithValue(c_E).index_cast(T.index) - total_padded = _lds_load(cumdup_mr, cs_E_ix_ps) + total_padded = _lds_load_raw(cumdup_mr, cs_E_ix_ps) buffer_ops.buffer_store(total_padded, nvalid_rsrc, c_zero_i32) buffer_ops.buffer_store(tokens, nvalid_rsrc, c_one_i32) gpu.barrier() # Copy cumdup → cumsum (all threads, one expert per thread) - for i_cp in range_constexpr(0, E + 1, DECODE_BLOCK): + for i_cp in range_constexpr(0, E + 1, ONESHOT_BLOCK): cp_idx = fx.Int32(i_cp) + tid cp_valid = cp_idx <= c_E safe_cp_idx = cp_valid.select(cp_idx, c_zero_i32) cp_ix = ArithValue(safe_cp_idx).index_cast(T.index) - cp_val = _lds_load(cumdup_mr, cp_ix) - _lds_store(cumsum_mr, cp_val, cp_ix) + cp_val = _lds_load_raw(cumdup_mr, cp_ix) + _lds_store_raw(cumsum_mr, cp_val, cp_ix) gpu.barrier() if has_mask: # EP: Compute mask cumsum in cumdup for local expert index mapping. # cumdup[eid] = exclusive prefix sum of mask[0..eid-1] = local expert index. - for i_ml in range_constexpr(0, E, DECODE_BLOCK): + for i_ml in range_constexpr(0, E, ONESHOT_BLOCK): ml_eid = fx.Int32(i_ml) + tid ml_valid = ml_eid < c_E safe_ml_eid = ml_valid.select(ml_eid, c_zero_i32) @@ -537,35 +560,19 @@ def moe_sorting_decode_kernel( ml_ix = ArithValue( ml_valid.select(ml_eid + c_one_i32, c_zero_i32) ).index_cast(T.index) - _lds_store(cumdup_mr, ml_val, ml_ix) - _lds_store(cumdup_mr, c_zero_i32, c_zero_i32) + _lds_store_raw(cumdup_mr, ml_val, ml_ix) + _lds_store_raw(cumdup_mr, c_zero_i32, c_zero_i32) gpu.barrier() # All-wave DPP prefix sum over mask values in cumdup m_tid_valid = tid < c_E mval = m_tid_valid.select( - _lds_load( - cumdup_mr, ArithValue(tid + c_one_i32).index_cast(T.index) - ), - c_zero_i32, + _lds_load_raw(cumdup_mr, tid + c_one_i32), c_zero_i32 ) - mval = _dpp_intra_wave_prefix_sum(mval, lane, WARP_SIZE) - - is_last_lane_m = lane == fx.Int32(WARP_SIZE - 1) - _if_ll_m = scf.IfOp(is_last_lane_m.ir_value()) - with _if_then(_if_ll_m): - _lds_store_raw( - scratch_mr, mval, ArithValue(wave).index_cast(T.index) - ) - gpu.barrier() - cross_m = c_zero_i32 - for _wm in range_constexpr(NUM_WAVES - 1): - wm_total = _lds_load_raw( - scratch_mr, ArithValue(fx.Int32(_wm)).index_cast(T.index) - ) - cross_m = (wave > fx.Int32(_wm)).select(cross_m + wm_total, cross_m) - inclusive_m = mval + cross_m - _lds_store( + _, inclusive_m = _allwave_inclusive_prefix_sum( + mval, lane, wave, scratch_mr, NUM_WAVES, WARP_SIZE + ) + _lds_store_raw( cumdup_mr, m_tid_valid.select(inclusive_m, c_zero_i32), ArithValue( @@ -574,100 +581,77 @@ def moe_sorting_decode_kernel( ) gpu.barrier() - if E > DECODE_BLOCK: - _if_t0_ext_m = scf.IfOp(is_t0.ir_value()) - with _if_then(_if_t0_ext_m): - prev_m = _lds_load( - cumdup_mr, - ArithValue(fx.Int32(DECODE_BLOCK)).index_cast(T.index), + if E > ONESHOT_BLOCK: + if is_t0: + _extend_prefix_sum_serial( + cumdup_mr, ONESHOT_BLOCK, E, _lds_load_raw, _lds_store_raw ) - for _ext_m in range_constexpr(DECODE_BLOCK, E): - cur_m = _lds_load( - cumdup_mr, - ArithValue(fx.Int32(_ext_m + 1)).index_cast(T.index), - ) - new_m = prev_m + cur_m - _lds_store( - cumdup_mr, - new_m, - ArithValue(fx.Int32(_ext_m + 1)).index_cast(T.index), - ) - prev_m = new_m gpu.barrier() - _lds_store(cumdup_mr, c_zero_i32, c_zero_i32) + _lds_store_raw(cumdup_mr, c_zero_i32, c_zero_i32) gpu.barrier() else: # No mask: cumdup[eid] = eid (identity mapping) - for i_ml in range_constexpr(0, E, DECODE_BLOCK): + for i_ml in range_constexpr(0, E, ONESHOT_BLOCK): ml_eid = fx.Int32(i_ml) + tid ml_valid = ml_eid < c_E safe_ml_eid = ml_valid.select(ml_eid, c_zero_i32) ml_ix = ArithValue(safe_ml_eid).index_cast(T.index) - _lds_store( + _lds_store_raw( cumdup_mr, ml_valid.select(safe_ml_eid, c_zero_i32), ml_ix ) gpu.barrier() # Write sorted_expert_ids — predicated stores to buffer (safe: buffer_store ignores OOB) # EP: use cumdup[eid] as local expert index instead of global eid - for i_eid in range_constexpr(0, E, DECODE_BLOCK): + for i_eid in range_constexpr(0, E, ONESHOT_BLOCK): eid_wr = fx.Int32(i_eid) + tid eid_wr_valid = eid_wr < c_E safe_eid_wr = eid_wr_valid.select(eid_wr, c_zero_i32) cs_start_ix = ArithValue(safe_eid_wr).index_cast(T.index) cs_end_ix = ArithValue(safe_eid_wr + c_one_i32).index_cast(T.index) - e_start = _lds_load(cumsum_mr, cs_start_ix) - e_end = eid_wr_valid.select(_lds_load(cumsum_mr, cs_end_ix), e_start) - local_eid = _lds_load(cumdup_mr, cs_start_ix) + e_start = _lds_load_raw(cumsum_mr, cs_start_ix) + e_end = eid_wr_valid.select( + _lds_load_raw(cumsum_mr, cs_end_ix), e_start + ) + local_eid = _lds_load_raw(cumdup_mr, cs_start_ix) # Store cumdup: reuse cumdup for scatter phase position tracking. # Write e_start to cumdup[eid] (overwriting mask cumsum, no longer needed). - _lds_store(cumdup_mr, e_start, cs_start_ix) + _lds_store_raw(cumdup_mr, e_start, cs_start_ix) blk_start = e_start // c_unit blk_end = e_end // c_unit n_blks_wr = eid_wr_valid.select(blk_end - blk_start, c_zero_i32) - for _jb in range( - fx.Index(0), ArithValue(n_blks_wr).index_cast(T.index), fx.Index(1) - ): - blk_idx = blk_start + fx.Int32(_jb) - buffer_ops.buffer_store(local_eid, sorted_e_rsrc, blk_idx) + _write_expert_id_blocks(sorted_e_rsrc, local_eid, blk_start, n_blks_wr) gpu.barrier() # Store cumdup[E] = cumsum[E]. # All threads write cumE to cumdup[E] (all write the same value, no race). cs_E_ix = ArithValue(c_E).index_cast(T.index) - cumE = _lds_load(cumsum_mr, cs_E_ix) - _lds_store(cumdup_mr, cumE, cs_E_ix) + cumE = _lds_load_raw(cumsum_mr, cs_E_ix) + _lds_store_raw(cumdup_mr, cumE, cs_E_ix) gpu.barrier() # ====================== PRE-FILL: Sentinel fill (cooperative) =========== - # Fill ALL total_padded slots with sentinels BEFORE scatter. - # Scatter will overwrite real token positions. - sentinel_val_pre = c_sentinel | tokens - c_zero_pre = c_zero_i32 - cs_E_ix_pre = ArithValue(c_E).index_cast(T.index) - total_padded_pre = _lds_load(cumdup_mr, cs_E_ix_pre) - n_pre_iters = ( - total_padded_pre + fx.Int32(DECODE_BLOCK) - c_one_i32 - ) // fx.Int32(DECODE_BLOCK) - for _pre in range( - fx.Index(0), ArithValue(n_pre_iters).index_cast(T.index), fx.Index(1) - ): - pre_slot = fx.Int32(_pre) * fx.Int32(DECODE_BLOCK) + tid - pre_valid = pre_slot < total_padded_pre - safe_pre = pre_valid.select(pre_slot, c_oob_idx) - buffer_ops.buffer_store(sentinel_val_pre, sorted_ids_rsrc, safe_pre) - buffer_ops.buffer_store(c_zero_pre, sorted_w_rsrc, safe_pre) + total_padded_pre = _lds_load_raw( + cumdup_mr, ArithValue(c_E).index_cast(T.index) + ) + _fill_sentinel_slots( + sorted_ids_rsrc, + sorted_w_rsrc, + c_zero_i32, + total_padded_pre, + c_sentinel | tokens, + ONESHOT_BLOCK, + tid, + c_oob_idx, + ) gpu.barrier() # ====================== PHASE 3: Scatter ============================== - # CK uses wave_cumsum<8> (DPP prefix sum) + ds_bpermute (lane broadcast). - # FlyDSL has neither. Instead, each lane reads all 8 mesh values in the - # batch from LDS to compute its own prefix offset. No shuffle needed. - for i_e2 in range_constexpr(0, E, DECODE_BLOCK // 8): + for i_e2 in range_constexpr(0, E, ONESHOT_BLOCK // 8): eid_sc = fx.Int32(i_e2) + lane_group_id eid_sc_valid = eid_sc < c_E # Invalid lane groups map to cumsum[E] (the total count) instead of @@ -686,7 +670,7 @@ def moe_sorting_decode_kernel( sc_expert_enabled = eid_sc_valid & (sc_mask_val != c_zero_i32) cs_sc_ix = ArithValue(safe_eid_sc).index_cast(T.index) - position = _lds_load(cumsum_mr, cs_sc_ix) + position = _lds_load_raw(cumsum_mr, cs_sc_ix) for i_sub2 in range_constexpr(0, sub_tokens, 8): # This lane handles sub_token (i_sub2 + lane_group_os). @@ -695,14 +679,14 @@ def moe_sorting_decode_kernel( safe_my_sub = my_sub_valid.select(my_sub, c_zero_i32) my_mesh_addr = safe_my_sub * c_smem_cols + safe_eid_sc my_mesh_ix = ArithValue(my_mesh_addr).index_cast(T.index) - my_x = _lds_load(mesh_mr, my_mesh_ix) + my_x = _lds_load_raw(mesh_mr, my_mesh_ix) my_has_token = my_sub_valid & (my_x != c_zero_i32) local_cnt = my_has_token.select(c_one_i32, c_zero_i32) # 8-lane group prefix sum (NOT full-wave — uses lane_group_os, # only shifts 1,2,4, no cross-row bpermute needed). - cnt_raw = _unwrap(local_cnt) - zero_raw = _unwrap(c_zero_i32) + cnt_raw = _unwrap_val(local_cnt) + zero_raw = _unwrap_val(c_zero_i32) # row_shr:1 remote = fly_rocdl.update_dpp( @@ -720,7 +704,7 @@ def moe_sorting_decode_kernel( ) # row_shr:2 - cnt_raw = _unwrap(local_cnt) + cnt_raw = _unwrap_val(local_cnt) remote = fly_rocdl.update_dpp( T.i32, zero_raw, @@ -736,7 +720,7 @@ def moe_sorting_decode_kernel( ) # row_shr:4 - cnt_raw = _unwrap(local_cnt) + cnt_raw = _unwrap_val(local_cnt) remote = fly_rocdl.update_dpp( T.i32, zero_raw, @@ -778,13 +762,13 @@ def moe_sorting_decode_kernel( # Write back updated position (for padding phase). # Invalid lane groups write position (=0+0=0) to cumsum[0] which is harmless. - _lds_store(cumsum_mr, position, cs_sc_ix) + _lds_store_raw(cumsum_mr, position, cs_sc_ix) gpu.barrier() # Padding already filled by PRE-FILL phase above (before scatter). @flyc.jit - def launch_moe_sorting_decode( + def launch_moe_sorting_oneshot( topk_ids_tensor: fx.Tensor, topk_weights_tensor: fx.Tensor, sorted_token_ids: fx.Tensor, @@ -803,7 +787,7 @@ def launch_moe_sorting_decode( with ir.InsertionPoint(ctx.gpu_module_body): allocator.finalize() - launcher = moe_sorting_decode_kernel( + launcher = moe_sorting_oneshot_kernel( topk_ids_tensor, topk_weights_tensor, sorted_token_ids, @@ -817,52 +801,25 @@ def launch_moe_sorting_decode( ) launcher.launch( grid=(n_grid_blocks, 1, 1), - block=(DECODE_BLOCK, 1, 1), + block=(ONESHOT_BLOCK, 1, 1), stream=stream, ) - return launch_moe_sorting_decode + return launch_moe_sorting_oneshot # --------------------------------------------------------------------------- -# LDS helpers for prefill kernels (module-level, used inside @flyc.kernel) -# --------------------------------------------------------------------------- -def _lds_load_raw(raw_mr, idx): - """Load i32 from LDS raw memref. idx can be i32 or index.""" - raw_idx = idx.ir_value() if hasattr(idx, "ir_value") else idx - if not isinstance(raw_idx.type, ir.IndexType): - raw_idx = ArithValue(idx).index_cast(T.index) - raw_idx = raw_idx.ir_value() if hasattr(raw_idx, "ir_value") else raw_idx - return fx.Int32(memref_ops.load(raw_mr, [raw_idx])) - - -def _lds_store_raw(raw_mr, val, idx): - """Store i32 to LDS raw memref. idx can be i32 or index.""" - v = val.ir_value() if hasattr(val, "ir_value") else val - raw_idx = idx.ir_value() if hasattr(idx, "ir_value") else idx - if not isinstance(raw_idx.type, ir.IndexType): - raw_idx = ArithValue(idx).index_cast(T.index) - raw_idx = raw_idx.ir_value() if hasattr(raw_idx, "ir_value") else raw_idx - memref_ops.store(v, raw_mr, [raw_idx]) - - -def _unwrap_raw(v): - """Unwrap DSL value to raw MLIR ir.Value.""" - return v.ir_value() if hasattr(v, "ir_value") else v - - -# --------------------------------------------------------------------------- -# FlyDSL GPU kernels — prefill path (4 kernels, large T via HBM workspace) +# FlyDSL GPU kernels — multiphase path (2 or 4 kernels, large T via HBM workspace) # --------------------------------------------------------------------------- @functools.lru_cache(maxsize=256) -def compile_moe_sorting_prefill( +def _compile_moe_sorting_multiphase( *, num_experts: int, topk: int, unit_size: int = UNIT_SIZE, has_mask: bool = False, ): - """Compile the prefill-path MoE sorting kernels. + """Compile the multiphase MoE sorting kernels (2 or 4 kernels via HBM workspace). For token counts exceeding LDS capacity, uses HBM workspace: K1: ClearWorkspace — zero the workspace buffer @@ -888,6 +845,159 @@ def compile_moe_sorting_prefill( arch = get_hip_arch() E = num_experts + @flyc.jit + def _extend_local_idx_for_extra_experts( + cumsum_mr, mask_rsrc, K4_BLOCK, E, has_mask + ): + """Thread-0: write local expert indices for experts >= K4_BLOCK to cumsum_mr.""" + if has_mask: + prev_local = _lds_load_raw(cumsum_mr, fx.Int32(K4_BLOCK - 1)) + prev_mask = buffer_ops.buffer_load( + mask_rsrc, fx.Int32(K4_BLOCK - 1), vec_width=1, dtype=T.i32 + ) + prev_local = prev_local + prev_mask + for _e3 in range_constexpr(K4_BLOCK, E): + e3_mask = buffer_ops.buffer_load( + mask_rsrc, fx.Int32(_e3), vec_width=1, dtype=T.i32 + ) + _lds_store_raw(cumsum_mr, prev_local, fx.Int32(_e3)) + prev_local = prev_local + e3_mask + else: + for _e3 in range_constexpr(K4_BLOCK, E): + _lds_store_raw(cumsum_mr, fx.Int32(_e3), fx.Int32(_e3)) + + @flyc.jit + def _p23_scatter_mesh( + tid, + scatter_mr, + ws_rsrc, + weights_rsrc, + sorted_ids_rsrc, + sorted_w_rsrc, + mask_rsrc, + my_expert, + my_start, + my_end, + i32_mesh_stride, + c_topk, + K4_BLOCK, + has_mask, + ): + """P23 Step 4: EP mask check, read uint8 mesh, DPP prefix sum, scatter tokens.""" + lane = tid % WARP_SIZE + wave = tid // WARP_SIZE + K4_NUM_WAVES = K4_BLOCK // WARP_SIZE + c_zero, c_one, c4 = fx.Int32(0), fx.Int32(1), fx.Int32(4) + c_ff, c_oob_idx = fx.Int32(0xFF), fx.Int32(0x7FFFFFFF) + p23_bid_enabled = c_one != c_zero + if has_mask: + p23_bid_mask = buffer_ops.buffer_load( + mask_rsrc, my_expert, vec_width=1, dtype=T.i32 + ) + p23_bid_enabled = p23_bid_mask != c_zero + i32_words_per_row = i32_mesh_stride >> fx.Int32(2) + n_mesh_iters = (my_start != my_end).select( + (i32_words_per_row + fx.Int32(K4_BLOCK - 1)) // fx.Int32(K4_BLOCK), c_zero + ) + mesh_row_i32_base = (my_expert * i32_mesh_stride) >> fx.Int32(2) + for _si, state in range( + fx.Index(0), + ArithValue(n_mesh_iters).index_cast(T.index), + fx.Index(1), + init=[my_start], + ): + position = state[0] + word_idx = fx.Int32(_si) * fx.Int32(K4_BLOCK) + tid + col_valid = p23_bid_enabled & (word_idx < i32_words_per_row) + safe_word_idx = col_valid.select(word_idx, c_zero) + word = buffer_ops.buffer_load( + ws_rsrc, mesh_row_i32_base + safe_word_idx, vec_width=1, dtype=T.i32 + ) + x0 = word & c_ff + x1 = (word >> fx.Int32(8)) & c_ff + x2 = (word >> fx.Int32(16)) & c_ff + x3 = (word >> fx.Int32(24)) & c_ff + base_col = word_idx * c4 + h0 = col_valid & (x0 != c_zero) + h1 = col_valid & (x1 != c_zero) + h2 = col_valid & (x2 != c_zero) + h3 = col_valid & (x3 != c_zero) + my_cnt = ( + h0.select(c_one, c_zero) + + h1.select(c_one, c_zero) + + h2.select(c_one, c_zero) + + h3.select(c_one, c_zero) + ) + my_pre_scan = my_cnt + my_cnt, my_cnt_inclusive = _allwave_inclusive_prefix_sum( + my_cnt, lane, wave, scatter_mr, K4_NUM_WAVES, WARP_SIZE + ) + wave_offset = my_cnt_inclusive - my_cnt + batch_total = c_zero + for _w in range_constexpr(K4_NUM_WAVES): + batch_total = batch_total + _lds_load_raw(scatter_mr, fx.Int32(_w)) + gpu.barrier() + my_exclusive = my_cnt - my_pre_scan + wave_offset + scatter_base = position + my_exclusive + pid_0 = (h0.select(x0 - c_one, c_zero) << fx.Int32(24)) | base_col + pid_1 = (h1.select(x1 - c_one, c_zero) << fx.Int32(24)) | (base_col + c_one) + pid_2 = (h2.select(x2 - c_one, c_zero) << fx.Int32(24)) | ( + base_col + fx.Int32(2) + ) + pid_3 = (h3.select(x3 - c_one, c_zero) << fx.Int32(24)) | ( + base_col + fx.Int32(3) + ) + safe_slot_0 = h0.select(scatter_base, c_oob_idx) + off1 = scatter_base + h0.select(c_one, c_zero) + safe_slot_1 = h1.select(off1, c_oob_idx) + off2 = off1 + h1.select(c_one, c_zero) + safe_slot_2 = h2.select(off2, c_oob_idx) + off3 = off2 + h2.select(c_one, c_zero) + safe_slot_3 = h3.select(off3, c_oob_idx) + w_val_0 = buffer_ops.buffer_load( + weights_rsrc, + h0.select(base_col * c_topk + h0.select(x0 - c_one, c_zero), c_zero), + vec_width=1, + dtype=T.i32, + ) + w_val_1 = buffer_ops.buffer_load( + weights_rsrc, + h1.select( + (base_col + c_one) * c_topk + h1.select(x1 - c_one, c_zero), c_zero + ), + vec_width=1, + dtype=T.i32, + ) + w_val_2 = buffer_ops.buffer_load( + weights_rsrc, + h2.select( + (base_col + fx.Int32(2)) * c_topk + h2.select(x2 - c_one, c_zero), + c_zero, + ), + vec_width=1, + dtype=T.i32, + ) + w_val_3 = buffer_ops.buffer_load( + weights_rsrc, + h3.select( + (base_col + fx.Int32(3)) * c_topk + h3.select(x3 - c_one, c_zero), + c_zero, + ), + vec_width=1, + dtype=T.i32, + ) + buffer_ops.buffer_store(pid_0, sorted_ids_rsrc, safe_slot_0) + buffer_ops.buffer_store(pid_1, sorted_ids_rsrc, safe_slot_1) + buffer_ops.buffer_store(pid_2, sorted_ids_rsrc, safe_slot_2) + buffer_ops.buffer_store(pid_3, sorted_ids_rsrc, safe_slot_3) + buffer_ops.buffer_store(w_val_0, sorted_w_rsrc, safe_slot_0) + buffer_ops.buffer_store(w_val_1, sorted_w_rsrc, safe_slot_1) + buffer_ops.buffer_store(w_val_2, sorted_w_rsrc, safe_slot_2) + buffer_ops.buffer_store(w_val_3, sorted_w_rsrc, safe_slot_3) + pos_next = position + batch_total + results = yield [pos_next] + return results + # --- K1: ClearWorkspace kernel ------------------------------------------- # CK uses grid=262144, block=1024 (1 store per thread, no loop). # Match that: block=1024, grid=ceil(ws_total/1024). @@ -952,8 +1062,7 @@ def p0_scatter_kernel( eid = buffer_ops.buffer_load(topk_rsrc, safe_flat, vec_width=1, dtype=T.i32) byte_offset = eid * i32_mesh_stride + token_id val_i8 = ArithValue(topk_slot + c_one).trunci(T.i8) - _if_valid_k2 = scf.IfOp(valid.ir_value()) - with _if_then(_if_valid_k2): + if valid: buffer_ops.buffer_store( val_i8, ws_rsrc, byte_offset, offset_is_bytes=True ) @@ -1072,8 +1181,7 @@ def p1_count_kernel( # Cross-warp reduce via LDS: lane 0 of each warp writes partial sum is_lane0 = lane == c_zero - _if_l0 = scf.IfOp(is_lane0.ir_value()) - with _if_then(_if_l0): + if is_lane0: wave_ix = ArithValue(wave).index_cast(T.index) _lds_store_raw(reduce_mr, cnt, wave_ix) gpu.barrier() @@ -1082,9 +1190,7 @@ def p1_count_kernel( is_t0 = tid == c_zero total = c_zero for _w in range_constexpr(K3_NUM_WAVES): - total = total + _lds_load_raw( - reduce_mr, ArithValue(fx.Int32(_w)).index_cast(T.index) - ) + total = total + _lds_load_raw(reduce_mr, fx.Int32(_w)) cs_offset = i32_mesh_size + eid c_oob_idx = fx.Int32(0x7FFFFFFF) @@ -1214,8 +1320,7 @@ def p0v2_kernel( val_i8 = ArithValue(is_mine.select(topk_slot + c_one, c_zero)).trunci(T.i8) # Byte-mode buffer_store with OOB offset crashes on AMD GPUs. # Use conditional branch to skip the store for non-matching threads. - _if_mine = scf.IfOp(is_mine.ir_value()) - with _if_then(_if_mine): + if is_mine: buffer_ops.buffer_store( val_i8, ws_rsrc, byte_offset, offset_is_bytes=True ) @@ -1260,8 +1365,7 @@ def p0v2_kernel( # Cross-warp reduce via LDS: lane 0 of each warp writes partial sum is_lane0 = lane == c_zero - _if_l0 = scf.IfOp(is_lane0.ir_value()) - with _if_then(_if_l0): + if is_lane0: wave_ix = ArithValue(wave).index_cast(T.index) _lds_store_raw(reduce_mr, cnt, wave_ix) gpu.barrier() @@ -1270,9 +1374,7 @@ def p0v2_kernel( is_t0 = tid == c_zero total = c_zero for _w in range_constexpr(P0V2_NUM_WAVES): - total = total + _lds_load_raw( - reduce_mr, ArithValue(fx.Int32(_w)).index_cast(T.index) - ) + total = total + _lds_load_raw(reduce_mr, fx.Int32(_w)) cs_offset = i32_mesh_size + eid c_oob_idx = fx.Int32(0x7FFFFFFF) @@ -1340,13 +1442,11 @@ def p23_kernel( wave = tid // WARP_SIZE c_zero = fx.Int32(0) c_one = fx.Int32(1) - c4 = fx.Int32(4) c_E = fx.Int32(E) c_unit = fx.Int32(unit_size) c_topk = fx.Int32(topk) c_sentinel = fx.Int32(topk << 24) c_oob_idx = fx.Int32(0x7FFFFFFF) - c_ff = fx.Int32(0xFF) # Buffer resources ws_rsrc = buffer_ops.create_buffer_resource(workspace, max_size=True) @@ -1374,40 +1474,29 @@ def p23_kernel( is_zero_block = bid >= c_E # ================ MOE_BUF ZEROING (blocks >= E) ================== - _if_zero = scf.IfOp(is_zero_block.ir_value()) - with _if_then(_if_zero): + if is_zero_block: moe_buf_rsrc = buffer_ops.create_buffer_resource(moe_buf, max_size=True) - zero_base_bid = bid - c_E - zero_gid_v4 = zero_base_bid * fx.Int32(K4_BLOCK) + tid - num_zero_blocks = gpu.grid_dim.x - c_E - zero_stride_v4 = num_zero_blocks * fx.Int32(K4_BLOCK) - i32_moe_buf_v4 = i32_moe_buf_elems >> fx.Int32(2) - zero_niters = (i32_moe_buf_v4 + zero_stride_v4 - c_one) // zero_stride_v4 - c_zero_v4 = fx.Vector.filled(4, 0, fx.Int32) - for _z in range( - fx.Index(0), ArithValue(zero_niters).index_cast(T.index), fx.Index(1) - ): - z_idx_v4 = zero_gid_v4 + fx.Int32(_z) * zero_stride_v4 - z_valid = z_idx_v4 < i32_moe_buf_v4 - z_elem = z_valid.select(z_idx_v4 * c4, c_oob_idx) - buffer_ops.buffer_store(c_zero_v4, moe_buf_rsrc, z_elem) + zero_gid_v4 = (bid - c_E) * fx.Int32(K4_BLOCK) + tid + zero_stride_v4 = (gpu.grid_dim.x - c_E) * fx.Int32(K4_BLOCK) + _zero_moe_buf_grid_stride( + moe_buf_rsrc, + zero_gid_v4, + zero_stride_v4, + i32_moe_buf_elems >> fx.Int32(2), + c_oob_idx, + ) # ================ PARALLEL PREFIX-SUM + MESH SCATTER (blocks 0..E-1) == # Each block independently: prefix sum (redundant), scatter for its expert only. - _if_sort = scf.IfOp(is_sort_block.ir_value()) - with _if_then(_if_sort): + if is_sort_block: my_expert = bid # Step 1: Load expert counts from workspace -> pad to unit_size -> LDS cumsum # Process E experts in chunks of K4_BLOCK (256). Most models have # E <= 256, so the extra chunk is only needed for E > 256 # (e.g. DeepSeek-R1 with 256 routed + 1 shared = 257). - is_t0_init = tid == c_zero - _if_init_cs = scf.IfOp(is_t0_init.ir_value()) - with _if_then(_if_init_cs): - _lds_store_raw( - cumsum_mr, c_zero, ArithValue(c_zero).index_cast(T.index) - ) + if tid == c_zero: + _lds_store_raw(cumsum_mr, c_zero, c_zero) # EP: load this thread's own mask value BEFORE the chunked loop. # The chunked loop overwrites p23_mask_val in later chunks, so we @@ -1446,110 +1535,47 @@ def p23_kernel( oob = raw_store_idx >= fx.Int32(k4_smem_cols) safe_store_idx = oob.select(c_zero, raw_store_idx) safe_store_val = oob.select(c_zero, padded) - _lds_store_raw( - cumsum_mr, - safe_store_val, - ArithValue(safe_store_idx).index_cast(T.index), - ) + _lds_store_raw(cumsum_mr, safe_store_val, safe_store_idx) gpu.barrier() # Step 2: Prefix sum over cumsum LDS. When E <= K4_BLOCK (256), # a single DPP pass covers all experts. When E > K4_BLOCK, we # do the DPP pass for the first K4_BLOCK elements, then serially # accumulate the remaining entries from thread 0. - val = _lds_load_raw(cumsum_mr, ArithValue(tid + c_one).index_cast(T.index)) - val = _dpp_intra_wave_prefix_sum(val, lane, WARP_SIZE) - - is_last_lane_ps = lane == fx.Int32(WARP_SIZE - 1) - _if_ll_ps = scf.IfOp(is_last_lane_ps.ir_value()) - with _if_then(_if_ll_ps): - _lds_store_raw(scatter_mr, val, ArithValue(wave).index_cast(T.index)) - gpu.barrier() - cross_offset = c_zero - for _w in range_constexpr(K4_NUM_WAVES - 1): - w_total = _lds_load_raw( - scatter_mr, ArithValue(fx.Int32(_w)).index_cast(T.index) - ) - cross_offset = (wave > fx.Int32(_w)).select( - cross_offset + w_total, cross_offset - ) + val = _lds_load_raw(cumsum_mr, tid + c_one) + val, inclusive_prefix = _allwave_inclusive_prefix_sum( + val, lane, wave, scatter_mr, K4_NUM_WAVES, WARP_SIZE + ) total_padded = c_zero for _w in range_constexpr(K4_NUM_WAVES): - total_padded = total_padded + _lds_load_raw( - scatter_mr, ArithValue(fx.Int32(_w)).index_cast(T.index) - ) - - inclusive_prefix = val + cross_offset - _lds_store_raw( - cumsum_mr, inclusive_prefix, ArithValue(tid + c_one).index_cast(T.index) - ) + total_padded = total_padded + _lds_load_raw(scatter_mr, fx.Int32(_w)) + _lds_store_raw(cumsum_mr, inclusive_prefix, tid + c_one) gpu.barrier() # For E > K4_BLOCK: thread 0 serially extends the prefix sum - # for experts K4_BLOCK..E-1 (at most a few iterations). if E > K4_BLOCK: - _if_t0_extra = scf.IfOp(is_t0_init.ir_value()) - with _if_then(_if_t0_extra): - prev_sum = _lds_load_raw( - cumsum_mr, ArithValue(fx.Int32(K4_BLOCK)).index_cast(T.index) + if tid == c_zero: + total_padded = _extend_prefix_sum_serial( + cumsum_mr, K4_BLOCK, E, _lds_load_raw, _lds_store_raw ) - for _extra in range_constexpr(K4_BLOCK, E): - cur = _lds_load_raw( - cumsum_mr, - ArithValue(fx.Int32(_extra + 1)).index_cast(T.index), - ) - new_sum = prev_sum + cur - _lds_store_raw( - cumsum_mr, - new_sum, - ArithValue(fx.Int32(_extra + 1)).index_cast(T.index), - ) - prev_sum = new_sum - total_padded = prev_sum gpu.barrier() - total_padded = _lds_load_raw( - cumsum_mr, ArithValue(c_E).index_cast(T.index) - ) + total_padded = _lds_load_raw(cumsum_mr, c_E) # Read my_start and my_end from cumsum LDS - my_start = _lds_load_raw( - cumsum_mr, ArithValue(my_expert).index_cast(T.index) - ) - my_end = _lds_load_raw( - cumsum_mr, ArithValue(my_expert + c_one).index_cast(T.index) - ) + my_start = _lds_load_raw(cumsum_mr, my_expert) + my_end = _lds_load_raw(cumsum_mr, my_expert + c_one) + # Hoist before if/else: AST rewriter extracts branches into + # separate functions, so variables must be defined in outer scope. local_idx_p23 = tid if has_mask: - # EP: Compute mask cumsum for local expert index (register-only DPP scan). - # Uses my_mask_val (loaded before the chunked loop) to avoid overwrite. - p23_mv = _dpp_intra_wave_prefix_sum(my_mask_val, lane, WARP_SIZE) - - # Cross-wave via scratch_mr - is_last_lane_pm = lane == fx.Int32(WARP_SIZE - 1) - _if_ll_pm = scf.IfOp(is_last_lane_pm.ir_value()) - with _if_then(_if_ll_pm): - _lds_store_raw( - scatter_mr, p23_mv, ArithValue(wave).index_cast(T.index) - ) - gpu.barrier() - cross_pm = c_zero - for _w in range_constexpr(K4_NUM_WAVES - 1): - wtp = _lds_load_raw( - scatter_mr, ArithValue(fx.Int32(_w)).index_cast(T.index) - ) - cross_pm = (wave > fx.Int32(_w)).select(cross_pm + wtp, cross_pm) - p23_mask_inclusive = p23_mv + cross_pm + _, p23_mask_inclusive = _allwave_inclusive_prefix_sum( + my_mask_val, lane, wave, scatter_mr, K4_NUM_WAVES, WARP_SIZE + ) local_idx_p23 = p23_mask_inclusive - my_mask_val - else: - local_idx_p23 = tid # Block 0, thread 0 writes num_valid_ids - is_b0 = bid == c_zero - is_t0 = tid == c_zero - is_b0_t0 = is_b0 & is_t0 - _if_nv = scf.IfOp(is_b0_t0.ir_value()) - with _if_then(_if_nv): + if (bid == c_zero) & (tid == c_zero): nvalid_rsrc = buffer_ops.create_buffer_resource( num_valid_ids, max_size=True ) @@ -1558,227 +1584,56 @@ def p23_kernel( # Step 3: Write sorted_expert_ids for THIS expert (using local_idx_p23 for EP) # Store local_idx to LDS cumsum[tid], barrier, read cumsum[my_expert] - _lds_store_raw( - cumsum_mr, local_idx_p23, ArithValue(tid).index_cast(T.index) - ) + _lds_store_raw(cumsum_mr, local_idx_p23, tid) # For E > K4_BLOCK: thread 0 extends local_idx using cumsum[K4_BLOCK-1]. # Barrier ensures all threads have written before thread 0 reads. if E > K4_BLOCK: gpu.barrier() - is_t0_extra3 = tid == c_zero - _if_t0_e3 = scf.IfOp(is_t0_extra3.ir_value()) - with _if_then(_if_t0_e3): - if has_mask: - # EP: serially extend mask prefix sum for experts >= K4_BLOCK - prev_local = _lds_load_raw( - cumsum_mr, - ArithValue(fx.Int32(K4_BLOCK - 1)).index_cast(T.index), - ) - prev_mask = buffer_ops.buffer_load( - mask_rsrc, fx.Int32(K4_BLOCK - 1), vec_width=1, dtype=T.i32 - ) - prev_local = prev_local + prev_mask - for _e3 in range_constexpr(K4_BLOCK, E): - e3_mask = buffer_ops.buffer_load( - mask_rsrc, fx.Int32(_e3), vec_width=1, dtype=T.i32 - ) - _lds_store_raw( - cumsum_mr, - prev_local, - ArithValue(fx.Int32(_e3)).index_cast(T.index), - ) - prev_local = prev_local + e3_mask - else: - for _e3 in range_constexpr(K4_BLOCK, E): - _lds_store_raw( - cumsum_mr, - fx.Int32(_e3), - ArithValue(fx.Int32(_e3)).index_cast(T.index), - ) + if tid == c_zero: + _extend_local_idx_for_extra_experts( + cumsum_mr, mask_rsrc, K4_BLOCK, E, has_mask + ) gpu.barrier() - my_local_idx = _lds_load_raw( - cumsum_mr, ArithValue(my_expert).index_cast(T.index) - ) + my_local_idx = _lds_load_raw(cumsum_mr, my_expert) sorted_e_rsrc = buffer_ops.create_buffer_resource( sorted_expert_ids, max_size=True ) blk_start = my_start // c_unit blk_end = my_end // c_unit - n_blks = blk_end - blk_start - n_eid_iters = (n_blks + fx.Int32(K4_BLOCK) - c_one) // fx.Int32(K4_BLOCK) - for _eii in range( - fx.Index(0), ArithValue(n_eid_iters).index_cast(T.index), fx.Index(1) - ): - blk_idx = blk_start + fx.Int32(_eii) * fx.Int32(K4_BLOCK) + tid - buffer_ops.buffer_store( - my_local_idx, - sorted_e_rsrc, - (blk_idx < blk_end).select(blk_idx, c_oob_idx), - ) - - # Step 4: Mesh-based scatter — read uint8 mesh from HBM, extract tokens, - # DPP prefix sum over counts, cross-wave LDS reduction, scatter stores. - p23_bid_enabled = c_one != c_zero - if has_mask: - # EP: skip scatter for masked experts (my_start == my_end, but mesh has data) - p23_bid_mask = buffer_ops.buffer_load( - mask_rsrc, my_expert, vec_width=1, dtype=T.i32 - ) - p23_bid_enabled = p23_bid_mask != c_zero - - i32_words_per_row = i32_mesh_stride >> fx.Int32(2) - n_mesh_iters_raw = (i32_words_per_row + fx.Int32(K4_BLOCK - 1)) // fx.Int32( - K4_BLOCK + _write_expert_id_blocks( + sorted_e_rsrc, my_local_idx, blk_start, blk_end - blk_start ) - has_work = my_start != my_end - n_mesh_iters = has_work.select(n_mesh_iters_raw, c_zero) - mesh_row_i32_base = (my_expert * i32_mesh_stride) >> fx.Int32(2) - - for _si, state in range( - fx.Index(0), - ArithValue(n_mesh_iters).index_cast(T.index), - fx.Index(1), - init=[my_start], - ): - position = state[0] - - word_idx = fx.Int32(_si) * fx.Int32(K4_BLOCK) + tid - col_valid = p23_bid_enabled & (word_idx < i32_words_per_row) - safe_word_idx = col_valid.select(word_idx, c_zero) - word = buffer_ops.buffer_load( - ws_rsrc, mesh_row_i32_base + safe_word_idx, vec_width=1, dtype=T.i32 - ) - - # Extract 4 bytes from the i32 word - x0 = word & c_ff - x1 = (word >> fx.Int32(8)) & c_ff - x2 = (word >> fx.Int32(16)) & c_ff - x3 = (word >> fx.Int32(24)) & c_ff - base_col = word_idx * c4 - - h0 = col_valid & (x0 != c_zero) - h1 = col_valid & (x1 != c_zero) - h2 = col_valid & (x2 != c_zero) - h3 = col_valid & (x3 != c_zero) - - my_cnt = ( - h0.select(c_one, c_zero) - + h1.select(c_one, c_zero) - + h2.select(c_one, c_zero) - + h3.select(c_one, c_zero) - ) - - my_cnt = _dpp_intra_wave_prefix_sum(my_cnt, lane, WARP_SIZE) - - # my_cnt is now intra-wave inclusive prefix sum of per-thread token counts. - # Cross-wave reduction via LDS scratch. - is_last_lane_sc = lane == fx.Int32(WARP_SIZE - 1) - _if_ll_sc = scf.IfOp(is_last_lane_sc.ir_value()) - with _if_then(_if_ll_sc): - _lds_store_raw( - scatter_mr, my_cnt, ArithValue(wave).index_cast(T.index) - ) - gpu.barrier() - - wave_offset = c_zero - for _w in range_constexpr(K4_NUM_WAVES - 1): - w_total = _lds_load_raw( - scatter_mr, ArithValue(fx.Int32(_w)).index_cast(T.index) - ) - wave_offset = (wave > fx.Int32(_w)).select( - wave_offset + w_total, wave_offset - ) - batch_total = c_zero - for _w in range_constexpr(K4_NUM_WAVES): - batch_total = batch_total + _lds_load_raw( - scatter_mr, ArithValue(fx.Int32(_w)).index_cast(T.index) - ) - gpu.barrier() - # Convert to exclusive prefix: my_exclusive = my_cnt - my_thread_count - my_thread_count = ( - h0.select(c_one, c_zero) - + h1.select(c_one, c_zero) - + h2.select(c_one, c_zero) - + h3.select(c_one, c_zero) - ) - my_exclusive = my_cnt - my_thread_count + wave_offset - - # Scatter: compute all addresses, batch-load weights, then batch-store. - scatter_base = position + my_exclusive - - # Compute packed IDs and output slots for all 4 tokens - token_id_0 = base_col - topk_slot_0 = h0.select(x0 - c_one, c_zero) - pid_0 = (topk_slot_0 << fx.Int32(24)) | token_id_0 - safe_slot_0 = h0.select(scatter_base, c_oob_idx) - - off1 = scatter_base + h0.select(c_one, c_zero) - token_id_1 = base_col + c_one - topk_slot_1 = h1.select(x1 - c_one, c_zero) - pid_1 = (topk_slot_1 << fx.Int32(24)) | token_id_1 - safe_slot_1 = h1.select(off1, c_oob_idx) - - off2 = off1 + h1.select(c_one, c_zero) - token_id_2 = base_col + fx.Int32(2) - topk_slot_2 = h2.select(x2 - c_one, c_zero) - pid_2 = (topk_slot_2 << fx.Int32(24)) | token_id_2 - safe_slot_2 = h2.select(off2, c_oob_idx) - - off3 = off2 + h2.select(c_one, c_zero) - token_id_3 = base_col + fx.Int32(3) - topk_slot_3 = h3.select(x3 - c_one, c_zero) - pid_3 = (topk_slot_3 << fx.Int32(24)) | token_id_3 - safe_slot_3 = h3.select(off3, c_oob_idx) - - # Batch-issue all 4 weight loads (increases load-use distance) - w_addr_0 = h0.select(token_id_0 * c_topk + topk_slot_0, c_zero) - w_addr_1 = h1.select(token_id_1 * c_topk + topk_slot_1, c_zero) - w_addr_2 = h2.select(token_id_2 * c_topk + topk_slot_2, c_zero) - w_addr_3 = h3.select(token_id_3 * c_topk + topk_slot_3, c_zero) - w_val_0 = buffer_ops.buffer_load( - weights_rsrc, w_addr_0, vec_width=1, dtype=T.i32 - ) - w_val_1 = buffer_ops.buffer_load( - weights_rsrc, w_addr_1, vec_width=1, dtype=T.i32 - ) - w_val_2 = buffer_ops.buffer_load( - weights_rsrc, w_addr_2, vec_width=1, dtype=T.i32 - ) - w_val_3 = buffer_ops.buffer_load( - weights_rsrc, w_addr_3, vec_width=1, dtype=T.i32 - ) - - # Batch-store: all packed IDs, then all weights - buffer_ops.buffer_store(pid_0, sorted_ids_rsrc, safe_slot_0) - buffer_ops.buffer_store(pid_1, sorted_ids_rsrc, safe_slot_1) - buffer_ops.buffer_store(pid_2, sorted_ids_rsrc, safe_slot_2) - buffer_ops.buffer_store(pid_3, sorted_ids_rsrc, safe_slot_3) - buffer_ops.buffer_store(w_val_0, sorted_w_rsrc, safe_slot_0) - buffer_ops.buffer_store(w_val_1, sorted_w_rsrc, safe_slot_1) - buffer_ops.buffer_store(w_val_2, sorted_w_rsrc, safe_slot_2) - buffer_ops.buffer_store(w_val_3, sorted_w_rsrc, safe_slot_3) - - pos_next = position + batch_total - results = yield [pos_next] - scatter_end_pos_t0 = results + # Step 4: Mesh-based scatter (EP mask + uint8 mesh read + DPP prefix sum + scatter) + scatter_end_pos_t0 = _p23_scatter_mesh( + tid, + scatter_mr, + ws_rsrc, + weights_rsrc, + sorted_ids_rsrc, + sorted_w_rsrc, + mask_rsrc, + my_expert, + my_start, + my_end, + i32_mesh_stride, + c_topk, + K4_BLOCK, + has_mask, + ) # Step 5: Fill padding with sentinel for THIS expert (parallel) - sentinel_val = c_sentinel | i32_tokens - pad_count = my_end - scatter_end_pos_t0 - pad_niters = (pad_count + fx.Int32(K4_BLOCK) - c_one) // fx.Int32(K4_BLOCK) - for _pi in range( - fx.Index(0), ArithValue(pad_niters).index_cast(T.index), fx.Index(1) - ): - pad_slot = scatter_end_pos_t0 + fx.Int32(_pi) * fx.Int32(K4_BLOCK) + tid - pad_valid = pad_slot < my_end - buffer_ops.buffer_store( - sentinel_val, sorted_ids_rsrc, pad_valid.select(pad_slot, c_oob_idx) - ) - buffer_ops.buffer_store( - c_zero, sorted_w_rsrc, pad_valid.select(pad_slot, c_oob_idx) - ) + _fill_sentinel_slots( + sorted_ids_rsrc, + sorted_w_rsrc, + scatter_end_pos_t0, + my_end - scatter_end_pos_t0, + c_sentinel | i32_tokens, + K4_BLOCK, + tid, + c_oob_idx, + ) @flyc.jit def launch_p23( @@ -1942,10 +1797,10 @@ def launch_4k_fused( # --------------------------------------------------------------------------- @functools.lru_cache(maxsize=64) def _compute_sub_tokens(num_experts, arch=None): - """Compute the LDS-capacity threshold (sub_tokens) for decode vs prefill decision. + """Compute the LDS-capacity threshold (sub_tokens) for oneshot vs multiphase decision. - Returns the max T that fits in LDS for the decode (single-kernel) path. - Same formula as compile_moe_sorting_decode. + Returns the max T that fits in LDS for the oneshot (single-kernel) path. + Same formula as _compile_moe_sorting_oneshot. """ if arch is None: arch = get_hip_arch() @@ -1963,17 +1818,17 @@ def _compute_sub_tokens(num_experts, arch=None): sub_unroll = 8 cumsum_bufs = 2 if r < (cumsum_bufs + sub_unroll): - return 0 # LDS too small — always use prefill + return 0 # LDS too small — always use multiphase r_for_sub = ((r - cumsum_bufs) // sub_unroll) * sub_unroll return r_for_sub def moe_sorting_get_workspace_size(M, num_experts, topk, unit_size=UNIT_SIZE): - """Return workspace size (in i32 elements) needed for the prefill path. - Returns 0 if the decode path will be used.""" + """Return workspace size (in i32 elements) needed for the multiphase path. + Returns 0 if the oneshot path will be used.""" sub_tokens = _compute_sub_tokens(num_experts) - DECODE_MAX_T = min(sub_tokens, max(16, BLOCK_SIZE // max(topk, num_experts // 8))) - if M <= min(sub_tokens, DECODE_MAX_T): + ONESHOT_MAX_T = min(sub_tokens, max(16, BLOCK_SIZE // max(topk, num_experts // 8))) + if M <= min(sub_tokens, ONESHOT_MAX_T): return 0 mesh_stride = ((M + unit_size - 1) // unit_size) * unit_size ws_mesh_bytes = num_experts * mesh_stride @@ -1981,6 +1836,39 @@ def moe_sorting_get_workspace_size(M, num_experts, topk, unit_size=UNIT_SIZE): return ws_mesh_i32 + (num_experts + 1) +def compile_moe_sorting( + *, num_experts, topk, max_tokens=128, unit_size=UNIT_SIZE, has_mask=False +): + """Compile MoE sorting kernels for all paths (oneshot + multiphase). + + Returns (launch_oneshot, launch_p0v2_p23, launch_4k_fused) covering all T ranges. + Oneshot compilation depends on max_tokens (LDS sizing); multiphase is independent. + """ + launch_oneshot = _compile_moe_sorting_oneshot( + num_experts=num_experts, + topk=topk, + max_tokens=max_tokens, + unit_size=unit_size, + has_mask=has_mask, + ) + _, _, _, _, _, launch_p0v2_p23, launch_4k_fused = _compile_moe_sorting_multiphase( + num_experts=num_experts, topk=topk, unit_size=unit_size, has_mask=has_mask + ) + return launch_oneshot, launch_p0v2_p23, launch_4k_fused + + +def _launch_cached(cache, key, launch_fn, args, stream): + """AOT-compiled dispatch: first call JITs, subsequent calls use cached CompiledFunction.""" + cf = cache.get(key) + stream_arg = fx.Stream(stream) + if cf is not None: + cf(*args, stream_arg) + else: + launch_fn(*args, stream=stream) + cf = flyc.compile(launch_fn, *args, stream_arg) + cache[key] = cf + + def moe_sorting_flydsl( topk_ids, topk_weights, @@ -1995,7 +1883,7 @@ def moe_sorting_flydsl( num_local_tokens=None, workspace=None, ): - """MoE sorting using FlyDSL kernel (decode + prefill paths). + """MoE sorting using FlyDSL kernel (oneshot + multiphase paths). API matches aiter.moe_sorting_fwd for drop-in replacement: moe_sorting_flydsl(topk_ids, topk_weights, @@ -2037,20 +1925,28 @@ def moe_sorting_flydsl( else: mask_tensor = expert_mask - DECODE_MAX_T = min(sub_tokens, max(16, BLOCK_SIZE // max(topk, num_experts // 8))) + ONESHOT_MAX_T = min(sub_tokens, max(16, BLOCK_SIZE // max(topk, num_experts // 8))) - if M <= min(sub_tokens, DECODE_MAX_T): + target_occupancy = 2 + num_cu = torch.cuda.get_device_properties(device).multi_processor_count + + if M <= min(sub_tokens, ONESHOT_MAX_T): max_tokens = max(M, 8) max_tokens = ((max_tokens + 7) // 8) * 8 - target_occupancy = 2 - num_cu = torch.cuda.get_device_properties(topk_ids.device).multi_processor_count n_zero_blocks = min( (moe_buf_elems + BLOCK_SIZE - 1) // BLOCK_SIZE, num_cu * target_occupancy ) n_grid_blocks = 1 + n_zero_blocks - launch_moe_sorting_decode_path( + launch_oneshot, _, _ = compile_moe_sorting( + num_experts=num_experts, + topk=topk, + max_tokens=max_tokens, + unit_size=unit_size, + has_mask=has_mask, + ) + oneshot_args = ( topk_ids, topk_weights, sorted_ids, @@ -2062,14 +1958,16 @@ def moe_sorting_flydsl( M, moe_buf_elems, n_grid_blocks, - num_experts=num_experts, - topk=topk, - max_tokens=max_tokens, - unit_size=unit_size, - has_mask=has_mask, + ) + cache_key = (num_experts, topk, max_tokens, unit_size, has_mask, device.index) + _launch_cached( + _oneshot_cf_cache, + cache_key, + launch_oneshot, + oneshot_args, + torch.cuda.current_stream(), ) else: - # Prefill path: multiple kernels via HBM workspace mesh_stride = ((M + unit_size - 1) // unit_size) * unit_size ws_mesh_bytes = num_experts * mesh_stride ws_mesh_i32 = (ws_mesh_bytes + 3) // 4 @@ -2077,188 +1975,18 @@ def moe_sorting_flydsl( if workspace is None: workspace = torch.empty(ws_total, dtype=torch.int32, device=device) - launch_moe_sorting_prefill_path( - topk_ids, - topk_weights, - sorted_ids, - sorted_weights, - sorted_expert_ids, - num_valid_ids, - moe_buf_i32, - workspace, - mask_tensor, - M, - moe_buf_elems, - mesh_stride, - ws_mesh_i32, - ws_total, - num_experts=num_experts, - topk=topk, - unit_size=unit_size, - has_mask=has_mask, + _, launch_p0v2_p23, launch_4k_fused = compile_moe_sorting( + num_experts=num_experts, topk=topk, unit_size=unit_size, has_mask=has_mask ) - - return sorted_ids, sorted_weights, sorted_expert_ids, num_valid_ids, moe_buf - - -def launch_moe_sorting_decode_path( - topk_ids, - topk_weights, - sorted_ids, - sorted_weights, - sorted_expert_ids, - num_valid_ids, - moe_buf_i32, - expert_mask, - i32_tokens, - i32_moe_buf_elems, - n_grid_blocks, - *, - num_experts, - topk, - max_tokens=128, - unit_size=UNIT_SIZE, - has_mask=False, -): - """Low-level launcher for decode path: single kernel. - - This is the hot-path entry point — no torch ops, just JIT dispatch. - Uses AOT-compiled dispatch after the first call to bypass the ~70 us - JIT overhead (inspect.Signature.bind + cache key + dict lookup). - """ - - cache_key = ( - num_experts, - topk, - max_tokens, - unit_size, - has_mask, - topk_ids.device.index, - ) - cf = _decode_cf_cache.get(cache_key) - if cf is not None: stream = torch.cuda.current_stream() - cf( - topk_ids, - topk_weights, - sorted_ids, - sorted_weights, - sorted_expert_ids, - num_valid_ids, - moe_buf_i32, - expert_mask, - i32_tokens, - i32_moe_buf_elems, - n_grid_blocks, - fx.Stream(stream), + n_zero_blocks = min( + (moe_buf_elems + BLOCK_SIZE - 1) // BLOCK_SIZE, num_cu * target_occupancy ) - return - - launch_fn = compile_moe_sorting_decode( - num_experts=num_experts, - topk=topk, - max_tokens=max_tokens, - unit_size=unit_size, - has_mask=has_mask, - ) - stream = torch.cuda.current_stream() - launch_fn( - topk_ids, - topk_weights, - sorted_ids, - sorted_weights, - sorted_expert_ids, - num_valid_ids, - moe_buf_i32, - expert_mask, - i32_tokens, - i32_moe_buf_elems, - n_grid_blocks, - stream=stream, - ) - - cf = flyc.compile( - launch_fn, - topk_ids, - topk_weights, - sorted_ids, - sorted_weights, - sorted_expert_ids, - num_valid_ids, - moe_buf_i32, - expert_mask, - i32_tokens, - i32_moe_buf_elems, - n_grid_blocks, - fx.Stream(stream), - ) - _decode_cf_cache[cache_key] = cf - + k4_grid = num_experts + n_zero_blocks + base_key = (num_experts, topk, unit_size, has_mask, device.index) -def launch_moe_sorting_prefill_path( - topk_ids, - topk_weights, - sorted_ids, - sorted_weights, - sorted_expert_ids, - num_valid_ids, - moe_buf_i32, - workspace, - expert_mask, - i32_tokens, - i32_moe_buf_elems, - mesh_stride, - mesh_size, - ws_total, - *, - num_experts, - topk, - unit_size=UNIT_SIZE, - has_mask=False, -): - """Low-level launcher for prefill path via HBM workspace. - - For small T (<=2048): fused P0_v2 (clear+scatter+count) + K4. - For large T (>2048): 4 separate kernels K1+K2+K3+K4. - - Uses AOT-compiled dispatch after the first call for each sub-kernel - to bypass JIT overhead. - """ - - ( - launch_clear_ws, - launch_p0, - launch_p1, - launch_p23, - launch_p0v2, - launch_p0v2_p23, - launch_4k_fused, - ) = compile_moe_sorting_prefill( - num_experts=num_experts, - topk=topk, - unit_size=unit_size, - has_mask=has_mask, - ) - - stream = torch.cuda.current_stream() - stream_arg = fx.Stream(stream) - - num_cu = torch.cuda.get_device_properties(topk_ids.device).multi_processor_count - - base_key = (num_experts, topk, unit_size, has_mask, topk_ids.device.index) - use_p0v2 = i32_tokens <= 2048 - - target_occupancy = 2 - n_zero_blocks = min( - (i32_moe_buf_elems + BLOCK_SIZE - 1) // BLOCK_SIZE, num_cu * target_occupancy - ) - k4_grid = num_experts + n_zero_blocks - - if use_p0v2: - ck = base_key + ("p0v2_p23",) - cf = _prefill_cf_cache.get(ck) - if cf is not None: - cf( + if M <= 2048: + p0v2_args = ( topk_ids, workspace, topk_weights, @@ -2267,16 +1995,27 @@ def launch_moe_sorting_prefill_path( sorted_expert_ids, num_valid_ids, moe_buf_i32, - expert_mask, - i32_tokens, + mask_tensor, + M, mesh_stride, - mesh_size, - i32_moe_buf_elems, + ws_mesh_i32, + moe_buf_elems, k4_grid, - stream_arg, + ) + _launch_cached( + _multiphase_cf_cache, + base_key + ("p0v2_p23",), + launch_p0v2_p23, + p0v2_args, + stream, ) else: - launch_p0v2_p23( + k1_grid = (ws_total + 1023) // 1024 + k2_grid = num_cu * target_occupancy + k2_total = M * topk + k2_stride = k2_grid * 256 + k2_niters = (k2_total + k2_stride - 1) // k2_stride + k4_args = ( topk_ids, workspace, topk_weights, @@ -2285,108 +2024,23 @@ def launch_moe_sorting_prefill_path( sorted_expert_ids, num_valid_ids, moe_buf_i32, - expert_mask, - i32_tokens, + mask_tensor, + M, mesh_stride, - mesh_size, - i32_moe_buf_elems, + ws_mesh_i32, + moe_buf_elems, + ws_total, + k2_niters, + k1_grid, + k2_grid, k4_grid, - stream=stream, ) - cf = flyc.compile( - launch_p0v2_p23, - topk_ids, - workspace, - topk_weights, - sorted_ids, - sorted_weights, - sorted_expert_ids, - num_valid_ids, - moe_buf_i32, - expert_mask, - i32_tokens, - mesh_stride, - mesh_size, - i32_moe_buf_elems, - k4_grid, - stream_arg, + _launch_cached( + _multiphase_cf_cache, + base_key + ("4k_fused",), + launch_4k_fused, + k4_args, + stream, ) - _prefill_cf_cache[ck] = cf - return - - # 4-kernel path (T > 2048): fused clear+p0+p1+p23 - k1_grid = (ws_total + 1023) // 1024 - k2_grid = min(num_cu * target_occupancy, (i32_tokens * topk + 255) // 256) - k2_total = i32_tokens * topk - k2_stride = k2_grid * 256 - k2_niters = (k2_total + k2_stride - 1) // k2_stride - - ck = base_key + ("4k_fused",) - cf = _prefill_cf_cache.get(ck) - if cf is not None: - cf( - topk_ids, - workspace, - topk_weights, - sorted_ids, - sorted_weights, - sorted_expert_ids, - num_valid_ids, - moe_buf_i32, - expert_mask, - i32_tokens, - mesh_stride, - mesh_size, - i32_moe_buf_elems, - ws_total, - k2_niters, - k1_grid, - k2_grid, - k4_grid, - stream_arg, - ) - else: - launch_4k_fused( - topk_ids, - workspace, - topk_weights, - sorted_ids, - sorted_weights, - sorted_expert_ids, - num_valid_ids, - moe_buf_i32, - expert_mask, - i32_tokens, - mesh_stride, - mesh_size, - i32_moe_buf_elems, - ws_total, - k2_niters, - k1_grid, - k2_grid, - k4_grid, - stream=stream, - ) - cf = flyc.compile( - launch_4k_fused, - topk_ids, - workspace, - topk_weights, - sorted_ids, - sorted_weights, - sorted_expert_ids, - num_valid_ids, - moe_buf_i32, - expert_mask, - i32_tokens, - mesh_stride, - mesh_size, - i32_moe_buf_elems, - ws_total, - k2_niters, - k1_grid, - k2_grid, - k4_grid, - stream_arg, - ) - _prefill_cf_cache[ck] = cf + + return sorted_ids, sorted_weights, sorted_expert_ids, num_valid_ids, moe_buf