From 595423d6c69e701c690880fe592126bc904e2a61 Mon Sep 17 00:00:00 2001 From: neoblizz Date: Tue, 3 Feb 2026 17:14:03 +0000 Subject: [PATCH 01/31] Add benchmark capabilities for ops. --- benchmark/ops/all_gather_matmul/benchmark.py | 376 ++++++++++++++++ benchmark/ops/matmul_all_gather/benchmark.py | 367 +++++++++++++++ benchmark/ops/matmul_all_reduce/benchmark.py | 378 ++++++++++++++++ .../ops/matmul_reduce_scatter/benchmark.py | 421 ++++++++++++++++++ iris/ops/__init__.py | 12 +- 5 files changed, 1547 insertions(+), 7 deletions(-) create mode 100644 benchmark/ops/all_gather_matmul/benchmark.py create mode 100644 benchmark/ops/matmul_all_gather/benchmark.py create mode 100644 benchmark/ops/matmul_all_reduce/benchmark.py create mode 100644 benchmark/ops/matmul_reduce_scatter/benchmark.py diff --git a/benchmark/ops/all_gather_matmul/benchmark.py b/benchmark/ops/all_gather_matmul/benchmark.py new file mode 100644 index 000000000..3bc45579e --- /dev/null +++ b/benchmark/ops/all_gather_matmul/benchmark.py @@ -0,0 +1,376 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +""" +Benchmark for iris.ops all_gather_matmul fused operation. + +This benchmark showcases the fused All-Gather + GEMM operation where each rank +has a sharded A matrix that gets gathered, then multiplied with B. +""" + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import random +import argparse + +from examples.common.utils import JSONWriter + +import iris +from iris.ops import FusedConfig + +torch.manual_seed(123) +random.seed(123) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Benchmark all_gather_matmul fused operation.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("-m", type=int, default=16384, help="Number of rows in matrix A (M)") + parser.add_argument("-n", type=int, default=2048, help="Number of columns in matrix B (N)") + parser.add_argument("-k", type=int, default=131072, help="Common dimension total (K)") + parser.add_argument("-d", "--debug", action="store_true", help="Enable debug mode") + parser.add_argument("-v", "--validate", action="store_true", help="Enable validation mode") + parser.add_argument("-b", "--benchmark", action="store_true", help="Enable benchmarking mode") + parser.add_argument( + "--datatype", + type=str, + default="fp16", + choices=["fp16", "fp32", "bf16"], + help="Datatype of tensors", + ) + parser.add_argument( + "--output_file", + type=str, + default="all_gather_matmul.json", + help="Output file", + ) + parser.add_argument("--heap_size", type=int, default=1 << 34, help="Iris heap size") + parser.add_argument("--comm_sms", type=int, default=None, help="Number of SMs for operation (auto-detect if None)") + parser.add_argument( + "--benchmark_pytorch", + action="store_true", + help="Also benchmark PyTorch (all_gather_into_tensor + matmul) for comparison", + ) + parser.add_argument("--block_size_m", type=int, default=256, help="Block size for M dimension") + parser.add_argument("--block_size_n", type=int, default=64, help="Block size for N dimension") + parser.add_argument("--block_size_k", type=int, default=64, help="Block size for K dimension") + parser.add_argument("--group_size_m", type=int, default=1, help="Group size for M dimension tiling") + parser.add_argument("--num_xcds", type=int, default=None, help="Number of XCDs (auto-detected if not set)") + parser.add_argument("-r", "--num_ranks", type=int, default=8, help="Number of ranks/processes") + parser.add_argument( + "--init_url", type=str, default="tcp://127.0.0.1:29530", help="Initialization URL for distributed setup" + ) + + return vars(parser.parse_args()) + + +def _worker(local_rank: int, world_size: int, init_url: str, args: dict): + """Worker function for PyTorch distributed execution.""" + backend = "nccl" if torch.cuda.is_available() else "gloo" + dist.init_process_group( + backend=backend, + init_method=init_url, + world_size=world_size, + rank=local_rank, + device_id=torch.device(f"cuda:{local_rank}"), + ) + + shmem = iris.iris(args["heap_size"]) + rank = shmem.get_rank() + world_size = shmem.get_num_ranks() + + # Datatype mapping + datatype = torch.float32 + if args["datatype"] == "fp16": + datatype = torch.float16 + elif args["datatype"] == "fp32": + datatype = torch.float32 + elif args["datatype"] == "bf16": + datatype = torch.bfloat16 + else: + print("Unknown datatype.") + exit(1) + + M = args["m"] + N = args["n"] + K = args["k"] + K_local = K // world_size # Sharded K dimension + + # Create config with parameters + config_kwargs = { + "block_size_m": args["block_size_m"], + "block_size_n": args["block_size_n"], + "block_size_k": args["block_size_k"], + "group_size_m": args["group_size_m"], + } + if args["comm_sms"] is not None: + config_kwargs["num_sms"] = args["comm_sms"] + if args["num_xcds"] is not None: + config_kwargs["num_xcds"] = args["num_xcds"] + + config = FusedConfig(**config_kwargs) + + json_writer = JSONWriter(args["output_file"]) + json_writer.add_field("world_size", world_size) + json_writer.add_field("operation", "all_gather_matmul") + json_writer.add_field("k_local", K_local) + json_writer.add_field("k_total", K) + + for key, value in args.items(): + json_writer.add_field(key, value) + + # Export actual config values to JSON (including defaults) + json_writer.add_field("block_size_m", config.block_size_m) + json_writer.add_field("block_size_n", config.block_size_n) + json_writer.add_field("block_size_k", config.block_size_k) + json_writer.add_field("group_size_m", config.group_size_m) + json_writer.add_field("num_sms", config.num_sms) + json_writer.add_field("num_xcds", config.num_xcds) + + # Create input and output tensors + # A_sharded is M x K_local, B is K x N, output is M x N + A_sharded = shmem.zeros((M, K_local), dtype=datatype) + B = shmem.zeros((K, N), dtype=datatype) + C = shmem.zeros((M, N), dtype=datatype) + expected_tensor = None + + # Fill inputs with deterministic values + # Each rank has different A_sharded, same B + torch.manual_seed(123 + rank) + A_sharded_data = torch.randn((M, K_local), dtype=datatype, device=f"cuda:{rank}") + A_sharded.copy_(A_sharded_data) + + torch.manual_seed(456) # Same B for all ranks + B_data = torch.randn((K, N), dtype=datatype, device=f"cuda:{rank}") + B.copy_(B_data) + + # For validation: compute expected result + if args["validate"]: + # Gather all A_sharded matrices and compute expected result + A_sharded_list = [torch.zeros((M, K_local), dtype=datatype, device=f"cuda:{rank}") for _ in range(world_size)] + dist.all_gather(A_sharded_list, A_sharded_data) + + # Concatenate along K dimension: A_gathered = [A_0 | A_1 | ... | A_n] + A_gathered = torch.cat(A_sharded_list, dim=1) # (M, K) + + # Expected: A_gathered @ B + expected_tensor = shmem.zeros((M, N), dtype=datatype) + expected_result = torch.matmul(A_gathered, B_data) + expected_tensor.copy_(expected_result) + + comm_stream = torch.cuda.Stream() + + kernel_timing = { + "all_gather_matmul": { + "start_event": torch.cuda.Event(enable_timing=True), + "end_event": torch.cuda.Event(enable_timing=True), + "ms": 0, + "experiments": 0, + }, + } + + workspace = None + + def run_experiment(): + nonlocal kernel_timing, workspace + + # Preamble if available + if hasattr(shmem.ops, "all_gather_matmul_preamble"): + workspace = shmem.ops.all_gather_matmul_preamble( + C, + A_sharded, + B, + config=config, + workspace=workspace, + ) + + shmem.barrier() + + torch.cuda.nvtx.range_push("All-Gather-Matmul") + with torch.cuda.stream(comm_stream): + kernel_timing["all_gather_matmul"]["start_event"].record() + shmem.ops.all_gather_matmul( + C, + A_sharded, + B, + config=config, + async_op=False, + workspace=workspace, + ) + kernel_timing["all_gather_matmul"]["end_event"].record() + kernel_timing["all_gather_matmul"]["experiments"] += 1 + torch.cuda.nvtx.range_pop() + + # Synchronize before querying event timing + shmem.barrier() + + # Update timing + ms = kernel_timing["all_gather_matmul"]["start_event"].elapsed_time( + kernel_timing["all_gather_matmul"]["end_event"] + ) + kernel_timing["all_gather_matmul"]["ms"] += ms + + # Synchronize across all GPUs + shmem.barrier() + + if args["validate"]: + shmem.info("Validating...") + + # Reset output before validation + C.zero_() + shmem.barrier() + + run_experiment() + torch.cuda.synchronize() + shmem.barrier() + + atol = 1e-1 if datatype == torch.float16 else 1e-3 + success = torch.allclose(C, expected_tensor, atol=atol) + if not success: + max_diff = torch.abs(C - expected_tensor).max().item() + shmem.error(f"Rank {rank}: Validation failed, max diff: {max_diff}") + + if success: + shmem.info("All-gather-matmul validation passed!") + else: + shmem.error("All-gather-matmul validation failed!") + + json_writer.add_field("success", success) + + # Wait for all to finish validation + shmem.barrier() + + if args["benchmark"]: + # Warmup for benchmarking + for k in ["all_gather_matmul"]: + kernel_timing[k]["ms"] = 0 + kernel_timing[k]["experiments"] = 0 + + iris.do_bench(run_experiment, shmem.barrier, n_warmup=25, n_repeat=1) + + for k in ["all_gather_matmul"]: + kernel_timing[k]["ms"] = 0 + kernel_timing[k]["experiments"] = 0 + + # Reset output before benchmarking + C.zero_() + shmem.barrier() + + shmem.info("Benchmarking...") + + # Calculate TFLOPS: 2*M*N*K flops + total_flops = 2 * M * N * K + total_tflops_unit = total_flops * 1e-12 + + triton_ms = iris.do_bench(run_experiment, shmem.barrier) + tflops = total_tflops_unit / ( + (kernel_timing["all_gather_matmul"]["ms"] / kernel_timing["all_gather_matmul"]["experiments"]) * 1e-3 + ) + + # Calculate bandwidth for all-gather part + # All-gather moves (world_size - 1) * M * K_local * element_size bytes + element_size = torch.tensor([], dtype=datatype).element_size() + input_bytes = M * K_local * element_size + total_bytes = input_bytes * (world_size - 1) + total_bytes_gb = total_bytes / (1024**3) + + bandwidth_gbps = total_bytes_gb / ( + (kernel_timing["all_gather_matmul"]["ms"] / kernel_timing["all_gather_matmul"]["experiments"]) * 1e-3 + ) + + shmem.info( + f"All-gather-matmul (M={M}, K_local={K_local}, K_total={K}, N={N}, world_size={world_size}, dtype={args['datatype']}): " + f"{triton_ms:.3f} ms, {tflops:.3f} TFLOPS, {bandwidth_gbps:.3f} GB/s" + ) + + json_writer.add_field("tflops", tflops) + json_writer.add_field("bandwidth_gbps", bandwidth_gbps) + json_writer.add_field("total_ms", triton_ms) + json_writer.add_field("total_flops", total_flops) + json_writer.add_field("total_bytes", total_bytes) + json_writer.add_field("total_bytes_gb", total_bytes_gb) + json_writer.add_field( + "all_gather_matmul_ms", + kernel_timing["all_gather_matmul"]["ms"] / kernel_timing["all_gather_matmul"]["experiments"], + ) + json_writer.add_field("all_gather_matmul_experiments", kernel_timing["all_gather_matmul"]["experiments"]) + + # Wait for all to finish benchmarking + shmem.barrier() + + # Benchmark PyTorch (all_gather_into_tensor + matmul) for comparison + if args["benchmark_pytorch"]: + shmem.info("Benchmarking PyTorch (all_gather_into_tensor + matmul)...") + + # Create PyTorch tensors (not on Iris heap) + pytorch_A_sharded = torch.randn(M, K_local, dtype=datatype, device=f"cuda:{rank}") + pytorch_B = torch.randn(K, N, dtype=datatype, device=f"cuda:{rank}") + pytorch_A_gathered = torch.zeros(M, K, dtype=datatype, device=f"cuda:{rank}") + pytorch_C = torch.zeros(M, N, dtype=datatype, device=f"cuda:{rank}") + + # Warmup + for _ in range(10): + dist.all_gather_into_tensor(pytorch_A_gathered, pytorch_A_sharded) + pytorch_C = torch.matmul(pytorch_A_gathered, pytorch_B) + torch.cuda.synchronize() + dist.barrier() + + # Benchmark + dist.barrier() + + def run_pytorch_experiment(): + dist.all_gather_into_tensor(pytorch_A_gathered, pytorch_A_sharded) + pytorch_C = torch.matmul(pytorch_A_gathered, pytorch_B) + + pytorch_ms = iris.do_bench(run_pytorch_experiment, dist.barrier) + + # Calculate TFLOPS and bandwidth + pytorch_tflops = total_tflops_unit / (pytorch_ms * 1e-3) + pytorch_bandwidth_gbps = total_bytes_gb / (pytorch_ms * 1e-3) + + shmem.info( + f"PyTorch all_gather_into_tensor+matmul (M={M}, K_local={K_local}, K_total={K}, N={N}, world_size={world_size}, dtype={args['datatype']}): " + f"{pytorch_ms:.3f} ms, {pytorch_tflops:.3f} TFLOPS, {pytorch_bandwidth_gbps:.3f} GB/s" + ) + + if args["benchmark"]: + # Calculate performance ratio + iris_tflops = tflops + speedup = (iris_tflops / pytorch_tflops) if pytorch_tflops > 0 else 0 + shmem.info(f"Speedup (Iris/PyTorch): {speedup:.2f}x") + + json_writer.add_field("pytorch_tflops", pytorch_tflops) + json_writer.add_field("pytorch_bandwidth_gbps", pytorch_bandwidth_gbps) + json_writer.add_field("pytorch_ms", pytorch_ms) + json_writer.add_field("iris_speedup", speedup) + + # Wait for all to finish PyTorch benchmarking + shmem.barrier() + + if rank == 0: + json_writer.flush() + json_writer.display() + + shmem.barrier() + dist.destroy_process_group() + + +def main(): + args = parse_args() + num_ranks = args["num_ranks"] + init_url = args["init_url"] + + mp.spawn( + fn=_worker, + args=(num_ranks, init_url, args), + nprocs=num_ranks, + join=True, + ) + + +if __name__ == "__main__": + main() diff --git a/benchmark/ops/matmul_all_gather/benchmark.py b/benchmark/ops/matmul_all_gather/benchmark.py new file mode 100644 index 000000000..22c914e8d --- /dev/null +++ b/benchmark/ops/matmul_all_gather/benchmark.py @@ -0,0 +1,367 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +""" +Benchmark for iris.ops matmul_all_gather fused operation. + +This benchmark showcases the fused GEMM + All-Gather operation where each rank +computes a local matmul and then gathers results along M dimension. +""" + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import random +import argparse + +from examples.common.utils import JSONWriter + +import iris +from iris.ops import FusedConfig + +torch.manual_seed(123) +random.seed(123) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Benchmark matmul_all_gather fused operation.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("-m", type=int, default=16384, help="Number of rows per rank in matrix A (M_local)") + parser.add_argument("-n", type=int, default=2048, help="Number of columns in matrix B (N)") + parser.add_argument("-k", type=int, default=131072, help="Common dimension (K)") + parser.add_argument("-d", "--debug", action="store_true", help="Enable debug mode") + parser.add_argument("-v", "--validate", action="store_true", help="Enable validation mode") + parser.add_argument("-b", "--benchmark", action="store_true", help="Enable benchmarking mode") + parser.add_argument( + "--datatype", + type=str, + default="fp16", + choices=["fp16", "fp32", "bf16"], + help="Datatype of tensors", + ) + parser.add_argument( + "--output_file", + type=str, + default="matmul_all_gather.json", + help="Output file", + ) + parser.add_argument("--heap_size", type=int, default=1 << 34, help="Iris heap size") + parser.add_argument("--comm_sms", type=int, default=None, help="Number of SMs for operation (auto-detect if None)") + parser.add_argument( + "--benchmark_pytorch", + action="store_true", + help="Also benchmark PyTorch (matmul + all_gather_into_tensor) for comparison", + ) + parser.add_argument("--block_size_m", type=int, default=256, help="Block size for M dimension") + parser.add_argument("--block_size_n", type=int, default=64, help="Block size for N dimension") + parser.add_argument("--block_size_k", type=int, default=64, help="Block size for K dimension") + parser.add_argument("--group_size_m", type=int, default=1, help="Group size for M dimension tiling") + parser.add_argument("--num_xcds", type=int, default=None, help="Number of XCDs (auto-detected if not set)") + parser.add_argument("-r", "--num_ranks", type=int, default=8, help="Number of ranks/processes") + parser.add_argument( + "--init_url", type=str, default="tcp://127.0.0.1:29529", help="Initialization URL for distributed setup" + ) + + return vars(parser.parse_args()) + + +def _worker(local_rank: int, world_size: int, init_url: str, args: dict): + """Worker function for PyTorch distributed execution.""" + backend = "nccl" if torch.cuda.is_available() else "gloo" + dist.init_process_group( + backend=backend, + init_method=init_url, + world_size=world_size, + rank=local_rank, + device_id=torch.device(f"cuda:{local_rank}"), + ) + + shmem = iris.iris(args["heap_size"]) + rank = shmem.get_rank() + world_size = shmem.get_num_ranks() + + # Datatype mapping + datatype = torch.float32 + if args["datatype"] == "fp16": + datatype = torch.float16 + elif args["datatype"] == "fp32": + datatype = torch.float32 + elif args["datatype"] == "bf16": + datatype = torch.bfloat16 + else: + print("Unknown datatype.") + exit(1) + + M_local = args["m"] # Local M dimension + M = M_local * world_size # Total M after gather + N = args["n"] + K = args["k"] + + # Create config with parameters + config_kwargs = { + "block_size_m": args["block_size_m"], + "block_size_n": args["block_size_n"], + "block_size_k": args["block_size_k"], + "group_size_m": args["group_size_m"], + } + if args["comm_sms"] is not None: + config_kwargs["num_sms"] = args["comm_sms"] + if args["num_xcds"] is not None: + config_kwargs["num_xcds"] = args["num_xcds"] + + config = FusedConfig(**config_kwargs) + + json_writer = JSONWriter(args["output_file"]) + json_writer.add_field("world_size", world_size) + json_writer.add_field("operation", "matmul_all_gather") + json_writer.add_field("m_local", M_local) + json_writer.add_field("m_total", M) + + for key, value in args.items(): + json_writer.add_field(key, value) + + # Export actual config values to JSON (including defaults) + json_writer.add_field("block_size_m", config.block_size_m) + json_writer.add_field("block_size_n", config.block_size_n) + json_writer.add_field("block_size_k", config.block_size_k) + json_writer.add_field("group_size_m", config.group_size_m) + json_writer.add_field("num_sms", config.num_sms) + json_writer.add_field("num_xcds", config.num_xcds) + + # Create input and output tensors + # A_local is M_local x K, output is M x N (gathered) + A_local = shmem.zeros((M_local, K), dtype=datatype) + B = shmem.zeros((K, N), dtype=datatype) + C = shmem.zeros((M, N), dtype=datatype) + expected_tensor = None + + # Fill inputs with deterministic values + # Each rank has different A_local, same B + torch.manual_seed(123 + rank) + A_local_data = torch.randn((M_local, K), dtype=datatype, device=f"cuda:{rank}") + A_local.copy_(A_local_data) + + torch.manual_seed(456) # Same B for all ranks + B_data = torch.randn((K, N), dtype=datatype, device=f"cuda:{rank}") + B.copy_(B_data) + + # For validation: compute expected result + if args["validate"]: + # Gather all A_local matrices and compute expected result + A_local_list = [torch.zeros((M_local, K), dtype=datatype, device=f"cuda:{rank}") for _ in range(world_size)] + dist.all_gather(A_local_list, A_local_data) + + # Expected: [A_0 @ B; A_1 @ B; ...; A_n @ B] stacked along M + expected_tensor = shmem.zeros((M, N), dtype=datatype) + expected_parts = [] + for i, A_rank_local in enumerate(A_local_list): + C_rank_local = torch.matmul(A_rank_local, B_data) + expected_parts.append(C_rank_local) + expected_result = torch.cat(expected_parts, dim=0) + expected_tensor.copy_(expected_result) + + comm_stream = torch.cuda.Stream() + + kernel_timing = { + "matmul_all_gather": { + "start_event": torch.cuda.Event(enable_timing=True), + "end_event": torch.cuda.Event(enable_timing=True), + "ms": 0, + "experiments": 0, + }, + } + + workspace = None + + def run_experiment(): + nonlocal kernel_timing, workspace + + shmem.barrier() + + torch.cuda.nvtx.range_push("Matmul-All-Gather") + with torch.cuda.stream(comm_stream): + kernel_timing["matmul_all_gather"]["start_event"].record() + shmem.ops.matmul_all_gather( + C, + A_local, + B, + config=config, + async_op=False, + workspace=workspace, + ) + kernel_timing["matmul_all_gather"]["end_event"].record() + kernel_timing["matmul_all_gather"]["experiments"] += 1 + torch.cuda.nvtx.range_pop() + + # Synchronize before querying event timing + shmem.barrier() + + # Update timing + ms = kernel_timing["matmul_all_gather"]["start_event"].elapsed_time( + kernel_timing["matmul_all_gather"]["end_event"] + ) + kernel_timing["matmul_all_gather"]["ms"] += ms + + # Synchronize across all GPUs + shmem.barrier() + + if args["validate"]: + shmem.info("Validating...") + + # Reset output before validation + C.zero_() + shmem.barrier() + + run_experiment() + torch.cuda.synchronize() + shmem.barrier() + + atol = 1e-1 if datatype == torch.float16 else 1e-3 + success = torch.allclose(C, expected_tensor, atol=atol) + if not success: + max_diff = torch.abs(C - expected_tensor).max().item() + shmem.error(f"Rank {rank}: Validation failed, max diff: {max_diff}") + + if success: + shmem.info("Matmul-all-gather validation passed!") + else: + shmem.error("Matmul-all-gather validation failed!") + + json_writer.add_field("success", success) + + # Wait for all to finish validation + shmem.barrier() + + if args["benchmark"]: + # Warmup for benchmarking + for k in ["matmul_all_gather"]: + kernel_timing[k]["ms"] = 0 + kernel_timing[k]["experiments"] = 0 + + iris.do_bench(run_experiment, shmem.barrier, n_warmup=25, n_repeat=1) + + for k in ["matmul_all_gather"]: + kernel_timing[k]["ms"] = 0 + kernel_timing[k]["experiments"] = 0 + + # Reset output before benchmarking + C.zero_() + shmem.barrier() + + shmem.info("Benchmarking...") + + # Calculate TFLOPS: 2*M_local*N*K flops per rank (but total is same across all ranks) + total_flops = 2 * M_local * N * K + total_tflops_unit = total_flops * 1e-12 + + triton_ms = iris.do_bench(run_experiment, shmem.barrier) + tflops = total_tflops_unit / ( + (kernel_timing["matmul_all_gather"]["ms"] / kernel_timing["matmul_all_gather"]["experiments"]) * 1e-3 + ) + + # Calculate bandwidth for all-gather part + # All-gather moves (world_size - 1) * M_local * N * element_size bytes + element_size = torch.tensor([], dtype=datatype).element_size() + output_bytes = M_local * N * element_size + total_bytes = output_bytes * (world_size - 1) + total_bytes_gb = total_bytes / (1024**3) + + bandwidth_gbps = total_bytes_gb / ( + (kernel_timing["matmul_all_gather"]["ms"] / kernel_timing["matmul_all_gather"]["experiments"]) * 1e-3 + ) + + shmem.info( + f"Matmul-all-gather (M_local={M_local}, M_total={M}, N={N}, K={K}, world_size={world_size}, dtype={args['datatype']}): " + f"{triton_ms:.3f} ms, {tflops:.3f} TFLOPS, {bandwidth_gbps:.3f} GB/s" + ) + + json_writer.add_field("tflops", tflops) + json_writer.add_field("bandwidth_gbps", bandwidth_gbps) + json_writer.add_field("total_ms", triton_ms) + json_writer.add_field("total_flops", total_flops) + json_writer.add_field("total_bytes", total_bytes) + json_writer.add_field("total_bytes_gb", total_bytes_gb) + json_writer.add_field( + "matmul_all_gather_ms", + kernel_timing["matmul_all_gather"]["ms"] / kernel_timing["matmul_all_gather"]["experiments"], + ) + json_writer.add_field("matmul_all_gather_experiments", kernel_timing["matmul_all_gather"]["experiments"]) + + # Wait for all to finish benchmarking + shmem.barrier() + + # Benchmark PyTorch (matmul + all_gather_into_tensor) for comparison + if args["benchmark_pytorch"]: + shmem.info("Benchmarking PyTorch (matmul + all_gather_into_tensor)...") + + # Create PyTorch tensors (not on Iris heap) + pytorch_A_local = torch.randn(M_local, K, dtype=datatype, device=f"cuda:{rank}") + pytorch_B = torch.randn(K, N, dtype=datatype, device=f"cuda:{rank}") + pytorch_C_local = torch.zeros(M_local, N, dtype=datatype, device=f"cuda:{rank}") + pytorch_C = torch.zeros(M, N, dtype=datatype, device=f"cuda:{rank}") + + # Warmup + for _ in range(10): + pytorch_C_local = torch.matmul(pytorch_A_local, pytorch_B) + dist.all_gather_into_tensor(pytorch_C, pytorch_C_local) + torch.cuda.synchronize() + dist.barrier() + + # Benchmark + dist.barrier() + + def run_pytorch_experiment(): + pytorch_C_local = torch.matmul(pytorch_A_local, pytorch_B) + dist.all_gather_into_tensor(pytorch_C, pytorch_C_local) + + pytorch_ms = iris.do_bench(run_pytorch_experiment, dist.barrier) + + # Calculate TFLOPS and bandwidth + pytorch_tflops = total_tflops_unit / (pytorch_ms * 1e-3) + pytorch_bandwidth_gbps = total_bytes_gb / (pytorch_ms * 1e-3) + + shmem.info( + f"PyTorch matmul+all_gather_into_tensor (M_local={M_local}, M_total={M}, N={N}, K={K}, world_size={world_size}, dtype={args['datatype']}): " + f"{pytorch_ms:.3f} ms, {pytorch_tflops:.3f} TFLOPS, {pytorch_bandwidth_gbps:.3f} GB/s" + ) + + if args["benchmark"]: + # Calculate performance ratio + iris_tflops = tflops + speedup = (iris_tflops / pytorch_tflops) if pytorch_tflops > 0 else 0 + shmem.info(f"Speedup (Iris/PyTorch): {speedup:.2f}x") + + json_writer.add_field("pytorch_tflops", pytorch_tflops) + json_writer.add_field("pytorch_bandwidth_gbps", pytorch_bandwidth_gbps) + json_writer.add_field("pytorch_ms", pytorch_ms) + json_writer.add_field("iris_speedup", speedup) + + # Wait for all to finish PyTorch benchmarking + shmem.barrier() + + if rank == 0: + json_writer.flush() + json_writer.display() + + shmem.barrier() + dist.destroy_process_group() + + +def main(): + args = parse_args() + num_ranks = args["num_ranks"] + init_url = args["init_url"] + + mp.spawn( + fn=_worker, + args=(num_ranks, init_url, args), + nprocs=num_ranks, + join=True, + ) + + +if __name__ == "__main__": + main() diff --git a/benchmark/ops/matmul_all_reduce/benchmark.py b/benchmark/ops/matmul_all_reduce/benchmark.py new file mode 100644 index 000000000..fd923e051 --- /dev/null +++ b/benchmark/ops/matmul_all_reduce/benchmark.py @@ -0,0 +1,378 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +""" +Benchmark for iris.ops matmul_all_reduce fused operation. + +This benchmark showcases the fused GEMM + All-Reduce operation and reports +achieved TFLOPS and communication bandwidth. +""" + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import random +import argparse + +from examples.common.utils import JSONWriter + +import iris +from iris.ops import FusedConfig + +torch.manual_seed(123) +random.seed(123) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Benchmark matmul_all_reduce fused operation.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("-m", type=int, default=16384, help="Number of rows in matrix A (M)") + parser.add_argument("-n", type=int, default=2048, help="Number of columns in matrix B (N)") + parser.add_argument("-k", type=int, default=131072, help="Common dimension (K)") + parser.add_argument("-d", "--debug", action="store_true", help="Enable debug mode") + parser.add_argument("-v", "--validate", action="store_true", help="Enable validation mode") + parser.add_argument("-b", "--benchmark", action="store_true", help="Enable benchmarking mode") + parser.add_argument( + "--datatype", + type=str, + default="fp16", + choices=["fp16", "fp32", "bf16"], + help="Datatype of tensors", + ) + parser.add_argument( + "--output_file", + type=str, + default="matmul_all_reduce.json", + help="Output file", + ) + parser.add_argument("--heap_size", type=int, default=1 << 34, help="Iris heap size") + parser.add_argument("--comm_sms", type=int, default=None, help="Number of SMs for operation (auto-detect if None)") + parser.add_argument( + "--benchmark_pytorch", + action="store_true", + help="Also benchmark PyTorch (matmul + all_reduce) for comparison", + ) + parser.add_argument("--block_size_m", type=int, default=256, help="Block size for M dimension") + parser.add_argument("--block_size_n", type=int, default=64, help="Block size for N dimension") + parser.add_argument("--block_size_k", type=int, default=64, help="Block size for K dimension") + parser.add_argument("--group_size_m", type=int, default=1, help="Group size for M dimension tiling") + parser.add_argument("--num_xcds", type=int, default=None, help="Number of XCDs (auto-detected if not set)") + parser.add_argument("-r", "--num_ranks", type=int, default=8, help="Number of ranks/processes") + parser.add_argument( + "--all_reduce_variant", + type=str, + default="two_shot", + choices=["atomic", "ring", "two_shot", "one_shot", "spinlock"], + help="All-reduce variant to use", + ) + parser.add_argument( + "--init_url", type=str, default="tcp://127.0.0.1:29528", help="Initialization URL for distributed setup" + ) + + return vars(parser.parse_args()) + + +def _worker(local_rank: int, world_size: int, init_url: str, args: dict): + """Worker function for PyTorch distributed execution.""" + backend = "nccl" if torch.cuda.is_available() else "gloo" + dist.init_process_group( + backend=backend, + init_method=init_url, + world_size=world_size, + rank=local_rank, + device_id=torch.device(f"cuda:{local_rank}"), + ) + + shmem = iris.iris(args["heap_size"]) + rank = shmem.get_rank() + world_size = shmem.get_num_ranks() + + # Datatype mapping + datatype = torch.float32 + if args["datatype"] == "fp16": + datatype = torch.float16 + elif args["datatype"] == "fp32": + datatype = torch.float32 + elif args["datatype"] == "bf16": + datatype = torch.bfloat16 + else: + print("Unknown datatype.") + exit(1) + + M = args["m"] + N = args["n"] + K = args["k"] + + # Create config with parameters + config_kwargs = { + "block_size_m": args["block_size_m"], + "block_size_n": args["block_size_n"], + "block_size_k": args["block_size_k"], + "group_size_m": args["group_size_m"], + "all_reduce_variant": args["all_reduce_variant"], + } + if args["comm_sms"] is not None: + config_kwargs["num_sms"] = args["comm_sms"] + if args["num_xcds"] is not None: + config_kwargs["num_xcds"] = args["num_xcds"] + + config = FusedConfig(**config_kwargs) + + json_writer = JSONWriter(args["output_file"]) + json_writer.add_field("world_size", world_size) + json_writer.add_field("operation", "matmul_all_reduce") + + for key, value in args.items(): + json_writer.add_field(key, value) + + # Export actual config values to JSON (including defaults) + json_writer.add_field("block_size_m", config.block_size_m) + json_writer.add_field("block_size_n", config.block_size_n) + json_writer.add_field("block_size_k", config.block_size_k) + json_writer.add_field("group_size_m", config.group_size_m) + json_writer.add_field("num_sms", config.num_sms) + json_writer.add_field("num_xcds", config.num_xcds) + json_writer.add_field("all_reduce_variant", config.all_reduce_variant) + + # Create input and output tensors + # Must use shmem.zeros() to allocate on Iris symmetric heap + A = shmem.zeros((M, K), dtype=datatype) + B = shmem.zeros((K, N), dtype=datatype) + C = shmem.zeros((M, N), dtype=datatype) + expected_tensor = None + + # Fill inputs with deterministic values + # Each rank has different A, same B + torch.manual_seed(123 + rank) + A_local_data = torch.randn((M, K), dtype=datatype, device=f"cuda:{rank}") + A.copy_(A_local_data) + + torch.manual_seed(456) # Same B for all ranks + B_data = torch.randn((K, N), dtype=datatype, device=f"cuda:{rank}") + B.copy_(B_data) + + # For validation: compute expected result + # Reference: each rank computes local C = A @ B, then all_reduce + if args["validate"]: + expected_tensor = shmem.zeros((M, N), dtype=datatype) + C_local_ref = torch.matmul(A_local_data, B_data) + pytorch_output = C_local_ref.clone() + shmem.barrier() + dist.all_reduce(pytorch_output, op=dist.ReduceOp.SUM) + torch.cuda.synchronize() + expected_tensor.copy_(pytorch_output) + + comm_stream = torch.cuda.Stream() + + kernel_timing = { + "matmul_all_reduce": { + "start_event": torch.cuda.Event(enable_timing=True), + "end_event": torch.cuda.Event(enable_timing=True), + "ms": 0, + "experiments": 0, + }, + } + + workspace = None + + def run_experiment(): + nonlocal kernel_timing, workspace + + # Preamble if available + if hasattr(shmem.ops, "matmul_all_reduce_preamble"): + workspace = shmem.ops.matmul_all_reduce_preamble( + C, + A, + B, + config=config, + workspace=workspace, + ) + + shmem.barrier() + + torch.cuda.nvtx.range_push("Matmul-All-Reduce") + with torch.cuda.stream(comm_stream): + kernel_timing["matmul_all_reduce"]["start_event"].record() + shmem.ops.matmul_all_reduce( + C, + A, + B, + config=config, + async_op=False, + workspace=workspace, + ) + kernel_timing["matmul_all_reduce"]["end_event"].record() + kernel_timing["matmul_all_reduce"]["experiments"] += 1 + torch.cuda.nvtx.range_pop() + + # Synchronize before querying event timing + shmem.barrier() + + # Update timing + ms = kernel_timing["matmul_all_reduce"]["start_event"].elapsed_time( + kernel_timing["matmul_all_reduce"]["end_event"] + ) + kernel_timing["matmul_all_reduce"]["ms"] += ms + + # Synchronize across all GPUs + shmem.barrier() + + if args["validate"]: + shmem.info("Validating...") + + # Reset output before validation + C.zero_() + shmem.barrier() + + run_experiment() + torch.cuda.synchronize() + shmem.barrier() + + atol = 0.2 if datatype == torch.float16 else 0.3 + success = torch.allclose(C, expected_tensor, atol=atol) + if not success: + max_diff = torch.abs(C - expected_tensor).max().item() + shmem.error(f"Rank {rank}: Validation failed, max diff: {max_diff}") + + if success: + shmem.info("Matmul-all-reduce validation passed!") + else: + shmem.error("Matmul-all-reduce validation failed!") + + json_writer.add_field("success", success) + + # Wait for all to finish validation + shmem.barrier() + + if args["benchmark"]: + # Warmup for benchmarking + for k in ["matmul_all_reduce"]: + kernel_timing[k]["ms"] = 0 + kernel_timing[k]["experiments"] = 0 + + iris.do_bench(run_experiment, shmem.barrier, n_warmup=25, n_repeat=1) + + for k in ["matmul_all_reduce"]: + kernel_timing[k]["ms"] = 0 + kernel_timing[k]["experiments"] = 0 + + # Reset output before benchmarking + C.zero_() + shmem.barrier() + + shmem.info("Benchmarking...") + + # Calculate TFLOPS: 2*M*N*K flops + total_flops = 2 * M * N * K + total_tflops_unit = total_flops * 1e-12 + + triton_ms = iris.do_bench(run_experiment, shmem.barrier) + tflops = total_tflops_unit / ( + (kernel_timing["matmul_all_reduce"]["ms"] / kernel_timing["matmul_all_reduce"]["experiments"]) * 1e-3 + ) + + # Calculate bandwidth for all-reduce part + # All-reduce moves 2 * (world_size - 1) / world_size * data_size bytes + element_size = torch.tensor([], dtype=datatype).element_size() + output_bytes = M * N * element_size + total_bytes = output_bytes * (2 * (world_size - 1)) / world_size + total_bytes_gb = total_bytes / (1024**3) + + bandwidth_gbps = total_bytes_gb / ( + (kernel_timing["matmul_all_reduce"]["ms"] / kernel_timing["matmul_all_reduce"]["experiments"]) * 1e-3 + ) + + shmem.info( + f"Matmul-all-reduce (M={M}, N={N}, K={K}, world_size={world_size}, dtype={args['datatype']}, variant={args['all_reduce_variant']}): " + f"{triton_ms:.3f} ms, {tflops:.3f} TFLOPS, {bandwidth_gbps:.3f} GB/s" + ) + + json_writer.add_field("tflops", tflops) + json_writer.add_field("bandwidth_gbps", bandwidth_gbps) + json_writer.add_field("total_ms", triton_ms) + json_writer.add_field("total_flops", total_flops) + json_writer.add_field("total_bytes", total_bytes) + json_writer.add_field("total_bytes_gb", total_bytes_gb) + json_writer.add_field( + "matmul_all_reduce_ms", + kernel_timing["matmul_all_reduce"]["ms"] / kernel_timing["matmul_all_reduce"]["experiments"], + ) + json_writer.add_field("matmul_all_reduce_experiments", kernel_timing["matmul_all_reduce"]["experiments"]) + + # Wait for all to finish benchmarking + shmem.barrier() + + # Benchmark PyTorch (matmul + all_reduce) for comparison + if args["benchmark_pytorch"]: + shmem.info("Benchmarking PyTorch (matmul + all_reduce)...") + + # Create PyTorch tensors (not on Iris heap) + pytorch_A = torch.randn(M, K, dtype=datatype, device=f"cuda:{rank}") + pytorch_B = torch.randn(K, N, dtype=datatype, device=f"cuda:{rank}") + pytorch_C = torch.zeros(M, N, dtype=datatype, device=f"cuda:{rank}") + + # Warmup + for _ in range(10): + pytorch_C = torch.matmul(pytorch_A, pytorch_B) + dist.all_reduce(pytorch_C, op=dist.ReduceOp.SUM) + torch.cuda.synchronize() + dist.barrier() + + # Benchmark + dist.barrier() + + def run_pytorch_experiment(): + pytorch_C = torch.matmul(pytorch_A, pytorch_B) + dist.all_reduce(pytorch_C, op=dist.ReduceOp.SUM) + + pytorch_ms = iris.do_bench(run_pytorch_experiment, dist.barrier) + + # Calculate TFLOPS and bandwidth + pytorch_tflops = total_tflops_unit / (pytorch_ms * 1e-3) + pytorch_bandwidth_gbps = total_bytes_gb / (pytorch_ms * 1e-3) + + shmem.info( + f"PyTorch matmul+all_reduce (M={M}, N={N}, K={K}, world_size={world_size}, dtype={args['datatype']}): " + f"{pytorch_ms:.3f} ms, {pytorch_tflops:.3f} TFLOPS, {pytorch_bandwidth_gbps:.3f} GB/s" + ) + + if args["benchmark"]: + # Calculate performance ratio + iris_tflops = tflops + speedup = (iris_tflops / pytorch_tflops) if pytorch_tflops > 0 else 0 + shmem.info(f"Speedup (Iris/PyTorch): {speedup:.2f}x") + + json_writer.add_field("pytorch_tflops", pytorch_tflops) + json_writer.add_field("pytorch_bandwidth_gbps", pytorch_bandwidth_gbps) + json_writer.add_field("pytorch_ms", pytorch_ms) + json_writer.add_field("iris_speedup", speedup) + + # Wait for all to finish PyTorch benchmarking + shmem.barrier() + + if rank == 0: + json_writer.flush() + json_writer.display() + + shmem.barrier() + dist.destroy_process_group() + + +def main(): + args = parse_args() + num_ranks = args["num_ranks"] + init_url = args["init_url"] + + mp.spawn( + fn=_worker, + args=(num_ranks, init_url, args), + nprocs=num_ranks, + join=True, + ) + + +if __name__ == "__main__": + main() diff --git a/benchmark/ops/matmul_reduce_scatter/benchmark.py b/benchmark/ops/matmul_reduce_scatter/benchmark.py new file mode 100644 index 000000000..301444f25 --- /dev/null +++ b/benchmark/ops/matmul_reduce_scatter/benchmark.py @@ -0,0 +1,421 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +""" +Benchmark for iris.ops matmul_reduce_scatter fused operation. + +This benchmark showcases the fused GEMM + Reduce-Scatter operation where each rank +computes a local matmul, reduces across all ranks, and scatters tiles to ranks. +""" + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import random +import argparse + +from examples.common.utils import JSONWriter + +import iris +from iris.ops import FusedConfig + +torch.manual_seed(123) +random.seed(123) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Benchmark matmul_reduce_scatter fused operation.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("-m", type=int, default=16384, help="Number of rows in matrix A (M)") + parser.add_argument("-n", type=int, default=2048, help="Number of columns in matrix B (N)") + parser.add_argument("-k", type=int, default=131072, help="Common dimension (K)") + parser.add_argument("-d", "--debug", action="store_true", help="Enable debug mode") + parser.add_argument("-v", "--validate", action="store_true", help="Enable validation mode") + parser.add_argument("-b", "--benchmark", action="store_true", help="Enable benchmarking mode") + parser.add_argument( + "--datatype", + type=str, + default="fp16", + choices=["fp16", "fp32", "bf16"], + help="Datatype of tensors", + ) + parser.add_argument( + "--output_file", + type=str, + default="matmul_reduce_scatter.json", + help="Output file", + ) + parser.add_argument("--heap_size", type=int, default=1 << 34, help="Iris heap size") + parser.add_argument("--comm_sms", type=int, default=None, help="Number of SMs for operation (auto-detect if None)") + parser.add_argument( + "--benchmark_pytorch", + action="store_true", + help="Also benchmark PyTorch (matmul + all_reduce) for comparison", + ) + parser.add_argument("--block_size_m", type=int, default=256, help="Block size for M dimension") + parser.add_argument("--block_size_n", type=int, default=64, help="Block size for N dimension") + parser.add_argument("--block_size_k", type=int, default=64, help="Block size for K dimension") + parser.add_argument("--group_size_m", type=int, default=1, help="Group size for M dimension tiling") + parser.add_argument("--num_xcds", type=int, default=None, help="Number of XCDs (auto-detected if not set)") + parser.add_argument("-r", "--num_ranks", type=int, default=8, help="Number of ranks/processes") + parser.add_argument( + "--init_url", type=str, default="tcp://127.0.0.1:29531", help="Initialization URL for distributed setup" + ) + + return vars(parser.parse_args()) + + +def _worker(local_rank: int, world_size: int, init_url: str, args: dict): + """Worker function for PyTorch distributed execution.""" + backend = "nccl" if torch.cuda.is_available() else "gloo" + dist.init_process_group( + backend=backend, + init_method=init_url, + world_size=world_size, + rank=local_rank, + device_id=torch.device(f"cuda:{local_rank}"), + ) + + shmem = iris.iris(args["heap_size"]) + rank = shmem.get_rank() + world_size = shmem.get_num_ranks() + + # Datatype mapping + datatype = torch.float32 + if args["datatype"] == "fp16": + datatype = torch.float16 + elif args["datatype"] == "fp32": + datatype = torch.float32 + elif args["datatype"] == "bf16": + datatype = torch.bfloat16 + else: + print("Unknown datatype.") + exit(1) + + M = args["m"] + N = args["n"] + K = args["k"] + + # Create config with parameters + config_kwargs = { + "block_size_m": args["block_size_m"], + "block_size_n": args["block_size_n"], + "block_size_k": args["block_size_k"], + "group_size_m": args["group_size_m"], + } + if args["comm_sms"] is not None: + config_kwargs["num_sms"] = args["comm_sms"] + if args["num_xcds"] is not None: + config_kwargs["num_xcds"] = args["num_xcds"] + + config = FusedConfig(**config_kwargs) + + json_writer = JSONWriter(args["output_file"]) + json_writer.add_field("world_size", world_size) + json_writer.add_field("operation", "matmul_reduce_scatter") + + for key, value in args.items(): + json_writer.add_field(key, value) + + # Export actual config values to JSON (including defaults) + json_writer.add_field("block_size_m", config.block_size_m) + json_writer.add_field("block_size_n", config.block_size_n) + json_writer.add_field("block_size_k", config.block_size_k) + json_writer.add_field("group_size_m", config.group_size_m) + json_writer.add_field("num_sms", config.num_sms) + json_writer.add_field("num_xcds", config.num_xcds) + + # Calculate tile distribution + num_pid_m = (M + config.block_size_m - 1) // config.block_size_m + num_pid_n = (N + config.block_size_n - 1) // config.block_size_n + total_tiles = num_pid_m * num_pid_n + tiles_per_rank = total_tiles // world_size + start_tile = rank * tiles_per_rank + if rank == world_size - 1: + tiles_per_rank = total_tiles - start_tile + + json_writer.add_field("total_tiles", total_tiles) + json_writer.add_field("tiles_per_rank", tiles_per_rank) + + # Create input and output tensors + # Each rank computes full A @ B, but only keeps its assigned tiles + A = shmem.zeros((M, K), dtype=datatype) + B = shmem.zeros((K, N), dtype=datatype) + C = shmem.zeros((M, N), dtype=datatype) + expected_tiles = [] + + # Fill inputs with deterministic values + # Each rank has different A, same B + torch.manual_seed(123 + rank) + A_local_data = torch.randn((M, K), dtype=datatype, device=f"cuda:{rank}") + A.copy_(A_local_data) + + torch.manual_seed(456) # Same B for all ranks + B_data = torch.randn((K, N), dtype=datatype, device=f"cuda:{rank}") + B.copy_(B_data) + + # For validation: compute expected result for this rank's tiles + if args["validate"]: + # Gather all A matrices to compute expected result + A_list = [torch.zeros((M, K), dtype=datatype, device=f"cuda:{rank}") for _ in range(world_size)] + dist.all_gather(A_list, A_local_data) + + # Expected: sum of all (A_i @ B) for each rank i, but only for this rank's tiles + expected_full = torch.zeros((M, N), dtype=datatype, device=f"cuda:{rank}") + for A_rank in A_list: + expected_full += torch.matmul(A_rank, B_data) + + # Extract only this rank's tiles + for local_tile_idx in range(tiles_per_rank): + tile_id = start_tile + local_tile_idx + pid_m = tile_id // num_pid_n + pid_n = tile_id % num_pid_n + + m_start = pid_m * config.block_size_m + m_end = min(m_start + config.block_size_m, M) + n_start = pid_n * config.block_size_n + n_end = min(n_start + config.block_size_n, N) + + expected_tiles.append( + { + "tile_id": tile_id, + "pid_m": pid_m, + "pid_n": pid_n, + "m_start": m_start, + "m_end": m_end, + "n_start": n_start, + "n_end": n_end, + "data": expected_full[m_start:m_end, n_start:n_end].clone(), + } + ) + + comm_stream = torch.cuda.Stream() + + kernel_timing = { + "matmul_reduce_scatter": { + "start_event": torch.cuda.Event(enable_timing=True), + "end_event": torch.cuda.Event(enable_timing=True), + "ms": 0, + "experiments": 0, + }, + } + + workspace = None + + def run_experiment(): + nonlocal kernel_timing, workspace + + # Preamble if available + if hasattr(shmem.ops, "matmul_reduce_scatter_preamble"): + workspace = shmem.ops.matmul_reduce_scatter_preamble( + C, + A, + B, + config=config, + workspace=workspace, + ) + + shmem.barrier() + + torch.cuda.nvtx.range_push("Matmul-Reduce-Scatter") + with torch.cuda.stream(comm_stream): + kernel_timing["matmul_reduce_scatter"]["start_event"].record() + shmem.ops.matmul_reduce_scatter( + C, + A, + B, + async_op=False, + config=config, + workspace=workspace, + ) + kernel_timing["matmul_reduce_scatter"]["end_event"].record() + kernel_timing["matmul_reduce_scatter"]["experiments"] += 1 + torch.cuda.nvtx.range_pop() + + # Synchronize before querying event timing + shmem.barrier() + + # Update timing + ms = kernel_timing["matmul_reduce_scatter"]["start_event"].elapsed_time( + kernel_timing["matmul_reduce_scatter"]["end_event"] + ) + kernel_timing["matmul_reduce_scatter"]["ms"] += ms + + # Synchronize across all GPUs + shmem.barrier() + + if args["validate"]: + shmem.info("Validating...") + + # Reset output before validation + C.zero_() + shmem.barrier() + + run_experiment() + torch.cuda.synchronize() + shmem.barrier() + + atol = 2e-1 if datatype == torch.float16 else 1e-1 + success = True + + # Validate each tile assigned to this rank + for tile_info in expected_tiles: + C_tile = C[tile_info["m_start"] : tile_info["m_end"], tile_info["n_start"] : tile_info["n_end"]] + expected_tile = tile_info["data"] + + tile_match = torch.allclose(C_tile, expected_tile, atol=atol) + if not tile_match: + max_diff = torch.abs(C_tile - expected_tile).max().item() + shmem.error( + f"Rank {rank}, tile {tile_info['tile_id']} ({tile_info['pid_m']},{tile_info['pid_n']}): " + f"Validation failed, max diff: {max_diff}" + ) + success = False + + if success: + shmem.info("Matmul-reduce-scatter validation passed!") + else: + shmem.error("Matmul-reduce-scatter validation failed!") + + json_writer.add_field("success", success) + + # Wait for all to finish validation + shmem.barrier() + + if args["benchmark"]: + # Warmup for benchmarking + for k in ["matmul_reduce_scatter"]: + kernel_timing[k]["ms"] = 0 + kernel_timing[k]["experiments"] = 0 + + iris.do_bench(run_experiment, shmem.barrier, n_warmup=25, n_repeat=1) + + for k in ["matmul_reduce_scatter"]: + kernel_timing[k]["ms"] = 0 + kernel_timing[k]["experiments"] = 0 + + # Reset output before benchmarking + C.zero_() + shmem.barrier() + + shmem.info("Benchmarking...") + + # Calculate TFLOPS: 2*M*N*K flops + total_flops = 2 * M * N * K + total_tflops_unit = total_flops * 1e-12 + + triton_ms = iris.do_bench(run_experiment, shmem.barrier) + tflops = total_tflops_unit / ( + (kernel_timing["matmul_reduce_scatter"]["ms"] / kernel_timing["matmul_reduce_scatter"]["experiments"]) + * 1e-3 + ) + + # Calculate bandwidth for reduce-scatter part + # Similar to all-reduce: 2 * (world_size - 1) / world_size * data_size bytes + element_size = torch.tensor([], dtype=datatype).element_size() + output_bytes = M * N * element_size + total_bytes = output_bytes * (2 * (world_size - 1)) / world_size + total_bytes_gb = total_bytes / (1024**3) + + bandwidth_gbps = total_bytes_gb / ( + (kernel_timing["matmul_reduce_scatter"]["ms"] / kernel_timing["matmul_reduce_scatter"]["experiments"]) + * 1e-3 + ) + + shmem.info( + f"Matmul-reduce-scatter (M={M}, N={N}, K={K}, world_size={world_size}, dtype={args['datatype']}): " + f"{triton_ms:.3f} ms, {tflops:.3f} TFLOPS, {bandwidth_gbps:.3f} GB/s" + ) + + json_writer.add_field("tflops", tflops) + json_writer.add_field("bandwidth_gbps", bandwidth_gbps) + json_writer.add_field("total_ms", triton_ms) + json_writer.add_field("total_flops", total_flops) + json_writer.add_field("total_bytes", total_bytes) + json_writer.add_field("total_bytes_gb", total_bytes_gb) + json_writer.add_field( + "matmul_reduce_scatter_ms", + kernel_timing["matmul_reduce_scatter"]["ms"] / kernel_timing["matmul_reduce_scatter"]["experiments"], + ) + json_writer.add_field( + "matmul_reduce_scatter_experiments", kernel_timing["matmul_reduce_scatter"]["experiments"] + ) + + # Wait for all to finish benchmarking + shmem.barrier() + + # Benchmark PyTorch (matmul + all_reduce) for comparison + # Note: We use all_reduce since PyTorch's reduce_scatter has different semantics + if args["benchmark_pytorch"]: + shmem.info("Benchmarking PyTorch (matmul + all_reduce)...") + + # Create PyTorch tensors (not on Iris heap) + pytorch_A = torch.randn(M, K, dtype=datatype, device=f"cuda:{rank}") + pytorch_B = torch.randn(K, N, dtype=datatype, device=f"cuda:{rank}") + pytorch_C = torch.zeros(M, N, dtype=datatype, device=f"cuda:{rank}") + + # Warmup + for _ in range(10): + pytorch_C = torch.matmul(pytorch_A, pytorch_B) + dist.all_reduce(pytorch_C, op=dist.ReduceOp.SUM) + torch.cuda.synchronize() + dist.barrier() + + # Benchmark + dist.barrier() + + def run_pytorch_experiment(): + pytorch_C = torch.matmul(pytorch_A, pytorch_B) + dist.all_reduce(pytorch_C, op=dist.ReduceOp.SUM) + + pytorch_ms = iris.do_bench(run_pytorch_experiment, dist.barrier) + + # Calculate TFLOPS and bandwidth + pytorch_tflops = total_tflops_unit / (pytorch_ms * 1e-3) + pytorch_bandwidth_gbps = total_bytes_gb / (pytorch_ms * 1e-3) + + shmem.info( + f"PyTorch matmul+all_reduce (M={M}, N={N}, K={K}, world_size={world_size}, dtype={args['datatype']}): " + f"{pytorch_ms:.3f} ms, {pytorch_tflops:.3f} TFLOPS, {pytorch_bandwidth_gbps:.3f} GB/s" + ) + + if args["benchmark"]: + # Calculate performance ratio + iris_tflops = tflops + speedup = (iris_tflops / pytorch_tflops) if pytorch_tflops > 0 else 0 + shmem.info(f"Speedup (Iris/PyTorch): {speedup:.2f}x") + + json_writer.add_field("pytorch_tflops", pytorch_tflops) + json_writer.add_field("pytorch_bandwidth_gbps", pytorch_bandwidth_gbps) + json_writer.add_field("pytorch_ms", pytorch_ms) + json_writer.add_field("iris_speedup", speedup) + + # Wait for all to finish PyTorch benchmarking + shmem.barrier() + + if rank == 0: + json_writer.flush() + json_writer.display() + + shmem.barrier() + dist.destroy_process_group() + + +def main(): + args = parse_args() + num_ranks = args["num_ranks"] + init_url = args["init_url"] + + mp.spawn( + fn=_worker, + args=(num_ranks, init_url, args), + nprocs=num_ranks, + join=True, + ) + + +if __name__ == "__main__": + main() diff --git a/iris/ops/__init__.py b/iris/ops/__init__.py index e0d12ba51..a6ed4a659 100644 --- a/iris/ops/__init__.py +++ b/iris/ops/__init__.py @@ -141,17 +141,16 @@ def matmul_all_gather(self, output_tensor, A, B, bias=None, async_op=False, conf """ return matmul_all_gather(self._shmem, output_tensor, A, B, bias, async_op, config, workspace) - def matmul_reduce_scatter(self, output_tensor, A, B, bias=None, async_op=False, config=None, workspace=None): + def matmul_reduce_scatter(self, output_tensor, A, B, async_op=False, config=None, workspace=None): """ Fused matrix multiplication and reduce-scatter. - Computes: output = reduce_scatter(A @ B + bias) along N dimension + Computes: output = reduce_scatter(A @ B) where each rank keeps assigned tiles Args: - output_tensor: Output tensor (M, N_local) where N_local = N / world_size + output_tensor: Output tensor (M, N) - will contain reduced tiles for this rank A: Input matrix A (M, K) B: Input matrix B (K, N) - bias: Optional bias vector (M,) or (N,) async_op: If False, performs barrier at end config: Optional FusedConfig for tuning workspace: Optional pre-allocated workspace @@ -160,11 +159,10 @@ def matmul_reduce_scatter(self, output_tensor, A, B, bias=None, async_op=False, workspace: Updated workspace object Example: - >>> N_local = N // world_size - >>> output = shmem.zeros((M, N_local), dtype=torch.float16) + >>> output = shmem.zeros((M, N), dtype=torch.float16) >>> shmem.ops.matmul_reduce_scatter(output, A, B) """ - return matmul_reduce_scatter(self._shmem, output_tensor, A, B, bias, async_op, config, workspace) + return matmul_reduce_scatter(self._shmem, output_tensor, A, B, async_op, config, workspace) # Export public API From ef227b08acacc7534f96349e3845064db09589ea Mon Sep 17 00:00:00 2001 From: neoblizz Date: Sat, 7 Feb 2026 19:14:58 +0000 Subject: [PATCH 02/31] Merge conflicts. --- benchmark/ops/all_gather_matmul/benchmark.py | 8 + iris/iris.py | 15 +- iris/iris.py.backup | 2255 ++++++++++++++++++ iris/ops/all_gather_matmul.py.with_chunked | 521 ++++ iris/ops/config.py | 26 +- iris/ops/workspace.py | 4 + iris/x/gather.py | 2 +- tests/ops/test_all_gather_matmul.py | 21 +- 8 files changed, 2831 insertions(+), 21 deletions(-) create mode 100644 iris/iris.py.backup create mode 100644 iris/ops/all_gather_matmul.py.with_chunked diff --git a/benchmark/ops/all_gather_matmul/benchmark.py b/benchmark/ops/all_gather_matmul/benchmark.py index 3bc45579e..20ff0c536 100644 --- a/benchmark/ops/all_gather_matmul/benchmark.py +++ b/benchmark/ops/all_gather_matmul/benchmark.py @@ -61,6 +61,13 @@ def parse_args(): parser.add_argument("--group_size_m", type=int, default=1, help="Group size for M dimension tiling") parser.add_argument("--num_xcds", type=int, default=None, help="Number of XCDs (auto-detected if not set)") parser.add_argument("-r", "--num_ranks", type=int, default=8, help="Number of ranks/processes") + parser.add_argument( + "--variant", + type=str, + default="pull", + choices=["pull", "chunked"], + help="All-gather matmul variant (pull or chunked)", + ) parser.add_argument( "--init_url", type=str, default="tcp://127.0.0.1:29530", help="Initialization URL for distributed setup" ) @@ -106,6 +113,7 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): "block_size_n": args["block_size_n"], "block_size_k": args["block_size_k"], "group_size_m": args["group_size_m"], + "all_gather_matmul_variant": args["variant"], } if args["comm_sms"] is not None: config_kwargs["num_sms"] = args["comm_sms"] diff --git a/iris/iris.py b/iris/iris.py index 5032a640e..9b8a3d35a 100644 --- a/iris/iris.py +++ b/iris/iris.py @@ -1793,17 +1793,12 @@ def __translate(ptr, from_rank, to_rank, heap_bases): # Cast to_base back to pointer type translated_ptr = tl.cast(translated_ptr_byte, ptr.dtype) - # Optimization to vectorize the load/store - # We can't do this in general because we don't know the shape of the tensor or block sizes - # ptr = tl.max_contiguous(tl.multiple_of(ptr, (16, 16)), (16, 32)) + # Vectorization hints: must be <= minimum block size used by any caller. + # (32, 32) is safe since all supported block sizes are multiples of 32. + # Largest vectorized load instruction is dwordx4 (128-bits = 8 x fp16). + translated_ptr = tl.multiple_of(translated_ptr, (32, 32)) + translated_ptr = tl.max_contiguous(translated_ptr, (32, 32)) - # 0 You can use this if your block sizes are multiples of 32. - # Largest vectorized load instruction is dwordx4 (128-bits) - # translated_ptr = tl.multiple_of(translated_ptr, (32, 32)) - # translated_ptr = tl.max_contiguous(translated_ptr, (1, 32)) - - # ptr = tl.max_contiguous(tl.multiple_of(ptr, 512), 512) - # translated_ptr = tl.max_contiguous(tl.multiple_of(translated_ptr, 512), 512) return translated_ptr diff --git a/iris/iris.py.backup b/iris/iris.py.backup new file mode 100644 index 000000000..e8932c3c8 --- /dev/null +++ b/iris/iris.py.backup @@ -0,0 +1,2255 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. + +""" +Iris: Multi-GPU Communication and Memory Management Framework + +Iris is a high-performance framework that enables seamless multi-GPU programming in Triton, +enabling fine-grained communication and compute overlap natively in Triton +across multiple GPUs with SHMEM-like Remote Memory Access (RMA) capabilities. + +Key Features: +- Symmetric heap management across multiple GPUs +- High-performance atomic operations (add, cas, xchg, xor, and, or, min, max) +- Efficient load/store operations with rank-to-rank communication +- Memory allocation and deallocation utilities +- Built-in logging with rank information +- PyTorch distributed integration for distributed computing + +Example: + >>> import iris + >>> ctx = iris.iris(heap_size=2**30) # 1GB heap + >>> tensor = ctx.zeros(1024, 1024, dtype=torch.float32) +""" + +import triton +import triton.language as tl + +from iris._distributed_helpers import ( + init_distributed, + distributed_barrier, + distributed_broadcast_scalar, + distributed_broadcast_tensor, +) +from iris.hip import ( + set_device, + get_cu_count, + count_devices, +) +from iris.symmetric_heap import SymmetricHeap +import numpy as np +import math +import torch +import logging + +# Import logging functionality from the separate logging module +from .logging import logger + + +class Iris: + """ + Main Iris class for multi-GPU communication and memory management. + + This class provides a unified interface for distributed GPU operations including + memory allocation, atomic operations, and inter-rank communication. + + Args: + heap_size (int): Size of the symmetric heap in bytes. Default: 1GB (2^30) + + Example: + >>> ctx = iris.iris(heap_size=2**31) # 2GB heap + >>> print(f"Rank {ctx.cur_rank} of {ctx.num_ranks}") # Rank 0 of 1 + >>> tensor = ctx.zeros(1000, 1000, dtype=torch.float32) + """ + + def __init__(self, heap_size=1 << 30): + # Initialize distributed environment + comm, cur_rank, num_ranks = init_distributed() + num_gpus = count_devices() + + gpu_id = cur_rank % num_gpus + set_device(gpu_id) + + self.comm = comm + self.num_ranks = num_ranks + self.cur_rank = cur_rank + self.gpu_id = gpu_id + self.heap_size = heap_size + + # Initialize symmetric heap + self.heap = SymmetricHeap(heap_size, gpu_id, cur_rank, num_ranks) + self.device = f"cuda:{gpu_id}" + self.heap_bases = self.heap.get_heap_bases() + + for i in range(num_ranks): + self.debug(f"GPU {i}: Heap base {hex(int(self.heap_bases[i].item()))}") + + distributed_barrier() + + # Initialize CCL interface + self.ccl = self.CCL(self) + + # Lazy initialization for ops interface + self._ops = None + + def _log_with_rank(self, level, message): + """Helper method to log with rank information injected into the record.""" + if logger.isEnabledFor(level): + record = logging.LogRecord( + name=logger.name, level=level, pathname="", lineno=0, msg=message, args=(), exc_info=None + ) + # Inject rank information into the record + record.iris_rank = self.cur_rank + record.iris_num_ranks = self.num_ranks + logger.handle(record) + + def debug(self, message): + """ + Log a debug message with rank information. + + Args: + message (str): Human-readable message to log at debug level. + + Notes: + The log record is enriched with ``iris_rank`` and ``iris_num_ranks`` so + formatters can display the originating rank and world size. + + Example: + >>> ctx = iris.iris() + >>> iris.set_logger_level(iris.DEBUG) + >>> ctx.debug("Allocating buffers") # [Iris] [0/1] Allocating buffers + """ + self._log_with_rank(logging.DEBUG, message) + + def info(self, message): + """ + Log an info message with rank information. + + Args: + message (str): Human-readable message to log at info level. + + Example: + >>> ctx = iris.iris() + >>> ctx.info("Starting iteration 0") # [Iris] [0/1] Starting iteration 0 + """ + self._log_with_rank(logging.INFO, message) + + def warning(self, message): + """ + Log a warning message with rank information. + + Args: + message (str): Human-readable message to log at warning level. + + Example: + >>> ctx = iris.iris() + >>> ctx.warning("Memory usage is high") # [Iris] [0/1] Memory usage is high + """ + self._log_with_rank(logging.WARNING, message) + + def error(self, message): + """ + Log an error message with rank information. + + Args: + message (str): Human-readable message to log at error level. + + Example: + >>> ctx = iris.iris() + >>> ctx.error("Failed to allocate memory") # [Iris] [0/1] Failed to allocate memory + """ + self._log_with_rank(logging.ERROR, message) + + @property + def ops(self): + """ + Access fused GEMM+CCL operations. + + This property provides a namespace for high-level fused operations that combine + matrix multiplication with collective communication. Operations automatically infer + dimensions, strides, and hardware parameters from input tensors. + + Available operations: + - matmul_all_reduce: GEMM + All-Reduce + - all_gather_matmul: All-Gather + GEMM + - matmul_all_gather: GEMM + All-Gather + - matmul_reduce_scatter: GEMM + Reduce-Scatter + + Returns: + OpsNamespace: Namespace with fused operation methods + + Raises: + ImportError: If tritonBLAS is not available + + Example: + >>> ctx = iris.iris() + >>> A = ctx.randn((1024, 512), dtype=torch.float16) + >>> B = ctx.randn((512, 2048), dtype=torch.float16) + >>> output = ctx.zeros((1024, 2048), dtype=torch.float16) + >>> ctx.ops.matmul_all_reduce(output, A, B, ctx) + """ + if self._ops is None: + from iris.ops import OpsNamespace + + self._ops = OpsNamespace(self) + return self._ops + + def broadcast(self, value, source_rank=0): + """ + Broadcast a value from one rank to all ranks. + + This method automatically detects the type of value and uses the appropriate + broadcast mechanism: + - For tensors and arrays: uses efficient PyTorch distributed tensor collectives + - For scalars and other objects: uses object broadcast + + Args: + value (Any): The value to broadcast. Can be a scalar, tensor, numpy array, + or any picklable object. Only the ``source_rank`` value is used; + other ranks should pass a placeholder (e.g., ``None``). + source_rank (int): Rank id that holds the authoritative value. + + Returns: + Any: The value broadcast to all ranks. Tensors and arrays are returned as + numpy arrays; scalars and objects are returned in their original type. + + Examples: + >>> ctx = iris.iris() + >>> # Broadcasting a scalar + >>> value = 42 if ctx.cur_rank == 0 else None + >>> value = ctx.broadcast(value, source_rank=0) # All ranks get 42 + >>> + >>> # Broadcasting a tensor + >>> if ctx.cur_rank == 0: + >>> data = torch.randn(10, 10) + >>> else: + >>> data = None + >>> data = ctx.broadcast(data, source_rank=0) # All ranks get the same array + """ + # Check if the value on source_rank is a tensor or array-like + if self.cur_rank == source_rank and value is not None: + # Explicitly exclude strings and non-numeric types + if isinstance(value, (str, dict, bool)): + is_tensor = False + elif isinstance(value, torch.Tensor): + is_tensor = True + elif isinstance(value, np.ndarray): + is_tensor = True + elif isinstance(value, (list, tuple)): + # Try to convert list/tuple to tensor to check if it's numeric + try: + torch.as_tensor(value) + is_tensor = True + except (TypeError, ValueError): + is_tensor = False + else: + # For other types, try to convert and check + try: + test_array = np.asarray(value) + # Check if it's a numeric dtype that torch can handle + if np.issubdtype(test_array.dtype, np.number): + torch.as_tensor(test_array) + is_tensor = True + else: + is_tensor = False + except (TypeError, ValueError): + is_tensor = False + else: + is_tensor = False + + # Broadcast the type decision to all ranks + is_tensor = distributed_broadcast_scalar(is_tensor, source_rank) + + if is_tensor: + return distributed_broadcast_tensor(value, root=source_rank) + else: + return distributed_broadcast_scalar(value, source_rank) + + def __allocate(self, num_elements, dtype): + """Allocate memory using the symmetric heap.""" + self.debug(f"allocate: num_elements = {num_elements}, dtype = {dtype}") + return self.heap.allocate(num_elements, dtype) + + def __parse_size(self, size): + # Handle nested tuples/lists by flattening them recursively + while len(size) == 1 and isinstance(size[0], (tuple, list)): + size = size[0] + num_elements = math.prod(size) + return size, num_elements + + def zeros_like( + self, input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format + ): + """ + Returns a tensor filled with the scalar value 0, with the same size as input, allocated on the Iris symmetric heap. + + Args: + input (Tensor): the size of input will determine size of the output tensor. + + Keyword Arguments: + dtype (torch.dtype, optional): the desired data type of returned Tensor. + Default: if None, defaults to the dtype of input. + layout (torch.layout, optional): the desired layout of returned tensor. + Default: if None, defaults to the layout of input. Note: Iris tensors are always contiguous (strided). + device (torch.device, optional): the desired device of returned tensor. + Default: if None, defaults to the device of input. Must be compatible with this Iris instance. + requires_grad (bool, optional): If autograd should record operations on the returned tensor. + Default: False. + memory_format (torch.memory_format, optional): the desired memory format of returned Tensor. + Default: torch.preserve_format. + + Example: + >>> ctx = iris.iris(1 << 20) + >>> input_tensor = ctx.ones(2, 3) + >>> zeros_tensor = ctx.zeros_like(input_tensor) + >>> print(zeros_tensor.shape) # torch.Size([2, 3]) + """ + self.debug( + f"zeros_like: input_shape = {input.shape}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}" + ) + + # Use input's properties as defaults if not specified + if dtype is None: + dtype = input.dtype + if layout is None: + layout = input.layout + if device is None: + device = input.device + + # Validate device compatibility with Iris + self.__throw_if_invalid_device(device) + + # Get the size from input tensor + size = input.size() + num_elements = input.numel() + + # Allocate new tensor with the same size + new_tensor = self.__allocate(num_elements, dtype) + new_tensor.zero_() + + # Reshape to match input size + new_tensor = new_tensor.reshape(size) + + # Apply the requested memory format + new_tensor = self.__apply_memory_format(new_tensor, size, memory_format, input) + + # Apply the requested layout + new_tensor = self.__apply_layout(new_tensor, layout) + + # Set requires_grad if specified + if requires_grad: + new_tensor.requires_grad_() + + return new_tensor + + def arange( + self, start=0, end=None, step=1, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False + ): + """ + Returns a 1-D tensor of size ⌈(end - start) / step⌉ with values from the interval [start, end) + taken with common difference step beginning from start. The tensor is allocated on the symmetric heap. + + Note: When using floating-point dtypes (especially reduced precision types like bfloat16), + the results may be affected by floating-point rounding behavior. Some values in the sequence + might not be exactly representable in certain floating-point formats, which can lead to + repeated values or unexpected rounding. For precise sequences, it is recommended to use + integer dtypes instead of floating-point dtypes. + + Note that non-integer step is subject to floating point rounding errors when comparing + against end; to avoid inconsistency, we advise subtracting a small epsilon from end in such cases. + + Args: + start (Number, optional): the starting value for the set of points. Default: 0. + end (Number): the ending value for the set of points + step (Number, optional): the gap between each pair of adjacent points. Default: 1. + out (Tensor, optional): the output tensor. + dtype (torch.dtype, optional): the desired data type of returned tensor. + Default: if None, uses a global default (see torch.get_default_dtype()). + If dtype is not given, infer the data type from the other input arguments. + If any of start, end, or step are floating-point, the dtype is inferred + be the default dtype, see get_default_dtype(). Otherwise, the dtype is inferred + to be torch.int64. + layout (torch.layout, optional): the desired layout of returned Tensor. Default: torch.strided. + Note: Iris tensors always use `torch.strided` regardless of this parameter. + device (torch.device, optional): the desired device of returned tensor. + Default: if None, uses the current device for the default tensor type. + requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False. + + Example: + >>> ctx = iris.iris(1 << 20) + >>> tensor = ctx.arange(0, 10, 2) # [0, 2, 4, 6, 8] + >>> print(tensor.shape) # torch.Size([5]) + """ + self.debug(f"arange: start = {start}, end = {end}, step = {step}, dtype = {dtype}, device = {device}") + + # Handle the case where only one argument is provided (end) + if end is None: + end = start + start = 0 + + # Validate inputs + if step == 0: + raise ValueError("step must be non-zero") + + # Validate step direction consistency + if step > 0 and start >= end: + raise ValueError(f"Invalid range: start >= end with positive step (start={start}, end={end}, step={step})") + elif step < 0 and start <= end: + raise ValueError(f"Invalid range: start <= end with negative step (start={start}, end={end}, step={step})") + + # Calculate the number of elements + num_elements = math.ceil((end - start) / step) + + # Infer dtype if not provided + if dtype is None: + if any(isinstance(x, float) for x in [start, end, step]): + dtype = torch.get_default_dtype() + else: + dtype = torch.int64 + + # Use current device if none specified + if device is None: + device = self.device + + # Validate device compatibility with Iris + self.__throw_if_invalid_device(device) + + if out is not None: + self.__throw_if_invalid_output_tensor(out, num_elements, dtype) + tensor = out + else: + tensor = self.__allocate(num_elements=num_elements, dtype=dtype) + + target_device = tensor.device + arange_tensor = torch.arange(start, end, step, dtype=dtype, device=target_device) + + tensor[:] = arange_tensor + + tensor = self.__apply_layout(tensor, layout) + + if requires_grad: + tensor.requires_grad_() + + return tensor + + def zeros(self, *size, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False): + """ + Returns a tensor filled with the scalar value 0, with the shape defined by the variable argument size. + The tensor is allocated on the Iris symmetric heap. + + Args: + *size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword Arguments: + out (Tensor, optional): the output tensor. + dtype (torch.dtype, optional): the desired data type of returned tensor. + Default: if None, uses a global default (see torch.set_default_dtype()). + layout (torch.layout, optional): the desired layout of returned Tensor. + Default: torch.strided. Note: Iris tensors always use `torch.strided` regardless of this parameter. + device (torch.device, optional): the desired device of returned tensor. + Default: if None, uses the current device for the default tensor type. + requires_grad (bool, optional): If autograd should record operations on the returned tensor. + Default: False. + + Example: + >>> ctx = iris.iris(1 << 20) + >>> tensor = ctx.zeros(2, 3) + >>> print(tensor.shape) # torch.Size([2, 3]) + >>> print(tensor[0]) # tensor([0., 0., 0.], device='cuda:0') + """ + self.debug(f"zeros: size = {size}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}") + + # Use global default dtype if None is provided + if dtype is None: + dtype = torch.get_default_dtype() + + # Use current device if none specified + if device is None: + device = self.device + + # Validate device compatibility with Iris + self.__throw_if_invalid_device(device) + + # Parse size and calculate number of elements + size, num_elements = self.__parse_size(size) + + # If out is provided, use it; otherwise allocate new tensor + if out is not None: + self.__throw_if_invalid_output_tensor(out, num_elements, dtype) + # Fill with zeros + out.zero_() + # Create a reshaped view of the out tensor + tensor = out.view(size) + else: + tensor = self.__allocate(num_elements=num_elements, dtype=dtype) + # Fill with zeros + tensor.zero_() + # Reshape to the desired size + tensor = tensor.reshape(size) + + # Apply the requested layout + tensor = self.__apply_layout(tensor, layout) + + # Set requires_grad if specified + if requires_grad: + tensor.requires_grad_() + + return tensor + + def randn( + self, + *size, + generator=None, + out=None, + dtype=None, + layout=torch.strided, + device=None, + requires_grad=False, + pin_memory=False, + ): + """ + Returns a tensor filled with random numbers from a normal distribution with mean 0 and variance 1 + (also called the standard normal distribution). The tensor is allocated on the Iris symmetric heap. + + .. math:: + \\text{out}_i \\sim \\mathcal{N}(0, 1) + + For complex dtypes, the tensor is i.i.d. sampled from a complex normal distribution with zero mean + and unit variance as + + .. math:: + \\text{out}_i \\sim \\mathcal{CN}(0, 1) + + This is equivalent to separately sampling the real :math:`(\\text{Re})` and imaginary :math:`(\\text{Im})` + part of :math:`\\text{out}_i` as + + .. math:: + \\text{Re}(\\text{out}_i) \\sim \\mathcal{N}(0, \\frac{1}{2}), \\quad \\text{Im}(\\text{out}_i) \\sim \\mathcal{N}(0, \\frac{1}{2}) + + The shape of the tensor is defined by the variable argument size. + + Args: + *size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword Arguments: + generator (torch.Generator, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (torch.dtype, optional): the desired data type of returned tensor. + Default: if None, uses a global default (see torch.set_default_dtype()). + layout (torch.layout, optional): the desired layout of returned Tensor. + Default: torch.strided. Note: Iris tensors always use `torch.strided` regardless of this parameter. + device (torch.device, optional): the desired device of returned tensor. + Default: if None, uses the current device for the default tensor type (see torch.set_default_device()). + device will be the CPU for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the returned tensor. + Default: False. + pin_memory (bool, optional): If set, returned tensor would be allocated in the pinned memory. + Works only for CPU tensors. Default: False. + + Example: + >>> ctx = iris.iris(1 << 20) + >>> tensor = ctx.randn(2, 3) + >>> print(tensor.shape) # torch.Size([2, 3]) + >>> print(tensor[0]) # tensor([ 0.3982, -0.0059, -0.4365], device='cuda:0') + """ + self.debug( + f"randn: size = {size}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}, pin_memory = {pin_memory}" + ) + + # Use global default dtype if None is provided + if dtype is None: + dtype = torch.get_default_dtype() + + # Use current device if none specified + if device is None: + device = self.device + + # Validate device compatibility with Iris + self.__throw_if_invalid_device(device) + + # Parse size and calculate number of elements + size, num_elements = self.__parse_size(size) + + # If out is provided, use it; otherwise allocate new tensor + if out is not None: + self.__throw_if_invalid_output_tensor(out, num_elements, dtype) + # Generate random data and copy to out tensor + random_data = torch.randn(num_elements, generator=generator, dtype=dtype, device=device, layout=layout) + out.copy_(random_data) + # Create a reshaped view of the out tensor + tensor = out.view(size) + else: + tensor = self.__allocate(num_elements=num_elements, dtype=dtype) + # Generate random data and copy to tensor + random_data = torch.randn(num_elements, generator=generator, dtype=dtype, device=device, layout=layout) + tensor.copy_(random_data) + # Reshape to the desired size + tensor = tensor.reshape(size) + + # Apply the requested layout + tensor = self.__apply_layout(tensor, layout) + + # Set requires_grad if specified + if requires_grad: + tensor.requires_grad_() + + return tensor + + def ones(self, *size, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False): + """ + Returns a tensor filled with the scalar value 1, with the shape defined by the variable argument size. + The tensor is allocated on the Iris symmetric heap. + + Args: + *size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword Arguments: + out (Tensor, optional): the output tensor. + dtype (torch.dtype, optional): the desired data type of returned tensor. + Default: if None, uses a global default (see torch.set_default_dtype()). + layout (torch.layout, optional): the desired layout of returned Tensor. + Default: torch.strided. Note: Iris tensors always use `torch.strided` regardless of this parameter. + device (torch.device, optional): the desired device of returned tensor. + Default: if None, uses the current device for the default tensor type. + requires_grad (bool, optional): If autograd should record operations on the returned tensor. + Default: False. + + Example: + >>> ctx = iris.iris(1 << 20) + >>> tensor = ctx.ones(2, 3) + >>> print(tensor.shape) # torch.Size([2, 3]) + >>> print(tensor[0]) # tensor([1., 1., 1.], device='cuda:0') + """ + self.debug(f"ones: size = {size}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}") + + # Use global default dtype if None is provided + if dtype is None: + dtype = torch.get_default_dtype() + + # Use current device if none specified + if device is None: + device = self.device + + # Validate device compatibility with Iris + self.__throw_if_invalid_device(device) + + # Parse size and calculate number of elements + size, num_elements = self.__parse_size(size) + + # If out is provided, use it; otherwise allocate new tensor + if out is not None: + self.__throw_if_invalid_output_tensor(out, num_elements, dtype) + # Fill with ones + out.fill_(1) + # Create a reshaped view of the out tensor + tensor = out.view(size) + else: + tensor = self.__allocate(num_elements=num_elements, dtype=dtype) + # Fill with ones + tensor.fill_(1) + # Reshape to the desired size + tensor = tensor.reshape(size) + + # Apply the requested layout + tensor = self.__apply_layout(tensor, layout) + + # Set requires_grad if specified + if requires_grad: + tensor.requires_grad_() + + return tensor + + def full(self, size, fill_value, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False): + """ + Creates a tensor of size size filled with fill_value. The tensor's dtype is inferred from fill_value. + The tensor is allocated on the Iris symmetric heap. + + Args: + size (int...): a list, tuple, or torch.Size of integers defining the shape of the output tensor. + fill_value (Scalar): the value to fill the output tensor with. + + Keyword Arguments: + out (Tensor, optional): the output tensor. + dtype (torch.dtype, optional): the desired data type of returned tensor. + Default: if None, uses a global default (see torch.set_default_dtype()). + layout (torch.layout, optional): the desired layout of returned Tensor. + Default: torch.strided. Note: Iris tensors always use `torch.strided` regardless of this parameter. + device (torch.device, optional): the desired device of returned tensor. + Default: if None, uses the current device for the default tensor type. + requires_grad (bool, optional): If autograd should record operations on the returned tensor. + Default: False. + + Example: + >>> ctx = iris.iris(1 << 20) + >>> tensor = ctx.full((2, 3), 3.14) + >>> print(tensor.shape) # torch.Size([2, 3]) + >>> print(tensor[0]) # tensor([3.1400, 3.1400, 3.1400], device='cuda:0') + """ + self.debug( + f"full: size = {size}, fill_value = {fill_value}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}" + ) + + # Infer dtype from fill_value if not provided + if dtype is None: + if isinstance(fill_value, (int, float)): + if isinstance(fill_value, float): + dtype = torch.get_default_dtype() + else: + dtype = torch.int64 + else: + # For other types (like tensors), use their dtype + dtype = torch.get_default_dtype() + + # Use current device if none specified + if device is None: + device = self.device + + # Validate device compatibility with Iris + self.__throw_if_invalid_device(device) + + # Parse size and calculate number of elements + size, num_elements = self.__parse_size(size) + + # If out is provided, use it; otherwise allocate new tensor + if out is not None: + self.__throw_if_invalid_output_tensor(out, num_elements, dtype) + # Fill with the specified value + out.fill_(fill_value) + # Create a reshaped view of the out tensor + tensor = out.view(size) + else: + tensor = self.__allocate(num_elements=num_elements, dtype=dtype) + # Fill with the specified value + tensor.fill_(fill_value) + # Reshape to the desired size + tensor = tensor.reshape(size) + + # Apply the requested layout + tensor = self.__apply_layout(tensor, layout) + + # Set requires_grad if specified + if requires_grad: + tensor.requires_grad_() + + return tensor + + def uniform(self, size, low=0.0, high=1.0, dtype=torch.float): + """ + Returns a tensor filled with random numbers from a uniform distribution, allocated on the Iris symmetric heap. + + Args: + size (int or tuple of ints): the size of the output tensor. + low (float, optional): the lower bound of the uniform distribution. Default: 0.0. + high (float, optional): the upper bound of the uniform distribution. Default: 1.0. + dtype (torch.dtype, optional): the desired data type of returned tensor. Default: torch.float. + + Returns: + Tensor: A tensor filled with random numbers from a uniform distribution. + + Example: + >>> ctx = iris.iris(1 << 20) + >>> tensor = ctx.uniform((2, 3), low=0.0, high=1.0) + >>> print(tensor.shape) # torch.Size([2, 3]) + >>> print(tensor[0]) # tensor([0.1234, 0.5678, 0.9012], device='cuda:0') + """ + self.debug(f"uniform: size = {size}, low = {low}, high = {high}, dtype = {dtype}") + size, num_elements = self.__parse_size(size) + tensor = self.__allocate(num_elements=num_elements, dtype=dtype) + tensor.uniform_(low, high) + return tensor.reshape(size) + + def empty( + self, + *size, + out=None, + dtype=None, + layout=torch.strided, + device=None, + requires_grad=False, + pin_memory=False, + memory_format=torch.contiguous_format, + ): + """ + Returns a tensor filled with uninitialized data. The shape of the tensor is defined by the variable argument size. + The tensor is allocated on the Iris symmetric heap. + + Note: + If torch.use_deterministic_algorithms() and torch.utils.deterministic.fill_uninitialized_memory are both set to True, + the output tensor is initialized to prevent any possible nondeterministic behavior from using the data as an input to an operation. + Floating point and complex tensors are filled with NaN, and integer tensors are filled with the maximum value. + + Args: + *size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword Arguments: + out (Tensor, optional): the output tensor. + dtype (torch.dtype, optional): the desired data type of returned tensor. + Default: if None, uses a global default (see torch.set_default_dtype()). + layout (torch.layout, optional): the desired layout of returned Tensor. + Default: torch.strided. Note: Iris tensors always use `torch.strided` regardless of this parameter. + device (torch.device, optional): the desired device of returned tensor. + Default: if None, uses the current device for the default tensor type. + requires_grad (bool, optional): If autograd should record operations on the returned tensor. + Default: False. + pin_memory (bool, optional): If set, returned tensor would be allocated in the pinned memory. + Works only for CPU tensors. Default: False. Note: Iris tensors are always on GPU. + memory_format (torch.memory_format, optional): the desired memory format of returned Tensor. + Default: torch.contiguous_format. + + Example: + >>> ctx = iris.iris(1 << 20) + >>> tensor = ctx.empty(2, 3) + >>> print(tensor.shape) # torch.Size([2, 3]) + """ + self.debug( + f"empty: size = {size}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}, pin_memory = {pin_memory}" + ) + + # Use global default dtype if None is provided + if dtype is None: + dtype = torch.get_default_dtype() + + # Use current device if none specified + if device is None: + device = self.device + + # Validate device compatibility with Iris + self.__throw_if_invalid_device(device) + + # Parse size and calculate number of elements + size, num_elements = self.__parse_size(size) + + # If out is provided, use it; otherwise allocate new tensor + if out is not None: + self.__throw_if_invalid_output_tensor(out, num_elements, dtype) + # Create a reshaped view of the out tensor + tensor = out.view(size) + else: + tensor = self.__allocate(num_elements=num_elements, dtype=dtype) + # Reshape to the desired size + tensor = tensor.reshape(size) + + # Apply the requested memory format + tensor = self.__apply_memory_format(tensor, size, memory_format) + + # Apply the requested layout + tensor = self.__apply_layout(tensor, layout) + + # Set requires_grad if specified + if requires_grad: + tensor.requires_grad_() + + return tensor + + def randint( + self, *args, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False + ): + """ + Returns a tensor filled with random integers generated uniformly between low (inclusive) and high (exclusive). + The shape of the tensor is defined by the variable argument size. + The tensor is allocated on the Iris symmetric heap. + + Note: + With the global dtype default (torch.float32), this function returns a tensor with dtype torch.int64. + + Args: + low (int, optional): Lowest integer to be drawn from the distribution. Default: 0. + high (int): One above the highest integer to be drawn from the distribution. + size (tuple): a tuple defining the shape of the output tensor. + + Keyword Arguments: + generator (torch.Generator, optional): a pseudorandom number generator for sampling. + out (Tensor, optional): the output tensor. + dtype (torch.dtype, optional): if None, this function returns a tensor with dtype torch.int64. + layout (torch.layout, optional): the desired layout of returned Tensor. Default: torch.strided. + device (torch.device, optional): the desired device of returned tensor. Default: if None, uses the current device. + requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False. + + Example: + >>> ctx = iris.iris(1 << 20) + >>> tensor = ctx.randint(0, 10, (2, 3)) # Random integers [0, 10) + >>> print(tensor.shape) # torch.Size([2, 3]) + >>> print(tensor[0]) # tensor([7, 2, 9], device='cuda:0') + """ + self.debug(f"randint: args = {args}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}") + + # Parse arguments to determine low, high, and size + # PyTorch randint signatures: + # randint(high, size) - where high is the upper bound and size is the shape + # randint(low, high, size) - where low and high are bounds, size is the shape + if len(args) == 2: + # randint(high, size) + high, size = args + low = 0 + elif len(args) == 3: + # randint(low, high, size) + low, high, size = args + else: + raise ValueError(f"randint expects 2 or 3 positional arguments, got {len(args)}") + + # Use default dtype if None is provided + if dtype is None: + dtype = torch.int64 + + # Use current device if none specified + if device is None: + device = self.device + + # Validate device compatibility with Iris + self.__throw_if_invalid_device(device) + + # Parse size and calculate number of elements + size, num_elements = self.__parse_size(size) + + # If out is provided, use it; otherwise allocate new tensor + if out is not None: + self.__throw_if_invalid_output_tensor(out, num_elements, dtype) + # Create a reshaped view of the out tensor + tensor = out.view(size) + else: + tensor = self.__allocate(num_elements=num_elements, dtype=dtype) + # Reshape to the desired size + tensor = tensor.reshape(size) + + # Generate random integers using PyTorch's randint + # Use specified device or fall back to current device + target_device = device if device is not None else self.device + + # Handle generator parameter + if generator is not None: + torch.randint(low, high, size, generator=generator, out=tensor, dtype=dtype, device=target_device) + else: + torch.randint(low, high, size, out=tensor, dtype=dtype, device=target_device) + + # Apply the requested layout + tensor = self.__apply_layout(tensor, layout) + + # Set requires_grad if specified + if requires_grad: + tensor.requires_grad_() + + return tensor + + def linspace(self, start, end, steps, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False): + """ + Creates a one-dimensional tensor of size steps whose values are evenly spaced from start to end, inclusive. + The tensor is allocated on the Iris symmetric heap. + + The values are: + (start, start + (end-start)/(steps-1), ..., start + (steps-2)*(end-start)/(steps-1), end) + + Args: + start (float or Tensor): the starting value for the set of points. If Tensor, it must be 0-dimensional. + end (float or Tensor): the ending value for the set of points. If Tensor, it must be 0-dimensional. + steps (int): size of the constructed tensor. + + Keyword Arguments: + out (Tensor, optional): the output tensor. + dtype (torch.dtype, optional): the data type to perform the computation in. + Default: if None, uses the global default dtype when both start and end are real, + and corresponding complex dtype when either is complex. + layout (torch.layout, optional): the desired layout of returned Tensor. Default: torch.strided. + device (torch.device, optional): the desired device of returned tensor. Default: if None, uses the current device. + requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False. + + Example: + >>> ctx = iris.iris(1 << 20) + >>> tensor = ctx.linspace(0, 10, 5) # [0, 2.5, 5, 7.5, 10] + >>> print(tensor) # tensor([ 0.0000, 2.5000, 5.0000, 7.5000, 10.0000], device='cuda:0') + """ + self.debug( + f"linspace: start = {start}, end = {end}, steps = {steps}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}" + ) + + # Use global default dtype if None is provided + if dtype is None: + # Check if start or end are complex numbers + start_is_complex = isinstance(start, complex) or (hasattr(start, "dtype") and torch.is_complex(start)) + end_is_complex = isinstance(end, complex) or (hasattr(end, "dtype") and torch.is_complex(end)) + + if start_is_complex or end_is_complex: + # Infer complex dtype based on default dtype + dtype = torch.complex64 if torch.get_default_dtype() == torch.float32 else torch.complex128 + else: + dtype = torch.get_default_dtype() + + # Use current device if none specified + if device is None: + device = self.device + + # Validate device compatibility with Iris + self.__throw_if_invalid_device(device) + + # Parse steps and extract the integer value + if isinstance(steps, (tuple, list)): + if len(steps) == 1: + # Single-element tuple/list like (5,) or [5] + steps_int = steps[0] + # Handle nested tuples like ((5,),) + if isinstance(steps_int, (tuple, list)): + steps_int = steps_int[0] + else: + # Multi-element tuple/list - use __parse_size for compatibility + size, num_elements = self.__parse_size(steps) + steps_int = num_elements + else: + # steps is a single integer + steps_int = steps + + # Ensure steps_int is an integer + steps_int = int(steps_int) + size = (steps_int,) + num_elements = steps_int + + # If out is provided, use it; otherwise allocate new tensor + if out is not None: + self.__throw_if_invalid_output_tensor(out, num_elements, dtype) + # Create a reshaped view of the out tensor + tensor = out.view(size) + else: + tensor = self.__allocate(num_elements=num_elements, dtype=dtype) + # Reshape to the desired size + tensor = tensor.reshape(size) + + # Generate linspace using PyTorch's linspace + # Use specified device or fall back to current device + target_device = device if device is not None else self.device + torch.linspace(start, end, steps_int, out=tensor, dtype=dtype, device=target_device) + + # Apply the requested layout + tensor = self.__apply_layout(tensor, layout) + + # Set requires_grad if specified + if requires_grad: + tensor.requires_grad_() + + return tensor + + def rand( + self, + *size, + generator=None, + out=None, + dtype=None, + layout=torch.strided, + device=None, + requires_grad=False, + pin_memory=False, + ): + """ + Returns a tensor filled with random numbers from a uniform distribution on the interval [0, 1). + The tensor is allocated on the Iris symmetric heap. + + Args: + *size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword Arguments: + generator (torch.Generator, optional): a pseudorandom number generator for sampling. + out (Tensor, optional): the output tensor. + dtype (torch.dtype, optional): the desired data type of returned tensor. + Default: if None, uses a global default (see torch.set_default_dtype()). + layout (torch.layout, optional): the desired layout of returned Tensor. + Default: torch.strided. Note: Iris tensors always use `torch.strided` regardless of this parameter. + device (torch.device, optional): the desired device of returned tensor. + Default: if None, uses the current device for the default tensor type. + requires_grad (bool, optional): If autograd should record operations on the returned tensor. + Default: False. + pin_memory (bool, optional): If set, returned tensor would be allocated in the pinned memory. + Works only for CPU tensors. Default: False. Note: Iris tensors are always on GPU. + + Example: + >>> ctx = iris.iris(1 << 20) + >>> tensor = ctx.rand(2, 3) # Random values in [0, 1) + >>> print(tensor.shape) # torch.Size([2, 3]) + >>> print(tensor[0]) # tensor([0.1234, 0.5678, 0.9012], device='cuda:0') + """ + self.debug( + f"rand: size = {size}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}, pin_memory = {pin_memory}" + ) + + # Use global default dtype if None is provided + if dtype is None: + dtype = torch.get_default_dtype() + + # Use current device if none specified + if device is None: + device = self.device + + # Validate device compatibility with Iris + self.__throw_if_invalid_device(device) + + # Parse size and calculate number of elements + size, num_elements = self.__parse_size(size) + + # If out is provided, use it; otherwise allocate new tensor + if out is not None: + self.__throw_if_invalid_output_tensor(out, num_elements, dtype) + # Create a reshaped view of the out tensor + tensor = out.view(size) + else: + tensor = self.__allocate(num_elements=num_elements, dtype=dtype) + # Reshape to the desired size + tensor = tensor.reshape(size) + + # Generate random numbers using PyTorch's rand + # Use specified device (already validated and set above) + + # Handle generator parameter + if generator is not None: + torch.rand(size, generator=generator, out=tensor, dtype=dtype, device=device) + else: + torch.rand(size, out=tensor, dtype=dtype, device=device) + + # Apply the requested layout + tensor = self.__apply_layout(tensor, layout) + + # Set requires_grad if specified + if requires_grad: + tensor.requires_grad_() + + return tensor + + def __deallocate(self, pointer): + pass + + def get_heap_bases(self): + """ + Return the tensor of symmetric heap base addresses for all ranks. + + Returns: + torch.Tensor: A 1D tensor of ``uint64`` heap base addresses of size ``num_ranks`` + on the Iris device. Pass this to device-side Triton kernels that require + heap translation. + + Example: + >>> ctx = iris.iris(1 << 20) + >>> heap_bases = ctx.get_heap_bases() + >>> print(heap_bases.shape) # torch.Size([num_ranks]) + """ + return self.heap_bases + + def barrier(self, stream=None, group=None): + """ + Synchronize ranks within the specified group and their CUDA devices. + + This first calls ``torch.cuda.synchronize()`` or ``stream.synchronize()`` to ensure the local GPU has + finished all queued work, then performs a distributed barrier so that all + ranks in the group reach the same point before proceeding. + + Args: + stream: If stream is given: wait only for that stream before barrier. If stream is None: legacy behavior (device-wide sync). + group (ProcessGroup, optional): The process group to synchronize. + If None, uses the default process group (all ranks). + + Example: + >>> ctx = iris.iris(1 << 20) + >>> ctx.barrier() # Synchronize all ranks + >>> ctx.barrier(group=my_group) # Synchronize only ranks in my_group + """ + # Wait for all GPUs to finish work + if stream is None: + torch.cuda.synchronize() + else: + stream.synchronize() + + # Distributed barrier + distributed_barrier(group=group) + + def get_device(self): + """ + Get the underlying device where the Iris symmetric heap resides. + + Returns: + torch.device: The CUDA device of Iris-managed memory. + + Example: + >>> ctx = iris.iris(1 << 20) + >>> device = ctx.get_device() + >>> print(device) # cuda:0 + """ + return self.heap.get_device() + + def get_cu_count(self): + """ + Get the number of compute units (CUs) for the current GPU. + + Returns: + int: Number of compute units on this rank's GPU. + + Example: + >>> ctx = iris.iris(1 << 20) + >>> cu_count = ctx.get_cu_count() + >>> print(f"GPU has {cu_count} CUs") # GPU has 304 CUs + """ + return get_cu_count(self.gpu_id) + + def get_rank(self): + """ + Get this process's rank id in the distributed communicator. + + Returns: + int: Zero-based rank id of the current process. + + Example: + >>> ctx = iris.iris(1 << 20) + >>> rank = ctx.get_rank() + >>> print(f"This is rank {rank}") # This is rank 0 + """ + return self.cur_rank + + def get_num_ranks(self): + """ + Get the total number of ranks in the distributed communicator. + + Returns: + int: World size (number of ranks). + + Example: + >>> ctx = iris.iris(1 << 20) + >>> num_ranks = ctx.get_num_ranks() + >>> print(f"Total ranks: {num_ranks}") # Total ranks: 1 + """ + return self.num_ranks + + def __throw_if_invalid_output_tensor(self, tensor: torch.Tensor, num_elements: int, dtype: torch.dtype): + if not self.__tensor_on_device(tensor): + raise RuntimeError( + f"The output tensor is not on the same device as the Iris instance. The Iris instance is on device {self.device} but the output tensor is on device {tensor.device}" + ) + if not self.__on_symmetric_heap(tensor): + raise RuntimeError( + f"The output tensor is not on the symmetric heap. The Iris instance is on heap base {self.heap_bases[self.cur_rank]} but the output tensor is on heap base {tensor.data_ptr()}" + ) + if tensor.numel() != num_elements: + raise RuntimeError(f"The output tensor has {tensor.numel()} elements, but {num_elements} are required") + if tensor.dtype != dtype: + raise RuntimeError(f"The output tensor has dtype {tensor.dtype}, but {dtype} is required") + + def __throw_if_invalid_device(self, device): + """ + Throw a RuntimeError if the requested device is not compatible with this Iris instance. + + Args: + device: The requested device (can be string, torch.device, or None) + + Raises: + RuntimeError: If the device is not compatible + """ + if not self.__is_valid_device(device): + raise RuntimeError( + f"Device mismatch: requested device {device} but Iris instance is on device {self.device}. " + f"Iris only supports tensors on its own device." + ) + + def __apply_memory_format( + self, tensor: torch.Tensor, size: tuple, memory_format: torch.memory_format, input_tensor: torch.Tensor = None + ): + """ + Apply the requested memory format to a tensor by setting appropriate strides. + This keeps the tensor on the symmetric heap while changing how PyTorch interprets the memory layout. + + Args: + tensor: The tensor to modify + size: The tensor's size/dimensions + memory_format: The desired memory format + input_tensor: The original input tensor (needed for preserve_format detection) + """ + if memory_format == torch.contiguous_format: + # Default format, no changes needed + return tensor + elif memory_format == torch.channels_last and len(size) == 4: + # For channels_last format: preserve shape (N, C, H, W) but change strides + # channels_last strides: [C*H*W, 1, C*W, C] for shape (N, C, H, W) + N, C, H, W = size[0], size[1], size[2], size[3] + # Keep the original shape (N, C, H, W) but use channels_last strides + tensor = self.__create_tensor_with_strides(tensor, size, (C * H * W, 1, C * W, C)) + return tensor + elif memory_format == torch.channels_last_3d and len(size) == 5: + # For channels_last_3d format: preserve shape (N, C, D, H, W) but change strides + # channels_last_3d strides: [C*D*H*W, 1, C*D*W, C*W, C] for shape (N, C, D, H, W) + N, C, D, H, W = size[0], size[1], size[2], size[3], size[4] + # Keep the original shape (N, C, D, H, W) but use channels_last_3d strides + tensor = self.__create_tensor_with_strides(tensor, size, (C * D * H * W, 1, C * D * W, C * W, C)) + return tensor + elif memory_format == torch.preserve_format: + # For preserve_format, we need to detect the input tensor's memory format + # and apply the same format to the output + if input_tensor is not None: + # Check the actual memory format of the input tensor + if len(size) == 4: + # Check if input tensor is in channels_last format by examining strides + # channels_last format has strides[1] == 1 (channels dimension is contiguous) + input_strides = input_tensor.stride() + if len(input_strides) == 4 and input_strides[1] == 1: + # Input is in channels_last format, preserve it + # Use the input tensor's actual shape, not the size parameter + input_shape = input_tensor.shape + if len(input_shape) == 4: + # Input is already in channels_last format (N, H, W, C) + new_size = input_shape + # Use the input tensor's strides directly + tensor = self.__create_tensor_with_strides(tensor, new_size, input_strides) + return tensor + elif len(size) == 5: + # Check if input tensor is in channels_last_3d format + input_strides = input_tensor.stride() + if len(input_strides) == 5 and input_strides[1] == 1: + # Input is in channels_last_3d format, preserve it + # Use the input tensor's actual shape, not the size parameter + input_shape = input_tensor.shape + if len(input_shape) == 5: + # Input is already in channels_last_3d format (N, D, H, W, C) + new_size = input_shape + # Use the input tensor's strides directly + tensor = self.__create_tensor_with_strides(tensor, new_size, input_strides) + return tensor + # If no special format detected or no input tensor provided, use contiguous format + return tensor + else: + # Unsupported format or dimension combination + self.debug( + f"Warning: Memory format {memory_format} not supported for {len(size)}D tensor, using contiguous format" + ) + # For unsupported formats, return the tensor as-is (contiguous) + return tensor + + def __create_tensor_with_strides(self, original_tensor: torch.Tensor, size: tuple, strides: tuple) -> torch.Tensor: + """ + Create a new tensor with the specified strides while keeping the data on the symmetric heap. + + Args: + original_tensor: The original tensor (source of data and heap allocation) + size: The tensor's size/dimensions + strides: The desired strides for the new memory format + + Returns: + A new tensor with the specified strides, data copied from original, on the same heap + """ + + # First, create a temporary tensor with the correct strides using PyTorch + temp_tensor = torch.empty_strided(size, strides, dtype=original_tensor.dtype, device=original_tensor.device) + + # Handle different cases based on whether size changes and what the strides indicate + if size != original_tensor.shape: + # Size is different - this might be a format change that requires permutation + # Check if this is a channels_last format by comparing strides + if len(size) == 4: + # For channels_last: expected strides are [H*W*C, 1, W*C, C] for shape (N, H, W, C) + N, H, W, C = size[0], size[1], size[2], size[3] + expected_strides = (H * W * C, 1, W * C, C) + if strides == expected_strides: + permuted = original_tensor.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) + else: + # If the size differs for other reasons, do not permute; just reshape if possible + try: + permuted = original_tensor.reshape(size) + except Exception: + raise ValueError( + "Cannot safely permute or reshape tensor: size differs from original shape for unknown reason." + ) + elif len(size) == 5: + # For channels_last_3d: expected strides are [D*H*W*C, 1, H*W*C, W*C, C] for shape (N, D, H, W, C) + N, D, H, W, C = size[0], size[1], size[2], size[3], size[4] + expected_strides = (D * H * W * C, 1, H * W * C, W * C, C) + if strides == expected_strides: + permuted = original_tensor.permute(0, 2, 3, 4, 1) # (N, C, D, H, W) -> (N, D, H, W, C) + else: + # If the size differs for other reasons, do not permute; just reshape if possible + try: + permuted = original_tensor.reshape(size) + except Exception: + raise ValueError( + "Cannot safely permute or reshape tensor: size differs from original shape for unknown reason." + ) + else: + # For other dimensions, just try to reshape + try: + permuted = original_tensor.reshape(size) + except Exception: + raise ValueError( + "Cannot safely permute or reshape tensor: size differs from original shape for unknown reason." + ) + else: + # Size is the same - this is a stride-only change (like channels_last with preserved shape) + # We need to reorder the data to match the new stride pattern + if len(size) == 4: + # Check if this is channels_last format with preserved shape + N, C, H, W = size[0], size[1], size[2], size[3] + expected_strides = (C * H * W, 1, C * W, C) + if strides == expected_strides: + permuted = original_tensor + else: + permuted = original_tensor + elif len(size) == 5: + # Check if this is channels_last_3d format with preserved shape + N, C, D, H, W = size[0], size[1], size[2], size[3], size[4] + expected_strides = (C * D * H * W, 1, C * D * W, C * W, C) + if strides == expected_strides: + permuted = original_tensor + else: + permuted = original_tensor + else: + permuted = original_tensor + + # Copy the permuted data to the temporary tensor + temp_tensor.copy_(permuted) + + # Now allocate a new tensor on our symmetric heap + num_elements = math.prod(size) + heap_tensor = self.__allocate(num_elements, original_tensor.dtype) + + # Reshape to the desired size + heap_tensor = heap_tensor.reshape(size) + + # Copy the data from the temporary tensor to our heap tensor + heap_tensor.copy_(temp_tensor) + + # Clean up the temporary tensor + del temp_tensor + + # Now we need to create a view with the correct strides + # We can't use as_strided directly on our heap tensor, but we can + # create a new tensor with the right strides and copy the data again + final_tensor = torch.as_strided(heap_tensor, size, strides) + + return final_tensor + + def __apply_layout(self, tensor: torch.Tensor, layout: torch.layout) -> torch.Tensor: + """ + Apply the requested layout to a tensor. + + Args: + tensor: The tensor to modify + layout: The desired layout + + Returns: + Tensor with the requested layout + """ + + if layout == torch.strided: + # Strided layout is the default - no changes needed + return tensor + else: + # Only support strided layout for now + raise ValueError(f"Layout {layout} not supported. Only torch.strided is currently supported.") + + def __tensor_on_device(self, tensor: torch.Tensor): + # Get the Iris device from memory_pool.device + iris_device = self.get_device() + tensor_device = tensor.device + + # For CUDA devices, check if they're compatible + if tensor_device.type == "cuda" and iris_device.type == "cuda": + if iris_device.index is None: + return True + return tensor_device.index == iris_device.index + + # For non-CUDA devices, they must be exactly equal + return tensor_device == iris_device + + def __on_symmetric_heap(self, tensor: torch.Tensor): + """Check if a tensor is allocated on the symmetric heap.""" + return self.heap.on_symmetric_heap(tensor) + + def __is_valid_device(self, device) -> bool: + """ + Check if the requested device is compatible with this Iris instance. + + Args: + device: The requested device (can be string, torch.device, or None) + + Returns: + bool: True if the device is compatible, False otherwise + """ + if device is None: + return True # None means use default device + + # Convert device strings to torch.device objects for proper comparison + requested_device = torch.device(device) if isinstance(device, str) else device + iris_device = self.get_device() + + # Check if both are CUDA devices + if requested_device.type == "cuda" and iris_device.type == "cuda": + # Check if index matches or if requested is "cuda" (any index) + if requested_device.index is None: + return True + else: + return requested_device.index == iris_device.index + + # For non-CUDA devices, always return False + return False + + class CCL: + """ + Collective Communication Library (CCL) interface for Iris. + + Provides collective operations that can be called as methods on the Iris instance. + Example usage: + >>> shmem = iris.iris() + >>> shmem.ccl.all_to_all(output_tensor, input_tensor) + """ + + def __init__(self, iris_instance): + """ + Initialize CCL with a reference to the parent Iris instance. + + Args: + iris_instance: The parent Iris instance + """ + self._iris = iris_instance + + def all_to_all(self, output_tensor, input_tensor, group=None, async_op=False, config=None): + """ + All-to-all collective operation. + + Each rank sends a tensor chunk to each other rank and receives + a tensor chunk from each other rank. Input/output tensors should have + shape (M, N * world_size) where each chunk of N columns corresponds to one rank. + + Args: + output_tensor: Output tensor of shape (M, N * world_size) + input_tensor: Input tensor of shape (M, N * world_size) + group: ProcessGroup or None. If None, uses all ranks in shmem context. + Default: None. + async_op: If False, performs a barrier at the end. If True, returns immediately. + Default: False. + config: Config instance with kernel parameters (default: None). + If None, uses default Config values. + + Example: + >>> shmem = iris.iris() + >>> shmem.ccl.all_to_all(output_tensor, input_tensor) + + >>> # Custom configuration + >>> from iris.ccl import Config + >>> config = Config(block_size_m=128, block_size_n=32) + >>> shmem.ccl.all_to_all(output_tensor, input_tensor, config=config) + + >>> # Async operation (no barrier) + >>> shmem.ccl.all_to_all(output_tensor, input_tensor, async_op=True) + """ + from iris.ccl.all_to_all import all_to_all as _all_to_all + + _all_to_all(output_tensor, input_tensor, self._iris, group=group, async_op=async_op, config=config) + + def all_gather(self, output_tensor, input_tensor, group=None, async_op=False, config=None): + """ + All-gather collective operation. + + Each rank sends its input tensor to all ranks, and all ranks receive + and concatenate all input tensors along dimension 0 (rows), matching + torch.distributed.all_gather_into_tensor behavior. + + Args: + output_tensor: Output tensor of shape (world_size * M, N) - will contain concatenated inputs + input_tensor: Input tensor of shape (M, N) - local rank's data to send + group: ProcessGroup or None. If None, uses all ranks in shmem context. + Default: None. + async_op: If False, performs a barrier at the end. If True, returns immediately. + Default: False. + config: Config instance with kernel parameters (default: None). + If None, uses default Config values. + + Example: + >>> shmem = iris.iris() + >>> # Input: (M, N), Output: (world_size * M, N) + >>> shmem.ccl.all_gather(output_tensor, input_tensor) + + >>> # Custom configuration + >>> from iris.ccl import Config + >>> config = Config(block_size_m=128, block_size_n=32) + >>> shmem.ccl.all_gather(output_tensor, input_tensor, config=config) + + >>> # Async operation (no barrier) + >>> shmem.ccl.all_gather(output_tensor, input_tensor, async_op=True) + """ + from iris.ccl.all_gather import all_gather as _all_gather + + _all_gather(output_tensor, input_tensor, self._iris, group=group, async_op=async_op, config=config) + + def all_reduce_preamble(self, output_tensor, input_tensor, config=None, workspace=None): + """ + Prepare reusable workspace for all-reduce. + + Args: + output_tensor: Output tensor that will receive the reduced data. + input_tensor: Input tensor providing the local contribution. + config: Optional Config describing variant parameters. + workspace: Optional existing workspace to update/reuse. + + Returns: + Workspace object that can be passed to ``all_reduce``. + """ + from iris.ccl.all_reduce import all_reduce_preamble as _all_reduce_preamble + + return _all_reduce_preamble( + output_tensor, + input_tensor, + self._iris, + config=config, + workspace=workspace, + ) + + def all_reduce( + self, output_tensor, input_tensor, op=None, group=None, async_op=False, config=None, workspace=None + ): + """ + All-reduce collective operation. + + Each rank has a local input tensor, and all ranks compute the sum of all + input tensors. The result is written to output_tensor on all ranks. + + Args: + output_tensor: Output tensor of shape (M, N) - will contain sum of all inputs + input_tensor: Input tensor of shape (M, N) - local rank's partial data + op: Reduction operation to apply. Currently only ReduceOp.SUM is supported. + Default: ReduceOp.SUM. + group: ProcessGroup or None. If None, uses all ranks in shmem context. + Default: None. + async_op: If False, performs a barrier at the end. If True, returns immediately. + Default: False. + config: Config instance with kernel parameters (default: None). + If None, uses default Config values. + Set config.all_reduce_variant to choose variant: "atomic", "ring", or "two_shot" + workspace: Optional workspace prepared by ``all_reduce_preamble`` to + reuse internal buffers across invocations. + + Example: + >>> shmem = iris.iris() + >>> shmem.ccl.all_reduce(output_tensor, input_tensor) + + >>> # Custom configuration with ring variant + >>> from iris.ccl import Config + >>> config = Config(all_reduce_variant="ring") + >>> shmem.ccl.all_reduce(output_tensor, input_tensor, config=config) + + >>> # Two-shot variant with block distribution + >>> config = Config(all_reduce_variant="two_shot", all_reduce_distribution=1) + >>> shmem.ccl.all_reduce(output_tensor, input_tensor, config=config) + + >>> # Async operation (no barrier) + >>> shmem.ccl.all_reduce(output_tensor, input_tensor, async_op=True) + """ + from iris.ccl.all_reduce import all_reduce as _all_reduce + from iris.ccl import ReduceOp + + # Default to SUM if not specified + if op is None: + op = ReduceOp.SUM + + return _all_reduce( + output_tensor, + input_tensor, + self._iris, + op=op, + group=group, + async_op=async_op, + config=config, + workspace=workspace, + ) + + def reduce_scatter(self, output_tensor, input_tensor, op=None, group=None, async_op=False, config=None): + """ + Reduce-scatter collective operation. + + Each rank reduces its assigned tiles from all ranks' inputs and stores + the result only to its own output tensor. This is similar to all-reduce + but without broadcasting the result to all ranks. + + Args: + output_tensor: Output tensor of shape (M, N) - will contain reduced tiles for this rank + input_tensor: Input tensor of shape (M, N) - local rank's partial data + op: Reduction operation to apply. Currently only ReduceOp.SUM is supported. + Default: ReduceOp.SUM. + group: ProcessGroup or None. If None, uses all ranks in shmem context. + Default: None. + async_op: If False, performs a barrier at the end. If True, returns immediately. + Default: False. + config: Config instance with kernel parameters (default: None). + If None, uses default Config values. + Only supports reduce_scatter_variant="two_shot". + + Example: + >>> shmem = iris.iris() + >>> shmem.ccl.reduce_scatter(output_tensor, input_tensor) + + >>> # Custom configuration + >>> from iris.ccl import Config + >>> config = Config(reduce_scatter_variant="two_shot", all_reduce_distribution=1) + >>> shmem.ccl.reduce_scatter(output_tensor, input_tensor, config=config) + """ + from iris.ccl.reduce_scatter import reduce_scatter as _reduce_scatter + from iris.ccl import ReduceOp + + # Default to SUM if not specified + if op is None: + op = ReduceOp.SUM + + _reduce_scatter( + output_tensor, input_tensor, self._iris, op=op, group=group, async_op=async_op, config=config + ) + + +@triton.jit +def __translate(ptr, from_rank, to_rank, heap_bases): + from_base = tl.load(heap_bases + from_rank) + to_base = tl.load(heap_bases + to_rank) + # convert to int to compute difference + ptr_int = tl.cast(ptr, tl.uint64) + # Find the offset from from_rank heap + offset = ptr_int - from_base + # Byte cast for byte offset addition + to_base_byte = tl.cast(to_base, tl.pointer_type(tl.int8)) + # Find the offset into the to_rank heap + translated_ptr_byte = to_base_byte + offset + # Cast to_base back to pointer type + translated_ptr = tl.cast(translated_ptr_byte, ptr.dtype) + + # Optimization to vectorize the load/store + # We can't do this in general because we don't know the shape of the tensor or block sizes + # ptr = tl.max_contiguous(tl.multiple_of(ptr, (16, 16)), (16, 32)) + + # 0 You can use this if your block sizes are multiples of 32. + # Largest vectorized load instruction is dwordx4 (128-bits) + translated_ptr = tl.multiple_of(translated_ptr, (32, 32)) + translated_ptr = tl.max_contiguous(translated_ptr, (32, 32)) + + # ptr = tl.max_contiguous(tl.multiple_of(ptr, 512), 512) + # translated_ptr = tl.max_contiguous(tl.multiple_of(translated_ptr, 512), 512) + return translated_ptr + + +@triton.jit +def load(pointer, to_rank, from_rank, heap_bases, mask=None): + """ + Loads a value from the specified rank's memory location. + + This function performs a memory read operation by translating the pointer + from the `from_rank`'s address space to the `to_rank`'s address space and loading + data from the target memory location. If the `from_rank` and `to_rank` are the same, + this function performs a local load operation. + + Args: + pointer (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the `from_rank`'s address space that will be translated to the `to_rank`'s address space. Must be the current rank where the pointer is local. + to_rank (int): The rank ID to which the pointer will be translated. Must be the current rank where the pointer is local. + from_rank (int): The rank ID from which to read the data. + heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. + mask (Block of triton.int1, optional): If mask[idx] is false, do not load the data at address pointer[idx]. Defaults to None. + + Returns: + Block: The loaded value from the target memory location. + + Example: + >>> @triton.jit + >>> def kernel(ptr, heap_bases): + >>> # Load data from rank 1's memory into the current rank + >>> cur_rank = 0 # Current rank + >>> remote_rank = 1 # Remote rank to load from + >>> data = iris.load(ptr, cur_rank, remote_rank, heap_bases) + >>> return data + """ + translated_ptr = __translate(pointer, to_rank, from_rank, heap_bases) + result = tl.load(translated_ptr, mask=mask) + return result + + +@triton.jit +def store(pointer, value, from_rank, to_rank, heap_bases, mask=None): + """ + Writes data to the specified rank's memory location. + + This function performs a memory write operation by translating the pointer + from the `from_rank`'s address space to the `to_rank`'s address space and storing + the provided data to the target memory location. If the `from_rank` and `to_rank` are the same, + this function performs a local store operation. + + Args: + pointer (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the `from_rank`'s address space that will be translated to the `to_rank`'s address space. Must be the current rank where the pointer is local. + value (Block): The tensor of elements to be stored. + from_rank (int): The rank ID from which the pointer originates. Must be the current rank where the pointer is local. + to_rank (int): The rank ID to which the data will be written. + heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. + mask (Block of triton.int1, optional): If mask[idx] is false, do not store the data at address pointer[idx]. Defaults to None. + + Returns: + None + + Example: + >>> @triton.jit + >>> def kernel(ptr, heap_bases): + >>> # Store value 42 into rank 1's heap from rank 0 + >>> cur_rank = 0 # Current rank (source) + >>> remote_rank = 1 # Remote rank (destination) + >>> value = 42 + >>> iris.store(ptr, value, cur_rank, remote_rank, heap_bases) + """ + translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) + tl.store(translated_ptr, value, mask=mask) + + +@triton.jit +def copy(src_ptr, dst_ptr, from_rank, to_rank, cur_rank, heap_bases, mask=None): + """ + Copies data from the specified rank's memory into the destination rank's memory. + This function performs the transfer by translating `src_ptr` from the `from_rank`'s address + space to the `to_rank`'s address space, performing a masked load from the translated + source, and storing the loaded data to `dst_ptr` in the `to_rank` memory location. + If `from_rank` and `to_rank` are the same, this function performs a local copy operation. + It is undefined behaviour if neither `from_rank` nor `to_rank` is the `cur_rank`. + + Args: + src_ptr (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the `from_rank`'s local memory from which to read data. + dst_ptr (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the `to_rank`'s local memory where the data will be written. + from_rank (int): The rank ID that owns `src_ptr` (source rank). + to_rank (int): The rank ID that will receive the data (destination rank). + cur_rank (int): The rank ID issuing the copy operation. Must be either `from_rank` or `to_rank`. + heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. + mask (Block of triton.int1, optional): If mask[idx] is false, do not load from the translated src_ptr[idx] and do not store to dst_ptr[idx]. Defaults to None. + + Returns: + None + + Example: + >>> @triton.jit + >>> def kernel(remote_ptr, local_ptr, heap_bases): + >>> from_rank = 1 + >>> to_rank = 0 + >>> iris.copy(remote_ptr, local_ptr, from_rank, to_rank, to_rank, heap_bases) + """ + + cur_base = tl.load(heap_bases + cur_rank) + + from_base = tl.load(heap_bases + from_rank) + to_base = tl.load(heap_bases + to_rank) + + src_ptr_int = tl.cast(src_ptr, tl.uint64) + src_offset = src_ptr_int - cur_base + + dst_ptr_int = tl.cast(dst_ptr, tl.uint64) + dst_offset = dst_ptr_int - cur_base + + from_base_byte = tl.cast(from_base, tl.pointer_type(tl.int8)) + to_base_byte = tl.cast(to_base, tl.pointer_type(tl.int8)) + + translated_src = tl.cast(from_base_byte + src_offset, src_ptr.dtype) + translated_dst = tl.cast(to_base_byte + dst_offset, src_ptr.dtype) + + data = tl.load(translated_src, mask=mask) + tl.store(translated_dst, data, mask=mask) + + +@triton.jit +def get(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None): + """ + Copies data from the specified rank's memory to the current rank's local memory. + + This function performs a memory read operation by translating the `from_ptr` + from the current rank's address space to the `from_rank`'s address space, loading data + from the `from_rank` memory location, and storing it to the local `to_ptr`. + If the `from_rank` is the same as the current rank, this function performs a local copy operation. + + Args: + from_ptr (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the current rank's address space that will be translated to the `from_rank`'s address space. Must be the current rank where the pointer is local. + to_ptr (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the current rank's local memory where the data will be stored. + from_rank (int): The `from_rank` ID from which to read the data. + to_rank (int): The current rank ID where the data will be stored. + heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. + mask (Block of triton.int1, optional): If mask[idx] is false, do not load the data at address from_ptr[idx] and do not store to to_ptr[idx]. Defaults to None. + + Returns: + None + + Example: + >>> @triton.jit + >>> def kernel(remote_ptr, local_ptr, heap_bases): + >>> from_rank = 1 + >>> to_rank = 0 + >>> iris.get(remote_ptr, local_ptr, from_rank, to_rank, heap_bases) + """ + translated_from_ptr = __translate(from_ptr, from_rank, to_rank, heap_bases) + + data = tl.load(translated_from_ptr, mask=mask) + + tl.store(to_ptr, data, mask=mask) + + +@triton.jit +def put(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None): + """ + Copies data from the current rank's local memory to the specified rank's memory. + This function performs a memory write operation by loading data from the current + rank's `from_ptr`, translating the `to_ptr` from the current rank's address + space to the `to_rank`'s address space, and storing the data to the `to_rank` memory location. + If the `to_rank` is the same as the current rank, this function performs a local copy operation. + + Args: + from_ptr (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the current rank's local memory from which to read data. + to_ptr (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the current rank's address space that will be translated to the `to_rank`'s address space. Must be the current rank where the pointer is local. + from_rank (int): The current rank ID from which to read the data. + to_rank (int): The `to_rank` ID to which the data will be written. + heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. + mask (Block of triton.int1, optional): If mask[idx] is false, do not load the data at address from_ptr[idx] and do not store to to_ptr[idx]. Defaults to None. + + Returns: + None + + Example: + >>> @triton.jit + >>> def kernel(local_ptr, remote_ptr, heap_bases): + >>> from_rank = 0 + >>> to_rank = 1 + >>> iris.put(local_ptr, remote_ptr, from_rank, to_rank, heap_bases) + """ + translated_to_ptr = __translate(to_ptr, from_rank, to_rank, heap_bases) + + data = tl.load(from_ptr, mask=mask) + + tl.store(translated_to_ptr, data, mask=mask) + + +@triton.jit +def atomic_add(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None): + """ + Performs an atomic add at the specified rank's memory location. + + This function performs an atomic addition operation by translating the pointer + from the `from_rank`'s address space to the `to_rank`'s address space and atomically + adding the provided data to the `to_rank` memory location. If the `from_rank` and `to_rank` are the same, + this function performs a local atomic addition operation. + + Args: + pointer (triton.PointerType, or block of dtype=triton.PointerType): The memory locations in the `from_rank`'s address space that will be translated to the `to_rank`'s address space. Must be the current rank where the pointer is local. + val (Block of dtype=pointer.dtype.element_ty): The values with which to perform the atomic operation. + from_rank (int): The rank ID from which the pointer originates. Must be the current rank where the pointer is local. + to_rank (int): The rank ID to which the atomic operation will be performed. + heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. + mask (Block of triton.int1, optional): If mask[idx] is false, do not perform the atomic operation at address pointer[idx]. Defaults to None. + sem (str, optional): Specifies the memory semantics for the operation. Acceptable values are "acquire", "release", "acq_rel" (stands for "ACQUIRE_RELEASE"), and "relaxed". If not provided, the function defaults to using "acq_rel" semantics. + scope (str, optional): Defines the scope of threads that observe the synchronizing effect of the atomic operation. Acceptable values are "gpu" (default), "cta" (cooperative thread array, thread block), or "sys" (stands for "SYSTEM"). The default value is "gpu". + + Returns: + Block: The data stored at pointer before the atomic operation. + + Example: + >>> @triton.jit + >>> def kernel(ptr, heap_bases): + >>> # Atomically add 5 to rank 1's memory from rank 0 + >>> cur_rank = 0 # Current rank (source) + >>> remote_rank = 1 # Remote rank (destination) + >>> increment = 5 + >>> old_val = iris.atomic_add(ptr, increment, cur_rank, remote_rank, heap_bases) + """ + translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) + return tl.atomic_add(translated_ptr, val, mask=mask, sem=sem, scope=scope) + + +@triton.jit +def atomic_sub(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None): + """ + Atomically subtracts data from the specified rank's memory location. + + This function performs an atomic subtraction operation by translating the pointer + from the `from_rank`'s address space to the `to_rank`'s address space and atomically + subtracting the provided data from the `to_rank` memory location. If the `from_rank` and `to_rank` are the same, + this function performs a local atomic subtraction operation. + + Args: + pointer (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the `from_rank`'s address space that will be translated to the `to_rank`'s address space. Must be the current rank where the pointer is local. + val (Block): The tensor of elements to be subtracted atomically. + from_rank (int): The rank ID from which the pointer originates. Must be the current rank where the pointer is local. + to_rank (int): The rank ID to which the atomic operation will be performed. + heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. + mask (Block of triton.int1, optional): If mask[idx] is false, do not perform the atomic operation at address pointer[idx]. Defaults to None. + sem (str, optional): Specifies the memory semantics for the operation. Acceptable values are "acquire", "release", "acq_rel" (stands for "ACQUIRE_RELEASE"), and "relaxed". Defaults to "acq_rel". + scope (str, optional): Defines the scope of threads that observe the synchronizing effect of the atomic operation. Acceptable values are "gpu" (default), "cta" (cooperative thread array, thread block), or "sys" (stands for "SYSTEM"). Defaults to "gpu". + + Returns: + Block: The value at the memory location before the atomic subtraction. + + Example: + >>> @triton.jit + >>> def kernel(ptr, heap_bases): + >>> # Atomically subtract 3 from rank 2's memory from rank 0 + >>> cur_rank = 0 # Current rank (source) + >>> remote_rank = 2 # Remote rank (destination) + >>> decrement = 3 + >>> old_val = iris.atomic_sub(ptr, decrement, cur_rank, remote_rank, heap_bases) + """ + translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) + return tl.atomic_sub(translated_ptr, val, mask=mask, sem=sem, scope=scope) + + +@triton.jit +def atomic_cas(pointer, cmp, val, from_rank, to_rank, heap_bases, sem=None, scope=None): + """ + Atomically compares and exchanges the specified rank's memory location. + + This function performs an atomic compare-and-swap operation by translating the pointer + from the `from_rank`'s address space to the `to_rank`'s address space and atomically + comparing the current value with the expected value, then writing the new value if they match. + If the `from_rank` and `to_rank` are the same, this function performs a local atomic compare-and-swap operation. + + Args: + pointer (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the `from_rank`'s address space that will be translated to the `to_rank`'s address space. Must be the current rank where the pointer is local. + cmp (Block): The expected value to be compared with the current value at the memory location. + val (Block): The new value to be written if the compare succeeds. + from_rank (int): The rank ID from which the pointer originates. Must be the current rank where the pointer is local. + to_rank (int): The rank ID to which the atomic operation will be performed. + heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. + sem (str, optional): Specifies the memory semantics for the operation. Acceptable values are "acquire", "release", "acq_rel" (stands for "ACQUIRE_RELEASE"), and "relaxed". Defaults to "acq_rel". + scope (str, optional): Defines the scope of threads that observe the synchronizing effect of the atomic operation. Acceptable values are "gpu" (default), "cta" (cooperative thread array, thread block), or "sys" (stands for "SYSTEM"). Defaults to "gpu". + + Returns: + Block: The value contained at the memory location before the atomic operation attempt. + + Example: + >>> @triton.jit + >>> def kernel(ptr, heap_bases): + >>> # Compare-and-swap on rank 1's memory from rank 0 + >>> cur_rank = 0 # Current rank (source) + >>> remote_rank = 1 # Remote rank (destination) + >>> expected = 0 + >>> new_val = 42 + >>> old_val = iris.atomic_cas(ptr, expected, new_val, cur_rank, remote_rank, heap_bases) + """ + translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) + return tl.atomic_cas(translated_ptr, cmp, val, sem=sem, scope=scope) + + +@triton.jit +def atomic_xchg(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None): + """ + Performs an atomic exchange at the specified rank's memory location. + + This function performs an atomic exchange operation by translating the pointer + from the `from_rank`'s address space to the `to_rank`'s address space and atomically + exchanging the current value with the provided new value. If the `from_rank` and `to_rank` are the same, + this function performs a local atomic exchange operation. + + Args: + pointer (triton.PointerType, or block of dtype=triton.PointerType): The memory locations in the `from_rank`'s address space that will be translated to the `to_rank`'s address space. Must be the current rank where the pointer is local. + val (Block of dtype=pointer.dtype.element_ty): The values with which to perform the atomic operation. + from_rank (int): The rank ID from which the pointer originates. Must be the current rank where the pointer is local. + to_rank (int): The rank ID to which the atomic operation will be performed. + heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. + mask (Block of triton.int1, optional): If mask[idx] is false, do not perform the atomic operation at address pointer[idx]. Defaults to None. + sem (str, optional): Specifies the memory semantics for the operation. Acceptable values are "acquire", "release", "acq_rel" (stands for "ACQUIRE_RELEASE"), and "relaxed". If not provided, the function defaults to using "acq_rel" semantics. + scope (str, optional): Defines the scope of threads that observe the synchronizing effect of the atomic operation. Acceptable values are "gpu" (default), "cta" (cooperative thread array, thread block), or "sys" (stands for "SYSTEM"). The default value is "gpu". + + Returns: + Block: The data stored at pointer before the atomic operation. + + Example: + >>> @triton.jit + >>> def kernel(ptr, heap_bases): + >>> # Exchange value with rank 1's memory from rank 0 + >>> cur_rank = 0 # Current rank (source) + >>> remote_rank = 1 # Remote rank (destination) + >>> new_value = 99 + >>> old_val = iris.atomic_xchg(ptr, new_value, cur_rank, remote_rank, heap_bases) + """ + translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) + return tl.atomic_xchg(translated_ptr, val, mask=mask, sem=sem, scope=scope) + + +@triton.jit +def atomic_xor(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None): + """ + Performs an atomic xor at the specified rank's memory location. + + This function performs an atomic xor operation by translating the pointer + from the `from_rank`'s address space to the `to_rank`'s address space and atomically + xoring the provided data to the `to_rank` memory location. If the `from_rank` and `to_rank` are the same, + this function performs a local atomic xor operation. + + Args: + pointer (triton.PointerType, or block of dtype=triton.PointerType): The memory locations in the `from_rank`'s address space that will be translated to the `to_rank`'s address space. Must be the current rank where the pointer is local. + val (Block of dtype=pointer.dtype.element_ty): The values with which to perform the atomic operation. + from_rank (int): The rank ID from which the pointer originates. Must be the current rank where the pointer is local. + to_rank (int): The rank ID to which the atomic operation will be performed. + heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. + mask (Block of triton.int1, optional): If mask[idx] is false, do not perform the atomic operation at address pointer[idx]. Defaults to None. + sem (str, optional): Specifies the memory semantics for the operation. Acceptable values are "acquire", "release", "acq_rel" (stands for "ACQUIRE_RELEASE"), and "relaxed". If not provided, the function defaults to using "acq_rel" semantics. + scope (str, optional): Defines the scope of threads that observe the synchronizing effect of the atomic operation. Acceptable values are "gpu" (default), "cta" (cooperative thread array, thread block), or "sys" (stands for "SYSTEM"). The default value is "gpu". + + Returns: + Block: The data stored at pointer before the atomic operation. + + Example: + >>> @triton.jit + >>> def kernel(ptr, heap_bases): + >>> # Atomically XOR with rank 1's memory from rank 0 + >>> cur_rank = 0 # Current rank (source) + >>> remote_rank = 1 # Remote rank (destination) + >>> mask_val = 0xFF + >>> old_val = iris.atomic_xor(ptr, mask_val, cur_rank, remote_rank, heap_bases) + """ + translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) + return tl.atomic_xor(translated_ptr, val, mask=mask, sem=sem, scope=scope) + + +@triton.jit +def atomic_and(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None): + """ + Performs an atomic and at the specified rank's memory location. + + This function performs an atomic and operation by translating the pointer + from the `from_rank`'s address space to the `to_rank`'s address space and atomically + anding the provided data to the `to_rank` memory location. If the `from_rank` and `to_rank` are the same, + this function performs a local atomic and operation. + + Args: + pointer (triton.PointerType, or block of dtype=triton.PointerType): The memory locations in the `from_rank`'s address space that will be translated to the `to_rank`'s address space. Must be the current rank where the pointer is local. + val (Block of dtype=pointer.dtype.element_ty): The values with which to perform the atomic operation. + from_rank (int): The rank ID from which the pointer originates. Must be the current rank where the pointer is local. + to_rank (int): The rank ID to which the atomic operation will be performed. + heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. + mask (Block of triton.int1, optional): If mask[idx] is false, do not perform the atomic operation at address pointer[idx]. Defaults to None. + sem (str, optional): Specifies the memory semantics for the operation. Acceptable values are "acquire", "release", "acq_rel" (stands for "ACQUIRE_RELEASE"), and "relaxed". If not provided, the function defaults to using "acq_rel" semantics. + scope (str, optional): Defines the scope of threads that observe the synchronizing effect of the atomic operation. Acceptable values are "gpu" (default), "cta" (cooperative thread array, thread block), or "sys" (stands for "SYSTEM"). The default value is "gpu". + + Returns: + Block: The data stored at pointer before the atomic operation. + + Example: + >>> @triton.jit + >>> def kernel(ptr, heap_bases): + >>> # Atomically AND with rank 1's memory from rank 0 + >>> cur_rank = 0 # Current rank (source) + >>> remote_rank = 1 # Remote rank (destination) + >>> mask_val = 0x0F + >>> old_val = iris.atomic_and(ptr, mask_val, cur_rank, remote_rank, heap_bases) + """ + translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) + return tl.atomic_and(translated_ptr, val, mask=mask, sem=sem, scope=scope) + + +@triton.jit +def atomic_or(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None): + """ + Performs an atomic or at the specified rank's memory location. + + This function performs an atomic or operation by translating the pointer + from the `from_rank`'s address space to the `to_rank`'s address space and atomically + oring the provided data to the `to_rank` memory location. If the `from_rank` and `to_rank` are the same, + this function performs a local atomic or operation. + + Args: + pointer (triton.PointerType, or block of dtype=triton.PointerType): The memory locations in the `from_rank`'s address space that will be translated to the `to_rank`'s address space. Must be the current rank where the pointer is local. + val (Block of dtype=pointer.dtype.element_ty): The values with which to perform the atomic operation. + from_rank (int): The rank ID from which the pointer originates. Must be the current rank where the pointer is local. + to_rank (int): The rank ID to which the atomic operation will be performed. + heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. + mask (Block of triton.int1, optional): If mask[idx] is false, do not perform the atomic operation at address pointer[idx]. Defaults to None. + sem (str, optional): Specifies the memory semantics for the operation. Acceptable values are "acquire", "release", "acq_rel" (stands for "ACQUIRE_RELEASE"), and "relaxed". If not provided, the function defaults to using "acq_rel" semantics. + scope (str, optional): Defines the scope of threads that observe the synchronizing effect of the atomic operation. Acceptable values are "gpu" (default), "cta" (cooperative thread array, thread block), or "sys" (stands for "SYSTEM"). The default value is "gpu". + + Returns: + Block: The data stored at pointer before the atomic operation. + + Example: + >>> @triton.jit + >>> def kernel(ptr, heap_bases): + >>> # Atomically OR with rank 1's memory from rank 0 + >>> cur_rank = 0 # Current rank (source) + >>> remote_rank = 1 # Remote rank (destination) + >>> mask_val = 0xF0 + >>> old_val = iris.atomic_or(ptr, mask_val, cur_rank, remote_rank, heap_bases) + """ + translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) + return tl.atomic_or(translated_ptr, val, mask=mask, sem=sem, scope=scope) + + +@triton.jit +def atomic_min(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None): + """ + Performs an atomic min at the specified rank's memory location. + + This function performs an atomic min operation by translating the pointer + from the `from_rank`'s address space to the `to_rank`'s address space and atomically + performing the min on the provided data to the `to_rank` memory location. If the `from_rank` and `to_rank` are the same, + this function performs a local atomic min operation. + + Args: + pointer (triton.PointerType, or block of dtype=triton.PointerType): The memory locations in the `from_rank`'s address space that will be translated to the `to_rank`'s address space. Must be the current rank where the pointer is local. + val (Block of dtype=pointer.dtype.element_ty): The values with which to perform the atomic operation. + from_rank (int): The rank ID from which the pointer originates. Must be the current rank where the pointer is local. + to_rank (int): The rank ID to which the atomic operation will be performed. + heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. + mask (Block of triton.int1, optional): If mask[idx] is false, do not perform the atomic operation at address pointer[idx]. Defaults to None. + sem (str, optional): Specifies the memory semantics for the operation. Acceptable values are "acquire", "release", "acq_rel" (stands for "ACQUIRE_RELEASE"), and "relaxed". If not provided, the function defaults to using "acq_rel" semantics. + scope (str, optional): Defines the scope of threads that observe the synchronizing effect of the atomic operation. Acceptable values are "gpu" (default), "cta" (cooperative thread array, thread block), or "sys" (stands for "SYSTEM"). The default value is "gpu". + + Returns: + Block: The data stored at pointer before the atomic operation. + + Example: + >>> @triton.jit + >>> def kernel(ptr, heap_bases): + >>> # Atomically find minimum with rank 1's memory from rank 0 + >>> cur_rank = 0 # Current rank (source) + >>> remote_rank = 1 # Remote rank (destination) + >>> new_val = 10 + >>> old_val = iris.atomic_min(ptr, new_val, cur_rank, remote_rank, heap_bases) + """ + translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) + return tl.atomic_min(translated_ptr, val, mask=mask, sem=sem, scope=scope) + + +@triton.jit +def atomic_max(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None): + """ + Performs an atomic max at the specified rank's memory location. + + This function performs an atomic max operation by translating the pointer + from the `from_rank`'s address space to the `to_rank`'s address space and atomically + performing the max on the provided data to the `to_rank` memory location. If the `from_rank` and `to_rank` are the same, + this function performs a local atomic max operation. + + Args: + pointer (triton.PointerType, or block of dtype=triton.PointerType): The memory locations in the `from_rank`'s address space that will be translated to the `to_rank`'s address space. Must be the current rank where the pointer is local. + val (Block of dtype=pointer.dtype.element_ty): The values with which to perform the atomic operation. + from_rank (int): The rank ID from which the pointer originates. Must be the current rank where the pointer is local. + to_rank (int): The rank ID to which the atomic operation will be performed. + heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. + mask (Block of triton.int1, optional): If mask[idx] is false, do not perform the atomic operation at address pointer[idx]. Defaults to None. + sem (str, optional): Specifies the memory semantics for the operation. Acceptable values are "acquire", "release", "acq_rel" (stands for "ACQUIRE_RELEASE"), and "relaxed". If not provided, the function defaults to using "acq_rel" semantics. + scope (str, optional): Defines the scope of threads that observe the synchronizing effect of the atomic operation. Acceptable values are "gpu" (default), "cta" (cooperative thread array, thread block), or "sys" (stands for "SYSTEM"). The default value is "gpu". + + Returns: + Block: The data stored at pointer before the atomic operation. + + Example: + >>> @triton.jit + >>> def kernel(ptr, heap_bases): + >>> # Atomically find maximum with rank 1's memory from rank 0 + >>> cur_rank = 0 # Current rank (source) + >>> remote_rank = 1 # Remote rank (destination) + >>> new_val = 100 + >>> old_val = iris.atomic_max(ptr, new_val, cur_rank, remote_rank, heap_bases) + """ + translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) + return tl.atomic_max(translated_ptr, val, mask=mask, sem=sem, scope=scope) + + +def iris(heap_size=1 << 30): + """ + Create and return an Iris instance with the specified heap size. + + Args: + heap_size (int): Size of the heap in bytes. Defaults to 1GB. + + Returns: + Iris: An initialized Iris instance. + + Example: + >>> import iris + >>> iris_ctx = iris.iris(2**30) # 1GB heap + >>> tensor = iris_ctx.zeros(1024, 1024) + """ + return Iris(heap_size) diff --git a/iris/ops/all_gather_matmul.py.with_chunked b/iris/ops/all_gather_matmul.py.with_chunked new file mode 100644 index 000000000..ddc03d027 --- /dev/null +++ b/iris/ops/all_gather_matmul.py.with_chunked @@ -0,0 +1,521 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +""" +Fused All-Gather + GEMM operation using pull pattern. + +Each rank has a column-sharded input A_sharded (M x K_local). +This operation computes C = all_gather(A_sharded) @ B by pulling +tiles from remote ranks on-demand during GEMM computation. +""" + +from typing import Optional +import torch +import triton +import triton.language as tl +import iris +import iris.x + +from tritonblas.kernels.stages.algorithms.binary import add_vector +from tritonblas.kernels.stages.algorithms.unary import convert_dtype + +from .config import FusedConfig +from .workspace import FusedWorkspace + + +@triton.jit() +def _fused_all_gather_matmul_kernel( + A_sharded, + B, + C, + bias_ptr, + M: tl.constexpr, + N: tl.constexpr, + K: tl.constexpr, + K_local: tl.constexpr, + stride_am: tl.constexpr, + stride_ak: tl.constexpr, + stride_bk: tl.constexpr, + stride_bn: tl.constexpr, + stride_cm: tl.constexpr, + stride_cn: tl.constexpr, + stride_bias: tl.constexpr, + heap_bases: tl.tensor, + cur_rank: tl.constexpr, + world_size: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + NUM_SMS: tl.constexpr, + NUM_XCDS: tl.constexpr, + BIAS: tl.constexpr, + EVEN_K: tl.constexpr, + ALLOW_TF32: tl.constexpr, +): + """Fused all-gather + GEMM kernel using pull pattern.""" + pid = tl.program_id(0) + + # Handle multi-XCD devices + if NUM_XCDS != 1: + pid = (pid % NUM_XCDS) * (NUM_SMS // NUM_XCDS) + (pid // NUM_XCDS) + + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + total_tiles = num_pid_m * num_pid_n + + tl.assume(stride_am > 0) + tl.assume(stride_ak > 0) + tl.assume(stride_bk > 0) + tl.assume(stride_bn > 0) + tl.assume(stride_cm > 0) + tl.assume(stride_cn > 0) + + acc_dtype = tl.int32 if C.type.element_ty == tl.int8 else tl.float32 + + # Persistent loop over output tiles + for tile_id in range(pid, total_tiles, NUM_SMS): + # Compute tile coordinates with swizzling + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + + # Compute row and column indices + rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) + rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) + + # Initialize accumulator + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) + + # Create DeviceContext and TensorView for gather operations + ctx = iris.x.DeviceContext(cur_rank, world_size, heap_bases) + src_view = iris.x.TensorView(A_sharded, M, K_local, stride_am, stride_ak) + + # Loop over all ranks to pull and accumulate + for source_rank_id in range(world_size): + loop_k_local = tl.cdiv(K_local, BLOCK_SIZE_K) + if not EVEN_K: + loop_k_local -= 1 + + # Loop over K dimension for this rank's shard + for k_block_idx in range(0, loop_k_local): + k_offset = k_block_idx * BLOCK_SIZE_K + + # Create tile view for this K block + tile_k = k_offset // BLOCK_SIZE_K + k_tile = iris.x.TileView(pid_m, tile_k, BLOCK_SIZE_M, BLOCK_SIZE_K) + + # Pull A tile from source_rank_id using gather primitive + a = iris.x.gather(k_tile, src_view, source_rank_id, ctx) + + # Load B tile + rk_local = k_offset + tl.arange(0, BLOCK_SIZE_K) + rk_global = (source_rank_id * K_local) + rk_local + B_ptr = B + rk_global[:, None] * stride_bk + rn[None, :] * stride_bn + b = tl.load(tl.multiple_of(B_ptr, (16, 1))) + + # Accumulate + if ALLOW_TF32: + acc = tl.dot(a, b, acc, allow_tf32=True) + else: + acc += tl.dot(a, b, allow_tf32=False) + + # Handle remaining K elements if not evenly divisible + if not EVEN_K: + k_offset = loop_k_local * BLOCK_SIZE_K + tile_k = k_offset // BLOCK_SIZE_K + k_tile = iris.x.TileView(pid_m, tile_k, BLOCK_SIZE_M, BLOCK_SIZE_K) + + # Pull A tile from source_rank_id using gather primitive + a = iris.x.gather(k_tile, src_view, source_rank_id, ctx) + + rk_local = k_offset + tl.arange(0, BLOCK_SIZE_K) + rk_global = (source_rank_id * K_local) + rk_local + rk_global_mask = rk_global < K + B_ptr = B + rk_global[:, None] * stride_bk + rn[None, :] * stride_bn + b = tl.load(tl.multiple_of(B_ptr, (16, 1)), mask=rk_global_mask[:, None], other=0.0) + + if ALLOW_TF32: + acc = tl.dot(a, b, acc, allow_tf32=True) + else: + acc += tl.dot(a, b, allow_tf32=False) + + # Add bias if provided using tritonBLAS + if BIAS: + bias_vector = tl.load(bias_ptr + rm * stride_bias, mask=rm < M, other=0.0) + acc = add_vector(acc, bias_vector, QUANTIZED=False) + + # Convert to output dtype using tritonBLAS + c = convert_dtype(acc, C.type.element_ty) + + # Store result (manual for now, tritonBLAS store has issues with our indices) + C_ptr = ( + C + + (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M))[:, None] * stride_cm + + (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N))[None, :] * stride_cn + ) + mask = ((pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M))[:, None] < M) & ( + (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N))[None, :] < N + ) + tl.store(C_ptr, c, mask=mask) + + +@triton.jit() +def _fused_chunked_all_gather_matmul_kernel( + A_sharded, + B, + C, + bias_ptr, + temp_buffer, # Temporary buffer: BLOCK_M x K x num_tiles + M: tl.constexpr, + N: tl.constexpr, + K: tl.constexpr, + K_local: tl.constexpr, + stride_am: tl.constexpr, + stride_ak: tl.constexpr, + stride_bk: tl.constexpr, + stride_bn: tl.constexpr, + stride_cm: tl.constexpr, + stride_cn: tl.constexpr, + stride_bias: tl.constexpr, + heap_bases: tl.tensor, + cur_rank: tl.constexpr, + world_size: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + NUM_SMS: tl.constexpr, + NUM_XCDS: tl.constexpr, + BIAS: tl.constexpr, + EVEN_K: tl.constexpr, + ALLOW_TF32: tl.constexpr, +): + """ + Fused all-gather + GEMM kernel using chunked/buffered pattern. + + This variant pre-gathers all of A into a temporary buffer before computing GEMM. + Eliminates the world_size loop by using iris.x.all_gather upfront. + + Memory layout: + - temp_buffer: BLOCK_M x K x num_tiles (stores gathered A for each tile) + - Each program gathers its M-tile of A, then does GEMM + """ + pid = tl.program_id(0) + + # Handle multi-XCD devices + if NUM_XCDS != 1: + pid = (pid % NUM_XCDS) * (NUM_SMS // NUM_XCDS) + (pid // NUM_XCDS) + + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + total_tiles = num_pid_m * num_pid_n + + tl.assume(stride_am > 0) + tl.assume(stride_ak > 0) + tl.assume(stride_bk > 0) + tl.assume(stride_bn > 0) + tl.assume(stride_cm > 0) + tl.assume(stride_cn > 0) + + acc_dtype = tl.int32 if C.type.element_ty == tl.int8 else tl.float32 + + # Persistent loop over output tiles + for tile_id in range(pid, total_tiles, NUM_SMS): + # Compute tile coordinates with swizzling + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + + # Compute row and column indices + rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) + rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) + + # Buffer pointer for this tile: BLOCK_M x K for this pid_m + buffer_ptr = temp_buffer + tile_id * BLOCK_SIZE_M * K + + # Step 1: Pre-gather entire M-tile of A (BLOCK_M x K) + # Create DeviceContext and TensorView for gather operations + ctx = iris.x.DeviceContext(cur_rank, world_size, heap_bases) + src_view = iris.x.TensorView(A_sharded, M, K_local, stride_am, stride_ak) + + # Gather K-tiles from all ranks + for source_rank_id in range(world_size): + k_start = source_rank_id * K_local + # Loop over K dimension in blocks + for k_local_idx in range(0, K_local, BLOCK_SIZE_K): + k_global = k_start + k_local_idx + rk = k_global + tl.arange(0, BLOCK_SIZE_K) + rk_mask = rk < K + + tile_k = k_local_idx // BLOCK_SIZE_K + k_tile = iris.x.TileView(pid_m, tile_k, BLOCK_SIZE_M, BLOCK_SIZE_K) + + # Pull A tile from source_rank_id + a = iris.x.gather(k_tile, src_view, source_rank_id, ctx) + + # Store in buffer + buffer_A_ptr = buffer_ptr + rm[:, None] * K + rk[None, :] + tl.store(buffer_A_ptr, a, mask=rk_mask[None, :]) + + # Step 2: Standard GEMM from buffer + # Initialize accumulator + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) + + # Loop over K dimension + loop_k = tl.cdiv(K, BLOCK_SIZE_K) + if EVEN_K: + for k_block_idx in range(loop_k): + k_offset = k_block_idx * BLOCK_SIZE_K + + # Load A from temp buffer + rk = k_offset + tl.arange(0, BLOCK_SIZE_K) + buffer_A_ptr = buffer_ptr + rm[:, None] * K + rk[None, :] + a = tl.load(buffer_A_ptr) + + # Load B tile + B_ptr = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn + b = tl.load(tl.multiple_of(B_ptr, (16, 1))) + + # Accumulate + if ALLOW_TF32: + acc = tl.dot(a, b, acc, allow_tf32=True) + else: + acc += tl.dot(a, b, allow_tf32=False) + else: + # Handle case where K is not evenly divisible by BLOCK_SIZE_K + for k_block_idx in range(loop_k): + k_offset = k_block_idx * BLOCK_SIZE_K + + # Load A from temp buffer + rk = k_offset + tl.arange(0, BLOCK_SIZE_K) + rk_mask = rk < K + buffer_A_ptr = buffer_ptr + rm[:, None] * K + rk[None, :] + a = tl.load(buffer_A_ptr, mask=rk_mask[None, :], other=0.0) + + # Load B tile + B_ptr = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn + b = tl.load(tl.multiple_of(B_ptr, (16, 1)), mask=rk_mask[:, None], other=0.0) + + if ALLOW_TF32: + acc = tl.dot(a, b, acc, allow_tf32=True) + else: + acc += tl.dot(a, b, allow_tf32=False) + + # Convert accumulator and add bias + c = convert_dtype(acc, C.type.element_ty) + if BIAS: + bias_offset = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) * stride_bias + bias_val = tl.load(bias_ptr + bias_offset) + c = add_vector(c, bias_val, 0) + + # Store result + C_ptr = ( + C + + (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M))[:, None] * stride_cm + + (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N))[None, :] * stride_cn + ) + mask = ((pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M))[:, None] < M) & ( + (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N))[None, :] < N + ) + tl.store(C_ptr, c, mask=mask) + + +def all_gather_matmul_preamble( + shmem, + A_sharded: torch.Tensor, + B: torch.Tensor, + config: Optional[FusedConfig] = None, +) -> FusedWorkspace: + """Allocate workspace for all_gather_matmul (buffer needed for chunked variant).""" + if config is None: + config = FusedConfig() + + M, K_local = A_sharded.shape + K, N = B.shape + world_size = shmem.get_num_ranks() + + expected_K = world_size * K_local + assert K == expected_K, f"K ({K}) must equal world_size ({world_size}) * K_local ({K_local})" + + # Detect hardware configuration + device = A_sharded.device + if config.num_sms is None: + import iris.hip + num_sms = iris.hip.get_cu_count(device.index) + else: + num_sms = config.num_sms + + if config.num_xcds == 1: + # Auto-detect XCDs if default value is used + import iris.hip + num_xcds = iris.hip.get_num_xcc(device.index) + else: + num_xcds = config.num_xcds + + # Allocate temporary buffer for chunked variant + aux_buffer = None + if config.all_gather_matmul_variant == "chunked": + # Calculate grid size to determine buffer size + num_tiles_m = (M + config.block_size_m - 1) // config.block_size_m + num_tiles_n = (N + config.block_size_n - 1) // config.block_size_n + num_tiles = num_tiles_m * num_tiles_n + + # Allocate buffer: BLOCK_M x K x num_tiles + buffer_size = config.block_size_m * K * num_tiles + aux_buffer = torch.empty(buffer_size, dtype=A_sharded.dtype, device=device) + + return FusedWorkspace( + operation="all_gather_matmul", + shape=(M, N, K), + dtype=A_sharded.dtype, + world_size=world_size, + num_sms=num_sms, + num_xcds=num_xcds, + variant=config.all_gather_matmul_variant, + aux_buffer=aux_buffer, + prepared=True, + ) + + +def all_gather_matmul( + shmem, + output_tensor: torch.Tensor, + A_sharded: torch.Tensor, + B: torch.Tensor, + bias: Optional[torch.Tensor] = None, + async_op: bool = False, + config: Optional[FusedConfig] = None, + workspace: Optional[FusedWorkspace] = None, +) -> FusedWorkspace: + """Fused all-gather and matrix multiplication using pull pattern.""" + if config is None: + config = FusedConfig() + + M, K_local = A_sharded.shape + K, N = B.shape + world_size = shmem.get_num_ranks() + rank = shmem.get_rank() + + expected_K = world_size * K_local + assert K == expected_K, f"K ({K}) must equal world_size ({world_size}) * K_local ({K_local})" + assert output_tensor.shape == (M, N), f"Output must be ({M}, {N}), got {output_tensor.shape}" + + # Validate problem size against block sizes + assert M >= config.block_size_m, ( + f"M ({M}) must be >= block_size_m ({config.block_size_m}). Use smaller block sizes for small problems." + ) + assert K_local >= config.block_size_k, ( + f"K_local ({K_local}) must be >= block_size_k ({config.block_size_k}). " + f"Use smaller block sizes for small problems." + ) + assert N >= config.block_size_n, ( + f"N ({N}) must be >= block_size_n ({config.block_size_n}). Use smaller block sizes for small problems." + ) + + if workspace is None: + workspace = all_gather_matmul_preamble(shmem, A_sharded, B, config) + + stride_am, stride_ak = A_sharded.stride() + stride_bk, stride_bn = B.stride() + stride_cm, stride_cn = output_tensor.stride() + + if bias is not None: + assert bias.shape[0] == M + bias_ptr = bias + stride_bias = bias.stride()[0] if bias.dim() > 0 else 1 + use_bias = True + else: + bias_ptr = output_tensor + stride_bias = 1 + use_bias = False + + # Get hardware configuration from workspace + num_sms = workspace.num_sms + num_xcds = workspace.num_xcds + + even_k = K_local % config.block_size_k == 0 + + # Use SM-based grid (persistent kernels) + grid = (num_sms,) + + # Select kernel variant based on config + if config.all_gather_matmul_variant == "chunked": + # Chunked variant: pre-gather into buffer, then GEMM + assert workspace.aux_buffer is not None, "Chunked variant requires aux_buffer in workspace" + _fused_chunked_all_gather_matmul_kernel[grid]( + A_sharded, + B, + output_tensor, + bias_ptr, + workspace.aux_buffer, # Temporary buffer + M, + N, + K, + K_local, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_bias, + shmem.heap_bases, + rank, + world_size, + config.block_size_m, + config.block_size_n, + config.block_size_k, + config.group_size_m, + num_sms, + num_xcds, + use_bias, + even_k, + config.allow_tf32, + ) + else: + # Pull variant (default): on-demand pull from remote ranks + _fused_all_gather_matmul_kernel[grid]( + A_sharded, + B, + output_tensor, + bias_ptr, + M, + N, + K, + K_local, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_bias, + shmem.heap_bases, + rank, + world_size, + config.block_size_m, + config.block_size_n, + config.block_size_k, + config.group_size_m, + num_sms, + num_xcds, + use_bias, + even_k, + config.allow_tf32, + ) + + if not async_op: + shmem.barrier() + + return workspace diff --git a/iris/ops/config.py b/iris/ops/config.py index 3ca085c31..77c0b5ab9 100644 --- a/iris/ops/config.py +++ b/iris/ops/config.py @@ -19,10 +19,10 @@ class FusedConfig: but users can override specific settings for performance tuning. GEMM Parameters: - block_size_m: Block size for M dimension (rows). Default: 256. - block_size_n: Block size for N dimension (columns). Default: 64. + block_size_m: Block size for M dimension (rows). Default: 128. + block_size_n: Block size for N dimension (columns). Default: 256. block_size_k: Block size for K dimension (reduction). Default: 64. - group_size_m: Group size for M dimension tiling. Default: 1. + group_size_m: Group size for M dimension tiling. Default: 4. num_sms: Number of SMs to use. If None, auto-detects from device. Default: None. num_xcds: Number of XCDs (chiplets). Default: 1. chunk_size: Chunk size for chiplet transform. Default: 1. @@ -32,8 +32,12 @@ class FusedConfig: CCL Parameters (for operations that need collective communication): all_reduce_variant: All-reduce algorithm variant. Options: "atomic", "ring", - "one_shot", "two_shot", "spinlock". Default: "one_shot". + "one_shot", "two_shot", "spinlock". Default: "two_shot". all_reduce_num_rings: Number of concurrent rings (for ring variant). Default: 1. + all_gather_matmul_variant: All-gather + matmul algorithm variant. Options: + "pull" (on-demand pull from remote ranks), + "chunked" (pre-gather into buffer then GEMM). + Default: "pull". Example: >>> # Use defaults @@ -47,10 +51,10 @@ class FusedConfig: """ # GEMM parameters - block_size_m: int = 256 - block_size_n: int = 64 + block_size_m: int = 128 + block_size_n: int = 256 block_size_k: int = 64 - group_size_m: int = 1 + group_size_m: int = 4 num_sms: Optional[int] = None # Auto-detect if None num_xcds: int = 1 chunk_size: int = 1 @@ -61,6 +65,7 @@ class FusedConfig: # CCL-specific parameters all_reduce_variant: str = "two_shot" # atomic, ring, one_shot, two_shot, spinlock all_reduce_num_rings: int = 1 + all_gather_matmul_variant: str = "pull" # pull, chunked def validate(self, world_size: Optional[int] = None): """ @@ -102,3 +107,10 @@ def validate(self, world_size: Optional[int] = None): if self.all_reduce_num_rings <= 0: raise ValueError(f"all_reduce_num_rings must be positive, got {self.all_reduce_num_rings}") + + # Validate all_gather_matmul_variant + valid_ag_variants = ["pull", "chunked"] + if self.all_gather_matmul_variant not in valid_ag_variants: + raise ValueError( + f"all_gather_matmul_variant must be one of {valid_ag_variants}, got {self.all_gather_matmul_variant}" + ) diff --git a/iris/ops/workspace.py b/iris/ops/workspace.py index a9c7cb616..9328e9f9e 100644 --- a/iris/ops/workspace.py +++ b/iris/ops/workspace.py @@ -38,6 +38,10 @@ class FusedWorkspace: world_size: int = 1 variant: str = "" + # Hardware configuration (detected in preamble) + num_sms: Optional[int] = None # Number of streaming multiprocessors + num_xcds: int = 1 # Number of XCDs/chiplets + # Temporary buffers (allocated as needed) aux_buffer: Optional[torch.Tensor] = None # Generic buffer for intermediate results locks: Optional[torch.Tensor] = None # Synchronization primitives diff --git a/iris/x/gather.py b/iris/x/gather.py index ca8bd4f9c..51f489a03 100644 --- a/iris/x/gather.py +++ b/iris/x/gather.py @@ -52,7 +52,7 @@ def gather( if source_rank == ctx.rank: # Local load - tile_data = tl.load(src_tile_ptr, mask=mask, other=0.0) + tile_data = tl.load(src_tile_ptr, mask=mask) else: # Remote load using RMA tile_data = iris.load( diff --git a/tests/ops/test_all_gather_matmul.py b/tests/ops/test_all_gather_matmul.py index 193505011..7dceea126 100644 --- a/tests/ops/test_all_gather_matmul.py +++ b/tests/ops/test_all_gather_matmul.py @@ -28,7 +28,14 @@ (256, 64, 128), ], ) -def test_all_gather_matmul(dtype, atol, rtol, M, K_local, N): +@pytest.mark.parametrize( + "variant", + [ + "pull", + "chunked", + ], +) +def test_all_gather_matmul(dtype, atol, rtol, M, K_local, N, variant): """Test all_gather_matmul against torch all_gather + matmul.""" if not dist.is_initialized(): pytest.skip("torch.distributed not initialized") @@ -77,12 +84,20 @@ def test_all_gather_matmul(dtype, atol, rtol, M, K_local, N): # Run fused all_gather + matmul using shmem.ops API from iris.ops.config import FusedConfig + if rank == 0: + print(f"\n[Test] Testing variant={variant}, M={M}, K_local={K_local}, N={N}, dtype={dtype}") + # Use appropriate block sizes based on problem size # For small problems, use smaller blocks if M <= 256 or K_local <= 64 or N <= 128: - config = FusedConfig(block_size_m=64, block_size_n=64, block_size_k=32) + config = FusedConfig( + block_size_m=64, + block_size_n=64, + block_size_k=32, + all_gather_matmul_variant=variant, + ) else: - config = FusedConfig() + config = FusedConfig(all_gather_matmul_variant=variant) # Validate config against problem size assert M >= config.block_size_m, f"M ({M}) must be >= block_size_m ({config.block_size_m})" From f132cebf3c4202d56da4e81e973f5811fb33d7c5 Mon Sep 17 00:00:00 2001 From: neoblizz Date: Sat, 7 Feb 2026 20:13:20 +0000 Subject: [PATCH 03/31] Up the tritonBLAS commit. --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 18e71badb..025337641 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ dependencies = [ "numpy", "requests", "ruff", - "tritonblas @ git+https://github.com/ROCm/tritonBLAS.git@df58476a4520b72495a3f03f911368a184126568", + "tritonblas @ git+https://github.com/ROCm/tritonBLAS.git@cd119279f3df543a558aa6d2cd4a3daed0b1ec7a", ] From 1628a6192b72f5120d3ec78665c7f9f5430fd646 Mon Sep 17 00:00:00 2001 From: neoblizz Date: Tue, 10 Feb 2026 00:03:37 +0000 Subject: [PATCH 04/31] ... --- benchmark/ops/all_gather_matmul/benchmark.py | 20 ++++--------- iris/iris.py | 4 +-- iris/ops/all_gather_matmul.py | 31 ++++++++++++++++---- iris/ops/config.py | 6 ++-- iris/ops/workspace.py | 6 ++++ 5 files changed, 42 insertions(+), 25 deletions(-) diff --git a/benchmark/ops/all_gather_matmul/benchmark.py b/benchmark/ops/all_gather_matmul/benchmark.py index 20ff0c536..ae0443e6d 100644 --- a/benchmark/ops/all_gather_matmul/benchmark.py +++ b/benchmark/ops/all_gather_matmul/benchmark.py @@ -18,6 +18,7 @@ from examples.common.utils import JSONWriter import iris +from iris.ops.all_gather_matmul import all_gather_matmul_preamble from iris.ops import FusedConfig torch.manual_seed(123) @@ -65,8 +66,8 @@ def parse_args(): "--variant", type=str, default="pull", - choices=["pull", "chunked"], - help="All-gather matmul variant (pull or chunked)", + choices=["pull", "chunked", "push", "pipelined_pull"], + help="All-gather matmul variant", ) parser.add_argument( "--init_url", type=str, default="tcp://127.0.0.1:29530", help="Initialization URL for distributed setup" @@ -181,20 +182,11 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): }, } - workspace = None + # Pre-allocate workspace once (important for push variant which needs large buffers) + workspace = all_gather_matmul_preamble(shmem, A_sharded, B, config) def run_experiment(): - nonlocal kernel_timing, workspace - - # Preamble if available - if hasattr(shmem.ops, "all_gather_matmul_preamble"): - workspace = shmem.ops.all_gather_matmul_preamble( - C, - A_sharded, - B, - config=config, - workspace=workspace, - ) + nonlocal kernel_timing shmem.barrier() diff --git a/iris/iris.py b/iris/iris.py index 9b8a3d35a..21aaddd8a 100644 --- a/iris/iris.py +++ b/iris/iris.py @@ -1796,8 +1796,8 @@ def __translate(ptr, from_rank, to_rank, heap_bases): # Vectorization hints: must be <= minimum block size used by any caller. # (32, 32) is safe since all supported block sizes are multiples of 32. # Largest vectorized load instruction is dwordx4 (128-bits = 8 x fp16). - translated_ptr = tl.multiple_of(translated_ptr, (32, 32)) - translated_ptr = tl.max_contiguous(translated_ptr, (32, 32)) + # translated_ptr = tl.multiple_of(translated_ptr, (32, 32)) + # translated_ptr = tl.max_contiguous(translated_ptr, (32, 32)) return translated_ptr diff --git a/iris/ops/all_gather_matmul.py b/iris/ops/all_gather_matmul.py index 5d700206c..0dad98aee 100644 --- a/iris/ops/all_gather_matmul.py +++ b/iris/ops/all_gather_matmul.py @@ -17,6 +17,7 @@ import iris.x from tritonblas.kernels.stages import GemmContext, ScheduleContext +from tritonblas.kernels.stages.indexing.pid_transforms import chiplet_transform_chunked from .config import FusedConfig from .workspace import FusedWorkspace @@ -164,7 +165,7 @@ def all_gather_matmul_preamble( B: torch.Tensor, config: Optional[FusedConfig] = None, ) -> FusedWorkspace: - """Allocate workspace for all_gather_matmul (none needed for pull pattern).""" + """Allocate workspace for all_gather_matmul.""" if config is None: config = FusedConfig() @@ -175,14 +176,27 @@ def all_gather_matmul_preamble( expected_K = world_size * K_local assert K == expected_K, f"K ({K}) must equal world_size ({world_size}) * K_local ({K_local})" - return FusedWorkspace( + ws = FusedWorkspace( operation="all_gather_matmul", shape=(M, N, K), dtype=A_sharded.dtype, world_size=world_size, + variant=config.all_gather_matmul_variant, prepared=True, ) + # Allocate push variant workspace + if config.all_gather_matmul_variant == "push": + num_m_tiles = (M + config.block_size_m - 1) // config.block_size_m + num_k_tiles = (K_local + config.block_size_k - 1) // config.block_size_k + ws.a_inbox = shmem.zeros((world_size, M, K_local), dtype=A_sharded.dtype) + ws.signal_flags = shmem.zeros( + (world_size, world_size, num_m_tiles, num_k_tiles), dtype=torch.int32 + ) + shmem.barrier() + + return ws + def all_gather_matmul( shmem, @@ -245,10 +259,15 @@ def all_gather_matmul( even_k = K_local % config.block_size_k == 0 num_k_blocks_local = (K_local + config.block_size_k - 1) // config.block_size_k - # Launch single fused kernel - grid = (num_sms,) - _fused_all_gather_matmul_kernel[grid]( - A_sharded, + variant = config.all_gather_matmul_variant + + if variant == "pull": + num_tiles_m = (M + config.block_size_m - 1) // config.block_size_m + num_tiles_n = (N + config.block_size_n - 1) // config.block_size_n + num_tiles = num_tiles_m * num_tiles_n + # grid = (num_tiles,) + grid = (num_sms,) + _fused_all_gather_matmul_kernel[grid](A_sharded, B, output_tensor, bias_ptr, diff --git a/iris/ops/config.py b/iris/ops/config.py index 77c0b5ab9..a92925035 100644 --- a/iris/ops/config.py +++ b/iris/ops/config.py @@ -54,9 +54,9 @@ class FusedConfig: block_size_m: int = 128 block_size_n: int = 256 block_size_k: int = 64 - group_size_m: int = 4 + group_size_m: int = 1 num_sms: Optional[int] = None # Auto-detect if None - num_xcds: int = 1 + num_xcds: int = 8 chunk_size: int = 1 cache_modifier_a: str = ".ca" cache_modifier_b: str = ".ca" @@ -109,7 +109,7 @@ def validate(self, world_size: Optional[int] = None): raise ValueError(f"all_reduce_num_rings must be positive, got {self.all_reduce_num_rings}") # Validate all_gather_matmul_variant - valid_ag_variants = ["pull", "chunked"] + valid_ag_variants = ["pull"] if self.all_gather_matmul_variant not in valid_ag_variants: raise ValueError( f"all_gather_matmul_variant must be one of {valid_ag_variants}, got {self.all_gather_matmul_variant}" diff --git a/iris/ops/workspace.py b/iris/ops/workspace.py index 9328e9f9e..e519f0823 100644 --- a/iris/ops/workspace.py +++ b/iris/ops/workspace.py @@ -46,6 +46,10 @@ class FusedWorkspace: aux_buffer: Optional[torch.Tensor] = None # Generic buffer for intermediate results locks: Optional[torch.Tensor] = None # Synchronization primitives + # Push variant workspace + a_inbox: Optional[torch.Tensor] = None # (world_size, M, K_local) inbox buffer + signal_flags: Optional[torch.Tensor] = None # (world_size, world_size, m_tiles, k_tiles) + prepared: bool = False def matches( @@ -86,4 +90,6 @@ def clear(self): """Free all allocated buffers.""" self.aux_buffer = None self.locks = None + self.a_inbox = None + self.signal_flags = None self.prepared = False From c26e87275043e996c9dca78e44c60fc34d6d2eac Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Tue, 10 Feb 2026 00:04:25 +0000 Subject: [PATCH 05/31] Apply Ruff auto-fixes --- iris/ops/all_gather_matmul.py | 64 +++++++++++++++++------------------ 1 file changed, 31 insertions(+), 33 deletions(-) diff --git a/iris/ops/all_gather_matmul.py b/iris/ops/all_gather_matmul.py index 0dad98aee..6000f50ef 100644 --- a/iris/ops/all_gather_matmul.py +++ b/iris/ops/all_gather_matmul.py @@ -17,7 +17,6 @@ import iris.x from tritonblas.kernels.stages import GemmContext, ScheduleContext -from tritonblas.kernels.stages.indexing.pid_transforms import chiplet_transform_chunked from .config import FusedConfig from .workspace import FusedWorkspace @@ -190,9 +189,7 @@ def all_gather_matmul_preamble( num_m_tiles = (M + config.block_size_m - 1) // config.block_size_m num_k_tiles = (K_local + config.block_size_k - 1) // config.block_size_k ws.a_inbox = shmem.zeros((world_size, M, K_local), dtype=A_sharded.dtype) - ws.signal_flags = shmem.zeros( - (world_size, world_size, num_m_tiles, num_k_tiles), dtype=torch.int32 - ) + ws.signal_flags = shmem.zeros((world_size, world_size, num_m_tiles, num_k_tiles), dtype=torch.int32) shmem.barrier() return ws @@ -267,35 +264,36 @@ def all_gather_matmul( num_tiles = num_tiles_m * num_tiles_n # grid = (num_tiles,) grid = (num_sms,) - _fused_all_gather_matmul_kernel[grid](A_sharded, - B, - output_tensor, - bias_ptr, - M, - N, - K, - K_local, - stride_am, - stride_ak, - stride_bk, - stride_bn, - stride_cm, - stride_cn, - stride_bias, - shmem.get_device_context(), - rank, - world_size, - config.block_size_m, - config.block_size_n, - config.block_size_k, - config.group_size_m, - num_sms, - config.num_xcds, - num_k_blocks_local, - use_bias, - even_k, - config.allow_tf32, - ) + _fused_all_gather_matmul_kernel[grid]( + A_sharded, + B, + output_tensor, + bias_ptr, + M, + N, + K, + K_local, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_bias, + shmem.get_device_context(), + rank, + world_size, + config.block_size_m, + config.block_size_n, + config.block_size_k, + config.group_size_m, + num_sms, + config.num_xcds, + num_k_blocks_local, + use_bias, + even_k, + config.allow_tf32, + ) if not async_op: shmem.barrier() From 3d4c7d7fc3129cfaf3125247bc9496c1a04bcaa8 Mon Sep 17 00:00:00 2001 From: Ryan Swann Date: Wed, 11 Feb 2026 12:01:43 -0500 Subject: [PATCH 06/31] Fix load vectorization and transpose config --- benchmark/ops/all_gather_matmul/benchmark.py | 50 +- .../all_gather_matmul/benchmark_torchrun.py | 487 ++++++++++++++++++ .../ops/all_gather_matmul/profile_att.sh | 344 +++++++++++++ benchmark/ops/all_gather_matmul/test.sh | 16 + iris/iris.py | 56 +- iris/ops/all_gather_matmul.py | 1 + iris/x/core.py | 5 +- iris/x/gather.py | 29 +- 8 files changed, 965 insertions(+), 23 deletions(-) create mode 100755 benchmark/ops/all_gather_matmul/benchmark_torchrun.py create mode 100755 benchmark/ops/all_gather_matmul/profile_att.sh create mode 100755 benchmark/ops/all_gather_matmul/test.sh diff --git a/benchmark/ops/all_gather_matmul/benchmark.py b/benchmark/ops/all_gather_matmul/benchmark.py index ae0443e6d..b9d40118d 100644 --- a/benchmark/ops/all_gather_matmul/benchmark.py +++ b/benchmark/ops/all_gather_matmul/benchmark.py @@ -72,6 +72,16 @@ def parse_args(): parser.add_argument( "--init_url", type=str, default="tcp://127.0.0.1:29530", help="Initialization URL for distributed setup" ) + parser.add_argument( + "--b_col_major", + action="store_true", + help="Store B matrix in column-major order (K-contiguous) to reduce LDS transpose overhead", + ) + parser.add_argument( + "--a_col_major", + action="store_true", + help="Store A matrix in column-major order (M-contiguous). Default is row-major (K-contiguous).", + ) return vars(parser.parse_args()) @@ -142,11 +152,45 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): # Create input and output tensors # A_sharded is M x K_local, B is K x N, output is M x N - A_sharded = shmem.zeros((M, K_local), dtype=datatype) - B = shmem.zeros((K, N), dtype=datatype) C = shmem.zeros((M, N), dtype=datatype) expected_tensor = None + # Create A_sharded matrix with optional column-major layout + # When a_col_major=True, M becomes the contiguous dimension + # Default (row-major): K is contiguous (stride_ak=1, stride_am=K_local) + if args["a_col_major"]: + # Allocate storage as (K_local, M) row-major, then transpose to get (M, K_local) with M-contiguous + # This means stride_am=1 and stride_ak=M + A_storage = shmem.zeros((K_local, M), dtype=datatype) + A_sharded = A_storage.T # View as (M, K_local) with M-contiguous strides + shmem.info(f"Using column-major A: shape={A_sharded.shape}, strides={A_sharded.stride()} (M-contiguous)") + else: + # Standard row-major (M, K_local) - K is contiguous + A_sharded = shmem.zeros((M, K_local), dtype=datatype) + shmem.info(f"Using row-major A: shape={A_sharded.shape}, strides={A_sharded.stride()} (K-contiguous)") + + json_writer.add_field("a_col_major", args["a_col_major"]) + json_writer.add_field("a_stride_m", A_sharded.stride()[0]) + json_writer.add_field("a_stride_k", A_sharded.stride()[1]) + + # Create B matrix with optional column-major layout for K-contiguous access + # When b_col_major=True, we store B such that K is the contiguous dimension + # This reduces LDS transpose overhead when loading B tiles along the K dimension + if args["b_col_major"]: + # Allocate storage as (N, K) row-major, then transpose to get (K, N) with K-contiguous + # This means stride_bk=1 and stride_bn=K + B_storage = shmem.zeros((N, K), dtype=datatype) + B = B_storage.T # View as (K, N) with K-contiguous strides + shmem.info(f"Using column-major B: shape={B.shape}, strides={B.stride()} (K-contiguous)") + else: + # Standard row-major (K, N) - N is contiguous + B = shmem.zeros((K, N), dtype=datatype) + shmem.info(f"Using row-major B: shape={B.shape}, strides={B.stride()} (N-contiguous)") + + json_writer.add_field("b_col_major", args["b_col_major"]) + json_writer.add_field("b_stride_k", B.stride()[0]) + json_writer.add_field("b_stride_n", B.stride()[1]) + # Fill inputs with deterministic values # Each rank has different A_sharded, same B torch.manual_seed(123 + rank) @@ -154,7 +198,9 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): A_sharded.copy_(A_sharded_data) torch.manual_seed(456) # Same B for all ranks + # Generate B data in standard (K, N) layout for consistency B_data = torch.randn((K, N), dtype=datatype, device=f"cuda:{rank}") + # Copy to B (handles both row-major and column-major storage) B.copy_(B_data) # For validation: compute expected result diff --git a/benchmark/ops/all_gather_matmul/benchmark_torchrun.py b/benchmark/ops/all_gather_matmul/benchmark_torchrun.py new file mode 100755 index 000000000..f4526410c --- /dev/null +++ b/benchmark/ops/all_gather_matmul/benchmark_torchrun.py @@ -0,0 +1,487 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +""" +Benchmark for iris.ops all_gather_matmul fused operation. + +This benchmark showcases the fused All-Gather + GEMM operation where each rank +has a sharded A matrix that gets gathered, then multiplied with B. + +This version is compatible with torchrun for use with profiling tools like rocprofv3/att. + +Usage with torchrun: + torchrun --nproc_per_node=8 benchmark_torchrun.py -m 16384 -n 2048 -k 131072 --benchmark + +Usage with rocprofv3: + torchrun --nproc_per_node=8 rocprofv3 --att benchmark_torchrun.py -m 16384 -n 2048 -k 131072 --benchmark +""" + +import os +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import random +import argparse + +from examples.common.utils import JSONWriter + +import iris +from iris.ops.all_gather_matmul import all_gather_matmul_preamble +from iris.ops import FusedConfig + +torch.manual_seed(123) +random.seed(123) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Benchmark all_gather_matmul fused operation.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("-m", type=int, default=16384, help="Number of rows in matrix A (M)") + parser.add_argument("-n", type=int, default=2048, help="Number of columns in matrix B (N)") + parser.add_argument("-k", type=int, default=131072, help="Common dimension total (K)") + parser.add_argument("-d", "--debug", action="store_true", help="Enable debug mode") + parser.add_argument("-v", "--validate", action="store_true", help="Enable validation mode") + parser.add_argument("-b", "--benchmark", action="store_true", help="Enable benchmarking mode") + parser.add_argument( + "--datatype", + type=str, + default="fp16", + choices=["fp16", "fp32", "bf16"], + help="Datatype of tensors", + ) + parser.add_argument( + "--output_file", + type=str, + default="all_gather_matmul.json", + help="Output file", + ) + parser.add_argument("--heap_size", type=int, default=1 << 34, help="Iris heap size") + parser.add_argument("--comm_sms", type=int, default=None, help="Number of SMs for operation (auto-detect if None)") + parser.add_argument( + "--benchmark_pytorch", + action="store_true", + help="Also benchmark PyTorch (all_gather_into_tensor + matmul) for comparison", + ) + parser.add_argument("--block_size_m", type=int, default=256, help="Block size for M dimension") + parser.add_argument("--block_size_n", type=int, default=64, help="Block size for N dimension") + parser.add_argument("--block_size_k", type=int, default=64, help="Block size for K dimension") + parser.add_argument("--group_size_m", type=int, default=1, help="Group size for M dimension tiling") + parser.add_argument("--num_xcds", type=int, default=None, help="Number of XCDs (auto-detected if not set)") + parser.add_argument("-r", "--num_ranks", type=int, default=8, help="Number of ranks/processes") + parser.add_argument( + "--variant", + type=str, + default="pull", + choices=["pull", "chunked", "push", "pipelined_pull"], + help="All-gather matmul variant", + ) + parser.add_argument( + "--init_url", type=str, default="tcp://127.0.0.1:29530", help="Initialization URL for distributed setup" + ) + parser.add_argument( + "--single-run", + action="store_true", + help="Run only one iteration (no warmup, 1 repeat) - useful for profiling", + ) + parser.add_argument( + "--b_col_major", + action="store_true", + help="Store B matrix in column-major order (K-contiguous) to reduce LDS transpose overhead", + ) + parser.add_argument( + "--a_col_major", + action="store_true", + help="Store A matrix in column-major order (M-contiguous). Default is row-major (K-contiguous).", + ) + + return vars(parser.parse_args()) + + +def _worker(local_rank: int = None, world_size: int = None, init_url: str = None, args: dict = None): + """Worker function for PyTorch distributed execution.""" + # Support torchrun: read from environment variables if available + if local_rank is None: + local_rank = int(os.environ.get("RANK", os.environ.get("LOCAL_RANK", 0))) + if world_size is None: + world_size = int(os.environ.get("WORLD_SIZE", 1)) + if init_url is None: + # torchrun sets MASTER_ADDR and MASTER_PORT + master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1") + master_port = os.environ.get("MASTER_PORT", "29500") + init_url = f"tcp://{master_addr}:{master_port}" + + # Use nccl backend - gloo doesn't support uint64 tensors used by Iris + backend = "nccl" if torch.cuda.is_available() else "gloo" + print(f"Rank {local_rank}: Using backend: {backend}") + + # Use environment-based initialization if torchrun is detected + if "RANK" in os.environ or "LOCAL_RANK" in os.environ: + # For torchrun, use env:// initialization with device_id for nccl + dist.init_process_group( + backend=backend, + init_method="env://", + device_id=torch.device(f"cuda:{local_rank}") if torch.cuda.is_available() else None, + ) + else: + dist.init_process_group( + backend=backend, + init_method=init_url, + world_size=world_size, + rank=local_rank, + device_id=torch.device(f"cuda:{local_rank}") if torch.cuda.is_available() else None, + ) + + shmem = iris.iris(args["heap_size"]) + rank = shmem.get_rank() + world_size = shmem.get_num_ranks() + + # Datatype mapping + datatype = torch.float32 + if args["datatype"] == "fp16": + datatype = torch.float16 + elif args["datatype"] == "fp32": + datatype = torch.float32 + elif args["datatype"] == "bf16": + datatype = torch.bfloat16 + else: + print("Unknown datatype.") + exit(1) + + M = args["m"] + N = args["n"] + K = args["k"] + K_local = K // world_size # Sharded K dimension + + # Create config with parameters + config_kwargs = { + "block_size_m": args["block_size_m"], + "block_size_n": args["block_size_n"], + "block_size_k": args["block_size_k"], + "group_size_m": args["group_size_m"], + "all_gather_matmul_variant": args["variant"], + } + if args["comm_sms"] is not None: + config_kwargs["num_sms"] = args["comm_sms"] + if args["num_xcds"] is not None: + config_kwargs["num_xcds"] = args["num_xcds"] + + config = FusedConfig(**config_kwargs) + + json_writer = JSONWriter(args["output_file"]) + json_writer.add_field("world_size", world_size) + json_writer.add_field("operation", "all_gather_matmul") + json_writer.add_field("k_local", K_local) + json_writer.add_field("k_total", K) + + for key, value in args.items(): + json_writer.add_field(key, value) + + # Export actual config values to JSON (including defaults) + json_writer.add_field("block_size_m", config.block_size_m) + json_writer.add_field("block_size_n", config.block_size_n) + json_writer.add_field("block_size_k", config.block_size_k) + json_writer.add_field("group_size_m", config.group_size_m) + json_writer.add_field("num_sms", config.num_sms) + json_writer.add_field("num_xcds", config.num_xcds) + + # Create input and output tensors + # A_sharded is M x K_local, B is K x N, output is M x N + C = shmem.zeros((M, N), dtype=datatype) + expected_tensor = None + + # Create A_sharded matrix with optional column-major layout + # When a_col_major=True, M becomes the contiguous dimension + # Default (row-major): K is contiguous (stride_ak=1, stride_am=K_local) + if args["a_col_major"]: + # Allocate storage as (K_local, M) row-major, then transpose to get (M, K_local) with M-contiguous + # This means stride_am=1 and stride_ak=M + A_storage = shmem.zeros((K_local, M), dtype=datatype) + A_sharded = A_storage.T # View as (M, K_local) with M-contiguous strides + shmem.info(f"Using column-major A: shape={A_sharded.shape}, strides={A_sharded.stride()} (M-contiguous)") + else: + # Standard row-major (M, K_local) - K is contiguous + A_sharded = shmem.zeros((M, K_local), dtype=datatype) + shmem.info(f"Using row-major A: shape={A_sharded.shape}, strides={A_sharded.stride()} (K-contiguous)") + + json_writer.add_field("a_col_major", args["a_col_major"]) + json_writer.add_field("a_stride_m", A_sharded.stride()[0]) + json_writer.add_field("a_stride_k", A_sharded.stride()[1]) + + # Create B matrix with optional column-major layout for K-contiguous access + # When b_col_major=True, we store B such that K is the contiguous dimension + # This reduces LDS transpose overhead when loading B tiles along the K dimension + if args["b_col_major"]: + # Allocate storage as (N, K) row-major, then transpose to get (K, N) with K-contiguous + # This means stride_bk=1 and stride_bn=K + B_storage = shmem.zeros((N, K), dtype=datatype) + B = B_storage.T # View as (K, N) with K-contiguous strides + shmem.info(f"Using column-major B: shape={B.shape}, strides={B.stride()} (K-contiguous)") + else: + # Standard row-major (K, N) - N is contiguous + B = shmem.zeros((K, N), dtype=datatype) + shmem.info(f"Using row-major B: shape={B.shape}, strides={B.stride()} (N-contiguous)") + + json_writer.add_field("b_col_major", args["b_col_major"]) + json_writer.add_field("b_stride_k", B.stride()[0]) + json_writer.add_field("b_stride_n", B.stride()[1]) + + # Fill inputs with deterministic values + # Each rank has different A_sharded, same B + torch.manual_seed(123 + rank) + A_sharded_data = torch.randn((M, K_local), dtype=datatype, device=f"cuda:{rank}") + A_sharded.copy_(A_sharded_data) + + torch.manual_seed(456) # Same B for all ranks + # Generate B data in standard (K, N) layout for consistency + B_data = torch.randn((K, N), dtype=datatype, device=f"cuda:{rank}") + # Copy to B (handles both row-major and column-major storage) + B.copy_(B_data) + + # For validation: compute expected result + if args["validate"]: + # Gather all A_sharded matrices and compute expected result + A_sharded_list = [torch.zeros((M, K_local), dtype=datatype, device=f"cuda:{rank}") for _ in range(world_size)] + dist.all_gather(A_sharded_list, A_sharded_data) + + # Concatenate along K dimension: A_gathered = [A_0 | A_1 | ... | A_n] + A_gathered = torch.cat(A_sharded_list, dim=1) # (M, K) + + # Expected: A_gathered @ B + expected_tensor = shmem.zeros((M, N), dtype=datatype) + expected_result = torch.matmul(A_gathered, B_data) + expected_tensor.copy_(expected_result) + + comm_stream = torch.cuda.Stream() + + kernel_timing = { + "all_gather_matmul": { + "start_event": torch.cuda.Event(enable_timing=True), + "end_event": torch.cuda.Event(enable_timing=True), + "ms": 0, + "experiments": 0, + }, + } + + # Pre-allocate workspace once (important for push variant which needs large buffers) + workspace = all_gather_matmul_preamble(shmem, A_sharded, B, config) + + def run_experiment(): + nonlocal kernel_timing + + shmem.barrier() + + torch.cuda.nvtx.range_push("All-Gather-Matmul") + with torch.cuda.stream(comm_stream): + kernel_timing["all_gather_matmul"]["start_event"].record() + shmem.ops.all_gather_matmul( + C, + A_sharded, + B, + config=config, + async_op=False, + workspace=workspace, + ) + kernel_timing["all_gather_matmul"]["end_event"].record() + kernel_timing["all_gather_matmul"]["experiments"] += 1 + torch.cuda.nvtx.range_pop() + + # Synchronize before querying event timing + shmem.barrier() + + # Update timing + ms = kernel_timing["all_gather_matmul"]["start_event"].elapsed_time( + kernel_timing["all_gather_matmul"]["end_event"] + ) + kernel_timing["all_gather_matmul"]["ms"] += ms + + # Synchronize across all GPUs + shmem.barrier() + + if args["validate"]: + shmem.info("Validating...") + + # Reset output before validation + C.zero_() + shmem.barrier() + + run_experiment() + torch.cuda.synchronize() + shmem.barrier() + + atol = 1e-1 if datatype == torch.float16 else 1e-3 + success = torch.allclose(C, expected_tensor, atol=atol) + if not success: + max_diff = torch.abs(C - expected_tensor).max().item() + shmem.error(f"Rank {rank}: Validation failed, max diff: {max_diff}") + + if success: + shmem.info("All-gather-matmul validation passed!") + else: + shmem.error("All-gather-matmul validation failed!") + + json_writer.add_field("success", success) + + # Wait for all to finish validation + shmem.barrier() + + if args["benchmark"]: + # Determine warmup and repeat counts + if args.get("single_run", False): + n_warmup = 0 + n_repeat = 1 + shmem.info("Single-run mode: no warmup, 1 repeat") + else: + n_warmup = 25 + n_repeat = 100 # default from iris.do_bench + + # Warmup for benchmarking (skip if single-run) + if not args.get("single_run", False): + for k in ["all_gather_matmul"]: + kernel_timing[k]["ms"] = 0 + kernel_timing[k]["experiments"] = 0 + + iris.do_bench(run_experiment, shmem.barrier, n_warmup=n_warmup, n_repeat=1) + + for k in ["all_gather_matmul"]: + kernel_timing[k]["ms"] = 0 + kernel_timing[k]["experiments"] = 0 + + # Reset output before benchmarking + C.zero_() + shmem.barrier() + + shmem.info("Benchmarking...") + + # Calculate TFLOPS: 2*M*N*K flops + total_flops = 2 * M * N * K + total_tflops_unit = total_flops * 1e-12 + + triton_ms = iris.do_bench(run_experiment, shmem.barrier, n_warmup=n_warmup, n_repeat=n_repeat) + tflops = total_tflops_unit / ( + (kernel_timing["all_gather_matmul"]["ms"] / kernel_timing["all_gather_matmul"]["experiments"]) * 1e-3 + ) + + # Calculate bandwidth for all-gather part + # All-gather moves (world_size - 1) * M * K_local * element_size bytes + element_size = torch.tensor([], dtype=datatype).element_size() + input_bytes = M * K_local * element_size + total_bytes = input_bytes * (world_size - 1) + total_bytes_gb = total_bytes / (1024**3) + + bandwidth_gbps = total_bytes_gb / ( + (kernel_timing["all_gather_matmul"]["ms"] / kernel_timing["all_gather_matmul"]["experiments"]) * 1e-3 + ) + + shmem.info( + f"All-gather-matmul (M={M}, K_local={K_local}, K_total={K}, N={N}, world_size={world_size}, dtype={args['datatype']}): " + f"{triton_ms:.3f} ms, {tflops:.3f} TFLOPS, {bandwidth_gbps:.3f} GB/s" + ) + + json_writer.add_field("tflops", tflops) + json_writer.add_field("bandwidth_gbps", bandwidth_gbps) + json_writer.add_field("total_ms", triton_ms) + json_writer.add_field("total_flops", total_flops) + json_writer.add_field("total_bytes", total_bytes) + json_writer.add_field("total_bytes_gb", total_bytes_gb) + json_writer.add_field( + "all_gather_matmul_ms", + kernel_timing["all_gather_matmul"]["ms"] / kernel_timing["all_gather_matmul"]["experiments"], + ) + json_writer.add_field("all_gather_matmul_experiments", kernel_timing["all_gather_matmul"]["experiments"]) + + # Wait for all to finish benchmarking + shmem.barrier() + + # Benchmark PyTorch (all_gather_into_tensor + matmul) for comparison + if args["benchmark_pytorch"]: + shmem.info("Benchmarking PyTorch (all_gather_into_tensor + matmul)...") + + # Create PyTorch tensors (not on Iris heap) + pytorch_A_sharded = torch.randn(M, K_local, dtype=datatype, device=f"cuda:{rank}") + pytorch_B = torch.randn(K, N, dtype=datatype, device=f"cuda:{rank}") + pytorch_A_gathered = torch.zeros(M, K, dtype=datatype, device=f"cuda:{rank}") + pytorch_C = torch.zeros(M, N, dtype=datatype, device=f"cuda:{rank}") + + # Warmup + for _ in range(10): + dist.all_gather_into_tensor(pytorch_A_gathered, pytorch_A_sharded) + pytorch_C = torch.matmul(pytorch_A_gathered, pytorch_B) + torch.cuda.synchronize() + dist.barrier() + + # Benchmark + dist.barrier() + + # Calculate TFLOPS: 2*M*N*K flops + total_flops = 2 * M * N * K + total_tflops_unit = total_flops * 1e-12 + + # Calculate bandwidth for all-gather part + element_size = torch.tensor([], dtype=datatype).element_size() + input_bytes = M * K_local * element_size + total_bytes = input_bytes * (world_size - 1) + total_bytes_gb = total_bytes / (1024**3) + + def run_pytorch_experiment(): + dist.all_gather_into_tensor(pytorch_A_gathered, pytorch_A_sharded) + pytorch_C = torch.matmul(pytorch_A_gathered, pytorch_B) + + pytorch_ms = iris.do_bench(run_pytorch_experiment, dist.barrier) + + # Calculate TFLOPS and bandwidth + pytorch_tflops = total_tflops_unit / (pytorch_ms * 1e-3) + pytorch_bandwidth_gbps = total_bytes_gb / (pytorch_ms * 1e-3) + + shmem.info( + f"PyTorch all_gather_into_tensor+matmul (M={M}, K_local={K_local}, K_total={K}, N={N}, world_size={world_size}, dtype={args['datatype']}): " + f"{pytorch_ms:.3f} ms, {pytorch_tflops:.3f} TFLOPS, {pytorch_bandwidth_gbps:.3f} GB/s" + ) + + if args["benchmark"]: + # Calculate performance ratio + iris_tflops = tflops + speedup = (iris_tflops / pytorch_tflops) if pytorch_tflops > 0 else 0 + shmem.info(f"Speedup (Iris/PyTorch): {speedup:.2f}x") + + json_writer.add_field("pytorch_tflops", pytorch_tflops) + json_writer.add_field("pytorch_bandwidth_gbps", pytorch_bandwidth_gbps) + json_writer.add_field("pytorch_ms", pytorch_ms) + json_writer.add_field("iris_speedup", speedup) + + # Wait for all to finish PyTorch benchmarking + shmem.barrier() + + if rank == 0: + json_writer.flush() + json_writer.display() + + shmem.barrier() + dist.destroy_process_group() + + +def main(): + print("Starting all_gather_matmul benchmark...") + args = parse_args() + + # Check if running with torchrun (detected by environment variables) + if "RANK" in os.environ or "LOCAL_RANK" in os.environ: + # torchrun handles process spawning, so call _worker directly + print("Detected torchrun execution mode") + _worker(args=args) + else: + # Use multiprocessing spawn for backward compatibility + num_ranks = args["num_ranks"] + init_url = args["init_url"] + mp.spawn( + fn=_worker, + args=(num_ranks, init_url, args), + nprocs=num_ranks, + join=True, + ) + + +if __name__ == "__main__": + main() diff --git a/benchmark/ops/all_gather_matmul/profile_att.sh b/benchmark/ops/all_gather_matmul/profile_att.sh new file mode 100755 index 000000000..21f6f21fe --- /dev/null +++ b/benchmark/ops/all_gather_matmul/profile_att.sh @@ -0,0 +1,344 @@ +#!/bin/bash +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +# ATT (Advanced Thread Trace) Profiling Script for all_gather_matmul benchmark +# Uses rocprofv3 with thread trace to profile the benchmark at ISA instruction level. +# +# Usage: +# ./profile_att.sh [OPTIONS] +# +# Options: +# -r, --ranks NUM_RANKS Number of ranks/GPUs (default: 8) +# -m, --m-dim M M dimension (default: 2048) +# -n, --n-dim N N dimension (default: 16384) +# -k, --k-dim K K dimension (default: 131072) +# -v, --variant VARIANT Variant: pull, chunked, push, pipelined_pull (default: pull) +# --block-m SIZE Block size for M dimension (default: 256) +# --block-n SIZE Block size for N dimension (default: 256) +# --block-k SIZE Block size for K dimension (default: 64) +# --group-m SIZE Group size for M dimension tiling (default: 1) +# --num-xcds NUM Number of XCDs (default: 8) +# --validate Enable validation mode +# --benchmark-pytorch Also benchmark PyTorch for comparison +# -o, --output-dir DIR Base output directory (default: ./att_profiles) +# --att-target-cu CU Target CU for thread trace (default: 1) +# --att-buffer-size SIZE Trace buffer size in hex (default: 0x6000000 = 96MB) +# --att-activity LEVEL Perfcounter streaming level 1-16 (default: 8) +# --kernel-regex REGEX Kernel name regex filter (optional) +# --single-run Run only one iteration (no warmup, no repeat) +# --k-contiguous Use K-contiguous layout for both A and B matrices +# (default A is row-major/K-contiguous, adds --b_col_major) +# --a-col-major Store A matrix in column-major order (M-contiguous) +# --b-col-major Store B matrix in column-major order (K-contiguous) +# -h, --help Show this help message + +set -e + +# Default values +NUM_RANKS=8 +M_DIM=2048 +N_DIM=16384 +K_DIM=131072 +VARIANT="pull" +BASE_OUTPUT_DIR="./att_profiles" +ATT_TARGET_CU=1 +ATT_BUFFER_SIZE="0x6000000" # 96MB +ATT_ACTIVITY=8 +KERNEL_REGEX="" +SINGLE_RUN=true +K_CONTIGUOUS=true # Default to K-contiguous layout for both matrices +A_COL_MAJOR=false +B_COL_MAJOR=false +BLOCK_M=256 +BLOCK_N=256 +BLOCK_K=64 +GROUP_M=1 +NUM_XCDS=8 +VALIDATE=true +BENCHMARK_PYTORCH=true + +# Script directory +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +BENCHMARK_SCRIPT="${SCRIPT_DIR}/benchmark_torchrun.py" + +# Parse arguments +while [[ $# -gt 0 ]]; do + case $1 in + -r|--ranks) + NUM_RANKS="$2" + shift 2 + ;; + -m|--m-dim) + M_DIM="$2" + shift 2 + ;; + -n|--n-dim) + N_DIM="$2" + shift 2 + ;; + -k|--k-dim) + K_DIM="$2" + shift 2 + ;; + -v|--variant) + VARIANT="$2" + shift 2 + ;; + -o|--output-dir) + BASE_OUTPUT_DIR="$2" + shift 2 + ;; + --att-target-cu) + ATT_TARGET_CU="$2" + shift 2 + ;; + --att-buffer-size) + ATT_BUFFER_SIZE="$2" + shift 2 + ;; + --att-activity) + ATT_ACTIVITY="$2" + shift 2 + ;; + --kernel-regex) + KERNEL_REGEX="$2" + shift 2 + ;; + --single-run) + SINGLE_RUN=true + shift + ;; + --k-contiguous) + K_CONTIGUOUS=true + shift + ;; + --a-col-major) + A_COL_MAJOR=true + shift + ;; + --b-col-major) + B_COL_MAJOR=true + shift + ;; + --block-m) + BLOCK_M="$2" + shift 2 + ;; + --block-n) + BLOCK_N="$2" + shift 2 + ;; + --block-k) + BLOCK_K="$2" + shift 2 + ;; + --group-m) + GROUP_M="$2" + shift 2 + ;; + --num-xcds) + NUM_XCDS="$2" + shift 2 + ;; + --validate) + VALIDATE=true + shift + ;; + --no-validate) + VALIDATE=false + shift + ;; + --benchmark-pytorch) + BENCHMARK_PYTORCH=true + shift + ;; + --no-benchmark-pytorch) + BENCHMARK_PYTORCH=false + shift + ;; + -h|--help) + head -30 "$0" | tail -n +2 | sed 's/^# //' | sed 's/^#//' + exit 0 + ;; + *) + echo "Unknown option: $1" + exit 1 + ;; + esac +done + +# Generate timestamp for output directory +TIMESTAMP=$(date +"%Y%m%d_%H%M%S") +OUTPUT_DIR="${BASE_OUTPUT_DIR}/att_${VARIANT}_m${M_DIM}_n${N_DIM}_k${K_DIM}_${TIMESTAMP}" + +# Create output directory +mkdir -p "${OUTPUT_DIR}" + +# Log file with timestamp +LOG_FILE="${OUTPUT_DIR}/profile_${TIMESTAMP}.log" + +echo "==============================================" | tee "${LOG_FILE}" +echo "ATT Profiling for all_gather_matmul benchmark" | tee -a "${LOG_FILE}" +echo "==============================================" | tee -a "${LOG_FILE}" +echo "Timestamp: $(date)" | tee -a "${LOG_FILE}" +echo "Output directory: ${OUTPUT_DIR}" | tee -a "${LOG_FILE}" +echo "" | tee -a "${LOG_FILE}" +echo "Configuration:" | tee -a "${LOG_FILE}" +echo " NUM_RANKS: ${NUM_RANKS}" | tee -a "${LOG_FILE}" +echo " M: ${M_DIM}" | tee -a "${LOG_FILE}" +echo " N: ${N_DIM}" | tee -a "${LOG_FILE}" +echo " K: ${K_DIM}" | tee -a "${LOG_FILE}" +echo " Variant: ${VARIANT}" | tee -a "${LOG_FILE}" +echo "" | tee -a "${LOG_FILE}" +echo "ATT Parameters:" | tee -a "${LOG_FILE}" +echo " att-target-cu: ${ATT_TARGET_CU}" | tee -a "${LOG_FILE}" +echo " att-buffer-size: ${ATT_BUFFER_SIZE}" | tee -a "${LOG_FILE}" +echo " att-activity: ${ATT_ACTIVITY}" | tee -a "${LOG_FILE}" +if [[ -n "${KERNEL_REGEX}" ]]; then + echo " kernel-include-regex: ${KERNEL_REGEX}" | tee -a "${LOG_FILE}" +fi +echo " single-run: ${SINGLE_RUN}" | tee -a "${LOG_FILE}" +echo "" | tee -a "${LOG_FILE}" +echo "Matrix Layout:" | tee -a "${LOG_FILE}" +echo " k-contiguous: ${K_CONTIGUOUS}" | tee -a "${LOG_FILE}" +echo " a-col-major: ${A_COL_MAJOR}" | tee -a "${LOG_FILE}" +echo " b-col-major: ${B_COL_MAJOR}" | tee -a "${LOG_FILE}" +echo "" | tee -a "${LOG_FILE}" +echo "Block Sizes:" | tee -a "${LOG_FILE}" +echo " block-m: ${BLOCK_M}" | tee -a "${LOG_FILE}" +echo " block-n: ${BLOCK_N}" | tee -a "${LOG_FILE}" +echo " block-k: ${BLOCK_K}" | tee -a "${LOG_FILE}" +echo " group-m: ${GROUP_M}" | tee -a "${LOG_FILE}" +echo " num-xcds: ${NUM_XCDS}" | tee -a "${LOG_FILE}" +echo "" | tee -a "${LOG_FILE}" +echo "Benchmark Options:" | tee -a "${LOG_FILE}" +echo " validate: ${VALIDATE}" | tee -a "${LOG_FILE}" +echo " benchmark-pytorch: ${BENCHMARK_PYTORCH}" | tee -a "${LOG_FILE}" +echo "" | tee -a "${LOG_FILE}" + +# Build rocprofv3 ATT options +ROCPROF_OPTS="--att" +ROCPROF_OPTS="${ROCPROF_OPTS} --att-target-cu ${ATT_TARGET_CU}" +ROCPROF_OPTS="${ROCPROF_OPTS} --att-buffer-size ${ATT_BUFFER_SIZE}" +ROCPROF_OPTS="${ROCPROF_OPTS} --att-activity ${ATT_ACTIVITY}" + +if [[ -n "${KERNEL_REGEX}" ]]; then + ROCPROF_OPTS="${ROCPROF_OPTS} --kernel-include-regex \"${KERNEL_REGEX}\"" +fi + +# Build benchmark args +BENCH_ARGS="-m ${M_DIM} -n ${N_DIM} -k ${K_DIM} --variant ${VARIANT} --benchmark -r ${NUM_RANKS}" +BENCH_ARGS="${BENCH_ARGS} --block_size_m ${BLOCK_M} --block_size_n ${BLOCK_N} --block_size_k ${BLOCK_K}" +BENCH_ARGS="${BENCH_ARGS} --group_size_m ${GROUP_M} --num_xcds ${NUM_XCDS}" + +if [[ "${SINGLE_RUN}" == "true" ]]; then + BENCH_ARGS="${BENCH_ARGS} --single-run" +fi + +if [[ "${VALIDATE}" == "true" ]]; then + BENCH_ARGS="${BENCH_ARGS} -v" +fi + +if [[ "${BENCHMARK_PYTORCH}" == "true" ]]; then + BENCH_ARGS="${BENCH_ARGS} --benchmark_pytorch" +fi + +# Add K-contiguous layout options +# --k-contiguous: Both A and B become K-contiguous +# - A is already K-contiguous in default row-major layout +# - B needs --b_col_major to become K-contiguous +if [[ "${K_CONTIGUOUS}" == "true" ]]; then + BENCH_ARGS="${BENCH_ARGS} --b_col_major" +fi + +# Individual matrix layout overrides +if [[ "${A_COL_MAJOR}" == "true" ]]; then + BENCH_ARGS="${BENCH_ARGS} --a_col_major" +fi +if [[ "${B_COL_MAJOR}" == "true" ]]; then + BENCH_ARGS="${BENCH_ARGS} --b_col_major" +fi + +# Full command +# rocprofv3 wraps the entire torchrun command, not the other way around +# HSA_NO_SCRATCH_RECLAIM=1 prevents scratch memory reclaim issues +FULL_CMD="HSA_NO_SCRATCH_RECLAIM=1 rocprofv3 ${ROCPROF_OPTS} -d ${OUTPUT_DIR} -- torchrun --nproc_per_node=${NUM_RANKS} ${BENCHMARK_SCRIPT} ${BENCH_ARGS}" + +echo "Command:" | tee -a "${LOG_FILE}" +echo "${FULL_CMD}" | tee -a "${LOG_FILE}" +echo "" | tee -a "${LOG_FILE}" + +# Save configuration to JSON for reference +cat > "${OUTPUT_DIR}/config.json" << EOF +{ + "timestamp": "${TIMESTAMP}", + "num_ranks": ${NUM_RANKS}, + "m_dim": ${M_DIM}, + "n_dim": ${N_DIM}, + "k_dim": ${K_DIM}, + "variant": "${VARIANT}", + "att_target_cu": ${ATT_TARGET_CU}, + "att_buffer_size": "${ATT_BUFFER_SIZE}", + "att_activity": ${ATT_ACTIVITY}, + "kernel_regex": "${KERNEL_REGEX}", + "single_run": ${SINGLE_RUN}, + "k_contiguous": ${K_CONTIGUOUS}, + "a_col_major": ${A_COL_MAJOR}, + "b_col_major": ${B_COL_MAJOR}, + "block_m": ${BLOCK_M}, + "block_n": ${BLOCK_N}, + "block_k": ${BLOCK_K}, + "group_m": ${GROUP_M}, + "num_xcds": ${NUM_XCDS}, + "validate": ${VALIDATE}, + "benchmark_pytorch": ${BENCHMARK_PYTORCH}, + "command": "${FULL_CMD}" +} +EOF + +echo "Starting profiling..." | tee -a "${LOG_FILE}" +echo "" | tee -a "${LOG_FILE}" + +# Run the profiling command +START_TIME=$(date +%s) + +# Execute the command and capture output +eval "${FULL_CMD}" 2>&1 | tee -a "${LOG_FILE}" +EXIT_CODE=${PIPESTATUS[0]} + +END_TIME=$(date +%s) +DURATION=$((END_TIME - START_TIME)) + +echo "" | tee -a "${LOG_FILE}" +echo "==============================================" | tee -a "${LOG_FILE}" +echo "Profiling completed" | tee -a "${LOG_FILE}" +echo "Exit code: ${EXIT_CODE}" | tee -a "${LOG_FILE}" +echo "Duration: ${DURATION} seconds" | tee -a "${LOG_FILE}" +echo "End time: $(date)" | tee -a "${LOG_FILE}" +echo "==============================================" | tee -a "${LOG_FILE}" +echo "" | tee -a "${LOG_FILE}" + +# List output files +echo "Output files:" | tee -a "${LOG_FILE}" +ls -la "${OUTPUT_DIR}" 2>&1 | tee -a "${LOG_FILE}" + +# Check for stats CSV files +if ls "${OUTPUT_DIR}"/stats_*.csv 1> /dev/null 2>&1; then + echo "" | tee -a "${LOG_FILE}" + echo "Stats CSV files found:" | tee -a "${LOG_FILE}" + ls -la "${OUTPUT_DIR}"/stats_*.csv 2>&1 | tee -a "${LOG_FILE}" +fi + +# Check for ui_output directories (ROCprof Compute Viewer compatible) +if ls -d "${OUTPUT_DIR}"/ui_output_* 1> /dev/null 2>&1; then + echo "" | tee -a "${LOG_FILE}" + echo "UI output directories (for ROCprof Compute Viewer):" | tee -a "${LOG_FILE}" + ls -d "${OUTPUT_DIR}"/ui_output_* 2>&1 | tee -a "${LOG_FILE}" +fi + +echo "" | tee -a "${LOG_FILE}" +echo "Profile output saved to: ${OUTPUT_DIR}" | tee -a "${LOG_FILE}" +echo "Log file: ${LOG_FILE}" | tee -a "${LOG_FILE}" + +exit ${EXIT_CODE} diff --git a/benchmark/ops/all_gather_matmul/test.sh b/benchmark/ops/all_gather_matmul/test.sh new file mode 100755 index 000000000..7d5ef1a98 --- /dev/null +++ b/benchmark/ops/all_gather_matmul/test.sh @@ -0,0 +1,16 @@ +HSA_NO_SCRATCH_RECLAIM=1 \ +python3 $(pwd)/benchmark.py \ + -m 2048 \ + -n 16384 \ + -k 131072 \ + --num_ranks 8 \ + --num_xcds 8 \ + --datatype fp16 \ + --block_size_m 512 \ + --block_size_n 128 \ + --block_size_k 64 \ + --group_size_m 1 \ + --benchmark \ + --b_col_major \ + -v \ + --benchmark_pytorch \ No newline at end of file diff --git a/iris/iris.py b/iris/iris.py index 21aaddd8a..50063a55e 100644 --- a/iris/iris.py +++ b/iris/iris.py @@ -1780,6 +1780,10 @@ def reduce_scatter(self, output_tensor, input_tensor, op=None, group=None, async @triton.jit def __translate(ptr, from_rank, to_rank, heap_bases): + """ + Basic pointer translation without vectorization hints. + Used for atomic operations which may receive scalar pointers. + """ from_base = tl.load(heap_bases + from_rank) to_base = tl.load(heap_bases + to_rank) # convert to int to compute difference @@ -1793,11 +1797,30 @@ def __translate(ptr, from_rank, to_rank, heap_bases): # Cast to_base back to pointer type translated_ptr = tl.cast(translated_ptr_byte, ptr.dtype) - # Vectorization hints: must be <= minimum block size used by any caller. - # (32, 32) is safe since all supported block sizes are multiples of 32. - # Largest vectorized load instruction is dwordx4 (128-bits = 8 x fp16). - # translated_ptr = tl.multiple_of(translated_ptr, (32, 32)) - # translated_ptr = tl.max_contiguous(translated_ptr, (32, 32)) + return translated_ptr + + + +@triton.jit +def __translate_block_2d(ptr, from_rank, to_rank, heap_bases): + """ + Pointer translation for block load/store operations. + + Note: Vectorization hints should be applied in the tile_ptr computation (core.py) + where the 2D block shape is actually created, not here in the translation. + """ + from_base = tl.load(heap_bases + from_rank) + to_base = tl.load(heap_bases + to_rank) + # convert to int to compute difference + ptr_int = tl.cast(ptr, tl.uint64) + # Find the offset from from_rank heap + offset = ptr_int - from_base + # Byte cast for byte offset addition + to_base_byte = tl.cast(to_base, tl.pointer_type(tl.int8)) + # Find the offset into the to_rank heap + translated_ptr_byte = to_base_byte + offset + # Cast to_base back to pointer type + translated_ptr = tl.cast(translated_ptr_byte, ptr.dtype) return translated_ptr @@ -1976,9 +1999,16 @@ def initialize(context_tensor, rank, world_size, tracing: tl.constexpr = False): @triton.jit def _translate(self, ptr, from_rank, to_rank): - """Internal pointer translation between rank address spaces.""" + """Internal pointer translation between rank address spaces. + Used for atomic operations which may receive scalar pointers.""" return __translate(ptr, from_rank, to_rank, self.heap_bases) + @triton.jit + def _translate_block_2d(self, ptr, from_rank, to_rank): + """Internal pointer translation with 2D vectorization hints. + Used for block load/store operations with 2D block pointers.""" + return __translate_block_2d(ptr, from_rank, to_rank, self.heap_bases) + @triton.jit def load(self, pointer, from_rank, mask=None): """ @@ -2000,7 +2030,7 @@ def load(self, pointer, from_rank, mask=None): Example: >>> data = ctx.load(buffer + offsets, from_rank=1, mask=mask) """ - translated_ptr = self._translate(pointer, self.rank, from_rank) + translated_ptr = self._translate_block_2d(pointer, self.rank, from_rank) result = tl.load(translated_ptr, mask=mask) return result @@ -2026,7 +2056,7 @@ def store(self, pointer, value, to_rank, mask=None): Example: >>> ctx.store(buffer + offsets, values, to_rank=1, mask=mask) """ - translated_ptr = self._translate(pointer, self.rank, to_rank) + translated_ptr = self._translate_block_2d(pointer, self.rank, to_rank) tl.store(translated_ptr, value, mask=mask) @triton.jit @@ -2356,6 +2386,9 @@ def load(pointer, to_rank, from_rank, heap_bases, mask=None): data from the target memory location. If the `from_rank` and `to_rank` are the same, this function performs a local load operation. + This function uses 2D vectorization hints for optimal performance with block pointers. + Minimum block size in each dimension should be >= 16. + Args: pointer (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the `from_rank`'s address space that will be translated to the `to_rank`'s address space. Must be the current rank where the pointer is local. to_rank (int): The rank ID to which the pointer will be translated. Must be the current rank where the pointer is local. @@ -2375,7 +2408,7 @@ def load(pointer, to_rank, from_rank, heap_bases, mask=None): >>> data = iris.load(ptr, cur_rank, remote_rank, heap_bases) >>> return data """ - translated_ptr = __translate(pointer, to_rank, from_rank, heap_bases) + translated_ptr = __translate_block_2d(pointer, to_rank, from_rank, heap_bases) result = tl.load(translated_ptr, mask=mask) return result @@ -2390,6 +2423,9 @@ def store(pointer, value, from_rank, to_rank, heap_bases, mask=None): the provided data to the target memory location. If the `from_rank` and `to_rank` are the same, this function performs a local store operation. + This function uses 2D vectorization hints for optimal performance with block pointers. + Minimum block size in each dimension should be >= 16. + Args: pointer (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the `from_rank`'s address space that will be translated to the `to_rank`'s address space. Must be the current rank where the pointer is local. value (Block): The tensor of elements to be stored. @@ -2410,7 +2446,7 @@ def store(pointer, value, from_rank, to_rank, heap_bases, mask=None): >>> value = 42 >>> iris.store(ptr, value, cur_rank, remote_rank, heap_bases) """ - translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) + translated_ptr = __translate_block_2d(pointer, from_rank, to_rank, heap_bases) tl.store(translated_ptr, value, mask=mask) diff --git a/iris/ops/all_gather_matmul.py b/iris/ops/all_gather_matmul.py index 0dad98aee..ed4d72b8a 100644 --- a/iris/ops/all_gather_matmul.py +++ b/iris/ops/all_gather_matmul.py @@ -295,6 +295,7 @@ def all_gather_matmul( use_bias, even_k, config.allow_tf32, + matrix_instr_nonkdim=16, ) if not async_op: diff --git a/iris/x/core.py b/iris/x/core.py index fee50918e..58786e79e 100644 --- a/iris/x/core.py +++ b/iris/x/core.py @@ -80,7 +80,10 @@ def tile_ptr(ptr, M, N, stride_m, stride_n, pid_m, pid_n, BLOCK_SIZE_M: tl.const rm, rn, mask = tile_layout(pid_m, pid_n, M, N, BLOCK_SIZE_M, BLOCK_SIZE_N) offset = rm[:, None] * stride_m + rn[None, :] * stride_n tile_ptr = ptr + offset - tile_ptr = tl.multiple_of(tile_ptr, (BLOCK_SIZE_M, BLOCK_SIZE_N)) + # NOTE: Vectorization hints are applied at the call site (e.g., gather.py) + # rather than here, because the caller knows the block dimensions. + # Alignment IS preserved through pointer translation since symmetric heaps + # are all page-aligned, so relative offsets within the heap are maintained. return tile_ptr, mask diff --git a/iris/x/gather.py b/iris/x/gather.py index 51f489a03..d94e85a93 100644 --- a/iris/x/gather.py +++ b/iris/x/gather.py @@ -51,16 +51,25 @@ def gather( src_tile_ptr, mask = src_view.tile_ptr(tile) if source_rank == ctx.rank: - # Local load - tile_data = tl.load(src_tile_ptr, mask=mask) + # Local load - can use vectorization hints since alignment is guaranteed + local_ptr = tl.multiple_of(src_tile_ptr, (1, tile.block_n)) + local_ptr = tl.max_contiguous(local_ptr, (1, tile.block_n)) + tile_data = tl.load(local_ptr, mask=mask) else: - # Remote load using RMA - tile_data = iris.load( - src_tile_ptr, - ctx.rank, # to_rank (current rank) - source_rank, # from_rank (source rank) - ctx.heap_bases, - mask=mask, - ) + # Remote load using RMA - inline translation and apply hints AFTER translation + # Hints must be applied to the translated pointer because pointer arithmetic + # (cast to uint64, subtract, add, cast back) destroys hint metadata. + # Alignment IS preserved because symmetric heaps are all page-aligned. + from_base = tl.load(ctx.heap_bases + ctx.rank) + to_base = tl.load(ctx.heap_bases + source_rank) + ptr_int = tl.cast(src_tile_ptr, tl.uint64) + offset = ptr_int - from_base + to_base_byte = tl.cast(to_base, tl.pointer_type(tl.int8)) + translated_ptr_byte = to_base_byte + offset + translated_ptr = tl.cast(translated_ptr_byte, src_tile_ptr.dtype) + # Apply vectorization hints AFTER translation + translated_ptr = tl.multiple_of(translated_ptr, (1, tile.block_n)) + translated_ptr = tl.max_contiguous(translated_ptr, (1, tile.block_n)) + tile_data = tl.load(translated_ptr, mask=mask) return tile_data From 5b022114ca303e354308b499b868f639dd6d8498 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Wed, 11 Feb 2026 17:02:36 +0000 Subject: [PATCH 07/31] Apply Ruff auto-fixes --- iris/iris.py | 1 - iris/ops/all_gather_matmul.py | 66 +++++++++++++++++------------------ iris/x/gather.py | 1 - 3 files changed, 32 insertions(+), 36 deletions(-) diff --git a/iris/iris.py b/iris/iris.py index 50063a55e..94cd0ae6e 100644 --- a/iris/iris.py +++ b/iris/iris.py @@ -1800,7 +1800,6 @@ def __translate(ptr, from_rank, to_rank, heap_bases): return translated_ptr - @triton.jit def __translate_block_2d(ptr, from_rank, to_rank, heap_bases): """ diff --git a/iris/ops/all_gather_matmul.py b/iris/ops/all_gather_matmul.py index ed4d72b8a..e72d0ef68 100644 --- a/iris/ops/all_gather_matmul.py +++ b/iris/ops/all_gather_matmul.py @@ -17,7 +17,6 @@ import iris.x from tritonblas.kernels.stages import GemmContext, ScheduleContext -from tritonblas.kernels.stages.indexing.pid_transforms import chiplet_transform_chunked from .config import FusedConfig from .workspace import FusedWorkspace @@ -190,9 +189,7 @@ def all_gather_matmul_preamble( num_m_tiles = (M + config.block_size_m - 1) // config.block_size_m num_k_tiles = (K_local + config.block_size_k - 1) // config.block_size_k ws.a_inbox = shmem.zeros((world_size, M, K_local), dtype=A_sharded.dtype) - ws.signal_flags = shmem.zeros( - (world_size, world_size, num_m_tiles, num_k_tiles), dtype=torch.int32 - ) + ws.signal_flags = shmem.zeros((world_size, world_size, num_m_tiles, num_k_tiles), dtype=torch.int32) shmem.barrier() return ws @@ -267,36 +264,37 @@ def all_gather_matmul( num_tiles = num_tiles_m * num_tiles_n # grid = (num_tiles,) grid = (num_sms,) - _fused_all_gather_matmul_kernel[grid](A_sharded, - B, - output_tensor, - bias_ptr, - M, - N, - K, - K_local, - stride_am, - stride_ak, - stride_bk, - stride_bn, - stride_cm, - stride_cn, - stride_bias, - shmem.get_device_context(), - rank, - world_size, - config.block_size_m, - config.block_size_n, - config.block_size_k, - config.group_size_m, - num_sms, - config.num_xcds, - num_k_blocks_local, - use_bias, - even_k, - config.allow_tf32, - matrix_instr_nonkdim=16, - ) + _fused_all_gather_matmul_kernel[grid]( + A_sharded, + B, + output_tensor, + bias_ptr, + M, + N, + K, + K_local, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_bias, + shmem.get_device_context(), + rank, + world_size, + config.block_size_m, + config.block_size_n, + config.block_size_k, + config.group_size_m, + num_sms, + config.num_xcds, + num_k_blocks_local, + use_bias, + even_k, + config.allow_tf32, + matrix_instr_nonkdim=16, + ) if not async_op: shmem.barrier() diff --git a/iris/x/gather.py b/iris/x/gather.py index d94e85a93..bb3fb637a 100644 --- a/iris/x/gather.py +++ b/iris/x/gather.py @@ -13,7 +13,6 @@ import triton import triton.language as tl -import iris from iris.iris import DeviceContext from .core import Tile, TensorView From 4c3b3f429e7abde4f0c5f37dca787e027841445c Mon Sep 17 00:00:00 2001 From: Ryan Swann Date: Wed, 11 Feb 2026 14:17:51 -0500 Subject: [PATCH 08/31] Add HBM buffered version --- .../all_gather_matmul/benchmark_hbm_buffer.py | 334 ++++++++++++++++ iris/ops/all_gather_matmul_hbm_buffer.py | 366 ++++++++++++++++++ 2 files changed, 700 insertions(+) create mode 100644 benchmark/ops/all_gather_matmul/benchmark_hbm_buffer.py create mode 100644 iris/ops/all_gather_matmul_hbm_buffer.py diff --git a/benchmark/ops/all_gather_matmul/benchmark_hbm_buffer.py b/benchmark/ops/all_gather_matmul/benchmark_hbm_buffer.py new file mode 100644 index 000000000..8a2dbae21 --- /dev/null +++ b/benchmark/ops/all_gather_matmul/benchmark_hbm_buffer.py @@ -0,0 +1,334 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. + +""" +Benchmark for the HBM-buffered all_gather_matmul variant. + +This variant cooperatively gathers A into a local HBM buffer with per-tile +ready flags, then runs GEMM from local memory. No global barriers -- CUs +that finish gathering early start GEMM immediately, spinning on flags for +any tile not yet available. + +Usage with torchrun: + torchrun --nproc_per_node=8 benchmark/ops/all_gather_matmul/benchmark_hbm_buffer.py \\ + -m 2048 -n 16384 -k 131072 --benchmark + + torchrun --nproc_per_node=8 benchmark/ops/all_gather_matmul/benchmark_hbm_buffer.py \\ + -m 2048 -n 16384 -k 131072 --benchmark --benchmark_pytorch --b_col_major +""" + +import os +import time +import torch +import torch.distributed as dist +import random +import argparse + +import iris +from iris.ops.all_gather_matmul_hbm_buffer import ( + all_gather_matmul_hbm_buffer, + all_gather_matmul_hbm_buffer_preamble, +) +from iris.ops import FusedConfig + +torch.manual_seed(123) +random.seed(123) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Benchmark HBM-buffered all_gather_matmul (per-tile flags).", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("-m", type=int, default=2048, help="M dimension") + parser.add_argument("-n", type=int, default=16384, help="N dimension") + parser.add_argument("-k", type=int, default=131072, help="K dimension (total)") + parser.add_argument("-v", "--validate", action="store_true", help="Validate correctness") + parser.add_argument("-b", "--benchmark", action="store_true", help="Run benchmark") + parser.add_argument( + "--datatype", type=str, default="fp16", + choices=["fp16", "fp32", "bf16"], help="Tensor datatype", + ) + parser.add_argument("--heap_size", type=int, default=1 << 34, help="Iris heap size") + parser.add_argument("--comm_sms", type=int, default=None, help="Number of SMs (auto if None)") + parser.add_argument( + "--benchmark_pytorch", action="store_true", + help="Also benchmark PyTorch (all_gather_into_tensor + matmul)", + ) + parser.add_argument("--block_size_m", type=int, default=256, help="Block size M") + parser.add_argument("--block_size_n", type=int, default=64, help="Block size N") + parser.add_argument("--block_size_k", type=int, default=64, help="Block size K") + parser.add_argument("--group_size_m", type=int, default=1, help="Group size M") + parser.add_argument("--num_xcds", type=int, default=None, help="Number of XCDs (auto if None)") + parser.add_argument("--b_col_major", action="store_true", help="B col-major (K-contiguous)") + parser.add_argument("--a_col_major", action="store_true", help="A col-major (M-contiguous)") + parser.add_argument("--single-run", action="store_true", help="1 iteration (for profiling)") + return vars(parser.parse_args()) + + +def _worker(args): + """Worker function for torchrun.""" + local_rank = int(os.environ.get("RANK", os.environ.get("LOCAL_RANK", 0))) + world_size_env = int(os.environ.get("WORLD_SIZE", 1)) + + backend = "nccl" if torch.cuda.is_available() else "gloo" + + if "RANK" in os.environ or "LOCAL_RANK" in os.environ: + dist.init_process_group( + backend=backend, init_method="env://", + device_id=torch.device(f"cuda:{local_rank}") if torch.cuda.is_available() else None, + ) + else: + dist.init_process_group( + backend=backend, init_method="tcp://127.0.0.1:29530", + world_size=world_size_env, rank=local_rank, + device_id=torch.device(f"cuda:{local_rank}") if torch.cuda.is_available() else None, + ) + + shmem = iris.iris(args["heap_size"]) + rank = shmem.get_rank() + world_size = shmem.get_num_ranks() + + datatype_map = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16} + datatype = datatype_map.get(args["datatype"], torch.float16) + + M = args["m"] + N = args["n"] + K = args["k"] + K_local = K // world_size + + config_kwargs = { + "block_size_m": args["block_size_m"], + "block_size_n": args["block_size_n"], + "block_size_k": args["block_size_k"], + "group_size_m": args["group_size_m"], + } + if args["comm_sms"] is not None: + config_kwargs["num_sms"] = args["comm_sms"] + if args["num_xcds"] is not None: + config_kwargs["num_xcds"] = args["num_xcds"] + config = FusedConfig(**config_kwargs) + + buffer_mb = M * K * torch.tensor([], dtype=datatype).element_size() / (1024 ** 2) + num_m_tiles = M // config.block_size_m + num_k_blocks = K // config.block_size_k + shmem.info( + f"HBM-Buffer variant: M={M} N={N} K={K} K_local={K_local} " + f"block=({config.block_size_m},{config.block_size_n},{config.block_size_k}) " + f"buffer={buffer_mb:.0f}MB flags={num_m_tiles}x{num_k_blocks}" + ) + + # ── Allocate tensors ───────────────────────────────────────────────── + C = shmem.zeros((M, N), dtype=datatype) + + if args["a_col_major"]: + A_storage = shmem.zeros((K_local, M), dtype=datatype) + A_sharded = A_storage.T + else: + A_sharded = shmem.zeros((M, K_local), dtype=datatype) + + if args["b_col_major"]: + B_storage = shmem.zeros((N, K), dtype=datatype) + B = B_storage.T + else: + B = shmem.zeros((K, N), dtype=datatype) + + shmem.info(f"A strides={A_sharded.stride()}, B strides={B.stride()}") + + # Fill + torch.manual_seed(123 + rank) + A_data = torch.randn((M, K_local), dtype=datatype, device=f"cuda:{rank}") + A_sharded.copy_(A_data) + + torch.manual_seed(456) + B_data = torch.randn((K, N), dtype=datatype, device=f"cuda:{rank}") + B.copy_(B_data) + + # Expected + expected_tensor = None + if args["validate"]: + A_list = [torch.zeros((M, K_local), dtype=datatype, device=f"cuda:{rank}") for _ in range(world_size)] + dist.all_gather(A_list, A_data) + A_gathered = torch.cat(A_list, dim=1) + expected_tensor = shmem.zeros((M, N), dtype=datatype) + expected_tensor.copy_(torch.matmul(A_gathered, B_data)) + + # Pre-allocate workspace + workspace = all_gather_matmul_hbm_buffer_preamble(shmem, A_sharded, B, config) + + # ── Timing ─────────────────────────────────────────────────────────── + comm_stream = torch.cuda.Stream() + start_ev = torch.cuda.Event(enable_timing=True) + end_ev = torch.cuda.Event(enable_timing=True) + total_ms = 0.0 + num_experiments = 0 + + def run_experiment(): + nonlocal total_ms, num_experiments + shmem.barrier() + with torch.cuda.stream(comm_stream): + start_ev.record() + all_gather_matmul_hbm_buffer( + shmem, C, A_sharded, B, + config=config, async_op=False, workspace=workspace, + ) + end_ev.record() + num_experiments += 1 + shmem.barrier() + total_ms += start_ev.elapsed_time(end_ev) + + shmem.barrier() + + # ── Validate ───────────────────────────────────────────────────────── + if args["validate"]: + shmem.info("Validating...") + C.zero_() + shmem.barrier() + run_experiment() + torch.cuda.synchronize() + shmem.barrier() + + atol = 1e-1 if datatype == torch.float16 else 1e-3 + success = torch.allclose(C, expected_tensor, atol=atol) + if not success: + max_diff = torch.abs(C - expected_tensor).max().item() + shmem.error(f"Rank {rank}: Validation FAILED, max diff: {max_diff}") + else: + shmem.info("Validation PASSED!") + shmem.barrier() + + # ── Benchmark ──────────────────────────────────────────────────────── + if args["benchmark"]: + if args.get("single_run"): + n_warmup, n_repeat = 0, 1 + else: + n_warmup, n_repeat = 25, 100 + + # Warmup + total_ms = 0.0 + num_experiments = 0 + if n_warmup > 0: + iris.do_bench(run_experiment, shmem.barrier, n_warmup=n_warmup, n_repeat=1) + + total_ms = 0.0 + num_experiments = 0 + C.zero_() + shmem.barrier() + + iris.do_bench(run_experiment, shmem.barrier, n_warmup=0, n_repeat=n_repeat) + avg_ms = total_ms / num_experiments if num_experiments > 0 else 0 + + total_flops = 2 * M * N * K + tflops = (total_flops * 1e-12) / (avg_ms * 1e-3) if avg_ms > 0 else 0 + element_size = torch.tensor([], dtype=datatype).element_size() + total_bytes = M * K_local * element_size * (world_size - 1) + bw_gbps = (total_bytes / (1024 ** 3)) / (avg_ms * 1e-3) if avg_ms > 0 else 0 + + shmem.info( + f"HBM-Buffer (M={M}, K_local={K_local}, K={K}, N={N}, " + f"ws={world_size}, dtype={args['datatype']}): " + f"{avg_ms:.3f} ms, {tflops:.3f} TFLOPS, {bw_gbps:.3f} GB/s" + ) + shmem.barrier() + + # ── Per-rank finish time measurement ───────────────────────────── + # Run a single iteration and record wall-clock finish time per rank + # to see if ranks complete at different times (load imbalance). + shmem.barrier() + torch.cuda.synchronize() + dist.barrier() + + # Synchronized start + dist.barrier() + t_start = time.perf_counter() + + all_gather_matmul_hbm_buffer( + shmem, C, A_sharded, B, + config=config, async_op=False, workspace=workspace, + ) + torch.cuda.synchronize() + t_end = time.perf_counter() + + finish_ms = (t_end - t_start) * 1000.0 + + # Gather all finish times to rank 0 for display + finish_tensor = torch.tensor([finish_ms], dtype=torch.float64, device=f"cuda:{rank}") + all_finish = [torch.zeros(1, dtype=torch.float64, device=f"cuda:{rank}") for _ in range(world_size)] + dist.all_gather(all_finish, finish_tensor) + + if rank == 0: + times = [t.item() for t in all_finish] + min_t = min(times) + max_t = max(times) + print(f"\n Per-rank finish times (single run):") + print(f" {'Rank':>6} {'Finish ms':>10} {'Delta ms':>10}") + print(f" {'-' * 30}") + for r, t in enumerate(times): + delta = t - min_t + print(f" {r:>6} {t:>10.3f} {delta:>+10.3f}") + print(f" {'-' * 30}") + print(f" Spread (max - min): {max_t - min_t:.3f} ms") + print() + + shmem.barrier() + + # ── PyTorch baseline ───────────────────────────────────────────────── + if args["benchmark_pytorch"]: + shmem.info("Benchmarking PyTorch (all_gather_into_tensor + matmul)...") + + pt_A = torch.randn(M, K_local, dtype=datatype, device=f"cuda:{rank}") + pt_B = torch.randn(K, N, dtype=datatype, device=f"cuda:{rank}") + pt_Ag = torch.zeros(M, K, dtype=datatype, device=f"cuda:{rank}") + + for _ in range(10): + dist.all_gather_into_tensor(pt_Ag, pt_A) + _ = torch.matmul(pt_Ag, pt_B) + torch.cuda.synchronize() + dist.barrier() + + def run_pt(): + dist.all_gather_into_tensor(pt_Ag, pt_A) + _ = torch.matmul(pt_Ag, pt_B) + + total_flops = 2 * M * N * K + element_size = torch.tensor([], dtype=datatype).element_size() + total_bytes = M * K_local * element_size * (world_size - 1) + + pt_ms = iris.do_bench(run_pt, dist.barrier) + pt_tflops = (total_flops * 1e-12) / (pt_ms * 1e-3) if pt_ms > 0 else 0 + pt_bw = (total_bytes / (1024 ** 3)) / (pt_ms * 1e-3) if pt_ms > 0 else 0 + + shmem.info( + f"PyTorch (M={M}, K_local={K_local}, K={K}, N={N}, ws={world_size}, " + f"dtype={args['datatype']}): " + f"{pt_ms:.3f} ms, {pt_tflops:.3f} TFLOPS, {pt_bw:.3f} GB/s" + ) + + if args["benchmark"]: + avg_ms = total_ms / num_experiments if num_experiments > 0 else 0 + iris_tflops = (total_flops * 1e-12) / (avg_ms * 1e-3) if avg_ms > 0 else 0 + speedup = iris_tflops / pt_tflops if pt_tflops > 0 else 0 + shmem.info(f"Speedup (HBM-Buffer / PyTorch): {speedup:.2f}x") + + shmem.barrier() + + shmem.barrier() + dist.destroy_process_group() + + +def main(): + print("Starting HBM-buffer all_gather_matmul benchmark...") + args = parse_args() + if "RANK" in os.environ or "LOCAL_RANK" in os.environ: + _worker(args) + else: + print( + "Please run with torchrun:\n" + " torchrun --nproc_per_node=N " + "benchmark/ops/all_gather_matmul/benchmark_hbm_buffer.py [OPTIONS]" + ) + + +if __name__ == "__main__": + main() diff --git a/iris/ops/all_gather_matmul_hbm_buffer.py b/iris/ops/all_gather_matmul_hbm_buffer.py new file mode 100644 index 000000000..a0233b6bb --- /dev/null +++ b/iris/ops/all_gather_matmul_hbm_buffer.py @@ -0,0 +1,366 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. + +""" +Fused All-Gather + GEMM using a local HBM staging buffer with per-tile flags. + +Each rank has a column-sharded input A_sharded (M x K_local). +This operation computes C = all_gather(A_sharded) @ B by: + 1. All CUs cooperate to gather A into a local HBM buffer, setting a ready + flag for each (m_tile, k_block) as it lands. + 2. Each CU then runs GEMM from the local buffer. Before consuming a tile, + it checks the ready flag; if not yet set, it spins until the gathering + CU writes it. + +No global barriers are needed. The per-tile flags provide fine-grained +producer-consumer synchronization: a CU that finishes gathering early can +start GEMM immediately, consuming any tile whose flag is already set. +""" + +from typing import Optional +import torch +import triton +import triton.language as tl +import iris +import iris.x + +from .config import FusedConfig +from .workspace import FusedWorkspace + + +# ========================================================================== +# Kernel +# ========================================================================== + + +@triton.jit +def _hbm_buffer_all_gather_matmul_kernel( + A_sharded, + B, + C, + bias_ptr, + staged_a, # Local HBM buffer: (M, K) fp16 + flags_ptr, # int32[NUM_M_TILES * NUM_K_BLOCKS] per-tile ready flags + M, + N, + K, + K_local, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_bias, + context_tensor: tl.tensor, + cur_rank: tl.constexpr, + world_size: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + NUM_SMS: tl.constexpr, + NUM_XCDS: tl.constexpr, + NUM_M_TILES: tl.constexpr, + NUM_K_BLOCKS: tl.constexpr, # K // BLOCK_SIZE_K (global) + NUM_K_BLOCKS_LOCAL: tl.constexpr, # K_local // BLOCK_SIZE_K + BIAS: tl.constexpr, + ALLOW_TF32: tl.constexpr, +): + """ + HBM-buffered all-gather + GEMM with per-tile ready flags. + + Each CU executes two phases back-to-back (no global barrier): + + Phase 1 (gather): The CU is assigned a slice of the (m_tile, src_rank, + k_block_local) gather work. For each assigned tile it pulls from remote + via iris.x.gather, writes to staged_a, and atomically sets the ready + flag. Local rank tiles are copied via a fast local load. + + Phase 2 (GEMM): The CU iterates over its assigned output tiles + (pid_m, pid_n). For each K-block in the accumulation loop it checks the + ready flag; if not yet set, it spins until the producing CU posts it. + A tiles are loaded from staged_a (local HBM) and B tiles from B. + """ + pid = tl.program_id(0) + + # XCD-aware PID remapping + if NUM_XCDS != 1: + pid = (pid % NUM_XCDS) * (NUM_SMS // NUM_XCDS) + (pid // NUM_XCDS) + + acc_dtype = tl.int32 if C.type.element_ty == tl.int8 else tl.float32 + + # DeviceContext and TensorView for gather + ctx = iris.DeviceContext.initialize(context_tensor, cur_rank, world_size) + src_view = iris.x.make_tensor_view(A_sharded, M, K_local, stride_am, stride_ak) + + # ================================================================== + # Phase 1: Cooperative gather into staged_a, set per-tile flags + # ================================================================== + # Total gather work = NUM_M_TILES * world_size * NUM_K_BLOCKS_LOCAL + # Each tile is BLOCK_SIZE_M x BLOCK_SIZE_K elements. + total_gather_tiles = NUM_M_TILES * world_size * NUM_K_BLOCKS_LOCAL + + for gather_idx in range(pid, total_gather_tiles, NUM_SMS): + # Decompose flat index -> (m_tile, src_rank_idx, k_block_local) + m_tile = gather_idx // (world_size * NUM_K_BLOCKS_LOCAL) + remainder = gather_idx % (world_size * NUM_K_BLOCKS_LOCAL) + src_rank_idx = remainder // NUM_K_BLOCKS_LOCAL + k_block_local = remainder % NUM_K_BLOCKS_LOCAL + + # Global k-block index in the full K dimension + k_block_global = src_rank_idx * NUM_K_BLOCKS_LOCAL + k_block_local + + # Gather the tile from the source rank, store to buffer, set flag. + # source_rank must be constexpr for iris.x.gather, so we iterate + # over all ranks at compile time and select at runtime. + # The store and flag-set are inside the branch so that a_tile is + # always defined when used. + zero = tl.program_id(0) * 0 + pid_m_t = zero + m_tile + tile_k_t = zero + k_block_local + k_tile = iris.x.TileView(pid_m_t, tile_k_t, BLOCK_SIZE_M, BLOCK_SIZE_K) + + rm = m_tile * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + rk = k_block_global * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + staged_ptrs = staged_a + rm[:, None] * K + rk[None, :] + flag_idx = m_tile * NUM_K_BLOCKS + k_block_global + + for compile_rank in range(world_size): + if src_rank_idx == compile_rank: + a_tile = iris.x.gather(k_tile, src_view, compile_rank, ctx) + tl.store(staged_ptrs, a_tile) + tl.atomic_xchg(flags_ptr + flag_idx, 1, sem="release", scope="gpu") + + # ================================================================== + # Phase 2: GEMM from staged_a (local) x B, checking flags + # ================================================================== + num_tiles_n = tl.cdiv(N, BLOCK_SIZE_N) + total_gemm_tiles = NUM_M_TILES * num_tiles_n + + for gemm_tile_id in range(pid, total_gemm_tiles, NUM_SMS): + # Tile scheduling with swizzle (GROUP_SIZE_M grouping) + num_pid_in_group = GROUP_SIZE_M * num_tiles_n + group_id = gemm_tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_sz = min(NUM_M_TILES - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((gemm_tile_id % num_pid_in_group) % group_sz) + pid_n = (gemm_tile_id % num_pid_in_group) // group_sz + + # Row / column indices + rm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) + rn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + rn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_SIZE_N), BLOCK_SIZE_N) + + # Initialize accumulator + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) + + # K-reduction loop + for k_block in range(NUM_K_BLOCKS): + # Wait for the (pid_m, k_block) tile to be ready. + # acquire semantics ensure subsequent loads see the stored data. + flag_idx = pid_m * NUM_K_BLOCKS + k_block + while tl.atomic_add(flags_ptr + flag_idx, 0, sem="acquire", scope="gpu") == 0: + pass + + # Load A from staged_a (purely local HBM) + rk = k_block * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + rk = tl.max_contiguous(tl.multiple_of(rk, BLOCK_SIZE_K), BLOCK_SIZE_K) + a_ptrs = staged_a + rm[:, None] * K + rk[None, :] + a = tl.load(a_ptrs) + + # Load B + B_ptrs = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn + b = tl.load(B_ptrs) + + # Accumulate + if ALLOW_TF32: + acc = tl.dot(a, b, acc, allow_tf32=True) + else: + acc += tl.dot(a, b, allow_tf32=False) + + # Add bias if provided + if BIAS: + bias_val = tl.load(bias_ptr + rm * stride_bias, mask=rm < M, other=0.0) + acc = acc + bias_val[:, None] + + # Convert to output dtype and store + c = acc.to(C.type.element_ty) + C_ptrs = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn + mask = (rm[:, None] < M) & (rn[None, :] < N) + tl.store(C_ptrs, c, mask=mask) + + +# ========================================================================== +# Python API +# ========================================================================== + + +def all_gather_matmul_hbm_buffer_preamble( + shmem, + A_sharded: torch.Tensor, + B: torch.Tensor, + config: Optional[FusedConfig] = None, +) -> FusedWorkspace: + """ + Allocate workspace for the HBM-buffered all_gather_matmul. + + Allocates: + - staged_a: (M, K) local HBM buffer for the gathered A matrix. + - flags: int32[num_m_tiles * num_k_blocks] per-tile ready flags. + """ + if config is None: + config = FusedConfig() + + M, K_local = A_sharded.shape + K, N = B.shape + world_size = shmem.get_num_ranks() + + expected_K = world_size * K_local + assert K == expected_K, f"K ({K}) must equal world_size ({world_size}) * K_local ({K_local})" + assert K_local % config.block_size_k == 0, ( + f"K_local ({K_local}) must be divisible by block_size_k ({config.block_size_k})" + ) + assert K % config.block_size_k == 0, ( + f"K ({K}) must be divisible by block_size_k ({config.block_size_k})" + ) + assert M % config.block_size_m == 0, ( + f"M ({M}) must be divisible by block_size_m ({config.block_size_m})" + ) + + num_m_tiles = M // config.block_size_m + num_k_blocks = K // config.block_size_k + + ws = FusedWorkspace( + operation="all_gather_matmul_hbm_buffer", + shape=(M, N, K), + dtype=A_sharded.dtype, + world_size=world_size, + variant="hbm_buffer", + prepared=True, + ) + + # (M, K) staging buffer in local HBM + ws.aux_buffer = shmem.zeros((M, K), dtype=A_sharded.dtype) + # Per-tile ready flags + ws.locks = shmem.zeros((num_m_tiles * num_k_blocks,), dtype=torch.int32) + + buffer_mb = M * K * A_sharded.element_size() / (1024 ** 2) + shmem.info(f"HBM buffer workspace: staged_a=({M},{K}) [{buffer_mb:.1f} MB], " + f"flags=[{num_m_tiles}x{num_k_blocks}={num_m_tiles * num_k_blocks}]") + + shmem.barrier() + return ws + + +def all_gather_matmul_hbm_buffer( + shmem, + output_tensor: torch.Tensor, + A_sharded: torch.Tensor, + B: torch.Tensor, + bias: Optional[torch.Tensor] = None, + async_op: bool = False, + config: Optional[FusedConfig] = None, + workspace: Optional[FusedWorkspace] = None, +) -> FusedWorkspace: + """ + All-gather + matmul using a local HBM staging buffer with per-tile flags. + + Computes C = all_gather(A_sharded) @ B + bias. + + Each CU first gathers its assigned slice of A tiles into the local buffer + (setting per-tile ready flags), then runs GEMM from the buffer, spinning + on flags for any tile not yet available. + """ + if config is None: + config = FusedConfig() + + M, K_local = A_sharded.shape + K, N = B.shape + world_size = shmem.get_num_ranks() + rank = shmem.get_rank() + + expected_K = world_size * K_local + assert K == expected_K, f"K ({K}) must equal world_size ({world_size}) * K_local ({K_local})" + assert output_tensor.shape == (M, N), f"Output must be ({M}, {N}), got {output_tensor.shape}" + assert M % config.block_size_m == 0, ( + f"M ({M}) must be divisible by block_size_m ({config.block_size_m})" + ) + assert K % config.block_size_k == 0, ( + f"K ({K}) must be divisible by block_size_k ({config.block_size_k})" + ) + assert K_local % config.block_size_k == 0, ( + f"K_local ({K_local}) must be divisible by block_size_k ({config.block_size_k})" + ) + + if workspace is None: + workspace = all_gather_matmul_hbm_buffer_preamble(shmem, A_sharded, B, config) + + # Reset flags to 0 before each launch + workspace.locks.zero_() + + stride_am, stride_ak = A_sharded.stride() + stride_bk, stride_bn = B.stride() + stride_cm, stride_cn = output_tensor.stride() + + if bias is not None: + assert bias.shape[0] == M + bias_ptr = bias + stride_bias = bias.stride()[0] if bias.dim() > 0 else 1 + use_bias = True + else: + bias_ptr = output_tensor # dummy, won't be read + stride_bias = 1 + use_bias = False + + device = A_sharded.device + num_sms = config.num_sms + if num_sms is None: + props = torch.cuda.get_device_properties(device) + num_sms = props.multi_processor_count + + num_m_tiles = M // config.block_size_m + num_k_blocks = K // config.block_size_k + num_k_blocks_local = K_local // config.block_size_k + + grid = (num_sms,) + _hbm_buffer_all_gather_matmul_kernel[grid]( + A_sharded, + B, + output_tensor, + bias_ptr, + workspace.aux_buffer, # staged_a + workspace.locks, # flags + M, + N, + K, + K_local, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_bias, + shmem.get_device_context(), + rank, + world_size, + config.block_size_m, + config.block_size_n, + config.block_size_k, + config.group_size_m, + num_sms, + config.num_xcds, + num_m_tiles, + num_k_blocks, + num_k_blocks_local, + use_bias, + config.allow_tf32, + matrix_instr_nonkdim=16, + ) + + if not async_op: + shmem.barrier() + + return workspace From 1f3b9ef87b218e6b405d317fd1194fded4099c20 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Wed, 11 Feb 2026 19:18:56 +0000 Subject: [PATCH 09/31] Apply Ruff auto-fixes --- .../all_gather_matmul/benchmark_hbm_buffer.py | 45 +++++++++++++------ iris/ops/all_gather_matmul_hbm_buffer.py | 36 +++++++-------- 2 files changed, 46 insertions(+), 35 deletions(-) diff --git a/benchmark/ops/all_gather_matmul/benchmark_hbm_buffer.py b/benchmark/ops/all_gather_matmul/benchmark_hbm_buffer.py index 8a2dbae21..0529ebb46 100644 --- a/benchmark/ops/all_gather_matmul/benchmark_hbm_buffer.py +++ b/benchmark/ops/all_gather_matmul/benchmark_hbm_buffer.py @@ -47,13 +47,17 @@ def parse_args(): parser.add_argument("-v", "--validate", action="store_true", help="Validate correctness") parser.add_argument("-b", "--benchmark", action="store_true", help="Run benchmark") parser.add_argument( - "--datatype", type=str, default="fp16", - choices=["fp16", "fp32", "bf16"], help="Tensor datatype", + "--datatype", + type=str, + default="fp16", + choices=["fp16", "fp32", "bf16"], + help="Tensor datatype", ) parser.add_argument("--heap_size", type=int, default=1 << 34, help="Iris heap size") parser.add_argument("--comm_sms", type=int, default=None, help="Number of SMs (auto if None)") parser.add_argument( - "--benchmark_pytorch", action="store_true", + "--benchmark_pytorch", + action="store_true", help="Also benchmark PyTorch (all_gather_into_tensor + matmul)", ) parser.add_argument("--block_size_m", type=int, default=256, help="Block size M") @@ -76,13 +80,16 @@ def _worker(args): if "RANK" in os.environ or "LOCAL_RANK" in os.environ: dist.init_process_group( - backend=backend, init_method="env://", + backend=backend, + init_method="env://", device_id=torch.device(f"cuda:{local_rank}") if torch.cuda.is_available() else None, ) else: dist.init_process_group( - backend=backend, init_method="tcp://127.0.0.1:29530", - world_size=world_size_env, rank=local_rank, + backend=backend, + init_method="tcp://127.0.0.1:29530", + world_size=world_size_env, + rank=local_rank, device_id=torch.device(f"cuda:{local_rank}") if torch.cuda.is_available() else None, ) @@ -110,7 +117,7 @@ def _worker(args): config_kwargs["num_xcds"] = args["num_xcds"] config = FusedConfig(**config_kwargs) - buffer_mb = M * K * torch.tensor([], dtype=datatype).element_size() / (1024 ** 2) + buffer_mb = M * K * torch.tensor([], dtype=datatype).element_size() / (1024**2) num_m_tiles = M // config.block_size_m num_k_blocks = K // config.block_size_k shmem.info( @@ -170,8 +177,13 @@ def run_experiment(): with torch.cuda.stream(comm_stream): start_ev.record() all_gather_matmul_hbm_buffer( - shmem, C, A_sharded, B, - config=config, async_op=False, workspace=workspace, + shmem, + C, + A_sharded, + B, + config=config, + async_op=False, + workspace=workspace, ) end_ev.record() num_experiments += 1 @@ -223,7 +235,7 @@ def run_experiment(): tflops = (total_flops * 1e-12) / (avg_ms * 1e-3) if avg_ms > 0 else 0 element_size = torch.tensor([], dtype=datatype).element_size() total_bytes = M * K_local * element_size * (world_size - 1) - bw_gbps = (total_bytes / (1024 ** 3)) / (avg_ms * 1e-3) if avg_ms > 0 else 0 + bw_gbps = (total_bytes / (1024**3)) / (avg_ms * 1e-3) if avg_ms > 0 else 0 shmem.info( f"HBM-Buffer (M={M}, K_local={K_local}, K={K}, N={N}, " @@ -244,8 +256,13 @@ def run_experiment(): t_start = time.perf_counter() all_gather_matmul_hbm_buffer( - shmem, C, A_sharded, B, - config=config, async_op=False, workspace=workspace, + shmem, + C, + A_sharded, + B, + config=config, + async_op=False, + workspace=workspace, ) torch.cuda.synchronize() t_end = time.perf_counter() @@ -261,7 +278,7 @@ def run_experiment(): times = [t.item() for t in all_finish] min_t = min(times) max_t = max(times) - print(f"\n Per-rank finish times (single run):") + print("\n Per-rank finish times (single run):") print(f" {'Rank':>6} {'Finish ms':>10} {'Delta ms':>10}") print(f" {'-' * 30}") for r, t in enumerate(times): @@ -297,7 +314,7 @@ def run_pt(): pt_ms = iris.do_bench(run_pt, dist.barrier) pt_tflops = (total_flops * 1e-12) / (pt_ms * 1e-3) if pt_ms > 0 else 0 - pt_bw = (total_bytes / (1024 ** 3)) / (pt_ms * 1e-3) if pt_ms > 0 else 0 + pt_bw = (total_bytes / (1024**3)) / (pt_ms * 1e-3) if pt_ms > 0 else 0 shmem.info( f"PyTorch (M={M}, K_local={K_local}, K={K}, N={N}, ws={world_size}, " diff --git a/iris/ops/all_gather_matmul_hbm_buffer.py b/iris/ops/all_gather_matmul_hbm_buffer.py index a0233b6bb..daeec0e1b 100644 --- a/iris/ops/all_gather_matmul_hbm_buffer.py +++ b/iris/ops/all_gather_matmul_hbm_buffer.py @@ -39,8 +39,8 @@ def _hbm_buffer_all_gather_matmul_kernel( B, C, bias_ptr, - staged_a, # Local HBM buffer: (M, K) fp16 - flags_ptr, # int32[NUM_M_TILES * NUM_K_BLOCKS] per-tile ready flags + staged_a, # Local HBM buffer: (M, K) fp16 + flags_ptr, # int32[NUM_M_TILES * NUM_K_BLOCKS] per-tile ready flags M, N, K, @@ -62,8 +62,8 @@ def _hbm_buffer_all_gather_matmul_kernel( NUM_SMS: tl.constexpr, NUM_XCDS: tl.constexpr, NUM_M_TILES: tl.constexpr, - NUM_K_BLOCKS: tl.constexpr, # K // BLOCK_SIZE_K (global) - NUM_K_BLOCKS_LOCAL: tl.constexpr, # K_local // BLOCK_SIZE_K + NUM_K_BLOCKS: tl.constexpr, # K // BLOCK_SIZE_K (global) + NUM_K_BLOCKS_LOCAL: tl.constexpr, # K_local // BLOCK_SIZE_K BIAS: tl.constexpr, ALLOW_TF32: tl.constexpr, ): @@ -222,12 +222,8 @@ def all_gather_matmul_hbm_buffer_preamble( assert K_local % config.block_size_k == 0, ( f"K_local ({K_local}) must be divisible by block_size_k ({config.block_size_k})" ) - assert K % config.block_size_k == 0, ( - f"K ({K}) must be divisible by block_size_k ({config.block_size_k})" - ) - assert M % config.block_size_m == 0, ( - f"M ({M}) must be divisible by block_size_m ({config.block_size_m})" - ) + assert K % config.block_size_k == 0, f"K ({K}) must be divisible by block_size_k ({config.block_size_k})" + assert M % config.block_size_m == 0, f"M ({M}) must be divisible by block_size_m ({config.block_size_m})" num_m_tiles = M // config.block_size_m num_k_blocks = K // config.block_size_k @@ -246,9 +242,11 @@ def all_gather_matmul_hbm_buffer_preamble( # Per-tile ready flags ws.locks = shmem.zeros((num_m_tiles * num_k_blocks,), dtype=torch.int32) - buffer_mb = M * K * A_sharded.element_size() / (1024 ** 2) - shmem.info(f"HBM buffer workspace: staged_a=({M},{K}) [{buffer_mb:.1f} MB], " - f"flags=[{num_m_tiles}x{num_k_blocks}={num_m_tiles * num_k_blocks}]") + buffer_mb = M * K * A_sharded.element_size() / (1024**2) + shmem.info( + f"HBM buffer workspace: staged_a=({M},{K}) [{buffer_mb:.1f} MB], " + f"flags=[{num_m_tiles}x{num_k_blocks}={num_m_tiles * num_k_blocks}]" + ) shmem.barrier() return ws @@ -284,12 +282,8 @@ def all_gather_matmul_hbm_buffer( expected_K = world_size * K_local assert K == expected_K, f"K ({K}) must equal world_size ({world_size}) * K_local ({K_local})" assert output_tensor.shape == (M, N), f"Output must be ({M}, {N}), got {output_tensor.shape}" - assert M % config.block_size_m == 0, ( - f"M ({M}) must be divisible by block_size_m ({config.block_size_m})" - ) - assert K % config.block_size_k == 0, ( - f"K ({K}) must be divisible by block_size_k ({config.block_size_k})" - ) + assert M % config.block_size_m == 0, f"M ({M}) must be divisible by block_size_m ({config.block_size_m})" + assert K % config.block_size_k == 0, f"K ({K}) must be divisible by block_size_k ({config.block_size_k})" assert K_local % config.block_size_k == 0, ( f"K_local ({K_local}) must be divisible by block_size_k ({config.block_size_k})" ) @@ -330,8 +324,8 @@ def all_gather_matmul_hbm_buffer( B, output_tensor, bias_ptr, - workspace.aux_buffer, # staged_a - workspace.locks, # flags + workspace.aux_buffer, # staged_a + workspace.locks, # flags M, N, K, From 45288ff39a32924339707f48180b9c27c0ec1bef Mon Sep 17 00:00:00 2001 From: Ryan Swann Date: Thu, 12 Feb 2026 21:16:06 -0500 Subject: [PATCH 10/31] Use workgroup specialized variant --- iris/ops/all_gather_matmul_hbm_buffer.py | 340 +++++++++++------------ 1 file changed, 164 insertions(+), 176 deletions(-) diff --git a/iris/ops/all_gather_matmul_hbm_buffer.py b/iris/ops/all_gather_matmul_hbm_buffer.py index daeec0e1b..936a9a9a4 100644 --- a/iris/ops/all_gather_matmul_hbm_buffer.py +++ b/iris/ops/all_gather_matmul_hbm_buffer.py @@ -2,19 +2,11 @@ # Copyright (c) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. """ -Fused All-Gather + GEMM using a local HBM staging buffer with per-tile flags. - -Each rank has a column-sharded input A_sharded (M x K_local). -This operation computes C = all_gather(A_sharded) @ B by: - 1. All CUs cooperate to gather A into a local HBM buffer, setting a ready - flag for each (m_tile, k_block) as it lands. - 2. Each CU then runs GEMM from the local buffer. Before consuming a tile, - it checks the ready flag; if not yet set, it spins until the gathering - CU writes it. - -No global barriers are needed. The per-tile flags provide fine-grained -producer-consumer synchronization: a CU that finishes gathering early can -start GEMM immediately, consuming any tile whose flag is already set. +Fused All-Gather + GEMM using a local HBM staging buffer with dedicated +fetcher and GEMM workgroups, launched data-parallel. + +Supports configurable staged_a buffer layout (M-contiguous or K-contiguous) +and B layout to match optimal tritonblas conventions (TN, TT, NT, NN). """ from typing import Optional @@ -28,19 +20,14 @@ from .workspace import FusedWorkspace -# ========================================================================== -# Kernel -# ========================================================================== - - @triton.jit def _hbm_buffer_all_gather_matmul_kernel( A_sharded, B, C, bias_ptr, - staged_a, # Local HBM buffer: (M, K) fp16 - flags_ptr, # int32[NUM_M_TILES * NUM_K_BLOCKS] per-tile ready flags + staged_a, + flags_ptr, M, N, K, @@ -51,6 +38,8 @@ def _hbm_buffer_all_gather_matmul_kernel( stride_bn, stride_cm, stride_cn, + stride_sa_m, # staged_a stride in M dim + stride_sa_k, # staged_a stride in K dim stride_bias, context_tensor: tl.tensor, cur_rank: tl.constexpr, @@ -59,137 +48,116 @@ def _hbm_buffer_all_gather_matmul_kernel( BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, - NUM_SMS: tl.constexpr, - NUM_XCDS: tl.constexpr, + NUM_FETCH_SMS: tl.constexpr, NUM_M_TILES: tl.constexpr, - NUM_K_BLOCKS: tl.constexpr, # K // BLOCK_SIZE_K (global) - NUM_K_BLOCKS_LOCAL: tl.constexpr, # K_local // BLOCK_SIZE_K + NUM_TILES_N: tl.constexpr, + NUM_K_BLOCKS: tl.constexpr, + NUM_K_BLOCKS_LOCAL: tl.constexpr, + K_PER_FLAG: tl.constexpr, + NUM_FLAG_GROUPS_K: tl.constexpr, + TOTAL_GATHER_TILES: tl.constexpr, BIAS: tl.constexpr, ALLOW_TF32: tl.constexpr, ): - """ - HBM-buffered all-gather + GEMM with per-tile ready flags. + pid = tl.program_id(0) + acc_dtype = tl.int32 if C.type.element_ty == tl.int8 else tl.float32 + zero = tl.program_id(0) * 0 - Each CU executes two phases back-to-back (no global barrier): + if pid < NUM_FETCH_SMS: + # ============================================================== + # FETCHER + # ============================================================== + ctx = iris.DeviceContext.initialize(context_tensor, cur_rank, world_size) + src_view = iris.x.make_tensor_view(A_sharded, M, K_local, stride_am, stride_ak) - Phase 1 (gather): The CU is assigned a slice of the (m_tile, src_rank, - k_block_local) gather work. For each assigned tile it pulls from remote - via iris.x.gather, writes to staged_a, and atomically sets the ready - flag. Local rank tiles are copied via a fast local load. + num_m_groups = (NUM_M_TILES + GROUP_SIZE_M - 1) // GROUP_SIZE_M + tiles_per_m_group = NUM_FLAG_GROUPS_K * GROUP_SIZE_M + total_flag_groups = NUM_FLAG_GROUPS_K * NUM_M_TILES - Phase 2 (GEMM): The CU iterates over its assigned output tiles - (pid_m, pid_n). For each K-block in the accumulation loop it checks the - ready flag; if not yet set, it spins until the producing CU posts it. - A tiles are loaded from staged_a (local HBM) and B tiles from B. - """ - pid = tl.program_id(0) + for fg_idx in range(pid, total_flag_groups, NUM_FETCH_SMS): + m_group = fg_idx // tiles_per_m_group + within_group = fg_idx % tiles_per_m_group + k_flag_group = within_group // GROUP_SIZE_M + m_in_group = within_group % GROUP_SIZE_M + m_tile = m_group * GROUP_SIZE_M + m_in_group + m_tile = min(m_tile, NUM_M_TILES - 1) + k_block_start = k_flag_group * K_PER_FLAG - # XCD-aware PID remapping - if NUM_XCDS != 1: - pid = (pid % NUM_XCDS) * (NUM_SMS // NUM_XCDS) + (pid // NUM_XCDS) + rm = m_tile * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - acc_dtype = tl.int32 if C.type.element_ty == tl.int8 else tl.float32 + for k_off in range(K_PER_FLAG): + k_block_global = k_block_start + k_off + + src_rank_idx = k_block_global // NUM_K_BLOCKS_LOCAL + k_block_local = k_block_global % NUM_K_BLOCKS_LOCAL + + pid_m_t = zero + m_tile + tile_k_t = zero + k_block_local + k_tile = iris.x.TileView(pid_m_t, tile_k_t, BLOCK_SIZE_M, BLOCK_SIZE_K) + + rk = k_block_global * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + # Use parameterized strides for staged_a + staged_ptrs = staged_a + rm[:, None] * stride_sa_m + rk[None, :] * stride_sa_k + + for compile_rank in range(world_size): + if src_rank_idx == compile_rank: + a_tile = iris.x.gather(k_tile, src_view, compile_rank, ctx) + tl.store(staged_ptrs, a_tile) - # DeviceContext and TensorView for gather - ctx = iris.DeviceContext.initialize(context_tensor, cur_rank, world_size) - src_view = iris.x.make_tensor_view(A_sharded, M, K_local, stride_am, stride_ak) - - # ================================================================== - # Phase 1: Cooperative gather into staged_a, set per-tile flags - # ================================================================== - # Total gather work = NUM_M_TILES * world_size * NUM_K_BLOCKS_LOCAL - # Each tile is BLOCK_SIZE_M x BLOCK_SIZE_K elements. - total_gather_tiles = NUM_M_TILES * world_size * NUM_K_BLOCKS_LOCAL - - for gather_idx in range(pid, total_gather_tiles, NUM_SMS): - # Decompose flat index -> (m_tile, src_rank_idx, k_block_local) - m_tile = gather_idx // (world_size * NUM_K_BLOCKS_LOCAL) - remainder = gather_idx % (world_size * NUM_K_BLOCKS_LOCAL) - src_rank_idx = remainder // NUM_K_BLOCKS_LOCAL - k_block_local = remainder % NUM_K_BLOCKS_LOCAL - - # Global k-block index in the full K dimension - k_block_global = src_rank_idx * NUM_K_BLOCKS_LOCAL + k_block_local - - # Gather the tile from the source rank, store to buffer, set flag. - # source_rank must be constexpr for iris.x.gather, so we iterate - # over all ranks at compile time and select at runtime. - # The store and flag-set are inside the branch so that a_tile is - # always defined when used. - zero = tl.program_id(0) * 0 - pid_m_t = zero + m_tile - tile_k_t = zero + k_block_local - k_tile = iris.x.TileView(pid_m_t, tile_k_t, BLOCK_SIZE_M, BLOCK_SIZE_K) - - rm = m_tile * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - rk = k_block_global * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) - staged_ptrs = staged_a + rm[:, None] * K + rk[None, :] - flag_idx = m_tile * NUM_K_BLOCKS + k_block_global - - for compile_rank in range(world_size): - if src_rank_idx == compile_rank: - a_tile = iris.x.gather(k_tile, src_view, compile_rank, ctx) - tl.store(staged_ptrs, a_tile) - tl.atomic_xchg(flags_ptr + flag_idx, 1, sem="release", scope="gpu") - - # ================================================================== - # Phase 2: GEMM from staged_a (local) x B, checking flags - # ================================================================== - num_tiles_n = tl.cdiv(N, BLOCK_SIZE_N) - total_gemm_tiles = NUM_M_TILES * num_tiles_n - - for gemm_tile_id in range(pid, total_gemm_tiles, NUM_SMS): - # Tile scheduling with swizzle (GROUP_SIZE_M grouping) - num_pid_in_group = GROUP_SIZE_M * num_tiles_n + flag_idx = m_tile * NUM_FLAG_GROUPS_K + k_flag_group + tl.atomic_xchg(flags_ptr + flag_idx, 1, sem="release", scope="gpu") + + else: + # ============================================================== + # GEMM + # ============================================================== + gemm_tile_id = pid - NUM_FETCH_SMS + + num_pid_in_group = GROUP_SIZE_M * NUM_TILES_N group_id = gemm_tile_id // num_pid_in_group first_pid_m = group_id * GROUP_SIZE_M group_sz = min(NUM_M_TILES - first_pid_m, GROUP_SIZE_M) pid_m = first_pid_m + ((gemm_tile_id % num_pid_in_group) % group_sz) pid_n = (gemm_tile_id % num_pid_in_group) // group_sz - # Row / column indices rm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) rn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) rn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_SIZE_N), BLOCK_SIZE_N) - # Initialize accumulator acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) - # K-reduction loop - for k_block in range(NUM_K_BLOCKS): - # Wait for the (pid_m, k_block) tile to be ready. - # acquire semantics ensure subsequent loads see the stored data. - flag_idx = pid_m * NUM_K_BLOCKS + k_block + for k_fg in range(NUM_FLAG_GROUPS_K): + flag_idx = pid_m * NUM_FLAG_GROUPS_K + k_fg while tl.atomic_add(flags_ptr + flag_idx, 0, sem="acquire", scope="gpu") == 0: pass - # Load A from staged_a (purely local HBM) - rk = k_block * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) - rk = tl.max_contiguous(tl.multiple_of(rk, BLOCK_SIZE_K), BLOCK_SIZE_K) - a_ptrs = staged_a + rm[:, None] * K + rk[None, :] - a = tl.load(a_ptrs) + k_block_base = k_fg * K_PER_FLAG + for k_off in range(K_PER_FLAG): + k_block = k_block_base + k_off + rk = k_block * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + rk = tl.max_contiguous(tl.multiple_of(rk, BLOCK_SIZE_K), BLOCK_SIZE_K) + + # Use parameterized strides for staged_a + a_ptrs = staged_a + rm[:, None] * stride_sa_m + rk[None, :] * stride_sa_k + a = tl.load(a_ptrs) - # Load B - B_ptrs = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn - b = tl.load(B_ptrs) + B_ptrs = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn + b = tl.load(B_ptrs) - # Accumulate - if ALLOW_TF32: - acc = tl.dot(a, b, acc, allow_tf32=True) - else: - acc += tl.dot(a, b, allow_tf32=False) + if ALLOW_TF32: + acc = tl.dot(a, b, acc, allow_tf32=True) + else: + acc += tl.dot(a, b, allow_tf32=False) - # Add bias if provided if BIAS: bias_val = tl.load(bias_ptr + rm * stride_bias, mask=rm < M, other=0.0) acc = acc + bias_val[:, None] - # Convert to output dtype and store c = acc.to(C.type.element_ty) C_ptrs = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn - mask = (rm[:, None] < M) & (rn[None, :] < N) - tl.store(C_ptrs, c, mask=mask) + c_mask = (rm[:, None] < M) & (rn[None, :] < N) + tl.store(C_ptrs, c, mask=c_mask) # ========================================================================== @@ -202,13 +170,15 @@ def all_gather_matmul_hbm_buffer_preamble( A_sharded: torch.Tensor, B: torch.Tensor, config: Optional[FusedConfig] = None, + k_per_flag: int = 1, + staged_a_layout: str = "k_contiguous", ) -> FusedWorkspace: """ - Allocate workspace for the HBM-buffered all_gather_matmul. + Allocate workspace. - Allocates: - - staged_a: (M, K) local HBM buffer for the gathered A matrix. - - flags: int32[num_m_tiles * num_k_blocks] per-tile ready flags. + Args: + staged_a_layout: "k_contiguous" (default, row-major (M,K)) or + "m_contiguous" (col-major, stored as (K,M) transposed). """ if config is None: config = FusedConfig() @@ -217,35 +187,41 @@ def all_gather_matmul_hbm_buffer_preamble( K, N = B.shape world_size = shmem.get_num_ranks() - expected_K = world_size * K_local - assert K == expected_K, f"K ({K}) must equal world_size ({world_size}) * K_local ({K_local})" - assert K_local % config.block_size_k == 0, ( - f"K_local ({K_local}) must be divisible by block_size_k ({config.block_size_k})" - ) - assert K % config.block_size_k == 0, f"K ({K}) must be divisible by block_size_k ({config.block_size_k})" - assert M % config.block_size_m == 0, f"M ({M}) must be divisible by block_size_m ({config.block_size_m})" + assert world_size * K_local == K + assert K_local % config.block_size_k == 0 + assert K % config.block_size_k == 0 + assert M % config.block_size_m == 0 num_m_tiles = M // config.block_size_m num_k_blocks = K // config.block_size_k + assert num_k_blocks % k_per_flag == 0 + num_flag_groups_k = num_k_blocks // k_per_flag ws = FusedWorkspace( operation="all_gather_matmul_hbm_buffer", shape=(M, N, K), dtype=A_sharded.dtype, world_size=world_size, - variant="hbm_buffer", + variant=f"hbm_buffer_{staged_a_layout}", prepared=True, ) - # (M, K) staging buffer in local HBM - ws.aux_buffer = shmem.zeros((M, K), dtype=A_sharded.dtype) - # Per-tile ready flags - ws.locks = shmem.zeros((num_m_tiles * num_k_blocks,), dtype=torch.int32) + if staged_a_layout == "m_contiguous": + # Allocate (K, M) row-major, .T gives (M, K) with stride_m=1, stride_k=M + storage = shmem.zeros((K, M), dtype=A_sharded.dtype) + ws.aux_buffer = storage.T # (M, K) view, M-contiguous + else: + # Default: (M, K) row-major, stride_m=K, stride_k=1 + ws.aux_buffer = shmem.zeros((M, K), dtype=A_sharded.dtype) + + ws.locks = shmem.zeros((num_m_tiles * num_flag_groups_k,), dtype=torch.int32) - buffer_mb = M * K * A_sharded.element_size() / (1024**2) + buffer_mb = M * K * A_sharded.element_size() / (1024 ** 2) + sa_stride_m, sa_stride_k = ws.aux_buffer.stride() shmem.info( - f"HBM buffer workspace: staged_a=({M},{K}) [{buffer_mb:.1f} MB], " - f"flags=[{num_m_tiles}x{num_k_blocks}={num_m_tiles * num_k_blocks}]" + f"HBM buffer: staged_a=({M},{K}) [{buffer_mb:.1f} MB] " + f"layout={staged_a_layout} strides=({sa_stride_m},{sa_stride_k}), " + f"flags={num_m_tiles}x{num_flag_groups_k}, k_per_flag={k_per_flag}" ) shmem.barrier() @@ -261,15 +237,19 @@ def all_gather_matmul_hbm_buffer( async_op: bool = False, config: Optional[FusedConfig] = None, workspace: Optional[FusedWorkspace] = None, + num_fetch_sms: Optional[int] = None, + k_per_flag: int = 1, + fetch_block_m: Optional[int] = None, + fetch_block_k: Optional[int] = None, + staged_a_layout: str = "k_contiguous", ) -> FusedWorkspace: """ - All-gather + matmul using a local HBM staging buffer with per-tile flags. + All-gather + matmul with dedicated fetcher/GEMM workgroups. - Computes C = all_gather(A_sharded) @ B + bias. - - Each CU first gathers its assigned slice of A tiles into the local buffer - (setting per-tile ready flags), then runs GEMM from the buffer, spinning - on flags for any tile not yet available. + Args: + staged_a_layout: Buffer layout for gathered A. + "k_contiguous" — (M,K) row-major, K is fast dim. Matches NN convention. + "m_contiguous" — (M,K) with M as fast dim. Matches TN convention (best for tritonblas). """ if config is None: config = FusedConfig() @@ -279,24 +259,31 @@ def all_gather_matmul_hbm_buffer( world_size = shmem.get_num_ranks() rank = shmem.get_rank() - expected_K = world_size * K_local - assert K == expected_K, f"K ({K}) must equal world_size ({world_size}) * K_local ({K_local})" - assert output_tensor.shape == (M, N), f"Output must be ({M}, {N}), got {output_tensor.shape}" - assert M % config.block_size_m == 0, f"M ({M}) must be divisible by block_size_m ({config.block_size_m})" - assert K % config.block_size_k == 0, f"K ({K}) must be divisible by block_size_k ({config.block_size_k})" - assert K_local % config.block_size_k == 0, ( - f"K_local ({K_local}) must be divisible by block_size_k ({config.block_size_k})" - ) + assert world_size * K_local == K + assert output_tensor.shape == (M, N) + assert M % config.block_size_m == 0 + assert K % config.block_size_k == 0 + assert K_local % config.block_size_k == 0 + + if fetch_block_m is None: + fetch_block_m = config.block_size_m + if fetch_block_k is None: + fetch_block_k = config.block_size_k + + num_k_blocks = K // config.block_size_k + assert num_k_blocks % k_per_flag == 0 if workspace is None: - workspace = all_gather_matmul_hbm_buffer_preamble(shmem, A_sharded, B, config) + workspace = all_gather_matmul_hbm_buffer_preamble( + shmem, A_sharded, B, config, k_per_flag, staged_a_layout + ) - # Reset flags to 0 before each launch workspace.locks.zero_() stride_am, stride_ak = A_sharded.stride() stride_bk, stride_bn = B.stride() stride_cm, stride_cn = output_tensor.stride() + stride_sa_m, stride_sa_k = workspace.aux_buffer.stride() if bias is not None: assert bias.shape[0] == M @@ -304,7 +291,7 @@ def all_gather_matmul_hbm_buffer( stride_bias = bias.stride()[0] if bias.dim() > 0 else 1 use_bias = True else: - bias_ptr = output_tensor # dummy, won't be read + bias_ptr = output_tensor stride_bias = 1 use_bias = False @@ -315,40 +302,41 @@ def all_gather_matmul_hbm_buffer( num_sms = props.multi_processor_count num_m_tiles = M // config.block_size_m - num_k_blocks = K // config.block_size_k + num_tiles_n = (N + config.block_size_n - 1) // config.block_size_n + total_gemm_tiles = num_m_tiles * num_tiles_n num_k_blocks_local = K_local // config.block_size_k - - grid = (num_sms,) - _hbm_buffer_all_gather_matmul_kernel[grid]( - A_sharded, - B, - output_tensor, - bias_ptr, - workspace.aux_buffer, # staged_a - workspace.locks, # flags - M, - N, - K, - K_local, - stride_am, - stride_ak, - stride_bk, - stride_bn, - stride_cm, - stride_cn, + num_flag_groups_k = num_k_blocks // k_per_flag + total_gather_tiles = num_m_tiles * num_k_blocks + + if num_fetch_sms is None: + num_fetch_sms = max(1, num_sms // 10) + assert 0 < num_fetch_sms + + grid_size = num_fetch_sms + total_gemm_tiles + + _hbm_buffer_all_gather_matmul_kernel[(grid_size,)]( + A_sharded, B, output_tensor, bias_ptr, + workspace.aux_buffer, workspace.locks, + M, N, K, K_local, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + stride_sa_m, stride_sa_k, stride_bias, shmem.get_device_context(), - rank, - world_size, + rank, world_size, config.block_size_m, config.block_size_n, config.block_size_k, config.group_size_m, - num_sms, - config.num_xcds, + num_fetch_sms, num_m_tiles, + num_tiles_n, num_k_blocks, num_k_blocks_local, + k_per_flag, + num_flag_groups_k, + total_gather_tiles, use_bias, config.allow_tf32, matrix_instr_nonkdim=16, From b2aadcd5dc8f7d28d8833fe3f4ed5834e58740cc Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Fri, 13 Feb 2026 02:16:39 +0000 Subject: [PATCH 11/31] Apply Ruff auto-fixes --- iris/ops/all_gather_matmul_hbm_buffer.py | 38 +++++++++++++++--------- 1 file changed, 24 insertions(+), 14 deletions(-) diff --git a/iris/ops/all_gather_matmul_hbm_buffer.py b/iris/ops/all_gather_matmul_hbm_buffer.py index 936a9a9a4..ab8e9d4f8 100644 --- a/iris/ops/all_gather_matmul_hbm_buffer.py +++ b/iris/ops/all_gather_matmul_hbm_buffer.py @@ -38,8 +38,8 @@ def _hbm_buffer_all_gather_matmul_kernel( stride_bn, stride_cm, stride_cn, - stride_sa_m, # staged_a stride in M dim - stride_sa_k, # staged_a stride in K dim + stride_sa_m, # staged_a stride in M dim + stride_sa_k, # staged_a stride in K dim stride_bias, context_tensor: tl.tensor, cur_rank: tl.constexpr, @@ -216,7 +216,7 @@ def all_gather_matmul_hbm_buffer_preamble( ws.locks = shmem.zeros((num_m_tiles * num_flag_groups_k,), dtype=torch.int32) - buffer_mb = M * K * A_sharded.element_size() / (1024 ** 2) + buffer_mb = M * K * A_sharded.element_size() / (1024**2) sa_stride_m, sa_stride_k = ws.aux_buffer.stride() shmem.info( f"HBM buffer: staged_a=({M},{K}) [{buffer_mb:.1f} MB] " @@ -274,9 +274,7 @@ def all_gather_matmul_hbm_buffer( assert num_k_blocks % k_per_flag == 0 if workspace is None: - workspace = all_gather_matmul_hbm_buffer_preamble( - shmem, A_sharded, B, config, k_per_flag, staged_a_layout - ) + workspace = all_gather_matmul_hbm_buffer_preamble(shmem, A_sharded, B, config, k_per_flag, staged_a_layout) workspace.locks.zero_() @@ -315,16 +313,28 @@ def all_gather_matmul_hbm_buffer( grid_size = num_fetch_sms + total_gemm_tiles _hbm_buffer_all_gather_matmul_kernel[(grid_size,)]( - A_sharded, B, output_tensor, bias_ptr, - workspace.aux_buffer, workspace.locks, - M, N, K, K_local, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - stride_sa_m, stride_sa_k, + A_sharded, + B, + output_tensor, + bias_ptr, + workspace.aux_buffer, + workspace.locks, + M, + N, + K, + K_local, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_sa_m, + stride_sa_k, stride_bias, shmem.get_device_context(), - rank, world_size, + rank, + world_size, config.block_size_m, config.block_size_n, config.block_size_k, From 7b2321eac0f31f65ec13c3ca259f49d629d8169b Mon Sep 17 00:00:00 2001 From: Ryan Swann Date: Mon, 16 Feb 2026 15:56:37 -0600 Subject: [PATCH 12/31] Update hbm buffered all gather matmul --- .../all_gather_matmul/benchmark_hbm_buffer.py | 21 ++++++++++++++++++- iris/ops/all_gather_matmul_hbm_buffer.py | 15 ++++++++++--- 2 files changed, 32 insertions(+), 4 deletions(-) diff --git a/benchmark/ops/all_gather_matmul/benchmark_hbm_buffer.py b/benchmark/ops/all_gather_matmul/benchmark_hbm_buffer.py index 0529ebb46..1a6ca502a 100644 --- a/benchmark/ops/all_gather_matmul/benchmark_hbm_buffer.py +++ b/benchmark/ops/all_gather_matmul/benchmark_hbm_buffer.py @@ -68,6 +68,10 @@ def parse_args(): parser.add_argument("--b_col_major", action="store_true", help="B col-major (K-contiguous)") parser.add_argument("--a_col_major", action="store_true", help="A col-major (M-contiguous)") parser.add_argument("--single-run", action="store_true", help="1 iteration (for profiling)") + parser.add_argument("--num_fetch_sms", type=int, default=None, help="Fetcher SMs (auto if None)") + parser.add_argument("--k_per_flag", type=int, default=1, help="K-blocks per ready flag") + parser.add_argument("--num_warps", type=int, default=None, help="Triton num_warps (auto if None)") + parser.add_argument("--num_stages", type=int, default=None, help="Triton num_stages (auto if None)") return vars(parser.parse_args()) @@ -162,7 +166,10 @@ def _worker(args): expected_tensor.copy_(torch.matmul(A_gathered, B_data)) # Pre-allocate workspace - workspace = all_gather_matmul_hbm_buffer_preamble(shmem, A_sharded, B, config) + k_per_flag = args["k_per_flag"] + workspace = all_gather_matmul_hbm_buffer_preamble( + shmem, A_sharded, B, config, k_per_flag=k_per_flag + ) # ── Timing ─────────────────────────────────────────────────────────── comm_stream = torch.cuda.Stream() @@ -171,6 +178,10 @@ def _worker(args): total_ms = 0.0 num_experiments = 0 + num_fetch_sms = args["num_fetch_sms"] + num_warps = args["num_warps"] + num_stages = args["num_stages"] + def run_experiment(): nonlocal total_ms, num_experiments shmem.barrier() @@ -184,6 +195,10 @@ def run_experiment(): config=config, async_op=False, workspace=workspace, + num_fetch_sms=num_fetch_sms, + k_per_flag=k_per_flag, + num_warps=num_warps, + num_stages=num_stages, ) end_ev.record() num_experiments += 1 @@ -263,6 +278,10 @@ def run_experiment(): config=config, async_op=False, workspace=workspace, + num_fetch_sms=num_fetch_sms, + k_per_flag=k_per_flag, + num_warps=num_warps, + num_stages=num_stages, ) torch.cuda.synchronize() t_end = time.perf_counter() diff --git a/iris/ops/all_gather_matmul_hbm_buffer.py b/iris/ops/all_gather_matmul_hbm_buffer.py index 936a9a9a4..8ab69704b 100644 --- a/iris/ops/all_gather_matmul_hbm_buffer.py +++ b/iris/ops/all_gather_matmul_hbm_buffer.py @@ -102,10 +102,11 @@ def _hbm_buffer_all_gather_matmul_kernel( for compile_rank in range(world_size): if src_rank_idx == compile_rank: a_tile = iris.x.gather(k_tile, src_view, compile_rank, ctx) - tl.store(staged_ptrs, a_tile) + tl.store(staged_ptrs, a_tile,cache_modifier=".wt") flag_idx = m_tile * NUM_FLAG_GROUPS_K + k_flag_group - tl.atomic_xchg(flags_ptr + flag_idx, 1, sem="release", scope="gpu") + #tl.atomic_xchg(flags_ptr + flag_idx, 1, sem="release", scope="gpu") + tl.store(flags_ptr + flag_idx, 1) else: # ============================================================== @@ -242,6 +243,8 @@ def all_gather_matmul_hbm_buffer( fetch_block_m: Optional[int] = None, fetch_block_k: Optional[int] = None, staged_a_layout: str = "k_contiguous", + num_warps: Optional[int] = None, + num_stages: Optional[int] = None, ) -> FusedWorkspace: """ All-gather + matmul with dedicated fetcher/GEMM workgroups. @@ -314,6 +317,12 @@ def all_gather_matmul_hbm_buffer( grid_size = num_fetch_sms + total_gemm_tiles + launch_kwargs = {"matrix_instr_nonkdim": 16} + if num_warps is not None: + launch_kwargs["num_warps"] = num_warps + if num_stages is not None: + launch_kwargs["num_stages"] = num_stages + _hbm_buffer_all_gather_matmul_kernel[(grid_size,)]( A_sharded, B, output_tensor, bias_ptr, workspace.aux_buffer, workspace.locks, @@ -339,7 +348,7 @@ def all_gather_matmul_hbm_buffer( total_gather_tiles, use_bias, config.allow_tf32, - matrix_instr_nonkdim=16, + **launch_kwargs, ) if not async_op: From 9692222bfb74930ef9fb50028c3554b3181c35ee Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Mon, 16 Feb 2026 21:57:42 +0000 Subject: [PATCH 13/31] Apply Ruff auto-fixes --- benchmark/ops/all_gather_matmul/benchmark_hbm_buffer.py | 4 +--- iris/ops/all_gather_matmul_hbm_buffer.py | 4 ++-- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/benchmark/ops/all_gather_matmul/benchmark_hbm_buffer.py b/benchmark/ops/all_gather_matmul/benchmark_hbm_buffer.py index 1a6ca502a..aa8221d60 100644 --- a/benchmark/ops/all_gather_matmul/benchmark_hbm_buffer.py +++ b/benchmark/ops/all_gather_matmul/benchmark_hbm_buffer.py @@ -167,9 +167,7 @@ def _worker(args): # Pre-allocate workspace k_per_flag = args["k_per_flag"] - workspace = all_gather_matmul_hbm_buffer_preamble( - shmem, A_sharded, B, config, k_per_flag=k_per_flag - ) + workspace = all_gather_matmul_hbm_buffer_preamble(shmem, A_sharded, B, config, k_per_flag=k_per_flag) # ── Timing ─────────────────────────────────────────────────────────── comm_stream = torch.cuda.Stream() diff --git a/iris/ops/all_gather_matmul_hbm_buffer.py b/iris/ops/all_gather_matmul_hbm_buffer.py index 329516bc0..4f8de5044 100644 --- a/iris/ops/all_gather_matmul_hbm_buffer.py +++ b/iris/ops/all_gather_matmul_hbm_buffer.py @@ -102,10 +102,10 @@ def _hbm_buffer_all_gather_matmul_kernel( for compile_rank in range(world_size): if src_rank_idx == compile_rank: a_tile = iris.x.gather(k_tile, src_view, compile_rank, ctx) - tl.store(staged_ptrs, a_tile,cache_modifier=".wt") + tl.store(staged_ptrs, a_tile, cache_modifier=".wt") flag_idx = m_tile * NUM_FLAG_GROUPS_K + k_flag_group - #tl.atomic_xchg(flags_ptr + flag_idx, 1, sem="release", scope="gpu") + # tl.atomic_xchg(flags_ptr + flag_idx, 1, sem="release", scope="gpu") tl.store(flags_ptr + flag_idx, 1) else: From 44ebc976f983e0e0573fbbdca17ad0a6c8b78231 Mon Sep 17 00:00:00 2001 From: Ryan Swann Date: Mon, 16 Feb 2026 17:59:20 -0600 Subject: [PATCH 14/31] Add tracing --- .../all_gather_matmul/benchmark_hbm_buffer.py | 192 ++++++++++++++++++ iris/ops/all_gather_matmul_hbm_buffer.py | 66 +++++- 2 files changed, 254 insertions(+), 4 deletions(-) diff --git a/benchmark/ops/all_gather_matmul/benchmark_hbm_buffer.py b/benchmark/ops/all_gather_matmul/benchmark_hbm_buffer.py index 1a6ca502a..dc74c7fa3 100644 --- a/benchmark/ops/all_gather_matmul/benchmark_hbm_buffer.py +++ b/benchmark/ops/all_gather_matmul/benchmark_hbm_buffer.py @@ -24,6 +24,7 @@ import torch.distributed as dist import random import argparse +import numpy as np import iris from iris.ops.all_gather_matmul_hbm_buffer import ( @@ -35,6 +36,135 @@ torch.manual_seed(123) random.seed(123) +TICKS_PER_US = 100 # s_memrealtime runs at 100 MHz: 1 tick = 10 ns = 0.01 us + + +def _plot_trace(trace_data, output_path, rank, M, N, K, num_fetch_sms_cfg): + """Generate a tall Gantt chart showing per-workgroup activity over time. + + Y-axis: workgroup (sorted by start time) + X-axis: time in microseconds + Colors: fetcher (blue), GEMM wait (red), GEMM compute (green) + """ + import matplotlib + matplotlib.use("Agg") + import matplotlib.pyplot as plt + from matplotlib.patches import Rectangle + from matplotlib.lines import Line2D + + starts = trace_data["start"].numpy().astype(np.int64) + ends = trace_data["end"].numpy().astype(np.int64) + waits = trace_data["wait"].numpy().astype(np.int64) + xcds = trace_data["xcd"].numpy().astype(np.int32) + grid_size = trace_data["grid_size"] + n_fetch = trace_data["num_fetch_sms"] + + # Convert to microseconds relative to earliest start + t_min = starts.min() + starts_us = (starts - t_min) / TICKS_PER_US + ends_us = (ends - t_min) / TICKS_PER_US + waits_us = waits / TICKS_PER_US + + # Build role array: 0=fetcher, 1=GEMM + roles = np.array([0 if i < n_fetch else 1 for i in range(grid_size)]) + + # Sort by start time + order = np.argsort(starts_us) + + # Compute figure height: ~0.012 inches per row, min 12 inches + row_h = 0.012 + fig_h = max(12, grid_size * row_h + 2) + fig, ax = plt.subplots(figsize=(18, fig_h)) + + fetch_color = "#2196F3" # blue + wait_color = "#F44336" # red + compute_color = "#4CAF50" # green + + for y_idx, wg_idx in enumerate(order): + s = starts_us[wg_idx] + e = ends_us[wg_idx] + dur = e - s + role = roles[wg_idx] + + if role == 0: + # Fetcher: solid blue bar + ax.barh(y_idx, dur, left=s, height=0.8, color=fetch_color, + edgecolor="none", linewidth=0) + else: + # GEMM: split into wait (red) and compute (green) + w = waits_us[wg_idx] + c = max(0, dur - w) + # Show wait portion first, then compute + ax.barh(y_idx, w, left=s, height=0.8, color=wait_color, + edgecolor="none", linewidth=0) + ax.barh(y_idx, c, left=s + w, height=0.8, color=compute_color, + edgecolor="none", linewidth=0) + + # XCD annotations on the right margin + xcd_set = sorted(set(xcds.tolist())) + xcd_cmap = {} + if len(xcd_set) > 1: + cmap = matplotlib.colormaps.get_cmap("tab10").resampled(len(xcd_set)) + for i, x in enumerate(xcd_set): + xcd_cmap[x] = cmap(i) + + x_max = ends_us.max() * 1.02 + for y_idx, wg_idx in enumerate(order): + xcd_id = xcds[wg_idx] + if xcd_id in xcd_cmap: + ax.plot(x_max, y_idx, marker="s", markersize=1.5, + color=xcd_cmap[xcd_id], clip_on=False) + + ax.set_xlabel("Time (us)", fontsize=12) + ax.set_ylabel("Workgroup (sorted by start time)", fontsize=12) + ax.set_title( + f"Rank {rank} | All-Gather GEMM Trace | " + f"M={M} N={N} K={K} | " + f"{n_fetch} fetchers + {grid_size - n_fetch} GEMM workgroups", + fontsize=13, + ) + ax.set_ylim(-1, grid_size + 1) + ax.set_xlim(0, x_max) + + # Invert y so earliest-starting workgroups are at top + ax.invert_yaxis() + + # Legend + legend_elements = [ + Line2D([0], [0], color=fetch_color, lw=6, label="Fetcher (all-gather)"), + Line2D([0], [0], color=wait_color, lw=6, label="GEMM: waiting on data"), + Line2D([0], [0], color=compute_color, lw=6, label="GEMM: compute"), + ] + ax.legend(handles=legend_elements, loc="upper right", fontsize=10) + + # Summary stats + fetch_mask = roles == 0 + gemm_mask = roles == 1 + fetch_dur = (ends_us - starts_us)[fetch_mask] + gemm_dur = (ends_us - starts_us)[gemm_mask] + gemm_wait = waits_us[gemm_mask] + gemm_compute = gemm_dur - gemm_wait + + stats_text = ( + f"Fetcher: {fetch_dur.mean():.1f} us avg ({fetch_dur.min():.1f}-{fetch_dur.max():.1f})\n" + f"GEMM total: {gemm_dur.mean():.1f} us avg ({gemm_dur.min():.1f}-{gemm_dur.max():.1f})\n" + f" wait: {gemm_wait.mean():.1f} us avg ({gemm_wait.min():.1f}-{gemm_wait.max():.1f})\n" + f" compute: {gemm_compute.mean():.1f} us avg ({gemm_compute.min():.1f}-{gemm_compute.max():.1f})\n" + f" wait%: {100 * gemm_wait.sum() / gemm_dur.sum():.1f}%\n" + f"Wall time: {ends_us.max():.1f} us" + ) + ax.text( + 0.01, 0.99, stats_text, transform=ax.transAxes, + fontsize=9, verticalalignment="top", fontfamily="monospace", + bbox=dict(boxstyle="round,pad=0.4", facecolor="white", alpha=0.85), + ) + + plt.tight_layout() + fig.savefig(output_path, dpi=150, bbox_inches="tight") + plt.close(fig) + print(f" [Rank {rank}] Trace plot saved to: {output_path}") + print(f" {stats_text}") + def parse_args(): parser = argparse.ArgumentParser( @@ -72,6 +202,8 @@ def parse_args(): parser.add_argument("--k_per_flag", type=int, default=1, help="K-blocks per ready flag") parser.add_argument("--num_warps", type=int, default=None, help="Triton num_warps (auto if None)") parser.add_argument("--num_stages", type=int, default=None, help="Triton num_stages (auto if None)") + parser.add_argument("--trace", action="store_true", help="Collect per-workgroup trace and save Gantt chart PNG") + parser.add_argument("--trace_output", type=str, default="trace_gantt.png", help="Output path for trace plot") return vars(parser.parse_args()) @@ -80,6 +212,8 @@ def _worker(args): local_rank = int(os.environ.get("RANK", os.environ.get("LOCAL_RANK", 0))) world_size_env = int(os.environ.get("WORLD_SIZE", 1)) + t0 = time.perf_counter() + backend = "nccl" if torch.cuda.is_available() else "gloo" if "RANK" in os.environ or "LOCAL_RANK" in os.environ: @@ -97,10 +231,18 @@ def _worker(args): device_id=torch.device(f"cuda:{local_rank}") if torch.cuda.is_available() else None, ) + t1 = time.perf_counter() + shmem = iris.iris(args["heap_size"]) rank = shmem.get_rank() world_size = shmem.get_num_ranks() + t2 = time.perf_counter() + shmem.info( + f"Startup: dist.init={t1 - t0:.1f}s, iris.init={t2 - t1:.1f}s, " + f"total={t2 - t0:.1f}s" + ) + datatype_map = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16} datatype = datatype_map.get(args["datatype"], torch.float16) @@ -309,6 +451,56 @@ def run_experiment(): shmem.barrier() + # ── Trace ──────────────────────────────────────────────────────────── + if args["trace"]: + # Warmup: compile the TRACE=True kernel variant before the real run + shmem.info("Trace warmup (compiling traced kernel variant)...") + C.zero_() + workspace.locks.zero_() + shmem.barrier() + all_gather_matmul_hbm_buffer( + shmem, C, A_sharded, B, + config=config, async_op=False, workspace=workspace, + num_fetch_sms=num_fetch_sms, k_per_flag=k_per_flag, + num_warps=num_warps, num_stages=num_stages, + trace=True, + ) + torch.cuda.synchronize() + shmem.barrier() + + # Actual traced run (post-compilation, clean state) + shmem.info("Running single traced iteration...") + C.zero_() + workspace.locks.zero_() + shmem.barrier() + + all_gather_matmul_hbm_buffer( + shmem, + C, + A_sharded, + B, + config=config, + async_op=False, + workspace=workspace, + num_fetch_sms=num_fetch_sms, + k_per_flag=k_per_flag, + num_warps=num_warps, + num_stages=num_stages, + trace=True, + ) + torch.cuda.synchronize() + shmem.barrier() + + if rank == 0 and hasattr(workspace, "trace_data"): + trace_out = args.get("trace_output", "trace_gantt.png") + try: + _plot_trace(workspace.trace_data, trace_out, rank, M, N, K, num_fetch_sms) + except ImportError: + print(" (matplotlib not available -- skipping trace plot)") + except Exception as e: + print(f" (Trace plot failed: {e})") + shmem.barrier() + # ── PyTorch baseline ───────────────────────────────────────────────── if args["benchmark_pytorch"]: shmem.info("Benchmarking PyTorch (all_gather_into_tensor + matmul)...") diff --git a/iris/ops/all_gather_matmul_hbm_buffer.py b/iris/ops/all_gather_matmul_hbm_buffer.py index 329516bc0..c797c939a 100644 --- a/iris/ops/all_gather_matmul_hbm_buffer.py +++ b/iris/ops/all_gather_matmul_hbm_buffer.py @@ -16,6 +16,7 @@ import iris import iris.x +from iris.device_utils import read_realtime, get_xcc_id from .config import FusedConfig from .workspace import FusedWorkspace @@ -58,11 +59,20 @@ def _hbm_buffer_all_gather_matmul_kernel( TOTAL_GATHER_TILES: tl.constexpr, BIAS: tl.constexpr, ALLOW_TF32: tl.constexpr, + trace_start_ptr, + trace_end_ptr, + trace_wait_ptr, + trace_xcd_ptr, + TRACE: tl.constexpr, ): pid = tl.program_id(0) acc_dtype = tl.int32 if C.type.element_ty == tl.int8 else tl.float32 zero = tl.program_id(0) * 0 + if TRACE: + tl.store(trace_start_ptr + pid, read_realtime()) + tl.store(trace_xcd_ptr + pid, get_xcc_id()) + if pid < NUM_FETCH_SMS: # ============================================================== # FETCHER @@ -102,11 +112,15 @@ def _hbm_buffer_all_gather_matmul_kernel( for compile_rank in range(world_size): if src_rank_idx == compile_rank: a_tile = iris.x.gather(k_tile, src_view, compile_rank, ctx) - tl.store(staged_ptrs, a_tile,cache_modifier=".wt") + tl.store(staged_ptrs, a_tile,cache_modifier=".cg") flag_idx = m_tile * NUM_FLAG_GROUPS_K + k_flag_group - #tl.atomic_xchg(flags_ptr + flag_idx, 1, sem="release", scope="gpu") - tl.store(flags_ptr + flag_idx, 1) + tl.atomic_xchg(flags_ptr + flag_idx, 1, sem="release", scope="gpu") + #tl.store(flags_ptr + flag_idx, 1,cache_modifier=".wt") + + if TRACE: + tl.store(trace_wait_ptr + pid, zero.to(tl.int64),cache_modifier=".wt") + tl.store(trace_end_ptr + pid, read_realtime(),cache_modifier=".wt") else: # ============================================================== @@ -128,11 +142,20 @@ def _hbm_buffer_all_gather_matmul_kernel( acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) + if TRACE: + _wt = zero.to(tl.int64) + for k_fg in range(NUM_FLAG_GROUPS_K): + if TRACE: + _ws = read_realtime() + flag_idx = pid_m * NUM_FLAG_GROUPS_K + k_fg while tl.atomic_add(flags_ptr + flag_idx, 0, sem="acquire", scope="gpu") == 0: pass + if TRACE: + _wt = _wt + (read_realtime() - _ws) + k_block_base = k_fg * K_PER_FLAG for k_off in range(K_PER_FLAG): k_block = k_block_base + k_off @@ -158,7 +181,11 @@ def _hbm_buffer_all_gather_matmul_kernel( c = acc.to(C.type.element_ty) C_ptrs = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn c_mask = (rm[:, None] < M) & (rn[None, :] < N) - tl.store(C_ptrs, c, mask=c_mask) + tl.store(C_ptrs, c, mask=c_mask,cache_modifier=".wt") + + if TRACE: + tl.store(trace_wait_ptr + pid, _wt) + tl.store(trace_end_ptr + pid, read_realtime(),cache_modifier=".wt") # ========================================================================== @@ -245,6 +272,7 @@ def all_gather_matmul_hbm_buffer( staged_a_layout: str = "k_contiguous", num_warps: Optional[int] = None, num_stages: Optional[int] = None, + trace: bool = False, ) -> FusedWorkspace: """ All-gather + matmul with dedicated fetcher/GEMM workgroups. @@ -315,6 +343,18 @@ def all_gather_matmul_hbm_buffer( grid_size = num_fetch_sms + total_gemm_tiles + # Trace buffers + if trace: + trace_start = torch.zeros(grid_size, dtype=torch.int64, device=device) + trace_end = torch.zeros(grid_size, dtype=torch.int64, device=device) + trace_wait = torch.zeros(grid_size, dtype=torch.int64, device=device) + trace_xcd = torch.zeros(grid_size, dtype=torch.int32, device=device) + else: + trace_start = torch.empty(1, dtype=torch.int64, device=device) + trace_end = torch.empty(1, dtype=torch.int64, device=device) + trace_wait = torch.empty(1, dtype=torch.int64, device=device) + trace_xcd = torch.empty(1, dtype=torch.int32, device=device) + launch_kwargs = {"matrix_instr_nonkdim": 16} if num_warps is not None: launch_kwargs["num_warps"] = num_warps @@ -358,10 +398,28 @@ def all_gather_matmul_hbm_buffer( total_gather_tiles, use_bias, config.allow_tf32, + trace_start, + trace_end, + trace_wait, + trace_xcd, + trace, **launch_kwargs, ) if not async_op: shmem.barrier() + if trace: + torch.cuda.synchronize() + workspace.trace_data = { + "start": trace_start.cpu(), + "end": trace_end.cpu(), + "wait": trace_wait.cpu(), + "xcd": trace_xcd.cpu(), + "grid_size": grid_size, + "num_fetch_sms": num_fetch_sms, + "num_m_tiles": num_m_tiles, + "num_tiles_n": num_tiles_n, + } + return workspace From 11d017aa8f1bdeefe5da27ba72a85113c3b5784c Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Tue, 17 Feb 2026 00:00:52 +0000 Subject: [PATCH 15/31] Apply Ruff auto-fixes --- .../all_gather_matmul/benchmark_hbm_buffer.py | 49 ++++++++++--------- iris/ops/all_gather_matmul_hbm_buffer.py | 12 ++--- 2 files changed, 33 insertions(+), 28 deletions(-) diff --git a/benchmark/ops/all_gather_matmul/benchmark_hbm_buffer.py b/benchmark/ops/all_gather_matmul/benchmark_hbm_buffer.py index 62dfa9acb..3bf2edf92 100644 --- a/benchmark/ops/all_gather_matmul/benchmark_hbm_buffer.py +++ b/benchmark/ops/all_gather_matmul/benchmark_hbm_buffer.py @@ -47,9 +47,9 @@ def _plot_trace(trace_data, output_path, rank, M, N, K, num_fetch_sms_cfg): Colors: fetcher (blue), GEMM wait (red), GEMM compute (green) """ import matplotlib + matplotlib.use("Agg") import matplotlib.pyplot as plt - from matplotlib.patches import Rectangle from matplotlib.lines import Line2D starts = trace_data["start"].numpy().astype(np.int64) @@ -76,9 +76,9 @@ def _plot_trace(trace_data, output_path, rank, M, N, K, num_fetch_sms_cfg): fig_h = max(12, grid_size * row_h + 2) fig, ax = plt.subplots(figsize=(18, fig_h)) - fetch_color = "#2196F3" # blue - wait_color = "#F44336" # red - compute_color = "#4CAF50" # green + fetch_color = "#2196F3" # blue + wait_color = "#F44336" # red + compute_color = "#4CAF50" # green for y_idx, wg_idx in enumerate(order): s = starts_us[wg_idx] @@ -88,17 +88,14 @@ def _plot_trace(trace_data, output_path, rank, M, N, K, num_fetch_sms_cfg): if role == 0: # Fetcher: solid blue bar - ax.barh(y_idx, dur, left=s, height=0.8, color=fetch_color, - edgecolor="none", linewidth=0) + ax.barh(y_idx, dur, left=s, height=0.8, color=fetch_color, edgecolor="none", linewidth=0) else: # GEMM: split into wait (red) and compute (green) w = waits_us[wg_idx] c = max(0, dur - w) # Show wait portion first, then compute - ax.barh(y_idx, w, left=s, height=0.8, color=wait_color, - edgecolor="none", linewidth=0) - ax.barh(y_idx, c, left=s + w, height=0.8, color=compute_color, - edgecolor="none", linewidth=0) + ax.barh(y_idx, w, left=s, height=0.8, color=wait_color, edgecolor="none", linewidth=0) + ax.barh(y_idx, c, left=s + w, height=0.8, color=compute_color, edgecolor="none", linewidth=0) # XCD annotations on the right margin xcd_set = sorted(set(xcds.tolist())) @@ -112,8 +109,7 @@ def _plot_trace(trace_data, output_path, rank, M, N, K, num_fetch_sms_cfg): for y_idx, wg_idx in enumerate(order): xcd_id = xcds[wg_idx] if xcd_id in xcd_cmap: - ax.plot(x_max, y_idx, marker="s", markersize=1.5, - color=xcd_cmap[xcd_id], clip_on=False) + ax.plot(x_max, y_idx, marker="s", markersize=1.5, color=xcd_cmap[xcd_id], clip_on=False) ax.set_xlabel("Time (us)", fontsize=12) ax.set_ylabel("Workgroup (sorted by start time)", fontsize=12) @@ -154,8 +150,13 @@ def _plot_trace(trace_data, output_path, rank, M, N, K, num_fetch_sms_cfg): f"Wall time: {ends_us.max():.1f} us" ) ax.text( - 0.01, 0.99, stats_text, transform=ax.transAxes, - fontsize=9, verticalalignment="top", fontfamily="monospace", + 0.01, + 0.99, + stats_text, + transform=ax.transAxes, + fontsize=9, + verticalalignment="top", + fontfamily="monospace", bbox=dict(boxstyle="round,pad=0.4", facecolor="white", alpha=0.85), ) @@ -238,10 +239,7 @@ def _worker(args): world_size = shmem.get_num_ranks() t2 = time.perf_counter() - shmem.info( - f"Startup: dist.init={t1 - t0:.1f}s, iris.init={t2 - t1:.1f}s, " - f"total={t2 - t0:.1f}s" - ) + shmem.info(f"Startup: dist.init={t1 - t0:.1f}s, iris.init={t2 - t1:.1f}s, total={t2 - t0:.1f}s") datatype_map = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16} datatype = datatype_map.get(args["datatype"], torch.float16) @@ -457,10 +455,17 @@ def run_experiment(): workspace.locks.zero_() shmem.barrier() all_gather_matmul_hbm_buffer( - shmem, C, A_sharded, B, - config=config, async_op=False, workspace=workspace, - num_fetch_sms=num_fetch_sms, k_per_flag=k_per_flag, - num_warps=num_warps, num_stages=num_stages, + shmem, + C, + A_sharded, + B, + config=config, + async_op=False, + workspace=workspace, + num_fetch_sms=num_fetch_sms, + k_per_flag=k_per_flag, + num_warps=num_warps, + num_stages=num_stages, trace=True, ) torch.cuda.synchronize() diff --git a/iris/ops/all_gather_matmul_hbm_buffer.py b/iris/ops/all_gather_matmul_hbm_buffer.py index c797c939a..e7f3b11bd 100644 --- a/iris/ops/all_gather_matmul_hbm_buffer.py +++ b/iris/ops/all_gather_matmul_hbm_buffer.py @@ -112,15 +112,15 @@ def _hbm_buffer_all_gather_matmul_kernel( for compile_rank in range(world_size): if src_rank_idx == compile_rank: a_tile = iris.x.gather(k_tile, src_view, compile_rank, ctx) - tl.store(staged_ptrs, a_tile,cache_modifier=".cg") + tl.store(staged_ptrs, a_tile, cache_modifier=".cg") flag_idx = m_tile * NUM_FLAG_GROUPS_K + k_flag_group tl.atomic_xchg(flags_ptr + flag_idx, 1, sem="release", scope="gpu") - #tl.store(flags_ptr + flag_idx, 1,cache_modifier=".wt") + # tl.store(flags_ptr + flag_idx, 1,cache_modifier=".wt") if TRACE: - tl.store(trace_wait_ptr + pid, zero.to(tl.int64),cache_modifier=".wt") - tl.store(trace_end_ptr + pid, read_realtime(),cache_modifier=".wt") + tl.store(trace_wait_ptr + pid, zero.to(tl.int64), cache_modifier=".wt") + tl.store(trace_end_ptr + pid, read_realtime(), cache_modifier=".wt") else: # ============================================================== @@ -181,11 +181,11 @@ def _hbm_buffer_all_gather_matmul_kernel( c = acc.to(C.type.element_ty) C_ptrs = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn c_mask = (rm[:, None] < M) & (rn[None, :] < N) - tl.store(C_ptrs, c, mask=c_mask,cache_modifier=".wt") + tl.store(C_ptrs, c, mask=c_mask, cache_modifier=".wt") if TRACE: tl.store(trace_wait_ptr + pid, _wt) - tl.store(trace_end_ptr + pid, read_realtime(),cache_modifier=".wt") + tl.store(trace_end_ptr + pid, read_realtime(), cache_modifier=".wt") # ========================================================================== From ace40d0df894098dbc64f91ce7ac344dba11b4f8 Mon Sep 17 00:00:00 2001 From: Ryan Swann Date: Mon, 16 Feb 2026 18:48:50 -0600 Subject: [PATCH 16/31] Add stages to all_gather_matmul_hbm_buffer --- .../all_gather_matmul/benchmark_hbm_buffer.py | 93 ++++++++++----- iris/ops/all_gather_matmul_hbm_buffer.py | 109 +++++++++++------- 2 files changed, 134 insertions(+), 68 deletions(-) diff --git a/benchmark/ops/all_gather_matmul/benchmark_hbm_buffer.py b/benchmark/ops/all_gather_matmul/benchmark_hbm_buffer.py index 62dfa9acb..1991111c7 100644 --- a/benchmark/ops/all_gather_matmul/benchmark_hbm_buffer.py +++ b/benchmark/ops/all_gather_matmul/benchmark_hbm_buffer.py @@ -44,12 +44,11 @@ def _plot_trace(trace_data, output_path, rank, M, N, K, num_fetch_sms_cfg): Y-axis: workgroup (sorted by start time) X-axis: time in microseconds - Colors: fetcher (blue), GEMM wait (red), GEMM compute (green) + Colors: fetcher stages (blue shades), GEMM wait (red), GEMM compute (green) """ import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt - from matplotlib.patches import Rectangle from matplotlib.lines import Line2D starts = trace_data["start"].numpy().astype(np.int64) @@ -57,7 +56,10 @@ def _plot_trace(trace_data, output_path, rank, M, N, K, num_fetch_sms_cfg): waits = trace_data["wait"].numpy().astype(np.int64) xcds = trace_data["xcd"].numpy().astype(np.int32) grid_size = trace_data["grid_size"] - n_fetch = trace_data["num_fetch_sms"] + n_fetch_per_stage = trace_data["num_fetch_sms"] + n_stages = trace_data.get("num_fetch_stages", 1) + total_fetch = trace_data.get("total_fetch_wgs", n_fetch_per_stage) + wgs_per_stage = trace_data.get("wgs_per_stage", grid_size) # Convert to microseconds relative to earliest start t_min = starts.min() @@ -65,8 +67,16 @@ def _plot_trace(trace_data, output_path, rank, M, N, K, num_fetch_sms_cfg): ends_us = (ends - t_min) / TICKS_PER_US waits_us = waits / TICKS_PER_US - # Build role array: 0=fetcher, 1=GEMM - roles = np.array([0 if i < n_fetch else 1 for i in range(grid_size)]) + # Build role array: stage index for fetchers (0..S-1), S for GEMM + # Interleaved layout: [fetch0 | gemm0 | fetch1 | gemm1 | ...] + roles = np.empty(grid_size, dtype=np.int32) + for i in range(grid_size): + stage = i // wgs_per_stage + local = i % wgs_per_stage + if local < n_fetch_per_stage: + roles[i] = stage # fetcher for this stage + else: + roles[i] = n_stages # GEMM # Sort by start time order = np.argsort(starts_us) @@ -76,7 +86,8 @@ def _plot_trace(trace_data, output_path, rank, M, N, K, num_fetch_sms_cfg): fig_h = max(12, grid_size * row_h + 2) fig, ax = plt.subplots(figsize=(18, fig_h)) - fetch_color = "#2196F3" # blue + # One color per fetch stage (blue palette), plus GEMM colors + fetch_blues = ["#1565C0", "#42A5F5", "#90CAF9", "#BBDEFB"] wait_color = "#F44336" # red compute_color = "#4CAF50" # green @@ -86,18 +97,18 @@ def _plot_trace(trace_data, output_path, rank, M, N, K, num_fetch_sms_cfg): dur = e - s role = roles[wg_idx] - if role == 0: - # Fetcher: solid blue bar - ax.barh(y_idx, dur, left=s, height=0.8, color=fetch_color, + if role < n_stages: + # Fetcher: color by stage + c = fetch_blues[role % len(fetch_blues)] + ax.barh(y_idx, dur, left=s, height=0.8, color=c, edgecolor="none", linewidth=0) else: # GEMM: split into wait (red) and compute (green) w = waits_us[wg_idx] - c = max(0, dur - w) - # Show wait portion first, then compute + comp = max(0, dur - w) ax.barh(y_idx, w, left=s, height=0.8, color=wait_color, edgecolor="none", linewidth=0) - ax.barh(y_idx, c, left=s + w, height=0.8, color=compute_color, + ax.barh(y_idx, comp, left=s + w, height=0.8, color=compute_color, edgecolor="none", linewidth=0) # XCD annotations on the right margin @@ -115,12 +126,15 @@ def _plot_trace(trace_data, output_path, rank, M, N, K, num_fetch_sms_cfg): ax.plot(x_max, y_idx, marker="s", markersize=1.5, color=xcd_cmap[xcd_id], clip_on=False) + n_gemm = grid_size - total_fetch + stage_info = (f"{n_stages}x{n_fetch_per_stage}" if n_stages > 1 + else str(n_fetch_per_stage)) ax.set_xlabel("Time (us)", fontsize=12) ax.set_ylabel("Workgroup (sorted by start time)", fontsize=12) ax.set_title( f"Rank {rank} | All-Gather GEMM Trace | " f"M={M} N={N} K={K} | " - f"{n_fetch} fetchers + {grid_size - n_fetch} GEMM workgroups", + f"{stage_info} fetchers + {n_gemm} GEMM workgroups", fontsize=13, ) ax.set_ylim(-1, grid_size + 1) @@ -130,29 +144,45 @@ def _plot_trace(trace_data, output_path, rank, M, N, K, num_fetch_sms_cfg): ax.invert_yaxis() # Legend - legend_elements = [ - Line2D([0], [0], color=fetch_color, lw=6, label="Fetcher (all-gather)"), - Line2D([0], [0], color=wait_color, lw=6, label="GEMM: waiting on data"), - Line2D([0], [0], color=compute_color, lw=6, label="GEMM: compute"), - ] + legend_elements = [] + for s_idx in range(min(n_stages, len(fetch_blues))): + legend_elements.append( + Line2D([0], [0], color=fetch_blues[s_idx], lw=6, + label=f"Fetch stage {s_idx}") + ) + legend_elements.append( + Line2D([0], [0], color=wait_color, lw=6, label="GEMM: waiting on data")) + legend_elements.append( + Line2D([0], [0], color=compute_color, lw=6, label="GEMM: compute")) ax.legend(handles=legend_elements, loc="upper right", fontsize=10) # Summary stats - fetch_mask = roles == 0 - gemm_mask = roles == 1 + fetch_mask = roles < n_stages + gemm_mask = roles == n_stages fetch_dur = (ends_us - starts_us)[fetch_mask] gemm_dur = (ends_us - starts_us)[gemm_mask] gemm_wait = waits_us[gemm_mask] gemm_compute = gemm_dur - gemm_wait - stats_text = ( - f"Fetcher: {fetch_dur.mean():.1f} us avg ({fetch_dur.min():.1f}-{fetch_dur.max():.1f})\n" - f"GEMM total: {gemm_dur.mean():.1f} us avg ({gemm_dur.min():.1f}-{gemm_dur.max():.1f})\n" - f" wait: {gemm_wait.mean():.1f} us avg ({gemm_wait.min():.1f}-{gemm_wait.max():.1f})\n" - f" compute: {gemm_compute.mean():.1f} us avg ({gemm_compute.min():.1f}-{gemm_compute.max():.1f})\n" - f" wait%: {100 * gemm_wait.sum() / gemm_dur.sum():.1f}%\n" - f"Wall time: {ends_us.max():.1f} us" - ) + stats_lines = [] + for s_idx in range(n_stages): + s_mask = roles == s_idx + s_dur = (ends_us - starts_us)[s_mask] + s_start = starts_us[s_mask] + if len(s_dur) > 0: + stats_lines.append( + f"Fetch stg{s_idx}: {s_dur.mean():.1f} us avg " + f"({s_dur.min():.1f}-{s_dur.max():.1f}) " + f"first@{s_start.min():.0f}us" + ) + stats_lines += [ + f"GEMM total: {gemm_dur.mean():.1f} us avg ({gemm_dur.min():.1f}-{gemm_dur.max():.1f})", + f" wait: {gemm_wait.mean():.1f} us avg ({gemm_wait.min():.1f}-{gemm_wait.max():.1f})", + f" compute: {gemm_compute.mean():.1f} us avg ({gemm_compute.min():.1f}-{gemm_compute.max():.1f})", + f" wait%: {100 * gemm_wait.sum() / gemm_dur.sum():.1f}%", + f"Wall time: {ends_us.max():.1f} us", + ] + stats_text = "\n".join(stats_lines) ax.text( 0.01, 0.99, stats_text, transform=ax.transAxes, fontsize=9, verticalalignment="top", fontfamily="monospace", @@ -202,6 +232,7 @@ def parse_args(): parser.add_argument("--k_per_flag", type=int, default=1, help="K-blocks per ready flag") parser.add_argument("--num_warps", type=int, default=None, help="Triton num_warps (auto if None)") parser.add_argument("--num_stages", type=int, default=None, help="Triton num_stages (auto if None)") + parser.add_argument("--num_fetch_stages", type=int, default=1, help="Number of fetch stages (1=all at once, 2=top/bottom half, etc.)") parser.add_argument("--trace", action="store_true", help="Collect per-workgroup trace and save Gantt chart PNG") parser.add_argument("--trace_output", type=str, default="trace_gantt.png", help="Output path for trace plot") return vars(parser.parse_args()) @@ -321,6 +352,7 @@ def _worker(args): num_fetch_sms = args["num_fetch_sms"] num_warps = args["num_warps"] num_stages = args["num_stages"] + num_fetch_stages = args["num_fetch_stages"] def run_experiment(): nonlocal total_ms, num_experiments @@ -339,6 +371,7 @@ def run_experiment(): k_per_flag=k_per_flag, num_warps=num_warps, num_stages=num_stages, + num_fetch_stages=num_fetch_stages, ) end_ev.record() num_experiments += 1 @@ -422,6 +455,7 @@ def run_experiment(): k_per_flag=k_per_flag, num_warps=num_warps, num_stages=num_stages, + num_fetch_stages=num_fetch_stages, ) torch.cuda.synchronize() t_end = time.perf_counter() @@ -461,7 +495,7 @@ def run_experiment(): config=config, async_op=False, workspace=workspace, num_fetch_sms=num_fetch_sms, k_per_flag=k_per_flag, num_warps=num_warps, num_stages=num_stages, - trace=True, + num_fetch_stages=num_fetch_stages, trace=True, ) torch.cuda.synchronize() shmem.barrier() @@ -484,6 +518,7 @@ def run_experiment(): k_per_flag=k_per_flag, num_warps=num_warps, num_stages=num_stages, + num_fetch_stages=num_fetch_stages, trace=True, ) torch.cuda.synchronize() diff --git a/iris/ops/all_gather_matmul_hbm_buffer.py b/iris/ops/all_gather_matmul_hbm_buffer.py index c797c939a..e9db8f5e7 100644 --- a/iris/ops/all_gather_matmul_hbm_buffer.py +++ b/iris/ops/all_gather_matmul_hbm_buffer.py @@ -59,6 +59,8 @@ def _hbm_buffer_all_gather_matmul_kernel( TOTAL_GATHER_TILES: tl.constexpr, BIAS: tl.constexpr, ALLOW_TF32: tl.constexpr, + NUM_FETCH_STAGES: tl.constexpr, + GEMM_TILES_PER_STAGE: tl.constexpr, trace_start_ptr, trace_end_ptr, trace_wait_ptr, @@ -73,67 +75,83 @@ def _hbm_buffer_all_gather_matmul_kernel( tl.store(trace_start_ptr + pid, read_realtime()) tl.store(trace_xcd_ptr + pid, get_xcc_id()) - if pid < NUM_FETCH_SMS: + # Interleaved layout: [fetch0 | gemm0 | fetch1 | gemm1 | ...] + WGS_PER_STAGE: tl.constexpr = NUM_FETCH_SMS + GEMM_TILES_PER_STAGE + M_PER_STAGE: tl.constexpr = (NUM_M_TILES + NUM_FETCH_STAGES - 1) // NUM_FETCH_STAGES + + local_pid = pid % WGS_PER_STAGE + + if local_pid < NUM_FETCH_SMS: # ============================================================== - # FETCHER + # FETCHER — interleaved: stage determined by pid // WGS_PER_STAGE # ============================================================== + my_stage = pid // WGS_PER_STAGE + stage_pid = local_pid + ctx = iris.DeviceContext.initialize(context_tensor, cur_rank, world_size) src_view = iris.x.make_tensor_view(A_sharded, M, K_local, stride_am, stride_ak) - num_m_groups = (NUM_M_TILES + GROUP_SIZE_M - 1) // GROUP_SIZE_M tiles_per_m_group = NUM_FLAG_GROUPS_K * GROUP_SIZE_M - total_flag_groups = NUM_FLAG_GROUPS_K * NUM_M_TILES - for fg_idx in range(pid, total_flag_groups, NUM_FETCH_SMS): - m_group = fg_idx // tiles_per_m_group - within_group = fg_idx % tiles_per_m_group - k_flag_group = within_group // GROUP_SIZE_M - m_in_group = within_group % GROUP_SIZE_M - m_tile = m_group * GROUP_SIZE_M + m_in_group - m_tile = min(m_tile, NUM_M_TILES - 1) - k_block_start = k_flag_group * K_PER_FLAG + for const_stage in range(NUM_FETCH_STAGES): + if my_stage == const_stage: + stage_m_start = const_stage * M_PER_STAGE + stage_m_count = min(M_PER_STAGE, NUM_M_TILES - stage_m_start) + total_fg_stage = NUM_FLAG_GROUPS_K * stage_m_count - rm = m_tile * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + for fg_idx in range(stage_pid, total_fg_stage, NUM_FETCH_SMS): + m_group = fg_idx // tiles_per_m_group + within_group = fg_idx % tiles_per_m_group + k_flag_group = within_group // GROUP_SIZE_M + m_in_group = within_group % GROUP_SIZE_M + m_tile = stage_m_start + m_group * GROUP_SIZE_M + m_in_group + m_tile = min(m_tile, NUM_M_TILES - 1) + k_block_start = k_flag_group * K_PER_FLAG - for k_off in range(K_PER_FLAG): - k_block_global = k_block_start + k_off + rm = m_tile * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - src_rank_idx = k_block_global // NUM_K_BLOCKS_LOCAL - k_block_local = k_block_global % NUM_K_BLOCKS_LOCAL + for k_off in range(K_PER_FLAG): + k_block_global = k_block_start + k_off - pid_m_t = zero + m_tile - tile_k_t = zero + k_block_local - k_tile = iris.x.TileView(pid_m_t, tile_k_t, BLOCK_SIZE_M, BLOCK_SIZE_K) + src_rank_idx = k_block_global // NUM_K_BLOCKS_LOCAL + k_block_local = k_block_global % NUM_K_BLOCKS_LOCAL - rk = k_block_global * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) - # Use parameterized strides for staged_a - staged_ptrs = staged_a + rm[:, None] * stride_sa_m + rk[None, :] * stride_sa_k + pid_m_t = zero + m_tile + tile_k_t = zero + k_block_local + k_tile = iris.x.TileView(pid_m_t, tile_k_t, BLOCK_SIZE_M, BLOCK_SIZE_K) + + rk = k_block_global * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + staged_ptrs = staged_a + rm[:, None] * stride_sa_m + rk[None, :] * stride_sa_k - for compile_rank in range(world_size): - if src_rank_idx == compile_rank: - a_tile = iris.x.gather(k_tile, src_view, compile_rank, ctx) - tl.store(staged_ptrs, a_tile,cache_modifier=".cg") + for compile_rank in range(world_size): + if src_rank_idx == compile_rank: + a_tile = iris.x.gather(k_tile, src_view, compile_rank, ctx) + tl.store(staged_ptrs, a_tile, cache_modifier=".cg") - flag_idx = m_tile * NUM_FLAG_GROUPS_K + k_flag_group - tl.atomic_xchg(flags_ptr + flag_idx, 1, sem="release", scope="gpu") - #tl.store(flags_ptr + flag_idx, 1,cache_modifier=".wt") + flag_idx = m_tile * NUM_FLAG_GROUPS_K + k_flag_group + tl.atomic_xchg(flags_ptr + flag_idx, 1, sem="release", scope="gpu") if TRACE: - tl.store(trace_wait_ptr + pid, zero.to(tl.int64),cache_modifier=".wt") - tl.store(trace_end_ptr + pid, read_realtime(),cache_modifier=".wt") + tl.store(trace_wait_ptr + pid, zero.to(tl.int64), cache_modifier=".wt") + tl.store(trace_end_ptr + pid, read_realtime(), cache_modifier=".wt") else: # ============================================================== - # GEMM + # GEMM — interleaved: stage determined by pid // WGS_PER_STAGE + # gemm_local_id indexes into this stage's M-tile range # ============================================================== - gemm_tile_id = pid - NUM_FETCH_SMS + my_stage = pid // WGS_PER_STAGE + gemm_local_id = local_pid - NUM_FETCH_SMS + stage_m_start = my_stage * M_PER_STAGE num_pid_in_group = GROUP_SIZE_M * NUM_TILES_N - group_id = gemm_tile_id // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M + group_id = gemm_local_id // num_pid_in_group + first_pid_m = stage_m_start + group_id * GROUP_SIZE_M + first_pid_m = min(first_pid_m, NUM_M_TILES - 1) group_sz = min(NUM_M_TILES - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((gemm_tile_id % num_pid_in_group) % group_sz) - pid_n = (gemm_tile_id % num_pid_in_group) // group_sz + pid_m = first_pid_m + ((gemm_local_id % num_pid_in_group) % group_sz) + pid_n = (gemm_local_id % num_pid_in_group) // group_sz + pid_m = min(pid_m, NUM_M_TILES - 1) rm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) @@ -272,6 +290,7 @@ def all_gather_matmul_hbm_buffer( staged_a_layout: str = "k_contiguous", num_warps: Optional[int] = None, num_stages: Optional[int] = None, + num_fetch_stages: int = 1, trace: bool = False, ) -> FusedWorkspace: """ @@ -340,8 +359,14 @@ def all_gather_matmul_hbm_buffer( if num_fetch_sms is None: num_fetch_sms = max(1, num_sms // 10) assert 0 < num_fetch_sms + assert num_fetch_stages >= 1 - grid_size = num_fetch_sms + total_gemm_tiles + # Interleaved layout: [fetch0 | gemm0 | fetch1 | gemm1 | ...] + m_per_stage = (num_m_tiles + num_fetch_stages - 1) // num_fetch_stages + gemm_tiles_per_stage = m_per_stage * num_tiles_n + wgs_per_stage = num_fetch_sms + gemm_tiles_per_stage + total_fetch_wgs = num_fetch_sms * num_fetch_stages + grid_size = wgs_per_stage * num_fetch_stages # Trace buffers if trace: @@ -398,6 +423,8 @@ def all_gather_matmul_hbm_buffer( total_gather_tiles, use_bias, config.allow_tf32, + num_fetch_stages, + gemm_tiles_per_stage, trace_start, trace_end, trace_wait, @@ -418,8 +445,12 @@ def all_gather_matmul_hbm_buffer( "xcd": trace_xcd.cpu(), "grid_size": grid_size, "num_fetch_sms": num_fetch_sms, + "num_fetch_stages": num_fetch_stages, + "total_fetch_wgs": total_fetch_wgs, "num_m_tiles": num_m_tiles, "num_tiles_n": num_tiles_n, + "wgs_per_stage": wgs_per_stage, + "gemm_tiles_per_stage": gemm_tiles_per_stage, } return workspace From f7612bd17ee2e4b683e300c02189f5a3d0924a27 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Tue, 17 Feb 2026 00:51:39 +0000 Subject: [PATCH 17/31] Apply Ruff auto-fixes --- .../all_gather_matmul/benchmark_hbm_buffer.py | 52 ++++++++++--------- 1 file changed, 28 insertions(+), 24 deletions(-) diff --git a/benchmark/ops/all_gather_matmul/benchmark_hbm_buffer.py b/benchmark/ops/all_gather_matmul/benchmark_hbm_buffer.py index b54aadc02..90aacc056 100644 --- a/benchmark/ops/all_gather_matmul/benchmark_hbm_buffer.py +++ b/benchmark/ops/all_gather_matmul/benchmark_hbm_buffer.py @@ -89,8 +89,8 @@ def _plot_trace(trace_data, output_path, rank, M, N, K, num_fetch_sms_cfg): # One color per fetch stage (blue palette), plus GEMM colors fetch_blues = ["#1565C0", "#42A5F5", "#90CAF9", "#BBDEFB"] - wait_color = "#F44336" # red - compute_color = "#4CAF50" # green + wait_color = "#F44336" # red + compute_color = "#4CAF50" # green for y_idx, wg_idx in enumerate(order): s = starts_us[wg_idx] @@ -101,16 +101,13 @@ def _plot_trace(trace_data, output_path, rank, M, N, K, num_fetch_sms_cfg): if role < n_stages: # Fetcher: color by stage c = fetch_blues[role % len(fetch_blues)] - ax.barh(y_idx, dur, left=s, height=0.8, color=c, - edgecolor="none", linewidth=0) + ax.barh(y_idx, dur, left=s, height=0.8, color=c, edgecolor="none", linewidth=0) else: # GEMM: split into wait (red) and compute (green) w = waits_us[wg_idx] comp = max(0, dur - w) - ax.barh(y_idx, w, left=s, height=0.8, color=wait_color, - edgecolor="none", linewidth=0) - ax.barh(y_idx, comp, left=s + w, height=0.8, color=compute_color, - edgecolor="none", linewidth=0) + ax.barh(y_idx, w, left=s, height=0.8, color=wait_color, edgecolor="none", linewidth=0) + ax.barh(y_idx, comp, left=s + w, height=0.8, color=compute_color, edgecolor="none", linewidth=0) # XCD annotations on the right margin xcd_set = sorted(set(xcds.tolist())) @@ -127,8 +124,7 @@ def _plot_trace(trace_data, output_path, rank, M, N, K, num_fetch_sms_cfg): ax.plot(x_max, y_idx, marker="s", markersize=1.5, color=xcd_cmap[xcd_id], clip_on=False) n_gemm = grid_size - total_fetch - stage_info = (f"{n_stages}x{n_fetch_per_stage}" if n_stages > 1 - else str(n_fetch_per_stage)) + stage_info = f"{n_stages}x{n_fetch_per_stage}" if n_stages > 1 else str(n_fetch_per_stage) ax.set_xlabel("Time (us)", fontsize=12) ax.set_ylabel("Workgroup (sorted by start time)", fontsize=12) ax.set_title( @@ -146,14 +142,9 @@ def _plot_trace(trace_data, output_path, rank, M, N, K, num_fetch_sms_cfg): # Legend legend_elements = [] for s_idx in range(min(n_stages, len(fetch_blues))): - legend_elements.append( - Line2D([0], [0], color=fetch_blues[s_idx], lw=6, - label=f"Fetch stage {s_idx}") - ) - legend_elements.append( - Line2D([0], [0], color=wait_color, lw=6, label="GEMM: waiting on data")) - legend_elements.append( - Line2D([0], [0], color=compute_color, lw=6, label="GEMM: compute")) + legend_elements.append(Line2D([0], [0], color=fetch_blues[s_idx], lw=6, label=f"Fetch stage {s_idx}")) + legend_elements.append(Line2D([0], [0], color=wait_color, lw=6, label="GEMM: waiting on data")) + legend_elements.append(Line2D([0], [0], color=compute_color, lw=6, label="GEMM: compute")) ax.legend(handles=legend_elements, loc="upper right", fontsize=10) # Summary stats @@ -237,7 +228,12 @@ def parse_args(): parser.add_argument("--k_per_flag", type=int, default=1, help="K-blocks per ready flag") parser.add_argument("--num_warps", type=int, default=None, help="Triton num_warps (auto if None)") parser.add_argument("--num_stages", type=int, default=None, help="Triton num_stages (auto if None)") - parser.add_argument("--num_fetch_stages", type=int, default=1, help="Number of fetch stages (1=all at once, 2=top/bottom half, etc.)") + parser.add_argument( + "--num_fetch_stages", + type=int, + default=1, + help="Number of fetch stages (1=all at once, 2=top/bottom half, etc.)", + ) parser.add_argument("--trace", action="store_true", help="Collect per-workgroup trace and save Gantt chart PNG") parser.add_argument("--trace_output", type=str, default="trace_gantt.png", help="Output path for trace plot") return vars(parser.parse_args()) @@ -493,11 +489,19 @@ def run_experiment(): workspace.locks.zero_() shmem.barrier() all_gather_matmul_hbm_buffer( - shmem, C, A_sharded, B, - config=config, async_op=False, workspace=workspace, - num_fetch_sms=num_fetch_sms, k_per_flag=k_per_flag, - num_warps=num_warps, num_stages=num_stages, - num_fetch_stages=num_fetch_stages, trace=True, + shmem, + C, + A_sharded, + B, + config=config, + async_op=False, + workspace=workspace, + num_fetch_sms=num_fetch_sms, + k_per_flag=k_per_flag, + num_warps=num_warps, + num_stages=num_stages, + num_fetch_stages=num_fetch_stages, + trace=True, ) torch.cuda.synchronize() shmem.barrier() From 51bccb5eeaa9b7ec1f6878accf9fd6d897a3e4c6 Mon Sep 17 00:00:00 2001 From: Ryan Swann Date: Tue, 17 Feb 2026 15:04:50 -0600 Subject: [PATCH 18/31] Updates to benchmark and kernel --- .../all_gather_matmul/benchmark_hbm_buffer.py | 35 +- .../ops/all_gather_matmul/tune_hbm_buffer.py | 576 ++++++++++++++++++ iris/iris.py | 20 +- iris/ops/all_gather_matmul_hbm_buffer.py | 59 +- 4 files changed, 657 insertions(+), 33 deletions(-) create mode 100644 benchmark/ops/all_gather_matmul/tune_hbm_buffer.py diff --git a/benchmark/ops/all_gather_matmul/benchmark_hbm_buffer.py b/benchmark/ops/all_gather_matmul/benchmark_hbm_buffer.py index b54aadc02..ba6afaa70 100644 --- a/benchmark/ops/all_gather_matmul/benchmark_hbm_buffer.py +++ b/benchmark/ops/all_gather_matmul/benchmark_hbm_buffer.py @@ -60,7 +60,9 @@ def _plot_trace(trace_data, output_path, rank, M, N, K, num_fetch_sms_cfg): n_fetch_per_stage = trace_data["num_fetch_sms"] n_stages = trace_data.get("num_fetch_stages", 1) total_fetch = trace_data.get("total_fetch_wgs", n_fetch_per_stage) - wgs_per_stage = trace_data.get("wgs_per_stage", grid_size) + first_stage_fetch = trace_data.get("first_stage_fetch_sms", n_fetch_per_stage) + first_stage_size = trace_data.get("first_stage_size", grid_size) + rest_stage_size = trace_data.get("rest_stage_size", grid_size) # Convert to microseconds relative to earliest start t_min = starts.min() @@ -69,12 +71,19 @@ def _plot_trace(trace_data, output_path, rank, M, N, K, num_fetch_sms_cfg): waits_us = waits / TICKS_PER_US # Build role array: stage index for fetchers (0..S-1), S for GEMM - # Interleaved layout: [fetch0 | gemm0 | fetch1 | gemm1 | ...] + # Asymmetric layout: [fetch0 (P)] [gemm0] [fetch1 (F)] [gemm1] ... roles = np.empty(grid_size, dtype=np.int32) for i in range(grid_size): - stage = i // wgs_per_stage - local = i % wgs_per_stage - if local < n_fetch_per_stage: + if i < first_stage_size: + stage = 0 + local = i + fetch_thresh = first_stage_fetch + else: + adjusted = i - first_stage_size + stage = 1 + adjusted // rest_stage_size + local = adjusted % rest_stage_size + fetch_thresh = n_fetch_per_stage + if local < fetch_thresh: roles[i] = stage # fetcher for this stage else: roles[i] = n_stages # GEMM @@ -127,8 +136,12 @@ def _plot_trace(trace_data, output_path, rank, M, N, K, num_fetch_sms_cfg): ax.plot(x_max, y_idx, marker="s", markersize=1.5, color=xcd_cmap[xcd_id], clip_on=False) n_gemm = grid_size - total_fetch - stage_info = (f"{n_stages}x{n_fetch_per_stage}" if n_stages > 1 - else str(n_fetch_per_stage)) + if n_stages > 1 and first_stage_fetch != n_fetch_per_stage: + stage_info = f"{first_stage_fetch}+{n_stages - 1}x{n_fetch_per_stage}" + elif n_stages > 1: + stage_info = f"{n_stages}x{n_fetch_per_stage}" + else: + stage_info = str(first_stage_fetch) ax.set_xlabel("Time (us)", fontsize=12) ax.set_ylabel("Workgroup (sorted by start time)", fontsize=12) ax.set_title( @@ -238,6 +251,7 @@ def parse_args(): parser.add_argument("--num_warps", type=int, default=None, help="Triton num_warps (auto if None)") parser.add_argument("--num_stages", type=int, default=None, help="Triton num_stages (auto if None)") parser.add_argument("--num_fetch_stages", type=int, default=1, help="Number of fetch stages (1=all at once, 2=top/bottom half, etc.)") + parser.add_argument("--first_stage_fetch_sms", type=int, default=None, help="Fetcher WGs for stage 0 (fills first GPU wave; defaults to num_fetch_sms)") parser.add_argument("--trace", action="store_true", help="Collect per-workgroup trace and save Gantt chart PNG") parser.add_argument("--trace_output", type=str, default="trace_gantt.png", help="Output path for trace plot") return vars(parser.parse_args()) @@ -355,6 +369,7 @@ def _worker(args): num_warps = args["num_warps"] num_stages = args["num_stages"] num_fetch_stages = args["num_fetch_stages"] + first_stage_fetch_sms = args["first_stage_fetch_sms"] def run_experiment(): nonlocal total_ms, num_experiments @@ -374,6 +389,7 @@ def run_experiment(): num_warps=num_warps, num_stages=num_stages, num_fetch_stages=num_fetch_stages, + first_stage_fetch_sms=first_stage_fetch_sms, ) end_ev.record() num_experiments += 1 @@ -458,6 +474,7 @@ def run_experiment(): num_warps=num_warps, num_stages=num_stages, num_fetch_stages=num_fetch_stages, + first_stage_fetch_sms=first_stage_fetch_sms, ) torch.cuda.synchronize() t_end = time.perf_counter() @@ -497,7 +514,8 @@ def run_experiment(): config=config, async_op=False, workspace=workspace, num_fetch_sms=num_fetch_sms, k_per_flag=k_per_flag, num_warps=num_warps, num_stages=num_stages, - num_fetch_stages=num_fetch_stages, trace=True, + num_fetch_stages=num_fetch_stages, + first_stage_fetch_sms=first_stage_fetch_sms, trace=True, ) torch.cuda.synchronize() shmem.barrier() @@ -521,6 +539,7 @@ def run_experiment(): num_warps=num_warps, num_stages=num_stages, num_fetch_stages=num_fetch_stages, + first_stage_fetch_sms=first_stage_fetch_sms, trace=True, ) torch.cuda.synchronize() diff --git a/benchmark/ops/all_gather_matmul/tune_hbm_buffer.py b/benchmark/ops/all_gather_matmul/tune_hbm_buffer.py new file mode 100644 index 000000000..7a5243eba --- /dev/null +++ b/benchmark/ops/all_gather_matmul/tune_hbm_buffer.py @@ -0,0 +1,576 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. + +""" +Parameter tuning script for HBM-buffered all_gather_matmul. + +Sweeps parameters around a baseline configuration, collecting traces, TFLOPs, +PyTorch baseline, and validation for every configuration. + +This script does NOT modify benchmark_hbm_buffer.py — it invokes it via +``torchrun`` as a subprocess for each parameter set. + +Usage: + # Default one-at-a-time sweep (each param varied independently): + python benchmark/ops/all_gather_matmul/tune_hbm_buffer.py + + # Custom matrix size: + python benchmark/ops/all_gather_matmul/tune_hbm_buffer.py -m 8192 -n 4096 -k 131072 + + # Only sweep specific parameters: + python benchmark/ops/all_gather_matmul/tune_hbm_buffer.py --params num_fetch_sms k_per_flag + + # Full cartesian product (warning: combinatorial explosion): + python benchmark/ops/all_gather_matmul/tune_hbm_buffer.py --mode full + + # Dry run — just print what would be tested: + python benchmark/ops/all_gather_matmul/tune_hbm_buffer.py --dry_run +""" + +import argparse +import json +import os +import re +import subprocess +import sys +import time +from datetime import datetime +from itertools import product +from pathlib import Path + +# ───────────────────────────────────────────────────────────────────────────── +# Baseline configuration — the centre point of every sweep. +# Edit these to match your current best-known config. +# ───────────────────────────────────────────────────────────────────────────── +BASELINE = { + "block_size_m": 256, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 4, + "num_fetch_sms": 64, + "k_per_flag": 64, + "num_warps": 8, + "num_fetch_stages": 4, + "first_stage_fetch_sms": 304, +} + +# ───────────────────────────────────────────────────────────────────────────── +# Sweep ranges — values to try for each parameter. +# In ``oneatatime`` mode only one parameter deviates from the baseline at a +# time; in ``full`` mode the cartesian product is taken (use with care). +# ───────────────────────────────────────────────────────────────────────────── +SWEEP_RANGES = { + "block_size_m": [64, 128, 256], + "block_size_n": [64, 128, 256], + "block_size_k": [64], + "group_size_m": [1, 2, 4, 8], + "num_fetch_sms": [64, 128, 192, 256], + "k_per_flag": [16, 32, 64, 128], + "num_warps": [4, 8], + "num_fetch_stages": [2, 4, 8], + "first_stage_fetch_sms": [128, 192, 256, 304], +} + +# ───────────────────────────────────────────────────────────────────────────── +# Helpers +# ───────────────────────────────────────────────────────────────────────────── + +def make_label(cfg): + """Short human-readable label for a config.""" + parts = [ + f"bm{cfg['block_size_m']}", + f"bn{cfg['block_size_n']}", + f"bk{cfg['block_size_k']}", + f"gm{cfg['group_size_m']}", + f"nf{cfg['num_fetch_sms']}", + f"kpf{cfg['k_per_flag']}", + f"nw{cfg['num_warps']}", + f"fs{cfg['num_fetch_stages']}", + ] + if cfg["num_fetch_stages"] > 1: + parts.append(f"fsf{cfg['first_stage_fetch_sms']}") + return "_".join(parts) + + +def validate_config(cfg, M, N, K, world_size=8): + """Return a list of error strings; empty list means valid.""" + errors = [] + K_local = K // world_size + bm, bn, bk = cfg["block_size_m"], cfg["block_size_n"], cfg["block_size_k"] + kpf = cfg["k_per_flag"] + + if M % bm != 0: + errors.append(f"M={M} not divisible by block_size_m={bm}") + if N % bn != 0: + errors.append(f"N={N} not divisible by block_size_n={bn}") + if K % bk != 0: + errors.append(f"K={K} not divisible by block_size_k={bk}") + if K_local % bk != 0: + errors.append(f"K_local={K_local} not divisible by block_size_k={bk}") + + num_k_blocks = K // bk + if num_k_blocks % kpf != 0: + errors.append(f"num_k_blocks={num_k_blocks} not divisible by k_per_flag={kpf}") + + if cfg["num_warps"] not in (1, 2, 4, 8, 16): + errors.append(f"num_warps={cfg['num_warps']} must be a power of 2 in [1..16]") + + return errors + + +def build_command(cfg, M, N, K, trace_path, nproc=8, + validate=True, benchmark=True, benchmark_pytorch=False): + """Build the ``torchrun`` CLI for one configuration.""" + cmd = [ + "torchrun", "--nproc_per_node", str(nproc), + "benchmark/ops/all_gather_matmul/benchmark_hbm_buffer.py", + "-m", str(M), + "-n", str(N), + "-k", str(K), + "--block_size_m", str(cfg["block_size_m"]), + "--block_size_n", str(cfg["block_size_n"]), + "--block_size_k", str(cfg["block_size_k"]), + "--group_size_m", str(cfg["group_size_m"]), + "--num_fetch_sms", str(cfg["num_fetch_sms"]), + "--k_per_flag", str(cfg["k_per_flag"]), + "--num_warps", str(cfg["num_warps"]), + "--num_fetch_stages", str(cfg["num_fetch_stages"]), + ] + + if cfg["num_fetch_stages"] > 1 and cfg.get("first_stage_fetch_sms") is not None: + cmd.extend(["--first_stage_fetch_sms", str(cfg["first_stage_fetch_sms"])]) + + if validate: + cmd.append("-v") + if benchmark: + cmd.append("-b") + if benchmark_pytorch: + cmd.append("--benchmark_pytorch") + + cmd.extend(["--trace", "--trace_output", trace_path]) + return cmd + + +# ── Output parsing ──────────────────────────────────────────────────────────── + +_RE_IRIS = re.compile( + r"HBM-Buffer\s*\([^)]*\):\s*([\d.]+)\s*ms,\s*([\d.]+)\s*TFLOPS,\s*([\d.]+)\s*GB/s" +) +_RE_PYTORCH = re.compile( + r"PyTorch\s*\([^)]*\):\s*([\d.]+)\s*ms,\s*([\d.]+)\s*TFLOPS,\s*([\d.]+)\s*GB/s" +) +_RE_SPEEDUP = re.compile(r"Speedup.*?:\s*([\d.]+)x") +_RE_VALID_FAIL = re.compile(r"Validation FAILED.*?max diff:\s*([\d.eE+-]+)") + + +def parse_output(output): + """Extract metrics from benchmark stdout+stderr.""" + result = { + "iris_ms": None, + "iris_tflops": None, + "iris_bw_gbps": None, + "pytorch_ms": None, + "pytorch_tflops": None, + "pytorch_bw_gbps": None, + "validation": None, + "speedup": None, + } + + m = _RE_IRIS.search(output) + if m: + result["iris_ms"] = float(m.group(1)) + result["iris_tflops"] = float(m.group(2)) + result["iris_bw_gbps"] = float(m.group(3)) + + m = _RE_PYTORCH.search(output) + if m: + result["pytorch_ms"] = float(m.group(1)) + result["pytorch_tflops"] = float(m.group(2)) + result["pytorch_bw_gbps"] = float(m.group(3)) + + if "Validation PASSED" in output: + result["validation"] = "PASSED" + elif "Validation FAILED" in output: + fm = _RE_VALID_FAIL.search(output) + result["validation"] = f"FAILED (diff={fm.group(1)})" if fm else "FAILED" + + m = _RE_SPEEDUP.search(output) + if m: + result["speedup"] = float(m.group(1)) + + return result + + +# ── Sweep generation ────────────────────────────────────────────────────────── + +def generate_configs(baseline, sweep_ranges, mode="oneatatime", params=None): + """ + Generate the list of configs to evaluate. + + Args: + baseline: dict of default values + sweep_ranges: dict mapping param name -> list of values + mode: "oneatatime" or "full" + params: optional list of param names to sweep (None = all) + """ + configs = [] + seen = set() + + def _add(cfg): + label = make_label(cfg) + if label not in seen: + configs.append(dict(cfg)) + seen.add(label) + + # Always include baseline first + _add(baseline) + + active_params = params if params else list(sweep_ranges.keys()) + + if mode == "oneatatime": + for param in active_params: + if param not in sweep_ranges: + print(f" WARNING: unknown param '{param}', skipping") + continue + for val in sweep_ranges[param]: + cfg = dict(baseline) + cfg[param] = val + # When num_fetch_stages == 1, first_stage_fetch_sms is irrelevant + if cfg["num_fetch_stages"] == 1: + cfg["first_stage_fetch_sms"] = cfg["num_fetch_sms"] + _add(cfg) + + elif mode == "full": + active_ranges = {p: sweep_ranges[p] for p in active_params if p in sweep_ranges} + names = list(active_ranges.keys()) + values = [active_ranges[n] for n in names] + for combo in product(*values): + cfg = dict(baseline) + for n, v in zip(names, combo): + cfg[n] = v + if cfg["num_fetch_stages"] == 1: + cfg["first_stage_fetch_sms"] = cfg["num_fetch_sms"] + _add(cfg) + + return configs + + +# ── Main ────────────────────────────────────────────────────────────────────── + +def main(): + parser = argparse.ArgumentParser( + description="Parameter tuning for HBM-buffered all_gather_matmul.", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + # ── Matrix dimensions ──────────────────────────────────────────────── + parser.add_argument("-m", type=int, default=16384, help="M dimension") + parser.add_argument("-n", type=int, default=2048, help="N dimension") + parser.add_argument("-k", type=int, default=131072, help="K dimension (total)") + parser.add_argument("--nproc", type=int, default=8, help="Number of GPUs") + + # ── Baseline overrides (non-swept params use these values) ──────── + parser.add_argument("--block_size_m", type=int, default=None, + help=f"Baseline block_size_m (default: {BASELINE['block_size_m']})") + parser.add_argument("--block_size_n", type=int, default=None, + help=f"Baseline block_size_n (default: {BASELINE['block_size_n']})") + parser.add_argument("--block_size_k", type=int, default=None, + help=f"Baseline block_size_k (default: {BASELINE['block_size_k']})") + parser.add_argument("--group_size_m", type=int, default=None, + help=f"Baseline group_size_m (default: {BASELINE['group_size_m']})") + parser.add_argument("--num_fetch_sms", type=int, default=None, + help=f"Baseline num_fetch_sms (default: {BASELINE['num_fetch_sms']})") + parser.add_argument("--k_per_flag", type=int, default=None, + help=f"Baseline k_per_flag (default: {BASELINE['k_per_flag']})") + parser.add_argument("--num_warps", type=int, default=None, + help=f"Baseline num_warps (default: {BASELINE['num_warps']})") + parser.add_argument("--num_fetch_stages", type=int, default=None, + help=f"Baseline num_fetch_stages (default: {BASELINE['num_fetch_stages']})") + parser.add_argument("--first_stage_fetch_sms", type=int, default=None, + help=f"Baseline first_stage_fetch_sms (default: {BASELINE['first_stage_fetch_sms']})") + + # ── Sweep control ───────────────────────────────────────────────── + parser.add_argument( + "--mode", choices=["oneatatime", "full"], default="oneatatime", + help="'oneatatime' varies one param at a time; 'full' = cartesian product", + ) + parser.add_argument( + "--params", nargs="+", default=None, + help="Only sweep these parameters (default: all). " + "Choices: " + ", ".join(SWEEP_RANGES.keys()), + ) + parser.add_argument("--output_dir", type=str, default=None, + help="Output directory (auto-generated if unset)") + parser.add_argument("--dry_run", action="store_true", + help="Print configs and exit without running") + parser.add_argument("--skip_validation", action="store_true", + help="Skip validation (faster, no correctness check)") + parser.add_argument("--timeout", type=int, default=600, + help="Per-config timeout in seconds (default: 600)") + + args = parser.parse_args() + M, N, K = args.m, args.n, args.k + + # Apply any CLI baseline overrides + baseline = dict(BASELINE) + for key in baseline: + cli_val = getattr(args, key, None) + if cli_val is not None: + baseline[key] = cli_val + + # Output directory + if args.output_dir: + output_dir = Path(args.output_dir) + else: + ts = datetime.now().strftime("%Y%m%d_%H%M%S") + output_dir = Path(f"benchmark/ops/all_gather_matmul/tune_results_{ts}") + output_dir.mkdir(parents=True, exist_ok=True) + trace_dir = output_dir / "traces" + trace_dir.mkdir(exist_ok=True) + + # Generate configs + configs = generate_configs(baseline, SWEEP_RANGES, + mode=args.mode, params=args.params) + + # Pre-validate all configs + valid_configs = [] + skipped = [] + for cfg in configs: + errs = validate_config(cfg, M, N, K, world_size=args.nproc) + if errs: + skipped.append((cfg, errs)) + else: + valid_configs.append(cfg) + + # Banner + print(f"\n{'='*100}") + print(f" HBM-Buffer All-Gather MatMul — Parameter Tuning") + print(f" M={M} N={N} K={K} nproc={args.nproc} mode={args.mode}") + print(f" Baseline: {make_label(baseline)}") + print(f" Configs to run: {len(valid_configs)} (skipped: {len(skipped)})") + print(f" Output dir: {output_dir}") + print(f" Validation: {'OFF' if args.skip_validation else 'ON'}") + print(f"{'='*100}") + + if skipped: + print(f"\n Skipped (invalid for M={M}, N={N}, K={K}):") + for cfg, errs in skipped: + print(f" {make_label(cfg)}: {'; '.join(errs)}") + + if args.dry_run: + print(f"\n Configs that would be run:") + for i, cfg in enumerate(valid_configs): + label = make_label(cfg) + is_baseline = (cfg == baseline) + tag = " [BASELINE]" if is_baseline else "" + print(f" [{i+1:>3}] {label}{tag}") + print(f"\n Total: {len(valid_configs)} configs") + return + + # ── Run sweep ───────────────────────────────────────────────────────── + results = [] + pytorch_baseline = None + env = os.environ.copy() + env["HSA_NO_SCRATCH_RECLAIM"] = "1" + + total_start = time.time() + + for i, cfg in enumerate(valid_configs): + label = make_label(cfg) + trace_path = str(trace_dir / f"trace_{label}.png") + is_first = (i == 0) + + sep = "-" * 80 + print(f"\n{sep}") + print(f"[{i+1}/{len(valid_configs)}] {label}") + if is_first: + print(f" (includes PyTorch baseline benchmark)") + print(sep) + + cmd = build_command( + cfg, M, N, K, trace_path, + nproc=args.nproc, + validate=not args.skip_validation, + benchmark=True, + benchmark_pytorch=is_first, + ) + cmd_str = " ".join(cmd) + print(f" $ HSA_NO_SCRATCH_RECLAIM=1 {cmd_str}") + + t0 = time.time() + try: + proc = subprocess.run( + cmd, env=env, + capture_output=True, text=True, + timeout=args.timeout, + ) + elapsed = time.time() - t0 + full_output = proc.stdout + "\n" + proc.stderr + + parsed = parse_output(full_output) + + # Capture PyTorch baseline on first run + if is_first and parsed["pytorch_tflops"] is not None: + pytorch_baseline = { + "ms": parsed["pytorch_ms"], + "tflops": parsed["pytorch_tflops"], + "bw_gbps": parsed["pytorch_bw_gbps"], + } + + trace_exists = os.path.exists(trace_path) + results.append({ + "label": label, + "config": cfg, + "iris_ms": parsed["iris_ms"], + "iris_tflops": parsed["iris_tflops"], + "iris_bw_gbps": parsed["iris_bw_gbps"], + "validation": parsed["validation"], + "trace_path": trace_path if trace_exists else None, + "elapsed_s": round(elapsed, 1), + "returncode": proc.returncode, + }) + + # Print summary line + parts = [] + if parsed["iris_tflops"] is not None: + parts.append(f"{parsed['iris_tflops']:.2f} TFLOPS") + parts.append(f"{parsed['iris_ms']:.3f} ms") + if parsed["iris_bw_gbps"] is not None: + parts.append(f"{parsed['iris_bw_gbps']:.1f} GB/s") + if parsed["validation"]: + parts.append(f"valid={parsed['validation']}") + if trace_exists: + parts.append(f"trace=OK") + else: + parts.append(f"trace=MISSING") + if proc.returncode != 0: + parts.append(f"EXIT={proc.returncode}") + print(f" => {' | '.join(parts)} ({elapsed:.0f}s)") + + if is_first and pytorch_baseline: + print(f" => PyTorch baseline: {pytorch_baseline['tflops']:.2f} TFLOPS" + f" {pytorch_baseline['ms']:.3f} ms") + + # Save full log for debugging + log_path = output_dir / f"log_{label}.txt" + with open(log_path, "w") as f: + f.write(f"COMMAND: HSA_NO_SCRATCH_RECLAIM=1 {cmd_str}\n") + f.write(f"EXIT CODE: {proc.returncode}\n") + f.write(f"ELAPSED: {elapsed:.1f}s\n\n") + f.write("=== STDOUT ===\n") + f.write(proc.stdout) + f.write("\n=== STDERR ===\n") + f.write(proc.stderr) + + except subprocess.TimeoutExpired: + elapsed = time.time() - t0 + results.append({ + "label": label, + "config": cfg, + "iris_ms": None, + "iris_tflops": None, + "iris_bw_gbps": None, + "validation": "TIMEOUT", + "trace_path": None, + "elapsed_s": round(elapsed, 1), + "returncode": -1, + }) + print(f" => TIMEOUT after {args.timeout}s") + + except Exception as e: + elapsed = time.time() - t0 + results.append({ + "label": label, + "config": cfg, + "iris_ms": None, + "iris_tflops": None, + "iris_bw_gbps": None, + "validation": f"ERROR: {e}", + "trace_path": None, + "elapsed_s": round(elapsed, 1), + "returncode": -1, + }) + print(f" => ERROR: {e}") + + total_elapsed = time.time() - total_start + + # ── Summary table ───────────────────────────────────────────────────── + W = 130 + print(f"\n\n{'='*W}") + print(f" TUNING RESULTS | M={M} N={N} K={K} | nproc={args.nproc} | " + f"{len(valid_configs)} configs in {total_elapsed:.0f}s") + if pytorch_baseline: + print(f" PyTorch baseline: {pytorch_baseline['ms']:.3f} ms | " + f"{pytorch_baseline['tflops']:.2f} TFLOPS | " + f"{pytorch_baseline['bw_gbps']:.1f} GB/s") + print(f"{'='*W}") + + col_label_w = 65 + print(f" {'#':>3} {'Configuration':<{col_label_w}} {'ms':>8} {'TFLOPS':>8} " + f"{'vs PT':>7} {'Valid':>8} {'Trace':>5}") + print(f" {'-'*(W-4)}") + + for i, r in enumerate(results): + ms_s = f"{r['iris_ms']:.3f}" if r["iris_ms"] is not None else "--" + tf_s = f"{r['iris_tflops']:.2f}" if r["iris_tflops"] is not None else "--" + + if pytorch_baseline and r["iris_tflops"] is not None and pytorch_baseline["tflops"] > 0: + vs_pt = f"{r['iris_tflops'] / pytorch_baseline['tflops']:.2f}x" + else: + vs_pt = "--" + + valid_s = (r["validation"] or "--")[:8] + trace_s = "Y" if r.get("trace_path") else "N" + + tag = " *" if (r["iris_tflops"] is not None and + r["iris_tflops"] == max((x["iris_tflops"] for x in results + if x["iris_tflops"] is not None), default=0)) else "" + + print(f" {i+1:>3} {r['label']:<{col_label_w}} {ms_s:>8} {tf_s:>8} " + f"{vs_pt:>7} {valid_s:>8} {trace_s:>5}{tag}") + + # Best config + valid_results = [r for r in results if r["iris_tflops"] is not None] + if valid_results: + best = max(valid_results, key=lambda r: r["iris_tflops"]) + worst = min(valid_results, key=lambda r: r["iris_tflops"]) + print(f"\n {'BEST':>6}: {best['label']}") + print(f" {best['iris_ms']:.3f} ms | {best['iris_tflops']:.2f} TFLOPS | " + f"valid={best['validation']}") + if pytorch_baseline and pytorch_baseline["tflops"] > 0: + print(f" {best['iris_tflops'] / pytorch_baseline['tflops']:.2f}x vs PyTorch") + if best.get("trace_path"): + print(f" trace: {best['trace_path']}") + print(f" {'WORST':>6}: {worst['label']}") + print(f" {worst['iris_ms']:.3f} ms | {worst['iris_tflops']:.2f} TFLOPS") + if best["iris_tflops"] > 0 and worst["iris_tflops"] > 0: + print(f" SPREAD: {best['iris_tflops'] / worst['iris_tflops']:.2f}x " + f"({worst['iris_tflops']:.2f} → {best['iris_tflops']:.2f} TFLOPS)") + + print(f"{'='*W}") + + # ── Save results JSON ───────────────────────────────────────────────── + results_path = output_dir / "results.json" + with open(results_path, "w") as f: + json.dump({ + "meta": { + "M": M, "N": N, "K": K, + "nproc": args.nproc, + "mode": args.mode, + "baseline": baseline, + "sweep_ranges": SWEEP_RANGES, + "timestamp": datetime.now().isoformat(), + "total_elapsed_s": round(total_elapsed, 1), + "pytorch_baseline": pytorch_baseline, + }, + "results": results, + }, f, indent=2, default=str) + + print(f"\n Results JSON : {results_path}") + print(f" Trace PNGs : {trace_dir}/") + print(f" Per-run logs : {output_dir}/log_*.txt") + print() + + +if __name__ == "__main__": + main() diff --git a/iris/iris.py b/iris/iris.py index 94cd0ae6e..c283abf24 100644 --- a/iris/iris.py +++ b/iris/iris.py @@ -1780,10 +1780,6 @@ def reduce_scatter(self, output_tensor, input_tensor, op=None, group=None, async @triton.jit def __translate(ptr, from_rank, to_rank, heap_bases): - """ - Basic pointer translation without vectorization hints. - Used for atomic operations which may receive scalar pointers. - """ from_base = tl.load(heap_bases + from_rank) to_base = tl.load(heap_bases + to_rank) # convert to int to compute difference @@ -1797,9 +1793,21 @@ def __translate(ptr, from_rank, to_rank, heap_bases): # Cast to_base back to pointer type translated_ptr = tl.cast(translated_ptr_byte, ptr.dtype) + # Optimization to vectorize the load/store + # We can't do this in general because we don't know the shape of the tensor or block sizes + # ptr = tl.max_contiguous(tl.multiple_of(ptr, (16, 16)), (16, 32)) + + # 0 You can use this if your block sizes are multiples of 32. + # Largest vectorized load instruction is dwordx4 (128-bits) + translated_ptr = tl.multiple_of(translated_ptr, (32, 32)) + translated_ptr = tl.max_contiguous(translated_ptr, (1, 32)) + + # ptr = tl.max_contiguous(tl.multiple_of(ptr, 512), 512) + # translated_ptr = tl.max_contiguous(tl.multiple_of(translated_ptr, 512), 512) return translated_ptr + @triton.jit def __translate_block_2d(ptr, from_rank, to_rank, heap_bases): """ @@ -2029,7 +2037,7 @@ def load(self, pointer, from_rank, mask=None): Example: >>> data = ctx.load(buffer + offsets, from_rank=1, mask=mask) """ - translated_ptr = self._translate_block_2d(pointer, self.rank, from_rank) + translated_ptr = self.__translate(pointer, self.rank, from_rank) result = tl.load(translated_ptr, mask=mask) return result @@ -2055,7 +2063,7 @@ def store(self, pointer, value, to_rank, mask=None): Example: >>> ctx.store(buffer + offsets, values, to_rank=1, mask=mask) """ - translated_ptr = self._translate_block_2d(pointer, self.rank, to_rank) + translated_ptr = self.__translate(pointer, self.rank, to_rank) tl.store(translated_ptr, value, mask=mask) @triton.jit diff --git a/iris/ops/all_gather_matmul_hbm_buffer.py b/iris/ops/all_gather_matmul_hbm_buffer.py index 4fdc1b067..e9a5b6b0b 100644 --- a/iris/ops/all_gather_matmul_hbm_buffer.py +++ b/iris/ops/all_gather_matmul_hbm_buffer.py @@ -61,6 +61,7 @@ def _hbm_buffer_all_gather_matmul_kernel( ALLOW_TF32: tl.constexpr, NUM_FETCH_STAGES: tl.constexpr, GEMM_TILES_PER_STAGE: tl.constexpr, + FIRST_STAGE_FETCH_SMS: tl.constexpr, trace_start_ptr, trace_end_ptr, trace_wait_ptr, @@ -71,21 +72,34 @@ def _hbm_buffer_all_gather_matmul_kernel( acc_dtype = tl.int32 if C.type.element_ty == tl.int8 else tl.float32 zero = tl.program_id(0) * 0 + if TRACE: tl.store(trace_start_ptr + pid, read_realtime()) tl.store(trace_xcd_ptr + pid, get_xcc_id()) - # Interleaved layout: [fetch0 | gemm0 | fetch1 | gemm1 | ...] - WGS_PER_STAGE: tl.constexpr = NUM_FETCH_SMS + GEMM_TILES_PER_STAGE + # Interleaved layout with asymmetric first stage: + # [fetch0 (P)] [gemm0 (G)] [fetch1 (F)] [gemm1 (G)] ... + # P = FIRST_STAGE_FETCH_SMS, F = NUM_FETCH_SMS, G = GEMM_TILES_PER_STAGE + FIRST_STAGE_SIZE: tl.constexpr = FIRST_STAGE_FETCH_SMS + GEMM_TILES_PER_STAGE + REST_STAGE_SIZE: tl.constexpr = NUM_FETCH_SMS + GEMM_TILES_PER_STAGE M_PER_STAGE: tl.constexpr = (NUM_M_TILES + NUM_FETCH_STAGES - 1) // NUM_FETCH_STAGES - local_pid = pid % WGS_PER_STAGE + # Two-phase decode: stage 0 has a different size than subsequent stages + if pid < FIRST_STAGE_SIZE: + my_stage = zero + local_pid = pid + fetch_threshold = zero + FIRST_STAGE_FETCH_SMS + else: + adjusted = pid - FIRST_STAGE_SIZE + my_stage = 1 + adjusted // REST_STAGE_SIZE + local_pid = adjusted % REST_STAGE_SIZE + fetch_threshold = zero + NUM_FETCH_SMS - if local_pid < NUM_FETCH_SMS: + if local_pid < fetch_threshold: # ============================================================== - # FETCHER — interleaved: stage determined by pid // WGS_PER_STAGE + # FETCHER — stage 0 uses FIRST_STAGE_FETCH_SMS WGs, + # later stages use NUM_FETCH_SMS WGs # ============================================================== - my_stage = pid // WGS_PER_STAGE stage_pid = local_pid ctx = iris.DeviceContext.initialize(context_tensor, cur_rank, world_size) @@ -95,11 +109,12 @@ def _hbm_buffer_all_gather_matmul_kernel( for const_stage in range(NUM_FETCH_STAGES): if my_stage == const_stage: + stage_fetch_sms = FIRST_STAGE_FETCH_SMS if const_stage == 0 else NUM_FETCH_SMS stage_m_start = const_stage * M_PER_STAGE stage_m_count = min(M_PER_STAGE, NUM_M_TILES - stage_m_start) total_fg_stage = NUM_FLAG_GROUPS_K * stage_m_count - for fg_idx in range(stage_pid, total_fg_stage, NUM_FETCH_SMS): + for fg_idx in range(stage_pid, total_fg_stage, stage_fetch_sms): m_group = fg_idx // tiles_per_m_group within_group = fg_idx % tiles_per_m_group k_flag_group = within_group // GROUP_SIZE_M @@ -129,21 +144,18 @@ def _hbm_buffer_all_gather_matmul_kernel( tl.store(staged_ptrs, a_tile, cache_modifier=".cg") flag_idx = m_tile * NUM_FLAG_GROUPS_K + k_flag_group - tl.atomic_xchg(flags_ptr + flag_idx, 1, sem="release", scope="gpu") + #tl.atomic_xchg(flags_ptr + flag_idx, 1, sem="release", scope="gpu") + tl.store(flags_ptr + flag_idx, 1, cache_modifier=".wt") if TRACE: tl.store(trace_wait_ptr + pid, zero.to(tl.int64), cache_modifier=".wt") tl.store(trace_end_ptr + pid, read_realtime(), cache_modifier=".wt") - tl.store(trace_wait_ptr + pid, zero.to(tl.int64), cache_modifier=".wt") - tl.store(trace_end_ptr + pid, read_realtime(), cache_modifier=".wt") else: # ============================================================== - # GEMM — interleaved: stage determined by pid // WGS_PER_STAGE - # gemm_local_id indexes into this stage's M-tile range + # GEMM — gemm_local_id indexes into this stage's M-tile range # ============================================================== - my_stage = pid // WGS_PER_STAGE - gemm_local_id = local_pid - NUM_FETCH_SMS + gemm_local_id = local_pid - fetch_threshold stage_m_start = my_stage * M_PER_STAGE num_pid_in_group = GROUP_SIZE_M * NUM_TILES_N @@ -293,6 +305,7 @@ def all_gather_matmul_hbm_buffer( num_warps: Optional[int] = None, num_stages: Optional[int] = None, num_fetch_stages: int = 1, + first_stage_fetch_sms: Optional[int] = None, trace: bool = False, ) -> FusedWorkspace: """ @@ -363,12 +376,17 @@ def all_gather_matmul_hbm_buffer( assert 0 < num_fetch_sms assert num_fetch_stages >= 1 - # Interleaved layout: [fetch0 | gemm0 | fetch1 | gemm1 | ...] + # First stage can use more fetcher WGs to fill the first GPU wave + if first_stage_fetch_sms is None: + first_stage_fetch_sms = num_fetch_sms + + # Interleaved layout: [fetch0 (P)] [gemm0 (G)] [fetch1 (F)] [gemm1 (G)] ... m_per_stage = (num_m_tiles + num_fetch_stages - 1) // num_fetch_stages gemm_tiles_per_stage = m_per_stage * num_tiles_n - wgs_per_stage = num_fetch_sms + gemm_tiles_per_stage - total_fetch_wgs = num_fetch_sms * num_fetch_stages - grid_size = wgs_per_stage * num_fetch_stages + first_stage_size = first_stage_fetch_sms + gemm_tiles_per_stage + rest_stage_size = num_fetch_sms + gemm_tiles_per_stage + total_fetch_wgs = first_stage_fetch_sms + num_fetch_sms * max(0, num_fetch_stages - 1) + grid_size = first_stage_size + rest_stage_size * max(0, num_fetch_stages - 1) # Trace buffers if trace: @@ -427,6 +445,7 @@ def all_gather_matmul_hbm_buffer( config.allow_tf32, num_fetch_stages, gemm_tiles_per_stage, + first_stage_fetch_sms, trace_start, trace_end, trace_wait, @@ -451,7 +470,9 @@ def all_gather_matmul_hbm_buffer( "total_fetch_wgs": total_fetch_wgs, "num_m_tiles": num_m_tiles, "num_tiles_n": num_tiles_n, - "wgs_per_stage": wgs_per_stage, + "first_stage_fetch_sms": first_stage_fetch_sms, + "first_stage_size": first_stage_size, + "rest_stage_size": rest_stage_size, "gemm_tiles_per_stage": gemm_tiles_per_stage, } From cbe2aff3d8bd08bef7cd5ce31fce7d1bd298dcad Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Tue, 17 Feb 2026 21:05:48 +0000 Subject: [PATCH 19/31] Apply Ruff auto-fixes --- .../all_gather_matmul/benchmark_hbm_buffer.py | 32 +- .../ops/all_gather_matmul/tune_hbm_buffer.py | 342 ++++++++++-------- iris/iris.py | 1 - iris/ops/all_gather_matmul_hbm_buffer.py | 3 +- 4 files changed, 226 insertions(+), 152 deletions(-) diff --git a/benchmark/ops/all_gather_matmul/benchmark_hbm_buffer.py b/benchmark/ops/all_gather_matmul/benchmark_hbm_buffer.py index 9b5e6c265..666a37a79 100644 --- a/benchmark/ops/all_gather_matmul/benchmark_hbm_buffer.py +++ b/benchmark/ops/all_gather_matmul/benchmark_hbm_buffer.py @@ -242,8 +242,18 @@ def parse_args(): parser.add_argument("--k_per_flag", type=int, default=1, help="K-blocks per ready flag") parser.add_argument("--num_warps", type=int, default=None, help="Triton num_warps (auto if None)") parser.add_argument("--num_stages", type=int, default=None, help="Triton num_stages (auto if None)") - parser.add_argument("--num_fetch_stages", type=int, default=1, help="Number of fetch stages (1=all at once, 2=top/bottom half, etc.)") - parser.add_argument("--first_stage_fetch_sms", type=int, default=None, help="Fetcher WGs for stage 0 (fills first GPU wave; defaults to num_fetch_sms)") + parser.add_argument( + "--num_fetch_stages", + type=int, + default=1, + help="Number of fetch stages (1=all at once, 2=top/bottom half, etc.)", + ) + parser.add_argument( + "--first_stage_fetch_sms", + type=int, + default=None, + help="Fetcher WGs for stage 0 (fills first GPU wave; defaults to num_fetch_sms)", + ) parser.add_argument("--trace", action="store_true", help="Collect per-workgroup trace and save Gantt chart PNG") parser.add_argument("--trace_output", type=str, default="trace_gantt.png", help="Output path for trace plot") return vars(parser.parse_args()) @@ -502,12 +512,20 @@ def run_experiment(): workspace.locks.zero_() shmem.barrier() all_gather_matmul_hbm_buffer( - shmem, C, A_sharded, B, - config=config, async_op=False, workspace=workspace, - num_fetch_sms=num_fetch_sms, k_per_flag=k_per_flag, - num_warps=num_warps, num_stages=num_stages, + shmem, + C, + A_sharded, + B, + config=config, + async_op=False, + workspace=workspace, + num_fetch_sms=num_fetch_sms, + k_per_flag=k_per_flag, + num_warps=num_warps, + num_stages=num_stages, num_fetch_stages=num_fetch_stages, - first_stage_fetch_sms=first_stage_fetch_sms, trace=True, + first_stage_fetch_sms=first_stage_fetch_sms, + trace=True, ) torch.cuda.synchronize() shmem.barrier() diff --git a/benchmark/ops/all_gather_matmul/tune_hbm_buffer.py b/benchmark/ops/all_gather_matmul/tune_hbm_buffer.py index 7a5243eba..db9cc56f2 100644 --- a/benchmark/ops/all_gather_matmul/tune_hbm_buffer.py +++ b/benchmark/ops/all_gather_matmul/tune_hbm_buffer.py @@ -33,7 +33,6 @@ import os import re import subprocess -import sys import time from datetime import datetime from itertools import product @@ -61,14 +60,14 @@ # time; in ``full`` mode the cartesian product is taken (use with care). # ───────────────────────────────────────────────────────────────────────────── SWEEP_RANGES = { - "block_size_m": [64, 128, 256], - "block_size_n": [64, 128, 256], - "block_size_k": [64], - "group_size_m": [1, 2, 4, 8], - "num_fetch_sms": [64, 128, 192, 256], - "k_per_flag": [16, 32, 64, 128], - "num_warps": [4, 8], - "num_fetch_stages": [2, 4, 8], + "block_size_m": [64, 128, 256], + "block_size_n": [64, 128, 256], + "block_size_k": [64], + "group_size_m": [1, 2, 4, 8], + "num_fetch_sms": [64, 128, 192, 256], + "k_per_flag": [16, 32, 64, 128], + "num_warps": [4, 8], + "num_fetch_stages": [2, 4, 8], "first_stage_fetch_sms": [128, 192, 256, 304], } @@ -76,6 +75,7 @@ # Helpers # ───────────────────────────────────────────────────────────────────────────── + def make_label(cfg): """Short human-readable label for a config.""" parts = [ @@ -119,23 +119,35 @@ def validate_config(cfg, M, N, K, world_size=8): return errors -def build_command(cfg, M, N, K, trace_path, nproc=8, - validate=True, benchmark=True, benchmark_pytorch=False): +def build_command(cfg, M, N, K, trace_path, nproc=8, validate=True, benchmark=True, benchmark_pytorch=False): """Build the ``torchrun`` CLI for one configuration.""" cmd = [ - "torchrun", "--nproc_per_node", str(nproc), + "torchrun", + "--nproc_per_node", + str(nproc), "benchmark/ops/all_gather_matmul/benchmark_hbm_buffer.py", - "-m", str(M), - "-n", str(N), - "-k", str(K), - "--block_size_m", str(cfg["block_size_m"]), - "--block_size_n", str(cfg["block_size_n"]), - "--block_size_k", str(cfg["block_size_k"]), - "--group_size_m", str(cfg["group_size_m"]), - "--num_fetch_sms", str(cfg["num_fetch_sms"]), - "--k_per_flag", str(cfg["k_per_flag"]), - "--num_warps", str(cfg["num_warps"]), - "--num_fetch_stages", str(cfg["num_fetch_stages"]), + "-m", + str(M), + "-n", + str(N), + "-k", + str(K), + "--block_size_m", + str(cfg["block_size_m"]), + "--block_size_n", + str(cfg["block_size_n"]), + "--block_size_k", + str(cfg["block_size_k"]), + "--group_size_m", + str(cfg["group_size_m"]), + "--num_fetch_sms", + str(cfg["num_fetch_sms"]), + "--k_per_flag", + str(cfg["k_per_flag"]), + "--num_warps", + str(cfg["num_warps"]), + "--num_fetch_stages", + str(cfg["num_fetch_stages"]), ] if cfg["num_fetch_stages"] > 1 and cfg.get("first_stage_fetch_sms") is not None: @@ -154,12 +166,8 @@ def build_command(cfg, M, N, K, trace_path, nproc=8, # ── Output parsing ──────────────────────────────────────────────────────────── -_RE_IRIS = re.compile( - r"HBM-Buffer\s*\([^)]*\):\s*([\d.]+)\s*ms,\s*([\d.]+)\s*TFLOPS,\s*([\d.]+)\s*GB/s" -) -_RE_PYTORCH = re.compile( - r"PyTorch\s*\([^)]*\):\s*([\d.]+)\s*ms,\s*([\d.]+)\s*TFLOPS,\s*([\d.]+)\s*GB/s" -) +_RE_IRIS = re.compile(r"HBM-Buffer\s*\([^)]*\):\s*([\d.]+)\s*ms,\s*([\d.]+)\s*TFLOPS,\s*([\d.]+)\s*GB/s") +_RE_PYTORCH = re.compile(r"PyTorch\s*\([^)]*\):\s*([\d.]+)\s*ms,\s*([\d.]+)\s*TFLOPS,\s*([\d.]+)\s*GB/s") _RE_SPEEDUP = re.compile(r"Speedup.*?:\s*([\d.]+)x") _RE_VALID_FAIL = re.compile(r"Validation FAILED.*?max diff:\s*([\d.eE+-]+)") @@ -204,6 +212,7 @@ def parse_output(output): # ── Sweep generation ────────────────────────────────────────────────────────── + def generate_configs(baseline, sweep_ranges, mode="oneatatime", params=None): """ Generate the list of configs to evaluate. @@ -258,6 +267,7 @@ def _add(cfg): # ── Main ────────────────────────────────────────────────────────────────────── + def main(): parser = argparse.ArgumentParser( description="Parameter tuning for HBM-buffered all_gather_matmul.", @@ -270,43 +280,57 @@ def main(): parser.add_argument("--nproc", type=int, default=8, help="Number of GPUs") # ── Baseline overrides (non-swept params use these values) ──────── - parser.add_argument("--block_size_m", type=int, default=None, - help=f"Baseline block_size_m (default: {BASELINE['block_size_m']})") - parser.add_argument("--block_size_n", type=int, default=None, - help=f"Baseline block_size_n (default: {BASELINE['block_size_n']})") - parser.add_argument("--block_size_k", type=int, default=None, - help=f"Baseline block_size_k (default: {BASELINE['block_size_k']})") - parser.add_argument("--group_size_m", type=int, default=None, - help=f"Baseline group_size_m (default: {BASELINE['group_size_m']})") - parser.add_argument("--num_fetch_sms", type=int, default=None, - help=f"Baseline num_fetch_sms (default: {BASELINE['num_fetch_sms']})") - parser.add_argument("--k_per_flag", type=int, default=None, - help=f"Baseline k_per_flag (default: {BASELINE['k_per_flag']})") - parser.add_argument("--num_warps", type=int, default=None, - help=f"Baseline num_warps (default: {BASELINE['num_warps']})") - parser.add_argument("--num_fetch_stages", type=int, default=None, - help=f"Baseline num_fetch_stages (default: {BASELINE['num_fetch_stages']})") - parser.add_argument("--first_stage_fetch_sms", type=int, default=None, - help=f"Baseline first_stage_fetch_sms (default: {BASELINE['first_stage_fetch_sms']})") + parser.add_argument( + "--block_size_m", type=int, default=None, help=f"Baseline block_size_m (default: {BASELINE['block_size_m']})" + ) + parser.add_argument( + "--block_size_n", type=int, default=None, help=f"Baseline block_size_n (default: {BASELINE['block_size_n']})" + ) + parser.add_argument( + "--block_size_k", type=int, default=None, help=f"Baseline block_size_k (default: {BASELINE['block_size_k']})" + ) + parser.add_argument( + "--group_size_m", type=int, default=None, help=f"Baseline group_size_m (default: {BASELINE['group_size_m']})" + ) + parser.add_argument( + "--num_fetch_sms", type=int, default=None, help=f"Baseline num_fetch_sms (default: {BASELINE['num_fetch_sms']})" + ) + parser.add_argument( + "--k_per_flag", type=int, default=None, help=f"Baseline k_per_flag (default: {BASELINE['k_per_flag']})" + ) + parser.add_argument( + "--num_warps", type=int, default=None, help=f"Baseline num_warps (default: {BASELINE['num_warps']})" + ) + parser.add_argument( + "--num_fetch_stages", + type=int, + default=None, + help=f"Baseline num_fetch_stages (default: {BASELINE['num_fetch_stages']})", + ) + parser.add_argument( + "--first_stage_fetch_sms", + type=int, + default=None, + help=f"Baseline first_stage_fetch_sms (default: {BASELINE['first_stage_fetch_sms']})", + ) # ── Sweep control ───────────────────────────────────────────────── parser.add_argument( - "--mode", choices=["oneatatime", "full"], default="oneatatime", + "--mode", + choices=["oneatatime", "full"], + default="oneatatime", help="'oneatatime' varies one param at a time; 'full' = cartesian product", ) parser.add_argument( - "--params", nargs="+", default=None, - help="Only sweep these parameters (default: all). " - "Choices: " + ", ".join(SWEEP_RANGES.keys()), + "--params", + nargs="+", + default=None, + help="Only sweep these parameters (default: all). Choices: " + ", ".join(SWEEP_RANGES.keys()), ) - parser.add_argument("--output_dir", type=str, default=None, - help="Output directory (auto-generated if unset)") - parser.add_argument("--dry_run", action="store_true", - help="Print configs and exit without running") - parser.add_argument("--skip_validation", action="store_true", - help="Skip validation (faster, no correctness check)") - parser.add_argument("--timeout", type=int, default=600, - help="Per-config timeout in seconds (default: 600)") + parser.add_argument("--output_dir", type=str, default=None, help="Output directory (auto-generated if unset)") + parser.add_argument("--dry_run", action="store_true", help="Print configs and exit without running") + parser.add_argument("--skip_validation", action="store_true", help="Skip validation (faster, no correctness check)") + parser.add_argument("--timeout", type=int, default=600, help="Per-config timeout in seconds (default: 600)") args = parser.parse_args() M, N, K = args.m, args.n, args.k @@ -329,8 +353,7 @@ def main(): trace_dir.mkdir(exist_ok=True) # Generate configs - configs = generate_configs(baseline, SWEEP_RANGES, - mode=args.mode, params=args.params) + configs = generate_configs(baseline, SWEEP_RANGES, mode=args.mode, params=args.params) # Pre-validate all configs valid_configs = [] @@ -343,14 +366,14 @@ def main(): valid_configs.append(cfg) # Banner - print(f"\n{'='*100}") - print(f" HBM-Buffer All-Gather MatMul — Parameter Tuning") + print(f"\n{'=' * 100}") + print(" HBM-Buffer All-Gather MatMul — Parameter Tuning") print(f" M={M} N={N} K={K} nproc={args.nproc} mode={args.mode}") print(f" Baseline: {make_label(baseline)}") print(f" Configs to run: {len(valid_configs)} (skipped: {len(skipped)})") print(f" Output dir: {output_dir}") print(f" Validation: {'OFF' if args.skip_validation else 'ON'}") - print(f"{'='*100}") + print(f"{'=' * 100}") if skipped: print(f"\n Skipped (invalid for M={M}, N={N}, K={K}):") @@ -358,12 +381,12 @@ def main(): print(f" {make_label(cfg)}: {'; '.join(errs)}") if args.dry_run: - print(f"\n Configs that would be run:") + print("\n Configs that would be run:") for i, cfg in enumerate(valid_configs): label = make_label(cfg) - is_baseline = (cfg == baseline) + is_baseline = cfg == baseline tag = " [BASELINE]" if is_baseline else "" - print(f" [{i+1:>3}] {label}{tag}") + print(f" [{i + 1:>3}] {label}{tag}") print(f"\n Total: {len(valid_configs)} configs") return @@ -378,17 +401,21 @@ def main(): for i, cfg in enumerate(valid_configs): label = make_label(cfg) trace_path = str(trace_dir / f"trace_{label}.png") - is_first = (i == 0) + is_first = i == 0 sep = "-" * 80 print(f"\n{sep}") - print(f"[{i+1}/{len(valid_configs)}] {label}") + print(f"[{i + 1}/{len(valid_configs)}] {label}") if is_first: - print(f" (includes PyTorch baseline benchmark)") + print(" (includes PyTorch baseline benchmark)") print(sep) cmd = build_command( - cfg, M, N, K, trace_path, + cfg, + M, + N, + K, + trace_path, nproc=args.nproc, validate=not args.skip_validation, benchmark=True, @@ -400,8 +427,10 @@ def main(): t0 = time.time() try: proc = subprocess.run( - cmd, env=env, - capture_output=True, text=True, + cmd, + env=env, + capture_output=True, + text=True, timeout=args.timeout, ) elapsed = time.time() - t0 @@ -418,17 +447,19 @@ def main(): } trace_exists = os.path.exists(trace_path) - results.append({ - "label": label, - "config": cfg, - "iris_ms": parsed["iris_ms"], - "iris_tflops": parsed["iris_tflops"], - "iris_bw_gbps": parsed["iris_bw_gbps"], - "validation": parsed["validation"], - "trace_path": trace_path if trace_exists else None, - "elapsed_s": round(elapsed, 1), - "returncode": proc.returncode, - }) + results.append( + { + "label": label, + "config": cfg, + "iris_ms": parsed["iris_ms"], + "iris_tflops": parsed["iris_tflops"], + "iris_bw_gbps": parsed["iris_bw_gbps"], + "validation": parsed["validation"], + "trace_path": trace_path if trace_exists else None, + "elapsed_s": round(elapsed, 1), + "returncode": proc.returncode, + } + ) # Print summary line parts = [] @@ -440,16 +471,17 @@ def main(): if parsed["validation"]: parts.append(f"valid={parsed['validation']}") if trace_exists: - parts.append(f"trace=OK") + parts.append("trace=OK") else: - parts.append(f"trace=MISSING") + parts.append("trace=MISSING") if proc.returncode != 0: parts.append(f"EXIT={proc.returncode}") print(f" => {' | '.join(parts)} ({elapsed:.0f}s)") if is_first and pytorch_baseline: - print(f" => PyTorch baseline: {pytorch_baseline['tflops']:.2f} TFLOPS" - f" {pytorch_baseline['ms']:.3f} ms") + print( + f" => PyTorch baseline: {pytorch_baseline['tflops']:.2f} TFLOPS {pytorch_baseline['ms']:.3f} ms" + ) # Save full log for debugging log_path = output_dir / f"log_{label}.txt" @@ -464,51 +496,61 @@ def main(): except subprocess.TimeoutExpired: elapsed = time.time() - t0 - results.append({ - "label": label, - "config": cfg, - "iris_ms": None, - "iris_tflops": None, - "iris_bw_gbps": None, - "validation": "TIMEOUT", - "trace_path": None, - "elapsed_s": round(elapsed, 1), - "returncode": -1, - }) + results.append( + { + "label": label, + "config": cfg, + "iris_ms": None, + "iris_tflops": None, + "iris_bw_gbps": None, + "validation": "TIMEOUT", + "trace_path": None, + "elapsed_s": round(elapsed, 1), + "returncode": -1, + } + ) print(f" => TIMEOUT after {args.timeout}s") except Exception as e: elapsed = time.time() - t0 - results.append({ - "label": label, - "config": cfg, - "iris_ms": None, - "iris_tflops": None, - "iris_bw_gbps": None, - "validation": f"ERROR: {e}", - "trace_path": None, - "elapsed_s": round(elapsed, 1), - "returncode": -1, - }) + results.append( + { + "label": label, + "config": cfg, + "iris_ms": None, + "iris_tflops": None, + "iris_bw_gbps": None, + "validation": f"ERROR: {e}", + "trace_path": None, + "elapsed_s": round(elapsed, 1), + "returncode": -1, + } + ) print(f" => ERROR: {e}") total_elapsed = time.time() - total_start # ── Summary table ───────────────────────────────────────────────────── W = 130 - print(f"\n\n{'='*W}") - print(f" TUNING RESULTS | M={M} N={N} K={K} | nproc={args.nproc} | " - f"{len(valid_configs)} configs in {total_elapsed:.0f}s") + print(f"\n\n{'=' * W}") + print( + f" TUNING RESULTS | M={M} N={N} K={K} | nproc={args.nproc} | " + f"{len(valid_configs)} configs in {total_elapsed:.0f}s" + ) if pytorch_baseline: - print(f" PyTorch baseline: {pytorch_baseline['ms']:.3f} ms | " - f"{pytorch_baseline['tflops']:.2f} TFLOPS | " - f"{pytorch_baseline['bw_gbps']:.1f} GB/s") - print(f"{'='*W}") + print( + f" PyTorch baseline: {pytorch_baseline['ms']:.3f} ms | " + f"{pytorch_baseline['tflops']:.2f} TFLOPS | " + f"{pytorch_baseline['bw_gbps']:.1f} GB/s" + ) + print(f"{'=' * W}") col_label_w = 65 - print(f" {'#':>3} {'Configuration':<{col_label_w}} {'ms':>8} {'TFLOPS':>8} " - f"{'vs PT':>7} {'Valid':>8} {'Trace':>5}") - print(f" {'-'*(W-4)}") + print( + f" {'#':>3} {'Configuration':<{col_label_w}} {'ms':>8} {'TFLOPS':>8} " + f"{'vs PT':>7} {'Valid':>8} {'Trace':>5}" + ) + print(f" {'-' * (W - 4)}") for i, r in enumerate(results): ms_s = f"{r['iris_ms']:.3f}" if r["iris_ms"] is not None else "--" @@ -522,12 +564,20 @@ def main(): valid_s = (r["validation"] or "--")[:8] trace_s = "Y" if r.get("trace_path") else "N" - tag = " *" if (r["iris_tflops"] is not None and - r["iris_tflops"] == max((x["iris_tflops"] for x in results - if x["iris_tflops"] is not None), default=0)) else "" + tag = ( + " *" + if ( + r["iris_tflops"] is not None + and r["iris_tflops"] + == max((x["iris_tflops"] for x in results if x["iris_tflops"] is not None), default=0) + ) + else "" + ) - print(f" {i+1:>3} {r['label']:<{col_label_w}} {ms_s:>8} {tf_s:>8} " - f"{vs_pt:>7} {valid_s:>8} {trace_s:>5}{tag}") + print( + f" {i + 1:>3} {r['label']:<{col_label_w}} {ms_s:>8} {tf_s:>8} " + f"{vs_pt:>7} {valid_s:>8} {trace_s:>5}{tag}" + ) # Best config valid_results = [r for r in results if r["iris_tflops"] is not None] @@ -535,8 +585,7 @@ def main(): best = max(valid_results, key=lambda r: r["iris_tflops"]) worst = min(valid_results, key=lambda r: r["iris_tflops"]) print(f"\n {'BEST':>6}: {best['label']}") - print(f" {best['iris_ms']:.3f} ms | {best['iris_tflops']:.2f} TFLOPS | " - f"valid={best['validation']}") + print(f" {best['iris_ms']:.3f} ms | {best['iris_tflops']:.2f} TFLOPS | valid={best['validation']}") if pytorch_baseline and pytorch_baseline["tflops"] > 0: print(f" {best['iris_tflops'] / pytorch_baseline['tflops']:.2f}x vs PyTorch") if best.get("trace_path"): @@ -544,27 +593,36 @@ def main(): print(f" {'WORST':>6}: {worst['label']}") print(f" {worst['iris_ms']:.3f} ms | {worst['iris_tflops']:.2f} TFLOPS") if best["iris_tflops"] > 0 and worst["iris_tflops"] > 0: - print(f" SPREAD: {best['iris_tflops'] / worst['iris_tflops']:.2f}x " - f"({worst['iris_tflops']:.2f} → {best['iris_tflops']:.2f} TFLOPS)") + print( + f" SPREAD: {best['iris_tflops'] / worst['iris_tflops']:.2f}x " + f"({worst['iris_tflops']:.2f} → {best['iris_tflops']:.2f} TFLOPS)" + ) - print(f"{'='*W}") + print(f"{'=' * W}") # ── Save results JSON ───────────────────────────────────────────────── results_path = output_dir / "results.json" with open(results_path, "w") as f: - json.dump({ - "meta": { - "M": M, "N": N, "K": K, - "nproc": args.nproc, - "mode": args.mode, - "baseline": baseline, - "sweep_ranges": SWEEP_RANGES, - "timestamp": datetime.now().isoformat(), - "total_elapsed_s": round(total_elapsed, 1), - "pytorch_baseline": pytorch_baseline, + json.dump( + { + "meta": { + "M": M, + "N": N, + "K": K, + "nproc": args.nproc, + "mode": args.mode, + "baseline": baseline, + "sweep_ranges": SWEEP_RANGES, + "timestamp": datetime.now().isoformat(), + "total_elapsed_s": round(total_elapsed, 1), + "pytorch_baseline": pytorch_baseline, + }, + "results": results, }, - "results": results, - }, f, indent=2, default=str) + f, + indent=2, + default=str, + ) print(f"\n Results JSON : {results_path}") print(f" Trace PNGs : {trace_dir}/") diff --git a/iris/iris.py b/iris/iris.py index c283abf24..d061f09ea 100644 --- a/iris/iris.py +++ b/iris/iris.py @@ -1807,7 +1807,6 @@ def __translate(ptr, from_rank, to_rank, heap_bases): return translated_ptr - @triton.jit def __translate_block_2d(ptr, from_rank, to_rank, heap_bases): """ diff --git a/iris/ops/all_gather_matmul_hbm_buffer.py b/iris/ops/all_gather_matmul_hbm_buffer.py index e9a5b6b0b..8c2d94159 100644 --- a/iris/ops/all_gather_matmul_hbm_buffer.py +++ b/iris/ops/all_gather_matmul_hbm_buffer.py @@ -72,7 +72,6 @@ def _hbm_buffer_all_gather_matmul_kernel( acc_dtype = tl.int32 if C.type.element_ty == tl.int8 else tl.float32 zero = tl.program_id(0) * 0 - if TRACE: tl.store(trace_start_ptr + pid, read_realtime()) tl.store(trace_xcd_ptr + pid, get_xcc_id()) @@ -144,7 +143,7 @@ def _hbm_buffer_all_gather_matmul_kernel( tl.store(staged_ptrs, a_tile, cache_modifier=".cg") flag_idx = m_tile * NUM_FLAG_GROUPS_K + k_flag_group - #tl.atomic_xchg(flags_ptr + flag_idx, 1, sem="release", scope="gpu") + # tl.atomic_xchg(flags_ptr + flag_idx, 1, sem="release", scope="gpu") tl.store(flags_ptr + flag_idx, 1, cache_modifier=".wt") if TRACE: From 11d90019dc440b04ae78b7dcbe3623c40f9896c7 Mon Sep 17 00:00:00 2001 From: Ryan Swann Date: Tue, 3 Mar 2026 11:59:24 -0500 Subject: [PATCH 20/31] Add predictive params, fix pointer overflows, fix race conditions --- .../all_gather_matmul/benchmark_hbm_buffer.py | 90 ++- .../ops/all_gather_matmul/derive_params.py | 683 ++++++++++++++++++ iris/ops/all_gather_matmul_hbm_buffer.py | 8 +- 3 files changed, 766 insertions(+), 15 deletions(-) create mode 100644 benchmark/ops/all_gather_matmul/derive_params.py diff --git a/benchmark/ops/all_gather_matmul/benchmark_hbm_buffer.py b/benchmark/ops/all_gather_matmul/benchmark_hbm_buffer.py index 666a37a79..7978c0682 100644 --- a/benchmark/ops/all_gather_matmul/benchmark_hbm_buffer.py +++ b/benchmark/ops/all_gather_matmul/benchmark_hbm_buffer.py @@ -33,6 +33,35 @@ ) from iris.ops import FusedConfig +_DERIVE_AVAILABLE = False +try: + import sys as _sys + _script_dir = os.path.dirname(os.path.abspath(__file__)) + if _script_dir not in _sys.path: + _sys.path.insert(0, _script_dir) + from derive_params import ( + derive as _derive_params, + DEFAULT_NUM_CUS, + DEFAULT_PEAK_TFLOPS_FP16, + DEFAULT_HBM_BW_GBPS, + DEFAULT_L2_SIZE_BYTES, + DEFAULT_SCHEDULING_FACTOR, + ) + _DERIVE_AVAILABLE = True +except Exception: + pass + +_MODEL_PARAMS = ( + "block_size_m", "block_size_n", "block_size_k", "group_size_m", + "num_fetch_sms", "k_per_flag", "num_warps", + "num_fetch_stages", "first_stage_fetch_sms", +) + +_FALLBACK_DEFAULTS = { + "block_size_m": 256, "block_size_n": 64, "block_size_k": 64, + "group_size_m": 1, "k_per_flag": 1, "num_fetch_stages": 1, +} + torch.manual_seed(123) random.seed(123) @@ -230,23 +259,23 @@ def parse_args(): action="store_true", help="Also benchmark PyTorch (all_gather_into_tensor + matmul)", ) - parser.add_argument("--block_size_m", type=int, default=256, help="Block size M") - parser.add_argument("--block_size_n", type=int, default=64, help="Block size N") - parser.add_argument("--block_size_k", type=int, default=64, help="Block size K") - parser.add_argument("--group_size_m", type=int, default=1, help="Group size M") + parser.add_argument("--block_size_m", type=int, default=None, help="Block size M (model-derived if omitted)") + parser.add_argument("--block_size_n", type=int, default=None, help="Block size N (model-derived if omitted)") + parser.add_argument("--block_size_k", type=int, default=None, help="Block size K (model-derived if omitted)") + parser.add_argument("--group_size_m", type=int, default=None, help="Group size M (model-derived if omitted)") parser.add_argument("--num_xcds", type=int, default=None, help="Number of XCDs (auto if None)") parser.add_argument("--b_col_major", action="store_true", help="B col-major (K-contiguous)") parser.add_argument("--a_col_major", action="store_true", help="A col-major (M-contiguous)") parser.add_argument("--single-run", action="store_true", help="1 iteration (for profiling)") parser.add_argument("--num_fetch_sms", type=int, default=None, help="Fetcher SMs (auto if None)") - parser.add_argument("--k_per_flag", type=int, default=1, help="K-blocks per ready flag") + parser.add_argument("--k_per_flag", type=int, default=None, help="K-blocks per ready flag (model-derived if omitted)") parser.add_argument("--num_warps", type=int, default=None, help="Triton num_warps (auto if None)") parser.add_argument("--num_stages", type=int, default=None, help="Triton num_stages (auto if None)") parser.add_argument( "--num_fetch_stages", type=int, - default=1, - help="Number of fetch stages (1=all at once, 2=top/bottom half, etc.)", + default=None, + help="Number of fetch stages (model-derived if omitted)", ) parser.add_argument( "--first_stage_fetch_sms", @@ -254,11 +283,43 @@ def parse_args(): default=None, help="Fetcher WGs for stage 0 (fills first GPU wave; defaults to num_fetch_sms)", ) - parser.add_argument("--trace", action="store_true", help="Collect per-workgroup trace and save Gantt chart PNG") - parser.add_argument("--trace_output", type=str, default="trace_gantt.png", help="Output path for trace plot") + parser.add_argument("--trace", action=argparse.BooleanOptionalAction, default=True, help="Collect per-workgroup trace and save Gantt chart PNG") + parser.add_argument("--trace_output", type=str, default="trace.png", help="Output path for trace plot") return vars(parser.parse_args()) +def _apply_model_defaults(args, world_size, dtype_bytes=2): + """Fill None-valued kernel parameters with model-derived predictions. + + Returns a list of parameter names that were set by the model. + """ + applied = [] + if _DERIVE_AVAILABLE: + try: + p = _derive_params( + args["m"], args["n"], args["k"], world_size, + link_bw=50.0, + num_cus=DEFAULT_NUM_CUS, + peak_tflops=DEFAULT_PEAK_TFLOPS_FP16, + hbm_bw_gbps=DEFAULT_HBM_BW_GBPS, + l2_size=DEFAULT_L2_SIZE_BYTES, + scheduling_factor=DEFAULT_SCHEDULING_FACTOR, + dtype_bytes=dtype_bytes, + ) + for name in _MODEL_PARAMS: + if args.get(name) is None and name in p: + args[name] = p[name] + applied.append(name) + except Exception: + pass + + for name, fallback in _FALLBACK_DEFAULTS.items(): + if args.get(name) is None: + args[name] = fallback + + return applied + + def _worker(args): """Worker function for torchrun.""" local_rank = int(os.environ.get("RANK", os.environ.get("LOCAL_RANK", 0))) @@ -294,6 +355,14 @@ def _worker(args): datatype_map = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16} datatype = datatype_map.get(args["datatype"], torch.float16) + dtype_bytes = torch.tensor([], dtype=datatype).element_size() + + model_applied = _apply_model_defaults(args, world_size, dtype_bytes) + if rank == 0 and model_applied: + shmem.info(f"Model-derived defaults: {', '.join(model_applied)}") + if rank == 0: + param_summary = " ".join(f"{k}={args[k]}" for k in _MODEL_PARAMS) + shmem.info(f"Kernel params: {param_summary}") M = args["m"] N = args["n"] @@ -410,7 +479,8 @@ def run_experiment(): shmem.barrier() atol = 1e-1 if datatype == torch.float16 else 1e-3 - success = torch.allclose(C, expected_tensor, atol=atol) + rtol = 1e-2 if datatype == torch.float16 else 1e-5 + success = torch.allclose(C, expected_tensor, atol=atol, rtol=rtol) if not success: max_diff = torch.abs(C - expected_tensor).max().item() shmem.error(f"Rank {rank}: Validation FAILED, max diff: {max_diff}") diff --git a/benchmark/ops/all_gather_matmul/derive_params.py b/benchmark/ops/all_gather_matmul/derive_params.py new file mode 100644 index 000000000..539b298a5 --- /dev/null +++ b/benchmark/ops/all_gather_matmul/derive_params.py @@ -0,0 +1,683 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. + +""" +Parameter derivation for the HBM-buffered all_gather_matmul kernel. + +Given a problem size (M, N, K), world size, and per-link XGMI bandwidth, +derives kernel parameters that balance communication and computation in +the device-level pipeline. + +The kernel fuses all-gather with GEMM using two workgroup roles: + - Fetcher WGs: gather remote A tiles into an HBM staging buffer, + setting per-tile ready flags as data arrives. + - GEMM WGs: poll flags, then compute C += A_staged @ B tile-by-tile. + +The M dimension is split into `num_fetch_stages` pipeline stages. Each +stage's fetchers and GEMM WGs are interleaved in the launch grid so that +stage N+1's fetch overlaps with stage N's compute. + +Pipeline timeline (S stages): + |-- fetch stage 0 --|-- max(fetch, compute) * (S-1) --|-- compute last --| + +Usage: + python derive_params.py -m 131072 -n 2048 -k 16384 + python derive_params.py -m 196608 -n 2304 -k 16384 --link_bw 50 + python derive_params.py -m 196608 -n 2304 -k 16384 -v -b --trace + +When --link_bw is omitted the script automatically profiles the XGMI +link bandwidth by timing GPU-to-GPU copies across all peer pairs visible +from GPU 0. +""" + +import argparse +import math +import time + +# ── MI300X hardware defaults ────────────────────────────────────────────── +DEFAULT_NUM_CUS = 304 +DEFAULT_PEAK_TFLOPS_FP16 = 1300.0 +DEFAULT_HBM_BW_GBPS = 5300.0 +DEFAULT_L2_SIZE_BYTES = 256 * 1024 * 1024 +DEFAULT_NUM_XCDS = 8 +DEFAULT_WORLD_SIZE = 8 + +# Calibrated from MI300X trace data: the ratio of measured wall time to +# the CU-work-queue lower bound. Captures WG dispatch overhead, +# cross-XCD coherence latency, and pipeline bubble effects. +DEFAULT_SCHEDULING_FACTOR = 4.5 + + +def profile_link_bandwidth(world_size=DEFAULT_WORLD_SIZE): + """Measure per-link unidirectional XGMI bandwidth. + + Copies a 256 MB fp16 tensor from GPU 0 to every other visible GPU, + times the transfers with host-side timing after explicit device syncs, + and returns the conservative (min) per-link bandwidth. + """ + import torch + + n_gpus = torch.cuda.device_count() + if n_gpus < 2: + raise RuntimeError( + f"Need >= 2 visible GPUs for bandwidth profiling, found {n_gpus}. " + f"Pass --link_bw explicitly instead." + ) + + n_peers = min(world_size, n_gpus) - 1 + size_bytes = 256 * 1024 * 1024 + numel = size_bytes // 2 + warmup_iters = 10 + timed_iters = 40 + + print(f"\n── Link Bandwidth Profiling {'─' * 43}") + print(f" GPUs visible: {n_gpus}") + print(f" Testing: GPU 0 → GPUs 1..{n_peers}") + print(f" Transfer size: {size_bytes // (1024**2)} MB × {timed_iters} iterations\n") + + src = torch.empty(numel, dtype=torch.float16, device="cuda:0").normal_() + bandwidths = [] + + for peer in range(1, n_peers + 1): + dst = torch.empty(numel, dtype=torch.float16, device=f"cuda:{peer}") + + for _ in range(warmup_iters): + dst.copy_(src) + torch.cuda.synchronize(0) + torch.cuda.synchronize(peer) + + t_start = time.perf_counter() + for _ in range(timed_iters): + dst.copy_(src) + torch.cuda.synchronize(peer) + elapsed_s = time.perf_counter() - t_start + + bw = size_bytes * timed_iters / elapsed_s / 1e9 + bandwidths.append(bw) + print(f" GPU 0 → GPU {peer}: {bw:6.1f} GB/s") + + del dst + + del src + torch.cuda.empty_cache() + + bw_min = min(bandwidths) + bw_max = max(bandwidths) + bw_avg = sum(bandwidths) / len(bandwidths) + print(f"\n min = {bw_min:.1f} avg = {bw_avg:.1f} max = {bw_max:.1f} GB/s") + print(f" Using conservative (min): {bw_min:.1f} GB/s per link") + + return bw_min + + +# ── Tile / block size heuristics ────────────────────────────────────────── + +def _choose_block_sizes(M, N, K, K_local): + """Heuristic tile-size selection for MI300X MFMA.""" + bk = 64 + + bm = 256 if M >= 8192 else 128 + while M % bm != 0 and bm > 64: + bm //= 2 + + if N >= 512: + bn = 256 + elif N >= 256: + bn = 256 if N % 256 == 0 else 128 + else: + bn = 128 + while N % bn != 0 and bn > 32: + bn //= 2 + + while K % bk != 0 and bk > 16: + bk //= 2 + while K_local % bk != 0 and bk > 16: + bk //= 2 + + nw = 8 if bm * bn >= 256 * 256 else 4 + return bm, bn, bk, nw + + +def _choose_k_per_flag(num_k_blocks, num_k_blocks_local, target_groups=8): + """Pick k_per_flag so that flag groups align to rank boundaries when + possible, falling back to the largest divisor near the target.""" + if num_k_blocks % num_k_blocks_local == 0: + candidate = num_k_blocks_local + groups = num_k_blocks // candidate + if groups >= 4: + return candidate + + kpf = max(1, num_k_blocks // target_groups) + while num_k_blocks % kpf != 0 and kpf > 1: + kpf -= 1 + return kpf + + +# ── Per-tile roofline model ────────────────────────────────────────────── + +def _tile_roofline(bm, bn, bk, M, K, N, dtype_bytes, + peak_tflops, hbm_bw_gbps, l2_size): + """Compute achievable per-CU TFLOPS from tile arithmetic intensity. + + staged_a is always >> L2, so A tiles come from HBM. B may fit in L2 + only when staged_a is small enough that reads don't evict B. + Returns (roofline_tflops, tile_intensity, ridge_point, b_in_l2). + """ + tile_flops = 2 * bm * bn * bk + a_bytes = bm * bk * dtype_bytes + b_bytes = bk * bn * dtype_bytes + + b_total = K * N * dtype_bytes + staged_a_total = M * K * dtype_bytes + # When staged_a exceeds L2, streaming GEMM reads evict B regardless + # of B's absolute size. + b_in_l2 = (staged_a_total <= l2_size) and (b_total <= l2_size) + + hbm_bytes = a_bytes + (0 if b_in_l2 else b_bytes) + intensity = tile_flops / max(hbm_bytes, 1) + + ridge = peak_tflops * 1e3 / hbm_bw_gbps + if intensity >= ridge: + roofline = peak_tflops + else: + roofline = hbm_bw_gbps * intensity / 1e3 + + return roofline, intensity, ridge, b_in_l2 + + +# ── Per-WG execution time models ──────────────────────────────────────── + +def _gemm_wg_time_us(bm, bn, bk, K, num_flag_groups, + roofline_tflops, num_cus): + """Estimate per-WG GEMM execution time in microseconds. + + Uses the per-tile roofline to get the per-CU throughput, then applies + a calibrated overhead for memory-latency hiding and instruction + scheduling at single-WG occupancy (large tiles). + """ + total_flops = 2 * bm * bn * K + per_cu_tflops = roofline_tflops / num_cus + + # Roofline-ideal per-WG time + ideal_us = total_flops / (per_cu_tflops * 1e6) + + # Single-occupancy overhead: imperfect latency hiding, instruction + # scheduling gaps, cross-XCD coherence on staged_a reads. + # Calibrated from MI300X traces: actual/ideal ≈ 1.2-1.3. + occupancy_factor = 1.25 if bm * bn >= 256 * 256 else 1.10 + + # Flag polling: acquire-semantics atomic per flag group + flag_us = num_flag_groups * 2.5 + + return ideal_us * occupancy_factor + flag_us + + +def _fetch_wg_time_us(bm, bk, kpf, world_size, link_bw, + dtype_bytes, num_fgs_per_wg): + """Estimate per-fetcher-WG execution time in microseconds. + + Each flag group fetches kpf K-blocks (each BM × BK) from one rank. + Remote data traverses XGMI; local data uses HBM. + """ + bytes_per_fg = bm * kpf * bk * dtype_bytes + remote_frac = (world_size - 1) / world_size + + # XGMI gather: raw transfer + iris.x.gather software overhead + remote_bytes = bytes_per_fg * remote_frac + gather_overhead = 1.5 + xgmi_us = remote_bytes / (link_bw * 1e3) * gather_overhead + + # HBM write to staged_a (.cg → L2/HBM, per-WG share of bandwidth) + write_bw = 15.0 # GB/s effective per fetcher WG (calibrated from traces) + write_us = bytes_per_fg / (write_bw * 1e3) + + # Read and write overlap within each tile; dominant cost + flag-store + per_fg_us = max(xgmi_us, write_us) + 5.0 + + return num_fgs_per_wg * per_fg_us + + +# ── Kernel time estimation ─────────────────────────────────────────────── + +def _estimate_kernel_time(total_gemm_wgs, gemm_wg_us, + total_fetch_wgs, fetch_wg_us, + num_cus, scheduling_factor): + """Estimate kernel wall-clock time from the CU work queue model. + + total_CU_work / num_CUs gives the ideal (work-conserving) lower + bound. The scheduling_factor captures GPU dispatch overhead, + cross-XCD coherence, and pipeline bubble effects measured on MI300X. + """ + total_cu_work_us = (total_gemm_wgs * gemm_wg_us + + total_fetch_wgs * fetch_wg_us) + + ideal_ms = total_cu_work_us / num_cus / 1e3 + estimated_ms = ideal_ms * scheduling_factor + return estimated_ms, ideal_ms + + +# ── Pipeline stage selection ───────────────────────────────────────────── + +def _choose_fetch_stages(num_m_tiles, num_tiles_n, group_size_m, + comm_time_ms, compute_time_ms, num_cus): + """Choose num_fetch_stages for good pipeline efficiency while keeping + m_per_stage divisible by group_size_m.""" + ratio = comm_time_ms / compute_time_ms if compute_time_ms > 0 else 999 + + if ratio > 1.5: + ideal_stages = 32 + elif ratio > 0.8: + ideal_stages = 16 + elif ratio > 0.3: + ideal_stages = 8 + else: + ideal_stages = 4 + + min_gemm_tiles = max(num_cus // 4, 32) + min_m_per_stage = max(group_size_m, + math.ceil(min_gemm_tiles / max(num_tiles_n, 1))) + max_stages = max(1, num_m_tiles // min_m_per_stage) + num_stages = min(ideal_stages, max_stages) + num_stages = max(1, num_stages) + + m_per_stage = math.ceil(num_m_tiles / num_stages) + if m_per_stage % group_size_m != 0: + m_per_stage = ((m_per_stage + group_size_m - 1) + // group_size_m) * group_size_m + num_stages = max(1, math.ceil(num_m_tiles / m_per_stage)) + + m_per_stage = math.ceil(num_m_tiles / num_stages) + return num_stages, m_per_stage + + +# ── num_fetch_sms optimisation ─────────────────────────────────────────── + +def _choose_num_fetch_sms(m_per_stage, group_size_m, num_flag_groups_k, + link_bw, world_size, num_cus, + bm, bk, kpf, dtype_bytes, + gemm_wg_us, gemm_tiles_per_stage): + """Choose num_fetch_sms for good pipeline overlap. + + Balances three constraints: + 1. Flag delivery parallelism: ≥ m_per_stage so every M-tile gets + a fetcher early (good for reducing GEMM flag-poll stalls). + 2. Link saturation: enough concurrent fetchers to use the XGMI + aggregate bandwidth. + 3. CU budget: leave enough CUs for GEMM in the first dispatch wave. + + Returns (num_fetch_sms, per-WG timing info dict). + """ + total_fg_per_stage = num_flag_groups_k * m_per_stage + + # Constraint 1: one fetcher per M-group for broad flag delivery + parallel_min = max(1, m_per_stage // group_size_m) + + # Constraint 2: enough fetchers to keep XGMI links busy + per_fg_bytes = bm * kpf * bk * dtype_bytes + per_fg_remote = per_fg_bytes * (world_size - 1) / world_size + per_fg_xgmi_us = per_fg_remote / (link_bw * 1e3) * 1.5 + per_fg_write_us = per_fg_bytes / (15.0 * 1e3) + per_fg_us = max(per_fg_xgmi_us, per_fg_write_us) + 5.0 + + # Total flag groups per stage should finish within the stage GEMM time + gemm_waves = math.ceil(gemm_tiles_per_stage / num_cus) + stage_gemm_us = gemm_waves * gemm_wg_us + if per_fg_us > 0: + balance_min = max(1, math.ceil( + total_fg_per_stage * per_fg_us / stage_gemm_us)) + else: + balance_min = 1 + + nf = max(parallel_min, balance_min, 64) + nf = min(nf, num_cus // 2) + nf = max(1, nf) + + return nf + + +# ── Main derivation ────────────────────────────────────────────────────── + +def derive(M, N, K, world_size, link_bw, num_cus, peak_tflops, + hbm_bw_gbps, l2_size, scheduling_factor, dtype_bytes): + K_local = K // world_size + + # 1. Tile sizes + bm, bn, bk, nw = _choose_block_sizes(M, N, K, K_local) + gm = 4 + num_m_tiles = M // bm + num_tiles_n = math.ceil(N / bn) + num_k_blocks = K // bk + num_k_blocks_local = K_local // bk + + # 2. Per-tile roofline + roofline_tflops, intensity, ridge, b_in_l2 = _tile_roofline( + bm, bn, bk, M, K, N, dtype_bytes, peak_tflops, hbm_bw_gbps, l2_size) + + # 3. Communication model (link-limited) + total_remote_bytes = M * K_local * (world_size - 1) * dtype_bytes + total_link_bw = link_bw * (world_size - 1) + comm_time_ms = total_remote_bytes / (total_link_bw * 1e9) * 1e3 + + # 4. Compute model (roofline-limited) + total_flops = 2 * M * N * K + compute_time_ms = total_flops / (roofline_tflops * 1e12) * 1e3 + + ratio = comm_time_ms / compute_time_ms if compute_time_ms > 0 else 999 + + # 5. k_per_flag + kpf = _choose_k_per_flag(num_k_blocks, num_k_blocks_local) + num_flag_groups_k = num_k_blocks // kpf + + # 6. Pipeline stages + num_stages, m_per_stage = _choose_fetch_stages( + num_m_tiles, num_tiles_n, gm, comm_time_ms, compute_time_ms, num_cus) + gemm_tiles_per_stage = m_per_stage * num_tiles_n + + # 7. first_stage_fetch_sms: use all CUs to fill the pipeline ASAP + fsf = num_cus + + # 8. Per-WG timing + gemm_wg_us_val = _gemm_wg_time_us(bm, bn, bk, K, num_flag_groups_k, + roofline_tflops, num_cus) + + # 9. Choose num_fetch_sms + nf = _choose_num_fetch_sms( + m_per_stage, gm, num_flag_groups_k, + link_bw, world_size, num_cus, + bm, bk, kpf, dtype_bytes, + gemm_wg_us_val, gemm_tiles_per_stage) + + # 10. Compute per-WG fetch times + total_fg_per_stage = num_flag_groups_k * m_per_stage + fgs_per_wg_stg0 = max(1, math.ceil(total_fg_per_stage / fsf)) + fgs_per_wg_rest = max(1, math.ceil(total_fg_per_stage / nf)) + fetch_us_stg0 = _fetch_wg_time_us(bm, bk, kpf, world_size, + link_bw, dtype_bytes, fgs_per_wg_stg0) + fetch_us_rest = _fetch_wg_time_us(bm, bk, kpf, world_size, + link_bw, dtype_bytes, fgs_per_wg_rest) + + # 11. Grid geometry + first_stage_size = fsf + gemm_tiles_per_stage + rest_stage_size = nf + gemm_tiles_per_stage + grid_size = first_stage_size + rest_stage_size * max(0, num_stages - 1) + total_fetch_wgs = fsf + nf * max(0, num_stages - 1) + total_gemm_wgs = gemm_tiles_per_stage * num_stages + + # 12. Kernel time estimate (CU-work model) + avg_fetch_us = (fsf * fetch_us_stg0 + nf * max(0, num_stages - 1) * fetch_us_rest) + avg_fetch_us /= max(total_fetch_wgs, 1) + est_kernel_ms, est_ideal_ms = _estimate_kernel_time( + total_gemm_wgs, gemm_wg_us_val, + total_fetch_wgs, avg_fetch_us, + num_cus, scheduling_factor) + + # 13. Link-limited pipeline estimate (simple model for comparison) + stage_m = m_per_stage * bm + stage_comm_ms = (stage_m * K_local * (world_size - 1) * dtype_bytes + / (total_link_bw * 1e9) * 1e3) + stage_compute_ms = (2 * stage_m * N * K + / (roofline_tflops * 1e12) * 1e3) + startup_ms = stage_comm_ms + steady_ms = max(stage_comm_ms, stage_compute_ms) * max(0, num_stages - 1) + drain_ms = stage_compute_ms + pipeline_ms = startup_ms + steady_ms + drain_ms + sequential_ms = comm_time_ms + compute_time_ms + + # 14. Standalone GEMM estimate (rocBLAS-class efficiency for comparison) + standalone_gemm_eff = 0.30 + standalone_tflops = roofline_tflops * standalone_gemm_eff + standalone_gemm_ms = total_flops / (standalone_tflops * 1e12) * 1e3 + pytorch_est_ms = comm_time_ms + standalone_gemm_ms + + staged_a_gb = M * K * dtype_bytes / (1024**3) + + return dict( + block_size_m=bm, block_size_n=bn, block_size_k=bk, + group_size_m=gm, num_warps=nw, + num_fetch_sms=nf, k_per_flag=kpf, + num_fetch_stages=num_stages, first_stage_fetch_sms=fsf, + # derived + K_local=K_local, num_m_tiles=num_m_tiles, num_tiles_n=num_tiles_n, + num_k_blocks=num_k_blocks, num_flag_groups_k=num_flag_groups_k, + m_per_stage=m_per_stage, gemm_tiles_per_stage=gemm_tiles_per_stage, + grid_size=grid_size, total_fetch_wgs=total_fetch_wgs, + total_gemm_wgs=total_gemm_wgs, + # roofline + roofline_tflops=roofline_tflops, tile_intensity=intensity, + ridge_point=ridge, b_in_l2=b_in_l2, + # per-WG timing + gemm_wg_us=gemm_wg_us_val, + fetch_wg_us_stg0=fetch_us_stg0, + fetch_wg_us_rest=fetch_us_rest, + # estimates + total_remote_bytes=total_remote_bytes, total_link_bw=total_link_bw, + comm_time_ms=comm_time_ms, total_flops=total_flops, + compute_time_ms=compute_time_ms, ratio=ratio, + stage_comm_ms=stage_comm_ms, stage_compute_ms=stage_compute_ms, + pipeline_ms=pipeline_ms, sequential_ms=sequential_ms, + est_kernel_ms=est_kernel_ms, + est_ideal_ms=est_ideal_ms, + standalone_gemm_ms=standalone_gemm_ms, + pytorch_est_ms=pytorch_est_ms, + staged_a_gb=staged_a_gb, + scheduling_factor=scheduling_factor, + ) + + +# ── Formatting helpers ─────────────────────────────────────────────────── + +def _fmt_bytes(n): + if n >= 1024**3: + return f"{n / 1024**3:.2f} GB" + if n >= 1024**2: + return f"{n / 1024**2:.1f} MB" + return f"{n / 1024:.1f} KB" + + +def _fmt_flops(n): + if n >= 1e15: + return f"{n / 1e15:.2f} PFLOPs" + return f"{n / 1e12:.2f} TFLOPs" + + +def _fmt_tflops(t): + return f"{t:.0f} TFLOPS" + + +# ── Analysis output ────────────────────────────────────────────────────── + +def print_analysis(M, N, K, world_size, link_bw, p, passthrough_args, + bw_profiled=False): + K_local = p["K_local"] + dtype_bytes = 2 + bound = "COMM-BOUND" if p["ratio"] > 1.0 else "COMPUTE-BOUND" + + print("=" * 72) + print(" All-Gather Matmul HBM-Buffer — Parameter Derivation") + print("=" * 72) + + # ── Problem ─────────────────────────────────────────────────────── + print(f"\n{'Problem':>14}: C({M}, {N}) = all_gather(A_shard({M}, {K_local})) @ B({K}, {N})") + print(f"{'World size':>14}: {world_size} GPUs") + print(f"{'Dtype':>14}: fp16 ({dtype_bytes}B)") + + # ── Data sizes ──────────────────────────────────────────────────── + a_shard = M * K_local * dtype_bytes + b_size = K * N * dtype_bytes + c_size = M * N * dtype_bytes + staged = M * K * dtype_bytes + print(f"\n{'A_shard':>14}: ({M}, {K_local}) {_fmt_bytes(a_shard)}") + print(f"{'B':>14}: ({K}, {N}) {_fmt_bytes(b_size)}") + print(f"{'C':>14}: ({M}, {N}) {_fmt_bytes(c_size)}") + print(f"{'staged_a':>14}: ({M}, {K}) {_fmt_bytes(staged)}") + if staged > 4 * 1024**3: + print(f"{'':>14} *** > 4 GB: requires int64 pointer arithmetic ***") + + # ── Per-tile roofline ───────────────────────────────────────────── + print(f"\n── Roofline {'─' * 59}") + print(f"{'Tile':>14}: ({p['block_size_m']}, {p['block_size_n']}, {p['block_size_k']})") + print(f"{'Intensity':>14}: {p['tile_intensity']:.0f} FLOPs/byte " + f"{'(B in L2)' if p['b_in_l2'] else '(B from HBM)'}") + print(f"{'Ridge point':>14}: {p['ridge_point']:.0f} FLOPs/byte") + region = "COMPUTE" if p["tile_intensity"] >= p["ridge_point"] else "MEMORY" + print(f"{'Roofline':>14}: {_fmt_tflops(p['roofline_tflops'])} ({region}-bound tiles)") + + # ── Communication ───────────────────────────────────────────────── + print(f"\n── Communication {'─' * 54}") + print(f"{'Remote bytes':>14}: {_fmt_bytes(p['total_remote_bytes'])} " + f"(from {world_size - 1} peers)") + bw_src = "profiled" if bw_profiled else "user" + print(f"{'Link BW':>14}: {link_bw:.1f} GB/s/link × {world_size - 1} links " + f"= {p['total_link_bw']:.0f} GB/s aggregate ({bw_src})") + print(f"{'Comm time':>14}: {p['comm_time_ms']:.3f} ms (link-limited)") + + # ── Compute ─────────────────────────────────────────────────────── + print(f"\n── Compute {'─' * 60}") + print(f"{'Total FLOPs':>14}: {_fmt_flops(p['total_flops'])}") + print(f"{'Roofline time':>14}: {p['compute_time_ms']:.3f} ms " + f"(at {_fmt_tflops(p['roofline_tflops'])})") + print(f"{'Comm/Compute':>14}: {p['ratio']:.2f}x → {bound}") + + # ── Per-WG timing ───────────────────────────────────────────────── + print(f"\n── Per-WG Model {'─' * 55}") + print(f"{'GEMM WG':>14}: {p['gemm_wg_us']:.0f} us " + f"({p['total_flops'] / p['total_gemm_wgs'] / 1e9:.2f} GFLOPs/WG)") + print(f"{'Fetch WG stg0':>14}: {p['fetch_wg_us_stg0']:.0f} us") + if p["num_fetch_stages"] > 1: + print(f"{'Fetch WG rest':>14}: {p['fetch_wg_us_rest']:.0f} us") + + # ── Pipeline ────────────────────────────────────────────────────── + S = p["num_fetch_stages"] + print(f"\n── Pipeline {'─' * 59}") + print(f"{'Stages (S)':>14}: {S}") + print(f"{'M tiles/stage':>14}: {p['m_per_stage']} ({p['m_per_stage'] * p['block_size_m']} rows)") + print(f"{'GEMM WGs/stg':>14}: {p['gemm_tiles_per_stage']} " + f"({p['m_per_stage']} m-tiles × {p['num_tiles_n']} n-tiles)") + print(f"{'K flag groups':>14}: {p['num_flag_groups_k']} " + f"(k_per_flag={p['k_per_flag']})") + print(f"{'Stage comm':>14}: {p['stage_comm_ms']:.3f} ms") + print(f"{'Stage compute':>14}: {p['stage_compute_ms']:.3f} ms") + + # ── Grid ────────────────────────────────────────────────────────── + print(f"\n── Grid Layout {'─' * 56}") + print(f"{'Stage 0':>14}: {p['first_stage_fetch_sms']} fetchers + " + f"{p['gemm_tiles_per_stage']} GEMM = " + f"{p['first_stage_fetch_sms'] + p['gemm_tiles_per_stage']} WGs") + if S > 1: + print(f"{'Stages 1..{}'.format(S - 1):>14}: {p['num_fetch_sms']} fetchers + " + f"{p['gemm_tiles_per_stage']} GEMM = " + f"{p['num_fetch_sms'] + p['gemm_tiles_per_stage']} WGs (×{S - 1})") + print(f"{'Total grid':>14}: {p['grid_size']} WGs " + f"({p['total_fetch_wgs']} fetch + {p['total_gemm_wgs']} GEMM)") + + # ── Time estimates ──────────────────────────────────────────────── + print(f"\n── Time Estimates {'─' * 53}") + print(f"{'CU-work lower':>14}: {p['est_ideal_ms']:.1f} ms " + f"(total WG time / {DEFAULT_NUM_CUS} CUs)") + print(f"{'Fused kernel':>14}: {p['est_kernel_ms']:.1f} ms " + f"(×{p['scheduling_factor']:.1f} scheduling overhead)") + est_tflops = p["total_flops"] / (p["est_kernel_ms"] * 1e-3) / 1e12 + print(f"{'Est. TFLOPS':>14}: {est_tflops:.0f} TFLOPS " + f"({est_tflops / p['roofline_tflops'] * 100:.0f}% of roofline)") + print(f"{'':>14}") + print(f"{'PyTorch est.':>14}: {p['pytorch_est_ms']:.1f} ms " + f"(all_gather {p['comm_time_ms']:.1f} + matmul {p['standalone_gemm_ms']:.1f})") + if p["est_kernel_ms"] < p["pytorch_est_ms"]: + speedup = p["pytorch_est_ms"] / p["est_kernel_ms"] + print(f"{'Fused speedup':>14}: {speedup:.2f}x over sequential PyTorch") + else: + slowdown = p["est_kernel_ms"] / p["pytorch_est_ms"] + print(f"{'Fused speedup':>14}: {1/slowdown:.2f}x (slower than sequential by {slowdown:.2f}x)") + + # ── Recommended parameters ──────────────────────────────────────── + print(f"\n── Recommended Kernel Parameters {'─' * 38}") + params = [ + ("block_size_m", p["block_size_m"]), + ("block_size_n", p["block_size_n"]), + ("block_size_k", p["block_size_k"]), + ("group_size_m", p["group_size_m"]), + ("num_fetch_sms", p["num_fetch_sms"]), + ("k_per_flag", p["k_per_flag"]), + ("num_warps", p["num_warps"]), + ("num_fetch_stages", p["num_fetch_stages"]), + ("first_stage_fetch_sms", p["first_stage_fetch_sms"]), + ] + for name, val in params: + print(f" --{name:30s} {val}") + + # ── Command line ────────────────────────────────────────────────── + extra = " ".join(passthrough_args) + if extra: + extra = " " + extra + cmd = ( + f"HSA_NO_SCRATCH_RECLAIM=1 torchrun --nproc_per_node {world_size} " + f"benchmark/ops/all_gather_matmul/benchmark_hbm_buffer.py " + f"-m {M} -n {N} -k {K} " + f"--block_size_m {p['block_size_m']} " + f"--block_size_n {p['block_size_n']} " + f"--block_size_k {p['block_size_k']} " + f"--group_size_m {p['group_size_m']} " + f"--num_fetch_sms {p['num_fetch_sms']} " + f"--k_per_flag {p['k_per_flag']} " + f"--num_warps {p['num_warps']} " + f"--num_fetch_stages {p['num_fetch_stages']} " + f"--first_stage_fetch_sms {p['first_stage_fetch_sms']}" + f"{extra}" + ) + print(f"\n── Command {'─' * 60}") + print(f" {cmd}") + print() + + +def main(): + parser = argparse.ArgumentParser( + description="Derive parameters for HBM-buffered all_gather_matmul.", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=__doc__, + ) + parser.add_argument("-m", type=int, required=True, help="M dimension (rows of output)") + parser.add_argument("-n", type=int, required=True, help="N dimension (cols of output)") + parser.add_argument("-k", type=int, required=True, help="K dimension (total reduction dim)") + parser.add_argument("--world_size", type=int, default=DEFAULT_WORLD_SIZE, + help="Number of GPUs") + parser.add_argument("--link_bw", type=float, default=None, + help="Per-link XGMI bandwidth in GB/s (one direction). " + "Omit to auto-profile via GPU-to-GPU copies.") + parser.add_argument("--num_cus", type=int, default=DEFAULT_NUM_CUS, + help="Number of compute units") + parser.add_argument("--peak_tflops", type=float, default=DEFAULT_PEAK_TFLOPS_FP16, + help="Peak fp16 TFLOPS") + parser.add_argument("--hbm_bw", type=float, default=DEFAULT_HBM_BW_GBPS, + help="HBM bandwidth in GB/s") + parser.add_argument("--scheduling_factor", type=float, + default=DEFAULT_SCHEDULING_FACTOR, + help="CU scheduling overhead factor (calibrated from traces)") + + args, passthrough = parser.parse_known_args() + + if args.k % args.world_size != 0: + parser.error(f"K ({args.k}) must be divisible by world_size ({args.world_size})") + + link_bw = args.link_bw + bw_profiled = False + if link_bw is None: + try: + link_bw = profile_link_bandwidth(args.world_size) + bw_profiled = True + except Exception as e: + print(f"\n Auto-profiling failed: {e}") + print(" Falling back to --link_bw 50 (MI300X default)\n") + link_bw = 50.0 + + p = derive(args.m, args.n, args.k, args.world_size, link_bw, + args.num_cus, args.peak_tflops, args.hbm_bw, + DEFAULT_L2_SIZE_BYTES, args.scheduling_factor, + dtype_bytes=2) + + print_analysis(args.m, args.n, args.k, args.world_size, link_bw, + p, passthrough, bw_profiled=bw_profiled) + + +if __name__ == "__main__": + main() diff --git a/iris/ops/all_gather_matmul_hbm_buffer.py b/iris/ops/all_gather_matmul_hbm_buffer.py index 8c2d94159..36010e24f 100644 --- a/iris/ops/all_gather_matmul_hbm_buffer.py +++ b/iris/ops/all_gather_matmul_hbm_buffer.py @@ -135,7 +135,7 @@ def _hbm_buffer_all_gather_matmul_kernel( k_tile = iris.x.TileView(pid_m_t, tile_k_t, BLOCK_SIZE_M, BLOCK_SIZE_K) rk = k_block_global * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) - staged_ptrs = staged_a + rm[:, None] * stride_sa_m + rk[None, :] * stride_sa_k + staged_ptrs = staged_a + rm.to(tl.int64)[:, None] * stride_sa_m + rk[None, :] * stride_sa_k for compile_rank in range(world_size): if src_rank_idx == compile_rank: @@ -143,8 +143,7 @@ def _hbm_buffer_all_gather_matmul_kernel( tl.store(staged_ptrs, a_tile, cache_modifier=".cg") flag_idx = m_tile * NUM_FLAG_GROUPS_K + k_flag_group - # tl.atomic_xchg(flags_ptr + flag_idx, 1, sem="release", scope="gpu") - tl.store(flags_ptr + flag_idx, 1, cache_modifier=".wt") + tl.atomic_xchg(flags_ptr + flag_idx, 1, sem="release", scope="gpu") if TRACE: tl.store(trace_wait_ptr + pid, zero.to(tl.int64), cache_modifier=".wt") @@ -193,8 +192,7 @@ def _hbm_buffer_all_gather_matmul_kernel( rk = k_block * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) rk = tl.max_contiguous(tl.multiple_of(rk, BLOCK_SIZE_K), BLOCK_SIZE_K) - # Use parameterized strides for staged_a - a_ptrs = staged_a + rm[:, None] * stride_sa_m + rk[None, :] * stride_sa_k + a_ptrs = staged_a + rm.to(tl.int64)[:, None] * stride_sa_m + rk[None, :] * stride_sa_k a = tl.load(a_ptrs) B_ptrs = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn From 3c4cb4dfa02cb5be71f84c32101011165bd57015 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Tue, 3 Mar 2026 17:00:39 +0000 Subject: [PATCH 21/31] Apply Ruff auto-fixes --- .../all_gather_matmul/benchmark_hbm_buffer.py | 38 ++- .../ops/all_gather_matmul/derive_params.py | 270 ++++++++++-------- 2 files changed, 184 insertions(+), 124 deletions(-) diff --git a/benchmark/ops/all_gather_matmul/benchmark_hbm_buffer.py b/benchmark/ops/all_gather_matmul/benchmark_hbm_buffer.py index 7978c0682..190799986 100644 --- a/benchmark/ops/all_gather_matmul/benchmark_hbm_buffer.py +++ b/benchmark/ops/all_gather_matmul/benchmark_hbm_buffer.py @@ -36,6 +36,7 @@ _DERIVE_AVAILABLE = False try: import sys as _sys + _script_dir = os.path.dirname(os.path.abspath(__file__)) if _script_dir not in _sys.path: _sys.path.insert(0, _script_dir) @@ -47,19 +48,30 @@ DEFAULT_L2_SIZE_BYTES, DEFAULT_SCHEDULING_FACTOR, ) + _DERIVE_AVAILABLE = True except Exception: pass _MODEL_PARAMS = ( - "block_size_m", "block_size_n", "block_size_k", "group_size_m", - "num_fetch_sms", "k_per_flag", "num_warps", - "num_fetch_stages", "first_stage_fetch_sms", + "block_size_m", + "block_size_n", + "block_size_k", + "group_size_m", + "num_fetch_sms", + "k_per_flag", + "num_warps", + "num_fetch_stages", + "first_stage_fetch_sms", ) _FALLBACK_DEFAULTS = { - "block_size_m": 256, "block_size_n": 64, "block_size_k": 64, - "group_size_m": 1, "k_per_flag": 1, "num_fetch_stages": 1, + "block_size_m": 256, + "block_size_n": 64, + "block_size_k": 64, + "group_size_m": 1, + "k_per_flag": 1, + "num_fetch_stages": 1, } torch.manual_seed(123) @@ -268,7 +280,9 @@ def parse_args(): parser.add_argument("--a_col_major", action="store_true", help="A col-major (M-contiguous)") parser.add_argument("--single-run", action="store_true", help="1 iteration (for profiling)") parser.add_argument("--num_fetch_sms", type=int, default=None, help="Fetcher SMs (auto if None)") - parser.add_argument("--k_per_flag", type=int, default=None, help="K-blocks per ready flag (model-derived if omitted)") + parser.add_argument( + "--k_per_flag", type=int, default=None, help="K-blocks per ready flag (model-derived if omitted)" + ) parser.add_argument("--num_warps", type=int, default=None, help="Triton num_warps (auto if None)") parser.add_argument("--num_stages", type=int, default=None, help="Triton num_stages (auto if None)") parser.add_argument( @@ -283,7 +297,12 @@ def parse_args(): default=None, help="Fetcher WGs for stage 0 (fills first GPU wave; defaults to num_fetch_sms)", ) - parser.add_argument("--trace", action=argparse.BooleanOptionalAction, default=True, help="Collect per-workgroup trace and save Gantt chart PNG") + parser.add_argument( + "--trace", + action=argparse.BooleanOptionalAction, + default=True, + help="Collect per-workgroup trace and save Gantt chart PNG", + ) parser.add_argument("--trace_output", type=str, default="trace.png", help="Output path for trace plot") return vars(parser.parse_args()) @@ -297,7 +316,10 @@ def _apply_model_defaults(args, world_size, dtype_bytes=2): if _DERIVE_AVAILABLE: try: p = _derive_params( - args["m"], args["n"], args["k"], world_size, + args["m"], + args["n"], + args["k"], + world_size, link_bw=50.0, num_cus=DEFAULT_NUM_CUS, peak_tflops=DEFAULT_PEAK_TFLOPS_FP16, diff --git a/benchmark/ops/all_gather_matmul/derive_params.py b/benchmark/ops/all_gather_matmul/derive_params.py index 539b298a5..cf4acd9fe 100644 --- a/benchmark/ops/all_gather_matmul/derive_params.py +++ b/benchmark/ops/all_gather_matmul/derive_params.py @@ -61,8 +61,7 @@ def profile_link_bandwidth(world_size=DEFAULT_WORLD_SIZE): n_gpus = torch.cuda.device_count() if n_gpus < 2: raise RuntimeError( - f"Need >= 2 visible GPUs for bandwidth profiling, found {n_gpus}. " - f"Pass --link_bw explicitly instead." + f"Need >= 2 visible GPUs for bandwidth profiling, found {n_gpus}. Pass --link_bw explicitly instead." ) n_peers = min(world_size, n_gpus) - 1 @@ -113,6 +112,7 @@ def profile_link_bandwidth(world_size=DEFAULT_WORLD_SIZE): # ── Tile / block size heuristics ────────────────────────────────────────── + def _choose_block_sizes(M, N, K, K_local): """Heuristic tile-size selection for MI300X MFMA.""" bk = 64 @@ -156,8 +156,8 @@ def _choose_k_per_flag(num_k_blocks, num_k_blocks_local, target_groups=8): # ── Per-tile roofline model ────────────────────────────────────────────── -def _tile_roofline(bm, bn, bk, M, K, N, dtype_bytes, - peak_tflops, hbm_bw_gbps, l2_size): + +def _tile_roofline(bm, bn, bk, M, K, N, dtype_bytes, peak_tflops, hbm_bw_gbps, l2_size): """Compute achievable per-CU TFLOPS from tile arithmetic intensity. staged_a is always >> L2, so A tiles come from HBM. B may fit in L2 @@ -188,8 +188,8 @@ def _tile_roofline(bm, bn, bk, M, K, N, dtype_bytes, # ── Per-WG execution time models ──────────────────────────────────────── -def _gemm_wg_time_us(bm, bn, bk, K, num_flag_groups, - roofline_tflops, num_cus): + +def _gemm_wg_time_us(bm, bn, bk, K, num_flag_groups, roofline_tflops, num_cus): """Estimate per-WG GEMM execution time in microseconds. Uses the per-tile roofline to get the per-CU throughput, then applies @@ -213,8 +213,7 @@ def _gemm_wg_time_us(bm, bn, bk, K, num_flag_groups, return ideal_us * occupancy_factor + flag_us -def _fetch_wg_time_us(bm, bk, kpf, world_size, link_bw, - dtype_bytes, num_fgs_per_wg): +def _fetch_wg_time_us(bm, bk, kpf, world_size, link_bw, dtype_bytes, num_fgs_per_wg): """Estimate per-fetcher-WG execution time in microseconds. Each flag group fetches kpf K-blocks (each BM × BK) from one rank. @@ -240,17 +239,15 @@ def _fetch_wg_time_us(bm, bk, kpf, world_size, link_bw, # ── Kernel time estimation ─────────────────────────────────────────────── -def _estimate_kernel_time(total_gemm_wgs, gemm_wg_us, - total_fetch_wgs, fetch_wg_us, - num_cus, scheduling_factor): + +def _estimate_kernel_time(total_gemm_wgs, gemm_wg_us, total_fetch_wgs, fetch_wg_us, num_cus, scheduling_factor): """Estimate kernel wall-clock time from the CU work queue model. total_CU_work / num_CUs gives the ideal (work-conserving) lower bound. The scheduling_factor captures GPU dispatch overhead, cross-XCD coherence, and pipeline bubble effects measured on MI300X. """ - total_cu_work_us = (total_gemm_wgs * gemm_wg_us + - total_fetch_wgs * fetch_wg_us) + total_cu_work_us = total_gemm_wgs * gemm_wg_us + total_fetch_wgs * fetch_wg_us ideal_ms = total_cu_work_us / num_cus / 1e3 estimated_ms = ideal_ms * scheduling_factor @@ -259,8 +256,8 @@ def _estimate_kernel_time(total_gemm_wgs, gemm_wg_us, # ── Pipeline stage selection ───────────────────────────────────────────── -def _choose_fetch_stages(num_m_tiles, num_tiles_n, group_size_m, - comm_time_ms, compute_time_ms, num_cus): + +def _choose_fetch_stages(num_m_tiles, num_tiles_n, group_size_m, comm_time_ms, compute_time_ms, num_cus): """Choose num_fetch_stages for good pipeline efficiency while keeping m_per_stage divisible by group_size_m.""" ratio = comm_time_ms / compute_time_ms if compute_time_ms > 0 else 999 @@ -275,16 +272,14 @@ def _choose_fetch_stages(num_m_tiles, num_tiles_n, group_size_m, ideal_stages = 4 min_gemm_tiles = max(num_cus // 4, 32) - min_m_per_stage = max(group_size_m, - math.ceil(min_gemm_tiles / max(num_tiles_n, 1))) + min_m_per_stage = max(group_size_m, math.ceil(min_gemm_tiles / max(num_tiles_n, 1))) max_stages = max(1, num_m_tiles // min_m_per_stage) num_stages = min(ideal_stages, max_stages) num_stages = max(1, num_stages) m_per_stage = math.ceil(num_m_tiles / num_stages) if m_per_stage % group_size_m != 0: - m_per_stage = ((m_per_stage + group_size_m - 1) - // group_size_m) * group_size_m + m_per_stage = ((m_per_stage + group_size_m - 1) // group_size_m) * group_size_m num_stages = max(1, math.ceil(num_m_tiles / m_per_stage)) m_per_stage = math.ceil(num_m_tiles / num_stages) @@ -293,10 +288,21 @@ def _choose_fetch_stages(num_m_tiles, num_tiles_n, group_size_m, # ── num_fetch_sms optimisation ─────────────────────────────────────────── -def _choose_num_fetch_sms(m_per_stage, group_size_m, num_flag_groups_k, - link_bw, world_size, num_cus, - bm, bk, kpf, dtype_bytes, - gemm_wg_us, gemm_tiles_per_stage): + +def _choose_num_fetch_sms( + m_per_stage, + group_size_m, + num_flag_groups_k, + link_bw, + world_size, + num_cus, + bm, + bk, + kpf, + dtype_bytes, + gemm_wg_us, + gemm_tiles_per_stage, +): """Choose num_fetch_sms for good pipeline overlap. Balances three constraints: @@ -324,8 +330,7 @@ def _choose_num_fetch_sms(m_per_stage, group_size_m, num_flag_groups_k, gemm_waves = math.ceil(gemm_tiles_per_stage / num_cus) stage_gemm_us = gemm_waves * gemm_wg_us if per_fg_us > 0: - balance_min = max(1, math.ceil( - total_fg_per_stage * per_fg_us / stage_gemm_us)) + balance_min = max(1, math.ceil(total_fg_per_stage * per_fg_us / stage_gemm_us)) else: balance_min = 1 @@ -338,8 +343,8 @@ def _choose_num_fetch_sms(m_per_stage, group_size_m, num_flag_groups_k, # ── Main derivation ────────────────────────────────────────────────────── -def derive(M, N, K, world_size, link_bw, num_cus, peak_tflops, - hbm_bw_gbps, l2_size, scheduling_factor, dtype_bytes): + +def derive(M, N, K, world_size, link_bw, num_cus, peak_tflops, hbm_bw_gbps, l2_size, scheduling_factor, dtype_bytes): K_local = K // world_size # 1. Tile sizes @@ -352,7 +357,8 @@ def derive(M, N, K, world_size, link_bw, num_cus, peak_tflops, # 2. Per-tile roofline roofline_tflops, intensity, ridge, b_in_l2 = _tile_roofline( - bm, bn, bk, M, K, N, dtype_bytes, peak_tflops, hbm_bw_gbps, l2_size) + bm, bn, bk, M, K, N, dtype_bytes, peak_tflops, hbm_bw_gbps, l2_size + ) # 3. Communication model (link-limited) total_remote_bytes = M * K_local * (world_size - 1) * dtype_bytes @@ -370,32 +376,37 @@ def derive(M, N, K, world_size, link_bw, num_cus, peak_tflops, num_flag_groups_k = num_k_blocks // kpf # 6. Pipeline stages - num_stages, m_per_stage = _choose_fetch_stages( - num_m_tiles, num_tiles_n, gm, comm_time_ms, compute_time_ms, num_cus) + num_stages, m_per_stage = _choose_fetch_stages(num_m_tiles, num_tiles_n, gm, comm_time_ms, compute_time_ms, num_cus) gemm_tiles_per_stage = m_per_stage * num_tiles_n # 7. first_stage_fetch_sms: use all CUs to fill the pipeline ASAP fsf = num_cus # 8. Per-WG timing - gemm_wg_us_val = _gemm_wg_time_us(bm, bn, bk, K, num_flag_groups_k, - roofline_tflops, num_cus) + gemm_wg_us_val = _gemm_wg_time_us(bm, bn, bk, K, num_flag_groups_k, roofline_tflops, num_cus) # 9. Choose num_fetch_sms nf = _choose_num_fetch_sms( - m_per_stage, gm, num_flag_groups_k, - link_bw, world_size, num_cus, - bm, bk, kpf, dtype_bytes, - gemm_wg_us_val, gemm_tiles_per_stage) + m_per_stage, + gm, + num_flag_groups_k, + link_bw, + world_size, + num_cus, + bm, + bk, + kpf, + dtype_bytes, + gemm_wg_us_val, + gemm_tiles_per_stage, + ) # 10. Compute per-WG fetch times total_fg_per_stage = num_flag_groups_k * m_per_stage fgs_per_wg_stg0 = max(1, math.ceil(total_fg_per_stage / fsf)) fgs_per_wg_rest = max(1, math.ceil(total_fg_per_stage / nf)) - fetch_us_stg0 = _fetch_wg_time_us(bm, bk, kpf, world_size, - link_bw, dtype_bytes, fgs_per_wg_stg0) - fetch_us_rest = _fetch_wg_time_us(bm, bk, kpf, world_size, - link_bw, dtype_bytes, fgs_per_wg_rest) + fetch_us_stg0 = _fetch_wg_time_us(bm, bk, kpf, world_size, link_bw, dtype_bytes, fgs_per_wg_stg0) + fetch_us_rest = _fetch_wg_time_us(bm, bk, kpf, world_size, link_bw, dtype_bytes, fgs_per_wg_rest) # 11. Grid geometry first_stage_size = fsf + gemm_tiles_per_stage @@ -405,19 +416,16 @@ def derive(M, N, K, world_size, link_bw, num_cus, peak_tflops, total_gemm_wgs = gemm_tiles_per_stage * num_stages # 12. Kernel time estimate (CU-work model) - avg_fetch_us = (fsf * fetch_us_stg0 + nf * max(0, num_stages - 1) * fetch_us_rest) + avg_fetch_us = fsf * fetch_us_stg0 + nf * max(0, num_stages - 1) * fetch_us_rest avg_fetch_us /= max(total_fetch_wgs, 1) est_kernel_ms, est_ideal_ms = _estimate_kernel_time( - total_gemm_wgs, gemm_wg_us_val, - total_fetch_wgs, avg_fetch_us, - num_cus, scheduling_factor) + total_gemm_wgs, gemm_wg_us_val, total_fetch_wgs, avg_fetch_us, num_cus, scheduling_factor + ) # 13. Link-limited pipeline estimate (simple model for comparison) stage_m = m_per_stage * bm - stage_comm_ms = (stage_m * K_local * (world_size - 1) * dtype_bytes - / (total_link_bw * 1e9) * 1e3) - stage_compute_ms = (2 * stage_m * N * K - / (roofline_tflops * 1e12) * 1e3) + stage_comm_ms = stage_m * K_local * (world_size - 1) * dtype_bytes / (total_link_bw * 1e9) * 1e3 + stage_compute_ms = 2 * stage_m * N * K / (roofline_tflops * 1e12) * 1e3 startup_ms = stage_comm_ms steady_ms = max(stage_comm_ms, stage_compute_ms) * max(0, num_stages - 1) drain_ms = stage_compute_ms @@ -433,29 +441,46 @@ def derive(M, N, K, world_size, link_bw, num_cus, peak_tflops, staged_a_gb = M * K * dtype_bytes / (1024**3) return dict( - block_size_m=bm, block_size_n=bn, block_size_k=bk, - group_size_m=gm, num_warps=nw, - num_fetch_sms=nf, k_per_flag=kpf, - num_fetch_stages=num_stages, first_stage_fetch_sms=fsf, + block_size_m=bm, + block_size_n=bn, + block_size_k=bk, + group_size_m=gm, + num_warps=nw, + num_fetch_sms=nf, + k_per_flag=kpf, + num_fetch_stages=num_stages, + first_stage_fetch_sms=fsf, # derived - K_local=K_local, num_m_tiles=num_m_tiles, num_tiles_n=num_tiles_n, - num_k_blocks=num_k_blocks, num_flag_groups_k=num_flag_groups_k, - m_per_stage=m_per_stage, gemm_tiles_per_stage=gemm_tiles_per_stage, - grid_size=grid_size, total_fetch_wgs=total_fetch_wgs, + K_local=K_local, + num_m_tiles=num_m_tiles, + num_tiles_n=num_tiles_n, + num_k_blocks=num_k_blocks, + num_flag_groups_k=num_flag_groups_k, + m_per_stage=m_per_stage, + gemm_tiles_per_stage=gemm_tiles_per_stage, + grid_size=grid_size, + total_fetch_wgs=total_fetch_wgs, total_gemm_wgs=total_gemm_wgs, # roofline - roofline_tflops=roofline_tflops, tile_intensity=intensity, - ridge_point=ridge, b_in_l2=b_in_l2, + roofline_tflops=roofline_tflops, + tile_intensity=intensity, + ridge_point=ridge, + b_in_l2=b_in_l2, # per-WG timing gemm_wg_us=gemm_wg_us_val, fetch_wg_us_stg0=fetch_us_stg0, fetch_wg_us_rest=fetch_us_rest, # estimates - total_remote_bytes=total_remote_bytes, total_link_bw=total_link_bw, - comm_time_ms=comm_time_ms, total_flops=total_flops, - compute_time_ms=compute_time_ms, ratio=ratio, - stage_comm_ms=stage_comm_ms, stage_compute_ms=stage_compute_ms, - pipeline_ms=pipeline_ms, sequential_ms=sequential_ms, + total_remote_bytes=total_remote_bytes, + total_link_bw=total_link_bw, + comm_time_ms=comm_time_ms, + total_flops=total_flops, + compute_time_ms=compute_time_ms, + ratio=ratio, + stage_comm_ms=stage_comm_ms, + stage_compute_ms=stage_compute_ms, + pipeline_ms=pipeline_ms, + sequential_ms=sequential_ms, est_kernel_ms=est_kernel_ms, est_ideal_ms=est_ideal_ms, standalone_gemm_ms=standalone_gemm_ms, @@ -467,6 +492,7 @@ def derive(M, N, K, world_size, link_bw, num_cus, peak_tflops, # ── Formatting helpers ─────────────────────────────────────────────────── + def _fmt_bytes(n): if n >= 1024**3: return f"{n / 1024**3:.2f} GB" @@ -487,8 +513,8 @@ def _fmt_tflops(t): # ── Analysis output ────────────────────────────────────────────────────── -def print_analysis(M, N, K, world_size, link_bw, p, passthrough_args, - bw_profiled=False): + +def print_analysis(M, N, K, world_size, link_bw, p, passthrough_args, bw_profiled=False): K_local = p["K_local"] dtype_bytes = 2 bound = "COMM-BOUND" if p["ratio"] > 1.0 else "COMPUTE-BOUND" @@ -517,32 +543,30 @@ def print_analysis(M, N, K, world_size, link_bw, p, passthrough_args, # ── Per-tile roofline ───────────────────────────────────────────── print(f"\n── Roofline {'─' * 59}") print(f"{'Tile':>14}: ({p['block_size_m']}, {p['block_size_n']}, {p['block_size_k']})") - print(f"{'Intensity':>14}: {p['tile_intensity']:.0f} FLOPs/byte " - f"{'(B in L2)' if p['b_in_l2'] else '(B from HBM)'}") + print(f"{'Intensity':>14}: {p['tile_intensity']:.0f} FLOPs/byte {'(B in L2)' if p['b_in_l2'] else '(B from HBM)'}") print(f"{'Ridge point':>14}: {p['ridge_point']:.0f} FLOPs/byte") region = "COMPUTE" if p["tile_intensity"] >= p["ridge_point"] else "MEMORY" print(f"{'Roofline':>14}: {_fmt_tflops(p['roofline_tflops'])} ({region}-bound tiles)") # ── Communication ───────────────────────────────────────────────── print(f"\n── Communication {'─' * 54}") - print(f"{'Remote bytes':>14}: {_fmt_bytes(p['total_remote_bytes'])} " - f"(from {world_size - 1} peers)") + print(f"{'Remote bytes':>14}: {_fmt_bytes(p['total_remote_bytes'])} (from {world_size - 1} peers)") bw_src = "profiled" if bw_profiled else "user" - print(f"{'Link BW':>14}: {link_bw:.1f} GB/s/link × {world_size - 1} links " - f"= {p['total_link_bw']:.0f} GB/s aggregate ({bw_src})") + print( + f"{'Link BW':>14}: {link_bw:.1f} GB/s/link × {world_size - 1} links " + f"= {p['total_link_bw']:.0f} GB/s aggregate ({bw_src})" + ) print(f"{'Comm time':>14}: {p['comm_time_ms']:.3f} ms (link-limited)") # ── Compute ─────────────────────────────────────────────────────── print(f"\n── Compute {'─' * 60}") print(f"{'Total FLOPs':>14}: {_fmt_flops(p['total_flops'])}") - print(f"{'Roofline time':>14}: {p['compute_time_ms']:.3f} ms " - f"(at {_fmt_tflops(p['roofline_tflops'])})") + print(f"{'Roofline time':>14}: {p['compute_time_ms']:.3f} ms (at {_fmt_tflops(p['roofline_tflops'])})") print(f"{'Comm/Compute':>14}: {p['ratio']:.2f}x → {bound}") # ── Per-WG timing ───────────────────────────────────────────────── print(f"\n── Per-WG Model {'─' * 55}") - print(f"{'GEMM WG':>14}: {p['gemm_wg_us']:.0f} us " - f"({p['total_flops'] / p['total_gemm_wgs'] / 1e9:.2f} GFLOPs/WG)") + print(f"{'GEMM WG':>14}: {p['gemm_wg_us']:.0f} us ({p['total_flops'] / p['total_gemm_wgs'] / 1e9:.2f} GFLOPs/WG)") print(f"{'Fetch WG stg0':>14}: {p['fetch_wg_us_stg0']:.0f} us") if p["num_fetch_stages"] > 1: print(f"{'Fetch WG rest':>14}: {p['fetch_wg_us_rest']:.0f} us") @@ -552,43 +576,47 @@ def print_analysis(M, N, K, world_size, link_bw, p, passthrough_args, print(f"\n── Pipeline {'─' * 59}") print(f"{'Stages (S)':>14}: {S}") print(f"{'M tiles/stage':>14}: {p['m_per_stage']} ({p['m_per_stage'] * p['block_size_m']} rows)") - print(f"{'GEMM WGs/stg':>14}: {p['gemm_tiles_per_stage']} " - f"({p['m_per_stage']} m-tiles × {p['num_tiles_n']} n-tiles)") - print(f"{'K flag groups':>14}: {p['num_flag_groups_k']} " - f"(k_per_flag={p['k_per_flag']})") + print( + f"{'GEMM WGs/stg':>14}: {p['gemm_tiles_per_stage']} ({p['m_per_stage']} m-tiles × {p['num_tiles_n']} n-tiles)" + ) + print(f"{'K flag groups':>14}: {p['num_flag_groups_k']} (k_per_flag={p['k_per_flag']})") print(f"{'Stage comm':>14}: {p['stage_comm_ms']:.3f} ms") print(f"{'Stage compute':>14}: {p['stage_compute_ms']:.3f} ms") # ── Grid ────────────────────────────────────────────────────────── print(f"\n── Grid Layout {'─' * 56}") - print(f"{'Stage 0':>14}: {p['first_stage_fetch_sms']} fetchers + " - f"{p['gemm_tiles_per_stage']} GEMM = " - f"{p['first_stage_fetch_sms'] + p['gemm_tiles_per_stage']} WGs") + print( + f"{'Stage 0':>14}: {p['first_stage_fetch_sms']} fetchers + " + f"{p['gemm_tiles_per_stage']} GEMM = " + f"{p['first_stage_fetch_sms'] + p['gemm_tiles_per_stage']} WGs" + ) if S > 1: - print(f"{'Stages 1..{}'.format(S - 1):>14}: {p['num_fetch_sms']} fetchers + " - f"{p['gemm_tiles_per_stage']} GEMM = " - f"{p['num_fetch_sms'] + p['gemm_tiles_per_stage']} WGs (×{S - 1})") - print(f"{'Total grid':>14}: {p['grid_size']} WGs " - f"({p['total_fetch_wgs']} fetch + {p['total_gemm_wgs']} GEMM)") + print( + f"{'Stages 1..{}'.format(S - 1):>14}: {p['num_fetch_sms']} fetchers + " + f"{p['gemm_tiles_per_stage']} GEMM = " + f"{p['num_fetch_sms'] + p['gemm_tiles_per_stage']} WGs (×{S - 1})" + ) + print(f"{'Total grid':>14}: {p['grid_size']} WGs ({p['total_fetch_wgs']} fetch + {p['total_gemm_wgs']} GEMM)") # ── Time estimates ──────────────────────────────────────────────── print(f"\n── Time Estimates {'─' * 53}") - print(f"{'CU-work lower':>14}: {p['est_ideal_ms']:.1f} ms " - f"(total WG time / {DEFAULT_NUM_CUS} CUs)") - print(f"{'Fused kernel':>14}: {p['est_kernel_ms']:.1f} ms " - f"(×{p['scheduling_factor']:.1f} scheduling overhead)") + print(f"{'CU-work lower':>14}: {p['est_ideal_ms']:.1f} ms (total WG time / {DEFAULT_NUM_CUS} CUs)") + print(f"{'Fused kernel':>14}: {p['est_kernel_ms']:.1f} ms (×{p['scheduling_factor']:.1f} scheduling overhead)") est_tflops = p["total_flops"] / (p["est_kernel_ms"] * 1e-3) / 1e12 - print(f"{'Est. TFLOPS':>14}: {est_tflops:.0f} TFLOPS " - f"({est_tflops / p['roofline_tflops'] * 100:.0f}% of roofline)") + print( + f"{'Est. TFLOPS':>14}: {est_tflops:.0f} TFLOPS ({est_tflops / p['roofline_tflops'] * 100:.0f}% of roofline)" + ) print(f"{'':>14}") - print(f"{'PyTorch est.':>14}: {p['pytorch_est_ms']:.1f} ms " - f"(all_gather {p['comm_time_ms']:.1f} + matmul {p['standalone_gemm_ms']:.1f})") + print( + f"{'PyTorch est.':>14}: {p['pytorch_est_ms']:.1f} ms " + f"(all_gather {p['comm_time_ms']:.1f} + matmul {p['standalone_gemm_ms']:.1f})" + ) if p["est_kernel_ms"] < p["pytorch_est_ms"]: speedup = p["pytorch_est_ms"] / p["est_kernel_ms"] print(f"{'Fused speedup':>14}: {speedup:.2f}x over sequential PyTorch") else: slowdown = p["est_kernel_ms"] / p["pytorch_est_ms"] - print(f"{'Fused speedup':>14}: {1/slowdown:.2f}x (slower than sequential by {slowdown:.2f}x)") + print(f"{'Fused speedup':>14}: {1 / slowdown:.2f}x (slower than sequential by {slowdown:.2f}x)") # ── Recommended parameters ──────────────────────────────────────── print(f"\n── Recommended Kernel Parameters {'─' * 38}") @@ -639,20 +667,22 @@ def main(): parser.add_argument("-m", type=int, required=True, help="M dimension (rows of output)") parser.add_argument("-n", type=int, required=True, help="N dimension (cols of output)") parser.add_argument("-k", type=int, required=True, help="K dimension (total reduction dim)") - parser.add_argument("--world_size", type=int, default=DEFAULT_WORLD_SIZE, - help="Number of GPUs") - parser.add_argument("--link_bw", type=float, default=None, - help="Per-link XGMI bandwidth in GB/s (one direction). " - "Omit to auto-profile via GPU-to-GPU copies.") - parser.add_argument("--num_cus", type=int, default=DEFAULT_NUM_CUS, - help="Number of compute units") - parser.add_argument("--peak_tflops", type=float, default=DEFAULT_PEAK_TFLOPS_FP16, - help="Peak fp16 TFLOPS") - parser.add_argument("--hbm_bw", type=float, default=DEFAULT_HBM_BW_GBPS, - help="HBM bandwidth in GB/s") - parser.add_argument("--scheduling_factor", type=float, - default=DEFAULT_SCHEDULING_FACTOR, - help="CU scheduling overhead factor (calibrated from traces)") + parser.add_argument("--world_size", type=int, default=DEFAULT_WORLD_SIZE, help="Number of GPUs") + parser.add_argument( + "--link_bw", + type=float, + default=None, + help="Per-link XGMI bandwidth in GB/s (one direction). Omit to auto-profile via GPU-to-GPU copies.", + ) + parser.add_argument("--num_cus", type=int, default=DEFAULT_NUM_CUS, help="Number of compute units") + parser.add_argument("--peak_tflops", type=float, default=DEFAULT_PEAK_TFLOPS_FP16, help="Peak fp16 TFLOPS") + parser.add_argument("--hbm_bw", type=float, default=DEFAULT_HBM_BW_GBPS, help="HBM bandwidth in GB/s") + parser.add_argument( + "--scheduling_factor", + type=float, + default=DEFAULT_SCHEDULING_FACTOR, + help="CU scheduling overhead factor (calibrated from traces)", + ) args, passthrough = parser.parse_known_args() @@ -670,13 +700,21 @@ def main(): print(" Falling back to --link_bw 50 (MI300X default)\n") link_bw = 50.0 - p = derive(args.m, args.n, args.k, args.world_size, link_bw, - args.num_cus, args.peak_tflops, args.hbm_bw, - DEFAULT_L2_SIZE_BYTES, args.scheduling_factor, - dtype_bytes=2) + p = derive( + args.m, + args.n, + args.k, + args.world_size, + link_bw, + args.num_cus, + args.peak_tflops, + args.hbm_bw, + DEFAULT_L2_SIZE_BYTES, + args.scheduling_factor, + dtype_bytes=2, + ) - print_analysis(args.m, args.n, args.k, args.world_size, link_bw, - p, passthrough, bw_profiled=bw_profiled) + print_analysis(args.m, args.n, args.k, args.world_size, link_bw, p, passthrough, bw_profiled=bw_profiled) if __name__ == "__main__": From 77eff5b17ea8a34c6c5f9e9fe66a0af10cd7f678 Mon Sep 17 00:00:00 2001 From: Ryan Swann Date: Tue, 3 Mar 2026 16:32:15 -0500 Subject: [PATCH 22/31] Reverse 2D block translate --- iris/iris.py | 47 +++++------------------------------------------ iris/x/core.py | 5 +---- iris/x/gather.py | 30 +++++++++++------------------- 3 files changed, 17 insertions(+), 65 deletions(-) diff --git a/iris/iris.py b/iris/iris.py index d061f09ea..e68adc3f0 100644 --- a/iris/iris.py +++ b/iris/iris.py @@ -1807,30 +1807,6 @@ def __translate(ptr, from_rank, to_rank, heap_bases): return translated_ptr -@triton.jit -def __translate_block_2d(ptr, from_rank, to_rank, heap_bases): - """ - Pointer translation for block load/store operations. - - Note: Vectorization hints should be applied in the tile_ptr computation (core.py) - where the 2D block shape is actually created, not here in the translation. - """ - from_base = tl.load(heap_bases + from_rank) - to_base = tl.load(heap_bases + to_rank) - # convert to int to compute difference - ptr_int = tl.cast(ptr, tl.uint64) - # Find the offset from from_rank heap - offset = ptr_int - from_base - # Byte cast for byte offset addition - to_base_byte = tl.cast(to_base, tl.pointer_type(tl.int8)) - # Find the offset into the to_rank heap - translated_ptr_byte = to_base_byte + offset - # Cast to_base back to pointer type - translated_ptr = tl.cast(translated_ptr_byte, ptr.dtype) - - return translated_ptr - - @aggregate class DeviceContext: """ @@ -2005,16 +1981,9 @@ def initialize(context_tensor, rank, world_size, tracing: tl.constexpr = False): @triton.jit def _translate(self, ptr, from_rank, to_rank): - """Internal pointer translation between rank address spaces. - Used for atomic operations which may receive scalar pointers.""" + """Internal pointer translation between rank address spaces.""" return __translate(ptr, from_rank, to_rank, self.heap_bases) - @triton.jit - def _translate_block_2d(self, ptr, from_rank, to_rank): - """Internal pointer translation with 2D vectorization hints. - Used for block load/store operations with 2D block pointers.""" - return __translate_block_2d(ptr, from_rank, to_rank, self.heap_bases) - @triton.jit def load(self, pointer, from_rank, mask=None): """ @@ -2036,7 +2005,7 @@ def load(self, pointer, from_rank, mask=None): Example: >>> data = ctx.load(buffer + offsets, from_rank=1, mask=mask) """ - translated_ptr = self.__translate(pointer, self.rank, from_rank) + translated_ptr = self._translate(pointer, self.rank, from_rank) result = tl.load(translated_ptr, mask=mask) return result @@ -2062,7 +2031,7 @@ def store(self, pointer, value, to_rank, mask=None): Example: >>> ctx.store(buffer + offsets, values, to_rank=1, mask=mask) """ - translated_ptr = self.__translate(pointer, self.rank, to_rank) + translated_ptr = self._translate(pointer, self.rank, to_rank) tl.store(translated_ptr, value, mask=mask) @triton.jit @@ -2392,9 +2361,6 @@ def load(pointer, to_rank, from_rank, heap_bases, mask=None): data from the target memory location. If the `from_rank` and `to_rank` are the same, this function performs a local load operation. - This function uses 2D vectorization hints for optimal performance with block pointers. - Minimum block size in each dimension should be >= 16. - Args: pointer (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the `from_rank`'s address space that will be translated to the `to_rank`'s address space. Must be the current rank where the pointer is local. to_rank (int): The rank ID to which the pointer will be translated. Must be the current rank where the pointer is local. @@ -2414,7 +2380,7 @@ def load(pointer, to_rank, from_rank, heap_bases, mask=None): >>> data = iris.load(ptr, cur_rank, remote_rank, heap_bases) >>> return data """ - translated_ptr = __translate_block_2d(pointer, to_rank, from_rank, heap_bases) + translated_ptr = __translate(pointer, to_rank, from_rank, heap_bases) result = tl.load(translated_ptr, mask=mask) return result @@ -2429,9 +2395,6 @@ def store(pointer, value, from_rank, to_rank, heap_bases, mask=None): the provided data to the target memory location. If the `from_rank` and `to_rank` are the same, this function performs a local store operation. - This function uses 2D vectorization hints for optimal performance with block pointers. - Minimum block size in each dimension should be >= 16. - Args: pointer (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the `from_rank`'s address space that will be translated to the `to_rank`'s address space. Must be the current rank where the pointer is local. value (Block): The tensor of elements to be stored. @@ -2452,7 +2415,7 @@ def store(pointer, value, from_rank, to_rank, heap_bases, mask=None): >>> value = 42 >>> iris.store(ptr, value, cur_rank, remote_rank, heap_bases) """ - translated_ptr = __translate_block_2d(pointer, from_rank, to_rank, heap_bases) + translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) tl.store(translated_ptr, value, mask=mask) diff --git a/iris/x/core.py b/iris/x/core.py index 58786e79e..fee50918e 100644 --- a/iris/x/core.py +++ b/iris/x/core.py @@ -80,10 +80,7 @@ def tile_ptr(ptr, M, N, stride_m, stride_n, pid_m, pid_n, BLOCK_SIZE_M: tl.const rm, rn, mask = tile_layout(pid_m, pid_n, M, N, BLOCK_SIZE_M, BLOCK_SIZE_N) offset = rm[:, None] * stride_m + rn[None, :] * stride_n tile_ptr = ptr + offset - # NOTE: Vectorization hints are applied at the call site (e.g., gather.py) - # rather than here, because the caller knows the block dimensions. - # Alignment IS preserved through pointer translation since symmetric heaps - # are all page-aligned, so relative offsets within the heap are maintained. + tile_ptr = tl.multiple_of(tile_ptr, (BLOCK_SIZE_M, BLOCK_SIZE_N)) return tile_ptr, mask diff --git a/iris/x/gather.py b/iris/x/gather.py index bb3fb637a..ca8bd4f9c 100644 --- a/iris/x/gather.py +++ b/iris/x/gather.py @@ -13,6 +13,7 @@ import triton import triton.language as tl +import iris from iris.iris import DeviceContext from .core import Tile, TensorView @@ -50,25 +51,16 @@ def gather( src_tile_ptr, mask = src_view.tile_ptr(tile) if source_rank == ctx.rank: - # Local load - can use vectorization hints since alignment is guaranteed - local_ptr = tl.multiple_of(src_tile_ptr, (1, tile.block_n)) - local_ptr = tl.max_contiguous(local_ptr, (1, tile.block_n)) - tile_data = tl.load(local_ptr, mask=mask) + # Local load + tile_data = tl.load(src_tile_ptr, mask=mask, other=0.0) else: - # Remote load using RMA - inline translation and apply hints AFTER translation - # Hints must be applied to the translated pointer because pointer arithmetic - # (cast to uint64, subtract, add, cast back) destroys hint metadata. - # Alignment IS preserved because symmetric heaps are all page-aligned. - from_base = tl.load(ctx.heap_bases + ctx.rank) - to_base = tl.load(ctx.heap_bases + source_rank) - ptr_int = tl.cast(src_tile_ptr, tl.uint64) - offset = ptr_int - from_base - to_base_byte = tl.cast(to_base, tl.pointer_type(tl.int8)) - translated_ptr_byte = to_base_byte + offset - translated_ptr = tl.cast(translated_ptr_byte, src_tile_ptr.dtype) - # Apply vectorization hints AFTER translation - translated_ptr = tl.multiple_of(translated_ptr, (1, tile.block_n)) - translated_ptr = tl.max_contiguous(translated_ptr, (1, tile.block_n)) - tile_data = tl.load(translated_ptr, mask=mask) + # Remote load using RMA + tile_data = iris.load( + src_tile_ptr, + ctx.rank, # to_rank (current rank) + source_rank, # from_rank (source rank) + ctx.heap_bases, + mask=mask, + ) return tile_data From dcafd2a669d77a3ac225a96c00e66d394110d68d Mon Sep 17 00:00:00 2001 From: Ryan Swann Date: Tue, 3 Mar 2026 16:51:16 -0500 Subject: [PATCH 23/31] Properly use iris tracing APIs --- iris/ops/all_gather_matmul_hbm_buffer.py | 127 +++++++++++++++-------- iris/tracing/events.py | 13 +++ 2 files changed, 97 insertions(+), 43 deletions(-) diff --git a/iris/ops/all_gather_matmul_hbm_buffer.py b/iris/ops/all_gather_matmul_hbm_buffer.py index 36010e24f..b23123fcb 100644 --- a/iris/ops/all_gather_matmul_hbm_buffer.py +++ b/iris/ops/all_gather_matmul_hbm_buffer.py @@ -16,7 +16,8 @@ import iris import iris.x -from iris.device_utils import read_realtime, get_xcc_id +from iris.device_utils import read_realtime +from iris.tracing.events import TraceEvent from .config import FusedConfig from .workspace import FusedWorkspace @@ -62,19 +63,13 @@ def _hbm_buffer_all_gather_matmul_kernel( NUM_FETCH_STAGES: tl.constexpr, GEMM_TILES_PER_STAGE: tl.constexpr, FIRST_STAGE_FETCH_SMS: tl.constexpr, - trace_start_ptr, - trace_end_ptr, - trace_wait_ptr, - trace_xcd_ptr, TRACE: tl.constexpr, ): pid = tl.program_id(0) acc_dtype = tl.int32 if C.type.element_ty == tl.int8 else tl.float32 zero = tl.program_id(0) * 0 - if TRACE: - tl.store(trace_start_ptr + pid, read_realtime()) - tl.store(trace_xcd_ptr + pid, get_xcc_id()) + ctx = iris.DeviceContext.initialize(context_tensor, cur_rank, world_size, tracing=TRACE) # Interleaved layout with asymmetric first stage: # [fetch0 (P)] [gemm0 (G)] [fetch1 (F)] [gemm1 (G)] ... @@ -101,7 +96,15 @@ def _hbm_buffer_all_gather_matmul_kernel( # ============================================================== stage_pid = local_pid - ctx = iris.DeviceContext.initialize(context_tensor, cur_rank, world_size) + if TRACE: + _trace_handle = ctx.tracing.record_event_start( + event_id=TraceEvent().wg_fetch, + target_rank=cur_rank, + address=flags_ptr + tl.arange(0, 1), + pid_m=pid, + pid_n=my_stage, + ) + src_view = iris.x.make_tensor_view(A_sharded, M, K_local, stride_am, stride_ak) tiles_per_m_group = NUM_FLAG_GROUPS_K * GROUP_SIZE_M @@ -146,8 +149,7 @@ def _hbm_buffer_all_gather_matmul_kernel( tl.atomic_xchg(flags_ptr + flag_idx, 1, sem="release", scope="gpu") if TRACE: - tl.store(trace_wait_ptr + pid, zero.to(tl.int64), cache_modifier=".wt") - tl.store(trace_end_ptr + pid, read_realtime(), cache_modifier=".wt") + ctx.tracing.record_event_end(_trace_handle) else: # ============================================================== @@ -173,6 +175,13 @@ def _hbm_buffer_all_gather_matmul_kernel( acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) if TRACE: + _trace_handle = ctx.tracing.record_event_start( + event_id=TraceEvent().wg_gemm, + target_rank=cur_rank, + address=flags_ptr + tl.arange(0, 1), + pid_m=pid, + pid_n=my_stage, + ) _wt = zero.to(tl.int64) for k_fg in range(NUM_FLAG_GROUPS_K): @@ -213,8 +222,14 @@ def _hbm_buffer_all_gather_matmul_kernel( tl.store(C_ptrs, c, mask=c_mask, cache_modifier=".wt") if TRACE: - tl.store(trace_wait_ptr + pid, _wt) - tl.store(trace_end_ptr + pid, read_realtime(), cache_modifier=".wt") + ctx.tracing.record_event_end(_trace_handle) + ctx.tracing.record_event_start( + event_id=TraceEvent().wg_gemm_wait, + target_rank=cur_rank, + address=flags_ptr + tl.arange(0, 1), + pid_m=pid, + pid_n=_wt.to(tl.int32), + ) # ========================================================================== @@ -285,6 +300,45 @@ def all_gather_matmul_hbm_buffer_preamble( return ws +_WG_FETCH = 14 +_WG_GEMM = 15 +_WG_GEMM_WAIT = 16 + + +def _extract_wg_trace(shmem, grid_size, **metadata): + """Reconstruct per-workgroup trace arrays from DeviceTracing events.""" + import numpy as np + + bufs = shmem.tracing.trace_buffers + n = min(shmem.tracing.trace_counter.item(), shmem.tracing.max_events) + + event_ids = bufs["event_id"][:n].cpu().numpy() + pids = bufs["pid"][:n].cpu().numpy() + timestamps = bufs["timestamp"][:n].cpu().numpy().astype(np.int64) + end_ts = bufs["duration_cycles"][:n].cpu().numpy().astype(np.int64) + xcc_ids = bufs["xcc_id"][:n].cpu().numpy().astype(np.int32) + pid_ns = bufs["pid_n"][:n].cpu().numpy() + + starts = torch.zeros(grid_size, dtype=torch.int64) + ends = torch.zeros(grid_size, dtype=torch.int64) + waits = torch.zeros(grid_size, dtype=torch.int64) + xcds = torch.zeros(grid_size, dtype=torch.int32) + + for i in range(n): + eid = int(event_ids[i]) + wg = int(pids[i]) + if wg >= grid_size: + continue + if eid == _WG_FETCH or eid == _WG_GEMM: + starts[wg] = int(timestamps[i]) + ends[wg] = int(end_ts[i]) + xcds[wg] = int(xcc_ids[i]) + elif eid == _WG_GEMM_WAIT: + waits[wg] = int(pid_ns[i]) + + return {"start": starts, "end": ends, "wait": waits, "xcd": xcds, "grid_size": grid_size, **metadata} + + def all_gather_matmul_hbm_buffer( shmem, output_tensor: torch.Tensor, @@ -385,17 +439,12 @@ def all_gather_matmul_hbm_buffer( total_fetch_wgs = first_stage_fetch_sms + num_fetch_sms * max(0, num_fetch_stages - 1) grid_size = first_stage_size + rest_stage_size * max(0, num_fetch_stages - 1) - # Trace buffers if trace: - trace_start = torch.zeros(grid_size, dtype=torch.int64, device=device) - trace_end = torch.zeros(grid_size, dtype=torch.int64, device=device) - trace_wait = torch.zeros(grid_size, dtype=torch.int64, device=device) - trace_xcd = torch.zeros(grid_size, dtype=torch.int32, device=device) - else: - trace_start = torch.empty(1, dtype=torch.int64, device=device) - trace_end = torch.empty(1, dtype=torch.int64, device=device) - trace_wait = torch.empty(1, dtype=torch.int64, device=device) - trace_xcd = torch.empty(1, dtype=torch.int32, device=device) + max_trace_events = grid_size * 4 + if not shmem.tracing.enabled: + shmem.tracing.enable(max_events=max_trace_events) + else: + shmem.tracing.reset() launch_kwargs = {"matrix_instr_nonkdim": 16} if num_warps is not None: @@ -443,10 +492,6 @@ def all_gather_matmul_hbm_buffer( num_fetch_stages, gemm_tiles_per_stage, first_stage_fetch_sms, - trace_start, - trace_end, - trace_wait, - trace_xcd, trace, **launch_kwargs, ) @@ -456,21 +501,17 @@ def all_gather_matmul_hbm_buffer( if trace: torch.cuda.synchronize() - workspace.trace_data = { - "start": trace_start.cpu(), - "end": trace_end.cpu(), - "wait": trace_wait.cpu(), - "xcd": trace_xcd.cpu(), - "grid_size": grid_size, - "num_fetch_sms": num_fetch_sms, - "num_fetch_stages": num_fetch_stages, - "total_fetch_wgs": total_fetch_wgs, - "num_m_tiles": num_m_tiles, - "num_tiles_n": num_tiles_n, - "first_stage_fetch_sms": first_stage_fetch_sms, - "first_stage_size": first_stage_size, - "rest_stage_size": rest_stage_size, - "gemm_tiles_per_stage": gemm_tiles_per_stage, - } + workspace.trace_data = _extract_wg_trace( + shmem, grid_size, + num_fetch_sms=num_fetch_sms, + num_fetch_stages=num_fetch_stages, + total_fetch_wgs=total_fetch_wgs, + num_m_tiles=num_m_tiles, + num_tiles_n=num_tiles_n, + first_stage_fetch_sms=first_stage_fetch_sms, + first_stage_size=first_stage_size, + rest_stage_size=rest_stage_size, + gemm_tiles_per_stage=gemm_tiles_per_stage, + ) return workspace diff --git a/iris/tracing/events.py b/iris/tracing/events.py index 4838c09d6..62d7cf8df 100644 --- a/iris/tracing/events.py +++ b/iris/tracing/events.py @@ -26,6 +26,9 @@ 11: "atomic_or", 12: "atomic_min", 13: "atomic_max", + 14: "wg_fetch", + 15: "wg_gemm", + 16: "wg_gemm_wait", } @@ -75,6 +78,11 @@ class TraceEvent: atomic_min: tl.constexpr atomic_max: tl.constexpr + # Workgroup-level profiling events + wg_fetch: tl.constexpr + wg_gemm: tl.constexpr + wg_gemm_wait: tl.constexpr + @triton.constexpr_function def __init__(self): # Data movement @@ -94,3 +102,8 @@ def __init__(self): self.atomic_or = tl.constexpr(11) self.atomic_min = tl.constexpr(12) self.atomic_max = tl.constexpr(13) + + # Workgroup-level profiling + self.wg_fetch = tl.constexpr(14) + self.wg_gemm = tl.constexpr(15) + self.wg_gemm_wait = tl.constexpr(16) From 6fdad6dad72f79befe4a83ea3fdc49b9c3cb5c56 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Tue, 3 Mar 2026 21:51:56 +0000 Subject: [PATCH 24/31] Apply Ruff auto-fixes --- iris/ops/all_gather_matmul_hbm_buffer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/iris/ops/all_gather_matmul_hbm_buffer.py b/iris/ops/all_gather_matmul_hbm_buffer.py index b23123fcb..abe3b3936 100644 --- a/iris/ops/all_gather_matmul_hbm_buffer.py +++ b/iris/ops/all_gather_matmul_hbm_buffer.py @@ -502,7 +502,8 @@ def all_gather_matmul_hbm_buffer( if trace: torch.cuda.synchronize() workspace.trace_data = _extract_wg_trace( - shmem, grid_size, + shmem, + grid_size, num_fetch_sms=num_fetch_sms, num_fetch_stages=num_fetch_stages, total_fetch_wgs=total_fetch_wgs, From 08755b777bded836ef206968f3c63700f0be9b8c Mon Sep 17 00:00:00 2001 From: Ryan Swann Date: Tue, 3 Mar 2026 16:54:57 -0500 Subject: [PATCH 25/31] Remove test.sh --- benchmark/ops/all_gather_matmul/test.sh | 16 ---------------- 1 file changed, 16 deletions(-) delete mode 100755 benchmark/ops/all_gather_matmul/test.sh diff --git a/benchmark/ops/all_gather_matmul/test.sh b/benchmark/ops/all_gather_matmul/test.sh deleted file mode 100755 index 7d5ef1a98..000000000 --- a/benchmark/ops/all_gather_matmul/test.sh +++ /dev/null @@ -1,16 +0,0 @@ -HSA_NO_SCRATCH_RECLAIM=1 \ -python3 $(pwd)/benchmark.py \ - -m 2048 \ - -n 16384 \ - -k 131072 \ - --num_ranks 8 \ - --num_xcds 8 \ - --datatype fp16 \ - --block_size_m 512 \ - --block_size_n 128 \ - --block_size_k 64 \ - --group_size_m 1 \ - --benchmark \ - --b_col_major \ - -v \ - --benchmark_pytorch \ No newline at end of file From f55829349c3fcd0ecef290361e41367440aa23b5 Mon Sep 17 00:00:00 2001 From: Ryan Swann Date: Thu, 5 Mar 2026 19:38:29 -0500 Subject: [PATCH 26/31] Fix CI: restore vectorization hints, align tritonBLAS versions, remove temp files - Restore optional `hint` parameter in `__translate` and all public iris API functions (load, store, get, put, copy, atomic_*) to match main branch pattern. The previous hardcoded `tl.multiple_of(ptr, (32, 32))` assumed 2D pointers and broke all scalar-pointer atomic operations. - Align tritonBLAS commit across pyproject.toml, run_tests.sh, apptainer/iris.def, and docker/Dockerfile to cd119279f. - Remove tracked backup files (iris.py.backup, all_gather_matmul.py.with_chunked) and add gitignore patterns. - Remove unimplemented "chunked" variant from test_all_gather_matmul parametrization. - Fix test_matmul_all_reduce_via_shmem_ops dimensions (N=128->256) to match new default block_size_n=256. - Remove phantom "matmul" from iris/ops/__init__.py __all__. Co-Authored-By: Claude Opus 4.6 (1M context) --- .github/scripts/run_tests.sh | 4 +- .gitignore | 2 + apptainer/iris.def | 2 +- docker/Dockerfile | 2 +- iris/iris.py | 154 +- iris/iris.py.backup | 2255 -------------------- iris/ops/__init__.py | 1 - iris/ops/all_gather_matmul.py.with_chunked | 521 ----- iris/ops/config.py | 3 +- tests/ops/test_all_gather_matmul.py | 1 - tests/ops/test_matmul_all_reduce.py | 2 +- 11 files changed, 93 insertions(+), 2854 deletions(-) delete mode 100644 iris/iris.py.backup delete mode 100644 iris/ops/all_gather_matmul.py.with_chunked diff --git a/.github/scripts/run_tests.sh b/.github/scripts/run_tests.sh index 4abf4a717..8f254b326 100755 --- a/.github/scripts/run_tests.sh +++ b/.github/scripts/run_tests.sh @@ -75,11 +75,11 @@ fi if [ ! -d \"\$TRITONBLAS_DIR\" ]; then git clone https://github.com/ROCm/tritonBLAS.git \"\$TRITONBLAS_DIR\" cd \"\$TRITONBLAS_DIR\" - git checkout 47768c93acb7f89511d797964b84544c30ab81ad + git checkout cd119279f3df543a558aa6d2cd4a3daed0b1ec7a else cd \"\$TRITONBLAS_DIR\" git fetch - git checkout 47768c93acb7f89511d797964b84544c30ab81ad + git checkout cd119279f3df543a558aa6d2cd4a3daed0b1ec7a fi # Install with dependencies pip install -e . diff --git a/.gitignore b/.gitignore index 57d842401..845d61207 100644 --- a/.gitignore +++ b/.gitignore @@ -28,6 +28,8 @@ omni*.pdf slurm*.out *.egg-info +*.backup +*.with_chunked examples/gemm/results/* asm/ diff --git a/apptainer/iris.def b/apptainer/iris.def index a5f3c3088..a02c2c32d 100644 --- a/apptainer/iris.def +++ b/apptainer/iris.def @@ -38,7 +38,7 @@ From: rocm/pytorch:rocm7.1_ubuntu24.04_py3.13_pytorch_release_2.9.1 cd /opt git clone https://github.com/ROCm/tritonBLAS.git cd tritonBLAS - git checkout 47768c93acb7f89511d797964b84544c30ab81ad + git checkout cd119279f3df543a558aa6d2cd4a3daed0b1ec7a pip3 install -e . " diff --git a/docker/Dockerfile b/docker/Dockerfile index a0f97d1c5..9c3954f98 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -43,7 +43,7 @@ ENV PYTHONPATH=$TRITON_PATH WORKDIR /opt RUN git clone https://github.com/ROCm/tritonBLAS.git && \ cd tritonBLAS && \ - git checkout 47768c93acb7f89511d797964b84544c30ab81ad && \ + git checkout cd119279f3df543a558aa6d2cd4a3daed0b1ec7a && \ pip3 install -e . # Set up workspace diff --git a/iris/iris.py b/iris/iris.py index e68adc3f0..5791d2e76 100644 --- a/iris/iris.py +++ b/iris/iris.py @@ -1779,7 +1779,7 @@ def reduce_scatter(self, output_tensor, input_tensor, op=None, group=None, async @triton.jit -def __translate(ptr, from_rank, to_rank, heap_bases): +def __translate(ptr, from_rank, to_rank, heap_bases, hint: tl.constexpr = None): from_base = tl.load(heap_bases + from_rank) to_base = tl.load(heap_bases + to_rank) # convert to int to compute difference @@ -1792,18 +1792,8 @@ def __translate(ptr, from_rank, to_rank, heap_bases): translated_ptr_byte = to_base_byte + offset # Cast to_base back to pointer type translated_ptr = tl.cast(translated_ptr_byte, ptr.dtype) - - # Optimization to vectorize the load/store - # We can't do this in general because we don't know the shape of the tensor or block sizes - # ptr = tl.max_contiguous(tl.multiple_of(ptr, (16, 16)), (16, 32)) - - # 0 You can use this if your block sizes are multiples of 32. - # Largest vectorized load instruction is dwordx4 (128-bits) - translated_ptr = tl.multiple_of(translated_ptr, (32, 32)) - translated_ptr = tl.max_contiguous(translated_ptr, (1, 32)) - - # ptr = tl.max_contiguous(tl.multiple_of(ptr, 512), 512) - # translated_ptr = tl.max_contiguous(tl.multiple_of(translated_ptr, 512), 512) + if hint is not None: + translated_ptr = tl.max_contiguous(tl.multiple_of(translated_ptr, hint), hint) return translated_ptr @@ -1980,12 +1970,12 @@ def initialize(context_tensor, rank, world_size, tracing: tl.constexpr = False): return DeviceContext(rank, world_size, heap_bases, device_tracing) @triton.jit - def _translate(self, ptr, from_rank, to_rank): + def _translate(self, ptr, from_rank, to_rank, hint: tl.constexpr = None): """Internal pointer translation between rank address spaces.""" - return __translate(ptr, from_rank, to_rank, self.heap_bases) + return __translate(ptr, from_rank, to_rank, self.heap_bases, hint) @triton.jit - def load(self, pointer, from_rank, mask=None): + def load(self, pointer, from_rank, mask=None, hint: tl.constexpr = None): """ Loads a value from the specified rank's memory location. @@ -1998,6 +1988,7 @@ def load(self, pointer, from_rank, mask=None): pointer (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the current rank's address space that will be translated to the `from_rank`'s address space. from_rank (int): The rank ID from which to read the data. mask (Block of triton.int1, optional): If mask[idx] is false, do not load the data at address pointer[idx]. Defaults to None. + hint (int or tuple, optional): Vectorization hint passed to tl.multiple_of / tl.max_contiguous on the translated pointer. Defaults to None. Returns: Block: The loaded value from the target memory location. @@ -2005,12 +1996,12 @@ def load(self, pointer, from_rank, mask=None): Example: >>> data = ctx.load(buffer + offsets, from_rank=1, mask=mask) """ - translated_ptr = self._translate(pointer, self.rank, from_rank) + translated_ptr = self._translate(pointer, self.rank, from_rank, hint) result = tl.load(translated_ptr, mask=mask) return result @triton.jit - def store(self, pointer, value, to_rank, mask=None): + def store(self, pointer, value, to_rank, mask=None, hint: tl.constexpr = None): """ Writes data to the specified rank's memory location. @@ -2024,6 +2015,7 @@ def store(self, pointer, value, to_rank, mask=None): value (Block): The tensor of elements to be stored. to_rank (int): The rank ID to which the data will be written. mask (Block of triton.int1, optional): If mask[idx] is false, do not store the data at address pointer[idx]. Defaults to None. + hint (int or tuple, optional): Vectorization hint passed to tl.multiple_of / tl.max_contiguous on the translated pointer. Defaults to None. Returns: None @@ -2031,11 +2023,11 @@ def store(self, pointer, value, to_rank, mask=None): Example: >>> ctx.store(buffer + offsets, values, to_rank=1, mask=mask) """ - translated_ptr = self._translate(pointer, self.rank, to_rank) + translated_ptr = self._translate(pointer, self.rank, to_rank, hint) tl.store(translated_ptr, value, mask=mask) @triton.jit - def get(self, from_ptr, to_ptr, from_rank, mask=None): + def get(self, from_ptr, to_ptr, from_rank, mask=None, hint: tl.constexpr = None): """ Copies data from the specified rank's memory into current rank's local memory. @@ -2049,6 +2041,7 @@ def get(self, from_ptr, to_ptr, from_rank, mask=None): to_ptr (triton.PointerType, or block of dtype=triton.PointerType): Pointer to local memory in current rank where the data will be written. from_rank (int): The rank ID from which to read the data. mask (Block of triton.int1, optional): If mask[idx] is false, do not load from from_ptr[idx] and do not store to to_ptr[idx]. Defaults to None. + hint (int or tuple, optional): Vectorization hint passed to tl.multiple_of / tl.max_contiguous on the translated pointer. Defaults to None. Returns: None @@ -2056,12 +2049,12 @@ def get(self, from_ptr, to_ptr, from_rank, mask=None): Example: >>> ctx.get(remote_ptr + offsets, local_ptr + offsets, from_rank=1, mask=mask) """ - translated_from_ptr = self._translate(from_ptr, self.rank, from_rank) + translated_from_ptr = self._translate(from_ptr, self.rank, from_rank, hint) data = tl.load(translated_from_ptr, mask=mask) tl.store(to_ptr, data, mask=mask) @triton.jit - def put(self, from_ptr, to_ptr, to_rank, mask=None): + def put(self, from_ptr, to_ptr, to_rank, mask=None, hint: tl.constexpr = None): """ Copies data from current rank's local memory to the specified rank's memory. @@ -2075,6 +2068,7 @@ def put(self, from_ptr, to_ptr, to_rank, mask=None): to_ptr (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the current rank's address space that references memory in `to_rank`. to_rank (int): The rank ID to which the data will be written. mask (Block of triton.int1, optional): If mask[idx] is false, do not load from from_ptr[idx] and do not store to to_ptr[idx]. Defaults to None. + hint (int or tuple, optional): Vectorization hint passed to tl.multiple_of / tl.max_contiguous on the translated pointer. Defaults to None. Returns: None @@ -2082,12 +2076,12 @@ def put(self, from_ptr, to_ptr, to_rank, mask=None): Example: >>> ctx.put(local_ptr + offsets, remote_ptr + offsets, to_rank=1, mask=mask) """ - translated_to_ptr = self._translate(to_ptr, self.rank, to_rank) + translated_to_ptr = self._translate(to_ptr, self.rank, to_rank, hint) data = tl.load(from_ptr, mask=mask) tl.store(translated_to_ptr, data, mask=mask) @triton.jit - def copy(self, src_ptr, dst_ptr, from_rank, to_rank, mask=None): + def copy(self, src_ptr, dst_ptr, from_rank, to_rank, mask=None, hint: tl.constexpr = None): """ Copies data from one rank's memory to another rank's memory. @@ -2127,11 +2121,15 @@ def copy(self, src_ptr, dst_ptr, from_rank, to_rank, mask=None): translated_src = tl.cast(from_base_byte + src_offset, src_ptr.dtype) translated_dst = tl.cast(to_base_byte + dst_offset, src_ptr.dtype) + if hint is not None: + translated_src = tl.max_contiguous(tl.multiple_of(translated_src, hint), hint) + translated_dst = tl.max_contiguous(tl.multiple_of(translated_dst, hint), hint) + data = tl.load(translated_src, mask=mask) tl.store(translated_dst, data, mask=mask) @triton.jit - def atomic_add(self, pointer, val, to_rank, mask=None, sem=None, scope=None): + def atomic_add(self, pointer, val, to_rank, mask=None, sem=None, scope=None, hint: tl.constexpr = None): """ Performs an atomic add at the specified rank's memory location. @@ -2154,11 +2152,11 @@ def atomic_add(self, pointer, val, to_rank, mask=None, sem=None, scope=None): Example: >>> old_val = ctx.atomic_add(counter, 1, to_rank=1) """ - translated_ptr = self._translate(pointer, self.rank, to_rank) + translated_ptr = self._translate(pointer, self.rank, to_rank, hint) return tl.atomic_add(translated_ptr, val, mask=mask, sem=sem, scope=scope) @triton.jit - def atomic_sub(self, pointer, val, to_rank, mask=None, sem=None, scope=None): + def atomic_sub(self, pointer, val, to_rank, mask=None, sem=None, scope=None, hint: tl.constexpr = None): """ Atomically subtracts data from the specified rank's memory location. @@ -2178,11 +2176,11 @@ def atomic_sub(self, pointer, val, to_rank, mask=None, sem=None, scope=None): Returns: Block: The data stored at pointer before the atomic operation. """ - translated_ptr = self._translate(pointer, self.rank, to_rank) + translated_ptr = self._translate(pointer, self.rank, to_rank, hint) return tl.atomic_sub(translated_ptr, val, mask=mask, sem=sem, scope=scope) @triton.jit - def atomic_cas(self, pointer, cmp, val, to_rank, sem=None, scope=None): + def atomic_cas(self, pointer, cmp, val, to_rank, sem=None, scope=None, hint: tl.constexpr = None): """ Performs an atomic compare-and-swap at the specified rank's memory location. @@ -2203,11 +2201,11 @@ def atomic_cas(self, pointer, cmp, val, to_rank, sem=None, scope=None): Returns: Block: The data stored at pointer before the atomic operation. """ - translated_ptr = self._translate(pointer, self.rank, to_rank) + translated_ptr = self._translate(pointer, self.rank, to_rank, hint) return tl.atomic_cas(translated_ptr, cmp, val, sem=sem, scope=scope) @triton.jit - def atomic_xchg(self, pointer, val, to_rank, mask=None, sem=None, scope=None): + def atomic_xchg(self, pointer, val, to_rank, mask=None, sem=None, scope=None, hint: tl.constexpr = None): """ Performs an atomic exchange at the specified rank's memory location. @@ -2227,11 +2225,11 @@ def atomic_xchg(self, pointer, val, to_rank, mask=None, sem=None, scope=None): Returns: Block: The data stored at pointer before the atomic operation. """ - translated_ptr = self._translate(pointer, self.rank, to_rank) + translated_ptr = self._translate(pointer, self.rank, to_rank, hint) return tl.atomic_xchg(translated_ptr, val, mask=mask, sem=sem, scope=scope) @triton.jit - def atomic_xor(self, pointer, val, to_rank, mask=None, sem=None, scope=None): + def atomic_xor(self, pointer, val, to_rank, mask=None, sem=None, scope=None, hint: tl.constexpr = None): """ Performs an atomic XOR at the specified rank's memory location. @@ -2251,11 +2249,11 @@ def atomic_xor(self, pointer, val, to_rank, mask=None, sem=None, scope=None): Returns: Block: The data stored at pointer before the atomic operation. """ - translated_ptr = self._translate(pointer, self.rank, to_rank) + translated_ptr = self._translate(pointer, self.rank, to_rank, hint) return tl.atomic_xor(translated_ptr, val, mask=mask, sem=sem, scope=scope) @triton.jit - def atomic_and(self, pointer, val, to_rank, mask=None, sem=None, scope=None): + def atomic_and(self, pointer, val, to_rank, mask=None, sem=None, scope=None, hint: tl.constexpr = None): """ Performs an atomic AND at the specified rank's memory location. @@ -2275,11 +2273,11 @@ def atomic_and(self, pointer, val, to_rank, mask=None, sem=None, scope=None): Returns: Block: The data stored at pointer before the atomic operation. """ - translated_ptr = self._translate(pointer, self.rank, to_rank) + translated_ptr = self._translate(pointer, self.rank, to_rank, hint) return tl.atomic_and(translated_ptr, val, mask=mask, sem=sem, scope=scope) @triton.jit - def atomic_or(self, pointer, val, to_rank, mask=None, sem=None, scope=None): + def atomic_or(self, pointer, val, to_rank, mask=None, sem=None, scope=None, hint: tl.constexpr = None): """ Performs an atomic OR at the specified rank's memory location. @@ -2299,11 +2297,11 @@ def atomic_or(self, pointer, val, to_rank, mask=None, sem=None, scope=None): Returns: Block: The data stored at pointer before the atomic operation. """ - translated_ptr = self._translate(pointer, self.rank, to_rank) + translated_ptr = self._translate(pointer, self.rank, to_rank, hint) return tl.atomic_or(translated_ptr, val, mask=mask, sem=sem, scope=scope) @triton.jit - def atomic_min(self, pointer, val, to_rank, mask=None, sem=None, scope=None): + def atomic_min(self, pointer, val, to_rank, mask=None, sem=None, scope=None, hint: tl.constexpr = None): """ Performs an atomic minimum at the specified rank's memory location. @@ -2323,11 +2321,11 @@ def atomic_min(self, pointer, val, to_rank, mask=None, sem=None, scope=None): Returns: Block: The data stored at pointer before the atomic operation. """ - translated_ptr = self._translate(pointer, self.rank, to_rank) + translated_ptr = self._translate(pointer, self.rank, to_rank, hint) return tl.atomic_min(translated_ptr, val, mask=mask, sem=sem, scope=scope) @triton.jit - def atomic_max(self, pointer, val, to_rank, mask=None, sem=None, scope=None): + def atomic_max(self, pointer, val, to_rank, mask=None, sem=None, scope=None, hint: tl.constexpr = None): """ Performs an atomic maximum at the specified rank's memory location. @@ -2347,12 +2345,12 @@ def atomic_max(self, pointer, val, to_rank, mask=None, sem=None, scope=None): Returns: Block: The data stored at pointer before the atomic operation. """ - translated_ptr = self._translate(pointer, self.rank, to_rank) + translated_ptr = self._translate(pointer, self.rank, to_rank, hint) return tl.atomic_max(translated_ptr, val, mask=mask, sem=sem, scope=scope) @triton.jit -def load(pointer, to_rank, from_rank, heap_bases, mask=None): +def load(pointer, to_rank, from_rank, heap_bases, mask=None, hint: tl.constexpr = None): """ Loads a value from the specified rank's memory location. @@ -2380,13 +2378,13 @@ def load(pointer, to_rank, from_rank, heap_bases, mask=None): >>> data = iris.load(ptr, cur_rank, remote_rank, heap_bases) >>> return data """ - translated_ptr = __translate(pointer, to_rank, from_rank, heap_bases) + translated_ptr = __translate(pointer, to_rank, from_rank, heap_bases, hint) result = tl.load(translated_ptr, mask=mask) return result @triton.jit -def store(pointer, value, from_rank, to_rank, heap_bases, mask=None): +def store(pointer, value, from_rank, to_rank, heap_bases, mask=None, hint: tl.constexpr = None): """ Writes data to the specified rank's memory location. @@ -2415,12 +2413,12 @@ def store(pointer, value, from_rank, to_rank, heap_bases, mask=None): >>> value = 42 >>> iris.store(ptr, value, cur_rank, remote_rank, heap_bases) """ - translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) + translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases, hint) tl.store(translated_ptr, value, mask=mask) @triton.jit -def copy(src_ptr, dst_ptr, from_rank, to_rank, cur_rank, heap_bases, mask=None): +def copy(src_ptr, dst_ptr, from_rank, to_rank, cur_rank, heap_bases, mask=None, hint: tl.constexpr = None): """ Copies data from the specified rank's memory into the destination rank's memory. This function performs the transfer by translating `src_ptr` from the `from_rank`'s address @@ -2466,12 +2464,16 @@ def copy(src_ptr, dst_ptr, from_rank, to_rank, cur_rank, heap_bases, mask=None): translated_src = tl.cast(from_base_byte + src_offset, src_ptr.dtype) translated_dst = tl.cast(to_base_byte + dst_offset, src_ptr.dtype) + if hint is not None: + translated_src = tl.max_contiguous(tl.multiple_of(translated_src, hint), hint) + translated_dst = tl.max_contiguous(tl.multiple_of(translated_dst, hint), hint) + data = tl.load(translated_src, mask=mask) tl.store(translated_dst, data, mask=mask) @triton.jit -def get(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None): +def get(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None, hint: tl.constexpr = None): """ Copies data from the specified rank's memory to the current rank's local memory. @@ -2498,7 +2500,7 @@ def get(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None): >>> to_rank = 0 >>> iris.get(remote_ptr, local_ptr, from_rank, to_rank, heap_bases) """ - translated_from_ptr = __translate(from_ptr, from_rank, to_rank, heap_bases) + translated_from_ptr = __translate(from_ptr, from_rank, to_rank, heap_bases, hint) data = tl.load(translated_from_ptr, mask=mask) @@ -2506,7 +2508,7 @@ def get(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None): @triton.jit -def put(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None): +def put(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None, hint: tl.constexpr = None): """ Copies data from the current rank's local memory to the specified rank's memory. This function performs a memory write operation by loading data from the current @@ -2532,7 +2534,7 @@ def put(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None): >>> to_rank = 1 >>> iris.put(local_ptr, remote_ptr, from_rank, to_rank, heap_bases) """ - translated_to_ptr = __translate(to_ptr, from_rank, to_rank, heap_bases) + translated_to_ptr = __translate(to_ptr, from_rank, to_rank, heap_bases, hint) data = tl.load(from_ptr, mask=mask) @@ -2540,7 +2542,9 @@ def put(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None): @triton.jit -def atomic_add(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None): +def atomic_add( + pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None, hint: tl.constexpr = None +): """ Performs an atomic add at the specified rank's memory location. @@ -2571,12 +2575,14 @@ def atomic_add(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None >>> increment = 5 >>> old_val = iris.atomic_add(ptr, increment, cur_rank, remote_rank, heap_bases) """ - translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) + translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases, hint) return tl.atomic_add(translated_ptr, val, mask=mask, sem=sem, scope=scope) @triton.jit -def atomic_sub(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None): +def atomic_sub( + pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None, hint: tl.constexpr = None +): """ Atomically subtracts data from the specified rank's memory location. @@ -2607,12 +2613,12 @@ def atomic_sub(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None >>> decrement = 3 >>> old_val = iris.atomic_sub(ptr, decrement, cur_rank, remote_rank, heap_bases) """ - translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) + translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases, hint) return tl.atomic_sub(translated_ptr, val, mask=mask, sem=sem, scope=scope) @triton.jit -def atomic_cas(pointer, cmp, val, from_rank, to_rank, heap_bases, sem=None, scope=None): +def atomic_cas(pointer, cmp, val, from_rank, to_rank, heap_bases, sem=None, scope=None, hint: tl.constexpr = None): """ Atomically compares and exchanges the specified rank's memory location. @@ -2644,12 +2650,14 @@ def atomic_cas(pointer, cmp, val, from_rank, to_rank, heap_bases, sem=None, scop >>> new_val = 42 >>> old_val = iris.atomic_cas(ptr, expected, new_val, cur_rank, remote_rank, heap_bases) """ - translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) + translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases, hint) return tl.atomic_cas(translated_ptr, cmp, val, sem=sem, scope=scope) @triton.jit -def atomic_xchg(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None): +def atomic_xchg( + pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None, hint: tl.constexpr = None +): """ Performs an atomic exchange at the specified rank's memory location. @@ -2680,12 +2688,14 @@ def atomic_xchg(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=Non >>> new_value = 99 >>> old_val = iris.atomic_xchg(ptr, new_value, cur_rank, remote_rank, heap_bases) """ - translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) + translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases, hint) return tl.atomic_xchg(translated_ptr, val, mask=mask, sem=sem, scope=scope) @triton.jit -def atomic_xor(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None): +def atomic_xor( + pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None, hint: tl.constexpr = None +): """ Performs an atomic xor at the specified rank's memory location. @@ -2716,12 +2726,14 @@ def atomic_xor(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None >>> mask_val = 0xFF >>> old_val = iris.atomic_xor(ptr, mask_val, cur_rank, remote_rank, heap_bases) """ - translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) + translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases, hint) return tl.atomic_xor(translated_ptr, val, mask=mask, sem=sem, scope=scope) @triton.jit -def atomic_and(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None): +def atomic_and( + pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None, hint: tl.constexpr = None +): """ Performs an atomic and at the specified rank's memory location. @@ -2752,12 +2764,12 @@ def atomic_and(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None >>> mask_val = 0x0F >>> old_val = iris.atomic_and(ptr, mask_val, cur_rank, remote_rank, heap_bases) """ - translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) + translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases, hint) return tl.atomic_and(translated_ptr, val, mask=mask, sem=sem, scope=scope) @triton.jit -def atomic_or(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None): +def atomic_or(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None, hint: tl.constexpr = None): """ Performs an atomic or at the specified rank's memory location. @@ -2788,12 +2800,14 @@ def atomic_or(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, >>> mask_val = 0xF0 >>> old_val = iris.atomic_or(ptr, mask_val, cur_rank, remote_rank, heap_bases) """ - translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) + translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases, hint) return tl.atomic_or(translated_ptr, val, mask=mask, sem=sem, scope=scope) @triton.jit -def atomic_min(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None): +def atomic_min( + pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None, hint: tl.constexpr = None +): """ Performs an atomic min at the specified rank's memory location. @@ -2824,12 +2838,14 @@ def atomic_min(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None >>> new_val = 10 >>> old_val = iris.atomic_min(ptr, new_val, cur_rank, remote_rank, heap_bases) """ - translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) + translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases, hint) return tl.atomic_min(translated_ptr, val, mask=mask, sem=sem, scope=scope) @triton.jit -def atomic_max(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None): +def atomic_max( + pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None, hint: tl.constexpr = None +): """ Performs an atomic max at the specified rank's memory location. @@ -2860,7 +2876,7 @@ def atomic_max(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None >>> new_val = 100 >>> old_val = iris.atomic_max(ptr, new_val, cur_rank, remote_rank, heap_bases) """ - translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) + translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases, hint) return tl.atomic_max(translated_ptr, val, mask=mask, sem=sem, scope=scope) diff --git a/iris/iris.py.backup b/iris/iris.py.backup deleted file mode 100644 index e8932c3c8..000000000 --- a/iris/iris.py.backup +++ /dev/null @@ -1,2255 +0,0 @@ -# SPDX-License-Identifier: MIT -# Copyright (c) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. - -""" -Iris: Multi-GPU Communication and Memory Management Framework - -Iris is a high-performance framework that enables seamless multi-GPU programming in Triton, -enabling fine-grained communication and compute overlap natively in Triton -across multiple GPUs with SHMEM-like Remote Memory Access (RMA) capabilities. - -Key Features: -- Symmetric heap management across multiple GPUs -- High-performance atomic operations (add, cas, xchg, xor, and, or, min, max) -- Efficient load/store operations with rank-to-rank communication -- Memory allocation and deallocation utilities -- Built-in logging with rank information -- PyTorch distributed integration for distributed computing - -Example: - >>> import iris - >>> ctx = iris.iris(heap_size=2**30) # 1GB heap - >>> tensor = ctx.zeros(1024, 1024, dtype=torch.float32) -""" - -import triton -import triton.language as tl - -from iris._distributed_helpers import ( - init_distributed, - distributed_barrier, - distributed_broadcast_scalar, - distributed_broadcast_tensor, -) -from iris.hip import ( - set_device, - get_cu_count, - count_devices, -) -from iris.symmetric_heap import SymmetricHeap -import numpy as np -import math -import torch -import logging - -# Import logging functionality from the separate logging module -from .logging import logger - - -class Iris: - """ - Main Iris class for multi-GPU communication and memory management. - - This class provides a unified interface for distributed GPU operations including - memory allocation, atomic operations, and inter-rank communication. - - Args: - heap_size (int): Size of the symmetric heap in bytes. Default: 1GB (2^30) - - Example: - >>> ctx = iris.iris(heap_size=2**31) # 2GB heap - >>> print(f"Rank {ctx.cur_rank} of {ctx.num_ranks}") # Rank 0 of 1 - >>> tensor = ctx.zeros(1000, 1000, dtype=torch.float32) - """ - - def __init__(self, heap_size=1 << 30): - # Initialize distributed environment - comm, cur_rank, num_ranks = init_distributed() - num_gpus = count_devices() - - gpu_id = cur_rank % num_gpus - set_device(gpu_id) - - self.comm = comm - self.num_ranks = num_ranks - self.cur_rank = cur_rank - self.gpu_id = gpu_id - self.heap_size = heap_size - - # Initialize symmetric heap - self.heap = SymmetricHeap(heap_size, gpu_id, cur_rank, num_ranks) - self.device = f"cuda:{gpu_id}" - self.heap_bases = self.heap.get_heap_bases() - - for i in range(num_ranks): - self.debug(f"GPU {i}: Heap base {hex(int(self.heap_bases[i].item()))}") - - distributed_barrier() - - # Initialize CCL interface - self.ccl = self.CCL(self) - - # Lazy initialization for ops interface - self._ops = None - - def _log_with_rank(self, level, message): - """Helper method to log with rank information injected into the record.""" - if logger.isEnabledFor(level): - record = logging.LogRecord( - name=logger.name, level=level, pathname="", lineno=0, msg=message, args=(), exc_info=None - ) - # Inject rank information into the record - record.iris_rank = self.cur_rank - record.iris_num_ranks = self.num_ranks - logger.handle(record) - - def debug(self, message): - """ - Log a debug message with rank information. - - Args: - message (str): Human-readable message to log at debug level. - - Notes: - The log record is enriched with ``iris_rank`` and ``iris_num_ranks`` so - formatters can display the originating rank and world size. - - Example: - >>> ctx = iris.iris() - >>> iris.set_logger_level(iris.DEBUG) - >>> ctx.debug("Allocating buffers") # [Iris] [0/1] Allocating buffers - """ - self._log_with_rank(logging.DEBUG, message) - - def info(self, message): - """ - Log an info message with rank information. - - Args: - message (str): Human-readable message to log at info level. - - Example: - >>> ctx = iris.iris() - >>> ctx.info("Starting iteration 0") # [Iris] [0/1] Starting iteration 0 - """ - self._log_with_rank(logging.INFO, message) - - def warning(self, message): - """ - Log a warning message with rank information. - - Args: - message (str): Human-readable message to log at warning level. - - Example: - >>> ctx = iris.iris() - >>> ctx.warning("Memory usage is high") # [Iris] [0/1] Memory usage is high - """ - self._log_with_rank(logging.WARNING, message) - - def error(self, message): - """ - Log an error message with rank information. - - Args: - message (str): Human-readable message to log at error level. - - Example: - >>> ctx = iris.iris() - >>> ctx.error("Failed to allocate memory") # [Iris] [0/1] Failed to allocate memory - """ - self._log_with_rank(logging.ERROR, message) - - @property - def ops(self): - """ - Access fused GEMM+CCL operations. - - This property provides a namespace for high-level fused operations that combine - matrix multiplication with collective communication. Operations automatically infer - dimensions, strides, and hardware parameters from input tensors. - - Available operations: - - matmul_all_reduce: GEMM + All-Reduce - - all_gather_matmul: All-Gather + GEMM - - matmul_all_gather: GEMM + All-Gather - - matmul_reduce_scatter: GEMM + Reduce-Scatter - - Returns: - OpsNamespace: Namespace with fused operation methods - - Raises: - ImportError: If tritonBLAS is not available - - Example: - >>> ctx = iris.iris() - >>> A = ctx.randn((1024, 512), dtype=torch.float16) - >>> B = ctx.randn((512, 2048), dtype=torch.float16) - >>> output = ctx.zeros((1024, 2048), dtype=torch.float16) - >>> ctx.ops.matmul_all_reduce(output, A, B, ctx) - """ - if self._ops is None: - from iris.ops import OpsNamespace - - self._ops = OpsNamespace(self) - return self._ops - - def broadcast(self, value, source_rank=0): - """ - Broadcast a value from one rank to all ranks. - - This method automatically detects the type of value and uses the appropriate - broadcast mechanism: - - For tensors and arrays: uses efficient PyTorch distributed tensor collectives - - For scalars and other objects: uses object broadcast - - Args: - value (Any): The value to broadcast. Can be a scalar, tensor, numpy array, - or any picklable object. Only the ``source_rank`` value is used; - other ranks should pass a placeholder (e.g., ``None``). - source_rank (int): Rank id that holds the authoritative value. - - Returns: - Any: The value broadcast to all ranks. Tensors and arrays are returned as - numpy arrays; scalars and objects are returned in their original type. - - Examples: - >>> ctx = iris.iris() - >>> # Broadcasting a scalar - >>> value = 42 if ctx.cur_rank == 0 else None - >>> value = ctx.broadcast(value, source_rank=0) # All ranks get 42 - >>> - >>> # Broadcasting a tensor - >>> if ctx.cur_rank == 0: - >>> data = torch.randn(10, 10) - >>> else: - >>> data = None - >>> data = ctx.broadcast(data, source_rank=0) # All ranks get the same array - """ - # Check if the value on source_rank is a tensor or array-like - if self.cur_rank == source_rank and value is not None: - # Explicitly exclude strings and non-numeric types - if isinstance(value, (str, dict, bool)): - is_tensor = False - elif isinstance(value, torch.Tensor): - is_tensor = True - elif isinstance(value, np.ndarray): - is_tensor = True - elif isinstance(value, (list, tuple)): - # Try to convert list/tuple to tensor to check if it's numeric - try: - torch.as_tensor(value) - is_tensor = True - except (TypeError, ValueError): - is_tensor = False - else: - # For other types, try to convert and check - try: - test_array = np.asarray(value) - # Check if it's a numeric dtype that torch can handle - if np.issubdtype(test_array.dtype, np.number): - torch.as_tensor(test_array) - is_tensor = True - else: - is_tensor = False - except (TypeError, ValueError): - is_tensor = False - else: - is_tensor = False - - # Broadcast the type decision to all ranks - is_tensor = distributed_broadcast_scalar(is_tensor, source_rank) - - if is_tensor: - return distributed_broadcast_tensor(value, root=source_rank) - else: - return distributed_broadcast_scalar(value, source_rank) - - def __allocate(self, num_elements, dtype): - """Allocate memory using the symmetric heap.""" - self.debug(f"allocate: num_elements = {num_elements}, dtype = {dtype}") - return self.heap.allocate(num_elements, dtype) - - def __parse_size(self, size): - # Handle nested tuples/lists by flattening them recursively - while len(size) == 1 and isinstance(size[0], (tuple, list)): - size = size[0] - num_elements = math.prod(size) - return size, num_elements - - def zeros_like( - self, input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format - ): - """ - Returns a tensor filled with the scalar value 0, with the same size as input, allocated on the Iris symmetric heap. - - Args: - input (Tensor): the size of input will determine size of the output tensor. - - Keyword Arguments: - dtype (torch.dtype, optional): the desired data type of returned Tensor. - Default: if None, defaults to the dtype of input. - layout (torch.layout, optional): the desired layout of returned tensor. - Default: if None, defaults to the layout of input. Note: Iris tensors are always contiguous (strided). - device (torch.device, optional): the desired device of returned tensor. - Default: if None, defaults to the device of input. Must be compatible with this Iris instance. - requires_grad (bool, optional): If autograd should record operations on the returned tensor. - Default: False. - memory_format (torch.memory_format, optional): the desired memory format of returned Tensor. - Default: torch.preserve_format. - - Example: - >>> ctx = iris.iris(1 << 20) - >>> input_tensor = ctx.ones(2, 3) - >>> zeros_tensor = ctx.zeros_like(input_tensor) - >>> print(zeros_tensor.shape) # torch.Size([2, 3]) - """ - self.debug( - f"zeros_like: input_shape = {input.shape}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}" - ) - - # Use input's properties as defaults if not specified - if dtype is None: - dtype = input.dtype - if layout is None: - layout = input.layout - if device is None: - device = input.device - - # Validate device compatibility with Iris - self.__throw_if_invalid_device(device) - - # Get the size from input tensor - size = input.size() - num_elements = input.numel() - - # Allocate new tensor with the same size - new_tensor = self.__allocate(num_elements, dtype) - new_tensor.zero_() - - # Reshape to match input size - new_tensor = new_tensor.reshape(size) - - # Apply the requested memory format - new_tensor = self.__apply_memory_format(new_tensor, size, memory_format, input) - - # Apply the requested layout - new_tensor = self.__apply_layout(new_tensor, layout) - - # Set requires_grad if specified - if requires_grad: - new_tensor.requires_grad_() - - return new_tensor - - def arange( - self, start=0, end=None, step=1, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False - ): - """ - Returns a 1-D tensor of size ⌈(end - start) / step⌉ with values from the interval [start, end) - taken with common difference step beginning from start. The tensor is allocated on the symmetric heap. - - Note: When using floating-point dtypes (especially reduced precision types like bfloat16), - the results may be affected by floating-point rounding behavior. Some values in the sequence - might not be exactly representable in certain floating-point formats, which can lead to - repeated values or unexpected rounding. For precise sequences, it is recommended to use - integer dtypes instead of floating-point dtypes. - - Note that non-integer step is subject to floating point rounding errors when comparing - against end; to avoid inconsistency, we advise subtracting a small epsilon from end in such cases. - - Args: - start (Number, optional): the starting value for the set of points. Default: 0. - end (Number): the ending value for the set of points - step (Number, optional): the gap between each pair of adjacent points. Default: 1. - out (Tensor, optional): the output tensor. - dtype (torch.dtype, optional): the desired data type of returned tensor. - Default: if None, uses a global default (see torch.get_default_dtype()). - If dtype is not given, infer the data type from the other input arguments. - If any of start, end, or step are floating-point, the dtype is inferred - be the default dtype, see get_default_dtype(). Otherwise, the dtype is inferred - to be torch.int64. - layout (torch.layout, optional): the desired layout of returned Tensor. Default: torch.strided. - Note: Iris tensors always use `torch.strided` regardless of this parameter. - device (torch.device, optional): the desired device of returned tensor. - Default: if None, uses the current device for the default tensor type. - requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False. - - Example: - >>> ctx = iris.iris(1 << 20) - >>> tensor = ctx.arange(0, 10, 2) # [0, 2, 4, 6, 8] - >>> print(tensor.shape) # torch.Size([5]) - """ - self.debug(f"arange: start = {start}, end = {end}, step = {step}, dtype = {dtype}, device = {device}") - - # Handle the case where only one argument is provided (end) - if end is None: - end = start - start = 0 - - # Validate inputs - if step == 0: - raise ValueError("step must be non-zero") - - # Validate step direction consistency - if step > 0 and start >= end: - raise ValueError(f"Invalid range: start >= end with positive step (start={start}, end={end}, step={step})") - elif step < 0 and start <= end: - raise ValueError(f"Invalid range: start <= end with negative step (start={start}, end={end}, step={step})") - - # Calculate the number of elements - num_elements = math.ceil((end - start) / step) - - # Infer dtype if not provided - if dtype is None: - if any(isinstance(x, float) for x in [start, end, step]): - dtype = torch.get_default_dtype() - else: - dtype = torch.int64 - - # Use current device if none specified - if device is None: - device = self.device - - # Validate device compatibility with Iris - self.__throw_if_invalid_device(device) - - if out is not None: - self.__throw_if_invalid_output_tensor(out, num_elements, dtype) - tensor = out - else: - tensor = self.__allocate(num_elements=num_elements, dtype=dtype) - - target_device = tensor.device - arange_tensor = torch.arange(start, end, step, dtype=dtype, device=target_device) - - tensor[:] = arange_tensor - - tensor = self.__apply_layout(tensor, layout) - - if requires_grad: - tensor.requires_grad_() - - return tensor - - def zeros(self, *size, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False): - """ - Returns a tensor filled with the scalar value 0, with the shape defined by the variable argument size. - The tensor is allocated on the Iris symmetric heap. - - Args: - *size (int...): a sequence of integers defining the shape of the output tensor. - Can be a variable number of arguments or a collection like a list or tuple. - - Keyword Arguments: - out (Tensor, optional): the output tensor. - dtype (torch.dtype, optional): the desired data type of returned tensor. - Default: if None, uses a global default (see torch.set_default_dtype()). - layout (torch.layout, optional): the desired layout of returned Tensor. - Default: torch.strided. Note: Iris tensors always use `torch.strided` regardless of this parameter. - device (torch.device, optional): the desired device of returned tensor. - Default: if None, uses the current device for the default tensor type. - requires_grad (bool, optional): If autograd should record operations on the returned tensor. - Default: False. - - Example: - >>> ctx = iris.iris(1 << 20) - >>> tensor = ctx.zeros(2, 3) - >>> print(tensor.shape) # torch.Size([2, 3]) - >>> print(tensor[0]) # tensor([0., 0., 0.], device='cuda:0') - """ - self.debug(f"zeros: size = {size}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}") - - # Use global default dtype if None is provided - if dtype is None: - dtype = torch.get_default_dtype() - - # Use current device if none specified - if device is None: - device = self.device - - # Validate device compatibility with Iris - self.__throw_if_invalid_device(device) - - # Parse size and calculate number of elements - size, num_elements = self.__parse_size(size) - - # If out is provided, use it; otherwise allocate new tensor - if out is not None: - self.__throw_if_invalid_output_tensor(out, num_elements, dtype) - # Fill with zeros - out.zero_() - # Create a reshaped view of the out tensor - tensor = out.view(size) - else: - tensor = self.__allocate(num_elements=num_elements, dtype=dtype) - # Fill with zeros - tensor.zero_() - # Reshape to the desired size - tensor = tensor.reshape(size) - - # Apply the requested layout - tensor = self.__apply_layout(tensor, layout) - - # Set requires_grad if specified - if requires_grad: - tensor.requires_grad_() - - return tensor - - def randn( - self, - *size, - generator=None, - out=None, - dtype=None, - layout=torch.strided, - device=None, - requires_grad=False, - pin_memory=False, - ): - """ - Returns a tensor filled with random numbers from a normal distribution with mean 0 and variance 1 - (also called the standard normal distribution). The tensor is allocated on the Iris symmetric heap. - - .. math:: - \\text{out}_i \\sim \\mathcal{N}(0, 1) - - For complex dtypes, the tensor is i.i.d. sampled from a complex normal distribution with zero mean - and unit variance as - - .. math:: - \\text{out}_i \\sim \\mathcal{CN}(0, 1) - - This is equivalent to separately sampling the real :math:`(\\text{Re})` and imaginary :math:`(\\text{Im})` - part of :math:`\\text{out}_i` as - - .. math:: - \\text{Re}(\\text{out}_i) \\sim \\mathcal{N}(0, \\frac{1}{2}), \\quad \\text{Im}(\\text{out}_i) \\sim \\mathcal{N}(0, \\frac{1}{2}) - - The shape of the tensor is defined by the variable argument size. - - Args: - *size (int...): a sequence of integers defining the shape of the output tensor. - Can be a variable number of arguments or a collection like a list or tuple. - - Keyword Arguments: - generator (torch.Generator, optional): a pseudorandom number generator for sampling - out (Tensor, optional): the output tensor. - dtype (torch.dtype, optional): the desired data type of returned tensor. - Default: if None, uses a global default (see torch.set_default_dtype()). - layout (torch.layout, optional): the desired layout of returned Tensor. - Default: torch.strided. Note: Iris tensors always use `torch.strided` regardless of this parameter. - device (torch.device, optional): the desired device of returned tensor. - Default: if None, uses the current device for the default tensor type (see torch.set_default_device()). - device will be the CPU for CPU tensor types and the current CUDA device for CUDA tensor types. - requires_grad (bool, optional): If autograd should record operations on the returned tensor. - Default: False. - pin_memory (bool, optional): If set, returned tensor would be allocated in the pinned memory. - Works only for CPU tensors. Default: False. - - Example: - >>> ctx = iris.iris(1 << 20) - >>> tensor = ctx.randn(2, 3) - >>> print(tensor.shape) # torch.Size([2, 3]) - >>> print(tensor[0]) # tensor([ 0.3982, -0.0059, -0.4365], device='cuda:0') - """ - self.debug( - f"randn: size = {size}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}, pin_memory = {pin_memory}" - ) - - # Use global default dtype if None is provided - if dtype is None: - dtype = torch.get_default_dtype() - - # Use current device if none specified - if device is None: - device = self.device - - # Validate device compatibility with Iris - self.__throw_if_invalid_device(device) - - # Parse size and calculate number of elements - size, num_elements = self.__parse_size(size) - - # If out is provided, use it; otherwise allocate new tensor - if out is not None: - self.__throw_if_invalid_output_tensor(out, num_elements, dtype) - # Generate random data and copy to out tensor - random_data = torch.randn(num_elements, generator=generator, dtype=dtype, device=device, layout=layout) - out.copy_(random_data) - # Create a reshaped view of the out tensor - tensor = out.view(size) - else: - tensor = self.__allocate(num_elements=num_elements, dtype=dtype) - # Generate random data and copy to tensor - random_data = torch.randn(num_elements, generator=generator, dtype=dtype, device=device, layout=layout) - tensor.copy_(random_data) - # Reshape to the desired size - tensor = tensor.reshape(size) - - # Apply the requested layout - tensor = self.__apply_layout(tensor, layout) - - # Set requires_grad if specified - if requires_grad: - tensor.requires_grad_() - - return tensor - - def ones(self, *size, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False): - """ - Returns a tensor filled with the scalar value 1, with the shape defined by the variable argument size. - The tensor is allocated on the Iris symmetric heap. - - Args: - *size (int...): a sequence of integers defining the shape of the output tensor. - Can be a variable number of arguments or a collection like a list or tuple. - - Keyword Arguments: - out (Tensor, optional): the output tensor. - dtype (torch.dtype, optional): the desired data type of returned tensor. - Default: if None, uses a global default (see torch.set_default_dtype()). - layout (torch.layout, optional): the desired layout of returned Tensor. - Default: torch.strided. Note: Iris tensors always use `torch.strided` regardless of this parameter. - device (torch.device, optional): the desired device of returned tensor. - Default: if None, uses the current device for the default tensor type. - requires_grad (bool, optional): If autograd should record operations on the returned tensor. - Default: False. - - Example: - >>> ctx = iris.iris(1 << 20) - >>> tensor = ctx.ones(2, 3) - >>> print(tensor.shape) # torch.Size([2, 3]) - >>> print(tensor[0]) # tensor([1., 1., 1.], device='cuda:0') - """ - self.debug(f"ones: size = {size}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}") - - # Use global default dtype if None is provided - if dtype is None: - dtype = torch.get_default_dtype() - - # Use current device if none specified - if device is None: - device = self.device - - # Validate device compatibility with Iris - self.__throw_if_invalid_device(device) - - # Parse size and calculate number of elements - size, num_elements = self.__parse_size(size) - - # If out is provided, use it; otherwise allocate new tensor - if out is not None: - self.__throw_if_invalid_output_tensor(out, num_elements, dtype) - # Fill with ones - out.fill_(1) - # Create a reshaped view of the out tensor - tensor = out.view(size) - else: - tensor = self.__allocate(num_elements=num_elements, dtype=dtype) - # Fill with ones - tensor.fill_(1) - # Reshape to the desired size - tensor = tensor.reshape(size) - - # Apply the requested layout - tensor = self.__apply_layout(tensor, layout) - - # Set requires_grad if specified - if requires_grad: - tensor.requires_grad_() - - return tensor - - def full(self, size, fill_value, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False): - """ - Creates a tensor of size size filled with fill_value. The tensor's dtype is inferred from fill_value. - The tensor is allocated on the Iris symmetric heap. - - Args: - size (int...): a list, tuple, or torch.Size of integers defining the shape of the output tensor. - fill_value (Scalar): the value to fill the output tensor with. - - Keyword Arguments: - out (Tensor, optional): the output tensor. - dtype (torch.dtype, optional): the desired data type of returned tensor. - Default: if None, uses a global default (see torch.set_default_dtype()). - layout (torch.layout, optional): the desired layout of returned Tensor. - Default: torch.strided. Note: Iris tensors always use `torch.strided` regardless of this parameter. - device (torch.device, optional): the desired device of returned tensor. - Default: if None, uses the current device for the default tensor type. - requires_grad (bool, optional): If autograd should record operations on the returned tensor. - Default: False. - - Example: - >>> ctx = iris.iris(1 << 20) - >>> tensor = ctx.full((2, 3), 3.14) - >>> print(tensor.shape) # torch.Size([2, 3]) - >>> print(tensor[0]) # tensor([3.1400, 3.1400, 3.1400], device='cuda:0') - """ - self.debug( - f"full: size = {size}, fill_value = {fill_value}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}" - ) - - # Infer dtype from fill_value if not provided - if dtype is None: - if isinstance(fill_value, (int, float)): - if isinstance(fill_value, float): - dtype = torch.get_default_dtype() - else: - dtype = torch.int64 - else: - # For other types (like tensors), use their dtype - dtype = torch.get_default_dtype() - - # Use current device if none specified - if device is None: - device = self.device - - # Validate device compatibility with Iris - self.__throw_if_invalid_device(device) - - # Parse size and calculate number of elements - size, num_elements = self.__parse_size(size) - - # If out is provided, use it; otherwise allocate new tensor - if out is not None: - self.__throw_if_invalid_output_tensor(out, num_elements, dtype) - # Fill with the specified value - out.fill_(fill_value) - # Create a reshaped view of the out tensor - tensor = out.view(size) - else: - tensor = self.__allocate(num_elements=num_elements, dtype=dtype) - # Fill with the specified value - tensor.fill_(fill_value) - # Reshape to the desired size - tensor = tensor.reshape(size) - - # Apply the requested layout - tensor = self.__apply_layout(tensor, layout) - - # Set requires_grad if specified - if requires_grad: - tensor.requires_grad_() - - return tensor - - def uniform(self, size, low=0.0, high=1.0, dtype=torch.float): - """ - Returns a tensor filled with random numbers from a uniform distribution, allocated on the Iris symmetric heap. - - Args: - size (int or tuple of ints): the size of the output tensor. - low (float, optional): the lower bound of the uniform distribution. Default: 0.0. - high (float, optional): the upper bound of the uniform distribution. Default: 1.0. - dtype (torch.dtype, optional): the desired data type of returned tensor. Default: torch.float. - - Returns: - Tensor: A tensor filled with random numbers from a uniform distribution. - - Example: - >>> ctx = iris.iris(1 << 20) - >>> tensor = ctx.uniform((2, 3), low=0.0, high=1.0) - >>> print(tensor.shape) # torch.Size([2, 3]) - >>> print(tensor[0]) # tensor([0.1234, 0.5678, 0.9012], device='cuda:0') - """ - self.debug(f"uniform: size = {size}, low = {low}, high = {high}, dtype = {dtype}") - size, num_elements = self.__parse_size(size) - tensor = self.__allocate(num_elements=num_elements, dtype=dtype) - tensor.uniform_(low, high) - return tensor.reshape(size) - - def empty( - self, - *size, - out=None, - dtype=None, - layout=torch.strided, - device=None, - requires_grad=False, - pin_memory=False, - memory_format=torch.contiguous_format, - ): - """ - Returns a tensor filled with uninitialized data. The shape of the tensor is defined by the variable argument size. - The tensor is allocated on the Iris symmetric heap. - - Note: - If torch.use_deterministic_algorithms() and torch.utils.deterministic.fill_uninitialized_memory are both set to True, - the output tensor is initialized to prevent any possible nondeterministic behavior from using the data as an input to an operation. - Floating point and complex tensors are filled with NaN, and integer tensors are filled with the maximum value. - - Args: - *size (int...): a sequence of integers defining the shape of the output tensor. - Can be a variable number of arguments or a collection like a list or tuple. - - Keyword Arguments: - out (Tensor, optional): the output tensor. - dtype (torch.dtype, optional): the desired data type of returned tensor. - Default: if None, uses a global default (see torch.set_default_dtype()). - layout (torch.layout, optional): the desired layout of returned Tensor. - Default: torch.strided. Note: Iris tensors always use `torch.strided` regardless of this parameter. - device (torch.device, optional): the desired device of returned tensor. - Default: if None, uses the current device for the default tensor type. - requires_grad (bool, optional): If autograd should record operations on the returned tensor. - Default: False. - pin_memory (bool, optional): If set, returned tensor would be allocated in the pinned memory. - Works only for CPU tensors. Default: False. Note: Iris tensors are always on GPU. - memory_format (torch.memory_format, optional): the desired memory format of returned Tensor. - Default: torch.contiguous_format. - - Example: - >>> ctx = iris.iris(1 << 20) - >>> tensor = ctx.empty(2, 3) - >>> print(tensor.shape) # torch.Size([2, 3]) - """ - self.debug( - f"empty: size = {size}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}, pin_memory = {pin_memory}" - ) - - # Use global default dtype if None is provided - if dtype is None: - dtype = torch.get_default_dtype() - - # Use current device if none specified - if device is None: - device = self.device - - # Validate device compatibility with Iris - self.__throw_if_invalid_device(device) - - # Parse size and calculate number of elements - size, num_elements = self.__parse_size(size) - - # If out is provided, use it; otherwise allocate new tensor - if out is not None: - self.__throw_if_invalid_output_tensor(out, num_elements, dtype) - # Create a reshaped view of the out tensor - tensor = out.view(size) - else: - tensor = self.__allocate(num_elements=num_elements, dtype=dtype) - # Reshape to the desired size - tensor = tensor.reshape(size) - - # Apply the requested memory format - tensor = self.__apply_memory_format(tensor, size, memory_format) - - # Apply the requested layout - tensor = self.__apply_layout(tensor, layout) - - # Set requires_grad if specified - if requires_grad: - tensor.requires_grad_() - - return tensor - - def randint( - self, *args, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False - ): - """ - Returns a tensor filled with random integers generated uniformly between low (inclusive) and high (exclusive). - The shape of the tensor is defined by the variable argument size. - The tensor is allocated on the Iris symmetric heap. - - Note: - With the global dtype default (torch.float32), this function returns a tensor with dtype torch.int64. - - Args: - low (int, optional): Lowest integer to be drawn from the distribution. Default: 0. - high (int): One above the highest integer to be drawn from the distribution. - size (tuple): a tuple defining the shape of the output tensor. - - Keyword Arguments: - generator (torch.Generator, optional): a pseudorandom number generator for sampling. - out (Tensor, optional): the output tensor. - dtype (torch.dtype, optional): if None, this function returns a tensor with dtype torch.int64. - layout (torch.layout, optional): the desired layout of returned Tensor. Default: torch.strided. - device (torch.device, optional): the desired device of returned tensor. Default: if None, uses the current device. - requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False. - - Example: - >>> ctx = iris.iris(1 << 20) - >>> tensor = ctx.randint(0, 10, (2, 3)) # Random integers [0, 10) - >>> print(tensor.shape) # torch.Size([2, 3]) - >>> print(tensor[0]) # tensor([7, 2, 9], device='cuda:0') - """ - self.debug(f"randint: args = {args}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}") - - # Parse arguments to determine low, high, and size - # PyTorch randint signatures: - # randint(high, size) - where high is the upper bound and size is the shape - # randint(low, high, size) - where low and high are bounds, size is the shape - if len(args) == 2: - # randint(high, size) - high, size = args - low = 0 - elif len(args) == 3: - # randint(low, high, size) - low, high, size = args - else: - raise ValueError(f"randint expects 2 or 3 positional arguments, got {len(args)}") - - # Use default dtype if None is provided - if dtype is None: - dtype = torch.int64 - - # Use current device if none specified - if device is None: - device = self.device - - # Validate device compatibility with Iris - self.__throw_if_invalid_device(device) - - # Parse size and calculate number of elements - size, num_elements = self.__parse_size(size) - - # If out is provided, use it; otherwise allocate new tensor - if out is not None: - self.__throw_if_invalid_output_tensor(out, num_elements, dtype) - # Create a reshaped view of the out tensor - tensor = out.view(size) - else: - tensor = self.__allocate(num_elements=num_elements, dtype=dtype) - # Reshape to the desired size - tensor = tensor.reshape(size) - - # Generate random integers using PyTorch's randint - # Use specified device or fall back to current device - target_device = device if device is not None else self.device - - # Handle generator parameter - if generator is not None: - torch.randint(low, high, size, generator=generator, out=tensor, dtype=dtype, device=target_device) - else: - torch.randint(low, high, size, out=tensor, dtype=dtype, device=target_device) - - # Apply the requested layout - tensor = self.__apply_layout(tensor, layout) - - # Set requires_grad if specified - if requires_grad: - tensor.requires_grad_() - - return tensor - - def linspace(self, start, end, steps, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False): - """ - Creates a one-dimensional tensor of size steps whose values are evenly spaced from start to end, inclusive. - The tensor is allocated on the Iris symmetric heap. - - The values are: - (start, start + (end-start)/(steps-1), ..., start + (steps-2)*(end-start)/(steps-1), end) - - Args: - start (float or Tensor): the starting value for the set of points. If Tensor, it must be 0-dimensional. - end (float or Tensor): the ending value for the set of points. If Tensor, it must be 0-dimensional. - steps (int): size of the constructed tensor. - - Keyword Arguments: - out (Tensor, optional): the output tensor. - dtype (torch.dtype, optional): the data type to perform the computation in. - Default: if None, uses the global default dtype when both start and end are real, - and corresponding complex dtype when either is complex. - layout (torch.layout, optional): the desired layout of returned Tensor. Default: torch.strided. - device (torch.device, optional): the desired device of returned tensor. Default: if None, uses the current device. - requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False. - - Example: - >>> ctx = iris.iris(1 << 20) - >>> tensor = ctx.linspace(0, 10, 5) # [0, 2.5, 5, 7.5, 10] - >>> print(tensor) # tensor([ 0.0000, 2.5000, 5.0000, 7.5000, 10.0000], device='cuda:0') - """ - self.debug( - f"linspace: start = {start}, end = {end}, steps = {steps}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}" - ) - - # Use global default dtype if None is provided - if dtype is None: - # Check if start or end are complex numbers - start_is_complex = isinstance(start, complex) or (hasattr(start, "dtype") and torch.is_complex(start)) - end_is_complex = isinstance(end, complex) or (hasattr(end, "dtype") and torch.is_complex(end)) - - if start_is_complex or end_is_complex: - # Infer complex dtype based on default dtype - dtype = torch.complex64 if torch.get_default_dtype() == torch.float32 else torch.complex128 - else: - dtype = torch.get_default_dtype() - - # Use current device if none specified - if device is None: - device = self.device - - # Validate device compatibility with Iris - self.__throw_if_invalid_device(device) - - # Parse steps and extract the integer value - if isinstance(steps, (tuple, list)): - if len(steps) == 1: - # Single-element tuple/list like (5,) or [5] - steps_int = steps[0] - # Handle nested tuples like ((5,),) - if isinstance(steps_int, (tuple, list)): - steps_int = steps_int[0] - else: - # Multi-element tuple/list - use __parse_size for compatibility - size, num_elements = self.__parse_size(steps) - steps_int = num_elements - else: - # steps is a single integer - steps_int = steps - - # Ensure steps_int is an integer - steps_int = int(steps_int) - size = (steps_int,) - num_elements = steps_int - - # If out is provided, use it; otherwise allocate new tensor - if out is not None: - self.__throw_if_invalid_output_tensor(out, num_elements, dtype) - # Create a reshaped view of the out tensor - tensor = out.view(size) - else: - tensor = self.__allocate(num_elements=num_elements, dtype=dtype) - # Reshape to the desired size - tensor = tensor.reshape(size) - - # Generate linspace using PyTorch's linspace - # Use specified device or fall back to current device - target_device = device if device is not None else self.device - torch.linspace(start, end, steps_int, out=tensor, dtype=dtype, device=target_device) - - # Apply the requested layout - tensor = self.__apply_layout(tensor, layout) - - # Set requires_grad if specified - if requires_grad: - tensor.requires_grad_() - - return tensor - - def rand( - self, - *size, - generator=None, - out=None, - dtype=None, - layout=torch.strided, - device=None, - requires_grad=False, - pin_memory=False, - ): - """ - Returns a tensor filled with random numbers from a uniform distribution on the interval [0, 1). - The tensor is allocated on the Iris symmetric heap. - - Args: - *size (int...): a sequence of integers defining the shape of the output tensor. - Can be a variable number of arguments or a collection like a list or tuple. - - Keyword Arguments: - generator (torch.Generator, optional): a pseudorandom number generator for sampling. - out (Tensor, optional): the output tensor. - dtype (torch.dtype, optional): the desired data type of returned tensor. - Default: if None, uses a global default (see torch.set_default_dtype()). - layout (torch.layout, optional): the desired layout of returned Tensor. - Default: torch.strided. Note: Iris tensors always use `torch.strided` regardless of this parameter. - device (torch.device, optional): the desired device of returned tensor. - Default: if None, uses the current device for the default tensor type. - requires_grad (bool, optional): If autograd should record operations on the returned tensor. - Default: False. - pin_memory (bool, optional): If set, returned tensor would be allocated in the pinned memory. - Works only for CPU tensors. Default: False. Note: Iris tensors are always on GPU. - - Example: - >>> ctx = iris.iris(1 << 20) - >>> tensor = ctx.rand(2, 3) # Random values in [0, 1) - >>> print(tensor.shape) # torch.Size([2, 3]) - >>> print(tensor[0]) # tensor([0.1234, 0.5678, 0.9012], device='cuda:0') - """ - self.debug( - f"rand: size = {size}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}, pin_memory = {pin_memory}" - ) - - # Use global default dtype if None is provided - if dtype is None: - dtype = torch.get_default_dtype() - - # Use current device if none specified - if device is None: - device = self.device - - # Validate device compatibility with Iris - self.__throw_if_invalid_device(device) - - # Parse size and calculate number of elements - size, num_elements = self.__parse_size(size) - - # If out is provided, use it; otherwise allocate new tensor - if out is not None: - self.__throw_if_invalid_output_tensor(out, num_elements, dtype) - # Create a reshaped view of the out tensor - tensor = out.view(size) - else: - tensor = self.__allocate(num_elements=num_elements, dtype=dtype) - # Reshape to the desired size - tensor = tensor.reshape(size) - - # Generate random numbers using PyTorch's rand - # Use specified device (already validated and set above) - - # Handle generator parameter - if generator is not None: - torch.rand(size, generator=generator, out=tensor, dtype=dtype, device=device) - else: - torch.rand(size, out=tensor, dtype=dtype, device=device) - - # Apply the requested layout - tensor = self.__apply_layout(tensor, layout) - - # Set requires_grad if specified - if requires_grad: - tensor.requires_grad_() - - return tensor - - def __deallocate(self, pointer): - pass - - def get_heap_bases(self): - """ - Return the tensor of symmetric heap base addresses for all ranks. - - Returns: - torch.Tensor: A 1D tensor of ``uint64`` heap base addresses of size ``num_ranks`` - on the Iris device. Pass this to device-side Triton kernels that require - heap translation. - - Example: - >>> ctx = iris.iris(1 << 20) - >>> heap_bases = ctx.get_heap_bases() - >>> print(heap_bases.shape) # torch.Size([num_ranks]) - """ - return self.heap_bases - - def barrier(self, stream=None, group=None): - """ - Synchronize ranks within the specified group and their CUDA devices. - - This first calls ``torch.cuda.synchronize()`` or ``stream.synchronize()`` to ensure the local GPU has - finished all queued work, then performs a distributed barrier so that all - ranks in the group reach the same point before proceeding. - - Args: - stream: If stream is given: wait only for that stream before barrier. If stream is None: legacy behavior (device-wide sync). - group (ProcessGroup, optional): The process group to synchronize. - If None, uses the default process group (all ranks). - - Example: - >>> ctx = iris.iris(1 << 20) - >>> ctx.barrier() # Synchronize all ranks - >>> ctx.barrier(group=my_group) # Synchronize only ranks in my_group - """ - # Wait for all GPUs to finish work - if stream is None: - torch.cuda.synchronize() - else: - stream.synchronize() - - # Distributed barrier - distributed_barrier(group=group) - - def get_device(self): - """ - Get the underlying device where the Iris symmetric heap resides. - - Returns: - torch.device: The CUDA device of Iris-managed memory. - - Example: - >>> ctx = iris.iris(1 << 20) - >>> device = ctx.get_device() - >>> print(device) # cuda:0 - """ - return self.heap.get_device() - - def get_cu_count(self): - """ - Get the number of compute units (CUs) for the current GPU. - - Returns: - int: Number of compute units on this rank's GPU. - - Example: - >>> ctx = iris.iris(1 << 20) - >>> cu_count = ctx.get_cu_count() - >>> print(f"GPU has {cu_count} CUs") # GPU has 304 CUs - """ - return get_cu_count(self.gpu_id) - - def get_rank(self): - """ - Get this process's rank id in the distributed communicator. - - Returns: - int: Zero-based rank id of the current process. - - Example: - >>> ctx = iris.iris(1 << 20) - >>> rank = ctx.get_rank() - >>> print(f"This is rank {rank}") # This is rank 0 - """ - return self.cur_rank - - def get_num_ranks(self): - """ - Get the total number of ranks in the distributed communicator. - - Returns: - int: World size (number of ranks). - - Example: - >>> ctx = iris.iris(1 << 20) - >>> num_ranks = ctx.get_num_ranks() - >>> print(f"Total ranks: {num_ranks}") # Total ranks: 1 - """ - return self.num_ranks - - def __throw_if_invalid_output_tensor(self, tensor: torch.Tensor, num_elements: int, dtype: torch.dtype): - if not self.__tensor_on_device(tensor): - raise RuntimeError( - f"The output tensor is not on the same device as the Iris instance. The Iris instance is on device {self.device} but the output tensor is on device {tensor.device}" - ) - if not self.__on_symmetric_heap(tensor): - raise RuntimeError( - f"The output tensor is not on the symmetric heap. The Iris instance is on heap base {self.heap_bases[self.cur_rank]} but the output tensor is on heap base {tensor.data_ptr()}" - ) - if tensor.numel() != num_elements: - raise RuntimeError(f"The output tensor has {tensor.numel()} elements, but {num_elements} are required") - if tensor.dtype != dtype: - raise RuntimeError(f"The output tensor has dtype {tensor.dtype}, but {dtype} is required") - - def __throw_if_invalid_device(self, device): - """ - Throw a RuntimeError if the requested device is not compatible with this Iris instance. - - Args: - device: The requested device (can be string, torch.device, or None) - - Raises: - RuntimeError: If the device is not compatible - """ - if not self.__is_valid_device(device): - raise RuntimeError( - f"Device mismatch: requested device {device} but Iris instance is on device {self.device}. " - f"Iris only supports tensors on its own device." - ) - - def __apply_memory_format( - self, tensor: torch.Tensor, size: tuple, memory_format: torch.memory_format, input_tensor: torch.Tensor = None - ): - """ - Apply the requested memory format to a tensor by setting appropriate strides. - This keeps the tensor on the symmetric heap while changing how PyTorch interprets the memory layout. - - Args: - tensor: The tensor to modify - size: The tensor's size/dimensions - memory_format: The desired memory format - input_tensor: The original input tensor (needed for preserve_format detection) - """ - if memory_format == torch.contiguous_format: - # Default format, no changes needed - return tensor - elif memory_format == torch.channels_last and len(size) == 4: - # For channels_last format: preserve shape (N, C, H, W) but change strides - # channels_last strides: [C*H*W, 1, C*W, C] for shape (N, C, H, W) - N, C, H, W = size[0], size[1], size[2], size[3] - # Keep the original shape (N, C, H, W) but use channels_last strides - tensor = self.__create_tensor_with_strides(tensor, size, (C * H * W, 1, C * W, C)) - return tensor - elif memory_format == torch.channels_last_3d and len(size) == 5: - # For channels_last_3d format: preserve shape (N, C, D, H, W) but change strides - # channels_last_3d strides: [C*D*H*W, 1, C*D*W, C*W, C] for shape (N, C, D, H, W) - N, C, D, H, W = size[0], size[1], size[2], size[3], size[4] - # Keep the original shape (N, C, D, H, W) but use channels_last_3d strides - tensor = self.__create_tensor_with_strides(tensor, size, (C * D * H * W, 1, C * D * W, C * W, C)) - return tensor - elif memory_format == torch.preserve_format: - # For preserve_format, we need to detect the input tensor's memory format - # and apply the same format to the output - if input_tensor is not None: - # Check the actual memory format of the input tensor - if len(size) == 4: - # Check if input tensor is in channels_last format by examining strides - # channels_last format has strides[1] == 1 (channels dimension is contiguous) - input_strides = input_tensor.stride() - if len(input_strides) == 4 and input_strides[1] == 1: - # Input is in channels_last format, preserve it - # Use the input tensor's actual shape, not the size parameter - input_shape = input_tensor.shape - if len(input_shape) == 4: - # Input is already in channels_last format (N, H, W, C) - new_size = input_shape - # Use the input tensor's strides directly - tensor = self.__create_tensor_with_strides(tensor, new_size, input_strides) - return tensor - elif len(size) == 5: - # Check if input tensor is in channels_last_3d format - input_strides = input_tensor.stride() - if len(input_strides) == 5 and input_strides[1] == 1: - # Input is in channels_last_3d format, preserve it - # Use the input tensor's actual shape, not the size parameter - input_shape = input_tensor.shape - if len(input_shape) == 5: - # Input is already in channels_last_3d format (N, D, H, W, C) - new_size = input_shape - # Use the input tensor's strides directly - tensor = self.__create_tensor_with_strides(tensor, new_size, input_strides) - return tensor - # If no special format detected or no input tensor provided, use contiguous format - return tensor - else: - # Unsupported format or dimension combination - self.debug( - f"Warning: Memory format {memory_format} not supported for {len(size)}D tensor, using contiguous format" - ) - # For unsupported formats, return the tensor as-is (contiguous) - return tensor - - def __create_tensor_with_strides(self, original_tensor: torch.Tensor, size: tuple, strides: tuple) -> torch.Tensor: - """ - Create a new tensor with the specified strides while keeping the data on the symmetric heap. - - Args: - original_tensor: The original tensor (source of data and heap allocation) - size: The tensor's size/dimensions - strides: The desired strides for the new memory format - - Returns: - A new tensor with the specified strides, data copied from original, on the same heap - """ - - # First, create a temporary tensor with the correct strides using PyTorch - temp_tensor = torch.empty_strided(size, strides, dtype=original_tensor.dtype, device=original_tensor.device) - - # Handle different cases based on whether size changes and what the strides indicate - if size != original_tensor.shape: - # Size is different - this might be a format change that requires permutation - # Check if this is a channels_last format by comparing strides - if len(size) == 4: - # For channels_last: expected strides are [H*W*C, 1, W*C, C] for shape (N, H, W, C) - N, H, W, C = size[0], size[1], size[2], size[3] - expected_strides = (H * W * C, 1, W * C, C) - if strides == expected_strides: - permuted = original_tensor.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) - else: - # If the size differs for other reasons, do not permute; just reshape if possible - try: - permuted = original_tensor.reshape(size) - except Exception: - raise ValueError( - "Cannot safely permute or reshape tensor: size differs from original shape for unknown reason." - ) - elif len(size) == 5: - # For channels_last_3d: expected strides are [D*H*W*C, 1, H*W*C, W*C, C] for shape (N, D, H, W, C) - N, D, H, W, C = size[0], size[1], size[2], size[3], size[4] - expected_strides = (D * H * W * C, 1, H * W * C, W * C, C) - if strides == expected_strides: - permuted = original_tensor.permute(0, 2, 3, 4, 1) # (N, C, D, H, W) -> (N, D, H, W, C) - else: - # If the size differs for other reasons, do not permute; just reshape if possible - try: - permuted = original_tensor.reshape(size) - except Exception: - raise ValueError( - "Cannot safely permute or reshape tensor: size differs from original shape for unknown reason." - ) - else: - # For other dimensions, just try to reshape - try: - permuted = original_tensor.reshape(size) - except Exception: - raise ValueError( - "Cannot safely permute or reshape tensor: size differs from original shape for unknown reason." - ) - else: - # Size is the same - this is a stride-only change (like channels_last with preserved shape) - # We need to reorder the data to match the new stride pattern - if len(size) == 4: - # Check if this is channels_last format with preserved shape - N, C, H, W = size[0], size[1], size[2], size[3] - expected_strides = (C * H * W, 1, C * W, C) - if strides == expected_strides: - permuted = original_tensor - else: - permuted = original_tensor - elif len(size) == 5: - # Check if this is channels_last_3d format with preserved shape - N, C, D, H, W = size[0], size[1], size[2], size[3], size[4] - expected_strides = (C * D * H * W, 1, C * D * W, C * W, C) - if strides == expected_strides: - permuted = original_tensor - else: - permuted = original_tensor - else: - permuted = original_tensor - - # Copy the permuted data to the temporary tensor - temp_tensor.copy_(permuted) - - # Now allocate a new tensor on our symmetric heap - num_elements = math.prod(size) - heap_tensor = self.__allocate(num_elements, original_tensor.dtype) - - # Reshape to the desired size - heap_tensor = heap_tensor.reshape(size) - - # Copy the data from the temporary tensor to our heap tensor - heap_tensor.copy_(temp_tensor) - - # Clean up the temporary tensor - del temp_tensor - - # Now we need to create a view with the correct strides - # We can't use as_strided directly on our heap tensor, but we can - # create a new tensor with the right strides and copy the data again - final_tensor = torch.as_strided(heap_tensor, size, strides) - - return final_tensor - - def __apply_layout(self, tensor: torch.Tensor, layout: torch.layout) -> torch.Tensor: - """ - Apply the requested layout to a tensor. - - Args: - tensor: The tensor to modify - layout: The desired layout - - Returns: - Tensor with the requested layout - """ - - if layout == torch.strided: - # Strided layout is the default - no changes needed - return tensor - else: - # Only support strided layout for now - raise ValueError(f"Layout {layout} not supported. Only torch.strided is currently supported.") - - def __tensor_on_device(self, tensor: torch.Tensor): - # Get the Iris device from memory_pool.device - iris_device = self.get_device() - tensor_device = tensor.device - - # For CUDA devices, check if they're compatible - if tensor_device.type == "cuda" and iris_device.type == "cuda": - if iris_device.index is None: - return True - return tensor_device.index == iris_device.index - - # For non-CUDA devices, they must be exactly equal - return tensor_device == iris_device - - def __on_symmetric_heap(self, tensor: torch.Tensor): - """Check if a tensor is allocated on the symmetric heap.""" - return self.heap.on_symmetric_heap(tensor) - - def __is_valid_device(self, device) -> bool: - """ - Check if the requested device is compatible with this Iris instance. - - Args: - device: The requested device (can be string, torch.device, or None) - - Returns: - bool: True if the device is compatible, False otherwise - """ - if device is None: - return True # None means use default device - - # Convert device strings to torch.device objects for proper comparison - requested_device = torch.device(device) if isinstance(device, str) else device - iris_device = self.get_device() - - # Check if both are CUDA devices - if requested_device.type == "cuda" and iris_device.type == "cuda": - # Check if index matches or if requested is "cuda" (any index) - if requested_device.index is None: - return True - else: - return requested_device.index == iris_device.index - - # For non-CUDA devices, always return False - return False - - class CCL: - """ - Collective Communication Library (CCL) interface for Iris. - - Provides collective operations that can be called as methods on the Iris instance. - Example usage: - >>> shmem = iris.iris() - >>> shmem.ccl.all_to_all(output_tensor, input_tensor) - """ - - def __init__(self, iris_instance): - """ - Initialize CCL with a reference to the parent Iris instance. - - Args: - iris_instance: The parent Iris instance - """ - self._iris = iris_instance - - def all_to_all(self, output_tensor, input_tensor, group=None, async_op=False, config=None): - """ - All-to-all collective operation. - - Each rank sends a tensor chunk to each other rank and receives - a tensor chunk from each other rank. Input/output tensors should have - shape (M, N * world_size) where each chunk of N columns corresponds to one rank. - - Args: - output_tensor: Output tensor of shape (M, N * world_size) - input_tensor: Input tensor of shape (M, N * world_size) - group: ProcessGroup or None. If None, uses all ranks in shmem context. - Default: None. - async_op: If False, performs a barrier at the end. If True, returns immediately. - Default: False. - config: Config instance with kernel parameters (default: None). - If None, uses default Config values. - - Example: - >>> shmem = iris.iris() - >>> shmem.ccl.all_to_all(output_tensor, input_tensor) - - >>> # Custom configuration - >>> from iris.ccl import Config - >>> config = Config(block_size_m=128, block_size_n=32) - >>> shmem.ccl.all_to_all(output_tensor, input_tensor, config=config) - - >>> # Async operation (no barrier) - >>> shmem.ccl.all_to_all(output_tensor, input_tensor, async_op=True) - """ - from iris.ccl.all_to_all import all_to_all as _all_to_all - - _all_to_all(output_tensor, input_tensor, self._iris, group=group, async_op=async_op, config=config) - - def all_gather(self, output_tensor, input_tensor, group=None, async_op=False, config=None): - """ - All-gather collective operation. - - Each rank sends its input tensor to all ranks, and all ranks receive - and concatenate all input tensors along dimension 0 (rows), matching - torch.distributed.all_gather_into_tensor behavior. - - Args: - output_tensor: Output tensor of shape (world_size * M, N) - will contain concatenated inputs - input_tensor: Input tensor of shape (M, N) - local rank's data to send - group: ProcessGroup or None. If None, uses all ranks in shmem context. - Default: None. - async_op: If False, performs a barrier at the end. If True, returns immediately. - Default: False. - config: Config instance with kernel parameters (default: None). - If None, uses default Config values. - - Example: - >>> shmem = iris.iris() - >>> # Input: (M, N), Output: (world_size * M, N) - >>> shmem.ccl.all_gather(output_tensor, input_tensor) - - >>> # Custom configuration - >>> from iris.ccl import Config - >>> config = Config(block_size_m=128, block_size_n=32) - >>> shmem.ccl.all_gather(output_tensor, input_tensor, config=config) - - >>> # Async operation (no barrier) - >>> shmem.ccl.all_gather(output_tensor, input_tensor, async_op=True) - """ - from iris.ccl.all_gather import all_gather as _all_gather - - _all_gather(output_tensor, input_tensor, self._iris, group=group, async_op=async_op, config=config) - - def all_reduce_preamble(self, output_tensor, input_tensor, config=None, workspace=None): - """ - Prepare reusable workspace for all-reduce. - - Args: - output_tensor: Output tensor that will receive the reduced data. - input_tensor: Input tensor providing the local contribution. - config: Optional Config describing variant parameters. - workspace: Optional existing workspace to update/reuse. - - Returns: - Workspace object that can be passed to ``all_reduce``. - """ - from iris.ccl.all_reduce import all_reduce_preamble as _all_reduce_preamble - - return _all_reduce_preamble( - output_tensor, - input_tensor, - self._iris, - config=config, - workspace=workspace, - ) - - def all_reduce( - self, output_tensor, input_tensor, op=None, group=None, async_op=False, config=None, workspace=None - ): - """ - All-reduce collective operation. - - Each rank has a local input tensor, and all ranks compute the sum of all - input tensors. The result is written to output_tensor on all ranks. - - Args: - output_tensor: Output tensor of shape (M, N) - will contain sum of all inputs - input_tensor: Input tensor of shape (M, N) - local rank's partial data - op: Reduction operation to apply. Currently only ReduceOp.SUM is supported. - Default: ReduceOp.SUM. - group: ProcessGroup or None. If None, uses all ranks in shmem context. - Default: None. - async_op: If False, performs a barrier at the end. If True, returns immediately. - Default: False. - config: Config instance with kernel parameters (default: None). - If None, uses default Config values. - Set config.all_reduce_variant to choose variant: "atomic", "ring", or "two_shot" - workspace: Optional workspace prepared by ``all_reduce_preamble`` to - reuse internal buffers across invocations. - - Example: - >>> shmem = iris.iris() - >>> shmem.ccl.all_reduce(output_tensor, input_tensor) - - >>> # Custom configuration with ring variant - >>> from iris.ccl import Config - >>> config = Config(all_reduce_variant="ring") - >>> shmem.ccl.all_reduce(output_tensor, input_tensor, config=config) - - >>> # Two-shot variant with block distribution - >>> config = Config(all_reduce_variant="two_shot", all_reduce_distribution=1) - >>> shmem.ccl.all_reduce(output_tensor, input_tensor, config=config) - - >>> # Async operation (no barrier) - >>> shmem.ccl.all_reduce(output_tensor, input_tensor, async_op=True) - """ - from iris.ccl.all_reduce import all_reduce as _all_reduce - from iris.ccl import ReduceOp - - # Default to SUM if not specified - if op is None: - op = ReduceOp.SUM - - return _all_reduce( - output_tensor, - input_tensor, - self._iris, - op=op, - group=group, - async_op=async_op, - config=config, - workspace=workspace, - ) - - def reduce_scatter(self, output_tensor, input_tensor, op=None, group=None, async_op=False, config=None): - """ - Reduce-scatter collective operation. - - Each rank reduces its assigned tiles from all ranks' inputs and stores - the result only to its own output tensor. This is similar to all-reduce - but without broadcasting the result to all ranks. - - Args: - output_tensor: Output tensor of shape (M, N) - will contain reduced tiles for this rank - input_tensor: Input tensor of shape (M, N) - local rank's partial data - op: Reduction operation to apply. Currently only ReduceOp.SUM is supported. - Default: ReduceOp.SUM. - group: ProcessGroup or None. If None, uses all ranks in shmem context. - Default: None. - async_op: If False, performs a barrier at the end. If True, returns immediately. - Default: False. - config: Config instance with kernel parameters (default: None). - If None, uses default Config values. - Only supports reduce_scatter_variant="two_shot". - - Example: - >>> shmem = iris.iris() - >>> shmem.ccl.reduce_scatter(output_tensor, input_tensor) - - >>> # Custom configuration - >>> from iris.ccl import Config - >>> config = Config(reduce_scatter_variant="two_shot", all_reduce_distribution=1) - >>> shmem.ccl.reduce_scatter(output_tensor, input_tensor, config=config) - """ - from iris.ccl.reduce_scatter import reduce_scatter as _reduce_scatter - from iris.ccl import ReduceOp - - # Default to SUM if not specified - if op is None: - op = ReduceOp.SUM - - _reduce_scatter( - output_tensor, input_tensor, self._iris, op=op, group=group, async_op=async_op, config=config - ) - - -@triton.jit -def __translate(ptr, from_rank, to_rank, heap_bases): - from_base = tl.load(heap_bases + from_rank) - to_base = tl.load(heap_bases + to_rank) - # convert to int to compute difference - ptr_int = tl.cast(ptr, tl.uint64) - # Find the offset from from_rank heap - offset = ptr_int - from_base - # Byte cast for byte offset addition - to_base_byte = tl.cast(to_base, tl.pointer_type(tl.int8)) - # Find the offset into the to_rank heap - translated_ptr_byte = to_base_byte + offset - # Cast to_base back to pointer type - translated_ptr = tl.cast(translated_ptr_byte, ptr.dtype) - - # Optimization to vectorize the load/store - # We can't do this in general because we don't know the shape of the tensor or block sizes - # ptr = tl.max_contiguous(tl.multiple_of(ptr, (16, 16)), (16, 32)) - - # 0 You can use this if your block sizes are multiples of 32. - # Largest vectorized load instruction is dwordx4 (128-bits) - translated_ptr = tl.multiple_of(translated_ptr, (32, 32)) - translated_ptr = tl.max_contiguous(translated_ptr, (32, 32)) - - # ptr = tl.max_contiguous(tl.multiple_of(ptr, 512), 512) - # translated_ptr = tl.max_contiguous(tl.multiple_of(translated_ptr, 512), 512) - return translated_ptr - - -@triton.jit -def load(pointer, to_rank, from_rank, heap_bases, mask=None): - """ - Loads a value from the specified rank's memory location. - - This function performs a memory read operation by translating the pointer - from the `from_rank`'s address space to the `to_rank`'s address space and loading - data from the target memory location. If the `from_rank` and `to_rank` are the same, - this function performs a local load operation. - - Args: - pointer (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the `from_rank`'s address space that will be translated to the `to_rank`'s address space. Must be the current rank where the pointer is local. - to_rank (int): The rank ID to which the pointer will be translated. Must be the current rank where the pointer is local. - from_rank (int): The rank ID from which to read the data. - heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. - mask (Block of triton.int1, optional): If mask[idx] is false, do not load the data at address pointer[idx]. Defaults to None. - - Returns: - Block: The loaded value from the target memory location. - - Example: - >>> @triton.jit - >>> def kernel(ptr, heap_bases): - >>> # Load data from rank 1's memory into the current rank - >>> cur_rank = 0 # Current rank - >>> remote_rank = 1 # Remote rank to load from - >>> data = iris.load(ptr, cur_rank, remote_rank, heap_bases) - >>> return data - """ - translated_ptr = __translate(pointer, to_rank, from_rank, heap_bases) - result = tl.load(translated_ptr, mask=mask) - return result - - -@triton.jit -def store(pointer, value, from_rank, to_rank, heap_bases, mask=None): - """ - Writes data to the specified rank's memory location. - - This function performs a memory write operation by translating the pointer - from the `from_rank`'s address space to the `to_rank`'s address space and storing - the provided data to the target memory location. If the `from_rank` and `to_rank` are the same, - this function performs a local store operation. - - Args: - pointer (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the `from_rank`'s address space that will be translated to the `to_rank`'s address space. Must be the current rank where the pointer is local. - value (Block): The tensor of elements to be stored. - from_rank (int): The rank ID from which the pointer originates. Must be the current rank where the pointer is local. - to_rank (int): The rank ID to which the data will be written. - heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. - mask (Block of triton.int1, optional): If mask[idx] is false, do not store the data at address pointer[idx]. Defaults to None. - - Returns: - None - - Example: - >>> @triton.jit - >>> def kernel(ptr, heap_bases): - >>> # Store value 42 into rank 1's heap from rank 0 - >>> cur_rank = 0 # Current rank (source) - >>> remote_rank = 1 # Remote rank (destination) - >>> value = 42 - >>> iris.store(ptr, value, cur_rank, remote_rank, heap_bases) - """ - translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) - tl.store(translated_ptr, value, mask=mask) - - -@triton.jit -def copy(src_ptr, dst_ptr, from_rank, to_rank, cur_rank, heap_bases, mask=None): - """ - Copies data from the specified rank's memory into the destination rank's memory. - This function performs the transfer by translating `src_ptr` from the `from_rank`'s address - space to the `to_rank`'s address space, performing a masked load from the translated - source, and storing the loaded data to `dst_ptr` in the `to_rank` memory location. - If `from_rank` and `to_rank` are the same, this function performs a local copy operation. - It is undefined behaviour if neither `from_rank` nor `to_rank` is the `cur_rank`. - - Args: - src_ptr (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the `from_rank`'s local memory from which to read data. - dst_ptr (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the `to_rank`'s local memory where the data will be written. - from_rank (int): The rank ID that owns `src_ptr` (source rank). - to_rank (int): The rank ID that will receive the data (destination rank). - cur_rank (int): The rank ID issuing the copy operation. Must be either `from_rank` or `to_rank`. - heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. - mask (Block of triton.int1, optional): If mask[idx] is false, do not load from the translated src_ptr[idx] and do not store to dst_ptr[idx]. Defaults to None. - - Returns: - None - - Example: - >>> @triton.jit - >>> def kernel(remote_ptr, local_ptr, heap_bases): - >>> from_rank = 1 - >>> to_rank = 0 - >>> iris.copy(remote_ptr, local_ptr, from_rank, to_rank, to_rank, heap_bases) - """ - - cur_base = tl.load(heap_bases + cur_rank) - - from_base = tl.load(heap_bases + from_rank) - to_base = tl.load(heap_bases + to_rank) - - src_ptr_int = tl.cast(src_ptr, tl.uint64) - src_offset = src_ptr_int - cur_base - - dst_ptr_int = tl.cast(dst_ptr, tl.uint64) - dst_offset = dst_ptr_int - cur_base - - from_base_byte = tl.cast(from_base, tl.pointer_type(tl.int8)) - to_base_byte = tl.cast(to_base, tl.pointer_type(tl.int8)) - - translated_src = tl.cast(from_base_byte + src_offset, src_ptr.dtype) - translated_dst = tl.cast(to_base_byte + dst_offset, src_ptr.dtype) - - data = tl.load(translated_src, mask=mask) - tl.store(translated_dst, data, mask=mask) - - -@triton.jit -def get(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None): - """ - Copies data from the specified rank's memory to the current rank's local memory. - - This function performs a memory read operation by translating the `from_ptr` - from the current rank's address space to the `from_rank`'s address space, loading data - from the `from_rank` memory location, and storing it to the local `to_ptr`. - If the `from_rank` is the same as the current rank, this function performs a local copy operation. - - Args: - from_ptr (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the current rank's address space that will be translated to the `from_rank`'s address space. Must be the current rank where the pointer is local. - to_ptr (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the current rank's local memory where the data will be stored. - from_rank (int): The `from_rank` ID from which to read the data. - to_rank (int): The current rank ID where the data will be stored. - heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. - mask (Block of triton.int1, optional): If mask[idx] is false, do not load the data at address from_ptr[idx] and do not store to to_ptr[idx]. Defaults to None. - - Returns: - None - - Example: - >>> @triton.jit - >>> def kernel(remote_ptr, local_ptr, heap_bases): - >>> from_rank = 1 - >>> to_rank = 0 - >>> iris.get(remote_ptr, local_ptr, from_rank, to_rank, heap_bases) - """ - translated_from_ptr = __translate(from_ptr, from_rank, to_rank, heap_bases) - - data = tl.load(translated_from_ptr, mask=mask) - - tl.store(to_ptr, data, mask=mask) - - -@triton.jit -def put(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None): - """ - Copies data from the current rank's local memory to the specified rank's memory. - This function performs a memory write operation by loading data from the current - rank's `from_ptr`, translating the `to_ptr` from the current rank's address - space to the `to_rank`'s address space, and storing the data to the `to_rank` memory location. - If the `to_rank` is the same as the current rank, this function performs a local copy operation. - - Args: - from_ptr (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the current rank's local memory from which to read data. - to_ptr (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the current rank's address space that will be translated to the `to_rank`'s address space. Must be the current rank where the pointer is local. - from_rank (int): The current rank ID from which to read the data. - to_rank (int): The `to_rank` ID to which the data will be written. - heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. - mask (Block of triton.int1, optional): If mask[idx] is false, do not load the data at address from_ptr[idx] and do not store to to_ptr[idx]. Defaults to None. - - Returns: - None - - Example: - >>> @triton.jit - >>> def kernel(local_ptr, remote_ptr, heap_bases): - >>> from_rank = 0 - >>> to_rank = 1 - >>> iris.put(local_ptr, remote_ptr, from_rank, to_rank, heap_bases) - """ - translated_to_ptr = __translate(to_ptr, from_rank, to_rank, heap_bases) - - data = tl.load(from_ptr, mask=mask) - - tl.store(translated_to_ptr, data, mask=mask) - - -@triton.jit -def atomic_add(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None): - """ - Performs an atomic add at the specified rank's memory location. - - This function performs an atomic addition operation by translating the pointer - from the `from_rank`'s address space to the `to_rank`'s address space and atomically - adding the provided data to the `to_rank` memory location. If the `from_rank` and `to_rank` are the same, - this function performs a local atomic addition operation. - - Args: - pointer (triton.PointerType, or block of dtype=triton.PointerType): The memory locations in the `from_rank`'s address space that will be translated to the `to_rank`'s address space. Must be the current rank where the pointer is local. - val (Block of dtype=pointer.dtype.element_ty): The values with which to perform the atomic operation. - from_rank (int): The rank ID from which the pointer originates. Must be the current rank where the pointer is local. - to_rank (int): The rank ID to which the atomic operation will be performed. - heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. - mask (Block of triton.int1, optional): If mask[idx] is false, do not perform the atomic operation at address pointer[idx]. Defaults to None. - sem (str, optional): Specifies the memory semantics for the operation. Acceptable values are "acquire", "release", "acq_rel" (stands for "ACQUIRE_RELEASE"), and "relaxed". If not provided, the function defaults to using "acq_rel" semantics. - scope (str, optional): Defines the scope of threads that observe the synchronizing effect of the atomic operation. Acceptable values are "gpu" (default), "cta" (cooperative thread array, thread block), or "sys" (stands for "SYSTEM"). The default value is "gpu". - - Returns: - Block: The data stored at pointer before the atomic operation. - - Example: - >>> @triton.jit - >>> def kernel(ptr, heap_bases): - >>> # Atomically add 5 to rank 1's memory from rank 0 - >>> cur_rank = 0 # Current rank (source) - >>> remote_rank = 1 # Remote rank (destination) - >>> increment = 5 - >>> old_val = iris.atomic_add(ptr, increment, cur_rank, remote_rank, heap_bases) - """ - translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) - return tl.atomic_add(translated_ptr, val, mask=mask, sem=sem, scope=scope) - - -@triton.jit -def atomic_sub(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None): - """ - Atomically subtracts data from the specified rank's memory location. - - This function performs an atomic subtraction operation by translating the pointer - from the `from_rank`'s address space to the `to_rank`'s address space and atomically - subtracting the provided data from the `to_rank` memory location. If the `from_rank` and `to_rank` are the same, - this function performs a local atomic subtraction operation. - - Args: - pointer (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the `from_rank`'s address space that will be translated to the `to_rank`'s address space. Must be the current rank where the pointer is local. - val (Block): The tensor of elements to be subtracted atomically. - from_rank (int): The rank ID from which the pointer originates. Must be the current rank where the pointer is local. - to_rank (int): The rank ID to which the atomic operation will be performed. - heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. - mask (Block of triton.int1, optional): If mask[idx] is false, do not perform the atomic operation at address pointer[idx]. Defaults to None. - sem (str, optional): Specifies the memory semantics for the operation. Acceptable values are "acquire", "release", "acq_rel" (stands for "ACQUIRE_RELEASE"), and "relaxed". Defaults to "acq_rel". - scope (str, optional): Defines the scope of threads that observe the synchronizing effect of the atomic operation. Acceptable values are "gpu" (default), "cta" (cooperative thread array, thread block), or "sys" (stands for "SYSTEM"). Defaults to "gpu". - - Returns: - Block: The value at the memory location before the atomic subtraction. - - Example: - >>> @triton.jit - >>> def kernel(ptr, heap_bases): - >>> # Atomically subtract 3 from rank 2's memory from rank 0 - >>> cur_rank = 0 # Current rank (source) - >>> remote_rank = 2 # Remote rank (destination) - >>> decrement = 3 - >>> old_val = iris.atomic_sub(ptr, decrement, cur_rank, remote_rank, heap_bases) - """ - translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) - return tl.atomic_sub(translated_ptr, val, mask=mask, sem=sem, scope=scope) - - -@triton.jit -def atomic_cas(pointer, cmp, val, from_rank, to_rank, heap_bases, sem=None, scope=None): - """ - Atomically compares and exchanges the specified rank's memory location. - - This function performs an atomic compare-and-swap operation by translating the pointer - from the `from_rank`'s address space to the `to_rank`'s address space and atomically - comparing the current value with the expected value, then writing the new value if they match. - If the `from_rank` and `to_rank` are the same, this function performs a local atomic compare-and-swap operation. - - Args: - pointer (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the `from_rank`'s address space that will be translated to the `to_rank`'s address space. Must be the current rank where the pointer is local. - cmp (Block): The expected value to be compared with the current value at the memory location. - val (Block): The new value to be written if the compare succeeds. - from_rank (int): The rank ID from which the pointer originates. Must be the current rank where the pointer is local. - to_rank (int): The rank ID to which the atomic operation will be performed. - heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. - sem (str, optional): Specifies the memory semantics for the operation. Acceptable values are "acquire", "release", "acq_rel" (stands for "ACQUIRE_RELEASE"), and "relaxed". Defaults to "acq_rel". - scope (str, optional): Defines the scope of threads that observe the synchronizing effect of the atomic operation. Acceptable values are "gpu" (default), "cta" (cooperative thread array, thread block), or "sys" (stands for "SYSTEM"). Defaults to "gpu". - - Returns: - Block: The value contained at the memory location before the atomic operation attempt. - - Example: - >>> @triton.jit - >>> def kernel(ptr, heap_bases): - >>> # Compare-and-swap on rank 1's memory from rank 0 - >>> cur_rank = 0 # Current rank (source) - >>> remote_rank = 1 # Remote rank (destination) - >>> expected = 0 - >>> new_val = 42 - >>> old_val = iris.atomic_cas(ptr, expected, new_val, cur_rank, remote_rank, heap_bases) - """ - translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) - return tl.atomic_cas(translated_ptr, cmp, val, sem=sem, scope=scope) - - -@triton.jit -def atomic_xchg(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None): - """ - Performs an atomic exchange at the specified rank's memory location. - - This function performs an atomic exchange operation by translating the pointer - from the `from_rank`'s address space to the `to_rank`'s address space and atomically - exchanging the current value with the provided new value. If the `from_rank` and `to_rank` are the same, - this function performs a local atomic exchange operation. - - Args: - pointer (triton.PointerType, or block of dtype=triton.PointerType): The memory locations in the `from_rank`'s address space that will be translated to the `to_rank`'s address space. Must be the current rank where the pointer is local. - val (Block of dtype=pointer.dtype.element_ty): The values with which to perform the atomic operation. - from_rank (int): The rank ID from which the pointer originates. Must be the current rank where the pointer is local. - to_rank (int): The rank ID to which the atomic operation will be performed. - heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. - mask (Block of triton.int1, optional): If mask[idx] is false, do not perform the atomic operation at address pointer[idx]. Defaults to None. - sem (str, optional): Specifies the memory semantics for the operation. Acceptable values are "acquire", "release", "acq_rel" (stands for "ACQUIRE_RELEASE"), and "relaxed". If not provided, the function defaults to using "acq_rel" semantics. - scope (str, optional): Defines the scope of threads that observe the synchronizing effect of the atomic operation. Acceptable values are "gpu" (default), "cta" (cooperative thread array, thread block), or "sys" (stands for "SYSTEM"). The default value is "gpu". - - Returns: - Block: The data stored at pointer before the atomic operation. - - Example: - >>> @triton.jit - >>> def kernel(ptr, heap_bases): - >>> # Exchange value with rank 1's memory from rank 0 - >>> cur_rank = 0 # Current rank (source) - >>> remote_rank = 1 # Remote rank (destination) - >>> new_value = 99 - >>> old_val = iris.atomic_xchg(ptr, new_value, cur_rank, remote_rank, heap_bases) - """ - translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) - return tl.atomic_xchg(translated_ptr, val, mask=mask, sem=sem, scope=scope) - - -@triton.jit -def atomic_xor(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None): - """ - Performs an atomic xor at the specified rank's memory location. - - This function performs an atomic xor operation by translating the pointer - from the `from_rank`'s address space to the `to_rank`'s address space and atomically - xoring the provided data to the `to_rank` memory location. If the `from_rank` and `to_rank` are the same, - this function performs a local atomic xor operation. - - Args: - pointer (triton.PointerType, or block of dtype=triton.PointerType): The memory locations in the `from_rank`'s address space that will be translated to the `to_rank`'s address space. Must be the current rank where the pointer is local. - val (Block of dtype=pointer.dtype.element_ty): The values with which to perform the atomic operation. - from_rank (int): The rank ID from which the pointer originates. Must be the current rank where the pointer is local. - to_rank (int): The rank ID to which the atomic operation will be performed. - heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. - mask (Block of triton.int1, optional): If mask[idx] is false, do not perform the atomic operation at address pointer[idx]. Defaults to None. - sem (str, optional): Specifies the memory semantics for the operation. Acceptable values are "acquire", "release", "acq_rel" (stands for "ACQUIRE_RELEASE"), and "relaxed". If not provided, the function defaults to using "acq_rel" semantics. - scope (str, optional): Defines the scope of threads that observe the synchronizing effect of the atomic operation. Acceptable values are "gpu" (default), "cta" (cooperative thread array, thread block), or "sys" (stands for "SYSTEM"). The default value is "gpu". - - Returns: - Block: The data stored at pointer before the atomic operation. - - Example: - >>> @triton.jit - >>> def kernel(ptr, heap_bases): - >>> # Atomically XOR with rank 1's memory from rank 0 - >>> cur_rank = 0 # Current rank (source) - >>> remote_rank = 1 # Remote rank (destination) - >>> mask_val = 0xFF - >>> old_val = iris.atomic_xor(ptr, mask_val, cur_rank, remote_rank, heap_bases) - """ - translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) - return tl.atomic_xor(translated_ptr, val, mask=mask, sem=sem, scope=scope) - - -@triton.jit -def atomic_and(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None): - """ - Performs an atomic and at the specified rank's memory location. - - This function performs an atomic and operation by translating the pointer - from the `from_rank`'s address space to the `to_rank`'s address space and atomically - anding the provided data to the `to_rank` memory location. If the `from_rank` and `to_rank` are the same, - this function performs a local atomic and operation. - - Args: - pointer (triton.PointerType, or block of dtype=triton.PointerType): The memory locations in the `from_rank`'s address space that will be translated to the `to_rank`'s address space. Must be the current rank where the pointer is local. - val (Block of dtype=pointer.dtype.element_ty): The values with which to perform the atomic operation. - from_rank (int): The rank ID from which the pointer originates. Must be the current rank where the pointer is local. - to_rank (int): The rank ID to which the atomic operation will be performed. - heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. - mask (Block of triton.int1, optional): If mask[idx] is false, do not perform the atomic operation at address pointer[idx]. Defaults to None. - sem (str, optional): Specifies the memory semantics for the operation. Acceptable values are "acquire", "release", "acq_rel" (stands for "ACQUIRE_RELEASE"), and "relaxed". If not provided, the function defaults to using "acq_rel" semantics. - scope (str, optional): Defines the scope of threads that observe the synchronizing effect of the atomic operation. Acceptable values are "gpu" (default), "cta" (cooperative thread array, thread block), or "sys" (stands for "SYSTEM"). The default value is "gpu". - - Returns: - Block: The data stored at pointer before the atomic operation. - - Example: - >>> @triton.jit - >>> def kernel(ptr, heap_bases): - >>> # Atomically AND with rank 1's memory from rank 0 - >>> cur_rank = 0 # Current rank (source) - >>> remote_rank = 1 # Remote rank (destination) - >>> mask_val = 0x0F - >>> old_val = iris.atomic_and(ptr, mask_val, cur_rank, remote_rank, heap_bases) - """ - translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) - return tl.atomic_and(translated_ptr, val, mask=mask, sem=sem, scope=scope) - - -@triton.jit -def atomic_or(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None): - """ - Performs an atomic or at the specified rank's memory location. - - This function performs an atomic or operation by translating the pointer - from the `from_rank`'s address space to the `to_rank`'s address space and atomically - oring the provided data to the `to_rank` memory location. If the `from_rank` and `to_rank` are the same, - this function performs a local atomic or operation. - - Args: - pointer (triton.PointerType, or block of dtype=triton.PointerType): The memory locations in the `from_rank`'s address space that will be translated to the `to_rank`'s address space. Must be the current rank where the pointer is local. - val (Block of dtype=pointer.dtype.element_ty): The values with which to perform the atomic operation. - from_rank (int): The rank ID from which the pointer originates. Must be the current rank where the pointer is local. - to_rank (int): The rank ID to which the atomic operation will be performed. - heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. - mask (Block of triton.int1, optional): If mask[idx] is false, do not perform the atomic operation at address pointer[idx]. Defaults to None. - sem (str, optional): Specifies the memory semantics for the operation. Acceptable values are "acquire", "release", "acq_rel" (stands for "ACQUIRE_RELEASE"), and "relaxed". If not provided, the function defaults to using "acq_rel" semantics. - scope (str, optional): Defines the scope of threads that observe the synchronizing effect of the atomic operation. Acceptable values are "gpu" (default), "cta" (cooperative thread array, thread block), or "sys" (stands for "SYSTEM"). The default value is "gpu". - - Returns: - Block: The data stored at pointer before the atomic operation. - - Example: - >>> @triton.jit - >>> def kernel(ptr, heap_bases): - >>> # Atomically OR with rank 1's memory from rank 0 - >>> cur_rank = 0 # Current rank (source) - >>> remote_rank = 1 # Remote rank (destination) - >>> mask_val = 0xF0 - >>> old_val = iris.atomic_or(ptr, mask_val, cur_rank, remote_rank, heap_bases) - """ - translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) - return tl.atomic_or(translated_ptr, val, mask=mask, sem=sem, scope=scope) - - -@triton.jit -def atomic_min(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None): - """ - Performs an atomic min at the specified rank's memory location. - - This function performs an atomic min operation by translating the pointer - from the `from_rank`'s address space to the `to_rank`'s address space and atomically - performing the min on the provided data to the `to_rank` memory location. If the `from_rank` and `to_rank` are the same, - this function performs a local atomic min operation. - - Args: - pointer (triton.PointerType, or block of dtype=triton.PointerType): The memory locations in the `from_rank`'s address space that will be translated to the `to_rank`'s address space. Must be the current rank where the pointer is local. - val (Block of dtype=pointer.dtype.element_ty): The values with which to perform the atomic operation. - from_rank (int): The rank ID from which the pointer originates. Must be the current rank where the pointer is local. - to_rank (int): The rank ID to which the atomic operation will be performed. - heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. - mask (Block of triton.int1, optional): If mask[idx] is false, do not perform the atomic operation at address pointer[idx]. Defaults to None. - sem (str, optional): Specifies the memory semantics for the operation. Acceptable values are "acquire", "release", "acq_rel" (stands for "ACQUIRE_RELEASE"), and "relaxed". If not provided, the function defaults to using "acq_rel" semantics. - scope (str, optional): Defines the scope of threads that observe the synchronizing effect of the atomic operation. Acceptable values are "gpu" (default), "cta" (cooperative thread array, thread block), or "sys" (stands for "SYSTEM"). The default value is "gpu". - - Returns: - Block: The data stored at pointer before the atomic operation. - - Example: - >>> @triton.jit - >>> def kernel(ptr, heap_bases): - >>> # Atomically find minimum with rank 1's memory from rank 0 - >>> cur_rank = 0 # Current rank (source) - >>> remote_rank = 1 # Remote rank (destination) - >>> new_val = 10 - >>> old_val = iris.atomic_min(ptr, new_val, cur_rank, remote_rank, heap_bases) - """ - translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) - return tl.atomic_min(translated_ptr, val, mask=mask, sem=sem, scope=scope) - - -@triton.jit -def atomic_max(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None): - """ - Performs an atomic max at the specified rank's memory location. - - This function performs an atomic max operation by translating the pointer - from the `from_rank`'s address space to the `to_rank`'s address space and atomically - performing the max on the provided data to the `to_rank` memory location. If the `from_rank` and `to_rank` are the same, - this function performs a local atomic max operation. - - Args: - pointer (triton.PointerType, or block of dtype=triton.PointerType): The memory locations in the `from_rank`'s address space that will be translated to the `to_rank`'s address space. Must be the current rank where the pointer is local. - val (Block of dtype=pointer.dtype.element_ty): The values with which to perform the atomic operation. - from_rank (int): The rank ID from which the pointer originates. Must be the current rank where the pointer is local. - to_rank (int): The rank ID to which the atomic operation will be performed. - heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. - mask (Block of triton.int1, optional): If mask[idx] is false, do not perform the atomic operation at address pointer[idx]. Defaults to None. - sem (str, optional): Specifies the memory semantics for the operation. Acceptable values are "acquire", "release", "acq_rel" (stands for "ACQUIRE_RELEASE"), and "relaxed". If not provided, the function defaults to using "acq_rel" semantics. - scope (str, optional): Defines the scope of threads that observe the synchronizing effect of the atomic operation. Acceptable values are "gpu" (default), "cta" (cooperative thread array, thread block), or "sys" (stands for "SYSTEM"). The default value is "gpu". - - Returns: - Block: The data stored at pointer before the atomic operation. - - Example: - >>> @triton.jit - >>> def kernel(ptr, heap_bases): - >>> # Atomically find maximum with rank 1's memory from rank 0 - >>> cur_rank = 0 # Current rank (source) - >>> remote_rank = 1 # Remote rank (destination) - >>> new_val = 100 - >>> old_val = iris.atomic_max(ptr, new_val, cur_rank, remote_rank, heap_bases) - """ - translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) - return tl.atomic_max(translated_ptr, val, mask=mask, sem=sem, scope=scope) - - -def iris(heap_size=1 << 30): - """ - Create and return an Iris instance with the specified heap size. - - Args: - heap_size (int): Size of the heap in bytes. Defaults to 1GB. - - Returns: - Iris: An initialized Iris instance. - - Example: - >>> import iris - >>> iris_ctx = iris.iris(2**30) # 1GB heap - >>> tensor = iris_ctx.zeros(1024, 1024) - """ - return Iris(heap_size) diff --git a/iris/ops/__init__.py b/iris/ops/__init__.py index a6ed4a659..c96fa32e5 100644 --- a/iris/ops/__init__.py +++ b/iris/ops/__init__.py @@ -173,7 +173,6 @@ def matmul_reduce_scatter(self, output_tensor, A, B, async_op=False, config=None # Namespace "OpsNamespace", # Operations - "matmul", # Simple single-GPU GEMM "matmul_all_reduce", "matmul_all_reduce_preamble", "all_gather_matmul", diff --git a/iris/ops/all_gather_matmul.py.with_chunked b/iris/ops/all_gather_matmul.py.with_chunked deleted file mode 100644 index ddc03d027..000000000 --- a/iris/ops/all_gather_matmul.py.with_chunked +++ /dev/null @@ -1,521 +0,0 @@ -# SPDX-License-Identifier: MIT -# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. - -""" -Fused All-Gather + GEMM operation using pull pattern. - -Each rank has a column-sharded input A_sharded (M x K_local). -This operation computes C = all_gather(A_sharded) @ B by pulling -tiles from remote ranks on-demand during GEMM computation. -""" - -from typing import Optional -import torch -import triton -import triton.language as tl -import iris -import iris.x - -from tritonblas.kernels.stages.algorithms.binary import add_vector -from tritonblas.kernels.stages.algorithms.unary import convert_dtype - -from .config import FusedConfig -from .workspace import FusedWorkspace - - -@triton.jit() -def _fused_all_gather_matmul_kernel( - A_sharded, - B, - C, - bias_ptr, - M: tl.constexpr, - N: tl.constexpr, - K: tl.constexpr, - K_local: tl.constexpr, - stride_am: tl.constexpr, - stride_ak: tl.constexpr, - stride_bk: tl.constexpr, - stride_bn: tl.constexpr, - stride_cm: tl.constexpr, - stride_cn: tl.constexpr, - stride_bias: tl.constexpr, - heap_bases: tl.tensor, - cur_rank: tl.constexpr, - world_size: tl.constexpr, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, - NUM_SMS: tl.constexpr, - NUM_XCDS: tl.constexpr, - BIAS: tl.constexpr, - EVEN_K: tl.constexpr, - ALLOW_TF32: tl.constexpr, -): - """Fused all-gather + GEMM kernel using pull pattern.""" - pid = tl.program_id(0) - - # Handle multi-XCD devices - if NUM_XCDS != 1: - pid = (pid % NUM_XCDS) * (NUM_SMS // NUM_XCDS) + (pid // NUM_XCDS) - - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - total_tiles = num_pid_m * num_pid_n - - tl.assume(stride_am > 0) - tl.assume(stride_ak > 0) - tl.assume(stride_bk > 0) - tl.assume(stride_bn > 0) - tl.assume(stride_cm > 0) - tl.assume(stride_cn > 0) - - acc_dtype = tl.int32 if C.type.element_ty == tl.int8 else tl.float32 - - # Persistent loop over output tiles - for tile_id in range(pid, total_tiles, NUM_SMS): - # Compute tile coordinates with swizzling - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = tile_id // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) - pid_n = (tile_id % num_pid_in_group) // group_size_m - - # Compute row and column indices - rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N - rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) - rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) - - # Initialize accumulator - acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) - - # Create DeviceContext and TensorView for gather operations - ctx = iris.x.DeviceContext(cur_rank, world_size, heap_bases) - src_view = iris.x.TensorView(A_sharded, M, K_local, stride_am, stride_ak) - - # Loop over all ranks to pull and accumulate - for source_rank_id in range(world_size): - loop_k_local = tl.cdiv(K_local, BLOCK_SIZE_K) - if not EVEN_K: - loop_k_local -= 1 - - # Loop over K dimension for this rank's shard - for k_block_idx in range(0, loop_k_local): - k_offset = k_block_idx * BLOCK_SIZE_K - - # Create tile view for this K block - tile_k = k_offset // BLOCK_SIZE_K - k_tile = iris.x.TileView(pid_m, tile_k, BLOCK_SIZE_M, BLOCK_SIZE_K) - - # Pull A tile from source_rank_id using gather primitive - a = iris.x.gather(k_tile, src_view, source_rank_id, ctx) - - # Load B tile - rk_local = k_offset + tl.arange(0, BLOCK_SIZE_K) - rk_global = (source_rank_id * K_local) + rk_local - B_ptr = B + rk_global[:, None] * stride_bk + rn[None, :] * stride_bn - b = tl.load(tl.multiple_of(B_ptr, (16, 1))) - - # Accumulate - if ALLOW_TF32: - acc = tl.dot(a, b, acc, allow_tf32=True) - else: - acc += tl.dot(a, b, allow_tf32=False) - - # Handle remaining K elements if not evenly divisible - if not EVEN_K: - k_offset = loop_k_local * BLOCK_SIZE_K - tile_k = k_offset // BLOCK_SIZE_K - k_tile = iris.x.TileView(pid_m, tile_k, BLOCK_SIZE_M, BLOCK_SIZE_K) - - # Pull A tile from source_rank_id using gather primitive - a = iris.x.gather(k_tile, src_view, source_rank_id, ctx) - - rk_local = k_offset + tl.arange(0, BLOCK_SIZE_K) - rk_global = (source_rank_id * K_local) + rk_local - rk_global_mask = rk_global < K - B_ptr = B + rk_global[:, None] * stride_bk + rn[None, :] * stride_bn - b = tl.load(tl.multiple_of(B_ptr, (16, 1)), mask=rk_global_mask[:, None], other=0.0) - - if ALLOW_TF32: - acc = tl.dot(a, b, acc, allow_tf32=True) - else: - acc += tl.dot(a, b, allow_tf32=False) - - # Add bias if provided using tritonBLAS - if BIAS: - bias_vector = tl.load(bias_ptr + rm * stride_bias, mask=rm < M, other=0.0) - acc = add_vector(acc, bias_vector, QUANTIZED=False) - - # Convert to output dtype using tritonBLAS - c = convert_dtype(acc, C.type.element_ty) - - # Store result (manual for now, tritonBLAS store has issues with our indices) - C_ptr = ( - C - + (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M))[:, None] * stride_cm - + (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N))[None, :] * stride_cn - ) - mask = ((pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M))[:, None] < M) & ( - (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N))[None, :] < N - ) - tl.store(C_ptr, c, mask=mask) - - -@triton.jit() -def _fused_chunked_all_gather_matmul_kernel( - A_sharded, - B, - C, - bias_ptr, - temp_buffer, # Temporary buffer: BLOCK_M x K x num_tiles - M: tl.constexpr, - N: tl.constexpr, - K: tl.constexpr, - K_local: tl.constexpr, - stride_am: tl.constexpr, - stride_ak: tl.constexpr, - stride_bk: tl.constexpr, - stride_bn: tl.constexpr, - stride_cm: tl.constexpr, - stride_cn: tl.constexpr, - stride_bias: tl.constexpr, - heap_bases: tl.tensor, - cur_rank: tl.constexpr, - world_size: tl.constexpr, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, - NUM_SMS: tl.constexpr, - NUM_XCDS: tl.constexpr, - BIAS: tl.constexpr, - EVEN_K: tl.constexpr, - ALLOW_TF32: tl.constexpr, -): - """ - Fused all-gather + GEMM kernel using chunked/buffered pattern. - - This variant pre-gathers all of A into a temporary buffer before computing GEMM. - Eliminates the world_size loop by using iris.x.all_gather upfront. - - Memory layout: - - temp_buffer: BLOCK_M x K x num_tiles (stores gathered A for each tile) - - Each program gathers its M-tile of A, then does GEMM - """ - pid = tl.program_id(0) - - # Handle multi-XCD devices - if NUM_XCDS != 1: - pid = (pid % NUM_XCDS) * (NUM_SMS // NUM_XCDS) + (pid // NUM_XCDS) - - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - total_tiles = num_pid_m * num_pid_n - - tl.assume(stride_am > 0) - tl.assume(stride_ak > 0) - tl.assume(stride_bk > 0) - tl.assume(stride_bn > 0) - tl.assume(stride_cm > 0) - tl.assume(stride_cn > 0) - - acc_dtype = tl.int32 if C.type.element_ty == tl.int8 else tl.float32 - - # Persistent loop over output tiles - for tile_id in range(pid, total_tiles, NUM_SMS): - # Compute tile coordinates with swizzling - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = tile_id // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) - pid_n = (tile_id % num_pid_in_group) // group_size_m - - # Compute row and column indices - rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N - rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) - rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) - - # Buffer pointer for this tile: BLOCK_M x K for this pid_m - buffer_ptr = temp_buffer + tile_id * BLOCK_SIZE_M * K - - # Step 1: Pre-gather entire M-tile of A (BLOCK_M x K) - # Create DeviceContext and TensorView for gather operations - ctx = iris.x.DeviceContext(cur_rank, world_size, heap_bases) - src_view = iris.x.TensorView(A_sharded, M, K_local, stride_am, stride_ak) - - # Gather K-tiles from all ranks - for source_rank_id in range(world_size): - k_start = source_rank_id * K_local - # Loop over K dimension in blocks - for k_local_idx in range(0, K_local, BLOCK_SIZE_K): - k_global = k_start + k_local_idx - rk = k_global + tl.arange(0, BLOCK_SIZE_K) - rk_mask = rk < K - - tile_k = k_local_idx // BLOCK_SIZE_K - k_tile = iris.x.TileView(pid_m, tile_k, BLOCK_SIZE_M, BLOCK_SIZE_K) - - # Pull A tile from source_rank_id - a = iris.x.gather(k_tile, src_view, source_rank_id, ctx) - - # Store in buffer - buffer_A_ptr = buffer_ptr + rm[:, None] * K + rk[None, :] - tl.store(buffer_A_ptr, a, mask=rk_mask[None, :]) - - # Step 2: Standard GEMM from buffer - # Initialize accumulator - acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) - - # Loop over K dimension - loop_k = tl.cdiv(K, BLOCK_SIZE_K) - if EVEN_K: - for k_block_idx in range(loop_k): - k_offset = k_block_idx * BLOCK_SIZE_K - - # Load A from temp buffer - rk = k_offset + tl.arange(0, BLOCK_SIZE_K) - buffer_A_ptr = buffer_ptr + rm[:, None] * K + rk[None, :] - a = tl.load(buffer_A_ptr) - - # Load B tile - B_ptr = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn - b = tl.load(tl.multiple_of(B_ptr, (16, 1))) - - # Accumulate - if ALLOW_TF32: - acc = tl.dot(a, b, acc, allow_tf32=True) - else: - acc += tl.dot(a, b, allow_tf32=False) - else: - # Handle case where K is not evenly divisible by BLOCK_SIZE_K - for k_block_idx in range(loop_k): - k_offset = k_block_idx * BLOCK_SIZE_K - - # Load A from temp buffer - rk = k_offset + tl.arange(0, BLOCK_SIZE_K) - rk_mask = rk < K - buffer_A_ptr = buffer_ptr + rm[:, None] * K + rk[None, :] - a = tl.load(buffer_A_ptr, mask=rk_mask[None, :], other=0.0) - - # Load B tile - B_ptr = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn - b = tl.load(tl.multiple_of(B_ptr, (16, 1)), mask=rk_mask[:, None], other=0.0) - - if ALLOW_TF32: - acc = tl.dot(a, b, acc, allow_tf32=True) - else: - acc += tl.dot(a, b, allow_tf32=False) - - # Convert accumulator and add bias - c = convert_dtype(acc, C.type.element_ty) - if BIAS: - bias_offset = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) * stride_bias - bias_val = tl.load(bias_ptr + bias_offset) - c = add_vector(c, bias_val, 0) - - # Store result - C_ptr = ( - C - + (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M))[:, None] * stride_cm - + (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N))[None, :] * stride_cn - ) - mask = ((pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M))[:, None] < M) & ( - (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N))[None, :] < N - ) - tl.store(C_ptr, c, mask=mask) - - -def all_gather_matmul_preamble( - shmem, - A_sharded: torch.Tensor, - B: torch.Tensor, - config: Optional[FusedConfig] = None, -) -> FusedWorkspace: - """Allocate workspace for all_gather_matmul (buffer needed for chunked variant).""" - if config is None: - config = FusedConfig() - - M, K_local = A_sharded.shape - K, N = B.shape - world_size = shmem.get_num_ranks() - - expected_K = world_size * K_local - assert K == expected_K, f"K ({K}) must equal world_size ({world_size}) * K_local ({K_local})" - - # Detect hardware configuration - device = A_sharded.device - if config.num_sms is None: - import iris.hip - num_sms = iris.hip.get_cu_count(device.index) - else: - num_sms = config.num_sms - - if config.num_xcds == 1: - # Auto-detect XCDs if default value is used - import iris.hip - num_xcds = iris.hip.get_num_xcc(device.index) - else: - num_xcds = config.num_xcds - - # Allocate temporary buffer for chunked variant - aux_buffer = None - if config.all_gather_matmul_variant == "chunked": - # Calculate grid size to determine buffer size - num_tiles_m = (M + config.block_size_m - 1) // config.block_size_m - num_tiles_n = (N + config.block_size_n - 1) // config.block_size_n - num_tiles = num_tiles_m * num_tiles_n - - # Allocate buffer: BLOCK_M x K x num_tiles - buffer_size = config.block_size_m * K * num_tiles - aux_buffer = torch.empty(buffer_size, dtype=A_sharded.dtype, device=device) - - return FusedWorkspace( - operation="all_gather_matmul", - shape=(M, N, K), - dtype=A_sharded.dtype, - world_size=world_size, - num_sms=num_sms, - num_xcds=num_xcds, - variant=config.all_gather_matmul_variant, - aux_buffer=aux_buffer, - prepared=True, - ) - - -def all_gather_matmul( - shmem, - output_tensor: torch.Tensor, - A_sharded: torch.Tensor, - B: torch.Tensor, - bias: Optional[torch.Tensor] = None, - async_op: bool = False, - config: Optional[FusedConfig] = None, - workspace: Optional[FusedWorkspace] = None, -) -> FusedWorkspace: - """Fused all-gather and matrix multiplication using pull pattern.""" - if config is None: - config = FusedConfig() - - M, K_local = A_sharded.shape - K, N = B.shape - world_size = shmem.get_num_ranks() - rank = shmem.get_rank() - - expected_K = world_size * K_local - assert K == expected_K, f"K ({K}) must equal world_size ({world_size}) * K_local ({K_local})" - assert output_tensor.shape == (M, N), f"Output must be ({M}, {N}), got {output_tensor.shape}" - - # Validate problem size against block sizes - assert M >= config.block_size_m, ( - f"M ({M}) must be >= block_size_m ({config.block_size_m}). Use smaller block sizes for small problems." - ) - assert K_local >= config.block_size_k, ( - f"K_local ({K_local}) must be >= block_size_k ({config.block_size_k}). " - f"Use smaller block sizes for small problems." - ) - assert N >= config.block_size_n, ( - f"N ({N}) must be >= block_size_n ({config.block_size_n}). Use smaller block sizes for small problems." - ) - - if workspace is None: - workspace = all_gather_matmul_preamble(shmem, A_sharded, B, config) - - stride_am, stride_ak = A_sharded.stride() - stride_bk, stride_bn = B.stride() - stride_cm, stride_cn = output_tensor.stride() - - if bias is not None: - assert bias.shape[0] == M - bias_ptr = bias - stride_bias = bias.stride()[0] if bias.dim() > 0 else 1 - use_bias = True - else: - bias_ptr = output_tensor - stride_bias = 1 - use_bias = False - - # Get hardware configuration from workspace - num_sms = workspace.num_sms - num_xcds = workspace.num_xcds - - even_k = K_local % config.block_size_k == 0 - - # Use SM-based grid (persistent kernels) - grid = (num_sms,) - - # Select kernel variant based on config - if config.all_gather_matmul_variant == "chunked": - # Chunked variant: pre-gather into buffer, then GEMM - assert workspace.aux_buffer is not None, "Chunked variant requires aux_buffer in workspace" - _fused_chunked_all_gather_matmul_kernel[grid]( - A_sharded, - B, - output_tensor, - bias_ptr, - workspace.aux_buffer, # Temporary buffer - M, - N, - K, - K_local, - stride_am, - stride_ak, - stride_bk, - stride_bn, - stride_cm, - stride_cn, - stride_bias, - shmem.heap_bases, - rank, - world_size, - config.block_size_m, - config.block_size_n, - config.block_size_k, - config.group_size_m, - num_sms, - num_xcds, - use_bias, - even_k, - config.allow_tf32, - ) - else: - # Pull variant (default): on-demand pull from remote ranks - _fused_all_gather_matmul_kernel[grid]( - A_sharded, - B, - output_tensor, - bias_ptr, - M, - N, - K, - K_local, - stride_am, - stride_ak, - stride_bk, - stride_bn, - stride_cm, - stride_cn, - stride_bias, - shmem.heap_bases, - rank, - world_size, - config.block_size_m, - config.block_size_n, - config.block_size_k, - config.group_size_m, - num_sms, - num_xcds, - use_bias, - even_k, - config.allow_tf32, - ) - - if not async_op: - shmem.barrier() - - return workspace diff --git a/iris/ops/config.py b/iris/ops/config.py index a92925035..c5d15349b 100644 --- a/iris/ops/config.py +++ b/iris/ops/config.py @@ -35,8 +35,7 @@ class FusedConfig: "one_shot", "two_shot", "spinlock". Default: "two_shot". all_reduce_num_rings: Number of concurrent rings (for ring variant). Default: 1. all_gather_matmul_variant: All-gather + matmul algorithm variant. Options: - "pull" (on-demand pull from remote ranks), - "chunked" (pre-gather into buffer then GEMM). + "pull" (on-demand pull from remote ranks). Default: "pull". Example: diff --git a/tests/ops/test_all_gather_matmul.py b/tests/ops/test_all_gather_matmul.py index 7dceea126..db4b21250 100644 --- a/tests/ops/test_all_gather_matmul.py +++ b/tests/ops/test_all_gather_matmul.py @@ -32,7 +32,6 @@ "variant", [ "pull", - "chunked", ], ) def test_all_gather_matmul(dtype, atol, rtol, M, K_local, N, variant): diff --git a/tests/ops/test_matmul_all_reduce.py b/tests/ops/test_matmul_all_reduce.py index 5780b5d4d..0fd278fe0 100644 --- a/tests/ops/test_matmul_all_reduce.py +++ b/tests/ops/test_matmul_all_reduce.py @@ -112,7 +112,7 @@ def test_matmul_all_reduce_via_shmem_ops(): shmem = iris.iris(heap_size) rank = shmem.get_rank() - M, N, K = 256, 128, 64 + M, N, K = 256, 256, 64 dtype = torch.float16 A = shmem.randn((M, K), dtype=dtype) From 477b47220da79207529c40369592d093af59ccb2 Mon Sep 17 00:00:00 2001 From: Ryan Swann Date: Fri, 6 Mar 2026 19:03:44 +0000 Subject: [PATCH 27/31] Fix CI: increase default N to match FusedConfig block_size_n=256 Examples 28 (matmul_all_reduce) and 30 (matmul_all_gather) used N=128 as default, which is smaller than the new FusedConfig default block_size_n=256. This triggers assertion failures (N >= block_size_n) in CI, crashing all ranks and causing the 8-rank test to hang for 179 minutes waiting for the dead rank. Increase both examples' default N from 128 to 256 to match the new config defaults. Co-Authored-By: Claude Opus 4.6 (1M context) --- examples/28_ops_matmul_all_reduce/example.py | 2 +- examples/30_ops_matmul_all_gather/example.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/28_ops_matmul_all_reduce/example.py b/examples/28_ops_matmul_all_reduce/example.py index acaaff85d..086ef4d70 100644 --- a/examples/28_ops_matmul_all_reduce/example.py +++ b/examples/28_ops_matmul_all_reduce/example.py @@ -26,7 +26,7 @@ def parse_args(): formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument("-m", type=int, default=512, help="Rows of A") - parser.add_argument("-n", type=int, default=128, help="Columns of B") + parser.add_argument("-n", type=int, default=256, help="Columns of B") parser.add_argument("-k", type=int, default=256, help="Inner dimension") parser.add_argument("--heap_size", type=int, default=1 << 31, help="Iris heap size") parser.add_argument("--datatype", type=str, default="fp16", choices=["fp16", "fp32", "bf16"], help="Data type") diff --git a/examples/30_ops_matmul_all_gather/example.py b/examples/30_ops_matmul_all_gather/example.py index fbb12442e..e704246cc 100644 --- a/examples/30_ops_matmul_all_gather/example.py +++ b/examples/30_ops_matmul_all_gather/example.py @@ -27,7 +27,7 @@ def parse_args(): formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument("-m", type=int, default=4096, help="Total rows (must be divisible by world_size)") - parser.add_argument("-n", type=int, default=128, help="Columns of B") + parser.add_argument("-n", type=int, default=256, help="Columns of B") parser.add_argument("-k", type=int, default=256, help="Inner dimension") parser.add_argument("--heap_size", type=int, default=1 << 31, help="Iris heap size") parser.add_argument("--datatype", type=str, default="fp16", choices=["fp16", "fp32", "bf16"], help="Data type") From 76cc30d256f7cb6e4e9fe06314959e4131b72cfc Mon Sep 17 00:00:00 2001 From: Ryan Swann Date: Fri, 6 Mar 2026 19:15:24 +0000 Subject: [PATCH 28/31] Revert "Fix CI: increase default N to match FusedConfig block_size_n=256" This reverts commit 477b47220da79207529c40369592d093af59ccb2. --- examples/28_ops_matmul_all_reduce/example.py | 2 +- examples/30_ops_matmul_all_gather/example.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/28_ops_matmul_all_reduce/example.py b/examples/28_ops_matmul_all_reduce/example.py index 086ef4d70..acaaff85d 100644 --- a/examples/28_ops_matmul_all_reduce/example.py +++ b/examples/28_ops_matmul_all_reduce/example.py @@ -26,7 +26,7 @@ def parse_args(): formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument("-m", type=int, default=512, help="Rows of A") - parser.add_argument("-n", type=int, default=256, help="Columns of B") + parser.add_argument("-n", type=int, default=128, help="Columns of B") parser.add_argument("-k", type=int, default=256, help="Inner dimension") parser.add_argument("--heap_size", type=int, default=1 << 31, help="Iris heap size") parser.add_argument("--datatype", type=str, default="fp16", choices=["fp16", "fp32", "bf16"], help="Data type") diff --git a/examples/30_ops_matmul_all_gather/example.py b/examples/30_ops_matmul_all_gather/example.py index e704246cc..fbb12442e 100644 --- a/examples/30_ops_matmul_all_gather/example.py +++ b/examples/30_ops_matmul_all_gather/example.py @@ -27,7 +27,7 @@ def parse_args(): formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument("-m", type=int, default=4096, help="Total rows (must be divisible by world_size)") - parser.add_argument("-n", type=int, default=256, help="Columns of B") + parser.add_argument("-n", type=int, default=128, help="Columns of B") parser.add_argument("-k", type=int, default=256, help="Inner dimension") parser.add_argument("--heap_size", type=int, default=1 << 31, help="Iris heap size") parser.add_argument("--datatype", type=str, default="fp16", choices=["fp16", "fp32", "bf16"], help="Data type") From 9743b13e56dd724e942fbe2b96a9928b275b2088 Mon Sep 17 00:00:00 2001 From: Ryan Swann Date: Fri, 6 Mar 2026 19:18:55 +0000 Subject: [PATCH 29/31] =?UTF-8?q?Remove=20unnecessary=20block=20size=20ass?= =?UTF-8?q?ertions=20=E2=80=94=20Triton=20handles=20masking?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The Triton kernels already handle block_size > dimension via: - tl.cdiv(N, BLOCK_SIZE_N) for grid sizing - mask=(rn < N) on loads/stores - tritonblas GemmContext.reduce_axis handles K masking The assertions were preventing valid configurations (e.g., block_size_n=256 with N=128) that the kernels handle correctly. Removed for_problem() clamping too — it's unnecessary when the kernels already mask. Fixes CI failures on examples 28 and 30 which use N=128 with default FusedConfig block_size_n=256. --- iris/ops/all_gather_matmul.py | 11 ----------- iris/ops/matmul_all_gather.py | 11 ----------- iris/ops/matmul_all_reduce.py | 5 ----- 3 files changed, 27 deletions(-) diff --git a/iris/ops/all_gather_matmul.py b/iris/ops/all_gather_matmul.py index 6000f50ef..4f272825f 100644 --- a/iris/ops/all_gather_matmul.py +++ b/iris/ops/all_gather_matmul.py @@ -219,17 +219,6 @@ def all_gather_matmul( assert output_tensor.shape == (M, N), f"Output must be ({M}, {N}), got {output_tensor.shape}" # Validate problem size against block sizes - assert M >= config.block_size_m, ( - f"M ({M}) must be >= block_size_m ({config.block_size_m}). Use smaller block sizes for small problems." - ) - assert K_local >= config.block_size_k, ( - f"K_local ({K_local}) must be >= block_size_k ({config.block_size_k}). " - f"Use smaller block sizes for small problems." - ) - assert N >= config.block_size_n, ( - f"N ({N}) must be >= block_size_n ({config.block_size_n}). Use smaller block sizes for small problems." - ) - if workspace is None: workspace = all_gather_matmul_preamble(shmem, A_sharded, B, config) diff --git a/iris/ops/matmul_all_gather.py b/iris/ops/matmul_all_gather.py index ad42ac041..6b19caea4 100644 --- a/iris/ops/matmul_all_gather.py +++ b/iris/ops/matmul_all_gather.py @@ -180,17 +180,6 @@ def matmul_all_gather( assert output_tensor.shape == (M, N), f"Output must be ({M}, {N}), got {output_tensor.shape}" # Validate problem size against block sizes - assert M_local >= config.block_size_m, ( - f"M_local ({M_local}) must be >= block_size_m ({config.block_size_m}). " - f"Use smaller block sizes for small problems." - ) - assert K >= config.block_size_k, ( - f"K ({K}) must be >= block_size_k ({config.block_size_k}). Use smaller block sizes for small problems." - ) - assert N >= config.block_size_n, ( - f"N ({N}) must be >= block_size_n ({config.block_size_n}). Use smaller block sizes for small problems." - ) - # Allocate workspace if not provided if workspace is None: workspace = matmul_all_gather_preamble(shmem, A, B, config) diff --git a/iris/ops/matmul_all_reduce.py b/iris/ops/matmul_all_reduce.py index 73bea92c2..ceded7057 100644 --- a/iris/ops/matmul_all_reduce.py +++ b/iris/ops/matmul_all_reduce.py @@ -272,11 +272,6 @@ def matmul_all_reduce( if A.dtype != B.dtype or A.dtype != C.dtype: raise ValueError(f"All tensors must have same dtype, got A:{A.dtype}, B:{B.dtype}, C:{C.dtype}") - # Validate block sizes match problem dimensions - assert M >= config.block_size_m, f"M={M} too small for block_size_m={config.block_size_m}" - assert K >= config.block_size_k, f"K={K} too small for block_size_k={config.block_size_k}" - assert N >= config.block_size_n, f"N={N} too small for block_size_n={config.block_size_n}" - # Extract strides stride_am, stride_ak = A.stride() stride_bk, stride_bn = B.stride() From a86dc0400de2fff3e8b4d27cd637b2c41fc5dffc Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 11 Mar 2026 16:11:23 +0000 Subject: [PATCH 30/31] Initial plan From 445b25cb095078a51c3db9fb5cbf5c5b6eec80ee Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 12 Mar 2026 19:48:25 +0000 Subject: [PATCH 31/31] Add vectorization hints and tests for HBM buffer all-gather matmul Co-authored-by: ryanswann-amd <109695074+ryanswann-amd@users.noreply.github.com> --- iris/ops/all_gather_matmul_hbm_buffer.py | 4 +- iris/x/gather.py | 5 + .../ops/test_all_gather_matmul_hbm_buffer.py | 202 ++++++++++++++++++ 3 files changed, 210 insertions(+), 1 deletion(-) create mode 100644 tests/ops/test_all_gather_matmul_hbm_buffer.py diff --git a/iris/ops/all_gather_matmul_hbm_buffer.py b/iris/ops/all_gather_matmul_hbm_buffer.py index abe3b3936..2db1b6ed7 100644 --- a/iris/ops/all_gather_matmul_hbm_buffer.py +++ b/iris/ops/all_gather_matmul_hbm_buffer.py @@ -126,6 +126,7 @@ def _hbm_buffer_all_gather_matmul_kernel( k_block_start = k_flag_group * K_PER_FLAG rm = m_tile * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) for k_off in range(K_PER_FLAG): k_block_global = k_block_start + k_off @@ -138,11 +139,12 @@ def _hbm_buffer_all_gather_matmul_kernel( k_tile = iris.x.TileView(pid_m_t, tile_k_t, BLOCK_SIZE_M, BLOCK_SIZE_K) rk = k_block_global * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + rk = tl.max_contiguous(tl.multiple_of(rk, BLOCK_SIZE_K), BLOCK_SIZE_K) staged_ptrs = staged_a + rm.to(tl.int64)[:, None] * stride_sa_m + rk[None, :] * stride_sa_k for compile_rank in range(world_size): if src_rank_idx == compile_rank: - a_tile = iris.x.gather(k_tile, src_view, compile_rank, ctx) + a_tile = iris.x.gather(k_tile, src_view, compile_rank, ctx, hint=(1, BLOCK_SIZE_K)) tl.store(staged_ptrs, a_tile, cache_modifier=".cg") flag_idx = m_tile * NUM_FLAG_GROUPS_K + k_flag_group diff --git a/iris/x/gather.py b/iris/x/gather.py index ca8bd4f9c..4e2b10cc9 100644 --- a/iris/x/gather.py +++ b/iris/x/gather.py @@ -24,6 +24,7 @@ def gather( src_view: TensorView, source_rank: tl.constexpr, ctx: DeviceContext, + hint: tl.constexpr = None, ): """ Tile-level gather from a specific rank. @@ -37,6 +38,9 @@ def gather( src_view: TensorView for source tensor on source_rank. source_rank: Specific rank to load from (constexpr). ctx: DeviceContext with rank, world_size, and heap_bases. + hint: Vectorization hint passed to tl.multiple_of / tl.max_contiguous on + the translated pointer. Use a scalar (e.g. 16) or a tuple + (e.g. (1, 16)) to indicate alignment. Defaults to None (no hint). Returns: Loaded tile data as a tensor. @@ -61,6 +65,7 @@ def gather( source_rank, # from_rank (source rank) ctx.heap_bases, mask=mask, + hint=hint, ) return tile_data diff --git a/tests/ops/test_all_gather_matmul_hbm_buffer.py b/tests/ops/test_all_gather_matmul_hbm_buffer.py new file mode 100644 index 000000000..af173ea8b --- /dev/null +++ b/tests/ops/test_all_gather_matmul_hbm_buffer.py @@ -0,0 +1,202 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. + +""" +Tests for fused all_gather + matmul using the HBM staging buffer implementation. + +Each rank has A_sharded (M x K_local), B is replicated. +The operation gathers A from all ranks into a local HBM buffer and computes C = A_gathered @ B. +""" + +import pytest +import torch +import torch.distributed as dist + +import iris +from iris.ops.all_gather_matmul_hbm_buffer import ( + all_gather_matmul_hbm_buffer, + all_gather_matmul_hbm_buffer_preamble, +) +from iris.ops.config import FusedConfig + + +@pytest.mark.parametrize( + "dtype, atol, rtol", + [ + (torch.float16, 1e-2, 1e-2), + ], +) +@pytest.mark.parametrize( + "M,K_local,N", + [ + (128, 32, 64), + (256, 64, 128), + ], +) +@pytest.mark.parametrize( + "staged_a_layout", + [ + "k_contiguous", + "m_contiguous", + ], +) +def test_all_gather_matmul_hbm_buffer(dtype, atol, rtol, M, K_local, N, staged_a_layout): + """Test all_gather_matmul_hbm_buffer against torch all_gather + matmul.""" + if not dist.is_initialized(): + pytest.skip("torch.distributed not initialized") + + heap_size = 2**33 + shmem = iris.iris(heap_size) + rank = shmem.get_rank() + world_size = shmem.get_num_ranks() + + K = K_local * world_size # Full K dimension + + # Seed for reproducibility - different seed per rank for A_sharded + torch.manual_seed(42 + rank) + A_sharded = torch.randn(M, K_local, dtype=dtype, device=f"cuda:{rank}") + + # B must be identical on all ranks + torch.manual_seed(123) + B = torch.randn(K, N, dtype=dtype, device=f"cuda:{rank}") + + # Reference: torch all_gather + matmul + A_gathered_list = [torch.zeros(M, K_local, dtype=dtype, device=f"cuda:{rank}") for _ in range(world_size)] + dist.all_gather(A_gathered_list, A_sharded) + A_gathered_ref = torch.cat(A_gathered_list, dim=1) # (M, K) + ref_output = torch.matmul(A_gathered_ref, B) + torch.cuda.synchronize() + + # Create shmem tensors + A_sharded_shmem = shmem.zeros((M, K_local), dtype=dtype) + A_sharded_shmem.copy_(A_sharded) + B_shmem = shmem.zeros((K, N), dtype=dtype) + B_shmem.copy_(B) + output = shmem.zeros((M, N), dtype=dtype) + + shmem.barrier() + + # Use small block sizes for small test problems + config = FusedConfig( + block_size_m=64, + block_size_n=64, + block_size_k=32, + ) + + workspace = all_gather_matmul_hbm_buffer_preamble( + shmem, A_sharded_shmem, B_shmem, config=config, staged_a_layout=staged_a_layout + ) + + all_gather_matmul_hbm_buffer( + shmem, + output, + A_sharded_shmem, + B_shmem, + config=config, + workspace=workspace, + staged_a_layout=staged_a_layout, + trace=False, + ) + + torch.cuda.synchronize() + shmem.barrier() + + max_diff = (output - ref_output).abs().max().item() + + assert torch.allclose(output, ref_output, atol=atol, rtol=rtol), ( + f"Rank {rank}: Max diff {max_diff}, expected < {atol} " + f"(staged_a_layout={staged_a_layout}, M={M}, K_local={K_local}, N={N})" + ) + + +@pytest.mark.parametrize( + "dtype, atol, rtol", + [ + (torch.float16, 1e-2, 1e-2), + ], +) +@pytest.mark.parametrize( + "M,K_local,N", + [ + (128, 32, 64), + ], +) +def test_all_gather_matmul_hbm_buffer_with_bias(dtype, atol, rtol, M, K_local, N): + """Test all_gather_matmul_hbm_buffer with a bias vector.""" + if not dist.is_initialized(): + pytest.skip("torch.distributed not initialized") + + heap_size = 2**33 + shmem = iris.iris(heap_size) + rank = shmem.get_rank() + world_size = shmem.get_num_ranks() + + K = K_local * world_size + + torch.manual_seed(42 + rank) + A_sharded = torch.randn(M, K_local, dtype=dtype, device=f"cuda:{rank}") + + torch.manual_seed(123) + B = torch.randn(K, N, dtype=dtype, device=f"cuda:{rank}") + + torch.manual_seed(77) + bias = torch.randn(M, dtype=dtype, device=f"cuda:{rank}") + + # Reference: torch all_gather + matmul + bias + A_gathered_list = [torch.zeros(M, K_local, dtype=dtype, device=f"cuda:{rank}") for _ in range(world_size)] + dist.all_gather(A_gathered_list, A_sharded) + A_gathered_ref = torch.cat(A_gathered_list, dim=1) + ref_output = torch.matmul(A_gathered_ref, B) + bias[:, None] + torch.cuda.synchronize() + + # Create shmem tensors + A_sharded_shmem = shmem.zeros((M, K_local), dtype=dtype) + A_sharded_shmem.copy_(A_sharded) + B_shmem = shmem.zeros((K, N), dtype=dtype) + B_shmem.copy_(B) + bias_shmem = shmem.zeros((M,), dtype=dtype) + bias_shmem.copy_(bias) + output = shmem.zeros((M, N), dtype=dtype) + + shmem.barrier() + + config = FusedConfig( + block_size_m=64, + block_size_n=64, + block_size_k=32, + ) + + all_gather_matmul_hbm_buffer( + shmem, + output, + A_sharded_shmem, + B_shmem, + bias=bias_shmem, + config=config, + trace=False, + ) + + torch.cuda.synchronize() + shmem.barrier() + + max_diff = (output - ref_output).abs().max().item() + + assert torch.allclose(output, ref_output, atol=atol, rtol=rtol), ( + f"Rank {rank}: Max diff {max_diff}, expected < {atol} (with bias)" + ) + + +if __name__ == "__main__": + # For quick debugging + import sys + + if not dist.is_initialized(): + print("Run with: torchrun --nproc_per_node=2 tests/ops/test_all_gather_matmul_hbm_buffer.py") + sys.exit(1) + + rank = dist.get_rank() + torch.cuda.set_device(rank) + + print(f"[Rank {rank}] Testing all_gather_matmul_hbm_buffer...") + test_all_gather_matmul_hbm_buffer(torch.float16, 1e-2, 1e-2, 128, 32, 64, "k_contiguous") + print(f"[Rank {rank}] ✓ Test passed!")