Skip to content

[Feat] Add FlyDSL MoE sorting kernel#540

Open
amd-weisun wants to merge 14 commits into
ROCm:mainfrom
amd-weisun:moe_sorting
Open

[Feat] Add FlyDSL MoE sorting kernel#540
amd-weisun wants to merge 14 commits into
ROCm:mainfrom
amd-weisun:moe_sorting

Conversation

@amd-weisun
Copy link
Copy Markdown
Contributor

@amd-weisun amd-weisun commented May 18, 2026

Drop-in replacement for OPUS/CK moe_sorting in aiter's fused_moe. (aiter uses OPUS by default)

Kernel paths:

  • T <= 16: single-block LDS histogram + DPP prefix sum
  • 16 < T <= 2048: p0v2 + p23 — per-expert scatter + parallel prefix sum
  • T > 2048: 4-kernel — K1 clear + K2 scatter + K3 count + p23

Correctness: 32 CI tests + 14 large_shape (46 total), covering all 11 production MoE models (E=8..513, topk=2..9).

Kernel Performance (moe sorting only)

Decode (T≤256): CUDA graph capture + 200 replays.
Prefill (T>256): CUDA events (eager mode), 100 iterations.
MI350X (gfx950), single GPU.

DeepSeek-R1: E=257, topk=9, unit_size=32

     T    Path    CK (us)  OPUS (us)  FlyDSL (us)  vs CK    vs OPUS
     1   graph     13.5      13.5         9.7    +28.5%    +28.2%
     2   graph     13.5      13.5         9.8    +27.3%    +27.3%
     4   graph     13.9      13.9         9.8    +29.4%    +29.3%
     8   graph     14.0      14.0         9.8    +29.7%    +29.6%
    16   graph     17.8      17.8        10.7    +39.8%    +39.9%
    32   graph     13.5      13.6        13.5     -0.1%     +0.2%
    64   graph     13.6      13.5        13.5     +0.2%     -0.1%
   128   graph     13.6      13.6        13.5     +0.9%     +0.9%
   256   graph     13.9      13.8        13.8     +0.5%     -0.1%
   512  events     10.1      23.5        12.0    -18.3%    +49.1%
  1024  events     12.1      23.8        11.8     +2.6%    +50.5%
  2048  events     24.0      47.9        17.7    +26.1%    +63.0%
  4096  events     20.9      30.2        17.2    +17.8%    +43.1%
  8192  events     31.1      31.3        25.0    +19.7%    +20.2%
 16384  events     52.9      53.1        40.9    +22.6%    +22.8%

DeepSeek-V4: E=385, topk=7, unit_size=32

     T    Path    CK (us)  OPUS (us)  FlyDSL (us)  vs CK    vs OPUS
     1   graph     16.0      16.0         9.7    +39.4%    +39.3%
     2   graph     16.3      16.3         9.8    +39.7%    +39.7%
     4   graph     16.6      17.0         9.9    +40.5%    +41.8%
     8   graph     17.1      17.1         9.9    +42.1%    +42.1%
    16   graph     21.1      21.1        11.9    +43.9%    +43.8%
    32   graph     13.7      13.6        13.8     -1.4%     -1.5%
    64   graph     13.8      13.7        13.9     -0.9%     -1.7%
   128   graph     13.9      13.9        13.9     +0.3%     -0.1%
   256   graph     14.1      14.0        14.1     -0.2%     -0.6%
   512  events     13.1      24.4        12.6     +3.6%    +48.2%
  1024  events     12.4      24.6        12.7     -2.2%    +48.5%
  2048  events     17.7      30.1        14.6    +17.4%    +51.3%
  4096  events     21.1      30.2        17.3    +18.1%    +42.8%
  8192  events     30.5      30.6        26.9    +11.7%    +12.2%
 16384  events     49.7      50.3        43.6    +12.3%    +13.4%

Qwen3-MoE: E=128, topk=4, unit_size=128

     T    Path    CK (us)  OPUS (us)  FlyDSL (us)  vs CK    vs OPUS
     1   graph     10.9      10.9         9.7    +11.2%    +11.2%
     2   graph     11.3      11.2         9.7    +14.2%    +14.1%
     4   graph     11.5      11.4         9.7    +15.3%    +15.0%
     8   graph     11.5      11.5         9.7    +15.5%    +15.9%
    16   graph     13.8      13.9        10.4    +25.0%    +25.3%
    32   graph     17.3      17.3        12.9    +25.4%    +25.4%
    64   graph     12.8      12.8        12.9     -0.7%     -1.1%
   128   graph     12.8      12.8        12.9     -0.2%     -0.3%
   256   graph     12.9      13.0        12.9     -0.3%     +0.4%
   512  events     11.7      24.6        13.0    -11.0%    +47.2%
  1024  events     10.3      24.6        13.0    -26.6%    +47.1%
  2048  events     16.2      30.3        12.7    +21.7%    +58.1%
  4096  events     17.7      30.2        18.1     -2.3%    +40.1%
  8192  events     22.3      30.3        20.6     +7.7%    +31.9%
 16384  events     33.5      33.5        30.8     +8.3%     +8.1%

ATOM E2E Benchmark Result

DeepSeek-R1-0528 FP8, TP8, FP8 KV Cache, 8×MI350X
ISL=8192, OSL=1024, CONC=4, 40 prompts, 3 runs each

                          Baseline(OPUS)                            FlyDSL
Metric            Run1     Run2     Run3    Avg      Run1     Run2     Run3    Avg      Delta
─────────────────────────────────────────────────────────────────────────────────────────────
Output tok/s     329.09   329.37   328.99  329.15   336.36   336.75   337.00  336.70   +2.3%
Total tok/s     2958.34  2960.80  2957.46 2958.87  3023.69  3027.20  3029.42 3026.77   +2.3%
Mean TPOT (ms)    11.56    11.55    11.56   11.56    11.31    11.29    11.29   11.30   -2.2%
Median TPOT (ms)  11.62    11.60    11.62   11.61    11.36    11.34    11.33   11.34   -2.3%
P99 TPOT (ms)     12.00    11.99    12.00   12.00    11.76    11.74    11.72   11.74   -2.2%
Mean TTFT (ms)   277.28   279.32   278.18  278.26   276.19   277.24   275.73  276.39   -0.7%
Median TTFT (ms) 250.76   250.52   250.66  250.65   249.14   249.54   248.88  249.19   -0.6%
P99 TTFT (ms)    677.23   687.90   683.20  682.78   684.32   686.03   681.03  683.79   +0.1%
Mean ITL (ms)     11.55    11.54    11.55   11.55    11.30    11.28    11.27   11.28   -2.3%
Mean E2EL (ms) 10903.89 10897.20 10908.16 10903.08 10670.28 10655.88 10649.77 10658.64  -2.2%
P99 E2EL (ms)  12329.03 12330.07 12352.11 12337.07 12095.14 12091.76 12051.80 12079.57  -2.1%

Accuracy validated using llm accuracy validation steps in ATOM

@amd-weisun amd-weisun force-pushed the moe_sorting branch 3 times, most recently from 2e21d3c to ba693c7 Compare May 18, 2026 13:48
@amd-weisun amd-weisun marked this pull request as ready for review May 18, 2026 16:00
Copilot AI review requested due to automatic review settings May 18, 2026 16:00
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a new FlyDSL implementation of the MoE token sorting kernel (decode + prefill paths) intended as a drop-in replacement for aiter/OPUS/CK, along with a comprehensive GPU correctness test suite and optional benchmarking helpers.

Changes:

  • Introduce kernels/moe_sorting_kernel.py implementing decode (single-kernel LDS) and prefill (HBM workspace; p0v2+p23 or 4-kernel fused) sorting paths, including optional expert-mask (EP) support.
  • Add tests/kernels/test_moe_sorting.py validating outputs vs a Python reference and (optionally) aiter, plus EP-mode coverage and benchmark utilities.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.

File Description
kernels/moe_sorting_kernel.py New FlyDSL MoE sorting kernel implementation with decode/prefill dispatch, workspace handling, and EP support.
tests/kernels/test_moe_sorting.py New GPU test suite comparing against a reference/aiter and covering multiple shapes + EP mode; includes optional benchmarking.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread kernels/moe_sorting_kernel.py Outdated
Comment thread kernels/moe_sorting_kernel.py Outdated
Comment thread kernels/moe_sorting_kernel.py
Comment thread tests/kernels/test_moe_sorting.py Outdated
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 2 out of 2 changed files in this pull request and generated 5 comments.

Comment thread kernels/moe_sorting_kernel.py Outdated
Comment thread kernels/moe_sorting_kernel.py Outdated
Comment thread tests/kernels/test_moe_sorting.py Outdated
Comment thread tests/kernels/test_moe_sorting.py Outdated
Comment thread tests/kernels/test_moe_sorting.py Outdated
Comment thread kernels/moe_sorting_kernel.py Outdated

# =================== MOE_BUF ZEROING (blocks > 0 only) ===============
is_zero_block = bid != c_zero_i32
_if_zero = scf.IfOp(is_zero_block.ir_value())
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use if not scf?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

converted to if

Comment thread kernels/moe_sorting_kernel.py Outdated

# For E > DECODE_BLOCK: thread 0 serially extends
if E > DECODE_BLOCK:
_if_t0_ext = scf.IfOp(is_t0.ir_value())
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scf -> if

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

converted to if

safe_ml_eid = ml_valid.select(ml_eid, c_zero_i32)
ml_mask = buffer_ops.buffer_load(mask_rsrc, safe_ml_eid, vec_width=1, dtype=T.i32)
ml_val = ml_valid.select(ml_mask, c_zero_i32)
ml_ix = ArithValue(ml_valid.select(ml_eid + c_one_i32, c_zero_i32)).index_cast(T.index)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fx.Int64() ?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed explicit cast

eid_wr_valid = eid_wr < c_E
safe_eid_wr = eid_wr_valid.select(eid_wr, c_zero_i32)

cs_start_ix = ArithValue(safe_eid_wr).index_cast(T.index)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fx.Int64?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed explicit cast

Comment thread kernels/moe_sorting_kernel.py Outdated
# FlyDSL GPU kernels — prefill path (4 kernels, large T via HBM workspace)
# ---------------------------------------------------------------------------
@functools.lru_cache(maxsize=256)
def compile_moe_sorting_prefill(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reuse the code with decode?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Merged into a single compile_moe_sorting() entry point . Use shared helpers function to improve code reuse. Also renamed decode to oneshot and prefill to multiphase to avoid confusion. As kernel paths are selected based on num of tokens .

Drop-in replacement for OPUS/CK moe_sorting in aiter's fused_moe.

Kernel paths:
- T <= 16: decode — single-block LDS histogram + DPP prefix sum
- 16 < T <= 2048: p0v2 + p23 — per-expert scatter + parallel prefix sum
- T > 2048: 4-kernel fused — K1 clear + K2 scatter + K3 count + p23

Correctness: 32 CI tests + 14 large_shape (46 total), covering
all 11 production MoE models (E=8..513, topk=2..9).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants