Skip to content

add pageattention with sliding window#3297

Open
ChengYao-amd wants to merge 1 commit into
ROCm:mainfrom
ChengYao-amd:dev/yaoc/add-swa-isa-kernel
Open

add pageattention with sliding window#3297
ChengYao-amd wants to merge 1 commit into
ROCm:mainfrom
ChengYao-amd:dev/yaoc/add-swa-isa-kernel

Conversation

@ChengYao-amd
Copy link
Copy Markdown

Motivation

_paged_attention_kernel and _paged_attention_kernel_EXPERIMENTAL already implement
sliding-window attention (SWA) at the token level — masking K/V positions that fall
outside [context_len - sliding_window, context_len) inside the inner loop. The kernel
still launches one workgroup per partition (T_PAR_SIZE = 256 tokens) and runs the
full QK / softmax / KV-output GEMMs for every partition, even when an entire partition
lies before the window's lower bound and contributes nothing to the result.
For long contexts paired with a small SWA window (e.g. ctx_len = 9216,
sliding_window = 8192), the wasted partitions are a non-trivial fraction of the
total work. This PR adds a partition-level early-out so those workgroups exit
before any matmul, while keeping the reduce kernel correct.

Technical Details

Kernel changes — csrc/cpp_itfs/pa/pa_kernels.cuh

Both _paged_attention_kernel and _paged_attention_kernel_EXPERIMENTAL get a new
if constexpr (SLIDING_WINDOW_ENABLED) block right after the wg_start_head_idx /
total_num_heads computation:

if constexpr (SLIDING_WINDOW_ENABLED) {
    const int kv_lo = context_len - sliding_window;
    if (kv_lo > 0 && partition_start_token_idx + T_PAR_SIZE <= kv_lo) {
        // partition is entirely outside the SWA window -> skip GEMMs
    }
}

The reduce kernel always iterates over every partition slot and computes
tmp_out[..] * shared_exp_sums[..], so leaving the skipped partition's outputs
uninitialised would propagate NaN (NaN * 0 = NaN). The early-out therefore still
writes well-defined sentinels for every head/MTP slot it owns, then returns:

  1. max_logits[seq, head, partition] = -FLT_MAX
  2. exp_sums[seq, head, partition] = 0.0f
  3. tmp_out[seq, head, partition, 0..HEAD_SIZE) = 0
    With these sentinels the reduce path's logsumexp and weighted sum naturally drop
    the skipped partition without any branches on its side. The if constexpr gate
    ensures non-SWA specializations compile to the original code path with zero
    runtime overhead.

Tests — op_tests/test_pa_v1_swa.py (new)

Targets torch.ops.aiter.paged_attention_v1. Components:

  • run_torch_swa(...): from-scratch PyTorch reference that gathers paged K/V from
    block_tables, builds the full QK^T, applies an explicit
    position < max(0, seq_len - window_size) mask, then softmax and V product.
  • run_aiter_pa_v1_swa(...): thin wrapper that allocates the workspace buffer,
    sets max_num_partitions from _PARTITION_SIZE_ROCM = 256, and dispatches the
    v1 op with the sliding_window argument (0 = disabled).
  • test_pa_v1_swa_correctness: parametrized over dtype ∈ {bf16, fp16},
    num_seqs ∈ {1, 8, 16}, (num_q_heads, num_kv_heads) ∈ {(8,1), (16,4), (32,4)},
    head_size ∈ {64, 128}, block_size ∈ {16, 32}, and (ctx_len, window) pairs
    covering: window > ctx (no masking), window == ctx, ctx == window + 1
    (one masked position), ctx >> window to exercise many fully-skipped partitions
    (e.g. (9216, 8192), (4097, 2048)), and window <= 0 (disabled).
  • test_pa_v1_swa_matches_full_causal_when_window_ge_context: equivalence check —
    when window >= ctx_len, SWA output must match the full-causal baseline.
    Tolerances: atol = 1e-2 (bf16) / 5e-3 (fp16), rtol = 1e-2, with an additional
    error_ratio < 0.01 gate via aiter.test_common.checkAllclose.

Test Plan

pytest op_tests/test_pa_v1_swa.py -v
# Single-shape CLI for quick smoke checks:
python op_tests/test_pa_v1_swa.py --ctx 9216 --window 8192 --num_seqs 8
python op_tests/test_pa_v1_swa.py --ctx 4096 --window 1024 --dtype fp16

Test Result

Hardware: AMD Instinct MI300X (gfx942:sramecc+:xnack-), ROCm, PyTorch 2.9.1.

============================= test session starts ==============================
platform linux -- Python 3.12.12, pytest-9.0.2, pluggy-1.6.0
collected 793 items
op_tests/test_pa_v1_swa.py ...........................................  [100%]
======================= 793 passed in 238.08s (0:03:58) ========================
  • test_pa_v1_swa_correctness: 792 / 792 passed across all parametrized
    (dtype, num_seqs, num_heads, head_size, block_size, ctx_len, window_size)
    combinations, including the long-context / short-window shapes that exercise
    the new partition-level early-out ((ctx=9216, window=8192), (ctx=4097, window=2048), (ctx=2048, window=512) …).
  • test_pa_v1_swa_matches_full_causal_when_window_ge_context: passed —
    SWA output with window >= ctx_len matches the full-causal baseline within
    atol = rtol = 1e-3.
  • 0 failed, 0 errored.

Submission Checklist

@ChengYao-amd ChengYao-amd requested a review from a team May 21, 2026 05:08
@github-actions
Copy link
Copy Markdown
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests on MI35X (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-300x Run an additional Triton test job on MI300X in PRs; main branch always runs both MI35X and MI300X
ci:sglang SGLang integration tests: DeepSeek-R1-MXFP4 accuracy, Qwen 3.5 accuracy
ci:atom ATOM benchmark: DeepSeek-R1-0528, GPT-OSS-120B
ci:atom_full ATOM accuracy suite for PR and main models from ATOM models_accuracy.json
ci:vllm vLLM benchmark: GPT-OSS-120B, DeepSeek-R1-0528, Kimi-K2.5
ci:all All standard extended tests (excludes ci:atom_full)

Only add ci:atom_full for FlyDSL or Triton upgrades.
Add labels via the sidebar or gh pr edit 3297 --add-label <label>

@ChengYao-amd ChengYao-amd force-pushed the dev/yaoc/add-swa-isa-kernel branch from 698a944 to de783a2 Compare May 21, 2026 05:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant