diff --git a/hpc/fuse_moe.py b/hpc/fuse_moe.py index e698d21..18fc457 100644 --- a/hpc/fuse_moe.py +++ b/hpc/fuse_moe.py @@ -291,6 +291,74 @@ def fuse_moe_blockwise_fp8( ) +def fuse_moe_bf16( + x: Tensor, + gate_up_weight: Tensor, + down_weight: Tensor, + topk_ids: Tensor, + topk_scale: Tensor, + rank_ep: int, + num_expert_total: int, + shared_output: Tensor = None, +) -> Tensor: + """Performs Mixture of Experts (MoE) forward operation with BF16 precision. + + This function executes the MoE computation with all matrix multiplications + performed in BF16 precision. The gate and up projections are fused into + a single matrix multiplication. + + Args: + x: Input activation tensor + Shape: [num_seq, hidden_size] + Dtype: bfloat16 + gate_up_weight: Combined weight tensor for gate and up projections + Shape: [num_expert_local, intermediate_size * 2, hidden_size] + Dtype: bfloat16 + down_weight: Weight tensor for down projection + Shape: [num_expert_local, hidden_size, intermediate_size] + Dtype: bfloat16 + topk_ids: Token indices assigned to each expert + Shape: [num_seq, num_topk] + Dtype: int32 + topk_scale: Weighting factors for each token-expert assignment + Shape: [num_seq, num_topk] + Dtype: float32 + rank_ep: Expert parallel rank (for distributed training) + Dtype: int32 + num_expert_total: the total number of expert + Dtype: int32 + shared_output: output for shared experts, default is None + Shape: [num_seq, hidden_size] + Dtype: bfloat16 + + Returns: + torch.Tensor: Output tensor after MoE computation + Shape: [num_seq, hidden_size] + Dtype: bfloat16 + + Raises: + RuntimeError: If the input tensors have incompatible shapes or types, + or if CUDA kernel execution fails. + + Note: + - All input tensors must be on CUDA device + - The gate and up projections are combined into a single matrix multiplication + - BF16 precision is used for all matrix operations + - Activation function used is SiLU (Swish) + - Token routing is determined by topk_ids and weighted by topk_scale + """ + return torch.ops.hpc.fuse_moe_bf16( + x, + gate_up_weight, + down_weight, + topk_ids, + topk_scale, + shared_output, + rank_ep, + num_expert_total, + ) + + @torch.library.register_fake("hpc::count_and_gather") def count_and_gather_fake( x, topk_ids, num_expert, rank_ep, intermediate_size, num_seq_per_group_avg @@ -346,3 +414,17 @@ def fuse_moe_blockwise_fp8_fake( use_bf16_mul: bool = True, ): return torch.empty((x.shape[0], x.shape[1]), dtype=torch.bfloat16) + + +@torch.library.register_fake("hpc::fuse_moe_bf16") +def fuse_moe_bf16_fake( + x: Tensor, + gate_up_weight: Tensor, + down_weight: Tensor, + topk_ids: Tensor, + topk_scale: Tensor, + shared_output: Tensor, + rank_ep: int, + num_expert_total: int, +): + return torch.empty((x.shape[0], x.shape[1]), dtype=torch.bfloat16) diff --git a/hpc/group_gemm.py b/hpc/group_gemm.py index 0631c2e..a292673 100644 --- a/hpc/group_gemm.py +++ b/hpc/group_gemm.py @@ -152,6 +152,52 @@ def group_gemm_blockwise_fp8( ) +def group_gemm_bf16( + x: Tensor, + weight: Tensor, + seqlens: Tensor, + cu_seqlens: Tensor, + num_seq_per_group_avg: int = 32, + output: Tensor = None, + tma_desc: Tensor = None, +) -> Tensor: + """Performs group GEMM operation with BF16 precision. + + This function executes multiple matrix multiplications in a group manner + using BF16 precision for improved performance. + + Args: + x: Input activation tensor + Shape: [total_seq, hidden_size] + Dtype: bfloat16 + weight: Weight tensor for group matrix multiplication + Shape: [num_group, output_dim, hidden_size] + Dtype: bfloat16 + seqlens: Sequence lengths for each group + Shape: [num_group] + Dtype: int32 + cu_seqlens: Cumulative sequence lengths indicating start indices in input tensor + Shape: [num_group + 1] + Dtype: int32 + + Returns: + Tensor: Output tensor after group matrix multiplication + Shape: [total_seq, output_dim] + Dtype: bfloat16 + + Raises: + RuntimeError: If the input tensors have incompatible shapes or types, + or if the CUDA kernel execution fails. + + Note: + - All input tensors must be on CUDA device + + """ + return torch.ops.hpc.group_gemm_bf16( + x, weight, seqlens, cu_seqlens, num_seq_per_group_avg, output, tma_desc + ) + + @torch.library.register_fake("hpc::group_gemm_pertensor_fp8") def group_gemm_pertensor_fp8_fake( x, weight, seqlens, cu_seqlens, y_scale, num_seq_per_group_avg, output, tma_des @@ -164,3 +210,8 @@ def group_gemm_blockwise_fp8_fake( x, weight, seqlens, cu_seqlens, x_scale, w_scale, num_seq_per_group_avg, output, tma_des ): return torch.empty((x.shape[0], weight.shape[1]), dtype=torch.bfloat16) + + +@torch.library.register_fake("hpc::group_gemm_bf16") +def group_gemm_bf16_fake(x, weight, seqlens, cu_seqlens, num_seq_per_group_avg, output, tma_des): + return torch.empty((x.shape[0], weight.shape[1]), dtype=torch.bfloat16) diff --git a/src/activation/activation.cu b/src/activation/activation.cu index b1000e4..fb5b141 100644 --- a/src/activation/activation.cu +++ b/src/activation/activation.cu @@ -64,6 +64,46 @@ __global__ void act_mul_and_quant_kernel(__nv_fp8_e4m3 *out_ptr, const __nv_bflo } } +__global__ void act_mul_bf16_kernel(__nv_bfloat16 *out_ptr, const __nv_bfloat16 *gate_up_ptr, + const int *valid_row_range, const int num_row, + const int num_col, cutlass::FastDivmod block1D22D) { + int iblockx; + int iblocky; + + block1D22D(iblocky, iblockx, blockIdx.x); + int it = threadIdx.x + iblockx * blockDim.x; + + int irow = iblocky; + int my_valid_row_end_exclusive = valid_row_range ? valid_row_range[0] : num_row; + if (irow >= my_valid_row_end_exclusive) { + return; + } + + using T = __nv_bfloat162; + + const auto *gate_row_ptr = gate_up_ptr + irow * num_col * 2; + const auto *up_row_ptr = gate_row_ptr + num_col; + auto *out_row_ptr = out_ptr + irow * num_col; + + int icol = it * 8; + + if (icol < num_col) { + auto gate = to(load(gate_row_ptr + icol)); + auto up = to(load(up_row_ptr + icol)); + + vec_t<__nv_bfloat16, 8> out; +#pragma unroll + for (int i = 0; i < size(out); ++i) { + auto g = gate[i]; + auto u = up[i]; + auto result = silu(g) * u; + out[i] = __float2bfloat16_rn(result); + } + + store(out_row_ptr + icol, out); + } +} + // input : gate + up __global__ void masked_act_mul_and_quant_kernel( __nv_fp8_e4m3 *output_ptr, const __nv_bfloat16 *input_ptr, const float *scale_ptr, @@ -391,5 +431,17 @@ void masked_act_mul_and_blockwise_quant_async(__nv_fp8_e4m3 *output_ptr, float * Row2EandT, num_block_row); } +void act_mul_bf16_async(__nv_bfloat16 *y_ptr, const __nv_bfloat16 *x_ptr, + const int *valid_row_range, const int num_row, const int num_col, + cudaStream_t stream) { + dim3 block(256); + int num_block_per_row = (num_col / 8 + block.x - 1) / block.x; + cutlass::FastDivmod block1D22D(num_block_per_row); + dim3 grid(num_row * num_block_per_row); + + kernels::act_mul_bf16_kernel<<>>( + y_ptr, x_ptr, valid_row_range, num_row, num_col, block1D22D); +} + } // namespace activation } // namespace hpc diff --git a/src/activation/activation.h b/src/activation/activation.h index 784dc27..bd72ceb 100644 --- a/src/activation/activation.h +++ b/src/activation/activation.h @@ -38,6 +38,10 @@ void masked_act_mul_and_blockwise_quant_async(__nv_fp8_e4m3 *output_ptr, float * int num_intermediate_size, int num_tokens_per_expert, cudaStream_t stream); +void act_mul_bf16_async(__nv_bfloat16 *y_ptr, const __nv_bfloat16 *x_ptr, + const int *valid_row_range, const int num_row, const int num_col, + cudaStream_t stream); + } // namespace activation } // namespace hpc diff --git a/src/fuse_moe/count_and_gather.cu b/src/fuse_moe/count_and_gather.cu index 0c4781c..9aba97e 100644 --- a/src/fuse_moe/count_and_gather.cu +++ b/src/fuse_moe/count_and_gather.cu @@ -262,7 +262,7 @@ __global__ void gather_kernel(const vec_t td_xy, } // namespace kernels -template +template void launch_count_and_gather(void *gate_up_input_ptr, void *gate_up_output_ptr, void *down_input_ptr, void *down_output_ptr, const void *x_ptr, const void *topk_ids_ptr, void *topk_pos_ptr, void *seqlens_ptr, @@ -272,9 +272,6 @@ void launch_count_and_gather(void *gate_up_input_ptr, void *gate_up_output_ptr, int num_seq_per_group_avg, cudaStream_t stream) { using namespace cute; // NOLINT - using Tin = cute::float_e4m3_t; - using Tout = cute::bfloat16_t; - int m = num_seq; int n = intermediate_size; int k = hidden_size; @@ -378,11 +375,58 @@ void count_and_gather_async(void *gate_up_input_ptr, void *gate_up_output_ptr, v cudaStream_t stream) { constexpr int kTileN = 128; constexpr int kTileK = 128; + using Tin = cute::float_e4m3_t; + using Tout = cute::bfloat16_t; + if (num_seq_per_group_avg <= 16) { + constexpr int kTileM = 16; + constexpr int kStage = 8; + launch_count_and_gather( + gate_up_input_ptr, gate_up_output_ptr, down_input_ptr, down_output_ptr, x_ptr, topk_ids_ptr, + topk_pos_ptr, seqlens_ptr, cu_seqlens_ptr, gate_up_tmas_ptr, down_tmas_ptr, tiles_ptr, + cu_tiles_ptr, num_seq, hidden_size, intermediate_size, num_topk, num_expert, eprank, + num_seq_per_group_avg, stream); + } else if (num_seq_per_group_avg <= 32) { + constexpr int kTileM = 32; + constexpr int kStage = 8; + launch_count_and_gather( + gate_up_input_ptr, gate_up_output_ptr, down_input_ptr, down_output_ptr, x_ptr, topk_ids_ptr, + topk_pos_ptr, seqlens_ptr, cu_seqlens_ptr, gate_up_tmas_ptr, down_tmas_ptr, tiles_ptr, + cu_tiles_ptr, num_seq, hidden_size, intermediate_size, num_topk, num_expert, eprank, + num_seq_per_group_avg, stream); + } else if (num_seq_per_group_avg <= 48) { + constexpr int kTileM = 48; + constexpr int kStage = 8; + launch_count_and_gather( + gate_up_input_ptr, gate_up_output_ptr, down_input_ptr, down_output_ptr, x_ptr, topk_ids_ptr, + topk_pos_ptr, seqlens_ptr, cu_seqlens_ptr, gate_up_tmas_ptr, down_tmas_ptr, tiles_ptr, + cu_tiles_ptr, num_seq, hidden_size, intermediate_size, num_topk, num_expert, eprank, + num_seq_per_group_avg, stream); + } else { + constexpr int kTileM = 64; + constexpr int kStage = 8; + launch_count_and_gather( + gate_up_input_ptr, gate_up_output_ptr, down_input_ptr, down_output_ptr, x_ptr, topk_ids_ptr, + topk_pos_ptr, seqlens_ptr, cu_seqlens_ptr, gate_up_tmas_ptr, down_tmas_ptr, tiles_ptr, + cu_tiles_ptr, num_seq, hidden_size, intermediate_size, num_topk, num_expert, eprank, + num_seq_per_group_avg, stream); + } +} +void count_and_gather_bf16_async(void *gate_up_input_ptr, void *gate_up_output_ptr, void *down_input_ptr, + void *down_output_ptr, const void *x_ptr, const void *topk_ids_ptr, + void *topk_pos_ptr, void *seqlens_ptr, void *cu_seqlens_ptr, + void *gate_up_tmas_ptr, void *down_tmas_ptr, void *tiles_ptr, + void *cu_tiles_ptr, int num_seq, int hidden_size, int intermediate_size, + int num_topk, int num_expert, int eprank, int num_seq_per_group_avg, + cudaStream_t stream) { + constexpr int kTileN = 128; + constexpr int kTileK = 64; + using Tin = cute::bfloat16_t; + using Tout = cute::bfloat16_t; if (num_seq_per_group_avg <= 16) { constexpr int kTileM = 16; constexpr int kStage = 8; - launch_count_and_gather( + launch_count_and_gather( gate_up_input_ptr, gate_up_output_ptr, down_input_ptr, down_output_ptr, x_ptr, topk_ids_ptr, topk_pos_ptr, seqlens_ptr, cu_seqlens_ptr, gate_up_tmas_ptr, down_tmas_ptr, tiles_ptr, cu_tiles_ptr, num_seq, hidden_size, intermediate_size, num_topk, num_expert, eprank, @@ -390,7 +434,7 @@ void count_and_gather_async(void *gate_up_input_ptr, void *gate_up_output_ptr, v } else if (num_seq_per_group_avg <= 32) { constexpr int kTileM = 32; constexpr int kStage = 8; - launch_count_and_gather( + launch_count_and_gather( gate_up_input_ptr, gate_up_output_ptr, down_input_ptr, down_output_ptr, x_ptr, topk_ids_ptr, topk_pos_ptr, seqlens_ptr, cu_seqlens_ptr, gate_up_tmas_ptr, down_tmas_ptr, tiles_ptr, cu_tiles_ptr, num_seq, hidden_size, intermediate_size, num_topk, num_expert, eprank, @@ -398,7 +442,7 @@ void count_and_gather_async(void *gate_up_input_ptr, void *gate_up_output_ptr, v } else if (num_seq_per_group_avg <= 48) { constexpr int kTileM = 48; constexpr int kStage = 8; - launch_count_and_gather( + launch_count_and_gather( gate_up_input_ptr, gate_up_output_ptr, down_input_ptr, down_output_ptr, x_ptr, topk_ids_ptr, topk_pos_ptr, seqlens_ptr, cu_seqlens_ptr, gate_up_tmas_ptr, down_tmas_ptr, tiles_ptr, cu_tiles_ptr, num_seq, hidden_size, intermediate_size, num_topk, num_expert, eprank, @@ -406,7 +450,7 @@ void count_and_gather_async(void *gate_up_input_ptr, void *gate_up_output_ptr, v } else { constexpr int kTileM = 64; constexpr int kStage = 8; - launch_count_and_gather( + launch_count_and_gather( gate_up_input_ptr, gate_up_output_ptr, down_input_ptr, down_output_ptr, x_ptr, topk_ids_ptr, topk_pos_ptr, seqlens_ptr, cu_seqlens_ptr, gate_up_tmas_ptr, down_tmas_ptr, tiles_ptr, cu_tiles_ptr, num_seq, hidden_size, intermediate_size, num_topk, num_expert, eprank, diff --git a/src/fuse_moe/entry.cc b/src/fuse_moe/entry.cc index 9e97c19..9dafe99 100644 --- a/src/fuse_moe/entry.cc +++ b/src/fuse_moe/entry.cc @@ -363,6 +363,105 @@ torch::Tensor fuse_moe_blockwise_fp8_entry( return y; } +torch::Tensor fuse_moe_bf16_entry( + const torch::Tensor &x, const torch::Tensor &gate_up_weight, const torch::Tensor &down_weight, + const torch::Tensor &topk_ids, const torch::Tensor &topk_scale, + const std::optional &shared_output, int64_t rank_ep, int64_t num_expert_total) { + auto stream = at::cuda::getCurrentCUDAStream(x.get_device()); + + TORCH_CHECK(x.device().is_cuda(), "x tensor must be cuda"); + TORCH_CHECK(gate_up_weight.device().is_cuda(), "gate_up_weight tensor must be cuda"); + TORCH_CHECK(down_weight.device().is_cuda(), "down_weight tensor must be cuda"); + TORCH_CHECK(topk_ids.device().is_cuda(), "topk_ids tensor must be cuda"); + TORCH_CHECK(topk_scale.device().is_cuda(), "topk_scale tensor must be cuda"); + + TORCH_CHECK(x.is_contiguous(), "x tensor must be contiguous"); + TORCH_CHECK(gate_up_weight.is_contiguous(), "gate_up_weight tensor must be contiguous"); + TORCH_CHECK(down_weight.is_contiguous(), "down_weight tensor must be contiguous"); + TORCH_CHECK(topk_ids.is_contiguous(), "topk_ids tensor must be contiguous"); + TORCH_CHECK(topk_scale.is_contiguous(), "topk_scale tensor must be contiguous"); + + TORCH_CHECK(x.dtype() == torch::kBFloat16, "x tensor dtype must be bfloat16"); + TORCH_CHECK(gate_up_weight.dtype() == torch::kBFloat16, "gate_up_weight tensor dtype must be bfloat16"); + TORCH_CHECK(down_weight.dtype() == torch::kBFloat16, "down_weight tensor dtype must be bfloat16"); + + TORCH_CHECK(x.size(0) == topk_ids.size(0), "x and topk_ids must share the same num_seq"); + TORCH_CHECK(topk_ids.size(0) == topk_scale.size(0), + "topk_ids and topk_scale must share the same num_seq"); + TORCH_CHECK(topk_ids.size(1) == topk_scale.size(1), + "topk_ids and topk_scale must share the same num_topk"); + TORCH_CHECK(x.size(1) == gate_up_weight.size(2), "x and weight must share the same k"); + TORCH_CHECK(gate_up_weight.size(0) == down_weight.size(0), + "gate_up_weight and down_weight must share the same num_expert"); + + const void *shared_output_ptr = nullptr; + if (shared_output.has_value()) { + const auto shared_output_tensor = shared_output.value(); + TORCH_CHECK(shared_output_tensor.device().is_cuda(), "shared_output tensor must be cuda"); + TORCH_CHECK(shared_output_tensor.is_contiguous(), "shared_output tensor must be contiguous"); + TORCH_CHECK(shared_output_tensor.dtype() == torch::kBFloat16, + "shared_output tensor dtype must be bfloat16"); + TORCH_CHECK( + shared_output_tensor.size(0) == x.size(0) && shared_output_tensor.size(1) == x.size(1), + "shared_output tensor shape must be same as x tensor"); + shared_output_ptr = shared_output_tensor.const_data_ptr(); + } + + int num_seq = x.size(0); + int hidden_size = x.size(1); + int num_expert = gate_up_weight.size(0); + int intermediate_size = gate_up_weight.size(1); + int num_topk = topk_ids.size(1); + TORCH_CHECK(num_topk <= 128, "num_topk must less than or equal to 128"); + + auto options = x.options(); + torch::Tensor y = torch::empty({num_seq, hidden_size}, options.dtype(torch::kBFloat16)); + + torch::Tensor gate_up_input = torch::empty({num_seq * num_topk, hidden_size}, options); + torch::Tensor gate_up_output = + torch::empty({num_seq * num_topk, intermediate_size}, options.dtype(torch::kBFloat16)); + torch::Tensor gate_up_tmas = torch::empty({num_expert * 2, 128}, options.dtype(torch::kInt8)); + torch::Tensor down_input = torch::empty({num_seq * num_topk, intermediate_size / 2}, options); + torch::Tensor down_output = + torch::empty({num_seq * num_topk, hidden_size}, options.dtype(torch::kBFloat16)); + torch::Tensor down_tmas = torch::empty({num_expert * 2, 128}, options.dtype(torch::kInt8)); + + torch::Tensor topk_pos = torch::empty({num_seq, num_topk}, options.dtype(torch::kInt32)); + torch::Tensor seqlens = torch::zeros({num_expert}, options.dtype(torch::kInt32)); + torch::Tensor cu_seqlens = torch::empty({num_expert + 1}, options.dtype(torch::kInt32)); + torch::Tensor tiles = torch::empty({num_expert}, options.dtype(torch::kInt32)); + torch::Tensor cu_tiles = torch::empty({num_expert + 1}, options.dtype(torch::kInt32)); + + const auto *x_ptr = x.const_data_ptr(); + const auto *topk_ids_ptr = topk_ids.const_data_ptr(); + const auto *topk_scale_ptr = topk_scale.const_data_ptr(); + const auto *gate_up_weight_ptr = gate_up_weight.const_data_ptr(); + const auto *down_weight_ptr = down_weight.const_data_ptr(); + + auto *y_ptr = y.mutable_data_ptr(); + auto *topk_pos_ptr = topk_pos.mutable_data_ptr(); + auto *seqlens_ptr = seqlens.mutable_data_ptr(); + auto *cu_seqlens_ptr = cu_seqlens.mutable_data_ptr(); + auto *tiles_ptr = tiles.mutable_data_ptr(); + auto *cu_tiles_ptr = cu_tiles.mutable_data_ptr(); + auto *gate_up_input_ptr = gate_up_input.mutable_data_ptr(); + auto *gate_up_output_ptr = gate_up_output.mutable_data_ptr(); + auto *gate_up_tmas_ptr = gate_up_tmas.mutable_data_ptr(); + auto *down_input_ptr = down_input.mutable_data_ptr(); + auto *down_output_ptr = down_output.mutable_data_ptr(); + auto *down_tmas_ptr = down_tmas.mutable_data_ptr(); + + fuse_moe_bf16_async( + y_ptr, x_ptr, gate_up_input_ptr, gate_up_output_ptr, gate_up_weight_ptr, gate_up_tmas_ptr, + down_input_ptr, down_output_ptr, down_weight_ptr, down_tmas_ptr, + topk_ids_ptr, topk_scale_ptr, topk_pos_ptr, seqlens_ptr, + cu_seqlens_ptr, tiles_ptr, cu_tiles_ptr, shared_output_ptr, num_seq, hidden_size, + intermediate_size, num_topk, num_expert_total, num_expert, rank_ep, stream); + + return y; +} + + } // namespace fuse_moe } // namespace hpc @@ -393,4 +492,10 @@ TORCH_LIBRARY_FRAGMENT(hpc, m) { "shared_output, " "int rank_ep, int num_expert_total) -> (Tensor)"); m.impl("fuse_moe_blockwise_fp8", torch::kCUDA, &hpc::fuse_moe::fuse_moe_blockwise_fp8_entry); + + m.def( + "fuse_moe_bf16(Tensor x, Tensor gate_up_weight, Tensor down_weight, " + "Tensor topk_ids, Tensor topk_scale, Tensor ? shared_output, " + "int rank_ep, int num_expert_total) -> (Tensor)"); + m.impl("fuse_moe_bf16", torch::kCUDA, &hpc::fuse_moe::fuse_moe_bf16_entry); } diff --git a/src/fuse_moe/fuse_moe.cu b/src/fuse_moe/fuse_moe.cu index fe73750..16d69be 100644 --- a/src/fuse_moe/fuse_moe.cu +++ b/src/fuse_moe/fuse_moe.cu @@ -107,5 +107,48 @@ void fuse_moe_blockwise_fp8_async( reduce_async(output_ptr, down_output_ptr, topk_pos_ptr, topk_scale_ptr, shared_output_ptr, total_num_tokens, num_tokens, hidden_size, num_topk, stream); } + +void fuse_moe_bf16_async( + void *output_ptr, const void *input_ptr, void *gate_up_input_ptr, void *gate_up_output_ptr, + const void *gate_up_weight_ptr, void *gate_up_tmas_ptr, void *down_input_ptr, + void *down_output_ptr, const void *down_weight_ptr, void *down_tmas_ptr, + const void *topk_ids_ptr, const void *topk_scale_ptr, void *topk_pos_ptr, void *seqlens_ptr, + void *cu_seqlens_ptr, void *tiles_ptr, void *cu_tiles_ptr, const void *shared_output_ptr, + int num_seq, int hidden_size, int intermediate_size, int num_topk, int num_expert_total, + int num_expert_local, int rank_ep, cudaStream_t stream) { + int total_num_seq = num_seq * num_topk; + int num_seq_per_group_avg = total_num_seq / num_expert_total; + using T1 = __nv_bfloat16; + + // 0. call count_and_gather_bf16_async (fills TMA descriptors for bf16 group gemm) + count_and_gather_bf16_async(gate_up_input_ptr, gate_up_output_ptr, down_input_ptr, down_output_ptr, + input_ptr, topk_ids_ptr, topk_pos_ptr, seqlens_ptr, cu_seqlens_ptr, + gate_up_tmas_ptr, down_tmas_ptr, tiles_ptr, cu_tiles_ptr, num_seq, + hidden_size, intermediate_size, num_topk, num_expert_local, rank_ep, + num_seq_per_group_avg, stream); + + // 1. call gate_up linear (bf16), TMA descriptors pre-filled by count_and_gather_bf16_async + group_gemm::group_gemm_bf16_async( + gate_up_output_ptr, gate_up_input_ptr, gate_up_weight_ptr, seqlens_ptr, cu_seqlens_ptr, + gate_up_tmas_ptr, tiles_ptr, cu_tiles_ptr, num_expert_local, total_num_seq, + intermediate_size, hidden_size, num_seq_per_group_avg, false, stream); + + // 2. call act and mul (bf16 activation) + const int *valid_row_range_ptr = + (int *)cu_seqlens_ptr + num_expert_local; // get last number as valid row + activation::act_mul_bf16_async((T1 *)down_input_ptr, (const T1 *)gate_up_output_ptr, + valid_row_range_ptr, total_num_seq, intermediate_size / 2, stream); + + // 3. call down linear (bf16), TMA descriptors pre-filled by count_and_gather_bf16_async + group_gemm::group_gemm_bf16_async( + down_output_ptr, down_input_ptr, down_weight_ptr, seqlens_ptr, cu_seqlens_ptr, + down_tmas_ptr, tiles_ptr, cu_tiles_ptr, num_expert_local, total_num_seq, hidden_size, + intermediate_size / 2, num_seq_per_group_avg, false, stream); + + // 4. call reduce + reduce_async(output_ptr, down_output_ptr, topk_pos_ptr, topk_scale_ptr, shared_output_ptr, + total_num_seq, num_seq, hidden_size, num_topk, stream); +} + } // namespace fuse_moe } // namespace hpc diff --git a/src/fuse_moe/fuse_moe.h b/src/fuse_moe/fuse_moe.h index a388878..85a844b 100644 --- a/src/fuse_moe/fuse_moe.h +++ b/src/fuse_moe/fuse_moe.h @@ -20,6 +20,14 @@ void count_and_gather_async(void *gate_up_input_ptr, void *gata_up_output_ptr, v int num_topk, int num_expert, int rank_ep, int num_seq_per_group_avg, cudaStream_t stream); +void count_and_gather_bf16_async(void *gate_up_input_ptr, void *gate_up_output_ptr, + void *down_input_ptr, void *down_output_ptr, const void *x_ptr, + const void *topk_ids_ptr, void *topk_pos_ptr, void *seqlens_ptr, + void *cu_seqlens_ptr, void *gate_up_tmas_ptr, void *down_tmas_ptr, + void *tiles_ptr, void *cu_tiles_ptr, int num_seq, int hidden_size, + int intermediate_size, int num_topk, int num_expert, int rank_ep, + int num_seq_per_group_avg, cudaStream_t stream); + void blockwise_count_and_gather_async( const void *input_ptr, const void *input_scale_ptr, void *gate_up_input_ptr, void *gate_up_output_ptr, void *gate_up_input_scale_ptr, void *down_input_ptr, @@ -56,6 +64,15 @@ void fuse_moe_blockwise_fp8_async( int gate_up_weight_scale_lastdim_pad4, int down_weight_scale_lastdim_pad4, int rank_ep, cudaStream_t stream); +void fuse_moe_bf16_async( + void *output_ptr, const void *input_ptr, void *gate_up_input_ptr, void *gate_up_output_ptr, + const void *gate_up_weight_ptr, void *gate_up_tmas_ptr, void *down_input_ptr, + void *down_output_ptr, const void *down_weight_ptr, void *down_tmas_ptr, + const void *topk_ids_ptr, const void *topk_scale_ptr, void *topk_pos_ptr, void *seqlens_ptr, + void *cu_seqlens_ptr, void *tiles_ptr, void *cu_tiles_ptr, const void *shared_output_ptr, + int num_seq, int hidden_size, int intermediate_size, int num_topk, int num_expert_total, + int num_expert_local, int rank_ep, cudaStream_t stream); + } // namespace fuse_moe } // namespace hpc diff --git a/src/group_gemm/config.h b/src/group_gemm/config.h index b0b4656..16462c0 100644 --- a/src/group_gemm/config.h +++ b/src/group_gemm/config.h @@ -165,6 +165,75 @@ struct GroupGEMMBlockWiseFp8Config { auto get_shm_size() { return shm_size; } }; +template +static constexpr auto mma_selector_bf16() { + constexpr auto MajorA = GMMA::Major::K; + constexpr auto MajorB = GMMA::Major::K; + if constexpr (kTileM == 8) { + return cute::SM90_64x8x16_F32BF16BF16_SS{}; + } else if constexpr (kTileM == 16) { + return cute::SM90_64x16x16_F32BF16BF16_SS{}; + } else if constexpr (kTileM == 32) { + return cute::SM90_64x32x16_F32BF16BF16_SS{}; + } else if constexpr (kTileM == 48) { + return cute::SM90_64x48x16_F32BF16BF16_SS{}; + } else if constexpr (kTileM == 64) { + return cute::SM90_64x64x16_F32BF16BF16_SS{}; + } else if constexpr (kTileM == 96) { + return cute::SM90_64x96x16_F32BF16BF16_SS{}; + } else if constexpr (kTileM == 128) { + return cute::SM90_64x128x16_F32BF16BF16_SS{}; + } else { + return cute::SM90_64x64x16_F32BF16BF16_SS{}; + } +} + +template +struct GroupGEMMBF16Config { + using Tin = Tin_; + using Tout = Tout_; + + static constexpr int kTileM = kTileM_; + static constexpr int kTileN = kTileN_; + static constexpr int kTileK = kTileK_; + static constexpr int kStage = kStage_; + static constexpr int kWarpgroupM = kWarpgroupM_; + static constexpr int kWarpgroupN = kWarpgroupN_; + + using SLayoutXAtom = decltype(slayout_selector()); + using SLayoutWAtom = decltype(slayout_selector()); + using SLayoutYAtom = decltype(slayout_selector()); + + using SLayoutX = decltype(tile_to_shape(SLayoutXAtom{}, + make_shape(Int{}, Int{}, Int{}))); + using SLayoutW = decltype(tile_to_shape(SLayoutWAtom{}, + make_shape(Int{}, Int{}, Int{}))); + using SLayoutY = + decltype(tile_to_shape(SLayoutYAtom{}, make_shape(Int{}, Int{}))); + using CopyBoxY = decltype(tile_to_shape(SLayoutYAtom{}, + make_shape(Int{}, Int{}))); + + template + auto get_tma(TX x, TW w, TY y) { + auto tma_x = make_tma_copy(SM90_TMA_LOAD{}, x, take<0, 2>(SLayoutX{})); + auto tma_w = make_tma_copy(SM90_TMA_LOAD{}, w, take<0, 2>(SLayoutW{})); + auto tma_y = make_tma_copy(SM90_TMA_STORE{}, y, CopyBoxY{}); + return std::make_tuple(tma_x, tma_w, tma_y); + } + + using WarpgroupLayout = + decltype(make_layout(make_shape(Int{}, Int{}, Int<1>{}))); + using TiledMma = decltype(make_tiled_mma(mma_selector_bf16(), WarpgroupLayout{})); + + static constexpr int shm_xw = (cosize(SLayoutX{}) + cosize(SLayoutW{})) * sizeof(Tin); + static constexpr int shm_y = cosize(SLayoutY{}) * sizeof(Tout); + static constexpr int shm_size = shm_xw + shm_y; + + auto get_shm_size() { return shm_size; } +}; + } // namespace group_gemm } // namespace hpc diff --git a/src/group_gemm/entry.cc b/src/group_gemm/entry.cc index 5c74dff..9cc4c9c 100644 --- a/src/group_gemm/entry.cc +++ b/src/group_gemm/entry.cc @@ -6,7 +6,7 @@ #include #include - +#include #include "src/group_gemm/group_gemm.h" namespace hpc { @@ -137,6 +137,65 @@ torch::Tensor group_gemm_blockwise_fp8_entry( return y; } +torch::Tensor group_gemm_bf16_entry(const torch::Tensor &x, const torch::Tensor &weight, + const torch::Tensor &seqlens, + const torch::Tensor &cu_seqlens, + const int64_t num_seq_per_group_avg, + std::optional output, + std::optional tma_desc) { + auto stream = at::cuda::getCurrentCUDAStream(x.get_device()); + TORCH_CHECK(x.device().is_cuda(), "x tensor must be cuda"); + TORCH_CHECK(weight.device().is_cuda(), "weight tensor must be cuda"); + TORCH_CHECK(seqlens.device().is_cuda(), "seqlens tensor must be cuda"); + TORCH_CHECK(cu_seqlens.device().is_cuda(), "cu_seqlens tensor must be cuda"); + TORCH_CHECK(x.is_contiguous(), "x tensor a must be contiguous"); + TORCH_CHECK(weight.is_contiguous(), "weight tensor a must be contiguous"); + TORCH_CHECK(seqlens.size(0) == weight.size(0), + "seqlens and weight must share the same num_group"); + TORCH_CHECK(x.size(1) == weight.size(2), "x and weight must share the same k"); + + int m = x.size(0); + int k = x.size(1); + int n = weight.size(1); + int num_group = seqlens.size(0); + + auto options = x.options(); + torch::Tensor y; + if (output.has_value()) { + y = output.value(); + } else { + y = torch::empty({m, n}, options.dtype(torch::kBFloat16)); + } + + torch::Tensor tmas; + bool update_tma = true; + if (tma_desc.has_value()) { + tmas = tma_desc.value(); + update_tma = false; + } else { + tmas = torch::empty({num_group * 2, 128}, options); + } + + torch::Tensor tiles = torch::empty({num_group}, options.dtype(torch::kInt32)); + torch::Tensor cu_tiles = torch::empty({num_group + 1}, options.dtype(torch::kInt32)); + + const auto *x_ptr = x.const_data_ptr(); + const auto *weight_ptr = weight.const_data_ptr(); + const auto *seqlens_ptr = seqlens.const_data_ptr(); + const auto *cu_seqlens_ptr = cu_seqlens.const_data_ptr(); + auto *tmas_ptr = tmas.mutable_data_ptr(); + auto *y_ptr = y.mutable_data_ptr(); + + auto *tiles_ptr = tiles.mutable_data_ptr(); + auto *cu_tiles_ptr = cu_tiles.mutable_data_ptr(); + + group_gemm_bf16_async(y_ptr, x_ptr, weight_ptr, seqlens_ptr, cu_seqlens_ptr, + tmas_ptr, tiles_ptr, cu_tiles_ptr, num_group, m, n, k, + num_seq_per_group_avg, update_tma, stream); + + return y; +} + torch::Tensor reformat_x_scale_entry(const torch::Tensor &x_scale, const torch::Tensor &seqlens, const torch::Tensor &cu_seqlens, std::optional out_x_scale, @@ -197,6 +256,12 @@ TORCH_LIBRARY_FRAGMENT(hpc, m) { m.impl("group_gemm_pertensor_fp8", torch::kCUDA, &hpc::group_gemm::group_gemm_pertensor_fp8_entry); + m.def( + "group_gemm_bf16(Tensor x, Tensor weight, Tensor seqlens, Tensor cu_seqlens, " + "int num_seq_per_group_avg, Tensor? output, Tensor? tma_desc) -> (Tensor)"); + m.impl("group_gemm_bf16", torch::kCUDA, + &hpc::group_gemm::group_gemm_bf16_entry); + m.def( "group_gemm_blockwise_fp8(Tensor x, Tensor weight, Tensor seqlens, Tensor cu_seqlens, Tensor " "xscale, Tensor wscale," diff --git a/src/group_gemm/group_gemm.h b/src/group_gemm/group_gemm.h index 6e5752a..e74447d 100644 --- a/src/group_gemm/group_gemm.h +++ b/src/group_gemm/group_gemm.h @@ -28,6 +28,12 @@ void reformat_x_scale_async(void *output_ptr, const void *xscale_ptr, const void const void *cu_seqlens_ptr, int num_group, int m, int n, int tilem, cudaStream_t stream); +void group_gemm_bf16_async(void *y_ptr, const void *x_ptr, const void *w_ptr, + const void *seqlens_ptr, const void *cu_seqlens_ptr, + void *tmas_ptr, void *tiles_ptr, + void *cu_tiles_ptr, int num_group, int m, int n, int k, + int num_seq_per_group_avg, bool update_tma, + cudaStream_t stream); } // namespace group_gemm } // namespace hpc diff --git a/src/group_gemm/group_gemm_bf16.cu b/src/group_gemm/group_gemm_bf16.cu new file mode 100644 index 0000000..9417007 --- /dev/null +++ b/src/group_gemm/group_gemm_bf16.cu @@ -0,0 +1,137 @@ +// Copyright (C) 2026 Tencent. + +#include +#include + +#include +#include "cute/tensor.hpp" +#include "src/group_gemm/config.h" +#include "src/group_gemm/group_gemm.h" +#include "src/group_gemm/kernels.cuh" + +namespace hpc { +namespace group_gemm { + +template +void launch_group_gemm_bf16(void *y_ptr, const void *x_ptr, const void *w_ptr, + const void *seqlens_ptr, const void *cu_seqlens_ptr, + void *tmas_ptr, void *tiles_ptr, void *cu_tiles_ptr, int num_group, + int m, int n, int k, bool update_tma, cudaStream_t stream) { + using namespace cute; // NOLINT + + using Tin = cute::bfloat16_t; + using Tout = cute::bfloat16_t; + + auto X = make_tensor(make_gmem_ptr(reinterpret_cast(x_ptr)), make_shape(m, k), + make_stride(k, Int<1>{})); + auto W = make_tensor(make_gmem_ptr(reinterpret_cast(w_ptr)), + make_shape(n, k, num_group), make_stride(k, Int<1>{}, n * k)); + auto Y = make_tensor(make_gmem_ptr(reinterpret_cast(y_ptr)), make_shape(n, m), + make_stride(Int<1>{}, n)); + + using Config = GroupGEMMBF16Config; + Config config; + auto [tma_x, tma_w, tma_y] = config.get_tma(X, W, Y); + + auto *tma_xy = static_cast(tmas_ptr); + + // 0. update tma + if (update_tma) { + vec_t td_xy{ + *tma_x.get_tma_descriptor(), + *tma_y.get_tma_descriptor(), + }; + + constexpr int kGroupPerThread = 8; + constexpr int kThreadPerBlock = 32; + kernels::update_grouped_tma + <<>>( + td_xy, tma_xy, (const Tin *)x_ptr, (const Tout *)y_ptr, (const int *)seqlens_ptr, + (const int *)cu_seqlens_ptr, (int *)tiles_ptr, (int *)cu_tiles_ptr, num_group, m, n, k); + } + + // 1. group gemm + { + int num_tile_n = (n + kTileN - 1) / kTileN; + cutlass::FastDivmod flat_divider(num_tile_n); + + // dim3 block(size(Config::TiledMma{}) + 128); + dim3 block(384); + dim3 grid(get_sm_count()); + + int shm_seq = sizeof(int) * (num_group + 1); + int shm_size = config.get_shm_size() + shm_seq; + + if (k <= 1024 || n <= 1024) { + constexpr bool IsLoopH = true; + auto kernel = + kernels::group_gemm_bf16_kernel; + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shm_size); + + kernel<<>>(tma_w, tma_xy, (int *)seqlens_ptr, + (int *)tiles_ptr, (int *)cu_tiles_ptr, num_group, m, + n, k, flat_divider); + } else { + constexpr bool IsLoopH = false; + auto kernel = + kernels::group_gemm_bf16_kernel; + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shm_size); + + kernel<<>>(tma_w, tma_xy, (int *)seqlens_ptr, + (int *)tiles_ptr, (int *)cu_tiles_ptr, num_group, m, + n, k, flat_divider); + } + } +} + +void group_gemm_bf16_async(void *y_ptr, const void *x_ptr, const void *w_ptr, + const void *seqlens_ptr, const void *cu_seqlens_ptr, + void *tmas_ptr, void *tiles_ptr, + void *cu_tiles_ptr, int num_group, int m, int n, int k, + int num_seq_per_group_avg, bool update_tma, + cudaStream_t stream) { + constexpr int kTileN = 128; + constexpr int kTileK = 64; + constexpr int kWarpgroupM = 2; + constexpr int kWarpgroupN = 1; + constexpr int kSwizzleX = 128; + constexpr int kSwizzleW = 128; + constexpr int kSwizzleY = 64; + if (num_seq_per_group_avg <= 16) { + constexpr int kTileM = 16; + constexpr int kStage = 8; + launch_group_gemm_bf16(y_ptr, x_ptr, w_ptr, seqlens_ptr, cu_seqlens_ptr, + tmas_ptr, tiles_ptr, cu_tiles_ptr, + num_group, m, n, k, update_tma, stream); + } else if (num_seq_per_group_avg <= 32) { + constexpr int kTileM = 32; + constexpr int kStage = 8; + launch_group_gemm_bf16(y_ptr, x_ptr, w_ptr, seqlens_ptr, cu_seqlens_ptr, + tmas_ptr, tiles_ptr, cu_tiles_ptr, + num_group, m, n, k, update_tma, stream); + } else if (num_seq_per_group_avg <= 48) { + constexpr int kTileM = 48; + constexpr int kStage = 8; + launch_group_gemm_bf16(y_ptr, x_ptr, w_ptr, seqlens_ptr, cu_seqlens_ptr, + tmas_ptr, tiles_ptr, cu_tiles_ptr, + num_group, m, n, k, update_tma, stream); + } else { + constexpr int kTileM = 64; + constexpr int kStage = 8; + launch_group_gemm_bf16(y_ptr, x_ptr, w_ptr, seqlens_ptr, cu_seqlens_ptr, + tmas_ptr, tiles_ptr, cu_tiles_ptr, + num_group, m, n, k, update_tma, stream); + } +} + +} // namespace group_gemm +} // namespace hpc diff --git a/src/group_gemm/kernels.cuh b/src/group_gemm/kernels.cuh index aa7f06a..84973e4 100644 --- a/src/group_gemm/kernels.cuh +++ b/src/group_gemm/kernels.cuh @@ -752,6 +752,297 @@ __global__ void __launch_bounds__(384, 1) } } +template +__global__ void __launch_bounds__(384, 1) + group_gemm_bf16_kernel(const __grid_constant__ TmaB tma_b, cute::TmaDescriptor *td_xy, + int *seqlens_ptr, int *tiles_ptr, + int *cu_tiles_ptr, int num_group, int m, int n, int k, + cutlass::FastDivmod flat_divider) { + using namespace cute; // NOLINT + + using Tin = typename Config::Tin; + using Tout = typename Config::Tout; + using TiledMma = typename Config::TiledMma; + using SLayoutA = typename Config::SLayoutX; + using SLayoutB = typename Config::SLayoutW; + using SLayoutCT = typename Config::SLayoutY; + + constexpr int kTileM = Config::kTileM; + constexpr int kTileN = Config::kTileN; + constexpr int kTileK = Config::kTileK; + constexpr int kStage = Config::kStage; + + int idx = threadIdx.x; + + int iwarp = __shfl_sync(0xFFFFFFFF, idx / 32, 0); + int elected = cute::elect_one_sync(); + bool is_leader_in_block = (iwarp == 0) && elected; + bool is_leader_in_warpgroup = ((iwarp % 4) == 0) && elected; + + __shared__ uint64_t writable[kStage]; + __shared__ uint64_t readable[kStage]; + + extern __shared__ uint8_t shm_data[] alignas(128); + auto *shm_a = reinterpret_cast(shm_data); + auto *shm_b = shm_a + cosize(SLayoutA{}); + auto *shm_c = reinterpret_cast(shm_b + cosize(SLayoutB{})); + int *shm_tiles = reinterpret_cast(shm_c + cosize(SLayoutCT{})); + + TmaA tma_a; + TmaD tma_d; + + int num_total_warps = blockDim.x / 32; + for (int i = iwarp; i < num_group * 2; i += num_total_warps) { + tma_descriptor_fence_acquire(td_xy + i); + } + + auto sA = make_tensor(make_smem_ptr(shm_a), SLayoutA{}); + auto sB = make_tensor(make_smem_ptr(shm_b), SLayoutB{}); + + auto gA = tma_a.get_tma_tensor(make_shape(m, k)); + auto gB = tma_b.get_tma_tensor(make_shape(n, k, num_group)); + auto gC = + make_tensor(make_gmem_ptr(static_cast(nullptr)), + make_shape(Int{}, Int{}), make_stride(Int{}, Int<1>{})); + + auto btma_a = tma_a.get_slice(0); + auto btma_b = tma_b.get_slice(0); + + auto tAg = btma_a.partition_S(gA); // (TMA, TMA_M, TMA_K) + auto tAs = btma_a.partition_D(sA); // (TMA, _1, _1, kStage) + + auto tBg = btma_b.partition_S(gB); // (TMA, TMA_N, TMA_K, num_group) + auto tBs = btma_b.partition_D(sB); // (TMA, _1, _1, kStage) + + int num_tile_n = size<1>(tBg); + + if (is_leader_in_block) { +#pragma unroll + for (int i = 0; i < kStage; ++i) { + initialize_barrier(readable[i], 1); + initialize_barrier(writable[i], size(TiledMma{}) / 128); + } + } + + // we can also use the following code to initialize the barrier + /* + if (idx < kStage) { + readable[idx] = 0x7ffff800001ffffe; // initialize_barrier(1); + writable[idx] = 0x7ffff000001ffffc; // initialize_barrier(2); + } + */ + + int total_m = cu_tiles_ptr[num_group]; + if (total_m <= 0) { + return; + } + + if constexpr (IsLoopH) { + for (int i = idx; i < num_group; i += blockDim.x) { + shm_tiles[i] = tiles_ptr[i]; + } + } else { + for (int i = idx; i < (num_group + 1); i += blockDim.x) { + shm_tiles[i] = cu_tiles_ptr[i]; + } + } + + // sync to avoid ahead thread use(wait) readable when it is not initizlized yet + __syncthreads(); + + constexpr int kNumThreads = size(TiledMma{}); + // load warpgroup + if (idx >= kNumThreads) { + cutlass::arch::warpgroup_reg_dealloc<24>(); + idx -= kNumThreads; + constexpr int kTransactionBytes = sizeof(Tin) * (kTileM + kTileN) * kTileK; + // sizeof(Tin) * cosize(SLayoutA{}(_, _, 0)) + sizeof(Tin) * cosize(SLayoutB{}(_, _, 0)); + + int iwarp = __shfl_sync(0xFFFFFFFF, idx / 32, 0); + int is_leader_in_load = ((iwarp == 0) && elected); + + if (is_leader_in_load) { + int phase = 1; // start with ok + int ismem_write = __shfl_sync(0xFFFFFFFF, 0, 0); + int iblock = blockIdx.x; + int ntile_k = size<2>(tAg); + + int igroup = 0; + int sum_tile_m = 0; + int itile_m, itile_n; + while (true) { + if constexpr (IsLoopH) { + get_next_tile_horizon(shm_tiles, iblock, num_group, igroup, itile_m, itile_n, sum_tile_m, + flat_divider); + if (igroup < 0) { + break; + } + } else { + get_next_tile_vert(shm_tiles, iblock, num_group, igroup, itile_m, itile_n, total_m); + if (itile_n >= num_tile_n) { + break; + } + } + + iblock += gridDim.x; + + auto *td_x = td_xy + igroup * 2; + +#pragma unroll 1 + for (int itile_k = 0; itile_k < ntile_k; ++itile_k) { + // load a, b + wait_barrier(writable[ismem_write], phase); + + cute::copy(tma_a.with(td_x, readable[ismem_write]), tAg(_, itile_m, itile_k), + tAs(_, 0, 0, ismem_write)); + + cute::copy(tma_b.with(readable[ismem_write]), tBg(_, itile_n, itile_k, igroup), + tBs(_, 0, 0, ismem_write)); + + set_barrier_transaction_bytes(readable[ismem_write], kTransactionBytes); + + ++ismem_write; + if (ismem_write == kStage) { + ismem_write = 0; + phase ^= 1; + } + } // ntile_todo + } // while + } // if idx == 0 + + } else { + // math warpgroup + cutlass::arch::warpgroup_reg_alloc<168>(); + + int idx_in_warpgroup = idx % 128; + int iwarpgroup = idx / 128; + int iwarp_in_warpgroup = idx_in_warpgroup / 32; + int elected_idx_in_warpgroup = ((iwarp_in_warpgroup == 0) && elected); + + TiledMma tiled_mma; + + auto thr_mma = tiled_mma.get_slice(idx); + auto tBs4r = thr_mma.partition_A(sB); + auto tAs4r = thr_mma.partition_B(sA); + + auto tBr = thr_mma.make_fragment_A(tBs4r); // (MMA, MMA_N, MMA_K, kStage) + auto tAr = thr_mma.make_fragment_B(tAs4r); // (MMA, MMA_M, MMA_K, kStage) + + auto tCr = thr_mma.partition_fragment_C(gC); + + int ismem_read = 0; + int phase = 0; + + int iblock = blockIdx.x; + int igroup = 0; + int sum_tile_m = 0; + int itile_m, itile_n; + while (true) { + if constexpr (IsLoopH) { + get_next_tile_horizon(shm_tiles, iblock, num_group, igroup, itile_m, itile_n, sum_tile_m, + flat_divider); + if (igroup < 0) { + break; + } + } else { + get_next_tile_vert(shm_tiles, iblock, num_group, igroup, itile_m, itile_n, total_m); + if (itile_n >= num_tile_n) { + break; + } + } + + iblock += gridDim.x; + + int ntile_k = size<2>(tAg); + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + + // Pipelined K-loop: keep 1 WGMMA batch in flight (wait<1> instead of + // wait<0> per iter). The writable barrier of stage k is released only + // after wait<1> in iter k+1, so the MMA of iter k+1 can overlap with + // the producer's reload of stage k. The accumulator tCr only needs to + // be fenced once before the first WGMMA and once after the last drain. + int prev_ismem_read = 0; + warpgroup_fence_operand(tCr); +#pragma unroll 1 + for (int itile_k = 0; itile_k < ntile_k; ++itile_k) { + wait_barrier(readable[ismem_read], phase); + + // mma + warpgroup_arrive(); +#pragma unroll + for (int ik = 0; ik < size<2>(tAr); ++ik) { + cute::gemm(tiled_mma, tBr(_, _, ik, ismem_read), tAr(_, _, ik, ismem_read), tCr(_, _, _)); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + warpgroup_commit_batch(); + + if (itile_k >= 1) { + // Previous batch is now done; release the stage it consumed so + // the producer can refill it while the current batch is still + // in flight. + warpgroup_wait<1>(); + if (elected_idx_in_warpgroup) { + arrive_barrier(writable[prev_ismem_read]); + } + } + + prev_ismem_read = ismem_read; + ++ismem_read; + if (ismem_read == kStage) { + phase ^= 1; + ismem_read = 0; + } + } + + // Drain the last in-flight batch and release its stage. + warpgroup_wait<0>(); + warpgroup_fence_operand(tCr); + if (elected_idx_in_warpgroup) { + arrive_barrier(writable[prev_ismem_read]); + } + + // float32 -> bfloat16 + auto tCrh = make_tensor_like(tCr); + +#pragma unroll + for (int i = 0; i < size(tCr); ++i) { + tCrh(i) = (Tout)(tCr(i)); + } + + // Epilogue + auto sCT = + make_tensor(make_smem_ptr(reinterpret_cast(shm_c)), SLayoutCT{}); // (M, N) + using R2SCopyAtomC = Copy_Atom; + // using R2SCopyAtomC = Copy_Atom; + auto tiled_copy_c = make_tiled_copy_C(R2SCopyAtomC{}, tiled_mma); + auto thr_copy_c = tiled_copy_c.get_slice(idx); + + auto tCr4s = thr_copy_c.retile_S(tCrh); + auto tCs4r = thr_copy_c.partition_D(sCT); + + tma_store_wait<0>(); + syncwarpgroup(iwarpgroup); + + cute::copy(tiled_copy_c, tCr4s, tCs4r); + syncwarpgroup(iwarpgroup); + cute::tma_store_fence(); + + if (is_leader_in_warpgroup) { + auto gD = tma_d.get_tma_tensor(make_shape(n, m)); + auto btma_d = tma_d.get_slice(0); + + auto tDs = btma_d.partition_S(sCT); // (TMA, _2, _1) + auto tDg = btma_d.partition_D(gD); // (TMA, TMA_M, TMA_N) + + auto *td_y = td_xy + igroup * 2 + 1; + cute::copy(tma_d.with(td_y), tDs(_, iwarpgroup, Int<0>{}), + tDg(_, itile_n * 2 + iwarpgroup, itile_m)); + tma_store_arrive(); + } + } + } +} + } // namespace kernels } // namespace group_gemm diff --git a/tests/test_fuse_moe_bf16.py b/tests/test_fuse_moe_bf16.py new file mode 100644 index 0000000..7874733 --- /dev/null +++ b/tests/test_fuse_moe_bf16.py @@ -0,0 +1,201 @@ +import os +import sys +from pathlib import Path + +sys.path.insert(0, os.path.realpath(list(Path(__file__).parent.glob("../build/lib.*/"))[0])) + +import math +import pytest +import torch + +import hpc +from utils import allclose + +# Set random seed for reproducibility +torch.manual_seed(41) +torch.cuda.manual_seed(41) + + +def naive_gather_expert_inputs(x, topk_ids, num_expert, rank_ep): + num_tokens, num_topk = topk_ids.shape + num_seq, hidden_size = x.shape + + unique_values, num_tokens_per_expert_partial = torch.unique( + topk_ids.flatten(), return_counts=True, sorted=True + ) + start_expert = rank_ep * num_expert + end_expert = (rank_ep + 1) * num_expert + mask = (unique_values >= start_expert) & (unique_values < end_expert) + unique_values = unique_values[mask] + num_tokens_per_expert_partial = num_tokens_per_expert_partial[mask] + num_tokens_per_expert = torch.full([num_expert], 0, dtype=torch.int32, device="cuda") + + for i in range(unique_values.numel()): + num_tokens_per_expert[unique_values[i] - start_expert] = num_tokens_per_expert_partial[i] + + cu_num_tokens_per_expert = torch.cumsum( + torch.cat([torch.tensor([0], dtype=torch.int32, device="cuda"), num_tokens_per_expert]), + dim=0, + ).to(torch.int32) + + y = torch.zeros((num_tokens * num_topk, hidden_size), dtype=x.dtype, device=x.device) + token_pos = torch.zeros((num_tokens, num_topk), dtype=torch.int32, device=x.device) + token_pos.fill_(-1) + + # reset + num_tokens_per_expert.fill_(0) + + for idx, iexpert in enumerate(topk_ids.flatten()): + itoken = idx // num_topk + icol = idx % num_topk + if iexpert >= start_expert and iexpert < end_expert: + pos = ( + cu_num_tokens_per_expert[iexpert - start_expert] + + num_tokens_per_expert[iexpert - start_expert] + ) + y[pos] = x[itoken] + token_pos[itoken, icol] = pos.item() + num_tokens_per_expert[iexpert - start_expert] += 1 + + return ( + y, + token_pos, + num_tokens_per_expert, + cu_num_tokens_per_expert, + unique_values, + ) + + +def naive_group_gemm_bf16(x, w, cu_seqlens): + m, k = x.shape + num_group, n, _ = w.shape + + y = torch.zeros((m, n), dtype=torch.bfloat16, device=x.device) + + start_idx = 0 + for i in range(num_group): + start_idx = cu_seqlens[i].item() + end_idx = cu_seqlens[i + 1].item() + + x_group = x[start_idx:end_idx] + w_group = w[i] + + y_group = x_group @ w_group.t() + y[start_idx:end_idx] = y_group + + return y + + +def naive_act_mul_bf16(gate_up): + def silu(x): + return x / (1 + (-x).exp()) + + gate, up = torch.chunk(gate_up, 2, dim=1) + out = silu(gate) * up + return out + + +def naive_reduce(x_bf16, topk_pos, topk_scale, shared_output=None): + num_seq, num_topk = topk_pos.shape + total_num_seq, hidden_size = x_bf16.shape + + y_bf16 = torch.zeros((num_seq, hidden_size), dtype=torch.bfloat16, device=x_bf16.device) + for i in range(num_seq): + y_bf16[i] = torch.sum(x_bf16[topk_pos[i]] * topk_scale[i].unsqueeze(1), dim=0) + if shared_output is not None: + y_bf16[i] += shared_output[i] + + return y_bf16 + + +def naive_fuse_moe_bf16( + x, + gate_up_weight, + down_weight, + topk_ids, + topk_scale, + rank_ep, + shared_output=None, +): + num_expert = gate_up_weight.size(0) + # count_and_gather + gate_up_input, topk_pos, seqlens, cu_seqlens, expert_ids = naive_gather_expert_inputs( + x, topk_ids, num_expert, rank_ep + ) + + # gate_up_proj + gate_up_output = naive_group_gemm_bf16(gate_up_input, gate_up_weight, cu_seqlens) + + # act_and_mul + down_input = naive_act_mul_bf16(gate_up_output) + + # down_proj + down_output = naive_group_gemm_bf16(down_input, down_weight, cu_seqlens) + + # reduce + y = naive_reduce(down_output, topk_pos, topk_scale, shared_output) + + return y + + +@pytest.mark.parametrize("num_seq", [128]) +@pytest.mark.parametrize("num_topk", [8]) +@pytest.mark.parametrize("hidden_size", [512]) +@pytest.mark.parametrize("intermediate_size", [512]) +@pytest.mark.parametrize("num_expert", [128]) +@pytest.mark.parametrize("rank_ep", [0, 1]) +@pytest.mark.parametrize("size_ep", [1, 4, 8]) +@pytest.mark.parametrize("has_shared_output", [False, True]) +def test_fuse_moe_bf16( + num_seq, + num_topk, + hidden_size, + intermediate_size, + num_expert, + rank_ep, + size_ep, + has_shared_output, +): + dtype = torch.bfloat16 + + topk_ids = torch.randint(0, num_expert, (num_seq, num_topk), dtype=torch.int32, device="cuda") + topk_ids, _ = torch.sort(topk_ids, dim=1) + + x = torch.randn((num_seq, hidden_size), dtype=dtype, device="cuda") / 100 + gate_up_weight = torch.randn( + (num_expert // size_ep, intermediate_size * 2, hidden_size), + dtype=dtype, + device="cuda", + ) + down_weight = torch.randn( + (num_expert // size_ep, hidden_size, intermediate_size), + dtype=dtype, + device="cuda", + ) + topk_scale = torch.randn((num_seq, num_topk), dtype=torch.float, device="cuda") / num_topk + if has_shared_output: + shared_output = torch.randn((num_seq, hidden_size), dtype=dtype, device="cuda") + else: + shared_output = None + + my = hpc.fuse_moe_bf16( + x, + gate_up_weight, + down_weight, + topk_ids, + topk_scale, + rank_ep, + num_expert // size_ep, + shared_output=shared_output, + ) + gt = naive_fuse_moe_bf16( + x, + gate_up_weight, + down_weight, + topk_ids, + topk_scale, + rank_ep, + shared_output, + ) + + assert allclose(gt.to(torch.float32), my.to(torch.float32), rtol=0.1, atol=0.1) diff --git a/tests/test_group_gemm_bf16.py b/tests/test_group_gemm_bf16.py new file mode 100644 index 0000000..34a5b0d --- /dev/null +++ b/tests/test_group_gemm_bf16.py @@ -0,0 +1,65 @@ +import os +import sys +from pathlib import Path + +sys.path.insert(0, os.path.realpath(list(Path(__file__).parent.glob("../build/lib.*/"))[0])) + +import torch +import pytest +import hpc +from utils import allclose + +torch.manual_seed(41) +torch.cuda.manual_seed(41) + + +def naive_group_gemm_bf16(x, w, seqlens, cu_seqlens): + m, k = x.shape + num_group, n, _ = w.shape + + y = torch.zeros((m, n), dtype=torch.bfloat16, device=x.device) + + for i in range(num_group): + start_idx = int(cu_seqlens[i].item()) + end_idx = int(cu_seqlens[i + 1].item()) + + if seqlens[i].item() == 0: + continue + + x_group = x[start_idx:end_idx] # [M_i, K] + w_group = w[i] # [N, K] + + # y_group = x_group @ w_group.T + y_group = torch.matmul(x_group, w_group.t()) + + y[start_idx:end_idx] = y_group + + return y + + +@pytest.mark.parametrize("num_group", [8]) +@pytest.mark.parametrize("actual_m", [8, 16, 32, 64, 128, 256, 512]) +@pytest.mark.parametrize("m", [512]) +@pytest.mark.parametrize("n", [4096]) +@pytest.mark.parametrize("k", [7168]) +def test_group_gemm_bf16(num_group, actual_m, m, n, k): + dtype = torch.bfloat16 + + seqlens = torch.full((num_group,), actual_m, dtype=torch.int32, device="cuda") + total_seq = torch.sum(seqlens) + mean_seq = int(total_seq / num_group) + + x = torch.randn((total_seq, k), dtype=dtype, device="cuda") + w = torch.randn((num_group, n, k), dtype=dtype, device="cuda") + + cu_seqlens = torch.zeros(num_group + 1, dtype=torch.int32, device="cuda") + cu_seqlens[1:] = torch.cumsum(seqlens, dim=0) + gt = naive_group_gemm_bf16(x, w, seqlens, cu_seqlens) + + my = hpc.group_gemm_bf16(x, w, seqlens, cu_seqlens, num_seq_per_group_avg=mean_seq) + + assert allclose(gt.to(torch.float32), my.to(torch.float32), rtol=1e-2, atol=1e-2) + + +if __name__ == "__main__": + test_group_gemm_bf16(8, 128, 512, 4096, 7168)