Skip to content

[Feature] Add TIR builtins for warp-level vote and block-level predicate sync#1858

Merged
LeiWang1999 merged 3 commits intotile-ai:mainfrom
sepcnt:main
Apr 13, 2026
Merged

[Feature] Add TIR builtins for warp-level vote and block-level predicate sync#1858
LeiWang1999 merged 3 commits intotile-ai:mainfrom
sepcnt:main

Conversation

@sepcnt
Copy link
Copy Markdown
Contributor

@sepcnt sepcnt commented Feb 18, 2026

This pull request adds builtin support for warp-level vote/ballot intrinsics and block-wide predicated synchronization operations. #1861

These primitives are fundamental building blocks for high-performance GPU programming. They enable warp-wide early termination, fast consensus evaluation, and efficient divergence control.

Such capabilities are critical in optimized sorting algorithms, parallel partitioning, and other performance-sensitive kernels, where eliminating redundant computation at the warp level can substantially reduce execution time.

Note: On the ROCm backend, these intrinsics rely on HIP language extensions available in ROCm/HIP >= 7.0. Older ROCm releases may not provide full support for these builtins.

Summary by CodeRabbit

  • New Features

    • Added expanded synchronization primitives: block-, warp-, and cooperative-grid syncs; warp vote/ballot/activemask; warp-match and shuffle intrinsics; predicated synchronization reductions.
  • Documentation

    • Programming guide updated with descriptions and platform-specific notes (CUDA vs HIP) and explicit return-type/behavior guidance.
  • Tests

    • New unit tests validating warp-vote/ballot, match, shuffle, and block-synchronization semantics and runtime results.

@github-actions
Copy link
Copy Markdown

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Feb 18, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: df374686-7dca-4cd6-aeba-91db5d16884d

📥 Commits

Reviewing files that changed from the base of the PR and between c856804 and 220c0e9.

📒 Files selected for processing (5)
  • docs/programming_guides/instructions.md
  • src/op/builtin.cc
  • src/op/builtin.h
  • src/target/codegen_cuda.cc
  • src/target/codegen_hip.cc
✅ Files skipped from review due to trivial changes (1)
  • docs/programming_guides/instructions.md
🚧 Files skipped from review as they are similar to previous changes (3)
  • src/target/codegen_hip.cc
  • src/op/builtin.h
  • src/target/codegen_cuda.cc

📝 Walkthrough

Walkthrough

Adds TileLang GPU synchronization, warp vote/ballot, warp-match, and shuffle intrinsics; registers new builtins; implements CUDA/HIP codegen paths; re-exports new APIs; extends docs; and adds CUDA-focused unit tests validating emission and runtime behavior.

Changes

Cohort / File(s) Summary
Docs
docs/programming_guides/instructions.md
Adds reference entries for block/warp/grid sync, warp vote/ballot/mask, block predicated sync reductions, warp-match, and intra-warp shuffles with CUDA/HIP notes.
Language APIs
tilelang/language/builtin.py, tilelang/language/__init__.py
Adds frontend helpers (any_sync, all_sync, ballot_sync, ballot, activemask, syncthreads_count, syncthreads_and, syncthreads_or, match_any_sync, match_all_sync) and changes shfl_* signatures to accept explicit mask/width; re-exports added.
Builtins registry
src/op/builtin.h, src/op/builtin.cc
Registers new TL builtin ops for vote/ballot/activemask, syncthreads_* reductions, shfl_sync, and match_sync with opaque call effects to constrain reorderings.
Codegen
src/target/codegen_cuda.cc, src/target/codegen_hip.cc
Implements target-specific lowering: CUDA emits corresponding _* intrinsics (with casts/ID helpers and a lambda wrapper for match_all), HIP ignores mask for some intrinsics and fatals on unsupported match* intrinsics.
Tests
testing/python/language/test_tilelang_language_warp_vote.py
Adds CUDA-only TileLang JIT tests that check emitted intrinsic names and verify per-lane/runtime results for vote/ballot, activemask, syncthreads_* reductions, shuffles, and match_* behaviors.

Sequence Diagram(s)

sequenceDiagram
    participant Frontend as TileLang frontend
    participant Builtins as TL builtins registry
    participant CodeGen as CodeGen (CUDA / HIP)
    participant Device as GPU runtime
    Frontend->>Builtins: create CallNode for intrinsic (tl.*)
    Frontend->>CodeGen: VisitExpr_(CallNode)
    CodeGen->>CodeGen: validate args, choose target lowering (CUDA/HIP)
    CodeGen->>Device: emit kernel with target intrinsics (__any_sync, __ballot, __shfl, etc.)
    Device->>Device: execute warp/block intrinsic
    Device-->>Frontend: results stored to output buffers
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related issues

Possibly related PRs

Suggested labels

enhancement

Poem

🐇 I hop through lanes where ballots gleam,
Masks and shuffles stitch the stream.
Threads count, match, and clap in sync,
A rabbit hops, inspects the link—
🥕 kernels hum; the kernels beam.

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 52.63% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly and concisely summarizes the main changes: adding TIR builtins for warp-level voting/ballot operations and block-level predicate synchronization, which aligns with the changeset across documentation, tests, and implementation files.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (1)
tilelang/language/builtin.py (1)

941-959: ballot() duplicates logic instead of delegating to ballot_sync().

The docstring says "convenience wrapper around ballot_sync" but the implementation reimplements the HIP/CUDA branching rather than calling ballot_sync(0xFFFFFFFF, predicate). If the ballot_sync implementation ever changes (e.g., different cast strategy), ballot will silently diverge.

♻️ Proposed refactor to delegate to ballot_sync
 def ballot(predicate: int | PrimExpr) -> PrimExpr:
     ...
-    if _IS_HIP_AVAILABLE:
-        return tir.cast("uint32", tir.call_extern("uint64", "__ballot", predicate))
-    else:
-        return tir.call_extern("uint32", "__ballot_sync", tir.const(0xFFFFFFFF, "uint32"), predicate)
+    return ballot_sync(tir.const(0xFFFFFFFF, "uint32"), predicate)

Similarly, activemask() on HIP (line 972) reimplements the same __ballot + cast pattern. Consider delegating:

 def activemask() -> PrimExpr:
     ...
-    if _IS_HIP_AVAILABLE:
-        return tir.cast("uint32", tir.call_extern("uint64", "__ballot", tir.const(1, "int32")))
-    else:
+    if not _IS_HIP_AVAILABLE:
         return tir.call_extern("uint32", "__activemask")
+    return ballot(tir.const(1, "int32"))
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/language/builtin.py` around lines 941 - 959, The ballot() function
duplicates the HIP/CUDA branching instead of delegating to ballot_sync; change
ballot to simply return ballot_sync(tir.const(0xFFFFFFFF, "uint32"), predicate)
(ensuring types match) so any future ballot_sync changes apply here; likewise
update activemask() to delegate to ballot_sync/activemask_sync equivalent
instead of reimplementing the __ballot + cast pattern (use the same
uint32/uint64 casting behavior that ballot_sync already implements) and remove
the duplicated call_extern/cast logic in ballot and activemask while keeping
their public signatures unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@testing/python/language/test_tilelang_language_warp_vote.py`:
- Around line 26-40: The test declares an unused tensor parameter A in
kernel_any_sync.main and has an incorrect docstring; remove A from the kernel
signature (delete the T.Tensor((1,), "int32") parameter and any references to
A), update the docstring to accurately describe that lane 0 participates in the
any_sync predicate (e.g., "Lane 0 sets the predicate; all lanes use any_sync to
see if any lane matched"), and update any test invocation that passed an A
argument so it only supplies the B tensor to kernel_any_sync.

In `@tilelang/language/builtin.py`:
- Around line 918-938: The current ballot_sync function truncates HIP's 64-bit
(__ballot) result to uint32; modify the HIP branch (where _IS_HIP_AVAILABLE is
checked) to return the raw uint64 result via tir.call_extern("uint64",
"__ballot", predicate) instead of tir.cast("uint32", ...), and update the
function docstring and the inline comment to state that on HIP the function
returns uint64 (callers should narrow to uint32 when they know wavefronts are
32-wide); reference symbols: ballot_sync, _IS_HIP_AVAILABLE, tir.call_extern,
tir.cast.

---

Nitpick comments:
In `@tilelang/language/builtin.py`:
- Around line 941-959: The ballot() function duplicates the HIP/CUDA branching
instead of delegating to ballot_sync; change ballot to simply return
ballot_sync(tir.const(0xFFFFFFFF, "uint32"), predicate) (ensuring types match)
so any future ballot_sync changes apply here; likewise update activemask() to
delegate to ballot_sync/activemask_sync equivalent instead of reimplementing the
__ballot + cast pattern (use the same uint32/uint64 casting behavior that
ballot_sync already implements) and remove the duplicated call_extern/cast logic
in ballot and activemask while keeping their public signatures unchanged.

Comment thread testing/python/language/test_tilelang_language_warp_vote.py
Comment thread tilelang/language/builtin.py Outdated
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (1)
tilelang/language/__init__.py (1)

103-110: Optional: Remove redundant # noqa: F401 directives.

Ruff flags these as unused since F401 is not enabled in the project's Ruff configuration. The same pattern exists on lines 94–102 (pre-existing, not flagged here). Consider a one-time cleanup of all such directives across lines 94–110 once the existing pattern is addressed.

🧹 Proposed cleanup (lines 103–110 only)
-from .builtin import any_sync as any_sync  # noqa: F401
-from .builtin import all_sync as all_sync  # noqa: F401
-from .builtin import ballot_sync as ballot_sync  # noqa: F401
-from .builtin import ballot as ballot  # noqa: F401
-from .builtin import activemask as activemask  # noqa: F401
-from .builtin import syncthreads_count as syncthreads_count  # noqa: F401
-from .builtin import syncthreads_and as syncthreads_and  # noqa: F401
-from .builtin import syncthreads_or as syncthreads_or  # noqa: F401
+from .builtin import any_sync as any_sync
+from .builtin import all_sync as all_sync
+from .builtin import ballot_sync as ballot_sync
+from .builtin import ballot as ballot
+from .builtin import activemask as activemask
+from .builtin import syncthreads_count as syncthreads_count
+from .builtin import syncthreads_and as syncthreads_and
+from .builtin import syncthreads_or as syncthreads_or
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/language/__init__.py` around lines 103 - 110, Remove the redundant
"# noqa: F401" directives from the import statements that re-export builtin
symbols; specifically edit the import lines that reference any_sync, all_sync,
ballot_sync, ballot, activemask, syncthreads_count, syncthreads_and, and
syncthreads_or in __init__.py and delete the trailing " # noqa: F401" from each
import; ensure the names remain imported (no other code changes) and run the
linter/flake check to confirm no warnings remain.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@docs/programming_guides/instructions.md`:
- Around line 155-158: Add a brief HIP version note stating that
T.syncthreads_count, T.syncthreads_and, and T.syncthreads_or require ROCm/HIP ≥
7.0 (they are only under development in HIP ≤ 6.2); update the three entries
(T.syncthreads_count(predicate), T.syncthreads_and(predicate),
T.syncthreads_or(predicate)) to append “HIP: ROCm/hip ≥ 7.0” or expand the
existing HIP note that covers vote/ballot caveats to explicitly call out these
predicated block-wide sync functions and their minimum HIP version requirement.

In `@tilelang/language/builtin.py`:
- Around line 982-1024: The docstrings for syncthreads_count, syncthreads_and,
and syncthreads_or incorrectly claim the intrinsics map to CUDA and HIP
unconditionally; update each function's docstring (syncthreads_count,
syncthreads_and, syncthreads_or) to note the HIP/ROCm version constraints: that
the __syncthreads_count/and/or variants are under development in HIP 6.2 and
that __sync variants became available in ROCm 7.0 (and are enabled by default
there), so on older HIP/ROCm stacks these intrinsics may not be available; keep
the CUDA mapping text unchanged and add a brief sentence about the HIP version
caveat.

