Skip to content

PA: route uniform shuffled-FP8 MTP decode (qlen 2/3/4) to pa_fwd_asm#3302

Open
vstakhov-amd wants to merge 1 commit into
mainfrom
vstakhov/mtp_asm_route_shuffled_fp8
Open

PA: route uniform shuffled-FP8 MTP decode (qlen 2/3/4) to pa_fwd_asm#3302
vstakhov-amd wants to merge 1 commit into
mainfrom
vstakhov/mtp_asm_route_shuffled_fp8

Conversation

@vstakhov-amd
Copy link
Copy Markdown

vLLM MTP decode (GLM-4.7-FP8 TP4, DeepSeek-style speculative decoding) calls PagedAttention.forward_decode with bf16
Q and shuffled FP8 KV cache (value_cache.dim() == 5) and mtp ∈ {2, 3, 4}. Today that lands on the generated HIP
shuffled-MTP path, which is both slower and noticeably less accurate than the existing ASM kernel for this shape.

The ASM Mtp=1 msk=1 kernel (pa_bf16_pertokenFp8_gqa{8,16}_1tg_4w_mtp_msk1) is already a generic QTP-driven MTP
kernel: asm_pa.cu encodes mtp = max_qlen + 10, msk = 1 when qo_indptr != nullptr && max_qlen > 1, and the
kernel branches on QTP[b+1] - QTP[b] for qlen 2/3/4. The CSV manifests for both gfx942 and gfx950 already ship
this kernel — there is nothing new to compile.

Motivation

JIRA

Technical Details

PagedAttention.forward_decode now routes to ops.pa_fwd_asm when all of these hold:

  • mtp ∈ {2, 3, 4}
  • gqa_ratio ∈ {8, 16}
  • get_gfx() ∈ {gfx942, gfx950}
  • block_size == 16, head_size == 128
  • value_cache.dim() == 5 (shuffled layout)
  • kv_cache_dtype ∈ {fp8, fp8_e4m3}
  • q_scale is None and fp8_out_scale is None
  • num_seqs == batch * mtp (uniform qlen per batch)

On the fast path we synthesize qo_indptr = arange(0, num_seqs+1, mtp) and call pa_fwd_asm(..., block_tables.stride(0), mtp, k_scale, v_scale, output, qo_indptr). The fast-path block is placed before the partition
/ tmp_output / exp_sums / max_logits allocations so the routed call doesn't pay for buffers it never uses.

gqa_ratio == 10 deliberately stays on the HIP path: asm_pa.cu:get_heuristic_kernel falls gqa=10 back to the
gqa=16 kernel via gqa_flags = {gqa, ceil(gqa/8)*8}, which then runs with mismatched args.GQA = 10 and produces
incorrect output. Excluding it from the gate keeps the existing HIP behavior intact.

Test Plan

op_tests/test_mtp_routing.py (new) builds vLLM-shaped inputs (bf16 Q, per-token-quant FP8 KV, shuffled V) and
verifies the routed forward_decode matches a torch reference within atol = rtol = 0.02 across:

  • ASM fast path: gqa ∈ {8, 16} × qlen ∈ {2, 3, 4} × ctx ∈ {128, 4097, 16384} (18 cases)
  • HIP fallback: gqa = 10 × qlen ∈ {2, 3, 4} × ctx ∈ {128, 4097} (6 cases)

Skips on non-gfx942/gfx950 archs.

Test Result

Arch Cases Pass Notes
gfx950 24 24 ASM gqa∈{8,16} + HIP gqa=10
gfx942 24 24 ASM gqa∈{8,16} + HIP gqa=10

End-to-end: GLM-4.7-FP8 TP4 routed-to-ASM matches reference and is faster than the HIP fallback on both archs (perf
scratch run locally, not committed).

Submission Checklist

@vstakhov-amd vstakhov-amd requested review from a team, JohnNikolay84 and valarLip May 21, 2026 11:13
@github-actions
Copy link
Copy Markdown
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests on MI35X (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-300x Run an additional Triton test job on MI300X in PRs; main branch always runs both MI35X and MI300X
ci:sglang SGLang integration tests: DeepSeek-R1-MXFP4 accuracy, Qwen 3.5 accuracy
ci:atom ATOM benchmark: DeepSeek-R1-0528, GPT-OSS-120B
ci:atom_full ATOM accuracy suite for PR and main models from ATOM models_accuracy.json
ci:vllm vLLM benchmark: GPT-OSS-120B, DeepSeek-R1-0528, Kimi-K2.5
ci:all All standard extended tests (excludes ci:atom_full)

Only add ci:atom_full for FlyDSL or Triton upgrades.
Add labels via the sidebar or gh pr edit 3302 --add-label <label>

@vstakhov-amd vstakhov-amd force-pushed the vstakhov/mtp_asm_route_shuffled_fp8 branch from 6e81f70 to 1ad0061 Compare May 21, 2026 11:15
vLLM-style MTP decode with bf16 Q and shuffled FP8 KV (value_cache.dim()==5)
previously took the generated HIP shuffled-MTP path. The ASM Mtp=1 generic
kernel pa_bf16_pertokenFp8_gqa{8,16}_1tg_4w_mtp_msk1 already handles qlen
2/3/4 via qo_indptr (asm_pa.cu encodes mtp = max_qlen + 10, msk = 1 when
qo_indptr != nullptr and max_qlen > 1), and is more accurate and faster than
the HIP path for this shape.

PagedAttention.forward_decode now routes to pa_fwd_asm when every gate holds:
mtp in {2,3,4}, gqa_ratio in {8,16}, gfx in {gfx942,gfx950}, block_size==16,
head_size==128, value_cache.dim()==5, num_seqs == batch*mtp, kv_cache_dtype
in {fp8, fp8_e4m3}, and q_scale / fp8_out_scale unset. gqa=10 deliberately
stays on HIP: asm_pa.cu:get_heuristic_kernel falls gqa=10 back to the gqa=16
kernel via gqa_flags={10,16}, which runs with mismatched args.GQA and
produces incorrect output.

The fast-path block is placed before the partition/tmp_output allocations so
the routed call does not pay for buffers it never uses.
@vstakhov-amd vstakhov-amd force-pushed the vstakhov/mtp_asm_route_shuffled_fp8 branch from 1ad0061 to a1ca7de Compare May 21, 2026 11:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant