From 5bfda1f662e886eff6a67135ffa4b7d2999e725d Mon Sep 17 00:00:00 2001 From: Mohsen Saffari Date: Wed, 19 Nov 2025 14:02:24 +0000 Subject: [PATCH 1/5] Add validity checks for MoE FlatMM scatter and enable bf16 hardware atomic --- example/ck_tile/18_flatmm/run_moe_flatmm_example.inc | 2 +- include/ck_tile/core/arch/generic_memory_space_atomic.hpp | 4 ++++ include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp | 8 ++++++-- 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/example/ck_tile/18_flatmm/run_moe_flatmm_example.inc b/example/ck_tile/18_flatmm/run_moe_flatmm_example.inc index d898ed2f294..c92f90359d0 100644 --- a/example/ck_tile/18_flatmm/run_moe_flatmm_example.inc +++ b/example/ck_tile/18_flatmm/run_moe_flatmm_example.inc @@ -304,7 +304,7 @@ int run_moe_gemm_example_with_layouts(int argc, const float max_accumulated_value = *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); - const auto rtol_atol = calculate_rtol_atol( + [[maybe_unused]] const auto rtol_atol = calculate_rtol_atol( K, 1 /*kbatch*/, max_accumulated_value); c_m_n_ref_buf->FromDevice(c_m_n_host_ref.data()); diff --git a/include/ck_tile/core/arch/generic_memory_space_atomic.hpp b/include/ck_tile/core/arch/generic_memory_space_atomic.hpp index e56bcadcba2..0ff97bb9a79 100644 --- a/include/ck_tile/core/arch/generic_memory_space_atomic.hpp +++ b/include/ck_tile/core/arch/generic_memory_space_atomic.hpp @@ -102,6 +102,9 @@ CK_TILE_DEVICE void atomic_add(X* p_dst, const X& x); template <> CK_TILE_DEVICE void atomic_add(bf16x2_t* p_dst, const bf16x2_t& x) { +#if HAS_GLOBAL_ATOMIC_PK_ADD_BUILTIN + __builtin_amdgcn_global_atomic_fadd_v2bf16(c_style_pointer_cast(p_dst), x); +#else union U32BF162_ADDR { uint32_t* u32_a; @@ -128,6 +131,7 @@ CK_TILE_DEVICE void atomic_add(bf16x2_t* p_dst, const bf16x2_t& x) new_v = new_.u32; cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v); } while(cur_v.u32 != old_v); +#endif } template <> diff --git a/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp index 8a9aa3cdd3a..ceaec5ff0e6 100644 --- a/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp @@ -623,7 +623,7 @@ struct MoeFlatmmKernel { return make_naive_tensor_view( e_ptr, - make_tuple(IsInputGemm ? kargs.NumTokens * kargs.TopK : kargs.NumToken, + make_tuple(IsInputGemm ? kargs.NumTokens * kargs.TopK : kargs.NumTokens, IsGateUp ? kargs.N / 2 : kargs.N), make_tuple(1, kargs.stride_C), number<1>{}, @@ -1250,6 +1250,8 @@ struct MoeFlatmmKernel constexpr int MPerThread = TileEncodingPattern::Y2; statically_indexed_array, NumMEpiTile> c_scatter_offsets; + statically_indexed_array, NumMEpiTile> + c_scatter_valids; auto c_coord = dram_tile_distribution.calculate_index(); static_for<0, NumMEpiTile, 1>{}([&](auto mIter) { static_for<0, MPerThread, 1>{}([&](auto m0) { @@ -1262,6 +1264,7 @@ struct MoeFlatmmKernel scatter_token_id = scatter_token_id * kargs.TopK + (fused_token >> token_id_offset); c_scatter_offsets[mIter][m0] = scatter_token_id * kargs.stride_C; + c_scatter_valids[mIter][m0] = (scatter_token_id < kargs.NumTokens); }); }); @@ -1302,7 +1305,8 @@ struct MoeFlatmmKernel c_block_window.get_window_lengths(), c_block_window.get_window_origin(), dram_tile_distribution, - c_scatter_offsets[mIter]); + c_scatter_offsets[mIter], + c_scatter_valids[mIter]); if constexpr(!IsInputGemm || EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add) From 5b4ee5518571c0775cb87b7001b9fac3082c2487 Mon Sep 17 00:00:00 2001 From: mohsen saffari Date: Wed, 19 Nov 2025 15:32:39 +0100 Subject: [PATCH 2/5] correct clang-format --- example/ck_tile/18_flatmm/run_moe_flatmm_example.inc | 5 +++-- include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/example/ck_tile/18_flatmm/run_moe_flatmm_example.inc b/example/ck_tile/18_flatmm/run_moe_flatmm_example.inc index c92f90359d0..f85b051da52 100644 --- a/example/ck_tile/18_flatmm/run_moe_flatmm_example.inc +++ b/example/ck_tile/18_flatmm/run_moe_flatmm_example.inc @@ -304,8 +304,9 @@ int run_moe_gemm_example_with_layouts(int argc, const float max_accumulated_value = *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); - [[maybe_unused]] const auto rtol_atol = calculate_rtol_atol( - K, 1 /*kbatch*/, max_accumulated_value); + [[maybe_unused]] const auto rtol_atol = + calculate_rtol_atol( + K, 1 /*kbatch*/, max_accumulated_value); c_m_n_ref_buf->FromDevice(c_m_n_host_ref.data()); const float rtol = std::is_same_v && IsInputGemm ? 1e-3 : 1e-2; diff --git a/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp index ceaec5ff0e6..fb98a71b0f4 100644 --- a/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp @@ -1264,7 +1264,7 @@ struct MoeFlatmmKernel scatter_token_id = scatter_token_id * kargs.TopK + (fused_token >> token_id_offset); c_scatter_offsets[mIter][m0] = scatter_token_id * kargs.stride_C; - c_scatter_valids[mIter][m0] = (scatter_token_id < kargs.NumTokens); + c_scatter_valids[mIter][m0] = (scatter_token_id < kargs.NumTokens); }); }); From 8d56b0af6fa7ad4a10028a79028225698b52bb96 Mon Sep 17 00:00:00 2001 From: mohsen saffari Date: Fri, 21 Nov 2025 14:36:24 +0100 Subject: [PATCH 3/5] removed unused rtol_atol variable from example code --- example/ck_tile/18_flatmm/run_moe_flatmm_example.inc | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/example/ck_tile/18_flatmm/run_moe_flatmm_example.inc b/example/ck_tile/18_flatmm/run_moe_flatmm_example.inc index f85b051da52..ef6a6dba90b 100644 --- a/example/ck_tile/18_flatmm/run_moe_flatmm_example.inc +++ b/example/ck_tile/18_flatmm/run_moe_flatmm_example.inc @@ -304,9 +304,7 @@ int run_moe_gemm_example_with_layouts(int argc, const float max_accumulated_value = *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); - [[maybe_unused]] const auto rtol_atol = - calculate_rtol_atol( - K, 1 /*kbatch*/, max_accumulated_value); + c_m_n_ref_buf->FromDevice(c_m_n_host_ref.data()); const float rtol = std::is_same_v && IsInputGemm ? 1e-3 : 1e-2; From 122a981483357552870301581f6637c93be48538 Mon Sep 17 00:00:00 2001 From: mohsen saffari Date: Fri, 21 Nov 2025 14:39:32 +0100 Subject: [PATCH 4/5] clang format correction --- example/ck_tile/18_flatmm/run_moe_flatmm_example.inc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/example/ck_tile/18_flatmm/run_moe_flatmm_example.inc b/example/ck_tile/18_flatmm/run_moe_flatmm_example.inc index ef6a6dba90b..b4aeb9c59b1 100644 --- a/example/ck_tile/18_flatmm/run_moe_flatmm_example.inc +++ b/example/ck_tile/18_flatmm/run_moe_flatmm_example.inc @@ -304,7 +304,7 @@ int run_moe_gemm_example_with_layouts(int argc, const float max_accumulated_value = *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); - + c_m_n_ref_buf->FromDevice(c_m_n_host_ref.data()); const float rtol = std::is_same_v && IsInputGemm ? 1e-3 : 1e-2; From caf82d1a4fb5d1f01a474437fe4cd8bf9a17024f Mon Sep 17 00:00:00 2001 From: mohsen saffari Date: Fri, 21 Nov 2025 14:51:39 +0100 Subject: [PATCH 5/5] remove unused varable max_accumulated_value from example --- example/ck_tile/18_flatmm/run_moe_flatmm_example.inc | 3 --- 1 file changed, 3 deletions(-) diff --git a/example/ck_tile/18_flatmm/run_moe_flatmm_example.inc b/example/ck_tile/18_flatmm/run_moe_flatmm_example.inc index b4aeb9c59b1..4303acec5a0 100644 --- a/example/ck_tile/18_flatmm/run_moe_flatmm_example.inc +++ b/example/ck_tile/18_flatmm/run_moe_flatmm_example.inc @@ -302,9 +302,6 @@ int run_moe_gemm_example_with_layouts(int argc, static_cast(per_token_scale_dev_buf.GetDeviceBuffer()), static_cast(per_channel_scale_dev_buf.GetDeviceBuffer())); - const float max_accumulated_value = - *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); - c_m_n_ref_buf->FromDevice(c_m_n_host_ref.data()); const float rtol = std::is_same_v && IsInputGemm ? 1e-3 : 1e-2;