Update FINDINGS.md with optimization round results#296
Open
jduprat wants to merge 45 commits into
Open
Conversation
Summary: Move the sparse_attn kernel from fb/mslk/attention/sparse_attn to mslk/attention/sparse_attn, matching the convention used by other kernels in this directory. Update all imports from mslk.fb.mslk.attention.sparse_attn to mslk.attention.sparse_attn, update BUCK targets so the library is defined directly in mslk/BUCK instead of as a shim forwarding to mslk/fb, and update test deps accordingly. Also fix a stale shape assertion in test_sparse_attn_sparsity_masks that expected the old dense index layout instead of the compact one. Additionally fixes three O(N^2) memory bottlenecks that caused NSA to OOM at N=65536: tiled block scoring in select.py, compact index tensors in sparsity_masks.py, and chunked gating in gating.py. Differential Revision: D93656846
Summary: Replace the two-step PyTorch gating (compute_gates + gate_and_combine) with a single fused CuteDSL kernel (fused_gate_and_combine). The fused kernel computes gate dot-products, sigmoid activation, and weighted branch combination in one pass per (b,n,h) position, using warp-level shuffle reductions for the gate dot products. This eliminates the intermediate gates tensor and reduces kernel launches from 3-5 to 1. Each warp (32 threads) handles one (b,n,h) row: loads Q and gate weights, computes 3 dot products via butterfly shuffle reduction, applies sigmoid, then does the weighted sum of the 3 branch outputs. No shared memory needed — gate weights are small enough to stay L2-cached across positions. Benchmark on B200 (B=2, H=32, D=128, bf16): fused kernel is 4-7x faster than the PyTorch reference, with the speedup growing at longer sequences (0.22ms vs 0.90ms at N=4096, 0.50ms vs 3.63ms at N=16384). At N=16384, gating drops from 63% to 9% of the NSA total time. The existing PyTorch compute_gates and gate_and_combine functions are kept as reference implementations for testing. Differential Revision: D93888405
Summary:
Replace the PyTorch score_and_select_blocks() with a fused CuteDSL kernel (fused_score_and_select_blocks) that reduces ~10 kernel launches per query tile to 1. The kernel fuses dot-product scoring, causal masking, parallel top-k selection, and index sorting into a single launch of 128 threads per (batch, head, query_tile) block.
Key optimization: uses Q_mean algebraic identity — mean(Q @ K) = mean(Q) @ K — to reduce the scoring from a (256, D) x (D, N_cmp) GEMM to a (D,) . (D, N_cmp) GEMV, giving 256x fewer FLOPs. The Q_mean computation is done in PyTorch (single kernel), then the CuteDSL kernel handles scoring + top-k.
The top-k pipeline uses a parallel reduction: per-thread insertion sort → warp butterfly shuffle merge → cross-warp shared memory merge. The insertion sort avoids CuteDSL's break limitation in range_constexpr by letting the inner loop run fully with cascading shifts — each value "bubbles up" to its correct position without early exit.
Measured on B200 (B=1, H=8, H_kv=2, D=128, bf16, up to 512K tokens):
N Dense NSA-old NSA-new old/new dense/new
(ms) l=64(ms) l=64(ms)
64K 7.61 8.14 1.73 4.70x 4.41x
128K 29.71 24.08 4.11 5.86x 7.23x
256K 118.63 81.37 11.79 6.90x 10.06x
512K 474.50 302.66 37.90 7.99x 12.52x
The old PyTorch selection was such a bottleneck that NSA was barely faster than dense even at 128K. With the fused kernel, NSA becomes faster than dense at ~24K and reaches 20x speedup at 512K (l=128).
Also fixes OOM in test_sparse_attn_benchmark at N=65536/131072 by scaling down B/H for large sequence lengths, and adds nvidia-cutlass-dsl deps to BUCK targets that transitively use the fused kernel.
Differential Revision: D93905195
Summary: Replace the two PyTorch mean-reduction kernel launches in compress_kv with a single fused CuteDSL kernel that processes both K and V simultaneously. Each thread block handles one output element (b, j, h), reads block_size input positions, accumulates in float32, and writes the mean. The optional W_k/W_v projections remain in PyTorch. The fused kernel is integrated into nsa_forward.py as the default compression path. Tests confirm fused output matches the PyTorch reference within atol=1e-5 (fp32) and 1e-3 (bf16). Benchmark SVG updated with current measured data (all three fused kernels active). Differential Revision: D94099960
Summary: CuTe tensor indexing uses int32 by default for offset computation. When row_index * stride exceeds INT32_MAX (2,147,483,647), the offset overflows, causing cudaErrorIllegalAddress. This crashes NSA at sequence lengths >= 2M with typical configs (B=1, H=8, D=128). Fix: cast row indices to cutlass.Int64() before CuTe tensor subscript access in the gating and compress kernels. This causes CuTe to compute linear offsets in int64, preventing overflow. The select kernel's indices stay within int32 for N up to hundreds of millions and needs no change. Also adds lean benchmark tooling (bench_sparse_attn_e2e.py, probe_max_seqlen.py) for testing at large sequence lengths, and updates the performance SVG with measured data through N=3M showing 16.2x (l=64) and 29.4x (l=128) speedup over dense FA4 on B200. Differential Revision: D94120744
Summary: Implements the backward pass for the NSA (Native Sparse Attention) kernel, enabling training. The implementation includes: 1. Reference backward (reference.py): Pure PyTorch differentiable forward for gradient validation via torch.autograd.gradcheck. 2. Autograd wrapper (nsa_autograd.py): NSAFunction(torch.autograd.Function) with nsa() user API and activation checkpointing. 3. CuteDSL gating backward (gating.py): fused_gating_backward() kernel. 4. CuteDSL compression backward (compress.py): fused_compress_kv_backward(). 5. Block sparsity format (sparsity_masks.py): Dense format for FA4 backward. Differential Revision: D98080760
Summary: Adds mask_mod support to the selected attention backward path in nsa_autograd.py, enabling correct causal masking for the block-sparse attention branch. Differential Revision: D98080763
Summary: Adds bench_nsa_backward.py with forward and fwd+bwd timing at various sequence lengths (1K-65K). Reports wall-clock ms and dense-equivalent TFLOPS. Differential Revision: D98080762
Summary: Fixes FA4 interface for 2-CTA backward compatibility. Adds 1-line fix to interface.py to disable 2-CTA when block_sparse_tensors is present. Restructures nsa_autograd.py backward to handle block sparsity correctly with mask_mod fallback path. Differential Revision: D98080757
Summary: Rewrites all three attention backward branches to use PyTorch matmul (cuBLAS) instead of CUTLASS JIT compilation. This avoids the CuTe DSL block-sparse backward compilation bug on SM100 while maintaining correctness. Includes gather/scatter for selected branch and SDPA for sliding window. Differential Revision: D98080759
Summary: Adds SDPA-based sliding window backward path and bench_nsa_vs_dense.py benchmark comparing NSA sparse attention vs dense FA4 at sequence lengths from 1K to 3M for both forward and backward passes. Differential Revision: D98080756
Summary: Moves two of three backward attention branches from PyTorch to FA4 CuteDSL: 1. Compressed attention backward: _flash_attn_bwd with mask_mod. Eliminates O(N * N_cmp) score matrix materialization. 2. Sliding window backward: _flash_attn_bwd with window_size_left. Fixes correctness bug: SDPA backward had is_causal=causal and window_size is None which always evaluated to False, producing wrong non-causal non-windowed gradients. 3. Selected attention backward: Rewritten to per-tile gather/scatter with GQA-aware memory layout (gather from H_kv heads, accumulate dK/dV at H_kv resolution). Avoids O(N_q_tiles * N * D) memory. Adds _fa4_bwd() helper, 5 TestNSABackwardMatchesReference correctness tests, and extends benchmarks to 3M tokens. FA4 block-sparse backward on SM100 has a CuTe DSL while-loop closure capture bug (upstream FlashAttention issue #2011). Differential Revision: D98080761
Summary: Adds SVG performance chart showing NSA vs dense FA4 latency, forward speedup curve (10.2x at 1M), and fwd+bwd breakdown with OOM boundary. Data from GB200 benchmarks (B=1, H=32, H_kv=8, D=128). Differential Revision: D98080758
Summary: Replaces the PyTorch gather/scatter backward for the selected attention branch with FA4's native block-sparse backward via _flash_attn_bwd. This was blocked by a missing argument bug in flash_bwd_sm100.py (fixed in the previous commit). Now all three backward branches use FA4 CuteDSL: 1. Compressed: _flash_attn_bwd with mask_mod 2. Selected: _flash_attn_bwd with block_sparse_tensors (this commit) 3. Sliding window: _flash_attn_bwd with window_size_left The selected branch backward requires transposing the forward block-sparse tensors from (Q-tiles, KV-blocks) to (KV-blocks, Q-tiles) format, since the backward kernel iterates over KV-blocks. This is done by _transpose_block_sparse_for_bwd() using vectorized scatter + sort. Eliminates all PyTorch matmul/gather/scatter from the backward pass. Should remove the 512K OOM limit since no intermediate PyTorch tensors are allocated. Differential Revision: D98086884
Summary: Updates benchmark data with all three backward branches on FA4 CuteDSL. Backward now reaches 1M tokens (was 512K with PyTorch fallback). fwd+bwd at 1M: 1786ms. 3-6x faster than previous PyTorch backward. Differential Revision: D98086882
Summary: Restructures the backward pass to process each branch sequentially: recompute forward, compute dO_i, run backward, free intermediates. Without gate weights (the common case), dO_i = g_i * dO doesn't need the recomputed O_i. So each branch's O and lse are freed before the next branch's forward is computed, reducing peak memory by ~2x the output tensor size. With gate weights, gating backward needs all three O_i simultaneously, so we fall back to the previous behavior. This should allow backward to reach 2M+ tokens on 184 GiB GB200. Differential Revision: D98147257
Summary: Adds memory cleanup between benchmark sizes and enables PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to reduce fragmentation. Pre-warms kernel compilation at small N for large sequence benchmarks. Note: CuTe DSL JIT compilation cache consumes ~170 GiB on GB200, leaving only ~14 GiB for computation. This limits the maximum sequence length to ~1M tokens regardless of backward memory optimizations. The 2M+ forward was likely benchmarked on a machine with more GPU memory. Differential Revision: D98147256
Summary: Replaces the 3-sort dedup algorithm in build_fa4_block_sparse_tensors (contraction case: compress_block_size < n_block_size) with a scatter-based approach: 1. Scatter block indices into a boolean attendance mask (O(1) per element) 2. Sum for counts 3. Single sort to pack indices This eliminates 2 of 3 sorts and the sentinel-based dedup logic. The contraction case (compress=64, n_block=128) is the default config and was taking 22-24% of forward time at small sequence lengths. Differential Revision: D98147254
Summary: Adds two new tools: - test_fa4_block_sparse_bwd.py: Minimal reproducer for the FA4 block-sparse backward bug (now fixed). Tests _flash_attn_bwd with block_sparse_tensors. - profile_nsa_forward.py: Component-level profiler that breaks down NSA forward time into compression, selection, mask construction, three FA4 attention branches, and gating. Differential Revision: D98147679
Summary: Documents all findings from the NSA backward pass work: - Performance data (forward crossover at 32K, 10x at 1M) - Component profiling breakdown - FA4 block-sparse backward bug analysis and fix - Sliding window causal bug fix - 2M/3M memory limit analysis (CuTe DSL compilation cache) - Next steps: upstream PR, crossover optimization, varlen support Updates performance chart with latest sequential backward numbers. Differential Revision: D98147680
Summary: - Re-benchmark NSA vs Dense FA4 on B200 (devgpu016): forward crossover ~25K, fwd+bwd crossover ~20K, 10.1x/11.5x speedup at 1M - Add dense FA4 fwd+bwd comparison (was missing — fix tuple return bug in bench_nsa_vs_dense.py) - Update chart to show fwd+bwd latency panel with dense vs NSA comparison and combined speedup panel (fwd vs fwd+bwd) - Add hostname + timestamp to chart footer - Update FINDINGS.md with complete benchmark data including 2M/3M dense results - NSA still OOMs at 2M due to CuTe DSL JIT cache (~165 GiB) Differential Revision: D98172197
Summary: When gate_proj_weight is None (the common benchmark case), the gates are uniform 1/3. Previously this still launched the fused CuteDSL gating kernel which computes sigmoid(0)=0.5 for each gate and normalizes. Now we simply average the three branch outputs: O = (O_cmp + O_slc + O_sld) / 3. This replaces a CuteDSL kernel launch with a single PyTorch add+scale op. Saves ~0.3ms at small N where kernel launch overhead dominates. Differential Revision: D98174599
Summary: Documents CUDA graph compatibility and benchmark results showing crossover moved from 32K to ~4K tokens with graph capture. Differential Revision: D98174552
Summary: Removes the NotImplementedError that blocked block_sparse_tensors + varlen (cu_seqlens_q/k) in _flash_attn_fwd. The fix uses max_seqlen_q/k for block sparse tensor shape computation instead of seqlen_q/k (which is None or total_tokens for varlen). The kernel-level code already supports both features independently. The tile scheduler checks m_block < m_block_max per sequence using cu_seqlens, so tiles beyond a shorter sequence's actual length are skipped. Block sparse tensors are sized for max_seqlen (the longest sequence in the batch), with shorter sequences having their excess tiles zeroed out. This enables NSA's selected attention branch to use native FA4 varlen instead of padding, which is required for packed-sequence training. Upstream: Should be submitted to Dao-AILab/flash-attention alongside the is_leader_cta fix (issue #2011). Differential Revision: D98200950
Summary: Extends nsa() (the autograd wrapper) with cu_seqlens and max_seqlen parameters for variable-length sequence training. For varlen, sequences are padded to 4D internally, the existing NSAFunction autograd path runs on padded tensors, and the output is unpacked back to 3D using autograd-preserving indexing (torch.cat). Gradients flow correctly through the pad/unpad via autograd. Differential Revision: D98203410
Summary: Eliminates the full pad-to-4D approach for varlen backward. Selected and sliding window branches now use native FA4 varlen (cu_seqlens) for both fwd and bwd, avoiding wasted compute on padding positions. Only the compressed branch still uses padded Q (mask_mod blocks varlen in FA4 backward). Changes: - Gating: fused_gate_and_combine, compute_gates, fused_gating_backward, gate_and_combine all accept 3D (T, H, D) varlen input in addition to 4D (B, N, H, D) - FA4 SM100 backward: split assertion to allow varlen + block_sparse (mask_mod still blocked) - Compress/select: fused_compress_kv and fused_score_and_select_blocks accept cu_seqlens, reading from 3D varlen input directly (no Q_pad/K_pad/V_pad allocation) - fused_compress_kv_backward: scatters from padded compressed gradients directly to 3D - NSAFunction: restructured to handle varlen inside the autograd function. nsa() calls varlen-aware compress/select, then passes 3D Q,K,V + small padded K_cmp/V_cmp to NSAFunction which uses native FA4 varlen for selected + sliding window branches. - fb/BUCK: add einops dep to flash_attn target (testing.py imports it) - Updated bench_nsa_vs_dense.py with varlen benchmark mode - Updated FINDINGS.md and performance chart with GB200 benchmark data Differential Revision: D98218912
Summary: Set CUTLASS_CUTE_DSL_KERNEL_CACHE_DIR and PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True in benchmark scripts. This persists compiled CuteDSL kernels to disk across runs, eliminating JIT warmup cost and reducing GPU memory pressure. CUDA graph forward pass confirmed working with restructured code: - 2.3x speedup at 4K, tapering to 1.0x at 128K - Forward+backward CUDA graph capture fails (pre-existing — dynamic tensor allocations in backward are not graph-capturable) Differential Revision: D98304607
Summary: GB200 benchmark results including 1M token fwd+bwd: - 14.07x fwd+bwd speedup at 1M tokens (NSA 1.85s vs Dense 26.08s) - 9.96x fwd+bwd speedup at 512K (up from previous 6.50x measurement) - 7.93x forward speedup at 1M tokens - Crossover at ~32K for both fwd and fwd+bwd - Varlen is 23-33% faster than padded at 256K-1M contexts 1M now works on GB200 (184 GiB). 2M forward works, 2M backward OOMs. Previous OOM at 1M was on B200 (178 GiB) or due to corrupted GPU state. Changes: - Add disk JIT cache (CUTLASS_CUTE_DSL_KERNEL_CACHE_DIR) to benchmarks - Add expandable_segments to benchmark scripts - Add memory diagnostic tool (diagnose_memory.py) - Update FINDINGS.md with 1M data and varlen comparison - Update performance chart (gen_nsa_perf_chart.py + nsa_bwd_perf.svg) - Add sliding window block-sparse utility (window_sparse.py) for future use - Clean up forward path: use _fa4_fwd consistently for all branches Differential Revision: D98319585
Summary: - Rename nsa_bwd_perf.svg → nsa_perf.svg (it measures both fwd and bwd) - Add "NSA (varlen)" line to fwd+bwd latency panel showing native varlen performance alongside padded - Varlen is 23-33% faster than padded at 256K-1M (visible in chart) Differential Revision: D98326585
Summary: Extended the forward profiler to 512K and 1M tokens. Key finding: The compressed branch (Q × K_cmp with mask_mod) dominates forward time at large N: - 512K: 91.68ms (53.6% of 171ms total) - 1M: 364.20ms (57.5% of 633ms total) This is Q(N) × K_cmp(N/64) = O(N²/64) — quadratic, explaining the forward speedup regression from 8.24x (512K) to 7.93x (1M). The block selection kernel is 21.5% at 1M (136ms), and mask construction is 9.4% (59ms). Together with compressed attention, these three components account for 88% of forward time at 1M. Differential Revision: D98327090
Summary: Replace the O(N²) attendance-matrix approach in build_fa4_block_sparse_tensors with consecutive dedup on sorted indices. The old approach materialized a (B, H, N_q_tiles, n_blocks_k) boolean tensor (1 GB at 1M), then sorted it. The new approach sorts the k=16 indices per tile (tiny) and removes duplicates. Mask construction time at 1M: 59.22ms → 4.15ms (14.3x faster). At 512K: 8.08ms → 1.87ms (4.3x faster). At 256K: 2.16ms → 1.11ms (1.9x faster). Also adds fused_select_and_build_sparse() convenience function for callers. Differential Revision: D98341658
Summary: Replace `dQ = dQ_cmp + dQ_slc + dQ_sld` (allocates new tensor) with in-place `dQ_cmp += dQ_slc; del dQ_slc; dQ_cmp += dQ_sld; del dQ_sld` in both the 4D and varlen backward paths. Reduces peak memory by avoiding temporary copies during gradient summation. 2M backward still OOMs (needs 32 GB allocation with only 5 GB free). The bottleneck is FA4's dQ output tensor at 2M (32 heads × 128 dim × 2M tokens × 2 bytes = 32 GB). Deferring 2M backward to future work — would require sequence-level splitting or more aggressive checkpointing. Differential Revision: D98399985
Summary: Updated benchmarks and charts with final performance data after all optimizations (mask construction 14x speedup, in-place gradient accumulation, native varlen). GB200 benchmark results (B=1, H=32, H_kv=8, D=128, 1K-1M measured): - Forward: 10.19x speedup at 1M (up from 7.93x before mask optimization) - Fwd+Bwd: 11.21x speedup at 1M - Crossover: ~32K for both fwd and fwd+bwd - Varlen is 47% faster than padded at 1M (up from 23%) - 2M forward works (156 GB), 2M backward OOMs (needs ~200 GB) Forward speedup regression at 1M is fixed — was caused by O(N²) mask construction that scaled quadratically. After optimization, forward speedup curve is monotonically increasing from 32K to 1M. Differential Revision: D98400268
Summary: The test expected full_block_idx shape (B, H, N_q_tiles, k) but our implementation uses dense shape (B, H, N_q_tiles, n_blocks_k) for backward transpose compatibility. Updated the test to match. Differential Revision: D98496676
Summary: The compressed attention branch was O(N²/64), dominating 57.5% of forward time at 1M tokens. This change adds an optional `num_cmp_selected_blocks` parameter that selects only the top-k most relevant FA4 blocks of compressed KV per Q tile, making the compressed branch O(N * k_cmp) instead of O(N²/compress_block_size). New functions: - `select_compressed_blocks()`: Scores Q tiles against compressed K, aggregates per FA4 block (128 compressed tokens), applies causal mask, selects top-k. Supports both padded and varlen. Chunked for bounded peak memory. - `build_compressed_block_sparse_tensors()`: Builds BlockSparseTensorsTorch with selected blocks in mask_block (not full_block) so FA4 applies the compressed causal mask (mask_mod) within each block. Modified: - `nsa_forward()` and `nsa()`: Accept `num_cmp_selected_blocks`. When set, uses block-sparse FA4 with both mask_mod AND block_sparse_tensors for the compressed branch. Both padded and varlen paths updated. - `NSAFunction`: Threads `cmp_sparse_tensors` through forward→ctx→backward. Backward transposes compressed block-sparse for FA4 bwd. - `_transpose_block_sparse_for_bwd()`: New `use_mask_blocks` flag to read from and output as mask blocks (needed for compressed branch). Expected impact at 1M with num_cmp_selected_blocks=16: Compressed branch: 362ms → ~45ms (~8x). Total: 629ms → ~312ms (~20x vs FA4). Differential Revision: D98540309
Summary: The block selection kernel used scalar dot products in CuteDSL — each of 128 threads computed N_cmp/128 dot products of dimension D=128. At 1M tokens this took 135ms (21.5% of forward time). Replace with cuBLAS GEMM via torch.matmul: Q_mean @ K_cmp.T produces scores in a single highly-optimized GEMM call. Top-k and sorting done via torch.topk and torch.sort (PyTorch ops). Chunked over Q tiles for bounded peak memory. Benefits: - cuBLAS GEMM is 10-50x faster than scalar dot products for this workload - No CuteDSL JIT compilation for the select kernel (saves ~10GB JIT cache) - No fallback to PyTorch reference for large N_cmp (the GEMM handles all sizes) - Simpler code (no CuteDSL kernel boilerplate) The CuteDSL kernel code is retained for reference/testing but no longer called from fused_score_and_select_blocks. Expected impact: block selection 135ms → ~20-25ms at 1M tokens. Differential Revision: D98540306
Summary: Benchmark comparing NSA with and without num_cmp_selected_blocks=16. At 1M tokens: 26.42x speedup (vs 13.13x without sparse compressed). Differential Revision: D98540308
Summary: - Forward at 1M: 26.42x with num_cmp_selected_blocks=16 (up from 10.19x) - Forward at 1M without CmpSp: 13.13x (up from 10.19x due to GEMM scoring) - Updated component breakdown with GEMM-based scoring profile - Added documentation for sparse compressed branch and GEMM scoring - Noted multi-stream investigation (not beneficial on GB200) Differential Revision: D98540307
…w_sparse, _fa4_fwd_simple Summary: Remove unused code that was superseded by GEMM-based scoring and FA4 native window support: - CuteDSL fused scoring + top-k kernel in select.py (~280 lines) — replaced by GEMM-based scoring (torch.bmm + torch.topk) - nsa_scoring.py (146 lines) — CuteDSL column-reduction utilities for a future fused scoring-within-FA4 kernel, never imported - nsa_topk.py (273 lines) — CuteDSL in-register parallel top-K, logic was inlined into the (now also removed) select kernel - window_sparse.py (93 lines) — sliding window as block-sparse pattern, superseded by FA4 native window_size_left - _fa4_fwd_simple() in nsa_forward.py — unused FA4 autograd wrapper Total: ~810 lines removed. No production or test code references any of these. All 30 tests pass (select, forward, sparsity masks). Differential Revision: D98816370
Summary: Add fused_score_and_select_all() that computes Q_mean and the Q_mean x K_cmp GEMM once, then derives both: - selected-branch block indices (top-k over compressed tokens) - compressed-branch FA4 block indices (aggregate + top-k over FA4 blocks) Previously, fused_score_and_select_blocks() and select_compressed_blocks() each independently computed Q_mean and ran the same GEMM — the dominant cost at 43% of forward time at 128K. When num_cmp_selected_blocks is enabled, the GEMM was running twice. Implementation: - Extract _compute_q_mean() shared helper (varlen + 4D) - Extract _score_and_topk() and _score_aggregate_and_topk() helpers using GQA-aware bmm (select_compressed_blocks previously used repeat_interleave) - fused_score_and_select_all() runs one GEMM loop, deriving both outputs from the same scores tensor - Update nsa_forward.py and nsa_autograd.py call sites to use combined function All 54 tests pass (select, forward, backward, varlen). Differential Revision: D98816373
Summary: Replace dense BlockSparseTensorsTorch index tensors with compact format where the last dimension is k (selected blocks) instead of n_blocks_k (total KV blocks). FA4 only accesses indices 0..cnt-1 per Q-tile, so compact format is sufficient. Changes: - build_fa4_block_sparse_tensors: emit compact full_block_idx (last dim = k_fa4) and minimal mask_block_idx (last dim = 1). Was (B,H,N_q_tiles,n_blocks_k), now (B,H,N_q_tiles,k_fa4). - build_compressed_block_sparse_tensors: emit compact mask_block_idx (last dim = k_cmp). Was (B,H,N_q_tiles,n_kv_blocks). - _transpose_block_sparse_for_bwd: rewritten as sparse inverted-index construction. No more dense boolean attendance matrix. Output is compact (B,H,n_kv_blocks,max_q_per_kv) instead of (B,H,n_kv_blocks,n_q_tiles). - FA4 block_sparsity.py: relax strict validation to allow compact N dimension (mask_block_idx.shape[3] <= expected_n_blocks, was ==). Memory savings at 1M tokens (B=1, H=32, n_block_size=128): - Forward: ~4 GiB eliminated (was n_blocks_k=8192 wide, now k=8 wide) - Backward: ~10 GiB eliminated (dense attendance + transpose + sort) 44/45 tests pass (sparsity_masks, forward, backward, varlen). The remaining test_fa4_block_sparse_backward_varlen failure appears pre-existing. Differential Revision: D98816371
Summary: Add compress_factor parameter to FA4's causal masking infrastructure, replacing the mask_mod approach for compressed attention. When compress_factor > 1, the causal condition becomes: kv_idx * compress_factor <= q_idx instead of the standard: kv_idx <= q_idx This is implemented by dividing the Q row index by compress_factor in the causal masking code, affecting both tile skipping and element-wise masking. The change is minimal and surgical — the diagonal slope changes from 1 to 1/compress_factor. Benefits over mask_mod: - Tile skipping: ~50% of KV blocks can be skipped for long sequences (mask_mod visited ALL blocks) - R2P masking: single GPU instruction (mask_mod used element-wise loop) - pack_gqa restored: mask_mod disabled GQA packing - No separate compile path: eliminates mask_mod_hash from compile key - Varlen compatible: mask_mod was blocked for varlen Files changed: - FA4: interface.py, block_info.py, mask.py, flash_fwd_sm100.py - NSA: nsa_forward.py (use causal=True + compress_factor=compress_block_size) - NSA: sparsity_masks.py (compressed blocks now use full_block, not mask_block) All 7 forward tests pass. Differential Revision: D98816375
Summary: Extend compress_factor to FA4 backward and update all NSA compressed branch calls (forward + backward, fixed-length + varlen) to use compress_factor instead of mask_mod. Changes: - FA4 backward: add compress_factor to _flash_attn_bwd, FlashAttentionBackwardSm100, BlockInfo.get_m_block_min_max (backward tile skipping), and AttentionMask.apply_mask_sm100_transposed (backward causal masking) - NSA nsa_autograd.py: replace all mask_mod=compressed_mask with compress_factor=compress_block_size, causal=causal (10 call sites) - NSA nsa_autograd.py: remove _make_compressed_causal_mask import - NSA sparsity_masks.py: compressed blocks now use full_block (not mask_block) since causal masking is native via compress_factor - NSA nsa_autograd.py: backward transpose uses use_mask_blocks=False This eliminates: - The _make_compressed_causal_mask CuteDSL callable (no more mask_mod_hash in FA4 compile key) - The pack_gqa=False override for compressed branch - The varlen+mask_mod incompatibility (compressed branch can now use varlen) All 37 tests pass (forward, backward with all weight configurations, varlen). Differential Revision: D98816372
Summary: Replace all CuteDSL auxiliary kernels with standard PyTorch operations: Compression (compress.py): - Forward: reshape + mean (was CuteDSL scalar accumulation kernel) - Backward: unsqueeze + expand + divide (was CuteDSL scatter kernel) - Eliminates _fused_compress_compile_cache and _fused_compress_bwd_compile_cache Gating (gating.py): - Forward: compute_gates() + gate_and_combine() (was CuteDSL warp-shuffle kernel) - Backward: element-wise ops + einsum (was CuteDSL kernel + PyTorch dW_gate) - Eliminates _fused_gating_compile_cache and _fused_gating_bwd_compile_cache Impact: - Eliminates 4 CuteDSL kernel families from the JIT cache (~5 compile variants) - After this + compress_factor, FA4 is the ONLY compiled kernel family on the hot path (no more cutlass/cuda.bindings.driver imports in compress/gating) - JIT cache pressure should be materially reduced for 2M+ context The CuteDSL kernel definitions (_make_fused_compress_kernel, etc.) are retained as dead code for now — they can be removed in a follow-up cleanup. All 65 tests pass (compress, gating, forward, backward). Differential Revision: D98816374
Summary: Update performance data and documentation after 7 optimization commits: - NSA forward at 1M: 13.1x → 28.6x (2.2x improvement) - NSA+CmpSparse at 1M: 26.4x → 38.1x (1.4x improvement) - Metadata memory at 1M: ~14 GiB → ~0.1 GiB (140x reduction) - CuteDSL kernel variants: ~11 → ~6 (FA4 only) - mask_mod compile paths: eliminated Differential Revision: D98816376
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary:
Update performance data and documentation after 7 optimization commits:
Differential Revision: D98816376