Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions aiter/configs/model_configs/hunyuan3_fp8_per_tensor_tuned_fmoe.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w,q_type,use_g1u1,doweight_stage1,block_m,ksplit,us1,kernelName1,err1,us2,kernelName2,err2,us,run_1stage,xbf16,tflops,bw,_tag
80,1,4096,192,192,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,16,0,15.986867346938894,fused_moe_asmjit_aot__16_True_False_False,6.03%,0.0,,0%,15.986867346938894,1,0,2.36,28335.58,
80,2,4096,192,192,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,16,0,31.46978571428563,fused_moe_asmjit_aot__16_True_False_False,4.33%,0.0,,0%,31.46978571428563,1,0,2.4,14395.06,
80,4,4096,192,192,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,16,0,42.93495959595968,fused_moe_asmjit_aot__16_True_False_False,3.25%,0.0,,0%,42.93495959595968,1,0,3.52,10551.63,
80,8,4096,192,192,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,16,0,60.95569000000001,fused_moe_asmjit_aot__16_True_False_False,3.65%,0.0,,0%,60.95569000000001,1,0,4.95,7432.99,
80,16,4096,192,192,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,16,0,90.33211111111105,fused_moe_asmjit_aot__16_True_False_False,3.68%,0.0,,0%,90.33211111111105,1,0,6.69,5016.84,
80,32,4096,192,192,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,16,0,138.59771717171733,fused_moe_asmjit_aot__16_True_False_False,3.49%,0.0,,0%,138.59771717171733,1,0,8.72,3271.18,
80,64,4096,192,192,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,16,0,223.5444946236561,fused_moe_asmjit_aot__64_True_True_False,2.21%,0.0,,0%,223.5444946236561,1,0,10.81,2029.89,
80,128,4096,192,192,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,16,0,240.1024555555547,fused_moe_asmjit_aot__64_True_True_False,2.52%,0.0,,0%,240.1024555555547,1,0,20.12,1893.18,
80,256,4096,192,192,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,16,0,247.60494117647124,fused_moe_asmjit_aot__64_True_True_False,2.73%,0.0,,0%,247.60494117647124,1,0,39.03,1842.17,
80,512,4096,192,192,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,16,0,284.450091836734,fused_moe_asmjit_aot__128_True_True_False,2.53%,0.0,,0%,284.450091836734,1,0,67.95,1614.61,
80,1024,4096,192,192,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,16,0,307.1181894736858,fused_moe_asmjit_aot__64_True_True_False,2.35%,0.0,,0%,307.1181894736858,1,0,125.86,1515.92,
80,2048,4096,192,192,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,16,0,576.2535684210527,fused_moe_asmjit_aot__128_True_True_False,2.04%,0.0,,0%,576.2535684210527,1,0,134.16,829.76,
80,4096,4096,192,192,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,16,0,1083.558147368421,fused_moe_asmjit_aot__128_True_True_False,2.14%,0.0,,0%,1083.558147368421,1,0,142.7,464.5,
80,8192,4096,192,192,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,16,0,1696.2595900000003,fused_moe_asmjit_aot__128_True_True_False,2.20%,0.0,,0%,1696.2595900000003,1,0,182.31,326.39,
80,16384,4096,192,192,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,16,0,3294.693249999998,fused_moe_asmjit_aot__128_True_True_False,2.19%,0.0,,0%,3294.693249999998,1,0,187.72,198.6,
80,32768,4096,192,192,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,16,0,6371.982838709681,fused_moe_asmjit_aot__128_True_True_False,2.24%,0.0,,0%,6371.982838709681,1,0,194.12,134.28,
80,65536,4096,192,192,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,16,0,12793.395297872323,fused_moe_asmjit_aot__128_True_True_False,2.28%,0.0,,0%,12793.395297872323,1,0,193.37,98.35,
80,131072,4096,192,192,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,16,0,25726.181630952433,fused_moe_asmjit_aot__128_True_True_False,2.28%,0.0,,0%,25726.181630952433,1,0,192.33,80.21,
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w,q_type,use_g1u1,doweight_stage1
1,4096,192,192,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0
2,4096,192,192,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0
4,4096,192,192,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0
8,4096,192,192,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0
16,4096,192,192,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0
32,4096,192,192,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0
64,4096,192,192,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0
128,4096,192,192,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0
256,4096,192,192,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0
512,4096,192,192,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0
1024,4096,192,192,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0
2048,4096,192,192,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0
4096,4096,192,192,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0
8192,4096,192,192,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0
16384,4096,192,192,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0
32768,4096,192,192,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0
65536,4096,192,192,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0
131072,4096,192,192,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0
18 changes: 18 additions & 0 deletions aiter/configs/model_configs/qwen3_5_397b_fp8_ptpc_tuned_fmoe.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w,q_type,use_g1u1,doweight_stage1,block_m,ksplit,us1,kernelName1,err1,us2,kernelName2,err2,us,run_1stage,xbf16,tflops,bw,_tag
80,1,4096,128,512,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,16,0,13.677915789473445,fused_moe_asmjit_aot__16_True_False_False,0.42%,0.0,,0%,13.677915789473445,1,0,2.3,58877.29,
80,2,4096,128,512,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,16,0,35.211979999999954,fused_moe_asmjit_aot__16_True_False_False,0.56%,0.0,,0%,35.211979999999954,1,0,1.79,22870.94,
80,4,4096,128,512,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,16,0,47.743629999999555,fused_moe_asmjit_aot__16_True_False_False,0.58%,0.0,,0%,47.743629999999555,1,0,2.64,16868.33,
80,8,4096,128,512,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,16,0,63.788901098900865,fused_moe_asmjit_aot__16_True_False_False,0.70%,0.0,,0%,63.788901098900865,1,0,3.95,12626.09,
80,16,4096,128,512,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,16,0,92.46276842105215,fused_moe_asmjit_aot__16_True_False_False,0.24%,0.0,,0%,92.46276842105215,1,0,5.44,8711.65,
80,32,4096,128,512,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,16,0,135.92059999999984,fused_moe_asmjit_aot__16_True_False_False,0.26%,0.0,,0%,135.92059999999984,1,0,7.41,5927.72,
80,128,4096,128,512,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,16,0,355.59850999999986,fused_moe_asmjit_aot__64_True_True_False,2.37%,0.0,,0%,355.59850999999986,1,0,11.32,2269.07,
80,256,4096,128,512,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,16,0,396.20228999999995,fused_moe_asmjit_aot__64_True_True_False,2.33%,0.0,,0%,396.20228999999995,1,0,20.33,2040.5,
80,512,4096,128,512,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,16,0,411.33022448979534,fused_moe_asmjit_aot__64_True_True_False,2.01%,0.0,,0%,411.33022448979534,1,0,39.16,1973.11,
80,1024,4096,128,512,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,16,0,446.35456701031023,fused_moe_asmjit_aot__64_True_True_False,1.87%,0.0,,0%,446.35456701031023,1,0,72.17,1832.38,
80,2048,4096,128,512,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,16,0,576.2807849462362,fused_moe_asmjit_aot__64_True_True_False,1.68%,0.0,,0%,576.2807849462362,1,0,111.79,1441.09,
80,4096,4096,128,512,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,16,0,993.6361145833362,fused_moe_asmjit_aot__128_True_True_False,1.81%,0.0,,0%,993.6361145833362,1,0,129.67,861.12,
80,8192,4096,128,512,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,16,0,1695.620755102034,fused_moe_asmjit_aot__128_True_True_True,1.82%,0.0,,0%,1695.620755102034,1,0,151.98,534.3,
80,16384,4096,128,512,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,16,0,3004.3378800000028,fused_moe_asmjit_aot__128_True_True_True,1.82%,0.0,,0%,3004.3378800000028,1,0,171.55,335.06,
80,32768,4096,128,512,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,16,0,5748.212340425532,fused_moe_asmjit_aot__128_True_True_True,1.83%,0.0,,0%,5748.212340425532,1,0,179.32,210.15,
80,65536,4096,128,512,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,16,0,11247.049708333347,fused_moe_asmjit_aot__128_True_True_True,1.85%,0.0,,0%,11247.049708333347,1,0,183.3,143.2,
80,131072,4096,128,512,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,16,0,22489.665710000027,fused_moe_asmjit_aot__128_True_True_True,1.85%,0.0,,0%,22489.665710000027,1,0,183.34,107.42,
19 changes: 19 additions & 0 deletions aiter/configs/model_configs/qwen3_5_397b_fp8_ptpc_untuned_fmoe.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w,q_type,use_g1u1,doweight_stage1
1,4096,128,512,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0
2,4096,128,512,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0
4,4096,128,512,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0
8,4096,128,512,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0
16,4096,128,512,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0
32,4096,128,512,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0
64,4096,128,512,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0
128,4096,128,512,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0
256,4096,128,512,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0
512,4096,128,512,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0
1024,4096,128,512,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0
2048,4096,128,512,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0
4096,4096,128,512,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0
8192,4096,128,512,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0
16384,4096,128,512,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0
32768,4096,128,512,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0
65536,4096,128,512,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0
131072,4096,128,512,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0
31 changes: 31 additions & 0 deletions aiter/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,22 @@ def fused_moe_(
gate_mode,
)

if metadata.stage0 is not None:
return metadata.stage0(
hidden_states,
w1,
w2,
topk_weight,
topk_ids,
activation,
quant_type,
w1_scale,
w2_scale,
expert_mask,
num_local_tokens,
moe_sorting_dispatch_policy,
)

block_size_M = metadata.block_m if block_size_M is None else block_size_M
# Ensure block_size_M is int (metadata.block_m from CSV may be float)
if block_size_M is not None:
Expand Down Expand Up @@ -719,6 +735,7 @@ class MOEMetadata:
use_non_temporal_load: bool = True
fuse_quant: str = ""
stage2_has_bias: bool = False
stage0: Callable = None


def _needs_swiglu_bias_support(dtype, quant_type):
Expand Down Expand Up @@ -1131,6 +1148,20 @@ def _lookup_cfg(c2s):
f"[fused_moe] using {'1stage' if run_1stage else '2stage'}{' xbf16' if run_1stage_xbf16 else ''} {'default' if cfg is None else tag} for {keys} "
)

if kernelName1.startswith("fused_moe_asmjit_aot"):
from aiter.fused_moe_asmjit_aot import fused_moe_asmjit_aot

return MOEMetadata(
None,
None,
block_m,
ksplit,
run_1stage,
stage0=functools.partial(
fused_moe_asmjit_aot, config_string=kernelName1.split("__")[1]
),
Comment on lines +1151 to +1162
)

def get_block_m() -> int:
if q_dtype_a == dtypes.fp8:
return 32
Expand Down
Loading
Loading