[Refactor] Refactor DecoupleTypeCast Pass#2026
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughDecoupleTypeCast now detects mixed-precision by presence of Cast nodes, inlines LetStmt bindings for analysis, gathers memory accesses with a MemoryAccessCollector, and restructures vectorized loops into three phases: copy-from, compute (with local cast buffers), and copy-to. New tests validate LetStmt and scalar-load cases. Changes
Sequence DiagramsequenceDiagram
participant Input as PrimFunc (with LetStmt)
participant Inliner as LetStmt Inliner
participant Detector as Cast Detector
participant Collector as MemoryAccessCollector
participant Transformer as DecoupleTypeCast
participant Output as Transformed PrimFunc
Input->>Inliner: Inline LetStmt bindings
Inliner->>Detector: Scan inlined tree for Cast nodes
alt Casts found
Detector->>Collector: Collect shared/global BufferLoad/BufferStore (skip certain BufferLoad index traversal)
Collector->>Transformer: Provide access patterns (loads/stores)
Transformer->>Transformer: Build store/load cast-buffer maps and merge overlaps
Transformer->>Transformer: Emit copy-from-memory loops (reads into cast buffers)
Transformer->>Transformer: Emit compute loop using local cast buffers
Transformer->>Transformer: Emit copy-to-memory loops (stores from cast buffers)
Transformer->>Output: Return transformed PrimFunc
else No Casts
Detector->>Output: Return original PrimFunc
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
@regression-perf |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (2)
tilelang/transform/decouple_type_cast.py (2)
283-285: Consider using list unpacking for clarity.Static analysis suggests using list unpacking which is more Pythonic and slightly more efficient.
♻️ Suggested change
- all_stmts = copy_from_loops + [compute_loop] + copy_to_loops + all_stmts = [*copy_from_loops, compute_loop, *copy_to_loops]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tilelang/transform/decouple_type_cast.py` around lines 283 - 285, The combination of copy_from_loops, compute_loop, and copy_to_loops into all_stmts can be clearer and more Pythonic by using list unpacking; replace the explicit concatenation assignment "all_stmts = copy_from_loops + [compute_loop] + copy_to_loops" with an unpacked list like "all_stmts = [*copy_from_loops, compute_loop, *copy_to_loops]" and keep the existing conditional construction of result using SeqStmt and the single-element fallback; update the code surrounding all_stmts, SeqStmt, result, copy_from_loops, compute_loop, and copy_to_loops accordingly.
141-161: Consider documenting the scope limitation.The function only handles
LetStmt,IfThenElse, andSeqStmt. NestedLetStmts within other statement types (e.g., nestedForloops) won't be inlined. While vectorized loop bodies typically don't contain nestedForloops, adding a brief note about this limitation would improve maintainability.📝 Suggested documentation addition
def inline_let_stmts(stmt: Stmt) -> Stmt: """Inline all LetStmt bindings in *stmt* so that downstream visitors can see the original BufferLoad nodes that were hidden behind Var references. Used before collecting memory accesses so that BufferLoads inside LetStmt values are visible to ``MemoryAccessCollector``. + + Note: Only traverses LetStmt, IfThenElse, and SeqStmt. LetStmts nested + within other statement types (e.g., For loops) are not inlined. """🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tilelang/transform/decouple_type_cast.py` around lines 141 - 161, Update the inline_let_stmts docstring to explicitly state its scope limitation: it only recurses into LetStmt, IfThenElse, and SeqStmt and will not inline LetStmt bindings nested inside other statement types (e.g., For loops or other custom Stmt subclasses); mention the assumption about vectorized loop bodies not containing nested For loops and advise that additional handlers (e.g., adding traversal logic for For or other Stmt subclasses) are required if those cases must be supported.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@testing/python/transform/test_tilelang_transform_decouple_type_cast.py`:
- Around line 123-145: The test uses the TVM IR node class name T.Cast instead
of the callable helper T.cast in test_local_to_memory_with_let_stmt: replace the
incorrect call to T.Cast("float8_e4m3fn", a_frag[i] * factor) with a call to
T.cast passing the expression then the dtype object, i.e. use T.cast(a_frag[i] *
factor, T.float8_e4m3fn) when assigning into b_local_cast so the test matches
the transform's expected API (symbols: T.Cast, T.cast,
test_local_to_memory_with_let_stmt, b_local_cast, a_frag, factor).
---
Nitpick comments:
In `@tilelang/transform/decouple_type_cast.py`:
- Around line 283-285: The combination of copy_from_loops, compute_loop, and
copy_to_loops into all_stmts can be clearer and more Pythonic by using list
unpacking; replace the explicit concatenation assignment "all_stmts =
copy_from_loops + [compute_loop] + copy_to_loops" with an unpacked list like
"all_stmts = [*copy_from_loops, compute_loop, *copy_to_loops]" and keep the
existing conditional construction of result using SeqStmt and the single-element
fallback; update the code surrounding all_stmts, SeqStmt, result,
copy_from_loops, compute_loop, and copy_to_loops accordingly.
- Around line 141-161: Update the inline_let_stmts docstring to explicitly state
its scope limitation: it only recurses into LetStmt, IfThenElse, and SeqStmt and
will not inline LetStmt bindings nested inside other statement types (e.g., For
loops or other custom Stmt subclasses); mention the assumption about vectorized
loop bodies not containing nested For loops and advise that additional handlers
(e.g., adding traversal logic for For or other Stmt subclasses) are required if
those cases must be supported.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 5e2b7bbd-66b5-4022-8f53-1e84f54224ed
📒 Files selected for processing (2)
testing/python/transform/test_tilelang_transform_decouple_type_cast.pytilelang/transform/decouple_type_cast.py
Performance Regression Test ReportTriggered by: @LJC00118 Results
Artifacts
|
There was a problem hiding this comment.
🧹 Nitpick comments (4)
tilelang/transform/decouple_type_cast.py (4)
170-172: Consider avoiding unnecessarySeqStmtrecreation.The function creates a new
SeqStmteven when noLetStmtinlining occurred in any of the child statements. This creates unnecessary object allocations.♻️ Suggested improvement
elif isinstance(stmt, SeqStmt): new_seq = [inline_let_stmts(s) for s in stmt.seq] + if all(n is o for n, o in zip(new_seq, stmt.seq)): + return stmt return SeqStmt(new_seq)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tilelang/transform/decouple_type_cast.py` around lines 170 - 172, inline_let_stmts unnecessarily reconstructs a SeqStmt even when none of the children change; modify inline_let_stmts so that when handling a SeqStmt it maps children (stmt.seq), compares each new child to the original (or checks if any child identity/value changed), and only returns a new SeqStmt(new_seq) if at least one child was modified—otherwise return the original stmt to avoid extra allocations (reference symbols: inline_let_stmts, SeqStmt, stmt.seq).
277-278: Avoid accessing private attribute_seen_load_buffers.Accessing the private
_seen_load_buffersattribute from outside the class breaks encapsulation. Consider exposing a public method or property.♻️ Suggested improvement
Add a public method to
MemoryAccessCollector:class MemoryAccessCollector(PyStmtExprVisitor): ... + def is_load_buffer(self, buf: Buffer) -> bool: + """Check if buffer was seen as a load source.""" + return buf in self._seen_load_buffersThen update the usage:
- rmw_buffers = {buf: store_cast_buffers[buf] for buf in store_cast_buffers if buf in collector._seen_load_buffers} + rmw_buffers = {buf: store_cast_buffers[buf] for buf in store_cast_buffers if collector.is_load_buffer(buf)}🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tilelang/transform/decouple_type_cast.py` around lines 277 - 278, Replace the direct access to the private attribute collector._seen_load_buffers with a public accessor on MemoryAccessCollector (e.g., a property or method named seen_load_buffers or get_seen_load_buffers), update MemoryAccessCollector to expose that public API, and then change the comprehension that builds rmw_buffers to use collector.seen_load_buffers() (or the property) instead of _seen_load_buffers so rmw_buffers, load_cast_buffers and the construction of all_copy_from remain correct while preserving encapsulation.
295-296: Prefer list unpacking over concatenation.Using list unpacking is more idiomatic and slightly more efficient than concatenation.
♻️ Suggested improvement
- all_stmts = copy_from_loops + [compute_loop] + copy_to_loops + all_stmts = [*copy_from_loops, compute_loop, *copy_to_loops]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tilelang/transform/decouple_type_cast.py` around lines 295 - 296, The code builds all_stmts by concatenating lists; replace the concatenation with list unpacking to be more idiomatic and slightly faster: construct all_stmts using unpacking of copy_from_loops and copy_to_loops with compute_loop in the middle (referencing variables copy_from_loops, compute_loop, copy_to_loops), then keep the existing result creation using SeqStmt(all_stmts) if length > 1 else all_stmts[0].
154-174: Document thatinline_let_stmtsonly handles top-level statements.The function doesn't recursively descend into
For,While, or other loop constructs. If aLetStmtis nested inside such constructs, it won't be inlined. This appears intentional for the use case (vectorized loop bodies typically have top-level LetStmts), but a brief comment would clarify this design choice.📝 Suggested documentation
def inline_let_stmts(stmt: Stmt) -> Stmt: """Inline all LetStmt bindings in *stmt* so that downstream visitors can see the original BufferLoad nodes that were hidden behind Var references. Used before collecting memory accesses so that BufferLoads inside LetStmt values are visible to ``MemoryAccessCollector``. + + Note: Only handles LetStmt, IfThenElse, and SeqStmt at the top level. + LetStmts nested inside For/While loops are not inlined. """🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tilelang/transform/decouple_type_cast.py` around lines 154 - 174, inline_let_stmts currently only inlines LetStmt bindings at the top-level and does not descend into loop constructs like For or While (so nested LetStmts inside loop bodies are not handled); update the function docstring or add a short comment above inline_let_stmts clarifying that it intentionally only handles top-level statements (LetStmt, IfThenElse, SeqStmt) and does not recurse into loop constructs such as For or While, and mention the rationale (e.g., vectorized loop bodies have top-level LetStmts) so readers know this is by design rather than an omission.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@tilelang/transform/decouple_type_cast.py`:
- Around line 170-172: inline_let_stmts unnecessarily reconstructs a SeqStmt
even when none of the children change; modify inline_let_stmts so that when
handling a SeqStmt it maps children (stmt.seq), compares each new child to the
original (or checks if any child identity/value changed), and only returns a new
SeqStmt(new_seq) if at least one child was modified—otherwise return the
original stmt to avoid extra allocations (reference symbols: inline_let_stmts,
SeqStmt, stmt.seq).
- Around line 277-278: Replace the direct access to the private attribute
collector._seen_load_buffers with a public accessor on MemoryAccessCollector
(e.g., a property or method named seen_load_buffers or get_seen_load_buffers),
update MemoryAccessCollector to expose that public API, and then change the
comprehension that builds rmw_buffers to use collector.seen_load_buffers() (or
the property) instead of _seen_load_buffers so rmw_buffers, load_cast_buffers
and the construction of all_copy_from remain correct while preserving
encapsulation.
- Around line 295-296: The code builds all_stmts by concatenating lists; replace
the concatenation with list unpacking to be more idiomatic and slightly faster:
construct all_stmts using unpacking of copy_from_loops and copy_to_loops with
compute_loop in the middle (referencing variables copy_from_loops, compute_loop,
copy_to_loops), then keep the existing result creation using SeqStmt(all_stmts)
if length > 1 else all_stmts[0].
- Around line 154-174: inline_let_stmts currently only inlines LetStmt bindings
at the top-level and does not descend into loop constructs like For or While (so
nested LetStmts inside loop bodies are not handled); update the function
docstring or add a short comment above inline_let_stmts clarifying that it
intentionally only handles top-level statements (LetStmt, IfThenElse, SeqStmt)
and does not recurse into loop constructs such as For or While, and mention the
rationale (e.g., vectorized loop bodies have top-level LetStmts) so readers know
this is by design rather than an omission.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 665983a9-4bab-4416-b021-848d0b643a6f
📒 Files selected for processing (2)
testing/python/transform/test_tilelang_transform_decouple_type_cast.pytilelang/transform/decouple_type_cast.py
🚧 Files skipped from review as they are similar to previous changes (1)
- testing/python/transform/test_tilelang_transform_decouple_type_cast.py
There was a problem hiding this comment.
🧹 Nitpick comments (2)
tilelang/transform/decouple_type_cast.py (2)
138-156: Consider exposing_seen_load_buffersvia a property.The
_seen_load_buffersattribute is accessed directly at line 297 (collector._seen_load_buffers). Since this is a "private" attribute (underscore prefix), consider exposing it via a read-only property for cleaner encapsulation:♻️ Suggested property accessor
def __init__(self, loop_var: Var): super().__init__() self.loop_var = loop_var self.stores: list[BufferStore] = [] self.loads: list[BufferLoad] = [] self._seen_load_buffers: set[Buffer] = set() + + `@property` + def seen_load_buffers(self) -> set[Buffer]: + """Buffers that have been collected as loads.""" + return self._seen_load_buffersThen at line 297:
- rmw_buffers = {buf: store_cast_buffers[buf] for buf in store_cast_buffers if buf in collector._seen_load_buffers} + rmw_buffers = {buf: store_cast_buffers[buf] for buf in store_cast_buffers if buf in collector.seen_load_buffers}🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tilelang/transform/decouple_type_cast.py` around lines 138 - 156, The code is directly accessing the private attribute _seen_load_buffers (e.g., collector._seen_load_buffers) which breaks encapsulation; add a read-only property (e.g., seen_load_buffers) on the class that exposes a frozenset or tuple view of self._seen_load_buffers and update call sites to use collector.seen_load_buffers; locate the attribute and related methods visit_buffer_load_ and visit_buffer_store_ to add the property and ensure the internal set remains private while external code only reads the immutable view.
314-316: Minor: Consider using unpacking for list construction.Per the static analysis hint (RUF005), list unpacking can be slightly more efficient than concatenation:
♻️ Suggested change
- all_stmts = copy_from_loops + [compute_loop] + copy_to_loops + all_stmts = [*copy_from_loops, compute_loop, *copy_to_loops]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tilelang/transform/decouple_type_cast.py` around lines 314 - 316, The list concatenation building all_stmts should use list unpacking for minor efficiency and clarity: replace the concatenation expression that combines copy_from_loops, compute_loop, and copy_to_loops with an unpacking form (e.g., [*copy_from_loops, compute_loop, *copy_to_loops]) and keep the existing conditional that constructs result = SeqStmt(all_stmts) if len(all_stmts) > 1 else all_stmts[0]; reference names: copy_from_loops, compute_loop, copy_to_loops, all_stmts, SeqStmt.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@tilelang/transform/decouple_type_cast.py`:
- Around line 138-156: The code is directly accessing the private attribute
_seen_load_buffers (e.g., collector._seen_load_buffers) which breaks
encapsulation; add a read-only property (e.g., seen_load_buffers) on the class
that exposes a frozenset or tuple view of self._seen_load_buffers and update
call sites to use collector.seen_load_buffers; locate the attribute and related
methods visit_buffer_load_ and visit_buffer_store_ to add the property and
ensure the internal set remains private while external code only reads the
immutable view.
- Around line 314-316: The list concatenation building all_stmts should use list
unpacking for minor efficiency and clarity: replace the concatenation expression
that combines copy_from_loops, compute_loop, and copy_to_loops with an unpacking
form (e.g., [*copy_from_loops, compute_loop, *copy_to_loops]) and keep the
existing conditional that constructs result = SeqStmt(all_stmts) if
len(all_stmts) > 1 else all_stmts[0]; reference names: copy_from_loops,
compute_loop, copy_to_loops, all_stmts, SeqStmt.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 91b4fbc2-140a-4215-a643-bf30da752eac
📒 Files selected for processing (2)
src/transform/loop_vectorize.cctilelang/transform/decouple_type_cast.py
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (1)
testing/python/transform/test_tilelang_transform_decouple_type_cast.py (1)
123-145: Assert that this case still lowers through aLetStmt.Right now
_checkonly proves the transformed shape. Iffactor = scale[i]stops lowering to aLetStmt, this test still passes without exercisinginline_let_stmts. Please pin that precondition with a direct structural assertion on the untransformed TIR.Based on learnings, tests in
testing/python/transformshould validate structural patterns in the generated IR/source rather than only inferred behavior.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@testing/python/transform/test_tilelang_transform_decouple_type_cast.py` around lines 123 - 145, The test test_local_to_memory_with_let_stmt only verifies shapes via _check and can pass even if the load "factor = scale[i]" no longer lowers to a LetStmt; add an explicit structural assertion on the original TIR to ensure the load is lowered into a LetStmt before transformation. Locate the untransformed TIR produced in test_local_to_memory_with_let_stmt (use the before prim_func or the pre-_check representation) and assert that a LetStmt node exists and binds the temporary for "factor = scale[i]" (or that the exact pattern "factor" bound from scale[i] is present), failing the test if inline_let_stmts has removed that LetStmt; keep the assertion adjacent to where _check is invoked so it verifies the precondition prior to running the transform.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tilelang/transform/decouple_type_cast.py`:
- Around line 138-157: The collector currently deduplicates loads by Buffer only
(see visit_buffer_load_, self._seen_load_buffers) which merges distinct index
patterns (e.g., a[i] vs a[i+1]); change the logic to distinguish per-access
patterns by replacing self._seen_load_buffers: use a key built from the buffer
plus a normalized representation of op.indices (e.g., tuple of index ASTs or
hashed expressions) and deduplicate on that key when appending to self.loads, or
alternatively detect when the same Buffer is seen with non-equivalent index
patterns and bail out (return/skip transformation) so LoadReplacer cannot
incorrectly map different accesses to the same cast buffer; update any uses of
_seen_load_buffers and ensure LoadReplacer is compatible with the new per-access
keys.
- Around line 160-167: The visitor visit_call_ currently skips the condition of
a tir.if_then_else (op.op.same_as(_IF_THEN_ELSE_OP)) but still collects loads
from its true/false arms, which lets those loads be hoisted unguarded; to fix,
change the decoupling logic so loads found under an expression-level IfThenElse
are either (A) annotated with the arm predicate and have that predicate carried
into the generated copy-from loop (so the copy loop uses the predicate when
emitting guarded loads), or (B) skipped entirely for decoupling when the
IfThenElse is expression-level; implement this by having visit_call_ (and the
same handling at the other site referenced around 198-206) propagate a predicate
flag/closure along with discovered loads (or mark them as non-decouplable) and
ensure the copy-from emitter (the function that currently reuses a predicate for
statement-level IfThenElse) consults that predicate flag before hoisting a load.
---
Nitpick comments:
In `@testing/python/transform/test_tilelang_transform_decouple_type_cast.py`:
- Around line 123-145: The test test_local_to_memory_with_let_stmt only verifies
shapes via _check and can pass even if the load "factor = scale[i]" no longer
lowers to a LetStmt; add an explicit structural assertion on the original TIR to
ensure the load is lowered into a LetStmt before transformation. Locate the
untransformed TIR produced in test_local_to_memory_with_let_stmt (use the before
prim_func or the pre-_check representation) and assert that a LetStmt node
exists and binds the temporary for "factor = scale[i]" (or that the exact
pattern "factor" bound from scale[i] is present), failing the test if
inline_let_stmts has removed that LetStmt; keep the assertion adjacent to where
_check is invoked so it verifies the precondition prior to running the
transform.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 402c3e48-971b-429e-a8fb-360f7619d967
📒 Files selected for processing (2)
testing/python/transform/test_tilelang_transform_decouple_type_cast.pytilelang/transform/decouple_type_cast.py
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (2)
tilelang/transform/decouple_type_cast.py (2)
138-160:⚠️ Potential issue | 🔴 CriticalDon't collapse distinct vector accesses from the same buffer into one cast buffer.
This collector deduplicates loads by
Bufferonly. A loop likeout[i] = T.cast(a[i] + a[i + 1], ...)will cache just one access, andLoadReplacerlater rewrites both loads to the samea_local_cast[i], which changes semantics. The cast-buffer mapping needs to distinguish access patterns, or the pass should bail out when one buffer is referenced with non-equivalent indices.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `tilelang/transform/decouple_type_cast.py` around lines 138 - 160, The collector currently deduplicates loads by Buffer only (see visit_buffer_load_, self._seen_load_buffers) which merges distinct index patterns (e.g., a[i] vs a[i+1]); change the logic to distinguish per-access patterns by replacing self._seen_load_buffers: use a key built from the buffer plus a normalized representation of op.indices (e.g., tuple of index ASTs or hashed expressions) and deduplicate on that key when appending to self.loads, or alternatively detect when the same Buffer is seen with non-equivalent index patterns and bail out (return/skip transformation) so LoadReplacer cannot incorrectly map different accesses to the same cast buffer; update any uses of _seen_load_buffers and ensure LoadReplacer is compatible with the new per-access keys.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tilelang/transform/decouple_type_cast.py` around lines 138 - 160, The collector currently deduplicates loads by Buffer using self._seen_load_buffers in visit_buffer_load_, which collapses distinct index patterns (e.g., a[i] vs a[i+1]); change the deduplication to use a key that combines the Buffer and a normalized representation of op.indices (e.g., tuple of index ASTs, their hashed string, or another canonical form) so self.loads collects each unique access pattern, and update/remove usage of self._seen_load_buffers accordingly; alternatively, detect when the same Buffer is seen with non-equivalent index patterns and bail out of the transform (skip/abort) to avoid LoadReplacer mapping different accesses to the same cast buffer, and ensure LoadReplacer is adjusted to expect per-access keys if you choose the first approach.
162-169:⚠️ Potential issue | 🔴 Critical
tir.if_then_elseloads can be hoisted past their guard.
visit_call_skips the condition and collects loads from the true/false arms, but the copy-from loop only reuses a predicate when the loop body is a statement-levelIfThenElse. For an expression likec[i] = T.if_then_else(i < n, T.cast(a[i], ...), 0), the preload ofa[i]becomes unconditional, which can reintroduce out-of-bounds reads that the predicate originally prevented. Either carry the arm predicate into the generated copy loop or skip decoupling for loads discovered under expression-levelif_then_else.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `tilelang/transform/decouple_type_cast.py` around lines 162 - 169, The visitor visit_call_ currently skips the condition of a tir.if_then_else (op.op.same_as(_IF_THEN_ELSE_OP)) but still collects loads from its true/false arms, which lets those loads be hoisted unguarded; to fix, change the decoupling logic so loads found under an expression-level IfThenElse are either (A) annotated with the arm predicate and have that predicate carried into the generated copy-from loop (so the copy loop uses the predicate when emitting guarded loads), or (B) skipped entirely for decoupling when the IfThenElse is expression-level; implement this by having visit_call_ (and the same handling at the other site referenced around 198-206) propagate a predicate flag/closure along with discovered loads (or mark them as non-decouplable) and ensure the copy-from emitter (the function that currently reuses a predicate for statement-level IfThenElse) consults that predicate flag before hoisting a load.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tilelang/transform/decouple_type_cast.py` around lines 162 - 169, The visitor visit_call_ currently skips the if_then_else condition but still records loads from the true/false arms, allowing those loads to be hoisted unguarded; modify visit_call_ (and the similar handling near the other site referenced around lines ~198-206) to either attach the arm predicate to loads discovered under an expression-level tir.if_then_else or mark those loads as non-decouplable so they are not hoisted. Concretely: when op.op.same_as(_IF_THEN_ELSE_OP) propagate a predicate token/closure with the loads collected from op.args[1] and op.args[2] (or set a flag on those load entries) and update the copy-from emitter that currently reuses a predicate for statement-level IfThenElse to consult this predicate/flag before emitting a guarded preload; alternatively, if simpler, skip adding loads discovered under expression-level if_then_else to the decoupling list so they are not hoisted.
🧹 Nitpick comments (1)
tilelang/transform/decouple_type_cast.py (1)
318-318: Prefer unpacking syntax for list concatenation.Python 3.5+ allows cleaner syntax for concatenating lists.
♻️ Cleaner syntax
- all_stmts = copy_from_loops + [compute_loop] + copy_to_loops + all_stmts = [*copy_from_loops, compute_loop, *copy_to_loops]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tilelang/transform/decouple_type_cast.py` at line 318, The line building all_stmts uses explicit list concatenation; change it to use list unpacking for clarity and modernization: replace the current expression that combines copy_from_loops, compute_loop, and copy_to_loops with a single list literal using the unpacking operator (e.g., [*copy_from_loops, compute_loop, *copy_to_loops]) so that the variable all_stmts is created with the unpacked elements; update the assignment where all_stmts is defined (referencing the names all_stmts, copy_from_loops, compute_loop, copy_to_loops) accordingly.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tilelang/transform/decouple_type_cast.py`:
- Around line 177-198: inline_let_stmts only handles LetStmt, IfThenElse and
SeqStmt so LetStmt bindings nested in other statements (e.g., For, While,
Allocate) are not inlined and BufferLoads can be hidden from
MemoryAccessCollector; replace the current recursive implementation in
inline_let_stmts with a two-phase approach that (1) walks the input Stmt and
collects all LetStmt bindings into a mapping of vars to values and (2) calls
tir.stmt_functor.substitute (or the existing substitute wrapper) once to apply
that mapping to the whole Stmt so all statement types are handled uniformly;
keep references to inline_let_stmts, LetStmt, substitute and
MemoryAccessCollector to locate and validate the change.
---
Duplicate comments:
In `@tilelang/transform/decouple_type_cast.py`:
- Around line 138-160: The collector currently deduplicates loads by Buffer
using self._seen_load_buffers in visit_buffer_load_, which collapses distinct
index patterns (e.g., a[i] vs a[i+1]); change the deduplication to use a key
that combines the Buffer and a normalized representation of op.indices (e.g.,
tuple of index ASTs, their hashed string, or another canonical form) so
self.loads collects each unique access pattern, and update/remove usage of
self._seen_load_buffers accordingly; alternatively, detect when the same Buffer
is seen with non-equivalent index patterns and bail out of the transform
(skip/abort) to avoid LoadReplacer mapping different accesses to the same cast
buffer, and ensure LoadReplacer is adjusted to expect per-access keys if you
choose the first approach.
- Around line 162-169: The visitor visit_call_ currently skips the if_then_else
condition but still records loads from the true/false arms, allowing those loads
to be hoisted unguarded; modify visit_call_ (and the similar handling near the
other site referenced around lines ~198-206) to either attach the arm predicate
to loads discovered under an expression-level tir.if_then_else or mark those
loads as non-decouplable so they are not hoisted. Concretely: when
op.op.same_as(_IF_THEN_ELSE_OP) propagate a predicate token/closure with the
loads collected from op.args[1] and op.args[2] (or set a flag on those load
entries) and update the copy-from emitter that currently reuses a predicate for
statement-level IfThenElse to consult this predicate/flag before emitting a
guarded preload; alternatively, if simpler, skip adding loads discovered under
expression-level if_then_else to the decoupling list so they are not hoisted.
---
Nitpick comments:
In `@tilelang/transform/decouple_type_cast.py`:
- Line 318: The line building all_stmts uses explicit list concatenation; change
it to use list unpacking for clarity and modernization: replace the current
expression that combines copy_from_loops, compute_loop, and copy_to_loops with a
single list literal using the unpacking operator (e.g., [*copy_from_loops,
compute_loop, *copy_to_loops]) so that the variable all_stmts is created with
the unpacked elements; update the assignment where all_stmts is defined
(referencing the names all_stmts, copy_from_loops, compute_loop, copy_to_loops)
accordingly.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 89dd194e-d6e1-478d-b3ca-79d0b56497b1
📒 Files selected for processing (1)
tilelang/transform/decouple_type_cast.py
|
@regression-perf |
Performance Regression Test ReportTriggered by: @LJC00118 Results
Artifacts
|
|
@regression-perf |
Performance Regression Test ReportTriggered by: @LJC00118 Results
Artifacts
|
|
@regression-perf |
1 similar comment
|
@regression-perf |
Performance Regression Test ReportTriggered by: @LJC00118 Results
Artifacts
|
Skip BufferLoad/BufferStore indices when searching for Cast nodes so an index-type conversion does not spuriously trigger the decoupling transformation. Clarify the load-replacement table in visit_for_ — store entries must feed into it so RMW loads map to the store-side cast buffer — and cover the a[i] = a[i] + a[i+32] case with a regression test. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
|
@regression-perf |
Performance Regression Test ReportTriggered by: @LeiWang1999 Results
Artifacts
|
Commit 03bb070 diverted all loop-invariant global/shared accesses into the local_fragment bucket, but that bucket's constraint is dropped by the has_global_or_shared_buffer strategy. ComputeBufferVectorSize already returns 1 for a reduction-like store such as shared[tx] += a[...+j], yet that 1 was silently lost, so vectorization proceeded with vector_size=2 and emitted two scalar writes to the same shared[tx] — the second clobbered the first and dropped a lane of the accumulation. Only loop-invariant loads (genuine broadcast reads) are safe to divert; stores must stay in the memory bucket so their vector_size=1 constraint is honored and the loop is left scalar. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
|
@regression-perf |
Each thread only touches its own shared[tx] slot, so no __syncthreads is actually required and the test asserts an over-conservative behavior. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Performance Regression Test ReportTriggered by: @LeiWang1999 Results
Artifacts
|
Summary by CodeRabbit
Bug Fixes
Refactor
Tests
Documentation