Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 22 additions & 24 deletions kernels/fp8_gemm_4wave.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,16 @@
Mfma16x16x128,
S2RLoader,
StoreC,
ceildiv,
compute_global_swizzle,
divmod,
make_fp8_buffer_tensor,
pack_i32x4_i32x8,
swizzle_128,
wait_barrier,
)


def _divmod(a, b):
return (a // b, a % b)


def _min(a, b):
return arith.select(a < b, a, b)

Expand All @@ -51,24 +49,25 @@ def _xcd_swizzle(num_pid_m, num_pid_n):

num_wg = num_pid_m * num_pid_n

if num_wg <= SWIZZLE_THRESHOLD or num_wg % NUM_XCDS != 0:
return _divmod(wgid, num_pid_n)
# Simple path: no XCD remapping.
simple_m, simple_n = divmod(wgid, num_pid_n)

intra_xcd, xcd = _divmod(wgid, NUM_XCDS)
wgid = xcd * (num_wg // NUM_XCDS) + intra_xcd
# XCD-remapped path.
intra_xcd, xcd = divmod(wgid, NUM_XCDS)
wgid_remap = xcd * (num_wg // NUM_XCDS) + intra_xcd
num_wgid_in_group = WGM * num_pid_n
group_id, intra_group = _divmod(wgid, num_wgid_in_group)
group_id, intra_group = divmod(wgid_remap, num_wgid_in_group)
first_pid_m = group_id * WGM
group_size_m = _min(num_pid_m - first_pid_m, WGM)
pid_n, intra_group_m = _divmod(intra_group, group_size_m)
pid_n, intra_group_m = divmod(intra_group, group_size_m)
pid_m = first_pid_m + intra_group_m
return (pid_m, pid_n)

use_simple = (num_wg <= SWIZZLE_THRESHOLD) | (num_wg % NUM_XCDS != 0)
return (arith.select(use_simple, simple_m, pid_m), arith.select(use_simple, simple_n, pid_n))


def compile_fp8_gemm_4w(
*,
M: int,
N: int,
K: int,
BLOCK_M: int = 256,
BLOCK_N: int = 256,
Expand All @@ -80,11 +79,9 @@ def compile_fp8_gemm_4w(
LDS_BLOCK_M = BLOCK_M // 2
LDS_BLOCK_N = BLOCK_N // 2

assert M >= 1 and N >= 1
assert BLOCK_M >= 64 and BLOCK_M % 64 == 0 and BLOCK_N >= 64 and BLOCK_N % 64 == 0
assert K % BLOCK_K == 0

N_BLOCKS = (N + BLOCK_N - 1) // BLOCK_N
K_ITERS = K // BLOCK_K
# Number of 16-row 16x128 tiles per wave per A/B partition.
N_TILES_A = BLOCK_M // 4 // 16
Expand Down Expand Up @@ -112,11 +109,7 @@ class SharedStorage:

@flyc.kernel
def kernel_gemm(
A: fx.Tensor,
B_T: fx.Tensor,
C: fx.Tensor,
A_scale: fx.Tensor,
B_scale: fx.Tensor,
A: fx.Tensor, B_T: fx.Tensor, C: fx.Tensor, A_scale: fx.Tensor, B_scale: fx.Tensor, c_m: fx.Int32, c_n: fx.Int32
):
F8_IR_t = fx.Float8E4M3FN.ir_type

Expand All @@ -133,10 +126,11 @@ def kernel_gemm(
lane_id = fx.thread_idx.x % 64
wave_id = fx.thread_idx.x // 64

n_blocks = ceildiv(c_n, BLOCK_N)
if const_expr(use_xcd_remap):
tile_i, tile_j = _xcd_swizzle((M + BLOCK_M - 1) // BLOCK_M, N_BLOCKS)
tile_i, tile_j = _xcd_swizzle(ceildiv(c_m, BLOCK_M), n_blocks)
else:
tile_i, tile_j = _divmod(fx.block_idx.x, N_BLOCKS)
tile_i, tile_j = divmod(fx.block_idx.x, n_blocks)

wave_i = wave_id // 2
wave_j = wave_id % 2
Expand Down Expand Up @@ -299,7 +293,7 @@ def _compute_block(
b_g2s = G2SLoader(gb_div, gl_off_b, N_TILES_B, F8_IR_t, wave_id)
a_s2r = S2RLoader(wave_i, N_TILES_A)
b_s2r = S2RLoader(wave_j, N_TILES_B)
store_c = StoreC(A_scale, B_scale, C, M, N, mfma.idx, N_TILES_A, N_TILES_B)
store_c = StoreC(A_scale, B_scale, C, c_m, c_n, mfma.idx, N_TILES_A, N_TILES_B)

# Prologue: 8-buffer LDS pipeline pre-fill.
a_g2s.load(a_cur0, A0_gl_offset + 0 * A_K_STEP)
Expand Down Expand Up @@ -416,15 +410,19 @@ def launch_gemm(
C: fx.Tensor,
A_scale: fx.Tensor,
B_scale: fx.Tensor,
c_m: fx.Int32,
c_n: fx.Int32,
stream: fx.Stream,
):
grid_x = ((M + BLOCK_M - 1) // BLOCK_M) * N_BLOCKS
grid_x = ceildiv(c_m, BLOCK_M) * ceildiv(c_n, BLOCK_N)
kernel_gemm(
A,
B_T,
C,
A_scale,
B_scale,
c_m,
c_n,
value_attrs={"rocdl.waves_per_eu": 1, "rocdl.flat_work_group_size": "256,256"},
).launch(grid=(grid_x, 1, 1), block=(256, 1, 1), stream=stream)

Expand Down
21 changes: 14 additions & 7 deletions kernels/fp8_gemm_8wave.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,20 @@
Mfma16x16x128,
S2RLoader,
StoreC,
ceildiv,
compute_global_swizzle,
divmod,
make_fp8_buffer_tensor,
wait_barrier,
)


def compile_fp8_gemm_8w(*, M: int, N: int, K: int, BLOCK_M: int = 256, BLOCK_N: int = 256, b_preshuffled: bool = False):
def compile_fp8_gemm_8w(*, K: int, BLOCK_M: int = 256, BLOCK_N: int = 256, b_preshuffled: bool = False):
BLOCK_K = 128

assert M >= 1 and N >= 1
assert BLOCK_M >= 128 and BLOCK_N >= 256 and BLOCK_M % 128 == 0 and BLOCK_N % 256 == 0
assert K % BLOCK_K == 0

N_BLOCKS = (N + BLOCK_N - 1) // BLOCK_N
K_ITERS = K // BLOCK_K

N_TILES_A = BLOCK_M // 64
Expand Down Expand Up @@ -65,9 +65,13 @@ def kernel_gemm(
C: fx.Tensor,
A_scale: fx.Tensor,
B_scale: fx.Tensor,
c_m: fx.Int32,
c_n: fx.Int32,
):
F8_IR_t = fx.Float8E4M3FN.ir_type

n_blocks = ceildiv(c_n, BLOCK_N)

lds = fx.SharedAllocator().allocate(SharedStorage).peek()
a_cur0 = lds.A_lds_cur_0
a_cur1 = lds.A_lds_cur_1
Expand All @@ -82,8 +86,7 @@ def kernel_gemm(
wave_id = fx.thread_idx.x // 64
wave_m = wave_id // 4
wave_n = wave_id % 4
block_m = fx.block_idx.x // N_BLOCKS
block_n = fx.block_idx.x % N_BLOCKS
block_m, block_n = divmod(fx.block_idx.x, n_blocks)

A0_gl_offset = (block_m * BLOCK_M) * K
A1_gl_offset = (block_m * BLOCK_M + LDS_BLOCK_M) * K
Expand All @@ -105,7 +108,7 @@ def kernel_gemm(
b_g2s = G2SLoader(b_div, gl_off_b, N_LDS_STEPS_B, F8_IR_t, wave_id)
a_s2r = S2RLoader(wave_m, N_TILES_A)
b_s2r = S2RLoader(wave_n, N_TILES_B)
store_c = StoreC(A_scale, B_scale, C, M, N, mfma.idx, N_TILES_A, N_TILES_B)
store_c = StoreC(A_scale, B_scale, C, c_m, c_n, mfma.idx, N_TILES_A, N_TILES_B)

# 2x2 config of 4x2 (instead of 4x4 in 4wave) 16x16 sub-tiles
c00_frag = [mfma.zero_value] * N_ACCUMS
Expand Down Expand Up @@ -257,15 +260,19 @@ def launch_gemm(
C: fx.Tensor,
A_scale: fx.Tensor,
B_scale: fx.Tensor,
c_m: fx.Int32,
c_n: fx.Int32,
stream: fx.Stream,
):
grid_x = ((M + BLOCK_M - 1) // BLOCK_M) * N_BLOCKS
grid_x = ceildiv(c_m, BLOCK_M) * ceildiv(c_n, BLOCK_N)
kernel_gemm(
A,
B_T,
C,
A_scale,
B_scale,
c_m,
c_n,
value_attrs={"rocdl.waves_per_eu": 2, "rocdl.flat_work_group_size": "512,512"},
).launch(grid=(grid_x, 1, 1), block=(512, 1, 1), stream=stream)

Expand Down
8 changes: 8 additions & 0 deletions kernels/fp8_gemm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@ def preshuffle_b(b_t):
return b_t.reshape(n // 16, 16, k // 64, 4, 16).permute(0, 2, 3, 1, 4).contiguous()


def ceildiv(a: int, b: int) -> int:
return (a + b - 1) // b


def divmod(a: int, b: int) -> tuple[int, int]:
return (a // b, a % b)


def make_fp8_buffer_tensor(arg_i8, fp8_ir_t):
t_i8 = fx.rocdl.make_buffer_tensor(arg_i8, max_size=False)
iter_i8 = fx.get_iter(t_i8)
Expand Down
6 changes: 2 additions & 4 deletions tests/kernels/test_fp8_gemm_rowscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,6 @@ def _bench_fp8_gemm(

if use_8w:
launch_fn = compile_fp8_gemm_8w(
M=M,
N=N,
K=K,
BLOCK_M=tile_m,
BLOCK_N=tile_n,
Expand All @@ -112,8 +110,6 @@ def _bench_fp8_gemm(
print(f"\n[fp8_gemm_8wave] M={M} N={N} K={K} BLOCK_M={tile_m} BLOCK_N={tile_n} preshuffle_b={b_preshuffled}")
else:
launch_fn = compile_fp8_gemm_4w(
M=M,
N=N,
K=K,
BLOCK_M=tile_m,
BLOCK_N=tile_n,
Expand All @@ -133,6 +129,8 @@ def _args(c, a, b, sa, sb):
c.contiguous().view(-1),
sa.contiguous().view(-1),
sb.contiguous().view(-1),
M,
N,
torch.cuda.current_stream(),
)

Expand Down
Loading