Skip to content

Fused CuteDSL gating kernel (#300)#300

Open
jduprat wants to merge 2 commits into
meta-pytorch:mainfrom
jduprat:export-D99181847
Open

Fused CuteDSL gating kernel (#300)#300
jduprat wants to merge 2 commits into
meta-pytorch:mainfrom
jduprat:export-D99181847

Conversation

@jduprat
Copy link
Copy Markdown
Contributor

@jduprat jduprat commented Apr 2, 2026

Summary:

Replace the multi-kernel PyTorch gating (compute_gates + gate_and_combine) with
a single fused CuteDSL kernel (fused_gate_and_combine) — 4-7x faster on B200.

Key design:

  • One warp (32 threads) per (b,n,h) row — each warp handles one output position
  • Warp-shuffle butterfly reduction for 3 gate dot-products (no shared memory)
  • elems_per_thread = D // 32, staying in registers (4 for D=128)
  • Sigmoid via log2-exp2 trick: uses fast hardware exp2
  • All accumulation in Float32 for numerical stability with bf16/fp16 inputs
  • In-memory compile cache keyed by (dtype, D, has_gate_weight)

When gate_proj_weight is None, skips the CuteDSL kernel entirely and returns
a simple (O_cmp + O_slc + O_sld) / 3 average — avoids kernel launch overhead
for the ungated case.

Returns (output, gates) tuple so gates are available for the backward pass.

PyTorch reference implementations (compute_gates, gate_and_combine) retained
for testing and fallback.

No performance impact — gating alone is not the bottleneck.
{F1987648122}

Differential Revision: D99181847

@meta-cla meta-cla Bot added the cla signed label Apr 2, 2026
@meta-codesync
Copy link
Copy Markdown

meta-codesync Bot commented Apr 2, 2026

@jduprat has exported this pull request. If you are a Meta employee, you can view the originating Diff in D99181847.

jduprat added a commit to jduprat/MSLK that referenced this pull request Apr 2, 2026
Summary:

Replace the multi-kernel PyTorch gating (compute_gates + gate_and_combine) with
a single fused CuteDSL kernel (fused_gate_and_combine) — 4-7x faster on B200.

Key design:
- One warp (32 threads) per (b,n,h) row — each warp handles one output position
- Warp-shuffle butterfly reduction for 3 gate dot-products (no shared memory)
- elems_per_thread = D // 32, staying in registers (4 for D=128)
- Sigmoid via log2-exp2 trick: uses fast hardware exp2
- All accumulation in Float32 for numerical stability with bf16/fp16 inputs
- In-memory compile cache keyed by (dtype, D, has_gate_weight)

When gate_proj_weight is None, skips the CuteDSL kernel entirely and returns
a simple (O_cmp + O_slc + O_sld) / 3 average — avoids kernel launch overhead
for the ungated case.

Returns (output, gates) tuple so gates are available for the backward pass.

PyTorch reference implementations (compute_gates, gate_and_combine) retained
for testing and fallback.

No performance chart yet — gating alone is not the bottleneck. Chart will be
updated after fused scoring (Diff 3) and fused compression (Diff 4).

Differential Revision: D99181847
jduprat added a commit to jduprat/MSLK that referenced this pull request Apr 2, 2026
Summary:

Replace the multi-kernel PyTorch gating (compute_gates + gate_and_combine) with
a single fused CuteDSL kernel (fused_gate_and_combine) — 4-7x faster on B200.

Key design:
- One warp (32 threads) per (b,n,h) row — each warp handles one output position
- Warp-shuffle butterfly reduction for 3 gate dot-products (no shared memory)
- elems_per_thread = D // 32, staying in registers (4 for D=128)
- Sigmoid via log2-exp2 trick: uses fast hardware exp2
- All accumulation in Float32 for numerical stability with bf16/fp16 inputs
- In-memory compile cache keyed by (dtype, D, has_gate_weight)

When gate_proj_weight is None, skips the CuteDSL kernel entirely and returns
a simple (O_cmp + O_slc + O_sld) / 3 average — avoids kernel launch overhead
for the ungated case.

Returns (output, gates) tuple so gates are available for the backward pass.

PyTorch reference implementations (compute_gates, gate_and_combine) retained
for testing and fallback.

No performance chart yet — gating alone is not the bottleneck. Chart will be
updated after fused scoring (Diff 3) and fused compression (Diff 4).

Differential Revision: D99181847
jduprat added a commit to jduprat/MSLK that referenced this pull request Apr 2, 2026
Summary:

Replace the multi-kernel PyTorch gating (compute_gates + gate_and_combine) with
a single fused CuteDSL kernel (fused_gate_and_combine) — 4-7x faster on B200.

Key design:
- One warp (32 threads) per (b,n,h) row — each warp handles one output position
- Warp-shuffle butterfly reduction for 3 gate dot-products (no shared memory)
- elems_per_thread = D // 32, staying in registers (4 for D=128)
- Sigmoid via log2-exp2 trick: uses fast hardware exp2
- All accumulation in Float32 for numerical stability with bf16/fp16 inputs
- In-memory compile cache keyed by (dtype, D, has_gate_weight)

When gate_proj_weight is None, skips the CuteDSL kernel entirely and returns
a simple (O_cmp + O_slc + O_sld) / 3 average — avoids kernel launch overhead
for the ungated case.

Returns (output, gates) tuple so gates are available for the backward pass.

PyTorch reference implementations (compute_gates, gate_and_combine) retained
for testing and fallback.

No performance chart yet — gating alone is not the bottleneck. Chart will be
updated after fused scoring (Diff 3) and fused compression (Diff 4).

Differential Revision: D99181847
jduprat added a commit to jduprat/MSLK that referenced this pull request Apr 2, 2026
Summary:

Replace the multi-kernel PyTorch gating (compute_gates + gate_and_combine) with
a single fused CuteDSL kernel (fused_gate_and_combine) — 4-7x faster on B200.

Key design:
- One warp (32 threads) per (b,n,h) row — each warp handles one output position
- Warp-shuffle butterfly reduction for 3 gate dot-products (no shared memory)
- elems_per_thread = D // 32, staying in registers (4 for D=128)
- Sigmoid via log2-exp2 trick: uses fast hardware exp2
- All accumulation in Float32 for numerical stability with bf16/fp16 inputs
- In-memory compile cache keyed by (dtype, D, has_gate_weight)

When gate_proj_weight is None, skips the CuteDSL kernel entirely and returns
a simple (O_cmp + O_slc + O_sld) / 3 average — avoids kernel launch overhead
for the ungated case.

Returns (output, gates) tuple so gates are available for the backward pass.

PyTorch reference implementations (compute_gates, gate_and_combine) retained
for testing and fallback.

No performance chart yet — gating alone is not the bottleneck. Chart will be
updated after fused scoring (Diff 3) and fused compression (Diff 4).

Differential Revision: D99181847
jduprat added a commit to jduprat/MSLK that referenced this pull request Apr 2, 2026
Summary:

Replace the multi-kernel PyTorch gating (compute_gates + gate_and_combine) with
a single fused CuteDSL kernel (fused_gate_and_combine) — 4-7x faster on B200.

Key design:
- One warp (32 threads) per (b,n,h) row — each warp handles one output position
- Warp-shuffle butterfly reduction for 3 gate dot-products (no shared memory)
- elems_per_thread = D // 32, staying in registers (4 for D=128)
- Sigmoid via log2-exp2 trick: uses fast hardware exp2
- All accumulation in Float32 for numerical stability with bf16/fp16 inputs
- In-memory compile cache keyed by (dtype, D, has_gate_weight)

When gate_proj_weight is None, skips the CuteDSL kernel entirely and returns
a simple (O_cmp + O_slc + O_sld) / 3 average — avoids kernel launch overhead
for the ungated case.

Returns (output, gates) tuple so gates are available for the backward pass.

PyTorch reference implementations (compute_gates, gate_and_combine) retained
for testing and fallback.

No performance chart yet — gating alone is not the bottleneck. Chart will be
updated after fused scoring (Diff 3) and fused compression (Diff 4).

Differential Revision: D99181847
Summary:
Establish the NSA (Native Sparse Attention) module with reference implementations,
compact block-sparse metadata format, and the FA4-based forward pass orchestrator.

Three attention branches combined via learned gating:
1. Compressed: FA4 on mean-pooled KV (short sequence)
2. Selected: FA4 with block sparsity (top-k important blocks per Q-tile)
3. Sliding window: FA4 with window_size_left

Key components:
- compress.py: compress_kv() — mean-pool + optional learned projection
- select.py: score_and_select_blocks() — tiled scoring with O(N) peak memory
- gating.py: compute_gates() + gate_and_combine() — sigmoid gating, chunked
- sparsity_masks.py: build_fa4_block_sparse_tensors() — compact index format
  (last dim = k selected blocks, not n_blocks_k total). Handles both expansion
  (compress_block_size >= n_block_size) and contraction (with sort + dedup).
- nsa_forward.py: nsa_forward() orchestrator + _fa4_fwd() wrapper
- reference.py: Pure PyTorch differentiable reference for correctness validation

FA4 dependency: imports from mslk.attention.flash_attn.interface shim (tries
internal fork, falls back to upstream flash_attn). Uses compress_factor for
compressed causal masking (not mask_mod).

All non-FA4 accumulation paths use fp32 for numerical stability with bf16/fp16.

No performance impact — this is the foundation diff (reference implementations
only, no CuteDSL fused kernels yet). Performance chart N/A for this diff.

Differential Revision: D99181841
jduprat added a commit to jduprat/MSLK that referenced this pull request Apr 2, 2026
Summary:

Replace the multi-kernel PyTorch gating (compute_gates + gate_and_combine) with
a single fused CuteDSL kernel (fused_gate_and_combine) — 4-7x faster on B200.

Key design:
- One warp (32 threads) per (b,n,h) row — each warp handles one output position
- Warp-shuffle butterfly reduction for 3 gate dot-products (no shared memory)
- elems_per_thread = D // 32, staying in registers (4 for D=128)
- Sigmoid via log2-exp2 trick: uses fast hardware exp2
- All accumulation in Float32 for numerical stability with bf16/fp16 inputs
- In-memory compile cache keyed by (dtype, D, has_gate_weight)

When gate_proj_weight is None, skips the CuteDSL kernel entirely and returns
a simple (O_cmp + O_slc + O_sld) / 3 average — avoids kernel launch overhead
for the ungated case.

Returns (output, gates) tuple so gates are available for the backward pass.

PyTorch reference implementations (compute_gates, gate_and_combine) retained
for testing and fallback.

No performance chart yet — gating alone is not the bottleneck. Chart will be
updated after fused scoring (Diff 3) and fused compression (Diff 4).

Differential Revision: D99181847
@meta-codesync meta-codesync Bot changed the title Fused CuteDSL gating kernel Fused CuteDSL gating kernel (#300) Apr 3, 2026
jduprat added a commit to jduprat/MSLK that referenced this pull request Apr 3, 2026
Summary:

Replace the multi-kernel PyTorch gating (compute_gates + gate_and_combine) with
a single fused CuteDSL kernel (fused_gate_and_combine) — 4-7x faster on B200.

Key design:
- One warp (32 threads) per (b,n,h) row — each warp handles one output position
- Warp-shuffle butterfly reduction for 3 gate dot-products (no shared memory)
- elems_per_thread = D // 32, staying in registers (4 for D=128)
- Sigmoid via log2-exp2 trick: uses fast hardware exp2
- All accumulation in Float32 for numerical stability with bf16/fp16 inputs
- In-memory compile cache keyed by (dtype, D, has_gate_weight)

When gate_proj_weight is None, skips the CuteDSL kernel entirely and returns
a simple (O_cmp + O_slc + O_sld) / 3 average — avoids kernel launch overhead
for the ungated case.

Returns (output, gates) tuple so gates are available for the backward pass.

PyTorch reference implementations (compute_gates, gate_and_combine) retained
for testing and fallback.

No performance impact — gating alone is not the bottleneck.
{F1987648122}

Differential Revision: D99181847
@jduprat jduprat force-pushed the export-D99181847 branch from b735721 to 3f1f8ec Compare April 3, 2026 00:46
jduprat added a commit to jduprat/MSLK that referenced this pull request Apr 3, 2026
Summary:

Replace the multi-kernel PyTorch gating (compute_gates + gate_and_combine) with
a single fused CuteDSL kernel (fused_gate_and_combine) — 4-7x faster on B200.

Key design:
- One warp (32 threads) per (b,n,h) row — each warp handles one output position
- Warp-shuffle butterfly reduction for 3 gate dot-products (no shared memory)
- elems_per_thread = D // 32, staying in registers (4 for D=128)
- Sigmoid via log2-exp2 trick: uses fast hardware exp2
- All accumulation in Float32 for numerical stability with bf16/fp16 inputs
- In-memory compile cache keyed by (dtype, D, has_gate_weight)

When gate_proj_weight is None, skips the CuteDSL kernel entirely and returns
a simple (O_cmp + O_slc + O_sld) / 3 average — avoids kernel launch overhead
for the ungated case.

Returns (output, gates) tuple so gates are available for the backward pass.

PyTorch reference implementations (compute_gates, gate_and_combine) retained
for testing and fallback.

No performance impact — gating alone is not the bottleneck.
{F1987648122}

Differential Revision: D99181847
Summary:
Pull Request resolved: meta-pytorch#300

Replace the multi-kernel PyTorch gating (compute_gates + gate_and_combine) with
a single fused CuteDSL kernel (fused_gate_and_combine) — 4-7x faster on B200.

Key design:
- One warp (32 threads) per (b,n,h) row — each warp handles one output position
- Warp-shuffle butterfly reduction for 3 gate dot-products (no shared memory)
- elems_per_thread = D // 32, staying in registers (4 for D=128)
- Sigmoid via log2-exp2 trick: uses fast hardware exp2
- All accumulation in Float32 for numerical stability with bf16/fp16 inputs
- In-memory compile cache keyed by (dtype, D, has_gate_weight)

When gate_proj_weight is None, skips the CuteDSL kernel entirely and returns
a simple (O_cmp + O_slc + O_sld) / 3 average — avoids kernel launch overhead
for the ungated case.

Returns (output, gates) tuple so gates are available for the backward pass.

PyTorch reference implementations (compute_gates, gate_and_combine) retained
for testing and fallback.

No performance impact — gating alone is not the bottleneck.
{F1987648122}

Differential Revision: D99181847
jduprat added a commit to jduprat/MSLK that referenced this pull request Apr 3, 2026
Summary:

Replace the multi-kernel PyTorch gating (compute_gates + gate_and_combine) with
a single fused CuteDSL kernel (fused_gate_and_combine) — 4-7x faster on B200.

Key design:
- One warp (32 threads) per (b,n,h) row — each warp handles one output position
- Warp-shuffle butterfly reduction for 3 gate dot-products (no shared memory)
- elems_per_thread = D // 32, staying in registers (4 for D=128)
- Sigmoid via log2-exp2 trick: uses fast hardware exp2
- All accumulation in Float32 for numerical stability with bf16/fp16 inputs
- In-memory compile cache keyed by (dtype, D, has_gate_weight)

When gate_proj_weight is None, skips the CuteDSL kernel entirely and returns
a simple (O_cmp + O_slc + O_sld) / 3 average — avoids kernel launch overhead
for the ungated case.

Returns (output, gates) tuple so gates are available for the backward pass.

PyTorch reference implementations (compute_gates, gate_and_combine) retained
for testing and fallback.

No performance impact — gating alone is not the bottleneck.
{F1987648122}

Differential Revision: D99181847
@jduprat jduprat force-pushed the export-D99181847 branch from 3f1f8ec to 7afab26 Compare April 3, 2026 00:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant