cuda: DS4_CUDA_STRICT_BATCHED — bit-equal batched-N infrastructure (stacked on #4)#5
Closed
TrevorS wants to merge 6 commits into
Closed
cuda: DS4_CUDA_STRICT_BATCHED — bit-equal batched-N infrastructure (stacked on #4)#5TrevorS wants to merge 6 commits into
TrevorS wants to merge 6 commits into
Conversation
19568a3 to
ae95ce7
Compare
This was referenced May 24, 2026
…CT_BATCHED)
When DS4_CUDA_STRICT_BATCHED is set, skip the cuBLAS Sgemm (f32 cache) and
cublasGemmEx (f16 cache) branches at n_tok > 1 in cuda_matmul_q8_0_tensor_labeled
and fall through to the per-token Q8 batch_warp8 kernel.
The cuBLAS path takes a different numerical route -- cached F32 or F16 dequant
followed by tensor-core / TF32 matmul -- than the per-token warp8 kernel that
N=1 plain decode uses. This makes the batched-N MTP verifier path not bit-
identical to N=1 raw_swa decode, which blocks an upcoming combined-forward
strict mode that would sample next-iter s->logits from the verifier output.
Opt-in via env so existing non-strict callers still get the cuBLAS perf win
during prefill and the existing non-strict speculative verifier.
Validation:
baseline (DS4_CUDA_STRICT_BATCHED unset, MTP strict, n=32):
plain vs verify_suffix_tops diverges from pos=10, max_abs grows to ~38
with DS4_CUDA_STRICT_BATCHED=1 (this commit + later commits):
text output matches plain for n=32; max_abs drops to ~2-5; greedy
argmax matches at every dumped position.
Test: ds4_test unchanged (1 pre-existing failure: long_memory_archive top1).
NO github push.
…A_STRICT_BATCHED) When DS4_CUDA_STRICT_BATCHED is set, skip the cublasGemmEx F16 path at n_tok > 1 in ds4_gpu_matmul_f16_tensor and fall through to the per-token matmul_f16_kernel that N=1 plain decode uses. Same motivation as the q8 dispatcher gate (preceding commit): cuBLAS uses tensor cores with F16 accumulation, the per-token kernel uses scalar F32 accumulation -- the two paths diverge at ULP scale and the divergence amplifies through the layer stack. Opt-in env preserves cuBLAS perf for the non-strict default. Test: ds4_test unchanged. NO github push.
…_STRICT_BATCHED) When DS4_CUDA_STRICT_BATCHED is set, skip the cublasSgemm path at n_tok > 1 in ds4_gpu_matmul_f32_tensor and fall through to the per-token matmul_f32_kernel that N=1 plain decode uses. cuBLAS Sgemm picks tensor-core paths (TF32 by default; pedantic in quality mode) that compute with different rounding than the scalar warp kernel. Opt-in so the non-strict default keeps cuBLAS perf. Test: ds4_test unchanged. NO github push.
When DS4_CUDA_STRICT_BATCHED is set, skip the indexer_scores_wmma128/64/32_kernel selection in indexer_scores_launch and fall through to the scalar indexer_scores_kernel. The wmma kernels accumulate in F16 inside tensor cores, which diverges at ulp scale from the F32 scalar accumulation that indexer_score_one_direct_kernel uses for N=1. Routing batched-N through the scalar kernel keeps the indexer numerics bit-equal to N=1. This is already gated off via g_quality_mode; DS4_CUDA_STRICT_BATCHED adds the same gate without forcing full --quality (which also turns off several non-divergent perf caches). Test: ds4_test unchanged. NO github push.
…CT_BATCHED When DS4_CUDA_STRICT_BATCHED is set, skip the attention_indexed_mixed_heads8_online_kernel<8,16> branch in ds4_gpu_attention_indexed_mixed_batch_heads_tensor and fall through to the scalar attention_indexed_mixed_kernel that N=1 uses. The online kernel computes running softmax with recomputed normalizers (flash-attention style) which diverges at ulp scale from the canonical materialized-scores softmax used by the scalar kernel. Test: ds4_test unchanged. NO github push.
When DS4_CUDA_STRICT_BATCHED is set, skip the F16 strided-batched cuBLAS path in ds4_gpu_attention_output_q8_batch_tensor and fall through to grouped_q8_0_a_preq_warp8_kernel, the same kernel N=1 plain decode uses (at N=1 the cuBLAS branch is already bypassed via the n_tokens >= out_a_cublas_min_tokens gate, default 2). The cuBLAS strided-batched path packs heads to F16 and does tensor-core GemmStridedBatchedEx; the per-token warp8 path keeps Q8 throughout with F32 accumulators -- the two paths diverge at ulp scale even on identical weights. This is already gated off via g_quality_mode and the existing DS4_CUDA_NO_CUBLAS_ATTENTION_OUTPUT_A env; DS4_CUDA_STRICT_BATCHED adds the same gate without forcing full --quality. Test: ds4_test unchanged. NO github push.
d3513a6 to
ed98f3e
Compare
ae95ce7 to
4ef4054
Compare
Owner
Author
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
PR5 draft: cuda: DS4_CUDA_STRICT_BATCHED — bit-equal batched-N forward (stacked on #4)
Summary
Adds a
DS4_CUDA_STRICT_BATCHEDenv knob that routes batched-N (n_tokens ≥ 2) forward through bit-equivalent N=1 fallbacks in 7 kernel dispatchers. Default behavior is unchanged — the gates only fire when the env is set. This is enabling infrastructure: it makes the batched-N forward path numerically equivalent to plain decode for the row 0 output, which is a prerequisite for any future combined-forward MTP path to be byte-equal under strict mode.The 7 gated dispatchers (single env knob covers all):
fa4a479cuda_matmul_q8_0_tensor_labeledcuBLAS Sgemm/GemmEx branch0594500ds4_gpu_matmul_f16_tensorcuBLAS GemmEx branche381515ds4_gpu_matmul_f32_tensorcuBLAS Sgemm branch48d2666cuda_matmul_q8_0_tensor_labeledQ8 share-warp kernel (n_tok 2-4)c7d4b56indexer_scores_launchwmma128/64/32 selection9ae83b1ds4_gpu_attention_indexed_mixed_batch_heads_tensorheads8 online flash19568a3ds4_gpu_attention_output_q8_batch_tensorcuBLAS GemmStridedBatchedEx F16Why
At n_tok > 1, several matmul / attention dispatchers route to cuBLAS / tensor-core / WMMA / online-flash kernels that use different reduction structures, FMA orderings, and accumulation precisions than the N=1 warp kernels. The differences are legitimate (the batched paths are faster); but for any caller that wants the row-0 output of a batched-N forward to be bit-equal to what a plain N=1 decode would produce, every divergent kernel needs an opt-out path.
The intended consumer is a future combined-forward MTP path under strict mode — but that consumer doesn't exist on
upstream/mainor in this PR stack. This PR establishes the gating infrastructure first; the consumer can land separately and reference these gates.What this PR is NOT
decode2_exactstrict path, which is already byte-identical to plain and remains the strict-mode defaultWhat this PR IS
DS4_CUDA_STRICT_BATCHED) opt-in that makes the batched-N≥2 forward produce N=1-equivalent output for row 0n=63tokens (residual divergence at n=64+ traceable to routed-MoE dispatcher + Q8 pair-fused matmul scheduling)Outstanding gap (residual)
With all 7 gates set, max-abs logit divergence vs plain decode is ~2-5 at every position (no longer cascading), and argmax matches for n=32. Two sources remain:
moe_gate_up_mid_decode_lut_qwarp32_kernel, N>1 usesmoe_gate_up_mid_sorted_qwarp32_kernel. Different reduction patterns. Would need a new strict-mode dispatch that routes N≥2 through the same LUT path (or makes the sorted path bit-equivalent).matmul_q8_0_pair_preq_warp8_kernel(one warp handles both gate+up against shared activation); N≥2 uses two separatematmul_q8_0_tensorcalls. Different FMA scheduling.Both would need new code (not just env gates) and are out of scope for this PR.
Tested against
make clean && make cuda-spark— clean, no warnings (7 builds, one per commit)./ds4 -p "knight" -n 16 --temp 0 --nothink -sys ""smoke — coherent output, 17.41 t/sDS4_CUDA_STRICT_BATCHED=1 DS4_MTP_BATCH_VERIFY=1 DS4_MTP_STRICT=1: text matches plain decode through n=63; argmax matches for n=32make cpu— clean build./ds4_test --all— only pre-existing--logprob-vectors short_code_completionfailure (also onupstream/main); the intermittentlong_memory_archivetest failure noted in PR cuda: revive 2 dropped kernels with FMA-contraction fixes (stacked on #2) #3 also reproduces on this branch (and on PR4 base, and on upstream/main) — pre-existing, not introducedDeepSeek-V4-Flash-IQ2XXS-w2Q2K-AProjQ8-SExpQ8-OutQ8-chat-v2-imatrix.ggufMethodology notes
The kernel divergence inventory was produced via an instrumented logit-dump (
DS4_MTP_DUMP_LOGITS=path— a temporary scaffold that was reverted before commit). Captured plain vs batched-N logits at each iter position, computed max-abs delta. Used an ablation matrix (disable one kernel at a time via env, measure logit-diff reduction) to localize which kernels contributed how much drift, ranked, then gated the top 7.The 5 scout-ranked divergences expanded to 7 after observing: (a) matmul_f32 dispatcher also routed to cuBLAS at n_tok>1 (added gate); (b) Q8 share-warp at small N was a separate divergence from cuBLAS Q8 path (already in scout but realized as distinct dispatch).
Out of scope / follow-ups
b4e5a8e,8208b3b,11094ccfrom downstream, plus an Option-B strict gate, then strict-batched env knob unlocks byte-equality)