Skip to content

Commit c13c814

Browse files
committed
Start optimizing the memory loading
1 parent f032633 commit c13c814

5 files changed

Lines changed: 74 additions & 13 deletions

File tree

tensorforge/backend/instructions/builders/multilinear_builder.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def __init__(self,
4242
self._dest_regs = None
4343

4444
self._use_registers_always = self._context.get_vm().get_hw_descr().vendor in ['amd']
45-
self._preload_registers = False #self._context.get_vm().get_hw_descr().vendor in ['amd']
45+
self._preload_registers = self._context.get_vm().get_hw_descr().vendor in ['amd']
4646
self._deferred_stores = {}
4747
self._temporaries = {}
4848

@@ -112,10 +112,10 @@ def _make_load_op(self, i):
112112
self._loaders_cache[self._mem_regions[i]] = load_op
113113
self._instructions.append(load_op)
114114
else:
115-
if self._preload_registers and self._ops[i].symbol.obj.is_dense() and not (self._ops[i].symbol in self._loaders_cache.keys()):
115+
if self._preload_registers and self._ops[i].symbol.obj.is_dense():
116116
# only register-preload dense matrices for now
117117
self._mem_regions[i], load_op = self._make_loader_and_symbol_reg(self._ops[i].symbol, is_transpose=self._descr.permute[i])
118-
self._loaders_cache[self._ops[i].symbol] = load_op
118+
self._deferred_stores[self._ops[i].symbol.name] = self._mem_regions[i].symbol, self._mem_regions[i].symbol
119119
self._instructions.append(load_op)
120120
else:
121121
# Note: operand will reside in glb. mem for gemm operation
@@ -204,16 +204,19 @@ def _alloc_register_array(self):
204204

205205
# TODO: shrink to enumerate(self._dest_obj.bbox.sizes())
206206
if self._add:
207-
sizes = self._get_target_symbol().data_view._bbox.sizes()
207+
bbox = self._get_target_symbol().data_view._bbox
208208
else:
209-
sizes = self._dest_obj.bbox.sizes()
209+
bbox = self._dest_obj.bbox
210210

211-
for d, dim in enumerate(sizes):
211+
for d in range(bbox.rank()):
212+
dim = bbox.size(d)
212213
if d not in lead_dim or threads == 0:
213214
regsize *= dim
214215
else:
215-
regsize *= (dim + threads - 1) // threads
216-
threads //= dim
216+
r_start = bbox.lower()[d] // threads
217+
r_end = (bbox.upper()[d] + threads - 1) // threads
218+
regsize *= r_end - r_start
219+
threads //= dim # TODO?
217220
name = self._name_registers()
218221
regmem = RegMemObject(name, regsize)
219222
registers = Symbol(name=name, stype=SymbolType.Register, obj=regmem)

tensorforge/backend/instructions/memory/load.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@ def _find_next_coprime(number, conumber):
1616
if math.gcd(i, conumber) == 1:
1717
return i
1818

19-
class GlbToShrLoader(AbstractShrMemWrite):
19+
class LoadInstruction:
20+
pass
21+
22+
class GlbToShrLoader(AbstractShrMemWrite, LoadInstruction):
2023
def __init__(self, **kwargs):
2124
super(GlbToShrLoader, self).__init__(kwargs['context'])
2225
self._dest = kwargs['dest']
@@ -51,7 +54,7 @@ def __init__(self, **kwargs):
5154
self._shr_mem.add_user(self)
5255
self._is_ready: bool = False
5356

54-
self._use_cuda_memcpy = False #self._context.get_vm().get_hw_descr().vendor == 'nvidia'
57+
self._use_cuda_memcpy = self._context.get_vm().get_hw_descr().vendor == 'nvidia'
5558

5659
if self._permute is None:
5760
self._permute = [i for i in range(len(self._src.obj.shape))]
@@ -178,6 +181,7 @@ def inner(indices):
178181
loop.__exit__(None, None, None)
179182

180183
if self._use_cuda_memcpy:
184+
writer(f'__syncwarp();')
181185
writer(f'{self._pipeline}.producer_commit();')
182186

183187
#if False:
@@ -271,7 +275,7 @@ def get_headers(self) -> List[str]:
271275
def __str__(self):
272276
return f'{self._dest.name} = load{{g>s}}({self._src.name}[{", ".join(str(p) for p in self._permute)}])'
273277

274-
class GlbToRegLoader(MemoryInstruction):
278+
class GlbToRegLoader(MemoryInstruction, LoadInstruction):
275279
def __init__(self,
276280
context: Context,
277281
src: Symbol,
@@ -327,3 +331,18 @@ def inner(indices):
327331

328332
def __str__(self) -> str:
329333
return f'{self._dest.name} = load{{g>r}}({self._src.name});'
334+
335+
class LoadWait(MemoryInstruction, LoadInstruction):
336+
def __init__(self, instr):
337+
super(LoadWait, self).__init__(instr._context)
338+
self._instr = instr
339+
self._is_ready = True
340+
341+
def gen_code_inner(self, writer: Writer) -> None:
342+
if isinstance(self._instr, GlbToShrLoader):
343+
if self._instr._use_cuda_memcpy:
344+
writer(f'{self._instr._pipeline}.consumer_wait();')
345+
writer(f'{self._instr._pipeline}.consumer_release();')
346+
347+
def __str__(self) -> str:
348+
return f'wait({self._instr});'

tensorforge/backend/opt/memmove.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from typing import List
2+
from .abstract import AbstractTransformer, Context, AbstractInstruction
3+
from tensorforge.backend.instructions.compute import ComputeInstruction
4+
from tensorforge.backend.instructions.memory import AbstractShrMemWrite, MemoryInstruction
5+
from tensorforge.backend.instructions.memory.load import LoadInstruction, LoadWait
6+
from tensorforge.backend.instructions.ptr_manip import GetElementPtr
7+
from tensorforge.backend.symbol import SymbolType
8+
9+
class MoveLoads(AbstractTransformer):
10+
def __init__(self,
11+
context: Context,
12+
instructions: List[AbstractInstruction]):
13+
super(MoveLoads, self).__init__(context, instructions)
14+
15+
def apply(self) -> None:
16+
instrsOut = []
17+
stored = []
18+
for instr in reversed(self._instrs):
19+
if not isinstance(instr, ComputeInstruction):
20+
while len(stored) > 0:
21+
delayed = stored.pop()
22+
instrsOut += [delayed]
23+
if isinstance(instr, LoadInstruction):
24+
instrsOut += [LoadWait(instr)]
25+
while len(stored) > 0:
26+
delayed = stored.pop()
27+
instrsOut += [delayed]
28+
stored.append(instr)
29+
else:
30+
instrsOut += [instr]
31+
instrsOut += stored[::-1]
32+
33+
self._instrs = instrsOut[::-1]

tensorforge/backend/opt/optimizer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from .shr_mem_analyzer import ShrMemOpt
99
from .sync_block import SyncThreadsOpt
1010
from .remove_redundancy import RemoveRedundancyOpt
11-
11+
from .memmove import MoveLoads
1212

1313
class OptimizationStage:
1414
def __init__(self,
@@ -24,6 +24,10 @@ def __init__(self,
2424
self._num_threads = num_threads
2525

2626
def optimize(self):
27+
opt = MoveLoads(self._context, self._instrs)
28+
opt.apply()
29+
self._instrs = opt.get_instructions()
30+
2731
opt = LivenessAnalysis(self._context, self._instrs)
2832
opt.apply()
2933
live_map: Dict[int, Set[Symbol]] = opt.get_live_map()

tensorforge/generators/generator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,8 @@ def _generate_kernel(self):
181181

182182
def generate_inner():
183183
with writer.If(f'{self._get_flag_guard(writer, i)}'):
184+
if self._context.get_vm().get_hw_descr().vendor == 'nvidia':
185+
writer(f'cuda::pipeline<cuda::thread_scope_thread> pipeline = cuda::make_pipeline();')
184186
for instruction in section.ir:
185187
if instruction.is_ready():
186188
instruction.gen_code(writer)
@@ -430,7 +432,7 @@ def _populate_global_scope(self):
430432
for matrix in self._matrix_list:
431433
if matrix not in self._tmp_list:
432434
# temporary. For now, take only the selector matrices
433-
if matrix.has_values() and len(matrix.get_values()) < 16:
435+
if matrix.has_values() and len(matrix.get_values()) < 16 and False:
434436
stype = SymbolType.Data
435437
elif matrix.addressing == Addressing.SCALAR:
436438
stype = SymbolType.Scalar

0 commit comments

Comments
 (0)