cuda: revive 2 dropped kernels with FMA-contraction fixes (stacked on #2)#3
Closed
TrevorS wants to merge 2 commits into
Closed
cuda: revive 2 dropped kernels with FMA-contraction fixes (stacked on #2)#3TrevorS wants to merge 2 commits into
TrevorS wants to merge 2 commits into
Conversation
This was referenced May 23, 2026
TrevorS
added a commit
that referenced
this pull request
May 24, 2026
…rp dispatch PR #3's K4 commit (584de5e on the v2 branch) added a cuBLAS-cache- availability gate to the share-warp Q8 matmul dispatch in cuda_matmul_q8_0_tensor_labeled. The intent was correctness: prevent share-warp from displacing cuBLAS where cuBLAS would have handled the weight. On DGX Spark this gate is empirically always-false: every Q8 weight has a cached F16 copy by the time the dispatcher runs, so share-warp NEVER fires at n_tok=2..4 with blocks<=32 -- which means the Q_B and output_proj_b matmuls (the only V4-Flash matmuls with blocks=32) silently route through cuBLAS' small-M tensor-core path which pads M=2..4 to M=16 and wastes ~7/8 of the inner-product work. The actual correctness concern is narrower: under DS4_MTP_STRICT (or --quality), users require byte-equality with plain decode. Share-warp is not bit-identical to cuBLAS Gemm at small M (different reduction order), so strict-mode must fall through to cuBLAS. In non-strict mode this drift is acceptable -- it matches PR #6's combined-forward Option-B pattern (same env knob selects strict vs perf). Replace the cuBLAS-cache-availability check with `!strict_mtp_env`. Same opt-out shape as the combined-forward gate in ds4_session_eval_speculative_argmax. The `blocks <= 32u` constraint is preserved (share-warp is bit-equal to N=1 warp8 only at blocks<=32; larger block counts drift from the batch_warp8 reference and would fail ds4_test --all long-context tensor equivalence -- verified empirically during bisect). Bench impact (DGX Spark, ds4flash, n=256, "knight" prompt): Default `--mtp` (combined-forward fires) 15.6 -> 16.20 t/s (+3.8%) `DS4_MTP_STRICT=1 --mtp` (canonical) 13.83 -> 13.83 unchanged Plain decode 16.60 -> 16.60 unchanged Strict-mode byte-equality vs PR #6 baseline confirmed (diff empty). `./ds4_test --all` shows the same 1 pre-existing failure as PR #6 (`logprob-vectors short_code_completion`, also fails on upstream/main).
At small batch (n_tok = 2..4) the per-token batch_warp8 kernel re-reads
each weight row N times. This commit adds
matmul_q8_0_preq_batch_share_warp_kernel<N_TOK> -- a templated warp kernel
that reads each row of weights exactly ONCE per warp and computes N_TOK
partial dot products against N_TOK token inputs, amortizing weight
bandwidth N-fold. N_TOK = {2, 3, 4} are instantiated.
Gating (cuda_matmul_q8_0_tensor_labeled): the share kernel only replaces
the existing batch_warp8 fallback path; it does NOT replace cuBLAS Gemm
on cached F32/F16 weights. Conditions to fire:
n_tok in [2, 4] AND
blocks <= 32 AND (same blocks cap as batch_warp8)
no F32 cuBLAS cache hit for this weight AND
no F16 cuBLAS cache hit for this weight
When any cuBLAS cache is present the reference path stays in charge of
that weight, preserving byte-equality with upstream/main on all q8
matmuls that hit cuBLAS (attn_output_a/b, ffn_*_shexp, attn_q_b, the
4096x{2048,1024,512} / 2048x4096 shapes). The share kernel reads the
same weights, performs the same per-block FMA (wscale * xs * dot in the
same operand order), and uses the same warp_sum_f32 reduction as
batch_warp8, so it is bit-identical to batch_warp8 on the weights it
serves.
Disable with DS4_CUDA_NO_Q8_SHARE_BATCH=1. Also respects
DS4_CUDA_NO_Q8_BATCH_WARP=1 (since the fallback this replaces is gated
by the same flag).
Per-layer impact (DS4_METAL_LAYER_STAGE_PROFILE at n_tokens=2): per-layer
total ~2.10ms -> ~1.60ms (-24%) on stages where cuBLAS is not used.
Bench (DGX Spark, ds4flash.gguf + MTP-Q4K, n=256, --mtp-draft 2,
3-run avg, "knight" prompt): K=2 batched verifier ~8-9 t/s with this
change, up from ~8.0 baseline.
Byte-equality
-------------
Share-warp kernel is mathematically and numerically identical to the
batch_warp8 kernel for the weights it serves. Plain decode and
MTP K=2 batched-verifier STRICT outputs are byte-equal to upstream/main.
LOC
---
ds4_cuda.cu: +102/-7 (one new templated kernel, one dispatch branch,
three template instantiations).
Swaps the back-to-back ds4_gpu_head_rms_norm_tensor + ds4_gpu_rope_tail_tensor pair on the Q tensor for the existing ds4_gpu_head_rms_norm_rope_tail_tensor fused kernel. Mainline already defined the fused kernel in ds4_cuda.cu but only the standalone function - no callers were using it on the decode hot path. Sites updated: - decode q_path: single-token raw-SWA decode path - batched q_path: batched verifier / prefill path Savings per call: one DRAM round trip on the Q tensor and one kernel launch per layer. Numerical-parity fix vs sequential reference -------------------------------------------- The naive fused kernel reads tail[i] from memory, multiplies by the rms scale in-register, then immediately combines with the RoPE cos/sin into the rotated output. Under --use_fast_math nvcc was folding the scale multiply into the c/s multiply via FMA, which differs from the sequential reference (head_rms_norm writes scale*tail[i] back to fp32 memory, then rope_tail reads it back). This ULP-scale per-pair drift compounds across 30 layers and high pos0 values, flipping argmax decisions on long-context prompts. Fix: wrap the scale multiply in __fmul_rn (single-rounded fp32 multiply, hard barrier to FMA contraction). x0 = __fmul_rn(tail[i], scale) and x1 = __fmul_rn(tail[i+1], scale) reproduce the bit pattern that the sequential path's memory store-then-load produces, eliminating the long- context drift. Verification ------------ - Plain decode 64-token "knight" output: byte-equal to upstream/main. - MTP K=2 batched-verifier STRICT 64-token output: byte-equal to upstream/main. - ds4_test --all: tensor-equivalence summary matches PR2 base distribution (pre-existing intermittent long_memory_archive atomicAdd non-determinism in MoE prefill at n_tokens>=128 is unchanged; not introduced by this commit). Header: added ds4_gpu_head_rms_norm_rope_tail_tensor declaration to ds4_gpu.h (the wrapper existed in ds4_cuda.cu but was not declared). Bench (DGX Spark, ds4flash.gguf, --temp 0, 32 gen tokens): Plain decode: ~16.3 t/s, no regression vs PR2 base.
b16124b to
86d01b9
Compare
af16691 to
65d8182
Compare
TrevorS
added a commit
that referenced
this pull request
May 24, 2026
Two related changes packed together.
1. Combined-forward K=2 wiring
-----------------------------
`spec_argmax_combined` now also handles draft_cap=2 (N=3 batched verify
over [first_token, drafts[0], drafts[1]]) with prefix-2 commit dispatch
for the commit ∈ {0, 1, 2} cases.
Gated behind DS4_MTP_COMBINED_K2=1 because measurement shows the K=2
variant is currently a loss on mainline:
Combined K=1 (N=2 batched): 9.51 t/s (no flag, /dev/null)
Combined K=2 (N=3 batched): 7.34 t/s (DS4_MTP_COMBINED_K2=1)
Why K=2 loses: `drafts[1]` cascades from `drafts[0]`'s MTP-state, but
`drafts[0]` itself comes from `combined_prev_hc` (= post-previous-iter-
last-token HC), not from the fresh post-`first_token` main-HC the
canonical eval(first_token) would produce. So `drafts[0]` is "one
position stale" already, and `drafts[1]` cascades further off-target.
The target verifier rejects `drafts[1]` in the vast majority of iters,
so the extra batched-N=3 row costs more than it pays.
Keeping the K=2 path as opt-in because the prefix-2 wiring is correct
and reusable when the staleness fix lands (interleaved MTP-block inside
batched main forward, spark-style). See PHASE4.md item #1.
2. Session-cached spec_row_logits buffer
----------------------------------------
Adds `s->spec_row_logits_buf` (3 * VOCAB f32 = ~1.5 MiB) and
`s->spec_row_tops_buf` (3 * int) allocated at session creation,
replacing the per-spec-call xmalloc/free pattern in
`ds4_session_eval_speculative_argmax_combined`.
Measurement impact: small (~0-3% in noise). The malloc overhead
hypothesis was a wrong guess at what was producing the 73 ms per-call
overhead between component-timed (~95 ms) and observed wall
(~168 ms) combined cost. Documented in PHASE4.md item #3 -- the
actual source of that overhead is still unidentified after this
attempt.
Effect on default combined K=1: 9.51 -> 9.48 t/s (within noise).
Foundation for future xmalloc cleanup in the canonical path's
decode2_exact branch (still allocates per-call).
LOC
---
ds4.c: +67/-35 (combined K=2 dispatch + session buf fields + alloc/free
sites + caller changes). Two new session fields, two new env gates.
NO github push. jj change oxmoztuq.
TrevorS
added a commit
that referenced
this pull request
May 24, 2026
…rp dispatch PR #3's K4 commit (584de5e on the v2 branch) added a cuBLAS-cache- availability gate to the share-warp Q8 matmul dispatch in cuda_matmul_q8_0_tensor_labeled. The intent was correctness: prevent share-warp from displacing cuBLAS where cuBLAS would have handled the weight. On DGX Spark this gate is empirically always-false: every Q8 weight has a cached F16 copy by the time the dispatcher runs, so share-warp NEVER fires at n_tok=2..4 with blocks<=32 -- which means the Q_B and output_proj_b matmuls (the only V4-Flash matmuls with blocks=32) silently route through cuBLAS' small-M tensor-core path which pads M=2..4 to M=16 and wastes ~7/8 of the inner-product work. The actual correctness concern is narrower: under DS4_MTP_STRICT (or --quality), users require byte-equality with plain decode. Share-warp is not bit-identical to cuBLAS Gemm at small M (different reduction order), so strict-mode must fall through to cuBLAS. In non-strict mode this drift is acceptable -- it matches PR #6's combined-forward Option-B pattern (same env knob selects strict vs perf). Replace the cuBLAS-cache-availability check with `!strict_mtp_env`. Same opt-out shape as the combined-forward gate in ds4_session_eval_speculative_argmax. The `blocks <= 32u` constraint is preserved (share-warp is bit-equal to N=1 warp8 only at blocks<=32; larger block counts drift from the batch_warp8 reference and would fail ds4_test --all long-context tensor equivalence -- verified empirically during bisect). Bench impact (DGX Spark, ds4flash, n=256, "knight" prompt): Default `--mtp` (combined-forward fires) 15.6 -> 16.20 t/s (+3.8%) `DS4_MTP_STRICT=1 --mtp` (canonical) 13.83 -> 13.83 unchanged Plain decode 16.60 -> 16.60 unchanged Strict-mode byte-equality vs PR #6 baseline confirmed (diff empty). `./ds4_test --all` shows the same 1 pre-existing failure as PR #6 (`logprob-vectors short_code_completion`, also fails on upstream/main).
Owner
Author
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
PR3 draft: cuda: revive 2 previously-dropped kernels with FMA-contraction fixes (stacked on #2)
Summary
Two new commits stacked on
mtp-beats-plain-kernels(PR #2). Revives kernel optimizations that had been dropped from PR #2 due to numerical drift, after root-causing each to a specific FMA-contraction issue and applying a targeted fix.The new commits:
584de5e— cuda: shared-weight Q8 matmul kernel for small-N batched (MTP verifier), originally dropped from PR cuda: small-N batched kernel polish (stacked on #1) #2 for MTP-strict drift, now correctly gated so it only replaces the bit-identical warp8 fallback (not the cuBLAS path).af16691— cuda: fuse head_rms_norm + rope_tail on Q (decode + batched paths), originally dropped for long-contextlong_memory_archivegreedy drift, now uses__fmul_rnas an FMA-contraction barrier to reproduce the sequential reference's bit pattern.What's still NOT in this PR (and why)
routed_moe small-N decode-LUT path(the K5 candidate from PR cuda: small-N batched kernel polish (stacked on #1) #2 investigation) — drift root-caused but fix has unacceptable blast radius. The K5 down path usesdev_dot_q2_K_q8_K_block(single-pair variant: pre-computesdall = y_d * x_dthendall * isum - dmin * summs), while the upstream reference at n_tokens=2 usesdev_dot_q2_K_q8_K_block8(8-pair batched variant: inlinesyd * xd * isum[p] - yd * xmin * summs[p]). Under--use_fast_math, nvcc emits different FMA contractions for these two formulations, producing different fp32 outputs from the same INT32 partial sums. Closing the gap would require either (a) a new K5-only down kernel inlining the 8-pair math but looping over the 6 expert slots, or (b) modifying the shared single-pair kernel — butdev_dot_q2_K_q8_K_blockhas 8 other callers across the MoE down stack (lines 9289, 9322, 9348, 9374, 9456, 9491, 9515, 9553), so a change there risks regressing every batched MoE path. Both options exceed the reasonable scope for this PR. Worth a separate, focused PR if/when someone wants to chase it.Root-cause details
K4 — shared-weight Q8 matmul
Originally fired unconditionally at n_tok=2..4, replacing both the cuBLAS Gemm/Sgemm path and the
matmul_q8_0_preq_batch_warp8_kernelfallback. The cuBLAS path uses tensor-core tiling that is numerically different from any custom Q8 warp kernel; the warp8 fallback uses the same per-blockwscale * xs * dotoperand order as the share-warp kernel, so K4 IS bit-identical to warp8 but NOT to cuBLAS.Fix (7-line gating predicate at
ds4_cuda.cu:5996-6053):The cache-pointer checks short-circuit on label/dim mismatch and respect the same allowed-cache predicates as the cuBLAS dispatch, so the gate fires precisely on the weights that cuBLAS would have skipped (cache miss → fallback to warp8). On those weights, share-warp matches warp8 bit-for-bit. Weights that hit cuBLAS stay on cuBLAS.
K1 — head_rms_norm + rope_tail fuse
The naive fused kernel applied the rms scale to
tail[i]/tail[i+1]inline in the same register-level expression as the RoPE cos/sin multiply. Under--use_fast_mathnvcc fusedscale * tail[i] * candscale * tail[i+1] * sinto single FMAs, which differs from the sequential reference path wherehead_rms_norm_kernelwritesxr[i] *= scaleto fp32 memory andrope_tail_kernelreads it back. The ULP-scale per-pair drift compounds across 30 layers × highpos0to flip argmaxes on long-context tests.Fix (2 inline operand changes at
ds4_cuda.cu:2416-2425):__fmul_rnis a hard FMA-contraction barrier that forces a single-rounded fp32 multiply, reproducing the bit pattern of the sequential reference's memory store-then-load.Headline numbers (DGX Spark / GB10, ds4flash IQ2XXS-w2Q2K)
Single-prompt smoke benches (
./ds4 -n 256 -p "knight" --temp 0), 3-run mean:Marginal absolute gains because the kernels target paths that don't dominate single-token decode. The engineering value is the root-causing and fixing, not the t/s number — the K1 and K4 fixes are general-purpose techniques (FMA barriers, dispatch gating on cache-pointer availability) that future kernel work can reuse.
Tested against
make clean && make cuda-spark— clean, no new warnings./ds4_test --all— only pre-existinglogprob-vectors short_code_completionfailure (also onupstream/main); tensor-equivalence summary:capture_fail=0 logits_fail=0 greedy_fail=0 top1_mismatch=0make cuda-regression— pre-existing build error (also onupstream/main), unchangedmake cpu— clean build, no new warningsupstream/mainfor plain decode (n=64, "knight") — 0 bytesupstream/mainfor MTPBATCH_VERIFY=1 STRICT=1(n=64, "knight") — 0 bytesDeepSeek-V4-Flash-IQ2XXS-w2Q2K-AProjQ8-SExpQ8-OutQ8-chat-v2-imatrix.gguf(80.76 GiB)Pre-existing non-determinism observation
Independent of this PR: the
mtp-beats-plain-kernels(#2) baseline andupstream/mainboth show intermittentlong_memory_archivegreedy_fail in roughly 1 of 4ds4_test --allruns. Source: MoE down path usesatomicAddatn_tokens >= 128(prefill chunk size), which is not float-deterministic. The 5 dropped-kernel runs and 3 PR3-final runs show this same distribution — neither K1 nor K4 introduces it. Flagged here because it surfaced during the careful investigation that produced this PR; it's a separate concern worth its own ticket.Out of scope / follow-ups
routed_moe small-N decode-LUT path) revival — would need a focused PR on thedev_dot_q2_K_q8_K_block/block8family