Skip to content

feat: bf16 fused moe kernel#41

Open
ZelinMa557 wants to merge 7 commits into
Tencent:mainfrom
ZelinMa557:bf16_moe
Open

feat: bf16 fused moe kernel#41
ZelinMa557 wants to merge 7 commits into
Tencent:mainfrom
ZelinMa557:bf16_moe

Conversation

@ZelinMa557
Copy link
Copy Markdown

Based on #34, I implemented bf16 fused moe kernel. The API of this kernel follows the style of fuse_moe_pertensor_fp8:

def fuse_moe_bf16(
    x: Tensor,
    gate_up_weight: Tensor,
    down_weight: Tensor,
    topk_ids: Tensor,
    topk_scale: Tensor,
    rank_ep: int,
    num_expert_total: int,
    shared_output: Tensor = None,
) -> Tensor:

This kernel significantly outperforms SGLang's Triton version when most experts are activated or the average tokens per expert <= 32. It provides a substantial speedup for the decoding and target verify stage in concurrent inference, addressing the lack of low-latency BF16 MoE decoding kernels in the open-source ecosystem.

performance test:
Qwen/Qwen3-235B-A22B tp8:

python tests/bench_fuse_moe_bf16.py  --hidden-size 4096 --intermediate-size 1536 --num-experts 128 --topk 8 --tp-size 8
<frozen importlib._bootstrap_external>:1297: FutureWarning: The cuda.cudart module is deprecated and will be removed in a future release, please switch to use the cuda.bindings.runtime module instead.
<frozen importlib._bootstrap_external>:1297: FutureWarning: The cuda.nvrtc module is deprecated and will be removed in a future release, please switch to use the cuda.bindings.nvrtc module instead.

Device              : NVIDIA H20
Model config        : hidden=4096, inter=192, experts=128, topk=8, tp=8
Local experts/GPU   : 128
Weight shapes       : gate_up=[128, 384, 4096], down=[128, 4096, 192]
Timing              : warmup=5, iters=100 (CUDA graph replay)

  batch    tokens     hpc(ms)    hpc(TF)     sgl(ms)    sgl(TF)   speedup
-------------------------------------------------------------------------
     16       128      0.1534       3.94      0.1664       3.63      1.08x
     32       256      0.1933       6.25      0.2913       4.15      1.51x
     48       384      0.2042       8.87      0.3112       5.82      1.52x
     64       512      0.2093      11.54      0.3221       7.50      1.54x
     80       640      0.2160      13.98      0.3326       9.08      1.54x
     96       768      0.2170      16.70      0.3333      10.87      1.54x
    112       896      0.2192      19.29      0.3321      12.73      1.51x
    128      1024      0.2235      21.62      0.3322      14.54      1.49x
    144      1152      0.2221      24.48      0.3344      16.26      1.51x
    160      1280      0.2309      26.16      0.3343      18.07      1.45x
    176      1408      0.2313      28.73      0.3364      19.75      1.45x
    192      1536      0.2321      31.23      0.3367      21.53      1.45x
    208      1664      0.2348      33.44      0.3384      23.20      1.44x
    224      1792      0.2415      35.01      0.3392      24.93      1.40x
    240      1920      0.2459      36.84      0.3405      26.61      1.38x
    256      2048      0.2474      39.06      0.3391      28.50      1.37x
    272      2176      0.2546      40.32      0.3428      29.95      1.35x
    288      2304      0.2565      42.39      0.3431      31.69      1.34x
    304      2432      0.2577      44.53      0.3456      33.20      1.34x
    320      2560      0.2574      46.94      0.3433      35.19      1.33x
    336      2688      0.2605      48.69      0.3453      36.73      1.33x
    352      2816      0.2610      50.92      0.3461      38.39      1.33x
    368      2944      0.2698      51.48      0.3500      39.69      1.30x
    384      3072      0.2767      52.38      0.3481      41.64      1.26x
    400      3200      0.2783      54.25      0.3502      43.12      1.26x
    416      3328      0.2795      56.17      0.3495      44.93      1.25x
    432      3456      0.2814      57.95      0.3534      46.14      1.26x
    448      3584      0.2853      59.27      0.3521      48.02      1.23x
    464      3712      0.2873      60.98      0.3565      49.13      1.24x
    480      3840      0.3083      58.78      0.3556      50.96      1.15x
    496      3968      0.3132      59.77      0.3584      52.25      1.14x
    512      4096      0.3156      61.24      0.3565      54.21      1.13x
-------------------------------------------------------------------------

Qwen/Qwen3.5-122B-A10B tp4:

python tests/bench_fuse_moe_bf16.py  --hidden-size 3072 --intermediate-size 1024 --num-experts 128 --topk 8 --tp-size 4
<frozen importlib._bootstrap_external>:1297: FutureWarning: The cuda.cudart module is deprecated and will be removed in a future release, please switch to use the cuda.bindings.runtime module instead.
<frozen importlib._bootstrap_external>:1297: FutureWarning: The cuda.nvrtc module is deprecated and will be removed in a future release, please switch to use the cuda.bindings.nvrtc module instead.

Device              : NVIDIA H20
Model config        : hidden=3072, inter=256, experts=128, topk=8, tp=4
Local experts/GPU   : 128
Weight shapes       : gate_up=[128, 512, 3072], down=[128, 3072, 256]
Timing              : warmup=5, iters=100 (CUDA graph replay)

  batch    tokens     hpc(ms)    hpc(TF)     sgl(ms)    sgl(TF)   speedup
-------------------------------------------------------------------------
     16       128      0.1488       4.06      0.1595       3.79      1.07x
     32       256      0.1864       6.48      0.2939       4.11      1.58x
     48       384      0.2028       8.93      0.3001       6.04      1.48x
     64       512      0.2058      11.74      0.3058       7.90      1.49x
     80       640      0.2117      14.26      0.3211       9.41      1.52x
     96       768      0.2121      17.09      0.3221      11.25      1.52x
    112       896      0.2141      19.75      0.3232      13.08      1.51x
    128      1024      0.2153      22.45      0.3260      14.82      1.51x
    144      1152      0.2165      25.10      0.3250      16.73      1.50x
    160      1280      0.2179      27.71      0.3279      18.42      1.50x
    176      1408      0.2189      30.36      0.3268      20.33      1.49x
    192      1536      0.2265      32.00      0.3270      22.16      1.44x
    208      1664      0.2266      34.65      0.3261      24.08      1.44x
    224      1792      0.2348      36.02      0.3280      25.78      1.40x
    240      1920      0.2377      38.12      0.3273      27.68      1.38x
    256      2048      0.2421      39.92      0.3288      29.39      1.36x
    272      2176      0.2462      41.70      0.3296      31.15      1.34x
    288      2304      0.2480      43.84      0.3306      32.89      1.33x
    304      2432      0.2491      46.07      0.3321      34.55      1.33x
    320      2560      0.2499      48.34      0.3316      36.42      1.33x
    336      2688      0.2505      50.64      0.3315      38.26      1.32x
    352      2816      0.2517      52.79      0.3339      39.79      1.33x
    368      2944      0.2525      55.01      0.3357      41.38      1.33x
    384      3072      0.2537      57.13      0.3363      43.10      1.33x
    400      3200      0.2665      56.67      0.3355      45.01      1.26x
    416      3328      0.2688      58.42      0.3372      46.57      1.25x
    432      3456      0.2679      60.86      0.3370      48.39      1.26x
    448      3584      0.2719      62.19      0.3383      49.99      1.24x
    464      3712      0.2843      61.62      0.3403      51.47      1.20x
    480      3840      0.2903      62.41      0.3410      53.13      1.17x
    496      3968      0.3028      61.83      0.3411      54.89      1.13x
    512      4096      0.3090      62.55      0.3430      56.35      1.11x
-------------------------------------------------------------------------

Signed-off-by: ZelinMa557 <3388706467@qq.com>
Signed-off-by: ZelinMa557 <3388706467@qq.com>
Signed-off-by: ZelinMa557 <3388706467@qq.com>
Signed-off-by: ZelinMa557 <3388706467@qq.com>
Signed-off-by: ZelinMa557 <3388706467@qq.com>
@ZelinMa557
Copy link
Copy Markdown
Author

benchmark script:

"""
Benchmark script for hpc.fuse_moe_bf16 vs sglang fused_moe (BF16).

Default model: Qwen3-235B-A22B, TP=8 (EP=8 simulation on a single GPU)
  hidden_size        = 4096
  moe_intermediate_size = 1536
  num_experts        = 128  ->  num_experts_local = 128 // 8 = 16 per GPU
  num_experts_per_tok = 8 (topk)

Usage:
    python tests/bench_fuse_moe_bf16.py
    python tests/bench_fuse_moe_bf16.py --hidden-size 4096 --intermediate-size 1536 \\
        --num-experts 128 --topk 8 --tp-size 8
    python tests/bench_fuse_moe_bf16.py --warmup 10 --iters 100
"""

import argparse
import os
import sys
from pathlib import Path

sys.path.insert(0, os.path.realpath(list(Path(__file__).parent.glob("../build/lib.*/"))[0]))

import triton.language as tl
import torch
import hpc
from sglang.srt.layers.moe.fused_moe_triton.fused_moe_triton_kernels import invoke_fused_moe_kernel
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
from sgl_kernel import silu_and_mul

# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------


def parse_args():
    p = argparse.ArgumentParser()
    p.add_argument("--hidden-size", type=int, default=4096)
    p.add_argument("--intermediate-size", type=int, default=1536,
                   help="moe_intermediate_size per expert (before gate/up split)")
    p.add_argument("--num-experts", type=int, default=128)
    p.add_argument("--topk", type=int, default=8)
    p.add_argument("--tp-size", type=int, default=8,
                   help="Tensor/Expert parallelism size. num_experts_local = num_experts // tp_size")
    p.add_argument("--warmup", type=int, default=5,
                   help="Warmup iterations before graph capture")
    p.add_argument("--iters", type=int, default=100,
                   help="Graph replay iterations for timing")
    return p.parse_args()


# ---------------------------------------------------------------------------
# Batch sizes to sweep
# ---------------------------------------------------------------------------

BATCH_SIZES = [i for i in range(16, 513, 16)]


# ---------------------------------------------------------------------------
# Input helpers
# ---------------------------------------------------------------------------


def make_inputs(batch_size, hidden_size, intermediate_size, num_experts_local, topk,
                device="cuda"):
    """
    Build BF16 inputs for a single-GPU EP benchmark.
    topk_ids are sampled uniformly from local experts [0, num_experts_local).
    """
    dtype = torch.bfloat16

    x = torch.randn((batch_size, hidden_size), dtype=dtype, device=device)

    # gate+up fused: [E_local, inter*2, hidden]
    gate_up_weight = torch.randn(
        (num_experts_local, intermediate_size * 2, hidden_size), dtype=dtype, device=device
    )
    # down:          [E_local, hidden, inter]
    down_weight = torch.randn(
        (num_experts_local, hidden_size, intermediate_size), dtype=dtype, device=device
    )

    # topk_ids in [0, num_experts_local); topk_scale positive, sum-normalised
    topk_ids = torch.randint(
        0, num_experts_local, (batch_size, topk), dtype=torch.int32, device=device
    )
    raw_scale = torch.rand((batch_size, topk), dtype=torch.float32, device=device)
    topk_scale = raw_scale / raw_scale.sum(dim=1, keepdim=True)

    return x, gate_up_weight, down_weight, topk_ids, topk_scale


# ---------------------------------------------------------------------------
# FLOPs estimate (two GEMMs: gate_up and down, per token-expert assignment)
# ---------------------------------------------------------------------------


def tflops(batch_size, topk, hidden_size, intermediate_size, elapsed_ms):
    # gate_up:  batch*topk × (inter*2) × hidden  (factor 2 for gate+up)
    # down:     batch*topk × hidden × inter
    tokens = batch_size * topk
    flops = 2 * tokens * intermediate_size * 2 * hidden_size  # gate_up gemm
    flops += 2 * tokens * hidden_size * intermediate_size       # down gemm
    return flops / (elapsed_ms * 1e-3) / 1e12


# ---------------------------------------------------------------------------
# CUDA-graph benchmarking helper
# ---------------------------------------------------------------------------


def bench_cuda_graph(fn, warmup, iters):
    """Warm up, capture a CUDA graph, replay `iters` times, return avg ms."""
    for _ in range(warmup):
        fn()
    torch.cuda.synchronize()

    g = torch.cuda.CUDAGraph()
    capture_stream = torch.cuda.Stream()
    with torch.cuda.stream(capture_stream):
        fn()
        torch.cuda.synchronize()
        with torch.cuda.graph(g, stream=capture_stream):
            fn()

    torch.cuda.synchronize()

    t0 = torch.cuda.Event(enable_timing=True)
    t1 = torch.cuda.Event(enable_timing=True)
    t0.record()
    for _ in range(iters):
        g.replay()
    t1.record()
    torch.cuda.synchronize()
    return t0.elapsed_time(t1) / iters


# ---------------------------------------------------------------------------
# sglang default BF16 config
# ---------------------------------------------------------------------------


def sglang_bf16_config(total_m, E):
    """
    Static BF16 config mirroring sglang's get_default_config logic:
      total_m <= E  ->  small-batch config (BLOCK_SIZE_M=16)
      total_m >  E  ->  regular config    (BLOCK_SIZE_M=64)
    """
    if total_m <= E:
        return {
            "BLOCK_SIZE_M": 16,
            "BLOCK_SIZE_N": 32,
            "BLOCK_SIZE_K": 64,
            "GROUP_SIZE_M": 1,
            "num_warps": 4,
            "num_stages": 3,
        }
    else:
        return {
            "BLOCK_SIZE_M": 64,
            "BLOCK_SIZE_N": 64,
            "BLOCK_SIZE_K": 32,
            "GROUP_SIZE_M": 8,
            "num_warps": 4,
            "num_stages": 3,
        }


# ---------------------------------------------------------------------------
# Per-kernel bench helpers
# ---------------------------------------------------------------------------


def bench_hpc(x, gate_up_weight, down_weight, topk_ids, topk_scale,
              num_experts_local, warmup, iters):
    """Benchmark hpc.fuse_moe_bf16 (full pipeline)."""
    # rank_ep=0, num_expert_total=num_experts_local: all experts are local
    def fn():
        hpc.fuse_moe_bf16(
            x, gate_up_weight, down_weight,
            topk_ids, topk_scale,
            rank_ep=0,
            num_expert_total=num_experts_local,
        )

    return bench_cuda_graph(fn, warmup, iters)


def bench_sglang(x, gate_up_weight, down_weight, topk_ids, topk_scale,
                 num_experts_local, warmup, iters):
    """
    Benchmark sglang full fused_moe pipeline (BF16) via direct kernel calls.
    All steps are inside the CUDA graph for a fair comparison with hpc.fuse_moe_bf16:
      0. sgl_moe_align_block_size  token sorting  (= hpc count_and_gather)
      1. invoke_fused_moe_kernel   gate_up GEMM   (mul_routed_weight=False)
      2. silu_and_mul              SiLU activation
      3. invoke_fused_moe_kernel   down GEMM      (mul_routed_weight=True, top_k=1)
      4. torch.sum over topk dim   weighted reduce
    """
    batch_size = x.shape[0]
    hidden_size = x.shape[1]
    inter_x2 = gate_up_weight.shape[1]   # intermediate_size * 2
    inter = inter_x2 // 2
    topk = topk_ids.shape[1]
    total_tokens = batch_size * topk

    config = sglang_bf16_config(total_tokens, num_experts_local)
    block_size = config["BLOCK_SIZE_M"]

    # Pre-allocate routing buffers at max possible size (shapes are static for CUDA graph)
    if topk_ids.numel() < num_experts_local + 1:
        max_padded = topk_ids.numel() * block_size
    else:
        max_padded = topk_ids.numel() + (num_experts_local + 1) * (block_size - 1)
    max_m_blocks = (max_padded + block_size - 1) // block_size

    sorted_ids          = torch.empty((max_padded,),              dtype=torch.int32, device=x.device)
    expert_ids          = torch.empty((max_m_blocks,),            dtype=torch.int32, device=x.device)
    num_tokens_post_pad = torch.empty((1,),                       dtype=torch.int32, device=x.device)
    cumsum_buf          = torch.empty((num_experts_local + 2,),   dtype=torch.int32, device=x.device)

    # Intermediate compute buffers (max padded size)
    cache1 = torch.empty((max_padded, inter_x2),        dtype=torch.bfloat16, device=x.device)
    cache2 = torch.empty((max_padded, inter),            dtype=torch.bfloat16, device=x.device)
    cache3 = torch.empty((batch_size, topk, hidden_size), dtype=torch.bfloat16, device=x.device)
    out    = torch.empty((batch_size, hidden_size),       dtype=torch.bfloat16, device=x.device)

    def fn():
        # 0. token sorting
        sgl_moe_align_block_size(
            topk_ids, num_experts_local + 1, block_size,
            sorted_ids, expert_ids, num_tokens_post_pad, cumsum_buf, True,
        )
        # 1. gate_up GEMM
        invoke_fused_moe_kernel(
            x, gate_up_weight, None, cache1,
            None, None, None,
            topk_scale, topk_ids,
            sorted_ids, expert_ids, num_tokens_post_pad,
            False,  # mul_routed_weight
            topk,   # top_k
            config,
            compute_type=tl.bfloat16,
            use_fp8_w8a8=False,
            use_int8_w8a8=False,
            use_int8_w8a16=False,
            use_int4_w4a16=False,
            per_channel_quant=False,
            block_shape=None,
        )
        # 2. SiLU activation: cache1 [max_padded, inter*2] -> cache2 [max_padded, inter]
        silu_and_mul(cache1, cache2)
        # 3. down GEMM – writes weighted results to cache3[batch, topk, hidden]
        invoke_fused_moe_kernel(
            cache2, down_weight, None, cache3,
            None, None, None,
            topk_scale, topk_ids,
            sorted_ids, expert_ids, num_tokens_post_pad,
            True,  # mul_routed_weight (applies routing weights during scatter)
            1,     # top_k=1: each sorted row maps to one (batch, topk_slot) in cache3
            config,
            compute_type=tl.bfloat16,
            use_fp8_w8a8=False,
            use_int8_w8a8=False,
            use_int8_w8a16=False,
            use_int4_w4a16=False,
            per_channel_quant=False,
            block_shape=None,
        )
        # 4. Sum-reduce over topk slots -> [batch, hidden]
        torch.sum(cache3, dim=1, out=out)

    return bench_cuda_graph(fn, warmup, iters)


# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------


def main():
    args = parse_args()
    tp_size = args.tp_size
    hidden_size = args.hidden_size
    intermediate_size = args.intermediate_size // tp_size
    num_experts = args.num_experts
    topk = args.topk
    num_experts_local = num_experts

    torch.manual_seed(0)
    torch.cuda.manual_seed(0)

    prop = torch.cuda.get_device_properties(0)
    print(f"\nDevice              : {prop.name}")
    print(f"Model config        : hidden={hidden_size}, inter={intermediate_size}, "
          f"experts={num_experts}, topk={topk}, tp={tp_size}")
    print(f"Local experts/GPU   : {num_experts_local}")
    print(f"Weight shapes       : gate_up=[{num_experts_local}, {intermediate_size*2}, {hidden_size}], "
          f"down=[{num_experts_local}, {hidden_size}, {intermediate_size}]")
    print(f"Timing              : warmup={args.warmup}, iters={args.iters} (CUDA graph replay)\n")

    hdr = (
        f"{'batch':>7}  {'tokens':>8}  "
        f"{'hpc(ms)':>10}  {'hpc(TF)':>9}  "
        f"{'sgl(ms)':>10}  {'sgl(TF)':>9}  "
        f"{'speedup':>8}"
    )
    sep = "-" * len(hdr)
    print(hdr)
    print(sep)

    for bs in BATCH_SIZES:
        x, gate_up_weight, down_weight, topk_ids, topk_scale = make_inputs(
            bs, hidden_size, intermediate_size, num_experts_local, topk
        )

        hpc_ms = bench_hpc(
            x, gate_up_weight, down_weight, topk_ids, topk_scale,
            num_experts_local, args.warmup, args.iters
        )
        sgl_ms = bench_sglang(
            x, gate_up_weight, down_weight, topk_ids, topk_scale,
            num_experts_local, args.warmup, args.iters
        )

        hpc_tf = tflops(bs, topk, hidden_size, intermediate_size, hpc_ms)
        sgl_tf = tflops(bs, topk, hidden_size, intermediate_size, sgl_ms)

        print(
            f"{bs:>7}  {bs*topk:>8}  "
            f"{hpc_ms:>10.4f}  {hpc_tf:>9.2f}  "
            f"{sgl_ms:>10.4f}  {sgl_tf:>9.2f}  "
            f"{sgl_ms/hpc_ms:>8.2f}x"
        )

    print(sep)
    print()


if __name__ == "__main__":
    main()

@reed-lau
Copy link
Copy Markdown
Collaborator

Thank you very much for this contribution! Based on your benchmarks, both group GEMM and MoE show strong performance at small batch sizes. For large batch sizes, group GEMM still has some room for improvement.

Before merging, we’d like to address two points (we’ll take a look internally @lhtin @weishengying @VAthree, but any input from you is also very welcome):

  1. For BF16, TensorCore precision is sufficient on its own—no additional CUDA core precision correction is needed, so that part can be removed.

  2. We’d like to optimize the large‑batch scenario for group MoE to ensure that group GEMM performance is no worse than the SGLang Triton implementation.

Thanks again!

@ZelinMa557
Copy link
Copy Markdown
Author

Thanks very much for your feedback! @reed-lau

I agree we need to improve the performance on large batch size. I made some attempts:

  1. increasekTileM to 128 and reduce kStage to 4 on large batch size. (kStage > 4 will exceed the shared memory limit)
  2. simplify the epilogue logic, as you said:

For BF16, TensorCore precision is sufficient on its own—no additional CUDA core precision correction is needed, so that part can be removed.

It shows improvement. In tp scenario, both gate_up_proj and down_proj are 0.99~1.03x compare to sglang's triton kernel. However, when I run the fused moe benchmark, the benchmark script stuck. I failed to fix it.

By the way, I found that simplify the epilogue logic only shows performance gain when kTileM = 128 in down_proj.

Thanks for your feedback again, I would really appreciate it if you could help improve the performance on large batch size!

Signed-off-by: ZelinMa557 <3388706467@qq.com>
Signed-off-by: ZelinMa557 <3388706467@qq.com>
@ZelinMa557
Copy link
Copy Markdown
Author

@reed-lau Hi reed, I updated the code with:

  1. remove additional CUDA core precision correction
  2. better wgmma pipelining: use warpgroup_wait<1> instead of warpgroup_wait<0> in math warp group to allow an inflight batch; issue another warpgroup_wait<0> outside the while loop to ensure correctness.

now for gate_up_proj, bf16 group gemm is almost as fast as sglang when M/group == 1024:

root@mxk-sgl-dev-0:/sgl-workspace/hpc-ops# python tests/bench_group_gemm_bf16.py --E 128 --N 384 --K 4096
<frozen importlib._bootstrap_external>:1297: FutureWarning: The cuda.cudart module is deprecated and will be removed in a future release, please switch to use the cuda.bindings.runtime module instead.
<frozen importlib._bootstrap_external>:1297: FutureWarning: The cuda.nvrtc module is deprecated and will be removed in a future release, please switch to use the cuda.bindings.nvrtc module instead.

Device : NVIDIA H20
Config : E=128, N=384, K=4096
Timing : warmup=5, iters=100 (CUDA graph replay)

 M/group   total_M     hpc(ms)    hpc(TF)     sgl(ms)    sgl(TF)     ref(ms)    ref(TF)   hpc/sgl
-------------------------------------------------------------------------------------------------
       1       128      0.1231       3.27      0.1414       2.85      1.0531       0.38      1.15x
       2       256      0.1243       6.48      0.2092       3.85      1.0527       0.77      1.68x
       4       512      0.1246      12.93      0.2090       7.71      1.0711       1.50      1.68x
       8      1024      0.1271      25.34      0.2100      15.34      1.0845       2.97      1.65x
      16      2048      0.1294      49.78      0.2101      30.67      1.1086       5.81      1.62x
      32      4096      0.1434      89.86      0.2117      60.87      1.2463      10.34      1.48x
      64      8192      0.2052     125.57      0.2321     111.01      1.2813      20.11      1.13x
     128     16384      0.3922     131.41      0.3928     131.22      1.3157      39.17      1.00x
     256     32768      0.7708     133.73      0.7532     136.85      1.7936      57.47      0.98x
     512     65536      1.5160     135.99      1.4679     140.45      2.7625      74.63      0.97x
    1024    131072      2.9454     139.99      2.9215     141.13      4.2998      95.89      0.99x
-------------------------------------------------------------------------------------------------

the down_proj is 5% faster than sglang's triton kernel in tp scenario:

root@mxk-sgl-dev-0:/sgl-workspace/hpc-ops# python tests/bench_group_gemm_bf16.py --E 128 --N 4096 --K 192
<frozen importlib._bootstrap_external>:1297: FutureWarning: The cuda.cudart module is deprecated and will be removed in a future release, please switch to use the cuda.bindings.runtime module instead.
<frozen importlib._bootstrap_external>:1297: FutureWarning: The cuda.nvrtc module is deprecated and will be removed in a future release, please switch to use the cuda.bindings.nvrtc module instead.

Device : NVIDIA H20
Config : E=128, N=4096, K=192
Timing : warmup=5, iters=100 (CUDA graph replay)

 M/group   total_M     hpc(ms)    hpc(TF)     sgl(ms)    sgl(TF)     ref(ms)    ref(TF)   hpc/sgl
-------------------------------------------------------------------------------------------------
       1       128      0.0684       2.94      0.0843       2.39      0.6095       0.33      1.23x
       2       256      0.0702       5.73      0.0969       4.16      0.6144       0.66      1.38x
       4       512      0.0713      11.29      0.0977       8.25      0.6205       1.30      1.37x
       8      1024      0.0730      22.06      0.0976      16.50      0.6327       2.55      1.34x
      16      2048      0.0754      42.74      0.1001      32.17      0.6405       5.03      1.33x
      32      4096      0.0855      75.32      0.1068      60.33      0.6781       9.50      1.25x
      64      8192      0.1114     115.66      0.1210     106.45      0.7466      17.26      1.09x
     128     16384      0.1967     130.99      0.2148     119.98      0.8714      29.57      1.09x
     256     32768      0.3817     135.04      0.4011     128.49      1.1856      43.47      1.05x
     512     65536      0.7482     137.76      0.7839     131.50      1.6880      61.07      1.05x
    1024    131072      1.4831     139.00      1.5548     132.59      2.9124      70.79      1.05x
-------------------------------------------------------------------------------------------------

Before this update, end to end bf16_fused_moe is 5% to 10% slower than sglang at large batch. With this patch, now bf16_fused_moe achieves 0.99x to 1x sglang's performance:

root@mxk-sgl-dev-0:/sgl-workspace/hpc-ops# python tests/bench_fuse_moe_bf16.py --hidden-size 4096 --intermediate-size 1536 --num-experts 128 --topk 8 --tp-size 8
<frozen importlib._bootstrap_external>:1297: FutureWarning: The cuda.cudart module is deprecated and will be removed in a future release, please switch to use the cuda.bindings.runtime module instead.
<frozen importlib._bootstrap_external>:1297: FutureWarning: The cuda.nvrtc module is deprecated and will be removed in a future release, please switch to use the cuda.bindings.nvrtc module instead.

Device              : NVIDIA H20
Model config        : hidden=4096, inter=192, experts=128, topk=8, tp=8
Local experts/GPU   : 128
Weight shapes       : gate_up=[128, 384, 4096], down=[128, 4096, 192]
Timing              : warmup=5, iters=100 (CUDA graph replay)

  batch    tokens     hpc(ms)    hpc(TF)     sgl(ms)    sgl(TF)   speedup
-------------------------------------------------------------------------
    512      4096      0.3174      60.89      0.3590      53.84      1.13x
   1024      8192      0.5309      72.81      0.5255      73.56      0.99x
   1536     12288      0.6776      85.57      0.6723      86.24      0.99x
   2048     16384      0.8568      90.23      0.8343      92.66      0.97x
   2560     20480      1.0109      95.59      1.0000      96.64      0.99x
   3072     24576      1.1860      97.78      1.1578     100.15      0.98x
   3584     28672      1.3393     101.02      1.3276     101.90      0.99x
   4096     32768      1.5198     101.74      1.4898     103.79      0.98x
   4608     36864      1.6757     103.80      1.6659     104.42      0.99x
   5120     40960      1.8748     103.09      1.8209     106.14      0.97x
   5632     45056      2.0092     105.81      1.9874     106.98      0.99x
   6144     49152      2.1553     107.61      2.1349     108.63      0.99x
   6656     53248      2.3411     107.32      2.3080     108.86      0.99x
   7168     57344      2.5167     107.52      2.4826     108.99      0.99x
   7680     61440      2.6752     108.37      2.6478     109.49      0.99x
   8192     65536      2.8218     109.59      2.7965     110.58      0.99x
   8704     69632      3.0020     109.45      2.9554     111.18      0.98x
   9216     73728      3.1309     111.12      3.1102     111.86      0.99x
   9728     77824      3.3192     110.63      3.2870     111.72      0.99x
  10240     81920      3.4835     110.97      3.4512     112.01      0.99x
  10752     86016      3.6400     111.50      3.5953     112.89      0.99x
  11264     90112      3.7951     112.04      3.7788     112.52      1.00x
  11776     94208      3.9410     112.80      3.9227     113.32      1.00x
  12288     98304      4.1342     112.20      4.1211     112.56      1.00x
  12800    102400      4.2906     112.61      4.2480     113.74      0.99x
  13312    106496      4.4589     112.70      4.4150     113.82      0.99x
  13824    110592      4.6114     113.16      4.5729     114.12      0.99x
  14336    114688      4.8008     112.72      4.7524     113.87      0.99x
  14848    118784      4.9320     113.64      4.9100     114.15      1.00x
  15360    122880      5.1384     112.84      5.0693     114.38      0.99x
  15872    126976      5.2819     113.44      5.2354     114.44      0.99x
-------------------------------------------------------------------------

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants