Skip to content

[Bugfix] Fix incorrect sync hoist for fragment buffer conditions in ThreadSync#2030

Merged
LeiWang1999 merged 1 commit intotile-ai:mainfrom
LeiWang1999:fix/thread-sync-hoist-fragment-uniform
Apr 11, 2026
Merged

[Bugfix] Fix incorrect sync hoist for fragment buffer conditions in ThreadSync#2030
LeiWang1999 merged 1 commit intotile-ai:mainfrom
LeiWang1999:fix/thread-sync-hoist-fragment-uniform

Conversation

@LeiWang1999
Copy link
Copy Markdown
Member

@LeiWang1999 LeiWang1999 commented Apr 10, 2026

Summary

  • Fix a bug in ConditionThreadPropertyChecker where conditions derived from fragment (local-scope) buffer loads were incorrectly classified as non-block-uniform based solely on storage scope
  • This caused ThreadSync to hoist __syncthreads() out of an if-body, removing the write-before-read synchronization between shared memory writes and TMA store reads
  • The fix removes the scope-based heuristic and relies on recursive index analysis to determine block-uniformity

Problem

When a blocksparse copy kernel uses a fragment buffer for block mask indices (e.g., a = block_mask_f[i]), the condition a >= 0 guarding the copy body is actually block-uniform — all threads in a block hold the same fragment data loaded from BlockMask[blockIdx.y, :].

However, ConditionThreadPropertyChecker::VisitExpr_(BufferLoadNode*) marked all local-scope buffer loads as non-block-uniform. This triggered the sync hoist logic, which moved __syncthreads() from between the shared memory writes and TMA store to before the if-statement — breaking the synchronization guarantee.

Before fix (incorrect):

__syncthreads();        // hoisted here (too early)
if (a >= 0) {
    write_to_shared();  // all threads
    tma_store();        // elected thread — no sync!
}

After fix (correct):

__syncthreads();        // loop-carried sync
if (a >= 0) {
    write_to_shared();  // all threads
    __syncthreads();    // correctly placed intra-iteration sync
    tma_store();        // elected thread
}

Root Cause

In ConditionThreadPropertyChecker::VisitExpr_(BufferLoadNode*):

if (IsThreadLocalScope(GetScope(op->buffer->data))) {
    current_.is_block_uniform = false;  // too conservative
}

This unconditionally marked fragment buffer loads as non-block-uniform. The fix removes this check and instead relies on the recursive visit of buffer load indices — if any index depends on threadIdx, VisitExpr_(VarNode*) will correctly set is_block_uniform = false.

Test plan

  • test_blocksparse_copy_tma — previously failing, now passes
  • test_blocksparse_copy_cp_async — passes
  • All 20 existing test_tilelang_transform_thread_sync tests pass

Summary by CodeRabbit

  • Bug Fixes
    • Improved compiler analysis for GPU thread uniformity. The compiler now more accurately determines when operations can be safely executed uniformly across threads by analyzing actual memory access patterns rather than relying solely on storage scope, leading to better code generation.

…hreadSync

The ConditionThreadPropertyChecker in ThreadSync incorrectly classified
conditions derived from fragment (local-scope) buffer loads as
non-block-uniform, solely based on storage scope. This caused the sync
planner to hoist __syncthreads() from inside the if-body to before the
if-statement, removing the write-before-read synchronization guarantee
between shared memory writes and TMA store reads.

Fragment buffers commonly hold block-uniform data when populated from
block-uniform global addresses (e.g., T.copy(BlockMask[blockIdx.y, :],
fragment)). The fix removes the scope-based heuristic and instead relies
on the recursive visit of buffer load indices — if any index depends on
threadIdx, VisitExpr_(VarNode*) will correctly mark the load as
non-block-uniform.

Before fix:
  __syncthreads();        // hoisted here (too early)
  if (a >= 0) {
      write_to_shared();  // all threads
      tma_store();        // elected thread — no sync protection!
  }

After fix:
  __syncthreads();        // loop-carried sync
  if (a >= 0) {
      write_to_shared();  // all threads
      __syncthreads();    // correctly placed intra-iteration sync
      tma_store();        // elected thread
  }
@github-actions
Copy link
Copy Markdown

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 10, 2026

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 49dab0d1-67c1-4e9d-b8cb-b5f6b81d0432

📥 Commits

Reviewing files that changed from the base of the PR and between b1a88bf and b1f1357.

📒 Files selected for processing (1)
  • src/transform/thread_storage_sync.cc

📝 Walkthrough

Walkthrough

Modified buffer load handling in thread property checker to delegate uniformity determination to load index traversal instead of unconditionally marking thread-local buffer loads as non-uniform. Runtime dependency is always recorded during analysis.

