@@ -184,25 +184,23 @@ def inner(indices):
184184 writer (f'__syncwarp();' )
185185 writer (f'{ self ._pipeline } .producer_commit();' )
186186
187- #if False:
188- # writer('cooperative_groups::wait(cooperative_groups::this_thread_block());')
189-
190187 def _write_datatransfer (self , writer , src_offset , dst_offset , index , length , nontemporal , linscale = None ):
191- if not self ._use_cuda_memcpy or linscale is not None or True :
192- pos = 0
193- for vecsize in [1 ]:
194- if src_offset % vecsize == 0 :
195- num_hops = ((length - pos * self ._num_threads ) // (self ._num_threads * vecsize )) * vecsize
196- self ._write_hop (writer , src_offset , dst_offset , index , pos , pos + num_hops , vecsize , nontemporal , linscale )
197- pos += num_hops
198- rest = length % self ._num_threads
199- if rest > 0 :
200- with writer .If (f'{ self ._linear_idx ()} < { rest } ' ):
201- self ._write_hop (writer , src_offset , dst_offset , index , pos , pos + 1 , 1 , nontemporal , linscale )
188+ pos = 0
189+
190+ if self ._use_cuda_memcpy :
191+ granularities = [1 ]
202192 else :
203- dest_access_index = self ._dest .access_address (self ._context , index )
204- src_access_index = self ._src .access_address (self ._context , index )
205- writer (f'cuda::memcpy_async(cooperative_groups::this_thread_block(), &{ self ._dest .name } [{ dst_offset } + { dest_access_index } ], &{ self ._src .name } [{ src_offset } + { src_access_index } ], cuda::aligned_size_t<{ self ._dest .get_fptype ().size ()} >({ length * self ._dest .get_fptype ().size ()} ), { self ._pipeline } );' )
193+ granularities = [4 , 2 , 1 ]
194+
195+ for vecsize in granularities :
196+ if src_offset % vecsize == 0 :
197+ num_hops = ((length - pos * self ._num_threads ) // (self ._num_threads * vecsize )) * vecsize
198+ self ._write_hop (writer , src_offset , dst_offset , index , pos , pos + num_hops , vecsize , nontemporal , linscale )
199+ pos += num_hops
200+ rest = length % self ._num_threads
201+ if rest > 0 :
202+ with writer .If (f'{ self ._linear_idx ()} < { rest } ' ):
203+ self ._write_hop (writer , src_offset , dst_offset , index , pos , pos + 1 , 1 , nontemporal , linscale )
206204
207205 def _write_hop (self , writer , src_offset , dst_offset , index , start , end , increment , nontemporal , linscale ):
208206 if end > start :
@@ -325,9 +323,66 @@ def gen_code_inner(self, writer: Writer) -> None:
325323 for dim in src_bbox .sizes ():
326324 total_size *= dim
327325
328- for i in range (0 , total_size , self ._num_threads ):
329- self ._src .load_linear (writer , self ._context , f'v{ i } ' , i )
330- self ._dest .store_linear (writer , self ._context , f'v{ i } ' , i )
326+ start = 0
327+ for g in [1 ]: #[4, 2, 1]:
328+ granularity = self ._num_threads * g
329+ for i in range (start , total_size , granularity ):
330+ self ._src .load_linear (writer , self ._context , f'v{ i } ' , i , g )
331+ self ._dest .store_linear (writer , self ._context , f'v{ i } ' , i , g )
332+
333+ start = (total_size // granularity ) * granularity
334+
335+ elif self ._context .get_vm ().get_hw_descr ().vendor in ['amd' ]:
336+
337+ # float4 load
338+
339+ # for now: use 0 1 2 3, transpose4x4
340+
341+ # TODO: sort into 4x4x4 blocks
342+
343+ lead_size = src_bbox .size (0 )
344+ lead_count = (lead_size + self ._num_threads - 1 ) // self ._num_threads
345+
346+ total_count = lead_count
347+ for dim in src_bbox .sizes ()[1 :]:
348+ total_count *= dim
349+
350+ start = 0
351+
352+ prec = 'float'
353+
354+ for g in [4 , 2 , 1 ]: # [4, 3, 2, 1]
355+ # 4x4
356+ # writer(f'const auto f{g}idx = (threadIdx.x % {g}) * {self._num_threads} + (threadIdx.x / {g}) * {g};')
357+
358+ writer (f'const auto f{ g } idx = ((threadIdx.x / { 16 // g } ) % { g } ) * { self ._num_threads } + (threadIdx.x % { 16 // g } ) * { g } + (threadIdx.x / 16) * 16;' )
359+
360+ total_count_g = (total_count // g ) * g
361+ for i in range (start , total_count_g , g ):
362+ sidx = i // lead_count
363+ ridx = i % lead_count
364+ index = sidx * lead_size + ridx * self ._num_threads
365+ writer (f'const auto v{ i } = *(tensorforge::VectorT<{ prec } , { g } >*)&{ self ._src .name } [{ index } + f{ g } idx];' )
366+
367+ args2 = ', ' .join (f'v{ i } [{ k } ]' for k in range (g ))
368+
369+ for k in range (g ):
370+ writer (f'{ prec } v{ i } w{ k } = 0;' )
371+
372+ args1 = ', ' .join (f'v{ i } w{ k } ' for k in range (g ))
373+
374+ if g == 4 :
375+ writer (f'tensorforge::transpose16x4({ args1 } , { args2 } );' )
376+ if g == 2 :
377+ writer (f'tensorforge::transpose16x2({ args1 } , { args2 } );' )
378+ if g == 1 :
379+ writer (f'{ args1 } = { args2 } ;' )
380+
381+ # TODO: generalize
382+ for k in range (g ):
383+ writer (f'{ self ._dest .name } [{ i + k } ] = v{ i } w{ k } ;' )
384+
385+ start = total_count_g
331386
332387 else :
333388 loops = []
0 commit comments