Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions hpc/group_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,52 @@ def group_gemm_blockwise_fp8(
)


def group_gemm_bf16(
x: Tensor,
weight: Tensor,
seqlens: Tensor,
cu_seqlens: Tensor,
num_seq_per_group_avg: int = 32,
output: Tensor = None,
tma_desc: Tensor = None,
) -> Tensor:
"""Performs group GEMM operation with BF16 precision.

This function executes multiple matrix multiplications in a group manner
using BF16 precision for improved performance.

Args:
x: Input activation tensor
Shape: [total_seq, hidden_size]
Dtype: bfloat16
weight: Weight tensor for group matrix multiplication
Shape: [num_group, output_dim, hidden_size]
Dtype: bfloat16
seqlens: Sequence lengths for each group
Shape: [num_group]
Dtype: int32
cu_seqlens: Cumulative sequence lengths indicating start indices in input tensor
Shape: [num_group + 1]
Dtype: int32

Returns:
Tensor: Output tensor after group matrix multiplication
Shape: [total_seq, output_dim]
Dtype: bfloat16

Raises:
RuntimeError: If the input tensors have incompatible shapes or types,
or if the CUDA kernel execution fails.

Note:
- All input tensors must be on CUDA device

"""
return torch.ops.hpc.group_gemm_bf16(
x, weight, seqlens, cu_seqlens, num_seq_per_group_avg, output, tma_desc
)


@torch.library.register_fake("hpc::group_gemm_pertensor_fp8")
def group_gemm_pertensor_fp8_fake(
x, weight, seqlens, cu_seqlens, y_scale, num_seq_per_group_avg, output, tma_des
Expand All @@ -164,3 +210,8 @@ def group_gemm_blockwise_fp8_fake(
x, weight, seqlens, cu_seqlens, x_scale, w_scale, num_seq_per_group_avg, output, tma_des
):
return torch.empty((x.shape[0], weight.shape[1]), dtype=torch.bfloat16)


@torch.library.register_fake("hpc::group_gemm_bf16")
def group_gemm_bf16_fake(x, weight, seqlens, cu_seqlens, num_seq_per_group_avg, output, tma_des):
return torch.empty((x.shape[0], weight.shape[1]), dtype=torch.bfloat16)
69 changes: 69 additions & 0 deletions src/group_gemm/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,75 @@ struct GroupGEMMBlockWiseFp8Config {
auto get_shm_size() { return shm_size; }
};

template <int kTileM>
static constexpr auto mma_selector_bf16() {
constexpr auto MajorA = GMMA::Major::K;
constexpr auto MajorB = GMMA::Major::K;
if constexpr (kTileM == 8) {
return cute::SM90_64x8x16_F32BF16BF16_SS<MajorA, MajorB>{};
} else if constexpr (kTileM == 16) {
return cute::SM90_64x16x16_F32BF16BF16_SS<MajorA, MajorB>{};
} else if constexpr (kTileM == 32) {
return cute::SM90_64x32x16_F32BF16BF16_SS<MajorA, MajorB>{};
} else if constexpr (kTileM == 48) {
return cute::SM90_64x48x16_F32BF16BF16_SS<MajorA, MajorB>{};
} else if constexpr (kTileM == 64) {
return cute::SM90_64x64x16_F32BF16BF16_SS<MajorA, MajorB>{};
} else if constexpr (kTileM == 96) {
return cute::SM90_64x96x16_F32BF16BF16_SS<MajorA, MajorB>{};
} else if constexpr (kTileM == 128) {
return cute::SM90_64x128x16_F32BF16BF16_SS<MajorA, MajorB>{};
} else {
return cute::SM90_64x64x16_F32BF16BF16_SS<MajorA, MajorB>{};
}
}

template <typename Tin_, typename Tout_, int kTileM_, int kTileN_, int kTileK_, int kStage_,
int kWarpgroupM_ = 2, int kWarpgroupN_ = 1, int kSwizzleX = 128, int kSwizzleW = 128,
int kSwizzleY = 128>
struct GroupGEMMBF16Config {
using Tin = Tin_;
using Tout = Tout_;

static constexpr int kTileM = kTileM_;
static constexpr int kTileN = kTileN_;
static constexpr int kTileK = kTileK_;
static constexpr int kStage = kStage_;
static constexpr int kWarpgroupM = kWarpgroupM_;
static constexpr int kWarpgroupN = kWarpgroupN_;

using SLayoutXAtom = decltype(slayout_selector<kSwizzleX, Tin>());
using SLayoutWAtom = decltype(slayout_selector<kSwizzleW, Tin>());
using SLayoutYAtom = decltype(slayout_selector<kSwizzleY, Tout, false>());

using SLayoutX = decltype(tile_to_shape(SLayoutXAtom{},
make_shape(Int<kTileM>{}, Int<kTileK>{}, Int<kStage>{})));
using SLayoutW = decltype(tile_to_shape(SLayoutWAtom{},
make_shape(Int<kTileN>{}, Int<kTileK>{}, Int<kStage>{})));
using SLayoutY =
decltype(tile_to_shape(SLayoutYAtom{}, make_shape(Int<kTileN>{}, Int<kTileM>{})));
using CopyBoxY = decltype(tile_to_shape(SLayoutYAtom{},
make_shape(Int<kTileN / kWarpgroupM>{}, Int<kTileM>{})));

template <typename TX, typename TW, typename TY>
auto get_tma(TX x, TW w, TY y) {
auto tma_x = make_tma_copy(SM90_TMA_LOAD{}, x, take<0, 2>(SLayoutX{}));
auto tma_w = make_tma_copy(SM90_TMA_LOAD{}, w, take<0, 2>(SLayoutW{}));
auto tma_y = make_tma_copy(SM90_TMA_STORE{}, y, CopyBoxY{});
return std::make_tuple(tma_x, tma_w, tma_y);
}

using WarpgroupLayout =
decltype(make_layout(make_shape(Int<kWarpgroupM>{}, Int<kWarpgroupN>{}, Int<1>{})));
using TiledMma = decltype(make_tiled_mma(mma_selector_bf16<kTileM>(), WarpgroupLayout{}));

static constexpr int shm_xw = (cosize(SLayoutX{}) + cosize(SLayoutW{})) * sizeof(Tin);
static constexpr int shm_y = cosize(SLayoutY{}) * sizeof(Tout);
static constexpr int shm_size = shm_xw + shm_y;

auto get_shm_size() { return shm_size; }
};

} // namespace group_gemm
} // namespace hpc

Expand Down
67 changes: 66 additions & 1 deletion src/group_gemm/entry.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#include <torch/library.h>

#include <tuple>

#include <iostream>
#include "src/group_gemm/group_gemm.h"

namespace hpc {
Expand Down Expand Up @@ -137,6 +137,65 @@ torch::Tensor group_gemm_blockwise_fp8_entry(
return y;
}

torch::Tensor group_gemm_bf16_entry(const torch::Tensor &x, const torch::Tensor &weight,
const torch::Tensor &seqlens,
const torch::Tensor &cu_seqlens,
const int64_t num_seq_per_group_avg,
std::optional<torch::Tensor> output,
std::optional<torch::Tensor> tma_desc) {
auto stream = at::cuda::getCurrentCUDAStream(x.get_device());
TORCH_CHECK(x.device().is_cuda(), "x tensor must be cuda");
TORCH_CHECK(weight.device().is_cuda(), "weight tensor must be cuda");
TORCH_CHECK(seqlens.device().is_cuda(), "seqlens tensor must be cuda");
TORCH_CHECK(cu_seqlens.device().is_cuda(), "cu_seqlens tensor must be cuda");
TORCH_CHECK(x.is_contiguous(), "x tensor a must be contiguous");
TORCH_CHECK(weight.is_contiguous(), "weight tensor a must be contiguous");
TORCH_CHECK(seqlens.size(0) == weight.size(0),
"seqlens and weight must share the same num_group");
TORCH_CHECK(x.size(1) == weight.size(2), "x and weight must share the same k");

int m = x.size(0);
int k = x.size(1);
int n = weight.size(1);
int num_group = seqlens.size(0);

auto options = x.options();
torch::Tensor y;
if (output.has_value()) {
y = output.value();
} else {
y = torch::empty({m, n}, options.dtype(torch::kBFloat16));
}

torch::Tensor tmas;
bool update_tma = true;
if (tma_desc.has_value()) {
tmas = tma_desc.value();
update_tma = false;
} else {
tmas = torch::empty({num_group * 2, 128}, options);
}

torch::Tensor tiles = torch::empty({num_group}, options.dtype(torch::kInt32));
torch::Tensor cu_tiles = torch::empty({num_group + 1}, options.dtype(torch::kInt32));

const auto *x_ptr = x.const_data_ptr();
const auto *weight_ptr = weight.const_data_ptr();
const auto *seqlens_ptr = seqlens.const_data_ptr();
const auto *cu_seqlens_ptr = cu_seqlens.const_data_ptr();
auto *tmas_ptr = tmas.mutable_data_ptr();
auto *y_ptr = y.mutable_data_ptr();

auto *tiles_ptr = tiles.mutable_data_ptr();
auto *cu_tiles_ptr = cu_tiles.mutable_data_ptr();

group_gemm_bf16_async(y_ptr, x_ptr, weight_ptr, seqlens_ptr, cu_seqlens_ptr,
tmas_ptr, tiles_ptr, cu_tiles_ptr, num_group, m, n, k,
num_seq_per_group_avg, update_tma, stream);

return y;
}

torch::Tensor reformat_x_scale_entry(const torch::Tensor &x_scale, const torch::Tensor &seqlens,
const torch::Tensor &cu_seqlens,
std::optional<torch::Tensor> out_x_scale,
Expand Down Expand Up @@ -197,6 +256,12 @@ TORCH_LIBRARY_FRAGMENT(hpc, m) {
m.impl("group_gemm_pertensor_fp8", torch::kCUDA,
&hpc::group_gemm::group_gemm_pertensor_fp8_entry);

m.def(
"group_gemm_bf16(Tensor x, Tensor weight, Tensor seqlens, Tensor cu_seqlens, "
"int num_seq_per_group_avg, Tensor? output, Tensor? tma_desc) -> (Tensor)");
m.impl("group_gemm_bf16", torch::kCUDA,
&hpc::group_gemm::group_gemm_bf16_entry);

m.def(
"group_gemm_blockwise_fp8(Tensor x, Tensor weight, Tensor seqlens, Tensor cu_seqlens, Tensor "
"xscale, Tensor wscale,"
Expand Down
6 changes: 6 additions & 0 deletions src/group_gemm/group_gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ void reformat_x_scale_async(void *output_ptr, const void *xscale_ptr, const void
const void *cu_seqlens_ptr, int num_group, int m, int n, int tilem,
cudaStream_t stream);

void group_gemm_bf16_async(void *y_ptr, const void *x_ptr, const void *w_ptr,
const void *seqlens_ptr, const void *cu_seqlens_ptr,
void *tmas_ptr, void *tiles_ptr,
void *cu_tiles_ptr, int num_group, int m, int n, int k,
int num_seq_per_group_avg, bool update_tma,
cudaStream_t stream);
} // namespace group_gemm
} // namespace hpc

Expand Down
138 changes: 138 additions & 0 deletions src/group_gemm/group_gemm_bf16.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
// Copyright (C) 2026 Tencent.

#include <cuda.h>
#include <stdio.h>

#include <cub/cub.cuh>
#include <iostream>
#include "cute/tensor.hpp"
#include "src/group_gemm/config.h"
#include "src/group_gemm/group_gemm.h"
#include "src/group_gemm/kernels.cuh"

namespace hpc {
namespace group_gemm {

template <int kTileM, int kTileN, int kTileK, int kStage, int kWarpgroupM, int kWarpgroupN,
int kSwizzleX, int kSwizzleW, int kSwizzleY>
void launch_group_gemm_bf16(void *y_ptr, const void *x_ptr, const void *w_ptr,
const void *seqlens_ptr, const void *cu_seqlens_ptr,
void *tmas_ptr, void *tiles_ptr, void *cu_tiles_ptr, int num_group,
int m, int n, int k, bool update_tma, cudaStream_t stream) {
using namespace cute; // NOLINT

using Tin = cute::bfloat16_t;
using Tout = cute::bfloat16_t;

auto X = make_tensor(make_gmem_ptr(reinterpret_cast<const Tin *>(x_ptr)), make_shape(m, k),
make_stride(k, Int<1>{}));
auto W = make_tensor(make_gmem_ptr(reinterpret_cast<const Tin *>(w_ptr)),
make_shape(n, k, num_group), make_stride(k, Int<1>{}, n * k));
auto Y = make_tensor(make_gmem_ptr(reinterpret_cast<Tout *>(y_ptr)), make_shape(n, m),
make_stride(Int<1>{}, n));

using Config = GroupGEMMBF16Config<Tin, Tout, kTileM, kTileN, kTileK, kStage, kWarpgroupM,
kWarpgroupN, kSwizzleX, kSwizzleW, kSwizzleY>;
Config config;
auto [tma_x, tma_w, tma_y] = config.get_tma(X, W, Y);

auto *tma_xy = static_cast<cute::TmaDescriptor *>(tmas_ptr);

// 0. update tma
if (update_tma) {
vec_t<cute::TmaDescriptor, 2> td_xy{
*tma_x.get_tma_descriptor(),
*tma_y.get_tma_descriptor(),
};

constexpr int kGroupPerThread = 8;
constexpr int kThreadPerBlock = 32;
kernels::update_grouped_tma<Tin, Tout, decltype(tma_x), decltype(tma_y), kTileM,
kGroupPerThread, kThreadPerBlock>
<<<num_group + 1, kThreadPerBlock, 0, stream>>>(
td_xy, tma_xy, (const Tin *)x_ptr, (const Tout *)y_ptr, (const int *)seqlens_ptr,
(const int *)cu_seqlens_ptr, (int *)tiles_ptr, (int *)cu_tiles_ptr, num_group, m, n, k);
}

// 1. group gemm
{
int num_tile_n = (n + kTileN - 1) / kTileN;
cutlass::FastDivmod flat_divider(num_tile_n);

// dim3 block(size(Config::TiledMma{}) + 128);
dim3 block(384);
dim3 grid(get_sm_count());

int shm_seq = sizeof(int) * (num_group + 1);
int shm_size = config.get_shm_size() + shm_seq;

if (k <= 1024 || n <= 1024) {
constexpr bool IsLoopH = true;
auto kernel =
kernels::group_gemm_bf16_kernel<decltype(config), decltype(tma_x),
decltype(tma_w), decltype(tma_y), IsLoopH>;
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shm_size);

kernel<<<grid, block, shm_size, stream>>>(tma_w, tma_xy, (int *)seqlens_ptr,
(int *)tiles_ptr, (int *)cu_tiles_ptr, num_group, m,
n, k, flat_divider);
} else {
constexpr bool IsLoopH = false;
auto kernel =
kernels::group_gemm_bf16_kernel<decltype(config), decltype(tma_x),
decltype(tma_w), decltype(tma_y), IsLoopH>;
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shm_size);

kernel<<<grid, block, shm_size, stream>>>(tma_w, tma_xy, (int *)seqlens_ptr,
(int *)tiles_ptr, (int *)cu_tiles_ptr, num_group, m,
n, k, flat_divider);
}
}
}

void group_gemm_bf16_async(void *y_ptr, const void *x_ptr, const void *w_ptr,
const void *seqlens_ptr, const void *cu_seqlens_ptr,
void *tmas_ptr, void *tiles_ptr,
void *cu_tiles_ptr, int num_group, int m, int n, int k,
int num_seq_per_group_avg, bool update_tma,
cudaStream_t stream) {
constexpr int kTileN = 128;
constexpr int kTileK = 64;
constexpr int kWarpgroupM = 2;
constexpr int kWarpgroupN = 1;
constexpr int kSwizzleX = 128;
constexpr int kSwizzleW = 128;
constexpr int kSwizzleY = 64;
if (num_seq_per_group_avg <= 16) {
constexpr int kTileM = 16;
constexpr int kStage = 8;
launch_group_gemm_bf16<kTileM, kTileN, kTileK, kStage, kWarpgroupM, kWarpgroupN, kSwizzleX,
kSwizzleW, kSwizzleY>(y_ptr, x_ptr, w_ptr, seqlens_ptr, cu_seqlens_ptr,
tmas_ptr, tiles_ptr, cu_tiles_ptr,
num_group, m, n, k, update_tma, stream);
} else if (num_seq_per_group_avg <= 32) {
constexpr int kTileM = 32;
constexpr int kStage = 8;
launch_group_gemm_bf16<kTileM, kTileN, kTileK, kStage, kWarpgroupM, kWarpgroupN, kSwizzleX,
kSwizzleW, kSwizzleY>(y_ptr, x_ptr, w_ptr, seqlens_ptr, cu_seqlens_ptr,
tmas_ptr, tiles_ptr, cu_tiles_ptr,
num_group, m, n, k, update_tma, stream);
} else if (num_seq_per_group_avg <= 48) {
constexpr int kTileM = 48;
constexpr int kStage = 8;
launch_group_gemm_bf16<kTileM, kTileN, kTileK, kStage, kWarpgroupM, kWarpgroupN, kSwizzleX,
kSwizzleW, kSwizzleY>(y_ptr, x_ptr, w_ptr, seqlens_ptr, cu_seqlens_ptr,
tmas_ptr, tiles_ptr, cu_tiles_ptr,
num_group, m, n, k, update_tma, stream);
} else {
constexpr int kTileM = 64;
constexpr int kStage = 8;
launch_group_gemm_bf16<kTileM, kTileN, kTileK, kStage, kWarpgroupM, kWarpgroupN, kSwizzleX,
kSwizzleW, kSwizzleY>(y_ptr, x_ptr, w_ptr, seqlens_ptr, cu_seqlens_ptr,
tmas_ptr, tiles_ptr, cu_tiles_ptr,
num_group, m, n, k, update_tma, stream);
}
}

} // namespace group_gemm
} // namespace hpc
Loading