Skip to content

[Triton-Gluon-MLA-GFX950] add decode kernel for nhead<=16#3311

Open
Dewei-Wang-sh wants to merge 3 commits into
ROCm:mainfrom
Dewei-Wang-sh:add_mla_gluon_bh16
Open

[Triton-Gluon-MLA-GFX950] add decode kernel for nhead<=16#3311
Dewei-Wang-sh wants to merge 3 commits into
ROCm:mainfrom
Dewei-Wang-sh:add_mla_gluon_bh16

Conversation

@Dewei-Wang-sh
Copy link
Copy Markdown
Contributor

@Dewei-Wang-sh Dewei-Wang-sh commented May 22, 2026

Summary

  • Adds two new compile-time regimes to mla_decode_gluon for the single-query long-context decode shape (batch=1, nhead ≤ 16): bh16bn128 (BF16 Q + FP8 KV)
    and bh16bn64 (BF16 Q + BF16 KV). Existing bh64 regime (batch ∈ {64,128,256}, nhead ∈ {64,128}, BF16 KV) is unchanged.

Constraint and Usage

check the README

Perf (MI350, gfx950)

  • bh16bn64 (ctx=10M, BF16 KV): 5.33 TB/s wall time
  • bh16bn128 (ctx=10M, FP8 KV): 4.58 TB/s wall time

Test plan

  • python op_tests/test_mla.py -c 10000000 -b 1 -n 16,1 4,1 8,1 -d bf16 -kvd bf16 (bh16bn64)
  • python op_tests/test_mla.py -c 10000000 -b 1 -n 16,1 4,1 8,1 -d bf16 -kvd fp8 (bh16bn128)
  • checkAllclose passes against the golden reference for all of the above

Dewei-Wang-sh and others added 2 commits May 21, 2026 07:25
… + FP8 KV)

Adds support for nhead in (4, 8, 16), batch=1, fp8 KV to the existing
gluon MLA decode kernel, targeting the single-query long-context decode
regime (NUM_KV_SPLITS=256, kv_scale dequant, NHEAD<BLOCK_H masking).

Implemented as a constexpr REGIME ('bh64' | 'bh16bn128') gate inside
_mla_decode_gluon: Gluon layouts, BLOCK_H/BLOCK_N, grid mapping, and
launch parameters branch at compile time on REGIME, while the algorithm
skeleton (3-stage SW-pipelined async-copy MLA) is shared. The wrapper
dispatches by nhead. The stage-2 reduce kernel is unchanged.

Tested on MI350 (gfx950):
  python op_tests/test_mla.py -c 10000000 -b 1 -n 16,1 4,1 8,1 -d bf16 -kvd fp8
  python op_tests/test_mla.py -c 10000 -b 128 -n 128,1 -d bf16 -kvd bf16

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@Dewei-Wang-sh Dewei-Wang-sh requested review from a team and Copilot May 22, 2026 07:18
@github-actions
Copy link
Copy Markdown
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests on MI35X (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-300x Run an additional Triton test job on MI300X in PRs; main branch always runs both MI35X and MI300X
ci:sglang SGLang integration tests: DeepSeek-R1-MXFP4 accuracy, Qwen 3.5 accuracy
ci:atom ATOM benchmark: DeepSeek-R1-0528, GPT-OSS-120B
ci:atom_full ATOM accuracy suite for PR and main models from ATOM models_accuracy.json
ci:vllm vLLM benchmark: GPT-OSS-120B, DeepSeek-R1-0528, Kimi-K2.5
ci:all All standard extended tests (excludes ci:atom_full)

Only add ci:atom_full for FlyDSL or Triton upgrades.
Add labels via the sidebar or gh pr edit 3311 --add-label <label>

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds new mla_decode_gluon compile-time regimes targeting single-query, long-context decode on gfx950 for batch_size=1 and nhead<=16, including FP8-KV and BF16-KV variants, plus corresponding test harness updates and documentation.

Changes:

  • Extend mla_decode_gluon with two new regimes: bh16bn128 (BF16 Q + FP8 KV) and bh16bn64 (BF16 Q + BF16 KV), alongside existing bh64.
  • Update op_tests/test_mla.py to gate expensive O(N²) prefill reference paths and add decode runners/metrics for the new regimes.
  • Update Gluon README to document regime dispatch, constraints, and performance numbers.

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 5 comments.

File Description
aiter/ops/triton/gluon/mla_decode_gluon.py Adds bh16bn{64,128} regimes, new grid mapping, masking for nhead < BLOCK_H, and kv_scale plumbing.
op_tests/test_mla.py Adds new decode test runners for bh16bn{64,128}, ref gating for large ctx prefill, and additional perf counters.
aiter/ops/triton/gluon/README.md Documents the three regimes, constraints, invocation examples, and perf tables.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread aiter/ops/triton/gluon/mla_decode_gluon.py
Comment thread aiter/ops/triton/gluon/mla_decode_gluon.py Outdated
Comment thread op_tests/test_mla.py Outdated
Comment thread aiter/ops/triton/gluon/README.md Outdated
Comment thread aiter/ops/triton/gluon/mla_decode_gluon.py
mla_decode_gluon.py: stage-1 split-KV switches to floor-with-remainder
so every split has num_iter >= 3; fold kv_scale into qk_scale (fp8
softmax correctness); bh16bn128 dtype assertions on q/kv; bh16 regime
dispatch rejects unsupported kv_c dtypes.

op_tests/test_mla.py: cal_diff(use_fp8=...) is now derived per-regime
(bh16bn128 only); tighten the absorb-prefill gate with ctx_lens <= 16384
to mirror the normal-prefill gate.

.github/scripts/aiter_test.sh: shard-0 extra invocations guard bh16bn64
(-c 49152 -b 1 -n 16,1 -kvd bf16) and bh16bn128 (-c 98304 -b 1 -n 16,1
-kvd fp8). Neither config is reachable from the test_mla.py default
sweep (caps ctxLen at 8192, never passes -kvd fp8).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@Dewei-Wang-sh Dewei-Wang-sh self-assigned this May 22, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants