Skip to content

[CUDA] Implement MaskedScatter #3151

Merged
zcbenz merged 8 commits intoml-explore:mainfrom
Lyxot:cuda/masked_scatter
Mar 15, 2026
Merged

[CUDA] Implement MaskedScatter #3151
zcbenz merged 8 commits intoml-explore:mainfrom
Lyxot:cuda/masked_scatter

Conversation

@Lyxot
Copy link
Copy Markdown
Contributor

@Lyxot Lyxot commented Feb 20, 2026

Proposed changes

This PR adds CUDA support for MaskedScatter.

Changed files

  • mlx/backend/cuda/indexing.cpp: implemented CUDA MaskedScatter::eval_gpu using the CUDA JIT module path.
  • mlx/backend/cuda/device/scatter.cuh: added the JIT device kernel masked_scatter_assign<...> used by CUDA masked scatter.
  • mlx/backend/cuda/scan.cu: refactored scan execution into reusable scan_gpu_inplace(...) and updated Scan::eval_gpu to delegate to it.
  • mlx/backend/cuda/scan.h: added declaration for scan_gpu_inplace(...).
  • python/tests/cuda_skip.py, tests/ops_tests.cpp, tests/autograd_tests.cpp: removed CUDA skip entries for masked-scatter-related tests.

Validation

  • Python:
    • python -m pytest python/tests/test_ops.py -k masked_scatter -q passed.
    • python -m pytest python/tests/test_vmap.py -k vmap_masked_scatter -q passed.
    • python -m pytest python/tests/test_array.py -k setitem_with_boolean_mask -q passed.
  • C++:
    • build/tests/tests -tc="test masked_scatter,test masked_scatter autograd" passed.

Checklist

Put an x in the boxes that apply.

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

Copilot AI review requested due to automatic review settings February 20, 2026 19:41
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR implements CUDA support for the MaskedScatter operation, which scatters values from a source array into a destination array at positions specified by a boolean mask. The implementation follows the existing Metal backend pattern and properly integrates with the CUDA backend infrastructure.

Changes:

  • Converted indexing.cpp to indexing.cu and added full CUDA MaskedScatter::eval_gpu implementation with masked_assign kernel
  • Refactored scan launch logic into reusable scan_gpu_inplace function with new header file
  • Removed CUDA skip entries for masked-scatter-related tests

Reviewed changes

Copilot reviewed 8 out of 8 changed files in this pull request and generated no comments.

Show a summary per file
File Description
mlx/backend/cuda/indexing.cu Implemented masked_assign CUDA kernel and MaskedScatter::eval_gpu method; converted from .cpp to .cu
mlx/backend/cuda/scan.cu Refactored scan logic into scan_gpu_inplace function for reuse in MaskedScatter
mlx/backend/cuda/scan.h Added header declaring scan_gpu_inplace function
mlx/backend/cuda/primitives.cpp Removed NO_GPU(MaskedScatter) macro to enable CUDA support
mlx/backend/cuda/CMakeLists.txt Updated build to compile indexing.cu instead of indexing.cpp
tests/ops_tests.cpp Removed CUDA skip guard from masked_scatter tests
tests/autograd_tests.cpp Removed CUDA skip guard from masked_scatter autograd tests
python/tests/cuda_skip.py Removed three masked-scatter-related test entries from skip list
Comments suppressed due to low confidence (1)

mlx/backend/cuda/indexing.cu:80

  • The masked_assign kernel uses a signed 32-bit IdxT together with stride = static_cast<IdxT>(blockDim.x) * gridDim.x * gridDim.y * gridDim.z, which can overflow when mask_flat.size() approaches INT32_MAX, causing stride to wrap negative while total remains positive. In that case, the loop for (IdxT idx = thread_id; idx < total; idx += stride) can revisit with negative idx values and read/write mask[idx], scatter_offsets[idx], and out[idx] out of bounds, leading to GPU memory corruption and potential data exposure or code execution in contexts that rely on untrusted shapes. To address this, ensure the index type used in this kernel cannot overflow for the chosen grid/block configuration (e.g., use an unsigned or 64-bit index consistently for IdxT when computing block_id, stride, and indexing, or otherwise constrain gridDim/blockDim so their product fits safely in the index type).

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copy link
Copy Markdown
Collaborator

@zcbenz zcbenz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, would like another review before merging.

Comment thread mlx/backend/cuda/indexing.cpp Outdated
@Lyxot Lyxot force-pushed the cuda/masked_scatter branch from 39962fe to f5693f7 Compare February 27, 2026 10:50
@Lyxot Lyxot requested a review from zcbenz February 27, 2026 10:50
@nastya236 nastya236 self-requested a review February 27, 2026 16:05
nastya236
nastya236 previously approved these changes Feb 27, 2026
Copy link
Copy Markdown
Collaborator

@nastya236 nastya236 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me as well!

@nastya236 nastya236 self-requested a review February 27, 2026 20:19
@nastya236 nastya236 dismissed their stale review February 28, 2026 00:25

Re-review

Copy link
Copy Markdown
Collaborator

@nastya236 nastya236 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I said looks great, thanks for your contribution.
Could you please provide bandwidth numbers for masked scatter kernel for a range of shapes?

@Lyxot
Copy link
Copy Markdown
Contributor Author

Lyxot commented Feb 28, 2026

@nastya236 bench result from benchmarks/python/masked_scatter.py
NVIDIA_GeForce_RTX_4070_SUPER_masked_scatter_float32_page-0001

@Lyxot Lyxot requested a review from nastya236 March 2, 2026 10:55
@Lyxot Lyxot force-pushed the cuda/masked_scatter branch from 19571bc to e70c1fc Compare March 4, 2026 13:21
@Lyxot
Copy link
Copy Markdown
Contributor Author

Lyxot commented Mar 5, 2026

@nastya236
I followed up on the large-shape regression and implemented an optimization pass for masked_scatter.

What changed:

  • Added a contiguous fast path with a fused masked-scatter kernel.
  • Increased per-thread work and reduced memory traffic in the hot path.
  • Reworked prefix handling for that path (tile count + prefix offsets) to avoid the previous per-element offset overhead.

On my setup, MLX is now faster than Torch in all benchmark cases, and the prior large-shape degradation trend is resolved.

NVIDIA_GeForce_RTX_4070_SUPER_masked_scatter_float32

Could you please take another look?

@nastya236
Copy link
Copy Markdown
Collaborator

Thanks for the update! I will look as soon as possible.

@zcbenz
Copy link
Copy Markdown
Collaborator

zcbenz commented Mar 13, 2026

I think the improvement on performance is impressive but in the meanwhile the new kernel code is complicated and really hard to review, and we also want to avoid using the CUB device APIs as the overhead of graph capture easily eliminates the performance gain.

I would say the initial PR was already good enough, the code was clean and achieved reasonable performance. For further optimizations we should judge on real world needs otherwise we would have a large piece of code that are hard to maintain and never used.

@nastya236 What do you think if we just merge the initial version?

@nastya236
Copy link
Copy Markdown
Collaborator

nastya236 commented Mar 13, 2026

I agree you @zcbenz, I think initial masked scatter is slow for larger shapes because of the scan. Just out of curiosity I checked pytorch masked scatter and it is identical to what was proposed by @Lyxot initially. I think we can merge masked scatter -- first version, and if needed scan improvement should be a separate PR.

@Lyxot thanks for exploring faster approach.

@Lyxot
Copy link
Copy Markdown
Contributor Author

Lyxot commented Mar 14, 2026

@nastya236 @zcbenz Should we preserve e70c1fc and f1c2a0b? These 2 small commits provide a ~1.1x-1.4x performance improvement over the initial version, while keeping the implementation relatively simple

@Lyxot Lyxot force-pushed the cuda/masked_scatter branch from fead885 to 7370a22 Compare March 14, 2026 07:03
Copy link
Copy Markdown
Collaborator

@zcbenz zcbenz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new changes look good to me, thanks for updating the PR!

Comment thread mlx/backend/cuda/device/scatter.cuh Outdated
Comment thread mlx/backend/cuda/indexing.cpp Outdated
Comment thread mlx/backend/cuda/device/scatter.cuh Outdated
Comment thread mlx/backend/cuda/scan.h Outdated
Copy link
Copy Markdown
Collaborator

@nastya236 nastya236 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great thank you! Lets merge when the tests are done.

@zcbenz zcbenz merged commit 0bdbfdb into ml-explore:main Mar 15, 2026
16 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants