diff --git a/aiter/aot/flydsl/gemm.py b/aiter/aot/flydsl/gemm.py index d77ed02104..82839f3cf4 100644 --- a/aiter/aot/flydsl/gemm.py +++ b/aiter/aot/flydsl/gemm.py @@ -210,6 +210,7 @@ def _compile_hgemm_to_cache( split_k: int, block_m_warps: int, block_n_warps: int, + block_k_warps: int, n_tile_repeat: int = 1, persistent_n_tiles: int = 1, waves_per_eu: int = 0, @@ -257,6 +258,7 @@ def _compile_hgemm_to_cache( split_k=split_k, block_m_warps=block_m_warps, block_n_warps=block_n_warps, + block_k_warps=block_k_warps, n_tile_repeat=n_tile_repeat, persistent_n_tiles=persistent_n_tiles, waves_per_eu=waves_per_eu, diff --git a/aiter/configs/model_configs/gptoss_bf16_tuned_gemm.csv b/aiter/configs/model_configs/gptoss_bf16_tuned_gemm.csv index 3fe9c8ba0c..cfd75d131d 100644 --- a/aiter/configs/model_configs/gptoss_bf16_tuned_gemm.csv +++ b/aiter/configs/model_configs/gptoss_bf16_tuned_gemm.csv @@ -1,90 +1,90 @@ gfx,cu_num,M,N,K,bias,dtype,outdtype,scaleAB,bpreshuffle,libtype,solidx,splitK,us,kernelName,err_ratio,tflops,bw -gfx950,256,1,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,14,4.9558,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0234,0.15,149.99 -gfx950,256,2,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,13,4.9466,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0195,0.3,151.48 -gfx950,256,4,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,14,4.9687,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0156,0.59,153.23 -gfx950,256,8,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,14,4.9927,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0176,1.18,157.31 -gfx950,256,16,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,13,5.031,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0171,2.34,165.68 -gfx950,256,32,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,13,4.6354,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0203,5.09,200.59 -gfx950,256,48,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,3,14,5.2547,_ZN5aiter39bf16gemm_fp32bf16_tn_48x64_splitk_cleanE,0.0212,6.73,195.26 -gfx950,256,64,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,5,13,5.3561,_ZN5aiter39bf16gemm_fp32bf16_tn_64x64_splitk_cleanE,0.021,8.81,209.54 -gfx950,256,80,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,7,13,5.6419,_ZN5aiter39bf16gemm_fp32bf16_tn_80x64_splitk_cleanE,0.0218,10.45,215.98 -gfx950,256,96,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,3,9,5.7166,_ZN5aiter39bf16gemm_fp32bf16_tn_48x64_splitk_cleanE,0.0163,12.38,230.0 -gfx950,256,112,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,5,9,5.9183,_ZN5aiter39bf16gemm_fp32bf16_tn_64x64_splitk_cleanE,0.0165,13.95,238.43 -gfx950,256,128,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,5,9,5.97,_ZN5aiter39bf16gemm_fp32bf16_tn_64x64_splitk_cleanE,0.0172,15.81,252.48 -gfx950,256,256,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,5,6,7.0187,_ZN5aiter39bf16gemm_fp32bf16_tn_64x64_splitk_cleanE,0.0126,26.89,324.47 -gfx950,256,1,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,4,9.6772,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0105,1.52,1524.87 -gfx950,256,2,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,4,9.8371,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0107,3.0,1501.19 -gfx950,256,4,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,4,9.8551,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0118,5.98,1500.66 -gfx950,256,8,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,4,9.9035,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0101,11.91,1497.72 -gfx950,256,16,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,4,10.0897,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.01,23.38,1478.7 -gfx950,256,32,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,4,10.2307,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.01,46.12,1475.34 -gfx950,256,48,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,3,10.7334,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0073,65.94,1422.46 -gfx950,256,64,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,3,10.6753,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0074,88.4,1446.51 -gfx950,256,128,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,5,3,11.6504,_ZN5aiter39bf16gemm_fp32bf16_tn_64x64_splitk_cleanE,0.0074,162.01,1385.21 -gfx950,256,256,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,8,2,15.3742,_ZN5aiter39bf16gemm_fp32bf16_tn_96x64_splitk_cleanE,0.0041,245.53,1140.28 -gfx950,256,1,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,4,9.0917,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0083,1.3,1298.58 -gfx950,256,1,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,triton,0,0,10.0871,auto,0.0,2.34,2340.31 -gfx950,256,2,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,3,9.2607,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0056,2.55,1275.95 -gfx950,256,2,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,triton,0,0,9.6663,auto,0.0,4.88,2443.63 -gfx950,256,4,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,4,9.1797,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0091,5.14,1289.36 -gfx950,256,4,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,triton,0,0,10.1637,auto,0.0,9.29,2326.79 -gfx950,256,8,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,4,9.2315,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0092,10.22,1286.39 -gfx950,256,8,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,5,11.4653,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0132,16.46,2067.51 -gfx950,256,16,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,3,9.3253,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0062,20.24,1281.91 -gfx950,256,16,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,5,11.591,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0136,32.57,2054.71 -gfx950,256,32,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,3,9.2671,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0065,40.73,1306.98 -gfx950,256,32,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,5,11.7821,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0137,64.08,2040.33 -gfx950,256,48,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,asm,3,4,9.8145,_ZN5aiter39bf16gemm_fp32bf16_tn_48x64_splitk_cleanE,0.009,57.69,1250.15 -gfx950,256,48,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,asm,3,5,12.1519,_ZN5aiter39bf16gemm_fp32bf16_tn_48x64_splitk_cleanE,0.0136,93.19,1996.61 -gfx950,256,64,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,asm,5,4,10.1075,_ZN5aiter39bf16gemm_fp32bf16_tn_64x64_splitk_cleanE,0.009,74.69,1229.51 -gfx950,256,64,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,asm,5,5,12.6689,_ZN5aiter39bf16gemm_fp32bf16_tn_64x64_splitk_cleanE,0.0137,119.19,1932.76 -gfx950,256,80,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,asm,7,5,13.1447,_ZN5aiter39bf16gemm_fp32bf16_tn_80x64_splitk_cleanE,0.0137,143.59,1879.78 -gfx950,256,96,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,asm,8,5,13.5787,_ZN5aiter39bf16gemm_fp32bf16_tn_96x64_splitk_cleanE,0.0136,166.8,1836.14 -gfx950,256,112,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,asm,5,2,17.4374,_ZN5aiter39bf16gemm_fp32bf16_tn_64x64_splitk_cleanE,0.0048,151.54,1442.62 -gfx950,256,128,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,asm,5,2,11.3635,_ZN5aiter39bf16gemm_fp32bf16_tn_64x64_splitk_cleanE,0.0035,132.88,1149.12 -gfx950,256,128,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,asm,5,2,17.3233,_ZN5aiter39bf16gemm_fp32bf16_tn_64x64_splitk_cleanE,0.0048,174.33,1465.01 -gfx950,256,256,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,asm,6,1,15.5453,_ZN5aiter37bf16gemm_fp32bf16_tn_64x64_pf3_splitkE,0.0,194.26,921.15 -gfx950,256,1,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,3,11.5513,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0078,2.55,2554.45 -gfx950,256,2,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,3,11.4061,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0075,5.17,2588.37 -gfx950,256,4,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,3,12.0536,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0075,9.79,2451.98 -gfx950,256,8,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,3,11.96,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0077,19.73,2476.52 -gfx950,256,16,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,3,12.2044,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0077,38.66,2437.42 -gfx950,256,32,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,3,12.3457,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0078,76.44,2430.26 -gfx950,256,48,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,3,3,13.0203,_ZN5aiter39bf16gemm_fp32bf16_tn_48x64_splitk_cleanE,0.0076,108.72,2324.0 -gfx950,256,64,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,5,3,13.4435,_ZN5aiter39bf16gemm_fp32bf16_tn_64x64_splitk_cleanE,0.0077,140.4,2269.89 -gfx950,256,80,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,7,3,14.2374,_ZN5aiter39bf16gemm_fp32bf16_tn_80x64_splitk_cleanE,0.0076,165.71,2161.29 -gfx950,256,96,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,8,3,14.4747,_ZN5aiter39bf16gemm_fp32bf16_tn_96x64_splitk_cleanE,0.0076,195.59,2143.55 -gfx950,256,112,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,4,1,19.1801,_ZN5aiter37bf16gemm_fp32bf16_tn_48x64_pf3_splitkE,0.0,172.21,1631.02 -gfx950,256,128,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,4,1,19.3026,_ZN5aiter37bf16gemm_fp32bf16_tn_48x64_pf3_splitkE,0.0,195.56,1633.94 -gfx950,256,1,640,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,6,7.8743,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0141,0.47,469.05 -gfx950,256,2,640,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,6,7.8652,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0148,0.94,470.49 -gfx950,256,4,640,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,6,7.7813,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0141,1.9,477.37 -gfx950,256,8,640,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,7,7.9497,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0166,3.71,470.8 -gfx950,256,16,640,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,6,7.9534,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0144,7.42,477.66 -gfx950,256,32,640,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,6,8.0289,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0145,14.69,487.2 -gfx950,256,48,640,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,3,6,8.2966,_ZN5aiter39bf16gemm_fp32bf16_tn_48x64_splitk_cleanE,0.0147,21.33,485.06 -gfx950,256,64,640,2880,True,torch.bfloat16,torch.bfloat16,False,False,triton,0,0,8.6344,auto,0.0,27.32,479.13 -gfx950,256,80,640,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,7,8,9.0799,_ZN5aiter39bf16gemm_fp32bf16_tn_80x64_splitk_cleanE,0.0177,32.48,468.02 -gfx950,256,96,640,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,5,9.1847,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0126,38.53,474.95 -gfx950,256,112,640,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,4,9.4523,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0107,43.68,473.42 -gfx950,256,128,640,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,4,9.4726,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0105,49.81,484.29 -gfx950,256,256,640,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,3,10.1055,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0075,93.39,543.13 -gfx950,256,80,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,3,3,11.2541,_ZN5aiter39bf16gemm_fp32bf16_tn_48x64_splitk_cleanE,0.0073,104.82,1387.58 -gfx950,256,96,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,3,3,11.2805,_ZN5aiter39bf16gemm_fp32bf16_tn_48x64_splitk_cleanE,0.0074,125.49,1399.77 -gfx950,256,112,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,7,3,11.8435,_ZN5aiter39bf16gemm_fp32bf16_tn_80x64_splitk_cleanE,0.0075,139.44,1347.93 -gfx950,256,1,2880,512,True,torch.bfloat16,torch.bfloat16,False,False,triton,0,0,3.6698,auto,0.0,0.8,805.47 -gfx950,256,2,2880,512,True,torch.bfloat16,torch.bfloat16,False,False,triton,0,0,4.2591,auto,0.0,1.38,695.61 -gfx950,256,4,2880,512,True,torch.bfloat16,torch.bfloat16,False,False,triton,0,0,4.3198,auto,0.0,2.73,688.98 -gfx950,256,8,2880,512,True,torch.bfloat16,torch.bfloat16,False,False,triton,0,0,4.1251,auto,0.0,5.72,728.08 -gfx950,256,16,2880,512,True,torch.bfloat16,torch.bfloat16,False,False,triton,0,0,4.0261,auto,0.0,11.72,759.46 -gfx950,256,32,2880,512,True,torch.bfloat16,torch.bfloat16,False,False,triton,0,0,4.1461,auto,0.0,22.76,763.66 -gfx950,256,48,2880,512,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,26,1,5.0181,flydsl_gemm2_abf16_wbf16_bf16_t32x64x256_split_k1_block_m_warp1_block_n_warp4_async_copyTrue_b_to_ldsTrue_b_preshuffleFalse_c_to_ldsFalse_gfx950,0.0,28.21,652.59 -gfx950,256,64,2880,512,True,torch.bfloat16,torch.bfloat16,False,False,triton,0,0,4.9656,auto,0.0,38.01,681.35 -gfx950,256,80,2880,512,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,67,1,5.0507,flydsl_gemm2_abf16_wbf16_bf16_t32x64x256_split_k1_block_m_warp1_block_n_warp4_async_copyTrue_b_to_ldsTrue_b_preshuffleFalse_c_to_ldsFalse_gfx950,0.0,46.71,691.36 -gfx950,256,80,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,asm,7,4,10.2424,_ZN5aiter39bf16gemm_fp32bf16_tn_80x64_splitk_cleanE,0.0088,92.14,1228.71 -gfx950,256,96,2880,512,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,67,1,5.1203,flydsl_gemm2_abf16_wbf16_bf16_t32x64x256_split_k1_block_m_warp1_block_n_warp4_async_copyTrue_b_to_ldsTrue_b_preshuffleFalse_c_to_ldsFalse_gfx950,0.0,55.29,703.16 -gfx950,256,96,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,asm,8,4,10.752,_ZN5aiter39bf16gemm_fp32bf16_tn_96x64_splitk_cleanE,0.009,105.33,1185.14 -gfx950,256,112,2880,512,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,67,1,5.3155,flydsl_gemm2_abf16_wbf16_bf16_t32x64x256_split_k1_block_m_warp1_block_n_warp4_async_copyTrue_b_to_ldsTrue_b_preshuffleFalse_c_to_ldsFalse_gfx950,0.0,62.14,697.76 -gfx950,256,112,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,torch,0,0,10.7172,native,0.0,123.28,1203.71 -gfx950,256,128,2880,512,True,torch.bfloat16,torch.bfloat16,False,False,triton,0,0,4.7044,auto,0.0,80.24,811.47 -gfx950,256,256,2880,512,True,torch.bfloat16,torch.bfloat16,False,False,triton,0,0,6.291,auto,0.0,120.01,744.85 +gfx950,256,1,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,skinny,2,0,4.3907,sol2,0.0,0.17,169.29 +gfx950,256,2,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,13,4.9474,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0195,0.3,151.46 +gfx950,256,4,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,9,4.6178,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0059,0.64,164.87 +gfx950,256,8,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,13,4.6295,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0166,1.27,169.65 +gfx950,256,16,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,14,4.7001,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0259,2.51,177.34 +gfx950,256,32,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,13,5.02,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0195,4.7,185.22 +gfx950,256,48,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,3,13,4.851,_ZN5aiter39bf16gemm_fp32bf16_tn_48x64_splitk_cleanE,0.0213,7.3,211.51 +gfx950,256,64,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,5,15,5.0677,_ZN5aiter39bf16gemm_fp32bf16_tn_64x64_splitk_cleanE,0.0262,9.31,221.46 +gfx950,256,80,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,7,14,5.3256,_ZN5aiter39bf16gemm_fp32bf16_tn_80x64_splitk_cleanE,0.0227,11.08,228.81 +gfx950,256,96,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,456,9,5.5433,flydsl_gemm2_abf16_wbf16_bf16_t32x64x64_split_k9_block_m_warp1_block_n_warp4_block_k_warp1_async_copyTrue_b_to_ldsFalse_b_preshuffleFalse_c_to_ldsFalse_gfx950,0.0171,12.77,237.19 +gfx950,256,112,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,5,9,5.7119,_ZN5aiter39bf16gemm_fp32bf16_tn_64x64_splitk_cleanE,0.0171,14.46,247.04 +gfx950,256,128,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,5,9,5.9016,_ZN5aiter39bf16gemm_fp32bf16_tn_64x64_splitk_cleanE,0.0177,15.99,255.41 +gfx950,256,256,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,5,5,7.2795,_ZN5aiter39bf16gemm_fp32bf16_tn_64x64_splitk_cleanE,0.0125,25.93,312.85 +gfx950,256,1,640,2880,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,46,9,6.957,flydsl_gemm2_abf16_wbf16_bf16_t16x64x64_split_k9_block_m_warp1_block_n_warp2_block_k_warp1_async_copyTrue_b_to_ldsFalse_b_preshuffleFalse_c_to_ldsFalse_gfx950,0.0125,0.53,530.9 +gfx950,256,2,640,2880,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,46,9,6.7037,flydsl_gemm2_abf16_wbf16_bf16_t16x64x64_split_k9_block_m_warp1_block_n_warp2_block_k_warp1_async_copyTrue_b_to_ldsFalse_b_preshuffleFalse_c_to_ldsFalse_gfx950,0.0164,1.1,552.01 +gfx950,256,4,640,2880,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,46,9,6.6585,flydsl_gemm2_abf16_wbf16_bf16_t16x64x64_split_k9_block_m_warp1_block_n_warp2_block_k_warp1_async_copyTrue_b_to_ldsFalse_b_preshuffleFalse_c_to_ldsFalse_gfx950,0.0168,2.21,557.87 +gfx950,256,8,640,2880,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,46,9,6.9713,flydsl_gemm2_abf16_wbf16_bf16_t16x64x64_split_k9_block_m_warp1_block_n_warp2_block_k_warp1_async_copyTrue_b_to_ldsFalse_b_preshuffleFalse_c_to_ldsFalse_gfx950,0.0154,4.23,536.88 +gfx950,256,16,640,2880,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,47,9,6.9779,flydsl_gemm2_abf16_wbf16_bf16_t16x64x64_split_k9_block_m_warp1_block_n_warp2_block_k_warp1_async_copyTrue_b_to_ldsTrue_b_preshuffleFalse_c_to_ldsFalse_gfx950,0.0188,8.45,544.44 +gfx950,256,32,640,2880,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,191,9,7.6226,flydsl_gemm2_abf16_wbf16_bf16_t32x64x64_split_k9_block_m_warp2_block_n_warp2_block_k_warp1_async_copyTrue_b_to_ldsTrue_b_preshuffleFalse_c_to_ldsFalse_gfx950,0.0184,15.48,513.17 +gfx950,256,48,640,2880,True,torch.bfloat16,torch.bfloat16,False,False,triton,0,0,7.7828,auto,0.0,22.74,517.08 +gfx950,256,64,640,2880,True,torch.bfloat16,torch.bfloat16,False,False,triton,0,0,7.6829,auto,0.0,30.71,538.46 +gfx950,256,80,640,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,7,7,8.9311,_ZN5aiter39bf16gemm_fp32bf16_tn_80x64_splitk_cleanE,0.0159,33.02,475.82 +gfx950,256,96,640,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,5,9.1914,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0122,38.5,474.6 +gfx950,256,112,640,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,4,9.4149,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0105,43.85,475.3 +gfx950,256,128,640,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,4,9.3631,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0104,50.4,489.96 +gfx950,256,256,640,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,3,4,10.1489,_ZN5aiter39bf16gemm_fp32bf16_tn_48x64_splitk_cleanE,0.0104,92.99,540.81 +gfx950,256,1,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,22,9,7.8998,flydsl_gemm2_abf16_wbf16_bf16_t16x128x64_split_k9_block_m_warp1_block_n_warp2_block_k_warp1_async_copyTrue_b_to_ldsFalse_b_preshuffleFalse_c_to_ldsFalse_gfx950,0.0223,1.87,1867.96 +gfx950,256,2,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,20,9,7.9753,flydsl_gemm2_abf16_wbf16_bf16_t16x128x64_split_k9_block_m_warp1_block_n_warp1_block_k_warp2_async_copyTrue_b_to_ldsFalse_b_preshuffleFalse_c_to_ldsFalse_gfx950,0.0242,3.7,1851.64 +gfx950,256,4,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,22,9,8.0194,flydsl_gemm2_abf16_wbf16_bf16_t16x128x64_split_k9_block_m_warp1_block_n_warp2_block_k_warp1_async_copyTrue_b_to_ldsFalse_b_preshuffleFalse_c_to_ldsFalse_gfx950,0.0198,7.35,1844.17 +gfx950,256,8,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,20,9,8.5477,flydsl_gemm2_abf16_wbf16_bf16_t16x128x64_split_k9_block_m_warp1_block_n_warp1_block_k_warp2_async_copyTrue_b_to_ldsFalse_b_preshuffleFalse_c_to_ldsFalse_gfx950,0.0252,13.8,1735.28 +gfx950,256,16,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,22,9,9.1875,flydsl_gemm2_abf16_wbf16_bf16_t16x128x64_split_k9_block_m_warp1_block_n_warp2_block_k_warp1_async_copyTrue_b_to_ldsFalse_b_preshuffleFalse_c_to_ldsFalse_gfx950,0.0178,25.68,1623.91 +gfx950,256,32,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,4,10.1658,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0099,46.42,1484.76 +gfx950,256,48,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,3,10.7413,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0073,65.89,1421.41 +gfx950,256,64,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,3,10.5517,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0074,89.44,1463.45 +gfx950,256,80,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,5,3,11.2081,_ZN5aiter39bf16gemm_fp32bf16_tn_64x64_splitk_cleanE,0.0073,105.25,1393.28 +gfx950,256,96,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,3,3,11.3114,_ZN5aiter39bf16gemm_fp32bf16_tn_48x64_splitk_cleanE,0.0074,125.15,1395.94 +gfx950,256,112,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,5,3,11.8801,_ZN5aiter39bf16gemm_fp32bf16_tn_64x64_splitk_cleanE,0.0075,139.01,1343.77 +gfx950,256,128,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,5,3,11.8203,_ZN5aiter39bf16gemm_fp32bf16_tn_64x64_splitk_cleanE,0.0074,159.68,1365.3 +gfx950,256,256,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,8,2,15.4704,_ZN5aiter39bf16gemm_fp32bf16_tn_96x64_splitk_cleanE,0.0041,244.01,1133.19 +gfx950,256,1,2880,512,True,torch.bfloat16,torch.bfloat16,False,False,triton,0,0,4.2083,auto,0.0,0.7,702.4 +gfx950,256,1,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,39,4,6.0971,flydsl_gemm2_abf16_wbf16_bf16_t16x64x256_split_k4_block_m_warp1_block_n_warp1_block_k_warp2_async_copyTrue_b_to_ldsTrue_b_preshuffleFalse_c_to_ldsFalse_gfx950,0.0156,1.93,1936.39 +gfx950,256,1,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,47,8,7.827,flydsl_gemm2_abf16_wbf16_bf16_t16x64x256_split_k8_block_m_warp1_block_n_warp2_block_k_warp1_async_copyTrue_b_to_ldsTrue_b_preshuffleFalse_c_to_ldsFalse_gfx950,0.0194,3.01,3016.09 +gfx950,256,2,2880,512,True,torch.bfloat16,torch.bfloat16,False,False,triton,0,0,4.216,auto,0.0,1.4,702.72 +gfx950,256,2,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,37,4,6.4071,flydsl_gemm2_abf16_wbf16_bf16_t16x64x256_split_k4_block_m_warp1_block_n_warp1_block_k_warp1_async_copyTrue_b_to_ldsTrue_b_preshuffleFalse_c_to_ldsFalse_gfx950,0.0089,3.68,1844.23 +gfx950,256,2,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,43,8,8.1896,flydsl_gemm2_abf16_wbf16_bf16_t16x64x256_split_k8_block_m_warp1_block_n_warp1_block_k_warp1_async_copyTrue_b_to_ldsTrue_b_preshuffleFalse_c_to_ldsFalse_gfx950,0.02,5.76,2884.25 +gfx950,256,4,2880,512,True,torch.bfloat16,torch.bfloat16,False,False,triton,0,0,3.7233,auto,0.0,3.17,799.36 +gfx950,256,4,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,41,4,6.5059,flydsl_gemm2_abf16_wbf16_bf16_t16x64x256_split_k4_block_m_warp1_block_n_warp2_block_k_warp1_async_copyTrue_b_to_ldsTrue_b_preshuffleFalse_c_to_ldsFalse_gfx950,0.0089,7.25,1819.26 +gfx950,256,4,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,43,8,8.2683,flydsl_gemm2_abf16_wbf16_bf16_t16x64x256_split_k8_block_m_warp1_block_n_warp1_block_k_warp1_async_copyTrue_b_to_ldsTrue_b_preshuffleFalse_c_to_ldsFalse_gfx950,0.0211,11.41,2860.17 +gfx950,256,8,2880,512,True,torch.bfloat16,torch.bfloat16,False,False,triton,0,0,4.2713,auto,0.0,5.52,703.16 +gfx950,256,8,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,41,4,6.4413,flydsl_gemm2_abf16_wbf16_bf16_t16x64x256_split_k4_block_m_warp1_block_n_warp2_block_k_warp1_async_copyTrue_b_to_ldsTrue_b_preshuffleFalse_c_to_ldsFalse_gfx950,0.009,14.65,1843.62 +gfx950,256,8,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,47,8,8.5513,flydsl_gemm2_abf16_wbf16_bf16_t16x64x256_split_k8_block_m_warp1_block_n_warp2_block_k_warp1_async_copyTrue_b_to_ldsTrue_b_preshuffleFalse_c_to_ldsFalse_gfx950,0.0202,22.07,2772.04 +gfx950,256,16,2880,512,True,torch.bfloat16,torch.bfloat16,False,False,triton,0,0,3.7647,auto,0.0,12.53,812.19 +gfx950,256,16,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,41,4,6.6168,flydsl_gemm2_abf16_wbf16_bf16_t16x64x256_split_k4_block_m_warp1_block_n_warp2_block_k_warp1_async_copyTrue_b_to_ldsTrue_b_preshuffleFalse_c_to_ldsFalse_gfx950,0.0101,28.52,1806.64 +gfx950,256,16,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,torch,0,0,8.8274,native,0.0,42.76,2697.98 +gfx950,256,32,2880,512,True,torch.bfloat16,torch.bfloat16,False,False,triton,0,0,3.7529,auto,0.0,25.15,843.67 +gfx950,256,32,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,189,4,7.948,flydsl_gemm2_abf16_wbf16_bf16_t32x64x256_split_k4_block_m_warp1_block_n_warp2_block_k_warp2_async_copyTrue_b_to_ldsTrue_b_preshuffleFalse_c_to_ldsFalse_gfx950,0.0167,47.49,1523.89 +gfx950,256,32,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,torch,0,0,9.7821,native,0.0,77.18,2457.49 +gfx950,256,48,2880,512,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,21,1,4.6087,flydsl_gemm2_abf16_wbf16_bf16_t16x64x256_split_k1_block_m_warp1_block_n_warp1_block_k_warp2_async_copyTrue_b_to_ldsTrue_b_preshuffleFalse_c_to_ldsFalse_gfx950,0.0012,30.72,710.56 +gfx950,256,48,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,179,2,9.4277,flydsl_gemm2_abf16_wbf16_bf16_t32x64x256_split_k2_block_m_warp2_block_n_warp2_block_k_warp1_async_copyTrue_b_to_ldsTrue_b_preshuffleFalse_c_to_ldsFalse_gfx950,0.0051,60.06,1301.44 +gfx950,256,48,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,asm,3,5,12.142,_ZN5aiter39bf16gemm_fp32bf16_tn_48x64_splitk_cleanE,0.0134,93.27,1998.24 +gfx950,256,64,2880,512,True,torch.bfloat16,torch.bfloat16,False,False,triton,0,0,4.3978,auto,0.0,42.92,769.32 +gfx950,256,64,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,torch,0,0,9.7603,native,0.0,77.35,1273.25 +gfx950,256,64,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,asm,5,5,12.6368,_ZN5aiter39bf16gemm_fp32bf16_tn_64x64_splitk_cleanE,0.0136,119.49,1937.67 +gfx950,256,80,2880,512,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,397,1,5.1248,flydsl_gemm2_abf16_wbf16_bf16_t32x64x256_split_k1_block_m_warp2_block_n_warp2_block_k_warp1_async_copyTrue_b_to_ldsTrue_b_preshuffleFalse_c_to_ldsFalse_gfx950,0.0,46.04,681.36 +gfx950,256,80,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,asm,7,4,10.4093,_ZN5aiter39bf16gemm_fp32bf16_tn_80x64_splitk_cleanE,0.0091,90.66,1209.01 +gfx950,256,80,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,asm,7,5,13.1591,_ZN5aiter39bf16gemm_fp32bf16_tn_80x64_splitk_cleanE,0.0138,143.43,1877.72 +gfx950,256,96,2880,512,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,397,1,5.1325,flydsl_gemm2_abf16_wbf16_bf16_t32x64x256_split_k1_block_m_warp2_block_n_warp2_block_k_warp1_async_copyTrue_b_to_ldsTrue_b_preshuffleFalse_c_to_ldsFalse_gfx950,0.0,55.16,701.49 +gfx950,256,96,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,asm,8,4,10.8637,_ZN5aiter39bf16gemm_fp32bf16_tn_96x64_splitk_cleanE,0.0091,104.24,1172.96 +gfx950,256,96,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,asm,8,5,13.2729,_ZN5aiter39bf16gemm_fp32bf16_tn_96x64_splitk_cleanE,0.0138,170.64,1878.44 +gfx950,256,112,2880,512,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,389,1,5.2924,flydsl_gemm2_abf16_wbf16_bf16_t32x64x256_split_k1_block_m_warp1_block_n_warp2_block_k_warp2_async_copyTrue_b_to_ldsTrue_b_preshuffleFalse_c_to_ldsFalse_gfx950,0.0011,62.41,700.8 +gfx950,256,112,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,torch,0,0,9.9645,native,0.0,132.59,1294.63 +gfx950,256,112,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,torch,0,0,15.6168,native,0.0,169.2,1610.8 +gfx950,256,128,2880,512,True,torch.bfloat16,torch.bfloat16,False,False,triton,0,0,4.9845,auto,0.0,75.73,765.87 +gfx950,256,128,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,torch,0,0,10.6636,native,0.0,141.6,1224.54 +gfx950,256,128,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,torch,0,0,15.2181,native,0.0,198.44,1667.67 +gfx950,256,256,2880,512,True,torch.bfloat16,torch.bfloat16,False,False,triton,0,0,6.1785,auto,0.0,122.19,758.41 +gfx950,256,256,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,torch,0,0,13.9059,native,0.0,217.17,1029.75 +gfx950,256,1,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,23,9,9.4214,flydsl_gemm2_abf16_wbf16_bf16_t16x128x64_split_k9_block_m_warp1_block_n_warp2_block_k_warp1_async_copyTrue_b_to_ldsTrue_b_preshuffleFalse_c_to_ldsFalse_gfx950,0.0197,3.13,3131.93 +gfx950,256,2,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,23,9,9.2631,flydsl_gemm2_abf16_wbf16_bf16_t16x128x64_split_k9_block_m_warp1_block_n_warp2_block_k_warp1_async_copyTrue_b_to_ldsTrue_b_preshuffleFalse_c_to_ldsFalse_gfx950,0.0181,6.37,3187.18 +gfx950,256,4,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,23,9,9.8631,flydsl_gemm2_abf16_wbf16_bf16_t16x128x64_split_k9_block_m_warp1_block_n_warp2_block_k_warp1_async_copyTrue_b_to_ldsTrue_b_preshuffleFalse_c_to_ldsFalse_gfx950,0.0192,11.96,2996.54 +gfx950,256,8,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,23,9,10.3673,flydsl_gemm2_abf16_wbf16_bf16_t16x128x64_split_k9_block_m_warp1_block_n_warp2_block_k_warp1_async_copyTrue_b_to_ldsTrue_b_preshuffleFalse_c_to_ldsFalse_gfx950,0.0199,22.76,2856.98 +gfx950,256,16,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,3,12.0709,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0077,39.09,2464.37 +gfx950,256,32,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,triton,0,0,12.1987,auto,0.0,77.36,2459.54 +gfx950,256,48,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,torch,0,0,13.0914,native,0.0,108.13,2311.38 +gfx950,256,64,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,5,3,13.5011,_ZN5aiter39bf16gemm_fp32bf16_tn_64x64_splitk_cleanE,0.0077,139.8,2260.2 +gfx950,256,80,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,7,3,13.9425,_ZN5aiter39bf16gemm_fp32bf16_tn_80x64_splitk_cleanE,0.0076,169.22,2207.01 +gfx950,256,96,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,8,3,14.5578,_ZN5aiter39bf16gemm_fp32bf16_tn_96x64_splitk_cleanE,0.0076,194.48,2131.31 +gfx950,256,112,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,torch,0,0,17.0959,native,0.0,193.21,1829.87 +gfx950,256,128,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,torch,0,0,17.1302,native,0.0,220.36,1841.15 diff --git a/aiter/ops/flydsl/gemm_kernels.py b/aiter/ops/flydsl/gemm_kernels.py index d624e7cce8..1e3b861347 100644 --- a/aiter/ops/flydsl/gemm_kernels.py +++ b/aiter/ops/flydsl/gemm_kernels.py @@ -48,6 +48,7 @@ def _get_dtypes(): r"split_k(?P\d+)_" r"block_m_warp(?P\d+)_" r"block_n_warp(?P\d+)_" + r"block_k_warp(?P\d+)_" r"async_copy(?PTrue|False)_" r"b_to_lds(?PTrue|False)_" r"b_preshuffle(?PTrue|False)_" @@ -73,35 +74,28 @@ def _get_dtypes(): HGEMM_TILE_N_OPTIONS = (64, 128, 256) HGEMM_TILE_K_OPTIONS = (64, 128, 256) HGEMM_TILE_M_OPTIONS = (16, 32, 48, 64, 96, 128, 256) -HGEMM_BASE_SPLIT_K_OPTIONS = tuple(range(1, 9)) -HGEMM_MAX_SPLIT_K = 8 -KERNEL_CONFIG_VARIANTS = ( - { - "block_m_warps": 1, - "block_n_warps": 2, - "b_to_lds": False, - }, - { - "block_m_warps": 1, - "block_n_warps": 4, - "b_to_lds": False, - }, +HGEMM_BASE_SPLIT_K_OPTIONS = tuple(range(1, 14)) +HGEMM_MAX_SPLIT_K = 13 +HGEMM_WARP_SHAPE_OPTIONS = [ + (wm, wn, wk) for wm, wn, wk in product([1, 2, 4], repeat=3) if wm * wn * wk <= 8 +] +KERNEL_CONFIG_VARIANTS = [ { - "block_m_warps": 2, - "block_n_warps": 2, + "block_m_warps": wm, + "block_n_warps": wn, + "block_k_warps": wk, "b_to_lds": False, - }, - { - "block_m_warps": 1, - "block_n_warps": 4, - "b_to_lds": True, - }, + } + for wm, wn, wk in HGEMM_WARP_SHAPE_OPTIONS +] + [ { - "block_m_warps": 2, - "block_n_warps": 2, + "block_m_warps": wm, + "block_n_warps": wn, + "block_k_warps": wk, "b_to_lds": True, - }, -) + } + for wm, wn, wk in HGEMM_WARP_SHAPE_OPTIONS +] _SPLITK_HGEMM_KERNELS: Dict[str, Dict] = {} @@ -141,6 +135,7 @@ def flydsl_kernel_name( split_k: int, block_m_warp: int, block_n_warp: int, + block_k_warp: int, async_copy: bool, b_to_lds: bool, b_preshuffle: bool = False, @@ -167,7 +162,7 @@ def flydsl_kernel_name( name = ( f"flydsl_gemm{stage}_a{dtype}_w{dtype}_{out_dtype}_t{tile_m}x{tile_n}x{tile_k}" ) - name += f"_split_k{split_k}_block_m_warp{block_m_warp}_block_n_warp{block_n_warp}" + name += f"_split_k{split_k}_block_m_warp{block_m_warp}_block_n_warp{block_n_warp}_block_k_warp{block_k_warp}" name += ( f"_async_copy{async_copy}_b_to_lds{b_to_lds}_b_preshuffle{b_preshuffle}" f"_c_to_lds{c_to_lds}" @@ -332,6 +327,7 @@ def _validate_hgemm_tiling( stages: int, block_m_warps: int, block_n_warps: int, + block_k_warps: int, b_to_lds: bool, ) -> None: del m @@ -340,10 +336,10 @@ def _validate_hgemm_tiling( raise ValueError( f"Tile sizes must be positive, got tile_m={tile_m}, tile_n={tile_n}, tile_k={tile_k}" ) - if block_m_warps < 1 or block_n_warps < 1: + if block_m_warps < 1 or block_n_warps < 1 or block_k_warps < 1: raise ValueError( "Warp tiling must be positive, got " - f"block_m_warps={block_m_warps}, block_n_warps={block_n_warps}" + f"block_m_warps={block_m_warps}, block_n_warps={block_n_warps}, block_k_warps={block_k_warps}" ) if tile_k < 32: raise ValueError( @@ -396,7 +392,7 @@ def _validate_hgemm_tiling( f">= tile_k={tile_k} and % tile_k == 0" ) - block_threads = block_m_warps * block_n_warps * 64 + block_threads = block_m_warps * block_n_warps * block_k_warps * 64 ldg_vec_size = 8 block_vecs = ldg_vec_size * block_threads block_mk_size = tile_m * tile_k @@ -457,6 +453,7 @@ def _normalize_registry_config( split_k: int, block_m_warps: int, block_n_warps: int, + block_k_warps: int, b_to_lds: bool, ) -> Optional[Dict]: config = { @@ -468,6 +465,7 @@ def _normalize_registry_config( "split_k": int(split_k), "block_m_warps": int(block_m_warps), "block_n_warps": int(block_n_warps), + "block_k_warps": int(block_k_warps), "async_copy": KERNEL_ASYNC_COPY, "b_to_lds": bool(b_to_lds), "b_preshuffle": False, @@ -490,6 +488,7 @@ def _normalize_registry_config( stages=FIXED_STAGE, block_m_warps=config["block_m_warps"], block_n_warps=config["block_n_warps"], + block_k_warps=config["block_k_warps"], b_to_lds=config["b_to_lds"], ) except ValueError: @@ -519,6 +518,7 @@ def _parse_hgemm_kernel_params(name: str) -> Optional[Dict]: "split_k": int(m.group("split_k")), "block_m_warps": int(m.group("block_m_warps")), "block_n_warps": int(m.group("block_n_warps")), + "block_k_warps": int(m.group("block_k_warps")), "async_copy": m.group("async_copy") == "True", "b_to_lds": m.group("b_to_lds") == "True", "b_preshuffle": m.group("b_preshuffle") == "True", @@ -582,6 +582,7 @@ def get_flydsl_splitk_hgemm_kernels( split_k=split_k, block_m_warps=variant["block_m_warps"], block_n_warps=variant["block_n_warps"], + block_k_warps=variant["block_k_warps"], b_to_lds=variant["b_to_lds"], ) if config is None: @@ -599,6 +600,7 @@ def get_flydsl_splitk_hgemm_kernels( config["split_k"], config["block_m_warps"], config["block_n_warps"], + config["block_k_warps"], config["async_copy"], config["b_to_lds"], config["b_preshuffle"], @@ -699,8 +701,9 @@ def _compile_flydsl_hgemm( k: int, *, tile_k: int = 64, - block_m_warps: int = 1, - block_n_warps: int = 4, + block_m_warps: int = 2, + block_n_warps: int = 2, + block_k_warps: int = 1, tile_m: int = 128, tile_n: int = 128, pack_n: int = 1, @@ -740,6 +743,7 @@ def _compile_flydsl_hgemm( stages=stages, block_m_warps=block_m_warps, block_n_warps=block_n_warps, + block_k_warps=block_k_warps, b_to_lds=b_to_lds, ) elif kernel_family == KERNEL_FAMILY_SMALL_M: @@ -776,6 +780,7 @@ def _compile_flydsl_hgemm( split_k=split_k, block_m_warps=block_m_warps, block_n_warps=block_n_warps, + block_k_warps=block_k_warps, n_tile_repeat=n_tile_repeat, persistent_n_tiles=persistent_n_tiles, waves_per_eu=waves_per_eu, @@ -835,8 +840,9 @@ def flydsl_hgemm( tile_k: int = 64, pack_n: int = 1, split_k: int = 1, - block_m_warps: int = 1, - block_n_warps: int = 4, + block_m_warps: int = 2, + block_n_warps: int = 2, + block_k_warps: int = 1, n_tile_repeat: int = 1, persistent_n_tiles: int = 1, waves_per_eu: int = 0, @@ -887,6 +893,7 @@ def flydsl_hgemm( tile_k=tile_k, block_m_warps=block_m_warps, block_n_warps=block_n_warps, + block_k_warps=block_k_warps, tile_m=tile_m, tile_n=tile_n, pack_n=pack_n, diff --git a/aiter/ops/flydsl/kernels/hgemm_dispatch.py b/aiter/ops/flydsl/kernels/hgemm_dispatch.py index c4110d73ff..caeb3f9281 100644 --- a/aiter/ops/flydsl/kernels/hgemm_dispatch.py +++ b/aiter/ops/flydsl/kernels/hgemm_dispatch.py @@ -20,8 +20,9 @@ def compile_flydsl_hgemm_kernel( tile_k: int = 64, pack_n: int = 1, split_k: int = 1, - block_m_warps: int = 1, - block_n_warps: int = 4, + block_m_warps: int = 2, + block_n_warps: int = 2, + block_k_warps: int = 1, n_tile_repeat: int = 1, persistent_n_tiles: int = 1, waves_per_eu: int = 0, @@ -52,6 +53,7 @@ def compile_flydsl_hgemm_kernel( SPLIT_K=split_k, BLOCK_M_WARPS=block_m_warps, BLOCK_N_WARPS=block_n_warps, + BLOCK_K_WARPS=block_k_warps, B_TO_LDS=b_to_lds, HAS_BIAS=has_bias, ) diff --git a/aiter/ops/flydsl/kernels/splitk_hgemm.py b/aiter/ops/flydsl/kernels/splitk_hgemm.py index 5621753497..6070f7717b 100644 --- a/aiter/ops/flydsl/kernels/splitk_hgemm.py +++ b/aiter/ops/flydsl/kernels/splitk_hgemm.py @@ -113,24 +113,25 @@ def compile_hgemm_kernel( TILE_N: int = 128, TILE_K: int = 64, SPLIT_K: int = 1, - BLOCK_M_WARPS: int = 1, - BLOCK_N_WARPS: int = 4, + BLOCK_M_WARPS: int = 2, + BLOCK_N_WARPS: int = 2, + BLOCK_K_WARPS: int = 1, B_TO_LDS: bool = False, HAS_BIAS: bool = False, ): - assert BLOCK_M_WARPS * BLOCK_N_WARPS <= 4 + assert BLOCK_M_WARPS * BLOCK_N_WARPS * BLOCK_K_WARPS <= 8 assert TILE_M * TILE_N * TILE_K <= 256 * 256 * 64 if (TILE_M == 256) and (TILE_N == 256): assert (TILE_K == 64) and (SPLIT_K == 1) N_BLOCKS = n // TILE_N assert (N_BLOCKS >= 1) and (n % TILE_N == 0) IS_SPLIT_K = SPLIT_K > 1 + IS_SLICE_K = BLOCK_K_WARPS > 1 BLOCK_K = TILE_K assert (k % SPLIT_K == 0) and (k // SPLIT_K >= 1) ks = k // SPLIT_K assert (ks % BLOCK_K == 0) and (ks // BLOCK_K >= 1) assert BLOCK_K >= 32 - GPU_ARCH = get_rocm_arch() if GPU_ARCH == "gfx942": WMMA_IMPL = WmmaHalf_m16n16k16(dtype) @@ -143,11 +144,13 @@ def compile_hgemm_kernel( MFMA_PER_WARP_K = 1 ASYNC_COPY = True + # Fixed parameters: WARP_SIZE = 64 DTYPE_BYTES = 2 LDG_VEC_SIZE = 8 STAGES = 2 + # Propagated parameters: WMMA_M = WMMA_IMPL.WMMA_M WMMA_N = WMMA_IMPL.WMMA_N WMMA_K = WMMA_IMPL.WMMA_K @@ -158,9 +161,13 @@ def compile_hgemm_kernel( WARP_ATOM_N = WMMA_N WARP_ATOM_K = WMMA_K * MFMA_PER_WARP_K BLOCK_K_LOOPS = ks // BLOCK_K - WARP_K_STEPS = BLOCK_K // WARP_ATOM_K - assert (BLOCK_K % WARP_ATOM_K == 0) and (WARP_K_STEPS >= 1) - BLOCK_THREADS = BLOCK_M_WARPS * BLOCK_N_WARPS * WARP_SIZE + WARP_GROUP_K = BLOCK_K_WARPS * WARP_ATOM_K + WARP_K_STEPS = BLOCK_K // WARP_GROUP_K + assert (BLOCK_K % WARP_GROUP_K == 0) and (WARP_K_STEPS >= 1) + K_SLICE = BLOCK_K // BLOCK_K_WARPS + assert K_SLICE % WARP_ATOM_K == 0 + BLOCK_THREADS = BLOCK_M_WARPS * BLOCK_N_WARPS * BLOCK_K_WARPS * WARP_SIZE + BLOCK_MN_WARPS = BLOCK_M_WARPS * BLOCK_N_WARPS WARP_M_STEPS = TILE_M // BLOCK_M_WARPS // WARP_ATOM_M WARP_N_STEPS = TILE_N // BLOCK_N_WARPS // WARP_ATOM_N assert (WARP_M_STEPS >= 1) and (WARP_N_STEPS >= 1) @@ -187,15 +194,7 @@ def compile_hgemm_kernel( assert BLOCK_MN_SIZE % BLOCK_VECS == 0 BLOCK_K_BYTES = BLOCK_K * DTYPE_BYTES - KERNEL_NAME = ( - f"hgemm_{dtype}_{BLOCK_M}x{BLOCK_N}x{BLOCK_K}" - f"_W{BLOCK_M_WARPS}x{BLOCK_N_WARPS}_S{STAGES}_BT_BLDS{int(B_TO_LDS)}" - ) - KERNEL_NAME += "_AS0" if not ASYNC_COPY else "_AS1" - KERNEL_NAME += f"_SPK{SPLIT_K}" - if HAS_BIAS: - KERNEL_NAME += "_BIAS" - + # LDS parameters: allocator = SmemAllocator(None, arch=GPU_ARCH, global_sym_name="smem") smem_a_offset = allocator._align(allocator.ptr, 16) AS_BYTES = STAGES * BLOCK_M * BLOCK_K * DTYPE_BYTES @@ -205,14 +204,21 @@ def compile_hgemm_kernel( smem_b_offset = allocator._align(allocator.ptr, 16) allocator.ptr = smem_b_offset + STAGES * BLOCK_N * BLOCK_K * DTYPE_BYTES SMEM_USE += STAGES * BLOCK_N * BLOCK_K * DTYPE_BYTES - SMEM_USE = max(SMEM_USE, BLOCK_M * BLOCK_N * DTYPE_BYTES) - assert SMEM_USE <= SMEM_CAPACITY_MAP[GPU_ARCH] + SMEM_USE_ = max(SMEM_USE, BLOCK_K_WARPS * BLOCK_M * BLOCK_N * DTYPE_BYTES) + allocator.ptr += SMEM_USE_ - SMEM_USE + assert SMEM_USE_ <= SMEM_CAPACITY_MAP[GPU_ARCH] LDG_ASYNC_VEC_SIZE = DMA_BYTES // DTYPE_BYTES LDG_A_X_THREADS_AS = BLOCK_K // LDG_ASYNC_VEC_SIZE LDG_REG_A_COUNT_AS = BLOCK_MK_SIZE // LDG_ASYNC_VEC_SIZE // BLOCK_THREADS LDG_B_X_THREADS_AS = BLOCK_K // LDG_ASYNC_VEC_SIZE LDG_REG_B_COUNT_AS = BLOCK_NK_SIZE // LDG_ASYNC_VEC_SIZE // BLOCK_THREADS + KERNEL_NAME = f"hgemm_{dtype}_{BLOCK_M}x{BLOCK_N}x{BLOCK_K}_W{BLOCK_M_WARPS}x{BLOCK_N_WARPS}x{BLOCK_K_WARPS}_S{STAGES}_BT_BLDS{int(B_TO_LDS)}" + KERNEL_NAME += "_AS0" if not ASYNC_COPY else "_AS1" + KERNEL_NAME += f"_SPK{SPLIT_K}" + if HAS_BIAS: + KERNEL_NAME += "_BIAS" + @flyc.kernel(known_block_size=[BLOCK_THREADS, 1, 1]) def hgemm_kernel( C: fx.Tensor, @@ -224,8 +230,6 @@ def hgemm_kernel( signal: fx.Tensor, ): dtype_ = get_dtype_in_kernel(dtype) - _ptr_type = ir.Type.parse("!llvm.ptr<1>") - _i64_type = T.i64 c_zero_d = arith.constant(0.0, type=dtype_) acc_init = arith.constant_vector(0.0, T.vec(WMMA_C_FRAG_VALUES, T.f32)) @@ -245,21 +249,22 @@ def hgemm_kernel( ) bs_ = STensor(smem_b_ptr, dtype_, shape=(STAGES, BLOCK_N, BLOCK_K)) smem_c_ptr = SmemPtr( - base_ptr, smem_a_offset, dtype_, shape=(BLOCK_M * BLOCK_N,) + base_ptr, smem_a_offset, dtype_, shape=(BLOCK_K_WARPS * BLOCK_M * BLOCK_N,) ) - cs_ = STensor(smem_c_ptr, dtype_, shape=(BLOCK_M, BLOCK_N)) + cs_ = STensor(smem_c_ptr, dtype_, shape=(BLOCK_K_WARPS, BLOCK_M, BLOCK_N)) if const_expr(IS_SPLIT_K): - smem_bc_ptr = SmemPtr(base_ptr, smem_a_offset, T.i32, shape=(1,)) - bc_ = STensor(smem_bc_ptr, T.i32, shape=(1,)) semaphore_ = GTensor(semaphore, dtype=T.i32, shape=(-1,)) signal_ = GTensor(signal, dtype=T.i32, shape=(-1,)) signal_idx = fx.Int32(fx.block_idx.x) - tid = fx.Int32(fx.thread_idx.x) + tid = fx.thread_idx.x wid = tid // WARP_SIZE + wid_mn = wid % BLOCK_MN_WARPS + wid_k = wid // BLOCK_MN_WARPS w_tid = tid % WARP_SIZE def swizzle_for_cache_reuse(pid): + # Do nothing currently return pid // N_BLOCKS, pid % N_BLOCKS block_m_idx, block_n_idx = swizzle_for_cache_reuse(fx.block_idx.x) @@ -270,52 +275,39 @@ def swizzle_for_cache_reuse(pid): n_offset = fx.Index(block_n_idx * BLOCK_N) k_blocks16 = fx.Int32(BLOCK_K_BYTES // 16) - warp_m_idx = wid // BLOCK_N_WARPS * WARP_M - warp_n_idx = wid % BLOCK_N_WARPS * WARP_N + warp_m_idx = wid_mn // BLOCK_N_WARPS * WARP_M + warp_n_idx = wid_mn % BLOCK_N_WARPS * WARP_N ldmatrix_a_m_idx = w_tid % WMMA_M ldmatrix_a_k_vec_idx = w_tid // WMMA_M * WMMA_A_FRAG_VALUES * MFMA_PER_WARP_K ldmatrix_b_n_idx = w_tid % WMMA_N ldmatrix_b_k_vec_idx = w_tid // WMMA_N * WMMA_B_FRAG_VALUES * MFMA_PER_WARP_K + warp_k_slice_base = wid_k * K_SLICE A_FRAGS_LEN = WARP_K_STEPS * WARP_M_STEPS B_FRAGS_LEN = WARP_K_STEPS * WARP_N_STEPS C_FRAGS_LEN = WARP_M_STEPS * WARP_N_STEPS c_frags = [acc_init] * C_FRAGS_LEN - def get_llvm_ptr(ptr, offset, dtype_bytes): - base_ptr = fly.extract_aligned_pointer_as_index(_ptr_type, ptr) - base_ptr = llvm.PtrToIntOp(_i64_type, base_ptr).result + def get_llvm_ptr( + ptr, offset, dtype_bytes, ptr_type=ir.Type.parse("!llvm.ptr<1>") + ): + base_ptr = fly.extract_aligned_pointer_as_index(ptr_type, ptr) + base_ptr = llvm.PtrToIntOp(T.i64, base_ptr).result byte_offset = arith.index_cast( T.i64, fx.Index(offset) * fx.Index(dtype_bytes) ) llvm_ptr = llvm.AddOp( base_ptr, byte_offset, llvm.IntegerOverflowFlags(0) ).result - llvm_ptr = llvm.IntToPtrOp(_ptr_type, llvm_ptr).result + llvm_ptr = llvm.IntToPtrOp(ptr_type, llvm_ptr).result ptr_v = ( llvm_ptr._value if const_expr(hasattr(llvm_ptr, "_value")) else llvm_ptr ) return ptr_v def zero_c(): - # get arrive index within split-k group + # zero c if current block is the first block is_t0_cond = arith.cmpi(arith.CmpIPredicate.eq, fx.Index(tid), fx.Index(0)) - is_t0_cond_if = scf.IfOp(is_t0_cond, results_=[], has_else=False) - with ir.InsertionPoint(is_t0_cond_if.then_block): - semaphore_ptr = get_llvm_ptr(semaphore, signal_idx, 4) - prev = llvm.AtomicRMWOp( - llvm.AtomicBinOp.add, - semaphore_ptr, - arith.constant(1, type=T.i32), - llvm.AtomicOrdering.monotonic, - syncscope="agent", - alignment=4, - ).result - bc_[0] = prev - scf.YieldOp([]) - gpu.barrier() - arrive_idx = fx.Index(bc_[0]) - # zero c if current block is the first arrived block - cond_ks0 = arith.cmpi(arith.CmpIPredicate.eq, arrive_idx, fx.Index(0)) + cond_ks0 = arith.cmpi(arith.CmpIPredicate.eq, ks_idx, fx.Index(0)) cond_ks0_if = scf.IfOp(cond_ks0, results_=[], has_else=False) with ir.InsertionPoint(cond_ks0_if.then_block): zero_vec = vector.broadcast(T.vec(LDG_VEC_SIZE, dtype_), c_zero_d) @@ -349,13 +341,6 @@ def zero_c(): has_side_effects=True, ) scf.YieldOp([]) - llvm.InlineAsmOp( - None, - [], - "s_waitcnt vmcnt(0)", - "", - has_side_effects=True, - ) gpu.barrier() # trigger signal when zeroc is done by the first arrived block is_t0_cond_if = scf.IfOp(is_t0_cond, results_=[], has_else=False) @@ -414,9 +399,7 @@ def split_k_barrier(): alignment=4, ).result cond_ksl = arith.cmpi( - arith.CmpIPredicate.eq, - fx.Index(arrive_idx), - fx.Index(2 * SPLIT_K - 1), + arith.CmpIPredicate.eq, fx.Index(arrive_idx), fx.Index(SPLIT_K - 1) ) cond_ksl_if = scf.IfOp(cond_ksl, results_=[], has_else=False) with ir.InsertionPoint(cond_ksl_if.then_block): @@ -592,7 +575,7 @@ def lds_matrix_a(lds_stage): for ii in range_constexpr(WARP_M_STEPS): warp_atom_m_idx = warp_m_idx + ii * WARP_ATOM_M for kk in range_constexpr(WARP_K_STEPS): - warp_atom_k_idx = kk * WARP_ATOM_K + warp_atom_k_idx = warp_k_slice_base + kk * WARP_ATOM_K row = warp_atom_m_idx + ldmatrix_a_m_idx col_in_bytes = ( warp_atom_k_idx + ldmatrix_a_k_vec_idx @@ -611,7 +594,7 @@ def lds_matrix_b(lds_stage): for ii in range_constexpr(WARP_N_STEPS): warp_atom_n_idx = warp_n_idx + ii * WARP_ATOM_N for kk in range_constexpr(WARP_K_STEPS): - warp_atom_k_idx = kk * WARP_ATOM_K + warp_atom_k_idx = warp_k_slice_base + kk * WARP_ATOM_K row = warp_atom_n_idx + ldmatrix_b_n_idx col_in_bytes = ( warp_atom_k_idx + ldmatrix_b_k_vec_idx @@ -629,7 +612,7 @@ def ldg_matrix_b(k_offset): for kk in range_constexpr(WARP_K_STEPS): for ii in range_constexpr(WARP_N_STEPS): warp_atom_n_idx = warp_n_idx + ii * WARP_ATOM_N - warp_atom_k_idx = kk * WARP_ATOM_K + warp_atom_k_idx = warp_k_slice_base + kk * WARP_ATOM_K n_idx = n_offset + warp_atom_n_idx + ldmatrix_b_n_idx k_idx = k_offset + warp_atom_k_idx + ldmatrix_b_k_vec_idx vec = B_.vec_load( @@ -857,14 +840,15 @@ def hot_loop_scheduler(): static_position=[kk], dynamic_position=[], ) - cs_[lds_m_idx, lds_n_idx] = val.truncf(dtype_) + val = val.truncf(dtype_) + if const_expr(IS_SLICE_K): + cs_[wid_k, lds_m_idx, lds_n_idx] = val + else: + cs_[0, lds_m_idx, lds_n_idx] = val # write back to global if const_expr(IS_SPLIT_K): split_k_barrier() - out_raw = C - out_base_ptr = fly.extract_aligned_pointer_as_index(_ptr_type, out_raw) - out_base_int = llvm.PtrToIntOp(_i64_type, out_base_ptr).result for i in range_constexpr(LDG_REG_C_COUNT): global_tid = BLOCK_THREADS * i + tid m_local_idx = fx.Index(global_tid // LDG_C_X_THREADS) @@ -876,10 +860,12 @@ def hot_loop_scheduler(): ) cond_boundary_if = scf.IfOp(cond_boundary, results_=[], has_else=False) with ir.InsertionPoint(cond_boundary_if.then_block): - pk_val = cs_.vec_load((m_local_idx, n_local_idx), LDG_VEC_SIZE) - linear_bytes_offset = ( - C_.linear_offset((m_global_idx, n_global_idx)) * DTYPE_BYTES - ) + pk_val = cs_.vec_load((0, m_local_idx, n_local_idx), LDG_VEC_SIZE) + for ksi in range_constexpr(1, BLOCK_K_WARPS): + pk_val += cs_.vec_load( + (ksi, m_local_idx, n_local_idx), LDG_VEC_SIZE + ) + linear_offset_c = C_.linear_offset((m_global_idx, n_global_idx)) # split to vec2s vec2_ty = T.vec(2, dtype_) for vec_idx in range_constexpr(LDG_VEC_SIZE // 2): @@ -892,22 +878,12 @@ def hot_loop_scheduler(): dynamic_position=[], ) pair = vector.from_elements(vec2_ty, [e0, e1]) - pair_byte_offset = arith.index_cast( - T.i64, - linear_bytes_offset + fx.Index(vec_idx * 2 * DTYPE_BYTES), - ) - pair_addr_i64 = llvm.AddOp( - out_base_int, pair_byte_offset, llvm.IntegerOverflowFlags(0) - ).result - pair_ptr = llvm.IntToPtrOp(_ptr_type, pair_addr_i64).result - pair_ptr_v = ( - pair_ptr._value - if const_expr(hasattr(pair_ptr, "_value")) - else pair_ptr - ) pair_v = ( pair._value if const_expr(hasattr(pair, "_value")) else pair ) + pair_ptr_v = get_llvm_ptr( + C, fx.Int32(linear_offset_c + vec_idx * 2), DTYPE_BYTES + ) llvm.AtomicRMWOp( llvm.AtomicBinOp.fadd, pair_ptr_v, @@ -929,7 +905,11 @@ def hot_loop_scheduler(): ) cond_boundary_if = scf.IfOp(cond_boundary, results_=[], has_else=False) with ir.InsertionPoint(cond_boundary_if.then_block): - vec = cs_.vec_load((m_local_idx, n_local_idx), LDG_VEC_SIZE) + vec = cs_.vec_load((0, m_local_idx, n_local_idx), LDG_VEC_SIZE) + for ksi in range_constexpr(1, BLOCK_K_WARPS): + vec += cs_.vec_load( + (ksi, m_local_idx, n_local_idx), LDG_VEC_SIZE + ) if const_expr(HAS_BIAS): bias_vec = BIAS_.vec_load( (n_offset + n_local_idx,), LDG_VEC_SIZE @@ -960,9 +940,7 @@ def launch_hgemm_kernel( bm = (m + BLOCK_M - 1) // BLOCK_M hgemm_kernel._func.__name__ = KERNEL_NAME hgemm_kernel(C, A, B, BIAS, m, semaphore, signal).launch( - grid=(bm * N_BLOCKS, SPLIT_K, 1), - block=(BLOCK_THREADS, 1, 1), - stream=stream, + grid=(bm * N_BLOCKS, SPLIT_K, 1), block=(BLOCK_THREADS, 1, 1), stream=stream ) return launch_hgemm_kernel diff --git a/gradlib/gradlib/GemmTuner.py b/gradlib/gradlib/GemmTuner.py index 0ee6e9f6b0..5d2963e273 100644 --- a/gradlib/gradlib/GemmTuner.py +++ b/gradlib/gradlib/GemmTuner.py @@ -155,6 +155,7 @@ def run_flydsl_gemm_bf16(input, weight, bias=None, otype=dtypes.bf16, config=Non split_k=config["split_k"], block_m_warps=config["block_m_warps"], block_n_warps=config["block_n_warps"], + block_k_warps=config["block_k_warps"], n_tile_repeat=config.get("n_tile_repeat", 1), persistent_n_tiles=config.get("persistent_n_tiles", 1), waves_per_eu=config.get("waves_per_eu", 0),