diff --git a/aiter/aot/flydsl/moe.py b/aiter/aot/flydsl/moe.py index 5dc751d647..0cdd8ec273 100644 --- a/aiter/aot/flydsl/moe.py +++ b/aiter/aot/flydsl/moe.py @@ -94,6 +94,8 @@ def parse_csv(csv_path: str): experts = int(row["expert"]) topk = int(row["topk"]) doweight_stage1 = bool(int(row.get("doweight_stage1", "0"))) + hidden_pad = int(row.get("hidden_pad", "0") or "0") + intermediate_pad = int(row.get("intermediate_pad", "0") or "0") cu_num = int(row.get("cu_num", "0")) block_m = int(row.get("block_m", "0") or "0") act_type = row.get("act_type", "") @@ -102,19 +104,8 @@ def parse_csv(csv_path: str): if act_type.strip().split(".")[-1].lower() == "swiglu" else "silu" ) - q_type = row.get("q_type", "") - dtype = row.get("dtype", "") - q_dtype_w = row.get("q_dtype_w", "") swiglu_limit = _row_swiglu_limit(row) - # Cover both runtime bias choices for fp4-weight MoE. Model configs - # share kernel families, and runtime bias selection can vary by - # activation dtype/model semantics. - bias_supported = ( - q_type.strip().split(".")[-1] == "per_1x32" - and dtype in ("torch.bfloat16", "torch.float16") - and "float4_e2m1fn_x2" in q_dtype_w - ) - enable_bias_options = [False, True] if bias_supported else [False] + enable_bias_options = [str(row.get("bias", "")).strip() == "True"] # Detect stage1's fuse_quant from kernel suffix to align stage2's # a2_scale shape with what runtime actually passes. @@ -144,6 +135,8 @@ def parse_csv(csv_path: str): "experts": experts, "topk": topk, "doweight_stage1": doweight_stage1, + "hidden_pad": hidden_pad, + "intermediate_pad": intermediate_pad, "cu_num": cu_num, "act": act, "enable_bias": enable_bias, @@ -208,6 +201,8 @@ def _precompile_to_cache( enable_bias: bool = False, stage1_fuse_quant=None, swiglu_limit: float = 0.0, + hidden_pad: int = 0, + intermediate_pad: int = 0, # Stage2-only kernel tuning knobs (registered by the production-variant # entries in `get_flydsl_stage2_kernels`). Forwarded into # `compile_flydsl_moe_stage2` for stage 2 AOT compilation. @@ -567,6 +562,8 @@ def _make_a_user(a_dtype_user_shape): a_scale_one=a_scale_one, xcd_swizzle=xcd_swizzle, swiglu_limit=swiglu_limit, + model_dim_pad=hidden_pad, + inter_dim_pad=intermediate_pad, ) _run_compiled(exe, args) @@ -739,6 +736,8 @@ def _make_a_user(a_dtype_user_shape): b_nt=b_nt, xcd_swizzle=xcd_swizzle, enable_bias=enable_bias, + model_dim_pad=hidden_pad, + inter_dim_pad=intermediate_pad, ) _run_compiled(exe, args) @@ -750,6 +749,8 @@ def compile_one_config( experts: int, topk: int, cu_num: int = 0, + hidden_pad: int = 0, + intermediate_pad: int = 0, **kwargs, ) -> dict: """Compile one MoE kernel configuration and save to cache. @@ -763,6 +764,8 @@ def compile_one_config( shape_str = ( f"{kernel_name} " f"model_dim={model_dim} inter_dim={inter_dim} " + f"hidden_pad={hidden_pad} " + f"intermediate_pad={intermediate_pad} " f"E={experts} topk={topk}" ) result = { @@ -785,6 +788,8 @@ def compile_one_config( experts=experts, topk=topk, cu_num=cu_num, + hidden_pad=hidden_pad, + intermediate_pad=intermediate_pad, **kwargs, ) elapsed = time.time() - t0 diff --git a/aiter/configs/model_configs/gptoss_fp4_tuned_fmoe.csv b/aiter/configs/model_configs/gptoss_fp4_tuned_fmoe.csv index e1bc608258..3f32cec7eb 100644 --- a/aiter/configs/model_configs/gptoss_fp4_tuned_fmoe.csv +++ b/aiter/configs/model_configs/gptoss_fp4_tuned_fmoe.csv @@ -1,25 +1,25 @@ -cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w,q_type,use_g1u1,doweight_stage1,block_m,ksplit,us1,kernelName1,err1,us2,kernelName2,err2,us,run_1stage,xbf16,tflops,bw,_tag -256,256,3072,3072,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,208.4462,flydsl_moe1_afp4_wfp4_bf16_t32x128x256_w2,0.0%,109.1431,flydsl_moe2_afp4_wfp4_bf16_t32x256x256_atomic_bnt2_persist,5.0%,317.5893,0,0,182.57,11418.01, -256,512,3072,3072,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,215.3665,flydsl_moe1_afp4_wfp4_bf16_t32x128x256_w2,0.0%,109.5792,flydsl_moe2_afp4_wfp4_bf16_t32x128x256_atomic_bnt2,5.1%,324.9457,0,0,356.87,11166.78, -256,1024,3072,3072,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,208.4256,flydsl_moe1_afp4_wfp4_bf16_t64x128x256_w3_fp4,0.8%,121.3742,flydsl_moe2_afp4_wfp4_bf16_t64x256x256_atomic_bnt2,5.1%,329.7998,0,0,703.24,11016.73, -256,2048,3072,3072,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,289.0776,flydsl_moe1_afp4_wfp4_bf16_t64x128x256_w3_bnt0,0.0%,151.789,flydsl_moe2_afp4_wfp4_bf16_t64x128x256_atomic,5.0%,440.8666,0,0,1052.15,8262.71, -256,4096,3072,3072,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,128,0,404.0325,flydsl_moe1_afp4_wfp4_bf16_t128x128x256_w2_bnt0,0.0%,210.7843,flydsl_moe2_afp4_wfp4_bf16_t128x128x256_atomic,5.1%,614.8168,0,0,1508.93,5955.64, -256,8192,3072,3072,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,128,0,1319.2527,flydsl_moe1_afp4_wfp4_bf16_t128x128x256_w4_bnt0,0.0%,637.455,flydsl_moe2_afp4_wfp4_bf16_t128x128x256_atomic,5.0%,1956.7077,0,0,948.24,1890.61, -256,16384,3072,3072,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,1739.6561,flydsl_moe1_afp4_wfp4_bf16_t64x128x256_w4_bnt0,0.0%,940.8593,flydsl_moe2_afp4_wfp4_bf16_t64x128x256_atomic,5.1%,2680.5154,0,0,1384.38,1408.26, -256,32768,3072,3072,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,1991.1671,flydsl_moe1_afp4_wfp4_bf16_t64x128x256_w4_bnt0,0.0%,2466.0812,flydsl_moe2_afp4_wfp4_bf16_t64x128x256_atomic,5.0%,4457.2483,0,0,1665.09,880.78, -256,256,3072,1536,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,98.4171,flydsl_moe1_afp4_wfp4_bf16_t32x128x256_w2_fp4,0.9%,61.0182,flydsl_moe2_afp4_wfp4_bf16_t32x256x256_reduce_bnt2_persist,2.0%,159.4353,0,0,181.84,11379.53, -256,512,3072,1536,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,102.0754,flydsl_moe1_afp4_wfp4_bf16_t32x128x256_w4_fp4,0.9%,65.2416,flydsl_moe2_afp4_wfp4_bf16_t32x256x256_reduce_bnt2_persist,2.0%,167.317,0,0,346.54,10857.58, -256,1024,3072,1536,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,108.8592,flydsl_moe1_afp4_wfp4_bf16_t64x128x256_w4_fp4,0.8%,71.2891,flydsl_moe2_afp4_wfp4_bf16_t64x256x256_reduce_bnt2,2.0%,180.1483,0,0,643.71,10110.43, -256,2048,3072,1536,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,149.6562,flydsl_moe1_afp4_wfp4_bf16_t64x64x256_w4_bnt0,0.0%,92.8451,flydsl_moe2_afp4_wfp4_bf16_t64x128x256_atomic,3.6%,242.5013,0,0,956.4,7549.71, -256,4096,3072,1536,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,128,0,229.8446,flydsl_moe1_afp4_wfp4_bf16_t128x128x256_w2_bnt0,0.0%,142.3561,flydsl_moe2_afp4_wfp4_bf16_t128x128x256_reduce,2.0%,372.2007,0,0,1246.25,4969.6, -256,8192,3072,1536,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,340.00350000000003,flydsl_moe1_afp4_wfp4_bf16_t64x64x256_w4_bnt0,0.0%,270.0618,flydsl_moe2_afp4_wfp4_bf16_t64x128x256_atomic_persist,3.6%,610.0653,0,0,1520.68,3093.83, -256,16384,3072,1536,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,128,0,597.4298,flydsl_moe1_afp4_wfp4_bf16_t128x128x256_w3_bnt0,0.0%,451.0174,flydsl_moe2_afp4_wfp4_bf16_t128x128x256_reduce,2.0%,1048.4472,0,0,1769.69,1872.23, -256,32768,3072,1536,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,128,0,1065.1228,flydsl_moe1_afp4_wfp4_bf16_t128x128x256_w3_bnt0,0.0%,844.3695,flydsl_moe2_afp4_wfp4_bf16_t128x128x256_reduce,2.0%,1909.4923,0,0,1943.37,1107.06, -256,256,3072,512,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,37.8528,flydsl_moe1_afp4_wfp4_bf16_t64x128x256_w3_fp4,0.8%,28.9955,flydsl_moe2_afp4_wfp4_bf16_t64x128x256_reduce_bnt2,0.6%,66.8483,0,0,144.56,9070.37, -256,512,3072,512,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,41.4756,flydsl_moe1_afp4_wfp4_bf16_t32x128x256_w2,0.0%,30.988,flydsl_moe2_afp4_wfp4_bf16_t32x128x256_reduce_bnt2_persist,0.6%,72.4636,0,0,266.72,8400.06, -256,1024,3072,512,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,44.1622,flydsl_moe1_afp4_wfp4_bf16_t64x128x256_w3,0.0%,37.4254,flydsl_moe2_afp4_wfp4_bf16_t64x128x256_reduce,0.6%,81.5876,0,0,473.78,7518.51, -256,2048,3072,512,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,59.0778,flydsl_moe1_afp4_wfp4_bf16_t64x64x256_w3_bnt0,0.0%,51.3745,flydsl_moe2_afp4_wfp4_bf16_t64x128x256_reduce,0.6%,110.4523,0,0,699.93,5639.12, -256,4096,3072,512,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,79.0359,flydsl_moe1_afp4_wfp4_bf16_t64x128x256_w2_bnt0_xcd4,0.0%,93.1675,flydsl_moe2_afp4_wfp4_bf16_t64x256x256_reduce,0.6%,172.2034,0,0,897.88,3726.57, -256,8192,3072,512,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,124.6681,flydsl_moe1_afp4_wfp4_bf16_t64x128x256_w2_bnt0,0.0%,160.8729,flydsl_moe2_afp4_wfp4_bf16_t64x128x256_reduce_persist,0.6%,285.541,0,0,1082.99,2379.61, -256,16384,3072,512,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,206.4827,flydsl_moe1_afp4_wfp4_bf16_t64x128x256_w4_bnt0_xcd4,0.0%,285.1907,flydsl_moe2_afp4_wfp4_bf16_t64x256x256_reduce_xcd4,0.6%,491.6734,0,0,1257.9,1535.52, -256,32768,3072,512,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,367.6363,flydsl_moe1_afp4_wfp4_bf16_t64x128x256_w2_bnt0_xcd4,0.0%,517.2156,flydsl_moe2_afp4_wfp4_bf16_t64x256x256_reduce_xcd4,0.6%,884.8519,0,0,1397.92,1023.87, +cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w,q_type,use_g1u1,doweight_stage1,bias,hidden_pad,intermediate_pad,block_m,ksplit,us1,kernelName1,err1,us2,kernelName2,err2,us,run_1stage,xbf16,tflops,bw,_tag +256,32768,3072,512,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,True,192,152,64,0,398.1113,flydsl_moe1_afp4_wfp4_bf16_t64x128x256_w4_bnt0_xcd4,0.0%,532.1537,flydsl_moe2_afp4_wfp4_bf16_t64x256x256_reduce_xcd4_persist,0.6%,930.265,0,0,1329.68,973.88, +256,256,3072,3072,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,True,192,192,32,0,194.4934,flydsl_moe1_afp4_wfp4_bf16_t32x128x256_w2_fp4,0.8%,107.0083,flydsl_moe2_afp4_wfp4_bf16_t32x256x256_atomic_bnt2_persist,5.0%,301.5017,0,0,192.31,12027.26, +256,512,3072,3072,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,True,192,192,64,0,201.4751,flydsl_moe1_afp4_wfp4_bf16_t64x128x256_w4_fp4,0.8%,112.9308,flydsl_moe2_afp4_wfp4_bf16_t64x128x256_atomic_bnt2,5.0%,314.4059,0,0,368.84,11541.12, +256,1024,3072,3072,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,True,192,192,64,0,210.677,flydsl_moe1_afp4_wfp4_bf16_t64x128x256_w3_fp4,0.8%,116.4949,flydsl_moe2_afp4_wfp4_bf16_t64x256x256_atomic_bnt2,5.1%,327.1719,0,0,708.89,11105.22, +256,2048,3072,3072,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,True,192,192,64,0,257.677,flydsl_moe1_afp4_wfp4_bf16_t64x128x256_w3_bnt0_fp4,0.8%,152.703,flydsl_moe2_afp4_wfp4_bf16_t64x128x256_atomic,5.0%,410.38,0,0,1130.31,8876.54, +256,4096,3072,3072,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,True,192,192,128,0,401.2668,flydsl_moe1_afp4_wfp4_bf16_t128x128x256_w4_bnt0,0.0%,207.587,flydsl_moe2_afp4_wfp4_bf16_t128x128x256_atomic,5.1%,608.8538,0,0,1523.7,6013.97, +256,8192,3072,3072,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,True,192,192,64,0,616.8445,flydsl_moe1_afp4_wfp4_bf16_t64x128x256_w4_bnt0,0.0%,450.2309,flydsl_moe2_afp4_wfp4_bf16_t64x128x256_reduce_persist,3.1%,1067.0754,0,0,1738.8,3466.84, +256,16384,3072,3072,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,True,192,192,64,0,1074.0444,flydsl_moe1_afp4_wfp4_bf16_t64x128x256_w4_bnt0,0.0%,782.384,flydsl_moe2_afp4_wfp4_bf16_t64x128x256_reduce_persist,3.1%,1856.4284,0,0,1998.92,2033.41, +256,32768,3072,3072,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,True,192,192,64,0,2004.8282,flydsl_moe1_afp4_wfp4_bf16_t64x128x256_w4_bnt0,0.0%,1463.8338,flydsl_moe2_afp4_wfp4_bf16_t64x128x256_reduce,3.1%,3468.662,0,0,2139.64,1131.81, +256,256,3072,1536,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,True,192,96,32,0,100.2431,flydsl_moe1_afp4_wfp4_bf16_t32x128x256_w2_fp4,0.9%,58.3019,flydsl_moe2_afp4_wfp4_bf16_t32x256x256_atomic_bnt2_persist,3.7%,158.545,0,0,182.86,11443.43, +256,512,3072,1536,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,True,192,96,32,0,101.6388,flydsl_moe1_afp4_wfp4_bf16_t32x128x256_w2_fp4,0.9%,60.53,flydsl_moe2_afp4_wfp4_bf16_t32x256x256_atomic_bnt2_persist,3.7%,162.1688,0,0,357.54,11202.27, +256,1024,3072,1536,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,True,192,96,64,0,105.8889,flydsl_moe1_afp4_wfp4_bf16_t64x128x256_w3_fp4,0.8%,67.0223,flydsl_moe2_afp4_wfp4_bf16_t64x256x256_atomic_bnt2,3.6%,172.9112,0,0,670.66,10533.59, +256,2048,3072,1536,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,True,192,96,64,0,149.9964,flydsl_moe1_afp4_wfp4_bf16_t64x64x256_w4_bnt0,0.0%,86.1496,flydsl_moe2_afp4_wfp4_bf16_t64x128x256_atomic,3.6%,236.146,0,0,982.14,7752.89, +256,4096,3072,1536,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,True,192,96,128,0,228.0148,flydsl_moe1_afp4_wfp4_bf16_t128x64x256_w2_bnt0,0.0%,142.5471,flydsl_moe2_afp4_wfp4_bf16_t128x128x256_atomic,3.6%,370.5619,0,0,1251.77,4991.58, +256,8192,3072,1536,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,True,192,96,64,0,342.1093,flydsl_moe1_afp4_wfp4_bf16_t64x128x256_w4_bnt0,0.0%,267.9577,flydsl_moe2_afp4_wfp4_bf16_t64x128x256_atomic_persist,3.6%,610.067,0,0,1520.67,3093.82, +256,16384,3072,1536,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,True,192,96,128,0,605.1533999999999,flydsl_moe1_afp4_wfp4_bf16_t128x128x256_w3_bnt0,0.0%,450.0292,flydsl_moe2_afp4_wfp4_bf16_t128x128x256_reduce,2.0%,1055.1826,0,0,1758.39,1860.28, +256,32768,3072,1536,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,True,192,96,128,0,1068.8735,flydsl_moe1_afp4_wfp4_bf16_t128x128x256_w2_bnt0,0.0%,842.2079,flydsl_moe2_afp4_wfp4_bf16_t128x128x256_reduce,2.0%,1911.0814,0,0,1941.75,1106.14, +256,256,3072,512,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,True,192,152,32,0,36.5525,flydsl_moe1_afp4_wfp4_bf16_t32x128x256_w4_fp4,0.8%,27.8074,flydsl_moe2_afp4_wfp4_bf16_t32x128x256_atomic_persist,1.5%,64.3599,0,0,150.15,9421.07, +256,512,3072,512,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,True,192,152,64,0,38.4215,flydsl_moe1_afp4_wfp4_bf16_t64x128x256_w2_fp4,0.8%,30.9836,flydsl_moe2_afp4_wfp4_bf16_t64x128x256_atomic_bnt2,1.6%,69.4051,0,0,278.47,8770.23, +256,1024,3072,512,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,True,192,152,64,0,44.9936,flydsl_moe1_afp4_wfp4_bf16_t64x128x256_w4,0.0%,37.9515,flydsl_moe2_afp4_wfp4_bf16_t32x128x256_atomic_bnt2_sbm64,1.6%,82.9451,0,0,466.03,7395.46, +256,2048,3072,512,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,True,192,152,64,0,60.532900000000005,flydsl_moe1_afp4_wfp4_bf16_t64x64x256_w3_bnt0,0.0%,51.2928,flydsl_moe2_afp4_wfp4_bf16_t64x128x256_reduce,0.6%,111.8257,0,0,691.34,5569.87, +256,4096,3072,512,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,True,192,152,64,0,82.8044,flydsl_moe1_afp4_wfp4_bf16_t64x128x256_w2_bnt0_xcd4,0.0%,91.2517,flydsl_moe2_afp4_wfp4_bf16_t64x256x256_reduce_persist,0.6%,174.0561,0,0,888.33,3686.91, +256,8192,3072,512,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,True,192,152,64,0,122.4926,flydsl_moe1_afp4_wfp4_bf16_t64x128x256_w2_bnt0,0.0%,162.8753,flydsl_moe2_afp4_wfp4_bf16_t64x128x256_reduce,0.6%,285.3679,0,0,1083.65,2381.06, +256,16384,3072,512,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,True,192,152,64,0,211.1064,flydsl_moe1_afp4_wfp4_bf16_t64x128x256_w3_bnt0,0.0%,288.3477,flydsl_moe2_afp4_wfp4_bf16_t64x256x256_reduce_xcd4,0.6%,499.4541,0,0,1238.3,1511.6, diff --git a/aiter/configs/model_configs/gptoss_fp4_untuned_fmoe.csv b/aiter/configs/model_configs/gptoss_fp4_untuned_fmoe.csv index ab6652b026..a4dc27ad9d 100644 --- a/aiter/configs/model_configs/gptoss_fp4_untuned_fmoe.csv +++ b/aiter/configs/model_configs/gptoss_fp4_untuned_fmoe.csv @@ -1,25 +1,25 @@ -token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w,q_type,use_g1u1,doweight_stage1 -256,3072,3072,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0 -512,3072,3072,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0 -1024,3072,3072,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0 -2048,3072,3072,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0 -4096,3072,3072,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0 -8192,3072,3072,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0 -16384,3072,3072,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0 -32768,3072,3072,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0 -256,3072,1536,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0 -512,3072,1536,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0 -1024,3072,1536,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0 -2048,3072,1536,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0 -4096,3072,1536,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0 -8192,3072,1536,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0 -16384,3072,1536,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0 -32768,3072,1536,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0 -256,3072,512,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0 -512,3072,512,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0 -1024,3072,512,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0 -2048,3072,512,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0 -4096,3072,512,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0 -8192,3072,512,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0 -16384,3072,512,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0 -32768,3072,512,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0 +token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w,q_type,use_g1u1,doweight_stage1,bias,hidden_pad,intermediate_pad +256,3072,3072,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,True,192,192 +512,3072,3072,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,True,192,192 +1024,3072,3072,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,True,192,192 +2048,3072,3072,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,True,192,192 +4096,3072,3072,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,True,192,192 +8192,3072,3072,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,True,192,192 +16384,3072,3072,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,True,192,192 +32768,3072,3072,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,True,192,192 +256,3072,1536,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,True,192,96 +512,3072,1536,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,True,192,96 +1024,3072,1536,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,True,192,96 +2048,3072,1536,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,True,192,96 +4096,3072,1536,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,True,192,96 +8192,3072,1536,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,True,192,96 +16384,3072,1536,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,True,192,96 +32768,3072,1536,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,True,192,96 +256,3072,512,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,True,192,152 +512,3072,512,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,True,192,152 +1024,3072,512,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,True,192,152 +2048,3072,512,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,True,192,152 +4096,3072,512,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,True,192,152 +8192,3072,512,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,True,192,152 +16384,3072,512,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,True,192,152 +32768,3072,512,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,True,192,152 diff --git a/aiter/fused_moe.py b/aiter/fused_moe.py index 44a998cdb7..9f514f6c4a 100644 --- a/aiter/fused_moe.py +++ b/aiter/fused_moe.py @@ -344,6 +344,7 @@ def fused_moe_( intermediate_pad, isShuffled, gate_mode, + bias=(bias1 is not None or bias2 is not None), ) block_size_M = metadata.block_m if block_size_M is None else block_size_M @@ -876,6 +877,7 @@ def get_2stage_cfgs( intermediate_pad, is_shuffled=True, gate_mode=GateMode.SEPARATED.value, + bias=False, ): gate_mode = GateMode(gate_mode) _INDEX_COLS = [ @@ -892,12 +894,27 @@ def get_2stage_cfgs( "q_type", "use_g1u1", "doweight_stage1", + "bias", + "hidden_pad", + "intermediate_pad", ] + def _normalize_lookup_cols(df): + for col in ("hidden_pad", "intermediate_pad"): + if col not in df.columns: + df[col] = 0 + df[col] = df[col].fillna(0).astype(int) + if "bias" in df.columns: + df["bias"] = df["bias"].astype(str).str.strip().eq("True") + else: + df["bias"] = False + return df + def get_cfg_2stages(tune_file): import pandas as pd df = pd.read_csv(tune_file) + df = _normalize_lookup_cols(df) if "_tag" in df.columns: df = df[df["_tag"].fillna("") == ""] @@ -939,6 +956,7 @@ def get_flydsl_fallback_cfgs(tune_file): _flydsl_fallback_cache[tune_file] = {} return {} df = pd.read_csv(tune_file) + df = _normalize_lookup_cols(df) if "_tag" not in df.columns: _flydsl_fallback_cache[tune_file] = {} return {} @@ -967,6 +985,7 @@ def get_flydsl_fallback_cfgs(tune_file): if cfg_2stages is None: cfg_2stages = get_cfg_2stages(tune_file) cu_num = get_cu_num() + bias_key = bool(bias) keys = ( cu_num, token, @@ -981,6 +1000,9 @@ def get_flydsl_fallback_cfgs(tune_file): str(q_type), use_g1u1, doweight_stage1, + bias_key, + hidden_pad, + intermediate_pad, ) keys_disabled = ( cu_num, @@ -996,17 +1018,41 @@ def get_flydsl_fallback_cfgs(tune_file): str(q_type), use_g1u1, doweight_stage1, + bias_key, + hidden_pad, + intermediate_pad, ) def MainFunc(): + header = "token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w,q_type,use_g1u1,doweight_stage1,bias,hidden_pad,intermediate_pad" + # Migrate legacy untuned CSV (no `bias` column) so appended rows stay aligned. + if os.path.exists(untune_file) and os.path.getsize(untune_file) > 0: + with open(untune_file, "r") as f: + lines = f.read().splitlines() + if lines and "bias" not in lines[0].split(","): + old_cols = lines[0].split(",") + try: + insert_at = old_cols.index("doweight_stage1") + 1 + except ValueError: + insert_at = len(old_cols) - 2 + new_lines = [ + ",".join(old_cols[:insert_at] + ["bias"] + old_cols[insert_at:]) + ] + for line in lines[1:]: + if not line.strip(): + new_lines.append(line) + continue + parts = line.split(",") + parts = parts[:insert_at] + ["False"] + parts[insert_at:] + new_lines.append(",".join(parts)) + with open(untune_file, "w") as f: + f.write("\n".join(new_lines)) with open(untune_file, "a") as f: if os.path.getsize(untune_file) == 0: - f.write( - "token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w,q_type,use_g1u1,doweight_stage1" - ) + f.write(header) q_dtype_ws = q_dtype_w if q_dtype_w != torch.uint32 else "torch.int4" f.write( - f"\n{token},{model_dim},{inter_dim},{expert},{topk},{activation},{dtype},{q_dtype_a},{q_dtype_ws},{q_type},{int(use_g1u1)},{int(doweight_stage1)}" + f"\n{token},{model_dim},{inter_dim},{expert},{topk},{activation},{dtype},{q_dtype_a},{q_dtype_ws},{q_type},{int(use_g1u1)},{int(doweight_stage1)},{bool(bias)},{hidden_pad},{intermediate_pad}" ) logger.info("\033[34m Start tuning fmoe") os.system( @@ -1547,6 +1593,7 @@ def fused_moe_2stages( intermediate_pad, is_shuffled, gate_mode, + bias=(bias1 is not None or bias2 is not None), ) if ( quant_type == QuantType.per_1x32 diff --git a/aiter/jit/core.py b/aiter/jit/core.py index d42155df10..0637139ec8 100644 --- a/aiter/jit/core.py +++ b/aiter/jit/core.py @@ -215,7 +215,12 @@ def update_config_files(self, file_path: str, merge_name: str): f"when merging '{merge_name}'." ) - _FILL_DEFAULTS = {"xbf16": 0, "run_1stage": 0, "ksplit": 0} + _FILL_DEFAULTS = { + "xbf16": 0, + "run_1stage": 0, + "ksplit": 0, + "bias": False, + } all_cols = list(source_pairs[0][1].columns) for _, df in source_pairs[1:]: for c in df.columns: @@ -252,6 +257,9 @@ def update_config_files(self, file_path: str, merge_name: str): keys.append("cu_num") if "gfx" in merge_df.columns and "gfx" not in keys: keys.append("gfx") + for col in ("bias", "hidden_pad", "intermediate_pad"): + if col in merge_df.columns and col not in keys: + keys.append(col) dedup_keys = keys + ["_tag"] if has_tag else keys duplicated_mask = merge_df.duplicated(subset=dedup_keys, keep=False) if duplicated_mask.any(): diff --git a/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_tune.py b/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_tune.py index f2a40f9c92..9173240f11 100644 --- a/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_tune.py +++ b/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_tune.py @@ -1719,6 +1719,9 @@ def calculate(self, results, bpes=(1, 1, 2)): q_type, use_g1u1, doweight_stage1, + bias, + hidden_pad, + intermediate_pad, ) = key if us == self.INVALID_TIME or us == self.INF_TIME: return 0, 0 @@ -1829,6 +1832,9 @@ def gen_1stage_asm_task(self, key): q_type, use_g1u1, doweight_stage1, + bias, + hidden_pad, + intermediate_pad, ) = info ## asm moe 1 stage tuning get_gfx() @@ -2004,6 +2010,9 @@ def gen_2stages_asm1_task(self, key, blockMs): q_type, use_g1u1, doweight_stage1, + bias, + hidden_pad, + intermediate_pad, ) = info kernels_list_csv = f"{get_asm_dir()}/fmoe_2stages/fmoe_stage1_bf16_pertoken{{quantDtype}}{{extraInfo}}_g1u1.csv" extraInfo = "" @@ -2115,6 +2124,9 @@ def gen_2stages_task(self, key, blockMs): q_type, use_g1u1, doweight_stage1, + bias, + hidden_pad, + intermediate_pad, ) = info _is_a8w4 = ( @@ -2306,6 +2318,9 @@ def _gen_2stages_task_cktile(self, info, blockMs): q_type, use_g1u1, doweight_stage1, + bias, + hidden_pad, + intermediate_pad, ) = info _gen_data_args_s1 = ( @@ -2424,6 +2439,9 @@ def gen_flydsl_2stages_task(self, info, blockMs): q_type, use_g1u1, doweight_stage1, + bias, + hidden_pad, + intermediate_pad, ) = info if q_type != QuantType.per_1x32 or q_dtype_w != dtypes.fp4x2: @@ -2644,6 +2662,9 @@ def gen_flydsl_i4_2stages_task(self, info, blockMs): q_type, use_g1u1, doweight_stage1, + bias, + hidden_pad, + intermediate_pad, ) = info if not (q_type == QuantType.per_1x32 and q_dtype_w == dtypes.i4x2): @@ -3073,12 +3094,18 @@ def tune( q_type, use_g1u1, doweight_stage1, + bias, + hidden_pad, + intermediate_pad, ) = line dtype = eval(dtype) q_dtype_a = eval(q_dtype_a) q_dtype_w = eval(q_dtype_w) q_type = eval(q_type) q_type = QuantType.per_1x128 if q_type == QuantType.per_128x128 else q_type + bias = bool(bias) + hidden_pad = int(hidden_pad) + intermediate_pad = int(intermediate_pad) print("\nStart tuning", line) if get_gfx() not in ["gfx950"] and q_type in [aiter.QuantType.per_1x32]: print(f"{q_type} is not supported on {get_gfx()}") @@ -3101,6 +3128,9 @@ def tune( q_type, use_g1u1, doweight_stage1, + bias, + hidden_pad, + intermediate_pad, ) tasks.extend(self.gen_2stages_asm1_task(info, blockMs)) tasks_ck.extend(self.gen_2stages_task(info, blockMs)) @@ -3243,6 +3273,9 @@ def post_process(self, results, args, topk=-1, fast_mode=False): q_type, use_g1u1, doweight_stage1, + bias, + hidden_pad, + intermediate_pad, ) = key import re @@ -3270,6 +3303,9 @@ def post_process(self, results, args, topk=-1, fast_mode=False): q_type, use_g1u1, doweight_stage1, + bias, + hidden_pad, + intermediate_pad, block_m, row_ksplit, us, @@ -3382,6 +3418,9 @@ def post_process(self, results, args, topk=-1, fast_mode=False): "q_type", "use_g1u1", "doweight_stage1", + "bias", + "hidden_pad", + "intermediate_pad", "block_m", ], how="inner", @@ -3409,6 +3448,9 @@ def post_process(self, results, args, topk=-1, fast_mode=False): q_type, use_g1u1, doweight_stage1, + bias, + hidden_pad, + intermediate_pad, 0, 0, self.INVALID_TIME, @@ -3643,6 +3685,22 @@ def _act_to_fp8(x): else: return pd.DataFrame() + # Optional untuned columns: backfilled with these defaults so older untuned + # CSVs (without bias / hidden_pad / intermediate_pad) still load cleanly. + OPTIONAL_UNTUNED_DEFAULTS = { + "bias": False, + "hidden_pad": 0, + "intermediate_pad": 0, + } + + def _backfill_optional_untuned_cols(self, df): + for col, default in self.OPTIONAL_UNTUNED_DEFAULTS.items(): + if col not in df.columns: + df[col] = default + else: + df[col] = df[col].fillna(default) + return df + def pre_process(self, args): if args.all: self.get_retune_gemm_list(args) @@ -3665,6 +3723,8 @@ def pre_process(self, args): self.tunedf[untunedf_cols].apply(tuple, axis=1) ) self.untunedf = self.untunedf[~mask] + if self.untunedf is not None: + self.untunedf = self._backfill_optional_untuned_cols(self.untunedf) if __name__ == "__main__": @@ -3682,6 +3742,9 @@ def pre_process(self, args): "q_type", "use_g1u1", "doweight_stage1", + "bias", + "hidden_pad", + "intermediate_pad", ] resultList = [ "block_m", diff --git a/op_tests/test_moe_2stage.py b/op_tests/test_moe_2stage.py index b249391791..63a3b51d4c 100644 --- a/op_tests/test_moe_2stage.py +++ b/op_tests/test_moe_2stage.py @@ -71,6 +71,7 @@ def test_fmoe( doweight_stage1=False, hidden_pad=0, intermediate_pad=0, + bias=False, preshuffle=True, strict_accuracy=True, check_aot_cache=True, @@ -78,6 +79,13 @@ def test_fmoe( ): if get_gfx() not in ["gfx950"] and qType in [aiter.QuantType.per_1x32]: return + assert ( + 0 <= hidden_pad < model_dim + ), f"invalid hidden_pad={hidden_pad} for model_dim={model_dim}" + assert ( + 0 <= intermediate_pad < inter_dim + ), f"invalid intermediate_pad={intermediate_pad} for inter_dim={inter_dim}" + torch_quant = aiter.get_torch_quant(qType) input = torch.randn((token, model_dim), dtype=dtype) if use_g1u1: @@ -87,16 +95,39 @@ def test_fmoe( if intermediate_pad != 0: w1[:, -intermediate_pad:, :] = 0 w1[:, inter_dim - intermediate_pad : inter_dim, :] = 0 - exp_bias1 = torch.clamp(torch.randn((E, inter_dim * 2), dtype=dtype), -1.0, 1.0) + if bias: + exp_bias1 = torch.clamp( + torch.randn((E, inter_dim * 2), dtype=dtype), -1.0, 1.0 + ) + # Dense torch reference still evaluates padded lanes; keep padded + # bias zero so invalid lanes do not affect activation quantization. + if intermediate_pad != 0: + exp_bias1[:, -intermediate_pad:] = 0 + exp_bias1[:, inter_dim - intermediate_pad : inter_dim] = 0 + else: + exp_bias1 = None else: w1 = torch.randn((E, inter_dim, model_dim), dtype=dtype) - exp_bias1 = torch.clamp(torch.randn((E * inter_dim), dtype=dtype), -1.0, 1.0) + if bias: + exp_bias1 = torch.clamp( + torch.randn((E * inter_dim), dtype=dtype), -1.0, 1.0 + ) + if intermediate_pad != 0: + exp_bias1.view(E, inter_dim)[:, -intermediate_pad:] = 0 + else: + exp_bias1 = None w2 = torch.randn((E, model_dim, inter_dim), dtype=dtype) if intermediate_pad != 0: w2[:, :, -intermediate_pad:] = 0 if hidden_pad != 0: w2[:, -hidden_pad:, :] = 0 - exp_bias2 = torch.clamp(torch.randn((E, model_dim), dtype=dtype), -1.0, 1.0) + if bias: + exp_bias2 = torch.clamp(torch.randn((E, model_dim), dtype=dtype), -1.0, 1.0) + # The padded hidden tail is outside the logical output dimension. + if hidden_pad != 0: + exp_bias2[:, -hidden_pad:] = 0 + else: + exp_bias2 = None if AITER_MOE_EXPERT_BALANCE: score = torch.zeros((token, E), dtype=dtype) start_col = 0 @@ -186,22 +217,15 @@ def weight_per_128x128_quant(weight, quant_dtype): else: a1_qt, a1_scale = torch_quant(input, quant_dtype=AQDType) - # bias dtype convert - if ( - qType == aiter.QuantType.per_1x32 - and (AQDType in [dtypes.bf16, dtypes.fp16, dtypes.fp8]) - and (WQDType == dtypes.fp4x2) - ): # a16w4 + # bias dtype convert: `bias` flag (from csv) is the source of truth. When + # set, cast to fp32 (kernel ABI). When csv has no bias column, exp_bias1 + # is already None (default False) and this is a no-op. + if exp_bias1 is None: + exp_bias1_aiter = None + exp_bias2_aiter = None + else: exp_bias1_aiter = exp_bias1.to(dtypes.fp32) exp_bias2_aiter = exp_bias2.to(dtypes.fp32) - elif ( - qType == aiter.QuantType.per_1x32 and WQDType == dtypes.i4x2 - ): # a16wi4: no bias - exp_bias1_aiter = exp_bias1 = None - exp_bias2_aiter = exp_bias2 = None - else: - exp_bias1_aiter = exp_bias1 = None - exp_bias2_aiter = exp_bias2 = None # pre-shuffle w1_scale_aiter = w1_scale @@ -286,6 +310,7 @@ def weight_per_128x128_quant(weight, quant_dtype): getattr(w1_qt_aiter, "is_shuffled", False) or getattr(w2_qt_aiter, "is_shuffled", False), gateMode, + bias=exp_bias1_aiter is not None, ) if metadata.fuse_quant == "fp4": # Fused Swiglu MXFP4 quantizes the f32 activation directly. @@ -366,9 +391,12 @@ def weight_per_128x128_quant(weight, quant_dtype): num_iters=5, num_warmup=2, ) + valid_model_dim = model_dim - hidden_pad + out2_ref_check = out2_ref[:, :valid_model_dim] + out2_ck_check = out2_ck[:, :valid_model_dim] err = checkAllclose( - out2_ref, - out2_ck, + out2_ref_check, + out2_ck_check, msg=f"ck_moe_2stages:{us2:>8.2f} us, {token*model_dim*inter_dim*3*topk*2/us2/1000/1000:>8.2f} tflops......(quant:{AQDType})", ) @@ -378,7 +406,7 @@ def calc_diff(x: torch.Tensor, y: torch.Tensor): sim = 2 * (x * y).sum() / denominator return 1 - sim - logits_diff = calc_diff(out2_ref, out2_ck) + logits_diff = calc_diff(out2_ref_check, out2_ck_check) if logits_diff > 1e-3: logging.warning( f"logits_diff: {logits_diff} is too large, please check the implementation" @@ -590,7 +618,14 @@ def _str2enum(s, enum_cls): def _row_to_kwargs(row): - # csv rows store already-effective dims, so pad defaults to 0. + def _row_int(name): + if name not in row: + return 0 + value = row.get(name) + if pd.isna(value) or str(value).strip() == "": + return 0 + return int(value) + q_type = _str2enum(row["q_type"], aiter.QuantType) aq_dtype = _str2dtype(row["q_dtype_a"]) wq_dtype = _str2dtype(row["q_dtype_w"]) @@ -612,8 +647,9 @@ def _row_to_kwargs(row): WQDType=wq_dtype, use_g1u1=dtypes.str2bool(str(row["use_g1u1"])), doweight_stage1=dtypes.str2bool(str(row["doweight_stage1"])), - hidden_pad=0, - intermediate_pad=0, + hidden_pad=_row_int("hidden_pad"), + intermediate_pad=_row_int("intermediate_pad"), + bias=dtypes.str2bool(str(row.get("bias", "False"))), preshuffle=True, ) @@ -767,6 +803,7 @@ def _kw( aiter.ActivationType.Swiglu, hidden_pad=hidden_pad, intermediate_pad=intermediate_pad, + bias=True, ), extras elif triple == _PER1X32_FP4_FP4: for preshuffle in args.preshuffle: