cuda: replace K4 cuBLAS-cache gate with strict-mode gate (stacked on #6)#7
Closed
TrevorS wants to merge 1 commit into
Closed
cuda: replace K4 cuBLAS-cache gate with strict-mode gate (stacked on #6)#7TrevorS wants to merge 1 commit into
TrevorS wants to merge 1 commit into
Conversation
…rp dispatch PR #3's K4 commit (584de5e on the v2 branch) added a cuBLAS-cache- availability gate to the share-warp Q8 matmul dispatch in cuda_matmul_q8_0_tensor_labeled. The intent was correctness: prevent share-warp from displacing cuBLAS where cuBLAS would have handled the weight. On DGX Spark this gate is empirically always-false: every Q8 weight has a cached F16 copy by the time the dispatcher runs, so share-warp NEVER fires at n_tok=2..4 with blocks<=32 -- which means the Q_B and output_proj_b matmuls (the only V4-Flash matmuls with blocks=32) silently route through cuBLAS' small-M tensor-core path which pads M=2..4 to M=16 and wastes ~7/8 of the inner-product work. The actual correctness concern is narrower: under DS4_MTP_STRICT (or --quality), users require byte-equality with plain decode. Share-warp is not bit-identical to cuBLAS Gemm at small M (different reduction order), so strict-mode must fall through to cuBLAS. In non-strict mode this drift is acceptable -- it matches PR #6's combined-forward Option-B pattern (same env knob selects strict vs perf). Replace the cuBLAS-cache-availability check with `!strict_mtp_env`. Same opt-out shape as the combined-forward gate in ds4_session_eval_speculative_argmax. The `blocks <= 32u` constraint is preserved (share-warp is bit-equal to N=1 warp8 only at blocks<=32; larger block counts drift from the batch_warp8 reference and would fail ds4_test --all long-context tensor equivalence -- verified empirically during bisect). Bench impact (DGX Spark, ds4flash, n=256, "knight" prompt): Default `--mtp` (combined-forward fires) 15.6 -> 16.20 t/s (+3.8%) `DS4_MTP_STRICT=1 --mtp` (canonical) 13.83 -> 13.83 unchanged Plain decode 16.60 -> 16.60 unchanged Strict-mode byte-equality vs PR #6 baseline confirmed (diff empty). `./ds4_test --all` shows the same 1 pre-existing failure as PR #6 (`logprob-vectors short_code_completion`, also fails on upstream/main).
bb85073 to
dcb5cc3
Compare
bb27595 to
28ff7ce
Compare
Owner
Author
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
PR7: cuda: replace K4 cuBLAS-cache gate with strict-mode gate (stacked on #6)
Summary
PR #3's K4 commit (
584de5eon the v2 branch) added a cuBLAS-cache-availability gate to the share-warp Q8 matmul dispatcher. The intent was correctness — prevent share-warp from displacing cuBLAS where cuBLAS would have handled the weight. On DGX Spark this gate is empirically always-false: every Q8 weight has a cached F16 copy by the time the dispatcher runs, so share-warp NEVER fires at n_tok=2..4. This silently routes the Q_B and output_proj_b matmuls (the only V4-Flash matmuls with blocks=32) through cuBLAS' small-M tensor-core path which pads M=2..4 → M=16 and wastes ~7/8 of the inner-product work.The actual correctness concern is narrower: under
DS4_MTP_STRICT(or--quality), users require byte-equality with plain decode. Share-warp is not bit-identical to cuBLAS Gemm at small M (different reduction order), so strict-mode must fall through to cuBLAS. In non-strict mode this drift is acceptable — same Option-B pattern as the combined-forward gate in PR #6.Replaces the cuBLAS-cache-availability check with
!strict_mtp_env. Theblocks <= 32uconstraint is preserved — share-warp is only bit-equal to the N=1 reference at blocks ≤ 32 (verified empirically during bisect: dropping that constraint causesds4_test --allto faillong_memory_archivegreedy-equivalence).Bench impact (DGX Spark, ds4flash, n=256, "knight")
--mtp(combined-forward fires)DS4_MTP_STRICT=1 --mtp(canonical)--mtp)How this PR was discovered
Bisect investigation triggered by noticing
mtp-beats-plain @ 45ba7613(downstream source) hits 19.7 t/s under combined-forward, while PR #6 (faithful cherry-pick port) only hits 15.6. Bisect localized 73% of the regression to this gate. The remaining 5% (16.2 → 19.7) is base-tree drift not localized to any single commit — likely from the captured-graph subsystem absent on this stack (which isn't a behavior we want to chase without proper investigation).Why not drop
blocks <= 32utoo?The bisect agent's full revert of K4 hit 18.6 t/s but didn't run
ds4_test --all. Verified that droppingblocks <= 32uhere causes share-warp to fire for matmuls where it's NOT bit-equal to the warp8 reference (different block-loop reduction order), failingds4_test --alllong_memory_archivewithgreedy_fail=4 top1_mismatch=1. The +2.5 t/s from dropping that constraint requires a kernel-level rewrite to make share-warp bit-equal at large block counts — separate PR.Tested against
make clean && make cuda-spark— clean, no warningsmake cpu— clean./ds4_test --all— only pre-existing--logprob-vectors short_code_completionfailure (same as upstream/main, PR cuda: DS4_CUDA_STRICT_BATCHED — bit-equal batched-N infrastructure (stacked on #4) #5, PR mtp: combined-forward default + Option-B strict fallback (stacked on #5) #6). Tensor-equivalence:capture_fail=0 logits_fail=0 greedy_fail=0 top1_mismatch=0DS4_MTP_STRICT=1) byte-identical to PR mtp: combined-forward default + Option-B strict fallback (stacked on #5) #6 strict output (diff empty)DeepSeek-V4-Flash-IQ2XXS-w2Q2K-AProjQ8-SExpQ8-OutQ8-chat-v2-imatrix.ggufAGENT.md compliance
DS4_MTP_STRICTenv knob to gate the fast pathDS4_CUDA_NO_Q8_SHARE_BATCH=1opt-out is preserved as the diagnostic kill switchOut of scope / follow-ups