Skip to content

[v4-pro] consolidate Q+KV norm/rope into qk_norm_rope_maybe_quant (prefer flydsl)#899

Merged
valarLip merged 2 commits into
mainfrom
feat/qk-norm-rope-maybe-quant
May 24, 2026
Merged

[v4-pro] consolidate Q+KV norm/rope into qk_norm_rope_maybe_quant (prefer flydsl)#899
valarLip merged 2 commits into
mainfrom
feat/qk-norm-rope-maybe-quant

Conversation

@valarLip
Copy link
Copy Markdown
Collaborator

Summary

  • New atom/model_ops/v4_kernels/qk_norm_rope_maybe_quant.py — one wrapper that fuses per-head Q + KV RMSNorm + GPT-J RoPE (+ optional FP8 quant) in a single launch. Auto-picks flydsl_qk_norm_rope_quant for the V4-Pro shape (H=16, D=512, RD=64) and falls back to the existing Triton kernel for everything else.
  • deepseek_v4.py: drop the inline fused_qk_norm_rope_swa_write triton helper, the parallel DualRMSNorm, and the dead q_norm2 / _make_weightless_rmsnorm plumbing. Decode path now calls the unified helper then issues a standalone swa_write; prefill uses the same wrapper without swa_write.
  • scripts/wait_server_ready.sh: snapshot log byte size at start so errors from a prior failed launch can't false-trigger the "Server FAILED" detection on the next start.

Same Q/KV outputs as before (flydsl path validated bit-exact vs the pure-torch reference; cos > 0.9999 across all sweep shapes). Pure-GPU kernel time wins at every T on V4-Pro since the combined Q+KV single-launch halves per-call overhead.

Test plan

  • aiter op_test op_tests/test_flydsl_qk_norm_rope_quant.py — all sweep configs pass
  • CUDA-graph capture/replay regression (the NULL-stream pitfall fix)
  • V4-Pro GSM8K num_fewshot=3:
    • no MTP: flexible-extract 0.9492 / strict-match 0.9500 (baseline 0.954, within ±1σ)
    • MTP3: 0.9492 / 0.9492 (within ±1σ)
  • V4-Pro 1024/1024 c=64 trace verifies the unified helper fires on both prefill and decode paths.

…efer flydsl)

Add `atom/model_ops/v4_kernels/qk_norm_rope_maybe_quant.py` — a single
wrapper that fuses per-head Q + KV RMSNorm + GPT-J RoPE (+ optional FP8
quant) in one launch. Auto-dispatches to `flydsl_qk_norm_rope_quant` for
the V4-Pro shape (H=16, D=512, RD=64) and falls back to the existing
Triton kernel on other shapes.

Replaces the inline `fused_qk_norm_rope_swa_write` triton helper in
`deepseek_v4.py`; decode path now calls the unified helper then issues a
standalone `swa_write`, prefill path uses the same wrapper without
swa_write. Removes the parallel `DualRMSNorm` + dead `q_norm2` /
`_make_weightless_rmsnorm` plumbing.

Also fix `scripts/wait_server_ready.sh`: snapshot the log byte size at
start so errors from a prior failed launch can't false-trigger the
"Server FAILED" detection on the next start.
Copilot AI review requested due to automatic review settings May 24, 2026 15:02
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR consolidates DeepSeek-V4’s per-token Q/KV RMSNorm + GPT‑J RoPE (and optional FP8 quant plumbing) behind a single wrapper (qk_norm_rope_maybe_quant), preferring a FlyDSL fast-path when available, and updates the DeepSeek-V4 attention path to use this unified helper. It also hardens the server-wait script against stale log errors from prior runs.

Changes:

  • Add atom/model_ops/v4_kernels/qk_norm_rope_maybe_quant.py implementing a fused Triton kernel plus optional FlyDSL dispatch and a torch reference implementation.
  • Update atom/models/deepseek_v4.py to use qk_norm_rope_maybe_quant for both decode and prefill, and remove the old inline fused helper / DualRMSNorm plumbing.
  • Update scripts/wait_server_ready.sh to ignore error strings from log content that existed before the current start.

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 10 comments.

File Description
scripts/wait_server_ready.sh Avoids false “FAILED” detection by only scanning log bytes appended after startup begins.
atom/models/deepseek_v4.py Switches Q/KV norm+rope to the unified helper and removes legacy fused/dual-norm code paths.
atom/model_ops/v4_kernels/qk_norm_rope_maybe_quant.py New fused norm+rope (+optional quant) kernel wrapper with FlyDSL/Triton paths and reference implementation.
atom/model_ops/v4_kernels/init.py Exposes the new helper and reference function from the v4_kernels package.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +402 to +416
kv,
kv_weight,
cos_cache,
sin_cache,
positions,
num_q_heads=n_local_heads,
head_dim=head_dim,
rope_head_dim=rope_head_dim,
quant=quant_q,
q_out=q_out,
kv_out=kv_out,
)

q_scale = (
torch.empty((T, n_local_heads), dtype=torch.float32, device=q.device)
Comment on lines +212 to +214
q_out_base + d_offs[None, :],
(x * inv_scaled).to(ot),
mask=m_mask[:, None] & nope_d_mask[None, :],
Comment on lines +283 to +285
kv_out_base + d_offs[None, :],
(x * inv_scaled * w[None, :]).to(ot),
mask=m_mask[:, None] & nope_d_mask[None, :],
import triton
import triton.language as tl

# Lazy-imported flydsl path (optional dependency). Set to None when flydsl
Comment on lines +27 to +29
Designed for the decode path only — prefill (large num_tokens) keeps the
3-kernel sequence where fusion savings are amortized over many GEMM-bound
ops anyway.
Comment on lines +1550 to +1554
# Single kernel fuses per-head Q RMSNorm (weightless) + KV RMSNorm
# (weighted) + GPT-J interleaved RoPE on the tail rd dims. Dispatches
# to flydsl when the shape matches (V4-Pro is always V4-Pro shape →
# always flydsl). Microbench shows flydsl wins at every measured T
# from 4 (1.12×) to 32k (1.04×); used for both decode and prefill.
Comment on lines +31 to +33
if [ -f "$LOG_FILE" ]; then
ERR=$(tail -c "+$((LOG_START_BYTES + 1))" "$LOG_FILE" 2>/dev/null \
| grep -c "cluster_dims\|InductorError\|SHUTDOWN signal\|proc died")
Comment on lines +375 to +379
)
assert sin_cache.stride(0) == cos_cache.stride(0), "sin/cos must share row stride"
# Inner-dim stride must be 1 (dense). q.stride(0) and kv.stride(0) may
# exceed H*D / D respectively when the caller passes a strided view of
# a wider tensor (e.g. `kv_pre` from `torch.split(qkv_a, ...)` whose
# from 4 (1.12×) to 32k (1.04×); used for both decode and prefill.
# Optional FP8 quant outputs left off — downstream sparse_attn /
# swa_write are still bf16.
q_sa, kv, q_scale, kv_scale = qk_norm_rope_maybe_quant(
Comment on lines +130 to +133

rd_offs = tl.arange(0, RD).to(tl.int64)
cos_d_offs = rd_offs // 2 # GPT-J + REUSE_FREQS_FRONT_PART: lane duplicate

@valarLip valarLip merged commit d55a233 into main May 24, 2026
24 of 31 checks passed
@valarLip valarLip deleted the feat/qk-norm-rope-maybe-quant branch May 24, 2026 15:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants