From 972f10873f1c142076300f7e88a9671d5c253538 Mon Sep 17 00:00:00 2001 From: Kumar Date: Tue, 18 Nov 2025 15:31:23 +0530 Subject: [PATCH 01/20] Part1. Added helper files for enabling persistent async --- example/ck_tile/03_gemm/CMakeLists.txt | 2 + .../ck_tile/03_gemm/gemm_persistent_async.cpp | 220 +++++++++++++++ .../03_gemm/gemm_persistent_async_invoker.hpp | 256 ++++++++++++++++++ example/ck_tile/03_gemm/gemm_utils.hpp | 141 ++++++++++ .../03_gemm/persistent_async_scheduler.hpp | 243 +++++++++++++++++ example/ck_tile/03_gemm/run_gemm_example.inc | 7 - 6 files changed, 862 insertions(+), 7 deletions(-) create mode 100644 example/ck_tile/03_gemm/gemm_persistent_async.cpp create mode 100644 example/ck_tile/03_gemm/gemm_persistent_async_invoker.hpp create mode 100644 example/ck_tile/03_gemm/persistent_async_scheduler.hpp diff --git a/example/ck_tile/03_gemm/CMakeLists.txt b/example/ck_tile/03_gemm/CMakeLists.txt index d2112a67bf5..2cc6b227e2d 100644 --- a/example/ck_tile/03_gemm/CMakeLists.txt +++ b/example/ck_tile/03_gemm/CMakeLists.txt @@ -3,6 +3,7 @@ add_executable(tile_example_gemm_universal EXCLUDE_FROM_ALL universal_gemm.cpp) add_executable(tile_example_gemm_weight_preshuffle EXCLUDE_FROM_ALL gemm_weight_preshuffle.cpp) add_executable(tile_example_gemm_reduce EXCLUDE_FROM_ALL gemm_splitk_two_stage_reduce.cpp) add_executable(tile_example_gemm_splitk_two_stage EXCLUDE_FROM_ALL gemm_splitk_two_stage.cpp) +add_executable(tile_example_gemm_persistent_async EXCLUDE_FROM_ALL gemm_persistent_async.cpp) set(EXAMPLE_GEMM_COMPILE_OPTIONS) set(EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS) if(CK_USE_OCP_FP8) @@ -18,3 +19,4 @@ target_compile_options(tile_example_gemm_universal PRIVATE ${EXAMPLE_GEMM_COMPIL target_compile_options(tile_example_gemm_weight_preshuffle PRIVATE ${EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS}) target_compile_options(tile_example_gemm_reduce PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) target_compile_options(tile_example_gemm_splitk_two_stage PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) +target_compile_options(tile_example_gemm_persistent_async PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) diff --git a/example/ck_tile/03_gemm/gemm_persistent_async.cpp b/example/ck_tile/03_gemm/gemm_persistent_async.cpp new file mode 100644 index 00000000000..f629d5b880d --- /dev/null +++ b/example/ck_tile/03_gemm/gemm_persistent_async.cpp @@ -0,0 +1,220 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +/** + * @file gemm_persistent_async.cpp + * @brief Example demonstrating persistent GEMM with async input readiness + * + * This example shows how to use the PersistentAsyncScheduler for GEMM operations + * where input data becomes ready asynchronously in chunks. This is particularly + * useful in distributed computing scenarios where data arrives incrementally. + * + * Features demonstrated: + * - Chunk-based async input signaling + * - Producer-consumer synchronization + * - Pivot-based tile traversal for hotspot spreading + * - Persistent kernel execution + */ + +#include "gemm_utils.hpp" +#include "run_gemm_example.inc" +#include "run_gemm_example_common.hpp" +#include "gemm_persistent_async_invoker.hpp" +#include "persistent_async_scheduler.hpp" +#include "ck_tile/core/utility/gemm_validation.hpp" +#include + +/** + * @brief Helper to allocate and initialize chunk signals + * + * @param num_chunks Number of chunks to allocate signals for + * @param stream HIP stream for async operations + * @return Device pointer to chunk signals array + */ +static uint32_t* allocate_chunk_signals(int num_chunks, hipStream_t stream) +{ + uint32_t* signals_device = nullptr; + + // Allocate device memory for signals + ck_tile::hip_check_error(hipMalloc(&signals_device, num_chunks * sizeof(uint32_t))); + + // Initialize all signals to 0 (not ready) + ck_tile::hip_check_error( + hipMemsetAsync(signals_device, 0, num_chunks * sizeof(uint32_t), stream)); + + return signals_device; +} + +/** + * @brief Helper to signal chunk readiness + * + * @param signals Device pointer to signals array + * @param chunk_idx Index of chunk to signal + * @param stream HIP stream for async operations + */ +static void signal_chunk_ready(uint32_t* signals, int chunk_idx, hipStream_t stream) +{ + uint32_t ready = 1; + ck_tile::hip_check_error(hipMemcpyAsync( + &signals[chunk_idx], &ready, sizeof(uint32_t), hipMemcpyHostToDevice, stream)); +} + +/** + * @brief Simulate async data arrival by signaling chunks progressively + * + * In a real application, this would be triggered by actual data arrival events + * (e.g., network communication, file I/O, etc.) + */ +static void +simulate_async_data_arrival(uint32_t* signals, int num_chunks, hipStream_t stream, int delay_ms = 1) +{ + // Signal chunks one by one with a small delay + // In practice, this would be event-driven based on actual data availability + for(int i = 0; i < num_chunks; ++i) + { + // Simulate delay in data arrival + if(delay_ms > 0 && i > 0) + { + std::this_thread::sleep_for(std::chrono::milliseconds(delay_ms)); + } + + signal_chunk_ready(signals, i, stream); + } +} + +int run_gemm_example(ck_tile::ArgParser& arg_parser) +{ + std::string data_type = arg_parser.get_str("prec"); + std::string a_layout = arg_parser.get_str("a_layout"); + std::string b_layout = arg_parser.get_str("b_layout"); + std::string c_layout = arg_parser.get_str("c_layout"); + + std::tuple gemm_sizes = + parse_gemm_size(arg_parser); + + int m = std::get<0>(gemm_sizes); + int n = std::get<1>(gemm_sizes); + int k = std::get<2>(gemm_sizes); + + int stride_a = arg_parser.get_int("stride_a"); + int stride_b = arg_parser.get_int("stride_b"); + int stride_c = arg_parser.get_int("stride_c"); + + // Async-specific parameters + int tiles_per_chunk_m = arg_parser.get_int("tiles_per_chunk_m"); + int tile_idx_pivot_m = arg_parser.get_int("tile_idx_pivot_m"); + bool enable_async = arg_parser.get_int("enable_async") != 0; + + // using GemmConfig = GemmConfigMemoryInterwave; + using DefaultGemmConfig = GemmConfigMemoryInterwave; + using Invoker = PersistentAsyncInvoker; + + ck_tile::validate_gemm_stride( + a_layout, b_layout, c_layout, m, n, k, stride_a, stride_b, stride_c); + + std::cout << "=== Persistent Async GEMM Example ===\n"; + std::cout << "Matrix dimensions: M=" << m << ", N=" << n << ", K=" << k << '\n'; + std::cout << "Async parameters:\n"; + std::cout << " tiles_per_chunk_m: " << tiles_per_chunk_m << '\n'; + std::cout << " tile_idx_pivot_m: " << tile_idx_pivot_m << '\n'; + std::cout << " async_enabled: " << (enable_async ? "yes" : "no") << '\n'; + std::cout << "====================================\n\n"; + + // Calculate number of chunks + const int tiles_m = (m + DefaultGemmConfig::M_Tile - 1) / DefaultGemmConfig::M_Tile; + const int num_chunks = (tiles_m + tiles_per_chunk_m - 1) / tiles_per_chunk_m; + + if(tiles_m % tiles_per_chunk_m != 0) + { + std::cerr << "Warning: tiles_per_chunk_m (" << tiles_per_chunk_m + << ") does not evenly divide total M tiles (" << tiles_m << ")\n"; + } + + // Allocate and initialize chunk signals + uint32_t* chunk_signals_device = nullptr; + if(enable_async) + { + chunk_signals_device = allocate_chunk_signals(num_chunks, hipStreamDefault); + std::cout << "Allocated " << num_chunks << " chunk signals\n"; + } + + // Create async args + ck_tile::PersistentAsyncArgs async_args( + enable_async ? tiles_per_chunk_m : 0, chunk_signals_device, tile_idx_pivot_m); + + // Launch async data arrival simulation in background thread if enabled + std::thread data_arrival_thread; + if(enable_async) + { + data_arrival_thread = std::thread([&]() { + // Small delay before starting to simulate initial data latency + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + simulate_async_data_arrival(chunk_signals_device, num_chunks, hipStreamDefault, 5); + }); + } + + int result = 0; + try + { + if(data_type == "fp16") + { + using GemmConfig = GemmConfigMemoryInterwave; + result = + run_gemm_example_prec_type_persistent_async( + a_layout, b_layout, arg_parser, async_args); + } + else if(data_type == "bf16") + { + using GemmConfig = GemmConfigMemoryInterwave; + result = + run_gemm_example_prec_type_persistent_async( + a_layout, b_layout, arg_parser, async_args); + } + } + catch(const std::exception& e) + { + std::cerr << "Error: " << e.what() << std::endl; + result = -1; + } + + // Wait for data arrival thread to complete + if(data_arrival_thread.joinable()) + { + data_arrival_thread.join(); + } + + // Clean up + if(chunk_signals_device) + { + ck_tile::hip_check_error(hipFree(chunk_signals_device)); + } + + return result; +} + +int main(int argc, char* argv[]) +{ + auto arg_parser = create_args(); + + // Add async-specific arguments + arg_parser.insert( + "tiles_per_chunk_m", "1", "Number of M tiles per chunk (granularity of async readiness)"); + arg_parser.insert( + "tile_idx_pivot_m", "0", "Pivot offset for M dimension (for hotspot spreading)"); + arg_parser.insert("enable_async", "1", "Enable async input signaling (0=disabled, 1=enabled)"); + + auto result = arg_parser.parse(argc, argv); + + if(!result) + return -1; + + try + { + return !run_gemm_example(arg_parser); + } + catch(const std::exception& e) + { + std::cerr << "Fatal error: " << e.what() << std::endl; + return -1; + } +} diff --git a/example/ck_tile/03_gemm/gemm_persistent_async_invoker.hpp b/example/ck_tile/03_gemm/gemm_persistent_async_invoker.hpp new file mode 100644 index 00000000000..7ee8bd2ed11 --- /dev/null +++ b/example/ck_tile/03_gemm/gemm_persistent_async_invoker.hpp @@ -0,0 +1,256 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "gemm_utils.hpp" +#include "persistent_async_scheduler.hpp" + +/** + * @brief Invoker for Persistent Async GEMM + * + * This invoker implements persistent GEMM with asynchronous input readiness. + * It extends the standard GEMM with support for: + * - Chunk-based async input signaling + * - Producer-consumer synchronization + * - Pivot-based tile traversal + */ +struct PersistentAsyncInvoker +{ + template + static float gemm(const ck_tile::GemmHostArgs& args, + const ck_tile::stream_config& s, + const ck_tile::PersistentAsyncArgs& async_args) + { + static_assert(Persistent, "PersistentAsyncInvoker requires persistent kernel mode"); + + // Tile configuration + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile::sequence>; + + using TilePartitioner = + ck_tile::GemmSpatiallyLocalTilePartitioner; + + using Traits = ck_tile::TileGemmTraits; + + using GemmUniversalTraits = + ck_tile::TileGemmUniversalTraits; + + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; + + using BaseGemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template UniversalGemmPipeline; + + const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile; + const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * GemmConfig::K_Tile; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{0}; + + const auto Run = [&](const auto has_hot_loop_, + const auto tail_number_, + const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = GemmConfig::Scheduler; + constexpr auto memory_operation = memory_operation_.value; + + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + + using GemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template GemmPipeline; + + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + + using Kernel = ck_tile::GemmKernel; + + // Create standard kernel args + auto kargs = Kernel::MakeKernelArgs(args); + + // Use max occupancy grid for persistent kernel + const dim3 grids = Kernel::MaxOccupancyGridSize(s); + const dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error( + "Wrong! Arguments not supported for persistent async GEMM!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching Persistent Async GEMM kernel:\n" + << " Kernel: " << Kernel::GetName() << '\n' + << " Shape: " << GemmShape::GetName() << '\n' + << " Problem: " << UniversalGemmProblem::GetName() << '\n' + << " Pipeline: " << GemmPipeline::GetName() << '\n' + << " Grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}\n" + << " Blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z + << "}\n" + << " Async Args:\n" + << " tiles_per_chunk_m: " << async_args.tiles_per_chunk_m << '\n' + << " tile_idx_pivot_m: " << async_args.tile_idx_pivot_m << '\n' + << " chunk_signals: " + << (async_args.chunk_signals ? "enabled" : "disabled") << std::endl; + } + + // Validation: tiles_per_chunk_m must divide tiles_m evenly + ck_tile::index_t tiles_m = (args.M + GemmConfig::M_Tile - 1) / GemmConfig::M_Tile; + if(async_args.tiles_per_chunk_m > 0 && tiles_m % async_args.tiles_per_chunk_m != 0) + { + throw std::runtime_error("tiles_per_chunk_m must divide total M tiles evenly!"); + } + + auto clear_gemm_output = [&]() { + if(args.k_batch > 1) + hipGetErrorString(hipMemsetAsync( + args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); + }; + + // Prepare preprocessing + std::function preprocess = clear_gemm_output; + + /* + // Custom kernel wrapper that includes async scheduler + + ck_tile::index_t tiles_n; + + ck_tile::index_t grid_size; + auto persistent_async_kernel = [&](auto... kernel_args) { + // Get tiles info + tiles_m = + (args.M + GemmConfig::M_Tile - 1) / GemmConfig::M_Tile; + tiles_n = + (args.N + GemmConfig::N_Tile - 1) / GemmConfig::N_Tile; + grid_size = grids.x * grids.y; + + // Create persistent async scheduler + ck_tile::PersistentAsyncScheduler persistent_scheduler( + async_args, tiles_m, tiles_n, grid_size); + + // Persistent tile loop + while(true) + { + auto work_tile = persistent_scheduler.GetNextWorkTile(); + if(!work_tile.IsValid()) + break; + + // Execute GEMM for this tile + // This would call the actual kernel implementation + Kernel{}(kernel_args...); + + // Fence before next iteration + scheduler.IterationBoundaryFence(); + + // Advance to next tile + scheduler.AdvanceToNextTile(); + } + }; + */ + + // Note: The PersistentAsyncScheduler is integrated into the kernel itself + // (device-side), not managed from the host. For full async support, a custom kernel + // implementation would be needed that integrates PersistentAsyncScheduler in its tile + // loop. + // + // TODO: Integrate async_args into kernel arguments and modify the kernel implementation + // to use PersistentAsyncScheduler for work distribution with async signaling. + // For now, this launches the standard persistent kernel without async signaling. + + // TODO: Integrate async scheduler into the kernel + // The async_args parameter is currently not used by the kernel launch. + // To fully implement async input scheduling, we need to: + // 1. Create a custom kernel that extends GemmKernel + // 2. Pass async_args through kernel arguments (KernelArgs) + // 3. Integrate PersistentAsyncScheduler::GetNextWorkTile() into the + // persistent tile loop inside the kernel's operator() + // 4. Call wait_signal() for chunk readiness before processing tiles + // + // For now, suppress unused variable warning + + (void)async_args; + + ave_time = ck_tile::launch_kernel_time_mask( + s, + preprocess, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + + return ave_time; + }; + + const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { + if(args.k_batch == 1) + { + return Run(has_hot_loop_, tail_number_, MemoryOpSet{}); + } + else + { + return Run(has_hot_loop_, tail_number_, MemoryOpAtomicAdd{}); + } + }; + + return ave_time = BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + } +}; diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index 6d833fbd7a5..085a620fb43 100644 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -562,3 +562,144 @@ template float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s); + +// Internal implementation with compile-time layouts +template +int run_gemm_example_prec_type_persistent_async_impl(ck_tile::ArgParser& arg_parser, + const ck_tile::PersistentAsyncArgs& async_args, + const ALayout a_layout, + const BLayout b_layout, + const CLayout c_layout) +{ + using AccDataType = float; + + std::tuple gemm_sizes = + parse_gemm_size(arg_parser); + + ck_tile::GemmHostArgs args; + args.M = std::get<0>(gemm_sizes); + args.N = std::get<1>(gemm_sizes); + args.K = std::get<2>(gemm_sizes); + + args.k_batch = arg_parser.get_int("split_k"); + + // Get strides from arguments + ck_tile::index_t stride_A = arg_parser.get_int("stride_a"); + ck_tile::index_t stride_B = arg_parser.get_int("stride_b"); + ck_tile::index_t stride_C = arg_parser.get_int("stride_c"); + + // Apply default strides + args.stride_A = ck_tile::get_default_stride(args.M, args.K, stride_A, is_row_major(a_layout)); + args.stride_B = ck_tile::get_default_stride(args.K, args.N, stride_B, is_row_major(b_layout)); + args.stride_C = ck_tile::get_default_stride(args.M, args.N, stride_C, is_row_major(c_layout)); + + // Prepare host tensors + ck_tile::HostTensor a_m( + ck_tile::host_tensor_descriptor(args.M, args.K, args.stride_A, is_row_major(a_layout))); + ck_tile::HostTensor b_n( + ck_tile::host_tensor_descriptor(args.K, args.N, args.stride_B, is_row_major(b_layout))); + ck_tile::HostTensor c_m_host( + ck_tile::host_tensor_descriptor(args.M, args.N, args.stride_C, is_row_major(c_layout))); + ck_tile::HostTensor c_m_device( + ck_tile::host_tensor_descriptor(args.M, args.N, args.stride_C, is_row_major(c_layout))); + + // Initialize tensors + int init_method = arg_parser.get_int("init"); + switch(init_method) + { + case 0: ck_tile::FillUniformDistribution{0.f, 1.f}(a_m); break; + case 1: ck_tile::FillConstant{static_cast(1)}(a_m); break; + } + switch(init_method) + { + case 0: ck_tile::FillUniformDistribution{0.f, 1.f}(b_n); break; + case 1: ck_tile::FillConstant{static_cast(1.f)}(b_n); break; + } + + // Allocate device memory + ck_tile::DeviceMem a_m_dev(a_m.get_element_space_size_in_bytes()); + ck_tile::DeviceMem b_n_dev(b_n.get_element_space_size_in_bytes()); + ck_tile::DeviceMem c_m_dev(c_m_device.get_element_space_size_in_bytes()); + + a_m_dev.ToDevice(a_m.data()); + b_n_dev.ToDevice(b_n.data()); + + args.a_ptr = a_m_dev.GetDeviceBuffer(); + args.b_ptr = b_n_dev.GetDeviceBuffer(); + args.c_ptr = c_m_dev.GetDeviceBuffer(); + + // Setup stream config - extract parameters from arg_parser + int n_warmup = arg_parser.get_int("warmup"); + int n_repeat = arg_parser.get_int("repeat"); + bool flush_cache = arg_parser.get_bool("flush_cache"); + int rotating_count = arg_parser.get_int("rotating_count"); + + ck_tile::stream_config stream{ + nullptr, true, 1, n_warmup, n_repeat, true, flush_cache, rotating_count}; + + // Run persistent async GEMM + constexpr bool kPersistent = true; + using ElementWise = ck_tile::element_wise::PassThrough; + float ave_time = Invoker::template gemm, + AccDataType, + CDataType, + ALayout, + BLayout, + ck_tile::tuple<>, + CLayout, + kPersistent, + ElementWise>(args, stream, async_args); + + if(stream.log_level_ > 0) + { + std::cout << "Persistent Async GEMM completed in " << ave_time << " ms" << std::endl; + } + + // Validation + int validation_mode = arg_parser.get_int("v"); + if(validation_mode > 0) + { + c_m_dev.FromDevice(c_m_device.data()); + + auto err = ck_tile::get_relative_threshold(); + + return err > 0 ? 0 : -1; + } + + return 0; +} + +// Public wrapper that accepts string layouts and dispatches to appropriate implementation +template +int run_gemm_example_prec_type_persistent_async(const std::string& a_layout, + const std::string& b_layout, + ck_tile::ArgParser& arg_parser, + const ck_tile::PersistentAsyncArgs& async_args) +{ + (void)a_layout; + (void)b_layout; + return run_gemm_example_prec_type_persistent_async_impl( + arg_parser, + async_args, + ck_tile::tensor_layout::gemm::RowMajor{}, + ck_tile::tensor_layout::gemm::RowMajor{}, + ck_tile::tensor_layout::gemm::RowMajor{}); +} diff --git a/example/ck_tile/03_gemm/persistent_async_scheduler.hpp b/example/ck_tile/03_gemm/persistent_async_scheduler.hpp new file mode 100644 index 00000000000..5a48037d1b3 --- /dev/null +++ b/example/ck_tile/03_gemm/persistent_async_scheduler.hpp @@ -0,0 +1,243 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +/** + * @file persistent_async_scheduler.hpp + * @brief HIP-based Persistent Async Input Scheduler for CK Tile GEMM + * + * This file implements a persistent async input scheduler similar to CUTLASS's + * PersistentAsyncInputScheduler, adapted for AMD GPUs using HIP. + * + * Features: + * - tiles_per_chunk_m: Granularity at which data becomes ready + * - chunk_signals: Global memory flags indicating chunk readiness + * - tile_idx_pivot_m: Post-swizzle pivot to spread hotspots + * - Producer-consumer wait mechanism for async data readiness + */ + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" + +namespace ck_tile { + +/** + * @brief Wait for signal to become ready using HIP atomic operations + * + * @param addr Pointer to signal in global memory + * + * This function implements a busy-wait on a global memory flag using + * volatile loads with acquire semantics for AMD GPUs. + */ +__device__ void wait_signal(uint32_t* addr) +{ + uint32_t ready = __atomic_load_n(addr, __ATOMIC_ACQUIRE); + while(!ready) + { + // Use volatile load to prevent compiler optimization + asm volatile("flat_load_dword %0, %1 glc\n" + "s_waitcnt vmcnt(0)" + : "=v"(ready) + : "v"(addr) + : "memory"); + + // Add a small delay to reduce memory traffic + __builtin_amdgcn_s_sleep(1); + + ready = __atomic_load_n(addr, __ATOMIC_ACQUIRE); + } +} + +/** + * @brief Iteration boundary fence for persistent kernels + * + * Ensures all memory operations complete before moving to next tile. + * Required for safe LDS reuse in persistent loops. + */ +__device__ void iteration_boundary_fence() +{ + // Wait for all vector memory operations + __builtin_amdgcn_s_waitcnt(/*vmcnt*/ 0 | (/*lgkmcnt*/ 0 << 8)); + // Workgroup barrier + __builtin_amdgcn_s_barrier(); +} + +/** + * @brief Arguments for persistent async input scheduler + * + * This structure extends the standard GEMM arguments with async-specific + * parameters for controlling data readiness signaling and tile traversal. + */ +struct PersistentAsyncArgs +{ + /// @brief Number of M tiles in each chunk (granularity of data readiness) + index_t tiles_per_chunk_m = 0; + + /// @brief Pointer to chunk readiness signals in global memory + /// chunk_signals[i] == 1 indicates chunk i is ready for processing + uint32_t* chunk_signals = nullptr; + + /// @brief Pivot offset for M dimension after swizzling + /// Allows different ranks to process different M tiles simultaneously + index_t tile_idx_pivot_m = 0; + + CK_TILE_HOST PersistentAsyncArgs() = default; + + CK_TILE_HOST PersistentAsyncArgs(index_t tiles_per_chunk_m_, + uint32_t* chunk_signals_, + index_t tile_idx_pivot_m_) + : tiles_per_chunk_m(tiles_per_chunk_m_), + chunk_signals(chunk_signals_), + tile_idx_pivot_m(tile_idx_pivot_m_) + { + } +}; + +/** + * @brief Persistent Async Tile Scheduler Implementation + * + * This scheduler manages work distribution for persistent GEMM kernels + * with asynchronous input readiness. It extends the basic persistent + * scheduler with: + * - Chunk-based data readiness signaling + * - Producer-consumer synchronization + * - Pivot-based tile traversal for hotspot spreading + * + * @tparam TilePartitioner_ The tile partitioner type + */ +template +struct PersistentAsyncScheduler +{ + using TilePartitioner = remove_cvref_t; + + struct WorkTileInfo + { + index_t tile_idx_m; + index_t tile_idx_n; + index_t batch_idx; + bool is_valid; + + CK_TILE_DEVICE bool IsValid() const { return is_valid; } + + CK_TILE_DEVICE static WorkTileInfo InvalidTile() { return WorkTileInfo{-1, -1, -1, false}; } + }; + + struct SchedulerState + { + index_t current_tile_linear; + index_t total_tiles_m; + index_t total_tiles_n; + index_t total_grid_size; + bool is_mainloop_producer; + + // Async-specific state + index_t tiles_per_chunk_m; + uint32_t* chunk_signals; + index_t tile_idx_pivot_m; + }; + + CK_TILE_DEVICE PersistentAsyncScheduler(const PersistentAsyncArgs& async_args, + index_t tiles_m, + index_t tiles_n, + index_t grid_size) + { + state_.current_tile_linear = blockIdx.x + blockIdx.y * gridDim.x; + state_.total_tiles_m = tiles_m; + state_.total_tiles_n = tiles_n; + state_.total_grid_size = grid_size; + + // Determine if this wave is a mainloop producer + // Only the first wave in the first wave group is the producer + const index_t warp_id = threadIdx.x / warpSize; + const index_t wave_group_id = warp_id / 4; // 4 waves per wave group + state_.is_mainloop_producer = (wave_group_id == 0) && (warp_id % 4 == 0); + + // Async parameters + state_.tiles_per_chunk_m = async_args.tiles_per_chunk_m; + state_.chunk_signals = async_args.chunk_signals; + state_.tile_idx_pivot_m = async_args.tile_idx_pivot_m; + } + + /** + * @brief Get the next work tile for this workgroup + * + * This function: + * 1. Calculates the tile indices from linear work index + * 2. Applies pivot to M dimension for hotspot spreading + * 3. Waits for chunk signal if async mode is enabled + * 4. Returns tile info + */ + CK_TILE_DEVICE WorkTileInfo GetNextWorkTile() + { + const index_t linear_idx = state_.current_tile_linear; + + // Check if we've processed all tiles + const index_t total_tiles = state_.total_tiles_m * state_.total_tiles_n; + if(linear_idx >= total_tiles) + { + return WorkTileInfo::InvalidTile(); + } + + // Map linear index to 2D tile coordinates + // Using row-major traversal (can be extended for different patterns) + index_t tile_m = linear_idx / state_.total_tiles_n; + index_t tile_n = linear_idx % state_.total_tiles_n; + + // Apply pivot to M dimension after basic mapping + if(state_.tile_idx_pivot_m > 0) + { + tile_m = (tile_m + state_.tile_idx_pivot_m) % state_.total_tiles_m; + } + + // Wait for async input readiness if enabled + if(state_.chunk_signals != nullptr && state_.tiles_per_chunk_m > 0) + { + const index_t chunk_idx = tile_m / state_.tiles_per_chunk_m; + + // Producer lane waits for signal + if(state_.is_mainloop_producer && threadIdx.x == 0) + { + wait_signal(state_.chunk_signals + chunk_idx); + } + + // Synchronize all threads in workgroup after producer receives signal + __builtin_amdgcn_s_barrier(); + } + + return WorkTileInfo{tile_m, tile_n, 0, true}; + } + + /** + * @brief Advance to next work tile in persistent loop + */ + CK_TILE_DEVICE void AdvanceToNextTile() + { + state_.current_tile_linear += state_.total_grid_size; + } + + /** + * @brief Check if this is the last tile for this workgroup + */ + CK_TILE_DEVICE bool IsLastTile() const + { + const index_t total_tiles = state_.total_tiles_m * state_.total_tiles_n; + return (state_.current_tile_linear + state_.total_grid_size) >= total_tiles; + } + + /** + * @brief Synchronization fence between persistent loop iterations + * + * Must be called before moving to next tile to ensure: + * - All memory operations complete + * - LDS can be safely reused + */ + CK_TILE_DEVICE void IterationBoundaryFence() { iteration_boundary_fence(); } + + CK_TILE_DEVICE const SchedulerState& GetState() const { return state_; } + + private: + SchedulerState state_; +}; + +} // namespace ck_tile diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index 1c57a03c972..11aff7319ab 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -4,13 +4,6 @@ #include "ck_tile/host/permute_pk_int4.hpp" #include "ck_tile/host/tensor_shuffle_utils.hpp" -template -static constexpr inline auto is_row_major(Layout layout_) -{ - return ck_tile::bool_constant, - ck_tile::tensor_layout::gemm::RowMajor>>{}; -} - template auto calculate_rtol_atol(const ck_tile::index_t K, const ck_tile::index_t kbatch, From dc2c613a9cec565e385cb24ac16fee7f7011a732 Mon Sep 17 00:00:00 2001 From: Kumar Date: Thu, 20 Nov 2025 11:28:11 +0530 Subject: [PATCH 02/20] Refactor code + cleaning work --- .../ck_tile/03_gemm/gemm_persistent_async.cpp | 143 +----------- .../03_gemm/gemm_persistent_async_invoker.hpp | 196 ++++++++-------- example/ck_tile/03_gemm/gemm_utils.hpp | 141 ------------ .../03_gemm/persistent_async_scheduler.hpp | 212 +----------------- example/ck_tile/03_gemm/run_gemm_example.inc | 7 + 5 files changed, 124 insertions(+), 575 deletions(-) diff --git a/example/ck_tile/03_gemm/gemm_persistent_async.cpp b/example/ck_tile/03_gemm/gemm_persistent_async.cpp index f629d5b880d..a8d58850b10 100644 --- a/example/ck_tile/03_gemm/gemm_persistent_async.cpp +++ b/example/ck_tile/03_gemm/gemm_persistent_async.cpp @@ -59,139 +59,6 @@ static void signal_chunk_ready(uint32_t* signals, int chunk_idx, hipStream_t str &signals[chunk_idx], &ready, sizeof(uint32_t), hipMemcpyHostToDevice, stream)); } -/** - * @brief Simulate async data arrival by signaling chunks progressively - * - * In a real application, this would be triggered by actual data arrival events - * (e.g., network communication, file I/O, etc.) - */ -static void -simulate_async_data_arrival(uint32_t* signals, int num_chunks, hipStream_t stream, int delay_ms = 1) -{ - // Signal chunks one by one with a small delay - // In practice, this would be event-driven based on actual data availability - for(int i = 0; i < num_chunks; ++i) - { - // Simulate delay in data arrival - if(delay_ms > 0 && i > 0) - { - std::this_thread::sleep_for(std::chrono::milliseconds(delay_ms)); - } - - signal_chunk_ready(signals, i, stream); - } -} - -int run_gemm_example(ck_tile::ArgParser& arg_parser) -{ - std::string data_type = arg_parser.get_str("prec"); - std::string a_layout = arg_parser.get_str("a_layout"); - std::string b_layout = arg_parser.get_str("b_layout"); - std::string c_layout = arg_parser.get_str("c_layout"); - - std::tuple gemm_sizes = - parse_gemm_size(arg_parser); - - int m = std::get<0>(gemm_sizes); - int n = std::get<1>(gemm_sizes); - int k = std::get<2>(gemm_sizes); - - int stride_a = arg_parser.get_int("stride_a"); - int stride_b = arg_parser.get_int("stride_b"); - int stride_c = arg_parser.get_int("stride_c"); - - // Async-specific parameters - int tiles_per_chunk_m = arg_parser.get_int("tiles_per_chunk_m"); - int tile_idx_pivot_m = arg_parser.get_int("tile_idx_pivot_m"); - bool enable_async = arg_parser.get_int("enable_async") != 0; - - // using GemmConfig = GemmConfigMemoryInterwave; - using DefaultGemmConfig = GemmConfigMemoryInterwave; - using Invoker = PersistentAsyncInvoker; - - ck_tile::validate_gemm_stride( - a_layout, b_layout, c_layout, m, n, k, stride_a, stride_b, stride_c); - - std::cout << "=== Persistent Async GEMM Example ===\n"; - std::cout << "Matrix dimensions: M=" << m << ", N=" << n << ", K=" << k << '\n'; - std::cout << "Async parameters:\n"; - std::cout << " tiles_per_chunk_m: " << tiles_per_chunk_m << '\n'; - std::cout << " tile_idx_pivot_m: " << tile_idx_pivot_m << '\n'; - std::cout << " async_enabled: " << (enable_async ? "yes" : "no") << '\n'; - std::cout << "====================================\n\n"; - - // Calculate number of chunks - const int tiles_m = (m + DefaultGemmConfig::M_Tile - 1) / DefaultGemmConfig::M_Tile; - const int num_chunks = (tiles_m + tiles_per_chunk_m - 1) / tiles_per_chunk_m; - - if(tiles_m % tiles_per_chunk_m != 0) - { - std::cerr << "Warning: tiles_per_chunk_m (" << tiles_per_chunk_m - << ") does not evenly divide total M tiles (" << tiles_m << ")\n"; - } - - // Allocate and initialize chunk signals - uint32_t* chunk_signals_device = nullptr; - if(enable_async) - { - chunk_signals_device = allocate_chunk_signals(num_chunks, hipStreamDefault); - std::cout << "Allocated " << num_chunks << " chunk signals\n"; - } - - // Create async args - ck_tile::PersistentAsyncArgs async_args( - enable_async ? tiles_per_chunk_m : 0, chunk_signals_device, tile_idx_pivot_m); - - // Launch async data arrival simulation in background thread if enabled - std::thread data_arrival_thread; - if(enable_async) - { - data_arrival_thread = std::thread([&]() { - // Small delay before starting to simulate initial data latency - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - simulate_async_data_arrival(chunk_signals_device, num_chunks, hipStreamDefault, 5); - }); - } - - int result = 0; - try - { - if(data_type == "fp16") - { - using GemmConfig = GemmConfigMemoryInterwave; - result = - run_gemm_example_prec_type_persistent_async( - a_layout, b_layout, arg_parser, async_args); - } - else if(data_type == "bf16") - { - using GemmConfig = GemmConfigMemoryInterwave; - result = - run_gemm_example_prec_type_persistent_async( - a_layout, b_layout, arg_parser, async_args); - } - } - catch(const std::exception& e) - { - std::cerr << "Error: " << e.what() << std::endl; - result = -1; - } - - // Wait for data arrival thread to complete - if(data_arrival_thread.joinable()) - { - data_arrival_thread.join(); - } - - // Clean up - if(chunk_signals_device) - { - ck_tile::hip_check_error(hipFree(chunk_signals_device)); - } - - return result; -} - int main(int argc, char* argv[]) { auto arg_parser = create_args(); @@ -208,13 +75,5 @@ int main(int argc, char* argv[]) if(!result) return -1; - try - { - return !run_gemm_example(arg_parser); - } - catch(const std::exception& e) - { - std::cerr << "Fatal error: " << e.what() << std::endl; - return -1; - } + return 0; } diff --git a/example/ck_tile/03_gemm/gemm_persistent_async_invoker.hpp b/example/ck_tile/03_gemm/gemm_persistent_async_invoker.hpp index 7ee8bd2ed11..b00d6965737 100644 --- a/example/ck_tile/03_gemm/gemm_persistent_async_invoker.hpp +++ b/example/ck_tile/03_gemm/gemm_persistent_async_invoker.hpp @@ -15,7 +15,8 @@ * - Producer-consumer synchronization * - Pivot-based tile traversal */ -struct PersistentAsyncInvoker + +struct GemmPersistentAsyncInvoker { template , ck_tile::sequence, @@ -45,15 +46,15 @@ struct PersistentAsyncInvoker using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner; + GemmConfig::TilePartitionerGroupNum, + GemmConfig::TilePartitionerM01>; using Traits = ck_tile::TileGemmTraits; using GemmUniversalTraits = @@ -81,7 +82,6 @@ struct PersistentAsyncInvoker const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); - float ave_time{0}; const auto Run = [&](const auto has_hot_loop_, @@ -104,14 +104,16 @@ struct PersistentAsyncInvoker using GemmPipeline = typename PipelineTypeTraits< GemmConfig::Pipeline>::template GemmPipeline; + using WorkspaceType = ck_tile::remove_cvref_t; + using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::GemmKernel; + using GemmKernel = ck_tile::GemmKernel; - // Create standard kernel args - auto kargs = Kernel::MakeKernelArgs(args); + ck_tile::DeviceMem ws_m_n_dev_buf(args.M * args.N * sizeof(WorkspaceType)); + ck_tile::GemmHostArgs ws_args = ck_tile::GemmHostArgs(args); + auto c_ptr = ws_args.c_ptr; + ws_args.c_ptr = ws_m_n_dev_buf.GetDeviceBuffer(); + auto gemm_kargs = GemmKernel::MakeKernelArgs(ws_args); - // Use max occupancy grid for persistent kernel - const dim3 grids = Kernel::MaxOccupancyGridSize(s); - const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Persistent ? GemmKernel::MaxOccupancyGridSize(s) + : GemmKernel::GridSize(args.M, args.N, args.k_batch); + const dim3 blocks = GemmKernel::BlockSize(); - if(!Kernel::IsSupportedArgument(kargs)) + if(!GemmKernel::IsSupportedArgument(gemm_kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + + using XElementwiseOperation = ck_tile::element_wise::UnaryConvert; + using BlockTile = ck_tile::sequence<2048>; + using BlockWarps = ck_tile::sequence<8>; + using WarpTile = ck_tile::sequence<64>; + + using ElementwiseShape = + ck_tile::ElementWiseShape; + using Problem = ck_tile::ElementWisePipelineProblem; + using ElementwiseKernel = + ck_tile::ElementWiseKernel; + + ck_tile::index_t total_elements = 1; + std::vector shape = {args.M, args.N}; + + for(auto d : shape) + total_elements *= d; + + const ck_tile::index_t kBlockSize = ElementwiseKernel::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = 1; + + constexpr ck_tile::index_t elements_per_block = BlockTile::at(ck_tile::number<0>{}); + ck_tile::index_t kGridSize = + (total_elements + elements_per_block - 1) / elements_per_block; + + auto input_tensors = ck_tile::make_tuple(static_cast(ws_args.c_ptr)); + auto input_size = ck_tile::make_tuple(args.M, args.N); + + // Check if the kernel configuration is supported + if(!ElementwiseKernel::IsSupportedArgument(input_size)) { throw std::runtime_error( - "Wrong! Arguments not supported for persistent async GEMM!\n"); + "Wrong! Elementwise arguments not supported! Skipping gemm!\n"); } if(s.log_level_ > 0) @@ -156,91 +198,67 @@ struct PersistentAsyncInvoker << (async_args.chunk_signals ? "enabled" : "disabled") << std::endl; } - // Validation: tiles_per_chunk_m must divide tiles_m evenly - ck_tile::index_t tiles_m = (args.M + GemmConfig::M_Tile - 1) / GemmConfig::M_Tile; - if(async_args.tiles_per_chunk_m > 0 && tiles_m % async_args.tiles_per_chunk_m != 0) - { - throw std::runtime_error("tiles_per_chunk_m must divide total M tiles evenly!"); - } + // Declare rotating_mem_ptr here so it stays in scope until it is needed + std::unique_ptr> rotating_mem_ptr; + std::function preprocess; auto clear_gemm_output = [&]() { if(args.k_batch > 1) hipGetErrorString(hipMemsetAsync( - args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); + ws_args.c_ptr, 0, args.M * args.N * sizeof(WorkspaceType), s.stream_id_)); }; - // Prepare preprocessing - std::function preprocess = clear_gemm_output; - - /* - // Custom kernel wrapper that includes async scheduler - - ck_tile::index_t tiles_n; - - ck_tile::index_t grid_size; - auto persistent_async_kernel = [&](auto... kernel_args) { - // Get tiles info - tiles_m = - (args.M + GemmConfig::M_Tile - 1) / GemmConfig::M_Tile; - tiles_n = - (args.N + GemmConfig::N_Tile - 1) / GemmConfig::N_Tile; - grid_size = grids.x * grids.y; - - // Create persistent async scheduler - ck_tile::PersistentAsyncScheduler persistent_scheduler( - async_args, tiles_m, tiles_n, grid_size); - - // Persistent tile loop - while(true) - { - auto work_tile = persistent_scheduler.GetNextWorkTile(); - if(!work_tile.IsValid()) - break; - - // Execute GEMM for this tile - // This would call the actual kernel implementation - Kernel{}(kernel_args...); - - // Fence before next iteration - scheduler.IterationBoundaryFence(); - - // Advance to next tile - scheduler.AdvanceToNextTile(); - } - }; - */ - - // Note: The PersistentAsyncScheduler is integrated into the kernel itself - // (device-side), not managed from the host. For full async support, a custom kernel - // implementation would be needed that integrates PersistentAsyncScheduler in its tile - // loop. - // - // TODO: Integrate async_args into kernel arguments and modify the kernel implementation - // to use PersistentAsyncScheduler for work distribution with async signaling. - // For now, this launches the standard persistent kernel without async signaling. - - // TODO: Integrate async scheduler into the kernel - // The async_args parameter is currently not used by the kernel launch. - // To fully implement async input scheduling, we need to: - // 1. Create a custom kernel that extends GemmKernel - // 2. Pass async_args through kernel arguments (KernelArgs) - // 3. Integrate PersistentAsyncScheduler::GetNextWorkTile() into the - // persistent tile loop inside the kernel's operator() - // 4. Call wait_signal() for chunk readiness before processing tiles - // - // For now, suppress unused variable warning - - (void)async_args; + if(s.flush_cache_) + { + std::cout << "Flushing cache..." << std::endl; + + ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( + args.M, args.K, args.stride_A, is_row_major(ALayout{}))); + ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( + args.K, args.N, args.stride_B, is_row_major(BLayout{}))); + + auto size_a_buffer = a_m.get_element_space_size_in_bytes(); + auto size_b_buffer = b_n.get_element_space_size_in_bytes(); + + rotating_mem_ptr = + std::make_unique>( + gemm_kargs.as_ptr[0], + gemm_kargs.bs_ptr[0], + s.rotating_count_, + size_a_buffer, + size_b_buffer); + rotating_mem_ptr->Print(); + + preprocess = [&]() { + ck_tile::flush_icache(); + rotating_mem_ptr->Next(); + clear_gemm_output(); + }; + } + else + { + preprocess = clear_gemm_output; + } ave_time = ck_tile::launch_kernel_time_mask( s, preprocess, - ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + ck_tile::make_kernel( + GemmKernel{}, grids, blocks, 0, gemm_kargs), + ck_tile::make_kernel(ElementwiseKernel{}, + kGridSize, + kBlockSize, + 0, + input_size, + ck_tile::make_tuple(args.N, 1), // Input Stride + ck_tile::make_tuple(args.N, 1), // Output Stride + input_tensors, + static_cast(c_ptr))); return ave_time; }; - const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { + const auto RunSplitK = [&](const auto has_hot_loop_, const auto tail_number_) { if(args.k_batch == 1) { return Run(has_hot_loop_, tail_number_, MemoryOpSet{}); @@ -251,6 +269,6 @@ struct PersistentAsyncInvoker } }; - return ave_time = BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + return ave_time = BaseGemmPipeline::TailHandler(RunSplitK, has_hot_loop, tail_num); } }; diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index 085a620fb43..6d833fbd7a5 100644 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -562,144 +562,3 @@ template float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s); - -// Internal implementation with compile-time layouts -template -int run_gemm_example_prec_type_persistent_async_impl(ck_tile::ArgParser& arg_parser, - const ck_tile::PersistentAsyncArgs& async_args, - const ALayout a_layout, - const BLayout b_layout, - const CLayout c_layout) -{ - using AccDataType = float; - - std::tuple gemm_sizes = - parse_gemm_size(arg_parser); - - ck_tile::GemmHostArgs args; - args.M = std::get<0>(gemm_sizes); - args.N = std::get<1>(gemm_sizes); - args.K = std::get<2>(gemm_sizes); - - args.k_batch = arg_parser.get_int("split_k"); - - // Get strides from arguments - ck_tile::index_t stride_A = arg_parser.get_int("stride_a"); - ck_tile::index_t stride_B = arg_parser.get_int("stride_b"); - ck_tile::index_t stride_C = arg_parser.get_int("stride_c"); - - // Apply default strides - args.stride_A = ck_tile::get_default_stride(args.M, args.K, stride_A, is_row_major(a_layout)); - args.stride_B = ck_tile::get_default_stride(args.K, args.N, stride_B, is_row_major(b_layout)); - args.stride_C = ck_tile::get_default_stride(args.M, args.N, stride_C, is_row_major(c_layout)); - - // Prepare host tensors - ck_tile::HostTensor a_m( - ck_tile::host_tensor_descriptor(args.M, args.K, args.stride_A, is_row_major(a_layout))); - ck_tile::HostTensor b_n( - ck_tile::host_tensor_descriptor(args.K, args.N, args.stride_B, is_row_major(b_layout))); - ck_tile::HostTensor c_m_host( - ck_tile::host_tensor_descriptor(args.M, args.N, args.stride_C, is_row_major(c_layout))); - ck_tile::HostTensor c_m_device( - ck_tile::host_tensor_descriptor(args.M, args.N, args.stride_C, is_row_major(c_layout))); - - // Initialize tensors - int init_method = arg_parser.get_int("init"); - switch(init_method) - { - case 0: ck_tile::FillUniformDistribution{0.f, 1.f}(a_m); break; - case 1: ck_tile::FillConstant{static_cast(1)}(a_m); break; - } - switch(init_method) - { - case 0: ck_tile::FillUniformDistribution{0.f, 1.f}(b_n); break; - case 1: ck_tile::FillConstant{static_cast(1.f)}(b_n); break; - } - - // Allocate device memory - ck_tile::DeviceMem a_m_dev(a_m.get_element_space_size_in_bytes()); - ck_tile::DeviceMem b_n_dev(b_n.get_element_space_size_in_bytes()); - ck_tile::DeviceMem c_m_dev(c_m_device.get_element_space_size_in_bytes()); - - a_m_dev.ToDevice(a_m.data()); - b_n_dev.ToDevice(b_n.data()); - - args.a_ptr = a_m_dev.GetDeviceBuffer(); - args.b_ptr = b_n_dev.GetDeviceBuffer(); - args.c_ptr = c_m_dev.GetDeviceBuffer(); - - // Setup stream config - extract parameters from arg_parser - int n_warmup = arg_parser.get_int("warmup"); - int n_repeat = arg_parser.get_int("repeat"); - bool flush_cache = arg_parser.get_bool("flush_cache"); - int rotating_count = arg_parser.get_int("rotating_count"); - - ck_tile::stream_config stream{ - nullptr, true, 1, n_warmup, n_repeat, true, flush_cache, rotating_count}; - - // Run persistent async GEMM - constexpr bool kPersistent = true; - using ElementWise = ck_tile::element_wise::PassThrough; - float ave_time = Invoker::template gemm, - AccDataType, - CDataType, - ALayout, - BLayout, - ck_tile::tuple<>, - CLayout, - kPersistent, - ElementWise>(args, stream, async_args); - - if(stream.log_level_ > 0) - { - std::cout << "Persistent Async GEMM completed in " << ave_time << " ms" << std::endl; - } - - // Validation - int validation_mode = arg_parser.get_int("v"); - if(validation_mode > 0) - { - c_m_dev.FromDevice(c_m_device.data()); - - auto err = ck_tile::get_relative_threshold(); - - return err > 0 ? 0 : -1; - } - - return 0; -} - -// Public wrapper that accepts string layouts and dispatches to appropriate implementation -template -int run_gemm_example_prec_type_persistent_async(const std::string& a_layout, - const std::string& b_layout, - ck_tile::ArgParser& arg_parser, - const ck_tile::PersistentAsyncArgs& async_args) -{ - (void)a_layout; - (void)b_layout; - return run_gemm_example_prec_type_persistent_async_impl( - arg_parser, - async_args, - ck_tile::tensor_layout::gemm::RowMajor{}, - ck_tile::tensor_layout::gemm::RowMajor{}, - ck_tile::tensor_layout::gemm::RowMajor{}); -} diff --git a/example/ck_tile/03_gemm/persistent_async_scheduler.hpp b/example/ck_tile/03_gemm/persistent_async_scheduler.hpp index 5a48037d1b3..82ea927ab52 100644 --- a/example/ck_tile/03_gemm/persistent_async_scheduler.hpp +++ b/example/ck_tile/03_gemm/persistent_async_scheduler.hpp @@ -1,20 +1,6 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. -/** - * @file persistent_async_scheduler.hpp - * @brief HIP-based Persistent Async Input Scheduler for CK Tile GEMM - * - * This file implements a persistent async input scheduler similar to CUTLASS's - * PersistentAsyncInputScheduler, adapted for AMD GPUs using HIP. - * - * Features: - * - tiles_per_chunk_m: Granularity at which data becomes ready - * - chunk_signals: Global memory flags indicating chunk readiness - * - tile_idx_pivot_m: Post-swizzle pivot to spread hotspots - * - Producer-consumer wait mechanism for async data readiness - */ - #pragma once #include "ck_tile/core.hpp" @@ -22,19 +8,12 @@ namespace ck_tile { -/** - * @brief Wait for signal to become ready using HIP atomic operations - * - * @param addr Pointer to signal in global memory - * - * This function implements a busy-wait on a global memory flag using - * volatile loads with acquire semantics for AMD GPUs. - */ -__device__ void wait_signal(uint32_t* addr) +__device__ wait_signal(uint32_t* signal_addr) { - uint32_t ready = __atomic_load_n(addr, __ATOMIC_ACQUIRE); + uint32_t ready = __atomic_load_n(signal_addr, __ATOMIC_ACQUIRE); while(!ready) { + // Use volatile load to prevent compiler optimization asm volatile("flat_load_dword %0, %1 glc\n" "s_waitcnt vmcnt(0)" @@ -45,199 +24,26 @@ __device__ void wait_signal(uint32_t* addr) // Add a small delay to reduce memory traffic __builtin_amdgcn_s_sleep(1); - ready = __atomic_load_n(addr, __ATOMIC_ACQUIRE); + ready = __atomic_load_n(signal_addr, __ATOMIC_ACQUIRE); } } -/** - * @brief Iteration boundary fence for persistent kernels - * - * Ensures all memory operations complete before moving to next tile. - * Required for safe LDS reuse in persistent loops. - */ -__device__ void iteration_boundary_fence() -{ - // Wait for all vector memory operations - __builtin_amdgcn_s_waitcnt(/*vmcnt*/ 0 | (/*lgkmcnt*/ 0 << 8)); - // Workgroup barrier - __builtin_amdgcn_s_barrier(); -} - -/** - * @brief Arguments for persistent async input scheduler - * - * This structure extends the standard GEMM arguments with async-specific - * parameters for controlling data readiness signaling and tile traversal. - */ struct PersistentAsyncArgs { - /// @brief Number of M tiles in each chunk (granularity of data readiness) index_t tiles_per_chunk_m = 0; - /// @brief Pointer to chunk readiness signals in global memory - /// chunk_signals[i] == 1 indicates chunk i is ready for processing uint32_t* chunk_signals = nullptr; - /// @brief Pivot offset for M dimension after swizzling - /// Allows different ranks to process different M tiles simultaneously index_t tile_idx_pivot_m = 0; - CK_TILE_HOST PersistentAsyncArgs() = default; - - CK_TILE_HOST PersistentAsyncArgs(index_t tiles_per_chunk_m_, - uint32_t* chunk_signals_, - index_t tile_idx_pivot_m_) + PersistentAsyncArgs(index_t tiles_per_chunk_m_, + uint32_t* chunk_signals_, + index_t tile_idx_pivot_m_, + bool enable_async_) : tiles_per_chunk_m(tiles_per_chunk_m_), chunk_signals(chunk_signals_), tile_idx_pivot_m(tile_idx_pivot_m_) { } }; - -/** - * @brief Persistent Async Tile Scheduler Implementation - * - * This scheduler manages work distribution for persistent GEMM kernels - * with asynchronous input readiness. It extends the basic persistent - * scheduler with: - * - Chunk-based data readiness signaling - * - Producer-consumer synchronization - * - Pivot-based tile traversal for hotspot spreading - * - * @tparam TilePartitioner_ The tile partitioner type - */ -template -struct PersistentAsyncScheduler -{ - using TilePartitioner = remove_cvref_t; - - struct WorkTileInfo - { - index_t tile_idx_m; - index_t tile_idx_n; - index_t batch_idx; - bool is_valid; - - CK_TILE_DEVICE bool IsValid() const { return is_valid; } - - CK_TILE_DEVICE static WorkTileInfo InvalidTile() { return WorkTileInfo{-1, -1, -1, false}; } - }; - - struct SchedulerState - { - index_t current_tile_linear; - index_t total_tiles_m; - index_t total_tiles_n; - index_t total_grid_size; - bool is_mainloop_producer; - - // Async-specific state - index_t tiles_per_chunk_m; - uint32_t* chunk_signals; - index_t tile_idx_pivot_m; - }; - - CK_TILE_DEVICE PersistentAsyncScheduler(const PersistentAsyncArgs& async_args, - index_t tiles_m, - index_t tiles_n, - index_t grid_size) - { - state_.current_tile_linear = blockIdx.x + blockIdx.y * gridDim.x; - state_.total_tiles_m = tiles_m; - state_.total_tiles_n = tiles_n; - state_.total_grid_size = grid_size; - - // Determine if this wave is a mainloop producer - // Only the first wave in the first wave group is the producer - const index_t warp_id = threadIdx.x / warpSize; - const index_t wave_group_id = warp_id / 4; // 4 waves per wave group - state_.is_mainloop_producer = (wave_group_id == 0) && (warp_id % 4 == 0); - - // Async parameters - state_.tiles_per_chunk_m = async_args.tiles_per_chunk_m; - state_.chunk_signals = async_args.chunk_signals; - state_.tile_idx_pivot_m = async_args.tile_idx_pivot_m; - } - - /** - * @brief Get the next work tile for this workgroup - * - * This function: - * 1. Calculates the tile indices from linear work index - * 2. Applies pivot to M dimension for hotspot spreading - * 3. Waits for chunk signal if async mode is enabled - * 4. Returns tile info - */ - CK_TILE_DEVICE WorkTileInfo GetNextWorkTile() - { - const index_t linear_idx = state_.current_tile_linear; - - // Check if we've processed all tiles - const index_t total_tiles = state_.total_tiles_m * state_.total_tiles_n; - if(linear_idx >= total_tiles) - { - return WorkTileInfo::InvalidTile(); - } - - // Map linear index to 2D tile coordinates - // Using row-major traversal (can be extended for different patterns) - index_t tile_m = linear_idx / state_.total_tiles_n; - index_t tile_n = linear_idx % state_.total_tiles_n; - - // Apply pivot to M dimension after basic mapping - if(state_.tile_idx_pivot_m > 0) - { - tile_m = (tile_m + state_.tile_idx_pivot_m) % state_.total_tiles_m; - } - - // Wait for async input readiness if enabled - if(state_.chunk_signals != nullptr && state_.tiles_per_chunk_m > 0) - { - const index_t chunk_idx = tile_m / state_.tiles_per_chunk_m; - - // Producer lane waits for signal - if(state_.is_mainloop_producer && threadIdx.x == 0) - { - wait_signal(state_.chunk_signals + chunk_idx); - } - - // Synchronize all threads in workgroup after producer receives signal - __builtin_amdgcn_s_barrier(); - } - - return WorkTileInfo{tile_m, tile_n, 0, true}; - } - - /** - * @brief Advance to next work tile in persistent loop - */ - CK_TILE_DEVICE void AdvanceToNextTile() - { - state_.current_tile_linear += state_.total_grid_size; - } - - /** - * @brief Check if this is the last tile for this workgroup - */ - CK_TILE_DEVICE bool IsLastTile() const - { - const index_t total_tiles = state_.total_tiles_m * state_.total_tiles_n; - return (state_.current_tile_linear + state_.total_grid_size) >= total_tiles; - } - - /** - * @brief Synchronization fence between persistent loop iterations - * - * Must be called before moving to next tile to ensure: - * - All memory operations complete - * - LDS can be safely reused - */ - CK_TILE_DEVICE void IterationBoundaryFence() { iteration_boundary_fence(); } - - CK_TILE_DEVICE const SchedulerState& GetState() const { return state_; } - - private: - SchedulerState state_; -}; - -} // namespace ck_tile +} // namespace ck_tile \ No newline at end of file diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index 11aff7319ab..1c57a03c972 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -4,6 +4,13 @@ #include "ck_tile/host/permute_pk_int4.hpp" #include "ck_tile/host/tensor_shuffle_utils.hpp" +template +static constexpr inline auto is_row_major(Layout layout_) +{ + return ck_tile::bool_constant, + ck_tile::tensor_layout::gemm::RowMajor>>{}; +} + template auto calculate_rtol_atol(const ck_tile::index_t K, const ck_tile::index_t kbatch, From c38a4ecc263f8c469194b7aced9616c001bce4bd Mon Sep 17 00:00:00 2001 From: Kumar Date: Thu, 20 Nov 2025 12:27:10 +0530 Subject: [PATCH 03/20] Add safe iteration --- .../ck_tile/03_gemm/gemm_persistent_async.cpp | 17 +--- .../03_gemm/gemm_persistent_async_invoker.hpp | 5 ++ .../03_gemm/persistent_async_scheduler.hpp | 26 ++---- .../ck_tile/ops/gemm/kernel/gemm_kernel.hpp | 16 +++- .../ops/gemm/kernel/universal_gemm_kernel.hpp | 88 ++++++++++++++++++- 5 files changed, 111 insertions(+), 41 deletions(-) diff --git a/example/ck_tile/03_gemm/gemm_persistent_async.cpp b/example/ck_tile/03_gemm/gemm_persistent_async.cpp index a8d58850b10..ecaa1d5d0b8 100644 --- a/example/ck_tile/03_gemm/gemm_persistent_async.cpp +++ b/example/ck_tile/03_gemm/gemm_persistent_async.cpp @@ -1,21 +1,6 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. -/** - * @file gemm_persistent_async.cpp - * @brief Example demonstrating persistent GEMM with async input readiness - * - * This example shows how to use the PersistentAsyncScheduler for GEMM operations - * where input data becomes ready asynchronously in chunks. This is particularly - * useful in distributed computing scenarios where data arrives incrementally. - * - * Features demonstrated: - * - Chunk-based async input signaling - * - Producer-consumer synchronization - * - Pivot-based tile traversal for hotspot spreading - * - Persistent kernel execution - */ - #include "gemm_utils.hpp" #include "run_gemm_example.inc" #include "run_gemm_example_common.hpp" @@ -72,6 +57,8 @@ int main(int argc, char* argv[]) auto result = arg_parser.parse(argc, argv); + // TO-DO Add example + if(!result) return -1; diff --git a/example/ck_tile/03_gemm/gemm_persistent_async_invoker.hpp b/example/ck_tile/03_gemm/gemm_persistent_async_invoker.hpp index b00d6965737..4877853d2ca 100644 --- a/example/ck_tile/03_gemm/gemm_persistent_async_invoker.hpp +++ b/example/ck_tile/03_gemm/gemm_persistent_async_invoker.hpp @@ -132,6 +132,11 @@ struct GemmPersistentAsyncInvoker ck_tile::GemmHostArgs ws_args = ck_tile::GemmHostArgs(args); auto c_ptr = ws_args.c_ptr; ws_args.c_ptr = ws_m_n_dev_buf.GetDeviceBuffer(); + + // Add persistent async arguments to ws_args + ws_args.chunk_signals = async_args.chunk_signals; + ws_args.tiles_per_chunk_m = async_args.tiles_per_chunk_m; + auto gemm_kargs = GemmKernel::MakeKernelArgs(ws_args); const dim3 grids = Persistent ? GemmKernel::MaxOccupancyGridSize(s) diff --git a/example/ck_tile/03_gemm/persistent_async_scheduler.hpp b/example/ck_tile/03_gemm/persistent_async_scheduler.hpp index 82ea927ab52..b691d333066 100644 --- a/example/ck_tile/03_gemm/persistent_async_scheduler.hpp +++ b/example/ck_tile/03_gemm/persistent_async_scheduler.hpp @@ -8,26 +8,12 @@ namespace ck_tile { -__device__ wait_signal(uint32_t* signal_addr) -{ - uint32_t ready = __atomic_load_n(signal_addr, __ATOMIC_ACQUIRE); - while(!ready) - { - - // Use volatile load to prevent compiler optimization - asm volatile("flat_load_dword %0, %1 glc\n" - "s_waitcnt vmcnt(0)" - : "=v"(ready) - : "v"(addr) - : "memory"); - - // Add a small delay to reduce memory traffic - __builtin_amdgcn_s_sleep(1); - - ready = __atomic_load_n(signal_addr, __ATOMIC_ACQUIRE); - } -} - +/** + * @brief Arguments for Persistent Async GEMM scheduling + * + * This structure contains parameters for producer-consumer synchronization + * in persistent GEMM kernels with asynchronous input readiness. + */ struct PersistentAsyncArgs { index_t tiles_per_chunk_m = 0; diff --git a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp index d632b1596ca..76763a7c250 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp @@ -37,7 +37,9 @@ struct GemmHostArgs index_t K_, index_t stride_A_, index_t stride_B_, - index_t stride_E_) + index_t stride_E_, + uint32_t* chunk_signals_ = nullptr, + index_t tiles_per_chunk_m_ = 0) : a_ptr(a_ptr_), b_ptr(b_ptr_), e_ptr(e_ptr_), @@ -47,7 +49,9 @@ struct GemmHostArgs stride_A(stride_A_), stride_B(stride_B_), stride_E(stride_E_), - k_batch(k_batch_) + k_batch(k_batch_), + chunk_signals(chunk_signals_), + tiles_per_chunk_m(tiles_per_chunk_m_) { } @@ -72,6 +76,10 @@ struct GemmHostArgs }; index_t k_batch; + + // Persistent async arguments + uint32_t* chunk_signals; + index_t tiles_per_chunk_m; }; template @@ -153,7 +161,9 @@ struct GemmKernel {hostArgs.stride_A}, {hostArgs.stride_B}, {/*hostArgs.stride_Ds*/}, - hostArgs.stride_E)); + hostArgs.stride_E, + hostArgs.chunk_signals, + hostArgs.tiles_per_chunk_m)); } CK_TILE_HOST static auto diff --git a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp index e77355ed3dc..c71b6a2ff2a 100644 --- a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp @@ -16,6 +16,60 @@ namespace ck_tile { +/** + * @brief Wait for a signal to become ready with acquire semantics + * + * Producer-only wait: One lane polls chunk_signals[chunk_idx] with acquire semantics, + * then a workgroup barrier releases everyone. + * + * @param signal_addr Pointer to the signal location in device memory + */ +CK_TILE_DEVICE static inline void wait_signal(uint32_t* signal_addr) +{ + // Only one thread in the workgroup polls the signal + if(threadIdx.x == 0) + { + uint32_t ready = 0; + while(!ready) + { + // Load with acquire semantics using AMD intrinsics + // glc (globally coherent) ensures visibility across the system + asm volatile("flat_load_dword %0, %1 glc\n\t" + "s_waitcnt vmcnt(0)" + : "=v"(ready) + : "v"(signal_addr) + : "memory"); + + // Add a small delay to reduce memory traffic + if(!ready) + { + __builtin_amdgcn_s_sleep(1); + } + } + } + + // Workgroup barrier to release all threads after signal is ready + __builtin_amdgcn_s_barrier(); +} + +/** + * @brief Fence for safe iteration boundaries in persistent loops + * + * Ensures all memory operations are complete before reusing LDS or moving to next tile. + * Uses s_waitcnt vmcnt=0, lgkmcnt=0 + s_barrier. + */ +CK_TILE_DEVICE static inline void iteration_boundary_fence() +{ + // Wait for all vector memory operations (global memory loads/stores) + __builtin_amdgcn_s_waitcnt_vmcnt(0); + + // Wait for all LDS operations + __builtin_amdgcn_s_waitcnt_lgkmcnt(0); + + // Synchronize all threads in the workgroup + __builtin_amdgcn_s_barrier(); +} + /// @brief The Universal GEMM kernel host arguments. /// /// @par Overview @@ -41,7 +95,9 @@ struct UniversalGemmHostArgs const std::array& stride_As_, const std::array& stride_Bs_, const std::array& stride_Ds_, - index_t stride_E_) + index_t stride_E_, + uint32_t* chunk_signals_ = nullptr, + index_t tiles_per_chunk_m_ = 0) : as_ptr(as_ptr_), bs_ptr(bs_ptr_), ds_ptr(ds_ptr_), @@ -53,7 +109,9 @@ struct UniversalGemmHostArgs stride_Bs(stride_Bs_), stride_Ds(stride_Ds_), stride_E(stride_E_), - k_batch(k_batch_) + k_batch(k_batch_), + chunk_signals(chunk_signals_), + tiles_per_chunk_m(tiles_per_chunk_m_) { } @@ -78,6 +136,10 @@ struct UniversalGemmHostArgs }; index_t k_batch; + + // Persistent async arguments + uint32_t* chunk_signals; + index_t tiles_per_chunk_m; }; /// @brief The GEMM kernel device arguments. @@ -111,6 +173,12 @@ struct UniversalGemmKernelArgs /// (in memory) of E tensor. index_t stride_E; index_t k_batch; + + /// @brief Pointer to chunk signals for async producer-consumer synchronization. + /// chunk_signals[i] == 1 indicates that chunk i is ready. + uint32_t* chunk_signals; + /// @brief Number of M tiles per chunk for async input signaling. + index_t tiles_per_chunk_m; }; /// @brief The Universal GEMM kernel template. @@ -313,7 +381,9 @@ struct UniversalGemmKernel hostArgs.stride_Bs, hostArgs.stride_Ds, hostArgs.stride_E, - hostArgs.k_batch}; + hostArgs.k_batch, + hostArgs.chunk_signals, + hostArgs.tiles_per_chunk_m}; } CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() @@ -1140,6 +1210,13 @@ struct UniversalGemmKernel const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(tile_idx); const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock); const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock); + + // Producer-consumer synchronization: wait for chunk to be ready + if(kargs.chunk_signals != nullptr && kargs.tiles_per_chunk_m > 0) + { + const index_t chunk_idx = iM / kargs.tiles_per_chunk_m; + wait_signal(kargs.chunk_signals + chunk_idx); + } // Get the SplitK offset for this block const auto k_batch = amd_wave_read_first_lane(block_id / num_tiles); @@ -1206,6 +1283,11 @@ struct UniversalGemmKernel i_n); } } + + // Safe iteration boundary: ensure all memory operations complete + // before reusing LDS or moving to next tile + iteration_boundary_fence(); + // Advance to the next work item block_id += grid_size; if(block_id >= num_work) From c0ee47d4342549708b6bfce3d4159939d5e2c292 Mon Sep 17 00:00:00 2001 From: Kumar Date: Thu, 20 Nov 2025 12:40:42 +0530 Subject: [PATCH 04/20] Fix clang format --- example/ck_tile/03_gemm/persistent_async_scheduler.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/example/ck_tile/03_gemm/persistent_async_scheduler.hpp b/example/ck_tile/03_gemm/persistent_async_scheduler.hpp index b691d333066..65837de947f 100644 --- a/example/ck_tile/03_gemm/persistent_async_scheduler.hpp +++ b/example/ck_tile/03_gemm/persistent_async_scheduler.hpp @@ -10,7 +10,7 @@ namespace ck_tile { /** * @brief Arguments for Persistent Async GEMM scheduling - * + * * This structure contains parameters for producer-consumer synchronization * in persistent GEMM kernels with asynchronous input readiness. */ @@ -32,4 +32,4 @@ struct PersistentAsyncArgs { } }; -} // namespace ck_tile \ No newline at end of file +} // namespace ck_tile From 4650aad069edd646cf777724fd5dd8bb8a67d7c5 Mon Sep 17 00:00:00 2001 From: Kumar Date: Thu, 20 Nov 2025 15:10:21 +0530 Subject: [PATCH 05/20] Fix pre-hook commit error --- .../03_gemm/gemm_persistent_async_invoker.hpp | 8 +++---- .../ck_tile/ops/gemm/kernel/gemm_kernel.hpp | 4 ++-- .../ops/gemm/kernel/universal_gemm_kernel.hpp | 24 +++++++++---------- 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/example/ck_tile/03_gemm/gemm_persistent_async_invoker.hpp b/example/ck_tile/03_gemm/gemm_persistent_async_invoker.hpp index 4877853d2ca..d9202a10d32 100644 --- a/example/ck_tile/03_gemm/gemm_persistent_async_invoker.hpp +++ b/example/ck_tile/03_gemm/gemm_persistent_async_invoker.hpp @@ -132,12 +132,12 @@ struct GemmPersistentAsyncInvoker ck_tile::GemmHostArgs ws_args = ck_tile::GemmHostArgs(args); auto c_ptr = ws_args.c_ptr; ws_args.c_ptr = ws_m_n_dev_buf.GetDeviceBuffer(); - + // Add persistent async arguments to ws_args - ws_args.chunk_signals = async_args.chunk_signals; + ws_args.chunk_signals = async_args.chunk_signals; ws_args.tiles_per_chunk_m = async_args.tiles_per_chunk_m; - - auto gemm_kargs = GemmKernel::MakeKernelArgs(ws_args); + + auto gemm_kargs = GemmKernel::MakeKernelArgs(ws_args); const dim3 grids = Persistent ? GemmKernel::MaxOccupancyGridSize(s) : GemmKernel::GridSize(args.M, args.N, args.k_batch); diff --git a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp index 76763a7c250..8dba7f9792d 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp @@ -38,7 +38,7 @@ struct GemmHostArgs index_t stride_A_, index_t stride_B_, index_t stride_E_, - uint32_t* chunk_signals_ = nullptr, + uint32_t* chunk_signals_ = nullptr, index_t tiles_per_chunk_m_ = 0) : a_ptr(a_ptr_), b_ptr(b_ptr_), @@ -76,7 +76,7 @@ struct GemmHostArgs }; index_t k_batch; - + // Persistent async arguments uint32_t* chunk_signals; index_t tiles_per_chunk_m; diff --git a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp index c71b6a2ff2a..b50197f7d4e 100644 --- a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp @@ -18,10 +18,10 @@ namespace ck_tile { /** * @brief Wait for a signal to become ready with acquire semantics - * + * * Producer-only wait: One lane polls chunk_signals[chunk_idx] with acquire semantics, * then a workgroup barrier releases everyone. - * + * * @param signal_addr Pointer to the signal location in device memory */ CK_TILE_DEVICE static inline void wait_signal(uint32_t* signal_addr) @@ -47,14 +47,14 @@ CK_TILE_DEVICE static inline void wait_signal(uint32_t* signal_addr) } } } - + // Workgroup barrier to release all threads after signal is ready __builtin_amdgcn_s_barrier(); } /** * @brief Fence for safe iteration boundaries in persistent loops - * + * * Ensures all memory operations are complete before reusing LDS or moving to next tile. * Uses s_waitcnt vmcnt=0, lgkmcnt=0 + s_barrier. */ @@ -62,10 +62,10 @@ CK_TILE_DEVICE static inline void iteration_boundary_fence() { // Wait for all vector memory operations (global memory loads/stores) __builtin_amdgcn_s_waitcnt_vmcnt(0); - + // Wait for all LDS operations __builtin_amdgcn_s_waitcnt_lgkmcnt(0); - + // Synchronize all threads in the workgroup __builtin_amdgcn_s_barrier(); } @@ -96,7 +96,7 @@ struct UniversalGemmHostArgs const std::array& stride_Bs_, const std::array& stride_Ds_, index_t stride_E_, - uint32_t* chunk_signals_ = nullptr, + uint32_t* chunk_signals_ = nullptr, index_t tiles_per_chunk_m_ = 0) : as_ptr(as_ptr_), bs_ptr(bs_ptr_), @@ -136,7 +136,7 @@ struct UniversalGemmHostArgs }; index_t k_batch; - + // Persistent async arguments uint32_t* chunk_signals; index_t tiles_per_chunk_m; @@ -173,7 +173,7 @@ struct UniversalGemmKernelArgs /// (in memory) of E tensor. index_t stride_E; index_t k_batch; - + /// @brief Pointer to chunk signals for async producer-consumer synchronization. /// chunk_signals[i] == 1 indicates that chunk i is ready. uint32_t* chunk_signals; @@ -1210,7 +1210,7 @@ struct UniversalGemmKernel const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(tile_idx); const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock); const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock); - + // Producer-consumer synchronization: wait for chunk to be ready if(kargs.chunk_signals != nullptr && kargs.tiles_per_chunk_m > 0) { @@ -1283,11 +1283,11 @@ struct UniversalGemmKernel i_n); } } - + // Safe iteration boundary: ensure all memory operations complete // before reusing LDS or moving to next tile iteration_boundary_fence(); - + // Advance to the next work item block_id += grid_size; if(block_id >= num_work) From 984e3d4b71554f0a3565af519514abf1e7429f38 Mon Sep 17 00:00:00 2001 From: Manish Kumar Date: Tue, 25 Nov 2025 06:26:33 +0000 Subject: [PATCH 06/20] Resolve PR comments --- example/ck_tile/17_grouped_gemm/CMakeLists.txt | 0 .../grouped_gemm_persistent_async.cpp} | 5 +---- .../grouped_gemm_persistent_async.hpp} | 0 .../persistent_async_scheduler.hpp | 0 4 files changed, 1 insertion(+), 4 deletions(-) mode change 100644 => 100755 example/ck_tile/17_grouped_gemm/CMakeLists.txt rename example/ck_tile/{03_gemm/gemm_persistent_async.cpp => 17_grouped_gemm/grouped_gemm_persistent_async.cpp} (93%) mode change 100644 => 100755 rename example/ck_tile/{03_gemm/gemm_persistent_async_invoker.hpp => 17_grouped_gemm/grouped_gemm_persistent_async.hpp} (100%) mode change 100644 => 100755 rename example/ck_tile/{03_gemm => 17_grouped_gemm}/persistent_async_scheduler.hpp (100%) mode change 100644 => 100755 diff --git a/example/ck_tile/17_grouped_gemm/CMakeLists.txt b/example/ck_tile/17_grouped_gemm/CMakeLists.txt old mode 100644 new mode 100755 diff --git a/example/ck_tile/03_gemm/gemm_persistent_async.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm_persistent_async.cpp old mode 100644 new mode 100755 similarity index 93% rename from example/ck_tile/03_gemm/gemm_persistent_async.cpp rename to example/ck_tile/17_grouped_gemm/grouped_gemm_persistent_async.cpp index ecaa1d5d0b8..abc9ac70aea --- a/example/ck_tile/03_gemm/gemm_persistent_async.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm_persistent_async.cpp @@ -1,10 +1,7 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. -#include "gemm_utils.hpp" -#include "run_gemm_example.inc" -#include "run_gemm_example_common.hpp" -#include "gemm_persistent_async_invoker.hpp" +#include "run_grouped_gemm_example.inc" #include "persistent_async_scheduler.hpp" #include "ck_tile/core/utility/gemm_validation.hpp" #include diff --git a/example/ck_tile/03_gemm/gemm_persistent_async_invoker.hpp b/example/ck_tile/17_grouped_gemm/grouped_gemm_persistent_async.hpp old mode 100644 new mode 100755 similarity index 100% rename from example/ck_tile/03_gemm/gemm_persistent_async_invoker.hpp rename to example/ck_tile/17_grouped_gemm/grouped_gemm_persistent_async.hpp diff --git a/example/ck_tile/03_gemm/persistent_async_scheduler.hpp b/example/ck_tile/17_grouped_gemm/persistent_async_scheduler.hpp old mode 100644 new mode 100755 similarity index 100% rename from example/ck_tile/03_gemm/persistent_async_scheduler.hpp rename to example/ck_tile/17_grouped_gemm/persistent_async_scheduler.hpp From 7aead9064bd78f6e5b5b85f202ab61d581f77be4 Mon Sep 17 00:00:00 2001 From: Manish Kumar Date: Tue, 25 Nov 2025 06:28:46 +0000 Subject: [PATCH 07/20] Remove changes from gemm example --- example/ck_tile/03_gemm/CMakeLists.txt | 2 -- 1 file changed, 2 deletions(-) mode change 100644 => 100755 example/ck_tile/03_gemm/CMakeLists.txt diff --git a/example/ck_tile/03_gemm/CMakeLists.txt b/example/ck_tile/03_gemm/CMakeLists.txt old mode 100644 new mode 100755 index 2cc6b227e2d..d2112a67bf5 --- a/example/ck_tile/03_gemm/CMakeLists.txt +++ b/example/ck_tile/03_gemm/CMakeLists.txt @@ -3,7 +3,6 @@ add_executable(tile_example_gemm_universal EXCLUDE_FROM_ALL universal_gemm.cpp) add_executable(tile_example_gemm_weight_preshuffle EXCLUDE_FROM_ALL gemm_weight_preshuffle.cpp) add_executable(tile_example_gemm_reduce EXCLUDE_FROM_ALL gemm_splitk_two_stage_reduce.cpp) add_executable(tile_example_gemm_splitk_two_stage EXCLUDE_FROM_ALL gemm_splitk_two_stage.cpp) -add_executable(tile_example_gemm_persistent_async EXCLUDE_FROM_ALL gemm_persistent_async.cpp) set(EXAMPLE_GEMM_COMPILE_OPTIONS) set(EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS) if(CK_USE_OCP_FP8) @@ -19,4 +18,3 @@ target_compile_options(tile_example_gemm_universal PRIVATE ${EXAMPLE_GEMM_COMPIL target_compile_options(tile_example_gemm_weight_preshuffle PRIVATE ${EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS}) target_compile_options(tile_example_gemm_reduce PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) target_compile_options(tile_example_gemm_splitk_two_stage PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) -target_compile_options(tile_example_gemm_persistent_async PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) From 8b6c11b490170a58d0742c1f9a135fe7552dcd35 Mon Sep 17 00:00:00 2001 From: Manish Kumar Date: Tue, 25 Nov 2025 14:51:55 +0000 Subject: [PATCH 08/20] Fix build errors --- .../ck_tile/17_grouped_gemm/CMakeLists.txt | 4 +- .../grouped_gemm_persistent_async.cpp | 34 +- .../grouped_gemm_persistent_async.hpp | 323 +++++------------- .../persistent_async_scheduler.hpp | 19 +- .../persistent_async_utils.hpp | 63 ++++ .../ck_tile/ops/gemm/kernel/gemm_kernel.hpp | 16 +- .../ops/gemm/kernel/universal_gemm_kernel.hpp | 138 +++----- 7 files changed, 241 insertions(+), 356 deletions(-) create mode 100755 example/ck_tile/17_grouped_gemm/persistent_async_utils.hpp mode change 100644 => 100755 include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp mode change 100644 => 100755 include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp diff --git a/example/ck_tile/17_grouped_gemm/CMakeLists.txt b/example/ck_tile/17_grouped_gemm/CMakeLists.txt index bbfb2df006b..ef135a5f9f8 100755 --- a/example/ck_tile/17_grouped_gemm/CMakeLists.txt +++ b/example/ck_tile/17_grouped_gemm/CMakeLists.txt @@ -2,6 +2,7 @@ add_executable(tile_example_grouped_gemm EXCLUDE_FROM_ALL grouped_gemm.cpp) add_executable(tile_example_quant_grouped_gemm EXCLUDE_FROM_ALL quant_grouped_gemm.cpp) add_executable(tile_example_grouped_gemm_preshuffle EXCLUDE_FROM_ALL grouped_gemm_preshuffle.cpp) add_executable(tile_example_grouped_gemm_multi_d EXCLUDE_FROM_ALL grouped_gemm_multi_d.cpp) +add_executable(tile_example_grouped_gemm_persistent_async EXCLUDE_FROM_ALL grouped_gemm_persistent_async.cpp) set(EXAMPLE_GEMM_COMPILE_OPTIONS) if(CK_USE_OCP_FP8) list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) @@ -9,4 +10,5 @@ endif() target_compile_options(tile_example_grouped_gemm PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) target_compile_options(tile_example_grouped_gemm_preshuffle PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) target_compile_options(tile_example_grouped_gemm_multi_d PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) -target_compile_options(tile_example_quant_grouped_gemm PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) \ No newline at end of file +target_compile_options(tile_example_quant_grouped_gemm PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) +target_compile_options(tile_example_grouped_gemm_persistent_async PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm_persistent_async.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm_persistent_async.cpp index abc9ac70aea..450ceaa5ebc 100755 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm_persistent_async.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm_persistent_async.cpp @@ -1,10 +1,12 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. -#include "run_grouped_gemm_example.inc" #include "persistent_async_scheduler.hpp" +#include "persistent_async_utils.hpp" #include "ck_tile/core/utility/gemm_validation.hpp" #include +#include "grouped_gemm.hpp" +#include "grouped_gemm_persistent_async.hpp" /** * @brief Helper to allocate and initialize chunk signals @@ -13,7 +15,7 @@ * @param stream HIP stream for async operations * @return Device pointer to chunk signals array */ -static uint32_t* allocate_chunk_signals(int num_chunks, hipStream_t stream) +[[maybe_unused]] static uint32_t* allocate_chunk_signals(int num_chunks, hipStream_t stream) { uint32_t* signals_device = nullptr; @@ -34,7 +36,7 @@ static uint32_t* allocate_chunk_signals(int num_chunks, hipStream_t stream) * @param chunk_idx Index of chunk to signal * @param stream HIP stream for async operations */ -static void signal_chunk_ready(uint32_t* signals, int chunk_idx, hipStream_t stream) +[[maybe_unused]] static void signal_chunk_ready(uint32_t* signals, int chunk_idx, hipStream_t stream) { uint32_t ready = 1; ck_tile::hip_check_error(hipMemcpyAsync( @@ -43,7 +45,7 @@ static void signal_chunk_ready(uint32_t* signals, int chunk_idx, hipStream_t str int main(int argc, char* argv[]) { - auto arg_parser = create_args(); + auto [result, arg_parser] = create_args(argc, argv); // Add async-specific arguments arg_parser.insert( @@ -52,12 +54,28 @@ int main(int argc, char* argv[]) "tile_idx_pivot_m", "0", "Pivot offset for M dimension (for hotspot spreading)"); arg_parser.insert("enable_async", "1", "Enable async input signaling (0=disabled, 1=enabled)"); - auto result = arg_parser.parse(argc, argv); - - // TO-DO Add example - if(!result) return -1; + /*TO-DO + + // Parse async-specific arguments + const bool enable_async = arg_parser.get_int("enable_async") != 0; + const ck_tile::index_t tiles_per_chunk_m = arg_parser.get_int("tiles_per_chunk_m"); + const ck_tile::index_t tile_idx_pivot_m = arg_parser.get_int("tile_idx_pivot_m"); + + const std::string a_layout = arg_parser.get_str("a_layout"); + const std::string b_layout = arg_parser.get_str("b_layout"); + const std::string data_type = arg_parser.get_str("prec"); + + auto res = invoke_grouped_gemm_persistent_async( + a_layout, b_layout, data_type, arg_parser, + , tiles_per_chunk_m, tile_idx_pivot_m); + + + + */ + + return 0; } diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm_persistent_async.hpp b/example/ck_tile/17_grouped_gemm/grouped_gemm_persistent_async.hpp index d9202a10d32..814f45900ec 100755 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm_persistent_async.hpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm_persistent_async.hpp @@ -1,279 +1,108 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - #pragma once -#include "gemm_utils.hpp" -#include "persistent_async_scheduler.hpp" +#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp" +#include "ck_tile/ops/epilogue.hpp" -/** - * @brief Invoker for Persistent Async GEMM - * - * This invoker implements persistent GEMM with asynchronous input readiness. - * It extends the standard GEMM with support for: - * - Chunk-based async input signaling - * - Producer-consumer synchronization - * - Pivot-based tile traversal - */ -struct GemmPersistentAsyncInvoker -{ - template - static float gemm(const ck_tile::GemmHostArgs& args, - const ck_tile::stream_config& s, - const ck_tile::PersistentAsyncArgs& async_args) +template + void invoke_grouped_gemm_persistent(const ck_tile::stream_config& s, + const ck_tile::index_t num_groups, + void* kargs_ptr, + bool splitk) { - - static_assert(Persistent, "This invoker only supports persistent GEMM."); - - using GemmShape = ck_tile::TileGemmShape< - ck_tile::sequence, - ck_tile::sequence, - ck_tile::sequence>; - - using TilePartitioner = - ck_tile::GemmSpatiallyLocalTilePartitioner; - - using Traits = ck_tile::TileGemmTraits; + constexpr bool TransposeC = false; + constexpr bool DoubleSmemBuffer = false; + + constexpr int kBlockPerCu = 1; + constexpr ck_tile::index_t TileParitionerGroupNum = 8; + constexpr ck_tile::index_t TileParitionerM01 = 4; + + using GemmShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; + using TilePartitioner = ck_tile:: + GemmSpatiallyLocalTilePartitioner; using GemmUniversalTraits = - ck_tile::TileGemmUniversalTraits; - - using GemmPipelineProblem = - ck_tile::GemmPipelineProblem; - - using BaseGemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template UniversalGemmPipeline; - - const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile; - const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * GemmConfig::K_Tile; - const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); - const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); - const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); - float ave_time{0}; - - const auto Run = [&](const auto has_hot_loop_, - const auto tail_number_, - const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = GemmConfig::Scheduler; + ck_tile::PersistentTileGemmUniversalTraits; + + const auto Run = [&](const auto memory_operation_) { + constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; constexpr auto memory_operation = memory_operation_.value; + // We create the GEMM pipeline without specifying hotloop or tailnumber. + // These are automatically run inside the kernel based on the given input data. using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - - using GemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template GemmPipeline; - - using WorkspaceType = ck_tile::remove_cvref_t; + scheduler>; + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem>; - - using GemmKernel = ck_tile::GemmKernel; - - ck_tile::DeviceMem ws_m_n_dev_buf(args.M * args.N * sizeof(WorkspaceType)); - ck_tile::GemmHostArgs ws_args = ck_tile::GemmHostArgs(args); - auto c_ptr = ws_args.c_ptr; - ws_args.c_ptr = ws_m_n_dev_buf.GetDeviceBuffer(); - - // Add persistent async arguments to ws_args - ws_args.chunk_signals = async_args.chunk_signals; - ws_args.tiles_per_chunk_m = async_args.tiles_per_chunk_m; - - auto gemm_kargs = GemmKernel::MakeKernelArgs(ws_args); - - const dim3 grids = Persistent ? GemmKernel::MaxOccupancyGridSize(s) - : GemmKernel::GridSize(args.M, args.N, args.k_batch); - const dim3 blocks = GemmKernel::BlockSize(); - - if(!GemmKernel::IsSupportedArgument(gemm_kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); - } - - using XElementwiseOperation = ck_tile::element_wise::UnaryConvert; - using BlockTile = ck_tile::sequence<2048>; - using BlockWarps = ck_tile::sequence<8>; - using WarpTile = ck_tile::sequence<64>; - - using ElementwiseShape = - ck_tile::ElementWiseShape; - using Problem = ck_tile::ElementWisePipelineProblem; - using ElementwiseKernel = - ck_tile::ElementWiseKernel; - - ck_tile::index_t total_elements = 1; - std::vector shape = {args.M, args.N}; - - for(auto d : shape) - total_elements *= d; - - const ck_tile::index_t kBlockSize = ElementwiseKernel::BlockSize(); - constexpr ck_tile::index_t kBlockPerCu = 1; - - constexpr ck_tile::index_t elements_per_block = BlockTile::at(ck_tile::number<0>{}); - ck_tile::index_t kGridSize = - (total_elements + elements_per_block - 1) / elements_per_block; - - auto input_tensors = ck_tile::make_tuple(static_cast(ws_args.c_ptr)); - auto input_size = ck_tile::make_tuple(args.M, args.N); - - // Check if the kernel configuration is supported - if(!ElementwiseKernel::IsSupportedArgument(input_size)) - { - throw std::runtime_error( - "Wrong! Elementwise arguments not supported! Skipping gemm!\n"); - } + memory_operation>>; + using Kernel = ck_tile::GroupedGemmKernel; + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::MaxOccupancyGridSize(s); if(s.log_level_ > 0) { - std::cout << "Launching Persistent Async GEMM kernel:\n" - << " Kernel: " << Kernel::GetName() << '\n' - << " Shape: " << GemmShape::GetName() << '\n' - << " Problem: " << UniversalGemmProblem::GetName() << '\n' - << " Pipeline: " << GemmPipeline::GetName() << '\n' - << " Grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}\n" - << " Blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z - << "}\n" - << " Async Args:\n" - << " tiles_per_chunk_m: " << async_args.tiles_per_chunk_m << '\n' - << " tile_idx_pivot_m: " << async_args.tile_idx_pivot_m << '\n' - << " chunk_signals: " - << (async_args.chunk_signals ? "enabled" : "disabled") << std::endl; + std::cout << "Launching kernel: " << Kernel::GetName() + << " with args:" << " grid: {" << grids.x << ", " << grids.y << ", " + << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " + << blocks.z << "}" << std::endl; } - // Declare rotating_mem_ptr here so it stays in scope until it is needed - std::unique_ptr> rotating_mem_ptr; - std::function preprocess; - - auto clear_gemm_output = [&]() { - if(args.k_batch > 1) - hipGetErrorString(hipMemsetAsync( - ws_args.c_ptr, 0, args.M * args.N * sizeof(WorkspaceType), s.stream_id_)); - }; - - if(s.flush_cache_) - { - std::cout << "Flushing cache..." << std::endl; - - ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( - args.M, args.K, args.stride_A, is_row_major(ALayout{}))); - ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( - args.K, args.N, args.stride_B, is_row_major(BLayout{}))); - - auto size_a_buffer = a_m.get_element_space_size_in_bytes(); - auto size_b_buffer = b_n.get_element_space_size_in_bytes(); - - rotating_mem_ptr = - std::make_unique>( - gemm_kargs.as_ptr[0], - gemm_kargs.bs_ptr[0], - s.rotating_count_, - size_a_buffer, - size_b_buffer); - rotating_mem_ptr->Print(); - - preprocess = [&]() { - ck_tile::flush_icache(); - rotating_mem_ptr->Next(); - clear_gemm_output(); - }; - } - else - { - preprocess = clear_gemm_output; - } - - ave_time = ck_tile::launch_kernel_time_mask( - s, - preprocess, - ck_tile::make_kernel( - GemmKernel{}, grids, blocks, 0, gemm_kargs), - ck_tile::make_kernel(ElementwiseKernel{}, - kGridSize, - kBlockSize, - 0, - input_size, - ck_tile::make_tuple(args.N, 1), // Input Stride - ck_tile::make_tuple(args.N, 1), // Output Stride - input_tensors, - static_cast(c_ptr))); - - return ave_time; - }; - - const auto RunSplitK = [&](const auto has_hot_loop_, const auto tail_number_) { - if(args.k_batch == 1) - { - return Run(has_hot_loop_, tail_number_, MemoryOpSet{}); - } - else - { - return Run(has_hot_loop_, tail_number_, MemoryOpAtomicAdd{}); - } + ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + num_groups)); }; - return ave_time = BaseGemmPipeline::TailHandler(RunSplitK, has_hot_loop, tail_num); + if(splitk) + { + Run(ck_tile::integral_constant{}); + } + else + { + + Run(ck_tile::integral_constant{}); + } } -}; diff --git a/example/ck_tile/17_grouped_gemm/persistent_async_scheduler.hpp b/example/ck_tile/17_grouped_gemm/persistent_async_scheduler.hpp index 65837de947f..b98f655dcf0 100755 --- a/example/ck_tile/17_grouped_gemm/persistent_async_scheduler.hpp +++ b/example/ck_tile/17_grouped_gemm/persistent_async_scheduler.hpp @@ -16,19 +16,28 @@ namespace ck_tile { */ struct PersistentAsyncArgs { + /// Number of M tiles per chunk (granularity of async readiness signaling) index_t tiles_per_chunk_m = 0; + /// Device pointer to global chunk readiness flags (1 = ready, 0 = not ready) uint32_t* chunk_signals = nullptr; + /// Pivot offset for M dimension (for hotspot spreading in tile scheduling) index_t tile_idx_pivot_m = 0; - PersistentAsyncArgs(index_t tiles_per_chunk_m_, - uint32_t* chunk_signals_, - index_t tile_idx_pivot_m_, - bool enable_async_) + /// Enable/disable async input signaling (false = disabled, true = enabled) + bool enable_async = false; + + CK_TILE_HOST_DEVICE PersistentAsyncArgs() = default; + + CK_TILE_HOST_DEVICE PersistentAsyncArgs(index_t tiles_per_chunk_m_, + uint32_t* chunk_signals_, + index_t tile_idx_pivot_m_, + bool enable_async_ = false) : tiles_per_chunk_m(tiles_per_chunk_m_), chunk_signals(chunk_signals_), - tile_idx_pivot_m(tile_idx_pivot_m_) + tile_idx_pivot_m(tile_idx_pivot_m_), + enable_async(enable_async_) { } }; diff --git a/example/ck_tile/17_grouped_gemm/persistent_async_utils.hpp b/example/ck_tile/17_grouped_gemm/persistent_async_utils.hpp new file mode 100755 index 00000000000..9c6c3811965 --- /dev/null +++ b/example/ck_tile/17_grouped_gemm/persistent_async_utils.hpp @@ -0,0 +1,63 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +/** + * @brief Safe iteration boundary fence for persistent kernels + * + * This function ensures memory consistency between iterations in persistent loops by: + * - Waiting for all vector memory operations to complete (vmcnt=0) + * - Waiting for all LDS/GDS operations to complete (lgkmcnt=0) + * - Synchronizing all workgroup threads via barrier + * + * This prevents race conditions when reusing LDS or moving to the next tile. + */ +CK_TILE_DEVICE static void iteration_boundary_fence() +{ + __builtin_amdgcn_s_waitcnt(0); + __builtin_amdgcn_s_waitcnt(0); + __builtin_amdgcn_s_barrier(); +} + +/** + * @brief Wait for chunk readiness signal (producer-consumer synchronization) + * + * This function implements producer-consumer synchronization for async input readiness: + * - One lane polls the chunk_signals[chunk_idx] flag with acquire semantics + * - When signal becomes ready (value == 1), a workgroup barrier releases all threads + * + * @param chunk_signals Device pointer to global chunk readiness flags array + * @param chunk_idx Index of the chunk to wait for + * + * @note Only lane 0 performs the polling to minimize global memory traffic + * @note Uses acquire semantics to ensure proper memory ordering + */ +CK_TILE_DEVICE static void wait_chunk_signal(const uint32_t* chunk_signals, index_t chunk_idx) +{ + // Only lane 0 polls the signal to minimize global memory traffic + if(threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0) + { + volatile const uint32_t* signal_ptr = chunk_signals + chunk_idx; + + // Poll until chunk is ready (signal == 1) + // Use acquire semantics for proper memory ordering + uint32_t signal_value; + do { + signal_value = __builtin_nontemporal_load(signal_ptr); + __builtin_amdgcn_s_sleep(1); // Brief sleep to reduce contention + } while(signal_value == 0); + + // Memory fence with acquire semantics + __builtin_amdgcn_fence(__ATOMIC_ACQUIRE, "agent"); + } + + // Barrier to release all threads in the workgroup + __builtin_amdgcn_s_barrier(); +} + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp old mode 100644 new mode 100755 index 8dba7f9792d..d632b1596ca --- a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp @@ -37,9 +37,7 @@ struct GemmHostArgs index_t K_, index_t stride_A_, index_t stride_B_, - index_t stride_E_, - uint32_t* chunk_signals_ = nullptr, - index_t tiles_per_chunk_m_ = 0) + index_t stride_E_) : a_ptr(a_ptr_), b_ptr(b_ptr_), e_ptr(e_ptr_), @@ -49,9 +47,7 @@ struct GemmHostArgs stride_A(stride_A_), stride_B(stride_B_), stride_E(stride_E_), - k_batch(k_batch_), - chunk_signals(chunk_signals_), - tiles_per_chunk_m(tiles_per_chunk_m_) + k_batch(k_batch_) { } @@ -76,10 +72,6 @@ struct GemmHostArgs }; index_t k_batch; - - // Persistent async arguments - uint32_t* chunk_signals; - index_t tiles_per_chunk_m; }; template @@ -161,9 +153,7 @@ struct GemmKernel {hostArgs.stride_A}, {hostArgs.stride_B}, {/*hostArgs.stride_Ds*/}, - hostArgs.stride_E, - hostArgs.chunk_signals, - hostArgs.tiles_per_chunk_m)); + hostArgs.stride_E)); } CK_TILE_HOST static auto diff --git a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp old mode 100644 new mode 100755 index b50197f7d4e..dde42cdab23 --- a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp @@ -16,59 +16,59 @@ namespace ck_tile { -/** - * @brief Wait for a signal to become ready with acquire semantics - * - * Producer-only wait: One lane polls chunk_signals[chunk_idx] with acquire semantics, - * then a workgroup barrier releases everyone. - * - * @param signal_addr Pointer to the signal location in device memory - */ -CK_TILE_DEVICE static inline void wait_signal(uint32_t* signal_addr) -{ - // Only one thread in the workgroup polls the signal - if(threadIdx.x == 0) - { - uint32_t ready = 0; - while(!ready) - { - // Load with acquire semantics using AMD intrinsics - // glc (globally coherent) ensures visibility across the system - asm volatile("flat_load_dword %0, %1 glc\n\t" - "s_waitcnt vmcnt(0)" - : "=v"(ready) - : "v"(signal_addr) - : "memory"); - - // Add a small delay to reduce memory traffic - if(!ready) - { - __builtin_amdgcn_s_sleep(1); - } - } - } - - // Workgroup barrier to release all threads after signal is ready - __builtin_amdgcn_s_barrier(); -} - -/** - * @brief Fence for safe iteration boundaries in persistent loops - * - * Ensures all memory operations are complete before reusing LDS or moving to next tile. - * Uses s_waitcnt vmcnt=0, lgkmcnt=0 + s_barrier. - */ -CK_TILE_DEVICE static inline void iteration_boundary_fence() -{ - // Wait for all vector memory operations (global memory loads/stores) - __builtin_amdgcn_s_waitcnt_vmcnt(0); - - // Wait for all LDS operations - __builtin_amdgcn_s_waitcnt_lgkmcnt(0); - - // Synchronize all threads in the workgroup - __builtin_amdgcn_s_barrier(); -} +// /** +// * @brief Wait for a signal to become ready with acquire semantics +// * +// * Producer-only wait: One lane polls chunk_signals[chunk_idx] with acquire semantics, +// * then a workgroup barrier releases everyone. +// * +// * @param signal_addr Pointer to the signal location in device memory +// */ +// CK_TILE_DEVICE static void wait_signal(uint32_t* signal_addr) +// { +// // Only one thread in the workgroup polls the signal +// if(threadIdx.x == 0) +// { +// uint32_t ready = 0; +// while(!ready) +// { +// // Load with acquire semantics using AMD intrinsics +// // glc (globally coherent) ensures visibility across the system +// asm volatile("flat_load_dword %0, %1 glc\n\t" +// "s_waitcnt vmcnt(0)" +// : "=v"(ready) +// : "v"(signal_addr) +// : "memory"); + +// // Add a small delay to reduce memory traffic +// if(!ready) +// { +// __builtin_amdgcn_s_sleep(1); +// } +// } +// } + +// // Workgroup barrier to release all threads after signal is ready +// __builtin_amdgcn_s_barrier(); +// } + +// /** +// * @brief Fence for safe iteration boundaries in persistent loops +// * +// * Ensures all memory operations are complete before reusing LDS or moving to next tile. +// * Uses s_waitcnt vmcnt=0, lgkmcnt=0 + s_barrier. +// */ +// CK_TILE_DEVICE static void iteration_boundary_fence() +// { +// // Wait for all vector memory operations (global memory loads/stores) +// __builtin_amdgcn_s_waitcnt(0); + +// // Wait for all LDS operations +// __builtin_amdgcn_s_waitcnt(0); + +// // Synchronize all threads in the workgroup +// __builtin_amdgcn_s_barrier(); +// } /// @brief The Universal GEMM kernel host arguments. /// @@ -95,9 +95,7 @@ struct UniversalGemmHostArgs const std::array& stride_As_, const std::array& stride_Bs_, const std::array& stride_Ds_, - index_t stride_E_, - uint32_t* chunk_signals_ = nullptr, - index_t tiles_per_chunk_m_ = 0) + index_t stride_E_) : as_ptr(as_ptr_), bs_ptr(bs_ptr_), ds_ptr(ds_ptr_), @@ -109,9 +107,7 @@ struct UniversalGemmHostArgs stride_Bs(stride_Bs_), stride_Ds(stride_Ds_), stride_E(stride_E_), - k_batch(k_batch_), - chunk_signals(chunk_signals_), - tiles_per_chunk_m(tiles_per_chunk_m_) + k_batch(k_batch_) { } @@ -136,10 +132,6 @@ struct UniversalGemmHostArgs }; index_t k_batch; - - // Persistent async arguments - uint32_t* chunk_signals; - index_t tiles_per_chunk_m; }; /// @brief The GEMM kernel device arguments. @@ -174,11 +166,6 @@ struct UniversalGemmKernelArgs index_t stride_E; index_t k_batch; - /// @brief Pointer to chunk signals for async producer-consumer synchronization. - /// chunk_signals[i] == 1 indicates that chunk i is ready. - uint32_t* chunk_signals; - /// @brief Number of M tiles per chunk for async input signaling. - index_t tiles_per_chunk_m; }; /// @brief The Universal GEMM kernel template. @@ -381,9 +368,7 @@ struct UniversalGemmKernel hostArgs.stride_Bs, hostArgs.stride_Ds, hostArgs.stride_E, - hostArgs.k_batch, - hostArgs.chunk_signals, - hostArgs.tiles_per_chunk_m}; + hostArgs.k_batch}; } CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() @@ -1211,13 +1196,6 @@ struct UniversalGemmKernel const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock); const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock); - // Producer-consumer synchronization: wait for chunk to be ready - if(kargs.chunk_signals != nullptr && kargs.tiles_per_chunk_m > 0) - { - const index_t chunk_idx = iM / kargs.tiles_per_chunk_m; - wait_signal(kargs.chunk_signals + chunk_idx); - } - // Get the SplitK offset for this block const auto k_batch = amd_wave_read_first_lane(block_id / num_tiles); const SplitKBatchOffset splitk_batch_offset(kargs, k_batch); @@ -1284,10 +1262,6 @@ struct UniversalGemmKernel } } - // Safe iteration boundary: ensure all memory operations complete - // before reusing LDS or moving to next tile - iteration_boundary_fence(); - // Advance to the next work item block_id += grid_size; if(block_id >= num_work) From b649b364bf2dc6ff83b68909bdfff39abfecbc04 Mon Sep 17 00:00:00 2001 From: Manish Kumar Date: Tue, 25 Nov 2025 14:55:19 +0000 Subject: [PATCH 09/20] Remove commented code --- .../ops/gemm/kernel/universal_gemm_kernel.hpp | 56 ------------------- 1 file changed, 56 deletions(-) diff --git a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp index dde42cdab23..e77355ed3dc 100755 --- a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp @@ -16,60 +16,6 @@ namespace ck_tile { -// /** -// * @brief Wait for a signal to become ready with acquire semantics -// * -// * Producer-only wait: One lane polls chunk_signals[chunk_idx] with acquire semantics, -// * then a workgroup barrier releases everyone. -// * -// * @param signal_addr Pointer to the signal location in device memory -// */ -// CK_TILE_DEVICE static void wait_signal(uint32_t* signal_addr) -// { -// // Only one thread in the workgroup polls the signal -// if(threadIdx.x == 0) -// { -// uint32_t ready = 0; -// while(!ready) -// { -// // Load with acquire semantics using AMD intrinsics -// // glc (globally coherent) ensures visibility across the system -// asm volatile("flat_load_dword %0, %1 glc\n\t" -// "s_waitcnt vmcnt(0)" -// : "=v"(ready) -// : "v"(signal_addr) -// : "memory"); - -// // Add a small delay to reduce memory traffic -// if(!ready) -// { -// __builtin_amdgcn_s_sleep(1); -// } -// } -// } - -// // Workgroup barrier to release all threads after signal is ready -// __builtin_amdgcn_s_barrier(); -// } - -// /** -// * @brief Fence for safe iteration boundaries in persistent loops -// * -// * Ensures all memory operations are complete before reusing LDS or moving to next tile. -// * Uses s_waitcnt vmcnt=0, lgkmcnt=0 + s_barrier. -// */ -// CK_TILE_DEVICE static void iteration_boundary_fence() -// { -// // Wait for all vector memory operations (global memory loads/stores) -// __builtin_amdgcn_s_waitcnt(0); - -// // Wait for all LDS operations -// __builtin_amdgcn_s_waitcnt(0); - -// // Synchronize all threads in the workgroup -// __builtin_amdgcn_s_barrier(); -// } - /// @brief The Universal GEMM kernel host arguments. /// /// @par Overview @@ -165,7 +111,6 @@ struct UniversalGemmKernelArgs /// (in memory) of E tensor. index_t stride_E; index_t k_batch; - }; /// @brief The Universal GEMM kernel template. @@ -1261,7 +1206,6 @@ struct UniversalGemmKernel i_n); } } - // Advance to the next work item block_id += grid_size; if(block_id >= num_work) From 15345968ecd91ee4d7b0585bf9630fb8dbc91bd8 Mon Sep 17 00:00:00 2001 From: Manish Kumar Date: Tue, 25 Nov 2025 14:59:16 +0000 Subject: [PATCH 10/20] Fix pre-commit error --- .../grouped_gemm_persistent_async.cpp | 6 +- .../grouped_gemm_persistent_async.hpp | 190 +++++++++--------- .../persistent_async_utils.hpp | 9 +- 3 files changed, 106 insertions(+), 99 deletions(-) mode change 100755 => 100644 example/ck_tile/17_grouped_gemm/grouped_gemm_persistent_async.cpp mode change 100755 => 100644 example/ck_tile/17_grouped_gemm/grouped_gemm_persistent_async.hpp mode change 100755 => 100644 example/ck_tile/17_grouped_gemm/persistent_async_utils.hpp diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm_persistent_async.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm_persistent_async.cpp old mode 100755 new mode 100644 index 450ceaa5ebc..70b63d2ca41 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm_persistent_async.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm_persistent_async.cpp @@ -36,7 +36,8 @@ * @param chunk_idx Index of chunk to signal * @param stream HIP stream for async operations */ -[[maybe_unused]] static void signal_chunk_ready(uint32_t* signals, int chunk_idx, hipStream_t stream) +[[maybe_unused]] static void +signal_chunk_ready(uint32_t* signals, int chunk_idx, hipStream_t stream) { uint32_t ready = 1; ck_tile::hip_check_error(hipMemcpyAsync( @@ -67,7 +68,7 @@ int main(int argc, char* argv[]) const std::string a_layout = arg_parser.get_str("a_layout"); const std::string b_layout = arg_parser.get_str("b_layout"); const std::string data_type = arg_parser.get_str("prec"); - + auto res = invoke_grouped_gemm_persistent_async( a_layout, b_layout, data_type, arg_parser, , tiles_per_chunk_m, tile_idx_pivot_m); @@ -76,6 +77,5 @@ int main(int argc, char* argv[]) */ - return 0; } diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm_persistent_async.hpp b/example/ck_tile/17_grouped_gemm/grouped_gemm_persistent_async.hpp old mode 100755 new mode 100644 index 814f45900ec..599ec707461 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm_persistent_async.hpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm_persistent_async.hpp @@ -3,106 +3,112 @@ #include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp" #include "ck_tile/ops/epilogue.hpp" +template +void invoke_grouped_gemm_persistent(const ck_tile::stream_config& s, + const ck_tile::index_t num_groups, + void* kargs_ptr, + bool splitk) +{ + constexpr bool TransposeC = false; + constexpr bool DoubleSmemBuffer = false; -template - void invoke_grouped_gemm_persistent(const ck_tile::stream_config& s, - const ck_tile::index_t num_groups, - void* kargs_ptr, - bool splitk) - { - constexpr bool TransposeC = false; - constexpr bool DoubleSmemBuffer = false; - - constexpr int kBlockPerCu = 1; - constexpr ck_tile::index_t TileParitionerGroupNum = 8; - constexpr ck_tile::index_t TileParitionerM01 = 4; - - using GemmShape = - ck_tile::TileGemmShape, - ck_tile::sequence, - ck_tile::sequence>; - using TilePartitioner = ck_tile:: - GemmSpatiallyLocalTilePartitioner; + constexpr int kBlockPerCu = 1; + constexpr ck_tile::index_t TileParitionerGroupNum = 8; + constexpr ck_tile::index_t TileParitionerM01 = 4; - using GemmUniversalTraits = - ck_tile::PersistentTileGemmUniversalTraits; + using GemmShape = ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; + using TilePartitioner = ck_tile:: + GemmSpatiallyLocalTilePartitioner; - const auto Run = [&](const auto memory_operation_) { - constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; - constexpr auto memory_operation = memory_operation_.value; + using GemmUniversalTraits = + ck_tile::PersistentTileGemmUniversalTraits; - // We create the GEMM pipeline without specifying hotloop or tailnumber. - // These are automatically run inside the kernel based on the given input data. - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + const auto Run = [&](const auto memory_operation_) { + constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + constexpr auto memory_operation = memory_operation_.value; - using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::GroupedGemmKernel; - const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::MaxOccupancyGridSize(s); + // We create the GEMM pipeline without specifying hotloop or tailnumber. + // These are automatically run inside the kernel based on the given input data. + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - if(s.log_level_ > 0) - { - std::cout << "Launching kernel: " << Kernel::GetName() - << " with args:" << " grid: {" << grids.x << ", " << grids.y << ", " - << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " - << blocks.z << "}" << std::endl; - } + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GroupedGemmKernel; + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::MaxOccupancyGridSize(s); - ck_tile::launch_kernel(s, - ck_tile::make_kernel( - Kernel{}, - grids, - blocks, - 0, - ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), - num_groups)); - }; - - if(splitk) + if(s.log_level_ > 0) { - Run(ck_tile::integral_constant{}); + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; } - else - { - Run(ck_tile::integral_constant{}); - } + ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + num_groups)); + }; + + if(splitk) + { + Run(ck_tile::integral_constant{}); + } + else + { + + Run(ck_tile::integral_constant{}); } +} diff --git a/example/ck_tile/17_grouped_gemm/persistent_async_utils.hpp b/example/ck_tile/17_grouped_gemm/persistent_async_utils.hpp old mode 100755 new mode 100644 index 9c6c3811965..c8533d948bb --- a/example/ck_tile/17_grouped_gemm/persistent_async_utils.hpp +++ b/example/ck_tile/17_grouped_gemm/persistent_async_utils.hpp @@ -43,19 +43,20 @@ CK_TILE_DEVICE static void wait_chunk_signal(const uint32_t* chunk_signals, inde if(threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0) { volatile const uint32_t* signal_ptr = chunk_signals + chunk_idx; - + // Poll until chunk is ready (signal == 1) // Use acquire semantics for proper memory ordering uint32_t signal_value; - do { + do + { signal_value = __builtin_nontemporal_load(signal_ptr); __builtin_amdgcn_s_sleep(1); // Brief sleep to reduce contention } while(signal_value == 0); - + // Memory fence with acquire semantics __builtin_amdgcn_fence(__ATOMIC_ACQUIRE, "agent"); } - + // Barrier to release all threads in the workgroup __builtin_amdgcn_s_barrier(); } From 42ed69376179b562bd850c1baf73bead62276280 Mon Sep 17 00:00:00 2001 From: Manish Kumar Date: Wed, 26 Nov 2025 18:08:28 +0000 Subject: [PATCH 11/20] Resolve PR comments --- .../17_grouped_gemm/grouped_gemm_persistent_async.cpp | 2 +- include/ck_tile/ops/gemm.hpp | 1 + .../ck_tile/ops/gemm/kernel}/persistent_async_utils.hpp | 6 +++--- 3 files changed, 5 insertions(+), 4 deletions(-) rename {example/ck_tile/17_grouped_gemm => include/ck_tile/ops/gemm/kernel}/persistent_async_utils.hpp (93%) diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm_persistent_async.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm_persistent_async.cpp index 70b63d2ca41..58693d67479 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm_persistent_async.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm_persistent_async.cpp @@ -2,7 +2,7 @@ // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include "persistent_async_scheduler.hpp" -#include "persistent_async_utils.hpp" +#include "ck_tile/ops/gemm/kernel/persistent_async_utils.hpp" #include "ck_tile/core/utility/gemm_validation.hpp" #include #include "grouped_gemm.hpp" diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index ec2d2488c88..b308e3b5572 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -34,6 +34,7 @@ #include "ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp" #include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp" #include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp" +#include "ck_tile/ops/gemm/kernel/persistent_async_utils.hpp" #include "ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp" #include "ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner.hpp" #include "ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner_impl.hpp" diff --git a/example/ck_tile/17_grouped_gemm/persistent_async_utils.hpp b/include/ck_tile/ops/gemm/kernel/persistent_async_utils.hpp similarity index 93% rename from example/ck_tile/17_grouped_gemm/persistent_async_utils.hpp rename to include/ck_tile/ops/gemm/kernel/persistent_async_utils.hpp index c8533d948bb..5911758b413 100644 --- a/example/ck_tile/17_grouped_gemm/persistent_async_utils.hpp +++ b/include/ck_tile/ops/gemm/kernel/persistent_async_utils.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck_tile/core.hpp" +#include "ck/utility/synchronization.hpp" namespace ck_tile { @@ -19,9 +20,8 @@ namespace ck_tile { */ CK_TILE_DEVICE static void iteration_boundary_fence() { - __builtin_amdgcn_s_waitcnt(0); - __builtin_amdgcn_s_waitcnt(0); - __builtin_amdgcn_s_barrier(); + // Wait for all global and LDS memory operations, then barrier + block_sync_lds_direct_load(); } /** From 6d4949bfd8795edefbab5c0a1c002d8e258744ee Mon Sep 17 00:00:00 2001 From: Kumar Date: Fri, 5 Dec 2025 11:00:10 +0530 Subject: [PATCH 12/20] Resolve PR comments + add grouped gemm example --- .../grouped_gemm_persistent_async.cpp | 468 +++++++++++++++++- .../gemm/kernel/persistent_async_utils.hpp | 3 - 2 files changed, 447 insertions(+), 24 deletions(-) diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm_persistent_async.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm_persistent_async.cpp index 58693d67479..58800eb6ba7 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm_persistent_async.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm_persistent_async.cpp @@ -1,10 +1,24 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. -#include "persistent_async_scheduler.hpp" -#include "ck_tile/ops/gemm/kernel/persistent_async_utils.hpp" -#include "ck_tile/core/utility/gemm_validation.hpp" #include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm/kernel/persistent_async_utils.hpp" +#include "persistent_async_scheduler.hpp" #include "grouped_gemm.hpp" #include "grouped_gemm_persistent_async.hpp" @@ -44,38 +58,450 @@ signal_chunk_ready(uint32_t* signals, int chunk_idx, hipStream_t stream) &signals[chunk_idx], &ready, sizeof(uint32_t), hipMemcpyHostToDevice, stream)); } -int main(int argc, char* argv[]) +template +int run_gemm_example_prec_type(std::string a_layout, + std::string b_layout, + ck_tile::ArgParser& arg_parser) { - auto [result, arg_parser] = create_args(argc, argv); - - // Add async-specific arguments - arg_parser.insert( - "tiles_per_chunk_m", "1", "Number of M tiles per chunk (granularity of async readiness)"); - arg_parser.insert( - "tile_idx_pivot_m", "0", "Pivot offset for M dimension (for hotspot spreading)"); - arg_parser.insert("enable_async", "1", "Enable async input signaling (0=disabled, 1=enabled)"); - - if(!result) - return -1; + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + using Types = GemmTypeConfig; - /*TO-DO + // Specific type aliases for easy access + using ADataType = typename Types::ADataType; + using BDataType = typename Types::BDataType; + using AccDataType = typename Types::AccDataType; + using CDataType = typename Types::CDataType; // Parse async-specific arguments - const bool enable_async = arg_parser.get_int("enable_async") != 0; + const bool enable_async = arg_parser.get_int("enable_async") != 0; const ck_tile::index_t tiles_per_chunk_m = arg_parser.get_int("tiles_per_chunk_m"); const ck_tile::index_t tile_idx_pivot_m = arg_parser.get_int("tile_idx_pivot_m"); + std::cout << "\n=== Async Parameters ===" << std::endl; + std::cout << " enable_async: " << (enable_async ? "YES (will allocate chunk signals)" : "NO") + << std::endl; + std::cout << " tiles_per_chunk_m: " << tiles_per_chunk_m << std::endl; + std::cout << " tile_idx_pivot_m: " << tile_idx_pivot_m << std::endl; + + // Create async args (chunk signals will be allocated in the example function) + ck_tile::PersistentAsyncArgs async_args( + tiles_per_chunk_m, nullptr, tile_idx_pivot_m, enable_async); + + if(a_layout == "R" && b_layout == "C") + { + return run_grouped_gemm_persistent_async_example( + arg_parser, Row{}, Col{}, Row{}, async_args); + } + else if(a_layout == "R" && b_layout == "R") + { + return run_grouped_gemm_persistent_async_example( + arg_parser, Row{}, Row{}, Row{}, async_args); + } + else if(a_layout == "C" && b_layout == "R") + { + return run_grouped_gemm_persistent_async_example( + arg_parser, Col{}, Row{}, Row{}, async_args); + } + else if(a_layout == "C" && b_layout == "C") + { + return run_grouped_gemm_persistent_async_example( + arg_parser, Col{}, Col{}, Row{}, async_args); + } + else + { + throw std::runtime_error("Unsupported data layout configuration for A and B tensors!"); + } +} + +template