Skip to content

Commit 3074009

Browse files
committed
Start implementing float4 loads
1 parent 2101864 commit 3074009

6 files changed

Lines changed: 128 additions & 36 deletions

File tree

tensorforge/backend/instructions/memory/load.py

Lines changed: 75 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -184,25 +184,23 @@ def inner(indices):
184184
writer(f'__syncwarp();')
185185
writer(f'{self._pipeline}.producer_commit();')
186186

187-
#if False:
188-
# writer('cooperative_groups::wait(cooperative_groups::this_thread_block());')
189-
190187
def _write_datatransfer(self, writer, src_offset, dst_offset, index, length, nontemporal, linscale=None):
191-
if not self._use_cuda_memcpy or linscale is not None or True:
192-
pos = 0
193-
for vecsize in [1]:
194-
if src_offset % vecsize == 0:
195-
num_hops = ((length - pos * self._num_threads) // (self._num_threads * vecsize)) * vecsize
196-
self._write_hop(writer, src_offset, dst_offset, index, pos, pos + num_hops, vecsize, nontemporal, linscale)
197-
pos += num_hops
198-
rest = length % self._num_threads
199-
if rest > 0:
200-
with writer.If(f'{self._linear_idx()} < {rest}'):
201-
self._write_hop(writer, src_offset, dst_offset, index, pos, pos+1, 1, nontemporal, linscale)
188+
pos = 0
189+
190+
if self._use_cuda_memcpy:
191+
granularities = [1]
202192
else:
203-
dest_access_index = self._dest.access_address(self._context, index)
204-
src_access_index = self._src.access_address(self._context, index)
205-
writer(f'cuda::memcpy_async(cooperative_groups::this_thread_block(), &{self._dest.name}[{dst_offset} + {dest_access_index}], &{self._src.name}[{src_offset} + {src_access_index}], cuda::aligned_size_t<{self._dest.get_fptype().size()}>({length * self._dest.get_fptype().size()}), {self._pipeline});')
193+
granularities = [4, 2, 1]
194+
195+
for vecsize in granularities:
196+
if src_offset % vecsize == 0:
197+
num_hops = ((length - pos * self._num_threads) // (self._num_threads * vecsize)) * vecsize
198+
self._write_hop(writer, src_offset, dst_offset, index, pos, pos + num_hops, vecsize, nontemporal, linscale)
199+
pos += num_hops
200+
rest = length % self._num_threads
201+
if rest > 0:
202+
with writer.If(f'{self._linear_idx()} < {rest}'):
203+
self._write_hop(writer, src_offset, dst_offset, index, pos, pos+1, 1, nontemporal, linscale)
206204

207205
def _write_hop(self, writer, src_offset, dst_offset, index, start, end, increment, nontemporal, linscale):
208206
if end > start:
@@ -325,9 +323,66 @@ def gen_code_inner(self, writer: Writer) -> None:
325323
for dim in src_bbox.sizes():
326324
total_size *= dim
327325

328-
for i in range(0, total_size, self._num_threads):
329-
self._src.load_linear(writer, self._context, f'v{i}', i)
330-
self._dest.store_linear(writer, self._context, f'v{i}', i)
326+
start = 0
327+
for g in [1]: #[4, 2, 1]:
328+
granularity = self._num_threads * g
329+
for i in range(start, total_size, granularity):
330+
self._src.load_linear(writer, self._context, f'v{i}', i, g)
331+
self._dest.store_linear(writer, self._context, f'v{i}', i, g)
332+
333+
start = (total_size // granularity) * granularity
334+
335+
elif self._context.get_vm().get_hw_descr().vendor in ['amd']:
336+
337+
# float4 load
338+
339+
# for now: use 0 1 2 3, transpose4x4
340+
341+
# TODO: sort into 4x4x4 blocks
342+
343+
lead_size = src_bbox.size(0)
344+
lead_count = (lead_size + self._num_threads - 1) // self._num_threads
345+
346+
total_count = lead_count
347+
for dim in src_bbox.sizes()[1:]:
348+
total_count *= dim
349+
350+
start = 0
351+
352+
prec = 'float'
353+
354+
for g in [4, 2, 1]: # [4, 3, 2, 1]
355+
# 4x4
356+
# writer(f'const auto f{g}idx = (threadIdx.x % {g}) * {self._num_threads} + (threadIdx.x / {g}) * {g};')
357+
358+
writer(f'const auto f{g}idx = ((threadIdx.x / {16 // g}) % {g}) * {self._num_threads} + (threadIdx.x % {16 // g}) * {g} + (threadIdx.x / 16) * 16;')
359+
360+
total_count_g = (total_count // g) * g
361+
for i in range(start, total_count_g, g):
362+
sidx = i // lead_count
363+
ridx = i % lead_count
364+
index = sidx * lead_size + ridx * self._num_threads
365+
writer(f'const auto v{i} = *(tensorforge::VectorT<{prec}, {g}>*)&{self._src.name}[{index} + f{g}idx];')
366+
367+
args2 = ', '.join(f'v{i}[{k}]' for k in range(g))
368+
369+
for k in range(g):
370+
writer(f'{prec} v{i}w{k} = 0;')
371+
372+
args1 = ', '.join(f'v{i}w{k}' for k in range(g))
373+
374+
if g == 4:
375+
writer(f'tensorforge::transpose16x4({args1}, {args2});')
376+
if g == 2:
377+
writer(f'tensorforge::transpose16x2({args1}, {args2});')
378+
if g == 1:
379+
writer(f'{args1} = {args2};')
380+
381+
# TODO: generalize
382+
for k in range(g):
383+
writer(f'{self._dest.name}[{i + k}] = v{i}w{k};')
384+
385+
start = total_count_g
331386

332387
else:
333388
loops = []

tensorforge/backend/instructions/memory/store.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,10 @@ def inner(indices):
195195
else:
196196
self._dest.store(writer, self._context, '0', indices, allow_nontemporal)
197197

198-
write_loops(self._context, writer, loops, inner)
198+
if not any(manual) and self._context.get_vm().get_hw_descr().vendor in ['amd'] and False:
199+
pass
200+
else:
201+
write_loops(self._context, writer, loops, inner)
199202

200203
def __str__(self) -> str:
201204
return f'{self._dest.name} = store{{r>g}}({self._src.name});'

tensorforge/backend/symbol.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -444,17 +444,21 @@ def encode_values(self, pos, runIdx, writer, context: Context, variable, index:
444444
wrote |= self.encode_values(pos + 1, runIdx, writer, context, variable, index, nontemp, leadidx)
445445
return wrote
446446

447-
def load_linear(self, writer, context: Context, variable, index):
447+
def load_linear(self, writer, context: Context, variable, index, vec = 1):
448448
if context.get_vm().get_lexic().simd_mode:
449449
writer(f'{context.get_vm().get_lexic().simd(self.get_fptype(), self.num_threads)} {variable}({index});')
450450
else:
451451
if self.stype == SymbolType.Register:
452452
access = f'{self.name}[{index // self.num_threads}]'
453453
else:
454-
access = f'{self.name}[{index} + threadIdx.x]'
455-
writer(f'{self.get_fptype()} {variable} = {access};')
454+
access = f'{self.name}[{index} + threadIdx.x * {vec}]'
456455

457-
def store_linear(self, writer, context: Context, variable, index):
456+
if vec == 1:
457+
writer(f'{self.get_fptype()} {variable} = {access};')
458+
else:
459+
writer(f'tensorforge::VectorT<{self.get_fptype()}, {vec}> {variable} = *(tensorforge::VectorT<{self.get_fptype()}, {vec}>*)&{access};')
460+
461+
def store_linear(self, writer, context: Context, variable, index, vec = 1):
458462
if context.get_vm().get_lexic().simd_mode:
459463
pass
460464
# TODO:
@@ -463,8 +467,13 @@ def store_linear(self, writer, context: Context, variable, index):
463467
if self.stype == SymbolType.Register:
464468
access = f'{self.name}[{index // self.num_threads}]'
465469
else:
466-
access = f'{self.name}[{index} + threadIdx.x]'
467-
writer(f'{access} = {variable};')
470+
access = f'{self.name}[{index} + threadIdx.x * {vec}]'
471+
472+
if vec == 1:
473+
writer(f'{access} = {variable};')
474+
else:
475+
convert = f'*(tensorforge::VectorT<{self.get_fptype()}, {vec}>*)&'
476+
writer(f'{convert}{access} = {convert}{variable};')
468477

469478
def load(self, writer, context: Context, variable, index: List[Union[str, int, Immediate, Variable, LeadIndex]], nontemp):
470479
if self.stype == SymbolType.Data or (not self.obj.is_dense() and not isinstance(self.obj.spp, BoundingBoxSPP)):

tensorforge/common/vm/lexic/hip_lexic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def get_headers(self):
6969
return ["hip/hip_runtime.h", "tensorforge_device/hip.h"]
7070

7171
def get_fptype(self, fptype, length=1):
72-
return f'HIP_vector_type<{fptype}, {length}>'
72+
return f'tensorforge::VectorT<{fptype}, {length}>'
7373

7474
def glb_store(self, lhs, rhs, nontemporal=False):
7575
if nontemporal and self._underlying_hardware == 'amd':

tensorforge/generators/generator.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,8 @@ def generate_inner():
191191

192192
if self._persistent_threading:
193193
# TODO: OMP target
194+
# TODO: maybe iterate over adjacent elements? (for indirect pointers)
195+
194196
offset = []
195197
idx = i - 1
196198
for ssection in reversed(self._sections[:i]):
@@ -299,12 +301,6 @@ def _emit_global_ir(self):
299301
self._section.shr_mem_obj = shmbuilder.get_resultant_obj()
300302
self._section.global_ir.extend(shmbuilder.get_instructions())
301303

302-
builder = GetElementPtrBuilder(self._context, self._scopes)
303-
for symbol in self._scopes.get_global_scope().values():
304-
if symbol.obj.addressing == Addressing.SCALAR or (symbol.obj.addressing == Addressing.NONE and symbol.stype == SymbolType.Data):
305-
builder.build(symbol)
306-
self._section.global_ir.extend(builder.get_instructions())
307-
308304
# load globals to shared memory (if requested)
309305
if self._preload_globals:
310306
load_ir = []
@@ -336,6 +332,17 @@ def _emit_global_ir(self):
336332
self._scopes.remove_scope()
337333
self._preload_globals = False
338334

335+
builder = GetElementPtrBuilder(self._context, self._scopes)
336+
for symbol in self._scopes.get_global_scope().values():
337+
if symbol.obj.addressing == Addressing.SCALAR or (symbol.obj.addressing == Addressing.NONE and (symbol.stype == SymbolType.Data or not self._preload_globals)):
338+
builder.build(symbol)
339+
self._section.global_ir.extend(builder.get_instructions())
340+
341+
# pipelines
342+
for symbol in self._scopes.get_global_scope().values():
343+
if symbol.obj.addressing in [Addressing.STRIDED, Addressing.PTR_BASED]:
344+
pass
345+
339346
if not self._preload_globals:
340347
if last_barrier:
341348
self._section.global_ir.append(SyncGrid(self._context))
@@ -349,8 +356,8 @@ def _emit_ir(self, descr_list):
349356
builder = GetElementPtrBuilder(self._context, self._scopes)
350357
self._scopes.add_scope()
351358
for symbol in self._scopes.get_global_scope().values():
352-
firstptr = symbol.obj.addressing == Addressing.SCALAR or (symbol.obj.addressing == Addressing.NONE and symbol.stype == SymbolType.Data)
353-
if not firstptr and not (self._preload_globals and symbol.obj.addressing == Addressing.NONE):
359+
firstptr = symbol.obj.addressing == Addressing.SCALAR or symbol.obj.addressing == Addressing.NONE
360+
if not firstptr:
354361
builder.build(symbol)
355362
self._section.ir.extend(builder.get_instructions())
356363

tensorforge/include/tensorforge_device/hip.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -721,6 +721,24 @@ transpose16x16b32(T &w1, T &w2, T &w3, T &w4, T &w5, T &w6, T &w7, T &w8, T &w9,
721721
w16 = dppUpdate<0x128, 0b1111, 0b0011, true>(u8, u16);
722722
}
723723

724+
template <typename T>
725+
__device__ __forceinline__ void transpose16x2(T &w1, T &w2, T v1, T v2) {
726+
w1 = dppUpdate<0x128, 0b1111, 0b1100, true>(v2, v1);
727+
w2 = dppUpdate<0x128, 0b1111, 0b0011, true>(v1, v2);
728+
}
729+
730+
template <typename T>
731+
__device__ __forceinline__ void transpose16x4(T &w1, T &w2, T &w3, T &w4, T v1,
732+
T v2, T v3, T v4) {
733+
const T u1 = dppUpdate<0x124, 0b1111, 0b1010, true>(v2, v1);
734+
const T u2 = dppUpdate<0x12c, 0b1111, 0b0101, true>(v1, v2);
735+
const T u3 = dppUpdate<0x124, 0b1111, 0b1010, true>(v4, v3);
736+
const T u4 = dppUpdate<0x12c, 0b1111, 0b0101, true>(v3, v4);
737+
738+
transpose16x2(w1, w3, u1, u3);
739+
transpose16x2(w2, w4, u2, u4);
740+
}
741+
724742
#define CM4STR(p1, p2, p3, p4, c, a, b) \
725743
"v_cndmask_b32_dpp " c ", " a ", " b CMVCC \
726744
" quad_perm:[" STR(p1) "," STR(p2) "," STR(p3) "," STR( \

0 commit comments

Comments
 (0)