Skip to content

Fix Flash Attention 3 API compatibility for window size parameters#2704

Open
jhvmhg wants to merge 8 commits intoNVIDIA:mainfrom
jhvmhg:fix/flash_attn3_support_CP
Open

Fix Flash Attention 3 API compatibility for window size parameters#2704
jhvmhg wants to merge 8 commits intoNVIDIA:mainfrom
jhvmhg:fix/flash_attn3_support_CP

Conversation

@jhvmhg
Copy link

@jhvmhg jhvmhg commented Feb 25, 2026

Replace single window_size parameter with window_size_left and window_size_right in flash_attn_fwd function to align with flash-attn v2.7.0+ API changes.

  • Update function signature in flash_attn_interface
  • Maintain backward compatibility where possible
  • Ensure consistency with Flash Attention v2 implementation

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

  1. Fix window size parameters in flash_attn_fwd - Replaces the single window_size parameter with separate window_size_left and window_size_right parameters to match the updated flash-attn v2.7.0+ API.
  2. Fix causal parameter naming in flash_attn_bwd - Renames causal to is_causal in the backward function signature for consistency with the latest flash-attn interface.

Motivation:

The flash-attn library v2.7.0+ introduced breaking API changes that cause compatibility issues with TransformerEngine's Flash Attention 3 integration. These updates ensure seamless operation with newer versions of the flash-attn library while maintaining correctness of both forward and backward attention computations.

Related API Changes:

flash-attn v2.7.0+ split window_size into window_size_left and window_size_right
flash-attn v3+ renamed causal parameter to is_causal in backward pass

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Replace single window_size parameter with window_size_left and window_size_right
    in flash_attn_fwd function to align with flash-attn v2.7.0+ API changes.
  • Rename causal parameter to is_causal in flash_attn_bwd function to align
    with flash-attn v3

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Replace single window_size parameter with window_size_left and window_size_right
in flash_attn_fwd function to align with flash-attn v2.7.0+ API changes.

- Update function signature in flash_attn_interface
- Maintain backward compatibility where possible
- Ensure consistency with Flash Attention v2 implementation

Signed-off-by: Chaoyang Mei <1192554423@qq.com>
Signed-off-by: meichaoyang001 <meichaoyang001@ke.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 25, 2026

Greptile Summary

Updates Flash Attention 3 integration to use the new API introduced in flash-attn v2.7.0+, replacing single window_size parameter with window_size_left/window_size_right and using is_causal instead of causal in backward pass.

Key changes:

  • Reorganizes version checks so Flash Attn 3 uses new API style (window_size_left/window_size_right)
  • Adds conditional logic to set is_causal for Flash Attn 3 and causal for older versions in backward functions
  • Properly converts ctx.attn_mask_type string to boolean using "causal" in ctx.attn_mask_type
  • Changes are applied consistently across all affected functions: cp_p2p_fwd_flash_attn, cp_p2p_bwd_flash_attn, AttnFuncWithCPAndKVP2P, AttnFuncWithCPAndKVAllGather, and AttnFuncWithCPAndQKVOA2A

Confidence Score: 5/5

  • This PR is safe to merge with minimal risk - it's a well-executed API compatibility fix with correct logic
  • Changes are straightforward API parameter updates applied consistently across all functions. The conditional logic correctly maps Flash Attn versions to appropriate APIs, type conversions are correct, and the previous review comment has been addressed.
  • No files require special attention

Important Files Changed

Filename Overview
transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py Updates Flash Attention 3 to use new API: window_size_left/window_size_right instead of window_size, and is_causal instead of causal. Changes are consistent across all functions.

Last reviewed commit: 1c66c29

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 25, 2026

Additional Comments (1)

transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py
removed causal parameter but other flash_attn_bwd calls in this file (lines 3222, 3832) still pass it - verify this inconsistency is intentional

Rename causal parameter to is_causal in flash_attn_bwd function to align
with flash-attn v2.7.0+ API changes. This ensures consistency with the
updated flash-attn library interface for backward pass operations.

Signed-off-by: meichaoyang001 <meichaoyang001@ke.com>
@jhvmhg jhvmhg force-pushed the fix/flash_attn3_support_CP branch from a245229 to f9752ca Compare February 25, 2026 07:54
Copy link
Author

@jhvmhg jhvmhg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix Flash Attention 3 backward API parameter naming

Rename causal parameter to is_causal in flash_attn_bwd function to align
with flash-attn v2.7.0+ API changes. This ensures consistency with the
updated flash-attn library interface for backward pass operations.

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

jhvmhg and others added 2 commits February 25, 2026 15:56
Rename causal parameter to is_causal in flash_attn_bwd function to align
with flash-attn v2.7.0+ API changes. This ensures consistency with the
updated flash-attn library interface for backward pass operations.

Signed-off-by: meichaoyang001 <meichaoyang001@ke.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

Replace keyword arguments with positional arguments in flash_attn_fwd and
flash_attn_bwd to abstract away parameter naming differences (causal vs
is_causal) between flash-attn versions. This provides a more robust
interface that is resilient to future API changes in the flash-attn library.

- Convert window_size_left, window_size_right, and causal parameters to
  positional args in both forward and backward functions
- Eliminate version-specific parameter naming dependencies
- Simplify compatibility handling across flash-attn v2.7.0+ variants

Signed-off-by: meichaoyang001 <meichaoyang001@ke.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

softmax_lse_per_step[i],
*fa_backward_args_thd,
causal="causal" in ctx.attn_mask_type,
ctx.attn_mask_type,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ctx.attn_mask_type is a string (e.g., "causal", "no_mask"), not a boolean. Should be "causal" in ctx.attn_mask_type to convert to boolean.

Suggested change
ctx.attn_mask_type,
"causal" in ctx.attn_mask_type,

@jhvmhg jhvmhg closed this Feb 25, 2026
Rename causal parameter to is_causal in flash_attn_bwd function to align
with flash-attn v3 API changes. This ensures consistency with the
updated flash-attn library interface for backward pass operations.

Signed-off-by: meichaoyang001 <meichaoyang001@ke.com>
@jhvmhg jhvmhg reopened this Feb 25, 2026
@jhvmhg jhvmhg closed this Feb 25, 2026
@jhvmhg jhvmhg reopened this Feb 25, 2026
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

@ptrendx ptrendx requested a review from cyanguwa February 25, 2026 19:57
@cyanguwa cyanguwa requested a review from mk-61 February 26, 2026 00:04
@cyanguwa
Copy link
Collaborator

@mk-61 I think the changes look good, but could you please follow through with the CI, especially the L3_FA_version tests, to make sure the new changes pass the SWA tests for FA3? Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants