diff --git a/fbgemm_gpu/include/fbgemm_gpu/utils/inclusive_sum_scan.cuh b/fbgemm_gpu/include/fbgemm_gpu/utils/inclusive_sum_scan.cuh index 7e1f8c1e6d..837807136a 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/utils/inclusive_sum_scan.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/utils/inclusive_sum_scan.cuh @@ -69,11 +69,52 @@ __inline__ __device__ void inclusive_sum_scan_kernel( const int block_id, const bool is_multi_block, const int signal) { +// ROCm path +#ifdef USE_ROCM // Perform scan within a block cub::BlockScan(temp_storage) .InclusiveSum(arr, arr); - // Perform stream scan across blocks + // Perform scan across blocks + if (is_multi_block) { + const bool is_last_thread = + threadIdx.x == (num_entries_per_block - 1) / ITEMS_PER_THREAD; + // The thread that holds the last entry in the block does synchronization + if (is_last_thread) { + scalar_t block_prev_local = 0; + if (block_id != 0) { + // Spin wait for the previous block to write the sum value + while (atomicAdd(&block_flags[block_id - 1], 0) < signal) + ; + + // Get sum from the previous block + *block_prev = block_prev_local = block_sums[block_id - 1]; + } + + // Write sum to global memory for the next block to consume + const int scope = (num_entries_per_block - 1) % ITEMS_PER_THREAD; + block_sums[block_id] = block_prev_local + arr[scope]; + __threadfence(); + // Set a flag to notify the next block + atomicExch(&block_flags[block_id], signal); + } + + __syncthreads(); + + if (block_id != 0) { + scalar_t block_prev_local = *block_prev; + for (int i = 0; i < ITEMS_PER_THREAD; i++) { + arr[i] += block_prev_local; + } + } + } +#else + // CUDA path + // Perform scan across blocks + cub::BlockScan(temp_storage) + .InclusiveSum(arr, arr); + + // Perform scan across blocks if (is_multi_block) { // The thread that holds the last entry in the block does synchronization if (threadIdx.x == (num_entries_per_block - 1) / ITEMS_PER_THREAD) { @@ -104,6 +145,6 @@ __inline__ __device__ void inclusive_sum_scan_kernel( } } } +#endif } - } // namespace fbgemm_gpu diff --git a/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu b/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu index c795d19ecd..bb05c61206 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu +++ b/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu @@ -6,6 +6,7 @@ * LICENSE file in the root directory of this source tree. */ +#include #include "common.cuh" using Tensor = at::Tensor; @@ -17,7 +18,8 @@ template < typename index_t, typename acc_t, int NUM_THREADS_PER_BLOCK, - int MAX_ENTRIES_PER_BLOCK> + int MAX_ENTRIES_PER_BLOCK, + int ENTRIES_PER_THREAD> __global__ void index_select_scalar_cumsum_kernel( pta::PackedTensorAccessor32 output, pta::PackedTensorAccessor32 output_cumsum, @@ -31,6 +33,81 @@ __global__ void index_select_scalar_cumsum_kernel( acc_t* block_sums) { typedef cub::BlockScan BlockScan; __shared__ typename BlockScan::TempStorage bs_temp_storage; + __shared__ acc_t block_prefix; + +// ROCm path +#ifdef USE_ROCM + const int output_batch_size = indices.size(0); + const int num_entries = num_batches * output_batch_size; + const bool multi_block = gridDim.x > 1; + const int block_entries = blockIdx.x == gridDim.x - 1 + ? last_block_num_entries + : MAX_ENTRIES_PER_BLOCK; + const int block_entry_start = blockIdx.x * MAX_ENTRIES_PER_BLOCK; + const int remaining_entries = num_entries - block_entry_start; + const int num_entries_per_block = remaining_entries > 0 + ? (remaining_entries < block_entries ? remaining_entries : block_entries) + : 0; + + const int base_entry = block_entry_start + threadIdx.x * ENTRIES_PER_THREAD; + acc_t local_data[ENTRIES_PER_THREAD]; + + #pragma unroll + for (int i = 0; i < ENTRIES_PER_THREAD; ++i) { + const int entry = base_entry + i; + if (entry < num_entries) { + const int bid = entry / output_batch_size; + const int idx_in_batch = entry - bid * output_batch_size; + const int bid_base = bid * input_batch_size; + const index_t sel_idx = indices[idx_in_batch]; + local_data[i] = __builtin_nontemporal_load(&input[bid_base + sel_idx]); + output[entry] = local_data[i]; + } else { + local_data[i] = 0; + } + } + + // Faster path for single block + if (!multi_block) { + if (num_entries_per_block > 0) { + BlockScan(bs_temp_storage).InclusiveSum(local_data, local_data); + } + if (base_entry < num_entries) { + #pragma unroll + for (int i = 0; i < ENTRIES_PER_THREAD; ++i) { + const int entry = base_entry + i; + if (entry < num_entries) { + output_cumsum[entry] = local_data[i]; + } + } + } + return; + } + + if (num_entries_per_block > 0) { + inclusive_sum_scan_kernel( + local_data, + bs_temp_storage, + block_flags, + block_sums, + &block_prefix, + num_entries_per_block, + blockIdx.x, + multi_block, + 1); + } + + if (base_entry < num_entries) { + #pragma unroll + for (int i = 0; i < ENTRIES_PER_THREAD; ++i) { + const int entry = base_entry + i; + if (entry < num_entries) { + output_cumsum[entry] = local_data[i]; + } + } + } +#else + // CUDA path __shared__ acc_t smem[MAX_ENTRIES_PER_BLOCK]; const auto tid = blockIdx.x * blockDim.x + threadIdx.x; const int output_batch_size = indices.size(0); @@ -65,6 +142,7 @@ __global__ void index_select_scalar_cumsum_kernel( if (tid < num_batches * output_batch_size) { output_cumsum[tid] = *local_data; } +#endif } template < @@ -183,62 +261,151 @@ class KeyedJaggedIndexSelectDim1GPUOp const int num_batches = lengths.numel() / batch_size; const int num_output_lengths = num_batches * indices.numel(); const int MAX_CUMSUM_ENTRIES_PER_BLOCK = 256; +#ifdef USE_ROCM + const int num_entries_per_thread[] = {4, 2, 1}; + int entries_per_thread = 1; + for (int i : num_entries_per_thread) { + if (indices.numel() % i == 0) { + entries_per_thread = i; + break; + } + } +#else + constexpr int ENTRIES_PER_THREAD = 1; auto grid_size = cuda_calc_xblock_count( num_output_lengths, MAX_CUMSUM_ENTRIES_PER_BLOCK); +#endif Tensor output_offsets = at::empty({num_batches * indices.numel()}, offsets.options()); Tensor output_lengths = at::empty({num_batches * indices.numel()}, lengths.options()); - Tensor block_flags, block_sums; - if (grid_size > 1) { - block_flags = at::zeros({grid_size}, lengths.options().dtype(at::kInt)); - block_sums = at::empty({grid_size}, output_offsets.options()); - } - // Do index select and cumsum - AT_DISPATCH_INDEX_TYPES( - lengths.scalar_type(), "index_select_scalar_cumsum_wrapper_1", [&] { - using length_t = index_t; - AT_DISPATCH_INDEX_TYPES( - offsets.scalar_type(), - "index_select_scalar_cumsum_wrapper_2", - [&] { - using offset_t = index_t; - AT_DISPATCH_INDEX_TYPES( - indices.scalar_type(), - "index_select_scalar_cumsum_wrapper_3", - [&] { - FBGEMM_LAUNCH_KERNEL( - (index_select_scalar_cumsum_kernel< - length_t, - index_t, - offset_t, - MAX_CUMSUM_ENTRIES_PER_BLOCK, - MAX_CUMSUM_ENTRIES_PER_BLOCK>), - grid_size, - MAX_CUMSUM_ENTRIES_PER_BLOCK, - 0, - at::cuda::getCurrentCUDAStream(), - PTA_B(output_lengths, length_t, 1, 32), - PTA_B(output_offsets, offset_t, 1, 32), - PTA_B(lengths, length_t, 1, 32), - PTA_B(indices, index_t, 1, 32), - num_batches, - batch_size, - num_output_lengths - - MAX_CUMSUM_ENTRIES_PER_BLOCK * (grid_size - 1), - grid_size > 1 ? block_flags.data_ptr() : nullptr, - grid_size > 1 ? block_sums.data_ptr() - : nullptr); - }); - }); - }); +#ifdef USE_ROCM + // ROCm path + auto dispatch_cumsum = [&](auto vec_tag) { + constexpr int ENTRIES_PER_THREAD = decltype(vec_tag)::value; + constexpr int ENTRIES_PER_BLOCK = + MAX_CUMSUM_ENTRIES_PER_BLOCK * ENTRIES_PER_THREAD; + const auto rocm_grid_size = + (num_output_lengths + ENTRIES_PER_BLOCK - 1) / ENTRIES_PER_BLOCK; + + if (rocm_grid_size == 0) + return; + + if (rocm_grid_size > 1) { + block_flags = at::zeros({rocm_grid_size}, lengths.options().dtype(at::kInt)); + block_sums = at::empty({rocm_grid_size}, output_offsets.options()); + } + + AT_DISPATCH_INDEX_TYPES( + lengths.scalar_type(), "index_select_scalar_cumsum_wrapper_1", [&] { + using length_t = index_t; + AT_DISPATCH_INDEX_TYPES( + offsets.scalar_type(), + "index_select_scalar_cumsum_wrapper_2", + [&] { + using offset_t = index_t; + AT_DISPATCH_INDEX_TYPES( + indices.scalar_type(), + "index_select_scalar_cumsum_wrapper_3", + [&] { + FBGEMM_LAUNCH_KERNEL( + (index_select_scalar_cumsum_kernel< + length_t, + index_t, + offset_t, + MAX_CUMSUM_ENTRIES_PER_BLOCK, + ENTRIES_PER_BLOCK, + ENTRIES_PER_THREAD>), + rocm_grid_size, + MAX_CUMSUM_ENTRIES_PER_BLOCK, + 0, + at::cuda::getCurrentCUDAStream(), + PTA_B(output_lengths, length_t, 1, 32), + PTA_B(output_offsets, offset_t, 1, 32), + PTA_B(lengths, length_t, 1, 32), + PTA_B(indices, index_t, 1, 32), + num_batches, + batch_size, + num_output_lengths - + ENTRIES_PER_BLOCK * (rocm_grid_size - 1), + rocm_grid_size > 1 + ? block_flags.data_ptr() + : nullptr, + rocm_grid_size > 1 + ? block_sums.data_ptr() + : nullptr); + }); + }); + }); + }; + + switch (entries_per_thread) { + case 4: + dispatch_cumsum(std::integral_constant{}); + break; + case 2: + dispatch_cumsum(std::integral_constant{}); + break; + default: + dispatch_cumsum(std::integral_constant{}); + break; + } +#else + // CUDA path + if (grid_size > 1) { + block_flags = at::zeros({grid_size}, lengths.options().dtype(at::kInt)); + block_sums = at::empty({grid_size}, output_offsets.options()); + } + + AT_DISPATCH_INDEX_TYPES( + lengths.scalar_type(), "index_select_scalar_cumsum_wrapper_1", [&] { + using length_t = index_t; + AT_DISPATCH_INDEX_TYPES( + offsets.scalar_type(), + "index_select_scalar_cumsum_wrapper_2", + [&] { + using offset_t = index_t; + AT_DISPATCH_INDEX_TYPES( + indices.scalar_type(), + "index_select_scalar_cumsum_wrapper_3", + [&] { + FBGEMM_LAUNCH_KERNEL( + (index_select_scalar_cumsum_kernel< + length_t, + index_t, + offset_t, + MAX_CUMSUM_ENTRIES_PER_BLOCK, + MAX_CUMSUM_ENTRIES_PER_BLOCK, + ENTRIES_PER_THREAD>), + grid_size, + MAX_CUMSUM_ENTRIES_PER_BLOCK, + 0, + at::cuda::getCurrentCUDAStream(), + PTA_B(output_lengths, length_t, 1, 32), + PTA_B(output_offsets, offset_t, 1, 32), + PTA_B(lengths, length_t, 1, 32), + PTA_B(indices, index_t, 1, 32), + num_batches, + batch_size, + num_output_lengths - + MAX_CUMSUM_ENTRIES_PER_BLOCK * (grid_size - 1), + grid_size > 1 + ? block_flags.data_ptr() + : nullptr, + grid_size > 1 + ? block_sums.data_ptr() + : nullptr); + }); + }); + }); +#endif const int64_t num_outputs = (selected_lengths_sum.has_value()) - ? selected_lengths_sum.value().guard_int(__FILE__, __LINE__) - : output_offsets[output_offsets.numel() - 1].item(); + ? selected_lengths_sum.value().guard_int(__FILE__, __LINE__) + : output_offsets[output_offsets.numel() - 1].item(); Tensor output = at::empty({num_outputs}, values.options()); Tensor output_weights; if (weights.has_value()) {