Skip to content

[Refactor] Refactor DecoupleTypeCast Pass#2026

Merged
LeiWang1999 merged 13 commits intotile-ai:mainfrom
LJC00118:qwq16
Apr 14, 2026
Merged

[Refactor] Refactor DecoupleTypeCast Pass#2026
LeiWang1999 merged 13 commits intotile-ai:mainfrom
LJC00118:qwq16

Conversation

@LJC00118
Copy link
Copy Markdown
Collaborator

@LJC00118 LJC00118 commented Apr 9, 2026

Summary by CodeRabbit

  • Bug Fixes

    • More reliable mixed-precision handling by detecting explicit casts and ensuring correct insertion/reuse of temporary cast buffers; now accounts for local bindings so casts hidden behind locals are handled.
  • Refactor

    • Improved vectorization decisions so accesses whose indices don't vary with the inner loop are treated as local, enabling more accurate vectorized code generation.
  • Tests

    • Added tests covering cast-related transforms with local bindings and scalar-load scenarios to verify when cast buffers are (not) introduced.
  • Documentation

    • Updated descriptions to reflect cast-driven mixed-precision detection.

@github-actions
Copy link
Copy Markdown

github-actions bot commented Apr 9, 2026

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 9, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

DecoupleTypeCast now detects mixed-precision by presence of Cast nodes, inlines LetStmt bindings for analysis, gathers memory accesses with a MemoryAccessCollector, and restructures vectorized loops into three phases: copy-from, compute (with local cast buffers), and copy-to. New tests validate LetStmt and scalar-load cases.

Changes

Cohort / File(s) Summary
Tests
testing/python/transform/test_tilelang_transform_decouple_type_cast.py
Add test_local_to_memory_with_let_stmt (IR-level) and test_e2e_scalar_load_no_cast_buffer (CUDA e2e) to validate LetStmt inlining, three-phase loop emission, and that scalar-loads don't force load-side cast buffers.
Core Transform Logic
tilelang/transform/decouple_type_cast.py
Switch mixed-precision detection to Cast-node presence; inline LetStmt bindings before analysis; replace previous store/load classification with MemoryAccessCollector (skips certain load traversals); build separate load/store cast-buffer maps with reuse for overlaps; emit copy-from, compute, then copy-to phases; update docstring and control flow for vectorized loops.
Vectorization Planner
src/transform/loop_vectorize.cc
Reclassify non-local/shared buffers whose indices do not depend on the innermost loop var as local/fragment-like (update local_fragment_min/local_fragment_buffers) instead of memory constraints, altering vectorization planning.

Sequence Diagram

sequenceDiagram
    participant Input as PrimFunc (with LetStmt)
    participant Inliner as LetStmt Inliner
    participant Detector as Cast Detector
    participant Collector as MemoryAccessCollector
    participant Transformer as DecoupleTypeCast
    participant Output as Transformed PrimFunc

    Input->>Inliner: Inline LetStmt bindings
    Inliner->>Detector: Scan inlined tree for Cast nodes
    alt Casts found
        Detector->>Collector: Collect shared/global BufferLoad/BufferStore (skip certain BufferLoad index traversal)
        Collector->>Transformer: Provide access patterns (loads/stores)
        Transformer->>Transformer: Build store/load cast-buffer maps and merge overlaps
        Transformer->>Transformer: Emit copy-from-memory loops (reads into cast buffers)
        Transformer->>Transformer: Emit compute loop using local cast buffers
        Transformer->>Transformer: Emit copy-to-memory loops (stores from cast buffers)
        Transformer->>Output: Return transformed PrimFunc
    else No Casts
        Detector->>Output: Return original PrimFunc
    end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested reviewers

  • LeiWang1999

Poem

🐰 I nudged the Lets to show their face,
I sniffed for Casts in every place,
Copy-out, compute in tiny beds,
Copy-back to stores and sleepy heads,
A hopping patch of code—soft pace 🥕

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 60.71% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly identifies the main change: a refactoring of the DecoupleTypeCast pass. The summary confirms significant architectural changes to this pass, including mixed-precision detection logic and memory access analysis rework.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@LJC00118
Copy link
Copy Markdown
Collaborator Author

LJC00118 commented Apr 9, 2026

@regression-perf

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (2)
tilelang/transform/decouple_type_cast.py (2)

283-285: Consider using list unpacking for clarity.

Static analysis suggests using list unpacking which is more Pythonic and slightly more efficient.

♻️ Suggested change
-        all_stmts = copy_from_loops + [compute_loop] + copy_to_loops
+        all_stmts = [*copy_from_loops, compute_loop, *copy_to_loops]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/transform/decouple_type_cast.py` around lines 283 - 285, The
combination of copy_from_loops, compute_loop, and copy_to_loops into all_stmts
can be clearer and more Pythonic by using list unpacking; replace the explicit
concatenation assignment "all_stmts = copy_from_loops + [compute_loop] +
copy_to_loops" with an unpacked list like "all_stmts = [*copy_from_loops,
compute_loop, *copy_to_loops]" and keep the existing conditional construction of
result using SeqStmt and the single-element fallback; update the code
surrounding all_stmts, SeqStmt, result, copy_from_loops, compute_loop, and
copy_to_loops accordingly.

141-161: Consider documenting the scope limitation.

The function only handles LetStmt, IfThenElse, and SeqStmt. Nested LetStmts within other statement types (e.g., nested For loops) won't be inlined. While vectorized loop bodies typically don't contain nested For loops, adding a brief note about this limitation would improve maintainability.

📝 Suggested documentation addition
 def inline_let_stmts(stmt: Stmt) -> Stmt:
     """Inline all LetStmt bindings in *stmt* so that downstream visitors can
     see the original BufferLoad nodes that were hidden behind Var references.
 
     Used before collecting memory accesses so that BufferLoads inside LetStmt
     values are visible to ``MemoryAccessCollector``.
+
+    Note: Only traverses LetStmt, IfThenElse, and SeqStmt. LetStmts nested
+    within other statement types (e.g., For loops) are not inlined.
     """
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/transform/decouple_type_cast.py` around lines 141 - 161, Update the
inline_let_stmts docstring to explicitly state its scope limitation: it only
recurses into LetStmt, IfThenElse, and SeqStmt and will not inline LetStmt
bindings nested inside other statement types (e.g., For loops or other custom
Stmt subclasses); mention the assumption about vectorized loop bodies not
containing nested For loops and advise that additional handlers (e.g., adding
traversal logic for For or other Stmt subclasses) are required if those cases
must be supported.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@testing/python/transform/test_tilelang_transform_decouple_type_cast.py`:
- Around line 123-145: The test uses the TVM IR node class name T.Cast instead
of the callable helper T.cast in test_local_to_memory_with_let_stmt: replace the
incorrect call to T.Cast("float8_e4m3fn", a_frag[i] * factor) with a call to
T.cast passing the expression then the dtype object, i.e. use T.cast(a_frag[i] *
factor, T.float8_e4m3fn) when assigning into b_local_cast so the test matches
the transform's expected API (symbols: T.Cast, T.cast,
test_local_to_memory_with_let_stmt, b_local_cast, a_frag, factor).

---

Nitpick comments:
In `@tilelang/transform/decouple_type_cast.py`:
- Around line 283-285: The combination of copy_from_loops, compute_loop, and
copy_to_loops into all_stmts can be clearer and more Pythonic by using list
unpacking; replace the explicit concatenation assignment "all_stmts =
copy_from_loops + [compute_loop] + copy_to_loops" with an unpacked list like
"all_stmts = [*copy_from_loops, compute_loop, *copy_to_loops]" and keep the
existing conditional construction of result using SeqStmt and the single-element
fallback; update the code surrounding all_stmts, SeqStmt, result,
copy_from_loops, compute_loop, and copy_to_loops accordingly.
- Around line 141-161: Update the inline_let_stmts docstring to explicitly state
its scope limitation: it only recurses into LetStmt, IfThenElse, and SeqStmt and
will not inline LetStmt bindings nested inside other statement types (e.g., For
loops or other custom Stmt subclasses); mention the assumption about vectorized
loop bodies not containing nested For loops and advise that additional handlers
(e.g., adding traversal logic for For or other Stmt subclasses) are required if
those cases must be supported.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 5e2b7bbd-66b5-4022-8f53-1e84f54224ed

📥 Commits

Reviewing files that changed from the base of the PR and between 86e37b7 and bea77fd.

📒 Files selected for processing (2)
  • testing/python/transform/test_tilelang_transform_decouple_type_cast.py
  • tilelang/transform/decouple_type_cast.py

@github-actions
Copy link
Copy Markdown

github-actions bot commented Apr 9, 2026

Performance Regression Test Report

Triggered by: @LJC00118
Workflow run: https://github.com/tile-ai/tilelang/actions/runs/24190323319

Results

File Original Latency Current Latency Speedup
example_gemm 0.0223872 0.0224355 0.997847
example_warp_specialize_gemm_barrierpipe_stage2 0.0406728 0.0407266 0.998679
example_dequant_gemv_fp16xint4 0.0283211 0.028353 0.998876
example_elementwise_add 0.115385 0.115499 0.999008
example_topk 0.011107 0.0111151 0.999269
example_gemm_autotune 0.0225392 0.0225503 0.999504
block_sparse_attn_tilelang 0.00884803 0.0088522 0.999529
example_mhc_post 0.109824 0.109854 0.999726
example_mha_fwd_bshd 0.02582 0.0258269 0.999734
tilelang_example_sparse_tensorcore 0.0146395 0.0146417 0.99985
example_warp_specialize_gemm_softpipe_stage2 0.0277123 0.0277155 0.999883
example_tilelang_gemm_fp8 0.311589 0.311614 0.999918
example_fusedmoe_tilelang 0.133139 0.133147 0.999937
example_mha_fwd_bhsd 0.0106253 0.0106258 0.999957
example_tilelang_gemm_fp8_intrinsic 0.842096 0.842123 0.999968
example_gqa_bwd_tma_reduce_varlen 0.0475822 0.0475823 0.999997
example_dequant_gemm_w4a8 5.58031 5.58031 0.999999
example_mha_bwd_bhsd 0.0392995 0.0392977 1.00004
example_dynamic 0.643569 0.643538 1.00005
example_gemm_intrinsics 0.0348708 0.0348663 1.00013
example_linear_attn_bwd 0.151772 0.151751 1.00014
example_warp_specialize_gemm_copy_1_gemm_0 0.0277128 0.0277087 1.00015
example_vertical_slash_sparse_attn 0.229633 0.229588 1.0002
example_gqa_fwd_bshd 0.0703171 0.070303 1.0002
example_mha_bwd_bshd 0.0399724 0.0399639 1.00021
example_linear_attn_fwd 0.0363282 0.0363201 1.00022
example_warp_specialize_gemm_copy_0_gemm_1 0.039934 0.0399237 1.00026
example_gqa_bwd 0.0464754 0.0464632 1.00026
example_tilelang_gemm_splitk_vectorize_atomicadd 1.03428 1.03398 1.00029
example_gemv 0.288338 0.288221 1.00041
example_tilelang_gemm_splitk 1.02634 1.02588 1.00044
example_mha_fwd_varlen 0.0462662 0.0462417 1.00053
example_tilelang_gemm_fp8_2xAcc 0.186878 0.186768 1.00059
example_dequant_gemm_fp4_hopper 1.0568 1.05593 1.00083
example_per_token_cast_to_fp8 0.0073906 0.00738083 1.00132
example_convolution_autotune 0.985233 0.983315 1.00195
example_tilelang_nsa_fwd 0.00706535 0.00704417 1.00301
example_group_per_split_token_cast_to_fp8 0.0104306 0.0103959 1.00334
example_mha_sink_fwd_bhsd_sliding_window 0.0150972 0.015026 1.00474
example_tilelang_block_sparse_attn 0.00877764 0.00872921 1.00555
example_tilelang_nsa_decode 0.00687591 0.00683728 1.00565
sparse_mla_fwd_pipelined 0.0953576 0.0947736 1.00616
example_mha_sink_bwd_bhsd_sliding_window 0.0434943 0.0432094 1.00659
example_mha_sink_fwd_bhsd 0.015076 0.0149752 1.00673
topk_selector 0.0543099 0.0539067 1.00748
example_tilelang_sparse_gqa_decode_varlen_indice 0.0161391 0.0160161 1.00768
sparse_mla_fwd 0.128942 0.127953 1.00773
fp8_lighting_indexer 0.0363481 0.0360673 1.00779
example_blocksparse_gemm 0.0202377 0.0200784 1.00793
example_mha_sink_bwd_bhsd 0.0641407 0.0636011 1.00849
example_tilelang_sparse_gqa_decode_varlen_mask 0.0177637 0.0176084 1.00882
sparse_mla_bwd 0.304641 0.301496 1.01043
example_dequant_gemm_bf16_fp4_hopper 0.562198 0.556278 1.01064
example_convolution 1.31267 1.29846 1.01094
example_mhc_pre 0.154035 0.152319 1.01127
example_gqa_sink_bwd_bhsd_sliding_window 0.0258329 0.0255442 1.0113
example_gqa_sink_bwd_bhsd 0.0427742 0.0422263 1.01298
example_dequant_gemm_bf16_mxfp4_hopper 0.516522 0.509887 1.01301
example_mla_decode 0.468578 0.462225 1.01375

Artifacts

  • regression_result.png (speedup plot) is attached as a workflow artifact. Download it from the workflow run page above.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (4)
tilelang/transform/decouple_type_cast.py (4)

170-172: Consider avoiding unnecessary SeqStmt recreation.

The function creates a new SeqStmt even when no LetStmt inlining occurred in any of the child statements. This creates unnecessary object allocations.

♻️ Suggested improvement
     elif isinstance(stmt, SeqStmt):
         new_seq = [inline_let_stmts(s) for s in stmt.seq]
+        if all(n is o for n, o in zip(new_seq, stmt.seq)):
+            return stmt
         return SeqStmt(new_seq)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/transform/decouple_type_cast.py` around lines 170 - 172,
inline_let_stmts unnecessarily reconstructs a SeqStmt even when none of the
children change; modify inline_let_stmts so that when handling a SeqStmt it maps
children (stmt.seq), compares each new child to the original (or checks if any
child identity/value changed), and only returns a new SeqStmt(new_seq) if at
least one child was modified—otherwise return the original stmt to avoid extra
allocations (reference symbols: inline_let_stmts, SeqStmt, stmt.seq).

277-278: Avoid accessing private attribute _seen_load_buffers.

Accessing the private _seen_load_buffers attribute from outside the class breaks encapsulation. Consider exposing a public method or property.

♻️ Suggested improvement

Add a public method to MemoryAccessCollector:

 class MemoryAccessCollector(PyStmtExprVisitor):
     ...
+    def is_load_buffer(self, buf: Buffer) -> bool:
+        """Check if buffer was seen as a load source."""
+        return buf in self._seen_load_buffers

Then update the usage:

-        rmw_buffers = {buf: store_cast_buffers[buf] for buf in store_cast_buffers if buf in collector._seen_load_buffers}
+        rmw_buffers = {buf: store_cast_buffers[buf] for buf in store_cast_buffers if collector.is_load_buffer(buf)}
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/transform/decouple_type_cast.py` around lines 277 - 278, Replace the
direct access to the private attribute collector._seen_load_buffers with a
public accessor on MemoryAccessCollector (e.g., a property or method named
seen_load_buffers or get_seen_load_buffers), update MemoryAccessCollector to
expose that public API, and then change the comprehension that builds
rmw_buffers to use collector.seen_load_buffers() (or the property) instead of
_seen_load_buffers so rmw_buffers, load_cast_buffers and the construction of
all_copy_from remain correct while preserving encapsulation.

295-296: Prefer list unpacking over concatenation.

Using list unpacking is more idiomatic and slightly more efficient than concatenation.

♻️ Suggested improvement
-        all_stmts = copy_from_loops + [compute_loop] + copy_to_loops
+        all_stmts = [*copy_from_loops, compute_loop, *copy_to_loops]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/transform/decouple_type_cast.py` around lines 295 - 296, The code
builds all_stmts by concatenating lists; replace the concatenation with list
unpacking to be more idiomatic and slightly faster: construct all_stmts using
unpacking of copy_from_loops and copy_to_loops with compute_loop in the middle
(referencing variables copy_from_loops, compute_loop, copy_to_loops), then keep
the existing result creation using SeqStmt(all_stmts) if length > 1 else
all_stmts[0].

154-174: Document that inline_let_stmts only handles top-level statements.

The function doesn't recursively descend into For, While, or other loop constructs. If a LetStmt is nested inside such constructs, it won't be inlined. This appears intentional for the use case (vectorized loop bodies typically have top-level LetStmts), but a brief comment would clarify this design choice.

📝 Suggested documentation
 def inline_let_stmts(stmt: Stmt) -> Stmt:
     """Inline all LetStmt bindings in *stmt* so that downstream visitors can
     see the original BufferLoad nodes that were hidden behind Var references.
 
     Used before collecting memory accesses so that BufferLoads inside LetStmt
     values are visible to ``MemoryAccessCollector``.
+
+    Note: Only handles LetStmt, IfThenElse, and SeqStmt at the top level.
+    LetStmts nested inside For/While loops are not inlined.
     """
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/transform/decouple_type_cast.py` around lines 154 - 174,
inline_let_stmts currently only inlines LetStmt bindings at the top-level and
does not descend into loop constructs like For or While (so nested LetStmts
inside loop bodies are not handled); update the function docstring or add a
short comment above inline_let_stmts clarifying that it intentionally only
handles top-level statements (LetStmt, IfThenElse, SeqStmt) and does not recurse
into loop constructs such as For or While, and mention the rationale (e.g.,
vectorized loop bodies have top-level LetStmts) so readers know this is by
design rather than an omission.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@tilelang/transform/decouple_type_cast.py`:
- Around line 170-172: inline_let_stmts unnecessarily reconstructs a SeqStmt
even when none of the children change; modify inline_let_stmts so that when
handling a SeqStmt it maps children (stmt.seq), compares each new child to the
original (or checks if any child identity/value changed), and only returns a new
SeqStmt(new_seq) if at least one child was modified—otherwise return the
original stmt to avoid extra allocations (reference symbols: inline_let_stmts,
SeqStmt, stmt.seq).
- Around line 277-278: Replace the direct access to the private attribute
collector._seen_load_buffers with a public accessor on MemoryAccessCollector
(e.g., a property or method named seen_load_buffers or get_seen_load_buffers),
update MemoryAccessCollector to expose that public API, and then change the
comprehension that builds rmw_buffers to use collector.seen_load_buffers() (or
the property) instead of _seen_load_buffers so rmw_buffers, load_cast_buffers
and the construction of all_copy_from remain correct while preserving
encapsulation.
- Around line 295-296: The code builds all_stmts by concatenating lists; replace
the concatenation with list unpacking to be more idiomatic and slightly faster:
construct all_stmts using unpacking of copy_from_loops and copy_to_loops with
compute_loop in the middle (referencing variables copy_from_loops, compute_loop,
copy_to_loops), then keep the existing result creation using SeqStmt(all_stmts)
if length > 1 else all_stmts[0].
- Around line 154-174: inline_let_stmts currently only inlines LetStmt bindings
at the top-level and does not descend into loop constructs like For or While (so
nested LetStmts inside loop bodies are not handled); update the function
docstring or add a short comment above inline_let_stmts clarifying that it
intentionally only handles top-level statements (LetStmt, IfThenElse, SeqStmt)
and does not recurse into loop constructs such as For or While, and mention the
rationale (e.g., vectorized loop bodies have top-level LetStmts) so readers know
this is by design rather than an omission.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 665983a9-4bab-4416-b021-848d0b643a6f

📥 Commits

Reviewing files that changed from the base of the PR and between bea77fd and b603136.

📒 Files selected for processing (2)
  • testing/python/transform/test_tilelang_transform_decouple_type_cast.py
  • tilelang/transform/decouple_type_cast.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • testing/python/transform/test_tilelang_transform_decouple_type_cast.py

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (2)
tilelang/transform/decouple_type_cast.py (2)

138-156: Consider exposing _seen_load_buffers via a property.

The _seen_load_buffers attribute is accessed directly at line 297 (collector._seen_load_buffers). Since this is a "private" attribute (underscore prefix), consider exposing it via a read-only property for cleaner encapsulation:

♻️ Suggested property accessor
     def __init__(self, loop_var: Var):
         super().__init__()
         self.loop_var = loop_var
         self.stores: list[BufferStore] = []
         self.loads: list[BufferLoad] = []
         self._seen_load_buffers: set[Buffer] = set()
+
+    `@property`
+    def seen_load_buffers(self) -> set[Buffer]:
+        """Buffers that have been collected as loads."""
+        return self._seen_load_buffers

Then at line 297:

