add pageattention with sliding window#3297
Open
ChengYao-amd wants to merge 1 commit into
Open
Conversation
Contributor
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
698a944 to
de783a2
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
_paged_attention_kerneland_paged_attention_kernel_EXPERIMENTALalready implementsliding-window attention (SWA) at the token level — masking K/V positions that fall
outside
[context_len - sliding_window, context_len)inside the inner loop. The kernelstill launches one workgroup per partition (
T_PAR_SIZE = 256tokens) and runs thefull QK / softmax / KV-output GEMMs for every partition, even when an entire partition
lies before the window's lower bound and contributes nothing to the result.
For long contexts paired with a small SWA window (e.g.
ctx_len = 9216,sliding_window = 8192), the wasted partitions are a non-trivial fraction of thetotal work. This PR adds a partition-level early-out so those workgroups exit
before any matmul, while keeping the reduce kernel correct.
Technical Details
Kernel changes —
csrc/cpp_itfs/pa/pa_kernels.cuhBoth
_paged_attention_kerneland_paged_attention_kernel_EXPERIMENTALget a newif constexpr (SLIDING_WINDOW_ENABLED)block right after thewg_start_head_idx/total_num_headscomputation:The reduce kernel always iterates over every partition slot and computes
tmp_out[..] * shared_exp_sums[..], so leaving the skipped partition's outputsuninitialised would propagate NaN (
NaN * 0 = NaN). The early-out therefore stillwrites well-defined sentinels for every head/MTP slot it owns, then returns:
max_logits[seq, head, partition] = -FLT_MAXexp_sums[seq, head, partition] = 0.0ftmp_out[seq, head, partition, 0..HEAD_SIZE) = 0With these sentinels the reduce path's logsumexp and weighted sum naturally drop
the skipped partition without any branches on its side. The
if constexprgateensures non-SWA specializations compile to the original code path with zero
runtime overhead.
Tests —
op_tests/test_pa_v1_swa.py(new)Targets
torch.ops.aiter.paged_attention_v1. Components:run_torch_swa(...): from-scratch PyTorch reference that gathers paged K/V fromblock_tables, builds the fullQK^T, applies an explicitposition < max(0, seq_len - window_size)mask, then softmax and V product.run_aiter_pa_v1_swa(...): thin wrapper that allocates the workspace buffer,sets
max_num_partitionsfrom_PARTITION_SIZE_ROCM = 256, and dispatches thev1 op with the
sliding_windowargument (0= disabled).test_pa_v1_swa_correctness: parametrized overdtype ∈ {bf16, fp16},num_seqs ∈ {1, 8, 16},(num_q_heads, num_kv_heads) ∈ {(8,1), (16,4), (32,4)},head_size ∈ {64, 128},block_size ∈ {16, 32}, and(ctx_len, window)pairscovering:
window > ctx(no masking),window == ctx,ctx == window + 1(one masked position),
ctx >> windowto exercise many fully-skipped partitions(e.g.
(9216, 8192),(4097, 2048)), andwindow <= 0(disabled).test_pa_v1_swa_matches_full_causal_when_window_ge_context: equivalence check —when
window >= ctx_len, SWA output must match the full-causal baseline.Tolerances:
atol = 1e-2(bf16) /5e-3(fp16),rtol = 1e-2, with an additionalerror_ratio < 0.01gate viaaiter.test_common.checkAllclose.Test Plan
pytest op_tests/test_pa_v1_swa.py -v # Single-shape CLI for quick smoke checks: python op_tests/test_pa_v1_swa.py --ctx 9216 --window 8192 --num_seqs 8 python op_tests/test_pa_v1_swa.py --ctx 4096 --window 1024 --dtype fp16Test Result
Hardware: AMD Instinct MI300X (
gfx942:sramecc+:xnack-), ROCm, PyTorch 2.9.1.test_pa_v1_swa_correctness: 792 / 792 passed across all parametrized(dtype, num_seqs, num_heads, head_size, block_size, ctx_len, window_size)combinations, including the long-context / short-window shapes that exercise
the new partition-level early-out (
(ctx=9216, window=8192),(ctx=4097, window=2048),(ctx=2048, window=512)…).test_pa_v1_swa_matches_full_causal_when_window_ge_context: passed —SWA output with
window >= ctx_lenmatches the full-causal baseline withinatol = rtol = 1e-3.Submission Checklist