Skip to content

Added inter document masking for manual and flash attention.#434

Open
BlueCrescent wants to merge 9 commits intomainfrom
inter_document_masking_for_attention
Open

Added inter document masking for manual and flash attention.#434
BlueCrescent wants to merge 9 commits intomainfrom
inter_document_masking_for_attention

Conversation

@BlueCrescent
Copy link
Member

What does this PR do?

Adds inter document masking for manual and flash attention.

General Changes

  • Added prepare_inter_document_masking() to CausalSelfAttention which computes 3D attention masks for manual attention and cu_seqlens for DAO flash attention. The input are the sub sequence lengths for each sequence. Thus, padded sequences are also supported.
  • When provided to the attention's forward() call, inter document masking is applied.
  • Added thorough tests for CausalSelfAttention.
  • Integrated into GPT2Model.
  • TODO: Test GPT2Model, support PP, create corresponding dataloader

Breaking Changes

  • None, if no inter document sequence lengths are provided, the behavior should remain unchanged.

Checklist before submitting final PR

  • My PR is minimal and addresses one issue in isolation
  • I have merged the latest version of the target branch into this feature branch
  • I have reviewed my own code w.r.t. correct implementation, missing type hints, proper documentation, etc.
  • I have run a sample config for model training
  • I have checked that all tests run through (python tests/tests.py)
  • I have updated the internal changelog (CHANGELOG_DEV.md)

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds inter-document masking support to the GPT-2 attention stack so sequences containing multiple concatenated documents can prevent cross-document attention for both manual attention and DAO flash attention (via varlen).

Changes:

  • Added CausalSelfAttention.prepare_inter_document_masking() and threaded optional masking info through attention execution paths.
  • Implemented DAO flash varlen execution path to support document-wise masking/splitting without padding leakage.
  • Added extensive unit tests for inter-document masking behaviors and edge cases.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 8 comments.

File Description
tests/models/test_causal_self_attention.py Adds comprehensive tests for inter-document masking across manual and DAO flash attention implementations.
src/modalities/models/gpt2/gpt2_model.py Implements inter-document masking preparation, DAO flash varlen execution, and integrates masking into GPT2LLM/GPT2Block/CausalSelfAttention.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +21 to 23
torch.manual_seed(0) # FIXME remove or do within tests?


Copy link

Copilot AI Feb 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Setting torch.manual_seed(0) at module import time mutates global RNG state for the entire test run, which can make unrelated tests order-dependent. Prefer seeding inside the specific tests that need determinism (or via a fixture) instead of at import scope.

Suggested change
torch.manual_seed(0) # FIXME remove or do within tests?

Copilot uses AI. Check for mistakes.
Comment on lines +523 to +533
device = self.c_proj.weight.device
if self.attention_impl == AttentionImplementation.MANUAL:
batch_size = len(in_batch_seq_lens)
attn_mask = torch.zeros((batch_size, max_seq_len, max_seq_len), dtype=torch.bool, device=device)
for i, doc_seq_lens in enumerate(in_batch_seq_lens):
doc_boundaries = torch.cumsum(torch.tensor([0] + doc_seq_lens, device=device), dim=0)
for j in range(len(doc_boundaries) - 1):
start_idx = doc_boundaries[j]
end_idx = doc_boundaries[j + 1]
attn_mask[i, start_idx:end_idx, start_idx:end_idx] = True
return attn_mask
Copy link

Copilot AI Feb 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prepare_inter_document_masking() validates in_batch_seq_lens for the dao_flash path but not for the manual path. If a batch item's subsequence lengths sum to more than max_seq_len (or contain invalid values), the manual path can produce incorrect masks or raise a low-level indexing error. Consider applying the same validation (e.g., reuse _build_concatenated_lengths_tensor checks) for the manual implementation too.

Copilot uses AI. Check for mistakes.
Comment on lines +788 to 792
attention_masking_information (torch.Tensor | tuple[torch.Tensor, torch.Tensor, int] | None):
Optional tensor containing masking information for inter-document attention.

Returns:
torch.Tensor: The output tensor.
Copy link

Copilot AI Feb 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

