From 97ac3d8d6d6d001121876319e67273689336e68e Mon Sep 17 00:00:00 2001 From: ShaoChunLee Date: Tue, 19 May 2026 21:32:43 +0000 Subject: [PATCH 1/9] integration --- atom/model_ops/layernorm.py | 23 +++++ atom/model_ops/linear.py | 171 ++++++++++++++++++++++++++++++++++-- atom/models/deepseek_v4.py | 79 +++++++++++++++-- 3 files changed, 259 insertions(+), 14 deletions(-) diff --git a/atom/model_ops/layernorm.py b/atom/model_ops/layernorm.py index 14898b200..138cd3e2f 100644 --- a/atom/model_ops/layernorm.py +++ b/atom/model_ops/layernorm.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: MIT # Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +import os from typing import Optional, Tuple import aiter @@ -318,6 +319,28 @@ def forward( and x_scale is None and self.quant_type.value in _AITER_RMS_QUANT_TYPE_VALUES ): + # MXFP8 path: when downstream Linear consumes MXFP8 1x32 e8m0 + # scales (ATOM_FP8_BLOCKSCALE_USE_MXFP8=1), emit those directly + # to skip the dequant+requant cascade. Limited to per_1x128 and + # the no-residual case (most q_norm callers). + if ( + self.quant_type.value == _QV_PER_1X128 + and residual is None + and os.environ.get("ATOM_FP8_BLOCKSCALE_USE_MXFP8", "0") == "1" + ): + from aiter.ops.triton.quant.quant_mxfp8 import ( + rmsnorm_mxfp8_quant, + ) + + if x.dim() != 2: + x2 = x.reshape(-1, x.shape[-1]) + else: + x2 = x + y, s = rmsnorm_mxfp8_quant(x2, self.weight, self.eps) + if x.dim() != 2: + y = y.view(*x.shape[:-1], x.shape[-1]) + s = s.view(*x.shape[:-1], s.shape[-1]) + return y, s # Dynamic-scale fused RMSNorm + quant via aiter HIP kernels. # Static FP8 (x_scale provided) stays on the branch above. x, x_scale, residual_out = _aiter_rms_quant( diff --git a/atom/model_ops/linear.py b/atom/model_ops/linear.py index 5505b5942..5b1fcc1ac 100644 --- a/atom/model_ops/linear.py +++ b/atom/model_ops/linear.py @@ -2,6 +2,7 @@ # Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. import logging +import os from functools import partial as functools_partial from typing import Callable, Optional @@ -286,11 +287,16 @@ def __init__( torch.empty(self.output_size, 1, dtype=dtypes.fp32) ) elif quant_type == QuantType.per_1x128: + # When MXFP8 a8w8 is active, allocate weight_scale directly as + # float8_e8m0fnu (1 byte) to match the on-disk dtype - avoids + # the lossy fp32 upcast/round-trip at load. + _mxfp8_env = os.environ.get("ATOM_FP8_BLOCKSCALE_USE_MXFP8", "0") == "1" + _scale_dtype = torch.float8_e8m0fnu if _mxfp8_env else dtypes.fp32 self.weight_scale = atom_parameter( torch.empty( (self.output_size + 127) // 128, (self.input_size + 127) // 128, - dtype=dtypes.fp32, + dtype=_scale_dtype, ) ) elif quant_type == QuantType.per_1x32: @@ -510,9 +516,22 @@ def process_weights_after_loading(self): self.quant_type == QuantType.per_Token and self.params_dtype == dtypes.fp8 ) or self.quant_type == QuantType.per_1x32 - # per_1x128 only needs shuffle when using the preshuffle GEMM path + # per_1x128 only needs shuffle when using a preshuffle GEMM path. + # MXFP8 path: only shuffle if gemm_mxfp8_preshuffle will consume the + # shuffled weight (gated by ATOM_MXFP8_USE_PRESHUFFLE, default on). + # The non-preshuffle gemm_mxfp8 reads scrambled rows otherwise. + _mxfp8_active = ( + self.quant_type == QuantType.per_1x128 + and os.environ.get("ATOM_FP8_BLOCKSCALE_USE_MXFP8", "0") == "1" + ) + _mxfp8_preshuffle = _mxfp8_active and ( + os.environ.get("ATOM_MXFP8_USE_PRESHUFFLE", "1") == "1" + ) if not need_shuffle and self.quant_type == QuantType.per_1x128: - need_shuffle = envs.ATOM_FP8_BLOCKSCALE_WEIGHT_PRESHUFFLE + if _mxfp8_active: + need_shuffle = _mxfp8_preshuffle + else: + need_shuffle = envs.ATOM_FP8_BLOCKSCALE_WEIGHT_PRESHUFFLE if need_shuffle: if self.weight.dim() == 2: shuffle_weights(self.weight) @@ -521,6 +540,30 @@ def process_weights_after_loading(self): if self.quant_type == QuantType.per_1x32: self.weight_scale.data = fp4_utils.e8m0_shuffle(self.weight_scale.data) + # ---- Optional: MXFP8 a8w8 path (dormant, only active when env set) ---- + # On disk V4-Flash stores per_1x128 weight scales as float8_e8m0fnu + # (uint8 powers-of-2) and weights as float8_e4m3fn. With MXFP8 active, + # __init__ allocated weight_scale as float8_e8m0fnu so the loader copied + # bits without any upcast/round-trip. Just expose as uint8 view here + # (gemm_mxfp8 expects uint8) and convert fp8 weight to fn-bits if the + # loader landed it as fnuz. + self._mxfp8_active = ( + self.quant_type == QuantType.per_1x128 + and os.environ.get("ATOM_FP8_BLOCKSCALE_USE_MXFP8", "0") == "1" + ) + if self._mxfp8_active: + if self.weight_scale.data.dtype != torch.uint8: + self.weight_scale = atom_parameter( + self.weight_scale.data.view(torch.uint8).contiguous() + ) + if self.weight.data.dtype == torch.float8_e4m3fnuz: + # On-disk format is float8_e4m3fn but the param was allocated + # as fnuz (AMD default). The bit patterns differ - convert via + # value to preserve semantics. + w_val = self.weight.data.float() + w_fn = w_val.clamp(-448.0, 448.0).to(torch.float8_e4m3fn) + self.weight = atom_parameter(w_fn) + @mark_trace def forward( self, x: torch.Tensor, x_scale: Optional[torch.Tensor] = None, otype=dtypes.bf16 @@ -533,7 +576,31 @@ def forward( otype=otype, ) else: - if x_scale is None: + # MXFP8 path: route per_1x128 weights through MXFP8 GEMM. + _use_mxfp8 = self.quant_type.value == QuantType.per_1x128.value and getattr( + self, "_mxfp8_active", False + ) + if _use_mxfp8: + from aiter.ops.triton.quant.quant_mxfp8 import ( + per_1x32_mxfp8_quant_triton, + ) + + if x_scale is None: + x, x_scale = per_1x32_mxfp8_quant_triton(x) + elif x_scale.dtype == torch.uint8: + pass # caller already emitted MXFP8 1x32 + else: + # legacy caller emitted FP8 + fp32 1x128 scale. Dequant to + # bf16, then re-quantize to MXFP8 1x32. + Mx, Kx = x.shape + sM, sCols = x_scale.shape + group = Kx // sCols + x_dq = x.to(torch.float32).view(Mx, sCols, group) * x_scale.to( + torch.float32 + ).view(sM, sCols, 1) + x_bf16 = x_dq.view(Mx, Kx).to(torch.bfloat16) + x, x_scale = per_1x32_mxfp8_quant_triton(x_bf16) + elif x_scale is None: quant_func = self.quant_func if self.quant_type.value == QuantType.per_1x128.value: # preshuffle GEMM expects column-major x_scale; @@ -578,7 +645,101 @@ def forward( if self.bias is not None: y += self.bias elif self.quant_type.value == QuantType.per_1x128.value: - if envs.ATOM_FP8_BLOCKSCALE_WEIGHT_PRESHUFFLE: + if _use_mxfp8: + # One-shot dump: overwrite each call so the final saved file + # is the most recent call (curl-driven, post-warmup). + # _dump_env = os.environ.get("ATOM_MXFP8_DUMP", "") + # if _dump_env: + # if _dump_env in getattr(self, "prefix", ""): + # try: + # import torch as _t + # _t.save( + # { + # "prefix": self.prefix, + # "x": x.detach().cpu(), + # "weight": self.weight.data.detach().cpu(), + # "x_scale": x_scale.detach().cpu() + # if x_scale is not None + # else None, + # "weight_scale": self.weight_scale.data.detach().cpu(), + # "otype": str(otype), + # }, + # f"/tmp/mxfp8_dump_{self.prefix.replace('.', '_')}.pt", + # ) + # print( + # f"[MXFP8_DUMP] saved /tmp/mxfp8_dump_{self.prefix.replace('.', '_')}.pt " + # f"x={tuple(x.shape)}/{x.dtype} " + # f"w={tuple(self.weight.shape)}/{self.weight.dtype} " + # f"xs={tuple(x_scale.shape) if x_scale is not None else None}/{x_scale.dtype if x_scale is not None else None} " + # f"ws={tuple(self.weight_scale.shape)}/{self.weight_scale.dtype}", + # flush=True, + # ) + # self._mxfp8_dumped = True + # except Exception as e: + # print(f"[MXFP8_DUMP_FAIL] {e}", flush=True) + # try: + if envs.ATOM_FP8_BLOCKSCALE_WEIGHT_PRESHUFFLE: + from aiter.ops.triton.gemm.basic.gemm_afp8wfp8 import ( + gemm_afp8wfp8_preshuffle, + ) + + y = gemm_afp8wfp8_preshuffle( + x, + self.weight, + x_scale, + self.weight_scale, + dtype=otype, + ) + else: + from aiter.ops.triton.gemm.basic.gemm_afp8wfp8 import ( + gemm_afp8wfp8, + ) + + y = gemm_afp8wfp8( + x, + self.weight, + x_scale, + self.weight_scale, + dtype=otype, + ) + # except Exception as e: + # print( + # f"[MXFP8_FAIL] x={tuple(x.shape)}/{x.dtype} " + # f"w={tuple(self.weight.shape)}/{self.weight.dtype} " + # f"xs={tuple(x_scale.shape)}/{x_scale.dtype} " + # f"ws={tuple(self.weight_scale.shape)}/{self.weight_scale.dtype} " + # f"err={type(e).__name__}: {str(e)[:200]}", + # flush=True, + # ) + # raise + elif envs.ATOM_FP8_BLOCKSCALE_WEIGHT_PRESHUFFLE: + # If MXFP8 loader was active but kernel was skipped, the + # weight_scale is uint8 (e8m0). Legacy preshuffle GEMM wants + # fp32 - decode on the fly. + # _ws = self.weight_scale + # if _ws.dtype == torch.uint8: + # _ws = fp4_utils.e8m0_to_f32(_ws.view(torch.float8_e8m0fnu)) + # Legacy dump hook for A/B comparison vs MXFP8 path. + # _dump_env = os.environ.get("ATOM_MXFP8_DUMP", "") + # if _dump_env and _dump_env in getattr(self, "prefix", ""): + # try: + # import torch as _t + # _t.save( + # { + # "prefix": self.prefix, + # "x": x.detach().cpu(), + # "weight": self.weight.data.detach().cpu(), + # "x_scale": x_scale.detach().cpu() + # if x_scale is not None + # else None, + # "weight_scale": self.weight_scale.data.detach().cpu(), + # "otype": str(otype), + # "path": "legacy", + # }, + # f"/tmp/legacy_dump_{self.prefix.replace('.', '_')}.pt", + # ) + # except Exception as e: + # print(f"[LEGACY_DUMP_FAIL] {e}", flush=True) y = gemm_a8w8_blockscale_preshuffle_impl( x, self.weight, diff --git a/atom/models/deepseek_v4.py b/atom/models/deepseek_v4.py index 7bb238941..0a1552a2a 100644 --- a/atom/models/deepseek_v4.py +++ b/atom/models/deepseek_v4.py @@ -126,6 +126,46 @@ # Fused-kernel switches. Default off; flip via env to A/B against the eager path. _V4_USE_TRITON_FUSION = os.environ.get("ATOM_V4_USE_TRITON_FUSION", "0") == "1" ENABLE_DS_QKNORM_QUANT_FUSION = envs.ATOM_ENABLE_DS_QKNORM_QUANT_FUSION +# MXFP8 a8w8 GEMM path (Task #77). When on, q_norm RMSNorm emits FP8 e4m3fn + +# uint8 e8m0 1x32 scales directly (via the Triton rmsnorm_mxfp8_quant path in +# atom/model_ops/layernorm.py) so wq_b's MXFP8 GEMM consumes them with no +# transcode. The helper below is reserved for call sites that need to fuse +# the Q-side MXFP8 quant with the K-side bf16 RMSNorm in a single launch. +_V4_USE_MXFP8 = os.environ.get("ATOM_FP8_BLOCKSCALE_USE_MXFP8", "0") == "1" + + +def _fuse_rmsnorm_mxfp8_quant( + q_lora: torch.Tensor, + q_norm_weight: torch.Tensor, + q_norm_eps: float, + kv_pre: torch.Tensor, + kv_norm_weight: torch.Tensor, + kv_norm_eps: float, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Dual RMSNorm with MXFP8 emit on Q side, bf16 on K side, single launch. + + Q: RMSNorm + per-1x32 MXFP8 quant -- emits FP8 e4m3fn + uint8 e8m0 scale + directly so the downstream wq_b MXFP8 GEMM consumes them with no + dequant+requant cascade. + K: RMSNorm only, bf16 output, consumed by the downstream RoPE / SWA-write + fused kernel which expects bf16 K. + + Both halves run in a single Triton launch (dual_rmsnorm_mxfp8_quant) to + avoid the +6us/layer per-launch overhead of two separate kernels. + + Returns (qr_fp8, qr_scale_e8m0, kv_bf16). + """ + from aiter.ops.triton.quant.quant_mxfp8 import dual_rmsnorm_mxfp8_quant + + qr, qr_scale, kv = dual_rmsnorm_mxfp8_quant( + q_lora, + kv_pre, + q_norm_weight, + kv_norm_weight, + q_norm_eps, + kv_norm_eps, + ) + return qr, qr_scale, kv def _rmsnorm_nw(x: torch.Tensor, eps: float, dim: int) -> torch.Tensor: @@ -1972,7 +2012,11 @@ def __init__( # Switch: route clamp + silu(gate)*up [+ weights] + per-token FP8 1x128 # quant through a single aiter triton kernel. The fused kernel emits # FP8 + scale; w2 accepts `x_scale` and skips its own quant step. - self.use_fused_clamp_act_mul = _V4_USE_TRITON_FUSION + # ATOM_V4_DISABLE_FCAM=1 forces the unfused eager path for A/B vs the + # Triton fused-clamp-act-mul kernel. + self.use_fused_clamp_act_mul = _V4_USE_TRITON_FUSION and ( + os.environ.get("ATOM_V4_DISABLE_FCAM", "0") != "1" + ) def forward( self, @@ -1987,14 +2031,31 @@ def forward( # feed the bf16 GEMM output directly. combined = self.gate_up_proj(x) # [num_tokens, 2*inter_dim_per_tp] if self.use_fused_clamp_act_mul: - x_fp8, x_scale = fused_clamp_act_mul( - combined, - swiglu_limit=self.swiglu_limit, - activation="silu", - weights=weights, - dtype_quant=dtypes.fp8, - transpose_scale=True, - ) + # When MXFP8 is enabled, ATOM_MXFP8_BYPASS_FCAM=1 forces the + # legacy fp32 1x128 emit + linear.py's dequant+requant fallback, + # isolating whether fused_clamp_act_mul's MXFP8 emit has a bug. + _mxfp8 = os.environ.get("ATOM_FP8_BLOCKSCALE_USE_MXFP8", "0") == "1" + _bypass_fcam = os.environ.get("ATOM_MXFP8_BYPASS_FCAM", "0") == "1" + if _mxfp8 and not _bypass_fcam: + x_fp8, x_scale = fused_clamp_act_mul( + combined, + swiglu_limit=self.swiglu_limit, + activation="silu", + weights=weights, + dtype_quant=torch.float8_e4m3fn, + quant_block_size=32, + scale_dtype_fmt="ue8m0", + transpose_scale=False, + ) + else: + x_fp8, x_scale = fused_clamp_act_mul( + combined, + swiglu_limit=self.swiglu_limit, + activation="silu", + weights=weights, + dtype_quant=dtypes.fp8, + transpose_scale=True, + ) return self.w2(x_fp8, x_scale=x_scale) out = torch.empty( (combined.shape[0], combined.shape[-1] // 2), From 995ec0e5f9d6e761e6c0216316beb246afbee2d9 Mon Sep 17 00:00:00 2001 From: ShaoChunLee Date: Wed, 20 May 2026 17:11:54 +0000 Subject: [PATCH 2/9] clean up --- atom/model_ops/layernorm.py | 9 ++--- atom/model_ops/linear.py | 81 ++----------------------------------- atom/models/deepseek_v4.py | 10 ++--- 3 files changed, 12 insertions(+), 88 deletions(-) diff --git a/atom/model_ops/layernorm.py b/atom/model_ops/layernorm.py index 138cd3e2f..cde67fcd2 100644 --- a/atom/model_ops/layernorm.py +++ b/atom/model_ops/layernorm.py @@ -18,6 +18,7 @@ from aiter.jit.utils.torch_guard import torch_compile_guard from aiter.ops.gated_rmsnorm_fp8_group_quant import gated_rmsnorm_fp8_group_quant from aiter.ops.triton.fused_add_rmsnorm_pad import fused_add_rmsnorm_pad +from aiter.ops.triton.quant.quant_mxfp8 import rmsnorm_mxfp8_quant from atom.config import QuantizationConfig from atom.model_ops.utils import atom_parameter from atom.quant_spec import LayerQuantConfig @@ -26,6 +27,8 @@ from torch import Tensor, nn from torch.overrides import handle_torch_function, has_torch_function_unary +_V4_USE_MXFP8 = os.environ.get("ATOM_FP8_BLOCKSCALE_USE_MXFP8", "0") == "1" + def silu(input: Tensor, inplace: bool = False) -> Tensor: r"""Apply the Sigmoid Linear Unit (SiLU) function, element-wise. @@ -326,12 +329,8 @@ def forward( if ( self.quant_type.value == _QV_PER_1X128 and residual is None - and os.environ.get("ATOM_FP8_BLOCKSCALE_USE_MXFP8", "0") == "1" + and _V4_USE_MXFP8 ): - from aiter.ops.triton.quant.quant_mxfp8 import ( - rmsnorm_mxfp8_quant, - ) - if x.dim() != 2: x2 = x.reshape(-1, x.shape[-1]) else: diff --git a/atom/model_ops/linear.py b/atom/model_ops/linear.py index 5b1fcc1ac..6bee6e4c7 100644 --- a/atom/model_ops/linear.py +++ b/atom/model_ops/linear.py @@ -23,6 +23,10 @@ from aiter.jit.utils.torch_guard import torch_compile_guard from aiter.tuned_gemm import tgemm from aiter.utility import fp4_utils +from aiter.ops.triton.gemm.basic.gemm_afp8wfp8 import ( + gemm_afp8wfp8, + gemm_afp8wfp8_preshuffle, +) from atom.config import QuantizationConfig, get_current_atom_config from atom.quant_spec import LayerQuantConfig from atom.model_ops.utils import ( @@ -646,43 +650,7 @@ def forward( y += self.bias elif self.quant_type.value == QuantType.per_1x128.value: if _use_mxfp8: - # One-shot dump: overwrite each call so the final saved file - # is the most recent call (curl-driven, post-warmup). - # _dump_env = os.environ.get("ATOM_MXFP8_DUMP", "") - # if _dump_env: - # if _dump_env in getattr(self, "prefix", ""): - # try: - # import torch as _t - # _t.save( - # { - # "prefix": self.prefix, - # "x": x.detach().cpu(), - # "weight": self.weight.data.detach().cpu(), - # "x_scale": x_scale.detach().cpu() - # if x_scale is not None - # else None, - # "weight_scale": self.weight_scale.data.detach().cpu(), - # "otype": str(otype), - # }, - # f"/tmp/mxfp8_dump_{self.prefix.replace('.', '_')}.pt", - # ) - # print( - # f"[MXFP8_DUMP] saved /tmp/mxfp8_dump_{self.prefix.replace('.', '_')}.pt " - # f"x={tuple(x.shape)}/{x.dtype} " - # f"w={tuple(self.weight.shape)}/{self.weight.dtype} " - # f"xs={tuple(x_scale.shape) if x_scale is not None else None}/{x_scale.dtype if x_scale is not None else None} " - # f"ws={tuple(self.weight_scale.shape)}/{self.weight_scale.dtype}", - # flush=True, - # ) - # self._mxfp8_dumped = True - # except Exception as e: - # print(f"[MXFP8_DUMP_FAIL] {e}", flush=True) - # try: if envs.ATOM_FP8_BLOCKSCALE_WEIGHT_PRESHUFFLE: - from aiter.ops.triton.gemm.basic.gemm_afp8wfp8 import ( - gemm_afp8wfp8_preshuffle, - ) - y = gemm_afp8wfp8_preshuffle( x, self.weight, @@ -691,10 +659,6 @@ def forward( dtype=otype, ) else: - from aiter.ops.triton.gemm.basic.gemm_afp8wfp8 import ( - gemm_afp8wfp8, - ) - y = gemm_afp8wfp8( x, self.weight, @@ -702,44 +666,7 @@ def forward( self.weight_scale, dtype=otype, ) - # except Exception as e: - # print( - # f"[MXFP8_FAIL] x={tuple(x.shape)}/{x.dtype} " - # f"w={tuple(self.weight.shape)}/{self.weight.dtype} " - # f"xs={tuple(x_scale.shape)}/{x_scale.dtype} " - # f"ws={tuple(self.weight_scale.shape)}/{self.weight_scale.dtype} " - # f"err={type(e).__name__}: {str(e)[:200]}", - # flush=True, - # ) - # raise elif envs.ATOM_FP8_BLOCKSCALE_WEIGHT_PRESHUFFLE: - # If MXFP8 loader was active but kernel was skipped, the - # weight_scale is uint8 (e8m0). Legacy preshuffle GEMM wants - # fp32 - decode on the fly. - # _ws = self.weight_scale - # if _ws.dtype == torch.uint8: - # _ws = fp4_utils.e8m0_to_f32(_ws.view(torch.float8_e8m0fnu)) - # Legacy dump hook for A/B comparison vs MXFP8 path. - # _dump_env = os.environ.get("ATOM_MXFP8_DUMP", "") - # if _dump_env and _dump_env in getattr(self, "prefix", ""): - # try: - # import torch as _t - # _t.save( - # { - # "prefix": self.prefix, - # "x": x.detach().cpu(), - # "weight": self.weight.data.detach().cpu(), - # "x_scale": x_scale.detach().cpu() - # if x_scale is not None - # else None, - # "weight_scale": self.weight_scale.data.detach().cpu(), - # "otype": str(otype), - # "path": "legacy", - # }, - # f"/tmp/legacy_dump_{self.prefix.replace('.', '_')}.pt", - # ) - # except Exception as e: - # print(f"[LEGACY_DUMP_FAIL] {e}", flush=True) y = gemm_a8w8_blockscale_preshuffle_impl( x, self.weight, diff --git a/atom/models/deepseek_v4.py b/atom/models/deepseek_v4.py index 0a1552a2a..b284e98a8 100644 --- a/atom/models/deepseek_v4.py +++ b/atom/models/deepseek_v4.py @@ -132,6 +132,8 @@ # transcode. The helper below is reserved for call sites that need to fuse # the Q-side MXFP8 quant with the K-side bf16 RMSNorm in a single launch. _V4_USE_MXFP8 = os.environ.get("ATOM_FP8_BLOCKSCALE_USE_MXFP8", "0") == "1" +_V4_DISABLE_FCAM = os.environ.get("ATOM_V4_DISABLE_FCAM", "0") == "1" +_MXFP8_BYPASS_FCAM = os.environ.get("ATOM_MXFP8_BYPASS_FCAM", "0") == "1" def _fuse_rmsnorm_mxfp8_quant( @@ -2014,9 +2016,7 @@ def __init__( # FP8 + scale; w2 accepts `x_scale` and skips its own quant step. # ATOM_V4_DISABLE_FCAM=1 forces the unfused eager path for A/B vs the # Triton fused-clamp-act-mul kernel. - self.use_fused_clamp_act_mul = _V4_USE_TRITON_FUSION and ( - os.environ.get("ATOM_V4_DISABLE_FCAM", "0") != "1" - ) + self.use_fused_clamp_act_mul = _V4_USE_TRITON_FUSION and not _V4_DISABLE_FCAM def forward( self, @@ -2034,9 +2034,7 @@ def forward( # When MXFP8 is enabled, ATOM_MXFP8_BYPASS_FCAM=1 forces the # legacy fp32 1x128 emit + linear.py's dequant+requant fallback, # isolating whether fused_clamp_act_mul's MXFP8 emit has a bug. - _mxfp8 = os.environ.get("ATOM_FP8_BLOCKSCALE_USE_MXFP8", "0") == "1" - _bypass_fcam = os.environ.get("ATOM_MXFP8_BYPASS_FCAM", "0") == "1" - if _mxfp8 and not _bypass_fcam: + if _V4_USE_MXFP8 and not _MXFP8_BYPASS_FCAM: x_fp8, x_scale = fused_clamp_act_mul( combined, swiglu_limit=self.swiglu_limit, From 623fd026878c6d304a65e101b7356e66518c3c72 Mon Sep 17 00:00:00 2001 From: ShaoChunLee Date: Wed, 20 May 2026 16:07:37 +0000 Subject: [PATCH 3/9] add a8w4 triton moe --- atom/model_ops/fused_moe_triton.py | 215 ++++++++++++++++++++++++++++- atom/model_ops/moe.py | 74 ++++++++++ atom/models/deepseek_v4.py | 26 +++- 3 files changed, 312 insertions(+), 3 deletions(-) diff --git a/atom/model_ops/fused_moe_triton.py b/atom/model_ops/fused_moe_triton.py index 8bb1be4c3..24648a0b8 100644 --- a/atom/model_ops/fused_moe_triton.py +++ b/atom/model_ops/fused_moe_triton.py @@ -218,6 +218,191 @@ def routing_from_topk(topk_weights, topk_ids, n_expts_tot): return routing_data, gather_indx, scatter_indx +def bridge_tk_to_aiter_routing(tk_rd, block_m: int): + """Bridge triton_kernels RoutingData -> aiter RoutingData (for moe_gemm_a8w4). + + triton_kernels carries per-block-m specializations of token_offs_pad / + block_pid_map as dicts keyed by block_m; the aiter variant is specialized + for a single block_m and stores those fields as plain tensors. + """ + from aiter.ops.triton.moe.moe_routing.routing import RoutingData, ExptData + + tk_ed = tk_rd.expt_data + if block_m not in tk_ed.token_offs_pad: + raise ValueError( + f"bridge_tk_to_aiter_routing: block_m={block_m} not present in " + f"tk_rd.expt_data.token_offs_pad. Keys: {sorted(tk_ed.token_offs_pad.keys())}" + ) + aiter_expt_data = ExptData( + hist=tk_ed.hist, + token_offs_raw=tk_ed.token_offs_raw, + token_offs_pad=tk_ed.token_offs_pad[block_m], + block_pid_map=tk_ed.block_pid_map[block_m], + ) + return RoutingData( + block_m=block_m, + gate_scal=tk_rd.gate_scal, + expt_hist=tk_rd.expt_hist, + n_expts_tot=tk_rd.n_expts_tot, + n_expts_act=tk_rd.n_expts_act, + expt_data=aiter_expt_data, + ) + + +def _a8w4_fused_experts( + output_tensor: torch.Tensor, + hidden_states: torch.Tensor, + w1: torch.Tensor, # a8w4 layout: uint8 [E, K/2, 2N], stride(-2)==1 + w2: torch.Tensor, # a8w4 layout: uint8 [E, N/2, K], stride(-2)==1 + w1_scale: torch.Tensor, # swizzled ue8m0 uint8 + w2_scale: torch.Tensor, # swizzled ue8m0 uint8 + routing_data, + gather_indx, + scatter_indx, + topk: int, + swiglu_limit: float, + w1_bias: torch.Tensor | None, + w2_bias: torch.Tensor | None, + apply_router_weight_on_input: bool, + pre_built_aiter_routing=None, # if set, skip bridge and use this aiter routing directly + # Step 9: shared-expert output to fold into GEMM2 reduce_grouped writeback. + residual: torch.Tensor | None = None, +) -> torch.Tensor: + """Step 1 minimum a8w4 fused experts. + + Pipeline: downcast_to_mxfp -> moe_gemm_a8w4 (G1, no swiglu) -> + fused_clamp_act_mul (silu_mul, no quant) -> downcast_to_mxfp -> + moe_gemm_a8w4 (G2). No apply_swiglu fold, no out_mx_quant fold, + no residual fold — those are later steps. + """ + from aiter.ops.triton.moe.moe_op_gemm_a8w4 import moe_gemm_a8w4, recommend_block_m + from aiter.ops.triton.moe.quant_moe import downcast_to_mxfp + + M, K = hidden_states.shape[-2:] + BLOCK_M = recommend_block_m(M) + # Step 3: apply_swiglu fold into GEMM1. Requires interleaved W13 (done at weight load). + _swiglu_fold = os.environ.get("ATOM_A8W4_SWIGLU_FOLD", "0") == "1" + + # Pre-MoE quant: bf16 -> fp8 e4m3 + ue8m0 per-1x32. + x_fp8, x_scale = downcast_to_mxfp(hidden_states, torch.float8_e4m3fn, axis=-1) + + # Step 4: skip bridge when pre_built_aiter_routing is provided. + if pre_built_aiter_routing is not None: + aiter_routing = pre_built_aiter_routing + gammas = aiter_routing.gate_scal + gather_src_indx = gather_indx # already raw int32 tensor + scatter_src_indx = scatter_indx + else: + aiter_routing = bridge_tk_to_aiter_routing(routing_data, BLOCK_M) + gammas = routing_data.gate_scal if routing_data else None + gather_src_indx = gather_indx.src_indx + scatter_src_indx = scatter_indx.src_indx + + if w1_scale.dtype != torch.uint8: + w1_scale = w1_scale.view(torch.uint8) + if w2_scale.dtype != torch.uint8: + w2_scale = w2_scale.view(torch.uint8) + + if _swiglu_fold: + # Step 5: optional fold of MXFP8 (fp8 e4m3 + ue8m0 per-1×32) emit into + # GEMM1's apply_swiglu writeback. Saves one `downcast_to_mxfp` launch. + _mx_emit = os.environ.get("ATOM_A8W4_GEMM1_MX_EMIT", "0") == "1" + if _mx_emit: + inter_fp8, inter_scale = moe_gemm_a8w4( + x=x_fp8, + w=w1, + x_scales=x_scale, + w_scales=w1_scale, + bias=w1_bias, + routing_data=aiter_routing, + gather_indx=gather_src_indx, + gammas=gammas if apply_router_weight_on_input else None, + swizzle_mx_scale="CDNA4_SCALE", + apply_swiglu=True, + alpha=1.0, + limit=swiglu_limit, + add_residual=False, + out_dtype=torch.bfloat16, # ignored when out_mx_quant=True + out_mx_quant=True, + ) + else: + # Step 3: GEMM1 with apply_swiglu=True. Reads interleaved W13. Returns bf16 [M*topk, N]. + # add_residual=False = V4-Flash semantics (silu(gate) * up). True would compute s*(up+1). + inter_bf16 = moe_gemm_a8w4( + x=x_fp8, + w=w1, + x_scales=x_scale, + w_scales=w1_scale, + bias=w1_bias, + routing_data=aiter_routing, + gather_indx=gather_src_indx, + gammas=gammas if apply_router_weight_on_input else None, + swizzle_mx_scale="CDNA4_SCALE", + apply_swiglu=True, + alpha=1.0, + limit=swiglu_limit, + add_residual=False, + out_dtype=torch.bfloat16, + ) + inter_bf16 = inter_bf16.view(inter_bf16.shape[-2], inter_bf16.shape[-1]) + inter_fp8, inter_scale = downcast_to_mxfp( + inter_bf16, torch.float8_e4m3fn, axis=-1 + ) + else: + # GEMM1: fp8 act x fp4 weight, bf16 output [M*topk, 2N]. + raw_intermediate = moe_gemm_a8w4( + x=x_fp8, + w=w1, + x_scales=x_scale, + w_scales=w1_scale, + bias=w1_bias, + routing_data=aiter_routing, + gather_indx=gather_src_indx, + gammas=gammas if apply_router_weight_on_input else None, + swizzle_mx_scale="CDNA4_SCALE", + apply_swiglu=False, + out_dtype=torch.bfloat16, + ) + raw_2d = raw_intermediate.view( + raw_intermediate.shape[-2], raw_intermediate.shape[-1] + ) + two_N = raw_2d.shape[-1] + inter_bf16 = torch.empty( + (raw_2d.shape[0], two_N // 2), + device=raw_2d.device, + dtype=torch.bfloat16, + ) + fused_clamp_act_mul( + raw_2d, + out=inter_bf16, + swiglu_limit=swiglu_limit, + activation="silu", + dtype_quant=None, + ) + inter_fp8, inter_scale = downcast_to_mxfp( + inter_bf16, torch.float8_e4m3fn, axis=-1 + ) + + # GEMM2: fp8 act x fp4 weight, scatter+combine. + # Step 9: pass `residual` so reduce_grouped folds the routed+shared add + # into the writeback (saves the standalone elementwise add launch). + y = moe_gemm_a8w4( + x=inter_fp8, + w=w2, + x_scales=inter_scale, + w_scales=w2_scale, + bias=w2_bias, + routing_data=aiter_routing, + scatter_indx=scatter_src_indx, + gammas=None if apply_router_weight_on_input else gammas, + swizzle_mx_scale="CDNA4_SCALE", + apply_swiglu=False, + out_dtype=torch.bfloat16, + residual=residual, + ) + return y + + def _resize_cache(x: torch.Tensor, v: tuple[int, ...]) -> torch.Tensor: """ Shrink the given tensor and apply the given view to it. This is @@ -293,13 +478,41 @@ def triton_kernel_fused_experts( expert_map: torch.Tensor | None = None, intermediate_cache: torch.Tensor | None = None, a1q_scale: torch.Tensor | None = None, + w13_weight_scale_a8w4: torch.Tensor | None = None, + w2_weight_scale_a8w4: torch.Tensor | None = None, + swiglu_limit_a8w4: float = 7.0, ) -> torch.Tensor: # type check, uint8 means mxfp4 assert hidden_states.dtype == torch.bfloat16 assert w1_bias is None or w1_bias.dtype == torch.float32 assert w2_bias is None or w2_bias.dtype == torch.float32 - # Shape check, only check non-mxfp4 + # a8w4 path: dispatched when env is set AND a8w4 scale tensors were plumbed + # through. matmul_ogs path below stays untouched. + _use_a8w4 = ( + os.environ.get("ATOM_MOE_BACKEND", "matmul_ogs") == "a8w4" + and w13_weight_scale_a8w4 is not None + and w2_weight_scale_a8w4 is not None + ) + if _use_a8w4: + return _a8w4_fused_experts( + output_tensor, + hidden_states, + w1, + w2, + w13_weight_scale_a8w4, + w2_weight_scale_a8w4, + routing_data, + gather_indx, + scatter_indx, + topk, + swiglu_limit=swiglu_limit_a8w4, + w1_bias=w1_bias, + w2_bias=w2_bias, + apply_router_weight_on_input=apply_router_weight_on_input, + ) + + # Shape check for matmul_ogs path (a8w4 has different layout) assert hidden_states.ndim == 2 assert hidden_states.shape[-1] == w1.shape[-2] assert w2.shape[-1] == w1.shape[1] diff --git a/atom/model_ops/moe.py b/atom/model_ops/moe.py index 92d4698f8..f96a6f87a 100644 --- a/atom/model_ops/moe.py +++ b/atom/model_ops/moe.py @@ -830,6 +830,80 @@ def process_weights_after_loading(self, layer): if os.environ.get("ATOM_V4_TORCH_MOE"): return + # Step 1: a8w4 MoE backend weight-layout branch. Reorders W13/W2 into + # the layout moe_gemm_a8w4 expects and swizzles scales. + _use_a8w4_moe = os.environ.get("ATOM_MOE_BACKEND", "matmul_ogs") == "a8w4" + if _use_a8w4_moe and self.use_triton: + from aiter.ops.triton.moe.moe_op_gemm_a8w4 import swizzle_scales + + # Step 3: gate/up interleave for apply_swiglu fold. Off by default. + _interleave_w13 = os.environ.get("ATOM_A8W4_SWIGLU_FOLD", "0") == "1" + + def _interleave_gateup(t): + """t: [E, 2N, *] → reorder dim=1 from [g..u..] to [g,u,g,u,...]""" + E, two_N = t.shape[0], t.shape[1] + N = two_N // 2 + tail = t.shape[2:] + return ( + t.view(E, 2, N, *tail) + .permute(0, 2, 1, *range(3, 3 + len(tail))) + .contiguous() + .view(E, two_N, *tail) + ) + + # W1 weight: view-transpose to [E, K/2, 2N] (no .contiguous()) + w1_w = layer.w13_weight.data.view(torch.uint8) # [E, 2N, K/2] + if _interleave_w13: + w1_w = _interleave_gateup(w1_w) + w1_w_kernel = w1_w.transpose(-1, -2) # [E, K/2, 2N], view + assert ( + w1_w_kernel.stride(-2) == 1 + ), "W1: K must be contiguous (do NOT call .contiguous())" + + # W1 scale: transpose + contiguous + swizzle + w1_s = layer.w13_weight_scale.data # [E, 2N, K/32] + if _interleave_w13: + w1_s = _interleave_gateup(w1_s) + w1_s_swz_in = w1_s.transpose(-1, -2).contiguous() # [E, K/32, 2N] + w1_s_E, w1_s_SCALE_K, w1_s_N = w1_s_swz_in.shape + w1_s_K = w1_s_SCALE_K * 32 + if w1_s_N % 32 == 0 and w1_s_K % 256 == 0: + w1_s_kernel = swizzle_scales(w1_s_swz_in) + else: + w1_s_kernel = w1_s_swz_in + + # W2 weight: view-transpose to [E, N_per/2, hidden] + w2_w = layer.w2_weight.data.view(torch.uint8) + w2_w_kernel = w2_w.transpose(-1, -2) + assert ( + w2_w_kernel.stride(-2) == 1 + ), "W2: K (=N_per/2 packed) must be contiguous" + + # W2 scale: transpose + contiguous + swizzle + w2_s = layer.w2_weight_scale.data + w2_s_swz_in = w2_s.transpose(-1, -2).contiguous() + w2_s_E, w2_s_SCALE_K, w2_s_N = w2_s_swz_in.shape + w2_s_K = w2_s_SCALE_K * 32 + if w2_s_N % 32 == 0 and w2_s_K % 256 == 0: + w2_s_kernel = swizzle_scales(w2_s_swz_in) + else: + w2_s_kernel = w2_s_swz_in + + del layer.w13_weight + del layer.w2_weight + del layer.w13_weight_scale + del layer.w2_weight_scale + layer.w13_weight = w1_w_kernel + layer.w2_weight = w2_w_kernel + layer.w13_weight_scale = w1_s_kernel + layer.w2_weight_scale = w2_s_kernel + + self.w13_precision_config = None + self.w2_precision_config = None + self.moe_backend = "a8w4" + self.a8w4_swiglu_fold = _interleave_w13 + return + if self.use_triton: from atom.model_ops.fused_moe_triton import _swizzle_mxfp4 from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig diff --git a/atom/models/deepseek_v4.py b/atom/models/deepseek_v4.py index b284e98a8..a22677f82 100644 --- a/atom/models/deepseek_v4.py +++ b/atom/models/deepseek_v4.py @@ -2282,9 +2282,31 @@ def combine_outputs( def single_stream_moe_forward( self, x: torch.Tensor # [num_tokens, dim] ) -> torch.Tensor: # [num_tokens, dim] - """Sequential: shared_experts → routed_experts → combine.""" + """Sequential: shared_experts → routed_experts → combine. + + Step 9: when the MoE method supports it (a8w4 triton-routing path), + the `routed + shared` add is folded into reduce_grouped's writeback by + stashing `shared` on `self.experts` before `routed_expert_forward`. + Saves the standalone elementwise add kernel. Gated by + ATOM_A8W4_FUSE_RESIDUAL (default 1) for easy A/B. + """ shared = self.shared_experts(x) if self.shared_experts is not None else None - routed = self.routed_expert_forward(x) + fuse_residual = ( + shared is not None and os.environ.get("ATOM_A8W4_FUSE_RESIDUAL", "1") == "1" + ) + if fuse_residual: + self.experts._moe_residual_to_fold = shared + try: + routed = self.routed_expert_forward(x) + finally: + if fuse_residual: + self.experts._moe_residual_to_fold = None + folded = getattr(self.experts, "_moe_residual_was_folded", False) + self.experts._moe_residual_was_folded = False # one-shot + if folded: + if self.tp_size > 1: + routed = tensor_model_parallel_all_reduce(routed) + return routed return self.combine_outputs(routed, shared) def dual_stream_moe_forward( From 22f9a029c8f83ac82f399cd6d9abe058ed1315a9 Mon Sep 17 00:00:00 2001 From: ShaoChunLee Date: Wed, 20 May 2026 17:02:47 +0000 Subject: [PATCH 4/9] bug fix --- atom/model_ops/moe.py | 202 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 201 insertions(+), 1 deletion(-) diff --git a/atom/model_ops/moe.py b/atom/model_ops/moe.py index f96a6f87a..a188e70c6 100644 --- a/atom/model_ops/moe.py +++ b/atom/model_ops/moe.py @@ -834,7 +834,9 @@ def process_weights_after_loading(self, layer): # the layout moe_gemm_a8w4 expects and swizzles scales. _use_a8w4_moe = os.environ.get("ATOM_MOE_BACKEND", "matmul_ogs") == "a8w4" if _use_a8w4_moe and self.use_triton: - from aiter.ops.triton.moe.moe_op_gemm_a8w4 import swizzle_scales + from aiter.ops.triton.moe.moe_op_gemm_a8w4 import ( + swizzle_scales_gfx950 as swizzle_scales, + ) # Step 3: gate/up interleave for apply_swiglu fold. Off by default. _interleave_w13 = os.environ.get("ATOM_A8W4_SWIGLU_FOLD", "0") == "1" @@ -931,6 +933,7 @@ def _interleave_gateup(t): layer.w2_weight = w2_weight layer.w13_weight_scale = None layer.w2_weight_scale = None + self.moe_backend = "matmul_ogs" return # shuffle weight @@ -1009,6 +1012,149 @@ def apply( ) if needs_custom_routing: + _a8w4 = getattr(self, "moe_backend", "matmul_ogs") == "a8w4" + + # DeepSeek-V4 hash-layer fully-fused fast path: replaces the + # Python `_hash_topk` + multi-kernel `fused_routing_from_topk` + # counting-sort + `compute_expt_data` (with memset) chain with + # ONE Triton kernel (hash_routing) + sort_tokens_fused. Same + # 2-kernel shape as the non-hash routing_a8w4 path. + # + # Detection: custom_routing_function is the bound method of a + # MoeLayer whose gate has the tid2eid lookup table. + _hash_layer = ( + _a8w4 + and custom_routing_function is not None + and hasattr(custom_routing_function, "__self__") + and hasattr( + getattr(custom_routing_function.__self__, "gate", None), + "tid2eid", + ) + ) + _a8w4_hash_fused = ( + _hash_layer + and os.environ.get("ATOM_A8W4_TRITON_ROUTING", "0") == "1" + and os.environ.get("ATOM_A8W4_HASH_FAST_ROUTING", "0") == "1" + and layer.num_fused_shared_experts == 0 + ) + if _a8w4_hash_fused: + from aiter.ops.triton.moe.moe_routing.routing import ( + routing_a8w4_from_hash, + ) + from aiter.ops.triton.moe.moe_op_gemm_a8w4 import recommend_block_m + from atom.model_ops.fused_moe_triton import _a8w4_fused_experts + + moe_layer = custom_routing_function.__self__ + tid2eid = moe_layer.gate.tid2eid + + # Get input_ids (with DP gather, mirror _hash_topk semantics). + fwd_ctx = get_forward_context() + ids = fwd_ctx.context.input_ids.flatten() + num_tokens = router_logits.shape[0] + if ids.shape[0] < num_tokens: + ids_2d = ids.unsqueeze(-1) + ids_2d, _ = pad_for_all_gather(ids_2d) + from aiter.dist.parallel_state import get_dp_group + + ids_2d = get_dp_group().all_gather(ids_2d, dim=0) + ids = ids_2d[:num_tokens].flatten() + ids = ids.clamp(0, tid2eid.shape[0] - 1) + + n_expts_tot = router_logits.shape[-1] + if global_num_experts > 0: + n_expts_tot = global_num_experts + n_expts_tot = n_expts_tot + layer.num_fused_shared_experts + + block_m = recommend_block_m(x.shape[-2]) + aiter_routing, gather_src, scatter_src = routing_a8w4_from_hash( + router_logits, + tid2eid, + ids, + n_expts_act=top_k, + block_m=block_m, + score_mode="sqrtsoftplus", + renorm=renormalize, + routed_scaling_factor=moe_layer.routed_scaling_factor, + ) + moe_residual = getattr(layer, "_moe_residual_to_fold", None) + if moe_residual is not None: + layer._moe_residual_was_folded = True + output = torch.empty_like(x) + return _a8w4_fused_experts( + output, + x, + layer.w13_weight, + layer.w2_weight, + layer.w13_weight_scale, + layer.w2_weight_scale, + routing_data=None, + gather_indx=gather_src, + scatter_indx=scatter_src, + topk=top_k, + swiglu_limit=getattr(layer, "swiglu_limit", 0.0), + w1_bias=layer.w13_bias, + w2_bias=layer.w2_bias, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + pre_built_aiter_routing=aiter_routing, + residual=moe_residual, + ) + + _a8w4_triton_routing = ( + _a8w4 + and os.environ.get("ATOM_A8W4_TRITON_ROUTING", "0") == "1" + and scoring_func == "sqrtsoftplus" + and not use_grouped_topk + and custom_routing_function is None + and layer.num_fused_shared_experts == 0 + ) + if _a8w4_triton_routing: + # Step 4: aiter `routing_a8w4` does V4 math + topk + sort + ExptData + # in one Triton pipeline. Skip FusedMoE.select_experts + bridge. + from aiter.ops.triton.moe.moe_routing.routing import routing_a8w4 + from aiter.ops.triton.moe.moe_op_gemm_a8w4 import recommend_block_m + from atom.model_ops.fused_moe_triton import _a8w4_fused_experts + + n_expts_tot = router_logits.shape[-1] + if global_num_experts > 0: + n_expts_tot = global_num_experts + n_expts_tot = n_expts_tot + layer.num_fused_shared_experts + M = x.shape[-2] + block_m = recommend_block_m(M) + aiter_routing, gather_src, scatter_src = routing_a8w4( + router_logits, + n_expts_act=top_k, + block_m=block_m, + score_mode="sqrtsoftplus", + bias=e_score_correction_bias, + renorm=renormalize, + routed_scaling_factor=layer.routed_scaling_factor, + ) + # Step 9: V4 single-stream stashes shared-experts output on + # the layer; fold it into reduce_grouped's writeback. None + # for callers that didn't pre-compute shared. + moe_residual = getattr(layer, "_moe_residual_to_fold", None) + if moe_residual is not None: + layer._moe_residual_was_folded = True + output = torch.empty_like(x) + return _a8w4_fused_experts( + output, + x, + layer.w13_weight, + layer.w2_weight, + layer.w13_weight_scale, + layer.w2_weight_scale, + routing_data=None, + gather_indx=gather_src, + scatter_indx=scatter_src, + topk=top_k, + swiglu_limit=getattr(layer, "swiglu_limit", 0.0), + w1_bias=layer.w13_bias, + w2_bias=layer.w2_bias, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + pre_built_aiter_routing=aiter_routing, + residual=moe_residual, + ) + # Use ATOM's full-featured select_experts for routing, # then triton matmul_ogs for the actual MoE computation. topk_weights, topk_ids = FusedMoE.select_experts( @@ -1033,10 +1179,61 @@ def apply( n_expts_tot = global_num_experts n_expts_tot = n_expts_tot + layer.num_fused_shared_experts + n_gates_pad = topk_weights.shape[0] * topk_weights.shape[1] + _a8w4_topk_fast = ( + _a8w4 + and os.environ.get("ATOM_A8W4_TRITON_ROUTING", "0") == "1" + and os.environ.get("ATOM_A8W4_HASH_FAST_ROUTING", "0") == "1" + and layer.num_fused_shared_experts == 0 + and n_gates_pad <= 4096 + ) + if _a8w4_topk_fast: + # Custom-routing a8w4 fast path (DeepSeek-V4 hash layers): + # build aiter RoutingData directly from (topk_weights, topk_ids) + # via `routing_a8w4_from_topk`. Skips matmul_ogs `compute_expt_data`, + # eliminating its histogram memset (`_expt_data_memset`). + # Bounded by `fused_routing_from_topk`'s 4096-NK single-CTA budget; + # prefill exceeds this and falls through to the matmul_ogs path + # where routing overhead is amortised over much larger GEMM work. + from aiter.ops.triton.moe.moe_routing.routing import ( + routing_a8w4_from_topk, + ) + from aiter.ops.triton.moe.moe_op_gemm_a8w4 import recommend_block_m + from atom.model_ops.fused_moe_triton import _a8w4_fused_experts + + block_m = recommend_block_m(x.shape[-2]) + aiter_routing, gather_src, scatter_src = routing_a8w4_from_topk( + topk_weights, topk_ids, n_expts_tot, block_m + ) + moe_residual = getattr(layer, "_moe_residual_to_fold", None) + if moe_residual is not None: + layer._moe_residual_was_folded = True + output = torch.empty_like(x) + return _a8w4_fused_experts( + output, + x, + layer.w13_weight, + layer.w2_weight, + layer.w13_weight_scale, + layer.w2_weight_scale, + routing_data=None, + gather_indx=gather_src, + scatter_indx=scatter_src, + topk=n_expts_act, + swiglu_limit=getattr(layer, "swiglu_limit", 0.0), + w1_bias=layer.w13_bias, + w2_bias=layer.w2_bias, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + pre_built_aiter_routing=aiter_routing, + residual=moe_residual, + ) + routing_data, gather_idx, scatter_idx = fused_routing_from_topk_triton( topk_weights, topk_ids, n_expts_tot ) + _a8w4 = getattr(self, "moe_backend", "matmul_ogs") == "a8w4" + output = torch.empty_like(x) _moe_result = triton_kernel_fused_experts( output, @@ -1056,6 +1253,9 @@ def apply( apply_router_weight_on_input=layer.apply_router_weight_on_input, global_num_experts=n_expts_tot, expert_map=expert_map, + w13_weight_scale_a8w4=(layer.w13_weight_scale if _a8w4 else None), + w2_weight_scale_a8w4=(layer.w2_weight_scale if _a8w4 else None), + swiglu_limit_a8w4=getattr(layer, "swiglu_limit", 0.0), ) return _moe_result From 903493c804a6a991c0961780a1e1615073d6744b Mon Sep 17 00:00:00 2001 From: ShaoChunLee Date: Wed, 20 May 2026 20:53:22 +0000 Subject: [PATCH 5/9] clean up --- atom/model_ops/fused_moe_triton.py | 17 +++++--- atom/model_ops/moe.py | 63 +++++++++++++++++------------- 2 files changed, 47 insertions(+), 33 deletions(-) diff --git a/atom/model_ops/fused_moe_triton.py b/atom/model_ops/fused_moe_triton.py index 24648a0b8..e79dca2ef 100644 --- a/atom/model_ops/fused_moe_triton.py +++ b/atom/model_ops/fused_moe_triton.py @@ -31,6 +31,11 @@ ) from aiter.ops.triton.fusions.fused_clamp_act_mul import fused_clamp_act_mul from atom.model_ops.utils import has_triton_kernels +from atom.utils import envs + +_A8W4_TRITON_MOE = os.environ.get( + "ATOM_MOE_BACKEND", "matmul_ogs" +) == "a8w4" and envs.is_set("ATOM_USE_TRITON_MOE") logger = logging.getLogger("atom") @@ -267,6 +272,8 @@ def _a8w4_fused_experts( pre_built_aiter_routing=None, # if set, skip bridge and use this aiter routing directly # Step 9: shared-expert output to fold into GEMM2 reduce_grouped writeback. residual: torch.Tensor | None = None, + # Step 3: apply_swiglu fold into GEMM1. Requires interleaved W13 (done at weight load). + swiglu_fold: bool = False, ) -> torch.Tensor: """Step 1 minimum a8w4 fused experts. @@ -280,8 +287,6 @@ def _a8w4_fused_experts( M, K = hidden_states.shape[-2:] BLOCK_M = recommend_block_m(M) - # Step 3: apply_swiglu fold into GEMM1. Requires interleaved W13 (done at weight load). - _swiglu_fold = os.environ.get("ATOM_A8W4_SWIGLU_FOLD", "0") == "1" # Pre-MoE quant: bf16 -> fp8 e4m3 + ue8m0 per-1x32. x_fp8, x_scale = downcast_to_mxfp(hidden_states, torch.float8_e4m3fn, axis=-1) @@ -303,10 +308,10 @@ def _a8w4_fused_experts( if w2_scale.dtype != torch.uint8: w2_scale = w2_scale.view(torch.uint8) - if _swiglu_fold: + if swiglu_fold: # Step 5: optional fold of MXFP8 (fp8 e4m3 + ue8m0 per-1×32) emit into # GEMM1's apply_swiglu writeback. Saves one `downcast_to_mxfp` launch. - _mx_emit = os.environ.get("ATOM_A8W4_GEMM1_MX_EMIT", "0") == "1" + _mx_emit = os.environ.get("ATOM_A8W4_TRITON_MOE_GEMM1_MX_EMIT", "1") == "1" if _mx_emit: inter_fp8, inter_scale = moe_gemm_a8w4( x=x_fp8, @@ -481,6 +486,7 @@ def triton_kernel_fused_experts( w13_weight_scale_a8w4: torch.Tensor | None = None, w2_weight_scale_a8w4: torch.Tensor | None = None, swiglu_limit_a8w4: float = 7.0, + swiglu_fold: bool = False, ) -> torch.Tensor: # type check, uint8 means mxfp4 assert hidden_states.dtype == torch.bfloat16 @@ -490,7 +496,7 @@ def triton_kernel_fused_experts( # a8w4 path: dispatched when env is set AND a8w4 scale tensors were plumbed # through. matmul_ogs path below stays untouched. _use_a8w4 = ( - os.environ.get("ATOM_MOE_BACKEND", "matmul_ogs") == "a8w4" + _A8W4_TRITON_MOE and w13_weight_scale_a8w4 is not None and w2_weight_scale_a8w4 is not None ) @@ -510,6 +516,7 @@ def triton_kernel_fused_experts( w1_bias=w1_bias, w2_bias=w2_bias, apply_router_weight_on_input=apply_router_weight_on_input, + swiglu_fold=swiglu_fold, ) # Shape check for matmul_ogs path (a8w4 has different layout) diff --git a/atom/model_ops/moe.py b/atom/model_ops/moe.py index a188e70c6..8e734f0d5 100644 --- a/atom/model_ops/moe.py +++ b/atom/model_ops/moe.py @@ -58,6 +58,20 @@ from atom.plugin.moe import FusedMoEDecoratorForPluginMode from atom.quantization.quark.utils import weight_dequant_fp8 +_A8W4_TRITON_MOE = os.environ.get( + "ATOM_MOE_BACKEND", "matmul_ogs" +) == "a8w4" and envs.is_set("ATOM_USE_TRITON_MOE") +_A8W4_TRITON_MOE_ROUTING = ( + os.environ.get("ATOM_A8W4_TRITON_MOE_ROUTING", "1") == "1" and _A8W4_TRITON_MOE +) +_A8W4_TRITON_MOE_HASH_FAST_ROUTING = ( + os.environ.get("ATOM_A8W4_TRITON_MOE_HASH_FAST_ROUTING", "1") == "1" + and _A8W4_TRITON_MOE +) +_A8W4_TRITON_MOE_INTERLEAVE_W13 = ( + os.environ.get("ATOM_A8W4_TRITON_MOE_SWIGLU_FOLD", "1") == "1" and _A8W4_TRITON_MOE +) + class FusedMoeWeightScaleSupported(Enum): """Supported quantization strategies for MoE weight scales.""" @@ -832,15 +846,12 @@ def process_weights_after_loading(self, layer): # Step 1: a8w4 MoE backend weight-layout branch. Reorders W13/W2 into # the layout moe_gemm_a8w4 expects and swizzles scales. - _use_a8w4_moe = os.environ.get("ATOM_MOE_BACKEND", "matmul_ogs") == "a8w4" - if _use_a8w4_moe and self.use_triton: + if _A8W4_TRITON_MOE and self.use_triton: + # Step 3: gate/up interleave for apply_swiglu fold. Off by default. from aiter.ops.triton.moe.moe_op_gemm_a8w4 import ( swizzle_scales_gfx950 as swizzle_scales, ) - # Step 3: gate/up interleave for apply_swiglu fold. Off by default. - _interleave_w13 = os.environ.get("ATOM_A8W4_SWIGLU_FOLD", "0") == "1" - def _interleave_gateup(t): """t: [E, 2N, *] → reorder dim=1 from [g..u..] to [g,u,g,u,...]""" E, two_N = t.shape[0], t.shape[1] @@ -855,7 +866,7 @@ def _interleave_gateup(t): # W1 weight: view-transpose to [E, K/2, 2N] (no .contiguous()) w1_w = layer.w13_weight.data.view(torch.uint8) # [E, 2N, K/2] - if _interleave_w13: + if _A8W4_TRITON_MOE_INTERLEAVE_W13: w1_w = _interleave_gateup(w1_w) w1_w_kernel = w1_w.transpose(-1, -2) # [E, K/2, 2N], view assert ( @@ -864,7 +875,7 @@ def _interleave_gateup(t): # W1 scale: transpose + contiguous + swizzle w1_s = layer.w13_weight_scale.data # [E, 2N, K/32] - if _interleave_w13: + if _A8W4_TRITON_MOE_INTERLEAVE_W13: w1_s = _interleave_gateup(w1_s) w1_s_swz_in = w1_s.transpose(-1, -2).contiguous() # [E, K/32, 2N] w1_s_E, w1_s_SCALE_K, w1_s_N = w1_s_swz_in.shape @@ -903,7 +914,7 @@ def _interleave_gateup(t): self.w13_precision_config = None self.w2_precision_config = None self.moe_backend = "a8w4" - self.a8w4_swiglu_fold = _interleave_w13 + self.swiglu_fold = _A8W4_TRITON_MOE_INTERLEAVE_W13 return if self.use_triton: @@ -997,10 +1008,17 @@ def apply( ) -> torch.Tensor: if self.use_triton: from atom.model_ops.fused_moe_triton import ( + _a8w4_fused_experts, triton_kernel_moe_forward, triton_kernel_fused_experts, fused_routing_from_topk_triton, ) + from aiter.ops.triton.moe.moe_op_gemm_a8w4 import recommend_block_m + from aiter.ops.triton.moe.moe_routing.routing import ( + routing_a8w4, + routing_a8w4_from_hash, + routing_a8w4_from_topk, + ) # Check if the model needs custom routing that triton routing() # does not support (grouped topk, sigmoid scoring, bias correction). @@ -1033,17 +1051,11 @@ def apply( ) _a8w4_hash_fused = ( _hash_layer - and os.environ.get("ATOM_A8W4_TRITON_ROUTING", "0") == "1" - and os.environ.get("ATOM_A8W4_HASH_FAST_ROUTING", "0") == "1" + and _A8W4_TRITON_MOE_ROUTING + and _A8W4_TRITON_MOE_HASH_FAST_ROUTING and layer.num_fused_shared_experts == 0 ) if _a8w4_hash_fused: - from aiter.ops.triton.moe.moe_routing.routing import ( - routing_a8w4_from_hash, - ) - from aiter.ops.triton.moe.moe_op_gemm_a8w4 import recommend_block_m - from atom.model_ops.fused_moe_triton import _a8w4_fused_experts - moe_layer = custom_routing_function.__self__ tid2eid = moe_layer.gate.tid2eid @@ -1097,11 +1109,12 @@ def apply( apply_router_weight_on_input=layer.apply_router_weight_on_input, pre_built_aiter_routing=aiter_routing, residual=moe_residual, + swiglu_fold=getattr(self, "swiglu_fold", False), ) _a8w4_triton_routing = ( _a8w4 - and os.environ.get("ATOM_A8W4_TRITON_ROUTING", "0") == "1" + and _A8W4_TRITON_MOE_ROUTING and scoring_func == "sqrtsoftplus" and not use_grouped_topk and custom_routing_function is None @@ -1110,10 +1123,6 @@ def apply( if _a8w4_triton_routing: # Step 4: aiter `routing_a8w4` does V4 math + topk + sort + ExptData # in one Triton pipeline. Skip FusedMoE.select_experts + bridge. - from aiter.ops.triton.moe.moe_routing.routing import routing_a8w4 - from aiter.ops.triton.moe.moe_op_gemm_a8w4 import recommend_block_m - from atom.model_ops.fused_moe_triton import _a8w4_fused_experts - n_expts_tot = router_logits.shape[-1] if global_num_experts > 0: n_expts_tot = global_num_experts @@ -1153,6 +1162,7 @@ def apply( apply_router_weight_on_input=layer.apply_router_weight_on_input, pre_built_aiter_routing=aiter_routing, residual=moe_residual, + swiglu_fold=getattr(self, "swiglu_fold", False), ) # Use ATOM's full-featured select_experts for routing, @@ -1182,8 +1192,8 @@ def apply( n_gates_pad = topk_weights.shape[0] * topk_weights.shape[1] _a8w4_topk_fast = ( _a8w4 - and os.environ.get("ATOM_A8W4_TRITON_ROUTING", "0") == "1" - and os.environ.get("ATOM_A8W4_HASH_FAST_ROUTING", "0") == "1" + and _A8W4_TRITON_MOE_ROUTING + and _A8W4_TRITON_MOE_HASH_FAST_ROUTING and layer.num_fused_shared_experts == 0 and n_gates_pad <= 4096 ) @@ -1195,11 +1205,6 @@ def apply( # Bounded by `fused_routing_from_topk`'s 4096-NK single-CTA budget; # prefill exceeds this and falls through to the matmul_ogs path # where routing overhead is amortised over much larger GEMM work. - from aiter.ops.triton.moe.moe_routing.routing import ( - routing_a8w4_from_topk, - ) - from aiter.ops.triton.moe.moe_op_gemm_a8w4 import recommend_block_m - from atom.model_ops.fused_moe_triton import _a8w4_fused_experts block_m = recommend_block_m(x.shape[-2]) aiter_routing, gather_src, scatter_src = routing_a8w4_from_topk( @@ -1226,6 +1231,7 @@ def apply( apply_router_weight_on_input=layer.apply_router_weight_on_input, pre_built_aiter_routing=aiter_routing, residual=moe_residual, + swiglu_fold=getattr(self, "swiglu_fold", False), ) routing_data, gather_idx, scatter_idx = fused_routing_from_topk_triton( @@ -1256,6 +1262,7 @@ def apply( w13_weight_scale_a8w4=(layer.w13_weight_scale if _a8w4 else None), w2_weight_scale_a8w4=(layer.w2_weight_scale if _a8w4 else None), swiglu_limit_a8w4=getattr(layer, "swiglu_limit", 0.0), + swiglu_fold=getattr(self, "swiglu_fold", False), ) return _moe_result From 52e8a7de504d67bed031902ae785cf2be0ffb1e6 Mon Sep 17 00:00:00 2001 From: ShaoChunLee Date: Thu, 21 May 2026 17:48:19 +0000 Subject: [PATCH 6/9] add fused_router_gate_a16w16_quant --- atom/model_ops/fused_moe_triton.py | 12 +++--- atom/model_ops/moe.py | 40 ++++++++++++++++--- atom/models/deepseek_v4.py | 63 +++++++++++++++++++++++++++++- 3 files changed, 103 insertions(+), 12 deletions(-) diff --git a/atom/model_ops/fused_moe_triton.py b/atom/model_ops/fused_moe_triton.py index e79dca2ef..b3c2b2d1c 100644 --- a/atom/model_ops/fused_moe_triton.py +++ b/atom/model_ops/fused_moe_triton.py @@ -274,6 +274,7 @@ def _a8w4_fused_experts( residual: torch.Tensor | None = None, # Step 3: apply_swiglu fold into GEMM1. Requires interleaved W13 (done at weight load). swiglu_fold: bool = False, + x_scale: torch.Tensor | None = None, ) -> torch.Tensor: """Step 1 minimum a8w4 fused experts. @@ -289,7 +290,10 @@ def _a8w4_fused_experts( BLOCK_M = recommend_block_m(M) # Pre-MoE quant: bf16 -> fp8 e4m3 + ue8m0 per-1x32. - x_fp8, x_scale = downcast_to_mxfp(hidden_states, torch.float8_e4m3fn, axis=-1) + if x_scale is None: + x_fp8, x_scale = downcast_to_mxfp(hidden_states, torch.float8_e4m3fn, axis=-1) + else: + x_fp8 = hidden_states # Step 4: skip bridge when pre_built_aiter_routing is provided. if pre_built_aiter_routing is not None: @@ -303,10 +307,8 @@ def _a8w4_fused_experts( gather_src_indx = gather_indx.src_indx scatter_src_indx = scatter_indx.src_indx - if w1_scale.dtype != torch.uint8: - w1_scale = w1_scale.view(torch.uint8) - if w2_scale.dtype != torch.uint8: - w2_scale = w2_scale.view(torch.uint8) + w1_scale = w1_scale.view(torch.uint8) + w2_scale = w2_scale.view(torch.uint8) if swiglu_fold: # Step 5: optional fold of MXFP8 (fp8 e4m3 + ue8m0 per-1×32) emit into diff --git a/atom/model_ops/moe.py b/atom/model_ops/moe.py index 8e734f0d5..6d7e779cc 100644 --- a/atom/model_ops/moe.py +++ b/atom/model_ops/moe.py @@ -283,6 +283,7 @@ def apply( fused_shared_experts_scoring_func: Optional[str] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", + x_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: raise NotImplementedError @@ -1005,6 +1006,7 @@ def apply( apply_router_weight_on_input: bool = False, fused_shared_experts_scoring_func: Optional[str] = None, activation: ActivationType = ActivationType.Silu, + x_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: if self.use_triton: from atom.model_ops.fused_moe_triton import ( @@ -1110,6 +1112,7 @@ def apply( pre_built_aiter_routing=aiter_routing, residual=moe_residual, swiglu_fold=getattr(self, "swiglu_fold", False), + x_scale=x_scale, ) _a8w4_triton_routing = ( @@ -1163,6 +1166,7 @@ def apply( pre_built_aiter_routing=aiter_routing, residual=moe_residual, swiglu_fold=getattr(self, "swiglu_fold", False), + x_scale=x_scale, ) # Use ATOM's full-featured select_experts for routing, @@ -1685,6 +1689,7 @@ def apply( apply_router_weight_on_input: bool = False, fused_shared_experts_scoring_func: Optional[str] = None, activation: ActivationType = ActivationType.Silu, + x_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Apply compressed-tensors FP8 MoE computation.""" # Select top-k experts using router logits @@ -2041,6 +2046,7 @@ def apply( apply_router_weight_on_input: bool = False, fused_shared_experts_scoring_func: Optional[str] = None, activation: ActivationType = ActivationType.Silu, + x_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, @@ -2145,16 +2151,18 @@ def moe_forward( hidden_states: torch.Tensor, router_logits: torch.Tensor, layer_name: str, + hidden_states_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: atom_config = get_current_atom_config() self = atom_config.compilation_config.static_forward_context[layer_name] - return self.forward_impl(hidden_states, router_logits) + return self.forward_impl(hidden_states, router_logits, hidden_states_scale) def moe_forward_fake( hidden_states: torch.Tensor, router_logits: torch.Tensor, layer_name: str, + hidden_states_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: return torch.empty_like(hidden_states) @@ -3192,13 +3200,24 @@ def select_experts( return topk_weights, topk_ids - def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): + def forward( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + hidden_states_scale: torch.Tensor = None, + ): return torch.ops.aiter.moe_forward( - hidden_states, router_logits, self.layer_name + hidden_states, + router_logits, + self.layer_name, + hidden_states_scale=hidden_states_scale, ) def forward_impl_graph( - self, hidden_states: torch.Tensor, router_logits: torch.Tensor + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + hidden_states_scale: torch.Tensor = None, ): # There are three mode # 1. Pure DP mode: only DP is used @@ -3246,6 +3265,7 @@ def forward_impl_graph( fused_shared_experts_scoring_func=self.shared_expert_scoring_func, activation=self.activation, apply_router_weight_on_input=self.apply_router_weight_on_input, + x_scale=hidden_states_scale, ) # Use reduce_scatter when DP > 1 but not using mori all2all kernels @@ -3266,11 +3286,18 @@ def forward_impl_graph( return final_hidden_states - def forward_impl(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): + def forward_impl( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + hidden_states_scale: torch.Tensor = None, + ): assert self.quant_method is not None # cuda graph not supported forward with combine and dispatch if self.use_chunked: - return self.forward_impl_graph(hidden_states, router_logits) + return self.forward_impl_graph( + hidden_states, router_logits, hidden_states_scale=hidden_states_scale + ) # return self.forward_impl_chunked(hidden_states, router_logits) dp_group = get_dp_group() @@ -3300,6 +3327,7 @@ def forward_impl(self, hidden_states: torch.Tensor, router_logits: torch.Tensor) fused_shared_experts_scoring_func=self.shared_expert_scoring_func, activation=self.activation, apply_router_weight_on_input=self.apply_router_weight_on_input, + x_scale=hidden_states_scale, ) dp_group = get_dp_group() diff --git a/atom/models/deepseek_v4.py b/atom/models/deepseek_v4.py index a22677f82..3b03b0207 100644 --- a/atom/models/deepseek_v4.py +++ b/atom/models/deepseek_v4.py @@ -47,10 +47,15 @@ from aiter.ops.triton.fusions.fused_clamp_act_mul import ( fused_clamp_act_mul, ) +from aiter.ops.triton.gemm.fused.fused_gemm_a16w16_quant_x import ( + fused_gemm_a16w16_quant_x, +) from aiter.ops.triton.fusions.fused_reduce_qk_norm_rope_swa_write import ( fused_reduce_qk_norm_rope_swa_write, ) from aiter.ops.triton.pa_mqa_logits import deepgemm_fp8_paged_mqa_logits +from aiter.tuned_gemm import tgemm +from aiter.ops.triton.moe.quant_moe import downcast_to_mxfp from atom.config import ( Config, LayerQuantConfig, @@ -69,7 +74,7 @@ ReplicatedLinear, RowParallelLinear, ) -from atom.model_ops.moe import FusedMoE +from atom.model_ops.moe import FusedMoE, Mxfp4MoEMethod from atom.model_ops.topK import is_rocm_aiter_fusion_shared_expert_enabled from aiter import rope_rotate_activation from atom.model_ops.quant_v4 import act_quant_inplace @@ -134,6 +139,9 @@ _V4_USE_MXFP8 = os.environ.get("ATOM_FP8_BLOCKSCALE_USE_MXFP8", "0") == "1" _V4_DISABLE_FCAM = os.environ.get("ATOM_V4_DISABLE_FCAM", "0") == "1" _MXFP8_BYPASS_FCAM = os.environ.get("ATOM_MXFP8_BYPASS_FCAM", "0") == "1" +_A8W4_TRITON_MOE = os.environ.get( + "ATOM_MOE_BACKEND", "matmul_ogs" +) == "a8w4" and envs.is_set("ATOM_USE_TRITON_MOE") def _fuse_rmsnorm_mxfp8_quant( @@ -282,6 +290,44 @@ def fused_qk_norm_rope_swa_write( return q_out +def _fused_router_gate_a16w16_quant_fake( + x: torch.Tensor, + gate_weight: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + num_tokens, K = x.shape + n_routed_experts = gate_weight.shape[0] + router_logits = torch.empty( + (num_tokens, n_routed_experts), dtype=torch.bfloat16, device=x.device + ) + x_out = torch.empty((num_tokens, K), dtype=dtypes.fp8, device=x.device) + x_scales = torch.empty((num_tokens, K // 32), dtype=torch.uint8, device=x.device) + return router_logits, x_out, x_scales + + +@torch_compile_guard(gen_fake=_fused_router_gate_a16w16_quant_fake, mutates_args=[]) +def fused_router_gate_a16w16_quant( + x: torch.Tensor, + gate_weight: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """A16W16 router-gate GEMM, optionally fused with a BF16->FP8 downcast + copy of x. + + For decode-shaped batches (num_tokens <= 64) the GEMM and the activation + downcast run in a single triton kernel; the FP8 copy is consumed by the + downstream fused_moe kernel, replacing its internal + `hidden_states.to(dtypes.fp8)` cast with a copy done inline with the + gate GEMM's A-loads. For larger batches the fusion's per-element copy + overhead is not worth the savings, so we fall back to tgemm.mm and let + aiter re-trigger the `.to(fp8)` cast on its side. + """ + num_tokens = x.shape[0] + if num_tokens <= 64: + return fused_gemm_a16w16_quant_x(x, gate_weight, quant_dtype=dtypes.fp8) + router_logits = tgemm.mm(x, gate_weight, None, otype=torch.bfloat16) + x, x_scale = downcast_to_mxfp(x, dtypes.fp8, axis=-1) + return router_logits, x, x_scale + + def _make_weightless_rmsnorm(dim: int, eps: float) -> RMSNorm: """Build an `RMSNorm(dim, eps)` whose `.weight` is `None`. @@ -2203,6 +2249,12 @@ def __init__( prefix ] = self + self._use_a8w4_triton_moe = False + if self.experts.quant_method.__class__ is Mxfp4MoEMethod and getattr( + self.experts.quant_method, "use_triton", False + ): + self._use_a8w4_triton_moe = _A8W4_TRITON_MOE + def _hash_topk( self, hidden_states: torch.Tensor, @@ -2263,6 +2315,15 @@ def routed_expert_forward( `_hash_topk` (FusedMoE's custom_routing_function) reads it there. """ router_logits = self.gate(x) # [num_tokens, n_routed_experts] + if self._use_a8w4_triton_moe: + router_logits, x, x_scale = fused_router_gate_a16w16_quant( + x, self.gate.weight + ) + return self.experts( + hidden_states=x, + router_logits=router_logits, + hidden_states_scale=x_scale, + ) return self.experts(hidden_states=x, router_logits=router_logits) def combine_outputs( From 700e19ec9b0bf5a1592b293fb3dc4c967f551533 Mon Sep 17 00:00:00 2001 From: ShaoChunLee Date: Thu, 21 May 2026 19:27:41 +0000 Subject: [PATCH 7/9] clean up, add flatten mxfp8 quant --- atom/model_ops/fused_moe_triton.py | 99 ++--------- atom/model_ops/moe.py | 263 +++++++++-------------------- atom/models/deepseek_v4.py | 11 +- 3 files changed, 101 insertions(+), 272 deletions(-) diff --git a/atom/model_ops/fused_moe_triton.py b/atom/model_ops/fused_moe_triton.py index b3c2b2d1c..184da383e 100644 --- a/atom/model_ops/fused_moe_triton.py +++ b/atom/model_ops/fused_moe_triton.py @@ -31,11 +31,6 @@ ) from aiter.ops.triton.fusions.fused_clamp_act_mul import fused_clamp_act_mul from atom.model_ops.utils import has_triton_kernels -from atom.utils import envs - -_A8W4_TRITON_MOE = os.environ.get( - "ATOM_MOE_BACKEND", "matmul_ogs" -) == "a8w4" and envs.is_set("ATOM_USE_TRITON_MOE") logger = logging.getLogger("atom") @@ -272,9 +267,8 @@ def _a8w4_fused_experts( pre_built_aiter_routing=None, # if set, skip bridge and use this aiter routing directly # Step 9: shared-expert output to fold into GEMM2 reduce_grouped writeback. residual: torch.Tensor | None = None, - # Step 3: apply_swiglu fold into GEMM1. Requires interleaved W13 (done at weight load). - swiglu_fold: bool = False, x_scale: torch.Tensor | None = None, + swiglu_fold: bool = True, ) -> torch.Tensor: """Step 1 minimum a8w4 fused experts. @@ -311,50 +305,23 @@ def _a8w4_fused_experts( w2_scale = w2_scale.view(torch.uint8) if swiglu_fold: - # Step 5: optional fold of MXFP8 (fp8 e4m3 + ue8m0 per-1×32) emit into - # GEMM1's apply_swiglu writeback. Saves one `downcast_to_mxfp` launch. - _mx_emit = os.environ.get("ATOM_A8W4_TRITON_MOE_GEMM1_MX_EMIT", "1") == "1" - if _mx_emit: - inter_fp8, inter_scale = moe_gemm_a8w4( - x=x_fp8, - w=w1, - x_scales=x_scale, - w_scales=w1_scale, - bias=w1_bias, - routing_data=aiter_routing, - gather_indx=gather_src_indx, - gammas=gammas if apply_router_weight_on_input else None, - swizzle_mx_scale="CDNA4_SCALE", - apply_swiglu=True, - alpha=1.0, - limit=swiglu_limit, - add_residual=False, - out_dtype=torch.bfloat16, # ignored when out_mx_quant=True - out_mx_quant=True, - ) - else: - # Step 3: GEMM1 with apply_swiglu=True. Reads interleaved W13. Returns bf16 [M*topk, N]. - # add_residual=False = V4-Flash semantics (silu(gate) * up). True would compute s*(up+1). - inter_bf16 = moe_gemm_a8w4( - x=x_fp8, - w=w1, - x_scales=x_scale, - w_scales=w1_scale, - bias=w1_bias, - routing_data=aiter_routing, - gather_indx=gather_src_indx, - gammas=gammas if apply_router_weight_on_input else None, - swizzle_mx_scale="CDNA4_SCALE", - apply_swiglu=True, - alpha=1.0, - limit=swiglu_limit, - add_residual=False, - out_dtype=torch.bfloat16, - ) - inter_bf16 = inter_bf16.view(inter_bf16.shape[-2], inter_bf16.shape[-1]) - inter_fp8, inter_scale = downcast_to_mxfp( - inter_bf16, torch.float8_e4m3fn, axis=-1 - ) + inter_fp8, inter_scale = moe_gemm_a8w4( + x=x_fp8, + w=w1, + x_scales=x_scale, + w_scales=w1_scale, + bias=w1_bias, + routing_data=aiter_routing, + gather_indx=gather_src_indx, + gammas=gammas if apply_router_weight_on_input else None, + swizzle_mx_scale="CDNA4_SCALE", + apply_swiglu=True, + alpha=1.0, + limit=swiglu_limit, + add_residual=False, + out_dtype=torch.bfloat16, # ignored when out_mx_quant=True + out_mx_quant=True, + ) else: # GEMM1: fp8 act x fp4 weight, bf16 output [M*topk, 2N]. raw_intermediate = moe_gemm_a8w4( @@ -485,42 +452,12 @@ def triton_kernel_fused_experts( expert_map: torch.Tensor | None = None, intermediate_cache: torch.Tensor | None = None, a1q_scale: torch.Tensor | None = None, - w13_weight_scale_a8w4: torch.Tensor | None = None, - w2_weight_scale_a8w4: torch.Tensor | None = None, - swiglu_limit_a8w4: float = 7.0, - swiglu_fold: bool = False, ) -> torch.Tensor: # type check, uint8 means mxfp4 assert hidden_states.dtype == torch.bfloat16 assert w1_bias is None or w1_bias.dtype == torch.float32 assert w2_bias is None or w2_bias.dtype == torch.float32 - # a8w4 path: dispatched when env is set AND a8w4 scale tensors were plumbed - # through. matmul_ogs path below stays untouched. - _use_a8w4 = ( - _A8W4_TRITON_MOE - and w13_weight_scale_a8w4 is not None - and w2_weight_scale_a8w4 is not None - ) - if _use_a8w4: - return _a8w4_fused_experts( - output_tensor, - hidden_states, - w1, - w2, - w13_weight_scale_a8w4, - w2_weight_scale_a8w4, - routing_data, - gather_indx, - scatter_indx, - topk, - swiglu_limit=swiglu_limit_a8w4, - w1_bias=w1_bias, - w2_bias=w2_bias, - apply_router_weight_on_input=apply_router_weight_on_input, - swiglu_fold=swiglu_fold, - ) - # Shape check for matmul_ogs path (a8w4 has different layout) assert hidden_states.ndim == 2 assert hidden_states.shape[-1] == w1.shape[-2] diff --git a/atom/model_ops/moe.py b/atom/model_ops/moe.py index 6d7e779cc..2be269863 100644 --- a/atom/model_ops/moe.py +++ b/atom/model_ops/moe.py @@ -58,20 +58,6 @@ from atom.plugin.moe import FusedMoEDecoratorForPluginMode from atom.quantization.quark.utils import weight_dequant_fp8 -_A8W4_TRITON_MOE = os.environ.get( - "ATOM_MOE_BACKEND", "matmul_ogs" -) == "a8w4" and envs.is_set("ATOM_USE_TRITON_MOE") -_A8W4_TRITON_MOE_ROUTING = ( - os.environ.get("ATOM_A8W4_TRITON_MOE_ROUTING", "1") == "1" and _A8W4_TRITON_MOE -) -_A8W4_TRITON_MOE_HASH_FAST_ROUTING = ( - os.environ.get("ATOM_A8W4_TRITON_MOE_HASH_FAST_ROUTING", "1") == "1" - and _A8W4_TRITON_MOE -) -_A8W4_TRITON_MOE_INTERLEAVE_W13 = ( - os.environ.get("ATOM_A8W4_TRITON_MOE_SWIGLU_FOLD", "1") == "1" and _A8W4_TRITON_MOE -) - class FusedMoeWeightScaleSupported(Enum): """Supported quantization strategies for MoE weight scales.""" @@ -716,6 +702,14 @@ def __init__(self, quant_config: LayerQuantConfig, moe: FusedMoEConfig): assert has_triton_kernels(), "triton_kernels is not installed" + self.use_triton_backend = None + if self.use_triton: + self.use_triton_backend = os.environ.get("ATOM_MOE_BACKEND", "a8w4") + assert self.use_triton_backend in ( + "matmul_ogs", + "a8w4", + ), f"ATOM_MOE_BACKEND={self.use_triton_backend} is not supported in Mxfp4MoEMethod, set ATOM_MOE_BACKEND to matmul_ogs or a8w4" + def create_weights( self, layer: torch.nn.Module, @@ -847,7 +841,7 @@ def process_weights_after_loading(self, layer): # Step 1: a8w4 MoE backend weight-layout branch. Reorders W13/W2 into # the layout moe_gemm_a8w4 expects and swizzles scales. - if _A8W4_TRITON_MOE and self.use_triton: + if self.use_triton_backend == "a8w4": # Step 3: gate/up interleave for apply_swiglu fold. Off by default. from aiter.ops.triton.moe.moe_op_gemm_a8w4 import ( swizzle_scales_gfx950 as swizzle_scales, @@ -867,8 +861,7 @@ def _interleave_gateup(t): # W1 weight: view-transpose to [E, K/2, 2N] (no .contiguous()) w1_w = layer.w13_weight.data.view(torch.uint8) # [E, 2N, K/2] - if _A8W4_TRITON_MOE_INTERLEAVE_W13: - w1_w = _interleave_gateup(w1_w) + w1_w = _interleave_gateup(w1_w) w1_w_kernel = w1_w.transpose(-1, -2) # [E, K/2, 2N], view assert ( w1_w_kernel.stride(-2) == 1 @@ -876,8 +869,7 @@ def _interleave_gateup(t): # W1 scale: transpose + contiguous + swizzle w1_s = layer.w13_weight_scale.data # [E, 2N, K/32] - if _A8W4_TRITON_MOE_INTERLEAVE_W13: - w1_s = _interleave_gateup(w1_s) + w1_s = _interleave_gateup(w1_s) w1_s_swz_in = w1_s.transpose(-1, -2).contiguous() # [E, K/32, 2N] w1_s_E, w1_s_SCALE_K, w1_s_N = w1_s_swz_in.shape w1_s_K = w1_s_SCALE_K * 32 @@ -914,11 +906,9 @@ def _interleave_gateup(t): self.w13_precision_config = None self.w2_precision_config = None - self.moe_backend = "a8w4" - self.swiglu_fold = _A8W4_TRITON_MOE_INTERLEAVE_W13 return - if self.use_triton: + elif self.use_triton_backend == "matmul_ogs": from atom.model_ops.fused_moe_triton import _swizzle_mxfp4 from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig @@ -945,7 +935,6 @@ def _interleave_gateup(t): layer.w2_weight = w2_weight layer.w13_weight_scale = None layer.w2_weight_scale = None - self.moe_backend = "matmul_ogs" return # shuffle weight @@ -1010,17 +999,10 @@ def apply( ) -> torch.Tensor: if self.use_triton: from atom.model_ops.fused_moe_triton import ( - _a8w4_fused_experts, triton_kernel_moe_forward, triton_kernel_fused_experts, fused_routing_from_topk_triton, ) - from aiter.ops.triton.moe.moe_op_gemm_a8w4 import recommend_block_m - from aiter.ops.triton.moe.moe_routing.routing import ( - routing_a8w4, - routing_a8w4_from_hash, - routing_a8w4_from_topk, - ) # Check if the model needs custom routing that triton routing() # does not support (grouped topk, sigmoid scoring, bias correction). @@ -1032,122 +1014,81 @@ def apply( ) if needs_custom_routing: - _a8w4 = getattr(self, "moe_backend", "matmul_ogs") == "a8w4" - - # DeepSeek-V4 hash-layer fully-fused fast path: replaces the - # Python `_hash_topk` + multi-kernel `fused_routing_from_topk` - # counting-sort + `compute_expt_data` (with memset) chain with - # ONE Triton kernel (hash_routing) + sort_tokens_fused. Same - # 2-kernel shape as the non-hash routing_a8w4 path. - # - # Detection: custom_routing_function is the bound method of a - # MoeLayer whose gate has the tid2eid lookup table. - _hash_layer = ( - _a8w4 - and custom_routing_function is not None - and hasattr(custom_routing_function, "__self__") - and hasattr( - getattr(custom_routing_function.__self__, "gate", None), - "tid2eid", + if self.use_triton_backend == "a8w4": + assert ( + layer.num_fused_shared_experts == 0 + ), f"A8W4 Triton MOE does not support fused_shared_experts mode, please set ATOM_MOE_BACKEND=matmul_ogs or ATOM_USE_TRITON_MOE=0" + # DeepSeek-V4 hash-layer fully-fused fast path: replaces the + # Python `_hash_topk` + multi-kernel `fused_routing_from_topk` + # counting-sort + `compute_expt_data` (with memset) chain with + # ONE Triton kernel (hash_routing) + sort_tokens_fused. Same + # 2-kernel shape as the non-hash routing_a8w4 path. + # + # Detection: custom_routing_function is the bound method of a + # MoeLayer whose gate has the tid2eid lookup table. + from atom.model_ops.fused_moe_triton import _a8w4_fused_experts + from aiter.ops.triton.moe.moe_op_gemm_a8w4 import recommend_block_m + from aiter.ops.triton.moe.moe_routing.routing import ( + routing_a8w4, + routing_a8w4_from_hash, ) - ) - _a8w4_hash_fused = ( - _hash_layer - and _A8W4_TRITON_MOE_ROUTING - and _A8W4_TRITON_MOE_HASH_FAST_ROUTING - and layer.num_fused_shared_experts == 0 - ) - if _a8w4_hash_fused: - moe_layer = custom_routing_function.__self__ - tid2eid = moe_layer.gate.tid2eid - - # Get input_ids (with DP gather, mirror _hash_topk semantics). - fwd_ctx = get_forward_context() - ids = fwd_ctx.context.input_ids.flatten() - num_tokens = router_logits.shape[0] - if ids.shape[0] < num_tokens: - ids_2d = ids.unsqueeze(-1) - ids_2d, _ = pad_for_all_gather(ids_2d) - from aiter.dist.parallel_state import get_dp_group - - ids_2d = get_dp_group().all_gather(ids_2d, dim=0) - ids = ids_2d[:num_tokens].flatten() - ids = ids.clamp(0, tid2eid.shape[0] - 1) + _hash_layer = ( + custom_routing_function is not None + and hasattr(custom_routing_function, "__self__") + and hasattr( + getattr(custom_routing_function.__self__, "gate", None), + "tid2eid", + ) + ) + M = x.shape[-2] + block_m = recommend_block_m(M) n_expts_tot = router_logits.shape[-1] if global_num_experts > 0: n_expts_tot = global_num_experts n_expts_tot = n_expts_tot + layer.num_fused_shared_experts + if _hash_layer: + moe_layer = custom_routing_function.__self__ + tid2eid = moe_layer.gate.tid2eid + fwd_ctx = get_forward_context() + ids = fwd_ctx.context.input_ids.flatten() + num_tokens = router_logits.shape[0] + if ids.shape[0] < num_tokens: + ids_2d = ids.unsqueeze(-1) + ids_2d, _ = pad_for_all_gather(ids_2d) + from aiter.dist.parallel_state import get_dp_group + + ids_2d = get_dp_group().all_gather(ids_2d, dim=0) + ids = ids_2d[:num_tokens].flatten() + ids = ids.clamp(0, tid2eid.shape[0] - 1) + + aiter_routing, gather_src, scatter_src = routing_a8w4_from_hash( + router_logits, + tid2eid, + ids, + n_expts_act=top_k, + block_m=block_m, + score_mode="sqrtsoftplus", + renorm=renormalize, + routed_scaling_factor=moe_layer.routed_scaling_factor, + ) + else: + aiter_routing, gather_src, scatter_src = routing_a8w4( + router_logits, + n_expts_act=top_k, + block_m=block_m, + score_mode="sqrtsoftplus", + bias=e_score_correction_bias, + renorm=renormalize, + routed_scaling_factor=layer.routed_scaling_factor, + ) - block_m = recommend_block_m(x.shape[-2]) - aiter_routing, gather_src, scatter_src = routing_a8w4_from_hash( - router_logits, - tid2eid, - ids, - n_expts_act=top_k, - block_m=block_m, - score_mode="sqrtsoftplus", - renorm=renormalize, - routed_scaling_factor=moe_layer.routed_scaling_factor, - ) moe_residual = getattr(layer, "_moe_residual_to_fold", None) if moe_residual is not None: layer._moe_residual_was_folded = True - output = torch.empty_like(x) - return _a8w4_fused_experts( - output, - x, - layer.w13_weight, - layer.w2_weight, - layer.w13_weight_scale, - layer.w2_weight_scale, - routing_data=None, - gather_indx=gather_src, - scatter_indx=scatter_src, - topk=top_k, - swiglu_limit=getattr(layer, "swiglu_limit", 0.0), - w1_bias=layer.w13_bias, - w2_bias=layer.w2_bias, - apply_router_weight_on_input=layer.apply_router_weight_on_input, - pre_built_aiter_routing=aiter_routing, - residual=moe_residual, - swiglu_fold=getattr(self, "swiglu_fold", False), - x_scale=x_scale, - ) - - _a8w4_triton_routing = ( - _a8w4 - and _A8W4_TRITON_MOE_ROUTING - and scoring_func == "sqrtsoftplus" - and not use_grouped_topk - and custom_routing_function is None - and layer.num_fused_shared_experts == 0 - ) - if _a8w4_triton_routing: - # Step 4: aiter `routing_a8w4` does V4 math + topk + sort + ExptData - # in one Triton pipeline. Skip FusedMoE.select_experts + bridge. - n_expts_tot = router_logits.shape[-1] - if global_num_experts > 0: - n_expts_tot = global_num_experts - n_expts_tot = n_expts_tot + layer.num_fused_shared_experts - M = x.shape[-2] - block_m = recommend_block_m(M) - aiter_routing, gather_src, scatter_src = routing_a8w4( - router_logits, - n_expts_act=top_k, - block_m=block_m, - score_mode="sqrtsoftplus", - bias=e_score_correction_bias, - renorm=renormalize, - routed_scaling_factor=layer.routed_scaling_factor, + output = torch.empty( + *x.shape, dtype=torch.bfloat16, device=x.device ) - # Step 9: V4 single-stream stashes shared-experts output on - # the layer; fold it into reduce_grouped's writeback. None - # for callers that didn't pre-compute shared. - moe_residual = getattr(layer, "_moe_residual_to_fold", None) - if moe_residual is not None: - layer._moe_residual_was_folded = True - output = torch.empty_like(x) return _a8w4_fused_experts( output, x, @@ -1165,7 +1106,6 @@ def apply( apply_router_weight_on_input=layer.apply_router_weight_on_input, pre_built_aiter_routing=aiter_routing, residual=moe_residual, - swiglu_fold=getattr(self, "swiglu_fold", False), x_scale=x_scale, ) @@ -1193,57 +1133,10 @@ def apply( n_expts_tot = global_num_experts n_expts_tot = n_expts_tot + layer.num_fused_shared_experts - n_gates_pad = topk_weights.shape[0] * topk_weights.shape[1] - _a8w4_topk_fast = ( - _a8w4 - and _A8W4_TRITON_MOE_ROUTING - and _A8W4_TRITON_MOE_HASH_FAST_ROUTING - and layer.num_fused_shared_experts == 0 - and n_gates_pad <= 4096 - ) - if _a8w4_topk_fast: - # Custom-routing a8w4 fast path (DeepSeek-V4 hash layers): - # build aiter RoutingData directly from (topk_weights, topk_ids) - # via `routing_a8w4_from_topk`. Skips matmul_ogs `compute_expt_data`, - # eliminating its histogram memset (`_expt_data_memset`). - # Bounded by `fused_routing_from_topk`'s 4096-NK single-CTA budget; - # prefill exceeds this and falls through to the matmul_ogs path - # where routing overhead is amortised over much larger GEMM work. - - block_m = recommend_block_m(x.shape[-2]) - aiter_routing, gather_src, scatter_src = routing_a8w4_from_topk( - topk_weights, topk_ids, n_expts_tot, block_m - ) - moe_residual = getattr(layer, "_moe_residual_to_fold", None) - if moe_residual is not None: - layer._moe_residual_was_folded = True - output = torch.empty_like(x) - return _a8w4_fused_experts( - output, - x, - layer.w13_weight, - layer.w2_weight, - layer.w13_weight_scale, - layer.w2_weight_scale, - routing_data=None, - gather_indx=gather_src, - scatter_indx=scatter_src, - topk=n_expts_act, - swiglu_limit=getattr(layer, "swiglu_limit", 0.0), - w1_bias=layer.w13_bias, - w2_bias=layer.w2_bias, - apply_router_weight_on_input=layer.apply_router_weight_on_input, - pre_built_aiter_routing=aiter_routing, - residual=moe_residual, - swiglu_fold=getattr(self, "swiglu_fold", False), - ) - routing_data, gather_idx, scatter_idx = fused_routing_from_topk_triton( topk_weights, topk_ids, n_expts_tot ) - _a8w4 = getattr(self, "moe_backend", "matmul_ogs") == "a8w4" - output = torch.empty_like(x) _moe_result = triton_kernel_fused_experts( output, @@ -1263,10 +1156,6 @@ def apply( apply_router_weight_on_input=layer.apply_router_weight_on_input, global_num_experts=n_expts_tot, expert_map=expert_map, - w13_weight_scale_a8w4=(layer.w13_weight_scale if _a8w4 else None), - w2_weight_scale_a8w4=(layer.w2_weight_scale if _a8w4 else None), - swiglu_limit_a8w4=getattr(layer, "swiglu_limit", 0.0), - swiglu_fold=getattr(self, "swiglu_fold", False), ) return _moe_result diff --git a/atom/models/deepseek_v4.py b/atom/models/deepseek_v4.py index 3b03b0207..baaaa38ba 100644 --- a/atom/models/deepseek_v4.py +++ b/atom/models/deepseek_v4.py @@ -53,6 +53,7 @@ from aiter.ops.triton.fusions.fused_reduce_qk_norm_rope_swa_write import ( fused_reduce_qk_norm_rope_swa_write, ) +from aiter.ops.triton.quant.quant_mxfp8 import fused_flatten_mxfp8_quant from aiter.ops.triton.pa_mqa_logits import deepgemm_fp8_paged_mqa_logits from aiter.tuned_gemm import tgemm from aiter.ops.triton.moe.quant_moe import downcast_to_mxfp @@ -139,9 +140,6 @@ _V4_USE_MXFP8 = os.environ.get("ATOM_FP8_BLOCKSCALE_USE_MXFP8", "0") == "1" _V4_DISABLE_FCAM = os.environ.get("ATOM_V4_DISABLE_FCAM", "0") == "1" _MXFP8_BYPASS_FCAM = os.environ.get("ATOM_MXFP8_BYPASS_FCAM", "0") == "1" -_A8W4_TRITON_MOE = os.environ.get( - "ATOM_MOE_BACKEND", "matmul_ogs" -) == "a8w4" and envs.is_set("ATOM_USE_TRITON_MOE") def _fuse_rmsnorm_mxfp8_quant( @@ -1956,6 +1954,9 @@ def forward_impl( o = o.view(num_tokens, self.n_local_groups, -1) wo_a = self.wo_a.weight.view(self.n_local_groups, self.o_lora_rank, -1) o = torch.einsum("sgd,grd->sgr", o, wo_a) + if _V4_USE_MXFP8: + x_fp8, x_scale = fused_flatten_mxfp8_quant(o) + return self.wo_b(x_fp8, x_scale=x_scale) x = self.wo_b(o.flatten(1)) return x @@ -2253,7 +2254,9 @@ def __init__( if self.experts.quant_method.__class__ is Mxfp4MoEMethod and getattr( self.experts.quant_method, "use_triton", False ): - self._use_a8w4_triton_moe = _A8W4_TRITON_MOE + self._use_a8w4_triton_moe = ( + getattr(self.experts.quant_method, "use_triton_backend", None) == "a8w4" + ) def _hash_topk( self, From e9fd00e6bfd6d2d95bca39391f3a898c7fccd2b9 Mon Sep 17 00:00:00 2001 From: ShaoChunLee Date: Thu, 21 May 2026 22:09:10 +0000 Subject: [PATCH 8/9] ruff --- atom/model_ops/moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/atom/model_ops/moe.py b/atom/model_ops/moe.py index 2be269863..109d59185 100644 --- a/atom/model_ops/moe.py +++ b/atom/model_ops/moe.py @@ -1017,7 +1017,7 @@ def apply( if self.use_triton_backend == "a8w4": assert ( layer.num_fused_shared_experts == 0 - ), f"A8W4 Triton MOE does not support fused_shared_experts mode, please set ATOM_MOE_BACKEND=matmul_ogs or ATOM_USE_TRITON_MOE=0" + ), "A8W4 Triton MOE does not support fused_shared_experts mode, please set ATOM_MOE_BACKEND=matmul_ogs or ATOM_USE_TRITON_MOE=0" # DeepSeek-V4 hash-layer fully-fused fast path: replaces the # Python `_hash_topk` + multi-kernel `fused_routing_from_topk` # counting-sort + `compute_expt_data` (with memset) chain with From 7040aa49a12a3bd66cbf1a0c8440b1879e595006 Mon Sep 17 00:00:00 2001 From: ShaoChunLee Date: Tue, 26 May 2026 17:32:55 +0000 Subject: [PATCH 9/9] refactor import --- atom/model_ops/layernorm.py | 4 ++-- atom/model_ops/linear.py | 8 ++++---- atom/models/deepseek_v4.py | 10 +++++----- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/atom/model_ops/layernorm.py b/atom/model_ops/layernorm.py index cde67fcd2..40348924f 100644 --- a/atom/model_ops/layernorm.py +++ b/atom/model_ops/layernorm.py @@ -18,7 +18,7 @@ from aiter.jit.utils.torch_guard import torch_compile_guard from aiter.ops.gated_rmsnorm_fp8_group_quant import gated_rmsnorm_fp8_group_quant from aiter.ops.triton.fused_add_rmsnorm_pad import fused_add_rmsnorm_pad -from aiter.ops.triton.quant.quant_mxfp8 import rmsnorm_mxfp8_quant +from aiter.ops.triton.quant.fused_mxfp8_quant import fused_rms_mxfp8_quant from atom.config import QuantizationConfig from atom.model_ops.utils import atom_parameter from atom.quant_spec import LayerQuantConfig @@ -335,7 +335,7 @@ def forward( x2 = x.reshape(-1, x.shape[-1]) else: x2 = x - y, s = rmsnorm_mxfp8_quant(x2, self.weight, self.eps) + y, s = fused_rms_mxfp8_quant(x2, self.weight, self.eps) if x.dim() != 2: y = y.view(*x.shape[:-1], x.shape[-1]) s = s.view(*x.shape[:-1], s.shape[-1]) diff --git a/atom/model_ops/linear.py b/atom/model_ops/linear.py index 6bee6e4c7..e27f00bf8 100644 --- a/atom/model_ops/linear.py +++ b/atom/model_ops/linear.py @@ -585,12 +585,12 @@ def forward( self, "_mxfp8_active", False ) if _use_mxfp8: - from aiter.ops.triton.quant.quant_mxfp8 import ( - per_1x32_mxfp8_quant_triton, + from aiter.ops.triton.quant.quant import ( + dynamic_mxfp8_quant, ) if x_scale is None: - x, x_scale = per_1x32_mxfp8_quant_triton(x) + x, x_scale = dynamic_mxfp8_quant(x) elif x_scale.dtype == torch.uint8: pass # caller already emitted MXFP8 1x32 else: @@ -603,7 +603,7 @@ def forward( torch.float32 ).view(sM, sCols, 1) x_bf16 = x_dq.view(Mx, Kx).to(torch.bfloat16) - x, x_scale = per_1x32_mxfp8_quant_triton(x_bf16) + x, x_scale = dynamic_mxfp8_quant(x_bf16) elif x_scale is None: quant_func = self.quant_func if self.quant_type.value == QuantType.per_1x128.value: diff --git a/atom/models/deepseek_v4.py b/atom/models/deepseek_v4.py index baaaa38ba..743170646 100644 --- a/atom/models/deepseek_v4.py +++ b/atom/models/deepseek_v4.py @@ -53,7 +53,7 @@ from aiter.ops.triton.fusions.fused_reduce_qk_norm_rope_swa_write import ( fused_reduce_qk_norm_rope_swa_write, ) -from aiter.ops.triton.quant.quant_mxfp8 import fused_flatten_mxfp8_quant +from aiter.ops.triton.quant.fused_mxfp8_quant import fused_flatten_mxfp8_quant from aiter.ops.triton.pa_mqa_logits import deepgemm_fp8_paged_mqa_logits from aiter.tuned_gemm import tgemm from aiter.ops.triton.moe.quant_moe import downcast_to_mxfp @@ -133,7 +133,7 @@ _V4_USE_TRITON_FUSION = os.environ.get("ATOM_V4_USE_TRITON_FUSION", "0") == "1" ENABLE_DS_QKNORM_QUANT_FUSION = envs.ATOM_ENABLE_DS_QKNORM_QUANT_FUSION # MXFP8 a8w8 GEMM path (Task #77). When on, q_norm RMSNorm emits FP8 e4m3fn + -# uint8 e8m0 1x32 scales directly (via the Triton rmsnorm_mxfp8_quant path in +# uint8 e8m0 1x32 scales directly (via the Triton fused_rms_mxfp8_quant path in # atom/model_ops/layernorm.py) so wq_b's MXFP8 GEMM consumes them with no # transcode. The helper below is reserved for call sites that need to fuse # the Q-side MXFP8 quant with the K-side bf16 RMSNorm in a single launch. @@ -158,14 +158,14 @@ def _fuse_rmsnorm_mxfp8_quant( K: RMSNorm only, bf16 output, consumed by the downstream RoPE / SWA-write fused kernel which expects bf16 K. - Both halves run in a single Triton launch (dual_rmsnorm_mxfp8_quant) to + Both halves run in a single Triton launch (fused_dual_rmsnorm_mxfp8_quant) to avoid the +6us/layer per-launch overhead of two separate kernels. Returns (qr_fp8, qr_scale_e8m0, kv_bf16). """ - from aiter.ops.triton.quant.quant_mxfp8 import dual_rmsnorm_mxfp8_quant + from aiter.ops.triton.quant.fused_mxfp8_quant import fused_dual_rmsnorm_mxfp8_quant - qr, qr_scale, kv = dual_rmsnorm_mxfp8_quant( + qr, qr_scale, kv = fused_dual_rmsnorm_mxfp8_quant( q_lora, kv_pre, q_norm_weight,