---

Duplicate comments:
In `@testing/python/language/test_tilelang_language_warp_vote.py`:
- Around line 26-52: The kernel declares A in kernel_any_sync -> main but never
uses it and the docstring is wrong; either remove the unused parameter A from
main and update the docstring and test_any_sync to stop allocating/passing a, or
implement the intended behavior: have lane 0 write 1 into A (e.g., inside
T.Kernel when tx == 0 set A[0]=1) and change the any_sync predicate to read
A[0]==1 so the call sites (test_any_sync's allocation/passing of a) remain
correct; update kernel_any_sync's docstring accordingly.

In `@tilelang/language/builtin.py`:
- Around line 918-974: The HIP implementations of ballot_sync, ballot, and
activemask incorrectly cast the uint64 result of __ballot to uint32, discarding
lanes 32–63 on wave-64 targets; update the HIP paths in ballot_sync, ballot, and
activemask to return 64-bit masks (use tir.call_extern with "uint64" and return
a uint64 PrimExpr, and avoid casting to "uint32"), and ensure ballot_sync's mask
parameter and any callers that expect a 32-bit return are adjusted to use 64-bit
mask types so full wave-64 lane data is preserved.

---

Nitpick comments:
In `@tilelang/language/__init__.py`:
- Around line 103-110: Remove the redundant "# noqa: F401" directives from the
import statements that re-export builtin symbols; specifically edit the import
lines that reference any_sync, all_sync, ballot_sync, ballot, activemask,
syncthreads_count, syncthreads_and, and syncthreads_or in __init__.py and delete
the trailing " # noqa: F401" from each import; ensure the names remain imported
(no other code changes) and run the linter/flake check to confirm no warnings
remain.

Comment on lines +155 to +158
Block-wide predicated sync
- `T.syncthreads_count(predicate)` → `int32`: Sync all threads; return count with non-zero predicate (`__syncthreads_count`).
- `T.syncthreads_and(predicate)` → `int32`: Sync; non-zero iff ALL threads have non-zero predicate (`__syncthreads_and`).
- `T.syncthreads_or(predicate)` → `int32`: Sync; non-zero iff ANY thread has non-zero predicate (`__syncthreads_or`).
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Missing HIP version note for syncthreads_count/and/or.

The block-wide predicated sync entries do not mention the ROCm ≥ 7.0 requirement noted in the PR description. The __syncthreads_count(int), __syncthreads_and(int), and __syncthreads_or(int) functions are under development in HIP 6.2 and below. The HIP note at line 166 addresses vote/ballot HIP caveats but omits this. Consider appending a brief version note to the syncthreads_count/and/or entries or expanding the HIP note to cover them.

📝 Suggested doc patch
 Block-wide predicated sync
-- `T.syncthreads_count(predicate)` → `int32`: Sync all threads; return count with non-zero predicate (`__syncthreads_count`).
-- `T.syncthreads_and(predicate)` → `int32`: Sync; non-zero iff ALL threads have non-zero predicate (`__syncthreads_and`).
-- `T.syncthreads_or(predicate)` → `int32`: Sync; non-zero iff ANY thread has non-zero predicate (`__syncthreads_or`).
+- `T.syncthreads_count(predicate)` → `int32`: Sync all threads; return count with non-zero predicate (`__syncthreads_count`). *(HIP: requires ROCm ≥ 7.0)*
+- `T.syncthreads_and(predicate)` → `int32`: Sync; non-zero iff ALL threads have non-zero predicate (`__syncthreads_and`). *(HIP: requires ROCm ≥ 7.0)*
+- `T.syncthreads_or(predicate)` → `int32`: Sync; non-zero iff ANY thread has non-zero predicate (`__syncthreads_or`). *(HIP: requires ROCm ≥ 7.0)*
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
Block-wide predicated sync
- `T.syncthreads_count(predicate)``int32`: Sync all threads; return count with non-zero predicate (`__syncthreads_count`).
- `T.syncthreads_and(predicate)``int32`: Sync; non-zero iff ALL threads have non-zero predicate (`__syncthreads_and`).
- `T.syncthreads_or(predicate)``int32`: Sync; non-zero iff ANY thread has non-zero predicate (`__syncthreads_or`).
Block-wide predicated sync
- `T.syncthreads_count(predicate)``int32`: Sync all threads; return count with non-zero predicate (`__syncthreads_count`). *(HIP: requires ROCm ≥ 7.0)*
- `T.syncthreads_and(predicate)``int32`: Sync; non-zero iff ALL threads have non-zero predicate (`__syncthreads_and`). *(HIP: requires ROCm ≥ 7.0)*
- `T.syncthreads_or(predicate)``int32`: Sync; non-zero iff ANY thread has non-zero predicate (`__syncthreads_or`). *(HIP: requires ROCm ≥ 7.0)*
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@docs/programming_guides/instructions.md` around lines 155 - 158, Add a brief
HIP version note stating that T.syncthreads_count, T.syncthreads_and, and
T.syncthreads_or require ROCm/HIP ≥ 7.0 (they are only under development in HIP
≤ 6.2); update the three entries (T.syncthreads_count(predicate),
T.syncthreads_and(predicate), T.syncthreads_or(predicate)) to append “HIP:
ROCm/hip ≥ 7.0” or expand the existing HIP note that covers vote/ballot caveats
to explicitly call out these predicated block-wide sync functions and their
minimum HIP version requirement.

Comment thread tilelang/language/builtin.py Outdated
@sepcnt sepcnt marked this pull request as draft February 18, 2026 17:59
@LeiWang1999
Copy link
Copy Markdown
Member

@sepcnt Thanks and would be better to wrap the instruction with T.call_intrin instead of T.call_extern

@LeiWang1999 LeiWang1999 marked this pull request as ready for review April 13, 2026 09:47
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (2)
tilelang/language/builtin.py (2)

881-884: Minor: Replace EN DASH with HYPHEN-MINUS in docstring.

Static analysis flagged an ambiguous EN DASH () character at line 882. For consistency and to avoid potential encoding issues, use a standard hyphen-minus (-).

📝 Suggested fix
     CUDA: ``__any_sync(mask, predicate)``
-    HIP:  ``__any(predicate)`` (mask argument is ignored – HIP uses the full
+    HIP:  ``__any(predicate)`` (mask argument is ignored - HIP uses the full
           wavefront).
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/language/builtin.py` around lines 881 - 884, Replace the ambiguous
EN DASH with a standard hyphen-minus in the docstring line describing HIP
behavior: change the phrase "mask argument is ignored – HIP uses the full
wavefront" to use a normal hyphen-minus ("mask argument is ignored - HIP uses
the full wavefront") so the text near the CUDA: __any_sync(mask, predicate) /
HIP: __any(predicate) examples uses a standard ASCII hyphen for consistency and
encoding safety.

877-895: Consider using tir.call_intrin instead of tir.call_extern for these intrinsics.

A PR comment from @LeiWang1999 recommends wrapping the instructions with T.call_intrin instead of T.call_extern. This would allow for better integration with TIR's intrinsic system and enable downstream passes to recognize and potentially optimize these operations.

However, I note that the existing shfl_xor, shfl_down, and shfl_up functions (lines 805-845) also use tir.call_extern, so the current approach is consistent with the existing codebase. If you decide to switch to call_intrin, consider updating the shuffle functions as well for uniformity.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/language/builtin.py` around lines 877 - 895, Replace the use of
tir.call_extern in the any_sync function with tir.call_intrin so the operation
is emitted as a TIR intrinsic (keep the return dtype int32 and pass the same
arguments: mask and predicate for CUDA path, predicate only for HIP path); also
consider updating the related shuffle helpers (shfl_xor, shfl_down, shfl_up) to
use tir.call_intrin for consistency with any_sync so downstream TIR passes can
recognize and optimize these intrinsics.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@tilelang/language/builtin.py`:
- Around line 881-884: Replace the ambiguous EN DASH with a standard
hyphen-minus in the docstring line describing HIP behavior: change the phrase
"mask argument is ignored – HIP uses the full wavefront" to use a normal
hyphen-minus ("mask argument is ignored - HIP uses the full wavefront") so the
text near the CUDA: __any_sync(mask, predicate) / HIP: __any(predicate) examples
uses a standard ASCII hyphen for consistency and encoding safety.
- Around line 877-895: Replace the use of tir.call_extern in the any_sync
function with tir.call_intrin so the operation is emitted as a TIR intrinsic
(keep the return dtype int32 and pass the same arguments: mask and predicate for
CUDA path, predicate only for HIP path); also consider updating the related
shuffle helpers (shfl_xor, shfl_down, shfl_up) to use tir.call_intrin for
consistency with any_sync so downstream TIR passes can recognize and optimize
these intrinsics.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 2bc0da8d-252d-4847-9258-2848fc693514

📥 Commits

Reviewing files that changed from the base of the PR and between 475eea1 and c070663.

📒 Files selected for processing (4)
  • docs/programming_guides/instructions.md
  • testing/python/language/test_tilelang_language_warp_vote.py
  • tilelang/language/__init__.py
  • tilelang/language/builtin.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • tilelang/language/init.py

LeiWang1999 and others added 2 commits April 13, 2026 18:21
Convert the warp-vote, warp-shuffle, predicated block-sync, and warp-match
builtins from raw `tir.call_extern` wrappers to proper TIR ops registered
under `tl.*` and lowered in `codegen_cuda` / `codegen_hip`. This removes
Python-side `_IS_HIP_AVAILABLE` branching and pushes the CUDA/HIP split
into codegen, where it belongs.

* Register new tl ops: any_sync, all_sync, ballot_sync, ballot,
  activemask, syncthreads_count/and/or, shfl_sync, shfl_xor_sync,
  shfl_down_sync, shfl_up_sync, match_any_sync, match_all_sync.
* Codegen lowering on both CUDA and HIP. uint32->uint64 zero-extension
  for ballot/activemask now happens in codegen. HIP drops the mask
  argument for shfl/vote and emits LOG(FATAL) for match_*_sync.
* `__match_all_sync`'s int* pred output is hidden behind an
  immediately-invoked lambda so the wrapper stays expression-form.
* Python wrappers normalize Python int masks to uint32 TIR consts via
  `_as_uint32_mask`, so the emitted source prints `(uint)0xFFFFFFFF`
  instead of `(int64_t)4294967295`.
* Unify `shfl_xor/down/up` signatures to `(mask, value, delta, width=32)`,
  matching CUDA convention and `shfl_sync`. No in-tree callers existed.
* Add tests for match_any_sync and match_all_sync.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Resolve conflict in src/target/codegen_hip.cc: keep the warp-vote /
shuffle / match codegen branches added in this branch, and adopt main's
new unified packed-x2 dispatch (add2/sub2/mul2/fma2/max2/min2/abs2)
that replaces the standalone fadd2/fmul2/fma2 blocks.
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

♻️ Duplicate comments (2)
tilelang/language/builtin.py (1)

954-971: ⚠️ Potential issue | 🟡 Minor

Add HIP/ROCm version caveat to syncthreads_* docstrings.

These docstrings still read as unconditional HIP support. Please note the ROCm/HIP availability constraint to prevent user confusion on older stacks.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/language/builtin.py` around lines 954 - 971, Update the docstrings
for syncthreads_count, syncthreads_and, and syncthreads_or to explicitly note
that while these map to CUDA __syncthreads_* intrinsics, HIP/ROCm support can be
limited on older ROCm stacks and may not be available or behave differently;
edit the triple-quoted docstrings in the functions syncthreads_count,
syncthreads_and, and syncthreads_or to add a short caveat sentence about
ROCm/HIP availability and recommend checking the target ROCm version/driver for
support.
docs/programming_guides/instructions.md (1)

155-170: ⚠️ Potential issue | 🟡 Minor

Add an explicit ROCm/HIP version caveat for syncthreads_count/and/or.

This section still reads as universally available on HIP; please append a minimum-version note (ROCm/HIP >= 7.0) for T.syncthreads_count, T.syncthreads_and, and T.syncthreads_or.

📝 Suggested doc patch
- - `T.syncthreads_count(predicate)` → `int32`: Sync all threads; return count with non-zero predicate (`__syncthreads_count`).
- - `T.syncthreads_and(predicate)` → `int32`: Sync; non-zero iff ALL threads have non-zero predicate (`__syncthreads_and`).
- - `T.syncthreads_or(predicate)` → `int32`: Sync; non-zero iff ANY thread has non-zero predicate (`__syncthreads_or`).
+ - `T.syncthreads_count(predicate)` → `int32`: Sync all threads; return count with non-zero predicate (`__syncthreads_count`). *(HIP: requires ROCm/HIP ≥ 7.0)*
+ - `T.syncthreads_and(predicate)` → `int32`: Sync; non-zero iff ALL threads have non-zero predicate (`__syncthreads_and`). *(HIP: requires ROCm/HIP ≥ 7.0)*
+ - `T.syncthreads_or(predicate)` → `int32`: Sync; non-zero iff ANY thread has non-zero predicate (`__syncthreads_or`). *(HIP: requires ROCm/HIP ≥ 7.0)*
Do official HIP/ROCm docs indicate that `__syncthreads_count`, `__syncthreads_and`, and `__syncthreads_or` require ROCm/HIP 7.0+ (or are incomplete in HIP 6.x)?
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@docs/programming_guides/instructions.md` around lines 155 - 170, Add a
ROCm/HIP minimum-version caveat for the block-wide predicated sync entries:
update the descriptions for T.syncthreads_count, T.syncthreads_and, and
T.syncthreads_or to append a short note like “(ROCm/HIP >= 7.0 required)” and/or
a parenthetical sentence that these intrinsics are unavailable or will fail to
codegen on HIP < 7.0; ensure the same three symbol names (T.syncthreads_count,
T.syncthreads_and, T.syncthreads_or) are updated in the Warp-shuffle/Warp-match
note where syncthreads_count/and/or are discussed so the HIP caveat appears
consistently.
🧹 Nitpick comments (1)
testing/python/language/test_tilelang_language_warp_vote.py (1)

42-406: Consider adding HIP-targeted tests; current coverage is CUDA-only.

The suite documents HIP behavior but only runs under @requires_cuda. Adding @requires_hip coverage for wave64 ballot/activemask and expected match_*_sync behavior would de-risk backend parity.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@testing/python/language/test_tilelang_language_warp_vote.py` around lines 42
- 406, Tests only run for CUDA; add HIP-targeted equivalents so HIP-specific
behavior (wave64 ballot/activemask and match_*_sync semantics) is validated.
Duplicate the relevant test functions (e.g., test_ballot, test_activemask,
test_ballot_sync, test_match_any_sync, test_match_all_sync, test_activemask,
test_match_any_sync) or parameterize them to run under HIP by replacing/adding
the `@tilelang.testing.requires_cuda` decorator with
`@tilelang.testing.requires_hip` (or adding a new HIP-decorated test next to the
CUDA one), and update the expected bitmask assertions in the HIP variants to
reflect wave64 behavior (upper/lower 32-bit expectations and width-specific
masks) so the tests assert the correct HIP/wave64 results.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@src/target/codegen_hip.cc`:
- Around line 848-857: The branches that currently emit HIP intrinsics for
tl::syncthreads_count, tl::syncthreads_and, and tl::syncthreads_or should not
unconditionally print __syncthreads_count/and/or; replace those emissions with
the same unsupported-operation handling used for tl::match_any_sync (e.g., emit
LOG(FATAL) with a clear message that tl::syncthreads_count / tl::syncthreads_and
/ tl::syncthreads_or are unsupported on HIP) instead of calling PrintExpr and
printing __syncthreads_* so compilation failures are avoided.

In `@testing/python/language/test_tilelang_language_warp_vote.py`:
- Around line 1-13: Replace all en dash characters (–) in the module docstring
with ASCII hyphens (-) to satisfy Ruff RUF002; specifically update the lines
listing intrinsics (e.g., the lines containing "T.any_sync – __any_sync / __any
(HIP)", "T.all_sync – __all_sync / __all (HIP)", and all other entries like
"T.ballot_sync – __ballot_sync→uint64", "T.activemask – __activemask→uint64",
etc.) so every occurrence of “–” becomes "-" and any similar en-dash uses
elsewhere in the docstring are normalized to ASCII hyphens.

In `@tilelang/language/builtin.py`:
- Around line 823-829: The shuffle wrapper functions (tl.shfl_xor_sync,
tl.shfl_xor, tl.shfl_up_sync, tl.shfl_down_sync) assume .dtype on `value` and
pass `delta`, `width`, `srcLane` directly to tir.call_intrin, which breaks for
Python literal ints; update each wrapper to coerce Python ints/Primitives into
TIR expressions before use: convert `value` (if an int) to a tir.const/TIR
PrimExpr and use its .dtype, and convert `delta`, `width`, `srcLane` (where
applicable) via a helper (e.g. _as_primexpr or tir.const) before passing them to
tir.call_intrin; keep the existing _as_uint32_mask(mask) usage for mask. Ensure
you modify the four functions named above so no literal int is accessed for
.dtype or passed raw into tir.call_intrin.

---

Duplicate comments:
In `@docs/programming_guides/instructions.md`:
- Around line 155-170: Add a ROCm/HIP minimum-version caveat for the block-wide
predicated sync entries: update the descriptions for T.syncthreads_count,
T.syncthreads_and, and T.syncthreads_or to append a short note like “(ROCm/HIP
>= 7.0 required)” and/or a parenthetical sentence that these intrinsics are
unavailable or will fail to codegen on HIP < 7.0; ensure the same three symbol
names (T.syncthreads_count, T.syncthreads_and, T.syncthreads_or) are updated in
the Warp-shuffle/Warp-match note where syncthreads_count/and/or are discussed so
the HIP caveat appears consistently.

In `@tilelang/language/builtin.py`:
- Around line 954-971: Update the docstrings for syncthreads_count,
syncthreads_and, and syncthreads_or to explicitly note that while these map to
CUDA __syncthreads_* intrinsics, HIP/ROCm support can be limited on older ROCm
stacks and may not be available or behave differently; edit the triple-quoted
docstrings in the functions syncthreads_count, syncthreads_and, and
syncthreads_or to add a short caveat sentence about ROCm/HIP availability and
recommend checking the target ROCm version/driver for support.

---

Nitpick comments:
In `@testing/python/language/test_tilelang_language_warp_vote.py`:
- Around line 42-406: Tests only run for CUDA; add HIP-targeted equivalents so
HIP-specific behavior (wave64 ballot/activemask and match_*_sync semantics) is
validated. Duplicate the relevant test functions (e.g., test_ballot,
test_activemask, test_ballot_sync, test_match_any_sync, test_match_all_sync,
test_activemask, test_match_any_sync) or parameterize them to run under HIP by
replacing/adding the `@tilelang.testing.requires_cuda` decorator with
`@tilelang.testing.requires_hip` (or adding a new HIP-decorated test next to the
CUDA one), and update the expected bitmask assertions in the HIP variants to
reflect wave64 behavior (upper/lower 32-bit expectations and width-specific
masks) so the tests assert the correct HIP/wave64 results.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 1fc66cbb-c9d3-4f11-b0d7-9e680e426870

📥 Commits

Reviewing files that changed from the base of the PR and between c070663 and c856804.

📒 Files selected for processing (7)
  • docs/programming_guides/instructions.md
  • src/op/builtin.cc
  • src/op/builtin.h
  • src/target/codegen_cuda.cc
  • src/target/codegen_hip.cc
  • testing/python/language/test_tilelang_language_warp_vote.py
  • tilelang/language/builtin.py

Comment thread src/target/codegen_hip.cc
Comment on lines +848 to +857
} else if (op->op.same_as(tl::syncthreads_count())) {
ICHECK_EQ(op->args.size(), 1U)
<< "tl.syncthreads_count expects <predicate>.";
os << "__syncthreads_count(" << PrintExpr(op->args[0]) << ")";
} else if (op->op.same_as(tl::syncthreads_and())) {
ICHECK_EQ(op->args.size(), 1U) << "tl.syncthreads_and expects <predicate>.";
os << "__syncthreads_and(" << PrintExpr(op->args[0]) << ")";
} else if (op->op.same_as(tl::syncthreads_or())) {
ICHECK_EQ(op->args.size(), 1U) << "tl.syncthreads_or expects <predicate>.";
os << "__syncthreads_or(" << PrintExpr(op->args[0]) << ")";
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🌐 Web query:

What ROCm/HIP version first fully supports HIP device intrinsics __syncthreads_count, __syncthreads_and, and __syncthreads_or on AMD GPUs?

💡 Result:

No ROCm/HIP version fully supports the HIP device intrinsics syncthreads_count, syncthreads_and, and syncthreads_or on AMD GPUs. These are block-level synchronization primitives with predicates, analogous to CUDA's, but official HIP documentation and changelogs do not list them as implemented. Older docs (e.g., ROCm 5.x) explicitly state they are "under development". Recent docs (HIP 7.x) mention warp-level sync variants like ballot and sync shfl but omit these block-level ones. They appear in HIP-CPU (CPU emulation) and rocPRIM (library), but not core HIP language for AMD GPUs. Issues requesting them remain open (e.g., HIP-CPU #58). Warp sync like syncwarp added in ROCm 7.0, but not block sync with count/and/or.

Citations:


🏁 Script executed:

# Check if there's any wrapper, guard, or fallback mechanism for these HIP intrinsics in the codebase
rg -B 5 -A 5 "syncthreads_count|syncthreads_and|syncthreads_or" src/target/codegen_hip.cc

# Also check for any `#ifdef` guards or version checks around this section
rg -B 10 "syncthreads_count" src/target/codegen_hip.cc | head -40

Repository: tile-ai/tilelang

Length of output: 2130


🏁 Script executed:

# Check if there's any HIP version macro or minimum version requirement set globally
rg "__HIP_|HIP_VERSION|ROCm|HIP_COMMON_HEADER_CXX_VERSION" src/target/codegen_hip.cc | head -20

# Look for any fallback definitions of these intrinsics elsewhere
fd -t f '\.(cc|h|hpp)$' src/ | xargs rg "define.*syncthreads_count|define.*syncthreads_and|define.*syncthreads_or" 2>/dev/null

Repository: tile-ai/tilelang

Length of output: 44


🏁 Script executed:

# Check the full context around lines 828-883 mentioned in the scratchpad
sed -n '828,883p' src/target/codegen_hip.cc

Repository: tile-ai/tilelang

Length of output: 3403


Replace unconditional HIP intrinsic emission with explicit unsupported error.

Lines 848–857 unconditionally emit __syncthreads_count, __syncthreads_and, and __syncthreads_or, but these intrinsics do not exist in HIP for AMD GPUs across any ROCm version. Code using these operations will fail to compile. Align with the existing pattern for unsupported operations (e.g., tl::match_any_sync): either emit a LOG(FATAL) with a clear error message or remove the code paths entirely.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/target/codegen_hip.cc` around lines 848 - 857, The branches that
currently emit HIP intrinsics for tl::syncthreads_count, tl::syncthreads_and,
and tl::syncthreads_or should not unconditionally print
__syncthreads_count/and/or; replace those emissions with the same
unsupported-operation handling used for tl::match_any_sync (e.g., emit
LOG(FATAL) with a clear message that tl::syncthreads_count / tl::syncthreads_and
/ tl::syncthreads_or are unsupported on HIP) instead of calling PrintExpr and
printing __syncthreads_* so compilation failures are avoided.

Comment on lines +1 to +13
"""Tests for warp-vote / warp-ballot / block-sync-with-predicate intrinsics.

Covered intrinsics
------------------
T.any_sync – __any_sync / __any (HIP)
T.all_sync – __all_sync / __all (HIP)
T.ballot_sync – __ballot_sync→uint64 (CUDA, zero-ext) / __ballot uint64 (HIP, all lanes)
T.ballot – full-warp ballot_sync / __ballot uint64 (HIP, all lanes)
T.activemask – __activemask→uint64 (CUDA, zero-ext) / __ballot(1) uint64 (HIP, all lanes)
T.syncthreads_count – __syncthreads_count
T.syncthreads_and – __syncthreads_and
T.syncthreads_or – __syncthreads_or
"""
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Replace EN DASH characters in docstrings to satisfy Ruff (RUF002).

Use ASCII - instead of to avoid lint failures.

Also applies to: 192-192

🧰 Tools
🪛 Ruff (0.15.9)

[warning] 5-5: Docstring contains ambiguous (EN DASH). Did you mean - (HYPHEN-MINUS)?

(RUF002)


[warning] 6-6: Docstring contains ambiguous (EN DASH). Did you mean - (HYPHEN-MINUS)?

(RUF002)


[warning] 7-7: Docstring contains ambiguous (EN DASH). Did you mean - (HYPHEN-MINUS)?

(RUF002)


[warning] 8-8: Docstring contains ambiguous (EN DASH). Did you mean - (HYPHEN-MINUS)?

(RUF002)


[warning] 9-9: Docstring contains ambiguous (EN DASH). Did you mean - (HYPHEN-MINUS)?

(RUF002)


[warning] 10-10: Docstring contains ambiguous (EN DASH). Did you mean - (HYPHEN-MINUS)?

(RUF002)


[warning] 11-11: Docstring contains ambiguous (EN DASH). Did you mean - (HYPHEN-MINUS)?

(RUF002)


[warning] 12-12: Docstring contains ambiguous (EN DASH). Did you mean - (HYPHEN-MINUS)?

(RUF002)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@testing/python/language/test_tilelang_language_warp_vote.py` around lines 1 -
13, Replace all en dash characters (–) in the module docstring with ASCII
hyphens (-) to satisfy Ruff RUF002; specifically update the lines listing
intrinsics (e.g., the lines containing "T.any_sync – __any_sync / __any (HIP)",
"T.all_sync – __all_sync / __all (HIP)", and all other entries like
"T.ballot_sync – __ballot_sync→uint64", "T.activemask – __activemask→uint64",
etc.) so every occurrence of “–” becomes "-" and any similar en-dash uses
elsewhere in the docstring are normalized to ASCII hyphens.

Comment on lines +823 to 829
mask: int | PrimExpr, value: int | PrimExpr | tir.Call, delta: int | PrimExpr | tir.Call, width: int | PrimExpr = _DEFAULT_SHFL_WIDTH
):
"""XOR-swap ``value`` across lanes (``__shfl_xor_sync`` on CUDA,
``__shfl_xor`` on HIP — mask ignored on HIP).
"""
if _IS_HIP_AVAILABLE:
return tir.call_extern(value.dtype, "__shfl_down", value, offset)
else:
return tir.call_extern(value.dtype, "__shfl_down_sync", 0xFFFFFFFF, value, offset)
return tir.call_intrin(value.dtype, tir.op.Op.get("tl.shfl_xor_sync"), _as_uint32_mask(mask), value, delta, width)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, locate the file and check its structure
find . -name "builtin.py" -path "*tilelang/language*" | head -5

Repository: tile-ai/tilelang

Length of output: 90


🏁 Script executed:

# Read the relevant lines from the file to verify the claim
head -900 tilelang/language/builtin.py | tail -100

Repository: tile-ai/tilelang

Length of output: 4036


🏁 Script executed:

# Also search for the _as_uint32_mask helper and convert function
rg -n "_as_uint32_mask|def convert" tilelang/language/builtin.py | head -20

Repository: tile-ai/tilelang

Length of output: 1138


🏁 Script executed:

# Look for the complete shuffle function definitions
rg -A 5 "def shfl_xor|def shfl_down|def shfl_up|def shfl_sync" tilelang/language/builtin.py

Repository: tile-ai/tilelang

Length of output: 1307


🏁 Script executed:

# Check if convert() function exists in the file or is imported
rg -n "^from|^import|def convert" tilelang/language/builtin.py | head -30

Repository: tile-ai/tilelang

Length of output: 604


🏁 Script executed:

# Look at the imports and beginning of the file
head -50 tilelang/language/builtin.py

Repository: tile-ai/tilelang

Length of output: 1791


🏁 Script executed:

# Check if there are tests for these shuffle functions with literal int values
find . -name "*test*builtin*" -o -name "*builtin*test*" | grep -i test

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Search for usage of shfl_xor, shfl_down, shfl_up, shfl_sync in tests
rg "shfl_xor|shfl_down|shfl_up|shfl_sync" --type py -l | head -10

Repository: tile-ai/tilelang

Length of output: 148


🏁 Script executed:

# Check the test file for shuffle function usage
cat testing/python/language/test_tilelang_language_warp_sync.py

Repository: tile-ai/tilelang

Length of output: 1542


🏁 Script executed:

# Also check what _as_uint32_mask does and how it handles int vs PrimExpr
sed -n '811,825p' tilelang/language/builtin.py

Repository: tile-ai/tilelang

Length of output: 651


🏁 Script executed:

# Look at how tir.call_intrin handles its parameters
rg -B 5 -A 5 "call_intrin.*\.dtype" tilelang/language/builtin.py | head -50

Repository: tile-ai/tilelang

Length of output: 2211


🏁 Script executed:

# Check other functions to see how they handle dtype extraction
rg -B 3 -A 3 "tir.call_intrin" tilelang/language/builtin.py | grep -A 3 -B 3 "\.dtype"

Repository: tile-ai/tilelang

Length of output: 1729


🏁 Script executed:

# Look for how convert() is used elsewhere in the file
rg "convert\(" tilelang/language/builtin.py

Repository: tile-ai/tilelang

Length of output: 528


🏁 Script executed:

# Check if there are other places where int | PrimExpr parameters are properly converted
sed -n '17,40p' tilelang/language/builtin.py

Repository: tile-ai/tilelang

Length of output: 789


🏁 Script executed:

# Look at a complete example of how normalize_index_arg is used
rg -B 2 -A 2 "_normalize_index_arg" tilelang/language/builtin.py | head -30

Repository: tile-ai/tilelang

Length of output: 1209


🏁 Script executed:

# Final check: Look at one more case to confirm how delta/width are handled
sed -n '823,829p' tilelang/language/builtin.py

Repository: tile-ai/tilelang

Length of output: 447


🏁 Script executed:

# Verify that passing literal int to these params would also fail
python3 << 'EOF'
# Simulate what happens when int is passed to call_intrin
class MockPrimExpr:
    def __init__(self):
        self.dtype = "int32"

# This is what happens in the current code
value = 42  # literal int
try:
    dtype = value.dtype
except AttributeError as e:
    print(f"✗ Error: {e}")

# With conversion
expr = MockPrimExpr()
dtype = expr.dtype
print(f"✓ After conversion: dtype = {dtype}")
EOF

Repository: tile-ai/tilelang

Length of output: 140


Fix literal int handling in shuffle wrappers before accessing .dtype.

The value parameter accepts int | PrimExpr | tir.Call, but all four shuffle functions directly access value.dtype without converting literal integers first. This will raise AttributeError: 'int' object has no attribute 'dtype' at runtime when a Python literal is passed.

Additionally, delta, width, and srcLane parameters also accept int | PrimExpr but are passed unconverted to tir.call_intrin(), which expects TIR expressions for proper type handling.

Proposed fix
+def _as_expr(x: int | PrimExpr | tir.Call) -> PrimExpr:
+    return convert(x)
+
 def shfl_xor(
     mask: int | PrimExpr, value: int | PrimExpr | tir.Call, delta: int | PrimExpr | tir.Call, width: int | PrimExpr = _DEFAULT_SHFL_WIDTH
 ):
@@
-    return tir.call_intrin(value.dtype, tir.op.Op.get("tl.shfl_xor_sync"), _as_uint32_mask(mask), value, delta, width)
+    value_expr = _as_expr(value)
+    return tir.call_intrin(
+        str(value_expr.dtype),
+        tir.op.Op.get("tl.shfl_xor_sync"),
+        _as_uint32_mask(mask),
+        value_expr,
+        convert(delta),
+        convert(width),
+    )
@@
 def shfl_down(
@@
-    return tir.call_intrin(value.dtype, tir.op.Op.get("tl.shfl_down_sync"), _as_uint32_mask(mask), value, delta, width)
+    value_expr = _as_expr(value)
+    return tir.call_intrin(
+        str(value_expr.dtype),
+        tir.op.Op.get("tl.shfl_down_sync"),
+        _as_uint32_mask(mask),
+        value_expr,
+        convert(delta),
+        convert(width),
+    )
@@
 def shfl_up(
@@
-    return tir.call_intrin(value.dtype, tir.op.Op.get("tl.shfl_up_sync"), _as_uint32_mask(mask), value, delta, width)
+    value_expr = _as_expr(value)
+    return tir.call_intrin(
+        str(value_expr.dtype),
+        tir.op.Op.get("tl.shfl_up_sync"),
+        _as_uint32_mask(mask),
+        value_expr,
+        convert(delta),
+        convert(width),
+    )
@@
 def shfl_sync(mask: int | PrimExpr, value: int | PrimExpr, srcLane: int | PrimExpr, width: int | PrimExpr = _DEFAULT_SHFL_WIDTH):
@@
-    return tir.call_intrin(value.dtype, tir.op.Op.get("tl.shfl_sync"), _as_uint32_mask(mask), value, srcLane, width)
+    value_expr = _as_expr(value)
+    return tir.call_intrin(
+        str(value_expr.dtype),
+        tir.op.Op.get("tl.shfl_sync"),
+        _as_uint32_mask(mask),
+        value_expr,
+        convert(srcLane),
+        convert(width),
+    )

Also applies to: 831–838, 840–847, 866–871

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/language/builtin.py` around lines 823 - 829, The shuffle wrapper
functions (tl.shfl_xor_sync, tl.shfl_xor, tl.shfl_up_sync, tl.shfl_down_sync)
assume .dtype on `value` and pass `delta`, `width`, `srcLane` directly to
tir.call_intrin, which breaks for Python literal ints; update each wrapper to
coerce Python ints/Primitives into TIR expressions before use: convert `value`
(if an int) to a tir.const/TIR PrimExpr and use its .dtype, and convert `delta`,
`width`, `srcLane` (where applicable) via a helper (e.g. _as_primexpr or
tir.const) before passing them to tir.call_intrin; keep the existing
_as_uint32_mask(mask) usage for mask. Ensure you modify the four functions named
above so no literal int is accessed for .dtype or passed raw into
tir.call_intrin.

@LeiWang1999 LeiWang1999 merged commit 19236b4 into tile-ai:main Apr 13, 2026
5 of 6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature Request] Add TIR builtins for warp-level vote and block-level predicate sync

2 participants