From d0bb1fa357cb4d3e475481e6bfd62380f409d47f Mon Sep 17 00:00:00 2001 From: Emily Martins Date: Wed, 21 Jan 2026 18:03:00 +0000 Subject: [PATCH 1/2] Fix alignment issue in Stream-K workspace buffer In CK Tile Stream-K, the workspace buffer is used to hold flags and partials, where the first i bytes holds the flags and the remaining bytes hold partials. This change adds padding to the flags prefix of the workspace buffer to ensure the number of bytes is 128B-aligned. Without this alignment, since workgroups do not skip cache when reading from partials, they may read stale partials data in cache, leading to incorrect results. The added padding avoids the stale data reading. This change also re-enables the test_ck_tile_streamk_reduction tests. --- .../streamk_gemm_tile_partitioner.hpp | 3 +- .../streamk_gemm_tile_partitioner_impl.hpp | 5 +- test/ck_tile/gemm_streamk/CMakeLists.txt | 7 ++- .../test_streamk_tile_partitioner.cpp | 37 ++++++++++++++- .../test_streamk_tile_partitioner_common.hpp | 47 +++++++++++++++++-- 5 files changed, 89 insertions(+), 10 deletions(-) diff --git a/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner.hpp b/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner.hpp index 0b0f6c18ef2..f028ba0c626 100644 --- a/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner.hpp +++ b/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner.hpp @@ -42,7 +42,8 @@ struct StreamKTilePartitionerBase CK_TILE_HOST_DEVICE index_t get_partials_buffer_size(index_t acc_element_bytes) const noexcept; /** - * @brief Calculates the total space needed for the flags buffer. + * @brief Calculates the total space needed for the flags buffer whose total byte size is + * 128B-aligned. * * @return index_t The number of bytes needed for the flags buffer. */ diff --git a/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner_impl.hpp b/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner_impl.hpp index 1764a1ce838..f80eec844cc 100644 --- a/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner_impl.hpp +++ b/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner_impl.hpp @@ -58,7 +58,10 @@ CK_TILE_HOST_DEVICE index_t StreamKTilePartitionerBase::get_flags_buffer_size() const noexcept { - return sizeof(index_t) * sk_ctas_; + constexpr index_t alignment = 128; + const index_t required_bytes = sizeof(index_t) * sk_ctas_; + const index_t padded_bytes = ck_tile::integer_least_multiple(required_bytes, alignment); + return padded_bytes; } template diff --git a/test/ck_tile/gemm_streamk/CMakeLists.txt b/test/ck_tile/gemm_streamk/CMakeLists.txt index 6aaa145c7d5..1390e5ee07f 100644 --- a/test/ck_tile/gemm_streamk/CMakeLists.txt +++ b/test/ck_tile/gemm_streamk/CMakeLists.txt @@ -23,10 +23,9 @@ if(GPU_TARGETS MATCHES "gfx90a|gfx942|gfx950") #TODO: support all arches #TODO: current c-shuffle only supports C layout as R add_gtest_executable(test_ck_tile_streamk_tile_partitioner test_streamk_tile_partitioner.cpp) - # TODO: Renable once transient bug for reduction is resolved. - # add_gtest_executable(test_ck_tile_streamk_reduction - # ${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_fp16_reduction.cpp - # test_gemm_streamk_util.cpp) + add_gtest_executable(test_ck_tile_streamk_reduction + ${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_fp16_reduction.cpp + test_gemm_streamk_util.cpp) add_gtest_executable(test_ck_tile_streamk_smoke ${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_fp16_persistent.cpp ${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_bf16_persistent.cpp diff --git a/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner.cpp b/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner.cpp index 637f71c04fa..30b1b878c5d 100644 --- a/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner.cpp +++ b/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner.cpp @@ -51,6 +51,39 @@ TEST(StreamKTilePartitionerBaseConstructor, EdgeCase) validate_streamk_base_constructor(expected_values, tile_partitioner); } +TEST(StreamKTilePartitionerBaseGetFlagsBufferSize, FlagsLessThan128Bytes) +{ + using Config = StreamKTilePartitionerBaseConfigDP2TileSK; + + ck_tile::StreamKTilePartitionerBase + tile_partitioner{Config::M, Config::N, Config::K, Config::GRID}; + + EXPECT_EQ(tile_partitioner.get_flags_buffer_size(), 128); +} + +TEST(StreamKTilePartitionerBaseGetFlagsBufferSize, FlagsEqual128Bytes) +{ + using Config = StreamKTilePartitionerBaseConfigFlagsSizeEqual128Bytes; + + ck_tile::StreamKTilePartitionerBase + tile_partitioner{Config::M, Config::N, Config::K, Config::GRID}; + + EXPECT_EQ(tile_partitioner.get_flags_buffer_size(), 128); +} + +TEST(StreamKTilePartitionerBaseGetFlagsBufferSize, FlagsGreaterThan128Bytes) +{ + using Config = StreamKTilePartitionerBaseConfigFlagsSizeGreaterThan128Bytes; + + ck_tile::StreamKTilePartitionerBase + tile_partitioner{Config::M, Config::N, Config::K, Config::GRID}; + + EXPECT_EQ(tile_partitioner.get_flags_buffer_size(), 256); +} + TEST(StreamKTilePartitionerBaseGetWorkSpaceSize, AtomicStrategy) { using Config = StreamKTilePartitionerBaseConfigDP2TileSK; @@ -71,7 +104,9 @@ TEST(StreamKTilePartitionerBaseGetWorkSpaceSize, ReductionStrategy) ck_tile::index_t expected_partials_size = sizeof(float) * Config::M_TILE * Config::N_TILE * Config::GRID; - ck_tile::index_t expected_flags_size = sizeof(ck_tile::index_t) * Config::GRID; + // Since GRID is 3, the final padded flags array must be 128B to ensure the total byte size of + // the flags array is 128B-aligned. + ck_tile::index_t expected_flags_size = 128; EXPECT_EQ(tile_partitioner.get_workspace_size(sizeof(float)), expected_partials_size + expected_flags_size); diff --git a/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner_common.hpp b/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner_common.hpp index 3daec049a77..31217ba1014 100644 --- a/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner_common.hpp +++ b/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner_common.hpp @@ -198,9 +198,11 @@ struct StreamKTilePartitionerBaseConfig struct StreamKTilePartitionerBaseConfigDP2TileSK : public StreamKTilePartitionerBaseConfig { - static constexpr ck_tile::index_t M = 28; - static constexpr ck_tile::index_t N = 4; - static constexpr ck_tile::index_t K = 16; + static constexpr ck_tile::index_t M = 28; + static constexpr ck_tile::index_t N = 4; + static constexpr ck_tile::index_t K = 16; + // The minimum number of bytes needed for the flags array is GRID * 4B = 3 * 4B = 12B. To ensure + // the total byte size of the array is 128B-aligned, the flags array must be 128B. static constexpr ck_tile::index_t GRID = 3; static constexpr ck_tile::index_t M_TILE = 4; @@ -212,6 +214,45 @@ struct StreamKTilePartitionerBaseConfigDP2TileSK : public StreamKTilePartitioner ck_tile::sequence>; }; +struct StreamKTilePartitionerBaseConfigFlagsSizeEqual128Bytes + : public StreamKTilePartitionerBaseConfig +{ + static constexpr ck_tile::index_t M = 28; + static constexpr ck_tile::index_t N = 4; + static constexpr ck_tile::index_t K = 32; + // The minimum number of bytes needed for the flags array is GRID * 4B = 32 * 4B = 128B. So, the + // number of bytes for the flags array should be 128B. + static constexpr ck_tile::index_t GRID = 32; + + static constexpr ck_tile::index_t M_TILE = 4; + static constexpr ck_tile::index_t N_TILE = 4; + static constexpr ck_tile::index_t K_TILE = 1; + + using GemmShape = ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; +}; + +struct StreamKTilePartitionerBaseConfigFlagsSizeGreaterThan128Bytes + : public StreamKTilePartitionerBaseConfig +{ + static constexpr ck_tile::index_t M = 28; + static constexpr ck_tile::index_t N = 4; + static constexpr ck_tile::index_t K = 33; + // The minimum number of bytes needed for the flags array is GRID * 4B = 33 * 4B = 132B. So, the + // number of bytes for the flags array should be 2 * 128B = 256B to ensure the total byte size + // of the array is 128B-aligned. + static constexpr ck_tile::index_t GRID = 33; + + static constexpr ck_tile::index_t M_TILE = 4; + static constexpr ck_tile::index_t N_TILE = 4; + static constexpr ck_tile::index_t K_TILE = 1; + + using GemmShape = ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; +}; + struct StreamKTilePartitionerBaseConfigSKOnlyWith2WgsPerSKTile : public StreamKTilePartitionerBaseConfig { From 309c2539895f29606bd840d930c780a5ec1b96c2 Mon Sep 17 00:00:00 2001 From: Emily Martins Date: Thu, 22 Jan 2026 23:27:40 +0000 Subject: [PATCH 2/2] Compute reference GEMM on GPU for test verification to decrease testing time --- .../gemm_streamk/test_gemm_streamk_util.hpp | 34 +++++++++++++++---- 1 file changed, 27 insertions(+), 7 deletions(-) diff --git a/test/ck_tile/gemm_streamk/test_gemm_streamk_util.hpp b/test/ck_tile/gemm_streamk/test_gemm_streamk_util.hpp index 237dc24c3bd..96f90a5c2d5 100644 --- a/test/ck_tile/gemm_streamk/test_gemm_streamk_util.hpp +++ b/test/ck_tile/gemm_streamk/test_gemm_streamk_util.hpp @@ -262,20 +262,40 @@ class TestCkTileStreamK : public ::testing::Test c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); - ck_tile::HostTensor c_m_n_host_ref( + // Calculate reference GEMM on the GPU + ck_tile::HostTensor c_m_n_dev_ref( f_host_tensor_descriptor(M, N, stride_C, CLayout{})); - c_m_n_host_ref.SetZero(); - - ck_tile::reference_gemm( - a_m_k, b_k_n, c_m_n_host_ref); + ck_tile::DeviceMem ref_c_m_n_dev_buf(c_m_n_dev_ref.get_element_space_size_in_bytes()); + ref_c_m_n_dev_buf.SetZero(); + + ADataType* a_m_k_dev_ref_ptr = static_cast(a_m_k_dev_buf.GetDeviceBuffer()); + BDataType* b_k_n_dev_ref_ptr = static_cast(b_k_n_dev_buf.GetDeviceBuffer()); + CDataType* c_m_n_dev_ref_ptr = static_cast(ref_c_m_n_dev_buf.GetDeviceBuffer()); + ck_tile::reference_gemm_gpu(a_m_k_dev_ref_ptr, + b_k_n_dev_ref_ptr, + c_m_n_dev_ref_ptr, + M, + N, + K, + stride_A, + stride_B, + stride_C); + ref_c_m_n_dev_buf.FromDevice(c_m_n_dev_ref.data()); const float max_accumulated_value = - *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); + *std::max_element(c_m_n_dev_ref.mData.begin(), c_m_n_dev_ref.mData.end()); + const auto rtol_atol = calculate_rtol_atol( K, num_accumulations_per_tile, max_accumulated_value); bool pass = ck_tile::check_err(c_m_n_dev_result, - c_m_n_host_ref, + c_m_n_dev_ref, "Error: Incorrect results!", rtol_atol.at(ck_tile::number<0>{}), rtol_atol.at(ck_tile::number<1>{}));