Update benchmarks to 1M tokens, add memory diagnostics#297
Open
jduprat wants to merge 28 commits into
Open
Conversation
Summary: Move the sparse_attn kernel from fb/mslk/attention/sparse_attn to mslk/attention/sparse_attn, matching the convention used by other kernels in this directory. Update all imports from mslk.fb.mslk.attention.sparse_attn to mslk.attention.sparse_attn, update BUCK targets so the library is defined directly in mslk/BUCK instead of as a shim forwarding to mslk/fb, and update test deps accordingly. Also fix a stale shape assertion in test_sparse_attn_sparsity_masks that expected the old dense index layout instead of the compact one. Additionally fixes three O(N^2) memory bottlenecks that caused NSA to OOM at N=65536: tiled block scoring in select.py, compact index tensors in sparsity_masks.py, and chunked gating in gating.py. Differential Revision: D93656846
Summary: Replace the two-step PyTorch gating (compute_gates + gate_and_combine) with a single fused CuteDSL kernel (fused_gate_and_combine). The fused kernel computes gate dot-products, sigmoid activation, and weighted branch combination in one pass per (b,n,h) position, using warp-level shuffle reductions for the gate dot products. This eliminates the intermediate gates tensor and reduces kernel launches from 3-5 to 1. Each warp (32 threads) handles one (b,n,h) row: loads Q and gate weights, computes 3 dot products via butterfly shuffle reduction, applies sigmoid, then does the weighted sum of the 3 branch outputs. No shared memory needed — gate weights are small enough to stay L2-cached across positions. Benchmark on B200 (B=2, H=32, D=128, bf16): fused kernel is 4-7x faster than the PyTorch reference, with the speedup growing at longer sequences (0.22ms vs 0.90ms at N=4096, 0.50ms vs 3.63ms at N=16384). At N=16384, gating drops from 63% to 9% of the NSA total time. The existing PyTorch compute_gates and gate_and_combine functions are kept as reference implementations for testing. Differential Revision: D93888405
Summary:
Replace the PyTorch score_and_select_blocks() with a fused CuteDSL kernel (fused_score_and_select_blocks) that reduces ~10 kernel launches per query tile to 1. The kernel fuses dot-product scoring, causal masking, parallel top-k selection, and index sorting into a single launch of 128 threads per (batch, head, query_tile) block.
Key optimization: uses Q_mean algebraic identity — mean(Q @ K) = mean(Q) @ K — to reduce the scoring from a (256, D) x (D, N_cmp) GEMM to a (D,) . (D, N_cmp) GEMV, giving 256x fewer FLOPs. The Q_mean computation is done in PyTorch (single kernel), then the CuteDSL kernel handles scoring + top-k.
The top-k pipeline uses a parallel reduction: per-thread insertion sort → warp butterfly shuffle merge → cross-warp shared memory merge. The insertion sort avoids CuteDSL's break limitation in range_constexpr by letting the inner loop run fully with cascading shifts — each value "bubbles up" to its correct position without early exit.
Measured on B200 (B=1, H=8, H_kv=2, D=128, bf16, up to 512K tokens):
N Dense NSA-old NSA-new old/new dense/new
(ms) l=64(ms) l=64(ms)
64K 7.61 8.14 1.73 4.70x 4.41x
128K 29.71 24.08 4.11 5.86x 7.23x
256K 118.63 81.37 11.79 6.90x 10.06x
512K 474.50 302.66 37.90 7.99x 12.52x
The old PyTorch selection was such a bottleneck that NSA was barely faster than dense even at 128K. With the fused kernel, NSA becomes faster than dense at ~24K and reaches 20x speedup at 512K (l=128).
Also fixes OOM in test_sparse_attn_benchmark at N=65536/131072 by scaling down B/H for large sequence lengths, and adds nvidia-cutlass-dsl deps to BUCK targets that transitively use the fused kernel.
Differential Revision: D93905195
Summary: Replace the two PyTorch mean-reduction kernel launches in compress_kv with a single fused CuteDSL kernel that processes both K and V simultaneously. Each thread block handles one output element (b, j, h), reads block_size input positions, accumulates in float32, and writes the mean. The optional W_k/W_v projections remain in PyTorch. The fused kernel is integrated into nsa_forward.py as the default compression path. Tests confirm fused output matches the PyTorch reference within atol=1e-5 (fp32) and 1e-3 (bf16). Benchmark SVG updated with current measured data (all three fused kernels active). Differential Revision: D94099960
Summary: CuTe tensor indexing uses int32 by default for offset computation. When row_index * stride exceeds INT32_MAX (2,147,483,647), the offset overflows, causing cudaErrorIllegalAddress. This crashes NSA at sequence lengths >= 2M with typical configs (B=1, H=8, D=128). Fix: cast row indices to cutlass.Int64() before CuTe tensor subscript access in the gating and compress kernels. This causes CuTe to compute linear offsets in int64, preventing overflow. The select kernel's indices stay within int32 for N up to hundreds of millions and needs no change. Also adds lean benchmark tooling (bench_sparse_attn_e2e.py, probe_max_seqlen.py) for testing at large sequence lengths, and updates the performance SVG with measured data through N=3M showing 16.2x (l=64) and 29.4x (l=128) speedup over dense FA4 on B200. Differential Revision: D94120744
Summary: Implements the backward pass for the NSA (Native Sparse Attention) kernel, enabling training. The implementation includes: 1. Reference backward (reference.py): Pure PyTorch differentiable forward for gradient validation via torch.autograd.gradcheck. 2. Autograd wrapper (nsa_autograd.py): NSAFunction(torch.autograd.Function) with nsa() user API and activation checkpointing. 3. CuteDSL gating backward (gating.py): fused_gating_backward() kernel. 4. CuteDSL compression backward (compress.py): fused_compress_kv_backward(). 5. Block sparsity format (sparsity_masks.py): Dense format for FA4 backward. Differential Revision: D98080760
Summary: Adds mask_mod support to the selected attention backward path in nsa_autograd.py, enabling correct causal masking for the block-sparse attention branch. Differential Revision: D98080763
Summary: Adds bench_nsa_backward.py with forward and fwd+bwd timing at various sequence lengths (1K-65K). Reports wall-clock ms and dense-equivalent TFLOPS. Differential Revision: D98080762
Summary: Fixes FA4 interface for 2-CTA backward compatibility. Adds 1-line fix to interface.py to disable 2-CTA when block_sparse_tensors is present. Restructures nsa_autograd.py backward to handle block sparsity correctly with mask_mod fallback path. Differential Revision: D98080757
Summary: Rewrites all three attention backward branches to use PyTorch matmul (cuBLAS) instead of CUTLASS JIT compilation. This avoids the CuTe DSL block-sparse backward compilation bug on SM100 while maintaining correctness. Includes gather/scatter for selected branch and SDPA for sliding window. Differential Revision: D98080759
Summary: Adds SDPA-based sliding window backward path and bench_nsa_vs_dense.py benchmark comparing NSA sparse attention vs dense FA4 at sequence lengths from 1K to 3M for both forward and backward passes. Differential Revision: D98080756
Summary: Moves two of three backward attention branches from PyTorch to FA4 CuteDSL: 1. Compressed attention backward: _flash_attn_bwd with mask_mod. Eliminates O(N * N_cmp) score matrix materialization. 2. Sliding window backward: _flash_attn_bwd with window_size_left. Fixes correctness bug: SDPA backward had is_causal=causal and window_size is None which always evaluated to False, producing wrong non-causal non-windowed gradients. 3. Selected attention backward: Rewritten to per-tile gather/scatter with GQA-aware memory layout (gather from H_kv heads, accumulate dK/dV at H_kv resolution). Avoids O(N_q_tiles * N * D) memory. Adds _fa4_bwd() helper, 5 TestNSABackwardMatchesReference correctness tests, and extends benchmarks to 3M tokens. FA4 block-sparse backward on SM100 has a CuTe DSL while-loop closure capture bug (upstream FlashAttention issue #2011). Differential Revision: D98080761
Summary: Adds SVG performance chart showing NSA vs dense FA4 latency, forward speedup curve (10.2x at 1M), and fwd+bwd breakdown with OOM boundary. Data from GB200 benchmarks (B=1, H=32, H_kv=8, D=128). Differential Revision: D98080758
Summary: Replaces the PyTorch gather/scatter backward for the selected attention branch with FA4's native block-sparse backward via _flash_attn_bwd. This was blocked by a missing argument bug in flash_bwd_sm100.py (fixed in the previous commit). Now all three backward branches use FA4 CuteDSL: 1. Compressed: _flash_attn_bwd with mask_mod 2. Selected: _flash_attn_bwd with block_sparse_tensors (this commit) 3. Sliding window: _flash_attn_bwd with window_size_left The selected branch backward requires transposing the forward block-sparse tensors from (Q-tiles, KV-blocks) to (KV-blocks, Q-tiles) format, since the backward kernel iterates over KV-blocks. This is done by _transpose_block_sparse_for_bwd() using vectorized scatter + sort. Eliminates all PyTorch matmul/gather/scatter from the backward pass. Should remove the 512K OOM limit since no intermediate PyTorch tensors are allocated. Differential Revision: D98086884
Summary: Updates benchmark data with all three backward branches on FA4 CuteDSL. Backward now reaches 1M tokens (was 512K with PyTorch fallback). fwd+bwd at 1M: 1786ms. 3-6x faster than previous PyTorch backward. Differential Revision: D98086882
Summary: Restructures the backward pass to process each branch sequentially: recompute forward, compute dO_i, run backward, free intermediates. Without gate weights (the common case), dO_i = g_i * dO doesn't need the recomputed O_i. So each branch's O and lse are freed before the next branch's forward is computed, reducing peak memory by ~2x the output tensor size. With gate weights, gating backward needs all three O_i simultaneously, so we fall back to the previous behavior. This should allow backward to reach 2M+ tokens on 184 GiB GB200. Differential Revision: D98147257
Summary: Adds memory cleanup between benchmark sizes and enables PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to reduce fragmentation. Pre-warms kernel compilation at small N for large sequence benchmarks. Note: CuTe DSL JIT compilation cache consumes ~170 GiB on GB200, leaving only ~14 GiB for computation. This limits the maximum sequence length to ~1M tokens regardless of backward memory optimizations. The 2M+ forward was likely benchmarked on a machine with more GPU memory. Differential Revision: D98147256
Summary: Replaces the 3-sort dedup algorithm in build_fa4_block_sparse_tensors (contraction case: compress_block_size < n_block_size) with a scatter-based approach: 1. Scatter block indices into a boolean attendance mask (O(1) per element) 2. Sum for counts 3. Single sort to pack indices This eliminates 2 of 3 sorts and the sentinel-based dedup logic. The contraction case (compress=64, n_block=128) is the default config and was taking 22-24% of forward time at small sequence lengths. Differential Revision: D98147254
Summary: Adds two new tools: - test_fa4_block_sparse_bwd.py: Minimal reproducer for the FA4 block-sparse backward bug (now fixed). Tests _flash_attn_bwd with block_sparse_tensors. - profile_nsa_forward.py: Component-level profiler that breaks down NSA forward time into compression, selection, mask construction, three FA4 attention branches, and gating. Differential Revision: D98147679
Summary: Documents all findings from the NSA backward pass work: - Performance data (forward crossover at 32K, 10x at 1M) - Component profiling breakdown - FA4 block-sparse backward bug analysis and fix - Sliding window causal bug fix - 2M/3M memory limit analysis (CuTe DSL compilation cache) - Next steps: upstream PR, crossover optimization, varlen support Updates performance chart with latest sequential backward numbers. Differential Revision: D98147680
Summary: - Re-benchmark NSA vs Dense FA4 on B200 (devgpu016): forward crossover ~25K, fwd+bwd crossover ~20K, 10.1x/11.5x speedup at 1M - Add dense FA4 fwd+bwd comparison (was missing — fix tuple return bug in bench_nsa_vs_dense.py) - Update chart to show fwd+bwd latency panel with dense vs NSA comparison and combined speedup panel (fwd vs fwd+bwd) - Add hostname + timestamp to chart footer - Update FINDINGS.md with complete benchmark data including 2M/3M dense results - NSA still OOMs at 2M due to CuTe DSL JIT cache (~165 GiB) Differential Revision: D98172197
Summary: When gate_proj_weight is None (the common benchmark case), the gates are uniform 1/3. Previously this still launched the fused CuteDSL gating kernel which computes sigmoid(0)=0.5 for each gate and normalizes. Now we simply average the three branch outputs: O = (O_cmp + O_slc + O_sld) / 3. This replaces a CuteDSL kernel launch with a single PyTorch add+scale op. Saves ~0.3ms at small N where kernel launch overhead dominates. Differential Revision: D98174599
Summary: Documents CUDA graph compatibility and benchmark results showing crossover moved from 32K to ~4K tokens with graph capture. Differential Revision: D98174552
Summary: Removes the NotImplementedError that blocked block_sparse_tensors + varlen (cu_seqlens_q/k) in _flash_attn_fwd. The fix uses max_seqlen_q/k for block sparse tensor shape computation instead of seqlen_q/k (which is None or total_tokens for varlen). The kernel-level code already supports both features independently. The tile scheduler checks m_block < m_block_max per sequence using cu_seqlens, so tiles beyond a shorter sequence's actual length are skipped. Block sparse tensors are sized for max_seqlen (the longest sequence in the batch), with shorter sequences having their excess tiles zeroed out. This enables NSA's selected attention branch to use native FA4 varlen instead of padding, which is required for packed-sequence training. Upstream: Should be submitted to Dao-AILab/flash-attention alongside the is_leader_cta fix (issue #2011). Differential Revision: D98200950
Summary: Extends nsa() (the autograd wrapper) with cu_seqlens and max_seqlen parameters for variable-length sequence training. For varlen, sequences are padded to 4D internally, the existing NSAFunction autograd path runs on padded tensors, and the output is unpacked back to 3D using autograd-preserving indexing (torch.cat). Gradients flow correctly through the pad/unpad via autograd. Differential Revision: D98203410
Summary: Eliminates the full pad-to-4D approach for varlen backward. Selected and sliding window branches now use native FA4 varlen (cu_seqlens) for both fwd and bwd, avoiding wasted compute on padding positions. Only the compressed branch still uses padded Q (mask_mod blocks varlen in FA4 backward). Changes: - Gating: fused_gate_and_combine, compute_gates, fused_gating_backward, gate_and_combine all accept 3D (T, H, D) varlen input in addition to 4D (B, N, H, D) - FA4 SM100 backward: split assertion to allow varlen + block_sparse (mask_mod still blocked) - Compress/select: fused_compress_kv and fused_score_and_select_blocks accept cu_seqlens, reading from 3D varlen input directly (no Q_pad/K_pad/V_pad allocation) - fused_compress_kv_backward: scatters from padded compressed gradients directly to 3D - NSAFunction: restructured to handle varlen inside the autograd function. nsa() calls varlen-aware compress/select, then passes 3D Q,K,V + small padded K_cmp/V_cmp to NSAFunction which uses native FA4 varlen for selected + sliding window branches. - fb/BUCK: add einops dep to flash_attn target (testing.py imports it) - Updated bench_nsa_vs_dense.py with varlen benchmark mode - Updated FINDINGS.md and performance chart with GB200 benchmark data Differential Revision: D98218912
Summary: Set CUTLASS_CUTE_DSL_KERNEL_CACHE_DIR and PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True in benchmark scripts. This persists compiled CuteDSL kernels to disk across runs, eliminating JIT warmup cost and reducing GPU memory pressure. CUDA graph forward pass confirmed working with restructured code: - 2.3x speedup at 4K, tapering to 1.0x at 128K - Forward+backward CUDA graph capture fails (pre-existing — dynamic tensor allocations in backward are not graph-capturable) Differential Revision: D98304607
Summary:
GB200 benchmark results including 1M token fwd+bwd:
- 14.07x fwd+bwd speedup at 1M tokens (NSA 1.85s vs Dense 26.08s)
- 9.96x fwd+bwd speedup at 512K (up from previous 6.50x measurement)
- 7.93x forward speedup at 1M tokens
- Crossover at ~32K for both fwd and fwd+bwd
- Varlen is 23-33% faster than padded at 256K-1M contexts
1M now works on GB200 (184 GiB). 2M forward works, 2M backward OOMs.
Previous OOM at 1M was on B200 (178 GiB) or due to corrupted GPU state.
Changes:
- Add disk JIT cache (CUTLASS_CUTE_DSL_KERNEL_CACHE_DIR) to benchmarks
- Add expandable_segments to benchmark scripts
- Add memory diagnostic tool (diagnose_memory.py)
- Update FINDINGS.md with 1M data and varlen comparison
- Update performance chart (gen_nsa_perf_chart.py + nsa_bwd_perf.svg)
- Add sliding window block-sparse utility (window_sparse.py) for future use
- Clean up forward path: use _fa4_fwd consistently for all branches
{F1987280652}
Differential Revision: D98319585
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary:
GB200 benchmark results including 1M token fwd+bwd:
1M now works on GB200 (184 GiB). 2M forward works, 2M backward OOMs.
Previous OOM at 1M was on B200 (178 GiB) or due to corrupted GPU state.
Changes:
{F1987280652}
Differential Revision: D98319585