From c1f8760bf6fefa35abf26bf0ae71cad59cb70f84 Mon Sep 17 00:00:00 2001 From: Christian Gilli <272772515+amd-cgilli@users.noreply.github.com> Date: Wed, 20 May 2026 14:07:12 +0000 Subject: [PATCH 1/2] Support dynamic dynamic shapes (M, N) --- kernels/fp8_gemm_4wave.py | 42 ++++++++++++++----------- kernels/fp8_gemm_8wave.py | 27 +++++++++++----- kernels/fp8_gemm_utils.py | 8 +++++ tests/kernels/test_fp8_gemm_rowscale.py | 6 ++-- 4 files changed, 53 insertions(+), 30 deletions(-) diff --git a/kernels/fp8_gemm_4wave.py b/kernels/fp8_gemm_4wave.py index a6dd590e..b7e7b958 100644 --- a/kernels/fp8_gemm_4wave.py +++ b/kernels/fp8_gemm_4wave.py @@ -25,7 +25,9 @@ Mfma16x16x128, S2RLoader, StoreC, + ceildiv, compute_global_swizzle, + divmod, make_fp8_buffer_tensor, pack_i32x4_i32x8, swizzle_128, @@ -33,10 +35,6 @@ ) -def _divmod(a, b): - return (a // b, a % b) - - def _min(a, b): return arith.select(a < b, a, b) @@ -51,24 +49,25 @@ def _xcd_swizzle(num_pid_m, num_pid_n): num_wg = num_pid_m * num_pid_n - if num_wg <= SWIZZLE_THRESHOLD or num_wg % NUM_XCDS != 0: - return _divmod(wgid, num_pid_n) + # Simple path: no XCD remapping. + simple_m, simple_n = divmod(wgid, num_pid_n) - intra_xcd, xcd = _divmod(wgid, NUM_XCDS) - wgid = xcd * (num_wg // NUM_XCDS) + intra_xcd + # XCD-remapped path. + intra_xcd, xcd = divmod(wgid, NUM_XCDS) + wgid_remap = xcd * (num_wg // NUM_XCDS) + intra_xcd num_wgid_in_group = WGM * num_pid_n - group_id, intra_group = _divmod(wgid, num_wgid_in_group) + group_id, intra_group = divmod(wgid_remap, num_wgid_in_group) first_pid_m = group_id * WGM group_size_m = _min(num_pid_m - first_pid_m, WGM) - pid_n, intra_group_m = _divmod(intra_group, group_size_m) + pid_n, intra_group_m = divmod(intra_group, group_size_m) pid_m = first_pid_m + intra_group_m - return (pid_m, pid_n) + + use_simple = (num_wg <= SWIZZLE_THRESHOLD) | (num_wg % NUM_XCDS != 0) + return (arith.select(use_simple, simple_m, pid_m), arith.select(use_simple, simple_n, pid_n)) def compile_fp8_gemm_4w( *, - M: int, - N: int, K: int, BLOCK_M: int = 256, BLOCK_N: int = 256, @@ -80,11 +79,9 @@ def compile_fp8_gemm_4w( LDS_BLOCK_M = BLOCK_M // 2 LDS_BLOCK_N = BLOCK_N // 2 - assert M >= 1 and N >= 1 assert BLOCK_M >= 64 and BLOCK_M % 64 == 0 and BLOCK_N >= 64 and BLOCK_N % 64 == 0 assert K % BLOCK_K == 0 - N_BLOCKS = (N + BLOCK_N - 1) // BLOCK_N K_ITERS = K // BLOCK_K # Number of 16-row 16x128 tiles per wave per A/B partition. N_TILES_A = BLOCK_M // 4 // 16 @@ -117,6 +114,8 @@ def kernel_gemm( C: fx.Tensor, A_scale: fx.Tensor, B_scale: fx.Tensor, + c_m: fx.Int32, + c_n: fx.Int32 ): F8_IR_t = fx.Float8E4M3FN.ir_type @@ -133,10 +132,11 @@ def kernel_gemm( lane_id = fx.thread_idx.x % 64 wave_id = fx.thread_idx.x // 64 + n_blocks = ceildiv(c_n, BLOCK_N) if const_expr(use_xcd_remap): - tile_i, tile_j = _xcd_swizzle((M + BLOCK_M - 1) // BLOCK_M, N_BLOCKS) + tile_i, tile_j = _xcd_swizzle(ceildiv(c_m, BLOCK_M), n_blocks) else: - tile_i, tile_j = _divmod(fx.block_idx.x, N_BLOCKS) + tile_i, tile_j = divmod(fx.block_idx.x, n_blocks) wave_i = wave_id // 2 wave_j = wave_id % 2 @@ -299,7 +299,7 @@ def _compute_block( b_g2s = G2SLoader(gb_div, gl_off_b, N_TILES_B, F8_IR_t, wave_id) a_s2r = S2RLoader(wave_i, N_TILES_A) b_s2r = S2RLoader(wave_j, N_TILES_B) - store_c = StoreC(A_scale, B_scale, C, M, N, mfma.idx, N_TILES_A, N_TILES_B) + store_c = StoreC(A_scale, B_scale, C, c_m, c_n, mfma.idx, N_TILES_A, N_TILES_B) # Prologue: 8-buffer LDS pipeline pre-fill. a_g2s.load(a_cur0, A0_gl_offset + 0 * A_K_STEP) @@ -416,15 +416,19 @@ def launch_gemm( C: fx.Tensor, A_scale: fx.Tensor, B_scale: fx.Tensor, + c_m: fx.Int32, + c_n: fx.Int32, stream: fx.Stream, ): - grid_x = ((M + BLOCK_M - 1) // BLOCK_M) * N_BLOCKS + grid_x = ceildiv(c_m, BLOCK_M) * ceildiv(c_n, BLOCK_N) kernel_gemm( A, B_T, C, A_scale, B_scale, + c_m, + c_n, value_attrs={"rocdl.waves_per_eu": 1, "rocdl.flat_work_group_size": "256,256"}, ).launch(grid=(grid_x, 1, 1), block=(256, 1, 1), stream=stream) diff --git a/kernels/fp8_gemm_8wave.py b/kernels/fp8_gemm_8wave.py index 4a35347d..20080ac8 100644 --- a/kernels/fp8_gemm_8wave.py +++ b/kernels/fp8_gemm_8wave.py @@ -15,20 +15,26 @@ Mfma16x16x128, S2RLoader, StoreC, + ceildiv, compute_global_swizzle, + divmod, make_fp8_buffer_tensor, wait_barrier, ) -def compile_fp8_gemm_8w(*, M: int, N: int, K: int, BLOCK_M: int = 256, BLOCK_N: int = 256, b_preshuffled: bool = False): +def compile_fp8_gemm_8w( + *, + K: int, + BLOCK_M: int = 256, + BLOCK_N: int = 256, + b_preshuffled: bool = False +): BLOCK_K = 128 - assert M >= 1 and N >= 1 assert BLOCK_M >= 128 and BLOCK_N >= 256 and BLOCK_M % 128 == 0 and BLOCK_N % 256 == 0 assert K % BLOCK_K == 0 - N_BLOCKS = (N + BLOCK_N - 1) // BLOCK_N K_ITERS = K // BLOCK_K N_TILES_A = BLOCK_M // 64 @@ -65,9 +71,13 @@ def kernel_gemm( C: fx.Tensor, A_scale: fx.Tensor, B_scale: fx.Tensor, + c_m: fx.Int32, + c_n: fx.Int32, ): F8_IR_t = fx.Float8E4M3FN.ir_type + n_blocks = ceildiv(c_n, BLOCK_N) + lds = fx.SharedAllocator().allocate(SharedStorage).peek() a_cur0 = lds.A_lds_cur_0 a_cur1 = lds.A_lds_cur_1 @@ -82,8 +92,7 @@ def kernel_gemm( wave_id = fx.thread_idx.x // 64 wave_m = wave_id // 4 wave_n = wave_id % 4 - block_m = fx.block_idx.x // N_BLOCKS - block_n = fx.block_idx.x % N_BLOCKS + block_m, block_n = divmod(fx.block_idx.x, n_blocks) A0_gl_offset = (block_m * BLOCK_M) * K A1_gl_offset = (block_m * BLOCK_M + LDS_BLOCK_M) * K @@ -105,7 +114,7 @@ def kernel_gemm( b_g2s = G2SLoader(b_div, gl_off_b, N_LDS_STEPS_B, F8_IR_t, wave_id) a_s2r = S2RLoader(wave_m, N_TILES_A) b_s2r = S2RLoader(wave_n, N_TILES_B) - store_c = StoreC(A_scale, B_scale, C, M, N, mfma.idx, N_TILES_A, N_TILES_B) + store_c = StoreC(A_scale, B_scale, C, c_m, c_n, mfma.idx, N_TILES_A, N_TILES_B) # 2x2 config of 4x2 (instead of 4x4 in 4wave) 16x16 sub-tiles c00_frag = [mfma.zero_value] * N_ACCUMS @@ -257,15 +266,19 @@ def launch_gemm( C: fx.Tensor, A_scale: fx.Tensor, B_scale: fx.Tensor, + c_m: fx.Int32, + c_n: fx.Int32, stream: fx.Stream, ): - grid_x = ((M + BLOCK_M - 1) // BLOCK_M) * N_BLOCKS + grid_x = ceildiv(c_m, BLOCK_M) * ceildiv(c_n, BLOCK_N) kernel_gemm( A, B_T, C, A_scale, B_scale, + c_m, + c_n, value_attrs={"rocdl.waves_per_eu": 2, "rocdl.flat_work_group_size": "512,512"}, ).launch(grid=(grid_x, 1, 1), block=(512, 1, 1), stream=stream) diff --git a/kernels/fp8_gemm_utils.py b/kernels/fp8_gemm_utils.py index 0d24149f..a291c0f9 100644 --- a/kernels/fp8_gemm_utils.py +++ b/kernels/fp8_gemm_utils.py @@ -16,6 +16,14 @@ def preshuffle_b(b_t): return b_t.reshape(n // 16, 16, k // 64, 4, 16).permute(0, 2, 3, 1, 4).contiguous() +def ceildiv(a: int, b: int) -> int: + return (a + b - 1) // b + + +def divmod(a: int, b: int) -> tuple[int, int]: + return (a // b, a % b) + + def make_fp8_buffer_tensor(arg_i8, fp8_ir_t): t_i8 = fx.rocdl.make_buffer_tensor(arg_i8, max_size=False) iter_i8 = fx.get_iter(t_i8) diff --git a/tests/kernels/test_fp8_gemm_rowscale.py b/tests/kernels/test_fp8_gemm_rowscale.py index 81b60877..64b7f902 100644 --- a/tests/kernels/test_fp8_gemm_rowscale.py +++ b/tests/kernels/test_fp8_gemm_rowscale.py @@ -102,8 +102,6 @@ def _bench_fp8_gemm( if use_8w: launch_fn = compile_fp8_gemm_8w( - M=M, - N=N, K=K, BLOCK_M=tile_m, BLOCK_N=tile_n, @@ -112,8 +110,6 @@ def _bench_fp8_gemm( print(f"\n[fp8_gemm_8wave] M={M} N={N} K={K} BLOCK_M={tile_m} BLOCK_N={tile_n} preshuffle_b={b_preshuffled}") else: launch_fn = compile_fp8_gemm_4w( - M=M, - N=N, K=K, BLOCK_M=tile_m, BLOCK_N=tile_n, @@ -133,6 +129,8 @@ def _args(c, a, b, sa, sb): c.contiguous().view(-1), sa.contiguous().view(-1), sb.contiguous().view(-1), + M, + N, torch.cuda.current_stream(), ) From affe3ad5250d2db66f6846ec2b8e108a814e5d77 Mon Sep 17 00:00:00 2001 From: Christian Gilli <272772515+amd-cgilli@users.noreply.github.com> Date: Wed, 20 May 2026 14:15:04 +0000 Subject: [PATCH 2/2] Fix formatting --- kernels/fp8_gemm_4wave.py | 8 +------- kernels/fp8_gemm_8wave.py | 8 +------- 2 files changed, 2 insertions(+), 14 deletions(-) diff --git a/kernels/fp8_gemm_4wave.py b/kernels/fp8_gemm_4wave.py index b7e7b958..1f32a435 100644 --- a/kernels/fp8_gemm_4wave.py +++ b/kernels/fp8_gemm_4wave.py @@ -109,13 +109,7 @@ class SharedStorage: @flyc.kernel def kernel_gemm( - A: fx.Tensor, - B_T: fx.Tensor, - C: fx.Tensor, - A_scale: fx.Tensor, - B_scale: fx.Tensor, - c_m: fx.Int32, - c_n: fx.Int32 + A: fx.Tensor, B_T: fx.Tensor, C: fx.Tensor, A_scale: fx.Tensor, B_scale: fx.Tensor, c_m: fx.Int32, c_n: fx.Int32 ): F8_IR_t = fx.Float8E4M3FN.ir_type diff --git a/kernels/fp8_gemm_8wave.py b/kernels/fp8_gemm_8wave.py index 20080ac8..e8e84e3e 100644 --- a/kernels/fp8_gemm_8wave.py +++ b/kernels/fp8_gemm_8wave.py @@ -23,13 +23,7 @@ ) -def compile_fp8_gemm_8w( - *, - K: int, - BLOCK_M: int = 256, - BLOCK_N: int = 256, - b_preshuffled: bool = False -): +def compile_fp8_gemm_8w(*, K: int, BLOCK_M: int = 256, BLOCK_N: int = 256, b_preshuffled: bool = False): BLOCK_K = 128 assert BLOCK_M >= 128 and BLOCK_N >= 256 and BLOCK_M % 128 == 0 and BLOCK_N % 256 == 0