Skip to content

cli: default mtp_draft_tokens=2; cuda: post-stack cleanup (stacked on #9)#10

Closed
TrevorS wants to merge 1 commit into
mtp-beats-plain-kernels-v8from
mtp-beats-plain-kernels-v9
Closed

cli: default mtp_draft_tokens=2; cuda: post-stack cleanup (stacked on #9)#10
TrevorS wants to merge 1 commit into
mtp-beats-plain-kernels-v8from
mtp-beats-plain-kernels-v9

Conversation

@TrevorS
Copy link
Copy Markdown
Owner

@TrevorS TrevorS commented May 24, 2026

PR10: cli: default mtp_draft_tokens=2; cuda: post-stack cleanup pass (stacked on #9)

Summary

Capstone for the mtp-beats-plain-kernels stack. Two changes:

  1. ds4_cli.c: default mtp_draft_tokens flips from 1 → 2. With the share-warp Q8 (PR8) and F16 (PR9) kernels landed, combined-forward MTP K=1 delivers +1.0 t/s above plain decode on DGX Spark. Previously users had to pass --mtp-draft 2 explicitly to get the win; the default of 1 routed --mtp through canonical decode2_exact and produced no measurable win over plain.

  2. ds4_cuda.cu: cleanup pass — strip fork-internal "PR3"/"PR5"/"PR7"/"PR8"/"PR9" comment prefixes that reference this stack's own PR numbering. Tighten the share-warp dispatcher comments to reflect the settled final shape (was scoped wider in intermediate revisions). Net -18 LOC with no behavior change to the kernels.

  3. speed-bench/gb10.csv: refreshed with PR-stack-tip plain-decode numbers from the standard CONTRIBUTING.md sweep.

Bench headline (DGX Spark / GB10, ds4flash, n=256, "knight" prompt, 5-run mean)

Mode Before stack (upstream-minimal) PR10 Δ
Plain decode 13.9 16.13 +2.2
--mtp (default) 8.8 17.14 +8.3
--mtp vs plain on PR10 +1.01 (MTP > Plain)

All 5 default---mtp runs landed within 0.03 t/s: 17.16 / 17.15 / 17.13 / 17.15 / 17.13.

Other modes (still supported):

Mode t/s
--mtp --mtp-draft 1 (forces canonical decode2) 16.18
--mtp --quality (strict canonical) 12.92

CONTRIBUTING.md test sweep

make clean && make cuda-spark — clean, no warnings
make cpu — clean
~make test (= ./ds4_test --all) — 1 pre-existing failure
╰─ --logprob-vectors short_code_completion (same on upstream/main, PR5-9; test fixture's official continuation is one greedy token off; well-documented across this stack)
~make cuda-regression — pre-existing build error
╰─ tests/cuda_long_context_smoke.c has a stale signature for ds4_gpu_attention_decode_heads_tensor (verified same on PR7 base and upstream/main; not introduced here)
./ds4-bench standard 2048..65536 sweep — written to speed-bench/gb10.csv
╰─ no regression vs prior baseline: every measured context size within ±0.05 t/s

What this stack delivers, end-to-end

upstream/main → +9 stacked PRs → PR10

Plain decode:  13.9 → 16.13 t/s   (+16% from HBM-resident model in PR1
                                  + small-N kernel polish in PR2-5)

Default --mtp: 8.8 → 17.14 t/s    (+95% from MTP infra + share-warp
                                  Q8 / F16 kernels in PR8 / PR9)

MTP > Plain:    -5.1 → +1.01 t/s  (MTP went from 37% slower than plain
                                  to 6% faster than plain)

The 10-PR stack:

TrevorS/ds4
├── #1  upstream-minimal              — HBM-resident model
├── #2  mtp-beats-plain-kernels       — small-N kernel polish
├── #3  mtp-beats-plain-kernels-v2    — K4+K1 revived (FMA + cache gate)
├── #4  mtp-beats-plain-kernels-v3    — ACCEPT_REPORT + SPEC_DISABLE
├── #5  mtp-beats-plain-kernels-v4    — 6 strict-batched gates
├── #6  mtp-beats-plain-kernels-v5    — combined-forward default + Option-B strict gate
├── #7  mtp-beats-plain-kernels-v6    — K4 gate fix (cuBLAS-cache → strict-mode)
├── #8  mtp-beats-plain-kernels-v7    — share-warp Q8 bit-equal at any block count
├── #9  mtp-beats-plain-kernels-v8    — F16 share-warp at n_tok=2 → MTP > Plain ✓
└── #10 mtp-beats-plain-kernels-v9    — default --mtp-draft 2 + cleanup (this PR)

AGENT.md compliance

  • "Preserve correctness before speed" — plain decode bit-identical, strict (--quality / DS4_MTP_STRICT=1) preserved as byte-equal-to-plain canonical path.
  • "Do not add permanent semantic variants behind flags" — no new env knob. The --mtp-draft CLI flag is pre-existing; we change only its default value.
  • "Diagnostic switches are fine when they validate the one release path" — all DS4_CUDA_NO_* opt-outs preserved as kill switches.
  • "Comments instructive and compact: explain why a shape, ordering, cache boundary, or memory choice exists" — cleanup pass keeps the why and drops fork-internal PR labels.

Hardware

  • NVIDIA DGX Spark (GB10 / sm_121), driver 580.142, CUDA 13.0
  • Model: DeepSeek-V4-Flash-IQ2XXS-w2Q2K-AProjQ8-SExpQ8-OutQ8-chat-v2-imatrix.gguf
  • MTP: DeepSeek-V4-Flash-MTP-Q4K-Q8_0-F32.gguf

Out of scope / follow-ups

The scout from PR9's bench analysis (recorded as comments in PR9) identifies the remaining levers if the stack lands and there's appetite for more:

  • Captured-graph subsystem — verifier launches ~1,390 kernels/token; CUDA graphs would cut launch overhead
  • Q8 share-warp tuning at n_tok=2 — currently 24% of GPU time, avg 88 µs/call; bandwidth-bound matmul with room for register-tile rework
  • F16 share-warp at n_tok=3..4 with a combined-forward-context gate (enables K=2 cascading once drafts[1] staleness is fixed)
  • K=2 combined-forward drafts[1] staleness — currently rejects nearly always; fixing unlocks --mtp-draft 3

Makes the combined-forward MTP win available at zero CLI ceremony.
With the share-warp Q8 and F16 kernels landed earlier in this stack,
combined-forward K=1 (mtp_draft_tokens=2) at the verifier reaches
+1.0 t/s above plain decode on DGX Spark.  Previously users had to
pass `--mtp-draft 2` explicitly; the default of 1 routed all `--mtp`
sessions through the canonical decode2_exact path and delivered no
visible win over plain decode.

Default flips from 1 -> 2.  The path through `DS4_MTP_STRICT=1` (or
`--quality`) still falls back to canonical decode2_exact for users
who require byte-equality with plain decode.

Cleanup pass on ds4_cuda.cu now that the share-warp design is settled:
  - Strip fork-internal "PR3/PR5/PR7/PR8/PR9" comment prefixes that
    referenced this stack's own PR numbering; they would be noise in
    the upstream history.  The behavioral comments themselves are
    preserved or tightened.
  - Drop the share-warp dispatcher's stale "no perf delta because
    combined-forward routes through F16" note: combined-forward at
    --mtp-draft 2 DOES route through Q8 share-warp 10k+ times per
    64-token gen (confirmed by nsys profile), so PR8's bit-equality
    rewrite is load-bearing for the win this PR makes default.
  - Tighten the F16 share-warp comment block to reflect the
    n_tok==2 final shape (was scoped to 2..4 in intermediate
    revisions).

Refresh `speed-bench/gb10.csv` with current PR-stack tip numbers
(plain decode only; ds4-bench doesn't drive --mtp).  Plain decode is
within +/-0.05 t/s of the prior baseline at every measured context
size from 2048 to 65536 -- no regression, slight improvement at
small contexts.

Bench (DGX Spark / GB10, ds4flash, n=256, "knight" prompt, 5-run mean):

  Plain decode (no --mtp)                          16.13 t/s
  --mtp (now defaults to combined-forward K=1)     17.14 t/s  (+1.01)
  --mtp --mtp-draft 1 (forces canonical decode2)   16.18 t/s  (parity)
  --mtp --quality (strict canonical)               12.92 t/s

CONTRIBUTING.md test sweep:
  - `make clean && make cuda-spark`               clean
  - `make cpu`                                    clean
  - `make test` (= `./ds4_test --all`)            1 failure, pre-existing
    (`--logprob-vectors short_code_completion`, same as upstream/main,
    PR5-9; the test fixture's official continuation is one greedy
    token off, well-documented across this stack)
  - `make cuda-regression`                        pre-existing build
    error in `tests/cuda_long_context_smoke.c` (stale signature for
    `ds4_gpu_attention_decode_heads_tensor`; same on PR7 base and
    upstream/main, not introduced here)
  - `./ds4-bench` standard 2048..65536 sweep      written to
    speed-bench/gb10.csv; no regression vs prior baseline

Hardware: NVIDIA DGX Spark (GB10 / sm_121), driver 580.142, CUDA 13.0
Model: DeepSeek-V4-Flash-IQ2XXS-w2Q2K-AProjQ8-SExpQ8-OutQ8-chat-v2-imatrix.gguf
MTP:   DeepSeek-V4-Flash-MTP-Q4K-Q8_0-F32.gguf

This is the capstone for the mtp-beats-plain-kernels stack: 10 PRs,
each minimal, that take MTP from -0.5 t/s under plain to +1.0 t/s
above plain on DGX Spark, with bit-equal plain decode preserved.
@TrevorS TrevorS force-pushed the mtp-beats-plain-kernels-v8 branch from 120c033 to cfddd4b Compare May 24, 2026 17:14
@TrevorS TrevorS force-pushed the mtp-beats-plain-kernels-v9 branch from 4d7eda8 to cf209be Compare May 24, 2026 17:14
@TrevorS
Copy link
Copy Markdown
Owner Author

TrevorS commented May 24, 2026

Superseded by the reframed 2-PR stack (#11 + #12), which tells the same Spark/GB10 + MTP combined-forward story more concisely, rebased on current upstream/main, with the exploratory paths dropped.

@TrevorS TrevorS closed this May 24, 2026
@TrevorS TrevorS deleted the mtp-beats-plain-kernels-v9 branch May 24, 2026 22:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant