[Codegen] Warp-Uniform Code Generation for ThreadGroupStmt#96
Merged
yaoyaoding merged 3 commits intomainfrom Mar 12, 2026
Merged
[Codegen] Warp-Uniform Code Generation for ThreadGroupStmt#96yaoyaoding merged 3 commits intomainfrom
yaoyaoding merged 3 commits intomainfrom
Conversation
Signed-off-by: Yaoyao Ding <dingyaoyao.cs@gmail.com>
Signed-off-by: Yaoyao Ding <dingyaoyao.cs@gmail.com>
3090945 to
b3096d3
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Warp-Uniform Code Generation for ThreadGroupStmt
Problem
The Tilus-generated GEMM kernel executed 4.6x more instructions than cuBLAS (nvjet), with only 10.5 avg active threads/warp (vs 32). The root cause: thread group selection (e.g.,
single_warp(0),single_thread()) generated per-thread branching likeif (threadIdx.x / 32 == 0), which the GPU treats as divergent — all 32 threads in a warp fetch the branch instructions, but most are predicated off.Solution
We introduced elect-any semantics (
thread_begin=-1) forThreadGroupStmtand updated the code generation to emit warp-uniform predicates using two new primitives:1.
elect.sync(PTX instruction) — Elects exactly one thread per warp. Used forsingle_thread()when the caller doesn't care which thread runs the body.2.
shfl.sync.idx(warp shuffle) — Broadcasts a value from one lane to all lanes in the warp. By shufflingthreadIdx / Nfrom lane 0, the result lives in a uniform register — all threads in the warp hold the same value, so comparisons against it produce uniform predicates that the hardware can execute without divergence.Code generation rules
single_thread()inside a single warpif (tid % 32 == 0)if (elect.sync())single_thread()inside multi-warp groupif (tid % N == 0)if (shfl(tid/32, lane0) == 0 && elect.sync())single_warp(N)/warp_group(N, M)if (tid / 32 == N)if (shfl(tid/32, lane0) == N)— uniform predicatesingle_thread(0)(explicit thread 0)if (tid == 0)The key insight from analyzing nvjet's SASS: it uses uniform registers (
UR) and uniform predicates (UP) for warp dispatch. Theshfl.synctrick forces values through the warp shuffle unit, which the GPU recognizes as uniform — enabling the sameUR/UPcode generation path.API change
single_thread()now defaults to elect-any (thread_begin=-1). Passsingle_thread(0)explicitly to pin to thread 0.single_warp()andwarp_group()retain fixed thread assignment (required for role-based warp dispatch) but now generate uniform predicates viashfl.sync.Files changed
ir/stmt.py—ThreadGroupStmtdocstring forthread_begin=-1ir/utils/thread_group_stack.py— Accept-1in stack validationlang/instructions/root.py—single_thread()defaults to-1ir/tools/printer.py— Printelect_anyfor-1backends/codegen.py—_elect_any_cond()+ uniform predicates for warp-aligned fixed groupsextensions/hidet/ir/primitives/cuda/elect.py— Newelect_sync,shfl_sync_i32,elect_one_syncprimitivestests/lang/test_thread_group.py— Test for elect-any semanticsResult
10240×10240 FP16 GEMM on Blackwell: 1.83ms → 1.75ms (~5% improvement), narrowing the gap to cuBLAS from ~8% to ~3%.
Before the optimization (matmul_v8.py vs cublas on B200)
After the optimization: