[Gluon][gfx1250] Gemm MXFP4 preshuffled#2332
Merged
Merged
Conversation
Contributor
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
vgokhale
reviewed
May 13, 2026
vgokhale
reviewed
May 13, 2026
vgokhale
reviewed
May 13, 2026
vgokhale
reviewed
May 13, 2026
vgokhale
previously approved these changes
May 13, 2026
Contributor
There was a problem hiding this comment.
Pull request overview
This PR adds a gfx1250-focused “preshuffled” MXFP4 GEMM path, including new shuffling helpers, a new Gluon kernel for gfx1250, and updated tests/benchmarks to exercise the preshuffled layout.
Changes:
- Added gfx1250-specific weight/scale preshuffle logic and a new Gluon-based preshuffle kernel.
- Updated GEMM preshuffle wrapper to route gfx1250 to the Gluon kernel and adjusted tests/benchmarks accordingly.
- Introduced a new gfx1250 preshuffle tuning config JSON and a new shuffle helper in
aiter/ops/shuffle.py.
Reviewed changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
| op_tests/triton_tests/gemm/basic/test_gemm_afp4wfp4.py | Adds gfx1250 scale shuffling + a new gfx1250 preshuffled GEMM test and switches to triton.testing.assert_close. |
| op_tests/op_benchmarks/triton/bench_gemm_afp4wfp4.py | Renames bench flag to --preshuffle (alias --shuffle) and benches gemm_afp4wfp4_preshuffle. |
| aiter/ops/triton/gluon/gemm_afp4wfp4.py | Expands device allow-list to include gfx1250 for Gluon AFP4WFP4 config loading. |
| aiter/ops/triton/gemm/basic/gemm_afp4wfp4.py | Refactors kernel imports, adds gfx1250 Gluon preshuffle dispatch, and modifies preshuffle K/grid handling. |
| aiter/ops/triton/configs/gemm/gfx1250-GEMM-AFP4WFP4_PRESHUFFLED.json | Replaces gfx1250 preshuffled tuning entries and adds NUM_BUFFERS for the new Gluon kernel. |
| aiter/ops/triton/_gluon_kernels/gemm/basic/gemm_mxfp4.py | Adds new Gluon/TDM-based gfx1250 preshuffled MXFP4 GEMM kernel and associated layouts/depreshuffle views. |
| aiter/ops/shuffle.py | Adds shuffle_weight_gfx1250 helper for preshuffling weights into the TDM-friendly layout. |
Comments suppressed due to low confidence (2)
aiter/ops/triton/gemm/basic/gemm_afp4wfp4.py:499
gridno longer multiplies byMETA["NUM_KSPLIT"]. The underlying Triton preshuffle kernel mapspid_kfromprogram_id(axis=0)assuming the launch grid isGRID_MN * NUM_KSPLIT; with the current grid, split-K launches will be incomplete and the reduction path will be wrong. Restore theNUM_KSPLIT * cdiv(M, BM) * cdiv(N, BN)factor for the Triton preshuffle kernel (and handle Gluon separately if needed).
grid = lambda META: ( # noqa: E731
(triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"])),
)
aiter/ops/triton/gemm/basic/gemm_afp4wfp4.py:551
- The gfx1250 Gluon preshuffle path drops
NUM_KSPLITfrom the kernel config but still allowsconfig["NUM_KSPLIT"] > 1earlier, and also setsstride_c_kbased on the 2D outputy. Since the Gluon kernel implementation is not split-K aware, this needs an explicit guard (e.g., forceNUM_KSPLIT=1/ skip split-K allocation) or implement proper split-K semantics for the Gluon path.
if use_gluon:
layouts = get_gemm_afp4wfp4_preshuffle_layouts(
config["num_warps"],
config["BLOCK_SIZE_M"],
config["BLOCK_SIZE_N"],
config["BLOCK_SIZE_K"],
)
_DROP_KEYS = (
"NUM_KSPLIT",
"SPLITK_BLOCK_SIZE",
"SPLITK_BLOCK",
"GROUP_SIZE_M",
"num_stages",
"waves_per_eu",
"matrix_instr_nonkdim",
"cache_modifier",
)
kernel_config = {k: v for k, v in config.items() if k not in _DROP_KEYS}
# Kernel consumes preshuffled scales directly (address math inverts the shuffle in registers)
assert M >= 32, "gluon mxfp4 preshuffle path requires M >= 32"
x_scales = x_scales.contiguous()
w_scales = w_scales.contiguous()
_gluon_gemm_mxfp4_preshuffle_gfx1250[grid](
x_fp4,
w_preshuf,
y,
x_scales,
w_scales,
M,
N,
K_elems,
x_fp4.stride(0),
x_fp4.stride(1),
w_preshuf.stride(0),
w_preshuf.stride(1),
0 if config["NUM_KSPLIT"] == 1 else y.stride(0),
y.stride(-2),
y.stride(-1),
x_scales.stride(0),
x_scales.stride(1),
w_scales.stride(0),
w_scales.stride(1),
**kernel_config,
**layouts,
)
return y
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
96c77c1 to
c4371d2
Compare
vgokhale
approved these changes
May 22, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
Technical Details
Test Plan
Test Result
Submission Checklist