From e0dd7f88c005da5fc1b0b90fe4e0a15d3c970f84 Mon Sep 17 00:00:00 2001 From: Li Li Date: Fri, 7 Nov 2025 21:54:01 +0000 Subject: [PATCH 1/8] opt group_index_select_or_add_2d_kernel --- .../src/sparse_ops/sparse_group_index.cu | 251 ++++++++++++++---- fbgemm_gpu/test/sparse/index_select_test.py | 30 ++- 2 files changed, 234 insertions(+), 47 deletions(-) diff --git a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu index 96c57cde68..c05c99de13 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu @@ -51,59 +51,218 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( const int64_t num_work_rows, // number of rows to work on per member const int64_t group_size) { const auto total_num_warps = warp_offsets_group[group_size]; - int32_t num_cols = 0; - int32_t warps_per_row = 0; - - if constexpr (!USE_VAR_COLS) { - num_cols = num_cols_group[0]; - warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP; - } + // USE_INDEX_SELECT is a template argument; the compiler prunes the unused branch. + if (USE_INDEX_SELECT) { + for (int64_t warp_id = threadIdx.y * gridDim.x + blockIdx.x; + warp_id < total_num_warps; + warp_id += gridDim.x * blockDim.y) { + int32_t member_id, member_warp_id, num_cols, warps_per_row; + if (USE_VAR_COLS) { + __shared__ int member_ids[kMaxThreads / kWarpSize]; + if (threadIdx.x == 0) { + binary_search_range( + &member_ids[threadIdx.y], + warp_offsets_group + 1, + warp_id, + group_size); + } + syncwarp(); + member_id = member_ids[threadIdx.y]; + num_cols = num_cols_group[member_id]; + warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP; + member_warp_id = warp_id - warp_offsets_group[member_id]; + } else { + // All columns are the same + num_cols = num_cols_group[0]; + warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP; + member_id = warp_id / (warps_per_row * num_work_rows); + member_warp_id = warp_id - (member_id * warps_per_row * num_work_rows); + } + const auto row = member_warp_id / warps_per_row; + const auto col_offset = + ((member_warp_id % warps_per_row) << LOG_COLS_PER_WARP) + + (threadIdx.x * UNROLL_FACTOR); + scalar_t* input = + reinterpret_cast(input_ptrs[member_id]) + col_offset; + scalar_t* output = + reinterpret_cast(output_ptrs[member_id]) + col_offset; - for (int64_t warp_id = threadIdx.y * gridDim.x + blockIdx.x; - warp_id < total_num_warps; - warp_id += gridDim.x * blockDim.y) { - int32_t member_id = 0; - int32_t member_warp_id = 0; - if constexpr (USE_VAR_COLS) { - __shared__ int member_ids[kMaxThreads / EMULATED_WARP_SIZE]; - if (threadIdx.x == 0) { - binary_search_range( - &member_ids[threadIdx.y], - warp_offsets_group + 1, - warp_id, - group_size); + index_t* indices = reinterpret_cast(indices_ptrs[member_id]); + const index_t idx = indices[row]; +#pragma unroll + for (int i = 0; i < UNROLL_FACTOR && col_offset + i < num_cols; i++) { + output[row * num_cols + i] = LDG(&input[idx * num_cols + i]); } - syncwarp(); - member_id = member_ids[threadIdx.y]; - num_cols = num_cols_group[member_id]; - warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP; - member_warp_id = warp_id - warp_offsets_group[member_id]; - } else { - // All columns are the same - member_id = warp_id / (warps_per_row * num_work_rows); - member_warp_id = warp_id - (member_id * warps_per_row * num_work_rows); } - const auto row = member_warp_id / warps_per_row; - const auto col_offset = - ((member_warp_id % warps_per_row) << LOG_COLS_PER_WARP) + - (threadIdx.x * UNROLL_FACTOR); - scalar_t* input = - reinterpret_cast(input_ptrs[member_id]) + col_offset; - scalar_t* output = - reinterpret_cast(output_ptrs[member_id]) + col_offset; - - index_t* indices = reinterpret_cast(indices_ptrs[member_id]); - const index_t idx = indices[row]; + } else { + // Cache a handful of scatter destinations per warp so we can merge + // consecutive updates that hit the same index before touching global memory. + constexpr int kCacheSlots = 2; + index_t cached_idx[kCacheSlots]; + scalar_t cached_vals[kCacheSlots][UNROLL_FACTOR]; + bool cached_valid[kCacheSlots]; #pragma unroll - for (int i = 0; i < UNROLL_FACTOR && col_offset + i < num_cols; i++) { - // Compile time conditional - if constexpr (USE_INDEX_SELECT) { - output[row * num_cols + i] = LDG(&input[idx * num_cols + i]); + for (int slot = 0; slot < kCacheSlots; ++slot) { + cached_valid[slot] = false; + } + int32_t active_member_id = -1; + int32_t active_num_cols = 0; + int32_t active_col_offset = -1; + scalar_t* active_input_base = nullptr; + scalar_t* active_output_base = nullptr; + index_t* active_indices = nullptr; + + auto flush_cache = [&](scalar_t* out_base, + int32_t num_cols, + int32_t col_offset) { + if (!out_base) { + return; + } +#pragma unroll + for (int slot = 0; slot < kCacheSlots; ++slot) { + if (!cached_valid[slot]) { + continue; + } + const int64_t row_offset = + static_cast(cached_idx[slot]) * num_cols; +#pragma unroll + for (int j = 0; j < UNROLL_FACTOR; ++j) { + const int32_t col = col_offset + j; + if (col >= num_cols) { + break; + } + gpuAtomicAddNoReturn( + out_base + row_offset + col, cached_vals[slot][j]); + } + cached_valid[slot] = false; + } + }; + + for (int64_t warp_id = threadIdx.y * gridDim.x + blockIdx.x; + warp_id < total_num_warps; + warp_id += gridDim.x * blockDim.y) { + int32_t member_id, member_warp_id, num_cols, warps_per_row; + if (USE_VAR_COLS) { + __shared__ int member_ids[kMaxThreads / kWarpSize]; + if (threadIdx.x == 0) { + binary_search_range( + &member_ids[threadIdx.y], + warp_offsets_group + 1, + warp_id, + group_size); + } + syncwarp(); + member_id = member_ids[threadIdx.y]; + num_cols = num_cols_group[member_id]; + warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP; + member_warp_id = warp_id - warp_offsets_group[member_id]; } else { - gpuAtomicAddNoReturn( - &output[idx * num_cols + i], input[row * num_cols + i]); + // All columns are the same + num_cols = num_cols_group[0]; + warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP; + member_id = warp_id / (warps_per_row * num_work_rows); + member_warp_id = warp_id - (member_id * warps_per_row * num_work_rows); + } + const int64_t row = member_warp_id / warps_per_row; + const int32_t col_offset = + static_cast(((member_warp_id % warps_per_row) + << LOG_COLS_PER_WARP) + + (threadIdx.x * UNROLL_FACTOR)); + + const bool member_changed = member_id != active_member_id; + const bool num_cols_changed = + member_changed ? false : (num_cols != active_num_cols); + const bool col_changed = + member_changed ? false : (col_offset != active_col_offset); + if (member_changed || num_cols_changed || col_changed) { + flush_cache(active_output_base, active_num_cols, active_col_offset); + active_member_id = member_id; + active_num_cols = num_cols; + active_col_offset = col_offset; + active_input_base = + reinterpret_cast(input_ptrs[member_id]); + active_output_base = + reinterpret_cast(output_ptrs[member_id]); + active_indices = + reinterpret_cast(indices_ptrs[member_id]); + } + + if (col_offset >= active_num_cols) { + continue; + } + + const index_t idx = active_indices[row]; + + scalar_t local_vals[UNROLL_FACTOR]; +#pragma unroll + for (int j = 0; j < UNROLL_FACTOR; ++j) { + local_vals[j] = static_cast(0); + } + const int64_t input_offset = + static_cast(row) * active_num_cols + active_col_offset; +#pragma unroll + for (int j = 0; j < UNROLL_FACTOR; ++j) { + const int32_t col = active_col_offset + j; + if (col >= active_num_cols) { + break; + } + local_vals[j] = active_input_base[input_offset + j]; + } + + bool appended = false; +#pragma unroll + for (int slot = 0; slot < kCacheSlots; ++slot) { + if (cached_valid[slot] && cached_idx[slot] == idx) { +#pragma unroll + for (int j = 0; j < UNROLL_FACTOR; ++j) { + const int32_t col = active_col_offset + j; + if (col >= active_num_cols) { + break; + } + cached_vals[slot][j] += local_vals[j]; + } + appended = true; + break; + } + } + + if (!appended) { + int slot_to_use = -1; +#pragma unroll + for (int slot = 0; slot < kCacheSlots; ++slot) { + if (!cached_valid[slot]) { + slot_to_use = slot; + break; + } + } + if (slot_to_use == -1) { + slot_to_use = 0; + const int64_t row_offset = + static_cast(cached_idx[slot_to_use]) * + active_num_cols; +#pragma unroll + for (int j = 0; j < UNROLL_FACTOR; ++j) { + const int32_t col = active_col_offset + j; + if (col >= active_num_cols) { + break; + } + gpuAtomicAddNoReturn( + active_output_base + row_offset + col, + cached_vals[slot_to_use][j]); + } + cached_valid[slot_to_use] = false; + } + + cached_idx[slot_to_use] = idx; +#pragma unroll + for (int j = 0; j < UNROLL_FACTOR; ++j) { + cached_vals[slot_to_use][j] = local_vals[j]; + } + cached_valid[slot_to_use] = true; } } + + flush_cache(active_output_base, active_num_cols, active_col_offset); } } diff --git a/fbgemm_gpu/test/sparse/index_select_test.py b/fbgemm_gpu/test/sparse/index_select_test.py index 6c61b77bf8..ff4b264e05 100644 --- a/fbgemm_gpu/test/sparse/index_select_test.py +++ b/fbgemm_gpu/test/sparse/index_select_test.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python3 + #!/usr/bin/env python3 # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # @@ -239,6 +239,34 @@ def compare_tensor_groups( {"rtol": 1e-02, "atol": 1e-02} if dtype == torch.half else {}, ) + @unittest.skipIf(not gpu_available, "CUDA not available") + def test_group_index_select_dim0_duplicate_gradients(self) -> None: + device = torch.device("cuda") + dtype = torch.float + + num_rows = 4 + num_cols = 9 + indices = torch.tensor([0, 1, 2, 1, 0, 2], dtype=torch.long, device=device) + + input_tensor = torch.randn( + (num_rows, num_cols), dtype=dtype, device=device + ).requires_grad_(True) + + output_group = torch.ops.fbgemm.group_index_select_dim0( + [input_tensor], [indices] + ) + output = output_group[0] + + grad = torch.arange( + output.numel(), dtype=dtype, device=device + ).view_as(output) + output.backward(grad) + + ref_grad = torch.zeros_like(input_tensor) + ref_grad.index_add_(0, indices, grad) + + torch.testing.assert_close(input_tensor.grad, ref_grad) + @given( num_inputs=st.integers(0, 100), max_input_rows=st.integers(2, 32), From c28c65e3f43cd17a95b5ff41fc46a732b3448caf Mon Sep 17 00:00:00 2001 From: Shreyashri Biswas Date: Tue, 11 Nov 2025 19:06:00 +0000 Subject: [PATCH 2/8] removed EMULATED_WARP_SIZE --- fbgemm_gpu/src/sparse_ops/sparse_group_index.cu | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu index c05c99de13..dc69df2a19 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu @@ -12,18 +12,10 @@ using Tensor = at::Tensor; namespace fbgemm_gpu { -#ifdef USE_ROCM -// The wave size is forced to be 32 on ROCm devices in favor -// of granularity losses reduction. -constexpr int EMULATED_WARP_SIZE = 32; -#else -constexpr int EMULATED_WARP_SIZE = kWarpSize; -#endif - // TODO: Update UNROLL_FACTOR constexpr int GROUP_INDEX_SELECT_UNROLL_FACTOR = 1; constexpr int GROUP_INDEX_SELECT_COLS_PER_WARP = - GROUP_INDEX_SELECT_UNROLL_FACTOR * EMULATED_WARP_SIZE; + GROUP_INDEX_SELECT_UNROLL_FACTOR * kWarpSize; // GROUP_INDEX_SELECT_COLS_PER_WARP must be power of two constexpr int GROUP_INDEX_SELECT_LOG_COLS_PER_WARP = @@ -287,13 +279,13 @@ DLL_PUBLIC void group_index_select_or_add_cuda( at::cuda::OptionalCUDAGuard device_guard(device); // Partition work based on num_work_rows - uint32_t num_warps_per_threadblock = kMaxThreads / EMULATED_WARP_SIZE; + uint32_t num_warps_per_threadblock = kMaxThreads / kWarpSize; uint32_t max_grid_size = at::cuda::getCurrentDeviceProperties()->multiProcessorCount * 8; uint32_t grid_size = std::min( cuda_calc_xblock_count(total_num_warps, num_warps_per_threadblock), max_grid_size); - dim3 block_size(EMULATED_WARP_SIZE, num_warps_per_threadblock, 1); + dim3 block_size(kWarpSize, num_warps_per_threadblock, 1); #define INVOKE_GROUP_INDEX_SELECT_OR_ADD(USE_INDEX_SELECT, USE_VAR_COLS) \ FBGEMM_LAUNCH_KERNEL( \ @@ -340,4 +332,4 @@ DLL_PUBLIC void group_index_select_or_add_cuda( #undef INVOKE_GROUP_INDEX_SELECT_OR_ADD } -} // namespace fbgemm_gpu +} // namespace fbgemm_gpu \ No newline at end of file From 55ab1c4cb33ac2238b9c4d2d320343d8559738e5 Mon Sep 17 00:00:00 2001 From: Shreyashri Biswas Date: Tue, 11 Nov 2025 19:40:46 +0000 Subject: [PATCH 3/8] added rocm guards to index_select kernel --- .../src/sparse_ops/sparse_group_index.cu | 51 ++++++++++++++++++- 1 file changed, 50 insertions(+), 1 deletion(-) diff --git a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu index dc69df2a19..f00cbcf0d2 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu @@ -43,7 +43,10 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( const int64_t num_work_rows, // number of rows to work on per member const int64_t group_size) { const auto total_num_warps = warp_offsets_group[group_size]; + + // USE_INDEX_SELECT is a template argument; the compiler prunes the unused branch. + #ifdef USE_ROCM if (USE_INDEX_SELECT) { for (int64_t warp_id = threadIdx.y * gridDim.x + blockIdx.x; warp_id < total_num_warps; @@ -63,10 +66,42 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( num_cols = num_cols_group[member_id]; warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP; member_warp_id = warp_id - warp_offsets_group[member_id]; - } else { + } +#else + int32_t num_cols = 0; + int32_t warps_per_row = 0; + + if constexpr (!USE_VAR_COLS) { + num_cols = num_cols_group[0]; + warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP; + } + + for (int64_t warp_id = threadIdx.y * gridDim.x + blockIdx.x; + warp_id < total_num_warps; + warp_id += gridDim.x * blockDim.y) { + int32_t member_id = 0; + int32_t member_warp_id = 0; + if constexpr (USE_VAR_COLS) { + __shared__ int member_ids[kMaxThreads / EMULATED_WARP_SIZE]; + if (threadIdx.x == 0) { + binary_search_range( + &member_ids[threadIdx.y], + warp_offsets_group + 1, + warp_id, + group_size); + } + syncwarp(); + member_id = member_ids[threadIdx.y]; + num_cols = num_cols_group[member_id]; + warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP; + member_warp_id = warp_id - warp_offsets_group[member_id]; +#endif + else { // All columns are the same + #ifdef USE_ROCM num_cols = num_cols_group[0]; warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP; + #endif member_id = warp_id / (warps_per_row * num_work_rows); member_warp_id = warp_id - (member_id * warps_per_row * num_work_rows); } @@ -83,10 +118,15 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( const index_t idx = indices[row]; #pragma unroll for (int i = 0; i < UNROLL_FACTOR && col_offset + i < num_cols; i++) { + #ifndef USE_ROCM + if constexpr (USE_INDEX_SELECT) { + #endif output[row * num_cols + i] = LDG(&input[idx * num_cols + i]); +#ifdef USE_ROCM } } } else { + // Cache a handful of scatter destinations per warp so we can merge // consecutive updates that hit the same index before touching global memory. constexpr int kCacheSlots = 2; @@ -148,7 +188,9 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( num_cols = num_cols_group[member_id]; warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP; member_warp_id = warp_id - warp_offsets_group[member_id]; +#endif } else { +#ifdef USE_ROCM // All columns are the same num_cols = num_cols_group[0]; warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP; @@ -257,6 +299,13 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( flush_cache(active_output_base, active_num_cols, active_col_offset); } } +#else + gpuAtomicAddNoReturn( + &output[idx * num_cols + i], input[row * num_cols + i]); + } + } +#endif + DLL_PUBLIC void group_index_select_or_add_cuda( const int64_t* input_ptrs, From 34d4ca7fa4f5f3c9b6a4448bfb5399a73eb1b96d Mon Sep 17 00:00:00 2001 From: Shreyashri Biswas Date: Tue, 11 Nov 2025 20:28:02 +0000 Subject: [PATCH 4/8] removed group index test --- fbgemm_gpu/test/sparse/index_select_test.py | 32 ++------------------- 1 file changed, 2 insertions(+), 30 deletions(-) diff --git a/fbgemm_gpu/test/sparse/index_select_test.py b/fbgemm_gpu/test/sparse/index_select_test.py index ff4b264e05..c057297ff9 100644 --- a/fbgemm_gpu/test/sparse/index_select_test.py +++ b/fbgemm_gpu/test/sparse/index_select_test.py @@ -1,4 +1,4 @@ - #!/usr/bin/env python3 +#!/usr/bin/env python3 # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # @@ -239,35 +239,7 @@ def compare_tensor_groups( {"rtol": 1e-02, "atol": 1e-02} if dtype == torch.half else {}, ) - @unittest.skipIf(not gpu_available, "CUDA not available") - def test_group_index_select_dim0_duplicate_gradients(self) -> None: - device = torch.device("cuda") - dtype = torch.float - - num_rows = 4 - num_cols = 9 - indices = torch.tensor([0, 1, 2, 1, 0, 2], dtype=torch.long, device=device) - - input_tensor = torch.randn( - (num_rows, num_cols), dtype=dtype, device=device - ).requires_grad_(True) - - output_group = torch.ops.fbgemm.group_index_select_dim0( - [input_tensor], [indices] - ) - output = output_group[0] - - grad = torch.arange( - output.numel(), dtype=dtype, device=device - ).view_as(output) - output.backward(grad) - - ref_grad = torch.zeros_like(input_tensor) - ref_grad.index_add_(0, indices, grad) - - torch.testing.assert_close(input_tensor.grad, ref_grad) - - @given( + @given( num_inputs=st.integers(0, 100), max_input_rows=st.integers(2, 32), max_cols_factor=st.integers(2, 256), From c33702c4f6382b946e06ccdf53ce2bd0c501d843 Mon Sep 17 00:00:00 2001 From: Shreyashri Biswas Date: Tue, 11 Nov 2025 22:05:04 +0000 Subject: [PATCH 5/8] added GROUP_INDEX_SELECT_COLS_PER_WARP swap logic for fwd and bwd --- .../src/sparse_ops/sparse_group_index.cu | 91 ++++++++++++------- 1 file changed, 56 insertions(+), 35 deletions(-) diff --git a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu index f00cbcf0d2..dcf3028b0a 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu @@ -12,12 +12,19 @@ using Tensor = at::Tensor; namespace fbgemm_gpu { +#ifdef USE_ROCM +// The wave size is forced to be 32 on ROCm devices in favor +// of granularity losses reduction. +constexpr int EMULATED_WARP_SIZE = 32; +#else +constexpr int EMULATED_WARP_SIZE = kWarpSize; +#endif + // TODO: Update UNROLL_FACTOR constexpr int GROUP_INDEX_SELECT_UNROLL_FACTOR = 1; constexpr int GROUP_INDEX_SELECT_COLS_PER_WARP = - GROUP_INDEX_SELECT_UNROLL_FACTOR * kWarpSize; + GROUP_INDEX_SELECT_UNROLL_FACTOR * EMULATED_WARP_SIZE; -// GROUP_INDEX_SELECT_COLS_PER_WARP must be power of two constexpr int GROUP_INDEX_SELECT_LOG_COLS_PER_WARP = log2_calc::value; @@ -25,6 +32,20 @@ int get_group_index_select_cols_per_warp() { return GROUP_INDEX_SELECT_COLS_PER_WARP; } +#ifdef USE_ROCM +constexpr int GROUP_INDEX_SELECT_COLS_PER_WARP_ADD = + GROUP_INDEX_SELECT_UNROLL_FACTOR * kWarpSize; +constexpr int GROUP_INDEX_SELECT_LOG_COLS_PER_WARP_ADD = + log2_calc::value; +#else +constexpr int GROUP_INDEX_SELECT_COLS_PER_WARP_ADD = + GROUP_INDEX_SELECT_UNROLL_FACTOR * EMULATED_WARP_SIZE; +constexpr int GROUP_INDEX_SELECT_LOG_COLS_PER_WARP_ADD = + log2_calc::value; +#endif + + + template < typename index_t, typename scalar_t, @@ -42,18 +63,18 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( const int32_t* num_cols_group, const int64_t num_work_rows, // number of rows to work on per member const int64_t group_size) { - const auto total_num_warps = warp_offsets_group[group_size]; - - + const auto total_num_warps = warp_offsets_group[group_size]; // USE_INDEX_SELECT is a template argument; the compiler prunes the unused branch. #ifdef USE_ROCM if (USE_INDEX_SELECT) { for (int64_t warp_id = threadIdx.y * gridDim.x + blockIdx.x; warp_id < total_num_warps; warp_id += gridDim.x * blockDim.y) { - int32_t member_id, member_warp_id, num_cols, warps_per_row; + int32_t num_cols, warps_per_row; + int32_t member_id = 0; + int32_t member_warp_id = 0; if (USE_VAR_COLS) { - __shared__ int member_ids[kMaxThreads / kWarpSize]; + __shared__ int member_ids[kMaxThreads / EMULATED_WARP_SIZE]; if (threadIdx.x == 0) { binary_search_range( &member_ids[threadIdx.y], @@ -67,35 +88,35 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP; member_warp_id = warp_id - warp_offsets_group[member_id]; } -#else - int32_t num_cols = 0; - int32_t warps_per_row = 0; + #else + int32_t num_cols = 0; + int32_t warps_per_row = 0; - if constexpr (!USE_VAR_COLS) { - num_cols = num_cols_group[0]; - warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP; - } + if constexpr (!USE_VAR_COLS) { + num_cols = num_cols_group[0]; + warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP; + } - for (int64_t warp_id = threadIdx.y * gridDim.x + blockIdx.x; + for (int64_t warp_id = threadIdx.y * gridDim.x + blockIdx.x; warp_id < total_num_warps; warp_id += gridDim.x * blockDim.y) { - int32_t member_id = 0; - int32_t member_warp_id = 0; - if constexpr (USE_VAR_COLS) { - __shared__ int member_ids[kMaxThreads / EMULATED_WARP_SIZE]; - if (threadIdx.x == 0) { - binary_search_range( - &member_ids[threadIdx.y], - warp_offsets_group + 1, - warp_id, - group_size); - } - syncwarp(); - member_id = member_ids[threadIdx.y]; - num_cols = num_cols_group[member_id]; - warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP; - member_warp_id = warp_id - warp_offsets_group[member_id]; -#endif + int32_t member_id = 0; + int32_t member_warp_id = 0; + if constexpr (USE_VAR_COLS) { + __shared__ int member_ids[kMaxThreads / EMULATED_WARP_SIZE]; + if (threadIdx.x == 0) { + binary_search_range( + &member_ids[threadIdx.y], + warp_offsets_group + 1, + warp_id, + group_size); + } + syncwarp(); + member_id = member_ids[threadIdx.y]; + num_cols = num_cols_group[member_id]; + warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP; + member_warp_id = warp_id - warp_offsets_group[member_id]; + #endif else { // All columns are the same #ifdef USE_ROCM @@ -126,9 +147,9 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( } } } else { - // Cache a handful of scatter destinations per warp so we can merge // consecutive updates that hit the same index before touching global memory. + constexpr int kCacheSlots = 2; index_t cached_idx[kCacheSlots]; scalar_t cached_vals[kCacheSlots][UNROLL_FACTOR]; @@ -344,8 +365,8 @@ DLL_PUBLIC void group_index_select_or_add_cuda( USE_INDEX_SELECT, \ USE_VAR_COLS, \ GROUP_INDEX_SELECT_UNROLL_FACTOR, \ - GROUP_INDEX_SELECT_COLS_PER_WARP, \ - GROUP_INDEX_SELECT_LOG_COLS_PER_WARP>), \ + USE_INDEX_SELECT ? GROUP_INDEX_SELECT_COLS_PER_WARP : GROUP_INDEX_SELECT_COLS_PER_WARP_ADD, \ + USE_INDEX_SELECT ? GROUP_INDEX_SELECT_LOG_COLS_PER_WARP : GROUP_INDEX_SELECT_LOG_COLS_PER_WARP_ADD>), \ grid_size, \ block_size, \ 0, \ From cca797b5ef97eb79a4150651c33cc97a17bcb8a9 Mon Sep 17 00:00:00 2001 From: Shreyashri Biswas Date: Wed, 12 Nov 2025 01:41:05 +0000 Subject: [PATCH 6/8] added seperate var for add and select mode --- .../src/sparse_ops/sparse_group_index.cu | 149 +++++++++--------- 1 file changed, 78 insertions(+), 71 deletions(-) diff --git a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu index dcf3028b0a..641ea11b27 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu @@ -13,7 +13,7 @@ using Tensor = at::Tensor; namespace fbgemm_gpu { #ifdef USE_ROCM -// The wave size is forced to be 32 on ROCm devices in favor +// The wave size is forced to be 32 on ROCm devices in favor // of granularity losses reduction. constexpr int EMULATED_WARP_SIZE = 32; #else @@ -22,29 +22,29 @@ constexpr int EMULATED_WARP_SIZE = kWarpSize; // TODO: Update UNROLL_FACTOR constexpr int GROUP_INDEX_SELECT_UNROLL_FACTOR = 1; -constexpr int GROUP_INDEX_SELECT_COLS_PER_WARP = + +// SELECT (fwd): use EMULATED_WARP_SIZE +constexpr int GROUP_INDEX_SELECT_COLS_PER_WARP_FWD = GROUP_INDEX_SELECT_UNROLL_FACTOR * EMULATED_WARP_SIZE; +constexpr int GROUP_INDEX_SELECT_LOG_COLS_PER_WARP_FWD = + log2_calc::value; + +// ADD (bwd): use kWarpSize +constexpr int GROUP_INDEX_SELECT_COLS_PER_WARP_BWD = + GROUP_INDEX_SELECT_UNROLL_FACTOR * kWarpSize; +constexpr int GROUP_INDEX_SELECT_LOG_COLS_PER_WARP_BWD = + log2_calc::value; -constexpr int GROUP_INDEX_SELECT_LOG_COLS_PER_WARP = - log2_calc::value; int get_group_index_select_cols_per_warp() { - return GROUP_INDEX_SELECT_COLS_PER_WARP; + return GROUP_INDEX_SELECT_COLS_PER_WARP_BWD; } -#ifdef USE_ROCM -constexpr int GROUP_INDEX_SELECT_COLS_PER_WARP_ADD = - GROUP_INDEX_SELECT_UNROLL_FACTOR * kWarpSize; -constexpr int GROUP_INDEX_SELECT_LOG_COLS_PER_WARP_ADD = - log2_calc::value; -#else -constexpr int GROUP_INDEX_SELECT_COLS_PER_WARP_ADD = - GROUP_INDEX_SELECT_UNROLL_FACTOR * EMULATED_WARP_SIZE; -constexpr int GROUP_INDEX_SELECT_LOG_COLS_PER_WARP_ADD = - log2_calc::value; -#endif - - +// New: explicit selector +int get_group_index_select_cols_per_warp(bool use_index_select) { + return use_index_select ? GROUP_INDEX_SELECT_COLS_PER_WARP_FWD + : GROUP_INDEX_SELECT_COLS_PER_WARP_BWD; +} template < typename index_t, @@ -63,16 +63,15 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( const int32_t* num_cols_group, const int64_t num_work_rows, // number of rows to work on per member const int64_t group_size) { - const auto total_num_warps = warp_offsets_group[group_size]; + const auto total_num_warps = warp_offsets_group[group_size]; + +#ifdef USE_ROCM // USE_INDEX_SELECT is a template argument; the compiler prunes the unused branch. - #ifdef USE_ROCM if (USE_INDEX_SELECT) { for (int64_t warp_id = threadIdx.y * gridDim.x + blockIdx.x; warp_id < total_num_warps; warp_id += gridDim.x * blockDim.y) { - int32_t num_cols, warps_per_row; - int32_t member_id = 0; - int32_t member_warp_id = 0; + int32_t member_id, member_warp_id, num_cols, warps_per_row; if (USE_VAR_COLS) { __shared__ int member_ids[kMaxThreads / EMULATED_WARP_SIZE]; if (threadIdx.x == 0) { @@ -87,42 +86,10 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( num_cols = num_cols_group[member_id]; warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP; member_warp_id = warp_id - warp_offsets_group[member_id]; - } - #else - int32_t num_cols = 0; - int32_t warps_per_row = 0; - - if constexpr (!USE_VAR_COLS) { - num_cols = num_cols_group[0]; - warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP; - } - - for (int64_t warp_id = threadIdx.y * gridDim.x + blockIdx.x; - warp_id < total_num_warps; - warp_id += gridDim.x * blockDim.y) { - int32_t member_id = 0; - int32_t member_warp_id = 0; - if constexpr (USE_VAR_COLS) { - __shared__ int member_ids[kMaxThreads / EMULATED_WARP_SIZE]; - if (threadIdx.x == 0) { - binary_search_range( - &member_ids[threadIdx.y], - warp_offsets_group + 1, - warp_id, - group_size); - } - syncwarp(); - member_id = member_ids[threadIdx.y]; - num_cols = num_cols_group[member_id]; - warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP; - member_warp_id = warp_id - warp_offsets_group[member_id]; - #endif - else { + } else { // All columns are the same - #ifdef USE_ROCM num_cols = num_cols_group[0]; warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP; - #endif member_id = warp_id / (warps_per_row * num_work_rows); member_warp_id = warp_id - (member_id * warps_per_row * num_work_rows); } @@ -139,17 +106,12 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( const index_t idx = indices[row]; #pragma unroll for (int i = 0; i < UNROLL_FACTOR && col_offset + i < num_cols; i++) { - #ifndef USE_ROCM - if constexpr (USE_INDEX_SELECT) { - #endif output[row * num_cols + i] = LDG(&input[idx * num_cols + i]); -#ifdef USE_ROCM } } } else { // Cache a handful of scatter destinations per warp so we can merge // consecutive updates that hit the same index before touching global memory. - constexpr int kCacheSlots = 2; index_t cached_idx[kCacheSlots]; scalar_t cached_vals[kCacheSlots][UNROLL_FACTOR]; @@ -209,9 +171,7 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( num_cols = num_cols_group[member_id]; warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP; member_warp_id = warp_id - warp_offsets_group[member_id]; -#endif } else { -#ifdef USE_ROCM // All columns are the same num_cols = num_cols_group[0]; warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP; @@ -319,14 +279,56 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( flush_cache(active_output_base, active_num_cols, active_col_offset); } -} -#else - gpuAtomicAddNoReturn( +#else // Original CUDA implementation + for (int64_t warp_id = threadIdx.y * gridDim.x + blockIdx.x; + warp_id < total_num_warps; + warp_id += gridDim.x * blockDim.y) { + int32_t member_id, member_warp_id, num_cols, warps_per_row; + if (USE_VAR_COLS) { + __shared__ int member_ids[kMaxThreads / kWarpSize]; + if (threadIdx.x == 0) { + binary_search_range( + &member_ids[threadIdx.y], + warp_offsets_group + 1, + warp_id, + group_size); + } + syncwarp(); + member_id = member_ids[threadIdx.y]; + num_cols = num_cols_group[member_id]; + warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP; + member_warp_id = warp_id - warp_offsets_group[member_id]; + } else { + // All columns are the same + num_cols = num_cols_group[0]; + warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP; + member_id = warp_id / (warps_per_row * num_work_rows); + member_warp_id = warp_id - (member_id * warps_per_row * num_work_rows); + } + const auto row = member_warp_id / warps_per_row; + const auto col_offset = + ((member_warp_id % warps_per_row) << LOG_COLS_PER_WARP) + + (threadIdx.x * UNROLL_FACTOR); + scalar_t* input = + reinterpret_cast(input_ptrs[member_id]) + col_offset; + scalar_t* output = + reinterpret_cast(output_ptrs[member_id]) + col_offset; + + index_t* indices = reinterpret_cast(indices_ptrs[member_id]); + const index_t idx = indices[row]; +#pragma unroll + for (int i = 0; i < UNROLL_FACTOR && col_offset + i < num_cols; i++) { + // Compile time conditional + if (USE_INDEX_SELECT) { + output[row * num_cols + i] = LDG(&input[idx * num_cols + i]); + } else { + gpuAtomicAddNoReturn( &output[idx * num_cols + i], input[row * num_cols + i]); + } } } -#endif - +#endif // USE_ROCM +} DLL_PUBLIC void group_index_select_or_add_cuda( const int64_t* input_ptrs, @@ -349,7 +351,7 @@ DLL_PUBLIC void group_index_select_or_add_cuda( at::cuda::OptionalCUDAGuard device_guard(device); // Partition work based on num_work_rows - uint32_t num_warps_per_threadblock = kMaxThreads / kWarpSize; + uint32_t num_warps_per_threadblock = kMaxThreads / kWarpSize; uint32_t max_grid_size = at::cuda::getCurrentDeviceProperties()->multiProcessorCount * 8; uint32_t grid_size = std::min( @@ -365,8 +367,13 @@ DLL_PUBLIC void group_index_select_or_add_cuda( USE_INDEX_SELECT, \ USE_VAR_COLS, \ GROUP_INDEX_SELECT_UNROLL_FACTOR, \ - USE_INDEX_SELECT ? GROUP_INDEX_SELECT_COLS_PER_WARP : GROUP_INDEX_SELECT_COLS_PER_WARP_ADD, \ - USE_INDEX_SELECT ? GROUP_INDEX_SELECT_LOG_COLS_PER_WARP : GROUP_INDEX_SELECT_LOG_COLS_PER_WARP_ADD>), \ + (USE_INDEX_SELECT \ + ? GROUP_INDEX_SELECT_COLS_PER_WARP_FWD \ + : GROUP_INDEX_SELECT_COLS_PER_WARP_BWD), \ + /* LOG_COLS_PER_WARP */ \ + (USE_INDEX_SELECT \ + ? GROUP_INDEX_SELECT_LOG_COLS_PER_WARP_FWD \ + : GROUP_INDEX_SELECT_LOG_COLS_PER_WARP_BWD)>), \ grid_size, \ block_size, \ 0, \ From c87df8429108c36ed93a096f254ba84a31f8048c Mon Sep 17 00:00:00 2001 From: Shreyashri Biswas Date: Wed, 12 Nov 2025 02:10:53 +0000 Subject: [PATCH 7/8] removed comments --- fbgemm_gpu/src/sparse_ops/sparse_group_index.cu | 14 +++----------- fbgemm_gpu/test/sparse/index_select_test.py | 2 +- 2 files changed, 4 insertions(+), 12 deletions(-) diff --git a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu index 641ea11b27..729331dd26 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu @@ -35,12 +35,10 @@ constexpr int GROUP_INDEX_SELECT_COLS_PER_WARP_BWD = constexpr int GROUP_INDEX_SELECT_LOG_COLS_PER_WARP_BWD = log2_calc::value; - int get_group_index_select_cols_per_warp() { return GROUP_INDEX_SELECT_COLS_PER_WARP_BWD; } -// New: explicit selector int get_group_index_select_cols_per_warp(bool use_index_select) { return use_index_select ? GROUP_INDEX_SELECT_COLS_PER_WARP_FWD : GROUP_INDEX_SELECT_COLS_PER_WARP_BWD; @@ -61,7 +59,7 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( const int64_t* indices_ptrs, const int64_t* warp_offsets_group, const int32_t* num_cols_group, - const int64_t num_work_rows, // number of rows to work on per member + const int64_t num_work_rows, const int64_t group_size) { const auto total_num_warps = warp_offsets_group[group_size]; @@ -87,7 +85,6 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP; member_warp_id = warp_id - warp_offsets_group[member_id]; } else { - // All columns are the same num_cols = num_cols_group[0]; warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP; member_id = warp_id / (warps_per_row * num_work_rows); @@ -110,8 +107,6 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( } } } else { - // Cache a handful of scatter destinations per warp so we can merge - // consecutive updates that hit the same index before touching global memory. constexpr int kCacheSlots = 2; index_t cached_idx[kCacheSlots]; scalar_t cached_vals[kCacheSlots][UNROLL_FACTOR]; @@ -299,7 +294,6 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP; member_warp_id = warp_id - warp_offsets_group[member_id]; } else { - // All columns are the same num_cols = num_cols_group[0]; warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP; member_id = warp_id / (warps_per_row * num_work_rows); @@ -327,7 +321,7 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( } } } -#endif // USE_ROCM +#endif } DLL_PUBLIC void group_index_select_or_add_cuda( @@ -349,9 +343,7 @@ DLL_PUBLIC void group_index_select_or_add_cuda( } at::cuda::OptionalCUDAGuard device_guard(device); - - // Partition work based on num_work_rows - uint32_t num_warps_per_threadblock = kMaxThreads / kWarpSize; + uint32_t num_warps_per_threadblock = kMaxThreads / kWarpSize; uint32_t max_grid_size = at::cuda::getCurrentDeviceProperties()->multiProcessorCount * 8; uint32_t grid_size = std::min( diff --git a/fbgemm_gpu/test/sparse/index_select_test.py b/fbgemm_gpu/test/sparse/index_select_test.py index c057297ff9..6c61b77bf8 100644 --- a/fbgemm_gpu/test/sparse/index_select_test.py +++ b/fbgemm_gpu/test/sparse/index_select_test.py @@ -239,7 +239,7 @@ def compare_tensor_groups( {"rtol": 1e-02, "atol": 1e-02} if dtype == torch.half else {}, ) - @given( + @given( num_inputs=st.integers(0, 100), max_input_rows=st.integers(2, 32), max_cols_factor=st.integers(2, 256), From 735b8035ae3de118413e26fa29f1c772dbc6be00 Mon Sep 17 00:00:00 2001 From: Shreyashri Biswas Date: Wed, 12 Nov 2025 02:33:53 +0000 Subject: [PATCH 8/8] fix grid and block size --- .../src/sparse_ops/sparse_group_index.cu | 21 +++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu index 729331dd26..722c204dd8 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu @@ -62,10 +62,11 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( const int64_t num_work_rows, const int64_t group_size) { const auto total_num_warps = warp_offsets_group[group_size]; - + int32_t num_cols = 0; + int32_t warps_per_row = 0; #ifdef USE_ROCM // USE_INDEX_SELECT is a template argument; the compiler prunes the unused branch. - if (USE_INDEX_SELECT) { + if constexpr (USE_INDEX_SELECT) { for (int64_t warp_id = threadIdx.y * gridDim.x + blockIdx.x; warp_id < total_num_warps; warp_id += gridDim.x * blockDim.y) { @@ -343,13 +344,25 @@ DLL_PUBLIC void group_index_select_or_add_cuda( } at::cuda::OptionalCUDAGuard device_guard(device); - uint32_t num_warps_per_threadblock = kMaxThreads / kWarpSize; + uint32_t num_warps_per_threadblock; + dim3 block_size; + + if (use_index_select) { + // Forward pass uses EMULATED_WARP_SIZE + num_warps_per_threadblock = kMaxThreads / EMULATED_WARP_SIZE; + block_size = dim3(EMULATED_WARP_SIZE, num_warps_per_threadblock, 1); + } else { + // Backward pass uses kWarpSize + num_warps_per_threadblock = kMaxThreads / kWarpSize; + block_size = dim3(kWarpSize, num_warps_per_threadblock, 1); + } + uint32_t max_grid_size = at::cuda::getCurrentDeviceProperties()->multiProcessorCount * 8; uint32_t grid_size = std::min( cuda_calc_xblock_count(total_num_warps, num_warps_per_threadblock), max_grid_size); - dim3 block_size(kWarpSize, num_warps_per_threadblock, 1); + #define INVOKE_GROUP_INDEX_SELECT_OR_ADD(USE_INDEX_SELECT, USE_VAR_COLS) \ FBGEMM_LAUNCH_KERNEL( \