diff --git a/atom/model_ops/fused_moe_triton.py b/atom/model_ops/fused_moe_triton.py index 8bb1be4c3..184da383e 100644 --- a/atom/model_ops/fused_moe_triton.py +++ b/atom/model_ops/fused_moe_triton.py @@ -218,6 +218,165 @@ def routing_from_topk(topk_weights, topk_ids, n_expts_tot): return routing_data, gather_indx, scatter_indx +def bridge_tk_to_aiter_routing(tk_rd, block_m: int): + """Bridge triton_kernels RoutingData -> aiter RoutingData (for moe_gemm_a8w4). + + triton_kernels carries per-block-m specializations of token_offs_pad / + block_pid_map as dicts keyed by block_m; the aiter variant is specialized + for a single block_m and stores those fields as plain tensors. + """ + from aiter.ops.triton.moe.moe_routing.routing import RoutingData, ExptData + + tk_ed = tk_rd.expt_data + if block_m not in tk_ed.token_offs_pad: + raise ValueError( + f"bridge_tk_to_aiter_routing: block_m={block_m} not present in " + f"tk_rd.expt_data.token_offs_pad. Keys: {sorted(tk_ed.token_offs_pad.keys())}" + ) + aiter_expt_data = ExptData( + hist=tk_ed.hist, + token_offs_raw=tk_ed.token_offs_raw, + token_offs_pad=tk_ed.token_offs_pad[block_m], + block_pid_map=tk_ed.block_pid_map[block_m], + ) + return RoutingData( + block_m=block_m, + gate_scal=tk_rd.gate_scal, + expt_hist=tk_rd.expt_hist, + n_expts_tot=tk_rd.n_expts_tot, + n_expts_act=tk_rd.n_expts_act, + expt_data=aiter_expt_data, + ) + + +def _a8w4_fused_experts( + output_tensor: torch.Tensor, + hidden_states: torch.Tensor, + w1: torch.Tensor, # a8w4 layout: uint8 [E, K/2, 2N], stride(-2)==1 + w2: torch.Tensor, # a8w4 layout: uint8 [E, N/2, K], stride(-2)==1 + w1_scale: torch.Tensor, # swizzled ue8m0 uint8 + w2_scale: torch.Tensor, # swizzled ue8m0 uint8 + routing_data, + gather_indx, + scatter_indx, + topk: int, + swiglu_limit: float, + w1_bias: torch.Tensor | None, + w2_bias: torch.Tensor | None, + apply_router_weight_on_input: bool, + pre_built_aiter_routing=None, # if set, skip bridge and use this aiter routing directly + # Step 9: shared-expert output to fold into GEMM2 reduce_grouped writeback. + residual: torch.Tensor | None = None, + x_scale: torch.Tensor | None = None, + swiglu_fold: bool = True, +) -> torch.Tensor: + """Step 1 minimum a8w4 fused experts. + + Pipeline: downcast_to_mxfp -> moe_gemm_a8w4 (G1, no swiglu) -> + fused_clamp_act_mul (silu_mul, no quant) -> downcast_to_mxfp -> + moe_gemm_a8w4 (G2). No apply_swiglu fold, no out_mx_quant fold, + no residual fold — those are later steps. + """ + from aiter.ops.triton.moe.moe_op_gemm_a8w4 import moe_gemm_a8w4, recommend_block_m + from aiter.ops.triton.moe.quant_moe import downcast_to_mxfp + + M, K = hidden_states.shape[-2:] + BLOCK_M = recommend_block_m(M) + + # Pre-MoE quant: bf16 -> fp8 e4m3 + ue8m0 per-1x32. + if x_scale is None: + x_fp8, x_scale = downcast_to_mxfp(hidden_states, torch.float8_e4m3fn, axis=-1) + else: + x_fp8 = hidden_states + + # Step 4: skip bridge when pre_built_aiter_routing is provided. + if pre_built_aiter_routing is not None: + aiter_routing = pre_built_aiter_routing + gammas = aiter_routing.gate_scal + gather_src_indx = gather_indx # already raw int32 tensor + scatter_src_indx = scatter_indx + else: + aiter_routing = bridge_tk_to_aiter_routing(routing_data, BLOCK_M) + gammas = routing_data.gate_scal if routing_data else None + gather_src_indx = gather_indx.src_indx + scatter_src_indx = scatter_indx.src_indx + + w1_scale = w1_scale.view(torch.uint8) + w2_scale = w2_scale.view(torch.uint8) + + if swiglu_fold: + inter_fp8, inter_scale = moe_gemm_a8w4( + x=x_fp8, + w=w1, + x_scales=x_scale, + w_scales=w1_scale, + bias=w1_bias, + routing_data=aiter_routing, + gather_indx=gather_src_indx, + gammas=gammas if apply_router_weight_on_input else None, + swizzle_mx_scale="CDNA4_SCALE", + apply_swiglu=True, + alpha=1.0, + limit=swiglu_limit, + add_residual=False, + out_dtype=torch.bfloat16, # ignored when out_mx_quant=True + out_mx_quant=True, + ) + else: + # GEMM1: fp8 act x fp4 weight, bf16 output [M*topk, 2N]. + raw_intermediate = moe_gemm_a8w4( + x=x_fp8, + w=w1, + x_scales=x_scale, + w_scales=w1_scale, + bias=w1_bias, + routing_data=aiter_routing, + gather_indx=gather_src_indx, + gammas=gammas if apply_router_weight_on_input else None, + swizzle_mx_scale="CDNA4_SCALE", + apply_swiglu=False, + out_dtype=torch.bfloat16, + ) + raw_2d = raw_intermediate.view( + raw_intermediate.shape[-2], raw_intermediate.shape[-1] + ) + two_N = raw_2d.shape[-1] + inter_bf16 = torch.empty( + (raw_2d.shape[0], two_N // 2), + device=raw_2d.device, + dtype=torch.bfloat16, + ) + fused_clamp_act_mul( + raw_2d, + out=inter_bf16, + swiglu_limit=swiglu_limit, + activation="silu", + dtype_quant=None, + ) + inter_fp8, inter_scale = downcast_to_mxfp( + inter_bf16, torch.float8_e4m3fn, axis=-1 + ) + + # GEMM2: fp8 act x fp4 weight, scatter+combine. + # Step 9: pass `residual` so reduce_grouped folds the routed+shared add + # into the writeback (saves the standalone elementwise add launch). + y = moe_gemm_a8w4( + x=inter_fp8, + w=w2, + x_scales=inter_scale, + w_scales=w2_scale, + bias=w2_bias, + routing_data=aiter_routing, + scatter_indx=scatter_src_indx, + gammas=None if apply_router_weight_on_input else gammas, + swizzle_mx_scale="CDNA4_SCALE", + apply_swiglu=False, + out_dtype=torch.bfloat16, + residual=residual, + ) + return y + + def _resize_cache(x: torch.Tensor, v: tuple[int, ...]) -> torch.Tensor: """ Shrink the given tensor and apply the given view to it. This is @@ -299,7 +458,7 @@ def triton_kernel_fused_experts( assert w1_bias is None or w1_bias.dtype == torch.float32 assert w2_bias is None or w2_bias.dtype == torch.float32 - # Shape check, only check non-mxfp4 + # Shape check for matmul_ogs path (a8w4 has different layout) assert hidden_states.ndim == 2 assert hidden_states.shape[-1] == w1.shape[-2] assert w2.shape[-1] == w1.shape[1] diff --git a/atom/model_ops/layernorm.py b/atom/model_ops/layernorm.py index 14898b200..40348924f 100644 --- a/atom/model_ops/layernorm.py +++ b/atom/model_ops/layernorm.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: MIT # Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +import os from typing import Optional, Tuple import aiter @@ -17,6 +18,7 @@ from aiter.jit.utils.torch_guard import torch_compile_guard from aiter.ops.gated_rmsnorm_fp8_group_quant import gated_rmsnorm_fp8_group_quant from aiter.ops.triton.fused_add_rmsnorm_pad import fused_add_rmsnorm_pad +from aiter.ops.triton.quant.fused_mxfp8_quant import fused_rms_mxfp8_quant from atom.config import QuantizationConfig from atom.model_ops.utils import atom_parameter from atom.quant_spec import LayerQuantConfig @@ -25,6 +27,8 @@ from torch import Tensor, nn from torch.overrides import handle_torch_function, has_torch_function_unary +_V4_USE_MXFP8 = os.environ.get("ATOM_FP8_BLOCKSCALE_USE_MXFP8", "0") == "1" + def silu(input: Tensor, inplace: bool = False) -> Tensor: r"""Apply the Sigmoid Linear Unit (SiLU) function, element-wise. @@ -318,6 +322,24 @@ def forward( and x_scale is None and self.quant_type.value in _AITER_RMS_QUANT_TYPE_VALUES ): + # MXFP8 path: when downstream Linear consumes MXFP8 1x32 e8m0 + # scales (ATOM_FP8_BLOCKSCALE_USE_MXFP8=1), emit those directly + # to skip the dequant+requant cascade. Limited to per_1x128 and + # the no-residual case (most q_norm callers). + if ( + self.quant_type.value == _QV_PER_1X128 + and residual is None + and _V4_USE_MXFP8 + ): + if x.dim() != 2: + x2 = x.reshape(-1, x.shape[-1]) + else: + x2 = x + y, s = fused_rms_mxfp8_quant(x2, self.weight, self.eps) + if x.dim() != 2: + y = y.view(*x.shape[:-1], x.shape[-1]) + s = s.view(*x.shape[:-1], s.shape[-1]) + return y, s # Dynamic-scale fused RMSNorm + quant via aiter HIP kernels. # Static FP8 (x_scale provided) stays on the branch above. x, x_scale, residual_out = _aiter_rms_quant( diff --git a/atom/model_ops/linear.py b/atom/model_ops/linear.py index 5505b5942..e27f00bf8 100644 --- a/atom/model_ops/linear.py +++ b/atom/model_ops/linear.py @@ -2,6 +2,7 @@ # Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. import logging +import os from functools import partial as functools_partial from typing import Callable, Optional @@ -22,6 +23,10 @@ from aiter.jit.utils.torch_guard import torch_compile_guard from aiter.tuned_gemm import tgemm from aiter.utility import fp4_utils +from aiter.ops.triton.gemm.basic.gemm_afp8wfp8 import ( + gemm_afp8wfp8, + gemm_afp8wfp8_preshuffle, +) from atom.config import QuantizationConfig, get_current_atom_config from atom.quant_spec import LayerQuantConfig from atom.model_ops.utils import ( @@ -286,11 +291,16 @@ def __init__( torch.empty(self.output_size, 1, dtype=dtypes.fp32) ) elif quant_type == QuantType.per_1x128: + # When MXFP8 a8w8 is active, allocate weight_scale directly as + # float8_e8m0fnu (1 byte) to match the on-disk dtype - avoids + # the lossy fp32 upcast/round-trip at load. + _mxfp8_env = os.environ.get("ATOM_FP8_BLOCKSCALE_USE_MXFP8", "0") == "1" + _scale_dtype = torch.float8_e8m0fnu if _mxfp8_env else dtypes.fp32 self.weight_scale = atom_parameter( torch.empty( (self.output_size + 127) // 128, (self.input_size + 127) // 128, - dtype=dtypes.fp32, + dtype=_scale_dtype, ) ) elif quant_type == QuantType.per_1x32: @@ -510,9 +520,22 @@ def process_weights_after_loading(self): self.quant_type == QuantType.per_Token and self.params_dtype == dtypes.fp8 ) or self.quant_type == QuantType.per_1x32 - # per_1x128 only needs shuffle when using the preshuffle GEMM path + # per_1x128 only needs shuffle when using a preshuffle GEMM path. + # MXFP8 path: only shuffle if gemm_mxfp8_preshuffle will consume the + # shuffled weight (gated by ATOM_MXFP8_USE_PRESHUFFLE, default on). + # The non-preshuffle gemm_mxfp8 reads scrambled rows otherwise. + _mxfp8_active = ( + self.quant_type == QuantType.per_1x128 + and os.environ.get("ATOM_FP8_BLOCKSCALE_USE_MXFP8", "0") == "1" + ) + _mxfp8_preshuffle = _mxfp8_active and ( + os.environ.get("ATOM_MXFP8_USE_PRESHUFFLE", "1") == "1" + ) if not need_shuffle and self.quant_type == QuantType.per_1x128: - need_shuffle = envs.ATOM_FP8_BLOCKSCALE_WEIGHT_PRESHUFFLE + if _mxfp8_active: + need_shuffle = _mxfp8_preshuffle + else: + need_shuffle = envs.ATOM_FP8_BLOCKSCALE_WEIGHT_PRESHUFFLE if need_shuffle: if self.weight.dim() == 2: shuffle_weights(self.weight) @@ -521,6 +544,30 @@ def process_weights_after_loading(self): if self.quant_type == QuantType.per_1x32: self.weight_scale.data = fp4_utils.e8m0_shuffle(self.weight_scale.data) + # ---- Optional: MXFP8 a8w8 path (dormant, only active when env set) ---- + # On disk V4-Flash stores per_1x128 weight scales as float8_e8m0fnu + # (uint8 powers-of-2) and weights as float8_e4m3fn. With MXFP8 active, + # __init__ allocated weight_scale as float8_e8m0fnu so the loader copied + # bits without any upcast/round-trip. Just expose as uint8 view here + # (gemm_mxfp8 expects uint8) and convert fp8 weight to fn-bits if the + # loader landed it as fnuz. + self._mxfp8_active = ( + self.quant_type == QuantType.per_1x128 + and os.environ.get("ATOM_FP8_BLOCKSCALE_USE_MXFP8", "0") == "1" + ) + if self._mxfp8_active: + if self.weight_scale.data.dtype != torch.uint8: + self.weight_scale = atom_parameter( + self.weight_scale.data.view(torch.uint8).contiguous() + ) + if self.weight.data.dtype == torch.float8_e4m3fnuz: + # On-disk format is float8_e4m3fn but the param was allocated + # as fnuz (AMD default). The bit patterns differ - convert via + # value to preserve semantics. + w_val = self.weight.data.float() + w_fn = w_val.clamp(-448.0, 448.0).to(torch.float8_e4m3fn) + self.weight = atom_parameter(w_fn) + @mark_trace def forward( self, x: torch.Tensor, x_scale: Optional[torch.Tensor] = None, otype=dtypes.bf16 @@ -533,7 +580,31 @@ def forward( otype=otype, ) else: - if x_scale is None: + # MXFP8 path: route per_1x128 weights through MXFP8 GEMM. + _use_mxfp8 = self.quant_type.value == QuantType.per_1x128.value and getattr( + self, "_mxfp8_active", False + ) + if _use_mxfp8: + from aiter.ops.triton.quant.quant import ( + dynamic_mxfp8_quant, + ) + + if x_scale is None: + x, x_scale = dynamic_mxfp8_quant(x) + elif x_scale.dtype == torch.uint8: + pass # caller already emitted MXFP8 1x32 + else: + # legacy caller emitted FP8 + fp32 1x128 scale. Dequant to + # bf16, then re-quantize to MXFP8 1x32. + Mx, Kx = x.shape + sM, sCols = x_scale.shape + group = Kx // sCols + x_dq = x.to(torch.float32).view(Mx, sCols, group) * x_scale.to( + torch.float32 + ).view(sM, sCols, 1) + x_bf16 = x_dq.view(Mx, Kx).to(torch.bfloat16) + x, x_scale = dynamic_mxfp8_quant(x_bf16) + elif x_scale is None: quant_func = self.quant_func if self.quant_type.value == QuantType.per_1x128.value: # preshuffle GEMM expects column-major x_scale; @@ -578,7 +649,24 @@ def forward( if self.bias is not None: y += self.bias elif self.quant_type.value == QuantType.per_1x128.value: - if envs.ATOM_FP8_BLOCKSCALE_WEIGHT_PRESHUFFLE: + if _use_mxfp8: + if envs.ATOM_FP8_BLOCKSCALE_WEIGHT_PRESHUFFLE: + y = gemm_afp8wfp8_preshuffle( + x, + self.weight, + x_scale, + self.weight_scale, + dtype=otype, + ) + else: + y = gemm_afp8wfp8( + x, + self.weight, + x_scale, + self.weight_scale, + dtype=otype, + ) + elif envs.ATOM_FP8_BLOCKSCALE_WEIGHT_PRESHUFFLE: y = gemm_a8w8_blockscale_preshuffle_impl( x, self.weight, diff --git a/atom/model_ops/moe.py b/atom/model_ops/moe.py index 92d4698f8..109d59185 100644 --- a/atom/model_ops/moe.py +++ b/atom/model_ops/moe.py @@ -269,6 +269,7 @@ def apply( fused_shared_experts_scoring_func: Optional[str] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", + x_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: raise NotImplementedError @@ -701,6 +702,14 @@ def __init__(self, quant_config: LayerQuantConfig, moe: FusedMoEConfig): assert has_triton_kernels(), "triton_kernels is not installed" + self.use_triton_backend = None + if self.use_triton: + self.use_triton_backend = os.environ.get("ATOM_MOE_BACKEND", "a8w4") + assert self.use_triton_backend in ( + "matmul_ogs", + "a8w4", + ), f"ATOM_MOE_BACKEND={self.use_triton_backend} is not supported in Mxfp4MoEMethod, set ATOM_MOE_BACKEND to matmul_ogs or a8w4" + def create_weights( self, layer: torch.nn.Module, @@ -830,7 +839,76 @@ def process_weights_after_loading(self, layer): if os.environ.get("ATOM_V4_TORCH_MOE"): return - if self.use_triton: + # Step 1: a8w4 MoE backend weight-layout branch. Reorders W13/W2 into + # the layout moe_gemm_a8w4 expects and swizzles scales. + if self.use_triton_backend == "a8w4": + # Step 3: gate/up interleave for apply_swiglu fold. Off by default. + from aiter.ops.triton.moe.moe_op_gemm_a8w4 import ( + swizzle_scales_gfx950 as swizzle_scales, + ) + + def _interleave_gateup(t): + """t: [E, 2N, *] → reorder dim=1 from [g..u..] to [g,u,g,u,...]""" + E, two_N = t.shape[0], t.shape[1] + N = two_N // 2 + tail = t.shape[2:] + return ( + t.view(E, 2, N, *tail) + .permute(0, 2, 1, *range(3, 3 + len(tail))) + .contiguous() + .view(E, two_N, *tail) + ) + + # W1 weight: view-transpose to [E, K/2, 2N] (no .contiguous()) + w1_w = layer.w13_weight.data.view(torch.uint8) # [E, 2N, K/2] + w1_w = _interleave_gateup(w1_w) + w1_w_kernel = w1_w.transpose(-1, -2) # [E, K/2, 2N], view + assert ( + w1_w_kernel.stride(-2) == 1 + ), "W1: K must be contiguous (do NOT call .contiguous())" + + # W1 scale: transpose + contiguous + swizzle + w1_s = layer.w13_weight_scale.data # [E, 2N, K/32] + w1_s = _interleave_gateup(w1_s) + w1_s_swz_in = w1_s.transpose(-1, -2).contiguous() # [E, K/32, 2N] + w1_s_E, w1_s_SCALE_K, w1_s_N = w1_s_swz_in.shape + w1_s_K = w1_s_SCALE_K * 32 + if w1_s_N % 32 == 0 and w1_s_K % 256 == 0: + w1_s_kernel = swizzle_scales(w1_s_swz_in) + else: + w1_s_kernel = w1_s_swz_in + + # W2 weight: view-transpose to [E, N_per/2, hidden] + w2_w = layer.w2_weight.data.view(torch.uint8) + w2_w_kernel = w2_w.transpose(-1, -2) + assert ( + w2_w_kernel.stride(-2) == 1 + ), "W2: K (=N_per/2 packed) must be contiguous" + + # W2 scale: transpose + contiguous + swizzle + w2_s = layer.w2_weight_scale.data + w2_s_swz_in = w2_s.transpose(-1, -2).contiguous() + w2_s_E, w2_s_SCALE_K, w2_s_N = w2_s_swz_in.shape + w2_s_K = w2_s_SCALE_K * 32 + if w2_s_N % 32 == 0 and w2_s_K % 256 == 0: + w2_s_kernel = swizzle_scales(w2_s_swz_in) + else: + w2_s_kernel = w2_s_swz_in + + del layer.w13_weight + del layer.w2_weight + del layer.w13_weight_scale + del layer.w2_weight_scale + layer.w13_weight = w1_w_kernel + layer.w2_weight = w2_w_kernel + layer.w13_weight_scale = w1_s_kernel + layer.w2_weight_scale = w2_s_kernel + + self.w13_precision_config = None + self.w2_precision_config = None + return + + elif self.use_triton_backend == "matmul_ogs": from atom.model_ops.fused_moe_triton import _swizzle_mxfp4 from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig @@ -917,6 +995,7 @@ def apply( apply_router_weight_on_input: bool = False, fused_shared_experts_scoring_func: Optional[str] = None, activation: ActivationType = ActivationType.Silu, + x_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: if self.use_triton: from atom.model_ops.fused_moe_triton import ( @@ -935,6 +1014,101 @@ def apply( ) if needs_custom_routing: + if self.use_triton_backend == "a8w4": + assert ( + layer.num_fused_shared_experts == 0 + ), "A8W4 Triton MOE does not support fused_shared_experts mode, please set ATOM_MOE_BACKEND=matmul_ogs or ATOM_USE_TRITON_MOE=0" + # DeepSeek-V4 hash-layer fully-fused fast path: replaces the + # Python `_hash_topk` + multi-kernel `fused_routing_from_topk` + # counting-sort + `compute_expt_data` (with memset) chain with + # ONE Triton kernel (hash_routing) + sort_tokens_fused. Same + # 2-kernel shape as the non-hash routing_a8w4 path. + # + # Detection: custom_routing_function is the bound method of a + # MoeLayer whose gate has the tid2eid lookup table. + from atom.model_ops.fused_moe_triton import _a8w4_fused_experts + from aiter.ops.triton.moe.moe_op_gemm_a8w4 import recommend_block_m + from aiter.ops.triton.moe.moe_routing.routing import ( + routing_a8w4, + routing_a8w4_from_hash, + ) + + _hash_layer = ( + custom_routing_function is not None + and hasattr(custom_routing_function, "__self__") + and hasattr( + getattr(custom_routing_function.__self__, "gate", None), + "tid2eid", + ) + ) + M = x.shape[-2] + block_m = recommend_block_m(M) + n_expts_tot = router_logits.shape[-1] + if global_num_experts > 0: + n_expts_tot = global_num_experts + n_expts_tot = n_expts_tot + layer.num_fused_shared_experts + if _hash_layer: + moe_layer = custom_routing_function.__self__ + tid2eid = moe_layer.gate.tid2eid + fwd_ctx = get_forward_context() + ids = fwd_ctx.context.input_ids.flatten() + num_tokens = router_logits.shape[0] + if ids.shape[0] < num_tokens: + ids_2d = ids.unsqueeze(-1) + ids_2d, _ = pad_for_all_gather(ids_2d) + from aiter.dist.parallel_state import get_dp_group + + ids_2d = get_dp_group().all_gather(ids_2d, dim=0) + ids = ids_2d[:num_tokens].flatten() + ids = ids.clamp(0, tid2eid.shape[0] - 1) + + aiter_routing, gather_src, scatter_src = routing_a8w4_from_hash( + router_logits, + tid2eid, + ids, + n_expts_act=top_k, + block_m=block_m, + score_mode="sqrtsoftplus", + renorm=renormalize, + routed_scaling_factor=moe_layer.routed_scaling_factor, + ) + else: + aiter_routing, gather_src, scatter_src = routing_a8w4( + router_logits, + n_expts_act=top_k, + block_m=block_m, + score_mode="sqrtsoftplus", + bias=e_score_correction_bias, + renorm=renormalize, + routed_scaling_factor=layer.routed_scaling_factor, + ) + + moe_residual = getattr(layer, "_moe_residual_to_fold", None) + if moe_residual is not None: + layer._moe_residual_was_folded = True + output = torch.empty( + *x.shape, dtype=torch.bfloat16, device=x.device + ) + return _a8w4_fused_experts( + output, + x, + layer.w13_weight, + layer.w2_weight, + layer.w13_weight_scale, + layer.w2_weight_scale, + routing_data=None, + gather_indx=gather_src, + scatter_indx=scatter_src, + topk=top_k, + swiglu_limit=getattr(layer, "swiglu_limit", 0.0), + w1_bias=layer.w13_bias, + w2_bias=layer.w2_bias, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + pre_built_aiter_routing=aiter_routing, + residual=moe_residual, + x_scale=x_scale, + ) + # Use ATOM's full-featured select_experts for routing, # then triton matmul_ogs for the actual MoE computation. topk_weights, topk_ids = FusedMoE.select_experts( @@ -1404,6 +1578,7 @@ def apply( apply_router_weight_on_input: bool = False, fused_shared_experts_scoring_func: Optional[str] = None, activation: ActivationType = ActivationType.Silu, + x_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Apply compressed-tensors FP8 MoE computation.""" # Select top-k experts using router logits @@ -1760,6 +1935,7 @@ def apply( apply_router_weight_on_input: bool = False, fused_shared_experts_scoring_func: Optional[str] = None, activation: ActivationType = ActivationType.Silu, + x_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, @@ -1864,16 +2040,18 @@ def moe_forward( hidden_states: torch.Tensor, router_logits: torch.Tensor, layer_name: str, + hidden_states_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: atom_config = get_current_atom_config() self = atom_config.compilation_config.static_forward_context[layer_name] - return self.forward_impl(hidden_states, router_logits) + return self.forward_impl(hidden_states, router_logits, hidden_states_scale) def moe_forward_fake( hidden_states: torch.Tensor, router_logits: torch.Tensor, layer_name: str, + hidden_states_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: return torch.empty_like(hidden_states) @@ -2911,13 +3089,24 @@ def select_experts( return topk_weights, topk_ids - def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): + def forward( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + hidden_states_scale: torch.Tensor = None, + ): return torch.ops.aiter.moe_forward( - hidden_states, router_logits, self.layer_name + hidden_states, + router_logits, + self.layer_name, + hidden_states_scale=hidden_states_scale, ) def forward_impl_graph( - self, hidden_states: torch.Tensor, router_logits: torch.Tensor + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + hidden_states_scale: torch.Tensor = None, ): # There are three mode # 1. Pure DP mode: only DP is used @@ -2965,6 +3154,7 @@ def forward_impl_graph( fused_shared_experts_scoring_func=self.shared_expert_scoring_func, activation=self.activation, apply_router_weight_on_input=self.apply_router_weight_on_input, + x_scale=hidden_states_scale, ) # Use reduce_scatter when DP > 1 but not using mori all2all kernels @@ -2985,11 +3175,18 @@ def forward_impl_graph( return final_hidden_states - def forward_impl(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): + def forward_impl( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + hidden_states_scale: torch.Tensor = None, + ): assert self.quant_method is not None # cuda graph not supported forward with combine and dispatch if self.use_chunked: - return self.forward_impl_graph(hidden_states, router_logits) + return self.forward_impl_graph( + hidden_states, router_logits, hidden_states_scale=hidden_states_scale + ) # return self.forward_impl_chunked(hidden_states, router_logits) dp_group = get_dp_group() @@ -3019,6 +3216,7 @@ def forward_impl(self, hidden_states: torch.Tensor, router_logits: torch.Tensor) fused_shared_experts_scoring_func=self.shared_expert_scoring_func, activation=self.activation, apply_router_weight_on_input=self.apply_router_weight_on_input, + x_scale=hidden_states_scale, ) dp_group = get_dp_group() diff --git a/atom/models/deepseek_v4.py b/atom/models/deepseek_v4.py index 7bb238941..743170646 100644 --- a/atom/models/deepseek_v4.py +++ b/atom/models/deepseek_v4.py @@ -47,10 +47,16 @@ from aiter.ops.triton.fusions.fused_clamp_act_mul import ( fused_clamp_act_mul, ) +from aiter.ops.triton.gemm.fused.fused_gemm_a16w16_quant_x import ( + fused_gemm_a16w16_quant_x, +) from aiter.ops.triton.fusions.fused_reduce_qk_norm_rope_swa_write import ( fused_reduce_qk_norm_rope_swa_write, ) +from aiter.ops.triton.quant.fused_mxfp8_quant import fused_flatten_mxfp8_quant from aiter.ops.triton.pa_mqa_logits import deepgemm_fp8_paged_mqa_logits +from aiter.tuned_gemm import tgemm +from aiter.ops.triton.moe.quant_moe import downcast_to_mxfp from atom.config import ( Config, LayerQuantConfig, @@ -69,7 +75,7 @@ ReplicatedLinear, RowParallelLinear, ) -from atom.model_ops.moe import FusedMoE +from atom.model_ops.moe import FusedMoE, Mxfp4MoEMethod from atom.model_ops.topK import is_rocm_aiter_fusion_shared_expert_enabled from aiter import rope_rotate_activation from atom.model_ops.quant_v4 import act_quant_inplace @@ -126,6 +132,48 @@ # Fused-kernel switches. Default off; flip via env to A/B against the eager path. _V4_USE_TRITON_FUSION = os.environ.get("ATOM_V4_USE_TRITON_FUSION", "0") == "1" ENABLE_DS_QKNORM_QUANT_FUSION = envs.ATOM_ENABLE_DS_QKNORM_QUANT_FUSION +# MXFP8 a8w8 GEMM path (Task #77). When on, q_norm RMSNorm emits FP8 e4m3fn + +# uint8 e8m0 1x32 scales directly (via the Triton fused_rms_mxfp8_quant path in +# atom/model_ops/layernorm.py) so wq_b's MXFP8 GEMM consumes them with no +# transcode. The helper below is reserved for call sites that need to fuse +# the Q-side MXFP8 quant with the K-side bf16 RMSNorm in a single launch. +_V4_USE_MXFP8 = os.environ.get("ATOM_FP8_BLOCKSCALE_USE_MXFP8", "0") == "1" +_V4_DISABLE_FCAM = os.environ.get("ATOM_V4_DISABLE_FCAM", "0") == "1" +_MXFP8_BYPASS_FCAM = os.environ.get("ATOM_MXFP8_BYPASS_FCAM", "0") == "1" + + +def _fuse_rmsnorm_mxfp8_quant( + q_lora: torch.Tensor, + q_norm_weight: torch.Tensor, + q_norm_eps: float, + kv_pre: torch.Tensor, + kv_norm_weight: torch.Tensor, + kv_norm_eps: float, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Dual RMSNorm with MXFP8 emit on Q side, bf16 on K side, single launch. + + Q: RMSNorm + per-1x32 MXFP8 quant -- emits FP8 e4m3fn + uint8 e8m0 scale + directly so the downstream wq_b MXFP8 GEMM consumes them with no + dequant+requant cascade. + K: RMSNorm only, bf16 output, consumed by the downstream RoPE / SWA-write + fused kernel which expects bf16 K. + + Both halves run in a single Triton launch (fused_dual_rmsnorm_mxfp8_quant) to + avoid the +6us/layer per-launch overhead of two separate kernels. + + Returns (qr_fp8, qr_scale_e8m0, kv_bf16). + """ + from aiter.ops.triton.quant.fused_mxfp8_quant import fused_dual_rmsnorm_mxfp8_quant + + qr, qr_scale, kv = fused_dual_rmsnorm_mxfp8_quant( + q_lora, + kv_pre, + q_norm_weight, + kv_norm_weight, + q_norm_eps, + kv_norm_eps, + ) + return qr, qr_scale, kv def _rmsnorm_nw(x: torch.Tensor, eps: float, dim: int) -> torch.Tensor: @@ -240,6 +288,44 @@ def fused_qk_norm_rope_swa_write( return q_out +def _fused_router_gate_a16w16_quant_fake( + x: torch.Tensor, + gate_weight: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + num_tokens, K = x.shape + n_routed_experts = gate_weight.shape[0] + router_logits = torch.empty( + (num_tokens, n_routed_experts), dtype=torch.bfloat16, device=x.device + ) + x_out = torch.empty((num_tokens, K), dtype=dtypes.fp8, device=x.device) + x_scales = torch.empty((num_tokens, K // 32), dtype=torch.uint8, device=x.device) + return router_logits, x_out, x_scales + + +@torch_compile_guard(gen_fake=_fused_router_gate_a16w16_quant_fake, mutates_args=[]) +def fused_router_gate_a16w16_quant( + x: torch.Tensor, + gate_weight: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """A16W16 router-gate GEMM, optionally fused with a BF16->FP8 downcast + copy of x. + + For decode-shaped batches (num_tokens <= 64) the GEMM and the activation + downcast run in a single triton kernel; the FP8 copy is consumed by the + downstream fused_moe kernel, replacing its internal + `hidden_states.to(dtypes.fp8)` cast with a copy done inline with the + gate GEMM's A-loads. For larger batches the fusion's per-element copy + overhead is not worth the savings, so we fall back to tgemm.mm and let + aiter re-trigger the `.to(fp8)` cast on its side. + """ + num_tokens = x.shape[0] + if num_tokens <= 64: + return fused_gemm_a16w16_quant_x(x, gate_weight, quant_dtype=dtypes.fp8) + router_logits = tgemm.mm(x, gate_weight, None, otype=torch.bfloat16) + x, x_scale = downcast_to_mxfp(x, dtypes.fp8, axis=-1) + return router_logits, x, x_scale + + def _make_weightless_rmsnorm(dim: int, eps: float) -> RMSNorm: """Build an `RMSNorm(dim, eps)` whose `.weight` is `None`. @@ -1868,6 +1954,9 @@ def forward_impl( o = o.view(num_tokens, self.n_local_groups, -1) wo_a = self.wo_a.weight.view(self.n_local_groups, self.o_lora_rank, -1) o = torch.einsum("sgd,grd->sgr", o, wo_a) + if _V4_USE_MXFP8: + x_fp8, x_scale = fused_flatten_mxfp8_quant(o) + return self.wo_b(x_fp8, x_scale=x_scale) x = self.wo_b(o.flatten(1)) return x @@ -1972,7 +2061,9 @@ def __init__( # Switch: route clamp + silu(gate)*up [+ weights] + per-token FP8 1x128 # quant through a single aiter triton kernel. The fused kernel emits # FP8 + scale; w2 accepts `x_scale` and skips its own quant step. - self.use_fused_clamp_act_mul = _V4_USE_TRITON_FUSION + # ATOM_V4_DISABLE_FCAM=1 forces the unfused eager path for A/B vs the + # Triton fused-clamp-act-mul kernel. + self.use_fused_clamp_act_mul = _V4_USE_TRITON_FUSION and not _V4_DISABLE_FCAM def forward( self, @@ -1987,14 +2078,29 @@ def forward( # feed the bf16 GEMM output directly. combined = self.gate_up_proj(x) # [num_tokens, 2*inter_dim_per_tp] if self.use_fused_clamp_act_mul: - x_fp8, x_scale = fused_clamp_act_mul( - combined, - swiglu_limit=self.swiglu_limit, - activation="silu", - weights=weights, - dtype_quant=dtypes.fp8, - transpose_scale=True, - ) + # When MXFP8 is enabled, ATOM_MXFP8_BYPASS_FCAM=1 forces the + # legacy fp32 1x128 emit + linear.py's dequant+requant fallback, + # isolating whether fused_clamp_act_mul's MXFP8 emit has a bug. + if _V4_USE_MXFP8 and not _MXFP8_BYPASS_FCAM: + x_fp8, x_scale = fused_clamp_act_mul( + combined, + swiglu_limit=self.swiglu_limit, + activation="silu", + weights=weights, + dtype_quant=torch.float8_e4m3fn, + quant_block_size=32, + scale_dtype_fmt="ue8m0", + transpose_scale=False, + ) + else: + x_fp8, x_scale = fused_clamp_act_mul( + combined, + swiglu_limit=self.swiglu_limit, + activation="silu", + weights=weights, + dtype_quant=dtypes.fp8, + transpose_scale=True, + ) return self.w2(x_fp8, x_scale=x_scale) out = torch.empty( (combined.shape[0], combined.shape[-1] // 2), @@ -2144,6 +2250,14 @@ def __init__( prefix ] = self + self._use_a8w4_triton_moe = False + if self.experts.quant_method.__class__ is Mxfp4MoEMethod and getattr( + self.experts.quant_method, "use_triton", False + ): + self._use_a8w4_triton_moe = ( + getattr(self.experts.quant_method, "use_triton_backend", None) == "a8w4" + ) + def _hash_topk( self, hidden_states: torch.Tensor, @@ -2204,6 +2318,15 @@ def routed_expert_forward( `_hash_topk` (FusedMoE's custom_routing_function) reads it there. """ router_logits = self.gate(x) # [num_tokens, n_routed_experts] + if self._use_a8w4_triton_moe: + router_logits, x, x_scale = fused_router_gate_a16w16_quant( + x, self.gate.weight + ) + return self.experts( + hidden_states=x, + router_logits=router_logits, + hidden_states_scale=x_scale, + ) return self.experts(hidden_states=x, router_logits=router_logits) def combine_outputs( @@ -2223,9 +2346,31 @@ def combine_outputs( def single_stream_moe_forward( self, x: torch.Tensor # [num_tokens, dim] ) -> torch.Tensor: # [num_tokens, dim] - """Sequential: shared_experts → routed_experts → combine.""" + """Sequential: shared_experts → routed_experts → combine. + + Step 9: when the MoE method supports it (a8w4 triton-routing path), + the `routed + shared` add is folded into reduce_grouped's writeback by + stashing `shared` on `self.experts` before `routed_expert_forward`. + Saves the standalone elementwise add kernel. Gated by + ATOM_A8W4_FUSE_RESIDUAL (default 1) for easy A/B. + """ shared = self.shared_experts(x) if self.shared_experts is not None else None - routed = self.routed_expert_forward(x) + fuse_residual = ( + shared is not None and os.environ.get("ATOM_A8W4_FUSE_RESIDUAL", "1") == "1" + ) + if fuse_residual: + self.experts._moe_residual_to_fold = shared + try: + routed = self.routed_expert_forward(x) + finally: + if fuse_residual: + self.experts._moe_residual_to_fold = None + folded = getattr(self.experts, "_moe_residual_was_folded", False) + self.experts._moe_residual_was_folded = False # one-shot + if folded: + if self.tp_size > 1: + routed = tensor_model_parallel_all_reduce(routed) + return routed return self.combine_outputs(routed, shared) def dual_stream_moe_forward(