From 5e47a8dcf7d5d4d8c34501bb8866a1df00e495fb Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 8 Mar 2026 00:49:59 +0000 Subject: [PATCH 1/3] Initial plan From 895fba80269e90878a38b4f313ea0fd0df4e91d3 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 8 Mar 2026 00:54:34 +0000 Subject: [PATCH 2/3] Add reduce scatter example and CCL test Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com> --- examples/27_ccl_reduce_scatter/example.py | 90 +++++++++++++++++++++++ tests/ccl/test_reduce_scatter.py | 86 ++++++++++++++++++++++ 2 files changed, 176 insertions(+) create mode 100644 examples/27_ccl_reduce_scatter/example.py create mode 100644 tests/ccl/test_reduce_scatter.py diff --git a/examples/27_ccl_reduce_scatter/example.py b/examples/27_ccl_reduce_scatter/example.py new file mode 100644 index 00000000..15dcf4e9 --- /dev/null +++ b/examples/27_ccl_reduce_scatter/example.py @@ -0,0 +1,90 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +""" +Example: iris.ccl.reduce_scatter + +Each rank contributes an (M, N) tensor. The reduce-scatter collective reduces (sums) +the inputs from all ranks and partitions the result: each rank receives the reduced +values for its assigned tile partition only, with all other elements remaining zero. + +Together, all ranks' outputs form a complete partition of the full reduced result. + +Run with: + torchrun --nproc_per_node= --standalone example.py [--validate] +""" + +import argparse +import os + +import torch +import torch.distributed as dist + +import iris + + +def parse_args(): + parser = argparse.ArgumentParser( + description="CCL reduce-scatter example", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("-m", type=int, default=1024, help="Number of rows") + parser.add_argument("-n", type=int, default=512, help="Number of columns") + parser.add_argument("--heap_size", type=int, default=1 << 31, help="Iris heap size") + parser.add_argument("--datatype", type=str, default="fp16", choices=["fp16", "fp32", "bf16"], help="Data type") + parser.add_argument("-v", "--validate", action="store_true", help="Validate output against reference") + return vars(parser.parse_args()) + + +def main(): + args = parse_args() + + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + torch.cuda.set_device(local_rank) + dist.init_process_group(backend="gloo") + + ctx = iris.iris(heap_size=args["heap_size"]) + rank = ctx.get_rank() + world_size = ctx.get_num_ranks() + + dtype_map = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16} + dtype = dtype_map[args["datatype"]] + M, N = args["m"], args["n"] + + # Each rank fills its input with (rank + 1) + input_tensor = ctx.zeros((M, N), dtype=dtype) + input_tensor.fill_(float(rank + 1)) + output_tensor = ctx.zeros((M, N), dtype=dtype) + + ctx.barrier() + ctx.ccl.reduce_scatter(output_tensor, input_tensor) + torch.cuda.synchronize() + + if rank == 0: + ctx.info(f"reduce_scatter: world_size={world_size}, shape=({M},{N}), dtype={dtype}") + + if args["validate"]: + # Each rank owns a partition of tiles. The value at each assigned tile is the + # element-wise sum of all ranks' inputs: sum(r+1 for r in 0..world_size-1). + # Tiles not assigned to this rank remain 0. + # Summing the outputs across all ranks (all_reduce) fills every element with + # the expected per-element sum, since the tile partition is complete. + expected = float(world_size * (world_size + 1) // 2) + + aggregated = output_tensor.clone() + dist.all_reduce(aggregated, op=dist.ReduceOp.SUM) + torch.cuda.synchronize() + + assert torch.allclose(aggregated, torch.full_like(aggregated, expected), atol=0.5), ( + f"Rank {rank}: mismatch after aggregation. Got {aggregated[0, 0].item():.1f}, expected {expected:.1f}" + ) + if rank == 0: + ctx.info(f"Validation passed: aggregated[0,0] = {aggregated[0, 0].item():.1f}") + + ctx.barrier() + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/tests/ccl/test_reduce_scatter.py b/tests/ccl/test_reduce_scatter.py new file mode 100644 index 00000000..1d2f6e8b --- /dev/null +++ b/tests/ccl/test_reduce_scatter.py @@ -0,0 +1,86 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +""" +Test suite for reduce-scatter collective operation. +""" + +import gc + +import pytest +import torch +import torch.distributed as dist +import iris +from iris.ccl import Config + + +@pytest.mark.parametrize( + "distribution", + [ + 0, # striding + 1, # block + ], +) +@pytest.mark.parametrize( + "dtype", + [ + torch.float16, + torch.float32, + torch.bfloat16, + ], +) +@pytest.mark.parametrize( + "M, N", + [ + (128, 64), # Small + (1024, 256), # Medium + (8192, 8192), # Large + ], +) +def test_reduce_scatter(distribution, dtype, M, N): + """Test reduce-scatter functionality. + + Each rank reduces its assigned tiles from all ranks' inputs. The tile partition + is complete: summing the outputs across all ranks (all_reduce) should yield the + element-wise sum of all inputs at every position. + """ + if not dist.is_initialized(): + pytest.skip("torch.distributed not initialized") + + heap_size = 2**33 # 8GB + shmem = iris.iris(heap_size) + rank = shmem.get_rank() + world_size = shmem.get_num_ranks() + + # Each rank fills its input with (rank + 1) + iris_input_tensor = shmem.zeros((M, N), dtype=dtype) + iris_input_tensor.fill_(float(rank + 1)) + + iris_output_tensor = shmem.zeros((M, N), dtype=dtype) + + # Run Iris reduce_scatter + shmem.barrier() + config = Config(block_size_m=32, block_size_n=64, all_reduce_distribution=distribution) + shmem.ccl.reduce_scatter(iris_output_tensor, iris_input_tensor, config=config) + torch.cuda.synchronize() + + # Validate: tiles are partitioned across ranks, so summing outputs from all ranks + # (via all_reduce) should give the element-wise sum of all inputs at every position. + expected = float(world_size * (world_size + 1) // 2) + aggregated = iris_output_tensor.clone() + dist.all_reduce(aggregated, op=dist.ReduceOp.SUM) + torch.cuda.synchronize() + + atol = 1e-3 if dtype == torch.float16 else 1e-5 + max_diff = torch.abs(aggregated - expected).max().item() + + try: + assert torch.allclose(aggregated, torch.full_like(aggregated, expected), atol=atol), ( + f"Max difference: {max_diff}, expected < {atol}\n" + f"Rank {rank}: aggregated reduce-scatter outputs don't match expected sum " + f"(distribution={distribution})" + ) + finally: + shmem.barrier() + del shmem + gc.collect() From fe1a915a70d20f954a20089c42e5fb2b8c4073d4 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 8 Mar 2026 01:30:16 +0000 Subject: [PATCH 3/3] Remove redundant test_reduce_scatter.py (tests already exist in test_process_groups.py) Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com> --- tests/ccl/test_reduce_scatter.py | 86 -------------------------------- 1 file changed, 86 deletions(-) delete mode 100644 tests/ccl/test_reduce_scatter.py diff --git a/tests/ccl/test_reduce_scatter.py b/tests/ccl/test_reduce_scatter.py deleted file mode 100644 index 1d2f6e8b..00000000 --- a/tests/ccl/test_reduce_scatter.py +++ /dev/null @@ -1,86 +0,0 @@ -# SPDX-License-Identifier: MIT -# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. - -""" -Test suite for reduce-scatter collective operation. -""" - -import gc - -import pytest -import torch -import torch.distributed as dist -import iris -from iris.ccl import Config - - -@pytest.mark.parametrize( - "distribution", - [ - 0, # striding - 1, # block - ], -) -@pytest.mark.parametrize( - "dtype", - [ - torch.float16, - torch.float32, - torch.bfloat16, - ], -) -@pytest.mark.parametrize( - "M, N", - [ - (128, 64), # Small - (1024, 256), # Medium - (8192, 8192), # Large - ], -) -def test_reduce_scatter(distribution, dtype, M, N): - """Test reduce-scatter functionality. - - Each rank reduces its assigned tiles from all ranks' inputs. The tile partition - is complete: summing the outputs across all ranks (all_reduce) should yield the - element-wise sum of all inputs at every position. - """ - if not dist.is_initialized(): - pytest.skip("torch.distributed not initialized") - - heap_size = 2**33 # 8GB - shmem = iris.iris(heap_size) - rank = shmem.get_rank() - world_size = shmem.get_num_ranks() - - # Each rank fills its input with (rank + 1) - iris_input_tensor = shmem.zeros((M, N), dtype=dtype) - iris_input_tensor.fill_(float(rank + 1)) - - iris_output_tensor = shmem.zeros((M, N), dtype=dtype) - - # Run Iris reduce_scatter - shmem.barrier() - config = Config(block_size_m=32, block_size_n=64, all_reduce_distribution=distribution) - shmem.ccl.reduce_scatter(iris_output_tensor, iris_input_tensor, config=config) - torch.cuda.synchronize() - - # Validate: tiles are partitioned across ranks, so summing outputs from all ranks - # (via all_reduce) should give the element-wise sum of all inputs at every position. - expected = float(world_size * (world_size + 1) // 2) - aggregated = iris_output_tensor.clone() - dist.all_reduce(aggregated, op=dist.ReduceOp.SUM) - torch.cuda.synchronize() - - atol = 1e-3 if dtype == torch.float16 else 1e-5 - max_diff = torch.abs(aggregated - expected).max().item() - - try: - assert torch.allclose(aggregated, torch.full_like(aggregated, expected), atol=atol), ( - f"Max difference: {max_diff}, expected < {atol}\n" - f"Rank {rank}: aggregated reduce-scatter outputs don't match expected sum " - f"(distribution={distribution})" - ) - finally: - shmem.barrier() - del shmem - gc.collect()