From 3c7dddb16009e5bb7f31d569c006236f33f432d8 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 4 Mar 2026 16:13:40 +0000 Subject: [PATCH 1/4] Initial plan From e5ecf8e99342b96f2fccb23f81713984cc99a2f5 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 4 Mar 2026 16:21:32 +0000 Subject: [PATCH 2/4] Introduce GEMM+AllScatter op in iris.ops Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com> --- iris/ops/__init__.py | 34 ++++ iris/ops/matmul_all_scatter.py | 266 +++++++++++++++++++++++++++ tests/ops/test_matmul_all_scatter.py | 181 ++++++++++++++++++ 3 files changed, 481 insertions(+) create mode 100644 iris/ops/matmul_all_scatter.py create mode 100644 tests/ops/test_matmul_all_scatter.py diff --git a/iris/ops/__init__.py b/iris/ops/__init__.py index e0d12ba5..dd2f4895 100644 --- a/iris/ops/__init__.py +++ b/iris/ops/__init__.py @@ -27,6 +27,7 @@ - all_gather_matmul: All-Gather + GEMM - matmul_all_gather: GEMM + All-Gather - matmul_reduce_scatter: GEMM + Reduce-Scatter + - matmul_all_scatter: GEMM + All-Scatter """ from .config import FusedConfig @@ -38,6 +39,7 @@ from .all_gather_matmul import all_gather_matmul, all_gather_matmul_preamble from .matmul_all_gather import matmul_all_gather from .matmul_reduce_scatter import matmul_reduce_scatter, matmul_reduce_scatter_preamble +from .matmul_all_scatter import matmul_all_scatter, matmul_all_scatter_preamble class OpsNamespace: @@ -166,6 +168,36 @@ def matmul_reduce_scatter(self, output_tensor, A, B, bias=None, async_op=False, """ return matmul_reduce_scatter(self._shmem, output_tensor, A, B, bias, async_op, config, workspace) + def matmul_all_scatter(self, output_tensor, A, B_local, bias=None, async_op=False, config=None, workspace=None): + """ + Fused matrix multiplication and all-scatter. + + Computes: output = all_scatter(A @ B_local) along N dimension + + Each rank has B_local of shape (K, N_local) where N_local = N / world_size. + The operation computes C_local = A @ B_local on each rank and scatters the + tiles to all ranks so that every rank ends up with the full C (M, N). + + Args: + output_tensor: Output tensor (M, N) where N = N_local * world_size + A: Input matrix A (M, K) - replicated across ranks + B_local: Column-sharded input matrix B (K, N_local) + bias: Optional bias vector (M,) + async_op: If False, performs barrier at end + config: Optional FusedConfig for tuning + workspace: Optional pre-allocated workspace + + Returns: + workspace: Updated workspace object + + Example: + >>> N_local = N // world_size + >>> B_local = shmem.randn((K, N_local), dtype=torch.float16) + >>> output = shmem.zeros((M, N), dtype=torch.float16) + >>> shmem.ops.matmul_all_scatter(output, A, B_local) + """ + return matmul_all_scatter(self._shmem, output_tensor, A, B_local, bias, async_op, config, workspace) + # Export public API __all__ = [ @@ -183,4 +215,6 @@ def matmul_reduce_scatter(self, output_tensor, A, B, bias=None, async_op=False, "matmul_all_gather", "matmul_reduce_scatter", "matmul_reduce_scatter_preamble", + "matmul_all_scatter", + "matmul_all_scatter_preamble", ] diff --git a/iris/ops/matmul_all_scatter.py b/iris/ops/matmul_all_scatter.py new file mode 100644 index 00000000..1dd54932 --- /dev/null +++ b/iris/ops/matmul_all_scatter.py @@ -0,0 +1,266 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +""" +Fused GEMM + All-Scatter operation using scatter pattern. + +Each rank has a column-sharded input B_local (K x N_local) and full input A (M x K). +Each rank computes C_local = A @ B_local, then scatters C_local tiles to all ranks +so that each rank ends up with the full C (M x N) where N = world_size * N_local. + +This is useful for tensor-parallel workloads where weights are column-sharded and +outputs need to be gathered along the column dimension. +""" + +from typing import Optional +import torch +import triton +import triton.language as tl +import iris +import iris.x + +from tritonblas.kernels.stages import GemmContext, ScheduleContext, make_tensor_view + +from .config import FusedConfig +from .workspace import FusedWorkspace + + +@triton.jit() +def _fused_matmul_all_scatter_kernel( + A, # (M, K) - replicated across ranks + B_local, # (K, N_local) - each rank's column shard + C_gathered, # (M, N) - gathered output (N = N_local * world_size) + bias_ptr, + M, + N, + K, + N_local, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm_gathered, + stride_cn_gathered, + stride_bias, + context_tensor: tl.tensor, + cur_rank: tl.constexpr, + world_size: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + NUM_SMS: tl.constexpr, + NUM_XCDS: tl.constexpr, + BIAS: tl.constexpr, + EVEN_K: tl.constexpr, + ALLOW_TF32: tl.constexpr, +): + """ + Fused GEMM + all-scatter kernel using scatter pattern. + + Computes local GEMM tile and immediately scatters to all ranks along the N dimension. + No intermediate buffer needed - direct from registers to remote memory. + """ + # ═══════════════════════════════════════════════════════════════════════ + # Create tritonblas views, context, and scheduler for GEMM + # ═══════════════════════════════════════════════════════════════════════ + tensorA = make_tensor_view(A, M, K, stride_am, stride_ak) + tensorB = make_tensor_view(B_local, K, N_local, stride_bk, stride_bn) + gemm_ctx = GemmContext( + BLOCK_SIZE_M, + BLOCK_SIZE_N, + BLOCK_SIZE_K, + num_sms=NUM_SMS, + num_xcds=NUM_XCDS, + group_size_m=GROUP_SIZE_M, + even_k=EVEN_K, + allow_tf32=ALLOW_TF32, + ) + sched = ScheduleContext(M, N_local, K, gemm_ctx) + + # Persistent loop over local tiles using scheduler + start, total, stride = sched.persistent_tile_range() + for tile_id in range(start, total, stride): + # Get tile coordinates with swizzling from scheduler + out_tile = sched.get_tile_from_idx(tile_id) + + # GEMM using tritonblas stages + acc = gemm_ctx.reduce_axis(tensorA, tensorB, out_tile) + + # Add bias if provided + if BIAS: + rm, _ = out_tile.indices() + bias_vector = tl.load(bias_ptr + rm * stride_bias, mask=rm < M, other=0.0) + acc = acc + bias_vector[:, None] + + # Convert to output dtype + c = acc.to(C_gathered.type.element_ty) + + # Create DeviceContext and destination TensorView for all-scatter + ctx = iris.DeviceContext.initialize(context_tensor, cur_rank, world_size) + dst_view = iris.x.make_tensor_view(C_gathered, M, N, stride_cm_gathered, stride_cn_gathered) + tile_obj = iris.x.Tile(out_tile.pid_m, out_tile.pid_n, BLOCK_SIZE_M, BLOCK_SIZE_N, c) + + # Broadcast this rank's tile to all ranks using iris.x.all_gather with dim=1. + # dim=1 places the tile at the current rank's column offset in the global output, + # so every rank receives each rank's column-shard (all-scatter along N dimension). + iris.x.all_gather(tile_obj, dst_view, dim=1, ctx=ctx) + + +def matmul_all_scatter_preamble( + shmem, + A: torch.Tensor, + B_local: torch.Tensor, + config: Optional[FusedConfig] = None, +) -> FusedWorkspace: + """Allocate workspace for matmul_all_scatter (none needed for scatter pattern).""" + if config is None: + config = FusedConfig() + + M, K = A.shape + K2, N_local = B_local.shape + world_size = shmem.get_num_ranks() + + assert K == K2, f"Inner dimensions must match: A has K={K}, B_local has K={K2}" + + N = N_local * world_size + + return FusedWorkspace( + operation="matmul_all_scatter", + shape=(M, N, K), + dtype=A.dtype, + world_size=world_size, + prepared=True, + ) + + +def matmul_all_scatter( + shmem, + output_tensor: torch.Tensor, + A: torch.Tensor, + B_local: torch.Tensor, + bias: Optional[torch.Tensor] = None, + async_op: bool = False, + config: Optional[FusedConfig] = None, + workspace: Optional[FusedWorkspace] = None, +) -> FusedWorkspace: + """ + Fused matrix multiplication and all-scatter using scatter pattern. + + Computes: output = all_scatter(A @ B_local) along N dimension + + Each rank has B_local of shape (K, N_local) where N_local = N / world_size. + The operation computes C_local = A @ B_local on each rank and immediately + broadcasts each rank's column-shard tiles to all ranks via iris.x.all_gather + (dim=1), so that every rank ends up with the full C (M, N). + + This is a single-kernel implementation - no intermediate buffer needed. + Internally this uses iris.x.all_gather(dim=1) to broadcast each rank's + computed column tiles to all other ranks at the correct N offset. + + Args: + shmem: Iris shmem context + output_tensor: Output tensor C of shape (M, N) where N = N_local * world_size + A: Input matrix A of shape (M, K) - replicated across ranks + B_local: Column-sharded input matrix B of shape (K, N_local) + bias: Optional bias vector (M,) + async_op: If False, performs barrier at end + config: Optional FusedConfig for tuning + workspace: Optional pre-allocated workspace + + Returns: + FusedWorkspace object + + Example: + >>> N_local = N // world_size + >>> B_local = shmem.randn((K, N_local), dtype=torch.float16) + >>> output = shmem.zeros((M, N), dtype=torch.float16) + >>> shmem.ops.matmul_all_scatter(output, A, B_local) + """ + if config is None: + config = FusedConfig() + + M, K = A.shape + K2, N_local = B_local.shape + world_size = shmem.get_num_ranks() + rank = shmem.get_rank() + + assert K == K2, f"Inner dimensions must match: A has K={K}, B_local has K={K2}" + + N = N_local * world_size + assert output_tensor.shape == (M, N), f"Output must be ({M}, {N}), got {output_tensor.shape}" + + # Validate problem size against block sizes + assert M >= config.block_size_m, ( + f"M ({M}) must be >= block_size_m ({config.block_size_m}). Use smaller block sizes for small problems." + ) + assert K >= config.block_size_k, ( + f"K ({K}) must be >= block_size_k ({config.block_size_k}). Use smaller block sizes for small problems." + ) + assert N_local >= config.block_size_n, ( + f"N_local ({N_local}) must be >= block_size_n ({config.block_size_n}). " + f"Use smaller block sizes for small problems." + ) + + # Allocate workspace if not provided + if workspace is None: + workspace = matmul_all_scatter_preamble(shmem, A, B_local, config) + + stride_am, stride_ak = A.stride() + stride_bk, stride_bn = B_local.stride() + stride_cm_gathered, stride_cn_gathered = output_tensor.stride() + + if bias is not None: + assert bias.shape[0] == M + bias_ptr = bias + stride_bias = bias.stride()[0] if bias.dim() > 0 else 1 + use_bias = True + else: + bias_ptr = output_tensor + stride_bias = 1 + use_bias = False + + device = A.device + num_sms = config.num_sms + if num_sms is None: + props = torch.cuda.get_device_properties(device) + num_sms = props.multi_processor_count + + even_k = K % config.block_size_k == 0 + + # Launch single fused kernel + grid = (num_sms,) + _fused_matmul_all_scatter_kernel[grid]( + A, + B_local, + output_tensor, + bias_ptr, + M, + N, + K, + N_local, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm_gathered, + stride_cn_gathered, + stride_bias, + shmem.get_device_context(), + rank, + world_size, + config.block_size_m, + config.block_size_n, + config.block_size_k, + config.group_size_m, + num_sms, + config.num_xcds, + use_bias, + even_k, + config.allow_tf32, + ) + + if not async_op: + shmem.barrier() + + return workspace diff --git a/tests/ops/test_matmul_all_scatter.py b/tests/ops/test_matmul_all_scatter.py new file mode 100644 index 00000000..faa8694d --- /dev/null +++ b/tests/ops/test_matmul_all_scatter.py @@ -0,0 +1,181 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +""" +Test suite for high-level matmul_all_scatter API. + +Note: This test requires tritonBLAS to be installed. +Install with: pip install git+https://github.com/ROCm/tritonBLAS.git +""" + +import pytest +import torch +import torch.distributed as dist +import iris +import iris.ops as ops + + +@pytest.mark.parametrize( + "dtype, atol, rtol", + [ + (torch.float16, 0.5, 0.01), + (torch.bfloat16, 0.5, 0.01), + ], +) +@pytest.mark.parametrize( + "M, N, K", + [ + (64, 64, 32), + (512, 256, 512), + (1024, 2048, 1024), + ], +) +def test_matmul_all_scatter(dtype, atol, rtol, M, N, K): + """Test matmul_all_scatter using shmem.ops API with proper config. + + Validates against a PyTorch reference: local GEMM on each rank followed by + all_gather to concatenate column shards along the N dimension. + """ + if not dist.is_initialized(): + pytest.skip("torch.distributed not initialized") + + heap_size = 2**33 # 8GB + shmem = iris.iris(heap_size) + rank = shmem.get_rank() + world_size = shmem.get_num_ranks() + + # N must be divisible by world_size for column-wise sharding + if N % world_size != 0: + pytest.skip(f"N={N} not divisible by world_size={world_size}") + + N_local = N // world_size + + # Skip if problem size is too small for world_size + min_block_size = 32 # Smallest block size we use + if N_local < min_block_size: + pytest.skip(f"N_local={N_local} too small for world_size={world_size} (need >= {min_block_size})") + if K < min_block_size: + pytest.skip(f"K={K} too small (need >= {min_block_size})") + if M < min_block_size: + pytest.skip(f"M={M} too small (need >= {min_block_size})") + + # Create shmem tensors directly + A = shmem.randn((M, K), dtype=dtype) + B_local = shmem.randn((K, N_local), dtype=dtype) + output = shmem.zeros((M, N), dtype=dtype) + + # Reference: compute local GEMM, then all-gather along N dimension + A_ref = A.clone() + B_local_ref = B_local.clone() + C_local_ref = torch.matmul(A_ref, B_local_ref) + C_gathered_list = [torch.zeros(M, N_local, dtype=dtype, device=f"cuda:{rank}") for _ in range(world_size)] + dist.all_gather(C_gathered_list, C_local_ref) + pytorch_output = torch.cat(C_gathered_list, dim=1) # Concatenate along N dimension + torch.cuda.synchronize() + + shmem.barrier() + + # Use appropriate block sizes based on problem size + from iris.ops.config import FusedConfig + + # Select config based on actual problem dimensions + if M <= 64 or K <= 64 or N_local <= 64: + config = FusedConfig(block_size_m=32, block_size_n=32, block_size_k=32) + elif M <= 128 or K <= 128 or N_local <= 128: + config = FusedConfig(block_size_m=64, block_size_n=64, block_size_k=32) + elif dtype == torch.float32: + config = FusedConfig(block_size_m=128, block_size_n=128, block_size_k=64) + else: + config = FusedConfig(block_size_m=128, block_size_n=128, block_size_k=64) + + # Validate config against problem size + assert M >= config.block_size_m, f"M ({M}) must be >= block_size_m ({config.block_size_m})" + assert K >= config.block_size_k, f"K ({K}) must be >= block_size_k ({config.block_size_k})" + assert N_local >= config.block_size_n, f"N_local ({N_local}) must be >= block_size_n ({config.block_size_n})" + + # Use shmem.ops API with proper config + shmem.ops.matmul_all_scatter(output, A, B_local, config=config) + + torch.cuda.synchronize() + shmem.barrier() + + max_diff = torch.abs(output - pytorch_output).max().item() + + assert torch.allclose(output, pytorch_output, atol=atol, rtol=rtol), ( + f"Max difference: {max_diff}, expected < {atol}\n" + f"Rank {rank}: shmem.ops.matmul_all_scatter output doesn't match reference" + ) + + if rank == 0: + print(f"✓ matmul_all_scatter test passed: {dtype}, M={M}, N={N}, K={K}") + + shmem.barrier() + del shmem + import gc + + gc.collect() + + +@pytest.mark.parametrize( + "dtype, atol, rtol", + [ + (torch.float16, 0.5, 0.01), + (torch.bfloat16, 0.5, 0.01), + ], +) +def test_matmul_all_scatter_semantics(dtype, atol, rtol): + """ + Test that matmul_all_scatter is equivalent to: + C_local = A @ B_local (on each rank) + output = all_gather(C_local, dim=1) (concatenate along N) + """ + if not dist.is_initialized(): + pytest.skip("torch.distributed not initialized") + + heap_size = 2**33 + shmem = iris.iris(heap_size) + rank = shmem.get_rank() + world_size = shmem.get_num_ranks() + + M, N, K = 128, 128, 32 + + if N % world_size != 0: + pytest.skip(f"N={N} not divisible by world_size={world_size}") + + N_local = N // world_size + + A = shmem.randn((M, K), dtype=dtype) + B_local = shmem.randn((K, N_local), dtype=dtype) + output = shmem.zeros((M, N), dtype=dtype) + + # Reference + C_local_ref = torch.matmul(A.clone(), B_local.clone()) + C_gathered_list = [torch.zeros(M, N_local, dtype=dtype, device=f"cuda:{rank}") for _ in range(world_size)] + dist.all_gather(C_gathered_list, C_local_ref) + C_ref = torch.cat(C_gathered_list, dim=1) + torch.cuda.synchronize() + + config = ops.FusedConfig(block_size_m=64, block_size_n=64, block_size_k=32) + + if N_local < config.block_size_n: + config = ops.FusedConfig(block_size_m=32, block_size_n=N_local, block_size_k=32) + + from iris.ops.matmul_all_scatter import matmul_all_scatter + + matmul_all_scatter(shmem, output, A, B_local, config=config) + + torch.cuda.synchronize() + shmem.barrier() + + assert torch.allclose(output, C_ref, atol=atol, rtol=rtol), ( + f"Rank {rank}: matmul_all_scatter semantics mismatch. Max diff: {torch.abs(output - C_ref).max().item()}" + ) + + if rank == 0: + print("matmul_all_scatter semantics verified") + + shmem.barrier() + del shmem + import gc + + gc.collect() From d81a560448e9dd710da6ef8d421ab315bd65a5d9 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 5 Mar 2026 01:12:15 +0000 Subject: [PATCH 3/4] Rewrite matmul_all_scatter kernel to use direct iris.store scatter pattern from example 07 Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com> --- iris/ops/matmul_all_scatter.py | 56 ++++++++++++++++++++-------------- 1 file changed, 33 insertions(+), 23 deletions(-) diff --git a/iris/ops/matmul_all_scatter.py b/iris/ops/matmul_all_scatter.py index 1dd54932..2f8e8918 100644 --- a/iris/ops/matmul_all_scatter.py +++ b/iris/ops/matmul_all_scatter.py @@ -2,11 +2,11 @@ # Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. """ -Fused GEMM + All-Scatter operation using scatter pattern. +Fused GEMM + All-Scatter operation. Each rank has a column-sharded input B_local (K x N_local) and full input A (M x K). -Each rank computes C_local = A @ B_local, then scatters C_local tiles to all ranks -so that each rank ends up with the full C (M x N) where N = world_size * N_local. +Each rank computes C_local = A @ B_local, then scatters its column-stripe of C to all +other ranks so that every rank ends up with the full C (M x N) where N = world_size * N_local. This is useful for tensor-parallel workloads where weights are column-sharded and outputs need to be gathered along the column dimension. @@ -17,7 +17,6 @@ import triton import triton.language as tl import iris -import iris.x from tritonblas.kernels.stages import GemmContext, ScheduleContext, make_tensor_view @@ -29,7 +28,7 @@ def _fused_matmul_all_scatter_kernel( A, # (M, K) - replicated across ranks B_local, # (K, N_local) - each rank's column shard - C_gathered, # (M, N) - gathered output (N = N_local * world_size) + C_gathered, # (M, N) - output where N = N_local * world_size bias_ptr, M, N, @@ -42,7 +41,7 @@ def _fused_matmul_all_scatter_kernel( stride_cm_gathered, stride_cn_gathered, stride_bias, - context_tensor: tl.tensor, + heap_bases: tl.tensor, cur_rank: tl.constexpr, world_size: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, @@ -56,10 +55,10 @@ def _fused_matmul_all_scatter_kernel( ALLOW_TF32: tl.constexpr, ): """ - Fused GEMM + all-scatter kernel using scatter pattern. + Fused GEMM + all-scatter kernel. - Computes local GEMM tile and immediately scatters to all ranks along the N dimension. - No intermediate buffer needed - direct from registers to remote memory. + Computes local GEMM tile and immediately scatters this rank's column stripe + to all ranks via iris.store. No intermediate buffer needed. """ # ═══════════════════════════════════════════════════════════════════════ # Create tritonblas views, context, and scheduler for GEMM @@ -96,15 +95,28 @@ def _fused_matmul_all_scatter_kernel( # Convert to output dtype c = acc.to(C_gathered.type.element_ty) - # Create DeviceContext and destination TensorView for all-scatter - ctx = iris.DeviceContext.initialize(context_tensor, cur_rank, world_size) - dst_view = iris.x.make_tensor_view(C_gathered, M, N, stride_cm_gathered, stride_cn_gathered) - tile_obj = iris.x.Tile(out_tile.pid_m, out_tile.pid_n, BLOCK_SIZE_M, BLOCK_SIZE_N, c) + # Compute tile row/col indices with vectorization hints and bounds mask. + # Use local dimensions (M, N_local) since the GEMM covers the local tile space; + # the global offset below maps these local indices to the correct global position. + rm, rn, sub_mask = out_tile.layout(M, N_local) - # Broadcast this rank's tile to all ranks using iris.x.all_gather with dim=1. - # dim=1 places the tile at the current rank's column offset in the global output, - # so every rank receives each rank's column-shard (all-scatter along N dimension). - iris.x.all_gather(tile_obj, dst_view, dim=1, ctx=ctx) + # Global write offset: this rank owns column stripe [cur_rank*N_local, (cur_rank+1)*N_local) + global_offset = rm[:, None] * stride_cm_gathered + (rn[None, :] + cur_rank * N_local) * stride_cn_gathered + + # Write local result to this rank's own output buffer + tl.store(C_gathered + global_offset, c, mask=sub_mask) + + # Scatter this rank's column stripe to all remote ranks + for remote_rank in range(world_size): + if remote_rank != cur_rank: + iris.store( + C_gathered + global_offset, + c, + cur_rank, + remote_rank, + heap_bases, + mask=sub_mask, + ) def matmul_all_scatter_preamble( @@ -145,18 +157,16 @@ def matmul_all_scatter( workspace: Optional[FusedWorkspace] = None, ) -> FusedWorkspace: """ - Fused matrix multiplication and all-scatter using scatter pattern. + Fused matrix multiplication and all-scatter. Computes: output = all_scatter(A @ B_local) along N dimension Each rank has B_local of shape (K, N_local) where N_local = N / world_size. The operation computes C_local = A @ B_local on each rank and immediately - broadcasts each rank's column-shard tiles to all ranks via iris.x.all_gather - (dim=1), so that every rank ends up with the full C (M, N). + scatters its column stripe to all other ranks via iris.store, so that every + rank ends up with the full C (M, N). This is a single-kernel implementation - no intermediate buffer needed. - Internally this uses iris.x.all_gather(dim=1) to broadcast each rank's - computed column tiles to all other ranks at the correct N offset. Args: shmem: Iris shmem context @@ -246,7 +256,7 @@ def matmul_all_scatter( stride_cm_gathered, stride_cn_gathered, stride_bias, - shmem.get_device_context(), + shmem.get_heap_bases(), rank, world_size, config.block_size_m, From 0ac27a4b80991ed348d7c71fbe815f8987addbd3 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 5 Mar 2026 01:23:02 +0000 Subject: [PATCH 4/4] Add iris.x.all_scatter primitive and use it in matmul_all_scatter; fix variable naming Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com> --- iris/ops/__init__.py | 24 ++--- iris/ops/matmul_all_scatter.py | 139 ++++++++++++--------------- iris/x/__init__.py | 6 ++ iris/x/all_scatter.py | 63 ++++++++++++ tests/ops/test_matmul_all_scatter.py | 33 ++++--- 5 files changed, 160 insertions(+), 105 deletions(-) create mode 100644 iris/x/all_scatter.py diff --git a/iris/ops/__init__.py b/iris/ops/__init__.py index dd2f4895..1a1fe05c 100644 --- a/iris/ops/__init__.py +++ b/iris/ops/__init__.py @@ -168,20 +168,20 @@ def matmul_reduce_scatter(self, output_tensor, A, B, bias=None, async_op=False, """ return matmul_reduce_scatter(self._shmem, output_tensor, A, B, bias, async_op, config, workspace) - def matmul_all_scatter(self, output_tensor, A, B_local, bias=None, async_op=False, config=None, workspace=None): + def matmul_all_scatter(self, output, A, B_shard, bias=None, async_op=False, config=None, workspace=None): """ Fused matrix multiplication and all-scatter. - Computes: output = all_scatter(A @ B_local) along N dimension + Computes: output = all_scatter(A @ B_shard) along N dimension - Each rank has B_local of shape (K, N_local) where N_local = N / world_size. - The operation computes C_local = A @ B_local on each rank and scatters the - tiles to all ranks so that every rank ends up with the full C (M, N). + Each rank has B_shard of shape (K, N_shard) where N_shard = N / world_size. + The operation computes C_shard = A @ B_shard on each rank and scatters the + column stripe to all ranks so that every rank ends up with the full C (M, N). Args: - output_tensor: Output tensor (M, N) where N = N_local * world_size - A: Input matrix A (M, K) - replicated across ranks - B_local: Column-sharded input matrix B (K, N_local) + output: Output tensor (M, N) where N = N_shard * world_size + A: Input matrix A (M, K) — replicated across ranks + B_shard: Column-sharded weight matrix (K, N_shard) bias: Optional bias vector (M,) async_op: If False, performs barrier at end config: Optional FusedConfig for tuning @@ -191,12 +191,12 @@ def matmul_all_scatter(self, output_tensor, A, B_local, bias=None, async_op=Fals workspace: Updated workspace object Example: - >>> N_local = N // world_size - >>> B_local = shmem.randn((K, N_local), dtype=torch.float16) + >>> N_shard = N // world_size + >>> B_shard = shmem.randn((K, N_shard), dtype=torch.float16) >>> output = shmem.zeros((M, N), dtype=torch.float16) - >>> shmem.ops.matmul_all_scatter(output, A, B_local) + >>> shmem.ops.matmul_all_scatter(output, A, B_shard) """ - return matmul_all_scatter(self._shmem, output_tensor, A, B_local, bias, async_op, config, workspace) + return matmul_all_scatter(self._shmem, output, A, B_shard, bias, async_op, config, workspace) # Export public API diff --git a/iris/ops/matmul_all_scatter.py b/iris/ops/matmul_all_scatter.py index 2f8e8918..96ab721a 100644 --- a/iris/ops/matmul_all_scatter.py +++ b/iris/ops/matmul_all_scatter.py @@ -4,12 +4,13 @@ """ Fused GEMM + All-Scatter operation. -Each rank has a column-sharded input B_local (K x N_local) and full input A (M x K). -Each rank computes C_local = A @ B_local, then scatters its column-stripe of C to all -other ranks so that every rank ends up with the full C (M x N) where N = world_size * N_local. +Each rank has a column-sharded weight B_shard (K x N_shard) and a replicated +input A (M x K). Each rank computes C_shard = A @ B_shard, then scatters its +column stripe to all other ranks so that every rank ends up with the full +output C (M x N) where N = world_size * N_shard. -This is useful for tensor-parallel workloads where weights are column-sharded and -outputs need to be gathered along the column dimension. +This is useful for tensor-parallel workloads where weights are column-sharded +and the full output is needed on all ranks. """ from typing import Optional @@ -17,6 +18,7 @@ import triton import triton.language as tl import iris +import iris.x from tritonblas.kernels.stages import GemmContext, ScheduleContext, make_tensor_view @@ -26,22 +28,22 @@ @triton.jit() def _fused_matmul_all_scatter_kernel( - A, # (M, K) - replicated across ranks - B_local, # (K, N_local) - each rank's column shard - C_gathered, # (M, N) - output where N = N_local * world_size + A, # (M, K) - replicated across ranks + B_shard, # (K, N_shard) - each rank's column shard of weight matrix B + C, # (M, N) - output, N = N_shard * world_size bias_ptr, M, N, K, - N_local, + N_shard, stride_am, stride_ak, stride_bk, stride_bn, - stride_cm_gathered, - stride_cn_gathered, + stride_cm, + stride_cn, stride_bias, - heap_bases: tl.tensor, + context_tensor: tl.tensor, cur_rank: tl.constexpr, world_size: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, @@ -58,13 +60,13 @@ def _fused_matmul_all_scatter_kernel( Fused GEMM + all-scatter kernel. Computes local GEMM tile and immediately scatters this rank's column stripe - to all ranks via iris.store. No intermediate buffer needed. + to all ranks via ``iris.x.all_scatter``. No intermediate buffer needed. """ # ═══════════════════════════════════════════════════════════════════════ # Create tritonblas views, context, and scheduler for GEMM # ═══════════════════════════════════════════════════════════════════════ - tensorA = make_tensor_view(A, M, K, stride_am, stride_ak) - tensorB = make_tensor_view(B_local, K, N_local, stride_bk, stride_bn) + view_A = make_tensor_view(A, M, K, stride_am, stride_ak) + view_B = make_tensor_view(B_shard, K, N_shard, stride_bk, stride_bn) gemm_ctx = GemmContext( BLOCK_SIZE_M, BLOCK_SIZE_N, @@ -75,7 +77,9 @@ def _fused_matmul_all_scatter_kernel( even_k=EVEN_K, allow_tf32=ALLOW_TF32, ) - sched = ScheduleContext(M, N_local, K, gemm_ctx) + sched = ScheduleContext(M, N_shard, K, gemm_ctx) + ctx = iris.DeviceContext.initialize(context_tensor, cur_rank, world_size) + dst_view = iris.x.make_tensor_view(C, M, N, stride_cm, stride_cn) # Persistent loop over local tiles using scheduler start, total, stride = sched.persistent_tile_range() @@ -84,45 +88,26 @@ def _fused_matmul_all_scatter_kernel( out_tile = sched.get_tile_from_idx(tile_id) # GEMM using tritonblas stages - acc = gemm_ctx.reduce_axis(tensorA, tensorB, out_tile) + acc = gemm_ctx.reduce_axis(view_A, view_B, out_tile) # Add bias if provided if BIAS: rm, _ = out_tile.indices() - bias_vector = tl.load(bias_ptr + rm * stride_bias, mask=rm < M, other=0.0) - acc = acc + bias_vector[:, None] + bias_vec = tl.load(bias_ptr + rm * stride_bias, mask=rm < M, other=0.0) + acc = acc + bias_vec[:, None] # Convert to output dtype - c = acc.to(C_gathered.type.element_ty) + c_tile = acc.to(C.type.element_ty) - # Compute tile row/col indices with vectorization hints and bounds mask. - # Use local dimensions (M, N_local) since the GEMM covers the local tile space; - # the global offset below maps these local indices to the correct global position. - rm, rn, sub_mask = out_tile.layout(M, N_local) - - # Global write offset: this rank owns column stripe [cur_rank*N_local, (cur_rank+1)*N_local) - global_offset = rm[:, None] * stride_cm_gathered + (rn[None, :] + cur_rank * N_local) * stride_cn_gathered - - # Write local result to this rank's own output buffer - tl.store(C_gathered + global_offset, c, mask=sub_mask) - - # Scatter this rank's column stripe to all remote ranks - for remote_rank in range(world_size): - if remote_rank != cur_rank: - iris.store( - C_gathered + global_offset, - c, - cur_rank, - remote_rank, - heap_bases, - mask=sub_mask, - ) + # Wrap result in a Tile object and scatter to all ranks + tile_obj = iris.x.Tile(out_tile.pid_m, out_tile.pid_n, BLOCK_SIZE_M, BLOCK_SIZE_N, c_tile) + iris.x.all_scatter(tile_obj, dst_view, ctx) def matmul_all_scatter_preamble( shmem, A: torch.Tensor, - B_local: torch.Tensor, + B_shard: torch.Tensor, config: Optional[FusedConfig] = None, ) -> FusedWorkspace: """Allocate workspace for matmul_all_scatter (none needed for scatter pattern).""" @@ -130,12 +115,12 @@ def matmul_all_scatter_preamble( config = FusedConfig() M, K = A.shape - K2, N_local = B_local.shape + K2, N_shard = B_shard.shape world_size = shmem.get_num_ranks() - assert K == K2, f"Inner dimensions must match: A has K={K}, B_local has K={K2}" + assert K == K2, f"Inner dimensions must match: A has K={K}, B_shard has K={K2}" - N = N_local * world_size + N = N_shard * world_size return FusedWorkspace( operation="matmul_all_scatter", @@ -148,9 +133,9 @@ def matmul_all_scatter_preamble( def matmul_all_scatter( shmem, - output_tensor: torch.Tensor, + output: torch.Tensor, A: torch.Tensor, - B_local: torch.Tensor, + B_shard: torch.Tensor, bias: Optional[torch.Tensor] = None, async_op: bool = False, config: Optional[FusedConfig] = None, @@ -159,21 +144,21 @@ def matmul_all_scatter( """ Fused matrix multiplication and all-scatter. - Computes: output = all_scatter(A @ B_local) along N dimension + Computes: output = all_scatter(A @ B_shard) along N dimension - Each rank has B_local of shape (K, N_local) where N_local = N / world_size. - The operation computes C_local = A @ B_local on each rank and immediately - scatters its column stripe to all other ranks via iris.store, so that every - rank ends up with the full C (M, N). + Each rank has B_shard of shape (K, N_shard) where N_shard = N / world_size. + The operation computes C_shard = A @ B_shard on each rank and immediately + scatters its column stripe to all other ranks via ``iris.x.all_scatter``, + so that every rank ends up with the full C (M, N). - This is a single-kernel implementation - no intermediate buffer needed. + This is a single-kernel implementation — no intermediate buffer needed. Args: shmem: Iris shmem context - output_tensor: Output tensor C of shape (M, N) where N = N_local * world_size - A: Input matrix A of shape (M, K) - replicated across ranks - B_local: Column-sharded input matrix B of shape (K, N_local) - bias: Optional bias vector (M,) + output: Output tensor C of shape (M, N) where N = N_shard * world_size + A: Input matrix A of shape (M, K) — replicated across ranks + B_shard: Column-sharded weight matrix of shape (K, N_shard) + bias: Optional bias vector of shape (M,) async_op: If False, performs barrier at end config: Optional FusedConfig for tuning workspace: Optional pre-allocated workspace @@ -182,23 +167,23 @@ def matmul_all_scatter( FusedWorkspace object Example: - >>> N_local = N // world_size - >>> B_local = shmem.randn((K, N_local), dtype=torch.float16) + >>> N_shard = N // world_size + >>> B_shard = shmem.randn((K, N_shard), dtype=torch.float16) >>> output = shmem.zeros((M, N), dtype=torch.float16) - >>> shmem.ops.matmul_all_scatter(output, A, B_local) + >>> shmem.ops.matmul_all_scatter(output, A, B_shard) """ if config is None: config = FusedConfig() M, K = A.shape - K2, N_local = B_local.shape + K2, N_shard = B_shard.shape world_size = shmem.get_num_ranks() rank = shmem.get_rank() - assert K == K2, f"Inner dimensions must match: A has K={K}, B_local has K={K2}" + assert K == K2, f"Inner dimensions must match: A has K={K}, B_shard has K={K2}" - N = N_local * world_size - assert output_tensor.shape == (M, N), f"Output must be ({M}, {N}), got {output_tensor.shape}" + N = N_shard * world_size + assert output.shape == (M, N), f"Output must be ({M}, {N}), got {output.shape}" # Validate problem size against block sizes assert M >= config.block_size_m, ( @@ -207,18 +192,18 @@ def matmul_all_scatter( assert K >= config.block_size_k, ( f"K ({K}) must be >= block_size_k ({config.block_size_k}). Use smaller block sizes for small problems." ) - assert N_local >= config.block_size_n, ( - f"N_local ({N_local}) must be >= block_size_n ({config.block_size_n}). " + assert N_shard >= config.block_size_n, ( + f"N_shard ({N_shard}) must be >= block_size_n ({config.block_size_n}). " f"Use smaller block sizes for small problems." ) # Allocate workspace if not provided if workspace is None: - workspace = matmul_all_scatter_preamble(shmem, A, B_local, config) + workspace = matmul_all_scatter_preamble(shmem, A, B_shard, config) stride_am, stride_ak = A.stride() - stride_bk, stride_bn = B_local.stride() - stride_cm_gathered, stride_cn_gathered = output_tensor.stride() + stride_bk, stride_bn = B_shard.stride() + stride_cm, stride_cn = output.stride() if bias is not None: assert bias.shape[0] == M @@ -226,7 +211,7 @@ def matmul_all_scatter( stride_bias = bias.stride()[0] if bias.dim() > 0 else 1 use_bias = True else: - bias_ptr = output_tensor + bias_ptr = output stride_bias = 1 use_bias = False @@ -242,21 +227,21 @@ def matmul_all_scatter( grid = (num_sms,) _fused_matmul_all_scatter_kernel[grid]( A, - B_local, - output_tensor, + B_shard, + output, bias_ptr, M, N, K, - N_local, + N_shard, stride_am, stride_ak, stride_bk, stride_bn, - stride_cm_gathered, - stride_cn_gathered, + stride_cm, + stride_cn, stride_bias, - shmem.get_heap_bases(), + shmem.get_device_context(), rank, world_size, config.block_size_m, diff --git a/iris/x/__init__.py b/iris/x/__init__.py index 7377fbe3..63a3a641 100644 --- a/iris/x/__init__.py +++ b/iris/x/__init__.py @@ -26,6 +26,10 @@ >>> ctx.all_gather(tile, src_view, dst_view, dim=0) >>> ctx.all_to_all(tile, src_view, dst_view, N_per_rank) >>> ctx.reduce_scatter(tile, src_view, dst_view) + >>> + >>> # Standalone all_scatter (each rank pushes its tile to all ranks) + >>> tile_with_data = iris.x.Tile(pid_m, pid_n, BLOCK_M, BLOCK_N, computed_data) + >>> iris.x.all_scatter(tile_with_data, dst_view, ctx) Example (API with AllReduceConfig for algorithm selection): >>> @triton.jit @@ -70,6 +74,7 @@ ) from .gather import gather from .all_gather import all_gather +from .all_scatter import all_scatter from .all_to_all import all_to_all from .reduce_scatter import reduce_scatter @@ -91,6 +96,7 @@ "all_reduce_spinlock", "gather", "all_gather", + "all_scatter", "all_to_all", "reduce_scatter", ] diff --git a/iris/x/all_scatter.py b/iris/x/all_scatter.py new file mode 100644 index 00000000..9b069252 --- /dev/null +++ b/iris/x/all_scatter.py @@ -0,0 +1,63 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. + +""" +Tile-level all-scatter primitive for Iris. + +Each rank pushes its pre-computed tile to all other ranks at the rank's column +offset in the global output tensor. After the operation every rank holds the +full result. +""" + +import triton +import iris +from iris.iris import DeviceContext +from .core import Tile, TensorView + + +@triton.jit() +def all_scatter( + tile: Tile, + dst_view: TensorView, + ctx: DeviceContext, +): + """ + Tile-level all-scatter operation. + + Each rank scatters its pre-computed tile to all ranks (including itself) at + its column-stripe offset in the global output. Automatically derives + N_local from ``dst_view`` and ``ctx.world_size``. + + Args: + tile: Tile object containing the pre-computed data (e.g. a GEMM + result in registers). + dst_view: TensorView for the full output tensor of shape (M, N) where + ``N = N_local * world_size``. + ctx: DeviceContext carrying rank, world_size, and heap_bases. + + Layout: + Current rank's column stripe occupies + ``output[:, ctx.rank * N_local : (ctx.rank + 1) * N_local]`` + where ``N_local = dst_view.N // world_size``. + + Example:: + + tile_obj = iris.x.Tile(pid_m, pid_n, BLOCK_SIZE_M, BLOCK_SIZE_N, c) + dst_view = iris.x.make_tensor_view(C, M, N, stride_cm, stride_cn) + iris.x.all_scatter(tile_obj, dst_view, ctx) + """ + N_local = dst_view.N // ctx.world_size + + # Scatter this rank's tile to all destination ranks + for dest_rank in range(ctx.world_size): + # Compute pointer at this rank's column-stripe offset + dst_ptr, combined_mask = dst_view.offset_tile_ptr(tile, offset_n=ctx.rank * N_local, src_mask=None) + + iris.store( + dst_ptr, + tile.data, + ctx.rank, + dest_rank, + ctx.heap_bases, + mask=combined_mask, + ) diff --git a/tests/ops/test_matmul_all_scatter.py b/tests/ops/test_matmul_all_scatter.py index faa8694d..9a29634d 100644 --- a/tests/ops/test_matmul_all_scatter.py +++ b/tests/ops/test_matmul_all_scatter.py @@ -61,16 +61,16 @@ def test_matmul_all_scatter(dtype, atol, rtol, M, N, K): # Create shmem tensors directly A = shmem.randn((M, K), dtype=dtype) - B_local = shmem.randn((K, N_local), dtype=dtype) + B_shard = shmem.randn((K, N_local), dtype=dtype) output = shmem.zeros((M, N), dtype=dtype) # Reference: compute local GEMM, then all-gather along N dimension A_ref = A.clone() - B_local_ref = B_local.clone() - C_local_ref = torch.matmul(A_ref, B_local_ref) - C_gathered_list = [torch.zeros(M, N_local, dtype=dtype, device=f"cuda:{rank}") for _ in range(world_size)] - dist.all_gather(C_gathered_list, C_local_ref) - pytorch_output = torch.cat(C_gathered_list, dim=1) # Concatenate along N dimension + B_shard_ref = B_shard.clone() + C_shard_ref = torch.matmul(A_ref, B_shard_ref) + C_shards = [torch.zeros(M, N_local, dtype=dtype, device=f"cuda:{rank}") for _ in range(world_size)] + dist.all_gather(C_shards, C_shard_ref) + pytorch_output = torch.cat(C_shards, dim=1) # Concatenate along N dimension torch.cuda.synchronize() shmem.barrier() @@ -83,8 +83,6 @@ def test_matmul_all_scatter(dtype, atol, rtol, M, N, K): config = FusedConfig(block_size_m=32, block_size_n=32, block_size_k=32) elif M <= 128 or K <= 128 or N_local <= 128: config = FusedConfig(block_size_m=64, block_size_n=64, block_size_k=32) - elif dtype == torch.float32: - config = FusedConfig(block_size_m=128, block_size_n=128, block_size_k=64) else: config = FusedConfig(block_size_m=128, block_size_n=128, block_size_k=64) @@ -94,7 +92,7 @@ def test_matmul_all_scatter(dtype, atol, rtol, M, N, K): assert N_local >= config.block_size_n, f"N_local ({N_local}) must be >= block_size_n ({config.block_size_n})" # Use shmem.ops API with proper config - shmem.ops.matmul_all_scatter(output, A, B_local, config=config) + shmem.ops.matmul_all_scatter(output, A, B_shard, config=config) torch.cuda.synchronize() shmem.barrier() @@ -144,25 +142,28 @@ def test_matmul_all_scatter_semantics(dtype, atol, rtol): N_local = N // world_size + if N_local < 32: + pytest.skip(f"N_local={N_local} too small (need >= 32)") + A = shmem.randn((M, K), dtype=dtype) - B_local = shmem.randn((K, N_local), dtype=dtype) + B_shard = shmem.randn((K, N_local), dtype=dtype) output = shmem.zeros((M, N), dtype=dtype) # Reference - C_local_ref = torch.matmul(A.clone(), B_local.clone()) - C_gathered_list = [torch.zeros(M, N_local, dtype=dtype, device=f"cuda:{rank}") for _ in range(world_size)] - dist.all_gather(C_gathered_list, C_local_ref) - C_ref = torch.cat(C_gathered_list, dim=1) + C_shard_ref = torch.matmul(A.clone(), B_shard.clone()) + C_shards = [torch.zeros(M, N_local, dtype=dtype, device=f"cuda:{rank}") for _ in range(world_size)] + dist.all_gather(C_shards, C_shard_ref) + C_ref = torch.cat(C_shards, dim=1) torch.cuda.synchronize() config = ops.FusedConfig(block_size_m=64, block_size_n=64, block_size_k=32) if N_local < config.block_size_n: - config = ops.FusedConfig(block_size_m=32, block_size_n=N_local, block_size_k=32) + pytest.skip(f"N_local={N_local} < block_size_n={config.block_size_n}, skipping") from iris.ops.matmul_all_scatter import matmul_all_scatter - matmul_all_scatter(shmem, output, A, B_local, config=config) + matmul_all_scatter(shmem, output, A, B_shard, config=config) torch.cuda.synchronize() shmem.barrier()