Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions hpc/fuse_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,74 @@ def fuse_moe_blockwise_fp8(
)


def fuse_moe_bf16(
x: Tensor,
gate_up_weight: Tensor,
down_weight: Tensor,
topk_ids: Tensor,
topk_scale: Tensor,
rank_ep: int,
num_expert_total: int,
shared_output: Tensor = None,
) -> Tensor:
"""Performs Mixture of Experts (MoE) forward operation with BF16 precision.

This function executes the MoE computation with all matrix multiplications
performed in BF16 precision. The gate and up projections are fused into
a single matrix multiplication.

Args:
x: Input activation tensor
Shape: [num_seq, hidden_size]
Dtype: bfloat16
gate_up_weight: Combined weight tensor for gate and up projections
Shape: [num_expert_local, intermediate_size * 2, hidden_size]
Dtype: bfloat16
down_weight: Weight tensor for down projection
Shape: [num_expert_local, hidden_size, intermediate_size]
Dtype: bfloat16
topk_ids: Token indices assigned to each expert
Shape: [num_seq, num_topk]
Dtype: int32
topk_scale: Weighting factors for each token-expert assignment
Shape: [num_seq, num_topk]
Dtype: float32
rank_ep: Expert parallel rank (for distributed training)
Dtype: int32
num_expert_total: the total number of expert
Dtype: int32
shared_output: output for shared experts, default is None
Shape: [num_seq, hidden_size]
Dtype: bfloat16

Returns:
torch.Tensor: Output tensor after MoE computation
Shape: [num_seq, hidden_size]
Dtype: bfloat16

Raises:
RuntimeError: If the input tensors have incompatible shapes or types,
or if CUDA kernel execution fails.

Note:
- All input tensors must be on CUDA device
- The gate and up projections are combined into a single matrix multiplication
- BF16 precision is used for all matrix operations
- Activation function used is SiLU (Swish)
- Token routing is determined by topk_ids and weighted by topk_scale
"""
return torch.ops.hpc.fuse_moe_bf16(
x,
gate_up_weight,
down_weight,
topk_ids,
topk_scale,
shared_output,
rank_ep,
num_expert_total,
)


