diff --git a/examples/22_rs_rmsnorm_fp8quant_ag/README.md b/examples/22_rs_rmsnorm_fp8quant_ag/README.md new file mode 100644 index 000000000..84c657ff4 --- /dev/null +++ b/examples/22_rs_rmsnorm_fp8quant_ag/README.md @@ -0,0 +1,19 @@ + + +# Reduce-Scatter → RMSNorm → FP8 Quantization → All-Gather benchmark using Iris +This example implements a complete tensor processing pipeline across multiple GPUs: + +1. **Reduce-Scatter**: Sum tensors across all GPUs and distribute shards +2. **RMSNorm**: Apply Root Mean Square normalization to each shard +3. **FP8 Quantization**: Quantize to 8-bit floating point (optional) 4. **All-Gather**: Reconstruct the full tensor across all GPUs (optional) + +## Usage + +```terminal +python benchmark.py --num_rows 8192 --num_cols 7168 --num_ranks 8 --benchmark --fp8_out --all_gather --BLOCK_M 16 --BLOCK_N 64 --num_warps 16 --num_stages 4 --waves_per_eu 4 --rmsnorm_block_size 1024 --rmsnorm_num_warps 8 --rmsnorm_num_prgms 1024 --rmsnorm_waves_per_eu 2 --fp8_block_m 64 --fp8_block_n 64 --fp8_num_warps 4 --fp8_num_stages 2 --fp8_waves_per_eu 2 --ag_block_m 64 --ag_block_n 64 --ag_num_warps 8 --ag_num_stages 3 --ag_waves_per_eu 2 --validate +``` + +The benchmark measures the bandwidth of each GPU receiving data from all other GPUs. Each GPU performs a load operation from every other GPU in the system, and the total bandwidth is calculated based on the total amount of data received and the time taken. diff --git a/examples/22_rs_rmsnorm_fp8quant_ag/benchmark.py b/examples/22_rs_rmsnorm_fp8quant_ag/benchmark.py new file mode 100644 index 000000000..ad331a346 --- /dev/null +++ b/examples/22_rs_rmsnorm_fp8quant_ag/benchmark.py @@ -0,0 +1,1177 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +""" +Benchmark for Reduce-Scatter → RMSNorm → FP8 Quantization pipeline. +""" + +import argparse +import json +import os +import random +import sys +import time + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import triton + +import iris + +# Import kernels from reduce_scatter_rmsnorm_quant.py +from reduce_scatter_rmsnorm_quant import ( + reduce_scatter_m_kernel, + all_gather_m_kernel, + aiter_rmsnorm, + quantize_fp8_kernel, +) + +torch.manual_seed(123) +random.seed(123) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Benchmark Reduce-Scatter → RMSNorm → FP8 Quantization", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("--num_rows", type=int, default=2048, help="Number of rows (M), must be divisible by num_ranks") + parser.add_argument("--num_cols", type=int, default=2048, help="Number of columns (N)") + parser.add_argument( + "--datatype", + type=str, + default="fp16", + choices=["fp16", "fp32", "bf16"], + help="Data type for input/intermediate values", + ) + parser.add_argument("--fp8_out", action="store_true", help="Enable FP8 quantization after RMSNorm") + parser.add_argument("--eps", type=float, default=1e-6, help="RMSNorm epsilon for numerical stability") + parser.add_argument( + "--all_gather", action="store_true", help="Perform all-gather to reconstruct full M×N tensor across all ranks" + ) + parser.add_argument( + "--validate", action="store_true", help="Validate results against PyTorch reference implementation" + ) + parser.add_argument("--benchmark", action="store_true", help="Run performance benchmarks with GPU-side timing") + parser.add_argument("--warmup", type=int, default=10, help="Number of warmup iterations for benchmarking") + parser.add_argument("--iters", type=int, default=100, help="Number of timed iterations for benchmarking") + parser.add_argument( + "--output_file", + type=str, + default="rs_rmsnorm_results.json", + help="Output JSON file for results", + ) + parser.add_argument("--num_ranks", type=int, default=8, help="Number of ranks/GPUs") + parser.add_argument("--heap_size", type=int, default=0, help="IRIS heap size in bytes (0=auto, default: 2GB)") + parser.add_argument("--heap_size_gb", type=float, default=None, help="IRIS heap size in GB (overrides --heap_size)") + parser.add_argument("--BLOCK_M", type=int, default=16, help="Block size M") + parser.add_argument("--BLOCK_N", type=int, default=32, help="Block size N") + parser.add_argument("--GROUP_SIZE_M", type=int, default=8, help="Tile swizzle group size") + parser.add_argument("--NUM_SMS", type=int, default=None, help="Number of CUs (auto-detect if None)") + parser.add_argument("--num_warps", type=int, default=8, help="Number of warps per thread block (reduce-scatter)") + parser.add_argument("--num_stages", type=int, default=2, help="Number of pipeline stages (reduce-scatter)") + parser.add_argument("--waves_per_eu", type=int, default=0, help="Waves per execution unit (reduce-scatter, 0=auto)") + + # RMSNorm specific parameters + parser.add_argument("--rmsnorm_block_size", type=int, default=None, help="RMSNorm BLOCK_SIZE (auto-detect if None)") + parser.add_argument("--rmsnorm_num_warps", type=int, default=None, help="RMSNorm num_warps (default: 8)") + parser.add_argument("--rmsnorm_num_prgms", type=int, default=None, help="RMSNorm NUM_PRGMS (default: M_shard)") + parser.add_argument("--rmsnorm_waves_per_eu", type=int, default=None, help="RMSNorm waves_per_eu (default: 2)") + + # FP8 Quantization specific parameters + parser.add_argument( + "--fp8_block_m", type=int, default=None, help="FP8 BLOCK_M (default: same as reduce-scatter BLOCK_M)" + ) + parser.add_argument( + "--fp8_block_n", type=int, default=None, help="FP8 BLOCK_N (default: same as reduce-scatter BLOCK_N)" + ) + parser.add_argument("--fp8_num_warps", type=int, default=None, help="FP8 num_warps (default: 4)") + parser.add_argument("--fp8_num_stages", type=int, default=None, help="FP8 num_stages (default: 2)") + parser.add_argument("--fp8_waves_per_eu", type=int, default=None, help="FP8 waves_per_eu (default: 0)") + + # All-Gather specific parameters + parser.add_argument( + "--ag_block_m", type=int, default=None, help="All-Gather BLOCK_M (default: same as reduce-scatter)" + ) + parser.add_argument( + "--ag_block_n", type=int, default=None, help="All-Gather BLOCK_N (default: same as reduce-scatter)" + ) + parser.add_argument("--ag_num_warps", type=int, default=None, help="All-Gather num_warps (default: 4)") + parser.add_argument("--ag_num_stages", type=int, default=None, help="All-Gather num_stages (default: 2)") + parser.add_argument("--ag_waves_per_eu", type=int, default=None, help="All-Gather waves_per_eu (default: 0)") + + return vars(parser.parse_args()) + + +def run_reduce_scatter( + input_tensor, + M, + M_shard, + N, + rank, + world_size, + heap_bases, + BLOCK_M, + BLOCK_N, + GROUP_SIZE_M, + NUM_SMS, + num_warps, + num_stages, + waves_per_eu, + dtype, + device, + shmem=None, + output_buffer=None, +): + """Run reduce-scatter operation with pull-based iris.load approach.""" + # Use provided output buffer or allocate new one + if output_buffer is not None: + reduced_shard = output_buffer + elif shmem is not None: + reduced_shard = shmem.zeros((M_shard, N), dtype=dtype) + else: + # Fallback - but this won't work with IRIS operations! + raise ValueError("IRIS operations require output_buffer in IRIS shared memory") + + grid_rs = (NUM_SMS,) + + # Call kernel once - it will pull data from all source ranks using iris.load + reduce_scatter_m_kernel[grid_rs]( + input_tensor, + reduced_shard, + M, + M_shard, + N, + input_tensor.stride(0), + input_tensor.stride(1), + reduced_shard.stride(0), + reduced_shard.stride(1), + rank, + world_size, + heap_bases, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + GROUP_SIZE_M=GROUP_SIZE_M, + NUM_SMS=NUM_SMS, + num_warps=num_warps, + num_stages=num_stages, + waves_per_eu=waves_per_eu, + ) + + # Synchronize to ensure all loads and reductions complete + torch.cuda.synchronize() + if shmem is not None: + shmem.barrier() + + return reduced_shard + + +def run_rmsnorm(input_tensor, eps, device, block_size=None, num_warps=None, num_prgms=None, waves_per_eu=None): + """Run RMSNorm operation using AITer kernel.""" + M_shard, N = input_tensor.shape + dtype = input_tensor.dtype + + gamma = torch.ones(N, device=device, dtype=dtype) + output = torch.empty_like(input_tensor) + rsigma = torch.empty(M_shard, device=device, dtype=dtype) + + # Auto-detect BLOCK_SIZE + if block_size is None: + element_size = input_tensor.element_size() + max_block_size = 65536 // element_size + BLOCK_SIZE = min(max_block_size, triton.next_power_of_2(N)) + else: + BLOCK_SIZE = block_size + + # Always auto-detect USE_BLOCKED based on N and BLOCK_SIZE + USE_BLOCKED = N > BLOCK_SIZE + + # Set NUM_PRGMS (default to M_shard for full parallelism) + NUM_PRGMS = num_prgms if num_prgms is not None else M_shard + + # Set num_warps (default to 8) + final_num_warps = num_warps if num_warps is not None else 8 + + # Set waves_per_eu (default to 2) + final_waves_per_eu = waves_per_eu if waves_per_eu is not None else 2 + + aiter_rmsnorm[(M_shard,)]( + input_tensor, + output, + gamma, + rsigma, + input_tensor.stride(0), + output.stride(0), + M_shard, + N, + eps, + BLOCK_SIZE=BLOCK_SIZE, + USE_BLOCKED=USE_BLOCKED, + NUM_PRGMS=NUM_PRGMS, + num_warps=final_num_warps, + waves_per_eu=final_waves_per_eu, + ) + + return output + + +def run_quantize_fp8(input_tensor, BLOCK_M, BLOCK_N, device, shmem=None): + """Run FP8 quantization.""" + M_shard, N = input_tensor.shape + + max_val = input_tensor.abs().max().item() + scale = max(max_val / 448.0, 1e-8) + scale_tensor = torch.tensor([scale], device=device, dtype=torch.float32) + + # Allocate output - always in regular CUDA memory for FP8 (IRIS may not support FP8) + if hasattr(torch, "float8_e4m3fn"): + output = torch.empty(M_shard, N, device=device, dtype=torch.float8_e4m3fn) + else: + output = torch.empty_like(input_tensor) + + grid = (triton.cdiv(M_shard, BLOCK_M), triton.cdiv(N, BLOCK_N)) + + quantize_fp8_kernel[grid]( + input_tensor, + output, + scale_tensor, + M_shard, + N, + input_tensor.stride(0), + input_tensor.stride(1), + output.stride(0), + output.stride(1), + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + num_warps=16, + waves_per_eu=2, + ) + + return output, scale + + +def run_all_gather( + shard, + M, + M_shard, + N, + rank, + world_size, + heap_bases, + shmem, + BLOCK_M, + BLOCK_N, + GROUP_SIZE_M, + NUM_SMS, + device, + output_buffer=None, +): + """Run all-gather operation.""" + dtype = shard.dtype + + # Use provided output buffer or allocate new one + if output_buffer is not None: + full_output = output_buffer + else: + # Allocate output in IRIS shared memory for remote writes + full_output = shmem.empty((M, N), dtype=dtype) + + grid = (NUM_SMS,) + + all_gather_m_kernel[grid]( + shard, + full_output, + M, + M_shard, + N, + shard.stride(0), + shard.stride(1), + full_output.stride(0), + full_output.stride(1), + rank, + world_size, + heap_bases, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + GROUP_SIZE_M=GROUP_SIZE_M, + NUM_SMS=NUM_SMS, + num_warps=8, + waves_per_eu=2, + ) + + return full_output + + +def _worker(local_rank: int, world_size: int, init_url: str, args: dict): + """Worker function for distributed execution.""" + # Parse arguments + M = args["num_rows"] + N = args["num_cols"] + + assert M % world_size == 0, f"M ({M}) must be divisible by world_size ({world_size})" + M_shard = M // world_size + + # Datatype + dtype_map = { + "fp16": torch.float16, + "fp32": torch.float32, + "bf16": torch.bfloat16, + } + dtype = dtype_map[args["datatype"]] + + # Calculate heap size if auto (0) or use provided value + if args.get("heap_size_gb") is not None: + # User specified GB + heap_size = int(args["heap_size_gb"] * (1024**3)) + elif args["heap_size"] == 0: + # Auto-calculate based on problem size + bytes_per_element = 2 if dtype in [torch.float16, torch.bfloat16] else 4 + fp8_bytes_per_element = 1 + + # Validation allocations: + mem_input = M * N * bytes_per_element # input_tensor + mem_rs_output = M_shard * N * bytes_per_element # reduced_shard + mem_rmsnorm = M_shard * N * bytes_per_element # rmsnorm_output + mem_fp8 = M_shard * N * fp8_bytes_per_element if args["fp8_out"] else 0 # quantized_output (as uint8) + mem_ag_output = ( + M * N * (fp8_bytes_per_element if args["fp8_out"] else bytes_per_element) if args["all_gather"] else 0 + ) + + # Benchmark allocations (if enabled): + if args.get("benchmark"): + mem_test_input = M * N * bytes_per_element # test_input + mem_test_rs = 2 * M_shard * N * bytes_per_element # test_reduced_shard (2x size) + mem_test_rmsnorm = M_shard * N * bytes_per_element # rmsnorm_output_bench + mem_test_fp8 = M_shard * N * fp8_bytes_per_element if args["fp8_out"] else 0 + mem_test_ag = ( + M * N * (fp8_bytes_per_element if args["fp8_out"] else bytes_per_element) if args["all_gather"] else 0 + ) + else: + mem_test_input = mem_test_rs = mem_test_rmsnorm = mem_test_fp8 = mem_test_ag = 0 + + total_mem = ( + mem_input + + mem_rs_output + + mem_rmsnorm + + mem_fp8 + + mem_ag_output + + mem_test_input + + mem_test_rs + + mem_test_rmsnorm + + mem_test_fp8 + + mem_test_ag + ) + + # Add 20% overhead for alignment (1KB per allocation) and safety margin + heap_size = int(total_mem * 1.2) + + # Ensure minimum 1GB + heap_size = max(heap_size, 1 << 30) + else: + heap_size = args["heap_size"] + + # Use gloo backend to avoid below warning for now + # backend = "nccl" if torch.cuda.is_available() else "gloo" + # /opt/venv/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py:4814: + # UserWarning: No device id is provided via `init_process_group` or `barrier + # `. Using the current device set by the user. + backend = "gloo" + dist.init_process_group( + backend=backend, + init_method=init_url, + world_size=world_size, + rank=local_rank, + ) + + # Initialize IRIS with calculated heap size + shmem = iris.iris(heap_size) + rank = shmem.get_rank() + world_size_iris = shmem.get_num_ranks() + + assert world_size == world_size_iris, f"World size mismatch: {world_size} != {world_size_iris}" + + # Set device + torch.cuda.set_device(local_rank) + device = torch.device(f"cuda:{local_rank}") + + # Auto-detect NUM_SMS if not provided + if args["NUM_SMS"] is None: + cu_count = torch.cuda.get_device_properties(local_rank).multi_processor_count + NUM_SMS = cu_count + else: + NUM_SMS = args["NUM_SMS"] + + BLOCK_M = args["BLOCK_M"] + BLOCK_N = args["BLOCK_N"] + GROUP_SIZE_M = args["GROUP_SIZE_M"] + num_warps = args["num_warps"] + num_stages = args["num_stages"] + waves_per_eu = args["waves_per_eu"] + + # RMSNorm parameters - extract from args if they exist + rmsnorm_block_size = args.get("rmsnorm_block_size") + rmsnorm_num_warps = args.get("rmsnorm_num_warps") + rmsnorm_num_prgms = args.get("rmsnorm_num_prgms") + rmsnorm_waves_per_eu = args.get("rmsnorm_waves_per_eu") + + # FP8 Quantization parameters + fp8_block_m = args.get("fp8_block_m") + fp8_block_n = args.get("fp8_block_n") + fp8_num_warps = args.get("fp8_num_warps") + fp8_num_stages = args.get("fp8_num_stages") + fp8_waves_per_eu = args.get("fp8_waves_per_eu") + + # All-Gather parameters + ag_block_m = args.get("ag_block_m") + ag_block_n = args.get("ag_block_n") + ag_num_warps = args.get("ag_num_warps") + ag_num_stages = args.get("ag_num_stages") + ag_waves_per_eu = args.get("ag_waves_per_eu") + + if rank == 0: + print("Configuration:") + print(f" M={M}, N={N}, M_shard={M_shard}") + print(f" dtype={dtype}, world_size={world_size}") + print(" Reduce-Scatter:") + print(f" BLOCK_M={BLOCK_M}, BLOCK_N={BLOCK_N}, GROUP_SIZE_M={GROUP_SIZE_M}, NUM_SMS={NUM_SMS}") + print(f" num_warps={num_warps}, num_stages={num_stages}, waves_per_eu={waves_per_eu}") + print(" RMSNorm Parameters:") + print(f" BLOCK_SIZE: {rmsnorm_block_size or 'auto'}") + print(" USE_BLOCKED: auto (N > BLOCK_SIZE)") + print(f" num_warps: {rmsnorm_num_warps or 8}") + print(f" NUM_PRGMS: {rmsnorm_num_prgms or M_shard}") + print(f" waves_per_eu: {rmsnorm_waves_per_eu if rmsnorm_waves_per_eu is not None else 2}") + print(" FP8 Quantization Parameters:") + print(f" BLOCK_M: {fp8_block_m or BLOCK_M}") + print(f" BLOCK_N: {fp8_block_n or BLOCK_N}") + print(f" num_warps: {fp8_num_warps or 4}") + print(f" num_stages: {fp8_num_stages or 2}") + print(f" waves_per_eu: {fp8_waves_per_eu if fp8_waves_per_eu is not None else 0}") + print(" All-Gather Parameters:") + print(f" BLOCK_M: {ag_block_m or BLOCK_M}") + print(f" BLOCK_N: {ag_block_n or BLOCK_N}") + print(f" num_warps: {ag_num_warps or 4}") + print(f" num_stages: {ag_num_stages or 2}") + print(f" waves_per_eu: {ag_waves_per_eu if ag_waves_per_eu is not None else 0}") + print(f" FP8 output: {args['fp8_out']}") + print(f" All-gather: {args['all_gather']}") + + # Calculate memory requirements (should match auto-calculation logic) + bytes_per_element = 2 if dtype in [torch.float16, torch.bfloat16] else 4 + fp8_bytes_per_element = 1 + + # Validation memory: + mem_input = M * N * bytes_per_element + mem_rs_output = M_shard * N * bytes_per_element + mem_rmsnorm = M_shard * N * bytes_per_element + mem_fp8 = M_shard * N * fp8_bytes_per_element if args["fp8_out"] else 0 + mem_ag_output = ( + M * N * (fp8_bytes_per_element if args["fp8_out"] else bytes_per_element) if args["all_gather"] else 0 + ) + + # Benchmark memory (if enabled): + if args.get("benchmark"): + mem_test_input = M * N * bytes_per_element + mem_test_rs = 2 * M_shard * N * bytes_per_element + mem_test_rmsnorm = M_shard * N * bytes_per_element + mem_test_fp8 = M_shard * N * fp8_bytes_per_element if args["fp8_out"] else 0 + mem_test_ag = ( + M * N * (fp8_bytes_per_element if args["fp8_out"] else bytes_per_element) if args["all_gather"] else 0 + ) + else: + mem_test_input = mem_test_rs = mem_test_rmsnorm = mem_test_fp8 = mem_test_ag = 0 + + total_mem = ( + mem_input + + mem_rs_output + + mem_rmsnorm + + mem_fp8 + + mem_ag_output + + mem_test_input + + mem_test_rs + + mem_test_rmsnorm + + mem_test_fp8 + + mem_test_ag + ) + + # Add 20% overhead for alignment + estimated_heap_bytes = int(total_mem * 1.2) + estimated_heap_mb = estimated_heap_bytes / (1024 * 1024) + + heap_size_mb = heap_size / (1024**2) + print(f" Heap size: {heap_size_mb:.0f} MB {'(auto-calculated)' if args['heap_size'] == 0 else ''}") + print(f" Estimated memory needed: ~{estimated_heap_mb:.0f} MB") + + if estimated_heap_bytes > heap_size: + print("WARNING: May run out of heap memory!") + print(f"Recommended: --heap_size {estimated_heap_bytes}") + print("Or use smaller M/N values") + + # Clear GPU cache + torch.cuda.empty_cache() + + # Create input tensor + torch.manual_seed(123 + rank) + input_tensor_local = torch.randn(M, N, device=device, dtype=dtype) * (rank + 1) + + # Allocate input tensor in IRIS shared memory for remote access + input_tensor = shmem.empty((M, N), dtype=dtype) + input_tensor.copy_(input_tensor_local) + + # IRIS heap bases + heap_bases = shmem.get_heap_bases() + + # Barrier to ensure all ranks have allocated their tensors + shmem.barrier() + + # ================================================================ + # Step 1: Reduce-Scatter + # ================================================================ + # Call kernel once per rank - it will use iris.load() to pull data from all source ranks + reduced_shard = run_reduce_scatter( + input_tensor, + M, + M_shard, + N, + rank, + world_size, + heap_bases, + BLOCK_M, + BLOCK_N, + GROUP_SIZE_M, + NUM_SMS, + num_warps, + num_stages, + waves_per_eu, + dtype, + device, + shmem, + ) + + # Synchronize to ensure all ranks have completed their loads and reductions + torch.cuda.synchronize() + shmem.barrier() + + # ================================================================ + # Step 2: RMSNorm + # ================================================================ + rmsnorm_output = run_rmsnorm( + reduced_shard, + args["eps"], + device, + block_size=rmsnorm_block_size, + num_warps=rmsnorm_num_warps, + num_prgms=rmsnorm_num_prgms, + waves_per_eu=rmsnorm_waves_per_eu, + ) + + # ================================================================ + # Step 3: FP8 Quantization + # ================================================================ + quantized_output = None # Initialize for validation scope + if args["fp8_out"]: + # Allocate in regular CUDA memory + quantized_output, scale = run_quantize_fp8(rmsnorm_output, BLOCK_M, BLOCK_N, device, shmem=None) + + # If all-gather is enabled, copy to IRIS memory as uint8 (workaround for FP8 dtype support) + if args["all_gather"]: + # IRIS may not fully support FP8 dtype, so copy via uint8 byte view + final_output_iris_bytes = shmem.empty((M_shard, N), dtype=torch.uint8) + quantized_bytes = quantized_output.view(torch.uint8) + final_output_iris_bytes.copy_(quantized_bytes) + final_output = final_output_iris_bytes.view(quantized_output.dtype) + else: + final_output = quantized_output + else: + # If all-gather is enabled, ensure rmsnorm_output is in IRIS memory + if args["all_gather"]: + final_output_iris = shmem.empty(rmsnorm_output.shape, dtype=rmsnorm_output.dtype) + final_output_iris.copy_(rmsnorm_output) + final_output = final_output_iris + else: + final_output = rmsnorm_output + + # ================================================================ + # Step 4: All-Gather (optional) + # ================================================================ + if args["all_gather"]: + result = run_all_gather( + final_output, + M, + M_shard, + N, + rank, + world_size, + heap_bases, + shmem, + BLOCK_M, + BLOCK_N, + GROUP_SIZE_M, + NUM_SMS, + device, + ) + torch.cuda.synchronize() + shmem.barrier() + else: + result = final_output + + # ================================================================ + # Validation + # ================================================================ + if args["validate"] and rank == 0: + print("\nValidation:") + print("Note: Validation uses initial pipeline execution (may use different params than benchmark)") + print(" For best results, ensure command-line params match tuned values\n") + + import torch.nn as nn + + # Reference computation + torch.manual_seed(123) + ref_tensors = [] + for i in range(world_size): + torch.manual_seed(123 + i) + tensor = torch.randn(M, N, device=device, dtype=dtype) * (i + 1) + ref_tensors.append(tensor) + + # Use FP32 accumulation to match kernel (more accurate than FP16) + ref_reduced = torch.zeros(M, N, device=device, dtype=torch.float32) + for tensor in ref_tensors: + ref_reduced += tensor.to(torch.float32) + + # Convert back to FP16 and extract shard + ref_shard = ref_reduced[rank * M_shard : (rank + 1) * M_shard, :].to(dtype) + + # Debug: Print sums to diagnose accumulation issues + ref_sum = ref_shard.sum(dtype=torch.float32).item() + actual_sum = reduced_shard.sum(dtype=torch.float32).item() + + # Compare reduce-scatter + rs_diff = torch.abs(ref_shard - reduced_shard) + rel_error = abs(ref_sum - actual_sum) / abs(ref_sum) * 100 + + print(f" Reduce-scatter max diff: {rs_diff.max().item():.8f}") + print(f" Reduce-scatter sum - Reference: {ref_sum:.4f}, Actual: {actual_sum:.4f}, Rel Error: {rel_error:.4f}%") + + # For FP16 with 8-rank accumulation, max diff ~0.1 is acceptable + # The key metric is the sum - should be within 0.1% relative error + if rel_error < 0.1 and rs_diff.max() < 0.1: + print(" ✅ PASS") + else: + print(" ❌ FAIL") + + # Compare RMSNorm + rmsnorm_layer = nn.RMSNorm(N, eps=args["eps"], device=device, dtype=dtype) + ref_normed = rmsnorm_layer(ref_shard) + + # NOTE: rmsnorm_output might use different parameters than benchmark + # This is just a basic sanity check + rms_diff = torch.abs(ref_normed - rmsnorm_output) + print(f" RMSNorm max diff: {rms_diff.max().item():.8f}") + + ref_norm_sum = ref_normed.sum(dtype=torch.float32).item() + actual_norm_sum = rmsnorm_output.sum(dtype=torch.float32).item() + rms_sum_rel_err = abs(ref_norm_sum - actual_norm_sum) / abs(ref_norm_sum) * 100 + print( + f" RMSNorm sum - Reference: {ref_norm_sum:.4f}, Actual: {actual_norm_sum:.4f}, Rel Error: {rms_sum_rel_err:.4f}%" + ) + print(f" {'✅ PASS' if rms_diff.max() < 10.0 else '❌ FAIL'} (initial exec, may differ from benchmark)") + + # Compare FP8 Quantization + if args["fp8_out"] and quantized_output is not None: + # For FP8, just verify the quantization is within expected range + quant_float = quantized_output.to(torch.float32) + + print(f" FP8 Quantization range: [{quant_float.min().item():.2f}, {quant_float.max().item():.2f}]") + print(f" FP8 Quantization sum: {quant_float.sum().item():.4f}") + + # FP8 range should be within [-448, 448] and not all zeros + in_range = (quant_float.min() >= -448.0) and (quant_float.max() <= 448.0) + not_all_zero = quant_float.abs().max() > 0.01 + + print( + f" {'✅ PASS' if (in_range and not_all_zero) else '❌ FAIL'} (values in valid FP8 range and non-zero)" + ) + + # Compare All-Gather + if args["all_gather"]: + # Check value range of full gathered result + result_float = result.to(torch.float32) + result_min = result_float.min().item() + result_max = result_float.max().item() + result_sum = result_float.sum().item() + result_nonzero = (result_float.abs() > 0.01).sum().item() + + print(" All-Gather full result:") + print(f" Value range: [{result_min:.4f}, {result_max:.4f}]") + print(f" Sum: {result_sum:.4f}") + print( + f" Non-zero elements: {result_nonzero}/{result_float.numel()} ({result_nonzero / result_float.numel() * 100:.1f}%)" + ) + + # Verify that this rank's shard appears correctly in the gathered result + ag_shard_result = result[rank * M_shard : (rank + 1) * M_shard, :] + + # Convert to float32 for comparison (FP8 doesn't support some ops) + ag_result_float = ag_shard_result.to(torch.float32) + final_out_float = final_output.to(torch.float32) + + ag_diff_float = torch.abs(ag_result_float - final_out_float) + ag_sum_diff = abs(ag_result_float.sum() - final_out_float.sum()) + ag_rel_err = ag_sum_diff / abs(final_out_float.sum()) * 100 if final_out_float.sum() != 0 else 0.0 + + print( + f" All-Gather (rank {rank} shard) max diff: {ag_diff_float.max().item():.8f}, rel error: {ag_rel_err:.4f}%" + ) + + # Check if result is valid (not all zeros) + is_valid = (abs(result_sum) > 1.0) and (result_nonzero > result_float.numel() * 0.5) + if not is_valid: + print("WARNING: All-Gather result may be invalid (mostly zeros or very small values)") + + print(f" {'✅ PASS' if (ag_diff_float.max() < 0.01 and is_valid) else '❌ FAIL'}") + + # ================================================================ + # Benchmarking + # ================================================================ + if args["benchmark"]: + if rank == 0: + print(f"\nBenchmarking with {args['warmup']} warmup + {args['iters']} iterations...") + + # ---------------------------------------------------------------- + # Benchmark Reduce-Scatter + # ---------------------------------------------------------------- + # Pre-allocate test tensors in IRIS memory (reuse to avoid re-allocation) + test_input = shmem.empty((M, N), dtype=dtype) + test_input_local = torch.randn(M, N, device=device, dtype=dtype) + test_input.copy_(test_input_local) + + # Pre-allocate output buffer in IRIS memory (M_shard × N, will be reused) + test_reduced_shard = shmem.zeros((2 * M_shard, N), dtype=dtype) + + # Warmup + for _ in range(args["warmup"]): + test_reduced_shard.zero_() + _ = run_reduce_scatter( + test_input, + M, + M_shard, + N, + rank, + world_size, + heap_bases, + BLOCK_M, + BLOCK_N, + GROUP_SIZE_M, + NUM_SMS, + num_warps, + num_stages, + waves_per_eu, + dtype, + device, + shmem=shmem, + output_buffer=test_reduced_shard, + ) + torch.cuda.synchronize() + shmem.barrier() + + # Benchmark using CUDA events for accurate GPU timing + # Call kernel directly (not through wrapper) to avoid sync overhead + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + grid_rs = (NUM_SMS,) + + start_event.record() + for _ in range(args["iters"]): + reduce_scatter_m_kernel[grid_rs]( + test_input, + test_reduced_shard, + M, + M_shard, + N, + test_input.stride(0), + test_input.stride(1), + test_reduced_shard.stride(0), + test_reduced_shard.stride(1), + rank, + world_size, + heap_bases, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + GROUP_SIZE_M=GROUP_SIZE_M, + NUM_SMS=NUM_SMS, + num_warps=num_warps, + num_stages=num_stages, + waves_per_eu=waves_per_eu, + ) + end_event.record() + + torch.cuda.synchronize() + rs_time_ms = start_event.elapsed_time(end_event) / args["iters"] + shmem.barrier() + + # ---------------------------------------------------------------- + # Benchmark RMSNorm + # ---------------------------------------------------------------- + # Allocate tensors once (not in the loop!) + gamma_bench = torch.ones(N, device=device, dtype=dtype) + rmsnorm_output_bench = torch.empty_like(reduced_shard) + rsigma_bench = torch.empty(M_shard, device=device, dtype=dtype) + + # Determine RMSNorm parameters + if rmsnorm_block_size is None: + element_size = reduced_shard.element_size() + max_block_size = 65536 // element_size + RMSNORM_BLOCK_SIZE = min(max_block_size, triton.next_power_of_2(N)) + else: + RMSNORM_BLOCK_SIZE = rmsnorm_block_size + + RMSNORM_USE_BLOCKED = N > RMSNORM_BLOCK_SIZE # Always auto-detect + RMSNORM_NUM_PRGMS = M_shard if rmsnorm_num_prgms is None else rmsnorm_num_prgms + RMSNORM_NUM_WARPS = 8 if rmsnorm_num_warps is None else rmsnorm_num_warps + RMSNORM_WAVES_PER_EU = 2 if rmsnorm_waves_per_eu is None else rmsnorm_waves_per_eu + + # Warmup + for _ in range(args["warmup"]): + aiter_rmsnorm[(M_shard,)]( + reduced_shard, + rmsnorm_output_bench, + gamma_bench, + rsigma_bench, + reduced_shard.stride(0), + rmsnorm_output_bench.stride(0), + M_shard, + N, + args["eps"], + BLOCK_SIZE=RMSNORM_BLOCK_SIZE, + USE_BLOCKED=RMSNORM_USE_BLOCKED, + NUM_PRGMS=RMSNORM_NUM_PRGMS, + num_warps=RMSNORM_NUM_WARPS, + waves_per_eu=RMSNORM_WAVES_PER_EU, + ) + torch.cuda.synchronize() + + # Benchmark using CUDA events - call kernel directly + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + start_event.record() + for _ in range(args["iters"]): + aiter_rmsnorm[(M_shard,)]( + reduced_shard, + rmsnorm_output_bench, + gamma_bench, + rsigma_bench, + reduced_shard.stride(0), + rmsnorm_output_bench.stride(0), + M_shard, + N, + args["eps"], + BLOCK_SIZE=RMSNORM_BLOCK_SIZE, + USE_BLOCKED=RMSNORM_USE_BLOCKED, + NUM_PRGMS=RMSNORM_NUM_PRGMS, + num_warps=RMSNORM_NUM_WARPS, + waves_per_eu=RMSNORM_WAVES_PER_EU, + ) + end_event.record() + + torch.cuda.synchronize() + rmsnorm_time_ms = start_event.elapsed_time(end_event) / args["iters"] + + # ---------------------------------------------------------------- + # Benchmark FP8 Quantization + # ---------------------------------------------------------------- + quant_time_ms = 0.0 + if args["fp8_out"]: + # Determine FP8 quantization parameters + FP8_BLOCK_M = fp8_block_m if fp8_block_m is not None else BLOCK_M + FP8_BLOCK_N = fp8_block_n if fp8_block_n is not None else BLOCK_N + FP8_NUM_WARPS = fp8_num_warps if fp8_num_warps is not None else 4 + FP8_NUM_STAGES = fp8_num_stages if fp8_num_stages is not None else 2 + FP8_WAVES_PER_EU = fp8_waves_per_eu if fp8_waves_per_eu is not None else 0 + + # Allocate tensors once + max_val = rmsnorm_output_bench.abs().max().item() + scale = max(max_val / 448.0, 1e-8) + scale_tensor_bench = torch.tensor([scale], device=device, dtype=torch.float32) + + if hasattr(torch, "float8_e4m3fn"): + fp8_output_bench = torch.empty(M_shard, N, device=device, dtype=torch.float8_e4m3fn) + else: + fp8_output_bench = torch.empty_like(rmsnorm_output_bench) + + grid_fp8 = (triton.cdiv(M_shard, FP8_BLOCK_M), triton.cdiv(N, FP8_BLOCK_N)) + + # Warmup + for _ in range(args["warmup"]): + quantize_fp8_kernel[grid_fp8]( + rmsnorm_output_bench, + fp8_output_bench, + scale_tensor_bench, + M_shard, + N, + rmsnorm_output_bench.stride(0), + rmsnorm_output_bench.stride(1), + fp8_output_bench.stride(0), + fp8_output_bench.stride(1), + BLOCK_M=FP8_BLOCK_M, + BLOCK_N=FP8_BLOCK_N, + num_warps=FP8_NUM_WARPS, + num_stages=FP8_NUM_STAGES, + waves_per_eu=FP8_WAVES_PER_EU, + ) + torch.cuda.synchronize() + + # Benchmark using CUDA events - call kernel directly + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + start_event.record() + for _ in range(args["iters"]): + quantize_fp8_kernel[grid_fp8]( + rmsnorm_output_bench, + fp8_output_bench, + scale_tensor_bench, + M_shard, + N, + rmsnorm_output_bench.stride(0), + rmsnorm_output_bench.stride(1), + fp8_output_bench.stride(0), + fp8_output_bench.stride(1), + BLOCK_M=FP8_BLOCK_M, + BLOCK_N=FP8_BLOCK_N, + num_warps=FP8_NUM_WARPS, + num_stages=FP8_NUM_STAGES, + waves_per_eu=FP8_WAVES_PER_EU, + ) + end_event.record() + + torch.cuda.synchronize() + quant_time_ms = start_event.elapsed_time(end_event) / args["iters"] + + # ---------------------------------------------------------------- + # Benchmark All-Gather + # ---------------------------------------------------------------- + ag_time_ms = 0.0 + if args["all_gather"]: + # Determine All-Gather parameters + AG_BLOCK_M = ag_block_m if ag_block_m is not None else BLOCK_M + AG_BLOCK_N = ag_block_n if ag_block_n is not None else BLOCK_N + AG_NUM_WARPS = ag_num_warps if ag_num_warps is not None else 4 + AG_NUM_STAGES = ag_num_stages if ag_num_stages is not None else 2 + AG_WAVES_PER_EU = ag_waves_per_eu if ag_waves_per_eu is not None else 0 + + # Pre-allocate output in IRIS memory (reuse to avoid heap exhaustion) + ag_output_reuse = shmem.empty((M, N), dtype=final_output.dtype) + + grid_ag = (NUM_SMS,) + + # Warmup + for _ in range(args["warmup"]): + all_gather_m_kernel[grid_ag]( + final_output, + ag_output_reuse, + M, + M_shard, + N, + final_output.stride(0), + final_output.stride(1), + ag_output_reuse.stride(0), + ag_output_reuse.stride(1), + rank, + world_size, + heap_bases, + BLOCK_M=AG_BLOCK_M, + BLOCK_N=AG_BLOCK_N, + GROUP_SIZE_M=GROUP_SIZE_M, + NUM_SMS=NUM_SMS, + num_warps=AG_NUM_WARPS, + num_stages=AG_NUM_STAGES, + waves_per_eu=AG_WAVES_PER_EU, + ) + torch.cuda.synchronize() + + # Benchmark using CUDA events - call kernel directly + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + start_event.record() + for _ in range(args["iters"]): + all_gather_m_kernel[grid_ag]( + final_output, + ag_output_reuse, + M, + M_shard, + N, + final_output.stride(0), + final_output.stride(1), + ag_output_reuse.stride(0), + ag_output_reuse.stride(1), + rank, + world_size, + heap_bases, + BLOCK_M=AG_BLOCK_M, + BLOCK_N=AG_BLOCK_N, + GROUP_SIZE_M=GROUP_SIZE_M, + NUM_SMS=NUM_SMS, + num_warps=AG_NUM_WARPS, + num_stages=AG_NUM_STAGES, + waves_per_eu=AG_WAVES_PER_EU, + ) + end_event.record() + + torch.cuda.synchronize() + ag_time_ms = start_event.elapsed_time(end_event) / args["iters"] + + # ---------------------------------------------------------------- + # Calculate metrics for all components + # ---------------------------------------------------------------- + num_elements = M_shard * N + bytes_per_element = dtype.itemsize if hasattr(dtype, "itemsize") else 2 + + # Reduce-Scatter with iris.load (pull-based): + # Each rank loads M_shard×N from (world_size - 1) remote ranks + # Local read doesn't go over interconnect, so we exclude it + # Interconnect bandwidth = data transferred over network / time + rs_interconnect_bytes = M_shard * N * (world_size - 1) * bytes_per_element + rs_bandwidth_gb_s = rs_interconnect_bytes / (rs_time_ms / 1000) / 1e9 + + # RMSNorm: Read (M_shard)×N + write (M_shard)×N + bytes_processed_rmsnorm = num_elements * bytes_per_element * 2 # Read + write + rmsnorm_bandwidth_gb_s = bytes_processed_rmsnorm / (rmsnorm_time_ms / 1000) / 1e9 + + # RMSNorm TFLOPS (approximate) + # RMSNorm: ~3N FLOPs per element (square, sum, rsqrt, multiply) + rmsnorm_flops = num_elements * N * 3 + rmsnorm_tflops = rmsnorm_flops / (rmsnorm_time_ms / 1000) / 1e12 + + # FP8 Quantization: Read FP16/BF16 + write FP8 + quant_bandwidth_gb_s = 0.0 + fp8_bytes = 0 + if args["fp8_out"]: + # Read FP16 (2 bytes) + write FP8 (1 byte) = 3 bytes per element + fp8_bytes = num_elements * 3 + quant_bandwidth_gb_s = fp8_bytes / (quant_time_ms / 1000) / 1e9 + + # All-Gather: Each rank sends M_shard×N to (world_size - 1) remote ranks + # Local write doesn't go over interconnect, so we exclude it + # Interconnect bandwidth = data transferred over network / time + ag_bandwidth_gb_s = 0.0 + ag_interconnect_bytes = 0 + if args["all_gather"]: + # Use actual dtype of data being gathered (FP8 if quantized, otherwise FP16) + ag_bytes_per_element = fp8_output_bench.element_size() if args["fp8_out"] else bytes_per_element + ag_interconnect_bytes = M_shard * N * (world_size - 1) * ag_bytes_per_element + ag_bandwidth_gb_s = ag_interconnect_bytes / (ag_time_ms / 1000) / 1e9 + + # Calculate total bytes and time + total_bytes = rs_interconnect_bytes + bytes_processed_rmsnorm + fp8_bytes + ag_interconnect_bytes + total_time = rs_time_ms + rmsnorm_time_ms + quant_time_ms + ag_time_ms + + # Calculate total effective bandwidth + total_bandwidth_gb_s = total_bytes / (total_time / 1000) / 1e9 + + if rank == 0: + print(f"\n{'=' * 60}") + print("Benchmark Results (Rank 0)") + print(f"{'=' * 60}") + print("Configuration:") + print(f" M={M}, N={N}, M_shard={M_shard}") + print(f" dtype={args['datatype']}, world_size={world_size}") + print(f" Elements per rank: {num_elements:,}") + print("\nComponent Performance:") + print(" Reduce-Scatter:") + print(f" Time: {rs_time_ms:.3f} ms") + print(f" Interconnect BW: {rs_bandwidth_gb_s:.2f} GB/s") + print(f" Data transferred: {rs_interconnect_bytes / 1e9:.3f} GB") + print(" RMSNorm:") + print(f" Time: {rmsnorm_time_ms:.3f} ms") + print(f" Bandwidth: {rmsnorm_bandwidth_gb_s:.2f} GB/s (memory)") + print(f" TFLOPS: {rmsnorm_tflops:.2f}") + + if args["fp8_out"]: + print(" FP8 Quantization:") + print(f" Time: {quant_time_ms:.3f} ms") + print(f" Bandwidth: {quant_bandwidth_gb_s:.2f} GB/s (memory)") + + if args["all_gather"]: + print(" All-Gather:") + print(f" Time: {ag_time_ms:.3f} ms") + print(f" Interconnect BW: {ag_bandwidth_gb_s:.2f} GB/s") + print(f" Data transferred: {ag_interconnect_bytes / 1e9:.3f} GB") + + print("\nTotal Pipeline:") + print(f" Total time: {total_time:.3f} ms") + print(f" Total bandwidth: {total_bandwidth_gb_s:.2f} GB/s") + print(f" Total bytes: {total_bytes / 1e9:.3f} GB") + print(f"{'=' * 60}") + + # Save results + results = { + "M": M, + "N": N, + "M_shard": M_shard, + "world_size": world_size, + "dtype": args["datatype"], + "fp8_out": args["fp8_out"], + "all_gather": args["all_gather"], + # Reduce-Scatter metrics + "reduce_scatter_time_ms": rs_time_ms, + "reduce_scatter_bandwidth_gb_s": rs_bandwidth_gb_s, + # RMSNorm metrics + "rmsnorm_time_ms": rmsnorm_time_ms, + "rmsnorm_bandwidth_gb_s": rmsnorm_bandwidth_gb_s, + "rmsnorm_tflops": rmsnorm_tflops, + # FP8 Quantization metrics + "quant_time_ms": quant_time_ms if args["fp8_out"] else None, + "quant_bandwidth_gb_s": quant_bandwidth_gb_s if args["fp8_out"] else None, + # All-Gather metrics + "all_gather_time_ms": ag_time_ms if args["all_gather"] else None, + "all_gather_bandwidth_gb_s": ag_bandwidth_gb_s if args["all_gather"] else None, + # Total pipeline metrics + "total_time_ms": total_time, + "total_bandwidth_gb_s": total_bandwidth_gb_s, + "total_bytes_gb": total_bytes / 1e9, + # Configuration + "NUM_SMS": NUM_SMS, + "BLOCK_M": BLOCK_M, + "BLOCK_N": BLOCK_N, + "GROUP_SIZE_M": GROUP_SIZE_M, + } + + with open(args["output_file"], "w") as f: + json.dump(results, f, indent=2) + + print(f"\nResults saved to {args['output_file']}") + + if rank == 0: + print(f"\nRank {rank}: Pipeline completed successfully!") + + dist.destroy_process_group() + + +def main(): + args = parse_args() + + world_size = args["num_ranks"] + + init_url = f"tcp://127.0.0.1:{random.randint(20000, 60000)}" + + print(f"Launching {world_size} processes...") + print(f"Init URL: {init_url}") + + # Spawn workers + mp.spawn( + _worker, + args=(world_size, init_url, args), + nprocs=world_size, + join=True, + ) + + print("\nAll processes completed!") + + +if __name__ == "__main__": + main() diff --git a/examples/22_rs_rmsnorm_fp8quant_ag/reduce_scatter_rmsnorm_quant.py b/examples/22_rs_rmsnorm_fp8quant_ag/reduce_scatter_rmsnorm_quant.py new file mode 100644 index 000000000..417254914 --- /dev/null +++ b/examples/22_rs_rmsnorm_fp8quant_ag/reduce_scatter_rmsnorm_quant.py @@ -0,0 +1,654 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. + +""" +Reduce-Scatter → RMSNorm → FP8 Quantization Pipeline + +Task: +- Start with M×N tensor on each of 8 GPUs (same position, different values) +- Reduce (sum) pointwise across all GPUs +- Split along M dimension: Each GPU gets (M/8)×N piece +- RMSNorm along N dimension (locally, since we have full N) +- Quantize to FP8 + +Pipeline: +1. Reduce-Scatter along M dimension: 8 M×N → Each GPU gets (M/world_size)×N +2. RMSNorm on (M/world_size)×N with full N dimension +3. FP8 Quantization +4. (Optional) All-Gather along M dimension to reconstruct full M×N + +Usage: + # Run with torchrun for multi-GPU + torchrun --nproc_per_node=8 reduce_scatter_rmsnorm_quant.py --verify + + # Or use the benchmark script which handles multi-process spawning + python benchmark.py --num_rows 8192 --num_cols 7168 --num_ranks 8 --validate +""" + +import os +import argparse + +import torch +import torch.distributed as dist +import triton +import triton.language as tl + +import iris + + +@triton.jit +def reduce_scatter_m_kernel( + input_ptr, # Local input tensor in IRIS memory: *[M, N] + output_ptr, # Output shard in IRIS memory: *[M_shard, N] + M, + M_shard, + N, + stride_im, + stride_in, + stride_om, + stride_on, + cur_rank: tl.constexpr, + world_size: tl.constexpr, + heap_bases: tl.tensor, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + NUM_SMS: tl.constexpr, +): + """ + Reduce-scatter kernel along M dimension using pull-based approach with iris.load. + + Each rank computes its own output shard by: + - Loading the relevant portion from all ranks (including itself) + - Accumulating the sum locally + - Storing the result + + For example, rank 0 computes output[0:M_shard, :] by: + - Loading input[0:M_shard, :] from rank 0 (local) + - Loading input[0:M_shard, :] from rank 1 (remote via iris.load) + - ... + - Loading input[0:M_shard, :] from rank 7 (remote via iris.load) + - Summing all loaded data + + This kernel is called once per rank. + """ + pid = tl.program_id(0) + + num_pid_m = tl.cdiv(M_shard, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + total_tiles = num_pid_m * num_pid_n + + # Persistent loop over tiles + for tile_id in range(pid, total_tiles, NUM_SMS): + # Swizzle pattern + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + + # Local indices in this rank's output shard (M_shard × N) + rm_local = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + # Add compiler hints + rm_local = tl.max_contiguous(tl.multiple_of(rm_local, BLOCK_M), BLOCK_M) + rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_N), BLOCK_N) + + # Masks + mask_m_local = rm_local < M_shard + mask_n = rn < N + mask = mask_m_local[:, None] & mask_n[None, :] + + # Calculate which rows to read from each source rank's input + # This rank (cur_rank) needs rows [cur_rank*M_shard : (cur_rank+1)*M_shard] + # from ALL source ranks + rm_global = cur_rank * M_shard + rm_local + mask_m_global = rm_global < M + load_mask = mask_m_global[:, None] & mask_n[None, :] + + # Accumulator for the sum across all ranks + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + + # Pointers to the data we need from all ranks + src_ptrs = input_ptr + rm_global[:, None] * stride_im + rn[None, :] * stride_in + + # Load from all source ranks and accumulate + for src_rank in tl.static_range(world_size): + data = iris.load(src_ptrs, cur_rank, src_rank, heap_bases, mask=load_mask) + accumulator += data.to(tl.float32) + + # Store the result to output shard + output_ptrs = output_ptr + rm_local[:, None] * stride_om + rn[None, :] * stride_on + tl.store(output_ptrs, accumulator.to(output_ptr.type.element_ty), mask=mask) + + +@triton.jit +def all_gather_m_kernel( + shard_ptr, # *[M_shard, N] + out_ptr, # *[M, N] + M, + M_shard, + N, + stride_sm, + stride_sn, + stride_om, + stride_on, + cur_rank: tl.constexpr, + world_size: tl.constexpr, + heap_bases: tl.tensor, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + NUM_SMS: tl.constexpr, +): + """ + All-gather kernel along M dimension with 1D persistent-style PID mapping. + Each rank sends its (M_shard)×N to all other ranks. + """ + pid = tl.program_id(0) + + num_pid_m = tl.cdiv(M_shard, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + total_tiles = num_pid_m * num_pid_n + + # Persistent loop over tiles + for tile_id in range(pid, total_tiles, NUM_SMS): + # Swizzle pattern + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + + # Local indices + rm_local = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + rm_local = tl.max_contiguous(tl.multiple_of(rm_local, BLOCK_M), BLOCK_M) + rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_N), BLOCK_N) + mask_m_local = rm_local < M_shard + mask_n = rn < N + + # Load local shard + shard_ptrs = shard_ptr + rm_local[:, None] * stride_sm + rn[None, :] * stride_sn + shard_data = tl.load(shard_ptrs, mask=mask_m_local[:, None] & mask_n[None, :], other=0.0) + + # Send to all ranks at the appropriate M offset + for dst in range(world_size): + # Calculate global M indices + rm_global = cur_rank * M_shard + rm_local + mask_m_global = rm_global < M + + if dst == cur_rank: + # Local store + out_ptrs = out_ptr + rm_global[:, None] * stride_om + rn[None, :] * stride_on + tl.store(out_ptrs, shard_data, mask=mask_m_global[:, None] & mask_n[None, :]) + else: + # Remote store using IRIS + # iris.put(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask) + # from_ptr: local source, to_ptr: remote destination + iris.put( + shard_ptr + rm_local[:, None] * stride_sm + rn[None, :] * stride_sn, # from_ptr (local source) + out_ptr + rm_global[:, None] * stride_om + rn[None, :] * stride_on, # to_ptr (remote dest) + cur_rank, + dst, + heap_bases, + mask=mask_m_global[:, None] & mask_n[None, :], + ) + + +@triton.jit +def aiter_rmsnorm( + input_ptr, + output_ptr, + g_ptr, + rsigma_ptr, + input_row_stride, + output_row_stride, + n_rows, + n_cols, + epsilon, + BLOCK_SIZE: tl.constexpr, + USE_BLOCKED: tl.constexpr, + NUM_PRGMS: tl.constexpr, +): + """RMSNorm kernel from AITer.""" + row_start = tl.program_id(0) + col_offsets = tl.arange(0, BLOCK_SIZE) + + if USE_BLOCKED: + for row_idx in tl.range(row_start, n_rows, NUM_PRGMS, num_stages=1): + row_input_ptr = input_ptr + row_idx * input_row_stride + row_output_ptr = output_ptr + row_idx * output_row_stride + + n_cols_blks = tl.cdiv(n_cols, BLOCK_SIZE) - 1 + sum_squares = 0.0 + for blk_idx in tl.range(0, n_cols_blks, num_stages=2): + cols = blk_idx * BLOCK_SIZE + col_offsets + input_ptrs = row_input_ptr + cols + input_ptrs = tl.multiple_of(input_ptrs, (16,)) + x = tl.load(input_ptrs, cache_modifier=".cg").to(tl.float32) + sum_squares += tl.sum(x * x, axis=0) + + cols = n_cols_blks * BLOCK_SIZE + col_offsets + mask = cols < n_cols + input_ptrs = row_input_ptr + cols + input_ptrs = tl.multiple_of(input_ptrs, (16,)) + x = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg").to(tl.float32) + sum_squares += tl.sum(x * x, axis=0) + + mean_square = sum_squares / n_cols + norm_factor = tl.rsqrt(mean_square + epsilon) + tl.store(rsigma_ptr + row_idx, norm_factor) + + for blk_idx in tl.range(0, n_cols_blks, num_stages=2): + cols = blk_idx * BLOCK_SIZE + col_offsets + input_ptrs = row_input_ptr + cols + input_ptrs = tl.multiple_of(input_ptrs, (16,)) + x = tl.load(input_ptrs, cache_modifier=".cg").to(tl.float32) + g_ptrs = g_ptr + cols + g = tl.load(g_ptrs).to(tl.float32) + rms_norm = x * norm_factor * g + output_ptrs = row_output_ptr + cols + tl.store(output_ptrs, rms_norm.to(output_ptr.type.element_ty)) + + cols = n_cols_blks * BLOCK_SIZE + col_offsets + mask = cols < n_cols + input_ptrs = row_input_ptr + cols + x = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg").to(tl.float32) + g_ptrs = g_ptr + cols + g = tl.load( + g_ptrs, + mask=mask, + other=0.0, + ).to(tl.float32) + rms_norm = x * norm_factor * g + output_ptrs = row_output_ptr + cols + tl.store(output_ptrs, rms_norm.to(output_ptr.type.element_ty), mask=mask) + else: + mask = col_offsets < n_cols + for row_idx in tl.range(row_start, n_rows, NUM_PRGMS, num_stages=2): + input_ptrs = input_ptr + row_idx * input_row_stride + col_offsets + input_ptrs = tl.multiple_of(input_ptrs, (16,)) + row = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg").to(tl.float32) + g = tl.load(g_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32) + row_norm = row * row + row_norm = tl.sum(row_norm, axis=-1) + norm_factor = tl.math.rsqrt((row_norm / n_cols) + epsilon) + tl.store(rsigma_ptr + row_idx, norm_factor) + rms_norm = row * norm_factor * g + output_ptrs = output_ptr + row_idx * output_row_stride + col_offsets + output_ptrs = tl.multiple_of(output_ptrs, (16,)) + tl.store(output_ptrs, rms_norm.to(output_ptr.type.element_ty), mask=mask) + + +@triton.jit +def quantize_fp8_kernel( + input_ptr, + output_ptr, + scale_ptr, + M, + N, + stride_im, + stride_in, + stride_om, + stride_on, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + """FP8 quantization kernel.""" + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_M), BLOCK_M) + rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_N), BLOCK_N) + + mask = (rm[:, None] < M) & (rn[None, :] < N) + + # Load input + input_ptrs = input_ptr + rm[:, None] * stride_im + rn[None, :] * stride_in + data = tl.load(input_ptrs, mask=mask, other=0.0).to(tl.float32) + + # Load scale + scale = tl.load(scale_ptr) + + # Quantize + fp8_max = 448.0 + scaled = data / scale + clamped = tl.clamp(scaled, -fp8_max, fp8_max) + + # Store + output_ptrs = output_ptr + rm[:, None] * stride_om + rn[None, :] * stride_on + tl.store(output_ptrs, clamped.to(output_ptr.type.element_ty), mask=mask) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--num_rows", "--m", type=int, default=8192, help="Number of rows (M)") + parser.add_argument("--num_cols", "--n", type=int, default=7168, help="Number of columns (N)") + parser.add_argument("--num_ranks", "--world_size", type=int, default=8, help="Number of ranks") + parser.add_argument("--dtype", type=str, default="fp16", choices=["bf16", "fp16", "fp32"]) + parser.add_argument("--fp8_out", action="store_true", help="Enable FP8 quantization") + parser.add_argument("--eps", type=float, default=1e-6, help="RMSNorm epsilon") + parser.add_argument("--all_gather", action="store_true", help="All-gather at the end to reconstruct full M×N") + parser.add_argument("--verify", action="store_true", help="Verify against PyTorch reference") + args = parser.parse_args() + + M = args.num_rows + N = args.num_cols + world_size = args.num_ranks + + assert M % world_size == 0, f"M ({M}) must be divisible by world_size ({world_size})" + M_shard = M // world_size + + if args.dtype == "bf16": + dtype = torch.bfloat16 + elif args.dtype == "fp16": + dtype = torch.float16 + else: + dtype = torch.float32 + + # Set device + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + torch.cuda.set_device(local_rank) + device = torch.device(f"cuda:{local_rank}") + + cur_rank = int(os.environ.get("RANK", "0")) + actual_world_size = int(os.environ.get("WORLD_SIZE", str(world_size))) + + if actual_world_size != world_size: + print(f"Warning: WORLD_SIZE ({actual_world_size}) != requested world_size ({world_size})") + world_size = actual_world_size + assert M % world_size == 0, f"M ({M}) must be divisible by world_size ({world_size})" + M_shard = M // world_size + + print(f"Rank {cur_rank}/{world_size}: M={M}, N={N}, M_shard={M_shard}") + + # ================================================================ + # Initialize PyTorch Distributed (required for IRIS) + # ================================================================ + if not dist.is_initialized(): + # Set up distributed environment + os.environ["MASTER_ADDR"] = os.environ.get("MASTER_ADDR", "127.0.0.1") + os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "29500") + os.environ["RANK"] = str(cur_rank) + os.environ["WORLD_SIZE"] = str(world_size) + + dist.init_process_group(backend="gloo", rank=cur_rank, world_size=world_size) + + # ================================================================ + # Initialize IRIS for distributed communication + # ================================================================ + heap_size = 1 << 28 # 256MB + shmem = iris.iris(heap_size) + + # Get heap base addresses for all ranks + heap_bases = shmem.get_heap_bases() + + # ================================================================ + # Create input: Each rank has M×N tensor (same position, different values) + # Must be in IRIS shared memory for remote access via iris.load + # ================================================================ + torch.manual_seed(42 + cur_rank) # Different seed per rank for different values + local_input_temp = torch.randn(M, N, device=device, dtype=dtype) * (cur_rank + 1) + + # Allocate in IRIS shared memory + local_input = shmem.empty((M, N), dtype=dtype) + local_input.copy_(local_input_temp) + del local_input_temp + + print(f"Rank {cur_rank}: Input shape: {local_input.shape}") + + # Barrier to ensure all ranks have allocated their input tensors + shmem.barrier() + + # Default parameters (can be overridden via tuning) + BLOCK_M = 16 + BLOCK_N = 64 + GROUP_SIZE_M = 8 + # MI350 + NUM_SMS = 256 + + # ================================================================ + # Step 1: Reduce-Scatter along M dimension + # Sum all M×N tensors and each rank gets (M/world_size)×N piece + # ================================================================ + print(f"Rank {cur_rank}: Step 1 - Reduce-Scatter along M dimension") + + # Allocate output buffer in IRIS shared memory (must be accessible to all ranks) + reduced_shard = shmem.zeros((M_shard, N), dtype=dtype) + + grid_rs = (NUM_SMS,) + + # Call kernel once - it will use iris.load() to pull data from all source ranks + reduce_scatter_m_kernel[grid_rs]( + local_input, + reduced_shard, + M, + M_shard, + N, + local_input.stride(0), + local_input.stride(1), + reduced_shard.stride(0), + reduced_shard.stride(1), + cur_rank, + world_size, + heap_bases, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + GROUP_SIZE_M=GROUP_SIZE_M, + NUM_SMS=NUM_SMS, + num_warps=16, # Tuned for better performance + num_stages=4, + waves_per_eu=4, + ) + + # Synchronize to ensure all ranks have completed their loads and reductions + torch.cuda.synchronize() + shmem.barrier() + + print(f"Rank {cur_rank}: Reduce-scatter complete, shard shape: {reduced_shard.shape}") + + # ================================================================ + # Step 2: RMSNorm on (M_shard)×N with FULL N dimension + # ================================================================ + print(f"Rank {cur_rank}: Step 2 - RMSNorm on (M_shard)×N") + + gamma = torch.ones(N, device=device, dtype=dtype) + rmsnorm_output = torch.empty_like(reduced_shard) + rsigma = torch.empty(M_shard, device=device, dtype=dtype) + + # AITer RMSNorm configuration + # Note: Tuning found BLOCK_SIZE=1024 optimal for N=7168 (avoid VGPR spills with larger sizes) + BLOCK_SIZE = 1024 + USE_BLOCKED = False # Tuned: non-blocked mode is faster for moderate N + NUM_PRGMS = M_shard # Full parallelism: each program processes one row + + aiter_rmsnorm[(M_shard,)]( + reduced_shard, + rmsnorm_output, + gamma, + rsigma, + reduced_shard.stride(0), + rmsnorm_output.stride(0), + M_shard, + N, + args.eps, + BLOCK_SIZE=BLOCK_SIZE, + USE_BLOCKED=USE_BLOCKED, + NUM_PRGMS=NUM_PRGMS, + num_warps=8, # Tuned for better occupancy + waves_per_eu=2, + ) + + print(f"Rank {cur_rank}: RMSNorm complete, output shape: {rmsnorm_output.shape}") + + # ================================================================ + # Step 3: FP8 Quantization + # ================================================================ + if args.fp8_out: + print(f"Rank {cur_rank}: Step 3 - FP8 Quantization") + + # Compute scale + max_val = rmsnorm_output.abs().max() + scale = (max_val / 448.0).clamp(min=1e-8) + scale_tensor = torch.tensor([scale], device=device, dtype=torch.float32) + + # Quantize + if hasattr(torch, "float8_e4m3fn"): + quantized_output = torch.empty_like(rmsnorm_output, dtype=torch.float8_e4m3fn) + else: + quantized_output = torch.empty_like(rmsnorm_output) + + # FP8 quantization uses medium tile sizes + FP8_BLOCK_M = 64 + FP8_BLOCK_N = 64 + grid_quant = (triton.cdiv(M_shard, FP8_BLOCK_M), triton.cdiv(N, FP8_BLOCK_N)) + + quantize_fp8_kernel[grid_quant]( + rmsnorm_output, + quantized_output, + scale_tensor, + M_shard, + N, + rmsnorm_output.stride(0), + rmsnorm_output.stride(1), + quantized_output.stride(0), + quantized_output.stride(1), + BLOCK_M=FP8_BLOCK_M, + BLOCK_N=FP8_BLOCK_N, + num_warps=4, + num_stages=2, + waves_per_eu=2, + ) + + final_shard = quantized_output + print( + f"Rank {cur_rank}: Quantization complete, shape: {quantized_output.shape}, dtype: {quantized_output.dtype}" + ) + else: + final_shard = rmsnorm_output + print(f"Rank {cur_rank}: No quantization, final shard shape: {final_shard.shape}") + + # ================================================================ + # Step 4 (Optional): All-Gather along M dimension + # ================================================================ + if args.all_gather: + print(f"Rank {cur_rank}: Step 4 - All-Gather along M dimension") + + # Determine output dtype + if args.fp8_out and hasattr(torch, "float8_e4m3fn"): + out_dtype = torch.float8_e4m3fn + else: + out_dtype = dtype + + # Allocate output in IRIS shared memory + full_output = shmem.zeros((M, N), dtype=out_dtype) + + grid_ag = (NUM_SMS,) + + # All-gather uses similar parameters to reduce-scatter + AG_BLOCK_M = 64 + AG_BLOCK_N = 64 + + all_gather_m_kernel[grid_ag]( + final_shard, + full_output, + M, + M_shard, + N, + final_shard.stride(0), + final_shard.stride(1), + full_output.stride(0), + full_output.stride(1), + cur_rank, + world_size, + heap_bases, + BLOCK_M=AG_BLOCK_M, + BLOCK_N=AG_BLOCK_N, + GROUP_SIZE_M=GROUP_SIZE_M, + NUM_SMS=NUM_SMS, + num_warps=8, + num_stages=3, + waves_per_eu=2, + ) + + # Synchronize to ensure all ranks have completed their puts + torch.cuda.synchronize() + + print(f"Rank {cur_rank}: All-gather complete, full output shape: {full_output.shape}") + result = full_output + else: + result = final_shard + print(f"Rank {cur_rank}: Skipping all-gather, result shape: {result.shape}") + + # ================================================================ + # Verification + # ================================================================ + if args.verify and cur_rank == 0: + print("\n" + "=" * 60) + print("Verification against PyTorch reference") + print("=" * 60) + + import torch.nn as nn + + # Reference computation + torch.manual_seed(42) + ref_tensors = [] + for i in range(world_size): + torch.manual_seed(42 + i) + tensor = torch.randn(M, N, device=device, dtype=dtype) * (i + 1) + ref_tensors.append(tensor) + + # Pointwise reduce (sum) + ref_reduced = torch.zeros(M, N, device=device, dtype=dtype) + for tensor in ref_tensors: + ref_reduced += tensor + + print(f"Reference reduced sum: {ref_reduced.sum(dtype=torch.float32):.4f}") + + # Extract this rank's shard + start_row = cur_rank * M_shard + end_row = (cur_rank + 1) * M_shard + ref_shard = ref_reduced[start_row:end_row, :] + + # Compare reduce-scatter result + rs_diff = torch.abs(ref_shard - reduced_shard) + print(f"Reduce-scatter max diff: {rs_diff.max().item():.8f}") + + if rs_diff.max().item() < 1e-5: + print("✅ Reduce-scatter verification PASSED") + else: + print("❌ Reduce-scatter verification FAILED") + + # RMSNorm + rmsnorm_layer = nn.RMSNorm(N, eps=args.eps, device=device, dtype=dtype) + ref_normed = rmsnorm_layer(ref_shard) + + print(f"\nReference RMSNorm sum: {ref_normed.sum(dtype=torch.float32):.4f}") + print(f"Triton RMSNorm sum: {rmsnorm_output.sum(dtype=torch.float32):.4f}") + + rms_diff = torch.abs(ref_normed - rmsnorm_output) + print(f"RMSNorm max diff: {rms_diff.max().item():.8f}") + print(f"RMSNorm mean diff: {rms_diff.mean().item():.8f}") + + if rms_diff.max().item() < 1e-2: + print("✅ RMSNorm verification PASSED") + else: + print("❌ RMSNorm verification FAILED") + + print(f"\nRank {cur_rank}: Pipeline completed successfully!") + + # Cleanup + if dist.is_initialized(): + dist.destroy_process_group() + + +if __name__ == "__main__": + main()