From 3e3ab9085183d43c674fa93b934466c58d288fcc Mon Sep 17 00:00:00 2001 From: Aryaman Gupta Date: Wed, 20 May 2026 12:20:06 +0000 Subject: [PATCH 1/2] sparse_ops_gpu.cpp: gate cache and contig_warps for backward by dtype + workload size Compute enable_cache_and_contig_for_bwd in the forward and pass it through saved_data to the backward. The flag is on for fp16/bf16 unconditionally (software CAS-loop atomicAdd makes the cache+contig wins large), and for fp32 only when num_total_indices >= the lower sort threshold (cache+contig per-warp overhead is otherwise not offset by the savings). The gating intentionally uses only the lower sort threshold, not the upper one: when the workload is too large for sort to be worthwhile, cache+contig still help because the kernel runtime dominates the per-warp overhead. --- fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp | 29 +++++++++++++++++--- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp b/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp index da46b95100..b76741e167 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp +++ b/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp @@ -429,11 +429,11 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( // place for optimization. constexpr size_t kSortIndicesUpperThreshold = 15'000'000; - // Sorting only pays off when there are enough indices to amortize + // Sorting only pays off when there are enough indices to amortize // the sorting cost, and the crossover point depends on the dtype. constexpr size_t kSortIndicesLowerThresholdLowPrec = 1'000'000; constexpr size_t kSortIndicesLowerThresholdFullPrec = 2'000'000; - + const bool is_low_precision = first_input.dtype().itemsize() <= 2; const size_t kSortIndicesLowerThreshold = is_low_precision ? kSortIndicesLowerThresholdLowPrec @@ -441,8 +441,22 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( const bool use_sorted_indices_for_bwd = (num_total_indices >= kSortIndicesLowerThreshold) && (num_total_indices < kSortIndicesUpperThreshold); + + // Cache+contig in the backward kernel pay off whenever the per-warp + // sequence of indices has runs of duplicates or spatial locality. For + // low-precision dtypes (fp16/bf16) the wins are large -- atomicAdd is a + // software CAS-loop emulation -- so we enable them unconditionally. For + // fp32 the hardware-native atomicAdd is cheap, so they only pay off when + // the workload is large enough to amortize the per-warp cache logic + // overhead. We use the same lower threshold as the sort heuristic, but + // intentionally ignore the upper threshold: when the workload is too + // large for sort to be worthwhile, cache+contig still help because the + // kernel runtime dominates. + const bool enable_cache_and_contig_for_bwd = + is_low_precision || (num_total_indices >= kSortIndicesLowerThreshold); #else const bool use_sorted_indices_for_bwd = false; + const bool enable_cache_and_contig_for_bwd = false; (void)num_total_indices; #endif @@ -474,6 +488,7 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( reinterpret_cast(warp_offsets_group), reinterpret_cast(num_cols_group), warp_offset, + enable_cache_and_contig_for_bwd, }; auto saved_data_t = at::empty( {sizeof(saved_data) / sizeof(int64_t)}, @@ -560,6 +575,7 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( const int32_t* num_cols_group = reinterpret_cast(saved_data_ptr[5]); int64_t total_num_warps = saved_data_ptr[6]; + const bool enable_cache_and_contig = saved_data_ptr[7]; // We checked in forward that all output rows are the same for all member // in the group @@ -769,8 +785,13 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( args_tensor = args_tensor.to(first_indices.device(), /*non_blocking=*/true); #ifdef USE_ROCM - constexpr bool use_contiguous_warps = true; - constexpr bool use_cache = true; + // enable_cache_and_contig is computed in the forward (see the heuristic + // there): true for fp16/bf16 always, and for fp32 when total_indices is + // above the lower threshold. It is intentionally decoupled from + // use_sorted_indices so that cache+contig stay on for large fp32 + // workloads even when sort is disabled by the upper threshold. + const bool use_contiguous_warps = enable_cache_and_contig; + const bool use_cache = enable_cache_and_contig; #else constexpr bool use_contiguous_warps = false; constexpr bool use_cache = false; From df04cd45ed7de61f9f477b5272fd6a342c2a3fce Mon Sep 17 00:00:00 2001 From: Aryaman Gupta Date: Wed, 20 May 2026 16:09:38 +0000 Subject: [PATCH 2/2] sparse_ops_gpu.cpp: edits comments for cache+contig gating --- fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp | 17 +++-------------- 1 file changed, 3 insertions(+), 14 deletions(-) diff --git a/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp b/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp index b76741e167..31707039ef 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp +++ b/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp @@ -442,16 +442,8 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( (num_total_indices >= kSortIndicesLowerThreshold) && (num_total_indices < kSortIndicesUpperThreshold); - // Cache+contig in the backward kernel pay off whenever the per-warp - // sequence of indices has runs of duplicates or spatial locality. For - // low-precision dtypes (fp16/bf16) the wins are large -- atomicAdd is a - // software CAS-loop emulation -- so we enable them unconditionally. For - // fp32 the hardware-native atomicAdd is cheap, so they only pay off when - // the workload is large enough to amortize the per-warp cache logic - // overhead. We use the same lower threshold as the sort heuristic, but - // intentionally ignore the upper threshold: when the workload is too - // large for sort to be worthwhile, cache+contig still help because the - // kernel runtime dominates. + // Only use caching and contiguous warp dispatch when there are sufficiently many + // indices (for fp32). Always on for fp16 const bool enable_cache_and_contig_for_bwd = is_low_precision || (num_total_indices >= kSortIndicesLowerThreshold); #else @@ -786,10 +778,7 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( #ifdef USE_ROCM // enable_cache_and_contig is computed in the forward (see the heuristic - // there): true for fp16/bf16 always, and for fp32 when total_indices is - // above the lower threshold. It is intentionally decoupled from - // use_sorted_indices so that cache+contig stay on for large fp32 - // workloads even when sort is disabled by the upper threshold. + // there) const bool use_contiguous_warps = enable_cache_and_contig; const bool use_cache = enable_cache_and_contig; #else