Skip to content

NSA forward — foundation, reference implementations, compact metadata#299

Open
jduprat wants to merge 1 commit into
meta-pytorch:mainfrom
jduprat:export-D99181841
Open

NSA forward — foundation, reference implementations, compact metadata#299
jduprat wants to merge 1 commit into
meta-pytorch:mainfrom
jduprat:export-D99181841

Conversation

@jduprat
Copy link
Copy Markdown
Contributor

@jduprat jduprat commented Apr 2, 2026

Summary:
Establish the NSA (Native Sparse Attention) module with reference implementations,
compact block-sparse metadata format, and the FA4-based forward pass orchestrator.

Three attention branches combined via learned gating:

  1. Compressed: FA4 on mean-pooled KV (short sequence)
  2. Selected: FA4 with block sparsity (top-k important blocks per Q-tile)
  3. Sliding window: FA4 with window_size_left

Key components:

  • compress.py: compress_kv() — mean-pool + optional learned projection
  • select.py: score_and_select_blocks() — tiled scoring with O(N) peak memory
  • gating.py: compute_gates() + gate_and_combine() — sigmoid gating, chunked
  • sparsity_masks.py: build_fa4_block_sparse_tensors() — compact index format
    (last dim = k selected blocks, not n_blocks_k total). Handles both expansion
    (compress_block_size >= n_block_size) and contraction (with sort + dedup).
  • nsa_forward.py: nsa_forward() orchestrator + _fa4_fwd() wrapper
  • reference.py: Pure PyTorch differentiable reference for correctness validation

FA4 dependency: imports from mslk.attention.flash_attn.interface shim (tries
internal fork, falls back to upstream flash_attn). Uses compress_factor for
compressed causal masking (not mask_mod).

All non-FA4 accumulation paths use fp32 for numerical stability with bf16/fp16.

No performance impact — this is the foundation diff (reference implementations
only, no CuteDSL fused kernels yet). Performance chart N/A for this diff.

Differential Revision: D99181841

Summary:
Establish the NSA (Native Sparse Attention) module with reference implementations,
compact block-sparse metadata format, and the FA4-based forward pass orchestrator.

Three attention branches combined via learned gating:
1. Compressed: FA4 on mean-pooled KV (short sequence)
2. Selected: FA4 with block sparsity (top-k important blocks per Q-tile)
3. Sliding window: FA4 with window_size_left

Key components:
- compress.py: compress_kv() — mean-pool + optional learned projection
- select.py: score_and_select_blocks() — tiled scoring with O(N) peak memory
- gating.py: compute_gates() + gate_and_combine() — sigmoid gating, chunked
- sparsity_masks.py: build_fa4_block_sparse_tensors() — compact index format
  (last dim = k selected blocks, not n_blocks_k total). Handles both expansion
  (compress_block_size >= n_block_size) and contraction (with sort + dedup).
- nsa_forward.py: nsa_forward() orchestrator + _fa4_fwd() wrapper
- reference.py: Pure PyTorch differentiable reference for correctness validation

FA4 dependency: imports from mslk.attention.flash_attn.interface shim (tries
internal fork, falls back to upstream flash_attn). Uses compress_factor for
compressed causal masking (not mask_mod).

All non-FA4 accumulation paths use fp32 for numerical stability with bf16/fp16.

No performance impact — this is the foundation diff (reference implementations
only, no CuteDSL fused kernels yet). Performance chart N/A for this diff.

Differential Revision: D99181841
@meta-cla meta-cla Bot added the cla signed label Apr 2, 2026
@meta-codesync
Copy link
Copy Markdown

meta-codesync Bot commented Apr 2, 2026

@jduprat has exported this pull request. If you are a Meta employee, you can view the originating Diff in D99181841.

jduprat added a commit to jduprat/MSLK that referenced this pull request Apr 2, 2026
…meta-pytorch#299)

Summary:

Establish the NSA (Native Sparse Attention) module with reference implementations,
compact block-sparse metadata format, and the FA4-based forward pass orchestrator.

Three attention branches combined via learned gating:
1. Compressed: FA4 on mean-pooled KV (short sequence)
2. Selected: FA4 with block sparsity (top-k important blocks per Q-tile)
3. Sliding window: FA4 with window_size_left

Key components:
- compress.py: compress_kv() — mean-pool + optional learned projection
- select.py: score_and_select_blocks() — tiled scoring with O(N) peak memory
- gating.py: compute_gates() + gate_and_combine() — sigmoid gating, chunked
- sparsity_masks.py: build_fa4_block_sparse_tensors() — compact index format
  (last dim = k selected blocks, not n_blocks_k total). Handles both expansion
  (compress_block_size >= n_block_size) and contraction (with sort + dedup).
- nsa_forward.py: nsa_forward() orchestrator + _fa4_fwd() wrapper
- reference.py: Pure PyTorch differentiable reference for correctness validation

FA4 dependency: imports from mslk.attention.flash_attn.interface shim (tries
internal fork, falls back to upstream flash_attn). Uses compress_factor for
compressed causal masking (not mask_mod).

All non-FA4 accumulation paths use fp32 for numerical stability with bf16/fp16.

No performance impact — this is the foundation diff (reference implementations
only, no CuteDSL fused kernels yet). Performance chart N/A for this diff.

Differential Revision: D99181841
jduprat added a commit to jduprat/MSLK that referenced this pull request Apr 2, 2026
…meta-pytorch#299)

Summary:

Establish the NSA (Native Sparse Attention) module with reference implementations,
compact block-sparse metadata format, and the FA4-based forward pass orchestrator.

Three attention branches combined via learned gating:
1. Compressed: FA4 on mean-pooled KV (short sequence)
2. Selected: FA4 with block sparsity (top-k important blocks per Q-tile)
3. Sliding window: FA4 with window_size_left

Key components:
- compress.py: compress_kv() — mean-pool + optional learned projection
- select.py: score_and_select_blocks() — tiled scoring with O(N) peak memory
- gating.py: compute_gates() + gate_and_combine() — sigmoid gating, chunked
- sparsity_masks.py: build_fa4_block_sparse_tensors() — compact index format
  (last dim = k selected blocks, not n_blocks_k total). Handles both expansion
  (compress_block_size >= n_block_size) and contraction (with sort + dedup).
- nsa_forward.py: nsa_forward() orchestrator + _fa4_fwd() wrapper
- reference.py: Pure PyTorch differentiable reference for correctness validation

FA4 dependency: imports from mslk.attention.flash_attn.interface shim (tries
internal fork, falls back to upstream flash_attn). Uses compress_factor for
compressed causal masking (not mask_mod).

All non-FA4 accumulation paths use fp32 for numerical stability with bf16/fp16.

No performance impact — this is the foundation diff (reference implementations
only, no CuteDSL fused kernels yet). Performance chart N/A for this diff.

Differential Revision: D99181841
jduprat added a commit to jduprat/MSLK that referenced this pull request Apr 2, 2026
…meta-pytorch#299)

Summary:

Establish the NSA (Native Sparse Attention) module with reference implementations,
compact block-sparse metadata format, and the FA4-based forward pass orchestrator.

Three attention branches combined via learned gating:
1. Compressed: FA4 on mean-pooled KV (short sequence)
2. Selected: FA4 with block sparsity (top-k important blocks per Q-tile)
3. Sliding window: FA4 with window_size_left

Key components:
- compress.py: compress_kv() — mean-pool + optional learned projection
- select.py: score_and_select_blocks() — tiled scoring with O(N) peak memory
- gating.py: compute_gates() + gate_and_combine() — sigmoid gating, chunked
- sparsity_masks.py: build_fa4_block_sparse_tensors() — compact index format
  (last dim = k selected blocks, not n_blocks_k total). Handles both expansion
  (compress_block_size >= n_block_size) and contraction (with sort + dedup).
- nsa_forward.py: nsa_forward() orchestrator + _fa4_fwd() wrapper
- reference.py: Pure PyTorch differentiable reference for correctness validation

FA4 dependency: imports from mslk.attention.flash_attn.interface shim (tries
internal fork, falls back to upstream flash_attn). Uses compress_factor for
compressed causal masking (not mask_mod).

All non-FA4 accumulation paths use fp32 for numerical stability with bf16/fp16.

No performance impact — this is the foundation diff (reference implementations
only, no CuteDSL fused kernels yet). Performance chart N/A for this diff.

Differential Revision: D99181841
jduprat added a commit to jduprat/MSLK that referenced this pull request Apr 2, 2026
…meta-pytorch#299)

Summary:

Establish the NSA (Native Sparse Attention) module with reference implementations,
compact block-sparse metadata format, and the FA4-based forward pass orchestrator.

Three attention branches combined via learned gating:
1. Compressed: FA4 on mean-pooled KV (short sequence)
2. Selected: FA4 with block sparsity (top-k important blocks per Q-tile)
3. Sliding window: FA4 with window_size_left

Key components:
- compress.py: compress_kv() — mean-pool + optional learned projection
- select.py: score_and_select_blocks() — tiled scoring with O(N) peak memory
- gating.py: compute_gates() + gate_and_combine() — sigmoid gating, chunked
- sparsity_masks.py: build_fa4_block_sparse_tensors() — compact index format
  (last dim = k selected blocks, not n_blocks_k total). Handles both expansion
  (compress_block_size >= n_block_size) and contraction (with sort + dedup).
- nsa_forward.py: nsa_forward() orchestrator + _fa4_fwd() wrapper
- reference.py: Pure PyTorch differentiable reference for correctness validation

FA4 dependency: imports from mslk.attention.flash_attn.interface shim (tries
internal fork, falls back to upstream flash_attn). Uses compress_factor for
compressed causal masking (not mask_mod).

All non-FA4 accumulation paths use fp32 for numerical stability with bf16/fp16.

No performance impact — this is the foundation diff (reference implementations
only, no CuteDSL fused kernels yet). Performance chart N/A for this diff.

Differential Revision: D99181841
jduprat added a commit to jduprat/MSLK that referenced this pull request Apr 2, 2026
…meta-pytorch#299)

Summary:

Establish the NSA (Native Sparse Attention) module with reference implementations,
compact block-sparse metadata format, and the FA4-based forward pass orchestrator.

Three attention branches combined via learned gating:
1. Compressed: FA4 on mean-pooled KV (short sequence)
2. Selected: FA4 with block sparsity (top-k important blocks per Q-tile)
3. Sliding window: FA4 with window_size_left

Key components:
- compress.py: compress_kv() — mean-pool + optional learned projection
- select.py: score_and_select_blocks() — tiled scoring with O(N) peak memory
- gating.py: compute_gates() + gate_and_combine() — sigmoid gating, chunked
- sparsity_masks.py: build_fa4_block_sparse_tensors() — compact index format
  (last dim = k selected blocks, not n_blocks_k total). Handles both expansion
  (compress_block_size >= n_block_size) and contraction (with sort + dedup).
- nsa_forward.py: nsa_forward() orchestrator + _fa4_fwd() wrapper
- reference.py: Pure PyTorch differentiable reference for correctness validation

FA4 dependency: imports from mslk.attention.flash_attn.interface shim (tries
internal fork, falls back to upstream flash_attn). Uses compress_factor for
compressed causal masking (not mask_mod).

All non-FA4 accumulation paths use fp32 for numerical stability with bf16/fp16.

No performance impact — this is the foundation diff (reference implementations
only, no CuteDSL fused kernels yet). Performance chart N/A for this diff.

Differential Revision: D99181841
jduprat added a commit to jduprat/MSLK that referenced this pull request Apr 2, 2026
…meta-pytorch#299)

Summary:

Establish the NSA (Native Sparse Attention) module with reference implementations,
compact block-sparse metadata format, and the FA4-based forward pass orchestrator.

Three attention branches combined via learned gating:
1. Compressed: FA4 on mean-pooled KV (short sequence)
2. Selected: FA4 with block sparsity (top-k important blocks per Q-tile)
3. Sliding window: FA4 with window_size_left

Key components:
- compress.py: compress_kv() — mean-pool + optional learned projection
- select.py: score_and_select_blocks() — tiled scoring with O(N) peak memory
- gating.py: compute_gates() + gate_and_combine() — sigmoid gating, chunked
- sparsity_masks.py: build_fa4_block_sparse_tensors() — compact index format
  (last dim = k selected blocks, not n_blocks_k total). Handles both expansion
  (compress_block_size >= n_block_size) and contraction (with sort + dedup).
- nsa_forward.py: nsa_forward() orchestrator + _fa4_fwd() wrapper
- reference.py: Pure PyTorch differentiable reference for correctness validation

FA4 dependency: imports from mslk.attention.flash_attn.interface shim (tries
internal fork, falls back to upstream flash_attn). Uses compress_factor for
compressed causal masking (not mask_mod).

All non-FA4 accumulation paths use fp32 for numerical stability with bf16/fp16.

No performance impact — this is the foundation diff (reference implementations
only, no CuteDSL fused kernels yet). Performance chart N/A for this diff.

Differential Revision: D99181841
jduprat added a commit to jduprat/MSLK that referenced this pull request Apr 2, 2026
…meta-pytorch#299)

Summary:

Establish the NSA (Native Sparse Attention) module with reference implementations,
compact block-sparse metadata format, and the FA4-based forward pass orchestrator.

Three attention branches combined via learned gating:
1. Compressed: FA4 on mean-pooled KV (short sequence)
2. Selected: FA4 with block sparsity (top-k important blocks per Q-tile)
3. Sliding window: FA4 with window_size_left

Key components:
- compress.py: compress_kv() — mean-pool + optional learned projection
- select.py: score_and_select_blocks() — tiled scoring with O(N) peak memory
- gating.py: compute_gates() + gate_and_combine() — sigmoid gating, chunked
- sparsity_masks.py: build_fa4_block_sparse_tensors() — compact index format
  (last dim = k selected blocks, not n_blocks_k total). Handles both expansion
  (compress_block_size >= n_block_size) and contraction (with sort + dedup).
- nsa_forward.py: nsa_forward() orchestrator + _fa4_fwd() wrapper
- reference.py: Pure PyTorch differentiable reference for correctness validation

FA4 dependency: imports from mslk.attention.flash_attn.interface shim (tries
internal fork, falls back to upstream flash_attn). Uses compress_factor for
compressed causal masking (not mask_mod).

All non-FA4 accumulation paths use fp32 for numerical stability with bf16/fp16.

No performance impact — this is the foundation diff (reference implementations
only, no CuteDSL fused kernels yet). Performance chart N/A for this diff.

Differential Revision: D99181841
jduprat added a commit to jduprat/MSLK that referenced this pull request Apr 3, 2026
…meta-pytorch#299)

Summary:

Establish the NSA (Native Sparse Attention) module with reference implementations,
compact block-sparse metadata format, and the FA4-based forward pass orchestrator.

Three attention branches combined via learned gating:
1. Compressed: FA4 on mean-pooled KV (short sequence)
2. Selected: FA4 with block sparsity (top-k important blocks per Q-tile)
3. Sliding window: FA4 with window_size_left

Key components:
- compress.py: compress_kv() — mean-pool + optional learned projection
- select.py: score_and_select_blocks() — tiled scoring with O(N) peak memory
- gating.py: compute_gates() + gate_and_combine() — sigmoid gating, chunked
- sparsity_masks.py: build_fa4_block_sparse_tensors() — compact index format
  (last dim = k selected blocks, not n_blocks_k total). Handles both expansion
  (compress_block_size >= n_block_size) and contraction (with sort + dedup).
- nsa_forward.py: nsa_forward() orchestrator + _fa4_fwd() wrapper
- reference.py: Pure PyTorch differentiable reference for correctness validation

FA4 dependency: imports from mslk.attention.flash_attn.interface shim (tries
internal fork, falls back to upstream flash_attn). Uses compress_factor for
compressed causal masking (not mask_mod).

All non-FA4 accumulation paths use fp32 for numerical stability with bf16/fp16.

This is the foundation diff (reference implementations only).
NSA is faster than FA4 because its algorithmic complexity is lower.

 {F1987648002}

Differential Revision: D99181841
jduprat added a commit to jduprat/MSLK that referenced this pull request Apr 3, 2026
…meta-pytorch#299)

Summary:

Establish the NSA (Native Sparse Attention) module with reference implementations,
compact block-sparse metadata format, and the FA4-based forward pass orchestrator.

Three attention branches combined via learned gating:
1. Compressed: FA4 on mean-pooled KV (short sequence)
2. Selected: FA4 with block sparsity (top-k important blocks per Q-tile)
3. Sliding window: FA4 with window_size_left

Key components:
- compress.py: compress_kv() — mean-pool + optional learned projection
- select.py: score_and_select_blocks() — tiled scoring with O(N) peak memory
- gating.py: compute_gates() + gate_and_combine() — sigmoid gating, chunked
- sparsity_masks.py: build_fa4_block_sparse_tensors() — compact index format
  (last dim = k selected blocks, not n_blocks_k total). Handles both expansion
  (compress_block_size >= n_block_size) and contraction (with sort + dedup).
- nsa_forward.py: nsa_forward() orchestrator + _fa4_fwd() wrapper
- reference.py: Pure PyTorch differentiable reference for correctness validation

FA4 dependency: imports from mslk.attention.flash_attn.interface shim (tries
internal fork, falls back to upstream flash_attn). Uses compress_factor for
compressed causal masking (not mask_mod).

All non-FA4 accumulation paths use fp32 for numerical stability with bf16/fp16.

This is the foundation diff (reference implementations only).
NSA is faster than FA4 because its algorithmic complexity is lower.

 {F1987648002}

Differential Revision: D99181841
jduprat added a commit to jduprat/MSLK that referenced this pull request Apr 3, 2026
…meta-pytorch#299)

Summary:

Establish the NSA (Native Sparse Attention) module with reference implementations,
compact block-sparse metadata format, and the FA4-based forward pass orchestrator.

Three attention branches combined via learned gating:
1. Compressed: FA4 on mean-pooled KV (short sequence)
2. Selected: FA4 with block sparsity (top-k important blocks per Q-tile)
3. Sliding window: FA4 with window_size_left

Key components:
- compress.py: compress_kv() — mean-pool + optional learned projection
- select.py: score_and_select_blocks() — tiled scoring with O(N) peak memory
- gating.py: compute_gates() + gate_and_combine() — sigmoid gating, chunked
- sparsity_masks.py: build_fa4_block_sparse_tensors() — compact index format
  (last dim = k selected blocks, not n_blocks_k total). Handles both expansion
  (compress_block_size >= n_block_size) and contraction (with sort + dedup).
- nsa_forward.py: nsa_forward() orchestrator + _fa4_fwd() wrapper
- reference.py: Pure PyTorch differentiable reference for correctness validation

FA4 dependency: imports from mslk.attention.flash_attn.interface shim (tries
internal fork, falls back to upstream flash_attn). Uses compress_factor for
compressed causal masking (not mask_mod).

All non-FA4 accumulation paths use fp32 for numerical stability with bf16/fp16.

This is the foundation diff (reference implementations only).
NSA is faster than FA4 because its algorithmic complexity is lower.

 {F1987648002}

Differential Revision: D99181841
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant