feat: bf16 fused moe kernel#41
Conversation
Signed-off-by: ZelinMa557 <3388706467@qq.com>
Signed-off-by: ZelinMa557 <3388706467@qq.com>
Signed-off-by: ZelinMa557 <3388706467@qq.com>
Signed-off-by: ZelinMa557 <3388706467@qq.com>
Signed-off-by: ZelinMa557 <3388706467@qq.com>
|
benchmark script: |
|
Thank you very much for this contribution! Based on your benchmarks, both group GEMM and MoE show strong performance at small batch sizes. For large batch sizes, group GEMM still has some room for improvement. Before merging, we’d like to address two points (we’ll take a look internally @lhtin @weishengying @VAthree, but any input from you is also very welcome):
Thanks again! |
|
Thanks very much for your feedback! @reed-lau I agree we need to improve the performance on large batch size. I made some attempts:
It shows improvement. In tp scenario, both By the way, I found that simplify the epilogue logic only shows performance gain when Thanks for your feedback again, I would really appreciate it if you could help improve the performance on large batch size! |
Signed-off-by: ZelinMa557 <3388706467@qq.com>
Signed-off-by: ZelinMa557 <3388706467@qq.com>
|
@reed-lau Hi reed, I updated the code with:
now for gate_up_proj, bf16 group gemm is almost as fast as sglang when M/group == 1024: the down_proj is 5% faster than sglang's triton kernel in tp scenario: Before this update, end to end bf16_fused_moe is 5% to 10% slower than sglang at large batch. With this patch, now bf16_fused_moe achieves 0.99x to 1x sglang's performance: |
Based on #34, I implemented bf16 fused moe kernel. The API of this kernel follows the style of
fuse_moe_pertensor_fp8:This kernel significantly outperforms SGLang's Triton version when most experts are activated or the average tokens per expert <= 32. It provides a substantial speedup for the decoding and target verify stage in concurrent inference, addressing the lack of low-latency BF16 MoE decoding kernels in the open-source ecosystem.
performance test:
Qwen/Qwen3-235B-A22B tp8:
Qwen/Qwen3.5-122B-A10B tp4: