Added inter document masking for manual and flash attention.#434
Added inter document masking for manual and flash attention.#434BlueCrescent wants to merge 9 commits intomainfrom
Conversation
There was a problem hiding this comment.
Pull request overview
Adds inter-document masking support to the GPT-2 attention stack so sequences containing multiple concatenated documents can prevent cross-document attention for both manual attention and DAO flash attention (via varlen).
Changes:
- Added
CausalSelfAttention.prepare_inter_document_masking()and threaded optional masking info through attention execution paths. - Implemented DAO flash varlen execution path to support document-wise masking/splitting without padding leakage.
- Added extensive unit tests for inter-document masking behaviors and edge cases.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 8 comments.
| File | Description |
|---|---|
tests/models/test_causal_self_attention.py |
Adds comprehensive tests for inter-document masking across manual and DAO flash attention implementations. |
src/modalities/models/gpt2/gpt2_model.py |
Implements inter-document masking preparation, DAO flash varlen execution, and integrates masking into GPT2LLM/GPT2Block/CausalSelfAttention. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| torch.manual_seed(0) # FIXME remove or do within tests? | ||
|
|
||
|
|
There was a problem hiding this comment.
Setting torch.manual_seed(0) at module import time mutates global RNG state for the entire test run, which can make unrelated tests order-dependent. Prefer seeding inside the specific tests that need determinism (or via a fixture) instead of at import scope.
| torch.manual_seed(0) # FIXME remove or do within tests? |
| device = self.c_proj.weight.device | ||
| if self.attention_impl == AttentionImplementation.MANUAL: | ||
| batch_size = len(in_batch_seq_lens) | ||
| attn_mask = torch.zeros((batch_size, max_seq_len, max_seq_len), dtype=torch.bool, device=device) | ||
| for i, doc_seq_lens in enumerate(in_batch_seq_lens): | ||
| doc_boundaries = torch.cumsum(torch.tensor([0] + doc_seq_lens, device=device), dim=0) | ||
| for j in range(len(doc_boundaries) - 1): | ||
| start_idx = doc_boundaries[j] | ||
| end_idx = doc_boundaries[j + 1] | ||
| attn_mask[i, start_idx:end_idx, start_idx:end_idx] = True | ||
| return attn_mask |
There was a problem hiding this comment.
prepare_inter_document_masking() validates in_batch_seq_lens for the dao_flash path but not for the manual path. If a batch item's subsequence lengths sum to more than max_seq_len (or contain invalid values), the manual path can produce incorrect masks or raise a low-level indexing error. Consider applying the same validation (e.g., reuse _build_concatenated_lengths_tensor checks) for the manual implementation too.
| attention_masking_information (torch.Tensor | tuple[torch.Tensor, torch.Tensor, int] | None): | ||
| Optional tensor containing masking information for inter-document attention. | ||
|
|
||
| Returns: | ||
| torch.Tensor: The output tensor. |
There was a problem hiding this comment.
execute_attention() now accepts attention_masking_information, but inter-document masking is only applied in the MANUAL / DAO_FLASH implementations. For PYTORCH_FLASH, the mask is currently silently ignored (it always passes attn_mask=None). To avoid surprising behavior, consider raising NotImplementedError (or asserting the mask is None) when attention_impl is PYTORCH_FLASH and a mask is provided.
…ce lengths required for inter document attention masking.
… manual attention. - Also applied some review comments.
…ed eod_token_id to eos_token_id
…doc string Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 6 out of 6 changed files in this pull request and generated 11 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
src/modalities/config/config.py
Outdated
| @model_validator(mode="before") | ||
| def check_sub_seq_lengths_and_eos_token(cls, values): | ||
| sub_seq_lengths_key = values.get("sub_seq_lengths_key") | ||
| eos_token_id = values.get("eos_token_id") | ||
| if (sub_seq_lengths_key is None) != (eos_token_id is None): | ||
| raise ValueError("Either both or neither of sub_seq_lengths_key and eos_token_id must be provided.") | ||
| return values | ||
|
|
||
| @model_validator(mode="before") | ||
| def check_padding_token_and_sub_seq_lengths(cls, values): | ||
| padding_token_id = values.get("padding_token_id") | ||
| sub_seq_lengths_key = values.get("sub_seq_lengths_key") | ||
| if padding_token_id is not None and sub_seq_lengths_key is None: | ||
| raise ValueError("If padding_token_id is provided, sub_seq_lengths_key must also be provided.") | ||
| return values |
There was a problem hiding this comment.
The validator uses mode="before" which means it operates on raw input values before Pydantic parsing. However, the validation logic uses XOR ((sub_seq_lengths_key is None) != (eos_token_id is None)) which might be less clear to readers than explicitly checking both conditions. Additionally, consider using mode="after" for clearer semantics since you're checking already-parsed field values, not raw input.
| @model_validator(mode="before") | |
| def check_sub_seq_lengths_and_eos_token(cls, values): | |
| sub_seq_lengths_key = values.get("sub_seq_lengths_key") | |
| eos_token_id = values.get("eos_token_id") | |
| if (sub_seq_lengths_key is None) != (eos_token_id is None): | |
| raise ValueError("Either both or neither of sub_seq_lengths_key and eos_token_id must be provided.") | |
| return values | |
| @model_validator(mode="before") | |
| def check_padding_token_and_sub_seq_lengths(cls, values): | |
| padding_token_id = values.get("padding_token_id") | |
| sub_seq_lengths_key = values.get("sub_seq_lengths_key") | |
| if padding_token_id is not None and sub_seq_lengths_key is None: | |
| raise ValueError("If padding_token_id is provided, sub_seq_lengths_key must also be provided.") | |
| return values | |
| @model_validator(mode="after") | |
| def check_sub_seq_lengths_and_eos_token(cls, model: "GPT2LLMCollateFnConfig") -> "GPT2LLMCollateFnConfig": | |
| sub_seq_lengths_key = model.sub_seq_lengths_key | |
| eos_token_id = model.eos_token_id | |
| if (sub_seq_lengths_key is None and eos_token_id is not None) or ( | |
| sub_seq_lengths_key is not None and eos_token_id is None | |
| ): | |
| raise ValueError( | |
| "Either both or neither of sub_seq_lengths_key and eos_token_id must be provided." | |
| ) | |
| return model | |
| @model_validator(mode="after") | |
| def check_padding_token_and_sub_seq_lengths(cls, model: "GPT2LLMCollateFnConfig") -> "GPT2LLMCollateFnConfig": | |
| padding_token_id = model.padding_token_id | |
| sub_seq_lengths_key = model.sub_seq_lengths_key | |
| if padding_token_id is not None and sub_seq_lengths_key is None: | |
| raise ValueError("If padding_token_id is provided, sub_seq_lengths_key must also be provided.") | |
| return model |
tests/models/test_gpt2_collator.py
Outdated
| def test_gpt2_collate_sub_seq_lengths_without_eos(): | ||
| collator = GPT2LLMCollateFn( | ||
| sample_key="input_ids", | ||
| target_key="labels", | ||
| sub_seq_lengths_key="sub_seq_lengths", | ||
| eos_token_id=99, | ||
| ) | ||
| batch = [ | ||
| {"input_ids": torch.tensor([10, 11, 12, 13, 14])}, | ||
| {"input_ids": torch.tensor([20, 21, 22, 23, 24])}, | ||
| ] | ||
|
|
||
| result = collator(batch) | ||
|
|
||
| assert result.samples["sub_seq_lengths"] == [[5], [5]] | ||
|
|
||
|
|
||
| def test_gpt2_collate_sub_seq_lengths_with_eos(): | ||
| collator = GPT2LLMCollateFn( | ||
| sample_key="input_ids", | ||
| target_key="labels", | ||
| sub_seq_lengths_key="sub_seq_lengths", | ||
| eos_token_id=99, | ||
| ) | ||
| batch = [ | ||
| {"input_ids": torch.tensor([1, 99, 2, 3, 99])}, | ||
| {"input_ids": torch.tensor([7, 8, 9, 99, 10])}, | ||
| ] | ||
|
|
||
| result = collator(batch) | ||
|
|
||
| assert result.samples["sub_seq_lengths"] == [[2, 3], [4, 1]] |
There was a problem hiding this comment.
Please verify the test expectations are correct. For the input [10, 11, 12, 13, 14], after the shift operation (sample_tensor[:, :-1]) the sequence becomes [10, 11, 12, 13] with length 4. When there's no EOS token, the collator returns [len(seq)] which should be [4], but the test expects [[5], [5]]. Similarly, for test_gpt2_collate_sub_seq_lengths_with_eos, trace through the logic to ensure the expected values match the implementation. If the tests are passing, please add comments explaining the counter-intuitive behavior.
| torch.manual_seed(0) # FIXME remove or do within tests? | ||
|
|
||
|
|
There was a problem hiding this comment.
The FIXME comment suggests uncertainty about whether to keep the global torch.manual_seed(0) or move it within tests. Having a global seed can affect test isolation, as tests may depend on the order they're run. It's better practice to set the seed within each test that needs deterministic behavior, as is already done in several tests (e.g., lines 501, 529, etc.). Consider removing this global seed setting and ensuring each test that requires determinism sets its own seed.
| torch.manual_seed(0) # FIXME remove or do within tests? |
| # TODO: use drop out also without absolute position embedding? | ||
| h = self.transformer.drop(h) if hasattr(self.transformer, "drop") else h | ||
|
|
||
| # TODO: Handle this in case of pipeline parallelism. |
There was a problem hiding this comment.
The TODO comment indicates that pipeline parallelism (PP) handling for inter-document masking is not yet implemented. This is acknowledged in the PR description. However, it would be helpful to add more detail about what needs to be done. For example, in PP, attention_masking_information would need to be computed on the first pipeline stage and passed to subsequent stages, or each stage would need access to sub_seq_lengths. Consider expanding this TODO with specific implementation notes.
| # TODO: Handle this in case of pipeline parallelism. | |
| # TODO: Pipeline parallelism (PP) handling for inter-document masking. | |
| # In the non-PP case we compute `attention_masking_information` once using the | |
| # attention module of the first block (`self.transformer.h["0"].attn`) and reuse | |
| # it for all layers. In a PP setup, only a subset of blocks live on each | |
| # pipeline stage, so this logic needs to be adapted: | |
| # * Option 1: Compute `attention_masking_information` on the first PP stage | |
| # (which must have access to `sub_seq_lengths`) and propagate it along the | |
| # pipeline together with the hidden states so that later stages can pass it | |
| # to their local blocks. | |
| # * Option 2: Ensure that each PP stage has access to `sub_seq_lengths` and | |
| # recompute the same `attention_masking_information` locally using the | |
| # first block on that stage (or a dedicated helper) so that all stages use | |
| # identical masking. | |
| # When implementing PP support, choose one strategy consistent with the | |
| # existing PP activation/metadata communication mechanism and ensure that all | |
| # stages see the same `attention_masking_information`. |
| if is_causal: | ||
| assert attn_mask is None | ||
| temp_mask = torch.ones(L, S, dtype=torch.bool, device=query.device).tril(diagonal=0) # device added | ||
| attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) | ||
| attn_bias.to(query.dtype) | ||
|
|
||
| if attn_mask is not None: | ||
| if attn_mask is None: | ||
| attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) | ||
| elif attn_mask.dtype == torch.bool: | ||
| if attn_mask.dim() == 3: | ||
| combined_mask = temp_mask.unsqueeze(0) & attn_mask | ||
| else: | ||
| combined_mask = temp_mask & attn_mask | ||
| fully_masked = ~combined_mask.any(dim=-1) | ||
| attn_bias.masked_fill_(combined_mask.logical_not(), float("-inf")) | ||
| else: | ||
| if attn_mask.dim() == 3: | ||
| temp_mask = temp_mask.unsqueeze(0) | ||
| attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) | ||
| attn_bias += attn_mask | ||
| elif attn_mask is not None: |
There was a problem hiding this comment.
The comment at line 1280 says "assert attn_mask is None" was removed, but this assertion was actually removed from the code. When is_causal is True and attn_mask is provided, the function now combines them. However, the assertion removal is not noted in the PR description. The logic appears correct for combining causal and inter-document masks, but it's a behavioral change from the previous implementation that enforced mutual exclusivity.
| elif attn_mask.dtype == torch.bool: | ||
| if attn_mask.dim() == 3: | ||
| combined_mask = temp_mask.unsqueeze(0) & attn_mask | ||
| else: | ||
| combined_mask = temp_mask & attn_mask | ||
| fully_masked = ~combined_mask.any(dim=-1) | ||
| attn_bias.masked_fill_(combined_mask.logical_not(), float("-inf")) | ||
| else: | ||
| if attn_mask.dim() == 3: | ||
| temp_mask = temp_mask.unsqueeze(0) | ||
| attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) | ||
| attn_bias += attn_mask |
There was a problem hiding this comment.
The fully_masked variable is computed to identify rows that have no valid attention positions after combining causal and inter-document masks. However, this is only used when attn_mask.dtype is torch.bool within the is_causal branch. If attn_mask is a float mask, fully_masked will remain None, which means the special handling for fully masked rows won't apply. This could lead to NaN values in attention weights after softmax on fully masked rows when using float masks. Consider computing fully_masked for float masks as well, or document this limitation.
| if len(eos_positions) == 0: | ||
| assert ( | ||
| self.padding_token_id is None or seq[0] != self.padding_token_id | ||
| ), "Sequence starts with padding token" |
There was a problem hiding this comment.
The assertion message "Sequence starts with padding token" is not very informative. It doesn't explain why this is a problem or what the user should do to fix it. Consider improving the error message to explain that sequences cannot start with padding tokens because it would result in invalid sub-sequence length computation, and suggest how to fix the data (e.g., "Invalid sequence: cannot start with padding token. Please ensure padding is only at the end of sequences after EOS tokens.").
| ), "Sequence starts with padding token" | |
| ), ( | |
| "Invalid sequence: cannot start with padding token. This prevents valid " | |
| "sub-sequence length computation when no EOS token is present. Please ensure " | |
| "padding is only applied at the end of sequences, typically after EOS tokens." | |
| ) |
| attention_masking_information = self.transformer.h["0"].attn.prepare_inter_document_masking( | ||
| in_batch_seq_lens=sub_seq_lengths, max_seq_len=seq_len | ||
| ) |
There was a problem hiding this comment.
The attention_masking_information is computed once using the first layer's attention module (self.transformer.h["0"].attn) and then reused for all layers. While this is efficient, if different layers use different attention implementations (which the current code doesn't seem to support but could be a future feature), this could cause issues. Consider adding a comment explaining that all layers must use the same attention implementation for inter-document masking to work correctly, or add validation to ensure this.
| attention_masking_information = self.transformer.h["0"].attn.prepare_inter_document_masking( | |
| in_batch_seq_lens=sub_seq_lengths, max_seq_len=seq_len | |
| ) | |
| # Inter-document masking is prepared once using the first layer's attention implementation | |
| # and reused for all layers. This requires that all layers share the same attention type. | |
| first_attn = self.transformer.h["0"].attn | |
| attention_masking_information = first_attn.prepare_inter_document_masking( | |
| in_batch_seq_lens=sub_seq_lengths, max_seq_len=seq_len | |
| ) | |
| # Validate that all layers use a compatible attention implementation for inter-document masking. | |
| for _layer_idx, block in getattr(self.transformer, "h").items(): | |
| if type(block.attn) is not type(first_attn): | |
| raise RuntimeError( | |
| "All transformer layers must use the same attention implementation for " | |
| "inter-document masking to work correctly." | |
| ) |
| max_seq_len: The maximum allowed sequence length (number of subsequences and | ||
| total length constraints are validated against this value). | ||
| device: The torch device on which to allocate the output tensor. | ||
| Returns: | ||
| A tensor of shape (batch_size, max_seq_len) containing the subsequence lengths | ||
| for each batch item, padded with zeros beyond the number of subsequences. | ||
| Raises: | ||
| ValueError: If a batch item has more subsequences than max_seq_len or if the | ||
| sum of its subsequence lengths exceeds max_seq_len. | ||
| """ | ||
| batch_size = len(in_batch_seq_lens) | ||
| concatenated_lengths = torch.zeros((batch_size, max_seq_len), dtype=torch.int32, device=device) | ||
| for batch_idx, doc_seq_lens in enumerate(in_batch_seq_lens): | ||
| if len(doc_seq_lens) > max_seq_len: | ||
| raise ValueError( | ||
| f"Number of subsequences ({len(doc_seq_lens)}) exceeds max_seq_len ({max_seq_len}) " | ||
| f"for batch index {batch_idx}." | ||
| ) | ||
| if sum(doc_seq_lens) > max_seq_len: | ||
| raise ValueError( | ||
| f"Sum of subsequence lengths ({sum(doc_seq_lens)}) exceeds max_seq_len ({max_seq_len}) " | ||
| f"for batch index {batch_idx}." |
There was a problem hiding this comment.
The validation checks if len(doc_seq_lens) > max_seq_len and if sum(doc_seq_lens) > max_seq_len, raising ValueError in both cases. However, the error messages refer to "number of subsequences" and "sum of subsequence lengths" both being compared to max_seq_len. This is potentially confusing because max_seq_len should represent the total sequence length (sum of lengths), not the number of subsequences. The first check (number of subsequences > max_seq_len) might be overly restrictive - you could have many single-token documents that total less than max_seq_len. Consider clarifying the semantics of max_seq_len or adjusting the validation logic.
| max_seq_len: The maximum allowed sequence length (number of subsequences and | |
| total length constraints are validated against this value). | |
| device: The torch device on which to allocate the output tensor. | |
| Returns: | |
| A tensor of shape (batch_size, max_seq_len) containing the subsequence lengths | |
| for each batch item, padded with zeros beyond the number of subsequences. | |
| Raises: | |
| ValueError: If a batch item has more subsequences than max_seq_len or if the | |
| sum of its subsequence lengths exceeds max_seq_len. | |
| """ | |
| batch_size = len(in_batch_seq_lens) | |
| concatenated_lengths = torch.zeros((batch_size, max_seq_len), dtype=torch.int32, device=device) | |
| for batch_idx, doc_seq_lens in enumerate(in_batch_seq_lens): | |
| if len(doc_seq_lens) > max_seq_len: | |
| raise ValueError( | |
| f"Number of subsequences ({len(doc_seq_lens)}) exceeds max_seq_len ({max_seq_len}) " | |
| f"for batch index {batch_idx}." | |
| ) | |
| if sum(doc_seq_lens) > max_seq_len: | |
| raise ValueError( | |
| f"Sum of subsequence lengths ({sum(doc_seq_lens)}) exceeds max_seq_len ({max_seq_len}) " | |
| f"for batch index {batch_idx}." | |
| max_seq_len: An upper bound applied per batch item to both | |
| (a) the number of subsequences and (b) the total sum of subsequence | |
| lengths (in tokens). It also defines the second dimension of the | |
| returned tensor. | |
| device: The torch device on which to allocate the output tensor. | |
| Returns: | |
| A tensor of shape (batch_size, max_seq_len) containing the subsequence lengths | |
| for each batch item, padded with zeros beyond the number of subsequences. | |
| The second dimension (max_seq_len) therefore bounds the maximum number of | |
| subsequences represented per batch item. | |
| Raises: | |
| ValueError: If a batch item has more subsequences than the allowed maximum | |
| number (max_seq_len) or if the sum of its subsequence lengths exceeds the | |
| allowed maximum total length (max_seq_len). | |
| """ | |
| batch_size = len(in_batch_seq_lens) | |
| concatenated_lengths = torch.zeros((batch_size, max_seq_len), dtype=torch.int32, device=device) | |
| for batch_idx, doc_seq_lens in enumerate(in_batch_seq_lens): | |
| if len(doc_seq_lens) > max_seq_len: | |
| raise ValueError( | |
| f"Number of subsequences ({len(doc_seq_lens)}) exceeds the maximum allowed " | |
| f"number of subsequences ({max_seq_len}) for batch index {batch_idx}." | |
| ) | |
| if sum(doc_seq_lens) > max_seq_len: | |
| raise ValueError( | |
| f"Sum of subsequence lengths ({sum(doc_seq_lens)}) exceeds the maximum allowed " | |
| f"total length ({max_seq_len}) for batch index {batch_idx}." |
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
What does this PR do?
Adds inter document masking for manual and flash attention.
General Changes
prepare_inter_document_masking()to CausalSelfAttention which computes 3D attention masks for manual attention and cu_seqlens for DAO flash attention. The input are the sub sequence lengths for each sequence. Thus, padded sequences are also supported.forward()call, inter document masking is applied.Breaking Changes
Checklist before submitting final PR
python tests/tests.py)CHANGELOG_DEV.md)