[Feature] Add Turbo MXFP8 Grouped GEMM (gfx950) for MoE#330
Conversation
There was a problem hiding this comment.
Pull request overview
This PR adds a new Turbo MXFP8 (MX_BLOCKWISE) grouped GEMM path targeting gfx950 (MI350/MI355) to accelerate MoE FFN workloads, including fused per-group padding/quantization and dedicated wgrad (variable-K) support.
Changes:
- Added gfx950 Turbo grouped GEMM kernels (fwd/dgrad) plus a variable-K Turbo kernel for wgrad, along with new PyTorch extension entry points.
- Added fused MXFP8 quantization ops for grouped inputs (including per-group M padding) to keep the flow async and avoid host syncs.
- Added MXFP8 Turbo wrapper + parameter-sweep test coverage and a benchmark option.
Reviewed changes
Copilot reviewed 19 out of 19 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/pytorch/ops/test_grouped_gemm_fp8.py | Adds gfx950-only MX_BLOCKWISE Turbo backend test coverage and ensures E8M0 scale dtype configuration. |
| primus_turbo/pytorch/ops/grouped_gemm_fp8.py | Introduces GroupedGemmFP8MXFunc wrapper for MX_BLOCKWISE (handles trans_b=False and per-group padding) and calls new Turbo ops. |
| primus_turbo/pytorch/kernels/quantization/quantization_impl.py | Adds Python wrapper for fused grouped MXFP8 dual quantization with per-group M padding. |
| primus_turbo/pytorch/kernels/grouped_gemm/grouped_gemm_fp8_impl.py | Registers Turbo backend entries for MX_BLOCKWISE grouped GEMM and variable-K wgrad dispatch. |
| csrc/pytorch/quantization/quantization.cpp | Implements quantize_mxfp8_dual_perg and quantize_mxfp8_dual_grouped (grouped + padded) C++ entry points. |
| csrc/pytorch/quantization/quantization_meta.cpp | Adds Meta kernels for the new MXFP8 quantization ops to support shape inference. |
| csrc/pytorch/grouped_gemm/turbo_grouped_gemm.cpp | Adds PyTorch bindings for Turbo MXFP8 grouped GEMM and Turbo variable-K wgrad entry points. |
| csrc/pytorch/extensions.h | Declares new Turbo grouped GEMM and MXFP8 quantization APIs. |
| csrc/pytorch/bindings_pytorch.cpp | Registers new ops (quantize_mxfp8_dual_grouped, quantize_mxfp8_dual_perg, Turbo grouped GEMMs) for CUDA and Meta. |
| csrc/kernels/quantization/quantization_mxfp8.cu | Extends MXFP8 quantization kernels for per-group/per-layout variants and adds per-group padded-layout computation. |
| csrc/kernels/grouped_gemm/turbo/turbo_grouped_gemm_mxfp8_kernel.h | Adds Turbo MXFP8 grouped GEMM persistent kernel (fwd/dgrad) implementation for gfx950. |
| csrc/kernels/grouped_gemm/turbo/turbo_grouped_gemm_mxfp8_wgrad_kernel.h | Adds Turbo MXFP8 variable-K grouped GEMM persistent kernel for wgrad on gfx950. |
| csrc/kernels/grouped_gemm/turbo_grouped_gemm.cu | Adds grouped-GEMM-specific preshuffle + workspace handling and launches Turbo grouped GEMM kernels. |
| csrc/include/primus_turbo/quantization.h | Adds constants and declares new MXFP8 grouped quantization + padded-layout helpers. |
| csrc/include/primus_turbo/grouped_gemm.h | Adds public parameter structs and APIs for Turbo MXFP8 grouped GEMM and variable-K wgrad. |
| csrc/include/primus_turbo/device/register.cuh | Strengthens inline-asm clobbering to improve race-freeness/ordering assumptions. |
| csrc/include/primus_turbo/device/mfma.cuh | Adds memory clobbers to MFMA inline asm to match intended synchronization semantics. |
| csrc/include/primus_turbo/device/memory.cuh | Adds VGPR clobbers after LDS reads to align with race-freeness assumptions. |
| benchmark/ops/bench_grouped_gemm_turbo.py | Adds an mxfp8 benchmark configuration and CLI option for comparison runs. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| // ``quantize_mxfp8_dual_perg``: per-group dual quant for the grouped GEMM | ||
| // B-side. Input is 3D ``(G, M, N)``; per-group M/N are uniform. The | ||
| // kernel writes: | ||
| // rowwise_output: ``(G, M_pad, N_pad)`` -- naturally fits ``(G, N_of_B, | ||
| // K_pad)`` for the fwd GEMM. | ||
| // colwise_output: ``(G, N, M_pad)`` -- naturally fits ``(G, K_of_B, | ||
| // N_pad)`` for the dgrad GEMM | ||
| // (no separate regroup). |
| if constexpr (Bytes == 16) { | ||
| clobber_vgpr_one<VDST + 2>(); | ||
| clobber_vgpr_one<VDST + 3>(); | ||
| } |
There was a problem hiding this comment.
Why do we need to add clobber_vgpr_one here for ds_read? The related registers should already be clobbered before the ds_read.
There was a problem hiding this comment.
Tried removing them — 100% race in grouped MXFP8 dgrad (max diff inf) and 0.7% race in dense MXFP8 mxgemm at 4096³, exactly matching the historical race profile we hit before.
The pre-ds_read clobber_vgpr_one<> reserves the dst VGPRs from being reused, but doesn't fence the compiler from sinking the consuming MFMA above the ds_read. The post-asm clobber on the dst VGPR is what creates that ordering edge.
There was a problem hiding this comment.
-
How was the 0.7% race obtained?
-
In the Turbo MXFP8 GEMM, there should be no race in any phase between MFMA and DS. MFMA computes the current block, while DS loads the data for the next block. Any cross-block race has already been constrained using __builtin_amdgcn_s_barrier.
There was a problem hiding this comment.
The role of clobber_vgpr_one is to allocate/register VGPRs. It cannot prevent the compiler from using them.
There was a problem hiding this comment.
Additionally, my personal understanding is that using clobber_vgpr_one on ds_read is not a semantically correct way to enforce ordering, because it only operates at the level of register allocation and does not involve modeling of memory dependencies or execution ordering. Any apparent “correctness effect” is likely just an indirect result of changes in compiler scheduling behavior, rather than a reliable ordering guarantee. Race conditions should instead be addressed using standard synchronization and memory-dependency mechanisms such as s_waitcnt and barriers.
There was a problem hiding this comment.
It seems to be some bug that led to this result. No race in main branch. I will take time to find out.
There was a problem hiding this comment.
The grouped GEMM race is triggered by compiler-emitted SGPR spill: the spill reload uses vmcnt, which gfx950 does not FIFO-order against buffer_load_lds, so a stale m0 can leak into the LDS write address. The latest commits both reduce spill (CK refactor + smem packing + -sink-insts-to-avoid-spills; private_segment drops 44 → 20 bytes) and move the prologue LDS commit onto lgkmcnt (2-step buffer_load → ds_write), decoupling it from the spill's vmcnt traffic — race goes to zero.
Tested on both gemm kernel and grouped gemm kernels, no race condition happened again.
There was a problem hiding this comment.
Why do we need the two-step path? It copies the data through VGPR first and then into SMEM, right? If so, it seems to break the whole design philosophy here:
- it breaks the counter separation model.
vmcnt is supposed to track only buffer_load_lds (GMEM→LDS)
lgkmcnt is supposed to track only ds_read_pinned (LDS→VGPR)
The two-step path introduces ds_write into lgkmcnt, which muddies the semantics of the two counters.
- the two-step path is effectively a reverse optimization.
The original purpose of buffer_load_lds was zero-VGPR buffering plus fully asynchronous transfer. The two-step path regresses into GMEM→VGPR→LDS, which either consumes extra VGPRs or sacrifices asynchronicity.
- SGPR resources on the hardware should already be extremely abundant.
There was a problem hiding this comment.
The 2-step prologue path is removed, and VGPR / SGPR spill is fully eliminated — all kernel specs now show sgpr_spill = 0 / vgpr_spill = 0 / private_segment = 0, and scratch_load / scratch_store are completely gone from disasm. The vmcnt race root cause is fixed at the source.
But clobber_vgpr_one / clobber_agpr_one still need to stay:
- The spill race and the clobber race are two different things. Clobber prevents compiler reordering —
ds_read_pinnedand MFMA both use"n"(VDST)constant operands that encode the register name into the asm string, so the compiler sees no def-use chain. Per GCC: twoasm volatilestatements are not ordered with respect to each other. Without the clobbers, the compiler is free to reorder MFMA before the consumingds_read→ reads stale VGPR. - Verified empirically: with spill at 0, removing both clobbers still races — grouped fwd 5.5%, wgrad 0.05%.
| else | ||
| asm volatile("v_mfma_scale_f32_16x16x128_f8f6f4 v[%0:%1], v[%2:%3], v[%4:%5], v[%0:%1], v[%6], v[%7] op_sel_hi:[0,0,0] cbsz:1 blgp:1" | ||
| : : "n"(PIN_ACC), "n"(PIN_ACC + 3), "n"(PIN_A), "n"(PIN_A + 7), "n"(PIN_B), "n"(PIN_B + 7), "n"(PIN_SA), "n"(PIN_SB)); | ||
| : : "n"(PIN_ACC), "n"(PIN_ACC + 3), "n"(PIN_A), "n"(PIN_A + 7), "n"(PIN_B), "n"(PIN_B + 7), "n"(PIN_SA), "n"(PIN_SB) : "memory"); |
There was a problem hiding this comment.
v_mfma_scale_f32_16x16x128_f8f6f4 only modifies registers, so I don't think we need the "memory" here.
There was a problem hiding this comment.
removed all "memory"
| static_assert(AGPR >= 0 && AGPR <= 255, "AGPR must be in [0, 255]"); | ||
| // clang-format off | ||
| #define CLOBBER_AREG_CASE(N) case N: asm volatile("" ::: "a" #N); break; | ||
| #define CLOBBER_AREG_CASE(N) case N: asm volatile("" ::: "a" #N, "memory"); break; |
| template <int AC> __device__ __forceinline__ void zero_agpr() { | ||
| asm volatile("v_accvgpr_write_b32 a[%0], 0" : : "n"(AC)); | ||
| asm volatile("v_accvgpr_write_b32 a[%0], 0" : : "n"(AC) : "memory"); | ||
| clobber_agpr_one<AC>(); |
| if constexpr (N >= 13) asm volatile("v_accvgpr_read_b32 %0, a[%1]" : "=v"(raw[12]) : "n"(AC + 12) : "memory"); | ||
| if constexpr (N >= 14) asm volatile("v_accvgpr_read_b32 %0, a[%1]" : "=v"(raw[13]) : "n"(AC + 13) : "memory"); | ||
| if constexpr (N >= 15) asm volatile("v_accvgpr_read_b32 %0, a[%1]" : "=v"(raw[14]) : "n"(AC + 14) : "memory"); | ||
| if constexpr (N >= 16) asm volatile("v_accvgpr_read_b32 %0, a[%1]" : "=v"(raw[15]) : "n"(AC + 15) : "memory"); |
| typename GemmTile::AScaleSmemSubtile (*a_s_smem_tile)[4], | ||
| typename GemmTile::BScaleSmemSubtile (*b_s_smem_tile)[4], const AType *a_ptr, | ||
| const BType *b_ptr, const uint32_t *a_s_ptr, const uint32_t *b_s_ptr, CType *c_ptr, | ||
| const int64_t *group_offs_ptr, const int64_t *c_group_offs_ptr, const int32_t group_id, |
There was a problem hiding this comment.
Naming is asymmetric: group_offs_ptr is actually input-side (padded) offsets while c_group_offs_ptr is output-side (compact). Consider renaming to group_offs_in_ptr / group_offs_out_ptr (or a_group_offs_ptr / c_group_offs_ptr) so the asymmetry is visible at the callsite.
There was a problem hiding this comment.
Just realized that in real workloads the inputs probably always need to be padded anyway. In that case, it may be cleaner to stop overloading the presence of c_group_offs_ptr as a mode switch, and instead always pass explicit a_group_offs_ptr, b_group_offs_ptr, and c_group_offs_ptr. The interface becomes much more straightforward.
There was a problem hiding this comment.
Done. Launcher now always passes a non-null c_group_offs_ptr (caller's compact-output offsets when supplied; a_group_offs_ptr as fallback for the in-place layout). Kernel takes a single straight-line extract path with no mode switch on the pointer.
On b_group_offs_ptr: B is laid out as [G, N, K] and indexed by group_id directly — no per-group row offsets for B, so no third pointer to add.
| // per-group region to 128-aligned, so this matches | ||
| // group_offs_padded[g+1] - group_offs_padded[g] without an | ||
| // extra scalar load. | ||
| M_g_in = (M_g + 127) & ~127; |
| static_assert(VGPR >= 0 && VGPR <= 255, "VGPR must be in [0, 255]"); | ||
| // clang-format off | ||
| #define CLOBBER_VREG_CASE(N) case N: asm volatile("" ::: "v" #N); break; | ||
| #define CLOBBER_VREG_CASE(N) case N: asm volatile("" ::: "v" #N, "memory"); break; |
| granularity: ScalingGranularity, | ||
| num_cu: int | None, | ||
| **kwargs, | ||
| ) -> bool: |
There was a problem hiding this comment.
Done — added: n_fwd % 16 == 0 and k_fwd % 16 == 0 (wgrad MFMA tile is 16×16, output axes must clear)
78ccce9 to
847872c
Compare
…ler spill race Address PR #330 review #12 properly: launcher always passes a non-null c_group_offs_ptr (caller's compact-output offsets, or a_group_offs as fallback for the in-place layout); kernel takes a single straight-line extract path with no mode switch. Side effect: removing the if-else collapses the SSA phi-merge that had been sharing c_m_global / M_g_in register slots with m_global / M_g. Peak SGPR pressure across the LDS-direct buffer_load_lds prologue then crosses the per-wave limit (~100), and the compiler spills via scratch_store/load + readfirstlane. The reload's vmcnt-tracked completion can lose to a younger buffer_load on the shared counter (no FIFO guarantee), so a stale m0 leaks in and the buffer writes to the wrong LDS slot. Empirically reproduces at the gpt_oss-20B Expert shape (B=4 M=2048 N=5760 K=2880) at ~5% fwd in 10K reruns; tested 9 variants at the C++ level (assertions, fences, __builtin_assume, sched_barrier, deferral, formulation rewrites) — none recover baseline race rate. Workaround: 4 explicit s_waitcnt vmcnt(0) drains in the prologue between the first 4 buffer_load_lds calls, fencing the spill reload against the m0 setup. Empirically only the first four are needed (SGPR pressure peak); the rest run on already-flushed pressure. 0/20000 fwd reruns at gpt_oss-20B Expert + 0/10000 dense MXFP8. +4% fwd latency vs. baseline (338us vs 325us at the same shape). Filing the underlying compiler issue (vmcnt scoping for spill consumes against buffer_load_lds) separately.
|
@xiaobochen-amd Thanks for the commit! I have updated my code by the comments. Please review again, thanks. |
0df887d to
07e9940
Compare
| "alignment); also required: each per-group M_g % 128 == 0."); | ||
|
|
| // hint we use it; otherwise fall back to the pessimistic upper bound | ||
| // derivable from input metadata alone (no D2H sync of group_lens). | ||
| // Worst case: a single group holds all real rows, then per-group | ||
| // padding to MXFP8_PADDING_ALIGN_SIZE adds (group_num - 1)*ALIGN | ||
| // padding rows on top. | ||
| int32_t grid_x; | ||
| if (grid_x_hint > 0) { | ||
| grid_x = (int32_t) grid_x_hint; | ||
| } else { | ||
| constexpr int64_t ALIGN = primus_turbo::detail::MXFP8_PADDING_ALIGN_SIZE; | ||
| const int64_t total_m_upper = | ||
| ((int64_t) total_m_in + group_num * ALIGN + ALIGN - 1) / ALIGN * ALIGN; | ||
| grid_x = (int32_t) ((total_m_upper + 255) / 256); |
| // rowwise_output: ``(G, M_pad, N_pad)`` -- naturally fits ``(G, N_of_B, | ||
| // K_pad)`` for the fwd GEMM. |
| // | ||
| // Per-group dB[g] = LHS[g] @ RHS[g]^T, with LHS (N, total_M), RHS | ||
| // (K, total_M), dB (group_num, N, K); reduction is over total_M. | ||
| // Constraints: n % 16 == 0, k % 16 == 0, M_g % 32 == 0. |
| # MFMA tile = 16x16; output axes (N_fwd, K_fwd) must clear that. | ||
| if a.dim() == 2 and b.dim() == 2: | ||
| n_fwd = a.shape[0] | ||
| k_fwd = b.shape[1] | ||
| supported &= n_fwd % 16 == 0 and k_fwd % 16 == 0 | ||
| supported &= a.shape[1] == b.shape[0] |
|
|
||
| class GroupedGEMMFP8VariableKKernelDispatcher(BaseGroupedGEMMVariableKKernelDispatcher): | ||
| _backends = { | ||
| BackendType.TURBO: BackendEntry(GroupedGEMMFP8VariableKTurboBackend), |
| @pytest.mark.parametrize("balance", BALANCE_VALUES) | ||
| def test_grouped_gemm_fp8_mx_blockwise(B, M, NK, ori_dtype, format, trans_b, balance): |
| ScalingRecipe colwise_recipe, hipStream_t stream); | ||
|
|
||
| // Per-group dual quant for the MXFP8 grouped GEMM B-side: writes the | ||
| // rowwise output as ``(G, M_pad, N_pad)`` and the colwise output as |
…-K wgrad - Add hand-tuned 256x256x128 NT MXFP8 grouped GEMM kernels for gfx950: forward (turbo_grouped_gemm_mxfp8_kernel.h) and variable-K wgrad (turbo_grouped_gemm_mxfp8_wgrad_kernel.h), plus host-side dispatch and PyTorch bindings (turbo_grouped_gemm_fp8, turbo_grouped_gemm_variable_k_fp8, turbo_preshuffle_mxfp8_scale_16x4). - Wire MX_BLOCKWISE granularity into grouped_gemm_fp8 via a new GroupedGemmFP8MXFunc autograd Function covering fwd + dgrad + wgrad, routed through the turbo backends. - Extend benchmarks and tests with 384 MXFP8 cases (all passing on MI355). Made-with: Cursor
…rf and fix uint32 num_records overflow
This squash combines six iterative changes that bring grouped GEMM mxfp8
(forward + variable-K wgrad) on par with the single-GEMM mxfp8 kernel and
fix a correctness bug at huge total_M, plus a test-hygiene cleanup.
Changes
-------
1. Synchronization trim across single + grouped (fwd & wgrad):
- Drop redundant `wait_vmcnt`/`s_barrier` and AGPR/VGPR clobbers that
had been added defensively but are unnecessary once the device-header
memory clobbers are in place.
- Mainloop end-of-iter drain stays at `wait_vmcnt<12>` (looser was
racy at 100%).
- Drop the inter-tile `s_barrier` between persistent-loop tiles — the
prologue's `wait_vmcnt<0>; s_barrier` already provides cross-tile
ordering for SMEM.
2. C-store simplification: replace the dual-buffer + intermediate
`wait_vmcnt` C-store with the single-buffer pattern used by single
GEMM. Same number of stores, less code.
3. Tile reuse: hoist `GemmTile tile{...}` out of the per-tile compute
body so the persistent loop reuses one instance.
4. Mid-Epi2 split drain (mirror single GEMM): move the buffer-drain
`wait_vmcnt<0>` from end-of-Epi1 to mid-Epi2 (after Epi2 phase 2),
with `wait_vmcnt<6>` at end-of-Epi1 instead. This lets Epi1 phase
3+4 mfma_lds and Epi2 phase 1+2 overlap with the trailing 6
buffer_load_lds GMEM→LDS DMAs from Epi1's prefetch. Counter-intuitively
improves wgrad determinism (race ~0.3% → ~0.01%) by giving the
LDS-write commits an extra sync point before Phase 3 reads `next`.
5. wgrad BufferSRD num_records uint32 overflow fix: `(n - pid_n) *
total_m * sizeof(...)` was computed in uint32; for shapes with
total_M ≥ ~600K and either N or K large (e.g. Kimi-K2 EP=8 →
B=48 M=16384, total_M=786K), the product (5.25 GB) silently
truncated to 1.25 GB. buffer_load_lds then masked valid in-tile
addresses as OOB and returned zeros, dropping b_grad SNR to ~18 dB
on those cases. Compute in uint64 and clamp to UINT32_MAX before
assigning to BufferSRD's 32-bit num_records. FWD does not need the
same fix: per-group M_g ≤ 16384 keeps M_g*k and n*k well below 4 GB.
6. Test hygiene: drop the silent 128-rounding workaround in
`_run_grouped_gemm_fp8_test` that masked the kernel's M_g % 128 == 0
constraint, and restrict `test_grouped_gemm_fp8_mx_blockwise` to
`balance=True` only. Unbalanced groups whose per-group M_g is not a
128-multiple are not a supported configuration for the mxfp8 wgrad
kernel (preshuffled scale layout requires alignment) and should be
enforced at the wrapper level rather than worked around in tests.
Validation
----------
bench_grouped_gemm_turbo --dtype fp8 --granularity mxfp8: 288/288 PASS.
rocprofv3 kernel-only timings (3-run min) on representative shapes:
Single GEMM mxfp8 vs Triton tw, avg mx/tw = 0.948 (5% faster).
Grouped FWD mxfp8 vs Triton tw, avg mx/tw = 0.967 (3% faster).
Grouped WGRAD mxfp8 (no tw baseline; ≤ FWD on per-tile compute).
25 000-iter race stress (serial):
Single GEMM: 147/25000 = 0.588% (baseline)
Grouped FWD: 7/25000 = 0.028% (~21× more deterministic)
Grouped WGRAD: 2/25000 = 0.008% (~73× more deterministic)
Co-authored-by: Cursor <cursoragent@cursor.com>
…mbos, zero D2H sync
Make ``grouped_gemm_fp8(..., granularity=MX_BLOCKWISE)`` accept the same
``@pytest.mark.parametrize`` matrix as the TENSORWISE path (B / M / NK /
ori_dtype / format / trans_b / balance — minus Format.HYBRID), including
the previously-unsupported ``trans_b=False`` and ``balance=False``.
The latter two require per-group M-axis zero-padding to multiples of 128
because the turbo MXFP8 kernel mandates ``M_g % 128 == 0``. All padding,
quantisation and post-extraction logic now lives in C++ / HIP so the
hot path issues *zero* device-to-host syncs (verified via CUDA Graph
capture of the full fwd+bwd loop).
Two new C++ ops, mirroring the single-MXFP8 pair
(``quantize_mxfp8_dual`` + ``quantize_mxfp8``):
* ``quantize_mxfp8_dual_grouped(input, group_lens, group_offs, dtype,
...)`` — single fused kernel that
1. computes ``group_lens_padded`` / ``group_offs_padded`` on GPU
(single-thread layout kernel, fully async),
2. allocates outputs at the host-known upper bound
``ceil((total_M + G*128) / 128) * 128`` (over-alloc by at
most ``G*127`` rows — a few KB),
3. materialises the padded layout directly in the FP8 / scale
outputs (no intermediate ``torch.zeros`` + per-group
``copy_``),
and returns ``[rowwise_fp8, rowwise_scale, colwise_fp8,
colwise_scale, group_lens_padded, group_offs_padded]``.
* ``extract_grouped_rows(x_padded, group_offs_orig, group_offs_padded,
total_M_orig)`` — single-kernel replacement for the
``G x .copy_()`` loop that strips per-group padding from the GEMM
output. Tile shape is ``ROWS_PER_BLOCK=8`` rows × ``num_vec``
float4 to keep the launch grid small on large M.
The fused quant kernel handles out-of-bounds tiles (over-allocated
region, 0-1.6% of CTAs) by routing them through the existing
``amax == 0 -> scale = E8M0_EXPONENT_BIAS, fp8 = 0`` branch already
used for shuffle-padding rows. This keeps the wgrad preshuffle kernel
race-free without modifying the turbo grouped-GEMM kernel itself.
Misc: * ``GroupedGemmFP8MXFunc`` calls ``turbo_grouped_gemm_fp8`` and
``turbo_grouped_gemm_variable_k_fp8`` directly with a host-computed
``grid_x_hint`` (saves ~10-20us / call vs the dispatcher path and
skips the GEMM op's internal ``group_lens.cpu()`` D2H sync).
* Removed the silent ``group_lens % 128`` rounding workaround from
``test_grouped_gemm_fp8_mx_blockwise`` — the wrapper now handles
arbitrary group sizes.
Co-authored-by: Cursor <cursoragent@cursor.com>
Three perf-focused changes folded into one commit, all targeting the
MXFP8 grouped GEMM hot path on gfx950:
1. Chunked LDS-transpose preshuffle.
The original ``turbo::preshuffle_scale_16x4_kernel`` reads each 16-row
input tile with stride-``cols`` 1-byte loads (one warp covers a (16,4)
sub-tile via ``in[(tid%16)*cols + tid/16]``), which clocks at roughly
5-10% of HBM peak. On gpt_oss-style MoE shapes the grouped-GEMM hot
path issues three of these per step (fwd A+B, dgrad ∇y+B, wgrad
LHS+RHS) and they account for ~7-10% of the per-step cost.
Added a drop-in chunked variant
(``preshuffle_scale_16x4_v2_kernel`` /
``preshuffle_scale_16x4_dual_v2_kernel``) that stages a (32 cols × 16
rows) tile through LDS as 32-byte coalesced loads then writes the
transposed (16, 4)-shuffled scale block. This lifts the kernel from
~351us/step to ~95us/step on B=32, M=4096, K=4096 grouped MXFP8.
2. Fused extract pass into MXFP8 grouped GEMM store.
The MXFP8 grouped GEMM previously ran the GEMM on the per-group-
padded input layout, then ran a separate ``extract_grouped_rows``
kernel to strip per-group zero-pad rows back out. This commit folds
the extract into the GEMM store: the kernel now reads inputs from
the padded layout (via ``group_offs_ptr``) but writes outputs
directly to the unpadded ``[total_M_orig, N]`` tensor at row indices
given by the new optional ``c_group_offs_ptr``.
Plumbing:
- ``TurboGroupedGemmMXFP8Params`` gains ``c_group_offs_ptr``;
``total_m`` is the input layout's row count.
- ``turbo_grouped_gemm_fp8`` op signature gains ``c_group_offs`` and
``total_m_out``. Non-MX backends pass ``None`` / ``0`` and behave
identically.
- The MX autograd ``Function`` passes ``c_group_offs=group_offs``
(original unpadded) + ``total_m_out=a.size(0)`` for fwd and dgrad;
wgrad keeps the padded layout (its output is per-group-3D, no
compression).
Correctness: when ``c_group_offs_ptr != nullptr`` the input/scale SRD
bounds use ``M_g_in = ceil(M_g, 128)`` so reads cover full 16-row
blocks of preshuffled scale data; the original ``M_g`` continues to
gate compute early-exit and output writes (no padding rows leak to
the output). ``M_g_in`` is derived with a single ALU bit-mask
instead of a separate ``group_offs_padded[g+1] - [g]`` load, saving
one scalar load + readfirstlane per CTA setup.
Net effect on K=1536 cases (where ``extract_grouped_rows`` previously
contributed ~400us/step): the extract kernel disappears entirely and
most K=4096 / K=1536 grouped MXFP8 cases now hit or beat the
tensorwise+triton baseline.
All 21,728 grouped_gemm_fp8 parametrized test cases pass; CUDAGraph
capture confirms zero D2H sync.
Co-authored-by: Cursor <cursoragent@cursor.com>
The MXFP8 grouped GEMM B-side quant used to take a 2D ``(G*N, K)`` flat input through ``quantize_mxfp8_dual`` (producing rowwise + colwise outputs in flat ``(G*N, K_pad)`` / ``(K, G*N_pad)`` layout) and then reshape the colwise output into the ``(G, K, N_pad)`` layout the dgrad GEMM expects via a torch transpose+contiguous chain. The regroup chained ``view.reshape(K, G, N).transpose(0, 1).contiguous()`` (+ tail zero-pad), landing in torch's elementwise copy at ~30% HBM peak; on gpt-oss-20B Down B=32 N=K=2880 this took ~360us per step. This commit replaces both passes with a single per-group dual quant ``quantize_mxfp8_dual_perg``, taking the 3D ``(G, N, K)`` input directly and writing rowwise as ``(G, N, K_pad)`` and colwise as ``(G, K, N_pad)`` in one kernel. The kernel body is unchanged: we just add five per-blockIdx.z stride parameters to ``quantize_mxfp8_dual_kernel`` (default 0 for existing callers, so this is a no-op for them) and a thin host launcher that emits ``grid_z = G`` with the right per-group strides. Per-group M / N / M_pad / N_pad are uniform across G; the kernel's existing ``global_row < M`` guards on the row write side handle within-group M_pad-M padding rows. Measured impact (rocprofv3, gpt-oss-20B sweep, MX vs tensorwise+triton): - median MX-vs-TW overhead: +18.8% -> +12.8% - M=4096 cases hit or beat tensorwise+triton outright - Down B=32 M=2048: +38.6% -> +29.4% (regroup 360us -> 0) M=2048 cases remain bound by the persistent grouped GEMM kernel itself (MX fwd+dgrad lands ~30% off peak vs tensorwise+triton 57%, mostly because the per-block scale loads inflate the K=2944 inner loop). Co-authored-by: Cursor <cursoragent@cursor.com>
…ted) - #2-#6 drop redundant ': "memory"' clobbers from MFMA + clobber_*_one / zero_agpr / read_agpr asm — registers don't access memory; clobber lists carry the dependency. Verified 0/3000 grouped + 3/3000 dense reruns vs PR baseline 2/3000 + 3/3000 at gpt_oss-20B Expert shape. - #7 grid_x: drop the D2H sync of group_lens. C++ launcher computes the pessimistic upper bound from input metadata (total_m_in + group_num*ALIGN) directly; ~1cycle scalar arith vs ~5us PCIe round-trip. Python wrapper's _mxfp8_grid_x_hint helper deleted with all callers and ctx state. - #8 rename group_offs_ptr -> a_group_offs_ptr in MXFP8 paths to disambiguate from c_group_offs_ptr (output offsets). CK / hipBLASLt struct fields unchanged. - #9 use detail::MXFP8_PADDING_ALIGN_SIZE instead of the literal 127 in the M_g_padded round-up; matches the constant the upstream quant op uses. - #10 GroupedGEMMFP8VariableKTurboBackend.can_handle now gates on the 16-multiple shape requirement of the wgrad MFMA (n_fwd, k_fwd) and contraction-axis agreement, so unsupported shapes fall through to CK instead of asserting inside the kernel. - #11 ctx.trans_b_orig -> ctx.trans_b — the orig was a leftover from an earlier refactor; only one trans_b is now persisted. REJECTED on empirical grounds: - #1 (drop post-asm clobber_vgpr_one<> after ds_read_pinned). Reviewer argued these are redundant since callers reserve the dst VGPR range upfront; in practice removing them races 100% in grouped MXFP8 dgrad (max diff inf) and 0.7% in dense MXFP8 mxgemm at 4096^3 — exact match for the historical race profile. The reservation prevents the dst VGPRs from being reused but does NOT prevent the compiler from reordering ds_read past the consuming MFMA; only the per-asm clobber list does. Keeping the calls. OBSOLETED: - #12 (always-non-null c_group_offs_ptr; derive M_g_in from a_group_offs delta). The delta-load form raced ~3% in fwd at the gpt_oss-20B Expert shape; the simpler ceil(M_g, ALIGN) form (review #9 above) lands the same value with no extra scalar load. The if-nullptr branch in the kernel is one scalar test and stays — the cleanup wasn't worth the delta-load races.
…ed GEMM Combines and supersedes the trio of in-flight review iterations (0df887d, 7af738c, db6152e) into a single coherent change. Net effect relative to PR base (847872c): MXFP8 grouped GEMM (fwd / dgrad / wgrad) is fully deterministic on gfx950 at PR-base perf, with the source-level workarounds dropped in favour of the actual root-cause fix. Code changes: - Drop the c_group_offs == nullptr mode-switch and the clobber_vgpr_one pinning workaround; they hid (not fixed) the race and bloated codegen. - CK-style separation: outer persistent kernel resolves per-group pointers (a_grp_ptr, b_grp_ptr, ..., c_grp_ptr) entirely; compute_tile receives them already group-resolved + per-tile bookkeeping only. Narrows the SGPR live set to the dense single-GEMM kernel level. - Pack the 4 typed smem tile pointers into one char* base; compute_tile reconstructs typed views internally. Saves 6 SGPR args. - Two-step prologue (buffer_load_b128 -> VGPR -> ds_write_b128) for the 16 LDS-direct loads in fwd's prologue; wgrad uses 2-step only for the 4 data loads (full conversion regresses wgrad perf). Puts the LDS write on lgkmcnt instead of vmcnt, separated from compiler-emitted scratch_load (SGPR spill consume). Compiler flags (the actual root-cause fix for the residual race): - `-amdgpu-enable-merge-m0=true` merges adjacent `s_mov_b32 m0` writes, collapsing per-load m0 setup chains in the buffer_load_lds main loop. Without merge, each LDS-direct load has its own m0 update fed from a potentially-stale spill SGPR; with merge, m0 is established once per cluster, breaking the spill->m0->LDS-write race. - `-sink-insts-to-avoid-spills=true` sinks defining instructions toward their uses, halving private_segment from 44 to 20 bytes. This kills the residual cross-call VGPR live range that drove the scratch_store / scratch_load pair — vmcnt counts both scratch and buffer_load_lds, and gfx950 doesn't FIFO across them, so a spill reload could otherwise leak into m0. Race + perf at gpt_oss-20B Expert shape (B=4 M=2048 N=5760 K=2880): - fwd: 0/500K (vs PR HEAD ~3/30K = 0.01%); Fwd 1340 TFLOPS (+1.4%) - dgrad: 0/150K; Dgrad 1258 TFLOPS (+2.1%) - wgrad: 0/90K; Wgrad 1723 TFLOPS (-1.5%) Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Previous commit added separate `*_two_step` helper variants and a
TWO_STEP template parameter on phase_mfma_lds_ldg. In the final config
the main loop never uses the 2-step path, so the duplication and the
template branching in phase_mfma_lds_ldg are dead weight.
Cleanup:
- Drop `*_half_srd_two_step` duplicates; collapse the 4 affected
`load_{a,b}{,_scale}_gmem_to_smem_half_srd` helpers behind a
`bool TwoStep = false` template parameter and a single `if constexpr`
switch around the inner load.
- Drop the TWO_STEP template parameter + `if constexpr` chain on
`phase_mfma_lds_ldg`; the main loop is always LDS-direct.
- Collapse the two free-function `load_gmem_to_smem_srd_two_step` impls
into one `<Bytes>` template (b32 + b128).
- Drop the now-unused `clobber_{a,v}gpr_one` `using` decls in
turbo_gemm_mxfp8_kernel{,_hip}.h.
Per-prologue 2-step decisions kept as before:
- fwd: data + scale prologue loads use TwoStep=true (16 calls).
- wgrad: only the 8 data prologue loads use TwoStep=true; scales stay
LDS-direct (full conversion regresses wgrad ~25%).
- inter-iter and main-loop loads stay LDS-direct.
Race + perf at gpt_oss-20B Expert shape (B=4 M=2048 N=5760 K=2880),
chi2761 GPU 6, 50K iters each:
- fwd: 0/50K; Fwd 1337 TFLOPS (≡ 567e578)
- dgrad: 0/50K; Dgrad 1254 TFLOPS (≡ 567e578)
- wgrad: 0/50K; Wgrad 1756 TFLOPS (≡ 567e578)
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Reset csrc to PR base 847872c. Reviewer (Chen Xiaobo) rejected the 2-step prologue / mllvm flags / clobber_vgpr_one / reserve_pinned_regs approach as papering over the race instead of fixing it. The rejected commits stay in git log per prior request.
Race source: VGPR-spilled SMEM-subtile addresses go through scratch_load, which shares vmcnt with buffer_load_lds but doesn't FIFO across it on gfx950. s_waitcnt vmcnt(N) can pass while scratch_load is still in flight → stale m0 → LDS write to wrong address → 0.01-1% bit non-det. Three source-level changes (no LLVM flags, no clobber barriers): 1. Outer kernel resolves per-group base ptrs; inner takes resolved ptrs only. i64 group offset chain stays out of inner. c_group_offs_ptr now always-valid via launcher fallback; per-group input padding is signalled branchlessly via a c_padding_align_mask field. 2. SMEM-write helpers force LDS-base into SGPR: compute_sts_offsets drops lane_id*16 (buffer_load_lds auto-strides per lane); scale helpers drop scale_sts_offset = lane_id; sts_warp_base / smem_byte_offset wrapped in __builtin_amdgcn_readfirstlane so the warp-id arithmetic is explicitly scalar. 3. bf16 store uses raw-bit truncation instead of hip_bfloat16(float) software round-to-even (avoids SCC branches that get lane-spilled). SNR unchanged at 28.23 dB. reg notes (sgpr/sgpr_spill/vgpr_spill/scratch): grouped fwd bf16 106 / 7 / 10 / 44 → 106 / 0 / 0 / 0 grouped fwd half 106 / 0 / 17 / 72 → 106 / 0 / 0 / 0 grouped wgrad 106 / 0 / 17 / 72 → 104-105 / 0 / 0 / 0 single GEMM 81 / 0 / 0 / 0 unchanged fwd race: 0.01-1% → 0/200K (4 GPUs × 50K iters). Perf (B=16 M=2048 N=4096 K=7168): Fwd 1435 / Dgrad 1416 / Wgrad 2013 TFLOPS. SNR fwd / dgrad / wgrad: 28.23 / 28.23 / 28.08 dB.
After the spill fix, compute_sts_offsets returned only [0, 1024] and threaded a [2]-array through every load_a/b_gmem_to_smem_half_srd call. Inline the constant: helper writes `base` and `base + 1024` directly (1024 = per-wave LDS write span for buffer_load_dwordx4 = 16 B/lane × 64). Also two small style alignments with the single-GEMM kernel: - pid_n_idx → pid_n_local in fwd compute_tile (consistent with pid_m_local) - scale_cols declaration moved between base ptrs (mxgemm ordering) Verified: reg notes / race / SNR all unchanged.
07e9940 to
8244d7a
Compare
| // ``quantize_mxfp8_dual_perg``: per-group dual quant for the grouped GEMM | ||
| // B-side. Input is 3D ``(G, M, N)``; per-group M/N are uniform. The | ||
| // kernel writes: | ||
| // rowwise_output: ``(G, M_pad, N_pad)`` -- naturally fits ``(G, N_of_B, | ||
| // K_pad)`` for the fwd GEMM. | ||
| // colwise_output: ``(G, N, M_pad)`` -- naturally fits ``(G, K_of_B, | ||
| // N_pad)`` for the dgrad GEMM | ||
| // (no separate regroup). |
| detail::ScalingRecipe colwise_recipe, hipStream_t stream); | ||
|
|
||
| // Per-group dual quant with uniform per-group (M, N): input (G, M, N) -> | ||
| // rowwise (G, M_pad, N_pad), colwise (G, N, M_pad). Used by the grouped |
abaeaab to
e0263a0
Compare
Two wait_lgkmcnt<0> drains (mid-phase WAR + compute_tile prologue) fix an LDS WAR race in the MFMA-pinned-asm path; race verified at 0/N on both single-GEMM and grouped GEMM MX_BLOCKWISE deterministic tests.
e0263a0 to
17df201
Compare
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 21 out of 21 changed files in this pull request and generated 3 comments.
Comments suppressed due to low confidence (2)
tests/pytorch/ops/test_grouped_gemm_fp8.py:554
- Same issue as the deterministic MX_BLOCKWISE test above: this uses an exact (9, 5) compute capability check, which will skip on newer gfx95x devices even if they support MXFP8. Use check_mxfp8_support() or >= (9, 5) instead of equality.
if get_device_compute_capability() != (9, 5):
pytest.skip("MXFP8 grouped GEMM requires gfx950 (MI350/MI355).")
primus_turbo/pytorch/kernels/grouped_gemm/grouped_gemm_fp8_impl.py:628
- Same concern as the fwd/dgrad TURBO backend: can_handle() for the variable-K TURBO backend doesn’t verify gfx950/MXFP8 support. Adding a device capability gate would prevent the dispatcher from selecting a kernel that asserts at runtime on unsupported GPUs when invoked with pre-quantized tensors.
supported = True
supported &= a.dim() == 2 and b.dim() == 2
supported &= (a.dtype, b.dtype, out_dtype) in GroupedGEMMFP8VariableKTurboBackend.SUPPORTED_DTYPES
supported &= granularity in GroupedGEMMFP8VariableKTurboBackend.SUPPORTED_GRANULARITIES
supported &= not trans_a and not trans_b and not trans_c
# MFMA tile = 16x16; output axes (N_fwd, K_fwd) must clear that.
if a.dim() == 2 and b.dim() == 2:
n_fwd = a.shape[0]
k_fwd = b.shape[1]
supported &= n_fwd % 16 == 0 and k_fwd % 16 == 0
supported &= a.shape[1] == b.shape[0]
return supported
| @pytest.mark.parametrize("balance", [True, False]) | ||
| @pytest.mark.deterministic | ||
| def test_grouped_gemm_fp8_mx_blockwise_deterministic(B, M, NK, ori_dtype, format, trans_b, balance): | ||
| if get_device_compute_capability() != (9, 5): |
| // ``quantize_mxfp8_dual_perg``: per-group dual quant for the grouped GEMM | ||
| // B-side. Input is 3D ``(G, M, N)``; per-group M/N are uniform. The | ||
| // kernel writes: | ||
| // rowwise_output: ``(G, M_pad, N_pad)`` -- naturally fits ``(G, N_of_B, | ||
| // K_pad)`` for the fwd GEMM. | ||
| // colwise_output: ``(G, N, M_pad)`` -- naturally fits ``(G, K_of_B, | ||
| // N_pad)`` for the dgrad GEMM | ||
| // (no separate regroup). |
| supported = True | ||
| supported &= a.dim() == 2 and b.dim() == 3 | ||
| supported &= (a.dtype, b.dtype, out_dtype) in GroupedGEMMFP8TurboBackend.SUPPORTED_DTYPES | ||
| supported &= granularity in GroupedGEMMFP8TurboBackend.SUPPORTED_GRANULARITIES | ||
| supported &= not trans_a and trans_b | ||
| total_m = a.shape[0] | ||
| n = b.shape[-2] if trans_b else b.shape[-1] | ||
| k = a.shape[1] | ||
| supported &= total_m % 16 == 0 and n % 16 == 0 and k % 128 == 0 and k >= 384 | ||
| return supported |
| return grad_out if grad_out.is_contiguous() else grad_out.contiguous() | ||
|
|
||
|
|
||
| def _turbo_grouped_gemm_mxfp8( |
There was a problem hiding this comment.
The code here does not align with our backend design
| @pytest.mark.parametrize("balance", [True, False]) | ||
| @pytest.mark.deterministic | ||
| def test_grouped_gemm_fp8_mx_blockwise_deterministic(B, M, NK, ori_dtype, format, trans_b, balance): | ||
| if get_device_compute_capability() != (9, 5): |
There was a problem hiding this comment.
check_mxfp8_support
…pads & helpers - ops: delete _turbo_grouped_gemm_mxfp8 / _turbo_grouped_gemm_variable_k_mxfp8 bypass helpers; MX_BLOCKWISE fwd/dgrad/wgrad now go through grouped_gemm_fp8_impl / grouped_gemm_fp8_variable_k_impl with default_backend=BackendType.TURBO so env-var override / autotune / can_handle fallback all engage uniformly with the fp8 tensorwise path. - dispatcher: extend grouped_gemm_fp8_impl (+ register_fake) with c_group_offs / total_m_out / grid_x_hint kwargs to expose MX unpad-on- store and persistent grid hint via the public op signature. Hipblaslt backends gain a **kwargs shield to absorb the new keys. - can_handle: fix GroupedGEMMFP8VariableKTurboBackend shape check (was a.shape[1] == b.shape[0] / k_fwd = b.shape[1]; MX wgrad layout has the variable-M reduction dim as shape[1] on both — correct check is a.shape[1] == b.shape[1], k_fwd = b.shape[0]). Bug never surfaced because the bypass skipped can_handle entirely. - kernel: inline CType(f) (default RNE) at the C store; drop the bf16 raw-bit truncation specialisation (and the float_to_ctype helper). Disasm now shows 0 spill on all MXFP8 kernels, so the workaround is unnecessary; truncation silently diverged from torch's bf16 cast and bites numerics debugging. - kernel: drop the (scale_cols + 3) / 4 * 4 pad in workspace_size, impl, and the compute_tile / launcher scale_cols. Launcher asserts k % 128 == 0, so scale_cols = k/32 is always a multiple of 4 and the pad is dead arithmetic. Also drops the preshuffle's trailing partial col- block branch (dead under the same invariant). - test: switch grouped MX deterministic / mx_blockwise gating from get_device_compute_capability() != (9, 5) to check_mxfp8_support() to match test_gemm_fp8.py / test_quantization.py. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
17df201 to
c5c03ad
Compare
Description
Add a turbo MXFP8 grouped GEMM backend (gfx950 / MI350) for MoE FFN
workloads. Forward, dgrad and wgrad all run on hand-tuned 256x256x128
persistent kernels with E8M0 scales preshuffled into 16x4 col-major
blocks. The wrapper handles
trans_b=Falseand unaligned per-groupM_g transparently (zero-pad along M), so the new path passes the full
tensorwise parametrization (HYBRID excluded for now).
End-to-end overhead vs
tensorwise + Tritonon gpt-oss-20B shapes:median +12.8%, with M=4096 cases matching or beating the tensorwise
baseline.
Type of change
Changes
csrc/kernels/grouped_gemm/):fwd / dgrad share a 256x256x128 persistent kernel with optional
store-side unpad via
c_group_offs; wgrad uses a variable-K (M_greduction) kernel. uint32 BufferSRD num_records is clamped to avoid
the >4 GB silent-truncation that masks valid LDGs as OOB.
to the single-GEMM kernel but stages each 16-row tile through LDS in
256-byte column slabs, restoring fully coalesced HBM traffic on the
fwd / dgrad / wgrad hot paths.
quantize_mxfp8_dual_perg): writesrowwise
(G, M, N_pad)and colwise(G, N, M_pad)directly,eliminating the torch transpose+contiguous regroup that previously
bottlenecked the dgrad B-side.
quantize_mxfp8_dual_grouped):computes the padded layout on GPU (no D2H sync) and materialises the
padded layout straight into the output tensors, allocating outputs
at the host-known upper bound so the call is fully async.
clobber_*gpr/zero_agpr/read_agprinline asm +clobber_vgpr_oneafterds_read_b*, matching the single-GEMM kernel's race profile.GroupedGemmFP8MXFunc): handlestrans_b=False(NT materialisation + grad transpose-back) and unaligned per-group
M_g (128-aligned per-group zero-pad on A / grad_out) transparently.
test_grouped_gemm_fp8_mx_blockwisemirrors the fulltensorwise param sweep (
B,M,NK,ori_dtype,format,trans_b,balance); HYBRID format is intentionally excluded.bench_grouped_gemm_turbo.pyadds anmxfp8granularityfor direct comparison with
tensorwise + Triton.Checklist
MX Grouped GEMM Bench
GPU: MI355 (gfx950) · dtype: bfloat16 · MX_BLOCKWISE (E4M3 + E8M0, block=32) ·
NT layout · CUDA-event timed (warmup 20 / iters 100). All numbers are
TFLOPS (higher is better).
Fwd / Dgrad / Wgrad): pureturbo_grouped_gemm_fp8/turbo_grouped_gemm_variable_k_fp8, inputs pre-quantized outside thetimed region. Includes the in-kernel scale preshuffle.
FLOPs =
2·B·M·N·Kper kernel.Fwd+Q / Bwd+Q): full Python op times.Fwd+Q=grouped_gemm_fp8(...)(A grouped quant + B per-group quant +GEMM, FLOPs =
2·B·M·N·K).Bwd+Q=out.sum().backward()only (grad_out grouped quant + dgrad +wgrad, FLOPs =
4·B·M·N·K).bal= balanced (M_g = M);unb= random per-group split, total =B·M.DeepSeek-V3 (n_routed_experts=256, hidden=7168, ffn=2048)
Kernel-only · Balanced
Kernel-only · Unbalanced
With Quant · Balanced
With Quant · Unbalanced
gpt_oss_20B (n_routed_experts=32, hidden=2880, ffn=2880)
Kernel-only · Balanced
Kernel-only · Unbalanced
With Quant · Balanced
With Quant · Unbalanced