From 486dd5b5807a4e5a39ff3d90e742d0b57feffd3b Mon Sep 17 00:00:00 2001 From: yadaish Date: Fri, 22 May 2026 08:44:18 +0000 Subject: [PATCH 1/6] [MOE] Add fused dynamic MXFP8 quant + moe_sort HIP path Adds fused_dynamic_mxfp8_quant_moe_sort_hip that quantizes activations to fp8 with e8m0 group scales and writes the swizzled scale layout consumed by the FlyDSL a8w4 stage1/stage2 GEMM. Wires it into fused_moe_2stages to replace the Triton fused_quant_fp8_sort path, with AITER_MOE_A8W4_BYPASS_QUANT to fall back to the prior unquantized behavior. For e8m0-scaled fp8 output, the divisor is switched to 1/floor_pow2(max) so row_scale stays a pure power of two and matches the 2^(byte-127) dequant scale encoded in the e8m0 byte. Co-Authored-By: Claude Opus 4 --- aiter/fused_moe.py | 41 +++++++++++++++++-------- aiter/ops/quant.py | 55 +++++++++++++++++++++++++++++++++ csrc/include/quant.h | 9 ++++++ csrc/include/rocm_ops.hpp | 10 ++++++ csrc/kernels/quant_kernels.cu | 57 ++++++++++++++++++++++++++++++++--- 5 files changed, 154 insertions(+), 18 deletions(-) diff --git a/aiter/fused_moe.py b/aiter/fused_moe.py index 131a92d3c7..4c4b736273 100644 --- a/aiter/fused_moe.py +++ b/aiter/fused_moe.py @@ -18,7 +18,11 @@ from aiter.jit.utils.torch_guard import torch_compile_guard from aiter.ops.flydsl.utils import is_flydsl_available from aiter.ops.flydsl.moe_common import GateMode -from aiter import fused_dynamic_mxfp4_quant_moe_sort, mxfp4_moe_sort_fwd +from aiter import ( + fused_dynamic_mxfp4_quant_moe_sort, + fused_dynamic_mxfp8_quant_moe_sort, + mxfp4_moe_sort_fwd, +) BLOCK_SIZE_M = 32 @@ -26,6 +30,7 @@ _USE_CK_MOE_SORTING = os.environ.get("AITER_USE_CK_MOE_SORTING", "0") == "1" _ACT_TYPE_DISABLED_KEY = "__ignore__" _SWIGLU_MXFP4_BF16_BOUND = int(os.environ.get("GPTOSS_SWIGLU_MXFP4_BF16_BOUND", "256")) +_MOE_A8W4_BYPASS_QUANT = os.environ.get("AITER_MOE_A8W4_BYPASS_QUANT", "0") == "1" def _moe_sorting_impl( @@ -1555,13 +1560,25 @@ def fused_moe_2stages( and w1.dtype == dtypes.fp4x2 # and activation == aiter.ActivationType.Swiglu ): - a1 = hidden_states.to(dtypes.fp8) - M = sorted_ids.shape[0] - N = a1.shape[-1] - if metadata.fuse_quant == "fp8": - a1_scale = torch.empty(0, dtype=torch.uint8, device=a1.device) + if not _MOE_A8W4_BYPASS_QUANT: + # stage1 input is not topk-replicated, so M==token_num and the + # HIP launcher infers TOPK=1 from input.numel() / (cols * token_num). + a1, a1_scale = fused_dynamic_mxfp8_quant_moe_sort( + hidden_states, + sorted_ids=sorted_ids, + num_valid_ids=num_valid_ids, + token_num=token_num, + topk=topk, + block_size=block_size_M, + ) else: - a1_scale = torch.ones([M, N // 32], dtype=dtypes.fp8_e8m0, device=a1.device) + a1 = hidden_states.to(dtypes.fp8) + M = sorted_ids.shape[0] + N = a1.shape[-1] + if metadata.fuse_quant == "fp8": + a1_scale = torch.empty(0, dtype=torch.uint8, device=a1.device) + else: + a1_scale = torch.ones([M, N // 32], dtype=dtypes.fp8_e8m0, device=a1.device) elif quant_type == QuantType.per_1x32 and w1.dtype == dtypes.i4x2: # a16wi4: bf16 activations, int4 weights; no activation quantization needed @@ -1676,17 +1693,15 @@ def fused_moe_2stages( and q_dtype_a == dtypes.fp8 and w1.dtype == dtypes.fp4x2 ): - if activation == ActivationType.Silu and swiglu_limit == 0.0: - from aiter.ops.triton.quant.fused_mxfp4_quant import fused_quant_fp8_sort - + if not _MOE_A8W4_BYPASS_QUANT: a2 = a2.view(-1, inter_dim) - a2, a2_scale = fused_quant_fp8_sort( + a2, a2_scale = fused_dynamic_mxfp8_quant_moe_sort( a2, sorted_ids=sorted_ids, num_valid_ids=num_valid_ids, token_num=token_num, - block_size=32, - quant_dtype=dtypes.fp8, + topk=topk, + block_size=block_size_M, ) a2 = a2.view(token_num, topk, -1) else: diff --git a/aiter/ops/quant.py b/aiter/ops/quant.py index f9996d1e0b..b0edc096ce 100644 --- a/aiter/ops/quant.py +++ b/aiter/ops/quant.py @@ -684,6 +684,23 @@ def fused_dynamic_mxfp4_quant_moe_sort_hip( ... +@compile_ops("module_quant", develop=True) +def fused_dynamic_mxfp8_quant_moe_sort_hip( + out: torch.Tensor, + scales: torch.Tensor, + input: torch.Tensor, + sorted_ids: torch.Tensor, + num_valid_ids: torch.Tensor, + token_num: int, + block_m: int, + group_size: int = 32, +) -> None: + """ + HIP path for fused dynamic MXFP8 quantization and MoE scale sorting. + """ + ... + + @compile_ops("module_quant", develop=True) def quant_mxfp4( inp: torch.Tensor, @@ -797,6 +814,44 @@ def fused_dynamic_mxfp4_quant_moe_sort( return out, scale +def fused_dynamic_mxfp8_quant_moe_sort( + input: torch.Tensor, + sorted_ids: torch.Tensor, + num_valid_ids: torch.Tensor, + token_num: int, + topk: int, + block_size: int, + group_size: int = 32, +) -> Tuple[torch.Tensor, torch.Tensor]: + """HIP replacement for Triton fused_quant_fp8_sort. + + Returns (fp8_out, e8m0_scale) with the scale tensor laid out as + (pad32(M_o), N_o) fp8_e8m0 — the same byte layout the FlyDSL stage1/stage2 + GEMM consumes. ``topk`` is accepted for parity with the fp4 wrapper but + is inferred inside the HIP launcher from ``input.numel() / (cols * token_num)``. + """ + M, N = input.view(-1, input.shape[-1]).shape + N_o = (N + 31) // 32 + out = torch.empty(M, N, dtype=dtypes.fp8, device=input.device) + scale = torch.empty( + (sorted_ids.shape[0] + 31) // 32 * 32, + N_o, + dtype=dtypes.fp8_e8m0, + device=input.device, + ) + fused_dynamic_mxfp8_quant_moe_sort_hip( + out, + scale, + input, + sorted_ids, + num_valid_ids, + token_num, + block_size, + group_size, + ) + return out, scale + + @compile_ops("module_quant", develop=True) def partial_transpose( out: Tensor, diff --git a/csrc/include/quant.h b/csrc/include/quant.h index 90d093428f..9d5bfbfaec 100644 --- a/csrc/include/quant.h +++ b/csrc/include/quant.h @@ -77,6 +77,15 @@ void fused_dynamic_mxfp4_quant_moe_sort_hip(aiter_tensor_t& out, // [tok int block_m, int group_size = 32); +void fused_dynamic_mxfp8_quant_moe_sort_hip(aiter_tensor_t& out, // [token_num * topk, d] + aiter_tensor_t& scales, // swizzled e8m0 bytes + const aiter_tensor_t& input, // [token_num * topk, d] + const aiter_tensor_t& sorted_ids, + const aiter_tensor_t& num_valid_ids, + int token_num, + int block_m, + int group_size = 32); + void mxfp4_moe_sort_hip(aiter_tensor_t& out_scale, const aiter_tensor_t& scale, const aiter_tensor_t& sorted_ids, diff --git a/csrc/include/rocm_ops.hpp b/csrc/include/rocm_ops.hpp index 8ebcfcf92f..54bd18d196 100644 --- a/csrc/include/rocm_ops.hpp +++ b/csrc/include/rocm_ops.hpp @@ -1421,6 +1421,16 @@ namespace py = pybind11; py::arg("token_num"), \ py::arg("block_m"), \ py::arg("group_size") = 32); \ + m.def("fused_dynamic_mxfp8_quant_moe_sort_hip", \ + &aiter::fused_dynamic_mxfp8_quant_moe_sort_hip, \ + py::arg("out"), \ + py::arg("scales"), \ + py::arg("input"), \ + py::arg("sorted_ids"), \ + py::arg("num_valid_ids"), \ + py::arg("token_num"), \ + py::arg("block_m"), \ + py::arg("group_size") = 32); \ m.def("mxfp4_moe_sort_hip", \ &aiter::mxfp4_moe_sort_hip, \ py::arg("out_scale"), \ diff --git a/csrc/kernels/quant_kernels.cu b/csrc/kernels/quant_kernels.cu index 3b73ffae18..fd150f5e0f 100644 --- a/csrc/kernels/quant_kernels.cu +++ b/csrc/kernels/quant_kernels.cu @@ -1596,10 +1596,18 @@ __global__ void mxfp4_quant_moe_sort_kernel( : (sizeof(DTYPE_I) * vec_size_i % 8 == 0 ? 8 : 4)); using vec_i = opus::vector_t; using vec_f = opus::vector_t; + // For e8m0-scaled dtypes (fp4, fp8) use 1 / floor_pow2(DTYPE_MAX) so that + // `row_scale = pow2(absMax) * inverted_DTYPE_MAX` is itself a pure power of 2 — + // that keeps the quant divisor consistent with the dequant scale `2^(byte-127)` + // we encode in the e8m0 byte (the `>> 23` extraction below otherwise discards + // mantissa bits and breaks accuracy). For other dtypes fall back to the exact + // 1 / DTYPE_MAX divisor. const float inverted_DTYPE_MAX = std::is_same_v - ? 0.25 - : (1. / static_cast(opus::finfo::max())); + ? 0.25f /* 1/4, fp4 max=6 */ + : (std::is_same_v + ? 1.0f / 256.0f /* fp8 e4m3 max=448 */ + : 1.0f / static_cast(opus::finfo::max())); const int32_t scaleN_valid = (cols + group_size - 1) / group_size; const int32_t scaleN_pad = ((scaleN_valid + 7) / 8) * 8; @@ -1647,9 +1655,12 @@ __global__ void mxfp4_quant_moe_sort_kernel( } absMax = multithread_reduce(absMax, hipcub::Max(), num_thread_per_group); - float row_scale = std::is_same_v - ? aiter::fp4_f32_to_e8m0_scale(absMax) * inverted_DTYPE_MAX - : absMax * inverted_DTYPE_MAX; + float row_scale = + std::is_same_v + ? aiter::fp4_f32_to_e8m0_scale(absMax) * inverted_DTYPE_MAX + : (std::is_same_v + ? aiter::fp4_f32_to_e8m0_scale(absMax) * inverted_DTYPE_MAX + : absMax * inverted_DTYPE_MAX); const int sorted_row = sorted_ids_base + i * tgs_per_block_m; if(threadIdx.x % num_thread_per_group == 0 && scale_k < scaleN_valid) @@ -1761,6 +1772,42 @@ void fused_dynamic_mxfp4_quant_moe_sort_hip( #endif } +void fused_dynamic_mxfp8_quant_moe_sort_hip( + aiter_tensor_t& output, + aiter_tensor_t& scale, + const aiter_tensor_t& input, + const aiter_tensor_t& sorted_ids, + const aiter_tensor_t& num_valid_ids, + int token_num, + int block_m, + int group_size +) +{ + int cols = input.size(-1); + int topk = input.numel() / (cols * token_num); + int num_experts = (sorted_ids.size(0) + topk - topk * token_num) / block_m; + + const int num_cu = get_num_cu_func(); + int sub_block_m = (token_num * topk) > (num_cu * 8) || num_experts < 64 ? 2 : 4; + AITER_CHECK(block_m % sub_block_m == 0, __func__, " block_m is not divisible by sub_block_m"); + int tgs_per_block_m = block_m / sub_block_m; + int num_blocks = (sorted_ids.size(0) + sub_block_m - 1) / sub_block_m; + const bool persistent_mode = false; + const int input_stride = input.stride(-2); + + HipDeviceGuard device_guard(input.device_id); + const hipStream_t stream = aiter::getCurrentHIPStream(); + + if(output.dtype() == AITER_DTYPE_fp8) + { + MXFP4_QUANT_MOE_SORT_KERNEL_DISPATCH(opus::fp8_t, cols); + } + else + { + AITER_CHECK(false, __func__, ": not support output type: ", AiterDtype_to_str(output.dtype())); + } +} + template __global__ void mxfp4_moe_sort_kernel( uint8_t* __restrict__ out_scale, From 07e4337a81474dd8df854e45655a9f19de26afde Mon Sep 17 00:00:00 2001 From: yadaish Date: Mon, 25 May 2026 06:47:37 +0000 Subject: [PATCH 2/6] fix --- csrc/kernels/quant_kernels.cu | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/csrc/kernels/quant_kernels.cu b/csrc/kernels/quant_kernels.cu index fd150f5e0f..6b5e6c632d 100644 --- a/csrc/kernels/quant_kernels.cu +++ b/csrc/kernels/quant_kernels.cu @@ -1602,11 +1602,18 @@ __global__ void mxfp4_quant_moe_sort_kernel( // we encode in the e8m0 byte (the `>> 23` extraction below otherwise discards // mantissa bits and breaks accuracy). For other dtypes fall back to the exact // 1 / DTYPE_MAX divisor. +#if defined(__gfx942__) + /* gfx942 fp8 e4m3 fnuz max=240, floor_pow2(240)=128 */ + constexpr float fp8_power2_limit = 1.0f / 128.0f; +#else + /* gfx950 fp8 e4m3 max=448, floor_pow2(448)=256 */ + constexpr float fp8_power2_limit = 1.0f / 256.0f; +#endif const float inverted_DTYPE_MAX = std::is_same_v ? 0.25f /* 1/4, fp4 max=6 */ : (std::is_same_v - ? 1.0f / 256.0f /* fp8 e4m3 max=448 */ + ? fp8_power2_limit : 1.0f / static_cast(opus::finfo::max())); const int32_t scaleN_valid = (cols + group_size - 1) / group_size; const int32_t scaleN_pad = ((scaleN_valid + 7) / 8) * 8; From b7e559aaea5ceaa6bf32a9cac15fd17c32049e5b Mon Sep 17 00:00:00 2001 From: yadaish Date: Mon, 25 May 2026 07:29:43 +0000 Subject: [PATCH 3/6] [MOE] Unify fp4/fp8 fused dynamic mx quant moe_sort + gfx942/950 fp8 limit Merge fused_dynamic_mxfp4/mxfp8_quant_moe_sort_hip into a single fused_dynamic_mx_quant_moe_sort_hip that dispatches on out tensor dtype. Also select fp8 power-of-two divisor per arch: 1/128 on gfx942 (fp8 e4m3 fnuz max=240), 1/256 on gfx950 (fp8 e4m3 max=448). Co-Authored-By: Claude Opus 4 --- aiter/ops/quant.py | 27 ++++-------------- csrc/include/quant.h | 11 +------ csrc/include/rocm_ops.hpp | 14 ++------- csrc/kernels/quant_kernels.cu | 46 ++++-------------------------- op_tests/test_moe_sorting_mxfp4.py | 2 +- 5 files changed, 16 insertions(+), 84 deletions(-) diff --git a/aiter/ops/quant.py b/aiter/ops/quant.py index b0edc096ce..bea0a49616 100644 --- a/aiter/ops/quant.py +++ b/aiter/ops/quant.py @@ -668,7 +668,7 @@ def mxfp4_moe_sort_fwd( @compile_ops("module_quant", develop=True) -def fused_dynamic_mxfp4_quant_moe_sort_hip( +def fused_dynamic_mx_quant_moe_sort_hip( out: torch.Tensor, scales: torch.Tensor, input: torch.Tensor, @@ -679,24 +679,9 @@ def fused_dynamic_mxfp4_quant_moe_sort_hip( group_size: int = 32, ) -> None: """ - HIP path for fused dynamic MXFP4 quantization and MoE scale sorting. - """ - ... - - -@compile_ops("module_quant", develop=True) -def fused_dynamic_mxfp8_quant_moe_sort_hip( - out: torch.Tensor, - scales: torch.Tensor, - input: torch.Tensor, - sorted_ids: torch.Tensor, - num_valid_ids: torch.Tensor, - token_num: int, - block_m: int, - group_size: int = 32, -) -> None: - """ - HIP path for fused dynamic MXFP8 quantization and MoE scale sorting. + HIP path for fused dynamic MX (fp4 or fp8) quantization and MoE scale + sorting. The output dtype of ``out`` selects the quant target: fp4x2/uint8 + for MXFP4, fp8 for MXFP8. """ ... @@ -796,7 +781,7 @@ def fused_dynamic_mxfp4_quant_moe_sort( or group_size != 32 ): out = torch.empty(M, N // 2, dtype=dtypes.fp4x2, device=input.device) - fused_dynamic_mxfp4_quant_moe_sort_hip( + fused_dynamic_mx_quant_moe_sort_hip( out, scale, input, @@ -839,7 +824,7 @@ def fused_dynamic_mxfp8_quant_moe_sort( dtype=dtypes.fp8_e8m0, device=input.device, ) - fused_dynamic_mxfp8_quant_moe_sort_hip( + fused_dynamic_mx_quant_moe_sort_hip( out, scale, input, diff --git a/csrc/include/quant.h b/csrc/include/quant.h index 9d5bfbfaec..3825aedddd 100644 --- a/csrc/include/quant.h +++ b/csrc/include/quant.h @@ -68,16 +68,7 @@ void moe_smooth_per_token_scaled_quant_v2(aiter_tensor_t& out, // [..., bool shuffle_scale = false, bool transpose_out = false); -void fused_dynamic_mxfp4_quant_moe_sort_hip(aiter_tensor_t& out, // [token_num * topk, d / 2] - aiter_tensor_t& scales, // swizzled e8m0 bytes - const aiter_tensor_t& input, // [token_num * topk, d] - const aiter_tensor_t& sorted_ids, - const aiter_tensor_t& num_valid_ids, - int token_num, - int block_m, - int group_size = 32); - -void fused_dynamic_mxfp8_quant_moe_sort_hip(aiter_tensor_t& out, // [token_num * topk, d] +void fused_dynamic_mx_quant_moe_sort_hip(aiter_tensor_t& out, // [token_num * topk, d] for fp8 or [token_num * topk, d / 2] for fp4 aiter_tensor_t& scales, // swizzled e8m0 bytes const aiter_tensor_t& input, // [token_num * topk, d] const aiter_tensor_t& sorted_ids, diff --git a/csrc/include/rocm_ops.hpp b/csrc/include/rocm_ops.hpp index 54bd18d196..7e18349022 100644 --- a/csrc/include/rocm_ops.hpp +++ b/csrc/include/rocm_ops.hpp @@ -1411,18 +1411,8 @@ namespace py = pybind11; py::arg("block_m"), \ py::arg("shuffle_scale") = false, \ py::arg("transpose_out") = false); \ - m.def("fused_dynamic_mxfp4_quant_moe_sort_hip", \ - &aiter::fused_dynamic_mxfp4_quant_moe_sort_hip, \ - py::arg("out"), \ - py::arg("scales"), \ - py::arg("input"), \ - py::arg("sorted_ids"), \ - py::arg("num_valid_ids"), \ - py::arg("token_num"), \ - py::arg("block_m"), \ - py::arg("group_size") = 32); \ - m.def("fused_dynamic_mxfp8_quant_moe_sort_hip", \ - &aiter::fused_dynamic_mxfp8_quant_moe_sort_hip, \ + m.def("fused_dynamic_mx_quant_moe_sort_hip", \ + &aiter::fused_dynamic_mx_quant_moe_sort_hip, \ py::arg("out"), \ py::arg("scales"), \ py::arg("input"), \ diff --git a/csrc/kernels/quant_kernels.cu b/csrc/kernels/quant_kernels.cu index 6b5e6c632d..f75f72bf0c 100644 --- a/csrc/kernels/quant_kernels.cu +++ b/csrc/kernels/quant_kernels.cu @@ -1739,7 +1739,7 @@ __global__ void mxfp4_quant_moe_sort_kernel( AITER_CHECK(false, "input last dim has exceeded the maximum value ", 32 * BlockSize); \ } -void fused_dynamic_mxfp4_quant_moe_sort_hip( +void fused_dynamic_mx_quant_moe_sort_hip( aiter_tensor_t& output, aiter_tensor_t& scale, const aiter_tensor_t& input, @@ -1765,50 +1765,16 @@ void fused_dynamic_mxfp4_quant_moe_sort_hip( HipDeviceGuard device_guard(input.device_id); const hipStream_t stream = aiter::getCurrentHIPStream(); -#if defined(__Float4_e2m1fn_x2) - if(output.dtype() == AITER_DTYPE_fp4x2 || output.dtype() == AITER_DTYPE_u8) + if(output.dtype() == AITER_DTYPE_fp8) { - MXFP4_QUANT_MOE_SORT_KERNEL_DISPATCH(opus::fp4_t, cols); + MXFP4_QUANT_MOE_SORT_KERNEL_DISPATCH(opus::fp8_t, cols); } - else +#if defined(__Float4_e2m1fn_x2) + else if(output.dtype() == AITER_DTYPE_fp4x2 || output.dtype() == AITER_DTYPE_u8) { - AITER_CHECK(false, __func__, ": not support output type: ", AiterDtype_to_str(output.dtype())); + MXFP4_QUANT_MOE_SORT_KERNEL_DISPATCH(opus::fp4_t, cols); } -#else - AITER_CHECK(false, __func__, ": not support fp4x2 on this device"); #endif -} - -void fused_dynamic_mxfp8_quant_moe_sort_hip( - aiter_tensor_t& output, - aiter_tensor_t& scale, - const aiter_tensor_t& input, - const aiter_tensor_t& sorted_ids, - const aiter_tensor_t& num_valid_ids, - int token_num, - int block_m, - int group_size -) -{ - int cols = input.size(-1); - int topk = input.numel() / (cols * token_num); - int num_experts = (sorted_ids.size(0) + topk - topk * token_num) / block_m; - - const int num_cu = get_num_cu_func(); - int sub_block_m = (token_num * topk) > (num_cu * 8) || num_experts < 64 ? 2 : 4; - AITER_CHECK(block_m % sub_block_m == 0, __func__, " block_m is not divisible by sub_block_m"); - int tgs_per_block_m = block_m / sub_block_m; - int num_blocks = (sorted_ids.size(0) + sub_block_m - 1) / sub_block_m; - const bool persistent_mode = false; - const int input_stride = input.stride(-2); - - HipDeviceGuard device_guard(input.device_id); - const hipStream_t stream = aiter::getCurrentHIPStream(); - - if(output.dtype() == AITER_DTYPE_fp8) - { - MXFP4_QUANT_MOE_SORT_KERNEL_DISPATCH(opus::fp8_t, cols); - } else { AITER_CHECK(false, __func__, ": not support output type: ", AiterDtype_to_str(output.dtype())); diff --git a/op_tests/test_moe_sorting_mxfp4.py b/op_tests/test_moe_sorting_mxfp4.py index 2beb79a674..0f7fa20a5b 100644 --- a/op_tests/test_moe_sorting_mxfp4.py +++ b/op_tests/test_moe_sorting_mxfp4.py @@ -176,7 +176,7 @@ def test_moe_mxfp4_quant_sort(dtype, token_num, model_dim, E, topk, block_size, device=input.device, ) _, hip_us = run_perftest( - aiter.fused_dynamic_mxfp4_quant_moe_sort_hip, + aiter.fused_dynamic_mx_quant_moe_sort_hip, hip_out, hip_scale, input, From 9540b9943b3077358075f94743b246d328fcb7f4 Mon Sep 17 00:00:00 2001 From: yadaish Date: Mon, 25 May 2026 07:30:32 +0000 Subject: [PATCH 4/6] fix black format Co-Authored-By: Claude Opus 4 --- aiter/fused_moe.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/aiter/fused_moe.py b/aiter/fused_moe.py index 4c4b736273..796308d7b1 100644 --- a/aiter/fused_moe.py +++ b/aiter/fused_moe.py @@ -1578,7 +1578,9 @@ def fused_moe_2stages( if metadata.fuse_quant == "fp8": a1_scale = torch.empty(0, dtype=torch.uint8, device=a1.device) else: - a1_scale = torch.ones([M, N // 32], dtype=dtypes.fp8_e8m0, device=a1.device) + a1_scale = torch.ones( + [M, N // 32], dtype=dtypes.fp8_e8m0, device=a1.device + ) elif quant_type == QuantType.per_1x32 and w1.dtype == dtypes.i4x2: # a16wi4: bf16 activations, int4 weights; no activation quantization needed From cb480e883eb8dd460d2771a3e39621f1c22df1f8 Mon Sep 17 00:00:00 2001 From: yadaish Date: Mon, 25 May 2026 11:58:11 +0000 Subject: [PATCH 5/6] [MOE] Skip fused HIP fp8 quant when stage1 kernel fuses fp8 quant When metadata.fuse_quant == "fp8" the FlyDSL stage1 kernel quantizes its fp8 activation input internally (the bypass branch already sets a1_scale to an empty tensor in this case). The new fused_dynamic_mxfp8_quant_moe_sort path was pre-quantizing a1 and handing a real e8m0 scale to the kernel, which then re-applied quant and produced ~10^9-magnitude garbage in test_moe_2stage -q 7. Gate the HIP path on fuse_quant != "fp8" so the bypass-equivalent path is used. Co-Authored-By: Claude Opus 4 --- aiter/fused_moe.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/aiter/fused_moe.py b/aiter/fused_moe.py index 796308d7b1..86f4224fa7 100644 --- a/aiter/fused_moe.py +++ b/aiter/fused_moe.py @@ -1560,9 +1560,11 @@ def fused_moe_2stages( and w1.dtype == dtypes.fp4x2 # and activation == aiter.ActivationType.Swiglu ): - if not _MOE_A8W4_BYPASS_QUANT: + if not _MOE_A8W4_BYPASS_QUANT and metadata.fuse_quant != "fp8": # stage1 input is not topk-replicated, so M==token_num and the # HIP launcher infers TOPK=1 from input.numel() / (cols * token_num). + # When fuse_quant == "fp8" the FlyDSL stage1 kernel quantizes a1 + # internally, so we must NOT pre-quantize here. a1, a1_scale = fused_dynamic_mxfp8_quant_moe_sort( hidden_states, sorted_ids=sorted_ids, From cbab14edea6bee215f9bdca9d26a329977b45929 Mon Sep 17 00:00:00 2001 From: yadaish Date: Mon, 25 May 2026 12:53:36 +0000 Subject: [PATCH 6/6] [MOE] Cast a1 to fp8 dtype when stage1 fuses fp8 quant The FlyDSL stage1 kernel fuses fp8 quant of a1 internally, but its dispatch still requires an fp8-typed input tensor. Passing raw bf16 caused the kernel to reinterpret bytes as fp8 and produce NaN logits. Co-Authored-By: Claude Opus 4 --- aiter/fused_moe.py | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/aiter/fused_moe.py b/aiter/fused_moe.py index 86f4224fa7..9e578cfa09 100644 --- a/aiter/fused_moe.py +++ b/aiter/fused_moe.py @@ -1554,17 +1554,27 @@ def fused_moe_2stages( a1 = hidden_states.to(dtype) a1_scale = None elif ( - quant_type == aiter.QuantType.per_1x32 + quant_type == QuantType.per_1x32 and dtype in [dtypes.bf16, dtypes.fp16] and q_dtype_a == dtypes.fp8 and w1.dtype == dtypes.fp4x2 - # and activation == aiter.ActivationType.Swiglu ): - if not _MOE_A8W4_BYPASS_QUANT and metadata.fuse_quant != "fp8": - # stage1 input is not topk-replicated, so M==token_num and the - # HIP launcher infers TOPK=1 from input.numel() / (cols * token_num). - # When fuse_quant == "fp8" the FlyDSL stage1 kernel quantizes a1 - # internally, so we must NOT pre-quantize here. + # a8w4 mxfp8 activations + mxfp4 weights. + if metadata.fuse_quant == "fp8": + # FlyDSL stage1 fuses the fp8 quant of a1 internally, but the + # kernel dispatch requires an fp8-typed tensor. + a1 = hidden_states.to(dtypes.fp8) + a1_scale = torch.empty(0, dtype=torch.uint8, device=a1.device) + elif _MOE_A8W4_BYPASS_QUANT: + # Debug bypass: skip real quant, feed unit scales. + a1 = hidden_states.to(dtypes.fp8) + M = sorted_ids.shape[0] + a1_scale = torch.ones( + [M, a1.shape[-1] // 32], dtype=dtypes.fp8_e8m0, device=a1.device + ) + else: + # stage1 input is not topk-replicated, so M==token_num and the HIP + # launcher infers TOPK=1 from input.numel() / (cols * token_num). a1, a1_scale = fused_dynamic_mxfp8_quant_moe_sort( hidden_states, sorted_ids=sorted_ids, @@ -1573,16 +1583,6 @@ def fused_moe_2stages( topk=topk, block_size=block_size_M, ) - else: - a1 = hidden_states.to(dtypes.fp8) - M = sorted_ids.shape[0] - N = a1.shape[-1] - if metadata.fuse_quant == "fp8": - a1_scale = torch.empty(0, dtype=torch.uint8, device=a1.device) - else: - a1_scale = torch.ones( - [M, N // 32], dtype=dtypes.fp8_e8m0, device=a1.device - ) elif quant_type == QuantType.per_1x32 and w1.dtype == dtypes.i4x2: # a16wi4: bf16 activations, int4 weights; no activation quantization needed