diff --git a/iris/ops/__init__.py b/iris/ops/__init__.py index e0d12ba5..1a1fe05c 100644 --- a/iris/ops/__init__.py +++ b/iris/ops/__init__.py @@ -27,6 +27,7 @@ - all_gather_matmul: All-Gather + GEMM - matmul_all_gather: GEMM + All-Gather - matmul_reduce_scatter: GEMM + Reduce-Scatter + - matmul_all_scatter: GEMM + All-Scatter """ from .config import FusedConfig @@ -38,6 +39,7 @@ from .all_gather_matmul import all_gather_matmul, all_gather_matmul_preamble from .matmul_all_gather import matmul_all_gather from .matmul_reduce_scatter import matmul_reduce_scatter, matmul_reduce_scatter_preamble +from .matmul_all_scatter import matmul_all_scatter, matmul_all_scatter_preamble class OpsNamespace: @@ -166,6 +168,36 @@ def matmul_reduce_scatter(self, output_tensor, A, B, bias=None, async_op=False, """ return matmul_reduce_scatter(self._shmem, output_tensor, A, B, bias, async_op, config, workspace) + def matmul_all_scatter(self, output, A, B_shard, bias=None, async_op=False, config=None, workspace=None): + """ + Fused matrix multiplication and all-scatter. + + Computes: output = all_scatter(A @ B_shard) along N dimension + + Each rank has B_shard of shape (K, N_shard) where N_shard = N / world_size. + The operation computes C_shard = A @ B_shard on each rank and scatters the + column stripe to all ranks so that every rank ends up with the full C (M, N). + + Args: + output: Output tensor (M, N) where N = N_shard * world_size + A: Input matrix A (M, K) — replicated across ranks + B_shard: Column-sharded weight matrix (K, N_shard) + bias: Optional bias vector (M,) + async_op: If False, performs barrier at end + config: Optional FusedConfig for tuning + workspace: Optional pre-allocated workspace + + Returns: + workspace: Updated workspace object + + Example: + >>> N_shard = N // world_size + >>> B_shard = shmem.randn((K, N_shard), dtype=torch.float16) + >>> output = shmem.zeros((M, N), dtype=torch.float16) + >>> shmem.ops.matmul_all_scatter(output, A, B_shard) + """ + return matmul_all_scatter(self._shmem, output, A, B_shard, bias, async_op, config, workspace) + # Export public API __all__ = [ @@ -183,4 +215,6 @@ def matmul_reduce_scatter(self, output_tensor, A, B, bias=None, async_op=False, "matmul_all_gather", "matmul_reduce_scatter", "matmul_reduce_scatter_preamble", + "matmul_all_scatter", + "matmul_all_scatter_preamble", ] diff --git a/iris/ops/matmul_all_scatter.py b/iris/ops/matmul_all_scatter.py new file mode 100644 index 00000000..96ab721a --- /dev/null +++ b/iris/ops/matmul_all_scatter.py @@ -0,0 +1,261 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +""" +Fused GEMM + All-Scatter operation. + +Each rank has a column-sharded weight B_shard (K x N_shard) and a replicated +input A (M x K). Each rank computes C_shard = A @ B_shard, then scatters its +column stripe to all other ranks so that every rank ends up with the full +output C (M x N) where N = world_size * N_shard. + +This is useful for tensor-parallel workloads where weights are column-sharded +and the full output is needed on all ranks. +""" + +from typing import Optional +import torch +import triton +import triton.language as tl +import iris +import iris.x + +from tritonblas.kernels.stages import GemmContext, ScheduleContext, make_tensor_view + +from .config import FusedConfig +from .workspace import FusedWorkspace + + +@triton.jit() +def _fused_matmul_all_scatter_kernel( + A, # (M, K) - replicated across ranks + B_shard, # (K, N_shard) - each rank's column shard of weight matrix B + C, # (M, N) - output, N = N_shard * world_size + bias_ptr, + M, + N, + K, + N_shard, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_bias, + context_tensor: tl.tensor, + cur_rank: tl.constexpr, + world_size: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + NUM_SMS: tl.constexpr, + NUM_XCDS: tl.constexpr, + BIAS: tl.constexpr, + EVEN_K: tl.constexpr, + ALLOW_TF32: tl.constexpr, +): + """ + Fused GEMM + all-scatter kernel. + + Computes local GEMM tile and immediately scatters this rank's column stripe + to all ranks via ``iris.x.all_scatter``. No intermediate buffer needed. + """ + # ═══════════════════════════════════════════════════════════════════════ + # Create tritonblas views, context, and scheduler for GEMM + # ═══════════════════════════════════════════════════════════════════════ + view_A = make_tensor_view(A, M, K, stride_am, stride_ak) + view_B = make_tensor_view(B_shard, K, N_shard, stride_bk, stride_bn) + gemm_ctx = GemmContext( + BLOCK_SIZE_M, + BLOCK_SIZE_N, + BLOCK_SIZE_K, + num_sms=NUM_SMS, + num_xcds=NUM_XCDS, + group_size_m=GROUP_SIZE_M, + even_k=EVEN_K, + allow_tf32=ALLOW_TF32, + ) + sched = ScheduleContext(M, N_shard, K, gemm_ctx) + ctx = iris.DeviceContext.initialize(context_tensor, cur_rank, world_size) + dst_view = iris.x.make_tensor_view(C, M, N, stride_cm, stride_cn) + + # Persistent loop over local tiles using scheduler + start, total, stride = sched.persistent_tile_range() + for tile_id in range(start, total, stride): + # Get tile coordinates with swizzling from scheduler + out_tile = sched.get_tile_from_idx(tile_id) + + # GEMM using tritonblas stages + acc = gemm_ctx.reduce_axis(view_A, view_B, out_tile) + + # Add bias if provided + if BIAS: + rm, _ = out_tile.indices() + bias_vec = tl.load(bias_ptr + rm * stride_bias, mask=rm < M, other=0.0) + acc = acc + bias_vec[:, None] + + # Convert to output dtype + c_tile = acc.to(C.type.element_ty) + + # Wrap result in a Tile object and scatter to all ranks + tile_obj = iris.x.Tile(out_tile.pid_m, out_tile.pid_n, BLOCK_SIZE_M, BLOCK_SIZE_N, c_tile) + iris.x.all_scatter(tile_obj, dst_view, ctx) + + +def matmul_all_scatter_preamble( + shmem, + A: torch.Tensor, + B_shard: torch.Tensor, + config: Optional[FusedConfig] = None, +) -> FusedWorkspace: + """Allocate workspace for matmul_all_scatter (none needed for scatter pattern).""" + if config is None: + config = FusedConfig() + + M, K = A.shape + K2, N_shard = B_shard.shape + world_size = shmem.get_num_ranks() + + assert K == K2, f"Inner dimensions must match: A has K={K}, B_shard has K={K2}" + + N = N_shard * world_size + + return FusedWorkspace( + operation="matmul_all_scatter", + shape=(M, N, K), + dtype=A.dtype, + world_size=world_size, + prepared=True, + ) + + +def matmul_all_scatter( + shmem, + output: torch.Tensor, + A: torch.Tensor, + B_shard: torch.Tensor, + bias: Optional[torch.Tensor] = None, + async_op: bool = False, + config: Optional[FusedConfig] = None, + workspace: Optional[FusedWorkspace] = None, +) -> FusedWorkspace: + """ + Fused matrix multiplication and all-scatter. + + Computes: output = all_scatter(A @ B_shard) along N dimension + + Each rank has B_shard of shape (K, N_shard) where N_shard = N / world_size. + The operation computes C_shard = A @ B_shard on each rank and immediately + scatters its column stripe to all other ranks via ``iris.x.all_scatter``, + so that every rank ends up with the full C (M, N). + + This is a single-kernel implementation — no intermediate buffer needed. + + Args: + shmem: Iris shmem context + output: Output tensor C of shape (M, N) where N = N_shard * world_size + A: Input matrix A of shape (M, K) — replicated across ranks + B_shard: Column-sharded weight matrix of shape (K, N_shard) + bias: Optional bias vector of shape (M,) + async_op: If False, performs barrier at end + config: Optional FusedConfig for tuning + workspace: Optional pre-allocated workspace + + Returns: + FusedWorkspace object + + Example: + >>> N_shard = N // world_size + >>> B_shard = shmem.randn((K, N_shard), dtype=torch.float16) + >>> output = shmem.zeros((M, N), dtype=torch.float16) + >>> shmem.ops.matmul_all_scatter(output, A, B_shard) + """ + if config is None: + config = FusedConfig() + + M, K = A.shape + K2, N_shard = B_shard.shape + world_size = shmem.get_num_ranks() + rank = shmem.get_rank() + + assert K == K2, f"Inner dimensions must match: A has K={K}, B_shard has K={K2}" + + N = N_shard * world_size + assert output.shape == (M, N), f"Output must be ({M}, {N}), got {output.shape}" + + # Validate problem size against block sizes + assert M >= config.block_size_m, ( + f"M ({M}) must be >= block_size_m ({config.block_size_m}). Use smaller block sizes for small problems." + ) + assert K >= config.block_size_k, ( + f"K ({K}) must be >= block_size_k ({config.block_size_k}). Use smaller block sizes for small problems." + ) + assert N_shard >= config.block_size_n, ( + f"N_shard ({N_shard}) must be >= block_size_n ({config.block_size_n}). " + f"Use smaller block sizes for small problems." + ) + + # Allocate workspace if not provided + if workspace is None: + workspace = matmul_all_scatter_preamble(shmem, A, B_shard, config) + + stride_am, stride_ak = A.stride() + stride_bk, stride_bn = B_shard.stride() + stride_cm, stride_cn = output.stride() + + if bias is not None: + assert bias.shape[0] == M + bias_ptr = bias + stride_bias = bias.stride()[0] if bias.dim() > 0 else 1 + use_bias = True + else: + bias_ptr = output + stride_bias = 1 + use_bias = False + + device = A.device + num_sms = config.num_sms + if num_sms is None: + props = torch.cuda.get_device_properties(device) + num_sms = props.multi_processor_count + + even_k = K % config.block_size_k == 0 + + # Launch single fused kernel + grid = (num_sms,) + _fused_matmul_all_scatter_kernel[grid]( + A, + B_shard, + output, + bias_ptr, + M, + N, + K, + N_shard, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_bias, + shmem.get_device_context(), + rank, + world_size, + config.block_size_m, + config.block_size_n, + config.block_size_k, + config.group_size_m, + num_sms, + config.num_xcds, + use_bias, + even_k, + config.allow_tf32, + ) + + if not async_op: + shmem.barrier() + + return workspace diff --git a/iris/x/__init__.py b/iris/x/__init__.py index 7377fbe3..63a3a641 100644 --- a/iris/x/__init__.py +++ b/iris/x/__init__.py @@ -26,6 +26,10 @@ >>> ctx.all_gather(tile, src_view, dst_view, dim=0) >>> ctx.all_to_all(tile, src_view, dst_view, N_per_rank) >>> ctx.reduce_scatter(tile, src_view, dst_view) + >>> + >>> # Standalone all_scatter (each rank pushes its tile to all ranks) + >>> tile_with_data = iris.x.Tile(pid_m, pid_n, BLOCK_M, BLOCK_N, computed_data) + >>> iris.x.all_scatter(tile_with_data, dst_view, ctx) Example (API with AllReduceConfig for algorithm selection): >>> @triton.jit @@ -70,6 +74,7 @@ ) from .gather import gather from .all_gather import all_gather +from .all_scatter import all_scatter from .all_to_all import all_to_all from .reduce_scatter import reduce_scatter @@ -91,6 +96,7 @@ "all_reduce_spinlock", "gather", "all_gather", + "all_scatter", "all_to_all", "reduce_scatter", ] diff --git a/iris/x/all_scatter.py b/iris/x/all_scatter.py new file mode 100644 index 00000000..9b069252 --- /dev/null +++ b/iris/x/all_scatter.py @@ -0,0 +1,63 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. + +""" +Tile-level all-scatter primitive for Iris. + +Each rank pushes its pre-computed tile to all other ranks at the rank's column +offset in the global output tensor. After the operation every rank holds the +full result. +""" + +import triton +import iris +from iris.iris import DeviceContext +from .core import Tile, TensorView + + +@triton.jit() +def all_scatter( + tile: Tile, + dst_view: TensorView, + ctx: DeviceContext, +): + """ + Tile-level all-scatter operation. + + Each rank scatters its pre-computed tile to all ranks (including itself) at + its column-stripe offset in the global output. Automatically derives + N_local from ``dst_view`` and ``ctx.world_size``. + + Args: + tile: Tile object containing the pre-computed data (e.g. a GEMM + result in registers). + dst_view: TensorView for the full output tensor of shape (M, N) where + ``N = N_local * world_size``. + ctx: DeviceContext carrying rank, world_size, and heap_bases. + + Layout: + Current rank's column stripe occupies + ``output[:, ctx.rank * N_local : (ctx.rank + 1) * N_local]`` + where ``N_local = dst_view.N // world_size``. + + Example:: + + tile_obj = iris.x.Tile(pid_m, pid_n, BLOCK_SIZE_M, BLOCK_SIZE_N, c) + dst_view = iris.x.make_tensor_view(C, M, N, stride_cm, stride_cn) + iris.x.all_scatter(tile_obj, dst_view, ctx) + """ + N_local = dst_view.N // ctx.world_size + + # Scatter this rank's tile to all destination ranks + for dest_rank in range(ctx.world_size): + # Compute pointer at this rank's column-stripe offset + dst_ptr, combined_mask = dst_view.offset_tile_ptr(tile, offset_n=ctx.rank * N_local, src_mask=None) + + iris.store( + dst_ptr, + tile.data, + ctx.rank, + dest_rank, + ctx.heap_bases, + mask=combined_mask, + ) diff --git a/tests/ops/test_matmul_all_scatter.py b/tests/ops/test_matmul_all_scatter.py new file mode 100644 index 00000000..9a29634d --- /dev/null +++ b/tests/ops/test_matmul_all_scatter.py @@ -0,0 +1,182 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +""" +Test suite for high-level matmul_all_scatter API. + +Note: This test requires tritonBLAS to be installed. +Install with: pip install git+https://github.com/ROCm/tritonBLAS.git +""" + +import pytest +import torch +import torch.distributed as dist +import iris +import iris.ops as ops + + +@pytest.mark.parametrize( + "dtype, atol, rtol", + [ + (torch.float16, 0.5, 0.01), + (torch.bfloat16, 0.5, 0.01), + ], +) +@pytest.mark.parametrize( + "M, N, K", + [ + (64, 64, 32), + (512, 256, 512), + (1024, 2048, 1024), + ], +) +def test_matmul_all_scatter(dtype, atol, rtol, M, N, K): + """Test matmul_all_scatter using shmem.ops API with proper config. + + Validates against a PyTorch reference: local GEMM on each rank followed by + all_gather to concatenate column shards along the N dimension. + """ + if not dist.is_initialized(): + pytest.skip("torch.distributed not initialized") + + heap_size = 2**33 # 8GB + shmem = iris.iris(heap_size) + rank = shmem.get_rank() + world_size = shmem.get_num_ranks() + + # N must be divisible by world_size for column-wise sharding + if N % world_size != 0: + pytest.skip(f"N={N} not divisible by world_size={world_size}") + + N_local = N // world_size + + # Skip if problem size is too small for world_size + min_block_size = 32 # Smallest block size we use + if N_local < min_block_size: + pytest.skip(f"N_local={N_local} too small for world_size={world_size} (need >= {min_block_size})") + if K < min_block_size: + pytest.skip(f"K={K} too small (need >= {min_block_size})") + if M < min_block_size: + pytest.skip(f"M={M} too small (need >= {min_block_size})") + + # Create shmem tensors directly + A = shmem.randn((M, K), dtype=dtype) + B_shard = shmem.randn((K, N_local), dtype=dtype) + output = shmem.zeros((M, N), dtype=dtype) + + # Reference: compute local GEMM, then all-gather along N dimension + A_ref = A.clone() + B_shard_ref = B_shard.clone() + C_shard_ref = torch.matmul(A_ref, B_shard_ref) + C_shards = [torch.zeros(M, N_local, dtype=dtype, device=f"cuda:{rank}") for _ in range(world_size)] + dist.all_gather(C_shards, C_shard_ref) + pytorch_output = torch.cat(C_shards, dim=1) # Concatenate along N dimension + torch.cuda.synchronize() + + shmem.barrier() + + # Use appropriate block sizes based on problem size + from iris.ops.config import FusedConfig + + # Select config based on actual problem dimensions + if M <= 64 or K <= 64 or N_local <= 64: + config = FusedConfig(block_size_m=32, block_size_n=32, block_size_k=32) + elif M <= 128 or K <= 128 or N_local <= 128: + config = FusedConfig(block_size_m=64, block_size_n=64, block_size_k=32) + else: + config = FusedConfig(block_size_m=128, block_size_n=128, block_size_k=64) + + # Validate config against problem size + assert M >= config.block_size_m, f"M ({M}) must be >= block_size_m ({config.block_size_m})" + assert K >= config.block_size_k, f"K ({K}) must be >= block_size_k ({config.block_size_k})" + assert N_local >= config.block_size_n, f"N_local ({N_local}) must be >= block_size_n ({config.block_size_n})" + + # Use shmem.ops API with proper config + shmem.ops.matmul_all_scatter(output, A, B_shard, config=config) + + torch.cuda.synchronize() + shmem.barrier() + + max_diff = torch.abs(output - pytorch_output).max().item() + + assert torch.allclose(output, pytorch_output, atol=atol, rtol=rtol), ( + f"Max difference: {max_diff}, expected < {atol}\n" + f"Rank {rank}: shmem.ops.matmul_all_scatter output doesn't match reference" + ) + + if rank == 0: + print(f"✓ matmul_all_scatter test passed: {dtype}, M={M}, N={N}, K={K}") + + shmem.barrier() + del shmem + import gc + + gc.collect() + + +@pytest.mark.parametrize( + "dtype, atol, rtol", + [ + (torch.float16, 0.5, 0.01), + (torch.bfloat16, 0.5, 0.01), + ], +) +def test_matmul_all_scatter_semantics(dtype, atol, rtol): + """ + Test that matmul_all_scatter is equivalent to: + C_local = A @ B_local (on each rank) + output = all_gather(C_local, dim=1) (concatenate along N) + """ + if not dist.is_initialized(): + pytest.skip("torch.distributed not initialized") + + heap_size = 2**33 + shmem = iris.iris(heap_size) + rank = shmem.get_rank() + world_size = shmem.get_num_ranks() + + M, N, K = 128, 128, 32 + + if N % world_size != 0: + pytest.skip(f"N={N} not divisible by world_size={world_size}") + + N_local = N // world_size + + if N_local < 32: + pytest.skip(f"N_local={N_local} too small (need >= 32)") + + A = shmem.randn((M, K), dtype=dtype) + B_shard = shmem.randn((K, N_local), dtype=dtype) + output = shmem.zeros((M, N), dtype=dtype) + + # Reference + C_shard_ref = torch.matmul(A.clone(), B_shard.clone()) + C_shards = [torch.zeros(M, N_local, dtype=dtype, device=f"cuda:{rank}") for _ in range(world_size)] + dist.all_gather(C_shards, C_shard_ref) + C_ref = torch.cat(C_shards, dim=1) + torch.cuda.synchronize() + + config = ops.FusedConfig(block_size_m=64, block_size_n=64, block_size_k=32) + + if N_local < config.block_size_n: + pytest.skip(f"N_local={N_local} < block_size_n={config.block_size_n}, skipping") + + from iris.ops.matmul_all_scatter import matmul_all_scatter + + matmul_all_scatter(shmem, output, A, B_shard, config=config) + + torch.cuda.synchronize() + shmem.barrier() + + assert torch.allclose(output, C_ref, atol=atol, rtol=rtol), ( + f"Rank {rank}: matmul_all_scatter semantics mismatch. Max diff: {torch.abs(output - C_ref).max().item()}" + ) + + if rank == 0: + print("matmul_all_scatter semantics verified") + + shmem.barrier() + del shmem + import gc + + gc.collect()