Skip to content

Fix flaky dropout tests#3289

Open
micmelesse wants to merge 1 commit into
mainfrom
micmelesse/fix_fa_main
Open

Fix flaky dropout tests#3289
micmelesse wants to merge 1 commit into
mainfrom
micmelesse/fix_fa_main

Conversation

@micmelesse
Copy link
Copy Markdown
Contributor

@micmelesse micmelesse commented May 20, 2026

Motivation

This pr fixes a few flaky tests that have been failing on main due to mis-compilation of tl.where. The tl.where calls are replaced with semantically equivalent code.

Technical Details

Test Plan

I tested locally on an MI350 machine.

Test Result

Submission Checklist

@github-actions
Copy link
Copy Markdown
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests on MI35X (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-300x Run an additional Triton test job on MI300X in PRs; main branch always runs both MI35X and MI300X
ci:sglang SGLang integration tests: DeepSeek-R1-MXFP4 accuracy, Qwen 3.5 accuracy
ci:atom ATOM benchmark: DeepSeek-R1-0528, GPT-OSS-120B
ci:atom_full ATOM accuracy suite for PR and main models from ATOM models_accuracy.json
ci:vllm vLLM benchmark: GPT-OSS-120B, DeepSeek-R1-0528, Kimi-K2.5
ci:all All standard extended tests (excludes ci:atom_full)

Only add ci:atom_full for FlyDSL or Triton upgrades.
Add labels via the sidebar or gh pr edit 3289 --add-label <label>

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR aims to eliminate flakiness in dropout-related FlashAttention backward paths by removing tl.where usage in dropout masking (reported to miscompile) and replacing it with equivalent mask-multiply logic in Triton kernels.

Changes:

  • Replaced tl.where(dropout_mask, x, 0.0) * dropout_scale with x * dropout_mask.to(x.dtype) * dropout_scale in multiple backward inner loops.
  • Applied the change consistently across split-k and atomic variants of the backward kernels.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@micmelesse micmelesse force-pushed the micmelesse/fix_fa_main branch from 22928ce to a915762 Compare May 21, 2026 20:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants