Skip to content

add asmjit AOT kernels for qwen35/Hunyuan3#3309

Open
tingqli wants to merge 1 commit into
ROCm:mainfrom
tingqli:fmoe-aot
Open

add asmjit AOT kernels for qwen35/Hunyuan3#3309
tingqli wants to merge 1 commit into
ROCm:mainfrom
tingqli:fmoe-aot

Conversation

@tingqli
Copy link
Copy Markdown

@tingqli tingqli commented May 22, 2026

Motivation

These kernels are specially designed/optimized for fused MOE (TP8, FP8-per-tensor, FP8-ptpc) problems:

  • avoid overheads of quantization in small batch cases, by dequantize weights on-the-fly and do computation in bf16 precision.
  • avoid overheads of moe_sorting in single-token case
  • fine pipelined 4-wave gate-up kernel
  • in down-kernel, use store_dwordx4 + reduce-sum instead of atomic_add
  • in down-kernel, loads A matrix and make it resident in register, loop over output-channel dimensions in 1x4 warps, pipelined MFMAs with store_dwordx4

Technical Details

current tune methods didn't take extra overheads (moe-sorting/quant) into consideration, but fmoe_asmjit_aot introduced some optimizations regarding to these overheads, thus we introduced a new --e2e_tune flag into gemm_moe_tune.py, this mode directly compares fmoe_asmjit_aot's performance against current best fused_moe()'s performance, if former is better, it will be recorded into destination tuned_fmoe file under model_configs.

python3 csrc/ck_gemm_moe_2stages_codegen/gemm_moe_tune.py -i aiter/configs/model_configs/qwen3_5_397b_fp8_ptpc_untuned_fmoe.csv -o aiter/configs/model_configs/qwen3_5_397b_fp8_ptpc_tuned_fmoe.csv --timeout 300 -v --all --e2e_tune

Test Plan

Test Result

performance improve in e2e tune: Hunyuan fp8-per-tensor (TP8)
Screenshot 2026-05-22 103351

performance improve in e2e tune: qwen3.5 fp8-ptpc (TP8)
Screenshot 2026-05-22 103635

Submission Checklist

Co-authors: Cheng.Luo@amd.com Luwei.Zhou@amd.com

@github-actions
Copy link
Copy Markdown
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests on MI35X (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-300x Run an additional Triton test job on MI300X in PRs; main branch always runs both MI35X and MI300X
ci:sglang SGLang integration tests: DeepSeek-R1-MXFP4 accuracy, Qwen 3.5 accuracy
ci:atom ATOM benchmark: DeepSeek-R1-0528, GPT-OSS-120B
ci:atom_full ATOM accuracy suite for PR and main models from ATOM models_accuracy.json
ci:vllm vLLM benchmark: GPT-OSS-120B, DeepSeek-R1-0528, Kimi-K2.5
ci:all All standard extended tests (excludes ci:atom_full)

Only add ci:atom_full for FlyDSL or Triton upgrades.
Add labels via the sidebar or gh pr edit 3309 --add-label <label>

@tingqli tingqli marked this pull request as ready for review May 22, 2026 02:57
@tingqli tingqli requested review from a team and Copilot May 22, 2026 02:57
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds an asmjit AOT MoE fused implementation and an end-to-end tuning mode so tuning decisions account for real production overheads (sorting/quant/etc.), and wires tuned configs to select the new implementation at runtime.

Changes:

  • Introduces fused_moe_asmjit_aot and a lightweight HSACO loader/launcher to run AOT kernels.
  • Adds --e2e_tune path in gemm_moe_tune.py to compare current best fused_moe() vs asmjit-AOT variants and write winners into tuned CSVs.
  • Extends fused_moe config lookup to dispatch to asmjit-AOT when tuned CSV selects it; adds new model config CSVs for Qwen3.5/Hunyuan3.

Reviewed changes

Copilot reviewed 9 out of 29 changed files in this pull request and generated 11 comments.

Show a summary per file
File Description
csrc/cpp_itfs/hsaco_tools.py New ctypes-based HSACO loader/launcher and kernel symbol discovery.
csrc/ck_gemm_moe_2stages_codegen/gemm_moe_tune.py Adds e2e tuning mode and hooks to benchmark asmjit-AOT configs.
aiter/utility/base_tuner.py Adds --e2e_tune CLI flag to the shared tuner argument set.
aiter/fused_moe.py Enables tuned-config dispatch into the asmjit-AOT fused MoE implementation.
aiter/fused_moe_asmjit_aot.py New asmjit AOT fused MoE implementation + tuning space definition.
aiter/configs/model_configs/qwen3_5_397b_fp8_ptpc_untuned_fmoe.csv New untuned shape list for Qwen3.5 fp8-ptpc MoE.
aiter/configs/model_configs/qwen3_5_397b_fp8_ptpc_tuned_fmoe.csv New tuned results selecting asmjit-AOT kernels for Qwen3.5.
aiter/configs/model_configs/hunyuan3_fp8_per_tensor_untuned_fmoe.csv New untuned shape list for Hunyuan3 fp8-per-tensor MoE.
aiter/configs/model_configs/hunyuan3_fp8_per_tensor_tuned_fmoe.csv New tuned results selecting asmjit-AOT kernels for Hunyuan3.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

ExtraType = ctypes.c_void_p * 5
kernel_args_size = ctypes.c_uint64(ctypes.sizeof(kernel_args))
kernel_config = ExtraType(
1, ctypes.addressof(kernel_args), 2, ctypes.addressof(kernel_args_size), 3
Copy link
Copy Markdown
Author

@tingqli tingqli May 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kernel_config = ExtraType(
1, ctypes.addressof(kernel_args), 2, ctypes.addressof(kernel_args_size), 3
)
stream = ctypes.cast(torch.cuda.current_stream(), ctypes.c_void_p)
Comment on lines +228 to +236
while len(gridDims) < 3:
gridDims.append(1)
while len(blockDims) < 3:
blockDims.append(1)
hip_check_error(
hip.hipModuleLaunchKernel(
p_func,
*gridDims,
*blockDims,
Comment on lines +124 to +126
dynamic_syms_raw = subprocess.check_output(
["/opt/rocm/llvm/bin/llvm-objdump", "--dynamic-syms", co_path]
).decode("utf-8")
Comment on lines +166 to +167
assert kernel_cnt.value > 0
kernels = (ctypes.c_void_p * kernel_cnt.value)()
Comment on lines +34 to +36
return cls(*[eval(p) for p in parts])


Comment on lines +26 to +27
from aiter.fused_moe_asmjit_aot import fused_moe_asmjit_aot
from aiter.fused_moe_asmjit_aot import get_tune_space
"--e2e_tune",
action="store_true",
required=False,
help="Run an extra round of e2e tuning after main tuning is done, using production-op benchmark as the indicator",
Comment thread aiter/fused_moe.py
Comment on lines +1151 to +1162
if kernelName1.startswith("fused_moe_asmjit_aot"):
from aiter.fused_moe_asmjit_aot import fused_moe_asmjit_aot

return MOEMetadata(
None,
None,
block_m,
ksplit,
run_1stage,
stage0=functools.partial(
fused_moe_asmjit_aot, config_string=kernelName1.split("__")[1]
),
E, N1, K1 = w1.shape
N2, K2 = w2.shape[1], w2.shape[2]
TOPK = topk_ids.shape[1]
fp8_ptpc = w1.dtype in (torch.float8_e4m3fn, torch.float8_e4m3fnuz) and (
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants