Skip to content

Add validity checks for MoE FlatMM scatter and enable bf16 hardware atomic-add#3236

Merged
msaffari-amd merged 5 commits into
developfrom
moe-flatmm-scatter-validity
Nov 28, 2025
Merged

Add validity checks for MoE FlatMM scatter and enable bf16 hardware atomic-add#3236
msaffari-amd merged 5 commits into
developfrom
moe-flatmm-scatter-validity

Conversation

@msaffari-amd
Copy link
Copy Markdown
Contributor

@msaffari-amd msaffari-amd commented Nov 19, 2025

Proposed changes

Summary

This PR fixes crashes in MoE FlatMM operations when using bf16 data type on CDNA3 (gfx942) GPUs and enables bf16 hardware atomic operations. The crashes occurred with specific NumTokens values (powers of 2 ≥ 512) due to out-of-bounds atomic write operations.

Problem

Invalid/padding tokens in MoE operations are marked with scatter_token_id = NumTokens, which generates scatter offsets pointing exactly at the buffer end (out of bounds). When atomic write operations attempt to access these addresses:

  • fp16: Buffer atomics gracefully handle OOB → works ✓
  • bf16 on gfx950: Buffer atomics available → works ✓
  • bf16 on gfx942: Only global atomics available → crashes

Solution

1. Validity Flag Tracking

Added validity flag tracking to scatter operations using the designed tile_scatter_gather API:

  • Calculate validity for each scatter: scatter_token_id < NumTokens
  • Pass validity array to make_tile_scatter_gather()
  • Invalid tokens automatically skipped during writes

2. Enable bf16 Hardware Atomic Builtin

Previously, atomic_add<bf16x2_t> was missing the #if HAS_GLOBAL_ATOMIC_PK_ADD_BUILTIN check and always used software CAS fallback. Now it properly uses the hardware builtin when available:

Checklist

Please put an x into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.

  • I have added tests relevant to the introduced functionality, and the unit tests are passing locally
  • I have added the test to REGRESSION_TESTS list defined at the top of CMakeLists.txt in tests/CMakeLists.txt, IF the test takes more than 30 seconds to run.
  • I have added inline documentation which enables the maintainers with understanding the motivation
  • I have removed the stale documentation which is no longer relevant after this pull request
  • (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request
  • I have run clang-format on all changed files
  • Any dependent changes have been merged

Discussion

If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered

@msaffari-amd msaffari-amd self-assigned this Nov 20, 2025
@ROCm ROCm deleted a comment from Copilot AI Nov 20, 2025
@aosewski aosewski requested a review from Copilot November 20, 2025 14:40
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR fixes crashes in MoE FlatMM operations when using bf16 data type on CDNA3 (gfx942) GPUs by adding validity checks for scatter operations and enabling bf16 hardware atomic-add operations.

Key changes:

  • Added validity flag tracking to prevent out-of-bounds atomic writes for invalid/padding tokens
  • Enabled hardware atomic builtin for bf16x2_t when available instead of always using software CAS fallback
  • Fixed typo in tensor size calculation for non-InputGemm case

Reviewed Changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated no comments.

File Description
include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp Added validity flag tracking for scatter operations and fixed tensor size typo
include/ck_tile/core/arch/generic_memory_space_atomic.hpp Enabled bf16 hardware atomic builtin when available
example/ck_tile/18_flatmm/run_moe_flatmm_example.inc Marked unused variable with [[maybe_unused]]

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread example/ck_tile/18_flatmm/run_moe_flatmm_example.inc Outdated
Comment thread include/ck_tile/core/arch/generic_memory_space_atomic.hpp
@msaffari-amd msaffari-amd merged commit f875ab0 into develop Nov 28, 2025
21 checks passed
@msaffari-amd msaffari-amd deleted the moe-flatmm-scatter-validity branch November 28, 2025 08:43
AviralGoelAMD pushed a commit that referenced this pull request Nov 28, 2025
…tomic-add (#3236)

* Add validity checks for MoE FlatMM scatter and enable bf16 hardware atomic

* correct clang-format

* removed unused rtol_atol variable from example code

* clang format correction

* remove unused varable max_accumulated_value from example
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants