Skip to content

mtp: combined-forward default + Option-B strict fallback (stacked on #5)#6

Closed
TrevorS wants to merge 7 commits into
mtp-beats-plain-kernels-v4from
mtp-beats-plain-kernels-v5
Closed

mtp: combined-forward default + Option-B strict fallback (stacked on #5)#6
TrevorS wants to merge 7 commits into
mtp-beats-plain-kernels-v4from
mtp-beats-plain-kernels-v5

Conversation

@TrevorS
Copy link
Copy Markdown
Owner

@TrevorS TrevorS commented May 24, 2026

PR6: mtp: combined-forward default + Option-B strict fallback (stacked on #5)

Summary

Makes combined-forward MTP the default speculative decode path. Folds eval(first_token) + verify into a single batched-N=2 forward, saving one decode worth of compute per spec iter. Faithful port of the same path on the downstream mtp-beats-plain branch; provides architectural symmetry with how upstream already handles MTP-related dispatch via DS4_MTP_BATCH_VERIFY.

Strict mode (DS4_MTP_STRICT=1 or --quality) falls back internally to canonical decode2_exact, preserving byte-equality to plain decode.

No new env knobs. DS4_MTP_STRICT is pre-existing. The previous opt-in flag DS4_MTP_COMBINED_FORWARD is removed entirely. Combined becomes the default code path; strict-mode fallback is internal logic.

Commit stack

  1. d1c7078 — mtp: extract metal_graph_eval_mtp_draft_n_from_hc — batched MTP-draft primitive
  2. 0660011 — mtp: prefix-2 capture + combined N=2 forward
  3. 138d93a — mtp: bootstrap combined_prev_hc from canonical eval
  4. e004169 — mtp: use capture_prefix1 even in strict mode for K=2 (skip full snapshot)
  5. a5141a3 — mtp: extend combined-forward to K=2 + cache spec_row_logits buffer
  6. 33bb699mtp: default combined-forward on (Option-B strict-mode falls back to canonical) — the gate change
  7. bb85073 — speed-bench: refresh gb10.csv

Honest performance numbers (DGX Spark / GB10, n=256, "knight" prompt)

Mode PR #5 baseline This PR Δ
Plain decode (no --mtp) 16.60 16.60 unchanged
Default --mtp 15.42 15.56 +0.14 t/s
DS4_MTP_STRICT=1 --mtp 13.84 13.83 unchanged (strict falls back to canonical)

Reference: the downstream origin/mtp-beats-plain branch (which has the same combined-forward path) benches at 16.11 t/s for default --mtp on this hardware — this PR is parity with the source.

Important caveats:

  • The marquee win is the architectural simplification (one default path for spec decode, strict-mode internal fallback), not the t/s delta. The +0.14 t/s is barely above bench noise.
  • Default combined --mtp is currently slightly slower than plain decode (15.56 vs 16.60). The combined-forward design saves compute per iter but the batched-N forward has overhead that, on this hardware/kernel-mix, doesn't fully amortize against plain decode's simpler single-token path.
  • An earlier session anecdotally measured combined at ~19.58 t/s on a different stack state — that number is not currently achievable on this branch tip. Closing the gap would require fixing the residual batched-N kernel divergences (the work that PR cuda: DS4_CUDA_STRICT_BATCHED — bit-equal batched-N infrastructure (stacked on #4) #5 began and PR mtp: combined-forward default + Option-B strict fallback (stacked on #5) #6's investigation showed is deeper than initial estimates).

What this PR is and isn't

  • Is: a clean architectural change that makes combined-forward the default, with strict-mode preserved as a byte-equal fallback. Removes the opt-in env knob. Faithful port of the downstream path.
  • Is: preserves byte-equality vs upstream/main for plain decode and for strict-mode MTP
  • Is NOT: a major perf win — the t/s improvement is +0.14 within noise
  • Is NOT: a path to "MTP > Plain" — default combined remains below plain on this kernel stack

AGENT.md compliance

  • "Preserve correctness before speed" — the drift between combined-forward and canonical (in non-strict mode only) is explained: at small batched N, MoE/attention paths differ in FMA accumulation order from N=1. Documented in PR cuda: DS4_CUDA_STRICT_BATCHED — bit-equal batched-N infrastructure (stacked on #4) #5's strict-batched infrastructure description.
  • "Do not add permanent semantic variants behind flags"no new flag added. The opt-in DS4_MTP_COMBINED_FORWARD knob is removed. Default behavior changes (combined becomes the path); strict-mode fallback is internal logic gated on the pre-existing DS4_MTP_STRICT.
  • Default behavior change is disclosed explicitly above: combined produces subtly different (still coherent) output than canonical in non-strict mode.

Tested against

  • make clean && make cuda-spark — clean, no warnings
  • make cpu — clean build
  • ./ds4_test --all — only pre-existing --logprob-vectors short_code_completion failure (also on upstream/main)
  • Default non-strict --mtp combined fires every iter (verified via DS4_MTP_TIMING=1)
  • 3-run stability: byte-identical coherent output across runs
  • Strict mode (DS4_MTP_STRICT=1): byte-equal to PR cuda: DS4_CUDA_STRICT_BATCHED — bit-equal batched-N infrastructure (stacked on #4) #5 strict baseline (combined gate refuses, falls back to decode2_exact)
  • Plain decode (no --mtp): byte-equal to PR cuda: DS4_CUDA_STRICT_BATCHED — bit-equal batched-N infrastructure (stacked on #4) #5
  • make cuda-regression — pre-existing build error in tests/cuda_long_context_smoke.c (also on upstream/main), unchanged
  • Hardware: NVIDIA DGX Spark (GB10 / sm_121), driver 580.142, CUDA 13.0
  • Model: DeepSeek-V4-Flash-IQ2XXS-w2Q2K-AProjQ8-SExpQ8-OutQ8-chat-v2-imatrix.gguf
  • MTP: DeepSeek-V4-Flash-MTP-Q4K-Q8_0-F32.gguf
  • speed-bench/gb10.csv refreshed via canonical ds4-bench sweep with PR6 tip numbers

Out of scope / follow-ups

  • Closing residual batched-attention divergence (would let DS4_CUDA_STRICT_BATCHED=1 + combined-K=1 strict be byte-equal AND fast)
  • Captured-graph spec decode subsystem
  • K=2 combined cascading (opt-in via DS4_MTP_COMBINED_K2=1 — currently loses to K=1 due to drafts[1] staleness)

Note on PR #5 history

PR #5 was force-pushed earlier today to drop a 7th commit (48d2666 cuda: gate Q8 share-warp kernel under DS4_CUDA_STRICT_BATCHED) that was identified as a correctness regression during PR6 investigation. The 6 remaining gates in PR #5 are orthogonal to this PR's correctness story; combined-forward here runs over the existing default kernels.

TrevorS added a commit that referenced this pull request May 24, 2026
…rp dispatch

PR #3's K4 commit (584de5e on the v2 branch) added a cuBLAS-cache-
availability gate to the share-warp Q8 matmul dispatch in
cuda_matmul_q8_0_tensor_labeled. The intent was correctness: prevent
share-warp from displacing cuBLAS where cuBLAS would have handled the
weight. On DGX Spark this gate is empirically always-false: every Q8
weight has a cached F16 copy by the time the dispatcher runs, so
share-warp NEVER fires at n_tok=2..4 with blocks<=32 -- which means
the Q_B and output_proj_b matmuls (the only V4-Flash matmuls with
blocks=32) silently route through cuBLAS' small-M tensor-core path
which pads M=2..4 to M=16 and wastes ~7/8 of the inner-product work.

The actual correctness concern is narrower: under DS4_MTP_STRICT (or
--quality), users require byte-equality with plain decode. Share-warp
is not bit-identical to cuBLAS Gemm at small M (different reduction
order), so strict-mode must fall through to cuBLAS. In non-strict
mode this drift is acceptable -- it matches PR #6's combined-forward
Option-B pattern (same env knob selects strict vs perf).

Replace the cuBLAS-cache-availability check with `!strict_mtp_env`.
Same opt-out shape as the combined-forward gate in
ds4_session_eval_speculative_argmax. The `blocks <= 32u` constraint
is preserved (share-warp is bit-equal to N=1 warp8 only at blocks<=32;
larger block counts drift from the batch_warp8 reference and would
fail ds4_test --all long-context tensor equivalence -- verified
empirically during bisect).

Bench impact (DGX Spark, ds4flash, n=256, "knight" prompt):
  Default `--mtp` (combined-forward fires)   15.6 -> 16.20 t/s  (+3.8%)
  `DS4_MTP_STRICT=1 --mtp` (canonical)       13.83 -> 13.83     unchanged
  Plain decode                               16.60 -> 16.60     unchanged

Strict-mode byte-equality vs PR #6 baseline confirmed (diff empty).
`./ds4_test --all` shows the same 1 pre-existing failure as PR #6
(`logprob-vectors short_code_completion`, also fails on upstream/main).
TrevorS added 7 commits May 24, 2026 10:03
Adds Phase 3 (prefix-2 compressor-state capture) as load-bearing
infrastructure for cheap K>=2 partial-accept rollback, plus a working
Phase 1 combined N=2 forward path (opt-in via DS4_MTP_COMBINED_FORWARD)
that folds eval(first_token) into a single batched 2-row forward.
Default behavior unchanged.

Phase 3: prefix-2 capture (load-bearing)
  - 4 new ds4_gpu_graph fields: spec_prefix2_attn_state_kv/score[L],
    spec_prefix2_index_state_kv/score[L], counters spec_prefix2_n_comp/
    n_index_comp[L], bool spec_capture_prefix2.
  - Alloc/free mirror prefix-1 alongside existing init/cleanup loops.
  - metal_graph_capture_prefix2_{attn,index}_state functions; mirror
    prefix-1 and trigger at t==1 in encode_layer_batch.
  - spec_frontier_commit_prefix2 cheap commit (~0.5-1ms per Spark
    numbers vs ~7-12ms for full snapshot+restore).
  - metal_graph_verify_suffix_tops grows capture_prefix2 param;
    existing callers pass false.
  - spec_argmax_snapshot_combined_prev_hc helper + graph field
    combined_prev_hc + session field combined_prev_hc_valid.  Canonical
    full-accept and 1-accept(prefix1) paths now snapshot the post-commit
    HC row so a future combined iter can derive drafts[0] correctly.

Phase 1: combined N=2 forward (opt-in, off by default)
  - ds4_session_eval_speculative_argmax_combined helper: forwards
    [first_token, drafts[0]] in a single batched 2-row pass.  Effectively
    K=1 spec with first_token folded into the verifier call (saving
    eval's separate raw_swa forward).  Gated by DS4_MTP_COMBINED_FORWARD
    env; engages only when prior iter's combined_prev_hc is valid.
  - drafts[0] derivation: MTP-block(prev_hc=combined_prev_hc,
    token=first_token, pos=base_pos) produces the correct prediction
    for position base_pos+1 given that combined_prev_hc holds
    post-position-(base_pos-1) HC (= post-last-committed-token HC from
    prior iter).
  - Cold start (no combined_prev_hc_valid) falls back to canonical so
    the first iter primes the snapshot for the next combined iter.
  - Partial-accept rollback via spec_frontier_commit_prefix1 for
    commit=0 (drafts[0] rejected, first_token only committed).

Path tried and reverted: N=3 combined with chained drafts[1]
  An earlier draft used N=3 verifier on [first_token, drafts[0..1]]
  and chained MTP for drafts[1].  Bench showed drafts[1] systematically
  rejected (100% miss) because the chained MTP call inherits a
  mtp_state_hc from drafts[0]'s MTP-block, which used prev_hc =
  combined_prev_hc (= post-(p-1) HC).  The MTP layer's "informed"
  prediction semantic expects prev_hc = post-p HC (= post-first_token
  HC, which canonical obtains via raw_swa-forward of first_token).
  Cascade: drafts[1] uses a slightly-off baseline → diverges from
  target's row_tops[1].  Cannot be fixed on mainline without either
  (a) forwarding first_token first (= the eval cost we are trying to
  skip), or (b) fused main-forward + MTP-block kernels (out of scope,
  not present in Spark either).  N=2 mode skips drafts[1] and pays
  only for the necessary 2-row verifier.

Bench (DGX Spark, ds4flash.gguf + MTP, --temp 0 --mtp-draft 2):
  Plain (no MTP):                  16.78 t/s  — no regression.
  Canonical K=2 full-accept iter:  ~124ms / 3 tokens = ~24 t/s peak.
  Canonical K=2 margin-skip iter:  ~68ms  / 2 tokens = ~29 t/s peak.
  Combined N=2 commit=1 iter:      ~125ms / 2 tokens = ~16 t/s.
  Combined N=2 commit=0 iter:      ~125ms / 1 token  = ~8 t/s.
  Combined acceptance rate ~83% on test prompt; average ~14.7 t/s.
  Output byte-equal to canonical for 16-token sweep.

Architectural note on K>=2 combined ambition:
  Reaching canonical's full-accept 3-tokens-per-iter ceiling in the
  combined path requires fixing the drafts[1+] cascade, which requires
  kernel-level work to either (i) make batched-N=1 byte-equal to
  raw_swa (Spark's unresolved dense_step_n drift territory) or (ii)
  fuse MTP-block into the main forward kernel.  Phase 3 infrastructure
  (prefix-2 capture, combined_prev_hc snapshot) is reusable for those
  follow-ups.

NO github push.  4-commit jj stack on top of main.
Adds spec_argmax_bootstrap_combined_prev_hc which copies the post-final-
layer HC of the just-evaluated token (g->cur_hc) into combined_prev_hc
after every ds4_session_eval inside spec_argmax.  Cost is one device-to-
device tensor copy of N_HC * N_EMBD * 4 bytes = 112 KiB; negligible in
the 200ms-per-iter speculative path.

Why
---
The combined-N=K speculative path (DS4_MTP_COMBINED_FORWARD) returns -2
on cold start when !s->combined_prev_hc_valid.  Without this bootstrap
combined_prev_hc is only populated by combined itself (line 18223), so
the cold-start bailout is permanent: spec_argmax_combined falls through
to canonical every iter, and canonical never sets combined_prev_hc, so
combined is structurally dead.

After this commit any canonical iter populates combined_prev_hc with the
post-final-layer HC of first_token, which is the correct prev_hc for
the NEXT iter's MTP-block call:
  drafts[0] = mtp_block(prev_hc=post-first_token-HC, token=NEW_first_token)

This is what spark already does in its forward_mtp_step path.

Effect
------
- Canonical K=2 strict accept rate unchanged (83.9%).
- Bench (n=256, 3-run avg): canonical K=2 strict 7.59 t/s, combined K=1
  effective 8.06 t/s.  Combined no longer cold-starts every iter.
- Combined-N=2 path still does NOT beat plain decode (15.83 t/s) -- it
  costs ~130ms for batched-N=2 forward vs ~65ms for single decode, and
  K=1 effective limits accepted tokens.  Phase 4 batched-kernel speedup
  work is the next layer.

LOC
---
ds4.c: +24/-1 (one new helper + one call site).  No header changes.
DS4_MTP_COMBINED_FORWARD is still opt-in.

NO github push.  jj change oxqszsyw.
…hot)

Before this change, ds4_session_eval_speculative_argmax forced
spec_frontier_snapshot in DS4_MTP_STRICT mode at K=2.  That snapshot
copies per-layer compressor state (raw + indexer KV/score tensors for
the 21 layers that run compression) which totals ~6 MiB of GPU memory
per spec iteration.  At 39 iters per 128 generated tokens that's
~240 MiB of HBM copying purely for rollback insurance.

The capture_prefix1 path (already used by default in non-strict mode
and by combined-forward) does the same rollback work via cheap per-layer
counter resets in spec_frontier_commit_prefix1, with no GPU tensor
copies.  It is byte-correct on accept and on partial-accept because
the compressor caches are append-only -- rejected rows become invisible
once the counters rewind.

This commit removes the strict-mode special case and uses
capture_prefix1 unconditionally at K=2.  DS4_MTP_FORCE_SNAPSHOT still
forces the old full-snapshot path for measurement / regression
checking; DS4_MTP_CAPTURE_PREFIX1 becomes a no-op (the default).

Bench
-----
DGX Spark (GB10), ds4flash.gguf + MTP-Q4K, n=128, --mtp-draft 2,
DS4_MTP_STRICT=1, "knight" prompt, 3-run avg:

  Pre  (force snapshot): 5.20 t/s gen, 80.8% accept
  Post (capture_prefix1): 5.49 t/s gen, 80.8% accept
  Delta: +5.6%

Accept rate is identical (80.8% in both modes) -- the rollback
semantics are equivalent at K=2.  The 5.6% win comes purely from
eliminating the per-iter HBM snapshot copy.

Non-strict K=2 (margin=3 default): unchanged -- capture_prefix1 was
already true there.  Plain decode: 16.0 t/s, unchanged.

Tests: ds4_test passes (long-context, tool-call-quality, metal-kernels,
server).  Pre-existing logprob-vectors failure unchanged.

What this doesn't fix
---------------------
Even after this fix, mainline MTP K=2 strict mode is ~5.5 t/s vs spark's
~15.9 t/s -- still 3x slower than plain decode.  The remaining gap is
not in snapshot/restore (now removed) nor in the verifier compute path
(scout audited: kernels are tight).  Most likely candidates for the
residual gap:
  - Per-iter setup overhead (token_vec_push, checkpoint mgmt) that
    accumulates across many short spec iters
  - The `ds4_session_eval(first_token)` full-decode call at line 18312
    happens BEFORE the spec phase, paying full plain-decode cost per
    iter just to advance state with the verifier's known-good token
  - Spark batches first_token into the K-batched forward (combined
    mode), avoiding the separate eval call; mainline's combined mode
    is opt-in via DS4_MTP_COMBINED_FORWARD and has its own issues

LOC
---
ds4.c: +4/-3 (one boolean simplification + clarifying comment).

NO github push.  jj change kmssvuql -> zuxwytmo.
Two related changes packed together.

1. Combined-forward K=2 wiring
-----------------------------
`spec_argmax_combined` now also handles draft_cap=2 (N=3 batched verify
over [first_token, drafts[0], drafts[1]]) with prefix-2 commit dispatch
for the commit ∈ {0, 1, 2} cases.

Gated behind DS4_MTP_COMBINED_K2=1 because measurement shows the K=2
variant is currently a loss on mainline:

  Combined K=1 (N=2 batched): 9.51 t/s (no flag, /dev/null)
  Combined K=2 (N=3 batched): 7.34 t/s (DS4_MTP_COMBINED_K2=1)

Why K=2 loses: `drafts[1]` cascades from `drafts[0]`'s MTP-state, but
`drafts[0]` itself comes from `combined_prev_hc` (= post-previous-iter-
last-token HC), not from the fresh post-`first_token` main-HC the
canonical eval(first_token) would produce.  So `drafts[0]` is "one
position stale" already, and `drafts[1]` cascades further off-target.
The target verifier rejects `drafts[1]` in the vast majority of iters,
so the extra batched-N=3 row costs more than it pays.

Keeping the K=2 path as opt-in because the prefix-2 wiring is correct
and reusable when the staleness fix lands (interleaved MTP-block inside
batched main forward, spark-style).  See PHASE4.md item #1.

2. Session-cached spec_row_logits buffer
----------------------------------------
Adds `s->spec_row_logits_buf` (3 * VOCAB f32 = ~1.5 MiB) and
`s->spec_row_tops_buf` (3 * int) allocated at session creation,
replacing the per-spec-call xmalloc/free pattern in
`ds4_session_eval_speculative_argmax_combined`.

Measurement impact: small (~0-3% in noise).  The malloc overhead
hypothesis was a wrong guess at what was producing the 73 ms per-call
overhead between component-timed (~95 ms) and observed wall
(~168 ms) combined cost.  Documented in PHASE4.md item #3 -- the
actual source of that overhead is still unidentified after this
attempt.

Effect on default combined K=1: 9.51 -> 9.48 t/s (within noise).
Foundation for future xmalloc cleanup in the canonical path's
decode2_exact branch (still allocates per-call).

LOC
---
ds4.c: +67/-35 (combined K=2 dispatch + session buf fields + alloc/free
sites + caller changes).  Two new session fields, two new env gates.

NO github push.  jj change oxmoztuq.
…canonical)

Replaces the env-opt-in (DS4_MTP_COMBINED_FORWARD) gate with a default-on
gate for non-strict speculative decode.  Strict mode (DS4_MTP_STRICT=1 or
e->quality) now falls back to the canonical decode2_exact path so byte
equality with plain decode is preserved.

Batched-N MoE / attention is not yet bit-identical to N=1 raw_swa
(would need DS4_CUDA_STRICT_BATCHED + closing residual attention
divergence), so combined-forward can drift one token from canonical.
Output remains coherent; strict callers who want canonical bytes opt in
via DS4_MTP_STRICT=1.

No new env flag introduced: this just defaults the existing path on for
non-strict and routes strict to the existing canonical path.
ds4-bench measures plain (non-MTP) decode so numbers are within noise of
upstream — combined-forward only fires under --mtp.  Refreshed for
provenance under the new tip.
@TrevorS TrevorS force-pushed the mtp-beats-plain-kernels-v4 branch from ae95ce7 to 4ef4054 Compare May 24, 2026 17:14
@TrevorS TrevorS force-pushed the mtp-beats-plain-kernels-v5 branch from bb85073 to dcb5cc3 Compare May 24, 2026 17:14
TrevorS added a commit that referenced this pull request May 24, 2026
…rp dispatch

PR #3's K4 commit (584de5e on the v2 branch) added a cuBLAS-cache-
availability gate to the share-warp Q8 matmul dispatch in
cuda_matmul_q8_0_tensor_labeled. The intent was correctness: prevent
share-warp from displacing cuBLAS where cuBLAS would have handled the
weight. On DGX Spark this gate is empirically always-false: every Q8
weight has a cached F16 copy by the time the dispatcher runs, so
share-warp NEVER fires at n_tok=2..4 with blocks<=32 -- which means
the Q_B and output_proj_b matmuls (the only V4-Flash matmuls with
blocks=32) silently route through cuBLAS' small-M tensor-core path
which pads M=2..4 to M=16 and wastes ~7/8 of the inner-product work.

The actual correctness concern is narrower: under DS4_MTP_STRICT (or
--quality), users require byte-equality with plain decode. Share-warp
is not bit-identical to cuBLAS Gemm at small M (different reduction
order), so strict-mode must fall through to cuBLAS. In non-strict
mode this drift is acceptable -- it matches PR #6's combined-forward
Option-B pattern (same env knob selects strict vs perf).

Replace the cuBLAS-cache-availability check with `!strict_mtp_env`.
Same opt-out shape as the combined-forward gate in
ds4_session_eval_speculative_argmax. The `blocks <= 32u` constraint
is preserved (share-warp is bit-equal to N=1 warp8 only at blocks<=32;
larger block counts drift from the batch_warp8 reference and would
fail ds4_test --all long-context tensor equivalence -- verified
empirically during bisect).

Bench impact (DGX Spark, ds4flash, n=256, "knight" prompt):
  Default `--mtp` (combined-forward fires)   15.6 -> 16.20 t/s  (+3.8%)
  `DS4_MTP_STRICT=1 --mtp` (canonical)       13.83 -> 13.83     unchanged
  Plain decode                               16.60 -> 16.60     unchanged

Strict-mode byte-equality vs PR #6 baseline confirmed (diff empty).
`./ds4_test --all` shows the same 1 pre-existing failure as PR #6
(`logprob-vectors short_code_completion`, also fails on upstream/main).
@TrevorS
Copy link
Copy Markdown
Owner Author

TrevorS commented May 24, 2026

Superseded by the reframed 2-PR stack (#11 + #12), which tells the same Spark/GB10 + MTP combined-forward story more concisely, rebased on current upstream/main, with the exploratory paths dropped.

@TrevorS TrevorS closed this May 24, 2026
@TrevorS TrevorS deleted the mtp-beats-plain-kernels-v5 branch May 24, 2026 22:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant