Skip to content

[CK] Add new FMHA batch prefill kernel with FP8 per-tensor and per-block KV quantization on gfx950#6054

Open
poyenc wants to merge 63 commits into
developfrom
users/poyenc/ck/batch-prefill-v3
Open

[CK] Add new FMHA batch prefill kernel with FP8 per-tensor and per-block KV quantization on gfx950#6054
poyenc wants to merge 63 commits into
developfrom
users/poyenc/ck/batch-prefill-v3

Conversation

@poyenc
Copy link
Copy Markdown
Contributor

@poyenc poyenc commented Mar 31, 2026

Summary

V3 FMHA pipeline for batch prefill with paged KV cache on gfx950 (FP8 only). Uses the same 8-warp, 256x64 tile, double-buffered LDS architecture as the contiguous fmha_fwd V3 kernel, extended with scatter-gather page table support.

Supported configurations:

  • FP8 per-tensor and KV block-scale quantization
  • SGLang (1D) and vLLM (2D) page tables
  • LINEAR KV layout only (VECTORIZED falls back to V2 via trait matching)
  • No mask / causal mask
  • Page sizes: 1, 16, 1024

Performance vs V2 (FP8 batch prefill, paged KV, stock ROCm 7.1.1, avg runs 2–6, MI355X):

Per-tensor (ps=1, sglang):

Problem V2 TFlops V3 TFlops V3/V2
h=6/1 sq=1k c 43.1 47.7 1.11x
h=6/1 sq=2k c 95.4 120.1 1.26x
h=6/1 sq=4k c 197.5 259.5 1.31x
h=6/1 sq=8k c 396.3 549.4 1.39x
h=6/1 sq=16k c 778.7 1066.1 1.37x
h=6/1 sq=32k c 1000.2 1282.3 1.28x
h=6/1 sq=65k c 1045.0 1375.5 1.32x
h=6/1 sq=131k c 1072.9 1390.2 1.30x
h=16/1 sq=65k c 1072.3 1381.8 1.29x
h=40/40 sq=37k nc 930.1 1304.3 1.40x

KV block-scale (ps=1024, sglang):

Problem V2 TFlops V3 TFlops V3/V2
h=6/1 sq=1k c 42.6 45.2 1.06x
h=6/1 sq=2k c 95.2 111.6 1.17x
h=6/1 sq=4k c 197.8 241.1 1.22x
h=6/1 sq=8k c 389.9 509.7 1.31x
h=6/1 sq=16k c 745.6 1013.1 1.36x
h=6/1 sq=32k c 1016.1 1205.4 1.19x
h=6/1 sq=65k c 1057.0 1313.5 1.24x
h=6/1 sq=131k c 1069.6 1330.3 1.24x
h=16/1 sq=65k c 1084.4 1324.2 1.22x
h=40/40 sq=37k nc 945.0 1251.6 1.32x

Key design decisions

Branchless page_id clamping — The scatter-gather path bypasses buffer descriptor NUM_RECORDS protection, so auto-advancing past seqlen_k causes XNACK faults. Solved with min(page_id, max_page_table_idx) instead of branch guards, which would fragment sched_group_barrier-scheduled basic blocks.

Page advance issue/consume split — Page table lookups (global_load_dword) are issued BEFORE buffer_loads so they sit oldest in the vmcnt FIFO. At consume, s_waitcnt(N) drains only the page lookup while keeping N buffer_loads in flight. sched_barrier(0) prevents reordering.

KV block-scale k_descale fold — k_descale is folded into the scalar row_max via the FP8 shift trick, eliminating a full-tile VALU pass. v_descale is merged into the softmax rescale factor (o_acc_scale), with the final correction applied in the epilogue.

No no-packed-fp32-ops kernel attribute — V3 does NOT use kernel_attr<true> (which adds target("no-packed-fp32-ops")) despite it giving a +2–4% benefit on per-tensor workloads. The attribute conflicts with explicit v_pk_mul_f32 inline asm used in the KV block-scale descale path (pk_mul_f32 helper): when the asm is inlined into the attributed kernel entry, the assembler rejects the instruction. Benchmarks show the attribute is neutral-to-positive for block-scale (removing it actually gives +1–3%), so we accept the 2–4% per-tensor trade-off to avoid the asm conflict. A future per-variant solution (attribute on for per-tensor, off for block-scale) could recover this, but the codegen currently uses a single F_kernel_attr per kernel.

New files

File Description
fmha_batch_prefill_v3_kernel.hpp V3 kernel with paged KV DRAM views, page table Kargs, batch/group mode
block_fmha_batch_prefill_v3_pipeline.hpp V3 pipeline with scatter-gather loads, KV block-scale support
block_fmha_batch_prefill_v3_pipeline_default_policy.hpp FP8 LDS descriptors, SRD rebasing for large page sizes
block_fmha_fwd_v3_detail.hpp Shared V3 helpers: CoreLoopSchedulingParams, VALU intrinsics, macros

Codegen changes (fmha_batch_prefill.py)

Refactored to support per-architecture kernel generation (matching fmha_fwd.py patterns):

  • ArchTrait / factory hierarchy: gfx9 -> gfx950 with arch-specific tiles, pipelines, and compatibility rules
  • Separate fmha_batch_prefill_v2() / fmha_batch_prefill_v3() dispatch functions (V3 tried first, falls back to V2)
  • 80 V3 kernels generated for gfx950

Bug fix in fmha_fwd.py

Fixed missing return in check_hdim — bare False was a no-op, causing unnecessary bias/dropout kernels for hdim=(192,128).

Dependencies

Based on #4437 (users/poyenc/fa-v3-fp8-pertensor). That PR must be merged first.

Test plan

  • Full test_batch_prefill.py suite

poyenc added 30 commits March 13, 2026 02:35
Add FP8BF16 per-tensor quantization path to the FMHA forward V3
pipeline on gfx950. This includes:

- FP8 32x32x32 warp gemm with C-transposed distribution
- FP8 warp gemm dispatcher entries
- V3 kernel support for per-tensor descale (q/k/v descale pointers)
- V3 pipeline FP8 data path with asm volatile for P conversion
- FP8 instruction scheduling optimization in CoreLoopScheduler
- Codegen: FP8BF16 V3 tile size (256x64x128) and pipeline variants
- Codegen: V3 dispatch condition extended for fp8bf16+pertensor
- LLVM scheduler TRANS mask for scheduling control
- Fix mask_info default initialization for no_mask case

Note: V3 dispatch is disabled by default pending further validation.
Remove debug macros (ENABLE_DEBUG_STMTS, DEBUG_STMTS, WARP_ID, LANE_ID),
debug lambdas (print_dist_tensor, print_lds, print_lds_1d), unused LDS
windows (s/p/o/m_lds_window), their helper methods (MakeSimpleLdsDesc,
MakeSimpleLdsDesc1D), and unused KPack variables in the policy file.

Assembly verified identical (sched_diff=0), 176/176 fp8 tests pass.
The P matrix (attention weights) lives entirely in registers (VGPRs)
via sp_compute_type, not in LDS. Remove the P buffer terms from
GetSmemSize() so it reports only the actual KV buffer usage.

Assembly-verified: before/after diff shows identical GPU and host code.
- Separate smem_k[2]/smem_v[2] pointers for explicit buffer control
- Use async_load_tile_raw / init_raw for raw async copies
- Remove dead P buffer from LDS size calculation
- Reformat operator() signatures for readability
Revert changes from debug commit that swapped NumWarps and LaneGroups
in MakeVDramTileDistribution(), MakeVLdsStoreBlockDescriptor(), and
MakeVLdsLoadBlockDescriptor(). These were unrelated to the 4-buffer
LDS architecture refactor.

Restores the original dimension ordering:
- MakeVDramTileDistribution: N1=LaneGroups, N2=NumWarps
- MakeVLdsStoreBlockDescriptor: shape (NumIssues, LaneGroups, NumWarps, ...)
- MakeVLdsLoadBlockDescriptor: merge sequence<0, 2, 1> for correct reorder

Testing: 176/176 FP8 MHA tests pass
Remove fine-grained can_dispatch_v3 runtime guard. Try V3 first when
enabled; unsupported configs return -1 and fall back to V2.
…ution

Remove duplicate plain using definitions of WarpGemmMfma_f32_32x32x32_fp8_fp8,
WarpGemmMfma_f32_32x32x32_bf8_bf8 that conflicted with the templated
#if gfx950/#else versions, and deduplicate the corresponding dispatcher entry.
…rning

Drop the epilogue shared-memory buffer and smem_ptr parameter that were
left over after prior refactoring, and silence the -Wunreachable-code
diagnostic in the V3/V2 dispatch fallback.
The bare `False` statement was a no-op, causing bias/dropout kernels
to be generated for (192, 128) hdim configurations instead of being
filtered out.
Add kernel_attr_for<ArchTag, Attrs...> to kernel_launch.hpp that
composes an architecture tag with kernel attributes. When no attributes
are provided, kernel_attr_for<ArchTag> is an identity alias for ArchTag
itself (is_same_v is true). With attributes, it creates a unique type
that inherits both the arch tag and attribute mixins.

The existing kattr_no_packed_fp32_ops_v SFINAE detection works
transparently through the inheritance chain.

Usage:
  kernel_attr_for<gfx950_t>                       -> gfx950_t
  kernel_attr_for<gfx950_t, kernel_attr<true>>    -> unique type
Refactor fmha_batch_prefill.py to match fmha_fwd.py patterns, preparing
for V3 pipeline integration:

- Add ArchTrait to FmhaFwdApiTrait and FmhaFwdKernel with arch
  preprocessor guards
- Refactor FmhaFwdApiPool to hierarchical OrderedDict[arch][dtype][hdim]
  with render() method
- Split API template into HEADER/FUNC_TEMPLATE/PER_ARCH/FOOTER
- Add ProblemContext, KernelContext, CompatibilityRule, is_compatible(),
  create_kernel() abstractions
- Extract inline filtering into CompatibilityRuleFactory and Product
- Add factory hierarchy with get_factories_for_targets()
- Add extensible _get_cpp_kernel_class_name(),
  _get_cpp_kargs_creator_func_name(),
  _get_cpp_pipeline_problem_name() methods to FmhaFwdKernel
- Filename includes arch suffix: {name}_{arch}.cpp

Tested: 14848 passed, 0 failed across 4 combinations
(stock/custom compiler x before/after refactoring).
Add batch prefill V3 pipeline and kernel with scatter-gather paged KV
support, simplified dispatch that relies on trait matching for fallback.

- V3 pipeline: 4-phase double warp group, async buffer loads, 4-buffer LDS
- V3 kernel: LINEAR layout only, SGLang + vLLM page tables
- Codegen: 80 V3 kernels (bf16/fp8, no/causal mask, page sizes 1/16/1024)
- Dispatch: try V3 first when enabled, fall back to V2 via trait matching
- Static asserts enforce V3 constraints (LINEAR, no bias/dropout/kv_blockscale)
…odegen

V3 batch prefill is only needed for fp8bf16. Remove the bf16/fp16 V3
tile (256x64, 8 warps) and pipeline (qr_async_trload_v3) entries from
KernelComponentFactoryGfx950. bf16/fp16 continue to use V2 (qr_async).
Skip page table lookup in K_mem_load/V_mem_load when the next sequence
position exceeds seqlen_k_end. Without this guard, the auto-advance
reads page indices from the padding region of kv_page_indices and
computes scatter-gather offsets that produce buffer_load addresses
mapping to unmapped GPU pages, causing XNACK faults on gfx950.

The contiguous V3 fwd pipeline doesn't have this issue because its
move_tile_window is simple pointer arithmetic protected by the buffer
descriptor's NUM_RECORDS field. The scatter-gather path computes
physical offsets from the page table, bypassing NUM_RECORDS protection.
Separate K/V page offset updates from K/V_mem_load into dedicated
K_page_advance/V_page_advance lambdas, called at the very end of each
phase after Scheduler::schedule. This keeps async_load_tile + ds_read
in one uninterrupted basic block, preventing the XNACK guard branch
from fragmenting the load+ds_read scheduling.

Recovers 6-13% of the 14-26% guard overhead (measured on FP8
batch_prefill sweep). The guard branch still exists but only fragments
the tail of each phase, not the critical load/ds_read interleaving.
Replace the branched XNACK guard (if/s_cbranch) with branchless
min(page_id, max_page_table_idx) inside load_physical_pages(). The
max_page_table_idx is computed as (seqlen_k - 1) / kPageBlockSize in
the kernel and threaded through to all load_physical_pages() call sites.

This eliminates the 14-26% guard overhead that was caused by:
- s_cbranch fragmenting sched_group_barrier-scheduled basic blocks
- serialized global_load_dword + s_waitcnt at conditional join points
- +14 VGPRs from extended live ranges across branch boundaries

V3 FP8 batch_prefill is now 8-23% faster than V2 (was 4-30% slower).
Paged KV overhead reduced from 20-47% to 7-17% vs contiguous varlen.
…overlap

Split K_page_advance/V_page_advance into issue/consume pairs so the
global_load_dword (page table lookup) is issued BEFORE the buffer_loads
from cl_load, placing it oldest in the vmcnt FIFO. At consume time,
s_waitcnt(N) drains only the oldest global_load while keeping the N
buffer_loads in flight.

sched_barrier(0) brackets prevent the compiler from reordering the
global_load_dword across the buffer_loads, which would undo the FIFO
ordering.

Applied to all 4 core loop load phases (WG0 phases 1/3, WG1 phases
0/2), the pre-stage, and WG0 pre-loop setup.

Correctness: 16480 passed, 5376 skipped, 0 failed (matches baseline)
Performance (avg of 5 runs, FP8 batch_prefill, paged KV):
  s=4k-8k: +8-9%, s=16k: +6%, s=32k+: +2-5%, MHA h=40/40: +5%
Add per-page FP8 K/V dequantization scale support to the V3 batch
prefill pipeline, matching the existing V2 implementation.

Kernel: add FmhaFwdKVBlockScaleKargs with nblock/nhead strides,
update Kargs type selection and MakeKargs for both batch/group mode.
scale_s uses q_descale only (k_descale deferred to pipeline).

Pipeline: FP8 shift trick in fmha_alu0 (subtract 8.0/7.0 from row
max to implicitly scale P), k_descale applied after GEMM0, v_descale
rescale trick around GEMM1 (divide before, multiply after). Double-
buffered saved_k/v_descale indexed by LDS buffer slot, saved before
each K_page_issue.

Codegen: add "kv_blockscale" to V3 pipeline generation. Existing
check_page_size filter enforces page_size >= kN0.

Performance tuning not yet done (extra element-wise passes for
v_descale rescale not merged into fmha_alu_D_upd).
…asses

Reduce KV_BLOCKSCALE overhead from 28% to 11% vs pertensor by:

1. Merge v_descale into o_acc_scale: maintain o_acc in v_descale-scaled
   space, folding v_descale_prev/v_descale_cur ratio into the existing
   softmax rescale factor. Eliminates 2 full-tile VALU passes (divide
   before GEMM1 + multiply after GEMM1). Final v_descale applied in
   epilogue normalization.

2. Replace scalar k_descale multiply with v_pk_mul_f32: halves
   instruction count for the remaining k_descale pass by operating on
   float2 pairs.

Both changes are guarded by if constexpr(KV_BLOCKSCALE) — pertensor
assembly is structurally identical before/after.
…elop

- Add block_fmha_batch_prefill_v3_pipeline_default_policy.hpp with
  FP8-specific LDS descriptors (separate from shared V3 fwd policy)
- Add SRD rebasing (rebase_k/v_window) to V3 pipeline for
  kPageBlockSize >= kN0, enabling page_size=1024 V3 dispatch
- Fix load_tile_transpose_with_offset to use develop's void API
- Remove codegen page_size >= kN0 restriction for V3 pipeline
Add BatchPrefillCoreLoopScheduler with FP8-tuned VALU budgets (6/6 per
MFMA half) matching the feature branch's CoreLoopScheduler tuning.
The develop branch had VALU:4/3 which undershoots actual VALU work
because v_pk_mul_f32 asm volatile is invisible to the compiler.

Relax fma_impl_vsv from asm volatile("v_fma_f32") to plain C++ FMA.
The asm volatile anchor prevented compiler reordering across
sched_barrier(0) boundaries, causing s_nop 7+3 stalls (22 extra NOP
cycles per phase2). Plain C++ gives the compiler freedom to fill MFMA
latency gaps.
Use kernel_attr_for<gfx950_t, kernel_attr<true>> for V3 batch prefill
kernels to apply target("no-packed-fp32-ops") to the kernel entry point.
This prevents the compiler from generating v_pk_mul_f32 for non-asm
FP32 operations, allowing separated v_mul_f32 to co-execute with MFMA.

V2 kernels use the plain arch tag (unchanged behavior).

Replaces the -mllvm --amdgpu-disable-packed-fp32=1 flag which was
silently skipped by the stock ROCm 7.1.1 compiler. The target attribute
is honored by both stock and custom compilers.

Measured +2-5% over stock compiler without the attribute.
…aders

Extract CoreLoopSchedulingParams, block_gemm_mfma_count_v, detail::
VALU helpers (fma_impl_vsv, add_impl_vv, mul_impl_vv, cvt_pk_*,
pk_mul_f32), and macros (CK_TILE_FMHA_V3_ASM_MARKER,
CK_TILE_FMHA_V3_ADD_SBARRIER_FOR_PHASE0) into a shared header
block_fmha_fwd_v3_detail.hpp. Both fmha_fwd V3 and batch_prefill V3
pipelines now include this header instead of batch_prefill including
the full fwd V3 pipeline header.

Also:
- Remove __gfx950__ guards inside V3 pipelines (keep only
  permlane32_swap path; entire pipeline is gfx950-only)
- Add top-level gfx950 guard: operator() returns empty output on
  non-gfx950 device
- Remove CK_TILE_DISABLE_PACKED_FP32 macro (always 0; V3 uses
  kernel_attr_for instead of -mllvm flag)
- Add CK_TILE_ prefix to all custom macros
- Add design comments for quant/LSC-unaware scheduler

Assembly output verified identical (only __hip_cuid differs).
V3 batch prefill generates only LINEAR KV layout kernels. VECTORIZED
layout requires sub-dword async loads that violate V3's buffer
addressing constraints, and KV layout optimization has not been done
for V3. VECTORIZED requests fall back to V2 via trait matching.
The attribute conflicts with explicit v_pk_mul_f32 inline asm in the
KV_BLOCKSCALE descale path. Benchmarks (6 sweeps, avg runs 2-6) show
removing it costs 2-4% on pertensor but is neutral-to-positive on
blockscale, making the trade-off acceptable vs the asm conflict.
@poyenc
Copy link
Copy Markdown
Contributor Author

poyenc commented Apr 3, 2026

still be blocked by unrelated compilation errors

poyenc added 12 commits April 3, 2026 09:58
Preserve extraction refactoring (v3_detail.hpp shared header) while
adapting post-review fixes from PR 6051 (fmha_fwd V3) into both the
fwd and batch_prefill V3 pipelines:

1. Move s_waitcnt/s_barrier outside `if(2 < num_total_loop)` guard in
   pre-stage to ensure K1+V0 async loads are drained before core_loop
   reads K1 from LDS (bug fix for num_total_loop==2 case).

2. Replace manual __builtin_amdgcn_permlane32_swap intrinsic calls with
   block_tile_reduce/block_tile_reduce_sync in fmha_alu0 (rowmax) and
   fmha_alu1 (rowsum), preserving kFoldKDescale logic in batch_prefill.

3. Split fmha_alu_D_upd into unpack/pack with interleaved scheduling.

4. Add CK_TILE_DISABLE_PACKED_FP32 guard on schedule_gemm1_compute().

5. Add fmha_alu_D_reg_cnt % 2 == 0 assertion.
Lines 174-175 duplicated the specializations already at lines 125-126,
causing redefinition errors. The duplicates were introduced by the merge
(both branches added the same entries independently).
The merge conflict resolution dropped the Dispatcher<fp8_t, fp8_t, float,
32, 32, 32, false> specialization, causing compilation errors when
cshuffle_epilogue instantiates WarpGemmDispatcher with fp8 types and
isCTransposed=false.
@poyenc
Copy link
Copy Markdown
Contributor Author

poyenc commented Apr 18, 2026

the CI always failed to pull CK image

