-
Notifications
You must be signed in to change notification settings - Fork 37
Port cache_modifier, volatile, and other to DeviceContext and Gluon APIs #471
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
fd2a3aa
4e1fa9b
791e624
b46c8a7
f5cfba2
59f07b8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -144,7 +144,7 @@ def _translate(self, ptr, from_rank, to_rank): | |
| return translated_ptr | ||
|
|
||
| @gluon.jit | ||
| def load(self, pointer, from_rank, mask=None, other=None): | ||
| def load(self, pointer, from_rank, mask=None, other=None, cache_modifier=None, volatile=False): | ||
| """ | ||
| Loads a value from the specified rank's memory location to the current rank. | ||
|
|
||
|
|
@@ -153,6 +153,17 @@ def load(self, pointer, from_rank, mask=None, other=None): | |
| from_rank: The rank ID from which to read the data | ||
| mask: Optional mask for conditional loading | ||
| other: Value to return for masked-out elements. If not provided, the result for masked-out elements is undefined. | ||
| cache_modifier (str, optional): Controls cache behavior of the load. | ||
|
|
||
| Supported values: | ||
| - None: *(default)* — Same as ".ca". Uses cache at all levels (CU, L2, LLC) with LRU policy. | ||
| - ".ca": Cache at all levels (CU, L2, LLC) with LRU policy. | ||
| - ".cg": Bypasses the CU (L1) cache, streams through L2, and may hit in LLC but the line is not retained or inserted. | ||
| - ".cv": Bypasses all GPU caches (CU and L2) and fetches directly from system memory. If data exists in the LLC, it may hit, but is not retained or inserted. | ||
| Ensures global coherence by invalidating stale GPU cache lines. | ||
|
|
||
| volatile (bool, optional): If True, disables compiler optimizations that | ||
| could reorder or eliminate the load. Defaults to False. | ||
|
|
||
| Returns: | ||
| The loaded value from the target memory location | ||
|
|
@@ -162,11 +173,11 @@ def load(self, pointer, from_rank, mask=None, other=None): | |
| >>> data = ctx.load(buffer + offsets, 1, mask=mask) | ||
| """ | ||
| translated_ptr = self._translate(pointer, self.cur_rank, from_rank) | ||
| result = gl.load(translated_ptr, mask=mask, other=other) | ||
| result = gl.load(translated_ptr, mask=mask, other=other, cache_modifier=cache_modifier, volatile=volatile) | ||
|
||
| return result | ||
|
|
||
| @gluon.jit | ||
| def store(self, pointer, value, to_rank, mask=None): | ||
| def store(self, pointer, value, to_rank, mask=None, cache_modifier=None): | ||
|
||
| """ | ||
| Writes data from the current rank to the specified rank's memory location. | ||
|
|
||
|
|
@@ -175,16 +186,25 @@ def store(self, pointer, value, to_rank, mask=None): | |
| value: The value to store | ||
| to_rank: The rank ID to which the data will be written | ||
| mask: Optional mask for conditional storing | ||
| cache_modifier (str, optional): Controls cache behavior of the store. Supported values are: | ||
|
|
||
| - None: *(default)* — Same as ".wb". Uses write-back caching at all levels (CU, L2, LLC) with LRU policy. | ||
| - ".wb": Write-back. Write-allocate on L1 miss, inserted into caches and written back later. | ||
| - ".cg": Cache Global. Equivalent to ".wb" — stored through L1 → L2 → LLC under LRU. | ||
| - ".cs": Cache Streaming. Bypasses L1, streamed through L2, not retained in LLC. | ||
| - ".wt": Write-Through. Bypasses L1 and L2 (coherent cache bypass), may hit in LLC with LRU. | ||
|
|
||
| Example: | ||
| >>> # Store from current rank to rank 1 | ||
| >>> ctx.store(buffer + offsets, values, 1, mask=mask) | ||
| """ | ||
| translated_ptr = self._translate(pointer, self.cur_rank, to_rank) | ||
| gl.store(translated_ptr, value, mask=mask) | ||
| gl.store(translated_ptr, value, mask=mask, cache_modifier=cache_modifier) | ||
|
|
||
| @gluon.jit | ||
| def get(self, from_ptr, to_ptr, from_rank, mask=None, other=None): | ||
| def get( | ||
| self, from_ptr, to_ptr, from_rank, mask=None, other=None, load_cache_modifier=None, store_cache_modifier=None | ||
| ): | ||
|
Comment on lines
204
to
+207
|
||
| """ | ||
| Copies data from the specified rank's memory to the current rank's local memory. | ||
|
|
||
|
|
@@ -194,17 +214,31 @@ def get(self, from_ptr, to_ptr, from_rank, mask=None, other=None): | |
| from_rank: The rank ID from which to read the data | ||
| mask: Optional mask for conditional operations | ||
| other: Value to return for masked-out elements during the load operation. If not provided, the result for masked-out elements is undefined. | ||
| load_cache_modifier (str, optional): Controls cache behavior of the load. Supported values are: | ||
| - None: *(default)* — Same as ".ca". Uses cache at all levels (CU, L2, LLC) with LRU policy. | ||
| - ".ca": Cache at all levels (CU, L2, LLC) with LRU policy. | ||
| - ".cg": Bypasses the CU (L1) cache, streams through L2, and may hit in LLC but the line is not retained or inserted. | ||
| - ".cv": Bypasses all GPU caches (CU and L2) and fetches directly from system memory. If data exists in the LLC, it may hit, but is not retained or inserted. | ||
|
|
||
| store_cache_modifier (str, optional): Controls cache behavior of the store. Supported values are: | ||
| - None: *(default)* — Same as ".wb". Uses write-back caching at all levels (CU, L2, LLC) with LRU policy. | ||
| - ".wb": Write-back. Write-allocate on L1 miss, inserted into caches and written back later. | ||
| - ".cg": Cache Global. Equivalent to ".wb" — stored through L1 → L2 → LLC under LRU. | ||
| - ".cs": Cache Streaming. Bypasses L1, streamed through L2, not retained in LLC. | ||
| - ".wt": Write-Through. Bypasses L1 and L2 (coherent cache bypass), may hit in LLC with LRU. | ||
|
|
||
| Example: | ||
| >>> # Copy from rank 1 to current rank's local memory | ||
| >>> ctx.get(remote_ptr + offsets, local_ptr + offsets, 1, mask=mask) | ||
| """ | ||
| translated_from_ptr = self._translate(from_ptr, self.cur_rank, from_rank) | ||
| data = gl.load(translated_from_ptr, mask=mask, other=other) | ||
| gl.store(to_ptr, data, mask=mask) | ||
| data = gl.load(translated_from_ptr, mask=mask, other=other, cache_modifier=load_cache_modifier) | ||
| gl.store(to_ptr, data, mask=mask, cache_modifier=store_cache_modifier) | ||
|
Comment on lines
+235
to
+236
|
||
|
|
||
| @gluon.jit | ||
| def put(self, from_ptr, to_ptr, to_rank, mask=None, other=None): | ||
| def put( | ||
| self, from_ptr, to_ptr, to_rank, mask=None, other=None, load_cache_modifier=None, store_cache_modifier=None | ||
| ): | ||
| """ | ||
| Copies data from the current rank's local memory to the specified rank's memory. | ||
|
|
||
|
|
@@ -214,17 +248,39 @@ def put(self, from_ptr, to_ptr, to_rank, mask=None, other=None): | |
| to_rank: The rank ID to which the data will be written | ||
| mask: Optional mask for conditional operations | ||
| other: Value to return for masked-out elements during the load operation. If not provided, the result for masked-out elements is undefined. | ||
| load_cache_modifier (str, optional): Controls cache behavior of the load. Supported values are: | ||
| - None: *(default)* — Same as ".ca". Uses cache at all levels (CU, L2, LLC) with LRU policy. | ||
| - ".ca": Cache at all levels (CU, L2, LLC) with LRU policy. | ||
| - ".cg": Bypasses the CU (L1) cache, streams through L2, and may hit in LLC but the line is not retained or inserted. | ||
| - ".cv": Bypasses all GPU caches (CU and L2) and fetches directly from system memory. If data exists in the LLC, it may hit, but is not retained or inserted. | ||
|
|
||
| store_cache_modifier (str, optional): Controls cache behavior of the store. Supported values are: | ||
| - None: *(default)* — Same as ".wb". Uses write-back caching at all levels (CU, L2, LLC) with LRU policy. | ||
| - ".wb": Write-back. Write-allocate on L1 miss, inserted into caches and written back later. | ||
| - ".cg": Cache Global. Equivalent to ".wb" — stored through L1 → L2 → LLC under LRU. | ||
| - ".cs": Cache Streaming. Bypasses L1, streamed through L2, not retained in LLC. | ||
| - ".wt": Write-Through. Bypasses L1 and L2 (coherent cache bypass), may hit in LLC with LRU. | ||
|
|
||
| Example: | ||
| >>> # Copy from current rank's local memory to rank 1 | ||
| >>> ctx.put(local_ptr + offsets, remote_ptr + offsets, 1, mask=mask) | ||
| """ | ||
| translated_to_ptr = self._translate(to_ptr, self.cur_rank, to_rank) | ||
| data = gl.load(from_ptr, mask=mask, other=other) | ||
| gl.store(translated_to_ptr, data, mask=mask) | ||
| data = gl.load(from_ptr, mask=mask, other=other, cache_modifier=load_cache_modifier) | ||
| gl.store(translated_to_ptr, data, mask=mask, cache_modifier=store_cache_modifier) | ||
|
|
||
| @gluon.jit | ||
| def copy(self, src_ptr, dst_ptr, from_rank, to_rank, mask=None, other=None): | ||
| def copy( | ||
| self, | ||
| src_ptr, | ||
| dst_ptr, | ||
| from_rank, | ||
| to_rank, | ||
| mask=None, | ||
| other=None, | ||
| load_cache_modifier=None, | ||
| store_cache_modifier=None, | ||
| ): | ||
| """ | ||
| Copies data from the specified rank's memory into the destination rank's memory. | ||
|
|
||
|
|
@@ -241,6 +297,18 @@ def copy(self, src_ptr, dst_ptr, from_rank, to_rank, mask=None, other=None): | |
| to_rank: The rank ID that will receive the data (destination rank) | ||
| mask: Optional mask for conditional operations | ||
| other: Value to return for masked-out elements during the load operation. If not provided, the result for masked-out elements is undefined. | ||
| load_cache_modifier (str, optional): Controls cache behavior of the load. Supported values are: | ||
| - None: *(default)* — Same as ".ca". Uses cache at all levels (CU, L2, LLC) with LRU policy. | ||
| - ".ca": Cache at all levels (CU, L2, LLC) with LRU policy. | ||
| - ".cg": Bypasses the CU (L1) cache, streams through L2, and may hit in LLC but the line is not retained or inserted. | ||
| - ".cv": Bypasses all GPU caches (CU and L2) and fetches directly from system memory. If data exists in the LLC, it may hit, but is not retained or inserted. | ||
|
|
||
| store_cache_modifier (str, optional): Controls cache behavior of the store. Supported values are: | ||
| - None: *(default)* — Same as ".wb". Uses write-back caching at all levels (CU, L2, LLC) with LRU policy. | ||
| - ".wb": Write-back. Write-allocate on L1 miss, inserted into caches and written back later. | ||
| - ".cg": Cache Global. Equivalent to ".wb" — stored through L1 → L2 → LLC under LRU. | ||
| - ".cs": Cache Streaming. Bypasses L1, streamed through L2, not retained in LLC. | ||
| - ".wt": Write-Through. Bypasses L1 and L2 (coherent cache bypass), may hit in LLC with LRU. | ||
|
|
||
| Example: | ||
| >>> # Copy from rank 1 to rank 0 (current rank must be either 1 or 0) | ||
|
|
@@ -262,8 +330,8 @@ def copy(self, src_ptr, dst_ptr, from_rank, to_rank, mask=None, other=None): | |
| translated_src = tl.cast(from_base_byte + src_offset, src_ptr.dtype) | ||
| translated_dst = tl.cast(to_base_byte + dst_offset, src_ptr.dtype) | ||
|
|
||
| data = gl.load(translated_src, mask=mask, other=other) | ||
| gl.store(translated_dst, data, mask=mask) | ||
| data = gl.load(translated_src, mask=mask, other=other, cache_modifier=load_cache_modifier) | ||
| gl.store(translated_dst, data, mask=mask, cache_modifier=store_cache_modifier) | ||
|
|
||
| @gluon.jit | ||
| def atomic_add(self, pointer, val, to_rank, mask=None, sem=None, scope=None): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to Triton, Gluon language load/store cache control arguments are typically compile-time constants. These new parameters (
cache_modifier,volatile,load_cache_modifier,store_cache_modifier) are not marked asgl.constexprin@gluon.jitmethods, which can lead to compilation/type errors when passing strings/bools. Consider annotating them asgl.constexpr(and defaulting toNone/Falseas you do now).