diff --git a/benchmark/examples/benchmark_moe.py b/benchmark/examples/benchmark_moe.py index 160db9e3..c3818fa2 100644 --- a/benchmark/examples/benchmark_moe.py +++ b/benchmark/examples/benchmark_moe.py @@ -133,10 +133,17 @@ def parse_args(): "unfused", "fused_grouped_matmul_convert_ep_to_dp", "fused_convert_dp_to_ep_grouped_matmul", + "wg_fused_grouped_matmul_convert_ep_to_dp", "fused_convert_dp_to_ep_grouped_matmul__grouped_matmul_convert_ep_to_dp", ], help="MoE fusion mode selector", ) + parser.add_argument( + "--gemm_sms", + type=int, + default=None, + help="Override GEMM_SMS for WG-specialized variant (default: auto)", + ) return parser.parse_args() @@ -162,6 +169,7 @@ def _run_dist_once( n_expts_act, shmem, fusion_config, + gemm_sms=None, ): return mixture_of_expt_epsharded( x_dp_local, @@ -172,6 +180,7 @@ def _run_dist_once( n_expts_act, shmem, fusion_config=fusion_config, + gemm_sms=gemm_sms, ) @@ -248,6 +257,7 @@ def _worker(rank: int, world_size: int, init_url: str, args): args.n_expts_act, shmem, fusion_config, + args.gemm_sms, ) if args.validate or args.compare_single_gpu: @@ -259,7 +269,7 @@ def _worker(rank: int, world_size: int, init_url: str, args): dist.all_gather_into_tensor(y_tri, z_dp_local.contiguous()) if args.breakdown: - N_BREAKDOWN_ITERS = 5 + N_BREAKDOWN_ITERS = 10 stage_ms = {} for _ in range(N_BREAKDOWN_ITERS): shmem.heap.allocator.heap_offset = sweep_heap_base @@ -274,6 +284,7 @@ def _worker(rank: int, world_size: int, init_url: str, args): shmem, fusion_config=fusion_config, timing_dict=td, + gemm_sms=args.gemm_sms, ) if rank == 0: for j in range(1, len(td)): @@ -281,10 +292,13 @@ def _worker(rank: int, world_size: int, init_url: str, args): ms = td[j - 1][1].elapsed_time(td[j][1]) stage_ms.setdefault(key, []).append(ms) if rank == 0: - print( - " [breakdown bpe={}] ".format(bpe) - + " ".join("{}={:.2f}ms".format(k, sum(v) / len(v)) for k, v in stage_ms.items()) - ) + total_avg = sum(sum(v) / len(v) for v in stage_ms.values()) + parts = [] + for k, v in stage_ms.items(): + avg = sum(v) / len(v) + pct = 100 * avg / total_avg if total_avg > 0 else 0 + parts.append("{}={:.2f}ms ({:.1f}%)".format(k, avg, pct)) + print(" [breakdown bpe={} total={:.2f}ms] ".format(bpe, total_avg) + " ".join(parts)) result = { "world_size": ws, diff --git a/examples/31_expert_sharded_moe/combine.py b/examples/31_expert_sharded_moe/combine.py index 8498b32f..5f6ec11b 100644 --- a/examples/31_expert_sharded_moe/combine.py +++ b/examples/31_expert_sharded_moe/combine.py @@ -54,6 +54,7 @@ def _convert_ep_to_dp( dst_indx_local = dst_indx_global - dst_rank * n_slots_per_rank offs_n = tl.arange(0, BLOCK) + offs_n = tl.max_contiguous(tl.multiple_of(offs_n, BLOCK), BLOCK) for start_n in range(0, src_shape_n, BLOCK): mask_n = start_n + offs_n < src_shape_n src = tl.load( @@ -64,7 +65,7 @@ def _convert_ep_to_dp( dst_off = dst_indx_local * dst_stride_m + start_n + offs_n for r in tl.static_range(N_RANKS): if dst_rank == r: - iris.store(dst_ptr + dst_off, src, SRC_RANK, r, heap_bases, mask=mask_n) + iris.store(dst_ptr + dst_off, src, SRC_RANK, r, heap_bases, mask=mask_n, hint=16) def convert_ep_to_dp(src, expt_assignment, expt_indx, topk_indx, shmem): diff --git a/examples/31_expert_sharded_moe/dispatch.py b/examples/31_expert_sharded_moe/dispatch.py index 9b589d91..55c491c1 100644 --- a/examples/31_expert_sharded_moe/dispatch.py +++ b/examples/31_expert_sharded_moe/dispatch.py @@ -42,6 +42,7 @@ def _convert_dp_to_ep( off_m_local = pid_m offs_n = tl.arange(0, BLOCK) + offs_n = tl.max_contiguous(tl.multiple_of(offs_n, BLOCK), BLOCK) for act in tl.static_range(N_EXPT_ACT): dst_row = tl.load(dst_row_indx_ptr + off_m_global * dst_row_indx_stride_m + act) @@ -66,7 +67,7 @@ def _convert_dp_to_ep( dst_off = dst_row * dst_stride_m + start_n + offs_n for r in tl.static_range(N_RANKS): if dst_rank == r: - iris.store(dst_ptr + dst_off, src, SRC_RANK, r, heap_bases, mask=mask_n) + iris.store(dst_ptr + dst_off, src, SRC_RANK, r, heap_bases, mask=mask_n, hint=16) def convert_dp_to_ep(src, expt_assignment, expt_indx, gate_indx, shmem): diff --git a/examples/31_expert_sharded_moe/fused_dp_to_ep_matmul.py b/examples/31_expert_sharded_moe/fused_dp_to_ep_matmul.py new file mode 100644 index 00000000..520811e5 --- /dev/null +++ b/examples/31_expert_sharded_moe/fused_dp_to_ep_matmul.py @@ -0,0 +1,221 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. +""" +Fused DP->EP dispatch + expert matmul. + +This module fuses: + convert_dp_to_ep(...) + + grouped_matmul(y_ep_local, w_ep_local, b_ep_local, ...) + +into a single Triton kernel that: + 1) resolves, for each expert-sorted row, which rank owns the source token + 2) gathers the activation row from the owning rank via iris.load (prologue) + 3) computes a tiled GEMM (BLOCK_M x BLOCK_N via tl.dot) + 4) stores the output locally in expert-sorted order (epilogue) + +Grid: (n_n_tiles * n_local_experts,) -- same tiling as grouped_matmul. +Each program loops over M-tiles for one (expert, N-tile) pair. For each +M-tile, it uses combine_indx (col_sorted_indx) to map expert-sorted +positions back to global tokens, determines the owning rank, and pulls +the activation data from that rank's iris heap via per-rank masked 2D +iris.load. + +Prerequisites: + x_dp_local must be copied to the iris heap before launch so that remote + ranks can access it. All ranks allocate the same shape at the same heap + offset (symmetric allocation), making pointer translation correct. +""" + +import torch +import triton +import triton.language as tl +import iris + +from ragged_metadata import RaggedTensorMetadata + + +@triton.jit +def _fused_dp_to_ep_matmul_kernel( + y_ptr, + y_stride_m, + y_stride_n, + x_shmem_ptr, + x_stride_m, + x_stride_k, + w_ptr, + w_stride_e, + w_stride_k, + w_stride_n, + b_ptr, + b_stride_e, + b_stride_n, + slice_offs_ptr, + slice_sizes_ptr, + combine_indx_ptr, + n_local_experts, + n_tokens_local, + n_expts_act, + K, + N, + heap_bases, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + HAS_BIAS: tl.constexpr, + CUR_RANK: tl.constexpr, + N_RANKS: tl.constexpr, +): + pid = tl.program_id(0) + n_n_tiles = tl.cdiv(N, BLOCK_N) + + local_expert_id = pid // n_n_tiles + pid_n = pid % n_n_tiles + + if local_expert_id >= n_local_experts: + return + + local_expert_id_64 = local_expert_id.to(tl.int64) + slice_off = tl.load(slice_offs_ptr + local_expert_id_64).to(tl.int64) + slice_size = tl.load(slice_sizes_ptr + local_expert_id_64) + if slice_size == 0: + return + + n_m_tiles = tl.cdiv(slice_size, BLOCK_M) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_n = tl.max_contiguous(tl.multiple_of(offs_n, BLOCK_N), BLOCK_N) + mask_n = offs_n < N + + for pid_m in range(0, n_m_tiles): + offs_m_local = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_m = slice_off + offs_m_local + mask_m = offs_m_local < slice_size + + # -- Prologue: resolve source rank and local row for each row. -- + flat_idxs = tl.load(combine_indx_ptr + offs_m, mask=mask_m, other=-1) + row_valid = mask_m & (flat_idxs >= 0) + + safe_flat = tl.where(row_valid, flat_idxs, tl.zeros_like(flat_idxs)) + token_ids = safe_flat // n_expts_act + src_ranks = token_ids // n_tokens_local + local_rows = token_ids % n_tokens_local + + # -- Body: tiled GEMM with per-rank remote gather. -- + acc = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + + for start_k in range(0, K, BLOCK_K): + offs_k = start_k + tl.arange(0, BLOCK_K) + mask_k = offs_k < K + + # Build X tile by gathering from each rank's x_dp_local on heap. + x_ptrs = x_shmem_ptr + local_rows[:, None] * x_stride_m + offs_k[None, :] * x_stride_k + x_tile = tl.zeros([BLOCK_M, BLOCK_K], dtype=x_shmem_ptr.dtype.element_ty) + for r in tl.static_range(N_RANKS): + rank_mask = row_valid & (src_ranks == r) + load_mask = rank_mask[:, None] & mask_k[None, :] + if r == CUR_RANK: + loaded = tl.load(x_ptrs, mask=load_mask, other=0.0) + else: + loaded = iris.load(x_ptrs, CUR_RANK, r, heap_bases, mask=load_mask, hint=(1, 16)) + x_tile = tl.where(load_mask, loaded, x_tile) + + w_ptrs = ( + w_ptr + local_expert_id_64 * w_stride_e + offs_k[:, None] * w_stride_k + offs_n[None, :] * w_stride_n + ) + w = tl.load(w_ptrs, mask=mask_k[:, None] & mask_n[None, :], other=0.0) + + acc += tl.dot(x_tile, w) + + if HAS_BIAS: + b_ptrs = b_ptr + local_expert_id_64 * b_stride_e + offs_n * b_stride_n + bias = tl.load(b_ptrs, mask=mask_n, other=0.0) + acc += bias[None, :] + + # -- Epilogue: store output locally (expert-sorted order). -- + y_ptrs = y_ptr + offs_m[:, None] * y_stride_m + offs_n[None, :] * y_stride_n + tl.store(y_ptrs, acc.to(y_ptr.dtype.element_ty), mask=mask_m[:, None] & mask_n[None, :]) + + +def fused_dp_to_ep_matmul( + x_dp_local: torch.Tensor, + w_ep_local: torch.Tensor, + b_ep_local: torch.Tensor | None, + combine_indx: torch.Tensor, + n_expts_act: int, + shmem, + ragged_metadata: RaggedTensorMetadata, +) -> torch.Tensor: + """Gather tokens from remote ranks and compute expert matmul in one kernel. + + Replaces the standalone convert_dp_to_ep + grouped_matmul sequence. + Each GEMM tile's input rows are pulled directly from the owning rank's + iris heap via per-rank masked 2D iris.load. + + Args: + x_dp_local: (n_tokens_local, d_model) local token activations. + w_ep_local: (n_local_experts, K, N) local expert weights. + b_ep_local: (n_local_experts, N) local expert biases or None. + combine_indx: (n_total_slots,) col_sorted_indx mapping expert-sorted + positions back to global flat token*k indices. + n_expts_act: k (experts per token). + shmem: iris.Iris instance. + ragged_metadata: local-expert-view ragged metadata (slice_offs, slice_sizes). + + Returns: + (n_total_slots, N) output in expert-sorted order (same as grouped_matmul). + """ + n_tokens_local, d_model = x_dp_local.shape + n_local_experts = w_ep_local.shape[0] + n_total_slots = combine_indx.shape[0] + K = d_model + N = d_model + + # Place x_dp_local on the iris heap so remote ranks can read it. + x_shmem = shmem.zeros((n_tokens_local, d_model), dtype=x_dp_local.dtype) + x_shmem.copy_(x_dp_local) + shmem.barrier() + + y = torch.zeros((n_total_slots, N), dtype=x_dp_local.dtype, device=x_dp_local.device) + + BLOCK_M = 128 + BLOCK_N = min(triton.next_power_of_2(N), 128) + BLOCK_K = min(triton.next_power_of_2(K), 64) + + n_n_tiles = triton.cdiv(N, BLOCK_N) + grid = (n_n_tiles * n_local_experts,) + + _fused_dp_to_ep_matmul_kernel[grid]( + y, + y.stride(0), + y.stride(1), + x_shmem, + x_shmem.stride(0), + x_shmem.stride(1), + w_ep_local, + w_ep_local.stride(0), + w_ep_local.stride(1), + w_ep_local.stride(2), + b_ep_local if b_ep_local is not None else x_dp_local, + b_ep_local.stride(0) if b_ep_local is not None else 0, + b_ep_local.stride(1) if b_ep_local is not None else 0, + ragged_metadata.slice_offs, + ragged_metadata.slice_sizes, + combine_indx, + n_local_experts, + n_tokens_local, + n_expts_act, + K, + N, + shmem.get_heap_bases(), + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + BLOCK_K=BLOCK_K, + HAS_BIAS=(b_ep_local is not None), + CUR_RANK=shmem.get_rank(), + N_RANKS=shmem.get_num_ranks(), + num_warps=8, + num_stages=2, + matrix_instr_nonkdim=16, + kpack=1, + ) + + return y diff --git a/examples/31_expert_sharded_moe/fused_exp_matmul_ep_to_dp.py b/examples/31_expert_sharded_moe/fused_exp_matmul_ep_to_dp.py index 44e68e4c..ac163d1a 100644 --- a/examples/31_expert_sharded_moe/fused_exp_matmul_ep_to_dp.py +++ b/examples/31_expert_sharded_moe/fused_exp_matmul_ep_to_dp.py @@ -131,7 +131,7 @@ def _fused_exp_matmul_ep_to_dp_kernel( if r == SRC_RANK: tl.store(dst_ptrs_2d, out, mask=store_mask) else: - iris.store(dst_ptrs_2d, out, SRC_RANK, r, heap_bases, mask=store_mask) + iris.store(dst_ptrs_2d, out, SRC_RANK, r, heap_bases, mask=store_mask, hint=(1, 16)) def fused_exp_matmul_ep_to_dp( @@ -213,8 +213,9 @@ def fused_exp_matmul_ep_to_dp( N_RANKS=shmem.get_num_ranks(), num_warps=8, num_stages=2, + matrix_instr_nonkdim=16, + kpack=1, ) - torch.cuda.synchronize() shmem.barrier() return dst_local diff --git a/examples/31_expert_sharded_moe/fused_exp_matmul_ep_to_dp_wg.py b/examples/31_expert_sharded_moe/fused_exp_matmul_ep_to_dp_wg.py new file mode 100644 index 00000000..4d5e1502 --- /dev/null +++ b/examples/31_expert_sharded_moe/fused_exp_matmul_ep_to_dp_wg.py @@ -0,0 +1,302 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. +""" +WG-specialized fused expert matmul + EP->DP combine. + +Workgroup-specialized variant of fused_exp_matmul_ep_to_dp that splits CUs +into persistent GEMM and communication paths within a single kernel: + + GEMM CUs (pid < GEMM_SMS): + Compute tiled GEMM per expert, write to intermediate buffer, signal lock. + Comm CUs (pid >= GEMM_SMS): + Spin-wait on lock, load GEMM output, scatter to token-owning ranks via + per-rank masked iris.store. + +This overlaps GEMM compute with cross-rank scatter communication. +Inspired by examples/10_gemm_all_scatter_wg_specialization. + +Grid: (NUM_SMS,) -- one persistent program per CU. +Lock granularity: one lock per (expert, N-tile, M-tile) triple. +""" + +import math + +import torch +import triton +import triton.language as tl +import iris + +from ragged_metadata import RaggedTensorMetadata + + +@triton.jit +def _wg_fused_exp_matmul_ep_to_dp_kernel( + dst_ptr, + dst_stride_m, + x_ptr, + x_stride_m, + x_stride_k, + w_ptr, + w_stride_e, + w_stride_k, + w_stride_n, + b_ptr, + b_stride_e, + b_stride_n, + y_buf_ptr, + y_stride_m, + y_stride_n, + slice_offs_ptr, + slice_sizes_ptr, + expt_filter_ptr, + expt_filter_stride_m, + expt_indx_ptr, + topk_indx_ptr, + n_local_experts, + n_slots_per_rank, + K, + N, + max_m_tiles, + heap_bases, + locks_ptr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + HAS_BIAS: tl.constexpr, + SRC_RANK: tl.constexpr, + N_RANKS: tl.constexpr, + GEMM_SMS: tl.constexpr, + NUM_SMS: tl.constexpr, +): + pid = tl.program_id(0) + n_n_tiles = tl.cdiv(N, BLOCK_N) + total_en_pairs = n_n_tiles * n_local_experts + + if pid < GEMM_SMS: + # ====== GEMM PATH ====== + for en_pair in range(pid, total_en_pairs, GEMM_SMS): + local_expert_id = en_pair // n_n_tiles + pid_n = en_pair % n_n_tiles + + if local_expert_id < n_local_experts: + local_expert_id_64 = local_expert_id.to(tl.int64) + slice_off = tl.load(slice_offs_ptr + local_expert_id_64).to(tl.int64) + slice_size = tl.load(slice_sizes_ptr + local_expert_id_64) + + if slice_size > 0: + n_m_tiles = tl.cdiv(slice_size, BLOCK_M) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_n = tl.max_contiguous(tl.multiple_of(offs_n, BLOCK_N), BLOCK_N) + mask_n = offs_n < N + + for pid_m in range(0, n_m_tiles): + offs_m_local = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_m = slice_off + offs_m_local + mask_m = offs_m_local < slice_size + + acc = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + + for start_k in range(0, K, BLOCK_K): + offs_k = start_k + tl.arange(0, BLOCK_K) + mask_k = offs_k < K + + x_ptrs = x_ptr + offs_m[:, None] * x_stride_m + offs_k[None, :] * x_stride_k + x = tl.load(x_ptrs, mask=mask_m[:, None] & mask_k[None, :], other=0.0) + + w_ptrs = ( + w_ptr + + local_expert_id_64 * w_stride_e + + offs_k[:, None] * w_stride_k + + offs_n[None, :] * w_stride_n + ) + w = tl.load(w_ptrs, mask=mask_k[:, None] & mask_n[None, :], other=0.0) + + acc += tl.dot(x, w) + + if HAS_BIAS: + b_ptrs = b_ptr + local_expert_id_64 * b_stride_e + offs_n * b_stride_n + bias = tl.load(b_ptrs, mask=mask_n, other=0.0) + acc += bias[None, :] + + out = acc.to(y_buf_ptr.dtype.element_ty) + + y_ptrs = y_buf_ptr + offs_m[:, None] * y_stride_m + offs_n[None, :] * y_stride_n + tl.store(y_ptrs, out, mask=mask_m[:, None] & mask_n[None, :], cache_modifier=".wt") + + tl.debug_barrier() + lock_idx = en_pair * max_m_tiles + pid_m + tl.store(locks_ptr + lock_idx, 1, cache_modifier=".wt") + + else: + # ====== COMMUNICATION PATH ====== + COMM_SMS = NUM_SMS - GEMM_SMS + comm_pid = pid - GEMM_SMS + + for en_pair in range(comm_pid, total_en_pairs, COMM_SMS): + local_expert_id = en_pair // n_n_tiles + pid_n = en_pair % n_n_tiles + + if local_expert_id < n_local_experts: + local_expert_id_64 = local_expert_id.to(tl.int64) + slice_off = tl.load(slice_offs_ptr + local_expert_id_64).to(tl.int64) + slice_size = tl.load(slice_sizes_ptr + local_expert_id_64) + + if slice_size > 0: + n_m_tiles = tl.cdiv(slice_size, BLOCK_M) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_n = tl.max_contiguous(tl.multiple_of(offs_n, BLOCK_N), BLOCK_N) + mask_n = offs_n < N + + for pid_m in range(0, n_m_tiles): + lock_idx = en_pair * max_m_tiles + pid_m + while tl.load(locks_ptr + lock_idx, cache_modifier=".cv", volatile=True) != 1: + pass + + offs_m_local = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_m = slice_off + offs_m_local + mask_m = offs_m_local < slice_size + + dst_indx_globals = tl.load(topk_indx_ptr + offs_m, mask=mask_m, other=-1) + valid_dst = mask_m & (dst_indx_globals >= 0) + + safe_dst_indx = tl.where(valid_dst, dst_indx_globals, tl.zeros_like(dst_indx_globals)) + dst_expt_indxs = tl.load(expt_indx_ptr + safe_dst_indx, mask=valid_dst, other=0).to(tl.int32) + + expt_filter_ptr_local = expt_filter_ptr + SRC_RANK * expt_filter_stride_m + has_dst_expts = ( + ( + tl.load(expt_filter_ptr_local + dst_expt_indxs // 32, mask=valid_dst, other=0) + >> (dst_expt_indxs % 32) + ) + & 1 + ).to(tl.int1) + + row_valid = valid_dst & has_dst_expts + dst_ranks = dst_indx_globals // n_slots_per_rank + dst_indx_locals = dst_indx_globals - dst_ranks * n_slots_per_rank + dst_indx_locals = tl.where(row_valid, dst_indx_locals, tl.zeros_like(dst_indx_locals)) + + y_ptrs = y_buf_ptr + offs_m[:, None] * y_stride_m + offs_n[None, :] * y_stride_n + out = tl.load(y_ptrs, mask=mask_m[:, None] & mask_n[None, :], other=0.0) + + dst_ptrs_2d = dst_ptr + dst_indx_locals[:, None] * dst_stride_m + offs_n[None, :] + for r in tl.static_range(N_RANKS): + rank_mask = row_valid & (dst_ranks == r) + store_mask = rank_mask[:, None] & mask_n[None, :] + if r == SRC_RANK: + tl.store(dst_ptrs_2d, out, mask=store_mask) + else: + iris.store(dst_ptrs_2d, out, SRC_RANK, r, heap_bases, mask=store_mask, hint=(1, 16)) + + +def wg_fused_exp_matmul_ep_to_dp( + x_ep_local: torch.Tensor, + w_ep_local: torch.Tensor, + b_ep_local: torch.Tensor | None, + expt_assignment, + expt_map_local: torch.Tensor, + expt_indx_flat: torch.Tensor, + combine_indx: torch.Tensor, + shmem, + ragged_metadata: RaggedTensorMetadata | None = None, + gemm_sms: int | None = None, +) -> torch.Tensor: + """WG-specialized fused expert matmul + EP->DP scatter. + + Same semantics as fused_exp_matmul_ep_to_dp but uses persistent kernel + with workgroup specialization to overlap GEMM with scatter communication. + + Args: + x_ep_local: (n_total_slots, d_model) dispatched activations. + w_ep_local: (n_local_experts, d_model, d_model) local expert weights. + b_ep_local: (n_local_experts, d_model) local expert biases or None. + expt_assignment: ExptAssignment with bitmask for ownership check. + expt_map_local: (n_expts_tot,) global expert -> local expert id or -1. + expt_indx_flat: (n_total_slots,) flat global expert ids by token-slot. + combine_indx: (n_total_slots,) col_sorted_indx. + shmem: iris.Iris instance. + ragged_metadata: local-expert-view ragged metadata. + gemm_sms: Number of CUs for GEMM path. Default: 2^floor(log2(cu_count)). + + Returns: + (n_slots_per_rank, d_model) DP-local combined output. + """ + expt_bitmask = expt_assignment.expt_bitmask + n_total_slots, d_model = x_ep_local.shape + n_local_experts = w_ep_local.shape[0] + n_slots_per_rank = n_total_slots // shmem.get_num_ranks() + K = d_model + N = d_model + + BLOCK_M = 128 + BLOCK_N = min(triton.next_power_of_2(N), 128) + BLOCK_K = min(triton.next_power_of_2(K), 64) + + max_slice_size = int(ragged_metadata.slice_sizes.max().item()) + max_m_tiles = triton.cdiv(max_slice_size, BLOCK_M) + n_n_tiles = triton.cdiv(N, BLOCK_N) + + if max_m_tiles == 0: + dst_local = shmem.zeros((n_slots_per_rank, d_model), dtype=x_ep_local.dtype) + shmem.barrier() + shmem.barrier() + return dst_local + + device = x_ep_local.device + cu_count = torch.cuda.get_device_properties(device).multi_processor_count + num_sms = cu_count + if gemm_sms is None: + gemm_sms = 2 ** int(math.log2(cu_count)) if cu_count > 0 else 1 + + y_buf = torch.zeros((n_total_slots, N), dtype=x_ep_local.dtype, device=device) + dst_local = shmem.zeros((n_slots_per_rank, d_model), dtype=x_ep_local.dtype) + n_locks = n_n_tiles * n_local_experts * max_m_tiles + locks = torch.zeros(n_locks, dtype=torch.int32, device=device) + + shmem.barrier() + + _wg_fused_exp_matmul_ep_to_dp_kernel[(num_sms,)]( + dst_local, + dst_local.stride(0), + x_ep_local, + x_ep_local.stride(0), + x_ep_local.stride(1), + w_ep_local, + w_ep_local.stride(0), + w_ep_local.stride(1), + w_ep_local.stride(2), + b_ep_local if b_ep_local is not None else x_ep_local, + b_ep_local.stride(0) if b_ep_local is not None else 0, + b_ep_local.stride(1) if b_ep_local is not None else 0, + y_buf, + y_buf.stride(0), + y_buf.stride(1), + ragged_metadata.slice_offs, + ragged_metadata.slice_sizes, + expt_bitmask, + expt_bitmask.stride(0), + expt_indx_flat, + combine_indx, + n_local_experts, + n_slots_per_rank, + K, + N, + max_m_tiles, + shmem.get_heap_bases(), + locks, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + BLOCK_K=BLOCK_K, + HAS_BIAS=(b_ep_local is not None), + SRC_RANK=shmem.get_rank(), + N_RANKS=shmem.get_num_ranks(), + GEMM_SMS=gemm_sms, + NUM_SMS=num_sms, + num_warps=8, + num_stages=2, + matrix_instr_nonkdim=16, + kpack=1, + ) + + shmem.barrier() + return dst_local diff --git a/examples/31_expert_sharded_moe/moe.py b/examples/31_expert_sharded_moe/moe.py index 1905e12c..580fb4bd 100644 --- a/examples/31_expert_sharded_moe/moe.py +++ b/examples/31_expert_sharded_moe/moe.py @@ -24,6 +24,8 @@ from combine import convert_ep_to_dp from grouped_matmul import grouped_matmul from fused_exp_matmul_ep_to_dp import fused_exp_matmul_ep_to_dp +from fused_exp_matmul_ep_to_dp_wg import wg_fused_exp_matmul_ep_to_dp +from fused_dp_to_ep_matmul import fused_dp_to_ep_matmul from reduce import reduce @@ -45,10 +47,12 @@ def _allgather_push_kernel( ): pid = tl.program_id(0) offs = pid * BLOCK + tl.arange(0, BLOCK) + offs = tl.max_contiguous(tl.multiple_of(offs, BLOCK), BLOCK) mask = offs < src_numel data = tl.load(src_ptr + offs, mask=mask) for r in tl.static_range(N_RANKS): - iris.store(dst_ptr + dst_offset + offs, data, CUR_RANK, r, heap_bases, mask=mask) + dst = dst_ptr + dst_offset + offs + iris.store(dst, data, CUR_RANK, r, heap_bases, mask=mask, hint=16) def _allgather_iris(local_tensor, shmem): @@ -163,6 +167,7 @@ class MoeFusionConfig: fuse_convert_dp_to_ep_grouped_matmul: bool = False fuse_grouped_matmul_convert_ep_to_dp: bool = False + fuse_grouped_matmul_convert_ep_to_dp_wg: bool = False def mode_name(self) -> str: parts: list[str] = [] @@ -170,6 +175,8 @@ def mode_name(self) -> str: parts.append("convert_dp_to_ep_grouped_matmul") if self.fuse_grouped_matmul_convert_ep_to_dp: parts.append("grouped_matmul_convert_ep_to_dp") + if self.fuse_grouped_matmul_convert_ep_to_dp_wg: + parts.append("wg_grouped_matmul_convert_ep_to_dp") if not parts: return "unfused" return "fused_" + "__".join(parts) @@ -182,6 +189,8 @@ def from_mode_name(name: str) -> "MoeFusionConfig": return MoeFusionConfig(fuse_grouped_matmul_convert_ep_to_dp=True) if name == "fused_convert_dp_to_ep_grouped_matmul": return MoeFusionConfig(fuse_convert_dp_to_ep_grouped_matmul=True) + if name == "wg_fused_grouped_matmul_convert_ep_to_dp": + return MoeFusionConfig(fuse_grouped_matmul_convert_ep_to_dp_wg=True) if name == "fused_convert_dp_to_ep_grouped_matmul__grouped_matmul_convert_ep_to_dp": return MoeFusionConfig( fuse_convert_dp_to_ep_grouped_matmul=True, @@ -200,6 +209,7 @@ def mixture_of_expt_epsharded( shmem, fusion_config: MoeFusionConfig | None = None, timing_dict: dict | None = None, + gemm_sms: int | None = None, ): """Expert-parallel MoE forward using iris symmetric heap. @@ -262,49 +272,35 @@ def _tick(label): _tick("metadata") # ------------------------------------------------------------------ - # Step 4: DP -> EP dispatch (all-to-all via iris.store) - # ------------------------------------------------------------------ - y_ep_local = convert_dp_to_ep( - x_dp_local, - expt_assignment, - active_indx, - dispatch_indx, - shmem, - ) - _tick("dispatch") - - # ------------------------------------------------------------------ - # Step 5: Remap ragged metadata to local expert view + # Step 4-6: Dispatch, matmul, combine (select fused/unfused variants) # ------------------------------------------------------------------ expt_map = expt_assignment.expt_map[rank, :].contiguous() y_ep_local_metadata = remap_ragged_tensor_metadata(x_global_metadata, expt_map) fusion_config = fusion_config or MoeFusionConfig() - if fusion_config.fuse_convert_dp_to_ep_grouped_matmul: - raise NotImplementedError("Fusion mode convert_dp_to_ep_grouped_matmul is not implemented yet.") + n_fusions_active = sum( + [ + fusion_config.fuse_convert_dp_to_ep_grouped_matmul, + fusion_config.fuse_grouped_matmul_convert_ep_to_dp, + fusion_config.fuse_grouped_matmul_convert_ep_to_dp_wg, + ] + ) + if n_fusions_active > 1: + raise ValueError("At most one fusion mode may be enabled at a time.") - # ------------------------------------------------------------------ - # grouped_matmul + convert_ep_to_dp (select fused/unfused variant) - # ------------------------------------------------------------------ flat_expt_indx = active_indx.to(torch.int32).reshape(-1) - if fusion_config.fuse_grouped_matmul_convert_ep_to_dp: - torch.cuda.synchronize() - shmem.barrier() - y_dp_local = fused_exp_matmul_ep_to_dp( - y_ep_local, + + if fusion_config.fuse_convert_dp_to_ep_grouped_matmul: + y_ep_local = fused_dp_to_ep_matmul( + x_dp_local, w_ep_local, b_ep_local, - expt_assignment, - expt_map, - flat_expt_indx, combine_indx, + n_expts_act, shmem, ragged_metadata=y_ep_local_metadata, ) - _tick("fused_matmul_scatter") - else: - y_ep_local = grouped_matmul(y_ep_local, w_ep_local, b_ep_local, y_ep_local_metadata) - _tick("matmul") + _tick("fused_gather_matmul") y_dp_local = convert_ep_to_dp( y_ep_local, expt_assignment, @@ -313,6 +309,54 @@ def _tick(label): shmem, ) _tick("combine") + else: + y_ep_local = convert_dp_to_ep( + x_dp_local, + expt_assignment, + active_indx, + dispatch_indx, + shmem, + ) + _tick("dispatch") + + if fusion_config.fuse_grouped_matmul_convert_ep_to_dp: + y_dp_local = fused_exp_matmul_ep_to_dp( + y_ep_local, + w_ep_local, + b_ep_local, + expt_assignment, + expt_map, + flat_expt_indx, + combine_indx, + shmem, + ragged_metadata=y_ep_local_metadata, + ) + _tick("fused_matmul_scatter") + elif fusion_config.fuse_grouped_matmul_convert_ep_to_dp_wg: + y_dp_local = wg_fused_exp_matmul_ep_to_dp( + y_ep_local, + w_ep_local, + b_ep_local, + expt_assignment, + expt_map, + flat_expt_indx, + combine_indx, + shmem, + ragged_metadata=y_ep_local_metadata, + gemm_sms=gemm_sms, + ) + _tick("wg_fused_matmul_scatter") + else: + y_ep_local = grouped_matmul(y_ep_local, w_ep_local, b_ep_local, y_ep_local_metadata) + _tick("matmul") + y_dp_local = convert_ep_to_dp( + y_ep_local, + expt_assignment, + flat_expt_indx, + combine_indx, + shmem, + ) + _tick("combine") # ------------------------------------------------------------------ # Step 8: Reduce (unweighted sum, masked) diff --git a/iris/iris.py b/iris/iris.py index 13e8c51f..0f1073b3 100644 --- a/iris/iris.py +++ b/iris/iris.py @@ -1237,31 +1237,16 @@ def reduce_scatter(self, output_tensor, input_tensor, op=None, group=None, async @triton.jit -def __translate(ptr, from_rank, to_rank, heap_bases): +def __translate(ptr, from_rank, to_rank, heap_bases, hint: tl.constexpr = None): from_base = tl.load(heap_bases + from_rank) to_base = tl.load(heap_bases + to_rank) - # convert to int to compute difference ptr_int = tl.cast(ptr, tl.uint64) - # Find the offset from from_rank heap offset = ptr_int - from_base - # Byte cast for byte offset addition to_base_byte = tl.cast(to_base, tl.pointer_type(tl.int8)) - # Find the offset into the to_rank heap translated_ptr_byte = to_base_byte + offset - # Cast to_base back to pointer type translated_ptr = tl.cast(translated_ptr_byte, ptr.dtype) - - # Optimization to vectorize the load/store - # We can't do this in general because we don't know the shape of the tensor or block sizes - # ptr = tl.max_contiguous(tl.multiple_of(ptr, (16, 16)), (16, 32)) - - # 0 You can use this if your block sizes are multiples of 32. - # Largest vectorized load instruction is dwordx4 (128-bits) - # translated_ptr = tl.multiple_of(translated_ptr, (32, 32)) - # translated_ptr = tl.max_contiguous(translated_ptr, (1, 32)) - - # ptr = tl.max_contiguous(tl.multiple_of(ptr, 512), 512) - # translated_ptr = tl.max_contiguous(tl.multiple_of(translated_ptr, 512), 512) + if hint is not None: + translated_ptr = tl.max_contiguous(tl.multiple_of(translated_ptr, hint), hint) return translated_ptr @@ -1438,12 +1423,12 @@ def initialize(context_tensor, rank, world_size, tracing: tl.constexpr = False): return DeviceContext(rank, world_size, heap_bases, device_tracing) @triton.jit - def _translate(self, ptr, from_rank, to_rank): + def _translate(self, ptr, from_rank, to_rank, hint: tl.constexpr = None): """Internal pointer translation between rank address spaces.""" - return __translate(ptr, from_rank, to_rank, self.heap_bases) + return __translate(ptr, from_rank, to_rank, self.heap_bases, hint) @triton.jit - def load(self, pointer, from_rank, mask=None): + def load(self, pointer, from_rank, mask=None, hint: tl.constexpr = None): """ Loads a value from the specified rank's memory location. @@ -1456,6 +1441,7 @@ def load(self, pointer, from_rank, mask=None): pointer (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the current rank's address space that will be translated to the `from_rank`'s address space. from_rank (int): The rank ID from which to read the data. mask (Block of triton.int1, optional): If mask[idx] is false, do not load the data at address pointer[idx]. Defaults to None. + hint (int or tuple, optional): Vectorization hint for the translated pointer. Defaults to None. Returns: Block: The loaded value from the target memory location. @@ -1463,12 +1449,12 @@ def load(self, pointer, from_rank, mask=None): Example: >>> data = ctx.load(buffer + offsets, from_rank=1, mask=mask) """ - translated_ptr = self._translate(pointer, self.rank, from_rank) + translated_ptr = self._translate(pointer, self.rank, from_rank, hint) result = tl.load(translated_ptr, mask=mask) return result @triton.jit - def store(self, pointer, value, to_rank, mask=None): + def store(self, pointer, value, to_rank, mask=None, hint: tl.constexpr = None): """ Writes data to the specified rank's memory location. @@ -1489,11 +1475,11 @@ def store(self, pointer, value, to_rank, mask=None): Example: >>> ctx.store(buffer + offsets, values, to_rank=1, mask=mask) """ - translated_ptr = self._translate(pointer, self.rank, to_rank) + translated_ptr = self._translate(pointer, self.rank, to_rank, hint) tl.store(translated_ptr, value, mask=mask) @triton.jit - def get(self, from_ptr, to_ptr, from_rank, mask=None): + def get(self, from_ptr, to_ptr, from_rank, mask=None, hint: tl.constexpr = None): """ Copies data from the specified rank's memory into current rank's local memory. @@ -1514,12 +1500,12 @@ def get(self, from_ptr, to_ptr, from_rank, mask=None): Example: >>> ctx.get(remote_ptr + offsets, local_ptr + offsets, from_rank=1, mask=mask) """ - translated_from_ptr = self._translate(from_ptr, self.rank, from_rank) + translated_from_ptr = self._translate(from_ptr, self.rank, from_rank, hint) data = tl.load(translated_from_ptr, mask=mask) tl.store(to_ptr, data, mask=mask) @triton.jit - def put(self, from_ptr, to_ptr, to_rank, mask=None): + def put(self, from_ptr, to_ptr, to_rank, mask=None, hint: tl.constexpr = None): """ Copies data from current rank's local memory to the specified rank's memory. @@ -1540,12 +1526,12 @@ def put(self, from_ptr, to_ptr, to_rank, mask=None): Example: >>> ctx.put(local_ptr + offsets, remote_ptr + offsets, to_rank=1, mask=mask) """ - translated_to_ptr = self._translate(to_ptr, self.rank, to_rank) + translated_to_ptr = self._translate(to_ptr, self.rank, to_rank, hint) data = tl.load(from_ptr, mask=mask) tl.store(translated_to_ptr, data, mask=mask) @triton.jit - def copy(self, src_ptr, dst_ptr, from_rank, to_rank, mask=None): + def copy(self, src_ptr, dst_ptr, from_rank, to_rank, mask=None, hint: tl.constexpr = None): """ Copies data from one rank's memory to another rank's memory. @@ -1585,11 +1571,15 @@ def copy(self, src_ptr, dst_ptr, from_rank, to_rank, mask=None): translated_src = tl.cast(from_base_byte + src_offset, src_ptr.dtype) translated_dst = tl.cast(to_base_byte + dst_offset, src_ptr.dtype) + if hint is not None: + translated_src = tl.max_contiguous(tl.multiple_of(translated_src, hint), hint) + translated_dst = tl.max_contiguous(tl.multiple_of(translated_dst, hint), hint) + data = tl.load(translated_src, mask=mask) tl.store(translated_dst, data, mask=mask) @triton.jit - def atomic_add(self, pointer, val, to_rank, mask=None, sem=None, scope=None): + def atomic_add(self, pointer, val, to_rank, mask=None, sem=None, scope=None, hint: tl.constexpr = None): """ Performs an atomic add at the specified rank's memory location. @@ -1612,11 +1602,11 @@ def atomic_add(self, pointer, val, to_rank, mask=None, sem=None, scope=None): Example: >>> old_val = ctx.atomic_add(counter, 1, to_rank=1) """ - translated_ptr = self._translate(pointer, self.rank, to_rank) + translated_ptr = self._translate(pointer, self.rank, to_rank, hint) return tl.atomic_add(translated_ptr, val, mask=mask, sem=sem, scope=scope) @triton.jit - def atomic_sub(self, pointer, val, to_rank, mask=None, sem=None, scope=None): + def atomic_sub(self, pointer, val, to_rank, mask=None, sem=None, scope=None, hint: tl.constexpr = None): """ Atomically subtracts data from the specified rank's memory location. @@ -1636,11 +1626,11 @@ def atomic_sub(self, pointer, val, to_rank, mask=None, sem=None, scope=None): Returns: Block: The data stored at pointer before the atomic operation. """ - translated_ptr = self._translate(pointer, self.rank, to_rank) + translated_ptr = self._translate(pointer, self.rank, to_rank, hint) return tl.atomic_sub(translated_ptr, val, mask=mask, sem=sem, scope=scope) @triton.jit - def atomic_cas(self, pointer, cmp, val, to_rank, sem=None, scope=None): + def atomic_cas(self, pointer, cmp, val, to_rank, sem=None, scope=None, hint: tl.constexpr = None): """ Performs an atomic compare-and-swap at the specified rank's memory location. @@ -1661,11 +1651,11 @@ def atomic_cas(self, pointer, cmp, val, to_rank, sem=None, scope=None): Returns: Block: The data stored at pointer before the atomic operation. """ - translated_ptr = self._translate(pointer, self.rank, to_rank) + translated_ptr = self._translate(pointer, self.rank, to_rank, hint) return tl.atomic_cas(translated_ptr, cmp, val, sem=sem, scope=scope) @triton.jit - def atomic_xchg(self, pointer, val, to_rank, mask=None, sem=None, scope=None): + def atomic_xchg(self, pointer, val, to_rank, mask=None, sem=None, scope=None, hint: tl.constexpr = None): """ Performs an atomic exchange at the specified rank's memory location. @@ -1685,11 +1675,11 @@ def atomic_xchg(self, pointer, val, to_rank, mask=None, sem=None, scope=None): Returns: Block: The data stored at pointer before the atomic operation. """ - translated_ptr = self._translate(pointer, self.rank, to_rank) + translated_ptr = self._translate(pointer, self.rank, to_rank, hint) return tl.atomic_xchg(translated_ptr, val, mask=mask, sem=sem, scope=scope) @triton.jit - def atomic_xor(self, pointer, val, to_rank, mask=None, sem=None, scope=None): + def atomic_xor(self, pointer, val, to_rank, mask=None, sem=None, scope=None, hint: tl.constexpr = None): """ Performs an atomic XOR at the specified rank's memory location. @@ -1709,11 +1699,11 @@ def atomic_xor(self, pointer, val, to_rank, mask=None, sem=None, scope=None): Returns: Block: The data stored at pointer before the atomic operation. """ - translated_ptr = self._translate(pointer, self.rank, to_rank) + translated_ptr = self._translate(pointer, self.rank, to_rank, hint) return tl.atomic_xor(translated_ptr, val, mask=mask, sem=sem, scope=scope) @triton.jit - def atomic_and(self, pointer, val, to_rank, mask=None, sem=None, scope=None): + def atomic_and(self, pointer, val, to_rank, mask=None, sem=None, scope=None, hint: tl.constexpr = None): """ Performs an atomic AND at the specified rank's memory location. @@ -1733,11 +1723,11 @@ def atomic_and(self, pointer, val, to_rank, mask=None, sem=None, scope=None): Returns: Block: The data stored at pointer before the atomic operation. """ - translated_ptr = self._translate(pointer, self.rank, to_rank) + translated_ptr = self._translate(pointer, self.rank, to_rank, hint) return tl.atomic_and(translated_ptr, val, mask=mask, sem=sem, scope=scope) @triton.jit - def atomic_or(self, pointer, val, to_rank, mask=None, sem=None, scope=None): + def atomic_or(self, pointer, val, to_rank, mask=None, sem=None, scope=None, hint: tl.constexpr = None): """ Performs an atomic OR at the specified rank's memory location. @@ -1757,11 +1747,11 @@ def atomic_or(self, pointer, val, to_rank, mask=None, sem=None, scope=None): Returns: Block: The data stored at pointer before the atomic operation. """ - translated_ptr = self._translate(pointer, self.rank, to_rank) + translated_ptr = self._translate(pointer, self.rank, to_rank, hint) return tl.atomic_or(translated_ptr, val, mask=mask, sem=sem, scope=scope) @triton.jit - def atomic_min(self, pointer, val, to_rank, mask=None, sem=None, scope=None): + def atomic_min(self, pointer, val, to_rank, mask=None, sem=None, scope=None, hint: tl.constexpr = None): """ Performs an atomic minimum at the specified rank's memory location. @@ -1781,11 +1771,11 @@ def atomic_min(self, pointer, val, to_rank, mask=None, sem=None, scope=None): Returns: Block: The data stored at pointer before the atomic operation. """ - translated_ptr = self._translate(pointer, self.rank, to_rank) + translated_ptr = self._translate(pointer, self.rank, to_rank, hint) return tl.atomic_min(translated_ptr, val, mask=mask, sem=sem, scope=scope) @triton.jit - def atomic_max(self, pointer, val, to_rank, mask=None, sem=None, scope=None): + def atomic_max(self, pointer, val, to_rank, mask=None, sem=None, scope=None, hint: tl.constexpr = None): """ Performs an atomic maximum at the specified rank's memory location. @@ -1805,12 +1795,12 @@ def atomic_max(self, pointer, val, to_rank, mask=None, sem=None, scope=None): Returns: Block: The data stored at pointer before the atomic operation. """ - translated_ptr = self._translate(pointer, self.rank, to_rank) + translated_ptr = self._translate(pointer, self.rank, to_rank, hint) return tl.atomic_max(translated_ptr, val, mask=mask, sem=sem, scope=scope) @triton.jit -def load(pointer, to_rank, from_rank, heap_bases, mask=None): +def load(pointer, to_rank, from_rank, heap_bases, mask=None, hint: tl.constexpr = None): """ Loads a value from the specified rank's memory location. @@ -1825,6 +1815,7 @@ def load(pointer, to_rank, from_rank, heap_bases, mask=None): from_rank (int): The rank ID from which to read the data. heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. mask (Block of triton.int1, optional): If mask[idx] is false, do not load the data at address pointer[idx]. Defaults to None. + hint (int or tuple, optional): Vectorization hint passed to tl.multiple_of / tl.max_contiguous on the translated pointer. Use a scalar for 1-D (e.g. 16) or a tuple for N-D (e.g. (1, 16)). Defaults to None (no hint). Returns: Block: The loaded value from the target memory location. @@ -1838,13 +1829,13 @@ def load(pointer, to_rank, from_rank, heap_bases, mask=None): >>> data = iris.load(ptr, cur_rank, remote_rank, heap_bases) >>> return data """ - translated_ptr = __translate(pointer, to_rank, from_rank, heap_bases) + translated_ptr = __translate(pointer, to_rank, from_rank, heap_bases, hint) result = tl.load(translated_ptr, mask=mask) return result @triton.jit -def store(pointer, value, from_rank, to_rank, heap_bases, mask=None): +def store(pointer, value, from_rank, to_rank, heap_bases, mask=None, hint: tl.constexpr = None): """ Writes data to the specified rank's memory location. @@ -1860,6 +1851,7 @@ def store(pointer, value, from_rank, to_rank, heap_bases, mask=None): to_rank (int): The rank ID to which the data will be written. heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. mask (Block of triton.int1, optional): If mask[idx] is false, do not store the data at address pointer[idx]. Defaults to None. + hint (int or tuple, optional): Vectorization hint passed to tl.multiple_of / tl.max_contiguous on the translated pointer. Use a scalar for 1-D (e.g. 16) or a tuple for N-D (e.g. (1, 16)). Defaults to None (no hint). Returns: None @@ -1873,12 +1865,12 @@ def store(pointer, value, from_rank, to_rank, heap_bases, mask=None): >>> value = 42 >>> iris.store(ptr, value, cur_rank, remote_rank, heap_bases) """ - translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) + translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases, hint) tl.store(translated_ptr, value, mask=mask) @triton.jit -def copy(src_ptr, dst_ptr, from_rank, to_rank, cur_rank, heap_bases, mask=None): +def copy(src_ptr, dst_ptr, from_rank, to_rank, cur_rank, heap_bases, mask=None, hint: tl.constexpr = None): """ Copies data from the specified rank's memory into the destination rank's memory. This function performs the transfer by translating `src_ptr` from the `from_rank`'s address @@ -1895,6 +1887,7 @@ def copy(src_ptr, dst_ptr, from_rank, to_rank, cur_rank, heap_bases, mask=None): cur_rank (int): The rank ID issuing the copy operation. Must be either `from_rank` or `to_rank`. heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. mask (Block of triton.int1, optional): If mask[idx] is false, do not load from the translated src_ptr[idx] and do not store to dst_ptr[idx]. Defaults to None. + hint (int or tuple, optional): Vectorization hint passed to tl.multiple_of / tl.max_contiguous on the translated pointers. Use a scalar for 1-D (e.g. 16) or a tuple for N-D (e.g. (1, 16)). Defaults to None (no hint). Returns: None @@ -1924,12 +1917,16 @@ def copy(src_ptr, dst_ptr, from_rank, to_rank, cur_rank, heap_bases, mask=None): translated_src = tl.cast(from_base_byte + src_offset, src_ptr.dtype) translated_dst = tl.cast(to_base_byte + dst_offset, src_ptr.dtype) + if hint is not None: + translated_src = tl.max_contiguous(tl.multiple_of(translated_src, hint), hint) + translated_dst = tl.max_contiguous(tl.multiple_of(translated_dst, hint), hint) + data = tl.load(translated_src, mask=mask) tl.store(translated_dst, data, mask=mask) @triton.jit -def get(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None): +def get(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None, hint: tl.constexpr = None): """ Copies data from the specified rank's memory to the current rank's local memory. @@ -1945,6 +1942,7 @@ def get(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None): to_rank (int): The current rank ID where the data will be stored. heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. mask (Block of triton.int1, optional): If mask[idx] is false, do not load the data at address from_ptr[idx] and do not store to to_ptr[idx]. Defaults to None. + hint (int or tuple, optional): Vectorization hint passed to tl.multiple_of / tl.max_contiguous on the translated pointer. Use a scalar for 1-D (e.g. 16) or a tuple for N-D (e.g. (1, 16)). Defaults to None (no hint). Returns: None @@ -1956,7 +1954,7 @@ def get(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None): >>> to_rank = 0 >>> iris.get(remote_ptr, local_ptr, from_rank, to_rank, heap_bases) """ - translated_from_ptr = __translate(from_ptr, from_rank, to_rank, heap_bases) + translated_from_ptr = __translate(from_ptr, from_rank, to_rank, heap_bases, hint) data = tl.load(translated_from_ptr, mask=mask) @@ -1964,7 +1962,7 @@ def get(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None): @triton.jit -def put(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None): +def put(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None, hint: tl.constexpr = None): """ Copies data from the current rank's local memory to the specified rank's memory. This function performs a memory write operation by loading data from the current @@ -1979,6 +1977,7 @@ def put(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None): to_rank (int): The `to_rank` ID to which the data will be written. heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. mask (Block of triton.int1, optional): If mask[idx] is false, do not load the data at address from_ptr[idx] and do not store to to_ptr[idx]. Defaults to None. + hint (int or tuple, optional): Vectorization hint passed to tl.multiple_of / tl.max_contiguous on the translated pointer. Use a scalar for 1-D (e.g. 16) or a tuple for N-D (e.g. (1, 16)). Defaults to None (no hint). Returns: None @@ -1990,7 +1989,7 @@ def put(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None): >>> to_rank = 1 >>> iris.put(local_ptr, remote_ptr, from_rank, to_rank, heap_bases) """ - translated_to_ptr = __translate(to_ptr, from_rank, to_rank, heap_bases) + translated_to_ptr = __translate(to_ptr, from_rank, to_rank, heap_bases, hint) data = tl.load(from_ptr, mask=mask) @@ -1998,7 +1997,9 @@ def put(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None): @triton.jit -def atomic_add(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None): +def atomic_add( + pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None, hint: tl.constexpr = None +): """ Performs an atomic add at the specified rank's memory location. @@ -2016,6 +2017,7 @@ def atomic_add(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None mask (Block of triton.int1, optional): If mask[idx] is false, do not perform the atomic operation at address pointer[idx]. Defaults to None. sem (str, optional): Specifies the memory semantics for the operation. Acceptable values are "acquire", "release", "acq_rel" (stands for "ACQUIRE_RELEASE"), and "relaxed". If not provided, the function defaults to using "acq_rel" semantics. scope (str, optional): Defines the scope of threads that observe the synchronizing effect of the atomic operation. Acceptable values are "gpu" (default), "cta" (cooperative thread array, thread block), or "sys" (stands for "SYSTEM"). The default value is "gpu". + hint (int or tuple, optional): Vectorization hint passed to tl.multiple_of / tl.max_contiguous on the translated pointer. Defaults to None (no hint). Returns: Block: The data stored at pointer before the atomic operation. @@ -2029,12 +2031,14 @@ def atomic_add(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None >>> increment = 5 >>> old_val = iris.atomic_add(ptr, increment, cur_rank, remote_rank, heap_bases) """ - translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) + translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases, hint) return tl.atomic_add(translated_ptr, val, mask=mask, sem=sem, scope=scope) @triton.jit -def atomic_sub(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None): +def atomic_sub( + pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None, hint: tl.constexpr = None +): """ Atomically subtracts data from the specified rank's memory location. @@ -2052,6 +2056,7 @@ def atomic_sub(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None mask (Block of triton.int1, optional): If mask[idx] is false, do not perform the atomic operation at address pointer[idx]. Defaults to None. sem (str, optional): Specifies the memory semantics for the operation. Acceptable values are "acquire", "release", "acq_rel" (stands for "ACQUIRE_RELEASE"), and "relaxed". Defaults to "acq_rel". scope (str, optional): Defines the scope of threads that observe the synchronizing effect of the atomic operation. Acceptable values are "gpu" (default), "cta" (cooperative thread array, thread block), or "sys" (stands for "SYSTEM"). Defaults to "gpu". + hint (int or tuple, optional): Vectorization hint passed to tl.multiple_of / tl.max_contiguous on the translated pointer. Defaults to None (no hint). Returns: Block: The value at the memory location before the atomic subtraction. @@ -2065,12 +2070,12 @@ def atomic_sub(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None >>> decrement = 3 >>> old_val = iris.atomic_sub(ptr, decrement, cur_rank, remote_rank, heap_bases) """ - translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) + translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases, hint) return tl.atomic_sub(translated_ptr, val, mask=mask, sem=sem, scope=scope) @triton.jit -def atomic_cas(pointer, cmp, val, from_rank, to_rank, heap_bases, sem=None, scope=None): +def atomic_cas(pointer, cmp, val, from_rank, to_rank, heap_bases, sem=None, scope=None, hint: tl.constexpr = None): """ Atomically compares and exchanges the specified rank's memory location. @@ -2088,6 +2093,7 @@ def atomic_cas(pointer, cmp, val, from_rank, to_rank, heap_bases, sem=None, scop heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. sem (str, optional): Specifies the memory semantics for the operation. Acceptable values are "acquire", "release", "acq_rel" (stands for "ACQUIRE_RELEASE"), and "relaxed". Defaults to "acq_rel". scope (str, optional): Defines the scope of threads that observe the synchronizing effect of the atomic operation. Acceptable values are "gpu" (default), "cta" (cooperative thread array, thread block), or "sys" (stands for "SYSTEM"). Defaults to "gpu". + hint (int or tuple, optional): Vectorization hint passed to tl.multiple_of / tl.max_contiguous on the translated pointer. Defaults to None (no hint). Returns: Block: The value contained at the memory location before the atomic operation attempt. @@ -2102,12 +2108,14 @@ def atomic_cas(pointer, cmp, val, from_rank, to_rank, heap_bases, sem=None, scop >>> new_val = 42 >>> old_val = iris.atomic_cas(ptr, expected, new_val, cur_rank, remote_rank, heap_bases) """ - translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) + translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases, hint) return tl.atomic_cas(translated_ptr, cmp, val, sem=sem, scope=scope) @triton.jit -def atomic_xchg(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None): +def atomic_xchg( + pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None, hint: tl.constexpr = None +): """ Performs an atomic exchange at the specified rank's memory location. @@ -2125,6 +2133,7 @@ def atomic_xchg(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=Non mask (Block of triton.int1, optional): If mask[idx] is false, do not perform the atomic operation at address pointer[idx]. Defaults to None. sem (str, optional): Specifies the memory semantics for the operation. Acceptable values are "acquire", "release", "acq_rel" (stands for "ACQUIRE_RELEASE"), and "relaxed". If not provided, the function defaults to using "acq_rel" semantics. scope (str, optional): Defines the scope of threads that observe the synchronizing effect of the atomic operation. Acceptable values are "gpu" (default), "cta" (cooperative thread array, thread block), or "sys" (stands for "SYSTEM"). The default value is "gpu". + hint (int or tuple, optional): Vectorization hint passed to tl.multiple_of / tl.max_contiguous on the translated pointer. Defaults to None (no hint). Returns: Block: The data stored at pointer before the atomic operation. @@ -2138,12 +2147,14 @@ def atomic_xchg(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=Non >>> new_value = 99 >>> old_val = iris.atomic_xchg(ptr, new_value, cur_rank, remote_rank, heap_bases) """ - translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) + translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases, hint) return tl.atomic_xchg(translated_ptr, val, mask=mask, sem=sem, scope=scope) @triton.jit -def atomic_xor(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None): +def atomic_xor( + pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None, hint: tl.constexpr = None +): """ Performs an atomic xor at the specified rank's memory location. @@ -2161,6 +2172,7 @@ def atomic_xor(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None mask (Block of triton.int1, optional): If mask[idx] is false, do not perform the atomic operation at address pointer[idx]. Defaults to None. sem (str, optional): Specifies the memory semantics for the operation. Acceptable values are "acquire", "release", "acq_rel" (stands for "ACQUIRE_RELEASE"), and "relaxed". If not provided, the function defaults to using "acq_rel" semantics. scope (str, optional): Defines the scope of threads that observe the synchronizing effect of the atomic operation. Acceptable values are "gpu" (default), "cta" (cooperative thread array, thread block), or "sys" (stands for "SYSTEM"). The default value is "gpu". + hint (int or tuple, optional): Vectorization hint passed to tl.multiple_of / tl.max_contiguous on the translated pointer. Defaults to None (no hint). Returns: Block: The data stored at pointer before the atomic operation. @@ -2174,12 +2186,14 @@ def atomic_xor(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None >>> mask_val = 0xFF >>> old_val = iris.atomic_xor(ptr, mask_val, cur_rank, remote_rank, heap_bases) """ - translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) + translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases, hint) return tl.atomic_xor(translated_ptr, val, mask=mask, sem=sem, scope=scope) @triton.jit -def atomic_and(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None): +def atomic_and( + pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None, hint: tl.constexpr = None +): """ Performs an atomic and at the specified rank's memory location. @@ -2197,6 +2211,7 @@ def atomic_and(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None mask (Block of triton.int1, optional): If mask[idx] is false, do not perform the atomic operation at address pointer[idx]. Defaults to None. sem (str, optional): Specifies the memory semantics for the operation. Acceptable values are "acquire", "release", "acq_rel" (stands for "ACQUIRE_RELEASE"), and "relaxed". If not provided, the function defaults to using "acq_rel" semantics. scope (str, optional): Defines the scope of threads that observe the synchronizing effect of the atomic operation. Acceptable values are "gpu" (default), "cta" (cooperative thread array, thread block), or "sys" (stands for "SYSTEM"). The default value is "gpu". + hint (int or tuple, optional): Vectorization hint passed to tl.multiple_of / tl.max_contiguous on the translated pointer. Defaults to None (no hint). Returns: Block: The data stored at pointer before the atomic operation. @@ -2210,12 +2225,12 @@ def atomic_and(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None >>> mask_val = 0x0F >>> old_val = iris.atomic_and(ptr, mask_val, cur_rank, remote_rank, heap_bases) """ - translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) + translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases, hint) return tl.atomic_and(translated_ptr, val, mask=mask, sem=sem, scope=scope) @triton.jit -def atomic_or(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None): +def atomic_or(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None, hint: tl.constexpr = None): """ Performs an atomic or at the specified rank's memory location. @@ -2233,6 +2248,7 @@ def atomic_or(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, mask (Block of triton.int1, optional): If mask[idx] is false, do not perform the atomic operation at address pointer[idx]. Defaults to None. sem (str, optional): Specifies the memory semantics for the operation. Acceptable values are "acquire", "release", "acq_rel" (stands for "ACQUIRE_RELEASE"), and "relaxed". If not provided, the function defaults to using "acq_rel" semantics. scope (str, optional): Defines the scope of threads that observe the synchronizing effect of the atomic operation. Acceptable values are "gpu" (default), "cta" (cooperative thread array, thread block), or "sys" (stands for "SYSTEM"). The default value is "gpu". + hint (int or tuple, optional): Vectorization hint passed to tl.multiple_of / tl.max_contiguous on the translated pointer. Defaults to None (no hint). Returns: Block: The data stored at pointer before the atomic operation. @@ -2246,12 +2262,14 @@ def atomic_or(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, >>> mask_val = 0xF0 >>> old_val = iris.atomic_or(ptr, mask_val, cur_rank, remote_rank, heap_bases) """ - translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) + translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases, hint) return tl.atomic_or(translated_ptr, val, mask=mask, sem=sem, scope=scope) @triton.jit -def atomic_min(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None): +def atomic_min( + pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None, hint: tl.constexpr = None +): """ Performs an atomic min at the specified rank's memory location. @@ -2269,6 +2287,7 @@ def atomic_min(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None mask (Block of triton.int1, optional): If mask[idx] is false, do not perform the atomic operation at address pointer[idx]. Defaults to None. sem (str, optional): Specifies the memory semantics for the operation. Acceptable values are "acquire", "release", "acq_rel" (stands for "ACQUIRE_RELEASE"), and "relaxed". If not provided, the function defaults to using "acq_rel" semantics. scope (str, optional): Defines the scope of threads that observe the synchronizing effect of the atomic operation. Acceptable values are "gpu" (default), "cta" (cooperative thread array, thread block), or "sys" (stands for "SYSTEM"). The default value is "gpu". + hint (int or tuple, optional): Vectorization hint passed to tl.multiple_of / tl.max_contiguous on the translated pointer. Defaults to None (no hint). Returns: Block: The data stored at pointer before the atomic operation. @@ -2282,12 +2301,14 @@ def atomic_min(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None >>> new_val = 10 >>> old_val = iris.atomic_min(ptr, new_val, cur_rank, remote_rank, heap_bases) """ - translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) + translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases, hint) return tl.atomic_min(translated_ptr, val, mask=mask, sem=sem, scope=scope) @triton.jit -def atomic_max(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None): +def atomic_max( + pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None, hint: tl.constexpr = None +): """ Performs an atomic max at the specified rank's memory location. @@ -2305,6 +2326,7 @@ def atomic_max(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None mask (Block of triton.int1, optional): If mask[idx] is false, do not perform the atomic operation at address pointer[idx]. Defaults to None. sem (str, optional): Specifies the memory semantics for the operation. Acceptable values are "acquire", "release", "acq_rel" (stands for "ACQUIRE_RELEASE"), and "relaxed". If not provided, the function defaults to using "acq_rel" semantics. scope (str, optional): Defines the scope of threads that observe the synchronizing effect of the atomic operation. Acceptable values are "gpu" (default), "cta" (cooperative thread array, thread block), or "sys" (stands for "SYSTEM"). The default value is "gpu". + hint (int or tuple, optional): Vectorization hint passed to tl.multiple_of / tl.max_contiguous on the translated pointer. Defaults to None (no hint). Returns: Block: The data stored at pointer before the atomic operation. @@ -2318,7 +2340,7 @@ def atomic_max(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None >>> new_val = 100 >>> old_val = iris.atomic_max(ptr, new_val, cur_rank, remote_rank, heap_bases) """ - translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) + translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases, hint) return tl.atomic_max(translated_ptr, val, mask=mask, sem=sem, scope=scope) diff --git a/tests/examples/test_expert_sharded_moe.py b/tests/examples/test_expert_sharded_moe.py index 1da6e891..3023eb3c 100644 --- a/tests/examples/test_expert_sharded_moe.py +++ b/tests/examples/test_expert_sharded_moe.py @@ -37,7 +37,15 @@ def _load_module(module_name: str, file_path: Path): @pytest.mark.parametrize("n_tokens,d_model,n_expts_act", [(128, 64, 2)]) -@pytest.mark.parametrize("fusion_mode", ["unfused", "fused_grouped_matmul_convert_ep_to_dp"]) +@pytest.mark.parametrize( + "fusion_mode", + [ + "unfused", + "fused_grouped_matmul_convert_ep_to_dp", + "fused_convert_dp_to_ep_grouped_matmul", + "wg_fused_grouped_matmul_convert_ep_to_dp", + ], +) def test_expert_sharded_moe_matches_reference(n_tokens, d_model, n_expts_act, fusion_mode): if not dist.is_initialized(): pytest.skip("torch.distributed not initialized")