Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 32 additions & 13 deletions aiter/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,19 @@
from aiter.jit.utils.torch_guard import torch_compile_guard
from aiter.ops.flydsl.utils import is_flydsl_available
from aiter.ops.flydsl.moe_common import GateMode
from aiter import fused_dynamic_mxfp4_quant_moe_sort, mxfp4_moe_sort_fwd
from aiter import (
fused_dynamic_mxfp4_quant_moe_sort,
fused_dynamic_mxfp8_quant_moe_sort,
mxfp4_moe_sort_fwd,
)

BLOCK_SIZE_M = 32

# Default to Opus unless CK sorting is explicitly requested.
_USE_CK_MOE_SORTING = os.environ.get("AITER_USE_CK_MOE_SORTING", "0") == "1"
_ACT_TYPE_DISABLED_KEY = "__ignore__"
_SWIGLU_MXFP4_BF16_BOUND = int(os.environ.get("GPTOSS_SWIGLU_MXFP4_BF16_BOUND", "256"))
_MOE_A8W4_BYPASS_QUANT = os.environ.get("AITER_MOE_A8W4_BYPASS_QUANT", "0") == "1"


def _moe_sorting_impl(
Expand Down Expand Up @@ -1549,19 +1554,35 @@ def fused_moe_2stages(
a1 = hidden_states.to(dtype)
a1_scale = None
elif (
quant_type == aiter.QuantType.per_1x32
quant_type == QuantType.per_1x32
and dtype in [dtypes.bf16, dtypes.fp16]
and q_dtype_a == dtypes.fp8
and w1.dtype == dtypes.fp4x2
# and activation == aiter.ActivationType.Swiglu
):
a1 = hidden_states.to(dtypes.fp8)
M = sorted_ids.shape[0]
N = a1.shape[-1]
# a8w4 mxfp8 activations + mxfp4 weights.
if metadata.fuse_quant == "fp8":
# FlyDSL stage1 fuses the fp8 quant of a1 internally, but the
# kernel dispatch requires an fp8-typed tensor.
a1 = hidden_states.to(dtypes.fp8)
a1_scale = torch.empty(0, dtype=torch.uint8, device=a1.device)
elif _MOE_A8W4_BYPASS_QUANT:
# Debug bypass: skip real quant, feed unit scales.
a1 = hidden_states.to(dtypes.fp8)
M = sorted_ids.shape[0]
a1_scale = torch.ones(
[M, a1.shape[-1] // 32], dtype=dtypes.fp8_e8m0, device=a1.device
)
else:
a1_scale = torch.ones([M, N // 32], dtype=dtypes.fp8_e8m0, device=a1.device)
# stage1 input is not topk-replicated, so M==token_num and the HIP
# launcher infers TOPK=1 from input.numel() / (cols * token_num).
a1, a1_scale = fused_dynamic_mxfp8_quant_moe_sort(
hidden_states,
sorted_ids=sorted_ids,
num_valid_ids=num_valid_ids,
token_num=token_num,
topk=topk,
block_size=block_size_M,
)

elif quant_type == QuantType.per_1x32 and w1.dtype == dtypes.i4x2:
# a16wi4: bf16 activations, int4 weights; no activation quantization needed
Expand Down Expand Up @@ -1676,17 +1697,15 @@ def fused_moe_2stages(
and q_dtype_a == dtypes.fp8
and w1.dtype == dtypes.fp4x2
):
if activation == ActivationType.Silu and swiglu_limit == 0.0:
from aiter.ops.triton.quant.fused_mxfp4_quant import fused_quant_fp8_sort

if not _MOE_A8W4_BYPASS_QUANT:
a2 = a2.view(-1, inter_dim)
a2, a2_scale = fused_quant_fp8_sort(
a2, a2_scale = fused_dynamic_mxfp8_quant_moe_sort(
a2,
sorted_ids=sorted_ids,
num_valid_ids=num_valid_ids,
token_num=token_num,
block_size=32,
quant_dtype=dtypes.fp8,
topk=topk,
block_size=block_size_M,
)
a2 = a2.view(token_num, topk, -1)
else:
Expand Down
46 changes: 43 additions & 3 deletions aiter/ops/quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,7 +668,7 @@ def mxfp4_moe_sort_fwd(


@compile_ops("module_quant", develop=True)
def fused_dynamic_mxfp4_quant_moe_sort_hip(
def fused_dynamic_mx_quant_moe_sort_hip(
out: torch.Tensor,
scales: torch.Tensor,
input: torch.Tensor,
Expand All @@ -679,7 +679,9 @@ def fused_dynamic_mxfp4_quant_moe_sort_hip(
group_size: int = 32,
) -> None:
"""
HIP path for fused dynamic MXFP4 quantization and MoE scale sorting.
HIP path for fused dynamic MX (fp4 or fp8) quantization and MoE scale
sorting. The output dtype of ``out`` selects the quant target: fp4x2/uint8
for MXFP4, fp8 for MXFP8.
"""
...

Expand Down Expand Up @@ -779,7 +781,7 @@ def fused_dynamic_mxfp4_quant_moe_sort(
or group_size != 32
):
out = torch.empty(M, N // 2, dtype=dtypes.fp4x2, device=input.device)
fused_dynamic_mxfp4_quant_moe_sort_hip(
fused_dynamic_mx_quant_moe_sort_hip(
out,
scale,
input,
Expand All @@ -797,6 +799,44 @@ def fused_dynamic_mxfp4_quant_moe_sort(
return out, scale


def fused_dynamic_mxfp8_quant_moe_sort(
input: torch.Tensor,
sorted_ids: torch.Tensor,
num_valid_ids: torch.Tensor,
token_num: int,
topk: int,
block_size: int,
group_size: int = 32,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""HIP replacement for Triton fused_quant_fp8_sort.

Returns (fp8_out, e8m0_scale) with the scale tensor laid out as
(pad32(M_o), N_o) fp8_e8m0 — the same byte layout the FlyDSL stage1/stage2
GEMM consumes. ``topk`` is accepted for parity with the fp4 wrapper but
is inferred inside the HIP launcher from ``input.numel() / (cols * token_num)``.
"""
M, N = input.view(-1, input.shape[-1]).shape
N_o = (N + 31) // 32
out = torch.empty(M, N, dtype=dtypes.fp8, device=input.device)
scale = torch.empty(
(sorted_ids.shape[0] + 31) // 32 * 32,
N_o,
dtype=dtypes.fp8_e8m0,
device=input.device,
)
fused_dynamic_mx_quant_moe_sort_hip(
out,
scale,
input,
sorted_ids,
num_valid_ids,
token_num,
block_size,
group_size,
)
return out, scale


@compile_ops("module_quant", develop=True)
def partial_transpose(
out: Tensor,
Expand Down
2 changes: 1 addition & 1 deletion csrc/include/quant.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ void moe_smooth_per_token_scaled_quant_v2(aiter_tensor_t& out, // [...,
bool shuffle_scale = false,
bool transpose_out = false);

void fused_dynamic_mxfp4_quant_moe_sort_hip(aiter_tensor_t& out, // [token_num * topk, d / 2]
void fused_dynamic_mx_quant_moe_sort_hip(aiter_tensor_t& out, // [token_num * topk, d] for fp8 or [token_num * topk, d / 2] for fp4
aiter_tensor_t& scales, // swizzled e8m0 bytes
const aiter_tensor_t& input, // [token_num * topk, d]
const aiter_tensor_t& sorted_ids,
Expand Down
4 changes: 2 additions & 2 deletions csrc/include/rocm_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1411,8 +1411,8 @@ namespace py = pybind11;
py::arg("block_m"), \
py::arg("shuffle_scale") = false, \
py::arg("transpose_out") = false); \
m.def("fused_dynamic_mxfp4_quant_moe_sort_hip", \
&aiter::fused_dynamic_mxfp4_quant_moe_sort_hip, \
m.def("fused_dynamic_mx_quant_moe_sort_hip", \
&aiter::fused_dynamic_mx_quant_moe_sort_hip, \
py::arg("out"), \
py::arg("scales"), \
py::arg("input"), \
Expand Down
40 changes: 30 additions & 10 deletions csrc/kernels/quant_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1596,10 +1596,25 @@ __global__ void mxfp4_quant_moe_sort_kernel(
: (sizeof(DTYPE_I) * vec_size_i % 8 == 0 ? 8 : 4));
using vec_i = opus::vector_t<DTYPE_I, vec_size_i>;
using vec_f = opus::vector_t<float, vec_size_i>;
// For e8m0-scaled dtypes (fp4, fp8) use 1 / floor_pow2(DTYPE_MAX) so that
// `row_scale = pow2(absMax) * inverted_DTYPE_MAX` is itself a pure power of 2 —
// that keeps the quant divisor consistent with the dequant scale `2^(byte-127)`
// we encode in the e8m0 byte (the `>> 23` extraction below otherwise discards
// mantissa bits and breaks accuracy). For other dtypes fall back to the exact
// 1 / DTYPE_MAX divisor.
#if defined(__gfx942__)
/* gfx942 fp8 e4m3 fnuz max=240, floor_pow2(240)=128 */
constexpr float fp8_power2_limit = 1.0f / 128.0f;
#else
/* gfx950 fp8 e4m3 max=448, floor_pow2(448)=256 */
constexpr float fp8_power2_limit = 1.0f / 256.0f;
#endif
const float inverted_DTYPE_MAX =
std::is_same_v<DTYPE_O, opus::fp4_t>
? 0.25
: (1. / static_cast<float>(opus::finfo<DTYPE_O>::max()));
? 0.25f /* 1/4, fp4 max=6 */
: (std::is_same_v<DTYPE_O, opus::fp8_t>
? fp8_power2_limit
: 1.0f / static_cast<float>(opus::finfo<DTYPE_O>::max()));
const int32_t scaleN_valid = (cols + group_size - 1) / group_size;
const int32_t scaleN_pad = ((scaleN_valid + 7) / 8) * 8;

Expand Down Expand Up @@ -1647,9 +1662,12 @@ __global__ void mxfp4_quant_moe_sort_kernel(
}
absMax = multithread_reduce(absMax, hipcub::Max(), num_thread_per_group);

float row_scale = std::is_same_v<DTYPE_O, opus::fp4_t>
? aiter::fp4_f32_to_e8m0_scale(absMax) * inverted_DTYPE_MAX
: absMax * inverted_DTYPE_MAX;
float row_scale =
std::is_same_v<DTYPE_O, opus::fp4_t>
? aiter::fp4_f32_to_e8m0_scale(absMax) * inverted_DTYPE_MAX
: (std::is_same_v<DTYPE_O, opus::fp8_t>
? aiter::fp4_f32_to_e8m0_scale(absMax) * inverted_DTYPE_MAX
: absMax * inverted_DTYPE_MAX);

const int sorted_row = sorted_ids_base + i * tgs_per_block_m;
if(threadIdx.x % num_thread_per_group == 0 && scale_k < scaleN_valid)
Expand Down Expand Up @@ -1721,7 +1739,7 @@ __global__ void mxfp4_quant_moe_sort_kernel(
AITER_CHECK(false, "input last dim has exceeded the maximum value ", 32 * BlockSize); \
}

void fused_dynamic_mxfp4_quant_moe_sort_hip(
void fused_dynamic_mx_quant_moe_sort_hip(
aiter_tensor_t& output,
aiter_tensor_t& scale,
const aiter_tensor_t& input,
Expand All @@ -1747,18 +1765,20 @@ void fused_dynamic_mxfp4_quant_moe_sort_hip(
HipDeviceGuard device_guard(input.device_id);
const hipStream_t stream = aiter::getCurrentHIPStream();

if(output.dtype() == AITER_DTYPE_fp8)
{
MXFP4_QUANT_MOE_SORT_KERNEL_DISPATCH(opus::fp8_t, cols);
}
#if defined(__Float4_e2m1fn_x2)
if(output.dtype() == AITER_DTYPE_fp4x2 || output.dtype() == AITER_DTYPE_u8)
else if(output.dtype() == AITER_DTYPE_fp4x2 || output.dtype() == AITER_DTYPE_u8)
{
MXFP4_QUANT_MOE_SORT_KERNEL_DISPATCH(opus::fp4_t, cols);
}
#endif
else
{
AITER_CHECK(false, __func__, ": not support output type: ", AiterDtype_to_str(output.dtype()));
}
#else
AITER_CHECK(false, __func__, ": not support fp4x2 on this device");
#endif
}

template <int block_size, int num_rows, int thread_data_size = 16, int group_size = 32>
Expand Down
2 changes: 1 addition & 1 deletion op_tests/test_moe_sorting_mxfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def test_moe_mxfp4_quant_sort(dtype, token_num, model_dim, E, topk, block_size,
device=input.device,
)
_, hip_us = run_perftest(
aiter.fused_dynamic_mxfp4_quant_moe_sort_hip,
aiter.fused_dynamic_mx_quant_moe_sort_hip,
hip_out,
hip_scale,
input,
Expand Down
Loading