From 76d97f75fc400d240f4ca0f14f3e91efe1065339 Mon Sep 17 00:00:00 2001 From: Junlin Chen Date: Fri, 15 May 2026 16:03:20 +0800 Subject: [PATCH 1/5] [Enh] Add fused-add, quant kernels and tests --- kernels/layernorm_kernel.py | 653 ++++++++++++++++++++++++++++++++ tests/kernels/test_layernorm.py | 395 ++++++++++++------- 2 files changed, 916 insertions(+), 132 deletions(-) diff --git a/kernels/layernorm_kernel.py b/kernels/layernorm_kernel.py index 390f26832..0f0c2052a 100644 --- a/kernels/layernorm_kernel.py +++ b/kernels/layernorm_kernel.py @@ -19,6 +19,9 @@ from flydsl.compiler.kernel_function import CompilationContext from flydsl.expr import arith, const_expr, gpu, range_constexpr from flydsl.expr import math as fmath +from flydsl.expr.arith import ArithValue +from flydsl.expr.numeric import Float32, Numeric, Uint32 +from flydsl.expr.typing import Int32, T from flydsl.expr.vector import ReductionOp, full from flydsl.runtime.device import get_rocm_arch as get_hip_arch from flydsl.utils.smem_allocator import SmemAllocator, SmemPtr @@ -326,3 +329,653 @@ def launch_layernorm( ) return launch_layernorm + + +def _quant_dtype_to_elem_type(dtype_str: str): + if dtype_str in ("i8", "int8"): + return T.i8 + raise ValueError(f"unsupported quant dtype: {dtype_str!r} (expected 'i8' or 'int8')") + + +def _quant_dtype_max(dtype_str: str) -> float: + if dtype_str in ("i8", "int8"): + return 127.0 + raise ValueError(f"unsupported quant dtype: {dtype_str!r} (expected 'i8' or 'int8')") + + +def build_fused_add_layernorm_module(M: int, N: int, dtype_str: str): + arch = get_hip_arch() + RED_SLOTS = max(1, (BLOCK_THREADS + WARP_SIZE - 1) // WARP_SIZE) + elem_bits = 32 if dtype_str == "f32" else 16 + + allocator = SmemAllocator(None, arch=arch) + f32_bytes = 4 + sum_offset = allocator._align(allocator.ptr, 16) + allocator.ptr = sum_offset + RED_SLOTS * f32_bytes + sumsq_offset = allocator._align(allocator.ptr, 16) + allocator.ptr = sumsq_offset + RED_SLOTS * f32_bytes + + @flyc.kernel + def fused_add_layernorm_kernel( + Input: fx.Tensor, + ResidualIn: fx.Tensor, + Gamma: fx.Tensor, + Beta: fx.Tensor, + Output: fx.Tensor, + ResidualOut: fx.Tensor, + ): + bid = fx.block_idx.x + tid = fx.thread_idx.x + + elem_dtype = dtype_to_elem_type(dtype_str) + fm_fast = arith.FastMathFlags.fast + eps_c = EPS + + base_ptr = allocator.get_base() + s_sum = SmemPtr(base_ptr, sum_offset, fx.Float32.ir_type, shape=(RED_SLOTS,)) + s_sumsq = SmemPtr(base_ptr, sumsq_offset, fx.Float32.ir_type, shape=(RED_SLOTS,)) + s_sum.get() + s_sumsq.get() + + def wave_reduce_add(x): + w = x + for _sh_exp in range_constexpr(int(math.log2(WARP_SIZE))): + off = WARP_SIZE // (2 << _sh_exp) + peer = w.shuffle_xor(off, WARP_SIZE) + w = w.addf(peer, fastmath=fm_fast) + return w + + def block_reduce_add2(val0, val1): + if const_expr(RED_SLOTS == 1): + return wave_reduce_add(val0), wave_reduce_add(val1) + + lane = tid % WARP_SIZE + wave = tid // WARP_SIZE + w0 = wave_reduce_add(val0) + w1 = wave_reduce_add(val1) + + if lane == 0: + SmemPtr.store(s_sum, w0, [wave]) + SmemPtr.store(s_sumsq, w1, [wave]) + gpu.barrier() + + if wave == 0: + in_range = lane < RED_SLOTS + lane_safe = in_range.select(lane, 0) + v0 = SmemPtr.load(s_sum, [lane_safe]) + v1 = SmemPtr.load(s_sumsq, [lane_safe]) + ww0 = in_range.select(v0, 0.0) + ww1 = in_range.select(v1, 0.0) + ww0 = wave_reduce_add(ww0) + ww1 = wave_reduce_add(ww1) + + if lane == 0: + SmemPtr.store(s_sum, ww0, [0]) + SmemPtr.store(s_sumsq, ww1, [0]) + gpu.barrier() + + return SmemPtr.load(s_sum, [0]), SmemPtr.load(s_sumsq, [0]) + + def compute_mean_rstd(sum_val, sumsq_val): + inv_n = 1.0 / float(N) + mean = sum_val * inv_n + mean_sq = sumsq_val * inv_n + var = mean_sq - mean * mean + var = (var < 0.0).select(0.0, var) + return mean, fmath.rsqrt(var + eps_c, fastmath=fm_fast) + + Input_buf = fx.rocdl.make_buffer_tensor(Input) + ResidualIn_buf = fx.rocdl.make_buffer_tensor(ResidualIn) + Gamma_buf = fx.rocdl.make_buffer_tensor(Gamma) + Beta_buf = fx.rocdl.make_buffer_tensor(Beta) + Output_buf = fx.rocdl.make_buffer_tensor(Output) + ResidualOut_buf = fx.rocdl.make_buffer_tensor(ResidualOut) + + row_in = fx.slice(Input_buf, (bid, None)) + row_residual_in = fx.slice(ResidualIn_buf, (bid, None)) + row_out = fx.slice(Output_buf, (bid, None)) + row_residual_out = fx.slice(ResidualOut_buf, (bid, None)) + + copy_atom_s = fx.make_copy_atom( + fx.rocdl.BufferCopy16b() if elem_bits <= 16 else fx.rocdl.BufferCopy32b(), + elem_bits, + ) + + in_div = fx.logical_divide(row_in, fx.make_layout(1, 1)) + residual_in_div = fx.logical_divide(row_residual_in, fx.make_layout(1, 1)) + gamma_div = fx.logical_divide(Gamma_buf, fx.make_layout(1, 1)) + beta_div = fx.logical_divide(Beta_buf, fx.make_layout(1, 1)) + out_div = fx.logical_divide(row_out, fx.make_layout(1, 1)) + residual_out_div = fx.logical_divide(row_residual_out, fx.make_layout(1, 1)) + + def _load_scalar(divided_tensor, index): + view = fx.slice(divided_tensor, (None, index)) + r = fx.make_rmem_tensor(1, elem_dtype) + fx.copy_atom_call(copy_atom_s, view, r) + return fx.memref_load_vec(r)[0] + + def _store_scalar(divided_tensor, index, val): + r = fx.make_rmem_tensor(1, elem_dtype) + ts = full(1, elem_dtype(val), elem_dtype) + fx.memref_store_vec(ts, r) + view = fx.slice(divided_tensor, (None, index)) + fx.copy_atom_call(copy_atom_s, r, view) + + c_zero_f = fx.Float32(0.0) + thread_sum = c_zero_f + thread_sumsq = c_zero_f + + for base_idx_int in range_constexpr(0, N, BLOCK_THREADS): + idx = tid + base_idx_int + is_valid = idx < N + idx_safe = is_valid.select(idx, 0) + x_e = _load_scalar(in_div, idx_safe) + r_e = _load_scalar(residual_in_div, idx_safe) + x = x_e if dtype_str == "f32" else x_e.to(fx.Float32) + residual = r_e if dtype_str == "f32" else r_e.to(fx.Float32) + added = x + residual + added_safe = is_valid.select(added, c_zero_f) + thread_sum = thread_sum + added_safe + thread_sumsq = thread_sumsq + is_valid.select(added * added, c_zero_f) + if idx < N: + _store_scalar(residual_out_div, idx, added) + + sum_val, sumsq_val = block_reduce_add2(thread_sum, thread_sumsq) + mean, rstd = compute_mean_rstd(sum_val, sumsq_val) + + for base_idx_int in range_constexpr(0, N, BLOCK_THREADS): + idx = tid + base_idx_int + if idx < N: + x_e = _load_scalar(in_div, idx) + r_e = _load_scalar(residual_in_div, idx) + g_e = _load_scalar(gamma_div, idx) + b_e = _load_scalar(beta_div, idx) + x = x_e if dtype_str == "f32" else x_e.to(fx.Float32) + residual = r_e if dtype_str == "f32" else r_e.to(fx.Float32) + g = g_e if dtype_str == "f32" else g_e.to(fx.Float32) + b = b_e if dtype_str == "f32" else b_e.to(fx.Float32) + y = ((x + residual) - mean) * rstd + y = y * g + b + if const_expr(dtype_str == "f32"): + y_e = y + else: + y_e = y.to(elem_dtype) + _store_scalar(out_div, idx, y_e) + + @flyc.jit + def launch_fused_add_layernorm( + Input: fx.Tensor, + ResidualIn: fx.Tensor, + Gamma: fx.Tensor, + Beta: fx.Tensor, + Output: fx.Tensor, + ResidualOut: fx.Tensor, + m_in: fx.Int32, + stream: fx.Stream = fx.Stream(None), + ): + allocator.finalized = False + ctx = CompilationContext.get_current() + with InsertionPoint(ctx.gpu_module_body): + allocator.finalize() + + launcher = fused_add_layernorm_kernel(Input, ResidualIn, Gamma, Beta, Output, ResidualOut) + launcher.launch( + grid=(m_in, 1, 1), + block=(BLOCK_THREADS, 1, 1), + stream=stream, + ) + + return launch_fused_add_layernorm + + +def _build_layernorm_quant_module( + M: int, + N: int, + dtype_str: str, + *, + is_smooth: bool, + is_fused_add: bool, + quant_dtype_str: str = "i8", +): + arch = get_hip_arch() + RED_SLOTS = max(1, (BLOCK_THREADS + WARP_SIZE - 1) // WARP_SIZE) + elem_bits = 32 if dtype_str == "f32" else 16 + quant_dtype_max = _quant_dtype_max(quant_dtype_str) + + allocator = SmemAllocator(None, arch=arch) + f32_bytes = 4 + sum_offset = allocator._align(allocator.ptr, 16) + allocator.ptr = sum_offset + RED_SLOTS * f32_bytes + sumsq_offset = allocator._align(allocator.ptr, 16) + allocator.ptr = sumsq_offset + RED_SLOTS * f32_bytes + + @flyc.kernel + def layernorm_quant_kernel( + Input: fx.Tensor, + ResidualIn: fx.Tensor, + Gamma: fx.Tensor, + Beta: fx.Tensor, + XScale: fx.Tensor, + YScale: fx.Tensor, + Output: fx.Tensor, + ResidualOut: fx.Tensor, + ): + bid = fx.block_idx.x + tid = fx.thread_idx.x + + elem_dtype = dtype_to_elem_type(dtype_str) + quant_dtype = Numeric.from_ir_type(_quant_dtype_to_elem_type(quant_dtype_str)) + compute_type = T.f32 + + fm_fast = arith.FastMathFlags.fast + eps_c = arith.constant(EPS, type=compute_type) + n_float = arith.constant(float(N), type=compute_type) + c_zero_f = arith.constant(0.0, type=compute_type) + c_one_f = arith.constant(1.0, type=compute_type) + c_neg_inf = arith.constant(float("-inf"), type=compute_type) + c_dtype_max = arith.constant(quant_dtype_max, type=compute_type) + + base_ptr = allocator.get_base() + s_sum = SmemPtr(base_ptr, sum_offset, T.f32, shape=(RED_SLOTS,)) + s_sumsq = SmemPtr(base_ptr, sumsq_offset, T.f32, shape=(RED_SLOTS,)) + s_sum.get() + s_sumsq.get() + + YScale_buf = fx.rocdl.make_buffer_tensor(YScale) + yscale_div = fx.logical_divide(YScale_buf, fx.make_layout(1, 1)) + scale_copy_atom = fx.make_copy_atom(fx.rocdl.BufferCopy32b(), 32) + + def _store_yscale(index, val): + r = fx.make_rmem_tensor(1, Float32) + ts = full(1, Float32(val), Float32) + fx.memref_store_vec(ts, r) + fx.copy_atom_call(scale_copy_atom, r, fx.slice(yscale_div, (None, index))) + + def wave_reduce_add(x): + width_i32 = fx.Int32(WARP_SIZE) + w = x + for _sh_exp in range_constexpr(int(math.log2(WARP_SIZE))): + off = fx.Int32(WARP_SIZE // (2 << _sh_exp)) + peer = w.shuffle_xor(off, width_i32) + w = w.addf(peer, fastmath=fm_fast) + return w + + def wave_reduce_max(x): + width_i32 = fx.Int32(WARP_SIZE) + w = x + for _sh_exp in range_constexpr(int(math.log2(WARP_SIZE))): + off = fx.Int32(WARP_SIZE // (2 << _sh_exp)) + peer = w.shuffle_xor(off, width_i32) + w = w.maximumf(peer) + return w + + def block_reduce_add2(val0, val1): + if const_expr(RED_SLOTS == 1): + return wave_reduce_add(val0), wave_reduce_add(val1) + + lane = tid % WARP_SIZE + wave = tid // WARP_SIZE + w0 = wave_reduce_add(val0) + w1 = wave_reduce_add(val1) + + if lane == fx.Int32(0): + wave_idx = ArithValue(wave).index_cast(T.index) + SmemPtr.store(s_sum, w0, [wave_idx]) + SmemPtr.store(s_sumsq, w1, [wave_idx]) + gpu.barrier() + + if wave == fx.Int32(0): + in_range = lane < RED_SLOTS + lane_safe = in_range.select(lane, fx.Int32(0)) + lane_safe_idx = ArithValue(lane_safe).index_cast(T.index) + v0 = SmemPtr.load(s_sum, [lane_safe_idx]) + v1 = SmemPtr.load(s_sumsq, [lane_safe_idx]) + ww0 = in_range.select(v0, c_zero_f) + ww1 = in_range.select(v1, c_zero_f) + ww0 = wave_reduce_add(ww0) + ww1 = wave_reduce_add(ww1) + if lane == fx.Int32(0): + c0_idx = fx.Index(0) + SmemPtr.store(s_sum, ww0, [c0_idx]) + SmemPtr.store(s_sumsq, ww1, [c0_idx]) + gpu.barrier() + + c0_idx = fx.Index(0) + return SmemPtr.load(s_sum, [c0_idx]), SmemPtr.load(s_sumsq, [c0_idx]) + + def block_reduce_max(val): + if const_expr(RED_SLOTS == 1): + return wave_reduce_max(val) + + lane = tid % WARP_SIZE + wave = tid // WARP_SIZE + w = wave_reduce_max(val) + if lane == fx.Int32(0): + wave_idx = ArithValue(wave).index_cast(T.index) + SmemPtr.store(s_sum, w, [wave_idx]) + gpu.barrier() + + if wave == fx.Int32(0): + in_range = lane < RED_SLOTS + lane_safe = in_range.select(lane, fx.Int32(0)) + lane_safe_idx = ArithValue(lane_safe).index_cast(T.index) + v = SmemPtr.load(s_sum, [lane_safe_idx]) + ww = in_range.select(v, c_neg_inf) + ww = wave_reduce_max(ww) + if lane == fx.Int32(0): + c0_idx = fx.Index(0) + SmemPtr.store(s_sum, ww, [c0_idx]) + gpu.barrier() + + c0_idx = fx.Index(0) + return SmemPtr.load(s_sum, [c0_idx]) + + Input_buf = fx.rocdl.make_buffer_tensor(Input) + Gamma_buf = fx.rocdl.make_buffer_tensor(Gamma) + Beta_buf = fx.rocdl.make_buffer_tensor(Beta) + Output_buf = fx.rocdl.make_buffer_tensor(Output) + if const_expr(is_fused_add): + ResidualIn_buf = fx.rocdl.make_buffer_tensor(ResidualIn) + ResidualOut_buf = fx.rocdl.make_buffer_tensor(ResidualOut) + if const_expr(is_smooth): + XScale_buf = fx.rocdl.make_buffer_tensor(XScale) + + row_in = fx.slice(Input_buf, (bid, None)) + row_out = fx.slice(Output_buf, (bid, None)) + if const_expr(is_fused_add): + row_residual_in = fx.slice(ResidualIn_buf, (bid, None)) + row_residual_out = fx.slice(ResidualOut_buf, (bid, None)) + + copy_atom_s = fx.make_copy_atom( + fx.rocdl.BufferCopy16b() if elem_bits <= 16 else fx.rocdl.BufferCopy32b(), + elem_bits, + ) + copy_atom_qs = fx.make_copy_atom(fx.rocdl.BufferCopy(8), 8) + + in_div = fx.logical_divide(row_in, fx.make_layout(1, 1)) + gamma_div = fx.logical_divide(Gamma_buf, fx.make_layout(1, 1)) + beta_div = fx.logical_divide(Beta_buf, fx.make_layout(1, 1)) + out_div = fx.logical_divide(row_out, fx.make_layout(1, 1)) + if const_expr(is_fused_add): + residual_in_div = fx.logical_divide(row_residual_in, fx.make_layout(1, 1)) + residual_out_div = fx.logical_divide(row_residual_out, fx.make_layout(1, 1)) + if const_expr(is_smooth): + xscale_div = fx.logical_divide(XScale_buf, fx.make_layout(1, 1)) + + def _load_scalar(divided_tensor, index): + view = fx.slice(divided_tensor, (None, index)) + r = fx.make_rmem_tensor(1, elem_dtype) + fx.copy_atom_call(copy_atom_s, view, r) + return fx.memref_load_vec(r)[0].ir_value() + + def _store_elem_scalar(divided_tensor, index, val): + r = fx.make_rmem_tensor(1, elem_dtype) + ts = full(1, elem_dtype(val), elem_dtype) + fx.memref_store_vec(ts, r) + view = fx.slice(divided_tensor, (None, index)) + fx.copy_atom_call(copy_atom_s, r, view) + + def _store_quant_scalar(divided_tensor, index, val): + r = fx.make_rmem_tensor(1, quant_dtype) + ts = full(1, quant_dtype(val), quant_dtype) + fx.memref_store_vec(ts, r) + view = fx.slice(divided_tensor, (None, index)) + fx.copy_atom_call(copy_atom_qs, r, view) + + def _abs_scalar(val): + is_neg = val < c_zero_f + neg_val = c_zero_f - ArithValue(val) + return is_neg.select(neg_val, val) + + def _load_input_value(index): + x_e = _load_scalar(in_div, index) + x = x_e if dtype_str == "f32" else x_e.extf(compute_type) + if const_expr(is_fused_add): + r_e = _load_scalar(residual_in_div, index) + residual = r_e if dtype_str == "f32" else r_e.extf(compute_type) + return ArithValue(x) + ArithValue(residual) + return x + + thread_sum = c_zero_f + thread_sumsq = c_zero_f + c_N_i32 = Int32(N) + c0_i = Int32(0) + + for base_idx_int in range_constexpr(0, N, BLOCK_THREADS): + idx = tid + base_idx_int + is_valid = idx < c_N_i32 + idx_safe = is_valid.select(idx, c0_i) + x = _load_input_value(idx_safe) + x2 = ArithValue(x) * ArithValue(x) + thread_sum = ArithValue(thread_sum) + is_valid.select(x, c_zero_f) + thread_sumsq = ArithValue(thread_sumsq) + is_valid.select(x2, c_zero_f) + if const_expr(is_fused_add): + if arith.cmpi(arith.CmpIPredicate.ult, idx, c_N_i32): + _store_elem_scalar(residual_out_div, idx, x) + + sum_val, sumsq_val = block_reduce_add2(thread_sum, thread_sumsq) + mean = ArithValue(sum_val) / n_float + var = ArithValue(sumsq_val) / n_float - ArithValue(mean) * ArithValue(mean) + var = (var < c_zero_f).select(c_zero_f, var) + rstd = (var + eps_c).rsqrt(fastmath=fm_fast) + + thread_row_max = c_zero_f + for base_idx_int in range_constexpr(0, N, BLOCK_THREADS): + idx = tid + base_idx_int + is_valid = idx < c_N_i32 + idx_safe = is_valid.select(idx, c0_i) + x = _load_input_value(idx_safe) + g_e = _load_scalar(gamma_div, idx_safe) + b_e = _load_scalar(beta_div, idx_safe) + g = g_e if dtype_str == "f32" else g_e.extf(compute_type) + b = b_e if dtype_str == "f32" else b_e.extf(compute_type) + y = (ArithValue(x) - ArithValue(mean)) * ArithValue(rstd) + y = ArithValue(y) * ArithValue(g) + ArithValue(b) + if const_expr(is_smooth): + s_e = _load_scalar(xscale_div, idx_safe) + s = s_e if dtype_str == "f32" else s_e.extf(compute_type) + y = ArithValue(y) * ArithValue(s) + y_abs = _abs_scalar(y) + thread_row_max = thread_row_max.maximumf(is_valid.select(y_abs, c_zero_f)) + + row_max = block_reduce_max(thread_row_max) + scale = ArithValue(row_max) / c_dtype_max + final_scale = (scale == c_zero_f).select(c_one_f, scale) + + if tid == fx.Int32(0): + _store_yscale(bid, final_scale) + + inv_scale = ArithValue(c_one_f) / ArithValue(final_scale) + + for base_idx_int in range_constexpr(0, N, BLOCK_THREADS): + idx = tid + base_idx_int + if arith.cmpi(arith.CmpIPredicate.ult, idx, c_N_i32): + x = _load_input_value(idx) + g_e = _load_scalar(gamma_div, idx) + b_e = _load_scalar(beta_div, idx) + g = g_e if dtype_str == "f32" else g_e.extf(compute_type) + b = b_e if dtype_str == "f32" else b_e.extf(compute_type) + y = (ArithValue(x) - ArithValue(mean)) * ArithValue(rstd) + y = ArithValue(y) * ArithValue(g) + ArithValue(b) + if const_expr(is_smooth): + s_e = _load_scalar(xscale_div, idx) + s = s_e if dtype_str == "f32" else s_e.extf(compute_type) + y = ArithValue(y) * ArithValue(s) + q = ArithValue(y) * ArithValue(inv_scale) + q_i8 = quant_dtype(q) + _store_quant_scalar(out_div, idx, q_i8) + + if is_fused_add: + if is_smooth: + + @flyc.jit + def launch_fused_add_layernorm_smoothquant( + Input: fx.Tensor, + ResidualIn: fx.Tensor, + Gamma: fx.Tensor, + Beta: fx.Tensor, + XScale: fx.Tensor, + Output: fx.Tensor, + ResidualOut: fx.Tensor, + YScale: fx.Tensor, + m_in: fx.Int32, + stream: fx.Stream = fx.Stream(None), + ): + allocator.finalized = False + ctx = CompilationContext.get_current() + with InsertionPoint(ctx.gpu_module_body): + allocator.finalize() + + launcher = layernorm_quant_kernel( + Input, ResidualIn, Gamma, Beta, XScale, YScale, Output, ResidualOut + ) + launcher.launch( + grid=(m_in, 1, 1), + block=(BLOCK_THREADS, 1, 1), + stream=stream, + ) + + return launch_fused_add_layernorm_smoothquant + + @flyc.jit + def launch_fused_add_layernorm_dynamicquant( + Input: fx.Tensor, + ResidualIn: fx.Tensor, + Gamma: fx.Tensor, + Beta: fx.Tensor, + Output: fx.Tensor, + ResidualOut: fx.Tensor, + YScale: fx.Tensor, + m_in: fx.Int32, + stream: fx.Stream = fx.Stream(None), + ): + allocator.finalized = False + ctx = CompilationContext.get_current() + with InsertionPoint(ctx.gpu_module_body): + allocator.finalize() + + launcher = layernorm_quant_kernel( + Input, ResidualIn, Gamma, Beta, Gamma, YScale, Output, ResidualOut + ) + launcher.launch( + grid=(m_in, 1, 1), + block=(BLOCK_THREADS, 1, 1), + stream=stream, + ) + + return launch_fused_add_layernorm_dynamicquant + + if is_smooth: + + @flyc.jit + def launch_layernorm_smoothquant( + Input: fx.Tensor, + Gamma: fx.Tensor, + Beta: fx.Tensor, + XScale: fx.Tensor, + Output: fx.Tensor, + YScale: fx.Tensor, + m_in: fx.Int32, + stream: fx.Stream = fx.Stream(None), + ): + allocator.finalized = False + ctx = CompilationContext.get_current() + with InsertionPoint(ctx.gpu_module_body): + allocator.finalize() + + launcher = layernorm_quant_kernel(Input, Input, Gamma, Beta, XScale, YScale, Output, Output) + launcher.launch( + grid=(m_in, 1, 1), + block=(BLOCK_THREADS, 1, 1), + stream=stream, + ) + + return launch_layernorm_smoothquant + + @flyc.jit + def launch_layernorm_dynamicquant( + Input: fx.Tensor, + Gamma: fx.Tensor, + Beta: fx.Tensor, + Output: fx.Tensor, + YScale: fx.Tensor, + m_in: fx.Int32, + stream: fx.Stream = fx.Stream(None), + ): + allocator.finalized = False + ctx = CompilationContext.get_current() + with InsertionPoint(ctx.gpu_module_body): + allocator.finalize() + + launcher = layernorm_quant_kernel(Input, Input, Gamma, Beta, Gamma, YScale, Output, Output) + launcher.launch( + grid=(m_in, 1, 1), + block=(BLOCK_THREADS, 1, 1), + stream=stream, + ) + + return launch_layernorm_dynamicquant + + +def build_layernorm_dynamicquant_module( + M: int, + N: int, + dtype_str: str, + quant_dtype_str: str = "i8", +): + return _build_layernorm_quant_module( + M, + N, + dtype_str, + is_smooth=False, + is_fused_add=False, + quant_dtype_str=quant_dtype_str, + ) + + +def build_layernorm_smoothquant_module( + M: int, + N: int, + dtype_str: str, + quant_dtype_str: str = "i8", +): + return _build_layernorm_quant_module( + M, + N, + dtype_str, + is_smooth=True, + is_fused_add=False, + quant_dtype_str=quant_dtype_str, + ) + + +def build_fused_add_layernorm_dynamicquant_module( + M: int, + N: int, + dtype_str: str, + quant_dtype_str: str = "i8", +): + return _build_layernorm_quant_module( + M, + N, + dtype_str, + is_smooth=False, + is_fused_add=True, + quant_dtype_str=quant_dtype_str, + ) + + +def build_fused_add_layernorm_smoothquant_module( + M: int, + N: int, + dtype_str: str, + quant_dtype_str: str = "i8", +): + return _build_layernorm_quant_module( + M, + N, + dtype_str, + is_smooth=True, + is_fused_add=True, + quant_dtype_str=quant_dtype_str, + ) diff --git a/tests/kernels/test_layernorm.py b/tests/kernels/test_layernorm.py index 76e66df0b..502fcec9f 100644 --- a/tests/kernels/test_layernorm.py +++ b/tests/kernels/test_layernorm.py @@ -3,15 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 # Copyright (c) 2025 FlyDSL Project Contributors -""" -LayerNorm Operator Test -Implementation of a Block-wise LayerNorm: -- Grid: (M, 1, 1) -> One block per row -- Block: (N, 1, 1) -> Threads handle columns -- Shared Memory: Used for reduction (mean and variance) - -LayerNorm(x) = (x - mean) / sqrt(var + eps) * gamma + beta -""" +"""LayerNorm operator tests, including AIter/Triton-aligned variants.""" import os @@ -39,176 +31,315 @@ EPS: float = 1e-5 from kernels.layernorm_kernel import ( + build_fused_add_layernorm_dynamicquant_module, + build_fused_add_layernorm_module, + build_fused_add_layernorm_smoothquant_module, + build_layernorm_dynamicquant_module, build_layernorm_module, - KERNEL_NAME as LAYERNORM_KERNEL_NAME, + build_layernorm_smoothquant_module, BLOCK_THREADS, + KERNEL_NAME as LAYERNORM_KERNEL_NAME, ) WARMUP_ITERS = 10 BENCH_ITERS = 100 -def run_test(M: int, N: int, dtype: str = "f32"): - print(f"\nTesting LayerNorm (M={M}, N={N}, dtype={dtype})") - try: - launch_fn = build_layernorm_module(M, N, dtype) - except ValueError as e: - print(f"[FAIL] Compile failed: {e}") - return False, None +def _torch_dtype(dtype: str): + if dtype == "f32": + return DTYPE_FP32 + if dtype == "f16": + return DTYPE_FP16 + if dtype == "bf16": + return DTYPE_BF16 + raise ValueError(f"unsupported dtype: {dtype}") + + +def _atol(dtype: str) -> float: + if dtype == "f32": + return 1e-4 + if dtype == "f16": + return 1e-2 + if dtype == "bf16": + return 2e-2 + raise ValueError(f"unsupported dtype: {dtype}") + + +def _get_layernorm_configs(): + shapes_env = os.environ.get("ROCDSL_LAYERNORM_SHAPES", "").strip() + if shapes_env: + configs = [] + for part in shapes_env.split(";"): + p = part.strip() + if not p: + continue + m_s, n_s, dt = [x.strip() for x in p.split(",")] + configs.append((int(m_s), int(n_s), dt)) + return configs + + return [ + (4, 128, "f32"), + (16, 2000, "f16"), + (32, 8192, "bf16"), + ] + +def _make_inputs(M: int, N: int, dtype: str): + torch_dtype = _torch_dtype(dtype) torch.manual_seed(42) input_t = torch.randn((M, N), device="cuda", dtype=DTYPE_FP32) gamma_t = torch.rand((N,), device="cuda", dtype=DTYPE_FP32) beta_t = torch.rand((N,), device="cuda", dtype=DTYPE_FP32) - - if dtype == "f32": - input_dev = input_t.contiguous() - gamma_dev = gamma_t.contiguous() - beta_dev = beta_t.contiguous() - output_dev = torch.empty((M, N), device="cuda", dtype=DTYPE_FP32) - input_ref = input_dev.to(DTYPE_FP32) - gamma_ref = gamma_dev.to(DTYPE_FP32) - beta_ref = beta_dev.to(DTYPE_FP32) - atol = 1e-4 - elif dtype == "f16": - input_dev = input_t.to(DTYPE_FP16).contiguous() - gamma_dev = gamma_t.to(DTYPE_FP16).contiguous() - beta_dev = beta_t.to(DTYPE_FP16).contiguous() - output_dev = torch.empty((M, N), device="cuda", dtype=DTYPE_FP16) - input_ref = input_dev.to(DTYPE_FP32) - gamma_ref = gamma_dev.to(DTYPE_FP32) - beta_ref = beta_dev.to(DTYPE_FP32) - atol = 1e-2 - elif dtype == "bf16": - input_dev = input_t.to(DTYPE_BF16).contiguous() - gamma_dev = gamma_t.to(DTYPE_BF16).contiguous() - beta_dev = beta_t.to(DTYPE_BF16).contiguous() - output_dev = torch.empty((M, N), device="cuda", dtype=DTYPE_BF16) - input_ref = input_dev.to(DTYPE_FP32) - gamma_ref = gamma_dev.to(DTYPE_FP32) - beta_ref = beta_dev.to(DTYPE_FP32) - atol = 2e-2 - else: - raise ValueError(f"unsupported dtype: {dtype}") - - # PyTorch CPU Reference (variance uses unbiased=False) - x = input_ref - gamma = gamma_ref - beta = beta_ref + residual_t = torch.randn((M, N), device="cuda", dtype=DTYPE_FP32) + xscale_t = torch.rand((N,), device="cuda", dtype=DTYPE_FP32) + 0.5 + + return ( + input_t.to(torch_dtype).contiguous(), + gamma_t.to(torch_dtype).contiguous(), + beta_t.to(torch_dtype).contiguous(), + residual_t.to(torch_dtype).contiguous(), + xscale_t.to(torch_dtype).contiguous(), + ) + + +def _reference_layernorm(input_dev, gamma_dev, beta_dev, *, residual_dev=None, xscale_dev=None): + x = input_dev.to(DTYPE_FP32) + residual_out = None + if residual_dev is not None: + residual_out = x + residual_dev.to(DTYPE_FP32) + x = residual_out + gamma = gamma_dev.to(DTYPE_FP32) + beta = beta_dev.to(DTYPE_FP32) mean = x.mean(dim=1, keepdim=True) var = x.var(dim=1, keepdim=True, unbiased=False) expected = (x - mean) / torch.sqrt(var + EPS) * gamma + beta - expected = expected.to(DTYPE_FP32) + if xscale_dev is not None: + expected = expected * xscale_dev.to(DTYPE_FP32) + return expected, residual_out + + +def _reference_quant(input_dev, gamma_dev, beta_dev, *, residual_dev=None, xscale_dev=None): + expected, residual_out = _reference_layernorm( + input_dev, + gamma_dev, + beta_dev, + residual_dev=residual_dev, + xscale_dev=xscale_dev, + ) + yscale = expected.abs().amax(dim=1) / 127.0 + yscale = torch.where(yscale == 0, torch.ones_like(yscale), yscale) + q = torch.clamp(torch.trunc(expected / yscale.unsqueeze(1)), -127, 127).to(torch.int8) + return expected, residual_out, q, yscale + + +def _bench_aiter(M: int, N: int, dtype: str, mode: str): + torch_dtype = _torch_dtype(dtype) + try: + from aiter.ops.triton.normalization.norm import ( + layer_norm, + layernorm2d_fwd_with_add, + layernorm2d_fwd_with_add_dynamicquant, + layernorm2d_fwd_with_add_smoothquant, + layernorm2d_fwd_with_dynamicquant, + layernorm2d_fwd_with_smoothquant, + ) + except Exception as e: + print(f"[Perf] AIter layernorm {mode} skipped: {type(e).__name__}: {e!r}") + return None + + x = torch.randn((M, N), device="cuda", dtype=torch_dtype).contiguous() + w = torch.rand((N,), device="cuda", dtype=torch_dtype).contiguous() + b = torch.rand((N,), device="cuda", dtype=torch_dtype).contiguous() + residual = torch.randn((M, N), device="cuda", dtype=torch_dtype).contiguous() + residual_out = torch.empty_like(x) + xscale = (torch.rand((N,), device="cuda", dtype=torch_dtype) + 0.5).contiguous() + q_out = torch.empty((M, N), device="cuda", dtype=torch.int8) + yscale = torch.empty((M, 1), device="cuda", dtype=torch.float32) + + if mode == "base": + run = lambda: layer_norm(x, w, b, EPS) + elif mode == "fused_add": + out = torch.empty_like(x) + run = lambda: layernorm2d_fwd_with_add(out, x, residual, residual_out, w, b, EPS) + elif mode == "dynamicquant": + run = lambda: layernorm2d_fwd_with_dynamicquant(q_out, x, yscale, w, b, EPS) + elif mode == "smoothquant": + run = lambda: layernorm2d_fwd_with_smoothquant(q_out, x, xscale, yscale, w, b, EPS) + elif mode == "fused_add_dynamicquant": + run = lambda: layernorm2d_fwd_with_add_dynamicquant(q_out, x, residual, residual_out, yscale, w, b, EPS) + elif mode == "fused_add_smoothquant": + run = lambda: layernorm2d_fwd_with_add_smoothquant(q_out, x, residual, residual_out, xscale, yscale, w, b, EPS) + else: + raise ValueError(f"unsupported mode: {mode}") - print("Launching kernel...") + aiter_us = bench_gpu_us_torch(run, warmup=WARMUP_ITERS, iters=BENCH_ITERS) + print(f"[Perf] AIter layernorm {mode} gpu: {aiter_us:.1f} us") + return aiter_us + + +def run_test(M: int, N: int, dtype: str = "f32"): + print(f"\nTesting LayerNorm (M={M}, N={N}, dtype={dtype})") + launch_fn = build_layernorm_module(M, N, dtype) + input_dev, gamma_dev, beta_dev, _, _ = _make_inputs(M, N, dtype) + output_dev = torch.empty((M, N), device="cuda", dtype=_torch_dtype(dtype)) + expected, _ = _reference_layernorm(input_dev, gamma_dev, beta_dev) + atol = _atol(dtype) stream = torch.cuda.current_stream() def kernel_launch(): launch_fn(input_dev, gamma_dev, beta_dev, output_dev, M, stream=stream) - # One run for correctness visibility, then benchmark via shared harness. kernel_launch() torch.cuda.synchronize() - _, avg_us = run_perftest(lambda: (kernel_launch(), torch.cuda.synchronize()), num_iters=BENCH_ITERS, num_warmup=WARMUP_ITERS) - torch.cuda.synchronize() flydsl_gpu_us = None if os.environ.get("ROCDSL_COMPARE_AITER", "0") == "1": flydsl_gpu_us = bench_gpu_us_torch(kernel_launch, warmup=WARMUP_ITERS, iters=BENCH_ITERS) - avg_ms = avg_us / 1000.0 - elem_bytes = 4 if dtype == "f32" else 2 - total_bytes = (2 * M * N + 2 * N) * elem_bytes # read input + write output + (gamma+beta) - bandwidth_gbs = total_bytes / (avg_us / 1e6) / 1e9 - print(f"Kernel avg time: {avg_ms:.4f} ms via run_perftest (warmup={WARMUP_ITERS}, iters={BENCH_ITERS})") - print(f"Bandwidth: {bandwidth_gbs:.2f} GB/s") - if flydsl_gpu_us is not None: - print(f"[Perf] FlyDSL layernorm gpu: {flydsl_gpu_us:.1f} us") - - # Verification (pure torch style; compute max error in torch) - output_ref = output_dev.to(DTYPE_FP32) - - error = (output_ref - expected).abs().max().item() + error = (output_dev.to(DTYPE_FP32) - expected).abs().max().item() + print(f"Kernel avg time: {avg_us / 1000.0:.4f} ms") print(f"Max absolute error: {error:.2e} (atol={atol})") + return error < atol, flydsl_gpu_us - if error < atol: - print("PASSED") - ok = True - else: - print("FAILED") - print("First row Expected:") - print(expected[0, :5]) - print("First row Actual:") - print(output_ref[0, :5]) - ok = False - return ok, flydsl_gpu_us +def run_fused_add_test(M: int, N: int, dtype: str): + print(f"\nTesting LayerNorm fused_add (M={M}, N={N}, dtype={dtype})") + launch_fn = build_fused_add_layernorm_module(M, N, dtype) + input_dev, gamma_dev, beta_dev, residual_dev, _ = _make_inputs(M, N, dtype) + output_dev = torch.empty((M, N), device="cuda", dtype=_torch_dtype(dtype)) + residual_out_dev = torch.empty_like(output_dev) + expected, residual_expected = _reference_layernorm(input_dev, gamma_dev, beta_dev, residual_dev=residual_dev) + atol = _atol(dtype) + stream = torch.cuda.current_stream() -def test_all(): - print("="*80) - print("Running LayerNorm Tests") - print("="*80) + def kernel_launch(): + launch_fn(input_dev, residual_dev, gamma_dev, beta_dev, output_dev, residual_out_dev, M, stream=stream) - shapes_env = os.environ.get("ROCDSL_LAYERNORM_SHAPES", "").strip() - if shapes_env: - configs = [] - for part in shapes_env.split(";"): - p = part.strip() - if not p: - continue - m_s, n_s, dt = [x.strip() for x in p.split(",")] - configs.append((int(m_s), int(n_s), dt)) + _, avg_us = run_perftest(lambda: (kernel_launch(), torch.cuda.synchronize()), num_iters=BENCH_ITERS, num_warmup=WARMUP_ITERS) + flydsl_gpu_us = None + if os.environ.get("ROCDSL_COMPARE_AITER", "0") == "1": + flydsl_gpu_us = bench_gpu_us_torch(kernel_launch, warmup=WARMUP_ITERS, iters=BENCH_ITERS) + out_err = (output_dev.to(DTYPE_FP32) - expected).abs().max().item() + residual_err = (residual_out_dev.to(DTYPE_FP32) - residual_expected).abs().max().item() + print(f"Kernel avg time: {avg_us / 1000.0:.4f} ms") + print(f"Max output error: {out_err:.2e} (atol={atol})") + print(f"Max residual error: {residual_err:.2e} (atol={atol})") + return out_err < atol and residual_err < atol, flydsl_gpu_us + + +def run_quant_test(M: int, N: int, dtype: str, *, is_smooth: bool, is_fused_add: bool): + mode = "" + if is_fused_add: + mode += "fused_add_" + mode += "smoothquant" if is_smooth else "dynamicquant" + print(f"\nTesting LayerNorm {mode} (M={M}, N={N}, dtype={dtype})") + + if is_fused_add and is_smooth: + launch_fn = build_fused_add_layernorm_smoothquant_module(M, N, dtype) + elif is_fused_add: + launch_fn = build_fused_add_layernorm_dynamicquant_module(M, N, dtype) + elif is_smooth: + launch_fn = build_layernorm_smoothquant_module(M, N, dtype) else: - configs = [ - # (64, 256, "f32"), # Aligned - # (128, 1024, "f32"), # Aligned - # (32, 128, "f16"), # Aligned - # (64, 2000, "f32"), # Unaligned (tail handling) - # (16, 512, "bf16"), # BF16 - # (1024, 8192, "bf16"), # BF16 - (32768, 8192, "bf16"), - ] + launch_fn = build_layernorm_dynamicquant_module(M, N, dtype) + + input_dev, gamma_dev, beta_dev, residual_dev, xscale_dev = _make_inputs(M, N, dtype) + output_dev = torch.empty((M, N), device="cuda", dtype=torch.int8) + yscale_dev = torch.empty((M,), device="cuda", dtype=DTYPE_FP32) + residual_out_dev = torch.empty((M, N), device="cuda", dtype=_torch_dtype(dtype)) + + expected, residual_expected, q_ref, yscale_ref = _reference_quant( + input_dev, + gamma_dev, + beta_dev, + residual_dev=residual_dev if is_fused_add else None, + xscale_dev=xscale_dev if is_smooth else None, + ) + + stream = torch.cuda.current_stream() + + def kernel_launch(): + if is_fused_add and is_smooth: + launch_fn(input_dev, residual_dev, gamma_dev, beta_dev, xscale_dev, output_dev, residual_out_dev, yscale_dev, M, stream=stream) + elif is_fused_add: + launch_fn(input_dev, residual_dev, gamma_dev, beta_dev, output_dev, residual_out_dev, yscale_dev, M, stream=stream) + elif is_smooth: + launch_fn(input_dev, gamma_dev, beta_dev, xscale_dev, output_dev, yscale_dev, M, stream=stream) + else: + launch_fn(input_dev, gamma_dev, beta_dev, output_dev, yscale_dev, M, stream=stream) + + _, avg_us = run_perftest(lambda: (kernel_launch(), torch.cuda.synchronize()), num_iters=BENCH_ITERS, num_warmup=WARMUP_ITERS) + flydsl_gpu_us = None + if os.environ.get("ROCDSL_COMPARE_AITER", "0") == "1": + flydsl_gpu_us = bench_gpu_us_torch(kernel_launch, warmup=WARMUP_ITERS, iters=BENCH_ITERS) + + q_diff = (output_dev.to(torch.int16) - q_ref.to(torch.int16)).abs().max().item() + scale_diff = (yscale_dev.cpu() - yscale_ref.cpu()).abs().max().item() + recon = output_dev.to(DTYPE_FP32) * yscale_dev.unsqueeze(1) + recon_err = (recon - expected).abs().max().item() + ok = q_diff <= 1 and scale_diff < 1e-2 and recon_err < 0.3 + + if is_fused_add: + residual_err = (residual_out_dev.to(DTYPE_FP32) - residual_expected).abs().max().item() + ok = ok and residual_err < _atol(dtype) + print(f"Max residual error: {residual_err:.2e} (atol={_atol(dtype)})") + + print(f"Kernel avg time: {avg_us / 1000.0:.4f} ms") + print(f"Max quant diff: {q_diff}") + print(f"Max scale diff: {scale_diff:.2e}") + print(f"Max recon error: {recon_err:.2e}") + return ok, flydsl_gpu_us + +def _run_configs(op: str, runner): do_compare = os.environ.get("ROCDSL_COMPARE_AITER", "0") == "1" perf_rows = [] - failures = 0 - for M, N, dtype in configs: - ok, flydsl_gpu_us = run_test(M, N, dtype) + for M, N, dtype in _get_layernorm_configs(): + ok, flydsl_gpu_us = runner(M, N, dtype) if not ok: failures += 1 - if do_compare: - import torch aiter_us = None if maybe_enable_aiter(): - try: - from aiter.ops.triton.norm import layer_norm as aiter_layer_norm - x = torch.randn((M, N), device="cuda", dtype=DTYPE_BF16 if dtype == "bf16" else (DTYPE_FP16 if dtype == "f16" else DTYPE_FP32)) - w = torch.rand((N,), device="cuda", dtype=x.dtype) - b = torch.rand((N,), device="cuda", dtype=x.dtype) + aiter_us = _bench_aiter(M, N, dtype, op) + perf_rows.append(PerfRow(op=f"layernorm_{op}", shape=f"{M}x{N}", dtype=dtype, flydsl_gpu_us=flydsl_gpu_us, aiter_gpu_us=aiter_us)) + if do_compare and perf_rows: + print_perf_table(perf_rows) + if failures != 0: + raise SystemExit(f"{failures} {op} tests failed") - def run_aiter(): - aiter_layer_norm(x, w, b, EPS) - aiter_us = bench_gpu_us_torch(run_aiter, warmup=WARMUP_ITERS, iters=BENCH_ITERS) - print(f"[Perf] AIter layernorm gpu: {aiter_us:.1f} us") - except Exception as e: - print(f"[Perf] AIter layernorm skipped: {type(e).__name__}: {e!r}") +def test_all(): + print("=" * 80) + print("Running LayerNorm Tests") + print("=" * 80) + _run_configs("base", run_test) - perf_rows.append(PerfRow(op="layernorm", shape=f"{M}x{N}", dtype=dtype, flydsl_gpu_us=flydsl_gpu_us, aiter_gpu_us=aiter_us)) - print("\n" + "="*80) - if failures == 0: - print("ALL TESTS PASSED") - else: - print(f"{failures} TESTS FAILED") - print("="*80) - if do_compare and perf_rows: - print_perf_table(perf_rows) - # Ensure a non-zero exit code on failure for shell wrappers. - if failures != 0: - raise SystemExit(1) +def test_fused_add_layernorm(): + _run_configs("fused_add", run_fused_add_test) + + +def test_layernorm_dynamicquant(): + _run_configs("dynamicquant", lambda M, N, dtype: run_quant_test(M, N, dtype, is_smooth=False, is_fused_add=False)) + + +def test_layernorm_smoothquant(): + _run_configs("smoothquant", lambda M, N, dtype: run_quant_test(M, N, dtype, is_smooth=True, is_fused_add=False)) + + +def test_fused_add_layernorm_dynamicquant(): + _run_configs("fused_add_dynamicquant", lambda M, N, dtype: run_quant_test(M, N, dtype, is_smooth=False, is_fused_add=True)) + + +def test_fused_add_layernorm_smoothquant(): + _run_configs("fused_add_smoothquant", lambda M, N, dtype: run_quant_test(M, N, dtype, is_smooth=True, is_fused_add=True)) + if __name__ == "__main__": test_all() - + test_fused_add_layernorm() + test_layernorm_dynamicquant() + test_layernorm_smoothquant() + test_fused_add_layernorm_dynamicquant() + test_fused_add_layernorm_smoothquant() From 394339cf149f6e738971aabef7abe3832c14d095 Mon Sep 17 00:00:00 2001 From: Junlin Chen Date: Tue, 19 May 2026 16:18:51 +0800 Subject: [PATCH 2/5] Align layernorm quant paths with current register API Use the current fx.* numeric and register helper style in layernorm quant variants so they stay consistent with main's RMSNorm cleanup. --- kernels/layernorm_kernel.py | 160 ++++++++++++++++-------------------- 1 file changed, 70 insertions(+), 90 deletions(-) diff --git a/kernels/layernorm_kernel.py b/kernels/layernorm_kernel.py index 0f0c2052a..d4af782f0 100644 --- a/kernels/layernorm_kernel.py +++ b/kernels/layernorm_kernel.py @@ -19,9 +19,6 @@ from flydsl.compiler.kernel_function import CompilationContext from flydsl.expr import arith, const_expr, gpu, range_constexpr from flydsl.expr import math as fmath -from flydsl.expr.arith import ArithValue -from flydsl.expr.numeric import Float32, Numeric, Uint32 -from flydsl.expr.typing import Int32, T from flydsl.expr.vector import ReductionOp, full from flydsl.runtime.device import get_rocm_arch as get_hip_arch from flydsl.utils.smem_allocator import SmemAllocator, SmemPtr @@ -333,7 +330,7 @@ def launch_layernorm( def _quant_dtype_to_elem_type(dtype_str: str): if dtype_str in ("i8", "int8"): - return T.i8 + return fx.Int8 raise ValueError(f"unsupported quant dtype: {dtype_str!r} (expected 'i8' or 'int8')") @@ -564,20 +561,19 @@ def layernorm_quant_kernel( tid = fx.thread_idx.x elem_dtype = dtype_to_elem_type(dtype_str) - quant_dtype = Numeric.from_ir_type(_quant_dtype_to_elem_type(quant_dtype_str)) - compute_type = T.f32 + quant_dtype = _quant_dtype_to_elem_type(quant_dtype_str) fm_fast = arith.FastMathFlags.fast - eps_c = arith.constant(EPS, type=compute_type) - n_float = arith.constant(float(N), type=compute_type) - c_zero_f = arith.constant(0.0, type=compute_type) - c_one_f = arith.constant(1.0, type=compute_type) - c_neg_inf = arith.constant(float("-inf"), type=compute_type) - c_dtype_max = arith.constant(quant_dtype_max, type=compute_type) + eps_c = EPS + n_float = float(N) + c_zero_f = fx.Float32(0.0) + c_one_f = fx.Float32(1.0) + c_neg_inf = fx.Float32(float("-inf")) + c_dtype_max = fx.Float32(quant_dtype_max) base_ptr = allocator.get_base() - s_sum = SmemPtr(base_ptr, sum_offset, T.f32, shape=(RED_SLOTS,)) - s_sumsq = SmemPtr(base_ptr, sumsq_offset, T.f32, shape=(RED_SLOTS,)) + s_sum = SmemPtr(base_ptr, sum_offset, fx.Float32.ir_type, shape=(RED_SLOTS,)) + s_sumsq = SmemPtr(base_ptr, sumsq_offset, fx.Float32.ir_type, shape=(RED_SLOTS,)) s_sum.get() s_sumsq.get() @@ -586,26 +582,24 @@ def layernorm_quant_kernel( scale_copy_atom = fx.make_copy_atom(fx.rocdl.BufferCopy32b(), 32) def _store_yscale(index, val): - r = fx.make_rmem_tensor(1, Float32) - ts = full(1, Float32(val), Float32) + r = fx.make_rmem_tensor(1, fx.Float32) + ts = full(1, fx.Float32(val), fx.Float32) fx.memref_store_vec(ts, r) fx.copy_atom_call(scale_copy_atom, r, fx.slice(yscale_div, (None, index))) def wave_reduce_add(x): - width_i32 = fx.Int32(WARP_SIZE) w = x for _sh_exp in range_constexpr(int(math.log2(WARP_SIZE))): - off = fx.Int32(WARP_SIZE // (2 << _sh_exp)) - peer = w.shuffle_xor(off, width_i32) + off = WARP_SIZE // (2 << _sh_exp) + peer = w.shuffle_xor(off, WARP_SIZE) w = w.addf(peer, fastmath=fm_fast) return w def wave_reduce_max(x): - width_i32 = fx.Int32(WARP_SIZE) w = x for _sh_exp in range_constexpr(int(math.log2(WARP_SIZE))): - off = fx.Int32(WARP_SIZE // (2 << _sh_exp)) - peer = w.shuffle_xor(off, width_i32) + off = WARP_SIZE // (2 << _sh_exp) + peer = w.shuffle_xor(off, WARP_SIZE) w = w.maximumf(peer) return w @@ -618,30 +612,26 @@ def block_reduce_add2(val0, val1): w0 = wave_reduce_add(val0) w1 = wave_reduce_add(val1) - if lane == fx.Int32(0): - wave_idx = ArithValue(wave).index_cast(T.index) - SmemPtr.store(s_sum, w0, [wave_idx]) - SmemPtr.store(s_sumsq, w1, [wave_idx]) + if lane == 0: + SmemPtr.store(s_sum, w0, [wave]) + SmemPtr.store(s_sumsq, w1, [wave]) gpu.barrier() - if wave == fx.Int32(0): + if wave == 0: in_range = lane < RED_SLOTS - lane_safe = in_range.select(lane, fx.Int32(0)) - lane_safe_idx = ArithValue(lane_safe).index_cast(T.index) - v0 = SmemPtr.load(s_sum, [lane_safe_idx]) - v1 = SmemPtr.load(s_sumsq, [lane_safe_idx]) + lane_safe = in_range.select(lane, 0) + v0 = SmemPtr.load(s_sum, [lane_safe]) + v1 = SmemPtr.load(s_sumsq, [lane_safe]) ww0 = in_range.select(v0, c_zero_f) ww1 = in_range.select(v1, c_zero_f) ww0 = wave_reduce_add(ww0) ww1 = wave_reduce_add(ww1) - if lane == fx.Int32(0): - c0_idx = fx.Index(0) - SmemPtr.store(s_sum, ww0, [c0_idx]) - SmemPtr.store(s_sumsq, ww1, [c0_idx]) + if lane == 0: + SmemPtr.store(s_sum, ww0, [0]) + SmemPtr.store(s_sumsq, ww1, [0]) gpu.barrier() - c0_idx = fx.Index(0) - return SmemPtr.load(s_sum, [c0_idx]), SmemPtr.load(s_sumsq, [c0_idx]) + return SmemPtr.load(s_sum, [0]), SmemPtr.load(s_sumsq, [0]) def block_reduce_max(val): if const_expr(RED_SLOTS == 1): @@ -650,25 +640,21 @@ def block_reduce_max(val): lane = tid % WARP_SIZE wave = tid // WARP_SIZE w = wave_reduce_max(val) - if lane == fx.Int32(0): - wave_idx = ArithValue(wave).index_cast(T.index) - SmemPtr.store(s_sum, w, [wave_idx]) + if lane == 0: + SmemPtr.store(s_sum, w, [wave]) gpu.barrier() - if wave == fx.Int32(0): + if wave == 0: in_range = lane < RED_SLOTS - lane_safe = in_range.select(lane, fx.Int32(0)) - lane_safe_idx = ArithValue(lane_safe).index_cast(T.index) - v = SmemPtr.load(s_sum, [lane_safe_idx]) + lane_safe = in_range.select(lane, 0) + v = SmemPtr.load(s_sum, [lane_safe]) ww = in_range.select(v, c_neg_inf) ww = wave_reduce_max(ww) - if lane == fx.Int32(0): - c0_idx = fx.Index(0) - SmemPtr.store(s_sum, ww, [c0_idx]) + if lane == 0: + SmemPtr.store(s_sum, ww, [0]) gpu.barrier() - c0_idx = fx.Index(0) - return SmemPtr.load(s_sum, [c0_idx]) + return SmemPtr.load(s_sum, [0]) Input_buf = fx.rocdl.make_buffer_tensor(Input) Gamma_buf = fx.rocdl.make_buffer_tensor(Gamma) @@ -706,7 +692,7 @@ def _load_scalar(divided_tensor, index): view = fx.slice(divided_tensor, (None, index)) r = fx.make_rmem_tensor(1, elem_dtype) fx.copy_atom_call(copy_atom_s, view, r) - return fx.memref_load_vec(r)[0].ir_value() + return fx.memref_load_vec(r)[0] def _store_elem_scalar(divided_tensor, index, val): r = fx.make_rmem_tensor(1, elem_dtype) @@ -724,85 +710,83 @@ def _store_quant_scalar(divided_tensor, index, val): def _abs_scalar(val): is_neg = val < c_zero_f - neg_val = c_zero_f - ArithValue(val) + neg_val = c_zero_f - val return is_neg.select(neg_val, val) def _load_input_value(index): x_e = _load_scalar(in_div, index) - x = x_e if dtype_str == "f32" else x_e.extf(compute_type) + x = x_e if dtype_str == "f32" else x_e.to(fx.Float32) if const_expr(is_fused_add): r_e = _load_scalar(residual_in_div, index) - residual = r_e if dtype_str == "f32" else r_e.extf(compute_type) - return ArithValue(x) + ArithValue(residual) + residual = r_e if dtype_str == "f32" else r_e.to(fx.Float32) + return x + residual return x thread_sum = c_zero_f thread_sumsq = c_zero_f - c_N_i32 = Int32(N) - c0_i = Int32(0) for base_idx_int in range_constexpr(0, N, BLOCK_THREADS): idx = tid + base_idx_int - is_valid = idx < c_N_i32 - idx_safe = is_valid.select(idx, c0_i) + is_valid = idx < N + idx_safe = is_valid.select(idx, 0) x = _load_input_value(idx_safe) - x2 = ArithValue(x) * ArithValue(x) - thread_sum = ArithValue(thread_sum) + is_valid.select(x, c_zero_f) - thread_sumsq = ArithValue(thread_sumsq) + is_valid.select(x2, c_zero_f) + x2 = x * x + thread_sum = thread_sum + is_valid.select(x, c_zero_f) + thread_sumsq = thread_sumsq + is_valid.select(x2, c_zero_f) if const_expr(is_fused_add): - if arith.cmpi(arith.CmpIPredicate.ult, idx, c_N_i32): + if idx < N: _store_elem_scalar(residual_out_div, idx, x) sum_val, sumsq_val = block_reduce_add2(thread_sum, thread_sumsq) - mean = ArithValue(sum_val) / n_float - var = ArithValue(sumsq_val) / n_float - ArithValue(mean) * ArithValue(mean) + mean = sum_val / n_float + var = sumsq_val / n_float - mean * mean var = (var < c_zero_f).select(c_zero_f, var) rstd = (var + eps_c).rsqrt(fastmath=fm_fast) thread_row_max = c_zero_f for base_idx_int in range_constexpr(0, N, BLOCK_THREADS): idx = tid + base_idx_int - is_valid = idx < c_N_i32 - idx_safe = is_valid.select(idx, c0_i) + is_valid = idx < N + idx_safe = is_valid.select(idx, 0) x = _load_input_value(idx_safe) g_e = _load_scalar(gamma_div, idx_safe) b_e = _load_scalar(beta_div, idx_safe) - g = g_e if dtype_str == "f32" else g_e.extf(compute_type) - b = b_e if dtype_str == "f32" else b_e.extf(compute_type) - y = (ArithValue(x) - ArithValue(mean)) * ArithValue(rstd) - y = ArithValue(y) * ArithValue(g) + ArithValue(b) + g = g_e if dtype_str == "f32" else g_e.to(fx.Float32) + b = b_e if dtype_str == "f32" else b_e.to(fx.Float32) + y = (x - mean) * rstd + y = y * g + b if const_expr(is_smooth): s_e = _load_scalar(xscale_div, idx_safe) - s = s_e if dtype_str == "f32" else s_e.extf(compute_type) - y = ArithValue(y) * ArithValue(s) + s = s_e if dtype_str == "f32" else s_e.to(fx.Float32) + y = y * s y_abs = _abs_scalar(y) thread_row_max = thread_row_max.maximumf(is_valid.select(y_abs, c_zero_f)) row_max = block_reduce_max(thread_row_max) - scale = ArithValue(row_max) / c_dtype_max + scale = row_max / c_dtype_max final_scale = (scale == c_zero_f).select(c_one_f, scale) - if tid == fx.Int32(0): + if tid == 0: _store_yscale(bid, final_scale) - inv_scale = ArithValue(c_one_f) / ArithValue(final_scale) + inv_scale = c_one_f / final_scale for base_idx_int in range_constexpr(0, N, BLOCK_THREADS): idx = tid + base_idx_int - if arith.cmpi(arith.CmpIPredicate.ult, idx, c_N_i32): + if idx < N: x = _load_input_value(idx) g_e = _load_scalar(gamma_div, idx) b_e = _load_scalar(beta_div, idx) - g = g_e if dtype_str == "f32" else g_e.extf(compute_type) - b = b_e if dtype_str == "f32" else b_e.extf(compute_type) - y = (ArithValue(x) - ArithValue(mean)) * ArithValue(rstd) - y = ArithValue(y) * ArithValue(g) + ArithValue(b) + g = g_e if dtype_str == "f32" else g_e.to(fx.Float32) + b = b_e if dtype_str == "f32" else b_e.to(fx.Float32) + y = (x - mean) * rstd + y = y * g + b if const_expr(is_smooth): s_e = _load_scalar(xscale_div, idx) - s = s_e if dtype_str == "f32" else s_e.extf(compute_type) - y = ArithValue(y) * ArithValue(s) - q = ArithValue(y) * ArithValue(inv_scale) - q_i8 = quant_dtype(q) + s = s_e if dtype_str == "f32" else s_e.to(fx.Float32) + y = y * s + q = y * inv_scale + q_i8 = q.to(quant_dtype) _store_quant_scalar(out_div, idx, q_i8) if is_fused_add: @@ -826,9 +810,7 @@ def launch_fused_add_layernorm_smoothquant( with InsertionPoint(ctx.gpu_module_body): allocator.finalize() - launcher = layernorm_quant_kernel( - Input, ResidualIn, Gamma, Beta, XScale, YScale, Output, ResidualOut - ) + launcher = layernorm_quant_kernel(Input, ResidualIn, Gamma, Beta, XScale, YScale, Output, ResidualOut) launcher.launch( grid=(m_in, 1, 1), block=(BLOCK_THREADS, 1, 1), @@ -854,9 +836,7 @@ def launch_fused_add_layernorm_dynamicquant( with InsertionPoint(ctx.gpu_module_body): allocator.finalize() - launcher = layernorm_quant_kernel( - Input, ResidualIn, Gamma, Beta, Gamma, YScale, Output, ResidualOut - ) + launcher = layernorm_quant_kernel(Input, ResidualIn, Gamma, Beta, Gamma, YScale, Output, ResidualOut) launcher.launch( grid=(m_in, 1, 1), block=(BLOCK_THREADS, 1, 1), From fec4f9d26a6fcae6652d106c58a6c6d344f4b704 Mon Sep 17 00:00:00 2001 From: Junlin Chen Date: Wed, 20 May 2026 18:47:56 +0800 Subject: [PATCH 3/5] Add layernorm variant kernels and tests --- kernels/layernorm_kernel.py | 40 ++++++---------- tests/kernels/test_layernorm.py | 85 +++++++++++++++++++++++---------- 2 files changed, 75 insertions(+), 50 deletions(-) diff --git a/kernels/layernorm_kernel.py b/kernels/layernorm_kernel.py index d4af782f0..155510468 100644 --- a/kernels/layernorm_kernel.py +++ b/kernels/layernorm_kernel.py @@ -15,13 +15,10 @@ import flydsl.compiler as flyc import flydsl.expr as fx -from flydsl._mlir.ir import InsertionPoint -from flydsl.compiler.kernel_function import CompilationContext from flydsl.expr import arith, const_expr, gpu, range_constexpr from flydsl.expr import math as fmath from flydsl.expr.vector import ReductionOp, full from flydsl.runtime.device import get_rocm_arch as get_hip_arch -from flydsl.utils.smem_allocator import SmemAllocator, SmemPtr from kernels.kernels_common import dtype_to_elem_type, get_warp_size KERNEL_NAME = "layernorm" @@ -44,12 +41,10 @@ def build_layernorm_module(M: int, N: int, dtype_str: str): elem_bits = 32 if dtype_str == "f32" else 16 # ── Shared-memory allocation for block reductions ───────────────────── - allocator = SmemAllocator(None, arch=arch) - f32_bytes = 4 - sum_offset = allocator._align(allocator.ptr, 16) - allocator.ptr = sum_offset + RED_SLOTS * f32_bytes - sumsq_offset = allocator._align(allocator.ptr, 16) - allocator.ptr = sumsq_offset + RED_SLOTS * f32_bytes + @fx.struct + class SharedStorage: + s_sum: fx.Array[fx.Float32, RED_SLOTS, 16] + s_sumsq: fx.Array[fx.Float32, RED_SLOTS, 16] # ── GPU kernel ──────────────────────────────────────────────────────── @flyc.kernel @@ -66,11 +61,9 @@ def layernorm_kernel( fm_fast = arith.FastMathFlags.fast eps_c = EPS - base_ptr = allocator.get_base() - s_sum = SmemPtr(base_ptr, sum_offset, fx.Float32.ir_type, shape=(RED_SLOTS,)) - s_sumsq = SmemPtr(base_ptr, sumsq_offset, fx.Float32.ir_type, shape=(RED_SLOTS,)) - s_sum.get() - s_sumsq.get() + lds = fx.SharedAllocator().allocate(SharedStorage).peek() + s_sum = lds.s_sum.view(fx.make_layout(RED_SLOTS, 1)) + s_sumsq = lds.s_sumsq.view(fx.make_layout(RED_SLOTS, 1)) # ── helpers: wave / block reduction ─────────────────────────────── def wave_reduce_add(x): @@ -92,26 +85,26 @@ def block_reduce_add2(val0, val1): w1 = wave_reduce_add(val1) if lane == 0: - SmemPtr.store(s_sum, w0, [wave]) - SmemPtr.store(s_sumsq, w1, [wave]) + fx.memref_store(w0, s_sum, wave) + fx.memref_store(w1, s_sumsq, wave) gpu.barrier() if wave == 0: in_range = lane < RED_SLOTS lane_safe = in_range.select(lane, 0) - v0 = SmemPtr.load(s_sum, [lane_safe]) - v1 = SmemPtr.load(s_sumsq, [lane_safe]) + v0 = fx.memref_load(s_sum, lane_safe) + v1 = fx.memref_load(s_sumsq, lane_safe) ww0 = in_range.select(v0, 0.0) ww1 = in_range.select(v1, 0.0) ww0 = wave_reduce_add(ww0) ww1 = wave_reduce_add(ww1) if lane == 0: - SmemPtr.store(s_sum, ww0, [0]) - SmemPtr.store(s_sumsq, ww1, [0]) + fx.memref_store(ww0, s_sum, 0) + fx.memref_store(ww1, s_sumsq, 0) gpu.barrier() - return SmemPtr.load(s_sum, [0]), SmemPtr.load(s_sumsq, [0]) + return fx.memref_load(s_sum, 0), fx.memref_load(s_sumsq, 0) def compute_mean_rstd(sum_val, sumsq_val): inv_n = 1.0 / float(N) @@ -313,11 +306,6 @@ def launch_layernorm( m_in: fx.Int32, stream: fx.Stream = fx.Stream(None), ): - allocator.finalized = False - ctx = CompilationContext.get_current() - with InsertionPoint(ctx.gpu_module_body): - allocator.finalize() - launcher = layernorm_kernel(Input, Gamma, Beta, Output) launcher.launch( grid=(m_in, 1, 1), diff --git a/tests/kernels/test_layernorm.py b/tests/kernels/test_layernorm.py index 502fcec9f..733f448e0 100644 --- a/tests/kernels/test_layernorm.py +++ b/tests/kernels/test_layernorm.py @@ -7,6 +7,18 @@ import os +import pytest + +from kernels.layernorm_kernel import ( + BLOCK_THREADS, + KERNEL_NAME as LAYERNORM_KERNEL_NAME, + build_fused_add_layernorm_dynamicquant_module, + build_fused_add_layernorm_module, + build_fused_add_layernorm_smoothquant_module, + build_layernorm_dynamicquant_module, + build_layernorm_module, + build_layernorm_smoothquant_module, +) from tests.test_common import run_perftest from tests.kernels.benchmark_common import ( PerfRow, @@ -14,7 +26,6 @@ maybe_enable_aiter, print_perf_table, ) -import pytest pytestmark = [pytest.mark.l2_device, pytest.mark.rocm_lower] @@ -30,16 +41,6 @@ DTYPE_BF16 = torch.bfloat16 EPS: float = 1e-5 -from kernels.layernorm_kernel import ( - build_fused_add_layernorm_dynamicquant_module, - build_fused_add_layernorm_module, - build_fused_add_layernorm_smoothquant_module, - build_layernorm_dynamicquant_module, - build_layernorm_module, - build_layernorm_smoothquant_module, - BLOCK_THREADS, - KERNEL_NAME as LAYERNORM_KERNEL_NAME, -) WARMUP_ITERS = 10 BENCH_ITERS = 100 @@ -77,10 +78,15 @@ def _get_layernorm_configs(): configs.append((int(m_s), int(n_s), dt)) return configs + # Prefer N multiples of 2048 to exercise the fast path. return [ - (4, 128, "f32"), - (16, 2000, "f16"), - (32, 8192, "bf16"), + # (64, 256, "f32"), # Aligned + # (128, 1024, "f32"), # Aligned + # (32, 128, "f16"), # Aligned + # (64, 2000, "f32"), # Unaligned (tail handling) + # (16, 512, "bf16"), # BF16 + # (1024, 8192, "bf16"), # BF16 + (32768, 8192, "bf16"), ] @@ -191,7 +197,9 @@ def kernel_launch(): kernel_launch() torch.cuda.synchronize() - _, avg_us = run_perftest(lambda: (kernel_launch(), torch.cuda.synchronize()), num_iters=BENCH_ITERS, num_warmup=WARMUP_ITERS) + _, avg_us = run_perftest( + lambda: (kernel_launch(), torch.cuda.synchronize()), num_iters=BENCH_ITERS, num_warmup=WARMUP_ITERS + ) flydsl_gpu_us = None if os.environ.get("ROCDSL_COMPARE_AITER", "0") == "1": flydsl_gpu_us = bench_gpu_us_torch(kernel_launch, warmup=WARMUP_ITERS, iters=BENCH_ITERS) @@ -214,7 +222,9 @@ def run_fused_add_test(M: int, N: int, dtype: str): def kernel_launch(): launch_fn(input_dev, residual_dev, gamma_dev, beta_dev, output_dev, residual_out_dev, M, stream=stream) - _, avg_us = run_perftest(lambda: (kernel_launch(), torch.cuda.synchronize()), num_iters=BENCH_ITERS, num_warmup=WARMUP_ITERS) + _, avg_us = run_perftest( + lambda: (kernel_launch(), torch.cuda.synchronize()), num_iters=BENCH_ITERS, num_warmup=WARMUP_ITERS + ) flydsl_gpu_us = None if os.environ.get("ROCDSL_COMPARE_AITER", "0") == "1": flydsl_gpu_us = bench_gpu_us_torch(kernel_launch, warmup=WARMUP_ITERS, iters=BENCH_ITERS) @@ -259,15 +269,30 @@ def run_quant_test(M: int, N: int, dtype: str, *, is_smooth: bool, is_fused_add: def kernel_launch(): if is_fused_add and is_smooth: - launch_fn(input_dev, residual_dev, gamma_dev, beta_dev, xscale_dev, output_dev, residual_out_dev, yscale_dev, M, stream=stream) + launch_fn( + input_dev, + residual_dev, + gamma_dev, + beta_dev, + xscale_dev, + output_dev, + residual_out_dev, + yscale_dev, + M, + stream=stream, + ) elif is_fused_add: - launch_fn(input_dev, residual_dev, gamma_dev, beta_dev, output_dev, residual_out_dev, yscale_dev, M, stream=stream) + launch_fn( + input_dev, residual_dev, gamma_dev, beta_dev, output_dev, residual_out_dev, yscale_dev, M, stream=stream + ) elif is_smooth: launch_fn(input_dev, gamma_dev, beta_dev, xscale_dev, output_dev, yscale_dev, M, stream=stream) else: launch_fn(input_dev, gamma_dev, beta_dev, output_dev, yscale_dev, M, stream=stream) - _, avg_us = run_perftest(lambda: (kernel_launch(), torch.cuda.synchronize()), num_iters=BENCH_ITERS, num_warmup=WARMUP_ITERS) + _, avg_us = run_perftest( + lambda: (kernel_launch(), torch.cuda.synchronize()), num_iters=BENCH_ITERS, num_warmup=WARMUP_ITERS + ) flydsl_gpu_us = None if os.environ.get("ROCDSL_COMPARE_AITER", "0") == "1": flydsl_gpu_us = bench_gpu_us_torch(kernel_launch, warmup=WARMUP_ITERS, iters=BENCH_ITERS) @@ -302,14 +327,22 @@ def _run_configs(op: str, runner): aiter_us = None if maybe_enable_aiter(): aiter_us = _bench_aiter(M, N, dtype, op) - perf_rows.append(PerfRow(op=f"layernorm_{op}", shape=f"{M}x{N}", dtype=dtype, flydsl_gpu_us=flydsl_gpu_us, aiter_gpu_us=aiter_us)) + perf_rows.append( + PerfRow( + op=f"layernorm_{op}", + shape=f"{M}x{N}", + dtype=dtype, + flydsl_gpu_us=flydsl_gpu_us, + aiter_gpu_us=aiter_us, + ) + ) if do_compare and perf_rows: print_perf_table(perf_rows) if failures != 0: raise SystemExit(f"{failures} {op} tests failed") -def test_all(): +def test_layernorm_base(): print("=" * 80) print("Running LayerNorm Tests") print("=" * 80) @@ -329,15 +362,19 @@ def test_layernorm_smoothquant(): def test_fused_add_layernorm_dynamicquant(): - _run_configs("fused_add_dynamicquant", lambda M, N, dtype: run_quant_test(M, N, dtype, is_smooth=False, is_fused_add=True)) + _run_configs( + "fused_add_dynamicquant", lambda M, N, dtype: run_quant_test(M, N, dtype, is_smooth=False, is_fused_add=True) + ) def test_fused_add_layernorm_smoothquant(): - _run_configs("fused_add_smoothquant", lambda M, N, dtype: run_quant_test(M, N, dtype, is_smooth=True, is_fused_add=True)) + _run_configs( + "fused_add_smoothquant", lambda M, N, dtype: run_quant_test(M, N, dtype, is_smooth=True, is_fused_add=True) + ) if __name__ == "__main__": - test_all() + test_layernorm_base() test_fused_add_layernorm() test_layernorm_dynamicquant() test_layernorm_smoothquant() From efc477c13c8712f5526d77100e4f1cee721eb892 Mon Sep 17 00:00:00 2001 From: Junlin Chen Date: Thu, 21 May 2026 18:56:23 +0800 Subject: [PATCH 4/5] Migrate layernorm variants to SharedAllocator --- kernels/layernorm_kernel.py | 105 ++++++++++++------------------------ 1 file changed, 35 insertions(+), 70 deletions(-) diff --git a/kernels/layernorm_kernel.py b/kernels/layernorm_kernel.py index 155510468..79e59209c 100644 --- a/kernels/layernorm_kernel.py +++ b/kernels/layernorm_kernel.py @@ -61,9 +61,9 @@ def layernorm_kernel( fm_fast = arith.FastMathFlags.fast eps_c = EPS - lds = fx.SharedAllocator().allocate(SharedStorage).peek() - s_sum = lds.s_sum.view(fx.make_layout(RED_SLOTS, 1)) - s_sumsq = lds.s_sumsq.view(fx.make_layout(RED_SLOTS, 1)) + lds = fx.SharedAllocator().allocate(SharedStorage) + s_sum = lds.s_sum.peek().view(fx.make_layout(RED_SLOTS, 1)) + s_sumsq = lds.s_sumsq.peek().view(fx.make_layout(RED_SLOTS, 1)) # ── helpers: wave / block reduction ─────────────────────────────── def wave_reduce_add(x): @@ -329,16 +329,13 @@ def _quant_dtype_max(dtype_str: str) -> float: def build_fused_add_layernorm_module(M: int, N: int, dtype_str: str): - arch = get_hip_arch() RED_SLOTS = max(1, (BLOCK_THREADS + WARP_SIZE - 1) // WARP_SIZE) elem_bits = 32 if dtype_str == "f32" else 16 - allocator = SmemAllocator(None, arch=arch) - f32_bytes = 4 - sum_offset = allocator._align(allocator.ptr, 16) - allocator.ptr = sum_offset + RED_SLOTS * f32_bytes - sumsq_offset = allocator._align(allocator.ptr, 16) - allocator.ptr = sumsq_offset + RED_SLOTS * f32_bytes + @fx.struct + class SharedStorage: + s_sum: fx.Array[fx.Float32, RED_SLOTS, 16] + s_sumsq: fx.Array[fx.Float32, RED_SLOTS, 16] @flyc.kernel def fused_add_layernorm_kernel( @@ -356,11 +353,9 @@ def fused_add_layernorm_kernel( fm_fast = arith.FastMathFlags.fast eps_c = EPS - base_ptr = allocator.get_base() - s_sum = SmemPtr(base_ptr, sum_offset, fx.Float32.ir_type, shape=(RED_SLOTS,)) - s_sumsq = SmemPtr(base_ptr, sumsq_offset, fx.Float32.ir_type, shape=(RED_SLOTS,)) - s_sum.get() - s_sumsq.get() + lds = fx.SharedAllocator().allocate(SharedStorage) + s_sum = lds.s_sum.peek().view(fx.make_layout(RED_SLOTS, 1)) + s_sumsq = lds.s_sumsq.peek().view(fx.make_layout(RED_SLOTS, 1)) def wave_reduce_add(x): w = x @@ -380,26 +375,26 @@ def block_reduce_add2(val0, val1): w1 = wave_reduce_add(val1) if lane == 0: - SmemPtr.store(s_sum, w0, [wave]) - SmemPtr.store(s_sumsq, w1, [wave]) + fx.memref_store(w0, s_sum, wave) + fx.memref_store(w1, s_sumsq, wave) gpu.barrier() if wave == 0: in_range = lane < RED_SLOTS lane_safe = in_range.select(lane, 0) - v0 = SmemPtr.load(s_sum, [lane_safe]) - v1 = SmemPtr.load(s_sumsq, [lane_safe]) + v0 = fx.memref_load(s_sum, lane_safe) + v1 = fx.memref_load(s_sumsq, lane_safe) ww0 = in_range.select(v0, 0.0) ww1 = in_range.select(v1, 0.0) ww0 = wave_reduce_add(ww0) ww1 = wave_reduce_add(ww1) if lane == 0: - SmemPtr.store(s_sum, ww0, [0]) - SmemPtr.store(s_sumsq, ww1, [0]) + fx.memref_store(ww0, s_sum, 0) + fx.memref_store(ww1, s_sumsq, 0) gpu.barrier() - return SmemPtr.load(s_sum, [0]), SmemPtr.load(s_sumsq, [0]) + return fx.memref_load(s_sum, 0), fx.memref_load(s_sumsq, 0) def compute_mean_rstd(sum_val, sumsq_val): inv_n = 1.0 / float(N) @@ -498,11 +493,6 @@ def launch_fused_add_layernorm( m_in: fx.Int32, stream: fx.Stream = fx.Stream(None), ): - allocator.finalized = False - ctx = CompilationContext.get_current() - with InsertionPoint(ctx.gpu_module_body): - allocator.finalize() - launcher = fused_add_layernorm_kernel(Input, ResidualIn, Gamma, Beta, Output, ResidualOut) launcher.launch( grid=(m_in, 1, 1), @@ -522,17 +512,14 @@ def _build_layernorm_quant_module( is_fused_add: bool, quant_dtype_str: str = "i8", ): - arch = get_hip_arch() RED_SLOTS = max(1, (BLOCK_THREADS + WARP_SIZE - 1) // WARP_SIZE) elem_bits = 32 if dtype_str == "f32" else 16 quant_dtype_max = _quant_dtype_max(quant_dtype_str) - allocator = SmemAllocator(None, arch=arch) - f32_bytes = 4 - sum_offset = allocator._align(allocator.ptr, 16) - allocator.ptr = sum_offset + RED_SLOTS * f32_bytes - sumsq_offset = allocator._align(allocator.ptr, 16) - allocator.ptr = sumsq_offset + RED_SLOTS * f32_bytes + @fx.struct + class SharedStorage: + s_sum: fx.Array[fx.Float32, RED_SLOTS, 16] + s_sumsq: fx.Array[fx.Float32, RED_SLOTS, 16] @flyc.kernel def layernorm_quant_kernel( @@ -559,11 +546,9 @@ def layernorm_quant_kernel( c_neg_inf = fx.Float32(float("-inf")) c_dtype_max = fx.Float32(quant_dtype_max) - base_ptr = allocator.get_base() - s_sum = SmemPtr(base_ptr, sum_offset, fx.Float32.ir_type, shape=(RED_SLOTS,)) - s_sumsq = SmemPtr(base_ptr, sumsq_offset, fx.Float32.ir_type, shape=(RED_SLOTS,)) - s_sum.get() - s_sumsq.get() + lds = fx.SharedAllocator().allocate(SharedStorage) + s_sum = lds.s_sum.peek().view(fx.make_layout(RED_SLOTS, 1)) + s_sumsq = lds.s_sumsq.peek().view(fx.make_layout(RED_SLOTS, 1)) YScale_buf = fx.rocdl.make_buffer_tensor(YScale) yscale_div = fx.logical_divide(YScale_buf, fx.make_layout(1, 1)) @@ -601,25 +586,25 @@ def block_reduce_add2(val0, val1): w1 = wave_reduce_add(val1) if lane == 0: - SmemPtr.store(s_sum, w0, [wave]) - SmemPtr.store(s_sumsq, w1, [wave]) + fx.memref_store(w0, s_sum, wave) + fx.memref_store(w1, s_sumsq, wave) gpu.barrier() if wave == 0: in_range = lane < RED_SLOTS lane_safe = in_range.select(lane, 0) - v0 = SmemPtr.load(s_sum, [lane_safe]) - v1 = SmemPtr.load(s_sumsq, [lane_safe]) + v0 = fx.memref_load(s_sum, lane_safe) + v1 = fx.memref_load(s_sumsq, lane_safe) ww0 = in_range.select(v0, c_zero_f) ww1 = in_range.select(v1, c_zero_f) ww0 = wave_reduce_add(ww0) ww1 = wave_reduce_add(ww1) if lane == 0: - SmemPtr.store(s_sum, ww0, [0]) - SmemPtr.store(s_sumsq, ww1, [0]) + fx.memref_store(ww0, s_sum, 0) + fx.memref_store(ww1, s_sumsq, 0) gpu.barrier() - return SmemPtr.load(s_sum, [0]), SmemPtr.load(s_sumsq, [0]) + return fx.memref_load(s_sum, 0), fx.memref_load(s_sumsq, 0) def block_reduce_max(val): if const_expr(RED_SLOTS == 1): @@ -629,20 +614,20 @@ def block_reduce_max(val): wave = tid // WARP_SIZE w = wave_reduce_max(val) if lane == 0: - SmemPtr.store(s_sum, w, [wave]) + fx.memref_store(w, s_sum, wave) gpu.barrier() if wave == 0: in_range = lane < RED_SLOTS lane_safe = in_range.select(lane, 0) - v = SmemPtr.load(s_sum, [lane_safe]) + v = fx.memref_load(s_sum, lane_safe) ww = in_range.select(v, c_neg_inf) ww = wave_reduce_max(ww) if lane == 0: - SmemPtr.store(s_sum, ww, [0]) + fx.memref_store(ww, s_sum, 0) gpu.barrier() - return SmemPtr.load(s_sum, [0]) + return fx.memref_load(s_sum, 0) Input_buf = fx.rocdl.make_buffer_tensor(Input) Gamma_buf = fx.rocdl.make_buffer_tensor(Gamma) @@ -793,11 +778,6 @@ def launch_fused_add_layernorm_smoothquant( m_in: fx.Int32, stream: fx.Stream = fx.Stream(None), ): - allocator.finalized = False - ctx = CompilationContext.get_current() - with InsertionPoint(ctx.gpu_module_body): - allocator.finalize() - launcher = layernorm_quant_kernel(Input, ResidualIn, Gamma, Beta, XScale, YScale, Output, ResidualOut) launcher.launch( grid=(m_in, 1, 1), @@ -819,11 +799,6 @@ def launch_fused_add_layernorm_dynamicquant( m_in: fx.Int32, stream: fx.Stream = fx.Stream(None), ): - allocator.finalized = False - ctx = CompilationContext.get_current() - with InsertionPoint(ctx.gpu_module_body): - allocator.finalize() - launcher = layernorm_quant_kernel(Input, ResidualIn, Gamma, Beta, Gamma, YScale, Output, ResidualOut) launcher.launch( grid=(m_in, 1, 1), @@ -846,11 +821,6 @@ def launch_layernorm_smoothquant( m_in: fx.Int32, stream: fx.Stream = fx.Stream(None), ): - allocator.finalized = False - ctx = CompilationContext.get_current() - with InsertionPoint(ctx.gpu_module_body): - allocator.finalize() - launcher = layernorm_quant_kernel(Input, Input, Gamma, Beta, XScale, YScale, Output, Output) launcher.launch( grid=(m_in, 1, 1), @@ -870,11 +840,6 @@ def launch_layernorm_dynamicquant( m_in: fx.Int32, stream: fx.Stream = fx.Stream(None), ): - allocator.finalized = False - ctx = CompilationContext.get_current() - with InsertionPoint(ctx.gpu_module_body): - allocator.finalize() - launcher = layernorm_quant_kernel(Input, Input, Gamma, Beta, Gamma, YScale, Output, Output) launcher.launch( grid=(m_in, 1, 1), From cef8eef3b2819a0ace24666933b849a5aca82e94 Mon Sep 17 00:00:00 2001 From: Junlin Chen Date: Fri, 22 May 2026 13:36:03 +0800 Subject: [PATCH 5/5] Align layernorm kernels with struct-level SharedAllocator access --- kernels/layernorm_kernel.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/kernels/layernorm_kernel.py b/kernels/layernorm_kernel.py index 79e59209c..f3790f64f 100644 --- a/kernels/layernorm_kernel.py +++ b/kernels/layernorm_kernel.py @@ -61,9 +61,9 @@ def layernorm_kernel( fm_fast = arith.FastMathFlags.fast eps_c = EPS - lds = fx.SharedAllocator().allocate(SharedStorage) - s_sum = lds.s_sum.peek().view(fx.make_layout(RED_SLOTS, 1)) - s_sumsq = lds.s_sumsq.peek().view(fx.make_layout(RED_SLOTS, 1)) + lds = fx.SharedAllocator().allocate(SharedStorage).peek() + s_sum = lds.s_sum.view(fx.make_layout(RED_SLOTS, 1)) + s_sumsq = lds.s_sumsq.view(fx.make_layout(RED_SLOTS, 1)) # ── helpers: wave / block reduction ─────────────────────────────── def wave_reduce_add(x): @@ -353,9 +353,9 @@ def fused_add_layernorm_kernel( fm_fast = arith.FastMathFlags.fast eps_c = EPS - lds = fx.SharedAllocator().allocate(SharedStorage) - s_sum = lds.s_sum.peek().view(fx.make_layout(RED_SLOTS, 1)) - s_sumsq = lds.s_sumsq.peek().view(fx.make_layout(RED_SLOTS, 1)) + lds = fx.SharedAllocator().allocate(SharedStorage).peek() + s_sum = lds.s_sum.view(fx.make_layout(RED_SLOTS, 1)) + s_sumsq = lds.s_sumsq.view(fx.make_layout(RED_SLOTS, 1)) def wave_reduce_add(x): w = x @@ -546,9 +546,9 @@ def layernorm_quant_kernel( c_neg_inf = fx.Float32(float("-inf")) c_dtype_max = fx.Float32(quant_dtype_max) - lds = fx.SharedAllocator().allocate(SharedStorage) - s_sum = lds.s_sum.peek().view(fx.make_layout(RED_SLOTS, 1)) - s_sumsq = lds.s_sumsq.peek().view(fx.make_layout(RED_SLOTS, 1)) + lds = fx.SharedAllocator().allocate(SharedStorage).peek() + s_sum = lds.s_sum.view(fx.make_layout(RED_SLOTS, 1)) + s_sumsq = lds.s_sumsq.view(fx.make_layout(RED_SLOTS, 1)) YScale_buf = fx.rocdl.make_buffer_tensor(YScale) yscale_div = fx.logical_divide(YScale_buf, fx.make_layout(1, 1))