NSA forward — foundation, reference implementations, compact metadata#299
Open
jduprat wants to merge 1 commit into
Open
NSA forward — foundation, reference implementations, compact metadata#299jduprat wants to merge 1 commit into
jduprat wants to merge 1 commit into
Conversation
Summary: Establish the NSA (Native Sparse Attention) module with reference implementations, compact block-sparse metadata format, and the FA4-based forward pass orchestrator. Three attention branches combined via learned gating: 1. Compressed: FA4 on mean-pooled KV (short sequence) 2. Selected: FA4 with block sparsity (top-k important blocks per Q-tile) 3. Sliding window: FA4 with window_size_left Key components: - compress.py: compress_kv() — mean-pool + optional learned projection - select.py: score_and_select_blocks() — tiled scoring with O(N) peak memory - gating.py: compute_gates() + gate_and_combine() — sigmoid gating, chunked - sparsity_masks.py: build_fa4_block_sparse_tensors() — compact index format (last dim = k selected blocks, not n_blocks_k total). Handles both expansion (compress_block_size >= n_block_size) and contraction (with sort + dedup). - nsa_forward.py: nsa_forward() orchestrator + _fa4_fwd() wrapper - reference.py: Pure PyTorch differentiable reference for correctness validation FA4 dependency: imports from mslk.attention.flash_attn.interface shim (tries internal fork, falls back to upstream flash_attn). Uses compress_factor for compressed causal masking (not mask_mod). All non-FA4 accumulation paths use fp32 for numerical stability with bf16/fp16. No performance impact — this is the foundation diff (reference implementations only, no CuteDSL fused kernels yet). Performance chart N/A for this diff. Differential Revision: D99181841
jduprat
added a commit
to jduprat/MSLK
that referenced
this pull request
Apr 2, 2026
…meta-pytorch#299) Summary: Establish the NSA (Native Sparse Attention) module with reference implementations, compact block-sparse metadata format, and the FA4-based forward pass orchestrator. Three attention branches combined via learned gating: 1. Compressed: FA4 on mean-pooled KV (short sequence) 2. Selected: FA4 with block sparsity (top-k important blocks per Q-tile) 3. Sliding window: FA4 with window_size_left Key components: - compress.py: compress_kv() — mean-pool + optional learned projection - select.py: score_and_select_blocks() — tiled scoring with O(N) peak memory - gating.py: compute_gates() + gate_and_combine() — sigmoid gating, chunked - sparsity_masks.py: build_fa4_block_sparse_tensors() — compact index format (last dim = k selected blocks, not n_blocks_k total). Handles both expansion (compress_block_size >= n_block_size) and contraction (with sort + dedup). - nsa_forward.py: nsa_forward() orchestrator + _fa4_fwd() wrapper - reference.py: Pure PyTorch differentiable reference for correctness validation FA4 dependency: imports from mslk.attention.flash_attn.interface shim (tries internal fork, falls back to upstream flash_attn). Uses compress_factor for compressed causal masking (not mask_mod). All non-FA4 accumulation paths use fp32 for numerical stability with bf16/fp16. No performance impact — this is the foundation diff (reference implementations only, no CuteDSL fused kernels yet). Performance chart N/A for this diff. Differential Revision: D99181841
jduprat
added a commit
to jduprat/MSLK
that referenced
this pull request
Apr 2, 2026
…meta-pytorch#299) Summary: Establish the NSA (Native Sparse Attention) module with reference implementations, compact block-sparse metadata format, and the FA4-based forward pass orchestrator. Three attention branches combined via learned gating: 1. Compressed: FA4 on mean-pooled KV (short sequence) 2. Selected: FA4 with block sparsity (top-k important blocks per Q-tile) 3. Sliding window: FA4 with window_size_left Key components: - compress.py: compress_kv() — mean-pool + optional learned projection - select.py: score_and_select_blocks() — tiled scoring with O(N) peak memory - gating.py: compute_gates() + gate_and_combine() — sigmoid gating, chunked - sparsity_masks.py: build_fa4_block_sparse_tensors() — compact index format (last dim = k selected blocks, not n_blocks_k total). Handles both expansion (compress_block_size >= n_block_size) and contraction (with sort + dedup). - nsa_forward.py: nsa_forward() orchestrator + _fa4_fwd() wrapper - reference.py: Pure PyTorch differentiable reference for correctness validation FA4 dependency: imports from mslk.attention.flash_attn.interface shim (tries internal fork, falls back to upstream flash_attn). Uses compress_factor for compressed causal masking (not mask_mod). All non-FA4 accumulation paths use fp32 for numerical stability with bf16/fp16. No performance impact — this is the foundation diff (reference implementations only, no CuteDSL fused kernels yet). Performance chart N/A for this diff. Differential Revision: D99181841
jduprat
added a commit
to jduprat/MSLK
that referenced
this pull request
Apr 2, 2026
…meta-pytorch#299) Summary: Establish the NSA (Native Sparse Attention) module with reference implementations, compact block-sparse metadata format, and the FA4-based forward pass orchestrator. Three attention branches combined via learned gating: 1. Compressed: FA4 on mean-pooled KV (short sequence) 2. Selected: FA4 with block sparsity (top-k important blocks per Q-tile) 3. Sliding window: FA4 with window_size_left Key components: - compress.py: compress_kv() — mean-pool + optional learned projection - select.py: score_and_select_blocks() — tiled scoring with O(N) peak memory - gating.py: compute_gates() + gate_and_combine() — sigmoid gating, chunked - sparsity_masks.py: build_fa4_block_sparse_tensors() — compact index format (last dim = k selected blocks, not n_blocks_k total). Handles both expansion (compress_block_size >= n_block_size) and contraction (with sort + dedup). - nsa_forward.py: nsa_forward() orchestrator + _fa4_fwd() wrapper - reference.py: Pure PyTorch differentiable reference for correctness validation FA4 dependency: imports from mslk.attention.flash_attn.interface shim (tries internal fork, falls back to upstream flash_attn). Uses compress_factor for compressed causal masking (not mask_mod). All non-FA4 accumulation paths use fp32 for numerical stability with bf16/fp16. No performance impact — this is the foundation diff (reference implementations only, no CuteDSL fused kernels yet). Performance chart N/A for this diff. Differential Revision: D99181841
jduprat
added a commit
to jduprat/MSLK
that referenced
this pull request
Apr 2, 2026
…meta-pytorch#299) Summary: Establish the NSA (Native Sparse Attention) module with reference implementations, compact block-sparse metadata format, and the FA4-based forward pass orchestrator. Three attention branches combined via learned gating: 1. Compressed: FA4 on mean-pooled KV (short sequence) 2. Selected: FA4 with block sparsity (top-k important blocks per Q-tile) 3. Sliding window: FA4 with window_size_left Key components: - compress.py: compress_kv() — mean-pool + optional learned projection - select.py: score_and_select_blocks() — tiled scoring with O(N) peak memory - gating.py: compute_gates() + gate_and_combine() — sigmoid gating, chunked - sparsity_masks.py: build_fa4_block_sparse_tensors() — compact index format (last dim = k selected blocks, not n_blocks_k total). Handles both expansion (compress_block_size >= n_block_size) and contraction (with sort + dedup). - nsa_forward.py: nsa_forward() orchestrator + _fa4_fwd() wrapper - reference.py: Pure PyTorch differentiable reference for correctness validation FA4 dependency: imports from mslk.attention.flash_attn.interface shim (tries internal fork, falls back to upstream flash_attn). Uses compress_factor for compressed causal masking (not mask_mod). All non-FA4 accumulation paths use fp32 for numerical stability with bf16/fp16. No performance impact — this is the foundation diff (reference implementations only, no CuteDSL fused kernels yet). Performance chart N/A for this diff. Differential Revision: D99181841
jduprat
added a commit
to jduprat/MSLK
that referenced
this pull request
Apr 2, 2026
…meta-pytorch#299) Summary: Establish the NSA (Native Sparse Attention) module with reference implementations, compact block-sparse metadata format, and the FA4-based forward pass orchestrator. Three attention branches combined via learned gating: 1. Compressed: FA4 on mean-pooled KV (short sequence) 2. Selected: FA4 with block sparsity (top-k important blocks per Q-tile) 3. Sliding window: FA4 with window_size_left Key components: - compress.py: compress_kv() — mean-pool + optional learned projection - select.py: score_and_select_blocks() — tiled scoring with O(N) peak memory - gating.py: compute_gates() + gate_and_combine() — sigmoid gating, chunked - sparsity_masks.py: build_fa4_block_sparse_tensors() — compact index format (last dim = k selected blocks, not n_blocks_k total). Handles both expansion (compress_block_size >= n_block_size) and contraction (with sort + dedup). - nsa_forward.py: nsa_forward() orchestrator + _fa4_fwd() wrapper - reference.py: Pure PyTorch differentiable reference for correctness validation FA4 dependency: imports from mslk.attention.flash_attn.interface shim (tries internal fork, falls back to upstream flash_attn). Uses compress_factor for compressed causal masking (not mask_mod). All non-FA4 accumulation paths use fp32 for numerical stability with bf16/fp16. No performance impact — this is the foundation diff (reference implementations only, no CuteDSL fused kernels yet). Performance chart N/A for this diff. Differential Revision: D99181841
jduprat
added a commit
to jduprat/MSLK
that referenced
this pull request
Apr 2, 2026
…meta-pytorch#299) Summary: Establish the NSA (Native Sparse Attention) module with reference implementations, compact block-sparse metadata format, and the FA4-based forward pass orchestrator. Three attention branches combined via learned gating: 1. Compressed: FA4 on mean-pooled KV (short sequence) 2. Selected: FA4 with block sparsity (top-k important blocks per Q-tile) 3. Sliding window: FA4 with window_size_left Key components: - compress.py: compress_kv() — mean-pool + optional learned projection - select.py: score_and_select_blocks() — tiled scoring with O(N) peak memory - gating.py: compute_gates() + gate_and_combine() — sigmoid gating, chunked - sparsity_masks.py: build_fa4_block_sparse_tensors() — compact index format (last dim = k selected blocks, not n_blocks_k total). Handles both expansion (compress_block_size >= n_block_size) and contraction (with sort + dedup). - nsa_forward.py: nsa_forward() orchestrator + _fa4_fwd() wrapper - reference.py: Pure PyTorch differentiable reference for correctness validation FA4 dependency: imports from mslk.attention.flash_attn.interface shim (tries internal fork, falls back to upstream flash_attn). Uses compress_factor for compressed causal masking (not mask_mod). All non-FA4 accumulation paths use fp32 for numerical stability with bf16/fp16. No performance impact — this is the foundation diff (reference implementations only, no CuteDSL fused kernels yet). Performance chart N/A for this diff. Differential Revision: D99181841
jduprat
added a commit
to jduprat/MSLK
that referenced
this pull request
Apr 2, 2026
…meta-pytorch#299) Summary: Establish the NSA (Native Sparse Attention) module with reference implementations, compact block-sparse metadata format, and the FA4-based forward pass orchestrator. Three attention branches combined via learned gating: 1. Compressed: FA4 on mean-pooled KV (short sequence) 2. Selected: FA4 with block sparsity (top-k important blocks per Q-tile) 3. Sliding window: FA4 with window_size_left Key components: - compress.py: compress_kv() — mean-pool + optional learned projection - select.py: score_and_select_blocks() — tiled scoring with O(N) peak memory - gating.py: compute_gates() + gate_and_combine() — sigmoid gating, chunked - sparsity_masks.py: build_fa4_block_sparse_tensors() — compact index format (last dim = k selected blocks, not n_blocks_k total). Handles both expansion (compress_block_size >= n_block_size) and contraction (with sort + dedup). - nsa_forward.py: nsa_forward() orchestrator + _fa4_fwd() wrapper - reference.py: Pure PyTorch differentiable reference for correctness validation FA4 dependency: imports from mslk.attention.flash_attn.interface shim (tries internal fork, falls back to upstream flash_attn). Uses compress_factor for compressed causal masking (not mask_mod). All non-FA4 accumulation paths use fp32 for numerical stability with bf16/fp16. No performance impact — this is the foundation diff (reference implementations only, no CuteDSL fused kernels yet). Performance chart N/A for this diff. Differential Revision: D99181841
jduprat
added a commit
to jduprat/MSLK
that referenced
this pull request
Apr 3, 2026
…meta-pytorch#299) Summary: Establish the NSA (Native Sparse Attention) module with reference implementations, compact block-sparse metadata format, and the FA4-based forward pass orchestrator. Three attention branches combined via learned gating: 1. Compressed: FA4 on mean-pooled KV (short sequence) 2. Selected: FA4 with block sparsity (top-k important blocks per Q-tile) 3. Sliding window: FA4 with window_size_left Key components: - compress.py: compress_kv() — mean-pool + optional learned projection - select.py: score_and_select_blocks() — tiled scoring with O(N) peak memory - gating.py: compute_gates() + gate_and_combine() — sigmoid gating, chunked - sparsity_masks.py: build_fa4_block_sparse_tensors() — compact index format (last dim = k selected blocks, not n_blocks_k total). Handles both expansion (compress_block_size >= n_block_size) and contraction (with sort + dedup). - nsa_forward.py: nsa_forward() orchestrator + _fa4_fwd() wrapper - reference.py: Pure PyTorch differentiable reference for correctness validation FA4 dependency: imports from mslk.attention.flash_attn.interface shim (tries internal fork, falls back to upstream flash_attn). Uses compress_factor for compressed causal masking (not mask_mod). All non-FA4 accumulation paths use fp32 for numerical stability with bf16/fp16. This is the foundation diff (reference implementations only). NSA is faster than FA4 because its algorithmic complexity is lower. {F1987648002} Differential Revision: D99181841
jduprat
added a commit
to jduprat/MSLK
that referenced
this pull request
Apr 3, 2026
…meta-pytorch#299) Summary: Establish the NSA (Native Sparse Attention) module with reference implementations, compact block-sparse metadata format, and the FA4-based forward pass orchestrator. Three attention branches combined via learned gating: 1. Compressed: FA4 on mean-pooled KV (short sequence) 2. Selected: FA4 with block sparsity (top-k important blocks per Q-tile) 3. Sliding window: FA4 with window_size_left Key components: - compress.py: compress_kv() — mean-pool + optional learned projection - select.py: score_and_select_blocks() — tiled scoring with O(N) peak memory - gating.py: compute_gates() + gate_and_combine() — sigmoid gating, chunked - sparsity_masks.py: build_fa4_block_sparse_tensors() — compact index format (last dim = k selected blocks, not n_blocks_k total). Handles both expansion (compress_block_size >= n_block_size) and contraction (with sort + dedup). - nsa_forward.py: nsa_forward() orchestrator + _fa4_fwd() wrapper - reference.py: Pure PyTorch differentiable reference for correctness validation FA4 dependency: imports from mslk.attention.flash_attn.interface shim (tries internal fork, falls back to upstream flash_attn). Uses compress_factor for compressed causal masking (not mask_mod). All non-FA4 accumulation paths use fp32 for numerical stability with bf16/fp16. This is the foundation diff (reference implementations only). NSA is faster than FA4 because its algorithmic complexity is lower. {F1987648002} Differential Revision: D99181841
jduprat
added a commit
to jduprat/MSLK
that referenced
this pull request
Apr 3, 2026
…meta-pytorch#299) Summary: Establish the NSA (Native Sparse Attention) module with reference implementations, compact block-sparse metadata format, and the FA4-based forward pass orchestrator. Three attention branches combined via learned gating: 1. Compressed: FA4 on mean-pooled KV (short sequence) 2. Selected: FA4 with block sparsity (top-k important blocks per Q-tile) 3. Sliding window: FA4 with window_size_left Key components: - compress.py: compress_kv() — mean-pool + optional learned projection - select.py: score_and_select_blocks() — tiled scoring with O(N) peak memory - gating.py: compute_gates() + gate_and_combine() — sigmoid gating, chunked - sparsity_masks.py: build_fa4_block_sparse_tensors() — compact index format (last dim = k selected blocks, not n_blocks_k total). Handles both expansion (compress_block_size >= n_block_size) and contraction (with sort + dedup). - nsa_forward.py: nsa_forward() orchestrator + _fa4_fwd() wrapper - reference.py: Pure PyTorch differentiable reference for correctness validation FA4 dependency: imports from mslk.attention.flash_attn.interface shim (tries internal fork, falls back to upstream flash_attn). Uses compress_factor for compressed causal masking (not mask_mod). All non-FA4 accumulation paths use fp32 for numerical stability with bf16/fp16. This is the foundation diff (reference implementations only). NSA is faster than FA4 because its algorithmic complexity is lower. {F1987648002} Differential Revision: D99181841
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary:
Establish the NSA (Native Sparse Attention) module with reference implementations,
compact block-sparse metadata format, and the FA4-based forward pass orchestrator.
Three attention branches combined via learned gating:
Key components:
(last dim = k selected blocks, not n_blocks_k total). Handles both expansion
(compress_block_size >= n_block_size) and contraction (with sort + dedup).
FA4 dependency: imports from mslk.attention.flash_attn.interface shim (tries
internal fork, falls back to upstream flash_attn). Uses compress_factor for
compressed causal masking (not mask_mod).
All non-FA4 accumulation paths use fp32 for numerical stability with bf16/fp16.
No performance impact — this is the foundation diff (reference implementations
only, no CuteDSL fused kernels yet). Performance chart N/A for this diff.
Differential Revision: D99181841