Skip to content

[Feature] Adopt FlexAttention native FA4 backend (BACKEND="FLASH") to replace manual CuTeDSL integration #30

@cicirori

Description

@cicirori

Summary

PyTorch has officially released a native FlashAttention-4 backend for FlexAttention, accessible via kernel_options={"BACKEND": "FLASH"}. This provides automatic CuTeDSL score/mask function generation and JIT instantiation of FA4 kernels — delivering 1.2×–3.2× speedups over the Triton backend on Hopper and Blackwell GPUs.

TorchSpec currently has two separate code paths:

  • flex_attention backend — uses the Triton-based FlexAttention with compile_friendly_flex_attention
  • fa_experimental backend — manually imports flash_attn.cute.flash_attn_func and wires up Eagle3 mask_mod through CuTeDSL directly (LlamaFlashAttentionMasked)

The new native pathway could unify these two backends into a single FlexAttention path with an optional BACKEND="FLASH" flag, reducing code complexity while getting FA4-level performance.

Motivation

  1. Simpler code — The current fa_experimental path requires manual CuTeDSL integration, custom forward/backward wiring, compilation patching (_patch_cutlass_compilation), and pre-compilation warmup (precompile_flash_attn_masked). The native pathway handles all of this automatically via torch.compile.

  2. Better Blackwell performance — The blog reports 2.2×–3.2× speedups on GB200 vs Triton FlexAttention. Since TorchSpec already targets SM100, this is directly relevant.

  3. Maintainability — The manual CuTeDSL integration is tightly coupled to flash_attn.cute internals (e.g., _flash_attn_fwd, _flash_attn_bwd, BlockSparseTensorsTorch). The native pathway is a stable PyTorch API.

Proposed Changes

  • Add a new attention backend option (e.g., flex_flash) that uses flex_attention with kernel_options={"BACKEND": "FLASH"}
  • Reuse the existing score_mod / mask_mod functions from generate_eagle3_mask — they should work as-is
  • Benchmark against the current fa_experimental backend to validate performance parity
  • If performance is equivalent or better, consider deprecating fa_experimental in favor of the unified path

Example

from functools import partial
from torch.nn.attention.flex_attention import flex_attention

flex_flash = torch.compile(
    partial(flex_attention, kernel_options={"BACKEND": "FLASH"}),
    dynamic=False,
)

# Existing Eagle3 mask_mod works directly
mask_mod = generate_eagle3_mask(seq_lengths, Q_LEN, KV_LEN)
block_mask = create_block_mask(mask_mod, B, H, Q_LEN, KV_LEN, device="cuda")
out = flex_flash(query, key, value, block_mask=block_mask)

References

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions