diff --git a/aiter/configs/model_configs/hunyuan3_fp8_per_tensor_tuned_fmoe.csv b/aiter/configs/model_configs/hunyuan3_fp8_per_tensor_tuned_fmoe.csv new file mode 100644 index 0000000000..917e03da47 --- /dev/null +++ b/aiter/configs/model_configs/hunyuan3_fp8_per_tensor_tuned_fmoe.csv @@ -0,0 +1,19 @@ +cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w,q_type,use_g1u1,doweight_stage1,block_m,ksplit,us1,kernelName1,err1,us2,kernelName2,err2,us,run_1stage,xbf16,tflops,bw,_tag +80,1,4096,192,192,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,16,0,15.986867346938894,fused_moe_asmjit_aot__16_True_False_False,6.03%,0.0,,0%,15.986867346938894,1,0,2.36,28335.58, +80,2,4096,192,192,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,16,0,31.46978571428563,fused_moe_asmjit_aot__16_True_False_False,4.33%,0.0,,0%,31.46978571428563,1,0,2.4,14395.06, +80,4,4096,192,192,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,16,0,42.93495959595968,fused_moe_asmjit_aot__16_True_False_False,3.25%,0.0,,0%,42.93495959595968,1,0,3.52,10551.63, +80,8,4096,192,192,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,16,0,60.95569000000001,fused_moe_asmjit_aot__16_True_False_False,3.65%,0.0,,0%,60.95569000000001,1,0,4.95,7432.99, +80,16,4096,192,192,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,16,0,90.33211111111105,fused_moe_asmjit_aot__16_True_False_False,3.68%,0.0,,0%,90.33211111111105,1,0,6.69,5016.84, +80,32,4096,192,192,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,16,0,138.59771717171733,fused_moe_asmjit_aot__16_True_False_False,3.49%,0.0,,0%,138.59771717171733,1,0,8.72,3271.18, +80,64,4096,192,192,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,16,0,223.5444946236561,fused_moe_asmjit_aot__64_True_True_False,2.21%,0.0,,0%,223.5444946236561,1,0,10.81,2029.89, +80,128,4096,192,192,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,16,0,240.1024555555547,fused_moe_asmjit_aot__64_True_True_False,2.52%,0.0,,0%,240.1024555555547,1,0,20.12,1893.18, +80,256,4096,192,192,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,16,0,247.60494117647124,fused_moe_asmjit_aot__64_True_True_False,2.73%,0.0,,0%,247.60494117647124,1,0,39.03,1842.17, +80,512,4096,192,192,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,16,0,284.450091836734,fused_moe_asmjit_aot__128_True_True_False,2.53%,0.0,,0%,284.450091836734,1,0,67.95,1614.61, +80,1024,4096,192,192,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,16,0,307.1181894736858,fused_moe_asmjit_aot__64_True_True_False,2.35%,0.0,,0%,307.1181894736858,1,0,125.86,1515.92, +80,2048,4096,192,192,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,16,0,576.2535684210527,fused_moe_asmjit_aot__128_True_True_False,2.04%,0.0,,0%,576.2535684210527,1,0,134.16,829.76, +80,4096,4096,192,192,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,16,0,1083.558147368421,fused_moe_asmjit_aot__128_True_True_False,2.14%,0.0,,0%,1083.558147368421,1,0,142.7,464.5, +80,8192,4096,192,192,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,16,0,1696.2595900000003,fused_moe_asmjit_aot__128_True_True_False,2.20%,0.0,,0%,1696.2595900000003,1,0,182.31,326.39, +80,16384,4096,192,192,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,16,0,3294.693249999998,fused_moe_asmjit_aot__128_True_True_False,2.19%,0.0,,0%,3294.693249999998,1,0,187.72,198.6, +80,32768,4096,192,192,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,16,0,6371.982838709681,fused_moe_asmjit_aot__128_True_True_False,2.24%,0.0,,0%,6371.982838709681,1,0,194.12,134.28, +80,65536,4096,192,192,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,16,0,12793.395297872323,fused_moe_asmjit_aot__128_True_True_False,2.28%,0.0,,0%,12793.395297872323,1,0,193.37,98.35, +80,131072,4096,192,192,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,16,0,25726.181630952433,fused_moe_asmjit_aot__128_True_True_False,2.28%,0.0,,0%,25726.181630952433,1,0,192.33,80.21, diff --git a/aiter/configs/model_configs/hunyuan3_fp8_per_tensor_untuned_fmoe.csv b/aiter/configs/model_configs/hunyuan3_fp8_per_tensor_untuned_fmoe.csv new file mode 100644 index 0000000000..d1d589dc79 --- /dev/null +++ b/aiter/configs/model_configs/hunyuan3_fp8_per_tensor_untuned_fmoe.csv @@ -0,0 +1,19 @@ +token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w,q_type,use_g1u1,doweight_stage1 +1,4096,192,192,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0 +2,4096,192,192,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0 +4,4096,192,192,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0 +8,4096,192,192,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0 +16,4096,192,192,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0 +32,4096,192,192,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0 +64,4096,192,192,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0 +128,4096,192,192,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0 +256,4096,192,192,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0 +512,4096,192,192,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0 +1024,4096,192,192,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0 +2048,4096,192,192,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0 +4096,4096,192,192,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0 +8192,4096,192,192,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0 +16384,4096,192,192,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0 +32768,4096,192,192,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0 +65536,4096,192,192,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0 +131072,4096,192,192,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0 diff --git a/aiter/configs/model_configs/qwen3_5_397b_fp8_ptpc_tuned_fmoe.csv b/aiter/configs/model_configs/qwen3_5_397b_fp8_ptpc_tuned_fmoe.csv new file mode 100644 index 0000000000..4635279cbe --- /dev/null +++ b/aiter/configs/model_configs/qwen3_5_397b_fp8_ptpc_tuned_fmoe.csv @@ -0,0 +1,18 @@ +cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w,q_type,use_g1u1,doweight_stage1,block_m,ksplit,us1,kernelName1,err1,us2,kernelName2,err2,us,run_1stage,xbf16,tflops,bw,_tag +80,1,4096,128,512,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,16,0,13.677915789473445,fused_moe_asmjit_aot__16_True_False_False,0.42%,0.0,,0%,13.677915789473445,1,0,2.3,58877.29, +80,2,4096,128,512,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,16,0,35.211979999999954,fused_moe_asmjit_aot__16_True_False_False,0.56%,0.0,,0%,35.211979999999954,1,0,1.79,22870.94, +80,4,4096,128,512,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,16,0,47.743629999999555,fused_moe_asmjit_aot__16_True_False_False,0.58%,0.0,,0%,47.743629999999555,1,0,2.64,16868.33, +80,8,4096,128,512,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,16,0,63.788901098900865,fused_moe_asmjit_aot__16_True_False_False,0.70%,0.0,,0%,63.788901098900865,1,0,3.95,12626.09, +80,16,4096,128,512,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,16,0,92.46276842105215,fused_moe_asmjit_aot__16_True_False_False,0.24%,0.0,,0%,92.46276842105215,1,0,5.44,8711.65, +80,32,4096,128,512,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,16,0,135.92059999999984,fused_moe_asmjit_aot__16_True_False_False,0.26%,0.0,,0%,135.92059999999984,1,0,7.41,5927.72, +80,128,4096,128,512,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,16,0,355.59850999999986,fused_moe_asmjit_aot__64_True_True_False,2.37%,0.0,,0%,355.59850999999986,1,0,11.32,2269.07, +80,256,4096,128,512,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,16,0,396.20228999999995,fused_moe_asmjit_aot__64_True_True_False,2.33%,0.0,,0%,396.20228999999995,1,0,20.33,2040.5, +80,512,4096,128,512,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,16,0,411.33022448979534,fused_moe_asmjit_aot__64_True_True_False,2.01%,0.0,,0%,411.33022448979534,1,0,39.16,1973.11, +80,1024,4096,128,512,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,16,0,446.35456701031023,fused_moe_asmjit_aot__64_True_True_False,1.87%,0.0,,0%,446.35456701031023,1,0,72.17,1832.38, +80,2048,4096,128,512,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,16,0,576.2807849462362,fused_moe_asmjit_aot__64_True_True_False,1.68%,0.0,,0%,576.2807849462362,1,0,111.79,1441.09, +80,4096,4096,128,512,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,16,0,993.6361145833362,fused_moe_asmjit_aot__128_True_True_False,1.81%,0.0,,0%,993.6361145833362,1,0,129.67,861.12, +80,8192,4096,128,512,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,16,0,1695.620755102034,fused_moe_asmjit_aot__128_True_True_True,1.82%,0.0,,0%,1695.620755102034,1,0,151.98,534.3, +80,16384,4096,128,512,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,16,0,3004.3378800000028,fused_moe_asmjit_aot__128_True_True_True,1.82%,0.0,,0%,3004.3378800000028,1,0,171.55,335.06, +80,32768,4096,128,512,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,16,0,5748.212340425532,fused_moe_asmjit_aot__128_True_True_True,1.83%,0.0,,0%,5748.212340425532,1,0,179.32,210.15, +80,65536,4096,128,512,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,16,0,11247.049708333347,fused_moe_asmjit_aot__128_True_True_True,1.85%,0.0,,0%,11247.049708333347,1,0,183.3,143.2, +80,131072,4096,128,512,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,16,0,22489.665710000027,fused_moe_asmjit_aot__128_True_True_True,1.85%,0.0,,0%,22489.665710000027,1,0,183.34,107.42, diff --git a/aiter/configs/model_configs/qwen3_5_397b_fp8_ptpc_untuned_fmoe.csv b/aiter/configs/model_configs/qwen3_5_397b_fp8_ptpc_untuned_fmoe.csv new file mode 100644 index 0000000000..76a4be4677 --- /dev/null +++ b/aiter/configs/model_configs/qwen3_5_397b_fp8_ptpc_untuned_fmoe.csv @@ -0,0 +1,19 @@ +token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w,q_type,use_g1u1,doweight_stage1 +1,4096,128,512,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0 +2,4096,128,512,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0 +4,4096,128,512,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0 +8,4096,128,512,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0 +16,4096,128,512,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0 +32,4096,128,512,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0 +64,4096,128,512,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0 +128,4096,128,512,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0 +256,4096,128,512,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0 +512,4096,128,512,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0 +1024,4096,128,512,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0 +2048,4096,128,512,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0 +4096,4096,128,512,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0 +8192,4096,128,512,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0 +16384,4096,128,512,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0 +32768,4096,128,512,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0 +65536,4096,128,512,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0 +131072,4096,128,512,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0 diff --git a/aiter/fused_moe.py b/aiter/fused_moe.py index 131a92d3c7..e7cc14dd30 100644 --- a/aiter/fused_moe.py +++ b/aiter/fused_moe.py @@ -346,6 +346,22 @@ def fused_moe_( gate_mode, ) + if metadata.stage0 is not None: + return metadata.stage0( + hidden_states, + w1, + w2, + topk_weight, + topk_ids, + activation, + quant_type, + w1_scale, + w2_scale, + expert_mask, + num_local_tokens, + moe_sorting_dispatch_policy, + ) + block_size_M = metadata.block_m if block_size_M is None else block_size_M # Ensure block_size_M is int (metadata.block_m from CSV may be float) if block_size_M is not None: @@ -719,6 +735,7 @@ class MOEMetadata: use_non_temporal_load: bool = True fuse_quant: str = "" stage2_has_bias: bool = False + stage0: Callable = None def _needs_swiglu_bias_support(dtype, quant_type): @@ -1131,6 +1148,20 @@ def _lookup_cfg(c2s): f"[fused_moe] using {'1stage' if run_1stage else '2stage'}{' xbf16' if run_1stage_xbf16 else ''} {'default' if cfg is None else tag} for {keys} " ) + if kernelName1.startswith("fused_moe_asmjit_aot"): + from aiter.fused_moe_asmjit_aot import fused_moe_asmjit_aot + + return MOEMetadata( + None, + None, + block_m, + ksplit, + run_1stage, + stage0=functools.partial( + fused_moe_asmjit_aot, config_string=kernelName1.split("__")[1] + ), + ) + def get_block_m() -> int: if q_dtype_a == dtypes.fp8: return 32 diff --git a/aiter/fused_moe_asmjit_aot.py b/aiter/fused_moe_asmjit_aot.py new file mode 100644 index 0000000000..6ec3f73bf3 --- /dev/null +++ b/aiter/fused_moe_asmjit_aot.py @@ -0,0 +1,360 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. +from typing import Any, Optional + +import torch + +import aiter +from aiter import ActivationType, QuantType +from aiter.jit.utils.chip_info import get_gfx +from aiter.fused_moe import moe_sorting +from csrc.cpp_itfs.hsaco_tools import hsaco + +from dataclasses import dataclass + + +@dataclass +class Config: + BLOCK_M: int + use_down_loopn: bool + use_prefill: bool + use_dyn_sched: bool + + def to_string(self): + return ( + str(self.BLOCK_M) + + "_" + + str(self.use_down_loopn) + + "_" + + str(self.use_prefill) + + "_" + + str(self.use_dyn_sched) + ) + + @classmethod + def from_string(cls, data: str): + parts = data.split("_") + return cls(*[eval(p) for p in parts]) + + +def get_tune_space(): + return [ + Config(16, True, False, False).to_string(), + Config(64, True, True, False).to_string(), + Config(128, True, True, False).to_string(), + Config(128, True, True, True).to_string(), + ] + + +def fused_moe_asmjit_aot( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weight: torch.Tensor, + topk_ids: torch.Tensor, + activation: ActivationType, + quant_type: QuantType, + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + expert_mask: Any, + num_local_tokens: Any, + moe_sorting_dispatch_policy: int, + config_string: str, +) -> Optional[torch.Tensor]: + + # decode kernel configs from kernel name + kcfgs = Config.from_string(config_string) + + B = int(hidden_states.shape[0]) + if ( + hidden_states.dtype != torch.bfloat16 + or expert_mask is not None + or activation != ActivationType.Silu + or w1.dtype != torch.float8_e4m3fnuz + or w2.dtype != torch.float8_e4m3fnuz + ): + raise Exception("Unsupported input") + if get_gfx() != "gfx942": + raise Exception("Unsupported platform") + + if quant_type != QuantType.per_Token and quant_type != QuantType.per_Tensor: + raise Exception(f"Unsupported quant_type:{quant_type}") + + qtype_str = str(quant_type).split(".")[1] + + E, N1, K1 = w1.shape + N2, K2 = w2.shape[1], w2.shape[2] + TOPK = topk_ids.shape[1] + fp8_ptpc = w1.dtype in (torch.float8_e4m3fn, torch.float8_e4m3fnuz) and ( + quant_type == QuantType.per_Token + ) + num_CU = torch.cuda.get_device_properties( + hidden_states.device + ).multi_processor_count + assert N1 == 2 * K2 + + topk_w_f32 = ( + topk_weight if topk_weight.dtype == torch.float32 else topk_weight.float() + ) + + gemm1_out = torch.empty( + [B, TOPK, N1 // 2], + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + + if kcfgs.use_prefill: + sorted_ids, sorted_weights, sorted_expert_ids, num_valid_ids, cur_out = ( + moe_sorting( + topk_ids, + topk_weight, + E, + N2, # reduce dim is same with output dim + hidden_states.dtype, + kcfgs.BLOCK_M, + None, + None, + 0, + ) + ) + quant_func = aiter.get_hip_quant(aiter.QuantType.per_Token) + hidden_states_q, hidden_states_scale = quant_func( + hidden_states, + scale=None, + quant_dtype=w1.dtype, + num_rows=None, + ) + if kcfgs.use_dyn_sched: + dyn_buf1 = torch.zeros(64, dtype=torch.int32, device=hidden_states_q.device) + dyn_buf2 = torch.zeros(64, dtype=torch.int32, device=hidden_states_q.device) + grid_gate_up = torch.cuda.get_device_properties().multi_processor_count + grid_down = torch.cuda.get_device_properties().multi_processor_count * 2 # occupancy is 2 + GATEUP_BLOCK_TILE_SIZE_N = 256 + DOWN_BLOCK_TILE_SIZE_N = 128 + else: + GATEUP_BLOCK_TILE_SIZE_N = 128 + DOWN_BLOCK_TILE_SIZE_N = 128 + dyn_buf1 = None + dyn_buf2 = None + grid_gate_up = N1 // GATEUP_BLOCK_TILE_SIZE_N * sorted_expert_ids.shape[0] + grid_down = sorted_expert_ids.shape[0] + + hsaco.fmoe_asmjit.moe_2stage_gateup( + [grid_gate_up], + [256], + dyn_buf1, + hidden_states_q, + w1, + gemm1_out, + sorted_ids, + sorted_expert_ids, + num_valid_ids, + hidden_states_scale, + w1_scale, + B, + N1 // GATEUP_BLOCK_TILE_SIZE_N * sorted_expert_ids.shape[0], + weight_dtype=str(w1.dtype), + TOPK=TOPK, + K=K1, + N=N1, + BLOCK_TILE_SIZE_M=kcfgs.BLOCK_M, + BLOCK_TILE_SIZE_N=GATEUP_BLOCK_TILE_SIZE_N, + quant_type_w=f"QuantType.{qtype_str}", + dyn=kcfgs.use_dyn_sched, + ) + gemm1_out_q, gemm1_out_scale = quant_func( + gemm1_out.view(B * TOPK, -1), + scale=None, + quant_dtype=w2.dtype, + num_rows=None, + ) + gemm2_out = torch.empty( + B, TOPK, N2, dtype=torch.bfloat16, device=gemm1_out_q.device + ) + hsaco.fmoe_asmjit.moe_2stage_down( + [grid_down], + [256], + dyn_buf2, + gemm1_out_q, + w2, + gemm2_out, # cur_out, + sorted_ids, + sorted_weights, + sorted_expert_ids, + num_valid_ids, + gemm1_out_scale, + w2_scale, + B, + sorted_expert_ids.shape[0], + weight_dtype=str(w2.dtype), + TOPK=TOPK, + K=K2, + N=N2, + with_silu=False, + BLOCK_TILE_SIZE_M=kcfgs.BLOCK_M, + BLOCK_TILE_SIZE_N=DOWN_BLOCK_TILE_SIZE_N, + quant_type_w=f"QuantType.{qtype_str}", + dyn=kcfgs.use_dyn_sched, + ) + num_WG = num_CU * 4 + num_tokens_wg = B // num_WG + num_extra_tokens = B % num_WG + hsaco.fmoe_asmjit.moe_gemm_final_reduce_bf16( + [num_WG], + [64], + gemm2_out, + cur_out, + num_tokens_wg, + num_extra_tokens, + B, + TOPK=TOPK, + OC=N2, + ) + return cur_out + + if B == 1: + assert N1 == 2 * K2 + cur_out = torch.zeros( + [1, N2], dtype=hidden_states.dtype, device=hidden_states.device + ) + hsaco.fmoe_asmjit.moe_gemm_batch1( + [N1 // 32, TOPK], + [256], + hidden_states, + w1, + gemm1_out, + topk_ids, + topk_w_f32, + w1_scale, + 1, + N1, + K1, + weight_dtype=torch.float8_e4m3fnuz, + with_silu=True, + quant_type_str=qtype_str, + ) + hsaco.fmoe_asmjit.moe_gemm_batch1( + [N2 // 32, TOPK], + [64], + gemm1_out, + w2, + cur_out, + topk_ids, + topk_w_f32, + w2_scale, + 1, + N2, + K2, + weight_dtype=torch.float8_e4m3fnuz, + with_silu=False, + quant_type_str=qtype_str, + ) + elif 2 <= B <= 32: + # Stage 1: Shared ``moe_sorting`` + ``moe_gemm_batch``; + # stage 2: Choose between ``moe_2stage_down_loopn`` and ``moe_2stage_splitk`` based on ``use_down_loopn`` condition. + BLOCK_M = kcfgs.BLOCK_M + sorted_ids, sorted_weights, sorted_expert_ids, num_valid_ids, cur_out = ( + moe_sorting( + topk_ids, + topk_weight, + E, + K1, + hidden_states.dtype, + BLOCK_M, + expert_mask, + num_local_tokens, + moe_sorting_dispatch_policy, + ) + ) + grid = int(sorted_expert_ids.shape[0]) + if B * TOPK <= E: + grid = B * TOPK + + hsaco.fmoe_asmjit.moe_gemm_batch( + [N1 // 32, grid], + [256], + hidden_states, + w1, + gemm1_out, + sorted_ids, + sorted_weights, + sorted_expert_ids, + num_valid_ids, + w1_scale, + B, + N1, + K1, + TOPK, + weight_dtype=torch.float8_e4m3fnuz, + with_silu=True, + quant_type_str=qtype_str, + ) + + BLOCK_N = 1024 + if kcfgs.use_down_loopn: + # extra checks + use_down_loopn = ( + fp8_ptpc + and (N2 // BLOCK_N) * grid >= num_CU + and N2 % BLOCK_N == 0 + and 16 <= B <= 32 + ) + else: + use_down_loopn = False + + if use_down_loopn: + gemm2_out = torch.empty( + [B, TOPK, N2], + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + hsaco.fmoe_asmjit.moe_2stage_down_loopn( + [N2 // BLOCK_N, grid], + [256], + gemm1_out, + w2, + gemm2_out, + sorted_ids, + sorted_weights, + sorted_expert_ids, + num_valid_ids, + w2_scale, + B, + weight_dtype=torch.float8_e4m3fnuz, + TOPK=TOPK, + K=K2, + N=N2, + BLOCK_TILE_SIZE_M=16, + BLOCK_TILE_SIZE_N=16, + fp8_ptpc=True, + BLOCK_N=BLOCK_N, + atomic_write=False, + STAGES=3, + ) + cur_out = torch.sum(gemm2_out, dim=1) + else: + BLOCK_TILE_SIZE_N = 64 + hsaco.fmoe_asmjit.moe_2stage_splitk( + [N2 // BLOCK_TILE_SIZE_N, grid], + [64], + gemm1_out, + w2, + cur_out, + sorted_ids, + sorted_weights, + sorted_expert_ids, + num_valid_ids, + w2_scale, + B, + weight_dtype=torch.float8_e4m3fnuz, + TOPK=TOPK, + K=K2, + N=N2, + with_silu=False, + BLOCK_TILE_SIZE_M=16, + BLOCK_TILE_SIZE_N=BLOCK_TILE_SIZE_N, + quant_type_str=qtype_str, + ) + else: + raise Exception(f"Unsupported batch-size {B}") + return cur_out diff --git a/aiter/utility/base_tuner.py b/aiter/utility/base_tuner.py index f887451651..1890ab1bff 100644 --- a/aiter/utility/base_tuner.py +++ b/aiter/utility/base_tuner.py @@ -201,6 +201,12 @@ def _setup_common_arguments(self): "If a tuned CSV path is given, read shapes and kernels from it; " "otherwise read shapes from -i and run with default kernels.", ) + self.parser.add_argument( + "--e2e_tune", + action="store_true", + required=False, + help="Run an extra round of e2e tuning after main tuning is done, using production-op benchmark as the indicator", + ) self.parser.add_argument( "--compare", action="store_true", diff --git a/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_tune.py b/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_tune.py index f2a40f9c92..382914e39b 100644 --- a/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_tune.py +++ b/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_tune.py @@ -23,6 +23,8 @@ cktile_moe_stage1, cktile_moe_stage2, ) +from aiter.fused_moe_asmjit_aot import fused_moe_asmjit_aot +from aiter.fused_moe_asmjit_aot import get_tune_space from aiter import ck_moe_stage1_fwd, ck_moe_stage2_fwd, dtype2str_dict from aiter.ops.shuffle import ( shuffle_weight, @@ -2812,10 +2814,12 @@ def gen_flydsl_i4_2stages_task(self, info, blockMs): return tasks_flydsl - def run_config(self, args): + def run_config(self, args, target_fused_moe=None, try_extra_ref=False): from aiter.fused_moe import fused_moe, fused_topk from aiter.test_common import run_perftest, checkAllclose + if target_fused_moe is None: + target_fused_moe = fused_moe untunedf = self.untunedf results = [] for i in range(len(untunedf)): @@ -2985,7 +2989,7 @@ def run_config(self, args): a1_qt, a1_scale = torch_quant(hidden, quant_dtype=q_dtype_a) out, us = run_perftest( - fused_moe, + target_fused_moe, hidden, w1_qt_fmoe, w2_qt_fmoe, @@ -3019,6 +3023,62 @@ def run_config(self, args): err_ratio = 1.0 else: err_ratio = checkAllclose(out, ref, msg=f"run_config {shape_str}") + if try_extra_ref: + # try compare with extra references (due to different implementations) + try: + # use weight-decompression only algorithm as second reference + w1_deq = w1_qt.to(dtype=hidden.dtype) * w1_scale.view( + w1_scale.shape[0], -1, 1 + ).to(dtype=hidden.dtype) + w2_deq = w2_qt.to(dtype=hidden.dtype) * w2_scale.view( + w2_scale.shape[0], -1, 1 + ).to(dtype=hidden.dtype) + + ref2 = self.torch_moe_2stages( + hidden, + w1_deq, + w2_deq, + topk_weights, + topk_ids, + dtype=dtype, + activation=act_type, + quant_type=QuantType.No, + doweight_stage1=doweight_stage1, + ) + err_ratio2 = checkAllclose( + out, ref2, msg=f"run_config {shape_str}" + ) + err_ratio = min(err_ratio, err_ratio2) + except Exception: + pass + + if q_type == QuantType.per_Tensor: + try: + # inputs are quantized per-Token while weights are quantized per-Tensor + a1_qt, a1_scale = aiter.get_torch_quant( + QuantType.per_Token + )(hidden, quant_dtype=q_dtype_a) + ref2 = self.torch_moe_2stages( + a1_qt, + w1_qt, + w2_qt, + topk_weights, + topk_ids, + a1_scale=a1_scale, + w1_scale=w1_scale, + w2_scale=w2_scale, + dtype=dtype, + activation=act_type, + quant_type=QuantType.per_Token, + doweight_stage1=doweight_stage1, + ) + err_ratio2 = checkAllclose( + out, ref2, msg=f"run_config {shape_str}" + ) + err_ratio = min(err_ratio, err_ratio2) + except Exception: + pass + if err_ratio <= args.errRatio: status = "ok" else: @@ -3029,6 +3089,7 @@ def run_config(self, args): "e2e_us": us, "kernel_us": kernel_us, "status": status, + "err_ratio": err_ratio, } ) except Exception as e: @@ -3038,6 +3099,7 @@ def run_config(self, args): "e2e_us": -1, "kernel_us": kernel_us, "status": f"error:{e}", + "err_ratio": 1, } ) finally: @@ -3666,6 +3728,178 @@ def pre_process(self, args): ) self.untunedf = self.untunedf[~mask] + def e2e_tune(self, args): + """ + choosing best kernels based on (stage1_us + stage2_us) or (single_stage_us) + may overlook some overheads between stages, and this e2e tune is a complement. + """ + results_base = self.run_config(args, target_fused_moe=None, try_extra_ref=True) + better_kernels = {} + cu_num = self.get_cu_num() + + for i in range(len(self.untunedf)): + e2e_us = results_base[i]["e2e_us"] + err_ratio = results_base[i]["err_ratio"] + row = self.untunedf.iloc[i] + cu_num = int(row["cu_num"]) + token = int(row["token"]) + model_dim = int(row["model_dim"]) + inter_dim = int(row["inter_dim"]) + expert = int(row["expert"]) + topk = int(row["topk"]) + act_type = eval(row["act_type"]) + dtype = eval(row["dtype"]) + q_dtype_a = eval(row["q_dtype_a"]) + q_dtype_w = eval(row["q_dtype_w"]) + q_type = eval(row["q_type"]) + q_type = QuantType.per_1x128 if q_type == QuantType.per_128x128 else q_type + use_g1u1 = bool(row["use_g1u1"]) + doweight_stage1 = bool(row["doweight_stage1"]) + key = ( + cu_num, + token, + model_dim, + inter_dim, + expert, + topk, + act_type, + dtype, + q_dtype_a, + q_dtype_w, + q_type, + use_g1u1, + doweight_stage1, + ) + keyname = " ".join(map(str, row[self.keys].values)) + better_kernels[i] = { + "name": keyname, + "key": key, + "row": row, + "kernel_name": None, + "e2e_us": e2e_us, + "err_ratio": err_ratio, + "e2e_us_base": e2e_us, + "err_ratio_base": err_ratio, + } + print(keyname, e2e_us, err_ratio) + + from functools import partial + + def target_fused_moe( + hidden_states, + w1, # [expert(local_expert:EP), inter_dim*2, dim] N,K + w2, # [expert(local_expert:EP), dim, inter_dim] + topk_weight, + topk_ids, + expert_mask=None, + activation=ActivationType.Silu, + quant_type=QuantType.No, + doweight_stage1=False, + w1_scale=None, + w2_scale=None, + num_local_tokens=None, + moe_sorting_dispatch_policy=0, + dtype=None, + config_string="", + ): + return fused_moe_asmjit_aot( + hidden_states, + w1, + w2, + topk_weight, + topk_ids, + activation, + quant_type, + w1_scale, + w2_scale, + expert_mask, + num_local_tokens, + moe_sorting_dispatch_policy, + config_string=config_string, + ) + + GREEN = "\033[0;32m" + YELLOW = "\033[1;33m" + RED = "\033[0;31m" + END = "\033[0m" + for config_string in get_tune_space(): + results_cur = self.run_config( + args, + target_fused_moe=partial(target_fused_moe, config_string=config_string), + try_extra_ref=True, + ) + block_m = 16 + ksplit = 0 + run_1stage = 1 + err1 = "0%" + err2 = "0%" + kernelName1 = "fused_moe_asmjit_aot__" + config_string + kernelName2 = "" + xbf16 = 0 + for i in range(len(self.untunedf)): + k = better_kernels[i] + e2e_us = results_cur[i]["e2e_us"] + status = results_cur[i]["status"] + err_ratio = results_cur[i]["err_ratio"] + # skip invalid kernel + if e2e_us < 0 or status != "ok": + print( + f"{k['name']} {RED} {e2e_us=:.3f} {status=} {END} {kernelName1}" + ) + continue + row = self.untunedf.iloc[i] + print( + f"{k['name']} {YELLOW} {float(k['e2e_us_base']):.3f}us -> {float(e2e_us):.3f}us (err: {k['err_ratio']*100:.0f}%) {END} {kernelName1}" + ) + if e2e_us < k["e2e_us"]: + k["e2e_us"] = e2e_us + k["err_ratio"] = err_ratio + k["kernel_name"] = kernelName1 + tflops, bw = self.calculate( + (k["key"], "stage1", kernelName1, block_m, e2e_us, err1) + ) + k["results"] = ( + block_m, + ksplit, + e2e_us, + kernelName1, + f"{err_ratio*100:.2f}%", + 0.0, + kernelName2, + err2, + e2e_us, + run_1stage, + xbf16, + tflops, + bw, + ) + + tune_results = [] + + for i, k in better_kernels.items(): + if k["kernel_name"] is None: + continue + tune_results.append([*k["row"].values, *k["results"]]) + print( + f"{k['name']} {GREEN} {float(k['e2e_us_base']):.3f}us -> {float(k['e2e_us']):.3f}us (err: {k['err_ratio_base']*100:.0f}% -> {k['err_ratio']*100:.0f}%) {END} {k['kernel_name']}" + ) + + new_tunedf = pd.DataFrame(tune_results, columns=self.columns) + output_file = self.get_out_file(args.tune_file) + old_tunedf = self.get_tuned_gemm_list(output_file) + + if "_tag" == old_tunedf.columns[-1]: + new_tunedf["_tag"] = "" + self.columns.append("_tag") + + resultdf = self.update_tunedf(old_tunedf, new_tunedf) + + if "_tag" == old_tunedf.columns[-1]: + self.columns.pop(-1) + + resultdf.to_csv(output_file, index=False) + print(f"{args.tune_file} has been updated!") + if __name__ == "__main__": key = [ @@ -3701,4 +3935,8 @@ def pre_process(self, args): tuner = FmoeTuner("fmoeTuner", key, resultList, "fmoe tuner") args = tuner.parse_args() - tuner.run(args, False) + if args.e2e_tune: + tuner.pre_process(args) + tuner.e2e_tune(args) + else: + tuner.run(args, False) diff --git a/csrc/cpp_itfs/hsaco_tools.py b/csrc/cpp_itfs/hsaco_tools.py new file mode 100644 index 0000000000..6435e2b1b1 --- /dev/null +++ b/csrc/cpp_itfs/hsaco_tools.py @@ -0,0 +1,261 @@ +import ctypes +from ctypes.util import find_library +import functools +import torch +import os +import subprocess + +from aiter.jit.utils.chip_info import get_gfx +from csrc.cpp_itfs.utils import AITER_CORE_DIR + +_is_hip_library_api_supported_ = False + + +@functools.cache +def get_amdhip(): + global _is_hip_library_api_supported_ + + try: + lib = ctypes.CDLL(find_library("amdhip64")) + except Exception as e: + print(e) + torch_amdhip64 = os.path.join(torch.__path__[0], "lib", "libamdhip64.so") + print(f"Try {torch_amdhip64} instead...") + lib = ctypes.CDLL(torch_amdhip64) + lib.hipModuleLoad.argtypes = [ctypes.POINTER(ctypes.c_void_p), ctypes.c_char_p] + lib.hipModuleLoad.restype = ctypes.c_int32 + lib.hipModuleGetFunction.argtypes = [ + ctypes.POINTER(ctypes.c_void_p), + ctypes.c_void_p, + ctypes.c_char_p, + ] + lib.hipModuleGetFunction.restype = ctypes.c_int32 + lib.hipModuleLaunchKernel.argtypes = [ + ctypes.c_void_p, + ctypes.c_uint32, + ctypes.c_uint32, + ctypes.c_uint32, + ctypes.c_uint32, + ctypes.c_uint32, + ctypes.c_uint32, + ctypes.c_uint32, # unsigned int sharedMemBytes + ctypes.c_void_p, # hipStream_t stream + ctypes.c_void_p, # void **kernelParams + ctypes.c_void_p, # void **extra + ] + lib.hipModuleLaunchKernel.restype = ctypes.c_int32 + lib.hipGetErrorString.argtypes = [ctypes.c_int32] + lib.hipGetErrorString.restype = ctypes.c_char_p + + try: + lib.hipLibraryLoadFromFile.restype = ctypes.c_int32 + lib.hipLibraryLoadFromFile.argtypes = [ + ctypes.POINTER(ctypes.c_void_p), + ctypes.c_char_p, + ctypes.c_void_p, # hipJitOption *jitOptions + ctypes.c_void_p, # void **jitOptionsValues + ctypes.c_uint32, # unsigned int numJitOptions, + ctypes.c_void_p, # hipLibraryOption *libraryOptions + ctypes.c_void_p, # void **libraryOptionValues + ctypes.c_uint32, # unsigned int numLibraryOptions + ] + + lib.hipLibraryGetKernelCount.restype = ctypes.c_int32 + lib.hipLibraryGetKernelCount.argtypes = [ + ctypes.POINTER(ctypes.c_uint32), # unsigned int *count, + ctypes.c_void_p, # hipLibrary_t library + ] + + lib.hipLibraryEnumerateKernels.restype = ctypes.c_int32 + lib.hipLibraryEnumerateKernels.argtypes = [ + ctypes.POINTER(ctypes.c_void_p), # hipKernel_t *kernels + ctypes.c_uint32, # unsigned int numKernels, + ctypes.c_void_p, # hipLibrary_t library + ] + + lib.hipKernelGetName.restype = ctypes.c_int32 + lib.hipKernelGetName.argtypes = [ + ctypes.POINTER(ctypes.c_char_p), # const char **name + ctypes.c_void_p, # hipKernel_t kernel + ] + _is_hip_library_api_supported_ = True + except Exception: + _is_hip_library_api_supported_ = False + + return lib + + +def hip_check_error(err, *args): + if err != 0: + raise Exception( + "HIP error:" + + get_amdhip().hipGetErrorString(err).decode("utf-8") + + repr(args) + ) + + +@functools.cache +def get_lib(lib_fpath): + hip = get_amdhip() + p_lib = ctypes.c_void_p() + hip_check_error( + ( + hip.hipLibraryLoadFromFile( + ctypes.byref(p_lib), + lib_fpath.encode("utf-8"), + None, + None, + 0, + None, + None, + 0, + ) + if _is_hip_library_api_supported_ + else hip.hipModuleLoad(ctypes.byref(p_lib), lib_fpath.encode("utf-8")) + ), + lib_fpath, + ) + return p_lib + + +@functools.cache +def get_all_kernel_names(co_path): + # we need both demangle & symbol name for loading & argtype parsing + dynamic_syms_raw = subprocess.check_output( + ["/opt/rocm/llvm/bin/llvm-objdump", "--dynamic-syms", co_path] + ).decode("utf-8") + kernel_names = [] + for line_raw in dynamic_syms_raw.splitlines(): + ls = line_raw.split() + if len(ls) < 7: + continue + if ls[3] != ".text": + continue + symbol_name = line_raw.split()[6] + kernel_names.append(symbol_name) + return kernel_names + + +@functools.cache +def get_kernel(kernel_path_prefix, constexpr_args: tuple = ()): + """ + constexpr_args is compile-time args which are part of co-file name + """ + hip = get_amdhip() + + co_suffix = "" + for k, v in constexpr_args: + co_suffix += f"-{k}={v}" + co_suffix += ".co" + + if ":" in kernel_path_prefix: + # file contain many kernels, filename is not started with kernel name + kernel_path_base, kernel_name = kernel_path_prefix.split(":") + lib_fpath = kernel_path_base + co_suffix + else: + # file contain only one kernel, filename starts with kernel name + _, kernel_name = os.path.split(kernel_path_prefix) + lib_fpath = kernel_path_prefix + co_suffix + + p_lib = get_lib(lib_fpath) + + if _is_hip_library_api_supported_: + kernel_cnt = ctypes.c_uint32() + hip_check_error(hip.hipLibraryGetKernelCount(ctypes.byref(kernel_cnt), p_lib)) + + assert kernel_cnt.value > 0 + kernels = (ctypes.c_void_p * kernel_cnt.value)() + + hip_check_error(hip.hipLibraryEnumerateKernels(kernels, kernel_cnt, p_lib)) + + p_func = None + for k in kernels: + p_name = ctypes.c_char_p() + hip_check_error(hip.hipKernelGetName(ctypes.byref(p_name), k)) + assert p_name.value is not None + cur_kernel_name = p_name.value.decode("utf-8") + if kernel_name in cur_kernel_name: + p_func = k + break + else: + p_func = None + for cur_kernel_name in get_all_kernel_names(lib_fpath): + if kernel_name in cur_kernel_name: + p_func = ctypes.c_void_p() + hip_check_error( + hip.hipModuleGetFunction( + ctypes.byref(p_func), p_lib, cur_kernel_name.encode("utf-8") + ) + ) + break + + assert p_func is not None, f"kernel {kernel_name} is not found in {lib_fpath}" + + def CallableKernel( + gridDims: list[int], + blockDims: list[int], + *args, + sharedMemBytes=0, + ): + fields = [] + for i, arg in enumerate(args): + if arg is None or isinstance(arg, torch.Tensor): + fields.append((f"arg_{i}", ctypes.c_void_p)) + elif isinstance(arg, int): + # ctypes.c_uint/ctypes.c_ulong + fields.append((f"arg_{i}", ctypes.c_int)) + elif isinstance(arg, float): + fields.append((f"arg_{i}", ctypes.c_float)) + else: + raise Exception(f"Unsupported arg type: {arg}") + + class Args(ctypes.Structure): + _fields_ = fields + + kernel_args = Args() + for i, a in enumerate(args): + setattr( + kernel_args, + f"arg_{i}", + a.data_ptr() if isinstance(a, torch.Tensor) else a, + ) + ExtraType = ctypes.c_void_p * 5 + kernel_args_size = ctypes.c_uint64(ctypes.sizeof(kernel_args)) + kernel_config = ExtraType( + 1, ctypes.addressof(kernel_args), 2, ctypes.addressof(kernel_args_size), 3 + ) + stream = ctypes.cast(torch.cuda.current_stream(), ctypes.c_void_p) + while len(gridDims) < 3: + gridDims.append(1) + while len(blockDims) < 3: + blockDims.append(1) + hip_check_error( + hip.hipModuleLaunchKernel( + p_func, + *gridDims, + *blockDims, + sharedMemBytes, + stream, + 0, + ctypes.byref(kernel_config), + ) + ) + + return CallableKernel + + +class HSACO: + def __init__(self, base=None): + self.base = f"{AITER_CORE_DIR}/hsa/{get_gfx()}" if base is None else base + + def __getattr__(self, name): + return HSACO(f"{self.base}/{name}") + + def __call__(self, *args, **kwargs): + # kwargs is hsaco file name + # args is runtime-args + kernel = get_kernel(self.base, tuple(kwargs.items())) + kernel(*args) + + +hsaco = HSACO() diff --git a/hsa/gfx942/fmoe_asmjit/moe_2stage_down-weight_dtype=torch.float8_e4m3fnuz-TOPK=10-K=128-N=4096-with_silu=False-BLOCK_TILE_SIZE_M=128-BLOCK_TILE_SIZE_N=128-quant_type_w=QuantType.per_Token-dyn=False.co b/hsa/gfx942/fmoe_asmjit/moe_2stage_down-weight_dtype=torch.float8_e4m3fnuz-TOPK=10-K=128-N=4096-with_silu=False-BLOCK_TILE_SIZE_M=128-BLOCK_TILE_SIZE_N=128-quant_type_w=QuantType.per_Token-dyn=False.co new file mode 100755 index 0000000000..f625780018 Binary files /dev/null and b/hsa/gfx942/fmoe_asmjit/moe_2stage_down-weight_dtype=torch.float8_e4m3fnuz-TOPK=10-K=128-N=4096-with_silu=False-BLOCK_TILE_SIZE_M=128-BLOCK_TILE_SIZE_N=128-quant_type_w=QuantType.per_Token-dyn=False.co differ diff --git a/hsa/gfx942/fmoe_asmjit/moe_2stage_down-weight_dtype=torch.float8_e4m3fnuz-TOPK=10-K=128-N=4096-with_silu=False-BLOCK_TILE_SIZE_M=128-BLOCK_TILE_SIZE_N=128-quant_type_w=QuantType.per_Token-dyn=True.co b/hsa/gfx942/fmoe_asmjit/moe_2stage_down-weight_dtype=torch.float8_e4m3fnuz-TOPK=10-K=128-N=4096-with_silu=False-BLOCK_TILE_SIZE_M=128-BLOCK_TILE_SIZE_N=128-quant_type_w=QuantType.per_Token-dyn=True.co new file mode 100755 index 0000000000..47c06f5efd Binary files /dev/null and b/hsa/gfx942/fmoe_asmjit/moe_2stage_down-weight_dtype=torch.float8_e4m3fnuz-TOPK=10-K=128-N=4096-with_silu=False-BLOCK_TILE_SIZE_M=128-BLOCK_TILE_SIZE_N=128-quant_type_w=QuantType.per_Token-dyn=True.co differ diff --git a/hsa/gfx942/fmoe_asmjit/moe_2stage_down-weight_dtype=torch.float8_e4m3fnuz-TOPK=10-K=128-N=4096-with_silu=False-BLOCK_TILE_SIZE_M=64-BLOCK_TILE_SIZE_N=128-quant_type_w=QuantType.per_Token-dyn=False.co b/hsa/gfx942/fmoe_asmjit/moe_2stage_down-weight_dtype=torch.float8_e4m3fnuz-TOPK=10-K=128-N=4096-with_silu=False-BLOCK_TILE_SIZE_M=64-BLOCK_TILE_SIZE_N=128-quant_type_w=QuantType.per_Token-dyn=False.co new file mode 100755 index 0000000000..a7c1acb6f3 Binary files /dev/null and b/hsa/gfx942/fmoe_asmjit/moe_2stage_down-weight_dtype=torch.float8_e4m3fnuz-TOPK=10-K=128-N=4096-with_silu=False-BLOCK_TILE_SIZE_M=64-BLOCK_TILE_SIZE_N=128-quant_type_w=QuantType.per_Token-dyn=False.co differ diff --git a/hsa/gfx942/fmoe_asmjit/moe_2stage_down-weight_dtype=torch.float8_e4m3fnuz-TOPK=10-K=128-N=4096-with_silu=False-BLOCK_TILE_SIZE_M=64-BLOCK_TILE_SIZE_N=128-quant_type_w=QuantType.per_Token-dyn=True.co b/hsa/gfx942/fmoe_asmjit/moe_2stage_down-weight_dtype=torch.float8_e4m3fnuz-TOPK=10-K=128-N=4096-with_silu=False-BLOCK_TILE_SIZE_M=64-BLOCK_TILE_SIZE_N=128-quant_type_w=QuantType.per_Token-dyn=True.co new file mode 100755 index 0000000000..96c13a694e Binary files /dev/null and b/hsa/gfx942/fmoe_asmjit/moe_2stage_down-weight_dtype=torch.float8_e4m3fnuz-TOPK=10-K=128-N=4096-with_silu=False-BLOCK_TILE_SIZE_M=64-BLOCK_TILE_SIZE_N=128-quant_type_w=QuantType.per_Token-dyn=True.co differ diff --git a/hsa/gfx942/fmoe_asmjit/moe_2stage_down-weight_dtype=torch.float8_e4m3fnuz-TOPK=8-K=192-N=4096-with_silu=False-BLOCK_TILE_SIZE_M=128-BLOCK_TILE_SIZE_N=128-quant_type_w=QuantType.per_Tensor-dyn=False.co b/hsa/gfx942/fmoe_asmjit/moe_2stage_down-weight_dtype=torch.float8_e4m3fnuz-TOPK=8-K=192-N=4096-with_silu=False-BLOCK_TILE_SIZE_M=128-BLOCK_TILE_SIZE_N=128-quant_type_w=QuantType.per_Tensor-dyn=False.co new file mode 100755 index 0000000000..41a2c37d9a Binary files /dev/null and b/hsa/gfx942/fmoe_asmjit/moe_2stage_down-weight_dtype=torch.float8_e4m3fnuz-TOPK=8-K=192-N=4096-with_silu=False-BLOCK_TILE_SIZE_M=128-BLOCK_TILE_SIZE_N=128-quant_type_w=QuantType.per_Tensor-dyn=False.co differ diff --git a/hsa/gfx942/fmoe_asmjit/moe_2stage_down-weight_dtype=torch.float8_e4m3fnuz-TOPK=8-K=192-N=4096-with_silu=False-BLOCK_TILE_SIZE_M=64-BLOCK_TILE_SIZE_N=128-quant_type_w=QuantType.per_Tensor-dyn=False.co b/hsa/gfx942/fmoe_asmjit/moe_2stage_down-weight_dtype=torch.float8_e4m3fnuz-TOPK=8-K=192-N=4096-with_silu=False-BLOCK_TILE_SIZE_M=64-BLOCK_TILE_SIZE_N=128-quant_type_w=QuantType.per_Tensor-dyn=False.co new file mode 100755 index 0000000000..2f4f425764 Binary files /dev/null and b/hsa/gfx942/fmoe_asmjit/moe_2stage_down-weight_dtype=torch.float8_e4m3fnuz-TOPK=8-K=192-N=4096-with_silu=False-BLOCK_TILE_SIZE_M=64-BLOCK_TILE_SIZE_N=128-quant_type_w=QuantType.per_Tensor-dyn=False.co differ diff --git a/hsa/gfx942/fmoe_asmjit/moe_2stage_down_loopn-weight_dtype=torch.float8_e4m3fnuz-TOPK=10-K=128-N=4096-BLOCK_TILE_SIZE_M=16-BLOCK_TILE_SIZE_N=16-fp8_ptpc=True-BLOCK_N=1024-atomic_write=False-STAGES=3.co b/hsa/gfx942/fmoe_asmjit/moe_2stage_down_loopn-weight_dtype=torch.float8_e4m3fnuz-TOPK=10-K=128-N=4096-BLOCK_TILE_SIZE_M=16-BLOCK_TILE_SIZE_N=16-fp8_ptpc=True-BLOCK_N=1024-atomic_write=False-STAGES=3.co new file mode 100755 index 0000000000..b9240e5705 Binary files /dev/null and b/hsa/gfx942/fmoe_asmjit/moe_2stage_down_loopn-weight_dtype=torch.float8_e4m3fnuz-TOPK=10-K=128-N=4096-BLOCK_TILE_SIZE_M=16-BLOCK_TILE_SIZE_N=16-fp8_ptpc=True-BLOCK_N=1024-atomic_write=False-STAGES=3.co differ diff --git a/hsa/gfx942/fmoe_asmjit/moe_2stage_gateup-weight_dtype=torch.float8_e4m3fnuz-TOPK=10-K=4096-N=256-BLOCK_TILE_SIZE_M=128-BLOCK_TILE_SIZE_N=128-quant_type_w=QuantType.per_Token-dyn=False.co b/hsa/gfx942/fmoe_asmjit/moe_2stage_gateup-weight_dtype=torch.float8_e4m3fnuz-TOPK=10-K=4096-N=256-BLOCK_TILE_SIZE_M=128-BLOCK_TILE_SIZE_N=128-quant_type_w=QuantType.per_Token-dyn=False.co new file mode 100755 index 0000000000..9790f6f377 Binary files /dev/null and b/hsa/gfx942/fmoe_asmjit/moe_2stage_gateup-weight_dtype=torch.float8_e4m3fnuz-TOPK=10-K=4096-N=256-BLOCK_TILE_SIZE_M=128-BLOCK_TILE_SIZE_N=128-quant_type_w=QuantType.per_Token-dyn=False.co differ diff --git a/hsa/gfx942/fmoe_asmjit/moe_2stage_gateup-weight_dtype=torch.float8_e4m3fnuz-TOPK=10-K=4096-N=256-BLOCK_TILE_SIZE_M=128-BLOCK_TILE_SIZE_N=256-quant_type_w=QuantType.per_Token-dyn=True.co b/hsa/gfx942/fmoe_asmjit/moe_2stage_gateup-weight_dtype=torch.float8_e4m3fnuz-TOPK=10-K=4096-N=256-BLOCK_TILE_SIZE_M=128-BLOCK_TILE_SIZE_N=256-quant_type_w=QuantType.per_Token-dyn=True.co new file mode 100755 index 0000000000..bb3079a022 Binary files /dev/null and b/hsa/gfx942/fmoe_asmjit/moe_2stage_gateup-weight_dtype=torch.float8_e4m3fnuz-TOPK=10-K=4096-N=256-BLOCK_TILE_SIZE_M=128-BLOCK_TILE_SIZE_N=256-quant_type_w=QuantType.per_Token-dyn=True.co differ diff --git a/hsa/gfx942/fmoe_asmjit/moe_2stage_gateup-weight_dtype=torch.float8_e4m3fnuz-TOPK=10-K=4096-N=256-BLOCK_TILE_SIZE_M=64-BLOCK_TILE_SIZE_N=128-quant_type_w=QuantType.per_Token-dyn=False.co b/hsa/gfx942/fmoe_asmjit/moe_2stage_gateup-weight_dtype=torch.float8_e4m3fnuz-TOPK=10-K=4096-N=256-BLOCK_TILE_SIZE_M=64-BLOCK_TILE_SIZE_N=128-quant_type_w=QuantType.per_Token-dyn=False.co new file mode 100755 index 0000000000..c9ec44d3d3 Binary files /dev/null and b/hsa/gfx942/fmoe_asmjit/moe_2stage_gateup-weight_dtype=torch.float8_e4m3fnuz-TOPK=10-K=4096-N=256-BLOCK_TILE_SIZE_M=64-BLOCK_TILE_SIZE_N=128-quant_type_w=QuantType.per_Token-dyn=False.co differ diff --git a/hsa/gfx942/fmoe_asmjit/moe_2stage_gateup-weight_dtype=torch.float8_e4m3fnuz-TOPK=10-K=4096-N=256-BLOCK_TILE_SIZE_M=64-BLOCK_TILE_SIZE_N=256-quant_type_w=QuantType.per_Token-dyn=True.co b/hsa/gfx942/fmoe_asmjit/moe_2stage_gateup-weight_dtype=torch.float8_e4m3fnuz-TOPK=10-K=4096-N=256-BLOCK_TILE_SIZE_M=64-BLOCK_TILE_SIZE_N=256-quant_type_w=QuantType.per_Token-dyn=True.co new file mode 100755 index 0000000000..8fa7076ea8 Binary files /dev/null and b/hsa/gfx942/fmoe_asmjit/moe_2stage_gateup-weight_dtype=torch.float8_e4m3fnuz-TOPK=10-K=4096-N=256-BLOCK_TILE_SIZE_M=64-BLOCK_TILE_SIZE_N=256-quant_type_w=QuantType.per_Token-dyn=True.co differ diff --git a/hsa/gfx942/fmoe_asmjit/moe_2stage_gateup-weight_dtype=torch.float8_e4m3fnuz-TOPK=8-K=4096-N=384-BLOCK_TILE_SIZE_M=128-BLOCK_TILE_SIZE_N=128-quant_type_w=QuantType.per_Tensor-dyn=False.co b/hsa/gfx942/fmoe_asmjit/moe_2stage_gateup-weight_dtype=torch.float8_e4m3fnuz-TOPK=8-K=4096-N=384-BLOCK_TILE_SIZE_M=128-BLOCK_TILE_SIZE_N=128-quant_type_w=QuantType.per_Tensor-dyn=False.co new file mode 100755 index 0000000000..9954e2437a Binary files /dev/null and b/hsa/gfx942/fmoe_asmjit/moe_2stage_gateup-weight_dtype=torch.float8_e4m3fnuz-TOPK=8-K=4096-N=384-BLOCK_TILE_SIZE_M=128-BLOCK_TILE_SIZE_N=128-quant_type_w=QuantType.per_Tensor-dyn=False.co differ diff --git a/hsa/gfx942/fmoe_asmjit/moe_2stage_gateup-weight_dtype=torch.float8_e4m3fnuz-TOPK=8-K=4096-N=384-BLOCK_TILE_SIZE_M=64-BLOCK_TILE_SIZE_N=128-quant_type_w=QuantType.per_Tensor-dyn=False.co b/hsa/gfx942/fmoe_asmjit/moe_2stage_gateup-weight_dtype=torch.float8_e4m3fnuz-TOPK=8-K=4096-N=384-BLOCK_TILE_SIZE_M=64-BLOCK_TILE_SIZE_N=128-quant_type_w=QuantType.per_Tensor-dyn=False.co new file mode 100755 index 0000000000..e84334bca7 Binary files /dev/null and b/hsa/gfx942/fmoe_asmjit/moe_2stage_gateup-weight_dtype=torch.float8_e4m3fnuz-TOPK=8-K=4096-N=384-BLOCK_TILE_SIZE_M=64-BLOCK_TILE_SIZE_N=128-quant_type_w=QuantType.per_Tensor-dyn=False.co differ diff --git a/hsa/gfx942/fmoe_asmjit/moe_2stage_splitk-weight_dtype=torch.float8_e4m3fnuz-TOPK=10-K=128-N=4096-with_silu=False-BLOCK_TILE_SIZE_M=16-BLOCK_TILE_SIZE_N=64-quant_type_str=per_Token.co b/hsa/gfx942/fmoe_asmjit/moe_2stage_splitk-weight_dtype=torch.float8_e4m3fnuz-TOPK=10-K=128-N=4096-with_silu=False-BLOCK_TILE_SIZE_M=16-BLOCK_TILE_SIZE_N=64-quant_type_str=per_Token.co new file mode 100755 index 0000000000..e33e69a3e0 Binary files /dev/null and b/hsa/gfx942/fmoe_asmjit/moe_2stage_splitk-weight_dtype=torch.float8_e4m3fnuz-TOPK=10-K=128-N=4096-with_silu=False-BLOCK_TILE_SIZE_M=16-BLOCK_TILE_SIZE_N=64-quant_type_str=per_Token.co differ diff --git a/hsa/gfx942/fmoe_asmjit/moe_2stage_splitk-weight_dtype=torch.float8_e4m3fnuz-TOPK=8-K=192-N=4096-with_silu=False-BLOCK_TILE_SIZE_M=16-BLOCK_TILE_SIZE_N=64-quant_type_str=per_Tensor.co b/hsa/gfx942/fmoe_asmjit/moe_2stage_splitk-weight_dtype=torch.float8_e4m3fnuz-TOPK=8-K=192-N=4096-with_silu=False-BLOCK_TILE_SIZE_M=16-BLOCK_TILE_SIZE_N=64-quant_type_str=per_Tensor.co new file mode 100755 index 0000000000..eaf944d1fb Binary files /dev/null and b/hsa/gfx942/fmoe_asmjit/moe_2stage_splitk-weight_dtype=torch.float8_e4m3fnuz-TOPK=8-K=192-N=4096-with_silu=False-BLOCK_TILE_SIZE_M=16-BLOCK_TILE_SIZE_N=64-quant_type_str=per_Tensor.co differ diff --git a/hsa/gfx942/fmoe_asmjit/moe_gemm_batch-weight_dtype=torch.float8_e4m3fnuz-with_silu=True-quant_type_str=per_Tensor.co b/hsa/gfx942/fmoe_asmjit/moe_gemm_batch-weight_dtype=torch.float8_e4m3fnuz-with_silu=True-quant_type_str=per_Tensor.co new file mode 100755 index 0000000000..1426c2b049 Binary files /dev/null and b/hsa/gfx942/fmoe_asmjit/moe_gemm_batch-weight_dtype=torch.float8_e4m3fnuz-with_silu=True-quant_type_str=per_Tensor.co differ diff --git a/hsa/gfx942/fmoe_asmjit/moe_gemm_batch-weight_dtype=torch.float8_e4m3fnuz-with_silu=True-quant_type_str=per_Token.co b/hsa/gfx942/fmoe_asmjit/moe_gemm_batch-weight_dtype=torch.float8_e4m3fnuz-with_silu=True-quant_type_str=per_Token.co new file mode 100755 index 0000000000..954345ac16 Binary files /dev/null and b/hsa/gfx942/fmoe_asmjit/moe_gemm_batch-weight_dtype=torch.float8_e4m3fnuz-with_silu=True-quant_type_str=per_Token.co differ diff --git a/hsa/gfx942/fmoe_asmjit/moe_gemm_batch1-weight_dtype=torch.float8_e4m3fnuz-with_silu=False-quant_type_str=per_Tensor.co b/hsa/gfx942/fmoe_asmjit/moe_gemm_batch1-weight_dtype=torch.float8_e4m3fnuz-with_silu=False-quant_type_str=per_Tensor.co new file mode 100755 index 0000000000..0f8933247e Binary files /dev/null and b/hsa/gfx942/fmoe_asmjit/moe_gemm_batch1-weight_dtype=torch.float8_e4m3fnuz-with_silu=False-quant_type_str=per_Tensor.co differ diff --git a/hsa/gfx942/fmoe_asmjit/moe_gemm_batch1-weight_dtype=torch.float8_e4m3fnuz-with_silu=False-quant_type_str=per_Token.co b/hsa/gfx942/fmoe_asmjit/moe_gemm_batch1-weight_dtype=torch.float8_e4m3fnuz-with_silu=False-quant_type_str=per_Token.co new file mode 100755 index 0000000000..ae88ad1399 Binary files /dev/null and b/hsa/gfx942/fmoe_asmjit/moe_gemm_batch1-weight_dtype=torch.float8_e4m3fnuz-with_silu=False-quant_type_str=per_Token.co differ diff --git a/hsa/gfx942/fmoe_asmjit/moe_gemm_batch1-weight_dtype=torch.float8_e4m3fnuz-with_silu=True-quant_type_str=per_Tensor.co b/hsa/gfx942/fmoe_asmjit/moe_gemm_batch1-weight_dtype=torch.float8_e4m3fnuz-with_silu=True-quant_type_str=per_Tensor.co new file mode 100755 index 0000000000..b7670ee71b Binary files /dev/null and b/hsa/gfx942/fmoe_asmjit/moe_gemm_batch1-weight_dtype=torch.float8_e4m3fnuz-with_silu=True-quant_type_str=per_Tensor.co differ diff --git a/hsa/gfx942/fmoe_asmjit/moe_gemm_batch1-weight_dtype=torch.float8_e4m3fnuz-with_silu=True-quant_type_str=per_Token.co b/hsa/gfx942/fmoe_asmjit/moe_gemm_batch1-weight_dtype=torch.float8_e4m3fnuz-with_silu=True-quant_type_str=per_Token.co new file mode 100755 index 0000000000..301183d02e Binary files /dev/null and b/hsa/gfx942/fmoe_asmjit/moe_gemm_batch1-weight_dtype=torch.float8_e4m3fnuz-with_silu=True-quant_type_str=per_Token.co differ diff --git a/hsa/gfx942/fmoe_asmjit/moe_gemm_final_reduce_bf16-TOPK=10-OC=4096.co b/hsa/gfx942/fmoe_asmjit/moe_gemm_final_reduce_bf16-TOPK=10-OC=4096.co new file mode 100755 index 0000000000..422ce62298 Binary files /dev/null and b/hsa/gfx942/fmoe_asmjit/moe_gemm_final_reduce_bf16-TOPK=10-OC=4096.co differ diff --git a/hsa/gfx942/fmoe_asmjit/moe_gemm_final_reduce_bf16-TOPK=8-OC=4096.co b/hsa/gfx942/fmoe_asmjit/moe_gemm_final_reduce_bf16-TOPK=8-OC=4096.co new file mode 100755 index 0000000000..6cf62637fb Binary files /dev/null and b/hsa/gfx942/fmoe_asmjit/moe_gemm_final_reduce_bf16-TOPK=8-OC=4096.co differ