Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions atom/model_ops/layernorm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.

import os
from typing import Optional, Tuple

import aiter
Expand All @@ -17,6 +18,7 @@
from aiter.jit.utils.torch_guard import torch_compile_guard
from aiter.ops.gated_rmsnorm_fp8_group_quant import gated_rmsnorm_fp8_group_quant
from aiter.ops.triton.fused_add_rmsnorm_pad import fused_add_rmsnorm_pad
from aiter.ops.triton.quant.quant_mxfp8 import rmsnorm_mxfp8_quant
from atom.config import QuantizationConfig
from atom.model_ops.utils import atom_parameter
from atom.quant_spec import LayerQuantConfig
Expand All @@ -25,6 +27,8 @@
from torch import Tensor, nn
from torch.overrides import handle_torch_function, has_torch_function_unary

_V4_USE_MXFP8 = os.environ.get("ATOM_FP8_BLOCKSCALE_USE_MXFP8", "0") == "1"


def silu(input: Tensor, inplace: bool = False) -> Tensor:
r"""Apply the Sigmoid Linear Unit (SiLU) function, element-wise.
Expand Down Expand Up @@ -318,6 +322,24 @@ def forward(
and x_scale is None
and self.quant_type.value in _AITER_RMS_QUANT_TYPE_VALUES
):
# MXFP8 path: when downstream Linear consumes MXFP8 1x32 e8m0
# scales (ATOM_FP8_BLOCKSCALE_USE_MXFP8=1), emit those directly
# to skip the dequant+requant cascade. Limited to per_1x128 and
# the no-residual case (most q_norm callers).
if (
self.quant_type.value == _QV_PER_1X128
and residual is None
and _V4_USE_MXFP8
):
if x.dim() != 2:
x2 = x.reshape(-1, x.shape[-1])
else:
x2 = x
y, s = rmsnorm_mxfp8_quant(x2, self.weight, self.eps)
if x.dim() != 2:
y = y.view(*x.shape[:-1], x.shape[-1])
s = s.view(*x.shape[:-1], s.shape[-1])
return y, s
# Dynamic-scale fused RMSNorm + quant via aiter HIP kernels.
# Static FP8 (x_scale provided) stays on the branch above.
x, x_scale, residual_out = _aiter_rms_quant(
Expand Down
98 changes: 93 additions & 5 deletions atom/model_ops/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.

import logging
import os
from functools import partial as functools_partial
from typing import Callable, Optional

Expand All @@ -22,6 +23,10 @@
from aiter.jit.utils.torch_guard import torch_compile_guard
from aiter.tuned_gemm import tgemm
from aiter.utility import fp4_utils
from aiter.ops.triton.gemm.basic.gemm_afp8wfp8 import (
gemm_afp8wfp8,
gemm_afp8wfp8_preshuffle,
)
from atom.config import QuantizationConfig, get_current_atom_config
from atom.quant_spec import LayerQuantConfig
from atom.model_ops.utils import (
Expand Down Expand Up @@ -286,11 +291,16 @@ def __init__(
torch.empty(self.output_size, 1, dtype=dtypes.fp32)
)
elif quant_type == QuantType.per_1x128:
# When MXFP8 a8w8 is active, allocate weight_scale directly as
# float8_e8m0fnu (1 byte) to match the on-disk dtype - avoids
# the lossy fp32 upcast/round-trip at load.
_mxfp8_env = os.environ.get("ATOM_FP8_BLOCKSCALE_USE_MXFP8", "0") == "1"
_scale_dtype = torch.float8_e8m0fnu if _mxfp8_env else dtypes.fp32
self.weight_scale = atom_parameter(
torch.empty(
(self.output_size + 127) // 128,
(self.input_size + 127) // 128,
dtype=dtypes.fp32,
dtype=_scale_dtype,
)
)
elif quant_type == QuantType.per_1x32:
Expand Down Expand Up @@ -510,9 +520,22 @@ def process_weights_after_loading(self):
self.quant_type == QuantType.per_Token
and self.params_dtype == dtypes.fp8
) or self.quant_type == QuantType.per_1x32
# per_1x128 only needs shuffle when using the preshuffle GEMM path
# per_1x128 only needs shuffle when using a preshuffle GEMM path.
# MXFP8 path: only shuffle if gemm_mxfp8_preshuffle will consume the
# shuffled weight (gated by ATOM_MXFP8_USE_PRESHUFFLE, default on).
# The non-preshuffle gemm_mxfp8 reads scrambled rows otherwise.
_mxfp8_active = (
self.quant_type == QuantType.per_1x128
and os.environ.get("ATOM_FP8_BLOCKSCALE_USE_MXFP8", "0") == "1"
)
_mxfp8_preshuffle = _mxfp8_active and (
os.environ.get("ATOM_MXFP8_USE_PRESHUFFLE", "1") == "1"
)
if not need_shuffle and self.quant_type == QuantType.per_1x128:
need_shuffle = envs.ATOM_FP8_BLOCKSCALE_WEIGHT_PRESHUFFLE
if _mxfp8_active:
need_shuffle = _mxfp8_preshuffle
else:
need_shuffle = envs.ATOM_FP8_BLOCKSCALE_WEIGHT_PRESHUFFLE
if need_shuffle:
if self.weight.dim() == 2:
shuffle_weights(self.weight)
Expand All @@ -521,6 +544,30 @@ def process_weights_after_loading(self):
if self.quant_type == QuantType.per_1x32:
self.weight_scale.data = fp4_utils.e8m0_shuffle(self.weight_scale.data)

# ---- Optional: MXFP8 a8w8 path (dormant, only active when env set) ----
# On disk V4-Flash stores per_1x128 weight scales as float8_e8m0fnu
# (uint8 powers-of-2) and weights as float8_e4m3fn. With MXFP8 active,
# __init__ allocated weight_scale as float8_e8m0fnu so the loader copied
# bits without any upcast/round-trip. Just expose as uint8 view here
# (gemm_mxfp8 expects uint8) and convert fp8 weight to fn-bits if the
# loader landed it as fnuz.
self._mxfp8_active = (
self.quant_type == QuantType.per_1x128
and os.environ.get("ATOM_FP8_BLOCKSCALE_USE_MXFP8", "0") == "1"
)
if self._mxfp8_active:
if self.weight_scale.data.dtype != torch.uint8:
self.weight_scale = atom_parameter(
self.weight_scale.data.view(torch.uint8).contiguous()
)
if self.weight.data.dtype == torch.float8_e4m3fnuz:
# On-disk format is float8_e4m3fn but the param was allocated
# as fnuz (AMD default). The bit patterns differ - convert via
# value to preserve semantics.
w_val = self.weight.data.float()
w_fn = w_val.clamp(-448.0, 448.0).to(torch.float8_e4m3fn)
self.weight = atom_parameter(w_fn)

@mark_trace
def forward(
self, x: torch.Tensor, x_scale: Optional[torch.Tensor] = None, otype=dtypes.bf16
Expand All @@ -533,7 +580,31 @@ def forward(
otype=otype,
)
else:
if x_scale is None:
# MXFP8 path: route per_1x128 weights through MXFP8 GEMM.
_use_mxfp8 = self.quant_type.value == QuantType.per_1x128.value and getattr(
self, "_mxfp8_active", False
)
if _use_mxfp8:
from aiter.ops.triton.quant.quant_mxfp8 import (
per_1x32_mxfp8_quant_triton,
)

if x_scale is None:
x, x_scale = per_1x32_mxfp8_quant_triton(x)
elif x_scale.dtype == torch.uint8:
pass # caller already emitted MXFP8 1x32
else:
# legacy caller emitted FP8 + fp32 1x128 scale. Dequant to
# bf16, then re-quantize to MXFP8 1x32.
Mx, Kx = x.shape
sM, sCols = x_scale.shape
group = Kx // sCols
x_dq = x.to(torch.float32).view(Mx, sCols, group) * x_scale.to(
torch.float32
).view(sM, sCols, 1)
x_bf16 = x_dq.view(Mx, Kx).to(torch.bfloat16)
x, x_scale = per_1x32_mxfp8_quant_triton(x_bf16)
elif x_scale is None:
quant_func = self.quant_func
if self.quant_type.value == QuantType.per_1x128.value:
# preshuffle GEMM expects column-major x_scale;
Expand Down Expand Up @@ -578,7 +649,24 @@ def forward(
if self.bias is not None:
y += self.bias
elif self.quant_type.value == QuantType.per_1x128.value:
if envs.ATOM_FP8_BLOCKSCALE_WEIGHT_PRESHUFFLE:
if _use_mxfp8:
if envs.ATOM_FP8_BLOCKSCALE_WEIGHT_PRESHUFFLE:
y = gemm_afp8wfp8_preshuffle(
x,
self.weight,
x_scale,
self.weight_scale,
dtype=otype,
)
else:
y = gemm_afp8wfp8(
x,
self.weight,
x_scale,
self.weight_scale,
dtype=otype,
)
elif envs.ATOM_FP8_BLOCKSCALE_WEIGHT_PRESHUFFLE:
y = gemm_a8w8_blockscale_preshuffle_impl(
x,
self.weight,
Expand Down
77 changes: 68 additions & 9 deletions atom/models/deepseek_v4.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,48 @@
# Fused-kernel switches. Default off; flip via env to A/B against the eager path.
_V4_USE_TRITON_FUSION = os.environ.get("ATOM_V4_USE_TRITON_FUSION", "0") == "1"
ENABLE_DS_QKNORM_QUANT_FUSION = envs.ATOM_ENABLE_DS_QKNORM_QUANT_FUSION
# MXFP8 a8w8 GEMM path (Task #77). When on, q_norm RMSNorm emits FP8 e4m3fn +
# uint8 e8m0 1x32 scales directly (via the Triton rmsnorm_mxfp8_quant path in
# atom/model_ops/layernorm.py) so wq_b's MXFP8 GEMM consumes them with no
# transcode. The helper below is reserved for call sites that need to fuse
# the Q-side MXFP8 quant with the K-side bf16 RMSNorm in a single launch.
_V4_USE_MXFP8 = os.environ.get("ATOM_FP8_BLOCKSCALE_USE_MXFP8", "0") == "1"
_V4_DISABLE_FCAM = os.environ.get("ATOM_V4_DISABLE_FCAM", "0") == "1"
_MXFP8_BYPASS_FCAM = os.environ.get("ATOM_MXFP8_BYPASS_FCAM", "0") == "1"


def _fuse_rmsnorm_mxfp8_quant(
q_lora: torch.Tensor,
q_norm_weight: torch.Tensor,
q_norm_eps: float,
kv_pre: torch.Tensor,
kv_norm_weight: torch.Tensor,
kv_norm_eps: float,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Dual RMSNorm with MXFP8 emit on Q side, bf16 on K side, single launch.

Q: RMSNorm + per-1x32 MXFP8 quant -- emits FP8 e4m3fn + uint8 e8m0 scale
directly so the downstream wq_b MXFP8 GEMM consumes them with no
dequant+requant cascade.
K: RMSNorm only, bf16 output, consumed by the downstream RoPE / SWA-write
fused kernel which expects bf16 K.

Both halves run in a single Triton launch (dual_rmsnorm_mxfp8_quant) to
avoid the +6us/layer per-launch overhead of two separate kernels.

Returns (qr_fp8, qr_scale_e8m0, kv_bf16).
"""
from aiter.ops.triton.quant.quant_mxfp8 import dual_rmsnorm_mxfp8_quant

qr, qr_scale, kv = dual_rmsnorm_mxfp8_quant(
q_lora,
kv_pre,
q_norm_weight,
kv_norm_weight,
q_norm_eps,
kv_norm_eps,
)
return qr, qr_scale, kv


def _rmsnorm_nw(x: torch.Tensor, eps: float, dim: int) -> torch.Tensor:
Expand Down Expand Up @@ -1972,7 +2014,9 @@ def __init__(
# Switch: route clamp + silu(gate)*up [+ weights] + per-token FP8 1x128
# quant through a single aiter triton kernel. The fused kernel emits
# FP8 + scale; w2 accepts `x_scale` and skips its own quant step.
self.use_fused_clamp_act_mul = _V4_USE_TRITON_FUSION
# ATOM_V4_DISABLE_FCAM=1 forces the unfused eager path for A/B vs the
# Triton fused-clamp-act-mul kernel.
self.use_fused_clamp_act_mul = _V4_USE_TRITON_FUSION and not _V4_DISABLE_FCAM

def forward(
self,
Expand All @@ -1987,14 +2031,29 @@ def forward(
# feed the bf16 GEMM output directly.
combined = self.gate_up_proj(x) # [num_tokens, 2*inter_dim_per_tp]
if self.use_fused_clamp_act_mul:
x_fp8, x_scale = fused_clamp_act_mul(
combined,
swiglu_limit=self.swiglu_limit,
activation="silu",
weights=weights,
dtype_quant=dtypes.fp8,
transpose_scale=True,
)
# When MXFP8 is enabled, ATOM_MXFP8_BYPASS_FCAM=1 forces the
# legacy fp32 1x128 emit + linear.py's dequant+requant fallback,
# isolating whether fused_clamp_act_mul's MXFP8 emit has a bug.
if _V4_USE_MXFP8 and not _MXFP8_BYPASS_FCAM:
x_fp8, x_scale = fused_clamp_act_mul(
combined,
swiglu_limit=self.swiglu_limit,
activation="silu",
weights=weights,
dtype_quant=torch.float8_e4m3fn,
quant_block_size=32,
scale_dtype_fmt="ue8m0",
transpose_scale=False,
)
else:
x_fp8, x_scale = fused_clamp_act_mul(
combined,
swiglu_limit=self.swiglu_limit,
activation="silu",
weights=weights,
dtype_quant=dtypes.fp8,
transpose_scale=True,
)
return self.w2(x_fp8, x_scale=x_scale)
out = torch.empty(
(combined.shape[0], combined.shape[-1] // 2),
Expand Down
Loading