Skip to content

[Feature] Add Turbo MXFP8 Grouped GEMM (gfx950) for MoE#330

Open
kyle-256 wants to merge 15 commits into
mainfrom
dev/kyle_mxfp8_gg_pr
Open

[Feature] Add Turbo MXFP8 Grouped GEMM (gfx950) for MoE#330
kyle-256 wants to merge 15 commits into
mainfrom
dev/kyle_mxfp8_gg_pr

Conversation

@kyle-256
Copy link
Copy Markdown
Collaborator

@kyle-256 kyle-256 commented May 7, 2026

Description

Add a turbo MXFP8 grouped GEMM backend (gfx950 / MI350) for MoE FFN
workloads. Forward, dgrad and wgrad all run on hand-tuned 256x256x128
persistent kernels with E8M0 scales preshuffled into 16x4 col-major
blocks. The wrapper handles trans_b=False and unaligned per-group
M_g transparently (zero-pad along M), so the new path passes the full
tensorwise parametrization (HYBRID excluded for now).

End-to-end overhead vs tensorwise + Triton on gpt-oss-20B shapes:
median +12.8%, with M=4096 cases matching or beating the tensorwise
baseline.

Type of change

  • Documentation change
  • Bug fix
  • New feature
  • Breaking change
  • Infra/Build change
  • Code refactoring

Changes

  • Turbo MXFP8 grouped GEMM kernels (csrc/kernels/grouped_gemm/):
    fwd / dgrad share a 256x256x128 persistent kernel with optional
    store-side unpad via c_group_offs; wgrad uses a variable-K (M_g
    reduction) kernel. uint32 BufferSRD num_records is clamped to avoid
    the >4 GB silent-truncation that masks valid LDGs as OOB.
  • Chunked-LDS dual-tensor preshuffle for E8M0 scales: bit-identical
    to the single-GEMM kernel but stages each 16-row tile through LDS in
    256-byte column slabs, restoring fully coalesced HBM traffic on the
    fwd / dgrad / wgrad hot paths.
  • Per-group dual MXFP8 quant (quantize_mxfp8_dual_perg): writes
    rowwise (G, M, N_pad) and colwise (G, N, M_pad) directly,
    eliminating the torch transpose+contiguous regroup that previously
    bottlenecked the dgrad B-side.
  • Fused per-group M zero-pad quant (quantize_mxfp8_dual_grouped):
    computes the padded layout on GPU (no D2H sync) and materialises the
    padded layout straight into the output tensors, allocating outputs
    at the host-known upper bound so the call is fully async.
  • Race-freeness: memory clobbers on MFMA / clobber_*gpr /
    zero_agpr / read_agpr inline asm + clobber_vgpr_one after
    ds_read_b*, matching the single-GEMM kernel's race profile.
  • Python wrapper (GroupedGemmFP8MXFunc): handles trans_b=False
    (NT materialisation + grad transpose-back) and unaligned per-group
    M_g (128-aligned per-group zero-pad on A / grad_out) transparently.
  • Tests: test_grouped_gemm_fp8_mx_blockwise mirrors the full
    tensorwise param sweep (B, M, NK, ori_dtype, format,
    trans_b, balance); HYBRID format is intentionally excluded.
  • Bench: bench_grouped_gemm_turbo.py adds an mxfp8 granularity
    for direct comparison with tensorwise + Triton.

Checklist

  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

MX Grouped GEMM Bench

GPU: MI355 (gfx950) · dtype: bfloat16 · MX_BLOCKWISE (E4M3 + E8M0, block=32) ·
NT layout · CUDA-event timed (warmup 20 / iters 100). All numbers are
TFLOPS (higher is better).

  • Kernel-only (Fwd / Dgrad / Wgrad): pure turbo_grouped_gemm_fp8 /
    turbo_grouped_gemm_variable_k_fp8, inputs pre-quantized outside the
    timed region. Includes the in-kernel scale preshuffle.
    FLOPs = 2·B·M·N·K per kernel.
  • End-to-end (Fwd+Q / Bwd+Q): full Python op times.
    Fwd+Q = grouped_gemm_fp8(...) (A grouped quant + B per-group quant +
    GEMM, FLOPs = 2·B·M·N·K).
    Bwd+Q = out.sum().backward() only (grad_out grouped quant + dgrad +
    wgrad, FLOPs = 4·B·M·N·K).
  • bal = balanced (M_g = M); unb = random per-group split, total = B·M.

DeepSeek-V3 (n_routed_experts=256, hidden=7168, ffn=2048)

Kernel-only · Balanced

Layer B M N K Fwd Dgrad Wgrad
GateUp 16 2048 4096 7168 2178.9 2080.4 1731.7
GateUp 16 4096 4096 7168 2342.9 2165.2 2211.0
GateUp 32 2048 4096 7168 1327.1 1224.0 1708.1
GateUp 32 4096 4096 7168 2359.5 2147.5 2052.6
Down 16 2048 7168 2048 1620.2 2162.9 1752.1
Down 16 4096 7168 2048 1695.0 2275.7 2139.4
Down 32 2048 7168 2048 917.9 1314.5 1716.6
Down 32 4096 7168 2048 1636.0 2279.1 2025.1

Kernel-only · Unbalanced

Layer B M N K Fwd Dgrad Wgrad
GateUp 16 2048 4096 7168 1602.9 1563.8 1744.8
GateUp 16 4096 4096 7168 1735.1 1695.7 2145.8
GateUp 32 2048 4096 7168 1328.6 1282.6 1670.2
GateUp 32 4096 4096 7168 1585.9 1442.0 2026.7
Down 16 2048 7168 2048 1267.1 1461.9 1756.6
Down 16 4096 7168 2048 1473.0 1766.4 2063.4
Down 32 2048 7168 2048 1008.0 1356.0 1674.3
Down 32 4096 7168 2048 1232.4 1522.5 1992.2

With Quant · Balanced

Layer B M N K Fwd+Q Bwd+Q
GateUp 16 2048 4096 7168 1106.1 1685.8
GateUp 16 4096 4096 7168 1345.9 1948.0
GateUp 32 2048 4096 7168 810.2 1280.4
GateUp 32 4096 4096 7168 1336.6 1830.8
Down 16 2048 7168 2048 1011.1 1346.5
Down 16 4096 7168 2048 1210.8 1519.5
Down 32 2048 7168 2048 677.8 1102.3
Down 32 4096 7168 2048 1168.6 1474.2

With Quant · Unbalanced

Layer B M N K Fwd+Q Bwd+Q
GateUp 16 2048 4096 7168 921.5 1452.3
GateUp 16 4096 4096 7168 1129.8 1721.4
GateUp 32 2048 4096 7168 821.9 1322.5
GateUp 32 4096 4096 7168 1037.7 1562.3
Down 16 2048 7168 2048 854.6 1159.3
Down 16 4096 7168 2048 1093.2 1372.3
Down 32 2048 7168 2048 724.3 1122.7
Down 32 4096 7168 2048 935.9 1262.4

gpt_oss_20B (n_routed_experts=32, hidden=2880, ffn=2880)

Kernel-only · Balanced

Layer B M N K Fwd Dgrad Wgrad
GateUp 4 2048 5760 2880 1494.2 1064.2 1423.0
GateUp 4 4096 5760 2880 1699.8 1718.0 1797.0
GateUp 32 2048 5760 2880 1037.5 1258.5 1547.6
GateUp 32 4096 5760 2880 1822.4 2144.3 1835.3
Down 4 2048 2880 2880 855.2 856.8 1291.8
Down 4 4096 2880 2880 1451.9 1438.0 1634.3
Down 32 2048 2880 2880 988.1 991.3 1525.2
Down 32 4096 2880 2880 1751.7 1770.5 1853.2

Kernel-only · Unbalanced

Layer B M N K Fwd Dgrad Wgrad
GateUp 4 2048 5760 2880 1213.4 1288.7 1417.8
GateUp 4 4096 5760 2880 1504.1 1652.2 1748.4
GateUp 32 2048 5760 2880 1123.9 1335.5 1508.3
GateUp 32 4096 5760 2880 1350.7 1521.0 1779.7
Down 4 2048 2880 2880 1035.8 1034.2 1233.2
Down 4 4096 2880 2880 1378.9 1378.2 1540.8
Down 32 2048 2880 2880 1072.2 1072.6 1461.0
Down 32 4096 2880 2880 1390.0 1387.7 1746.3

With Quant · Balanced

Layer B M N K Fwd+Q Bwd+Q
GateUp 4 2048 5760 2880 829.4 929.5
GateUp 4 4096 5760 2880 1145.8 1283.8
GateUp 32 2048 5760 2880 703.3 1120.6
GateUp 32 4096 5760 2880 1161.3 1546.9
Down 4 2048 2880 2880 534.6 733.8
Down 4 4096 2880 2880 860.4 1126.4
Down 32 2048 2880 2880 638.2 989.3
Down 32 4096 2880 2880 1030.8 1389.9

With Quant · Unbalanced

Layer B M N K Fwd+Q Bwd+Q
GateUp 4 2048 5760 2880 743.8 998.3
GateUp 4 4096 5760 2880 1011.5 1283.1
GateUp 32 2048 5760 2880 724.2 1103.7
GateUp 32 4096 5760 2880 943.7 1318.1
Down 4 2048 2880 2880 602.4 769.9
Down 4 4096 2880 2880 846.4 1069.4
Down 32 2048 2880 2880 652.1 996.8
Down 32 4096 2880 2880 870.5 1204.0

Copilot AI review requested due to automatic review settings May 7, 2026 15:33
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds a new Turbo MXFP8 (MX_BLOCKWISE) grouped GEMM path targeting gfx950 (MI350/MI355) to accelerate MoE FFN workloads, including fused per-group padding/quantization and dedicated wgrad (variable-K) support.

Changes:

  • Added gfx950 Turbo grouped GEMM kernels (fwd/dgrad) plus a variable-K Turbo kernel for wgrad, along with new PyTorch extension entry points.
  • Added fused MXFP8 quantization ops for grouped inputs (including per-group M padding) to keep the flow async and avoid host syncs.
  • Added MXFP8 Turbo wrapper + parameter-sweep test coverage and a benchmark option.

Reviewed changes

Copilot reviewed 19 out of 19 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
tests/pytorch/ops/test_grouped_gemm_fp8.py Adds gfx950-only MX_BLOCKWISE Turbo backend test coverage and ensures E8M0 scale dtype configuration.
primus_turbo/pytorch/ops/grouped_gemm_fp8.py Introduces GroupedGemmFP8MXFunc wrapper for MX_BLOCKWISE (handles trans_b=False and per-group padding) and calls new Turbo ops.
primus_turbo/pytorch/kernels/quantization/quantization_impl.py Adds Python wrapper for fused grouped MXFP8 dual quantization with per-group M padding.
primus_turbo/pytorch/kernels/grouped_gemm/grouped_gemm_fp8_impl.py Registers Turbo backend entries for MX_BLOCKWISE grouped GEMM and variable-K wgrad dispatch.
csrc/pytorch/quantization/quantization.cpp Implements quantize_mxfp8_dual_perg and quantize_mxfp8_dual_grouped (grouped + padded) C++ entry points.
csrc/pytorch/quantization/quantization_meta.cpp Adds Meta kernels for the new MXFP8 quantization ops to support shape inference.
csrc/pytorch/grouped_gemm/turbo_grouped_gemm.cpp Adds PyTorch bindings for Turbo MXFP8 grouped GEMM and Turbo variable-K wgrad entry points.
csrc/pytorch/extensions.h Declares new Turbo grouped GEMM and MXFP8 quantization APIs.
csrc/pytorch/bindings_pytorch.cpp Registers new ops (quantize_mxfp8_dual_grouped, quantize_mxfp8_dual_perg, Turbo grouped GEMMs) for CUDA and Meta.
csrc/kernels/quantization/quantization_mxfp8.cu Extends MXFP8 quantization kernels for per-group/per-layout variants and adds per-group padded-layout computation.
csrc/kernels/grouped_gemm/turbo/turbo_grouped_gemm_mxfp8_kernel.h Adds Turbo MXFP8 grouped GEMM persistent kernel (fwd/dgrad) implementation for gfx950.
csrc/kernels/grouped_gemm/turbo/turbo_grouped_gemm_mxfp8_wgrad_kernel.h Adds Turbo MXFP8 variable-K grouped GEMM persistent kernel for wgrad on gfx950.
csrc/kernels/grouped_gemm/turbo_grouped_gemm.cu Adds grouped-GEMM-specific preshuffle + workspace handling and launches Turbo grouped GEMM kernels.
csrc/include/primus_turbo/quantization.h Adds constants and declares new MXFP8 grouped quantization + padded-layout helpers.
csrc/include/primus_turbo/grouped_gemm.h Adds public parameter structs and APIs for Turbo MXFP8 grouped GEMM and variable-K wgrad.
csrc/include/primus_turbo/device/register.cuh Strengthens inline-asm clobbering to improve race-freeness/ordering assumptions.
csrc/include/primus_turbo/device/mfma.cuh Adds memory clobbers to MFMA inline asm to match intended synchronization semantics.
csrc/include/primus_turbo/device/memory.cuh Adds VGPR clobbers after LDS reads to align with race-freeness assumptions.
benchmark/ops/bench_grouped_gemm_turbo.py Adds an mxfp8 benchmark configuration and CLI option for comparison runs.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread primus_turbo/pytorch/ops/grouped_gemm_fp8.py Outdated
Comment on lines +502 to +509
// ``quantize_mxfp8_dual_perg``: per-group dual quant for the grouped GEMM
// B-side. Input is 3D ``(G, M, N)``; per-group M/N are uniform. The
// kernel writes:
// rowwise_output: ``(G, M_pad, N_pad)`` -- naturally fits ``(G, N_of_B,
// K_pad)`` for the fwd GEMM.
// colwise_output: ``(G, N, M_pad)`` -- naturally fits ``(G, K_of_B,
// N_pad)`` for the dgrad GEMM
// (no separate regroup).
if constexpr (Bytes == 16) {
clobber_vgpr_one<VDST + 2>();
clobber_vgpr_one<VDST + 3>();
}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to add clobber_vgpr_one here for ds_read? The related registers should already be clobbered before the ds_read.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tried removing them — 100% race in grouped MXFP8 dgrad (max diff inf) and 0.7% race in dense MXFP8 mxgemm at 4096³, exactly matching the historical race profile we hit before.

The pre-ds_read clobber_vgpr_one<> reserves the dst VGPRs from being reused, but doesn't fence the compiler from sinking the consuming MFMA above the ds_read. The post-asm clobber on the dst VGPR is what creates that ordering edge.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. How was the 0.7% race obtained?

  2. In the Turbo MXFP8 GEMM, there should be no race in any phase between MFMA and DS. MFMA computes the current block, while DS loads the data for the next block. Any cross-block race has already been constrained using __builtin_amdgcn_s_barrier.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The role of clobber_vgpr_one is to allocate/register VGPRs. It cannot prevent the compiler from using them.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additionally, my personal understanding is that using clobber_vgpr_one on ds_read is not a semantically correct way to enforce ordering, because it only operates at the level of register allocation and does not involve modeling of memory dependencies or execution ordering. Any apparent “correctness effect” is likely just an indirect result of changes in compiler scheduling behavior, rather than a reliable ordering guarantee. Race conditions should instead be addressed using standard synchronization and memory-dependency mechanisms such as s_waitcnt and barriers.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to be some bug that led to this result. No race in main branch. I will take time to find out.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The grouped GEMM race is triggered by compiler-emitted SGPR spill: the spill reload uses vmcnt, which gfx950 does not FIFO-order against buffer_load_lds, so a stale m0 can leak into the LDS write address. The latest commits both reduce spill (CK refactor + smem packing + -sink-insts-to-avoid-spills; private_segment drops 44 → 20 bytes) and move the prologue LDS commit onto lgkmcnt (2-step buffer_load → ds_write), decoupling it from the spill's vmcnt traffic — race goes to zero.

Tested on both gemm kernel and grouped gemm kernels, no race condition happened again.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need the two-step path? It copies the data through VGPR first and then into SMEM, right? If so, it seems to break the whole design philosophy here:

  1. it breaks the counter separation model.

vmcnt is supposed to track only buffer_load_lds (GMEM→LDS)
lgkmcnt is supposed to track only ds_read_pinned (LDS→VGPR)

The two-step path introduces ds_write into lgkmcnt, which muddies the semantics of the two counters.

  1. the two-step path is effectively a reverse optimization.

The original purpose of buffer_load_lds was zero-VGPR buffering plus fully asynchronous transfer. The two-step path regresses into GMEM→VGPR→LDS, which either consumes extra VGPRs or sacrifices asynchronicity.

  1. SGPR resources on the hardware should already be extremely abundant.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

“CK refactor” ?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The 2-step prologue path is removed, and VGPR / SGPR spill is fully eliminated — all kernel specs now show sgpr_spill = 0 / vgpr_spill = 0 / private_segment = 0, and scratch_load / scratch_store are completely gone from disasm. The vmcnt race root cause is fixed at the source.

But clobber_vgpr_one / clobber_agpr_one still need to stay:

  • The spill race and the clobber race are two different things. Clobber prevents compiler reorderingds_read_pinned and MFMA both use "n"(VDST) constant operands that encode the register name into the asm string, so the compiler sees no def-use chain. Per GCC: two asm volatile statements are not ordered with respect to each other. Without the clobbers, the compiler is free to reorder MFMA before the consuming ds_read → reads stale VGPR.
  • Verified empirically: with spill at 0, removing both clobbers still races — grouped fwd 5.5%, wgrad 0.05%.

else
asm volatile("v_mfma_scale_f32_16x16x128_f8f6f4 v[%0:%1], v[%2:%3], v[%4:%5], v[%0:%1], v[%6], v[%7] op_sel_hi:[0,0,0] cbsz:1 blgp:1"
: : "n"(PIN_ACC), "n"(PIN_ACC + 3), "n"(PIN_A), "n"(PIN_A + 7), "n"(PIN_B), "n"(PIN_B + 7), "n"(PIN_SA), "n"(PIN_SB));
: : "n"(PIN_ACC), "n"(PIN_ACC + 3), "n"(PIN_A), "n"(PIN_A + 7), "n"(PIN_B), "n"(PIN_B + 7), "n"(PIN_SA), "n"(PIN_SB) : "memory");
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

v_mfma_scale_f32_16x16x128_f8f6f4 only modifies registers, so I don't think we need the "memory" here.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed all "memory"

static_assert(AGPR >= 0 && AGPR <= 255, "AGPR must be in [0, 255]");
// clang-format off
#define CLOBBER_AREG_CASE(N) case N: asm volatile("" ::: "a" #N); break;
#define CLOBBER_AREG_CASE(N) case N: asm volatile("" ::: "a" #N, "memory"); break;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

