[Triton-Gluon-MLA-GFX950] add decode kernel for nhead<=16#3311
Open
Dewei-Wang-sh wants to merge 3 commits into
Open
[Triton-Gluon-MLA-GFX950] add decode kernel for nhead<=16#3311Dewei-Wang-sh wants to merge 3 commits into
Dewei-Wang-sh wants to merge 3 commits into
Conversation
… + FP8 KV)
Adds support for nhead in (4, 8, 16), batch=1, fp8 KV to the existing
gluon MLA decode kernel, targeting the single-query long-context decode
regime (NUM_KV_SPLITS=256, kv_scale dequant, NHEAD<BLOCK_H masking).
Implemented as a constexpr REGIME ('bh64' | 'bh16bn128') gate inside
_mla_decode_gluon: Gluon layouts, BLOCK_H/BLOCK_N, grid mapping, and
launch parameters branch at compile time on REGIME, while the algorithm
skeleton (3-stage SW-pipelined async-copy MLA) is shared. The wrapper
dispatches by nhead. The stage-2 reduce kernel is unchanged.
Tested on MI350 (gfx950):
python op_tests/test_mla.py -c 10000000 -b 1 -n 16,1 4,1 8,1 -d bf16 -kvd fp8
python op_tests/test_mla.py -c 10000 -b 128 -n 128,1 -d bf16 -kvd bf16
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Contributor
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
Contributor
There was a problem hiding this comment.
Pull request overview
Adds new mla_decode_gluon compile-time regimes targeting single-query, long-context decode on gfx950 for batch_size=1 and nhead<=16, including FP8-KV and BF16-KV variants, plus corresponding test harness updates and documentation.
Changes:
- Extend
mla_decode_gluonwith two new regimes:bh16bn128(BF16 Q + FP8 KV) andbh16bn64(BF16 Q + BF16 KV), alongside existingbh64. - Update
op_tests/test_mla.pyto gate expensive O(N²) prefill reference paths and add decode runners/metrics for the new regimes. - Update Gluon README to document regime dispatch, constraints, and performance numbers.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 5 comments.
| File | Description |
|---|---|
aiter/ops/triton/gluon/mla_decode_gluon.py |
Adds bh16bn{64,128} regimes, new grid mapping, masking for nhead < BLOCK_H, and kv_scale plumbing. |
op_tests/test_mla.py |
Adds new decode test runners for bh16bn{64,128}, ref gating for large ctx prefill, and additional perf counters. |
aiter/ops/triton/gluon/README.md |
Documents the three regimes, constraints, invocation examples, and perf tables. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
mla_decode_gluon.py: stage-1 split-KV switches to floor-with-remainder so every split has num_iter >= 3; fold kv_scale into qk_scale (fp8 softmax correctness); bh16bn128 dtype assertions on q/kv; bh16 regime dispatch rejects unsupported kv_c dtypes. op_tests/test_mla.py: cal_diff(use_fp8=...) is now derived per-regime (bh16bn128 only); tighten the absorb-prefill gate with ctx_lens <= 16384 to mirror the normal-prefill gate. .github/scripts/aiter_test.sh: shard-0 extra invocations guard bh16bn64 (-c 49152 -b 1 -n 16,1 -kvd bf16) and bh16bn128 (-c 98304 -b 1 -n 16,1 -kvd fp8). Neither config is reachable from the test_mla.py default sweep (caps ctxLen at 8192, never passes -kvd fp8). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
906358a to
759fd6d
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
mla_decode_gluonfor the single-query long-context decode shape (batch=1, nhead ≤ 16):bh16bn128(BF16 Q + FP8 KV)and
bh16bn64(BF16 Q + BF16 KV). Existingbh64regime (batch ∈ {64,128,256}, nhead ∈ {64,128}, BF16 KV) is unchanged.Constraint and Usage
check the README
Perf (MI350, gfx950)
Test plan
python op_tests/test_mla.py -c 10000000 -b 1 -n 16,1 4,1 8,1 -d bf16 -kvd bf16(bh16bn64)python op_tests/test_mla.py -c 10000000 -b 1 -n 16,1 4,1 8,1 -d bf16 -kvd fp8(bh16bn128)checkAllclosepasses against the golden reference for all of the above