diff --git a/aiter/ops/triton/_triton_kernels/gemm/batched/batched_gemm_a16wfp4.py b/aiter/ops/triton/_triton_kernels/gemm/batched/batched_gemm_a16wfp4.py index 4c433c4ab6..9a1a55dd2f 100755 --- a/aiter/ops/triton/_triton_kernels/gemm/batched/batched_gemm_a16wfp4.py +++ b/aiter/ops/triton/_triton_kernels/gemm/batched/batched_gemm_a16wfp4.py @@ -317,4 +317,4 @@ def _get_config( ): # Note: Config files use K=2*K in their naming because FP4 weights are packed, # so the actual K dimension in the config file corresponds to 2*K unpacked elements - return get_gemm_config("BATCHED_GEMM_PREQUANT-AFP4WFP4", M, N, 2 * K) + return get_gemm_config("BATCHED_GEMM-A16WFP4", M, N, 2 * K) diff --git a/aiter/ops/triton/_triton_kernels/gemm/batched/batched_gemm_bf16.py b/aiter/ops/triton/_triton_kernels/gemm/batched/batched_gemm_bf16.py index ebcbe250e5..7b3df81a37 100644 --- a/aiter/ops/triton/_triton_kernels/gemm/batched/batched_gemm_bf16.py +++ b/aiter/ops/triton/_triton_kernels/gemm/batched/batched_gemm_bf16.py @@ -179,4 +179,5 @@ def _get_config( K: int, ): + # BF16 uses the shared 16-bit activation / 16-bit weight batched GEMM config. return get_gemm_config("BATCHED_GEMM-A16W16", M, N, K) diff --git a/aiter/ops/triton/configs/gemm/gfx950-BATCHED_GEMM-A16WFP4-N=128-K=512.json b/aiter/ops/triton/configs/gemm/gfx950-BATCHED_GEMM-A16WFP4-N=128-K=512.json new file mode 100644 index 0000000000..8899026fdf --- /dev/null +++ b/aiter/ops/triton/configs/gemm/gfx950-BATCHED_GEMM-A16WFP4-N=128-K=512.json @@ -0,0 +1,80 @@ +{ + "M_LEQ_16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 6, + "matrix_instr_nonkdim": 16, + "kpack": 1, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "M_LEQ_32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "kpack": 1, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "M_LEQ_64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "kpack": 1, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "M_LEQ_128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "kpack": 1, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "M_LEQ_256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "kpack": 1, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "any": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "kpack": 1, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + } +} diff --git a/aiter/ops/triton/configs/gemm/gfx950-BATCHED_GEMM-A16WFP4.json b/aiter/ops/triton/configs/gemm/gfx950-BATCHED_GEMM-A16WFP4.json new file mode 100644 index 0000000000..284df43967 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/gfx950-BATCHED_GEMM-A16WFP4.json @@ -0,0 +1,80 @@ +{ + "M_LEQ_16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 6, + "matrix_instr_nonkdim": 16, + "kpack": 1, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "M_LEQ_32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "kpack": 1, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "M_LEQ_64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "kpack": 1, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "M_LEQ_128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 2, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "kpack": 1, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "M_LEQ_256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 2, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "kpack": 1, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "any": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "kpack": 1, + "cache_modifier": null, + "NUM_KSPLIT": 1 + } +} \ No newline at end of file diff --git a/aiter/ops/triton/configs/gemm/gfx950-BATCHED_GEMM-A8W8-A_PER_TOKEN_GROUP_PREQUANT_W_PER_BATCHED_TENSOR_QUANT-N=128-K=512.json b/aiter/ops/triton/configs/gemm/gfx950-BATCHED_GEMM-A8W8-A_PER_TOKEN_GROUP_PREQUANT_W_PER_BATCHED_TENSOR_QUANT-N=128-K=512.json index 6d5ec6e3ca..ae0a6463ae 100644 --- a/aiter/ops/triton/configs/gemm/gfx950-BATCHED_GEMM-A8W8-A_PER_TOKEN_GROUP_PREQUANT_W_PER_BATCHED_TENSOR_QUANT-N=128-K=512.json +++ b/aiter/ops/triton/configs/gemm/gfx950-BATCHED_GEMM-A8W8-A_PER_TOKEN_GROUP_PREQUANT_W_PER_BATCHED_TENSOR_QUANT-N=128-K=512.json @@ -33,26 +33,26 @@ "cache_modifier": ".cg" }, "M_LEQ_128": { - "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 2, - "waves_per_eu": 2, - "matrix_instr_nonkdim": 16, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 32, "kpack": 1, - "cache_modifier": ".cg" + "cache_modifier": null }, "M_LEQ_256": { - "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 2, - "waves_per_eu": 2, - "matrix_instr_nonkdim": 16, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 32, "kpack": 1, - "cache_modifier": ".cg" + "cache_modifier": null }, "any": { "BLOCK_SIZE_M": 32, diff --git a/aiter/ops/triton/configs/gemm/gfx950-BATCHED_GEMM-A8W8-A_PER_TOKEN_GROUP_PREQUANT_W_PER_BATCHED_TENSOR_QUANT-N=512-K=128.json b/aiter/ops/triton/configs/gemm/gfx950-BATCHED_GEMM-A8W8-A_PER_TOKEN_GROUP_PREQUANT_W_PER_BATCHED_TENSOR_QUANT-N=512-K=128.json index 222f27b6d1..ed1007a4f0 100644 --- a/aiter/ops/triton/configs/gemm/gfx950-BATCHED_GEMM-A8W8-A_PER_TOKEN_GROUP_PREQUANT_W_PER_BATCHED_TENSOR_QUANT-N=512-K=128.json +++ b/aiter/ops/triton/configs/gemm/gfx950-BATCHED_GEMM-A8W8-A_PER_TOKEN_GROUP_PREQUANT_W_PER_BATCHED_TENSOR_QUANT-N=512-K=128.json @@ -11,46 +11,46 @@ "cache_modifier": ".cg" }, "M_LEQ_32": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, "GROUP_SIZE_M": 1, "num_warps": 8, - "num_stages": 2, + "num_stages": 1, "waves_per_eu": 1, - "matrix_instr_nonkdim": 16, + "matrix_instr_nonkdim": 32, "kpack": 1, "cache_modifier": ".cg" }, "M_LEQ_64": { "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 256, "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 1, - "waves_per_eu": 2, - "matrix_instr_nonkdim": 16, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 32, "kpack": 1, "cache_modifier": ".cg" }, "M_LEQ_128": { "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 256, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 8, "num_stages": 1, - "waves_per_eu": 6, - "matrix_instr_nonkdim": 16, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 32, "kpack": 1, "cache_modifier": ".cg" }, "M_LEQ_256": { "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 256, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 8, "num_stages": 1, - "waves_per_eu": 6, - "matrix_instr_nonkdim": 16, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 32, "kpack": 1, "cache_modifier": ".cg" }, diff --git a/aiter/ops/triton/configs/gemm/gfx950-BATCHED_GEMM-A8W8-A_PER_TOKEN_GROUP_PREQUANT_W_PER_BATCHED_TENSOR_QUANT.json b/aiter/ops/triton/configs/gemm/gfx950-BATCHED_GEMM-A8W8-A_PER_TOKEN_GROUP_PREQUANT_W_PER_BATCHED_TENSOR_QUANT.json index 43f7876e18..50ecbc9f5d 100644 --- a/aiter/ops/triton/configs/gemm/gfx950-BATCHED_GEMM-A8W8-A_PER_TOKEN_GROUP_PREQUANT_W_PER_BATCHED_TENSOR_QUANT.json +++ b/aiter/ops/triton/configs/gemm/gfx950-BATCHED_GEMM-A8W8-A_PER_TOKEN_GROUP_PREQUANT_W_PER_BATCHED_TENSOR_QUANT.json @@ -22,34 +22,34 @@ "cache_modifier": ".cg" }, "M_LEQ_64": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 2, - "waves_per_eu": 2, + "waves_per_eu": 1, "matrix_instr_nonkdim": 16, "kpack": 1, "cache_modifier": ".cg" }, "M_LEQ_128": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 2, - "waves_per_eu": 2, + "waves_per_eu": 1, "matrix_instr_nonkdim": 16, "kpack": 1, "cache_modifier": ".cg" }, "M_LEQ_256": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 2, - "waves_per_eu": 2, + "waves_per_eu": 1, "matrix_instr_nonkdim": 16, "kpack": 1, "cache_modifier": ".cg" diff --git a/op_tests/op_benchmarks/triton/bench_batched_gemm_a16wfp4.py b/op_tests/op_benchmarks/triton/bench_batched_gemm_a16wfp4.py index 80f775b78c..ae60ebc581 100644 --- a/op_tests/op_benchmarks/triton/bench_batched_gemm_a16wfp4.py +++ b/op_tests/op_benchmarks/triton/bench_batched_gemm_a16wfp4.py @@ -3,8 +3,8 @@ import triton import math import aiter.ops.triton.utils._triton.arch_info as arch_info -from aiter.ops.triton.gemm.batched.batched_gemm_afp4wfp4_pre_quant import ( - batched_gemm_afp4wfp4_pre_quant, +from aiter.ops.triton.gemm.batched.batched_gemm_a16wfp4 import ( + batched_gemm_a16wfp4, ) from op_tests.triton_tests.gemm.batched.test_batched_gemm_a16wfp4 import ( generate_batched_gemm_a16wfp4_inputs, @@ -47,7 +47,7 @@ def bench_gemm_fn( mem = mem_read + mem_write ms = triton.testing.do_bench( - lambda: batched_gemm_afp4wfp4_pre_quant(x, w, w_scale, c_dtype, y), + lambda: batched_gemm_a16wfp4(x, w, w_scale, c_dtype, y), warmup=25, rep=100, ) @@ -74,7 +74,7 @@ def run_model_benchmark(args): ) @triton.testing.perf_report([benchmark]) - def bench_batched_gemm_afp4wfp4_pre_quant( + def bench_batched_gemm_a16wfp4( M, hidden_dim, intermediate_dim, batch, metric, layer, **kwargs ): if layer == "fc1": @@ -92,7 +92,7 @@ def bench_batched_gemm_afp4wfp4_pre_quant( return bench_gemm_fn(batch, M, N, K, metric, args.layout) - bench_batched_gemm_afp4wfp4_pre_quant.run( + bench_batched_gemm_a16wfp4.run( save_path="." if args.o else None, print_data=True ) @@ -105,7 +105,7 @@ def run_shape_benchmark(args): ) @triton.testing.perf_report([benchmark]) - def bench_batched_gemm_afp4wfp4_pre_quant( + def bench_batched_gemm_a16wfp4( batch, M, N, @@ -115,7 +115,7 @@ def bench_batched_gemm_afp4wfp4_pre_quant( ): return bench_gemm_fn(batch, M, N, K, metric, args.layout) - bench_batched_gemm_afp4wfp4_pre_quant.run( + bench_batched_gemm_a16wfp4.run( save_path="." if args.o else None, print_data=True ) diff --git a/op_tests/op_benchmarks/triton/bench_batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant.py b/op_tests/op_benchmarks/triton/bench_batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant.py new file mode 100644 index 0000000000..f33517a9d4 --- /dev/null +++ b/op_tests/op_benchmarks/triton/bench_batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant.py @@ -0,0 +1,198 @@ +import math +import torch +import triton +from aiter.ops.triton.gemm.batched.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import ( + batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant, +) +from op_tests.triton_tests.gemm.batched.test_batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import ( + generate_batched_gemm_a16w8_inputs as generate_batched_gemm_a8w8_per_token_group_inputs, +) +from op_tests.op_benchmarks.triton.utils.argparse import ( + get_parser, + add_argparse_ff, + get_ff_args, +) +from op_tests.op_benchmarks.triton.utils.benchmark_utils import ( + get_model_benchmark_object, + get_shape_benchmark_object, + batched_model_benchmark_shapes, + print_vgpr, + get_caller_name_no_ext, +) + + +def bench_gemm_fn( + batch: int, + M: int, + N: int, + K: int, + metric: str, + layout: str, + group_size: int, + has_bias: bool, + transpose_bm: bool, +): + c_dtype = torch.bfloat16 + x, weight, w_scale, bias, y = generate_batched_gemm_a8w8_per_token_group_inputs( + batch, + M, + N, + K, + c_dtype, + has_bias=has_bias, + output=True, + layout=layout, + transpose_bm=transpose_bm, + ) + # flops + flops = 2.0 * batch * M * N * K + # memory transfer + mem = ( + x.numel() * x.element_size() + + weight.numel() * weight.element_size() + + w_scale.numel() * w_scale.element_size() + + (bias.numel() * bias.element_size() if bias is not None else 0) + + y.numel() * y.element_size() + ) + + fn = ( + lambda: batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant( + x, + weight, + w_scale, + group_size=group_size, + bias=bias, + dtype=c_dtype, + YQ=y, + transpose_bm=transpose_bm, + ) + ) + + ms = triton.testing.do_bench(fn, warmup=25, rep=100) + + # Return exactly one scalar depending on which metric is active + if metric == "time": + return ms + elif metric == "throughput": + return flops / ms * 1e-9 + elif metric == "bandwidth": + return mem / ms * 1e-6 + else: + raise ValueError(f"Unsupported metric: {metric}") + + +def run_model_benchmark(args): + plot_name = get_caller_name_no_ext() + x_names = ["M", "hidden_dim", "intermediate_dim", "batch", "model_name"] + benchmark = get_model_benchmark_object( + plot_name, + args, + x_names=x_names, + model_benchmark_shapes_fn=batched_model_benchmark_shapes, + ) + + @triton.testing.perf_report([benchmark]) + def bench_batched_gemm_a8w8_per_token_group_prequant_w_per_batched_tensor_quant( + M, hidden_dim, intermediate_dim, batch, metric, layer, **kwargs + ): + if layer == "fc1": + if args.no_glu: + N, K = intermediate_dim, hidden_dim + else: + N, K = intermediate_dim * 2, hidden_dim + N = math.ceil(N / args.tp) + elif layer == "fc2": + N, K = hidden_dim, intermediate_dim + K = math.ceil(K / args.tp) + else: + raise ValueError(f"Unsupported layer: {layer}") + + return bench_gemm_fn( + batch, + M, + N, + K, + metric, + args.layout, + args.group_size, + not args.no_bias, + args.transpose_bm, + ) + + bench_batched_gemm_a8w8_per_token_group_prequant_w_per_batched_tensor_quant.run( + save_path="." if args.o else None, print_data=True + ) + + +def run_shape_benchmark(args): + plot_name = get_caller_name_no_ext() + x_names = ["batch", "M", "N", "K"] + benchmark = get_shape_benchmark_object(plot_name, args, x_names=x_names) + + @triton.testing.perf_report([benchmark]) + def bench_batched_gemm_a8w8_per_token_group_prequant_w_per_batched_tensor_quant( + batch, M, N, K, metric, **kwargs + ): + return bench_gemm_fn( + batch, + M, + N, + K, + metric, + args.layout, + args.group_size, + not args.no_bias, + args.transpose_bm, + ) + + bench_batched_gemm_a8w8_per_token_group_prequant_w_per_batched_tensor_quant.run( + save_path="." if args.o else None, print_data=True + ) + + +def run_benchmark(args, defaults): + if args.model: + run_model_benchmark(args) + else: + run_shape_benchmark(args) + + +def parse_args(args: list[str] | None = None): + parser = get_parser( + "Batched A8W8 GEMM (A per-token-group pre-quant, W per-batched-tensor quant)" + ) + parser = add_argparse_ff(parser) + parser.add_argument("-B", type=int, default=None, help="Batch size") + parser.add_argument( + "--group-size", + type=int, + default=128, + dest="group_size", + help="Per-token group size for X quantization (default: 128).", + ) + parser.add_argument( + "--no-bias", + action="store_true", + default=False, + help="Disable bias.", + ) + parser.add_argument( + "--transpose-bm", + action="store_true", + default=False, + dest="transpose_bm", + help="Transpose batch and M dimensions in the output tensor.", + ) + return get_ff_args(parser, args=args) + + +def main(args: list[str] | None = None) -> None: + parsed_args, defaults = parse_args(args=args) + if parsed_args.print_vgpr: + print_vgpr(lambda: run_benchmark(parsed_args, defaults)) + return + run_benchmark(parsed_args, defaults) + + +if __name__ == "__main__": + main() diff --git a/op_tests/op_benchmarks/triton/bench_batched_gemm_a16w16.py b/op_tests/op_benchmarks/triton/bench_batched_gemm_bf16.py similarity index 74% rename from op_tests/op_benchmarks/triton/bench_batched_gemm_a16w16.py rename to op_tests/op_benchmarks/triton/bench_batched_gemm_bf16.py index 86fbcc3dd9..a99659e9b7 100644 --- a/op_tests/op_benchmarks/triton/bench_batched_gemm_a16w16.py +++ b/op_tests/op_benchmarks/triton/bench_batched_gemm_bf16.py @@ -3,7 +3,7 @@ import triton import math from op_tests.triton_tests.gemm.batched.test_batched_gemm_bf16 import ( - generate_batched_gemm_a16w16_inputs, + generate_batched_gemm_a16w16_inputs as generate_batched_gemm_bf16_inputs, ) from op_tests.op_benchmarks.triton.utils.argparse import ( get_parser, @@ -22,19 +22,38 @@ def bench_gemm_fn(batch: int, M: int, N: int, K: int, metric: str, layout: str): c_dtype = torch.bfloat16 - x, w, bias, y = generate_batched_gemm_a16w16_inputs( - batch, M, N, K, dtype=c_dtype, layout=layout, output=True + + x, w, bias, y = generate_batched_gemm_bf16_inputs( + batch, + M, + N, + K, + dtype=c_dtype, + layout=layout, + output=True, ) - # print(f"M: {M}, N: {N}, K: {K}, x.shape: {x.shape}, x.stride(): {x.stride()}, w.shape: {w.shape}, w.stride(): {w.stride()}") - # flops - flops = 2.0 * M * N * K * batch - # memory transfer - mem_read = x.numel() * x.element_size() + w.numel() * w.element_size() - mem_write = (M * N) * 2 # TODO: Fix for c_dtype != bf16 + + # FLOPs for batched GEMM: + # C[B, M, N] = A[B, M, K] @ W[B, N, K]^T + flops = 2.0 * batch * M * N * K + + # Memory traffic. + mem_read = ( + x.numel() * x.element_size() + + w.numel() * w.element_size() + + bias.numel() * bias.element_size() + ) + mem_write = y.numel() * y.element_size() mem = mem_read + mem_write ms = triton.testing.do_bench( - lambda: batched_gemm_bf16(x, w, bias, c_dtype, YQ=y), + lambda: batched_gemm_bf16( + x, + w, + bias=bias, + dtype=c_dtype, + YQ=y, + ), warmup=25, rep=100, ) @@ -61,8 +80,14 @@ def run_model_benchmark(args): ) @triton.testing.perf_report([benchmark]) - def bench_batched_gemm_a8w8( - M, hidden_dim, intermediate_dim, batch, metric, layer, **kwargs + def bench_batched_gemm_bf16( + M, + hidden_dim, + intermediate_dim, + batch, + metric, + layer, + **kwargs, ): if layer == "fc1": if args.no_glu: @@ -71,15 +96,21 @@ def bench_batched_gemm_a8w8( N, K = intermediate_dim * 2, hidden_dim # Divide N by tensor parallel N = math.ceil(N / args.tp) + elif layer == "fc2": N, K = hidden_dim, intermediate_dim # Divide K by tensor parallel K = math.ceil(K / args.tp) - # print(f"Layer: {layer}, B: {batch}, M: {M}, N: {N}, K: {K}, hidden_dim: {hidden_dim}, intermediate_dim: {intermediate_dim}") + + else: + raise ValueError(f"Unsupported layer: {layer}") return bench_gemm_fn(batch, M, N, K, metric, args.layout) - bench_batched_gemm_a8w8.run(save_path="." if args.o else None, print_data=True) + bench_batched_gemm_bf16.run( + save_path="." if args.o else None, + print_data=True, + ) def run_shape_benchmark(args): @@ -90,10 +121,13 @@ def run_shape_benchmark(args): ) @triton.testing.perf_report([benchmark]) - def bench_batched_gemm_a8w8(batch, M, N, K, metric, **kwargs): + def bench_batched_gemm_bf16(batch, M, N, K, metric, **kwargs): return bench_gemm_fn(batch, M, N, K, metric, args.layout) - bench_batched_gemm_a8w8.run(save_path="." if args.o else None, print_data=True) + bench_batched_gemm_bf16.run( + save_path="." if args.o else None, + print_data=True, + ) def run_benchmark(args, defaults): @@ -125,7 +159,7 @@ def run_benchmark(args, defaults): def parse_args(): - parser = get_parser("Batched A16W16 GEMM") + parser = get_parser("Batched BF16 GEMM") parser = add_argparse_ff(parser) parser.add_argument( "-B", @@ -138,11 +172,13 @@ def parse_args(): def main(): args, defaults = parse_args() + if args.print_vgpr: print("Retrieving VGPR usage for Triton kernels...") fun = lambda: run_benchmark(args, defaults) # noqa: E731 print_vgpr(fun, get_caller_name_no_ext()) return 0 + run_benchmark(args, defaults)