diff --git a/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp b/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp index da46b95100..31707039ef 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp +++ b/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp @@ -429,11 +429,11 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( // place for optimization. constexpr size_t kSortIndicesUpperThreshold = 15'000'000; - // Sorting only pays off when there are enough indices to amortize + // Sorting only pays off when there are enough indices to amortize // the sorting cost, and the crossover point depends on the dtype. constexpr size_t kSortIndicesLowerThresholdLowPrec = 1'000'000; constexpr size_t kSortIndicesLowerThresholdFullPrec = 2'000'000; - + const bool is_low_precision = first_input.dtype().itemsize() <= 2; const size_t kSortIndicesLowerThreshold = is_low_precision ? kSortIndicesLowerThresholdLowPrec @@ -441,8 +441,14 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( const bool use_sorted_indices_for_bwd = (num_total_indices >= kSortIndicesLowerThreshold) && (num_total_indices < kSortIndicesUpperThreshold); + + // Only use caching and contiguous warp dispatch when there are sufficiently many + // indices (for fp32). Always on for fp16 + const bool enable_cache_and_contig_for_bwd = + is_low_precision || (num_total_indices >= kSortIndicesLowerThreshold); #else const bool use_sorted_indices_for_bwd = false; + const bool enable_cache_and_contig_for_bwd = false; (void)num_total_indices; #endif @@ -474,6 +480,7 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( reinterpret_cast(warp_offsets_group), reinterpret_cast(num_cols_group), warp_offset, + enable_cache_and_contig_for_bwd, }; auto saved_data_t = at::empty( {sizeof(saved_data) / sizeof(int64_t)}, @@ -560,6 +567,7 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( const int32_t* num_cols_group = reinterpret_cast(saved_data_ptr[5]); int64_t total_num_warps = saved_data_ptr[6]; + const bool enable_cache_and_contig = saved_data_ptr[7]; // We checked in forward that all output rows are the same for all member // in the group @@ -769,8 +777,10 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( args_tensor = args_tensor.to(first_indices.device(), /*non_blocking=*/true); #ifdef USE_ROCM - constexpr bool use_contiguous_warps = true; - constexpr bool use_cache = true; + // enable_cache_and_contig is computed in the forward (see the heuristic + // there) + const bool use_contiguous_warps = enable_cache_and_contig; + const bool use_cache = enable_cache_and_contig; #else constexpr bool use_contiguous_warps = false; constexpr bool use_cache = false;