@torch.library.register_fake("hpc::count_and_gather")
def count_and_gather_fake(
x, topk_ids, num_expert, rank_ep, intermediate_size, num_seq_per_group_avg
Expand Down Expand Up @@ -346,3 +414,17 @@ def fuse_moe_blockwise_fp8_fake(
use_bf16_mul: bool = True,
):
return torch.empty((x.shape[0], x.shape[1]), dtype=torch.bfloat16)


@torch.library.register_fake("hpc::fuse_moe_bf16")
def fuse_moe_bf16_fake(
x: Tensor,
gate_up_weight: Tensor,
down_weight: Tensor,
topk_ids: Tensor,
topk_scale: Tensor,
shared_output: Tensor,
rank_ep: int,
num_expert_total: int,
):
return torch.empty((x.shape[0], x.shape[1]), dtype=torch.bfloat16)
51 changes: 51 additions & 0 deletions hpc/group_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,52 @@ def group_gemm_blockwise_fp8(
)


def group_gemm_bf16(
x: Tensor,
weight: Tensor,
seqlens: Tensor,
cu_seqlens: Tensor,
num_seq_per_group_avg: int = 32,
output: Tensor = None,
tma_desc: Tensor = None,
) -> Tensor:
"""Performs group GEMM operation with BF16 precision.

This function executes multiple matrix multiplications in a group manner
using BF16 precision for improved performance.

Args:
x: Input activation tensor
Shape: [total_seq, hidden_size]
Dtype: bfloat16
weight: Weight tensor for group matrix multiplication
Shape: [num_group, output_dim, hidden_size]
Dtype: bfloat16
seqlens: Sequence lengths for each group
Shape: [num_group]
Dtype: int32
cu_seqlens: Cumulative sequence lengths indicating start indices in input tensor
Shape: [num_group + 1]
Dtype: int32

Returns:
Tensor: Output tensor after group matrix multiplication
Shape: [total_seq, output_dim]
Dtype: bfloat16

Raises:
RuntimeError: If the input tensors have incompatible shapes or types,
or if the CUDA kernel execution fails.

Note:
- All input tensors must be on CUDA device

"""
return torch.ops.hpc.group_gemm_bf16(
x, weight, seqlens, cu_seqlens, num_seq_per_group_avg, output, tma_desc
)


@torch.library.register_fake("hpc::group_gemm_pertensor_fp8")
def group_gemm_pertensor_fp8_fake(
x, weight, seqlens, cu_seqlens, y_scale, num_seq_per_group_avg, output, tma_des
Expand All @@ -164,3 +210,8 @@ def group_gemm_blockwise_fp8_fake(
x, weight, seqlens, cu_seqlens, x_scale, w_scale, num_seq_per_group_avg, output, tma_des
):
return torch.empty((x.shape[0], weight.shape[1]), dtype=torch.bfloat16)


@torch.library.register_fake("hpc::group_gemm_bf16")
def group_gemm_bf16_fake(x, weight, seqlens, cu_seqlens, num_seq_per_group_avg, output, tma_des):
return torch.empty((x.shape[0], weight.shape[1]), dtype=torch.bfloat16)
52 changes: 52 additions & 0 deletions src/activation/activation.cu
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,46 @@ __global__ void act_mul_and_quant_kernel(__nv_fp8_e4m3 *out_ptr, const __nv_bflo
}
}

__global__ void act_mul_bf16_kernel(__nv_bfloat16 *out_ptr, const __nv_bfloat16 *gate_up_ptr,
const int *valid_row_range, const int num_row,
const int num_col, cutlass::FastDivmod block1D22D) {
int iblockx;
int iblocky;

block1D22D(iblocky, iblockx, blockIdx.x);
int it = threadIdx.x + iblockx * blockDim.x;

int irow = iblocky;
int my_valid_row_end_exclusive = valid_row_range ? valid_row_range[0] : num_row;
if (irow >= my_valid_row_end_exclusive) {
return;
}

using T = __nv_bfloat162;

const auto *gate_row_ptr = gate_up_ptr + irow * num_col * 2;
const auto *up_row_ptr = gate_row_ptr + num_col;
auto *out_row_ptr = out_ptr + irow * num_col;

int icol = it * 8;

if (icol < num_col) {
auto gate = to<float>(load<T, 4>(gate_row_ptr + icol));
auto up = to<float>(load<T, 4>(up_row_ptr + icol));

vec_t<__nv_bfloat16, 8> out;
#pragma unroll
for (int i = 0; i < size(out); ++i) {
auto g = gate[i];
auto u = up[i];
auto result = silu(g) * u;
out[i] = __float2bfloat16_rn(result);
}

store(out_row_ptr + icol, out);
}
}

// input : gate + up
__global__ void masked_act_mul_and_quant_kernel(
__nv_fp8_e4m3 *output_ptr, const __nv_bfloat16 *input_ptr, const float *scale_ptr,
Expand Down Expand Up @@ -391,5 +431,17 @@ void masked_act_mul_and_blockwise_quant_async(__nv_fp8_e4m3 *output_ptr, float *
Row2EandT, num_block_row);
}

void act_mul_bf16_async(__nv_bfloat16 *y_ptr, const __nv_bfloat16 *x_ptr,
const int *valid_row_range, const int num_row, const int num_col,
cudaStream_t stream) {
dim3 block(256);
int num_block_per_row = (num_col / 8 + block.x - 1) / block.x;
cutlass::FastDivmod block1D22D(num_block_per_row);
dim3 grid(num_row * num_block_per_row);

kernels::act_mul_bf16_kernel<<<grid, block, 0, stream>>>(
y_ptr, x_ptr, valid_row_range, num_row, num_col, block1D22D);
}

} // namespace activation
} // namespace hpc
4 changes: 4 additions & 0 deletions src/activation/activation.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ void masked_act_mul_and_blockwise_quant_async(__nv_fp8_e4m3 *output_ptr, float *
int num_intermediate_size, int num_tokens_per_expert,
cudaStream_t stream);

void act_mul_bf16_async(__nv_bfloat16 *y_ptr, const __nv_bfloat16 *x_ptr,
const int *valid_row_range, const int num_row, const int num_col,
cudaStream_t stream);

} // namespace activation
} // namespace hpc

Expand Down
60 changes: 52 additions & 8 deletions src/fuse_moe/count_and_gather.cu
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ __global__ void gather_kernel(const vec_t<cute::TmaDescriptor, 4> td_xy,

} // namespace kernels

template <int kTileM, int kTileN, int kTileK, int kStage>
template <typename Tin, typename Tout, int kTileM, int kTileN, int kTileK, int kStage>
void launch_count_and_gather(void *gate_up_input_ptr, void *gate_up_output_ptr,
void *down_input_ptr, void *down_output_ptr, const void *x_ptr,
const void *topk_ids_ptr, void *topk_pos_ptr, void *seqlens_ptr,
Expand All @@ -272,9 +272,6 @@ void launch_count_and_gather(void *gate_up_input_ptr, void *gate_up_output_ptr,
int num_seq_per_group_avg, cudaStream_t stream) {
using namespace cute; // NOLINT

using Tin = cute::float_e4m3_t;
using Tout = cute::bfloat16_t;

int m = num_seq;
int n = intermediate_size;
int k = hidden_size;
Expand Down Expand Up @@ -378,35 +375,82 @@ void count_and_gather_async(void *gate_up_input_ptr, void *gate_up_output_ptr, v
cudaStream_t stream) {
constexpr int kTileN = 128;
constexpr int kTileK = 128;
using Tin = cute::float_e4m3_t;
using Tout = cute::bfloat16_t;
if (num_seq_per_group_avg <= 16) {
constexpr int kTileM = 16;
constexpr int kStage = 8;
launch_count_and_gather<Tin, Tout, kTileM, kTileN, kTileK, kStage>(
gate_up_input_ptr, gate_up_output_ptr, down_input_ptr, down_output_ptr, x_ptr, topk_ids_ptr,
topk_pos_ptr, seqlens_ptr, cu_seqlens_ptr, gate_up_tmas_ptr, down_tmas_ptr, tiles_ptr,
cu_tiles_ptr, num_seq, hidden_size, intermediate_size, num_topk, num_expert, eprank,
num_seq_per_group_avg, stream);
} else if (num_seq_per_group_avg <= 32) {
constexpr int kTileM = 32;
constexpr int kStage = 8;
launch_count_and_gather<Tin, Tout, kTileM, kTileN, kTileK, kStage>(
gate_up_input_ptr, gate_up_output_ptr, down_input_ptr, down_output_ptr, x_ptr, topk_ids_ptr,
topk_pos_ptr, seqlens_ptr, cu_seqlens_ptr, gate_up_tmas_ptr, down_tmas_ptr, tiles_ptr,
cu_tiles_ptr, num_seq, hidden_size, intermediate_size, num_topk, num_expert, eprank,
num_seq_per_group_avg, stream);
} else if (num_seq_per_group_avg <= 48) {
constexpr int kTileM = 48;
constexpr int kStage = 8;
launch_count_and_gather<Tin, Tout, kTileM, kTileN, kTileK, kStage>(
gate_up_input_ptr, gate_up_output_ptr, down_input_ptr, down_output_ptr, x_ptr, topk_ids_ptr,
topk_pos_ptr, seqlens_ptr, cu_seqlens_ptr, gate_up_tmas_ptr, down_tmas_ptr, tiles_ptr,
cu_tiles_ptr, num_seq, hidden_size, intermediate_size, num_topk, num_expert, eprank,
num_seq_per_group_avg, stream);
} else {
constexpr int kTileM = 64;
constexpr int kStage = 8;
launch_count_and_gather<Tin, Tout, kTileM, kTileN, kTileK, kStage>(
gate_up_input_ptr, gate_up_output_ptr, down_input_ptr, down_output_ptr, x_ptr, topk_ids_ptr,
topk_pos_ptr, seqlens_ptr, cu_seqlens_ptr, gate_up_tmas_ptr, down_tmas_ptr, tiles_ptr,
cu_tiles_ptr, num_seq, hidden_size, intermediate_size, num_topk, num_expert, eprank,
num_seq_per_group_avg, stream);
}
}

void count_and_gather_bf16_async(void *gate_up_input_ptr, void *gate_up_output_ptr, void *down_input_ptr,
void *down_output_ptr, const void *x_ptr, const void *topk_ids_ptr,
void *topk_pos_ptr, void *seqlens_ptr, void *cu_seqlens_ptr,
void *gate_up_tmas_ptr, void *down_tmas_ptr, void *tiles_ptr,
void *cu_tiles_ptr, int num_seq, int hidden_size, int intermediate_size,
int num_topk, int num_expert, int eprank, int num_seq_per_group_avg,
cudaStream_t stream) {
constexpr int kTileN = 128;
constexpr int kTileK = 64;
using Tin = cute::bfloat16_t;
using Tout = cute::bfloat16_t;
if (num_seq_per_group_avg <= 16) {
constexpr int kTileM = 16;
constexpr int kStage = 8;
launch_count_and_gather<kTileM, kTileN, kTileK, kStage>(
launch_count_and_gather<Tin, Tout, kTileM, kTileN, kTileK, kStage>(
gate_up_input_ptr, gate_up_output_ptr, down_input_ptr, down_output_ptr, x_ptr, topk_ids_ptr,
topk_pos_ptr, seqlens_ptr, cu_seqlens_ptr, gate_up_tmas_ptr, down_tmas_ptr, tiles_ptr,
cu_tiles_ptr, num_seq, hidden_size, intermediate_size, num_topk, num_expert, eprank,
num_seq_per_group_avg, stream);
} else if (num_seq_per_group_avg <= 32) {
constexpr int kTileM = 32;
constexpr int kStage = 8;
launch_count_and_gather<kTileM, kTileN, kTileK, kStage>(
launch_count_and_gather<Tin, Tout, kTileM, kTileN, kTileK, kStage>(
gate_up_input_ptr, gate_up_output_ptr, down_input_ptr, down_output_ptr, x_ptr, topk_ids_ptr,
topk_pos_ptr, seqlens_ptr, cu_seqlens_ptr, gate_up_tmas_ptr, down_tmas_ptr, tiles_ptr,
cu_tiles_ptr, num_seq, hidden_size, intermediate_size, num_topk, num_expert, eprank,
num_seq_per_group_avg, stream);
} else if (num_seq_per_group_avg <= 48) {
constexpr int kTileM = 48;
constexpr int kStage = 8;
launch_count_and_gather<kTileM, kTileN, kTileK, kStage>(
launch_count_and_gather<Tin, Tout, kTileM, kTileN, kTileK, kStage>(
gate_up_input_ptr, gate_up_output_ptr, down_input_ptr, down_output_ptr, x_ptr, topk_ids_ptr,
topk_pos_ptr, seqlens_ptr, cu_seqlens_ptr, gate_up_tmas_ptr, down_tmas_ptr, tiles_ptr,
cu_tiles_ptr, num_seq, hidden_size, intermediate_size, num_topk, num_expert, eprank,
num_seq_per_group_avg, stream);
} else {
constexpr int kTileM = 64;
constexpr int kStage = 8;
launch_count_and_gather<kTileM, kTileN, kTileK, kStage>(
launch_count_and_gather<Tin, Tout, kTileM, kTileN, kTileK, kStage>(
gate_up_input_ptr, gate_up_output_ptr, down_input_ptr, down_output_ptr, x_ptr, topk_ids_ptr,
topk_pos_ptr, seqlens_ptr, cu_seqlens_ptr, gate_up_tmas_ptr, down_tmas_ptr, tiles_ptr,
cu_tiles_ptr, num_seq, hidden_size, intermediate_size, num_topk, num_expert, eprank,
Expand Down
Loading