Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -429,20 +429,26 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu(
// place for optimization.
constexpr size_t kSortIndicesUpperThreshold = 15'000'000;

// Sorting only pays off when there are enough indices to amortize
// Sorting only pays off when there are enough indices to amortize
// the sorting cost, and the crossover point depends on the dtype.
constexpr size_t kSortIndicesLowerThresholdLowPrec = 1'000'000;
constexpr size_t kSortIndicesLowerThresholdFullPrec = 2'000'000;

const bool is_low_precision = first_input.dtype().itemsize() <= 2;
const size_t kSortIndicesLowerThreshold = is_low_precision
? kSortIndicesLowerThresholdLowPrec
: kSortIndicesLowerThresholdFullPrec;
const bool use_sorted_indices_for_bwd =
(num_total_indices >= kSortIndicesLowerThreshold) &&
(num_total_indices < kSortIndicesUpperThreshold);

// Only use caching and contiguous warp dispatch when there are sufficiently many
// indices (for fp32). Always on for fp16
const bool enable_cache_and_contig_for_bwd =
is_low_precision || (num_total_indices >= kSortIndicesLowerThreshold);
#else
const bool use_sorted_indices_for_bwd = false;
const bool enable_cache_and_contig_for_bwd = false;
(void)num_total_indices;
#endif

Expand Down Expand Up @@ -474,6 +480,7 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu(
reinterpret_cast<int64_t>(warp_offsets_group),
reinterpret_cast<int64_t>(num_cols_group),
warp_offset,
enable_cache_and_contig_for_bwd,
};
auto saved_data_t = at::empty(
{sizeof(saved_data) / sizeof(int64_t)},
Expand Down Expand Up @@ -560,6 +567,7 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu(
const int32_t* num_cols_group =
reinterpret_cast<const int32_t*>(saved_data_ptr[5]);
int64_t total_num_warps = saved_data_ptr[6];
const bool enable_cache_and_contig = saved_data_ptr[7];

// We checked in forward that all output rows are the same for all member
// in the group
Expand Down Expand Up @@ -769,8 +777,10 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu(
args_tensor = args_tensor.to(first_indices.device(), /*non_blocking=*/true);

#ifdef USE_ROCM
constexpr bool use_contiguous_warps = true;
constexpr bool use_cache = true;
// enable_cache_and_contig is computed in the forward (see the heuristic
// there)
const bool use_contiguous_warps = enable_cache_and_contig;
const bool use_cache = enable_cache_and_contig;
#else
constexpr bool use_contiguous_warps = false;
constexpr bool use_cache = false;
Expand Down