Skip to content

[Codegen] Warp-Uniform Code Generation for ThreadGroupStmt#96

Merged
yaoyaoding merged 3 commits intomainfrom
blackwell-gemm
Mar 12, 2026
Merged

[Codegen] Warp-Uniform Code Generation for ThreadGroupStmt#96
yaoyaoding merged 3 commits intomainfrom
blackwell-gemm

Conversation

@yaoyaoding
Copy link
Member

Warp-Uniform Code Generation for ThreadGroupStmt

Problem

The Tilus-generated GEMM kernel executed 4.6x more instructions than cuBLAS (nvjet), with only 10.5 avg active threads/warp (vs 32). The root cause: thread group selection (e.g., single_warp(0), single_thread()) generated per-thread branching like if (threadIdx.x / 32 == 0), which the GPU treats as divergent — all 32 threads in a warp fetch the branch instructions, but most are predicated off.

Solution

We introduced elect-any semantics (thread_begin=-1) for ThreadGroupStmt and updated the code generation to emit warp-uniform predicates using two new primitives:

1. elect.sync (PTX instruction) — Elects exactly one thread per warp. Used for single_thread() when the caller doesn't care which thread runs the body.

2. shfl.sync.idx (warp shuffle) — Broadcasts a value from one lane to all lanes in the warp. By shuffling threadIdx / N from lane 0, the result lives in a uniform register — all threads in the warp hold the same value, so comparisons against it produce uniform predicates that the hardware can execute without divergence.

Code generation rules

Pattern Old codegen New codegen
single_thread() inside a single warp if (tid % 32 == 0) if (elect.sync())
single_thread() inside multi-warp group if (tid % N == 0) if (shfl(tid/32, lane0) == 0 && elect.sync())
single_warp(N) / warp_group(N, M) if (tid / 32 == N) if (shfl(tid/32, lane0) == N) — uniform predicate
single_thread(0) (explicit thread 0) if (tid == 0) unchanged (fixed assignment)

The key insight from analyzing nvjet's SASS: it uses uniform registers (UR) and uniform predicates (UP) for warp dispatch. The shfl.sync trick forces values through the warp shuffle unit, which the GPU recognizes as uniform — enabling the same UR/UP code generation path.

API change

single_thread() now defaults to elect-any (thread_begin=-1). Pass single_thread(0) explicitly to pin to thread 0. single_warp() and warp_group() retain fixed thread assignment (required for role-based warp dispatch) but now generate uniform predicates via shfl.sync.

Files changed

  • ir/stmt.pyThreadGroupStmt docstring for thread_begin=-1
  • ir/utils/thread_group_stack.py — Accept -1 in stack validation
  • lang/instructions/root.pysingle_thread() defaults to -1
  • ir/tools/printer.py — Print elect_any for -1
  • backends/codegen.py_elect_any_cond() + uniform predicates for warp-aligned fixed groups
  • extensions/hidet/ir/primitives/cuda/elect.py — New elect_sync, shfl_sync_i32, elect_one_sync primitives
  • tests/lang/test_thread_group.py — Test for elect-any semantics

Result

10240×10240 FP16 GEMM on Blackwell: 1.83ms → 1.75ms (~5% improvement), narrowing the gap to cuBLAS from ~8% to ~3%.

Before the optimization (matmul_v8.py vs cublas on B200)

       m      n      k   name  latency (ms)       tflops
0   4096   4096   4096  torch      0.144464   951.371639
1   4096   4096   4096  tilus      0.161088   853.191730
2   4096   4096  14336  torch      0.443920  1083.610415
3   4096   4096  14336  tilus      0.472048  1019.041152
4   8192   8192   8192  torch      0.890416  1234.829150
5   8192   8192   8192  tilus      0.953616  1152.992017
6  10240  10240  10240  torch      1.698240  1264.534854
7  10240  10240  10240  tilus      1.833008  1171.562622

After the optimization:

       m      n      k   name  latency (ms)       tflops
0   4096   4096   4096  torch      0.144512   951.055661
1   4096   4096   4096  tilus      0.154528   889.411330
2   4096   4096  14336  torch      0.444624  1081.894693
3   4096   4096  14336  tilus      0.449600  1069.920688
4   8192   8192   8192  torch      0.891408  1233.454975
5   8192   8192   8192  tilus      0.899104  1222.897050
6  10240  10240  10240  torch      1.697792  1264.868521
7  10240  10240  10240  tilus      1.725936  1244.242923

Signed-off-by: Yaoyao Ding <dingyaoyao.cs@gmail.com>
Signed-off-by: Yaoyao Ding <dingyaoyao.cs@gmail.com>
Signed-off-by: Yaoyao Ding <dingyaoyao.cs@gmail.com>
@yaoyaoding yaoyaoding mentioned this pull request Mar 12, 2026
17 tasks
@yaoyaoding yaoyaoding merged commit 364534c into main Mar 12, 2026
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant