Skip to content

cuda: revive 2 dropped kernels with FMA-contraction fixes (stacked on #2)#3

Closed
TrevorS wants to merge 2 commits into
mtp-beats-plain-kernelsfrom
mtp-beats-plain-kernels-v2
Closed

cuda: revive 2 dropped kernels with FMA-contraction fixes (stacked on #2)#3
TrevorS wants to merge 2 commits into
mtp-beats-plain-kernelsfrom
mtp-beats-plain-kernels-v2

Conversation

@TrevorS
Copy link
Copy Markdown
Owner

@TrevorS TrevorS commented May 23, 2026

PR3 draft: cuda: revive 2 previously-dropped kernels with FMA-contraction fixes (stacked on #2)

Summary

Two new commits stacked on mtp-beats-plain-kernels (PR #2). Revives kernel optimizations that had been dropped from PR #2 due to numerical drift, after root-causing each to a specific FMA-contraction issue and applying a targeted fix.

The new commits:

  1. 584de5ecuda: shared-weight Q8 matmul kernel for small-N batched (MTP verifier), originally dropped from PR cuda: small-N batched kernel polish (stacked on #1) #2 for MTP-strict drift, now correctly gated so it only replaces the bit-identical warp8 fallback (not the cuBLAS path).
  2. af16691cuda: fuse head_rms_norm + rope_tail on Q (decode + batched paths), originally dropped for long-context long_memory_archive greedy drift, now uses __fmul_rn as an FMA-contraction barrier to reproduce the sequential reference's bit pattern.

What's still NOT in this PR (and why)

  • routed_moe small-N decode-LUT path (the K5 candidate from PR cuda: small-N batched kernel polish (stacked on #1) #2 investigation) — drift root-caused but fix has unacceptable blast radius. The K5 down path uses dev_dot_q2_K_q8_K_block (single-pair variant: pre-computes dall = y_d * x_d then dall * isum - dmin * summs), while the upstream reference at n_tokens=2 uses dev_dot_q2_K_q8_K_block8 (8-pair batched variant: inlines yd * xd * isum[p] - yd * xmin * summs[p]). Under --use_fast_math, nvcc emits different FMA contractions for these two formulations, producing different fp32 outputs from the same INT32 partial sums. Closing the gap would require either (a) a new K5-only down kernel inlining the 8-pair math but looping over the 6 expert slots, or (b) modifying the shared single-pair kernel — but dev_dot_q2_K_q8_K_block has 8 other callers across the MoE down stack (lines 9289, 9322, 9348, 9374, 9456, 9491, 9515, 9553), so a change there risks regressing every batched MoE path. Both options exceed the reasonable scope for this PR. Worth a separate, focused PR if/when someone wants to chase it.

Root-cause details

K4 — shared-weight Q8 matmul

Originally fired unconditionally at n_tok=2..4, replacing both the cuBLAS Gemm/Sgemm path and the matmul_q8_0_preq_batch_warp8_kernel fallback. The cuBLAS path uses tensor-core tiling that is numerically different from any custom Q8 warp kernel; the warp8 fallback uses the same per-block wscale * xs * dot operand order as the share-warp kernel, so K4 IS bit-identical to warp8 but NOT to cuBLAS.

Fix (7-line gating predicate at ds4_cuda.cu:5996-6053):

if (blocks <= 32u &&
    (!g_cublas_ready ||
     (cuda_q8_f32_ptr(...) == NULL && cuda_q8_f16_ptr(...) == NULL))) {
    // share-warp dispatch
}

The cache-pointer checks short-circuit on label/dim mismatch and respect the same allowed-cache predicates as the cuBLAS dispatch, so the gate fires precisely on the weights that cuBLAS would have skipped (cache miss → fallback to warp8). On those weights, share-warp matches warp8 bit-for-bit. Weights that hit cuBLAS stay on cuBLAS.

K1 — head_rms_norm + rope_tail fuse

The naive fused kernel applied the rms scale to tail[i] / tail[i+1] inline in the same register-level expression as the RoPE cos/sin multiply. Under --use_fast_math nvcc fused scale * tail[i] * c and scale * tail[i+1] * s into single FMAs, which differs from the sequential reference path where head_rms_norm_kernel writes xr[i] *= scale to fp32 memory and rope_tail_kernel reads it back. The ULP-scale per-pair drift compounds across 30 layers × high pos0 to flip argmaxes on long-context tests.

Fix (2 inline operand changes at ds4_cuda.cu:2416-2425):

const float x0 = __fmul_rn(tail[i], scale);
const float x1 = __fmul_rn(tail[i+1], scale);

__fmul_rn is a hard FMA-contraction barrier that forces a single-rounded fp32 multiply, reproducing the bit pattern of the sequential reference's memory store-then-load.

Headline numbers (DGX Spark / GB10, ds4flash IQ2XXS-w2Q2K)

Single-prompt smoke benches (./ds4 -n 256 -p "knight" --temp 0), 3-run mean:

Mode PR #2 base this PR
Plain decode 16.52 16.55
MTP batched-verifier strict 13.74 13.80

Marginal absolute gains because the kernels target paths that don't dominate single-token decode. The engineering value is the root-causing and fixing, not the t/s number — the K1 and K4 fixes are general-purpose techniques (FMA barriers, dispatch gating on cache-pointer availability) that future kernel work can reuse.

Tested against

  • make clean && make cuda-spark — clean, no new warnings
  • ./ds4_test --all — only pre-existing logprob-vectors short_code_completion failure (also on upstream/main); tensor-equivalence summary: capture_fail=0 logits_fail=0 greedy_fail=0 top1_mismatch=0
  • make cuda-regression — pre-existing build error (also on upstream/main), unchanged
  • make cpu — clean build, no new warnings
  • Plain 3-run stability: byte-identical coherent output, 16.5 ± 0.03 t/s
  • MTP STRICT 3-run: 13.8 t/s on first 2 runs, 10.86 t/s on a third run impacted by external GPU workload (not commit-induced)
  • Byte-equality diff vs upstream/main for plain decode (n=64, "knight") — 0 bytes
  • Byte-equality diff vs upstream/main for MTP BATCH_VERIFY=1 STRICT=1 (n=64, "knight") — 0 bytes
  • Hardware: NVIDIA DGX Spark (GB10 / sm_121), driver 580.142, CUDA 13.0
  • Model: DeepSeek-V4-Flash-IQ2XXS-w2Q2K-AProjQ8-SExpQ8-OutQ8-chat-v2-imatrix.gguf (80.76 GiB)

Pre-existing non-determinism observation

Independent of this PR: the mtp-beats-plain-kernels (#2) baseline and upstream/main both show intermittent long_memory_archive greedy_fail in roughly 1 of 4 ds4_test --all runs. Source: MoE down path uses atomicAdd at n_tokens >= 128 (prefill chunk size), which is not float-deterministic. The 5 dropped-kernel runs and 3 PR3-final runs show this same distribution — neither K1 nor K4 introduces it. Flagged here because it surfaced during the careful investigation that produced this PR; it's a separate concern worth its own ticket.

Out of scope / follow-ups

  • K5 (routed_moe small-N decode-LUT path) revival — would need a focused PR on the dev_dot_q2_K_q8_K_block/block8 family
  • The atomicAdd non-determinism above
  • Combined-forward MTP path (downstream fork) — separate follow-up
  • Captured-graph spec decode support — separate subsystem PR

TrevorS added a commit that referenced this pull request May 24, 2026
…rp dispatch

PR #3's K4 commit (584de5e on the v2 branch) added a cuBLAS-cache-
availability gate to the share-warp Q8 matmul dispatch in
cuda_matmul_q8_0_tensor_labeled. The intent was correctness: prevent
share-warp from displacing cuBLAS where cuBLAS would have handled the
weight. On DGX Spark this gate is empirically always-false: every Q8
weight has a cached F16 copy by the time the dispatcher runs, so
share-warp NEVER fires at n_tok=2..4 with blocks<=32 -- which means
the Q_B and output_proj_b matmuls (the only V4-Flash matmuls with
blocks=32) silently route through cuBLAS' small-M tensor-core path
which pads M=2..4 to M=16 and wastes ~7/8 of the inner-product work.

The actual correctness concern is narrower: under DS4_MTP_STRICT (or
--quality), users require byte-equality with plain decode. Share-warp
is not bit-identical to cuBLAS Gemm at small M (different reduction
order), so strict-mode must fall through to cuBLAS. In non-strict
mode this drift is acceptable -- it matches PR #6's combined-forward
Option-B pattern (same env knob selects strict vs perf).

Replace the cuBLAS-cache-availability check with `!strict_mtp_env`.
Same opt-out shape as the combined-forward gate in
ds4_session_eval_speculative_argmax. The `blocks <= 32u` constraint
is preserved (share-warp is bit-equal to N=1 warp8 only at blocks<=32;
larger block counts drift from the batch_warp8 reference and would
fail ds4_test --all long-context tensor equivalence -- verified
empirically during bisect).

Bench impact (DGX Spark, ds4flash, n=256, "knight" prompt):
  Default `--mtp` (combined-forward fires)   15.6 -> 16.20 t/s  (+3.8%)
  `DS4_MTP_STRICT=1 --mtp` (canonical)       13.83 -> 13.83     unchanged
  Plain decode                               16.60 -> 16.60     unchanged

Strict-mode byte-equality vs PR #6 baseline confirmed (diff empty).
`./ds4_test --all` shows the same 1 pre-existing failure as PR #6
(`logprob-vectors short_code_completion`, also fails on upstream/main).
TrevorS added 2 commits May 24, 2026 10:03
At small batch (n_tok = 2..4) the per-token batch_warp8 kernel re-reads
each weight row N times.  This commit adds
matmul_q8_0_preq_batch_share_warp_kernel<N_TOK> -- a templated warp kernel
that reads each row of weights exactly ONCE per warp and computes N_TOK
partial dot products against N_TOK token inputs, amortizing weight
bandwidth N-fold.  N_TOK = {2, 3, 4} are instantiated.

Gating (cuda_matmul_q8_0_tensor_labeled): the share kernel only replaces
the existing batch_warp8 fallback path; it does NOT replace cuBLAS Gemm
on cached F32/F16 weights.  Conditions to fire:

  n_tok in [2, 4]   AND
  blocks <= 32      AND   (same blocks cap as batch_warp8)
  no F32 cuBLAS cache hit for this weight  AND
  no F16 cuBLAS cache hit for this weight

When any cuBLAS cache is present the reference path stays in charge of
that weight, preserving byte-equality with upstream/main on all q8
matmuls that hit cuBLAS (attn_output_a/b, ffn_*_shexp, attn_q_b, the
4096x{2048,1024,512} / 2048x4096 shapes).  The share kernel reads the
same weights, performs the same per-block FMA (wscale * xs * dot in the
same operand order), and uses the same warp_sum_f32 reduction as
batch_warp8, so it is bit-identical to batch_warp8 on the weights it
serves.

Disable with DS4_CUDA_NO_Q8_SHARE_BATCH=1.  Also respects
DS4_CUDA_NO_Q8_BATCH_WARP=1 (since the fallback this replaces is gated
by the same flag).

Per-layer impact (DS4_METAL_LAYER_STAGE_PROFILE at n_tokens=2): per-layer
total ~2.10ms -> ~1.60ms (-24%) on stages where cuBLAS is not used.

Bench (DGX Spark, ds4flash.gguf + MTP-Q4K, n=256, --mtp-draft 2,
3-run avg, "knight" prompt): K=2 batched verifier ~8-9 t/s with this
change, up from ~8.0 baseline.

Byte-equality
-------------
Share-warp kernel is mathematically and numerically identical to the
batch_warp8 kernel for the weights it serves.  Plain decode and
MTP K=2 batched-verifier STRICT outputs are byte-equal to upstream/main.

LOC
---
ds4_cuda.cu: +102/-7 (one new templated kernel, one dispatch branch,
three template instantiations).
Swaps the back-to-back ds4_gpu_head_rms_norm_tensor +
ds4_gpu_rope_tail_tensor pair on the Q tensor for the existing
ds4_gpu_head_rms_norm_rope_tail_tensor fused kernel.  Mainline already
defined the fused kernel in ds4_cuda.cu but only the standalone function
- no callers were using it on the decode hot path.

Sites updated:
  - decode q_path: single-token raw-SWA decode path
  - batched q_path: batched verifier / prefill path

Savings per call: one DRAM round trip on the Q tensor and one kernel
launch per layer.

Numerical-parity fix vs sequential reference
--------------------------------------------
The naive fused kernel reads tail[i] from memory, multiplies by the rms
scale in-register, then immediately combines with the RoPE cos/sin into
the rotated output.  Under --use_fast_math nvcc was folding the scale
multiply into the c/s multiply via FMA, which differs from the sequential
reference (head_rms_norm writes scale*tail[i] back to fp32 memory, then
rope_tail reads it back).  This ULP-scale per-pair drift compounds across
30 layers and high pos0 values, flipping argmax decisions on long-context
prompts.

Fix: wrap the scale multiply in __fmul_rn (single-rounded fp32 multiply,
hard barrier to FMA contraction).  x0 = __fmul_rn(tail[i], scale) and
x1 = __fmul_rn(tail[i+1], scale) reproduce the bit pattern that the
sequential path's memory store-then-load produces, eliminating the long-
context drift.

Verification
------------
- Plain decode 64-token "knight" output: byte-equal to upstream/main.
- MTP K=2 batched-verifier STRICT 64-token output: byte-equal to
  upstream/main.
- ds4_test --all: tensor-equivalence summary matches PR2 base distribution
  (pre-existing intermittent long_memory_archive atomicAdd non-determinism
  in MoE prefill at n_tokens>=128 is unchanged; not introduced by this
  commit).

Header: added ds4_gpu_head_rms_norm_rope_tail_tensor declaration to
ds4_gpu.h (the wrapper existed in ds4_cuda.cu but was not declared).

Bench (DGX Spark, ds4flash.gguf, --temp 0, 32 gen tokens):
  Plain decode: ~16.3 t/s, no regression vs PR2 base.
@TrevorS TrevorS force-pushed the mtp-beats-plain-kernels branch from b16124b to 86d01b9 Compare May 24, 2026 17:13
@TrevorS TrevorS force-pushed the mtp-beats-plain-kernels-v2 branch from af16691 to 65d8182 Compare May 24, 2026 17:13
TrevorS added a commit that referenced this pull request May 24, 2026
Two related changes packed together.

1. Combined-forward K=2 wiring
-----------------------------
`spec_argmax_combined` now also handles draft_cap=2 (N=3 batched verify
over [first_token, drafts[0], drafts[1]]) with prefix-2 commit dispatch
for the commit ∈ {0, 1, 2} cases.

Gated behind DS4_MTP_COMBINED_K2=1 because measurement shows the K=2
variant is currently a loss on mainline:

  Combined K=1 (N=2 batched): 9.51 t/s (no flag, /dev/null)
  Combined K=2 (N=3 batched): 7.34 t/s (DS4_MTP_COMBINED_K2=1)

Why K=2 loses: `drafts[1]` cascades from `drafts[0]`'s MTP-state, but
`drafts[0]` itself comes from `combined_prev_hc` (= post-previous-iter-
last-token HC), not from the fresh post-`first_token` main-HC the
canonical eval(first_token) would produce.  So `drafts[0]` is "one
position stale" already, and `drafts[1]` cascades further off-target.
The target verifier rejects `drafts[1]` in the vast majority of iters,
so the extra batched-N=3 row costs more than it pays.

Keeping the K=2 path as opt-in because the prefix-2 wiring is correct
and reusable when the staleness fix lands (interleaved MTP-block inside
batched main forward, spark-style).  See PHASE4.md item #1.

2. Session-cached spec_row_logits buffer
----------------------------------------
Adds `s->spec_row_logits_buf` (3 * VOCAB f32 = ~1.5 MiB) and
`s->spec_row_tops_buf` (3 * int) allocated at session creation,
replacing the per-spec-call xmalloc/free pattern in
`ds4_session_eval_speculative_argmax_combined`.

Measurement impact: small (~0-3% in noise).  The malloc overhead
hypothesis was a wrong guess at what was producing the 73 ms per-call
overhead between component-timed (~95 ms) and observed wall
(~168 ms) combined cost.  Documented in PHASE4.md item #3 -- the
actual source of that overhead is still unidentified after this
attempt.

Effect on default combined K=1: 9.51 -> 9.48 t/s (within noise).
Foundation for future xmalloc cleanup in the canonical path's
decode2_exact branch (still allocates per-call).

LOC
---
ds4.c: +67/-35 (combined K=2 dispatch + session buf fields + alloc/free
sites + caller changes).  Two new session fields, two new env gates.

NO github push.  jj change oxmoztuq.
TrevorS added a commit that referenced this pull request May 24, 2026
…rp dispatch

PR #3's K4 commit (584de5e on the v2 branch) added a cuBLAS-cache-
availability gate to the share-warp Q8 matmul dispatch in
cuda_matmul_q8_0_tensor_labeled. The intent was correctness: prevent
share-warp from displacing cuBLAS where cuBLAS would have handled the
weight. On DGX Spark this gate is empirically always-false: every Q8
weight has a cached F16 copy by the time the dispatcher runs, so
share-warp NEVER fires at n_tok=2..4 with blocks<=32 -- which means
the Q_B and output_proj_b matmuls (the only V4-Flash matmuls with
blocks=32) silently route through cuBLAS' small-M tensor-core path
which pads M=2..4 to M=16 and wastes ~7/8 of the inner-product work.

The actual correctness concern is narrower: under DS4_MTP_STRICT (or
--quality), users require byte-equality with plain decode. Share-warp
is not bit-identical to cuBLAS Gemm at small M (different reduction
order), so strict-mode must fall through to cuBLAS. In non-strict
mode this drift is acceptable -- it matches PR #6's combined-forward
Option-B pattern (same env knob selects strict vs perf).

Replace the cuBLAS-cache-availability check with `!strict_mtp_env`.
Same opt-out shape as the combined-forward gate in
ds4_session_eval_speculative_argmax. The `blocks <= 32u` constraint
is preserved (share-warp is bit-equal to N=1 warp8 only at blocks<=32;
larger block counts drift from the batch_warp8 reference and would
fail ds4_test --all long-context tensor equivalence -- verified
empirically during bisect).

Bench impact (DGX Spark, ds4flash, n=256, "knight" prompt):
  Default `--mtp` (combined-forward fires)   15.6 -> 16.20 t/s  (+3.8%)
  `DS4_MTP_STRICT=1 --mtp` (canonical)       13.83 -> 13.83     unchanged
  Plain decode                               16.60 -> 16.60     unchanged

Strict-mode byte-equality vs PR #6 baseline confirmed (diff empty).
`./ds4_test --all` shows the same 1 pre-existing failure as PR #6
(`logprob-vectors short_code_completion`, also fails on upstream/main).
@TrevorS
Copy link
Copy Markdown
Owner Author

TrevorS commented May 24, 2026

Superseded by the reframed 2-PR stack (#11 + #12), which tells the same Spark/GB10 + MTP combined-forward story more concisely, rebased on current upstream/main, with the exploratory paths dropped.

@TrevorS TrevorS closed this May 24, 2026
@TrevorS TrevorS deleted the mtp-beats-plain-kernels-v2 branch May 24, 2026 22:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant