Skip to content

Commit 56c84d1

Browse files
committed
[Fix] Correct an incorrectly changed string literal
1 parent 5509736 commit 56c84d1

1 file changed

Lines changed: 5 additions & 5 deletions

File tree

PyTorchSimFrontend/mlir/mlir_codegen_backend.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1725,7 +1725,7 @@ def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffe
17251725
"""
17261726
# Use loads as default
17271727
if buffer is None:
1728-
buffer = self.applys if "outputs" not in str(index) else self.dma_loads
1728+
buffer = self.applys if "tmp" not in str(index) else self.dma_loads
17291729

17301730
# TODO.
17311731
kg_tile_desc = self.kernel_group.tile_desc
@@ -1736,7 +1736,7 @@ def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffe
17361736
total_dims = [int(str(i)[5:]) for i in self.itervars]
17371737
local_tile_desc = mlir_common.MLIRMultiDimTile([1], self.vector_lane)
17381738
local_dims.sort() # Assume that smaller index is placed in the outer loop
1739-
indirect_dims = [f"{i}" for i in index.free_symbols if "outputs" in str(i)]
1739+
indirect_dims = [f"{i}" for i in index.free_symbols if "tmp" in str(i)]
17401740
for indirect_dim in indirect_dims:
17411741
index = index.replace(sympy.Symbol(indirect_dim), 0)
17421742

@@ -1992,7 +1992,7 @@ def get_mask(self):
19921992
return mask_shape, mask_var
19931993

19941994
def convert_indirect_indexing(self, index :sympy.Expr):
1995-
if "outputs" not in str(index):
1995+
if "tmp" not in str(index):
19961996
return index, None
19971997

19981998
# Note: In case of indirect indexing, dimensions should be divisible by tile size
@@ -2003,7 +2003,7 @@ def convert_indirect_indexing(self, index :sympy.Expr):
20032003
raise mlir_common.RecompileSignal(f"Indirect access (tile size {self.kernel_group.tile_desc.get_tile_size()} is not divisible by {self.ranges})")
20042004

20052005
# Process start
2006-
indirect_dims = [str(dim) for dim in index.free_symbols if "outputs" in str(dim)]
2006+
indirect_dims = [str(dim) for dim in index.free_symbols if "tmp" in str(dim)]
20072007
indirect_dims.sort()
20082008
first_dim = indirect_dims[0]
20092009
spad_vars = dict()
@@ -2051,7 +2051,7 @@ def convert_indirect_indexing(self, index :sympy.Expr):
20512051

20522052
# Apply stride
20532053
for arg in index.args:
2054-
if "outputs" not in str(arg):
2054+
if "tmp" not in str(arg):
20552055
continue
20562056
if arg.is_Mul and arg.args[0].is_number:
20572057
coeff_dtype = self.var_info[spad_vars[str(arg.args[1])]][1]

0 commit comments

Comments
 (0)