Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 160 additions & 1 deletion atom/model_ops/fused_moe_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,165 @@ def routing_from_topk(topk_weights, topk_ids, n_expts_tot):
return routing_data, gather_indx, scatter_indx


def bridge_tk_to_aiter_routing(tk_rd, block_m: int):
"""Bridge triton_kernels RoutingData -> aiter RoutingData (for moe_gemm_a8w4).

triton_kernels carries per-block-m specializations of token_offs_pad /
block_pid_map as dicts keyed by block_m; the aiter variant is specialized
for a single block_m and stores those fields as plain tensors.
"""
from aiter.ops.triton.moe.moe_routing.routing import RoutingData, ExptData

tk_ed = tk_rd.expt_data
if block_m not in tk_ed.token_offs_pad:
raise ValueError(
f"bridge_tk_to_aiter_routing: block_m={block_m} not present in "
f"tk_rd.expt_data.token_offs_pad. Keys: {sorted(tk_ed.token_offs_pad.keys())}"
)
aiter_expt_data = ExptData(
hist=tk_ed.hist,
token_offs_raw=tk_ed.token_offs_raw,
token_offs_pad=tk_ed.token_offs_pad[block_m],
block_pid_map=tk_ed.block_pid_map[block_m],
)
return RoutingData(
block_m=block_m,
gate_scal=tk_rd.gate_scal,
expt_hist=tk_rd.expt_hist,
n_expts_tot=tk_rd.n_expts_tot,
n_expts_act=tk_rd.n_expts_act,
expt_data=aiter_expt_data,
)


def _a8w4_fused_experts(
output_tensor: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor, # a8w4 layout: uint8 [E, K/2, 2N], stride(-2)==1
w2: torch.Tensor, # a8w4 layout: uint8 [E, N/2, K], stride(-2)==1
w1_scale: torch.Tensor, # swizzled ue8m0 uint8
w2_scale: torch.Tensor, # swizzled ue8m0 uint8
routing_data,
gather_indx,
scatter_indx,
topk: int,
swiglu_limit: float,
w1_bias: torch.Tensor | None,
w2_bias: torch.Tensor | None,
apply_router_weight_on_input: bool,
pre_built_aiter_routing=None, # if set, skip bridge and use this aiter routing directly
# Step 9: shared-expert output to fold into GEMM2 reduce_grouped writeback.
residual: torch.Tensor | None = None,
x_scale: torch.Tensor | None = None,
swiglu_fold: bool = True,
) -> torch.Tensor:
"""Step 1 minimum a8w4 fused experts.

Pipeline: downcast_to_mxfp -> moe_gemm_a8w4 (G1, no swiglu) ->
fused_clamp_act_mul (silu_mul, no quant) -> downcast_to_mxfp ->
moe_gemm_a8w4 (G2). No apply_swiglu fold, no out_mx_quant fold,
no residual fold — those are later steps.
"""
from aiter.ops.triton.moe.moe_op_gemm_a8w4 import moe_gemm_a8w4, recommend_block_m
from aiter.ops.triton.moe.quant_moe import downcast_to_mxfp

M, K = hidden_states.shape[-2:]
BLOCK_M = recommend_block_m(M)

# Pre-MoE quant: bf16 -> fp8 e4m3 + ue8m0 per-1x32.
if x_scale is None:
x_fp8, x_scale = downcast_to_mxfp(hidden_states, torch.float8_e4m3fn, axis=-1)
else:
x_fp8 = hidden_states

# Step 4: skip bridge when pre_built_aiter_routing is provided.
if pre_built_aiter_routing is not None:
aiter_routing = pre_built_aiter_routing
gammas = aiter_routing.gate_scal
gather_src_indx = gather_indx # already raw int32 tensor
scatter_src_indx = scatter_indx
else:
aiter_routing = bridge_tk_to_aiter_routing(routing_data, BLOCK_M)
gammas = routing_data.gate_scal if routing_data else None
gather_src_indx = gather_indx.src_indx
scatter_src_indx = scatter_indx.src_indx

w1_scale = w1_scale.view(torch.uint8)
w2_scale = w2_scale.view(torch.uint8)

if swiglu_fold:
inter_fp8, inter_scale = moe_gemm_a8w4(
x=x_fp8,
w=w1,
x_scales=x_scale,
w_scales=w1_scale,
bias=w1_bias,
routing_data=aiter_routing,
gather_indx=gather_src_indx,
gammas=gammas if apply_router_weight_on_input else None,
swizzle_mx_scale="CDNA4_SCALE",
apply_swiglu=True,
alpha=1.0,
limit=swiglu_limit,
add_residual=False,
out_dtype=torch.bfloat16, # ignored when out_mx_quant=True
out_mx_quant=True,
)
else:
# GEMM1: fp8 act x fp4 weight, bf16 output [M*topk, 2N].
raw_intermediate = moe_gemm_a8w4(
x=x_fp8,
w=w1,
x_scales=x_scale,
w_scales=w1_scale,
bias=w1_bias,
routing_data=aiter_routing,
gather_indx=gather_src_indx,
gammas=gammas if apply_router_weight_on_input else None,
swizzle_mx_scale="CDNA4_SCALE",
apply_swiglu=False,
out_dtype=torch.bfloat16,
)
raw_2d = raw_intermediate.view(
raw_intermediate.shape[-2], raw_intermediate.shape[-1]
)
two_N = raw_2d.shape[-1]
inter_bf16 = torch.empty(
(raw_2d.shape[0], two_N // 2),
device=raw_2d.device,
dtype=torch.bfloat16,
)
fused_clamp_act_mul(
raw_2d,
out=inter_bf16,
swiglu_limit=swiglu_limit,
activation="silu",
dtype_quant=None,
)
inter_fp8, inter_scale = downcast_to_mxfp(
inter_bf16, torch.float8_e4m3fn, axis=-1
)

# GEMM2: fp8 act x fp4 weight, scatter+combine.
# Step 9: pass `residual` so reduce_grouped folds the routed+shared add
# into the writeback (saves the standalone elementwise add launch).
y = moe_gemm_a8w4(
x=inter_fp8,
w=w2,
x_scales=inter_scale,
w_scales=w2_scale,
bias=w2_bias,
routing_data=aiter_routing,
scatter_indx=scatter_src_indx,
gammas=None if apply_router_weight_on_input else gammas,
swizzle_mx_scale="CDNA4_SCALE",
apply_swiglu=False,
out_dtype=torch.bfloat16,
residual=residual,
)
return y


def _resize_cache(x: torch.Tensor, v: tuple[int, ...]) -> torch.Tensor:
"""
Shrink the given tensor and apply the given view to it. This is
Expand Down Expand Up @@ -299,7 +458,7 @@ def triton_kernel_fused_experts(
assert w1_bias is None or w1_bias.dtype == torch.float32
assert w2_bias is None or w2_bias.dtype == torch.float32

# Shape check, only check non-mxfp4
# Shape check for matmul_ogs path (a8w4 has different layout)
assert hidden_states.ndim == 2
assert hidden_states.shape[-1] == w1.shape[-2]
assert w2.shape[-1] == w1.shape[1]
Expand Down
22 changes: 22 additions & 0 deletions atom/model_ops/layernorm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.

import os
from typing import Optional, Tuple

import aiter
Expand All @@ -17,6 +18,7 @@
from aiter.jit.utils.torch_guard import torch_compile_guard
from aiter.ops.gated_rmsnorm_fp8_group_quant import gated_rmsnorm_fp8_group_quant
from aiter.ops.triton.fused_add_rmsnorm_pad import fused_add_rmsnorm_pad
from aiter.ops.triton.quant.fused_mxfp8_quant import fused_rms_mxfp8_quant
from atom.config import QuantizationConfig
from atom.model_ops.utils import atom_parameter
from atom.quant_spec import LayerQuantConfig
Expand All @@ -25,6 +27,8 @@
from torch import Tensor, nn
from torch.overrides import handle_torch_function, has_torch_function_unary

_V4_USE_MXFP8 = os.environ.get("ATOM_FP8_BLOCKSCALE_USE_MXFP8", "0") == "1"


def silu(input: Tensor, inplace: bool = False) -> Tensor:
r"""Apply the Sigmoid Linear Unit (SiLU) function, element-wise.
Expand Down Expand Up @@ -318,6 +322,24 @@ def forward(
and x_scale is None
and self.quant_type.value in _AITER_RMS_QUANT_TYPE_VALUES
):
# MXFP8 path: when downstream Linear consumes MXFP8 1x32 e8m0
# scales (ATOM_FP8_BLOCKSCALE_USE_MXFP8=1), emit those directly
# to skip the dequant+requant cascade. Limited to per_1x128 and
# the no-residual case (most q_norm callers).
if (
self.quant_type.value == _QV_PER_1X128
and residual is None
and _V4_USE_MXFP8
):
if x.dim() != 2:
x2 = x.reshape(-1, x.shape[-1])
else:
x2 = x
y, s = fused_rms_mxfp8_quant(x2, self.weight, self.eps)
if x.dim() != 2:
y = y.view(*x.shape[:-1], x.shape[-1])
s = s.view(*x.shape[:-1], s.shape[-1])
return y, s
# Dynamic-scale fused RMSNorm + quant via aiter HIP kernels.
# Static FP8 (x_scale provided) stays on the branch above.
x, x_scale, residual_out = _aiter_rms_quant(
Expand Down
98 changes: 93 additions & 5 deletions atom/model_ops/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.

import logging
import os
from functools import partial as functools_partial
from typing import Callable, Optional

Expand All @@ -22,6 +23,10 @@
from aiter.jit.utils.torch_guard import torch_compile_guard
from aiter.tuned_gemm import tgemm
from aiter.utility import fp4_utils
from aiter.ops.triton.gemm.basic.gemm_afp8wfp8 import (
gemm_afp8wfp8,
gemm_afp8wfp8_preshuffle,
)
from atom.config import QuantizationConfig, get_current_atom_config
from atom.quant_spec import LayerQuantConfig
from atom.model_ops.utils import (
Expand Down Expand Up @@ -286,11 +291,16 @@ def __init__(
torch.empty(self.output_size, 1, dtype=dtypes.fp32)
)
elif quant_type == QuantType.per_1x128:
# When MXFP8 a8w8 is active, allocate weight_scale directly as
# float8_e8m0fnu (1 byte) to match the on-disk dtype - avoids
# the lossy fp32 upcast/round-trip at load.
_mxfp8_env = os.environ.get("ATOM_FP8_BLOCKSCALE_USE_MXFP8", "0") == "1"
_scale_dtype = torch.float8_e8m0fnu if _mxfp8_env else dtypes.fp32
self.weight_scale = atom_parameter(
torch.empty(
(self.output_size + 127) // 128,
(self.input_size + 127) // 128,
dtype=dtypes.fp32,
dtype=_scale_dtype,
)
)
elif quant_type == QuantType.per_1x32:
Expand Down Expand Up @@ -510,9 +520,22 @@ def process_weights_after_loading(self):
self.quant_type == QuantType.per_Token
and self.params_dtype == dtypes.fp8
) or self.quant_type == QuantType.per_1x32
# per_1x128 only needs shuffle when using the preshuffle GEMM path
# per_1x128 only needs shuffle when using a preshuffle GEMM path.
# MXFP8 path: only shuffle if gemm_mxfp8_preshuffle will consume the
# shuffled weight (gated by ATOM_MXFP8_USE_PRESHUFFLE, default on).
# The non-preshuffle gemm_mxfp8 reads scrambled rows otherwise.
_mxfp8_active = (
self.quant_type == QuantType.per_1x128
and os.environ.get("ATOM_FP8_BLOCKSCALE_USE_MXFP8", "0") == "1"
)
_mxfp8_preshuffle = _mxfp8_active and (
os.environ.get("ATOM_MXFP8_USE_PRESHUFFLE", "1") == "1"
)
if not need_shuffle and self.quant_type == QuantType.per_1x128:
need_shuffle = envs.ATOM_FP8_BLOCKSCALE_WEIGHT_PRESHUFFLE
if _mxfp8_active:
need_shuffle = _mxfp8_preshuffle
else:
need_shuffle = envs.ATOM_FP8_BLOCKSCALE_WEIGHT_PRESHUFFLE
if need_shuffle:
if self.weight.dim() == 2:
shuffle_weights(self.weight)
Expand All @@ -521,6 +544,30 @@ def process_weights_after_loading(self):
if self.quant_type == QuantType.per_1x32:
self.weight_scale.data = fp4_utils.e8m0_shuffle(self.weight_scale.data)

# ---- Optional: MXFP8 a8w8 path (dormant, only active when env set) ----
# On disk V4-Flash stores per_1x128 weight scales as float8_e8m0fnu
# (uint8 powers-of-2) and weights as float8_e4m3fn. With MXFP8 active,
# __init__ allocated weight_scale as float8_e8m0fnu so the loader copied
# bits without any upcast/round-trip. Just expose as uint8 view here
# (gemm_mxfp8 expects uint8) and convert fp8 weight to fn-bits if the
# loader landed it as fnuz.
self._mxfp8_active = (
self.quant_type == QuantType.per_1x128
and os.environ.get("ATOM_FP8_BLOCKSCALE_USE_MXFP8", "0") == "1"
)
if self._mxfp8_active:
if self.weight_scale.data.dtype != torch.uint8:
self.weight_scale = atom_parameter(
self.weight_scale.data.view(torch.uint8).contiguous()
)
if self.weight.data.dtype == torch.float8_e4m3fnuz:
# On-disk format is float8_e4m3fn but the param was allocated
# as fnuz (AMD default). The bit patterns differ - convert via
# value to preserve semantics.
w_val = self.weight.data.float()
w_fn = w_val.clamp(-448.0, 448.0).to(torch.float8_e4m3fn)
self.weight = atom_parameter(w_fn)

@mark_trace
def forward(
self, x: torch.Tensor, x_scale: Optional[torch.Tensor] = None, otype=dtypes.bf16
Expand All @@ -533,7 +580,31 @@ def forward(
otype=otype,
)
else:
if x_scale is None:
# MXFP8 path: route per_1x128 weights through MXFP8 GEMM.
_use_mxfp8 = self.quant_type.value == QuantType.per_1x128.value and getattr(
self, "_mxfp8_active", False
)
if _use_mxfp8:
from aiter.ops.triton.quant.quant import (
dynamic_mxfp8_quant,
)

if x_scale is None:
x, x_scale = dynamic_mxfp8_quant(x)
elif x_scale.dtype == torch.uint8:
pass # caller already emitted MXFP8 1x32
else:
# legacy caller emitted FP8 + fp32 1x128 scale. Dequant to
# bf16, then re-quantize to MXFP8 1x32.
Mx, Kx = x.shape
sM, sCols = x_scale.shape
group = Kx // sCols
x_dq = x.to(torch.float32).view(Mx, sCols, group) * x_scale.to(
torch.float32
).view(sM, sCols, 1)
x_bf16 = x_dq.view(Mx, Kx).to(torch.bfloat16)
x, x_scale = dynamic_mxfp8_quant(x_bf16)
elif x_scale is None:
quant_func = self.quant_func
if self.quant_type.value == QuantType.per_1x128.value:
# preshuffle GEMM expects column-major x_scale;
Expand Down Expand Up @@ -578,7 +649,24 @@ def forward(
if self.bias is not None:
y += self.bias
elif self.quant_type.value == QuantType.per_1x128.value:
if envs.ATOM_FP8_BLOCKSCALE_WEIGHT_PRESHUFFLE:
if _use_mxfp8:
if envs.ATOM_FP8_BLOCKSCALE_WEIGHT_PRESHUFFLE:
y = gemm_afp8wfp8_preshuffle(
x,
self.weight,
x_scale,
self.weight_scale,
dtype=otype,
)
else:
y = gemm_afp8wfp8(
x,
self.weight,
x_scale,
self.weight_scale,
dtype=otype,
)
elif envs.ATOM_FP8_BLOCKSCALE_WEIGHT_PRESHUFFLE:
y = gemm_a8w8_blockscale_preshuffle_impl(
x,
self.weight,
Expand Down
Loading