From 0bea0f6b72b4af0bd7aa96c1889df6aae708cd9c Mon Sep 17 00:00:00 2001 From: yanboshao Date: Sun, 17 May 2026 14:25:05 +0000 Subject: [PATCH 1/4] refactor(dispatch_combine): minimize FlyDSL core deltas, inline helpers Trim the FlyDSL Python helper surface introduced by the dispatch/combine kernel down to what is strictly necessary, by leaning on existing main- branch idioms and pushing small kernel-only wrappers into the kernel file itself. FlyDSL helper modules - python/flydsl/expr/arith.py: revert to origin/main. Drop the unused divui/remui/select_by_index extensions, and remove zext_i64 in favor of a kernel-local _to_i64 helper that wraps arith.extui(_lv_unwrap(...)). - python/flydsl/expr/vector.py: revert to origin/main. Drop the bitcast_i32_to_v2bf16/bitcast_v2bf16_to_i32 helpers; the kernel now uses the standard vector.from_elements + vector.bitcast + vector.extract idiom (mirrors kernels/hgemm_splitk.py:578-585). - python/flydsl/expr/rocdl/__init__.py: replace the bespoke ballot_i64 / readlane wrappers with generic ballot(res, pred, **kw) and readlane(res, src, lane, **kw) functions, aligned with the existing readfirstlane(res, src, **kw) style: capture the ODS-generated symbols as _ods_ballot / _ods_readlane up top, and use _to_ir coercion in the wrappers. Lets call sites pick the lane-mask width (i32 on wave32, i64 on wave64) explicitly. Kernel - kernels/dispatch_combine_intranode_kernel.py: - Add three file-local helpers: _to_i64, _i32_to_vec_bitcast, _vec_to_i32_bitcast (with docstrings pointing at the main-branch idioms they mirror). - Replace 31 arith.zext_i64(x) call sites with _to_i64(x); collapse two arith.zext_i64(arith.constant(rank)) sites into arith.constant(rank, type=T.i64()). - Update the 4 llvm_bitcast call sites to use the new _i32_to_vec_bitcast / _vec_to_i32_bitcast helpers. - Update ballot_i64(...) / readlane(...) call sites to the new generic APIs: ballot(T.i64(), pred), readlane(T.i32(), src, lane). Net effect vs origin/main: arith.py and vector.py are now untouched; rocdl/__init__.py keeps a +22 line delta (generic ballot/readlane wrappers). All complexity that used to live in FlyDSL core has moved into the kernel file where it belongs. Verified - torchrun -np 8 tests/kernels/test_profiler_dispatch_combine.py --mode verify -> ALL PASS (diff=0 on dispatch + combine) - torchrun -np 8 tests/kernels/test_profiler_dispatch_combine.py --mode verify --enable-std-moe -> ALL PASS (max_diff=0.015625 within StdMoE weighted tolerance) --- kernels/dispatch_combine_intranode_kernel.py | 1571 ++++++++++++++++++ python/flydsl/expr/arith.py | 21 +- python/flydsl/expr/rocdl/__init__.py | 27 +- 3 files changed, 1607 insertions(+), 12 deletions(-) create mode 100644 kernels/dispatch_combine_intranode_kernel.py diff --git a/kernels/dispatch_combine_intranode_kernel.py b/kernels/dispatch_combine_intranode_kernel.py new file mode 100644 index 000000000..d17b8330e --- /dev/null +++ b/kernels/dispatch_combine_intranode_kernel.py @@ -0,0 +1,1571 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 FlyDSL Project Contributors + +"""FlyDSL intranode dispatch / combine kernels for expert-parallel MoE. + +This module exposes :func:`make_dispatch_jit` and :func:`make_combine_jit`, +which generate ``@flyc.jit`` launchers wrapping the two intranode kernels +(``ep_dispatch_intranode`` / ``ep_combine_intranode``). The kernels are +implemented mostly with FlyDSL high-level syntax (operator overloading, +``if``/``for`` lowered to ``scf.if``/``scf.for`` by the AST rewriter, +``buffer_load``/``buffer_store`` for global access); a few low-level MLIR +helpers below provide system-scope atomics and pointer-cast intrinsics that +do not yet have a high-level wrapper. +""" + +from __future__ import annotations + +import flydsl.compiler as flyc +import flydsl.expr as fx +import torch + +import mori.ir.flydsl as mori_shmem + +from flydsl.expr import T, arith, const_expr, range_constexpr +from flydsl.expr.buffer_ops import ( + buffer_load, + buffer_store, + create_buffer_resource_from_addr, +) +from flydsl.expr.rocdl import ballot, readlane +from flydsl.expr.typing import Stream +from flydsl.expr import vector +from flydsl.utils.smem_allocator import SmemAllocator, SmemPtr + +# Low-level MLIR escape hatches: system-scope atomics, pointer casts, and +# raw LLVM ops used inside the SIMD micro-kernel helpers +# (``_accum_experts`` / ``_weighted_accum_experts``). +from flydsl._mlir import ir +from flydsl._mlir import ir as _ir +from flydsl._mlir.dialects import llvm as _llvm_d +from flydsl._mlir.ir import IntegerAttr as _IntAttr, IntegerType as _IntTy + + +def _lv_unwrap(v): + """Return the underlying ``ir.Value`` for *v*. + + Accepts an ``ir.Value`` directly, any FlyDSL Numeric wrapper + (Int32/Int64/Float32/..., exposing ``ir_value()`` or the + ``_extract_to_ir_values`` protocol), or a Python ``int`` literal + (materialized as an ``i32`` constant). + """ + if isinstance(v, _ir.Value): + return v + if hasattr(v, "ir_value"): + return v.ir_value() + if hasattr(v, "_extract_to_ir_values"): + vals = v._extract_to_ir_values() + if len(vals) == 1: + return vals[0] + raise ValueError(f"Expected 1 ir.Value, got {len(vals)}") + if isinstance(v, int): + _i32 = _IntTy.get_signless(32) + return _llvm_d.ConstantOp(_i32, _IntAttr.get(_i32, v)).result + raise TypeError(f"Cannot convert {type(v).__name__} to ir.Value") + + +def _to_i64(v): + """Zero-extend i32 (Numeric / ArithValue / ir.Value) to i64 ``ArithValue``. + + Thin wrapper over ``arith.extui`` so call sites can pass DSL Numeric + types directly without manual unwrap. Used pervasively to widen i32 + indices/offsets into i64 byte offsets for P2P address arithmetic. + """ + return arith.extui(T.i64(), _lv_unwrap(v)) + + +def _i32_to_vec_bitcast(target_vec_type, i32_scalar): + """Bitcast an i32 scalar to ``target_vec_type`` (e.g. ``vector<2xbf16>``). + + ``vector.bitcast`` only handles vector-to-vector reinterpretation, so we + first lift the i32 scalar to ``vector<1xi32>`` via ``vector.from_elements`` + and then let ``vector.bitcast`` widen the element count. Mirrors the + idiom used in ``kernels/hgemm_splitk.py`` (see L578-585) for splitting + i64 fragments into f16x4 WMMA operands. + """ + return vector.bitcast( + target_vec_type, + vector.from_elements(T.VectorType.get([1], T.i32()), [i32_scalar]), + ) + + +def _vec_to_i32_bitcast(vec_val): + """Bitcast a 32-bit vector (e.g. ``vector<2xbf16>``) back to an i32 scalar. + + Inverse of :func:`_i32_to_vec_bitcast`: ``vector.bitcast`` to + ``vector<1xi32>`` first, then ``vector.extract([0])`` to peel the lone + lane out as a scalar. + """ + return vector.extract( + vector.bitcast(T.VectorType.get([1], T.i32()), vec_val), + static_position=[0], + ) + + +def _to_ptr_global(v): + """Cast an i64 address to ``!llvm.ptr<1>`` (global address space).""" + return _llvm_d.IntToPtrOp( + _llvm_d.PointerType.get(address_space=1), _lv_unwrap(v)).result + + +def store_i32_system(addr_i64, offset, val): + """System-scope monotonic i32 store at ``addr_i64 + offset*4`` bytes.""" + base = _lv_unwrap(addr_i64) + off = _lv_unwrap(offset) + val_ = _lv_unwrap(val) + _i64 = _IntTy.get_signless(64) + _i32 = _IntTy.get_signless(32) + _nuw = _ir.Attribute.parse("#llvm.overflow") + off64 = _llvm_d.ZExtOp(_i64, off).res if off.type == _i32 else off + byte_off = _llvm_d.MulOp( + off64, _llvm_d.ConstantOp(_i64, _IntAttr.get(_i64, 4)).result, _nuw).result + addr = _llvm_d.AddOp(base, byte_off, _nuw).result + gptr = _llvm_d.IntToPtrOp( + _llvm_d.PointerType.get(address_space=1), addr).result + _llvm_d.StoreOp(val_, gptr, alignment=4, + ordering=_llvm_d.AtomicOrdering.monotonic, + syncscope="one-as") + + +def store_i64_global_system(addr_i64, val): + """System-scope monotonic i64 store to ``addr_i64``.""" + gptr = _to_ptr_global(addr_i64) + _llvm_d.StoreOp(_lv_unwrap(val), gptr, alignment=8, + ordering=_llvm_d.AtomicOrdering.monotonic, + syncscope="one-as") + + +def load_i64_global(addr_i64): + """Relaxed global i64 load from ``addr_i64``.""" + ptr = _to_ptr_global(addr_i64) + _i64 = _IntTy.get_signless(64) + return _llvm_d.LoadOp(_i64, ptr, alignment=8).result + + +def atomic_add_global_at(addr_i64, val): + """Monotonic global ``atomic fetch-and-add``; returns the old value.""" + ptr = _to_ptr_global(addr_i64) + return _llvm_d.AtomicRMWOp( + _llvm_d.AtomicBinOp.add, ptr, _lv_unwrap(val), + _llvm_d.AtomicOrdering.monotonic).res + + +# NOTE: explicit ``index``/``i32`` casts that used to wrap every for-loop bound +# and induction-variable in this file have been removed. FlyDSL's +# ``scf_for_dispatch`` accepts i32/Python-int bounds directly and yields an +# i32 IV; ``SmemPtr.{load,store}`` runs each index through +# ``get_index_value`` which materializes/casts to ``index`` on demand. So +# ``arith.index_cast(T.index(), x)`` / ``arith.index_cast(T.i32(), iv)`` +# everywhere was a no-op IR-wise and pure boilerplate. + + +def make_dispatch_kernel( + *, + rank: int, + npes: int, + experts_per_rank: int, + experts_per_token: int, + hidden_dim: int, + hidden_elem_size: int, + max_tok_per_rank: int, + block_num: int, + warp_num_per_block: int, + scale_dim: int = 0, + scale_type_size: int = 0, + enable_std_moe: bool = False, + data_type=None, +): + """Build the intranode dispatch ``@flyc.kernel``. + + Schedules ``cur_tok * experts_per_token`` work items across all + ``block_num * warp_num_per_block`` warps. Each warp: + + 1. resolves the (dest_pe, dest_tok) slot via atomic counters on the + remote rank's ``shmem_tok_off``, + 2. P2P-writes token embedding, weights and indices into the + destination's symmetric shmem buffers, + 3. publishes a "send done" signal to every peer and waits for the + dual signal from each peer so it can finalize ``total_recv``, + 4. when ``enable_std_moe`` is set, performs ConvertDispatchOutput + (per-expert packing for the std-MoE expert path). + """ + max_recv = npes * max_tok_per_rank + _is_fp4 = (data_type == torch.float4_e2m1fn_x2) + if _is_fp4: + n_i32 = hidden_dim // 8 # 8 fp4 values per i32 (4 bytes) + nbytes = hidden_dim // 2 # 2 fp4 values per byte + else: + n_i32 = (hidden_dim * hidden_elem_size) // 4 + nbytes = hidden_dim * hidden_elem_size + scale_bytes = scale_dim * scale_type_size + scale_n_i32 = (scale_bytes + 3) // 4 if scale_bytes > 0 else 0 + enable_scales = scale_bytes > 0 + max_tokens_per_expert = npes * max_tok_per_rank # per-expert bucket capacity + + @flyc.kernel + def ep_dispatch_intranode( + addr_inp_tok: fx.Int64, # [cur_tok, hidden_dim] bf16 + addr_idx: fx.Int64, # [cur_tok, k] i32 (token_indices) + addr_wts: fx.Int64, # [cur_tok, k] f32 (weights_buf) + addr_out_tok: fx.Int64, # shmem_out_tok + addr_out_wts: fx.Int64, # shmem_out_wts + addr_out_idx: fx.Int64, # shmem_out_idx + addr_tok_off: fx.Int64, # shmem_tok_off (i32[1]) + addr_recv_num: fx.Int64, # recv_tok_num (i32[npes]) + addr_dest_ctr: fx.Int64, # dest_pe_ctr (i32[npes]) + addr_disp_bar: fx.Int64, # dispatch_bar (i32[1]) + addr_tok_map: fx.Int64, # dest_tok_map (i32[cur_tok*k]) + addr_tis: fx.Int64, # tok_id_to_src (i32[max_recv]) + addr_total_rv: fx.Int64, # total_recv (i32[1]) + # Pre-resolved P2P address arrays (i64[npes]): the remote-side base + # of each symmetric shmem buffer on every peer PE. + addr_p2p_tok_off: fx.Int64, + addr_p2p_tis: fx.Int64, + addr_p2p_out_wts: fx.Int64, + addr_p2p_out_idx: fx.Int64, + addr_p2p_out_tok: fx.Int64, + addr_p2p_recv_num: fx.Int64, + addr_scales: fx.Int64, # input scales buffer + addr_p2p_out_scales: fx.Int64, # i64[npes] P2P addresses of scales buffer + # ── StdMoE ConvertDispatchOutput parameters ── + addr_packed_recv_x: fx.Int64, # expert-major token buffer + addr_packed_recv_count: fx.Int64, # per-expert token count (i32[experts_per_rank]) + addr_packed_recv_src_info: fx.Int64, # source info (i32[experts_per_rank * max_tok_per_expert]) + addr_disp_tok_map: fx.Int64, # slot mapping (i64[max_recv * top_k]) + addr_disp_grid_bar: fx.Int64, # grid barrier (i32[1]) + cur_tok: fx.Int32, # runtime token count for the current batch + ): + tid = fx.thread_idx.x # thread id within the block + bid = fx.block_idx.x # block id within the grid + lane = tid & 63 # lane id within the warp (0..63) + warp = tid >> 6 # warp id within the block + global_warp_id = bid * warp_num_per_block + warp # warp id across the grid + global_warp_num = block_num * warp_num_per_block # total warps in the grid + work_limit = cur_tok * experts_per_token # total (token, k-slot) pairs + _r_idx = create_buffer_resource_from_addr(addr_idx) + _r_wts = create_buffer_resource_from_addr(addr_wts) + _r_tok_map = create_buffer_resource_from_addr(addr_tok_map) + _r_tok_off = create_buffer_resource_from_addr(addr_tok_off) + _r_dest_ctr = create_buffer_resource_from_addr(addr_dest_ctr) + _r_disp_bar = create_buffer_resource_from_addr(addr_disp_bar) + _r_total_rv = create_buffer_resource_from_addr(addr_total_rv) + _r_p2p_tok_off = create_buffer_resource_from_addr(addr_p2p_tok_off) + _r_p2p_tis = create_buffer_resource_from_addr(addr_p2p_tis) + _r_p2p_out_wts = create_buffer_resource_from_addr(addr_p2p_out_wts) + _r_p2p_out_idx = create_buffer_resource_from_addr(addr_p2p_out_idx) + _r_p2p_out_tok = create_buffer_resource_from_addr(addr_p2p_out_tok) + _r_p2p_recv_num = create_buffer_resource_from_addr(addr_p2p_recv_num) + + # Phase 1: P2P-scatter tokens to their destination PEs. + # Iteration space: every (src_tok, k_slot) pair, distributed across + # all grid-wide warps. ``k_slot`` is the per-token expert slot index + # (i.e. which of the top-k experts this work-item handles). + for work_idx in range(global_warp_id, work_limit, global_warp_num): + src_tok = (work_idx // experts_per_token) + k_slot = (work_idx % experts_per_token) + # Issue the two idx loads in parallel; divui is deferred so the + # loads do not block on the integer divide. + dest_expert = buffer_load(_r_idx, work_idx, vec_width=1, dtype=T.i32()) + safe_lane = arith.select(lane < k_slot, lane, 0) + lane_expert = buffer_load(_r_idx, src_tok * experts_per_token + safe_lane, vec_width=1, dtype=T.i32()) + dest_pe = (dest_expert // experts_per_rank) + lane_dest_pe = (lane_expert // experts_per_rank) + # Per-lane "is this lane a duplicate destPE assignment for some + # k_slot earlier than the current one?" (sentinel 64 = no). + dup_per_lane = arith.select( + lane_dest_pe == dest_pe, + arith.select(lane < k_slot, lane, 64), + 64) + dup_ballot = ballot(T.i64(), dup_per_lane < 64) + is_dup = dup_ballot != 0 + + # Atomically allocate dest_tok_id on lane 0, then readlane-broadcast. + dest_tok_lane0 = arith.constant(0) + if lane == 0: + if dup_ballot == 0: + dest_tok_lane0 = atomic_add_global_at( + buffer_load(_r_p2p_tok_off, dest_pe, vec_width=1, dtype=T.i64()), + arith.constant(1)) + dest_tok_id = readlane(T.i32(), dest_tok_lane0, 0) + + # Per-(token, k_slot) entry stored into dest_tok_map: encoded + # global slot id, or sentinel ``npes * max_recv`` for dup-slots + # which the combine kernel will treat as "no source". + sentinel_val = npes * max_recv + tok_map_entry = arith.select( + is_dup, + sentinel_val, + dest_pe * max_recv + dest_tok_id) + if lane == 0: + buffer_store(tok_map_entry, _r_tok_map, work_idx) + + if lane == 0: + if dup_ballot == 0: + # Publish the (src_pe, src_lid) origin so the dest PE + # can later route the token back during combine. + src_tok_enc = rank * max_tok_per_rank + src_tok + _r_tis_remote = create_buffer_resource_from_addr( + buffer_load(_r_p2p_tis, dest_pe, vec_width=1, dtype=T.i64())) + buffer_store(src_tok_enc, _r_tis_remote, dest_tok_id) + dest_ctr_addr = addr_dest_ctr + _to_i64(dest_pe) * 4 + atomic_add_global_at(dest_ctr_addr, arith.constant(1)) + + # Each lane writes one (weight, expert_idx) entry to the dest + # PE's symmetric weights / idx buffers, parallel over k_slot. + if lane < experts_per_token: + if dup_ballot == 0: + wt_src_off = src_tok * experts_per_token + lane + wt_val = buffer_load(_r_wts, wt_src_off, vec_width=1, dtype=T.f32()) + idx_val = buffer_load(_r_idx, wt_src_off, vec_width=1, dtype=T.i32()) + dest_slot = dest_tok_id * experts_per_token + lane + _r_wts_remote = create_buffer_resource_from_addr( + buffer_load(_r_p2p_out_wts, dest_pe, vec_width=1, dtype=T.i64())) + buffer_store(arith.bitcast(T.i32(), wt_val), _r_wts_remote, dest_slot) + _r_idx_remote = create_buffer_resource_from_addr( + buffer_load(_r_p2p_out_idx, dest_pe, vec_width=1, dtype=T.i64())) + buffer_store(idx_val, _r_idx_remote, dest_slot) + + if const_expr(enable_scales): + if lane < scale_n_i32: + if dup_ballot == 0: + _r_scales = create_buffer_resource_from_addr(addr_scales) + sc_src_off = src_tok * scale_n_i32 + lane + sc_val = buffer_load(_r_scales, sc_src_off, vec_width=1, dtype=T.i32()) + sc_dst_off = dest_tok_id * scale_n_i32 + lane + _r_sc_remote = create_buffer_resource_from_addr( + buffer_load( + create_buffer_resource_from_addr(_lv_unwrap(addr_p2p_out_scales)), + dest_pe, vec_width=1, dtype=T.i64())) + buffer_store(sc_val, _r_sc_remote, sc_dst_off) + + # Token-embedding scatter: when ``is_dup`` the copy_end equals + # ``lane_i32_off`` and the loop trips zero iterations. + # + # ``lane_i32_off`` - this lane's starting i32 offset (each lane + # owns 4 consecutive i32 = 16 bytes). + # ``chunk_i32_off`` - sliding i32 offset within the token's + # hidden-dim chunk being copied this step. + remote_tok_addr = buffer_load(_r_p2p_out_tok, dest_pe, vec_width=1, dtype=T.i64()) + \ + _to_i64(dest_tok_id) * nbytes + local_tok_addr = addr_inp_tok + _to_i64(src_tok) * nbytes + rsrc_src = create_buffer_resource_from_addr(local_tok_addr) + rsrc_dst = create_buffer_resource_from_addr(remote_tok_addr) + lane_i32_off = lane * 4 + safe_end_i32 = (n_i32 // 512) * 512 # largest multiple of 512 that fits + if const_expr(n_i32 >= 512 and safe_end_i32 > 0): + copy_end_main = arith.select(is_dup, lane_i32_off, safe_end_i32) + for chunk_i32_off in range(lane_i32_off, copy_end_main, 512): + vec_a = buffer_load(rsrc_src, chunk_i32_off, vec_width=4, dtype=T.i32()) + vec_b = buffer_load(rsrc_src, chunk_i32_off + 256, vec_width=4, dtype=T.i32()) + buffer_store(vec_a, rsrc_dst, chunk_i32_off) + buffer_store(vec_b, rsrc_dst, chunk_i32_off + 256) + if const_expr(safe_end_i32 < n_i32): + copy_end_tail = arith.select(is_dup, lane_i32_off, n_i32) + for chunk_i32_off in range(lane_i32_off + safe_end_i32, copy_end_tail, 256): + vec_a = buffer_load(rsrc_src, chunk_i32_off, vec_width=4, dtype=T.i32()) + buffer_store(vec_a, rsrc_dst, chunk_i32_off) + elif const_expr(n_i32 < 512): + copy_end_small = arith.select(is_dup, lane_i32_off, n_i32) + for chunk_i32_off in range(lane_i32_off, copy_end_small, 256): + vec_a = buffer_load(rsrc_src, chunk_i32_off, vec_width=4, dtype=T.i32()) + buffer_store(vec_a, rsrc_dst, chunk_i32_off) + + # Phase 2: grid barrier + publish per-peer token-count signal. + # ``recv_num`` is a symmetric ``i32[npes]`` array: index ``src_pe`` + # on dest holds the count of tokens that ``src_pe`` will send. + fx.barrier() + if tid == 0: + atomic_add_global_at(addr_disp_bar, arith.constant(1)) + + recv_num_local_byte_off = arith.constant(rank, type=T.i64()) * 4 + for dest_pe in range(lane, npes, 64): + if global_warp_id == 0: + mori_shmem.int32_wait_until_equals(addr_disp_bar, block_num) + buffer_store(arith.constant(0), _r_disp_bar, 0) + # +1 because 0 is the "unset" sentinel that consumers wait on. + signal_value = buffer_load(_r_dest_ctr, dest_pe, vec_width=1, dtype=T.i32()) + 1 + recv_num_remote_addr = buffer_load( + _r_p2p_recv_num, dest_pe, vec_width=1, dtype=T.i64()) + recv_num_local_byte_off + mori_shmem.int32_wait_until_equals(recv_num_remote_addr, 0) + store_i32_system(recv_num_remote_addr, arith.constant(0), signal_value) + + # Phase 3: wait for each peer's count signal and accumulate total_recv. + for src_pe in range(lane, npes, 64): + if global_warp_id == 0: + recv_num_src_addr = addr_recv_num + _to_i64(src_pe) * 4 + signal_value = mori_shmem.int32_wait_until_greater_than(recv_num_src_addr, 0) + peer_recv_count = signal_value - 1 # undo the +1 sentinel offset + store_i32_system(recv_num_src_addr, arith.constant(0), arith.constant(0)) + atomic_add_global_at(addr_total_rv, peer_recv_count) + buffer_store(arith.constant(0), _r_dest_ctr, src_pe) + + if global_warp_id == 0: + if lane == 0: + buffer_store(arith.constant(0), _r_tok_off, 0) + + # Phase 4: ConvertDispatchOutput (StdMoE). + # Repack received tokens into per-expert buckets indexed by + # ``local_expert_id``. Each (received_tok, k_slot) pair allocates a + # slot in ``packed_recv_x[local_expert_id]`` if the expert is local. + if const_expr(enable_std_moe): + fx.barrier() + if tid == 0: + atomic_add_global_at(addr_disp_grid_bar, arith.constant(1)) + fx.barrier() + if tid == 0: + mori_shmem.int32_wait_until_equals(addr_disp_grid_bar, block_num) + fx.barrier() + + _r_out_idx_local = create_buffer_resource_from_addr(addr_out_idx) + _r_tis_local = create_buffer_resource_from_addr(addr_tis) + _r_out_tok_local = create_buffer_resource_from_addr(addr_out_tok) + total_recv = buffer_load(_r_total_rv, 0, vec_width=1, dtype=T.i32()) + smoe_work_limit = total_recv * experts_per_token + + for smoe_idx in range(global_warp_id, smoe_work_limit, global_warp_num): + smoe_tok_id = (smoe_idx // experts_per_token) + + expert_id = buffer_load(_r_out_idx_local, smoe_idx, vec_width=1, dtype=T.i32()) + local_expert_id = expert_id - rank * experts_per_rank + # MUST be unsigned ``ult``: when ``expert_id`` is NOT this + # rank's expert, ``local_expert_id`` is negative; the + # signed-overload form ``local_expert_id < experts_per_rank`` + # lowers to ``arith.cmpi slt`` and would mis-classify negative + # values as local (-> illegal global access in WarpCopy). + is_local = arith.cmpi(arith.CmpIPredicate.ult, local_expert_id, + arith.constant(experts_per_rank)) + + # Atomically allocate the per-expert packing slot on lane 0. + packed_slot_lane0 = arith.constant(0) + if lane == 0: + if is_local: + count_addr = addr_packed_recv_count + _to_i64(local_expert_id) * 4 + packed_slot_lane0 = atomic_add_global_at(count_addr, arith.constant(1)) + packed_slot = readlane(T.i32(), packed_slot_lane0, 0) + + safe_local_expert = arith.select(is_local, local_expert_id, 0) + # Linear slot in the flat ``packed_recv_x[experts_per_rank, max_tokens_per_expert]`` buffer. + packed_linear_idx = safe_local_expert * max_tokens_per_expert + packed_slot + slot_val_i64 = arith.select(is_local, + _to_i64(packed_linear_idx), + -1) # false_value materialized as i64 from true_value's type; -1 = not a local expert + if lane == 0: + slot_map_addr = addr_disp_tok_map + _to_i64(smoe_idx) * 8 + store_i64_global_system(slot_map_addr, slot_val_i64) + + if lane == 0: + if is_local: + src_pos_enc = buffer_load(_r_tis_local, smoe_tok_id, + vec_width=1, dtype=T.i32()) + store_i32_system(addr_packed_recv_src_info, + packed_linear_idx, src_pos_enc) + + # WarpCopy token data from shmem_out_tok into the packed + # per-expert buffer at slot ``packed_linear_idx``. + src_tok_base = addr_out_tok + _to_i64(smoe_tok_id) * nbytes + dst_tok_base = addr_packed_recv_x + _to_i64(packed_linear_idx) * nbytes + rsrc_src = create_buffer_resource_from_addr(src_tok_base) + rsrc_dst = create_buffer_resource_from_addr(dst_tok_base) + lane_i32_off = lane * 4 + safe_end_i32 = (n_i32 // 512) * 512 + if n_i32 >= 512 and safe_end_i32 > 0: + copy_end_main = arith.select(is_local, safe_end_i32, lane_i32_off) + for chunk_i32_off in range(lane_i32_off, copy_end_main, 512): + vec_a = buffer_load(rsrc_src, chunk_i32_off, vec_width=4, dtype=T.i32()) + vec_b = buffer_load(rsrc_src, chunk_i32_off + 256, vec_width=4, dtype=T.i32()) + buffer_store(vec_a, rsrc_dst, chunk_i32_off) + buffer_store(vec_b, rsrc_dst, chunk_i32_off + 256) + if safe_end_i32 < n_i32: + copy_end_tail = arith.select(is_local, n_i32, lane_i32_off) + for chunk_i32_off in range(lane_i32_off + safe_end_i32, copy_end_tail, 256): + vec_a = buffer_load(rsrc_src, chunk_i32_off, vec_width=4, dtype=T.i32()) + buffer_store(vec_a, rsrc_dst, chunk_i32_off) + elif n_i32 < 512: + copy_end_small = arith.select(is_local, n_i32, lane_i32_off) + for chunk_i32_off in range(lane_i32_off, copy_end_small, 256): + vec_a = buffer_load(rsrc_src, chunk_i32_off, vec_width=4, dtype=T.i32()) + buffer_store(vec_a, rsrc_dst, chunk_i32_off) + + return ep_dispatch_intranode + + +def make_combine_kernel( + *, + rank: int, + npes: int, + experts_per_rank: int = 0, + experts_per_token: int, + hidden_dim: int, + hidden_elem_size: int, + max_tok_per_rank: int, + block_num: int, + warp_num_per_block: int, + data_type=None, + enable_weights: bool = False, + enable_std_moe: bool = False, + use_p2p_read: bool = False, + skip_stage1: bool = False, + inp_data_type=None, +): + """Build the intranode combine ``@flyc.kernel``. + + Stages: + * Stage 1 - P2P-scatter token contributions (and optionally weights) + from each rank's GEMM2 output buffer into every peer's + ``shmem_comb_inp``. + * Stage 2 - CrossDeviceBarrier so every rank has observed Stage 1 + writes from every peer. + * Stage 3 - local read of ``shmem_comb_inp`` plus per-expert WarpAccum + reducing into ``addr_comb_out``. + * Stage 3b - parallel weight accumulation (when ``enable_weights``). + + Parameters: + skip_stage1: + Compile-out the token half of Stage 1 (P2P scatter / + ConvertCombineInput). The caller is expected to have staged token + bytes into ``shmem_comb_inp`` ahead of the launch (e.g. fused + GEMM2-epilogue P2P scatter). Weight scatter is still emitted when + ``enable_weights`` is set, because the 16B weight writes share the + ROCm IPC fabric with the heavy token writes from the upstream stage + and get silently dropped under contention — the combine kernel + therefore owns weight scatter on a quiet fabric. + inp_data_type: + External input dtype. When different from ``data_type`` (currently + only ``bfloat16 + float8_e4m3fn`` is supported) Stage 1 fuses a + bf16 -> fp8 cast inline (``UseFp8DirectCast``-equivalent), and + Stage 3 widens addressing strides for bf16 output writes. + """ + max_recv = npes * max_tok_per_rank + _is_fp4 = (data_type == torch.float4_e2m1fn_x2) + if _is_fp4: + n_i32 = hidden_dim // 8 + nbytes = hidden_dim // 2 + else: + n_i32 = (hidden_dim * hidden_elem_size) // 4 + nbytes = hidden_dim * hidden_elem_size + tok_stride = n_i32 * 4 + + # Mixed-dtype combine: external dtype (kernel input AND output) differs + # from the on-wire/staging dtype used for P2P transport. Currently only + # supports bf16 external + OCP fp8 transport (mori UseFp8DirectCast). + # + # Semantics ("fp8_direct_cast"): + # - kernel reads bf16 input → Stage 1 inline cast bf16→fp8 → P2P fp8 + # - kernel reads fp8 staging → Stage 3 reduce in f32 → cast f32→bf16 + # → kernel writes bf16 output to ``addr_comb_out`` + # + # ``inp_data_type`` is the legacy parameter name; conceptually it now + # represents the external (input/output-shared) dtype. + _xfer_bf16_to_fp8 = ( + inp_data_type is not None and inp_data_type != data_type + and inp_data_type == torch.bfloat16 + and data_type == torch.float8_e4m3fn + ) + if inp_data_type is not None and inp_data_type != data_type and not _xfer_bf16_to_fp8: + raise NotImplementedError( + f"combine_kernel mixed-dtype only supports " + f"inp_data_type=bfloat16 + data_type=float8_e4m3fn, " + f"got inp_data_type={inp_data_type}, data_type={data_type}") + if _xfer_bf16_to_fp8 and enable_std_moe: + raise NotImplementedError( + "combine_kernel mixed-dtype path does not yet support " + "enable_std_moe=True (the std-MoE Stage 1 / Stage 3 use " + "_weighted_accum_experts which has not been retrofitted for " + "asymmetric I/O dtypes)") + + if _xfer_bf16_to_fp8: + # bf16 input stride for Stage 1 source addressing only. The transport + # (P2P-scattered staging) uses ``nbytes`` (= fp8 stride) as before. + # Stage 3 output addressing also uses bf16 stride (= 2 × fp8 stride). + inp_nbytes = hidden_dim * 2 + inp_n_i32 = (hidden_dim * 2) // 4 + # bf16-stride i32 count per token for Stage 3 output offsets. + out_n_i32 = (hidden_dim * 2) // 4 + else: + inp_nbytes = nbytes + inp_n_i32 = n_i32 + out_n_i32 = n_i32 + if _is_fp4: + from flydsl._mlir.dialects import rocdl as _rocdl_d + _v2f32_fp4 = T.VectorType.get([2], T.f32()) + _v8f32_fp4 = T.VectorType.get([8], T.f32()) + + def _to_accum(i32_val): + # ROCDL fp4 lane unpack: i32 (8 packed fp4) -> 4 × vector<2xf32>. + scale_one = arith.constant(1.0, type=T.f32()) + pairs = [ + _rocdl_d.cvt_scalef32_pk_f32_fp4( + res=_v2f32_fp4, src=i32_val, scale=scale_one, + src_sel_index=sel) + for sel in range(4) + ] + # Stitch 4 × v2f32 -> v8f32 via two-stage shuffle. + lo4 = vector.shuffle(pairs[0], pairs[1], [0, 1, 2, 3]) + hi4 = vector.shuffle(pairs[2], pairs[3], [0, 1, 2, 3]) + return vector.shuffle(lo4, hi4, [0, 1, 2, 3, 4, 5, 6, 7]) + + def _from_accum(accum_val): + # Re-pack v8f32 -> i32 via 4 × cvt_scalef32_pk_fp4_f32. + _i32_ty = _IntTy.get_signless(32) + scale_one = arith.constant(1.0, type=T.f32()) + old = arith.constant(0, type=_i32_ty) + for sel in range(4): + f_a = vector.extract(accum_val, static_position=[sel * 2]) + f_b = vector.extract(accum_val, static_position=[sel * 2 + 1]) + old = _rocdl_d.cvt_scalef32_pk_fp4_f32( + res=_i32_ty, old_vdst=old, src0=f_a, src1=f_b, + scale=scale_one, dst_sel_index=sel) + return old + + def _zero_accum(): + return arith.constant_vector(0.0, _v8f32_fp4) + + elif hidden_elem_size == 2: # bf16 + def _to_accum(i32_val): + return _i32_to_vec_bitcast(T.VectorType.get([2], T.bf16()), i32_val).extf( + T.VectorType.get([2], T.f32())) + def _from_accum(accum_val): + return _vec_to_i32_bitcast(accum_val.truncf( + T.VectorType.get([2], T.bf16()))) + def _zero_accum(): + return arith.constant_vector(0.0, T.VectorType.get([2], T.f32())) + elif hidden_elem_size == 4: # f32 + def _to_accum(i32_val): + return arith.bitcast(T.f32(), i32_val) + def _from_accum(accum_val): + return arith.bitcast(T.i32(), accum_val) + def _zero_accum(): + return arith.constant(0.0, type=T.f32()) + elif hidden_elem_size == 1: # fp8 + from flydsl._mlir.dialects import rocdl as _rocdl_d + _is_ocp = (data_type == torch.float8_e4m3fn) + _is_fnuz = (data_type == torch.float8_e4m3fnuz) + _cvt_pk_f32 = _rocdl_d.cvt_pk_f32_fp8 + _cvt_pk_f8 = _rocdl_d.cvt_pk_fp8_f32 + _v2f32_fp8 = T.VectorType.get([2], T.f32()) + _v4f32_fp8 = T.VectorType.get([4], T.f32()) + + def _to_accum(i32_val): + # ROCDL fp8 lane unpack: i32 (4 packed fp8) -> 2 × vector<2xf32>. + lo = _cvt_pk_f32(res=_v2f32_fp8, src=i32_val, word_sel=False) + hi = _cvt_pk_f32(res=_v2f32_fp8, src=i32_val, word_sel=True) + # Concatenate lo|hi -> vector<4xf32> (mask picks lo[0,1], hi[0,1]). + vec = vector.shuffle(lo, hi, [0, 1, 2, 3]) + if _is_fnuz: + vec = vec * 0.5 + return vec + + def _from_accum(accum_val): + _i32_ty = _IntTy.get_signless(32) + if _is_fnuz: + accum_val = accum_val * 2.0 + if const_expr(_xfer_bf16_to_fp8): + # Mixed-dtype path: write bf16 output (8 bytes per lane). + # v4f32 -> v4bf16 (truncf) -> v2i32 (bitcast). Caller stores + # via buffer_store(..., vec_width=2, dtype=T.i32()) at an i32 + # offset doubled relative to fp8 mode (2 i32 = 4 bf16 = 8 B). + _v4bf16 = T.VectorType.get([4], T.bf16()) + _v2i32 = T.VectorType.get([2], _i32_ty) + return vector.bitcast(_v2i32, accum_val.truncf(_v4bf16)) + f0 = vector.extract(accum_val, static_position=[0]) + f1 = vector.extract(accum_val, static_position=[1]) + f2 = vector.extract(accum_val, static_position=[2]) + f3 = vector.extract(accum_val, static_position=[3]) + zero = arith.constant(0, type=_i32_ty) + lo = _cvt_pk_f8(res=_i32_ty, src_a=f0, src_b=f1, + old=zero, word_sel=False) + return _cvt_pk_f8(res=_i32_ty, src_a=f2, src_b=f3, + old=lo, word_sel=True) + + def _zero_accum(): + return arith.constant_vector(0.0, _v4f32_fp8) + else: + raise ValueError(f"Unsupported hidden_elem_size={hidden_elem_size}") + + def _accum_experts(vals, vlds, all_vld): + """Reduce the k per-expert i32 partials into one merged i32. + + Each value is widened via ``_to_accum`` (bf16/fp8/...->f32 vector), + summed in high precision, then narrowed back via ``_from_accum``. + + Args: + vals: per-expert raw i32 values (one per k-slot). + vlds: per-expert i1 validity flags (used iff ``all_vld`` is False). + all_vld: when True, skip the masking and treat every slot as live. + """ + if all_vld: + acc = _to_accum(vals[0]) + for k_slot in range(1, len(vals)): + acc = acc + _to_accum(vals[k_slot]) + else: + acc = _zero_accum() + for k_slot in range(len(vals)): + widened = _to_accum(vals[k_slot]) + zero = _zero_accum() + vld_raw = _lv_unwrap(vlds[k_slot]) + acc = acc + arith.select(vld_raw, widened, zero) + return _from_accum(acc) + + def _weighted_accum_experts(vals, wts, vlds, all_vld): + """Weighted variant of ``_accum_experts``: ``sum(wt[k] * val[k])``. + + Used by the StdMoE Stage 1 path where the kernel reduces the k + per-expert contributions (each multiplied by the dispatch-time + output weight) into one merged token before the P2P scatter. + """ + _i32ty = _IntTy.get_signless(32) + _f32ty = T.f32() + + if _is_fp4: # fp4 → v8f32 accum + from flydsl._mlir.dialects import rocdl as _rocdl_fp4 + _v2f32 = T.VectorType.get([2], T.f32()) + _v8f32 = T.VectorType.get([8], T.f32()) + scale_one = arith.constant(1.0, type=_f32ty) + acc = arith.constant_vector(0.0, _v8f32) + for j in range(len(vals)): + # ROCDL fp4 lane unpack: i32 (8 packed fp4) -> 4 × vector<2xf32>. + pairs = [ + _rocdl_fp4.cvt_scalef32_pk_f32_fp4( + res=_v2f32, src=vals[j], scale=scale_one, + src_sel_index=sel) + for sel in range(4) + ] + # Stitch 4 × v2f32 -> v8f32 via two-stage shuffle. + lo4 = vector.shuffle(pairs[0], pairs[1], [0, 1, 2, 3]) + hi4 = vector.shuffle(pairs[2], pairs[3], [0, 1, 2, 3]) + vec = vector.shuffle(lo4, hi4, [0, 1, 2, 3, 4, 5, 6, 7]) + w = vec * wts[j] # auto-broadcast scalar to v8f32 + if all_vld: + acc = acc + w + else: + acc = acc + arith.select( + vlds[j], w, arith.constant_vector(0.0, _v8f32)) + # Re-pack v8f32 -> i32 via 4 × cvt_scalef32_pk_fp4_f32. + old = arith.constant(0, type=_i32ty) + for sel in range(4): + f_a = vector.extract(acc, static_position=[sel * 2]) + f_b = vector.extract(acc, static_position=[sel * 2 + 1]) + old = _rocdl_fp4.cvt_scalef32_pk_fp4_f32( + res=_i32ty, old_vdst=old, src0=f_a, src1=f_b, + scale=scale_one, dst_sel_index=sel) + return old + + elif hidden_elem_size == 2: # bf16 → v2f32 accum + _v2bf16 = T.VectorType.get([2], T.bf16()) + _v2f32 = T.VectorType.get([2], T.f32()) + acc = arith.constant_vector(0.0, _v2f32) + for j in range(len(vals)): + # i32 → vector<2xbf16> via from_elements + vector.bitcast + # → vector<2xf32> via arith.extf, then broadcast wt and fma. + vb = _i32_to_vec_bitcast(_v2bf16, vals[j]) + vf = vb.extf(_v2f32) + w = vf * wts[j] # wts[j] scalar f32, auto-broadcast to v2f32 + if all_vld: + acc = acc + w + else: + acc = acc + arith.select( + vlds[j], w, arith.constant_vector(0.0, _v2f32)) + return _vec_to_i32_bitcast(acc.truncf(_v2bf16)) + + elif hidden_elem_size == 4: # f32 → f32 accum + acc = arith.constant(0.0, type=_f32ty) + for j in range(len(vals)): + vf = arith.bitcast(_f32ty, vals[j]) + w = vf * wts[j] + if all_vld: + acc = acc + w + else: + acc = acc + arith.select( + vlds[j], w, arith.constant(0.0, type=_f32ty)) + return arith.bitcast(_i32ty, acc) + + elif hidden_elem_size == 1: # fp8 → v4f32 accum + from flydsl._mlir.dialects import rocdl as _rocdl + _pk_f32 = _rocdl.cvt_pk_f32_fp8 + _pk_f8 = _rocdl.cvt_pk_fp8_f32 + _v2f32 = T.VectorType.get([2], T.f32()) + _v4f32 = T.VectorType.get([4], T.f32()) + acc = arith.constant_vector(0.0, _v4f32) + for j in range(len(vals)): + # ROCDL fp8 lane unpack: i32 (4 packed fp8) -> 2 × vector<2xf32>. + lo = _pk_f32(res=_v2f32, src=vals[j], word_sel=False) + hi = _pk_f32(res=_v2f32, src=vals[j], word_sel=True) + # Concatenate lo|hi into vector<4xf32>: + # mask [0,1,2,3] -> [lo[0], lo[1], hi[0], hi[1]] + vec = vector.shuffle(lo, hi, [0, 1, 2, 3]) + if _is_fnuz: + vec = vec * 0.5 + w = vec * wts[j] # wts[j] scalar f32, auto-broadcast to v4f32 + if all_vld: + acc = acc + w + else: + acc = acc + arith.select( + vlds[j], w, arith.constant_vector(0.0, _v4f32)) + if _is_fnuz: + acc = acc * 2.0 + f0 = vector.extract(acc, static_position=[0]) + f1 = vector.extract(acc, static_position=[1]) + f2 = vector.extract(acc, static_position=[2]) + f3 = vector.extract(acc, static_position=[3]) + zi = arith.constant(0, type=_i32ty) + lo = _pk_f8(res=_i32ty, src_a=f0, src_b=f1, old=zi, word_sel=False) + return _pk_f8(res=_i32ty, src_a=f2, src_b=f3, old=lo, word_sel=True) + + def _log2_if_pow2(v): + """Return ``log2(v)`` if *v* is a positive power of two, else ``None``.""" + if v > 0 and (v & (v - 1)) == 0: + return v.bit_length() - 1 + return None + # Pow2 fast-paths: when ``max_tok_per_rank`` / ``max_recv`` are powers + # of two, decode ``dest_pe / dest_lid`` and ``dest_pe / dtok`` via + # shift + mask instead of integer divide / mod. + _log2_max_tok = _log2_if_pow2(max_tok_per_rank) + _log2_max_recv = _log2_if_pow2(max_recv) + _mask_max_tok = max_tok_per_rank - 1 if _log2_max_tok is not None else None + _mask_max_recv = max_recv - 1 if _log2_max_recv is not None else None + + # Dispatch deduplicates same-PE assignments at runtime: when more than + # one of a token's k experts fall on the same dest_pe, the duplicate + # tok_map slot is encoded as ``dest_pe = npes`` (sentinel). The combine + # accumulator must skip those invalid lanes, which is exactly what the + # ``_maybe_load`` helper below does (equivalent to mori's + # ``EpCombineIntraNodeKernel`` ``srcPtrs[j] = nullptr`` short-circuit). + _use_compaction = True + + weight_bytes = experts_per_token * 4 if enable_weights else 0 + wt_n_i32 = experts_per_token if enable_weights else 0 + + # LDS layout for the P2P-base tables (i64[npes] for tokens, optionally + # i64[npes] for weights). ``SmemAllocator.finalize()`` is called from the + # JIT launcher to publish the layout to the GPU module. + allocator = SmemAllocator(None, arch="gfx942") + p2p_base_offset = allocator._align(allocator.ptr, 8) + p2p_base_size = npes * 8 + allocator.ptr = p2p_base_offset + p2p_base_size + + if enable_weights: + p2p_wt_base_offset = allocator._align(allocator.ptr, 8) + p2p_wt_base_size = npes * 8 + allocator.ptr = p2p_wt_base_offset + p2p_wt_base_size + + + @flyc.kernel + def ep_combine_intranode( + addr_inp_tok: fx.Int64, # inp_tok base (post-expert token buffer) + addr_comb_inp: fx.Int64, # shmem_comb_inp base (symmetric) + addr_comb_out: fx.Int64, # shmem_comb_out base (symmetric) + addr_xdb_mem: fx.Int64, # xdev_bar_mem (u64[npes]) + addr_xdb_flag: fx.Int64, # xdev_bar_flag (u64[1]) + addr_tok_map: fx.Int64, # dest_tok_map (i32[cur_tok*k]) + addr_comb_bar: fx.Int64, # combine_bar (i32[1]) + addr_trecv: fx.Int64, # total_recv_ptr (i32[1]) + addr_tis: fx.Int64, # tok_id_to_src (i32[max_recv], symmetric) + addr_p2p_comb_inp: fx.Int64, # i64[npes] pre-resolved P2P addresses + addr_p2p_xdb_mem: fx.Int64, # i64[npes] pre-resolved P2P addresses + addr_wts_buf: fx.Int64, # combine input weights f32[max_recv*k] + addr_comb_inp_wts: fx.Int64, # shmem weight P2P buffer (symmetric) + addr_comb_out_wts: fx.Int64, # combine output weights f32[max_tok*k] + addr_p2p_comb_inp_wts: fx.Int64, # i64[npes] weight P2P addresses + # ── StdMoE ConvertCombineInput parameters ── + addr_packed_recv_x: fx.Int64, # expert-major token buffer (post-expert) + addr_disp_tok_map: fx.Int64, # dispTokToEpSlotMap (i64[max_recv * top_k]) + addr_disp_out_wts: fx.Int64, # dispatch output weights (f32[max_recv * top_k]) + cur_rank_num_token: fx.Int32, # this PE's output token count (used by Stage 3) + ): + tid = fx.thread_idx.x + bid = fx.block_idx.x + lane = tid & 63 + warp = tid >> 6 + global_warp_id = bid * warp_num_per_block + warp # warp id across the grid + global_warp_num = block_num * warp_num_per_block # total warps in the grid + grid_thread_id = bid * (warp_num_per_block * 64) + tid # grid-wide thread id (used by Stage 2 only) + + # Predicated buffer_load: returns 0 (i32) when vld_flag is false. + # Defined as a nested function so the AST rewriter lowers the Python + # ``if`` to ``scf.if`` for every call site (the rewriter only walks + # function bodies inside ``@flyc.kernel`` and their nested defs). + def _maybe_load(rsrc, offset, vld_flag, **kwargs): + result = arith.constant(0, type=T.i32()) + if vld_flag: + result = buffer_load(rsrc, offset, **kwargs) + return result + + _r_trecv = create_buffer_resource_from_addr(addr_trecv) + _r_xdb_flag = create_buffer_resource_from_addr(addr_xdb_flag) + _r_tis = create_buffer_resource_from_addr(addr_tis) + _r_comb_bar = create_buffer_resource_from_addr(addr_comb_bar) + _r_p2p_comb = create_buffer_resource_from_addr(addr_p2p_comb_inp) + _r_p2p_xdb = create_buffer_resource_from_addr(addr_p2p_xdb_mem) + _rsrc_tok_map = create_buffer_resource_from_addr(addr_tok_map) + + total_recv = buffer_load(_r_trecv, 0, vec_width=1, dtype=T.i32()) + # Per-launch monotonically-incrementing flag value used by Stage 2's + # cross-device barrier (each rank waits to observe this value from + # every peer). + xdb_cur_flag = buffer_load(_r_xdb_flag, 0, vec_width=1, dtype=T.i64()) + + # LDS-resident table of pre-resolved P2P base addresses (i64[npes]). + # Cached once in shared memory so the Stage 1 scatter loop (which + # may visit thousands of tokens per warp) avoids reissuing a global + # load for the same per-peer base on every iteration. + base_ptr = allocator.get_base() + # NOTE: SmemPtr ops are intentionally written as unbound-class calls + # (``SmemPtr.(instance, ...)`` rather than ``instance.(...)``) + # to avoid the upstream ast_rewriter heuristic that treats any + # ``var.method(...)`` inside an scf-lowered if/for as a loop-carried + # variable (which then fails because SmemPtr is not an MLIR Value). + # All ``_lds_p2p_*`` and downstream ``SmemPtr.{get,load,store}`` + # call sites follow the same convention. + _lds_p2p_bases = SmemPtr(base_ptr, p2p_base_offset, T.i64(), + shape=(npes,)) + SmemPtr.get(_lds_p2p_bases) + + if lane < npes: + p2p_base_addr = buffer_load(_r_p2p_comb, lane, vec_width=1, dtype=T.i64()) + SmemPtr.store(_lds_p2p_bases, p2p_base_addr, [lane]) + + if const_expr(enable_weights): + _r_p2p_comb_wt = create_buffer_resource_from_addr(addr_p2p_comb_inp_wts) + _lds_p2p_wt_bases = SmemPtr(base_ptr, p2p_wt_base_offset, T.i64(), + shape=(npes,)) + SmemPtr.get(_lds_p2p_wt_bases) + if lane < npes: + p2p_wt_base_addr = buffer_load(_r_p2p_comb_wt, lane, vec_width=1, dtype=T.i64()) + SmemPtr.store(_lds_p2p_wt_bases, p2p_wt_base_addr, [lane]) + + fx.barrier() + + # Stage 1: P2P scatter / ConvertCombineInput. + # When ``skip_stage1`` is set the entire stage is compile-time + # eliminated; the caller is responsible for having pre-staged the + # equivalent P2P writes into shmem_comb_inp[_wts]. + # + # Common per-token decoding from ``shmem_tok_id_to_src[recv_tok_id]``: + # dest_pe - which peer this token must be combined to + # dest_lid - the per-PE local id ``[0, max_tok_per_rank)`` + n_chunks = nbytes // 16 # 16-byte (4-i32) vector chunks per token + + if const_expr(skip_stage1): + if const_expr(enable_weights): + # Weight-only Stage 1: same as default path but only writes + # the small weight slot (no per-token hidden bytes). Used by + # fused_gemm2_combine to keep weight scatter off the heavy + # token-write fabric. + for recv_tok_id in range(global_warp_id, total_recv, global_warp_num): + dest_tok_enc = buffer_load(_r_tis, recv_tok_id, vec_width=1, dtype=T.i32()) + if const_expr(_log2_max_tok is not None): + dest_pe = dest_tok_enc >> _log2_max_tok + dest_lid = dest_tok_enc & _mask_max_tok + else: + dest_pe = (dest_tok_enc // max_tok_per_rank) + dest_lid = (dest_tok_enc % max_tok_per_rank) + wt_pe_base = SmemPtr.load(_lds_p2p_wt_bases, [dest_pe]) + wt_dest_off = _to_i64( + rank * max_tok_per_rank + dest_lid) * weight_bytes + wt_dest_addr = _lv_unwrap(wt_pe_base) + wt_dest_off + wt_src_addr = _lv_unwrap(addr_wts_buf) + _to_i64(recv_tok_id) * weight_bytes + rsrc_wt_src = create_buffer_resource_from_addr(wt_src_addr) + rsrc_wt_dst = create_buffer_resource_from_addr(wt_dest_addr) + if lane < wt_n_i32: + wt_val = buffer_load(rsrc_wt_src, lane, vec_width=1, dtype=T.i32()) + buffer_store(wt_val, rsrc_wt_dst, lane) + else: + pass + elif const_expr(enable_std_moe): + # Stage 1 StdMoE: read the k-expert partials from + # ``packed_recv_x`` (per-expert buckets), reduce with the + # dispatch-time output weights, and scatter the merged token to + # the destination PE's ``shmem_comb_inp``. + _rsrc_dtm = create_buffer_resource_from_addr(addr_disp_tok_map) + _rsrc_dow = create_buffer_resource_from_addr(addr_disp_out_wts) + smoe_all_vld = False # k-slots may be sentinel (-1) for non-local experts + + for recv_tok_id in range(global_warp_id, total_recv, global_warp_num): + dest_tok_enc = buffer_load(_r_tis, recv_tok_id, vec_width=1, dtype=T.i32()) + if const_expr(_log2_max_tok is not None): + dest_pe = dest_tok_enc >> _log2_max_tok + dest_lid = dest_tok_enc & _mask_max_tok + else: + dest_pe = (dest_tok_enc // max_tok_per_rank) + dest_lid = (dest_tok_enc % max_tok_per_rank) + + if const_expr(use_p2p_read): + # P2P-read mode: write locally; peers will pull from us in Stage 3. + dest_byte_off = _to_i64(recv_tok_id) * nbytes + dest_tok_addr = _lv_unwrap(addr_comb_inp) + dest_byte_off + else: + peer_base = SmemPtr.load(_lds_p2p_bases, [dest_pe]) + dest_byte_off = _to_i64(rank * max_tok_per_rank + dest_lid) * nbytes + dest_tok_addr = _lv_unwrap(peer_base) + dest_byte_off + rsrc_dst = create_buffer_resource_from_addr(dest_tok_addr) + + # Collect resources/valid-flags/weights for each k-expert slot. + expert_rsrcs = [] + expert_vlds = [] + expert_wts = [] + for k_slot in range_constexpr(experts_per_token): + slot_addr = addr_disp_tok_map + _to_i64(recv_tok_id * experts_per_token + k_slot) * 8 + slot_val = load_i64_global(slot_addr) + slot_vld = slot_val != -1 + safe_slot = arith.select(slot_vld, slot_val, 0) + expert_tok_addr = addr_packed_recv_x + safe_slot * nbytes + expert_rsrcs.append(create_buffer_resource_from_addr(expert_tok_addr)) + expert_vlds.append(slot_vld) + wt_k = buffer_load(_rsrc_dow, recv_tok_id * experts_per_token + k_slot, + vec_width=1, dtype=T.f32()) + expert_wts.append(wt_k) + + # Weighted reduce across the k experts, then scatter. + for elem_off in range(lane, n_i32, 64): + expert_vals = [] + for k_slot in range_constexpr(experts_per_token): + expert_vals.append(buffer_load(expert_rsrcs[k_slot], elem_off, + vec_width=1, dtype=T.i32())) + accum = _weighted_accum_experts(expert_vals, expert_wts, + expert_vlds, smoe_all_vld) + buffer_store(accum, rsrc_dst, elem_off) + + if const_expr(enable_weights): + if const_expr(use_p2p_read): + wt_dest_off = _to_i64(recv_tok_id) * weight_bytes + wt_dest_addr = _lv_unwrap(addr_comb_inp_wts) + wt_dest_off + else: + wt_pe_base = SmemPtr.load(_lds_p2p_wt_bases, [dest_pe]) + wt_dest_off = _to_i64( + rank * max_tok_per_rank + dest_lid) * weight_bytes + wt_dest_addr = _lv_unwrap(wt_pe_base) + wt_dest_off + wt_src_addr = _lv_unwrap(addr_wts_buf) + _to_i64(recv_tok_id) * weight_bytes + rsrc_wt_src = create_buffer_resource_from_addr(wt_src_addr) + rsrc_wt_dst = create_buffer_resource_from_addr(wt_dest_addr) + if lane < wt_n_i32: + wt_val = buffer_load(rsrc_wt_src, lane, vec_width=1, dtype=T.i32()) + buffer_store(wt_val, rsrc_wt_dst, lane) + + elif const_expr(use_p2p_read): + # Stage 1 P2P-read mode: every rank writes its post-expert + # tokens into its OWN ``shmem_comb_inp`` slot indexed by + # ``recv_tok_id`` (no remote write). Peers will read these + # buffers cross-device during Stage 3. + dual_end_aligned = (n_chunks // 128) * 128 + for recv_tok_id in range(global_warp_id, total_recv, global_warp_num): + # In mixed-mode (bf16 input → fp8 staging), the source uses + # bf16 stride (inp_nbytes) while the dest uses fp8 stride + # (nbytes); in same-dtype mode the two strides are identical. + src_tok_addr = addr_inp_tok + _to_i64(recv_tok_id) * inp_nbytes + dst_tok_addr = addr_comb_inp + _to_i64(recv_tok_id) * nbytes + rsrc_src = create_buffer_resource_from_addr(src_tok_addr) + rsrc_dst = create_buffer_resource_from_addr(dst_tok_addr) + if const_expr(_xfer_bf16_to_fp8): + # Mixed-dtype Stage 1: load bf16 (2 i32 / lane = 4 bf16 + # elems) → ExtF v4f32 → cvt_pk_fp8_f32 ×2 → store 1 + # fp8 i32 (4 fp8 elems) at staging offset ``elem_off``. + from flydsl._mlir.dialects import rocdl as _rocdl_s1a + _v4bf16_a = T.VectorType.get([4], T.bf16()) + _v4f32_a = T.VectorType.get([4], T.f32()) + _i32t_a = T.i32() + for elem_off in range(lane, n_i32, 64): + bf_pair = buffer_load(rsrc_src, elem_off * 2, + vec_width=2, dtype=T.i32()) + v4f = vector.bitcast(_v4bf16_a, bf_pair).extf(_v4f32_a) + f0 = vector.extract(v4f, static_position=[0]) + f1 = vector.extract(v4f, static_position=[1]) + f2 = vector.extract(v4f, static_position=[2]) + f3 = vector.extract(v4f, static_position=[3]) + zi = arith.constant(0, type=_i32t_a) + lo = _rocdl_s1a.cvt_pk_fp8_f32(res=_i32t_a, src_a=f0, src_b=f1, + old=zi, word_sel=False) + fp8_i32 = _rocdl_s1a.cvt_pk_fp8_f32(res=_i32t_a, src_a=f2, src_b=f3, + old=lo, word_sel=True) + buffer_store(fp8_i32, rsrc_dst, elem_off) + else: + # Same-dtype path: 4-i32 vector copy. ``chunk_idx`` is + # the 16-byte-chunk index this lane is currently + # copying; ``chunk_i32_off`` translates it to i32 elems. + if const_expr(dual_end_aligned >= 128): + for chunk_idx in range(lane, dual_end_aligned, 128): + chunk_i32_off = chunk_idx * 4 + chunk_i32_off_alt = (chunk_idx + 64) * 4 + vec_a = buffer_load(rsrc_src, chunk_i32_off, vec_width=4, dtype=T.i32()) + vec_b = buffer_load(rsrc_src, chunk_i32_off_alt, vec_width=4, dtype=T.i32()) + buffer_store(vec_a, rsrc_dst, chunk_i32_off) + buffer_store(vec_b, rsrc_dst, chunk_i32_off_alt) + if const_expr(dual_end_aligned < n_chunks): + for chunk_idx in range(lane + dual_end_aligned, n_chunks, 64): + chunk_i32_off = chunk_idx * 4 + vec_a = buffer_load(rsrc_src, chunk_i32_off, vec_width=4, dtype=T.i32()) + buffer_store(vec_a, rsrc_dst, chunk_i32_off) + + if const_expr(enable_weights): + for recv_tok_id in range(global_warp_id, total_recv, global_warp_num): + wt_src_addr = _lv_unwrap(addr_wts_buf) + _to_i64(recv_tok_id) * weight_bytes + wt_dst_addr = _lv_unwrap(addr_comb_inp_wts) + _to_i64(recv_tok_id) * weight_bytes + rsrc_wt_src = create_buffer_resource_from_addr(wt_src_addr) + rsrc_wt_dst = create_buffer_resource_from_addr(wt_dst_addr) + if lane < wt_n_i32: + wt_val = buffer_load(rsrc_wt_src, lane, vec_width=1, dtype=T.i32()) + buffer_store(wt_val, rsrc_wt_dst, lane) + + else: + # Stage 1 default mode: P2P-write each received token to the + # destination PE's ``shmem_comb_inp`` at slot (rank, dest_lid). + dual_end_aligned = (n_chunks // 128) * 128 + for recv_tok_id in range(global_warp_id, total_recv, global_warp_num): + dest_tok_enc = buffer_load(_r_tis, recv_tok_id, vec_width=1, dtype=T.i32()) + if const_expr(_log2_max_tok is not None): + dest_pe = dest_tok_enc >> _log2_max_tok + dest_lid = dest_tok_enc & _mask_max_tok + else: + dest_pe = (dest_tok_enc // max_tok_per_rank) + dest_lid = (dest_tok_enc % max_tok_per_rank) + peer_base = SmemPtr.load(_lds_p2p_bases, [dest_pe]) + # Dest stride uses ``nbytes`` (staging dtype, fp8 in mixed mode). + dest_off = _to_i64(rank * max_tok_per_rank + dest_lid) * nbytes + dest_tok_addr = _lv_unwrap(peer_base) + dest_off + # Src stride uses ``inp_nbytes`` (input dtype, bf16 in mixed mode). + src_tok_addr = addr_inp_tok + _to_i64(recv_tok_id) * inp_nbytes + rsrc_src = create_buffer_resource_from_addr(src_tok_addr) + rsrc_dst = create_buffer_resource_from_addr(dest_tok_addr) + if const_expr(_xfer_bf16_to_fp8): + # Mixed-dtype Stage 1: load 2 bf16 i32 (=4 bf16 elems) → + # ExtF v4f32 → cvt_pk_fp8_f32 ×2 → store 1 fp8 i32 (=4 + # fp8 elems). Loop unit is 1 fp8-i32 per lane per step. + from flydsl._mlir.dialects import rocdl as _rocdl_s1b + _v4bf16_b = T.VectorType.get([4], T.bf16()) + _v4f32_b = T.VectorType.get([4], T.f32()) + _i32t_b = T.i32() + for elem_off in range(lane, n_i32, 64): + bf_pair = buffer_load(rsrc_src, elem_off * 2, + vec_width=2, dtype=T.i32()) + v4f = vector.bitcast(_v4bf16_b, bf_pair).extf(_v4f32_b) + f0 = vector.extract(v4f, static_position=[0]) + f1 = vector.extract(v4f, static_position=[1]) + f2 = vector.extract(v4f, static_position=[2]) + f3 = vector.extract(v4f, static_position=[3]) + zi = arith.constant(0, type=_i32t_b) + lo = _rocdl_s1b.cvt_pk_fp8_f32(res=_i32t_b, src_a=f0, src_b=f1, + old=zi, word_sel=False) + fp8_i32 = _rocdl_s1b.cvt_pk_fp8_f32(res=_i32t_b, src_a=f2, src_b=f3, + old=lo, word_sel=True) + buffer_store(fp8_i32, rsrc_dst, elem_off) + else: + if const_expr(dual_end_aligned >= 128): + for chunk_idx in range(lane, dual_end_aligned, 128): + chunk_i32_off = chunk_idx * 4 + chunk_i32_off_alt = (chunk_idx + 64) * 4 + vec_a = buffer_load(rsrc_src, chunk_i32_off, vec_width=4, dtype=T.i32()) + vec_b = buffer_load(rsrc_src, chunk_i32_off_alt, vec_width=4, dtype=T.i32()) + buffer_store(vec_a, rsrc_dst, chunk_i32_off) + buffer_store(vec_b, rsrc_dst, chunk_i32_off_alt) + if const_expr(dual_end_aligned < n_chunks): + for chunk_idx in range(lane + dual_end_aligned, n_chunks, 64): + chunk_i32_off = chunk_idx * 4 + vec_a = buffer_load(rsrc_src, chunk_i32_off, vec_width=4, dtype=T.i32()) + buffer_store(vec_a, rsrc_dst, chunk_i32_off) + + if const_expr(enable_weights): + wt_pe_base = SmemPtr.load(_lds_p2p_wt_bases, [dest_pe]) + wt_dest_off = _to_i64( + rank * max_tok_per_rank + dest_lid) * weight_bytes + wt_dest_addr = _lv_unwrap(wt_pe_base) + wt_dest_off + wt_src_addr = _lv_unwrap(addr_wts_buf) + _to_i64(recv_tok_id) * weight_bytes + rsrc_wt_src = create_buffer_resource_from_addr(wt_src_addr) + rsrc_wt_dst = create_buffer_resource_from_addr(wt_dest_addr) + if lane < wt_n_i32: + wt_val = buffer_load(rsrc_wt_src, lane, vec_width=1, dtype=T.i32()) + buffer_store(wt_val, rsrc_wt_dst, lane) + + # Stage 2: CrossDeviceBarrier. + # Every rank publishes ``xdb_cur_flag`` into every peer's + # ``xdev_bar_mem[rank]`` slot, then waits until every peer's + # corresponding slot in our local xdev_bar_mem hits the same flag. + fx.barrier() + if tid == 0: + atomic_add_global_at(addr_comb_bar, arith.constant(1)) + + if grid_thread_id < npes: + mori_shmem.int32_wait_until_equals(addr_comb_bar, block_num) + buffer_store(arith.constant(0), _r_comb_bar, 0) + xdb_remote_addr = buffer_load(_r_p2p_xdb, grid_thread_id, vec_width=1, dtype=T.i64()) + \ + arith.constant(rank, type=T.i64()) * 8 + store_i64_global_system(xdb_remote_addr, xdb_cur_flag) + + if grid_thread_id == 0: + atomic_add_global_at(addr_xdb_flag, arith.constant(1, type=T.i64())) + + if tid < npes: + xdb_peer_slot = addr_xdb_mem + _to_i64(tid) * 8 + mori_shmem.uint64_wait_until_equals(xdb_peer_slot, xdb_cur_flag) + + fx.barrier() + if tid == 0: + buffer_store(arith.constant(0), _r_trecv, 0) + + # Stage 3: local read + WarpAccum. + # Each output token's hidden dimension is split into ``warps_per_tok`` + # partitions; each warp handles one partition (size ``hdim_per_warp``) + # of one output token. Inside the partition, the warp reads the k + # per-expert partials from ``shmem_comb_inp``, accumulates them in + # high-precision (f32) and writes back the merged token to + # ``shmem_comb_out``. + SLC_CACHE = 2 # buffer_load/store ``cache_modifier=SLC`` (system-coherent) + rsrc_out = create_buffer_resource_from_addr(addr_comb_out) + + n_elems = n_i32 + # When ``cur_rank_num_token == 0`` the division below would divide by + # zero; clamp the denominator to 1 (loop won't execute anyway). + safe_token_count = arith.select( + cur_rank_num_token == 0, 1, cur_rank_num_token) + warps_per_tok = (global_warp_num + safe_token_count - 1) // safe_token_count + hdim_per_warp = (n_elems + warps_per_tok - 1) // warps_per_tok + s3_total_work = cur_rank_num_token * warps_per_tok + + for s3_work_idx in range(global_warp_id, s3_total_work, global_warp_num): + tok_id = (s3_work_idx // warps_per_tok) + part_id = (s3_work_idx % warps_per_tok) + hdim_off = part_id * hdim_per_warp + + expert_rsrcs = [] + expert_vlds = [] + + if const_expr(skip_stage1): + # Fused-upstream Stage 3: when ``skip_stage1`` is set the + # caller has plain-stored a per-(tok_id, k_slot) partial into + # ``shmem_comb_inp[(tok_id*k + k_slot) * token_bytes]``. Each + # k_slot is unique; there is no tok_map to decode -- the + # accumulator simply reads ``shmem_comb_inp`` for k_slot in + # [0, k). Unrouted (tok_id, k_slot) slots are zero-initialized + # by the caller and therefore contribute zero to the sum. + for k_slot in range_constexpr(experts_per_token): + slot_idx = tok_id * experts_per_token + k_slot + expert_tok_off = _to_i64(slot_idx) * nbytes + expert_tok_addr = _lv_unwrap(addr_comb_inp + expert_tok_off) + expert_rsrcs.append(create_buffer_resource_from_addr(expert_tok_addr)) + expert_vlds.append(arith.constant(1, type=T.bool())) + eff_all_vld = True + else: + # Baseline Stage 3: decode (peer_pe, dest_lid) from + # ``dest_tok_map[tok_id, 0..k)`` and read the per-(peer_pe, + # dest_lid) slot of ``shmem_comb_inp``. Stage 1 has P2P- + # scattered each (src_pe, src_lid) contribution into that + # slot. Two 4-i32 loads cover the 8 k-slots in one round. + tm_base_off = tok_id * experts_per_token + tm_vec_lo = buffer_load(_rsrc_tok_map, tm_base_off, vec_width=4, dtype=T.i32()) + tm_vec_hi = buffer_load(_rsrc_tok_map, tm_base_off + 4, vec_width=4, dtype=T.i32()) + + for k_slot in range_constexpr(experts_per_token): + if const_expr(k_slot < 4): + enc_k = vector.extract(tm_vec_lo, static_position=[k_slot]) + else: + enc_k = vector.extract(tm_vec_hi, static_position=[k_slot - 4]) + if const_expr(_log2_max_recv is not None): + dest_pe_k = enc_k >> _log2_max_recv + else: + dest_pe_k = (enc_k // max_recv) + vld_k = dest_pe_k < npes # sentinel = npes + safe_pe = arith.select(vld_k, dest_pe_k, rank) + if const_expr(use_p2p_read): + dtok_global = (enc_k % max_recv) + safe_dtok = arith.select(vld_k, dtok_global, 0) + peer_base = SmemPtr.load(_lds_p2p_bases, [safe_pe]) + expert_tok_off = _to_i64(safe_dtok) * nbytes + expert_tok_addr = _lv_unwrap(peer_base) + expert_tok_off + else: + expert_tok_off = _to_i64(safe_pe * max_tok_per_rank + tok_id) * nbytes + expert_tok_addr = _lv_unwrap(addr_comb_inp + expert_tok_off) + expert_rsrcs.append(create_buffer_resource_from_addr(expert_tok_addr)) + expert_vlds.append(vld_k) + + all_vld = (npes >= experts_per_token) # without compaction, every k_slot must be valid + eff_all_vld = all_vld or _use_compaction + + # Two paths optimised for the per-warp partition size: + # - wide path (hdim_per_warp > 895): step=128 dual or step=256 + # quad unrolled loads, each step covers 256/512/... bytes. + # - narrow path (hdim_per_warp <= 895): plain step=64 loop. + if 895 < hdim_per_warp: + rem_hdim_128 = n_elems - hdim_off + # Effective end of THIS warp's partition, clamped to n_elems. + eff_end_128 = arith.select( + rem_hdim_128 < hdim_per_warp, rem_hdim_128, hdim_per_warp) + + if const_expr(n_i32 % 256 == 0 and warp_num_per_block < 16): + if (hdim_per_warp % 256) < 1: + # Quad-unroll: 4 sub-stores per step (offset 0/256/512/768 B). + quad_end = eff_end_128 - 192 + for ec in range(lane, quad_end, 256): + ec_abs = hdim_off + ec + vals_a, vals_b, vals_c, vals_d = [], [], [], [] + for k_slot in range_constexpr(experts_per_token): + rsrc_k = expert_rsrcs[k_slot] + vld_k = expert_vlds[k_slot] + vals_a.append(_maybe_load(rsrc_k, ec_abs, vld_k, vec_width=1, dtype=T.i32(), cache_modifier=SLC_CACHE)) + vals_b.append(_maybe_load(rsrc_k, ec_abs, vld_k, vec_width=1, dtype=T.i32(), cache_modifier=SLC_CACHE, soffset_bytes=256)) + vals_c.append(_maybe_load(rsrc_k, ec_abs, vld_k, vec_width=1, dtype=T.i32(), cache_modifier=SLC_CACHE, soffset_bytes=512)) + vals_d.append(_maybe_load(rsrc_k, ec_abs, vld_k, vec_width=1, dtype=T.i32(), cache_modifier=SLC_CACHE, soffset_bytes=768)) + acc_a = _accum_experts(vals_a, expert_vlds, eff_all_vld) + acc_b = _accum_experts(vals_b, expert_vlds, eff_all_vld) + acc_c = _accum_experts(vals_c, expert_vlds, eff_all_vld) + acc_d = _accum_experts(vals_d, expert_vlds, eff_all_vld) + if const_expr(_xfer_bf16_to_fp8): + # bf16 output: data is v2i32 (8 B / lane); the + # i32 offset doubles per token and the 4 sub- + # stores use 256->512 byte spacing. + out_off = tok_id * out_n_i32 + ec_abs * 2 + buffer_store(acc_a, rsrc_out, out_off, cache_modifier=SLC_CACHE) + buffer_store(acc_b, rsrc_out, out_off, cache_modifier=SLC_CACHE, soffset_bytes=512) + buffer_store(acc_c, rsrc_out, out_off, cache_modifier=SLC_CACHE, soffset_bytes=1024) + buffer_store(acc_d, rsrc_out, out_off, cache_modifier=SLC_CACHE, soffset_bytes=1536) + else: + out_off = tok_id * n_i32 + ec_abs + buffer_store(acc_a, rsrc_out, out_off, cache_modifier=SLC_CACHE) + buffer_store(acc_b, rsrc_out, out_off, cache_modifier=SLC_CACHE, soffset_bytes=256) + buffer_store(acc_c, rsrc_out, out_off, cache_modifier=SLC_CACHE, soffset_bytes=512) + buffer_store(acc_d, rsrc_out, out_off, cache_modifier=SLC_CACHE, soffset_bytes=768) + else: + # Dual-unroll body + 1-wide tail. + s3_dual_end = (eff_end_128 // 128) * 128 + for ec in range(lane, s3_dual_end, 128): + ec_abs = hdim_off + ec + vals_a, vals_b = [], [] + for k_slot in range_constexpr(experts_per_token): + rsrc_k = expert_rsrcs[k_slot] + vld_k = expert_vlds[k_slot] + vals_a.append(_maybe_load(rsrc_k, ec_abs, vld_k, vec_width=1, dtype=T.i32(), cache_modifier=SLC_CACHE)) + vals_b.append(_maybe_load(rsrc_k, ec_abs, vld_k, vec_width=1, dtype=T.i32(), cache_modifier=SLC_CACHE, soffset_bytes=256)) + acc_a = _accum_experts(vals_a, expert_vlds, eff_all_vld) + acc_b = _accum_experts(vals_b, expert_vlds, eff_all_vld) + if const_expr(_xfer_bf16_to_fp8): + out_off = tok_id * out_n_i32 + ec_abs * 2 + buffer_store(acc_a, rsrc_out, out_off, cache_modifier=SLC_CACHE) + buffer_store(acc_b, rsrc_out, out_off, cache_modifier=SLC_CACHE, soffset_bytes=512) + else: + out_off = tok_id * n_i32 + ec_abs + buffer_store(acc_a, rsrc_out, out_off, cache_modifier=SLC_CACHE) + buffer_store(acc_b, rsrc_out, out_off, cache_modifier=SLC_CACHE, soffset_bytes=256) + for ec in range(lane + s3_dual_end, eff_end_128, 64): + ec_abs = hdim_off + ec + vals_tail = [] + for k_slot in range_constexpr(experts_per_token): + vals_tail.append(_maybe_load(expert_rsrcs[k_slot], ec_abs, expert_vlds[k_slot], vec_width=1, dtype=T.i32(), cache_modifier=SLC_CACHE)) + acc_tail = _accum_experts(vals_tail, expert_vlds, eff_all_vld) + if const_expr(_xfer_bf16_to_fp8): + out_off = tok_id * out_n_i32 + ec_abs * 2 + buffer_store(acc_tail, rsrc_out, out_off, cache_modifier=SLC_CACHE) + else: + out_off = tok_id * n_i32 + ec_abs + buffer_store(acc_tail, rsrc_out, out_off, cache_modifier=SLC_CACHE) + else: + # Narrow path: a single step=64 main loop. + rem_hdim_64 = n_elems - hdim_off + eff_end_64 = arith.select( + rem_hdim_64 < hdim_per_warp, rem_hdim_64, hdim_per_warp) + for ec in range(lane, eff_end_64, 64): + ec_abs = hdim_off + ec + vals_main = [] + for k_slot in range_constexpr(experts_per_token): + vals_main.append(_maybe_load(expert_rsrcs[k_slot], ec_abs, expert_vlds[k_slot], vec_width=1, dtype=T.i32(), cache_modifier=SLC_CACHE)) + acc = _accum_experts(vals_main, expert_vlds, eff_all_vld) + if const_expr(_xfer_bf16_to_fp8): + out_off = tok_id * out_n_i32 + ec_abs * 2 + buffer_store(acc, rsrc_out, out_off, cache_modifier=SLC_CACHE) + else: + out_off = tok_id * n_i32 + ec_abs + buffer_store(acc, rsrc_out, out_off, cache_modifier=SLC_CACHE) + + # Stage 3b: Weight accumulation. + # Each warp handles one output token; lanes 0..k-1 each pull the + # weight value from one k-expert slot's contribution in + # ``shmem_comb_inp_wts`` (or peer-side via P2P-read), then they + # f32-sum across the k slots and write into ``shmem_comb_out_wts``. + if const_expr(enable_weights): + rsrc_out_wts = create_buffer_resource_from_addr(addr_comb_out_wts) + for wt_tok_id in range(global_warp_id, cur_rank_num_token, global_warp_num): + wt_tm_off = wt_tok_id * experts_per_token + wt_tm_vec_lo = buffer_load(_rsrc_tok_map, wt_tm_off, vec_width=4, dtype=T.i32()) + wt_tm_vec_hi = buffer_load(_rsrc_tok_map, wt_tm_off + 4, vec_width=4, dtype=T.i32()) + + if lane < experts_per_token: + wt_acc = arith.constant(0.0, type=T.f32()) + for k_slot in range_constexpr(experts_per_token): + if const_expr(k_slot < 4): + wt_enc = vector.extract(wt_tm_vec_lo, static_position=[k_slot]) + else: + wt_enc = vector.extract(wt_tm_vec_hi, static_position=[k_slot - 4]) + if const_expr(_log2_max_recv is not None): + wt_pe = wt_enc >> _log2_max_recv + else: + wt_pe = (wt_enc // max_recv) + wt_vld = wt_pe < npes + wt_safe_pe = arith.select(wt_vld, wt_pe, rank) + if const_expr(use_p2p_read): + wt_dtok = (wt_enc % max_recv) + wt_safe_dtok = arith.select(wt_vld, wt_dtok, 0) + wt_pe_base = SmemPtr.load(_lds_p2p_wt_bases, [wt_safe_pe]) + wt_src_off = _to_i64(wt_safe_dtok) * weight_bytes + wt_rsrc = create_buffer_resource_from_addr( + wt_pe_base + wt_src_off) + else: + wt_src_off = _to_i64( + wt_safe_pe * max_tok_per_rank + wt_tok_id) * weight_bytes + wt_rsrc = create_buffer_resource_from_addr( + addr_comb_inp_wts + wt_src_off) + wt_val = buffer_load(wt_rsrc, lane, vec_width=1, dtype=T.f32()) + if const_expr(npes >= experts_per_token): + wt_acc = wt_acc + wt_val + else: + wt_acc = wt_acc + arith.select(wt_vld, wt_val, 0.0) + wt_out_off = wt_tok_id * experts_per_token + lane + buffer_store(wt_acc, rsrc_out_wts, wt_out_off) + + ep_combine_intranode._allocator = allocator + return ep_combine_intranode + + +def make_dispatch_jit(*, rank, npes, experts_per_rank, experts_per_token, + hidden_dim, max_tok_per_rank, block_num, + warp_num_per_block, data_type, + scale_dim=0, scale_type_size=0, + enable_std_moe=False): + hidden_elem_size = torch.tensor([], dtype=data_type).element_size() + kernel = make_dispatch_kernel( + rank=rank, npes=npes, + experts_per_rank=experts_per_rank, + experts_per_token=experts_per_token, + hidden_dim=hidden_dim, + hidden_elem_size=hidden_elem_size, + max_tok_per_rank=max_tok_per_rank, + block_num=block_num, + warp_num_per_block=warp_num_per_block, + scale_dim=scale_dim, + scale_type_size=scale_type_size, + enable_std_moe=enable_std_moe, + data_type=data_type, + ) + + # Closure variables that participate in the JIT cache key. The launcher + # closes over them so that two ``@flyc.jit`` invocations with different + # configs produce distinct cached entries. + _key_rank, _key_npes, _key_block_num = rank, npes, block_num + _key_warp_per_block = warp_num_per_block + _key_max_tok = max_tok_per_rank + _key_std_moe = enable_std_moe + + @flyc.jit + def dispatch_launch( + addr_inp_tok: fx.Int64, addr_idx: fx.Int64, addr_wts: fx.Int64, + addr_out_tok: fx.Int64, addr_out_wts: fx.Int64, addr_out_idx: fx.Int64, + addr_tok_off: fx.Int64, addr_recv_num: fx.Int64, + addr_dest_ctr: fx.Int64, addr_disp_bar: fx.Int64, + addr_tok_map: fx.Int64, addr_tis: fx.Int64, + addr_total_rv: fx.Int64, + addr_p2p_tok_off: fx.Int64, addr_p2p_tis: fx.Int64, + addr_p2p_out_wts: fx.Int64, addr_p2p_out_idx: fx.Int64, + addr_p2p_out_tok: fx.Int64, addr_p2p_recv_num: fx.Int64, + addr_scales: fx.Int64, addr_p2p_out_scales: fx.Int64, + addr_packed_recv_x: fx.Int64, addr_packed_recv_count: fx.Int64, + addr_packed_recv_src_info: fx.Int64, addr_disp_tok_map: fx.Int64, + addr_disp_grid_bar: fx.Int64, + cur_tok: fx.Int32, + stream: Stream = Stream(None), + ): + _ = (_key_rank, _key_npes, _key_block_num, _key_warp_per_block, + _key_max_tok, _key_std_moe) + kernel(addr_inp_tok, addr_idx, addr_wts, + addr_out_tok, addr_out_wts, addr_out_idx, + addr_tok_off, addr_recv_num, addr_dest_ctr, + addr_disp_bar, addr_tok_map, addr_tis, + addr_total_rv, + addr_p2p_tok_off, addr_p2p_tis, + addr_p2p_out_wts, addr_p2p_out_idx, + addr_p2p_out_tok, addr_p2p_recv_num, + addr_scales, addr_p2p_out_scales, + addr_packed_recv_x, addr_packed_recv_count, + addr_packed_recv_src_info, addr_disp_tok_map, + addr_disp_grid_bar, + cur_tok).launch( + grid=(block_num, 1, 1), + block=(warp_num_per_block * 64, 1, 1), + stream=stream, + ) + + return dispatch_launch + + +def make_combine_jit(*, rank, npes, experts_per_rank=0, experts_per_token, + hidden_dim, max_tok_per_rank, block_num, + warp_num_per_block, data_type, + enable_weights=False, enable_std_moe=False, + use_p2p_read=False, skip_stage1=False, + inp_data_type=None): + hidden_elem_size = torch.tensor([], dtype=data_type).element_size() + kernel = make_combine_kernel( + rank=rank, npes=npes, + experts_per_rank=experts_per_rank, + experts_per_token=experts_per_token, + hidden_dim=hidden_dim, + hidden_elem_size=hidden_elem_size, + max_tok_per_rank=max_tok_per_rank, + block_num=block_num, + warp_num_per_block=warp_num_per_block, + data_type=data_type, + enable_weights=enable_weights, + enable_std_moe=enable_std_moe, + use_p2p_read=use_p2p_read, + skip_stage1=skip_stage1, + inp_data_type=inp_data_type, + ) + + # Closure variables that participate in the JIT cache key. The launcher + # closes over them so two ``@flyc.jit`` invocations with different + # configs produce distinct cached entries. + _key_rank, _key_npes, _key_block_num = rank, npes, block_num + _key_warp_per_block = warp_num_per_block + _key_max_tok = max_tok_per_rank + _key_weights = enable_weights + _key_std_moe = enable_std_moe + _key_p2p_read = use_p2p_read + _key_skip_s1 = skip_stage1 + _key_inp_dtype = str(inp_data_type) if inp_data_type is not None else "none" + _allocator = kernel._allocator + + @flyc.jit + def combine_launch( + addr_inp_tok: fx.Int64, addr_comb_inp: fx.Int64, + addr_comb_out: fx.Int64, addr_xdb_mem: fx.Int64, + addr_xdb_flag: fx.Int64, addr_tok_map: fx.Int64, + addr_comb_bar: fx.Int64, addr_trecv: fx.Int64, + addr_tis: fx.Int64, + addr_p2p_comb_inp: fx.Int64, addr_p2p_xdb_mem: fx.Int64, + addr_wts_buf: fx.Int64, + addr_comb_inp_wts: fx.Int64, addr_comb_out_wts: fx.Int64, + addr_p2p_comb_inp_wts: fx.Int64, + addr_packed_recv_x: fx.Int64, addr_disp_tok_map: fx.Int64, + addr_disp_out_wts: fx.Int64, + cur_rank_num_token: fx.Int32, + stream: Stream = Stream(None), + ): + _ = (_key_rank, _key_npes, _key_block_num, _key_warp_per_block, + _key_max_tok, _key_weights, _key_std_moe, _key_p2p_read, + _key_skip_s1, _key_inp_dtype) + from flydsl.compiler.kernel_function import CompilationContext + from flydsl._mlir import ir + _allocator.finalized = False + ctx = CompilationContext.get_current() + with ir.InsertionPoint(ctx.gpu_module_body): + _allocator.finalize() + + kernel(addr_inp_tok, addr_comb_inp, addr_comb_out, + addr_xdb_mem, addr_xdb_flag, addr_tok_map, + addr_comb_bar, addr_trecv, addr_tis, + addr_p2p_comb_inp, addr_p2p_xdb_mem, + addr_wts_buf, addr_comb_inp_wts, + addr_comb_out_wts, addr_p2p_comb_inp_wts, + addr_packed_recv_x, addr_disp_tok_map, + addr_disp_out_wts, + cur_rank_num_token).launch( + grid=(block_num, 1, 1), + block=(warp_num_per_block * 64, 1, 1), + stream=stream, + ) + + return combine_launch diff --git a/python/flydsl/expr/arith.py b/python/flydsl/expr/arith.py index c04998c2e..38403f6c4 100644 --- a/python/flydsl/expr/arith.py +++ b/python/flydsl/expr/arith.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # Copyright (c) 2025 FlyDSL Project Contributors -# ruff: noqa: I001 +from .._mlir.dialects.arith import * # noqa: F401,F403 """Arith dialect API — operator overloading + function-level builders. Usage: @@ -12,29 +12,27 @@ r = arith.select(cond, a, b) # ArithValue operator overloading: c + 1, c * 2, c / 4, c % 16 """ - -from .._mlir.dialects.arith import * # noqa: F401,F403 - -# Override star-import cmpi/cmpf to accept Numeric types (Int32, etc.) -from .._mlir.dialects import arith as _mlir_arith from .meta import traced_op from .utils.arith import ( # noqa: F401 ArithValue, - _to_raw, - andi, constant, constant_vector, index, index_cast, int_to_fp, select, - shli, sitofp, trunc_f, - unwrap, + andi, xori, + shli, + unwrap, + _to_raw, ) +# Override star-import cmpi/cmpf to accept Numeric types (Int32, etc.) +from .._mlir.dialects import arith as _mlir_arith # noqa: E402 + @traced_op def cmpi(predicate, lhs, rhs, **kwargs): @@ -64,3 +62,6 @@ def cmpf(predicate, lhs, rhs, **kwargs): An ``i1`` comparison result. """ return _mlir_arith.cmpf(predicate, _to_raw(lhs), _to_raw(rhs), **kwargs) + + + diff --git a/python/flydsl/expr/rocdl/__init__.py b/python/flydsl/expr/rocdl/__init__.py index 53815ae3b..70e821c36 100644 --- a/python/flydsl/expr/rocdl/__init__.py +++ b/python/flydsl/expr/rocdl/__init__.py @@ -31,6 +31,8 @@ _ods_cluster_load_async_to_lds_b128 = cluster_load_async_to_lds_b128 _ods_s_wait_asynccnt = s_wait_asynccnt _ods_readfirstlane = readfirstlane +_ods_ballot = ballot +_ods_readlane = readlane _ods_mfma_f32_32x32x8f16 = globals().get("mfma_f32_32x32x8f16", None) _ods_mfma_f32_32x32x8bf16_1k = globals().get("mfma_f32_32x32x8bf16_1k", None) _ods_mfma_f32_32x32x16_f16 = globals().get("mfma_f32_32x32x16_f16", None) @@ -41,8 +43,9 @@ _ods_mfma_i32_16x16x32_i8 = mfma_i32_16x16x32_i8 _ods_mfma_f32_16x16x32_f16 = globals().get("mfma_f32_16x16x32_f16", None) _ods_mfma_f32_16x16x32_bf16 = globals().get("mfma_f32_16x16x32_bf16", None) -_ods_mfma_scale_f32_16x16x128_f8f6f4 = globals().get("mfma_scale_f32_16x16x128_f8f6f4", None) or globals().get( - "mfma_scale_f32_16x16x128_f8f6f4_", None +_ods_mfma_scale_f32_16x16x128_f8f6f4 = ( + globals().get("mfma_scale_f32_16x16x128_f8f6f4", None) + or globals().get("mfma_scale_f32_16x16x128_f8f6f4_", None) ) mask_mfma = 0x008 mask_vmem_rd = 0x020 @@ -518,3 +521,23 @@ def ds_bpermute(res, index, src, **kw): def readfirstlane(res, src, **kw): return _ods_readfirstlane(res=res, src=_to_ir(src), **kw) + + +def ballot(res, pred, **kw): + """Wrap ROCDL ``ballot``: coerce ``pred`` to ``i1`` if needed. + + ``res`` selects the lane-mask width (``i32`` on wave32, ``i64`` on wave64). + """ + from ..._mlir.ir import IntegerType + from ..._mlir.dialects import llvm as _llvm + + pred_v = _to_ir(pred) + i1 = IntegerType.get_signless(1) + if pred_v.type != i1: + pred_v = _llvm.TruncOp(i1, pred_v).result + return _ods_ballot(res=res, pred=pred_v, **kw) + + +def readlane(res, src, lane, **kw): + """Wrap ROCDL ``readlane`` with ``_to_ir`` coercion (Python ``int`` ok for ``lane``).""" + return _ods_readlane(res=res, src0=_to_ir(src), src1=_to_ir(lane), **kw) From 9f9a4b560fc39c51bc623d6fdf7fc19c7cb60939 Mon Sep 17 00:00:00 2001 From: yanboshao Date: Sun, 17 May 2026 14:45:43 +0000 Subject: [PATCH 2/4] style: satisfy pre-checks (black + ruff) Make the PR pass the `Check Python Code Style` CI step (.github/workflows/ pre-checks.yaml), which runs ``black --check --diff`` and ``ruff check`` on the set of Python files changed by the PR. Auto-fixes (ruff --fix): I001 (5 unsorted-imports), F401 (3 unused-imports), F811 (1 redefined-while-unused), W293 (1 blank-line-with-whitespace). Manual fixes: - F841 (7 unused-variable): drop dead assignments to ``tok_stride`` / ``inp_n_i32`` in dispatch_combine_intranode_kernel.py, and four ``hdim`` + one ``esz`` in dispatch_combine_intranode_op.py. - E702 (23 multiple-statements-on-one-line): split ``a; b; c`` boilerplate in tests/kernels/test_profiler_dispatch_combine.py (mostly ``dist.all_reduce`` aggregation patterns). - E402 (2 module-import-not-at-top): add ``# noqa: E402`` to the two imports that intentionally follow ``sys.path.insert(0, _p)`` in the test script. Formatting: run ``black`` (line-length=120, per pyproject.toml) on the four PR-modified Python files. ast_rewriter.py was already compliant. CI parity locally: ``black --check`` + ``ruff check`` both clean on all PR files. Verified end-to-end (8x GPU, gfx942, bf16) after the style sweep: - torchrun -np 8 tests/kernels/test_profiler_dispatch_combine.py --mode verify -> ALL PASS (diff=0 on dispatch + combine). - torchrun -np 8 tests/kernels/test_profiler_dispatch_combine.py --mode verify --enable-std-moe -> ALL PASS (max_diff=0.015625 within StdMoE weighted tolerance). --- kernels/dispatch_combine_intranode_kernel.py | 910 ++++++----- kernels/dispatch_combine_intranode_op.py | 649 ++++++++ python/flydsl/expr/rocdl/__init__.py | 7 +- .../kernels/test_profiler_dispatch_combine.py | 1442 +++++++++++++++++ 4 files changed, 2609 insertions(+), 399 deletions(-) create mode 100644 kernels/dispatch_combine_intranode_op.py create mode 100644 tests/kernels/test_profiler_dispatch_combine.py diff --git a/kernels/dispatch_combine_intranode_kernel.py b/kernels/dispatch_combine_intranode_kernel.py index d17b8330e..087aa6f8f 100644 --- a/kernels/dispatch_combine_intranode_kernel.py +++ b/kernels/dispatch_combine_intranode_kernel.py @@ -15,13 +15,21 @@ from __future__ import annotations -import flydsl.compiler as flyc -import flydsl.expr as fx +import mori.ir.flydsl as mori_shmem import torch -import mori.ir.flydsl as mori_shmem +import flydsl.compiler as flyc +import flydsl.expr as fx -from flydsl.expr import T, arith, const_expr, range_constexpr +# Low-level MLIR escape hatches: system-scope atomics, pointer casts, and +# raw LLVM ops used inside the SIMD micro-kernel helpers +# (``_accum_experts`` / ``_weighted_accum_experts``). +from flydsl._mlir import ir +from flydsl._mlir import ir as _ir +from flydsl._mlir.dialects import llvm as _llvm_d +from flydsl._mlir.ir import IntegerAttr as _IntAttr +from flydsl._mlir.ir import IntegerType as _IntTy +from flydsl.expr import T, arith, const_expr, range_constexpr, vector from flydsl.expr.buffer_ops import ( buffer_load, buffer_store, @@ -29,17 +37,8 @@ ) from flydsl.expr.rocdl import ballot, readlane from flydsl.expr.typing import Stream -from flydsl.expr import vector from flydsl.utils.smem_allocator import SmemAllocator, SmemPtr -# Low-level MLIR escape hatches: system-scope atomics, pointer casts, and -# raw LLVM ops used inside the SIMD micro-kernel helpers -# (``_accum_experts`` / ``_weighted_accum_experts``). -from flydsl._mlir import ir -from flydsl._mlir import ir as _ir -from flydsl._mlir.dialects import llvm as _llvm_d -from flydsl._mlir.ir import IntegerAttr as _IntAttr, IntegerType as _IntTy - def _lv_unwrap(v): """Return the underlying ``ir.Value`` for *v*. @@ -104,35 +103,28 @@ def _vec_to_i32_bitcast(vec_val): def _to_ptr_global(v): """Cast an i64 address to ``!llvm.ptr<1>`` (global address space).""" - return _llvm_d.IntToPtrOp( - _llvm_d.PointerType.get(address_space=1), _lv_unwrap(v)).result + return _llvm_d.IntToPtrOp(_llvm_d.PointerType.get(address_space=1), _lv_unwrap(v)).result def store_i32_system(addr_i64, offset, val): """System-scope monotonic i32 store at ``addr_i64 + offset*4`` bytes.""" base = _lv_unwrap(addr_i64) - off = _lv_unwrap(offset) + off = _lv_unwrap(offset) val_ = _lv_unwrap(val) _i64 = _IntTy.get_signless(64) _i32 = _IntTy.get_signless(32) _nuw = _ir.Attribute.parse("#llvm.overflow") off64 = _llvm_d.ZExtOp(_i64, off).res if off.type == _i32 else off - byte_off = _llvm_d.MulOp( - off64, _llvm_d.ConstantOp(_i64, _IntAttr.get(_i64, 4)).result, _nuw).result + byte_off = _llvm_d.MulOp(off64, _llvm_d.ConstantOp(_i64, _IntAttr.get(_i64, 4)).result, _nuw).result addr = _llvm_d.AddOp(base, byte_off, _nuw).result - gptr = _llvm_d.IntToPtrOp( - _llvm_d.PointerType.get(address_space=1), addr).result - _llvm_d.StoreOp(val_, gptr, alignment=4, - ordering=_llvm_d.AtomicOrdering.monotonic, - syncscope="one-as") + gptr = _llvm_d.IntToPtrOp(_llvm_d.PointerType.get(address_space=1), addr).result + _llvm_d.StoreOp(val_, gptr, alignment=4, ordering=_llvm_d.AtomicOrdering.monotonic, syncscope="one-as") def store_i64_global_system(addr_i64, val): """System-scope monotonic i64 store to ``addr_i64``.""" gptr = _to_ptr_global(addr_i64) - _llvm_d.StoreOp(_lv_unwrap(val), gptr, alignment=8, - ordering=_llvm_d.AtomicOrdering.monotonic, - syncscope="one-as") + _llvm_d.StoreOp(_lv_unwrap(val), gptr, alignment=8, ordering=_llvm_d.AtomicOrdering.monotonic, syncscope="one-as") def load_i64_global(addr_i64): @@ -145,9 +137,7 @@ def load_i64_global(addr_i64): def atomic_add_global_at(addr_i64, val): """Monotonic global ``atomic fetch-and-add``; returns the old value.""" ptr = _to_ptr_global(addr_i64) - return _llvm_d.AtomicRMWOp( - _llvm_d.AtomicBinOp.add, ptr, _lv_unwrap(val), - _llvm_d.AtomicOrdering.monotonic).res + return _llvm_d.AtomicRMWOp(_llvm_d.AtomicBinOp.add, ptr, _lv_unwrap(val), _llvm_d.AtomicOrdering.monotonic).res # NOTE: explicit ``index``/``i32`` casts that used to wrap every for-loop bound @@ -190,70 +180,70 @@ def make_dispatch_kernel( (per-expert packing for the std-MoE expert path). """ max_recv = npes * max_tok_per_rank - _is_fp4 = (data_type == torch.float4_e2m1fn_x2) + _is_fp4 = data_type == torch.float4_e2m1fn_x2 if _is_fp4: - n_i32 = hidden_dim // 8 # 8 fp4 values per i32 (4 bytes) - nbytes = hidden_dim // 2 # 2 fp4 values per byte + n_i32 = hidden_dim // 8 # 8 fp4 values per i32 (4 bytes) + nbytes = hidden_dim // 2 # 2 fp4 values per byte else: - n_i32 = (hidden_dim * hidden_elem_size) // 4 + n_i32 = (hidden_dim * hidden_elem_size) // 4 nbytes = hidden_dim * hidden_elem_size - scale_bytes = scale_dim * scale_type_size - scale_n_i32 = (scale_bytes + 3) // 4 if scale_bytes > 0 else 0 - enable_scales = scale_bytes > 0 + scale_bytes = scale_dim * scale_type_size + scale_n_i32 = (scale_bytes + 3) // 4 if scale_bytes > 0 else 0 + enable_scales = scale_bytes > 0 max_tokens_per_expert = npes * max_tok_per_rank # per-expert bucket capacity @flyc.kernel def ep_dispatch_intranode( - addr_inp_tok: fx.Int64, # [cur_tok, hidden_dim] bf16 - addr_idx: fx.Int64, # [cur_tok, k] i32 (token_indices) - addr_wts: fx.Int64, # [cur_tok, k] f32 (weights_buf) - addr_out_tok: fx.Int64, # shmem_out_tok - addr_out_wts: fx.Int64, # shmem_out_wts - addr_out_idx: fx.Int64, # shmem_out_idx - addr_tok_off: fx.Int64, # shmem_tok_off (i32[1]) + addr_inp_tok: fx.Int64, # [cur_tok, hidden_dim] bf16 + addr_idx: fx.Int64, # [cur_tok, k] i32 (token_indices) + addr_wts: fx.Int64, # [cur_tok, k] f32 (weights_buf) + addr_out_tok: fx.Int64, # shmem_out_tok + addr_out_wts: fx.Int64, # shmem_out_wts + addr_out_idx: fx.Int64, # shmem_out_idx + addr_tok_off: fx.Int64, # shmem_tok_off (i32[1]) addr_recv_num: fx.Int64, # recv_tok_num (i32[npes]) addr_dest_ctr: fx.Int64, # dest_pe_ctr (i32[npes]) addr_disp_bar: fx.Int64, # dispatch_bar (i32[1]) - addr_tok_map: fx.Int64, # dest_tok_map (i32[cur_tok*k]) - addr_tis: fx.Int64, # tok_id_to_src (i32[max_recv]) + addr_tok_map: fx.Int64, # dest_tok_map (i32[cur_tok*k]) + addr_tis: fx.Int64, # tok_id_to_src (i32[max_recv]) addr_total_rv: fx.Int64, # total_recv (i32[1]) # Pre-resolved P2P address arrays (i64[npes]): the remote-side base # of each symmetric shmem buffer on every peer PE. - addr_p2p_tok_off: fx.Int64, - addr_p2p_tis: fx.Int64, - addr_p2p_out_wts: fx.Int64, - addr_p2p_out_idx: fx.Int64, - addr_p2p_out_tok: fx.Int64, + addr_p2p_tok_off: fx.Int64, + addr_p2p_tis: fx.Int64, + addr_p2p_out_wts: fx.Int64, + addr_p2p_out_idx: fx.Int64, + addr_p2p_out_tok: fx.Int64, addr_p2p_recv_num: fx.Int64, - addr_scales: fx.Int64, # input scales buffer + addr_scales: fx.Int64, # input scales buffer addr_p2p_out_scales: fx.Int64, # i64[npes] P2P addresses of scales buffer # ── StdMoE ConvertDispatchOutput parameters ── - addr_packed_recv_x: fx.Int64, # expert-major token buffer - addr_packed_recv_count: fx.Int64, # per-expert token count (i32[experts_per_rank]) + addr_packed_recv_x: fx.Int64, # expert-major token buffer + addr_packed_recv_count: fx.Int64, # per-expert token count (i32[experts_per_rank]) addr_packed_recv_src_info: fx.Int64, # source info (i32[experts_per_rank * max_tok_per_expert]) - addr_disp_tok_map: fx.Int64, # slot mapping (i64[max_recv * top_k]) - addr_disp_grid_bar: fx.Int64, # grid barrier (i32[1]) - cur_tok: fx.Int32, # runtime token count for the current batch + addr_disp_tok_map: fx.Int64, # slot mapping (i64[max_recv * top_k]) + addr_disp_grid_bar: fx.Int64, # grid barrier (i32[1]) + cur_tok: fx.Int32, # runtime token count for the current batch ): - tid = fx.thread_idx.x # thread id within the block - bid = fx.block_idx.x # block id within the grid - lane = tid & 63 # lane id within the warp (0..63) - warp = tid >> 6 # warp id within the block - global_warp_id = bid * warp_num_per_block + warp # warp id across the grid - global_warp_num = block_num * warp_num_per_block # total warps in the grid - work_limit = cur_tok * experts_per_token # total (token, k-slot) pairs - _r_idx = create_buffer_resource_from_addr(addr_idx) - _r_wts = create_buffer_resource_from_addr(addr_wts) + tid = fx.thread_idx.x # thread id within the block + bid = fx.block_idx.x # block id within the grid + lane = tid & 63 # lane id within the warp (0..63) + warp = tid >> 6 # warp id within the block + global_warp_id = bid * warp_num_per_block + warp # warp id across the grid + global_warp_num = block_num * warp_num_per_block # total warps in the grid + work_limit = cur_tok * experts_per_token # total (token, k-slot) pairs + _r_idx = create_buffer_resource_from_addr(addr_idx) + _r_wts = create_buffer_resource_from_addr(addr_wts) _r_tok_map = create_buffer_resource_from_addr(addr_tok_map) _r_tok_off = create_buffer_resource_from_addr(addr_tok_off) _r_dest_ctr = create_buffer_resource_from_addr(addr_dest_ctr) _r_disp_bar = create_buffer_resource_from_addr(addr_disp_bar) _r_total_rv = create_buffer_resource_from_addr(addr_total_rv) - _r_p2p_tok_off = create_buffer_resource_from_addr(addr_p2p_tok_off) - _r_p2p_tis = create_buffer_resource_from_addr(addr_p2p_tis) - _r_p2p_out_wts = create_buffer_resource_from_addr(addr_p2p_out_wts) - _r_p2p_out_idx = create_buffer_resource_from_addr(addr_p2p_out_idx) - _r_p2p_out_tok = create_buffer_resource_from_addr(addr_p2p_out_tok) + _r_p2p_tok_off = create_buffer_resource_from_addr(addr_p2p_tok_off) + _r_p2p_tis = create_buffer_resource_from_addr(addr_p2p_tis) + _r_p2p_out_wts = create_buffer_resource_from_addr(addr_p2p_out_wts) + _r_p2p_out_idx = create_buffer_resource_from_addr(addr_p2p_out_idx) + _r_p2p_out_tok = create_buffer_resource_from_addr(addr_p2p_out_tok) _r_p2p_recv_num = create_buffer_resource_from_addr(addr_p2p_recv_num) # Phase 1: P2P-scatter tokens to their destination PEs. @@ -261,41 +251,35 @@ def ep_dispatch_intranode( # all grid-wide warps. ``k_slot`` is the per-token expert slot index # (i.e. which of the top-k experts this work-item handles). for work_idx in range(global_warp_id, work_limit, global_warp_num): - src_tok = (work_idx // experts_per_token) - k_slot = (work_idx % experts_per_token) + src_tok = work_idx // experts_per_token + k_slot = work_idx % experts_per_token # Issue the two idx loads in parallel; divui is deferred so the # loads do not block on the integer divide. - dest_expert = buffer_load(_r_idx, work_idx, vec_width=1, dtype=T.i32()) - safe_lane = arith.select(lane < k_slot, lane, 0) - lane_expert = buffer_load(_r_idx, src_tok * experts_per_token + safe_lane, vec_width=1, dtype=T.i32()) - dest_pe = (dest_expert // experts_per_rank) - lane_dest_pe = (lane_expert // experts_per_rank) + dest_expert = buffer_load(_r_idx, work_idx, vec_width=1, dtype=T.i32()) + safe_lane = arith.select(lane < k_slot, lane, 0) + lane_expert = buffer_load(_r_idx, src_tok * experts_per_token + safe_lane, vec_width=1, dtype=T.i32()) + dest_pe = dest_expert // experts_per_rank + lane_dest_pe = lane_expert // experts_per_rank # Per-lane "is this lane a duplicate destPE assignment for some # k_slot earlier than the current one?" (sentinel 64 = no). - dup_per_lane = arith.select( - lane_dest_pe == dest_pe, - arith.select(lane < k_slot, lane, 64), - 64) - dup_ballot = ballot(T.i64(), dup_per_lane < 64) - is_dup = dup_ballot != 0 + dup_per_lane = arith.select(lane_dest_pe == dest_pe, arith.select(lane < k_slot, lane, 64), 64) + dup_ballot = ballot(T.i64(), dup_per_lane < 64) + is_dup = dup_ballot != 0 # Atomically allocate dest_tok_id on lane 0, then readlane-broadcast. dest_tok_lane0 = arith.constant(0) if lane == 0: if dup_ballot == 0: dest_tok_lane0 = atomic_add_global_at( - buffer_load(_r_p2p_tok_off, dest_pe, vec_width=1, dtype=T.i64()), - arith.constant(1)) + buffer_load(_r_p2p_tok_off, dest_pe, vec_width=1, dtype=T.i64()), arith.constant(1) + ) dest_tok_id = readlane(T.i32(), dest_tok_lane0, 0) # Per-(token, k_slot) entry stored into dest_tok_map: encoded # global slot id, or sentinel ``npes * max_recv`` for dup-slots # which the combine kernel will treat as "no source". sentinel_val = npes * max_recv - tok_map_entry = arith.select( - is_dup, - sentinel_val, - dest_pe * max_recv + dest_tok_id) + tok_map_entry = arith.select(is_dup, sentinel_val, dest_pe * max_recv + dest_tok_id) if lane == 0: buffer_store(tok_map_entry, _r_tok_map, work_idx) @@ -305,7 +289,8 @@ def ep_dispatch_intranode( # can later route the token back during combine. src_tok_enc = rank * max_tok_per_rank + src_tok _r_tis_remote = create_buffer_resource_from_addr( - buffer_load(_r_p2p_tis, dest_pe, vec_width=1, dtype=T.i64())) + buffer_load(_r_p2p_tis, dest_pe, vec_width=1, dtype=T.i64()) + ) buffer_store(src_tok_enc, _r_tis_remote, dest_tok_id) dest_ctr_addr = addr_dest_ctr + _to_i64(dest_pe) * 4 atomic_add_global_at(dest_ctr_addr, arith.constant(1)) @@ -314,15 +299,17 @@ def ep_dispatch_intranode( # PE's symmetric weights / idx buffers, parallel over k_slot. if lane < experts_per_token: if dup_ballot == 0: - wt_src_off = src_tok * experts_per_token + lane - wt_val = buffer_load(_r_wts, wt_src_off, vec_width=1, dtype=T.f32()) - idx_val = buffer_load(_r_idx, wt_src_off, vec_width=1, dtype=T.i32()) - dest_slot = dest_tok_id * experts_per_token + lane + wt_src_off = src_tok * experts_per_token + lane + wt_val = buffer_load(_r_wts, wt_src_off, vec_width=1, dtype=T.f32()) + idx_val = buffer_load(_r_idx, wt_src_off, vec_width=1, dtype=T.i32()) + dest_slot = dest_tok_id * experts_per_token + lane _r_wts_remote = create_buffer_resource_from_addr( - buffer_load(_r_p2p_out_wts, dest_pe, vec_width=1, dtype=T.i64())) + buffer_load(_r_p2p_out_wts, dest_pe, vec_width=1, dtype=T.i64()) + ) buffer_store(arith.bitcast(T.i32(), wt_val), _r_wts_remote, dest_slot) _r_idx_remote = create_buffer_resource_from_addr( - buffer_load(_r_p2p_out_idx, dest_pe, vec_width=1, dtype=T.i64())) + buffer_load(_r_p2p_out_idx, dest_pe, vec_width=1, dtype=T.i64()) + ) buffer_store(idx_val, _r_idx_remote, dest_slot) if const_expr(enable_scales): @@ -330,12 +317,16 @@ def ep_dispatch_intranode( if dup_ballot == 0: _r_scales = create_buffer_resource_from_addr(addr_scales) sc_src_off = src_tok * scale_n_i32 + lane - sc_val = buffer_load(_r_scales, sc_src_off, vec_width=1, dtype=T.i32()) + sc_val = buffer_load(_r_scales, sc_src_off, vec_width=1, dtype=T.i32()) sc_dst_off = dest_tok_id * scale_n_i32 + lane _r_sc_remote = create_buffer_resource_from_addr( buffer_load( create_buffer_resource_from_addr(_lv_unwrap(addr_p2p_out_scales)), - dest_pe, vec_width=1, dtype=T.i64())) + dest_pe, + vec_width=1, + dtype=T.i64(), + ) + ) buffer_store(sc_val, _r_sc_remote, sc_dst_off) # Token-embedding scatter: when ``is_dup`` the copy_end equals @@ -345,13 +336,14 @@ def ep_dispatch_intranode( # owns 4 consecutive i32 = 16 bytes). # ``chunk_i32_off`` - sliding i32 offset within the token's # hidden-dim chunk being copied this step. - remote_tok_addr = buffer_load(_r_p2p_out_tok, dest_pe, vec_width=1, dtype=T.i64()) + \ - _to_i64(dest_tok_id) * nbytes - local_tok_addr = addr_inp_tok + _to_i64(src_tok) * nbytes - rsrc_src = create_buffer_resource_from_addr(local_tok_addr) - rsrc_dst = create_buffer_resource_from_addr(remote_tok_addr) - lane_i32_off = lane * 4 - safe_end_i32 = (n_i32 // 512) * 512 # largest multiple of 512 that fits + remote_tok_addr = ( + buffer_load(_r_p2p_out_tok, dest_pe, vec_width=1, dtype=T.i64()) + _to_i64(dest_tok_id) * nbytes + ) + local_tok_addr = addr_inp_tok + _to_i64(src_tok) * nbytes + rsrc_src = create_buffer_resource_from_addr(local_tok_addr) + rsrc_dst = create_buffer_resource_from_addr(remote_tok_addr) + lane_i32_off = lane * 4 + safe_end_i32 = (n_i32 // 512) * 512 # largest multiple of 512 that fits if const_expr(n_i32 >= 512 and safe_end_i32 > 0): copy_end_main = arith.select(is_dup, lane_i32_off, safe_end_i32) for chunk_i32_off in range(lane_i32_off, copy_end_main, 512): @@ -384,8 +376,9 @@ def ep_dispatch_intranode( buffer_store(arith.constant(0), _r_disp_bar, 0) # +1 because 0 is the "unset" sentinel that consumers wait on. signal_value = buffer_load(_r_dest_ctr, dest_pe, vec_width=1, dtype=T.i32()) + 1 - recv_num_remote_addr = buffer_load( - _r_p2p_recv_num, dest_pe, vec_width=1, dtype=T.i64()) + recv_num_local_byte_off + recv_num_remote_addr = ( + buffer_load(_r_p2p_recv_num, dest_pe, vec_width=1, dtype=T.i64()) + recv_num_local_byte_off + ) mori_shmem.int32_wait_until_equals(recv_num_remote_addr, 0) store_i32_system(recv_num_remote_addr, arith.constant(0), signal_value) @@ -393,8 +386,8 @@ def ep_dispatch_intranode( for src_pe in range(lane, npes, 64): if global_warp_id == 0: recv_num_src_addr = addr_recv_num + _to_i64(src_pe) * 4 - signal_value = mori_shmem.int32_wait_until_greater_than(recv_num_src_addr, 0) - peer_recv_count = signal_value - 1 # undo the +1 sentinel offset + signal_value = mori_shmem.int32_wait_until_greater_than(recv_num_src_addr, 0) + peer_recv_count = signal_value - 1 # undo the +1 sentinel offset store_i32_system(recv_num_src_addr, arith.constant(0), arith.constant(0)) atomic_add_global_at(addr_total_rv, peer_recv_count) buffer_store(arith.constant(0), _r_dest_ctr, src_pe) @@ -417,23 +410,22 @@ def ep_dispatch_intranode( fx.barrier() _r_out_idx_local = create_buffer_resource_from_addr(addr_out_idx) - _r_tis_local = create_buffer_resource_from_addr(addr_tis) + _r_tis_local = create_buffer_resource_from_addr(addr_tis) _r_out_tok_local = create_buffer_resource_from_addr(addr_out_tok) - total_recv = buffer_load(_r_total_rv, 0, vec_width=1, dtype=T.i32()) + total_recv = buffer_load(_r_total_rv, 0, vec_width=1, dtype=T.i32()) smoe_work_limit = total_recv * experts_per_token for smoe_idx in range(global_warp_id, smoe_work_limit, global_warp_num): - smoe_tok_id = (smoe_idx // experts_per_token) + smoe_tok_id = smoe_idx // experts_per_token - expert_id = buffer_load(_r_out_idx_local, smoe_idx, vec_width=1, dtype=T.i32()) + expert_id = buffer_load(_r_out_idx_local, smoe_idx, vec_width=1, dtype=T.i32()) local_expert_id = expert_id - rank * experts_per_rank # MUST be unsigned ``ult``: when ``expert_id`` is NOT this # rank's expert, ``local_expert_id`` is negative; the # signed-overload form ``local_expert_id < experts_per_rank`` # lowers to ``arith.cmpi slt`` and would mis-classify negative # values as local (-> illegal global access in WarpCopy). - is_local = arith.cmpi(arith.CmpIPredicate.ult, local_expert_id, - arith.constant(experts_per_rank)) + is_local = arith.cmpi(arith.CmpIPredicate.ult, local_expert_id, arith.constant(experts_per_rank)) # Atomically allocate the per-expert packing slot on lane 0. packed_slot_lane0 = arith.constant(0) @@ -446,26 +438,24 @@ def ep_dispatch_intranode( safe_local_expert = arith.select(is_local, local_expert_id, 0) # Linear slot in the flat ``packed_recv_x[experts_per_rank, max_tokens_per_expert]`` buffer. packed_linear_idx = safe_local_expert * max_tokens_per_expert + packed_slot - slot_val_i64 = arith.select(is_local, - _to_i64(packed_linear_idx), - -1) # false_value materialized as i64 from true_value's type; -1 = not a local expert + slot_val_i64 = arith.select( + is_local, _to_i64(packed_linear_idx), -1 + ) # false_value materialized as i64 from true_value's type; -1 = not a local expert if lane == 0: slot_map_addr = addr_disp_tok_map + _to_i64(smoe_idx) * 8 store_i64_global_system(slot_map_addr, slot_val_i64) if lane == 0: if is_local: - src_pos_enc = buffer_load(_r_tis_local, smoe_tok_id, - vec_width=1, dtype=T.i32()) - store_i32_system(addr_packed_recv_src_info, - packed_linear_idx, src_pos_enc) + src_pos_enc = buffer_load(_r_tis_local, smoe_tok_id, vec_width=1, dtype=T.i32()) + store_i32_system(addr_packed_recv_src_info, packed_linear_idx, src_pos_enc) # WarpCopy token data from shmem_out_tok into the packed # per-expert buffer at slot ``packed_linear_idx``. src_tok_base = addr_out_tok + _to_i64(smoe_tok_id) * nbytes dst_tok_base = addr_packed_recv_x + _to_i64(packed_linear_idx) * nbytes - rsrc_src = create_buffer_resource_from_addr(src_tok_base) - rsrc_dst = create_buffer_resource_from_addr(dst_tok_base) + rsrc_src = create_buffer_resource_from_addr(src_tok_base) + rsrc_dst = create_buffer_resource_from_addr(dst_tok_base) lane_i32_off = lane * 4 safe_end_i32 = (n_i32 // 512) * 512 if n_i32 >= 512 and safe_end_i32 > 0: @@ -535,15 +525,14 @@ def make_combine_kernel( bf16 -> fp8 cast inline (``UseFp8DirectCast``-equivalent), and Stage 3 widens addressing strides for bf16 output writes. """ - max_recv = npes * max_tok_per_rank - _is_fp4 = (data_type == torch.float4_e2m1fn_x2) + max_recv = npes * max_tok_per_rank + _is_fp4 = data_type == torch.float4_e2m1fn_x2 if _is_fp4: - n_i32 = hidden_dim // 8 + n_i32 = hidden_dim // 8 nbytes = hidden_dim // 2 else: - n_i32 = (hidden_dim * hidden_elem_size) // 4 + n_i32 = (hidden_dim * hidden_elem_size) // 4 nbytes = hidden_dim * hidden_elem_size - tok_stride = n_i32 * 4 # Mixed-dtype combine: external dtype (kernel input AND output) differs # from the on-wire/staging dtype used for P2P transport. Currently only @@ -557,7 +546,8 @@ def make_combine_kernel( # ``inp_data_type`` is the legacy parameter name; conceptually it now # represents the external (input/output-shared) dtype. _xfer_bf16_to_fp8 = ( - inp_data_type is not None and inp_data_type != data_type + inp_data_type is not None + and inp_data_type != data_type and inp_data_type == torch.bfloat16 and data_type == torch.float8_e4m3fn ) @@ -565,28 +555,29 @@ def make_combine_kernel( raise NotImplementedError( f"combine_kernel mixed-dtype only supports " f"inp_data_type=bfloat16 + data_type=float8_e4m3fn, " - f"got inp_data_type={inp_data_type}, data_type={data_type}") + f"got inp_data_type={inp_data_type}, data_type={data_type}" + ) if _xfer_bf16_to_fp8 and enable_std_moe: raise NotImplementedError( "combine_kernel mixed-dtype path does not yet support " "enable_std_moe=True (the std-MoE Stage 1 / Stage 3 use " "_weighted_accum_experts which has not been retrofitted for " - "asymmetric I/O dtypes)") + "asymmetric I/O dtypes)" + ) if _xfer_bf16_to_fp8: # bf16 input stride for Stage 1 source addressing only. The transport # (P2P-scattered staging) uses ``nbytes`` (= fp8 stride) as before. # Stage 3 output addressing also uses bf16 stride (= 2 × fp8 stride). inp_nbytes = hidden_dim * 2 - inp_n_i32 = (hidden_dim * 2) // 4 # bf16-stride i32 count per token for Stage 3 output offsets. - out_n_i32 = (hidden_dim * 2) // 4 + out_n_i32 = (hidden_dim * 2) // 4 else: inp_nbytes = nbytes - inp_n_i32 = n_i32 - out_n_i32 = n_i32 + out_n_i32 = n_i32 if _is_fp4: from flydsl._mlir.dialects import rocdl as _rocdl_d + _v2f32_fp4 = T.VectorType.get([2], T.f32()) _v8f32_fp4 = T.VectorType.get([8], T.f32()) @@ -594,9 +585,7 @@ def _to_accum(i32_val): # ROCDL fp4 lane unpack: i32 (8 packed fp4) -> 4 × vector<2xf32>. scale_one = arith.constant(1.0, type=T.f32()) pairs = [ - _rocdl_d.cvt_scalef32_pk_f32_fp4( - res=_v2f32_fp4, src=i32_val, scale=scale_one, - src_sel_index=sel) + _rocdl_d.cvt_scalef32_pk_f32_fp4(res=_v2f32_fp4, src=i32_val, scale=scale_one, src_sel_index=sel) for sel in range(4) ] # Stitch 4 × v2f32 -> v8f32 via two-stage shuffle. @@ -613,35 +602,42 @@ def _from_accum(accum_val): f_a = vector.extract(accum_val, static_position=[sel * 2]) f_b = vector.extract(accum_val, static_position=[sel * 2 + 1]) old = _rocdl_d.cvt_scalef32_pk_fp4_f32( - res=_i32_ty, old_vdst=old, src0=f_a, src1=f_b, - scale=scale_one, dst_sel_index=sel) + res=_i32_ty, old_vdst=old, src0=f_a, src1=f_b, scale=scale_one, dst_sel_index=sel + ) return old def _zero_accum(): return arith.constant_vector(0.0, _v8f32_fp4) elif hidden_elem_size == 2: # bf16 + def _to_accum(i32_val): - return _i32_to_vec_bitcast(T.VectorType.get([2], T.bf16()), i32_val).extf( - T.VectorType.get([2], T.f32())) + return _i32_to_vec_bitcast(T.VectorType.get([2], T.bf16()), i32_val).extf(T.VectorType.get([2], T.f32())) + def _from_accum(accum_val): - return _vec_to_i32_bitcast(accum_val.truncf( - T.VectorType.get([2], T.bf16()))) + return _vec_to_i32_bitcast(accum_val.truncf(T.VectorType.get([2], T.bf16()))) + def _zero_accum(): return arith.constant_vector(0.0, T.VectorType.get([2], T.f32())) + elif hidden_elem_size == 4: # f32 + def _to_accum(i32_val): return arith.bitcast(T.f32(), i32_val) + def _from_accum(accum_val): return arith.bitcast(T.i32(), accum_val) + def _zero_accum(): return arith.constant(0.0, type=T.f32()) + elif hidden_elem_size == 1: # fp8 from flydsl._mlir.dialects import rocdl as _rocdl_d - _is_ocp = (data_type == torch.float8_e4m3fn) - _is_fnuz = (data_type == torch.float8_e4m3fnuz) + + _is_ocp = data_type == torch.float8_e4m3fn + _is_fnuz = data_type == torch.float8_e4m3fnuz _cvt_pk_f32 = _rocdl_d.cvt_pk_f32_fp8 - _cvt_pk_f8 = _rocdl_d.cvt_pk_fp8_f32 + _cvt_pk_f8 = _rocdl_d.cvt_pk_fp8_f32 _v2f32_fp8 = T.VectorType.get([2], T.f32()) _v4f32_fp8 = T.VectorType.get([4], T.f32()) @@ -665,20 +661,19 @@ def _from_accum(accum_val): # via buffer_store(..., vec_width=2, dtype=T.i32()) at an i32 # offset doubled relative to fp8 mode (2 i32 = 4 bf16 = 8 B). _v4bf16 = T.VectorType.get([4], T.bf16()) - _v2i32 = T.VectorType.get([2], _i32_ty) + _v2i32 = T.VectorType.get([2], _i32_ty) return vector.bitcast(_v2i32, accum_val.truncf(_v4bf16)) f0 = vector.extract(accum_val, static_position=[0]) f1 = vector.extract(accum_val, static_position=[1]) f2 = vector.extract(accum_val, static_position=[2]) f3 = vector.extract(accum_val, static_position=[3]) zero = arith.constant(0, type=_i32_ty) - lo = _cvt_pk_f8(res=_i32_ty, src_a=f0, src_b=f1, - old=zero, word_sel=False) - return _cvt_pk_f8(res=_i32_ty, src_a=f2, src_b=f3, - old=lo, word_sel=True) + lo = _cvt_pk_f8(res=_i32_ty, src_a=f0, src_b=f1, old=zero, word_sel=False) + return _cvt_pk_f8(res=_i32_ty, src_a=f2, src_b=f3, old=lo, word_sel=True) def _zero_accum(): return arith.constant_vector(0.0, _v4f32_fp8) + else: raise ValueError(f"Unsupported hidden_elem_size={hidden_elem_size}") @@ -701,7 +696,7 @@ def _accum_experts(vals, vlds, all_vld): acc = _zero_accum() for k_slot in range(len(vals)): widened = _to_accum(vals[k_slot]) - zero = _zero_accum() + zero = _zero_accum() vld_raw = _lv_unwrap(vlds[k_slot]) acc = acc + arith.select(vld_raw, widened, zero) return _from_accum(acc) @@ -718,6 +713,7 @@ def _weighted_accum_experts(vals, wts, vlds, all_vld): if _is_fp4: # fp4 → v8f32 accum from flydsl._mlir.dialects import rocdl as _rocdl_fp4 + _v2f32 = T.VectorType.get([2], T.f32()) _v8f32 = T.VectorType.get([8], T.f32()) scale_one = arith.constant(1.0, type=_f32ty) @@ -725,66 +721,62 @@ def _weighted_accum_experts(vals, wts, vlds, all_vld): for j in range(len(vals)): # ROCDL fp4 lane unpack: i32 (8 packed fp4) -> 4 × vector<2xf32>. pairs = [ - _rocdl_fp4.cvt_scalef32_pk_f32_fp4( - res=_v2f32, src=vals[j], scale=scale_one, - src_sel_index=sel) + _rocdl_fp4.cvt_scalef32_pk_f32_fp4(res=_v2f32, src=vals[j], scale=scale_one, src_sel_index=sel) for sel in range(4) ] # Stitch 4 × v2f32 -> v8f32 via two-stage shuffle. lo4 = vector.shuffle(pairs[0], pairs[1], [0, 1, 2, 3]) hi4 = vector.shuffle(pairs[2], pairs[3], [0, 1, 2, 3]) vec = vector.shuffle(lo4, hi4, [0, 1, 2, 3, 4, 5, 6, 7]) - w = vec * wts[j] # auto-broadcast scalar to v8f32 + w = vec * wts[j] # auto-broadcast scalar to v8f32 if all_vld: acc = acc + w else: - acc = acc + arith.select( - vlds[j], w, arith.constant_vector(0.0, _v8f32)) + acc = acc + arith.select(vlds[j], w, arith.constant_vector(0.0, _v8f32)) # Re-pack v8f32 -> i32 via 4 × cvt_scalef32_pk_fp4_f32. old = arith.constant(0, type=_i32ty) for sel in range(4): f_a = vector.extract(acc, static_position=[sel * 2]) f_b = vector.extract(acc, static_position=[sel * 2 + 1]) old = _rocdl_fp4.cvt_scalef32_pk_fp4_f32( - res=_i32ty, old_vdst=old, src0=f_a, src1=f_b, - scale=scale_one, dst_sel_index=sel) + res=_i32ty, old_vdst=old, src0=f_a, src1=f_b, scale=scale_one, dst_sel_index=sel + ) return old elif hidden_elem_size == 2: # bf16 → v2f32 accum _v2bf16 = T.VectorType.get([2], T.bf16()) - _v2f32 = T.VectorType.get([2], T.f32()) + _v2f32 = T.VectorType.get([2], T.f32()) acc = arith.constant_vector(0.0, _v2f32) for j in range(len(vals)): # i32 → vector<2xbf16> via from_elements + vector.bitcast # → vector<2xf32> via arith.extf, then broadcast wt and fma. vb = _i32_to_vec_bitcast(_v2bf16, vals[j]) vf = vb.extf(_v2f32) - w = vf * wts[j] # wts[j] scalar f32, auto-broadcast to v2f32 + w = vf * wts[j] # wts[j] scalar f32, auto-broadcast to v2f32 if all_vld: acc = acc + w else: - acc = acc + arith.select( - vlds[j], w, arith.constant_vector(0.0, _v2f32)) + acc = acc + arith.select(vlds[j], w, arith.constant_vector(0.0, _v2f32)) return _vec_to_i32_bitcast(acc.truncf(_v2bf16)) elif hidden_elem_size == 4: # f32 → f32 accum acc = arith.constant(0.0, type=_f32ty) for j in range(len(vals)): vf = arith.bitcast(_f32ty, vals[j]) - w = vf * wts[j] + w = vf * wts[j] if all_vld: acc = acc + w else: - acc = acc + arith.select( - vlds[j], w, arith.constant(0.0, type=_f32ty)) + acc = acc + arith.select(vlds[j], w, arith.constant(0.0, type=_f32ty)) return arith.bitcast(_i32ty, acc) elif hidden_elem_size == 1: # fp8 → v4f32 accum from flydsl._mlir.dialects import rocdl as _rocdl + _pk_f32 = _rocdl.cvt_pk_f32_fp8 - _pk_f8 = _rocdl.cvt_pk_fp8_f32 - _v2f32 = T.VectorType.get([2], T.f32()) - _v4f32 = T.VectorType.get([4], T.f32()) + _pk_f8 = _rocdl.cvt_pk_fp8_f32 + _v2f32 = T.VectorType.get([2], T.f32()) + _v4f32 = T.VectorType.get([4], T.f32()) acc = arith.constant_vector(0.0, _v4f32) for j in range(len(vals)): # ROCDL fp8 lane unpack: i32 (4 packed fp8) -> 2 × vector<2xf32>. @@ -795,12 +787,11 @@ def _weighted_accum_experts(vals, wts, vlds, all_vld): vec = vector.shuffle(lo, hi, [0, 1, 2, 3]) if _is_fnuz: vec = vec * 0.5 - w = vec * wts[j] # wts[j] scalar f32, auto-broadcast to v4f32 + w = vec * wts[j] # wts[j] scalar f32, auto-broadcast to v4f32 if all_vld: acc = acc + w else: - acc = acc + arith.select( - vlds[j], w, arith.constant_vector(0.0, _v4f32)) + acc = acc + arith.select(vlds[j], w, arith.constant_vector(0.0, _v4f32)) if _is_fnuz: acc = acc * 2.0 f0 = vector.extract(acc, static_position=[0]) @@ -816,13 +807,14 @@ def _log2_if_pow2(v): if v > 0 and (v & (v - 1)) == 0: return v.bit_length() - 1 return None + # Pow2 fast-paths: when ``max_tok_per_rank`` / ``max_recv`` are powers # of two, decode ``dest_pe / dest_lid`` and ``dest_pe / dtok`` via # shift + mask instead of integer divide / mod. - _log2_max_tok = _log2_if_pow2(max_tok_per_rank) + _log2_max_tok = _log2_if_pow2(max_tok_per_rank) _log2_max_recv = _log2_if_pow2(max_recv) - _mask_max_tok = max_tok_per_rank - 1 if _log2_max_tok is not None else None - _mask_max_recv = max_recv - 1 if _log2_max_recv is not None else None + _mask_max_tok = max_tok_per_rank - 1 if _log2_max_tok is not None else None + _mask_max_recv = max_recv - 1 if _log2_max_recv is not None else None # Dispatch deduplicates same-PE assignments at runtime: when more than # one of a token's k experts fall on the same dest_pe, the duplicate @@ -833,7 +825,7 @@ def _log2_if_pow2(v): _use_compaction = True weight_bytes = experts_per_token * 4 if enable_weights else 0 - wt_n_i32 = experts_per_token if enable_weights else 0 + wt_n_i32 = experts_per_token if enable_weights else 0 # LDS layout for the P2P-base tables (i64[npes] for tokens, optionally # i64[npes] for weights). ``SmemAllocator.finalize()`` is called from the @@ -848,37 +840,36 @@ def _log2_if_pow2(v): p2p_wt_base_size = npes * 8 allocator.ptr = p2p_wt_base_offset + p2p_wt_base_size - @flyc.kernel def ep_combine_intranode( - addr_inp_tok: fx.Int64, # inp_tok base (post-expert token buffer) - addr_comb_inp: fx.Int64, # shmem_comb_inp base (symmetric) - addr_comb_out: fx.Int64, # shmem_comb_out base (symmetric) - addr_xdb_mem: fx.Int64, # xdev_bar_mem (u64[npes]) - addr_xdb_flag: fx.Int64, # xdev_bar_flag (u64[1]) - addr_tok_map: fx.Int64, # dest_tok_map (i32[cur_tok*k]) - addr_comb_bar: fx.Int64, # combine_bar (i32[1]) - addr_trecv: fx.Int64, # total_recv_ptr (i32[1]) - addr_tis: fx.Int64, # tok_id_to_src (i32[max_recv], symmetric) + addr_inp_tok: fx.Int64, # inp_tok base (post-expert token buffer) + addr_comb_inp: fx.Int64, # shmem_comb_inp base (symmetric) + addr_comb_out: fx.Int64, # shmem_comb_out base (symmetric) + addr_xdb_mem: fx.Int64, # xdev_bar_mem (u64[npes]) + addr_xdb_flag: fx.Int64, # xdev_bar_flag (u64[1]) + addr_tok_map: fx.Int64, # dest_tok_map (i32[cur_tok*k]) + addr_comb_bar: fx.Int64, # combine_bar (i32[1]) + addr_trecv: fx.Int64, # total_recv_ptr (i32[1]) + addr_tis: fx.Int64, # tok_id_to_src (i32[max_recv], symmetric) addr_p2p_comb_inp: fx.Int64, # i64[npes] pre-resolved P2P addresses - addr_p2p_xdb_mem: fx.Int64, # i64[npes] pre-resolved P2P addresses - addr_wts_buf: fx.Int64, # combine input weights f32[max_recv*k] + addr_p2p_xdb_mem: fx.Int64, # i64[npes] pre-resolved P2P addresses + addr_wts_buf: fx.Int64, # combine input weights f32[max_recv*k] addr_comb_inp_wts: fx.Int64, # shmem weight P2P buffer (symmetric) addr_comb_out_wts: fx.Int64, # combine output weights f32[max_tok*k] addr_p2p_comb_inp_wts: fx.Int64, # i64[npes] weight P2P addresses # ── StdMoE ConvertCombineInput parameters ── - addr_packed_recv_x: fx.Int64, # expert-major token buffer (post-expert) - addr_disp_tok_map: fx.Int64, # dispTokToEpSlotMap (i64[max_recv * top_k]) - addr_disp_out_wts: fx.Int64, # dispatch output weights (f32[max_recv * top_k]) - cur_rank_num_token: fx.Int32, # this PE's output token count (used by Stage 3) + addr_packed_recv_x: fx.Int64, # expert-major token buffer (post-expert) + addr_disp_tok_map: fx.Int64, # dispTokToEpSlotMap (i64[max_recv * top_k]) + addr_disp_out_wts: fx.Int64, # dispatch output weights (f32[max_recv * top_k]) + cur_rank_num_token: fx.Int32, # this PE's output token count (used by Stage 3) ): - tid = fx.thread_idx.x - bid = fx.block_idx.x - lane = tid & 63 - warp = tid >> 6 - global_warp_id = bid * warp_num_per_block + warp # warp id across the grid - global_warp_num = block_num * warp_num_per_block # total warps in the grid - grid_thread_id = bid * (warp_num_per_block * 64) + tid # grid-wide thread id (used by Stage 2 only) + tid = fx.thread_idx.x + bid = fx.block_idx.x + lane = tid & 63 + warp = tid >> 6 + global_warp_id = bid * warp_num_per_block + warp # warp id across the grid + global_warp_num = block_num * warp_num_per_block # total warps in the grid + grid_thread_id = bid * (warp_num_per_block * 64) + tid # grid-wide thread id (used by Stage 2 only) # Predicated buffer_load: returns 0 (i32) when vld_flag is false. # Defined as a nested function so the AST rewriter lowers the Python @@ -890,12 +881,12 @@ def _maybe_load(rsrc, offset, vld_flag, **kwargs): result = buffer_load(rsrc, offset, **kwargs) return result - _r_trecv = create_buffer_resource_from_addr(addr_trecv) - _r_xdb_flag = create_buffer_resource_from_addr(addr_xdb_flag) - _r_tis = create_buffer_resource_from_addr(addr_tis) - _r_comb_bar = create_buffer_resource_from_addr(addr_comb_bar) - _r_p2p_comb = create_buffer_resource_from_addr(addr_p2p_comb_inp) - _r_p2p_xdb = create_buffer_resource_from_addr(addr_p2p_xdb_mem) + _r_trecv = create_buffer_resource_from_addr(addr_trecv) + _r_xdb_flag = create_buffer_resource_from_addr(addr_xdb_flag) + _r_tis = create_buffer_resource_from_addr(addr_tis) + _r_comb_bar = create_buffer_resource_from_addr(addr_comb_bar) + _r_p2p_comb = create_buffer_resource_from_addr(addr_p2p_comb_inp) + _r_p2p_xdb = create_buffer_resource_from_addr(addr_p2p_xdb_mem) _rsrc_tok_map = create_buffer_resource_from_addr(addr_tok_map) total_recv = buffer_load(_r_trecv, 0, vec_width=1, dtype=T.i32()) @@ -916,8 +907,7 @@ def _maybe_load(rsrc, offset, vld_flag, **kwargs): # variable (which then fails because SmemPtr is not an MLIR Value). # All ``_lds_p2p_*`` and downstream ``SmemPtr.{get,load,store}`` # call sites follow the same convention. - _lds_p2p_bases = SmemPtr(base_ptr, p2p_base_offset, T.i64(), - shape=(npes,)) + _lds_p2p_bases = SmemPtr(base_ptr, p2p_base_offset, T.i64(), shape=(npes,)) SmemPtr.get(_lds_p2p_bases) if lane < npes: @@ -926,8 +916,7 @@ def _maybe_load(rsrc, offset, vld_flag, **kwargs): if const_expr(enable_weights): _r_p2p_comb_wt = create_buffer_resource_from_addr(addr_p2p_comb_inp_wts) - _lds_p2p_wt_bases = SmemPtr(base_ptr, p2p_wt_base_offset, T.i64(), - shape=(npes,)) + _lds_p2p_wt_bases = SmemPtr(base_ptr, p2p_wt_base_offset, T.i64(), shape=(npes,)) SmemPtr.get(_lds_p2p_wt_bases) if lane < npes: p2p_wt_base_addr = buffer_load(_r_p2p_comb_wt, lane, vec_width=1, dtype=T.i64()) @@ -943,7 +932,7 @@ def _maybe_load(rsrc, offset, vld_flag, **kwargs): # Common per-token decoding from ``shmem_tok_id_to_src[recv_tok_id]``: # dest_pe - which peer this token must be combined to # dest_lid - the per-PE local id ``[0, max_tok_per_rank)`` - n_chunks = nbytes // 16 # 16-byte (4-i32) vector chunks per token + n_chunks = nbytes // 16 # 16-byte (4-i32) vector chunks per token if const_expr(skip_stage1): if const_expr(enable_weights): @@ -954,18 +943,17 @@ def _maybe_load(rsrc, offset, vld_flag, **kwargs): for recv_tok_id in range(global_warp_id, total_recv, global_warp_num): dest_tok_enc = buffer_load(_r_tis, recv_tok_id, vec_width=1, dtype=T.i32()) if const_expr(_log2_max_tok is not None): - dest_pe = dest_tok_enc >> _log2_max_tok + dest_pe = dest_tok_enc >> _log2_max_tok dest_lid = dest_tok_enc & _mask_max_tok else: - dest_pe = (dest_tok_enc // max_tok_per_rank) - dest_lid = (dest_tok_enc % max_tok_per_rank) - wt_pe_base = SmemPtr.load(_lds_p2p_wt_bases, [dest_pe]) - wt_dest_off = _to_i64( - rank * max_tok_per_rank + dest_lid) * weight_bytes + dest_pe = dest_tok_enc // max_tok_per_rank + dest_lid = dest_tok_enc % max_tok_per_rank + wt_pe_base = SmemPtr.load(_lds_p2p_wt_bases, [dest_pe]) + wt_dest_off = _to_i64(rank * max_tok_per_rank + dest_lid) * weight_bytes wt_dest_addr = _lv_unwrap(wt_pe_base) + wt_dest_off - wt_src_addr = _lv_unwrap(addr_wts_buf) + _to_i64(recv_tok_id) * weight_bytes - rsrc_wt_src = create_buffer_resource_from_addr(wt_src_addr) - rsrc_wt_dst = create_buffer_resource_from_addr(wt_dest_addr) + wt_src_addr = _lv_unwrap(addr_wts_buf) + _to_i64(recv_tok_id) * weight_bytes + rsrc_wt_src = create_buffer_resource_from_addr(wt_src_addr) + rsrc_wt_dst = create_buffer_resource_from_addr(wt_dest_addr) if lane < wt_n_i32: wt_val = buffer_load(rsrc_wt_src, lane, vec_width=1, dtype=T.i32()) buffer_store(wt_val, rsrc_wt_dst, lane) @@ -978,61 +966,57 @@ def _maybe_load(rsrc, offset, vld_flag, **kwargs): # the destination PE's ``shmem_comb_inp``. _rsrc_dtm = create_buffer_resource_from_addr(addr_disp_tok_map) _rsrc_dow = create_buffer_resource_from_addr(addr_disp_out_wts) - smoe_all_vld = False # k-slots may be sentinel (-1) for non-local experts + smoe_all_vld = False # k-slots may be sentinel (-1) for non-local experts for recv_tok_id in range(global_warp_id, total_recv, global_warp_num): dest_tok_enc = buffer_load(_r_tis, recv_tok_id, vec_width=1, dtype=T.i32()) if const_expr(_log2_max_tok is not None): - dest_pe = dest_tok_enc >> _log2_max_tok + dest_pe = dest_tok_enc >> _log2_max_tok dest_lid = dest_tok_enc & _mask_max_tok else: - dest_pe = (dest_tok_enc // max_tok_per_rank) - dest_lid = (dest_tok_enc % max_tok_per_rank) + dest_pe = dest_tok_enc // max_tok_per_rank + dest_lid = dest_tok_enc % max_tok_per_rank if const_expr(use_p2p_read): # P2P-read mode: write locally; peers will pull from us in Stage 3. dest_byte_off = _to_i64(recv_tok_id) * nbytes dest_tok_addr = _lv_unwrap(addr_comb_inp) + dest_byte_off else: - peer_base = SmemPtr.load(_lds_p2p_bases, [dest_pe]) + peer_base = SmemPtr.load(_lds_p2p_bases, [dest_pe]) dest_byte_off = _to_i64(rank * max_tok_per_rank + dest_lid) * nbytes dest_tok_addr = _lv_unwrap(peer_base) + dest_byte_off rsrc_dst = create_buffer_resource_from_addr(dest_tok_addr) # Collect resources/valid-flags/weights for each k-expert slot. expert_rsrcs = [] - expert_vlds = [] - expert_wts = [] + expert_vlds = [] + expert_wts = [] for k_slot in range_constexpr(experts_per_token): slot_addr = addr_disp_tok_map + _to_i64(recv_tok_id * experts_per_token + k_slot) * 8 - slot_val = load_i64_global(slot_addr) - slot_vld = slot_val != -1 + slot_val = load_i64_global(slot_addr) + slot_vld = slot_val != -1 safe_slot = arith.select(slot_vld, slot_val, 0) expert_tok_addr = addr_packed_recv_x + safe_slot * nbytes expert_rsrcs.append(create_buffer_resource_from_addr(expert_tok_addr)) expert_vlds.append(slot_vld) - wt_k = buffer_load(_rsrc_dow, recv_tok_id * experts_per_token + k_slot, - vec_width=1, dtype=T.f32()) + wt_k = buffer_load(_rsrc_dow, recv_tok_id * experts_per_token + k_slot, vec_width=1, dtype=T.f32()) expert_wts.append(wt_k) # Weighted reduce across the k experts, then scatter. for elem_off in range(lane, n_i32, 64): expert_vals = [] for k_slot in range_constexpr(experts_per_token): - expert_vals.append(buffer_load(expert_rsrcs[k_slot], elem_off, - vec_width=1, dtype=T.i32())) - accum = _weighted_accum_experts(expert_vals, expert_wts, - expert_vlds, smoe_all_vld) + expert_vals.append(buffer_load(expert_rsrcs[k_slot], elem_off, vec_width=1, dtype=T.i32())) + accum = _weighted_accum_experts(expert_vals, expert_wts, expert_vlds, smoe_all_vld) buffer_store(accum, rsrc_dst, elem_off) if const_expr(enable_weights): if const_expr(use_p2p_read): - wt_dest_off = _to_i64(recv_tok_id) * weight_bytes + wt_dest_off = _to_i64(recv_tok_id) * weight_bytes wt_dest_addr = _lv_unwrap(addr_comb_inp_wts) + wt_dest_off else: - wt_pe_base = SmemPtr.load(_lds_p2p_wt_bases, [dest_pe]) - wt_dest_off = _to_i64( - rank * max_tok_per_rank + dest_lid) * weight_bytes + wt_pe_base = SmemPtr.load(_lds_p2p_wt_bases, [dest_pe]) + wt_dest_off = _to_i64(rank * max_tok_per_rank + dest_lid) * weight_bytes wt_dest_addr = _lv_unwrap(wt_pe_base) + wt_dest_off wt_src_addr = _lv_unwrap(addr_wts_buf) + _to_i64(recv_tok_id) * weight_bytes rsrc_wt_src = create_buffer_resource_from_addr(wt_src_addr) @@ -1051,7 +1035,7 @@ def _maybe_load(rsrc, offset, vld_flag, **kwargs): # In mixed-mode (bf16 input → fp8 staging), the source uses # bf16 stride (inp_nbytes) while the dest uses fp8 stride # (nbytes); in same-dtype mode the two strides are identical. - src_tok_addr = addr_inp_tok + _to_i64(recv_tok_id) * inp_nbytes + src_tok_addr = addr_inp_tok + _to_i64(recv_tok_id) * inp_nbytes dst_tok_addr = addr_comb_inp + _to_i64(recv_tok_id) * nbytes rsrc_src = create_buffer_resource_from_addr(src_tok_addr) rsrc_dst = create_buffer_resource_from_addr(dst_tok_addr) @@ -1060,22 +1044,20 @@ def _maybe_load(rsrc, offset, vld_flag, **kwargs): # elems) → ExtF v4f32 → cvt_pk_fp8_f32 ×2 → store 1 # fp8 i32 (4 fp8 elems) at staging offset ``elem_off``. from flydsl._mlir.dialects import rocdl as _rocdl_s1a + _v4bf16_a = T.VectorType.get([4], T.bf16()) - _v4f32_a = T.VectorType.get([4], T.f32()) - _i32t_a = T.i32() + _v4f32_a = T.VectorType.get([4], T.f32()) + _i32t_a = T.i32() for elem_off in range(lane, n_i32, 64): - bf_pair = buffer_load(rsrc_src, elem_off * 2, - vec_width=2, dtype=T.i32()) + bf_pair = buffer_load(rsrc_src, elem_off * 2, vec_width=2, dtype=T.i32()) v4f = vector.bitcast(_v4bf16_a, bf_pair).extf(_v4f32_a) f0 = vector.extract(v4f, static_position=[0]) f1 = vector.extract(v4f, static_position=[1]) f2 = vector.extract(v4f, static_position=[2]) f3 = vector.extract(v4f, static_position=[3]) zi = arith.constant(0, type=_i32t_a) - lo = _rocdl_s1a.cvt_pk_fp8_f32(res=_i32t_a, src_a=f0, src_b=f1, - old=zi, word_sel=False) - fp8_i32 = _rocdl_s1a.cvt_pk_fp8_f32(res=_i32t_a, src_a=f2, src_b=f3, - old=lo, word_sel=True) + lo = _rocdl_s1a.cvt_pk_fp8_f32(res=_i32t_a, src_a=f0, src_b=f1, old=zi, word_sel=False) + fp8_i32 = _rocdl_s1a.cvt_pk_fp8_f32(res=_i32t_a, src_a=f2, src_b=f3, old=lo, word_sel=True) buffer_store(fp8_i32, rsrc_dst, elem_off) else: # Same-dtype path: 4-i32 vector copy. ``chunk_idx`` is @@ -1083,9 +1065,9 @@ def _maybe_load(rsrc, offset, vld_flag, **kwargs): # copying; ``chunk_i32_off`` translates it to i32 elems. if const_expr(dual_end_aligned >= 128): for chunk_idx in range(lane, dual_end_aligned, 128): - chunk_i32_off = chunk_idx * 4 - chunk_i32_off_alt = (chunk_idx + 64) * 4 - vec_a = buffer_load(rsrc_src, chunk_i32_off, vec_width=4, dtype=T.i32()) + chunk_i32_off = chunk_idx * 4 + chunk_i32_off_alt = (chunk_idx + 64) * 4 + vec_a = buffer_load(rsrc_src, chunk_i32_off, vec_width=4, dtype=T.i32()) vec_b = buffer_load(rsrc_src, chunk_i32_off_alt, vec_width=4, dtype=T.i32()) buffer_store(vec_a, rsrc_dst, chunk_i32_off) buffer_store(vec_b, rsrc_dst, chunk_i32_off_alt) @@ -1097,7 +1079,7 @@ def _maybe_load(rsrc, offset, vld_flag, **kwargs): if const_expr(enable_weights): for recv_tok_id in range(global_warp_id, total_recv, global_warp_num): - wt_src_addr = _lv_unwrap(addr_wts_buf) + _to_i64(recv_tok_id) * weight_bytes + wt_src_addr = _lv_unwrap(addr_wts_buf) + _to_i64(recv_tok_id) * weight_bytes wt_dst_addr = _lv_unwrap(addr_comb_inp_wts) + _to_i64(recv_tok_id) * weight_bytes rsrc_wt_src = create_buffer_resource_from_addr(wt_src_addr) rsrc_wt_dst = create_buffer_resource_from_addr(wt_dst_addr) @@ -1112,17 +1094,17 @@ def _maybe_load(rsrc, offset, vld_flag, **kwargs): for recv_tok_id in range(global_warp_id, total_recv, global_warp_num): dest_tok_enc = buffer_load(_r_tis, recv_tok_id, vec_width=1, dtype=T.i32()) if const_expr(_log2_max_tok is not None): - dest_pe = dest_tok_enc >> _log2_max_tok + dest_pe = dest_tok_enc >> _log2_max_tok dest_lid = dest_tok_enc & _mask_max_tok else: - dest_pe = (dest_tok_enc // max_tok_per_rank) - dest_lid = (dest_tok_enc % max_tok_per_rank) + dest_pe = dest_tok_enc // max_tok_per_rank + dest_lid = dest_tok_enc % max_tok_per_rank peer_base = SmemPtr.load(_lds_p2p_bases, [dest_pe]) # Dest stride uses ``nbytes`` (staging dtype, fp8 in mixed mode). - dest_off = _to_i64(rank * max_tok_per_rank + dest_lid) * nbytes - dest_tok_addr = _lv_unwrap(peer_base) + dest_off + dest_off = _to_i64(rank * max_tok_per_rank + dest_lid) * nbytes + dest_tok_addr = _lv_unwrap(peer_base) + dest_off # Src stride uses ``inp_nbytes`` (input dtype, bf16 in mixed mode). - src_tok_addr = addr_inp_tok + _to_i64(recv_tok_id) * inp_nbytes + src_tok_addr = addr_inp_tok + _to_i64(recv_tok_id) * inp_nbytes rsrc_src = create_buffer_resource_from_addr(src_tok_addr) rsrc_dst = create_buffer_resource_from_addr(dest_tok_addr) if const_expr(_xfer_bf16_to_fp8): @@ -1130,29 +1112,27 @@ def _maybe_load(rsrc, offset, vld_flag, **kwargs): # ExtF v4f32 → cvt_pk_fp8_f32 ×2 → store 1 fp8 i32 (=4 # fp8 elems). Loop unit is 1 fp8-i32 per lane per step. from flydsl._mlir.dialects import rocdl as _rocdl_s1b + _v4bf16_b = T.VectorType.get([4], T.bf16()) - _v4f32_b = T.VectorType.get([4], T.f32()) - _i32t_b = T.i32() + _v4f32_b = T.VectorType.get([4], T.f32()) + _i32t_b = T.i32() for elem_off in range(lane, n_i32, 64): - bf_pair = buffer_load(rsrc_src, elem_off * 2, - vec_width=2, dtype=T.i32()) + bf_pair = buffer_load(rsrc_src, elem_off * 2, vec_width=2, dtype=T.i32()) v4f = vector.bitcast(_v4bf16_b, bf_pair).extf(_v4f32_b) f0 = vector.extract(v4f, static_position=[0]) f1 = vector.extract(v4f, static_position=[1]) f2 = vector.extract(v4f, static_position=[2]) f3 = vector.extract(v4f, static_position=[3]) zi = arith.constant(0, type=_i32t_b) - lo = _rocdl_s1b.cvt_pk_fp8_f32(res=_i32t_b, src_a=f0, src_b=f1, - old=zi, word_sel=False) - fp8_i32 = _rocdl_s1b.cvt_pk_fp8_f32(res=_i32t_b, src_a=f2, src_b=f3, - old=lo, word_sel=True) + lo = _rocdl_s1b.cvt_pk_fp8_f32(res=_i32t_b, src_a=f0, src_b=f1, old=zi, word_sel=False) + fp8_i32 = _rocdl_s1b.cvt_pk_fp8_f32(res=_i32t_b, src_a=f2, src_b=f3, old=lo, word_sel=True) buffer_store(fp8_i32, rsrc_dst, elem_off) else: if const_expr(dual_end_aligned >= 128): for chunk_idx in range(lane, dual_end_aligned, 128): - chunk_i32_off = chunk_idx * 4 + chunk_i32_off = chunk_idx * 4 chunk_i32_off_alt = (chunk_idx + 64) * 4 - vec_a = buffer_load(rsrc_src, chunk_i32_off, vec_width=4, dtype=T.i32()) + vec_a = buffer_load(rsrc_src, chunk_i32_off, vec_width=4, dtype=T.i32()) vec_b = buffer_load(rsrc_src, chunk_i32_off_alt, vec_width=4, dtype=T.i32()) buffer_store(vec_a, rsrc_dst, chunk_i32_off) buffer_store(vec_b, rsrc_dst, chunk_i32_off_alt) @@ -1163,13 +1143,12 @@ def _maybe_load(rsrc, offset, vld_flag, **kwargs): buffer_store(vec_a, rsrc_dst, chunk_i32_off) if const_expr(enable_weights): - wt_pe_base = SmemPtr.load(_lds_p2p_wt_bases, [dest_pe]) - wt_dest_off = _to_i64( - rank * max_tok_per_rank + dest_lid) * weight_bytes + wt_pe_base = SmemPtr.load(_lds_p2p_wt_bases, [dest_pe]) + wt_dest_off = _to_i64(rank * max_tok_per_rank + dest_lid) * weight_bytes wt_dest_addr = _lv_unwrap(wt_pe_base) + wt_dest_off - wt_src_addr = _lv_unwrap(addr_wts_buf) + _to_i64(recv_tok_id) * weight_bytes - rsrc_wt_src = create_buffer_resource_from_addr(wt_src_addr) - rsrc_wt_dst = create_buffer_resource_from_addr(wt_dest_addr) + wt_src_addr = _lv_unwrap(addr_wts_buf) + _to_i64(recv_tok_id) * weight_bytes + rsrc_wt_src = create_buffer_resource_from_addr(wt_src_addr) + rsrc_wt_dst = create_buffer_resource_from_addr(wt_dest_addr) if lane < wt_n_i32: wt_val = buffer_load(rsrc_wt_src, lane, vec_width=1, dtype=T.i32()) buffer_store(wt_val, rsrc_wt_dst, lane) @@ -1185,8 +1164,10 @@ def _maybe_load(rsrc, offset, vld_flag, **kwargs): if grid_thread_id < npes: mori_shmem.int32_wait_until_equals(addr_comb_bar, block_num) buffer_store(arith.constant(0), _r_comb_bar, 0) - xdb_remote_addr = buffer_load(_r_p2p_xdb, grid_thread_id, vec_width=1, dtype=T.i64()) + \ - arith.constant(rank, type=T.i64()) * 8 + xdb_remote_addr = ( + buffer_load(_r_p2p_xdb, grid_thread_id, vec_width=1, dtype=T.i64()) + + arith.constant(rank, type=T.i64()) * 8 + ) store_i64_global_system(xdb_remote_addr, xdb_cur_flag) if grid_thread_id == 0: @@ -1207,25 +1188,24 @@ def _maybe_load(rsrc, offset, vld_flag, **kwargs): # per-expert partials from ``shmem_comb_inp``, accumulates them in # high-precision (f32) and writes back the merged token to # ``shmem_comb_out``. - SLC_CACHE = 2 # buffer_load/store ``cache_modifier=SLC`` (system-coherent) - rsrc_out = create_buffer_resource_from_addr(addr_comb_out) + SLC_CACHE = 2 # buffer_load/store ``cache_modifier=SLC`` (system-coherent) + rsrc_out = create_buffer_resource_from_addr(addr_comb_out) n_elems = n_i32 # When ``cur_rank_num_token == 0`` the division below would divide by # zero; clamp the denominator to 1 (loop won't execute anyway). - safe_token_count = arith.select( - cur_rank_num_token == 0, 1, cur_rank_num_token) - warps_per_tok = (global_warp_num + safe_token_count - 1) // safe_token_count - hdim_per_warp = (n_elems + warps_per_tok - 1) // warps_per_tok - s3_total_work = cur_rank_num_token * warps_per_tok + safe_token_count = arith.select(cur_rank_num_token == 0, 1, cur_rank_num_token) + warps_per_tok = (global_warp_num + safe_token_count - 1) // safe_token_count + hdim_per_warp = (n_elems + warps_per_tok - 1) // warps_per_tok + s3_total_work = cur_rank_num_token * warps_per_tok for s3_work_idx in range(global_warp_id, s3_total_work, global_warp_num): - tok_id = (s3_work_idx // warps_per_tok) - part_id = (s3_work_idx % warps_per_tok) - hdim_off = part_id * hdim_per_warp + tok_id = s3_work_idx // warps_per_tok + part_id = s3_work_idx % warps_per_tok + hdim_off = part_id * hdim_per_warp expert_rsrcs = [] - expert_vlds = [] + expert_vlds = [] if const_expr(skip_stage1): # Fused-upstream Stage 3: when ``skip_stage1`` is set the @@ -1236,8 +1216,8 @@ def _maybe_load(rsrc, offset, vld_flag, **kwargs): # [0, k). Unrouted (tok_id, k_slot) slots are zero-initialized # by the caller and therefore contribute zero to the sum. for k_slot in range_constexpr(experts_per_token): - slot_idx = tok_id * experts_per_token + k_slot - expert_tok_off = _to_i64(slot_idx) * nbytes + slot_idx = tok_id * experts_per_token + k_slot + expert_tok_off = _to_i64(slot_idx) * nbytes expert_tok_addr = _lv_unwrap(addr_comb_inp + expert_tok_off) expert_rsrcs.append(create_buffer_resource_from_addr(expert_tok_addr)) expert_vlds.append(arith.constant(1, type=T.bool())) @@ -1249,8 +1229,8 @@ def _maybe_load(rsrc, offset, vld_flag, **kwargs): # scattered each (src_pe, src_lid) contribution into that # slot. Two 4-i32 loads cover the 8 k-slots in one round. tm_base_off = tok_id * experts_per_token - tm_vec_lo = buffer_load(_rsrc_tok_map, tm_base_off, vec_width=4, dtype=T.i32()) - tm_vec_hi = buffer_load(_rsrc_tok_map, tm_base_off + 4, vec_width=4, dtype=T.i32()) + tm_vec_lo = buffer_load(_rsrc_tok_map, tm_base_off, vec_width=4, dtype=T.i32()) + tm_vec_hi = buffer_load(_rsrc_tok_map, tm_base_off + 4, vec_width=4, dtype=T.i32()) for k_slot in range_constexpr(experts_per_token): if const_expr(k_slot < 4): @@ -1260,22 +1240,22 @@ def _maybe_load(rsrc, offset, vld_flag, **kwargs): if const_expr(_log2_max_recv is not None): dest_pe_k = enc_k >> _log2_max_recv else: - dest_pe_k = (enc_k // max_recv) - vld_k = dest_pe_k < npes # sentinel = npes + dest_pe_k = enc_k // max_recv + vld_k = dest_pe_k < npes # sentinel = npes safe_pe = arith.select(vld_k, dest_pe_k, rank) if const_expr(use_p2p_read): - dtok_global = (enc_k % max_recv) - safe_dtok = arith.select(vld_k, dtok_global, 0) - peer_base = SmemPtr.load(_lds_p2p_bases, [safe_pe]) + dtok_global = enc_k % max_recv + safe_dtok = arith.select(vld_k, dtok_global, 0) + peer_base = SmemPtr.load(_lds_p2p_bases, [safe_pe]) expert_tok_off = _to_i64(safe_dtok) * nbytes expert_tok_addr = _lv_unwrap(peer_base) + expert_tok_off else: - expert_tok_off = _to_i64(safe_pe * max_tok_per_rank + tok_id) * nbytes + expert_tok_off = _to_i64(safe_pe * max_tok_per_rank + tok_id) * nbytes expert_tok_addr = _lv_unwrap(addr_comb_inp + expert_tok_off) expert_rsrcs.append(create_buffer_resource_from_addr(expert_tok_addr)) expert_vlds.append(vld_k) - all_vld = (npes >= experts_per_token) # without compaction, every k_slot must be valid + all_vld = npes >= experts_per_token # without compaction, every k_slot must be valid eff_all_vld = all_vld or _use_compaction # Two paths optimised for the per-warp partition size: @@ -1283,10 +1263,9 @@ def _maybe_load(rsrc, offset, vld_flag, **kwargs): # quad unrolled loads, each step covers 256/512/... bytes. # - narrow path (hdim_per_warp <= 895): plain step=64 loop. if 895 < hdim_per_warp: - rem_hdim_128 = n_elems - hdim_off + rem_hdim_128 = n_elems - hdim_off # Effective end of THIS warp's partition, clamped to n_elems. - eff_end_128 = arith.select( - rem_hdim_128 < hdim_per_warp, rem_hdim_128, hdim_per_warp) + eff_end_128 = arith.select(rem_hdim_128 < hdim_per_warp, rem_hdim_128, hdim_per_warp) if const_expr(n_i32 % 256 == 0 and warp_num_per_block < 16): if (hdim_per_warp % 256) < 1: @@ -1297,11 +1276,45 @@ def _maybe_load(rsrc, offset, vld_flag, **kwargs): vals_a, vals_b, vals_c, vals_d = [], [], [], [] for k_slot in range_constexpr(experts_per_token): rsrc_k = expert_rsrcs[k_slot] - vld_k = expert_vlds[k_slot] - vals_a.append(_maybe_load(rsrc_k, ec_abs, vld_k, vec_width=1, dtype=T.i32(), cache_modifier=SLC_CACHE)) - vals_b.append(_maybe_load(rsrc_k, ec_abs, vld_k, vec_width=1, dtype=T.i32(), cache_modifier=SLC_CACHE, soffset_bytes=256)) - vals_c.append(_maybe_load(rsrc_k, ec_abs, vld_k, vec_width=1, dtype=T.i32(), cache_modifier=SLC_CACHE, soffset_bytes=512)) - vals_d.append(_maybe_load(rsrc_k, ec_abs, vld_k, vec_width=1, dtype=T.i32(), cache_modifier=SLC_CACHE, soffset_bytes=768)) + vld_k = expert_vlds[k_slot] + vals_a.append( + _maybe_load( + rsrc_k, ec_abs, vld_k, vec_width=1, dtype=T.i32(), cache_modifier=SLC_CACHE + ) + ) + vals_b.append( + _maybe_load( + rsrc_k, + ec_abs, + vld_k, + vec_width=1, + dtype=T.i32(), + cache_modifier=SLC_CACHE, + soffset_bytes=256, + ) + ) + vals_c.append( + _maybe_load( + rsrc_k, + ec_abs, + vld_k, + vec_width=1, + dtype=T.i32(), + cache_modifier=SLC_CACHE, + soffset_bytes=512, + ) + ) + vals_d.append( + _maybe_load( + rsrc_k, + ec_abs, + vld_k, + vec_width=1, + dtype=T.i32(), + cache_modifier=SLC_CACHE, + soffset_bytes=768, + ) + ) acc_a = _accum_experts(vals_a, expert_vlds, eff_all_vld) acc_b = _accum_experts(vals_b, expert_vlds, eff_all_vld) acc_c = _accum_experts(vals_c, expert_vlds, eff_all_vld) @@ -1329,9 +1342,23 @@ def _maybe_load(rsrc, offset, vld_flag, **kwargs): vals_a, vals_b = [], [] for k_slot in range_constexpr(experts_per_token): rsrc_k = expert_rsrcs[k_slot] - vld_k = expert_vlds[k_slot] - vals_a.append(_maybe_load(rsrc_k, ec_abs, vld_k, vec_width=1, dtype=T.i32(), cache_modifier=SLC_CACHE)) - vals_b.append(_maybe_load(rsrc_k, ec_abs, vld_k, vec_width=1, dtype=T.i32(), cache_modifier=SLC_CACHE, soffset_bytes=256)) + vld_k = expert_vlds[k_slot] + vals_a.append( + _maybe_load( + rsrc_k, ec_abs, vld_k, vec_width=1, dtype=T.i32(), cache_modifier=SLC_CACHE + ) + ) + vals_b.append( + _maybe_load( + rsrc_k, + ec_abs, + vld_k, + vec_width=1, + dtype=T.i32(), + cache_modifier=SLC_CACHE, + soffset_bytes=256, + ) + ) acc_a = _accum_experts(vals_a, expert_vlds, eff_all_vld) acc_b = _accum_experts(vals_b, expert_vlds, eff_all_vld) if const_expr(_xfer_bf16_to_fp8): @@ -1346,7 +1373,16 @@ def _maybe_load(rsrc, offset, vld_flag, **kwargs): ec_abs = hdim_off + ec vals_tail = [] for k_slot in range_constexpr(experts_per_token): - vals_tail.append(_maybe_load(expert_rsrcs[k_slot], ec_abs, expert_vlds[k_slot], vec_width=1, dtype=T.i32(), cache_modifier=SLC_CACHE)) + vals_tail.append( + _maybe_load( + expert_rsrcs[k_slot], + ec_abs, + expert_vlds[k_slot], + vec_width=1, + dtype=T.i32(), + cache_modifier=SLC_CACHE, + ) + ) acc_tail = _accum_experts(vals_tail, expert_vlds, eff_all_vld) if const_expr(_xfer_bf16_to_fp8): out_off = tok_id * out_n_i32 + ec_abs * 2 @@ -1357,13 +1393,21 @@ def _maybe_load(rsrc, offset, vld_flag, **kwargs): else: # Narrow path: a single step=64 main loop. rem_hdim_64 = n_elems - hdim_off - eff_end_64 = arith.select( - rem_hdim_64 < hdim_per_warp, rem_hdim_64, hdim_per_warp) + eff_end_64 = arith.select(rem_hdim_64 < hdim_per_warp, rem_hdim_64, hdim_per_warp) for ec in range(lane, eff_end_64, 64): ec_abs = hdim_off + ec vals_main = [] for k_slot in range_constexpr(experts_per_token): - vals_main.append(_maybe_load(expert_rsrcs[k_slot], ec_abs, expert_vlds[k_slot], vec_width=1, dtype=T.i32(), cache_modifier=SLC_CACHE)) + vals_main.append( + _maybe_load( + expert_rsrcs[k_slot], + ec_abs, + expert_vlds[k_slot], + vec_width=1, + dtype=T.i32(), + cache_modifier=SLC_CACHE, + ) + ) acc = _accum_experts(vals_main, expert_vlds, eff_all_vld) if const_expr(_xfer_bf16_to_fp8): out_off = tok_id * out_n_i32 + ec_abs * 2 @@ -1380,8 +1424,8 @@ def _maybe_load(rsrc, offset, vld_flag, **kwargs): if const_expr(enable_weights): rsrc_out_wts = create_buffer_resource_from_addr(addr_comb_out_wts) for wt_tok_id in range(global_warp_id, cur_rank_num_token, global_warp_num): - wt_tm_off = wt_tok_id * experts_per_token - wt_tm_vec_lo = buffer_load(_rsrc_tok_map, wt_tm_off, vec_width=4, dtype=T.i32()) + wt_tm_off = wt_tok_id * experts_per_token + wt_tm_vec_lo = buffer_load(_rsrc_tok_map, wt_tm_off, vec_width=4, dtype=T.i32()) wt_tm_vec_hi = buffer_load(_rsrc_tok_map, wt_tm_off + 4, vec_width=4, dtype=T.i32()) if lane < experts_per_token: @@ -1394,21 +1438,18 @@ def _maybe_load(rsrc, offset, vld_flag, **kwargs): if const_expr(_log2_max_recv is not None): wt_pe = wt_enc >> _log2_max_recv else: - wt_pe = (wt_enc // max_recv) - wt_vld = wt_pe < npes + wt_pe = wt_enc // max_recv + wt_vld = wt_pe < npes wt_safe_pe = arith.select(wt_vld, wt_pe, rank) if const_expr(use_p2p_read): - wt_dtok = (wt_enc % max_recv) + wt_dtok = wt_enc % max_recv wt_safe_dtok = arith.select(wt_vld, wt_dtok, 0) - wt_pe_base = SmemPtr.load(_lds_p2p_wt_bases, [wt_safe_pe]) - wt_src_off = _to_i64(wt_safe_dtok) * weight_bytes - wt_rsrc = create_buffer_resource_from_addr( - wt_pe_base + wt_src_off) + wt_pe_base = SmemPtr.load(_lds_p2p_wt_bases, [wt_safe_pe]) + wt_src_off = _to_i64(wt_safe_dtok) * weight_bytes + wt_rsrc = create_buffer_resource_from_addr(wt_pe_base + wt_src_off) else: - wt_src_off = _to_i64( - wt_safe_pe * max_tok_per_rank + wt_tok_id) * weight_bytes - wt_rsrc = create_buffer_resource_from_addr( - addr_comb_inp_wts + wt_src_off) + wt_src_off = _to_i64(wt_safe_pe * max_tok_per_rank + wt_tok_id) * weight_bytes + wt_rsrc = create_buffer_resource_from_addr(addr_comb_inp_wts + wt_src_off) wt_val = buffer_load(wt_rsrc, lane, vec_width=1, dtype=T.f32()) if const_expr(npes >= experts_per_token): wt_acc = wt_acc + wt_val @@ -1421,14 +1462,25 @@ def _maybe_load(rsrc, offset, vld_flag, **kwargs): return ep_combine_intranode -def make_dispatch_jit(*, rank, npes, experts_per_rank, experts_per_token, - hidden_dim, max_tok_per_rank, block_num, - warp_num_per_block, data_type, - scale_dim=0, scale_type_size=0, - enable_std_moe=False): +def make_dispatch_jit( + *, + rank, + npes, + experts_per_rank, + experts_per_token, + hidden_dim, + max_tok_per_rank, + block_num, + warp_num_per_block, + data_type, + scale_dim=0, + scale_type_size=0, + enable_std_moe=False, +): hidden_elem_size = torch.tensor([], dtype=data_type).element_size() kernel = make_dispatch_kernel( - rank=rank, npes=npes, + rank=rank, + npes=npes, experts_per_rank=experts_per_rank, experts_per_token=experts_per_token, hidden_dim=hidden_dim, @@ -1447,42 +1499,70 @@ def make_dispatch_jit(*, rank, npes, experts_per_rank, experts_per_token, # configs produce distinct cached entries. _key_rank, _key_npes, _key_block_num = rank, npes, block_num _key_warp_per_block = warp_num_per_block - _key_max_tok = max_tok_per_rank - _key_std_moe = enable_std_moe + _key_max_tok = max_tok_per_rank + _key_std_moe = enable_std_moe @flyc.jit def dispatch_launch( - addr_inp_tok: fx.Int64, addr_idx: fx.Int64, addr_wts: fx.Int64, - addr_out_tok: fx.Int64, addr_out_wts: fx.Int64, addr_out_idx: fx.Int64, - addr_tok_off: fx.Int64, addr_recv_num: fx.Int64, - addr_dest_ctr: fx.Int64, addr_disp_bar: fx.Int64, - addr_tok_map: fx.Int64, addr_tis: fx.Int64, + addr_inp_tok: fx.Int64, + addr_idx: fx.Int64, + addr_wts: fx.Int64, + addr_out_tok: fx.Int64, + addr_out_wts: fx.Int64, + addr_out_idx: fx.Int64, + addr_tok_off: fx.Int64, + addr_recv_num: fx.Int64, + addr_dest_ctr: fx.Int64, + addr_disp_bar: fx.Int64, + addr_tok_map: fx.Int64, + addr_tis: fx.Int64, addr_total_rv: fx.Int64, - addr_p2p_tok_off: fx.Int64, addr_p2p_tis: fx.Int64, - addr_p2p_out_wts: fx.Int64, addr_p2p_out_idx: fx.Int64, - addr_p2p_out_tok: fx.Int64, addr_p2p_recv_num: fx.Int64, - addr_scales: fx.Int64, addr_p2p_out_scales: fx.Int64, - addr_packed_recv_x: fx.Int64, addr_packed_recv_count: fx.Int64, - addr_packed_recv_src_info: fx.Int64, addr_disp_tok_map: fx.Int64, + addr_p2p_tok_off: fx.Int64, + addr_p2p_tis: fx.Int64, + addr_p2p_out_wts: fx.Int64, + addr_p2p_out_idx: fx.Int64, + addr_p2p_out_tok: fx.Int64, + addr_p2p_recv_num: fx.Int64, + addr_scales: fx.Int64, + addr_p2p_out_scales: fx.Int64, + addr_packed_recv_x: fx.Int64, + addr_packed_recv_count: fx.Int64, + addr_packed_recv_src_info: fx.Int64, + addr_disp_tok_map: fx.Int64, addr_disp_grid_bar: fx.Int64, cur_tok: fx.Int32, stream: Stream = Stream(None), ): - _ = (_key_rank, _key_npes, _key_block_num, _key_warp_per_block, - _key_max_tok, _key_std_moe) - kernel(addr_inp_tok, addr_idx, addr_wts, - addr_out_tok, addr_out_wts, addr_out_idx, - addr_tok_off, addr_recv_num, addr_dest_ctr, - addr_disp_bar, addr_tok_map, addr_tis, - addr_total_rv, - addr_p2p_tok_off, addr_p2p_tis, - addr_p2p_out_wts, addr_p2p_out_idx, - addr_p2p_out_tok, addr_p2p_recv_num, - addr_scales, addr_p2p_out_scales, - addr_packed_recv_x, addr_packed_recv_count, - addr_packed_recv_src_info, addr_disp_tok_map, - addr_disp_grid_bar, - cur_tok).launch( + _ = (_key_rank, _key_npes, _key_block_num, _key_warp_per_block, _key_max_tok, _key_std_moe) + kernel( + addr_inp_tok, + addr_idx, + addr_wts, + addr_out_tok, + addr_out_wts, + addr_out_idx, + addr_tok_off, + addr_recv_num, + addr_dest_ctr, + addr_disp_bar, + addr_tok_map, + addr_tis, + addr_total_rv, + addr_p2p_tok_off, + addr_p2p_tis, + addr_p2p_out_wts, + addr_p2p_out_idx, + addr_p2p_out_tok, + addr_p2p_recv_num, + addr_scales, + addr_p2p_out_scales, + addr_packed_recv_x, + addr_packed_recv_count, + addr_packed_recv_src_info, + addr_disp_tok_map, + addr_disp_grid_bar, + cur_tok, + ).launch( grid=(block_num, 1, 1), block=(warp_num_per_block * 64, 1, 1), stream=stream, @@ -1491,15 +1571,27 @@ def dispatch_launch( return dispatch_launch -def make_combine_jit(*, rank, npes, experts_per_rank=0, experts_per_token, - hidden_dim, max_tok_per_rank, block_num, - warp_num_per_block, data_type, - enable_weights=False, enable_std_moe=False, - use_p2p_read=False, skip_stage1=False, - inp_data_type=None): +def make_combine_jit( + *, + rank, + npes, + experts_per_rank=0, + experts_per_token, + hidden_dim, + max_tok_per_rank, + block_num, + warp_num_per_block, + data_type, + enable_weights=False, + enable_std_moe=False, + use_p2p_read=False, + skip_stage1=False, + inp_data_type=None, +): hidden_elem_size = torch.tensor([], dtype=data_type).element_size() kernel = make_combine_kernel( - rank=rank, npes=npes, + rank=rank, + npes=npes, experts_per_rank=experts_per_rank, experts_per_token=experts_per_token, hidden_dim=hidden_dim, @@ -1520,49 +1612,77 @@ def make_combine_jit(*, rank, npes, experts_per_rank=0, experts_per_token, # configs produce distinct cached entries. _key_rank, _key_npes, _key_block_num = rank, npes, block_num _key_warp_per_block = warp_num_per_block - _key_max_tok = max_tok_per_rank - _key_weights = enable_weights - _key_std_moe = enable_std_moe - _key_p2p_read = use_p2p_read - _key_skip_s1 = skip_stage1 - _key_inp_dtype = str(inp_data_type) if inp_data_type is not None else "none" + _key_max_tok = max_tok_per_rank + _key_weights = enable_weights + _key_std_moe = enable_std_moe + _key_p2p_read = use_p2p_read + _key_skip_s1 = skip_stage1 + _key_inp_dtype = str(inp_data_type) if inp_data_type is not None else "none" _allocator = kernel._allocator @flyc.jit def combine_launch( - addr_inp_tok: fx.Int64, addr_comb_inp: fx.Int64, - addr_comb_out: fx.Int64, addr_xdb_mem: fx.Int64, - addr_xdb_flag: fx.Int64, addr_tok_map: fx.Int64, - addr_comb_bar: fx.Int64, addr_trecv: fx.Int64, + addr_inp_tok: fx.Int64, + addr_comb_inp: fx.Int64, + addr_comb_out: fx.Int64, + addr_xdb_mem: fx.Int64, + addr_xdb_flag: fx.Int64, + addr_tok_map: fx.Int64, + addr_comb_bar: fx.Int64, + addr_trecv: fx.Int64, addr_tis: fx.Int64, - addr_p2p_comb_inp: fx.Int64, addr_p2p_xdb_mem: fx.Int64, + addr_p2p_comb_inp: fx.Int64, + addr_p2p_xdb_mem: fx.Int64, addr_wts_buf: fx.Int64, - addr_comb_inp_wts: fx.Int64, addr_comb_out_wts: fx.Int64, + addr_comb_inp_wts: fx.Int64, + addr_comb_out_wts: fx.Int64, addr_p2p_comb_inp_wts: fx.Int64, - addr_packed_recv_x: fx.Int64, addr_disp_tok_map: fx.Int64, + addr_packed_recv_x: fx.Int64, + addr_disp_tok_map: fx.Int64, addr_disp_out_wts: fx.Int64, cur_rank_num_token: fx.Int32, stream: Stream = Stream(None), ): - _ = (_key_rank, _key_npes, _key_block_num, _key_warp_per_block, - _key_max_tok, _key_weights, _key_std_moe, _key_p2p_read, - _key_skip_s1, _key_inp_dtype) + _ = ( + _key_rank, + _key_npes, + _key_block_num, + _key_warp_per_block, + _key_max_tok, + _key_weights, + _key_std_moe, + _key_p2p_read, + _key_skip_s1, + _key_inp_dtype, + ) from flydsl.compiler.kernel_function import CompilationContext - from flydsl._mlir import ir + _allocator.finalized = False ctx = CompilationContext.get_current() with ir.InsertionPoint(ctx.gpu_module_body): _allocator.finalize() - kernel(addr_inp_tok, addr_comb_inp, addr_comb_out, - addr_xdb_mem, addr_xdb_flag, addr_tok_map, - addr_comb_bar, addr_trecv, addr_tis, - addr_p2p_comb_inp, addr_p2p_xdb_mem, - addr_wts_buf, addr_comb_inp_wts, - addr_comb_out_wts, addr_p2p_comb_inp_wts, - addr_packed_recv_x, addr_disp_tok_map, - addr_disp_out_wts, - cur_rank_num_token).launch( + kernel( + addr_inp_tok, + addr_comb_inp, + addr_comb_out, + addr_xdb_mem, + addr_xdb_flag, + addr_tok_map, + addr_comb_bar, + addr_trecv, + addr_tis, + addr_p2p_comb_inp, + addr_p2p_xdb_mem, + addr_wts_buf, + addr_comb_inp_wts, + addr_comb_out_wts, + addr_p2p_comb_inp_wts, + addr_packed_recv_x, + addr_disp_tok_map, + addr_disp_out_wts, + cur_rank_num_token, + ).launch( grid=(block_num, 1, 1), block=(warp_num_per_block * 64, 1, 1), stream=stream, diff --git a/kernels/dispatch_combine_intranode_op.py b/kernels/dispatch_combine_intranode_op.py new file mode 100644 index 000000000..1220f2a9c --- /dev/null +++ b/kernels/dispatch_combine_intranode_op.py @@ -0,0 +1,649 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 FlyDSL Project Contributors + +"""FlyDSL DispatchCombine IntraNode 算子包装器。""" + +from __future__ import annotations + +from dataclasses import dataclass + +import mori.shmem as ms +import torch +from mori.shmem import mori_shmem_create_tensor + +import flydsl.compiler as flyc +import flydsl.expr as fx + +from .dispatch_combine_intranode_kernel import ( + make_combine_jit, + make_dispatch_jit, +) + + +@dataclass +class FlyDSLDispatchCombineConfig: + rank: int + world_size: int + hidden_dim: int + max_num_inp_token_per_rank: int + num_experts_per_rank: int + num_experts_per_token: int + data_type: torch.dtype = torch.bfloat16 + warp_num_per_block: int = 16 + block_num: int = 80 + chip: str = "gfx950" + scale_dim: int = 0 + scale_type_size: int = 0 + enable_std_moe: bool = False + use_external_inp_buf: bool = True + quant_type: str = "none" + + @property + def is_fp4(self): + return self.data_type == torch.float4_e2m1fn_x2 + + @property + def elem_size(self): + return torch.tensor([], dtype=self.data_type).element_size() + + @property + def token_bytes(self): + if self.is_fp4: + return self.hidden_dim // 2 + return self.hidden_dim * self.elem_size + + @property + def token_view_dim(self): + if self.is_fp4: + return self.hidden_dim // 2 + return self.hidden_dim + + @property + def block_dim(self): + return self.warp_num_per_block * 64 + + @property + def max_recv(self): + return self.world_size * self.max_num_inp_token_per_rank + + @property + def scale_bytes(self): + return self.scale_dim * self.scale_type_size + + +class FlyDSLDispatchCombineIntraNodeOp: + + def __init__(self, config): + self.cfg = config + self._dev = torch.device("cuda", config.rank) + r = config.rank + + self._alloc_buffers() + ms.shmem_barrier_all() + + npes = config.world_size + self._p2p_tok_off = torch.zeros(npes, dtype=torch.int64, device=self._dev) + self._p2p_tis = torch.zeros(npes, dtype=torch.int64, device=self._dev) + self._p2p_out_wts = torch.zeros(npes, dtype=torch.int64, device=self._dev) + self._p2p_out_idx = torch.zeros(npes, dtype=torch.int64, device=self._dev) + self._p2p_out_tok = torch.zeros(npes, dtype=torch.int64, device=self._dev) + self._p2p_recv_num = torch.zeros(npes, dtype=torch.int64, device=self._dev) + self._p2p_out_scales = torch.zeros(npes, dtype=torch.int64, device=self._dev) + for pe in range(npes): + self._p2p_tok_off[pe] = ms.shmem_ptr_p2p(self.shmem_tok_off.data_ptr(), r, pe) + self._p2p_tis[pe] = ms.shmem_ptr_p2p(self.shmem_tok_id_to_src.data_ptr(), r, pe) + self._p2p_out_wts[pe] = ms.shmem_ptr_p2p(self.shmem_disp_out_wts.data_ptr(), r, pe) + self._p2p_out_idx[pe] = ms.shmem_ptr_p2p(self.shmem_disp_out_idx.data_ptr(), r, pe) + self._p2p_out_tok[pe] = ms.shmem_ptr_p2p(self.shmem_disp_out_tok.data_ptr(), r, pe) + self._p2p_recv_num[pe] = ms.shmem_ptr_p2p(self.shmem_recv_tok_num.data_ptr(), r, pe) + self._p2p_out_scales[pe] = ms.shmem_ptr_p2p(self.shmem_out_scales.data_ptr(), r, pe) + + self._p2p_comb_inp = torch.zeros(npes, dtype=torch.int64, device=self._dev) + self._p2p_comb_inp_wts = torch.zeros(npes, dtype=torch.int64, device=self._dev) + self._p2p_xdb_mem = torch.zeros(npes, dtype=torch.int64, device=self._dev) + for pe in range(npes): + self._p2p_comb_inp[pe] = ms.shmem_ptr_p2p(self.shmem_comb_inp_tok.data_ptr(), r, pe) + self._p2p_comb_inp_wts[pe] = ms.shmem_ptr_p2p(self.shmem_comb_inp_wts.data_ptr(), r, pe) + self._p2p_xdb_mem[pe] = ms.shmem_ptr_p2p(self.shmem_xdev_bar_mem.data_ptr(), r, pe) + + _disp_wpb = config.warp_num_per_block + self._disp_fn = make_dispatch_jit( + rank=r, + npes=config.world_size, + experts_per_rank=config.num_experts_per_rank, + experts_per_token=config.num_experts_per_token, + hidden_dim=config.hidden_dim, + max_tok_per_rank=config.max_num_inp_token_per_rank, + block_num=config.block_num, + warp_num_per_block=_disp_wpb, + data_type=config.data_type, + scale_dim=config.scale_dim, + scale_type_size=config.scale_type_size, + enable_std_moe=config.enable_std_moe, + ) + + _use_fp8_cast = config.quant_type == "fp8_direct_cast" and config.data_type == torch.bfloat16 + _comb_dtype = torch.float8_e4m3fn if _use_fp8_cast else config.data_type + # Mixed-dtype Stage 1 (mori UseFp8DirectCast equivalent): when + # _use_fp8_cast is on, the user feeds bf16 input to ``combine()`` and + # the kernel performs an inline bf16 → fp8 cast in Stage 1 before P2P + # scatter. This avoids an extra ~12μs ``input.to(fp8).contiguous()`` + # PyTorch elementwise kernel that would otherwise sit on the cudagraph + # critical path. Wrapper-side allocation/views remain fp8-stride. + _comb_inp_dt = torch.bfloat16 if _use_fp8_cast else None + self._comb_fn = make_combine_jit( + rank=r, + npes=config.world_size, + experts_per_rank=config.num_experts_per_rank, + experts_per_token=config.num_experts_per_token, + hidden_dim=config.hidden_dim, + max_tok_per_rank=config.max_num_inp_token_per_rank, + block_num=config.block_num, + warp_num_per_block=_disp_wpb, + data_type=_comb_dtype, + enable_weights=True, + enable_std_moe=config.enable_std_moe, + use_p2p_read=not config.use_external_inp_buf, + inp_data_type=_comb_inp_dt, + ) + self._use_fp8_cast = _use_fp8_cast + + # barrier flag 初始值必须为 1, 否则首次 wait_until_equals(slot, 0) 立即满足 + self._xdev_flag = torch.ones(1, dtype=torch.int64, device=self._dev) + + self._fx_out_tok = fx.Int64(self.shmem_disp_out_tok.data_ptr()) + self._fx_out_wts = fx.Int64(self.shmem_disp_out_wts.data_ptr()) + self._fx_out_idx = fx.Int64(self.shmem_disp_out_idx.data_ptr()) + self._fx_tok_off = fx.Int64(self.shmem_tok_off.data_ptr()) + self._fx_recv_num = fx.Int64(self.shmem_recv_tok_num.data_ptr()) + self._fx_dest_ctr = fx.Int64(self.dest_pe_ctr.data_ptr()) + self._fx_disp_bar = fx.Int64(self.disp_bar.data_ptr()) + self._fx_tok_map = fx.Int64(self.dest_tok_map.data_ptr()) + self._fx_tis = fx.Int64(self.shmem_tok_id_to_src.data_ptr()) + self._fx_total_rv = fx.Int64(self.total_recv.data_ptr()) + # combine 固定地址 + self._fx_comb_inp = fx.Int64(self.shmem_comb_inp_tok.data_ptr()) + self._fx_comb_out = fx.Int64(self.shmem_comb_out_tok.data_ptr()) + self._fx_xdb_mem = fx.Int64(self.shmem_xdev_bar_mem.data_ptr()) + self._fx_xdev_flag = fx.Int64(self._xdev_flag.data_ptr()) + self._fx_comb_bar = fx.Int64(self.comb_bar.data_ptr()) + self._fx_trecv = fx.Int64(self.total_recv.data_ptr()) + self._fx_p2p_tok_off = fx.Int64(self._p2p_tok_off.data_ptr()) + self._fx_p2p_tis = fx.Int64(self._p2p_tis.data_ptr()) + self._fx_p2p_out_wts = fx.Int64(self._p2p_out_wts.data_ptr()) + self._fx_p2p_out_idx = fx.Int64(self._p2p_out_idx.data_ptr()) + self._fx_p2p_out_tok = fx.Int64(self._p2p_out_tok.data_ptr()) + self._fx_p2p_recv_num = fx.Int64(self._p2p_recv_num.data_ptr()) + self._fx_p2p_out_scales = fx.Int64(self._p2p_out_scales.data_ptr()) + self._fx_out_scales = fx.Int64(self.shmem_out_scales.data_ptr()) + self._fx_p2p_comb_inp = fx.Int64(self._p2p_comb_inp.data_ptr()) + self._fx_p2p_comb_inp_wts = fx.Int64(self._p2p_comb_inp_wts.data_ptr()) + self._fx_p2p_xdb_mem = fx.Int64(self._p2p_xdb_mem.data_ptr()) + self._fx_comb_inp_wts = fx.Int64(self.shmem_comb_inp_wts.data_ptr()) + self._fx_comb_out_wts = fx.Int64(self.shmem_comb_out_wts.data_ptr()) + self._fx_packed_recv_count = fx.Int64(self.packed_recv_count.data_ptr()) + self._fx_packed_recv_src_info = fx.Int64(self.packed_recv_src_info.data_ptr()) + self._fx_disp_tok_map = fx.Int64(self.disp_tok_to_ep_slot_map.data_ptr()) + self._fx_disp_grid_bar = fx.Int64(self.disp_grid_bar.data_ptr()) + self._fx_disp_out_wts = fx.Int64(self.shmem_disp_out_wts.data_ptr()) + + self._disp_compiled = None + self._comb_compiled = None + # combine kernel 的 skip_stage1 变体:给 fused_gemm2_combine 算子使用, + # 此时 fused kernel 已经把 token / 权重 P2P 写入 shmem_comb_inp[_wts], + # combine 只跑 Stage 2 (CrossDeviceBarrier) + Stage 3 (本地 weighted-accum)。 + self._comb_no_s1_fn = None + self._comb_no_s1_compiled = None + + def _alloc_buffers(self): + cfg = self.cfg + npes = cfg.world_size + k = cfg.num_experts_per_token + mt = cfg.max_num_inp_token_per_rank + mr = cfg.max_recv # npes * mt + + tb = cfg.token_bytes + tok_i16_mr = (mr * tb + 1) // 2 + tok_i16_mt = (mt * tb + 1) // 2 + + # Symmetric shmem buffers + self.shmem_disp_out_tok = mori_shmem_create_tensor((tok_i16_mr,), torch.int16) + self.shmem_disp_out_wts = mori_shmem_create_tensor((mr * k,), torch.float32) + self.shmem_disp_out_idx = mori_shmem_create_tensor((mr * k,), torch.int32) + scale_total = mr * cfg.scale_bytes if cfg.scale_bytes > 0 else 1 + self.shmem_out_scales = mori_shmem_create_tensor((scale_total,), torch.int8) + self.shmem_tok_off = mori_shmem_create_tensor((1,), torch.int32) + self.shmem_recv_tok_num = mori_shmem_create_tensor((npes,), torch.int32) + self.shmem_tok_id_to_src = mori_shmem_create_tensor((mr,), torch.int32) + self.shmem_comb_inp_tok = mori_shmem_create_tensor((tok_i16_mr,), torch.int16) + self.shmem_comb_out_tok = mori_shmem_create_tensor((tok_i16_mt,), torch.int16) + self.shmem_comb_inp_wts = mori_shmem_create_tensor((mr * k,), torch.float32) + self.shmem_comb_out_wts = mori_shmem_create_tensor((mt * k,), torch.float32) + self.shmem_xdev_bar_mem = mori_shmem_create_tensor((npes,), torch.int64) + + # mori_shmem_create_tensor 走 shmem_malloc,分配的是未初始化的 raw memory。 + # 对 fused MoE-GEMM2 + EP-Combine 路径,GEMM2 需要在 epilogue 用 + # shmem_tok_id_to_src 解码 dest_pe / dest_lid,越界 garbage 会触发 LDS + # OOB → 写到任意全局地址 → 破坏 control state。这里把所有 combine 路径 + # 直接读写的 symmetric buffer 显式清零,保证: + # - shmem_tok_id_to_src[t] 对未被 dispatch 写入的 t 解码为 (pe=0, lid=0), + # P2P scatter 退化成"安全无副作用"(多写同一槽位) + # - shmem_xdev_bar_mem 起始 0,CrossDeviceBarrier 第一次 wait 不会读到 + # 残留值(依赖 cur_flag 单调递增) + # - shmem_comb_inp_{tok,wts} 起始 0,combine_no_stage1 在 stage 3 累加 + # 时不会读到 garbage + self.shmem_tok_id_to_src.zero_() + self.shmem_comb_inp_tok.zero_() + self.shmem_comb_inp_wts.zero_() + self.shmem_xdev_bar_mem.zero_() + + # Local device buffers + self.dest_pe_ctr = torch.zeros(npes, dtype=torch.int32, device=self._dev) + self.disp_bar = torch.zeros(1, dtype=torch.int32, device=self._dev) + self.comb_bar = torch.zeros(1, dtype=torch.int32, device=self._dev) + self.total_recv = torch.zeros(1, dtype=torch.int32, device=self._dev) + sentinel = cfg.world_size * mr + self.dest_tok_map = torch.full((mt * k,), sentinel, dtype=torch.int32, device=self._dev) + + # StdMoE buffers + if cfg.enable_std_moe: + epr = cfg.num_experts_per_rank + max_tok_per_expert = mr # world_size * max_num_inp_token_per_rank + self.packed_recv_count = torch.zeros(epr, dtype=torch.int32, device=self._dev) + self.packed_recv_src_info = torch.zeros(epr * max_tok_per_expert, dtype=torch.int32, device=self._dev) + self.disp_tok_to_ep_slot_map = torch.full((mr * k,), -1, dtype=torch.int64, device=self._dev) + self.disp_grid_bar = torch.zeros(1, dtype=torch.int32, device=self._dev) + else: + self.packed_recv_count = torch.zeros(1, dtype=torch.int32, device=self._dev) + self.packed_recv_src_info = torch.zeros(1, dtype=torch.int32, device=self._dev) + self.disp_tok_to_ep_slot_map = torch.zeros(1, dtype=torch.int64, device=self._dev) + self.disp_grid_bar = torch.zeros(1, dtype=torch.int32, device=self._dev) + + def barrier(self): + ms.shmem_barrier_all() + + def reset(self): + self.barrier() + + def dispatch( + self, input, weights, scales, indices, packed_recv_x=None, block_num=-1, rdma_block_num=-1, warp_per_block=-1 + ): + cfg = self.cfg + cur_tok = input.shape[0] + stream = torch.cuda.current_stream() + inp_c = input if input.is_contiguous() else input.contiguous() + wts_c = weights if weights.is_contiguous() else weights.contiguous() + idx_c = ( + indices + if (indices.dtype == torch.int32 and indices.is_contiguous()) + else indices.to(torch.int32).contiguous() + ) + + sc_ptr = scales.data_ptr() if scales is not None else 0 + prx_ptr = packed_recv_x.data_ptr() if packed_recv_x is not None else 0 + + if cfg.enable_std_moe: + self.packed_recv_count.zero_() + + _std_args = ( + self._fx_packed_recv_count if cfg.enable_std_moe else fx.Int64(0), + self._fx_packed_recv_src_info, + self._fx_disp_tok_map, + self._fx_disp_grid_bar, + ) + + if self._disp_compiled is None: + args = ( + fx.Int64(inp_c.data_ptr()), + fx.Int64(idx_c.data_ptr()), + fx.Int64(wts_c.data_ptr()), + self._fx_out_tok, + self._fx_out_wts, + self._fx_out_idx, + self._fx_tok_off, + self._fx_recv_num, + self._fx_dest_ctr, + self._fx_disp_bar, + self._fx_tok_map, + self._fx_tis, + self._fx_total_rv, + self._fx_p2p_tok_off, + self._fx_p2p_tis, + self._fx_p2p_out_wts, + self._fx_p2p_out_idx, + self._fx_p2p_out_tok, + self._fx_p2p_recv_num, + fx.Int64(sc_ptr), + self._fx_p2p_out_scales, + fx.Int64(prx_ptr), + *_std_args, + cur_tok, + stream, + ) + self._disp_compiled = flyc.compile(self._disp_fn, *args) + else: + self._disp_compiled( + inp_c.data_ptr(), + idx_c.data_ptr(), + wts_c.data_ptr(), + self._fx_out_tok, + self._fx_out_wts, + self._fx_out_idx, + self._fx_tok_off, + self._fx_recv_num, + self._fx_dest_ctr, + self._fx_disp_bar, + self._fx_tok_map, + self._fx_tis, + self._fx_total_rv, + self._fx_p2p_tok_off, + self._fx_p2p_tis, + self._fx_p2p_out_wts, + self._fx_p2p_out_idx, + self._fx_p2p_out_tok, + self._fx_p2p_recv_num, + sc_ptr, + self._fx_p2p_out_scales, + prx_ptr, + *_std_args, + cur_tok, + stream, + ) + + mr = cfg.max_recv + k = cfg.num_experts_per_token + + out_tok = ( + self.shmem_disp_out_tok.view(torch.int8)[: mr * cfg.token_bytes] + .view(cfg.data_type) + .view(mr, cfg.token_view_dim) + ) + out_wts = self.shmem_disp_out_wts.view(mr, k) + out_idx = self.shmem_disp_out_idx.view(mr, k) + out_scales = None + if cfg.scale_bytes > 0: + out_scales = self.shmem_out_scales[: mr * cfg.scale_bytes].view(mr, cfg.scale_dim * cfg.scale_type_size) + + result = (out_tok, out_wts, out_scales, out_idx, self.total_recv) + if cfg.enable_std_moe: + epr = cfg.num_experts_per_rank + result = result + ( + self.packed_recv_count[:epr], + self.packed_recv_src_info, + ) + return result + + def combine( + self, + input, + weights, + indices, + packed_recv_x=None, + cur_tok=None, + block_num=-1, + rdma_block_num=-1, + warp_per_block=-1, + use_external_inp_buf=-1, + call_reset=False, + ): + cfg = self.cfg + stream = torch.cuda.current_stream() + + # In _use_fp8_cast mode, the combine kernel does the bf16 → fp8 cast + # inline in Stage 1 (mori UseFp8DirectCast equivalent), so the wrapper + # passes bf16 input straight through. Skipping the PyTorch-level + # ``.to(fp8).contiguous()`` saves ~12μs per iter on the cudagraph + # critical path. + inp_c = input if input.is_contiguous() else input.contiguous() + _cur_tok = cur_tok if cur_tok is not None else cfg.max_num_inp_token_per_rank + + wts_ptr = self.shmem_disp_out_wts.data_ptr() if weights is None else weights.data_ptr() + + _prx_ref = None + if self._use_fp8_cast and packed_recv_x is not None: + # std-MoE expert-major buffer (`packed_recv_x`) is produced in bf16 + # by the upstream pipeline; downstream Stage 1 reads it in fp8 + # dtype, so we still cast here. This branch is independent from + # the regular combine input path above. + _prx_ref = packed_recv_x.view(torch.bfloat16).to(torch.float8_e4m3fn).contiguous() + prx_ptr = _prx_ref.data_ptr() + else: + prx_ptr = packed_recv_x.data_ptr() if packed_recv_x is not None else 0 + + _std_args_comb = ( + fx.Int64(prx_ptr), + self._fx_disp_tok_map, + self._fx_disp_out_wts, + ) + + if self._comb_compiled is None: + args = ( + fx.Int64(inp_c.data_ptr()), + self._fx_comb_inp, + self._fx_comb_out, + self._fx_xdb_mem, + self._fx_xdev_flag, + self._fx_tok_map, + self._fx_comb_bar, + self._fx_trecv, + self._fx_tis, + self._fx_p2p_comb_inp, + self._fx_p2p_xdb_mem, + fx.Int64(wts_ptr), + self._fx_comb_inp_wts, + self._fx_comb_out_wts, + self._fx_p2p_comb_inp_wts, + *_std_args_comb, + _cur_tok, + stream, + ) + self._comb_compiled = flyc.compile(self._comb_fn, *args) + else: + self._comb_compiled( + inp_c.data_ptr(), + self._fx_comb_inp, + self._fx_comb_out, + self._fx_xdb_mem, + self._fx_xdev_flag, + self._fx_tok_map, + self._fx_comb_bar, + self._fx_trecv, + self._fx_tis, + self._fx_p2p_comb_inp, + self._fx_p2p_xdb_mem, + wts_ptr, + self._fx_comb_inp_wts, + self._fx_comb_out_wts, + self._fx_p2p_comb_inp_wts, + prx_ptr, + self._fx_disp_tok_map, + self._fx_disp_out_wts, + _cur_tok, + stream, + ) + + mt = cfg.max_num_inp_token_per_rank + k = cfg.num_experts_per_token + + # fp8_direct_cast contract: external dtype is bf16 on both ends; the + # combine kernel itself writes bf16 to ``shmem_comb_out_tok`` (Stage 3 + # _from_accum casts v4f32 → v4bf16 inline), so we view the buffer as + # bf16 directly with no extra PyTorch-level cast on the critical path. + out_tok = ( + self.shmem_comb_out_tok.view(torch.int8)[: mt * cfg.token_bytes] + .view(cfg.data_type) + .view(mt, cfg.token_view_dim) + ) + out_wts = self.shmem_comb_out_wts.view(mt, k) + + if call_reset: + self.reset() + return out_tok, out_wts + + def combine_no_stage1( + self, input, weights, indices, packed_recv_x=None, cur_tok=None, call_reset=False, enable_weights: bool = True + ): + """combine 的 stage1-skipped 变体。 + + 语义:跳过 P2P scatter(外部 fused kernel 已把数据写入 shmem_comb_inp[_wts]), + 只执行 Stage 2 (CrossDeviceBarrier) + Stage 3 (本地 weighted-accum)。 + + Parameters + ---------- + enable_weights + ``True`` (默认) 兼容当前 fused-with-weight 链路:在 combine + kernel 内保留 Stage 1 的 weight scatter + Stage 3b 的 weight + accumulate。weight scatter 显式留在 combine kernel 内(而不是 + 放在上游 fused GEMM2 的 epilogue 里),因为 16B 小写若与上游 + token P2P 并发会被 ROCm IPC fabric 静默丢,必须放在静态 fabric + 上由 combine kernel 完成。 + ``False`` 给 weight-free fused 路径(fused MoE 上游已经把 + weight 处理掉了,combine 端不需要 out_wts):完全 DCE 掉 + weight scatter + Stage 3b,省 ~3-5 μs。 + 两种变体走不同的 JIT 缓存,互不污染。 + + 约定:调用前 fused kernel 必须保证: + - shmem_comb_inp_tok 已写入本 PE 应接收的所有 token(按 max_tok_per_rank 槽位) + - shmem_comb_inp_wts 已写入对应权重(仅 enable_weights=True 时需要) + - total_recv 已被 dispatch 设置完毕(Stage 3 用于读 cur_rank_num_token) + """ + cfg = self.cfg + stream = torch.cuda.current_stream() + + # When skip_stage1=True (the only mode this method ever compiles for), + # the combine kernel does NOT read inp_c — Stage 1 is bypassed and the + # kernel reads from shmem_comb_inp_tok directly (already populated by + # the upstream fused GEMM2 epilogue P2P scatter). So skip the + # potentially-expensive Python-level fp8 cast (.to(fp8) + .contiguous()) + # if the caller gave us a fp8 input or even a placeholder bf16: the + # cast is a ~12us elementwise kernel that gets captured by cudagraph + # and ends up serially on the chain critical path for nothing. + # Caller (fused op wrapper) already CV-casted in the GEMM2 epilogue. + if self._use_fp8_cast and input.dtype != torch.float8_e4m3fn: + inp_c = input.to(torch.float8_e4m3fn).contiguous() + else: + inp_c = input if input.is_contiguous() else input.contiguous() + _cur_tok = cur_tok if cur_tok is not None else cfg.max_num_inp_token_per_rank + + wts_ptr = self.shmem_disp_out_wts.data_ptr() if weights is None else weights.data_ptr() + + _prx_ref = None + if self._use_fp8_cast and packed_recv_x is not None: + _prx_ref = packed_recv_x.view(torch.bfloat16).to(torch.float8_e4m3fn).contiguous() + prx_ptr = _prx_ref.data_ptr() + else: + prx_ptr = packed_recv_x.data_ptr() if packed_recv_x is not None else 0 + + # JIT 缓存按 enable_weights 区分(两份编译产物)。 + # 历史 self._comb_no_s1_fn / _compiled 升级为 dict[bool, fn]。 + if not isinstance(self._comb_no_s1_fn, dict): + self._comb_no_s1_fn = {} + self._comb_no_s1_compiled = {} + + if enable_weights not in self._comb_no_s1_fn: + from .dispatch_combine_intranode_kernel import make_combine_jit + + _use_fp8_cast = self._use_fp8_cast + _comb_dtype = torch.float8_e4m3fn if _use_fp8_cast else cfg.data_type + # Mixed-dtype contract for fp8_direct_cast: external dtype = bf16, + # transport dtype = fp8. Stage 3 _from_accum will cast f32 → bf16 + # inline so kernel writes bf16 directly to shmem_comb_out_tok and + # the wrapper does NOT need a post .to(bf16) cast. + _comb_inp_dt = torch.bfloat16 if _use_fp8_cast else None + # enable_weights=False 路径(fused MoE 不需要 out_wts): + # weight scatter + Stage 3b weight accumulate 都在 const_expr + # 处被 DCE 掉,省 ~3-5μs。 + # enable_weights=True 路径(兼容 fused-with-weight): + # combine kernel 在 skip_stage1=True 下默认仍跑 weight scatter, + # 因为同 fabric 上与 token P2P 并发的 16B 小写会被静默丢,必须 + # 放在静态 fabric 上由 combine kernel 完成。 + self._comb_no_s1_fn[enable_weights] = make_combine_jit( + rank=cfg.rank, + npes=cfg.world_size, + experts_per_rank=cfg.num_experts_per_rank, + experts_per_token=cfg.num_experts_per_token, + hidden_dim=cfg.hidden_dim, + max_tok_per_rank=cfg.max_num_inp_token_per_rank, + block_num=cfg.block_num, + warp_num_per_block=cfg.warp_num_per_block, + data_type=_comb_dtype, + enable_weights=bool(enable_weights), + enable_std_moe=cfg.enable_std_moe, + use_p2p_read=not cfg.use_external_inp_buf, + skip_stage1=True, + inp_data_type=_comb_inp_dt, + ) + + if enable_weights not in self._comb_no_s1_compiled: + args = ( + fx.Int64(inp_c.data_ptr()), + self._fx_comb_inp, + self._fx_comb_out, + self._fx_xdb_mem, + self._fx_xdev_flag, + self._fx_tok_map, + self._fx_comb_bar, + self._fx_trecv, + self._fx_tis, + self._fx_p2p_comb_inp, + self._fx_p2p_xdb_mem, + fx.Int64(wts_ptr), + self._fx_comb_inp_wts, + self._fx_comb_out_wts, + self._fx_p2p_comb_inp_wts, + fx.Int64(prx_ptr), + self._fx_disp_tok_map, + self._fx_disp_out_wts, + _cur_tok, + stream, + ) + self._comb_no_s1_compiled[enable_weights] = flyc.compile(self._comb_no_s1_fn[enable_weights], *args) + else: + self._comb_no_s1_compiled[enable_weights]( + inp_c.data_ptr(), + self._fx_comb_inp, + self._fx_comb_out, + self._fx_xdb_mem, + self._fx_xdev_flag, + self._fx_tok_map, + self._fx_comb_bar, + self._fx_trecv, + self._fx_tis, + self._fx_p2p_comb_inp, + self._fx_p2p_xdb_mem, + wts_ptr, + self._fx_comb_inp_wts, + self._fx_comb_out_wts, + self._fx_p2p_comb_inp_wts, + prx_ptr, + self._fx_disp_tok_map, + self._fx_disp_out_wts, + _cur_tok, + stream, + ) + + mt = cfg.max_num_inp_token_per_rank + k = cfg.num_experts_per_token + + # fp8_direct_cast contract: combine kernel writes bf16 to + # ``shmem_comb_out_tok`` directly (see ``combine`` above for details). + out_tok = ( + self.shmem_comb_out_tok.view(torch.int8)[: mt * cfg.token_bytes] + .view(cfg.data_type) + .view(mt, cfg.token_view_dim) + ) + out_wts = self.shmem_comb_out_wts.view(mt, k) + + if call_reset: + self.reset() + return out_tok, out_wts + + def get_dispatch_src_token_pos(self): + torch.cuda.synchronize() + n = int(self.total_recv[0].item()) + return self.shmem_tok_id_to_src[:n].clone() + + def get_registered_combine_input_buffer(self, dtype, hidden_dim=-1): + h = hidden_dim if hidden_dim > 0 else self.cfg.token_view_dim + dt = dtype if dtype is not None else self.cfg.data_type + return self.shmem_comb_inp_tok.view(torch.int8).view(dt).view(-1, h) diff --git a/python/flydsl/expr/rocdl/__init__.py b/python/flydsl/expr/rocdl/__init__.py index 70e821c36..99e1a94ef 100644 --- a/python/flydsl/expr/rocdl/__init__.py +++ b/python/flydsl/expr/rocdl/__init__.py @@ -43,9 +43,8 @@ _ods_mfma_i32_16x16x32_i8 = mfma_i32_16x16x32_i8 _ods_mfma_f32_16x16x32_f16 = globals().get("mfma_f32_16x16x32_f16", None) _ods_mfma_f32_16x16x32_bf16 = globals().get("mfma_f32_16x16x32_bf16", None) -_ods_mfma_scale_f32_16x16x128_f8f6f4 = ( - globals().get("mfma_scale_f32_16x16x128_f8f6f4", None) - or globals().get("mfma_scale_f32_16x16x128_f8f6f4_", None) +_ods_mfma_scale_f32_16x16x128_f8f6f4 = globals().get("mfma_scale_f32_16x16x128_f8f6f4", None) or globals().get( + "mfma_scale_f32_16x16x128_f8f6f4_", None ) mask_mfma = 0x008 mask_vmem_rd = 0x020 @@ -528,8 +527,8 @@ def ballot(res, pred, **kw): ``res`` selects the lane-mask width (``i32`` on wave32, ``i64`` on wave64). """ - from ..._mlir.ir import IntegerType from ..._mlir.dialects import llvm as _llvm + from ..._mlir.ir import IntegerType pred_v = _to_ir(pred) i1 = IntegerType.get_signless(1) diff --git a/tests/kernels/test_profiler_dispatch_combine.py b/tests/kernels/test_profiler_dispatch_combine.py new file mode 100644 index 000000000..16ba7a02f --- /dev/null +++ b/tests/kernels/test_profiler_dispatch_combine.py @@ -0,0 +1,1442 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 FlyDSL Project Contributors + +""" +FlyDSL 和 mori ref 的 dispatch/combine kernel 性能测试。 + +两个正交维度可自由组合: + --mode 测量方式:profile(torch.profiler 采集)| bench(CUDA Event 计时) + --cudagraph 执行方式:不带此标志 = eager 模式 | 带 = CUDAGraph capture+replay + +四种组合: + 1. profile + eager : torch.profiler 采集 eager 执行的 kernel + E2E + CPU 时间 + 2. bench + eager : CUDA Event 计时 eager dispatch/combine(无 profiler 开销) + 3. profile + cudagraph: torch.profiler 采集 CUDAGraph replay 中的 kernel 时间 + 4. bench + cudagraph: CUDA Event 计时 CUDAGraph replay(零 Python launch 开销) + +启动方式(支持 torchrun 或直接 python): + # profile + eager(默认) + python tests/kernels/test_profiler_dispatch_combine.py --max-tokens 512 + + # bench + eager + python tests/kernels/test_profiler_dispatch_combine.py --mode bench + + # bench + cudagraph + python tests/kernels/test_profiler_dispatch_combine.py --mode bench --cudagraph + + # profile + cudagraph + python tests/kernels/test_profiler_dispatch_combine.py --mode profile --cudagraph + + # 只测 FlyDSL + python tests/kernels/test_profiler_dispatch_combine.py --bench-op flydsl +""" + +from __future__ import annotations + +import argparse +import json +import os +import sys + +import torch +import torch.distributed as dist +from torch.profiler import ProfilerActivity, profile, record_function + +os.environ.setdefault("MORI_SHMEM_HEAP_SIZE", "16G") + +# ── dtype 映射 ── +DTYPE_MAP = { + "bf16": torch.bfloat16, + "f32": torch.float32, + "fp8_ocp": torch.float8_e4m3fn, + "fp8_fnuz": torch.float8_e4m3fnuz, + "fp4": torch.float4_e2m1fn_x2, +} + +MORI_KERNEL_SUFFIX = { + "bf16": "bf16", + "f32": "f32", + "fp8_ocp": "fp8_ocp", + "fp8_fnuz": "fp8_fnuz", + "fp4": "fp4", +} + +_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) +for _p in [_ROOT, "/home/yashao/FlyDSL/python", "/home/yashao/mori/python"]: + if _p not in sys.path: + sys.path.insert(0, _p) + +import mori.shmem as ms # noqa: E402 + +from kernels.dispatch_combine_intranode_op import ( # noqa: E402 + FlyDSLDispatchCombineConfig, + FlyDSLDispatchCombineIntraNodeOp, +) + + +# ─── 分布式初始化 ───────────────────────────────────────────────────────────── +def setup_distributed(rank, world_size, master_port=29600): + if "LOCAL_RANK" not in os.environ: + os.environ.update( + { + "LOCAL_RANK": str(rank), + "RANK": str(rank), + "WORLD_SIZE": str(world_size), + "MASTER_ADDR": "localhost", + "MASTER_PORT": str(master_port), + } + ) + local_rank = int(os.environ.get("LOCAL_RANK", rank)) + torch.cuda.set_device(local_rank) + dev = torch.device("cuda", local_rank) + dist.init_process_group( + backend="cpu:gloo,cuda:nccl", + rank=rank, + world_size=world_size, + device_id=dev, + ) + import torch._C._distributed_c10d as c10d + + c10d._register_process_group("default", dist.group.WORLD) + ms.shmem_torch_process_group_init("default") + return local_rank, world_size + + +def cleanup(): + try: + ms.shmem_finalize() + except Exception: + pass + if dist.is_initialized(): + try: + dist.barrier() + except Exception: + pass + dist.destroy_process_group() + + +_MORI_SUPPORTED_DTYPES = {torch.bfloat16, torch.float32, torch.float8_e4m3fn, torch.float4_e2m1fn_x2} + + +def build_mori_ref(rank, world_size, cfg, block_num: int = None, warp_per_block: int = None): + if cfg.data_type not in _MORI_SUPPORTED_DTYPES: + raise RuntimeError(f"mori does not support dtype {cfg.data_type} on this platform") + from mori.ops.dispatch_combine import EpDispatchCombineConfig, EpDispatchCombineOp + + elem = torch.tensor([], dtype=cfg.data_type).element_size() + mcfg = EpDispatchCombineConfig( + data_type=cfg.data_type, + rank=rank, + world_size=world_size, + hidden_dim=cfg.hidden_dim, + scale_dim=cfg.num_experts_per_token, + scale_type_size=4, + max_token_type_size=elem, + max_num_inp_token_per_rank=cfg.max_num_inp_token_per_rank, + num_experts_per_rank=cfg.num_experts_per_rank, + num_experts_per_token=cfg.num_experts_per_token, + warp_num_per_block=warp_per_block if warp_per_block is not None else cfg.warp_num_per_block, + block_num=block_num if block_num is not None else cfg.block_num, + gpu_per_node=world_size, + use_external_inp_buf=cfg.use_external_inp_buf, + quant_type=cfg.quant_type, + ) + return EpDispatchCombineOp(mcfg) + + +def _save_profile_json(prof, out_path: str, rank: int, op_tag: str, meta: dict): + """将 profiler 结果序列化为 JSON 文件。 + + JSON 结构: + { + "meta": {op_tag, rank, max_tokens, hidden_dim, k, world_size, ...}, + "kernel_stats": [ {name, calls, cuda_time_avg_us, cpu_time_avg_us}, ... ] + } + """ + rows = [] + for evt in prof.key_averages(): + rows.append( + { + "name": evt.key, + "calls": evt.count, + "cuda_time_avg_us": round(evt.device_time, 2), + "cuda_time_total_us": round(evt.device_time * evt.count, 2), + "cpu_time_avg_us": round(evt.cpu_time, 2), + "cpu_time_total_us": round(evt.cpu_time * evt.count, 2), + } + ) + # 按 GPU time 降序 + rows.sort(key=lambda r: r["cuda_time_total_us"], reverse=True) + + payload = { + "meta": {**meta, "op": op_tag, "rank": rank}, + "kernel_stats": rows, + } + os.makedirs(os.path.dirname(out_path), exist_ok=True) + with open(out_path, "w") as f: + json.dump(payload, f, indent=2, ensure_ascii=False) + + trace_path = out_path.replace(".json", "_trace.json") + prof.export_chrome_trace(trace_path) + + +def _allreduce_stats( + prof, + op_tag: str, + rank: int, + world_size: int, + dev: torch.device, + dtype_key: str = "bf16", + quant_type: str = "none", + use_p2p_read: bool = False, +) -> dict: + """从本卡 profiler 提取关键指标,跨卡 all_reduce 后返回 avg/min/max 字典。 + + 采集 6 项指标(顺序固定,打包成 float64 tensor 做 all_reduce): + 0: dispatch GPU kernel time (μs/call) + 1: combine GPU kernel time (μs/call) + 2: dispatch record_function CUDA time (μs/call) + 3: combine record_function CUDA time (μs/call) + 4: dispatch record_function CPU time (μs/call) + 5: combine record_function CPU time (μs/call) + """ + msuf = MORI_KERNEL_SUFFIX.get(dtype_key, "bf16") + _cast_suf = "_fp8cast" if (quant_type == "fp8_direct_cast" and not use_p2p_read) else "" + _p2p_suf = "_p2p" if use_p2p_read else "_nop2p" + if op_tag == "flydsl": + d_kernel = "ep_dispatch_intranode_0" + c_kernel = "ep_combine_intranode_0" + else: + d_kernel = f"EpDispatchIntraNodeKernel_{msuf}" + c_kernel = f"EpCombineIntraNodeKernel_{msuf}{_p2p_suf}{_cast_suf}" + d_label = f"{op_tag}::dispatch" + c_label = f"{op_tag}::combine" + + ev = {e.key: e for e in prof.key_averages()} + + def gpu_us(key): + e = ev.get(key) + return e.device_time if (e and e.count) else 0.0 + + def cpu_us(key): + e = ev.get(key) + return e.cpu_time if (e and e.count) else 0.0 + + local = torch.tensor( + [ + gpu_us(d_kernel), + gpu_us(c_kernel), + gpu_us(d_label), + gpu_us(c_label), + cpu_us(d_label), + cpu_us(c_label), + ], + dtype=torch.float64, + device=dev, + ) + + s = local.clone() + dist.all_reduce(s, op=dist.ReduceOp.SUM) + mx = local.clone() + dist.all_reduce(mx, op=dist.ReduceOp.MAX) + mn = local.clone() + dist.all_reduce(mn, op=dist.ReduceOp.MIN) + avg = s / world_size + + keys = [ + "dispatch_gpu", + "combine_gpu", + "dispatch_cuda_e2e", + "combine_cuda_e2e", + "dispatch_cpu_e2e", + "combine_cpu_e2e", + ] + return {k: {"avg": avg[i].item(), "min": mn[i].item(), "max": mx[i].item()} for i, k in enumerate(keys)} + + +def _print_aggregated(stats: dict, op_tag: str, world_size: int, meta: dict): + """rank 0 打印全卡聚合统计。""" + sep = "=" * 72 + print(f"\n{sep}") + print( + f" {op_tag.upper()} EP={world_size} bs={meta['max_tokens']} " + f"h={meta['hidden_dim']} k={meta['k']} ({meta['iters']} iters)" + ) + print(f" 所有 {world_size} 张卡的 avg / min / max(μs/call)") + print(sep) + hdr = f" {'指标':<36} {'avg':>8} {'min':>8} {'max':>8}" + print(hdr) + print(f" {'-'*60}") + + rows = [ + ("[Device] dispatch kernel GPU time", "dispatch_gpu"), + ("[Device] combine kernel GPU time", "combine_gpu"), + ("[E2E] dispatch CUDA time (含sync)", "dispatch_cuda_e2e"), + ("[E2E] combine CUDA time (含sync)", "combine_cuda_e2e"), + ("[Host] dispatch CPU time", "dispatch_cpu_e2e"), + ("[Host] combine CPU time", "combine_cpu_e2e"), + ] + for label, key in rows: + v = stats[key] + print(f" {label:<36} {v['avg']:>8.1f} {v['min']:>8.1f} {v['max']:>8.1f}") + print() + + +def _allreduce_cudagraph_stats_from_key_averages( + prof, + op_tag: str, + rank: int, + world_size: int, + dev: torch.device, + dtype_key: str = "bf16", + quant_type: str = "none", + use_p2p_read: bool = False, +) -> dict: + """从 key_averages() 提取指标(仅含 active 阶段数据),跨卡 all_reduce。 + + 采集 4 项: + 0: dispatch kernel GPU time + 1: combine kernel GPU time + 2: cudagraph_replay CUDA E2E time + 3: cudagraph_replay CPU E2E time + """ + msuf = MORI_KERNEL_SUFFIX.get(dtype_key, "bf16") + _cast_suf = "_fp8cast" if (quant_type == "fp8_direct_cast" and not use_p2p_read) else "" + _p2p_suf = "_p2p" if use_p2p_read else "_nop2p" + if op_tag == "flydsl": + d_kernel = "ep_dispatch_intranode_0" + c_kernel = "ep_combine_intranode_0" + else: + d_kernel = f"EpDispatchIntraNodeKernel_{msuf}" + c_kernel = f"EpCombineIntraNodeKernel_{msuf}{_p2p_suf}{_cast_suf}" + cg_label = f"{op_tag}::cudagraph_replay" + + ev = {e.key: e for e in prof.key_averages()} + + def gpu_us(key): + e = ev.get(key) + return e.device_time if (e and e.count) else 0.0 + + def cpu_us(key): + e = ev.get(key) + return e.cpu_time if (e and e.count) else 0.0 + + local = torch.tensor( + [ + gpu_us(d_kernel), + gpu_us(c_kernel), + gpu_us(cg_label), + cpu_us(cg_label), + ], + dtype=torch.float64, + device=dev, + ) + + s = local.clone() + dist.all_reduce(s, op=dist.ReduceOp.SUM) + mx = local.clone() + dist.all_reduce(mx, op=dist.ReduceOp.MAX) + mn = local.clone() + dist.all_reduce(mn, op=dist.ReduceOp.MIN) + avg = s / world_size + + keys = ["dispatch_gpu", "combine_gpu", "replay_cuda_e2e", "replay_cpu_e2e"] + return {k: {"avg": avg[i].item(), "min": mn[i].item(), "max": mx[i].item()} for i, k in enumerate(keys)} + + +def _cudagraph_stats_from_trace( + trace_path: str, + op_tag: str, + rank: int, + world_size: int, + dev: torch.device, + active_iters: int, + skip_first: int = 5, + dtype_key: str = "bf16", + quant_type: str = "none", + use_p2p_read: bool = False, +) -> dict: + """从 chrome trace JSON 手动统计 kernel 性能,跳过前 skip_first 次 active 调用。 + + 流程:解析 trace → 按时间排序取最后 active_iters 个事件 → 丢弃前 skip_first 个 → 跨卡聚合。 + """ + with open(trace_path) as f: + tr = json.load(f) + + msuf = MORI_KERNEL_SUFFIX.get(dtype_key, "bf16") + _cast_suf = "_fp8cast" if (quant_type == "fp8_direct_cast" and not use_p2p_read) else "" + _p2p_suf = "_p2p" if use_p2p_read else "_nop2p" + if op_tag == "flydsl": + d_name, c_name = "ep_dispatch_intranode_0", "ep_combine_intranode_0" + else: + d_name = f"EpDispatchIntraNodeKernel_{msuf}" + c_name = f"EpCombineIntraNodeKernel_{msuf}{_p2p_suf}{_cast_suf}" + cg_name = f"{op_tag}::cudagraph_replay" + + kernel_events = [e for e in tr["traceEvents"] if e.get("cat") == "kernel"] + d_all = sorted([e for e in kernel_events if d_name in e.get("name", "")], key=lambda e: e["ts"]) + c_all = sorted([e for e in kernel_events if c_name in e.get("name", "")], key=lambda e: e["ts"]) + cg_all = sorted( + [e for e in tr["traceEvents"] if e.get("cat") == "gpu_user_annotation" and cg_name in e.get("name", "")], + key=lambda e: e["ts"], + ) + + d_active = [e["dur"] for e in d_all[-active_iters:]] + c_active = [e["dur"] for e in c_all[-active_iters:]] + cg_active = [e["dur"] for e in cg_all[-active_iters:]] + + d_valid = d_active[skip_first:] + c_valid = c_active[skip_first:] + cg_valid = cg_active[skip_first:] + + valid_n = len(d_valid) + if rank == 0: + print( + f"[trace-stats] {op_tag}: trace 中 dispatch={len(d_all)} combine={len(c_all)} 个事件," + f"取最后 {active_iters} 个 active,跳过前 {skip_first},有效 {valid_n} 个" + ) + + d_avg = sum(d_valid) / valid_n if valid_n else 0.0 + c_avg = sum(c_valid) / valid_n if valid_n else 0.0 + cg_avg = sum(cg_valid) / len(cg_valid) if cg_valid else 0.0 + + local = torch.tensor([d_avg, c_avg, cg_avg, 0.0], dtype=torch.float64, device=dev) + s = local.clone() + dist.all_reduce(s, op=dist.ReduceOp.SUM) + mx = local.clone() + dist.all_reduce(mx, op=dist.ReduceOp.MAX) + mn = local.clone() + dist.all_reduce(mn, op=dist.ReduceOp.MIN) + avg = s / world_size + + keys = ["dispatch_gpu", "combine_gpu", "replay_cuda_e2e", "replay_cpu_e2e"] + return {k: {"avg": avg[i].item(), "min": mn[i].item(), "max": mx[i].item()} for i, k in enumerate(keys)} + + +def _print_cudagraph_aggregated(stats: dict, op_tag: str, world_size: int, meta: dict, active_iters: int = None): + """rank 0 打印 cudagraph profiler 全卡聚合统计。""" + n = active_iters if active_iters is not None else meta["iters"] + sep = "=" * 72 + print(f"\n{sep}") + print( + f" {op_tag.upper()} [CUDAGraph+Profiler] EP={world_size} bs={meta['max_tokens']} " + f"h={meta['hidden_dim']} k={meta['k']} ({n} iters)" + ) + print(f" 所有 {world_size} 张卡的 avg / min / max(μs/call)") + print(sep) + hdr = f" {'指标':<36} {'avg':>8} {'min':>8} {'max':>8}" + print(hdr) + print(f" {'-'*60}") + + rows = [ + ("[Device] dispatch kernel GPU time", "dispatch_gpu"), + ("[Device] combine kernel GPU time", "combine_gpu"), + ("[E2E] replay CUDA time (含sync)", "replay_cuda_e2e"), + ("[Host] replay CPU time", "replay_cpu_e2e"), + ] + for label, key in rows: + v = stats[key] + print(f" {label:<36} {v['avg']:>8.1f} {v['min']:>8.1f} {v['max']:>8.1f}") + print() + + +def _make_profiler(active_iters: int = None, prof_warmup: int = 10): + """创建 profiler。 + + 使用 schedule 让前 (1 + prof_warmup) 步不做/轻量追踪, + 减少 ROCTracer 在多 GPU P2P shmem 场景下的累积压力。 + """ + kwargs = dict( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + record_shapes=False, + with_stack=False, + ) + if active_iters is not None and active_iters > 0: + kwargs["schedule"] = torch.profiler.schedule( + wait=1, + warmup=prof_warmup, + active=active_iters, + repeat=1, + ) + return profile(**kwargs) + + +# ─── bench 模式:不用 profiler,用 CUDA Event 计时 ──────────────────────────── +def bench_op( + op, + op_tag: str, + inp, + wts, + idx, + wc_buf, + k, + rank: int, + world_size: int, + dev: torch.device, + warmup: int, + iters: int, + meta: dict, + scales=None, + packed_recv_x=None, +): + """无 profiler 的纯计时模式,输出 dispatch / combine 的 GPU 耗时(avg/min/max)。""" + _dkw = dict(packed_recv_x=packed_recv_x) if packed_recv_x is not None else {} + _ckw = dict(packed_recv_x=packed_recv_x) if packed_recv_x is not None else {} + ms.shmem_barrier_all() + if rank == 0: + print(f"\n[bench] {op_tag} 预热 {warmup} 轮...") + for _ in range(warmup): + op.reset() + ret = op.dispatch(inp, wts, scales, idx, **_dkw) + op.combine(ret[0], None, ret[3], **_ckw) + torch.cuda.synchronize() + dist.barrier() + + if rank == 0: + print(f"[bench] {op_tag} 计时 {iters} 轮...") + + d_events = [(torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)) for _ in range(iters)] + c_events = [(torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)) for _ in range(iters)] + + for i in range(iters): + # op.reset() + dist.barrier() + + d_events[i][0].record() + ret = op.dispatch(inp, wts, scales, idx, **_dkw) + d_events[i][1].record() + + dist.barrier() + + c_events[i][0].record() + op.combine(ret[0], None, ret[3], **_ckw) + c_events[i][1].record() + + torch.cuda.synchronize() + d_list = [d_events[i][0].elapsed_time(d_events[i][1]) * 1000 for i in range(iters)] + c_list = [c_events[i][0].elapsed_time(c_events[i][1]) * 1000 for i in range(iters)] + + # 全卡聚合 avg / min / max + local = torch.tensor( + [ + sum(d_list) / len(d_list), + min(d_list), + max(d_list), + sum(c_list) / len(c_list), + min(c_list), + max(c_list), + ], + dtype=torch.float64, + device=dev, + ) + s = local.clone() + dist.all_reduce(s, op=dist.ReduceOp.SUM) + mx = local.clone() + dist.all_reduce(mx, op=dist.ReduceOp.MAX) + mn = local.clone() + dist.all_reduce(mn, op=dist.ReduceOp.MIN) + avg_d = (s[0] / world_size).item() + mn_d = mn[0].item() + mx_d = mx[2].item() + avg_c = (s[3] / world_size).item() + mn_c = mn[3].item() + mx_c = mx[5].item() + + if rank == 0: + sep = "=" * 68 + tag = ( + f"{op_tag.upper()} EP={meta['world_size']} bs={meta['max_tokens']} " + f"h={meta['hidden_dim']} k={meta['k']} ({iters} iters)" + ) + print(f"\n{sep}\n {tag}\n 所有 {world_size} 张卡的 avg / min / max(μs/call)\n{sep}") + print(f" {'指标':<36} {'avg':>8} {'min':>8} {'max':>8}") + print(f" {'-'*58}") + print(f" {'[E2E] dispatch CUDA time':<36} {avg_d:>8.1f} {mn_d:>8.1f} {mx_d:>8.1f}") + print(f" {'[E2E] combine CUDA time':<36} {avg_c:>8.1f} {mn_c:>8.1f} {mx_c:>8.1f}") + print() + + +# ─── cudagraph 模式:CUDA Graph capture + replay 计时 ───────────────────────── +def _cudagraph_capture_flydsl(op, inp, wts, idx, wc_buf, capture_stream, scales=None, packed_recv_x=None): + """FlyDSL:录制 dispatch+combine 到 CUDA Graph。 + + dispatch/combine 均返回全尺寸 tensor(无 .item()、无动态切片)。 + 需要先 eager 调用一次触发 flyc.compile() JIT 编译(编译过程使用 + default stream,不能在 capture 期间执行),之后 capture 中仅录制 + 已编译的 kernel launch。 + """ + _dkw = dict(packed_recv_x=packed_recv_x) if packed_recv_x is not None else {} + _ckw = dict(packed_recv_x=packed_recv_x) if packed_recv_x is not None else {} + op.reset() + ret = op.dispatch(inp, wts, scales, idx, **_dkw) + op.combine(ret[0], None, ret[3], **_ckw) + + op.barrier() + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g, stream=capture_stream): + ret = op.dispatch(inp, wts, scales, idx, **_dkw) + op.combine(ret[0], None, ret[3], **_ckw) + return g, capture_stream + + +def _cudagraph_capture_mori(op, inp, wts, idx, wc_buf, capture_stream, scales=None, packed_recv_x=None): + """Mori 专用:直接在 graph capture 中录制 dispatch+combine。 + + Mori 的 dispatch 在 capture 模式下返回真实 tensor,combine kernel + 从 HBM 读取 totalRecvTokenNum,无需 pre-capture eager call。 + 参考 mori/tests/python/ops/bench_dispatch_combine.py stress_graph 写法。 + """ + ms.shmem_barrier_all() + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g, stream=capture_stream): + ret = op.dispatch(inp, wts, None, idx) + op.combine(ret[0], None, ret[3]) + return g, capture_stream + + +def cudagraph_op( + op, + op_tag: str, + inp, + wts, + idx, + wc_buf, + k, + rank: int, + world_size: int, + dev: torch.device, + warmup: int, + iters: int, + meta: dict, + scales=None, + packed_recv_x=None, +): + """CUDA Graph 模式:capture dispatch+combine kernel,replay 计时。""" + capture_stream = torch.cuda.Stream() + if op_tag == "flydsl": + g, cs = _cudagraph_capture_flydsl( + op, inp, wts, idx, wc_buf, capture_stream, scales=scales, packed_recv_x=packed_recv_x + ) + else: + g, cs = _cudagraph_capture_mori( + op, inp, wts, idx, wc_buf, capture_stream, scales=scales, packed_recv_x=packed_recv_x + ) + + if rank == 0: + print(f"\n[cudagraph] {op_tag} capture done") + + # replay warmup(HIP graph 冷启动 + GPU 缓存预热) + replay_warmup = 10 + if rank == 0: + print(f"[cudagraph] replay warmup {replay_warmup} 轮 + 计时 {iters} 轮(no-reset)...") + for _ in range(replay_warmup): + g.replay() + torch.cuda.synchronize() + + # 计时:预分配 event pairs,循环结束后统一 sync + events = [(torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)) for _ in range(iters)] + + for i in range(iters): + events[i][0].record() + g.replay() + events[i][1].record() + + torch.cuda.synchronize() + gpu_times = [events[i][0].elapsed_time(events[i][1]) * 1000 for i in range(iters)] + + # per-replay 诊断 + per_replay_t = torch.tensor(gpu_times, dtype=torch.float64, device=dev) + all_per_replay = [torch.zeros_like(per_replay_t) for _ in range(world_size)] + dist.all_gather(all_per_replay, per_replay_t) + + local = torch.tensor( + [ + sum(gpu_times) / len(gpu_times), + min(gpu_times), + max(gpu_times), + ], + dtype=torch.float64, + device=dev, + ) + s = local.clone() + dist.all_reduce(s, op=dist.ReduceOp.SUM) + mx = local.clone() + dist.all_reduce(mx, op=dist.ReduceOp.MAX) + mn = local.clone() + dist.all_reduce(mn, op=dist.ReduceOp.MIN) + avg_g = (s[0] / world_size).item() + mn_g = mn[0].item() + mx_g = mx[2].item() + + if rank == 0: + sep = "=" * 68 + tag = ( + f"{op_tag.upper()} [CUDAGraph] EP={meta['world_size']} " + f"bs={meta['max_tokens']} h={meta['hidden_dim']} k={meta['k']} " + f"({iters} replays)" + ) + print(f"\n{sep}\n {tag}\n 所有 {world_size} 张卡的 avg / min / max(μs/call)\n{sep}") + print(f" {'指标':<36} {'avg':>8} {'min':>8} {'max':>8}") + print(f" {'-'*58}") + print(f" {'[GPU] dispatch+combine (event)':<36} {avg_g:>8.1f} {mn_g:>8.1f} {mx_g:>8.1f}") + + print(f"\n Per-replay GPU time (μs) — all {world_size} ranks:") + hdr = f" {'replay':>6}" + "".join(f" {'R'+str(r):>8}" for r in range(world_size)) + f" {'max':>8}" + print(hdr) + mat = torch.stack(all_per_replay) + for i in range(iters): + vals = [mat[r, i].item() for r in range(world_size)] + mx_i = max(vals) + row = f" {i:>6}" + "".join(f" {v:>8.1f}" for v in vals) + f" {mx_i:>8.1f}" + if mx_i > avg_g * 3: + row += " ← SPIKE" + print(row) + print() + + +# ─── 单算子 profiler 采集 ────────────────────────────────────────────────────── +def profile_op( + op, + op_tag: str, + inp, + wts, + idx, + wc_buf, + k, + rank: int, + world_size: int, + dev: torch.device, + iters: int, + out_dir: str, + meta: dict, + scales=None, + packed_recv_x=None, + dtype_key: str = "bf16", + quant_type: str = "none", + use_p2p_read: bool = False, +): + """对单个算子(FlyDSL 或 mori)独立 profiling,保存 JSON 并打印全卡聚合统计。 + + 使用 schedule(wait=1, warmup=10, active=iters) 让 ROCTracer 在前 11 步 + 不做/轻量追踪,减少与多 GPU P2P shmem 操作的冲突。 + """ + ms.shmem_barrier_all() + prof_warmup = 10 + total_steps = iters + 1 + prof_warmup # wait=1 + warmup=prof_warmup + active=iters + if rank == 0: + print(f"\n[profiler] {op_tag} 开始采集({iters} 轮 active + {1 + prof_warmup} 轮 ramp-up)...") + + _dkw = dict(packed_recv_x=packed_recv_x) if packed_recv_x is not None else {} + _ckw = dict(packed_recv_x=packed_recv_x) if packed_recv_x is not None else {} + with _make_profiler(active_iters=iters, prof_warmup=prof_warmup) as prof: + for step in range(total_steps): + # with record_function(f"{op_tag}::reset"): + # op.reset() + dist.barrier() + + with record_function(f"{op_tag}::dispatch"): + ret = op.dispatch(inp, wts, scales, idx, **_dkw) + + dist.barrier() + + with record_function(f"{op_tag}::combine"): + op.combine(ret[0], None, ret[3], **_ckw) + + # dist.barrier() + + prof.step() + + # 保存 JSON:每张卡各自保存,文件名含 op_tag 和 rank + out_path = os.path.join(out_dir, f"{op_tag}_rank{rank}.json") + _save_profile_json(prof, out_path, rank, op_tag, meta) + if rank == 0: + print(f"[profiler] {op_tag} trace → {out_path}") + + # 跨卡聚合统计(all_reduce),rank 0 打印 + agg_stats = _allreduce_stats( + prof, op_tag, rank, world_size, dev, dtype_key=dtype_key, quant_type=quant_type, use_p2p_read=use_p2p_read + ) + if rank == 0: + _print_aggregated(agg_stats, op_tag, world_size, meta) + return prof + + +# ─── profile + cudagraph 模式 ───────────────────────────────────────────────── +def profile_cudagraph_op( + op, + op_tag: str, + inp, + wts, + idx, + wc_buf, + k, + rank: int, + world_size: int, + dev: torch.device, + warmup: int, + iters: int, + out_dir: str, + meta: dict, + scales=None, + packed_recv_x=None, + dtype_key: str = "bf16", + quant_type: str = "none", + use_p2p_read: bool = False, +): + """torch.profiler 采集 CUDAGraph replay,保存 JSON 并打印全卡聚合统计。 + + 流程:eager warmup → graph capture → replay warmup → profiler 包裹的 replay。 + """ + ms.shmem_barrier_all() + + capture_stream = torch.cuda.Stream() + if op_tag == "flydsl": + g, cs = _cudagraph_capture_flydsl( + op, inp, wts, idx, wc_buf, capture_stream, scales=scales, packed_recv_x=packed_recv_x + ) + else: + g, cs = _cudagraph_capture_mori( + op, inp, wts, idx, wc_buf, capture_stream, scales=scales, packed_recv_x=packed_recv_x + ) + + if rank == 0: + print(f"\n[profile+cudagraph] {op_tag} capture done") + + # replay warmup(HIP graph 冷启动 + GPU 缓存预热) + replay_warmup = 10 + for _ in range(replay_warmup): + g.replay() + torch.cuda.synchronize() + + prof_warmup = 5 + active_iters = iters + skip_first = 5 + valid_iters = max(active_iters - skip_first, 1) + total_steps = 1 + prof_warmup + active_iters # wait=1 + warmup + active + if rank == 0: + print( + f"[profile+cudagraph] {op_tag} scheduled profiler: " + f"warmup={prof_warmup}, active={active_iters}, " + f"丢弃前 {skip_first} 次,有效 {valid_iters} 次(no-reset)..." + ) + + with _make_profiler(active_iters=active_iters, prof_warmup=prof_warmup) as prof: + for step in range(total_steps): + with record_function(f"{op_tag}::cudagraph_replay"): + g.replay() + prof.step() + + out_path = os.path.join(out_dir, f"{op_tag}_cudagraph_rank{rank}.json") + _save_profile_json(prof, out_path, rank, op_tag, meta) + trace_path = out_path.replace(".json", "_trace.json") + if rank == 0: + print(f"[profile+cudagraph] {op_tag} trace → {trace_path}") + + agg_stats = _cudagraph_stats_from_trace( + trace_path, + op_tag, + rank, + world_size, + dev, + active_iters=active_iters, + skip_first=skip_first, + dtype_key=dtype_key, + quant_type=quant_type, + use_p2p_read=use_p2p_read, + ) + if rank == 0: + _print_cudagraph_aggregated(agg_stats, op_tag, world_size, meta, active_iters=valid_iters) + return prof + + +# ─── verify 模式:正确性验证 ───────────────────────────────────────────────── +VERIFY_TOL = { + "f32": {"atol": 1e-5, "rtol": 1e-4}, + "bf16": {"atol": 1e-2, "rtol": 1e-2}, + "fp8_ocp": {"atol": 1e-1, "rtol": 5e-2}, + "fp8_fnuz": {"atol": 1e-1, "rtol": 5e-2}, + "fp4": {"atol": 5e-1, "rtol": 1e-1}, +} + + +def _check_close(name, a, b, atol, rtol, rank, cast_to=None): + """Compare two tensors and print PASS/FAIL.""" + if cast_to is not None: + a, b = a.to(cast_to), b.to(cast_to) + ok = torch.allclose(a, b, atol=atol, rtol=rtol) + max_diff = (a.float() - b.float()).abs().max().item() + status = "PASS" if ok else "FAIL" + if rank == 0: + print(f" [{status}] {name:40s} max_diff={max_diff:.6g} atol={atol} rtol={rtol}") + return ok + + +def _check_exact(name, a, b, rank): + """Compare two tensors for exact equality.""" + ok = torch.equal(a, b) + if not ok: + diff_count = (a != b).sum().item() + status = "FAIL" + else: + diff_count = 0 + status = "PASS" + if rank == 0: + print(f" [{status}] {name:40s} diff_elements={diff_count}") + return ok + + +def verify_self(op_fly, inp, wts, idx, k, rank, world_size, dev, dtype_key, cfg): + """FlyDSL self-check when mori is unavailable. + + dispatch → combine → verify output ≈ weighted sum of input. + With uniform weights (1/k) and k distinct PEs, combine output should ≈ input. + """ + tol = VERIFY_TOL.get(dtype_key, VERIFY_TOL["bf16"]) + if cfg.quant_type == "fp8_direct_cast": + tol = {"atol": 2.0 * k, "rtol": 0.5} + all_pass = True + + if rank == 0: + print(f"\n{'='*65}") + print( + f" VERIFY (self-check, mori unavailable) dtype={dtype_key} " + f"EP={world_size} bs={inp.shape[0]} h={cfg.hidden_dim} k={k}" + ) + print(f"{'='*65}") + + op_fly.reset() + ms.shmem_barrier_all() + + packed_recv_x = None + if cfg.enable_std_moe: + epr = cfg.num_experts_per_rank + mr = cfg.max_recv + _prx_nbytes = epr * mr * cfg.token_bytes + packed_recv_x = ( + torch.zeros(_prx_nbytes, dtype=torch.uint8, device=dev) + .view(cfg.data_type) + .view(epr * mr, cfg.token_view_dim) + ) + + scales = None + if cfg.scale_dim > 0 and cfg.scale_type_size > 0: + _sc_bytes = cfg.scale_dim * cfg.scale_type_size + scales = torch.randn(inp.shape[0], _sc_bytes // 4, dtype=torch.float32, device=dev).contiguous() + scales = scales.view(torch.uint8).view(inp.shape[0], _sc_bytes) + + ret_f = op_fly.dispatch(inp, wts, scales, idx, packed_recv_x=packed_recv_x) + torch.cuda.synchronize() + dist.barrier() + + if rank == 0: + tr = ret_f[4].item() + print(f"\n total_recv = {tr}") + + cout_f = op_fly.combine(ret_f[0], None, ret_f[3], packed_recv_x=packed_recv_x) + torch.cuda.synchronize() + dist.barrier() + + mt = cfg.max_num_inp_token_per_rank + f_tok = cout_f[0][:mt] + + if cfg.enable_std_moe: + scale_factor = 1 + check_label = "out_tok vs inp (StdMoE weighted)" + else: + scale_factor = k + check_label = "out_tok vs k*inp" + + if rank == 0: + print(f"\n ── Self-check: combine output vs {'inp' if scale_factor == 1 else 'k*input'} ──") + if cfg.data_type == torch.float4_e2m1fn_x2: + if k == 1 and not cfg.enable_std_moe: + ok = torch.equal(f_tok.view(torch.uint8), inp.view(torch.uint8)) + status = "PASS" if ok else "FAIL" + print(f" [{status}] out_tok vs inp (byte-level, k=1)") + all_pass &= ok + else: + print(f" [SKIP] fp4 numeric check not supported " f"(k={k}, std_moe={cfg.enable_std_moe})") + else: + cast_to = torch.float32 if cfg.data_type in (torch.float8_e4m3fn, torch.float8_e4m3fnuz) else None + try: + expected = (inp.float() * scale_factor).to(cfg.data_type) + all_pass &= _check_close(check_label, f_tok, expected, tol["atol"], tol["rtol"], rank, cast_to=cast_to) + except Exception as e: + has_nan = torch.isnan(f_tok.float()).any().item() + has_inf = torch.isinf(f_tok.float()).any().item() + print(f" [INFO] Self-check exception (NaN={has_nan}, Inf={has_inf}): {e}") + all_pass &= not has_nan and not has_inf + + if rank == 0: + result = "ALL PASS" if all_pass else "SOME FAILED" + print(f"\n >>> {result} <<<\n") + return all_pass + + +def verify_op(op_fly, op_mori, inp, wts, idx, k, rank, world_size, dev, dtype_key, cfg, args): + """Run FlyDSL and mori dispatch+combine, compare outputs. + + Dispatch output ordering is non-deterministic (atomic fetch-and-add), so + we only compare total_recv. Combine output is the final accumulated result + and should be semantically identical. + """ + tol = VERIFY_TOL.get(dtype_key, VERIFY_TOL["bf16"]) + all_pass = True + + if rank == 0: + print(f"\n{'='*65}") + print(f" VERIFY dtype={dtype_key} EP={world_size} bs={inp.shape[0]} " f"h={cfg.hidden_dim} k={k}") + print(f"{'='*65}") + + # ── Dispatch ── + op_fly.reset() + op_mori.reset() + ms.shmem_barrier_all() + + ret_f = op_fly.dispatch(inp, wts, None, idx) + ret_m = op_mori.dispatch(inp, wts, None, idx) + torch.cuda.synchronize() + + tr_f = ret_f[4].clone() + tr_m = ret_m[4].clone() + dist.barrier() + if rank == 0: + print("\n ── Dispatch 对比(仅 total_recv,token 排列因原子序不同) ──") + + all_pass &= _check_exact("total_recv", tr_f, tr_m, rank) + + # ── Combine ── + ms.shmem_barrier_all() + cout_f = op_fly.combine(ret_f[0], None, ret_f[3]) + cout_m = op_mori.combine(ret_m[0], None, ret_m[3]) + torch.cuda.synchronize() + dist.barrier() + + if rank == 0: + print("\n ── Combine 输出对比 ──") + + mt = cfg.max_num_inp_token_per_rank + cast_to = ( + torch.float32 if cfg.data_type in (torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float4_e2m1fn_x2) else None + ) + + f_tok = cout_f[0][:mt] if cout_f[0] is not None else None + m_tok = cout_m[0][:mt] if cout_m[0] is not None else None + + # Diagnostic: compare both outputs with expected k*input (skip for packed types) + if rank == 0 and f_tok is not None and m_tok is not None: + try: + expected = (inp.float() * k).to(cfg.data_type) + f_vs_exp = (f_tok.float() - expected.float()).abs().max().item() + m_vs_exp = (m_tok.float() - expected.float()).abs().max().item() + f_vs_m = (f_tok.float() - m_tok.float()).abs().max().item() + print(f" [DIAG] fly vs k*inp: {f_vs_exp:.4f} mori vs k*inp: {m_vs_exp:.4f} fly vs mori: {f_vs_m:.4f}") + except Exception: + pass + + if f_tok is not None and m_tok is not None: + all_pass &= _check_close("out_tok[:mt]", f_tok, m_tok, tol["atol"], tol["rtol"], rank, cast_to=cast_to) + elif rank == 0: + print(f" [SKIP] out_tok: fly={f_tok is not None}, mori={m_tok is not None}") + + f_wts = cout_f[1] if (len(cout_f) > 1 and cout_f[1] is not None) else None + m_wts = cout_m[1] if (len(cout_m) > 1 and cout_m[1] is not None) else None + if f_wts is not None and m_wts is not None: + all_pass &= _check_close("out_wts[:mt]", f_wts[:mt], m_wts[:mt], 1e-4, 1e-3, rank) + elif rank == 0: + print(f" [SKIP] out_wts: fly={f_wts is not None}, mori={m_wts is not None}") + + if rank == 0: + result = "ALL PASS" if all_pass else "SOME FAILED" + print(f"\n >>> {result} <<<\n") + return all_pass + + +# ─── 主逻辑 ─────────────────────────────────────────────────────────────────── +def run_profiler(rank, world_size, args): + dev = torch.device("cuda", rank) + k = args.k + cur_tok = args.max_tokens + n_exp = world_size * args.num_experts_per_rank + + _dtype = DTYPE_MAP.get(args.dtype, torch.bfloat16) + cfg = FlyDSLDispatchCombineConfig( + rank=rank, + world_size=world_size, + hidden_dim=args.hidden_dim, + max_num_inp_token_per_rank=cur_tok, + num_experts_per_rank=args.num_experts_per_rank, + num_experts_per_token=k, + data_type=_dtype, + warp_num_per_block=args.warp_per_block, + block_num=args.block_num, + chip=args.chip, + use_external_inp_buf=args.use_external_inp_buf, + enable_std_moe=args.enable_std_moe, + scale_dim=args.scale_dim, + scale_type_size=args.scale_type_size, + quant_type=args.quant_type, + ) + + mori_bn = args.mori_block_num if args.mori_block_num > 0 else cfg.block_num + mori_wpb = args.mori_warp_per_block if args.mori_warp_per_block > 0 else cfg.warp_num_per_block + meta = dict( + world_size=world_size, + max_tokens=cur_tok, + hidden_dim=cfg.hidden_dim, + k=k, + num_experts_per_rank=args.num_experts_per_rank, + warmup=args.warmup, + iters=args.iters, + flydsl_block_num=cfg.block_num, + flydsl_warp_per_block=cfg.warp_num_per_block, + mori_block_num=mori_bn, + mori_warp_per_block=mori_wpb, + use_external_inp_buf=cfg.use_external_inp_buf, + enable_std_moe=cfg.enable_std_moe, + scale_dim=cfg.scale_dim, + scale_type_size=cfg.scale_type_size, + quant_type=cfg.quant_type, + ) + + # 输出目录:/tmp/ep{ws}_bs{cur_tok}/ + out_dir = os.path.join(args.output_dir, f"ep{world_size}_bs{cur_tok}") + os.makedirs(out_dir, exist_ok=True) + + # ── 构建算子 ─────────────────────────────────────────────────────────────── + if rank == 0: + print(f"\n{'='*65}") + print(f"[profiler] EP={world_size}, bs={cur_tok}, h={cfg.hidden_dim}, k={k}") + print(f"{'='*65}") + print("[profiler] 构建 FlyDSL...") + op_fly = FlyDSLDispatchCombineIntraNodeOp(cfg) + + op_ref = None + if args.compare and not cfg.enable_std_moe: + mori_bn = args.mori_block_num if args.mori_block_num > 0 else None + mori_wpb = args.mori_warp_per_block if args.mori_warp_per_block > 0 else None + bn_str = mori_bn if mori_bn else cfg.block_num + wpb_str = mori_wpb if mori_wpb else cfg.warp_num_per_block + if rank == 0: + print(f"[profiler] 构建 mori ref (block_num={bn_str}, warp_per_block={wpb_str})...") + try: + op_ref = build_mori_ref(rank, world_size, cfg, block_num=mori_bn, warp_per_block=mori_wpb) + except Exception as e: + if rank == 0: + print(f"[warn] mori ref 不可用: {e}") + elif cfg.enable_std_moe and rank == 0: + print("[info] StdMoE 模式:跳过 mori ref,使用自洽验证") + ms.shmem_barrier_all() + + # ── 准备输入(固定 seed,FlyDSL 和 mori 使用完全相同的输入)──────────────── + torch.manual_seed(42 + rank) + if cfg.data_type == torch.float4_e2m1fn_x2: + inp = torch.randint(0, 256, (cur_tok, cfg.hidden_dim // 2), dtype=torch.uint8, device=dev).view( + torch.float4_e2m1fn_x2 + ) + elif cfg.data_type in (torch.float8_e4m3fn, torch.float8_e4m3fnuz): + inp = torch.randn(cur_tok, cfg.hidden_dim, dtype=torch.bfloat16, device=dev).to(cfg.data_type) + else: + inp = torch.randn(cur_tok, cfg.hidden_dim, dtype=cfg.data_type, device=dev) + wts = torch.rand(cur_tok, k, dtype=torch.float32, device=dev) + wts = wts / wts.sum(-1, keepdim=True) + epr = args.num_experts_per_rank + idx = torch.zeros(cur_tok, k, dtype=torch.int32, device=dev) + if args.mode == "verify" and k <= world_size: + # Ensure each token's k experts go to k DISTINCT PEs. + # FlyDSL dispatch deduplicates same-PE assignments, mori does not. + for t in range(cur_tok): + pes = torch.randperm(world_size, device=dev)[:k] + for j in range(k): + idx[t, j] = pes[j] * epr + torch.randint(0, epr, (1,), device=dev) + else: + for t in range(cur_tok): + idx[t] = torch.randperm(n_exp, device=dev)[:k] + + # 预分配 combine 权重 buffer(FlyDSL 和 mori 共用,避免计时窗口内额外 GPU 核) + max_recv = world_size * cur_tok + wc_buf = torch.full((max_recv, k), 1.0 / k, dtype=torch.float32, device=dev) + + # ── 构造 scales / packed_recv_x(所有模式共用)───────────────────────── + packed_recv_x = None + if cfg.enable_std_moe: + _prx_nbytes = cfg.num_experts_per_rank * cfg.max_recv * cfg.token_bytes + packed_recv_x = ( + torch.zeros(_prx_nbytes, dtype=torch.uint8, device=dev) + .view(cfg.data_type) + .view(cfg.num_experts_per_rank * cfg.max_recv, cfg.token_view_dim) + ) + + scales = None + if cfg.scale_dim > 0 and cfg.scale_type_size > 0: + _sc_bytes = cfg.scale_dim * cfg.scale_type_size + scales = torch.randn(cur_tok, _sc_bytes // 4, dtype=torch.float32, device=dev).contiguous() + scales = scales.view(torch.uint8).view(cur_tok, _sc_bytes) + + # profile+eager 模式需要外部预热;其他 3 种组合由各自函数内部处理 + do_warmup = args.mode == "profile" and not args.cudagraph + + if do_warmup: + if rank == 0: + print(f"[setup] 预热 FlyDSL {args.warmup} 轮...") + for _ in range(args.warmup): + op_fly.reset() + ret = op_fly.dispatch(inp, wts, scales, idx, packed_recv_x=packed_recv_x) + op_fly.combine(ret[0], None, ret[3], packed_recv_x=packed_recv_x) + torch.cuda.synchronize() + + if op_ref is not None: + if rank == 0: + print(f"[setup] 预热 mori ref {args.warmup} 轮...") + for _ in range(args.warmup): + op_ref.reset() + ret_r = op_ref.dispatch(inp, wts, None, idx) + op_ref.combine(ret_r[0], None, ret_r[3]) + torch.cuda.synchronize() + + ms.shmem_barrier_all() + + # ── 根据 mode × cudagraph 分发执行 ───────────────────────────────────── + test_flydsl = args.bench_op in ("flydsl", "both") + test_mori = args.bench_op in ("mori", "both") and op_ref is not None + + if args.mode == "verify": + if op_ref is not None: + verify_op(op_fly, op_ref, inp, wts, idx, k, rank, world_size, dev, args.dtype, cfg, args) + else: + verify_self(op_fly, inp, wts, idx, k, rank, world_size, dev, args.dtype, cfg) + return + + if args.mode == "bench" and not args.cudagraph: + if test_flydsl: + bench_op( + op_fly, + "flydsl", + inp, + wts, + idx, + wc_buf, + k, + rank, + world_size, + dev, + args.warmup, + args.iters, + meta, + scales=scales, + packed_recv_x=packed_recv_x, + ) + if test_mori: + ms.shmem_barrier_all() + bench_op(op_ref, "mori", inp, wts, idx, wc_buf, k, rank, world_size, dev, args.warmup, args.iters, meta) + + elif args.mode == "bench" and args.cudagraph: + if test_flydsl: + cudagraph_op( + op_fly, + "flydsl", + inp, + wts, + idx, + wc_buf, + k, + rank, + world_size, + dev, + args.warmup, + args.iters, + meta, + scales=scales, + packed_recv_x=packed_recv_x, + ) + if test_mori: + ms.shmem_barrier_all() + cudagraph_op(op_ref, "mori", inp, wts, idx, wc_buf, k, rank, world_size, dev, args.warmup, args.iters, meta) + + elif args.mode == "profile" and not args.cudagraph: + _p2p = not args.use_external_inp_buf + if test_flydsl: + profile_op( + op_fly, + "flydsl", + inp, + wts, + idx, + wc_buf, + k, + rank, + world_size, + dev, + args.iters, + out_dir, + meta, + scales=scales, + packed_recv_x=packed_recv_x, + dtype_key=args.dtype, + quant_type=args.quant_type, + use_p2p_read=_p2p, + ) + if test_mori: + ms.shmem_barrier_all() + profile_op( + op_ref, + "mori", + inp, + wts, + idx, + wc_buf, + k, + rank, + world_size, + dev, + args.iters, + out_dir, + meta, + dtype_key=args.dtype, + quant_type=args.quant_type, + use_p2p_read=_p2p, + ) + if rank == 0: + print(f"\n[profiler] 全部结果已保存到: {out_dir}/") + + elif args.mode == "profile" and args.cudagraph: + _p2p = not args.use_external_inp_buf + if test_flydsl: + profile_cudagraph_op( + op_fly, + "flydsl", + inp, + wts, + idx, + wc_buf, + k, + rank, + world_size, + dev, + args.warmup, + args.iters, + out_dir, + meta, + scales=scales, + packed_recv_x=packed_recv_x, + dtype_key=args.dtype, + quant_type=args.quant_type, + use_p2p_read=_p2p, + ) + if test_mori: + ms.shmem_barrier_all() + profile_cudagraph_op( + op_ref, + "mori", + inp, + wts, + idx, + wc_buf, + k, + rank, + world_size, + dev, + args.warmup, + args.iters, + out_dir, + meta, + dtype_key=args.dtype, + quant_type=args.quant_type, + use_p2p_read=_p2p, + ) + if rank == 0: + print(f"\n[profiler] 全部结果已保存到: {out_dir}/") + + +# ─── Worker / 命令行入口 ────────────────────────────────────────────────────── +def _worker(rank, world_size, args, master_port): + setup_distributed(rank, world_size, master_port) + try: + run_profiler(rank, world_size, args) + except Exception as e: + import traceback as tb + + print(f"[rank {rank}] ERROR: {e}") + tb.print_exc() + finally: + cleanup() + + +def _parse_args(): + p = argparse.ArgumentParser(description="torch.profiler 分析 dispatch/combine") + p.add_argument("--world-size", type=int, default=8) + p.add_argument("--max-tokens", type=int, default=512) + p.add_argument("--hidden-dim", type=int, default=7168) + p.add_argument("--num-experts-per-rank", type=int, default=32) + p.add_argument("--k", type=int, default=8) + p.add_argument("--block-num", type=int, default=80) + p.add_argument("--warp-per-block", type=int, default=4) + p.add_argument( + "--mori-block-num", type=int, default=0, help="mori 专用 block_num(0=与FlyDSL相同,mori默认最优=80)" + ) + p.add_argument( + "--mori-warp-per-block", type=int, default=0, help="mori 专用 warp_per_block(0=与FlyDSL相同,mori默认最优=8)" + ) + p.add_argument("--chip", type=str, default="gfx950") + p.add_argument("--dtype", type=str, default="bf16", choices=list(DTYPE_MAP.keys()), help="数据类型(默认 bf16)") + p.add_argument("--warmup", type=int, default=5, help="预热轮次(不进 profiler,确保 JIT 编译完成)") + p.add_argument("--iters", type=int, default=5, help="profiler 采集轮次") + p.add_argument( + "--output-dir", + type=str, + default="dispatch_profile", + help="JSON 输出根目录(相对当前目录),子目录按 ep{ws}_bs{tok} 命名", + ) + p.add_argument("--port", type=int, default=29800) + p.add_argument("--no-compare", dest="compare", action="store_false") + # ── 模式选择 ────────────────────────────────────────────────────────────── + p.add_argument( + "--mode", + choices=["profile", "bench", "verify"], + default="profile", + help="测量方式:profile=torch.profiler 采集(默认); bench=CUDA Event 计时; verify=正确性验证", + ) + p.add_argument("--cudagraph", action="store_true", help="使用 CUDAGraph capture+replay 执行(默认 eager)") + p.add_argument("--bench-op", choices=["flydsl", "mori", "both"], default="both", help="测哪个算子(默认 both)") + # ── 功能开关 ────────────────────────────────────────────────────────────── + p.add_argument( + "--no-external-inp-buf", + dest="use_external_inp_buf", + action="store_false", + default=True, + help="使用 P2P Read combine 变体(默认使用 external inp buf)", + ) + p.add_argument("--enable-std-moe", action="store_true", default=False, help="启用 Standard MoE Adapt 模式") + p.add_argument("--scale-dim", type=int, default=0, help="Scale 张量维度(0=不使用 scale)") + p.add_argument("--scale-type-size", type=int, default=0, help="Scale 类型大小(字节,0=不使用 scale)") + p.add_argument( + "--quant-type", + type=str, + default="none", + choices=["none", "fp8_direct_cast"], + help="量化类型(none=默认,fp8_direct_cast=FP8直接转换combine)", + ) + p.set_defaults(compare=True) + return p.parse_args() + + +def main(): + args = _parse_args() + if "LOCAL_RANK" in os.environ: + rank = int(os.environ["LOCAL_RANK"]) + world_size = int(os.environ.get("WORLD_SIZE", args.world_size)) + _worker(rank, world_size, args, master_port=args.port) + else: + ws = min(args.world_size, torch.cuda.device_count()) + if ws < args.world_size: + print(f"[warn] 可用 GPU={torch.cuda.device_count()}, " f"world_size 调整: {args.world_size} → {ws}") + torch.multiprocessing.spawn( + _worker, + args=(ws, args, args.port), + nprocs=ws, + join=True, + ) + + +if __name__ == "__main__": + main() From 81e859d75298ad5d7398eee9ec5747738a730a9e Mon Sep 17 00:00:00 2001 From: yanboshao Date: Mon, 18 May 2026 02:38:35 +0000 Subject: [PATCH 3/4] ci(dispatch_combine): restrict tests to 8-GPU runners The dispatch/combine intranode test depends on mori shmem, which is only installed on the 8-GPU multi-gpu CI runners. Previously pytest collection on single-GPU / Navi-2-GPU runners would crash because ``import mori`` raises ModuleNotFoundError at module load time. * tests/kernels/test_profiler_dispatch_combine.py: when imported under pytest collection (detected via ``"pytest" in sys.modules``), call ``pytest.importorskip("mori")`` so single/dual-GPU jobs cleanly skip this file. Direct ``torchrun``/``python`` invocations are unaffected and still surface a normal ImportError when mori is genuinely missing. * .github/workflows/flydsl.yaml: add two explicit multi-GPU steps to the multi-gpu job that run the dispatch/combine verify torchrun script (default config + --enable-std-moe). These only execute when the PR carries the ``multi-gpu`` label, providing real 8-GPU coverage for the new kernel. * kernels/dispatch_combine_intranode_op.py: drop unused local ``_disp_wpb`` alias, use ``config.warp_num_per_block`` directly. --- .github/workflows/flydsl.yaml | 18 ++ kernels/dispatch_combine_intranode_kernel.py | 1 - kernels/dispatch_combine_intranode_op.py | 112 ++++--- .../kernels/test_profiler_dispatch_combine.py | 300 ++++++++++-------- 4 files changed, 254 insertions(+), 177 deletions(-) diff --git a/.github/workflows/flydsl.yaml b/.github/workflows/flydsl.yaml index 8e4df09cf..19c699149 100644 --- a/.github/workflows/flydsl.yaml +++ b/.github/workflows/flydsl.yaml @@ -424,6 +424,24 @@ jobs: -m multi_gpu -v --no-header --tb=short " + - name: Run multi-GPU dispatch/combine verify (8-GPU only) + timeout-minutes: 10 + run: | + docker exec flydsl_test bash -c " + cd /flydsl-test && + torchrun --nproc-per-node=8 --master-port=29503 \ + tests/kernels/test_profiler_dispatch_combine.py --mode verify + " + + - name: Run multi-GPU dispatch/combine verify --enable-std-moe (8-GPU only) + timeout-minutes: 10 + run: | + docker exec flydsl_test bash -c " + cd /flydsl-test && + torchrun --nproc-per-node=8 --master-port=29504 \ + tests/kernels/test_profiler_dispatch_combine.py --mode verify --enable-std-moe + " + - name: Run multi-GPU allreduce tests timeout-minutes: 30 run: | diff --git a/kernels/dispatch_combine_intranode_kernel.py b/kernels/dispatch_combine_intranode_kernel.py index 087aa6f8f..4ae259519 100644 --- a/kernels/dispatch_combine_intranode_kernel.py +++ b/kernels/dispatch_combine_intranode_kernel.py @@ -283,7 +283,6 @@ def ep_dispatch_intranode( if lane == 0: buffer_store(tok_map_entry, _r_tok_map, work_idx) - if lane == 0: if dup_ballot == 0: # Publish the (src_pe, src_lid) origin so the dest PE # can later route the token back during combine. diff --git a/kernels/dispatch_combine_intranode_op.py b/kernels/dispatch_combine_intranode_op.py index 1220f2a9c..021d2a59b 100644 --- a/kernels/dispatch_combine_intranode_op.py +++ b/kernels/dispatch_combine_intranode_op.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # Copyright (c) 2025 FlyDSL Project Contributors -"""FlyDSL DispatchCombine IntraNode 算子包装器。""" +"""Python wrapper for the FlyDSL intra-node DispatchCombine op.""" from __future__ import annotations @@ -106,7 +106,6 @@ def __init__(self, config): self._p2p_comb_inp_wts[pe] = ms.shmem_ptr_p2p(self.shmem_comb_inp_wts.data_ptr(), r, pe) self._p2p_xdb_mem[pe] = ms.shmem_ptr_p2p(self.shmem_xdev_bar_mem.data_ptr(), r, pe) - _disp_wpb = config.warp_num_per_block self._disp_fn = make_dispatch_jit( rank=r, npes=config.world_size, @@ -115,7 +114,7 @@ def __init__(self, config): hidden_dim=config.hidden_dim, max_tok_per_rank=config.max_num_inp_token_per_rank, block_num=config.block_num, - warp_num_per_block=_disp_wpb, + warp_num_per_block=config.warp_num_per_block, data_type=config.data_type, scale_dim=config.scale_dim, scale_type_size=config.scale_type_size, @@ -139,7 +138,7 @@ def __init__(self, config): hidden_dim=config.hidden_dim, max_tok_per_rank=config.max_num_inp_token_per_rank, block_num=config.block_num, - warp_num_per_block=_disp_wpb, + warp_num_per_block=config.warp_num_per_block, data_type=_comb_dtype, enable_weights=True, enable_std_moe=config.enable_std_moe, @@ -148,7 +147,9 @@ def __init__(self, config): ) self._use_fp8_cast = _use_fp8_cast - # barrier flag 初始值必须为 1, 否则首次 wait_until_equals(slot, 0) 立即满足 + # The cross-device barrier flag must start at 1; otherwise the very + # first wait_until_equals(slot, 0) would be satisfied immediately by + # the zero-initialized memory and skip the actual synchronization. self._xdev_flag = torch.ones(1, dtype=torch.int64, device=self._dev) self._fx_out_tok = fx.Int64(self.shmem_disp_out_tok.data_ptr()) @@ -161,11 +162,11 @@ def __init__(self, config): self._fx_tok_map = fx.Int64(self.dest_tok_map.data_ptr()) self._fx_tis = fx.Int64(self.shmem_tok_id_to_src.data_ptr()) self._fx_total_rv = fx.Int64(self.total_recv.data_ptr()) - # combine 固定地址 self._fx_comb_inp = fx.Int64(self.shmem_comb_inp_tok.data_ptr()) self._fx_comb_out = fx.Int64(self.shmem_comb_out_tok.data_ptr()) self._fx_xdb_mem = fx.Int64(self.shmem_xdev_bar_mem.data_ptr()) self._fx_xdev_flag = fx.Int64(self._xdev_flag.data_ptr()) + self._fx_comb_bar = fx.Int64(self.comb_bar.data_ptr()) self._fx_trecv = fx.Int64(self.total_recv.data_ptr()) self._fx_p2p_tok_off = fx.Int64(self._p2p_tok_off.data_ptr()) @@ -189,9 +190,11 @@ def __init__(self, config): self._disp_compiled = None self._comb_compiled = None - # combine kernel 的 skip_stage1 变体:给 fused_gemm2_combine 算子使用, - # 此时 fused kernel 已经把 token / 权重 P2P 写入 shmem_comb_inp[_wts], - # combine 只跑 Stage 2 (CrossDeviceBarrier) + Stage 3 (本地 weighted-accum)。 + # Lazy-compiled skip_stage1 combine variant used by the fused + # GEMM2-combine path: the upstream fused kernel has already P2P- + # scattered tokens / weights into shmem_comb_inp[_wts], so the + # combine kernel only runs Stage 2 (CrossDeviceBarrier) + Stage 3 + # (local weighted-accum). self._comb_no_s1_fn = None self._comb_no_s1_compiled = None @@ -221,17 +224,21 @@ def _alloc_buffers(self): self.shmem_comb_out_wts = mori_shmem_create_tensor((mt * k,), torch.float32) self.shmem_xdev_bar_mem = mori_shmem_create_tensor((npes,), torch.int64) - # mori_shmem_create_tensor 走 shmem_malloc,分配的是未初始化的 raw memory。 - # 对 fused MoE-GEMM2 + EP-Combine 路径,GEMM2 需要在 epilogue 用 - # shmem_tok_id_to_src 解码 dest_pe / dest_lid,越界 garbage 会触发 LDS - # OOB → 写到任意全局地址 → 破坏 control state。这里把所有 combine 路径 - # 直接读写的 symmetric buffer 显式清零,保证: - # - shmem_tok_id_to_src[t] 对未被 dispatch 写入的 t 解码为 (pe=0, lid=0), - # P2P scatter 退化成"安全无副作用"(多写同一槽位) - # - shmem_xdev_bar_mem 起始 0,CrossDeviceBarrier 第一次 wait 不会读到 - # 残留值(依赖 cur_flag 单调递增) - # - shmem_comb_inp_{tok,wts} 起始 0,combine_no_stage1 在 stage 3 累加 - # 时不会读到 garbage + # mori_shmem_create_tensor goes through shmem_malloc, which returns + # uninitialized raw memory. In the fused MoE-GEMM2 + EP-Combine + # path the GEMM2 epilogue decodes (dest_pe, dest_lid) from + # shmem_tok_id_to_src; out-of-bounds garbage there would trigger + # LDS OOB and corrupt arbitrary global state, so every symmetric + # buffer that combine touches directly is zeroed up-front to + # guarantee: + # - shmem_tok_id_to_src[t] for slots never written by dispatch + # decodes to (pe=0, lid=0), making the P2P scatter degenerate + # into a harmless duplicate write into a single slot; + # - shmem_xdev_bar_mem starts at 0 so the first + # CrossDeviceBarrier wait never observes stale data (the + # protocol relies on cur_flag monotonic increase); + # - shmem_comb_inp_{tok,wts} start at 0 so combine_no_stage1's + # Stage 3 accumulation never folds garbage into the result. self.shmem_tok_id_to_src.zero_() self.shmem_comb_inp_tok.zero_() self.shmem_comb_inp_wts.zero_() @@ -483,29 +490,35 @@ def combine( def combine_no_stage1( self, input, weights, indices, packed_recv_x=None, cur_tok=None, call_reset=False, enable_weights: bool = True ): - """combine 的 stage1-skipped 变体。 + """Skip-Stage1 variant of ``combine``. - 语义:跳过 P2P scatter(外部 fused kernel 已把数据写入 shmem_comb_inp[_wts]), - 只执行 Stage 2 (CrossDeviceBarrier) + Stage 3 (本地 weighted-accum)。 + Semantics: bypass the P2P scatter (the upstream fused kernel has + already populated shmem_comb_inp[_wts]) and only run Stage 2 + (CrossDeviceBarrier) + Stage 3 (local weighted-accum). Parameters ---------- enable_weights - ``True`` (默认) 兼容当前 fused-with-weight 链路:在 combine - kernel 内保留 Stage 1 的 weight scatter + Stage 3b 的 weight - accumulate。weight scatter 显式留在 combine kernel 内(而不是 - 放在上游 fused GEMM2 的 epilogue 里),因为 16B 小写若与上游 - token P2P 并发会被 ROCm IPC fabric 静默丢,必须放在静态 fabric - 上由 combine kernel 完成。 - ``False`` 给 weight-free fused 路径(fused MoE 上游已经把 - weight 处理掉了,combine 端不需要 out_wts):完全 DCE 掉 - weight scatter + Stage 3b,省 ~3-5 μs。 - 两种变体走不同的 JIT 缓存,互不污染。 - - 约定:调用前 fused kernel 必须保证: - - shmem_comb_inp_tok 已写入本 PE 应接收的所有 token(按 max_tok_per_rank 槽位) - - shmem_comb_inp_wts 已写入对应权重(仅 enable_weights=True 时需要) - - total_recv 已被 dispatch 设置完毕(Stage 3 用于读 cur_rank_num_token) + ``True`` (default) keeps the Stage 1 weight scatter and the + Stage 3b weight accumulate inside the combine kernel. The + weight scatter is intentionally kept here (instead of being + folded into the upstream fused GEMM2 epilogue) because the + 16B narrow stores are silently dropped by the ROCm IPC + fabric when they race with the upstream token P2P, so they + must be issued by the combine kernel on the static fabric. + ``False`` is for weight-free fused paths (the upstream fused + MoE has already collapsed the weights, so the combine side + does not need ``out_wts``): both the weight scatter and + Stage 3b are completely DCE'd, saving ~3-5 us. The two + variants use distinct JIT caches. + + Contract: before invocation the fused kernel must guarantee: + - shmem_comb_inp_tok already contains every token this PE + will consume (laid out in max_tok_per_rank slots); + - shmem_comb_inp_wts already contains the matching weights + (only when enable_weights=True); + - total_recv has already been set by dispatch (Stage 3 + reads it as cur_rank_num_token). """ cfg = self.cfg stream = torch.cuda.current_stream() @@ -534,8 +547,7 @@ def combine_no_stage1( else: prx_ptr = packed_recv_x.data_ptr() if packed_recv_x is not None else 0 - # JIT 缓存按 enable_weights 区分(两份编译产物)。 - # 历史 self._comb_no_s1_fn / _compiled 升级为 dict[bool, fn]。 + # Two JIT'd variants are cached, keyed by enable_weights. if not isinstance(self._comb_no_s1_fn, dict): self._comb_no_s1_fn = {} self._comb_no_s1_compiled = {} @@ -545,18 +557,16 @@ def combine_no_stage1( _use_fp8_cast = self._use_fp8_cast _comb_dtype = torch.float8_e4m3fn if _use_fp8_cast else cfg.data_type - # Mixed-dtype contract for fp8_direct_cast: external dtype = bf16, - # transport dtype = fp8. Stage 3 _from_accum will cast f32 → bf16 - # inline so kernel writes bf16 directly to shmem_comb_out_tok and - # the wrapper does NOT need a post .to(bf16) cast. + # Same fp8_direct_cast mixed-dtype contract as in ``combine`` + # above: external dtype = bf16, transport dtype = fp8; Stage 3 + # _from_accum casts f32 -> bf16 inline so the kernel writes + # bf16 straight to shmem_comb_out_tok. _comb_inp_dt = torch.bfloat16 if _use_fp8_cast else None - # enable_weights=False 路径(fused MoE 不需要 out_wts): - # weight scatter + Stage 3b weight accumulate 都在 const_expr - # 处被 DCE 掉,省 ~3-5μs。 - # enable_weights=True 路径(兼容 fused-with-weight): - # combine kernel 在 skip_stage1=True 下默认仍跑 weight scatter, - # 因为同 fabric 上与 token P2P 并发的 16B 小写会被静默丢,必须 - # 放在静态 fabric 上由 combine kernel 完成。 + # See the ``enable_weights`` doc above for the rationale of + # the two variants: ``False`` lets const_expr DCE the weight + # scatter + Stage 3b (~3-5 us); ``True`` keeps the weight + # scatter inside combine to dodge the IPC-fabric race against + # the upstream token P2P. self._comb_no_s1_fn[enable_weights] = make_combine_jit( rank=cfg.rank, npes=cfg.world_size, diff --git a/tests/kernels/test_profiler_dispatch_combine.py b/tests/kernels/test_profiler_dispatch_combine.py index 16ba7a02f..df098f406 100644 --- a/tests/kernels/test_profiler_dispatch_combine.py +++ b/tests/kernels/test_profiler_dispatch_combine.py @@ -2,20 +2,25 @@ # Copyright (c) 2025 FlyDSL Project Contributors """ -FlyDSL 和 mori ref 的 dispatch/combine kernel 性能测试。 - -两个正交维度可自由组合: - --mode 测量方式:profile(torch.profiler 采集)| bench(CUDA Event 计时) - --cudagraph 执行方式:不带此标志 = eager 模式 | 带 = CUDAGraph capture+replay - -四种组合: - 1. profile + eager : torch.profiler 采集 eager 执行的 kernel + E2E + CPU 时间 - 2. bench + eager : CUDA Event 计时 eager dispatch/combine(无 profiler 开销) - 3. profile + cudagraph: torch.profiler 采集 CUDAGraph replay 中的 kernel 时间 - 4. bench + cudagraph: CUDA Event 计时 CUDAGraph replay(零 Python launch 开销) - -启动方式(支持 torchrun 或直接 python): - # profile + eager(默认) +Performance harness for FlyDSL and mori-ref dispatch/combine kernels. + +Two orthogonal axes can be freely combined: + --mode measurement: ``profile`` (torch.profiler) | ``bench`` + (CUDA event timing) | ``verify`` (correctness check) + --cudagraph execution: absent = eager mode | present = + CUDAGraph capture+replay + +Four combinations: + 1. profile + eager : torch.profiler over eager kernels + E2E + + CPU timing + 2. bench + eager : CUDA event timing of eager dispatch/combine + (no profiler overhead) + 3. profile + cudagraph: torch.profiler over CUDAGraph replay kernels + 4. bench + cudagraph: CUDA event timing of CUDAGraph replay + (zero Python launch overhead) + +Launching (works under torchrun or plain python): + # profile + eager (default) python tests/kernels/test_profiler_dispatch_combine.py --max-tokens 512 # bench + eager @@ -27,7 +32,7 @@ # profile + cudagraph python tests/kernels/test_profiler_dispatch_combine.py --mode profile --cudagraph - # 只测 FlyDSL + # FlyDSL only python tests/kernels/test_profiler_dispatch_combine.py --bench-op flydsl """ @@ -44,7 +49,7 @@ os.environ.setdefault("MORI_SHMEM_HEAP_SIZE", "16G") -# ── dtype 映射 ── +# dtype mapping DTYPE_MAP = { "bf16": torch.bfloat16, "f32": torch.float32, @@ -66,6 +71,20 @@ if _p not in sys.path: sys.path.insert(0, _p) +# Module-level skip when mori is unavailable AND we are being imported by +# pytest collection. This file is a torchrun standalone script, but pytest +# still picks it up because the name matches ``test_*.py`` -- single-GPU CI +# runners don't install mori, so unconditional ``import mori`` would crash +# pytest collection. We only trigger ``pytest.importorskip`` when pytest +# is the orchestrator (``"pytest" in sys.modules``), so direct +# ``python``/``torchrun`` invocations still surface a normal ImportError +# if mori is missing (instead of an opaque pytest Skipped exception). +if "pytest" in sys.modules: + sys.modules["pytest"].importorskip( + "mori", + reason="dispatch/combine intranode test requires mori shmem (8-GPU multi-gpu CI only)", + ) + import mori.shmem as ms # noqa: E402 from kernels.dispatch_combine_intranode_op import ( # noqa: E402 @@ -74,7 +93,7 @@ ) -# ─── 分布式初始化 ───────────────────────────────────────────────────────────── +# --- Distributed init --- def setup_distributed(rank, world_size, master_port=29600): if "LOCAL_RANK" not in os.environ: os.environ.update( @@ -145,9 +164,10 @@ def build_mori_ref(rank, world_size, cfg, block_num: int = None, warp_per_block: def _save_profile_json(prof, out_path: str, rank: int, op_tag: str, meta: dict): - """将 profiler 结果序列化为 JSON 文件。 + """Serialize profiler results to a JSON file. + + JSON layout:: - JSON 结构: { "meta": {op_tag, rank, max_tokens, hidden_dim, k, world_size, ...}, "kernel_stats": [ {name, calls, cuda_time_avg_us, cpu_time_avg_us}, ... ] @@ -165,7 +185,6 @@ def _save_profile_json(prof, out_path: str, rank: int, op_tag: str, meta: dict): "cpu_time_total_us": round(evt.cpu_time * evt.count, 2), } ) - # 按 GPU time 降序 rows.sort(key=lambda r: r["cuda_time_total_us"], reverse=True) payload = { @@ -190,15 +209,17 @@ def _allreduce_stats( quant_type: str = "none", use_p2p_read: bool = False, ) -> dict: - """从本卡 profiler 提取关键指标,跨卡 all_reduce 后返回 avg/min/max 字典。 - - 采集 6 项指标(顺序固定,打包成 float64 tensor 做 all_reduce): - 0: dispatch GPU kernel time (μs/call) - 1: combine GPU kernel time (μs/call) - 2: dispatch record_function CUDA time (μs/call) - 3: combine record_function CUDA time (μs/call) - 4: dispatch record_function CPU time (μs/call) - 5: combine record_function CPU time (μs/call) + """Pull per-rank profiler metrics, all-reduce them across ranks, and + return an avg/min/max dict. + + Six metrics, packed into a float64 tensor in this fixed order for + the all-reduce: + 0: dispatch GPU kernel time (us/call) + 1: combine GPU kernel time (us/call) + 2: dispatch record_function CUDA time (us/call) + 3: combine record_function CUDA time (us/call) + 4: dispatch record_function CPU time (us/call) + 5: combine record_function CPU time (us/call) """ msuf = MORI_KERNEL_SUFFIX.get(dtype_key, "bf16") _cast_suf = "_fp8cast" if (quant_type == "fp8_direct_cast" and not use_p2p_read) else "" @@ -255,24 +276,24 @@ def cpu_us(key): def _print_aggregated(stats: dict, op_tag: str, world_size: int, meta: dict): - """rank 0 打印全卡聚合统计。""" + """Print the cross-rank aggregated stats on rank 0.""" sep = "=" * 72 print(f"\n{sep}") print( f" {op_tag.upper()} EP={world_size} bs={meta['max_tokens']} " f"h={meta['hidden_dim']} k={meta['k']} ({meta['iters']} iters)" ) - print(f" 所有 {world_size} 张卡的 avg / min / max(μs/call)") + print(f" avg / min / max across all {world_size} ranks (us/call)") print(sep) - hdr = f" {'指标':<36} {'avg':>8} {'min':>8} {'max':>8}" + hdr = f" {'metric':<36} {'avg':>8} {'min':>8} {'max':>8}" print(hdr) print(f" {'-'*60}") rows = [ ("[Device] dispatch kernel GPU time", "dispatch_gpu"), ("[Device] combine kernel GPU time", "combine_gpu"), - ("[E2E] dispatch CUDA time (含sync)", "dispatch_cuda_e2e"), - ("[E2E] combine CUDA time (含sync)", "combine_cuda_e2e"), + ("[E2E] dispatch CUDA time (w/sync)", "dispatch_cuda_e2e"), + ("[E2E] combine CUDA time (w/sync)", "combine_cuda_e2e"), ("[Host] dispatch CPU time", "dispatch_cpu_e2e"), ("[Host] combine CPU time", "combine_cpu_e2e"), ] @@ -292,9 +313,10 @@ def _allreduce_cudagraph_stats_from_key_averages( quant_type: str = "none", use_p2p_read: bool = False, ) -> dict: - """从 key_averages() 提取指标(仅含 active 阶段数据),跨卡 all_reduce。 + """Pull metrics from ``prof.key_averages()`` (active phase only) and + all-reduce them across ranks. - 采集 4 项: + Four metrics: 0: dispatch kernel GPU time 1: combine kernel GPU time 2: cudagraph_replay CUDA E2E time @@ -356,9 +378,12 @@ def _cudagraph_stats_from_trace( quant_type: str = "none", use_p2p_read: bool = False, ) -> dict: - """从 chrome trace JSON 手动统计 kernel 性能,跳过前 skip_first 次 active 调用。 + """Compute kernel stats by parsing the chrome trace JSON, dropping + the first ``skip_first`` active iterations. - 流程:解析 trace → 按时间排序取最后 active_iters 个事件 → 丢弃前 skip_first 个 → 跨卡聚合。 + Pipeline: parse trace -> sort by ts and keep the last + ``active_iters`` events -> drop the first ``skip_first`` -> + all-reduce across ranks. """ with open(trace_path) as f: tr = json.load(f) @@ -392,8 +417,8 @@ def _cudagraph_stats_from_trace( valid_n = len(d_valid) if rank == 0: print( - f"[trace-stats] {op_tag}: trace 中 dispatch={len(d_all)} combine={len(c_all)} 个事件," - f"取最后 {active_iters} 个 active,跳过前 {skip_first},有效 {valid_n} 个" + f"[trace-stats] {op_tag}: trace has dispatch={len(d_all)} combine={len(c_all)} events; " + f"keeping last {active_iters} active, skipping first {skip_first}, {valid_n} valid" ) d_avg = sum(d_valid) / valid_n if valid_n else 0.0 @@ -414,7 +439,7 @@ def _cudagraph_stats_from_trace( def _print_cudagraph_aggregated(stats: dict, op_tag: str, world_size: int, meta: dict, active_iters: int = None): - """rank 0 打印 cudagraph profiler 全卡聚合统计。""" + """Print the cudagraph+profiler aggregated stats on rank 0.""" n = active_iters if active_iters is not None else meta["iters"] sep = "=" * 72 print(f"\n{sep}") @@ -422,16 +447,16 @@ def _print_cudagraph_aggregated(stats: dict, op_tag: str, world_size: int, meta: f" {op_tag.upper()} [CUDAGraph+Profiler] EP={world_size} bs={meta['max_tokens']} " f"h={meta['hidden_dim']} k={meta['k']} ({n} iters)" ) - print(f" 所有 {world_size} 张卡的 avg / min / max(μs/call)") + print(f" avg / min / max across all {world_size} ranks (us/call)") print(sep) - hdr = f" {'指标':<36} {'avg':>8} {'min':>8} {'max':>8}" + hdr = f" {'metric':<36} {'avg':>8} {'min':>8} {'max':>8}" print(hdr) print(f" {'-'*60}") rows = [ ("[Device] dispatch kernel GPU time", "dispatch_gpu"), ("[Device] combine kernel GPU time", "combine_gpu"), - ("[E2E] replay CUDA time (含sync)", "replay_cuda_e2e"), + ("[E2E] replay CUDA time (w/sync)", "replay_cuda_e2e"), ("[Host] replay CPU time", "replay_cpu_e2e"), ] for label, key in rows: @@ -441,10 +466,11 @@ def _print_cudagraph_aggregated(stats: dict, op_tag: str, world_size: int, meta: def _make_profiler(active_iters: int = None, prof_warmup: int = 10): - """创建 profiler。 + """Build a torch.profiler. - 使用 schedule 让前 (1 + prof_warmup) 步不做/轻量追踪, - 减少 ROCTracer 在多 GPU P2P shmem 场景下的累积压力。 + The schedule keeps the first (1 + prof_warmup) steps in wait/warmup + so ROCTracer doesn't accumulate state under heavy multi-GPU P2P + shmem traffic. """ kwargs = dict( activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], @@ -461,7 +487,7 @@ def _make_profiler(active_iters: int = None, prof_warmup: int = 10): return profile(**kwargs) -# ─── bench 模式:不用 profiler,用 CUDA Event 计时 ──────────────────────────── +# --- bench mode: profiler-free CUDA-event timing --- def bench_op( op, op_tag: str, @@ -479,12 +505,13 @@ def bench_op( scales=None, packed_recv_x=None, ): - """无 profiler 的纯计时模式,输出 dispatch / combine 的 GPU 耗时(avg/min/max)。""" + """Profiler-free CUDA-event timing of dispatch/combine; reports GPU + time avg/min/max.""" _dkw = dict(packed_recv_x=packed_recv_x) if packed_recv_x is not None else {} _ckw = dict(packed_recv_x=packed_recv_x) if packed_recv_x is not None else {} ms.shmem_barrier_all() if rank == 0: - print(f"\n[bench] {op_tag} 预热 {warmup} 轮...") + print(f"\n[bench] {op_tag} warmup {warmup} iters...") for _ in range(warmup): op.reset() ret = op.dispatch(inp, wts, scales, idx, **_dkw) @@ -493,7 +520,7 @@ def bench_op( dist.barrier() if rank == 0: - print(f"[bench] {op_tag} 计时 {iters} 轮...") + print(f"[bench] {op_tag} timing {iters} iters...") d_events = [(torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)) for _ in range(iters)] c_events = [(torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)) for _ in range(iters)] @@ -516,7 +543,7 @@ def bench_op( d_list = [d_events[i][0].elapsed_time(d_events[i][1]) * 1000 for i in range(iters)] c_list = [c_events[i][0].elapsed_time(c_events[i][1]) * 1000 for i in range(iters)] - # 全卡聚合 avg / min / max + # Aggregate avg / min / max across ranks. local = torch.tensor( [ sum(d_list) / len(d_list), @@ -548,22 +575,23 @@ def bench_op( f"{op_tag.upper()} EP={meta['world_size']} bs={meta['max_tokens']} " f"h={meta['hidden_dim']} k={meta['k']} ({iters} iters)" ) - print(f"\n{sep}\n {tag}\n 所有 {world_size} 张卡的 avg / min / max(μs/call)\n{sep}") - print(f" {'指标':<36} {'avg':>8} {'min':>8} {'max':>8}") + print(f"\n{sep}\n {tag}\n avg / min / max across all {world_size} ranks (us/call)\n{sep}") + print(f" {'metric':<36} {'avg':>8} {'min':>8} {'max':>8}") print(f" {'-'*58}") print(f" {'[E2E] dispatch CUDA time':<36} {avg_d:>8.1f} {mn_d:>8.1f} {mx_d:>8.1f}") print(f" {'[E2E] combine CUDA time':<36} {avg_c:>8.1f} {mn_c:>8.1f} {mx_c:>8.1f}") print() -# ─── cudagraph 模式:CUDA Graph capture + replay 计时 ───────────────────────── +# --- cudagraph mode: CUDA Graph capture + replay timing --- def _cudagraph_capture_flydsl(op, inp, wts, idx, wc_buf, capture_stream, scales=None, packed_recv_x=None): - """FlyDSL:录制 dispatch+combine 到 CUDA Graph。 + """Capture FlyDSL dispatch+combine into a CUDA Graph. - dispatch/combine 均返回全尺寸 tensor(无 .item()、无动态切片)。 - 需要先 eager 调用一次触发 flyc.compile() JIT 编译(编译过程使用 - default stream,不能在 capture 期间执行),之后 capture 中仅录制 - 已编译的 kernel launch。 + Both dispatch and combine return full-sized tensors (no ``.item()``, + no dynamic slicing). We must first run them eagerly once to trigger + the ``flyc.compile()`` JIT (which uses the default stream and can't + run during capture); the capture then records only the already- + compiled kernel launches. """ _dkw = dict(packed_recv_x=packed_recv_x) if packed_recv_x is not None else {} _ckw = dict(packed_recv_x=packed_recv_x) if packed_recv_x is not None else {} @@ -580,11 +608,12 @@ def _cudagraph_capture_flydsl(op, inp, wts, idx, wc_buf, capture_stream, scales= def _cudagraph_capture_mori(op, inp, wts, idx, wc_buf, capture_stream, scales=None, packed_recv_x=None): - """Mori 专用:直接在 graph capture 中录制 dispatch+combine。 + """Capture mori dispatch+combine into a CUDA Graph. - Mori 的 dispatch 在 capture 模式下返回真实 tensor,combine kernel - 从 HBM 读取 totalRecvTokenNum,无需 pre-capture eager call。 - 参考 mori/tests/python/ops/bench_dispatch_combine.py stress_graph 写法。 + Mori's dispatch returns a real tensor under capture and the combine + kernel reads ``totalRecvTokenNum`` from HBM, so no pre-capture eager + call is needed. Pattern follows mori's ``stress_graph`` in + ``mori/tests/python/ops/bench_dispatch_combine.py``. """ ms.shmem_barrier_all() g = torch.cuda.CUDAGraph() @@ -611,7 +640,7 @@ def cudagraph_op( scales=None, packed_recv_x=None, ): - """CUDA Graph 模式:capture dispatch+combine kernel,replay 计时。""" + """CUDA Graph mode: capture dispatch+combine, then time replays.""" capture_stream = torch.cuda.Stream() if op_tag == "flydsl": g, cs = _cudagraph_capture_flydsl( @@ -625,15 +654,15 @@ def cudagraph_op( if rank == 0: print(f"\n[cudagraph] {op_tag} capture done") - # replay warmup(HIP graph 冷启动 + GPU 缓存预热) + # Replay warmup (HIP graph cold start + GPU cache warmup). replay_warmup = 10 if rank == 0: - print(f"[cudagraph] replay warmup {replay_warmup} 轮 + 计时 {iters} 轮(no-reset)...") + print(f"[cudagraph] replay warmup {replay_warmup} + timing {iters} iters (no-reset)...") for _ in range(replay_warmup): g.replay() torch.cuda.synchronize() - # 计时:预分配 event pairs,循环结束后统一 sync + # Timing: pre-allocate event pairs, sync once after the loop. events = [(torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)) for _ in range(iters)] for i in range(iters): @@ -644,7 +673,7 @@ def cudagraph_op( torch.cuda.synchronize() gpu_times = [events[i][0].elapsed_time(events[i][1]) * 1000 for i in range(iters)] - # per-replay 诊断 + # Per-replay diagnostics. per_replay_t = torch.tensor(gpu_times, dtype=torch.float64, device=dev) all_per_replay = [torch.zeros_like(per_replay_t) for _ in range(world_size)] dist.all_gather(all_per_replay, per_replay_t) @@ -675,8 +704,8 @@ def cudagraph_op( f"bs={meta['max_tokens']} h={meta['hidden_dim']} k={meta['k']} " f"({iters} replays)" ) - print(f"\n{sep}\n {tag}\n 所有 {world_size} 张卡的 avg / min / max(μs/call)\n{sep}") - print(f" {'指标':<36} {'avg':>8} {'min':>8} {'max':>8}") + print(f"\n{sep}\n {tag}\n avg / min / max across all {world_size} ranks (us/call)\n{sep}") + print(f" {'metric':<36} {'avg':>8} {'min':>8} {'max':>8}") print(f" {'-'*58}") print(f" {'[GPU] dispatch+combine (event)':<36} {avg_g:>8.1f} {mn_g:>8.1f} {mx_g:>8.1f}") @@ -694,7 +723,7 @@ def cudagraph_op( print() -# ─── 单算子 profiler 采集 ────────────────────────────────────────────────────── +# --- Per-op profiler capture --- def profile_op( op, op_tag: str, @@ -715,16 +744,18 @@ def profile_op( quant_type: str = "none", use_p2p_read: bool = False, ): - """对单个算子(FlyDSL 或 mori)独立 profiling,保存 JSON 并打印全卡聚合统计。 + """Profile a single op (FlyDSL or mori) standalone; save the JSON and + print cross-rank aggregated stats. - 使用 schedule(wait=1, warmup=10, active=iters) 让 ROCTracer 在前 11 步 - 不做/轻量追踪,减少与多 GPU P2P shmem 操作的冲突。 + Uses ``schedule(wait=1, warmup=10, active=iters)`` so ROCTracer + skips / light-traces the first 11 steps and avoids races with + multi-GPU P2P shmem. """ ms.shmem_barrier_all() prof_warmup = 10 total_steps = iters + 1 + prof_warmup # wait=1 + warmup=prof_warmup + active=iters if rank == 0: - print(f"\n[profiler] {op_tag} 开始采集({iters} 轮 active + {1 + prof_warmup} 轮 ramp-up)...") + print(f"\n[profiler] {op_tag} capturing ({iters} active + {1 + prof_warmup} ramp-up)...") _dkw = dict(packed_recv_x=packed_recv_x) if packed_recv_x is not None else {} _ckw = dict(packed_recv_x=packed_recv_x) if packed_recv_x is not None else {} @@ -746,13 +777,13 @@ def profile_op( prof.step() - # 保存 JSON:每张卡各自保存,文件名含 op_tag 和 rank + # Save JSON: one file per rank, named by op_tag and rank. out_path = os.path.join(out_dir, f"{op_tag}_rank{rank}.json") _save_profile_json(prof, out_path, rank, op_tag, meta) if rank == 0: - print(f"[profiler] {op_tag} trace → {out_path}") + print(f"[profiler] {op_tag} trace -> {out_path}") - # 跨卡聚合统计(all_reduce),rank 0 打印 + # Cross-rank aggregation via all_reduce; rank 0 prints. agg_stats = _allreduce_stats( prof, op_tag, rank, world_size, dev, dtype_key=dtype_key, quant_type=quant_type, use_p2p_read=use_p2p_read ) @@ -761,7 +792,7 @@ def profile_op( return prof -# ─── profile + cudagraph 模式 ───────────────────────────────────────────────── +# --- profile + cudagraph mode --- def profile_cudagraph_op( op, op_tag: str, @@ -783,9 +814,11 @@ def profile_cudagraph_op( quant_type: str = "none", use_p2p_read: bool = False, ): - """torch.profiler 采集 CUDAGraph replay,保存 JSON 并打印全卡聚合统计。 + """Profile CUDAGraph replays with torch.profiler; save JSON and + print cross-rank aggregated stats. - 流程:eager warmup → graph capture → replay warmup → profiler 包裹的 replay。 + Pipeline: eager warmup -> graph capture -> replay warmup -> + profiled replay loop. """ ms.shmem_barrier_all() @@ -802,7 +835,7 @@ def profile_cudagraph_op( if rank == 0: print(f"\n[profile+cudagraph] {op_tag} capture done") - # replay warmup(HIP graph 冷启动 + GPU 缓存预热) + # Replay warmup (HIP graph cold start + GPU cache warmup). replay_warmup = 10 for _ in range(replay_warmup): g.replay() @@ -817,7 +850,7 @@ def profile_cudagraph_op( print( f"[profile+cudagraph] {op_tag} scheduled profiler: " f"warmup={prof_warmup}, active={active_iters}, " - f"丢弃前 {skip_first} 次,有效 {valid_iters} 次(no-reset)..." + f"dropping first {skip_first}, {valid_iters} valid (no-reset)..." ) with _make_profiler(active_iters=active_iters, prof_warmup=prof_warmup) as prof: @@ -830,7 +863,7 @@ def profile_cudagraph_op( _save_profile_json(prof, out_path, rank, op_tag, meta) trace_path = out_path.replace(".json", "_trace.json") if rank == 0: - print(f"[profile+cudagraph] {op_tag} trace → {trace_path}") + print(f"[profile+cudagraph] {op_tag} trace -> {trace_path}") agg_stats = _cudagraph_stats_from_trace( trace_path, @@ -849,7 +882,7 @@ def profile_cudagraph_op( return prof -# ─── verify 模式:正确性验证 ───────────────────────────────────────────────── +# --- verify mode: correctness check --- VERIFY_TOL = { "f32": {"atol": 1e-5, "rtol": 1e-4}, "bf16": {"atol": 1e-2, "rtol": 1e-2}, @@ -988,7 +1021,7 @@ def verify_op(op_fly, op_mori, inp, wts, idx, k, rank, world_size, dev, dtype_ke print(f" VERIFY dtype={dtype_key} EP={world_size} bs={inp.shape[0]} " f"h={cfg.hidden_dim} k={k}") print(f"{'='*65}") - # ── Dispatch ── + # -- Dispatch -- op_fly.reset() op_mori.reset() ms.shmem_barrier_all() @@ -1001,11 +1034,11 @@ def verify_op(op_fly, op_mori, inp, wts, idx, k, rank, world_size, dev, dtype_ke tr_m = ret_m[4].clone() dist.barrier() if rank == 0: - print("\n ── Dispatch 对比(仅 total_recv,token 排列因原子序不同) ──") + print("\n -- Dispatch comparison (total_recv only; token order varies by atomic race) --") all_pass &= _check_exact("total_recv", tr_f, tr_m, rank) - # ── Combine ── + # -- Combine -- ms.shmem_barrier_all() cout_f = op_fly.combine(ret_f[0], None, ret_f[3]) cout_m = op_mori.combine(ret_m[0], None, ret_m[3]) @@ -1013,7 +1046,7 @@ def verify_op(op_fly, op_mori, inp, wts, idx, k, rank, world_size, dev, dtype_ke dist.barrier() if rank == 0: - print("\n ── Combine 输出对比 ──") + print("\n -- Combine output comparison --") mt = cfg.max_num_inp_token_per_rank cast_to = ( @@ -1052,7 +1085,7 @@ def verify_op(op_fly, op_mori, inp, wts, idx, k, rank, world_size, dev, dtype_ke return all_pass -# ─── 主逻辑 ─────────────────────────────────────────────────────────────────── +# --- Main entry --- def run_profiler(rank, world_size, args): dev = torch.device("cuda", rank) k = args.k @@ -1099,16 +1132,16 @@ def run_profiler(rank, world_size, args): quant_type=cfg.quant_type, ) - # 输出目录:/tmp/ep{ws}_bs{cur_tok}/ + # Output dir layout: /ep{ws}_bs{cur_tok}/ out_dir = os.path.join(args.output_dir, f"ep{world_size}_bs{cur_tok}") os.makedirs(out_dir, exist_ok=True) - # ── 构建算子 ─────────────────────────────────────────────────────────────── + # Build ops. if rank == 0: print(f"\n{'='*65}") print(f"[profiler] EP={world_size}, bs={cur_tok}, h={cfg.hidden_dim}, k={k}") print(f"{'='*65}") - print("[profiler] 构建 FlyDSL...") + print("[profiler] building FlyDSL...") op_fly = FlyDSLDispatchCombineIntraNodeOp(cfg) op_ref = None @@ -1118,17 +1151,17 @@ def run_profiler(rank, world_size, args): bn_str = mori_bn if mori_bn else cfg.block_num wpb_str = mori_wpb if mori_wpb else cfg.warp_num_per_block if rank == 0: - print(f"[profiler] 构建 mori ref (block_num={bn_str}, warp_per_block={wpb_str})...") + print(f"[profiler] building mori ref (block_num={bn_str}, warp_per_block={wpb_str})...") try: op_ref = build_mori_ref(rank, world_size, cfg, block_num=mori_bn, warp_per_block=mori_wpb) except Exception as e: if rank == 0: - print(f"[warn] mori ref 不可用: {e}") + print(f"[warn] mori ref unavailable: {e}") elif cfg.enable_std_moe and rank == 0: - print("[info] StdMoE 模式:跳过 mori ref,使用自洽验证") + print("[info] StdMoE mode: skipping mori ref, using self-check") ms.shmem_barrier_all() - # ── 准备输入(固定 seed,FlyDSL 和 mori 使用完全相同的输入)──────────────── + # Prepare inputs (fixed seed so FlyDSL and mori see identical data). torch.manual_seed(42 + rank) if cfg.data_type == torch.float4_e2m1fn_x2: inp = torch.randint(0, 256, (cur_tok, cfg.hidden_dim // 2), dtype=torch.uint8, device=dev).view( @@ -1153,11 +1186,12 @@ def run_profiler(rank, world_size, args): for t in range(cur_tok): idx[t] = torch.randperm(n_exp, device=dev)[:k] - # 预分配 combine 权重 buffer(FlyDSL 和 mori 共用,避免计时窗口内额外 GPU 核) + # Pre-allocate the combine weight buffer (shared by FlyDSL and mori + # so no extra GPU kernel sneaks into the timing window). max_recv = world_size * cur_tok wc_buf = torch.full((max_recv, k), 1.0 / k, dtype=torch.float32, device=dev) - # ── 构造 scales / packed_recv_x(所有模式共用)───────────────────────── + # Build scales / packed_recv_x (shared across modes). packed_recv_x = None if cfg.enable_std_moe: _prx_nbytes = cfg.num_experts_per_rank * cfg.max_recv * cfg.token_bytes @@ -1173,12 +1207,13 @@ def run_profiler(rank, world_size, args): scales = torch.randn(cur_tok, _sc_bytes // 4, dtype=torch.float32, device=dev).contiguous() scales = scales.view(torch.uint8).view(cur_tok, _sc_bytes) - # profile+eager 模式需要外部预热;其他 3 种组合由各自函数内部处理 + # profile+eager needs an external warmup; the other three combos + # warm up inside their own functions. do_warmup = args.mode == "profile" and not args.cudagraph if do_warmup: if rank == 0: - print(f"[setup] 预热 FlyDSL {args.warmup} 轮...") + print(f"[setup] warming up FlyDSL for {args.warmup} iters...") for _ in range(args.warmup): op_fly.reset() ret = op_fly.dispatch(inp, wts, scales, idx, packed_recv_x=packed_recv_x) @@ -1187,7 +1222,7 @@ def run_profiler(rank, world_size, args): if op_ref is not None: if rank == 0: - print(f"[setup] 预热 mori ref {args.warmup} 轮...") + print(f"[setup] warming up mori ref for {args.warmup} iters...") for _ in range(args.warmup): op_ref.reset() ret_r = op_ref.dispatch(inp, wts, None, idx) @@ -1196,7 +1231,7 @@ def run_profiler(rank, world_size, args): ms.shmem_barrier_all() - # ── 根据 mode × cudagraph 分发执行 ───────────────────────────────────── + # Dispatch by mode x cudagraph. test_flydsl = args.bench_op in ("flydsl", "both") test_mori = args.bench_op in ("mori", "both") and op_ref is not None @@ -1297,7 +1332,7 @@ def run_profiler(rank, world_size, args): use_p2p_read=_p2p, ) if rank == 0: - print(f"\n[profiler] 全部结果已保存到: {out_dir}/") + print(f"\n[profiler] all results saved to: {out_dir}/") elif args.mode == "profile" and args.cudagraph: _p2p = not args.use_external_inp_buf @@ -1345,10 +1380,10 @@ def run_profiler(rank, world_size, args): use_p2p_read=_p2p, ) if rank == 0: - print(f"\n[profiler] 全部结果已保存到: {out_dir}/") + print(f"\n[profiler] all results saved to: {out_dir}/") -# ─── Worker / 命令行入口 ────────────────────────────────────────────────────── +# --- Worker / CLI entry --- def _worker(rank, world_size, args, master_port): setup_distributed(rank, world_size, master_port) try: @@ -1363,7 +1398,7 @@ def _worker(rank, world_size, args, master_port): def _parse_args(): - p = argparse.ArgumentParser(description="torch.profiler 分析 dispatch/combine") + p = argparse.ArgumentParser(description="torch.profiler analysis of dispatch/combine") p.add_argument("--world-size", type=int, default=8) p.add_argument("--max-tokens", type=int, default=512) p.add_argument("--hidden-dim", type=int, default=7168) @@ -1372,49 +1407,64 @@ def _parse_args(): p.add_argument("--block-num", type=int, default=80) p.add_argument("--warp-per-block", type=int, default=4) p.add_argument( - "--mori-block-num", type=int, default=0, help="mori 专用 block_num(0=与FlyDSL相同,mori默认最优=80)" + "--mori-block-num", + type=int, + default=0, + help="mori-only block_num (0 = same as FlyDSL; mori's tuned default is 80)", ) p.add_argument( - "--mori-warp-per-block", type=int, default=0, help="mori 专用 warp_per_block(0=与FlyDSL相同,mori默认最优=8)" + "--mori-warp-per-block", + type=int, + default=0, + help="mori-only warp_per_block (0 = same as FlyDSL; mori's tuned default is 8)", ) p.add_argument("--chip", type=str, default="gfx950") - p.add_argument("--dtype", type=str, default="bf16", choices=list(DTYPE_MAP.keys()), help="数据类型(默认 bf16)") - p.add_argument("--warmup", type=int, default=5, help="预热轮次(不进 profiler,确保 JIT 编译完成)") - p.add_argument("--iters", type=int, default=5, help="profiler 采集轮次") + p.add_argument( + "--dtype", type=str, default="bf16", choices=list(DTYPE_MAP.keys()), help="data type (default: bf16)" + ) + p.add_argument( + "--warmup", + type=int, + default=5, + help="warmup iters outside the profiler (ensures JIT compilation completes)", + ) + p.add_argument("--iters", type=int, default=5, help="profiler active iters") p.add_argument( "--output-dir", type=str, default="dispatch_profile", - help="JSON 输出根目录(相对当前目录),子目录按 ep{ws}_bs{tok} 命名", + help="JSON output root (relative to cwd); per-shape subdir is named ep{ws}_bs{tok}", ) p.add_argument("--port", type=int, default=29800) p.add_argument("--no-compare", dest="compare", action="store_false") - # ── 模式选择 ────────────────────────────────────────────────────────────── + # Mode selection p.add_argument( "--mode", choices=["profile", "bench", "verify"], default="profile", - help="测量方式:profile=torch.profiler 采集(默认); bench=CUDA Event 计时; verify=正确性验证", + help="measurement: profile=torch.profiler (default); bench=CUDA event timing; verify=correctness check", + ) + p.add_argument("--cudagraph", action="store_true", help="use CUDAGraph capture+replay (default: eager)") + p.add_argument( + "--bench-op", choices=["flydsl", "mori", "both"], default="both", help="which op to measure (default: both)" ) - p.add_argument("--cudagraph", action="store_true", help="使用 CUDAGraph capture+replay 执行(默认 eager)") - p.add_argument("--bench-op", choices=["flydsl", "mori", "both"], default="both", help="测哪个算子(默认 both)") - # ── 功能开关 ────────────────────────────────────────────────────────────── + # Feature switches p.add_argument( "--no-external-inp-buf", dest="use_external_inp_buf", action="store_false", default=True, - help="使用 P2P Read combine 变体(默认使用 external inp buf)", + help="use the P2P-read combine variant (default: external inp buf)", ) - p.add_argument("--enable-std-moe", action="store_true", default=False, help="启用 Standard MoE Adapt 模式") - p.add_argument("--scale-dim", type=int, default=0, help="Scale 张量维度(0=不使用 scale)") - p.add_argument("--scale-type-size", type=int, default=0, help="Scale 类型大小(字节,0=不使用 scale)") + p.add_argument("--enable-std-moe", action="store_true", default=False, help="enable Standard MoE adapt mode") + p.add_argument("--scale-dim", type=int, default=0, help="scale tensor dim (0 = disable scales)") + p.add_argument("--scale-type-size", type=int, default=0, help="scale element size in bytes (0 = disable scales)") p.add_argument( "--quant-type", type=str, default="none", choices=["none", "fp8_direct_cast"], - help="量化类型(none=默认,fp8_direct_cast=FP8直接转换combine)", + help="quantization type (none = default; fp8_direct_cast = inline fp8 cast in combine)", ) p.set_defaults(compare=True) return p.parse_args() @@ -1429,7 +1479,7 @@ def main(): else: ws = min(args.world_size, torch.cuda.device_count()) if ws < args.world_size: - print(f"[warn] 可用 GPU={torch.cuda.device_count()}, " f"world_size 调整: {args.world_size} → {ws}") + print(f"[warn] available GPUs={torch.cuda.device_count()}, world_size adjusted: {args.world_size} -> {ws}") torch.multiprocessing.spawn( _worker, args=(ws, args, args.port), From 495da61f08c48d55b6d93ffd8e247ff82df3c856 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 20 May 2026 06:47:47 +0000 Subject: [PATCH 4/4] fix(ci): restore arith.py style-compliant import layout Align arith module ordering/ruff pragmas with mainline formatting so the Python style pre-check passes reliably in PR CI. --- python/flydsl/expr/arith.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/python/flydsl/expr/arith.py b/python/flydsl/expr/arith.py index 38403f6c4..c04998c2e 100644 --- a/python/flydsl/expr/arith.py +++ b/python/flydsl/expr/arith.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # Copyright (c) 2025 FlyDSL Project Contributors +# ruff: noqa: I001 -from .._mlir.dialects.arith import * # noqa: F401,F403 """Arith dialect API — operator overloading + function-level builders. Usage: @@ -12,27 +12,29 @@ r = arith.select(cond, a, b) # ArithValue operator overloading: c + 1, c * 2, c / 4, c % 16 """ + +from .._mlir.dialects.arith import * # noqa: F401,F403 + +# Override star-import cmpi/cmpf to accept Numeric types (Int32, etc.) +from .._mlir.dialects import arith as _mlir_arith from .meta import traced_op from .utils.arith import ( # noqa: F401 ArithValue, + _to_raw, + andi, constant, constant_vector, index, index_cast, int_to_fp, select, + shli, sitofp, trunc_f, - andi, - xori, - shli, unwrap, - _to_raw, + xori, ) -# Override star-import cmpi/cmpf to accept Numeric types (Int32, etc.) -from .._mlir.dialects import arith as _mlir_arith # noqa: E402 - @traced_op def cmpi(predicate, lhs, rhs, **kwargs): @@ -62,6 +64,3 @@ def cmpf(predicate, lhs, rhs, **kwargs): An ``i1`` comparison result. """ return _mlir_arith.cmpf(predicate, _to_raw(lhs), _to_raw(rhs), **kwargs) - - -