PA: route uniform shuffled-FP8 MTP decode (qlen 2/3/4) to pa_fwd_asm#3302
Open
vstakhov-amd wants to merge 1 commit into
Open
PA: route uniform shuffled-FP8 MTP decode (qlen 2/3/4) to pa_fwd_asm#3302vstakhov-amd wants to merge 1 commit into
vstakhov-amd wants to merge 1 commit into
Conversation
Contributor
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
6e81f70 to
1ad0061
Compare
vLLM-style MTP decode with bf16 Q and shuffled FP8 KV (value_cache.dim()==5)
previously took the generated HIP shuffled-MTP path. The ASM Mtp=1 generic
kernel pa_bf16_pertokenFp8_gqa{8,16}_1tg_4w_mtp_msk1 already handles qlen
2/3/4 via qo_indptr (asm_pa.cu encodes mtp = max_qlen + 10, msk = 1 when
qo_indptr != nullptr and max_qlen > 1), and is more accurate and faster than
the HIP path for this shape.
PagedAttention.forward_decode now routes to pa_fwd_asm when every gate holds:
mtp in {2,3,4}, gqa_ratio in {8,16}, gfx in {gfx942,gfx950}, block_size==16,
head_size==128, value_cache.dim()==5, num_seqs == batch*mtp, kv_cache_dtype
in {fp8, fp8_e4m3}, and q_scale / fp8_out_scale unset. gqa=10 deliberately
stays on HIP: asm_pa.cu:get_heuristic_kernel falls gqa=10 back to the gqa=16
kernel via gqa_flags={10,16}, which runs with mismatched args.GQA and
produces incorrect output.
The fast-path block is placed before the partition/tmp_output allocations so
the routed call does not pay for buffers it never uses.
1ad0061 to
a1ca7de
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
vLLM MTP decode (GLM-4.7-FP8 TP4, DeepSeek-style speculative decoding) calls
PagedAttention.forward_decodewith bf16Q and shuffled FP8 KV cache (
value_cache.dim() == 5) andmtp ∈ {2, 3, 4}. Today that lands on the generated HIPshuffled-MTP path, which is both slower and noticeably less accurate than the existing ASM kernel for this shape.
The ASM
Mtp=1 msk=1kernel (pa_bf16_pertokenFp8_gqa{8,16}_1tg_4w_mtp_msk1) is already a generic QTP-driven MTPkernel:
asm_pa.cuencodesmtp = max_qlen + 10,msk = 1whenqo_indptr != nullptr && max_qlen > 1, and thekernel branches on
QTP[b+1] - QTP[b]for qlen 2/3/4. The CSV manifests for bothgfx942andgfx950already shipthis kernel — there is nothing new to compile.
Motivation
JIRA
Technical Details
PagedAttention.forward_decodenow routes toops.pa_fwd_asmwhen all of these hold:mtp ∈ {2, 3, 4}gqa_ratio ∈ {8, 16}get_gfx() ∈ {gfx942, gfx950}block_size == 16,head_size == 128value_cache.dim() == 5(shuffled layout)kv_cache_dtype ∈ {fp8, fp8_e4m3}q_scale is Noneandfp8_out_scale is Nonenum_seqs == batch * mtp(uniform qlen per batch)On the fast path we synthesize
qo_indptr = arange(0, num_seqs+1, mtp)and callpa_fwd_asm(..., block_tables.stride(0), mtp, k_scale, v_scale, output, qo_indptr). The fast-path block is placed before the partition/
tmp_output/exp_sums/max_logitsallocations so the routed call doesn't pay for buffers it never uses.gqa_ratio == 10deliberately stays on the HIP path:asm_pa.cu:get_heuristic_kernelfallsgqa=10back to thegqa=16kernel viagqa_flags = {gqa, ceil(gqa/8)*8}, which then runs with mismatchedargs.GQA = 10and producesincorrect output. Excluding it from the gate keeps the existing HIP behavior intact.
Test Plan
op_tests/test_mtp_routing.py(new) builds vLLM-shaped inputs (bf16 Q, per-token-quant FP8 KV, shuffled V) andverifies the routed
forward_decodematches a torch reference withinatol = rtol = 0.02across:gqa ∈ {8, 16} × qlen ∈ {2, 3, 4} × ctx ∈ {128, 4097, 16384}(18 cases)gqa = 10 × qlen ∈ {2, 3, 4} × ctx ∈ {128, 4097}(6 cases)Skips on non-
gfx942/gfx950archs.Test Result
End-to-end: GLM-4.7-FP8 TP4 routed-to-ASM matches reference and is faster than the HIP fallback on both archs (perf
scratch run locally, not committed).
Submission Checklist
op_tests/