diff --git a/iris/experimental/iris_gluon.py b/iris/experimental/iris_gluon.py index 004a85ac..f2dc3608 100644 --- a/iris/experimental/iris_gluon.py +++ b/iris/experimental/iris_gluon.py @@ -142,7 +142,7 @@ def _translate(self, ptr, from_rank, to_rank): return translated_ptr @gluon.jit - def load(self, pointer, from_rank, mask=None): + def load(self, pointer, from_rank, mask=None, other=None): """ Loads a value from the specified rank's memory location to the current rank. @@ -150,6 +150,7 @@ def load(self, pointer, from_rank, mask=None): pointer: Pointer in the `from_rank`'s address space from_rank: The rank ID from which to read the data mask: Optional mask for conditional loading + other: Value to return for masked-out elements. If not provided, the result for masked-out elements is undefined. Returns: The loaded value from the target memory location @@ -159,7 +160,7 @@ def load(self, pointer, from_rank, mask=None): >>> data = ctx.load(buffer + offsets, 1, mask=mask) """ translated_ptr = self._translate(pointer, self.cur_rank, from_rank) - result = gl.load(translated_ptr, mask=mask) + result = gl.load(translated_ptr, mask=mask, other=other) return result @gluon.jit @@ -181,7 +182,7 @@ def store(self, pointer, value, to_rank, mask=None): gl.store(translated_ptr, value, mask=mask) @gluon.jit - def get(self, from_ptr, to_ptr, from_rank, mask=None): + def get(self, from_ptr, to_ptr, from_rank, mask=None, other=None): """ Copies data from the specified rank's memory to the current rank's local memory. @@ -190,17 +191,18 @@ def get(self, from_ptr, to_ptr, from_rank, mask=None): to_ptr: Pointer to local memory in current rank from_rank: The rank ID from which to read the data mask: Optional mask for conditional operations + other: Value to return for masked-out elements during the load operation. If not provided, the result for masked-out elements is undefined. Example: >>> # Copy from rank 1 to current rank's local memory >>> ctx.get(remote_ptr + offsets, local_ptr + offsets, 1, mask=mask) """ translated_from_ptr = self._translate(from_ptr, self.cur_rank, from_rank) - data = gl.load(translated_from_ptr, mask=mask) + data = gl.load(translated_from_ptr, mask=mask, other=other) gl.store(to_ptr, data, mask=mask) @gluon.jit - def put(self, from_ptr, to_ptr, to_rank, mask=None): + def put(self, from_ptr, to_ptr, to_rank, mask=None, other=None): """ Copies data from the current rank's local memory to the specified rank's memory. @@ -209,17 +211,18 @@ def put(self, from_ptr, to_ptr, to_rank, mask=None): to_ptr: Pointer to remote memory in `to_rank`'s address space to_rank: The rank ID to which the data will be written mask: Optional mask for conditional operations + other: Value to return for masked-out elements during the load operation. If not provided, the result for masked-out elements is undefined. Example: >>> # Copy from current rank's local memory to rank 1 >>> ctx.put(local_ptr + offsets, remote_ptr + offsets, 1, mask=mask) """ translated_to_ptr = self._translate(to_ptr, self.cur_rank, to_rank) - data = gl.load(from_ptr, mask=mask) + data = gl.load(from_ptr, mask=mask, other=other) gl.store(translated_to_ptr, data, mask=mask) @gluon.jit - def copy(self, src_ptr, dst_ptr, from_rank, to_rank, mask=None): + def copy(self, src_ptr, dst_ptr, from_rank, to_rank, mask=None, other=None): """ Copies data from the specified rank's memory into the destination rank's memory. @@ -235,6 +238,7 @@ def copy(self, src_ptr, dst_ptr, from_rank, to_rank, mask=None): from_rank: The rank ID that owns `src_ptr` (source rank) to_rank: The rank ID that will receive the data (destination rank) mask: Optional mask for conditional operations + other: Value to return for masked-out elements during the load operation. If not provided, the result for masked-out elements is undefined. Example: >>> # Copy from rank 1 to rank 0 (current rank must be either 1 or 0) @@ -256,7 +260,7 @@ def copy(self, src_ptr, dst_ptr, from_rank, to_rank, mask=None): translated_src = tl.cast(from_base_byte + src_offset, src_ptr.dtype) translated_dst = tl.cast(to_base_byte + dst_offset, src_ptr.dtype) - data = gl.load(translated_src, mask=mask) + data = gl.load(translated_src, mask=mask, other=other) gl.store(translated_dst, data, mask=mask) @gluon.jit diff --git a/iris/iris.py b/iris/iris.py index 3cd96a5a..239fd070 100644 --- a/iris/iris.py +++ b/iris/iris.py @@ -1725,7 +1725,7 @@ def __translate(ptr, from_rank, to_rank, heap_bases): @triton.jit -def load(pointer, to_rank, from_rank, heap_bases, mask=None): +def load(pointer, to_rank, from_rank, heap_bases, mask=None, other=None, cache_modifier=None, volatile=False): """ Loads a value from the specified rank's memory location. @@ -1734,12 +1734,29 @@ def load(pointer, to_rank, from_rank, heap_bases, mask=None): data from the target memory location. If the `from_rank` and `to_rank` are the same, this function performs a local load operation. + The `cache_modifier` parameter controls instruction-level cache behavior + by setting the appropriate scope (`SC0`, `SC1`) and non-temporal (`NT`) bits + in the global load instruction. These affect cache usage across the CU, + L2, and last-level caches. + Args: pointer (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the `from_rank`'s address space that will be translated to the `to_rank`'s address space. Must be the current rank where the pointer is local. to_rank (int): The rank ID to which the pointer will be translated. Must be the current rank where the pointer is local. from_rank (int): The rank ID from which to read the data. heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. mask (Block of triton.int1, optional): If mask[idx] is false, do not load the data at address pointer[idx]. Defaults to None. + other (Block, optional): Value to return for masked-out elements. If not provided, the result for masked-out elements is undefined. Defaults to None. + cache_modifier (str, optional): Controls cache behavior of the load. + + Supported values: + - None: *(default)* — Same as ".ca". Uses cache at all levels (CU, L2, LLC) with LRU policy. + - ".ca": Cache at all levels (CU, L2, LLC) with LRU policy + - ".cg": Bypasses the CU (L1) cache, streams through L2, and may hit in LLC but the line is not retained or inserted. + - ".cv": Bypasses all GPU caches (CU and L2) and fetches directly from system memory. If data exists in the LLC, it may hit, but is not retained or inserted. + Ensures global coherence by invalidating stale GPU cache lines. + + volatile (bool, optional): If True, disables compiler optimizations that + could reorder or eliminate the load. Returns: Block: The loaded value from the target memory location. @@ -1754,12 +1771,12 @@ def load(pointer, to_rank, from_rank, heap_bases, mask=None): >>> return data """ translated_ptr = __translate(pointer, to_rank, from_rank, heap_bases) - result = tl.load(translated_ptr, mask=mask) + result = tl.load(translated_ptr, mask=mask, other=other, cache_modifier=cache_modifier, volatile=volatile) return result @triton.jit -def store(pointer, value, from_rank, to_rank, heap_bases, mask=None): +def store(pointer, value, from_rank, to_rank, heap_bases, mask=None, cache_modifier=None): """ Writes data to the specified rank's memory location. @@ -1768,6 +1785,11 @@ def store(pointer, value, from_rank, to_rank, heap_bases, mask=None): the provided data to the target memory location. If the `from_rank` and `to_rank` are the same, this function performs a local store operation. + The `cache_modifier` parameter controls instruction-level cache behavior + by setting the appropriate scope (`SC0`, `SC1`) and non-temporal (`NT`) bits + in the global store instruction. These affect cache usage across the CU (L1), + L2, and last-level cache (LLC), following the CDNA ISA. + Args: pointer (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the `from_rank`'s address space that will be translated to the `to_rank`'s address space. Must be the current rank where the pointer is local. value (Block): The tensor of elements to be stored. @@ -1775,6 +1797,13 @@ def store(pointer, value, from_rank, to_rank, heap_bases, mask=None): to_rank (int): The rank ID to which the data will be written. heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. mask (Block of triton.int1, optional): If mask[idx] is false, do not store the data at address pointer[idx]. Defaults to None. + cache_modifier (str, optional): Controls cache behavior of the store. Supported values are: + + - None: *(default)* — Same as ".wb". Uses write-back caching at all levels (CU, L2, LLC) with LRU policy. + - ".wb": Write-back. Write-allocate on L1 miss, inserted into caches and written back later. + - ".cg": Cache Global. Equivalent to ".wb" — stored through L1 → L2 → LLC under LRU. + - ".cs": Cache Streaming. Bypasses L1, streamed through L2, not retained in LLC. + - ".wt": Write-Through. Bypasses L1 and L2 (coherent cache bypass), may hit in LLC with LRU. Returns: None @@ -1789,11 +1818,22 @@ def store(pointer, value, from_rank, to_rank, heap_bases, mask=None): >>> iris.store(ptr, value, cur_rank, remote_rank, heap_bases) """ translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) - tl.store(translated_ptr, value, mask=mask) + tl.store(translated_ptr, value, mask=mask, cache_modifier=cache_modifier) @triton.jit -def copy(src_ptr, dst_ptr, from_rank, to_rank, cur_rank, heap_bases, mask=None): +def copy( + src_ptr, + dst_ptr, + from_rank, + to_rank, + cur_rank, + heap_bases, + mask=None, + other=None, + load_cache_modifier=None, + store_cache_modifier=None, +): """ Copies data from the specified rank's memory into the destination rank's memory. This function performs the transfer by translating `src_ptr` from the `from_rank`'s address @@ -1810,6 +1850,20 @@ def copy(src_ptr, dst_ptr, from_rank, to_rank, cur_rank, heap_bases, mask=None): cur_rank (int): The rank ID issuing the copy operation. Must be either `from_rank` or `to_rank`. heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. mask (Block of triton.int1, optional): If mask[idx] is false, do not load from the translated src_ptr[idx] and do not store to dst_ptr[idx]. Defaults to None. + other (Block, optional): Value to return for masked-out elements during the load operation. If not provided, the result for masked-out elements is undefined. Defaults to None. + + load_cache_modifier (str, optional): Controls cache behavior of the load. Supported values are: + - None: *(default)* — Same as ".ca". Uses cache at all levels (CU, L2, LLC) with LRU policy. + - ".ca": Cache at all levels (CU, L2, LLC) with LRU policy. + - ".cg": Bypasses the CU (L1) cache, streams through L2, and may hit in LLC but the line is not retained or inserted. + - ".cv": Bypasses all GPU caches (CU and L2) and fetches directly from system memory. If data exists in the LLC, it may hit, but is not retained or inserted. + + store_cache_modifier (str, optional): Controls cache behavior of the store. Supported values are: + - None: *(default)* — Same as ".wb". Uses write-back caching at all levels (CU, L2, LLC) with LRU policy. + - ".wb": Write-back. Write-allocate on L1 miss, inserted into caches and written back later. + - ".cg": Cache Global. Equivalent to ".wb" — stored through L1 → L2 → LLC under LRU. + - ".cs": Cache Streaming. Bypasses L1, streamed through L2, not retained in LLC. + - ".wt": Write-Through. Bypasses L1 and L2 (coherent cache bypass), may hit in LLC with LRU. Returns: None @@ -1839,12 +1893,22 @@ def copy(src_ptr, dst_ptr, from_rank, to_rank, cur_rank, heap_bases, mask=None): translated_src = tl.cast(from_base_byte + src_offset, src_ptr.dtype) translated_dst = tl.cast(to_base_byte + dst_offset, src_ptr.dtype) - data = tl.load(translated_src, mask=mask) - tl.store(translated_dst, data, mask=mask) + data = tl.load(translated_src, mask=mask, other=other, cache_modifier=load_cache_modifier) + tl.store(translated_dst, data, mask=mask, cache_modifier=store_cache_modifier) @triton.jit -def get(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None): +def get( + from_ptr, + to_ptr, + from_rank, + to_rank, + heap_bases, + mask=None, + other=None, + load_cache_modifier=None, + store_cache_modifier=None, +): """ Copies data from the specified rank's memory to the current rank's local memory. @@ -1860,6 +1924,20 @@ def get(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None): to_rank (int): The current rank ID where the data will be stored. heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. mask (Block of triton.int1, optional): If mask[idx] is false, do not load the data at address from_ptr[idx] and do not store to to_ptr[idx]. Defaults to None. + other (Block, optional): Value to return for masked-out elements during the load operation. If not provided, the result for masked-out elements is undefined. Defaults to None. + + load_cache_modifier (str, optional): Controls cache behavior of the load. Supported values are: + - None: *(default)* — Same as ".ca". Uses cache at all levels (CU, L2, LLC) with LRU policy. + - ".ca": Cache at all levels (CU, L2, LLC) with LRU policy. + - ".cg": Bypasses the CU (L1) cache, streams through L2, and may hit in LLC but the line is not retained or inserted. + - ".cv": Bypasses all GPU caches (CU and L2) and fetches directly from system memory. If data exists in the LLC, it may hit, but is not retained or inserted. + + store_cache_modifier (str, optional): Controls cache behavior of the store. Supported values are: + - None: *(default)* — Same as ".wb". Uses write-back caching at all levels (CU, L2, LLC) with LRU policy. + - ".wb": Write-back. Write-allocate on L1 miss, inserted into caches and written back later. + - ".cg": Cache Global. Equivalent to ".wb" — stored through L1 → L2 → LLC under LRU. + - ".cs": Cache Streaming. Bypasses L1, streamed through L2, not retained in LLC. + - ".wt": Write-Through. Bypasses L1 and L2 (coherent cache bypass), may hit in LLC with LRU. Returns: None @@ -1873,13 +1951,23 @@ def get(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None): """ translated_from_ptr = __translate(from_ptr, from_rank, to_rank, heap_bases) - data = tl.load(translated_from_ptr, mask=mask) + data = tl.load(translated_from_ptr, mask=mask, other=other, cache_modifier=load_cache_modifier) - tl.store(to_ptr, data, mask=mask) + tl.store(to_ptr, data, mask=mask, cache_modifier=store_cache_modifier) @triton.jit -def put(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None): +def put( + from_ptr, + to_ptr, + from_rank, + to_rank, + heap_bases, + mask=None, + other=None, + load_cache_modifier=None, + store_cache_modifier=None, +): """ Copies data from the current rank's local memory to the specified rank's memory. This function performs a memory write operation by loading data from the current @@ -1894,6 +1982,20 @@ def put(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None): to_rank (int): The `to_rank` ID to which the data will be written. heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. mask (Block of triton.int1, optional): If mask[idx] is false, do not load the data at address from_ptr[idx] and do not store to to_ptr[idx]. Defaults to None. + other (Block, optional): Value to return for masked-out elements during the load operation. If not provided, the result for masked-out elements is undefined. Defaults to None. + + load_cache_modifier (str, optional): Controls cache behavior of the load. Supported values are: + - None: *(default)* — Same as ".ca". Uses cache at all levels (CU, L2, LLC) with LRU policy. + - ".ca": Cache at all levels (CU, L2, LLC) with LRU policy. + - ".cg": Bypasses the CU (L1) cache, streams through L2, and may hit in LLC but the line is not retained or inserted. + - ".cv": Bypasses all GPU caches (CU and L2) and fetches directly from system memory. If data exists in the LLC, it may hit, but is not retained or inserted. + + store_cache_modifier (str, optional): Controls cache behavior of the store. Supported values are: + - None: *(default)* — Same as ".wb". Uses write-back caching at all levels (CU, L2, LLC) with LRU policy. + - ".wb": Write-back. Write-allocate on L1 miss, inserted into caches and written back later. + - ".cg": Cache Global. Equivalent to ".wb" — stored through L1 → L2 → LLC under LRU. + - ".cs": Cache Streaming. Bypasses L1, streamed through L2, not retained in LLC. + - ".wt": Write-Through. Bypasses L1 and L2 (coherent cache bypass), may hit in LLC with LRU. Returns: None @@ -1907,9 +2009,9 @@ def put(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None): """ translated_to_ptr = __translate(to_ptr, from_rank, to_rank, heap_bases) - data = tl.load(from_ptr, mask=mask) + data = tl.load(from_ptr, mask=mask, other=other, cache_modifier=load_cache_modifier) - tl.store(translated_to_ptr, data, mask=mask) + tl.store(translated_to_ptr, data, mask=mask, cache_modifier=store_cache_modifier) @triton.jit diff --git a/tests/unittests/test_copy_cache_modifiers.py b/tests/unittests/test_copy_cache_modifiers.py new file mode 100644 index 00000000..b7c278ea --- /dev/null +++ b/tests/unittests/test_copy_cache_modifiers.py @@ -0,0 +1,221 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import torch +import triton +import triton.language as tl +import pytest +import iris +from itertools import product + + +@triton.jit +def copy_kernel_local_read_remote_write( + data, + results, + cur_rank: tl.constexpr, + num_ranks: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + heap_bases: tl.tensor, + load_cache_modifier: tl.constexpr, + store_cache_modifier: tl.constexpr, +): + """Copy from local memory to remote memory (local read, remote write)""" + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < BLOCK_SIZE + + # Copy from current rank to other ranks + for target_rank in range(num_ranks): + src_data = data + BLOCK_SIZE * cur_rank + dest_data = results + BLOCK_SIZE * cur_rank + if load_cache_modifier is None and store_cache_modifier is None: + iris.copy(src_data + offsets, dest_data + offsets, cur_rank, target_rank, cur_rank, heap_bases, mask=mask) + elif load_cache_modifier is None: + iris.copy( + src_data + offsets, + dest_data + offsets, + cur_rank, + target_rank, + cur_rank, + heap_bases, + mask=mask, + store_cache_modifier=store_cache_modifier, + ) + elif store_cache_modifier is None: + iris.copy( + src_data + offsets, + dest_data + offsets, + cur_rank, + target_rank, + cur_rank, + heap_bases, + mask=mask, + load_cache_modifier=load_cache_modifier, + ) + else: + iris.copy( + src_data + offsets, + dest_data + offsets, + cur_rank, + target_rank, + cur_rank, + heap_bases, + mask=mask, + load_cache_modifier=load_cache_modifier, + store_cache_modifier=store_cache_modifier, + ) + + +@triton.jit +def copy_kernel_remote_read_local_write( + data, + results, + cur_rank: tl.constexpr, + num_ranks: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + heap_bases: tl.tensor, + load_cache_modifier: tl.constexpr, + store_cache_modifier: tl.constexpr, +): + """Copy from remote memory to local memory (remote read, local write)""" + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < BLOCK_SIZE + + # Copy from other ranks to current rank + for source_rank in range(num_ranks): + src_data = data + BLOCK_SIZE * source_rank + dest_data = results + BLOCK_SIZE * source_rank + if load_cache_modifier is None and store_cache_modifier is None: + iris.copy(src_data + offsets, dest_data + offsets, source_rank, cur_rank, cur_rank, heap_bases, mask=mask) + elif load_cache_modifier is None: + iris.copy( + src_data + offsets, + dest_data + offsets, + source_rank, + cur_rank, + cur_rank, + heap_bases, + mask=mask, + store_cache_modifier=store_cache_modifier, + ) + elif store_cache_modifier is None: + iris.copy( + src_data + offsets, + dest_data + offsets, + source_rank, + cur_rank, + cur_rank, + heap_bases, + mask=mask, + load_cache_modifier=load_cache_modifier, + ) + else: + iris.copy( + src_data + offsets, + dest_data + offsets, + source_rank, + cur_rank, + cur_rank, + heap_bases, + mask=mask, + load_cache_modifier=load_cache_modifier, + store_cache_modifier=store_cache_modifier, + ) + + +# Define cache modifiers for load and store operations +LOAD_CACHE_MODIFIERS = [None, "", ".ca", ".cg", ".cv"] +# Remote stores (cross-GPU IPC) cannot use cache modifier bits +# Only default (None or empty string) works - cache bits break coherency +STORE_CACHE_MODIFIERS_REMOTE_WRITE = [None, ""] +# For testing remote reads (which work with all load modifiers), +# we can use all store modifiers since the store is local +LOAD_CACHE_MODIFIERS_REMOTE_READ = [None, "", ".ca", ".cg", ".cv"] +STORE_CACHE_MODIFIERS_LOCAL_WRITE = [None, "", ".wb", ".cg", ".cs", ".wt"] + + +@pytest.mark.parametrize( + "load_cache_modifier,store_cache_modifier", list(product(LOAD_CACHE_MODIFIERS, STORE_CACHE_MODIFIERS_REMOTE_WRITE)) +) +def test_copy_local_read_remote_write(load_cache_modifier, store_cache_modifier): + """Test copy: local read → remote write + + Direction: from_rank=cur_rank (local), to_rank=other (remote) + - Load: from LOCAL memory (all cache modifiers should work) + - Store: to REMOTE memory (only None/"" work, cache bits break coherency) + + This tests that load cache modifiers work for local reads. + """ + shmem = iris.iris(1 << 20) + num_ranks = shmem.get_num_ranks() + heap_bases = shmem.get_heap_bases() + cur_rank = shmem.get_rank() + + BLOCK_SIZE = 16 + data = shmem.zeros((num_ranks, BLOCK_SIZE), dtype=torch.float32) + base = cur_rank + num_ranks + for i in range(num_ranks): + data[i, :] = base * (i + 1) + + results = shmem.zeros((num_ranks, BLOCK_SIZE), dtype=torch.float32) + grid = lambda meta: (1,) + copy_kernel_local_read_remote_write[grid]( + data, results, cur_rank, num_ranks, BLOCK_SIZE, heap_bases, load_cache_modifier, store_cache_modifier + ) + + shmem.barrier() + + # Verify results - each rank copies its data to all other ranks + for rank_id in range(num_ranks): + expected_value = (rank_id + num_ranks) * (rank_id + 1) + assert torch.allclose( + results[rank_id], torch.full((BLOCK_SIZE,), expected_value, dtype=torch.float32, device=results.device) + ), ( + f"Mismatch at rank {cur_rank}, slot {rank_id} with load_cache_modifier={load_cache_modifier}, store_cache_modifier={store_cache_modifier}" + ) + + +@pytest.mark.parametrize( + "load_cache_modifier,store_cache_modifier", + list(product(LOAD_CACHE_MODIFIERS_REMOTE_READ, STORE_CACHE_MODIFIERS_LOCAL_WRITE)), +) +def test_copy_remote_read_local_write(load_cache_modifier, store_cache_modifier): + """Test copy: remote read → local write + + Direction: from_rank=other (remote), to_rank=cur_rank (local) + - Load: from REMOTE memory (test if cache modifiers work for remote reads) + - Store: to LOCAL memory (all cache modifiers should work) + + This tests whether load cache modifiers work for remote reads. + """ + shmem = iris.iris(1 << 20) + num_ranks = shmem.get_num_ranks() + heap_bases = shmem.get_heap_bases() + cur_rank = shmem.get_rank() + + BLOCK_SIZE = 16 + data = shmem.zeros((num_ranks, BLOCK_SIZE), dtype=torch.float32) + base = cur_rank + num_ranks + for i in range(num_ranks): + data[i, :] = base * (i + 1) + + results = shmem.zeros((num_ranks, BLOCK_SIZE), dtype=torch.float32) + grid = lambda meta: (1,) + copy_kernel_remote_read_local_write[grid]( + data, results, cur_rank, num_ranks, BLOCK_SIZE, heap_bases, load_cache_modifier, store_cache_modifier + ) + + shmem.barrier() + + # Verify results - each rank pulls data from all ranks + for rank_id in range(num_ranks): + expected_value = (rank_id + num_ranks) * (rank_id + 1) + assert torch.allclose( + results[rank_id], torch.full((BLOCK_SIZE,), expected_value, dtype=torch.float32, device=results.device) + ), ( + f"Mismatch at rank {cur_rank}, slot {rank_id} with load_cache_modifier={load_cache_modifier}, store_cache_modifier={store_cache_modifier}" + ) diff --git a/tests/unittests/test_get_cache_modifiers.py b/tests/unittests/test_get_cache_modifiers.py new file mode 100644 index 00000000..d256fcd4 --- /dev/null +++ b/tests/unittests/test_get_cache_modifiers.py @@ -0,0 +1,111 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import torch +import triton +import triton.language as tl +import pytest +import iris +from itertools import product + + +@triton.jit +def get_kernel( + data, + results, + cur_rank: tl.constexpr, + num_ranks: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + heap_bases: tl.tensor, + load_cache_modifier: tl.constexpr, + store_cache_modifier: tl.constexpr, +): + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < BLOCK_SIZE + + acc = tl.zeros([BLOCK_SIZE], dtype=data.type.element_ty) + + # Loop over all ranks, get the stored data with cache modifiers + # We test default values set by the function when parameters are None + for target_rank in range(num_ranks): + if load_cache_modifier is None and store_cache_modifier is None: + iris.get(data + offsets, results + offsets, target_rank, cur_rank, heap_bases, mask=mask) + elif load_cache_modifier is None: + iris.get( + data + offsets, + results + offsets, + target_rank, + cur_rank, + heap_bases, + mask=mask, + store_cache_modifier=store_cache_modifier, + ) + elif store_cache_modifier is None: + iris.get( + data + offsets, + results + offsets, + target_rank, + cur_rank, + heap_bases, + mask=mask, + load_cache_modifier=load_cache_modifier, + ) + else: + iris.get( + data + offsets, + results + offsets, + target_rank, + cur_rank, + heap_bases, + mask=mask, + load_cache_modifier=load_cache_modifier, + store_cache_modifier=store_cache_modifier, + ) + acc += tl.load(results + offsets, mask=mask) + + # Store the accumulated value back to the output + tl.store(results + offsets, acc, mask=mask) + + +# Define cache modifiers for load and store operations +LOAD_CACHE_MODIFIERS = [None, "", ".ca", ".cg", ".cv"] +STORE_CACHE_MODIFIERS = [None, "", ".wb", ".cg", ".cs", ".wt"] + + +@pytest.mark.parametrize( + "load_cache_modifier,store_cache_modifier", list(product(LOAD_CACHE_MODIFIERS, STORE_CACHE_MODIFIERS)) +) +def test_get_cache_modifiers(load_cache_modifier, store_cache_modifier): + """Test get (copy from other rank) with various cache modifiers.""" + shmem = iris.iris(1 << 20) + num_ranks = shmem.get_num_ranks() + heap_bases = shmem.get_heap_bases() + cur_rank = shmem.get_rank() + + BLOCK_SIZE = 16 + data = shmem.ones(BLOCK_SIZE, dtype=torch.float32) + results = shmem.zeros_like(data) + + shmem.barrier() + + grid = lambda meta: (1,) + get_kernel[grid]( + data, results, cur_rank, num_ranks, BLOCK_SIZE, heap_bases, load_cache_modifier, store_cache_modifier + ) + shmem.barrier() + + # Verify the result - should get data from all ranks (including self) + expected = torch.ones(BLOCK_SIZE, dtype=torch.float32, device="cuda") * num_ranks + + try: + torch.testing.assert_close(results, expected, rtol=0, atol=0) + except AssertionError as e: + print( + f"GET test failed with load_cache_modifier={load_cache_modifier}, store_cache_modifier={store_cache_modifier}" + ) + print(e) + print("Expected:", expected) + print("Actual:", results) + raise diff --git a/tests/unittests/test_get_other_triton.py b/tests/unittests/test_get_other_triton.py new file mode 100644 index 00000000..6b83ac0d --- /dev/null +++ b/tests/unittests/test_get_other_triton.py @@ -0,0 +1,99 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import torch +import triton +import triton.language as tl +import pytest +import iris + + +@triton.jit +def get_with_other_kernel( + data, + results, + cur_rank: tl.constexpr, + num_ranks: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + heap_bases: tl.tensor, + other_value: tl.constexpr, +): + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + + # Create a mask that is False for half the elements + mask = offsets < BLOCK_SIZE // 2 + + acc = tl.zeros([BLOCK_SIZE], dtype=data.type.element_ty) + + # Loop over all ranks, get the stored data. + # load to local register, accumulate. + for target_rank in range(num_ranks): + iris.get(data + offsets, results + offsets, target_rank, cur_rank, heap_bases, mask=mask, other=other_value) + acc += tl.load(results + offsets) + + # Store the accumulated value back to the output. + tl.store(results + offsets, acc) + + +@pytest.mark.parametrize( + "dtype", + [ + torch.float16, + torch.bfloat16, + torch.float32, + ], +) +@pytest.mark.parametrize( + "BLOCK_SIZE", + [ + 8, + 16, + 32, + ], +) +def test_get_other_api(dtype, BLOCK_SIZE): + # TODO: Adjust heap size. + shmem = iris.iris(1 << 20) + num_ranks = shmem.get_num_ranks() + heap_bases = shmem.get_heap_bases() + cur_rank = shmem.get_rank() + + data = shmem.ones(BLOCK_SIZE, dtype=dtype) + results = shmem.zeros_like(data) + + # Use -1 as the "other" value for masked-out elements + other_value = -1.0 + + shmem.barrier() + + grid = lambda meta: (1,) + get_with_other_kernel[grid](data, results, cur_rank, num_ranks, BLOCK_SIZE, heap_bases, other_value) + shmem.barrier() + + # Verify the results + # First half should contain loaded values accumulated from all ranks (num_ranks * 1.0) + # Second half should contain accumulated "other" values (num_ranks * -1.0) + expected = torch.zeros(BLOCK_SIZE, dtype=dtype, device="cuda") + expected[: BLOCK_SIZE // 2] = num_ranks * 1.0 + expected[BLOCK_SIZE // 2 :] = num_ranks * other_value + + try: + torch.testing.assert_close(results, expected, rtol=0, atol=0) + except AssertionError as e: + print(e) + print("Expected:", expected) + print("Actual:", results) + raise + finally: + # Final barrier to ensure all ranks complete before test cleanup + # This helps with test isolation when running multiple tests + # Note: shmem.barrier() already does cuda.synchronize() + shmem.barrier() + # Explicitly delete the shmem instance to trigger cleanup + del shmem + # Force garbage collection to ensure IPC handles are cleaned up + import gc + + gc.collect() diff --git a/tests/unittests/test_get_triton.py b/tests/unittests/test_get_triton.py index b19cf235..f4004691 100644 --- a/tests/unittests/test_get_triton.py +++ b/tests/unittests/test_get_triton.py @@ -31,7 +31,7 @@ def get_kernel( # Loop over all ranks, get the stored data. # load to local register, accumulate. for target_rank in range(num_ranks): - iris.get(data + offsets, results + offsets, cur_rank, target_rank, heap_bases, mask=mask) + iris.get(data + offsets, results + offsets, target_rank, cur_rank, heap_bases, mask=mask) acc += tl.load(results + offsets, mask=mask) # Store the accumulated value back to the output. diff --git a/tests/unittests/test_load_cache_modifiers.py b/tests/unittests/test_load_cache_modifiers.py new file mode 100644 index 00000000..5c147300 --- /dev/null +++ b/tests/unittests/test_load_cache_modifiers.py @@ -0,0 +1,82 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import torch +import triton +import triton.language as tl +import pytest +import iris +from itertools import product + + +@triton.jit +def kernel( + data, + results, + source_rank: tl.constexpr, + num_ranks: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + heap_bases: tl.tensor, + cache_modifier: tl.constexpr, + volatile: tl.constexpr, +): + pid = tl.program_id(0) + + partner = int((source_rank + num_ranks // 2) % num_ranks) + # Compute start index of this block + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + + # Guard for out-of-bounds accesses + mask = offsets < BLOCK_SIZE + + if cache_modifier is None: + result = iris.load(data + offsets, source_rank, partner, heap_bases, mask=mask, volatile=volatile) + else: + result = iris.load( + data + offsets, + source_rank, + partner, + heap_bases, + mask=mask, + cache_modifier=cache_modifier, + volatile=volatile, + ) + + tl.store(results + offsets, result, mask=mask) + + +# Define cache modifiers and volatile options +CACHE_MODIFIERS = [None, "", ".ca", ".cg", ".cv"] +VOLATILE_OPTIONS = [False, True] + + +@pytest.mark.parametrize("cache_modifier,volatile", list(product(CACHE_MODIFIERS, VOLATILE_OPTIONS))) +def test_load_cache_modifiers(cache_modifier, volatile): + """Test load with various cache modifiers and volatile settings.""" + shmem = iris.iris(1 << 20) + num_ranks = shmem.get_num_ranks() + heap_bases = shmem.get_heap_bases() + source_rank = shmem.get_rank() + partner = int((source_rank + num_ranks // 2) % num_ranks) + + BLOCK_SIZE = 16 + data = shmem.full((BLOCK_SIZE,), source_rank, dtype=torch.float32) + results = shmem.zeros_like(data) + + shmem.barrier() + + grid = lambda meta: (1,) + kernel[grid](data, results, source_rank, num_ranks, BLOCK_SIZE, heap_bases, cache_modifier, volatile) + shmem.barrier() + + # Verify the result + expected = torch.ones(BLOCK_SIZE, dtype=torch.float32, device="cuda") * partner + + try: + torch.testing.assert_close(results, expected, rtol=0, atol=0) + except AssertionError as e: + print(e) + print("Expected:", expected) + print("Actual:", results) + raise diff --git a/tests/unittests/test_load_other_triton.py b/tests/unittests/test_load_other_triton.py new file mode 100644 index 00000000..f9fe0e00 --- /dev/null +++ b/tests/unittests/test_load_other_triton.py @@ -0,0 +1,97 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import torch +import triton +import triton.language as tl +import pytest +import iris + + +@triton.jit +def load_with_other_kernel( + data, + results, + source_rank: tl.constexpr, + num_ranks: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + heap_bases: tl.tensor, + other_value: tl.constexpr, +): + pid = tl.program_id(0) + + partner = int((source_rank + num_ranks // 2) % num_ranks) + # Compute start index of this block + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + + # Create a mask that is False for half the elements + mask = offsets < BLOCK_SIZE // 2 + + # Load with mask and other parameter + result = iris.load(data + offsets, source_rank, partner, heap_bases, mask=mask, other=other_value) + tl.store(results + offsets, result) + + +@pytest.mark.parametrize( + "dtype", + [ + torch.float16, + torch.bfloat16, + torch.float32, + ], +) +@pytest.mark.parametrize( + "BLOCK_SIZE", + [ + 8, + 16, + 32, + ], +) +def test_load_other_api(dtype, BLOCK_SIZE): + # TODO: Adjust heap size. + shmem = iris.iris(1 << 20) + num_ranks = shmem.get_num_ranks() + heap_bases = shmem.get_heap_bases() + source_rank = shmem.get_rank() + partner = int((source_rank + num_ranks // 2) % num_ranks) + + # Fill data with partner rank value + data = shmem.full((BLOCK_SIZE,), partner, dtype=dtype) + results = shmem.zeros_like(data) + + # Use -1 as the "other" value for masked-out elements + other_value = -1.0 + + shmem.barrier() + + grid = lambda meta: (1,) + load_with_other_kernel[grid](data, results, source_rank, num_ranks, BLOCK_SIZE, heap_bases, other_value) + shmem.barrier() + + # Verify the result + # First half should contain loaded values (partner rank) + # Second half should contain the "other" value (-1.0) + expected = torch.zeros(BLOCK_SIZE, dtype=dtype, device="cuda") + expected[: BLOCK_SIZE // 2] = partner + expected[BLOCK_SIZE // 2 :] = other_value + + try: + torch.testing.assert_close(results, expected, rtol=0, atol=0) + except AssertionError as e: + print(e) + print("Expected:", expected) + print("Actual:", results) + raise + finally: + # Final barrier to ensure all ranks complete before test cleanup + # This helps with test isolation when running multiple tests + # Note: shmem.barrier() already does cuda.synchronize() + shmem.barrier() + # Explicitly delete the shmem instance to trigger cleanup + del shmem + # Force garbage collection to ensure IPC handles are cleaned up + import gc + + gc.collect() diff --git a/tests/unittests/test_put_cache_modifiers.py b/tests/unittests/test_put_cache_modifiers.py new file mode 100644 index 00000000..01b48037 --- /dev/null +++ b/tests/unittests/test_put_cache_modifiers.py @@ -0,0 +1,105 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import torch +import triton +import triton.language as tl +import pytest +import iris +from itertools import product + + +@triton.jit +def put_kernel( + data, + results, + cur_rank: tl.constexpr, + num_ranks: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + heap_bases: tl.tensor, + load_cache_modifier: tl.constexpr, + store_cache_modifier: tl.constexpr, +): + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < BLOCK_SIZE + + # Put data to all ranks with cache modifiers + # We test default values set by the function when parameters are None + for target_rank in range(num_ranks): + if load_cache_modifier is None and store_cache_modifier is None: + iris.put(data + offsets, results + offsets, cur_rank, target_rank, heap_bases, mask=mask) + elif load_cache_modifier is None: + iris.put( + data + offsets, + results + offsets, + cur_rank, + target_rank, + heap_bases, + mask=mask, + store_cache_modifier=store_cache_modifier, + ) + elif store_cache_modifier is None: + iris.put( + data + offsets, + results + offsets, + cur_rank, + target_rank, + heap_bases, + mask=mask, + load_cache_modifier=load_cache_modifier, + ) + else: + iris.put( + data + offsets, + results + offsets, + cur_rank, + target_rank, + heap_bases, + mask=mask, + load_cache_modifier=load_cache_modifier, + store_cache_modifier=store_cache_modifier, + ) + + +# Define cache modifiers for load and store operations +LOAD_CACHE_MODIFIERS = [None, "", ".ca", ".cg", ".cv"] +STORE_CACHE_MODIFIERS = [None, "", ".wb", ".cg", ".cs", ".wt"] + + +@pytest.mark.parametrize( + "load_cache_modifier,store_cache_modifier", list(product(LOAD_CACHE_MODIFIERS, STORE_CACHE_MODIFIERS)) +) +def test_put_cache_modifiers(load_cache_modifier, store_cache_modifier): + """Test put (copy to other rank) with various cache modifiers.""" + shmem = iris.iris(1 << 20) + num_ranks = shmem.get_num_ranks() + heap_bases = shmem.get_heap_bases() + cur_rank = shmem.get_rank() + + BLOCK_SIZE = 16 + data = shmem.ones(BLOCK_SIZE, dtype=torch.float32) + results = shmem.zeros_like(data) + + shmem.barrier() + + grid = lambda meta: (1,) + put_kernel[grid]( + data, results, cur_rank, num_ranks, BLOCK_SIZE, heap_bases, load_cache_modifier, store_cache_modifier + ) + shmem.barrier() + + # Verify the result - should have the data that was put + expected = torch.ones(BLOCK_SIZE, dtype=torch.float32, device="cuda") + + try: + torch.testing.assert_close(results, expected, rtol=0, atol=0) + except AssertionError as e: + print( + f"PUT test failed with load_cache_modifier={load_cache_modifier}, store_cache_modifier={store_cache_modifier}" + ) + print(e) + print("Expected:", expected) + print("Actual:", results) + raise diff --git a/tests/unittests/test_put_other_triton.py b/tests/unittests/test_put_other_triton.py new file mode 100644 index 00000000..db78bfc3 --- /dev/null +++ b/tests/unittests/test_put_other_triton.py @@ -0,0 +1,94 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import torch +import triton +import triton.language as tl +import pytest +import iris + + +@triton.jit +def put_with_other_kernel( + data, + results, + cur_rank: tl.constexpr, + num_ranks: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + heap_bases: tl.tensor, + other_value: tl.constexpr, +): + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + + # Create a mask that is False for half the elements + mask = offsets < BLOCK_SIZE // 2 + + # Put data in all ranks with mask and other parameter + # The "other" value will be used for masked-out elements during the load from data + for target_rank in range(num_ranks): + iris.put(data + offsets, results + offsets, cur_rank, target_rank, heap_bases, mask=mask, other=other_value) + + +@pytest.mark.parametrize( + "dtype", + [ + torch.float16, + torch.bfloat16, + torch.float32, + ], +) +@pytest.mark.parametrize( + "BLOCK_SIZE", + [ + 8, + 16, + 32, + ], +) +def test_put_other_api(dtype, BLOCK_SIZE): + # TODO: Adjust heap size. + shmem = iris.iris(1 << 20) + num_ranks = shmem.get_num_ranks() + heap_bases = shmem.get_heap_bases() + cur_rank = shmem.get_rank() + + # Fill data with ones + data = shmem.ones(BLOCK_SIZE, dtype=dtype) + results = shmem.zeros_like(data) + + # Use -1 as the "other" value for masked-out elements + other_value = -1.0 + + shmem.barrier() + + grid = lambda meta: (1,) + put_with_other_kernel[grid](data, results, cur_rank, num_ranks, BLOCK_SIZE, heap_bases, other_value) + shmem.barrier() + + # Verify the results + # First half should contain the value 1.0 (from data) + # Second half should contain the "other" value (-1.0) since mask was False + expected = torch.zeros(BLOCK_SIZE, dtype=dtype, device="cuda") + expected[: BLOCK_SIZE // 2] = 1.0 + expected[BLOCK_SIZE // 2 :] = other_value + + try: + torch.testing.assert_close(results, expected, rtol=0, atol=0) + except AssertionError as e: + print(e) + print("Expected:", expected) + print("Actual:", results) + raise + finally: + # Final barrier to ensure all ranks complete before test cleanup + # This helps with test isolation when running multiple tests + # Note: shmem.barrier() already does cuda.synchronize() + shmem.barrier() + # Explicitly delete the shmem instance to trigger cleanup + del shmem + # Force garbage collection to ensure IPC handles are cleaned up + import gc + + gc.collect() diff --git a/tests/unittests/test_store_cache_modifiers.py b/tests/unittests/test_store_cache_modifiers.py new file mode 100644 index 00000000..892a09fb --- /dev/null +++ b/tests/unittests/test_store_cache_modifiers.py @@ -0,0 +1,79 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import torch +import triton +import triton.language as tl +import pytest +import iris + + +@triton.jit +def kernel( + data, + results, + destination_rank: tl.constexpr, + num_ranks: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + heap_bases: tl.tensor, + cache_modifier: tl.constexpr, +): + pid = tl.program_id(0) + + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + + mask = offsets < BLOCK_SIZE + + # Load the data from src for this block + value = tl.load(data + offsets, mask=mask) + + # Store data to all ranks with the specified cache modifier + for dst_rank in range(num_ranks): + if cache_modifier is None: + iris.store(results + offsets, value, destination_rank, dst_rank, heap_bases, mask=mask) + else: + iris.store( + results + offsets, + value, + destination_rank, + dst_rank, + heap_bases, + mask=mask, + cache_modifier=cache_modifier, + ) + + +# Define cache modifiers for store operations +# Based on the provided cache modifier descriptions +CACHE_MODIFIERS = [None, "", ".wb", ".cg", ".cs", ".wt"] + + +@pytest.mark.parametrize("cache_modifier", CACHE_MODIFIERS) +def test_store_cache_modifiers(cache_modifier): + """Test store with various cache modifiers.""" + shmem = iris.iris(1 << 20) + num_ranks = shmem.get_num_ranks() + heap_bases = shmem.get_heap_bases() + destination_rank = shmem.get_rank() + + BLOCK_SIZE = 16 + src = shmem.ones(BLOCK_SIZE, dtype=torch.float32) + results = shmem.zeros_like(src) + + shmem.barrier() + + grid = lambda meta: (1,) + kernel[grid](src, results, destination_rank, num_ranks, BLOCK_SIZE, heap_bases, cache_modifier) + shmem.barrier() + + # Verify the result + expected = torch.ones(BLOCK_SIZE, dtype=torch.float32, device="cuda") + + try: + torch.testing.assert_close(results, expected, rtol=0, atol=0) + except AssertionError as e: + print(e) + print("Expected:", expected) + print("Actual:", results) + raise