mtp: combined-forward default + Option-B strict fallback (stacked on #5)#6
Closed
TrevorS wants to merge 7 commits into
Closed
mtp: combined-forward default + Option-B strict fallback (stacked on #5)#6TrevorS wants to merge 7 commits into
TrevorS wants to merge 7 commits into
Conversation
TrevorS
added a commit
that referenced
this pull request
May 24, 2026
…rp dispatch PR #3's K4 commit (584de5e on the v2 branch) added a cuBLAS-cache- availability gate to the share-warp Q8 matmul dispatch in cuda_matmul_q8_0_tensor_labeled. The intent was correctness: prevent share-warp from displacing cuBLAS where cuBLAS would have handled the weight. On DGX Spark this gate is empirically always-false: every Q8 weight has a cached F16 copy by the time the dispatcher runs, so share-warp NEVER fires at n_tok=2..4 with blocks<=32 -- which means the Q_B and output_proj_b matmuls (the only V4-Flash matmuls with blocks=32) silently route through cuBLAS' small-M tensor-core path which pads M=2..4 to M=16 and wastes ~7/8 of the inner-product work. The actual correctness concern is narrower: under DS4_MTP_STRICT (or --quality), users require byte-equality with plain decode. Share-warp is not bit-identical to cuBLAS Gemm at small M (different reduction order), so strict-mode must fall through to cuBLAS. In non-strict mode this drift is acceptable -- it matches PR #6's combined-forward Option-B pattern (same env knob selects strict vs perf). Replace the cuBLAS-cache-availability check with `!strict_mtp_env`. Same opt-out shape as the combined-forward gate in ds4_session_eval_speculative_argmax. The `blocks <= 32u` constraint is preserved (share-warp is bit-equal to N=1 warp8 only at blocks<=32; larger block counts drift from the batch_warp8 reference and would fail ds4_test --all long-context tensor equivalence -- verified empirically during bisect). Bench impact (DGX Spark, ds4flash, n=256, "knight" prompt): Default `--mtp` (combined-forward fires) 15.6 -> 16.20 t/s (+3.8%) `DS4_MTP_STRICT=1 --mtp` (canonical) 13.83 -> 13.83 unchanged Plain decode 16.60 -> 16.60 unchanged Strict-mode byte-equality vs PR #6 baseline confirmed (diff empty). `./ds4_test --all` shows the same 1 pre-existing failure as PR #6 (`logprob-vectors short_code_completion`, also fails on upstream/main).
Adds Phase 3 (prefix-2 compressor-state capture) as load-bearing
infrastructure for cheap K>=2 partial-accept rollback, plus a working
Phase 1 combined N=2 forward path (opt-in via DS4_MTP_COMBINED_FORWARD)
that folds eval(first_token) into a single batched 2-row forward.
Default behavior unchanged.
Phase 3: prefix-2 capture (load-bearing)
- 4 new ds4_gpu_graph fields: spec_prefix2_attn_state_kv/score[L],
spec_prefix2_index_state_kv/score[L], counters spec_prefix2_n_comp/
n_index_comp[L], bool spec_capture_prefix2.
- Alloc/free mirror prefix-1 alongside existing init/cleanup loops.
- metal_graph_capture_prefix2_{attn,index}_state functions; mirror
prefix-1 and trigger at t==1 in encode_layer_batch.
- spec_frontier_commit_prefix2 cheap commit (~0.5-1ms per Spark
numbers vs ~7-12ms for full snapshot+restore).
- metal_graph_verify_suffix_tops grows capture_prefix2 param;
existing callers pass false.
- spec_argmax_snapshot_combined_prev_hc helper + graph field
combined_prev_hc + session field combined_prev_hc_valid. Canonical
full-accept and 1-accept(prefix1) paths now snapshot the post-commit
HC row so a future combined iter can derive drafts[0] correctly.
Phase 1: combined N=2 forward (opt-in, off by default)
- ds4_session_eval_speculative_argmax_combined helper: forwards
[first_token, drafts[0]] in a single batched 2-row pass. Effectively
K=1 spec with first_token folded into the verifier call (saving
eval's separate raw_swa forward). Gated by DS4_MTP_COMBINED_FORWARD
env; engages only when prior iter's combined_prev_hc is valid.
- drafts[0] derivation: MTP-block(prev_hc=combined_prev_hc,
token=first_token, pos=base_pos) produces the correct prediction
for position base_pos+1 given that combined_prev_hc holds
post-position-(base_pos-1) HC (= post-last-committed-token HC from
prior iter).
- Cold start (no combined_prev_hc_valid) falls back to canonical so
the first iter primes the snapshot for the next combined iter.
- Partial-accept rollback via spec_frontier_commit_prefix1 for
commit=0 (drafts[0] rejected, first_token only committed).
Path tried and reverted: N=3 combined with chained drafts[1]
An earlier draft used N=3 verifier on [first_token, drafts[0..1]]
and chained MTP for drafts[1]. Bench showed drafts[1] systematically
rejected (100% miss) because the chained MTP call inherits a
mtp_state_hc from drafts[0]'s MTP-block, which used prev_hc =
combined_prev_hc (= post-(p-1) HC). The MTP layer's "informed"
prediction semantic expects prev_hc = post-p HC (= post-first_token
HC, which canonical obtains via raw_swa-forward of first_token).
Cascade: drafts[1] uses a slightly-off baseline → diverges from
target's row_tops[1]. Cannot be fixed on mainline without either
(a) forwarding first_token first (= the eval cost we are trying to
skip), or (b) fused main-forward + MTP-block kernels (out of scope,
not present in Spark either). N=2 mode skips drafts[1] and pays
only for the necessary 2-row verifier.
Bench (DGX Spark, ds4flash.gguf + MTP, --temp 0 --mtp-draft 2):
Plain (no MTP): 16.78 t/s — no regression.
Canonical K=2 full-accept iter: ~124ms / 3 tokens = ~24 t/s peak.
Canonical K=2 margin-skip iter: ~68ms / 2 tokens = ~29 t/s peak.
Combined N=2 commit=1 iter: ~125ms / 2 tokens = ~16 t/s.
Combined N=2 commit=0 iter: ~125ms / 1 token = ~8 t/s.
Combined acceptance rate ~83% on test prompt; average ~14.7 t/s.
Output byte-equal to canonical for 16-token sweep.
Architectural note on K>=2 combined ambition:
Reaching canonical's full-accept 3-tokens-per-iter ceiling in the
combined path requires fixing the drafts[1+] cascade, which requires
kernel-level work to either (i) make batched-N=1 byte-equal to
raw_swa (Spark's unresolved dense_step_n drift territory) or (ii)
fuse MTP-block into the main forward kernel. Phase 3 infrastructure
(prefix-2 capture, combined_prev_hc snapshot) is reusable for those
follow-ups.
NO github push. 4-commit jj stack on top of main.
Adds spec_argmax_bootstrap_combined_prev_hc which copies the post-final- layer HC of the just-evaluated token (g->cur_hc) into combined_prev_hc after every ds4_session_eval inside spec_argmax. Cost is one device-to- device tensor copy of N_HC * N_EMBD * 4 bytes = 112 KiB; negligible in the 200ms-per-iter speculative path. Why --- The combined-N=K speculative path (DS4_MTP_COMBINED_FORWARD) returns -2 on cold start when !s->combined_prev_hc_valid. Without this bootstrap combined_prev_hc is only populated by combined itself (line 18223), so the cold-start bailout is permanent: spec_argmax_combined falls through to canonical every iter, and canonical never sets combined_prev_hc, so combined is structurally dead. After this commit any canonical iter populates combined_prev_hc with the post-final-layer HC of first_token, which is the correct prev_hc for the NEXT iter's MTP-block call: drafts[0] = mtp_block(prev_hc=post-first_token-HC, token=NEW_first_token) This is what spark already does in its forward_mtp_step path. Effect ------ - Canonical K=2 strict accept rate unchanged (83.9%). - Bench (n=256, 3-run avg): canonical K=2 strict 7.59 t/s, combined K=1 effective 8.06 t/s. Combined no longer cold-starts every iter. - Combined-N=2 path still does NOT beat plain decode (15.83 t/s) -- it costs ~130ms for batched-N=2 forward vs ~65ms for single decode, and K=1 effective limits accepted tokens. Phase 4 batched-kernel speedup work is the next layer. LOC --- ds4.c: +24/-1 (one new helper + one call site). No header changes. DS4_MTP_COMBINED_FORWARD is still opt-in. NO github push. jj change oxqszsyw.
…hot)
Before this change, ds4_session_eval_speculative_argmax forced
spec_frontier_snapshot in DS4_MTP_STRICT mode at K=2. That snapshot
copies per-layer compressor state (raw + indexer KV/score tensors for
the 21 layers that run compression) which totals ~6 MiB of GPU memory
per spec iteration. At 39 iters per 128 generated tokens that's
~240 MiB of HBM copying purely for rollback insurance.
The capture_prefix1 path (already used by default in non-strict mode
and by combined-forward) does the same rollback work via cheap per-layer
counter resets in spec_frontier_commit_prefix1, with no GPU tensor
copies. It is byte-correct on accept and on partial-accept because
the compressor caches are append-only -- rejected rows become invisible
once the counters rewind.
This commit removes the strict-mode special case and uses
capture_prefix1 unconditionally at K=2. DS4_MTP_FORCE_SNAPSHOT still
forces the old full-snapshot path for measurement / regression
checking; DS4_MTP_CAPTURE_PREFIX1 becomes a no-op (the default).
Bench
-----
DGX Spark (GB10), ds4flash.gguf + MTP-Q4K, n=128, --mtp-draft 2,
DS4_MTP_STRICT=1, "knight" prompt, 3-run avg:
Pre (force snapshot): 5.20 t/s gen, 80.8% accept
Post (capture_prefix1): 5.49 t/s gen, 80.8% accept
Delta: +5.6%
Accept rate is identical (80.8% in both modes) -- the rollback
semantics are equivalent at K=2. The 5.6% win comes purely from
eliminating the per-iter HBM snapshot copy.
Non-strict K=2 (margin=3 default): unchanged -- capture_prefix1 was
already true there. Plain decode: 16.0 t/s, unchanged.
Tests: ds4_test passes (long-context, tool-call-quality, metal-kernels,
server). Pre-existing logprob-vectors failure unchanged.
What this doesn't fix
---------------------
Even after this fix, mainline MTP K=2 strict mode is ~5.5 t/s vs spark's
~15.9 t/s -- still 3x slower than plain decode. The remaining gap is
not in snapshot/restore (now removed) nor in the verifier compute path
(scout audited: kernels are tight). Most likely candidates for the
residual gap:
- Per-iter setup overhead (token_vec_push, checkpoint mgmt) that
accumulates across many short spec iters
- The `ds4_session_eval(first_token)` full-decode call at line 18312
happens BEFORE the spec phase, paying full plain-decode cost per
iter just to advance state with the verifier's known-good token
- Spark batches first_token into the K-batched forward (combined
mode), avoiding the separate eval call; mainline's combined mode
is opt-in via DS4_MTP_COMBINED_FORWARD and has its own issues
LOC
---
ds4.c: +4/-3 (one boolean simplification + clarifying comment).
NO github push. jj change kmssvuql -> zuxwytmo.
Two related changes packed together.
1. Combined-forward K=2 wiring
-----------------------------
`spec_argmax_combined` now also handles draft_cap=2 (N=3 batched verify
over [first_token, drafts[0], drafts[1]]) with prefix-2 commit dispatch
for the commit ∈ {0, 1, 2} cases.
Gated behind DS4_MTP_COMBINED_K2=1 because measurement shows the K=2
variant is currently a loss on mainline:
Combined K=1 (N=2 batched): 9.51 t/s (no flag, /dev/null)
Combined K=2 (N=3 batched): 7.34 t/s (DS4_MTP_COMBINED_K2=1)
Why K=2 loses: `drafts[1]` cascades from `drafts[0]`'s MTP-state, but
`drafts[0]` itself comes from `combined_prev_hc` (= post-previous-iter-
last-token HC), not from the fresh post-`first_token` main-HC the
canonical eval(first_token) would produce. So `drafts[0]` is "one
position stale" already, and `drafts[1]` cascades further off-target.
The target verifier rejects `drafts[1]` in the vast majority of iters,
so the extra batched-N=3 row costs more than it pays.
Keeping the K=2 path as opt-in because the prefix-2 wiring is correct
and reusable when the staleness fix lands (interleaved MTP-block inside
batched main forward, spark-style). See PHASE4.md item #1.
2. Session-cached spec_row_logits buffer
----------------------------------------
Adds `s->spec_row_logits_buf` (3 * VOCAB f32 = ~1.5 MiB) and
`s->spec_row_tops_buf` (3 * int) allocated at session creation,
replacing the per-spec-call xmalloc/free pattern in
`ds4_session_eval_speculative_argmax_combined`.
Measurement impact: small (~0-3% in noise). The malloc overhead
hypothesis was a wrong guess at what was producing the 73 ms per-call
overhead between component-timed (~95 ms) and observed wall
(~168 ms) combined cost. Documented in PHASE4.md item #3 -- the
actual source of that overhead is still unidentified after this
attempt.
Effect on default combined K=1: 9.51 -> 9.48 t/s (within noise).
Foundation for future xmalloc cleanup in the canonical path's
decode2_exact branch (still allocates per-call).
LOC
---
ds4.c: +67/-35 (combined K=2 dispatch + session buf fields + alloc/free
sites + caller changes). Two new session fields, two new env gates.
NO github push. jj change oxmoztuq.
…canonical) Replaces the env-opt-in (DS4_MTP_COMBINED_FORWARD) gate with a default-on gate for non-strict speculative decode. Strict mode (DS4_MTP_STRICT=1 or e->quality) now falls back to the canonical decode2_exact path so byte equality with plain decode is preserved. Batched-N MoE / attention is not yet bit-identical to N=1 raw_swa (would need DS4_CUDA_STRICT_BATCHED + closing residual attention divergence), so combined-forward can drift one token from canonical. Output remains coherent; strict callers who want canonical bytes opt in via DS4_MTP_STRICT=1. No new env flag introduced: this just defaults the existing path on for non-strict and routes strict to the existing canonical path.
ds4-bench measures plain (non-MTP) decode so numbers are within noise of upstream — combined-forward only fires under --mtp. Refreshed for provenance under the new tip.
ae95ce7 to
4ef4054
Compare
bb85073 to
dcb5cc3
Compare
TrevorS
added a commit
that referenced
this pull request
May 24, 2026
…rp dispatch PR #3's K4 commit (584de5e on the v2 branch) added a cuBLAS-cache- availability gate to the share-warp Q8 matmul dispatch in cuda_matmul_q8_0_tensor_labeled. The intent was correctness: prevent share-warp from displacing cuBLAS where cuBLAS would have handled the weight. On DGX Spark this gate is empirically always-false: every Q8 weight has a cached F16 copy by the time the dispatcher runs, so share-warp NEVER fires at n_tok=2..4 with blocks<=32 -- which means the Q_B and output_proj_b matmuls (the only V4-Flash matmuls with blocks=32) silently route through cuBLAS' small-M tensor-core path which pads M=2..4 to M=16 and wastes ~7/8 of the inner-product work. The actual correctness concern is narrower: under DS4_MTP_STRICT (or --quality), users require byte-equality with plain decode. Share-warp is not bit-identical to cuBLAS Gemm at small M (different reduction order), so strict-mode must fall through to cuBLAS. In non-strict mode this drift is acceptable -- it matches PR #6's combined-forward Option-B pattern (same env knob selects strict vs perf). Replace the cuBLAS-cache-availability check with `!strict_mtp_env`. Same opt-out shape as the combined-forward gate in ds4_session_eval_speculative_argmax. The `blocks <= 32u` constraint is preserved (share-warp is bit-equal to N=1 warp8 only at blocks<=32; larger block counts drift from the batch_warp8 reference and would fail ds4_test --all long-context tensor equivalence -- verified empirically during bisect). Bench impact (DGX Spark, ds4flash, n=256, "knight" prompt): Default `--mtp` (combined-forward fires) 15.6 -> 16.20 t/s (+3.8%) `DS4_MTP_STRICT=1 --mtp` (canonical) 13.83 -> 13.83 unchanged Plain decode 16.60 -> 16.60 unchanged Strict-mode byte-equality vs PR #6 baseline confirmed (diff empty). `./ds4_test --all` shows the same 1 pre-existing failure as PR #6 (`logprob-vectors short_code_completion`, also fails on upstream/main).
Owner
Author
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
PR6: mtp: combined-forward default + Option-B strict fallback (stacked on #5)
Summary
Makes combined-forward MTP the default speculative decode path. Folds
eval(first_token)+ verify into a single batched-N=2 forward, saving one decode worth of compute per spec iter. Faithful port of the same path on the downstreammtp-beats-plainbranch; provides architectural symmetry with how upstream already handles MTP-related dispatch viaDS4_MTP_BATCH_VERIFY.Strict mode (
DS4_MTP_STRICT=1or--quality) falls back internally to canonicaldecode2_exact, preserving byte-equality to plain decode.No new env knobs.
DS4_MTP_STRICTis pre-existing. The previous opt-in flagDS4_MTP_COMBINED_FORWARDis removed entirely. Combined becomes the default code path; strict-mode fallback is internal logic.Commit stack
d1c7078— mtp: extract metal_graph_eval_mtp_draft_n_from_hc — batched MTP-draft primitive0660011— mtp: prefix-2 capture + combined N=2 forward138d93a— mtp: bootstrap combined_prev_hc from canonical evale004169— mtp: use capture_prefix1 even in strict mode for K=2 (skip full snapshot)a5141a3— mtp: extend combined-forward to K=2 + cache spec_row_logits buffer33bb699— mtp: default combined-forward on (Option-B strict-mode falls back to canonical) — the gate changebb85073— speed-bench: refresh gb10.csvHonest performance numbers (DGX Spark / GB10, n=256, "knight" prompt)
--mtp)--mtpDS4_MTP_STRICT=1 --mtpReference: the downstream
origin/mtp-beats-plainbranch (which has the same combined-forward path) benches at 16.11 t/s for default--mtpon this hardware — this PR is parity with the source.Important caveats:
--mtpis currently slightly slower than plain decode (15.56 vs 16.60). The combined-forward design saves compute per iter but the batched-N forward has overhead that, on this hardware/kernel-mix, doesn't fully amortize against plain decode's simpler single-token path.What this PR is and isn't
upstream/mainfor plain decode and for strict-mode MTPAGENT.md compliance
DS4_MTP_COMBINED_FORWARDknob is removed. Default behavior changes (combined becomes the path); strict-mode fallback is internal logic gated on the pre-existingDS4_MTP_STRICT.Tested against
make clean && make cuda-spark— clean, no warningsmake cpu— clean build./ds4_test --all— only pre-existing--logprob-vectors short_code_completionfailure (also onupstream/main)--mtpcombined fires every iter (verified viaDS4_MTP_TIMING=1)DS4_MTP_STRICT=1): byte-equal to PR cuda: DS4_CUDA_STRICT_BATCHED — bit-equal batched-N infrastructure (stacked on #4) #5 strict baseline (combined gate refuses, falls back to decode2_exact)--mtp): byte-equal to PR cuda: DS4_CUDA_STRICT_BATCHED — bit-equal batched-N infrastructure (stacked on #4) #5make cuda-regression— pre-existing build error intests/cuda_long_context_smoke.c(also onupstream/main), unchangedDeepSeek-V4-Flash-IQ2XXS-w2Q2K-AProjQ8-SExpQ8-OutQ8-chat-v2-imatrix.ggufDeepSeek-V4-Flash-MTP-Q4K-Q8_0-F32.ggufspeed-bench/gb10.csvrefreshed via canonical ds4-bench sweep with PR6 tip numbersOut of scope / follow-ups
DS4_CUDA_STRICT_BATCHED=1+ combined-K=1 strict be byte-equal AND fast)DS4_MTP_COMBINED_K2=1— currently loses to K=1 due to drafts[1] staleness)Note on PR #5 history
PR #5 was force-pushed earlier today to drop a 7th commit (
48d2666 cuda: gate Q8 share-warp kernel under DS4_CUDA_STRICT_BATCHED) that was identified as a correctness regression during PR6 investigation. The 6 remaining gates in PR #5 are orthogonal to this PR's correctness story; combined-forward here runs over the existing default kernels.