Skip to content

NSA backward — benchmarks and performance documentation (#307)#307

Open
jduprat wants to merge 9 commits into
meta-pytorch:mainfrom
jduprat:export-D99181849
Open

NSA backward — benchmarks and performance documentation (#307)#307
jduprat wants to merge 9 commits into
meta-pytorch:mainfrom
jduprat:export-D99181849

Conversation

@jduprat
Copy link
Copy Markdown
Contributor

@jduprat jduprat commented Apr 2, 2026

Summary:

Add backward performance measurement, benchmarking, and visualization utilities.

Benchmark scripts:

  • bench_nsa_backward.py: Forward + fwd+bwd timing at N=1K-65K, reporting
    wall-clock ms and dense-equivalent TFLOPS
  • bench_nsa_vs_dense.py: NSA vs dense FA4 latency from 1K to 3M tokens,
    the primary metric for this project
  • bench_sparse_compressed.py: Compressed branch benchmarks
  • bench_2m_3m.py: Large sequence benchmarks (N >= 2M)

Performance visualization:

  • docs/gen_nsa_perf_chart.py: SVG chart generator showing NSA vs dense FA4
    latency, forward speedup curve, and fwd+bwd breakdown with OOM boundary

gc.collect + expandable_segments + pre-warmup in all benchmarks for
accurate timing at large N.

No performance impact on fwd pass.
fwd+bwd reaching 5-7x at the same scale. Crossover point at ~24K tokens.

{F1987648647}

Differential Revision: D99181849

@meta-codesync
Copy link
Copy Markdown

meta-codesync Bot commented Apr 2, 2026

@jduprat has exported this pull request. If you are a Meta employee, you can view the originating Diff in D99181849.

jduprat added 9 commits April 2, 2026 07:06
Summary:
Establish the NSA (Native Sparse Attention) module with reference implementations,
compact block-sparse metadata format, and the FA4-based forward pass orchestrator.

Three attention branches combined via learned gating:
1. Compressed: FA4 on mean-pooled KV (short sequence)
2. Selected: FA4 with block sparsity (top-k important blocks per Q-tile)
3. Sliding window: FA4 with window_size_left

Key components:
- compress.py: compress_kv() — mean-pool + optional learned projection
- select.py: score_and_select_blocks() — tiled scoring with O(N) peak memory
- gating.py: compute_gates() + gate_and_combine() — sigmoid gating, chunked
- sparsity_masks.py: build_fa4_block_sparse_tensors() — compact index format
  (last dim = k selected blocks, not n_blocks_k total). Handles both expansion
  (compress_block_size >= n_block_size) and contraction (with sort + dedup).
- nsa_forward.py: nsa_forward() orchestrator + _fa4_fwd() wrapper
- reference.py: Pure PyTorch differentiable reference for correctness validation

FA4 dependency: imports from mslk.attention.flash_attn.interface shim (tries
internal fork, falls back to upstream flash_attn). Uses compress_factor for
compressed causal masking (not mask_mod).

All non-FA4 accumulation paths use fp32 for numerical stability with bf16/fp16.

No performance impact — this is the foundation diff (reference implementations
only, no CuteDSL fused kernels yet). Performance chart N/A for this diff.

Differential Revision: D99181841
Summary:
Replace the multi-kernel PyTorch gating (compute_gates + gate_and_combine) with
a single fused CuteDSL kernel (fused_gate_and_combine) — 4-7x faster on B200.

Key design:
- One warp (32 threads) per (b,n,h) row — each warp handles one output position
- Warp-shuffle butterfly reduction for 3 gate dot-products (no shared memory)
- elems_per_thread = D // 32, staying in registers (4 for D=128)
- Sigmoid via log2-exp2 trick: uses fast hardware exp2
- All accumulation in Float32 for numerical stability with bf16/fp16 inputs
- In-memory compile cache keyed by (dtype, D, has_gate_weight)

When gate_proj_weight is None, skips the CuteDSL kernel entirely and returns
a simple (O_cmp + O_slc + O_sld) / 3 average — avoids kernel launch overhead
for the ungated case.

Returns (output, gates) tuple so gates are available for the backward pass.

PyTorch reference implementations (compute_gates, gate_and_combine) retained
for testing and fallback.

No performance chart yet — gating alone is not the bottleneck. Chart will be
updated after fused scoring (Diff 3) and fused compression (Diff 4).

Differential Revision: D99181847
Summary:
Replace ~10 kernel launches per query tile with a GEMM-based scoring pipeline
using the Q_mean algebraic identity: mean(Q @ K) = mean(Q) @ K, reducing
scoring from a (q_tile_size, D) x (D, N_cmp) GEMM per tile to a single
(D,) . (D, N_cmp) GEMV — 256x fewer FLOPs.

Key components:
- _compute_q_mean(): Single PyTorch kernel computes per-tile mean of Q in
  fp32. Supports both 4D fixed-length and 3D varlen (with cu_seqlens).
- _score_and_topk(): GQA-aware bmm that folds GQA groups into the M
  dimension of the GEMM, avoiding K_cmp expansion from H_kv to H heads:
  (B*H_kv, n_tiles*groups, D) @ (B*H_kv, D, N_cmp). cuBLAS GEMM.
- fused_score_and_select_blocks(): Unified entry for selected branch.
- fused_score_and_select_all(): Computes GEMM once, derives indices for
  both selected and compressed branches (avoids duplicate GEMM).

Chunked processing (64 Q-tiles per chunk) bounds peak memory.
All scoring in fp32 for numerical stability with bf16/fp16 inputs.

NSA becomes faster than dense FA4 at ~24K tokens with this optimization;
reaches 12.5x speedup at 512K (was barely faster before due to scoring
bottleneck). Performance chart will be updated after Diff 4.

Differential Revision: D99181843
Summary:
Replace two PyTorch mean() kernel launches (one for K, one for V) with a single
fused CuteDSL kernel that processes both K and V simultaneously.

Key design:
- Grid: one thread block per output element (b, j, h) — total B * N_cmp * H_kv
- Block: 128 threads, each handling ceil(D/128) elements
- Float32 accumulation: each thread reads block_size input positions, accumulates
  K and V values in Float32, divides by inv_block_size, writes the mean
- 2D flattening for CuteDSL: K/V reshaped to (B*N*H_kv, D), manual index
  decomposition (b, j, h) from bidx
- W_k/W_v projections remain as torch.einsum (cuBLAS GEMM — already optimal)
- Varlen path: PyTorch per-sequence loop (CuteDSL varlen kernel in future diff)
- In-memory compile cache keyed by (dtype, D, block_size)

No performance chart update — chart will be generated after Diff 5 (int32
overflow fix) when we can benchmark at N >= 2M.

Differential Revision: D99181839
Summary:
CuTe tensor indexing uses int32 by default for offset computation. When
row_index * stride exceeds INT32_MAX (2,147,483,647), the offset overflows,
causing cudaErrorIllegalAddress. This crashes NSA at sequence lengths >= 2M
with typical configs (B=1, H=8, D=128).

Fix: cast row indices to cutlass.Int64() BEFORE CuTe tensor subscript access
in the fused gating kernel. This causes CuTe to compute linear offsets in
int64, preventing overflow.

Also adds E2E benchmark and memory probe utilities:
- bench_sparse_attn_e2e.py: Lean benchmark measuring Dense FA4 vs NSA at
  N = [2M, 4M, 8M, 16M]. Manages memory with gc.collect + empty_cache.
- probe_max_seqlen.py: Binary search for max sequence length on one GPU.

With the overflow fix, NSA achieves 16.2x speedup (l=64) and 29.4x (l=128)
over dense FA4 at N=3M on B200.

Differential Revision: D99181853
Summary:
Implement backward kernels for the two auxiliary NSA operations: KV compression
and output gating.

Compression backward (fused_compress_kv_backward):
- W_k/W_v projection gradients via torch.einsum in fp32 (cuBLAS GEMM)
- Mean-pool scatter: broadcast dK_cmp/dV_cmp back to original positions with
  1/block_size scaling
- Varlen variant: _fused_compress_kv_backward_varlen with per-sequence scatter

Gating backward (fused_gating_backward):
- Pure PyTorch implementation using fp32 throughout
- Computes dO_cmp, dO_slc, dO_sld (gradient routing: dO_i = g_i * dO)
- Sigmoid derivative: d_logit_i = (dO · O_i) * g_i * (1 - g_i)
- dQ_gate = d_logit @ W (cuBLAS GEMM)
- dW_gate = d_logit.T @ Q (cuBLAS cross-row reduction GEMM)
- dgate_proj_weight computation stays as torch.einsum — already optimal as
  cuBLAS GEMM, not on critical path

No performance chart impact — these are auxiliary backward kernels needed by
the autograd function (Diff 8).

Differential Revision: D99181842
Summary:
Build the FA4 backward infrastructure: the _fa4_bwd() wrapper, the sparse index
inversion for block-sparse backward, and compressed-branch sparsity tensors.

Key components:
- _fa4_bwd(): Wrapper importing _flash_attn_bwd from the mslk.attention.flash_attn
  shim. Supports block_sparse_tensors, compress_factor, mask_mod, varlen.
  compress_factor for PR#2418 — graceful fallback if not supported.

- _transpose_block_sparse_for_bwd(): Sparse inverted-index construction that
  converts forward (Q-tile → KV-blocks) to backward (KV-block → Q-tiles) format.
  O(total_entries) time and memory, no dense intermediate. Python loops over B*H
  (small: typically 1*32=32) — will be vectorized in a future optimization diff.

- build_compressed_block_sparse_tensors(): Builds FA4 block-sparse tensors for
  the compressed branch. All blocks in full_block (causal masking handled by
  compress_factor, not mask_mod). Compact index format.

No performance chart impact — backward infrastructure only, will be wired into
the autograd function in Diff 8.

Differential Revision: D99181838
Summary:
Wire everything together into a trainable NSA function with full forward + backward.

NSAFunction(torch.autograd.Function) + nsa() user API:
- Forward: CUDA stream overlap for 3 independent FA4 branches (compressed,
  selected, sliding window). Activation checkpointing saves gates + config,
  recomputes FA4 outputs in backward.
- Backward: All 3 branches use FA4 from the start (no dead-end detours):
  * Compressed: _fa4_bwd() with compress_factor + optional block_sparse_tensors
  * Selected: _fa4_bwd() with block_sparse_tensors (transposed index from Diff 7)
  * Sliding window: _fa4_bwd() with window_size_left
- Sequential per-branch backward execution to reduce peak memory — no CUDA
  streams in backward. Each branch freed with del after computing grads.
- fp32 gradient accumulation with sequential upcasting: dQ_acc = dQ_cmp.float();
  dQ_acc += dQ_slc.float(); etc. Each bf16 summand freed after upcast-add.

Varlen support:
- _nsa_forward_varlen(): Pad compressed branch to 4D (mask_mod blocks varlen
  in FA4 backward), native FA4 varlen for selected + sliding window.
- _nsa_backward_varlen(): Same pad/unpad strategy for backward.
- _pad_to_4d() / _unpad_to_3d() utilities (Python loops — will be vectorized
  in future optimization diff).

No performance chart update — chart will be generated in Diff 9 (benchmarks).

Differential Revision: D99181840
…#307)

Summary:
Pull Request resolved: meta-pytorch#307

Add backward performance measurement, benchmarking, and visualization utilities.

Benchmark scripts:
- bench_nsa_backward.py: Forward + fwd+bwd timing at N=1K-65K, reporting
  wall-clock ms and dense-equivalent TFLOPS
- bench_nsa_vs_dense.py: NSA vs dense FA4 latency from 1K to 3M tokens,
  the primary metric for this project
- bench_sparse_compressed.py: Compressed branch benchmarks
- bench_2m_3m.py: Large sequence benchmarks (N >= 2M)

Performance visualization:
- docs/gen_nsa_perf_chart.py: SVG chart generator showing NSA vs dense FA4
  latency, forward speedup curve, and fwd+bwd breakdown with OOM boundary

gc.collect + expandable_segments + pre-warmup in all benchmarks for
accurate timing at large N.

No performance impact on fwd pass.
fwd+bwd reaching 5-7x at the same scale. Crossover point at ~24K tokens.

 {F1987648647}

Differential Revision: D99181849
@jduprat jduprat force-pushed the export-D99181849 branch from d6a9bc3 to d5ad4d9 Compare April 3, 2026 01:01
@meta-codesync meta-codesync Bot changed the title NSA backward — benchmarks and performance documentation NSA backward — benchmarks and performance documentation (#307) Apr 3, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant