From e0dd7f88c005da5fc1b0b90fe4e0a15d3c970f84 Mon Sep 17 00:00:00 2001
From: Li Li
Date: Fri, 7 Nov 2025 21:54:01 +0000
Subject: [PATCH 1/8] opt group_index_select_or_add_2d_kernel
---
.../src/sparse_ops/sparse_group_index.cu | 251 ++++++++++++++----
fbgemm_gpu/test/sparse/index_select_test.py | 30 ++-
2 files changed, 234 insertions(+), 47 deletions(-)
diff --git a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu
index 96c57cde68..c05c99de13 100644
--- a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu
+++ b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu
@@ -51,59 +51,218 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel(
const int64_t num_work_rows, // number of rows to work on per member
const int64_t group_size) {
const auto total_num_warps = warp_offsets_group[group_size];
- int32_t num_cols = 0;
- int32_t warps_per_row = 0;
-
- if constexpr (!USE_VAR_COLS) {
- num_cols = num_cols_group[0];
- warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP;
- }
+ // USE_INDEX_SELECT is a template argument; the compiler prunes the unused branch.
+ if (USE_INDEX_SELECT) {
+ for (int64_t warp_id = threadIdx.y * gridDim.x + blockIdx.x;
+ warp_id < total_num_warps;
+ warp_id += gridDim.x * blockDim.y) {
+ int32_t member_id, member_warp_id, num_cols, warps_per_row;
+ if (USE_VAR_COLS) {
+ __shared__ int member_ids[kMaxThreads / kWarpSize];
+ if (threadIdx.x == 0) {
+ binary_search_range(
+ &member_ids[threadIdx.y],
+ warp_offsets_group + 1,
+ warp_id,
+ group_size);
+ }
+ syncwarp();
+ member_id = member_ids[threadIdx.y];
+ num_cols = num_cols_group[member_id];
+ warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP;
+ member_warp_id = warp_id - warp_offsets_group[member_id];
+ } else {
+ // All columns are the same
+ num_cols = num_cols_group[0];
+ warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP;
+ member_id = warp_id / (warps_per_row * num_work_rows);
+ member_warp_id = warp_id - (member_id * warps_per_row * num_work_rows);
+ }
+ const auto row = member_warp_id / warps_per_row;
+ const auto col_offset =
+ ((member_warp_id % warps_per_row) << LOG_COLS_PER_WARP) +
+ (threadIdx.x * UNROLL_FACTOR);
+ scalar_t* input =
+ reinterpret_cast(input_ptrs[member_id]) + col_offset;
+ scalar_t* output =
+ reinterpret_cast(output_ptrs[member_id]) + col_offset;
- for (int64_t warp_id = threadIdx.y * gridDim.x + blockIdx.x;
- warp_id < total_num_warps;
- warp_id += gridDim.x * blockDim.y) {
- int32_t member_id = 0;
- int32_t member_warp_id = 0;
- if constexpr (USE_VAR_COLS) {
- __shared__ int member_ids[kMaxThreads / EMULATED_WARP_SIZE];
- if (threadIdx.x == 0) {
- binary_search_range(
- &member_ids[threadIdx.y],
- warp_offsets_group + 1,
- warp_id,
- group_size);
+ index_t* indices = reinterpret_cast(indices_ptrs[member_id]);
+ const index_t idx = indices[row];
+#pragma unroll
+ for (int i = 0; i < UNROLL_FACTOR && col_offset + i < num_cols; i++) {
+ output[row * num_cols + i] = LDG(&input[idx * num_cols + i]);
}
- syncwarp();
- member_id = member_ids[threadIdx.y];
- num_cols = num_cols_group[member_id];
- warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP;
- member_warp_id = warp_id - warp_offsets_group[member_id];
- } else {
- // All columns are the same
- member_id = warp_id / (warps_per_row * num_work_rows);
- member_warp_id = warp_id - (member_id * warps_per_row * num_work_rows);
}
- const auto row = member_warp_id / warps_per_row;
- const auto col_offset =
- ((member_warp_id % warps_per_row) << LOG_COLS_PER_WARP) +
- (threadIdx.x * UNROLL_FACTOR);
- scalar_t* input =
- reinterpret_cast(input_ptrs[member_id]) + col_offset;
- scalar_t* output =
- reinterpret_cast(output_ptrs[member_id]) + col_offset;
-
- index_t* indices = reinterpret_cast(indices_ptrs[member_id]);
- const index_t idx = indices[row];
+ } else {
+ // Cache a handful of scatter destinations per warp so we can merge
+ // consecutive updates that hit the same index before touching global memory.
+ constexpr int kCacheSlots = 2;
+ index_t cached_idx[kCacheSlots];
+ scalar_t cached_vals[kCacheSlots][UNROLL_FACTOR];
+ bool cached_valid[kCacheSlots];
#pragma unroll
- for (int i = 0; i < UNROLL_FACTOR && col_offset + i < num_cols; i++) {
- // Compile time conditional
- if constexpr (USE_INDEX_SELECT) {
- output[row * num_cols + i] = LDG(&input[idx * num_cols + i]);
+ for (int slot = 0; slot < kCacheSlots; ++slot) {
+ cached_valid[slot] = false;
+ }
+ int32_t active_member_id = -1;
+ int32_t active_num_cols = 0;
+ int32_t active_col_offset = -1;
+ scalar_t* active_input_base = nullptr;
+ scalar_t* active_output_base = nullptr;
+ index_t* active_indices = nullptr;
+
+ auto flush_cache = [&](scalar_t* out_base,
+ int32_t num_cols,
+ int32_t col_offset) {
+ if (!out_base) {
+ return;
+ }
+#pragma unroll
+ for (int slot = 0; slot < kCacheSlots; ++slot) {
+ if (!cached_valid[slot]) {
+ continue;
+ }
+ const int64_t row_offset =
+ static_cast(cached_idx[slot]) * num_cols;
+#pragma unroll
+ for (int j = 0; j < UNROLL_FACTOR; ++j) {
+ const int32_t col = col_offset + j;
+ if (col >= num_cols) {
+ break;
+ }
+ gpuAtomicAddNoReturn(
+ out_base + row_offset + col, cached_vals[slot][j]);
+ }
+ cached_valid[slot] = false;
+ }
+ };
+
+ for (int64_t warp_id = threadIdx.y * gridDim.x + blockIdx.x;
+ warp_id < total_num_warps;
+ warp_id += gridDim.x * blockDim.y) {
+ int32_t member_id, member_warp_id, num_cols, warps_per_row;
+ if (USE_VAR_COLS) {
+ __shared__ int member_ids[kMaxThreads / kWarpSize];
+ if (threadIdx.x == 0) {
+ binary_search_range(
+ &member_ids[threadIdx.y],
+ warp_offsets_group + 1,
+ warp_id,
+ group_size);
+ }
+ syncwarp();
+ member_id = member_ids[threadIdx.y];
+ num_cols = num_cols_group[member_id];
+ warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP;
+ member_warp_id = warp_id - warp_offsets_group[member_id];
} else {
- gpuAtomicAddNoReturn(
- &output[idx * num_cols + i], input[row * num_cols + i]);
+ // All columns are the same
+ num_cols = num_cols_group[0];
+ warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP;
+ member_id = warp_id / (warps_per_row * num_work_rows);
+ member_warp_id = warp_id - (member_id * warps_per_row * num_work_rows);
+ }
+ const int64_t row = member_warp_id / warps_per_row;
+ const int32_t col_offset =
+ static_cast(((member_warp_id % warps_per_row)
+ << LOG_COLS_PER_WARP) +
+ (threadIdx.x * UNROLL_FACTOR));
+
+ const bool member_changed = member_id != active_member_id;
+ const bool num_cols_changed =
+ member_changed ? false : (num_cols != active_num_cols);
+ const bool col_changed =
+ member_changed ? false : (col_offset != active_col_offset);
+ if (member_changed || num_cols_changed || col_changed) {
+ flush_cache(active_output_base, active_num_cols, active_col_offset);
+ active_member_id = member_id;
+ active_num_cols = num_cols;
+ active_col_offset = col_offset;
+ active_input_base =
+ reinterpret_cast(input_ptrs[member_id]);
+ active_output_base =
+ reinterpret_cast(output_ptrs[member_id]);
+ active_indices =
+ reinterpret_cast(indices_ptrs[member_id]);
+ }
+
+ if (col_offset >= active_num_cols) {
+ continue;
+ }
+
+ const index_t idx = active_indices[row];
+
+ scalar_t local_vals[UNROLL_FACTOR];
+#pragma unroll
+ for (int j = 0; j < UNROLL_FACTOR; ++j) {
+ local_vals[j] = static_cast(0);
+ }
+ const int64_t input_offset =
+ static_cast(row) * active_num_cols + active_col_offset;
+#pragma unroll
+ for (int j = 0; j < UNROLL_FACTOR; ++j) {
+ const int32_t col = active_col_offset + j;
+ if (col >= active_num_cols) {
+ break;
+ }
+ local_vals[j] = active_input_base[input_offset + j];
+ }
+
+ bool appended = false;
+#pragma unroll
+ for (int slot = 0; slot < kCacheSlots; ++slot) {
+ if (cached_valid[slot] && cached_idx[slot] == idx) {
+#pragma unroll
+ for (int j = 0; j < UNROLL_FACTOR; ++j) {
+ const int32_t col = active_col_offset + j;
+ if (col >= active_num_cols) {
+ break;
+ }
+ cached_vals[slot][j] += local_vals[j];
+ }
+ appended = true;
+ break;
+ }
+ }
+
+ if (!appended) {
+ int slot_to_use = -1;
+#pragma unroll
+ for (int slot = 0; slot < kCacheSlots; ++slot) {
+ if (!cached_valid[slot]) {
+ slot_to_use = slot;
+ break;
+ }
+ }
+ if (slot_to_use == -1) {
+ slot_to_use = 0;
+ const int64_t row_offset =
+ static_cast(cached_idx[slot_to_use]) *
+ active_num_cols;
+#pragma unroll
+ for (int j = 0; j < UNROLL_FACTOR; ++j) {
+ const int32_t col = active_col_offset + j;
+ if (col >= active_num_cols) {
+ break;
+ }
+ gpuAtomicAddNoReturn(
+ active_output_base + row_offset + col,
+ cached_vals[slot_to_use][j]);
+ }
+ cached_valid[slot_to_use] = false;
+ }
+
+ cached_idx[slot_to_use] = idx;
+#pragma unroll
+ for (int j = 0; j < UNROLL_FACTOR; ++j) {
+ cached_vals[slot_to_use][j] = local_vals[j];
+ }
+ cached_valid[slot_to_use] = true;
}
}
+
+ flush_cache(active_output_base, active_num_cols, active_col_offset);
}
}
diff --git a/fbgemm_gpu/test/sparse/index_select_test.py b/fbgemm_gpu/test/sparse/index_select_test.py
index 6c61b77bf8..ff4b264e05 100644
--- a/fbgemm_gpu/test/sparse/index_select_test.py
+++ b/fbgemm_gpu/test/sparse/index_select_test.py
@@ -1,4 +1,4 @@
-#!/usr/bin/env python3
+ #!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
@@ -239,6 +239,34 @@ def compare_tensor_groups(
{"rtol": 1e-02, "atol": 1e-02} if dtype == torch.half else {},
)
+ @unittest.skipIf(not gpu_available, "CUDA not available")
+ def test_group_index_select_dim0_duplicate_gradients(self) -> None:
+ device = torch.device("cuda")
+ dtype = torch.float
+
+ num_rows = 4
+ num_cols = 9
+ indices = torch.tensor([0, 1, 2, 1, 0, 2], dtype=torch.long, device=device)
+
+ input_tensor = torch.randn(
+ (num_rows, num_cols), dtype=dtype, device=device
+ ).requires_grad_(True)
+
+ output_group = torch.ops.fbgemm.group_index_select_dim0(
+ [input_tensor], [indices]
+ )
+ output = output_group[0]
+
+ grad = torch.arange(
+ output.numel(), dtype=dtype, device=device
+ ).view_as(output)
+ output.backward(grad)
+
+ ref_grad = torch.zeros_like(input_tensor)
+ ref_grad.index_add_(0, indices, grad)
+
+ torch.testing.assert_close(input_tensor.grad, ref_grad)
+
@given(
num_inputs=st.integers(0, 100),
max_input_rows=st.integers(2, 32),
From c28c65e3f43cd17a95b5ff41fc46a732b3448caf Mon Sep 17 00:00:00 2001
From: Shreyashri Biswas
Date: Tue, 11 Nov 2025 19:06:00 +0000
Subject: [PATCH 2/8] removed EMULATED_WARP_SIZE
---
fbgemm_gpu/src/sparse_ops/sparse_group_index.cu | 16 ++++------------
1 file changed, 4 insertions(+), 12 deletions(-)
diff --git a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu
index c05c99de13..dc69df2a19 100644
--- a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu
+++ b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu
@@ -12,18 +12,10 @@ using Tensor = at::Tensor;
namespace fbgemm_gpu {
-#ifdef USE_ROCM
-// The wave size is forced to be 32 on ROCm devices in favor
-// of granularity losses reduction.
-constexpr int EMULATED_WARP_SIZE = 32;
-#else
-constexpr int EMULATED_WARP_SIZE = kWarpSize;
-#endif
-
// TODO: Update UNROLL_FACTOR
constexpr int GROUP_INDEX_SELECT_UNROLL_FACTOR = 1;
constexpr int GROUP_INDEX_SELECT_COLS_PER_WARP =
- GROUP_INDEX_SELECT_UNROLL_FACTOR * EMULATED_WARP_SIZE;
+ GROUP_INDEX_SELECT_UNROLL_FACTOR * kWarpSize;
// GROUP_INDEX_SELECT_COLS_PER_WARP must be power of two
constexpr int GROUP_INDEX_SELECT_LOG_COLS_PER_WARP =
@@ -287,13 +279,13 @@ DLL_PUBLIC void group_index_select_or_add_cuda(
at::cuda::OptionalCUDAGuard device_guard(device);
// Partition work based on num_work_rows
- uint32_t num_warps_per_threadblock = kMaxThreads / EMULATED_WARP_SIZE;
+ uint32_t num_warps_per_threadblock = kMaxThreads / kWarpSize;
uint32_t max_grid_size =
at::cuda::getCurrentDeviceProperties()->multiProcessorCount * 8;
uint32_t grid_size = std::min(
cuda_calc_xblock_count(total_num_warps, num_warps_per_threadblock),
max_grid_size);
- dim3 block_size(EMULATED_WARP_SIZE, num_warps_per_threadblock, 1);
+ dim3 block_size(kWarpSize, num_warps_per_threadblock, 1);
#define INVOKE_GROUP_INDEX_SELECT_OR_ADD(USE_INDEX_SELECT, USE_VAR_COLS) \
FBGEMM_LAUNCH_KERNEL( \
@@ -340,4 +332,4 @@ DLL_PUBLIC void group_index_select_or_add_cuda(
#undef INVOKE_GROUP_INDEX_SELECT_OR_ADD
}
-} // namespace fbgemm_gpu
+} // namespace fbgemm_gpu
\ No newline at end of file
From 55ab1c4cb33ac2238b9c4d2d320343d8559738e5 Mon Sep 17 00:00:00 2001
From: Shreyashri Biswas
Date: Tue, 11 Nov 2025 19:40:46 +0000
Subject: [PATCH 3/8] added rocm guards to index_select kernel
---
.../src/sparse_ops/sparse_group_index.cu | 51 ++++++++++++++++++-
1 file changed, 50 insertions(+), 1 deletion(-)
diff --git a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu
index dc69df2a19..f00cbcf0d2 100644
--- a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu
+++ b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu
@@ -43,7 +43,10 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel(
const int64_t num_work_rows, // number of rows to work on per member
const int64_t group_size) {
const auto total_num_warps = warp_offsets_group[group_size];
+
+
// USE_INDEX_SELECT is a template argument; the compiler prunes the unused branch.
+ #ifdef USE_ROCM
if (USE_INDEX_SELECT) {
for (int64_t warp_id = threadIdx.y * gridDim.x + blockIdx.x;
warp_id < total_num_warps;
@@ -63,10 +66,42 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel(
num_cols = num_cols_group[member_id];
warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP;
member_warp_id = warp_id - warp_offsets_group[member_id];
- } else {
+ }
+#else
+ int32_t num_cols = 0;
+ int32_t warps_per_row = 0;
+
+ if constexpr (!USE_VAR_COLS) {
+ num_cols = num_cols_group[0];
+ warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP;
+ }
+
+ for (int64_t warp_id = threadIdx.y * gridDim.x + blockIdx.x;
+ warp_id < total_num_warps;
+ warp_id += gridDim.x * blockDim.y) {
+ int32_t member_id = 0;
+ int32_t member_warp_id = 0;
+ if constexpr (USE_VAR_COLS) {
+ __shared__ int member_ids[kMaxThreads / EMULATED_WARP_SIZE];
+ if (threadIdx.x == 0) {
+ binary_search_range(
+ &member_ids[threadIdx.y],
+ warp_offsets_group + 1,
+ warp_id,
+ group_size);
+ }
+ syncwarp();
+ member_id = member_ids[threadIdx.y];
+ num_cols = num_cols_group[member_id];
+ warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP;
+ member_warp_id = warp_id - warp_offsets_group[member_id];
+#endif
+ else {
// All columns are the same
+ #ifdef USE_ROCM
num_cols = num_cols_group[0];
warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP;
+ #endif
member_id = warp_id / (warps_per_row * num_work_rows);
member_warp_id = warp_id - (member_id * warps_per_row * num_work_rows);
}
@@ -83,10 +118,15 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel(
const index_t idx = indices[row];
#pragma unroll
for (int i = 0; i < UNROLL_FACTOR && col_offset + i < num_cols; i++) {
+ #ifndef USE_ROCM
+ if constexpr (USE_INDEX_SELECT) {
+ #endif
output[row * num_cols + i] = LDG(&input[idx * num_cols + i]);
+#ifdef USE_ROCM
}
}
} else {
+
// Cache a handful of scatter destinations per warp so we can merge
// consecutive updates that hit the same index before touching global memory.
constexpr int kCacheSlots = 2;
@@ -148,7 +188,9 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel(
num_cols = num_cols_group[member_id];
warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP;
member_warp_id = warp_id - warp_offsets_group[member_id];
+#endif
} else {
+#ifdef USE_ROCM
// All columns are the same
num_cols = num_cols_group[0];
warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP;
@@ -257,6 +299,13 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel(
flush_cache(active_output_base, active_num_cols, active_col_offset);
}
}
+#else
+ gpuAtomicAddNoReturn(
+ &output[idx * num_cols + i], input[row * num_cols + i]);
+ }
+ }
+#endif
+
DLL_PUBLIC void group_index_select_or_add_cuda(
const int64_t* input_ptrs,
From 34d4ca7fa4f5f3c9b6a4448bfb5399a73eb1b96d Mon Sep 17 00:00:00 2001
From: Shreyashri Biswas
Date: Tue, 11 Nov 2025 20:28:02 +0000
Subject: [PATCH 4/8] removed group index test
---
fbgemm_gpu/test/sparse/index_select_test.py | 32 ++-------------------
1 file changed, 2 insertions(+), 30 deletions(-)
diff --git a/fbgemm_gpu/test/sparse/index_select_test.py b/fbgemm_gpu/test/sparse/index_select_test.py
index ff4b264e05..c057297ff9 100644
--- a/fbgemm_gpu/test/sparse/index_select_test.py
+++ b/fbgemm_gpu/test/sparse/index_select_test.py
@@ -1,4 +1,4 @@
- #!/usr/bin/env python3
+#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
@@ -239,35 +239,7 @@ def compare_tensor_groups(
{"rtol": 1e-02, "atol": 1e-02} if dtype == torch.half else {},
)
- @unittest.skipIf(not gpu_available, "CUDA not available")
- def test_group_index_select_dim0_duplicate_gradients(self) -> None:
- device = torch.device("cuda")
- dtype = torch.float
-
- num_rows = 4
- num_cols = 9
- indices = torch.tensor([0, 1, 2, 1, 0, 2], dtype=torch.long, device=device)
-
- input_tensor = torch.randn(
- (num_rows, num_cols), dtype=dtype, device=device
- ).requires_grad_(True)
-
- output_group = torch.ops.fbgemm.group_index_select_dim0(
- [input_tensor], [indices]
- )
- output = output_group[0]
-
- grad = torch.arange(
- output.numel(), dtype=dtype, device=device
- ).view_as(output)
- output.backward(grad)
-
- ref_grad = torch.zeros_like(input_tensor)
- ref_grad.index_add_(0, indices, grad)
-
- torch.testing.assert_close(input_tensor.grad, ref_grad)
-
- @given(
+ @given(
num_inputs=st.integers(0, 100),
max_input_rows=st.integers(2, 32),
max_cols_factor=st.integers(2, 256),
From c33702c4f6382b946e06ccdf53ce2bd0c501d843 Mon Sep 17 00:00:00 2001
From: Shreyashri Biswas
Date: Tue, 11 Nov 2025 22:05:04 +0000
Subject: [PATCH 5/8] added GROUP_INDEX_SELECT_COLS_PER_WARP swap logic for fwd
and bwd
---
.../src/sparse_ops/sparse_group_index.cu | 91 ++++++++++++-------
1 file changed, 56 insertions(+), 35 deletions(-)
diff --git a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu
index f00cbcf0d2..dcf3028b0a 100644
--- a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu
+++ b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu
@@ -12,12 +12,19 @@ using Tensor = at::Tensor;
namespace fbgemm_gpu {
+#ifdef USE_ROCM
+// The wave size is forced to be 32 on ROCm devices in favor
+// of granularity losses reduction.
+constexpr int EMULATED_WARP_SIZE = 32;
+#else
+constexpr int EMULATED_WARP_SIZE = kWarpSize;
+#endif
+
// TODO: Update UNROLL_FACTOR
constexpr int GROUP_INDEX_SELECT_UNROLL_FACTOR = 1;
constexpr int GROUP_INDEX_SELECT_COLS_PER_WARP =
- GROUP_INDEX_SELECT_UNROLL_FACTOR * kWarpSize;
+ GROUP_INDEX_SELECT_UNROLL_FACTOR * EMULATED_WARP_SIZE;
-// GROUP_INDEX_SELECT_COLS_PER_WARP must be power of two
constexpr int GROUP_INDEX_SELECT_LOG_COLS_PER_WARP =
log2_calc::value;
@@ -25,6 +32,20 @@ int get_group_index_select_cols_per_warp() {
return GROUP_INDEX_SELECT_COLS_PER_WARP;
}
+#ifdef USE_ROCM
+constexpr int GROUP_INDEX_SELECT_COLS_PER_WARP_ADD =
+ GROUP_INDEX_SELECT_UNROLL_FACTOR * kWarpSize;
+constexpr int GROUP_INDEX_SELECT_LOG_COLS_PER_WARP_ADD =
+ log2_calc::value;
+#else
+constexpr int GROUP_INDEX_SELECT_COLS_PER_WARP_ADD =
+ GROUP_INDEX_SELECT_UNROLL_FACTOR * EMULATED_WARP_SIZE;
+constexpr int GROUP_INDEX_SELECT_LOG_COLS_PER_WARP_ADD =
+ log2_calc::value;
+#endif
+
+
+
template <
typename index_t,
typename scalar_t,
@@ -42,18 +63,18 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel(
const int32_t* num_cols_group,
const int64_t num_work_rows, // number of rows to work on per member
const int64_t group_size) {
- const auto total_num_warps = warp_offsets_group[group_size];
-
-
+ const auto total_num_warps = warp_offsets_group[group_size];
// USE_INDEX_SELECT is a template argument; the compiler prunes the unused branch.
#ifdef USE_ROCM
if (USE_INDEX_SELECT) {
for (int64_t warp_id = threadIdx.y * gridDim.x + blockIdx.x;
warp_id < total_num_warps;
warp_id += gridDim.x * blockDim.y) {
- int32_t member_id, member_warp_id, num_cols, warps_per_row;
+ int32_t num_cols, warps_per_row;
+ int32_t member_id = 0;
+ int32_t member_warp_id = 0;
if (USE_VAR_COLS) {
- __shared__ int member_ids[kMaxThreads / kWarpSize];
+ __shared__ int member_ids[kMaxThreads / EMULATED_WARP_SIZE];
if (threadIdx.x == 0) {
binary_search_range(
&member_ids[threadIdx.y],
@@ -67,35 +88,35 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel(
warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP;
member_warp_id = warp_id - warp_offsets_group[member_id];
}
-#else
- int32_t num_cols = 0;
- int32_t warps_per_row = 0;
+ #else
+ int32_t num_cols = 0;
+ int32_t warps_per_row = 0;
- if constexpr (!USE_VAR_COLS) {
- num_cols = num_cols_group[0];
- warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP;
- }
+ if constexpr (!USE_VAR_COLS) {
+ num_cols = num_cols_group[0];
+ warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP;
+ }
- for (int64_t warp_id = threadIdx.y * gridDim.x + blockIdx.x;
+ for (int64_t warp_id = threadIdx.y * gridDim.x + blockIdx.x;
warp_id < total_num_warps;
warp_id += gridDim.x * blockDim.y) {
- int32_t member_id = 0;
- int32_t member_warp_id = 0;
- if constexpr (USE_VAR_COLS) {
- __shared__ int member_ids[kMaxThreads / EMULATED_WARP_SIZE];
- if (threadIdx.x == 0) {
- binary_search_range(
- &member_ids[threadIdx.y],
- warp_offsets_group + 1,
- warp_id,
- group_size);
- }
- syncwarp();
- member_id = member_ids[threadIdx.y];
- num_cols = num_cols_group[member_id];
- warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP;
- member_warp_id = warp_id - warp_offsets_group[member_id];
-#endif
+ int32_t member_id = 0;
+ int32_t member_warp_id = 0;
+ if constexpr (USE_VAR_COLS) {
+ __shared__ int member_ids[kMaxThreads / EMULATED_WARP_SIZE];
+ if (threadIdx.x == 0) {
+ binary_search_range(
+ &member_ids[threadIdx.y],
+ warp_offsets_group + 1,
+ warp_id,
+ group_size);
+ }
+ syncwarp();
+ member_id = member_ids[threadIdx.y];
+ num_cols = num_cols_group[member_id];
+ warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP;
+ member_warp_id = warp_id - warp_offsets_group[member_id];
+ #endif
else {
// All columns are the same
#ifdef USE_ROCM
@@ -126,9 +147,9 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel(
}
}
} else {
-
// Cache a handful of scatter destinations per warp so we can merge
// consecutive updates that hit the same index before touching global memory.
+
constexpr int kCacheSlots = 2;
index_t cached_idx[kCacheSlots];
scalar_t cached_vals[kCacheSlots][UNROLL_FACTOR];
@@ -344,8 +365,8 @@ DLL_PUBLIC void group_index_select_or_add_cuda(
USE_INDEX_SELECT, \
USE_VAR_COLS, \
GROUP_INDEX_SELECT_UNROLL_FACTOR, \
- GROUP_INDEX_SELECT_COLS_PER_WARP, \
- GROUP_INDEX_SELECT_LOG_COLS_PER_WARP>), \
+ USE_INDEX_SELECT ? GROUP_INDEX_SELECT_COLS_PER_WARP : GROUP_INDEX_SELECT_COLS_PER_WARP_ADD, \
+ USE_INDEX_SELECT ? GROUP_INDEX_SELECT_LOG_COLS_PER_WARP : GROUP_INDEX_SELECT_LOG_COLS_PER_WARP_ADD>), \
grid_size, \
block_size, \
0, \
From cca797b5ef97eb79a4150651c33cc97a17bcb8a9 Mon Sep 17 00:00:00 2001
From: Shreyashri Biswas
Date: Wed, 12 Nov 2025 01:41:05 +0000
Subject: [PATCH 6/8] added seperate var for add and select mode
---
.../src/sparse_ops/sparse_group_index.cu | 149 +++++++++---------
1 file changed, 78 insertions(+), 71 deletions(-)
diff --git a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu
index dcf3028b0a..641ea11b27 100644
--- a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu
+++ b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu
@@ -13,7 +13,7 @@ using Tensor = at::Tensor;
namespace fbgemm_gpu {
#ifdef USE_ROCM
-// The wave size is forced to be 32 on ROCm devices in favor
+// The wave size is forced to be 32 on ROCm devices in favor
// of granularity losses reduction.
constexpr int EMULATED_WARP_SIZE = 32;
#else
@@ -22,29 +22,29 @@ constexpr int EMULATED_WARP_SIZE = kWarpSize;
// TODO: Update UNROLL_FACTOR
constexpr int GROUP_INDEX_SELECT_UNROLL_FACTOR = 1;
-constexpr int GROUP_INDEX_SELECT_COLS_PER_WARP =
+
+// SELECT (fwd): use EMULATED_WARP_SIZE
+constexpr int GROUP_INDEX_SELECT_COLS_PER_WARP_FWD =
GROUP_INDEX_SELECT_UNROLL_FACTOR * EMULATED_WARP_SIZE;
+constexpr int GROUP_INDEX_SELECT_LOG_COLS_PER_WARP_FWD =
+ log2_calc::value;
+
+// ADD (bwd): use kWarpSize
+constexpr int GROUP_INDEX_SELECT_COLS_PER_WARP_BWD =
+ GROUP_INDEX_SELECT_UNROLL_FACTOR * kWarpSize;
+constexpr int GROUP_INDEX_SELECT_LOG_COLS_PER_WARP_BWD =
+ log2_calc::value;
-constexpr int GROUP_INDEX_SELECT_LOG_COLS_PER_WARP =
- log2_calc::value;
int get_group_index_select_cols_per_warp() {
- return GROUP_INDEX_SELECT_COLS_PER_WARP;
+ return GROUP_INDEX_SELECT_COLS_PER_WARP_BWD;
}
-#ifdef USE_ROCM
-constexpr int GROUP_INDEX_SELECT_COLS_PER_WARP_ADD =
- GROUP_INDEX_SELECT_UNROLL_FACTOR * kWarpSize;
-constexpr int GROUP_INDEX_SELECT_LOG_COLS_PER_WARP_ADD =
- log2_calc::value;
-#else
-constexpr int GROUP_INDEX_SELECT_COLS_PER_WARP_ADD =
- GROUP_INDEX_SELECT_UNROLL_FACTOR * EMULATED_WARP_SIZE;
-constexpr int GROUP_INDEX_SELECT_LOG_COLS_PER_WARP_ADD =
- log2_calc::value;
-#endif
-
-
+// New: explicit selector
+int get_group_index_select_cols_per_warp(bool use_index_select) {
+ return use_index_select ? GROUP_INDEX_SELECT_COLS_PER_WARP_FWD
+ : GROUP_INDEX_SELECT_COLS_PER_WARP_BWD;
+}
template <
typename index_t,
@@ -63,16 +63,15 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel(
const int32_t* num_cols_group,
const int64_t num_work_rows, // number of rows to work on per member
const int64_t group_size) {
- const auto total_num_warps = warp_offsets_group[group_size];
+ const auto total_num_warps = warp_offsets_group[group_size];
+
+#ifdef USE_ROCM
// USE_INDEX_SELECT is a template argument; the compiler prunes the unused branch.
- #ifdef USE_ROCM
if (USE_INDEX_SELECT) {
for (int64_t warp_id = threadIdx.y * gridDim.x + blockIdx.x;
warp_id < total_num_warps;
warp_id += gridDim.x * blockDim.y) {
- int32_t num_cols, warps_per_row;
- int32_t member_id = 0;
- int32_t member_warp_id = 0;
+ int32_t member_id, member_warp_id, num_cols, warps_per_row;
if (USE_VAR_COLS) {
__shared__ int member_ids[kMaxThreads / EMULATED_WARP_SIZE];
if (threadIdx.x == 0) {
@@ -87,42 +86,10 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel(
num_cols = num_cols_group[member_id];
warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP;
member_warp_id = warp_id - warp_offsets_group[member_id];
- }
- #else
- int32_t num_cols = 0;
- int32_t warps_per_row = 0;
-
- if constexpr (!USE_VAR_COLS) {
- num_cols = num_cols_group[0];
- warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP;
- }
-
- for (int64_t warp_id = threadIdx.y * gridDim.x + blockIdx.x;
- warp_id < total_num_warps;
- warp_id += gridDim.x * blockDim.y) {
- int32_t member_id = 0;
- int32_t member_warp_id = 0;
- if constexpr (USE_VAR_COLS) {
- __shared__ int member_ids[kMaxThreads / EMULATED_WARP_SIZE];
- if (threadIdx.x == 0) {
- binary_search_range(
- &member_ids[threadIdx.y],
- warp_offsets_group + 1,
- warp_id,
- group_size);
- }
- syncwarp();
- member_id = member_ids[threadIdx.y];
- num_cols = num_cols_group[member_id];
- warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP;
- member_warp_id = warp_id - warp_offsets_group[member_id];
- #endif
- else {
+ } else {
// All columns are the same
- #ifdef USE_ROCM
num_cols = num_cols_group[0];
warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP;
- #endif
member_id = warp_id / (warps_per_row * num_work_rows);
member_warp_id = warp_id - (member_id * warps_per_row * num_work_rows);
}
@@ -139,17 +106,12 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel(
const index_t idx = indices[row];
#pragma unroll
for (int i = 0; i < UNROLL_FACTOR && col_offset + i < num_cols; i++) {
- #ifndef USE_ROCM
- if constexpr (USE_INDEX_SELECT) {
- #endif
output[row * num_cols + i] = LDG(&input[idx * num_cols + i]);
-#ifdef USE_ROCM
}
}
} else {
// Cache a handful of scatter destinations per warp so we can merge
// consecutive updates that hit the same index before touching global memory.
-
constexpr int kCacheSlots = 2;
index_t cached_idx[kCacheSlots];
scalar_t cached_vals[kCacheSlots][UNROLL_FACTOR];
@@ -209,9 +171,7 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel(
num_cols = num_cols_group[member_id];
warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP;
member_warp_id = warp_id - warp_offsets_group[member_id];
-#endif
} else {
-#ifdef USE_ROCM
// All columns are the same
num_cols = num_cols_group[0];
warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP;
@@ -319,14 +279,56 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel(
flush_cache(active_output_base, active_num_cols, active_col_offset);
}
-}
-#else
- gpuAtomicAddNoReturn(
+#else // Original CUDA implementation
+ for (int64_t warp_id = threadIdx.y * gridDim.x + blockIdx.x;
+ warp_id < total_num_warps;
+ warp_id += gridDim.x * blockDim.y) {
+ int32_t member_id, member_warp_id, num_cols, warps_per_row;
+ if (USE_VAR_COLS) {
+ __shared__ int member_ids[kMaxThreads / kWarpSize];
+ if (threadIdx.x == 0) {
+ binary_search_range(
+ &member_ids[threadIdx.y],
+ warp_offsets_group + 1,
+ warp_id,
+ group_size);
+ }
+ syncwarp();
+ member_id = member_ids[threadIdx.y];
+ num_cols = num_cols_group[member_id];
+ warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP;
+ member_warp_id = warp_id - warp_offsets_group[member_id];
+ } else {
+ // All columns are the same
+ num_cols = num_cols_group[0];
+ warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP;
+ member_id = warp_id / (warps_per_row * num_work_rows);
+ member_warp_id = warp_id - (member_id * warps_per_row * num_work_rows);
+ }
+ const auto row = member_warp_id / warps_per_row;
+ const auto col_offset =
+ ((member_warp_id % warps_per_row) << LOG_COLS_PER_WARP) +
+ (threadIdx.x * UNROLL_FACTOR);
+ scalar_t* input =
+ reinterpret_cast(input_ptrs[member_id]) + col_offset;
+ scalar_t* output =
+ reinterpret_cast(output_ptrs[member_id]) + col_offset;
+
+ index_t* indices = reinterpret_cast(indices_ptrs[member_id]);
+ const index_t idx = indices[row];
+#pragma unroll
+ for (int i = 0; i < UNROLL_FACTOR && col_offset + i < num_cols; i++) {
+ // Compile time conditional
+ if (USE_INDEX_SELECT) {
+ output[row * num_cols + i] = LDG(&input[idx * num_cols + i]);
+ } else {
+ gpuAtomicAddNoReturn(
&output[idx * num_cols + i], input[row * num_cols + i]);
+ }
}
}
-#endif
-
+#endif // USE_ROCM
+}
DLL_PUBLIC void group_index_select_or_add_cuda(
const int64_t* input_ptrs,
@@ -349,7 +351,7 @@ DLL_PUBLIC void group_index_select_or_add_cuda(
at::cuda::OptionalCUDAGuard device_guard(device);
// Partition work based on num_work_rows
- uint32_t num_warps_per_threadblock = kMaxThreads / kWarpSize;
+ uint32_t num_warps_per_threadblock = kMaxThreads / kWarpSize;
uint32_t max_grid_size =
at::cuda::getCurrentDeviceProperties()->multiProcessorCount * 8;
uint32_t grid_size = std::min(
@@ -365,8 +367,13 @@ DLL_PUBLIC void group_index_select_or_add_cuda(
USE_INDEX_SELECT, \
USE_VAR_COLS, \
GROUP_INDEX_SELECT_UNROLL_FACTOR, \
- USE_INDEX_SELECT ? GROUP_INDEX_SELECT_COLS_PER_WARP : GROUP_INDEX_SELECT_COLS_PER_WARP_ADD, \
- USE_INDEX_SELECT ? GROUP_INDEX_SELECT_LOG_COLS_PER_WARP : GROUP_INDEX_SELECT_LOG_COLS_PER_WARP_ADD>), \
+ (USE_INDEX_SELECT \
+ ? GROUP_INDEX_SELECT_COLS_PER_WARP_FWD \
+ : GROUP_INDEX_SELECT_COLS_PER_WARP_BWD), \
+ /* LOG_COLS_PER_WARP */ \
+ (USE_INDEX_SELECT \
+ ? GROUP_INDEX_SELECT_LOG_COLS_PER_WARP_FWD \
+ : GROUP_INDEX_SELECT_LOG_COLS_PER_WARP_BWD)>), \
grid_size, \
block_size, \
0, \
From c87df8429108c36ed93a096f254ba84a31f8048c Mon Sep 17 00:00:00 2001
From: Shreyashri Biswas
Date: Wed, 12 Nov 2025 02:10:53 +0000
Subject: [PATCH 7/8] removed comments
---
fbgemm_gpu/src/sparse_ops/sparse_group_index.cu | 14 +++-----------
fbgemm_gpu/test/sparse/index_select_test.py | 2 +-
2 files changed, 4 insertions(+), 12 deletions(-)
diff --git a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu
index 641ea11b27..729331dd26 100644
--- a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu
+++ b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu
@@ -35,12 +35,10 @@ constexpr int GROUP_INDEX_SELECT_COLS_PER_WARP_BWD =
constexpr int GROUP_INDEX_SELECT_LOG_COLS_PER_WARP_BWD =
log2_calc::value;
-
int get_group_index_select_cols_per_warp() {
return GROUP_INDEX_SELECT_COLS_PER_WARP_BWD;
}
-// New: explicit selector
int get_group_index_select_cols_per_warp(bool use_index_select) {
return use_index_select ? GROUP_INDEX_SELECT_COLS_PER_WARP_FWD
: GROUP_INDEX_SELECT_COLS_PER_WARP_BWD;
@@ -61,7 +59,7 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel(
const int64_t* indices_ptrs,
const int64_t* warp_offsets_group,
const int32_t* num_cols_group,
- const int64_t num_work_rows, // number of rows to work on per member
+ const int64_t num_work_rows,
const int64_t group_size) {
const auto total_num_warps = warp_offsets_group[group_size];
@@ -87,7 +85,6 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel(
warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP;
member_warp_id = warp_id - warp_offsets_group[member_id];
} else {
- // All columns are the same
num_cols = num_cols_group[0];
warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP;
member_id = warp_id / (warps_per_row * num_work_rows);
@@ -110,8 +107,6 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel(
}
}
} else {
- // Cache a handful of scatter destinations per warp so we can merge
- // consecutive updates that hit the same index before touching global memory.
constexpr int kCacheSlots = 2;
index_t cached_idx[kCacheSlots];
scalar_t cached_vals[kCacheSlots][UNROLL_FACTOR];
@@ -299,7 +294,6 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel(
warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP;
member_warp_id = warp_id - warp_offsets_group[member_id];
} else {
- // All columns are the same
num_cols = num_cols_group[0];
warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP;
member_id = warp_id / (warps_per_row * num_work_rows);
@@ -327,7 +321,7 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel(
}
}
}
-#endif // USE_ROCM
+#endif
}
DLL_PUBLIC void group_index_select_or_add_cuda(
@@ -349,9 +343,7 @@ DLL_PUBLIC void group_index_select_or_add_cuda(
}
at::cuda::OptionalCUDAGuard device_guard(device);
-
- // Partition work based on num_work_rows
- uint32_t num_warps_per_threadblock = kMaxThreads / kWarpSize;
+ uint32_t num_warps_per_threadblock = kMaxThreads / kWarpSize;
uint32_t max_grid_size =
at::cuda::getCurrentDeviceProperties()->multiProcessorCount * 8;
uint32_t grid_size = std::min(
diff --git a/fbgemm_gpu/test/sparse/index_select_test.py b/fbgemm_gpu/test/sparse/index_select_test.py
index c057297ff9..6c61b77bf8 100644
--- a/fbgemm_gpu/test/sparse/index_select_test.py
+++ b/fbgemm_gpu/test/sparse/index_select_test.py
@@ -239,7 +239,7 @@ def compare_tensor_groups(
{"rtol": 1e-02, "atol": 1e-02} if dtype == torch.half else {},
)
- @given(
+ @given(
num_inputs=st.integers(0, 100),
max_input_rows=st.integers(2, 32),
max_cols_factor=st.integers(2, 256),
From 735b8035ae3de118413e26fa29f1c772dbc6be00 Mon Sep 17 00:00:00 2001
From: Shreyashri Biswas
Date: Wed, 12 Nov 2025 02:33:53 +0000
Subject: [PATCH 8/8] fix grid and block size
---
.../src/sparse_ops/sparse_group_index.cu | 21 +++++++++++++++----
1 file changed, 17 insertions(+), 4 deletions(-)
diff --git a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu
index 729331dd26..722c204dd8 100644
--- a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu
+++ b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu
@@ -62,10 +62,11 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel(
const int64_t num_work_rows,
const int64_t group_size) {
const auto total_num_warps = warp_offsets_group[group_size];
-
+ int32_t num_cols = 0;
+ int32_t warps_per_row = 0;
#ifdef USE_ROCM
// USE_INDEX_SELECT is a template argument; the compiler prunes the unused branch.
- if (USE_INDEX_SELECT) {
+ if constexpr (USE_INDEX_SELECT) {
for (int64_t warp_id = threadIdx.y * gridDim.x + blockIdx.x;
warp_id < total_num_warps;
warp_id += gridDim.x * blockDim.y) {
@@ -343,13 +344,25 @@ DLL_PUBLIC void group_index_select_or_add_cuda(
}
at::cuda::OptionalCUDAGuard device_guard(device);
- uint32_t num_warps_per_threadblock = kMaxThreads / kWarpSize;
+ uint32_t num_warps_per_threadblock;
+ dim3 block_size;
+
+ if (use_index_select) {
+ // Forward pass uses EMULATED_WARP_SIZE
+ num_warps_per_threadblock = kMaxThreads / EMULATED_WARP_SIZE;
+ block_size = dim3(EMULATED_WARP_SIZE, num_warps_per_threadblock, 1);
+ } else {
+ // Backward pass uses kWarpSize
+ num_warps_per_threadblock = kMaxThreads / kWarpSize;
+ block_size = dim3(kWarpSize, num_warps_per_threadblock, 1);
+ }
+
uint32_t max_grid_size =
at::cuda::getCurrentDeviceProperties()->multiProcessorCount * 8;
uint32_t grid_size = std::min(
cuda_calc_xblock_count(total_num_warps, num_warps_per_threadblock),
max_grid_size);
- dim3 block_size(kWarpSize, num_warps_per_threadblock, 1);
+
#define INVOKE_GROUP_INDEX_SELECT_OR_ADD(USE_INDEX_SELECT, USE_VAR_COLS) \
FBGEMM_LAUNCH_KERNEL( \