execute_attention() now accepts attention_masking_information, but inter-document masking is only applied in the MANUAL / DAO_FLASH implementations. For PYTORCH_FLASH, the mask is currently silently ignored (it always passes attn_mask=None). To avoid surprising behavior, consider raising NotImplementedError (or asserting the mask is None) when attention_impl is PYTORCH_FLASH and a mask is provided.

Copilot uses AI. Check for mistakes.
@BlueCrescent BlueCrescent marked this pull request as ready for review February 25, 2026 10:42
@BlueCrescent BlueCrescent requested a review from Copilot February 25, 2026 10:42
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 6 out of 6 changed files in this pull request and generated 11 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 450 to 464
@model_validator(mode="before")
def check_sub_seq_lengths_and_eos_token(cls, values):
sub_seq_lengths_key = values.get("sub_seq_lengths_key")
eos_token_id = values.get("eos_token_id")
if (sub_seq_lengths_key is None) != (eos_token_id is None):
raise ValueError("Either both or neither of sub_seq_lengths_key and eos_token_id must be provided.")
return values

@model_validator(mode="before")
def check_padding_token_and_sub_seq_lengths(cls, values):
padding_token_id = values.get("padding_token_id")
sub_seq_lengths_key = values.get("sub_seq_lengths_key")
if padding_token_id is not None and sub_seq_lengths_key is None:
raise ValueError("If padding_token_id is provided, sub_seq_lengths_key must also be provided.")
return values
Copy link

Copilot AI Feb 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The validator uses mode="before" which means it operates on raw input values before Pydantic parsing. However, the validation logic uses XOR ((sub_seq_lengths_key is None) != (eos_token_id is None)) which might be less clear to readers than explicitly checking both conditions. Additionally, consider using mode="after" for clearer semantics since you're checking already-parsed field values, not raw input.

Suggested change
@model_validator(mode="before")
def check_sub_seq_lengths_and_eos_token(cls, values):
sub_seq_lengths_key = values.get("sub_seq_lengths_key")
eos_token_id = values.get("eos_token_id")
if (sub_seq_lengths_key is None) != (eos_token_id is None):
raise ValueError("Either both or neither of sub_seq_lengths_key and eos_token_id must be provided.")
return values
@model_validator(mode="before")
def check_padding_token_and_sub_seq_lengths(cls, values):
padding_token_id = values.get("padding_token_id")
sub_seq_lengths_key = values.get("sub_seq_lengths_key")
if padding_token_id is not None and sub_seq_lengths_key is None:
raise ValueError("If padding_token_id is provided, sub_seq_lengths_key must also be provided.")
return values
@model_validator(mode="after")
def check_sub_seq_lengths_and_eos_token(cls, model: "GPT2LLMCollateFnConfig") -> "GPT2LLMCollateFnConfig":
sub_seq_lengths_key = model.sub_seq_lengths_key
eos_token_id = model.eos_token_id
if (sub_seq_lengths_key is None and eos_token_id is not None) or (
sub_seq_lengths_key is not None and eos_token_id is None
):
raise ValueError(
"Either both or neither of sub_seq_lengths_key and eos_token_id must be provided."
)
return model
@model_validator(mode="after")
def check_padding_token_and_sub_seq_lengths(cls, model: "GPT2LLMCollateFnConfig") -> "GPT2LLMCollateFnConfig":
padding_token_id = model.padding_token_id
sub_seq_lengths_key = model.sub_seq_lengths_key
if padding_token_id is not None and sub_seq_lengths_key is None:
raise ValueError("If padding_token_id is provided, sub_seq_lengths_key must also be provided.")
return model

Copilot uses AI. Check for mistakes.
Comment on lines 20 to 51
def test_gpt2_collate_sub_seq_lengths_without_eos():
collator = GPT2LLMCollateFn(
sample_key="input_ids",
target_key="labels",
sub_seq_lengths_key="sub_seq_lengths",
eos_token_id=99,
)
batch = [
{"input_ids": torch.tensor([10, 11, 12, 13, 14])},
{"input_ids": torch.tensor([20, 21, 22, 23, 24])},
]

result = collator(batch)

assert result.samples["sub_seq_lengths"] == [[5], [5]]


def test_gpt2_collate_sub_seq_lengths_with_eos():
collator = GPT2LLMCollateFn(
sample_key="input_ids",
target_key="labels",
sub_seq_lengths_key="sub_seq_lengths",
eos_token_id=99,
)
batch = [
{"input_ids": torch.tensor([1, 99, 2, 3, 99])},
{"input_ids": torch.tensor([7, 8, 9, 99, 10])},
]

result = collator(batch)

assert result.samples["sub_seq_lengths"] == [[2, 3], [4, 1]]
Copy link

Copilot AI Feb 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please verify the test expectations are correct. For the input [10, 11, 12, 13, 14], after the shift operation (sample_tensor[:, :-1]) the sequence becomes [10, 11, 12, 13] with length 4. When there's no EOS token, the collator returns [len(seq)] which should be [4], but the test expects [[5], [5]]. Similarly, for test_gpt2_collate_sub_seq_lengths_with_eos, trace through the logic to ensure the expected values match the implementation. If the tests are passing, please add comments explaining the counter-intuitive behavior.

Copilot uses AI. Check for mistakes.
Comment on lines +21 to 23
torch.manual_seed(0) # FIXME remove or do within tests?


Copy link

Copilot AI Feb 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The FIXME comment suggests uncertainty about whether to keep the global torch.manual_seed(0) or move it within tests. Having a global seed can affect test isolation, as tests may depend on the order they're run. It's better practice to set the seed within each test that needs deterministic behavior, as is already done in several tests (e.g., lines 501, 529, etc.). Consider removing this global seed setting and ensuring each test that requires determinism sets its own seed.

Suggested change
torch.manual_seed(0) # FIXME remove or do within tests?

Copilot uses AI. Check for mistakes.
# TODO: use drop out also without absolute position embedding?
h = self.transformer.drop(h) if hasattr(self.transformer, "drop") else h

# TODO: Handle this in case of pipeline parallelism.
Copy link

Copilot AI Feb 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The TODO comment indicates that pipeline parallelism (PP) handling for inter-document masking is not yet implemented. This is acknowledged in the PR description. However, it would be helpful to add more detail about what needs to be done. For example, in PP, attention_masking_information would need to be computed on the first pipeline stage and passed to subsequent stages, or each stage would need access to sub_seq_lengths. Consider expanding this TODO with specific implementation notes.

Suggested change
# TODO: Handle this in case of pipeline parallelism.
# TODO: Pipeline parallelism (PP) handling for inter-document masking.
# In the non-PP case we compute `attention_masking_information` once using the
# attention module of the first block (`self.transformer.h["0"].attn`) and reuse
# it for all layers. In a PP setup, only a subset of blocks live on each
# pipeline stage, so this logic needs to be adapted:
# * Option 1: Compute `attention_masking_information` on the first PP stage
# (which must have access to `sub_seq_lengths`) and propagate it along the
# pipeline together with the hidden states so that later stages can pass it
# to their local blocks.
# * Option 2: Ensure that each PP stage has access to `sub_seq_lengths` and
# recompute the same `attention_masking_information` locally using the
# first block on that stage (or a dedicated helper) so that all stages use
# identical masking.
# When implementing PP support, choose one strategy consistent with the
# existing PP activation/metadata communication mechanism and ensure that all
# stages see the same `attention_masking_information`.

Copilot uses AI. Check for mistakes.
Comment on lines 1280 to +1296
if is_causal:
assert attn_mask is None
temp_mask = torch.ones(L, S, dtype=torch.bool, device=query.device).tril(diagonal=0) # device added
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(query.dtype)

if attn_mask is not None:
if attn_mask is None:
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
elif attn_mask.dtype == torch.bool:
if attn_mask.dim() == 3:
combined_mask = temp_mask.unsqueeze(0) & attn_mask
else:
combined_mask = temp_mask & attn_mask
fully_masked = ~combined_mask.any(dim=-1)
attn_bias.masked_fill_(combined_mask.logical_not(), float("-inf"))
else:
if attn_mask.dim() == 3:
temp_mask = temp_mask.unsqueeze(0)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias += attn_mask
elif attn_mask is not None:
Copy link

Copilot AI Feb 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment at line 1280 says "assert attn_mask is None" was removed, but this assertion was actually removed from the code. When is_causal is True and attn_mask is provided, the function now combines them. However, the assertion removal is not noted in the PR description. The logic appears correct for combining causal and inter-document masks, but it's a behavioral change from the previous implementation that enforced mutual exclusivity.

Copilot uses AI. Check for mistakes.
Comment on lines +1284 to +1295
elif attn_mask.dtype == torch.bool:
if attn_mask.dim() == 3:
combined_mask = temp_mask.unsqueeze(0) & attn_mask
else:
combined_mask = temp_mask & attn_mask
fully_masked = ~combined_mask.any(dim=-1)
attn_bias.masked_fill_(combined_mask.logical_not(), float("-inf"))
else:
if attn_mask.dim() == 3:
temp_mask = temp_mask.unsqueeze(0)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias += attn_mask
Copy link

Copilot AI Feb 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fully_masked variable is computed to identify rows that have no valid attention positions after combining causal and inter-document masks. However, this is only used when attn_mask.dtype is torch.bool within the is_causal branch. If attn_mask is a float mask, fully_masked will remain None, which means the special handling for fully masked rows won't apply. This could lead to NaN values in attention weights after softmax on fully masked rows when using float masks. Consider computing fully_masked for float masks as well, or document this limitation.

Copilot uses AI. Check for mistakes.
if len(eos_positions) == 0:
assert (
self.padding_token_id is None or seq[0] != self.padding_token_id
), "Sequence starts with padding token"
Copy link

Copilot AI Feb 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The assertion message "Sequence starts with padding token" is not very informative. It doesn't explain why this is a problem or what the user should do to fix it. Consider improving the error message to explain that sequences cannot start with padding tokens because it would result in invalid sub-sequence length computation, and suggest how to fix the data (e.g., "Invalid sequence: cannot start with padding token. Please ensure padding is only at the end of sequences after EOS tokens.").

Suggested change
), "Sequence starts with padding token"
), (
"Invalid sequence: cannot start with padding token. This prevents valid "
"sub-sequence length computation when no EOS token is present. Please ensure "
"padding is only applied at the end of sequences, typically after EOS tokens."
)

Copilot uses AI. Check for mistakes.
Comment on lines +1237 to +1239
attention_masking_information = self.transformer.h["0"].attn.prepare_inter_document_masking(
in_batch_seq_lens=sub_seq_lengths, max_seq_len=seq_len
)
Copy link

Copilot AI Feb 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The attention_masking_information is computed once using the first layer's attention module (self.transformer.h["0"].attn) and then reused for all layers. While this is efficient, if different layers use different attention implementations (which the current code doesn't seem to support but could be a future feature), this could cause issues. Consider adding a comment explaining that all layers must use the same attention implementation for inter-document masking to work correctly, or add validation to ensure this.

Suggested change
attention_masking_information = self.transformer.h["0"].attn.prepare_inter_document_masking(
in_batch_seq_lens=sub_seq_lengths, max_seq_len=seq_len
)
# Inter-document masking is prepared once using the first layer's attention implementation
# and reused for all layers. This requires that all layers share the same attention type.
first_attn = self.transformer.h["0"].attn
attention_masking_information = first_attn.prepare_inter_document_masking(
in_batch_seq_lens=sub_seq_lengths, max_seq_len=seq_len
)
# Validate that all layers use a compatible attention implementation for inter-document masking.
for _layer_idx, block in getattr(self.transformer, "h").items():
if type(block.attn) is not type(first_attn):
raise RuntimeError(
"All transformer layers must use the same attention implementation for "
"inter-document masking to work correctly."
)

Copilot uses AI. Check for mistakes.
Comment on lines +559 to +580
max_seq_len: The maximum allowed sequence length (number of subsequences and
total length constraints are validated against this value).
device: The torch device on which to allocate the output tensor.
Returns:
A tensor of shape (batch_size, max_seq_len) containing the subsequence lengths
for each batch item, padded with zeros beyond the number of subsequences.
Raises:
ValueError: If a batch item has more subsequences than max_seq_len or if the
sum of its subsequence lengths exceeds max_seq_len.
"""
batch_size = len(in_batch_seq_lens)
concatenated_lengths = torch.zeros((batch_size, max_seq_len), dtype=torch.int32, device=device)
for batch_idx, doc_seq_lens in enumerate(in_batch_seq_lens):
if len(doc_seq_lens) > max_seq_len:
raise ValueError(
f"Number of subsequences ({len(doc_seq_lens)}) exceeds max_seq_len ({max_seq_len}) "
f"for batch index {batch_idx}."
)
if sum(doc_seq_lens) > max_seq_len:
raise ValueError(
f"Sum of subsequence lengths ({sum(doc_seq_lens)}) exceeds max_seq_len ({max_seq_len}) "
f"for batch index {batch_idx}."
Copy link

Copilot AI Feb 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The validation checks if len(doc_seq_lens) > max_seq_len and if sum(doc_seq_lens) > max_seq_len, raising ValueError in both cases. However, the error messages refer to "number of subsequences" and "sum of subsequence lengths" both being compared to max_seq_len. This is potentially confusing because max_seq_len should represent the total sequence length (sum of lengths), not the number of subsequences. The first check (number of subsequences > max_seq_len) might be overly restrictive - you could have many single-token documents that total less than max_seq_len. Consider clarifying the semantics of max_seq_len or adjusting the validation logic.

Suggested change
max_seq_len: The maximum allowed sequence length (number of subsequences and
total length constraints are validated against this value).
device: The torch device on which to allocate the output tensor.
Returns:
A tensor of shape (batch_size, max_seq_len) containing the subsequence lengths
for each batch item, padded with zeros beyond the number of subsequences.
Raises:
ValueError: If a batch item has more subsequences than max_seq_len or if the
sum of its subsequence lengths exceeds max_seq_len.
"""
batch_size = len(in_batch_seq_lens)
concatenated_lengths = torch.zeros((batch_size, max_seq_len), dtype=torch.int32, device=device)
for batch_idx, doc_seq_lens in enumerate(in_batch_seq_lens):
if len(doc_seq_lens) > max_seq_len:
raise ValueError(
f"Number of subsequences ({len(doc_seq_lens)}) exceeds max_seq_len ({max_seq_len}) "
f"for batch index {batch_idx}."
)
if sum(doc_seq_lens) > max_seq_len:
raise ValueError(
f"Sum of subsequence lengths ({sum(doc_seq_lens)}) exceeds max_seq_len ({max_seq_len}) "
f"for batch index {batch_idx}."
max_seq_len: An upper bound applied per batch item to both
(a) the number of subsequences and (b) the total sum of subsequence
lengths (in tokens). It also defines the second dimension of the
returned tensor.
device: The torch device on which to allocate the output tensor.
Returns:
A tensor of shape (batch_size, max_seq_len) containing the subsequence lengths
for each batch item, padded with zeros beyond the number of subsequences.
The second dimension (max_seq_len) therefore bounds the maximum number of
subsequences represented per batch item.
Raises:
ValueError: If a batch item has more subsequences than the allowed maximum
number (max_seq_len) or if the sum of its subsequence lengths exceeds the
allowed maximum total length (max_seq_len).
"""
batch_size = len(in_batch_seq_lens)
concatenated_lengths = torch.zeros((batch_size, max_seq_len), dtype=torch.int32, device=device)
for batch_idx, doc_seq_lens in enumerate(in_batch_seq_lens):
if len(doc_seq_lens) > max_seq_len:
raise ValueError(
f"Number of subsequences ({len(doc_seq_lens)}) exceeds the maximum allowed "
f"number of subsequences ({max_seq_len}) for batch index {batch_idx}."
)
if sum(doc_seq_lens) > max_seq_len:
raise ValueError(
f"Sum of subsequence lengths ({sum(doc_seq_lens)}) exceeds the maximum allowed "
f"total length ({max_seq_len}) for batch index {batch_idx}."

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants