Skip to content

[FlyDSL]add FlyDSL MoE sorting kernel#3266

Draft
amd-weisun wants to merge 4 commits into
ROCm:mainfrom
amd-weisun:flydsl-moe-sorting
Draft

[FlyDSL]add FlyDSL MoE sorting kernel#3266
amd-weisun wants to merge 4 commits into
ROCm:mainfrom
amd-weisun:flydsl-moe-sorting

Conversation

@amd-weisun
Copy link
Copy Markdown

@amd-weisun amd-weisun commented May 19, 2026

Motivation

Add FlyDSL MoE sorting kernel.
The kernel implementation is from FlyDSL PR ROCm/FlyDSL#540

This PR includes the integration codes.
FlyDSL MoE can be enabled by AITER_USE_FLYDSL_MOE_SORTING=1 , following the style of CK kernel : AITER_USE_CK_MOE_SORTING=1. The default kernel now is OPUS.

E2E Benchmark Result

DeepSeek-R1-0528 FP8 TP8

DeepSeek-R1-0528 FP8, TP8, FP8 KV Cache, 8×MI350X
ISL=8192, OSL=1024, CONC=4, 40 prompts, 3 runs each

                  Baseline (OPUS)                          FlyDSL
Metric            Run1     Run2     Run3    Avg      Run1     Run2     Run3    Avg      Delta
─────────────────────────────────────────────────────────────────────────────────────────────
Output tok/s     329.09   329.37   328.99  329.15   336.36   336.75   337.00  336.70   +2.3%
Total tok/s     2958.34  2960.80  2957.46 2958.87  3023.69  3027.20  3029.42 3026.77   +2.3%
Mean TPOT (ms)    11.56    11.55    11.56   11.56    11.31    11.29    11.29   11.30   -2.2%
Median TPOT (ms)  11.62    11.60    11.62   11.61    11.36    11.34    11.33   11.34   -2.3%
P99 TPOT (ms)     12.00    11.99    12.00   12.00    11.76    11.74    11.72   11.74   -2.2%
Mean TTFT (ms)   277.28   279.32   278.18  278.26   276.19   277.24   275.73  276.39   -0.7%
Median TTFT (ms) 250.76   250.52   250.66  250.65   249.14   249.54   248.88  249.19   -0.6%
P99 TTFT (ms)    677.23   687.90   683.20  682.78   684.32   686.03   681.03  683.79   +0.1%
Mean ITL (ms)     11.55    11.54    11.55   11.55    11.30    11.28    11.27   11.28   -2.3%
Mean E2EL (ms) 10903.89 10897.20 10908.16 10903.08 10670.28 10655.88 10649.77 10658.64  -2.2%
P99 E2EL (ms)  12329.03 12330.07 12352.11 12337.07 12095.14 12091.76 12051.80 12079.57  -2.1%

Accuracy

                    Baseline (OPUS)          FlyDSL
flexible-extract    0.9560               0.9575 

DeepSeek-R1-0528 MXFP4 TP8

DeepSeek-R1-0528 MXFP4, TP8, FP8 KV Cache, 8×MI350X
ISL=8192, OSL=1024, CONC=4, 40 prompts, 3 runs each

                  Baseline (OPUS)                          FlyDSL
Metric            Run1     Run2     Run3    Avg      Run1     Run2     Run3    Avg      Delta
─────────────────────────────────────────────────────────────────────────────────────────────
Output tok/s     414.56   414.57   415.15  414.76   427.66   429.40   427.68  428.25   +3.3%
Total tok/s     3726.68  3726.76  3731.94 3728.46  3844.41  3860.04  3844.54 3849.66   +3.3%
Mean TPOT (ms)     9.17     9.17     9.16    9.17     8.88     8.84     8.89    8.87   -3.3%
Median TPOT (ms)   9.21     9.21     9.20    9.21     8.93     8.88     8.93    8.91   -3.3%
P99 TPOT (ms)      9.55     9.52     9.53    9.53     9.26     9.21     9.27    9.25   -2.9%
Mean TTFT (ms)   233.05   233.73   232.32  233.03   232.56   232.93   233.56  233.02   -0.0%
Median TTFT (ms) 209.85   210.47   208.41  209.58   209.78   211.99   210.06  210.61   +0.5%
P99 TTFT (ms)    583.87   583.11   583.02  583.33   585.43   586.92   584.77  585.71   +0.4%
Mean ITL (ms)      9.16     9.16     9.15    9.16     8.87     8.83     8.88    8.86   -3.3%
Mean E2EL (ms)  8663.62  8662.04  8650.62 8658.76  8397.61  8362.14  8400.81 8386.85   -3.1%
P99 E2EL (ms)   9858.62  9853.45  9832.53 9848.20  9559.02  9511.07  9558.54 9542.88   -3.1%

Accuracy

                    Baseline (OPUS)          FlyDSL
flexible-extract    0.9484            0.9492

@github-actions
Copy link
Copy Markdown
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests on MI35X (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-300x Run an additional Triton test job on MI300X in PRs; main branch always runs both MI35X and MI300X
ci:sglang SGLang integration tests: DeepSeek-R1-MXFP4 accuracy, Qwen 3.5 accuracy
ci:atom ATOM benchmark: DeepSeek-R1-0528, GPT-OSS-120B
ci:atom_full ATOM accuracy suite for PR and main models from ATOM models_accuracy.json
ci:vllm vLLM benchmark: GPT-OSS-120B, DeepSeek-R1-0528, Kimi-K2.5
ci:all All standard extended tests (excludes ci:atom_full)

Only add ci:atom_full for FlyDSL or Triton upgrades.
Add labels via the sidebar or gh pr edit 3266 --add-label <label>

@amd-weisun amd-weisun requested a review from coderfeli May 19, 2026 12:43
# =================== MOE_BUF ZEROING (blocks > 0 only) ===============
is_zero_block = bid != c_zero_i32
_if_zero = scf.IfOp(is_zero_block.ir_value())
with _if_then(_if_zero):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can use python if directly here?

mesh_addr = token_id * c_smem_cols + eid
last_mesh_idx = fx.Int32(sub_tokens * smem_cols - 1)
safe_mesh_addr = is_valid.select(mesh_addr, last_mesh_idx)
safe_mesh_ix = ArithValue(safe_mesh_addr).index_cast(T.index)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use Fx.int64() directly

p0v2_allocator.ptr = p0v2_reduce_offset + P0V2_NUM_WAVES * 4

@flyc.kernel(known_block_size=[P0V2_BLOCK, 1, 1])
def p0v2_kernel(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Too many duplicated codes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants