From 8a2ad7563589817a101dd7e5960aa583d1e53924 Mon Sep 17 00:00:00 2001 From: Lukasz Burzawa Date: Sat, 23 May 2026 05:56:04 +0000 Subject: [PATCH 1/7] optimizations --- .../gfx1250/moe/moe_op_gemm_a8w4.py | 92 ++++++++++++------- 1 file changed, 60 insertions(+), 32 deletions(-) diff --git a/aiter/ops/triton/_gluon_kernels/gfx1250/moe/moe_op_gemm_a8w4.py b/aiter/ops/triton/_gluon_kernels/gfx1250/moe/moe_op_gemm_a8w4.py index 03d70bd38a..e9a06658ed 100644 --- a/aiter/ops/triton/_gluon_kernels/gfx1250/moe/moe_op_gemm_a8w4.py +++ b/aiter/ops/triton/_gluon_kernels/gfx1250/moe/moe_op_gemm_a8w4.py @@ -165,27 +165,22 @@ def _moe_gemm_a8w4( yN = N // ACTIVATION_REDUCTION_N pid = gl.program_id(0) - if ExptOffsSum is not None: - # Determine how much padding there is on the expert data. This allows us to - # know the true grid size and avoid processing padding tiles. - padding_m = grid_m - gl.load(ExptOffsSum) - else: - padding_m: tl.constexpr = 0 index_type: tl.constexpr = gl.int64 if UPCAST_INDICES else gl.int32 - unpadded_m = grid_m - padding_m - total_actual_tiles = unpadded_m * grid_n - if padding_m > 0 and pid >= total_actual_tiles: - return - - pid_mn = pid % (unpadded_m * grid_n) if XCD_SWIZZLE != 1: - pid_mn = remap_xcd(pid_mn, total_actual_tiles, XCD_SWIZZLE) - pid_m, pid_n = pid_grid(pid_mn, unpadded_m, grid_n, 1) + padding_m = grid_m - gl.load(ExptOffsSum) + unpadded_m = grid_m - padding_m + total_actual_tiles = unpadded_m * grid_n + if padding_m > 0 and pid >= total_actual_tiles: + return + pid = remap_xcd(pid, total_actual_tiles, XCD_SWIZZLE) + else: + unpadded_m = grid_m + pid_m, pid_n = pid_grid(pid, unpadded_m, grid_n, 1) # unpack expert data expt_data = gl.load(ExptData + pid_m) - if expt_data == -1: + if XCD_SWIZZLE == 1 and expt_data == -1: return expt_id = expt_data & 0x0000FFFF block_id = expt_data >> 16 @@ -305,18 +300,18 @@ def _moe_gemm_a8w4( layout=SHARED_LAYOUT_X_SCALES, ) - if SWIZZLE_MX_SCALE == "GFX1250_SCALE": + if BLOCK_M == 16: WMMA_LAYOUT: gl.constexpr = gl.amd.AMDWMMALayout( 3, transposed=True, - warp_bases=[[0, 1], [1, 0]], + warp_bases=[[0, 1], [0, 2]], reg_bases=[], instr_shape=[16, 16, 128], ) WMMA_LAYOUT_PACKED: gl.constexpr = gl.amd.AMDWMMALayout( 3, transposed=True, - warp_bases=[[0, 1], [1, 0]], + warp_bases=[[0, 1], [0, 2]], reg_bases=[], instr_shape=[16, 16, 64], ) @@ -512,9 +507,21 @@ def _moe_gemm_a8w4( cur_x_scales = next_x_scales read_idx += 1 + # bias + offs_m = BLOCK_M * block_id + gl.arange( + 0, BLOCK_M, layout=gl.SliceLayout(1, WMMA_LAYOUT) + ) + offs_y_n = BLOCK_N * pid_n + gl.arange( + 0, BLOCK_N, layout=gl.SliceLayout(0, WMMA_LAYOUT) + ) + mask_m = offs_m < M + mask_n = offs_y_n < N + if B is not None: + BPtrs = B + expt_id * stride_b_e + bias = gl.amd.gfx1250.buffer_load(BPtrs, offs_y_n, mask=mask_n) + # Epilogue: drain remaining pipeline stages (no new TDM loads). # The first NUM_BUFFERS-1 iterations still use the pre-load / WMMA pattern. - for k_ep in gl.static_range(NUM_BUFFERS - 1): if is_x_microscaled: acc = gl.amd.gfx1250.wmma_scaled( @@ -567,19 +574,10 @@ def _moe_gemm_a8w4( # scalar fp8 scale if X_static_scale is not None: acc = acc * gl.load(X_static_scale) - # bias - offs_m = BLOCK_M * block_id + gl.arange( - 0, BLOCK_M, layout=gl.SliceLayout(1, WMMA_LAYOUT) - ) - offs_y_n = BLOCK_N * pid_n + gl.arange( - 0, BLOCK_N, layout=gl.SliceLayout(0, WMMA_LAYOUT) - ) - mask_m = offs_m < M - mask_n = offs_y_n < N - if B is not None: - BPtrs = B + expt_id * stride_b_e - bias = gl.amd.gfx1250.buffer_load(BPtrs, offs_y_n, mask=mask_n) + + if bias is not None: acc = acc + bias[None, :] + if APPLY_SWIGLU: out = _swiglu(acc, alpha, limit, ADD_RESIDUAL=ADD_RESIDUAL) tl.static_assert( @@ -612,4 +610,34 @@ def _moe_gemm_a8w4( mask = mask_m[:, None] & mask_n[None, :] if Quant_static_scale is None: out = out.to(tl.bfloat16) - gl.amd.gfx1250.buffer_store(out, Y, offs_y, mask=mask) + + # TDM Store: accumulator → shared memory → global memory + if Quant_static_scale is None: + SHARED_LAYOUT_Y: gl.constexpr = gl.PaddedSharedLayout.with_identity_for( + [[BLOCK_N, 8]], [BLOCK_M, BLOCK_N], [1, 0] + ) + else: + SHARED_LAYOUT_Y: gl.constexpr = gl.PaddedSharedLayout.with_identity_for( + [[BLOCK_N, 16]], [BLOCK_M, BLOCK_N], [1, 0] + ) + y_buffer = gl.allocate_shared_memory( + Y.type.element_ty, + shape=[BLOCK_M, BLOCK_N], + layout=SHARED_LAYOUT_Y, + ) + y_buffer.store(out) + + # Ensure all wavefronts have finished writing to LDS before TDM reads it. + gl.barrier() + + y_desc = gl.amd.gfx1250.tdm.make_tensor_descriptor( + base=Y, + shape=(M, N), + strides=(stride_y_m, stride_y_n), + block_shape=(BLOCK_M, BLOCK_N), + layout=SHARED_LAYOUT_Y, + ) + gl.amd.gfx1250.tdm.async_store( + y_desc, [block_id * BLOCK_M, pid_n * BLOCK_N], y_buffer + ) + gl.amd.gfx1250.tdm.async_wait(0) From 35215b4a316522eee723f14204b210e6b7dcf6eb Mon Sep 17 00:00:00 2001 From: Lukasz Burzawa Date: Sat, 23 May 2026 06:10:39 +0000 Subject: [PATCH 2/7] clean up --- .../gfx1250/moe/moe_op_gemm_a8w4.py | 25 +++++++------------ 1 file changed, 9 insertions(+), 16 deletions(-) diff --git a/aiter/ops/triton/_gluon_kernels/gfx1250/moe/moe_op_gemm_a8w4.py b/aiter/ops/triton/_gluon_kernels/gfx1250/moe/moe_op_gemm_a8w4.py index e9a06658ed..ad4ef21e39 100644 --- a/aiter/ops/triton/_gluon_kernels/gfx1250/moe/moe_op_gemm_a8w4.py +++ b/aiter/ops/triton/_gluon_kernels/gfx1250/moe/moe_op_gemm_a8w4.py @@ -594,32 +594,25 @@ def _moe_gemm_a8w4( "Activation reduction must be 1 if no activation fn is provided", ) out = acc + if Gammas is not None: gammas = gl.load(Gammas + start_m + offs_m, mask=mask_m, other=0.0) out *= gammas[:, None] + # quant if Quant_static_scale is not None: - out = _compute_static_fp8_quant(out, gl.load(Quant_static_scale)) - # write-back - Y += start_m * stride_y_m - offs_y_m = offs_m - offs_y = ( - offs_y_m.to(index_type)[:, None] * stride_y_m - + offs_y_n.to(index_type)[None, :] * stride_y_n - ) - mask = mask_m[:, None] & mask_n[None, :] - if Quant_static_scale is None: - out = out.to(tl.bfloat16) - - # TDM Store: accumulator → shared memory → global memory - if Quant_static_scale is None: SHARED_LAYOUT_Y: gl.constexpr = gl.PaddedSharedLayout.with_identity_for( - [[BLOCK_N, 8]], [BLOCK_M, BLOCK_N], [1, 0] + [[BLOCK_N, 16]], [BLOCK_M, BLOCK_N], [1, 0] ) + out = _compute_static_fp8_quant(out, gl.load(Quant_static_scale)) else: SHARED_LAYOUT_Y: gl.constexpr = gl.PaddedSharedLayout.with_identity_for( - [[BLOCK_N, 16]], [BLOCK_M, BLOCK_N], [1, 0] + [[BLOCK_N, 8]], [BLOCK_M, BLOCK_N], [1, 0] ) + out = out.to(tl.bfloat16) + + # TDM Store: accumulator → shared memory → global memory + Y += start_m * stride_y_m y_buffer = gl.allocate_shared_memory( Y.type.element_ty, shape=[BLOCK_M, BLOCK_N], From bc5e5c5ade69f1ec5068cd637b6f88d9464cfd7f Mon Sep 17 00:00:00 2001 From: Lukasz Burzawa Date: Sat, 23 May 2026 06:38:01 +0000 Subject: [PATCH 3/7] fix bias --- aiter/ops/triton/_gluon_kernels/gfx1250/moe/moe_op_gemm_a8w4.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aiter/ops/triton/_gluon_kernels/gfx1250/moe/moe_op_gemm_a8w4.py b/aiter/ops/triton/_gluon_kernels/gfx1250/moe/moe_op_gemm_a8w4.py index ad4ef21e39..324dba0346 100644 --- a/aiter/ops/triton/_gluon_kernels/gfx1250/moe/moe_op_gemm_a8w4.py +++ b/aiter/ops/triton/_gluon_kernels/gfx1250/moe/moe_op_gemm_a8w4.py @@ -575,7 +575,7 @@ def _moe_gemm_a8w4( if X_static_scale is not None: acc = acc * gl.load(X_static_scale) - if bias is not None: + if B is not None: acc = acc + bias[None, :] if APPLY_SWIGLU: From c6cbe4b28763a8de2e6cf0896a980d5c1eb48cf9 Mon Sep 17 00:00:00 2001 From: Lukasz Burzawa Date: Sat, 23 May 2026 06:47:33 +0000 Subject: [PATCH 4/7] fix swiglu --- .../gfx1250/moe/moe_op_gemm_a8w4.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/aiter/ops/triton/_gluon_kernels/gfx1250/moe/moe_op_gemm_a8w4.py b/aiter/ops/triton/_gluon_kernels/gfx1250/moe/moe_op_gemm_a8w4.py index 324dba0346..bf8674a376 100644 --- a/aiter/ops/triton/_gluon_kernels/gfx1250/moe/moe_op_gemm_a8w4.py +++ b/aiter/ops/triton/_gluon_kernels/gfx1250/moe/moe_op_gemm_a8w4.py @@ -585,15 +585,23 @@ def _moe_gemm_a8w4( f"Activation fn out.shape[1] ({out.shape[1]}) doesn't match computed OUT_BLOCK_N ({OUT_BLOCK_N})", ) offs_m = BLOCK_M * block_id + gl.arange(0, BLOCK_M) - offs_y_n = OUT_BLOCK_N * pid_n + gl.arange(0, OUT_BLOCK_N) mask_m = offs_m < M - mask_n = offs_y_n < yN + y_buffer = gl.allocate_shared_memory( + Y.type.element_ty, + shape=[BLOCK_M, OUT_BLOCK_N], + layout=SHARED_LAYOUT_Y, + ) else: tl.static_assert( ACTIVATION_REDUCTION_N == 1, "Activation reduction must be 1 if no activation fn is provided", ) out = acc + y_buffer = gl.allocate_shared_memory( + Y.type.element_ty, + shape=[BLOCK_M, BLOCK_N], + layout=SHARED_LAYOUT_Y, + ) if Gammas is not None: gammas = gl.load(Gammas + start_m + offs_m, mask=mask_m, other=0.0) @@ -613,11 +621,6 @@ def _moe_gemm_a8w4( # TDM Store: accumulator → shared memory → global memory Y += start_m * stride_y_m - y_buffer = gl.allocate_shared_memory( - Y.type.element_ty, - shape=[BLOCK_M, BLOCK_N], - layout=SHARED_LAYOUT_Y, - ) y_buffer.store(out) # Ensure all wavefronts have finished writing to LDS before TDM reads it. From 5aee53989ab3347d7f063da4b6604f7ffac49d82 Mon Sep 17 00:00:00 2001 From: Lukasz Burzawa Date: Sat, 23 May 2026 07:02:49 +0000 Subject: [PATCH 5/7] fix bug --- .../gfx1250/moe/moe_op_gemm_a8w4.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/aiter/ops/triton/_gluon_kernels/gfx1250/moe/moe_op_gemm_a8w4.py b/aiter/ops/triton/_gluon_kernels/gfx1250/moe/moe_op_gemm_a8w4.py index bf8674a376..165a833990 100644 --- a/aiter/ops/triton/_gluon_kernels/gfx1250/moe/moe_op_gemm_a8w4.py +++ b/aiter/ops/triton/_gluon_kernels/gfx1250/moe/moe_op_gemm_a8w4.py @@ -243,6 +243,14 @@ def _moe_gemm_a8w4( SHARED_LAYOUT_X_SCALES: gl.constexpr = gl.PaddedSharedLayout.with_identity_for( [[256, 16]], [BLOCK_M, MX_SCALE_BLOCK_K], [1, 0] ) + if Quant_static_scale is not None: + SHARED_LAYOUT_Y: gl.constexpr = gl.PaddedSharedLayout.with_identity_for( + [[BLOCK_N, 16]], [BLOCK_M, BLOCK_N], [1, 0] + ) + else: + SHARED_LAYOUT_Y: gl.constexpr = gl.PaddedSharedLayout.with_identity_for( + [[BLOCK_N, 8]], [BLOCK_M, BLOCK_N], [1, 0] + ) if GatherIndx is None: x_desc = gl.amd.gfx1250.tdm.make_tensor_descriptor( @@ -609,23 +617,14 @@ def _moe_gemm_a8w4( # quant if Quant_static_scale is not None: - SHARED_LAYOUT_Y: gl.constexpr = gl.PaddedSharedLayout.with_identity_for( - [[BLOCK_N, 16]], [BLOCK_M, BLOCK_N], [1, 0] - ) out = _compute_static_fp8_quant(out, gl.load(Quant_static_scale)) else: - SHARED_LAYOUT_Y: gl.constexpr = gl.PaddedSharedLayout.with_identity_for( - [[BLOCK_N, 8]], [BLOCK_M, BLOCK_N], [1, 0] - ) out = out.to(tl.bfloat16) # TDM Store: accumulator → shared memory → global memory Y += start_m * stride_y_m y_buffer.store(out) - # Ensure all wavefronts have finished writing to LDS before TDM reads it. - gl.barrier() - y_desc = gl.amd.gfx1250.tdm.make_tensor_descriptor( base=Y, shape=(M, N), From 74c613d171f0afcf116f4ca96b96dcba30e7099c Mon Sep 17 00:00:00 2001 From: Lukasz Burzawa Date: Sat, 23 May 2026 07:09:15 +0000 Subject: [PATCH 6/7] fix bug --- aiter/ops/triton/_gluon_kernels/gfx1250/moe/moe_op_gemm_a8w4.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aiter/ops/triton/_gluon_kernels/gfx1250/moe/moe_op_gemm_a8w4.py b/aiter/ops/triton/_gluon_kernels/gfx1250/moe/moe_op_gemm_a8w4.py index 165a833990..943ab7dd82 100644 --- a/aiter/ops/triton/_gluon_kernels/gfx1250/moe/moe_op_gemm_a8w4.py +++ b/aiter/ops/triton/_gluon_kernels/gfx1250/moe/moe_op_gemm_a8w4.py @@ -627,7 +627,7 @@ def _moe_gemm_a8w4( y_desc = gl.amd.gfx1250.tdm.make_tensor_descriptor( base=Y, - shape=(M, N), + shape=(M, yN), strides=(stride_y_m, stride_y_n), block_shape=(BLOCK_M, BLOCK_N), layout=SHARED_LAYOUT_Y, From 9fb2339f59ec170ba9ab70346683f4dd37100322 Mon Sep 17 00:00:00 2001 From: Lukasz Burzawa Date: Sat, 23 May 2026 07:42:02 +0000 Subject: [PATCH 7/7] fix bug --- .../gfx1250/moe/moe_op_gemm_a8w4.py | 26 +++++++------------ 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/aiter/ops/triton/_gluon_kernels/gfx1250/moe/moe_op_gemm_a8w4.py b/aiter/ops/triton/_gluon_kernels/gfx1250/moe/moe_op_gemm_a8w4.py index 943ab7dd82..a2ee4077f8 100644 --- a/aiter/ops/triton/_gluon_kernels/gfx1250/moe/moe_op_gemm_a8w4.py +++ b/aiter/ops/triton/_gluon_kernels/gfx1250/moe/moe_op_gemm_a8w4.py @@ -245,11 +245,11 @@ def _moe_gemm_a8w4( ) if Quant_static_scale is not None: SHARED_LAYOUT_Y: gl.constexpr = gl.PaddedSharedLayout.with_identity_for( - [[BLOCK_N, 16]], [BLOCK_M, BLOCK_N], [1, 0] + [[OUT_BLOCK_N, 16]], [BLOCK_M, OUT_BLOCK_N], [1, 0] ) else: SHARED_LAYOUT_Y: gl.constexpr = gl.PaddedSharedLayout.with_identity_for( - [[BLOCK_N, 8]], [BLOCK_M, BLOCK_N], [1, 0] + [[OUT_BLOCK_N, 8]], [BLOCK_M, OUT_BLOCK_N], [1, 0] ) if GatherIndx is None: @@ -594,22 +594,12 @@ def _moe_gemm_a8w4( ) offs_m = BLOCK_M * block_id + gl.arange(0, BLOCK_M) mask_m = offs_m < M - y_buffer = gl.allocate_shared_memory( - Y.type.element_ty, - shape=[BLOCK_M, OUT_BLOCK_N], - layout=SHARED_LAYOUT_Y, - ) else: tl.static_assert( ACTIVATION_REDUCTION_N == 1, "Activation reduction must be 1 if no activation fn is provided", ) out = acc - y_buffer = gl.allocate_shared_memory( - Y.type.element_ty, - shape=[BLOCK_M, BLOCK_N], - layout=SHARED_LAYOUT_Y, - ) if Gammas is not None: gammas = gl.load(Gammas + start_m + offs_m, mask=mask_m, other=0.0) @@ -623,16 +613,20 @@ def _moe_gemm_a8w4( # TDM Store: accumulator → shared memory → global memory Y += start_m * stride_y_m - y_buffer.store(out) - + y_buffer = gl.allocate_shared_memory( + Y.type.element_ty, + shape=[BLOCK_M, OUT_BLOCK_N], + layout=SHARED_LAYOUT_Y, + ) y_desc = gl.amd.gfx1250.tdm.make_tensor_descriptor( base=Y, shape=(M, yN), strides=(stride_y_m, stride_y_n), - block_shape=(BLOCK_M, BLOCK_N), + block_shape=(BLOCK_M, OUT_BLOCK_N), layout=SHARED_LAYOUT_Y, ) + y_buffer.store(out) gl.amd.gfx1250.tdm.async_store( - y_desc, [block_id * BLOCK_M, pid_n * BLOCK_N], y_buffer + y_desc, [block_id * BLOCK_M, pid_n * OUT_BLOCK_N], y_buffer ) gl.amd.gfx1250.tdm.async_wait(0)