diff --git a/atom/model_ops/layernorm.py b/atom/model_ops/layernorm.py index 14898b200..cde67fcd2 100644 --- a/atom/model_ops/layernorm.py +++ b/atom/model_ops/layernorm.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: MIT # Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +import os from typing import Optional, Tuple import aiter @@ -17,6 +18,7 @@ from aiter.jit.utils.torch_guard import torch_compile_guard from aiter.ops.gated_rmsnorm_fp8_group_quant import gated_rmsnorm_fp8_group_quant from aiter.ops.triton.fused_add_rmsnorm_pad import fused_add_rmsnorm_pad +from aiter.ops.triton.quant.quant_mxfp8 import rmsnorm_mxfp8_quant from atom.config import QuantizationConfig from atom.model_ops.utils import atom_parameter from atom.quant_spec import LayerQuantConfig @@ -25,6 +27,8 @@ from torch import Tensor, nn from torch.overrides import handle_torch_function, has_torch_function_unary +_V4_USE_MXFP8 = os.environ.get("ATOM_FP8_BLOCKSCALE_USE_MXFP8", "0") == "1" + def silu(input: Tensor, inplace: bool = False) -> Tensor: r"""Apply the Sigmoid Linear Unit (SiLU) function, element-wise. @@ -318,6 +322,24 @@ def forward( and x_scale is None and self.quant_type.value in _AITER_RMS_QUANT_TYPE_VALUES ): + # MXFP8 path: when downstream Linear consumes MXFP8 1x32 e8m0 + # scales (ATOM_FP8_BLOCKSCALE_USE_MXFP8=1), emit those directly + # to skip the dequant+requant cascade. Limited to per_1x128 and + # the no-residual case (most q_norm callers). + if ( + self.quant_type.value == _QV_PER_1X128 + and residual is None + and _V4_USE_MXFP8 + ): + if x.dim() != 2: + x2 = x.reshape(-1, x.shape[-1]) + else: + x2 = x + y, s = rmsnorm_mxfp8_quant(x2, self.weight, self.eps) + if x.dim() != 2: + y = y.view(*x.shape[:-1], x.shape[-1]) + s = s.view(*x.shape[:-1], s.shape[-1]) + return y, s # Dynamic-scale fused RMSNorm + quant via aiter HIP kernels. # Static FP8 (x_scale provided) stays on the branch above. x, x_scale, residual_out = _aiter_rms_quant( diff --git a/atom/model_ops/linear.py b/atom/model_ops/linear.py index 5505b5942..6bee6e4c7 100644 --- a/atom/model_ops/linear.py +++ b/atom/model_ops/linear.py @@ -2,6 +2,7 @@ # Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. import logging +import os from functools import partial as functools_partial from typing import Callable, Optional @@ -22,6 +23,10 @@ from aiter.jit.utils.torch_guard import torch_compile_guard from aiter.tuned_gemm import tgemm from aiter.utility import fp4_utils +from aiter.ops.triton.gemm.basic.gemm_afp8wfp8 import ( + gemm_afp8wfp8, + gemm_afp8wfp8_preshuffle, +) from atom.config import QuantizationConfig, get_current_atom_config from atom.quant_spec import LayerQuantConfig from atom.model_ops.utils import ( @@ -286,11 +291,16 @@ def __init__( torch.empty(self.output_size, 1, dtype=dtypes.fp32) ) elif quant_type == QuantType.per_1x128: + # When MXFP8 a8w8 is active, allocate weight_scale directly as + # float8_e8m0fnu (1 byte) to match the on-disk dtype - avoids + # the lossy fp32 upcast/round-trip at load. + _mxfp8_env = os.environ.get("ATOM_FP8_BLOCKSCALE_USE_MXFP8", "0") == "1" + _scale_dtype = torch.float8_e8m0fnu if _mxfp8_env else dtypes.fp32 self.weight_scale = atom_parameter( torch.empty( (self.output_size + 127) // 128, (self.input_size + 127) // 128, - dtype=dtypes.fp32, + dtype=_scale_dtype, ) ) elif quant_type == QuantType.per_1x32: @@ -510,9 +520,22 @@ def process_weights_after_loading(self): self.quant_type == QuantType.per_Token and self.params_dtype == dtypes.fp8 ) or self.quant_type == QuantType.per_1x32 - # per_1x128 only needs shuffle when using the preshuffle GEMM path + # per_1x128 only needs shuffle when using a preshuffle GEMM path. + # MXFP8 path: only shuffle if gemm_mxfp8_preshuffle will consume the + # shuffled weight (gated by ATOM_MXFP8_USE_PRESHUFFLE, default on). + # The non-preshuffle gemm_mxfp8 reads scrambled rows otherwise. + _mxfp8_active = ( + self.quant_type == QuantType.per_1x128 + and os.environ.get("ATOM_FP8_BLOCKSCALE_USE_MXFP8", "0") == "1" + ) + _mxfp8_preshuffle = _mxfp8_active and ( + os.environ.get("ATOM_MXFP8_USE_PRESHUFFLE", "1") == "1" + ) if not need_shuffle and self.quant_type == QuantType.per_1x128: - need_shuffle = envs.ATOM_FP8_BLOCKSCALE_WEIGHT_PRESHUFFLE + if _mxfp8_active: + need_shuffle = _mxfp8_preshuffle + else: + need_shuffle = envs.ATOM_FP8_BLOCKSCALE_WEIGHT_PRESHUFFLE if need_shuffle: if self.weight.dim() == 2: shuffle_weights(self.weight) @@ -521,6 +544,30 @@ def process_weights_after_loading(self): if self.quant_type == QuantType.per_1x32: self.weight_scale.data = fp4_utils.e8m0_shuffle(self.weight_scale.data) + # ---- Optional: MXFP8 a8w8 path (dormant, only active when env set) ---- + # On disk V4-Flash stores per_1x128 weight scales as float8_e8m0fnu + # (uint8 powers-of-2) and weights as float8_e4m3fn. With MXFP8 active, + # __init__ allocated weight_scale as float8_e8m0fnu so the loader copied + # bits without any upcast/round-trip. Just expose as uint8 view here + # (gemm_mxfp8 expects uint8) and convert fp8 weight to fn-bits if the + # loader landed it as fnuz. + self._mxfp8_active = ( + self.quant_type == QuantType.per_1x128 + and os.environ.get("ATOM_FP8_BLOCKSCALE_USE_MXFP8", "0") == "1" + ) + if self._mxfp8_active: + if self.weight_scale.data.dtype != torch.uint8: + self.weight_scale = atom_parameter( + self.weight_scale.data.view(torch.uint8).contiguous() + ) + if self.weight.data.dtype == torch.float8_e4m3fnuz: + # On-disk format is float8_e4m3fn but the param was allocated + # as fnuz (AMD default). The bit patterns differ - convert via + # value to preserve semantics. + w_val = self.weight.data.float() + w_fn = w_val.clamp(-448.0, 448.0).to(torch.float8_e4m3fn) + self.weight = atom_parameter(w_fn) + @mark_trace def forward( self, x: torch.Tensor, x_scale: Optional[torch.Tensor] = None, otype=dtypes.bf16 @@ -533,7 +580,31 @@ def forward( otype=otype, ) else: - if x_scale is None: + # MXFP8 path: route per_1x128 weights through MXFP8 GEMM. + _use_mxfp8 = self.quant_type.value == QuantType.per_1x128.value and getattr( + self, "_mxfp8_active", False + ) + if _use_mxfp8: + from aiter.ops.triton.quant.quant_mxfp8 import ( + per_1x32_mxfp8_quant_triton, + ) + + if x_scale is None: + x, x_scale = per_1x32_mxfp8_quant_triton(x) + elif x_scale.dtype == torch.uint8: + pass # caller already emitted MXFP8 1x32 + else: + # legacy caller emitted FP8 + fp32 1x128 scale. Dequant to + # bf16, then re-quantize to MXFP8 1x32. + Mx, Kx = x.shape + sM, sCols = x_scale.shape + group = Kx // sCols + x_dq = x.to(torch.float32).view(Mx, sCols, group) * x_scale.to( + torch.float32 + ).view(sM, sCols, 1) + x_bf16 = x_dq.view(Mx, Kx).to(torch.bfloat16) + x, x_scale = per_1x32_mxfp8_quant_triton(x_bf16) + elif x_scale is None: quant_func = self.quant_func if self.quant_type.value == QuantType.per_1x128.value: # preshuffle GEMM expects column-major x_scale; @@ -578,7 +649,24 @@ def forward( if self.bias is not None: y += self.bias elif self.quant_type.value == QuantType.per_1x128.value: - if envs.ATOM_FP8_BLOCKSCALE_WEIGHT_PRESHUFFLE: + if _use_mxfp8: + if envs.ATOM_FP8_BLOCKSCALE_WEIGHT_PRESHUFFLE: + y = gemm_afp8wfp8_preshuffle( + x, + self.weight, + x_scale, + self.weight_scale, + dtype=otype, + ) + else: + y = gemm_afp8wfp8( + x, + self.weight, + x_scale, + self.weight_scale, + dtype=otype, + ) + elif envs.ATOM_FP8_BLOCKSCALE_WEIGHT_PRESHUFFLE: y = gemm_a8w8_blockscale_preshuffle_impl( x, self.weight, diff --git a/atom/models/deepseek_v4.py b/atom/models/deepseek_v4.py index 7bb238941..b284e98a8 100644 --- a/atom/models/deepseek_v4.py +++ b/atom/models/deepseek_v4.py @@ -126,6 +126,48 @@ # Fused-kernel switches. Default off; flip via env to A/B against the eager path. _V4_USE_TRITON_FUSION = os.environ.get("ATOM_V4_USE_TRITON_FUSION", "0") == "1" ENABLE_DS_QKNORM_QUANT_FUSION = envs.ATOM_ENABLE_DS_QKNORM_QUANT_FUSION +# MXFP8 a8w8 GEMM path (Task #77). When on, q_norm RMSNorm emits FP8 e4m3fn + +# uint8 e8m0 1x32 scales directly (via the Triton rmsnorm_mxfp8_quant path in +# atom/model_ops/layernorm.py) so wq_b's MXFP8 GEMM consumes them with no +# transcode. The helper below is reserved for call sites that need to fuse +# the Q-side MXFP8 quant with the K-side bf16 RMSNorm in a single launch. +_V4_USE_MXFP8 = os.environ.get("ATOM_FP8_BLOCKSCALE_USE_MXFP8", "0") == "1" +_V4_DISABLE_FCAM = os.environ.get("ATOM_V4_DISABLE_FCAM", "0") == "1" +_MXFP8_BYPASS_FCAM = os.environ.get("ATOM_MXFP8_BYPASS_FCAM", "0") == "1" + + +def _fuse_rmsnorm_mxfp8_quant( + q_lora: torch.Tensor, + q_norm_weight: torch.Tensor, + q_norm_eps: float, + kv_pre: torch.Tensor, + kv_norm_weight: torch.Tensor, + kv_norm_eps: float, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Dual RMSNorm with MXFP8 emit on Q side, bf16 on K side, single launch. + + Q: RMSNorm + per-1x32 MXFP8 quant -- emits FP8 e4m3fn + uint8 e8m0 scale + directly so the downstream wq_b MXFP8 GEMM consumes them with no + dequant+requant cascade. + K: RMSNorm only, bf16 output, consumed by the downstream RoPE / SWA-write + fused kernel which expects bf16 K. + + Both halves run in a single Triton launch (dual_rmsnorm_mxfp8_quant) to + avoid the +6us/layer per-launch overhead of two separate kernels. + + Returns (qr_fp8, qr_scale_e8m0, kv_bf16). + """ + from aiter.ops.triton.quant.quant_mxfp8 import dual_rmsnorm_mxfp8_quant + + qr, qr_scale, kv = dual_rmsnorm_mxfp8_quant( + q_lora, + kv_pre, + q_norm_weight, + kv_norm_weight, + q_norm_eps, + kv_norm_eps, + ) + return qr, qr_scale, kv def _rmsnorm_nw(x: torch.Tensor, eps: float, dim: int) -> torch.Tensor: @@ -1972,7 +2014,9 @@ def __init__( # Switch: route clamp + silu(gate)*up [+ weights] + per-token FP8 1x128 # quant through a single aiter triton kernel. The fused kernel emits # FP8 + scale; w2 accepts `x_scale` and skips its own quant step. - self.use_fused_clamp_act_mul = _V4_USE_TRITON_FUSION + # ATOM_V4_DISABLE_FCAM=1 forces the unfused eager path for A/B vs the + # Triton fused-clamp-act-mul kernel. + self.use_fused_clamp_act_mul = _V4_USE_TRITON_FUSION and not _V4_DISABLE_FCAM def forward( self, @@ -1987,14 +2031,29 @@ def forward( # feed the bf16 GEMM output directly. combined = self.gate_up_proj(x) # [num_tokens, 2*inter_dim_per_tp] if self.use_fused_clamp_act_mul: - x_fp8, x_scale = fused_clamp_act_mul( - combined, - swiglu_limit=self.swiglu_limit, - activation="silu", - weights=weights, - dtype_quant=dtypes.fp8, - transpose_scale=True, - ) + # When MXFP8 is enabled, ATOM_MXFP8_BYPASS_FCAM=1 forces the + # legacy fp32 1x128 emit + linear.py's dequant+requant fallback, + # isolating whether fused_clamp_act_mul's MXFP8 emit has a bug. + if _V4_USE_MXFP8 and not _MXFP8_BYPASS_FCAM: + x_fp8, x_scale = fused_clamp_act_mul( + combined, + swiglu_limit=self.swiglu_limit, + activation="silu", + weights=weights, + dtype_quant=torch.float8_e4m3fn, + quant_block_size=32, + scale_dtype_fmt="ue8m0", + transpose_scale=False, + ) + else: + x_fp8, x_scale = fused_clamp_act_mul( + combined, + swiglu_limit=self.swiglu_limit, + activation="silu", + weights=weights, + dtype_quant=dtypes.fp8, + transpose_scale=True, + ) return self.w2(x_fp8, x_scale=x_scale) out = torch.empty( (combined.shape[0], combined.shape[-1] // 2),