Fused CuteDSL kernel for KV compression (#302)#302
Open
jduprat wants to merge 4 commits into
Open
Conversation
jduprat
added a commit
to jduprat/MSLK
that referenced
this pull request
Apr 2, 2026
Summary: Replace two PyTorch mean() kernel launches (one for K, one for V) with a single fused CuteDSL kernel that processes both K and V simultaneously. Key design: - Grid: one thread block per output element (b, j, h) — total B * N_cmp * H_kv - Block: 128 threads, each handling ceil(D/128) elements - Float32 accumulation: each thread reads block_size input positions, accumulates K and V values in Float32, divides by inv_block_size, writes the mean - 2D flattening for CuteDSL: K/V reshaped to (B*N*H_kv, D), manual index decomposition (b, j, h) from bidx - W_k/W_v projections remain as torch.einsum (cuBLAS GEMM — already optimal) - Varlen path: PyTorch per-sequence loop (CuteDSL varlen kernel in future diff) - In-memory compile cache keyed by (dtype, D, block_size) No performance chart update — chart will be generated after Diff 5 (int32 overflow fix) when we can benchmark at N >= 2M. Differential Revision: D99181839
jduprat
added a commit
to jduprat/MSLK
that referenced
this pull request
Apr 2, 2026
Summary: Replace two PyTorch mean() kernel launches (one for K, one for V) with a single fused CuteDSL kernel that processes both K and V simultaneously. Key design: - Grid: one thread block per output element (b, j, h) — total B * N_cmp * H_kv - Block: 128 threads, each handling ceil(D/128) elements - Float32 accumulation: each thread reads block_size input positions, accumulates K and V values in Float32, divides by inv_block_size, writes the mean - 2D flattening for CuteDSL: K/V reshaped to (B*N*H_kv, D), manual index decomposition (b, j, h) from bidx - W_k/W_v projections remain as torch.einsum (cuBLAS GEMM — already optimal) - Varlen path: PyTorch per-sequence loop (CuteDSL varlen kernel in future diff) - In-memory compile cache keyed by (dtype, D, block_size) No performance chart update — chart will be generated after Diff 5 (int32 overflow fix) when we can benchmark at N >= 2M. Differential Revision: D99181839
jduprat
added a commit
to jduprat/MSLK
that referenced
this pull request
Apr 2, 2026
Summary: Replace two PyTorch mean() kernel launches (one for K, one for V) with a single fused CuteDSL kernel that processes both K and V simultaneously. Key design: - Grid: one thread block per output element (b, j, h) — total B * N_cmp * H_kv - Block: 128 threads, each handling ceil(D/128) elements - Float32 accumulation: each thread reads block_size input positions, accumulates K and V values in Float32, divides by inv_block_size, writes the mean - 2D flattening for CuteDSL: K/V reshaped to (B*N*H_kv, D), manual index decomposition (b, j, h) from bidx - W_k/W_v projections remain as torch.einsum (cuBLAS GEMM — already optimal) - Varlen path: PyTorch per-sequence loop (CuteDSL varlen kernel in future diff) - In-memory compile cache keyed by (dtype, D, block_size) No performance chart update — chart will be generated after Diff 5 (int32 overflow fix) when we can benchmark at N >= 2M. Differential Revision: D99181839
Summary: Establish the NSA (Native Sparse Attention) module with reference implementations, compact block-sparse metadata format, and the FA4-based forward pass orchestrator. Three attention branches combined via learned gating: 1. Compressed: FA4 on mean-pooled KV (short sequence) 2. Selected: FA4 with block sparsity (top-k important blocks per Q-tile) 3. Sliding window: FA4 with window_size_left Key components: - compress.py: compress_kv() — mean-pool + optional learned projection - select.py: score_and_select_blocks() — tiled scoring with O(N) peak memory - gating.py: compute_gates() + gate_and_combine() — sigmoid gating, chunked - sparsity_masks.py: build_fa4_block_sparse_tensors() — compact index format (last dim = k selected blocks, not n_blocks_k total). Handles both expansion (compress_block_size >= n_block_size) and contraction (with sort + dedup). - nsa_forward.py: nsa_forward() orchestrator + _fa4_fwd() wrapper - reference.py: Pure PyTorch differentiable reference for correctness validation FA4 dependency: imports from mslk.attention.flash_attn.interface shim (tries internal fork, falls back to upstream flash_attn). Uses compress_factor for compressed causal masking (not mask_mod). All non-FA4 accumulation paths use fp32 for numerical stability with bf16/fp16. No performance impact — this is the foundation diff (reference implementations only, no CuteDSL fused kernels yet). Performance chart N/A for this diff. Differential Revision: D99181841
Summary: Replace the multi-kernel PyTorch gating (compute_gates + gate_and_combine) with a single fused CuteDSL kernel (fused_gate_and_combine) — 4-7x faster on B200. Key design: - One warp (32 threads) per (b,n,h) row — each warp handles one output position - Warp-shuffle butterfly reduction for 3 gate dot-products (no shared memory) - elems_per_thread = D // 32, staying in registers (4 for D=128) - Sigmoid via log2-exp2 trick: uses fast hardware exp2 - All accumulation in Float32 for numerical stability with bf16/fp16 inputs - In-memory compile cache keyed by (dtype, D, has_gate_weight) When gate_proj_weight is None, skips the CuteDSL kernel entirely and returns a simple (O_cmp + O_slc + O_sld) / 3 average — avoids kernel launch overhead for the ungated case. Returns (output, gates) tuple so gates are available for the backward pass. PyTorch reference implementations (compute_gates, gate_and_combine) retained for testing and fallback. No performance chart yet — gating alone is not the bottleneck. Chart will be updated after fused scoring (Diff 3) and fused compression (Diff 4). Differential Revision: D99181847
Summary: Replace ~10 kernel launches per query tile with a GEMM-based scoring pipeline using the Q_mean algebraic identity: mean(Q @ K) = mean(Q) @ K, reducing scoring from a (q_tile_size, D) x (D, N_cmp) GEMM per tile to a single (D,) . (D, N_cmp) GEMV — 256x fewer FLOPs. Key components: - _compute_q_mean(): Single PyTorch kernel computes per-tile mean of Q in fp32. Supports both 4D fixed-length and 3D varlen (with cu_seqlens). - _score_and_topk(): GQA-aware bmm that folds GQA groups into the M dimension of the GEMM, avoiding K_cmp expansion from H_kv to H heads: (B*H_kv, n_tiles*groups, D) @ (B*H_kv, D, N_cmp). cuBLAS GEMM. - fused_score_and_select_blocks(): Unified entry for selected branch. - fused_score_and_select_all(): Computes GEMM once, derives indices for both selected and compressed branches (avoids duplicate GEMM). Chunked processing (64 Q-tiles per chunk) bounds peak memory. All scoring in fp32 for numerical stability with bf16/fp16 inputs. NSA becomes faster than dense FA4 at ~24K tokens with this optimization; reaches 12.5x speedup at 512K (was barely faster before due to scoring bottleneck). Performance chart will be updated after Diff 4. Differential Revision: D99181843
jduprat
added a commit
to jduprat/MSLK
that referenced
this pull request
Apr 2, 2026
Summary: Replace two PyTorch mean() kernel launches (one for K, one for V) with a single fused CuteDSL kernel that processes both K and V simultaneously. Key design: - Grid: one thread block per output element (b, j, h) — total B * N_cmp * H_kv - Block: 128 threads, each handling ceil(D/128) elements - Float32 accumulation: each thread reads block_size input positions, accumulates K and V values in Float32, divides by inv_block_size, writes the mean - 2D flattening for CuteDSL: K/V reshaped to (B*N*H_kv, D), manual index decomposition (b, j, h) from bidx - W_k/W_v projections remain as torch.einsum (cuBLAS GEMM — already optimal) - Varlen path: PyTorch per-sequence loop (CuteDSL varlen kernel in future diff) - In-memory compile cache keyed by (dtype, D, block_size) No performance chart update — chart will be generated after Diff 5 (int32 overflow fix) when we can benchmark at N >= 2M. Differential Revision: D99181839
3da5826 to
1c54895
Compare
jduprat
added a commit
to jduprat/MSLK
that referenced
this pull request
Apr 3, 2026
Summary:
Replace two PyTorch mean() kernel launches (one for K, one for V) with a single
fused CuteDSL kernel that processes both K and V simultaneously.
Key design:
- Grid: one thread block per output element (b, j, h) — total B * N_cmp * H_kv
- Block: 128 threads, each handling ceil(D/128) elements
- Float32 accumulation: each thread reads block_size input positions, accumulates
K and V values in Float32, divides by inv_block_size, writes the mean
- 2D flattening for CuteDSL: K/V reshaped to (B*N*H_kv, D), manual index
decomposition (b, j, h) from bidx
- W_k/W_v projections remain as torch.einsum (cuBLAS GEMM — already optimal)
- Varlen path: PyTorch per-sequence loop (CuteDSL varlen kernel in future diff)
- In-memory compile cache keyed by (dtype, D, block_size)
Slightly lower cross-over point, same overall performance as previous diff.
{F1987648272}
Differential Revision: D99181839
Summary: Pull Request resolved: meta-pytorch#302 Replace two PyTorch mean() kernel launches (one for K, one for V) with a single fused CuteDSL kernel that processes both K and V simultaneously. Key design: - Grid: one thread block per output element (b, j, h) — total B * N_cmp * H_kv - Block: 128 threads, each handling ceil(D/128) elements - Float32 accumulation: each thread reads block_size input positions, accumulates K and V values in Float32, divides by inv_block_size, writes the mean - 2D flattening for CuteDSL: K/V reshaped to (B*N*H_kv, D), manual index decomposition (b, j, h) from bidx - W_k/W_v projections remain as torch.einsum (cuBLAS GEMM — already optimal) - Varlen path: PyTorch per-sequence loop (CuteDSL varlen kernel in future diff) - In-memory compile cache keyed by (dtype, D, block_size) Slightly lower cross-over point, same overall performance as previous diff. {F1987648272} Differential Revision: D99181839
1c54895 to
eff5113
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary:
Replace two PyTorch mean() kernel launches (one for K, one for V) with a single
fused CuteDSL kernel that processes both K and V simultaneously.
Key design:
K and V values in Float32, divides by inv_block_size, writes the mean
decomposition (b, j, h) from bidx
Slightly lower cross-over point, same overall performance as previous diff.
{F1987648272}
Differential Revision: D99181839