Skip to content

[Gluon][gfx1250] Gemm MXFP4 preshuffled#2332

Merged
Boss2002n merged 53 commits into
mainfrom
satya/gfx12_mxfp4_gemm
May 23, 2026
Merged

[Gluon][gfx1250] Gemm MXFP4 preshuffled#2332
Boss2002n merged 53 commits into
mainfrom
satya/gfx12_mxfp4_gemm

Conversation

@Boss2002n
Copy link
Copy Markdown
Contributor

Motivation

Technical Details

Test Plan

Test Result

Submission Checklist

@github-actions
Copy link
Copy Markdown
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:sglang SGLang integration tests
ci:atom ATOM benchmark (DeepSeek-R1 + GPT-OSS)
ci:vllm vLLM benchmark
ci:all All of the above

Add labels via the sidebar or gh pr edit 2332 --add-label <label>

@Boss2002n Boss2002n changed the title PR to main [Gluon][gfx1250] Gemm MXFP4 preshuffled Mar 24, 2026
@Boss2002n Boss2002n self-assigned this Mar 24, 2026
Comment thread aiter/ops/triton/gemm/basic/gemm_afp4wfp4.py Outdated
Comment thread aiter/ops/triton/gemm/basic/gemm_afp4wfp4.py
Comment thread op_tests/triton_tests/gemm/basic/test_gemm_afp4wfp4.py Outdated
Comment thread op_tests/triton_tests/gemm/basic/test_gemm_afp4wfp4.py Outdated
vgokhale
vgokhale previously approved these changes May 13, 2026
@Boss2002n Boss2002n marked this pull request as ready for review May 13, 2026 20:15
@Boss2002n Boss2002n requested review from a team and Copilot May 13, 2026 20:15
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds a gfx1250-focused “preshuffled” MXFP4 GEMM path, including new shuffling helpers, a new Gluon kernel for gfx1250, and updated tests/benchmarks to exercise the preshuffled layout.

Changes:

  • Added gfx1250-specific weight/scale preshuffle logic and a new Gluon-based preshuffle kernel.
  • Updated GEMM preshuffle wrapper to route gfx1250 to the Gluon kernel and adjusted tests/benchmarks accordingly.
  • Introduced a new gfx1250 preshuffle tuning config JSON and a new shuffle helper in aiter/ops/shuffle.py.

Reviewed changes

Copilot reviewed 7 out of 7 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
op_tests/triton_tests/gemm/basic/test_gemm_afp4wfp4.py Adds gfx1250 scale shuffling + a new gfx1250 preshuffled GEMM test and switches to triton.testing.assert_close.
op_tests/op_benchmarks/triton/bench_gemm_afp4wfp4.py Renames bench flag to --preshuffle (alias --shuffle) and benches gemm_afp4wfp4_preshuffle.
aiter/ops/triton/gluon/gemm_afp4wfp4.py Expands device allow-list to include gfx1250 for Gluon AFP4WFP4 config loading.
aiter/ops/triton/gemm/basic/gemm_afp4wfp4.py Refactors kernel imports, adds gfx1250 Gluon preshuffle dispatch, and modifies preshuffle K/grid handling.
aiter/ops/triton/configs/gemm/gfx1250-GEMM-AFP4WFP4_PRESHUFFLED.json Replaces gfx1250 preshuffled tuning entries and adds NUM_BUFFERS for the new Gluon kernel.
aiter/ops/triton/_gluon_kernels/gemm/basic/gemm_mxfp4.py Adds new Gluon/TDM-based gfx1250 preshuffled MXFP4 GEMM kernel and associated layouts/depreshuffle views.
aiter/ops/shuffle.py Adds shuffle_weight_gfx1250 helper for preshuffling weights into the TDM-friendly layout.
Comments suppressed due to low confidence (2)

aiter/ops/triton/gemm/basic/gemm_afp4wfp4.py:499

  • grid no longer multiplies by META["NUM_KSPLIT"]. The underlying Triton preshuffle kernel maps pid_k from program_id(axis=0) assuming the launch grid is GRID_MN * NUM_KSPLIT; with the current grid, split-K launches will be incomplete and the reduction path will be wrong. Restore the NUM_KSPLIT * cdiv(M, BM) * cdiv(N, BN) factor for the Triton preshuffle kernel (and handle Gluon separately if needed).
    grid = lambda META: (  # noqa: E731
        (triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"])),
    )

aiter/ops/triton/gemm/basic/gemm_afp4wfp4.py:551

  • The gfx1250 Gluon preshuffle path drops NUM_KSPLIT from the kernel config but still allows config["NUM_KSPLIT"] > 1 earlier, and also sets stride_c_k based on the 2D output y. Since the Gluon kernel implementation is not split-K aware, this needs an explicit guard (e.g., force NUM_KSPLIT=1 / skip split-K allocation) or implement proper split-K semantics for the Gluon path.
    if use_gluon:
        layouts = get_gemm_afp4wfp4_preshuffle_layouts(
            config["num_warps"],
            config["BLOCK_SIZE_M"],
            config["BLOCK_SIZE_N"],
            config["BLOCK_SIZE_K"],
        )

        _DROP_KEYS = (
            "NUM_KSPLIT",
            "SPLITK_BLOCK_SIZE",
            "SPLITK_BLOCK",
            "GROUP_SIZE_M",
            "num_stages",
            "waves_per_eu",
            "matrix_instr_nonkdim",
            "cache_modifier",
        )
        kernel_config = {k: v for k, v in config.items() if k not in _DROP_KEYS}
        # Kernel consumes preshuffled scales directly (address math inverts the shuffle in registers)
        assert M >= 32, "gluon mxfp4 preshuffle path requires M >= 32"
        x_scales = x_scales.contiguous()
        w_scales = w_scales.contiguous()
        _gluon_gemm_mxfp4_preshuffle_gfx1250[grid](
            x_fp4,
            w_preshuf,
            y,
            x_scales,
            w_scales,
            M,
            N,
            K_elems,
            x_fp4.stride(0),
            x_fp4.stride(1),
            w_preshuf.stride(0),
            w_preshuf.stride(1),
            0 if config["NUM_KSPLIT"] == 1 else y.stride(0),
            y.stride(-2),
            y.stride(-1),
            x_scales.stride(0),
            x_scales.stride(1),
            w_scales.stride(0),
            w_scales.stride(1),
            **kernel_config,
            **layouts,
        )
        return y

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread aiter/ops/triton/gemm/basic/gemm_afp4wfp4.py Outdated
Comment thread aiter/ops/triton/gluon/gemm_afp4wfp4.py
Comment thread op_tests/triton_tests/gemm/basic/test_gemm_afp4wfp4.py
Comment thread aiter/ops/shuffle.py
Comment thread op_tests/op_benchmarks/triton/bench_gemm_afp4wfp4.py
Comment thread aiter/ops/triton/_gluon_kernels/gemm/basic/gemm_mxfp4.py
@Boss2002n Boss2002n force-pushed the satya/gfx12_mxfp4_gemm branch from 96c77c1 to c4371d2 Compare May 20, 2026 21:30
@Boss2002n Boss2002n merged commit e00cf3e into main May 23, 2026
36 of 43 checks passed
@Boss2002n Boss2002n deleted the satya/gfx12_mxfp4_gemm branch May 23, 2026 00:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants