1111
1212namespace device {
1313
14- constexpr int WorkGroupSize = 1024 ;
15- constexpr int ItemsPerWorkItem = 4 ;
14+ constexpr int BlockSize = 1024 ;
15+ constexpr int ItemsPerThread = 4 ;
1616template <typename T>
1717struct Sum {
1818 T defaultValue{0 };
@@ -98,15 +98,17 @@ __device__ __forceinline__ T blockReduce(T val, T* shmem, OperationT operation)
9898 const int warpId = threadIdx.x / warpSize;
9999
100100 val = warpReduce (val, operation);
101- if (laneId == 0 )
101+ if (laneId == 0 ) {
102102 shmem[warpId] = val;
103+ }
103104 __syncthreads ();
104105
105- const int numWarps = WorkGroupSize / warpSize;
106+ const int numWarps = BlockSize / warpSize;
106107 val = (threadIdx.x < numWarps) ? shmem[laneId] : operation.defaultValue ;
107108
108- if (warpId == 0 )
109+ if (warpId == 0 ) {
109110 val = warpReduce (val, operation);
111+ }
110112
111113 return val;
112114}
@@ -120,20 +122,20 @@ __global__ void initKernel(T* result, OperationT operation) {
120122}
121123
122124template <typename AccT, typename VecT, typename OperationT>
123- __launch_bounds__ (WorkGroupSize ) void __global__ kernel_reduce (
125+ __launch_bounds__ (BlockSize ) void __global__ kernel_reduce (
124126 AccT* result, const VecT* vector, size_t size, bool overrideResult, OperationT operation) {
125127
126128 // Maximum block size 1024, warp size 32 so 1024/32 = 32 chosen
127129 // For AMD, warp size 64, 1024/64 = 16, but 32 should work with a few idle memory addresses
128130 __shared__ AccT shmem[32 ];
129131
130132 AccT threadAcc = operation.defaultValue ;
131- size_t blockBaseIdx = blockIdx.x * (WorkGroupSize * ItemsPerWorkItem );
133+ size_t blockBaseIdx = blockIdx.x * (BlockSize * ItemsPerThread );
132134 size_t threadBaseIdx = blockBaseIdx + threadIdx.x ;
133135
134136#pragma unroll
135- for (int i = 0 ; i < ItemsPerWorkItem ; i++) {
136- size_t idx = threadBaseIdx + i * WorkGroupSize ;
137+ for (int i = 0 ; i < ItemsPerThread ; i++) {
138+ size_t idx = threadBaseIdx + i * BlockSize ;
137139 if (idx < size) {
138140 threadAcc = operation (threadAcc, static_cast <AccT>(ntload (&vector[idx])));
139141 }
@@ -156,8 +158,8 @@ void Algorithms::reduceVector(AccT* result,
156158 void * streamPtr) {
157159 auto * stream = reinterpret_cast <internals::DeviceStreamT>(streamPtr);
158160
159- size_t totalItems = WorkGroupSize * ItemsPerWorkItem ;
160- size_t numBlocks = (size + totalItems - 1 ) / totalItems;
161+ const size_t totalItems = BlockSize * ItemsPerThread ;
162+ const size_t numBlocks = (size + totalItems - 1 ) / totalItems;
161163
162164 if (overrideResult) {
163165 switch (type) {
@@ -175,17 +177,17 @@ void Algorithms::reduceVector(AccT* result,
175177
176178 switch (type) {
177179 case ReductionType::Add: {
178- kernel_reduce<<<numBlocks, WorkGroupSize , 0 , stream>>>(
180+ kernel_reduce<<<numBlocks, BlockSize , 0 , stream>>>(
179181 result, buffer, size, overrideResult, device::Sum<AccT>());
180182 break ;
181183 }
182184 case ReductionType::Max: {
183- kernel_reduce<<<numBlocks, WorkGroupSize , 0 , stream>>>(
185+ kernel_reduce<<<numBlocks, BlockSize , 0 , stream>>>(
184186 result, buffer, size, overrideResult, device::Max<AccT>());
185187 break ;
186188 }
187189 case ReductionType::Min: {
188- kernel_reduce<<<numBlocks, WorkGroupSize , 0 , stream>>>(
190+ kernel_reduce<<<numBlocks, BlockSize , 0 , stream>>>(
189191 result, buffer, size, overrideResult, device::Min<AccT>());
190192 break ;
191193 }
0 commit comments