poyenc added 3 commits April 20, 2026 13:47
Resolve merge conflicts from StreamLLM sink token support (#6479)
landing on develop after the V3 batch prefill refactor. Integrate
F_sink parameter into the refactored multi-arch factory codegen.
The refactoring to CompatibilityRuleFactory dropped the original
`if mode != "group": continue` guard, allowing batch-mode kernel
instantiations that hit the static_assert in pipeline problem.
@poyenc poyenc force-pushed the users/poyenc/ck/batch-prefill-v3 branch from 51085d9 to 515cb8f Compare April 23, 2026 03:46
poyenc added 6 commits April 23, 2026 11:47
…ispatch

Add #pragma clang diagnostic push/pop around the v3 dispatch block in the
API footer template, matching the pattern already used in fmha_fwd.py.
Without this, gfx942 builds (which have no v3 kernels) emit if(false)
and fail with -Werror,-Wunreachable-code.
Return None from get_factory for targets without a factory instead of
raising an exception. Filter unsupported targets before dispatching to
get_factories_for_targets so that builds targeting e.g. gfx1101 produce
an empty kernel pool instead of crashing at configure time.
# Conflicts:
#	projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py
Fix -Werror,-Wunused-variable compilation error.
@poyenc poyenc requested a review from Jeff-Huang April 29, 2026 06:11
@Jeff-Huang
Copy link
Copy Markdown
Contributor

Jeff-Huang commented Apr 29, 2026

To fix the OOB issue surfaced by AICK-1171, I based the V2 changes on your PR and opened a separate PR — #6932 — for the V2 load_physical_pages part only. Two small tweaks while extracting:

  • max_page_table_idx is mandatory (no INT32_MAX default), so every callsite has to pass the bound explicitly. With the optional default, any unupdated callsite silently no-ops the clamp.
  • Clamp applied to all branches in load_physical_pages (K prefetch + V LINEAR + V crosses-pages + V single-page lane0 broadcast) in V2.

Verified on MI-308X with the AICK-1171 reproducer and the full FMHA batch prefill suite on gfx942/gfx950. Splitting it out lets the OOB fix land on its own timeline; the rest of #6054 stays as-is.

Credit to your original approach — same min(page_id, max_bound) idea, just hardened.

@Jeff-Huang
Copy link
Copy Markdown
Contributor

Hi @poyenc — thanks for the V3 batch_prefill kernel. While reviewing this PR I hit a silent-correctness issue in the new V3 KV_BLOCKSCALE arm. The output shape and magnitude look right (out_normref_norm within ~0.5%), so i
t doesn't crash — it just quietly exceeds the FP8 reference threshold at specific multi-page configurations. I tracked the root cause down and verified a minimal fix; sharing the full picture below.

Reproducer

A targeted regression test is in aiter PR #3032 (branch jeff-huang/fix-batch-prefill-oob-page-table-read).

test_batch_prefill_4gb_boundary_targeted is a small 2-page boundary probe designed to surface > 4 GB KV-cache page-table bugs without dilution. It allocates a 5000-page paged-KV pool (~5 GB for fp8, ~10 GB for bf16) with page_ size=1024 / num_kv_heads=8 / hdim=128, then for the kernel call passes a single-batch, 2-page page table [vh_start, vh_start + 1] with kv_len = 2048. Because only 2 pages are read and kv_len is small, ALL attention math ru
ns through the suspect pages — wrong reads cannot be diluted by surrounding correct pages (which is what hides the bug in the existing test_batch_prefill_large_kvcache long-sequence stress test).

The failing pytest parametrize id is [False-2-kv_blockscale-fp8]:

Param Value Meaning
causal False Plain attention, no causal mask
page_offset_factor 2 Place first page at the 2 × 2³¹ = 2³² byte boundary
quant_mode kv_blockscale Per-page FP32 descale (one descale per page, per K/V)
input_dtype fp8 KV stored as 1 byte/element

For fp8, each page is 1024 × 8 × 128 × 1 byte = 1 MB. With page_offset_factor=2, the test picks page_indices = [4096, 4097] so the first page byte_offset is exactly 2³² (4 GB).

Note: the 2³¹ / 2³² byte boundary is incidental to this bug — test_batch_prefill_4gb_boundary_targeted was originally designed to surface a separate > 4 GB overflow concern, but happens to also trigger the bug des
cribed here. The actual root cause has nothing to do with byte offsets crossing 32-bit boundaries — any cross-page transition with KV_BLOCKSCALE on V3 fp8 will trigger the same hazard. This test case just happens to be a dete
rministic, low-noise way to expose it.

After checking out aiter PR #3032 with this PR's CK:

HIP_VISIBLE_DEVICES=0 python3 -m pytest \
  op_tests/test_batch_prefill.py::test_batch_prefill_4gb_boundary_targeted \
  -k "False-2-kv_blockscale-fp8" -v

Expected:

FAILED test_batch_prefill_4gb_boundary_targeted[False-2-kv_blockscale-fp8]
AssertionError: FP8 kernel vs reference difference too large:
                0.080078125 (threshold: 0.055)

Confirmed via stream_config.log_level_=1 that the failing case dispatches to ..._v3_..._kv_blockscale_..._linear_sglang.

Bug fingerprint (all conditions must hold)

Isolated by 23 page-permutation probes:

Condition Required
arch = gfx950 + ROCm 7.1+ Yes — gfx942 with same source PASSes (V3 not built → V2 fallback)
quant_mode = KV_BLOCKSCALE Yes — V3 PERTENSOR with same shape PASSes (max = 0.005)
kHasLogitsSoftCap = false Yes — activates the kFoldKDescale path
causal = false Yes
Multi-page (≥ 2 distinct pages) Yes
Both pages ≥ overflow_page (= 2048 in this config) Yes
Pages consecutive (gap = 1) Yes
First page index is even Yes (empirical filter — see note below)

Representative probe results — each row's bracket list is the physical page indices passed via kv_page_indices to the kernel. [4096] means a single page (page index 4096 in the paged-KV pool); [4096, 4097] means two pages (4096 then 4097). Note that each individual page reads correctly, only the transition fails:

pages          verdict   max_diff   note
[4096]         PASS      0.0054     single page (no transition)         ← baseline noise
[4097]         PASS      0.0063     single page (no transition)         ← baseline noise
[4096, 4096]   PASS      0.0054     duplicate (no transition)           ← baseline noise
[4096, 4097]   FAIL      0.0801     even-first consecutive — bug fires, OVER threshold
[4098, 4099]   FAIL      0.1108     even-first consecutive — bug fires, OVER threshold
[4097, 4098]   PASS      0.0313     odd-first — bug fires, ~6× baseline, just under threshold
[4096, 4098]   PASS      0.0157     gap=2     — bug fires, ~3× baseline, just under threshold
[4095, 4096]   PASS      0.0247     boundary  — bug fires, ~5× baseline, just under threshold

The "PASS" rows in the multi-page section also have elevated max_diff (3-6× the single-page baseline of ~0.005). All of them are silently affected by the same hazard — they only PASS because the random per-page descale values happen to differ less, keeping the multiplicative-factor error under the 0.055 FP8 threshold. With a different RNG seed or different head_dim/page_size, the same configurations could easily flip to FAIL. The "first page index even" condition in the fingerprint table is therefore an empirical filter for "loud enough to fail under this seed", not a fundamental requirement of the bug. Verification (further down) shows the fix drops every multi-page row back to baseline noise.

Root cause — read-after-write hazard between save_descales and fmha_alu0

The kFoldKDescale path (active for KV_BLOCKSCALE && !kHasLogitsSoftCap) has a 2-slot double-buffer read-after-write hazard within a single core_loop iteration.

All line numbers below are verified against PR HEAD 70ed05b674.

Notation: K_n = the specific K-tile that gemm0 of THIS iteration consumes; K_{n+1}, K_{n+2} = subsequent K-tiles in the kv_len sequence (consumed by later iterations). pi is the compile-time iteration index that alternates between 0 and 1 per core_loop call — see L1206 (iteration = [&](auto pi)) and L1411 (return iteration(number<0>{}) && iteration(number<1>{})).

Slot binding: at L1210 auto K_w0_lds_wr_idx = number<1>{} - pi; so K_w0_lds_wr_idx always evaluates to slot 1-pi. Both save_descales (writer) and fmha_alu0 (reader inside cl_calc(gemm1)) target this same slot in the SAME iteration.

Within one iteration(pi) (Wave0-3 region at L1259-1295; Wave4-7 region at L1357-1408 is structurally identical with K_w4_lds_wr_idx):

iteration(pi):                                        ┌── slot 1-pi state ──┐
  ① cl_calc(p01, gemm0)                               │  desc_for(K_n)       │  consumes K_n (loaded 2 iters ago)
  ② fmha_alu1(p23)                                    │  desc_for(K_n)       │
  ③ fmha_logits_trans(p01)                            │  desc_for(K_n)       │
  ④ save_descales(K_w0_lds_wr_idx)  // = slot 1-pi    │  desc_for(K_{n+2})   │  ⚠️ OVERWRITES K_n's descale
  ⑤ K_page_issue                                      │                      │
  ⑥ cl_load(memK, K_w0_lds_wr_idx, V_w0_lds_rd_idx)   │  K_{n+2} → LDS slot 1-pi │
  ⑦ K_page_consume                                    │                      │
  ⑧ cl_calc(p23, gemm1)  // calls fmha_alu0(1-pi)     │  reads desc_for(K_{n+2})│  ❌ but applies to K_n!
  ⑨ fmha_alu_D_upd_unpack(p23)                        │                      │
  ⑩ cl_load(memV, V_w0_lds_wr_idx, K_w0_lds_rd_idx)   │  K_lds_load(pi) → kv_tile.k_tile for next iter │

save_descales at step ④ writes the descale for the K-tile being LOADED THIS iter (K_{n+2}) into slot 1-pi. fmha_alu0 at step ⑧ reads the SAME slot 1-pi, expecting the descale of the K-tile that gemm0 just CONSUMED THIS iter (K_n).

Pipeline lag derivationK_nK_{n+2} because cl_load(memK) writes the LDS slot in step ⑥ but kv_tile.k_tile in step ① was populated by K_lds_load in step ⑩ of an EARLIER iteration. Tracing through pre-stage + 2 iterations of core_loop: the K-tile cl_load'd into LDS at iteration i is consumed by gemm0 at iteration i+2 (the toggling of K_w0_lds_wr_idx = 1 - pi and K_w0_lds_rd_idx = pi swaps roles per iteration). The 2-slot buffer is too small to hold both the about-to-be-loaded descale AND the about-to-be-consumed descale simultaneously, and the current ordering puts the WRITE (step ④) before the READ (step ⑧) in the same iteration — closing the hazard.

Concrete trace at the page boundary for [4096, 4097] (d_a = page-4096 desc = 0.01056, d_b = page-4097 desc = 0.01038):

step iter consumed K  saved_k_descale[1] before/after save  fmha_alu0 reads → applies to
─────────────────────────────────────────────────────────────────────────────────────────
step7  0  K13         d_a → d_a   d_a → K13 ✓
step7  1  K14         (writes slot 0 — slot 1 untouched)
step8  0  K15         d_a → d_b   d_b → K15 ❌  (page 4096, needs d_a)
step8  1  K16         (writes slot 0 with d_b)
step9  0  K17         d_b → d_b   d_b → K17 ✓

The wrong-descale region is exactly the 1-2 iterations spanning the boundary. The wrong descale flows through m.thread_buf_[0] *= saved_k_descale[si] (L918) AND scale_s_k = scale_s * saved_k_descale[si] (L928), producing silent multiplicative-factor corruption in the attention output.

Note on the "first page index EVEN" filter — empirically, only even-first-page configurations push the per-element error above the FP8 threshold; odd-first-page configurations stay under threshold despite the same hazard mechanism (e.g. [4097, 4098] PASSes at max=0.0312, well above the all-page-4096 baseline of 0.0054 but below the 0.055 threshold). The exact reason — likely related to how the wrong descale value happens to combine with the running m_old and rowsum at the specific boundary iteration — is not fully derived here. Treat the parity condition as an empirical filter that selects the most visible failures, not as a load-bearing part of the root cause.

Cross-ROCm: bug reproduces on both ROCm 7.1 and 7.2 → source-level bug, not compiler/codegen.

Suggested fix (verified, but please feel free to take a different approach)

Below is the smallest fix I could find that resolves the hazard without changing the kernel's overall structure or losing the kFoldKDescale perf benefit. You know this kernel far better than I do — there may well be a cleaner approach (e.g. extending saved_*_descale to a 3-slot ring, separating K and V descale buffers, restructuring the iteration body, or something else I'm not seeing). Treat this as a starting point that demonstrates the bug is fixable with a localized change, not as a prescribed patch.

The idea is: defer save_descales to AFTER fmha_alu0 reads the slot, so the slot's previous content (descale of the K-tile being consumed) is preserved long enough. To get the right page index for the deferred save, capture k_physical_pages[number<0>{}] BEFORE K_page_issue advances it.

Properties of this approach:

  • Pre-stage's three save_descales(0/1/0) calls are unchanged — pre-stage is manually unrolled without the pipeline lag, so it has no RAW hazard.
  • kFoldKDescale performance benefit (no full-tile pk_mul_f32) is preserved.
  • Deferred save can overlap with the next iteration's K_page_issue / cl_load, so the global-memory load latency is still hidden.
  • ISA dump shows no register spill; SGPR/VGPR counts within 2 of baseline.

Diff (5 hunks, +20/-2 LOC):

diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_v3_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_v3_pipeline.hpp
--- a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_v3_pipeline.hpp
+++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_v3_pipeline.hpp
@@ -875,6 +875,21 @@
             }
         };

+        // Deferred-save variant of save_descales. Used by the core_loop to
+        // defer the descale write until AFTER the same iteration's fmha_alu0
+        // has read the slot — closing a read-after-write hazard with the
+        // descale of the K-tile being CONSUMED this iteration (which lives
+        // in slot 1-pi from a previous iteration's save).
+        auto save_descales_at = [&](index_t captured_page, auto buf_idx_tag) {
+            if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE)
+            {
+                constexpr index_t b = decltype(buf_idx_tag)::value;
+                const index_t scale_offset =
+                    captured_page * nblock_stride_kv_block_descale +
+                    block_indices.kv_head_idx * nhead_stride_kv_block_descale;
+                saved_k_descale[b] = k_descale_ptr[scale_offset];
+                saved_v_descale[b] = v_descale_ptr[scale_offset];
+            }
+        };
+
         auto fmha_logits_trans = [&](auto sp_reg_idx) {
             if constexpr(kHasLogitsSoftCap)
             {
@@ -1270,7 +1285,9 @@
                     __builtin_amdgcn_sched_barrier(0);
                     __builtin_amdgcn_s_barrier();
                     __builtin_amdgcn_sched_barrier(0);
-                    save_descales(K_w0_lds_wr_idx);
+                    // Capture page for deferred save below; do NOT save here
+                    // (would overwrite the descale that fmha_alu0 needs to read).
+                    const index_t captured_page_w0 = k_physical_pages[number<0>{}];
                     K_page_issue();                                  // global_load_dword FIRST
                     __builtin_amdgcn_sched_barrier(0);               // prevent reorder
                     cl_load(memK, K_w0_lds_wr_idx, V_w0_lds_rd_idx); // buffer_loads SECOND
@@ -1296,6 +1313,8 @@
                     Scheduler::schedule(cl_p, number<2>{});
                     __builtin_amdgcn_sched_barrier(0);
                     fmha_alu_D_upd_pack();
+                    // Deferred save: fmha_alu0 has already read the slot.
+                    save_descales_at(captured_page_w0, K_w0_lds_wr_idx);

                     __builtin_amdgcn_sched_barrier(0);
                     // phase3
@@ -1367,7 +1386,9 @@
                     CK_TILE_FMHA_V3_ASM_MARKER("phase2 Wave4-7");
                     __builtin_amdgcn_s_barrier();
                     __builtin_amdgcn_sched_barrier(0);
-                    save_descales(K_w4_lds_wr_idx);
+                    // Capture page for deferred save below; do NOT save here
+                    // (would overwrite the descale that fmha_alu0 needs to read).
+                    const index_t captured_page_w4 = k_physical_pages[number<0>{}];
                     K_page_issue();                                  // global_load_dword FIRST
                     __builtin_amdgcn_sched_barrier(0);               // prevent reorder
                     cl_load(memK, K_w4_lds_wr_idx, V_w4_lds_rd_idx); // buffer_loads SECOND
@@ -1405,6 +1426,8 @@
                     Scheduler::schedule(cl_p, number<3>{});
                     __builtin_amdgcn_sched_barrier(0);
                     fmha_alu_D_upd_pack();
+                    // Deferred save: fmha_alu0 has already read the slot.
+                    save_descales_at(captured_page_w4, K_w4_lds_wr_idx);
                 }
                 return result;
             };

Verification — all 8 page-permutation probes pass after fix

pages baseline with fix
[4096] PASS 0.0061 PASS 0.0061
[4097] PASS 0.0063 PASS 0.0063
[4096, 4096] PASS 0.0054 PASS 0.0054
[4096, 4097] FAIL 0.0801 PASS 0.0084
[4098, 4099] FAIL 0.1108 PASS 0.0088
[4097, 4098] PASS 0.0312 PASS 0.0151
[4096, 4098] PASS 0.0157 PASS 0.0054
[4095, 4096] PASS 0.0247 PASS 0.0072

Note that the 3 already-passing multi-page cases also see significant max_diff reduction (3-4× lower), which independently confirms the fix attribution — all of them were silently affected by the same hazard, just below the FP8 threshold.

test_batch_prefill_4gb_boundary_targeted from aiter PR #3032 is a deterministic regression guard for this class of bug.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants