Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 37 additions & 5 deletions kernels/mixed_moe_gemm_2stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,13 @@ def moe_gemm1(
x_nbytes_i32 = arith.index_cast(T.i32, x_nbytes_idx)
x_rsrc = buffer_ops.create_buffer_resource(arg_x, max_size=False, num_records_bytes=x_nbytes_i32)

w_rsrc = buffer_ops.create_buffer_resource(arg_w, max_size=False)
# W: [experts, 2*inter_dim, model_dim]; fp4 packs 2 elements per byte.
w_nbytes_s1 = (
(experts * (2 * inter_dim) * model_dim) // 2
if is_f4_b
else (experts * (2 * inter_dim) * model_dim * b_elem_bytes)
)
w_rsrc = buffer_ops.create_buffer_resource(arg_w, max_size=False, num_records_bytes=w_nbytes_s1)

# Out: [tokens*topk, inter_dim]
numids_rsrc = buffer_ops.create_buffer_resource(
Expand Down Expand Up @@ -621,14 +627,30 @@ def moe_gemm1(
expert_rsrc = buffer_ops.create_buffer_resource(
arg_expert_ids, max_size=False, num_records_bytes=eid_nbytes_i32
)
bias_rsrc = buffer_ops.create_buffer_resource(arg_bias, max_size=False) if enable_bias else None
# bias: [experts, 2*inter_dim] f32 -> bytes = experts * 2*inter_dim * 4
bias_nbytes_s1 = experts * (2 * inter_dim) * 4
bias_rsrc = (
buffer_ops.create_buffer_resource(arg_bias, max_size=False, num_records_bytes=bias_nbytes_s1)
if enable_bias
else None
)

# Sorted-scale buffer resource for fused mxfp4 quantization
_sorted_scale_cols = inter_dim // 32
_sorted_scale_cols_i32 = arith.constant(_sorted_scale_cols, type=T.i32)
sorted_scale_rsrc = None
if const_expr(_need_sort):
sorted_scale_rsrc = buffer_ops.create_buffer_resource(arg_out_scale_sorted, max_size=False)
_sort_rows_idx = size_expert_ids_in * arith.constant(sort_block_m, index=True)
_sort_padded_rows = (
(_sort_rows_idx + arith.constant(255, index=True))
/ arith.constant(256, index=True)
* arith.constant(256, index=True)
)
_sort_padded_cols = arith.constant(((_sorted_scale_cols + 7) // 8) * 8, index=True)
_sort_scale_nbytes = arith.index_cast(T.i32, _sort_padded_rows * _sort_padded_cols)
sorted_scale_rsrc = buffer_ops.create_buffer_resource(
arg_out_scale_sorted, max_size=False, num_records_bytes=_sort_scale_nbytes
)

# ---- persist_m loop (same pattern as stage2) ----
_PERSIST_M = persist_m
Expand Down Expand Up @@ -2746,7 +2768,11 @@ def check_c_k_valid_gate(base_k):
x_nbytes_i32 = arith.index_cast(T.i32, x_nbytes_idx)
x_rsrc = buffer_ops.create_buffer_resource(arg_x, max_size=False, num_records_bytes=x_nbytes_i32)

w_rsrc = buffer_ops.create_buffer_resource(arg_w, max_size=False)
# W: [experts, model_dim, inter_dim]; fp4 packs 2 elements per byte.
w_nbytes_s2 = (
(experts * model_dim * inter_dim) // 2 if is_f4_b else (experts * model_dim * inter_dim * b_elem_bytes)
)
w_rsrc = buffer_ops.create_buffer_resource(arg_w, max_size=False, num_records_bytes=w_nbytes_s2)

# OUT: [tokens, model_dim] -> clamp to descriptor max (i32 bytes) to avoid overflow on huge tokens.
out_elem_bytes = 4 if out_is_f32 else 2
Expand Down Expand Up @@ -2824,7 +2850,13 @@ def check_c_k_valid_gate(base_k):
expert_rsrc = buffer_ops.create_buffer_resource(
arg_expert_ids, max_size=False, num_records_bytes=eid_nbytes_i32
)
bias_rsrc = buffer_ops.create_buffer_resource(arg_bias, max_size=False) if enable_bias else None
# bias: [experts, model_dim] f32 -> bytes = experts * model_dim * 4
bias_nbytes_s2 = experts * model_dim * 4
bias_rsrc = (
buffer_ops.create_buffer_resource(arg_bias, max_size=False, num_records_bytes=bias_nbytes_s2)
if enable_bias
else None
)

# ---- persist loop ----
_c0_p = arith.constant(0, index=True)
Expand Down
35 changes: 24 additions & 11 deletions kernels/moe_blockscale_2stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,9 @@ def compile_moe_blockscale_gemm1(
sb_per_tile_s1 = tile_k // scale_block_k # scale blocks per tile (in K dim)
ku_per_sb_s1 = scale_block_k // 64 # K64-steps per scale block = 2
nblk_k_w1 = model_dim // scale_block_k # K-blocks in W1 (=scale_k)
(2 * inter_dim) // 128 # N-blocks in W1 (ScaleBlockN=128)
nblk_n_w1 = (2 * inter_dim) // 128 # N-blocks in W1 (ScaleBlockN=128)
# scale_w: [experts, nblk_n_w1, nblk_k_w1] f32 (per-block scale)
sw_nbytes = experts * nblk_n_w1 * nblk_k_w1 * 4

mfma_i32_k32 = None
if is_int8:
Expand All @@ -140,7 +142,11 @@ def compile_moe_blockscale_gemm1(

ir.ShapedType.get_dynamic_size()
# W is packed int4 for W4A8: 2 values per byte.
(experts * (2 * inter_dim) * model_dim) // 2 if is_int4 else (experts * (2 * inter_dim) * model_dim)
w_nbytes = (
(experts * (2 * inter_dim) * model_dim) // 2
if is_int4
else (experts * (2 * inter_dim) * model_dim * elem_bytes)
)

total_threads = 256
bytes_x_per_tile = int(tile_m) * int(tile_k) * int(elem_bytes)
Expand Down Expand Up @@ -301,7 +307,7 @@ def silu(x):
arg_x, max_size=False, num_records_bytes=arith.index_cast(T.i64, x_nbytes_idx)
)

w_rsrc = buffer_ops.create_buffer_resource(arg_w, max_size=False)
w_rsrc = buffer_ops.create_buffer_resource(arg_w, max_size=False, num_records_bytes=w_nbytes)

# OUT: [tokens, topk, inter] f16/bf16 -> bytes = tokens*topk*inter*out_elem_bytes
out_elem_bytes = 2 # f16/bf16
Expand All @@ -321,10 +327,15 @@ def silu(x):
sx_rsrc = buffer_ops.create_buffer_resource(
arg_scale_x, max_size=False, num_records_bytes=arith.index_cast(T.i64, sx_nbytes_idx)
)
sw_rsrc = buffer_ops.create_buffer_resource(arg_scale_w, max_size=False)
sw_rsrc = buffer_ops.create_buffer_resource(arg_scale_w, max_size=False, num_records_bytes=sw_nbytes)

sorted_rsrc = buffer_ops.create_buffer_resource(arg_sorted_token_ids, max_size=False)
sorted_w_rsrc = buffer_ops.create_buffer_resource(arg_sorted_weights, max_size=False)
sorted_nbytes_idx = size_expert_ids_in * fx.Index(tile_m) * fx.Index(4)
sorted_rsrc = buffer_ops.create_buffer_resource(
arg_sorted_token_ids, max_size=False, num_records_bytes=sorted_nbytes_idx
)
sorted_w_rsrc = buffer_ops.create_buffer_resource(
arg_sorted_weights, max_size=False, num_records_bytes=sorted_nbytes_idx
)

# expert ids: [blocks] i32 -> bytes = size_expert_ids_in*4
expert_rsrc = buffer_ops.create_buffer_resource(
Expand Down Expand Up @@ -1236,7 +1247,9 @@ def compile_moe_blockscale_gemm2(
sb_per_tile_s2 = tile_k // scale_block_k # scale blocks per tile (in K dim)
ku_per_sb_s2 = scale_block_k // 64 # K64-steps per scale block = 2
nblk_k_w2 = inter_dim // scale_block_k # K-blocks in W2 (=scale_k)
model_dim // 128 # N-blocks in W2 (ScaleBlockN=128)
nblk_n_w2 = model_dim // 128 # N-blocks in W2 (ScaleBlockN=128)
# scale_w: [experts, nblk_n_w2, nblk_k_w2] f32 (per-block scale)
sw_nbytes = experts * nblk_n_w2 * nblk_k_w2 * 4

mfma_i32_k32 = None
if is_int8:
Expand All @@ -1248,7 +1261,7 @@ def compile_moe_blockscale_gemm2(

ir.ShapedType.get_dynamic_size()
# W is packed int4 for W4A8: 2 values per byte.
(experts * model_dim * inter_dim) // 2 if is_int4 else (experts * model_dim * inter_dim)
w_nbytes = (experts * model_dim * inter_dim) // 2 if is_int4 else (experts * model_dim * inter_dim * elem_bytes)

total_threads = 256
tile_k_bytes = int(tile_k) * int(elem_bytes)
Expand Down Expand Up @@ -1431,7 +1444,7 @@ def moe_blockscale_gemm2(
arg_x, max_size=False, num_records_bytes=arith.index_cast(T.i64, x_nbytes_idx)
)

w_rsrc = buffer_ops.create_buffer_resource(arg_w, max_size=False)
w_rsrc = buffer_ops.create_buffer_resource(arg_w, max_size=False, num_records_bytes=w_nbytes)

# OUT: [tokens, model_dim] -> clamp to descriptor max (i32 bytes) to avoid overflow on huge tokens.
out_elem_bytes = 4 if out_is_f32 else 2
Expand All @@ -1450,8 +1463,8 @@ def moe_blockscale_gemm2(
sx_rsrc = buffer_ops.create_buffer_resource(
arg_scale_x, max_size=False, num_records_bytes=arith.index_cast(T.i64, sx_nbytes_idx)
)
# scale_w: [experts*model_dim] f32 (static shape in practice)
sw_rsrc = buffer_ops.create_buffer_resource(arg_scale_w, max_size=False)
# scale_w: [experts, nblk_n_w2, nblk_k_w2] f32 (per-block scale)
sw_rsrc = buffer_ops.create_buffer_resource(arg_scale_w, max_size=False, num_records_bytes=sw_nbytes)

# sorted_token_ids / sorted_weights: [blocks*tile_m] (CK-style padded length)
sorted_nbytes_idx = size_expert_ids_in * fx.Index(tile_m) * fx.Index(4)
Expand Down
31 changes: 21 additions & 10 deletions kernels/moe_gemm_2stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def compile_moe_gemm1(
is_int4_bf16_groupwise = is_int4_bf16 and use_groupwise_scale
num_groups = model_dim // group_size if use_groupwise_scale else 1
_scale_is_bf16 = scale_is_bf16 and use_groupwise_scale
experts * (2 * inter_dim) * num_groups
sw_nbytes = experts * (2 * inter_dim) * num_groups * (2 if _scale_is_bf16 else 4) if needs_scale_w else 0
# For groupwise scale, weight scale is applied per-group in the K loop,
# so epilogue can skip weight scale multiplication (uses 1.0 for sw).

Expand Down Expand Up @@ -213,7 +213,11 @@ def compile_moe_gemm1(

ir.ShapedType.get_dynamic_size()
# W is packed int4 for W4A8/W4A16/W4A_FP8: 2 values per byte.
(experts * (2 * inter_dim) * model_dim) // 2 if w_is_int4 else (experts * (2 * inter_dim) * model_dim)
w_nbytes = (
(experts * (2 * inter_dim) * model_dim) // 2
if w_is_int4
else (experts * (2 * inter_dim) * model_dim * elem_bytes)
)

total_threads = 256
bytes_x_per_tile = int(tile_m) * int(tile_k) * int(elem_bytes)
Expand Down Expand Up @@ -400,7 +404,7 @@ def silu(x):
x_nbytes_idx = x_rows * k_in * arith.index(int(elem_bytes))
x_rsrc = buffer_ops.create_buffer_resource(arg_x, max_size=False, num_records_bytes=x_nbytes_idx)

w_rsrc = buffer_ops.create_buffer_resource(arg_w, max_size=False)
w_rsrc = buffer_ops.create_buffer_resource(arg_w, max_size=False, num_records_bytes=w_nbytes)

# OUT: normal=[tokens, topk, inter] f16/bf16, split-K=[tokens*topk, 2*inter] f32
out_elem_bytes = 4 if _is_splitk else 2
Expand All @@ -422,10 +426,17 @@ def silu(x):
# scale_w: fp16/bf16 (non-int4) path ignores; int4_bf16 needs dequant scale.
sw_rsrc = -1
if const_expr(needs_scale_w):
sw_rsrc = buffer_ops.create_buffer_resource(arg_scale_w, max_size=False)
sw_rsrc = buffer_ops.create_buffer_resource(
arg_scale_w, max_size=False, num_records_bytes=sw_nbytes
)

sorted_rsrc = buffer_ops.create_buffer_resource(arg_sorted_token_ids, max_size=False)
sorted_w_rsrc = buffer_ops.create_buffer_resource(arg_sorted_weights, max_size=False)
sorted_nbytes_idx = size_expert_ids_in * fx.Index(tile_m) * fx.Index(4)
sorted_rsrc = buffer_ops.create_buffer_resource(
arg_sorted_token_ids, max_size=False, num_records_bytes=sorted_nbytes_idx
)
sorted_w_rsrc = buffer_ops.create_buffer_resource(
arg_sorted_weights, max_size=False, num_records_bytes=sorted_nbytes_idx
)

# expert ids: [blocks] i32 -> bytes = size_expert_ids_in*4
expert_rsrc = buffer_ops.create_buffer_resource(
Expand Down Expand Up @@ -1738,7 +1749,7 @@ def compile_moe_gemm2(
# Stage2 K dimension is inter_dim (weight shape: [E, model_dim, inter_dim])
num_groups = inter_dim // group_size if use_groupwise_scale else 1
_scale_is_bf16 = scale_is_bf16 and use_groupwise_scale
experts * model_dim * num_groups
sw_nbytes = experts * model_dim * num_groups * (2 if _scale_is_bf16 else 4) if needs_scale_w else 0

_is_gfx950 = "gfx95" in get_hip_arch()
use_gfx950_cvt = is_int4_bf16 and _is_gfx950
Expand Down Expand Up @@ -1767,7 +1778,7 @@ def compile_moe_gemm2(

ir.ShapedType.get_dynamic_size()
# W is packed int4 for W4A8/W4A16/W4A_FP8: 2 values per byte.
(experts * model_dim * inter_dim) // 2 if w_is_int4 else (experts * model_dim * inter_dim)
w_nbytes = (experts * model_dim * inter_dim) // 2 if w_is_int4 else (experts * model_dim * inter_dim * elem_bytes)

total_threads = 256
tile_k_bytes = int(tile_k) * int(elem_bytes)
Expand Down Expand Up @@ -1958,7 +1969,7 @@ def moe_gemm2(
x_nbytes_idx = (tokens_in * c_topk) * k_in * arith.index(int(elem_bytes))
x_rsrc = buffer_ops.create_buffer_resource(arg_x, max_size=False, num_records_bytes=x_nbytes_idx)

w_rsrc = buffer_ops.create_buffer_resource(arg_w, max_size=False)
w_rsrc = buffer_ops.create_buffer_resource(arg_w, max_size=False, num_records_bytes=w_nbytes)

# OUT: [tokens, model_dim] -> clamp to descriptor max (i32 bytes) to avoid overflow on huge tokens.
out_elem_bytes = 4 if out_is_f32 else 2
Expand All @@ -1978,7 +1989,7 @@ def moe_gemm2(
sw_rsrc = -1
if const_expr(needs_scale_w):
# scale_w: [experts*model_dim] f32 (static shape in practice)
sw_rsrc = buffer_ops.create_buffer_resource(arg_scale_w, max_size=False)
sw_rsrc = buffer_ops.create_buffer_resource(arg_scale_w, max_size=False, num_records_bytes=sw_nbytes)

# sorted_token_ids / sorted_weights: [blocks*tile_m] (CK-style padded length)
sorted_nbytes_idx = size_expert_ids_in * fx.Index(tile_m) * fx.Index(4)
Expand Down
12 changes: 11 additions & 1 deletion kernels/preshuffle_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,17 @@ def kernel_gemm(
a_rsrc = buffer_ops.create_buffer_resource(arg_a, max_size=False, num_records_bytes=_a_nrec)
c_rsrc = buffer_ops.create_buffer_resource(arg_c, max_size=False, num_records_bytes=_c_nrec)
_needs_per_token_scale = not is_f16_or_bf16 and not is_fp4
scale_a_rsrc = None if (is_f16_or_bf16) else buffer_ops.create_buffer_resource(arg_scale_a, max_size=False)
scale_a_rsrc = None
if const_expr(not is_f16_or_bf16):
if const_expr(is_fp4):
_scale_a_rows = (c_m + fx.Index(31)) // fx.Index(32)
_scale_a_stride_elems = fx.Index((K // (32 * 4 * 2)) * 64)
_scale_a_nrec = fx.Int64(_scale_a_rows * _scale_a_stride_elems * fx.Index(4))
else:
_scale_a_nrec = fx.Int64(c_m * fx.Index(4))
scale_a_rsrc = buffer_ops.create_buffer_resource(
arg_scale_a, max_size=False, num_records_bytes=_scale_a_nrec
)

# ---- Bias buffer resource (for fused epilogue) ----
# Use max_size=True so the buffer descriptor's size is taken from the
Expand Down
Loading