Skip to content

Commit 431a1f8

Browse files
Update docs
1 parent 9837f8d commit 431a1f8

5 files changed

Lines changed: 148 additions & 443 deletions

File tree

_sources/autoapi/tilelang/transform/decouple_type_cast/index.rst.txt

Lines changed: 38 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ tilelang.transform.decouple_type_cast
1515
intermediate stage, allowing optimal vectorization for both computation and
1616
memory access.
1717

18+
Mixed-precision is detected by the presence of Cast nodes in the loop body.
19+
1820
Two cases are handled:
1921

2022
Case 1: local → memory (store to memory with mixed types)
@@ -48,20 +50,17 @@ Attributes
4850

4951
.. autoapisummary::
5052

51-
tilelang.transform.decouple_type_cast.CastBufferMap
53+
tilelang.transform.decouple_type_cast.CastEntry
5254

5355

5456
Classes
5557
-------
5658

5759
.. autoapisummary::
5860

59-
tilelang.transform.decouple_type_cast.MixedTypeChecker
60-
tilelang.transform.decouple_type_cast.GlobalSharedBufferLoadCollector
61-
tilelang.transform.decouple_type_cast.StoreCollector
61+
tilelang.transform.decouple_type_cast.MemoryAccessCollector
6262
tilelang.transform.decouple_type_cast.DecoupleTypeCastMutator
63-
tilelang.transform.decouple_type_cast.StoreReplacer
64-
tilelang.transform.decouple_type_cast.LoadReplacer
63+
tilelang.transform.decouple_type_cast.AccessReplacer
6564

6665

6766
Functions
@@ -71,11 +70,7 @@ Functions
7170

7271
tilelang.transform.decouple_type_cast.is_local_buffer
7372
tilelang.transform.decouple_type_cast.is_global_or_shared_buffer
74-
tilelang.transform.decouple_type_cast.validate_buffer_scope
75-
tilelang.transform.decouple_type_cast.has_mixed_types
76-
tilelang.transform.decouple_type_cast.get_global_or_shared_buffer_loads
77-
tilelang.transform.decouple_type_cast.has_global_or_shared_load_with_different_dtype
78-
tilelang.transform.decouple_type_cast.contains_seq_stmt
73+
tilelang.transform.decouple_type_cast.inline_let_stmts
7974
tilelang.transform.decouple_type_cast.extract_if_condition
8075
tilelang.transform.decouple_type_cast.DecoupleTypeCast
8176

@@ -93,62 +88,40 @@ Module Contents
9388
Check if a buffer is a global or shared buffer.
9489

9590

96-
.. py:function:: validate_buffer_scope(buffer)
97-
98-
Validate that buffer has a known scope.
99-
100-
:raises ValueError: If buffer scope is unknown or empty.
101-
102-
103-
.. py:class:: MixedTypeChecker(target_dtype)
91+
.. py:class:: MemoryAccessCollector(loop_var)
10492
10593
Bases: :py:obj:`tvm.tir.PyStmtExprVisitor`
10694

10795

108-
Check if expression contains BufferLoads with different dtypes, skipping indices.
96+
Collect shared/global BufferStore and BufferLoad nodes.
10997

98+
Skips indices traversal so that index expressions (which may contain
99+
BufferLoads to index buffers) do not pollute the result.
110100

111-
.. py:attribute:: target_dtype
112-
:value: ''
101+
BufferLoads in if_then_else conditions are skipped because conditions
102+
don't participate in the type-cast compute path.
113103

104+
BufferLoads whose indices do not depend on ``loop_var`` are skipped
105+
because they are scalar accesses (e.g. ``b[0]``) that should remain
106+
in the compute loop as broadcasts.
114107

115108

116-
.. py:attribute:: found_different
117-
:value: False
118-
119-
120-
121-
.. py:method:: visit_buffer_load_(op)
122-
123-
124-
.. py:function:: has_mixed_types(expr, target_dtype)
125-
126-
Check if expression contains BufferLoads with different dtypes than target.
127-
128-
If any BufferLoad in the expression has a different dtype than the target
129-
(store buffer's dtype), vectorization may be constrained by GCD of all dtypes.
130-
131-
132-
.. py:class:: GlobalSharedBufferLoadCollector(skip_if_then_else_cond = False)
133-
134-
Bases: :py:obj:`tvm.tir.PyStmtExprVisitor`
109+
.. py:attribute:: loop_var
135110
136111
137-
Collect BufferLoads from global/shared buffers, skipping if_then_else conditions.
112+
.. py:attribute:: stores
113+
:type: list[tvm.tir.BufferStore]
114+
:value: []
138115

139-
The condition part of if_then_else doesn't participate in type casting,
140-
so we skip collecting BufferLoads from there.
141116

142117

143-
.. py:attribute:: result
118+
.. py:attribute:: loads
144119
:type: list[tvm.tir.BufferLoad]
145120
:value: []
146121

147122

148123

149-
.. py:attribute:: skip_if_then_else_cond
150-
:value: False
151-
124+
.. py:method:: visit_buffer_store_(op)
152125
153126
154127
.. py:method:: visit_buffer_load_(op)
@@ -157,58 +130,13 @@ Module Contents
157130
.. py:method:: visit_call_(op)
158131
159132
160-
.. py:function:: get_global_or_shared_buffer_loads(expr, skip_if_then_else_cond = False)
161-
162-
Get BufferLoads from global/shared buffers in the expression.
163-
164-
:param expr: The expression to search.
165-
:param skip_if_then_else_cond: If True, skip BufferLoads in if_then_else conditions,
166-
since they don't participate in type casting.
167-
168-
169-
.. py:function:: has_global_or_shared_load_with_different_dtype(expr, target_dtype)
170-
171-
Check if expression has global/shared BufferLoad with different dtype than target.
172-
173-
Used to detect memory→local cases where we need to insert cast buffer.
174-
Skips if_then_else condition since it doesn't participate in type casting.
175-
176-
177-
.. py:class:: StoreCollector
178-
179-
Bases: :py:obj:`tvm.tir.PyStmtExprVisitor`
180-
133+
.. py:function:: inline_let_stmts(stmt)
181134
182-
Collect BufferStore nodes that need transformation, skipping indices traversal.
135+
Inline all LetStmt bindings in *stmt* so that downstream visitors can
136+
see the original BufferLoad nodes that were hidden behind Var references.
183137

184-
This avoids visiting BufferLoad/BufferStore nodes inside indices, which don't
185-
participate in the type casting transformation.
186-
187-
188-
.. py:attribute:: local_to_memory
189-
:type: list[tvm.tir.BufferStore]
190-
:value: []
191-
192-
193-
194-
.. py:attribute:: memory_to_local
195-
:type: list[tvm.tir.BufferStore]
196-
:value: []
197-
198-
199-
200-
.. py:method:: visit_buffer_store_(op)
201-
202-
203-
.. py:method:: visit_buffer_load_(op)
204-
205-
206-
.. py:function:: contains_seq_stmt(stmt)
207-
208-
Check if statement contains SeqStmt (multiple statements).
209-
210-
When the For body has SeqStmt, the transformation is more complex
211-
and we skip the optimization for now.
138+
Used before collecting memory accesses so that BufferLoads inside LetStmt
139+
values are visible to ``MemoryAccessCollector``.
212140

213141

214142
.. py:function:: extract_if_condition(stmt)
@@ -218,7 +146,7 @@ Module Contents
218146
:returns: A tuple of (condition, inner_body). If no IfThenElse, returns (None, stmt).
219147

220148

221-
.. py:data:: CastBufferMap
149+
.. py:data:: CastEntry
222150
223151
.. py:class:: DecoupleTypeCastMutator
224152
@@ -227,8 +155,8 @@ Module Contents
227155

228156
Mutator that decouples type cast vectorization constraints.
229157

230-
This mutator transforms vectorized loops that store to memory buffers
231-
(global/shared) with mixed-precision expressions by inserting local
158+
This mutator transforms vectorized loops that have mixed-precision
159+
operations (detected by the presence of Cast nodes) by inserting local
232160
cache buffers as intermediate stages.
233161

234162

@@ -238,35 +166,27 @@ Module Contents
238166

239167

240168

241-
.. py:class:: StoreReplacer(cast_buffers, loop_var)
169+
.. py:class:: AccessReplacer(store_entries, load_entries, loop_var)
242170
243171
Bases: :py:obj:`tvm.tir.PyStmtExprMutator`
244172

245173

246-
Mutator to replace memory BufferStores with cast buffer BufferStores.
247-
248-
249-
.. py:attribute:: cast_buffers
250-
251-
252-
.. py:attribute:: loop_var
253-
254-
255-
.. py:method:: visit_buffer_store_(op)
174+
Mutator to replace memory BufferStores/BufferLoads with cast buffer accesses.
256175

176+
Matches by both buffer and indices (structural equality) so that accesses
177+
like a[i] and a[i+32] from the same buffer map to different cast buffers.
257178

258-
.. py:class:: LoadReplacer(cast_buffers, loop_var)
259179

260-
Bases: :py:obj:`tvm.tir.PyStmtExprMutator`
180+
.. py:attribute:: store_entries
261181
262182
263-
Mutator to replace memory BufferLoads with cast buffer BufferLoads.
183+
.. py:attribute:: load_entries
264184
265185
266-
.. py:attribute:: cast_buffers
186+
.. py:attribute:: loop_var
267187
268188
269-
.. py:attribute:: loop_var
189+
.. py:method:: visit_buffer_store_(op)
270190
271191
272192
.. py:method:: visit_buffer_load_(op)
@@ -277,8 +197,7 @@ Module Contents
277197
Create a TVM pass that decouples type cast vectorization constraints.
278198

279199
This pass inserts a local buffer as an intermediate stage for vectorized
280-
stores to non-local buffers (global/shared) where the store value contains
281-
expressions with different dtypes.
200+
loops where the body contains Cast nodes (mixed-precision operations).
282201

283202
This allows optimal vectorization for both computation and memory access.
284203

0 commit comments

Comments
 (0)