From 97ac3d8d6d6d001121876319e67273689336e68e Mon Sep 17 00:00:00 2001 From: ShaoChunLee Date: Tue, 19 May 2026 21:32:43 +0000 Subject: [PATCH 1/2] integration --- atom/model_ops/layernorm.py | 23 +++++ atom/model_ops/linear.py | 171 ++++++++++++++++++++++++++++++++++-- atom/models/deepseek_v4.py | 79 +++++++++++++++-- 3 files changed, 259 insertions(+), 14 deletions(-) diff --git a/atom/model_ops/layernorm.py b/atom/model_ops/layernorm.py index 14898b200..138cd3e2f 100644 --- a/atom/model_ops/layernorm.py +++ b/atom/model_ops/layernorm.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: MIT # Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +import os from typing import Optional, Tuple import aiter @@ -318,6 +319,28 @@ def forward( and x_scale is None and self.quant_type.value in _AITER_RMS_QUANT_TYPE_VALUES ): + # MXFP8 path: when downstream Linear consumes MXFP8 1x32 e8m0 + # scales (ATOM_FP8_BLOCKSCALE_USE_MXFP8=1), emit those directly + # to skip the dequant+requant cascade. Limited to per_1x128 and + # the no-residual case (most q_norm callers). + if ( + self.quant_type.value == _QV_PER_1X128 + and residual is None + and os.environ.get("ATOM_FP8_BLOCKSCALE_USE_MXFP8", "0") == "1" + ): + from aiter.ops.triton.quant.quant_mxfp8 import ( + rmsnorm_mxfp8_quant, + ) + + if x.dim() != 2: + x2 = x.reshape(-1, x.shape[-1]) + else: + x2 = x + y, s = rmsnorm_mxfp8_quant(x2, self.weight, self.eps) + if x.dim() != 2: + y = y.view(*x.shape[:-1], x.shape[-1]) + s = s.view(*x.shape[:-1], s.shape[-1]) + return y, s # Dynamic-scale fused RMSNorm + quant via aiter HIP kernels. # Static FP8 (x_scale provided) stays on the branch above. x, x_scale, residual_out = _aiter_rms_quant( diff --git a/atom/model_ops/linear.py b/atom/model_ops/linear.py index 5505b5942..5b1fcc1ac 100644 --- a/atom/model_ops/linear.py +++ b/atom/model_ops/linear.py @@ -2,6 +2,7 @@ # Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. import logging +import os from functools import partial as functools_partial from typing import Callable, Optional @@ -286,11 +287,16 @@ def __init__( torch.empty(self.output_size, 1, dtype=dtypes.fp32) ) elif quant_type == QuantType.per_1x128: + # When MXFP8 a8w8 is active, allocate weight_scale directly as + # float8_e8m0fnu (1 byte) to match the on-disk dtype - avoids + # the lossy fp32 upcast/round-trip at load. + _mxfp8_env = os.environ.get("ATOM_FP8_BLOCKSCALE_USE_MXFP8", "0") == "1" + _scale_dtype = torch.float8_e8m0fnu if _mxfp8_env else dtypes.fp32 self.weight_scale = atom_parameter( torch.empty( (self.output_size + 127) // 128, (self.input_size + 127) // 128, - dtype=dtypes.fp32, + dtype=_scale_dtype, ) ) elif quant_type == QuantType.per_1x32: @@ -510,9 +516,22 @@ def process_weights_after_loading(self): self.quant_type == QuantType.per_Token and self.params_dtype == dtypes.fp8 ) or self.quant_type == QuantType.per_1x32 - # per_1x128 only needs shuffle when using the preshuffle GEMM path + # per_1x128 only needs shuffle when using a preshuffle GEMM path. + # MXFP8 path: only shuffle if gemm_mxfp8_preshuffle will consume the + # shuffled weight (gated by ATOM_MXFP8_USE_PRESHUFFLE, default on). + # The non-preshuffle gemm_mxfp8 reads scrambled rows otherwise. + _mxfp8_active = ( + self.quant_type == QuantType.per_1x128 + and os.environ.get("ATOM_FP8_BLOCKSCALE_USE_MXFP8", "0") == "1" + ) + _mxfp8_preshuffle = _mxfp8_active and ( + os.environ.get("ATOM_MXFP8_USE_PRESHUFFLE", "1") == "1" + ) if not need_shuffle and self.quant_type == QuantType.per_1x128: - need_shuffle = envs.ATOM_FP8_BLOCKSCALE_WEIGHT_PRESHUFFLE + if _mxfp8_active: + need_shuffle = _mxfp8_preshuffle + else: + need_shuffle = envs.ATOM_FP8_BLOCKSCALE_WEIGHT_PRESHUFFLE if need_shuffle: if self.weight.dim() == 2: shuffle_weights(self.weight) @@ -521,6 +540,30 @@ def process_weights_after_loading(self): if self.quant_type == QuantType.per_1x32: self.weight_scale.data = fp4_utils.e8m0_shuffle(self.weight_scale.data) + # ---- Optional: MXFP8 a8w8 path (dormant, only active when env set) ---- + # On disk V4-Flash stores per_1x128 weight scales as float8_e8m0fnu + # (uint8 powers-of-2) and weights as float8_e4m3fn. With MXFP8 active, + # __init__ allocated weight_scale as float8_e8m0fnu so the loader copied + # bits without any upcast/round-trip. Just expose as uint8 view here + # (gemm_mxfp8 expects uint8) and convert fp8 weight to fn-bits if the + # loader landed it as fnuz. + self._mxfp8_active = ( + self.quant_type == QuantType.per_1x128 + and os.environ.get("ATOM_FP8_BLOCKSCALE_USE_MXFP8", "0") == "1" + ) + if self._mxfp8_active: + if self.weight_scale.data.dtype != torch.uint8: + self.weight_scale = atom_parameter( + self.weight_scale.data.view(torch.uint8).contiguous() + ) + if self.weight.data.dtype == torch.float8_e4m3fnuz: + # On-disk format is float8_e4m3fn but the param was allocated + # as fnuz (AMD default). The bit patterns differ - convert via + # value to preserve semantics. + w_val = self.weight.data.float() + w_fn = w_val.clamp(-448.0, 448.0).to(torch.float8_e4m3fn) + self.weight = atom_parameter(w_fn) + @mark_trace def forward( self, x: torch.Tensor, x_scale: Optional[torch.Tensor] = None, otype=dtypes.bf16 @@ -533,7 +576,31 @@ def forward( otype=otype, ) else: - if x_scale is None: + # MXFP8 path: route per_1x128 weights through MXFP8 GEMM. + _use_mxfp8 = self.quant_type.value == QuantType.per_1x128.value and getattr( + self, "_mxfp8_active", False + ) + if _use_mxfp8: + from aiter.ops.triton.quant.quant_mxfp8 import ( + per_1x32_mxfp8_quant_triton, + ) + + if x_scale is None: + x, x_scale = per_1x32_mxfp8_quant_triton(x) + elif x_scale.dtype == torch.uint8: + pass # caller already emitted MXFP8 1x32 + else: + # legacy caller emitted FP8 + fp32 1x128 scale. Dequant to + # bf16, then re-quantize to MXFP8 1x32. + Mx, Kx = x.shape + sM, sCols = x_scale.shape + group = Kx // sCols + x_dq = x.to(torch.float32).view(Mx, sCols, group) * x_scale.to( + torch.float32 + ).view(sM, sCols, 1) + x_bf16 = x_dq.view(Mx, Kx).to(torch.bfloat16) + x, x_scale = per_1x32_mxfp8_quant_triton(x_bf16) + elif x_scale is None: quant_func = self.quant_func if self.quant_type.value == QuantType.per_1x128.value: # preshuffle GEMM expects column-major x_scale; @@ -578,7 +645,101 @@ def forward( if self.bias is not None: y += self.bias elif self.quant_type.value == QuantType.per_1x128.value: - if envs.ATOM_FP8_BLOCKSCALE_WEIGHT_PRESHUFFLE: + if _use_mxfp8: + # One-shot dump: overwrite each call so the final saved file + # is the most recent call (curl-driven, post-warmup). + # _dump_env = os.environ.get("ATOM_MXFP8_DUMP", "") + # if _dump_env: + # if _dump_env in getattr(self, "prefix", ""): + # try: + # import torch as _t + # _t.save( + # { + # "prefix": self.prefix, + # "x": x.detach().cpu(), + # "weight": self.weight.data.detach().cpu(), + # "x_scale": x_scale.detach().cpu() + # if x_scale is not None + # else None, + # "weight_scale": self.weight_scale.data.detach().cpu(), + # "otype": str(otype), + # }, + # f"/tmp/mxfp8_dump_{self.prefix.replace('.', '_')}.pt", + # ) + # print( + # f"[MXFP8_DUMP] saved /tmp/mxfp8_dump_{self.prefix.replace('.', '_')}.pt " + # f"x={tuple(x.shape)}/{x.dtype} " + # f"w={tuple(self.weight.shape)}/{self.weight.dtype} " + # f"xs={tuple(x_scale.shape) if x_scale is not None else None}/{x_scale.dtype if x_scale is not None else None} " + # f"ws={tuple(self.weight_scale.shape)}/{self.weight_scale.dtype}", + # flush=True, + # ) + # self._mxfp8_dumped = True + # except Exception as e: + # print(f"[MXFP8_DUMP_FAIL] {e}", flush=True) + # try: + if envs.ATOM_FP8_BLOCKSCALE_WEIGHT_PRESHUFFLE: + from aiter.ops.triton.gemm.basic.gemm_afp8wfp8 import ( + gemm_afp8wfp8_preshuffle, + ) + + y = gemm_afp8wfp8_preshuffle( + x, + self.weight, + x_scale, + self.weight_scale, + dtype=otype, + ) + else: + from aiter.ops.triton.gemm.basic.gemm_afp8wfp8 import ( + gemm_afp8wfp8, + ) + + y = gemm_afp8wfp8( + x, + self.weight, + x_scale, + self.weight_scale, + dtype=otype, + ) + # except Exception as e: + # print( + # f"[MXFP8_FAIL] x={tuple(x.shape)}/{x.dtype} " + # f"w={tuple(self.weight.shape)}/{self.weight.dtype} " + # f"xs={tuple(x_scale.shape)}/{x_scale.dtype} " + # f"ws={tuple(self.weight_scale.shape)}/{self.weight_scale.dtype} " + # f"err={type(e).__name__}: {str(e)[:200]}", + # flush=True, + # ) + # raise + elif envs.ATOM_FP8_BLOCKSCALE_WEIGHT_PRESHUFFLE: + # If MXFP8 loader was active but kernel was skipped, the + # weight_scale is uint8 (e8m0). Legacy preshuffle GEMM wants + # fp32 - decode on the fly. + # _ws = self.weight_scale + # if _ws.dtype == torch.uint8: + # _ws = fp4_utils.e8m0_to_f32(_ws.view(torch.float8_e8m0fnu)) + # Legacy dump hook for A/B comparison vs MXFP8 path. + # _dump_env = os.environ.get("ATOM_MXFP8_DUMP", "") + # if _dump_env and _dump_env in getattr(self, "prefix", ""): + # try: + # import torch as _t + # _t.save( + # { + # "prefix": self.prefix, + # "x": x.detach().cpu(), + # "weight": self.weight.data.detach().cpu(), + # "x_scale": x_scale.detach().cpu() + # if x_scale is not None + # else None, + # "weight_scale": self.weight_scale.data.detach().cpu(), + # "otype": str(otype), + # "path": "legacy", + # }, + # f"/tmp/legacy_dump_{self.prefix.replace('.', '_')}.pt", + # ) + # except Exception as e: + # print(f"[LEGACY_DUMP_FAIL] {e}", flush=True) y = gemm_a8w8_blockscale_preshuffle_impl( x, self.weight, diff --git a/atom/models/deepseek_v4.py b/atom/models/deepseek_v4.py index 7bb238941..0a1552a2a 100644 --- a/atom/models/deepseek_v4.py +++ b/atom/models/deepseek_v4.py @@ -126,6 +126,46 @@ # Fused-kernel switches. Default off; flip via env to A/B against the eager path. _V4_USE_TRITON_FUSION = os.environ.get("ATOM_V4_USE_TRITON_FUSION", "0") == "1" ENABLE_DS_QKNORM_QUANT_FUSION = envs.ATOM_ENABLE_DS_QKNORM_QUANT_FUSION +# MXFP8 a8w8 GEMM path (Task #77). When on, q_norm RMSNorm emits FP8 e4m3fn + +# uint8 e8m0 1x32 scales directly (via the Triton rmsnorm_mxfp8_quant path in +# atom/model_ops/layernorm.py) so wq_b's MXFP8 GEMM consumes them with no +# transcode. The helper below is reserved for call sites that need to fuse +# the Q-side MXFP8 quant with the K-side bf16 RMSNorm in a single launch. +_V4_USE_MXFP8 = os.environ.get("ATOM_FP8_BLOCKSCALE_USE_MXFP8", "0") == "1" + + +def _fuse_rmsnorm_mxfp8_quant( + q_lora: torch.Tensor, + q_norm_weight: torch.Tensor, + q_norm_eps: float, + kv_pre: torch.Tensor, + kv_norm_weight: torch.Tensor, + kv_norm_eps: float, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Dual RMSNorm with MXFP8 emit on Q side, bf16 on K side, single launch. + + Q: RMSNorm + per-1x32 MXFP8 quant -- emits FP8 e4m3fn + uint8 e8m0 scale + directly so the downstream wq_b MXFP8 GEMM consumes them with no + dequant+requant cascade. + K: RMSNorm only, bf16 output, consumed by the downstream RoPE / SWA-write + fused kernel which expects bf16 K. + + Both halves run in a single Triton launch (dual_rmsnorm_mxfp8_quant) to + avoid the +6us/layer per-launch overhead of two separate kernels. + + Returns (qr_fp8, qr_scale_e8m0, kv_bf16). + """ + from aiter.ops.triton.quant.quant_mxfp8 import dual_rmsnorm_mxfp8_quant + + qr, qr_scale, kv = dual_rmsnorm_mxfp8_quant( + q_lora, + kv_pre, + q_norm_weight, + kv_norm_weight, + q_norm_eps, + kv_norm_eps, + ) + return qr, qr_scale, kv def _rmsnorm_nw(x: torch.Tensor, eps: float, dim: int) -> torch.Tensor: @@ -1972,7 +2012,11 @@ def __init__( # Switch: route clamp + silu(gate)*up [+ weights] + per-token FP8 1x128 # quant through a single aiter triton kernel. The fused kernel emits # FP8 + scale; w2 accepts `x_scale` and skips its own quant step. - self.use_fused_clamp_act_mul = _V4_USE_TRITON_FUSION + # ATOM_V4_DISABLE_FCAM=1 forces the unfused eager path for A/B vs the + # Triton fused-clamp-act-mul kernel. + self.use_fused_clamp_act_mul = _V4_USE_TRITON_FUSION and ( + os.environ.get("ATOM_V4_DISABLE_FCAM", "0") != "1" + ) def forward( self, @@ -1987,14 +2031,31 @@ def forward( # feed the bf16 GEMM output directly. combined = self.gate_up_proj(x) # [num_tokens, 2*inter_dim_per_tp] if self.use_fused_clamp_act_mul: - x_fp8, x_scale = fused_clamp_act_mul( - combined, - swiglu_limit=self.swiglu_limit, - activation="silu", - weights=weights, - dtype_quant=dtypes.fp8, - transpose_scale=True, - ) + # When MXFP8 is enabled, ATOM_MXFP8_BYPASS_FCAM=1 forces the + # legacy fp32 1x128 emit + linear.py's dequant+requant fallback, + # isolating whether fused_clamp_act_mul's MXFP8 emit has a bug. + _mxfp8 = os.environ.get("ATOM_FP8_BLOCKSCALE_USE_MXFP8", "0") == "1" + _bypass_fcam = os.environ.get("ATOM_MXFP8_BYPASS_FCAM", "0") == "1" + if _mxfp8 and not _bypass_fcam: + x_fp8, x_scale = fused_clamp_act_mul( + combined, + swiglu_limit=self.swiglu_limit, + activation="silu", + weights=weights, + dtype_quant=torch.float8_e4m3fn, + quant_block_size=32, + scale_dtype_fmt="ue8m0", + transpose_scale=False, + ) + else: + x_fp8, x_scale = fused_clamp_act_mul( + combined, + swiglu_limit=self.swiglu_limit, + activation="silu", + weights=weights, + dtype_quant=dtypes.fp8, + transpose_scale=True, + ) return self.w2(x_fp8, x_scale=x_scale) out = torch.empty( (combined.shape[0], combined.shape[-1] // 2), From 995ec0e5f9d6e761e6c0216316beb246afbee2d9 Mon Sep 17 00:00:00 2001 From: ShaoChunLee Date: Wed, 20 May 2026 17:11:54 +0000 Subject: [PATCH 2/2] clean up --- atom/model_ops/layernorm.py | 9 ++--- atom/model_ops/linear.py | 81 ++----------------------------------- atom/models/deepseek_v4.py | 10 ++--- 3 files changed, 12 insertions(+), 88 deletions(-) diff --git a/atom/model_ops/layernorm.py b/atom/model_ops/layernorm.py index 138cd3e2f..cde67fcd2 100644 --- a/atom/model_ops/layernorm.py +++ b/atom/model_ops/layernorm.py @@ -18,6 +18,7 @@ from aiter.jit.utils.torch_guard import torch_compile_guard from aiter.ops.gated_rmsnorm_fp8_group_quant import gated_rmsnorm_fp8_group_quant from aiter.ops.triton.fused_add_rmsnorm_pad import fused_add_rmsnorm_pad +from aiter.ops.triton.quant.quant_mxfp8 import rmsnorm_mxfp8_quant from atom.config import QuantizationConfig from atom.model_ops.utils import atom_parameter from atom.quant_spec import LayerQuantConfig @@ -26,6 +27,8 @@ from torch import Tensor, nn from torch.overrides import handle_torch_function, has_torch_function_unary +_V4_USE_MXFP8 = os.environ.get("ATOM_FP8_BLOCKSCALE_USE_MXFP8", "0") == "1" + def silu(input: Tensor, inplace: bool = False) -> Tensor: r"""Apply the Sigmoid Linear Unit (SiLU) function, element-wise. @@ -326,12 +329,8 @@ def forward( if ( self.quant_type.value == _QV_PER_1X128 and residual is None - and os.environ.get("ATOM_FP8_BLOCKSCALE_USE_MXFP8", "0") == "1" + and _V4_USE_MXFP8 ): - from aiter.ops.triton.quant.quant_mxfp8 import ( - rmsnorm_mxfp8_quant, - ) - if x.dim() != 2: x2 = x.reshape(-1, x.shape[-1]) else: diff --git a/atom/model_ops/linear.py b/atom/model_ops/linear.py index 5b1fcc1ac..6bee6e4c7 100644 --- a/atom/model_ops/linear.py +++ b/atom/model_ops/linear.py @@ -23,6 +23,10 @@ from aiter.jit.utils.torch_guard import torch_compile_guard from aiter.tuned_gemm import tgemm from aiter.utility import fp4_utils +from aiter.ops.triton.gemm.basic.gemm_afp8wfp8 import ( + gemm_afp8wfp8, + gemm_afp8wfp8_preshuffle, +) from atom.config import QuantizationConfig, get_current_atom_config from atom.quant_spec import LayerQuantConfig from atom.model_ops.utils import ( @@ -646,43 +650,7 @@ def forward( y += self.bias elif self.quant_type.value == QuantType.per_1x128.value: if _use_mxfp8: - # One-shot dump: overwrite each call so the final saved file - # is the most recent call (curl-driven, post-warmup). - # _dump_env = os.environ.get("ATOM_MXFP8_DUMP", "") - # if _dump_env: - # if _dump_env in getattr(self, "prefix", ""): - # try: - # import torch as _t - # _t.save( - # { - # "prefix": self.prefix, - # "x": x.detach().cpu(), - # "weight": self.weight.data.detach().cpu(), - # "x_scale": x_scale.detach().cpu() - # if x_scale is not None - # else None, - # "weight_scale": self.weight_scale.data.detach().cpu(), - # "otype": str(otype), - # }, - # f"/tmp/mxfp8_dump_{self.prefix.replace('.', '_')}.pt", - # ) - # print( - # f"[MXFP8_DUMP] saved /tmp/mxfp8_dump_{self.prefix.replace('.', '_')}.pt " - # f"x={tuple(x.shape)}/{x.dtype} " - # f"w={tuple(self.weight.shape)}/{self.weight.dtype} " - # f"xs={tuple(x_scale.shape) if x_scale is not None else None}/{x_scale.dtype if x_scale is not None else None} " - # f"ws={tuple(self.weight_scale.shape)}/{self.weight_scale.dtype}", - # flush=True, - # ) - # self._mxfp8_dumped = True - # except Exception as e: - # print(f"[MXFP8_DUMP_FAIL] {e}", flush=True) - # try: if envs.ATOM_FP8_BLOCKSCALE_WEIGHT_PRESHUFFLE: - from aiter.ops.triton.gemm.basic.gemm_afp8wfp8 import ( - gemm_afp8wfp8_preshuffle, - ) - y = gemm_afp8wfp8_preshuffle( x, self.weight, @@ -691,10 +659,6 @@ def forward( dtype=otype, ) else: - from aiter.ops.triton.gemm.basic.gemm_afp8wfp8 import ( - gemm_afp8wfp8, - ) - y = gemm_afp8wfp8( x, self.weight, @@ -702,44 +666,7 @@ def forward( self.weight_scale, dtype=otype, ) - # except Exception as e: - # print( - # f"[MXFP8_FAIL] x={tuple(x.shape)}/{x.dtype} " - # f"w={tuple(self.weight.shape)}/{self.weight.dtype} " - # f"xs={tuple(x_scale.shape)}/{x_scale.dtype} " - # f"ws={tuple(self.weight_scale.shape)}/{self.weight_scale.dtype} " - # f"err={type(e).__name__}: {str(e)[:200]}", - # flush=True, - # ) - # raise elif envs.ATOM_FP8_BLOCKSCALE_WEIGHT_PRESHUFFLE: - # If MXFP8 loader was active but kernel was skipped, the - # weight_scale is uint8 (e8m0). Legacy preshuffle GEMM wants - # fp32 - decode on the fly. - # _ws = self.weight_scale - # if _ws.dtype == torch.uint8: - # _ws = fp4_utils.e8m0_to_f32(_ws.view(torch.float8_e8m0fnu)) - # Legacy dump hook for A/B comparison vs MXFP8 path. - # _dump_env = os.environ.get("ATOM_MXFP8_DUMP", "") - # if _dump_env and _dump_env in getattr(self, "prefix", ""): - # try: - # import torch as _t - # _t.save( - # { - # "prefix": self.prefix, - # "x": x.detach().cpu(), - # "weight": self.weight.data.detach().cpu(), - # "x_scale": x_scale.detach().cpu() - # if x_scale is not None - # else None, - # "weight_scale": self.weight_scale.data.detach().cpu(), - # "otype": str(otype), - # "path": "legacy", - # }, - # f"/tmp/legacy_dump_{self.prefix.replace('.', '_')}.pt", - # ) - # except Exception as e: - # print(f"[LEGACY_DUMP_FAIL] {e}", flush=True) y = gemm_a8w8_blockscale_preshuffle_impl( x, self.weight, diff --git a/atom/models/deepseek_v4.py b/atom/models/deepseek_v4.py index 0a1552a2a..b284e98a8 100644 --- a/atom/models/deepseek_v4.py +++ b/atom/models/deepseek_v4.py @@ -132,6 +132,8 @@ # transcode. The helper below is reserved for call sites that need to fuse # the Q-side MXFP8 quant with the K-side bf16 RMSNorm in a single launch. _V4_USE_MXFP8 = os.environ.get("ATOM_FP8_BLOCKSCALE_USE_MXFP8", "0") == "1" +_V4_DISABLE_FCAM = os.environ.get("ATOM_V4_DISABLE_FCAM", "0") == "1" +_MXFP8_BYPASS_FCAM = os.environ.get("ATOM_MXFP8_BYPASS_FCAM", "0") == "1" def _fuse_rmsnorm_mxfp8_quant( @@ -2014,9 +2016,7 @@ def __init__( # FP8 + scale; w2 accepts `x_scale` and skips its own quant step. # ATOM_V4_DISABLE_FCAM=1 forces the unfused eager path for A/B vs the # Triton fused-clamp-act-mul kernel. - self.use_fused_clamp_act_mul = _V4_USE_TRITON_FUSION and ( - os.environ.get("ATOM_V4_DISABLE_FCAM", "0") != "1" - ) + self.use_fused_clamp_act_mul = _V4_USE_TRITON_FUSION and not _V4_DISABLE_FCAM def forward( self, @@ -2034,9 +2034,7 @@ def forward( # When MXFP8 is enabled, ATOM_MXFP8_BYPASS_FCAM=1 forces the # legacy fp32 1x128 emit + linear.py's dequant+requant fallback, # isolating whether fused_clamp_act_mul's MXFP8 emit has a bug. - _mxfp8 = os.environ.get("ATOM_FP8_BLOCKSCALE_USE_MXFP8", "0") == "1" - _bypass_fcam = os.environ.get("ATOM_MXFP8_BYPASS_FCAM", "0") == "1" - if _mxfp8 and not _bypass_fcam: + if _V4_USE_MXFP8 and not _MXFP8_BYPASS_FCAM: x_fp8, x_scale = fused_clamp_act_mul( combined, swiglu_limit=self.swiglu_limit,