mtp: combined-forward speculative decode beats plain on GB10 (+2.4 t/s) (stacked on #11)#12
Closed
TrevorS wants to merge 1 commit into
Closed
mtp: combined-forward speculative decode beats plain on GB10 (+2.4 t/s) (stacked on #11)#12TrevorS wants to merge 1 commit into
TrevorS wants to merge 1 commit into
Conversation
9a81c1e to
c82328b
Compare
c82328b to
4b9587c
Compare
2330d1d to
36c1735
Compare
4b9587c to
9c55968
Compare
Makes `--mtp` faster than plain decode on DGX Spark / GB10 by replacing
the canonical MTP draft+verify sequence (eval first_token, then run the
verifier separately) with a single batched-N=2 forward over
[first_token, drafts[0]]. The verifier reads both tokens' logits in one
graph eval, so the cost amortizes.
Strict mode (--quality or DS4_MTP_STRICT=1) falls back to canonical
decode2_exact for byte-equality with plain decode.
Speed (ds4-bench standard sweep, promessi_sposi.txt, gen=128, GB10):
ctx plain --mtp delta
2048 14.24 16.13 +1.89
10240 14.04 15.56 +1.52
18432 13.88 14.82 +0.94
26624 13.59 14.77 +1.18
34816 12.95 14.91 +1.96
43008 12.79 13.79 +1.00
51200 12.57 12.95 +0.38
59392 12.32 12.91 +0.59
MTP is faster at every context; the margin tracks the speculative
accept rate, which varies with how predictable the continuation is
(prose like promessi_sposi sits on the lower end; chat-style prompts
accept more drafts and see +2 to +6 t/s). Full sweeps for both paths
are in speed-bench/gb10.csv (plain) and speed-bench/gb10_mtp.csv (--mtp).
What's in this commit:
1. Small-N matmul kernel polish (cuda):
- Pair-fuse Q_A + KV_A in qkv_rms_fused decode (one weight load, two
outputs).
- Fuse head_rms_norm + rope_tail on Q (decode + batched paths).
2. Q8 share-weight batched matmul (cuda):
- matmul_q8_0_preq_batch_share_warp_kernel<N_TOK>: one Q8 weight row
per warp, N_TOK F32 dot-products against N_TOK token activations.
Bit-equal to N=1 matmul_q8_0_preq_warp8_kernel at any block count
(same per-lane stride, same warp_sum_f32, explicit __fmaf_rn locks
the FMA contraction to match the N=1 SASS). Dispatched at
n_tok=2..4 under !DS4_MTP_STRICT; cuBLAS Q8 path under strict.
3. Combined-forward verifier (mtp):
- metal_graph_eval_mtp_draft_n_from_hc: batched MTP-draft primitive.
- ds4_session_eval_speculative_argmax_combined: single batched
verifier forward over [first_token, drafts[0]], accept drafts that
match target argmax, strict-mode fallback to decode2_exact.
- combined_prev_hc bootstrap from canonical eval for the first iter.
4. Dispatch + default (cli):
- K4 share-warp dispatch gate: replace the cuBLAS-cache-availability
check (always-false on Spark, where every Q8 weight has a cached
F16 copy) with a !DS4_MTP_STRICT gate, so the share-warp kernel
actually fires at decode time.
- mtp_draft_tokens default 1 -> 2. Combined-forward needs
mtp_draft_tokens==2 to fire; the previous default of 1 routed
--mtp through canonical decode2 with no measurable win.
5. Bench tooling (ds4-bench):
- Add --mtp FILE / --mtp-draft N so ds4-bench can drive the
speculative decode path (mirrors the CLI decode loop) and report
real --mtp throughput. Logs the chosen decode path. Default
--mtp-draft 2. speed-bench/gb10_mtp.csv generated with it.
Plain decode is unchanged by this commit (byte-identical to upstream
on the chat-formatted test prompts; the ds4-bench plain sweep matches
the no-MTP baseline at every context).
Tested:
- make clean && make cuda-spark clean
- make cpu clean
- ./ds4_test --long-context, --tool-call-quality, --server,
--metal-kernels OK
- ds4-bench plain + --mtp full 2048..65536 sweeps (CSVs included)
- Two ds4_test checks fail identically on raw upstream/main (bfe070a),
not introduced here:
--logprob-vectors short_code_completion: fixture continuation is
one greedy token off.
--metal-tensor-equivalence: intrinsically flaky on GB10 (batched-
prefill run-to-run non-determinism; raw upstream and the no-MTP
PR1 branch show the same behavior, so combined-forward does not
worsen it).
Hardware: NVIDIA DGX Spark (GB10 / sm_121), driver 580.142, CUDA 13.0
Model: DeepSeek-V4-Flash-IQ2XXS-w2Q2K-AProjQ8-SExpQ8-OutQ8-chat-v2-imatrix.gguf
MTP: DeepSeek-V4-Flash-MTP-Q4K-Q8_0-F32.gguf
9c55968 to
00644e0
Compare
This was referenced May 24, 2026
Owner
Author
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
mtp: combined-forward speculative decode beats plain on GB10
Stacked on #11 (
clean-spark-backend).Summary
Makes
--mtpfaster than plain decode on DGX Spark / GB10 by replacing the canonical MTP draft+verify sequence (evalfirst_token, then run the verifier separately) with a single batched-N=2 forward over[first_token, drafts[0]]. The verifier reads both tokens' logits in one graph eval, so the cost amortizes.Strict mode (
--qualityorDS4_MTP_STRICT=1) falls back to canonicaldecode2_exactfor byte-equality with plain decode.Speed —
ds4-benchstandard sweep (promessi_sposi.txt, gen=128, GB10)ds4-benchnow takes--mtp(see below), so this is the same harness CONTRIBUTING.md specifies, just with speculative decode enabled. Full CSVs:speed-bench/gb10.csv(plain),speed-bench/gb10_mtp.csv(--mtp).--mtpMTP is faster at every context (+0.4 to +2.0 t/s on this prompt). The margin tracks the speculative accept rate, which depends on how predictable the continuation is — prose like
promessi_sposisits on the lower end.Prompt-dependence (chat-style prompts accept more drafts)
Measured with
ds4 -p ... -n 256 --temp 0(plain vs default--mtp), not in the CSV:--mtp"knight"(short)These are higher-accept-rate cases; the CSV table above is the conservative standard-prompt number.
What's in the PR
1. Small-N matmul kernel polish
qkv_rms_fuseddecode (one weight load, two outputs).head_rms_norm+rope_tailon Q.2. Q8 share-weight batched matmul
matmul_q8_0_preq_batch_share_warp_kernel<N_TOK>: one Q8 weight row per warp, N_TOK F32 dot-products against N_TOK token activations.matmul_q8_0_preq_warp8_kernelat any block count: same per-lane stride, samewarp_sum_f32, explicit__fmaf_rnlocks the FMA contraction to the N=1 SASS. Dispatched atn_tok=2..4under!DS4_MTP_STRICT.3. Combined-forward verifier
metal_graph_eval_mtp_draft_n_from_hc: batched MTP-draft primitive.ds4_session_eval_speculative_argmax_combined: single batched verifier forward over[first_token, drafts[0]], accept drafts matching target argmax, strict-mode fallback todecode2_exact.combined_prev_hcbootstrap from canonical eval for the first iter.4. Dispatch + default
!DS4_MTP_STRICTgate, so the share-warp kernel fires at decode time.mtp_draft_tokensdefault 1 → 2 (combined-forward needs==2to fire; the old default of 1 routed--mtpthrough canonicaldecode2with no win).5. Bench tooling
ds4-benchgains--mtp FILE/--mtp-draft N, mirroring the CLI decode loop so the standard harness can measure real--mtpthroughput. Logs the chosen decode path.speed-bench/gb10_mtp.csvgenerated with it.Correctness
DS4_MTP_STRICT=1/--quality: combined-forward declines, path is canonicaldecode2_exact. Plain decode byte-identical toupstream/main(chat-formatted test prompts). Theds4-benchplain sweep matches the no-MTP baseline at every context.Tested
make clean && make cuda-spark— cleanmake cpu— clean./ds4_test --long-context,--tool-call-quality,--server,--metal-kernels— OKds4-benchplain +--mtpfull 2048→65536 sweeps (both CSVs included)ds4_testchecks fail identically on rawupstream/main(bfe070a) — not introduced here:--logprob-vectors short_code_completion— fixture's official continuation is one greedy token off--metal-tensor-equivalence— intrinsically flaky on GB10 (batched-prefill run-to-run non-determinism; reproduces on raw upstream). Verified the combined-forward strip/changes don't worsen it: PR1 (no MTP changes) and raw upstream show the same behavior.Hardware: NVIDIA DGX Spark (GB10 / sm_121), driver 580.142, CUDA 13.0
Model:
DeepSeek-V4-Flash-IQ2XXS-w2Q2K-AProjQ8-SExpQ8-OutQ8-chat-v2-imatrix.ggufMTP:
DeepSeek-V4-Flash-MTP-Q4K-Q8_0-F32.ggufAGENT.md compliance
--mtp-draftis pre-existing (only its default changes),DS4_MTP_STRICTis pre-existing.DS4_CUDA_NO_Q8_SHARE_BATCH=1opt-out preserved as the kill switch.