From 996f1ea500ecdc1faec1b92403a69ca886d1c3e2 Mon Sep 17 00:00:00 2001 From: Lingpeng Jin <103567126+valarLip@users.noreply.github.com> Date: Sun, 24 May 2026 15:01:38 +0000 Subject: [PATCH 1/2] [v4-pro] consolidate Q+KV norm/rope into qk_norm_rope_maybe_quant (prefer flydsl) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add `atom/model_ops/v4_kernels/qk_norm_rope_maybe_quant.py` — a single wrapper that fuses per-head Q + KV RMSNorm + GPT-J RoPE (+ optional FP8 quant) in one launch. Auto-dispatches to `flydsl_qk_norm_rope_quant` for the V4-Pro shape (H=16, D=512, RD=64) and falls back to the existing Triton kernel on other shapes. Replaces the inline `fused_qk_norm_rope_swa_write` triton helper in `deepseek_v4.py`; decode path now calls the unified helper then issues a standalone `swa_write`, prefill path uses the same wrapper without swa_write. Removes the parallel `DualRMSNorm` + dead `q_norm2` / `_make_weightless_rmsnorm` plumbing. Also fix `scripts/wait_server_ready.sh`: snapshot the log byte size at start so errors from a prior failed launch can't false-trigger the "Server FAILED" detection on the next start. --- atom/model_ops/v4_kernels/__init__.py | 6 + .../v4_kernels/qk_norm_rope_maybe_quant.py | 568 ++++++++++++++++++ atom/models/deepseek_v4.py | 217 ++----- scripts/wait_server_ready.sh | 19 +- 4 files changed, 626 insertions(+), 184 deletions(-) create mode 100644 atom/model_ops/v4_kernels/qk_norm_rope_maybe_quant.py diff --git a/atom/model_ops/v4_kernels/__init__.py b/atom/model_ops/v4_kernels/__init__.py index f39b3afdb..0b86fff07 100644 --- a/atom/model_ops/v4_kernels/__init__.py +++ b/atom/model_ops/v4_kernels/__init__.py @@ -39,6 +39,10 @@ write_v4_paged_prefill_indices, write_v4_paged_prefill_indices_reference, ) +from atom.model_ops.v4_kernels.qk_norm_rope_maybe_quant import ( + qk_norm_rope_maybe_quant, + qk_norm_rope_maybe_quant_reference, +) from atom.model_ops.v4_kernels.state_writes import update_compressor_states, swa_write __all__ = [ @@ -60,4 +64,6 @@ "write_v4_paged_decode_indices_reference", "write_v4_paged_prefill_indices", "write_v4_paged_prefill_indices_reference", + "qk_norm_rope_maybe_quant", + "qk_norm_rope_maybe_quant_reference", ] diff --git a/atom/model_ops/v4_kernels/qk_norm_rope_maybe_quant.py b/atom/model_ops/v4_kernels/qk_norm_rope_maybe_quant.py new file mode 100644 index 000000000..70ff6cd28 --- /dev/null +++ b/atom/model_ops/v4_kernels/qk_norm_rope_maybe_quant.py @@ -0,0 +1,568 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. + +"""Fused per-token RMSNorm + RoPE (+ optional FP8 per-row quant). + +Replaces this 3-kernel sequence on the V4 decode path:: + + q_flat, kv = qk_norm(q, kv_pre) # triton: fused_qk_norm + q = q_flat.view(T, H, D) + rotary_emb(positions, q[..., -rd:], # aiter: rope_cached_positions_2c + kv[..., -rd:]) + +with a single Triton kernel. The Q-side norm is *weightless* (V4's +``q_norm2`` has ``weight=None``) — the kernel hardcodes 1.0 on that side +and only loads ``kv_weight``. RoPE uses ``rotate_style=1`` (GPT-J +interleaved pairs) with ``reuse_freqs_front_part=True`` and +``nope_first=False`` to match ``_V4RoPE.forward``. + +Optional FP8 outputs (``quant_q`` / ``quant_k``) emit per-row e4m3 + a +single fp32 ``amax/FP8_MAX`` scale per row. "1x128" blockscale = one +scale per (token, head) for Q, one per token for KV — head_dim is the +only contracted dim. Default off; plumbing for a future FP8 consumer +(sparse_attn FP8 path / FP8 swa_write). When the corresponding flag is +off the wrapper returns ``None`` for that scale and the fp8 output +buffer is not allocated. + +Designed for the decode path only — prefill (large num_tokens) keeps the +3-kernel sequence where fusion savings are amortized over many GEMM-bound +ops anyway. +""" + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from atom.utils import envs + +# Lazy-imported flydsl path (optional dependency). Set to None when flydsl +# is unavailable; the dispatch in ``qk_norm_rope_maybe_quant`` will fall +# back to the Triton kernel. +try: + from aiter.ops.flydsl import flydsl_qk_norm_rope_quant + + _FLYDSL_AVAILABLE = True +except Exception: + _FLYDSL_AVAILABLE = False + + +# AMD MI3 native e4m3 variant. aiter's a8w8 path and the existing +# act_quant_inplace consumer agree on this dtype. +_FP8_DTYPE = torch.float8_e4m3fnuz +_FP8_MAX = float(torch.finfo(_FP8_DTYPE).max) +# Precomputed constants used by the fp8 quant fast-path. With the √2 +# upper-bound for amax, scaling x_n by 1/scale algebraically equals +# scaling x (pre-norm) by FP8_MAX / (abs_max_x * √2) — `rstd` cancels. +# Folding into a single constant saves a multiply per row. +_SQRT2 = 1.4142135623730951 +_INV_FP8_MAX_SQRT2 = _SQRT2 / _FP8_MAX +_FP8_MAX_OVER_SQRT2 = _FP8_MAX / _SQRT2 + + +@triton.jit +def _gptj_rotate(x, x_rot_mask, BLOCK_M: tl.constexpr, RD: tl.constexpr): + """GPT-J interleaved rotation on a [BLOCK_M, RD] tile. + + Returns ``(-x[2i+1], x[2i], -x[2i+3], x[2i+2], ...)`` so that + ``x * cos + rotated * sin`` realizes the per-pair RoPE + ``(e*c - o*s, e*s + o*c)``. cos/sin must be lane-duplicated + (``cache[i]`` at lanes 2i and 2i+1), produced via + ``d_cos_offs = d_pe_offs // 2``. + """ + x_rot = tl.where(x_rot_mask, x, -x) + x_rot = tl.reshape(x_rot, (BLOCK_M, RD // 2, 2)) + x_rot = tl.flip(x_rot, 2) + return tl.reshape(x_rot, (BLOCK_M, RD)) + + +@triton.jit +def _qk_norm_rope_maybe_quant_kernel( + q_in_ptr, # [T, H*D] bf16 (post wq_b, heads packed) + kv_ptr, # [T, D] bf16 (post wkv_a split) + kv_weight_ptr, # [D] bf16 (KV RMSNorm weight; Q weightless) + cos_ptr, # [..., rd/2] (REUSE_FREQS_FRONT_PART=True) + sin_ptr, # [..., rd/2] + positions_ptr, # [T] int64 + q_out_ptr, # [T, H, D] bf16 or e4m3 + kv_out_ptr, # [T, D] bf16 or e4m3 + q_scale_ptr, # [T, H] fp32 (only when QUANT_Q) + kv_scale_ptr, # [T] fp32 (only when QUANT_K) + eps: tl.constexpr, + T, + q_in_row_stride, + kv_in_row_stride, + cos_row_stride: tl.constexpr, + H: tl.constexpr, + D: tl.constexpr, # head_dim — must be power of 2 (loaded as single tile) + RD: tl.constexpr, # rope_head_dim + NOPE: tl.constexpr, # D - RD + NUM_PE_CHUNKS: tl.constexpr, # D // RD — requires D % RD == 0 (V4: 512/64=8) + FP8_MAX: tl.constexpr, + INV_FP8_MAX_SQRT2: tl.constexpr, # √2 / FP8_MAX, for `scale` compute + FP8_MAX_OVER_SQRT2: tl.constexpr, # FP8_MAX / √2, for `inv_scaled` (rstd-cancelled) + BLOCK_M: tl.constexpr, + QUANT_Q: tl.constexpr, + QUANT_K: tl.constexpr, +): + """Grid: ``(cdiv(T, BLOCK_M), H + 1)``. + + - ``pid_h < H`` → process Q-head ``pid_h`` (weightless RMSNorm + RoPE tail). + - ``pid_h == H`` → process KV row (weighted RMSNorm + RoPE tail). + + Each program handles a ``BLOCK_M``-token tile. We load the full + ``[BLOCK_M, D]`` tile, RMSNorm it, then extract the RoPE tail by + ``tl.where(d >= NOPE, normed, 0)`` → reshape ``(BLOCK_M, NUM_PE_CHUNKS, RD)`` + → ``sum(axis=1)`` (only the last chunk is nonzero so the sum just + selects that chunk). Stores nope to output positions, RoPE to tail. + + Q early-returns so the KV-only stores below see a single, non-divergent + type for ``kv_out_ptr.dtype.element_ty`` (triton's IR cannot unify + bf16 vs e4m3 store ops across an ``if/else`` branch). + """ + pid_m = tl.program_id(0).to(tl.int64) + pid_h = tl.program_id(1).to(tl.int64) + + m_offs = pid_m * BLOCK_M + tl.arange(0, BLOCK_M).to(tl.int64) + m_mask = m_offs < T + + d_offs = tl.arange(0, D) + nope_d_mask = d_offs < NOPE + rope_d_mask = d_offs >= NOPE + + rd_offs = tl.arange(0, RD).to(tl.int64) + cos_d_offs = rd_offs // 2 # GPT-J + REUSE_FREQS_FRONT_PART: lane duplicate + + # positions/cos/sin are reused across all H+1 programs sharing this pid_m. + # Tag them evict_last so the L2 keeps them hot for sibling head-tiles. + pos = tl.load( + positions_ptr + m_offs, mask=m_mask, other=0, eviction_policy="evict_last" + ).to(tl.int64) + cos_addr = pos[:, None] * cos_row_stride + cos_d_offs[None, :] + cos = tl.load( + cos_ptr + cos_addr, + mask=m_mask[:, None], + other=0, + eviction_policy="evict_last", + ).to(tl.float32) + sin = tl.load( + sin_ptr + cos_addr, + mask=m_mask[:, None], + other=0, + eviction_policy="evict_last", + ).to(tl.float32) + # Rotation mask: evens get +x, odds get -x → after pair-flip realizes + # the (-o, e) pattern needed for x*c + rot*s == (e*c-o*s, e*s+o*c). + x_rot_mask = (rd_offs % 2 == 0)[None, :] + + # ---- Q path (pid_h < H) ---- + if pid_h < H: + h = pid_h.to(tl.int32) + q_base = q_in_ptr + m_offs[:, None] * q_in_row_stride + h * D + # Q tile is one-shot (no other program loads this head): evict_first. + x = tl.load( + q_base + d_offs[None, :], + mask=m_mask[:, None], + other=0.0, + eviction_policy="evict_first", + ).to(tl.float32) + + # Single pass over x: variance + (when quanting) input amax. + # Triton fuses both reductions onto the same scan of x, so amax is + # essentially free vs a second pass over `x_n`. + sq = tl.sum(x * x, axis=1) + if QUANT_Q: + abs_max_x = tl.max(tl.abs(x), axis=1) + rstd = tl.rsqrt(sq / D + eps) + + # RoPE input: re-load just the [BM, RD] rope tail (L2-hot from the + # full-row load above) instead of extracting it via + # `tl.where + reshape + sum` on [BM, D]. The extract path costs ~3us + # at BM=8 D=512 because it touches the full 4096-elem tile; the + # re-load hits L2 and is essentially free. + pe_in = tl.load(q_base + NOPE + rd_offs[None, :], mask=m_mask[:, None]).to( + tl.float32 + ) + + q_out_base = q_out_ptr + m_offs[:, None] * (H * D) + h * D + ot = q_out_ptr.dtype.element_ty + if QUANT_Q: + # Conservative √2 amax bound for fp8 scale (Q is weightless): + # |x_n[d]| = |x[d]| * rstd ≤ abs_max_x * rstd + # |pe[d]| ≤ |x_rope[d]| * rstd * √2 (GPT-J rotation: + # |pe_even/odd| ≤ √(e²+o²) ≤ √2·max(|e|,|o|)) + # Bounded by `abs_max_x * rstd * √2`. Skipping the second-pass + # `tl.max(tl.abs(x_n))` AND the [BM, RD] pe reduction. Cost: + # ≤ 0.5 bits of fp8 precision (over-scale by ≤ √2). + # + # Algebraic fast-path: `x_n * inv == x * (rstd/scale)`, and + # rstd/scale = rstd / (abs_max_x*rstd*INV_FP8_MAX_SQRT2) + # = FP8_MAX_OVER_SQRT2 / abs_max_x (rstd cancels!) + # So we skip materializing `x_n = x * rstd` as a separate fp32 + # tile, and apply a single multiplier directly to x before cast. + # Same trick on pe via linearity of RoPE rotation: + # pe * inv = (pe_in * rstd * cos + rotate(...) * sin) * inv + # = (pe_in * inv_scaled) * cos + rotate(pe_in * inv_scaled) * sin + inv_scaled = (FP8_MAX_OVER_SQRT2 / tl.maximum(abs_max_x, 1e-12))[:, None] + pe_scaled = pe_in * inv_scaled + pe_quant = ( + pe_scaled * cos + _gptj_rotate(pe_scaled, x_rot_mask, BLOCK_M, RD) * sin + ) + # Scale to store (downstream consumer reconstructs via fp8*scale). + scale = abs_max_x * rstd * INV_FP8_MAX_SQRT2 + tl.store( + q_out_base + d_offs[None, :], + (x * inv_scaled).to(ot), + mask=m_mask[:, None] & nope_d_mask[None, :], + ) + tl.store( + q_out_base + NOPE + rd_offs[None, :], + pe_quant.to(ot), + mask=m_mask[:, None], + ) + tl.store(q_scale_ptr + m_offs * H + h, scale, mask=m_mask) + else: + # bf16 path: still need to materialize x_n and pe in fp32. + x_n = x * rstd[:, None] + pe = pe_in * rstd[:, None] + pe = pe * cos + _gptj_rotate(pe, x_rot_mask, BLOCK_M, RD) * sin + tl.store( + q_out_base + d_offs[None, :], + x_n.to(ot), + mask=m_mask[:, None] & nope_d_mask[None, :], + ) + tl.store( + q_out_base + NOPE + rd_offs[None, :], + pe.to(ot), + mask=m_mask[:, None], + ) + return + + # ---- KV path (pid_h == H) ---- + kv_base = kv_ptr + m_offs[:, None] * kv_in_row_stride + # KV tile is one-shot; weight is reused across all M-tiles. + x = tl.load( + kv_base + d_offs[None, :], + mask=m_mask[:, None], + other=0.0, + eviction_policy="evict_first", + ).to(tl.float32) + w = tl.load(kv_weight_ptr + d_offs, eviction_policy="evict_last").to(tl.float32) + + sq = tl.sum(x * x, axis=1) + if QUANT_K: + # Weighted amax: |x_n[d]| = |x[d]| * rstd * |w[d]|. + # Pre-multiply x by abs(w) elementwise then take row-max. + abs_max_xw = tl.max(tl.abs(x) * tl.abs(w)[None, :], axis=1) + rstd = tl.rsqrt(sq / D + eps) + + # Reload rope tail from L2 (hot after the full-row load above) and apply + # the per-rope-tail weight slice directly. + pe_in = tl.load(kv_base + NOPE + rd_offs[None, :], mask=m_mask[:, None]).to( + tl.float32 + ) + w_rope = tl.load(kv_weight_ptr + NOPE + rd_offs, eviction_policy="evict_last").to( + tl.float32 + ) + + kv_out_base = kv_out_ptr + m_offs[:, None] * D + ot = kv_out_ptr.dtype.element_ty + if QUANT_K: + # Same √2 bound + rstd-cancellation fast-path as Q (see Q-path + # comment). For KV with weighted norm: + # x_n_out = x * rstd * w * inv = (x * w) * (rstd / scale) + # = (x * w) * FP8_MAX_OVER_SQRT2 / abs_max_xw (rstd cancels) + # And pe_out via rope linearity: + # pe_out = (pe_in * inv_scaled * w_rope) * cos + # + rotate(pe_in * inv_scaled * w_rope) * sin + inv_scaled = (FP8_MAX_OVER_SQRT2 / tl.maximum(abs_max_xw, 1e-12))[:, None] + pe_scaled = pe_in * inv_scaled * w_rope[None, :] + pe_quant = ( + pe_scaled * cos + _gptj_rotate(pe_scaled, x_rot_mask, BLOCK_M, RD) * sin + ) + scale = abs_max_xw * rstd * INV_FP8_MAX_SQRT2 + tl.store( + kv_out_base + d_offs[None, :], + (x * inv_scaled * w[None, :]).to(ot), + mask=m_mask[:, None] & nope_d_mask[None, :], + ) + tl.store( + kv_out_base + NOPE + rd_offs[None, :], + pe_quant.to(ot), + mask=m_mask[:, None], + ) + tl.store(kv_scale_ptr + m_offs, scale, mask=m_mask) + else: + # bf16 path: materialize x_n and pe in fp32. + x_n = x * rstd[:, None] * w[None, :] + pe = pe_in * rstd[:, None] * w_rope[None, :] + pe = pe * cos + _gptj_rotate(pe, x_rot_mask, BLOCK_M, RD) * sin + tl.store( + kv_out_base + d_offs[None, :], + x_n.to(ot), + mask=m_mask[:, None] & nope_d_mask[None, :], + ) + tl.store(kv_out_base + NOPE + rd_offs[None, :], pe.to(ot), mask=m_mask[:, None]) + + +def qk_norm_rope_maybe_quant( + q: torch.Tensor, + kv: torch.Tensor, + kv_weight: torch.Tensor, + cos_cache: torch.Tensor, + sin_cache: torch.Tensor, + positions: torch.Tensor, + n_local_heads: int, + head_dim: int, + rope_head_dim: int, + eps: float, + quant_q: bool = False, + quant_k: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + """Fused per-token RMSNorm + GPT-J interleaved RoPE (+ optional FP8 quant). + + Args: + q: ``[T, H*D]`` bf16 — post-``wq_b`` Q (heads packed in last dim). + kv: ``[T, D]`` bf16 — post-``wkv_a`` split KV row. + kv_weight: ``[D]`` bf16 — KV-side RMSNorm weight. Q-side is weightless + (kernel hardcodes 1.0). + cos_cache, sin_cache: rope tables with ``rd/2`` columns on the inner- + most axis (``reuse_freqs_front_part=True`` layout from + ``_build_cos_sin_cache``). Higher-rank caches like + ``[max_pos, 1, 1, rd/2]`` are tolerated — only the last-dim width + and row-stride to (max_pos's) next index are read. + positions: ``[T]`` int64 — absolute token positions. + eps: RMSNorm epsilon. + quant_q, quant_k: independently emit per-row FP8 + per-row fp32 scale. + ``False`` keeps the bf16 output and returns ``None`` for that scale. + + Returns: + ``(q_out, kv_out, q_scale_or_None, k_scale_or_None)``: + - ``q_out`` shape ``[T, H, D]``, dtype = ``float8_e4m3fnuz`` if + ``quant_q`` else ``bf16``. + - ``kv_out`` shape ``[T, D]``, dtype = ``float8_e4m3fnuz`` if + ``quant_k`` else ``bf16``. + - ``q_scale`` shape ``[T, H]`` fp32 if ``quant_q`` else ``None``. + - ``k_scale`` shape ``[T]`` fp32 if ``quant_k`` else ``None``. + """ + assert ( + q.dim() == 2 and kv.dim() == 2 + ), f"q/kv must be 2-D; got q={tuple(q.shape)} kv={tuple(kv.shape)}" + T = q.shape[0] + assert ( + q.shape[1] == n_local_heads * head_dim + ), f"q last dim {q.shape[1]} != H*D = {n_local_heads * head_dim}" + assert kv.shape == ( + T, + head_dim, + ), f"kv must be [T={T}, D={head_dim}]; got {tuple(kv.shape)}" + assert ( + rope_head_dim <= head_dim and rope_head_dim % 2 == 0 + ), f"rope_head_dim must be even and ≤ head_dim; got {rope_head_dim}" + # head_dim must be a power of 2 (loaded as a single triton tile) AND + # divisible by rope_head_dim (the reshape+sum pe-extract trick requires + # the rope tail to be the last `head_dim/rope_head_dim`-th chunk). + assert ( + head_dim & (head_dim - 1) + ) == 0, f"head_dim must be a power of 2; got {head_dim}" + assert ( + head_dim % rope_head_dim == 0 + ), f"head_dim {head_dim} must be divisible by rope_head_dim {rope_head_dim}" + assert ( + q.dtype == torch.bfloat16 and kv.dtype == torch.bfloat16 + ), f"q/kv must be bf16; got q={q.dtype} kv={kv.dtype}" + assert cos_cache.shape[-1] == rope_head_dim // 2, ( + f"cos_cache last-dim {cos_cache.shape[-1]} != rope_head_dim/2 " + f"{rope_head_dim // 2}" + ) + assert sin_cache.stride(0) == cos_cache.stride(0), "sin/cos must share row stride" + # Inner-dim stride must be 1 (dense). q.stride(0) and kv.stride(0) may + # exceed H*D / D respectively when the caller passes a strided view of + # a wider tensor (e.g. `kv_pre` from `torch.split(qkv_a, ...)` whose + # row stride is `q_lora_rank + head_dim`). + assert q.stride(-1) == 1 and kv.stride(-1) == 1, ( + f"q/kv must be dense in the last dim; got q.stride={q.stride()} " + f"kv.stride={kv.stride()}" + ) + + q_out_dtype = _FP8_DTYPE if quant_q else torch.bfloat16 + kv_out_dtype = _FP8_DTYPE if quant_k else torch.bfloat16 + q_out = torch.empty( + (T, n_local_heads, head_dim), dtype=q_out_dtype, device=q.device + ) + kv_out = torch.empty((T, head_dim), dtype=kv_out_dtype, device=kv.device) + + # ------------------------------------------------------------------ + # flydsl dispatch (MVP hardcoded for V4-Pro decode shape). The combined + # Q+KV single-launch kernel wins at all T (large for small T due to + # halved launch overhead, large for big T due to better occupancy), so + # "auto" picks flydsl whenever the shape matches. + # ------------------------------------------------------------------ + if _FLYDSL_AVAILABLE: + return flydsl_qk_norm_rope_quant( + q, + kv, + kv_weight, + cos_cache, + sin_cache, + positions, + num_q_heads=n_local_heads, + head_dim=head_dim, + rope_head_dim=rope_head_dim, + quant=quant_q, + q_out=q_out, + kv_out=kv_out, + ) + + q_scale = ( + torch.empty((T, n_local_heads), dtype=torch.float32, device=q.device) + if quant_q + else None + ) + kv_scale = ( + torch.empty((T,), dtype=torch.float32, device=kv.device) if quant_k else None + ) + + # 1-element dummies so triton has concrete pointers when the QUANT_* + # constexpr branch is off (kernel won't touch them). + q_scale_arg = ( + q_scale if q_scale is not None else q.new_empty(1, dtype=torch.float32) + ) + kv_scale_arg = ( + kv_scale if kv_scale is not None else q.new_empty(1, dtype=torch.float32) + ) + + # Tuned on V4-Pro decode shape (H=16, D=512, RD=64) on MI355. After + # ditching the `tl.where + reshape + sum` pe-extract in favor of a direct + # L2-hot reload of the rope tail, BM=8 NW=8 is within 0.1us of optimal + # across the full T range (4..1024). The trailing shrink handles T 1 and block_m > T: + block_m //= 2 + + grid = (triton.cdiv(T, block_m), n_local_heads + 1) + _qk_norm_rope_maybe_quant_kernel[grid]( + q, + kv, + kv_weight, + cos_cache, + sin_cache, + positions, + q_out, + kv_out, + q_scale_arg, + kv_scale_arg, + eps=float(eps), + T=T, + q_in_row_stride=q.stride(0), + kv_in_row_stride=kv.stride(0), + cos_row_stride=cos_cache.stride(0), + H=n_local_heads, + D=head_dim, + RD=rope_head_dim, + NOPE=head_dim - rope_head_dim, + NUM_PE_CHUNKS=head_dim // rope_head_dim, + FP8_MAX=_FP8_MAX, + INV_FP8_MAX_SQRT2=_INV_FP8_MAX_SQRT2, + FP8_MAX_OVER_SQRT2=_FP8_MAX_OVER_SQRT2, + BLOCK_M=block_m, + QUANT_Q=quant_q, + QUANT_K=quant_k, + num_warps=num_warps, + waves_per_eu=1, + ) + return q_out, kv_out, q_scale, kv_scale + + +def qk_norm_rope_maybe_quant_reference( + q: torch.Tensor, + kv: torch.Tensor, + kv_weight: torch.Tensor, + cos_cache: torch.Tensor, + sin_cache: torch.Tensor, + positions: torch.Tensor, + n_local_heads: int, + head_dim: int, + rope_head_dim: int, + eps: float, + quant_q: bool = False, + quant_k: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + """Pure-torch reference. Matches the kernel modulo bf16 reduction-order + noise. Performs RMSNorm (Q weightless, KV weighted), then a manual GPT-J + interleaved RoPE on the tail ``rope_head_dim``, then optional per-row + amax-based e4m3 quant. + """ + T = q.shape[0] + rd = rope_head_dim + nope = head_dim - rd + + q_h = q.view(T, n_local_heads, head_dim).to(torch.float32) + kv_f = kv.to(torch.float32) + + rstd_q = torch.rsqrt(q_h.pow(2).mean(-1, keepdim=True) + eps) + q_h = q_h * rstd_q # weightless + rstd_kv = torch.rsqrt(kv_f.pow(2).mean(-1, keepdim=True) + eps) + kv_f = kv_f * rstd_kv * kv_weight.to(torch.float32) + + cos = cos_cache.index_select(0, positions).view(T, rd // 2).to(torch.float32) + sin = sin_cache.index_select(0, positions).view(T, rd // 2).to(torch.float32) + + def _rope_tail(x: torch.Tensor) -> torch.Tensor: + head_shape = x.shape[:-1] + tail = x[..., nope:].reshape(*head_shape, rd // 2, 2) + c = cos.reshape((T,) + (1,) * (tail.ndim - 3) + (rd // 2,)) + s = sin.reshape((T,) + (1,) * (tail.ndim - 3) + (rd // 2,)) + even, odd = tail[..., 0], tail[..., 1] + new_even = even * c - odd * s + new_odd = even * s + odd * c + tail_new = torch.stack([new_even, new_odd], dim=-1).reshape(*head_shape, rd) + return torch.cat([x[..., :nope], tail_new], dim=-1) + + # Compute amax for quant BEFORE applying rope: the kernel uses the + # `abs_max_x * rstd * √2` upper bound (saves a full-tile reduction). + # Reproduce that bound here so kernel and reference quantize to the same + # values bit-for-bit (modulo bf16 noise). + SQRT2 = 1.4142135623730951 + if quant_q: + # Q is weightless: x_n = x * rstd. amax bound from input. + x_q_in = q.view(T, n_local_heads, head_dim).to(torch.float32) + abs_max_x_q = x_q_in.abs().amax(dim=-1, keepdim=True) + amax_q = abs_max_x_q * rstd_q * SQRT2 + q_scale_t = (amax_q / _FP8_MAX).clamp_min(1e-12) + + if quant_k: + # KV is weighted: x_n = x * rstd * w. amax bound from |x*w|. + x_kv_in = kv.to(torch.float32) + abs_max_xw_kv = (x_kv_in.abs() * kv_weight.to(torch.float32).abs()).amax( + dim=-1, keepdim=True + ) + amax_k = abs_max_xw_kv * rstd_kv * SQRT2 + kv_scale_t = (amax_k / _FP8_MAX).clamp_min(1e-12) + + q_h = _rope_tail(q_h) + kv_f = _rope_tail(kv_f) + + if quant_q: + q_out = (q_h / q_scale_t).to(_FP8_DTYPE) + q_scale = q_scale_t.squeeze(-1).contiguous() + else: + q_out = q_h.to(torch.bfloat16) + q_scale = None + + if quant_k: + kv_out = (kv_f / kv_scale_t).to(_FP8_DTYPE) + kv_scale = kv_scale_t.squeeze(-1).contiguous() + else: + kv_out = kv_f.to(torch.bfloat16) + kv_scale = None + + return q_out, kv_out, q_scale, kv_scale diff --git a/atom/models/deepseek_v4.py b/atom/models/deepseek_v4.py index 18af2fc39..2a304ef6e 100644 --- a/atom/models/deepseek_v4.py +++ b/atom/models/deepseek_v4.py @@ -47,9 +47,6 @@ from aiter.ops.triton.fusions.fused_clamp_act_mul import ( fused_clamp_act_mul, ) -from aiter.ops.triton.fusions.fused_reduce_qk_norm_rope_swa_write import ( - fused_reduce_qk_norm_rope_swa_write, -) from aiter.ops.triton.pa_mqa_logits import deepgemm_fp8_paged_mqa_logits from atom.config import ( Config, @@ -60,7 +57,7 @@ ) from atom.model_loader.loader import WeightsMapper from atom.model_ops.embed_head import ParallelLMHead, VocabParallelEmbedding -from atom.model_ops.layernorm import DualRMSNorm, RMSNorm, rmsnorm2d_fwd_ +from atom.model_ops.layernorm import RMSNorm, rmsnorm2d_fwd_ from atom.model_ops.triton_rmsnorm_nw import rmsnorm_nw from atom.model_ops.linear import ( ColumnParallelLinear, @@ -90,6 +87,7 @@ csa_translate_pack, fused_compress_attn, inverse_rope_inplace, + qk_norm_rope_maybe_quant, scale_indexer_weights, sparse_attn_v4_paged_decode, sparse_attn_v4_paged_prefill, @@ -135,114 +133,6 @@ def _rmsnorm_nw(x: torch.Tensor, eps: float, dim: int) -> torch.Tensor: return rmsnorm2d_fwd_(x, ones, eps, dim) -def _fused_qk_norm_rope_swa_write_fake( - q: torch.Tensor, - kv: torch.Tensor, - cos_cache: torch.Tensor, - sin_cache: torch.Tensor, - positions: torch.Tensor, - n_local_heads: int, - head_dim: int, - rope_head_dim: int, - kv_weight: torch.Tensor, - eps: float, - win: int, - batch_id_per_token: Optional[torch.Tensor] = None, - state_slot_mapping: Optional[torch.Tensor] = None, - swa_kv: Optional[torch.Tensor] = None, -) -> torch.Tensor: - M = q.shape[0] - return torch.empty( - (M, n_local_heads, head_dim), - dtype=torch.bfloat16, - device=q.device, - ) - - -@torch_compile_guard(gen_fake=_fused_qk_norm_rope_swa_write_fake) -def fused_qk_norm_rope_swa_write( - q: torch.Tensor, - kv: torch.Tensor, - cos_cache: torch.Tensor, - sin_cache: torch.Tensor, - positions: torch.Tensor, - n_local_heads: int, - head_dim: int, - rope_head_dim: int, - kv_weight: torch.Tensor, - eps: float, - win: int, - batch_id_per_token: Optional[torch.Tensor] = None, - state_slot_mapping: Optional[torch.Tensor] = None, - swa_kv: Optional[torch.Tensor] = None, -) -> torch.Tensor: - """Fused wq_b GEMM (a8w8 1x128 blockscale) + per-head RMSNorm-nw + RoPE on - q/kv tail in a single triton kernel. - - `kv` must already be kv_norm-applied; the kernel does not weight-norm kv. - `kv` is RoPE-mutated in place. SWA write is not performed here — callers - that need it must invoke the standalone `swa_write` after this call. - """ - num_tokens = q.shape[0] - if num_tokens <= 64: - q_out = torch.empty( - (num_tokens, n_local_heads, head_dim), - dtype=torch.bfloat16, - device=q.device, - ) - fused_reduce_qk_norm_rope_swa_write( - q, - kv, - None, - kv_weight, - eps, - eps, - rope_head_dim, - cos_cache, - sin_cache, - positions, - q_out=q_out, - is_neox=False, - dtype=torch.bfloat16, - write_indices=None, - batch_id_per_token=batch_id_per_token, - state_slot_mapping=state_slot_mapping, - swa_kv=swa_kv, - win=win, - ) - else: - q = q.view(num_tokens, n_local_heads, head_dim) - q_out = _rmsnorm_nw(q, eps, head_dim) - kv = rmsnorm2d_fwd_(kv, kv_weight, eps, kv.shape[-1]) - aiter.rope_cached_positions_2c_fwd_inplace( - q_out[..., -rope_head_dim:].view(1, num_tokens, -1, rope_head_dim), - kv[..., -rope_head_dim:].view(1, num_tokens, -1, rope_head_dim), - cos_cache, - sin_cache, - positions.view(1, num_tokens), - 1, - reuse_freqs_front_part=True, - nope_first=False, - ) - return q_out - - -def _make_weightless_rmsnorm(dim: int, eps: float) -> RMSNorm: - """Build an `RMSNorm(dim, eps)` whose `.weight` is `None`. - - Drops the learnable Parameter so `state_dict()` is empty for this - submodule and `load_model` doesn't expect a weight from disk (no - "unloaded parameter" warning). `DualRMSNorm` recognizes the - sentinel `weight is None` and instructs the fused kernel to skip - both the q_weight load and the multiply (Q_HAS_WEIGHT=False), - matching the prior `_rmsnorm_nw(x, eps, dim)` behavior. - """ - norm = RMSNorm(dim, eps) - del norm.weight # remove Parameter from `_parameters` - norm.weight = None # sentinel for downstream consumers - return norm - - def _hc_head_reduce_fake( x: torch.Tensor, hc_fn: torch.Tensor, @@ -1375,12 +1265,6 @@ def __init__( quant_config=qc, prefix=f"{p}.q_norm", ) - # q_norm2: per-head Q normalization. The checkpoint has no - # `q_norm2.weight` entry, so the module is constructed with - # `weight=None` (no learnable Parameter) — `DualRMSNorm` reads the - # sentinel and tells the kernel to skip the q_weight load entirely. - # Behavior is identical to the prior `_rmsnorm_nw(q, eps, head_dim)`. - self.q_norm2 = _make_weightless_rmsnorm(self.head_dim, self.eps) self.wq_b = ColumnParallelLinear( self.q_lora_rank, self.n_heads * self.head_dim, @@ -1389,19 +1273,6 @@ def __init__( prefix=f"{p}.wq_b", ) self.kv_norm = RMSNorm(self.head_dim, self.eps) - # Fused per-head Q RMSNorm (identity weight) + KV RMSNorm — both have - # feature dim = head_dim, single Triton kernel via DualRMSNorm. - # `q_norm2.weight is None` tells the kernel's Q_HAS_WEIGHT=False - # path to skip the q_weight load (Q side has no learnable weight; - # equivalent to the prior `_rmsnorm_nw` helper). - self.qk_norm = DualRMSNorm( - self.q_norm2, - self.kv_norm, - num_q_heads=self.n_local_heads, - num_kv_heads=1, - head_dim=self.head_dim, - prefix=f"{p}.qk_norm", - ) # wo_a: grouped LoRA — V4QuantConfig forces this BF16 even though disk is FP8. # The grouped einsum (`bsgd,grd->bsgr`) needs BF16 weights; aiter has no FP8 einsum. self.wo_a = ColumnParallelLinear( @@ -1496,8 +1367,6 @@ def __init__( self.alt_stream is not None and self.compressor is not None ) - self.use_fuse_qk_norm_rope_swa_write = _V4_USE_TRITON_FUSION - self.layer_name = prefix atom_config = get_current_atom_config() atom_config.compilation_config.static_forward_context[self.layer_name] = self @@ -1678,49 +1547,39 @@ def forward_impl( qr, qr_scale = self.q_norm(q_lora) q = self.wq_b(qr, x_scale=qr_scale) is_decode = attn_md.state is AttnState.DECODE - if is_decode and self.use_fuse_qk_norm_rope_swa_write: - # Fused: wq_b GEMM (a8w8 1x128 blockscale) + per-head RMSNorm-nw - # + RoPE on q tail + RoPE on kv tail (+ SWA write) in one triton - # kernel. KV RMSNorm stays out (kernel doesn't apply weighted norm - # to kv); the standalone swa_write below is gated off when this - # path runs since the kernel already wrote the window slots. - cos_cache, sin_cache = self.rotary_emb.cos_cache, self.rotary_emb.sin_cache - q = fused_qk_norm_rope_swa_write( - q, - kv_pre, - cos_cache, - sin_cache, + # Single kernel fuses per-head Q RMSNorm (weightless) + KV RMSNorm + # (weighted) + GPT-J interleaved RoPE on the tail rd dims. Dispatches + # to flydsl when the shape matches (V4-Pro is always V4-Pro shape → + # always flydsl). Microbench shows flydsl wins at every measured T + # from 4 (1.12×) to 32k (1.04×); used for both decode and prefill. + # Optional FP8 quant outputs left off — downstream sparse_attn / + # swa_write are still bf16. + q_sa, kv, q_scale, kv_scale = qk_norm_rope_maybe_quant( + q, + kv_pre, + self.kv_norm.weight, + self.rotary_emb.cos_cache, + self.rotary_emb.sin_cache, + positions, + self.n_local_heads, + self.head_dim, + rd, + self.eps, + quant_q=False, + quant_k=False, + ) + if is_decode: + # SWA write per-token in decode (prefill writes after sparse_attn + # below so the in-chunk SWA tail is captured post-attention). + swa_write( + kv, positions, - self.n_local_heads, - self.head_dim, - rd, - self.kv_norm.weight, - self.eps, + attn_md.cu_seqlens_q, + state_slot_mapping, + self.swa_kv, cache_size, - batch_id_per_token=v4_batch_id_per_token, - state_slot_mapping=state_slot_mapping, - swa_kv=self.swa_kv, + min(attn_md.max_seqlen_q, cache_size), ) - else: - # Flat q_flat is [num_tokens, n_local_heads * head_dim]; DualRMSNorm - # views internally to per-head shape and returns flat. Single Triton - # launch fuses per-head Q RMSNorm (weightless) + KV RMSNorm (both - # head_dim=128), replacing the prior `_rmsnorm_nw + kv_norm` pair. - q_flat, kv = self.qk_norm(q, kv_pre) - q = q_flat.view(num_tokens, self.n_local_heads, self.head_dim) - # q [S, H, D] / kv [S, head_dim] — rotary_emb internally unsqueezes - # to (1, num_tokens, ...) for aiter's per-position rope kernel. - self.rotary_emb(positions, q[..., -rd:], kv[..., -rd:]) - if is_decode: - swa_write( - kv, - positions, - attn_md.cu_seqlens_q, - state_slot_mapping, - self.swa_kv, - cache_size, - min(attn_md.max_seqlen_q, cache_size), - ) if _V4_USE_REF_QUANT: act_quant_inplace(kv[..., :-rd], 64, self.scale_fmt) @@ -1749,13 +1608,11 @@ def forward_impl( ) # ===== Sparse attention dispatch ===== - # Decode SWA write fires upstream of this dispatch — either inside - # `fused_qk_norm_rope_swa_write` (fused path above) or via the - # explicit `swa_write` call in the non-fused `else` branch — so - # `paged_decode` always sees the current token's K in the ring. - # Prefill does NOT call swa_write from this layer (prior-chunk K is - # read from `unified_kv` ring via the kv_indices_prefix_swa region). - q_sa = q.contiguous() + # Decode SWA write fires upstream of this dispatch via the + # ``swa_write`` call in the decode branch — so ``paged_decode`` + # always sees the current token's K in the ring. Prefill does NOT + # call swa_write from this layer (prior-chunk K is read from + # ``unified_kv`` ring via the kv_indices_prefix_swa region). if is_decode: if ratio == 0: kv_indices = attn_md.kv_indices_swa diff --git a/scripts/wait_server_ready.sh b/scripts/wait_server_ready.sh index 1f489a919..5e233716d 100755 --- a/scripts/wait_server_ready.sh +++ b/scripts/wait_server_ready.sh @@ -18,11 +18,22 @@ POLL="${3:-30}" LOG_FILE="${4:-/app/logs_claude/atom_server.log}" ITERS=$(( MAX_MIN * 60 / POLL )) +# Snapshot log size at start — only scan content APPENDED after this point. +# Prevents false-positive matches against errors left by a prior failed launch. +LOG_START_BYTES=0 +if [ -f "$LOG_FILE" ]; then + LOG_START_BYTES=$(stat -c %s "$LOG_FILE" 2>/dev/null || echo 0) +fi + for ((i=1; i<=ITERS; i++)); do sleep "$POLL" READY=$(curl -s -m 3 "http://localhost:${PORT}/v1/models" 2>/dev/null | head -c 60) - ERR=$(grep -c "cluster_dims\|InductorError\|SHUTDOWN signal\|proc died" \ - "$LOG_FILE" 2>/dev/null | head -1) + if [ -f "$LOG_FILE" ]; then + ERR=$(tail -c "+$((LOG_START_BYTES + 1))" "$LOG_FILE" 2>/dev/null \ + | grep -c "cluster_dims\|InductorError\|SHUTDOWN signal\|proc died") + else + ERR=0 + fi ERR="${ERR:-0}" echo "[t=$((i*POLL))s] ready=${READY:-(empty)} err=$ERR" if [ -n "$READY" ]; then @@ -31,11 +42,11 @@ for ((i=1; i<=ITERS; i++)); do fi if [ "$ERR" -gt 0 ]; then echo "Server FAILED to start (errors detected)" - tail -30 "$LOG_FILE" + tail -c "+$((LOG_START_BYTES + 1))" "$LOG_FILE" | tail -30 exit 1 fi done echo "Server NOT ready after ${MAX_MIN} min" -tail -30 "$LOG_FILE" +tail -c "+$((LOG_START_BYTES + 1))" "$LOG_FILE" | tail -30 exit 1 From 35afaaecea963d242e42ed7f1446fe404c5e16ac Mon Sep 17 00:00:00 2001 From: Lingpeng Jin <103567126+valarLip@users.noreply.github.com> Date: Sun, 24 May 2026 15:05:53 +0000 Subject: [PATCH 2/2] [v4-pro] fix ruff: drop unused envs import + rope_d_mask local --- atom/model_ops/v4_kernels/qk_norm_rope_maybe_quant.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/atom/model_ops/v4_kernels/qk_norm_rope_maybe_quant.py b/atom/model_ops/v4_kernels/qk_norm_rope_maybe_quant.py index 70ff6cd28..e5875688b 100644 --- a/atom/model_ops/v4_kernels/qk_norm_rope_maybe_quant.py +++ b/atom/model_ops/v4_kernels/qk_norm_rope_maybe_quant.py @@ -35,8 +35,6 @@ import triton import triton.language as tl -from atom.utils import envs - # Lazy-imported flydsl path (optional dependency). Set to None when flydsl # is unavailable; the dispatch in ``qk_norm_rope_maybe_quant`` will fall # back to the Triton kernel. @@ -129,7 +127,6 @@ def _qk_norm_rope_maybe_quant_kernel( d_offs = tl.arange(0, D) nope_d_mask = d_offs < NOPE - rope_d_mask = d_offs >= NOPE rd_offs = tl.arange(0, RD).to(tl.int64) cos_d_offs = rd_offs // 2 # GPT-J + REUSE_FREQS_FRONT_PART: lane duplicate