Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 58 additions & 41 deletions aiter/ops/triton/_gluon_kernels/gfx1250/moe/moe_op_gemm_a8w4.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,27 +165,22 @@ def _moe_gemm_a8w4(
yN = N // ACTIVATION_REDUCTION_N

pid = gl.program_id(0)
if ExptOffsSum is not None:
# Determine how much padding there is on the expert data. This allows us to
# know the true grid size and avoid processing padding tiles.
padding_m = grid_m - gl.load(ExptOffsSum)
else:
padding_m: tl.constexpr = 0

index_type: tl.constexpr = gl.int64 if UPCAST_INDICES else gl.int32

unpadded_m = grid_m - padding_m
total_actual_tiles = unpadded_m * grid_n
if padding_m > 0 and pid >= total_actual_tiles:
return

pid_mn = pid % (unpadded_m * grid_n)
if XCD_SWIZZLE != 1:
pid_mn = remap_xcd(pid_mn, total_actual_tiles, XCD_SWIZZLE)
pid_m, pid_n = pid_grid(pid_mn, unpadded_m, grid_n, 1)
padding_m = grid_m - gl.load(ExptOffsSum)
unpadded_m = grid_m - padding_m
total_actual_tiles = unpadded_m * grid_n
if padding_m > 0 and pid >= total_actual_tiles:
return
pid = remap_xcd(pid, total_actual_tiles, XCD_SWIZZLE)
else:
unpadded_m = grid_m
pid_m, pid_n = pid_grid(pid, unpadded_m, grid_n, 1)
# unpack expert data
expt_data = gl.load(ExptData + pid_m)
if expt_data == -1:
if XCD_SWIZZLE == 1 and expt_data == -1:
return
expt_id = expt_data & 0x0000FFFF
block_id = expt_data >> 16
Expand Down Expand Up @@ -248,6 +243,14 @@ def _moe_gemm_a8w4(
SHARED_LAYOUT_X_SCALES: gl.constexpr = gl.PaddedSharedLayout.with_identity_for(
[[256, 16]], [BLOCK_M, MX_SCALE_BLOCK_K], [1, 0]
)
if Quant_static_scale is not None:
SHARED_LAYOUT_Y: gl.constexpr = gl.PaddedSharedLayout.with_identity_for(
[[OUT_BLOCK_N, 16]], [BLOCK_M, OUT_BLOCK_N], [1, 0]
)
else:
SHARED_LAYOUT_Y: gl.constexpr = gl.PaddedSharedLayout.with_identity_for(
[[OUT_BLOCK_N, 8]], [BLOCK_M, OUT_BLOCK_N], [1, 0]
)

if GatherIndx is None:
x_desc = gl.amd.gfx1250.tdm.make_tensor_descriptor(
Expand Down Expand Up @@ -305,18 +308,18 @@ def _moe_gemm_a8w4(
layout=SHARED_LAYOUT_X_SCALES,
)

if SWIZZLE_MX_SCALE == "GFX1250_SCALE":
if BLOCK_M == 16:
WMMA_LAYOUT: gl.constexpr = gl.amd.AMDWMMALayout(
3,
transposed=True,
warp_bases=[[0, 1], [1, 0]],
warp_bases=[[0, 1], [0, 2]],
reg_bases=[],
instr_shape=[16, 16, 128],
)
WMMA_LAYOUT_PACKED: gl.constexpr = gl.amd.AMDWMMALayout(
3,
transposed=True,
warp_bases=[[0, 1], [1, 0]],
warp_bases=[[0, 1], [0, 2]],
reg_bases=[],
instr_shape=[16, 16, 64],
)
Expand Down Expand Up @@ -512,9 +515,21 @@ def _moe_gemm_a8w4(
cur_x_scales = next_x_scales
read_idx += 1

# bias
offs_m = BLOCK_M * block_id + gl.arange(
0, BLOCK_M, layout=gl.SliceLayout(1, WMMA_LAYOUT)
)
offs_y_n = BLOCK_N * pid_n + gl.arange(
0, BLOCK_N, layout=gl.SliceLayout(0, WMMA_LAYOUT)
)
mask_m = offs_m < M
mask_n = offs_y_n < N
if B is not None:
BPtrs = B + expt_id * stride_b_e
bias = gl.amd.gfx1250.buffer_load(BPtrs, offs_y_n, mask=mask_n)

# Epilogue: drain remaining pipeline stages (no new TDM loads).
# The first NUM_BUFFERS-1 iterations still use the pre-load / WMMA pattern.

for k_ep in gl.static_range(NUM_BUFFERS - 1):
if is_x_microscaled:
acc = gl.amd.gfx1250.wmma_scaled(
Expand Down Expand Up @@ -567,49 +582,51 @@ def _moe_gemm_a8w4(
# scalar fp8 scale
if X_static_scale is not None:
acc = acc * gl.load(X_static_scale)
# bias
offs_m = BLOCK_M * block_id + gl.arange(
0, BLOCK_M, layout=gl.SliceLayout(1, WMMA_LAYOUT)
)
offs_y_n = BLOCK_N * pid_n + gl.arange(
0, BLOCK_N, layout=gl.SliceLayout(0, WMMA_LAYOUT)
)
mask_m = offs_m < M
mask_n = offs_y_n < N

if B is not None:
BPtrs = B + expt_id * stride_b_e
bias = gl.amd.gfx1250.buffer_load(BPtrs, offs_y_n, mask=mask_n)
acc = acc + bias[None, :]

if APPLY_SWIGLU:
out = _swiglu(acc, alpha, limit, ADD_RESIDUAL=ADD_RESIDUAL)
tl.static_assert(
out.shape[1] == OUT_BLOCK_N,
f"Activation fn out.shape[1] ({out.shape[1]}) doesn't match computed OUT_BLOCK_N ({OUT_BLOCK_N})",
)
offs_m = BLOCK_M * block_id + gl.arange(0, BLOCK_M)
offs_y_n = OUT_BLOCK_N * pid_n + gl.arange(0, OUT_BLOCK_N)
mask_m = offs_m < M
mask_n = offs_y_n < yN
else:
tl.static_assert(
ACTIVATION_REDUCTION_N == 1,
"Activation reduction must be 1 if no activation fn is provided",
)
out = acc

if Gammas is not None:
gammas = gl.load(Gammas + start_m + offs_m, mask=mask_m, other=0.0)
out *= gammas[:, None]

# quant
if Quant_static_scale is not None:
out = _compute_static_fp8_quant(out, gl.load(Quant_static_scale))
# write-back
else:
out = out.to(tl.bfloat16)

# TDM Store: accumulator → shared memory → global memory
Y += start_m * stride_y_m
offs_y_m = offs_m
offs_y = (
offs_y_m.to(index_type)[:, None] * stride_y_m
+ offs_y_n.to(index_type)[None, :] * stride_y_n
y_buffer = gl.allocate_shared_memory(
Y.type.element_ty,
shape=[BLOCK_M, OUT_BLOCK_N],
layout=SHARED_LAYOUT_Y,
)
mask = mask_m[:, None] & mask_n[None, :]
if Quant_static_scale is None:
out = out.to(tl.bfloat16)
gl.amd.gfx1250.buffer_store(out, Y, offs_y, mask=mask)
y_desc = gl.amd.gfx1250.tdm.make_tensor_descriptor(
base=Y,
shape=(M, yN),
strides=(stride_y_m, stride_y_n),
block_shape=(BLOCK_M, OUT_BLOCK_N),
layout=SHARED_LAYOUT_Y,
)
y_buffer.store(out)
gl.amd.gfx1250.tdm.async_store(
y_desc, [block_id * BLOCK_M, pid_n * OUT_BLOCK_N], y_buffer
)
gl.amd.gfx1250.tdm.async_wait(0)
Loading