[Feature] Add TIR builtins for warp-level vote and block-level predicate sync#1858
[Feature] Add TIR builtins for warp-level vote and block-level predicate sync#1858LeiWang1999 merged 3 commits intotile-ai:mainfrom
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (5)
✅ Files skipped from review due to trivial changes (1)
🚧 Files skipped from review as they are similar to previous changes (3)
📝 WalkthroughWalkthroughAdds TileLang GPU synchronization, warp vote/ballot, warp-match, and shuffle intrinsics; registers new builtins; implements CUDA/HIP codegen paths; re-exports new APIs; extends docs; and adds CUDA-focused unit tests validating emission and runtime behavior. Changes
Sequence Diagram(s)sequenceDiagram
participant Frontend as TileLang frontend
participant Builtins as TL builtins registry
participant CodeGen as CodeGen (CUDA / HIP)
participant Device as GPU runtime
Frontend->>Builtins: create CallNode for intrinsic (tl.*)
Frontend->>CodeGen: VisitExpr_(CallNode)
CodeGen->>CodeGen: validate args, choose target lowering (CUDA/HIP)
CodeGen->>Device: emit kernel with target intrinsics (__any_sync, __ballot, __shfl, etc.)
Device->>Device: execute warp/block intrinsic
Device-->>Frontend: results stored to output buffers
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related issues
Possibly related PRs
Suggested labels
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (1)
tilelang/language/builtin.py (1)
941-959:ballot()duplicates logic instead of delegating toballot_sync().The docstring says "convenience wrapper around
ballot_sync" but the implementation reimplements the HIP/CUDA branching rather than callingballot_sync(0xFFFFFFFF, predicate). If theballot_syncimplementation ever changes (e.g., different cast strategy),ballotwill silently diverge.♻️ Proposed refactor to delegate to ballot_sync
def ballot(predicate: int | PrimExpr) -> PrimExpr: ... - if _IS_HIP_AVAILABLE: - return tir.cast("uint32", tir.call_extern("uint64", "__ballot", predicate)) - else: - return tir.call_extern("uint32", "__ballot_sync", tir.const(0xFFFFFFFF, "uint32"), predicate) + return ballot_sync(tir.const(0xFFFFFFFF, "uint32"), predicate)Similarly,
activemask()on HIP (line 972) reimplements the same__ballot+ cast pattern. Consider delegating:def activemask() -> PrimExpr: ... - if _IS_HIP_AVAILABLE: - return tir.cast("uint32", tir.call_extern("uint64", "__ballot", tir.const(1, "int32"))) - else: + if not _IS_HIP_AVAILABLE: return tir.call_extern("uint32", "__activemask") + return ballot(tir.const(1, "int32"))🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tilelang/language/builtin.py` around lines 941 - 959, The ballot() function duplicates the HIP/CUDA branching instead of delegating to ballot_sync; change ballot to simply return ballot_sync(tir.const(0xFFFFFFFF, "uint32"), predicate) (ensuring types match) so any future ballot_sync changes apply here; likewise update activemask() to delegate to ballot_sync/activemask_sync equivalent instead of reimplementing the __ballot + cast pattern (use the same uint32/uint64 casting behavior that ballot_sync already implements) and remove the duplicated call_extern/cast logic in ballot and activemask while keeping their public signatures unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@testing/python/language/test_tilelang_language_warp_vote.py`:
- Around line 26-40: The test declares an unused tensor parameter A in
kernel_any_sync.main and has an incorrect docstring; remove A from the kernel
signature (delete the T.Tensor((1,), "int32") parameter and any references to
A), update the docstring to accurately describe that lane 0 participates in the
any_sync predicate (e.g., "Lane 0 sets the predicate; all lanes use any_sync to
see if any lane matched"), and update any test invocation that passed an A
argument so it only supplies the B tensor to kernel_any_sync.
In `@tilelang/language/builtin.py`:
- Around line 918-938: The current ballot_sync function truncates HIP's 64-bit
(__ballot) result to uint32; modify the HIP branch (where _IS_HIP_AVAILABLE is
checked) to return the raw uint64 result via tir.call_extern("uint64",
"__ballot", predicate) instead of tir.cast("uint32", ...), and update the
function docstring and the inline comment to state that on HIP the function
returns uint64 (callers should narrow to uint32 when they know wavefronts are
32-wide); reference symbols: ballot_sync, _IS_HIP_AVAILABLE, tir.call_extern,
tir.cast.
---
Nitpick comments:
In `@tilelang/language/builtin.py`:
- Around line 941-959: The ballot() function duplicates the HIP/CUDA branching
instead of delegating to ballot_sync; change ballot to simply return
ballot_sync(tir.const(0xFFFFFFFF, "uint32"), predicate) (ensuring types match)
so any future ballot_sync changes apply here; likewise update activemask() to
delegate to ballot_sync/activemask_sync equivalent instead of reimplementing the
__ballot + cast pattern (use the same uint32/uint64 casting behavior that
ballot_sync already implements) and remove the duplicated call_extern/cast logic
in ballot and activemask while keeping their public signatures unchanged.
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (1)
tilelang/language/__init__.py (1)
103-110: Optional: Remove redundant# noqa: F401directives.Ruff flags these as unused since
F401is not enabled in the project's Ruff configuration. The same pattern exists on lines 94–102 (pre-existing, not flagged here). Consider a one-time cleanup of all such directives across lines 94–110 once the existing pattern is addressed.🧹 Proposed cleanup (lines 103–110 only)
-from .builtin import any_sync as any_sync # noqa: F401 -from .builtin import all_sync as all_sync # noqa: F401 -from .builtin import ballot_sync as ballot_sync # noqa: F401 -from .builtin import ballot as ballot # noqa: F401 -from .builtin import activemask as activemask # noqa: F401 -from .builtin import syncthreads_count as syncthreads_count # noqa: F401 -from .builtin import syncthreads_and as syncthreads_and # noqa: F401 -from .builtin import syncthreads_or as syncthreads_or # noqa: F401 +from .builtin import any_sync as any_sync +from .builtin import all_sync as all_sync +from .builtin import ballot_sync as ballot_sync +from .builtin import ballot as ballot +from .builtin import activemask as activemask +from .builtin import syncthreads_count as syncthreads_count +from .builtin import syncthreads_and as syncthreads_and +from .builtin import syncthreads_or as syncthreads_or🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tilelang/language/__init__.py` around lines 103 - 110, Remove the redundant "# noqa: F401" directives from the import statements that re-export builtin symbols; specifically edit the import lines that reference any_sync, all_sync, ballot_sync, ballot, activemask, syncthreads_count, syncthreads_and, and syncthreads_or in __init__.py and delete the trailing " # noqa: F401" from each import; ensure the names remain imported (no other code changes) and run the linter/flake check to confirm no warnings remain.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@docs/programming_guides/instructions.md`:
- Around line 155-158: Add a brief HIP version note stating that
T.syncthreads_count, T.syncthreads_and, and T.syncthreads_or require ROCm/HIP ≥
7.0 (they are only under development in HIP ≤ 6.2); update the three entries
(T.syncthreads_count(predicate), T.syncthreads_and(predicate),
T.syncthreads_or(predicate)) to append “HIP: ROCm/hip ≥ 7.0” or expand the
existing HIP note that covers vote/ballot caveats to explicitly call out these
predicated block-wide sync functions and their minimum HIP version requirement.
In `@tilelang/language/builtin.py`:
- Around line 982-1024: The docstrings for syncthreads_count, syncthreads_and,
and syncthreads_or incorrectly claim the intrinsics map to CUDA and HIP
unconditionally; update each function's docstring (syncthreads_count,
syncthreads_and, syncthreads_or) to note the HIP/ROCm version constraints: that
the __syncthreads_count/and/or variants are under development in HIP 6.2 and
that __sync variants became available in ROCm 7.0 (and are enabled by default
there), so on older HIP/ROCm stacks these intrinsics may not be available; keep
the CUDA mapping text unchanged and add a brief sentence about the HIP version
caveat.
---
Duplicate comments:
In `@testing/python/language/test_tilelang_language_warp_vote.py`:
- Around line 26-52: The kernel declares A in kernel_any_sync -> main but never
uses it and the docstring is wrong; either remove the unused parameter A from
main and update the docstring and test_any_sync to stop allocating/passing a, or
implement the intended behavior: have lane 0 write 1 into A (e.g., inside
T.Kernel when tx == 0 set A[0]=1) and change the any_sync predicate to read
A[0]==1 so the call sites (test_any_sync's allocation/passing of a) remain
correct; update kernel_any_sync's docstring accordingly.
In `@tilelang/language/builtin.py`:
- Around line 918-974: The HIP implementations of ballot_sync, ballot, and
activemask incorrectly cast the uint64 result of __ballot to uint32, discarding
lanes 32–63 on wave-64 targets; update the HIP paths in ballot_sync, ballot, and
activemask to return 64-bit masks (use tir.call_extern with "uint64" and return
a uint64 PrimExpr, and avoid casting to "uint32"), and ensure ballot_sync's mask
parameter and any callers that expect a 32-bit return are adjusted to use 64-bit
mask types so full wave-64 lane data is preserved.
---
Nitpick comments:
In `@tilelang/language/__init__.py`:
- Around line 103-110: Remove the redundant "# noqa: F401" directives from the
import statements that re-export builtin symbols; specifically edit the import
lines that reference any_sync, all_sync, ballot_sync, ballot, activemask,
syncthreads_count, syncthreads_and, and syncthreads_or in __init__.py and delete
the trailing " # noqa: F401" from each import; ensure the names remain imported
(no other code changes) and run the linter/flake check to confirm no warnings
remain.
| Block-wide predicated sync | ||
| - `T.syncthreads_count(predicate)` → `int32`: Sync all threads; return count with non-zero predicate (`__syncthreads_count`). | ||
| - `T.syncthreads_and(predicate)` → `int32`: Sync; non-zero iff ALL threads have non-zero predicate (`__syncthreads_and`). | ||
| - `T.syncthreads_or(predicate)` → `int32`: Sync; non-zero iff ANY thread has non-zero predicate (`__syncthreads_or`). |
There was a problem hiding this comment.
Missing HIP version note for syncthreads_count/and/or.
The block-wide predicated sync entries do not mention the ROCm ≥ 7.0 requirement noted in the PR description. The __syncthreads_count(int), __syncthreads_and(int), and __syncthreads_or(int) functions are under development in HIP 6.2 and below. The HIP note at line 166 addresses vote/ballot HIP caveats but omits this. Consider appending a brief version note to the syncthreads_count/and/or entries or expanding the HIP note to cover them.
📝 Suggested doc patch
Block-wide predicated sync
-- `T.syncthreads_count(predicate)` → `int32`: Sync all threads; return count with non-zero predicate (`__syncthreads_count`).
-- `T.syncthreads_and(predicate)` → `int32`: Sync; non-zero iff ALL threads have non-zero predicate (`__syncthreads_and`).
-- `T.syncthreads_or(predicate)` → `int32`: Sync; non-zero iff ANY thread has non-zero predicate (`__syncthreads_or`).
+- `T.syncthreads_count(predicate)` → `int32`: Sync all threads; return count with non-zero predicate (`__syncthreads_count`). *(HIP: requires ROCm ≥ 7.0)*
+- `T.syncthreads_and(predicate)` → `int32`: Sync; non-zero iff ALL threads have non-zero predicate (`__syncthreads_and`). *(HIP: requires ROCm ≥ 7.0)*
+- `T.syncthreads_or(predicate)` → `int32`: Sync; non-zero iff ANY thread has non-zero predicate (`__syncthreads_or`). *(HIP: requires ROCm ≥ 7.0)*📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| Block-wide predicated sync | |
| - `T.syncthreads_count(predicate)` → `int32`: Sync all threads; return count with non-zero predicate (`__syncthreads_count`). | |
| - `T.syncthreads_and(predicate)` → `int32`: Sync; non-zero iff ALL threads have non-zero predicate (`__syncthreads_and`). | |
| - `T.syncthreads_or(predicate)` → `int32`: Sync; non-zero iff ANY thread has non-zero predicate (`__syncthreads_or`). | |
| Block-wide predicated sync | |
| - `T.syncthreads_count(predicate)` → `int32`: Sync all threads; return count with non-zero predicate (`__syncthreads_count`). *(HIP: requires ROCm ≥ 7.0)* | |
| - `T.syncthreads_and(predicate)` → `int32`: Sync; non-zero iff ALL threads have non-zero predicate (`__syncthreads_and`). *(HIP: requires ROCm ≥ 7.0)* | |
| - `T.syncthreads_or(predicate)` → `int32`: Sync; non-zero iff ANY thread has non-zero predicate (`__syncthreads_or`). *(HIP: requires ROCm ≥ 7.0)* |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@docs/programming_guides/instructions.md` around lines 155 - 158, Add a brief
HIP version note stating that T.syncthreads_count, T.syncthreads_and, and
T.syncthreads_or require ROCm/HIP ≥ 7.0 (they are only under development in HIP
≤ 6.2); update the three entries (T.syncthreads_count(predicate),
T.syncthreads_and(predicate), T.syncthreads_or(predicate)) to append “HIP:
ROCm/hip ≥ 7.0” or expand the existing HIP note that covers vote/ballot caveats
to explicitly call out these predicated block-wide sync functions and their
minimum HIP version requirement.
|
@sepcnt Thanks and would be better to wrap the instruction with |
There was a problem hiding this comment.
🧹 Nitpick comments (2)
tilelang/language/builtin.py (2)
881-884: Minor: Replace EN DASH with HYPHEN-MINUS in docstring.Static analysis flagged an ambiguous EN DASH (
–) character at line 882. For consistency and to avoid potential encoding issues, use a standard hyphen-minus (-).📝 Suggested fix
CUDA: ``__any_sync(mask, predicate)`` - HIP: ``__any(predicate)`` (mask argument is ignored – HIP uses the full + HIP: ``__any(predicate)`` (mask argument is ignored - HIP uses the full wavefront).🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tilelang/language/builtin.py` around lines 881 - 884, Replace the ambiguous EN DASH with a standard hyphen-minus in the docstring line describing HIP behavior: change the phrase "mask argument is ignored – HIP uses the full wavefront" to use a normal hyphen-minus ("mask argument is ignored - HIP uses the full wavefront") so the text near the CUDA: __any_sync(mask, predicate) / HIP: __any(predicate) examples uses a standard ASCII hyphen for consistency and encoding safety.
877-895: Consider usingtir.call_intrininstead oftir.call_externfor these intrinsics.A PR comment from
@LeiWang1999recommends wrapping the instructions withT.call_intrininstead ofT.call_extern. This would allow for better integration with TIR's intrinsic system and enable downstream passes to recognize and potentially optimize these operations.However, I note that the existing
shfl_xor,shfl_down, andshfl_upfunctions (lines 805-845) also usetir.call_extern, so the current approach is consistent with the existing codebase. If you decide to switch tocall_intrin, consider updating the shuffle functions as well for uniformity.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tilelang/language/builtin.py` around lines 877 - 895, Replace the use of tir.call_extern in the any_sync function with tir.call_intrin so the operation is emitted as a TIR intrinsic (keep the return dtype int32 and pass the same arguments: mask and predicate for CUDA path, predicate only for HIP path); also consider updating the related shuffle helpers (shfl_xor, shfl_down, shfl_up) to use tir.call_intrin for consistency with any_sync so downstream TIR passes can recognize and optimize these intrinsics.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@tilelang/language/builtin.py`:
- Around line 881-884: Replace the ambiguous EN DASH with a standard
hyphen-minus in the docstring line describing HIP behavior: change the phrase
"mask argument is ignored – HIP uses the full wavefront" to use a normal
hyphen-minus ("mask argument is ignored - HIP uses the full wavefront") so the
text near the CUDA: __any_sync(mask, predicate) / HIP: __any(predicate) examples
uses a standard ASCII hyphen for consistency and encoding safety.
- Around line 877-895: Replace the use of tir.call_extern in the any_sync
function with tir.call_intrin so the operation is emitted as a TIR intrinsic
(keep the return dtype int32 and pass the same arguments: mask and predicate for
CUDA path, predicate only for HIP path); also consider updating the related
shuffle helpers (shfl_xor, shfl_down, shfl_up) to use tir.call_intrin for
consistency with any_sync so downstream TIR passes can recognize and optimize
these intrinsics.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 2bc0da8d-252d-4847-9258-2848fc693514
📒 Files selected for processing (4)
docs/programming_guides/instructions.mdtesting/python/language/test_tilelang_language_warp_vote.pytilelang/language/__init__.pytilelang/language/builtin.py
🚧 Files skipped from review as they are similar to previous changes (1)
- tilelang/language/init.py
Convert the warp-vote, warp-shuffle, predicated block-sync, and warp-match builtins from raw `tir.call_extern` wrappers to proper TIR ops registered under `tl.*` and lowered in `codegen_cuda` / `codegen_hip`. This removes Python-side `_IS_HIP_AVAILABLE` branching and pushes the CUDA/HIP split into codegen, where it belongs. * Register new tl ops: any_sync, all_sync, ballot_sync, ballot, activemask, syncthreads_count/and/or, shfl_sync, shfl_xor_sync, shfl_down_sync, shfl_up_sync, match_any_sync, match_all_sync. * Codegen lowering on both CUDA and HIP. uint32->uint64 zero-extension for ballot/activemask now happens in codegen. HIP drops the mask argument for shfl/vote and emits LOG(FATAL) for match_*_sync. * `__match_all_sync`'s int* pred output is hidden behind an immediately-invoked lambda so the wrapper stays expression-form. * Python wrappers normalize Python int masks to uint32 TIR consts via `_as_uint32_mask`, so the emitted source prints `(uint)0xFFFFFFFF` instead of `(int64_t)4294967295`. * Unify `shfl_xor/down/up` signatures to `(mask, value, delta, width=32)`, matching CUDA convention and `shfl_sync`. No in-tree callers existed. * Add tests for match_any_sync and match_all_sync. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Resolve conflict in src/target/codegen_hip.cc: keep the warp-vote / shuffle / match codegen branches added in this branch, and adopt main's new unified packed-x2 dispatch (add2/sub2/mul2/fma2/max2/min2/abs2) that replaces the standalone fadd2/fmul2/fma2 blocks.
There was a problem hiding this comment.
Actionable comments posted: 3
♻️ Duplicate comments (2)
tilelang/language/builtin.py (1)
954-971:⚠️ Potential issue | 🟡 MinorAdd HIP/ROCm version caveat to
syncthreads_*docstrings.These docstrings still read as unconditional HIP support. Please note the ROCm/HIP availability constraint to prevent user confusion on older stacks.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tilelang/language/builtin.py` around lines 954 - 971, Update the docstrings for syncthreads_count, syncthreads_and, and syncthreads_or to explicitly note that while these map to CUDA __syncthreads_* intrinsics, HIP/ROCm support can be limited on older ROCm stacks and may not be available or behave differently; edit the triple-quoted docstrings in the functions syncthreads_count, syncthreads_and, and syncthreads_or to add a short caveat sentence about ROCm/HIP availability and recommend checking the target ROCm version/driver for support.docs/programming_guides/instructions.md (1)
155-170:⚠️ Potential issue | 🟡 MinorAdd an explicit ROCm/HIP version caveat for
syncthreads_count/and/or.This section still reads as universally available on HIP; please append a minimum-version note (ROCm/HIP >= 7.0) for
T.syncthreads_count,T.syncthreads_and, andT.syncthreads_or.📝 Suggested doc patch
- - `T.syncthreads_count(predicate)` → `int32`: Sync all threads; return count with non-zero predicate (`__syncthreads_count`). - - `T.syncthreads_and(predicate)` → `int32`: Sync; non-zero iff ALL threads have non-zero predicate (`__syncthreads_and`). - - `T.syncthreads_or(predicate)` → `int32`: Sync; non-zero iff ANY thread has non-zero predicate (`__syncthreads_or`). + - `T.syncthreads_count(predicate)` → `int32`: Sync all threads; return count with non-zero predicate (`__syncthreads_count`). *(HIP: requires ROCm/HIP ≥ 7.0)* + - `T.syncthreads_and(predicate)` → `int32`: Sync; non-zero iff ALL threads have non-zero predicate (`__syncthreads_and`). *(HIP: requires ROCm/HIP ≥ 7.0)* + - `T.syncthreads_or(predicate)` → `int32`: Sync; non-zero iff ANY thread has non-zero predicate (`__syncthreads_or`). *(HIP: requires ROCm/HIP ≥ 7.0)*Do official HIP/ROCm docs indicate that `__syncthreads_count`, `__syncthreads_and`, and `__syncthreads_or` require ROCm/HIP 7.0+ (or are incomplete in HIP 6.x)?🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@docs/programming_guides/instructions.md` around lines 155 - 170, Add a ROCm/HIP minimum-version caveat for the block-wide predicated sync entries: update the descriptions for T.syncthreads_count, T.syncthreads_and, and T.syncthreads_or to append a short note like “(ROCm/HIP >= 7.0 required)” and/or a parenthetical sentence that these intrinsics are unavailable or will fail to codegen on HIP < 7.0; ensure the same three symbol names (T.syncthreads_count, T.syncthreads_and, T.syncthreads_or) are updated in the Warp-shuffle/Warp-match note where syncthreads_count/and/or are discussed so the HIP caveat appears consistently.
🧹 Nitpick comments (1)
testing/python/language/test_tilelang_language_warp_vote.py (1)
42-406: Consider adding HIP-targeted tests; current coverage is CUDA-only.The suite documents HIP behavior but only runs under
@requires_cuda. Adding@requires_hipcoverage for wave64 ballot/activemask and expectedmatch_*_syncbehavior would de-risk backend parity.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@testing/python/language/test_tilelang_language_warp_vote.py` around lines 42 - 406, Tests only run for CUDA; add HIP-targeted equivalents so HIP-specific behavior (wave64 ballot/activemask and match_*_sync semantics) is validated. Duplicate the relevant test functions (e.g., test_ballot, test_activemask, test_ballot_sync, test_match_any_sync, test_match_all_sync, test_activemask, test_match_any_sync) or parameterize them to run under HIP by replacing/adding the `@tilelang.testing.requires_cuda` decorator with `@tilelang.testing.requires_hip` (or adding a new HIP-decorated test next to the CUDA one), and update the expected bitmask assertions in the HIP variants to reflect wave64 behavior (upper/lower 32-bit expectations and width-specific masks) so the tests assert the correct HIP/wave64 results.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/target/codegen_hip.cc`:
- Around line 848-857: The branches that currently emit HIP intrinsics for
tl::syncthreads_count, tl::syncthreads_and, and tl::syncthreads_or should not
unconditionally print __syncthreads_count/and/or; replace those emissions with
the same unsupported-operation handling used for tl::match_any_sync (e.g., emit
LOG(FATAL) with a clear message that tl::syncthreads_count / tl::syncthreads_and
/ tl::syncthreads_or are unsupported on HIP) instead of calling PrintExpr and
printing __syncthreads_* so compilation failures are avoided.
In `@testing/python/language/test_tilelang_language_warp_vote.py`:
- Around line 1-13: Replace all en dash characters (–) in the module docstring
with ASCII hyphens (-) to satisfy Ruff RUF002; specifically update the lines
listing intrinsics (e.g., the lines containing "T.any_sync – __any_sync / __any
(HIP)", "T.all_sync – __all_sync / __all (HIP)", and all other entries like
"T.ballot_sync – __ballot_sync→uint64", "T.activemask – __activemask→uint64",
etc.) so every occurrence of “–” becomes "-" and any similar en-dash uses
elsewhere in the docstring are normalized to ASCII hyphens.
In `@tilelang/language/builtin.py`:
- Around line 823-829: The shuffle wrapper functions (tl.shfl_xor_sync,
tl.shfl_xor, tl.shfl_up_sync, tl.shfl_down_sync) assume .dtype on `value` and
pass `delta`, `width`, `srcLane` directly to tir.call_intrin, which breaks for
Python literal ints; update each wrapper to coerce Python ints/Primitives into
TIR expressions before use: convert `value` (if an int) to a tir.const/TIR
PrimExpr and use its .dtype, and convert `delta`, `width`, `srcLane` (where
applicable) via a helper (e.g. _as_primexpr or tir.const) before passing them to
tir.call_intrin; keep the existing _as_uint32_mask(mask) usage for mask. Ensure
you modify the four functions named above so no literal int is accessed for
.dtype or passed raw into tir.call_intrin.
---
Duplicate comments:
In `@docs/programming_guides/instructions.md`:
- Around line 155-170: Add a ROCm/HIP minimum-version caveat for the block-wide
predicated sync entries: update the descriptions for T.syncthreads_count,
T.syncthreads_and, and T.syncthreads_or to append a short note like “(ROCm/HIP
>= 7.0 required)” and/or a parenthetical sentence that these intrinsics are
unavailable or will fail to codegen on HIP < 7.0; ensure the same three symbol
names (T.syncthreads_count, T.syncthreads_and, T.syncthreads_or) are updated in
the Warp-shuffle/Warp-match note where syncthreads_count/and/or are discussed so
the HIP caveat appears consistently.
In `@tilelang/language/builtin.py`:
- Around line 954-971: Update the docstrings for syncthreads_count,
syncthreads_and, and syncthreads_or to explicitly note that while these map to
CUDA __syncthreads_* intrinsics, HIP/ROCm support can be limited on older ROCm
stacks and may not be available or behave differently; edit the triple-quoted
docstrings in the functions syncthreads_count, syncthreads_and, and
syncthreads_or to add a short caveat sentence about ROCm/HIP availability and
recommend checking the target ROCm version/driver for support.
---
Nitpick comments:
In `@testing/python/language/test_tilelang_language_warp_vote.py`:
- Around line 42-406: Tests only run for CUDA; add HIP-targeted equivalents so
HIP-specific behavior (wave64 ballot/activemask and match_*_sync semantics) is
validated. Duplicate the relevant test functions (e.g., test_ballot,
test_activemask, test_ballot_sync, test_match_any_sync, test_match_all_sync,
test_activemask, test_match_any_sync) or parameterize them to run under HIP by
replacing/adding the `@tilelang.testing.requires_cuda` decorator with
`@tilelang.testing.requires_hip` (or adding a new HIP-decorated test next to the
CUDA one), and update the expected bitmask assertions in the HIP variants to
reflect wave64 behavior (upper/lower 32-bit expectations and width-specific
masks) so the tests assert the correct HIP/wave64 results.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 1fc66cbb-c9d3-4f11-b0d7-9e680e426870
📒 Files selected for processing (7)
docs/programming_guides/instructions.mdsrc/op/builtin.ccsrc/op/builtin.hsrc/target/codegen_cuda.ccsrc/target/codegen_hip.cctesting/python/language/test_tilelang_language_warp_vote.pytilelang/language/builtin.py
| } else if (op->op.same_as(tl::syncthreads_count())) { | ||
| ICHECK_EQ(op->args.size(), 1U) | ||
| << "tl.syncthreads_count expects <predicate>."; | ||
| os << "__syncthreads_count(" << PrintExpr(op->args[0]) << ")"; | ||
| } else if (op->op.same_as(tl::syncthreads_and())) { | ||
| ICHECK_EQ(op->args.size(), 1U) << "tl.syncthreads_and expects <predicate>."; | ||
| os << "__syncthreads_and(" << PrintExpr(op->args[0]) << ")"; | ||
| } else if (op->op.same_as(tl::syncthreads_or())) { | ||
| ICHECK_EQ(op->args.size(), 1U) << "tl.syncthreads_or expects <predicate>."; | ||
| os << "__syncthreads_or(" << PrintExpr(op->args[0]) << ")"; |
There was a problem hiding this comment.
🧩 Analysis chain
🌐 Web query:
What ROCm/HIP version first fully supports HIP device intrinsics __syncthreads_count, __syncthreads_and, and __syncthreads_or on AMD GPUs?
💡 Result:
No ROCm/HIP version fully supports the HIP device intrinsics syncthreads_count, syncthreads_and, and syncthreads_or on AMD GPUs. These are block-level synchronization primitives with predicates, analogous to CUDA's, but official HIP documentation and changelogs do not list them as implemented. Older docs (e.g., ROCm 5.x) explicitly state they are "under development". Recent docs (HIP 7.x) mention warp-level sync variants like ballot and sync shfl but omit these block-level ones. They appear in HIP-CPU (CPU emulation) and rocPRIM (library), but not core HIP language for AMD GPUs. Issues requesting them remain open (e.g., HIP-CPU #58). Warp sync like syncwarp added in ROCm 7.0, but not block sync with count/and/or.
Citations:
- 1: https://rocm.docs.amd.com/projects/HIP/en/develop/reference/cpp_language_extensions.html
- 2: [Feature]: missing synchronisation functions ROCm/HIP-CPU#58
- 3: [HIP][device] 4 __shfl_sync functions are missing ROCm/hip#1491
- 4: https://github.com/ROCm-Developer-Tools/HIP-CPU/blob/master/src/include/hip/detail/intrinsics.hpp
- 5: https://github.com/ROCm/rocm-systems/blob/develop/projects/clr/CHANGELOG.md
- 6: https://github.com/ROCm-Developer-Tools/HIP/blob/master/docs/markdown/hip_kernel_language.md
- 7: https://rocm.docs.amd.com/projects/HIP/en/docs-5.7.0/reference/kernel_language.html
- 8: https://github.com/GPUOpen-ProfessionalCompute-Tools/HIP/blob/master/docs/markdown/hip_kernel_language.md
- 9: https://rocm.docs.amd.com/projects/HIP/en/docs-7.0.0/reference/hip%5Fruntime%5Fapi/modules/cooperative_groups_reference.html
- 10: https://rocm.docs.amd.com/projects/HIP/en/latest/reference/api_syntax.html
🏁 Script executed:
# Check if there's any wrapper, guard, or fallback mechanism for these HIP intrinsics in the codebase
rg -B 5 -A 5 "syncthreads_count|syncthreads_and|syncthreads_or" src/target/codegen_hip.cc
# Also check for any `#ifdef` guards or version checks around this section
rg -B 10 "syncthreads_count" src/target/codegen_hip.cc | head -40Repository: tile-ai/tilelang
Length of output: 2130
🏁 Script executed:
# Check if there's any HIP version macro or minimum version requirement set globally
rg "__HIP_|HIP_VERSION|ROCm|HIP_COMMON_HEADER_CXX_VERSION" src/target/codegen_hip.cc | head -20
# Look for any fallback definitions of these intrinsics elsewhere
fd -t f '\.(cc|h|hpp)$' src/ | xargs rg "define.*syncthreads_count|define.*syncthreads_and|define.*syncthreads_or" 2>/dev/nullRepository: tile-ai/tilelang
Length of output: 44
🏁 Script executed:
# Check the full context around lines 828-883 mentioned in the scratchpad
sed -n '828,883p' src/target/codegen_hip.ccRepository: tile-ai/tilelang
Length of output: 3403
Replace unconditional HIP intrinsic emission with explicit unsupported error.
Lines 848–857 unconditionally emit __syncthreads_count, __syncthreads_and, and __syncthreads_or, but these intrinsics do not exist in HIP for AMD GPUs across any ROCm version. Code using these operations will fail to compile. Align with the existing pattern for unsupported operations (e.g., tl::match_any_sync): either emit a LOG(FATAL) with a clear error message or remove the code paths entirely.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/target/codegen_hip.cc` around lines 848 - 857, The branches that
currently emit HIP intrinsics for tl::syncthreads_count, tl::syncthreads_and,
and tl::syncthreads_or should not unconditionally print
__syncthreads_count/and/or; replace those emissions with the same
unsupported-operation handling used for tl::match_any_sync (e.g., emit
LOG(FATAL) with a clear message that tl::syncthreads_count / tl::syncthreads_and
/ tl::syncthreads_or are unsupported on HIP) instead of calling PrintExpr and
printing __syncthreads_* so compilation failures are avoided.
| """Tests for warp-vote / warp-ballot / block-sync-with-predicate intrinsics. | ||
|
|
||
| Covered intrinsics | ||
| ------------------ | ||
| T.any_sync – __any_sync / __any (HIP) | ||
| T.all_sync – __all_sync / __all (HIP) | ||
| T.ballot_sync – __ballot_sync→uint64 (CUDA, zero-ext) / __ballot uint64 (HIP, all lanes) | ||
| T.ballot – full-warp ballot_sync / __ballot uint64 (HIP, all lanes) | ||
| T.activemask – __activemask→uint64 (CUDA, zero-ext) / __ballot(1) uint64 (HIP, all lanes) | ||
| T.syncthreads_count – __syncthreads_count | ||
| T.syncthreads_and – __syncthreads_and | ||
| T.syncthreads_or – __syncthreads_or | ||
| """ |
There was a problem hiding this comment.
Replace EN DASH characters in docstrings to satisfy Ruff (RUF002).
Use ASCII - instead of – to avoid lint failures.
Also applies to: 192-192
🧰 Tools
🪛 Ruff (0.15.9)
[warning] 5-5: Docstring contains ambiguous – (EN DASH). Did you mean - (HYPHEN-MINUS)?
(RUF002)
[warning] 6-6: Docstring contains ambiguous – (EN DASH). Did you mean - (HYPHEN-MINUS)?
(RUF002)
[warning] 7-7: Docstring contains ambiguous – (EN DASH). Did you mean - (HYPHEN-MINUS)?
(RUF002)
[warning] 8-8: Docstring contains ambiguous – (EN DASH). Did you mean - (HYPHEN-MINUS)?
(RUF002)
[warning] 9-9: Docstring contains ambiguous – (EN DASH). Did you mean - (HYPHEN-MINUS)?
(RUF002)
[warning] 10-10: Docstring contains ambiguous – (EN DASH). Did you mean - (HYPHEN-MINUS)?
(RUF002)
[warning] 11-11: Docstring contains ambiguous – (EN DASH). Did you mean - (HYPHEN-MINUS)?
(RUF002)
[warning] 12-12: Docstring contains ambiguous – (EN DASH). Did you mean - (HYPHEN-MINUS)?
(RUF002)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@testing/python/language/test_tilelang_language_warp_vote.py` around lines 1 -
13, Replace all en dash characters (–) in the module docstring with ASCII
hyphens (-) to satisfy Ruff RUF002; specifically update the lines listing
intrinsics (e.g., the lines containing "T.any_sync – __any_sync / __any (HIP)",
"T.all_sync – __all_sync / __all (HIP)", and all other entries like
"T.ballot_sync – __ballot_sync→uint64", "T.activemask – __activemask→uint64",
etc.) so every occurrence of “–” becomes "-" and any similar en-dash uses
elsewhere in the docstring are normalized to ASCII hyphens.
| mask: int | PrimExpr, value: int | PrimExpr | tir.Call, delta: int | PrimExpr | tir.Call, width: int | PrimExpr = _DEFAULT_SHFL_WIDTH | ||
| ): | ||
| """XOR-swap ``value`` across lanes (``__shfl_xor_sync`` on CUDA, | ||
| ``__shfl_xor`` on HIP — mask ignored on HIP). | ||
| """ | ||
| if _IS_HIP_AVAILABLE: | ||
| return tir.call_extern(value.dtype, "__shfl_down", value, offset) | ||
| else: | ||
| return tir.call_extern(value.dtype, "__shfl_down_sync", 0xFFFFFFFF, value, offset) | ||
| return tir.call_intrin(value.dtype, tir.op.Op.get("tl.shfl_xor_sync"), _as_uint32_mask(mask), value, delta, width) | ||
|
|
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, locate the file and check its structure
find . -name "builtin.py" -path "*tilelang/language*" | head -5Repository: tile-ai/tilelang
Length of output: 90
🏁 Script executed:
# Read the relevant lines from the file to verify the claim
head -900 tilelang/language/builtin.py | tail -100Repository: tile-ai/tilelang
Length of output: 4036
🏁 Script executed:
# Also search for the _as_uint32_mask helper and convert function
rg -n "_as_uint32_mask|def convert" tilelang/language/builtin.py | head -20Repository: tile-ai/tilelang
Length of output: 1138
🏁 Script executed:
# Look for the complete shuffle function definitions
rg -A 5 "def shfl_xor|def shfl_down|def shfl_up|def shfl_sync" tilelang/language/builtin.pyRepository: tile-ai/tilelang
Length of output: 1307
🏁 Script executed:
# Check if convert() function exists in the file or is imported
rg -n "^from|^import|def convert" tilelang/language/builtin.py | head -30Repository: tile-ai/tilelang
Length of output: 604
🏁 Script executed:
# Look at the imports and beginning of the file
head -50 tilelang/language/builtin.pyRepository: tile-ai/tilelang
Length of output: 1791
🏁 Script executed:
# Check if there are tests for these shuffle functions with literal int values
find . -name "*test*builtin*" -o -name "*builtin*test*" | grep -i testRepository: tile-ai/tilelang
Length of output: 42
🏁 Script executed:
# Search for usage of shfl_xor, shfl_down, shfl_up, shfl_sync in tests
rg "shfl_xor|shfl_down|shfl_up|shfl_sync" --type py -l | head -10Repository: tile-ai/tilelang
Length of output: 148
🏁 Script executed:
# Check the test file for shuffle function usage
cat testing/python/language/test_tilelang_language_warp_sync.pyRepository: tile-ai/tilelang
Length of output: 1542
🏁 Script executed:
# Also check what _as_uint32_mask does and how it handles int vs PrimExpr
sed -n '811,825p' tilelang/language/builtin.pyRepository: tile-ai/tilelang
Length of output: 651
🏁 Script executed:
# Look at how tir.call_intrin handles its parameters
rg -B 5 -A 5 "call_intrin.*\.dtype" tilelang/language/builtin.py | head -50Repository: tile-ai/tilelang
Length of output: 2211
🏁 Script executed:
# Check other functions to see how they handle dtype extraction
rg -B 3 -A 3 "tir.call_intrin" tilelang/language/builtin.py | grep -A 3 -B 3 "\.dtype"Repository: tile-ai/tilelang
Length of output: 1729
🏁 Script executed:
# Look for how convert() is used elsewhere in the file
rg "convert\(" tilelang/language/builtin.pyRepository: tile-ai/tilelang
Length of output: 528
🏁 Script executed:
# Check if there are other places where int | PrimExpr parameters are properly converted
sed -n '17,40p' tilelang/language/builtin.pyRepository: tile-ai/tilelang
Length of output: 789
🏁 Script executed:
# Look at a complete example of how normalize_index_arg is used
rg -B 2 -A 2 "_normalize_index_arg" tilelang/language/builtin.py | head -30Repository: tile-ai/tilelang
Length of output: 1209
🏁 Script executed:
# Final check: Look at one more case to confirm how delta/width are handled
sed -n '823,829p' tilelang/language/builtin.pyRepository: tile-ai/tilelang
Length of output: 447
🏁 Script executed:
# Verify that passing literal int to these params would also fail
python3 << 'EOF'
# Simulate what happens when int is passed to call_intrin
class MockPrimExpr:
def __init__(self):
self.dtype = "int32"
# This is what happens in the current code
value = 42 # literal int
try:
dtype = value.dtype
except AttributeError as e:
print(f"✗ Error: {e}")
# With conversion
expr = MockPrimExpr()
dtype = expr.dtype
print(f"✓ After conversion: dtype = {dtype}")
EOFRepository: tile-ai/tilelang
Length of output: 140
Fix literal int handling in shuffle wrappers before accessing .dtype.
The value parameter accepts int | PrimExpr | tir.Call, but all four shuffle functions directly access value.dtype without converting literal integers first. This will raise AttributeError: 'int' object has no attribute 'dtype' at runtime when a Python literal is passed.
Additionally, delta, width, and srcLane parameters also accept int | PrimExpr but are passed unconverted to tir.call_intrin(), which expects TIR expressions for proper type handling.
Proposed fix
+def _as_expr(x: int | PrimExpr | tir.Call) -> PrimExpr:
+ return convert(x)
+
def shfl_xor(
mask: int | PrimExpr, value: int | PrimExpr | tir.Call, delta: int | PrimExpr | tir.Call, width: int | PrimExpr = _DEFAULT_SHFL_WIDTH
):
@@
- return tir.call_intrin(value.dtype, tir.op.Op.get("tl.shfl_xor_sync"), _as_uint32_mask(mask), value, delta, width)
+ value_expr = _as_expr(value)
+ return tir.call_intrin(
+ str(value_expr.dtype),
+ tir.op.Op.get("tl.shfl_xor_sync"),
+ _as_uint32_mask(mask),
+ value_expr,
+ convert(delta),
+ convert(width),
+ )
@@
def shfl_down(
@@
- return tir.call_intrin(value.dtype, tir.op.Op.get("tl.shfl_down_sync"), _as_uint32_mask(mask), value, delta, width)
+ value_expr = _as_expr(value)
+ return tir.call_intrin(
+ str(value_expr.dtype),
+ tir.op.Op.get("tl.shfl_down_sync"),
+ _as_uint32_mask(mask),
+ value_expr,
+ convert(delta),
+ convert(width),
+ )
@@
def shfl_up(
@@
- return tir.call_intrin(value.dtype, tir.op.Op.get("tl.shfl_up_sync"), _as_uint32_mask(mask), value, delta, width)
+ value_expr = _as_expr(value)
+ return tir.call_intrin(
+ str(value_expr.dtype),
+ tir.op.Op.get("tl.shfl_up_sync"),
+ _as_uint32_mask(mask),
+ value_expr,
+ convert(delta),
+ convert(width),
+ )
@@
def shfl_sync(mask: int | PrimExpr, value: int | PrimExpr, srcLane: int | PrimExpr, width: int | PrimExpr = _DEFAULT_SHFL_WIDTH):
@@
- return tir.call_intrin(value.dtype, tir.op.Op.get("tl.shfl_sync"), _as_uint32_mask(mask), value, srcLane, width)
+ value_expr = _as_expr(value)
+ return tir.call_intrin(
+ str(value_expr.dtype),
+ tir.op.Op.get("tl.shfl_sync"),
+ _as_uint32_mask(mask),
+ value_expr,
+ convert(srcLane),
+ convert(width),
+ )Also applies to: 831–838, 840–847, 866–871
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tilelang/language/builtin.py` around lines 823 - 829, The shuffle wrapper
functions (tl.shfl_xor_sync, tl.shfl_xor, tl.shfl_up_sync, tl.shfl_down_sync)
assume .dtype on `value` and pass `delta`, `width`, `srcLane` directly to
tir.call_intrin, which breaks for Python literal ints; update each wrapper to
coerce Python ints/Primitives into TIR expressions before use: convert `value`
(if an int) to a tir.const/TIR PrimExpr and use its .dtype, and convert `delta`,
`width`, `srcLane` (where applicable) via a helper (e.g. _as_primexpr or
tir.const) before passing them to tir.call_intrin; keep the existing
_as_uint32_mask(mask) usage for mask. Ensure you modify the four functions named
above so no literal int is accessed for .dtype or passed raw into
tir.call_intrin.
This pull request adds builtin support for warp-level vote/ballot intrinsics and block-wide predicated synchronization operations. #1861
These primitives are fundamental building blocks for high-performance GPU programming. They enable warp-wide early termination, fast consensus evaluation, and efficient divergence control.
Such capabilities are critical in optimized sorting algorithms, parallel partitioning, and other performance-sensitive kernels, where eliminating redundant computation at the warp level can substantially reduce execution time.
Note: On the ROCm backend, these intrinsics rely on HIP language extensions available in ROCm/HIP >= 7.0. Older ROCm releases may not provide full support for these builtins.
Summary by CodeRabbit
New Features
Documentation
Tests