[CUDA] Implement MaskedScatter #3151
Conversation
There was a problem hiding this comment.
Pull request overview
This PR implements CUDA support for the MaskedScatter operation, which scatters values from a source array into a destination array at positions specified by a boolean mask. The implementation follows the existing Metal backend pattern and properly integrates with the CUDA backend infrastructure.
Changes:
- Converted
indexing.cpptoindexing.cuand added full CUDAMaskedScatter::eval_gpuimplementation withmasked_assignkernel - Refactored scan launch logic into reusable
scan_gpu_inplacefunction with new header file - Removed CUDA skip entries for masked-scatter-related tests
Reviewed changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated no comments.
Show a summary per file
| File | Description |
|---|---|
| mlx/backend/cuda/indexing.cu | Implemented masked_assign CUDA kernel and MaskedScatter::eval_gpu method; converted from .cpp to .cu |
| mlx/backend/cuda/scan.cu | Refactored scan logic into scan_gpu_inplace function for reuse in MaskedScatter |
| mlx/backend/cuda/scan.h | Added header declaring scan_gpu_inplace function |
| mlx/backend/cuda/primitives.cpp | Removed NO_GPU(MaskedScatter) macro to enable CUDA support |
| mlx/backend/cuda/CMakeLists.txt | Updated build to compile indexing.cu instead of indexing.cpp |
| tests/ops_tests.cpp | Removed CUDA skip guard from masked_scatter tests |
| tests/autograd_tests.cpp | Removed CUDA skip guard from masked_scatter autograd tests |
| python/tests/cuda_skip.py | Removed three masked-scatter-related test entries from skip list |
Comments suppressed due to low confidence (1)
mlx/backend/cuda/indexing.cu:80
- The
masked_assignkernel uses a signed 32-bitIdxTtogether withstride = static_cast<IdxT>(blockDim.x) * gridDim.x * gridDim.y * gridDim.z, which can overflow whenmask_flat.size()approachesINT32_MAX, causingstrideto wrap negative whiletotalremains positive. In that case, the loopfor (IdxT idx = thread_id; idx < total; idx += stride)can revisit with negativeidxvalues and read/writemask[idx],scatter_offsets[idx], andout[idx]out of bounds, leading to GPU memory corruption and potential data exposure or code execution in contexts that rely on untrusted shapes. To address this, ensure the index type used in this kernel cannot overflow for the chosen grid/block configuration (e.g., use an unsigned or 64-bit index consistently forIdxTwhen computingblock_id,stride, and indexing, or otherwise constraingridDim/blockDimso their product fits safely in the index type).
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
zcbenz
left a comment
There was a problem hiding this comment.
Looks good to me, would like another review before merging.
39962fe to
f5693f7
Compare
nastya236
left a comment
There was a problem hiding this comment.
Looks good to me as well!
nastya236
left a comment
There was a problem hiding this comment.
As I said looks great, thanks for your contribution.
Could you please provide bandwidth numbers for masked scatter kernel for a range of shapes?
|
@nastya236 bench result from |
19571bc to
e70c1fc
Compare
|
@nastya236 What changed:
On my setup, MLX is now faster than Torch in all benchmark cases, and the prior large-shape degradation trend is resolved.
Could you please take another look? |
|
Thanks for the update! I will look as soon as possible. |
|
I think the improvement on performance is impressive but in the meanwhile the new kernel code is complicated and really hard to review, and we also want to avoid using the CUB device APIs as the overhead of graph capture easily eliminates the performance gain. I would say the initial PR was already good enough, the code was clean and achieved reasonable performance. For further optimizations we should judge on real world needs otherwise we would have a large piece of code that are hard to maintain and never used. @nastya236 What do you think if we just merge the initial version? |
|
I agree you @zcbenz, I think initial masked scatter is slow for larger shapes because of the scan. Just out of curiosity I checked pytorch masked scatter and it is identical to what was proposed by @Lyxot initially. I think we can merge masked scatter -- first version, and if needed scan improvement should be a separate PR. @Lyxot thanks for exploring faster approach. |
|
@nastya236 @zcbenz Should we preserve e70c1fc and f1c2a0b? These 2 small commits provide a ~1.1x-1.4x performance improvement over the initial version, while keeping the implementation relatively simple |
fead885 to
7370a22
Compare
zcbenz
left a comment
There was a problem hiding this comment.
The new changes look good to me, thanks for updating the PR!


Proposed changes
This PR adds CUDA support for
MaskedScatter.Changed files
mlx/backend/cuda/indexing.cpp: implemented CUDAMaskedScatter::eval_gpuusing the CUDA JIT module path.mlx/backend/cuda/device/scatter.cuh: added the JIT device kernelmasked_scatter_assign<...>used by CUDA masked scatter.mlx/backend/cuda/scan.cu: refactored scan execution into reusablescan_gpu_inplace(...)and updatedScan::eval_gputo delegate to it.mlx/backend/cuda/scan.h: added declaration forscan_gpu_inplace(...).python/tests/cuda_skip.py,tests/ops_tests.cpp,tests/autograd_tests.cpp: removed CUDA skip entries for masked-scatter-related tests.Validation
python -m pytest python/tests/test_ops.py -k masked_scatter -qpassed.python -m pytest python/tests/test_vmap.py -k vmap_masked_scatter -qpassed.python -m pytest python/tests/test_array.py -k setitem_with_boolean_mask -qpassed.build/tests/tests -tc="test masked_scatter,test masked_scatter autograd"passed.Checklist
Put an
xin the boxes that apply.pre-commit run --all-filesto format my code / installed pre-commit prior to committing changes