Changes

Cohort / File(s) Summary
Buffer Load Property Checking
src/transform/thread_storage_sync.cc
Refactored ConditionThreadPropertyChecker::VisitExpr_(const BufferLoadNode *op) to always set current_.depends_on_runtime = true and determine block uniformity through recursive index traversal rather than immediately marking non-uniform for thread-local storage scopes.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

Suggested reviewers

  • SiriusNEO
  • silentCoder-dev

Poem

🐰 A buffer hops through indices fine,
No longer forced to cross the line,
Let recursion trace the way,
Thread-local uniformity to display! ✨

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately summarizes the main bugfix—correcting incorrect sync hoisting for fragment buffer conditions in ThreadSync, which is the core change in the PR.
Docstring Coverage ✅ Passed No functions found in the changed files to evaluate docstring coverage. Skipping docstring coverage check.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@LeiWang1999
Copy link
Copy Markdown
Member Author

@regression-perf

@github-actions
Copy link
Copy Markdown

Performance Regression Test Report

Triggered by: @LeiWang1999
Workflow run: https://github.com/tile-ai/tilelang/actions/runs/24259869181

Results

File Original Latency Current Latency Speedup
example_mhc_post 0.110014 0.109829 1.00169
example_warp_specialize_gemm_copy_0_gemm_1 0.03984 0.039713 1.0032
example_tilelang_gemm_splitk 1.02867 1.02505 1.00352
block_sparse_attn_tilelang 0.00889963 0.00884271 1.00644
example_linear_attn_fwd 0.0365474 0.0362907 1.00707
example_linear_attn_bwd 0.152675 0.151524 1.00759
example_warp_specialize_gemm_barrierpipe_stage2 0.0407512 0.0404342 1.00784
tilelang_example_sparse_tensorcore 0.0147293 0.0146139 1.00789
example_tilelang_gemm_splitk_vectorize_atomicadd 1.04023 1.03168 1.00829
example_topk 0.0112042 0.0111024 1.00917
example_warp_specialize_gemm_copy_1_gemm_0 0.0279914 0.0277164 1.00992
example_warp_specialize_gemm_softpipe_stage2 0.0279775 0.0276939 1.01024
example_gemv 0.291694 0.288217 1.01206
example_vertical_slash_sparse_attn 0.232311 0.229305 1.01311
example_gemm 0.0227593 0.0223502 1.0183
example_mhc_pre 0.15715 0.153726 1.02228
example_tilelang_gemm_fp8 0.319121 0.310851 1.0266
example_tilelang_gemm_fp8_2xAcc 0.192805 0.18669 1.03276
example_gemm_autotune 0.0232281 0.0224609 1.03416
example_gemm_intrinsics 0.0360856 0.0348184 1.03639
example_tilelang_gemm_fp8_intrinsic 0.873606 0.842029 1.0375

Artifacts

  • regression_result.png (speedup plot) is attached as a workflow artifact. Download it from the workflow run page above.

@LeiWang1999
Copy link
Copy Markdown
Member Author

@regression-perf

@github-actions
Copy link
Copy Markdown

Performance Regression Test Report

Triggered by: @LeiWang1999
Workflow run: https://github.com/tile-ai/tilelang/actions/runs/24262243497

Results

File Original Latency Current Latency Speedup
example_warp_specialize_gemm_barrierpipe_stage2 0.0404051 0.0405162 0.997257
example_mhc_pre 0.152252 0.152549 0.998054
example_blocksparse_gemm 0.0199867 0.0200172 0.998475
example_tilelang_sparse_gqa_decode_varlen_indice 0.0159985 0.0160188 0.998733
example_tilelang_sparse_gqa_decode_varlen_mask 0.0176052 0.0176242 0.998923
example_tilelang_gemm_splitk 1.02358 1.02445 0.99915
example_linear_attn_fwd 0.0362822 0.0363121 0.999177
sparse_mla_fwd 0.127788 0.127892 0.999187
example_gqa_sink_bwd_bhsd_sliding_window 0.0255039 0.0255237 0.999225
example_mha_fwd_varlen 0.046061 0.0460965 0.999231
example_warp_specialize_gemm_softpipe_stage2 0.0276923 0.0277109 0.99933
example_mha_fwd_bhsd 0.0105963 0.0106032 0.999354
example_mha_sink_fwd_bhsd 0.0149603 0.0149694 0.999398
example_tilelang_nsa_decode 0.00683407 0.00683816 0.999402
example_tilelang_nsa_fwd 0.00703035 0.00703387 0.999499
sparse_mla_fwd_pipelined 0.0944777 0.0945235 0.999516
example_tilelang_gemm_fp8 0.310783 0.310896 0.999636
example_warp_specialize_gemm_copy_1_gemm_0 0.0276933 0.0277031 0.999646
example_gqa_bwd 0.0463806 0.0463929 0.999734
example_dequant_gemm_fp4_hopper 1.05179 1.05202 0.999774
example_fusedmoe_tilelang 0.132862 0.132888 0.9998
tilelang_example_sparse_tensorcore 0.0146071 0.0146098 0.999812
example_mha_bwd_bhsd 0.0392006 0.0392077 0.999819
example_gqa_fwd_bshd 0.0701593 0.0701702 0.999845
topk_selector 0.0538808 0.0538885 0.999857
example_gqa_bwd_tma_reduce_varlen 0.0474702 0.0474757 0.999885
example_elementwise_add 0.115482 0.115492 0.999909
example_topk 0.0111112 0.0111118 0.999944
example_convolution 1.29646 1.29651 0.999958
example_mha_sink_bwd_bhsd 0.0634405 0.0634429 0.999962
example_gemm_intrinsics 0.0348138 0.034814 0.999995
example_gemm_autotune 0.0224624 0.0224623 1
example_mha_sink_bwd_bhsd_sliding_window 0.0430716 0.043071 1.00001
example_mha_fwd_bshd 0.0257638 0.0257634 1.00002
example_dynamic 0.642802 0.642784 1.00003
example_dequant_gemm_w4a8 5.5803 5.58002 1.00005
example_gqa_sink_bwd_bhsd 0.0422227 0.0422201 1.00006
example_mha_bwd_bshd 0.0398736 0.0398711 1.00006
example_mla_decode 0.462308 0.462262 1.0001
example_gemv 0.288256 0.288218 1.00013
block_sparse_attn_tilelang 0.00884862 0.00884728 1.00015
example_linear_attn_bwd 0.15154 0.151516 1.00016
sparse_mla_bwd 0.301265 0.301213 1.00017
example_tilelang_gemm_fp8_intrinsic 0.841968 0.841818 1.00018
example_gemm 0.0224 0.0223957 1.00019
example_dequant_gemv_fp16xint4 0.0283406 0.0283344 1.00022
fp8_lighting_indexer 0.0359478 0.0359381 1.00027
example_dequant_gemm_bf16_mxfp4_hopper 0.509684 0.50954 1.00028
example_group_per_split_token_cast_to_fp8 0.0103478 0.0103435 1.00041
example_per_token_cast_to_fp8 0.00737191 0.00736876 1.00043
example_mha_sink_fwd_bhsd_sliding_window 0.0149942 0.014986 1.00055
example_convolution_autotune 0.983369 0.982825 1.00055
example_warp_specialize_gemm_copy_0_gemm_1 0.0397308 0.0397088 1.00056
example_vertical_slash_sparse_attn 0.229402 0.229274 1.00056
example_tilelang_gemm_splitk_vectorize_atomicadd 1.03291 1.03228 1.00061
example_mhc_post 0.109754 0.109683 1.00065
example_tilelang_gemm_fp8_2xAcc 0.186833 0.186709 1.00067
example_dequant_gemm_bf16_fp4_hopper 0.555849 0.555389 1.00083
example_tilelang_block_sparse_attn 0.00871885 0.00870767 1.00128

Artifacts

  • regression_result.png (speedup plot) is attached as a workflow artifact. Download it from the workflow run page above.

@LeiWang1999 LeiWang1999 merged commit 90299d6 into tile-ai:main Apr 11, 2026
6 of 7 checks passed
kurisu6912 pushed a commit that referenced this pull request Apr 13, 2026
…hreadSync (#2030)

The ConditionThreadPropertyChecker in ThreadSync incorrectly classified
conditions derived from fragment (local-scope) buffer loads as
non-block-uniform, solely based on storage scope. This caused the sync
planner to hoist __syncthreads() from inside the if-body to before the
if-statement, removing the write-before-read synchronization guarantee
between shared memory writes and TMA store reads.

Fragment buffers commonly hold block-uniform data when populated from
block-uniform global addresses (e.g., T.copy(BlockMask[blockIdx.y, :],
fragment)). The fix removes the scope-based heuristic and instead relies
on the recursive visit of buffer load indices — if any index depends on
threadIdx, VisitExpr_(VarNode*) will correctly mark the load as
non-block-uniform.

Before fix:
  __syncthreads();        // hoisted here (too early)
  if (a >= 0) {
      write_to_shared();  // all threads
      tma_store();        // elected thread — no sync protection!
  }

After fix:
  __syncthreads();        // loop-carried sync
  if (a >= 0) {
      write_to_shared();  // all threads
      __syncthreads();    // correctly placed intra-iteration sync
      tma_store();        // elected thread
  }
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant