Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions lib/bumblebee/audio/whisper.ex
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ defmodule Bumblebee.Audio.Whisper do
end

@impl true
def init_cache(spec, batch_size, max_length, inputs) do
def init_cache(spec, batch_size, max_length, inputs, opts \\ []) do
encoder_sequence_length =
if encoder_hidden_state = inputs["encoder_hidden_state"] do
Nx.axis_size(encoder_hidden_state, 1)
Expand All @@ -238,7 +238,8 @@ defmodule Bumblebee.Audio.Whisper do
decoder_num_attention_heads: spec.decoder_num_attention_heads,
encoder_num_attention_heads: spec.encoder_num_attention_heads,
decoder_num_blocks: spec.decoder_num_blocks,
encoder_sequence_length: encoder_sequence_length
encoder_sequence_length: encoder_sequence_length,
attention_cache_type: opts[:cache_type]
)
end

Expand Down
43 changes: 39 additions & 4 deletions lib/bumblebee/layers/decoder.ex
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ defmodule Bumblebee.Layers.Decoder do

* `:decoder_num_attention_heads` - the number of decoder attention heads

* `:decoder_num_key_value_heads` - the number of decoder key-value
attention heads. Defaults to `:decoder_num_attention_heads`

* `:attention_cache_type` - the type of the key-value cache tensors.
Defaults to `{:bf, 16}`

* `:encoder_num_attention_heads` - the number of encoder attention heads
(for cross attention)

Expand All @@ -52,6 +58,12 @@ defmodule Bumblebee.Layers.Decoder do
def init_cache(batch_size, max_length, opts \\ []) do
hidden_size = Keyword.fetch!(opts, :hidden_size)
decoder_num_attention_heads = Keyword.fetch!(opts, :decoder_num_attention_heads)

decoder_num_key_value_heads =
opts[:decoder_num_key_value_heads] || decoder_num_attention_heads

attention_cache_type = opts[:attention_cache_type] || {:bf, 16}

decoder_num_blocks = Keyword.fetch!(opts, :decoder_num_blocks)
encoder_num_attention_heads = opts[:encoder_num_attention_heads]
encoder_sequence_length = opts[:encoder_sequence_length]
Expand All @@ -60,7 +72,13 @@ defmodule Bumblebee.Layers.Decoder do
opts[:attention_head_size] || div(hidden_size, decoder_num_attention_heads)

self_attention =
attention_cache(batch_size, max_length, decoder_num_attention_heads, decoder_head_size)
attention_cache(
batch_size,
max_length,
decoder_num_key_value_heads,
decoder_head_size,
attention_cache_type
)

cross_attention =
if encoder_sequence_length do
Expand All @@ -71,7 +89,8 @@ defmodule Bumblebee.Layers.Decoder do
batch_size,
encoder_sequence_length,
encoder_num_attention_heads,
encoder_head_size
encoder_head_size,
attention_cache_type
)
else
%Axon.None{}
Expand All @@ -89,9 +108,9 @@ defmodule Bumblebee.Layers.Decoder do
%{blocks: blocks, offset: offset, attention_mask: attention_mask}
end

defp attention_cache(batch_size, sequence_length, num_heads, head_size) do
defp attention_cache(batch_size, sequence_length, num_heads, head_size, type) do
shape = {batch_size, sequence_length, num_heads, head_size}
zeros = Nx.broadcast(0.0, shape)
zeros = Nx.broadcast(Nx.tensor(0, type: type), shape)
%{key: zeros, value: zeros}
end

Expand Down Expand Up @@ -204,10 +223,26 @@ defmodule Bumblebee.Layers.Decoder do
{Axon.nx(block_cache, & &1.self_attention), Axon.nx(block_cache, & &1.cross_attention)}
end

@doc """
Retrieves self-attention cache from a block cache.
"""
def get_self_attention_cache(block_cache) do
Axon.nx(block_cache, & &1.self_attention)
end

@doc """
Puts updated self-attention and cross-attention cache entries for
in the decoder block cache.
"""
def put_attention_caches(block_cache, self_attention_cache, %Axon.None{}) do
Axon.layer(
fn block_cache, self_attention_cache, _opts ->
%{block_cache | self_attention: self_attention_cache}
end,
[block_cache, self_attention_cache]
)
end

def put_attention_caches(block_cache, self_attention_cache, cross_attention_cache) do
Axon.layer(
fn block_cache, self_attention_cache, cross_attention_cache, _opts ->
Expand Down
39 changes: 34 additions & 5 deletions lib/bumblebee/layers/transformer.ex
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ defmodule Bumblebee.Layers.Transformer do
This enables per-layer attention patterns like Gemma 3's alternating
local/global attention (5 local layers followed by 1 global layer)

* `:output_hidden_states` - whether to accumulate hidden states from all blocks.
Defaults to `true`

* `:output_attentions` - whether to accumulate attention weights from all blocks.
Defaults to `true`

* `:name` - the prefix for layer names

For all other options (including required options) see `block/2`.
Expand Down Expand Up @@ -82,7 +88,9 @@ defmodule Bumblebee.Layers.Transformer do
cross_hidden_state: nil,
cross_attention_mask: Layers.none(),
cross_attention_head_mask: Layers.none(),
cache: Layers.none()
cache: Layers.none(),
output_hidden_states: true,
output_attentions: true
]
)

Expand All @@ -97,6 +105,8 @@ defmodule Bumblebee.Layers.Transformer do
cache = opts[:cache]
rotary_embedding = opts[:rotary_embedding]
attention_window_size = opts[:attention_window_size]
output_hidden_states = opts[:output_hidden_states]
output_attentions = opts[:output_attentions]

block_opts = Keyword.take(opts, block_opts_keys)

Expand Down Expand Up @@ -164,9 +174,24 @@ defmodule Bumblebee.Layers.Transformer do

%{
hidden_state: hidden_state,
hidden_states: Layers.append(state.hidden_states, hidden_state),
attentions: Layers.append(state.attentions, attention),
cross_attentions: Layers.append(state.cross_attentions, cross_attention),
hidden_states:
if output_hidden_states do
Layers.append(state.hidden_states, hidden_state)
else
state.hidden_states
end,
attentions:
if output_attentions do
Layers.append(state.attentions, attention)
else
state.attentions
end,
cross_attentions:
if output_attentions do
Layers.append(state.cross_attentions, cross_attention)
else
state.cross_attentions
end,
attention_relative_bias: attention_relative_bias,
cache: cache
}
Expand Down Expand Up @@ -416,7 +441,11 @@ defmodule Bumblebee.Layers.Transformer do
end

{self_attention_cache, cross_attention_cache} =
Layers.Decoder.get_attention_caches(block_cache)
if is_nil(cross_hidden_state) do
{Layers.Decoder.get_self_attention_cache(block_cache), %Axon.None{}}
Comment on lines +444 to +445
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need a separate function? If cross attention is not enabled then get_attention_caches already returns none in the second element (or rather a model that compiles to none).

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@doc """
Retrieves self-attention and cross-attention caches from a block
cache.
"""
def get_attention_caches(block_cache) do
{Axon.nx(block_cache, & &1.self_attention), Axon.nx(block_cache, & &1.cross_attention)}
end

It always returns cross-attention. Do you mean that cross attention is always Axon.None?

Also, being eager means fewer Axon.nx calls, which reduces the overall graph in Nx.Defn.Evaluator

else
Layers.Decoder.get_attention_caches(block_cache)
end

# Self-attention, shortcut connection, normalization and dropout

Expand Down
5 changes: 3 additions & 2 deletions lib/bumblebee/multimodal/blip.ex
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,8 @@ defmodule Bumblebee.Multimodal.Blip do
%{vision_spec: vision_spec, text_spec: text_spec},
batch_size,
max_length,
inputs
inputs,
opts \\ []
) do
num_patches = div(vision_spec.image_size, vision_spec.patch_size) ** 2
encoder_sequence_length = num_patches + 1
Expand All @@ -193,7 +194,7 @@ defmodule Bumblebee.Multimodal.Blip do
}
|> Map.reject(&match?({_, nil}, &1))

text_spec.__struct__.init_cache(text_spec, batch_size, max_length, inputs)
text_spec.__struct__.init_cache(text_spec, batch_size, max_length, inputs, opts)
end

@impl true
Expand Down
5 changes: 3 additions & 2 deletions lib/bumblebee/text/bart.ex
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ defmodule Bumblebee.Text.Bart do
end

@impl true
def init_cache(spec, batch_size, max_length, inputs) do
def init_cache(spec, batch_size, max_length, inputs, opts \\ []) do
encoder_sequence_length =
if encoder_hidden_state = inputs["encoder_hidden_state"] do
Nx.axis_size(encoder_hidden_state, 1)
Expand All @@ -428,7 +428,8 @@ defmodule Bumblebee.Text.Bart do
decoder_num_attention_heads: spec.decoder_num_attention_heads,
encoder_num_attention_heads: spec.encoder_num_attention_heads,
decoder_num_blocks: spec.decoder_num_blocks,
encoder_sequence_length: encoder_sequence_length
encoder_sequence_length: encoder_sequence_length,
attention_cache_type: opts[:cache_type]
)
end

Expand Down
5 changes: 3 additions & 2 deletions lib/bumblebee/text/bert.ex
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ defmodule Bumblebee.Text.Bert do
end

@impl true
def init_cache(spec, batch_size, max_length, inputs) do
def init_cache(spec, batch_size, max_length, inputs, opts \\ []) do
encoder_sequence_length =
if encoder_hidden_state = inputs["encoder_hidden_state"] do
Nx.axis_size(encoder_hidden_state, 1)
Expand All @@ -385,7 +385,8 @@ defmodule Bumblebee.Text.Bert do
decoder_num_attention_heads: spec.num_attention_heads,
encoder_num_attention_heads: spec.num_attention_heads,
decoder_num_blocks: spec.num_blocks,
encoder_sequence_length: encoder_sequence_length
encoder_sequence_length: encoder_sequence_length,
attention_cache_type: opts[:cache_type]
)
end

Expand Down
5 changes: 3 additions & 2 deletions lib/bumblebee/text/blenderbot.ex
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ defmodule Bumblebee.Text.Blenderbot do
end

@impl true
def init_cache(spec, batch_size, max_length, inputs) do
def init_cache(spec, batch_size, max_length, inputs, opts \\ []) do
encoder_sequence_length =
if encoder_hidden_state = inputs["encoder_hidden_state"] do
Nx.axis_size(encoder_hidden_state, 1)
Expand All @@ -280,7 +280,8 @@ defmodule Bumblebee.Text.Blenderbot do
decoder_num_attention_heads: spec.decoder_num_attention_heads,
encoder_num_attention_heads: spec.encoder_num_attention_heads,
decoder_num_blocks: spec.decoder_num_blocks,
encoder_sequence_length: encoder_sequence_length
encoder_sequence_length: encoder_sequence_length,
attention_cache_type: opts[:cache_type]
)
end

Expand Down
5 changes: 3 additions & 2 deletions lib/bumblebee/text/blip_text.ex
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ defmodule Bumblebee.Text.BlipText do
end

@impl true
def init_cache(spec, batch_size, max_length, inputs) do
def init_cache(spec, batch_size, max_length, inputs, opts \\ []) do
encoder_sequence_length =
if encoder_hidden_state = inputs["encoder_hidden_state"] do
Nx.axis_size(encoder_hidden_state, 1)
Expand All @@ -193,7 +193,8 @@ defmodule Bumblebee.Text.BlipText do
decoder_num_attention_heads: spec.num_attention_heads,
encoder_num_attention_heads: spec.num_attention_heads,
decoder_num_blocks: spec.num_blocks,
encoder_sequence_length: encoder_sequence_length
encoder_sequence_length: encoder_sequence_length,
attention_cache_type: opts[:cache_type]
)
end

Expand Down
6 changes: 4 additions & 2 deletions lib/bumblebee/text/gemma.ex
Original file line number Diff line number Diff line change
Expand Up @@ -173,12 +173,14 @@ defmodule Bumblebee.Text.Gemma do
end

@impl true
def init_cache(spec, batch_size, max_length, _inputs) do
def init_cache(spec, batch_size, max_length, _inputs, opts \\ []) do
Layers.Decoder.init_cache(batch_size, max_length,
hidden_size: spec.hidden_size,
attention_head_size: spec.attention_head_size,
decoder_num_attention_heads: spec.num_attention_heads,
decoder_num_blocks: spec.num_blocks
decoder_num_key_value_heads: spec.num_key_value_heads,
decoder_num_blocks: spec.num_blocks,
attention_cache_type: opts[:cache_type]
)
end

Expand Down
6 changes: 4 additions & 2 deletions lib/bumblebee/text/gemma3_text.ex
Original file line number Diff line number Diff line change
Expand Up @@ -209,12 +209,14 @@ defmodule Bumblebee.Text.Gemma3Text do
end

@impl true
def init_cache(spec, batch_size, max_length, _inputs) do
def init_cache(spec, batch_size, max_length, _inputs, opts \\ []) do
Layers.Decoder.init_cache(batch_size, max_length,
hidden_size: spec.hidden_size,
attention_head_size: spec.attention_head_size,
decoder_num_attention_heads: spec.num_attention_heads,
decoder_num_blocks: spec.num_blocks
decoder_num_key_value_heads: spec.num_key_value_heads,
decoder_num_blocks: spec.num_blocks,
attention_cache_type: opts[:cache_type]
)
end

Expand Down
24 changes: 11 additions & 13 deletions lib/bumblebee/text/generation.ex
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ defmodule Bumblebee.Text.Generation do
spec :: Bumblebee.ModelSpec.t(),
batch_size :: pos_integer(),
max_length :: pos_integer(),
inputs :: map()
inputs :: map(),
opts :: keyword()
) :: cache()

@doc """
Expand Down Expand Up @@ -42,9 +43,10 @@ defmodule Bumblebee.Text.Generation do
@doc """
Initializes an opaque cache input for iterative inference.
"""
@spec init_cache(Bumblebee.ModelSpec.t(), pos_integer(), pos_integer(), map()) :: cache()
def init_cache(%module{} = spec, batch_size, max_length, inputs) do
module.init_cache(spec, batch_size, max_length, inputs)
@spec init_cache(Bumblebee.ModelSpec.t(), pos_integer(), pos_integer(), map(), keyword()) ::
cache()
def init_cache(%module{} = spec, batch_size, max_length, inputs, opts \\ []) do
module.init_cache(spec, batch_size, max_length, inputs, opts)
end

@doc """
Expand Down Expand Up @@ -313,17 +315,13 @@ defmodule Bumblebee.Text.Generation do
|> Map.put(prefix <> "position_ids", position_ids)

batch_size = Nx.axis_size(input_ids, 0)
cache = init_cache(spec, batch_size, max_length, inputs)

output_policy = model_output_policy(model)

# Cast all float cache tensors to match the model output. This way
# we make sure the cache we pass as input has the same types as
# the updated cache returned from the model
cache =
Bumblebee.Utils.Nx.map(cache, fn tensor ->
Axon.MixedPrecision.cast(output_policy, tensor, :output)
end)
# Use the compute precision as the cache type. The key/value tensors are
# produced by projection layers running in compute precision, so this
# matches what the model will actually return for the cache.
cache_type = output_policy.compute || {:f, 32}
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC the cache value returned from attention layers is cast using :output precision (since it's the layer output). That's why we cast as output here.

I'm not really sure how to model this with mixed precision policy. It may be that we don't want to cast cache at any point, but then we don't have granularity to specify that, since it's a specific input/output.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went with using the policy, but originally what I wanted was to introduce a new explicit parameter for the cache type. When I just used the output, at least in my use-case, things ended up using f32 for the cache instead of bf16 like I wanted

cache = init_cache(spec, batch_size, max_length, inputs, cache_type: cache_type)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we get anything from passing it downstream instead of casting as above?

Copy link
Copy Markdown
Contributor Author

@polvalente polvalente May 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It means we can fit in a smaller memory footprint, if we allocate bf16 instead of f32 and then downcast to bf16


Map.put(inputs, "cache", cache)
end
Expand Down
5 changes: 3 additions & 2 deletions lib/bumblebee/text/gpt2.ex
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ defmodule Bumblebee.Text.Gpt2 do
end

@impl true
def init_cache(spec, batch_size, max_length, inputs) do
def init_cache(spec, batch_size, max_length, inputs, opts \\ []) do
encoder_sequence_length =
if encoder_hidden_state = inputs["encoder_hidden_state"] do
Nx.axis_size(encoder_hidden_state, 1)
Expand All @@ -289,7 +289,8 @@ defmodule Bumblebee.Text.Gpt2 do
decoder_num_attention_heads: spec.num_attention_heads,
encoder_num_attention_heads: spec.num_attention_heads,
decoder_num_blocks: spec.num_blocks,
encoder_sequence_length: encoder_sequence_length
encoder_sequence_length: encoder_sequence_length,
attention_cache_type: opts[:cache_type]
)
end

Expand Down
5 changes: 3 additions & 2 deletions lib/bumblebee/text/gpt_big_code.ex
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ defmodule Bumblebee.Text.GptBigCode do
end

@impl true
def init_cache(spec, batch_size, max_length, inputs) do
def init_cache(spec, batch_size, max_length, inputs, opts \\ []) do
encoder_sequence_length =
if encoder_hidden_state = inputs["encoder_hidden_state"] do
Nx.axis_size(encoder_hidden_state, 1)
Expand All @@ -293,7 +293,8 @@ defmodule Bumblebee.Text.GptBigCode do
decoder_num_attention_heads: spec.num_attention_heads,
encoder_num_attention_heads: spec.num_attention_heads,
decoder_num_blocks: spec.num_blocks,
encoder_sequence_length: encoder_sequence_length
encoder_sequence_length: encoder_sequence_length,
attention_cache_type: opts[:cache_type]
)
end

Expand Down
Loading
Loading