diff --git a/hpc/group_gemm.py b/hpc/group_gemm.py index 0631c2e..a292673 100644 --- a/hpc/group_gemm.py +++ b/hpc/group_gemm.py @@ -152,6 +152,52 @@ def group_gemm_blockwise_fp8( ) +def group_gemm_bf16( + x: Tensor, + weight: Tensor, + seqlens: Tensor, + cu_seqlens: Tensor, + num_seq_per_group_avg: int = 32, + output: Tensor = None, + tma_desc: Tensor = None, +) -> Tensor: + """Performs group GEMM operation with BF16 precision. + + This function executes multiple matrix multiplications in a group manner + using BF16 precision for improved performance. + + Args: + x: Input activation tensor + Shape: [total_seq, hidden_size] + Dtype: bfloat16 + weight: Weight tensor for group matrix multiplication + Shape: [num_group, output_dim, hidden_size] + Dtype: bfloat16 + seqlens: Sequence lengths for each group + Shape: [num_group] + Dtype: int32 + cu_seqlens: Cumulative sequence lengths indicating start indices in input tensor + Shape: [num_group + 1] + Dtype: int32 + + Returns: + Tensor: Output tensor after group matrix multiplication + Shape: [total_seq, output_dim] + Dtype: bfloat16 + + Raises: + RuntimeError: If the input tensors have incompatible shapes or types, + or if the CUDA kernel execution fails. + + Note: + - All input tensors must be on CUDA device + + """ + return torch.ops.hpc.group_gemm_bf16( + x, weight, seqlens, cu_seqlens, num_seq_per_group_avg, output, tma_desc + ) + + @torch.library.register_fake("hpc::group_gemm_pertensor_fp8") def group_gemm_pertensor_fp8_fake( x, weight, seqlens, cu_seqlens, y_scale, num_seq_per_group_avg, output, tma_des @@ -164,3 +210,8 @@ def group_gemm_blockwise_fp8_fake( x, weight, seqlens, cu_seqlens, x_scale, w_scale, num_seq_per_group_avg, output, tma_des ): return torch.empty((x.shape[0], weight.shape[1]), dtype=torch.bfloat16) + + +@torch.library.register_fake("hpc::group_gemm_bf16") +def group_gemm_bf16_fake(x, weight, seqlens, cu_seqlens, num_seq_per_group_avg, output, tma_des): + return torch.empty((x.shape[0], weight.shape[1]), dtype=torch.bfloat16) diff --git a/src/group_gemm/config.h b/src/group_gemm/config.h index b0b4656..16462c0 100644 --- a/src/group_gemm/config.h +++ b/src/group_gemm/config.h @@ -165,6 +165,75 @@ struct GroupGEMMBlockWiseFp8Config { auto get_shm_size() { return shm_size; } }; +template +static constexpr auto mma_selector_bf16() { + constexpr auto MajorA = GMMA::Major::K; + constexpr auto MajorB = GMMA::Major::K; + if constexpr (kTileM == 8) { + return cute::SM90_64x8x16_F32BF16BF16_SS{}; + } else if constexpr (kTileM == 16) { + return cute::SM90_64x16x16_F32BF16BF16_SS{}; + } else if constexpr (kTileM == 32) { + return cute::SM90_64x32x16_F32BF16BF16_SS{}; + } else if constexpr (kTileM == 48) { + return cute::SM90_64x48x16_F32BF16BF16_SS{}; + } else if constexpr (kTileM == 64) { + return cute::SM90_64x64x16_F32BF16BF16_SS{}; + } else if constexpr (kTileM == 96) { + return cute::SM90_64x96x16_F32BF16BF16_SS{}; + } else if constexpr (kTileM == 128) { + return cute::SM90_64x128x16_F32BF16BF16_SS{}; + } else { + return cute::SM90_64x64x16_F32BF16BF16_SS{}; + } +} + +template +struct GroupGEMMBF16Config { + using Tin = Tin_; + using Tout = Tout_; + + static constexpr int kTileM = kTileM_; + static constexpr int kTileN = kTileN_; + static constexpr int kTileK = kTileK_; + static constexpr int kStage = kStage_; + static constexpr int kWarpgroupM = kWarpgroupM_; + static constexpr int kWarpgroupN = kWarpgroupN_; + + using SLayoutXAtom = decltype(slayout_selector()); + using SLayoutWAtom = decltype(slayout_selector()); + using SLayoutYAtom = decltype(slayout_selector()); + + using SLayoutX = decltype(tile_to_shape(SLayoutXAtom{}, + make_shape(Int{}, Int{}, Int{}))); + using SLayoutW = decltype(tile_to_shape(SLayoutWAtom{}, + make_shape(Int{}, Int{}, Int{}))); + using SLayoutY = + decltype(tile_to_shape(SLayoutYAtom{}, make_shape(Int{}, Int{}))); + using CopyBoxY = decltype(tile_to_shape(SLayoutYAtom{}, + make_shape(Int{}, Int{}))); + + template + auto get_tma(TX x, TW w, TY y) { + auto tma_x = make_tma_copy(SM90_TMA_LOAD{}, x, take<0, 2>(SLayoutX{})); + auto tma_w = make_tma_copy(SM90_TMA_LOAD{}, w, take<0, 2>(SLayoutW{})); + auto tma_y = make_tma_copy(SM90_TMA_STORE{}, y, CopyBoxY{}); + return std::make_tuple(tma_x, tma_w, tma_y); + } + + using WarpgroupLayout = + decltype(make_layout(make_shape(Int{}, Int{}, Int<1>{}))); + using TiledMma = decltype(make_tiled_mma(mma_selector_bf16(), WarpgroupLayout{})); + + static constexpr int shm_xw = (cosize(SLayoutX{}) + cosize(SLayoutW{})) * sizeof(Tin); + static constexpr int shm_y = cosize(SLayoutY{}) * sizeof(Tout); + static constexpr int shm_size = shm_xw + shm_y; + + auto get_shm_size() { return shm_size; } +}; + } // namespace group_gemm } // namespace hpc diff --git a/src/group_gemm/entry.cc b/src/group_gemm/entry.cc index 5c74dff..9cc4c9c 100644 --- a/src/group_gemm/entry.cc +++ b/src/group_gemm/entry.cc @@ -6,7 +6,7 @@ #include #include - +#include #include "src/group_gemm/group_gemm.h" namespace hpc { @@ -137,6 +137,65 @@ torch::Tensor group_gemm_blockwise_fp8_entry( return y; } +torch::Tensor group_gemm_bf16_entry(const torch::Tensor &x, const torch::Tensor &weight, + const torch::Tensor &seqlens, + const torch::Tensor &cu_seqlens, + const int64_t num_seq_per_group_avg, + std::optional output, + std::optional tma_desc) { + auto stream = at::cuda::getCurrentCUDAStream(x.get_device()); + TORCH_CHECK(x.device().is_cuda(), "x tensor must be cuda"); + TORCH_CHECK(weight.device().is_cuda(), "weight tensor must be cuda"); + TORCH_CHECK(seqlens.device().is_cuda(), "seqlens tensor must be cuda"); + TORCH_CHECK(cu_seqlens.device().is_cuda(), "cu_seqlens tensor must be cuda"); + TORCH_CHECK(x.is_contiguous(), "x tensor a must be contiguous"); + TORCH_CHECK(weight.is_contiguous(), "weight tensor a must be contiguous"); + TORCH_CHECK(seqlens.size(0) == weight.size(0), + "seqlens and weight must share the same num_group"); + TORCH_CHECK(x.size(1) == weight.size(2), "x and weight must share the same k"); + + int m = x.size(0); + int k = x.size(1); + int n = weight.size(1); + int num_group = seqlens.size(0); + + auto options = x.options(); + torch::Tensor y; + if (output.has_value()) { + y = output.value(); + } else { + y = torch::empty({m, n}, options.dtype(torch::kBFloat16)); + } + + torch::Tensor tmas; + bool update_tma = true; + if (tma_desc.has_value()) { + tmas = tma_desc.value(); + update_tma = false; + } else { + tmas = torch::empty({num_group * 2, 128}, options); + } + + torch::Tensor tiles = torch::empty({num_group}, options.dtype(torch::kInt32)); + torch::Tensor cu_tiles = torch::empty({num_group + 1}, options.dtype(torch::kInt32)); + + const auto *x_ptr = x.const_data_ptr(); + const auto *weight_ptr = weight.const_data_ptr(); + const auto *seqlens_ptr = seqlens.const_data_ptr(); + const auto *cu_seqlens_ptr = cu_seqlens.const_data_ptr(); + auto *tmas_ptr = tmas.mutable_data_ptr(); + auto *y_ptr = y.mutable_data_ptr(); + + auto *tiles_ptr = tiles.mutable_data_ptr(); + auto *cu_tiles_ptr = cu_tiles.mutable_data_ptr(); + + group_gemm_bf16_async(y_ptr, x_ptr, weight_ptr, seqlens_ptr, cu_seqlens_ptr, + tmas_ptr, tiles_ptr, cu_tiles_ptr, num_group, m, n, k, + num_seq_per_group_avg, update_tma, stream); + + return y; +} + torch::Tensor reformat_x_scale_entry(const torch::Tensor &x_scale, const torch::Tensor &seqlens, const torch::Tensor &cu_seqlens, std::optional out_x_scale, @@ -197,6 +256,12 @@ TORCH_LIBRARY_FRAGMENT(hpc, m) { m.impl("group_gemm_pertensor_fp8", torch::kCUDA, &hpc::group_gemm::group_gemm_pertensor_fp8_entry); + m.def( + "group_gemm_bf16(Tensor x, Tensor weight, Tensor seqlens, Tensor cu_seqlens, " + "int num_seq_per_group_avg, Tensor? output, Tensor? tma_desc) -> (Tensor)"); + m.impl("group_gemm_bf16", torch::kCUDA, + &hpc::group_gemm::group_gemm_bf16_entry); + m.def( "group_gemm_blockwise_fp8(Tensor x, Tensor weight, Tensor seqlens, Tensor cu_seqlens, Tensor " "xscale, Tensor wscale," diff --git a/src/group_gemm/group_gemm.h b/src/group_gemm/group_gemm.h index 6e5752a..e74447d 100644 --- a/src/group_gemm/group_gemm.h +++ b/src/group_gemm/group_gemm.h @@ -28,6 +28,12 @@ void reformat_x_scale_async(void *output_ptr, const void *xscale_ptr, const void const void *cu_seqlens_ptr, int num_group, int m, int n, int tilem, cudaStream_t stream); +void group_gemm_bf16_async(void *y_ptr, const void *x_ptr, const void *w_ptr, + const void *seqlens_ptr, const void *cu_seqlens_ptr, + void *tmas_ptr, void *tiles_ptr, + void *cu_tiles_ptr, int num_group, int m, int n, int k, + int num_seq_per_group_avg, bool update_tma, + cudaStream_t stream); } // namespace group_gemm } // namespace hpc diff --git a/src/group_gemm/group_gemm_bf16.cu b/src/group_gemm/group_gemm_bf16.cu new file mode 100644 index 0000000..6cae49b --- /dev/null +++ b/src/group_gemm/group_gemm_bf16.cu @@ -0,0 +1,138 @@ +// Copyright (C) 2026 Tencent. + +#include +#include + +#include +#include +#include "cute/tensor.hpp" +#include "src/group_gemm/config.h" +#include "src/group_gemm/group_gemm.h" +#include "src/group_gemm/kernels.cuh" + +namespace hpc { +namespace group_gemm { + +template +void launch_group_gemm_bf16(void *y_ptr, const void *x_ptr, const void *w_ptr, + const void *seqlens_ptr, const void *cu_seqlens_ptr, + void *tmas_ptr, void *tiles_ptr, void *cu_tiles_ptr, int num_group, + int m, int n, int k, bool update_tma, cudaStream_t stream) { + using namespace cute; // NOLINT + + using Tin = cute::bfloat16_t; + using Tout = cute::bfloat16_t; + + auto X = make_tensor(make_gmem_ptr(reinterpret_cast(x_ptr)), make_shape(m, k), + make_stride(k, Int<1>{})); + auto W = make_tensor(make_gmem_ptr(reinterpret_cast(w_ptr)), + make_shape(n, k, num_group), make_stride(k, Int<1>{}, n * k)); + auto Y = make_tensor(make_gmem_ptr(reinterpret_cast(y_ptr)), make_shape(n, m), + make_stride(Int<1>{}, n)); + + using Config = GroupGEMMBF16Config; + Config config; + auto [tma_x, tma_w, tma_y] = config.get_tma(X, W, Y); + + auto *tma_xy = static_cast(tmas_ptr); + + // 0. update tma + if (update_tma) { + vec_t td_xy{ + *tma_x.get_tma_descriptor(), + *tma_y.get_tma_descriptor(), + }; + + constexpr int kGroupPerThread = 8; + constexpr int kThreadPerBlock = 32; + kernels::update_grouped_tma + <<>>( + td_xy, tma_xy, (const Tin *)x_ptr, (const Tout *)y_ptr, (const int *)seqlens_ptr, + (const int *)cu_seqlens_ptr, (int *)tiles_ptr, (int *)cu_tiles_ptr, num_group, m, n, k); + } + + // 1. group gemm + { + int num_tile_n = (n + kTileN - 1) / kTileN; + cutlass::FastDivmod flat_divider(num_tile_n); + + // dim3 block(size(Config::TiledMma{}) + 128); + dim3 block(384); + dim3 grid(get_sm_count()); + + int shm_seq = sizeof(int) * (num_group + 1); + int shm_size = config.get_shm_size() + shm_seq; + + if (k <= 1024 || n <= 1024) { + constexpr bool IsLoopH = true; + auto kernel = + kernels::group_gemm_bf16_kernel; + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shm_size); + + kernel<<>>(tma_w, tma_xy, (int *)seqlens_ptr, + (int *)tiles_ptr, (int *)cu_tiles_ptr, num_group, m, + n, k, flat_divider); + } else { + constexpr bool IsLoopH = false; + auto kernel = + kernels::group_gemm_bf16_kernel; + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shm_size); + + kernel<<>>(tma_w, tma_xy, (int *)seqlens_ptr, + (int *)tiles_ptr, (int *)cu_tiles_ptr, num_group, m, + n, k, flat_divider); + } + } +} + +void group_gemm_bf16_async(void *y_ptr, const void *x_ptr, const void *w_ptr, + const void *seqlens_ptr, const void *cu_seqlens_ptr, + void *tmas_ptr, void *tiles_ptr, + void *cu_tiles_ptr, int num_group, int m, int n, int k, + int num_seq_per_group_avg, bool update_tma, + cudaStream_t stream) { + constexpr int kTileN = 128; + constexpr int kTileK = 64; + constexpr int kWarpgroupM = 2; + constexpr int kWarpgroupN = 1; + constexpr int kSwizzleX = 128; + constexpr int kSwizzleW = 128; + constexpr int kSwizzleY = 64; + if (num_seq_per_group_avg <= 16) { + constexpr int kTileM = 16; + constexpr int kStage = 8; + launch_group_gemm_bf16(y_ptr, x_ptr, w_ptr, seqlens_ptr, cu_seqlens_ptr, + tmas_ptr, tiles_ptr, cu_tiles_ptr, + num_group, m, n, k, update_tma, stream); + } else if (num_seq_per_group_avg <= 32) { + constexpr int kTileM = 32; + constexpr int kStage = 8; + launch_group_gemm_bf16(y_ptr, x_ptr, w_ptr, seqlens_ptr, cu_seqlens_ptr, + tmas_ptr, tiles_ptr, cu_tiles_ptr, + num_group, m, n, k, update_tma, stream); + } else if (num_seq_per_group_avg <= 48) { + constexpr int kTileM = 48; + constexpr int kStage = 8; + launch_group_gemm_bf16(y_ptr, x_ptr, w_ptr, seqlens_ptr, cu_seqlens_ptr, + tmas_ptr, tiles_ptr, cu_tiles_ptr, + num_group, m, n, k, update_tma, stream); + } else { + constexpr int kTileM = 64; + constexpr int kStage = 8; + launch_group_gemm_bf16(y_ptr, x_ptr, w_ptr, seqlens_ptr, cu_seqlens_ptr, + tmas_ptr, tiles_ptr, cu_tiles_ptr, + num_group, m, n, k, update_tma, stream); + } +} + +} // namespace group_gemm +} // namespace hpc diff --git a/src/group_gemm/kernels.cuh b/src/group_gemm/kernels.cuh index aa7f06a..6b054f1 100644 --- a/src/group_gemm/kernels.cuh +++ b/src/group_gemm/kernels.cuh @@ -752,6 +752,287 @@ __global__ void __launch_bounds__(384, 1) } } +template +__global__ void __launch_bounds__(384, 1) + group_gemm_bf16_kernel(const __grid_constant__ TmaB tma_b, cute::TmaDescriptor *td_xy, + int *seqlens_ptr, int *tiles_ptr, + int *cu_tiles_ptr, int num_group, int m, int n, int k, + cutlass::FastDivmod flat_divider) { + using namespace cute; // NOLINT + + using Tin = typename Config::Tin; + using Tout = typename Config::Tout; + using TiledMma = typename Config::TiledMma; + using SLayoutA = typename Config::SLayoutX; + using SLayoutB = typename Config::SLayoutW; + using SLayoutCT = typename Config::SLayoutY; + + constexpr int kTileM = Config::kTileM; + constexpr int kTileN = Config::kTileN; + constexpr int kTileK = Config::kTileK; + constexpr int kStage = Config::kStage; + + int idx = threadIdx.x; + + int iwarp = __shfl_sync(0xFFFFFFFF, idx / 32, 0); + int elected = cute::elect_one_sync(); + bool is_leader_in_block = (iwarp == 0) && elected; + bool is_leader_in_warpgroup = ((iwarp % 4) == 0) && elected; + + __shared__ uint64_t writable[kStage]; + __shared__ uint64_t readable[kStage]; + + extern __shared__ uint8_t shm_data[] alignas(128); + auto *shm_a = reinterpret_cast(shm_data); + auto *shm_b = shm_a + cosize(SLayoutA{}); + auto *shm_c = reinterpret_cast(shm_b + cosize(SLayoutB{})); + int *shm_tiles = reinterpret_cast(shm_c + cosize(SLayoutCT{})); + + TmaA tma_a; + TmaD tma_d; + + int num_total_warps = blockDim.x / 32; + for (int i = iwarp; i < num_group * 2; i += num_total_warps) { + tma_descriptor_fence_acquire(td_xy + i); + } + + auto sA = make_tensor(make_smem_ptr(shm_a), SLayoutA{}); + auto sB = make_tensor(make_smem_ptr(shm_b), SLayoutB{}); + + auto gA = tma_a.get_tma_tensor(make_shape(m, k)); + auto gB = tma_b.get_tma_tensor(make_shape(n, k, num_group)); + auto gC = + make_tensor(make_gmem_ptr(static_cast(nullptr)), + make_shape(Int{}, Int{}), make_stride(Int{}, Int<1>{})); + + auto btma_a = tma_a.get_slice(0); + auto btma_b = tma_b.get_slice(0); + + auto tAg = btma_a.partition_S(gA); // (TMA, TMA_M, TMA_K) + auto tAs = btma_a.partition_D(sA); // (TMA, _1, _1, kStage) + + auto tBg = btma_b.partition_S(gB); // (TMA, TMA_N, TMA_K, num_group) + auto tBs = btma_b.partition_D(sB); // (TMA, _1, _1, kStage) + + int num_tile_n = size<1>(tBg); + + if (is_leader_in_block) { +#pragma unroll + for (int i = 0; i < kStage; ++i) { + initialize_barrier(readable[i], 1); + initialize_barrier(writable[i], size(TiledMma{}) / 128); + } + } + + // we can also use the following code to initialize the barrier + /* + if (idx < kStage) { + readable[idx] = 0x7ffff800001ffffe; // initialize_barrier(1); + writable[idx] = 0x7ffff000001ffffc; // initialize_barrier(2); + } + */ + + int total_m = cu_tiles_ptr[num_group]; + if (total_m <= 0) { + return; + } + + if constexpr (IsLoopH) { + for (int i = idx; i < num_group; i += blockDim.x) { + shm_tiles[i] = tiles_ptr[i]; + } + } else { + for (int i = idx; i < (num_group + 1); i += blockDim.x) { + shm_tiles[i] = cu_tiles_ptr[i]; + } + } + + // sync to avoid ahead thread use(wait) readable when it is not initizlized yet + __syncthreads(); + + constexpr int kNumThreads = size(TiledMma{}); + // load warpgroup + if (idx >= kNumThreads) { + cutlass::arch::warpgroup_reg_dealloc<24>(); + idx -= kNumThreads; + constexpr int kTransactionBytes = sizeof(Tin) * (kTileM + kTileN) * kTileK; + // sizeof(Tin) * cosize(SLayoutA{}(_, _, 0)) + sizeof(Tin) * cosize(SLayoutB{}(_, _, 0)); + + int iwarp = __shfl_sync(0xFFFFFFFF, idx / 32, 0); + int is_leader_in_load = ((iwarp == 0) && elected); + + if (is_leader_in_load) { + int phase = 1; // start with ok + int ismem_write = __shfl_sync(0xFFFFFFFF, 0, 0); + int iblock = blockIdx.x; + int ntile_k = size<2>(tAg); + + int igroup = 0; + int sum_tile_m = 0; + int itile_m, itile_n; + while (true) { + if constexpr (IsLoopH) { + get_next_tile_horizon(shm_tiles, iblock, num_group, igroup, itile_m, itile_n, sum_tile_m, + flat_divider); + if (igroup < 0) { + break; + } + } else { + get_next_tile_vert(shm_tiles, iblock, num_group, igroup, itile_m, itile_n, total_m); + if (itile_n >= num_tile_n) { + break; + } + } + + iblock += gridDim.x; + + auto *td_x = td_xy + igroup * 2; + +#pragma unroll 1 + for (int itile_k = 0; itile_k < ntile_k; ++itile_k) { + // load a, b + wait_barrier(writable[ismem_write], phase); + + cute::copy(tma_a.with(td_x, readable[ismem_write]), tAg(_, itile_m, itile_k), + tAs(_, 0, 0, ismem_write)); + + cute::copy(tma_b.with(readable[ismem_write]), tBg(_, itile_n, itile_k, igroup), + tBs(_, 0, 0, ismem_write)); + + set_barrier_transaction_bytes(readable[ismem_write], kTransactionBytes); + + ++ismem_write; + if (ismem_write == kStage) { + ismem_write = 0; + phase ^= 1; + } + } // ntile_todo + } // while + } // if idx == 0 + + } else { + // math warpgroup + cutlass::arch::warpgroup_reg_alloc<168>(); + + int idx_in_warpgroup = idx % 128; + int iwarpgroup = idx / 128; + int iwarp_in_warpgroup = idx_in_warpgroup / 32; + int elected_idx_in_warpgroup = ((iwarp_in_warpgroup == 0) && elected); + + TiledMma tiled_mma; + + auto thr_mma = tiled_mma.get_slice(idx); + auto tBs4r = thr_mma.partition_A(sB); + auto tAs4r = thr_mma.partition_B(sA); + + auto tBr = thr_mma.make_fragment_A(tBs4r); // (MMA, MMA_N, MMA_K, kStage) + auto tAr = thr_mma.make_fragment_B(tAs4r); // (MMA, MMA_M, MMA_K, kStage) + + auto tCr = thr_mma.partition_fragment_C(gC); + + int ismem_read = 0; + int phase = 0; + + int iblock = blockIdx.x; + int igroup = 0; + int sum_tile_m = 0; + int itile_m, itile_n; + while (true) { + if constexpr (IsLoopH) { + get_next_tile_horizon(shm_tiles, iblock, num_group, igroup, itile_m, itile_n, sum_tile_m, + flat_divider); + if (igroup < 0) { + break; + } + } else { + get_next_tile_vert(shm_tiles, iblock, num_group, igroup, itile_m, itile_n, total_m); + if (itile_n >= num_tile_n) { + break; + } + } + + iblock += gridDim.x; + + auto tDr = make_tensor_like(tCr); + clear(tDr); + + int ntile_k = size<2>(tAg); +#pragma unroll 1 + for (int itile_k = 0; itile_k < ntile_k; ++itile_k) { + wait_barrier(readable[ismem_read], phase); + + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + // mma + warpgroup_fence_operand(tCr); + warpgroup_arrive(); +#pragma unroll + for (int ik = 0; ik < size<2>(tAr); ++ik) { + cute::gemm(tiled_mma, tBr(_, _, ik, ismem_read), tAr(_, _, ik, ismem_read), tCr(_, _, _)); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + + warpgroup_commit_batch(); + warpgroup_wait<0>(); + warpgroup_fence_operand(tCr); + + if (elected_idx_in_warpgroup) { + arrive_barrier(writable[ismem_read]); + } + +#pragma unroll + for (int i = 0; i < size(tCr); ++i) { + tDr(i) = tCr(i) + tDr(i); + } + + ++ismem_read; + if (ismem_read == kStage) { + phase ^= 1; + ismem_read = 0; + } + } + + // float32 -> bfloat16 + auto tCrh = make_tensor_like(tCr); + +#pragma unroll + for (int i = 0; i < size(tCr); ++i) { + tCrh(i) = (Tout)(tDr(i)); + } + + // Epilogue + auto sCT = + make_tensor(make_smem_ptr(reinterpret_cast(shm_c)), SLayoutCT{}); // (M, N) + using R2SCopyAtomC = Copy_Atom; + // using R2SCopyAtomC = Copy_Atom; + auto tiled_copy_c = make_tiled_copy_C(R2SCopyAtomC{}, tiled_mma); + auto thr_copy_c = tiled_copy_c.get_slice(idx); + + auto tCr4s = thr_copy_c.retile_S(tCrh); + auto tCs4r = thr_copy_c.partition_D(sCT); + + tma_store_wait<0>(); + syncwarpgroup(iwarpgroup); + + cute::copy(tiled_copy_c, tCr4s, tCs4r); + syncwarpgroup(iwarpgroup); + cute::tma_store_fence(); + + if (is_leader_in_warpgroup) { + auto gD = tma_d.get_tma_tensor(make_shape(n, m)); + auto btma_d = tma_d.get_slice(0); + + auto tDs = btma_d.partition_S(sCT); // (TMA, _2, _1) + auto tDg = btma_d.partition_D(gD); // (TMA, TMA_M, TMA_N) + + auto *td_y = td_xy + igroup * 2 + 1; + cute::copy(tma_d.with(td_y), tDs(_, iwarpgroup, Int<0>{}), + tDg(_, itile_n * 2 + iwarpgroup, itile_m)); + tma_store_arrive(); + } + } + } +} + } // namespace kernels } // namespace group_gemm diff --git a/tests/test_group_gemm_bf16.py b/tests/test_group_gemm_bf16.py new file mode 100644 index 0000000..34a5b0d --- /dev/null +++ b/tests/test_group_gemm_bf16.py @@ -0,0 +1,65 @@ +import os +import sys +from pathlib import Path + +sys.path.insert(0, os.path.realpath(list(Path(__file__).parent.glob("../build/lib.*/"))[0])) + +import torch +import pytest +import hpc +from utils import allclose + +torch.manual_seed(41) +torch.cuda.manual_seed(41) + + +def naive_group_gemm_bf16(x, w, seqlens, cu_seqlens): + m, k = x.shape + num_group, n, _ = w.shape + + y = torch.zeros((m, n), dtype=torch.bfloat16, device=x.device) + + for i in range(num_group): + start_idx = int(cu_seqlens[i].item()) + end_idx = int(cu_seqlens[i + 1].item()) + + if seqlens[i].item() == 0: + continue + + x_group = x[start_idx:end_idx] # [M_i, K] + w_group = w[i] # [N, K] + + # y_group = x_group @ w_group.T + y_group = torch.matmul(x_group, w_group.t()) + + y[start_idx:end_idx] = y_group + + return y + + +@pytest.mark.parametrize("num_group", [8]) +@pytest.mark.parametrize("actual_m", [8, 16, 32, 64, 128, 256, 512]) +@pytest.mark.parametrize("m", [512]) +@pytest.mark.parametrize("n", [4096]) +@pytest.mark.parametrize("k", [7168]) +def test_group_gemm_bf16(num_group, actual_m, m, n, k): + dtype = torch.bfloat16 + + seqlens = torch.full((num_group,), actual_m, dtype=torch.int32, device="cuda") + total_seq = torch.sum(seqlens) + mean_seq = int(total_seq / num_group) + + x = torch.randn((total_seq, k), dtype=dtype, device="cuda") + w = torch.randn((num_group, n, k), dtype=dtype, device="cuda") + + cu_seqlens = torch.zeros(num_group + 1, dtype=torch.int32, device="cuda") + cu_seqlens[1:] = torch.cumsum(seqlens, dim=0) + gt = naive_group_gemm_bf16(x, w, seqlens, cu_seqlens) + + my = hpc.group_gemm_bf16(x, w, seqlens, cu_seqlens, num_seq_per_group_avg=mean_seq) + + assert allclose(gt.to(torch.float32), my.to(torch.float32), rtol=1e-2, atol=1e-2) + + +if __name__ == "__main__": + test_group_gemm_bf16(8, 128, 512, 4096, 7168)