diff --git a/kernels/layernorm_kernel.py b/kernels/layernorm_kernel.py index e53f4fb4..f3790f64 100644 --- a/kernels/layernorm_kernel.py +++ b/kernels/layernorm_kernel.py @@ -314,3 +314,601 @@ def launch_layernorm( ) return launch_layernorm + + +def _quant_dtype_to_elem_type(dtype_str: str): + if dtype_str in ("i8", "int8"): + return fx.Int8 + raise ValueError(f"unsupported quant dtype: {dtype_str!r} (expected 'i8' or 'int8')") + + +def _quant_dtype_max(dtype_str: str) -> float: + if dtype_str in ("i8", "int8"): + return 127.0 + raise ValueError(f"unsupported quant dtype: {dtype_str!r} (expected 'i8' or 'int8')") + + +def build_fused_add_layernorm_module(M: int, N: int, dtype_str: str): + RED_SLOTS = max(1, (BLOCK_THREADS + WARP_SIZE - 1) // WARP_SIZE) + elem_bits = 32 if dtype_str == "f32" else 16 + + @fx.struct + class SharedStorage: + s_sum: fx.Array[fx.Float32, RED_SLOTS, 16] + s_sumsq: fx.Array[fx.Float32, RED_SLOTS, 16] + + @flyc.kernel + def fused_add_layernorm_kernel( + Input: fx.Tensor, + ResidualIn: fx.Tensor, + Gamma: fx.Tensor, + Beta: fx.Tensor, + Output: fx.Tensor, + ResidualOut: fx.Tensor, + ): + bid = fx.block_idx.x + tid = fx.thread_idx.x + + elem_dtype = dtype_to_elem_type(dtype_str) + fm_fast = arith.FastMathFlags.fast + eps_c = EPS + + lds = fx.SharedAllocator().allocate(SharedStorage).peek() + s_sum = lds.s_sum.view(fx.make_layout(RED_SLOTS, 1)) + s_sumsq = lds.s_sumsq.view(fx.make_layout(RED_SLOTS, 1)) + + def wave_reduce_add(x): + w = x + for _sh_exp in range_constexpr(int(math.log2(WARP_SIZE))): + off = WARP_SIZE // (2 << _sh_exp) + peer = w.shuffle_xor(off, WARP_SIZE) + w = w.addf(peer, fastmath=fm_fast) + return w + + def block_reduce_add2(val0, val1): + if const_expr(RED_SLOTS == 1): + return wave_reduce_add(val0), wave_reduce_add(val1) + + lane = tid % WARP_SIZE + wave = tid // WARP_SIZE + w0 = wave_reduce_add(val0) + w1 = wave_reduce_add(val1) + + if lane == 0: + fx.memref_store(w0, s_sum, wave) + fx.memref_store(w1, s_sumsq, wave) + gpu.barrier() + + if wave == 0: + in_range = lane < RED_SLOTS + lane_safe = in_range.select(lane, 0) + v0 = fx.memref_load(s_sum, lane_safe) + v1 = fx.memref_load(s_sumsq, lane_safe) + ww0 = in_range.select(v0, 0.0) + ww1 = in_range.select(v1, 0.0) + ww0 = wave_reduce_add(ww0) + ww1 = wave_reduce_add(ww1) + + if lane == 0: + fx.memref_store(ww0, s_sum, 0) + fx.memref_store(ww1, s_sumsq, 0) + gpu.barrier() + + return fx.memref_load(s_sum, 0), fx.memref_load(s_sumsq, 0) + + def compute_mean_rstd(sum_val, sumsq_val): + inv_n = 1.0 / float(N) + mean = sum_val * inv_n + mean_sq = sumsq_val * inv_n + var = mean_sq - mean * mean + var = (var < 0.0).select(0.0, var) + return mean, fmath.rsqrt(var + eps_c, fastmath=fm_fast) + + Input_buf = fx.rocdl.make_buffer_tensor(Input) + ResidualIn_buf = fx.rocdl.make_buffer_tensor(ResidualIn) + Gamma_buf = fx.rocdl.make_buffer_tensor(Gamma) + Beta_buf = fx.rocdl.make_buffer_tensor(Beta) + Output_buf = fx.rocdl.make_buffer_tensor(Output) + ResidualOut_buf = fx.rocdl.make_buffer_tensor(ResidualOut) + + row_in = fx.slice(Input_buf, (bid, None)) + row_residual_in = fx.slice(ResidualIn_buf, (bid, None)) + row_out = fx.slice(Output_buf, (bid, None)) + row_residual_out = fx.slice(ResidualOut_buf, (bid, None)) + + copy_atom_s = fx.make_copy_atom( + fx.rocdl.BufferCopy16b() if elem_bits <= 16 else fx.rocdl.BufferCopy32b(), + elem_bits, + ) + + in_div = fx.logical_divide(row_in, fx.make_layout(1, 1)) + residual_in_div = fx.logical_divide(row_residual_in, fx.make_layout(1, 1)) + gamma_div = fx.logical_divide(Gamma_buf, fx.make_layout(1, 1)) + beta_div = fx.logical_divide(Beta_buf, fx.make_layout(1, 1)) + out_div = fx.logical_divide(row_out, fx.make_layout(1, 1)) + residual_out_div = fx.logical_divide(row_residual_out, fx.make_layout(1, 1)) + + def _load_scalar(divided_tensor, index): + view = fx.slice(divided_tensor, (None, index)) + r = fx.make_rmem_tensor(1, elem_dtype) + fx.copy_atom_call(copy_atom_s, view, r) + return fx.memref_load_vec(r)[0] + + def _store_scalar(divided_tensor, index, val): + r = fx.make_rmem_tensor(1, elem_dtype) + ts = full(1, elem_dtype(val), elem_dtype) + fx.memref_store_vec(ts, r) + view = fx.slice(divided_tensor, (None, index)) + fx.copy_atom_call(copy_atom_s, r, view) + + c_zero_f = fx.Float32(0.0) + thread_sum = c_zero_f + thread_sumsq = c_zero_f + + for base_idx_int in range_constexpr(0, N, BLOCK_THREADS): + idx = tid + base_idx_int + is_valid = idx < N + idx_safe = is_valid.select(idx, 0) + x_e = _load_scalar(in_div, idx_safe) + r_e = _load_scalar(residual_in_div, idx_safe) + x = x_e if dtype_str == "f32" else x_e.to(fx.Float32) + residual = r_e if dtype_str == "f32" else r_e.to(fx.Float32) + added = x + residual + added_safe = is_valid.select(added, c_zero_f) + thread_sum = thread_sum + added_safe + thread_sumsq = thread_sumsq + is_valid.select(added * added, c_zero_f) + if idx < N: + _store_scalar(residual_out_div, idx, added) + + sum_val, sumsq_val = block_reduce_add2(thread_sum, thread_sumsq) + mean, rstd = compute_mean_rstd(sum_val, sumsq_val) + + for base_idx_int in range_constexpr(0, N, BLOCK_THREADS): + idx = tid + base_idx_int + if idx < N: + x_e = _load_scalar(in_div, idx) + r_e = _load_scalar(residual_in_div, idx) + g_e = _load_scalar(gamma_div, idx) + b_e = _load_scalar(beta_div, idx) + x = x_e if dtype_str == "f32" else x_e.to(fx.Float32) + residual = r_e if dtype_str == "f32" else r_e.to(fx.Float32) + g = g_e if dtype_str == "f32" else g_e.to(fx.Float32) + b = b_e if dtype_str == "f32" else b_e.to(fx.Float32) + y = ((x + residual) - mean) * rstd + y = y * g + b + if const_expr(dtype_str == "f32"): + y_e = y + else: + y_e = y.to(elem_dtype) + _store_scalar(out_div, idx, y_e) + + @flyc.jit + def launch_fused_add_layernorm( + Input: fx.Tensor, + ResidualIn: fx.Tensor, + Gamma: fx.Tensor, + Beta: fx.Tensor, + Output: fx.Tensor, + ResidualOut: fx.Tensor, + m_in: fx.Int32, + stream: fx.Stream = fx.Stream(None), + ): + launcher = fused_add_layernorm_kernel(Input, ResidualIn, Gamma, Beta, Output, ResidualOut) + launcher.launch( + grid=(m_in, 1, 1), + block=(BLOCK_THREADS, 1, 1), + stream=stream, + ) + + return launch_fused_add_layernorm + + +def _build_layernorm_quant_module( + M: int, + N: int, + dtype_str: str, + *, + is_smooth: bool, + is_fused_add: bool, + quant_dtype_str: str = "i8", +): + RED_SLOTS = max(1, (BLOCK_THREADS + WARP_SIZE - 1) // WARP_SIZE) + elem_bits = 32 if dtype_str == "f32" else 16 + quant_dtype_max = _quant_dtype_max(quant_dtype_str) + + @fx.struct + class SharedStorage: + s_sum: fx.Array[fx.Float32, RED_SLOTS, 16] + s_sumsq: fx.Array[fx.Float32, RED_SLOTS, 16] + + @flyc.kernel + def layernorm_quant_kernel( + Input: fx.Tensor, + ResidualIn: fx.Tensor, + Gamma: fx.Tensor, + Beta: fx.Tensor, + XScale: fx.Tensor, + YScale: fx.Tensor, + Output: fx.Tensor, + ResidualOut: fx.Tensor, + ): + bid = fx.block_idx.x + tid = fx.thread_idx.x + + elem_dtype = dtype_to_elem_type(dtype_str) + quant_dtype = _quant_dtype_to_elem_type(quant_dtype_str) + + fm_fast = arith.FastMathFlags.fast + eps_c = EPS + n_float = float(N) + c_zero_f = fx.Float32(0.0) + c_one_f = fx.Float32(1.0) + c_neg_inf = fx.Float32(float("-inf")) + c_dtype_max = fx.Float32(quant_dtype_max) + + lds = fx.SharedAllocator().allocate(SharedStorage).peek() + s_sum = lds.s_sum.view(fx.make_layout(RED_SLOTS, 1)) + s_sumsq = lds.s_sumsq.view(fx.make_layout(RED_SLOTS, 1)) + + YScale_buf = fx.rocdl.make_buffer_tensor(YScale) + yscale_div = fx.logical_divide(YScale_buf, fx.make_layout(1, 1)) + scale_copy_atom = fx.make_copy_atom(fx.rocdl.BufferCopy32b(), 32) + + def _store_yscale(index, val): + r = fx.make_rmem_tensor(1, fx.Float32) + ts = full(1, fx.Float32(val), fx.Float32) + fx.memref_store_vec(ts, r) + fx.copy_atom_call(scale_copy_atom, r, fx.slice(yscale_div, (None, index))) + + def wave_reduce_add(x): + w = x + for _sh_exp in range_constexpr(int(math.log2(WARP_SIZE))): + off = WARP_SIZE // (2 << _sh_exp) + peer = w.shuffle_xor(off, WARP_SIZE) + w = w.addf(peer, fastmath=fm_fast) + return w + + def wave_reduce_max(x): + w = x + for _sh_exp in range_constexpr(int(math.log2(WARP_SIZE))): + off = WARP_SIZE // (2 << _sh_exp) + peer = w.shuffle_xor(off, WARP_SIZE) + w = w.maximumf(peer) + return w + + def block_reduce_add2(val0, val1): + if const_expr(RED_SLOTS == 1): + return wave_reduce_add(val0), wave_reduce_add(val1) + + lane = tid % WARP_SIZE + wave = tid // WARP_SIZE + w0 = wave_reduce_add(val0) + w1 = wave_reduce_add(val1) + + if lane == 0: + fx.memref_store(w0, s_sum, wave) + fx.memref_store(w1, s_sumsq, wave) + gpu.barrier() + + if wave == 0: + in_range = lane < RED_SLOTS + lane_safe = in_range.select(lane, 0) + v0 = fx.memref_load(s_sum, lane_safe) + v1 = fx.memref_load(s_sumsq, lane_safe) + ww0 = in_range.select(v0, c_zero_f) + ww1 = in_range.select(v1, c_zero_f) + ww0 = wave_reduce_add(ww0) + ww1 = wave_reduce_add(ww1) + if lane == 0: + fx.memref_store(ww0, s_sum, 0) + fx.memref_store(ww1, s_sumsq, 0) + gpu.barrier() + + return fx.memref_load(s_sum, 0), fx.memref_load(s_sumsq, 0) + + def block_reduce_max(val): + if const_expr(RED_SLOTS == 1): + return wave_reduce_max(val) + + lane = tid % WARP_SIZE + wave = tid // WARP_SIZE + w = wave_reduce_max(val) + if lane == 0: + fx.memref_store(w, s_sum, wave) + gpu.barrier() + + if wave == 0: + in_range = lane < RED_SLOTS + lane_safe = in_range.select(lane, 0) + v = fx.memref_load(s_sum, lane_safe) + ww = in_range.select(v, c_neg_inf) + ww = wave_reduce_max(ww) + if lane == 0: + fx.memref_store(ww, s_sum, 0) + gpu.barrier() + + return fx.memref_load(s_sum, 0) + + Input_buf = fx.rocdl.make_buffer_tensor(Input) + Gamma_buf = fx.rocdl.make_buffer_tensor(Gamma) + Beta_buf = fx.rocdl.make_buffer_tensor(Beta) + Output_buf = fx.rocdl.make_buffer_tensor(Output) + if const_expr(is_fused_add): + ResidualIn_buf = fx.rocdl.make_buffer_tensor(ResidualIn) + ResidualOut_buf = fx.rocdl.make_buffer_tensor(ResidualOut) + if const_expr(is_smooth): + XScale_buf = fx.rocdl.make_buffer_tensor(XScale) + + row_in = fx.slice(Input_buf, (bid, None)) + row_out = fx.slice(Output_buf, (bid, None)) + if const_expr(is_fused_add): + row_residual_in = fx.slice(ResidualIn_buf, (bid, None)) + row_residual_out = fx.slice(ResidualOut_buf, (bid, None)) + + copy_atom_s = fx.make_copy_atom( + fx.rocdl.BufferCopy16b() if elem_bits <= 16 else fx.rocdl.BufferCopy32b(), + elem_bits, + ) + copy_atom_qs = fx.make_copy_atom(fx.rocdl.BufferCopy(8), 8) + + in_div = fx.logical_divide(row_in, fx.make_layout(1, 1)) + gamma_div = fx.logical_divide(Gamma_buf, fx.make_layout(1, 1)) + beta_div = fx.logical_divide(Beta_buf, fx.make_layout(1, 1)) + out_div = fx.logical_divide(row_out, fx.make_layout(1, 1)) + if const_expr(is_fused_add): + residual_in_div = fx.logical_divide(row_residual_in, fx.make_layout(1, 1)) + residual_out_div = fx.logical_divide(row_residual_out, fx.make_layout(1, 1)) + if const_expr(is_smooth): + xscale_div = fx.logical_divide(XScale_buf, fx.make_layout(1, 1)) + + def _load_scalar(divided_tensor, index): + view = fx.slice(divided_tensor, (None, index)) + r = fx.make_rmem_tensor(1, elem_dtype) + fx.copy_atom_call(copy_atom_s, view, r) + return fx.memref_load_vec(r)[0] + + def _store_elem_scalar(divided_tensor, index, val): + r = fx.make_rmem_tensor(1, elem_dtype) + ts = full(1, elem_dtype(val), elem_dtype) + fx.memref_store_vec(ts, r) + view = fx.slice(divided_tensor, (None, index)) + fx.copy_atom_call(copy_atom_s, r, view) + + def _store_quant_scalar(divided_tensor, index, val): + r = fx.make_rmem_tensor(1, quant_dtype) + ts = full(1, quant_dtype(val), quant_dtype) + fx.memref_store_vec(ts, r) + view = fx.slice(divided_tensor, (None, index)) + fx.copy_atom_call(copy_atom_qs, r, view) + + def _abs_scalar(val): + is_neg = val < c_zero_f + neg_val = c_zero_f - val + return is_neg.select(neg_val, val) + + def _load_input_value(index): + x_e = _load_scalar(in_div, index) + x = x_e if dtype_str == "f32" else x_e.to(fx.Float32) + if const_expr(is_fused_add): + r_e = _load_scalar(residual_in_div, index) + residual = r_e if dtype_str == "f32" else r_e.to(fx.Float32) + return x + residual + return x + + thread_sum = c_zero_f + thread_sumsq = c_zero_f + + for base_idx_int in range_constexpr(0, N, BLOCK_THREADS): + idx = tid + base_idx_int + is_valid = idx < N + idx_safe = is_valid.select(idx, 0) + x = _load_input_value(idx_safe) + x2 = x * x + thread_sum = thread_sum + is_valid.select(x, c_zero_f) + thread_sumsq = thread_sumsq + is_valid.select(x2, c_zero_f) + if const_expr(is_fused_add): + if idx < N: + _store_elem_scalar(residual_out_div, idx, x) + + sum_val, sumsq_val = block_reduce_add2(thread_sum, thread_sumsq) + mean = sum_val / n_float + var = sumsq_val / n_float - mean * mean + var = (var < c_zero_f).select(c_zero_f, var) + rstd = (var + eps_c).rsqrt(fastmath=fm_fast) + + thread_row_max = c_zero_f + for base_idx_int in range_constexpr(0, N, BLOCK_THREADS): + idx = tid + base_idx_int + is_valid = idx < N + idx_safe = is_valid.select(idx, 0) + x = _load_input_value(idx_safe) + g_e = _load_scalar(gamma_div, idx_safe) + b_e = _load_scalar(beta_div, idx_safe) + g = g_e if dtype_str == "f32" else g_e.to(fx.Float32) + b = b_e if dtype_str == "f32" else b_e.to(fx.Float32) + y = (x - mean) * rstd + y = y * g + b + if const_expr(is_smooth): + s_e = _load_scalar(xscale_div, idx_safe) + s = s_e if dtype_str == "f32" else s_e.to(fx.Float32) + y = y * s + y_abs = _abs_scalar(y) + thread_row_max = thread_row_max.maximumf(is_valid.select(y_abs, c_zero_f)) + + row_max = block_reduce_max(thread_row_max) + scale = row_max / c_dtype_max + final_scale = (scale == c_zero_f).select(c_one_f, scale) + + if tid == 0: + _store_yscale(bid, final_scale) + + inv_scale = c_one_f / final_scale + + for base_idx_int in range_constexpr(0, N, BLOCK_THREADS): + idx = tid + base_idx_int + if idx < N: + x = _load_input_value(idx) + g_e = _load_scalar(gamma_div, idx) + b_e = _load_scalar(beta_div, idx) + g = g_e if dtype_str == "f32" else g_e.to(fx.Float32) + b = b_e if dtype_str == "f32" else b_e.to(fx.Float32) + y = (x - mean) * rstd + y = y * g + b + if const_expr(is_smooth): + s_e = _load_scalar(xscale_div, idx) + s = s_e if dtype_str == "f32" else s_e.to(fx.Float32) + y = y * s + q = y * inv_scale + q_i8 = q.to(quant_dtype) + _store_quant_scalar(out_div, idx, q_i8) + + if is_fused_add: + if is_smooth: + + @flyc.jit + def launch_fused_add_layernorm_smoothquant( + Input: fx.Tensor, + ResidualIn: fx.Tensor, + Gamma: fx.Tensor, + Beta: fx.Tensor, + XScale: fx.Tensor, + Output: fx.Tensor, + ResidualOut: fx.Tensor, + YScale: fx.Tensor, + m_in: fx.Int32, + stream: fx.Stream = fx.Stream(None), + ): + launcher = layernorm_quant_kernel(Input, ResidualIn, Gamma, Beta, XScale, YScale, Output, ResidualOut) + launcher.launch( + grid=(m_in, 1, 1), + block=(BLOCK_THREADS, 1, 1), + stream=stream, + ) + + return launch_fused_add_layernorm_smoothquant + + @flyc.jit + def launch_fused_add_layernorm_dynamicquant( + Input: fx.Tensor, + ResidualIn: fx.Tensor, + Gamma: fx.Tensor, + Beta: fx.Tensor, + Output: fx.Tensor, + ResidualOut: fx.Tensor, + YScale: fx.Tensor, + m_in: fx.Int32, + stream: fx.Stream = fx.Stream(None), + ): + launcher = layernorm_quant_kernel(Input, ResidualIn, Gamma, Beta, Gamma, YScale, Output, ResidualOut) + launcher.launch( + grid=(m_in, 1, 1), + block=(BLOCK_THREADS, 1, 1), + stream=stream, + ) + + return launch_fused_add_layernorm_dynamicquant + + if is_smooth: + + @flyc.jit + def launch_layernorm_smoothquant( + Input: fx.Tensor, + Gamma: fx.Tensor, + Beta: fx.Tensor, + XScale: fx.Tensor, + Output: fx.Tensor, + YScale: fx.Tensor, + m_in: fx.Int32, + stream: fx.Stream = fx.Stream(None), + ): + launcher = layernorm_quant_kernel(Input, Input, Gamma, Beta, XScale, YScale, Output, Output) + launcher.launch( + grid=(m_in, 1, 1), + block=(BLOCK_THREADS, 1, 1), + stream=stream, + ) + + return launch_layernorm_smoothquant + + @flyc.jit + def launch_layernorm_dynamicquant( + Input: fx.Tensor, + Gamma: fx.Tensor, + Beta: fx.Tensor, + Output: fx.Tensor, + YScale: fx.Tensor, + m_in: fx.Int32, + stream: fx.Stream = fx.Stream(None), + ): + launcher = layernorm_quant_kernel(Input, Input, Gamma, Beta, Gamma, YScale, Output, Output) + launcher.launch( + grid=(m_in, 1, 1), + block=(BLOCK_THREADS, 1, 1), + stream=stream, + ) + + return launch_layernorm_dynamicquant + + +def build_layernorm_dynamicquant_module( + M: int, + N: int, + dtype_str: str, + quant_dtype_str: str = "i8", +): + return _build_layernorm_quant_module( + M, + N, + dtype_str, + is_smooth=False, + is_fused_add=False, + quant_dtype_str=quant_dtype_str, + ) + + +def build_layernorm_smoothquant_module( + M: int, + N: int, + dtype_str: str, + quant_dtype_str: str = "i8", +): + return _build_layernorm_quant_module( + M, + N, + dtype_str, + is_smooth=True, + is_fused_add=False, + quant_dtype_str=quant_dtype_str, + ) + + +def build_fused_add_layernorm_dynamicquant_module( + M: int, + N: int, + dtype_str: str, + quant_dtype_str: str = "i8", +): + return _build_layernorm_quant_module( + M, + N, + dtype_str, + is_smooth=False, + is_fused_add=True, + quant_dtype_str=quant_dtype_str, + ) + + +def build_fused_add_layernorm_smoothquant_module( + M: int, + N: int, + dtype_str: str, + quant_dtype_str: str = "i8", +): + return _build_layernorm_quant_module( + M, + N, + dtype_str, + is_smooth=True, + is_fused_add=True, + quant_dtype_str=quant_dtype_str, + ) diff --git a/tests/kernels/test_layernorm.py b/tests/kernels/test_layernorm.py index 40ee0d21..1671dfe8 100644 --- a/tests/kernels/test_layernorm.py +++ b/tests/kernels/test_layernorm.py @@ -19,6 +19,11 @@ from kernels.layernorm_kernel import ( build_layernorm_module, + build_fused_add_layernorm_module, + build_layernorm_dynamicquant_module, + build_layernorm_smoothquant_module, + build_fused_add_layernorm_dynamicquant_module, + build_fused_add_layernorm_smoothquant_module, ) from tests.kernels.benchmark_common import ( PerfRow, @@ -40,6 +45,7 @@ DTYPE_FP32 = torch.float32 DTYPE_FP16 = torch.float16 DTYPE_BF16 = torch.bfloat16 +DTYPE_INT8 = torch.int8 EPS: float = 1e-5 @@ -146,7 +152,7 @@ def kernel_launch(): return ok, flydsl_gpu_us -def test_all(): +def test_layernorm(): print("=" * 80) print("Running LayerNorm Tests") print("=" * 80) @@ -223,5 +229,1049 @@ def run_aiter(): raise SystemExit(1) +def run_fused_add_test(M: int, N: int, dtype: str = "f32"): + print(f"\nTesting FusedAdd LayerNorm (M={M}, N={N}, dtype={dtype})") + + try: + launch_fn = build_fused_add_layernorm_module(M, N, dtype) + except Exception as e: + print(f"[FAIL] Compile failed for fused_add layernorm (M={M}, N={N}, dtype={dtype}): {type(e).__name__}: {e}") + return False, None + + torch.manual_seed(42) + input_t = torch.randn((M, N), device="cuda", dtype=DTYPE_FP32) + residual_t = torch.randn((M, N), device="cuda", dtype=DTYPE_FP32) + gamma_t = torch.rand((N,), device="cuda", dtype=DTYPE_FP32) + beta_t = torch.rand((N,), device="cuda", dtype=DTYPE_FP32) + + if dtype == "f32": + input_dev = input_t.contiguous() + residual_dev = residual_t.contiguous() + gamma_dev = gamma_t.contiguous() + beta_dev = beta_t.contiguous() + output_dev = torch.empty((M, N), device="cuda", dtype=DTYPE_FP32) + residual_out_dev = torch.empty((M, N), device="cuda", dtype=DTYPE_FP32) + input_ref = input_dev.to(DTYPE_FP32) + residual_ref = residual_dev.to(DTYPE_FP32) + gamma_ref = gamma_dev.to(DTYPE_FP32) + beta_ref = beta_dev.to(DTYPE_FP32) + atol = 1e-4 + elif dtype == "f16": + input_dev = input_t.to(DTYPE_FP16).contiguous() + residual_dev = residual_t.to(DTYPE_FP16).contiguous() + gamma_dev = gamma_t.to(DTYPE_FP16).contiguous() + beta_dev = beta_t.to(DTYPE_FP16).contiguous() + output_dev = torch.empty((M, N), device="cuda", dtype=DTYPE_FP16) + residual_out_dev = torch.empty((M, N), device="cuda", dtype=DTYPE_FP16) + input_ref = input_dev.to(DTYPE_FP32) + residual_ref = residual_dev.to(DTYPE_FP32) + gamma_ref = gamma_dev.to(DTYPE_FP32) + beta_ref = beta_dev.to(DTYPE_FP32) + atol = 1e-2 + elif dtype == "bf16": + input_dev = input_t.to(DTYPE_BF16).contiguous() + residual_dev = residual_t.to(DTYPE_BF16).contiguous() + gamma_dev = gamma_t.to(DTYPE_BF16).contiguous() + beta_dev = beta_t.to(DTYPE_BF16).contiguous() + output_dev = torch.empty((M, N), device="cuda", dtype=DTYPE_BF16) + residual_out_dev = torch.empty((M, N), device="cuda", dtype=DTYPE_BF16) + input_ref = input_dev.to(DTYPE_FP32) + residual_ref = residual_dev.to(DTYPE_FP32) + gamma_ref = gamma_dev.to(DTYPE_FP32) + beta_ref = beta_dev.to(DTYPE_FP32) + atol = 2e-2 + else: + raise ValueError(f"unsupported dtype: {dtype}") + + residual_expected = input_ref + residual_ref + x = residual_expected + gamma = gamma_ref + beta = beta_ref + mean = x.mean(dim=1, keepdim=True) + var = x.var(dim=1, keepdim=True, unbiased=False) + expected = (x - mean) / torch.sqrt(var + EPS) * gamma + beta + expected = expected.to(DTYPE_FP32) + + print("Launching kernel...") + stream = torch.cuda.current_stream() + + def kernel_launch(): + launch_fn(input_dev, residual_dev, gamma_dev, beta_dev, output_dev, residual_out_dev, M, stream=stream) + + kernel_launch() + torch.cuda.synchronize() + + _, avg_us = run_perftest( + lambda: (kernel_launch(), torch.cuda.synchronize()), num_iters=BENCH_ITERS, num_warmup=WARMUP_ITERS + ) + torch.cuda.synchronize() + flydsl_gpu_us = None + if os.environ.get("ROCDSL_COMPARE_AITER", "0") == "1": + flydsl_gpu_us = bench_gpu_us_torch(kernel_launch, warmup=WARMUP_ITERS, iters=BENCH_ITERS) + avg_ms = avg_us / 1000.0 + elem_bytes = 4 if dtype == "f32" else 2 + total_bytes = (4 * M * N + 2 * N) * elem_bytes + bandwidth_gbs = total_bytes / (avg_us / 1e6) / 1e9 + print(f"Kernel avg time: {avg_ms:.4f} ms via run_perftest (warmup={WARMUP_ITERS}, iters={BENCH_ITERS})") + print(f"Bandwidth: {bandwidth_gbs:.2f} GB/s") + if flydsl_gpu_us is not None: + print(f"[Perf] FlyDSL fused_add layernorm gpu: {flydsl_gpu_us:.1f} us") + + output_ref = output_dev.to(DTYPE_FP32) + residual_out_ref = residual_out_dev.to(DTYPE_FP32) + + output_error = (output_ref - expected).abs().max().item() + residual_error = (residual_out_ref - residual_expected).abs().max().item() + print(f"Max output error: {output_error:.2e} (atol={atol})") + print(f"Max residual error: {residual_error:.2e} (atol={atol})") + + if output_error < atol and residual_error < atol: + print("PASSED") + ok = True + else: + print("FAILED") + print("First row Expected:") + print(expected[0, :5]) + print("First row Actual:") + print(output_ref[0, :5]) + print("First row Residual Expected:") + print(residual_expected[0, :5]) + print("First row Residual Actual:") + print(residual_out_ref[0, :5]) + ok = False + + return ok, flydsl_gpu_us + + +def test_fused_add_layernorm(): + print("=" * 80) + print("Running FusedAdd LayerNorm Tests") + print("=" * 80) + + shapes_env = os.environ.get("ROCDSL_LAYERNORM_SHAPES", "").strip() + if shapes_env: + configs = [] + for part in shapes_env.split(";"): + p = part.strip() + if not p: + continue + m_s, n_s, dt = [x.strip() for x in p.split(",")] + configs.append((int(m_s), int(n_s), dt)) + else: + configs = [ + # (64, 256, "f32"), # Aligned + # (128, 1024, "f32"), # Aligned + # (32, 128, "f16"), # Aligned + # (64, 2000, "f32"), # Unaligned (tail handling) + # (16, 512, "bf16"), # BF16 + # (1024, 8192, "bf16"), # BF16 + (32768, 8192, "bf16"), + ] + + do_compare = os.environ.get("ROCDSL_COMPARE_AITER", "0") == "1" + perf_rows = [] + failures = 0 + + for M, N, dtype in configs: + ok, flydsl_gpu_us = run_fused_add_test(M, N, dtype) + if not ok: + failures += 1 + + if do_compare: + import torch + + aiter_us = None + if maybe_enable_aiter(): + try: + from aiter.ops.triton.normalization.norm import layernorm2d_fwd_with_add + + torch_dtype = DTYPE_BF16 if dtype == "bf16" else (DTYPE_FP16 if dtype == "f16" else DTYPE_FP32) + x = torch.randn((M, N), device="cuda", dtype=torch_dtype).contiguous() + residual = torch.randn((M, N), device="cuda", dtype=torch_dtype).contiguous() + residual_out = torch.empty_like(x) + out = torch.empty_like(x) + w = torch.rand((N,), device="cuda", dtype=torch_dtype).contiguous() + b = torch.rand((N,), device="cuda", dtype=torch_dtype).contiguous() + + def run_aiter(): + layernorm2d_fwd_with_add(out, x, residual, residual_out, w, b, EPS) + + aiter_us = bench_gpu_us_torch(run_aiter, warmup=WARMUP_ITERS, iters=BENCH_ITERS) + print(f"[Perf] AIter fused_add layernorm gpu: {aiter_us:.1f} us") + except Exception as e: + print(f"[Perf] AIter fused_add layernorm skipped: {type(e).__name__}: {e!r}") + + perf_rows.append( + PerfRow( + op="layernorm_fused_add", + shape=f"{M}x{N}", + dtype=dtype, + flydsl_gpu_us=flydsl_gpu_us, + aiter_gpu_us=aiter_us, + ) + ) + + print("\n" + "=" * 80) + if failures == 0: + print("ALL TESTS PASSED") + else: + print(f"{failures} TESTS FAILED") + print("=" * 80) + if do_compare and perf_rows: + print_perf_table(perf_rows) + if failures != 0: + raise SystemExit(1) + + +def run_dynamicquant_test(M: int, N: int, dtype: str = "f32"): + print(f"\nTesting LayerNorm DynamicQuant (M={M}, N={N}, dtype={dtype})") + + try: + launch_fn = build_layernorm_dynamicquant_module(M, N, dtype) + except Exception as e: + print(f"[FAIL] Compile failed for dynamicquant layernorm (M={M}, N={N}, dtype={dtype}): {type(e).__name__}: {e}") + return False, None + + torch.manual_seed(42) + input_t = torch.randn((M, N), device="cuda", dtype=DTYPE_FP32) + gamma_t = torch.rand((N,), device="cuda", dtype=DTYPE_FP32) + beta_t = torch.rand((N,), device="cuda", dtype=DTYPE_FP32) + + if dtype == "f32": + input_dev = input_t.contiguous() + gamma_dev = gamma_t.contiguous() + beta_dev = beta_t.contiguous() + input_ref = input_dev.to(DTYPE_FP32) + gamma_ref = gamma_dev.to(DTYPE_FP32) + beta_ref = beta_dev.to(DTYPE_FP32) + elif dtype == "f16": + input_dev = input_t.to(DTYPE_FP16).contiguous() + gamma_dev = gamma_t.to(DTYPE_FP16).contiguous() + beta_dev = beta_t.to(DTYPE_FP16).contiguous() + input_ref = input_dev.to(DTYPE_FP32) + gamma_ref = gamma_dev.to(DTYPE_FP32) + beta_ref = beta_dev.to(DTYPE_FP32) + elif dtype == "bf16": + input_dev = input_t.to(DTYPE_BF16).contiguous() + gamma_dev = gamma_t.to(DTYPE_BF16).contiguous() + beta_dev = beta_t.to(DTYPE_BF16).contiguous() + input_ref = input_dev.to(DTYPE_FP32) + gamma_ref = gamma_dev.to(DTYPE_FP32) + beta_ref = beta_dev.to(DTYPE_FP32) + else: + raise ValueError(f"unsupported dtype: {dtype}") + + output_dev = torch.empty((M, N), device="cuda", dtype=DTYPE_INT8) + yscale_dev = torch.empty((M,), device="cuda", dtype=DTYPE_FP32) + + x = input_ref + gamma = gamma_ref + beta = beta_ref + mean = x.mean(dim=1, keepdim=True) + var = x.var(dim=1, keepdim=True, unbiased=False) + expected = (x - mean) / torch.sqrt(var + EPS) * gamma + beta + yscale_expected = expected.abs().amax(dim=1) / 127.0 + yscale_expected = torch.where(yscale_expected == 0, torch.ones_like(yscale_expected), yscale_expected) + q_expected = torch.clamp(torch.trunc(expected / yscale_expected.unsqueeze(1)), -127, 127).to(DTYPE_INT8) + + print("Launching kernel...") + stream = torch.cuda.current_stream() + + def kernel_launch(): + launch_fn(input_dev, gamma_dev, beta_dev, output_dev, yscale_dev, M, stream=stream) + + kernel_launch() + torch.cuda.synchronize() + + _, avg_us = run_perftest( + lambda: (kernel_launch(), torch.cuda.synchronize()), num_iters=BENCH_ITERS, num_warmup=WARMUP_ITERS + ) + torch.cuda.synchronize() + flydsl_gpu_us = None + if os.environ.get("ROCDSL_COMPARE_AITER", "0") == "1": + flydsl_gpu_us = bench_gpu_us_torch(kernel_launch, warmup=WARMUP_ITERS, iters=BENCH_ITERS) + avg_ms = avg_us / 1000.0 + elem_bytes = 4 if dtype == "f32" else 2 + total_bytes = (M * N + 2 * N) * elem_bytes + M * N + M * 4 + bandwidth_gbs = total_bytes / (avg_us / 1e6) / 1e9 + print(f"Kernel avg time: {avg_ms:.4f} ms via run_perftest (warmup={WARMUP_ITERS}, iters={BENCH_ITERS})") + print(f"Bandwidth: {bandwidth_gbs:.2f} GB/s") + if flydsl_gpu_us is not None: + print(f"[Perf] FlyDSL layernorm dynamicquant gpu: {flydsl_gpu_us:.1f} us") + + output_ref = output_dev.to(DTYPE_FP32) * yscale_dev.unsqueeze(1) + q_out = output_dev.to(torch.int16) + q_ref = q_expected.to(torch.int16) + yscale_out = yscale_dev.cpu() + yscale_ref = yscale_expected.cpu() + + recon_error = (output_ref - expected).abs().max().item() + scale_diff = (yscale_out - yscale_ref).abs().max().item() + quant_diff = (q_out - q_ref).abs().max().item() + + print(f"Max recon error: {recon_error:.2e} (tol=0.3)") + print(f"Max scale diff: {scale_diff:.2e} (tol=1e-2)") + print(f"Max quant diff: {quant_diff}") + + if recon_error < 0.3 and scale_diff < 1e-2 and quant_diff <= 1: + print("PASSED") + ok = True + else: + print("FAILED") + print("First row Expected:") + print(expected[0, :5]) + print("First row Actual:") + print(output_ref[0, :5]) + print("First row Quant Expected:") + print(q_ref[0, :8]) + print("First row Quant Actual:") + print(q_out[0, :8]) + print("First few YScale Expected:") + print(yscale_ref[:5]) + print("First few YScale Actual:") + print(yscale_out[:5]) + ok = False + + return ok, flydsl_gpu_us + + +def test_layernorm_dynamicquant(): + print("=" * 80) + print("Running LayerNorm DynamicQuant Tests") + print("=" * 80) + + shapes_env = os.environ.get("ROCDSL_LAYERNORM_SHAPES", "").strip() + if shapes_env: + configs = [] + for part in shapes_env.split(";"): + p = part.strip() + if not p: + continue + m_s, n_s, dt = [x.strip() for x in p.split(",")] + configs.append((int(m_s), int(n_s), dt)) + else: + configs = [ + # (64, 256, "f32"), # Aligned + # (128, 1024, "f32"), # Aligned + # (32, 128, "f16"), # Aligned + # (64, 2000, "f32"), # Unaligned (tail handling) + # (16, 512, "bf16"), # BF16 + # (1024, 8192, "bf16"), # BF16 + (32768, 8192, "bf16"), + ] + + do_compare = os.environ.get("ROCDSL_COMPARE_AITER", "0") == "1" + perf_rows = [] + failures = 0 + + for M, N, dtype in configs: + ok, flydsl_gpu_us = run_dynamicquant_test(M, N, dtype) + if not ok: + failures += 1 + + if do_compare: + import torch + + aiter_us = None + if maybe_enable_aiter(): + try: + from aiter.ops.triton.normalization.norm import layernorm2d_fwd_with_dynamicquant + + torch_dtype = DTYPE_BF16 if dtype == "bf16" else (DTYPE_FP16 if dtype == "f16" else DTYPE_FP32) + x = torch.randn((M, N), device="cuda", dtype=torch_dtype).contiguous() + w = torch.rand((N,), device="cuda", dtype=torch_dtype).contiguous() + b = torch.rand((N,), device="cuda", dtype=torch_dtype).contiguous() + q_out = torch.empty((M, N), device="cuda", dtype=DTYPE_INT8) + yscale = torch.empty((M, 1), device="cuda", dtype=DTYPE_FP32) + + def run_aiter(): + layernorm2d_fwd_with_dynamicquant(q_out, x, yscale, w, b, EPS) + + aiter_us = bench_gpu_us_torch(run_aiter, warmup=WARMUP_ITERS, iters=BENCH_ITERS) + print(f"[Perf] AIter layernorm dynamicquant gpu: {aiter_us:.1f} us") + except Exception as e: + print(f"[Perf] AIter layernorm dynamicquant skipped: {type(e).__name__}: {e!r}") + + perf_rows.append( + PerfRow( + op="layernorm_dynamicquant", + shape=f"{M}x{N}", + dtype=dtype, + flydsl_gpu_us=flydsl_gpu_us, + aiter_gpu_us=aiter_us, + ) + ) + + print("\n" + "=" * 80) + if failures == 0: + print("ALL TESTS PASSED") + else: + print(f"{failures} TESTS FAILED") + print("=" * 80) + if do_compare and perf_rows: + print_perf_table(perf_rows) + if failures != 0: + raise SystemExit(1) + + +def run_smoothquant_test(M: int, N: int, dtype: str = "f32"): + print(f"\nTesting LayerNorm SmoothQuant (M={M}, N={N}, dtype={dtype})") + + try: + launch_fn = build_layernorm_smoothquant_module(M, N, dtype) + except Exception as e: + print(f"[FAIL] Compile failed for smoothquant layernorm (M={M}, N={N}, dtype={dtype}): {type(e).__name__}: {e}") + return False, None + + torch.manual_seed(42) + input_t = torch.randn((M, N), device="cuda", dtype=DTYPE_FP32) + gamma_t = torch.rand((N,), device="cuda", dtype=DTYPE_FP32) + beta_t = torch.rand((N,), device="cuda", dtype=DTYPE_FP32) + xscale_t = torch.rand((N,), device="cuda", dtype=DTYPE_FP32) + 0.5 + + if dtype == "f32": + input_dev = input_t.contiguous() + gamma_dev = gamma_t.contiguous() + beta_dev = beta_t.contiguous() + xscale_dev = xscale_t.contiguous() + input_ref = input_dev.to(DTYPE_FP32) + gamma_ref = gamma_dev.to(DTYPE_FP32) + beta_ref = beta_dev.to(DTYPE_FP32) + xscale_ref = xscale_dev.to(DTYPE_FP32) + elif dtype == "f16": + input_dev = input_t.to(DTYPE_FP16).contiguous() + gamma_dev = gamma_t.to(DTYPE_FP16).contiguous() + beta_dev = beta_t.to(DTYPE_FP16).contiguous() + xscale_dev = xscale_t.to(DTYPE_FP16).contiguous() + input_ref = input_dev.to(DTYPE_FP32) + gamma_ref = gamma_dev.to(DTYPE_FP32) + beta_ref = beta_dev.to(DTYPE_FP32) + xscale_ref = xscale_dev.to(DTYPE_FP32) + elif dtype == "bf16": + input_dev = input_t.to(DTYPE_BF16).contiguous() + gamma_dev = gamma_t.to(DTYPE_BF16).contiguous() + beta_dev = beta_t.to(DTYPE_BF16).contiguous() + xscale_dev = xscale_t.to(DTYPE_BF16).contiguous() + input_ref = input_dev.to(DTYPE_FP32) + gamma_ref = gamma_dev.to(DTYPE_FP32) + beta_ref = beta_dev.to(DTYPE_FP32) + xscale_ref = xscale_dev.to(DTYPE_FP32) + else: + raise ValueError(f"unsupported dtype: {dtype}") + + output_dev = torch.empty((M, N), device="cuda", dtype=DTYPE_INT8) + yscale_dev = torch.empty((M,), device="cuda", dtype=DTYPE_FP32) + + x = input_ref + gamma = gamma_ref + beta = beta_ref + mean = x.mean(dim=1, keepdim=True) + var = x.var(dim=1, keepdim=True, unbiased=False) + expected = (x - mean) / torch.sqrt(var + EPS) * gamma + beta + expected = expected * xscale_ref + yscale_expected = expected.abs().amax(dim=1) / 127.0 + yscale_expected = torch.where(yscale_expected == 0, torch.ones_like(yscale_expected), yscale_expected) + q_expected = torch.clamp(torch.trunc(expected / yscale_expected.unsqueeze(1)), -127, 127).to(DTYPE_INT8) + + print("Launching kernel...") + stream = torch.cuda.current_stream() + + def kernel_launch(): + launch_fn(input_dev, gamma_dev, beta_dev, xscale_dev, output_dev, yscale_dev, M, stream=stream) + + kernel_launch() + torch.cuda.synchronize() + + _, avg_us = run_perftest( + lambda: (kernel_launch(), torch.cuda.synchronize()), num_iters=BENCH_ITERS, num_warmup=WARMUP_ITERS + ) + torch.cuda.synchronize() + flydsl_gpu_us = None + if os.environ.get("ROCDSL_COMPARE_AITER", "0") == "1": + flydsl_gpu_us = bench_gpu_us_torch(kernel_launch, warmup=WARMUP_ITERS, iters=BENCH_ITERS) + avg_ms = avg_us / 1000.0 + elem_bytes = 4 if dtype == "f32" else 2 + total_bytes = (M * N + 3 * N) * elem_bytes + M * N + M * 4 + bandwidth_gbs = total_bytes / (avg_us / 1e6) / 1e9 + print(f"Kernel avg time: {avg_ms:.4f} ms via run_perftest (warmup={WARMUP_ITERS}, iters={BENCH_ITERS})") + print(f"Bandwidth: {bandwidth_gbs:.2f} GB/s") + if flydsl_gpu_us is not None: + print(f"[Perf] FlyDSL layernorm smoothquant gpu: {flydsl_gpu_us:.1f} us") + + output_ref = output_dev.to(DTYPE_FP32) * yscale_dev.unsqueeze(1) + q_out = output_dev.to(torch.int16) + q_ref = q_expected.to(torch.int16) + yscale_out = yscale_dev.cpu() + yscale_ref = yscale_expected.cpu() + + recon_error = (output_ref - expected).abs().max().item() + scale_diff = (yscale_out - yscale_ref).abs().max().item() + quant_diff = (q_out - q_ref).abs().max().item() + + print(f"Max recon error: {recon_error:.2e} (tol=0.3)") + print(f"Max scale diff: {scale_diff:.2e} (tol=1e-2)") + print(f"Max quant diff: {quant_diff}") + + if recon_error < 0.3 and scale_diff < 1e-2 and quant_diff <= 1: + print("PASSED") + ok = True + else: + print("FAILED") + print("First row Expected:") + print(expected[0, :5]) + print("First row Actual:") + print(output_ref[0, :5]) + print("First row Quant Expected:") + print(q_ref[0, :8]) + print("First row Quant Actual:") + print(q_out[0, :8]) + print("First few YScale Expected:") + print(yscale_ref[:5]) + print("First few YScale Actual:") + print(yscale_out[:5]) + ok = False + + return ok, flydsl_gpu_us + + +def test_layernorm_smoothquant(): + print("=" * 80) + print("Running LayerNorm SmoothQuant Tests") + print("=" * 80) + + shapes_env = os.environ.get("ROCDSL_LAYERNORM_SHAPES", "").strip() + if shapes_env: + configs = [] + for part in shapes_env.split(";"): + p = part.strip() + if not p: + continue + m_s, n_s, dt = [x.strip() for x in p.split(",")] + configs.append((int(m_s), int(n_s), dt)) + else: + configs = [ + # (64, 256, "f32"), # Aligned + # (128, 1024, "f32"), # Aligned + # (32, 128, "f16"), # Aligned + # (64, 2000, "f32"), # Unaligned (tail handling) + # (16, 512, "bf16"), # BF16 + # (1024, 8192, "bf16"), # BF16 + (32768, 8192, "bf16"), + ] + + do_compare = os.environ.get("ROCDSL_COMPARE_AITER", "0") == "1" + perf_rows = [] + failures = 0 + + for M, N, dtype in configs: + ok, flydsl_gpu_us = run_smoothquant_test(M, N, dtype) + if not ok: + failures += 1 + + if do_compare: + import torch + + aiter_us = None + if maybe_enable_aiter(): + try: + from aiter.ops.triton.normalization.norm import layernorm2d_fwd_with_smoothquant + + torch_dtype = DTYPE_BF16 if dtype == "bf16" else (DTYPE_FP16 if dtype == "f16" else DTYPE_FP32) + x = torch.randn((M, N), device="cuda", dtype=torch_dtype).contiguous() + w = torch.rand((N,), device="cuda", dtype=torch_dtype).contiguous() + b = torch.rand((N,), device="cuda", dtype=torch_dtype).contiguous() + xscale = (torch.rand((N,), device="cuda", dtype=torch_dtype) + 0.5).contiguous() + q_out = torch.empty((M, N), device="cuda", dtype=DTYPE_INT8) + yscale = torch.empty((M, 1), device="cuda", dtype=DTYPE_FP32) + + def run_aiter(): + layernorm2d_fwd_with_smoothquant(q_out, x, xscale, yscale, w, b, EPS) + + aiter_us = bench_gpu_us_torch(run_aiter, warmup=WARMUP_ITERS, iters=BENCH_ITERS) + print(f"[Perf] AIter layernorm smoothquant gpu: {aiter_us:.1f} us") + except Exception as e: + print(f"[Perf] AIter layernorm smoothquant skipped: {type(e).__name__}: {e!r}") + + perf_rows.append( + PerfRow( + op="layernorm_smoothquant", + shape=f"{M}x{N}", + dtype=dtype, + flydsl_gpu_us=flydsl_gpu_us, + aiter_gpu_us=aiter_us, + ) + ) + + print("\n" + "=" * 80) + if failures == 0: + print("ALL TESTS PASSED") + else: + print(f"{failures} TESTS FAILED") + print("=" * 80) + if do_compare and perf_rows: + print_perf_table(perf_rows) + if failures != 0: + raise SystemExit(1) + + +def run_fused_add_dynamicquant_test(M: int, N: int, dtype: str = "f32"): + print(f"\nTesting FusedAdd LayerNorm DynamicQuant (M={M}, N={N}, dtype={dtype})") + + try: + launch_fn = build_fused_add_layernorm_dynamicquant_module(M, N, dtype) + except Exception as e: + print( + f"[FAIL] Compile failed for fused_add dynamicquant layernorm " + f"(M={M}, N={N}, dtype={dtype}): {type(e).__name__}: {e}" + ) + return False, None + + torch.manual_seed(42) + input_t = torch.randn((M, N), device="cuda", dtype=DTYPE_FP32) + residual_t = torch.randn((M, N), device="cuda", dtype=DTYPE_FP32) + gamma_t = torch.rand((N,), device="cuda", dtype=DTYPE_FP32) + beta_t = torch.rand((N,), device="cuda", dtype=DTYPE_FP32) + + if dtype == "f32": + input_dev = input_t.contiguous() + residual_dev = residual_t.contiguous() + gamma_dev = gamma_t.contiguous() + beta_dev = beta_t.contiguous() + residual_out_dev = torch.empty((M, N), device="cuda", dtype=DTYPE_FP32) + input_ref = input_dev.to(DTYPE_FP32) + residual_ref = residual_dev.to(DTYPE_FP32) + gamma_ref = gamma_dev.to(DTYPE_FP32) + beta_ref = beta_dev.to(DTYPE_FP32) + residual_atol = 1e-4 + elif dtype == "f16": + input_dev = input_t.to(DTYPE_FP16).contiguous() + residual_dev = residual_t.to(DTYPE_FP16).contiguous() + gamma_dev = gamma_t.to(DTYPE_FP16).contiguous() + beta_dev = beta_t.to(DTYPE_FP16).contiguous() + residual_out_dev = torch.empty((M, N), device="cuda", dtype=DTYPE_FP16) + input_ref = input_dev.to(DTYPE_FP32) + residual_ref = residual_dev.to(DTYPE_FP32) + gamma_ref = gamma_dev.to(DTYPE_FP32) + beta_ref = beta_dev.to(DTYPE_FP32) + residual_atol = 1e-2 + elif dtype == "bf16": + input_dev = input_t.to(DTYPE_BF16).contiguous() + residual_dev = residual_t.to(DTYPE_BF16).contiguous() + gamma_dev = gamma_t.to(DTYPE_BF16).contiguous() + beta_dev = beta_t.to(DTYPE_BF16).contiguous() + residual_out_dev = torch.empty((M, N), device="cuda", dtype=DTYPE_BF16) + input_ref = input_dev.to(DTYPE_FP32) + residual_ref = residual_dev.to(DTYPE_FP32) + gamma_ref = gamma_dev.to(DTYPE_FP32) + beta_ref = beta_dev.to(DTYPE_FP32) + residual_atol = 2e-2 + else: + raise ValueError(f"unsupported dtype: {dtype}") + + output_dev = torch.empty((M, N), device="cuda", dtype=DTYPE_INT8) + yscale_dev = torch.empty((M,), device="cuda", dtype=DTYPE_FP32) + + residual_expected = input_ref + residual_ref + x = residual_expected + gamma = gamma_ref + beta = beta_ref + mean = x.mean(dim=1, keepdim=True) + var = x.var(dim=1, keepdim=True, unbiased=False) + expected = (x - mean) / torch.sqrt(var + EPS) * gamma + beta + yscale_expected = expected.abs().amax(dim=1) / 127.0 + yscale_expected = torch.where(yscale_expected == 0, torch.ones_like(yscale_expected), yscale_expected) + q_expected = torch.clamp(torch.trunc(expected / yscale_expected.unsqueeze(1)), -127, 127).to(DTYPE_INT8) + + print("Launching kernel...") + stream = torch.cuda.current_stream() + + def kernel_launch(): + launch_fn(input_dev, residual_dev, gamma_dev, beta_dev, output_dev, residual_out_dev, yscale_dev, M, stream=stream) + + kernel_launch() + torch.cuda.synchronize() + + _, avg_us = run_perftest( + lambda: (kernel_launch(), torch.cuda.synchronize()), num_iters=BENCH_ITERS, num_warmup=WARMUP_ITERS + ) + torch.cuda.synchronize() + flydsl_gpu_us = None + if os.environ.get("ROCDSL_COMPARE_AITER", "0") == "1": + flydsl_gpu_us = bench_gpu_us_torch(kernel_launch, warmup=WARMUP_ITERS, iters=BENCH_ITERS) + avg_ms = avg_us / 1000.0 + elem_bytes = 4 if dtype == "f32" else 2 + total_bytes = (3 * M * N + 2 * N) * elem_bytes + M * N + M * 4 + bandwidth_gbs = total_bytes / (avg_us / 1e6) / 1e9 + print(f"Kernel avg time: {avg_ms:.4f} ms via run_perftest (warmup={WARMUP_ITERS}, iters={BENCH_ITERS})") + print(f"Bandwidth: {bandwidth_gbs:.2f} GB/s") + if flydsl_gpu_us is not None: + print(f"[Perf] FlyDSL fused_add layernorm dynamicquant gpu: {flydsl_gpu_us:.1f} us") + + residual_out_ref = residual_out_dev.to(DTYPE_FP32) + output_ref = output_dev.to(DTYPE_FP32) * yscale_dev.unsqueeze(1) + q_out = output_dev.to(torch.int16) + q_ref = q_expected.to(torch.int16) + yscale_out = yscale_dev.cpu() + yscale_ref = yscale_expected.cpu() + + residual_error = (residual_out_ref - residual_expected).abs().max().item() + recon_error = (output_ref - expected).abs().max().item() + scale_diff = (yscale_out - yscale_ref).abs().max().item() + quant_diff = (q_out - q_ref).abs().max().item() + + print(f"Max residual error: {residual_error:.2e} (atol={residual_atol})") + print(f"Max recon error: {recon_error:.2e} (tol=0.3)") + print(f"Max scale diff: {scale_diff:.2e} (tol=1e-2)") + print(f"Max quant diff: {quant_diff}") + + if residual_error < residual_atol and recon_error < 0.3 and scale_diff < 1e-2 and quant_diff <= 1: + print("PASSED") + ok = True + else: + print("FAILED") + print("First row Residual Expected:") + print(residual_expected[0, :5]) + print("First row Residual Actual:") + print(residual_out_ref[0, :5]) + print("First row Expected:") + print(expected[0, :5]) + print("First row Actual:") + print(output_ref[0, :5]) + print("First row Quant Expected:") + print(q_ref[0, :8]) + print("First row Quant Actual:") + print(q_out[0, :8]) + print("First few YScale Expected:") + print(yscale_ref[:5]) + print("First few YScale Actual:") + print(yscale_out[:5]) + ok = False + + return ok, flydsl_gpu_us + + +def test_fused_add_layernorm_dynamicquant(): + print("=" * 80) + print("Running FusedAdd LayerNorm DynamicQuant Tests") + print("=" * 80) + + shapes_env = os.environ.get("ROCDSL_LAYERNORM_SHAPES", "").strip() + if shapes_env: + configs = [] + for part in shapes_env.split(";"): + p = part.strip() + if not p: + continue + m_s, n_s, dt = [x.strip() for x in p.split(",")] + configs.append((int(m_s), int(n_s), dt)) + else: + configs = [ + # (64, 256, "f32"), # Aligned + # (128, 1024, "f32"), # Aligned + # (32, 128, "f16"), # Aligned + # (64, 2000, "f32"), # Unaligned (tail handling) + # (16, 512, "bf16"), # BF16 + # (1024, 8192, "bf16"), # BF16 + (32768, 8192, "bf16"), + ] + + do_compare = os.environ.get("ROCDSL_COMPARE_AITER", "0") == "1" + perf_rows = [] + failures = 0 + + for M, N, dtype in configs: + ok, flydsl_gpu_us = run_fused_add_dynamicquant_test(M, N, dtype) + if not ok: + failures += 1 + + if do_compare: + import torch + + aiter_us = None + if maybe_enable_aiter(): + try: + from aiter.ops.triton.normalization.norm import layernorm2d_fwd_with_add_dynamicquant + + torch_dtype = DTYPE_BF16 if dtype == "bf16" else (DTYPE_FP16 if dtype == "f16" else DTYPE_FP32) + x = torch.randn((M, N), device="cuda", dtype=torch_dtype).contiguous() + residual = torch.randn((M, N), device="cuda", dtype=torch_dtype).contiguous() + residual_out = torch.empty_like(x) + w = torch.rand((N,), device="cuda", dtype=torch_dtype).contiguous() + b = torch.rand((N,), device="cuda", dtype=torch_dtype).contiguous() + q_out = torch.empty((M, N), device="cuda", dtype=DTYPE_INT8) + yscale = torch.empty((M, 1), device="cuda", dtype=DTYPE_FP32) + + def run_aiter(): + layernorm2d_fwd_with_add_dynamicquant(q_out, x, residual, residual_out, yscale, w, b, EPS) + + aiter_us = bench_gpu_us_torch(run_aiter, warmup=WARMUP_ITERS, iters=BENCH_ITERS) + print(f"[Perf] AIter fused_add layernorm dynamicquant gpu: {aiter_us:.1f} us") + except Exception as e: + print(f"[Perf] AIter fused_add layernorm dynamicquant skipped: {type(e).__name__}: {e!r}") + + perf_rows.append( + PerfRow( + op="layernorm_fused_add_dynamicquant", + shape=f"{M}x{N}", + dtype=dtype, + flydsl_gpu_us=flydsl_gpu_us, + aiter_gpu_us=aiter_us, + ) + ) + + print("\n" + "=" * 80) + if failures == 0: + print("ALL TESTS PASSED") + else: + print(f"{failures} TESTS FAILED") + print("=" * 80) + if do_compare and perf_rows: + print_perf_table(perf_rows) + if failures != 0: + raise SystemExit(1) + + +def run_fused_add_smoothquant_test(M: int, N: int, dtype: str = "f32"): + print(f"\nTesting FusedAdd LayerNorm SmoothQuant (M={M}, N={N}, dtype={dtype})") + + try: + launch_fn = build_fused_add_layernorm_smoothquant_module(M, N, dtype) + except Exception as e: + print( + f"[FAIL] Compile failed for fused_add smoothquant layernorm " + f"(M={M}, N={N}, dtype={dtype}): {type(e).__name__}: {e}" + ) + return False, None + + torch.manual_seed(42) + input_t = torch.randn((M, N), device="cuda", dtype=DTYPE_FP32) + residual_t = torch.randn((M, N), device="cuda", dtype=DTYPE_FP32) + gamma_t = torch.rand((N,), device="cuda", dtype=DTYPE_FP32) + beta_t = torch.rand((N,), device="cuda", dtype=DTYPE_FP32) + xscale_t = torch.rand((N,), device="cuda", dtype=DTYPE_FP32) + 0.5 + + if dtype == "f32": + input_dev = input_t.contiguous() + residual_dev = residual_t.contiguous() + gamma_dev = gamma_t.contiguous() + beta_dev = beta_t.contiguous() + xscale_dev = xscale_t.contiguous() + residual_out_dev = torch.empty((M, N), device="cuda", dtype=DTYPE_FP32) + input_ref = input_dev.to(DTYPE_FP32) + residual_ref = residual_dev.to(DTYPE_FP32) + gamma_ref = gamma_dev.to(DTYPE_FP32) + beta_ref = beta_dev.to(DTYPE_FP32) + xscale_ref = xscale_dev.to(DTYPE_FP32) + residual_atol = 1e-4 + elif dtype == "f16": + input_dev = input_t.to(DTYPE_FP16).contiguous() + residual_dev = residual_t.to(DTYPE_FP16).contiguous() + gamma_dev = gamma_t.to(DTYPE_FP16).contiguous() + beta_dev = beta_t.to(DTYPE_FP16).contiguous() + xscale_dev = xscale_t.to(DTYPE_FP16).contiguous() + residual_out_dev = torch.empty((M, N), device="cuda", dtype=DTYPE_FP16) + input_ref = input_dev.to(DTYPE_FP32) + residual_ref = residual_dev.to(DTYPE_FP32) + gamma_ref = gamma_dev.to(DTYPE_FP32) + beta_ref = beta_dev.to(DTYPE_FP32) + xscale_ref = xscale_dev.to(DTYPE_FP32) + residual_atol = 1e-2 + elif dtype == "bf16": + input_dev = input_t.to(DTYPE_BF16).contiguous() + residual_dev = residual_t.to(DTYPE_BF16).contiguous() + gamma_dev = gamma_t.to(DTYPE_BF16).contiguous() + beta_dev = beta_t.to(DTYPE_BF16).contiguous() + xscale_dev = xscale_t.to(DTYPE_BF16).contiguous() + residual_out_dev = torch.empty((M, N), device="cuda", dtype=DTYPE_BF16) + input_ref = input_dev.to(DTYPE_FP32) + residual_ref = residual_dev.to(DTYPE_FP32) + gamma_ref = gamma_dev.to(DTYPE_FP32) + beta_ref = beta_dev.to(DTYPE_FP32) + xscale_ref = xscale_dev.to(DTYPE_FP32) + residual_atol = 2e-2 + else: + raise ValueError(f"unsupported dtype: {dtype}") + + output_dev = torch.empty((M, N), device="cuda", dtype=DTYPE_INT8) + yscale_dev = torch.empty((M,), device="cuda", dtype=DTYPE_FP32) + + residual_expected = input_ref + residual_ref + x = residual_expected + gamma = gamma_ref + beta = beta_ref + mean = x.mean(dim=1, keepdim=True) + var = x.var(dim=1, keepdim=True, unbiased=False) + expected = (x - mean) / torch.sqrt(var + EPS) * gamma + beta + expected = expected * xscale_ref + yscale_expected = expected.abs().amax(dim=1) / 127.0 + yscale_expected = torch.where(yscale_expected == 0, torch.ones_like(yscale_expected), yscale_expected) + q_expected = torch.clamp(torch.trunc(expected / yscale_expected.unsqueeze(1)), -127, 127).to(DTYPE_INT8) + + print("Launching kernel...") + stream = torch.cuda.current_stream() + + def kernel_launch(): + launch_fn( + input_dev, + residual_dev, + gamma_dev, + beta_dev, + xscale_dev, + output_dev, + residual_out_dev, + yscale_dev, + M, + stream=stream, + ) + + kernel_launch() + torch.cuda.synchronize() + + _, avg_us = run_perftest( + lambda: (kernel_launch(), torch.cuda.synchronize()), num_iters=BENCH_ITERS, num_warmup=WARMUP_ITERS + ) + torch.cuda.synchronize() + flydsl_gpu_us = None + if os.environ.get("ROCDSL_COMPARE_AITER", "0") == "1": + flydsl_gpu_us = bench_gpu_us_torch(kernel_launch, warmup=WARMUP_ITERS, iters=BENCH_ITERS) + avg_ms = avg_us / 1000.0 + elem_bytes = 4 if dtype == "f32" else 2 + total_bytes = (3 * M * N + 3 * N) * elem_bytes + M * N + M * 4 + bandwidth_gbs = total_bytes / (avg_us / 1e6) / 1e9 + print(f"Kernel avg time: {avg_ms:.4f} ms via run_perftest (warmup={WARMUP_ITERS}, iters={BENCH_ITERS})") + print(f"Bandwidth: {bandwidth_gbs:.2f} GB/s") + if flydsl_gpu_us is not None: + print(f"[Perf] FlyDSL fused_add layernorm smoothquant gpu: {flydsl_gpu_us:.1f} us") + + residual_out_ref = residual_out_dev.to(DTYPE_FP32) + output_ref = output_dev.to(DTYPE_FP32) * yscale_dev.unsqueeze(1) + q_out = output_dev.to(torch.int16) + q_ref = q_expected.to(torch.int16) + yscale_out = yscale_dev.cpu() + yscale_ref = yscale_expected.cpu() + + residual_error = (residual_out_ref - residual_expected).abs().max().item() + recon_error = (output_ref - expected).abs().max().item() + scale_diff = (yscale_out - yscale_ref).abs().max().item() + quant_diff = (q_out - q_ref).abs().max().item() + + print(f"Max residual error: {residual_error:.2e} (atol={residual_atol})") + print(f"Max recon error: {recon_error:.2e} (tol=0.3)") + print(f"Max scale diff: {scale_diff:.2e} (tol=1e-2)") + print(f"Max quant diff: {quant_diff}") + + if residual_error < residual_atol and recon_error < 0.3 and scale_diff < 1e-2 and quant_diff <= 1: + print("PASSED") + ok = True + else: + print("FAILED") + print("First row Residual Expected:") + print(residual_expected[0, :5]) + print("First row Residual Actual:") + print(residual_out_ref[0, :5]) + print("First row Expected:") + print(expected[0, :5]) + print("First row Actual:") + print(output_ref[0, :5]) + print("First row Quant Expected:") + print(q_ref[0, :8]) + print("First row Quant Actual:") + print(q_out[0, :8]) + print("First few YScale Expected:") + print(yscale_ref[:5]) + print("First few YScale Actual:") + print(yscale_out[:5]) + ok = False + + return ok, flydsl_gpu_us + + +def test_fused_add_layernorm_smoothquant(): + print("=" * 80) + print("Running FusedAdd LayerNorm SmoothQuant Tests") + print("=" * 80) + + shapes_env = os.environ.get("ROCDSL_LAYERNORM_SHAPES", "").strip() + if shapes_env: + configs = [] + for part in shapes_env.split(";"): + p = part.strip() + if not p: + continue + m_s, n_s, dt = [x.strip() for x in p.split(",")] + configs.append((int(m_s), int(n_s), dt)) + else: + configs = [ + # (64, 256, "f32"), # Aligned + # (128, 1024, "f32"), # Aligned + # (32, 128, "f16"), # Aligned + # (64, 2000, "f32"), # Unaligned (tail handling) + # (16, 512, "bf16"), # BF16 + # (1024, 8192, "bf16"), # BF16 + (32768, 8192, "bf16"), + ] + + do_compare = os.environ.get("ROCDSL_COMPARE_AITER", "0") == "1" + perf_rows = [] + failures = 0 + + for M, N, dtype in configs: + ok, flydsl_gpu_us = run_fused_add_smoothquant_test(M, N, dtype) + if not ok: + failures += 1 + + if do_compare: + import torch + + aiter_us = None + if maybe_enable_aiter(): + try: + from aiter.ops.triton.normalization.norm import layernorm2d_fwd_with_add_smoothquant + + torch_dtype = DTYPE_BF16 if dtype == "bf16" else (DTYPE_FP16 if dtype == "f16" else DTYPE_FP32) + x = torch.randn((M, N), device="cuda", dtype=torch_dtype).contiguous() + residual = torch.randn((M, N), device="cuda", dtype=torch_dtype).contiguous() + residual_out = torch.empty_like(x) + w = torch.rand((N,), device="cuda", dtype=torch_dtype).contiguous() + b = torch.rand((N,), device="cuda", dtype=torch_dtype).contiguous() + xscale = (torch.rand((N,), device="cuda", dtype=torch_dtype) + 0.5).contiguous() + q_out = torch.empty((M, N), device="cuda", dtype=DTYPE_INT8) + yscale = torch.empty((M, 1), device="cuda", dtype=DTYPE_FP32) + + def run_aiter(): + layernorm2d_fwd_with_add_smoothquant(q_out, x, residual, residual_out, xscale, yscale, w, b, EPS) + + aiter_us = bench_gpu_us_torch(run_aiter, warmup=WARMUP_ITERS, iters=BENCH_ITERS) + print(f"[Perf] AIter fused_add layernorm smoothquant gpu: {aiter_us:.1f} us") + except Exception as e: + print(f"[Perf] AIter fused_add layernorm smoothquant skipped: {type(e).__name__}: {e!r}") + + perf_rows.append( + PerfRow( + op="layernorm_fused_add_smoothquant", + shape=f"{M}x{N}", + dtype=dtype, + flydsl_gpu_us=flydsl_gpu_us, + aiter_gpu_us=aiter_us, + ) + ) + + print("\n" + "=" * 80) + if failures == 0: + print("ALL TESTS PASSED") + else: + print(f"{failures} TESTS FAILED") + print("=" * 80) + if do_compare and perf_rows: + print_perf_table(perf_rows) + if failures != 0: + raise SystemExit(1) + + if __name__ == "__main__": - test_all() + test_layernorm() + test_fused_add_layernorm() + test_layernorm_dynamicquant() + test_layernorm_smoothquant() + test_fused_add_layernorm_dynamicquant() + test_fused_add_layernorm_smoothquant()