cuda: small-N batched kernel polish (stacked on #1)#2
Closed
TrevorS wants to merge 2 commits into
Closed
Conversation
…arget 1)
Replaces the lane-0-only HC epilogue in
matmul_q8_0_hc_expand_preq_warp8_kernel (ds4_cuda.cu:2193) with a
warp-shuffle parallel epilogue that uses lanes 0..n_hc-1 to share
residual_hc loads and compute n_hc outputs in parallel.
Bottleneck before
-----------------
After warp_sum_f32 reduces the dot product into lane 0, the original
epilogue (lines 2224-2240) collapsed to lane 0 doing:
- n_hc serial reads of post[]
- n_hc * n_hc reads of comb[]
- n_hc * n_hc reads of residual_hc[] (each from a different HBM row)
- n_hc writes of out_hc[]
For DSV4 (n_hc=4), that is 16 serial HBM reads of residual_hc per row,
plus 4 writes -- 31 of 32 lanes in the warp idle.
Change
------
Broadcast block_v (the post-reduction acc) from lane 0 to all lanes via
__shfl_sync. Lanes 0..n_hc-1 each load ONE residual value from HBM,
then share residuals across lanes via warp shuffle (no shared memory
needed since n_hc <= 32). Each of those lanes computes its dst_hc
output independently and writes in parallel.
For n_hc > warp size (not used by DSV4 but allowed by the kernel
signature), the old serial path is preserved as a fallback.
Result
------
Validated on DGX Spark (GB10, sm_121) via DS4_METAL_DECODE_STAGE_PROFILE
profile aggregation:
attn_output stage (43-layer sum):
pre: 14.62 ms/iter
post: 14.42 ms/iter
delta: -0.20 ms (~1.4% on the stage)
Plain decode (n=32, --temp 0 --nothink, ds4flash.gguf, 3-run avg):
pre: 16.01 t/s
post: 16.11 t/s
delta: +0.7%
The gain is real but smaller than the initial scout estimate (3-5x on
attn_output) because the lane-0 epilogue is only ~1.5% of the
matmul_q8_0_hc_expand kernel's per-row cost. The dot product loop (256
Q8_0 blocks per row through dp4a) is bandwidth-bound and dominates the
kernel time (~47% of peak HBM bandwidth, ~98% of the kernel's wall
time).
Net effect on the 14.6 ms attn_output bucket: the kernel is still
bandwidth-bound on the dot product. Further gains on this stage need
either vectorized Q8 weight loads or a different tile geometry --
larger LOC, deeper kernel work.
Parity
------
Generation output coherent and byte-identical to pre-change for the
deterministic "Once upon a time" prompt at --temp 0:
"Once upon a time, in a land where the rivers sparkled like liquid
sapphires and the trees whispered secrets to..."
Per-element output_hc may differ by 1-2 ULP vs the previous lane-0
serial path due to FMA-order differences from the per-lane reduction
vs the single-lane sequential reduction. No observable token-level
divergence in the test suite (./ds4_test: tool-call-quality,
long-context, metal-kernels, server all pass; pre-existing
logprob-vectors failure unchanged).
LOC
---
ds4_cuda.cu: +30/-12 in matmul_q8_0_hc_expand_preq_warp8_kernel only.
No caller changes, no header changes, no launch-param changes.
NO github push. jj change spnkxvoo -> ztnmlvyx.
In the qkv_rms_fused branch of metal_graph_encode_decode_layer, the Q_A
matmul (attn_q_a, DS4_N_EMBD -> q_rank) and KV_A matmul (attn_kv,
DS4_N_EMBD -> DS4_N_HEAD_DIM) ran as two back-to-back
ds4_gpu_matmul_q8_0_tensor calls. Both read the same attn_norm input,
so they each triggered an independent prequantize of attn_norm and an
independent warp-of-8 matmul launch.
This commit replaces the pair with ds4_gpu_matmul_q8_0_pair_tensor (the
same primitive already used by the shared_gate/shared_up fusion), which
at n_tok=1 issues one prequantize_q8_0_f32_kernel + one
matmul_q8_0_pair_preq_warp8_kernel. The pair kernel handles asymmetric
out0_dim/out1_dim (q_rank=768 vs DS4_N_HEAD_DIM=192) by gridding on
max(out0_dim, out1_dim).
Sites updated:
- decode qkv_rms_fused branch (ds4.c:9371-9415): pair-fused path
- decode qkv_rms_fused else (ds4.c:9416-...): Q_A standalone retained
Sites NOT updated:
- batched qkv_rms_fused (ds4.c:11331+): at n_tok>1 the pair primitive
falls back to two sequential matmul_q8_0_preq_batch_warp8 calls,
yielding identical behavior; touching it would be a no-op pending a
separate batched-pair kernel rewrite.
- decode !qkv_rms_fused (rare reference path): KV_A matmul is too far
downstream of Q_A to share attn_norm prequantize cleanly.
Savings per layer at n_tok=1:
- 1 quantize_q8_0_f32_kernel launch eliminated (~5us)
- 1 matmul kernel launch eliminated (~5us)
- prequantized x reused for both matmuls (in-kernel) -> 1 fewer
DRAM read of attn_norm + 1 fewer scale-buffer write
At 30 layers, that's ~300-450us per token of theoretical headroom.
Bench (DGX Spark, ds4flash.gguf, --temp 0 --nothink, plain decode):
Pre-fusion (prxlvzlq): 16.26 t/s (3-run stable)
Post-fusion (lvlnlxsk): 16.28 t/s (3-run avg @ n=32),
15.80 t/s (3-run avg @ n=128)
The delta is within run-to-run variation -- same characterization as
the head_rms_norm + rope_tail fusion in the prior commit. Launch
reductions compound and become substantially more material once
captured graphs are added (each eliminated launch becomes one fewer
graph node in the captured DAG).
Parity: generation output coherent ("Once upon a time, in a land where
the rivers sparkled like liquid sapp..." @ n=24, --temp 0). The pair
kernel uses the same Q8_0 prequantize of attn_norm and the same dp4a
warp-of-8 reduction as the standalone matmul_q8_0_preq_warp8 kernel.
Output is byte-equal modulo FMA-reordering-scale differences that the
existing pair primitive has also exhibited in shared_gate/up usage.
Test suite: ds4_test logprob-vectors shows a pre-existing failure on
short_code_completion step 1 (assertion at tests/ds4_test.c:490),
reproduced on the prxlvzlq parent before this change. Not introduced
here; tracking separately. metal-kernels and server suites pass.
Header: added ds4_gpu_matmul_q8_0_pair_tensor declaration to ds4_gpu.h.
The wrapper has existed in ds4_cuda.cu since the shared-expert fusion
landed but was not declared in the public header.
NO github push. jj change prxlvzlq -> lvlnlxsk.
b16124b to
86d01b9
Compare
Owner
Author
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
PR2 draft: cuda: incremental kernel polish on top of HBM-resident model (#1)
Summary
Two new commits stacked on
upstream-minimal(#1 in this series, tip3514c55). Builds on the HBM-resident model groundwork from #1; targets per-kernel polish in the decode hot-path.The new commits:
63e11be— cuda: parallelize matmul_q8_0_hc_expand epilogue across n_hc lanes. Spreads the existing per-row reduction acrossn_hcparallel lanes; pure parallelism gain, no numerical change.b16124b— cuda: pair-fuse Q_A + KV_A matmuls in qkv_rms_fused decode path. Fuses two adjacent Q8 matmuls into one shared kernel call, amortizing launch overhead.What's NOT in this PR (and why)
Three other kernel candidates from the larger downstream fork were evaluated and dropped:
fuse head_rms_norm + rope_tail— dropped due to long-context greedy drift. Short-prompt byte-equality held butds4_test --all'slong_memory_archivetensor-equivalence case showed top-1 token mismatch from step 0, with cascading divergence through generation. RoPE FMA-order interaction at far positions surfaces only on long contexts. Worth pursuing in a future PR if the fusion can be reworked to preserve FMA order.shared-weight Q8 matmul kernel— dropped due to MTP strict-mode drift on the polished cuda: HBM-resident model on DGX Spark / GB10 #1 base (independently confirmed). Byte-diff vsupstream/mainforDS4_MTP_BATCH_VERIFY=1 DS4_MTP_STRICT=1diverges starting at first token.routed_moe small-N decode-LUT path— dropped due to MTP strict-mode drift at n_tokens=2..4. The LUT path produces numerically different output than the cuBLAS path it replaces in the batched verifier regime.Each dropped commit was tested in isolation against the polished #1 base. Strict-mode byte-equality is the hard correctness bar per AGENT.md ("Do not keep a faster path with unexplained ... drift").
Headline numbers (DGX Spark / GB10, ds4flash IQ2XXS-w2Q2K)
Single-prompt smoke benches (
./ds4 -n 256 -p "knight" --temp 0), 3-run mean:Gains are small in absolute terms because both kernels target paths that don't dominate single-token decode. The
ds4-benchformal sweep (mostly N=1 decode at growing ctx) is within noise of #1 — PR2 doesn't changespeed-bench/gb10.csv.Shipping the two regardless because each is strict-clean and multi-run-stable. The narrow scope (one optimization per commit) keeps review simple.
Tested against
make clean && make cuda-spark— clean, no new warnings./ds4_test --all— 1 failure (logprob-vectors short_code_completion), same failure onupstream/main— pre-existing fixture drift./ds4_test --alltensor-equivalence: 0 greedy_fail, 0 top1_mismatch, 0 logits_fail on all 5 test cases (short_italian_fact, short_code_completion, short_reasoning_plain, long_memory_archive, long_code_audit)make cuda-regression— pre-existing build error intests/cuda_long_context_smoke.c(signature mismatch withds4_gpu_attention_decode_heads_tensor), same error onupstream/main, not introduced by this PRmake cpu— clean build, no new warningsupstream/mainfor plain decode (n=64, "knight") — 0 bytesupstream/mainfor MTPBATCH_VERIFY=1 STRICT=1(n=64, "knight") — 0 bytesDeepSeek-V4-Flash-IQ2XXS-w2Q2K-AProjQ8-SExpQ8-OutQ8-chat-v2-imatrix.gguf(80.76 GiB)DeepSeek-V4-Flash-MTP-Q4K-Q8_0-F32.gguf(3.5 GiB)Notes
ds4_test --all's tensor-equivalence covers both short and long contexts (up to ~30K tokens), which catches drift the short-prompt diff misses.Out of scope / follow-ups
head_rms_norm + rope_tailfusion is the most likely salvage — its drift is FMA-ordering on long-position RoPE, recoverable by preserving the original op order inside the fused kernel.