Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
ba7028b
Add initial cache modifiers code and docs
mawad-amd Sep 10, 2025
276713b
Add test
mawad-amd Sep 10, 2025
6f6818f
Apply Ruff auto-fixes
github-actions[bot] Sep 10, 2025
9ad63a0
Use `None` for default value
mawad-amd Sep 13, 2025
677c966
Apply Ruff auto-fixes
github-actions[bot] Sep 13, 2025
af3592d
Cleanup the test
mawad-amd Sep 14, 2025
8a411a2
Apply Ruff auto-fixes
github-actions[bot] Sep 14, 2025
ff26f96
Merge branch 'main' into muhaawad/cache-modifiers
mawad-amd Sep 14, 2025
8f76d95
Check return value
mawad-amd Sep 14, 2025
87fb74a
Remove volatile from store
mawad-amd Sep 14, 2025
162ec39
Add test store
mawad-amd Sep 14, 2025
01da6ca
Apply Ruff auto-fixes
github-actions[bot] Sep 14, 2025
9a27ead
Add put/get modifiers
mawad-amd Sep 14, 2025
99ee66c
Add tests for put and get cache modifiers
mawad-amd Sep 14, 2025
74d0133
Apply Ruff auto-fixes
github-actions[bot] Sep 14, 2025
e76d4c5
Test default values
mawad-amd Sep 14, 2025
451ee99
Fix default value docstring
mawad-amd Sep 14, 2025
b524f40
Fix tests
mawad-amd Sep 14, 2025
b8bd8a7
Apply Ruff auto-fixes
github-actions[bot] Sep 14, 2025
0a157b5
Sync cache modifiers branch with main and add cache modifiers to copy…
Copilot Oct 11, 2025
b127b91
Merge branch 'main' into muhaawad/cache-modifiers
mawad-amd Oct 11, 2025
c9f314f
Merge branch 'main' into muhaawad/cache-modifiers
mawad-amd Oct 24, 2025
e74aacd
Fix device mismatch in test_copy_cache_modifiers assertions (#271)
Copilot Oct 29, 2025
88970ee
Fix pointer arithmetic in test_copy_cache_modifiers (#273)
Copilot Oct 30, 2025
e3426f5
Merge branch 'main' into muhaawad/cache-modifiers
mawad-amd Dec 19, 2025
1827481
Disable remote store cache modifers
mawad-amd Dec 28, 2025
1edad3f
Add tests for both local read/remote
mawad-amd Dec 28, 2025
369d854
Apply Ruff auto-fixes
github-actions[bot] Dec 28, 2025
3c9cc49
Merge branch 'main' into muhaawad/cache-modifiers
mawad-amd Jan 31, 2026
ed5b397
Add `other` parameter to distributed memory operations (#343)
Copilot Feb 3, 2026
b218be0
Merge branch 'main' into muhaawad/cache-modifiers
neoblizz Feb 4, 2026
e2a4bb6
Initial plan
Copilot Feb 4, 2026
20f9990
Fix parameter order in test_get_other_triton.py
Copilot Feb 4, 2026
b1c357e
Fix parameter order in test_get_cache_modifiers.py and test_get_trito…
Copilot Feb 4, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions iris/experimental/iris_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,14 +142,15 @@ def _translate(self, ptr, from_rank, to_rank):
return translated_ptr

@gluon.jit
def load(self, pointer, from_rank, mask=None):
def load(self, pointer, from_rank, mask=None, other=None):
"""
Loads a value from the specified rank's memory location to the current rank.

Args:
pointer: Pointer in the `from_rank`'s address space
from_rank: The rank ID from which to read the data
mask: Optional mask for conditional loading
other: Value to return for masked-out elements. If not provided, the result for masked-out elements is undefined.

Returns:
The loaded value from the target memory location
Expand All @@ -159,7 +160,7 @@ def load(self, pointer, from_rank, mask=None):
>>> data = ctx.load(buffer + offsets, 1, mask=mask)
"""
translated_ptr = self._translate(pointer, self.cur_rank, from_rank)
result = gl.load(translated_ptr, mask=mask)
result = gl.load(translated_ptr, mask=mask, other=other)
return result

@gluon.jit
Expand All @@ -181,7 +182,7 @@ def store(self, pointer, value, to_rank, mask=None):
gl.store(translated_ptr, value, mask=mask)

@gluon.jit
def get(self, from_ptr, to_ptr, from_rank, mask=None):
def get(self, from_ptr, to_ptr, from_rank, mask=None, other=None):
"""
Copies data from the specified rank's memory to the current rank's local memory.

Expand All @@ -190,17 +191,18 @@ def get(self, from_ptr, to_ptr, from_rank, mask=None):
to_ptr: Pointer to local memory in current rank
from_rank: The rank ID from which to read the data
mask: Optional mask for conditional operations
other: Value to return for masked-out elements during the load operation. If not provided, the result for masked-out elements is undefined.

Example:
>>> # Copy from rank 1 to current rank's local memory
>>> ctx.get(remote_ptr + offsets, local_ptr + offsets, 1, mask=mask)
"""
translated_from_ptr = self._translate(from_ptr, self.cur_rank, from_rank)
data = gl.load(translated_from_ptr, mask=mask)
data = gl.load(translated_from_ptr, mask=mask, other=other)
gl.store(to_ptr, data, mask=mask)

@gluon.jit
def put(self, from_ptr, to_ptr, to_rank, mask=None):
def put(self, from_ptr, to_ptr, to_rank, mask=None, other=None):
"""
Copies data from the current rank's local memory to the specified rank's memory.

Expand All @@ -209,17 +211,18 @@ def put(self, from_ptr, to_ptr, to_rank, mask=None):
to_ptr: Pointer to remote memory in `to_rank`'s address space
to_rank: The rank ID to which the data will be written
mask: Optional mask for conditional operations
other: Value to return for masked-out elements during the load operation. If not provided, the result for masked-out elements is undefined.

Example:
>>> # Copy from current rank's local memory to rank 1
>>> ctx.put(local_ptr + offsets, remote_ptr + offsets, 1, mask=mask)
"""
translated_to_ptr = self._translate(to_ptr, self.cur_rank, to_rank)
data = gl.load(from_ptr, mask=mask)
data = gl.load(from_ptr, mask=mask, other=other)
gl.store(translated_to_ptr, data, mask=mask)

@gluon.jit
def copy(self, src_ptr, dst_ptr, from_rank, to_rank, mask=None):
def copy(self, src_ptr, dst_ptr, from_rank, to_rank, mask=None, other=None):
"""
Copies data from the specified rank's memory into the destination rank's memory.

Expand All @@ -235,6 +238,7 @@ def copy(self, src_ptr, dst_ptr, from_rank, to_rank, mask=None):
from_rank: The rank ID that owns `src_ptr` (source rank)
to_rank: The rank ID that will receive the data (destination rank)
mask: Optional mask for conditional operations
other: Value to return for masked-out elements during the load operation. If not provided, the result for masked-out elements is undefined.

Example:
>>> # Copy from rank 1 to rank 0 (current rank must be either 1 or 0)
Expand All @@ -256,7 +260,7 @@ def copy(self, src_ptr, dst_ptr, from_rank, to_rank, mask=None):
translated_src = tl.cast(from_base_byte + src_offset, src_ptr.dtype)
translated_dst = tl.cast(to_base_byte + dst_offset, src_ptr.dtype)

data = gl.load(translated_src, mask=mask)
data = gl.load(translated_src, mask=mask, other=other)
gl.store(translated_dst, data, mask=mask)

@gluon.jit
Expand Down
128 changes: 115 additions & 13 deletions iris/iris.py
Original file line number Diff line number Diff line change
Expand Up @@ -1725,7 +1725,7 @@ def __translate(ptr, from_rank, to_rank, heap_bases):


@triton.jit
def load(pointer, to_rank, from_rank, heap_bases, mask=None):
def load(pointer, to_rank, from_rank, heap_bases, mask=None, other=None, cache_modifier=None, volatile=False):
"""
Loads a value from the specified rank's memory location.

Expand All @@ -1734,12 +1734,29 @@ def load(pointer, to_rank, from_rank, heap_bases, mask=None):
data from the target memory location. If the `from_rank` and `to_rank` are the same,
this function performs a local load operation.

The `cache_modifier` parameter controls instruction-level cache behavior
by setting the appropriate scope (`SC0`, `SC1`) and non-temporal (`NT`) bits
in the global load instruction. These affect cache usage across the CU,
L2, and last-level caches.

Args:
pointer (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the `from_rank`'s address space that will be translated to the `to_rank`'s address space. Must be the current rank where the pointer is local.
to_rank (int): The rank ID to which the pointer will be translated. Must be the current rank where the pointer is local.
from_rank (int): The rank ID from which to read the data.
heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks.
mask (Block of triton.int1, optional): If mask[idx] is false, do not load the data at address pointer[idx]. Defaults to None.
other (Block, optional): Value to return for masked-out elements. If not provided, the result for masked-out elements is undefined. Defaults to None.
cache_modifier (str, optional): Controls cache behavior of the load.

Supported values:
- None: *(default)* — Same as ".ca". Uses cache at all levels (CU, L2, LLC) with LRU policy.
- ".ca": Cache at all levels (CU, L2, LLC) with LRU policy
- ".cg": Bypasses the CU (L1) cache, streams through L2, and may hit in LLC but the line is not retained or inserted.
- ".cv": Bypasses all GPU caches (CU and L2) and fetches directly from system memory. If data exists in the LLC, it may hit, but is not retained or inserted.
Ensures global coherence by invalidating stale GPU cache lines.

volatile (bool, optional): If True, disables compiler optimizations that
could reorder or eliminate the load.

Returns:
Block: The loaded value from the target memory location.
Expand All @@ -1754,12 +1771,12 @@ def load(pointer, to_rank, from_rank, heap_bases, mask=None):
>>> return data
"""
translated_ptr = __translate(pointer, to_rank, from_rank, heap_bases)
result = tl.load(translated_ptr, mask=mask)
result = tl.load(translated_ptr, mask=mask, other=other, cache_modifier=cache_modifier, volatile=volatile)
return result


@triton.jit
def store(pointer, value, from_rank, to_rank, heap_bases, mask=None):
def store(pointer, value, from_rank, to_rank, heap_bases, mask=None, cache_modifier=None):
"""
Writes data to the specified rank's memory location.

Expand All @@ -1768,13 +1785,25 @@ def store(pointer, value, from_rank, to_rank, heap_bases, mask=None):
the provided data to the target memory location. If the `from_rank` and `to_rank` are the same,
this function performs a local store operation.

The `cache_modifier` parameter controls instruction-level cache behavior
by setting the appropriate scope (`SC0`, `SC1`) and non-temporal (`NT`) bits
in the global store instruction. These affect cache usage across the CU (L1),
L2, and last-level cache (LLC), following the CDNA ISA.

Args:
pointer (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the `from_rank`'s address space that will be translated to the `to_rank`'s address space. Must be the current rank where the pointer is local.
value (Block): The tensor of elements to be stored.
from_rank (int): The rank ID from which the pointer originates. Must be the current rank where the pointer is local.
to_rank (int): The rank ID to which the data will be written.
heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks.
mask (Block of triton.int1, optional): If mask[idx] is false, do not store the data at address pointer[idx]. Defaults to None.
cache_modifier (str, optional): Controls cache behavior of the store. Supported values are:

- None: *(default)* — Same as ".wb". Uses write-back caching at all levels (CU, L2, LLC) with LRU policy.
- ".wb": Write-back. Write-allocate on L1 miss, inserted into caches and written back later.
- ".cg": Cache Global. Equivalent to ".wb" — stored through L1 → L2 → LLC under LRU.
- ".cs": Cache Streaming. Bypasses L1, streamed through L2, not retained in LLC.
- ".wt": Write-Through. Bypasses L1 and L2 (coherent cache bypass), may hit in LLC with LRU.

Returns:
None
Expand All @@ -1789,11 +1818,22 @@ def store(pointer, value, from_rank, to_rank, heap_bases, mask=None):
>>> iris.store(ptr, value, cur_rank, remote_rank, heap_bases)
"""
translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases)
tl.store(translated_ptr, value, mask=mask)
tl.store(translated_ptr, value, mask=mask, cache_modifier=cache_modifier)


@triton.jit
def copy(src_ptr, dst_ptr, from_rank, to_rank, cur_rank, heap_bases, mask=None):
def copy(
src_ptr,
dst_ptr,
from_rank,
to_rank,
cur_rank,
heap_bases,
mask=None,
other=None,
load_cache_modifier=None,
store_cache_modifier=None,
):
"""
Copies data from the specified rank's memory into the destination rank's memory.
This function performs the transfer by translating `src_ptr` from the `from_rank`'s address
Expand All @@ -1810,6 +1850,20 @@ def copy(src_ptr, dst_ptr, from_rank, to_rank, cur_rank, heap_bases, mask=None):
cur_rank (int): The rank ID issuing the copy operation. Must be either `from_rank` or `to_rank`.
heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks.
mask (Block of triton.int1, optional): If mask[idx] is false, do not load from the translated src_ptr[idx] and do not store to dst_ptr[idx]. Defaults to None.
other (Block, optional): Value to return for masked-out elements during the load operation. If not provided, the result for masked-out elements is undefined. Defaults to None.

load_cache_modifier (str, optional): Controls cache behavior of the load. Supported values are:
- None: *(default)* — Same as ".ca". Uses cache at all levels (CU, L2, LLC) with LRU policy.
- ".ca": Cache at all levels (CU, L2, LLC) with LRU policy.
- ".cg": Bypasses the CU (L1) cache, streams through L2, and may hit in LLC but the line is not retained or inserted.
- ".cv": Bypasses all GPU caches (CU and L2) and fetches directly from system memory. If data exists in the LLC, it may hit, but is not retained or inserted.

store_cache_modifier (str, optional): Controls cache behavior of the store. Supported values are:
- None: *(default)* — Same as ".wb". Uses write-back caching at all levels (CU, L2, LLC) with LRU policy.
- ".wb": Write-back. Write-allocate on L1 miss, inserted into caches and written back later.
- ".cg": Cache Global. Equivalent to ".wb" — stored through L1 → L2 → LLC under LRU.
- ".cs": Cache Streaming. Bypasses L1, streamed through L2, not retained in LLC.
- ".wt": Write-Through. Bypasses L1 and L2 (coherent cache bypass), may hit in LLC with LRU.

Returns:
None
Expand Down Expand Up @@ -1839,12 +1893,22 @@ def copy(src_ptr, dst_ptr, from_rank, to_rank, cur_rank, heap_bases, mask=None):
translated_src = tl.cast(from_base_byte + src_offset, src_ptr.dtype)
translated_dst = tl.cast(to_base_byte + dst_offset, src_ptr.dtype)

data = tl.load(translated_src, mask=mask)
tl.store(translated_dst, data, mask=mask)
data = tl.load(translated_src, mask=mask, other=other, cache_modifier=load_cache_modifier)
tl.store(translated_dst, data, mask=mask, cache_modifier=store_cache_modifier)


@triton.jit
def get(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None):
def get(
from_ptr,
to_ptr,
from_rank,
to_rank,
heap_bases,
mask=None,
other=None,
load_cache_modifier=None,
store_cache_modifier=None,
):
"""
Copies data from the specified rank's memory to the current rank's local memory.

Expand All @@ -1860,6 +1924,20 @@ def get(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None):
to_rank (int): The current rank ID where the data will be stored.
heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks.
mask (Block of triton.int1, optional): If mask[idx] is false, do not load the data at address from_ptr[idx] and do not store to to_ptr[idx]. Defaults to None.
other (Block, optional): Value to return for masked-out elements during the load operation. If not provided, the result for masked-out elements is undefined. Defaults to None.

load_cache_modifier (str, optional): Controls cache behavior of the load. Supported values are:
- None: *(default)* — Same as ".ca". Uses cache at all levels (CU, L2, LLC) with LRU policy.
- ".ca": Cache at all levels (CU, L2, LLC) with LRU policy.
- ".cg": Bypasses the CU (L1) cache, streams through L2, and may hit in LLC but the line is not retained or inserted.
- ".cv": Bypasses all GPU caches (CU and L2) and fetches directly from system memory. If data exists in the LLC, it may hit, but is not retained or inserted.

store_cache_modifier (str, optional): Controls cache behavior of the store. Supported values are:
- None: *(default)* — Same as ".wb". Uses write-back caching at all levels (CU, L2, LLC) with LRU policy.
- ".wb": Write-back. Write-allocate on L1 miss, inserted into caches and written back later.
- ".cg": Cache Global. Equivalent to ".wb" — stored through L1 → L2 → LLC under LRU.
- ".cs": Cache Streaming. Bypasses L1, streamed through L2, not retained in LLC.
- ".wt": Write-Through. Bypasses L1 and L2 (coherent cache bypass), may hit in LLC with LRU.

Returns:
None
Expand All @@ -1873,13 +1951,23 @@ def get(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None):
"""
translated_from_ptr = __translate(from_ptr, from_rank, to_rank, heap_bases)

data = tl.load(translated_from_ptr, mask=mask)
data = tl.load(translated_from_ptr, mask=mask, other=other, cache_modifier=load_cache_modifier)

tl.store(to_ptr, data, mask=mask)
tl.store(to_ptr, data, mask=mask, cache_modifier=store_cache_modifier)


@triton.jit
def put(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None):
def put(
from_ptr,
to_ptr,
from_rank,
to_rank,
heap_bases,
mask=None,
other=None,
load_cache_modifier=None,
store_cache_modifier=None,
):
"""
Copies data from the current rank's local memory to the specified rank's memory.
This function performs a memory write operation by loading data from the current
Expand All @@ -1894,6 +1982,20 @@ def put(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None):
to_rank (int): The `to_rank` ID to which the data will be written.
heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks.
mask (Block of triton.int1, optional): If mask[idx] is false, do not load the data at address from_ptr[idx] and do not store to to_ptr[idx]. Defaults to None.
other (Block, optional): Value to return for masked-out elements during the load operation. If not provided, the result for masked-out elements is undefined. Defaults to None.

load_cache_modifier (str, optional): Controls cache behavior of the load. Supported values are:
- None: *(default)* — Same as ".ca". Uses cache at all levels (CU, L2, LLC) with LRU policy.
- ".ca": Cache at all levels (CU, L2, LLC) with LRU policy.
- ".cg": Bypasses the CU (L1) cache, streams through L2, and may hit in LLC but the line is not retained or inserted.
- ".cv": Bypasses all GPU caches (CU and L2) and fetches directly from system memory. If data exists in the LLC, it may hit, but is not retained or inserted.

store_cache_modifier (str, optional): Controls cache behavior of the store. Supported values are:
- None: *(default)* — Same as ".wb". Uses write-back caching at all levels (CU, L2, LLC) with LRU policy.
- ".wb": Write-back. Write-allocate on L1 miss, inserted into caches and written back later.
- ".cg": Cache Global. Equivalent to ".wb" — stored through L1 → L2 → LLC under LRU.
- ".cs": Cache Streaming. Bypasses L1, streamed through L2, not retained in LLC.
- ".wt": Write-Through. Bypasses L1 and L2 (coherent cache bypass), may hit in LLC with LRU.

Returns:
None
Expand All @@ -1907,9 +2009,9 @@ def put(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None):
"""
translated_to_ptr = __translate(to_ptr, from_rank, to_rank, heap_bases)

data = tl.load(from_ptr, mask=mask)
data = tl.load(from_ptr, mask=mask, other=other, cache_modifier=load_cache_modifier)

tl.store(translated_to_ptr, data, mask=mask)
tl.store(translated_to_ptr, data, mask=mask, cache_modifier=store_cache_modifier)


@triton.jit
Expand Down
Loading
Loading