diff --git a/kernels/mixed_moe_gemm_2stage.py b/kernels/mixed_moe_gemm_2stage.py index f8005804..3bdefc98 100644 --- a/kernels/mixed_moe_gemm_2stage.py +++ b/kernels/mixed_moe_gemm_2stage.py @@ -573,7 +573,13 @@ def moe_gemm1( x_nbytes_i32 = arith.index_cast(T.i32, x_nbytes_idx) x_rsrc = buffer_ops.create_buffer_resource(arg_x, max_size=False, num_records_bytes=x_nbytes_i32) - w_rsrc = buffer_ops.create_buffer_resource(arg_w, max_size=False) + # W: [experts, 2*inter_dim, model_dim]; fp4 packs 2 elements per byte. + w_nbytes_s1 = ( + (experts * (2 * inter_dim) * model_dim) // 2 + if is_f4_b + else (experts * (2 * inter_dim) * model_dim * b_elem_bytes) + ) + w_rsrc = buffer_ops.create_buffer_resource(arg_w, max_size=False, num_records_bytes=w_nbytes_s1) # Out: [tokens*topk, inter_dim] numids_rsrc = buffer_ops.create_buffer_resource( @@ -621,14 +627,30 @@ def moe_gemm1( expert_rsrc = buffer_ops.create_buffer_resource( arg_expert_ids, max_size=False, num_records_bytes=eid_nbytes_i32 ) - bias_rsrc = buffer_ops.create_buffer_resource(arg_bias, max_size=False) if enable_bias else None + # bias: [experts, 2*inter_dim] f32 -> bytes = experts * 2*inter_dim * 4 + bias_nbytes_s1 = experts * (2 * inter_dim) * 4 + bias_rsrc = ( + buffer_ops.create_buffer_resource(arg_bias, max_size=False, num_records_bytes=bias_nbytes_s1) + if enable_bias + else None + ) # Sorted-scale buffer resource for fused mxfp4 quantization _sorted_scale_cols = inter_dim // 32 _sorted_scale_cols_i32 = arith.constant(_sorted_scale_cols, type=T.i32) sorted_scale_rsrc = None if const_expr(_need_sort): - sorted_scale_rsrc = buffer_ops.create_buffer_resource(arg_out_scale_sorted, max_size=False) + _sort_rows_idx = size_expert_ids_in * arith.constant(sort_block_m, index=True) + _sort_padded_rows = ( + (_sort_rows_idx + arith.constant(255, index=True)) + / arith.constant(256, index=True) + * arith.constant(256, index=True) + ) + _sort_padded_cols = arith.constant(((_sorted_scale_cols + 7) // 8) * 8, index=True) + _sort_scale_nbytes = arith.index_cast(T.i32, _sort_padded_rows * _sort_padded_cols) + sorted_scale_rsrc = buffer_ops.create_buffer_resource( + arg_out_scale_sorted, max_size=False, num_records_bytes=_sort_scale_nbytes + ) # ---- persist_m loop (same pattern as stage2) ---- _PERSIST_M = persist_m @@ -2746,7 +2768,11 @@ def check_c_k_valid_gate(base_k): x_nbytes_i32 = arith.index_cast(T.i32, x_nbytes_idx) x_rsrc = buffer_ops.create_buffer_resource(arg_x, max_size=False, num_records_bytes=x_nbytes_i32) - w_rsrc = buffer_ops.create_buffer_resource(arg_w, max_size=False) + # W: [experts, model_dim, inter_dim]; fp4 packs 2 elements per byte. + w_nbytes_s2 = ( + (experts * model_dim * inter_dim) // 2 if is_f4_b else (experts * model_dim * inter_dim * b_elem_bytes) + ) + w_rsrc = buffer_ops.create_buffer_resource(arg_w, max_size=False, num_records_bytes=w_nbytes_s2) # OUT: [tokens, model_dim] -> clamp to descriptor max (i32 bytes) to avoid overflow on huge tokens. out_elem_bytes = 4 if out_is_f32 else 2 @@ -2824,7 +2850,13 @@ def check_c_k_valid_gate(base_k): expert_rsrc = buffer_ops.create_buffer_resource( arg_expert_ids, max_size=False, num_records_bytes=eid_nbytes_i32 ) - bias_rsrc = buffer_ops.create_buffer_resource(arg_bias, max_size=False) if enable_bias else None + # bias: [experts, model_dim] f32 -> bytes = experts * model_dim * 4 + bias_nbytes_s2 = experts * model_dim * 4 + bias_rsrc = ( + buffer_ops.create_buffer_resource(arg_bias, max_size=False, num_records_bytes=bias_nbytes_s2) + if enable_bias + else None + ) # ---- persist loop ---- _c0_p = arith.constant(0, index=True) diff --git a/kernels/moe_blockscale_2stage.py b/kernels/moe_blockscale_2stage.py index f749f17a..2c5eb635 100644 --- a/kernels/moe_blockscale_2stage.py +++ b/kernels/moe_blockscale_2stage.py @@ -128,7 +128,9 @@ def compile_moe_blockscale_gemm1( sb_per_tile_s1 = tile_k // scale_block_k # scale blocks per tile (in K dim) ku_per_sb_s1 = scale_block_k // 64 # K64-steps per scale block = 2 nblk_k_w1 = model_dim // scale_block_k # K-blocks in W1 (=scale_k) - (2 * inter_dim) // 128 # N-blocks in W1 (ScaleBlockN=128) + nblk_n_w1 = (2 * inter_dim) // 128 # N-blocks in W1 (ScaleBlockN=128) + # scale_w: [experts, nblk_n_w1, nblk_k_w1] f32 (per-block scale) + sw_nbytes = experts * nblk_n_w1 * nblk_k_w1 * 4 mfma_i32_k32 = None if is_int8: @@ -140,7 +142,11 @@ def compile_moe_blockscale_gemm1( ir.ShapedType.get_dynamic_size() # W is packed int4 for W4A8: 2 values per byte. - (experts * (2 * inter_dim) * model_dim) // 2 if is_int4 else (experts * (2 * inter_dim) * model_dim) + w_nbytes = ( + (experts * (2 * inter_dim) * model_dim) // 2 + if is_int4 + else (experts * (2 * inter_dim) * model_dim * elem_bytes) + ) total_threads = 256 bytes_x_per_tile = int(tile_m) * int(tile_k) * int(elem_bytes) @@ -301,7 +307,7 @@ def silu(x): arg_x, max_size=False, num_records_bytes=arith.index_cast(T.i64, x_nbytes_idx) ) - w_rsrc = buffer_ops.create_buffer_resource(arg_w, max_size=False) + w_rsrc = buffer_ops.create_buffer_resource(arg_w, max_size=False, num_records_bytes=w_nbytes) # OUT: [tokens, topk, inter] f16/bf16 -> bytes = tokens*topk*inter*out_elem_bytes out_elem_bytes = 2 # f16/bf16 @@ -321,10 +327,15 @@ def silu(x): sx_rsrc = buffer_ops.create_buffer_resource( arg_scale_x, max_size=False, num_records_bytes=arith.index_cast(T.i64, sx_nbytes_idx) ) - sw_rsrc = buffer_ops.create_buffer_resource(arg_scale_w, max_size=False) + sw_rsrc = buffer_ops.create_buffer_resource(arg_scale_w, max_size=False, num_records_bytes=sw_nbytes) - sorted_rsrc = buffer_ops.create_buffer_resource(arg_sorted_token_ids, max_size=False) - sorted_w_rsrc = buffer_ops.create_buffer_resource(arg_sorted_weights, max_size=False) + sorted_nbytes_idx = size_expert_ids_in * fx.Index(tile_m) * fx.Index(4) + sorted_rsrc = buffer_ops.create_buffer_resource( + arg_sorted_token_ids, max_size=False, num_records_bytes=sorted_nbytes_idx + ) + sorted_w_rsrc = buffer_ops.create_buffer_resource( + arg_sorted_weights, max_size=False, num_records_bytes=sorted_nbytes_idx + ) # expert ids: [blocks] i32 -> bytes = size_expert_ids_in*4 expert_rsrc = buffer_ops.create_buffer_resource( @@ -1236,7 +1247,9 @@ def compile_moe_blockscale_gemm2( sb_per_tile_s2 = tile_k // scale_block_k # scale blocks per tile (in K dim) ku_per_sb_s2 = scale_block_k // 64 # K64-steps per scale block = 2 nblk_k_w2 = inter_dim // scale_block_k # K-blocks in W2 (=scale_k) - model_dim // 128 # N-blocks in W2 (ScaleBlockN=128) + nblk_n_w2 = model_dim // 128 # N-blocks in W2 (ScaleBlockN=128) + # scale_w: [experts, nblk_n_w2, nblk_k_w2] f32 (per-block scale) + sw_nbytes = experts * nblk_n_w2 * nblk_k_w2 * 4 mfma_i32_k32 = None if is_int8: @@ -1248,7 +1261,7 @@ def compile_moe_blockscale_gemm2( ir.ShapedType.get_dynamic_size() # W is packed int4 for W4A8: 2 values per byte. - (experts * model_dim * inter_dim) // 2 if is_int4 else (experts * model_dim * inter_dim) + w_nbytes = (experts * model_dim * inter_dim) // 2 if is_int4 else (experts * model_dim * inter_dim * elem_bytes) total_threads = 256 tile_k_bytes = int(tile_k) * int(elem_bytes) @@ -1431,7 +1444,7 @@ def moe_blockscale_gemm2( arg_x, max_size=False, num_records_bytes=arith.index_cast(T.i64, x_nbytes_idx) ) - w_rsrc = buffer_ops.create_buffer_resource(arg_w, max_size=False) + w_rsrc = buffer_ops.create_buffer_resource(arg_w, max_size=False, num_records_bytes=w_nbytes) # OUT: [tokens, model_dim] -> clamp to descriptor max (i32 bytes) to avoid overflow on huge tokens. out_elem_bytes = 4 if out_is_f32 else 2 @@ -1450,8 +1463,8 @@ def moe_blockscale_gemm2( sx_rsrc = buffer_ops.create_buffer_resource( arg_scale_x, max_size=False, num_records_bytes=arith.index_cast(T.i64, sx_nbytes_idx) ) - # scale_w: [experts*model_dim] f32 (static shape in practice) - sw_rsrc = buffer_ops.create_buffer_resource(arg_scale_w, max_size=False) + # scale_w: [experts, nblk_n_w2, nblk_k_w2] f32 (per-block scale) + sw_rsrc = buffer_ops.create_buffer_resource(arg_scale_w, max_size=False, num_records_bytes=sw_nbytes) # sorted_token_ids / sorted_weights: [blocks*tile_m] (CK-style padded length) sorted_nbytes_idx = size_expert_ids_in * fx.Index(tile_m) * fx.Index(4) diff --git a/kernels/moe_gemm_2stage.py b/kernels/moe_gemm_2stage.py index 5182de6a..f4ecfca6 100644 --- a/kernels/moe_gemm_2stage.py +++ b/kernels/moe_gemm_2stage.py @@ -167,7 +167,7 @@ def compile_moe_gemm1( is_int4_bf16_groupwise = is_int4_bf16 and use_groupwise_scale num_groups = model_dim // group_size if use_groupwise_scale else 1 _scale_is_bf16 = scale_is_bf16 and use_groupwise_scale - experts * (2 * inter_dim) * num_groups + sw_nbytes = experts * (2 * inter_dim) * num_groups * (2 if _scale_is_bf16 else 4) if needs_scale_w else 0 # For groupwise scale, weight scale is applied per-group in the K loop, # so epilogue can skip weight scale multiplication (uses 1.0 for sw). @@ -213,7 +213,11 @@ def compile_moe_gemm1( ir.ShapedType.get_dynamic_size() # W is packed int4 for W4A8/W4A16/W4A_FP8: 2 values per byte. - (experts * (2 * inter_dim) * model_dim) // 2 if w_is_int4 else (experts * (2 * inter_dim) * model_dim) + w_nbytes = ( + (experts * (2 * inter_dim) * model_dim) // 2 + if w_is_int4 + else (experts * (2 * inter_dim) * model_dim * elem_bytes) + ) total_threads = 256 bytes_x_per_tile = int(tile_m) * int(tile_k) * int(elem_bytes) @@ -400,7 +404,7 @@ def silu(x): x_nbytes_idx = x_rows * k_in * arith.index(int(elem_bytes)) x_rsrc = buffer_ops.create_buffer_resource(arg_x, max_size=False, num_records_bytes=x_nbytes_idx) - w_rsrc = buffer_ops.create_buffer_resource(arg_w, max_size=False) + w_rsrc = buffer_ops.create_buffer_resource(arg_w, max_size=False, num_records_bytes=w_nbytes) # OUT: normal=[tokens, topk, inter] f16/bf16, split-K=[tokens*topk, 2*inter] f32 out_elem_bytes = 4 if _is_splitk else 2 @@ -422,10 +426,17 @@ def silu(x): # scale_w: fp16/bf16 (non-int4) path ignores; int4_bf16 needs dequant scale. sw_rsrc = -1 if const_expr(needs_scale_w): - sw_rsrc = buffer_ops.create_buffer_resource(arg_scale_w, max_size=False) + sw_rsrc = buffer_ops.create_buffer_resource( + arg_scale_w, max_size=False, num_records_bytes=sw_nbytes + ) - sorted_rsrc = buffer_ops.create_buffer_resource(arg_sorted_token_ids, max_size=False) - sorted_w_rsrc = buffer_ops.create_buffer_resource(arg_sorted_weights, max_size=False) + sorted_nbytes_idx = size_expert_ids_in * fx.Index(tile_m) * fx.Index(4) + sorted_rsrc = buffer_ops.create_buffer_resource( + arg_sorted_token_ids, max_size=False, num_records_bytes=sorted_nbytes_idx + ) + sorted_w_rsrc = buffer_ops.create_buffer_resource( + arg_sorted_weights, max_size=False, num_records_bytes=sorted_nbytes_idx + ) # expert ids: [blocks] i32 -> bytes = size_expert_ids_in*4 expert_rsrc = buffer_ops.create_buffer_resource( @@ -1738,7 +1749,7 @@ def compile_moe_gemm2( # Stage2 K dimension is inter_dim (weight shape: [E, model_dim, inter_dim]) num_groups = inter_dim // group_size if use_groupwise_scale else 1 _scale_is_bf16 = scale_is_bf16 and use_groupwise_scale - experts * model_dim * num_groups + sw_nbytes = experts * model_dim * num_groups * (2 if _scale_is_bf16 else 4) if needs_scale_w else 0 _is_gfx950 = "gfx95" in get_hip_arch() use_gfx950_cvt = is_int4_bf16 and _is_gfx950 @@ -1767,7 +1778,7 @@ def compile_moe_gemm2( ir.ShapedType.get_dynamic_size() # W is packed int4 for W4A8/W4A16/W4A_FP8: 2 values per byte. - (experts * model_dim * inter_dim) // 2 if w_is_int4 else (experts * model_dim * inter_dim) + w_nbytes = (experts * model_dim * inter_dim) // 2 if w_is_int4 else (experts * model_dim * inter_dim * elem_bytes) total_threads = 256 tile_k_bytes = int(tile_k) * int(elem_bytes) @@ -1958,7 +1969,7 @@ def moe_gemm2( x_nbytes_idx = (tokens_in * c_topk) * k_in * arith.index(int(elem_bytes)) x_rsrc = buffer_ops.create_buffer_resource(arg_x, max_size=False, num_records_bytes=x_nbytes_idx) - w_rsrc = buffer_ops.create_buffer_resource(arg_w, max_size=False) + w_rsrc = buffer_ops.create_buffer_resource(arg_w, max_size=False, num_records_bytes=w_nbytes) # OUT: [tokens, model_dim] -> clamp to descriptor max (i32 bytes) to avoid overflow on huge tokens. out_elem_bytes = 4 if out_is_f32 else 2 @@ -1978,7 +1989,7 @@ def moe_gemm2( sw_rsrc = -1 if const_expr(needs_scale_w): # scale_w: [experts*model_dim] f32 (static shape in practice) - sw_rsrc = buffer_ops.create_buffer_resource(arg_scale_w, max_size=False) + sw_rsrc = buffer_ops.create_buffer_resource(arg_scale_w, max_size=False, num_records_bytes=sw_nbytes) # sorted_token_ids / sorted_weights: [blocks*tile_m] (CK-style padded length) sorted_nbytes_idx = size_expert_ids_in * fx.Index(tile_m) * fx.Index(4) diff --git a/kernels/preshuffle_gemm.py b/kernels/preshuffle_gemm.py index cdd81ef6..72c0caa8 100644 --- a/kernels/preshuffle_gemm.py +++ b/kernels/preshuffle_gemm.py @@ -396,7 +396,17 @@ def kernel_gemm( a_rsrc = buffer_ops.create_buffer_resource(arg_a, max_size=False, num_records_bytes=_a_nrec) c_rsrc = buffer_ops.create_buffer_resource(arg_c, max_size=False, num_records_bytes=_c_nrec) _needs_per_token_scale = not is_f16_or_bf16 and not is_fp4 - scale_a_rsrc = None if (is_f16_or_bf16) else buffer_ops.create_buffer_resource(arg_scale_a, max_size=False) + scale_a_rsrc = None + if const_expr(not is_f16_or_bf16): + if const_expr(is_fp4): + _scale_a_rows = (c_m + fx.Index(31)) // fx.Index(32) + _scale_a_stride_elems = fx.Index((K // (32 * 4 * 2)) * 64) + _scale_a_nrec = fx.Int64(_scale_a_rows * _scale_a_stride_elems * fx.Index(4)) + else: + _scale_a_nrec = fx.Int64(c_m * fx.Index(4)) + scale_a_rsrc = buffer_ops.create_buffer_resource( + arg_scale_a, max_size=False, num_records_bytes=_scale_a_nrec + ) # ---- Bias buffer resource (for fused epilogue) ---- # Use max_size=True so the buffer descriptor's size is taken from the