From 4bc109fe408f46e29c25279a241b3b1c44f95d35 Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Thu, 25 Dec 2025 01:08:37 -0600 Subject: [PATCH 01/31] Enable gptoss sink Signed-off-by: Linjun-AMD --- example/ck_tile/01_fmha/example_fmha_fwd.cpp | 5 +- example/ck_tile/01_fmha/fmha_fwd.hpp | 29 +++++-- example/ck_tile/01_fmha/fmha_fwd_runner.hpp | 78 ++++++++++++++++--- .../fmha/kernel/fmha_batch_prefill_kernel.hpp | 18 ++++- .../ops/fmha/kernel/fmha_fwd_kernel.hpp | 52 +++++++++---- .../fmha/kernel/fmha_fwd_pagedkv_kernel.hpp | 31 +++++--- .../fmha/kernel/fmha_fwd_splitkv_kernel.hpp | 20 +++-- ...ock_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp | 21 +++-- ...litkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp | 21 +++-- ...ock_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp | 23 ++++-- .../pipeline/block_fmha_pipeline_qr_ks_vs.hpp | 21 +++-- .../block_fmha_pipeline_qr_ks_vs_async.hpp | 21 +++-- ...ck_fmha_pipeline_qr_ks_vs_async_trload.hpp | 28 +++++-- 13 files changed, 285 insertions(+), 83 deletions(-) diff --git a/example/ck_tile/01_fmha/example_fmha_fwd.cpp b/example/ck_tile/01_fmha/example_fmha_fwd.cpp index 6f2616cae56..f5ad6b2bc57 100644 --- a/example/ck_tile/01_fmha/example_fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/example_fmha_fwd.cpp @@ -114,7 +114,8 @@ auto create_args(int argc, char* argv[]) .insert("kv_eff_lens", "", "Batch-mode only: per-batch effective seqlen for KV (exclude PAD).\n" - "Comma-separated list of length 'b'. If empty, no override."); + "Comma-separated list of length 'b'. If empty, no override.") + .insert("init_sink", "0", "value to init the output tensor sink value for validation"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); @@ -157,6 +158,7 @@ auto run(const ck_tile::ArgParser& arg_parser) ck_tile::index_t num_splits = arg_parser.get_int("num_splits"); std::string init_method = arg_parser.get_str("init"); uint32_t seed = arg_parser.get_uint32("seed"); + int init_sink_value = arg_parser.get_int("init_sink"); ck_tile::stream_config stream_config{nullptr, true, @@ -203,6 +205,7 @@ auto run(const ck_tile::ArgParser& arg_parser) init_method, seed, do_validation, + init_sink_value, stream_config, json); } diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index ba55d6d722a..ba0615d4a79 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -230,6 +230,7 @@ struct fmha_fwd_args // array [batch + 1]. (Used with padding) const void* cu_seqlen_k_ptr = nullptr; // Cumulative logical (excluding padding) sequence length // array [batch + 1]. (Used with padding) + const void* sink_ptr; ck_tile::index_t seqlen_q; ck_tile::index_t seqlen_k; @@ -317,6 +318,7 @@ struct fmha_fwd_pagedkv_args const void* seqstart_q_ptr; const void* seqstart_k_ptr; const void* seqlen_k_ptr; + const void* sink_ptr; ck_tile::index_t seqlen_q; ck_tile::index_t seqlen_k; @@ -400,6 +402,7 @@ struct fmha_fwd_splitkv_args const void* seqstart_q_ptr; const void* seqstart_k_ptr; const void* seqlen_k_ptr; + const void* sink_ptr; ck_tile::index_t seqlen_q; ck_tile::index_t seqlen_k; @@ -476,6 +479,7 @@ struct fmha_fwd_appendkv_args ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr const void* cache_batch_idx; // only used if block_table_ptr is nullptr -> batch mode (kvcache) + const void* sink_ptr; ck_tile::index_t stride_q; ck_tile::index_t stride_k; @@ -519,6 +523,7 @@ struct fmha_batch_prefill_args // 1) + // kargs.kv_last_page_lens[b] const void* seqstart_q_ptr; + const void* sink_ptr; ck_tile::index_t seqlen_q; ck_tile::index_t seqlen_k; @@ -627,7 +632,8 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.s_randval, args.drop_seed_offset, args.cu_seqlen_q_ptr, - args.cu_seqlen_k_ptr); + args.cu_seqlen_k_ptr, + args.sink_ptr); } else { // create batch mode kernel arguments @@ -677,7 +683,8 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.s_randval, args.drop_seed_offset, args.cu_seqlen_q_ptr, - args.cu_seqlen_k_ptr); + args.cu_seqlen_k_ptr, + args.sink_ptr); } }(); @@ -837,7 +844,8 @@ auto fmha_fwd_pagedkv_create_kargs_and_grids(fmha_fwd_pagedkv_args args) args.window_size_right, args.sink_size, args.mask_type, - args.min_seqlen_q); + args.min_seqlen_q, + args.sink_ptr); } else { // create batch mode kernel arguments @@ -882,7 +890,8 @@ auto fmha_fwd_pagedkv_create_kargs_and_grids(fmha_fwd_pagedkv_args args) args.window_size_left, args.window_size_right, args.sink_size, - args.mask_type); + args.mask_type, + args.sink_ptr); } }(); @@ -949,7 +958,8 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args) args.window_size_left, args.window_size_right, args.sink_size, - args.mask_type); + args.mask_type, + args.sink_ptr); } else { // create batch mode kernel arguments @@ -997,7 +1007,8 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args) args.window_size_left, args.window_size_right, args.sink_size, - args.mask_type); + args.mask_type, + args.sink_ptr); } }(); @@ -1164,7 +1175,8 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args) args.mask_type, args.p_drop, args.s_randval, - args.drop_seed_offset); + args.drop_seed_offset, + args.sink_ptr); } else { // create batch mode kernel arguments @@ -1220,7 +1232,8 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args) args.mask_type, args.p_drop, args.s_randval, - args.drop_seed_offset); + args.drop_seed_offset, + args.sink_ptr); } }(); diff --git a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp index 536fcb06922..2d65fe467e6 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp @@ -184,6 +184,7 @@ fwd_result fmha_fwd_run(mode_enum mode, std::string init_method, uint32_t seed, int do_validation, + int init_sink_value, const ck_tile::stream_config& stream_config, std::optional json = std::nullopt) { @@ -527,6 +528,7 @@ fwd_result fmha_fwd_run(mode_enum mode, ck_tile::HostTensor q_host( get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q)); + ck_tile::HostTensor sink_host({nhead}); ck_tile::HostTensor k_host( 0 < page_block_size ? get_lengths(i_perm, max_num_page_blocks, nhead_k, page_block_size, hdim_q) @@ -609,6 +611,7 @@ fwd_result fmha_fwd_run(mode_enum mode, ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, next_seed()}( bias_host); } + else if(init_method == "ni") { ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, next_seed()}(q_host); @@ -695,10 +698,15 @@ fwd_result fmha_fwd_run(mode_enum mode, iota_shuffle(block_table_host.begin(), block_table_host.end(), 0, random_engine); iota_shuffle(cache_batch_idx_host.begin(), cache_batch_idx_host.end(), 0, random_engine); - + if(init_sink_value != 0) + { + ck_tile::FillUniformDistributionIntegerValue{30.f, 100.f, next_seed()}( + sink_host); + } ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem sink_buf(sink_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem knew_buf(knew_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem vnew_buf(vnew_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem bias_buf(bias_host.get_element_space_size_in_bytes()); @@ -743,6 +751,7 @@ fwd_result fmha_fwd_run(mode_enum mode, q_buf.ToDevice(q_host.data()); k_buf.ToDevice(k_host.data()); v_buf.ToDevice(v_host.data()); + sink_buf.ToDevice(sink_host.data()); knew_buf.ToDevice(knew_host.data()); vnew_buf.ToDevice(vnew_host.data()); bias_buf.ToDevice(bias_host.data()); @@ -971,7 +980,10 @@ fwd_result fmha_fwd_run(mode_enum mode, args.q_ptr = q_buf.GetDeviceBuffer(); args.k_ptr = k_buf.GetDeviceBuffer(); args.v_ptr = v_buf.GetDeviceBuffer(); - + if(init_sink_value != 0) + args.sink_ptr = sink_buf.GetDeviceBuffer(); + else + args.sink_ptr = nullptr; args.batch = batch; args.seqlen_q = shape_seqlen_q; // unused in group mode args.hdim_q = hdim_q; @@ -1675,19 +1687,67 @@ fwd_result fmha_fwd_run(mode_enum mode, mask.type == mask_enum::mask_top_left)); } const ck_tile::HostTensor masked_s_host_ref = s_host_ref; - if(lse) + if(init_sink_value != 0) { - ck_tile:: - reference_batched_softmax( - s_host_ref, p_host_ref, p_compute_element_func, lse_host_ref); + // Create extended tensor with sink token + ck_tile::HostTensor s_with_sinks_ref( + {nhead, real_seqlen_q, real_seqlen_k + 1}); + + // Copy original attention scores and append sink values + for(auto i_h = 0; i_h < nhead; i_h++) + { + for(auto i_r = 0; i_r < real_seqlen_q; i_r++) + { + for(auto i_c = 0; i_c < real_seqlen_k; i_c++) + { + s_with_sinks_ref(i_h, i_r, i_c) = s_host_ref(i_h, i_r, i_c); + } + // Append sink token at the end of each row + s_with_sinks_ref(i_h, i_r, real_seqlen_k) = scale_s_host * sink_host(i_h); + } + } + + // Compute softmax on extended tensor + ck_tile::HostTensor p_extended( + {nhead, real_seqlen_q, real_seqlen_k + 1}); + + if(lse) + { + ck_tile::reference_batched_softmax( + s_with_sinks_ref, p_extended, p_compute_element_func, lse_host_ref); + } + else + { + ck_tile::reference_batched_softmax( + s_with_sinks_ref, p_extended, p_compute_element_func); + } + + // Extract only the original columns (exclude sink token column) + p_host_ref.ForEach( + [&](auto& self, auto idx) { self(idx) = p_extended(idx[0], idx[1], idx[2]); }); } else { - ck_tile:: - reference_batched_softmax( + // No sink tokens - compute softmax directly + if(lse) + { + ck_tile::reference_batched_softmax( + s_host_ref, p_host_ref, p_compute_element_func, lse_host_ref); + } + else + { + ck_tile::reference_batched_softmax( s_host_ref, p_host_ref, p_compute_element_func); + } } - if(p_drop > 0) { ck_tile::HostTensor randval_host_ref( diff --git a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp index 73b6a329d18..e820c7d3642 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp @@ -77,6 +77,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel const void* k_ptr; const void* v_ptr; void* o_ptr; + const void* sink_ptr; ck_tile::index_t seqlen_q; ck_tile::index_t seqlen_k; @@ -332,12 +333,14 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel float p_drop, bool s_randval, std::variant, std::pair> - drop_seed_offset) + drop_seed_offset, + const void* sink_ptr = nullptr) { Kargs kargs{{q_ptr, k_ptr, v_ptr, o_ptr, + sink_ptr, seqlen_q, -1, hdim_q, @@ -485,12 +488,14 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel float p_drop, bool s_randval, std::variant, std::pair> - drop_seed_offset) + drop_seed_offset, + const void* sink_ptr = nullptr) { Kargs kargs{{q_ptr, k_ptr, v_ptr, o_ptr, + sink_ptr, -1, // seqlen will be updated by another pointer -1, // hdim_q, @@ -701,6 +706,9 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel long_index_t batch_offset_o = 0; const int32_t num_page_blocks = kargs.kv_indptr[i_batch + 1] - kargs.kv_indptr[i_batch]; + const float sink_value = kargs.sink_ptr != nullptr + ? *(static_cast(kargs.sink_ptr) + i_nhead) + : static_cast(-numeric::infinity()); #if 0 // we assume page_block_size=1 for now const int32_t last_page_len = kargs.kv_last_page_lens[i_batch]; #endif @@ -1111,7 +1119,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel kargs.kv_page_indices, kargs.stride_k, kargs.stride_v, - dropout); + dropout, + sink_value); } else { @@ -1131,7 +1140,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel kargs.kv_page_indices, kargs.stride_k, kargs.stride_v, - dropout); + dropout, + sink_value); } }(); diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 4dd99a6ea96..ba8a1f3f970 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -89,6 +89,7 @@ struct FmhaFwdKernel const void* k_ptr; const void* v_ptr; void* o_ptr; + const void* sink_ptr; ck_tile::index_t seqlen_q; ck_tile::index_t seqlen_k; @@ -343,12 +344,14 @@ struct FmhaFwdKernel std::variant, std::pair> drop_seed_offset, const void* cu_seqlen_q_ptr = nullptr, - const void* cu_seqlen_k_ptr = nullptr) + const void* cu_seqlen_k_ptr = nullptr, + const void* sink_ptr = nullptr) { Kargs kargs{{q_ptr, k_ptr, v_ptr, o_ptr, + sink_ptr, seqlen_q, seqlen_k, hdim_q, @@ -490,7 +493,8 @@ struct FmhaFwdKernel bool s_randval, const std::tuple& drop_seed_offset, const void* cu_seqlen_q_ptr = nullptr, - const void* cu_seqlen_k_ptr = nullptr) + const void* cu_seqlen_k_ptr = nullptr, + const void* sink_ptr = nullptr) { return MakeKargsImpl( q_ptr, @@ -539,7 +543,8 @@ struct FmhaFwdKernel s_randval, std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)), cu_seqlen_q_ptr, - cu_seqlen_k_ptr); + cu_seqlen_k_ptr, + sink_ptr); } // std::variant<> can't take in a list initializer, overload for backward compatibility @@ -591,7 +596,8 @@ struct FmhaFwdKernel bool s_randval, const std::tuple& drop_seed_offset, const void* cu_seqlen_q_ptr = nullptr, - const void* cu_seqlen_k_ptr = nullptr) + const void* cu_seqlen_k_ptr = nullptr, + const void* sink_ptr = nullptr) { return MakeKargsImpl( q_ptr, @@ -640,7 +646,8 @@ struct FmhaFwdKernel s_randval, std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)), cu_seqlen_q_ptr, - cu_seqlen_k_ptr); + cu_seqlen_k_ptr, + sink_ptr); } template @@ -688,12 +695,14 @@ struct FmhaFwdKernel std::variant, std::pair> drop_seed_offset, const void* cu_seqlen_q_ptr = nullptr, - const void* cu_seqlen_k_ptr = nullptr) + const void* cu_seqlen_k_ptr = nullptr, + const void* sink_ptr = nullptr) { Kargs kargs{{q_ptr, k_ptr, v_ptr, o_ptr, + sink_ptr, -1, // seqlen will be updated by another pointer -1, // hdim_q, @@ -833,7 +842,8 @@ struct FmhaFwdKernel bool s_randval, const std::tuple& drop_seed_offset, const void* cu_seqlen_q_ptr = nullptr, - const void* cu_seqlen_k_ptr = nullptr) + const void* cu_seqlen_k_ptr = nullptr, + const void* sink_ptr = nullptr) { return MakeKargsImpl( q_ptr, @@ -878,7 +888,8 @@ struct FmhaFwdKernel s_randval, std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)), cu_seqlen_q_ptr, - cu_seqlen_k_ptr); + cu_seqlen_k_ptr, + sink_ptr); } // std::variant<> can't take in a list initializer, overload for backward compatibility @@ -926,7 +937,8 @@ struct FmhaFwdKernel bool s_randval, const std::tuple& drop_seed_offset, const void* cu_seqlen_q_ptr = nullptr, - const void* cu_seqlen_k_ptr = nullptr) + const void* cu_seqlen_k_ptr = nullptr, + const void* sink_ptr = nullptr) { return MakeKargsImpl( q_ptr, @@ -971,7 +983,8 @@ struct FmhaFwdKernel s_randval, std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)), cu_seqlen_q_ptr, - cu_seqlen_k_ptr); + cu_seqlen_k_ptr, + sink_ptr); } CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_, @@ -1093,10 +1106,8 @@ struct FmhaFwdKernel { // allocate LDS __shared__ char smem_ptr[GetSmemSize()]; - // divide problem const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs); - const index_t i_m0 = amd_wave_read_first_lane(i_tile_m * FmhaPipeline::kM0); const index_t i_n1 = amd_wave_read_first_lane(i_tile_n * FmhaPipeline::kN1); @@ -1107,6 +1118,9 @@ struct FmhaFwdKernel long_index_t batch_offset_randval = 0; long_index_t batch_offset_lse = 0; long_index_t batch_offset_o = 0; + const float sink_value = kargs.sink_ptr != nullptr + ? *(static_cast(kargs.sink_ptr) + i_nhead) + : static_cast(-numeric::infinity()); if constexpr(kIsGroupMode) { @@ -1525,7 +1539,6 @@ struct FmhaFwdKernel }(); BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk}; - auto o_acc_tile = [&]() { if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR) { @@ -1564,7 +1577,8 @@ struct FmhaFwdKernel variant_params, block_indices, smem_ptr, - dropout); + dropout, + sink_value); } else { @@ -1581,7 +1595,8 @@ struct FmhaFwdKernel variant_params, block_indices, smem_ptr, - dropout); + dropout, + sink_value); } }(); @@ -1621,6 +1636,9 @@ struct FmhaFwdKernel constexpr bool PrefillCase = FmhaPipeline::kM0 > 64; // divide problem const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs); + const float sink_value = kargs.sink_ptr != nullptr + ? *(static_cast(kargs.sink_ptr) + i_nhead) + : static_cast(-numeric::infinity()); const index_t i_m0 = i_tile_m * FmhaPipeline::kM0; const index_t i_n1 = i_tile_n * FmhaPipeline::kN1; @@ -2273,6 +2291,7 @@ struct FmhaFwdKernel mask, position_encoding, kargs.scale_s, + sink_value, smem_ptrk0, smem_ptrk1, smem_ptrv0, @@ -2289,7 +2308,8 @@ struct FmhaFwdKernel mask, position_encoding, kargs.scale_s, - smem_ptr); + smem_ptr, + sink_value); } }(); diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp index b75b35fc1e8..7225db281c6 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp @@ -123,6 +123,7 @@ struct FmhaFwdPagedKVKernel const void* k_ptr; const void* v_ptr; void* o_ptr; + const void* sink_ptr; ck_tile::index_t seqlen_q; ck_tile::index_t seqlen_k; @@ -328,12 +329,14 @@ struct FmhaFwdPagedKVKernel ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t sink_size, - ck_tile::index_t mask_type) + ck_tile::index_t mask_type, + const void* sink_ptr = nullptr) { Kargs kargs{{q_ptr, k_ptr, v_ptr, o_ptr, + sink_ptr, seqlen_q, seqlen_k, hdim_q, @@ -457,7 +460,8 @@ struct FmhaFwdPagedKVKernel ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t sink_size, - ck_tile::index_t mask_type) + ck_tile::index_t mask_type, + const void* sink_ptr = nullptr) { return MakeKargsImpl(q_ptr, k_ptr, @@ -500,7 +504,8 @@ struct FmhaFwdPagedKVKernel window_size_left, window_size_right, sink_size, - mask_type); + mask_type, + sink_ptr); } template @@ -543,12 +548,14 @@ struct FmhaFwdPagedKVKernel ck_tile::index_t window_size_right, ck_tile::index_t sink_size, ck_tile::index_t mask_type, - ck_tile::index_t min_seqlen_q) + ck_tile::index_t min_seqlen_q, + const void* sink_ptr = nullptr) { Kargs kargs{{q_ptr, k_ptr, v_ptr, o_ptr, + sink_ptr, -1, // seqlen will be updated by another pointer -1, // hdim_q, @@ -669,7 +676,8 @@ struct FmhaFwdPagedKVKernel ck_tile::index_t window_size_right, ck_tile::index_t sink_size, ck_tile::index_t mask_type, - ck_tile::index_t min_seqlen_q) + ck_tile::index_t min_seqlen_q, + const void* sink_ptr = nullptr) { return MakeKargsImpl(q_ptr, k_ptr, @@ -709,7 +717,8 @@ struct FmhaFwdPagedKVKernel window_size_right, sink_size, mask_type, - min_seqlen_q); + min_seqlen_q, + sink_ptr); } CK_TILE_HOST static void PrintParameters(const Kargs& kargs, int num_batches) @@ -898,7 +907,6 @@ struct FmhaFwdPagedKVKernel // divide problem const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs); - const index_t i_m0 = amd_wave_read_first_lane(i_tile_m * FmhaPipeline::kM0); const index_t i_n1 = amd_wave_read_first_lane(i_tile_n * FmhaPipeline::kN1); @@ -909,6 +917,9 @@ struct FmhaFwdPagedKVKernel long_index_t batch_offset_lse = 0; long_index_t batch_offset_o = 0; index_t kv_l2p_offset = 0; + const float sink_value = kargs.sink_ptr != nullptr + ? *(static_cast(kargs.sink_ptr) + i_nhead) + : static_cast(-numeric::infinity()); if constexpr(kIsGroupMode) { @@ -1348,7 +1359,8 @@ struct FmhaFwdPagedKVKernel variant_params, block_indices, kv_l2p_offset, - smem_ptr); + smem_ptr, + sink_value); } else { @@ -1366,7 +1378,8 @@ struct FmhaFwdPagedKVKernel variant_params, block_indices, kv_l2p_offset, - smem_ptr); + smem_ptr, + sink_value); } }(); diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp index bd5cddb5260..3e0f5205944 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp @@ -124,6 +124,7 @@ struct FmhaFwdSplitKVKernel const void* v_ptr; void* lse_acc_ptr; void* o_acc_ptr; + const void* sink_ptr; ck_tile::index_t batch; @@ -327,13 +328,15 @@ struct FmhaFwdSplitKVKernel ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t sink_size, - ck_tile::index_t mask_type) + ck_tile::index_t mask_type, + const void* sink_ptr = nullptr) { Kargs kargs{{q_ptr, k_ptr, v_ptr, lse_acc_ptr, o_acc_ptr, + sink_ptr, batch, seqlen_q, seqlen_k, @@ -455,13 +458,15 @@ struct FmhaFwdSplitKVKernel ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t sink_size, - ck_tile::index_t mask_type) + ck_tile::index_t mask_type, + const void* sink_ptr = nullptr) { Kargs kargs{{q_ptr, k_ptr, v_ptr, lse_acc_ptr, o_acc_ptr, + sink_ptr, batch, -1, // seqlen_q will be updated by another pointer -1, // seqlen_k will be updated by another pointer @@ -530,7 +535,6 @@ struct FmhaFwdSplitKVKernel { kargs.init_logits_soft_cap(logits_soft_cap); } - return kargs; } @@ -615,6 +619,9 @@ struct FmhaFwdSplitKVKernel long_index_t batch_offset_o_acc = 0; index_t kv_l2p_offset = 0; // logical-to-physical offset of seqlen_k coordinate. only used for paged-kvcache + const float sink_value = kargs.sink_ptr != nullptr + ? *(static_cast(kargs.sink_ptr) + i_nhead) + : static_cast(-numeric::infinity()); if constexpr(kIsGroupMode) { @@ -698,7 +705,6 @@ struct FmhaFwdSplitKVKernel kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch]; } } - // for simplicity, batch stride we just modify the pointer const index_t i_nhead_k = (kMergeNumHeadGroupsSeqLenQ ? i_nhead : i_nhead / kargs.nhead_ratio_qk); @@ -1082,7 +1088,8 @@ struct FmhaFwdSplitKVKernel variant_params, block_indices, kv_l2p_offset, - smem_ptr); + smem_ptr, + sink_value); } else { @@ -1102,7 +1109,8 @@ struct FmhaFwdSplitKVKernel variant_params, block_indices, kv_l2p_offset, - smem_ptr); + smem_ptr, + sink_value); } }(); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp index d55d0d93427..8cc17c87d04 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp @@ -163,7 +163,8 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS const AttentionVariantParams& variant_params, const BlockIndices& block_indices, index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate - void* smem_ptr) const + void* smem_ptr, + const float sink_v) const { static_assert( std::is_same_v> && @@ -227,8 +228,16 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS auto l = MLBlockTileType{}; clear_tile(o_acc); - set_tile(m, -numeric::infinity()); - clear_tile(l); + if(!__builtin_isinf_sign(sink_v) || __builtin_isinf_sign(sink_v) > 0) + { + set_tile(m, sink_v); + set_tile(l, SMPLComputeDataType{1.0f}); + } + else + { + set_tile(m, -numeric::infinity()); + clear_tile(l); + } const auto q_origin = q_dram_window.get_window_origin(); const auto tile_range_result = [&mask, &q_origin]() { if constexpr(kHasSink) @@ -788,7 +797,8 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS const AttentionVariantParams& variant_params, const BlockIndices& block_indices, index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate - void* smem_ptr) const + void* smem_ptr, + const float sink_v) const { return operator()(q_dram_block_window_tmp, identity{}, @@ -812,7 +822,8 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS variant_params, block_indices, kv_l2p_offset, - smem_ptr); + smem_ptr, + sink_v); } }; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp index 944d49a8aad..dd5b5fcd245 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp @@ -164,7 +164,8 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS const AttentionVariantParams& variant_params, const BlockIndices& block_indices, index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate - void* smem_ptr) const + void* smem_ptr, + float sink_v) const { static_assert( std::is_same_v> && @@ -254,8 +255,16 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS auto l = MLBlockTileType{}; clear_tile(o_acc); - set_tile(m, -numeric::infinity()); - clear_tile(l); + if((!__builtin_isinf_sign(sink_v) || __builtin_isinf_sign(sink_v) > 0) && i_split == 0) + { + set_tile(m, SMPLComputeDataType{sink_v}); + set_tile(l, SMPLComputeDataType{1.0f}); + } + else + { + set_tile(m, -numeric::infinity()); + clear_tile(l); + } const auto q_origin = q_dram_window.get_window_origin(); const auto tile_range_result = [&mask, &q_origin, num_splits, i_split]() { @@ -879,7 +888,8 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS const AttentionVariantParams& variant_params, const BlockIndices& block_indices, index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate - void* smem_ptr) const + void* smem_ptr, + float sink_v) const { return operator()(q_dram_block_window_tmp, identity{}, @@ -905,7 +915,8 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS variant_params, block_indices, kv_l2p_offset, - smem_ptr); + smem_ptr, + sink_v); } }; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp index 26a4cc905c7..73ac966507a 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp @@ -163,7 +163,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS const AttentionVariantParams& variant_params, const BlockIndices& block_indices, index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate - void* smem_ptr) const + void* smem_ptr, + float sink_v) const { static_assert( std::is_same_v> && @@ -227,8 +228,16 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS auto l = MLBlockTileType{}; clear_tile(o_acc); - set_tile(m, -numeric::infinity()); - clear_tile(l); + if((!__builtin_isinf_sign(sink_v) || __builtin_isinf_sign(sink_v) > 0) && i_split == 0) + { + set_tile(m, SMPLComputeDataType{sink_v}); + set_tile(l, SMPLComputeDataType{1.0f}); + } + else + { + set_tile(m, -numeric::infinity()); + clear_tile(l); + } const auto q_origin = q_dram_window.get_window_origin(); const auto tile_range_result = [&mask, &q_origin, num_splits, i_split]() { @@ -453,6 +462,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS else { s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + // printf("scale_s1: %f\n", scale_s); if constexpr(kHasLogitsSoftCap) { auto apply_logits_transform = @@ -621,6 +631,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS } else { + // printf("here use tmp\n"); auto row_max = scale_s * get_validated_m(m[i_idx]); return exp2(scale_s * m_old[i_idx] - row_max); } @@ -797,7 +808,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS const AttentionVariantParams& variant_params, const BlockIndices& block_indices, index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate - void* smem_ptr) const + void* smem_ptr, + float sink_v) const { return operator()(q_dram_block_window_tmp, identity{}, @@ -823,7 +835,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS variant_params, block_indices, kv_l2p_offset, - smem_ptr); + smem_ptr, + sink_v); } }; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp index fe825a370a0..b65084e8b17 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -166,7 +166,8 @@ struct BlockFmhaPipelineQRKSVS const AttentionVariantParams& variant_params, const BlockIndices& block_indices, void* smem_ptr, - DropoutType& dropout) const + DropoutType& dropout, + const float sink_v) const { static_assert( std::is_same_v> && @@ -230,8 +231,16 @@ struct BlockFmhaPipelineQRKSVS auto l = MLBlockTileType{}; clear_tile(o_acc); - set_tile(m, -numeric::infinity()); - clear_tile(l); + if(!__builtin_isinf_sign(sink_v) || __builtin_isinf_sign(sink_v) > 0) + { + set_tile(m, sink_v); + set_tile(l, SMPLComputeDataType{1.0f}); + } + else + { + set_tile(m, -numeric::infinity()); + clear_tile(l); + } const auto q_origin = q_dram_window.get_window_origin(); @@ -786,7 +795,8 @@ struct BlockFmhaPipelineQRKSVS const AttentionVariantParams& variant_params, const BlockIndices& block_indices, void* smem_ptr, - DropoutType& dropout) const + DropoutType& dropout, + const float sink_v) const { return operator()(q_dram_block_window_tmp, identity{}, @@ -809,7 +819,8 @@ struct BlockFmhaPipelineQRKSVS variant_params, block_indices, smem_ptr, - dropout); + dropout, + sink_v); } }; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index f57b89cf9dd..6758bffac6f 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -188,7 +188,8 @@ struct BlockFmhaPipelineQRKSVSAsync const AttentionVariantParams& variant_params, const BlockIndices& block_indices, void* smem_ptr, - DropoutType& dropout) const + DropoutType& dropout, + const float sink_v) const { static_assert( std::is_same_v> && @@ -274,8 +275,16 @@ struct BlockFmhaPipelineQRKSVSAsync auto l = MLBlockTileType{}; clear_tile(o_acc); - set_tile(m, -numeric::infinity()); - clear_tile(l); + if(!__builtin_isinf_sign(sink_v) || __builtin_isinf_sign(sink_v) > 0) + { + set_tile(m, sink_v); + set_tile(l, SMPLComputeDataType{1.0f}); + } + else + { + set_tile(m, -numeric::infinity()); + clear_tile(l); + } __builtin_amdgcn_sched_barrier(0); const auto q_origin = q_dram_window.get_window_origin(); @@ -880,7 +889,8 @@ struct BlockFmhaPipelineQRKSVSAsync const AttentionVariantParams& variant_params, const BlockIndices& block_indices, void* smem_ptr, - DropoutType& dropout) const + DropoutType& dropout, + const float sink_v) const { return operator()(q_dram_block_window_tmp, identity{}, @@ -903,7 +913,8 @@ struct BlockFmhaPipelineQRKSVSAsync variant_params, block_indices, smem_ptr, - dropout); + dropout, + sink_v); } }; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp index 26662dafeb9..372e9fe10d9 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp @@ -148,7 +148,8 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload FmhaMask mask, PositionEncoding position_encoding, float scale_s, - void* smem_ptr) const + void* smem_ptr, + float sink_v) const { static_assert( std::is_same_v> && @@ -193,8 +194,16 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload auto l = MLBlockTileType{}; clear_tile(o_acc); - set_tile(m, -numeric::infinity()); - clear_tile(l); + if(!__builtin_isinf_sign(sink_v) || __builtin_isinf_sign(sink_v) > 0) + { + set_tile(m, sink_v); + set_tile(l, SMPLComputeDataType{1.0f}); + } + else + { + set_tile(m, -numeric::infinity()); + clear_tile(l); + } const auto q_origin = q_dram_block_window_tmp.get_window_origin(); const auto [logical_seqlen_k_start, logical_seqlen_k_end] = @@ -649,6 +658,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload FmhaMask mask, PositionEncoding position_encoding, float scale_s, + float sink_v, void* __restrict__ smem_ptrk0, void* __restrict__ smem_ptrk1, void* __restrict__ smem_ptrv0, @@ -698,8 +708,16 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload auto l = MLBlockTileType{}; clear_tile(o_acc); - set_tile(m, -numeric::infinity()); - clear_tile(l); + if(!__builtin_isinf_sign(sink_v) || __builtin_isinf_sign(sink_v) > 0) + { + set_tile(m, sink_v); + set_tile(l, SMPLComputeDataType{1.0f}); + } + else + { + set_tile(m, -numeric::infinity()); + clear_tile(l); + } const auto q_origin = q_dram_block_window_tmp.get_window_origin(); const auto [logical_seqlen_k_start, logical_seqlen_k_end] = From 6711cecf7ca251e1b663f51d049b85f2e5466ed0 Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Thu, 25 Dec 2025 15:15:55 +0800 Subject: [PATCH 02/31] Update include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp | 1 - 1 file changed, 1 deletion(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp index 73ac966507a..f351f42bd43 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp @@ -631,7 +631,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS } else { - // printf("here use tmp\n"); auto row_max = scale_s * get_validated_m(m[i_idx]); return exp2(scale_s * m_old[i_idx] - row_max); } From 1e3c54d28020ccb964b87905b2c13f025a189c13 Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Thu, 25 Dec 2025 15:16:06 +0800 Subject: [PATCH 03/31] Update include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp | 1 - 1 file changed, 1 deletion(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp index f351f42bd43..ae08279b373 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp @@ -462,7 +462,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS else { s_acc = tile_elementwise_in(s_acc_element_func, s_acc); - // printf("scale_s1: %f\n", scale_s); if constexpr(kHasLogitsSoftCap) { auto apply_logits_transform = From 7c9cb83ca015a4689bd2f660f1e410d4288e2f8f Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Thu, 25 Dec 2025 01:35:32 -0600 Subject: [PATCH 04/31] add gptoss sink test Signed-off-by: Linjun-AMD --- example/ck_tile/01_fmha/script/smoke_test_fwd_sink.sh | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/example/ck_tile/01_fmha/script/smoke_test_fwd_sink.sh b/example/ck_tile/01_fmha/script/smoke_test_fwd_sink.sh index 664c8254181..6be3a443b95 100755 --- a/example/ck_tile/01_fmha/script/smoke_test_fwd_sink.sh +++ b/example/ck_tile/01_fmha/script/smoke_test_fwd_sink.sh @@ -84,3 +84,12 @@ $EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=16384 -s_k=16384 -bias=n -l # 1 1 1 1 1 1 1 1 1 1 # l=2/r=0(br) l=2/r=0/s=2(br) +$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=512 -s_k=512 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 mask=1 + +$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=1024 -s_k=1024 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 mask=0 + +$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=4096 -s_k=4096 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 + +$EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=8192 -s_k=8192 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 mask=0 + +$EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=16384 -s_k=16384 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=2 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 \ No newline at end of file From 5163868319cacf62339ee2f4fde1fd39ba39745a Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Thu, 25 Dec 2025 01:51:05 -0600 Subject: [PATCH 05/31] update CHANGELOG.md Signed-off-by: Linjun-AMD --- CHANGELOG.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6a9b25b062a..14d91ad195d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,10 +7,11 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj ### Added * Added support for explicit GEMM in CK_TILE grouped convolution forward and backward weight. * Added TF32 convolution support on gfx942 and gfx950 in CK. It could be enabled/disabled via `DTYPES` of "tf32". -* Added attention sink support for FMHA FWD, include qr_ks_vs, qr_async and splitkv pipelines. +* Added streamingllm sink support for FMHA FWD, include qr_ks_vs, qr_async and splitkv pipelines. * Added support for microscaling (MX) FP8/FP4 mixed data types to Flatmm pipeline. * Added support for fp8 dynamic tensor-wise quantization of fp8 fmha fwd kernel. * Added FP8 KV cache support for FMHA batch prefill. +* Added gptoss sink support for FMHA FWD, include qr_ks_vs, qr_async, qr_async_trload and splitkv pipelines. ### Changed From 5ab683b02b6b90518e059d343dfafcdc13c75d29 Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Mon, 29 Dec 2025 21:25:58 -0600 Subject: [PATCH 06/31] fix test args error Signed-off-by: Linjun-AMD --- ..._batch_prefill_pipeline_qr_ks_vs_async.hpp | 21 ++++++++++++++----- test/ck_tile/fmha/test_fmha_fwd.cpp | 7 +++++-- 2 files changed, 21 insertions(+), 7 deletions(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp index 2102fe768f1..55f2354ff45 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp @@ -196,7 +196,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync const index_t* page_idx, const index_t stride_k, const index_t stride_v, - DropoutType& dropout) const + DropoutType& dropout, + const float sink_v) const { static_assert( std::is_same_v> && @@ -282,8 +283,16 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync auto l = MLBlockTileType{}; clear_tile(o_acc); - set_tile(m, -numeric::infinity()); - clear_tile(l); + if(!__builtin_isinf_sign(sink_v) || __builtin_isinf_sign(sink_v) > 0) + { + set_tile(m, sink_v); + set_tile(l, SMPLComputeDataType{1.0f}); + } + else + { + set_tile(m, -numeric::infinity()); + clear_tile(l); + } __builtin_amdgcn_sched_barrier(0); const auto q_origin = q_dram_window.get_window_origin(); @@ -887,7 +896,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync const index_t* page_idx, const index_t stride_k, const index_t stride_v, - DropoutType& dropout) const + DropoutType& dropout, + const float sink_v) const { return operator()(q_dram_block_window_tmp, identity{}, @@ -913,7 +923,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync page_idx, stride_k, stride_v, - dropout); + dropout, + sink_v); } }; diff --git a/test/ck_tile/fmha/test_fmha_fwd.cpp b/test/ck_tile/fmha/test_fmha_fwd.cpp index b81fa88aa22..e9bd2549b13 100644 --- a/test/ck_tile/fmha/test_fmha_fwd.cpp +++ b/test/ck_tile/fmha/test_fmha_fwd.cpp @@ -120,8 +120,8 @@ const ck_tile::stream_config stream_config{ 1, // rotating_count_ }; -#define COMMON_ARGS \ - init_method, static_cast(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))), 1, \ +#define COMMON_ARGS \ + init_method, static_cast(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))), 1, 0, \ stream_config auto EnableTestIf(bool condition) @@ -255,6 +255,7 @@ TEST(TestCkTileFmhaFwd, AppendKvWithBatchEffLensShouldFail) init_method, static_cast(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))), 0, + 1, //init_sink stream_config); ASSERT_EQ(result, fwd_result::invalid_args); } @@ -299,6 +300,7 @@ TEST(TestCkTileFmhaFwd, SplitKvWithGroupPaddingShouldFail) init_method, static_cast(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))), 0, + 1, //init_sink stream_config); ASSERT_EQ(result, fwd_result::invalid_args); } @@ -342,6 +344,7 @@ TEST(TestCkTileFmhaFwd, PagedKvWithGroupPaddingShouldFail) init_method, static_cast(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))), 0, + 1, //init_sink stream_config); ASSERT_EQ(result, fwd_result::invalid_args); } From 5c0e07abc2e214ec48c4986fd83794313fc93636 Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Tue, 30 Dec 2025 11:32:52 +0800 Subject: [PATCH 07/31] Update test_fmha_fwd.cpp --- test/ck_tile/fmha/test_fmha_fwd.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/ck_tile/fmha/test_fmha_fwd.cpp b/test/ck_tile/fmha/test_fmha_fwd.cpp index e9bd2549b13..c59ee7a67d8 100644 --- a/test/ck_tile/fmha/test_fmha_fwd.cpp +++ b/test/ck_tile/fmha/test_fmha_fwd.cpp @@ -255,7 +255,7 @@ TEST(TestCkTileFmhaFwd, AppendKvWithBatchEffLensShouldFail) init_method, static_cast(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))), 0, - 1, //init_sink + 1, // init_sink stream_config); ASSERT_EQ(result, fwd_result::invalid_args); } @@ -300,7 +300,7 @@ TEST(TestCkTileFmhaFwd, SplitKvWithGroupPaddingShouldFail) init_method, static_cast(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))), 0, - 1, //init_sink + 1, // init_sink stream_config); ASSERT_EQ(result, fwd_result::invalid_args); } @@ -344,7 +344,7 @@ TEST(TestCkTileFmhaFwd, PagedKvWithGroupPaddingShouldFail) init_method, static_cast(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))), 0, - 1, //init_sink + 1, // init_sink stream_config); ASSERT_EQ(result, fwd_result::invalid_args); } From 970b4f168636e989e41325d18432fa7e5ff40090 Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Tue, 30 Dec 2025 03:46:29 -0600 Subject: [PATCH 08/31] update sink test Signed-off-by: Linjun-AMD --- example/ck_tile/01_fmha/script/smoke_test_fwd_sink.sh | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/example/ck_tile/01_fmha/script/smoke_test_fwd_sink.sh b/example/ck_tile/01_fmha/script/smoke_test_fwd_sink.sh index 6be3a443b95..5d8cea0ce6b 100755 --- a/example/ck_tile/01_fmha/script/smoke_test_fwd_sink.sh +++ b/example/ck_tile/01_fmha/script/smoke_test_fwd_sink.sh @@ -84,12 +84,12 @@ $EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=16384 -s_k=16384 -bias=n -l # 1 1 1 1 1 1 1 1 1 1 # l=2/r=0(br) l=2/r=0/s=2(br) -$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=512 -s_k=512 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 mask=1 +$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=512 -s_k=512 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink_value=1 mask=1 -$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=1024 -s_k=1024 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 mask=0 +$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=1024 -s_k=1024 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink_value=1 mask=0 -$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=4096 -s_k=4096 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 +$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=4096 -s_k=4096 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink_value=1 -$EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=8192 -s_k=8192 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 mask=0 +$EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=8192 -s_k=8192 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink_value=1 mask=0 -$EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=16384 -s_k=16384 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=2 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 \ No newline at end of file +$EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=16384 -s_k=16384 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=2 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink_value=1 \ No newline at end of file From 0eeedeb105946a75327cbf691477a36853ab4acf Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Tue, 30 Dec 2025 03:53:42 -0600 Subject: [PATCH 09/31] Revert "update sink test" This reverts commit 970b4f168636e989e41325d18432fa7e5ff40090. --- example/ck_tile/01_fmha/script/smoke_test_fwd_sink.sh | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/example/ck_tile/01_fmha/script/smoke_test_fwd_sink.sh b/example/ck_tile/01_fmha/script/smoke_test_fwd_sink.sh index 5d8cea0ce6b..6be3a443b95 100755 --- a/example/ck_tile/01_fmha/script/smoke_test_fwd_sink.sh +++ b/example/ck_tile/01_fmha/script/smoke_test_fwd_sink.sh @@ -84,12 +84,12 @@ $EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=16384 -s_k=16384 -bias=n -l # 1 1 1 1 1 1 1 1 1 1 # l=2/r=0(br) l=2/r=0/s=2(br) -$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=512 -s_k=512 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink_value=1 mask=1 +$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=512 -s_k=512 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 mask=1 -$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=1024 -s_k=1024 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink_value=1 mask=0 +$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=1024 -s_k=1024 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 mask=0 -$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=4096 -s_k=4096 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink_value=1 +$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=4096 -s_k=4096 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 -$EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=8192 -s_k=8192 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink_value=1 mask=0 +$EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=8192 -s_k=8192 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 mask=0 -$EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=16384 -s_k=16384 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=2 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink_value=1 \ No newline at end of file +$EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=16384 -s_k=16384 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=2 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 \ No newline at end of file From 31db4124744d530ca3fb65da18b7ca429e41fc44 Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Tue, 30 Dec 2025 03:55:05 -0600 Subject: [PATCH 10/31] update sink test Signed-off-by: Linjun-AMD --- example/ck_tile/01_fmha/script/smoke_test_fwd_sink.sh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/example/ck_tile/01_fmha/script/smoke_test_fwd_sink.sh b/example/ck_tile/01_fmha/script/smoke_test_fwd_sink.sh index 6be3a443b95..746ff8c0e1e 100755 --- a/example/ck_tile/01_fmha/script/smoke_test_fwd_sink.sh +++ b/example/ck_tile/01_fmha/script/smoke_test_fwd_sink.sh @@ -84,12 +84,12 @@ $EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=16384 -s_k=16384 -bias=n -l # 1 1 1 1 1 1 1 1 1 1 # l=2/r=0(br) l=2/r=0/s=2(br) -$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=512 -s_k=512 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 mask=1 +$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=512 -s_k=512 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 -mask=1 -$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=1024 -s_k=1024 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 mask=0 +$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=1024 -s_k=1024 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 -mask=0 $EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=4096 -s_k=4096 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 -$EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=8192 -s_k=8192 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 mask=0 +$EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=8192 -s_k=8192 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 -mask=1 $EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=16384 -s_k=16384 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=2 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 \ No newline at end of file From b37b17456ba50ae1311bafae8fbf4e133fd9cf91 Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Mon, 5 Jan 2026 00:11:21 -0600 Subject: [PATCH 11/31] update valid sink_v in splitkv pipeline Signed-off-by: Linjun-AMD --- example/ck_tile/01_fmha/fmha_fwd_runner.hpp | 2 +- ...k_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp | 2 +- .../block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp | 2 +- ...fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp | 13 +++++++++++-- .../block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp | 12 +++++++++++- .../fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp | 2 +- .../pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp | 2 +- .../block_fmha_pipeline_qr_ks_vs_async_trload.hpp | 4 ++-- 8 files changed, 29 insertions(+), 10 deletions(-) diff --git a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp index 2d65fe467e6..ee3cc5b4d3d 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp @@ -700,7 +700,7 @@ fwd_result fmha_fwd_run(mode_enum mode, iota_shuffle(cache_batch_idx_host.begin(), cache_batch_idx_host.end(), 0, random_engine); if(init_sink_value != 0) { - ck_tile::FillUniformDistributionIntegerValue{30.f, 100.f, next_seed()}( + ck_tile::FillUniformDistributionIntegerValue{30.f, 60.f, next_seed()}( sink_host); } ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes()); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp index 55f2354ff45..04d52b9ad51 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp @@ -283,7 +283,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync auto l = MLBlockTileType{}; clear_tile(o_acc); - if(!__builtin_isinf_sign(sink_v) || __builtin_isinf_sign(sink_v) > 0) + if( __builtin_isinf_sign(sink_v) >= 0) { set_tile(m, sink_v); set_tile(l, SMPLComputeDataType{1.0f}); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp index 8cc17c87d04..12b9c1dbcca 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp @@ -228,7 +228,7 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS auto l = MLBlockTileType{}; clear_tile(o_acc); - if(!__builtin_isinf_sign(sink_v) || __builtin_isinf_sign(sink_v) > 0) + if(__builtin_isinf_sign(sink_v) >= 0) { set_tile(m, sink_v); set_tile(l, SMPLComputeDataType{1.0f}); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp index dd5b5fcd245..08a65c5bfd2 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp @@ -255,7 +255,7 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS auto l = MLBlockTileType{}; clear_tile(o_acc); - if((!__builtin_isinf_sign(sink_v) || __builtin_isinf_sign(sink_v) > 0) && i_split == 0) + if((__builtin_isinf_sign(sink_v) >= 0) && i_split == 0) { set_tile(m, SMPLComputeDataType{sink_v}); set_tile(l, SMPLComputeDataType{1.0f}); @@ -308,7 +308,16 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS return o_acc; } } - + if(i_split > 0) + { + auto [start, end] = mask.GetTileRangeAlongX( + q_origin.at(number<0>{}), number{}, number{}, num_splits, i_split - 1); + if((__builtin_isinf_sign(sink_v) >= 0) && start >= end) + { + set_tile(m, SMPLComputeDataType{sink_v}); + set_tile(l, SMPLComputeDataType{1.0f}); + } + } const index_t physical_seqlen_k_start = logical_seqlen_k_start + kv_l2p_offset; const index_t physical_seqlen_k_end = logical_seqlen_k_end + kv_l2p_offset; // make sure the first tile is completely located in page-block (page-block size should be diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp index ae08279b373..13cbcc1304b 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp @@ -228,7 +228,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS auto l = MLBlockTileType{}; clear_tile(o_acc); - if((!__builtin_isinf_sign(sink_v) || __builtin_isinf_sign(sink_v) > 0) && i_split == 0) + if((__builtin_isinf_sign(sink_v) >= 0) && i_split == 0) { set_tile(m, SMPLComputeDataType{sink_v}); set_tile(l, SMPLComputeDataType{1.0f}); @@ -281,6 +281,16 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS } } + if(i_split > 0) + { + auto [start, end] = mask.GetTileRangeAlongX( + q_origin.at(number<0>{}), number{}, number{}, num_splits, i_split - 1); + if((__builtin_isinf_sign(sink_v) >= 0) && start >= end) + { + set_tile(m, SMPLComputeDataType{sink_v}); + set_tile(l, SMPLComputeDataType{1.0f}); + } + } const index_t physical_seqlen_k_start = logical_seqlen_k_start + kv_l2p_offset; const index_t physical_seqlen_k_end = logical_seqlen_k_end + kv_l2p_offset; // make sure the first tile is completely located in page-block (page-block size should be diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp index b65084e8b17..5c5a26656cb 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -231,7 +231,7 @@ struct BlockFmhaPipelineQRKSVS auto l = MLBlockTileType{}; clear_tile(o_acc); - if(!__builtin_isinf_sign(sink_v) || __builtin_isinf_sign(sink_v) > 0) + if(__builtin_isinf_sign(sink_v) >= 0) { set_tile(m, sink_v); set_tile(l, SMPLComputeDataType{1.0f}); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index 6758bffac6f..d8bc7dd40fc 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -275,7 +275,7 @@ struct BlockFmhaPipelineQRKSVSAsync auto l = MLBlockTileType{}; clear_tile(o_acc); - if(!__builtin_isinf_sign(sink_v) || __builtin_isinf_sign(sink_v) > 0) + if(__builtin_isinf_sign(sink_v) >= 0) { set_tile(m, sink_v); set_tile(l, SMPLComputeDataType{1.0f}); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp index 372e9fe10d9..83b29658efa 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp @@ -194,7 +194,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload auto l = MLBlockTileType{}; clear_tile(o_acc); - if(!__builtin_isinf_sign(sink_v) || __builtin_isinf_sign(sink_v) > 0) + if(__builtin_isinf_sign(sink_v) >= 0) { set_tile(m, sink_v); set_tile(l, SMPLComputeDataType{1.0f}); @@ -708,7 +708,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload auto l = MLBlockTileType{}; clear_tile(o_acc); - if(!__builtin_isinf_sign(sink_v) || __builtin_isinf_sign(sink_v) > 0) + if(__builtin_isinf_sign(sink_v) >= 0) { set_tile(m, sink_v); set_tile(l, SMPLComputeDataType{1.0f}); From a20868eb4421327c202c33f3a4a3b266618a667e Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Mon, 5 Jan 2026 16:01:12 +0800 Subject: [PATCH 12/31] Update block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp --- .../block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp index 04d52b9ad51..295fe55cd3d 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp @@ -283,7 +283,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync auto l = MLBlockTileType{}; clear_tile(o_acc); - if( __builtin_isinf_sign(sink_v) >= 0) + if(__builtin_isinf_sign(sink_v) >= 0) { set_tile(m, sink_v); set_tile(l, SMPLComputeDataType{1.0f}); From 81b02a639ce7ea14c5b38296af7587450bc3e527 Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Tue, 6 Jan 2026 20:21:12 +0800 Subject: [PATCH 13/31] Update example_fmha_fwd.cpp --- example/ck_tile/01_fmha/example_fmha_fwd.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/example/ck_tile/01_fmha/example_fmha_fwd.cpp b/example/ck_tile/01_fmha/example_fmha_fwd.cpp index f5ad6b2bc57..3d729a272de 100644 --- a/example/ck_tile/01_fmha/example_fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/example_fmha_fwd.cpp @@ -115,7 +115,7 @@ auto create_args(int argc, char* argv[]) "", "Batch-mode only: per-batch effective seqlen for KV (exclude PAD).\n" "Comma-separated list of length 'b'. If empty, no override.") - .insert("init_sink", "0", "value to init the output tensor sink value for validation"); + .insert("init_sink", "1", "value to init the output tensor sink value for validation"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); From ebf445f65c8f3dbba1dda3ea7f821f43bfa20338 Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Tue, 6 Jan 2026 23:30:16 -0600 Subject: [PATCH 14/31] fix lse error Signed-off-by: Linjun-AMD --- ...a_batch_prefill_pipeline_qr_ks_vs_async.hpp | 9 ++++++++- ...lock_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp | 9 ++++++++- ...plitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp | 9 ++++++++- ...lock_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp | 9 ++++++++- .../block_fmha_pipeline_qr_ks_vs_async.hpp | 9 ++++++++- ...ock_fmha_pipeline_qr_ks_vs_async_trload.hpp | 18 ++++++++++++++++-- 6 files changed, 56 insertions(+), 7 deletions(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp index 295fe55cd3d..b56746805d9 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp @@ -311,7 +311,14 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync auto lse = make_static_distributed_tensor(m.get_tile_distribution()); - set_tile(lse, -numeric::infinity()); + if (__builtin_isinf_sign(sink_v) >= 0) + { + set_tile(lse, SMPLComputeDataType{sink_v}); + } + else + { + set_tile(lse, -numeric::infinity()); + } store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp index 12b9c1dbcca..8361f6ee3bc 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp @@ -267,7 +267,14 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS auto lse = make_static_distributed_tensor(m.get_tile_distribution()); - set_tile(lse, -numeric::infinity()); + if (__builtin_isinf_sign(sink_v) >= 0) + { + set_tile(lse, SMPLComputeDataType{sink_v}); + } + else + { + set_tile(lse, -numeric::infinity()); + } store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp index 08a65c5bfd2..211adf56621 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp @@ -294,7 +294,14 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS auto lse_acc = make_static_distributed_tensor(m.get_tile_distribution()); - set_tile(lse_acc, -numeric::infinity()); + if (__builtin_isinf_sign(sink_v) >= 0) + { + set_tile(lse_acc, SMPLComputeDataType{sink_v}); + } + else + { + set_tile(lse_acc, -numeric::infinity()); + } if(get_thread_local_1d_id() < kM0) { diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp index 13cbcc1304b..971fcbe0a7b 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp @@ -269,7 +269,14 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS auto lse_acc = make_static_distributed_tensor(m.get_tile_distribution()); - set_tile(lse_acc, -numeric::infinity()); + if (__builtin_isinf_sign(sink_v) >= 0) + { + set_tile(lse_acc, SMPLComputeDataType{sink_v}); + } + else + { + set_tile(lse_acc, -numeric::infinity()); + } store_tile(lse_acc_dram_window_tmp, tile_elementwise_in(lse_acc_element_func, lse_acc)); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index d8bc7dd40fc..57c4d9c799e 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -318,7 +318,14 @@ struct BlockFmhaPipelineQRKSVSAsync auto lse = make_static_distributed_tensor(m.get_tile_distribution()); - set_tile(lse, -numeric::infinity()); + if (__builtin_isinf_sign(sink_v) >= 0) + { + set_tile(lse, SMPLComputeDataType{sink_v}); + } + else + { + set_tile(lse, -numeric::infinity()); + } store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp index 83b29658efa..725d4de482e 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp @@ -221,7 +221,14 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload auto lse_acc = make_static_distributed_tensor(m.get_tile_distribution()); - set_tile(lse_acc, -numeric::infinity()); + if (__builtin_isinf_sign(sink_v) >= 0) + { + set_tile(lse_acc, SMPLComputeDataType{sink_v}); + } + else + { + set_tile(lse_acc, -numeric::infinity()); + } store_tile(lse_acc_dram_window_tmp, lse_acc); } @@ -735,7 +742,14 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload auto lse_acc = make_static_distributed_tensor(m.get_tile_distribution()); - set_tile(lse_acc, -numeric::infinity()); + if (__builtin_isinf_sign(sink_v) >= 0) + { + set_tile(lse_acc, SMPLComputeDataType{sink_v}); + } + else + { + set_tile(lse_acc, -numeric::infinity()); + } store_tile(lse_acc_dram_window_tmp, lse_acc); } From f2fddfa9e1eb03bd605d031b4354506661837b95 Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Wed, 7 Jan 2026 00:02:40 -0600 Subject: [PATCH 15/31] fix clangformat error Signed-off-by: Linjun-AMD --- .../block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp | 2 +- .../pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp | 2 +- ...lock_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp | 2 +- .../pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp | 2 +- .../ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp | 2 +- .../pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp | 4 ++-- 6 files changed, 7 insertions(+), 7 deletions(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp index b56746805d9..df92d5d351d 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp @@ -311,7 +311,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync auto lse = make_static_distributed_tensor(m.get_tile_distribution()); - if (__builtin_isinf_sign(sink_v) >= 0) + if(__builtin_isinf_sign(sink_v) >= 0) { set_tile(lse, SMPLComputeDataType{sink_v}); } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp index 8361f6ee3bc..e471b8ddc49 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp @@ -267,7 +267,7 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS auto lse = make_static_distributed_tensor(m.get_tile_distribution()); - if (__builtin_isinf_sign(sink_v) >= 0) + if(__builtin_isinf_sign(sink_v) >= 0) { set_tile(lse, SMPLComputeDataType{sink_v}); } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp index 211adf56621..720247e8cf1 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp @@ -294,7 +294,7 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS auto lse_acc = make_static_distributed_tensor(m.get_tile_distribution()); - if (__builtin_isinf_sign(sink_v) >= 0) + if(__builtin_isinf_sign(sink_v) >= 0) { set_tile(lse_acc, SMPLComputeDataType{sink_v}); } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp index 971fcbe0a7b..a28a430c7a5 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp @@ -269,7 +269,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS auto lse_acc = make_static_distributed_tensor(m.get_tile_distribution()); - if (__builtin_isinf_sign(sink_v) >= 0) + if(__builtin_isinf_sign(sink_v) >= 0) { set_tile(lse_acc, SMPLComputeDataType{sink_v}); } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index 57c4d9c799e..c6771db931b 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -318,7 +318,7 @@ struct BlockFmhaPipelineQRKSVSAsync auto lse = make_static_distributed_tensor(m.get_tile_distribution()); - if (__builtin_isinf_sign(sink_v) >= 0) + if(__builtin_isinf_sign(sink_v) >= 0) { set_tile(lse, SMPLComputeDataType{sink_v}); } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp index 725d4de482e..0b17c14fc4d 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp @@ -221,7 +221,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload auto lse_acc = make_static_distributed_tensor(m.get_tile_distribution()); - if (__builtin_isinf_sign(sink_v) >= 0) + if(__builtin_isinf_sign(sink_v) >= 0) { set_tile(lse_acc, SMPLComputeDataType{sink_v}); } @@ -742,7 +742,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload auto lse_acc = make_static_distributed_tensor(m.get_tile_distribution()); - if (__builtin_isinf_sign(sink_v) >= 0) + if(__builtin_isinf_sign(sink_v) >= 0) { set_tile(lse_acc, SMPLComputeDataType{sink_v}); } From 1d4e21989bceb559523425637065d8df5efecb23 Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Wed, 7 Jan 2026 20:28:54 -0600 Subject: [PATCH 16/31] fix aiter scale error Signed-off-by: Linjun-AMD --- ...k_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp | 5 +++++ .../block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp | 6 +++++- .../fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp | 13 ++++++++++++- .../pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp | 5 +++++ .../block_fmha_pipeline_qr_ks_vs_async_trload.hpp | 8 ++++++++ 5 files changed, 35 insertions(+), 2 deletions(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp index df92d5d351d..38f5c2e4559 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp @@ -92,6 +92,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync #if CK_TILE_FMHA_FWD_FAST_EXP2 static constexpr auto R_LOG2E = 1.0 / log2e_v; + static constexpr auto LOG2E = log2e_v; #endif static constexpr index_t kBlockPerCu = []() { @@ -285,7 +286,11 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync clear_tile(o_acc); if(__builtin_isinf_sign(sink_v) >= 0) { +#if CK_TILE_FMHA_FWD_FAST_EXP2 + set_tile(m, sink_v * LOG2E); +#else set_tile(m, sink_v); +#endif set_tile(l, SMPLComputeDataType{1.0f}); } else diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp index a28a430c7a5..c04dbee01ee 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp @@ -230,7 +230,11 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS clear_tile(o_acc); if((__builtin_isinf_sign(sink_v) >= 0) && i_split == 0) { - set_tile(m, SMPLComputeDataType{sink_v}); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + set_tile(m, sink_v * C_LOG2E); +#else + set_tile(m, sink_v); +#endif set_tile(l, SMPLComputeDataType{1.0f}); } else diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp index 5c5a26656cb..1df942abd5f 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -233,7 +233,11 @@ struct BlockFmhaPipelineQRKSVS clear_tile(o_acc); if(__builtin_isinf_sign(sink_v) >= 0) { +#if CK_TILE_FMHA_FWD_FAST_EXP2 + set_tile(m, sink_v * C_LOG2E); +#else set_tile(m, sink_v); +#endif set_tile(l, SMPLComputeDataType{1.0f}); } else @@ -274,7 +278,14 @@ struct BlockFmhaPipelineQRKSVS auto lse = make_static_distributed_tensor(m.get_tile_distribution()); - set_tile(lse, -numeric::infinity()); + if(__builtin_isinf_sign(sink_v) >= 0) + { + set_tile(lse_acc, SMPLComputeDataType{sink_v}); + } + else + { + set_tile(lse_acc, -numeric::infinity()); + } store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index c6771db931b..bddb6db2cbd 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -87,6 +87,7 @@ struct BlockFmhaPipelineQRKSVSAsync #if CK_TILE_FMHA_FWD_FAST_EXP2 static constexpr auto R_LOG2E = 1.0 / log2e_v; + static constexpr auto LOG2E = log2e_v; #endif static constexpr index_t kBlockPerCu = []() { @@ -277,7 +278,11 @@ struct BlockFmhaPipelineQRKSVSAsync clear_tile(o_acc); if(__builtin_isinf_sign(sink_v) >= 0) { +#if CK_TILE_FMHA_FWD_FAST_EXP2 + set_tile(m, sink_v * LOG2E); +#else set_tile(m, sink_v); +#endif set_tile(l, SMPLComputeDataType{1.0f}); } else diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp index 0b17c14fc4d..95e20d95277 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp @@ -196,7 +196,11 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload clear_tile(o_acc); if(__builtin_isinf_sign(sink_v) >= 0) { +#if CK_TILE_FMHA_FWD_FAST_EXP2 + set_tile(m, sink_v * C_LOG2E); +#else set_tile(m, sink_v); +#endif set_tile(l, SMPLComputeDataType{1.0f}); } else @@ -717,7 +721,11 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload clear_tile(o_acc); if(__builtin_isinf_sign(sink_v) >= 0) { +#if CK_TILE_FMHA_FWD_FAST_EXP2 + set_tile(m, sink_v * C_LOG2E); +#else set_tile(m, sink_v); +#endif set_tile(l, SMPLComputeDataType{1.0f}); } else From a667752116cc19c26ebbc053a0a78a115a6e9213 Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Thu, 8 Jan 2026 10:47:01 +0800 Subject: [PATCH 17/31] Update block_fmha_pipeline_qr_ks_vs.hpp --- .../ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp index 1df942abd5f..b8e02038cf1 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -280,11 +280,11 @@ struct BlockFmhaPipelineQRKSVS if(__builtin_isinf_sign(sink_v) >= 0) { - set_tile(lse_acc, SMPLComputeDataType{sink_v}); + set_tile(lse, SMPLComputeDataType{sink_v}); } else { - set_tile(lse_acc, -numeric::infinity()); + set_tile(lse, -numeric::infinity()); } store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); From 68ff3e2f2c2bb0a5273f010e4d24614f6178982c Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Wed, 7 Jan 2026 21:43:11 -0600 Subject: [PATCH 18/31] div scale_s for sink_value Signed-off-by: Linjun-AMD --- .../ops/fmha/kernel/fmha_batch_prefill_kernel.hpp | 7 ++++--- .../ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp | 14 ++++++++------ .../ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp | 7 ++++--- .../ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp | 7 ++++--- 4 files changed, 20 insertions(+), 15 deletions(-) diff --git a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp index e820c7d3642..9796991f37b 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp @@ -706,9 +706,10 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel long_index_t batch_offset_o = 0; const int32_t num_page_blocks = kargs.kv_indptr[i_batch + 1] - kargs.kv_indptr[i_batch]; - const float sink_value = kargs.sink_ptr != nullptr - ? *(static_cast(kargs.sink_ptr) + i_nhead) - : static_cast(-numeric::infinity()); + const float sink_value = + kargs.sink_ptr != nullptr + ? (*(static_cast(kargs.sink_ptr) + i_nhead)) / kargs.scale_s + : static_cast(-numeric::infinity()); #if 0 // we assume page_block_size=1 for now const int32_t last_page_len = kargs.kv_last_page_lens[i_batch]; #endif diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index ba8a1f3f970..b117b8fbea8 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -1118,9 +1118,10 @@ struct FmhaFwdKernel long_index_t batch_offset_randval = 0; long_index_t batch_offset_lse = 0; long_index_t batch_offset_o = 0; - const float sink_value = kargs.sink_ptr != nullptr - ? *(static_cast(kargs.sink_ptr) + i_nhead) - : static_cast(-numeric::infinity()); + const float sink_value = + kargs.sink_ptr != nullptr + ? (*(static_cast(kargs.sink_ptr) + i_nhead)) / kargs.scale_s + : static_cast(-numeric::infinity()); if constexpr(kIsGroupMode) { @@ -1636,9 +1637,10 @@ struct FmhaFwdKernel constexpr bool PrefillCase = FmhaPipeline::kM0 > 64; // divide problem const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs); - const float sink_value = kargs.sink_ptr != nullptr - ? *(static_cast(kargs.sink_ptr) + i_nhead) - : static_cast(-numeric::infinity()); + const float sink_value = + kargs.sink_ptr != nullptr + ? (*(static_cast(kargs.sink_ptr) + i_nhead)) / kargs.scale_s + : static_cast(-numeric::infinity()); const index_t i_m0 = i_tile_m * FmhaPipeline::kM0; const index_t i_n1 = i_tile_n * FmhaPipeline::kN1; diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp index 7225db281c6..f078f19dc69 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp @@ -917,9 +917,10 @@ struct FmhaFwdPagedKVKernel long_index_t batch_offset_lse = 0; long_index_t batch_offset_o = 0; index_t kv_l2p_offset = 0; - const float sink_value = kargs.sink_ptr != nullptr - ? *(static_cast(kargs.sink_ptr) + i_nhead) - : static_cast(-numeric::infinity()); + const float sink_value = + kargs.sink_ptr != nullptr + ? (*(static_cast(kargs.sink_ptr) + i_nhead)) / kargs.scale_s + : static_cast(-numeric::infinity()); if constexpr(kIsGroupMode) { diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp index 3e0f5205944..25a8ce9c683 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp @@ -619,9 +619,10 @@ struct FmhaFwdSplitKVKernel long_index_t batch_offset_o_acc = 0; index_t kv_l2p_offset = 0; // logical-to-physical offset of seqlen_k coordinate. only used for paged-kvcache - const float sink_value = kargs.sink_ptr != nullptr - ? *(static_cast(kargs.sink_ptr) + i_nhead) - : static_cast(-numeric::infinity()); + const float sink_value = + kargs.sink_ptr != nullptr + ? (*(static_cast(kargs.sink_ptr) + i_nhead)) / kargs.scale_s + : static_cast(-numeric::infinity()); if constexpr(kIsGroupMode) { From 1984232f8f31f5508f3a61be501d15ce5c649e6a Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Fri, 9 Jan 2026 16:06:41 +0800 Subject: [PATCH 19/31] Update fmha_fwd_runner.hpp --- example/ck_tile/01_fmha/fmha_fwd_runner.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp index ee3cc5b4d3d..332a078db7b 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp @@ -1703,7 +1703,7 @@ fwd_result fmha_fwd_run(mode_enum mode, s_with_sinks_ref(i_h, i_r, i_c) = s_host_ref(i_h, i_r, i_c); } // Append sink token at the end of each row - s_with_sinks_ref(i_h, i_r, real_seqlen_k) = scale_s_host * sink_host(i_h); + s_with_sinks_ref(i_h, i_r, real_seqlen_k) = sink_host(i_h); } } From 1db49953087e9740f6b01e29a4775d345abfcfe1 Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Sun, 11 Jan 2026 21:05:19 -0600 Subject: [PATCH 20/31] update sink_value with bias Signed-off-by: Linjun-AMD --- ..._batch_prefill_pipeline_qr_ks_vs_async.hpp | 8 +++++-- ...ock_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp | 10 +++++++- ...ock_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp | 23 ++++++++++++++++--- .../pipeline/block_fmha_pipeline_qr_ks_vs.hpp | 9 +++++--- .../block_fmha_pipeline_qr_ks_vs_async.hpp | 16 +++++-------- ...ck_fmha_pipeline_qr_ks_vs_async_trload.hpp | 14 ++++++++--- 6 files changed, 58 insertions(+), 22 deletions(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp index 38f5c2e4559..a0d639d1449 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp @@ -287,7 +287,11 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync if(__builtin_isinf_sign(sink_v) >= 0) { #if CK_TILE_FMHA_FWD_FAST_EXP2 - set_tile(m, sink_v * LOG2E); + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + set_tile(m, sink_v * LOG2E * scale_s); + else + set_tile(m, sink_v * LOG2E); #else set_tile(m, sink_v); #endif @@ -318,7 +322,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync if(__builtin_isinf_sign(sink_v) >= 0) { - set_tile(lse, SMPLComputeDataType{sink_v}); + set_tile(lse, SMPLComputeDataType{sink_v * scale_s}); } else { diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp index e471b8ddc49..e516fc8eea0 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp @@ -230,7 +230,15 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS clear_tile(o_acc); if(__builtin_isinf_sign(sink_v) >= 0) { +#if CK_TILE_FMHA_FWD_FAST_EXP2 + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + set_tile(m, sink_v * C_LOG2E * scale_s); + else + set_tile(m, sink_v * C_LOG2E); +#else set_tile(m, sink_v); +#endif set_tile(l, SMPLComputeDataType{1.0f}); } else @@ -269,7 +277,7 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS if(__builtin_isinf_sign(sink_v) >= 0) { - set_tile(lse, SMPLComputeDataType{sink_v}); + set_tile(lse, SMPLComputeDataType{sink_v * scale_s}); } else { diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp index c04dbee01ee..e4bd2b0e3d3 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp @@ -231,7 +231,11 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS if((__builtin_isinf_sign(sink_v) >= 0) && i_split == 0) { #if CK_TILE_FMHA_FWD_FAST_EXP2 - set_tile(m, sink_v * C_LOG2E); + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + set_tile(m, sink_v * C_LOG2E * scale_s); + else + set_tile(m, sink_v * C_LOG2E); #else set_tile(m, sink_v); #endif @@ -275,7 +279,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS if(__builtin_isinf_sign(sink_v) >= 0) { - set_tile(lse_acc, SMPLComputeDataType{sink_v}); + set_tile(lse_acc, SMPLComputeDataType{sink_v * scale_s}); } else { @@ -298,9 +302,22 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS q_origin.at(number<0>{}), number{}, number{}, num_splits, i_split - 1); if((__builtin_isinf_sign(sink_v) >= 0) && start >= end) { - set_tile(m, SMPLComputeDataType{sink_v}); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + set_tile(m, sink_v * C_LOG2E * scale_s); + else + set_tile(m, sink_v * C_LOG2E); +#else + set_tile(m, sink_v); +#endif set_tile(l, SMPLComputeDataType{1.0f}); } + else + { + set_tile(m, -numeric::infinity()); + clear_tile(l); + } } const index_t physical_seqlen_k_start = logical_seqlen_k_start + kv_l2p_offset; const index_t physical_seqlen_k_end = logical_seqlen_k_end + kv_l2p_offset; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp index b8e02038cf1..4c07955ba52 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -234,7 +234,11 @@ struct BlockFmhaPipelineQRKSVS if(__builtin_isinf_sign(sink_v) >= 0) { #if CK_TILE_FMHA_FWD_FAST_EXP2 - set_tile(m, sink_v * C_LOG2E); + if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI || + BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + set_tile(m, sink_v * scale_s * C_LOG2E); + else + set_tile(m, sink_v * C_LOG2E); #else set_tile(m, sink_v); #endif @@ -245,7 +249,6 @@ struct BlockFmhaPipelineQRKSVS set_tile(m, -numeric::infinity()); clear_tile(l); } - const auto q_origin = q_dram_window.get_window_origin(); const auto tile_range_result = [&mask, &q_origin]() { @@ -280,7 +283,7 @@ struct BlockFmhaPipelineQRKSVS if(__builtin_isinf_sign(sink_v) >= 0) { - set_tile(lse, SMPLComputeDataType{sink_v}); + set_tile(lse, SMPLComputeDataType{sink_v * scale_s}); } else { diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index bddb6db2cbd..7224ed3a708 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -279,7 +279,11 @@ struct BlockFmhaPipelineQRKSVSAsync if(__builtin_isinf_sign(sink_v) >= 0) { #if CK_TILE_FMHA_FWD_FAST_EXP2 - set_tile(m, sink_v * LOG2E); + if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI || + BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + set_tile(m, sink_v * scale_s * LOG2E); + else + set_tile(m, sink_v * LOG2E); #else set_tile(m, sink_v); #endif @@ -290,7 +294,6 @@ struct BlockFmhaPipelineQRKSVSAsync set_tile(m, -numeric::infinity()); clear_tile(l); } - __builtin_amdgcn_sched_barrier(0); const auto q_origin = q_dram_window.get_window_origin(); const auto tile_range_result = [&mask, &q_origin]() { @@ -325,7 +328,7 @@ struct BlockFmhaPipelineQRKSVSAsync if(__builtin_isinf_sign(sink_v) >= 0) { - set_tile(lse, SMPLComputeDataType{sink_v}); + set_tile(lse, SMPLComputeDataType{sink_v * scale_s}); } else { @@ -496,17 +499,10 @@ struct BlockFmhaPipelineQRKSVSAsync block_indices.qo_head_idx, block_indices.kv_head_idx); }; -#if !CK_TILE_FMHA_FWD_FAST_EXP2 for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i) { apply_logits_transform(s_acc.thread_buf_[i]); } -#else - for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i) - { - apply_logits_transform(s_acc.thread_buf_[i]); - } -#endif } else { diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp index 95e20d95277..2be2036e204 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp @@ -197,7 +197,11 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload if(__builtin_isinf_sign(sink_v) >= 0) { #if CK_TILE_FMHA_FWD_FAST_EXP2 - set_tile(m, sink_v * C_LOG2E); + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + set_tile(m, sink_v * C_LOG2E * scale_s); + else + set_tile(m, sink_v * C_LOG2E); #else set_tile(m, sink_v); #endif @@ -722,7 +726,11 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload if(__builtin_isinf_sign(sink_v) >= 0) { #if CK_TILE_FMHA_FWD_FAST_EXP2 - set_tile(m, sink_v * C_LOG2E); + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + set_tile(m, sink_v * C_LOG2E * scale_s); + else + set_tile(m, sink_v * C_LOG2E); #else set_tile(m, sink_v); #endif @@ -752,7 +760,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload if(__builtin_isinf_sign(sink_v) >= 0) { - set_tile(lse_acc, SMPLComputeDataType{sink_v}); + set_tile(lse_acc, SMPLComputeDataType{sink_v * scale_s}); } else { From 7ae150ad32e480f5c5df5ee6c9ce4720116acbd5 Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Mon, 12 Jan 2026 11:22:29 +0800 Subject: [PATCH 21/31] Update block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp --- .../block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp index 29985894d58..23850d8cdc1 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp @@ -276,7 +276,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync const index_t stride_v, const index_t page_stride_k, const index_t page_stride_v, - DropoutType& dropout + DropoutType& dropout, const float sink_v) const { static_assert( From d9a2b40ce185cd5239e902f92008dab6e43d0138 Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Mon, 12 Jan 2026 11:27:27 +0800 Subject: [PATCH 22/31] Fix typo in dropout parameter in fmha_batch_prefill_kernel --- include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp index 6f41eedff47..aeea5220c55 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp @@ -1256,7 +1256,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel stride_v_for_pipeline, kargs.batch_stride_k, kargs.batch_stride_v, - dropout, + dropout, sink_value); } }(); From 1bc51b3d8d998b5928bd3cd655ca4817ebc33fd1 Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Mon, 12 Jan 2026 14:01:25 +0800 Subject: [PATCH 23/31] Update block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp --- .../block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp index 23850d8cdc1..d18d3ecf5a6 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp @@ -1053,7 +1053,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync const index_t page_stride_k, const index_t page_stride_v, DropoutType& dropout, - const float sink_v) const + float sink_v) const { return operator()(q_dram_block_window_tmp, identity{}, @@ -1082,7 +1082,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync page_stride_k, page_stride_v, dropout, - const float sink_v); + sink_v); } }; From 815044c2e4b0e9736e4faadf036217efc510b010 Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Tue, 13 Jan 2026 08:46:17 +0800 Subject: [PATCH 24/31] Update example_fmha_fwd.cpp --- example/ck_tile/01_fmha/example_fmha_fwd.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/example/ck_tile/01_fmha/example_fmha_fwd.cpp b/example/ck_tile/01_fmha/example_fmha_fwd.cpp index 3d729a272de..f5ad6b2bc57 100644 --- a/example/ck_tile/01_fmha/example_fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/example_fmha_fwd.cpp @@ -115,7 +115,7 @@ auto create_args(int argc, char* argv[]) "", "Batch-mode only: per-batch effective seqlen for KV (exclude PAD).\n" "Comma-separated list of length 'b'. If empty, no override.") - .insert("init_sink", "1", "value to init the output tensor sink value for validation"); + .insert("init_sink", "0", "value to init the output tensor sink value for validation"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); From 9cd4f13beded6b23909099d473c8c08d0b8e6a5e Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Tue, 13 Jan 2026 16:52:01 +0800 Subject: [PATCH 25/31] Update include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp index 2be2036e204..aab79c52ae9 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp @@ -231,7 +231,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload if(__builtin_isinf_sign(sink_v) >= 0) { - set_tile(lse_acc, SMPLComputeDataType{sink_v}); + set_tile(lse_acc, SMPLComputeDataType{sink_v * scale_s}); } else { From 3c68266563e33d7238ed12e789f0e3ef4e592fbe Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Tue, 13 Jan 2026 16:53:43 +0800 Subject: [PATCH 26/31] Update include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp index 720247e8cf1..72b61083688 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp @@ -296,7 +296,7 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS if(__builtin_isinf_sign(sink_v) >= 0) { - set_tile(lse_acc, SMPLComputeDataType{sink_v}); + set_tile(lse_acc, SMPLComputeDataType{sink_v * scale_s}); } else { From f0e1d503f3a256c97bbc4a0a915291ed20b1b8a0 Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Tue, 13 Jan 2026 04:13:28 -0600 Subject: [PATCH 27/31] optimized some code Signed-off-by: Linjun-AMD --- CHANGELOG.md | 2 +- example/ck_tile/01_fmha/fmha_fwd_runner.hpp | 1 + include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp | 2 +- include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp | 4 ++-- include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp | 2 +- include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp | 2 +- ...lock_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp | 2 +- 7 files changed, 8 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 85ba5d8b453..066dc9aa3b6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,7 +14,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added FP8 KV cache support for FMHA batch prefill. * Added support for gfx1153 target. * Added FMHA batch prefill kernel support for several KV cache layouts, flexible page sizes, and different lookup table configurations. -* Added gptoss sink support for FMHA FWD, include qr_ks_vs, qr_async, qr_async_trload and splitkv pipelines. +* Added gpt-oss sink support for FMHA FWD, include qr_ks_vs, qr_async, qr_async_trload and splitkv pipelines. ### Changed diff --git a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp index 332a078db7b..dce5d030422 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp @@ -700,6 +700,7 @@ fwd_result fmha_fwd_run(mode_enum mode, iota_shuffle(cache_batch_idx_host.begin(), cache_batch_idx_host.end(), 0, random_engine); if(init_sink_value != 0) { + // sink is initialized to a fixed integer value for easy debugging and use 30 to 60 range for close to rowmax values. ck_tile::FillUniformDistributionIntegerValue{30.f, 60.f, next_seed()}( sink_host); } diff --git a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp index aeea5220c55..95e805cac3f 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp @@ -709,7 +709,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel const float sink_value = kargs.sink_ptr != nullptr ? (*(static_cast(kargs.sink_ptr) + i_nhead)) / kargs.scale_s - : static_cast(-numeric::infinity()); + : -numeric::infinity(); const index_t seqlen_k = [&]() { if constexpr(kKVLookupTable == BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D) diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index b117b8fbea8..15a6529ebd3 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -1121,7 +1121,7 @@ struct FmhaFwdKernel const float sink_value = kargs.sink_ptr != nullptr ? (*(static_cast(kargs.sink_ptr) + i_nhead)) / kargs.scale_s - : static_cast(-numeric::infinity()); + : -numeric::infinity(); if constexpr(kIsGroupMode) { @@ -1640,7 +1640,7 @@ struct FmhaFwdKernel const float sink_value = kargs.sink_ptr != nullptr ? (*(static_cast(kargs.sink_ptr) + i_nhead)) / kargs.scale_s - : static_cast(-numeric::infinity()); + : -numeric::infinity(); const index_t i_m0 = i_tile_m * FmhaPipeline::kM0; const index_t i_n1 = i_tile_n * FmhaPipeline::kN1; diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp index f078f19dc69..59d33f287ec 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp @@ -920,7 +920,7 @@ struct FmhaFwdPagedKVKernel const float sink_value = kargs.sink_ptr != nullptr ? (*(static_cast(kargs.sink_ptr) + i_nhead)) / kargs.scale_s - : static_cast(-numeric::infinity()); + : -numeric::infinity(); if constexpr(kIsGroupMode) { diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp index 25a8ce9c683..a7daa31d4a5 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp @@ -622,7 +622,7 @@ struct FmhaFwdSplitKVKernel const float sink_value = kargs.sink_ptr != nullptr ? (*(static_cast(kargs.sink_ptr) + i_nhead)) / kargs.scale_s - : static_cast(-numeric::infinity()); + : -numeric::infinity(); if constexpr(kIsGroupMode) { diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp index 72b61083688..a242fec11ec 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp @@ -257,7 +257,7 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS clear_tile(o_acc); if((__builtin_isinf_sign(sink_v) >= 0) && i_split == 0) { - set_tile(m, SMPLComputeDataType{sink_v}); + set_tile(m, SMPLComputeDataType{sink_v * C_LOG2E}); set_tile(l, SMPLComputeDataType{1.0f}); } else From 2862c13daa154b4bc3ab19c645ce364b7e551c9e Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Tue, 13 Jan 2026 04:36:27 -0600 Subject: [PATCH 28/31] fix splitkv error Signed-off-by: Linjun-AMD --- .../block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp | 2 +- .../fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp index a242fec11ec..adc8ea5a90c 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp @@ -294,7 +294,7 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS auto lse_acc = make_static_distributed_tensor(m.get_tile_distribution()); - if(__builtin_isinf_sign(sink_v) >= 0) + if(__builtin_isinf_sign(sink_v) >= 0 && i_split == 0) { set_tile(lse_acc, SMPLComputeDataType{sink_v * scale_s}); } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp index e4bd2b0e3d3..c09330f8471 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp @@ -277,7 +277,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS auto lse_acc = make_static_distributed_tensor(m.get_tile_distribution()); - if(__builtin_isinf_sign(sink_v) >= 0) + if(__builtin_isinf_sign(sink_v) >= 0 && i_split == 0) { set_tile(lse_acc, SMPLComputeDataType{sink_v * scale_s}); } From be921712ec8cd627bdde4b830333f7925cc8c3d8 Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Tue, 13 Jan 2026 19:43:32 -0600 Subject: [PATCH 29/31] update sink reference Signed-off-by: Linjun-AMD --- example/ck_tile/01_fmha/fmha_fwd_runner.hpp | 39 ++++++++++++++------- 1 file changed, 26 insertions(+), 13 deletions(-) diff --git a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp index dce5d030422..90590dc88fc 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp @@ -149,6 +149,28 @@ int override_num_splits_if_necessary( return num_splits; } +template +void copy_attention_scores_with_sink(const ck_tile::HostTensor& s_host_ref, + const ck_tile::HostTensor& sink_host, + ck_tile::HostTensor& s_with_sinks_ref, + ck_tile::index_t nhead, + ck_tile::index_t real_seqlen_q, + ck_tile::index_t real_seqlen_k) +{ + for(auto i_h = 0; i_h < nhead; i_h++) + { + for(auto i_r = 0; i_r < real_seqlen_q; i_r++) + { + for(auto i_c = 0; i_c < real_seqlen_k; i_c++) + { + s_with_sinks_ref(i_h, i_r, i_c) = s_host_ref(i_h, i_r, i_c); + } + // Append sink token at the end of each row + s_with_sinks_ref(i_h, i_r, real_seqlen_k) = sink_host(i_h); + } + } +} + template fwd_result fmha_fwd_run(mode_enum mode, ck_tile::index_t batch, @@ -700,7 +722,8 @@ fwd_result fmha_fwd_run(mode_enum mode, iota_shuffle(cache_batch_idx_host.begin(), cache_batch_idx_host.end(), 0, random_engine); if(init_sink_value != 0) { - // sink is initialized to a fixed integer value for easy debugging and use 30 to 60 range for close to rowmax values. + // sink is initialized to a fixed integer value for easy debugging and use 30 to 60 range + // for close to rowmax values. ck_tile::FillUniformDistributionIntegerValue{30.f, 60.f, next_seed()}( sink_host); } @@ -1695,18 +1718,8 @@ fwd_result fmha_fwd_run(mode_enum mode, {nhead, real_seqlen_q, real_seqlen_k + 1}); // Copy original attention scores and append sink values - for(auto i_h = 0; i_h < nhead; i_h++) - { - for(auto i_r = 0; i_r < real_seqlen_q; i_r++) - { - for(auto i_c = 0; i_c < real_seqlen_k; i_c++) - { - s_with_sinks_ref(i_h, i_r, i_c) = s_host_ref(i_h, i_r, i_c); - } - // Append sink token at the end of each row - s_with_sinks_ref(i_h, i_r, real_seqlen_k) = sink_host(i_h); - } - } + copy_attention_scores_with_sink( + s_host_ref, sink_host, s_with_sinks_ref, nhead, real_seqlen_q, real_seqlen_k); // Compute softmax on extended tensor ck_tile::HostTensor p_extended( From 2ad9d5eb9db153dc790d9c83c6bcef0b9747932b Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Wed, 14 Jan 2026 11:01:32 +0800 Subject: [PATCH 30/31] Update fmha_fwd_runner.hpp --- example/ck_tile/01_fmha/fmha_fwd_runner.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp index 19eb3d4037d..0c988b2acce 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp @@ -149,8 +149,8 @@ int override_num_splits_if_necessary( return num_splits; } -template -void copy_attention_scores_with_sink(const ck_tile::HostTensor& s_host_ref, +template +void copy_attention_scores_with_sink(const ck_tile::HostTensor& s_host_ref, const ck_tile::HostTensor& sink_host, ck_tile::HostTensor& s_with_sinks_ref, ck_tile::index_t nhead, From fdd767e5d916efb1788694af2c833ae1bfa5b956 Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Wed, 14 Jan 2026 13:39:47 +0800 Subject: [PATCH 31/31] Update smoke_test_fwd_sink.sh --- example/ck_tile/01_fmha/script/smoke_test_fwd_sink.sh | 2 -- 1 file changed, 2 deletions(-) diff --git a/example/ck_tile/01_fmha/script/smoke_test_fwd_sink.sh b/example/ck_tile/01_fmha/script/smoke_test_fwd_sink.sh index 746ff8c0e1e..5c9d3132b3f 100755 --- a/example/ck_tile/01_fmha/script/smoke_test_fwd_sink.sh +++ b/example/ck_tile/01_fmha/script/smoke_test_fwd_sink.sh @@ -91,5 +91,3 @@ $EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=1024 -s_k=1024 -bias=n -lse $EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=4096 -s_k=4096 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 $EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=8192 -s_k=8192 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 -mask=1 - -$EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=16384 -s_k=16384 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=2 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 \ No newline at end of file