[v4-pro] consolidate Q+KV norm/rope into qk_norm_rope_maybe_quant (prefer flydsl)#899
Merged
Conversation
…efer flydsl) Add `atom/model_ops/v4_kernels/qk_norm_rope_maybe_quant.py` — a single wrapper that fuses per-head Q + KV RMSNorm + GPT-J RoPE (+ optional FP8 quant) in one launch. Auto-dispatches to `flydsl_qk_norm_rope_quant` for the V4-Pro shape (H=16, D=512, RD=64) and falls back to the existing Triton kernel on other shapes. Replaces the inline `fused_qk_norm_rope_swa_write` triton helper in `deepseek_v4.py`; decode path now calls the unified helper then issues a standalone `swa_write`, prefill path uses the same wrapper without swa_write. Removes the parallel `DualRMSNorm` + dead `q_norm2` / `_make_weightless_rmsnorm` plumbing. Also fix `scripts/wait_server_ready.sh`: snapshot the log byte size at start so errors from a prior failed launch can't false-trigger the "Server FAILED" detection on the next start.
Contributor
There was a problem hiding this comment.
Pull request overview
This PR consolidates DeepSeek-V4’s per-token Q/KV RMSNorm + GPT‑J RoPE (and optional FP8 quant plumbing) behind a single wrapper (qk_norm_rope_maybe_quant), preferring a FlyDSL fast-path when available, and updates the DeepSeek-V4 attention path to use this unified helper. It also hardens the server-wait script against stale log errors from prior runs.
Changes:
- Add
atom/model_ops/v4_kernels/qk_norm_rope_maybe_quant.pyimplementing a fused Triton kernel plus optional FlyDSL dispatch and a torch reference implementation. - Update
atom/models/deepseek_v4.pyto useqk_norm_rope_maybe_quantfor both decode and prefill, and remove the old inline fused helper / DualRMSNorm plumbing. - Update
scripts/wait_server_ready.shto ignore error strings from log content that existed before the current start.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 10 comments.
| File | Description |
|---|---|
| scripts/wait_server_ready.sh | Avoids false “FAILED” detection by only scanning log bytes appended after startup begins. |
| atom/models/deepseek_v4.py | Switches Q/KV norm+rope to the unified helper and removes legacy fused/dual-norm code paths. |
| atom/model_ops/v4_kernels/qk_norm_rope_maybe_quant.py | New fused norm+rope (+optional quant) kernel wrapper with FlyDSL/Triton paths and reference implementation. |
| atom/model_ops/v4_kernels/init.py | Exposes the new helper and reference function from the v4_kernels package. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Comment on lines
+402
to
+416
| kv, | ||
| kv_weight, | ||
| cos_cache, | ||
| sin_cache, | ||
| positions, | ||
| num_q_heads=n_local_heads, | ||
| head_dim=head_dim, | ||
| rope_head_dim=rope_head_dim, | ||
| quant=quant_q, | ||
| q_out=q_out, | ||
| kv_out=kv_out, | ||
| ) | ||
|
|
||
| q_scale = ( | ||
| torch.empty((T, n_local_heads), dtype=torch.float32, device=q.device) |
Comment on lines
+212
to
+214
| q_out_base + d_offs[None, :], | ||
| (x * inv_scaled).to(ot), | ||
| mask=m_mask[:, None] & nope_d_mask[None, :], |
Comment on lines
+283
to
+285
| kv_out_base + d_offs[None, :], | ||
| (x * inv_scaled * w[None, :]).to(ot), | ||
| mask=m_mask[:, None] & nope_d_mask[None, :], |
| import triton | ||
| import triton.language as tl | ||
|
|
||
| # Lazy-imported flydsl path (optional dependency). Set to None when flydsl |
Comment on lines
+27
to
+29
| Designed for the decode path only — prefill (large num_tokens) keeps the | ||
| 3-kernel sequence where fusion savings are amortized over many GEMM-bound | ||
| ops anyway. |
Comment on lines
+1550
to
+1554
| # Single kernel fuses per-head Q RMSNorm (weightless) + KV RMSNorm | ||
| # (weighted) + GPT-J interleaved RoPE on the tail rd dims. Dispatches | ||
| # to flydsl when the shape matches (V4-Pro is always V4-Pro shape → | ||
| # always flydsl). Microbench shows flydsl wins at every measured T | ||
| # from 4 (1.12×) to 32k (1.04×); used for both decode and prefill. |
Comment on lines
+31
to
+33
| if [ -f "$LOG_FILE" ]; then | ||
| ERR=$(tail -c "+$((LOG_START_BYTES + 1))" "$LOG_FILE" 2>/dev/null \ | ||
| | grep -c "cluster_dims\|InductorError\|SHUTDOWN signal\|proc died") |
Comment on lines
+375
to
+379
| ) | ||
| assert sin_cache.stride(0) == cos_cache.stride(0), "sin/cos must share row stride" | ||
| # Inner-dim stride must be 1 (dense). q.stride(0) and kv.stride(0) may | ||
| # exceed H*D / D respectively when the caller passes a strided view of | ||
| # a wider tensor (e.g. `kv_pre` from `torch.split(qkv_a, ...)` whose |
| # from 4 (1.12×) to 32k (1.04×); used for both decode and prefill. | ||
| # Optional FP8 quant outputs left off — downstream sparse_attn / | ||
| # swa_write are still bf16. | ||
| q_sa, kv, q_scale, kv_scale = qk_norm_rope_maybe_quant( |
Comment on lines
+130
to
+133
|
|
||
| rd_offs = tl.arange(0, RD).to(tl.int64) | ||
| cos_d_offs = rd_offs // 2 # GPT-J + REUSE_FREQS_FRONT_PART: lane duplicate | ||
|
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
atom/model_ops/v4_kernels/qk_norm_rope_maybe_quant.py— one wrapper that fuses per-head Q + KV RMSNorm + GPT-J RoPE (+ optional FP8 quant) in a single launch. Auto-picksflydsl_qk_norm_rope_quantfor the V4-Pro shape (H=16, D=512, RD=64) and falls back to the existing Triton kernel for everything else.deepseek_v4.py: drop the inlinefused_qk_norm_rope_swa_writetriton helper, the parallelDualRMSNorm, and the deadq_norm2/_make_weightless_rmsnormplumbing. Decode path now calls the unified helper then issues a standaloneswa_write; prefill uses the same wrapper without swa_write.scripts/wait_server_ready.sh: snapshot log byte size at start so errors from a prior failed launch can't false-trigger the "Server FAILED" detection on the next start.Same Q/KV outputs as before (flydsl path validated bit-exact vs the pure-torch reference; cos > 0.9999 across all sweep shapes). Pure-GPU kernel time wins at every T on V4-Pro since the combined Q+KV single-launch halves per-call overhead.
Test plan
op_tests/test_flydsl_qk_norm_rope_quant.py— all sweep configs pass