diff --git a/aiter/fused_moe.py b/aiter/fused_moe.py index 131a92d3c7..9e578cfa09 100644 --- a/aiter/fused_moe.py +++ b/aiter/fused_moe.py @@ -18,7 +18,11 @@ from aiter.jit.utils.torch_guard import torch_compile_guard from aiter.ops.flydsl.utils import is_flydsl_available from aiter.ops.flydsl.moe_common import GateMode -from aiter import fused_dynamic_mxfp4_quant_moe_sort, mxfp4_moe_sort_fwd +from aiter import ( + fused_dynamic_mxfp4_quant_moe_sort, + fused_dynamic_mxfp8_quant_moe_sort, + mxfp4_moe_sort_fwd, +) BLOCK_SIZE_M = 32 @@ -26,6 +30,7 @@ _USE_CK_MOE_SORTING = os.environ.get("AITER_USE_CK_MOE_SORTING", "0") == "1" _ACT_TYPE_DISABLED_KEY = "__ignore__" _SWIGLU_MXFP4_BF16_BOUND = int(os.environ.get("GPTOSS_SWIGLU_MXFP4_BF16_BOUND", "256")) +_MOE_A8W4_BYPASS_QUANT = os.environ.get("AITER_MOE_A8W4_BYPASS_QUANT", "0") == "1" def _moe_sorting_impl( @@ -1549,19 +1554,35 @@ def fused_moe_2stages( a1 = hidden_states.to(dtype) a1_scale = None elif ( - quant_type == aiter.QuantType.per_1x32 + quant_type == QuantType.per_1x32 and dtype in [dtypes.bf16, dtypes.fp16] and q_dtype_a == dtypes.fp8 and w1.dtype == dtypes.fp4x2 - # and activation == aiter.ActivationType.Swiglu ): - a1 = hidden_states.to(dtypes.fp8) - M = sorted_ids.shape[0] - N = a1.shape[-1] + # a8w4 mxfp8 activations + mxfp4 weights. if metadata.fuse_quant == "fp8": + # FlyDSL stage1 fuses the fp8 quant of a1 internally, but the + # kernel dispatch requires an fp8-typed tensor. + a1 = hidden_states.to(dtypes.fp8) a1_scale = torch.empty(0, dtype=torch.uint8, device=a1.device) + elif _MOE_A8W4_BYPASS_QUANT: + # Debug bypass: skip real quant, feed unit scales. + a1 = hidden_states.to(dtypes.fp8) + M = sorted_ids.shape[0] + a1_scale = torch.ones( + [M, a1.shape[-1] // 32], dtype=dtypes.fp8_e8m0, device=a1.device + ) else: - a1_scale = torch.ones([M, N // 32], dtype=dtypes.fp8_e8m0, device=a1.device) + # stage1 input is not topk-replicated, so M==token_num and the HIP + # launcher infers TOPK=1 from input.numel() / (cols * token_num). + a1, a1_scale = fused_dynamic_mxfp8_quant_moe_sort( + hidden_states, + sorted_ids=sorted_ids, + num_valid_ids=num_valid_ids, + token_num=token_num, + topk=topk, + block_size=block_size_M, + ) elif quant_type == QuantType.per_1x32 and w1.dtype == dtypes.i4x2: # a16wi4: bf16 activations, int4 weights; no activation quantization needed @@ -1676,17 +1697,15 @@ def fused_moe_2stages( and q_dtype_a == dtypes.fp8 and w1.dtype == dtypes.fp4x2 ): - if activation == ActivationType.Silu and swiglu_limit == 0.0: - from aiter.ops.triton.quant.fused_mxfp4_quant import fused_quant_fp8_sort - + if not _MOE_A8W4_BYPASS_QUANT: a2 = a2.view(-1, inter_dim) - a2, a2_scale = fused_quant_fp8_sort( + a2, a2_scale = fused_dynamic_mxfp8_quant_moe_sort( a2, sorted_ids=sorted_ids, num_valid_ids=num_valid_ids, token_num=token_num, - block_size=32, - quant_dtype=dtypes.fp8, + topk=topk, + block_size=block_size_M, ) a2 = a2.view(token_num, topk, -1) else: diff --git a/aiter/ops/quant.py b/aiter/ops/quant.py index f9996d1e0b..bea0a49616 100644 --- a/aiter/ops/quant.py +++ b/aiter/ops/quant.py @@ -668,7 +668,7 @@ def mxfp4_moe_sort_fwd( @compile_ops("module_quant", develop=True) -def fused_dynamic_mxfp4_quant_moe_sort_hip( +def fused_dynamic_mx_quant_moe_sort_hip( out: torch.Tensor, scales: torch.Tensor, input: torch.Tensor, @@ -679,7 +679,9 @@ def fused_dynamic_mxfp4_quant_moe_sort_hip( group_size: int = 32, ) -> None: """ - HIP path for fused dynamic MXFP4 quantization and MoE scale sorting. + HIP path for fused dynamic MX (fp4 or fp8) quantization and MoE scale + sorting. The output dtype of ``out`` selects the quant target: fp4x2/uint8 + for MXFP4, fp8 for MXFP8. """ ... @@ -779,7 +781,7 @@ def fused_dynamic_mxfp4_quant_moe_sort( or group_size != 32 ): out = torch.empty(M, N // 2, dtype=dtypes.fp4x2, device=input.device) - fused_dynamic_mxfp4_quant_moe_sort_hip( + fused_dynamic_mx_quant_moe_sort_hip( out, scale, input, @@ -797,6 +799,44 @@ def fused_dynamic_mxfp4_quant_moe_sort( return out, scale +def fused_dynamic_mxfp8_quant_moe_sort( + input: torch.Tensor, + sorted_ids: torch.Tensor, + num_valid_ids: torch.Tensor, + token_num: int, + topk: int, + block_size: int, + group_size: int = 32, +) -> Tuple[torch.Tensor, torch.Tensor]: + """HIP replacement for Triton fused_quant_fp8_sort. + + Returns (fp8_out, e8m0_scale) with the scale tensor laid out as + (pad32(M_o), N_o) fp8_e8m0 — the same byte layout the FlyDSL stage1/stage2 + GEMM consumes. ``topk`` is accepted for parity with the fp4 wrapper but + is inferred inside the HIP launcher from ``input.numel() / (cols * token_num)``. + """ + M, N = input.view(-1, input.shape[-1]).shape + N_o = (N + 31) // 32 + out = torch.empty(M, N, dtype=dtypes.fp8, device=input.device) + scale = torch.empty( + (sorted_ids.shape[0] + 31) // 32 * 32, + N_o, + dtype=dtypes.fp8_e8m0, + device=input.device, + ) + fused_dynamic_mx_quant_moe_sort_hip( + out, + scale, + input, + sorted_ids, + num_valid_ids, + token_num, + block_size, + group_size, + ) + return out, scale + + @compile_ops("module_quant", develop=True) def partial_transpose( out: Tensor, diff --git a/csrc/include/quant.h b/csrc/include/quant.h index 90d093428f..3825aedddd 100644 --- a/csrc/include/quant.h +++ b/csrc/include/quant.h @@ -68,7 +68,7 @@ void moe_smooth_per_token_scaled_quant_v2(aiter_tensor_t& out, // [..., bool shuffle_scale = false, bool transpose_out = false); -void fused_dynamic_mxfp4_quant_moe_sort_hip(aiter_tensor_t& out, // [token_num * topk, d / 2] +void fused_dynamic_mx_quant_moe_sort_hip(aiter_tensor_t& out, // [token_num * topk, d] for fp8 or [token_num * topk, d / 2] for fp4 aiter_tensor_t& scales, // swizzled e8m0 bytes const aiter_tensor_t& input, // [token_num * topk, d] const aiter_tensor_t& sorted_ids, diff --git a/csrc/include/rocm_ops.hpp b/csrc/include/rocm_ops.hpp index 8ebcfcf92f..7e18349022 100644 --- a/csrc/include/rocm_ops.hpp +++ b/csrc/include/rocm_ops.hpp @@ -1411,8 +1411,8 @@ namespace py = pybind11; py::arg("block_m"), \ py::arg("shuffle_scale") = false, \ py::arg("transpose_out") = false); \ - m.def("fused_dynamic_mxfp4_quant_moe_sort_hip", \ - &aiter::fused_dynamic_mxfp4_quant_moe_sort_hip, \ + m.def("fused_dynamic_mx_quant_moe_sort_hip", \ + &aiter::fused_dynamic_mx_quant_moe_sort_hip, \ py::arg("out"), \ py::arg("scales"), \ py::arg("input"), \ diff --git a/csrc/kernels/quant_kernels.cu b/csrc/kernels/quant_kernels.cu index 3b73ffae18..f75f72bf0c 100644 --- a/csrc/kernels/quant_kernels.cu +++ b/csrc/kernels/quant_kernels.cu @@ -1596,10 +1596,25 @@ __global__ void mxfp4_quant_moe_sort_kernel( : (sizeof(DTYPE_I) * vec_size_i % 8 == 0 ? 8 : 4)); using vec_i = opus::vector_t; using vec_f = opus::vector_t; + // For e8m0-scaled dtypes (fp4, fp8) use 1 / floor_pow2(DTYPE_MAX) so that + // `row_scale = pow2(absMax) * inverted_DTYPE_MAX` is itself a pure power of 2 — + // that keeps the quant divisor consistent with the dequant scale `2^(byte-127)` + // we encode in the e8m0 byte (the `>> 23` extraction below otherwise discards + // mantissa bits and breaks accuracy). For other dtypes fall back to the exact + // 1 / DTYPE_MAX divisor. +#if defined(__gfx942__) + /* gfx942 fp8 e4m3 fnuz max=240, floor_pow2(240)=128 */ + constexpr float fp8_power2_limit = 1.0f / 128.0f; +#else + /* gfx950 fp8 e4m3 max=448, floor_pow2(448)=256 */ + constexpr float fp8_power2_limit = 1.0f / 256.0f; +#endif const float inverted_DTYPE_MAX = std::is_same_v - ? 0.25 - : (1. / static_cast(opus::finfo::max())); + ? 0.25f /* 1/4, fp4 max=6 */ + : (std::is_same_v + ? fp8_power2_limit + : 1.0f / static_cast(opus::finfo::max())); const int32_t scaleN_valid = (cols + group_size - 1) / group_size; const int32_t scaleN_pad = ((scaleN_valid + 7) / 8) * 8; @@ -1647,9 +1662,12 @@ __global__ void mxfp4_quant_moe_sort_kernel( } absMax = multithread_reduce(absMax, hipcub::Max(), num_thread_per_group); - float row_scale = std::is_same_v - ? aiter::fp4_f32_to_e8m0_scale(absMax) * inverted_DTYPE_MAX - : absMax * inverted_DTYPE_MAX; + float row_scale = + std::is_same_v + ? aiter::fp4_f32_to_e8m0_scale(absMax) * inverted_DTYPE_MAX + : (std::is_same_v + ? aiter::fp4_f32_to_e8m0_scale(absMax) * inverted_DTYPE_MAX + : absMax * inverted_DTYPE_MAX); const int sorted_row = sorted_ids_base + i * tgs_per_block_m; if(threadIdx.x % num_thread_per_group == 0 && scale_k < scaleN_valid) @@ -1721,7 +1739,7 @@ __global__ void mxfp4_quant_moe_sort_kernel( AITER_CHECK(false, "input last dim has exceeded the maximum value ", 32 * BlockSize); \ } -void fused_dynamic_mxfp4_quant_moe_sort_hip( +void fused_dynamic_mx_quant_moe_sort_hip( aiter_tensor_t& output, aiter_tensor_t& scale, const aiter_tensor_t& input, @@ -1747,18 +1765,20 @@ void fused_dynamic_mxfp4_quant_moe_sort_hip( HipDeviceGuard device_guard(input.device_id); const hipStream_t stream = aiter::getCurrentHIPStream(); + if(output.dtype() == AITER_DTYPE_fp8) + { + MXFP4_QUANT_MOE_SORT_KERNEL_DISPATCH(opus::fp8_t, cols); + } #if defined(__Float4_e2m1fn_x2) - if(output.dtype() == AITER_DTYPE_fp4x2 || output.dtype() == AITER_DTYPE_u8) + else if(output.dtype() == AITER_DTYPE_fp4x2 || output.dtype() == AITER_DTYPE_u8) { MXFP4_QUANT_MOE_SORT_KERNEL_DISPATCH(opus::fp4_t, cols); } +#endif else { AITER_CHECK(false, __func__, ": not support output type: ", AiterDtype_to_str(output.dtype())); } -#else - AITER_CHECK(false, __func__, ": not support fp4x2 on this device"); -#endif } template diff --git a/op_tests/test_moe_sorting_mxfp4.py b/op_tests/test_moe_sorting_mxfp4.py index 2beb79a674..0f7fa20a5b 100644 --- a/op_tests/test_moe_sorting_mxfp4.py +++ b/op_tests/test_moe_sorting_mxfp4.py @@ -176,7 +176,7 @@ def test_moe_mxfp4_quant_sort(dtype, token_num, model_dim, E, topk, block_size, device=input.device, ) _, hip_us = run_perftest( - aiter.fused_dynamic_mxfp4_quant_moe_sort_hip, + aiter.fused_dynamic_mx_quant_moe_sort_hip, hip_out, hip_scale, input,