template <int AC> __device__ __forceinline__ void zero_agpr() {
asm volatile("v_accvgpr_write_b32 a[%0], 0" : : "n"(AC));
asm volatile("v_accvgpr_write_b32 a[%0], 0" : : "n"(AC) : "memory");
clobber_agpr_one<AC>();
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

if constexpr (N >= 13) asm volatile("v_accvgpr_read_b32 %0, a[%1]" : "=v"(raw[12]) : "n"(AC + 12) : "memory");
if constexpr (N >= 14) asm volatile("v_accvgpr_read_b32 %0, a[%1]" : "=v"(raw[13]) : "n"(AC + 13) : "memory");
if constexpr (N >= 15) asm volatile("v_accvgpr_read_b32 %0, a[%1]" : "=v"(raw[14]) : "n"(AC + 14) : "memory");
if constexpr (N >= 16) asm volatile("v_accvgpr_read_b32 %0, a[%1]" : "=v"(raw[15]) : "n"(AC + 15) : "memory");
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

typename GemmTile::AScaleSmemSubtile (*a_s_smem_tile)[4],
typename GemmTile::BScaleSmemSubtile (*b_s_smem_tile)[4], const AType *a_ptr,
const BType *b_ptr, const uint32_t *a_s_ptr, const uint32_t *b_s_ptr, CType *c_ptr,
const int64_t *group_offs_ptr, const int64_t *c_group_offs_ptr, const int32_t group_id,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Naming is asymmetric: group_offs_ptr is actually input-side (padded) offsets while c_group_offs_ptr is output-side (compact). Consider renaming to group_offs_in_ptr / group_offs_out_ptr (or a_group_offs_ptr / c_group_offs_ptr) so the asymmetry is visible at the callsite.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just realized that in real workloads the inputs probably always need to be padded anyway. In that case, it may be cleaner to stop overloading the presence of c_group_offs_ptr as a mode switch, and instead always pass explicit a_group_offs_ptr, b_group_offs_ptr, and c_group_offs_ptr. The interface becomes much more straightforward.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Launcher now always passes a non-null c_group_offs_ptr (caller's compact-output offsets when supplied; a_group_offs_ptr as fallback for the in-place layout). Kernel takes a single straight-line extract path with no mode switch on the pointer.

On b_group_offs_ptr: B is laid out as [G, N, K] and indexed by group_id directly — no per-group row offsets for B, so no third pointer to add.

// per-group region to 128-aligned, so this matches
// group_offs_padded[g+1] - group_offs_padded[g] without an
// extra scalar load.
M_g_in = (M_g + 127) & ~127;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

127 hardcode?

static_assert(VGPR >= 0 && VGPR <= 255, "VGPR must be in [0, 255]");
// clang-format off
#define CLOBBER_VREG_CASE(N) case N: asm volatile("" ::: "v" #N); break;
#define CLOBBER_VREG_CASE(N) case N: asm volatile("" ::: "v" #N, "memory"); break;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

granularity: ScalingGranularity,
num_cu: int | None,
**kwargs,
) -> bool:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need check shape?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done — added: n_fwd % 16 == 0 and k_fwd % 16 == 0 (wgrad MFMA tile is 16×16, output axes must clear)

Comment thread primus_turbo/pytorch/ops/grouped_gemm_fp8.py Outdated
Copilot AI review requested due to automatic review settings May 12, 2026 05:06
@kyle-256 kyle-256 force-pushed the dev/kyle_mxfp8_gg_pr branch from 78ccce9 to 847872c Compare May 12, 2026 05:12
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 18 out of 18 changed files in this pull request and generated 3 comments.

Comment thread csrc/pytorch/quantization/quantization.cpp
Comment thread csrc/pytorch/quantization/quantization_meta.cpp
Comment thread csrc/pytorch/quantization/quantization_meta.cpp
kyle-256 pushed a commit that referenced this pull request May 12, 2026
…ler spill race

Address PR #330 review #12 properly: launcher always passes a non-null
c_group_offs_ptr (caller's compact-output offsets, or a_group_offs as
fallback for the in-place layout); kernel takes a single straight-line
extract path with no mode switch.

Side effect: removing the if-else collapses the SSA phi-merge that had
been sharing c_m_global / M_g_in register slots with m_global / M_g.
Peak SGPR pressure across the LDS-direct buffer_load_lds prologue then
crosses the per-wave limit (~100), and the compiler spills via
scratch_store/load + readfirstlane.  The reload's vmcnt-tracked
completion can lose to a younger buffer_load on the shared counter
(no FIFO guarantee), so a stale m0 leaks in and the buffer writes to
the wrong LDS slot.  Empirically reproduces at the gpt_oss-20B Expert
shape (B=4 M=2048 N=5760 K=2880) at ~5% fwd in 10K reruns; tested 9
variants at the C++ level (assertions, fences, __builtin_assume,
sched_barrier, deferral, formulation rewrites) — none recover baseline
race rate.

Workaround: 4 explicit s_waitcnt vmcnt(0) drains in the prologue
between the first 4 buffer_load_lds calls, fencing the spill reload
against the m0 setup.  Empirically only the first four are needed
(SGPR pressure peak); the rest run on already-flushed pressure.
0/20000 fwd reruns at gpt_oss-20B Expert + 0/10000 dense MXFP8.
+4% fwd latency vs. baseline (338us vs 325us at the same shape).

Filing the underlying compiler issue (vmcnt scoping for spill consumes
against buffer_load_lds) separately.
@kyle-256
Copy link
Copy Markdown
Collaborator Author

@xiaobochen-amd Thanks for the commit! I have updated my code by the comments. Please review again, thanks.

Copilot AI review requested due to automatic review settings May 14, 2026 06:55
@kyle-256 kyle-256 force-pushed the dev/kyle_mxfp8_gg_pr branch from 0df887d to 07e9940 Compare May 14, 2026 06:55
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 20 out of 20 changed files in this pull request and generated 10 comments.

Comment on lines +151 to +152
"alignment); also required: each per-group M_g % 128 == 0.");

