Skip to content

cuda: small-N batched kernel polish (stacked on #1)#2

Closed
TrevorS wants to merge 2 commits into
upstream-minimalfrom
mtp-beats-plain-kernels
Closed

cuda: small-N batched kernel polish (stacked on #1)#2
TrevorS wants to merge 2 commits into
upstream-minimalfrom
mtp-beats-plain-kernels

Conversation

@TrevorS
Copy link
Copy Markdown
Owner

@TrevorS TrevorS commented May 23, 2026

PR2 draft: cuda: incremental kernel polish on top of HBM-resident model (#1)

Summary

Two new commits stacked on upstream-minimal (#1 in this series, tip 3514c55). Builds on the HBM-resident model groundwork from #1; targets per-kernel polish in the decode hot-path.

The new commits:

  1. 63e11becuda: parallelize matmul_q8_0_hc_expand epilogue across n_hc lanes. Spreads the existing per-row reduction across n_hc parallel lanes; pure parallelism gain, no numerical change.
  2. b16124bcuda: pair-fuse Q_A + KV_A matmuls in qkv_rms_fused decode path. Fuses two adjacent Q8 matmuls into one shared kernel call, amortizing launch overhead.

What's NOT in this PR (and why)

Three other kernel candidates from the larger downstream fork were evaluated and dropped:

  • fuse head_rms_norm + rope_tail — dropped due to long-context greedy drift. Short-prompt byte-equality held but ds4_test --all's long_memory_archive tensor-equivalence case showed top-1 token mismatch from step 0, with cascading divergence through generation. RoPE FMA-order interaction at far positions surfaces only on long contexts. Worth pursuing in a future PR if the fusion can be reworked to preserve FMA order.
  • shared-weight Q8 matmul kernel — dropped due to MTP strict-mode drift on the polished cuda: HBM-resident model on DGX Spark / GB10 #1 base (independently confirmed). Byte-diff vs upstream/main for DS4_MTP_BATCH_VERIFY=1 DS4_MTP_STRICT=1 diverges starting at first token.
  • routed_moe small-N decode-LUT path — dropped due to MTP strict-mode drift at n_tokens=2..4. The LUT path produces numerically different output than the cuBLAS path it replaces in the batched verifier regime.

Each dropped commit was tested in isolation against the polished #1 base. Strict-mode byte-equality is the hard correctness bar per AGENT.md ("Do not keep a faster path with unexplained ... drift").

Headline numbers (DGX Spark / GB10, ds4flash IQ2XXS-w2Q2K)

Single-prompt smoke benches (./ds4 -n 256 -p "knight" --temp 0), 3-run mean:

Mode upstream-minimal (#1) this PR Δ
Plain decode 16.45 16.52 +0.4%
MTP batched-verifier strict 13.57 13.74 +1.3%

Gains are small in absolute terms because both kernels target paths that don't dominate single-token decode. The ds4-bench formal sweep (mostly N=1 decode at growing ctx) is within noise of #1 — PR2 doesn't change speed-bench/gb10.csv.

Shipping the two regardless because each is strict-clean and multi-run-stable. The narrow scope (one optimization per commit) keeps review simple.

Tested against

  • make clean && make cuda-spark — clean, no new warnings
  • ./ds4_test --all1 failure (logprob-vectors short_code_completion), same failure on upstream/main — pre-existing fixture drift
  • ./ds4_test --all tensor-equivalence: 0 greedy_fail, 0 top1_mismatch, 0 logits_fail on all 5 test cases (short_italian_fact, short_code_completion, short_reasoning_plain, long_memory_archive, long_code_audit)
  • make cuda-regressionpre-existing build error in tests/cuda_long_context_smoke.c (signature mismatch with ds4_gpu_attention_decode_heads_tensor), same error on upstream/main, not introduced by this PR
  • make cpu — clean build, no new warnings
  • Plain 3-run stability — all 3 runs produce byte-identical coherent output (16.51, 16.52, 16.53 t/s)
  • Byte-equality diff vs upstream/main for plain decode (n=64, "knight") — 0 bytes
  • Byte-equality diff vs upstream/main for MTP BATCH_VERIFY=1 STRICT=1 (n=64, "knight") — 0 bytes
  • Hardware: NVIDIA DGX Spark (GB10 / sm_121), driver 580.142, CUDA 13.0
  • Model: DeepSeek-V4-Flash-IQ2XXS-w2Q2K-AProjQ8-SExpQ8-OutQ8-chat-v2-imatrix.gguf (80.76 GiB)
  • MTP: DeepSeek-V4-Flash-MTP-Q4K-Q8_0-F32.gguf (3.5 GiB)

Notes

  • No new env knobs. Behavior changes are unconditional.
  • Both new kernels were validated for strict-mode byte equivalence using a stricter superset of just the short-prompt check: ds4_test --all's tensor-equivalence covers both short and long contexts (up to ~30K tokens), which catches drift the short-prompt diff misses.
  • Bisect path used during development: K3 alone clean → K3+K2 clean → K3+K2+K1 introduces long_memory_archive drift. K1 dropped, K3+K2 kept.

Out of scope / follow-ups

  • The 3 dropped kernel commits can return as future PRs if rewritten to preserve strict equivalence. The head_rms_norm + rope_tail fusion is the most likely salvage — its drift is FMA-ordering on long-position RoPE, recoverable by preserving the original op order inside the fused kernel.
  • Combined-forward MTP path (downstream fork) — has unresolved strict-mode drift; separate follow-up PR.
  • Captured-graph support for spec decode — independent subsystem, separate PR.

TrevorS added 2 commits May 24, 2026 10:03
…arget 1)

Replaces the lane-0-only HC epilogue in
matmul_q8_0_hc_expand_preq_warp8_kernel (ds4_cuda.cu:2193) with a
warp-shuffle parallel epilogue that uses lanes 0..n_hc-1 to share
residual_hc loads and compute n_hc outputs in parallel.

Bottleneck before
-----------------
After warp_sum_f32 reduces the dot product into lane 0, the original
epilogue (lines 2224-2240) collapsed to lane 0 doing:
  - n_hc serial reads of post[]
  - n_hc * n_hc reads of comb[]
  - n_hc * n_hc reads of residual_hc[] (each from a different HBM row)
  - n_hc writes of out_hc[]
For DSV4 (n_hc=4), that is 16 serial HBM reads of residual_hc per row,
plus 4 writes -- 31 of 32 lanes in the warp idle.

Change
------
Broadcast block_v (the post-reduction acc) from lane 0 to all lanes via
__shfl_sync.  Lanes 0..n_hc-1 each load ONE residual value from HBM,
then share residuals across lanes via warp shuffle (no shared memory
needed since n_hc <= 32).  Each of those lanes computes its dst_hc
output independently and writes in parallel.

For n_hc > warp size (not used by DSV4 but allowed by the kernel
signature), the old serial path is preserved as a fallback.

Result
------
Validated on DGX Spark (GB10, sm_121) via DS4_METAL_DECODE_STAGE_PROFILE
profile aggregation:

  attn_output stage (43-layer sum):
    pre:  14.62 ms/iter
    post: 14.42 ms/iter
    delta: -0.20 ms (~1.4% on the stage)

  Plain decode (n=32, --temp 0 --nothink, ds4flash.gguf, 3-run avg):
    pre:  16.01 t/s
    post: 16.11 t/s
    delta: +0.7%

The gain is real but smaller than the initial scout estimate (3-5x on
attn_output) because the lane-0 epilogue is only ~1.5% of the
matmul_q8_0_hc_expand kernel's per-row cost.  The dot product loop (256
Q8_0 blocks per row through dp4a) is bandwidth-bound and dominates the
kernel time (~47% of peak HBM bandwidth, ~98% of the kernel's wall
time).

Net effect on the 14.6 ms attn_output bucket: the kernel is still
bandwidth-bound on the dot product.  Further gains on this stage need
either vectorized Q8 weight loads or a different tile geometry --
larger LOC, deeper kernel work.

Parity
------
Generation output coherent and byte-identical to pre-change for the
deterministic "Once upon a time" prompt at --temp 0:
  "Once upon a time, in a land where the rivers sparkled like liquid
   sapphires and the trees whispered secrets to..."

Per-element output_hc may differ by 1-2 ULP vs the previous lane-0
serial path due to FMA-order differences from the per-lane reduction
vs the single-lane sequential reduction.  No observable token-level
divergence in the test suite (./ds4_test: tool-call-quality,
long-context, metal-kernels, server all pass; pre-existing
logprob-vectors failure unchanged).

LOC
---
ds4_cuda.cu: +30/-12 in matmul_q8_0_hc_expand_preq_warp8_kernel only.
No caller changes, no header changes, no launch-param changes.

NO github push.  jj change spnkxvoo -> ztnmlvyx.
In the qkv_rms_fused branch of metal_graph_encode_decode_layer, the Q_A
matmul (attn_q_a, DS4_N_EMBD -> q_rank) and KV_A matmul (attn_kv,
DS4_N_EMBD -> DS4_N_HEAD_DIM) ran as two back-to-back
ds4_gpu_matmul_q8_0_tensor calls.  Both read the same attn_norm input,
so they each triggered an independent prequantize of attn_norm and an
independent warp-of-8 matmul launch.

This commit replaces the pair with ds4_gpu_matmul_q8_0_pair_tensor (the
same primitive already used by the shared_gate/shared_up fusion), which
at n_tok=1 issues one prequantize_q8_0_f32_kernel + one
matmul_q8_0_pair_preq_warp8_kernel.  The pair kernel handles asymmetric
out0_dim/out1_dim (q_rank=768 vs DS4_N_HEAD_DIM=192) by gridding on
max(out0_dim, out1_dim).

Sites updated:
  - decode qkv_rms_fused branch (ds4.c:9371-9415): pair-fused path
  - decode qkv_rms_fused else  (ds4.c:9416-...): Q_A standalone retained

Sites NOT updated:
  - batched qkv_rms_fused (ds4.c:11331+): at n_tok>1 the pair primitive
    falls back to two sequential matmul_q8_0_preq_batch_warp8 calls,
    yielding identical behavior; touching it would be a no-op pending a
    separate batched-pair kernel rewrite.
  - decode !qkv_rms_fused (rare reference path): KV_A matmul is too far
    downstream of Q_A to share attn_norm prequantize cleanly.

Savings per layer at n_tok=1:
  - 1 quantize_q8_0_f32_kernel launch eliminated (~5us)
  - 1 matmul kernel launch eliminated (~5us)
  - prequantized x reused for both matmuls (in-kernel) -> 1 fewer
    DRAM read of attn_norm + 1 fewer scale-buffer write
At 30 layers, that's ~300-450us per token of theoretical headroom.

Bench (DGX Spark, ds4flash.gguf, --temp 0 --nothink, plain decode):
  Pre-fusion (prxlvzlq):  16.26 t/s (3-run stable)
  Post-fusion (lvlnlxsk): 16.28 t/s (3-run avg @ n=32),
                          15.80 t/s (3-run avg @ n=128)
The delta is within run-to-run variation -- same characterization as
the head_rms_norm + rope_tail fusion in the prior commit.  Launch
reductions compound and become substantially more material once
captured graphs are added (each eliminated launch becomes one fewer
graph node in the captured DAG).

Parity: generation output coherent ("Once upon a time, in a land where
the rivers sparkled like liquid sapp..." @ n=24, --temp 0).  The pair
kernel uses the same Q8_0 prequantize of attn_norm and the same dp4a
warp-of-8 reduction as the standalone matmul_q8_0_preq_warp8 kernel.
Output is byte-equal modulo FMA-reordering-scale differences that the
existing pair primitive has also exhibited in shared_gate/up usage.

Test suite: ds4_test logprob-vectors shows a pre-existing failure on
short_code_completion step 1 (assertion at tests/ds4_test.c:490),
reproduced on the prxlvzlq parent before this change.  Not introduced
here; tracking separately.  metal-kernels and server suites pass.

Header: added ds4_gpu_matmul_q8_0_pair_tensor declaration to ds4_gpu.h.
The wrapper has existed in ds4_cuda.cu since the shared-expert fusion
landed but was not declared in the public header.

NO github push.  jj change prxlvzlq -> lvlnlxsk.
@TrevorS TrevorS force-pushed the upstream-minimal branch from 3514c55 to 6e95cf6 Compare May 24, 2026 17:13
@TrevorS TrevorS force-pushed the mtp-beats-plain-kernels branch from b16124b to 86d01b9 Compare May 24, 2026 17:13
@TrevorS
Copy link
Copy Markdown
Owner Author

TrevorS commented May 24, 2026

Superseded by the reframed 2-PR stack (#11 + #12), which tells the same Spark/GB10 + MTP combined-forward story more concisely, rebased on current upstream/main, with the exploratory paths dropped.

@TrevorS TrevorS closed this May 24, 2026
@TrevorS TrevorS deleted the mtp-beats-plain-kernels branch May 24, 2026 22:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant