@@ -15,6 +15,8 @@ tilelang.transform.decouple_type_cast
1515 intermediate stage, allowing optimal vectorization for both computation and
1616 memory access.
1717
18+ Mixed-precision is detected by the presence of Cast nodes in the loop body.
19+
1820 Two cases are handled:
1921
2022 Case 1: local → memory (store to memory with mixed types)
@@ -48,20 +50,17 @@ Attributes
4850
4951.. autoapisummary ::
5052
51- tilelang.transform.decouple_type_cast.CastBufferMap
53+ tilelang.transform.decouple_type_cast.CastEntry
5254
5355
5456Classes
5557-------
5658
5759.. autoapisummary ::
5860
59- tilelang.transform.decouple_type_cast.MixedTypeChecker
60- tilelang.transform.decouple_type_cast.GlobalSharedBufferLoadCollector
61- tilelang.transform.decouple_type_cast.StoreCollector
61+ tilelang.transform.decouple_type_cast.MemoryAccessCollector
6262 tilelang.transform.decouple_type_cast.DecoupleTypeCastMutator
63- tilelang.transform.decouple_type_cast.StoreReplacer
64- tilelang.transform.decouple_type_cast.LoadReplacer
63+ tilelang.transform.decouple_type_cast.AccessReplacer
6564
6665
6766Functions
@@ -71,11 +70,7 @@ Functions
7170
7271 tilelang.transform.decouple_type_cast.is_local_buffer
7372 tilelang.transform.decouple_type_cast.is_global_or_shared_buffer
74- tilelang.transform.decouple_type_cast.validate_buffer_scope
75- tilelang.transform.decouple_type_cast.has_mixed_types
76- tilelang.transform.decouple_type_cast.get_global_or_shared_buffer_loads
77- tilelang.transform.decouple_type_cast.has_global_or_shared_load_with_different_dtype
78- tilelang.transform.decouple_type_cast.contains_seq_stmt
73+ tilelang.transform.decouple_type_cast.inline_let_stmts
7974 tilelang.transform.decouple_type_cast.extract_if_condition
8075 tilelang.transform.decouple_type_cast.DecoupleTypeCast
8176
@@ -93,62 +88,40 @@ Module Contents
9388 Check if a buffer is a global or shared buffer.
9489
9590
96- .. py :function :: validate_buffer_scope(buffer)
97-
98- Validate that buffer has a known scope.
99-
100- :raises ValueError: If buffer scope is unknown or empty.
101-
102-
103- .. py :class :: MixedTypeChecker(target_dtype)
91+ .. py :class :: MemoryAccessCollector(loop_var)
10492
10593 Bases: :py:obj: `tvm.tir.PyStmtExprVisitor `
10694
10795
108- Check if expression contains BufferLoads with different dtypes, skipping indices .
96+ Collect shared/global BufferStore and BufferLoad nodes .
10997
98+ Skips indices traversal so that index expressions (which may contain
99+ BufferLoads to index buffers) do not pollute the result.
110100
111- .. py : attribute :: target_dtype
112- :value: ''
101+ BufferLoads in if_then_else conditions are skipped because conditions
102+ don't participate in the type-cast compute path.
113103
104+ BufferLoads whose indices do not depend on ``loop_var `` are skipped
105+ because they are scalar accesses (e.g. ``b[0] ``) that should remain
106+ in the compute loop as broadcasts.
114107
115108
116- .. py :attribute :: found_different
117- :value: False
118-
119-
120-
121- .. py :method :: visit_buffer_load_(op)
122-
123-
124- .. py :function :: has_mixed_types(expr, target_dtype)
125-
126- Check if expression contains BufferLoads with different dtypes than target.
127-
128- If any BufferLoad in the expression has a different dtype than the target
129- (store buffer's dtype), vectorization may be constrained by GCD of all dtypes.
130-
131-
132- .. py :class :: GlobalSharedBufferLoadCollector(skip_if_then_else_cond = False )
133-
134- Bases: :py:obj: `tvm.tir.PyStmtExprVisitor `
109+ .. py :attribute :: loop_var
135110
136111
137- Collect BufferLoads from global/shared buffers, skipping if_then_else conditions.
112+ .. py :attribute :: stores
113+ :type: list[tvm.tir.BufferStore]
114+ :value: []
138115
139- The condition part of if_then_else doesn't participate in type casting,
140- so we skip collecting BufferLoads from there.
141116
142117
143- .. py :attribute :: result
118+ .. py :attribute :: loads
144119 :type: list[tvm.tir.BufferLoad]
145120 :value: []
146121
147122
148123
149- .. py :attribute :: skip_if_then_else_cond
150- :value: False
151-
124+ .. py :method :: visit_buffer_store_(op)
152125
153126
154127 .. py :method :: visit_buffer_load_(op)
@@ -157,58 +130,13 @@ Module Contents
157130 .. py :method :: visit_call_(op)
158131
159132
160- .. py :function :: get_global_or_shared_buffer_loads(expr, skip_if_then_else_cond = False )
161-
162- Get BufferLoads from global/shared buffers in the expression.
163-
164- :param expr: The expression to search.
165- :param skip_if_then_else_cond: If True, skip BufferLoads in if_then_else conditions,
166- since they don't participate in type casting.
167-
168-
169- .. py :function :: has_global_or_shared_load_with_different_dtype(expr, target_dtype)
170-
171- Check if expression has global/shared BufferLoad with different dtype than target.
172-
173- Used to detect memory→local cases where we need to insert cast buffer.
174- Skips if_then_else condition since it doesn't participate in type casting.
175-
176-
177- .. py :class :: StoreCollector
178-
179- Bases: :py:obj: `tvm.tir.PyStmtExprVisitor `
180-
133+ .. py :function :: inline_let_stmts(stmt)
181134
182- Collect BufferStore nodes that need transformation, skipping indices traversal.
135+ Inline all LetStmt bindings in *stmt * so that downstream visitors can
136+ see the original BufferLoad nodes that were hidden behind Var references.
183137
184- This avoids visiting BufferLoad/BufferStore nodes inside indices, which don't
185- participate in the type casting transformation.
186-
187-
188- .. py :attribute :: local_to_memory
189- :type: list[tvm.tir.BufferStore]
190- :value: []
191-
192-
193-
194- .. py :attribute :: memory_to_local
195- :type: list[tvm.tir.BufferStore]
196- :value: []
197-
198-
199-
200- .. py :method :: visit_buffer_store_(op)
201-
202-
203- .. py :method :: visit_buffer_load_(op)
204-
205-
206- .. py :function :: contains_seq_stmt(stmt)
207-
208- Check if statement contains SeqStmt (multiple statements).
209-
210- When the For body has SeqStmt, the transformation is more complex
211- and we skip the optimization for now.
138+ Used before collecting memory accesses so that BufferLoads inside LetStmt
139+ values are visible to ``MemoryAccessCollector ``.
212140
213141
214142.. py :function :: extract_if_condition(stmt)
@@ -218,7 +146,7 @@ Module Contents
218146 :returns: A tuple of (condition, inner_body). If no IfThenElse, returns (None, stmt).
219147
220148
221- .. py :data :: CastBufferMap
149+ .. py :data :: CastEntry
222150
223151 .. py :class :: DecoupleTypeCastMutator
224152
@@ -227,8 +155,8 @@ Module Contents
227155
228156 Mutator that decouples type cast vectorization constraints.
229157
230- This mutator transforms vectorized loops that store to memory buffers
231- (global/shared) with mixed-precision expressions by inserting local
158+ This mutator transforms vectorized loops that have mixed-precision
159+ operations (detected by the presence of Cast nodes) by inserting local
232160 cache buffers as intermediate stages.
233161
234162
@@ -238,35 +166,27 @@ Module Contents
238166
239167
240168
241- .. py :class :: StoreReplacer(cast_buffers , loop_var)
169+ .. py :class :: AccessReplacer(store_entries, load_entries , loop_var)
242170
243171 Bases: :py:obj: `tvm.tir.PyStmtExprMutator `
244172
245173
246- Mutator to replace memory BufferStores with cast buffer BufferStores.
247-
248-
249- .. py :attribute :: cast_buffers
250-
251-
252- .. py :attribute :: loop_var
253-
254-
255- .. py :method :: visit_buffer_store_(op)
174+ Mutator to replace memory BufferStores/BufferLoads with cast buffer accesses.
256175
176+ Matches by both buffer and indices (structural equality) so that accesses
177+ like a[i] and a[i+32] from the same buffer map to different cast buffers.
257178
258- .. py :class :: LoadReplacer(cast_buffers, loop_var)
259179
260- Bases: : py:obj: ` tvm.tir.PyStmtExprMutator `
180+ .. py :attribute :: store_entries
261181
262182
263- Mutator to replace memory BufferLoads with cast buffer BufferLoads.
183+ .. py : attribute :: load_entries
264184
265185
266- .. py :attribute :: cast_buffers
186+ .. py :attribute :: loop_var
267187
268188
269- .. py :attribute :: loop_var
189+ .. py :method :: visit_buffer_store_(op)
270190
271191
272192 .. py :method :: visit_buffer_load_(op)
@@ -277,8 +197,7 @@ Module Contents
277197 Create a TVM pass that decouples type cast vectorization constraints.
278198
279199 This pass inserts a local buffer as an intermediate stage for vectorized
280- stores to non-local buffers (global/shared) where the store value contains
281- expressions with different dtypes.
200+ loops where the body contains Cast nodes (mixed-precision operations).
282201
283202 This allows optimal vectorization for both computation and memory access.
284203
0 commit comments