From 8b6ea40f582800763144005811613f1eb62ca557 Mon Sep 17 00:00:00 2001 From: ZelinMa557 <3388706467@qq.com> Date: Mon, 23 Mar 2026 19:39:16 +0800 Subject: [PATCH 1/7] feat: add bf16 group gemm kernel Signed-off-by: ZelinMa557 <3388706467@qq.com> --- hpc/group_gemm.py | 51 ++++++ src/group_gemm/config.h | 69 ++++++++ src/group_gemm/entry.cc | 67 ++++++- src/group_gemm/group_gemm.h | 6 + src/group_gemm/group_gemm_bf16.cu | 138 +++++++++++++++ src/group_gemm/kernels.cuh | 281 ++++++++++++++++++++++++++++++ tests/test_group_gemm_bf16.py | 65 +++++++ 7 files changed, 676 insertions(+), 1 deletion(-) create mode 100644 src/group_gemm/group_gemm_bf16.cu create mode 100644 tests/test_group_gemm_bf16.py diff --git a/hpc/group_gemm.py b/hpc/group_gemm.py index 0631c2e..a292673 100644 --- a/hpc/group_gemm.py +++ b/hpc/group_gemm.py @@ -152,6 +152,52 @@ def group_gemm_blockwise_fp8( ) +def group_gemm_bf16( + x: Tensor, + weight: Tensor, + seqlens: Tensor, + cu_seqlens: Tensor, + num_seq_per_group_avg: int = 32, + output: Tensor = None, + tma_desc: Tensor = None, +) -> Tensor: + """Performs group GEMM operation with BF16 precision. + + This function executes multiple matrix multiplications in a group manner + using BF16 precision for improved performance. + + Args: + x: Input activation tensor + Shape: [total_seq, hidden_size] + Dtype: bfloat16 + weight: Weight tensor for group matrix multiplication + Shape: [num_group, output_dim, hidden_size] + Dtype: bfloat16 + seqlens: Sequence lengths for each group + Shape: [num_group] + Dtype: int32 + cu_seqlens: Cumulative sequence lengths indicating start indices in input tensor + Shape: [num_group + 1] + Dtype: int32 + + Returns: + Tensor: Output tensor after group matrix multiplication + Shape: [total_seq, output_dim] + Dtype: bfloat16 + + Raises: + RuntimeError: If the input tensors have incompatible shapes or types, + or if the CUDA kernel execution fails. + + Note: + - All input tensors must be on CUDA device + + """ + return torch.ops.hpc.group_gemm_bf16( + x, weight, seqlens, cu_seqlens, num_seq_per_group_avg, output, tma_desc + ) + + @torch.library.register_fake("hpc::group_gemm_pertensor_fp8") def group_gemm_pertensor_fp8_fake( x, weight, seqlens, cu_seqlens, y_scale, num_seq_per_group_avg, output, tma_des @@ -164,3 +210,8 @@ def group_gemm_blockwise_fp8_fake( x, weight, seqlens, cu_seqlens, x_scale, w_scale, num_seq_per_group_avg, output, tma_des ): return torch.empty((x.shape[0], weight.shape[1]), dtype=torch.bfloat16) + + +@torch.library.register_fake("hpc::group_gemm_bf16") +def group_gemm_bf16_fake(x, weight, seqlens, cu_seqlens, num_seq_per_group_avg, output, tma_des): + return torch.empty((x.shape[0], weight.shape[1]), dtype=torch.bfloat16) diff --git a/src/group_gemm/config.h b/src/group_gemm/config.h index b0b4656..16462c0 100644 --- a/src/group_gemm/config.h +++ b/src/group_gemm/config.h @@ -165,6 +165,75 @@ struct GroupGEMMBlockWiseFp8Config { auto get_shm_size() { return shm_size; } }; +template +static constexpr auto mma_selector_bf16() { + constexpr auto MajorA = GMMA::Major::K; + constexpr auto MajorB = GMMA::Major::K; + if constexpr (kTileM == 8) { + return cute::SM90_64x8x16_F32BF16BF16_SS{}; + } else if constexpr (kTileM == 16) { + return cute::SM90_64x16x16_F32BF16BF16_SS{}; + } else if constexpr (kTileM == 32) { + return cute::SM90_64x32x16_F32BF16BF16_SS{}; + } else if constexpr (kTileM == 48) { + return cute::SM90_64x48x16_F32BF16BF16_SS{}; + } else if constexpr (kTileM == 64) { + return cute::SM90_64x64x16_F32BF16BF16_SS{}; + } else if constexpr (kTileM == 96) { + return cute::SM90_64x96x16_F32BF16BF16_SS{}; + } else if constexpr (kTileM == 128) { + return cute::SM90_64x128x16_F32BF16BF16_SS{}; + } else { + return cute::SM90_64x64x16_F32BF16BF16_SS{}; + } +} + +template +struct GroupGEMMBF16Config { + using Tin = Tin_; + using Tout = Tout_; + + static constexpr int kTileM = kTileM_; + static constexpr int kTileN = kTileN_; + static constexpr int kTileK = kTileK_; + static constexpr int kStage = kStage_; + static constexpr int kWarpgroupM = kWarpgroupM_; + static constexpr int kWarpgroupN = kWarpgroupN_; + + using SLayoutXAtom = decltype(slayout_selector()); + using SLayoutWAtom = decltype(slayout_selector()); + using SLayoutYAtom = decltype(slayout_selector()); + + using SLayoutX = decltype(tile_to_shape(SLayoutXAtom{}, + make_shape(Int{}, Int{}, Int{}))); + using SLayoutW = decltype(tile_to_shape(SLayoutWAtom{}, + make_shape(Int{}, Int{}, Int{}))); + using SLayoutY = + decltype(tile_to_shape(SLayoutYAtom{}, make_shape(Int{}, Int{}))); + using CopyBoxY = decltype(tile_to_shape(SLayoutYAtom{}, + make_shape(Int{}, Int{}))); + + template + auto get_tma(TX x, TW w, TY y) { + auto tma_x = make_tma_copy(SM90_TMA_LOAD{}, x, take<0, 2>(SLayoutX{})); + auto tma_w = make_tma_copy(SM90_TMA_LOAD{}, w, take<0, 2>(SLayoutW{})); + auto tma_y = make_tma_copy(SM90_TMA_STORE{}, y, CopyBoxY{}); + return std::make_tuple(tma_x, tma_w, tma_y); + } + + using WarpgroupLayout = + decltype(make_layout(make_shape(Int{}, Int{}, Int<1>{}))); + using TiledMma = decltype(make_tiled_mma(mma_selector_bf16(), WarpgroupLayout{})); + + static constexpr int shm_xw = (cosize(SLayoutX{}) + cosize(SLayoutW{})) * sizeof(Tin); + static constexpr int shm_y = cosize(SLayoutY{}) * sizeof(Tout); + static constexpr int shm_size = shm_xw + shm_y; + + auto get_shm_size() { return shm_size; } +}; + } // namespace group_gemm } // namespace hpc diff --git a/src/group_gemm/entry.cc b/src/group_gemm/entry.cc index 5c74dff..9cc4c9c 100644 --- a/src/group_gemm/entry.cc +++ b/src/group_gemm/entry.cc @@ -6,7 +6,7 @@ #include #include - +#include #include "src/group_gemm/group_gemm.h" namespace hpc { @@ -137,6 +137,65 @@ torch::Tensor group_gemm_blockwise_fp8_entry( return y; } +torch::Tensor group_gemm_bf16_entry(const torch::Tensor &x, const torch::Tensor &weight, + const torch::Tensor &seqlens, + const torch::Tensor &cu_seqlens, + const int64_t num_seq_per_group_avg, + std::optional output, + std::optional tma_desc) { + auto stream = at::cuda::getCurrentCUDAStream(x.get_device()); + TORCH_CHECK(x.device().is_cuda(), "x tensor must be cuda"); + TORCH_CHECK(weight.device().is_cuda(), "weight tensor must be cuda"); + TORCH_CHECK(seqlens.device().is_cuda(), "seqlens tensor must be cuda"); + TORCH_CHECK(cu_seqlens.device().is_cuda(), "cu_seqlens tensor must be cuda"); + TORCH_CHECK(x.is_contiguous(), "x tensor a must be contiguous"); + TORCH_CHECK(weight.is_contiguous(), "weight tensor a must be contiguous"); + TORCH_CHECK(seqlens.size(0) == weight.size(0), + "seqlens and weight must share the same num_group"); + TORCH_CHECK(x.size(1) == weight.size(2), "x and weight must share the same k"); + + int m = x.size(0); + int k = x.size(1); + int n = weight.size(1); + int num_group = seqlens.size(0); + + auto options = x.options(); + torch::Tensor y; + if (output.has_value()) { + y = output.value(); + } else { + y = torch::empty({m, n}, options.dtype(torch::kBFloat16)); + } + + torch::Tensor tmas; + bool update_tma = true; + if (tma_desc.has_value()) { + tmas = tma_desc.value(); + update_tma = false; + } else { + tmas = torch::empty({num_group * 2, 128}, options); + } + + torch::Tensor tiles = torch::empty({num_group}, options.dtype(torch::kInt32)); + torch::Tensor cu_tiles = torch::empty({num_group + 1}, options.dtype(torch::kInt32)); + + const auto *x_ptr = x.const_data_ptr(); + const auto *weight_ptr = weight.const_data_ptr(); + const auto *seqlens_ptr = seqlens.const_data_ptr(); + const auto *cu_seqlens_ptr = cu_seqlens.const_data_ptr(); + auto *tmas_ptr = tmas.mutable_data_ptr(); + auto *y_ptr = y.mutable_data_ptr(); + + auto *tiles_ptr = tiles.mutable_data_ptr(); + auto *cu_tiles_ptr = cu_tiles.mutable_data_ptr(); + + group_gemm_bf16_async(y_ptr, x_ptr, weight_ptr, seqlens_ptr, cu_seqlens_ptr, + tmas_ptr, tiles_ptr, cu_tiles_ptr, num_group, m, n, k, + num_seq_per_group_avg, update_tma, stream); + + return y; +} + torch::Tensor reformat_x_scale_entry(const torch::Tensor &x_scale, const torch::Tensor &seqlens, const torch::Tensor &cu_seqlens, std::optional out_x_scale, @@ -197,6 +256,12 @@ TORCH_LIBRARY_FRAGMENT(hpc, m) { m.impl("group_gemm_pertensor_fp8", torch::kCUDA, &hpc::group_gemm::group_gemm_pertensor_fp8_entry); + m.def( + "group_gemm_bf16(Tensor x, Tensor weight, Tensor seqlens, Tensor cu_seqlens, " + "int num_seq_per_group_avg, Tensor? output, Tensor? tma_desc) -> (Tensor)"); + m.impl("group_gemm_bf16", torch::kCUDA, + &hpc::group_gemm::group_gemm_bf16_entry); + m.def( "group_gemm_blockwise_fp8(Tensor x, Tensor weight, Tensor seqlens, Tensor cu_seqlens, Tensor " "xscale, Tensor wscale," diff --git a/src/group_gemm/group_gemm.h b/src/group_gemm/group_gemm.h index 6e5752a..e74447d 100644 --- a/src/group_gemm/group_gemm.h +++ b/src/group_gemm/group_gemm.h @@ -28,6 +28,12 @@ void reformat_x_scale_async(void *output_ptr, const void *xscale_ptr, const void const void *cu_seqlens_ptr, int num_group, int m, int n, int tilem, cudaStream_t stream); +void group_gemm_bf16_async(void *y_ptr, const void *x_ptr, const void *w_ptr, + const void *seqlens_ptr, const void *cu_seqlens_ptr, + void *tmas_ptr, void *tiles_ptr, + void *cu_tiles_ptr, int num_group, int m, int n, int k, + int num_seq_per_group_avg, bool update_tma, + cudaStream_t stream); } // namespace group_gemm } // namespace hpc diff --git a/src/group_gemm/group_gemm_bf16.cu b/src/group_gemm/group_gemm_bf16.cu new file mode 100644 index 0000000..6cae49b --- /dev/null +++ b/src/group_gemm/group_gemm_bf16.cu @@ -0,0 +1,138 @@ +// Copyright (C) 2026 Tencent. + +#include +#include + +#include +#include +#include "cute/tensor.hpp" +#include "src/group_gemm/config.h" +#include "src/group_gemm/group_gemm.h" +#include "src/group_gemm/kernels.cuh" + +namespace hpc { +namespace group_gemm { + +template +void launch_group_gemm_bf16(void *y_ptr, const void *x_ptr, const void *w_ptr, + const void *seqlens_ptr, const void *cu_seqlens_ptr, + void *tmas_ptr, void *tiles_ptr, void *cu_tiles_ptr, int num_group, + int m, int n, int k, bool update_tma, cudaStream_t stream) { + using namespace cute; // NOLINT + + using Tin = cute::bfloat16_t; + using Tout = cute::bfloat16_t; + + auto X = make_tensor(make_gmem_ptr(reinterpret_cast(x_ptr)), make_shape(m, k), + make_stride(k, Int<1>{})); + auto W = make_tensor(make_gmem_ptr(reinterpret_cast(w_ptr)), + make_shape(n, k, num_group), make_stride(k, Int<1>{}, n * k)); + auto Y = make_tensor(make_gmem_ptr(reinterpret_cast(y_ptr)), make_shape(n, m), + make_stride(Int<1>{}, n)); + + using Config = GroupGEMMBF16Config; + Config config; + auto [tma_x, tma_w, tma_y] = config.get_tma(X, W, Y); + + auto *tma_xy = static_cast(tmas_ptr); + + // 0. update tma + if (update_tma) { + vec_t td_xy{ + *tma_x.get_tma_descriptor(), + *tma_y.get_tma_descriptor(), + }; + + constexpr int kGroupPerThread = 8; + constexpr int kThreadPerBlock = 32; + kernels::update_grouped_tma + <<>>( + td_xy, tma_xy, (const Tin *)x_ptr, (const Tout *)y_ptr, (const int *)seqlens_ptr, + (const int *)cu_seqlens_ptr, (int *)tiles_ptr, (int *)cu_tiles_ptr, num_group, m, n, k); + } + + // 1. group gemm + { + int num_tile_n = (n + kTileN - 1) / kTileN; + cutlass::FastDivmod flat_divider(num_tile_n); + + // dim3 block(size(Config::TiledMma{}) + 128); + dim3 block(384); + dim3 grid(get_sm_count()); + + int shm_seq = sizeof(int) * (num_group + 1); + int shm_size = config.get_shm_size() + shm_seq; + + if (k <= 1024 || n <= 1024) { + constexpr bool IsLoopH = true; + auto kernel = + kernels::group_gemm_bf16_kernel; + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shm_size); + + kernel<<>>(tma_w, tma_xy, (int *)seqlens_ptr, + (int *)tiles_ptr, (int *)cu_tiles_ptr, num_group, m, + n, k, flat_divider); + } else { + constexpr bool IsLoopH = false; + auto kernel = + kernels::group_gemm_bf16_kernel; + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shm_size); + + kernel<<>>(tma_w, tma_xy, (int *)seqlens_ptr, + (int *)tiles_ptr, (int *)cu_tiles_ptr, num_group, m, + n, k, flat_divider); + } + } +} + +void group_gemm_bf16_async(void *y_ptr, const void *x_ptr, const void *w_ptr, + const void *seqlens_ptr, const void *cu_seqlens_ptr, + void *tmas_ptr, void *tiles_ptr, + void *cu_tiles_ptr, int num_group, int m, int n, int k, + int num_seq_per_group_avg, bool update_tma, + cudaStream_t stream) { + constexpr int kTileN = 128; + constexpr int kTileK = 64; + constexpr int kWarpgroupM = 2; + constexpr int kWarpgroupN = 1; + constexpr int kSwizzleX = 128; + constexpr int kSwizzleW = 128; + constexpr int kSwizzleY = 64; + if (num_seq_per_group_avg <= 16) { + constexpr int kTileM = 16; + constexpr int kStage = 8; + launch_group_gemm_bf16(y_ptr, x_ptr, w_ptr, seqlens_ptr, cu_seqlens_ptr, + tmas_ptr, tiles_ptr, cu_tiles_ptr, + num_group, m, n, k, update_tma, stream); + } else if (num_seq_per_group_avg <= 32) { + constexpr int kTileM = 32; + constexpr int kStage = 8; + launch_group_gemm_bf16(y_ptr, x_ptr, w_ptr, seqlens_ptr, cu_seqlens_ptr, + tmas_ptr, tiles_ptr, cu_tiles_ptr, + num_group, m, n, k, update_tma, stream); + } else if (num_seq_per_group_avg <= 48) { + constexpr int kTileM = 48; + constexpr int kStage = 8; + launch_group_gemm_bf16(y_ptr, x_ptr, w_ptr, seqlens_ptr, cu_seqlens_ptr, + tmas_ptr, tiles_ptr, cu_tiles_ptr, + num_group, m, n, k, update_tma, stream); + } else { + constexpr int kTileM = 64; + constexpr int kStage = 8; + launch_group_gemm_bf16(y_ptr, x_ptr, w_ptr, seqlens_ptr, cu_seqlens_ptr, + tmas_ptr, tiles_ptr, cu_tiles_ptr, + num_group, m, n, k, update_tma, stream); + } +} + +} // namespace group_gemm +} // namespace hpc diff --git a/src/group_gemm/kernels.cuh b/src/group_gemm/kernels.cuh index aa7f06a..6b054f1 100644 --- a/src/group_gemm/kernels.cuh +++ b/src/group_gemm/kernels.cuh @@ -752,6 +752,287 @@ __global__ void __launch_bounds__(384, 1) } } +template +__global__ void __launch_bounds__(384, 1) + group_gemm_bf16_kernel(const __grid_constant__ TmaB tma_b, cute::TmaDescriptor *td_xy, + int *seqlens_ptr, int *tiles_ptr, + int *cu_tiles_ptr, int num_group, int m, int n, int k, + cutlass::FastDivmod flat_divider) { + using namespace cute; // NOLINT + + using Tin = typename Config::Tin; + using Tout = typename Config::Tout; + using TiledMma = typename Config::TiledMma; + using SLayoutA = typename Config::SLayoutX; + using SLayoutB = typename Config::SLayoutW; + using SLayoutCT = typename Config::SLayoutY; + + constexpr int kTileM = Config::kTileM; + constexpr int kTileN = Config::kTileN; + constexpr int kTileK = Config::kTileK; + constexpr int kStage = Config::kStage; + + int idx = threadIdx.x; + + int iwarp = __shfl_sync(0xFFFFFFFF, idx / 32, 0); + int elected = cute::elect_one_sync(); + bool is_leader_in_block = (iwarp == 0) && elected; + bool is_leader_in_warpgroup = ((iwarp % 4) == 0) && elected; + + __shared__ uint64_t writable[kStage]; + __shared__ uint64_t readable[kStage]; + + extern __shared__ uint8_t shm_data[] alignas(128); + auto *shm_a = reinterpret_cast(shm_data); + auto *shm_b = shm_a + cosize(SLayoutA{}); + auto *shm_c = reinterpret_cast(shm_b + cosize(SLayoutB{})); + int *shm_tiles = reinterpret_cast(shm_c + cosize(SLayoutCT{})); + + TmaA tma_a; + TmaD tma_d; + + int num_total_warps = blockDim.x / 32; + for (int i = iwarp; i < num_group * 2; i += num_total_warps) { + tma_descriptor_fence_acquire(td_xy + i); + } + + auto sA = make_tensor(make_smem_ptr(shm_a), SLayoutA{}); + auto sB = make_tensor(make_smem_ptr(shm_b), SLayoutB{}); + + auto gA = tma_a.get_tma_tensor(make_shape(m, k)); + auto gB = tma_b.get_tma_tensor(make_shape(n, k, num_group)); + auto gC = + make_tensor(make_gmem_ptr(static_cast(nullptr)), + make_shape(Int{}, Int{}), make_stride(Int{}, Int<1>{})); + + auto btma_a = tma_a.get_slice(0); + auto btma_b = tma_b.get_slice(0); + + auto tAg = btma_a.partition_S(gA); // (TMA, TMA_M, TMA_K) + auto tAs = btma_a.partition_D(sA); // (TMA, _1, _1, kStage) + + auto tBg = btma_b.partition_S(gB); // (TMA, TMA_N, TMA_K, num_group) + auto tBs = btma_b.partition_D(sB); // (TMA, _1, _1, kStage) + + int num_tile_n = size<1>(tBg); + + if (is_leader_in_block) { +#pragma unroll + for (int i = 0; i < kStage; ++i) { + initialize_barrier(readable[i], 1); + initialize_barrier(writable[i], size(TiledMma{}) / 128); + } + } + + // we can also use the following code to initialize the barrier + /* + if (idx < kStage) { + readable[idx] = 0x7ffff800001ffffe; // initialize_barrier(1); + writable[idx] = 0x7ffff000001ffffc; // initialize_barrier(2); + } + */ + + int total_m = cu_tiles_ptr[num_group]; + if (total_m <= 0) { + return; + } + + if constexpr (IsLoopH) { + for (int i = idx; i < num_group; i += blockDim.x) { + shm_tiles[i] = tiles_ptr[i]; + } + } else { + for (int i = idx; i < (num_group + 1); i += blockDim.x) { + shm_tiles[i] = cu_tiles_ptr[i]; + } + } + + // sync to avoid ahead thread use(wait) readable when it is not initizlized yet + __syncthreads(); + + constexpr int kNumThreads = size(TiledMma{}); + // load warpgroup + if (idx >= kNumThreads) { + cutlass::arch::warpgroup_reg_dealloc<24>(); + idx -= kNumThreads; + constexpr int kTransactionBytes = sizeof(Tin) * (kTileM + kTileN) * kTileK; + // sizeof(Tin) * cosize(SLayoutA{}(_, _, 0)) + sizeof(Tin) * cosize(SLayoutB{}(_, _, 0)); + + int iwarp = __shfl_sync(0xFFFFFFFF, idx / 32, 0); + int is_leader_in_load = ((iwarp == 0) && elected); + + if (is_leader_in_load) { + int phase = 1; // start with ok + int ismem_write = __shfl_sync(0xFFFFFFFF, 0, 0); + int iblock = blockIdx.x; + int ntile_k = size<2>(tAg); + + int igroup = 0; + int sum_tile_m = 0; + int itile_m, itile_n; + while (true) { + if constexpr (IsLoopH) { + get_next_tile_horizon(shm_tiles, iblock, num_group, igroup, itile_m, itile_n, sum_tile_m, + flat_divider); + if (igroup < 0) { + break; + } + } else { + get_next_tile_vert(shm_tiles, iblock, num_group, igroup, itile_m, itile_n, total_m); + if (itile_n >= num_tile_n) { + break; + } + } + + iblock += gridDim.x; + + auto *td_x = td_xy + igroup * 2; + +#pragma unroll 1 + for (int itile_k = 0; itile_k < ntile_k; ++itile_k) { + // load a, b + wait_barrier(writable[ismem_write], phase); + + cute::copy(tma_a.with(td_x, readable[ismem_write]), tAg(_, itile_m, itile_k), + tAs(_, 0, 0, ismem_write)); + + cute::copy(tma_b.with(readable[ismem_write]), tBg(_, itile_n, itile_k, igroup), + tBs(_, 0, 0, ismem_write)); + + set_barrier_transaction_bytes(readable[ismem_write], kTransactionBytes); + + ++ismem_write; + if (ismem_write == kStage) { + ismem_write = 0; + phase ^= 1; + } + } // ntile_todo + } // while + } // if idx == 0 + + } else { + // math warpgroup + cutlass::arch::warpgroup_reg_alloc<168>(); + + int idx_in_warpgroup = idx % 128; + int iwarpgroup = idx / 128; + int iwarp_in_warpgroup = idx_in_warpgroup / 32; + int elected_idx_in_warpgroup = ((iwarp_in_warpgroup == 0) && elected); + + TiledMma tiled_mma; + + auto thr_mma = tiled_mma.get_slice(idx); + auto tBs4r = thr_mma.partition_A(sB); + auto tAs4r = thr_mma.partition_B(sA); + + auto tBr = thr_mma.make_fragment_A(tBs4r); // (MMA, MMA_N, MMA_K, kStage) + auto tAr = thr_mma.make_fragment_B(tAs4r); // (MMA, MMA_M, MMA_K, kStage) + + auto tCr = thr_mma.partition_fragment_C(gC); + + int ismem_read = 0; + int phase = 0; + + int iblock = blockIdx.x; + int igroup = 0; + int sum_tile_m = 0; + int itile_m, itile_n; + while (true) { + if constexpr (IsLoopH) { + get_next_tile_horizon(shm_tiles, iblock, num_group, igroup, itile_m, itile_n, sum_tile_m, + flat_divider); + if (igroup < 0) { + break; + } + } else { + get_next_tile_vert(shm_tiles, iblock, num_group, igroup, itile_m, itile_n, total_m); + if (itile_n >= num_tile_n) { + break; + } + } + + iblock += gridDim.x; + + auto tDr = make_tensor_like(tCr); + clear(tDr); + + int ntile_k = size<2>(tAg); +#pragma unroll 1 + for (int itile_k = 0; itile_k < ntile_k; ++itile_k) { + wait_barrier(readable[ismem_read], phase); + + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + // mma + warpgroup_fence_operand(tCr); + warpgroup_arrive(); +#pragma unroll + for (int ik = 0; ik < size<2>(tAr); ++ik) { + cute::gemm(tiled_mma, tBr(_, _, ik, ismem_read), tAr(_, _, ik, ismem_read), tCr(_, _, _)); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + + warpgroup_commit_batch(); + warpgroup_wait<0>(); + warpgroup_fence_operand(tCr); + + if (elected_idx_in_warpgroup) { + arrive_barrier(writable[ismem_read]); + } + +#pragma unroll + for (int i = 0; i < size(tCr); ++i) { + tDr(i) = tCr(i) + tDr(i); + } + + ++ismem_read; + if (ismem_read == kStage) { + phase ^= 1; + ismem_read = 0; + } + } + + // float32 -> bfloat16 + auto tCrh = make_tensor_like(tCr); + +#pragma unroll + for (int i = 0; i < size(tCr); ++i) { + tCrh(i) = (Tout)(tDr(i)); + } + + // Epilogue + auto sCT = + make_tensor(make_smem_ptr(reinterpret_cast(shm_c)), SLayoutCT{}); // (M, N) + using R2SCopyAtomC = Copy_Atom; + // using R2SCopyAtomC = Copy_Atom; + auto tiled_copy_c = make_tiled_copy_C(R2SCopyAtomC{}, tiled_mma); + auto thr_copy_c = tiled_copy_c.get_slice(idx); + + auto tCr4s = thr_copy_c.retile_S(tCrh); + auto tCs4r = thr_copy_c.partition_D(sCT); + + tma_store_wait<0>(); + syncwarpgroup(iwarpgroup); + + cute::copy(tiled_copy_c, tCr4s, tCs4r); + syncwarpgroup(iwarpgroup); + cute::tma_store_fence(); + + if (is_leader_in_warpgroup) { + auto gD = tma_d.get_tma_tensor(make_shape(n, m)); + auto btma_d = tma_d.get_slice(0); + + auto tDs = btma_d.partition_S(sCT); // (TMA, _2, _1) + auto tDg = btma_d.partition_D(gD); // (TMA, TMA_M, TMA_N) + + auto *td_y = td_xy + igroup * 2 + 1; + cute::copy(tma_d.with(td_y), tDs(_, iwarpgroup, Int<0>{}), + tDg(_, itile_n * 2 + iwarpgroup, itile_m)); + tma_store_arrive(); + } + } + } +} + } // namespace kernels } // namespace group_gemm diff --git a/tests/test_group_gemm_bf16.py b/tests/test_group_gemm_bf16.py new file mode 100644 index 0000000..34a5b0d --- /dev/null +++ b/tests/test_group_gemm_bf16.py @@ -0,0 +1,65 @@ +import os +import sys +from pathlib import Path + +sys.path.insert(0, os.path.realpath(list(Path(__file__).parent.glob("../build/lib.*/"))[0])) + +import torch +import pytest +import hpc +from utils import allclose + +torch.manual_seed(41) +torch.cuda.manual_seed(41) + + +def naive_group_gemm_bf16(x, w, seqlens, cu_seqlens): + m, k = x.shape + num_group, n, _ = w.shape + + y = torch.zeros((m, n), dtype=torch.bfloat16, device=x.device) + + for i in range(num_group): + start_idx = int(cu_seqlens[i].item()) + end_idx = int(cu_seqlens[i + 1].item()) + + if seqlens[i].item() == 0: + continue + + x_group = x[start_idx:end_idx] # [M_i, K] + w_group = w[i] # [N, K] + + # y_group = x_group @ w_group.T + y_group = torch.matmul(x_group, w_group.t()) + + y[start_idx:end_idx] = y_group + + return y + + +@pytest.mark.parametrize("num_group", [8]) +@pytest.mark.parametrize("actual_m", [8, 16, 32, 64, 128, 256, 512]) +@pytest.mark.parametrize("m", [512]) +@pytest.mark.parametrize("n", [4096]) +@pytest.mark.parametrize("k", [7168]) +def test_group_gemm_bf16(num_group, actual_m, m, n, k): + dtype = torch.bfloat16 + + seqlens = torch.full((num_group,), actual_m, dtype=torch.int32, device="cuda") + total_seq = torch.sum(seqlens) + mean_seq = int(total_seq / num_group) + + x = torch.randn((total_seq, k), dtype=dtype, device="cuda") + w = torch.randn((num_group, n, k), dtype=dtype, device="cuda") + + cu_seqlens = torch.zeros(num_group + 1, dtype=torch.int32, device="cuda") + cu_seqlens[1:] = torch.cumsum(seqlens, dim=0) + gt = naive_group_gemm_bf16(x, w, seqlens, cu_seqlens) + + my = hpc.group_gemm_bf16(x, w, seqlens, cu_seqlens, num_seq_per_group_avg=mean_seq) + + assert allclose(gt.to(torch.float32), my.to(torch.float32), rtol=1e-2, atol=1e-2) + + +if __name__ == "__main__": + test_group_gemm_bf16(8, 128, 512, 4096, 7168) From c1dec75b9946df1929cfce89b7cf24394d98a8d6 Mon Sep 17 00:00:00 2001 From: ZelinMa557 <3388706467@qq.com> Date: Wed, 1 Apr 2026 11:39:35 +0800 Subject: [PATCH 2/7] implement bf16 moe op Signed-off-by: ZelinMa557 <3388706467@qq.com> --- hpc/fuse_moe.py | 82 ++++++++++++++++++++++++ src/activation/activation.cu | 52 +++++++++++++++ src/activation/activation.h | 4 ++ src/fuse_moe/count_and_gather.cu | 60 +++++++++++++++--- src/fuse_moe/entry.cc | 105 +++++++++++++++++++++++++++++++ src/fuse_moe/fuse_moe.cu | 43 +++++++++++++ src/fuse_moe/fuse_moe.h | 17 +++++ 7 files changed, 355 insertions(+), 8 deletions(-) diff --git a/hpc/fuse_moe.py b/hpc/fuse_moe.py index e698d21..18fc457 100644 --- a/hpc/fuse_moe.py +++ b/hpc/fuse_moe.py @@ -291,6 +291,74 @@ def fuse_moe_blockwise_fp8( ) +def fuse_moe_bf16( + x: Tensor, + gate_up_weight: Tensor, + down_weight: Tensor, + topk_ids: Tensor, + topk_scale: Tensor, + rank_ep: int, + num_expert_total: int, + shared_output: Tensor = None, +) -> Tensor: + """Performs Mixture of Experts (MoE) forward operation with BF16 precision. + + This function executes the MoE computation with all matrix multiplications + performed in BF16 precision. The gate and up projections are fused into + a single matrix multiplication. + + Args: + x: Input activation tensor + Shape: [num_seq, hidden_size] + Dtype: bfloat16 + gate_up_weight: Combined weight tensor for gate and up projections + Shape: [num_expert_local, intermediate_size * 2, hidden_size] + Dtype: bfloat16 + down_weight: Weight tensor for down projection + Shape: [num_expert_local, hidden_size, intermediate_size] + Dtype: bfloat16 + topk_ids: Token indices assigned to each expert + Shape: [num_seq, num_topk] + Dtype: int32 + topk_scale: Weighting factors for each token-expert assignment + Shape: [num_seq, num_topk] + Dtype: float32 + rank_ep: Expert parallel rank (for distributed training) + Dtype: int32 + num_expert_total: the total number of expert + Dtype: int32 + shared_output: output for shared experts, default is None + Shape: [num_seq, hidden_size] + Dtype: bfloat16 + + Returns: + torch.Tensor: Output tensor after MoE computation + Shape: [num_seq, hidden_size] + Dtype: bfloat16 + + Raises: + RuntimeError: If the input tensors have incompatible shapes or types, + or if CUDA kernel execution fails. + + Note: + - All input tensors must be on CUDA device + - The gate and up projections are combined into a single matrix multiplication + - BF16 precision is used for all matrix operations + - Activation function used is SiLU (Swish) + - Token routing is determined by topk_ids and weighted by topk_scale + """ + return torch.ops.hpc.fuse_moe_bf16( + x, + gate_up_weight, + down_weight, + topk_ids, + topk_scale, + shared_output, + rank_ep, + num_expert_total, + ) + + @torch.library.register_fake("hpc::count_and_gather") def count_and_gather_fake( x, topk_ids, num_expert, rank_ep, intermediate_size, num_seq_per_group_avg @@ -346,3 +414,17 @@ def fuse_moe_blockwise_fp8_fake( use_bf16_mul: bool = True, ): return torch.empty((x.shape[0], x.shape[1]), dtype=torch.bfloat16) + + +@torch.library.register_fake("hpc::fuse_moe_bf16") +def fuse_moe_bf16_fake( + x: Tensor, + gate_up_weight: Tensor, + down_weight: Tensor, + topk_ids: Tensor, + topk_scale: Tensor, + shared_output: Tensor, + rank_ep: int, + num_expert_total: int, +): + return torch.empty((x.shape[0], x.shape[1]), dtype=torch.bfloat16) diff --git a/src/activation/activation.cu b/src/activation/activation.cu index b1000e4..fb5b141 100644 --- a/src/activation/activation.cu +++ b/src/activation/activation.cu @@ -64,6 +64,46 @@ __global__ void act_mul_and_quant_kernel(__nv_fp8_e4m3 *out_ptr, const __nv_bflo } } +__global__ void act_mul_bf16_kernel(__nv_bfloat16 *out_ptr, const __nv_bfloat16 *gate_up_ptr, + const int *valid_row_range, const int num_row, + const int num_col, cutlass::FastDivmod block1D22D) { + int iblockx; + int iblocky; + + block1D22D(iblocky, iblockx, blockIdx.x); + int it = threadIdx.x + iblockx * blockDim.x; + + int irow = iblocky; + int my_valid_row_end_exclusive = valid_row_range ? valid_row_range[0] : num_row; + if (irow >= my_valid_row_end_exclusive) { + return; + } + + using T = __nv_bfloat162; + + const auto *gate_row_ptr = gate_up_ptr + irow * num_col * 2; + const auto *up_row_ptr = gate_row_ptr + num_col; + auto *out_row_ptr = out_ptr + irow * num_col; + + int icol = it * 8; + + if (icol < num_col) { + auto gate = to(load(gate_row_ptr + icol)); + auto up = to(load(up_row_ptr + icol)); + + vec_t<__nv_bfloat16, 8> out; +#pragma unroll + for (int i = 0; i < size(out); ++i) { + auto g = gate[i]; + auto u = up[i]; + auto result = silu(g) * u; + out[i] = __float2bfloat16_rn(result); + } + + store(out_row_ptr + icol, out); + } +} + // input : gate + up __global__ void masked_act_mul_and_quant_kernel( __nv_fp8_e4m3 *output_ptr, const __nv_bfloat16 *input_ptr, const float *scale_ptr, @@ -391,5 +431,17 @@ void masked_act_mul_and_blockwise_quant_async(__nv_fp8_e4m3 *output_ptr, float * Row2EandT, num_block_row); } +void act_mul_bf16_async(__nv_bfloat16 *y_ptr, const __nv_bfloat16 *x_ptr, + const int *valid_row_range, const int num_row, const int num_col, + cudaStream_t stream) { + dim3 block(256); + int num_block_per_row = (num_col / 8 + block.x - 1) / block.x; + cutlass::FastDivmod block1D22D(num_block_per_row); + dim3 grid(num_row * num_block_per_row); + + kernels::act_mul_bf16_kernel<<>>( + y_ptr, x_ptr, valid_row_range, num_row, num_col, block1D22D); +} + } // namespace activation } // namespace hpc diff --git a/src/activation/activation.h b/src/activation/activation.h index 784dc27..bd72ceb 100644 --- a/src/activation/activation.h +++ b/src/activation/activation.h @@ -38,6 +38,10 @@ void masked_act_mul_and_blockwise_quant_async(__nv_fp8_e4m3 *output_ptr, float * int num_intermediate_size, int num_tokens_per_expert, cudaStream_t stream); +void act_mul_bf16_async(__nv_bfloat16 *y_ptr, const __nv_bfloat16 *x_ptr, + const int *valid_row_range, const int num_row, const int num_col, + cudaStream_t stream); + } // namespace activation } // namespace hpc diff --git a/src/fuse_moe/count_and_gather.cu b/src/fuse_moe/count_and_gather.cu index 0c4781c..9aba97e 100644 --- a/src/fuse_moe/count_and_gather.cu +++ b/src/fuse_moe/count_and_gather.cu @@ -262,7 +262,7 @@ __global__ void gather_kernel(const vec_t td_xy, } // namespace kernels -template +template void launch_count_and_gather(void *gate_up_input_ptr, void *gate_up_output_ptr, void *down_input_ptr, void *down_output_ptr, const void *x_ptr, const void *topk_ids_ptr, void *topk_pos_ptr, void *seqlens_ptr, @@ -272,9 +272,6 @@ void launch_count_and_gather(void *gate_up_input_ptr, void *gate_up_output_ptr, int num_seq_per_group_avg, cudaStream_t stream) { using namespace cute; // NOLINT - using Tin = cute::float_e4m3_t; - using Tout = cute::bfloat16_t; - int m = num_seq; int n = intermediate_size; int k = hidden_size; @@ -378,11 +375,58 @@ void count_and_gather_async(void *gate_up_input_ptr, void *gate_up_output_ptr, v cudaStream_t stream) { constexpr int kTileN = 128; constexpr int kTileK = 128; + using Tin = cute::float_e4m3_t; + using Tout = cute::bfloat16_t; + if (num_seq_per_group_avg <= 16) { + constexpr int kTileM = 16; + constexpr int kStage = 8; + launch_count_and_gather( + gate_up_input_ptr, gate_up_output_ptr, down_input_ptr, down_output_ptr, x_ptr, topk_ids_ptr, + topk_pos_ptr, seqlens_ptr, cu_seqlens_ptr, gate_up_tmas_ptr, down_tmas_ptr, tiles_ptr, + cu_tiles_ptr, num_seq, hidden_size, intermediate_size, num_topk, num_expert, eprank, + num_seq_per_group_avg, stream); + } else if (num_seq_per_group_avg <= 32) { + constexpr int kTileM = 32; + constexpr int kStage = 8; + launch_count_and_gather( + gate_up_input_ptr, gate_up_output_ptr, down_input_ptr, down_output_ptr, x_ptr, topk_ids_ptr, + topk_pos_ptr, seqlens_ptr, cu_seqlens_ptr, gate_up_tmas_ptr, down_tmas_ptr, tiles_ptr, + cu_tiles_ptr, num_seq, hidden_size, intermediate_size, num_topk, num_expert, eprank, + num_seq_per_group_avg, stream); + } else if (num_seq_per_group_avg <= 48) { + constexpr int kTileM = 48; + constexpr int kStage = 8; + launch_count_and_gather( + gate_up_input_ptr, gate_up_output_ptr, down_input_ptr, down_output_ptr, x_ptr, topk_ids_ptr, + topk_pos_ptr, seqlens_ptr, cu_seqlens_ptr, gate_up_tmas_ptr, down_tmas_ptr, tiles_ptr, + cu_tiles_ptr, num_seq, hidden_size, intermediate_size, num_topk, num_expert, eprank, + num_seq_per_group_avg, stream); + } else { + constexpr int kTileM = 64; + constexpr int kStage = 8; + launch_count_and_gather( + gate_up_input_ptr, gate_up_output_ptr, down_input_ptr, down_output_ptr, x_ptr, topk_ids_ptr, + topk_pos_ptr, seqlens_ptr, cu_seqlens_ptr, gate_up_tmas_ptr, down_tmas_ptr, tiles_ptr, + cu_tiles_ptr, num_seq, hidden_size, intermediate_size, num_topk, num_expert, eprank, + num_seq_per_group_avg, stream); + } +} +void count_and_gather_bf16_async(void *gate_up_input_ptr, void *gate_up_output_ptr, void *down_input_ptr, + void *down_output_ptr, const void *x_ptr, const void *topk_ids_ptr, + void *topk_pos_ptr, void *seqlens_ptr, void *cu_seqlens_ptr, + void *gate_up_tmas_ptr, void *down_tmas_ptr, void *tiles_ptr, + void *cu_tiles_ptr, int num_seq, int hidden_size, int intermediate_size, + int num_topk, int num_expert, int eprank, int num_seq_per_group_avg, + cudaStream_t stream) { + constexpr int kTileN = 128; + constexpr int kTileK = 64; + using Tin = cute::bfloat16_t; + using Tout = cute::bfloat16_t; if (num_seq_per_group_avg <= 16) { constexpr int kTileM = 16; constexpr int kStage = 8; - launch_count_and_gather( + launch_count_and_gather( gate_up_input_ptr, gate_up_output_ptr, down_input_ptr, down_output_ptr, x_ptr, topk_ids_ptr, topk_pos_ptr, seqlens_ptr, cu_seqlens_ptr, gate_up_tmas_ptr, down_tmas_ptr, tiles_ptr, cu_tiles_ptr, num_seq, hidden_size, intermediate_size, num_topk, num_expert, eprank, @@ -390,7 +434,7 @@ void count_and_gather_async(void *gate_up_input_ptr, void *gate_up_output_ptr, v } else if (num_seq_per_group_avg <= 32) { constexpr int kTileM = 32; constexpr int kStage = 8; - launch_count_and_gather( + launch_count_and_gather( gate_up_input_ptr, gate_up_output_ptr, down_input_ptr, down_output_ptr, x_ptr, topk_ids_ptr, topk_pos_ptr, seqlens_ptr, cu_seqlens_ptr, gate_up_tmas_ptr, down_tmas_ptr, tiles_ptr, cu_tiles_ptr, num_seq, hidden_size, intermediate_size, num_topk, num_expert, eprank, @@ -398,7 +442,7 @@ void count_and_gather_async(void *gate_up_input_ptr, void *gate_up_output_ptr, v } else if (num_seq_per_group_avg <= 48) { constexpr int kTileM = 48; constexpr int kStage = 8; - launch_count_and_gather( + launch_count_and_gather( gate_up_input_ptr, gate_up_output_ptr, down_input_ptr, down_output_ptr, x_ptr, topk_ids_ptr, topk_pos_ptr, seqlens_ptr, cu_seqlens_ptr, gate_up_tmas_ptr, down_tmas_ptr, tiles_ptr, cu_tiles_ptr, num_seq, hidden_size, intermediate_size, num_topk, num_expert, eprank, @@ -406,7 +450,7 @@ void count_and_gather_async(void *gate_up_input_ptr, void *gate_up_output_ptr, v } else { constexpr int kTileM = 64; constexpr int kStage = 8; - launch_count_and_gather( + launch_count_and_gather( gate_up_input_ptr, gate_up_output_ptr, down_input_ptr, down_output_ptr, x_ptr, topk_ids_ptr, topk_pos_ptr, seqlens_ptr, cu_seqlens_ptr, gate_up_tmas_ptr, down_tmas_ptr, tiles_ptr, cu_tiles_ptr, num_seq, hidden_size, intermediate_size, num_topk, num_expert, eprank, diff --git a/src/fuse_moe/entry.cc b/src/fuse_moe/entry.cc index 9e97c19..9dafe99 100644 --- a/src/fuse_moe/entry.cc +++ b/src/fuse_moe/entry.cc @@ -363,6 +363,105 @@ torch::Tensor fuse_moe_blockwise_fp8_entry( return y; } +torch::Tensor fuse_moe_bf16_entry( + const torch::Tensor &x, const torch::Tensor &gate_up_weight, const torch::Tensor &down_weight, + const torch::Tensor &topk_ids, const torch::Tensor &topk_scale, + const std::optional &shared_output, int64_t rank_ep, int64_t num_expert_total) { + auto stream = at::cuda::getCurrentCUDAStream(x.get_device()); + + TORCH_CHECK(x.device().is_cuda(), "x tensor must be cuda"); + TORCH_CHECK(gate_up_weight.device().is_cuda(), "gate_up_weight tensor must be cuda"); + TORCH_CHECK(down_weight.device().is_cuda(), "down_weight tensor must be cuda"); + TORCH_CHECK(topk_ids.device().is_cuda(), "topk_ids tensor must be cuda"); + TORCH_CHECK(topk_scale.device().is_cuda(), "topk_scale tensor must be cuda"); + + TORCH_CHECK(x.is_contiguous(), "x tensor must be contiguous"); + TORCH_CHECK(gate_up_weight.is_contiguous(), "gate_up_weight tensor must be contiguous"); + TORCH_CHECK(down_weight.is_contiguous(), "down_weight tensor must be contiguous"); + TORCH_CHECK(topk_ids.is_contiguous(), "topk_ids tensor must be contiguous"); + TORCH_CHECK(topk_scale.is_contiguous(), "topk_scale tensor must be contiguous"); + + TORCH_CHECK(x.dtype() == torch::kBFloat16, "x tensor dtype must be bfloat16"); + TORCH_CHECK(gate_up_weight.dtype() == torch::kBFloat16, "gate_up_weight tensor dtype must be bfloat16"); + TORCH_CHECK(down_weight.dtype() == torch::kBFloat16, "down_weight tensor dtype must be bfloat16"); + + TORCH_CHECK(x.size(0) == topk_ids.size(0), "x and topk_ids must share the same num_seq"); + TORCH_CHECK(topk_ids.size(0) == topk_scale.size(0), + "topk_ids and topk_scale must share the same num_seq"); + TORCH_CHECK(topk_ids.size(1) == topk_scale.size(1), + "topk_ids and topk_scale must share the same num_topk"); + TORCH_CHECK(x.size(1) == gate_up_weight.size(2), "x and weight must share the same k"); + TORCH_CHECK(gate_up_weight.size(0) == down_weight.size(0), + "gate_up_weight and down_weight must share the same num_expert"); + + const void *shared_output_ptr = nullptr; + if (shared_output.has_value()) { + const auto shared_output_tensor = shared_output.value(); + TORCH_CHECK(shared_output_tensor.device().is_cuda(), "shared_output tensor must be cuda"); + TORCH_CHECK(shared_output_tensor.is_contiguous(), "shared_output tensor must be contiguous"); + TORCH_CHECK(shared_output_tensor.dtype() == torch::kBFloat16, + "shared_output tensor dtype must be bfloat16"); + TORCH_CHECK( + shared_output_tensor.size(0) == x.size(0) && shared_output_tensor.size(1) == x.size(1), + "shared_output tensor shape must be same as x tensor"); + shared_output_ptr = shared_output_tensor.const_data_ptr(); + } + + int num_seq = x.size(0); + int hidden_size = x.size(1); + int num_expert = gate_up_weight.size(0); + int intermediate_size = gate_up_weight.size(1); + int num_topk = topk_ids.size(1); + TORCH_CHECK(num_topk <= 128, "num_topk must less than or equal to 128"); + + auto options = x.options(); + torch::Tensor y = torch::empty({num_seq, hidden_size}, options.dtype(torch::kBFloat16)); + + torch::Tensor gate_up_input = torch::empty({num_seq * num_topk, hidden_size}, options); + torch::Tensor gate_up_output = + torch::empty({num_seq * num_topk, intermediate_size}, options.dtype(torch::kBFloat16)); + torch::Tensor gate_up_tmas = torch::empty({num_expert * 2, 128}, options.dtype(torch::kInt8)); + torch::Tensor down_input = torch::empty({num_seq * num_topk, intermediate_size / 2}, options); + torch::Tensor down_output = + torch::empty({num_seq * num_topk, hidden_size}, options.dtype(torch::kBFloat16)); + torch::Tensor down_tmas = torch::empty({num_expert * 2, 128}, options.dtype(torch::kInt8)); + + torch::Tensor topk_pos = torch::empty({num_seq, num_topk}, options.dtype(torch::kInt32)); + torch::Tensor seqlens = torch::zeros({num_expert}, options.dtype(torch::kInt32)); + torch::Tensor cu_seqlens = torch::empty({num_expert + 1}, options.dtype(torch::kInt32)); + torch::Tensor tiles = torch::empty({num_expert}, options.dtype(torch::kInt32)); + torch::Tensor cu_tiles = torch::empty({num_expert + 1}, options.dtype(torch::kInt32)); + + const auto *x_ptr = x.const_data_ptr(); + const auto *topk_ids_ptr = topk_ids.const_data_ptr(); + const auto *topk_scale_ptr = topk_scale.const_data_ptr(); + const auto *gate_up_weight_ptr = gate_up_weight.const_data_ptr(); + const auto *down_weight_ptr = down_weight.const_data_ptr(); + + auto *y_ptr = y.mutable_data_ptr(); + auto *topk_pos_ptr = topk_pos.mutable_data_ptr(); + auto *seqlens_ptr = seqlens.mutable_data_ptr(); + auto *cu_seqlens_ptr = cu_seqlens.mutable_data_ptr(); + auto *tiles_ptr = tiles.mutable_data_ptr(); + auto *cu_tiles_ptr = cu_tiles.mutable_data_ptr(); + auto *gate_up_input_ptr = gate_up_input.mutable_data_ptr(); + auto *gate_up_output_ptr = gate_up_output.mutable_data_ptr(); + auto *gate_up_tmas_ptr = gate_up_tmas.mutable_data_ptr(); + auto *down_input_ptr = down_input.mutable_data_ptr(); + auto *down_output_ptr = down_output.mutable_data_ptr(); + auto *down_tmas_ptr = down_tmas.mutable_data_ptr(); + + fuse_moe_bf16_async( + y_ptr, x_ptr, gate_up_input_ptr, gate_up_output_ptr, gate_up_weight_ptr, gate_up_tmas_ptr, + down_input_ptr, down_output_ptr, down_weight_ptr, down_tmas_ptr, + topk_ids_ptr, topk_scale_ptr, topk_pos_ptr, seqlens_ptr, + cu_seqlens_ptr, tiles_ptr, cu_tiles_ptr, shared_output_ptr, num_seq, hidden_size, + intermediate_size, num_topk, num_expert_total, num_expert, rank_ep, stream); + + return y; +} + + } // namespace fuse_moe } // namespace hpc @@ -393,4 +492,10 @@ TORCH_LIBRARY_FRAGMENT(hpc, m) { "shared_output, " "int rank_ep, int num_expert_total) -> (Tensor)"); m.impl("fuse_moe_blockwise_fp8", torch::kCUDA, &hpc::fuse_moe::fuse_moe_blockwise_fp8_entry); + + m.def( + "fuse_moe_bf16(Tensor x, Tensor gate_up_weight, Tensor down_weight, " + "Tensor topk_ids, Tensor topk_scale, Tensor ? shared_output, " + "int rank_ep, int num_expert_total) -> (Tensor)"); + m.impl("fuse_moe_bf16", torch::kCUDA, &hpc::fuse_moe::fuse_moe_bf16_entry); } diff --git a/src/fuse_moe/fuse_moe.cu b/src/fuse_moe/fuse_moe.cu index fe73750..16d69be 100644 --- a/src/fuse_moe/fuse_moe.cu +++ b/src/fuse_moe/fuse_moe.cu @@ -107,5 +107,48 @@ void fuse_moe_blockwise_fp8_async( reduce_async(output_ptr, down_output_ptr, topk_pos_ptr, topk_scale_ptr, shared_output_ptr, total_num_tokens, num_tokens, hidden_size, num_topk, stream); } + +void fuse_moe_bf16_async( + void *output_ptr, const void *input_ptr, void *gate_up_input_ptr, void *gate_up_output_ptr, + const void *gate_up_weight_ptr, void *gate_up_tmas_ptr, void *down_input_ptr, + void *down_output_ptr, const void *down_weight_ptr, void *down_tmas_ptr, + const void *topk_ids_ptr, const void *topk_scale_ptr, void *topk_pos_ptr, void *seqlens_ptr, + void *cu_seqlens_ptr, void *tiles_ptr, void *cu_tiles_ptr, const void *shared_output_ptr, + int num_seq, int hidden_size, int intermediate_size, int num_topk, int num_expert_total, + int num_expert_local, int rank_ep, cudaStream_t stream) { + int total_num_seq = num_seq * num_topk; + int num_seq_per_group_avg = total_num_seq / num_expert_total; + using T1 = __nv_bfloat16; + + // 0. call count_and_gather_bf16_async (fills TMA descriptors for bf16 group gemm) + count_and_gather_bf16_async(gate_up_input_ptr, gate_up_output_ptr, down_input_ptr, down_output_ptr, + input_ptr, topk_ids_ptr, topk_pos_ptr, seqlens_ptr, cu_seqlens_ptr, + gate_up_tmas_ptr, down_tmas_ptr, tiles_ptr, cu_tiles_ptr, num_seq, + hidden_size, intermediate_size, num_topk, num_expert_local, rank_ep, + num_seq_per_group_avg, stream); + + // 1. call gate_up linear (bf16), TMA descriptors pre-filled by count_and_gather_bf16_async + group_gemm::group_gemm_bf16_async( + gate_up_output_ptr, gate_up_input_ptr, gate_up_weight_ptr, seqlens_ptr, cu_seqlens_ptr, + gate_up_tmas_ptr, tiles_ptr, cu_tiles_ptr, num_expert_local, total_num_seq, + intermediate_size, hidden_size, num_seq_per_group_avg, false, stream); + + // 2. call act and mul (bf16 activation) + const int *valid_row_range_ptr = + (int *)cu_seqlens_ptr + num_expert_local; // get last number as valid row + activation::act_mul_bf16_async((T1 *)down_input_ptr, (const T1 *)gate_up_output_ptr, + valid_row_range_ptr, total_num_seq, intermediate_size / 2, stream); + + // 3. call down linear (bf16), TMA descriptors pre-filled by count_and_gather_bf16_async + group_gemm::group_gemm_bf16_async( + down_output_ptr, down_input_ptr, down_weight_ptr, seqlens_ptr, cu_seqlens_ptr, + down_tmas_ptr, tiles_ptr, cu_tiles_ptr, num_expert_local, total_num_seq, hidden_size, + intermediate_size / 2, num_seq_per_group_avg, false, stream); + + // 4. call reduce + reduce_async(output_ptr, down_output_ptr, topk_pos_ptr, topk_scale_ptr, shared_output_ptr, + total_num_seq, num_seq, hidden_size, num_topk, stream); +} + } // namespace fuse_moe } // namespace hpc diff --git a/src/fuse_moe/fuse_moe.h b/src/fuse_moe/fuse_moe.h index a388878..85a844b 100644 --- a/src/fuse_moe/fuse_moe.h +++ b/src/fuse_moe/fuse_moe.h @@ -20,6 +20,14 @@ void count_and_gather_async(void *gate_up_input_ptr, void *gata_up_output_ptr, v int num_topk, int num_expert, int rank_ep, int num_seq_per_group_avg, cudaStream_t stream); +void count_and_gather_bf16_async(void *gate_up_input_ptr, void *gate_up_output_ptr, + void *down_input_ptr, void *down_output_ptr, const void *x_ptr, + const void *topk_ids_ptr, void *topk_pos_ptr, void *seqlens_ptr, + void *cu_seqlens_ptr, void *gate_up_tmas_ptr, void *down_tmas_ptr, + void *tiles_ptr, void *cu_tiles_ptr, int num_seq, int hidden_size, + int intermediate_size, int num_topk, int num_expert, int rank_ep, + int num_seq_per_group_avg, cudaStream_t stream); + void blockwise_count_and_gather_async( const void *input_ptr, const void *input_scale_ptr, void *gate_up_input_ptr, void *gate_up_output_ptr, void *gate_up_input_scale_ptr, void *down_input_ptr, @@ -56,6 +64,15 @@ void fuse_moe_blockwise_fp8_async( int gate_up_weight_scale_lastdim_pad4, int down_weight_scale_lastdim_pad4, int rank_ep, cudaStream_t stream); +void fuse_moe_bf16_async( + void *output_ptr, const void *input_ptr, void *gate_up_input_ptr, void *gate_up_output_ptr, + const void *gate_up_weight_ptr, void *gate_up_tmas_ptr, void *down_input_ptr, + void *down_output_ptr, const void *down_weight_ptr, void *down_tmas_ptr, + const void *topk_ids_ptr, const void *topk_scale_ptr, void *topk_pos_ptr, void *seqlens_ptr, + void *cu_seqlens_ptr, void *tiles_ptr, void *cu_tiles_ptr, const void *shared_output_ptr, + int num_seq, int hidden_size, int intermediate_size, int num_topk, int num_expert_total, + int num_expert_local, int rank_ep, cudaStream_t stream); + } // namespace fuse_moe } // namespace hpc From b3ecebf5556a809aa18bd811da1d21757bb0a525 Mon Sep 17 00:00:00 2001 From: ZelinMa557 <3388706467@qq.com> Date: Wed, 1 Apr 2026 11:40:04 +0800 Subject: [PATCH 3/7] update test and bench code Signed-off-by: ZelinMa557 <3388706467@qq.com> --- tests/bench_fuse_moe_bf16.py | 342 +++++++++++++++++++++++++++++++++ tests/bench_group_gemm_bf16.py | 257 +++++++++++++++++++++++++ tests/test_fuse_moe_bf16.py | 201 +++++++++++++++++++ 3 files changed, 800 insertions(+) create mode 100644 tests/bench_fuse_moe_bf16.py create mode 100644 tests/bench_group_gemm_bf16.py create mode 100644 tests/test_fuse_moe_bf16.py diff --git a/tests/bench_fuse_moe_bf16.py b/tests/bench_fuse_moe_bf16.py new file mode 100644 index 0000000..554a018 --- /dev/null +++ b/tests/bench_fuse_moe_bf16.py @@ -0,0 +1,342 @@ +""" +Benchmark script for hpc.fuse_moe_bf16 vs sglang fused_moe (BF16). + +Default model: Qwen3-235B-A22B, TP=8 (EP=8 simulation on a single GPU) + hidden_size = 4096 + moe_intermediate_size = 1536 + num_experts = 128 -> num_experts_local = 128 // 8 = 16 per GPU + num_experts_per_tok = 8 (topk) + +Usage: + python tests/bench_fuse_moe_bf16.py + python tests/bench_fuse_moe_bf16.py --hidden-size 4096 --intermediate-size 1536 \\ + --num-experts 128 --topk 8 --tp-size 8 + python tests/bench_fuse_moe_bf16.py --warmup 10 --iters 100 +""" + +import argparse +import os +import sys +from pathlib import Path + +sys.path.insert(0, os.path.realpath(list(Path(__file__).parent.glob("../build/lib.*/"))[0])) + +import triton.language as tl +import torch +import hpc +from sglang.srt.layers.moe.fused_moe_triton.fused_moe_triton_kernels import invoke_fused_moe_kernel +from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size +from sgl_kernel import silu_and_mul + +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- + + +def parse_args(): + p = argparse.ArgumentParser() + p.add_argument("--hidden-size", type=int, default=4096) + p.add_argument("--intermediate-size", type=int, default=1536, + help="moe_intermediate_size per expert (before gate/up split)") + p.add_argument("--num-experts", type=int, default=128) + p.add_argument("--topk", type=int, default=8) + p.add_argument("--tp-size", type=int, default=8, + help="Tensor/Expert parallelism size. num_experts_local = num_experts // tp_size") + p.add_argument("--warmup", type=int, default=5, + help="Warmup iterations before graph capture") + p.add_argument("--iters", type=int, default=100, + help="Graph replay iterations for timing") + return p.parse_args() + + +# --------------------------------------------------------------------------- +# Batch sizes to sweep +# --------------------------------------------------------------------------- + +BATCH_SIZES = [i for i in range(1, 17)] + + +# --------------------------------------------------------------------------- +# Input helpers +# --------------------------------------------------------------------------- + + +def make_inputs(batch_size, hidden_size, intermediate_size, num_experts_local, topk, + device="cuda"): + """ + Build BF16 inputs for a single-GPU EP benchmark. + topk_ids are sampled uniformly from local experts [0, num_experts_local). + """ + dtype = torch.bfloat16 + + x = torch.randn((batch_size, hidden_size), dtype=dtype, device=device) + + # gate+up fused: [E_local, inter*2, hidden] + gate_up_weight = torch.randn( + (num_experts_local, intermediate_size * 2, hidden_size), dtype=dtype, device=device + ) + # down: [E_local, hidden, inter] + down_weight = torch.randn( + (num_experts_local, hidden_size, intermediate_size), dtype=dtype, device=device + ) + + # topk_ids in [0, num_experts_local); topk_scale positive, sum-normalised + topk_ids = torch.randint( + 0, num_experts_local, (batch_size, topk), dtype=torch.int32, device=device + ) + raw_scale = torch.rand((batch_size, topk), dtype=torch.float32, device=device) + topk_scale = raw_scale / raw_scale.sum(dim=1, keepdim=True) + + return x, gate_up_weight, down_weight, topk_ids, topk_scale + + +# --------------------------------------------------------------------------- +# FLOPs estimate (two GEMMs: gate_up and down, per token-expert assignment) +# --------------------------------------------------------------------------- + + +def tflops(batch_size, topk, hidden_size, intermediate_size, elapsed_ms): + # gate_up: batch*topk × (inter*2) × hidden (factor 2 for gate+up) + # down: batch*topk × hidden × inter + tokens = batch_size * topk + flops = 2 * tokens * intermediate_size * 2 * hidden_size # gate_up gemm + flops += 2 * tokens * hidden_size * intermediate_size # down gemm + return flops / (elapsed_ms * 1e-3) / 1e12 + + +# --------------------------------------------------------------------------- +# CUDA-graph benchmarking helper +# --------------------------------------------------------------------------- + + +def bench_cuda_graph(fn, warmup, iters): + """Warm up, capture a CUDA graph, replay `iters` times, return avg ms.""" + for _ in range(warmup): + fn() + torch.cuda.synchronize() + + g = torch.cuda.CUDAGraph() + capture_stream = torch.cuda.Stream() + with torch.cuda.stream(capture_stream): + fn() + torch.cuda.synchronize() + with torch.cuda.graph(g, stream=capture_stream): + fn() + + torch.cuda.synchronize() + + t0 = torch.cuda.Event(enable_timing=True) + t1 = torch.cuda.Event(enable_timing=True) + t0.record() + for _ in range(iters): + g.replay() + t1.record() + torch.cuda.synchronize() + return t0.elapsed_time(t1) / iters + + +# --------------------------------------------------------------------------- +# sglang default BF16 config +# --------------------------------------------------------------------------- + + +def sglang_bf16_config(total_m, E): + """ + Static BF16 config mirroring sglang's get_default_config logic: + total_m <= E -> small-batch config (BLOCK_SIZE_M=16) + total_m > E -> regular config (BLOCK_SIZE_M=64) + """ + if total_m <= E: + return { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + } + else: + return { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 3, + } + + +# --------------------------------------------------------------------------- +# Per-kernel bench helpers +# --------------------------------------------------------------------------- + + +def bench_hpc(x, gate_up_weight, down_weight, topk_ids, topk_scale, + num_experts_local, warmup, iters): + """Benchmark hpc.fuse_moe_bf16 (full pipeline).""" + # rank_ep=0, num_expert_total=num_experts_local: all experts are local + def fn(): + hpc.fuse_moe_bf16( + x, gate_up_weight, down_weight, + topk_ids, topk_scale, + rank_ep=0, + num_expert_total=num_experts_local, + ) + + return bench_cuda_graph(fn, warmup, iters) + + +def bench_sglang(x, gate_up_weight, down_weight, topk_ids, topk_scale, + num_experts_local, warmup, iters): + """ + Benchmark sglang full fused_moe pipeline (BF16) via direct kernel calls. + All steps are inside the CUDA graph for a fair comparison with hpc.fuse_moe_bf16: + 0. sgl_moe_align_block_size token sorting (= hpc count_and_gather) + 1. invoke_fused_moe_kernel gate_up GEMM (mul_routed_weight=False) + 2. silu_and_mul SiLU activation + 3. invoke_fused_moe_kernel down GEMM (mul_routed_weight=True, top_k=1) + 4. torch.sum over topk dim weighted reduce + """ + batch_size = x.shape[0] + hidden_size = x.shape[1] + inter_x2 = gate_up_weight.shape[1] # intermediate_size * 2 + inter = inter_x2 // 2 + topk = topk_ids.shape[1] + total_tokens = batch_size * topk + + config = sglang_bf16_config(total_tokens, num_experts_local) + block_size = config["BLOCK_SIZE_M"] + + # Pre-allocate routing buffers at max possible size (shapes are static for CUDA graph) + if topk_ids.numel() < num_experts_local + 1: + max_padded = topk_ids.numel() * block_size + else: + max_padded = topk_ids.numel() + (num_experts_local + 1) * (block_size - 1) + max_m_blocks = (max_padded + block_size - 1) // block_size + + sorted_ids = torch.empty((max_padded,), dtype=torch.int32, device=x.device) + expert_ids = torch.empty((max_m_blocks,), dtype=torch.int32, device=x.device) + num_tokens_post_pad = torch.empty((1,), dtype=torch.int32, device=x.device) + cumsum_buf = torch.empty((num_experts_local + 2,), dtype=torch.int32, device=x.device) + + # Intermediate compute buffers (max padded size) + cache1 = torch.empty((max_padded, inter_x2), dtype=torch.bfloat16, device=x.device) + cache2 = torch.empty((max_padded, inter), dtype=torch.bfloat16, device=x.device) + cache3 = torch.empty((batch_size, topk, hidden_size), dtype=torch.bfloat16, device=x.device) + out = torch.empty((batch_size, hidden_size), dtype=torch.bfloat16, device=x.device) + + def fn(): + # 0. token sorting + sgl_moe_align_block_size( + topk_ids, num_experts_local + 1, block_size, + sorted_ids, expert_ids, num_tokens_post_pad, cumsum_buf, True, + ) + # 1. gate_up GEMM + invoke_fused_moe_kernel( + x, gate_up_weight, None, cache1, + None, None, None, + topk_scale, topk_ids, + sorted_ids, expert_ids, num_tokens_post_pad, + False, # mul_routed_weight + topk, # top_k + config, + compute_type=tl.bfloat16, + use_fp8_w8a8=False, + use_int8_w8a8=False, + use_int8_w8a16=False, + use_int4_w4a16=False, + per_channel_quant=False, + block_shape=None, + ) + # 2. SiLU activation: cache1 [max_padded, inter*2] -> cache2 [max_padded, inter] + silu_and_mul(cache1, cache2) + # 3. down GEMM – writes weighted results to cache3[batch, topk, hidden] + invoke_fused_moe_kernel( + cache2, down_weight, None, cache3, + None, None, None, + topk_scale, topk_ids, + sorted_ids, expert_ids, num_tokens_post_pad, + True, # mul_routed_weight (applies routing weights during scatter) + 1, # top_k=1: each sorted row maps to one (batch, topk_slot) in cache3 + config, + compute_type=tl.bfloat16, + use_fp8_w8a8=False, + use_int8_w8a8=False, + use_int8_w8a16=False, + use_int4_w4a16=False, + per_channel_quant=False, + block_shape=None, + ) + # 4. Sum-reduce over topk slots -> [batch, hidden] + torch.sum(cache3, dim=1, out=out) + + return bench_cuda_graph(fn, warmup, iters) + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def main(): + args = parse_args() + + hidden_size = args.hidden_size + intermediate_size = args.intermediate_size + num_experts = args.num_experts + topk = args.topk + tp_size = args.tp_size + num_experts_local = num_experts // tp_size + + torch.manual_seed(0) + torch.cuda.manual_seed(0) + + prop = torch.cuda.get_device_properties(0) + print(f"\nDevice : {prop.name}") + print(f"Model config : hidden={hidden_size}, inter={intermediate_size}, " + f"experts={num_experts}, topk={topk}, tp={tp_size}") + print(f"Local experts/GPU : {num_experts_local}") + print(f"Weight shapes : gate_up=[{num_experts_local}, {intermediate_size*2}, {hidden_size}], " + f"down=[{num_experts_local}, {hidden_size}, {intermediate_size}]") + print(f"Timing : warmup={args.warmup}, iters={args.iters} (CUDA graph replay)\n") + + hdr = ( + f"{'batch':>7} {'tokens':>8} " + f"{'hpc(ms)':>10} {'hpc(TF)':>9} " + f"{'sgl(ms)':>10} {'sgl(TF)':>9} " + f"{'speedup':>8}" + ) + sep = "-" * len(hdr) + print(hdr) + print(sep) + + for bs in BATCH_SIZES: + x, gate_up_weight, down_weight, topk_ids, topk_scale = make_inputs( + bs, hidden_size, intermediate_size, num_experts_local, topk + ) + + hpc_ms = bench_hpc( + x, gate_up_weight, down_weight, topk_ids, topk_scale, + num_experts_local, args.warmup, args.iters + ) + sgl_ms = bench_sglang( + x, gate_up_weight, down_weight, topk_ids, topk_scale, + num_experts_local, args.warmup, args.iters + ) + + hpc_tf = tflops(bs, topk, hidden_size, intermediate_size, hpc_ms) + sgl_tf = tflops(bs, topk, hidden_size, intermediate_size, sgl_ms) + + print( + f"{bs:>7} {bs*topk:>8} " + f"{hpc_ms:>10.4f} {hpc_tf:>9.2f} " + f"{sgl_ms:>10.4f} {sgl_tf:>9.2f} " + f"{sgl_ms/hpc_ms:>8.2f}x" + ) + + print(sep) + print() + + +if __name__ == "__main__": + main() diff --git a/tests/bench_group_gemm_bf16.py b/tests/bench_group_gemm_bf16.py new file mode 100644 index 0000000..5061a03 --- /dev/null +++ b/tests/bench_group_gemm_bf16.py @@ -0,0 +1,257 @@ +""" +Benchmark script for hpc.group_gemm_bf16 vs sglang fused_moe_kernel (BF16). + +Usage: + python tests/bench_group_gemm_bf16.py + python tests/bench_group_gemm_bf16.py --E 128 --N 384 --K 4096 + python tests/bench_group_gemm_bf16.py --E 8 --N 4096 --K 7168 --warmup 10 --iters 100 +""" + +import argparse +import os +import sys +from pathlib import Path + +sys.path.insert(0, os.path.realpath(list(Path(__file__).parent.glob("../build/lib.*/"))[0])) + +import torch +import triton.language as tl +import hpc +from sglang.srt.layers.moe.fused_moe_triton.fused_moe_triton_kernels import invoke_fused_moe_kernel +from sglang.srt.layers.moe.fused_moe_triton.moe_align_block_size import moe_align_block_size + +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- + + +def parse_args(): + p = argparse.ArgumentParser() + p.add_argument("--E", type=int, default=8, help="Number of groups (experts)") + p.add_argument("--N", type=int, default=4096, help="Output dimension per group") + p.add_argument("--K", type=int, default=7168, help="Input dimension (hidden size)") + p.add_argument("--warmup", type=int, default=5, help="Warmup iterations before graph capture") + p.add_argument("--iters", type=int, default=100, help="Graph replay iterations for timing") + return p.parse_args() + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +M_VALUES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512] + + +def make_inputs(E, m_per_group, N, K, device="cuda"): + dtype = torch.bfloat16 + total_m = E * m_per_group + x = torch.randn((total_m, K), dtype=dtype, device=device) + w = torch.randn((E, N, K), dtype=dtype, device=device) + seqlens = torch.full((E,), m_per_group, dtype=torch.int32, device=device) + cu_seqlens = torch.zeros(E + 1, dtype=torch.int32, device=device) + cu_seqlens[1:] = torch.cumsum(seqlens, dim=0) + output = torch.empty((total_m, N), dtype=dtype, device=device) + return x, w, seqlens, cu_seqlens, output + + +def tflops(E, m_per_group, N, K, elapsed_ms): + flops = 2 * E * m_per_group * N * K + return flops / (elapsed_ms * 1e-3) / 1e12 + + +# --------------------------------------------------------------------------- +# sglang default BF16 config +# (mirrors get_default_config logic for dtype=None, no server args needed) +# --------------------------------------------------------------------------- + + +def sglang_bf16_config(total_m, E): + """ + Replicates sglang's get_default_config logic for plain BF16 (dtype=None). + M <= E -> small-batch config (BLOCK_SIZE_M=16) + M > E -> regular config (BLOCK_SIZE_M=64) + num_warps / num_stages are typical Hopper-friendly defaults. + """ + if total_m <= E: + return { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + } + else: + return { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 3, + } + + +# --------------------------------------------------------------------------- +# Benchmark with CUDA graph +# --------------------------------------------------------------------------- + + +def bench_cuda_graph(fn, warmup, iters): + """Warm up, capture a CUDA graph, replay `iters` times, return avg ms.""" + for _ in range(warmup): + fn() + torch.cuda.synchronize() + + g = torch.cuda.CUDAGraph() + capture_stream = torch.cuda.Stream() + with torch.cuda.stream(capture_stream): + fn() # dry run on capture stream + torch.cuda.synchronize() + with torch.cuda.graph(g, stream=capture_stream): + fn() + + torch.cuda.synchronize() + + t0 = torch.cuda.Event(enable_timing=True) + t1 = torch.cuda.Event(enable_timing=True) + t0.record() + for _ in range(iters): + g.replay() + t1.record() + torch.cuda.synchronize() + return t0.elapsed_time(t1) / iters + + +# --------------------------------------------------------------------------- +# Per-kernel bench helpers +# --------------------------------------------------------------------------- + + +def bench_hpc(x, w, seqlens, cu_seqlens, output, mean_seq, warmup, iters): + def fn(): + hpc.group_gemm_bf16( + x, w, seqlens, cu_seqlens, num_seq_per_group_avg=mean_seq, output=output + ) + + return bench_cuda_graph(fn, warmup, iters) + + +def bench_torch_ref(x, w, seqlens, cu_seqlens, output, warmup, iters): + E = seqlens.shape[0] + slices = [(int(cu_seqlens[i].item()), int(cu_seqlens[i + 1].item())) for i in range(E)] + + def fn(): + for i, (s, e) in enumerate(slices): + output[s:e] = torch.matmul(x[s:e], w[i].t()) + + return bench_cuda_graph(fn, warmup, iters) + + +def bench_sglang(x, w, E, m_per_group, N, warmup, iters): + """ + Benchmark sglang's invoke_fused_moe_kernel for a single up/gate projection + (top_k=1, each token routes to its group, no routed-weight multiply). + + moe_align_block_size is called once outside the graph; only the GEMM + kernel itself is captured and replayed. + """ + total_m = E * m_per_group + config = sglang_bf16_config(total_m, E) + + # top_k=1: token i belongs to expert floor(i / m_per_group) + topk_ids = ( + (torch.arange(total_m, device="cuda") // m_per_group).view(-1, 1).to(torch.int32) + ) # [total_m, 1] + topk_weights = torch.ones(total_m, 1, dtype=torch.bfloat16, device="cuda") + + # Routing tensors – computed once, reused by every graph replay + sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( + topk_ids, config["BLOCK_SIZE_M"], E + ) + + # Output shape follows sglang convention: [padded_tokens, N] + sgl_out = torch.empty((sorted_token_ids.shape[0], N), dtype=torch.bfloat16, device="cuda") + + def fn(): + invoke_fused_moe_kernel( + A=x, + B=w, + bias=None, + C=sgl_out, + A_scale=None, + B_scale=None, + B_zp=None, + topk_weights=topk_weights, + topk_ids=topk_ids, + sorted_token_ids=sorted_token_ids, + expert_ids=expert_ids, + num_tokens_post_padded=num_tokens_post_padded, + mul_routed_weight=False, + top_k=1, + config=config, + compute_type=tl.bfloat16, + use_fp8_w8a8=False, + use_int8_w8a8=False, + use_int8_w8a16=False, + use_int4_w4a16=False, + per_channel_quant=False, + block_shape=None, + ) + + return bench_cuda_graph(fn, warmup, iters) + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def main(): + args = parse_args() + E, N, K = args.E, args.N, args.K + + torch.manual_seed(0) + torch.cuda.manual_seed(0) + + prop = torch.cuda.get_device_properties(0) + print(f"\nDevice : {prop.name}") + print(f"Config : E={E}, N={N}, K={K}") + print(f"Timing : warmup={args.warmup}, iters={args.iters} (CUDA graph replay)\n") + + hdr = ( + f"{'M/group':>8} {'total_M':>8} " + f"{'hpc(ms)':>10} {'hpc(TF)':>9} " + f"{'sgl(ms)':>10} {'sgl(TF)':>9} " + f"{'ref(ms)':>10} {'ref(TF)':>9} " + f"{'hpc/sgl':>8}" + ) + sep = "-" * len(hdr) + print(hdr) + print(sep) + + for m in M_VALUES: + x, w, seqlens, cu_seqlens, output = make_inputs(E, m, N, K) + + hpc_ms = bench_hpc(x, w, seqlens, cu_seqlens, output, m, args.warmup, args.iters) + sgl_ms = bench_sglang(x, w, E, m, N, args.warmup, args.iters) + ref_ms = bench_torch_ref(x, w, seqlens, cu_seqlens, output, args.warmup, args.iters) + + hpc_tf = tflops(E, m, N, K, hpc_ms) + sgl_tf = tflops(E, m, N, K, sgl_ms) + ref_tf = tflops(E, m, N, K, ref_ms) + + print( + f"{m:>8} {E*m:>8} " + f"{hpc_ms:>10.4f} {hpc_tf:>9.2f} " + f"{sgl_ms:>10.4f} {sgl_tf:>9.2f} " + f"{ref_ms:>10.4f} {ref_tf:>9.2f} " + f"{sgl_ms/hpc_ms:>8.2f}x" + ) + + print(sep) + print() + + +if __name__ == "__main__": + main() diff --git a/tests/test_fuse_moe_bf16.py b/tests/test_fuse_moe_bf16.py new file mode 100644 index 0000000..7874733 --- /dev/null +++ b/tests/test_fuse_moe_bf16.py @@ -0,0 +1,201 @@ +import os +import sys +from pathlib import Path + +sys.path.insert(0, os.path.realpath(list(Path(__file__).parent.glob("../build/lib.*/"))[0])) + +import math +import pytest +import torch + +import hpc +from utils import allclose + +# Set random seed for reproducibility +torch.manual_seed(41) +torch.cuda.manual_seed(41) + + +def naive_gather_expert_inputs(x, topk_ids, num_expert, rank_ep): + num_tokens, num_topk = topk_ids.shape + num_seq, hidden_size = x.shape + + unique_values, num_tokens_per_expert_partial = torch.unique( + topk_ids.flatten(), return_counts=True, sorted=True + ) + start_expert = rank_ep * num_expert + end_expert = (rank_ep + 1) * num_expert + mask = (unique_values >= start_expert) & (unique_values < end_expert) + unique_values = unique_values[mask] + num_tokens_per_expert_partial = num_tokens_per_expert_partial[mask] + num_tokens_per_expert = torch.full([num_expert], 0, dtype=torch.int32, device="cuda") + + for i in range(unique_values.numel()): + num_tokens_per_expert[unique_values[i] - start_expert] = num_tokens_per_expert_partial[i] + + cu_num_tokens_per_expert = torch.cumsum( + torch.cat([torch.tensor([0], dtype=torch.int32, device="cuda"), num_tokens_per_expert]), + dim=0, + ).to(torch.int32) + + y = torch.zeros((num_tokens * num_topk, hidden_size), dtype=x.dtype, device=x.device) + token_pos = torch.zeros((num_tokens, num_topk), dtype=torch.int32, device=x.device) + token_pos.fill_(-1) + + # reset + num_tokens_per_expert.fill_(0) + + for idx, iexpert in enumerate(topk_ids.flatten()): + itoken = idx // num_topk + icol = idx % num_topk + if iexpert >= start_expert and iexpert < end_expert: + pos = ( + cu_num_tokens_per_expert[iexpert - start_expert] + + num_tokens_per_expert[iexpert - start_expert] + ) + y[pos] = x[itoken] + token_pos[itoken, icol] = pos.item() + num_tokens_per_expert[iexpert - start_expert] += 1 + + return ( + y, + token_pos, + num_tokens_per_expert, + cu_num_tokens_per_expert, + unique_values, + ) + + +def naive_group_gemm_bf16(x, w, cu_seqlens): + m, k = x.shape + num_group, n, _ = w.shape + + y = torch.zeros((m, n), dtype=torch.bfloat16, device=x.device) + + start_idx = 0 + for i in range(num_group): + start_idx = cu_seqlens[i].item() + end_idx = cu_seqlens[i + 1].item() + + x_group = x[start_idx:end_idx] + w_group = w[i] + + y_group = x_group @ w_group.t() + y[start_idx:end_idx] = y_group + + return y + + +def naive_act_mul_bf16(gate_up): + def silu(x): + return x / (1 + (-x).exp()) + + gate, up = torch.chunk(gate_up, 2, dim=1) + out = silu(gate) * up + return out + + +def naive_reduce(x_bf16, topk_pos, topk_scale, shared_output=None): + num_seq, num_topk = topk_pos.shape + total_num_seq, hidden_size = x_bf16.shape + + y_bf16 = torch.zeros((num_seq, hidden_size), dtype=torch.bfloat16, device=x_bf16.device) + for i in range(num_seq): + y_bf16[i] = torch.sum(x_bf16[topk_pos[i]] * topk_scale[i].unsqueeze(1), dim=0) + if shared_output is not None: + y_bf16[i] += shared_output[i] + + return y_bf16 + + +def naive_fuse_moe_bf16( + x, + gate_up_weight, + down_weight, + topk_ids, + topk_scale, + rank_ep, + shared_output=None, +): + num_expert = gate_up_weight.size(0) + # count_and_gather + gate_up_input, topk_pos, seqlens, cu_seqlens, expert_ids = naive_gather_expert_inputs( + x, topk_ids, num_expert, rank_ep + ) + + # gate_up_proj + gate_up_output = naive_group_gemm_bf16(gate_up_input, gate_up_weight, cu_seqlens) + + # act_and_mul + down_input = naive_act_mul_bf16(gate_up_output) + + # down_proj + down_output = naive_group_gemm_bf16(down_input, down_weight, cu_seqlens) + + # reduce + y = naive_reduce(down_output, topk_pos, topk_scale, shared_output) + + return y + + +@pytest.mark.parametrize("num_seq", [128]) +@pytest.mark.parametrize("num_topk", [8]) +@pytest.mark.parametrize("hidden_size", [512]) +@pytest.mark.parametrize("intermediate_size", [512]) +@pytest.mark.parametrize("num_expert", [128]) +@pytest.mark.parametrize("rank_ep", [0, 1]) +@pytest.mark.parametrize("size_ep", [1, 4, 8]) +@pytest.mark.parametrize("has_shared_output", [False, True]) +def test_fuse_moe_bf16( + num_seq, + num_topk, + hidden_size, + intermediate_size, + num_expert, + rank_ep, + size_ep, + has_shared_output, +): + dtype = torch.bfloat16 + + topk_ids = torch.randint(0, num_expert, (num_seq, num_topk), dtype=torch.int32, device="cuda") + topk_ids, _ = torch.sort(topk_ids, dim=1) + + x = torch.randn((num_seq, hidden_size), dtype=dtype, device="cuda") / 100 + gate_up_weight = torch.randn( + (num_expert // size_ep, intermediate_size * 2, hidden_size), + dtype=dtype, + device="cuda", + ) + down_weight = torch.randn( + (num_expert // size_ep, hidden_size, intermediate_size), + dtype=dtype, + device="cuda", + ) + topk_scale = torch.randn((num_seq, num_topk), dtype=torch.float, device="cuda") / num_topk + if has_shared_output: + shared_output = torch.randn((num_seq, hidden_size), dtype=dtype, device="cuda") + else: + shared_output = None + + my = hpc.fuse_moe_bf16( + x, + gate_up_weight, + down_weight, + topk_ids, + topk_scale, + rank_ep, + num_expert // size_ep, + shared_output=shared_output, + ) + gt = naive_fuse_moe_bf16( + x, + gate_up_weight, + down_weight, + topk_ids, + topk_scale, + rank_ep, + shared_output, + ) + + assert allclose(gt.to(torch.float32), my.to(torch.float32), rtol=0.1, atol=0.1) From a95f67b3d7919a3cdaac9934bc1618b7c98f0d30 Mon Sep 17 00:00:00 2001 From: ZelinMa557 <3388706467@qq.com> Date: Wed, 1 Apr 2026 16:30:01 +0800 Subject: [PATCH 4/7] fix bench fuse moe script Signed-off-by: ZelinMa557 <3388706467@qq.com> --- tests/bench_fuse_moe_bf16.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/bench_fuse_moe_bf16.py b/tests/bench_fuse_moe_bf16.py index 554a018..b09233a 100644 --- a/tests/bench_fuse_moe_bf16.py +++ b/tests/bench_fuse_moe_bf16.py @@ -53,7 +53,7 @@ def parse_args(): # Batch sizes to sweep # --------------------------------------------------------------------------- -BATCH_SIZES = [i for i in range(1, 17)] +BATCH_SIZES = [i for i in range(1, 65)] + [4096, 8192] # --------------------------------------------------------------------------- @@ -280,13 +280,12 @@ def fn(): def main(): args = parse_args() - + tp_size = args.tp_size hidden_size = args.hidden_size - intermediate_size = args.intermediate_size + intermediate_size = args.intermediate_size // tp_size num_experts = args.num_experts topk = args.topk - tp_size = args.tp_size - num_experts_local = num_experts // tp_size + num_experts_local = num_experts torch.manual_seed(0) torch.cuda.manual_seed(0) From 4152ab107ac0cab857c4633e3040ea1140a9916c Mon Sep 17 00:00:00 2001 From: ZelinMa557 <3388706467@qq.com> Date: Fri, 10 Apr 2026 00:59:51 +0800 Subject: [PATCH 5/7] remove dead code Signed-off-by: ZelinMa557 <3388706467@qq.com> --- src/group_gemm/group_gemm_bf16.cu | 1 - tests/bench_fuse_moe_bf16.py | 341 ------------------------------ tests/bench_group_gemm_bf16.py | 257 ---------------------- 3 files changed, 599 deletions(-) delete mode 100644 tests/bench_fuse_moe_bf16.py delete mode 100644 tests/bench_group_gemm_bf16.py diff --git a/src/group_gemm/group_gemm_bf16.cu b/src/group_gemm/group_gemm_bf16.cu index 6cae49b..9417007 100644 --- a/src/group_gemm/group_gemm_bf16.cu +++ b/src/group_gemm/group_gemm_bf16.cu @@ -4,7 +4,6 @@ #include #include -#include #include "cute/tensor.hpp" #include "src/group_gemm/config.h" #include "src/group_gemm/group_gemm.h" diff --git a/tests/bench_fuse_moe_bf16.py b/tests/bench_fuse_moe_bf16.py deleted file mode 100644 index b09233a..0000000 --- a/tests/bench_fuse_moe_bf16.py +++ /dev/null @@ -1,341 +0,0 @@ -""" -Benchmark script for hpc.fuse_moe_bf16 vs sglang fused_moe (BF16). - -Default model: Qwen3-235B-A22B, TP=8 (EP=8 simulation on a single GPU) - hidden_size = 4096 - moe_intermediate_size = 1536 - num_experts = 128 -> num_experts_local = 128 // 8 = 16 per GPU - num_experts_per_tok = 8 (topk) - -Usage: - python tests/bench_fuse_moe_bf16.py - python tests/bench_fuse_moe_bf16.py --hidden-size 4096 --intermediate-size 1536 \\ - --num-experts 128 --topk 8 --tp-size 8 - python tests/bench_fuse_moe_bf16.py --warmup 10 --iters 100 -""" - -import argparse -import os -import sys -from pathlib import Path - -sys.path.insert(0, os.path.realpath(list(Path(__file__).parent.glob("../build/lib.*/"))[0])) - -import triton.language as tl -import torch -import hpc -from sglang.srt.layers.moe.fused_moe_triton.fused_moe_triton_kernels import invoke_fused_moe_kernel -from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size -from sgl_kernel import silu_and_mul - -# --------------------------------------------------------------------------- -# CLI -# --------------------------------------------------------------------------- - - -def parse_args(): - p = argparse.ArgumentParser() - p.add_argument("--hidden-size", type=int, default=4096) - p.add_argument("--intermediate-size", type=int, default=1536, - help="moe_intermediate_size per expert (before gate/up split)") - p.add_argument("--num-experts", type=int, default=128) - p.add_argument("--topk", type=int, default=8) - p.add_argument("--tp-size", type=int, default=8, - help="Tensor/Expert parallelism size. num_experts_local = num_experts // tp_size") - p.add_argument("--warmup", type=int, default=5, - help="Warmup iterations before graph capture") - p.add_argument("--iters", type=int, default=100, - help="Graph replay iterations for timing") - return p.parse_args() - - -# --------------------------------------------------------------------------- -# Batch sizes to sweep -# --------------------------------------------------------------------------- - -BATCH_SIZES = [i for i in range(1, 65)] + [4096, 8192] - - -# --------------------------------------------------------------------------- -# Input helpers -# --------------------------------------------------------------------------- - - -def make_inputs(batch_size, hidden_size, intermediate_size, num_experts_local, topk, - device="cuda"): - """ - Build BF16 inputs for a single-GPU EP benchmark. - topk_ids are sampled uniformly from local experts [0, num_experts_local). - """ - dtype = torch.bfloat16 - - x = torch.randn((batch_size, hidden_size), dtype=dtype, device=device) - - # gate+up fused: [E_local, inter*2, hidden] - gate_up_weight = torch.randn( - (num_experts_local, intermediate_size * 2, hidden_size), dtype=dtype, device=device - ) - # down: [E_local, hidden, inter] - down_weight = torch.randn( - (num_experts_local, hidden_size, intermediate_size), dtype=dtype, device=device - ) - - # topk_ids in [0, num_experts_local); topk_scale positive, sum-normalised - topk_ids = torch.randint( - 0, num_experts_local, (batch_size, topk), dtype=torch.int32, device=device - ) - raw_scale = torch.rand((batch_size, topk), dtype=torch.float32, device=device) - topk_scale = raw_scale / raw_scale.sum(dim=1, keepdim=True) - - return x, gate_up_weight, down_weight, topk_ids, topk_scale - - -# --------------------------------------------------------------------------- -# FLOPs estimate (two GEMMs: gate_up and down, per token-expert assignment) -# --------------------------------------------------------------------------- - - -def tflops(batch_size, topk, hidden_size, intermediate_size, elapsed_ms): - # gate_up: batch*topk × (inter*2) × hidden (factor 2 for gate+up) - # down: batch*topk × hidden × inter - tokens = batch_size * topk - flops = 2 * tokens * intermediate_size * 2 * hidden_size # gate_up gemm - flops += 2 * tokens * hidden_size * intermediate_size # down gemm - return flops / (elapsed_ms * 1e-3) / 1e12 - - -# --------------------------------------------------------------------------- -# CUDA-graph benchmarking helper -# --------------------------------------------------------------------------- - - -def bench_cuda_graph(fn, warmup, iters): - """Warm up, capture a CUDA graph, replay `iters` times, return avg ms.""" - for _ in range(warmup): - fn() - torch.cuda.synchronize() - - g = torch.cuda.CUDAGraph() - capture_stream = torch.cuda.Stream() - with torch.cuda.stream(capture_stream): - fn() - torch.cuda.synchronize() - with torch.cuda.graph(g, stream=capture_stream): - fn() - - torch.cuda.synchronize() - - t0 = torch.cuda.Event(enable_timing=True) - t1 = torch.cuda.Event(enable_timing=True) - t0.record() - for _ in range(iters): - g.replay() - t1.record() - torch.cuda.synchronize() - return t0.elapsed_time(t1) / iters - - -# --------------------------------------------------------------------------- -# sglang default BF16 config -# --------------------------------------------------------------------------- - - -def sglang_bf16_config(total_m, E): - """ - Static BF16 config mirroring sglang's get_default_config logic: - total_m <= E -> small-batch config (BLOCK_SIZE_M=16) - total_m > E -> regular config (BLOCK_SIZE_M=64) - """ - if total_m <= E: - return { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 3, - } - else: - return { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - "num_warps": 4, - "num_stages": 3, - } - - -# --------------------------------------------------------------------------- -# Per-kernel bench helpers -# --------------------------------------------------------------------------- - - -def bench_hpc(x, gate_up_weight, down_weight, topk_ids, topk_scale, - num_experts_local, warmup, iters): - """Benchmark hpc.fuse_moe_bf16 (full pipeline).""" - # rank_ep=0, num_expert_total=num_experts_local: all experts are local - def fn(): - hpc.fuse_moe_bf16( - x, gate_up_weight, down_weight, - topk_ids, topk_scale, - rank_ep=0, - num_expert_total=num_experts_local, - ) - - return bench_cuda_graph(fn, warmup, iters) - - -def bench_sglang(x, gate_up_weight, down_weight, topk_ids, topk_scale, - num_experts_local, warmup, iters): - """ - Benchmark sglang full fused_moe pipeline (BF16) via direct kernel calls. - All steps are inside the CUDA graph for a fair comparison with hpc.fuse_moe_bf16: - 0. sgl_moe_align_block_size token sorting (= hpc count_and_gather) - 1. invoke_fused_moe_kernel gate_up GEMM (mul_routed_weight=False) - 2. silu_and_mul SiLU activation - 3. invoke_fused_moe_kernel down GEMM (mul_routed_weight=True, top_k=1) - 4. torch.sum over topk dim weighted reduce - """ - batch_size = x.shape[0] - hidden_size = x.shape[1] - inter_x2 = gate_up_weight.shape[1] # intermediate_size * 2 - inter = inter_x2 // 2 - topk = topk_ids.shape[1] - total_tokens = batch_size * topk - - config = sglang_bf16_config(total_tokens, num_experts_local) - block_size = config["BLOCK_SIZE_M"] - - # Pre-allocate routing buffers at max possible size (shapes are static for CUDA graph) - if topk_ids.numel() < num_experts_local + 1: - max_padded = topk_ids.numel() * block_size - else: - max_padded = topk_ids.numel() + (num_experts_local + 1) * (block_size - 1) - max_m_blocks = (max_padded + block_size - 1) // block_size - - sorted_ids = torch.empty((max_padded,), dtype=torch.int32, device=x.device) - expert_ids = torch.empty((max_m_blocks,), dtype=torch.int32, device=x.device) - num_tokens_post_pad = torch.empty((1,), dtype=torch.int32, device=x.device) - cumsum_buf = torch.empty((num_experts_local + 2,), dtype=torch.int32, device=x.device) - - # Intermediate compute buffers (max padded size) - cache1 = torch.empty((max_padded, inter_x2), dtype=torch.bfloat16, device=x.device) - cache2 = torch.empty((max_padded, inter), dtype=torch.bfloat16, device=x.device) - cache3 = torch.empty((batch_size, topk, hidden_size), dtype=torch.bfloat16, device=x.device) - out = torch.empty((batch_size, hidden_size), dtype=torch.bfloat16, device=x.device) - - def fn(): - # 0. token sorting - sgl_moe_align_block_size( - topk_ids, num_experts_local + 1, block_size, - sorted_ids, expert_ids, num_tokens_post_pad, cumsum_buf, True, - ) - # 1. gate_up GEMM - invoke_fused_moe_kernel( - x, gate_up_weight, None, cache1, - None, None, None, - topk_scale, topk_ids, - sorted_ids, expert_ids, num_tokens_post_pad, - False, # mul_routed_weight - topk, # top_k - config, - compute_type=tl.bfloat16, - use_fp8_w8a8=False, - use_int8_w8a8=False, - use_int8_w8a16=False, - use_int4_w4a16=False, - per_channel_quant=False, - block_shape=None, - ) - # 2. SiLU activation: cache1 [max_padded, inter*2] -> cache2 [max_padded, inter] - silu_and_mul(cache1, cache2) - # 3. down GEMM – writes weighted results to cache3[batch, topk, hidden] - invoke_fused_moe_kernel( - cache2, down_weight, None, cache3, - None, None, None, - topk_scale, topk_ids, - sorted_ids, expert_ids, num_tokens_post_pad, - True, # mul_routed_weight (applies routing weights during scatter) - 1, # top_k=1: each sorted row maps to one (batch, topk_slot) in cache3 - config, - compute_type=tl.bfloat16, - use_fp8_w8a8=False, - use_int8_w8a8=False, - use_int8_w8a16=False, - use_int4_w4a16=False, - per_channel_quant=False, - block_shape=None, - ) - # 4. Sum-reduce over topk slots -> [batch, hidden] - torch.sum(cache3, dim=1, out=out) - - return bench_cuda_graph(fn, warmup, iters) - - -# --------------------------------------------------------------------------- -# Main -# --------------------------------------------------------------------------- - - -def main(): - args = parse_args() - tp_size = args.tp_size - hidden_size = args.hidden_size - intermediate_size = args.intermediate_size // tp_size - num_experts = args.num_experts - topk = args.topk - num_experts_local = num_experts - - torch.manual_seed(0) - torch.cuda.manual_seed(0) - - prop = torch.cuda.get_device_properties(0) - print(f"\nDevice : {prop.name}") - print(f"Model config : hidden={hidden_size}, inter={intermediate_size}, " - f"experts={num_experts}, topk={topk}, tp={tp_size}") - print(f"Local experts/GPU : {num_experts_local}") - print(f"Weight shapes : gate_up=[{num_experts_local}, {intermediate_size*2}, {hidden_size}], " - f"down=[{num_experts_local}, {hidden_size}, {intermediate_size}]") - print(f"Timing : warmup={args.warmup}, iters={args.iters} (CUDA graph replay)\n") - - hdr = ( - f"{'batch':>7} {'tokens':>8} " - f"{'hpc(ms)':>10} {'hpc(TF)':>9} " - f"{'sgl(ms)':>10} {'sgl(TF)':>9} " - f"{'speedup':>8}" - ) - sep = "-" * len(hdr) - print(hdr) - print(sep) - - for bs in BATCH_SIZES: - x, gate_up_weight, down_weight, topk_ids, topk_scale = make_inputs( - bs, hidden_size, intermediate_size, num_experts_local, topk - ) - - hpc_ms = bench_hpc( - x, gate_up_weight, down_weight, topk_ids, topk_scale, - num_experts_local, args.warmup, args.iters - ) - sgl_ms = bench_sglang( - x, gate_up_weight, down_weight, topk_ids, topk_scale, - num_experts_local, args.warmup, args.iters - ) - - hpc_tf = tflops(bs, topk, hidden_size, intermediate_size, hpc_ms) - sgl_tf = tflops(bs, topk, hidden_size, intermediate_size, sgl_ms) - - print( - f"{bs:>7} {bs*topk:>8} " - f"{hpc_ms:>10.4f} {hpc_tf:>9.2f} " - f"{sgl_ms:>10.4f} {sgl_tf:>9.2f} " - f"{sgl_ms/hpc_ms:>8.2f}x" - ) - - print(sep) - print() - - -if __name__ == "__main__": - main() diff --git a/tests/bench_group_gemm_bf16.py b/tests/bench_group_gemm_bf16.py deleted file mode 100644 index 5061a03..0000000 --- a/tests/bench_group_gemm_bf16.py +++ /dev/null @@ -1,257 +0,0 @@ -""" -Benchmark script for hpc.group_gemm_bf16 vs sglang fused_moe_kernel (BF16). - -Usage: - python tests/bench_group_gemm_bf16.py - python tests/bench_group_gemm_bf16.py --E 128 --N 384 --K 4096 - python tests/bench_group_gemm_bf16.py --E 8 --N 4096 --K 7168 --warmup 10 --iters 100 -""" - -import argparse -import os -import sys -from pathlib import Path - -sys.path.insert(0, os.path.realpath(list(Path(__file__).parent.glob("../build/lib.*/"))[0])) - -import torch -import triton.language as tl -import hpc -from sglang.srt.layers.moe.fused_moe_triton.fused_moe_triton_kernels import invoke_fused_moe_kernel -from sglang.srt.layers.moe.fused_moe_triton.moe_align_block_size import moe_align_block_size - -# --------------------------------------------------------------------------- -# CLI -# --------------------------------------------------------------------------- - - -def parse_args(): - p = argparse.ArgumentParser() - p.add_argument("--E", type=int, default=8, help="Number of groups (experts)") - p.add_argument("--N", type=int, default=4096, help="Output dimension per group") - p.add_argument("--K", type=int, default=7168, help="Input dimension (hidden size)") - p.add_argument("--warmup", type=int, default=5, help="Warmup iterations before graph capture") - p.add_argument("--iters", type=int, default=100, help="Graph replay iterations for timing") - return p.parse_args() - - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - -M_VALUES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512] - - -def make_inputs(E, m_per_group, N, K, device="cuda"): - dtype = torch.bfloat16 - total_m = E * m_per_group - x = torch.randn((total_m, K), dtype=dtype, device=device) - w = torch.randn((E, N, K), dtype=dtype, device=device) - seqlens = torch.full((E,), m_per_group, dtype=torch.int32, device=device) - cu_seqlens = torch.zeros(E + 1, dtype=torch.int32, device=device) - cu_seqlens[1:] = torch.cumsum(seqlens, dim=0) - output = torch.empty((total_m, N), dtype=dtype, device=device) - return x, w, seqlens, cu_seqlens, output - - -def tflops(E, m_per_group, N, K, elapsed_ms): - flops = 2 * E * m_per_group * N * K - return flops / (elapsed_ms * 1e-3) / 1e12 - - -# --------------------------------------------------------------------------- -# sglang default BF16 config -# (mirrors get_default_config logic for dtype=None, no server args needed) -# --------------------------------------------------------------------------- - - -def sglang_bf16_config(total_m, E): - """ - Replicates sglang's get_default_config logic for plain BF16 (dtype=None). - M <= E -> small-batch config (BLOCK_SIZE_M=16) - M > E -> regular config (BLOCK_SIZE_M=64) - num_warps / num_stages are typical Hopper-friendly defaults. - """ - if total_m <= E: - return { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 3, - } - else: - return { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - "num_warps": 4, - "num_stages": 3, - } - - -# --------------------------------------------------------------------------- -# Benchmark with CUDA graph -# --------------------------------------------------------------------------- - - -def bench_cuda_graph(fn, warmup, iters): - """Warm up, capture a CUDA graph, replay `iters` times, return avg ms.""" - for _ in range(warmup): - fn() - torch.cuda.synchronize() - - g = torch.cuda.CUDAGraph() - capture_stream = torch.cuda.Stream() - with torch.cuda.stream(capture_stream): - fn() # dry run on capture stream - torch.cuda.synchronize() - with torch.cuda.graph(g, stream=capture_stream): - fn() - - torch.cuda.synchronize() - - t0 = torch.cuda.Event(enable_timing=True) - t1 = torch.cuda.Event(enable_timing=True) - t0.record() - for _ in range(iters): - g.replay() - t1.record() - torch.cuda.synchronize() - return t0.elapsed_time(t1) / iters - - -# --------------------------------------------------------------------------- -# Per-kernel bench helpers -# --------------------------------------------------------------------------- - - -def bench_hpc(x, w, seqlens, cu_seqlens, output, mean_seq, warmup, iters): - def fn(): - hpc.group_gemm_bf16( - x, w, seqlens, cu_seqlens, num_seq_per_group_avg=mean_seq, output=output - ) - - return bench_cuda_graph(fn, warmup, iters) - - -def bench_torch_ref(x, w, seqlens, cu_seqlens, output, warmup, iters): - E = seqlens.shape[0] - slices = [(int(cu_seqlens[i].item()), int(cu_seqlens[i + 1].item())) for i in range(E)] - - def fn(): - for i, (s, e) in enumerate(slices): - output[s:e] = torch.matmul(x[s:e], w[i].t()) - - return bench_cuda_graph(fn, warmup, iters) - - -def bench_sglang(x, w, E, m_per_group, N, warmup, iters): - """ - Benchmark sglang's invoke_fused_moe_kernel for a single up/gate projection - (top_k=1, each token routes to its group, no routed-weight multiply). - - moe_align_block_size is called once outside the graph; only the GEMM - kernel itself is captured and replayed. - """ - total_m = E * m_per_group - config = sglang_bf16_config(total_m, E) - - # top_k=1: token i belongs to expert floor(i / m_per_group) - topk_ids = ( - (torch.arange(total_m, device="cuda") // m_per_group).view(-1, 1).to(torch.int32) - ) # [total_m, 1] - topk_weights = torch.ones(total_m, 1, dtype=torch.bfloat16, device="cuda") - - # Routing tensors – computed once, reused by every graph replay - sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( - topk_ids, config["BLOCK_SIZE_M"], E - ) - - # Output shape follows sglang convention: [padded_tokens, N] - sgl_out = torch.empty((sorted_token_ids.shape[0], N), dtype=torch.bfloat16, device="cuda") - - def fn(): - invoke_fused_moe_kernel( - A=x, - B=w, - bias=None, - C=sgl_out, - A_scale=None, - B_scale=None, - B_zp=None, - topk_weights=topk_weights, - topk_ids=topk_ids, - sorted_token_ids=sorted_token_ids, - expert_ids=expert_ids, - num_tokens_post_padded=num_tokens_post_padded, - mul_routed_weight=False, - top_k=1, - config=config, - compute_type=tl.bfloat16, - use_fp8_w8a8=False, - use_int8_w8a8=False, - use_int8_w8a16=False, - use_int4_w4a16=False, - per_channel_quant=False, - block_shape=None, - ) - - return bench_cuda_graph(fn, warmup, iters) - - -# --------------------------------------------------------------------------- -# Main -# --------------------------------------------------------------------------- - - -def main(): - args = parse_args() - E, N, K = args.E, args.N, args.K - - torch.manual_seed(0) - torch.cuda.manual_seed(0) - - prop = torch.cuda.get_device_properties(0) - print(f"\nDevice : {prop.name}") - print(f"Config : E={E}, N={N}, K={K}") - print(f"Timing : warmup={args.warmup}, iters={args.iters} (CUDA graph replay)\n") - - hdr = ( - f"{'M/group':>8} {'total_M':>8} " - f"{'hpc(ms)':>10} {'hpc(TF)':>9} " - f"{'sgl(ms)':>10} {'sgl(TF)':>9} " - f"{'ref(ms)':>10} {'ref(TF)':>9} " - f"{'hpc/sgl':>8}" - ) - sep = "-" * len(hdr) - print(hdr) - print(sep) - - for m in M_VALUES: - x, w, seqlens, cu_seqlens, output = make_inputs(E, m, N, K) - - hpc_ms = bench_hpc(x, w, seqlens, cu_seqlens, output, m, args.warmup, args.iters) - sgl_ms = bench_sglang(x, w, E, m, N, args.warmup, args.iters) - ref_ms = bench_torch_ref(x, w, seqlens, cu_seqlens, output, args.warmup, args.iters) - - hpc_tf = tflops(E, m, N, K, hpc_ms) - sgl_tf = tflops(E, m, N, K, sgl_ms) - ref_tf = tflops(E, m, N, K, ref_ms) - - print( - f"{m:>8} {E*m:>8} " - f"{hpc_ms:>10.4f} {hpc_tf:>9.2f} " - f"{sgl_ms:>10.4f} {sgl_tf:>9.2f} " - f"{ref_ms:>10.4f} {ref_tf:>9.2f} " - f"{sgl_ms/hpc_ms:>8.2f}x" - ) - - print(sep) - print() - - -if __name__ == "__main__": - main() From 4fc4d799e9692894831b47fd6a1d9aa04644e4fd Mon Sep 17 00:00:00 2001 From: ZelinMa557 <3388706467@qq.com> Date: Tue, 28 Apr 2026 16:14:33 +0800 Subject: [PATCH 6/7] remove additional CUDA core precision correction Signed-off-by: ZelinMa557 <3388706467@qq.com> --- src/group_gemm/kernels.cuh | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/src/group_gemm/kernels.cuh b/src/group_gemm/kernels.cuh index 6b054f1..9c3dbda 100644 --- a/src/group_gemm/kernels.cuh +++ b/src/group_gemm/kernels.cuh @@ -953,15 +953,12 @@ __global__ void __launch_bounds__(384, 1) iblock += gridDim.x; - auto tDr = make_tensor_like(tCr); - clear(tDr); - int ntile_k = size<2>(tAg); + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; #pragma unroll 1 for (int itile_k = 0; itile_k < ntile_k; ++itile_k) { wait_barrier(readable[ismem_read], phase); - tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; // mma warpgroup_fence_operand(tCr); warpgroup_arrive(); @@ -979,11 +976,6 @@ __global__ void __launch_bounds__(384, 1) arrive_barrier(writable[ismem_read]); } -#pragma unroll - for (int i = 0; i < size(tCr); ++i) { - tDr(i) = tCr(i) + tDr(i); - } - ++ismem_read; if (ismem_read == kStage) { phase ^= 1; @@ -996,7 +988,7 @@ __global__ void __launch_bounds__(384, 1) #pragma unroll for (int i = 0; i < size(tCr); ++i) { - tCrh(i) = (Tout)(tDr(i)); + tCrh(i) = (Tout)(tCr(i)); } // Epilogue From 1c8e9f1de8a69e9c86bbc9aef1f61c8a8a0ed7b0 Mon Sep 17 00:00:00 2001 From: ZelinMa557 <3388706467@qq.com> Date: Sat, 9 May 2026 17:17:31 +0800 Subject: [PATCH 7/7] better wgmma pipeline for bf16 group gemm Signed-off-by: ZelinMa557 <3388706467@qq.com> --- src/group_gemm/kernels.cuh | 30 ++++++++++++++++++++++++------ 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/src/group_gemm/kernels.cuh b/src/group_gemm/kernels.cuh index 9c3dbda..84973e4 100644 --- a/src/group_gemm/kernels.cuh +++ b/src/group_gemm/kernels.cuh @@ -955,27 +955,38 @@ __global__ void __launch_bounds__(384, 1) int ntile_k = size<2>(tAg); tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + + // Pipelined K-loop: keep 1 WGMMA batch in flight (wait<1> instead of + // wait<0> per iter). The writable barrier of stage k is released only + // after wait<1> in iter k+1, so the MMA of iter k+1 can overlap with + // the producer's reload of stage k. The accumulator tCr only needs to + // be fenced once before the first WGMMA and once after the last drain. + int prev_ismem_read = 0; + warpgroup_fence_operand(tCr); #pragma unroll 1 for (int itile_k = 0; itile_k < ntile_k; ++itile_k) { wait_barrier(readable[ismem_read], phase); // mma - warpgroup_fence_operand(tCr); warpgroup_arrive(); #pragma unroll for (int ik = 0; ik < size<2>(tAr); ++ik) { cute::gemm(tiled_mma, tBr(_, _, ik, ismem_read), tAr(_, _, ik, ismem_read), tCr(_, _, _)); tiled_mma.accumulate_ = GMMA::ScaleOut::One; } - warpgroup_commit_batch(); - warpgroup_wait<0>(); - warpgroup_fence_operand(tCr); - if (elected_idx_in_warpgroup) { - arrive_barrier(writable[ismem_read]); + if (itile_k >= 1) { + // Previous batch is now done; release the stage it consumed so + // the producer can refill it while the current batch is still + // in flight. + warpgroup_wait<1>(); + if (elected_idx_in_warpgroup) { + arrive_barrier(writable[prev_ismem_read]); + } } + prev_ismem_read = ismem_read; ++ismem_read; if (ismem_read == kStage) { phase ^= 1; @@ -983,6 +994,13 @@ __global__ void __launch_bounds__(384, 1) } } + // Drain the last in-flight batch and release its stage. + warpgroup_wait<0>(); + warpgroup_fence_operand(tCr); + if (elected_idx_in_warpgroup) { + arrive_barrier(writable[prev_ismem_read]); + } + // float32 -> bfloat16 auto tCrh = make_tensor_like(tCr);