-        rmw_buffers = {buf: store_cast_buffers[buf] for buf in store_cast_buffers if buf in collector._seen_load_buffers}
+        rmw_buffers = {buf: store_cast_buffers[buf] for buf in store_cast_buffers if buf in collector.seen_load_buffers}
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/transform/decouple_type_cast.py` around lines 138 - 156, The code is
directly accessing the private attribute _seen_load_buffers (e.g.,
collector._seen_load_buffers) which breaks encapsulation; add a read-only
property (e.g., seen_load_buffers) on the class that exposes a frozenset or
tuple view of self._seen_load_buffers and update call sites to use
collector.seen_load_buffers; locate the attribute and related methods
visit_buffer_load_ and visit_buffer_store_ to add the property and ensure the
internal set remains private while external code only reads the immutable view.

314-316: Minor: Consider using unpacking for list construction.

Per the static analysis hint (RUF005), list unpacking can be slightly more efficient than concatenation:

♻️ Suggested change
-        all_stmts = copy_from_loops + [compute_loop] + copy_to_loops
+        all_stmts = [*copy_from_loops, compute_loop, *copy_to_loops]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/transform/decouple_type_cast.py` around lines 314 - 316, The list
concatenation building all_stmts should use list unpacking for minor efficiency
and clarity: replace the concatenation expression that combines copy_from_loops,
compute_loop, and copy_to_loops with an unpacking form (e.g., [*copy_from_loops,
compute_loop, *copy_to_loops]) and keep the existing conditional that constructs
result = SeqStmt(all_stmts) if len(all_stmts) > 1 else all_stmts[0]; reference
names: copy_from_loops, compute_loop, copy_to_loops, all_stmts, SeqStmt.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@tilelang/transform/decouple_type_cast.py`:
- Around line 138-156: The code is directly accessing the private attribute
_seen_load_buffers (e.g., collector._seen_load_buffers) which breaks
encapsulation; add a read-only property (e.g., seen_load_buffers) on the class
that exposes a frozenset or tuple view of self._seen_load_buffers and update
call sites to use collector.seen_load_buffers; locate the attribute and related
methods visit_buffer_load_ and visit_buffer_store_ to add the property and
ensure the internal set remains private while external code only reads the
immutable view.
- Around line 314-316: The list concatenation building all_stmts should use list
unpacking for minor efficiency and clarity: replace the concatenation expression
that combines copy_from_loops, compute_loop, and copy_to_loops with an unpacking
form (e.g., [*copy_from_loops, compute_loop, *copy_to_loops]) and keep the
existing conditional that constructs result = SeqStmt(all_stmts) if
len(all_stmts) > 1 else all_stmts[0]; reference names: copy_from_loops,
compute_loop, copy_to_loops, all_stmts, SeqStmt.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 91b4fbc2-140a-4215-a643-bf30da752eac

📥 Commits

Reviewing files that changed from the base of the PR and between b603136 and 03bb070.

📒 Files selected for processing (2)
  • src/transform/loop_vectorize.cc
  • tilelang/transform/decouple_type_cast.py

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (1)
testing/python/transform/test_tilelang_transform_decouple_type_cast.py (1)

123-145: Assert that this case still lowers through a LetStmt.

Right now _check only proves the transformed shape. If factor = scale[i] stops lowering to a LetStmt, this test still passes without exercising inline_let_stmts. Please pin that precondition with a direct structural assertion on the untransformed TIR.

Based on learnings, tests in testing/python/transform should validate structural patterns in the generated IR/source rather than only inferred behavior.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@testing/python/transform/test_tilelang_transform_decouple_type_cast.py`
around lines 123 - 145, The test test_local_to_memory_with_let_stmt only
verifies shapes via _check and can pass even if the load "factor = scale[i]" no
longer lowers to a LetStmt; add an explicit structural assertion on the original
TIR to ensure the load is lowered into a LetStmt before transformation. Locate
the untransformed TIR produced in test_local_to_memory_with_let_stmt (use the
before prim_func or the pre-_check representation) and assert that a LetStmt
node exists and binds the temporary for "factor = scale[i]" (or that the exact
pattern "factor" bound from scale[i] is present), failing the test if
inline_let_stmts has removed that LetStmt; keep the assertion adjacent to where
_check is invoked so it verifies the precondition prior to running the
transform.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tilelang/transform/decouple_type_cast.py`:
- Around line 138-157: The collector currently deduplicates loads by Buffer only
(see visit_buffer_load_, self._seen_load_buffers) which merges distinct index
patterns (e.g., a[i] vs a[i+1]); change the logic to distinguish per-access
patterns by replacing self._seen_load_buffers: use a key built from the buffer
plus a normalized representation of op.indices (e.g., tuple of index ASTs or
hashed expressions) and deduplicate on that key when appending to self.loads, or
alternatively detect when the same Buffer is seen with non-equivalent index
patterns and bail out (return/skip transformation) so LoadReplacer cannot
incorrectly map different accesses to the same cast buffer; update any uses of
_seen_load_buffers and ensure LoadReplacer is compatible with the new per-access
keys.
- Around line 160-167: The visitor visit_call_ currently skips the condition of
a tir.if_then_else (op.op.same_as(_IF_THEN_ELSE_OP)) but still collects loads
from its true/false arms, which lets those loads be hoisted unguarded; to fix,
change the decoupling logic so loads found under an expression-level IfThenElse
are either (A) annotated with the arm predicate and have that predicate carried
into the generated copy-from loop (so the copy loop uses the predicate when
emitting guarded loads), or (B) skipped entirely for decoupling when the
IfThenElse is expression-level; implement this by having visit_call_ (and the
same handling at the other site referenced around 198-206) propagate a predicate
flag/closure along with discovered loads (or mark them as non-decouplable) and
ensure the copy-from emitter (the function that currently reuses a predicate for
statement-level IfThenElse) consults that predicate flag before hoisting a load.

---

Nitpick comments:
In `@testing/python/transform/test_tilelang_transform_decouple_type_cast.py`:
- Around line 123-145: The test test_local_to_memory_with_let_stmt only verifies
shapes via _check and can pass even if the load "factor = scale[i]" no longer
lowers to a LetStmt; add an explicit structural assertion on the original TIR to
ensure the load is lowered into a LetStmt before transformation. Locate the
untransformed TIR produced in test_local_to_memory_with_let_stmt (use the before
prim_func or the pre-_check representation) and assert that a LetStmt node
exists and binds the temporary for "factor = scale[i]" (or that the exact
pattern "factor" bound from scale[i] is present), failing the test if
inline_let_stmts has removed that LetStmt; keep the assertion adjacent to where
_check is invoked so it verifies the precondition prior to running the
transform.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 402c3e48-971b-429e-a8fb-360f7619d967

📥 Commits

Reviewing files that changed from the base of the PR and between 03bb070 and 4066f33.

📒 Files selected for processing (2)
  • testing/python/transform/test_tilelang_transform_decouple_type_cast.py
  • tilelang/transform/decouple_type_cast.py

Comment thread tilelang/transform/decouple_type_cast.py Outdated
Comment thread tilelang/transform/decouple_type_cast.py
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (2)
tilelang/transform/decouple_type_cast.py (2)

138-160: ⚠️ Potential issue | 🔴 Critical

Don't collapse distinct vector accesses from the same buffer into one cast buffer.

This collector deduplicates loads by Buffer only. A loop like out[i] = T.cast(a[i] + a[i + 1], ...) will cache just one access, and LoadReplacer later rewrites both loads to the same a_local_cast[i], which changes semantics. The cast-buffer mapping needs to distinguish access patterns, or the pass should bail out when one buffer is referenced with non-equivalent indices.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `tilelang/transform/decouple_type_cast.py` around lines 138 - 160, The
collector currently deduplicates loads by Buffer only (see visit_buffer_load_,
self._seen_load_buffers) which merges distinct index patterns (e.g., a[i] vs
a[i+1]); change the logic to distinguish per-access patterns by replacing
self._seen_load_buffers: use a key built from the buffer plus a normalized
representation of op.indices (e.g., tuple of index ASTs or hashed expressions)
and deduplicate on that key when appending to self.loads, or alternatively
detect when the same Buffer is seen with non-equivalent index patterns and bail
out (return/skip transformation) so LoadReplacer cannot incorrectly map
different accesses to the same cast buffer; update any uses of
_seen_load_buffers and ensure LoadReplacer is compatible with the new per-access
keys.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/transform/decouple_type_cast.py` around lines 138 - 160, The
collector currently deduplicates loads by Buffer using self._seen_load_buffers
in visit_buffer_load_, which collapses distinct index patterns (e.g., a[i] vs
a[i+1]); change the deduplication to use a key that combines the Buffer and a
normalized representation of op.indices (e.g., tuple of index ASTs, their hashed
string, or another canonical form) so self.loads collects each unique access
pattern, and update/remove usage of self._seen_load_buffers accordingly;
alternatively, detect when the same Buffer is seen with non-equivalent index
patterns and bail out of the transform (skip/abort) to avoid LoadReplacer
mapping different accesses to the same cast buffer, and ensure LoadReplacer is
adjusted to expect per-access keys if you choose the first approach.

162-169: ⚠️ Potential issue | 🔴 Critical

tir.if_then_else loads can be hoisted past their guard.

visit_call_ skips the condition and collects loads from the true/false arms, but the copy-from loop only reuses a predicate when the loop body is a statement-level IfThenElse. For an expression like c[i] = T.if_then_else(i < n, T.cast(a[i], ...), 0), the preload of a[i] becomes unconditional, which can reintroduce out-of-bounds reads that the predicate originally prevented. Either carry the arm predicate into the generated copy loop or skip decoupling for loads discovered under expression-level if_then_else.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `tilelang/transform/decouple_type_cast.py` around lines 162 - 169, The visitor
visit_call_ currently skips the condition of a tir.if_then_else
(op.op.same_as(_IF_THEN_ELSE_OP)) but still collects loads from its true/false
arms, which lets those loads be hoisted unguarded; to fix, change the decoupling
logic so loads found under an expression-level IfThenElse are either (A)
annotated with the arm predicate and have that predicate carried into the
generated copy-from loop (so the copy loop uses the predicate when emitting
guarded loads), or (B) skipped entirely for decoupling when the IfThenElse is
expression-level; implement this by having visit_call_ (and the same handling at
the other site referenced around 198-206) propagate a predicate flag/closure
along with discovered loads (or mark them as non-decouplable) and ensure the
copy-from emitter (the function that currently reuses a predicate for
statement-level IfThenElse) consults that predicate flag before hoisting a load.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/transform/decouple_type_cast.py` around lines 162 - 169, The visitor
visit_call_ currently skips the if_then_else condition but still records loads
from the true/false arms, allowing those loads to be hoisted unguarded; modify
visit_call_ (and the similar handling near the other site referenced around
lines ~198-206) to either attach the arm predicate to loads discovered under an
expression-level tir.if_then_else or mark those loads as non-decouplable so they
are not hoisted. Concretely: when op.op.same_as(_IF_THEN_ELSE_OP) propagate a
predicate token/closure with the loads collected from op.args[1] and op.args[2]
(or set a flag on those load entries) and update the copy-from emitter that
currently reuses a predicate for statement-level IfThenElse to consult this
predicate/flag before emitting a guarded preload; alternatively, if simpler,
skip adding loads discovered under expression-level if_then_else to the
decoupling list so they are not hoisted.
🧹 Nitpick comments (1)
tilelang/transform/decouple_type_cast.py (1)

318-318: Prefer unpacking syntax for list concatenation.

Python 3.5+ allows cleaner syntax for concatenating lists.

♻️ Cleaner syntax
-        all_stmts = copy_from_loops + [compute_loop] + copy_to_loops
+        all_stmts = [*copy_from_loops, compute_loop, *copy_to_loops]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/transform/decouple_type_cast.py` at line 318, The line building
all_stmts uses explicit list concatenation; change it to use list unpacking for
clarity and modernization: replace the current expression that combines
copy_from_loops, compute_loop, and copy_to_loops with a single list literal
using the unpacking operator (e.g., [*copy_from_loops, compute_loop,
*copy_to_loops]) so that the variable all_stmts is created with the unpacked
elements; update the assignment where all_stmts is defined (referencing the
names all_stmts, copy_from_loops, compute_loop, copy_to_loops) accordingly.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tilelang/transform/decouple_type_cast.py`:
- Around line 177-198: inline_let_stmts only handles LetStmt, IfThenElse and
SeqStmt so LetStmt bindings nested in other statements (e.g., For, While,
Allocate) are not inlined and BufferLoads can be hidden from
MemoryAccessCollector; replace the current recursive implementation in
inline_let_stmts with a two-phase approach that (1) walks the input Stmt and
collects all LetStmt bindings into a mapping of vars to values and (2) calls
tir.stmt_functor.substitute (or the existing substitute wrapper) once to apply
that mapping to the whole Stmt so all statement types are handled uniformly;
keep references to inline_let_stmts, LetStmt, substitute and
MemoryAccessCollector to locate and validate the change.

---

Duplicate comments:
In `@tilelang/transform/decouple_type_cast.py`:
- Around line 138-160: The collector currently deduplicates loads by Buffer
using self._seen_load_buffers in visit_buffer_load_, which collapses distinct
index patterns (e.g., a[i] vs a[i+1]); change the deduplication to use a key
that combines the Buffer and a normalized representation of op.indices (e.g.,
tuple of index ASTs, their hashed string, or another canonical form) so
self.loads collects each unique access pattern, and update/remove usage of
self._seen_load_buffers accordingly; alternatively, detect when the same Buffer
is seen with non-equivalent index patterns and bail out of the transform
(skip/abort) to avoid LoadReplacer mapping different accesses to the same cast
buffer, and ensure LoadReplacer is adjusted to expect per-access keys if you
choose the first approach.
- Around line 162-169: The visitor visit_call_ currently skips the if_then_else
condition but still records loads from the true/false arms, allowing those loads
to be hoisted unguarded; modify visit_call_ (and the similar handling near the
other site referenced around lines ~198-206) to either attach the arm predicate
to loads discovered under an expression-level tir.if_then_else or mark those
loads as non-decouplable so they are not hoisted. Concretely: when
op.op.same_as(_IF_THEN_ELSE_OP) propagate a predicate token/closure with the
loads collected from op.args[1] and op.args[2] (or set a flag on those load
entries) and update the copy-from emitter that currently reuses a predicate for
statement-level IfThenElse to consult this predicate/flag before emitting a
guarded preload; alternatively, if simpler, skip adding loads discovered under
expression-level if_then_else to the decoupling list so they are not hoisted.

---

Nitpick comments:
In `@tilelang/transform/decouple_type_cast.py`:
- Line 318: The line building all_stmts uses explicit list concatenation; change
it to use list unpacking for clarity and modernization: replace the current
expression that combines copy_from_loops, compute_loop, and copy_to_loops with a
single list literal using the unpacking operator (e.g., [*copy_from_loops,
compute_loop, *copy_to_loops]) so that the variable all_stmts is created with
the unpacked elements; update the assignment where all_stmts is defined
(referencing the names all_stmts, copy_from_loops, compute_loop, copy_to_loops)
accordingly.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 89dd194e-d6e1-478d-b3ca-79d0b56497b1

📥 Commits

Reviewing files that changed from the base of the PR and between 4066f33 and 5354e7a.

📒 Files selected for processing (1)
  • tilelang/transform/decouple_type_cast.py

Comment thread tilelang/transform/decouple_type_cast.py
@LJC00118
Copy link
Copy Markdown
Collaborator Author

@regression-perf

@github-actions
Copy link
Copy Markdown

Performance Regression Test Report

Triggered by: @LJC00118
Workflow run: https://github.com/tile-ai/tilelang/actions/runs/24233699731

Results

File Original Latency Current Latency Speedup
example_dequant_gemm_w4a8 5.58051 8.75918 0.637104
example_dequant_gemm_bf16_mxfp4_hopper 0.509813 0.749487 0.680217
example_dequant_gemm_fp4_hopper 1.05273 1.3737 0.766348
example_vertical_slash_sparse_attn 0.229633 0.236028 0.972906
block_sparse_attn_tilelang 0.00885019 0.00904803 0.978133
example_mha_fwd_varlen 0.0462488 0.0466326 0.991769
example_mha_bwd_bshd 0.0399686 0.0401342 0.995873
example_tilelang_gemm_splitk 1.02781 1.03158 0.996343
example_linear_attn_fwd 0.0363503 0.0364509 0.997241
example_elementwise_add 0.1155 0.115724 0.998061
example_dynamic 0.643369 0.644418 0.998372
example_gqa_bwd_tma_reduce_varlen 0.0475751 0.0476412 0.998613
sparse_mla_fwd 0.127817 0.127994 0.998619
example_mha_bwd_bhsd 0.0392992 0.0393475 0.99877
example_tilelang_sparse_gqa_decode_varlen_indice 0.0159925 0.0160105 0.998881
example_linear_attn_bwd 0.151704 0.151868 0.998924
example_mha_sink_bwd_bhsd_sliding_window 0.0432011 0.0432335 0.999252
example_blocksparse_gemm 0.0200416 0.0200566 0.999252
example_warp_specialize_gemm_copy_1_gemm_0 0.0277053 0.0277175 0.99956
example_gqa_sink_bwd_bhsd 0.0422252 0.0422411 0.999624
example_convolution_autotune 0.983151 0.983411 0.999736
sparse_mla_fwd_pipelined 0.0947628 0.0947877 0.999738
topk_selector 0.0538818 0.0538939 0.999776
example_tilelang_block_sparse_attn 0.00873117 0.00873311 0.999778
example_mla_decode 0.462243 0.462323 0.999826
example_fusedmoe_tilelang 0.133108 0.133128 0.99985
example_gqa_sink_bwd_bhsd_sliding_window 0.025538 0.0255414 0.999868
example_warp_specialize_gemm_copy_0_gemm_1 0.0399276 0.0399309 0.999917
example_tilelang_gemm_fp8_intrinsic 0.842151 0.842172 0.999976
example_convolution 1.29853 1.29852 1
example_dequant_gemm_bf16_fp4_hopper 0.557206 0.557173 1.00006
sparse_mla_bwd 0.301464 0.30142 1.00015
example_tilelang_nsa_fwd 0.0070367 0.00703469 1.00029
example_gemm_intrinsics 0.0348691 0.0348588 1.0003
example_per_token_cast_to_fp8 0.00737298 0.00737047 1.00034
example_mha_sink_bwd_bhsd 0.0636195 0.0635959 1.00037
example_tilelang_sparse_gqa_decode_varlen_mask 0.0176282 0.0176204 1.00044
example_gqa_bwd 0.0464311 0.0464081 1.0005
example_topk 0.0111067 0.011101 1.00051
example_tilelang_gemm_splitk_vectorize_atomicadd 1.03324 1.03271 1.00051
example_warp_specialize_gemm_softpipe_stage2 0.027726 0.0277099 1.00058
example_gqa_fwd_bshd 0.0703162 0.0702645 1.00074
fp8_lighting_indexer 0.0361365 0.03611 1.00074
example_tilelang_nsa_decode 0.00684234 0.00683729 1.00074
example_gemv 0.288215 0.287985 1.0008
example_mha_sink_fwd_bhsd_sliding_window 0.015022 0.0150076 1.00096
example_gemm 0.0224109 0.0223862 1.0011
example_mha_sink_fwd_bhsd 0.014997 0.0149801 1.00113
example_mha_fwd_bhsd 0.0106278 0.0106107 1.00161
example_tilelang_gemm_fp8_2xAcc 0.186618 0.186178 1.00236
tilelang_example_sparse_tensorcore 0.0146448 0.0145924 1.00359
example_gemm_autotune 0.0225466 0.0224642 1.00367
example_tilelang_gemm_fp8 0.31174 0.310435 1.0042
example_dequant_gemv_fp16xint4 0.0283624 0.0282319 1.00462
example_mha_fwd_bshd 0.0258245 0.0256788 1.00567
example_warp_specialize_gemm_barrierpipe_stage2 0.0406722 0.0403041 1.00913
example_group_per_split_token_cast_to_fp8 0.0103899 0.0102767 1.01102

Artifacts

  • regression_result.png (speedup plot) is attached as a workflow artifact. Download it from the workflow run page above.

@LJC00118
Copy link
Copy Markdown
Collaborator Author

@regression-perf

@github-actions
Copy link
Copy Markdown

Performance Regression Test Report

Triggered by: @LJC00118
Workflow run: https://github.com/tile-ai/tilelang/actions/runs/24237114160

Results

File Original Latency Current Latency Speedup
example_tilelang_gemm_fp8_2xAcc 0.192267 0.328011 0.586161
example_gqa_fwd_bshd 0.0742316 0.12473 0.595136
example_mha_sink_bwd_bhsd_sliding_window 0.0456689 0.0550623 0.829404
example_dequant_gemm_fp4_hopper 1.05574 1.08731 0.970963
example_dequant_gemm_w4a8 5.57833 5.71622 0.975878
example_tilelang_gemm_splitk_vectorize_atomicadd 1.03727 1.06003 0.97853
example_tilelang_nsa_fwd 0.00714785 0.00728811 0.980754
example_tilelang_sparse_gqa_decode_varlen_mask 0.0182742 0.0183904 0.993679
example_tilelang_gemm_fp8_intrinsic 0.865888 0.871203 0.993899
example_mha_bwd_bshd 0.0420638 0.042305 0.9943
example_tilelang_block_sparse_attn 0.00900031 0.00904948 0.994566
example_linear_attn_bwd 0.154258 0.155085 0.994663
example_group_per_split_token_cast_to_fp8 0.0109157 0.0109521 0.99667
example_mha_bwd_bhsd 0.0414708 0.0415742 0.997512
example_tilelang_sparse_gqa_decode_varlen_indice 0.0165453 0.0165832 0.997716
example_dequant_gemm_bf16_fp4_hopper 0.586319 0.587188 0.99852
tilelang_example_sparse_tensorcore 0.0151606 0.0151813 0.998635
example_mha_sink_bwd_bhsd 0.0669547 0.0670234 0.998975
example_gqa_sink_bwd_bhsd 0.0445072 0.0445507 0.999023
example_topk 0.0115686 0.0115784 0.999155
fp8_lighting_indexer 0.0375168 0.0375474 0.999184
topk_selector 0.0557978 0.0558351 0.999333
example_gqa_bwd_tma_reduce_varlen 0.0499703 0.0499944 0.999518
example_mha_fwd_varlen 0.0485572 0.0485796 0.999538
example_gemv 0.302678 0.302812 0.999557
block_sparse_attn_tilelang 0.00907471 0.00907846 0.999587
example_mhc_post 0.10989 0.109935 0.99959
sparse_mla_fwd 0.132929 0.132982 0.999606
example_elementwise_add 0.115438 0.115475 0.999676
example_gemm_intrinsics 0.0368484 0.0368545 0.999834
example_tilelang_gemm_splitk 1.06029 1.06039 0.999903
example_mha_fwd_bshd 0.0268755 0.0268768 0.99995
example_dynamic 0.664071 0.664092 0.999969
example_mla_decode 0.488482 0.488492 0.99998
example_vertical_slash_sparse_attn 0.24314 0.243142 0.999991
example_dequant_gemv_fp16xint4 0.0285487 0.0285479 1.00003
example_convolution 1.37439 1.37432 1.00005
example_warp_specialize_gemm_softpipe_stage2 0.0290235 0.0290208 1.00009
example_linear_attn_fwd 0.0374377 0.0374302 1.0002
example_warp_specialize_gemm_copy_1_gemm_0 0.0290265 0.0290195 1.00024
example_fusedmoe_tilelang 0.138636 0.138579 1.00041
example_gemm_autotune 0.0238728 0.0238616 1.00047
example_warp_specialize_gemm_copy_0_gemm_1 0.0404437 0.040418 1.00064
example_dequant_gemm_bf16_mxfp4_hopper 0.529987 0.529613 1.00071
example_gqa_bwd 0.0487977 0.0487592 1.00079
sparse_mla_fwd_pipelined 0.0992168 0.0991359 1.00082
example_blocksparse_gemm 0.0210791 0.0210597 1.00092
example_gqa_sink_bwd_bhsd_sliding_window 0.0269053 0.0268631 1.00157
example_convolution_autotune 0.986944 0.985379 1.00159
example_gemm 0.0232243 0.023184 1.00174
example_mha_fwd_bhsd 0.0110623 0.0110397 1.00204
example_mhc_pre 0.15901 0.158543 1.00295
example_warp_specialize_gemm_barrierpipe_stage2 0.0411418 0.0410204 1.00296
sparse_mla_bwd 0.311944 0.311019 1.00297
example_tilelang_gemm_fp8 0.323199 0.32194 1.00391
example_per_token_cast_to_fp8 0.00745315 0.00738938 1.00863
example_mha_sink_fwd_bhsd 0.0157905 0.0155784 1.01362
example_mha_sink_fwd_bhsd_sliding_window 0.0158381 0.0156068 1.01482
example_tilelang_nsa_decode 0.00695863 0.00682358 1.01979

Artifacts

  • regression_result.png (speedup plot) is attached as a workflow artifact. Download it from the workflow run page above.

@LJC00118
Copy link
Copy Markdown
Collaborator Author

@regression-perf

1 similar comment
@LJC00118
Copy link
Copy Markdown
Collaborator Author

@regression-perf

@github-actions
Copy link
Copy Markdown

Performance Regression Test Report

Triggered by: @LJC00118
Workflow run: https://github.com/tile-ai/tilelang/actions/runs/24261638600

Results

File Original Latency Current Latency Speedup
example_blocksparse_gemm 0.0199908 0.0200588 0.996607
example_warp_specialize_gemm_barrierpipe_stage2 0.0404544 0.0405259 0.998236
example_mha_sink_fwd_bhsd_sliding_window 0.0149773 0.0150028 0.998305
example_tilelang_gemm_splitk_vectorize_atomicadd 1.03269 1.03392 0.99881
example_per_token_cast_to_fp8 0.00736314 0.00737028 0.99903
example_elementwise_add 0.115582 0.115684 0.999116
example_warp_specialize_gemm_copy_0_gemm_1 0.0396905 0.0397245 0.999146
example_tilelang_nsa_decode 0.00683176 0.00683625 0.999343
example_group_per_split_token_cast_to_fp8 0.0103484 0.0103551 0.999353
example_tilelang_gemm_fp8_2xAcc 0.186388 0.186508 0.999357
example_tilelang_block_sparse_attn 0.00871908 0.00872465 0.999361
example_mha_fwd_varlen 0.0460861 0.0461142 0.99939
topk_selector 0.0538552 0.0538835 0.999475
example_mha_sink_bwd_bhsd_sliding_window 0.043042 0.0430623 0.999528
example_convolution_autotune 0.982961 0.983327 0.999628
example_warp_specialize_gemm_copy_1_gemm_0 0.0276941 0.0277037 0.999655
example_tilelang_gemm_splitk 1.02487 1.02522 0.99966
example_dequant_gemm_bf16_fp4_hopper 0.555394 0.55558 0.999665
example_mha_fwd_bhsd 0.0105986 0.0106014 0.999732
example_gqa_bwd_tma_reduce_varlen 0.0474613 0.0474726 0.999762
example_topk 0.0111029 0.0111055 0.999768
example_gqa_sink_bwd_bhsd 0.042224 0.0422335 0.999777
example_mhc_pre 0.152273 0.1523 0.999824
example_vertical_slash_sparse_attn 0.229324 0.22936 0.999844
example_fusedmoe_tilelang 0.132885 0.132903 0.999866
example_gqa_fwd_bshd 0.0701653 0.0701731 0.999889
example_linear_attn_fwd 0.0362935 0.0362964 0.999921
example_tilelang_gemm_fp8_intrinsic 0.841953 0.842016 0.999926
tilelang_example_sparse_tensorcore 0.0146174 0.0146181 0.999951
fp8_lighting_indexer 0.0359317 0.035933 0.999963
block_sparse_attn_tilelang 0.00884278 0.00884262 1.00002
example_dequant_gemm_bf16_mxfp4_hopper 0.509325 0.509309 1.00003
example_warp_specialize_gemm_softpipe_stage2 0.0277006 0.0276994 1.00004
example_dequant_gemm_w4a8 5.58059 5.58028 1.00006
example_mhc_post 0.10983 0.109824 1.00006
example_linear_attn_bwd 0.151508 0.151498 1.00006
sparse_mla_bwd 0.301222 0.3012 1.00007
example_convolution 1.29666 1.29654 1.00009
example_gemm_intrinsics 0.0348242 0.03482 1.00012
example_dynamic 0.642859 0.642777 1.00013
example_mla_decode 0.462275 0.462195 1.00017
example_mha_sink_bwd_bhsd 0.0634475 0.0634359 1.00018
example_tilelang_gemm_fp8 0.310956 0.310897 1.00019
example_mha_bwd_bhsd 0.0392084 0.0392005 1.0002
example_gqa_sink_bwd_bhsd_sliding_window 0.0255198 0.0255132 1.00026
example_gemm 0.0223983 0.0223923 1.00027
sparse_mla_fwd_pipelined 0.0944872 0.0944579 1.00031
example_dequant_gemv_fp16xint4 0.0283546 0.0283416 1.00046
example_gemv 0.288361 0.288225 1.00047
example_tilelang_nsa_fwd 0.00703103 0.00702728 1.00053
example_mha_bwd_bshd 0.0398819 0.0398605 1.00054
example_gqa_bwd 0.0464377 0.0464095 1.00061
example_mha_fwd_bshd 0.025763 0.0257469 1.00062
example_tilelang_sparse_gqa_decode_varlen_indice 0.0159982 0.0159881 1.00063
example_gemm_autotune 0.0224703 0.0224558 1.00065
example_tilelang_sparse_gqa_decode_varlen_mask 0.0176278 0.0176163 1.00065
sparse_mla_fwd 0.127759 0.127624 1.00106
example_dequant_gemm_fp4_hopper 1.05374 1.0517 1.00194
example_mha_sink_fwd_bhsd 0.0149704 0.0149404 1.00201

Artifacts

  • regression_result.png (speedup plot) is attached as a workflow artifact. Download it from the workflow run page above.

Skip BufferLoad/BufferStore indices when searching for Cast nodes so an
index-type conversion does not spuriously trigger the decoupling
transformation. Clarify the load-replacement table in visit_for_ — store
entries must feed into it so RMW loads map to the store-side cast buffer —
and cover the a[i] = a[i] + a[i+32] case with a regression test.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@LeiWang1999
Copy link
Copy Markdown
Member

@regression-perf

@github-actions
Copy link
Copy Markdown

Performance Regression Test Report

Triggered by: @LeiWang1999
Workflow run: https://github.com/tile-ai/tilelang/actions/runs/24354843357

Results

File Original Latency Current Latency Speedup
example_topk 0.0110979 0.042324 0.262212
example_tilelang_gemm_fp8 0.305384 0.32153 0.949784
example_mha_sink_fwd_bhsd_sliding_window 0.0151859 0.0153028 0.992358
example_mha_sink_fwd_bhsd 0.0152418 0.0153402 0.993589
example_tilelang_gemm_splitk 1.00791 1.01327 0.994709
example_linear_attn_fwd 0.036395 0.0365321 0.996246
example_mha_sink_bwd_bhsd_sliding_window 0.043528 0.0436765 0.9966
example_gemm 0.0223123 0.022386 0.996708
example_mha_fwd_bhsd 0.0108353 0.0108663 0.997151
block_sparse_attn_tilelang 0.00880718 0.00882808 0.997633
example_tilelang_nsa_decode 0.006845 0.00685939 0.997901
example_blocksparse_gemm 0.0190939 0.019133 0.997959
example_tilelang_gemm_splitk_vectorize_atomicadd 1.01103 1.01242 0.99863
example_warp_specialize_gemm_copy_0_gemm_1 0.0373781 0.0374226 0.998809
example_warp_specialize_gemm_barrierpipe_stage2 0.0404521 0.040499 0.998842
example_dequant_gemm_bf16_fp4_hopper 0.556605 0.557038 0.999223
example_group_per_split_token_cast_to_fp8 0.0103831 0.0103895 0.999378
example_mha_fwd_varlen 0.0444443 0.0444708 0.999404
example_mha_fwd_bshd 0.0248103 0.0248218 0.999538
example_tilelang_block_sparse_attn 0.0086478 0.00865165 0.999556
example_warp_specialize_gemm_softpipe_stage2 0.0275674 0.0275795 0.999564
example_gqa_bwd 0.0461294 0.0461492 0.999569
example_gemv 0.288211 0.288323 0.999612
example_mha_bwd_bhsd 0.0392262 0.039238 0.999699
example_gqa_sink_bwd_bhsd 0.0427555 0.042765 0.999779
example_mhc_post 0.109796 0.109815 0.999823
example_gqa_fwd_bshd 0.0690133 0.0690219 0.999876
example_mha_bwd_bshd 0.0402723 0.040277 0.999883
example_dequant_gemv_fp16xint4 0.0283307 0.0283324 0.999939
example_mla_decode 0.451211 0.451231 0.999957
example_tilelang_sparse_gqa_decode_varlen_mask 0.0176014 0.017601 1.00002
example_tilelang_gemm_fp8_intrinsic 0.842075 0.842034 1.00005
example_gqa_bwd_tma_reduce_varlen 0.0463703 0.0463679 1.00005
example_per_token_cast_to_fp8 0.00737795 0.00737756 1.00005
example_dequant_gemm_w4a8 5.5806 5.58025 1.00006
example_warp_specialize_gemm_copy_1_gemm_0 0.0275837 0.0275806 1.00011
example_mha_sink_bwd_bhsd 0.0645499 0.0645377 1.00019
tilelang_example_sparse_tensorcore 0.0146423 0.0146387 1.00024
example_dynamic 0.638131 0.637958 1.00027
example_dequant_gemm_bf16_mxfp4_hopper 0.514591 0.514451 1.00027
example_convolution 1.23689 1.23654 1.00029
example_gqa_sink_bwd_bhsd_sliding_window 0.0252698 0.0252617 1.00032
example_gemm_intrinsics 0.0348678 0.0348549 1.00037
example_tilelang_gemm_fp8_2xAcc 0.133217 0.133167 1.00037
example_linear_attn_bwd 0.153229 0.153157 1.00047
example_convolution_autotune 0.981703 0.981172 1.00054
example_elementwise_add 0.115586 0.115516 1.00061
example_vertical_slash_sparse_attn 0.227609 0.227462 1.00065
example_fusedmoe_tilelang 0.133147 0.133049 1.00073
example_tilelang_sparse_gqa_decode_varlen_indice 0.0159982 0.0159808 1.00109
example_mhc_pre 0.152583 0.152411 1.00113
example_dequant_gemm_fp4_hopper 1.03514 1.03394 1.00116
example_tilelang_nsa_fwd 0.00703622 0.00702622 1.00142
example_gemm_autotune 0.0225018 0.0224333 1.00306

Artifacts

  • regression_result.png (speedup plot) is attached as a workflow artifact. Download it from the workflow run page above.

Commit 03bb070 diverted all loop-invariant global/shared accesses into
the local_fragment bucket, but that bucket's constraint is dropped by the
has_global_or_shared_buffer strategy. ComputeBufferVectorSize already
returns 1 for a reduction-like store such as shared[tx] += a[...+j], yet
that 1 was silently lost, so vectorization proceeded with vector_size=2
and emitted two scalar writes to the same shared[tx] — the second
clobbered the first and dropped a lane of the accumulation.

Only loop-invariant loads (genuine broadcast reads) are safe to divert;
stores must stay in the memory bucket so their vector_size=1 constraint
is honored and the loop is left scalar.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@LeiWang1999
Copy link
Copy Markdown
Member

@regression-perf

Each thread only touches its own shared[tx] slot, so no __syncthreads
is actually required and the test asserts an over-conservative behavior.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@github-actions
Copy link
Copy Markdown

Performance Regression Test Report

Triggered by: @LeiWang1999
Workflow run: https://github.com/tile-ai/tilelang/actions/runs/24382885810

Results

File Original Latency Current Latency Speedup
example_dequant_gemm_fp4_hopper 1.04128 1.07022 0.972955
example_tilelang_gemm_splitk_vectorize_atomicadd 1.02503 1.04908 0.977072
example_mha_fwd_bhsd 0.0112908 0.011492 0.98249
sparse_mla_bwd 0.302638 0.305411 0.990922
example_mha_fwd_varlen 0.0466676 0.0468572 0.995954
tilelang_example_sparse_tensorcore 0.0151147 0.0151704 0.996329
example_tilelang_nsa_decode 0.00694363 0.00696919 0.996333
example_tilelang_gemm_fp8 0.315564 0.316606 0.996706
example_mha_bwd_bhsd 0.0414286 0.0415576 0.996895
sparse_mla_fwd 0.130306 0.130707 0.996929
example_mha_bwd_bshd 0.0424219 0.042551 0.996964
example_linear_attn_fwd 0.0375046 0.0376174 0.997
example_gqa_bwd_tma_reduce_varlen 0.0487449 0.0488577 0.997691
example_gqa_fwd_bshd 0.0728936 0.0730215 0.998248
example_mha_fwd_bshd 0.0258158 0.0258534 0.998545
sparse_mla_fwd_pipelined 0.094328 0.0944437 0.998775
example_tilelang_sparse_gqa_decode_varlen_indice 0.0165182 0.0165326 0.99913
example_mha_sink_bwd_bhsd 0.0680114 0.0680471 0.999476
example_tilelang_sparse_gqa_decode_varlen_mask 0.0182666 0.0182747 0.999558
example_dequant_gemm_bf16_fp4_hopper 0.587741 0.587927 0.999682
example_mla_decode 0.476697 0.476774 0.999838
example_gqa_sink_bwd_bhsd 0.045079 0.0450816 0.999944
example_group_per_split_token_cast_to_fp8 0.0109295 0.0109295 1.00001
example_gemm_intrinsics 0.0368733 0.0368715 1.00005
example_vertical_slash_sparse_attn 0.241072 0.241019 1.00022
example_topk 0.0115689 0.0115658 1.00026
example_warp_specialize_gemm_copy_0_gemm_1 0.0379712 0.0379609 1.00027
example_gemm_autotune 0.0238291 0.0238224 1.00028
example_convolution 1.31273 1.31226 1.00036
example_tilelang_gemm_fp8_2xAcc 0.140177 0.140122 1.00039
example_mhc_post 0.109953 0.109907 1.00042
example_fusedmoe_tilelang 0.138692 0.138633 1.00042
example_mha_sink_bwd_bhsd_sliding_window 0.0456795 0.0456599 1.00043
example_tilelang_block_sparse_attn 0.00891522 0.0089112 1.00045
example_per_token_cast_to_fp8 0.00747031 0.00746679 1.00047
example_warp_specialize_gemm_softpipe_stage2 0.0288815 0.0288644 1.00059
example_convolution_autotune 0.983713 0.983088 1.00063
example_gqa_sink_bwd_bhsd_sliding_window 0.0265817 0.0265621 1.00074
block_sparse_attn_tilelang 0.0090326 0.00902509 1.00083
example_gemm 0.0230729 0.0230457 1.00118
example_warp_specialize_gemm_copy_1_gemm_0 0.0288977 0.0288586 1.00135
example_linear_attn_bwd 0.158699 0.158461 1.0015
example_gqa_bwd 0.0484673 0.0483858 1.00168
example_tilelang_gemm_fp8_intrinsic 0.88296 0.881463 1.0017
example_warp_specialize_gemm_barrierpipe_stage2 0.040911 0.0408402 1.00173
example_mha_sink_fwd_bhsd 0.0158365 0.0158023 1.00217
example_mhc_pre 0.159026 0.158681 1.00218
example_blocksparse_gemm 0.0200043 0.0199524 1.0026
example_elementwise_add 0.115694 0.115377 1.00275
example_mha_sink_fwd_bhsd_sliding_window 0.0157736 0.0157282 1.00289
example_tilelang_gemm_splitk 1.05494 1.05189 1.0029
example_dequant_gemm_bf16_mxfp4_hopper 0.535905 0.534107 1.00336
topk_selector 0.0560727 0.0557993 1.0049
example_dequant_gemv_fp16xint4 0.0288066 0.028546 1.00913
example_tilelang_nsa_fwd 0.00723338 0.0071507 1.01156
fp8_lighting_indexer 0.0677188 0.0337443 2.00682
example_gemv 0.661953 0.302704 2.1868
example_dynamic 1.48314 0.657417 2.25601
example_dequant_gemm_w4a8 12.9409 5.71321 2.26509

Artifacts

  • regression_result.png (speedup plot) is attached as a workflow artifact. Download it from the workflow run page above.

@LeiWang1999 LeiWang1999 merged commit 74fc980 into tile-ai:main Apr 14, 2026
5 of 6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants