-
Notifications
You must be signed in to change notification settings - Fork 210
add DFLASH block mask support for SDPA #477
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,6 +19,34 @@ | |
| create_block_mask = None | ||
|
|
||
|
|
||
| def create_dflash_sdpa_mask(anchor_positions, block_keep_mask, S, block_size, device): | ||
| B, N = anchor_positions.shape | ||
| Q_LEN = N * block_size | ||
| KV_LEN = S + N * block_size | ||
|
|
||
| q_indices = torch.arange(Q_LEN, device=device).view(1, 1, -1, 1) # (1, 1, Q_LEN, 1) | ||
| kv_indices = torch.arange(KV_LEN, device=device).view( | ||
| 1, 1, 1, -1 | ||
| ) # (1, 1, 1, KV_LEN) | ||
|
|
||
| q_block_ids = q_indices // block_size | ||
|
|
||
| anchor_expanded = anchor_positions.view(B, 1, N, 1).repeat_interleave( | ||
| block_size, dim=2 | ||
| ) | ||
|
|
||
| mask_context = (kv_indices < S) & (kv_indices < anchor_expanded) | ||
|
|
||
| is_draft = kv_indices >= S | ||
| kv_block_ids = (kv_indices - S) // block_size | ||
| mask_draft = is_draft & (q_block_ids == kv_block_ids) | ||
|
|
||
| valid_block = block_keep_mask.view(B, 1, N, 1).repeat_interleave(block_size, dim=2) | ||
|
|
||
| final_mask = (mask_context | mask_draft) & valid_block | ||
| return final_mask | ||
|
|
||
|
|
||
| def create_dflash_block_mask( | ||
| anchor_positions: torch.Tensor, | ||
| block_keep_mask: torch.Tensor, | ||
|
|
@@ -207,13 +235,22 @@ def forward( | |
| draft_position_ids = self._create_position_ids(anchor_positions) | ||
| full_position_ids = torch.cat([context_position_ids, draft_position_ids], dim=1) | ||
|
|
||
| dflash_attn_mask = create_dflash_block_mask( | ||
| anchor_positions=anchor_positions, | ||
| block_keep_mask=block_keep_mask, | ||
| S=seq_len, | ||
| block_size=self.block_size, | ||
| device=device, | ||
| ) | ||
| if self.attention_backend == "flex_attention": | ||
| dflash_attn_mask = create_dflash_block_mask( | ||
| anchor_positions=anchor_positions, | ||
| block_keep_mask=block_keep_mask, | ||
| S=seq_len, | ||
| block_size=self.block_size, | ||
| device=device, | ||
| ) | ||
| else: | ||
| dflash_attn_mask = create_dflash_sdpa_mask( | ||
| anchor_positions=anchor_positions, | ||
| block_keep_mask=block_keep_mask, | ||
| S=seq_len, | ||
| block_size=self.block_size, | ||
| device=device, | ||
| ) | ||
|
Comment on lines
+238
to
+253
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To improve readability and reduce code duplication, you can refactor this if self.attention_backend == "flex_attention":
mask_fn = create_dflash_block_mask
else:
mask_fn = create_dflash_sdpa_mask
dflash_attn_mask = mask_fn(
anchor_positions=anchor_positions,
block_keep_mask=block_keep_mask,
S=seq_len,
block_size=self.block_size,
device=device,
) |
||
|
|
||
| output_hidden = self.draft_model( | ||
| position_ids=full_position_ids, | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,265 @@ | ||
| import unittest | ||
|
|
||
| import torch | ||
|
|
||
| from specforge.core.dflash import create_dflash_block_mask, create_dflash_sdpa_mask | ||
|
|
||
|
|
||
| def _reference_dflash_mask(anchor_positions, block_keep_mask, S, block_size, device): | ||
| """Element-level reference mask mirroring the mask_mod inside create_dflash_block_mask. | ||
|
|
||
| This uses plain Python loops so correctness is obvious by inspection. | ||
| """ | ||
| B, N = anchor_positions.shape | ||
| Q_LEN = N * block_size | ||
| KV_LEN = S + N * block_size | ||
|
|
||
| mask = torch.zeros(B, 1, Q_LEN, KV_LEN, dtype=torch.bool, device=device) | ||
| for b in range(B): | ||
| for q_idx in range(Q_LEN): | ||
| q_block_id = q_idx // block_size | ||
| anchor_pos = anchor_positions[b, q_block_id].item() | ||
| is_valid = block_keep_mask[b, q_block_id].item() | ||
| if not is_valid: | ||
| continue | ||
| for kv_idx in range(KV_LEN): | ||
| is_context = kv_idx < S | ||
| ctx_visible = is_context and (kv_idx < anchor_pos) | ||
|
|
||
| is_draft = kv_idx >= S | ||
| kv_block_id = (kv_idx - S) // block_size | ||
| draft_visible = is_draft and (q_block_id == kv_block_id) | ||
|
|
||
| if ctx_visible or draft_visible: | ||
| mask[b, 0, q_idx, kv_idx] = True | ||
| return mask | ||
|
|
||
|
|
||
| class TestDFlashMask(unittest.TestCase): | ||
|
|
||
| def setUp(self): | ||
| torch.manual_seed(42) | ||
| self.device = torch.device("cuda") | ||
|
|
||
| def _compare_masks(self, anchor_positions, block_keep_mask, S, block_size): | ||
| """Compare create_dflash_sdpa_mask against element-level reference (ground truth).""" | ||
| anchor_positions = anchor_positions.to(self.device) | ||
| block_keep_mask = block_keep_mask.to(self.device) | ||
|
|
||
| sdpa_mask = create_dflash_sdpa_mask( | ||
| anchor_positions=anchor_positions, | ||
| block_keep_mask=block_keep_mask, | ||
| S=S, | ||
| block_size=block_size, | ||
| device=self.device, | ||
| ) | ||
|
|
||
| ref_mask = _reference_dflash_mask( | ||
| anchor_positions=anchor_positions, | ||
| block_keep_mask=block_keep_mask, | ||
| S=S, | ||
| block_size=block_size, | ||
| device=self.device, | ||
| ) | ||
|
|
||
| self.assertEqual( | ||
| sdpa_mask.shape, | ||
| ref_mask.shape, | ||
| f"Shape mismatch: sdpa {sdpa_mask.shape} vs ref {ref_mask.shape}", | ||
| ) | ||
| self.assertTrue( | ||
| torch.equal(sdpa_mask, ref_mask), | ||
| f"Mask mismatch with S={S}, block_size={block_size}, " | ||
| f"anchors={anchor_positions.tolist()}, keep={block_keep_mask.tolist()}\n" | ||
| f"Diff positions: {(sdpa_mask != ref_mask).nonzero(as_tuple=False).tolist()}", | ||
| ) | ||
|
|
||
| def _compare_block_mask_consistency( | ||
| self, anchor_positions, block_keep_mask, S, block_size | ||
| ): | ||
| """Verify create_dflash_block_mask block-level mask is consistent with reference.""" | ||
| anchor_positions = anchor_positions.to(self.device) | ||
| block_keep_mask = block_keep_mask.to(self.device) | ||
|
|
||
| block_mask = create_dflash_block_mask( | ||
| anchor_positions=anchor_positions, | ||
| block_keep_mask=block_keep_mask, | ||
| S=S, | ||
| block_size=block_size, | ||
| device=self.device, | ||
| ) | ||
|
|
||
| ref_mask = _reference_dflash_mask( | ||
| anchor_positions=anchor_positions, | ||
| block_keep_mask=block_keep_mask, | ||
| S=S, | ||
| block_size=block_size, | ||
| device=self.device, | ||
| ) | ||
|
|
||
| dense_blocks = block_mask.to_dense() # (B, H, Q_blocks, KV_blocks) | ||
| BM_BLOCK = 128 | ||
| B, N = anchor_positions.shape | ||
| Q_LEN = N * block_size | ||
| KV_LEN = S + N * block_size | ||
| n_q_blocks = (Q_LEN + BM_BLOCK - 1) // BM_BLOCK | ||
| n_kv_blocks = (KV_LEN + BM_BLOCK - 1) // BM_BLOCK | ||
|
|
||
| ref_int = ref_mask.squeeze(1).int() # (B, Q_LEN, KV_LEN) | ||
| for b in range(B): | ||
| for qi in range(n_q_blocks): | ||
| for ki in range(n_kv_blocks): | ||
| q_start = qi * BM_BLOCK | ||
| q_end = min(q_start + BM_BLOCK, Q_LEN) | ||
| k_start = ki * BM_BLOCK | ||
| k_end = min(k_start + BM_BLOCK, KV_LEN) | ||
| has_nonzero = ref_int[b, q_start:q_end, k_start:k_end].any().item() | ||
| block_val = dense_blocks[b, 0, qi, ki].item() | ||
| if has_nonzero: | ||
| self.assertEqual( | ||
| block_val, | ||
| 1, | ||
| f"Block ({qi},{ki}) for batch {b} should be 1 but got 0", | ||
| ) | ||
|
|
||
| def test_basic_single_batch_single_block(self): | ||
| """Single batch, single draft block.""" | ||
| anchor_positions = torch.tensor([[64]]) | ||
| block_keep_mask = torch.tensor([[True]]) | ||
| self._compare_masks(anchor_positions, block_keep_mask, S=128, block_size=4) | ||
|
|
||
| def test_basic_single_batch_multi_block(self): | ||
| """Single batch, multiple draft blocks.""" | ||
| anchor_positions = torch.tensor([[32, 64, 96]]) | ||
| block_keep_mask = torch.tensor([[True, True, True]]) | ||
| self._compare_masks(anchor_positions, block_keep_mask, S=128, block_size=4) | ||
|
|
||
| def test_multi_batch(self): | ||
| """Multiple batches with different anchors.""" | ||
| anchor_positions = torch.tensor([[16, 48, 80], [32, 64, 100]]) | ||
| block_keep_mask = torch.tensor([[True, True, True], [True, True, True]]) | ||
| self._compare_masks(anchor_positions, block_keep_mask, S=128, block_size=4) | ||
|
|
||
| def test_invalid_blocks(self): | ||
| """Some blocks are masked out (block_keep_mask=False).""" | ||
| anchor_positions = torch.tensor([[20, 50, 80, 110]]) | ||
| block_keep_mask = torch.tensor([[True, False, True, False]]) | ||
| self._compare_masks(anchor_positions, block_keep_mask, S=128, block_size=4) | ||
|
|
||
| def test_all_blocks_invalid(self): | ||
| """All blocks invalid — mask should be all zeros.""" | ||
| anchor_positions = torch.tensor([[30, 60]]) | ||
| block_keep_mask = torch.tensor([[False, False]]) | ||
| self._compare_masks(anchor_positions, block_keep_mask, S=128, block_size=4) | ||
|
|
||
| def test_anchor_at_zero(self): | ||
| """Anchor at position 0 — no context tokens visible.""" | ||
| anchor_positions = torch.tensor([[0, 64]]) | ||
| block_keep_mask = torch.tensor([[True, True]]) | ||
| self._compare_masks(anchor_positions, block_keep_mask, S=128, block_size=4) | ||
|
|
||
| def test_anchor_at_boundary(self): | ||
| """Anchor exactly at S — all context tokens visible.""" | ||
| anchor_positions = torch.tensor([[128]]) | ||
| block_keep_mask = torch.tensor([[True]]) | ||
| self._compare_masks(anchor_positions, block_keep_mask, S=128, block_size=4) | ||
|
|
||
| def test_large_block_size(self): | ||
| """Larger draft block size.""" | ||
| anchor_positions = torch.tensor([[50, 150]]) | ||
| block_keep_mask = torch.tensor([[True, True]]) | ||
| self._compare_masks(anchor_positions, block_keep_mask, S=256, block_size=16) | ||
|
|
||
| def test_block_size_1(self): | ||
| """Minimal block_size=1.""" | ||
| anchor_positions = torch.tensor([[10, 30, 50]]) | ||
| block_keep_mask = torch.tensor([[True, True, True]]) | ||
| self._compare_masks(anchor_positions, block_keep_mask, S=64, block_size=1) | ||
|
|
||
| def test_mixed_validity_multi_batch(self): | ||
| """Multi-batch with mixed block validity patterns.""" | ||
| anchor_positions = torch.tensor([[10, 40, 70, 100], [20, 50, 80, 110]]) | ||
| block_keep_mask = torch.tensor( | ||
| [[True, False, True, True], [False, True, False, True]] | ||
| ) | ||
| self._compare_masks(anchor_positions, block_keep_mask, S=128, block_size=8) | ||
|
|
||
| def test_various_context_lengths(self): | ||
| """Sweep over various context lengths.""" | ||
| for S in [64, 128, 256, 512]: | ||
| with self.subTest(S=S): | ||
| anchor_positions = torch.tensor([[S // 4, S // 2, 3 * S // 4]]) | ||
| block_keep_mask = torch.tensor([[True, True, True]]) | ||
| self._compare_masks( | ||
| anchor_positions, block_keep_mask, S=S, block_size=4 | ||
| ) | ||
|
|
||
| def test_various_block_sizes(self): | ||
| """Sweep over various draft block sizes.""" | ||
| for block_size in [1, 2, 4, 8, 16]: | ||
| with self.subTest(block_size=block_size): | ||
| anchor_positions = torch.tensor([[32, 80]]) | ||
| block_keep_mask = torch.tensor([[True, True]]) | ||
| self._compare_masks( | ||
| anchor_positions, block_keep_mask, S=128, block_size=block_size | ||
| ) | ||
|
|
||
| def test_many_blocks(self): | ||
| """Large number of draft blocks.""" | ||
| N = 32 | ||
| anchors = torch.arange(10, 10 + N * 4, 4).unsqueeze(0) | ||
| keep = torch.ones(1, N, dtype=torch.bool) | ||
| keep[0, ::3] = False | ||
| self._compare_masks(anchors, keep, S=256, block_size=4) | ||
|
|
||
| def test_consecutive_anchors(self): | ||
| """Anchors placed consecutively.""" | ||
| anchor_positions = torch.tensor([[0, 1, 2, 3]]) | ||
| block_keep_mask = torch.tensor([[True, True, True, True]]) | ||
| self._compare_masks(anchor_positions, block_keep_mask, S=64, block_size=4) | ||
|
|
||
| def test_random_stress(self): | ||
| """Randomized stress test with multiple random configurations.""" | ||
| rng = torch.Generator().manual_seed(123) | ||
| for trial in range(5): | ||
| with self.subTest(trial=trial): | ||
| B = torch.randint(1, 4, (1,), generator=rng).item() | ||
| N = torch.randint(1, 8, (1,), generator=rng).item() | ||
| S = 64 * torch.randint(1, 5, (1,), generator=rng).item() | ||
| block_size = [1, 2, 4, 8][ | ||
| torch.randint(0, 4, (1,), generator=rng).item() | ||
| ] | ||
|
|
||
| anchor_positions = torch.stack( | ||
| [ | ||
| torch.randperm(S, generator=rng)[:N].sort().values | ||
| for _ in range(B) | ||
| ] | ||
| ) | ||
| block_keep_mask = torch.rand(B, N, generator=rng) > 0.3 | ||
|
|
||
| self._compare_masks( | ||
| anchor_positions, block_keep_mask, S=S, block_size=block_size | ||
| ) | ||
|
|
||
| def test_block_mask_consistency(self): | ||
| """Verify BlockMask block-level mask is consistent with element-level reference.""" | ||
| anchor_positions = torch.tensor([[32, 64, 96]]) | ||
| block_keep_mask = torch.tensor([[True, True, True]]) | ||
| self._compare_block_mask_consistency( | ||
| anchor_positions, block_keep_mask, S=128, block_size=4 | ||
| ) | ||
|
|
||
| def test_block_mask_consistency_mixed(self): | ||
| """Verify BlockMask consistency with mixed validity.""" | ||
| anchor_positions = torch.tensor([[10, 40, 70, 100], [20, 50, 80, 110]]) | ||
| block_keep_mask = torch.tensor( | ||
| [[True, False, True, True], [False, True, False, True]] | ||
| ) | ||
| self._compare_block_mask_consistency( | ||
| anchor_positions, block_keep_mask, S=128, block_size=8 | ||
| ) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main(verbosity=2) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This new function can be improved by:
create_dflash_block_mask.maskvariable on line 27.Here is a suggested implementation that incorporates these changes.