[Feat] Add FlyDSL MoE sorting kernel#540
Conversation
2e21d3c to
ba693c7
Compare
There was a problem hiding this comment.
Pull request overview
Adds a new FlyDSL implementation of the MoE token sorting kernel (decode + prefill paths) intended as a drop-in replacement for aiter/OPUS/CK, along with a comprehensive GPU correctness test suite and optional benchmarking helpers.
Changes:
- Introduce
kernels/moe_sorting_kernel.pyimplementing decode (single-kernel LDS) and prefill (HBM workspace; p0v2+p23 or 4-kernel fused) sorting paths, including optional expert-mask (EP) support. - Add
tests/kernels/test_moe_sorting.pyvalidating outputs vs a Python reference and (optionally) aiter, plus EP-mode coverage and benchmark utilities.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
| kernels/moe_sorting_kernel.py | New FlyDSL MoE sorting kernel implementation with decode/prefill dispatch, workspace handling, and EP support. |
| tests/kernels/test_moe_sorting.py | New GPU test suite comparing against a reference/aiter and covering multiple shapes + EP mode; includes optional benchmarking. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
|
||
| # =================== MOE_BUF ZEROING (blocks > 0 only) =============== | ||
| is_zero_block = bid != c_zero_i32 | ||
| _if_zero = scf.IfOp(is_zero_block.ir_value()) |
|
|
||
| # For E > DECODE_BLOCK: thread 0 serially extends | ||
| if E > DECODE_BLOCK: | ||
| _if_t0_ext = scf.IfOp(is_t0.ir_value()) |
| safe_ml_eid = ml_valid.select(ml_eid, c_zero_i32) | ||
| ml_mask = buffer_ops.buffer_load(mask_rsrc, safe_ml_eid, vec_width=1, dtype=T.i32) | ||
| ml_val = ml_valid.select(ml_mask, c_zero_i32) | ||
| ml_ix = ArithValue(ml_valid.select(ml_eid + c_one_i32, c_zero_i32)).index_cast(T.index) |
There was a problem hiding this comment.
removed explicit cast
| eid_wr_valid = eid_wr < c_E | ||
| safe_eid_wr = eid_wr_valid.select(eid_wr, c_zero_i32) | ||
|
|
||
| cs_start_ix = ArithValue(safe_eid_wr).index_cast(T.index) |
There was a problem hiding this comment.
removed explicit cast
| # FlyDSL GPU kernels — prefill path (4 kernels, large T via HBM workspace) | ||
| # --------------------------------------------------------------------------- | ||
| @functools.lru_cache(maxsize=256) | ||
| def compile_moe_sorting_prefill( |
There was a problem hiding this comment.
reuse the code with decode?
There was a problem hiding this comment.
Merged into a single compile_moe_sorting() entry point . Use shared helpers function to improve code reuse. Also renamed decode to oneshot and prefill to multiphase to avoid confusion. As kernel paths are selected based on num of tokens .
Drop-in replacement for OPUS/CK moe_sorting in aiter's fused_moe. Kernel paths: - T <= 16: decode — single-block LDS histogram + DPP prefix sum - 16 < T <= 2048: p0v2 + p23 — per-expert scatter + parallel prefix sum - T > 2048: 4-kernel fused — K1 clear + K2 scatter + K3 count + p23 Correctness: 32 CI tests + 14 large_shape (46 total), covering all 11 production MoE models (E=8..513, topk=2..9).
Drop-in replacement for OPUS/CK moe_sorting in aiter's fused_moe. (aiter uses OPUS by default)
Kernel paths:
Correctness: 32 CI tests + 14 large_shape (46 total), covering all 11 production MoE models (E=8..513, topk=2..9).
Kernel Performance (moe sorting only)
Decode (T≤256): CUDA graph capture + 200 replays.
Prefill (T>256): CUDA events (eager mode), 100 iterations.
MI350X (gfx950), single GPU.
DeepSeek-R1: E=257, topk=9, unit_size=32
DeepSeek-V4: E=385, topk=7, unit_size=32
Qwen3-MoE: E=128, topk=4, unit_size=128
ATOM E2E Benchmark Result
Accuracy validated using llm accuracy validation steps in ATOM