diff --git a/csrc/py_itfs_cu/asm_pa.cu b/csrc/py_itfs_cu/asm_pa.cu index eebe86dcab..782e2c7f2f 100644 --- a/csrc/py_itfs_cu/asm_pa.cu +++ b/csrc/py_itfs_cu/asm_pa.cu @@ -37,10 +37,12 @@ struct __attribute__((packed)) KernelArgs p3 _p16; unsigned int KVs; p3 _p17; - unsigned int GQA; + unsigned int mtp; p3 _p18; + unsigned int GQA; + p3 _p19; void* ptr_QTP; - p2 _p19; + p2 _p20; }; @@ -198,7 +200,7 @@ void pa_fwd(aiter_tensor_t* Q, // [num_seqs, num_heads, head_size float k_scalar = sqrt(dim); k_scalar = (float)((double)k_log2e / (double)k_scalar); - KernelArgs args; + KernelArgs args = {}; size_t arg_size = sizeof(args); args.ptr_O = out_->data_ptr(); args.ptr_Q = Q->data_ptr(); @@ -222,6 +224,7 @@ void pa_fwd(aiter_tensor_t* Q, // [num_seqs, num_heads, head_size args.Qs = stride_Q; args.Bs = stride_KV_blk; args.KVs = stride_KV_head; + args.mtp = max_qlen - 1; args.GQA = gqa_ratio; args.ptr_QTP = (qo_indptr != nullptr) ? qo_indptr->data_ptr() : nullptr; diff --git a/hsa/gfx942/pa/pa_bf16_noquant_gqa16_1tg_4w.co b/hsa/gfx942/pa/pa_bf16_noquant_gqa16_1tg_4w.co index 8dcd222508..5861fd741c 100755 Binary files a/hsa/gfx942/pa/pa_bf16_noquant_gqa16_1tg_4w.co and b/hsa/gfx942/pa/pa_bf16_noquant_gqa16_1tg_4w.co differ diff --git a/hsa/gfx942/pa/pa_bf16_noquant_gqa8_1tg_4w.co b/hsa/gfx942/pa/pa_bf16_noquant_gqa8_1tg_4w.co index 1b15561485..f48aa9aebc 100755 Binary files a/hsa/gfx942/pa/pa_bf16_noquant_gqa8_1tg_4w.co and b/hsa/gfx942/pa/pa_bf16_noquant_gqa8_1tg_4w.co differ diff --git a/hsa/gfx942/pa/pa_bf16_noquant_gqa8_1tg_4w_mtp_msk0.co b/hsa/gfx942/pa/pa_bf16_noquant_gqa8_1tg_4w_mtp_msk0.co index 7019e9c533..28f96875c2 100755 Binary files a/hsa/gfx942/pa/pa_bf16_noquant_gqa8_1tg_4w_mtp_msk0.co and b/hsa/gfx942/pa/pa_bf16_noquant_gqa8_1tg_4w_mtp_msk0.co differ diff --git a/hsa/gfx942/pa/pa_bf16_noquant_gqa8_1tg_4w_mtp_msk1.co b/hsa/gfx942/pa/pa_bf16_noquant_gqa8_1tg_4w_mtp_msk1.co index 3f6012a25e..b8d853aea8 100755 Binary files a/hsa/gfx942/pa/pa_bf16_noquant_gqa8_1tg_4w_mtp_msk1.co and b/hsa/gfx942/pa/pa_bf16_noquant_gqa8_1tg_4w_mtp_msk1.co differ diff --git a/hsa/gfx942/pa/pa_bf16_pertokenFp8_gqa16_1tg_4w_mtp_msk0.co b/hsa/gfx942/pa/pa_bf16_pertokenFp8_gqa16_1tg_4w_mtp_msk0.co index f4c760940c..a5a378fd5b 100755 Binary files a/hsa/gfx942/pa/pa_bf16_pertokenFp8_gqa16_1tg_4w_mtp_msk0.co and b/hsa/gfx942/pa/pa_bf16_pertokenFp8_gqa16_1tg_4w_mtp_msk0.co differ diff --git a/hsa/gfx942/pa/pa_bf16_pertokenFp8_gqa16_1tg_4w_mtp_msk1.co b/hsa/gfx942/pa/pa_bf16_pertokenFp8_gqa16_1tg_4w_mtp_msk1.co index 58a114a355..f2ba9ef9bc 100755 Binary files a/hsa/gfx942/pa/pa_bf16_pertokenFp8_gqa16_1tg_4w_mtp_msk1.co and b/hsa/gfx942/pa/pa_bf16_pertokenFp8_gqa16_1tg_4w_mtp_msk1.co differ diff --git a/hsa/gfx942/pa/pa_bf16_pertokenFp8_gqa16_2tg_4w.co b/hsa/gfx942/pa/pa_bf16_pertokenFp8_gqa16_2tg_4w.co index d63cfb67f9..b3cffead82 100755 Binary files a/hsa/gfx942/pa/pa_bf16_pertokenFp8_gqa16_2tg_4w.co and b/hsa/gfx942/pa/pa_bf16_pertokenFp8_gqa16_2tg_4w.co differ diff --git a/hsa/gfx942/pa/pa_bf16_pertokenFp8_gqa8_1tg_4w_mtp_msk0.co b/hsa/gfx942/pa/pa_bf16_pertokenFp8_gqa8_1tg_4w_mtp_msk0.co index 3ba9bb4b82..715ed9ded8 100755 Binary files a/hsa/gfx942/pa/pa_bf16_pertokenFp8_gqa8_1tg_4w_mtp_msk0.co and b/hsa/gfx942/pa/pa_bf16_pertokenFp8_gqa8_1tg_4w_mtp_msk0.co differ diff --git a/hsa/gfx942/pa/pa_bf16_pertokenFp8_gqa8_1tg_4w_mtp_msk1.co b/hsa/gfx942/pa/pa_bf16_pertokenFp8_gqa8_1tg_4w_mtp_msk1.co index f1a2d85982..50ed428f85 100755 Binary files a/hsa/gfx942/pa/pa_bf16_pertokenFp8_gqa8_1tg_4w_mtp_msk1.co and b/hsa/gfx942/pa/pa_bf16_pertokenFp8_gqa8_1tg_4w_mtp_msk1.co differ diff --git a/hsa/gfx942/pa/pa_bf16_pertokenFp8_gqa8_2tg_4w.co b/hsa/gfx942/pa/pa_bf16_pertokenFp8_gqa8_2tg_4w.co index d720eafaa3..c3624198d6 100755 Binary files a/hsa/gfx942/pa/pa_bf16_pertokenFp8_gqa8_2tg_4w.co and b/hsa/gfx942/pa/pa_bf16_pertokenFp8_gqa8_2tg_4w.co differ diff --git a/hsa/gfx942/pa/pa_bf16_pertokenFp8_gqa8_2tg_4w_hp.co b/hsa/gfx942/pa/pa_bf16_pertokenFp8_gqa8_2tg_4w_hp.co index bd6f281396..4360d0e280 100755 Binary files a/hsa/gfx942/pa/pa_bf16_pertokenFp8_gqa8_2tg_4w_hp.co and b/hsa/gfx942/pa/pa_bf16_pertokenFp8_gqa8_2tg_4w_hp.co differ diff --git a/hsa/gfx942/pa/pa_bf16_pertokenFp8_gqa8_2tg_4w_uhp.co b/hsa/gfx942/pa/pa_bf16_pertokenFp8_gqa8_2tg_4w_uhp.co index a2cdab2b47..72ae9983c0 100755 Binary files a/hsa/gfx942/pa/pa_bf16_pertokenFp8_gqa8_2tg_4w_uhp.co and b/hsa/gfx942/pa/pa_bf16_pertokenFp8_gqa8_2tg_4w_uhp.co differ diff --git a/hsa/gfx942/pa/pa_bf16_pertokenInt8_gqa16_1tg_4w_mtp_msk0.co b/hsa/gfx942/pa/pa_bf16_pertokenInt8_gqa16_1tg_4w_mtp_msk0.co index a4f44572a0..40235bd0e7 100755 Binary files a/hsa/gfx942/pa/pa_bf16_pertokenInt8_gqa16_1tg_4w_mtp_msk0.co and b/hsa/gfx942/pa/pa_bf16_pertokenInt8_gqa16_1tg_4w_mtp_msk0.co differ diff --git a/hsa/gfx942/pa/pa_bf16_pertokenInt8_gqa16_1tg_4w_mtp_msk1.co b/hsa/gfx942/pa/pa_bf16_pertokenInt8_gqa16_1tg_4w_mtp_msk1.co index 980d663a01..04249483b9 100755 Binary files a/hsa/gfx942/pa/pa_bf16_pertokenInt8_gqa16_1tg_4w_mtp_msk1.co and b/hsa/gfx942/pa/pa_bf16_pertokenInt8_gqa16_1tg_4w_mtp_msk1.co differ diff --git a/hsa/gfx942/pa/pa_bf16_pertokenInt8_gqa16_2tg_4w.co b/hsa/gfx942/pa/pa_bf16_pertokenInt8_gqa16_2tg_4w.co index 7808645e2a..5753abe74e 100755 Binary files a/hsa/gfx942/pa/pa_bf16_pertokenInt8_gqa16_2tg_4w.co and b/hsa/gfx942/pa/pa_bf16_pertokenInt8_gqa16_2tg_4w.co differ diff --git a/hsa/gfx942/pa/pa_bf16_pertokenInt8_gqa8_1tg_4w_mtp_msk0.co b/hsa/gfx942/pa/pa_bf16_pertokenInt8_gqa8_1tg_4w_mtp_msk0.co index d822ecbba3..17fd9d0118 100755 Binary files a/hsa/gfx942/pa/pa_bf16_pertokenInt8_gqa8_1tg_4w_mtp_msk0.co and b/hsa/gfx942/pa/pa_bf16_pertokenInt8_gqa8_1tg_4w_mtp_msk0.co differ diff --git a/hsa/gfx942/pa/pa_bf16_pertokenInt8_gqa8_1tg_4w_mtp_msk1.co b/hsa/gfx942/pa/pa_bf16_pertokenInt8_gqa8_1tg_4w_mtp_msk1.co index 1d56b231d8..080a1ea8b9 100755 Binary files a/hsa/gfx942/pa/pa_bf16_pertokenInt8_gqa8_1tg_4w_mtp_msk1.co and b/hsa/gfx942/pa/pa_bf16_pertokenInt8_gqa8_1tg_4w_mtp_msk1.co differ diff --git a/hsa/gfx942/pa/pa_bf16_pertokenInt8_gqa8_2tg_4w.co b/hsa/gfx942/pa/pa_bf16_pertokenInt8_gqa8_2tg_4w.co index 23f3b033d1..78f0c7752c 100755 Binary files a/hsa/gfx942/pa/pa_bf16_pertokenInt8_gqa8_2tg_4w.co and b/hsa/gfx942/pa/pa_bf16_pertokenInt8_gqa8_2tg_4w.co differ diff --git a/hsa/gfx942/pa/pa_fp16_noquant_gqa16_1tg_4w.co b/hsa/gfx942/pa/pa_fp16_noquant_gqa16_1tg_4w.co index 420805f007..2d31179792 100755 Binary files a/hsa/gfx942/pa/pa_fp16_noquant_gqa16_1tg_4w.co and b/hsa/gfx942/pa/pa_fp16_noquant_gqa16_1tg_4w.co differ diff --git a/hsa/gfx942/pa/pa_fp16_noquant_gqa8_1tg_4w.co b/hsa/gfx942/pa/pa_fp16_noquant_gqa8_1tg_4w.co index 7603a30737..d52a375589 100755 Binary files a/hsa/gfx942/pa/pa_fp16_noquant_gqa8_1tg_4w.co and b/hsa/gfx942/pa/pa_fp16_noquant_gqa8_1tg_4w.co differ diff --git a/hsa/gfx942/pa/pa_fp16_noquant_gqa8_1tg_4w_mtp_msk0.co b/hsa/gfx942/pa/pa_fp16_noquant_gqa8_1tg_4w_mtp_msk0.co index c5c8e53249..88ce4e8972 100755 Binary files a/hsa/gfx942/pa/pa_fp16_noquant_gqa8_1tg_4w_mtp_msk0.co and b/hsa/gfx942/pa/pa_fp16_noquant_gqa8_1tg_4w_mtp_msk0.co differ diff --git a/hsa/gfx942/pa/pa_fp16_noquant_gqa8_1tg_4w_mtp_msk1.co b/hsa/gfx942/pa/pa_fp16_noquant_gqa8_1tg_4w_mtp_msk1.co index 687031fb26..40e8ef5309 100755 Binary files a/hsa/gfx942/pa/pa_fp16_noquant_gqa8_1tg_4w_mtp_msk1.co and b/hsa/gfx942/pa/pa_fp16_noquant_gqa8_1tg_4w_mtp_msk1.co differ diff --git a/hsa/gfx942/pa/pa_fp16_pertokenFp8_gqa16_1tg_4w_mtp_msk0.co b/hsa/gfx942/pa/pa_fp16_pertokenFp8_gqa16_1tg_4w_mtp_msk0.co index 28f7002b67..f35baeefd4 100755 Binary files a/hsa/gfx942/pa/pa_fp16_pertokenFp8_gqa16_1tg_4w_mtp_msk0.co and b/hsa/gfx942/pa/pa_fp16_pertokenFp8_gqa16_1tg_4w_mtp_msk0.co differ diff --git a/hsa/gfx942/pa/pa_fp16_pertokenFp8_gqa16_1tg_4w_mtp_msk1.co b/hsa/gfx942/pa/pa_fp16_pertokenFp8_gqa16_1tg_4w_mtp_msk1.co index 89458b5a3d..d98dae9987 100755 Binary files a/hsa/gfx942/pa/pa_fp16_pertokenFp8_gqa16_1tg_4w_mtp_msk1.co and b/hsa/gfx942/pa/pa_fp16_pertokenFp8_gqa16_1tg_4w_mtp_msk1.co differ diff --git a/hsa/gfx942/pa/pa_fp16_pertokenFp8_gqa16_2tg_4w.co b/hsa/gfx942/pa/pa_fp16_pertokenFp8_gqa16_2tg_4w.co index 3d36feb8ce..646865616f 100755 Binary files a/hsa/gfx942/pa/pa_fp16_pertokenFp8_gqa16_2tg_4w.co and b/hsa/gfx942/pa/pa_fp16_pertokenFp8_gqa16_2tg_4w.co differ diff --git a/hsa/gfx942/pa/pa_fp16_pertokenFp8_gqa8_1tg_4w_mtp_msk0.co b/hsa/gfx942/pa/pa_fp16_pertokenFp8_gqa8_1tg_4w_mtp_msk0.co index 20da3b15c0..0779ee56de 100755 Binary files a/hsa/gfx942/pa/pa_fp16_pertokenFp8_gqa8_1tg_4w_mtp_msk0.co and b/hsa/gfx942/pa/pa_fp16_pertokenFp8_gqa8_1tg_4w_mtp_msk0.co differ diff --git a/hsa/gfx942/pa/pa_fp16_pertokenFp8_gqa8_1tg_4w_mtp_msk1.co b/hsa/gfx942/pa/pa_fp16_pertokenFp8_gqa8_1tg_4w_mtp_msk1.co index dae61631e4..e8c54190cd 100755 Binary files a/hsa/gfx942/pa/pa_fp16_pertokenFp8_gqa8_1tg_4w_mtp_msk1.co and b/hsa/gfx942/pa/pa_fp16_pertokenFp8_gqa8_1tg_4w_mtp_msk1.co differ diff --git a/hsa/gfx942/pa/pa_fp16_pertokenFp8_gqa8_2tg_4w.co b/hsa/gfx942/pa/pa_fp16_pertokenFp8_gqa8_2tg_4w.co index ef2de5c5f2..9353d04b40 100755 Binary files a/hsa/gfx942/pa/pa_fp16_pertokenFp8_gqa8_2tg_4w.co and b/hsa/gfx942/pa/pa_fp16_pertokenFp8_gqa8_2tg_4w.co differ diff --git a/hsa/gfx942/pa/pa_fp16_pertokenFp8_gqa8_2tg_4w_hp.co b/hsa/gfx942/pa/pa_fp16_pertokenFp8_gqa8_2tg_4w_hp.co index 92ed555015..0c310c3895 100755 Binary files a/hsa/gfx942/pa/pa_fp16_pertokenFp8_gqa8_2tg_4w_hp.co and b/hsa/gfx942/pa/pa_fp16_pertokenFp8_gqa8_2tg_4w_hp.co differ diff --git a/hsa/gfx942/pa/pa_fp16_pertokenFp8_gqa8_2tg_4w_uhp.co b/hsa/gfx942/pa/pa_fp16_pertokenFp8_gqa8_2tg_4w_uhp.co index eea7a948cd..86ac7c899f 100755 Binary files a/hsa/gfx942/pa/pa_fp16_pertokenFp8_gqa8_2tg_4w_uhp.co and b/hsa/gfx942/pa/pa_fp16_pertokenFp8_gqa8_2tg_4w_uhp.co differ diff --git a/hsa/gfx942/pa/pa_fp16_pertokenInt8_gqa16_1tg_4w_mtp_msk0.co b/hsa/gfx942/pa/pa_fp16_pertokenInt8_gqa16_1tg_4w_mtp_msk0.co index ced740d521..9b8020fc58 100755 Binary files a/hsa/gfx942/pa/pa_fp16_pertokenInt8_gqa16_1tg_4w_mtp_msk0.co and b/hsa/gfx942/pa/pa_fp16_pertokenInt8_gqa16_1tg_4w_mtp_msk0.co differ diff --git a/hsa/gfx942/pa/pa_fp16_pertokenInt8_gqa16_1tg_4w_mtp_msk1.co b/hsa/gfx942/pa/pa_fp16_pertokenInt8_gqa16_1tg_4w_mtp_msk1.co index 352b18f7bc..c68accb11e 100755 Binary files a/hsa/gfx942/pa/pa_fp16_pertokenInt8_gqa16_1tg_4w_mtp_msk1.co and b/hsa/gfx942/pa/pa_fp16_pertokenInt8_gqa16_1tg_4w_mtp_msk1.co differ diff --git a/hsa/gfx942/pa/pa_fp16_pertokenInt8_gqa16_2tg_4w.co b/hsa/gfx942/pa/pa_fp16_pertokenInt8_gqa16_2tg_4w.co index bf09eb56f6..47a6566bd6 100755 Binary files a/hsa/gfx942/pa/pa_fp16_pertokenInt8_gqa16_2tg_4w.co and b/hsa/gfx942/pa/pa_fp16_pertokenInt8_gqa16_2tg_4w.co differ diff --git a/hsa/gfx942/pa/pa_fp16_pertokenInt8_gqa8_1tg_4w_mtp_msk0.co b/hsa/gfx942/pa/pa_fp16_pertokenInt8_gqa8_1tg_4w_mtp_msk0.co index 3cde2b9c0a..f6212253cf 100755 Binary files a/hsa/gfx942/pa/pa_fp16_pertokenInt8_gqa8_1tg_4w_mtp_msk0.co and b/hsa/gfx942/pa/pa_fp16_pertokenInt8_gqa8_1tg_4w_mtp_msk0.co differ diff --git a/hsa/gfx942/pa/pa_fp16_pertokenInt8_gqa8_1tg_4w_mtp_msk1.co b/hsa/gfx942/pa/pa_fp16_pertokenInt8_gqa8_1tg_4w_mtp_msk1.co index eba6f4ad43..7fce62f6f8 100755 Binary files a/hsa/gfx942/pa/pa_fp16_pertokenInt8_gqa8_1tg_4w_mtp_msk1.co and b/hsa/gfx942/pa/pa_fp16_pertokenInt8_gqa8_1tg_4w_mtp_msk1.co differ diff --git a/hsa/gfx942/pa/pa_fp16_pertokenInt8_gqa8_2tg_4w.co b/hsa/gfx942/pa/pa_fp16_pertokenInt8_gqa8_2tg_4w.co index 1ae740b970..f0ca6e9aa2 100755 Binary files a/hsa/gfx942/pa/pa_fp16_pertokenInt8_gqa8_2tg_4w.co and b/hsa/gfx942/pa/pa_fp16_pertokenInt8_gqa8_2tg_4w.co differ diff --git a/hsa/gfx950/pa/pa_bf16_noquant_gqa16_1tg_4w.co b/hsa/gfx950/pa/pa_bf16_noquant_gqa16_1tg_4w.co index 15efd8eea5..f6a11f0068 100755 Binary files a/hsa/gfx950/pa/pa_bf16_noquant_gqa16_1tg_4w.co and b/hsa/gfx950/pa/pa_bf16_noquant_gqa16_1tg_4w.co differ diff --git a/hsa/gfx950/pa/pa_bf16_noquant_gqa8_1tg_4w.co b/hsa/gfx950/pa/pa_bf16_noquant_gqa8_1tg_4w.co old mode 100755 new mode 100644 index 3e97ffe2f2..b9c97f8944 Binary files a/hsa/gfx950/pa/pa_bf16_noquant_gqa8_1tg_4w.co and b/hsa/gfx950/pa/pa_bf16_noquant_gqa8_1tg_4w.co differ diff --git a/hsa/gfx950/pa/pa_bf16_noquant_gqa8_1tg_4w.co.orig b/hsa/gfx950/pa/pa_bf16_noquant_gqa8_1tg_4w.co.orig new file mode 100755 index 0000000000..6b9e9b4486 Binary files /dev/null and b/hsa/gfx950/pa/pa_bf16_noquant_gqa8_1tg_4w.co.orig differ diff --git a/hsa/gfx950/pa/pa_bf16_noquant_gqa8_1tg_4w.co.poc_kl_merg b/hsa/gfx950/pa/pa_bf16_noquant_gqa8_1tg_4w.co.poc_kl_merg new file mode 100755 index 0000000000..6a5809ec41 Binary files /dev/null and b/hsa/gfx950/pa/pa_bf16_noquant_gqa8_1tg_4w.co.poc_kl_merg differ diff --git a/hsa/gfx950/pa/pa_bf16_noquant_gqa8_1tg_4w_mtp_msk0.co b/hsa/gfx950/pa/pa_bf16_noquant_gqa8_1tg_4w_mtp_msk0.co index 623b638837..9c08b4df5f 100755 Binary files a/hsa/gfx950/pa/pa_bf16_noquant_gqa8_1tg_4w_mtp_msk0.co and b/hsa/gfx950/pa/pa_bf16_noquant_gqa8_1tg_4w_mtp_msk0.co differ diff --git a/hsa/gfx950/pa/pa_bf16_noquant_gqa8_1tg_4w_mtp_msk1.co b/hsa/gfx950/pa/pa_bf16_noquant_gqa8_1tg_4w_mtp_msk1.co index 73f01d6438..68f81e24c6 100755 Binary files a/hsa/gfx950/pa/pa_bf16_noquant_gqa8_1tg_4w_mtp_msk1.co and b/hsa/gfx950/pa/pa_bf16_noquant_gqa8_1tg_4w_mtp_msk1.co differ diff --git a/hsa/gfx950/pa/pa_bf16_pertokenFp8_gqa16_1tg_4w_mtp_msk0.co b/hsa/gfx950/pa/pa_bf16_pertokenFp8_gqa16_1tg_4w_mtp_msk0.co index a558cf55dc..1a8b13e117 100755 Binary files a/hsa/gfx950/pa/pa_bf16_pertokenFp8_gqa16_1tg_4w_mtp_msk0.co and b/hsa/gfx950/pa/pa_bf16_pertokenFp8_gqa16_1tg_4w_mtp_msk0.co differ diff --git a/hsa/gfx950/pa/pa_bf16_pertokenFp8_gqa16_1tg_4w_mtp_msk1.co b/hsa/gfx950/pa/pa_bf16_pertokenFp8_gqa16_1tg_4w_mtp_msk1.co index 99742d0057..206020d612 100755 Binary files a/hsa/gfx950/pa/pa_bf16_pertokenFp8_gqa16_1tg_4w_mtp_msk1.co and b/hsa/gfx950/pa/pa_bf16_pertokenFp8_gqa16_1tg_4w_mtp_msk1.co differ diff --git a/hsa/gfx950/pa/pa_bf16_pertokenFp8_gqa16_2tg_4w.co b/hsa/gfx950/pa/pa_bf16_pertokenFp8_gqa16_2tg_4w.co index 605998b54a..844fd03871 100755 Binary files a/hsa/gfx950/pa/pa_bf16_pertokenFp8_gqa16_2tg_4w.co and b/hsa/gfx950/pa/pa_bf16_pertokenFp8_gqa16_2tg_4w.co differ diff --git a/hsa/gfx950/pa/pa_bf16_pertokenFp8_gqa8_1tg_4w_mtp_msk0.co b/hsa/gfx950/pa/pa_bf16_pertokenFp8_gqa8_1tg_4w_mtp_msk0.co index 8eb0bb26b1..662c463991 100755 Binary files a/hsa/gfx950/pa/pa_bf16_pertokenFp8_gqa8_1tg_4w_mtp_msk0.co and b/hsa/gfx950/pa/pa_bf16_pertokenFp8_gqa8_1tg_4w_mtp_msk0.co differ diff --git a/hsa/gfx950/pa/pa_bf16_pertokenFp8_gqa8_1tg_4w_mtp_msk1.co b/hsa/gfx950/pa/pa_bf16_pertokenFp8_gqa8_1tg_4w_mtp_msk1.co index 0268190f1b..6c6d75514b 100755 Binary files a/hsa/gfx950/pa/pa_bf16_pertokenFp8_gqa8_1tg_4w_mtp_msk1.co and b/hsa/gfx950/pa/pa_bf16_pertokenFp8_gqa8_1tg_4w_mtp_msk1.co differ diff --git a/hsa/gfx950/pa/pa_bf16_pertokenFp8_gqa8_2tg_4w.co b/hsa/gfx950/pa/pa_bf16_pertokenFp8_gqa8_2tg_4w.co index cb841a6bb3..f33746320f 100755 Binary files a/hsa/gfx950/pa/pa_bf16_pertokenFp8_gqa8_2tg_4w.co and b/hsa/gfx950/pa/pa_bf16_pertokenFp8_gqa8_2tg_4w.co differ diff --git a/hsa/gfx950/pa/pa_bf16_pertokenFp8_gqa8_2tg_4w_hp.co b/hsa/gfx950/pa/pa_bf16_pertokenFp8_gqa8_2tg_4w_hp.co index d0bf341cf0..2b04347bcb 100755 Binary files a/hsa/gfx950/pa/pa_bf16_pertokenFp8_gqa8_2tg_4w_hp.co and b/hsa/gfx950/pa/pa_bf16_pertokenFp8_gqa8_2tg_4w_hp.co differ diff --git a/hsa/gfx950/pa/pa_bf16_pertokenFp8_gqa8_2tg_4w_uhp.co b/hsa/gfx950/pa/pa_bf16_pertokenFp8_gqa8_2tg_4w_uhp.co index 40a8c3555b..03a96abb4c 100755 Binary files a/hsa/gfx950/pa/pa_bf16_pertokenFp8_gqa8_2tg_4w_uhp.co and b/hsa/gfx950/pa/pa_bf16_pertokenFp8_gqa8_2tg_4w_uhp.co differ diff --git a/hsa/gfx950/pa/pa_bf16_pertokenInt8_gqa16_1tg_4w_mtp_msk0.co b/hsa/gfx950/pa/pa_bf16_pertokenInt8_gqa16_1tg_4w_mtp_msk0.co index 4cd7b44221..e7579ce652 100755 Binary files a/hsa/gfx950/pa/pa_bf16_pertokenInt8_gqa16_1tg_4w_mtp_msk0.co and b/hsa/gfx950/pa/pa_bf16_pertokenInt8_gqa16_1tg_4w_mtp_msk0.co differ diff --git a/hsa/gfx950/pa/pa_bf16_pertokenInt8_gqa16_1tg_4w_mtp_msk1.co b/hsa/gfx950/pa/pa_bf16_pertokenInt8_gqa16_1tg_4w_mtp_msk1.co index 57303aa5e3..6acd65b85c 100755 Binary files a/hsa/gfx950/pa/pa_bf16_pertokenInt8_gqa16_1tg_4w_mtp_msk1.co and b/hsa/gfx950/pa/pa_bf16_pertokenInt8_gqa16_1tg_4w_mtp_msk1.co differ diff --git a/hsa/gfx950/pa/pa_bf16_pertokenInt8_gqa16_2tg_4w.co b/hsa/gfx950/pa/pa_bf16_pertokenInt8_gqa16_2tg_4w.co index 53cd196323..e6f60d8720 100755 Binary files a/hsa/gfx950/pa/pa_bf16_pertokenInt8_gqa16_2tg_4w.co and b/hsa/gfx950/pa/pa_bf16_pertokenInt8_gqa16_2tg_4w.co differ diff --git a/hsa/gfx950/pa/pa_bf16_pertokenInt8_gqa8_1tg_4w_mtp_msk0.co b/hsa/gfx950/pa/pa_bf16_pertokenInt8_gqa8_1tg_4w_mtp_msk0.co index 5e266b49f5..54f2dd5d5d 100755 Binary files a/hsa/gfx950/pa/pa_bf16_pertokenInt8_gqa8_1tg_4w_mtp_msk0.co and b/hsa/gfx950/pa/pa_bf16_pertokenInt8_gqa8_1tg_4w_mtp_msk0.co differ diff --git a/hsa/gfx950/pa/pa_bf16_pertokenInt8_gqa8_1tg_4w_mtp_msk1.co b/hsa/gfx950/pa/pa_bf16_pertokenInt8_gqa8_1tg_4w_mtp_msk1.co index d0f5671d7e..21fa21b5e4 100755 Binary files a/hsa/gfx950/pa/pa_bf16_pertokenInt8_gqa8_1tg_4w_mtp_msk1.co and b/hsa/gfx950/pa/pa_bf16_pertokenInt8_gqa8_1tg_4w_mtp_msk1.co differ diff --git a/hsa/gfx950/pa/pa_bf16_pertokenInt8_gqa8_2tg_4w.co b/hsa/gfx950/pa/pa_bf16_pertokenInt8_gqa8_2tg_4w.co index 911cefd182..4113d925b3 100755 Binary files a/hsa/gfx950/pa/pa_bf16_pertokenInt8_gqa8_2tg_4w.co and b/hsa/gfx950/pa/pa_bf16_pertokenInt8_gqa8_2tg_4w.co differ diff --git a/hsa/gfx950/pa/pa_fp16_noquant_gqa16_1tg_4w.co b/hsa/gfx950/pa/pa_fp16_noquant_gqa16_1tg_4w.co index 97ab9c9f3b..00fc842731 100755 Binary files a/hsa/gfx950/pa/pa_fp16_noquant_gqa16_1tg_4w.co and b/hsa/gfx950/pa/pa_fp16_noquant_gqa16_1tg_4w.co differ diff --git a/hsa/gfx950/pa/pa_fp16_noquant_gqa8_1tg_4w.co b/hsa/gfx950/pa/pa_fp16_noquant_gqa8_1tg_4w.co index 1ae323697c..c417b763e0 100755 Binary files a/hsa/gfx950/pa/pa_fp16_noquant_gqa8_1tg_4w.co and b/hsa/gfx950/pa/pa_fp16_noquant_gqa8_1tg_4w.co differ diff --git a/hsa/gfx950/pa/pa_fp16_noquant_gqa8_1tg_4w_mtp_msk0.co b/hsa/gfx950/pa/pa_fp16_noquant_gqa8_1tg_4w_mtp_msk0.co index 2a2ffd0f0d..bfcfe5e35a 100755 Binary files a/hsa/gfx950/pa/pa_fp16_noquant_gqa8_1tg_4w_mtp_msk0.co and b/hsa/gfx950/pa/pa_fp16_noquant_gqa8_1tg_4w_mtp_msk0.co differ diff --git a/hsa/gfx950/pa/pa_fp16_noquant_gqa8_1tg_4w_mtp_msk1.co b/hsa/gfx950/pa/pa_fp16_noquant_gqa8_1tg_4w_mtp_msk1.co index a1af49e2ca..fe1bc50a14 100755 Binary files a/hsa/gfx950/pa/pa_fp16_noquant_gqa8_1tg_4w_mtp_msk1.co and b/hsa/gfx950/pa/pa_fp16_noquant_gqa8_1tg_4w_mtp_msk1.co differ diff --git a/hsa/gfx950/pa/pa_fp16_pertokenFp8_gqa16_1tg_4w_mtp_msk0.co b/hsa/gfx950/pa/pa_fp16_pertokenFp8_gqa16_1tg_4w_mtp_msk0.co index fb5a3e6fbe..914085792e 100755 Binary files a/hsa/gfx950/pa/pa_fp16_pertokenFp8_gqa16_1tg_4w_mtp_msk0.co and b/hsa/gfx950/pa/pa_fp16_pertokenFp8_gqa16_1tg_4w_mtp_msk0.co differ diff --git a/hsa/gfx950/pa/pa_fp16_pertokenFp8_gqa16_1tg_4w_mtp_msk1.co b/hsa/gfx950/pa/pa_fp16_pertokenFp8_gqa16_1tg_4w_mtp_msk1.co index abad7b93d9..44055defdb 100755 Binary files a/hsa/gfx950/pa/pa_fp16_pertokenFp8_gqa16_1tg_4w_mtp_msk1.co and b/hsa/gfx950/pa/pa_fp16_pertokenFp8_gqa16_1tg_4w_mtp_msk1.co differ diff --git a/hsa/gfx950/pa/pa_fp16_pertokenFp8_gqa16_2tg_4w.co b/hsa/gfx950/pa/pa_fp16_pertokenFp8_gqa16_2tg_4w.co index a0e75556a2..9110f6e800 100755 Binary files a/hsa/gfx950/pa/pa_fp16_pertokenFp8_gqa16_2tg_4w.co and b/hsa/gfx950/pa/pa_fp16_pertokenFp8_gqa16_2tg_4w.co differ diff --git a/hsa/gfx950/pa/pa_fp16_pertokenFp8_gqa8_1tg_4w_mtp_msk0.co b/hsa/gfx950/pa/pa_fp16_pertokenFp8_gqa8_1tg_4w_mtp_msk0.co index 4255ff8406..2d66d2b5e9 100755 Binary files a/hsa/gfx950/pa/pa_fp16_pertokenFp8_gqa8_1tg_4w_mtp_msk0.co and b/hsa/gfx950/pa/pa_fp16_pertokenFp8_gqa8_1tg_4w_mtp_msk0.co differ diff --git a/hsa/gfx950/pa/pa_fp16_pertokenFp8_gqa8_1tg_4w_mtp_msk1.co b/hsa/gfx950/pa/pa_fp16_pertokenFp8_gqa8_1tg_4w_mtp_msk1.co index aabe09b4b2..0828e434f6 100755 Binary files a/hsa/gfx950/pa/pa_fp16_pertokenFp8_gqa8_1tg_4w_mtp_msk1.co and b/hsa/gfx950/pa/pa_fp16_pertokenFp8_gqa8_1tg_4w_mtp_msk1.co differ diff --git a/hsa/gfx950/pa/pa_fp16_pertokenFp8_gqa8_2tg_4w.co b/hsa/gfx950/pa/pa_fp16_pertokenFp8_gqa8_2tg_4w.co index 57867426d6..ba6315d518 100755 Binary files a/hsa/gfx950/pa/pa_fp16_pertokenFp8_gqa8_2tg_4w.co and b/hsa/gfx950/pa/pa_fp16_pertokenFp8_gqa8_2tg_4w.co differ diff --git a/hsa/gfx950/pa/pa_fp16_pertokenFp8_gqa8_2tg_4w_hp.co b/hsa/gfx950/pa/pa_fp16_pertokenFp8_gqa8_2tg_4w_hp.co index 6b51117b1f..4210f2d693 100755 Binary files a/hsa/gfx950/pa/pa_fp16_pertokenFp8_gqa8_2tg_4w_hp.co and b/hsa/gfx950/pa/pa_fp16_pertokenFp8_gqa8_2tg_4w_hp.co differ diff --git a/hsa/gfx950/pa/pa_fp16_pertokenFp8_gqa8_2tg_4w_uhp.co b/hsa/gfx950/pa/pa_fp16_pertokenFp8_gqa8_2tg_4w_uhp.co index cb9649c5ac..845a97a809 100755 Binary files a/hsa/gfx950/pa/pa_fp16_pertokenFp8_gqa8_2tg_4w_uhp.co and b/hsa/gfx950/pa/pa_fp16_pertokenFp8_gqa8_2tg_4w_uhp.co differ diff --git a/hsa/gfx950/pa/pa_fp16_pertokenInt8_gqa16_1tg_4w_mtp_msk0.co b/hsa/gfx950/pa/pa_fp16_pertokenInt8_gqa16_1tg_4w_mtp_msk0.co index 03aaba5996..6a92da410f 100755 Binary files a/hsa/gfx950/pa/pa_fp16_pertokenInt8_gqa16_1tg_4w_mtp_msk0.co and b/hsa/gfx950/pa/pa_fp16_pertokenInt8_gqa16_1tg_4w_mtp_msk0.co differ diff --git a/hsa/gfx950/pa/pa_fp16_pertokenInt8_gqa16_1tg_4w_mtp_msk1.co b/hsa/gfx950/pa/pa_fp16_pertokenInt8_gqa16_1tg_4w_mtp_msk1.co index c6d8b1afd2..b69af73f7e 100755 Binary files a/hsa/gfx950/pa/pa_fp16_pertokenInt8_gqa16_1tg_4w_mtp_msk1.co and b/hsa/gfx950/pa/pa_fp16_pertokenInt8_gqa16_1tg_4w_mtp_msk1.co differ diff --git a/hsa/gfx950/pa/pa_fp16_pertokenInt8_gqa16_2tg_4w.co b/hsa/gfx950/pa/pa_fp16_pertokenInt8_gqa16_2tg_4w.co index 40a330e09d..b317fa40b5 100755 Binary files a/hsa/gfx950/pa/pa_fp16_pertokenInt8_gqa16_2tg_4w.co and b/hsa/gfx950/pa/pa_fp16_pertokenInt8_gqa16_2tg_4w.co differ diff --git a/hsa/gfx950/pa/pa_fp16_pertokenInt8_gqa8_1tg_4w_mtp_msk0.co b/hsa/gfx950/pa/pa_fp16_pertokenInt8_gqa8_1tg_4w_mtp_msk0.co index a6ead50fc1..cfcea28779 100755 Binary files a/hsa/gfx950/pa/pa_fp16_pertokenInt8_gqa8_1tg_4w_mtp_msk0.co and b/hsa/gfx950/pa/pa_fp16_pertokenInt8_gqa8_1tg_4w_mtp_msk0.co differ diff --git a/hsa/gfx950/pa/pa_fp16_pertokenInt8_gqa8_1tg_4w_mtp_msk1.co b/hsa/gfx950/pa/pa_fp16_pertokenInt8_gqa8_1tg_4w_mtp_msk1.co index 6c9d764507..c5f4570b19 100755 Binary files a/hsa/gfx950/pa/pa_fp16_pertokenInt8_gqa8_1tg_4w_mtp_msk1.co and b/hsa/gfx950/pa/pa_fp16_pertokenInt8_gqa8_1tg_4w_mtp_msk1.co differ diff --git a/hsa/gfx950/pa/pa_fp16_pertokenInt8_gqa8_2tg_4w.co b/hsa/gfx950/pa/pa_fp16_pertokenInt8_gqa8_2tg_4w.co index 8717a1915f..8ce62e2e23 100755 Binary files a/hsa/gfx950/pa/pa_fp16_pertokenInt8_gqa8_2tg_4w.co and b/hsa/gfx950/pa/pa_fp16_pertokenInt8_gqa8_2tg_4w.co differ diff --git a/op_tests/repros/README.md b/op_tests/repros/README.md new file mode 100644 index 0000000000..52977a4280 --- /dev/null +++ b/op_tests/repros/README.md @@ -0,0 +1,127 @@ +# aiter ASM PA crash — standalone reproducer + +A minimal aiter-only reproducer for the HIP illegal-memory crash observed in +production (Kimi-K2.5-MXFP4 + Eagle3 spec-decode on 8x MI355 / gfx950) when +ATOM uses ASM-force paged-attention. + +## TL;DR fingerprint + +| | value | +|---|---| +| **Kernel** | `pa_bf16_pertokenFp8_gqa8_1tg_4w_mtp_msk1.co` (gfx950) | +| **Trigger** | `batch_size == 128` AND `qlen == 3` (specifically — not a boundary, not `total_qo`) | +| **Shape** | GQA=8 (8 Q heads / 1 KV head), head_size=128, block_size=16 | +| **KV dtype** | **fp8 per-token quant only** — bf16 (`pa_bf16_noquant_gqa8_1tg_4w_mtp_msk1.co`) does NOT crash at the same shape | +| **Failure** | HIP illegal memory access, reported asynchronously. Surfaces at the next `hipModuleLaunchKernel` call (so call N+1 errors when call N was the offender). | +| **Min repeats to crash** | 2–3 invocations of the bad shape. Single call sometimes survives, 2nd or 3rd reliably trips it. | +| **Sequence dependency** | None — pure shape bug. Calling the bad shape in a fresh process triggers it. | +| **Concurrency dependency** | None — single-stream, `AMD_SERIALIZE_KERNEL=3 HIP_LAUNCH_BLOCKING=1` still crashes. | + +## How to reproduce in ≤30 seconds (inside the eagle3 container) + +```bash +cd /app/aiter-test # or wherever aiter is importable +AMD_SERIALIZE_KERNEL=3 HIP_LAUNCH_BLOCKING=1 \ + python /home/hyi_qle/yhl/project/002-kimi-pa-asm-fix/aiter_repro/pa_asm_fp8_repeat_call.py \ + --bs 128 --ctx 1024 --qlen 3 --kv-dtype fp8 --n-repeat 5 +``` + +Expected output on a buggy build: +``` +[aiter] LoadKernel: _ZN5aiter40pa_bf16_pertokenFp8_gqa8_1tg_4w_mtp_msk1E + hsaco: /app/aiter-test/hsa//gfx950/pa/pa_bf16_pertokenFp8_gqa8_1tg_4w_mtp_msk1.co +[AITER] /app/aiter-test/csrc/include/aiter_hip_common.h:244 + fail to call hipModuleLaunchKernel(...) ---> [HIP error](an illegal memory access) +Aborted (core dumped) +``` + +Negative controls (all pass on the same build): +```bash +# bf16 KV — same bs=128 qlen=3, different kernel, no crash: +python pa_asm_fp8_repeat_call.py --bs 128 --ctx 1024 --qlen 3 --kv-dtype bf16 + +# fp8 KV, bs off by one — no crash: +python pa_asm_fp8_repeat_call.py --bs 127 --ctx 1024 --qlen 3 --kv-dtype fp8 +python pa_asm_fp8_repeat_call.py --bs 129 --ctx 1024 --qlen 3 --kv-dtype fp8 + +# fp8 KV, qlen off by one — no crash: +python pa_asm_fp8_repeat_call.py --bs 128 --ctx 1024 --qlen 2 --kv-dtype fp8 +python pa_asm_fp8_repeat_call.py --bs 128 --ctx 1024 --qlen 4 --kv-dtype fp8 + +# same total_qo=384 via other (bs, qlen) — no crash: +python pa_asm_fp8_repeat_call.py --bs 192 --ctx 1024 --qlen 2 --kv-dtype fp8 # 192*2=384 +python pa_asm_fp8_repeat_call.py --bs 96 --ctx 1024 --qlen 4 --kv-dtype fp8 # 96*4=384 +``` + +## Sweep data (each cell = 5 repeated calls, fresh process, `AMD_SERIALIZE_KERNEL=3 HIP_LAUNCH_BLOCKING=1`) + +KV dtype = fp8, head_size=128, block_size=16, num_blocks=8192, GQA=8 (num_q_heads=8, num_kv_heads=1). + +qlen=3 sweep over batch_size (ctx_len=1024, fp8 KV): + +| bs | 32 | 64 | 96 | 124 | 125 | 126 | 127 | **128** | 129 | 130 | 144 | 192 | 256 | 512 | +|----|----|----|----|----|----|----|----|----|----|----|----|----|----|----| +| result | OK | OK | OK | OK | OK | OK | OK | **CRASH** | OK | OK | OK | OK | OK | OK | + +bs=128 sweep over qlen (ctx_len=1024, fp8 KV): + +| qlen | 1 | 2 | **3** | 4 | 5 | 6 | 7 | 8 | +|------|---|---|---|---|---|---|---|---| +| result | OK | OK | **CRASH** | OK | OK | OK | OK | OK | + +bs=128, qlen=3, ctx_len sweep: + +| ctx_len | 128 | 1024 | 2048 | 4096 | 6724 | 8192 | 16384 | +|---------|---|---|---|---|---|---|---| +| result | CRASH | CRASH | CRASH | CRASH | CRASH | CRASH | CRASH | + +→ ctx_len does not affect; bs=128 ∧ qlen=3 is the entire trigger. + +## What's in this directory + +| file | purpose | +|---|---| +| `pa_asm_fp8_repeat_call.py` | **The minimal reproducer.** Repeats one `(bs, ctx, qlen)` call N times. Use for bisection. | +| `pa_asm_fp8_min_repro.py` | Single-call variant (does NOT crash by itself — bug needs ≥2 calls). Useful for checking shapes that are individually safe. | +| `pa_asm_fp8_seq_repro.py` | Replays a fixed call sequence from the stress driver and lets you also `--repeat-only-bad` to confirm sequence-independence. | +| `pa_asm_crash_repro.py` | Original stress driver that mimics ATOM's call pattern (random shape mix, multi-stream). Useful for end-to-end "would this build crash under prod-like load?" | +| `pa_asm_fp8_shape_sweep.py` | Sweep wrapper (4 qlens × 9 ctx × 5 bs). Each cell forks a fresh process. | +| `README.md` | This file. | + +## Production correlation + +- ATOM ASM-force path calls `aiter.pa_fwd_asm` with exactly these shapes during + the Eagle3 spec-decode target step (Kimi MLA absorbed → 1 KV head, TP=8 → + 8 Q heads per rank → GQA=8). Draft tokens are 3 per step (eagle3 emits 3), + so `qlen=3`. +- The bench runs at concurrency=128. When the scheduler packs 128 in-flight + requests into one step, the resulting `attn_metadata` is exactly + `batch_size=128, max_seqlen_q=3` → triggers this kernel. +- Production crash signature (`event.synchronize() → HIP illegal memory`) is + the same async-reported error this reproducer surfaces. +- Crash req# varies wildly in production (131, 857, ~900, ~3945) because it + depends on when the scheduler first assembles a step matching `bs=128 ∧ + qlen=3` — not on cumulative state. + +## Build versions used + +- aiter: `pr3211-on-main @ aff40475d` (also reproduces on `main @ ee28d47ac`) +- ROCm/HIP: per the `rocm/atom-dev:latest` container +- Target arch: gfx950 (MI355) +- Container: eagle3 (podman) + +## Suggested next steps for the ASM team + +1. Disassemble `pa_bf16_pertokenFp8_gqa8_1tg_4w_mtp_msk1.co` with + `roc-obj-extract` / `llvm-objdump -d`, then look at how `batch_size` + and `max_qlen=3` parameterize the kernarg block. The branch that's hit + only at `(bs=128, qlen=3)` is the suspect. +2. Compare with the bf16 sibling `pa_bf16_noquant_gqa8_1tg_4w_mtp_msk1.co` + (which does NOT crash on the same shape) to find the extra fp8 code path. +3. Confirm with `rocm-debug-agent`: run + ``` + AMD_LOG_LEVEL=4 ROCM_DEBUG_AGENT=on \ + AMD_SERIALIZE_KERNEL=3 HIP_LAUNCH_BLOCKING=1 \ + python pa_asm_fp8_repeat_call.py --bs 128 --ctx 1024 --qlen 3 --n-repeat 3 + ``` + to capture the wave dump, faulting PC, and offending V# descriptor. diff --git a/op_tests/repros/pa_asm_crash_repro.py b/op_tests/repros/pa_asm_crash_repro.py new file mode 100644 index 0000000000..6498ed4a00 --- /dev/null +++ b/op_tests/repros/pa_asm_crash_repro.py @@ -0,0 +1,423 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Standalone aiter reproducer for ASM paged-attention crash observed in ATOM +# (Kimi-K2.5-MXFP4 + Eagle3 spec-decode) at 30k req x 128 conc. +# +# Background: +# - Inside ATOM, attention_mha.py calls aiter.pa_fwd_asm via ASM-force path. +# At 30k req x 128 concurrency, this crashes with HIP illegal memory access, +# async-reported via event.synchronize(). Crash req# is highly variable +# (observed 131 / 857 / ~900 / ~3945 across runs). +# - The crash signature (wave dump) names kernels of the form +# pa_bf16_pertokenFp8_gqa8_1tg_4w_mtp_msk1.co (fp8 KV) +# pa_bf16_gqa8_1tg_4w_mtp_msk1.co (bf16 KV) +# so the offending kernel is the ASM PA backend, GQA ratio 8, with mtp. +# - The crash reproduces with both fp8 and bf16 KV cache. Forcing +# self.kv_scale to be a 131072-element array did NOT fix it. +# - Gluon-attention path under identical workload (30k x 128) is stable +# (30000/30000 PASS). So root cause is in the ASM PA backend, not in ATOM. +# +# This script is a self-contained aiter-only stress driver that: +# * matches ATOM's pa_fwd_asm call signature (incl. qo_indptr, K_QScale, +# V_QScale, max_qlen, high_precision=0). +# * uses GQA-8 (num_q_heads=8, num_kv_heads=1), block_size=16, head_dim=128 +# -> selects the same ASM .co binary as production. +# * mixes mtp qlen in {1,2,3,4} per iteration to exercise the same kernel +# variants the bench hits. +# * varies batch_size (mostly large, ~128) and ctx_len each iteration +# to imitate the request mix. +# * launches many calls on multiple CUDA streams without sync between them, +# because the bug is async-reported and a strict launch-and-sync loop may +# not race the way 128-conc inflight requests do. +# * periodically forces a sync, catches the HIP error, and prints the iter, +# batch shape, and current call params so the ASM team can inspect. +# +# Usage: +# # bf16 KV (no quant), default: +# python pa_asm_crash_repro.py +# +# # fp8 KV (matches the wave-dumped kernel from crash note): +# python pa_asm_crash_repro.py --kv-dtype fp8 +# +# # tweak shape mix: +# python pa_asm_crash_repro.py --kv-dtype fp8 --max-iters 200000 \ +# --streams 16 --sync-every 32 + +import argparse +import os +import random +import sys +import time +import traceback +from typing import List, Optional, Tuple + +import torch + +import aiter +from aiter import dtypes +from aiter import pertoken_quant + + +# --------------------------------------------------------------------------- # +# KV cache helpers (same layout as test_pa_mtp.py and aiter PA convention) +# --------------------------------------------------------------------------- # +def make_kv_cache( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + dtype: torch.dtype, + device: str = "cuda", +) -> Tuple[torch.Tensor, torch.Tensor]: + """Allocate K/V cache in aiter PA layout. + + K: [num_blocks, num_kv_heads, head_size // x, block_size, x] + V: [num_blocks, num_kv_heads, head_size, block_size] + where x = 16 // dtype.itemsize. + """ + x = 16 // dtype.itemsize + k_shape = (num_blocks, num_kv_heads, head_size // x, block_size, x) + v_shape = (num_blocks, num_kv_heads, head_size, block_size) + k_cache = torch.empty(k_shape, dtype=dtype, device=device).uniform_(-1, 1) + v_cache = torch.empty(v_shape, dtype=dtype, device=device).uniform_(-1, 1) + return k_cache, v_cache + + +def asm_v_shuffle(v_cache: torch.Tensor) -> torch.Tensor: + """ASM PA expects V re-shuffled to [B, KVH, block_size/x, head_size, x].""" + x = 16 // v_cache.element_size() + num_blocks, num_kv_heads, head_size, block_size = v_cache.shape + v = v_cache.view(num_blocks, num_kv_heads, head_size, block_size // x, x) + return v.permute(0, 1, 3, 2, 4).contiguous() + + +def pertoken_quant_kvcache_symm( + k_cache: torch.Tensor, + v_cache: torch.Tensor, + quant_dtype: torch.dtype, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Per-token symmetric quant of K/V to fp8 + scale arrays in ASM layout. + + Returns: + k_quant: same layout as K cache, in `quant_dtype` + v_quant: same layout as V cache, in `quant_dtype` + k_scale_asm: [num_blocks, num_kv_heads, block_size, 1] (ASM-friendly) + v_scale_asm: same shape + """ + num_blocks, num_kv_heads = k_cache.shape[0], k_cache.shape[1] + head_dim = v_cache.shape[2] + block_size = v_cache.shape[3] + + k_perm = ( + k_cache.permute(0, 1, 3, 2, 4) + .reshape(num_blocks, num_kv_heads, block_size, -1) + .contiguous() + ) + v_perm = ( + v_cache.permute(0, 1, 3, 2) + .reshape(num_blocks, num_kv_heads, block_size, -1) + .contiguous() + ) + + k_quant, k_scale_asm = pertoken_quant(k_perm, quant_dtype=quant_dtype) + v_quant, v_scale_asm = pertoken_quant(v_perm, quant_dtype=quant_dtype) + + quant_x = 16 // quant_dtype.itemsize + k_quant = ( + k_quant.view(num_blocks, num_kv_heads, block_size, head_dim // quant_x, quant_x) + .permute(0, 1, 3, 2, 4) + .contiguous() + ) + v_quant = ( + v_quant.view(num_blocks, num_kv_heads, block_size, head_dim) + .permute(0, 1, 3, 2) + .contiguous() + ) + return k_quant, v_quant, k_scale_asm, v_scale_asm + + +# --------------------------------------------------------------------------- # +# Iteration: pick a random request mix, build inputs, fire pa_fwd_asm +# --------------------------------------------------------------------------- # +def build_iter_inputs( + rng: random.Random, + num_kv_heads: int, + num_q_heads: int, + head_size: int, + block_size: int, + num_blocks: int, + dtype: torch.dtype, + device: str = "cuda", +): + """Construct one paged-attention call's inputs. + + Mirrors the production mix at 128 conc: + - batch_size: most often near 128, sometimes 32/64/96/256 to exercise + edge cases (matches a serving step where some seqs finished). + - ctx_len: random in [64, 8192], occasionally up to 16384. + - qlen: in {1, 2, 3, 4} (eagle3 draft / MTP draft tokens). + """ + # batch size mix: biased toward 128 (matches CONCURRENCY=128) + batch_size = rng.choice([32, 64, 96, 128, 128, 128, 128, 128, 192, 256]) + + # ctx_len mix: most short-to-medium, sometimes long + if rng.random() < 0.08: + ctx_len = rng.randint(8192, 16384) + elif rng.random() < 0.4: + ctx_len = rng.randint(64, 1024) + else: + ctx_len = rng.randint(1024, 8192) + ctx_len = max(ctx_len, block_size) + + # MTP qlen distribution: weight toward 3-4 since eagle3 emits 3 draft tokens + qlen = rng.choice([1, 2, 3, 3, 3, 4, 4]) + + max_num_blocks_per_seq = (16384 + block_size - 1) // block_size + num_blocks_per_seq = (ctx_len + block_size - 1) // block_size + + qo_indptr = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) + seq_lens_qo = torch.full((batch_size,), qlen, dtype=torch.int32, device=device) + qo_indptr[1:] = torch.cumsum(seq_lens_qo, dim=0) + total_qo = int(qo_indptr[-1].item()) + max_qlen = qlen + + query = torch.empty( + (total_qo, num_q_heads, head_size), dtype=dtype, device=device + ).uniform_(-1, 1) + + seq_lens = torch.full( + (batch_size,), ctx_len, dtype=torch.int32, device=device + ) + + # block_tables: random page assignments per request, padded to + # max_num_blocks_per_seq. + block_tables = torch.zeros( + (batch_size, max_num_blocks_per_seq), dtype=torch.int32, device=device + ) + for i in range(batch_size): + idx = torch.randint( + 0, num_blocks, (num_blocks_per_seq,), dtype=torch.int32, device=device + ) + block_tables[i, :num_blocks_per_seq] = idx + + return dict( + query=query, + block_tables=block_tables, + seq_lens=seq_lens, + qo_indptr=qo_indptr, + max_qlen=max_qlen, + batch_size=batch_size, + ctx_len=ctx_len, + qlen=qlen, + total_qo=total_qo, + ) + + +def run_one( + iter_idx: int, + rng: random.Random, + cfg, + persistent, + stream: torch.cuda.Stream, + log_shape: bool = False, +) -> dict: + inp = build_iter_inputs( + rng, + cfg.num_kv_heads, + cfg.num_q_heads, + cfg.head_size, + cfg.block_size, + cfg.num_blocks, + cfg.compute_dtype, + ) + if log_shape: + print(f"[repro] iter={iter_idx:>6} bs={inp['batch_size']:>4} " + f"ctx={inp['ctx_len']:>5} qlen={inp['qlen']} " + f"total_qo={inp['total_qo']:>5} -> calling pa_fwd_asm", flush=True) + + if cfg.kv_dtype == "fp8": + k_cache = persistent["k_quant"] + v_cache_asm = persistent["v_quant_asm"] + k_scale = persistent["k_scale_asm"] + v_scale = persistent["v_scale_asm"] + else: + k_cache = persistent["k_cache"] + v_cache_asm = persistent["v_cache_asm"] + k_scale = None + v_scale = None + + with torch.cuda.stream(stream): + out = aiter.pa_fwd_asm( + inp["query"], + k_cache, + v_cache_asm, + inp["block_tables"], + inp["seq_lens"], + inp["block_tables"].stride(0), + max_qlen=inp["max_qlen"], + K_QScale=k_scale, + V_QScale=v_scale, + out_=None, + qo_indptr=inp["qo_indptr"], + high_precision=0, + ) + + # IMPORTANT: keep refs to *all* input tensors so the caching allocator + # cannot recycle their backing memory while the async kernel is still + # running. pa_fwd_asm uses raw HIP launches and may not properly mark + # the allocator's stream-use tracking on the input blocks. + rec = dict( + iter=iter_idx, + batch_size=inp["batch_size"], + ctx_len=inp["ctx_len"], + qlen=inp["qlen"], + total_qo=inp["total_qo"], + out=out, + _query=inp["query"], + _block_tables=inp["block_tables"], + _seq_lens=inp["seq_lens"], + _qo_indptr=inp["qo_indptr"], + ) + return rec + + +# --------------------------------------------------------------------------- # +# main +# --------------------------------------------------------------------------- # +def main(): + p = argparse.ArgumentParser(description="ASM paged-attention crash repro") + p.add_argument("--kv-dtype", choices=["bf16", "fp8"], default="bf16", + help="KV cache dtype (both observed to crash in production)") + p.add_argument("--num-q-heads", type=int, default=8, + help="Q heads (production: 8/rank with TP=8)") + p.add_argument("--num-kv-heads", type=int, default=1, + help="KV heads (production: 1 after MLA absorb -> GQA=8)") + p.add_argument("--head-size", type=int, default=128) + p.add_argument("--block-size", type=int, default=16) + p.add_argument("--num-blocks", type=int, default=8192, + help="Size of KV cache pool") + p.add_argument("--max-iters", type=int, default=100000) + p.add_argument("--streams", type=int, default=8, + help="Number of concurrent CUDA streams") + p.add_argument("--sync-every", type=int, default=64, + help="Force device sync + check error every N iters") + p.add_argument("--seed", type=int, default=42) + p.add_argument("--device", default="cuda:0") + p.add_argument("--log-each-call", action="store_true", + help="Print every call's shape (for bisection)") + args = p.parse_args() + + torch.manual_seed(args.seed) + random.seed(args.seed) + torch.set_default_device(args.device) + + class Cfg: + pass + cfg = Cfg() + cfg.num_q_heads = args.num_q_heads + cfg.num_kv_heads = args.num_kv_heads + cfg.head_size = args.head_size + cfg.block_size = args.block_size + cfg.num_blocks = args.num_blocks + cfg.kv_dtype = args.kv_dtype + cfg.compute_dtype = torch.bfloat16 # query/output dtype + + print(f"[repro] device={args.device}") + print(f"[repro] kv_dtype={cfg.kv_dtype} num_q_heads={cfg.num_q_heads} " + f"num_kv_heads={cfg.num_kv_heads} GQA={cfg.num_q_heads // cfg.num_kv_heads}") + print(f"[repro] head_size={cfg.head_size} block_size={cfg.block_size} " + f"num_blocks={cfg.num_blocks}") + print(f"[repro] max_iters={args.max_iters} streams={args.streams} " + f"sync_every={args.sync_every}") + + # --- allocate persistent KV cache (re-used across iterations, like a + # paged KV pool in production) + k_cache, v_cache = make_kv_cache( + cfg.num_blocks, cfg.block_size, cfg.num_kv_heads, cfg.head_size, + cfg.compute_dtype, device=args.device, + ) + persistent = { + "k_cache": k_cache, + "v_cache_asm": asm_v_shuffle(v_cache), + } + if cfg.kv_dtype == "fp8": + k_q, v_q, k_s_asm, v_s_asm = pertoken_quant_kvcache_symm( + k_cache, v_cache, quant_dtype=aiter.dtypes.fp8 + ) + persistent["k_quant"] = k_q + persistent["v_quant_asm"] = asm_v_shuffle(v_q) + persistent["k_scale_asm"] = k_s_asm + persistent["v_scale_asm"] = v_s_asm + + torch.cuda.synchronize() + print(f"[repro] KV cache allocated: K {k_cache.shape} {k_cache.dtype} " + f"V {v_cache.shape} {v_cache.dtype}") + if cfg.kv_dtype == "fp8": + print(f"[repro] quant K {persistent['k_quant'].shape} " + f"{persistent['k_quant'].dtype} " + f"K_QScale {persistent['k_scale_asm'].shape} " + f"{persistent['k_scale_asm'].dtype}") + + streams = [torch.cuda.Stream(device=args.device) for _ in range(args.streams)] + rng = random.Random(args.seed) + + t0 = time.time() + last_log_t = t0 + last_log_iter = 0 + keepalive: List[dict] = [] # keep refs so async kernels don't free inputs + + for i in range(args.max_iters): + s = streams[i % args.streams] + try: + rec = run_one(i, rng, cfg, persistent, s, log_shape=args.log_each_call) + keepalive.append(rec) + # bound memory: only hold last ~2 batches per stream + if len(keepalive) > args.streams * 4: + keepalive.pop(0) + except Exception as e: + print(f"\n[repro] !! exception at iter {i} (sync-launch): {e}") + traceback.print_exc() + return _report_crash(i, t0, args) + + if (i + 1) % args.sync_every == 0: + try: + torch.cuda.synchronize() + except Exception as e: + print(f"\n[repro] !! HIP error surfaced at sync after iter {i}: " + f"{type(e).__name__}: {e}") + traceback.print_exc() + # dump last batch shapes to help ASM team + recent = keepalive[-min(len(keepalive), 16):] + print(f"[repro] last {len(recent)} call shapes:") + for r in recent: + print(f" iter={r['iter']:>6} bs={r['batch_size']:>4} " + f"ctx={r['ctx_len']:>5} qlen={r['qlen']} " + f"total_qo={r['total_qo']:>5}") + return _report_crash(i, t0, args) + + now = time.time() + if now - last_log_t >= 5.0: + d_iter = (i + 1) - last_log_iter + d_t = now - last_log_t + ips = d_iter / d_t + print(f"[repro] iter {i + 1:>7}/{args.max_iters} " + f"{ips:>7.1f} iter/s elapsed={now - t0:>7.1f}s") + last_log_t = now + last_log_iter = i + 1 + + torch.cuda.synchronize() + dt = time.time() - t0 + print(f"\n[repro] DONE — {args.max_iters} iters OK in {dt:.1f}s " + f"({args.max_iters / dt:.1f} iter/s). No crash observed.") + return 0 + + +def _report_crash(iter_idx: int, t0: float, args) -> int: + dt = time.time() - t0 + print(f"\n[repro] CRASHED at iter {iter_idx} after {dt:.1f}s " + f"(kv_dtype={args.kv_dtype}, streams={args.streams})") + return 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/op_tests/repros/pa_asm_fp8_min_repro.py b/op_tests/repros/pa_asm_fp8_min_repro.py new file mode 100644 index 0000000000..f604d4f9c5 --- /dev/null +++ b/op_tests/repros/pa_asm_fp8_min_repro.py @@ -0,0 +1,153 @@ +#!/usr/bin/env python3 +# Minimal aiter-only single-call reproducer for the fp8 ASM PA crash. +# +# This is *single call* of aiter.pa_fwd_asm — no streams, no loop, no race. +# Reproduces an HIP illegal memory access purely from one bad shape. +# +# Discovery path: +# The stress driver (pa_asm_crash_repro.py) crashed at iter 3 of seed=1 +# with: batch_size=128, ctx_len=6724, qlen=3, GQA=8, head_dim=128, +# block_size=16, fp8 KV per-token quant. +# This script isolates exactly that call. +# +# Usage: +# python pa_asm_fp8_min_repro.py +# +# Knobs (env var): +# PA_REPRO_PAD=1 pad block_tables to max_num_blocks_per_seq=1024 (default). +# PA_REPRO_PAD=0 use tight block_tables of (batch_size, num_blocks_per_seq). +# PA_REPRO_BS=N override batch_size (default 128). +# PA_REPRO_CTX=N override ctx_len (default 6724). +# PA_REPRO_QLEN=N override qlen (default 3). +# PA_REPRO_NBLOCKS=N override KV pool num_blocks (default 8192). +# +# Expected on a buggy build (current aiter ee28d47ac + PR3211 aff40475d): +# [AITER] hipModuleLaunchKernel failed -> HIP illegal memory access +# (kernel: pa_bf16_pertokenFp8_gqa8_1tg_4w_mtp_msk1.co) + +import os +import sys +import torch +import aiter +from aiter import pertoken_quant + + +def main(): + torch.manual_seed(0) + device = "cuda:0" + torch.set_default_device(device) + + # ---- shape (matches Kimi-K2.5 MLA-via-MHA, eagle3 MTP, TP=8 per-rank) ---- + head_size = 128 + block_size = 16 + num_q_heads = 8 # per-rank Q heads (TP=8 on 64 heads) + num_kv_heads = 1 # MLA absorbs to 1 latent KV head -> GQA=8 + + batch_size = int(os.environ.get("PA_REPRO_BS", 128)) + ctx_len = int(os.environ.get("PA_REPRO_CTX", 6724)) + qlen = int(os.environ.get("PA_REPRO_QLEN", 3)) + num_blocks = int(os.environ.get("PA_REPRO_NBLOCKS", 8192)) + pad_bt = bool(int(os.environ.get("PA_REPRO_PAD", "1"))) + + max_seq_len = 16384 + max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size # 1024 + num_blocks_per_seq = (ctx_len + block_size - 1) // block_size + + print(f"[min-repro] batch={batch_size} ctx_len={ctx_len} qlen={qlen} " + f"num_blocks={num_blocks} pad_block_tables={pad_bt}") + print(f"[min-repro] num_blocks_per_seq={num_blocks_per_seq} " + f"max_num_blocks_per_seq={max_num_blocks_per_seq}") + + # ---- allocate KV cache (bf16) and pertoken-quant it to fp8 ASM layout ---- + x = 16 // 2 # bf16 itemsize + k_cache = torch.empty( + (num_blocks, num_kv_heads, head_size // x, block_size, x), + dtype=torch.bfloat16, device=device, + ).uniform_(-1, 1) + v_cache = torch.empty( + (num_blocks, num_kv_heads, head_size, block_size), + dtype=torch.bfloat16, device=device, + ).uniform_(-1, 1) + + # pertoken quant -> fp8 + k_perm = ( + k_cache.permute(0, 1, 3, 2, 4) + .reshape(num_blocks, num_kv_heads, block_size, -1).contiguous() + ) + v_perm = ( + v_cache.permute(0, 1, 3, 2) + .reshape(num_blocks, num_kv_heads, block_size, -1).contiguous() + ) + k_q, k_scale_asm = pertoken_quant(k_perm, quant_dtype=aiter.dtypes.fp8) + v_q, v_scale_asm = pertoken_quant(v_perm, quant_dtype=aiter.dtypes.fp8) + quant_x = 16 // aiter.dtypes.fp8.itemsize + k_quant = ( + k_q.view(num_blocks, num_kv_heads, block_size, head_size // quant_x, quant_x) + .permute(0, 1, 3, 2, 4).contiguous() + ) + v_quant = ( + v_q.view(num_blocks, num_kv_heads, block_size, head_size) + .permute(0, 1, 3, 2).contiguous() + ) + # ASM V shuffle: [B, KVH, head_size, block_size] -> [B, KVH, block_size/x, head_size, x] + qx = 16 // v_quant.element_size() + v_quant_asm = ( + v_quant.view(num_blocks, num_kv_heads, head_size, block_size // qx, qx) + .permute(0, 1, 3, 2, 4).contiguous() + ) + + # ---- per-iter request inputs ---- + qo_indptr = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) + seq_lens_qo = torch.full((batch_size,), qlen, dtype=torch.int32, device=device) + qo_indptr[1:] = torch.cumsum(seq_lens_qo, dim=0) + total_qo = int(qo_indptr[-1].item()) + query = torch.empty( + (total_qo, num_q_heads, head_size), dtype=torch.bfloat16, device=device, + ).uniform_(-1, 1) + seq_lens = torch.full((batch_size,), ctx_len, dtype=torch.int32, device=device) + + if pad_bt: + block_tables = torch.zeros( + (batch_size, max_num_blocks_per_seq), dtype=torch.int32, device=device, + ) + else: + block_tables = torch.zeros( + (batch_size, num_blocks_per_seq), dtype=torch.int32, device=device, + ) + for i in range(batch_size): + idx = torch.randint( + 0, num_blocks, (num_blocks_per_seq,), dtype=torch.int32, device=device, + ) + block_tables[i, :num_blocks_per_seq] = idx + + print(f"[min-repro] query={tuple(query.shape)} qo_indptr[-1]={total_qo}") + print(f"[min-repro] block_tables={tuple(block_tables.shape)} " + f"stride0={block_tables.stride(0)}") + print(f"[min-repro] k_quant={tuple(k_quant.shape)} " + f"v_quant_asm={tuple(v_quant_asm.shape)}") + print(f"[min-repro] K_QScale={tuple(k_scale_asm.shape)} {k_scale_asm.dtype}") + torch.cuda.synchronize() + + # ---- single call ---- + print(f"[min-repro] calling pa_fwd_asm ...", flush=True) + out = aiter.pa_fwd_asm( + query, + k_quant, + v_quant_asm, + block_tables, + seq_lens, + block_tables.stride(0), + max_qlen=qlen, + K_QScale=k_scale_asm, + V_QScale=v_scale_asm, + out_=None, + qo_indptr=qo_indptr, + high_precision=0, + ) + torch.cuda.synchronize() + print(f"[min-repro] OK -> out={tuple(out.shape)} {out.dtype}") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/op_tests/repros/pa_asm_fp8_repeat_call.py b/op_tests/repros/pa_asm_fp8_repeat_call.py new file mode 100644 index 0000000000..6327497ecd --- /dev/null +++ b/op_tests/repros/pa_asm_fp8_repeat_call.py @@ -0,0 +1,149 @@ +#!/usr/bin/env python3 +# Repeat one specific (bs, ctx, qlen) call N times and report when it crashes. +# All input tensors are kept in `history` so the caching allocator can't reuse +# their memory while async kernels are in flight. + +import argparse +import sys +import torch +import aiter +from aiter import pertoken_quant + + +def build_kv(num_blocks, num_kv_heads, head_size, block_size, device): + x = 16 // 2 + k_cache = torch.empty( + (num_blocks, num_kv_heads, head_size // x, block_size, x), + dtype=torch.bfloat16, device=device, + ).uniform_(-1, 1) + v_cache = torch.empty( + (num_blocks, num_kv_heads, head_size, block_size), + dtype=torch.bfloat16, device=device, + ).uniform_(-1, 1) + k_perm = ( + k_cache.permute(0, 1, 3, 2, 4) + .reshape(num_blocks, num_kv_heads, block_size, -1).contiguous() + ) + v_perm = ( + v_cache.permute(0, 1, 3, 2) + .reshape(num_blocks, num_kv_heads, block_size, -1).contiguous() + ) + k_q, k_scale_asm = pertoken_quant(k_perm, quant_dtype=aiter.dtypes.fp8) + v_q, v_scale_asm = pertoken_quant(v_perm, quant_dtype=aiter.dtypes.fp8) + quant_x = 16 // aiter.dtypes.fp8.itemsize + k_quant = ( + k_q.view(num_blocks, num_kv_heads, block_size, head_size // quant_x, quant_x) + .permute(0, 1, 3, 2, 4).contiguous() + ) + v_quant = ( + v_q.view(num_blocks, num_kv_heads, block_size, head_size) + .permute(0, 1, 3, 2).contiguous() + ) + qx = 16 // v_quant.element_size() + v_quant_asm = ( + v_quant.view(num_blocks, num_kv_heads, head_size, block_size // qx, qx) + .permute(0, 1, 3, 2, 4).contiguous() + ) + return k_quant, v_quant_asm, k_scale_asm, v_scale_asm + + +def build_call(bs, ctx, qlen, num_blocks, num_kv_heads, num_q_heads, + head_size, block_size, device, seed=0): + torch.manual_seed(seed) + max_seq_len = 16384 + max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size + num_blocks_per_seq = (ctx + block_size - 1) // block_size + + qo_indptr = torch.zeros(bs + 1, dtype=torch.int32, device=device) + qo_indptr[1:] = torch.cumsum( + torch.full((bs,), qlen, dtype=torch.int32, device=device), dim=0) + total_qo = int(qo_indptr[-1].item()) + query = torch.empty( + (total_qo, num_q_heads, head_size), + dtype=torch.bfloat16, device=device, + ).uniform_(-1, 1) + seq_lens = torch.full((bs,), ctx, dtype=torch.int32, device=device) + block_tables = torch.zeros( + (bs, max_num_blocks_per_seq), dtype=torch.int32, device=device, + ) + for i in range(bs): + idx = torch.randint( + 0, num_blocks, (num_blocks_per_seq,), dtype=torch.int32, device=device, + ) + block_tables[i, :num_blocks_per_seq] = idx + return query, seq_lens, qo_indptr, block_tables + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--bs", type=int, default=128) + ap.add_argument("--ctx", type=int, default=6724) + ap.add_argument("--qlen", type=int, default=3) + ap.add_argument("--n-repeat", type=int, default=5) + ap.add_argument("--num-blocks", type=int, default=8192) + ap.add_argument("--kv-dtype", choices=["fp8", "bf16"], default="fp8") + args = ap.parse_args() + + device = "cuda:0" + torch.set_default_device(device) + head_size, block_size, num_q_heads, num_kv_heads = 128, 16, 8, 1 + + if args.kv_dtype == "fp8": + k_cache_for_call, v_cache_for_call, k_scale, v_scale = build_kv( + args.num_blocks, num_kv_heads, head_size, block_size, device, + ) + else: + # bf16 path: no quant scales, V still needs ASM shuffle + x = 16 // 2 + k_cache_for_call = torch.empty( + (args.num_blocks, num_kv_heads, head_size // x, block_size, x), + dtype=torch.bfloat16, device=device, + ).uniform_(-1, 1) + _v = torch.empty( + (args.num_blocks, num_kv_heads, head_size, block_size), + dtype=torch.bfloat16, device=device, + ).uniform_(-1, 1) + qx = 16 // _v.element_size() + v_cache_for_call = ( + _v.view(args.num_blocks, num_kv_heads, head_size, block_size // qx, qx) + .permute(0, 1, 3, 2, 4).contiguous() + ) + k_scale = v_scale = None + torch.cuda.synchronize() + + history = [] + for i in range(args.n_repeat): + try: + query, seq_lens, qo_indptr, block_tables = build_call( + args.bs, args.ctx, args.qlen, args.num_blocks, + num_kv_heads, num_q_heads, head_size, block_size, device, + seed=i + 1, + ) + out = aiter.pa_fwd_asm( + query, + k_cache_for_call, + v_cache_for_call, + block_tables, + seq_lens, + block_tables.stride(0), + max_qlen=args.qlen, + K_QScale=k_scale, + V_QScale=v_scale, + out_=None, + qo_indptr=qo_indptr, + high_precision=0, + ) + torch.cuda.synchronize() + history.append((query, seq_lens, qo_indptr, block_tables, out)) + except Exception as e: + print(f"CRASH at iter={i} bs={args.bs} ctx={args.ctx} qlen={args.qlen}: " + f"{type(e).__name__}: {e}", flush=True) + return 1 + + print(f"ALL OK — {args.n_repeat} calls of bs={args.bs} ctx={args.ctx} " + f"qlen={args.qlen}", flush=True) + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/op_tests/repros/pa_asm_fp8_seq_repro.py b/op_tests/repros/pa_asm_fp8_seq_repro.py new file mode 100644 index 0000000000..35eb81c86c --- /dev/null +++ b/op_tests/repros/pa_asm_fp8_seq_repro.py @@ -0,0 +1,171 @@ +#!/usr/bin/env python3 +# Clean fp8 ASM PA crash reproducer: exact sequence from stress driver +# (seed=1), no streams, no keepalive games. Every call holds refs locally +# in `history` so input tensors can't be freed under the kernel. +# +# Sequence (recorded from pa_asm_crash_repro.py --seed 1 --kv-dtype fp8): +# iter 0 bs= 96 ctx=1540 qlen=3 total_qo=288 +# iter 1 bs= 64 ctx=6361 qlen=3 total_qo=192 +# iter 2 bs=128 ctx= 919 qlen=3 total_qo=384 +# iter 3 bs=128 ctx=6724 qlen=3 total_qo=384 +# iter 4 bs=128 ctx= 168 qlen=3 total_qo=384 <- HIP error surfaces here +# +# Run with: +# AMD_SERIALIZE_KERNEL=3 HIP_LAUNCH_BLOCKING=1 \ +# python pa_asm_fp8_seq_repro.py +# +# If `--repeat-only-bad` is set, the script instead calls the iter-3 shape +# (bs=128, ctx=6724, qlen=3) repeatedly to test whether a single bad shape is +# enough vs whether the *sequence* matters. + +import argparse +import os +import sys +import torch +import aiter +from aiter import pertoken_quant + + +def build_call(rng_seed, batch_size, ctx_len, qlen, num_kv_heads, num_q_heads, + head_size, block_size, num_blocks, device): + torch.manual_seed(rng_seed) + max_seq_len = 16384 + max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size + num_blocks_per_seq = (ctx_len + block_size - 1) // block_size + + qo_indptr = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) + seq_lens_qo = torch.full((batch_size,), qlen, dtype=torch.int32, device=device) + qo_indptr[1:] = torch.cumsum(seq_lens_qo, dim=0) + total_qo = int(qo_indptr[-1].item()) + query = torch.empty( + (total_qo, num_q_heads, head_size), + dtype=torch.bfloat16, device=device, + ).uniform_(-1, 1) + seq_lens = torch.full((batch_size,), ctx_len, dtype=torch.int32, device=device) + block_tables = torch.zeros( + (batch_size, max_num_blocks_per_seq), dtype=torch.int32, device=device, + ) + for i in range(batch_size): + idx = torch.randint( + 0, num_blocks, (num_blocks_per_seq,), dtype=torch.int32, device=device, + ) + block_tables[i, :num_blocks_per_seq] = idx + return dict(query=query, seq_lens=seq_lens, qo_indptr=qo_indptr, + block_tables=block_tables, max_qlen=qlen, + bs=batch_size, ctx=ctx_len, qlen=qlen) + + +def build_kv(num_blocks, num_kv_heads, head_size, block_size, device): + x = 16 // 2 + k_cache = torch.empty( + (num_blocks, num_kv_heads, head_size // x, block_size, x), + dtype=torch.bfloat16, device=device, + ).uniform_(-1, 1) + v_cache = torch.empty( + (num_blocks, num_kv_heads, head_size, block_size), + dtype=torch.bfloat16, device=device, + ).uniform_(-1, 1) + k_perm = ( + k_cache.permute(0, 1, 3, 2, 4) + .reshape(num_blocks, num_kv_heads, block_size, -1).contiguous() + ) + v_perm = ( + v_cache.permute(0, 1, 3, 2) + .reshape(num_blocks, num_kv_heads, block_size, -1).contiguous() + ) + k_q, k_scale_asm = pertoken_quant(k_perm, quant_dtype=aiter.dtypes.fp8) + v_q, v_scale_asm = pertoken_quant(v_perm, quant_dtype=aiter.dtypes.fp8) + quant_x = 16 // aiter.dtypes.fp8.itemsize + k_quant = ( + k_q.view(num_blocks, num_kv_heads, block_size, head_size // quant_x, quant_x) + .permute(0, 1, 3, 2, 4).contiguous() + ) + v_quant = ( + v_q.view(num_blocks, num_kv_heads, block_size, head_size) + .permute(0, 1, 3, 2).contiguous() + ) + qx = 16 // v_quant.element_size() + v_quant_asm = ( + v_quant.view(num_blocks, num_kv_heads, head_size, block_size // qx, qx) + .permute(0, 1, 3, 2, 4).contiguous() + ) + return k_quant, v_quant_asm, k_scale_asm, v_scale_asm + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--repeat-only-bad", action="store_true", + help="Skip iters 0-2, repeat iter-3 shape N times.") + ap.add_argument("--n-repeat", type=int, default=20) + ap.add_argument("--num-blocks", type=int, default=8192) + args = ap.parse_args() + + device = "cuda:0" + torch.set_default_device(device) + + head_size, block_size, num_q_heads, num_kv_heads = 128, 16, 8, 1 + print(f"[seq-repro] KV pool: num_blocks={args.num_blocks} GQA=8 " + f"head_size={head_size} block_size={block_size}") + + k_quant, v_quant_asm, k_scale, v_scale = build_kv( + args.num_blocks, num_kv_heads, head_size, block_size, device, + ) + torch.cuda.synchronize() + + if args.repeat_only_bad: + seq = [(3 + i, 128, 6724, 3) for i in range(args.n_repeat)] + print(f"[seq-repro] mode: repeat-only-bad, " + f"{args.n_repeat}x (bs=128, ctx=6724, qlen=3)") + else: + seq = [ + (0, 96, 1540, 3), + (1, 64, 6361, 3), + (2, 128, 919, 3), + (3, 128, 6724, 3), + (4, 128, 168, 3), + (5, 128, 1024, 3), # extras to keep going + (6, 96, 4096, 4), + ] + + history = [] # holds all input tensors so they cannot be freed + for idx, (seed_for_call, bs, ctx, qlen) in enumerate(seq): + try: + call = build_call( + rng_seed=seed_for_call, + batch_size=bs, ctx_len=ctx, qlen=qlen, + num_kv_heads=num_kv_heads, num_q_heads=num_q_heads, + head_size=head_size, block_size=block_size, + num_blocks=args.num_blocks, device=device, + ) + print(f"[seq-repro] iter={idx} bs={call['bs']:>4} " + f"ctx={call['ctx']:>5} qlen={call['qlen']} " + f"-> pa_fwd_asm", flush=True) + out = aiter.pa_fwd_asm( + call["query"], + k_quant, + v_quant_asm, + call["block_tables"], + call["seq_lens"], + call["block_tables"].stride(0), + max_qlen=call["max_qlen"], + K_QScale=k_scale, + V_QScale=v_scale, + out_=None, + qo_indptr=call["qo_indptr"], + high_precision=0, + ) + call["out"] = out + history.append(call) + torch.cuda.synchronize() + print(f"[seq-repro] ok out={tuple(out.shape)}", flush=True) + except Exception as e: + print(f"[seq-repro] !! CRASH at iter={idx} bs={bs} ctx={ctx} " + f"qlen={qlen}: {type(e).__name__}: {e}") + return 1 + + print(f"\n[seq-repro] ALL OK — {len(seq)} calls completed.") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/op_tests/repros/pa_asm_fp8_shape_sweep.py b/op_tests/repros/pa_asm_fp8_shape_sweep.py new file mode 100644 index 0000000000..672cec6eb3 --- /dev/null +++ b/op_tests/repros/pa_asm_fp8_shape_sweep.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python3 +# Sweep shapes to find which trigger the fp8 ASM PA OOB. +# Each (bs, ctx, qlen) is tested in a fresh forked Python process. + +import argparse +import os +import subprocess +import sys + + +def run_one(bs, ctx, qlen, n_repeat=5, num_blocks=8192): + cmd = [ + sys.executable, + os.path.join(os.path.dirname(__file__), "pa_asm_fp8_repeat_call.py"), + "--bs", str(bs), + "--ctx", str(ctx), + "--qlen", str(qlen), + "--n-repeat", str(n_repeat), + "--num-blocks", str(num_blocks), + ] + env = dict(os.environ) + env["AMD_SERIALIZE_KERNEL"] = "3" + env["HIP_LAUNCH_BLOCKING"] = "1" + try: + r = subprocess.run(cmd, env=env, timeout=60, + stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + except subprocess.TimeoutExpired: + return "TIMEOUT", "" + out = r.stdout.decode(errors="ignore") + if "ALL OK" in out: + return "OK", out + if "CRASH" in out or "HIP error" in out or "illegal memory" in out: + # find first crash iter from "CRASH at iter=N" + import re + m = re.search(r"CRASH at iter=(\d+)", out) + crash_iter = int(m.group(1)) if m else -1 + return f"CRASH@{crash_iter}", out + return "UNKNOWN", out + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--n-repeat", type=int, default=5) + args = ap.parse_args() + + print(f"# fp8 ASM PA shape sweep (each cell = {args.n_repeat} repeats " + f"of same call, fresh process)") + print(f"# OK = no crash. CRASH@k = launch error surfaced at call k " + f"(0-indexed; means call k-1 corrupted device).") + print() + + qlens = [1, 2, 3, 4] + ctx_lens = [128, 512, 1024, 2048, 4096, 6724, 8192, 12288, 16384] + batch_sizes = [16, 32, 64, 96, 128] + + for qlen in qlens: + print(f"## qlen={qlen}") + header = "ctx \\ bs |" + "".join(f"{b:>10} |" for b in batch_sizes) + print(header) + print("-" * len(header)) + for ctx in ctx_lens: + row = f"{ctx:>8} |" + for bs in batch_sizes: + tag, _ = run_one(bs, ctx, qlen, n_repeat=args.n_repeat) + row += f"{tag:>10} |" + print(row, flush=True) + print() + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/op_tests/test_pa_block_id_truncation.py b/op_tests/test_pa_block_id_truncation.py new file mode 100644 index 0000000000..994c5a5865 --- /dev/null +++ b/op_tests/test_pa_block_id_truncation.py @@ -0,0 +1,341 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +"""Reproduce the aiter ASM paged-attention block_id truncation issue. + +When the block_id loaded from the block_tables tensor crosses 65,535 +(= 2^16), the aiter precompiled ASM `pa_*.co` family on gfx950/gfx942 +reads from the wrong physical KV slot — consistent with a 16-bit +narrowing (`block_id & 0xFFFF`) of the loaded value before it is used +in slot-address arithmetic. + +Strategy: + * Allocate a KV pool with > 65,535 physical blocks (NUM_BLOCKS = 70,000). + * Fill two specific blocks (one below 65,535, one above) with a + distinctive constant, leave everything else at zero. + * Run pa_fwd_asm on a single sequence whose block_tables points at + each chosen block in turn, with context_lens = block_size. + * Because the chosen block is filled with a constant V, the attention + output equals that constant (softmax over a single block's slots + sums to 1, weighted with constant V). + +If the kernel narrows block_id to 16 bits, the high block_id (= 67,000) +wraps to 67,000 - 65,536 = 1,464, an unfilled block that contains zeros, +so the output collapses to ~0 instead of the expected fingerprint. + +Empirical result on gfx950 (MI355X), aiter built 2026-04-20+: + Both the qlen=1 kernel (`pa_bf16_noquant_gqa8_1tg_4w.co`) and the + qlen=4 MTP kernel (`pa_bf16_noquant_gqa8_1tg_4w_mtp_msk1.co`) return + 0.0000 for block_id = 67,000 instead of the expected 0.7500. The wrap + target (1,464) matches `block_id & 0xFFFF`. + + The reproduction requires NUM_KV_HEADS = 8 to match production + per-block stride (32 KB). With NUM_KV_HEADS = 1 (4 KB stride) the bug + does not surface — likely because some tile-level address calculation + in the kernel only narrows block_id when iterating over enough KV + heads. Either way, this file reproduces the production-relevant + configuration. + +Run: + pytest /root/aiter/op_tests/test_pa_block_id_truncation.py -v -s + +Or as a script: + python /root/aiter/op_tests/test_pa_block_id_truncation.py +""" + +import pytest +import torch + +import aiter + +# ---------- configuration matching the ATOM Eagle3 draft signature ---------- +# Production layout per TP=8 rank: num_q_heads = num_kv_heads = 8 (full MHA). +# aiter's gqa-rounding selects the gqa8 kernel either way. +# +# Critical: per-block stride must match production for the i32-overflow +# hypothesis to be testable. With NUM_KV_HEADS=8, HEAD_DIM=128, BLOCK_SIZE=16, +# bf16 elem_size=2: +# per_block_stride = 16 × 8 × 128 × 2 = 32,768 bytes +# i32 overflow boundary = 2^31 / 32768 = 65,536 +# Lowering NUM_KV_HEADS would shrink the stride and push the overflow +# boundary far above any practical block_id, masking the bug. +NUM_Q_HEADS = 8 +NUM_KV_HEADS = 8 +HEAD_DIM = 128 +BLOCK_SIZE = 16 + +# Need num_blocks > 65535 to trigger the crossing. +NUM_BLOCKS = 70_000 + +# Block IDs to fingerprint and probe. Layout: +# 1,000 — safely below the boundary (sanity baseline) +# 65,535 — last value that fits in u16 (= 0xFFFF). Should still read +# correctly even if the kernel does `block_id & 0xFFFF`, +# because that operation is a no-op here. +# 65,536 — first value that overflows u16 (= 0x10000). If the kernel +# narrows to 16 bits, this wraps to 0 and reads block 0. +# 67,000 — well above the boundary; wraps to 67000 - 65536 = 1,464. +SAFE_BLOCK_ID = 1_000 +EDGE_LAST_SAFE = 65_535 +EDGE_FIRST_BUGGY = 65_536 +BUGGY_BLOCK_ID = 67_000 + +# Distinct fingerprint per block — kept small (< 1.0) to stay well within +# bf16 precision after softmax normalization. +SIG_SAFE = 0.50 +SIG_EDGE_LAST = 0.30 +SIG_EDGE_FIRST = 0.40 +SIG_BUGGY = 0.75 + +_FINGERPRINTS = [ + (SAFE_BLOCK_ID, SIG_SAFE, "below_65535"), + (EDGE_LAST_SAFE, SIG_EDGE_LAST, "edge_65535_last_u16"), + (EDGE_FIRST_BUGGY, SIG_EDGE_FIRST, "edge_65536_first_overflow"), + (BUGGY_BLOCK_ID, SIG_BUGGY, "above_65535"), +] + + +def _build_kv_cache(): + """Allocate a sparse bf16 KV pool with two fingerprinted blocks.""" + dtype = torch.bfloat16 + x = 16 // dtype.itemsize # = 8 for bf16 + assert HEAD_DIM % x == 0 + + # K layout: [num_blocks, num_kv_heads, head_dim/x, block_size, x] + k_cache = torch.zeros( + NUM_BLOCKS, + NUM_KV_HEADS, + HEAD_DIM // x, + BLOCK_SIZE, + x, + dtype=dtype, + device="cuda", + ) + # V layout: [num_blocks, num_kv_heads, head_dim, block_size] + v_cache = torch.zeros( + NUM_BLOCKS, + NUM_KV_HEADS, + HEAD_DIM, + BLOCK_SIZE, + dtype=dtype, + device="cuda", + ) + + for block_id, sig, _label in _FINGERPRINTS: + k_cache[block_id].fill_(sig) + v_cache[block_id].fill_(sig) + + return k_cache, v_cache + + +def _run_pa_fwd_asm(k_cache, v_cache, target_block_id, max_qlen=1): + """Run pa_fwd_asm with a single sequence that contains exactly one block, + that block being `target_block_id`. Returns the attention output value. + + `max_qlen` selects the kernel family: + max_qlen=1 → mtp=0 → pa_bf16_noquant_gqa8_1tg_4w.co (non-MTP decode) + max_qlen=4 → mtp=14→1 → pa_bf16_noquant_gqa8_1tg_4w_mtp_msk1.co (MTP) + """ + NUM_PAGES = 16 + block_tables = torch.full( + (1, NUM_PAGES), target_block_id, dtype=torch.int32, device="cuda" + ) + context_lens = torch.full( + (1,), BLOCK_SIZE * NUM_PAGES, dtype=torch.int32, device="cuda" + ) + cu_seqlens_q = torch.tensor([0, max_qlen], dtype=torch.int32, device="cuda") + + # Query: arbitrary nonzero values — softmax will normalize, V is constant. + query = torch.ones( + max_qlen, NUM_Q_HEADS, HEAD_DIM, dtype=torch.bfloat16, device="cuda" + ) + + out = aiter.pa_fwd_asm( + query, + k_cache, + v_cache, + block_tables, + context_lens, + block_tables.stride(0), + max_qlen=max_qlen, + K_QScale=None, + V_QScale=None, + out_=None, + qo_indptr=cu_seqlens_q, + high_precision=0, + ) + # Output shape: [max_qlen, num_q_heads, head_dim] — all elements should + # equal the fingerprint of target_block_id (because V is constant in + # that block and softmax weights sum to 1). + return out.float().mean().item() + + +@pytest.mark.parametrize( + "block_id,expected_sig,label", + _FINGERPRINTS, +) +@pytest.mark.parametrize( + "max_qlen,kernel_label", + [ + (1, "qlen1_non_MTP_kernel"), + (4, "qlen4_MTP_kernel"), + ], +) +def test_pa_fwd_asm_block_id_no_truncation( + block_id, expected_sig, label, max_qlen, kernel_label +): + """Output for a single-block sequence must match that block's fingerprint + regardless of whether block_id is below or above 65,535. Run for both + qlen=1 (non-MTP decode kernel) and qlen=4 (MTP kernel).""" + k_cache, v_cache = _build_kv_cache() + actual = _run_pa_fwd_asm(k_cache, v_cache, block_id, max_qlen=max_qlen) + + msg = ( + f"[{kernel_label}/{label}] block_id={block_id} max_qlen={max_qlen}: " + f"expected output ≈ {expected_sig}, got {actual:.6f}. " + ) + if block_id >= 65_536: + wrap = block_id - 65_536 + msg += ( + f"If the kernel narrows block_id to 16 bits, the high block_id " + f"would wrap to block {wrap} (unfilled, = 0), so output collapses " + f"toward ~0. Observed value of ~0 here is the bug signature." + ) + assert actual == pytest.approx(expected_sig, abs=1e-2), msg + + +if __name__ == "__main__": + # Standalone runner for quick repro without pytest infrastructure. + print( + f"Allocating KV pool: {NUM_BLOCKS} blocks × bf16 " + f"× {NUM_KV_HEADS} kv_head × {HEAD_DIM} head_dim × {BLOCK_SIZE} block_size" + ) + k_cache, v_cache = _build_kv_cache() + print(f" K cache {tuple(k_cache.shape)} = {k_cache.numel() * 2 / 1e9:.2f} GB") + print(f" V cache {tuple(v_cache.shape)} = {v_cache.numel() * 2 / 1e9:.2f} GB") + print() + + for max_qlen, kernel_label in [(1, "qlen1_non_MTP"), (4, "qlen4_MTP")]: + print(f"=== {kernel_label} (max_qlen={max_qlen}) ===") + for block_id, expected, label in _FINGERPRINTS: + actual = _run_pa_fwd_asm(k_cache, v_cache, block_id, max_qlen=max_qlen) + status = "OK" if abs(actual - expected) < 1e-2 else "BUG" + print( + f"[{status}] block_id={block_id:>7d} expected={expected:.4f} " + f"actual={actual:.4f} Δ={actual - expected:+.4f} ({label})" + ) + if status == "BUG" and block_id >= 65_536: + wrap = block_id & 0xFFFF + print( + f" → if block_id is narrowed to 16 bits, " + f"reads block {wrap} instead (unfilled = 0)." + ) + print() + + # ---- Performance comparison ---- + # Measure latency across different block_id ranges and batch sizes + # to verify no performance regression from the rebase fix. + + print("=== Performance Comparison ===") + print( + f"{'scenario':<30s} {'batch':>5s} {'ctx_len':>7s} {'max_qlen':>8s} " + f"{'avg_us':>8s} {'std_us':>8s}" + ) + print("-" * 75) + + PERF_NUM_WARMUP = 5 + PERF_NUM_ITERS = 50 + + perf_configs = [ + ("low_block_ids", 1000, 1), + ("high_block_ids", 67000, 1), + ("low_block_ids", 1000, 4), + ("high_block_ids", 67000, 4), + ] + + for num_seqs in [1, 8, 32]: + for label, base_block_id, max_qlen in perf_configs: + num_pages = 16 + block_tables = torch.full( + (num_seqs, num_pages), + base_block_id, + dtype=torch.int32, + device="cuda", + ) + for i in range(num_seqs): + block_tables[i] = base_block_id + i + k_cache[base_block_id + i].fill_(0.25) + v_cache[base_block_id + i].fill_(0.25) + + ctx_len = BLOCK_SIZE * num_pages + context_lens = torch.full( + (num_seqs,), + ctx_len, + dtype=torch.int32, + device="cuda", + ) + total_q = num_seqs * max_qlen + cu_seqlens_q = torch.arange( + 0, + total_q + 1, + max_qlen, + dtype=torch.int32, + device="cuda", + ) + query = torch.randn( + total_q, + NUM_Q_HEADS, + HEAD_DIM, + dtype=torch.bfloat16, + device="cuda", + ) + + def _run(): + return aiter.pa_fwd_asm( + query, + k_cache, + v_cache, + block_tables, + context_lens, + block_tables.stride(0), + max_qlen=max_qlen, + K_QScale=None, + V_QScale=None, + out_=None, + qo_indptr=cu_seqlens_q, + high_precision=0, + ) + + for _ in range(PERF_NUM_WARMUP): + _run() + torch.cuda.synchronize() + + start_events = [ + torch.cuda.Event(enable_timing=True) for _ in range(PERF_NUM_ITERS) + ] + end_events = [ + torch.cuda.Event(enable_timing=True) for _ in range(PERF_NUM_ITERS) + ] + for i in range(PERF_NUM_ITERS): + start_events[i].record() + _run() + end_events[i].record() + torch.cuda.synchronize() + + latencies = [ + s.elapsed_time(e) * 1000 for s, e in zip(start_events, end_events) + ] + avg_us = sum(latencies) / len(latencies) + std_us = (sum((x - avg_us) ** 2 for x in latencies) / len(latencies)) ** 0.5 + + tag = f"{label}_qlen{max_qlen}" + print( + f" {tag:<28s} {num_seqs:>5d} {ctx_len:>7d} {max_qlen:>8d} " + f"{avg_us:>8.2f} {std_us:>8.2f}" + ) + print() + + print( + "Note: low_block_ids (<65536) vs high_block_ids (>65536) should show\n" + " similar latency — any significant gap indicates a regression." + )