Comment on lines +74 to +86
// hint we use it; otherwise fall back to the pessimistic upper bound
// derivable from input metadata alone (no D2H sync of group_lens).
// Worst case: a single group holds all real rows, then per-group
// padding to MXFP8_PADDING_ALIGN_SIZE adds (group_num - 1)*ALIGN
// padding rows on top.
int32_t grid_x;
if (grid_x_hint > 0) {
grid_x = (int32_t) grid_x_hint;
} else {
constexpr int64_t ALIGN = primus_turbo::detail::MXFP8_PADDING_ALIGN_SIZE;
const int64_t total_m_upper =
((int64_t) total_m_in + group_num * ALIGN + ALIGN - 1) / ALIGN * ALIGN;
grid_x = (int32_t) ((total_m_upper + 255) / 256);
Comment on lines +505 to +506
// rowwise_output: ``(G, M_pad, N_pad)`` -- naturally fits ``(G, N_of_B,
// K_pad)`` for the fwd GEMM.
//
// Per-group dB[g] = LHS[g] @ RHS[g]^T, with LHS (N, total_M), RHS
// (K, total_M), dB (group_num, N, K); reduction is over total_M.
// Constraints: n % 16 == 0, k % 16 == 0, M_g % 32 == 0.
Comment on lines +622 to +627
# MFMA tile = 16x16; output axes (N_fwd, K_fwd) must clear that.
if a.dim() == 2 and b.dim() == 2:
n_fwd = a.shape[0]
k_fwd = b.shape[1]
supported &= n_fwd % 16 == 0 and k_fwd % 16 == 0
supported &= a.shape[1] == b.shape[0]

class GroupedGEMMFP8VariableKKernelDispatcher(BaseGroupedGEMMVariableKKernelDispatcher):
_backends = {
BackendType.TURBO: BackendEntry(GroupedGEMMFP8VariableKTurboBackend),
Comment on lines +510 to +511
@pytest.mark.parametrize("balance", BALANCE_VALUES)
def test_grouped_gemm_fp8_mx_blockwise(B, M, NK, ori_dtype, format, trans_b, balance):
ScalingRecipe colwise_recipe, hipStream_t stream);

// Per-group dual quant for the MXFP8 grouped GEMM B-side: writes the
// rowwise output as ``(G, M_pad, N_pad)`` and the colwise output as
kyle-256 and others added 12 commits May 16, 2026 01:02
…-K wgrad

- Add hand-tuned 256x256x128 NT MXFP8 grouped GEMM kernels for gfx950:
  forward (turbo_grouped_gemm_mxfp8_kernel.h) and variable-K wgrad
  (turbo_grouped_gemm_mxfp8_wgrad_kernel.h), plus host-side dispatch and
  PyTorch bindings (turbo_grouped_gemm_fp8, turbo_grouped_gemm_variable_k_fp8,
  turbo_preshuffle_mxfp8_scale_16x4).
- Wire MX_BLOCKWISE granularity into grouped_gemm_fp8 via a new
  GroupedGemmFP8MXFunc autograd Function covering fwd + dgrad + wgrad,
  routed through the turbo backends.
- Extend benchmarks and tests with 384 MXFP8 cases (all passing on MI355).

Made-with: Cursor
…rf and fix uint32 num_records overflow

This squash combines six iterative changes that bring grouped GEMM mxfp8
(forward + variable-K wgrad) on par with the single-GEMM mxfp8 kernel and
fix a correctness bug at huge total_M, plus a test-hygiene cleanup.

Changes
-------
1. Synchronization trim across single + grouped (fwd & wgrad):
   - Drop redundant `wait_vmcnt`/`s_barrier` and AGPR/VGPR clobbers that
     had been added defensively but are unnecessary once the device-header
     memory clobbers are in place.
   - Mainloop end-of-iter drain stays at `wait_vmcnt<12>` (looser was
     racy at 100%).
   - Drop the inter-tile `s_barrier` between persistent-loop tiles — the
     prologue's `wait_vmcnt<0>; s_barrier` already provides cross-tile
     ordering for SMEM.

2. C-store simplification: replace the dual-buffer + intermediate
   `wait_vmcnt` C-store with the single-buffer pattern used by single
   GEMM.  Same number of stores, less code.

3. Tile reuse: hoist `GemmTile tile{...}` out of the per-tile compute
   body so the persistent loop reuses one instance.

4. Mid-Epi2 split drain (mirror single GEMM): move the buffer-drain
   `wait_vmcnt<0>` from end-of-Epi1 to mid-Epi2 (after Epi2 phase 2),
   with `wait_vmcnt<6>` at end-of-Epi1 instead.  This lets Epi1 phase
   3+4 mfma_lds and Epi2 phase 1+2 overlap with the trailing 6
   buffer_load_lds GMEM→LDS DMAs from Epi1's prefetch.  Counter-intuitively
   improves wgrad determinism (race ~0.3% → ~0.01%) by giving the
   LDS-write commits an extra sync point before Phase 3 reads `next`.

5. wgrad BufferSRD num_records uint32 overflow fix: `(n - pid_n) *
   total_m * sizeof(...)` was computed in uint32; for shapes with
   total_M ≥ ~600K and either N or K large (e.g. Kimi-K2 EP=8 →
   B=48 M=16384, total_M=786K), the product (5.25 GB) silently
   truncated to 1.25 GB.  buffer_load_lds then masked valid in-tile
   addresses as OOB and returned zeros, dropping b_grad SNR to ~18 dB
   on those cases.  Compute in uint64 and clamp to UINT32_MAX before
   assigning to BufferSRD's 32-bit num_records.  FWD does not need the
   same fix: per-group M_g ≤ 16384 keeps M_g*k and n*k well below 4 GB.

6. Test hygiene: drop the silent 128-rounding workaround in
   `_run_grouped_gemm_fp8_test` that masked the kernel's M_g % 128 == 0
   constraint, and restrict `test_grouped_gemm_fp8_mx_blockwise` to
   `balance=True` only.  Unbalanced groups whose per-group M_g is not a
   128-multiple are not a supported configuration for the mxfp8 wgrad
   kernel (preshuffled scale layout requires alignment) and should be
   enforced at the wrapper level rather than worked around in tests.

Validation
----------
bench_grouped_gemm_turbo --dtype fp8 --granularity mxfp8: 288/288 PASS.

rocprofv3 kernel-only timings (3-run min) on representative shapes:
  Single GEMM mxfp8 vs Triton tw, avg mx/tw = 0.948 (5% faster).
  Grouped FWD  mxfp8 vs Triton tw, avg mx/tw = 0.967 (3% faster).
  Grouped WGRAD mxfp8 (no tw baseline; ≤ FWD on per-tile compute).

25 000-iter race stress (serial):
  Single GEMM:  147/25000 = 0.588% (baseline)
  Grouped FWD:    7/25000 = 0.028%   (~21× more deterministic)
  Grouped WGRAD:  2/25000 = 0.008%   (~73× more deterministic)

Co-authored-by: Cursor <cursoragent@cursor.com>
…mbos, zero D2H sync

Make ``grouped_gemm_fp8(..., granularity=MX_BLOCKWISE)`` accept the same
``@pytest.mark.parametrize`` matrix as the TENSORWISE path (B / M / NK /
ori_dtype / format / trans_b / balance — minus Format.HYBRID), including
the previously-unsupported ``trans_b=False`` and ``balance=False``.

The latter two require per-group M-axis zero-padding to multiples of 128
because the turbo MXFP8 kernel mandates ``M_g % 128 == 0``.  All padding,
quantisation and post-extraction logic now lives in C++ / HIP so the
hot path issues *zero* device-to-host syncs (verified via CUDA Graph
capture of the full fwd+bwd loop).

Two new C++ ops, mirroring the single-MXFP8 pair
(``quantize_mxfp8_dual`` + ``quantize_mxfp8``):

  * ``quantize_mxfp8_dual_grouped(input, group_lens, group_offs, dtype,
    ...)`` — single fused kernel that
        1. computes ``group_lens_padded`` / ``group_offs_padded`` on GPU
           (single-thread layout kernel, fully async),
        2. allocates outputs at the host-known upper bound
           ``ceil((total_M + G*128) / 128) * 128`` (over-alloc by at
           most ``G*127`` rows — a few KB),
        3. materialises the padded layout directly in the FP8 / scale
           outputs (no intermediate ``torch.zeros`` + per-group
           ``copy_``),
       and returns ``[rowwise_fp8, rowwise_scale, colwise_fp8,
       colwise_scale, group_lens_padded, group_offs_padded]``.

  * ``extract_grouped_rows(x_padded, group_offs_orig, group_offs_padded,
    total_M_orig)`` — single-kernel replacement for the
    ``G x .copy_()`` loop that strips per-group padding from the GEMM
    output.  Tile shape is ``ROWS_PER_BLOCK=8`` rows × ``num_vec``
    float4 to keep the launch grid small on large M.

The fused quant kernel handles out-of-bounds tiles (over-allocated
region, 0-1.6% of CTAs) by routing them through the existing
``amax == 0 -> scale = E8M0_EXPONENT_BIAS, fp8 = 0`` branch already
used for shuffle-padding rows.  This keeps the wgrad preshuffle kernel
race-free without modifying the turbo grouped-GEMM kernel itself.

Misc: * ``GroupedGemmFP8MXFunc`` calls ``turbo_grouped_gemm_fp8`` and
    ``turbo_grouped_gemm_variable_k_fp8`` directly with a host-computed
    ``grid_x_hint`` (saves ~10-20us / call vs the dispatcher path and
    skips the GEMM op's internal ``group_lens.cpu()`` D2H sync).
  * Removed the silent ``group_lens % 128`` rounding workaround from
    ``test_grouped_gemm_fp8_mx_blockwise`` — the wrapper now handles
    arbitrary group sizes.
Co-authored-by: Cursor <cursoragent@cursor.com>
Three perf-focused changes folded into one commit, all targeting the
MXFP8 grouped GEMM hot path on gfx950:

1. Chunked LDS-transpose preshuffle.

   The original ``turbo::preshuffle_scale_16x4_kernel`` reads each 16-row
   input tile with stride-``cols`` 1-byte loads (one warp covers a (16,4)
   sub-tile via ``in[(tid%16)*cols + tid/16]``), which clocks at roughly
   5-10% of HBM peak.  On gpt_oss-style MoE shapes the grouped-GEMM hot
   path issues three of these per step (fwd A+B, dgrad ∇y+B, wgrad
   LHS+RHS) and they account for ~7-10% of the per-step cost.

   Added a drop-in chunked variant
   (``preshuffle_scale_16x4_v2_kernel`` /
   ``preshuffle_scale_16x4_dual_v2_kernel``) that stages a (32 cols × 16
   rows) tile through LDS as 32-byte coalesced loads then writes the
   transposed (16, 4)-shuffled scale block.  This lifts the kernel from
   ~351us/step to ~95us/step on B=32, M=4096, K=4096 grouped MXFP8.

2. Fused extract pass into MXFP8 grouped GEMM store.

   The MXFP8 grouped GEMM previously ran the GEMM on the per-group-
   padded input layout, then ran a separate ``extract_grouped_rows``
   kernel to strip per-group zero-pad rows back out.  This commit folds
   the extract into the GEMM store: the kernel now reads inputs from
   the padded layout (via ``group_offs_ptr``) but writes outputs
   directly to the unpadded ``[total_M_orig, N]`` tensor at row indices
   given by the new optional ``c_group_offs_ptr``.

   Plumbing:
   - ``TurboGroupedGemmMXFP8Params`` gains ``c_group_offs_ptr``;
     ``total_m`` is the input layout's row count.
   - ``turbo_grouped_gemm_fp8`` op signature gains ``c_group_offs`` and
     ``total_m_out``.  Non-MX backends pass ``None`` / ``0`` and behave
     identically.
   - The MX autograd ``Function`` passes ``c_group_offs=group_offs``
     (original unpadded) + ``total_m_out=a.size(0)`` for fwd and dgrad;
     wgrad keeps the padded layout (its output is per-group-3D, no
     compression).

   Correctness: when ``c_group_offs_ptr != nullptr`` the input/scale SRD
   bounds use ``M_g_in = ceil(M_g, 128)`` so reads cover full 16-row
   blocks of preshuffled scale data; the original ``M_g`` continues to
   gate compute early-exit and output writes (no padding rows leak to
   the output).  ``M_g_in`` is derived with a single ALU bit-mask
   instead of a separate ``group_offs_padded[g+1] - [g]`` load, saving
   one scalar load + readfirstlane per CTA setup.

   Net effect on K=1536 cases (where ``extract_grouped_rows`` previously
   contributed ~400us/step): the extract kernel disappears entirely and
   most K=4096 / K=1536 grouped MXFP8 cases now hit or beat the
   tensorwise+triton baseline.

All 21,728 grouped_gemm_fp8 parametrized test cases pass; CUDAGraph
capture confirms zero D2H sync.

Co-authored-by: Cursor <cursoragent@cursor.com>
The MXFP8 grouped GEMM B-side quant used to take a 2D ``(G*N, K)`` flat
input through ``quantize_mxfp8_dual`` (producing rowwise + colwise
outputs in flat ``(G*N, K_pad)`` / ``(K, G*N_pad)`` layout) and then
reshape the colwise output into the ``(G, K, N_pad)`` layout the dgrad
GEMM expects via a torch transpose+contiguous chain.  The regroup
chained ``view.reshape(K, G, N).transpose(0, 1).contiguous()`` (+ tail
zero-pad), landing in torch's elementwise copy at ~30% HBM peak; on
gpt-oss-20B Down B=32 N=K=2880 this took ~360us per step.

This commit replaces both passes with a single per-group dual quant
``quantize_mxfp8_dual_perg``, taking the 3D ``(G, N, K)`` input
directly and writing rowwise as ``(G, N, K_pad)`` and colwise as
``(G, K, N_pad)`` in one kernel.  The kernel body is unchanged: we
just add five per-blockIdx.z stride parameters to
``quantize_mxfp8_dual_kernel`` (default 0 for existing callers, so this
is a no-op for them) and a thin host launcher that emits ``grid_z = G``
with the right per-group strides.  Per-group M / N / M_pad / N_pad are
uniform across G; the kernel's existing ``global_row < M`` guards on
the row write side handle within-group M_pad-M padding rows.

Measured impact (rocprofv3, gpt-oss-20B sweep, MX vs tensorwise+triton):
  - median MX-vs-TW overhead: +18.8% -> +12.8%
  - M=4096 cases hit or beat tensorwise+triton outright
  - Down B=32 M=2048: +38.6% -> +29.4% (regroup 360us -> 0)

M=2048 cases remain bound by the persistent grouped GEMM kernel itself
(MX fwd+dgrad lands ~30% off peak vs tensorwise+triton 57%, mostly
because the per-block scale loads inflate the K=2944 inner loop).

Co-authored-by: Cursor <cursoragent@cursor.com>
…ted)

- #2-#6 drop redundant ': "memory"' clobbers from MFMA + clobber_*_one /
  zero_agpr / read_agpr asm — registers don't access memory; clobber lists
  carry the dependency.  Verified 0/3000 grouped + 3/3000 dense reruns vs
  PR baseline 2/3000 + 3/3000 at gpt_oss-20B Expert shape.

- #7 grid_x: drop the D2H sync of group_lens.  C++ launcher computes the
  pessimistic upper bound from input metadata (total_m_in + group_num*ALIGN)
  directly; ~1cycle scalar arith vs ~5us PCIe round-trip.  Python wrapper's
  _mxfp8_grid_x_hint helper deleted with all callers and ctx state.

- #8 rename group_offs_ptr -> a_group_offs_ptr in MXFP8 paths to disambiguate
  from c_group_offs_ptr (output offsets).  CK / hipBLASLt struct fields
  unchanged.

- #9 use detail::MXFP8_PADDING_ALIGN_SIZE instead of the literal 127 in the
  M_g_padded round-up; matches the constant the upstream quant op uses.

- #10 GroupedGEMMFP8VariableKTurboBackend.can_handle now gates on the
  16-multiple shape requirement of the wgrad MFMA (n_fwd, k_fwd) and
  contraction-axis agreement, so unsupported shapes fall through to CK
  instead of asserting inside the kernel.

- #11 ctx.trans_b_orig -> ctx.trans_b — the orig was a leftover from an
  earlier refactor; only one trans_b is now persisted.

REJECTED on empirical grounds:
- #1 (drop post-asm clobber_vgpr_one<> after ds_read_pinned).  Reviewer
  argued these are redundant since callers reserve the dst VGPR range
  upfront; in practice removing them races 100% in grouped MXFP8 dgrad
  (max diff inf) and 0.7% in dense MXFP8 mxgemm at 4096^3 — exact match
  for the historical race profile.  The reservation prevents the dst
  VGPRs from being reused but does NOT prevent the compiler from
  reordering ds_read past the consuming MFMA; only the per-asm clobber
  list does.  Keeping the calls.

OBSOLETED:
- #12 (always-non-null c_group_offs_ptr; derive M_g_in from
  a_group_offs delta).  The delta-load form raced ~3% in fwd at the
  gpt_oss-20B Expert shape; the simpler ceil(M_g, ALIGN) form (review
  #9 above) lands the same value with no extra scalar load.  The
  if-nullptr branch in the kernel is one scalar test and stays — the
  cleanup wasn't worth the delta-load races.
…ed GEMM

Combines and supersedes the trio of in-flight review iterations
(0df887d, 7af738c, db6152e) into a single coherent change.  Net effect
relative to PR base (847872c): MXFP8 grouped GEMM (fwd / dgrad / wgrad)
is fully deterministic on gfx950 at PR-base perf, with the source-level
workarounds dropped in favour of the actual root-cause fix.

Code changes:

- Drop the c_group_offs == nullptr mode-switch and the clobber_vgpr_one
  pinning workaround; they hid (not fixed) the race and bloated codegen.
- CK-style separation: outer persistent kernel resolves per-group
  pointers (a_grp_ptr, b_grp_ptr, ..., c_grp_ptr) entirely; compute_tile
  receives them already group-resolved + per-tile bookkeeping only.
  Narrows the SGPR live set to the dense single-GEMM kernel level.
- Pack the 4 typed smem tile pointers into one char* base; compute_tile
  reconstructs typed views internally.  Saves 6 SGPR args.
- Two-step prologue (buffer_load_b128 -> VGPR -> ds_write_b128) for the
  16 LDS-direct loads in fwd's prologue; wgrad uses 2-step only for the
  4 data loads (full conversion regresses wgrad perf).  Puts the LDS
  write on lgkmcnt instead of vmcnt, separated from compiler-emitted
  scratch_load (SGPR spill consume).

Compiler flags (the actual root-cause fix for the residual race):

- `-amdgpu-enable-merge-m0=true` merges adjacent `s_mov_b32 m0` writes,
  collapsing per-load m0 setup chains in the buffer_load_lds main loop.
  Without merge, each LDS-direct load has its own m0 update fed from a
  potentially-stale spill SGPR; with merge, m0 is established once per
  cluster, breaking the spill->m0->LDS-write race.
- `-sink-insts-to-avoid-spills=true` sinks defining instructions toward
  their uses, halving private_segment from 44 to 20 bytes.  This kills
  the residual cross-call VGPR live range that drove the scratch_store /
  scratch_load pair — vmcnt counts both scratch and buffer_load_lds, and
  gfx950 doesn't FIFO across them, so a spill reload could otherwise
  leak into m0.

Race + perf at gpt_oss-20B Expert shape (B=4 M=2048 N=5760 K=2880):
- fwd:   0/500K  (vs PR HEAD ~3/30K = 0.01%);  Fwd  1340 TFLOPS (+1.4%)
- dgrad: 0/150K;                                Dgrad 1258 TFLOPS (+2.1%)
- wgrad: 0/90K;                                 Wgrad 1723 TFLOPS (-1.5%)

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Previous commit added separate `*_two_step` helper variants and a
TWO_STEP template parameter on phase_mfma_lds_ldg.  In the final config
the main loop never uses the 2-step path, so the duplication and the
template branching in phase_mfma_lds_ldg are dead weight.

Cleanup:
- Drop `*_half_srd_two_step` duplicates; collapse the 4 affected
  `load_{a,b}{,_scale}_gmem_to_smem_half_srd` helpers behind a
  `bool TwoStep = false` template parameter and a single `if constexpr`
  switch around the inner load.
- Drop the TWO_STEP template parameter + `if constexpr` chain on
  `phase_mfma_lds_ldg`; the main loop is always LDS-direct.
- Collapse the two free-function `load_gmem_to_smem_srd_two_step` impls
  into one `<Bytes>` template (b32 + b128).
- Drop the now-unused `clobber_{a,v}gpr_one` `using` decls in
  turbo_gemm_mxfp8_kernel{,_hip}.h.

Per-prologue 2-step decisions kept as before:
- fwd: data + scale prologue loads use TwoStep=true (16 calls).
- wgrad: only the 8 data prologue loads use TwoStep=true; scales stay
  LDS-direct (full conversion regresses wgrad ~25%).
- inter-iter and main-loop loads stay LDS-direct.

Race + perf at gpt_oss-20B Expert shape (B=4 M=2048 N=5760 K=2880),
chi2761 GPU 6, 50K iters each:
- fwd:   0/50K;  Fwd  1337 TFLOPS (≡ 567e578)
- dgrad: 0/50K;  Dgrad 1254 TFLOPS (≡ 567e578)
- wgrad: 0/50K;  Wgrad 1756 TFLOPS (≡ 567e578)

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Reset csrc to PR base 847872c.  Reviewer (Chen Xiaobo) rejected the
2-step prologue / mllvm flags / clobber_vgpr_one / reserve_pinned_regs
approach as papering over the race instead of fixing it.  The rejected
commits stay in git log per prior request.
Race source: VGPR-spilled SMEM-subtile addresses go through scratch_load,
which shares vmcnt with buffer_load_lds but doesn't FIFO across it on
gfx950.  s_waitcnt vmcnt(N) can pass while scratch_load is still in
flight → stale m0 → LDS write to wrong address → 0.01-1% bit non-det.

Three source-level changes (no LLVM flags, no clobber barriers):

1. Outer kernel resolves per-group base ptrs; inner takes resolved ptrs
   only.  i64 group offset chain stays out of inner.  c_group_offs_ptr
   now always-valid via launcher fallback; per-group input padding is
   signalled branchlessly via a c_padding_align_mask field.

2. SMEM-write helpers force LDS-base into SGPR: compute_sts_offsets
   drops lane_id*16 (buffer_load_lds auto-strides per lane); scale
   helpers drop scale_sts_offset = lane_id; sts_warp_base /
   smem_byte_offset wrapped in __builtin_amdgcn_readfirstlane so the
   warp-id arithmetic is explicitly scalar.

3. bf16 store uses raw-bit truncation instead of hip_bfloat16(float)
   software round-to-even (avoids SCC branches that get lane-spilled).
   SNR unchanged at 28.23 dB.

reg notes (sgpr/sgpr_spill/vgpr_spill/scratch):
  grouped fwd bf16  106 / 7 / 10 / 44  →  106 / 0 / 0 / 0
  grouped fwd half  106 / 0 / 17 / 72  →  106 / 0 / 0 / 0
  grouped wgrad     106 / 0 / 17 / 72  →  104-105 / 0 / 0 / 0
  single GEMM       81 / 0 / 0 / 0     unchanged

fwd race: 0.01-1% → 0/200K (4 GPUs × 50K iters).
Perf (B=16 M=2048 N=4096 K=7168): Fwd 1435 / Dgrad 1416 / Wgrad 2013 TFLOPS.
SNR fwd / dgrad / wgrad: 28.23 / 28.23 / 28.08 dB.
After the spill fix, compute_sts_offsets returned only [0, 1024] and
threaded a [2]-array through every load_a/b_gmem_to_smem_half_srd call.
Inline the constant: helper writes `base` and `base + 1024` directly
(1024 = per-wave LDS write span for buffer_load_dwordx4 = 16 B/lane × 64).

Also two small style alignments with the single-GEMM kernel:
- pid_n_idx → pid_n_local in fwd compute_tile (consistent with pid_m_local)
- scale_cols declaration moved between base ptrs (mxgemm ordering)

Verified: reg notes / race / SNR all unchanged.
@kyle-256 kyle-256 force-pushed the dev/kyle_mxfp8_gg_pr branch from 07e9940 to 8244d7a Compare May 16, 2026 02:19
Copilot AI review requested due to automatic review settings May 19, 2026 10:05
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 21 out of 21 changed files in this pull request and generated 2 comments.

Comment on lines +502 to +509
// ``quantize_mxfp8_dual_perg``: per-group dual quant for the grouped GEMM
// B-side. Input is 3D ``(G, M, N)``; per-group M/N are uniform. The
// kernel writes:
// rowwise_output: ``(G, M_pad, N_pad)`` -- naturally fits ``(G, N_of_B,
// K_pad)`` for the fwd GEMM.
// colwise_output: ``(G, N, M_pad)`` -- naturally fits ``(G, K_of_B,
// N_pad)`` for the dgrad GEMM
// (no separate regroup).
detail::ScalingRecipe colwise_recipe, hipStream_t stream);

// Per-group dual quant with uniform per-group (M, N): input (G, M, N) ->
// rowwise (G, M_pad, N_pad), colwise (G, N, M_pad). Used by the grouped
@kyle-256 kyle-256 force-pushed the dev/kyle_mxfp8_gg_pr branch from abaeaab to e0263a0 Compare May 19, 2026 10:26
Two wait_lgkmcnt<0> drains (mid-phase WAR + compute_tile prologue) fix
an LDS WAR race in the MFMA-pinned-asm path; race verified at 0/N on
both single-GEMM and grouped GEMM MX_BLOCKWISE deterministic tests.
Copilot AI review requested due to automatic review settings May 19, 2026 13:01
@kyle-256 kyle-256 force-pushed the dev/kyle_mxfp8_gg_pr branch from e0263a0 to 17df201 Compare May 19, 2026 13:01
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 21 out of 21 changed files in this pull request and generated 3 comments.

Comments suppressed due to low confidence (2)

tests/pytorch/ops/test_grouped_gemm_fp8.py:554

  • Same issue as the deterministic MX_BLOCKWISE test above: this uses an exact (9, 5) compute capability check, which will skip on newer gfx95x devices even if they support MXFP8. Use check_mxfp8_support() or >= (9, 5) instead of equality.
    if get_device_compute_capability() != (9, 5):
        pytest.skip("MXFP8 grouped GEMM requires gfx950 (MI350/MI355).")

primus_turbo/pytorch/kernels/grouped_gemm/grouped_gemm_fp8_impl.py:628

  • Same concern as the fwd/dgrad TURBO backend: can_handle() for the variable-K TURBO backend doesn’t verify gfx950/MXFP8 support. Adding a device capability gate would prevent the dispatcher from selecting a kernel that asserts at runtime on unsupported GPUs when invoked with pre-quantized tensors.
        supported = True
        supported &= a.dim() == 2 and b.dim() == 2
        supported &= (a.dtype, b.dtype, out_dtype) in GroupedGEMMFP8VariableKTurboBackend.SUPPORTED_DTYPES
        supported &= granularity in GroupedGEMMFP8VariableKTurboBackend.SUPPORTED_GRANULARITIES
        supported &= not trans_a and not trans_b and not trans_c
        # MFMA tile = 16x16; output axes (N_fwd, K_fwd) must clear that.
        if a.dim() == 2 and b.dim() == 2:
            n_fwd = a.shape[0]
            k_fwd = b.shape[1]
            supported &= n_fwd % 16 == 0 and k_fwd % 16 == 0
            supported &= a.shape[1] == b.shape[0]
        return supported

@pytest.mark.parametrize("balance", [True, False])
@pytest.mark.deterministic
def test_grouped_gemm_fp8_mx_blockwise_deterministic(B, M, NK, ori_dtype, format, trans_b, balance):
if get_device_compute_capability() != (9, 5):
Comment on lines +502 to +509
// ``quantize_mxfp8_dual_perg``: per-group dual quant for the grouped GEMM
// B-side. Input is 3D ``(G, M, N)``; per-group M/N are uniform. The
// kernel writes:
// rowwise_output: ``(G, M_pad, N_pad)`` -- naturally fits ``(G, N_of_B,
// K_pad)`` for the fwd GEMM.
// colwise_output: ``(G, N, M_pad)`` -- naturally fits ``(G, K_of_B,
// N_pad)`` for the dgrad GEMM
// (no separate regroup).
Comment on lines +422 to +431
supported = True
supported &= a.dim() == 2 and b.dim() == 3
supported &= (a.dtype, b.dtype, out_dtype) in GroupedGEMMFP8TurboBackend.SUPPORTED_DTYPES
supported &= granularity in GroupedGEMMFP8TurboBackend.SUPPORTED_GRANULARITIES
supported &= not trans_a and trans_b
total_m = a.shape[0]
n = b.shape[-2] if trans_b else b.shape[-1]
k = a.shape[1]
supported &= total_m % 16 == 0 and n % 16 == 0 and k % 128 == 0 and k >= 384
return supported
return grad_out if grad_out.is_contiguous() else grad_out.contiguous()


def _turbo_grouped_gemm_mxfp8(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code here does not align with our backend design

@pytest.mark.parametrize("balance", [True, False])
@pytest.mark.deterministic
def test_grouped_gemm_fp8_mx_blockwise_deterministic(B, M, NK, ori_dtype, format, trans_b, balance):
if get_device_compute_capability() != (9, 5):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check_mxfp8_support

kyle-256 and others added 2 commits May 25, 2026 11:25
…pads & helpers

- ops: delete _turbo_grouped_gemm_mxfp8 / _turbo_grouped_gemm_variable_k_mxfp8
  bypass helpers; MX_BLOCKWISE fwd/dgrad/wgrad now go through
  grouped_gemm_fp8_impl / grouped_gemm_fp8_variable_k_impl with
  default_backend=BackendType.TURBO so env-var override / autotune /
  can_handle fallback all engage uniformly with the fp8 tensorwise path.
- dispatcher: extend grouped_gemm_fp8_impl (+ register_fake) with
  c_group_offs / total_m_out / grid_x_hint kwargs to expose MX unpad-on-
  store and persistent grid hint via the public op signature. Hipblaslt
  backends gain a **kwargs shield to absorb the new keys.
- can_handle: fix GroupedGEMMFP8VariableKTurboBackend shape check (was
  a.shape[1] == b.shape[0] / k_fwd = b.shape[1]; MX wgrad layout has the
  variable-M reduction dim as shape[1] on both — correct check is
  a.shape[1] == b.shape[1], k_fwd = b.shape[0]). Bug never surfaced
  because the bypass skipped can_handle entirely.
- kernel: inline CType(f) (default RNE) at the C store; drop the bf16
  raw-bit truncation specialisation (and the float_to_ctype helper).
  Disasm now shows 0 spill on all MXFP8 kernels, so the workaround is
  unnecessary; truncation silently diverged from torch's bf16 cast and
  bites numerics debugging.
- kernel: drop the (scale_cols + 3) / 4 * 4 pad in workspace_size, impl,
  and the compute_tile / launcher scale_cols. Launcher asserts k % 128
  == 0, so scale_cols = k/32 is always a multiple of 4 and the pad is
  dead arithmetic. Also drops the preshuffle's trailing partial col-
  block branch (dead under the same invariant).
- test: switch grouped MX deterministic / mx_blockwise gating from
  get_device_compute_capability() != (9, 5) to check_mxfp8_support() to
  match test_gemm_fp8.py / test_quantization.py.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@kyle-256 kyle-256 force-pushed the dev/kyle_mxfp8_gg_pr branch from 17df201 to c5c03ad Compare May 26, 2026 03:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants