Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 112 additions & 0 deletions csrc/cpp_itfs/pa/pa_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,65 @@ _paged_attention_kernel(const int* block_table_seq,
const int wg_start_kv_head_idx = kv_head_idx;
const int total_num_heads = gridDim.z * GQA_RATIO;

// SWA partition-level early-out: if the entire partition lies before the
// sliding-window lower bound, skip the GEMMs entirely. We still must write
// sentinel max_logits = -FLT_MAX and exp_sums = 0 plus tmp_out = 0 for the
// skipped partition, because the reduce kernel always reads every partition
// slot and computes `tmp_out[...] * shared_exp_sums[...]` -- uninitialised
// memory would propagate as NaN (NaN * 0 = NaN).
if constexpr (SLIDING_WINDOW_ENABLED) {
const int kv_lo = context_len - sliding_window;
if (kv_lo > 0 && partition_start_token_idx + T_PAR_SIZE <= kv_lo) {
// (a) sentinel writes for max_logits and exp_sums
if (threadIdx.x < GQA_RATIO_MTP_PARALLEL) {
for (int mtp = 0; mtp < MTP_PER_THREAD; mtp++) {
for (int gqa_ratio_loop = 0; gqa_ratio_loop < GQA_RATIO_LOOP;
gqa_ratio_loop++) {
const int qhead_idx =
lane16id + gqa_ratio_loop * GQA_RATIO_PER_LOOP;
const int64_t offset =
static_cast<int64_t>(seq_idx + mtp * MTP_PARALLEL_THREADS) *
static_cast<int64_t>(total_num_heads) *
static_cast<int64_t>(max_num_partitions) +
(static_cast<int64_t>(wg_start_head_idx) +
static_cast<int64_t>(qhead_idx)) *
static_cast<int64_t>(max_num_partitions) +
static_cast<int64_t>(partition_idx);
max_logits[offset] = -FLT_MAX;
exp_sums[offset] = 0.0f;
}
}
}
// (b) zero-fill tmp_out for this partition's heads. The reduce
// kernel always loads tmp_out[seq, head, partition, :HEAD_SIZE]
// for every partition slot regardless of exp_sum.
const int64_t hsz_maxp_mult =
static_cast<int64_t>(HEAD_SIZE * max_num_partitions);
for (int mtp = 0; mtp < MTP_PER_THREAD; mtp++) {
scalar_t* out_ptr =
out +
static_cast<int64_t>(seq_idx + mtp * MTP_PARALLEL_THREADS) *
static_cast<int64_t>(total_num_heads) * hsz_maxp_mult +
static_cast<int64_t>(partition_idx) *
static_cast<int64_t>(HEAD_SIZE);
for (int gqa_ratio_loop = 0; gqa_ratio_loop < GQA_RATIO_LOOP;
gqa_ratio_loop++) {
for (int h = 0; h < GQA_RATIO_PER_LOOP; h++) {
const int64_t out_head_idx = static_cast<int64_t>(
wg_start_head_idx + h +
gqa_ratio_loop * GQA_RATIO_PER_LOOP);
scalar_t* out_ptr2 = out_ptr + out_head_idx * hsz_maxp_mult;
for (int idx = threadIdx.x; idx < HEAD_SIZE;
idx += NUM_THREADS) {
out_ptr2[idx] = from_float<scalar_t>(0.0f);
}
}
}
}
return;
}
}

/// NOTICE: We don't support mask for this kernel, so just use a placeholder type/object here.
using Mask = ck_tile::SimplifiedGenericAttentionMask</*IsMasking=*/false>;
const Mask mask{/*seqlen_q=*/1, /*seqlen_k=*/context_len};
Expand Down Expand Up @@ -1212,6 +1271,59 @@ __inline__ __device__ void _paged_attention_kernel_EXPERIMENTAL(
const int wg_start_kv_head_idx = kv_head_idx;
const int total_num_heads = gridDim.z * GQA_RATIO;

// SWA partition-level early-out -- see _paged_attention_kernel for full
// explanation. Mirrors the non-experimental kernel: writes sentinels
// (max_logits=-FLT_MAX, exp_sums=0) and zeros tmp_out for the skipped
// partition so the reduce kernel does not propagate NaN.
if constexpr (SLIDING_WINDOW_ENABLED) {
const int kv_lo = context_len - sliding_window;
if (kv_lo > 0 && partition_start_token_idx + T_PAR_SIZE <= kv_lo) {
if (threadIdx.x < GQA_RATIO_MTP_PARALLEL) {
for (int mtp = 0; mtp < MTP_PER_THREAD; mtp++) {
for (int gqa_ratio_loop = 0; gqa_ratio_loop < GQA_RATIO_LOOP;
gqa_ratio_loop++) {
const int qhead_idx =
lane16id + gqa_ratio_loop * GQA_RATIO_PER_LOOP;
const int64_t offset =
static_cast<int64_t>(seq_idx + mtp * MTP_PARALLEL_THREADS) *
static_cast<int64_t>(total_num_heads) *
static_cast<int64_t>(max_num_partitions) +
(static_cast<int64_t>(wg_start_head_idx) +
static_cast<int64_t>(qhead_idx)) *
static_cast<int64_t>(max_num_partitions) +
static_cast<int64_t>(partition_idx);
max_logits[offset] = -FLT_MAX;
exp_sums[offset] = 0.0f;
}
}
}
const int64_t hsz_maxp_mult =
static_cast<int64_t>(HEAD_SIZE * max_num_partitions);
for (int mtp = 0; mtp < MTP_PER_THREAD; mtp++) {
scalar_t* out_ptr =
out +
static_cast<int64_t>(seq_idx + mtp * MTP_PARALLEL_THREADS) *
static_cast<int64_t>(total_num_heads) * hsz_maxp_mult +
static_cast<int64_t>(partition_idx) *
static_cast<int64_t>(HEAD_SIZE);
for (int gqa_ratio_loop = 0; gqa_ratio_loop < GQA_RATIO_LOOP;
gqa_ratio_loop++) {
for (int h = 0; h < GQA_RATIO_PER_LOOP; h++) {
const int64_t out_head_idx = static_cast<int64_t>(
wg_start_head_idx + h +
gqa_ratio_loop * GQA_RATIO_PER_LOOP);
scalar_t* out_ptr2 = out_ptr + out_head_idx * hsz_maxp_mult;
for (int idx = threadIdx.x; idx < HEAD_SIZE;
idx += NUM_THREADS) {
out_ptr2[idx] = from_float<scalar_t>(0.0f);
}
}
}
}
return;
}
}

// HEAD_SIZE=128, cache_t=bf16, blockSize 16/64/256
constexpr int BYTES_PER_WARP_FETCH = WARP_SIZE * 16; // 1024 bytes
constexpr int TOKEN_PER_WARP_FETCH =
Expand Down
Loading
Loading