Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions tests/acceptance/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""Shared fixtures for acceptance tests.

Session-scoped fixtures avoid redundant model loads across test files.
All models used here must be in the CI cache (see .github/workflows/checks.yml).
"""

import pytest


@pytest.fixture(scope="session")
def gpt2_model():
"""Session-scoped HookedTransformer gpt2 with default weight processing."""
from transformer_lens import HookedTransformer

return HookedTransformer.from_pretrained("gpt2", device="cpu")
30 changes: 30 additions & 0 deletions tests/acceptance/test_generate_batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""Tests that batched HookedTransformer generation matches individual generation."""


def test_ht_generate_batch_matches_individual(gpt2_model):
"""Batched generate() should match one-by-one generate() for left-padded inputs."""
prompts = ["Hello, my dog is cute", "This is a much longer text. Hello, my cat is cute"]
individual_outputs = [gpt2_model.generate(p, verbose=False, do_sample=False) for p in prompts]

batched_outputs = gpt2_model.generate(prompts, verbose=False, do_sample=False)
for i, prompt in enumerate(prompts):
assert (
individual_outputs[i] == batched_outputs[i]
), f"Prompt {i} mismatch:\n individual: {individual_outputs[i]}\n batched: {batched_outputs[i]}"


def test_ht_generate_batch_without_kv_cache(gpt2_model):
"""Same test with use_past_kv_cache=False."""
prompts = ["Hello, my dog is cute", "This is a much longer text. Hello, my cat is cute"]
individual_outputs = [
gpt2_model.generate(p, verbose=False, do_sample=False, use_past_kv_cache=False)
for p in prompts
]

batched_outputs = gpt2_model.generate(
prompts, verbose=False, do_sample=False, use_past_kv_cache=False
)
for i, prompt in enumerate(prompts):
assert (
individual_outputs[i] == batched_outputs[i]
), f"Prompt {i} mismatch:\n individual: {individual_outputs[i]}\n batched: {batched_outputs[i]}"
63 changes: 55 additions & 8 deletions transformer_lens/HookedTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1936,8 +1936,9 @@ def generate(
implying usage of self.cfg.default_prepend_bos (default is True unless specified
otherwise). Pass True or False to override the default.
padding_side (Union[Literal["left", "right"], None], optional): Overrides
self.tokenizer.padding_side. Specifies which side to pad when tokenizing multiple
strings of different lengths.
self.tokenizer.padding_side. Specifies which side to pad when tokenizing
multiple strings of different lengths. For batched list inputs, left-padding
is forced internally for correct generation behavior.
return_type (Optional[str]): The type of the output to return - a string or a list of strings ('str'),
a tensor of tokens ('tokens'), a tensor of output embeddings ('embeds') or whatever the format of the
input was ('input').
Expand Down Expand Up @@ -1974,20 +1975,53 @@ def generate(
else:
return_type = "embeds"

# initial_attention_mask is always computed so that single-prompt and
# batched generation go through the same masked code path, producing
# consistent results for the same prompt regardless of batching.
initial_attention_mask: Optional[torch.Tensor] = None
_is_batched_list = isinstance(input, list) and len(input) > 1

if isinstance(input, (str, list)):
input_type = "str"
# If text, convert to tokens (batch_size=1)
assert (
self.tokenizer is not None
), "Must provide a tokenizer if passing a string to the model"
input = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side)
if _is_batched_list:
# Force left-padding for batched generation so real tokens
# are flush-right and logits[:, -1, :] is always correct.
input = self.to_tokens(input, prepend_bos=prepend_bos, padding_side="left")
else:
input = self.to_tokens(
input, prepend_bos=prepend_bos, padding_side=padding_side
)
elif input.ndim == 2:
input_type = "tokens"
else:
input_type = "embeds"

input_tokens = input if input_type in ["str", "tokens"] else None
batch_size, ctx_length = input.shape[0], input.shape[1]

# Compute initial attention mask. For batched inputs with padding,
# this correctly masks pad tokens. For single/unpadded inputs, this
# is all-ones which matches the no-mask code path but ensures both
# go through the same PosEmbed/attention logic for consistency.
if input_tokens is not None and self.tokenizer is not None:
_prepend_bos = (
self.cfg.default_prepend_bos
if prepend_bos is USE_DEFAULT_VALUE
else (False if prepend_bos is None else prepend_bos)
)
# Temporarily set padding_side="left" so get_attention_mask
# scans for leading pads (matching the left-padded tokens).
_orig_padding_side = self.tokenizer.padding_side
if _is_batched_list:
self.tokenizer.padding_side = "left"
initial_attention_mask = utils.get_attention_mask(
self.tokenizer, input_tokens, _prepend_bos
)
if _is_batched_list:
self.tokenizer.padding_side = _orig_padding_side
device = get_device_for_block_index(0, self.cfg)
input = input.to(device)
if use_past_kv_cache:
Expand Down Expand Up @@ -2062,10 +2096,20 @@ def generate(
for index in tqdm.tqdm(range(max_new_tokens), disable=not verbose):
pos_offset = self.get_pos_offset(past_kv_cache, batch_size)

tokens = torch.zeros((embeds.size(0), embeds.size(1))).to(torch.int)
attention_mask = utils.get_attention_mask(
self.tokenizer, tokens, False if prepend_bos is None else prepend_bos
).to(device)
# Extend the initial attention mask with 1s for generated tokens.
attention_mask: Optional[torch.Tensor] = None
if initial_attention_mask is not None:
n_new = len(sampled_tokens_list)
if n_new > 0:
ones = torch.ones(
batch_size,
n_new,
dtype=initial_attention_mask.dtype,
device=device,
)
attention_mask = torch.cat([initial_attention_mask.to(device), ones], dim=1)
else:
attention_mask = initial_attention_mask.to(device)
residual, shortformer_pos_embed = self.get_residual(
embeds,
pos_offset,
Expand All @@ -2089,6 +2133,7 @@ def generate(
past_kv_cache=past_kv_cache,
start_at_layer=start_at_layer,
shortformer_pos_embed=shortformer_pos_embed,
attention_mask=attention_mask,
)
else:
logits = self.forward(
Expand All @@ -2099,6 +2144,7 @@ def generate(
past_kv_cache=past_kv_cache,
start_at_layer=start_at_layer,
shortformer_pos_embed=shortformer_pos_embed,
attention_mask=attention_mask,
)
else:
# We input the entire sequence, as a [batch, pos] tensor, since we aren't using
Expand All @@ -2110,6 +2156,7 @@ def generate(
padding_side=padding_side,
start_at_layer=start_at_layer,
shortformer_pos_embed=shortformer_pos_embed,
attention_mask=attention_mask,
)
final_logits = logits[:, -1, :]

Expand Down
64 changes: 48 additions & 16 deletions transformer_lens/model_bridge/bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -2121,9 +2121,9 @@ def generate(
prepend_bos: Accepted for API compatibility but not applied during generation.
The HF model expects tokens in its native format (tokenizer defaults).
Overriding BOS can silently degrade generation quality.
padding_side: Accepted for API compatibility but not applied during generation.
The generation loop always extends tokens to the right, so overriding
initial padding_side creates inconsistent token layout.
padding_side: Which side to pad when tokenizing multiple strings of different
lengths. For batched list inputs, left-padding is forced internally for
correct generation behavior. Defaults to None (tokenizer default).
return_type: The type of output to return - 'input', 'str', or 'tokens'
verbose: Not used in Bridge (kept for API compatibility)
output_logits: If True, return a ModelOutput with sequences and logits tuple
Expand All @@ -2135,10 +2135,9 @@ def generate(
Generated sequence as string, list of strings, or tensor depending on input type and return_type.
If output_logits=True, returns a ModelOutput-like object with 'sequences' and 'logits' attributes.
"""
# prepend_bos and padding_side are intentionally not applied during generation.
# prepend_bos is intentionally not applied during generation.
# The HF model expects tokens in its native format. Overriding BOS can silently
# degrade quality, and overriding padding_side conflicts with the generation loop
# which always extends tokens to the right.
# degrade quality.
if prepend_bos is not None:
import warnings

Expand All @@ -2149,27 +2148,28 @@ def generate(
"resulting tensor to generate().",
stacklevel=2,
)
if padding_side is not None:
import warnings

warnings.warn(
"padding_side is ignored during TransformerBridge.generate(). "
"The generation loop extends tokens to the right regardless of initial "
"padding. To control padding, tokenize with to_tokens(padding_side=...) "
"and pass the resulting tensor to generate().",
stacklevel=2,
)
# padding_side is handled internally: for batched list inputs, left-padding
# is forced to ensure correct generation. See _is_batched_list logic below.

# Stateful dispatch is decided after input parsing so we can fall back
# to hf_generate() for input types the stateful loop doesn't handle.
is_stateful_model = getattr(self.cfg, "is_stateful", False)

_is_batched_list = isinstance(input, list) and len(input) > 1

_generate_from_embeds = False
if isinstance(input, str):
input_tokens = self.to_tokens(input, move_to_device=True, truncate=False)
input_type = "str"
elif isinstance(input, list):
# Force left-padding for batched generation so real tokens are
# flush-right and logits[:, -1, :] is always the last real token.
if _is_batched_list:
_orig_padding_side = self.tokenizer.padding_side
self.tokenizer.padding_side = "left"
input_tokens = self.to_tokens(input, move_to_device=True, truncate=False)
if _is_batched_list:
self.tokenizer.padding_side = _orig_padding_side
input_type = "list"
elif isinstance(input, torch.Tensor) and input.is_floating_point():
# inputs_embeds: pre-computed embeddings (e.g., from multimodal models)
Expand Down Expand Up @@ -2307,6 +2307,30 @@ def generate(
)
else:
forward_kwargs: Dict[str, Any] = {}
# Compute attention mask and position_ids for batched
# inputs with padding. HF models default to all-ones
# when no mask is given, which ignores padding tokens.
if (
_is_batched_list
and self.tokenizer is not None
and self.tokenizer.pad_token_id is not None
):
# Temp-swap to "left" so get_attention_mask scans
# for leading pads (matching the left-padded tokens).
_prev_side = self.tokenizer.padding_side
self.tokenizer.padding_side = "left"
attn_mask = utils.get_attention_mask(
self.tokenizer,
current_tokens,
prepend_bos=getattr(self.cfg, "default_prepend_bos", True),
).to(self.cfg.device)
self.tokenizer.padding_side = _prev_side
forward_kwargs["attention_mask"] = attn_mask
# Adjust position_ids for left-padding so pad
# tokens don't consume real position embeddings.
position_ids = attn_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attn_mask == 0, 1)
forward_kwargs["position_ids"] = position_ids
# Pass multimodal inputs only on the first step — the vision
# encoder processes the image once, embedding it into the
# token sequence. This includes pixel_values plus any extra
Expand Down Expand Up @@ -2346,6 +2370,10 @@ def generate(
[input_seq_pos], device=self.cfg.device
)
forward_kwargs["cache_position"] = cache_position
if "position_ids" in forward_kwargs:
forward_kwargs["position_ids"] = forward_kwargs["position_ids"][
:, -1:
]
logits = self(
current_tokens[:, -1:],
return_type="logits",
Expand All @@ -2356,6 +2384,10 @@ def generate(
if _hf_kv_cache is not None:
# Cached step: pass only the last token + cache
forward_kwargs["past_key_values"] = _hf_kv_cache
if "position_ids" in forward_kwargs:
forward_kwargs["position_ids"] = forward_kwargs["position_ids"][
:, -1:
]
logits = self(
current_tokens[:, -1:],
return_type="logits",
Expand Down
Loading