From 080087a152e2d48892110acc6d3c98fbd09bf3cc Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Mon, 4 May 2026 18:02:02 -0300 Subject: [PATCH 1/2] fix: attention cache type and num heads --- lib/bumblebee/layers/decoder.ex | 43 ++++++++++++++++++++++++++--- lib/bumblebee/layers/transformer.ex | 39 ++++++++++++++++++++++---- lib/bumblebee/text/gemma.ex | 1 + lib/bumblebee/text/gemma3_text.ex | 1 + lib/bumblebee/text/llama.ex | 1 + lib/bumblebee/text/mistral.ex | 1 + lib/bumblebee/text/phi.ex | 1 + lib/bumblebee/text/phi3.ex | 1 + lib/bumblebee/text/qwen3.ex | 3 ++ lib/bumblebee/text/smollm3.ex | 1 + 10 files changed, 83 insertions(+), 9 deletions(-) diff --git a/lib/bumblebee/layers/decoder.ex b/lib/bumblebee/layers/decoder.ex index e08c7c0c..f5113e5a 100644 --- a/lib/bumblebee/layers/decoder.ex +++ b/lib/bumblebee/layers/decoder.ex @@ -42,6 +42,12 @@ defmodule Bumblebee.Layers.Decoder do * `:decoder_num_attention_heads` - the number of decoder attention heads + * `:decoder_num_key_value_heads` - the number of decoder key-value + attention heads. Defaults to `:decoder_num_attention_heads` + + * `:attention_cache_type` - the type of the key-value cache tensors. + Defaults to `{:bf, 16}` + * `:encoder_num_attention_heads` - the number of encoder attention heads (for cross attention) @@ -52,6 +58,12 @@ defmodule Bumblebee.Layers.Decoder do def init_cache(batch_size, max_length, opts \\ []) do hidden_size = Keyword.fetch!(opts, :hidden_size) decoder_num_attention_heads = Keyword.fetch!(opts, :decoder_num_attention_heads) + + decoder_num_key_value_heads = + opts[:decoder_num_key_value_heads] || decoder_num_attention_heads + + attention_cache_type = opts[:attention_cache_type] || {:bf, 16} + decoder_num_blocks = Keyword.fetch!(opts, :decoder_num_blocks) encoder_num_attention_heads = opts[:encoder_num_attention_heads] encoder_sequence_length = opts[:encoder_sequence_length] @@ -60,7 +72,13 @@ defmodule Bumblebee.Layers.Decoder do opts[:attention_head_size] || div(hidden_size, decoder_num_attention_heads) self_attention = - attention_cache(batch_size, max_length, decoder_num_attention_heads, decoder_head_size) + attention_cache( + batch_size, + max_length, + decoder_num_key_value_heads, + decoder_head_size, + attention_cache_type + ) cross_attention = if encoder_sequence_length do @@ -71,7 +89,8 @@ defmodule Bumblebee.Layers.Decoder do batch_size, encoder_sequence_length, encoder_num_attention_heads, - encoder_head_size + encoder_head_size, + attention_cache_type ) else %Axon.None{} @@ -89,9 +108,9 @@ defmodule Bumblebee.Layers.Decoder do %{blocks: blocks, offset: offset, attention_mask: attention_mask} end - defp attention_cache(batch_size, sequence_length, num_heads, head_size) do + defp attention_cache(batch_size, sequence_length, num_heads, head_size, type) do shape = {batch_size, sequence_length, num_heads, head_size} - zeros = Nx.broadcast(0.0, shape) + zeros = Nx.broadcast(Nx.tensor(0, type: type), shape) %{key: zeros, value: zeros} end @@ -204,10 +223,26 @@ defmodule Bumblebee.Layers.Decoder do {Axon.nx(block_cache, & &1.self_attention), Axon.nx(block_cache, & &1.cross_attention)} end + @doc """ + Retrieves self-attention cache from a block cache. + """ + def get_self_attention_cache(block_cache) do + Axon.nx(block_cache, & &1.self_attention) + end + @doc """ Puts updated self-attention and cross-attention cache entries for in the decoder block cache. """ + def put_attention_caches(block_cache, self_attention_cache, %Axon.None{}) do + Axon.layer( + fn block_cache, self_attention_cache, _opts -> + %{block_cache | self_attention: self_attention_cache} + end, + [block_cache, self_attention_cache] + ) + end + def put_attention_caches(block_cache, self_attention_cache, cross_attention_cache) do Axon.layer( fn block_cache, self_attention_cache, cross_attention_cache, _opts -> diff --git a/lib/bumblebee/layers/transformer.ex b/lib/bumblebee/layers/transformer.ex index 188b0ffe..d5392ea7 100644 --- a/lib/bumblebee/layers/transformer.ex +++ b/lib/bumblebee/layers/transformer.ex @@ -32,6 +32,12 @@ defmodule Bumblebee.Layers.Transformer do This enables per-layer attention patterns like Gemma 3's alternating local/global attention (5 local layers followed by 1 global layer) + * `:output_hidden_states` - whether to accumulate hidden states from all blocks. + Defaults to `true` + + * `:output_attentions` - whether to accumulate attention weights from all blocks. + Defaults to `true` + * `:name` - the prefix for layer names For all other options (including required options) see `block/2`. @@ -82,7 +88,9 @@ defmodule Bumblebee.Layers.Transformer do cross_hidden_state: nil, cross_attention_mask: Layers.none(), cross_attention_head_mask: Layers.none(), - cache: Layers.none() + cache: Layers.none(), + output_hidden_states: true, + output_attentions: true ] ) @@ -97,6 +105,8 @@ defmodule Bumblebee.Layers.Transformer do cache = opts[:cache] rotary_embedding = opts[:rotary_embedding] attention_window_size = opts[:attention_window_size] + output_hidden_states = opts[:output_hidden_states] + output_attentions = opts[:output_attentions] block_opts = Keyword.take(opts, block_opts_keys) @@ -164,9 +174,24 @@ defmodule Bumblebee.Layers.Transformer do %{ hidden_state: hidden_state, - hidden_states: Layers.append(state.hidden_states, hidden_state), - attentions: Layers.append(state.attentions, attention), - cross_attentions: Layers.append(state.cross_attentions, cross_attention), + hidden_states: + if output_hidden_states do + Layers.append(state.hidden_states, hidden_state) + else + state.hidden_states + end, + attentions: + if output_attentions do + Layers.append(state.attentions, attention) + else + state.attentions + end, + cross_attentions: + if output_attentions do + Layers.append(state.cross_attentions, cross_attention) + else + state.cross_attentions + end, attention_relative_bias: attention_relative_bias, cache: cache } @@ -416,7 +441,11 @@ defmodule Bumblebee.Layers.Transformer do end {self_attention_cache, cross_attention_cache} = - Layers.Decoder.get_attention_caches(block_cache) + if is_nil(cross_hidden_state) do + {Layers.Decoder.get_self_attention_cache(block_cache), %Axon.None{}} + else + Layers.Decoder.get_attention_caches(block_cache) + end # Self-attention, shortcut connection, normalization and dropout diff --git a/lib/bumblebee/text/gemma.ex b/lib/bumblebee/text/gemma.ex index cc04731a..951d575d 100644 --- a/lib/bumblebee/text/gemma.ex +++ b/lib/bumblebee/text/gemma.ex @@ -178,6 +178,7 @@ defmodule Bumblebee.Text.Gemma do hidden_size: spec.hidden_size, attention_head_size: spec.attention_head_size, decoder_num_attention_heads: spec.num_attention_heads, + decoder_num_key_value_heads: spec.num_key_value_heads, decoder_num_blocks: spec.num_blocks ) end diff --git a/lib/bumblebee/text/gemma3_text.ex b/lib/bumblebee/text/gemma3_text.ex index afe1502a..4b751e92 100644 --- a/lib/bumblebee/text/gemma3_text.ex +++ b/lib/bumblebee/text/gemma3_text.ex @@ -214,6 +214,7 @@ defmodule Bumblebee.Text.Gemma3Text do hidden_size: spec.hidden_size, attention_head_size: spec.attention_head_size, decoder_num_attention_heads: spec.num_attention_heads, + decoder_num_key_value_heads: spec.num_key_value_heads, decoder_num_blocks: spec.num_blocks ) end diff --git a/lib/bumblebee/text/llama.ex b/lib/bumblebee/text/llama.ex index 18141523..75e5c8a3 100644 --- a/lib/bumblebee/text/llama.ex +++ b/lib/bumblebee/text/llama.ex @@ -182,6 +182,7 @@ defmodule Bumblebee.Text.Llama do hidden_size: spec.hidden_size, attention_head_size: spec.attention_head_size, decoder_num_attention_heads: spec.num_attention_heads, + decoder_num_key_value_heads: spec.num_key_value_heads, decoder_num_blocks: spec.num_blocks ) end diff --git a/lib/bumblebee/text/mistral.ex b/lib/bumblebee/text/mistral.ex index afbe26d9..19a38e1b 100644 --- a/lib/bumblebee/text/mistral.ex +++ b/lib/bumblebee/text/mistral.ex @@ -165,6 +165,7 @@ defmodule Bumblebee.Text.Mistral do Layers.Decoder.init_cache(batch_size, max_length, hidden_size: spec.hidden_size, decoder_num_attention_heads: spec.num_attention_heads, + decoder_num_key_value_heads: spec.num_key_value_heads, decoder_num_blocks: spec.num_blocks ) end diff --git a/lib/bumblebee/text/phi.ex b/lib/bumblebee/text/phi.ex index 0d7b4250..385a4c3d 100644 --- a/lib/bumblebee/text/phi.ex +++ b/lib/bumblebee/text/phi.ex @@ -170,6 +170,7 @@ defmodule Bumblebee.Text.Phi do Layers.Decoder.init_cache(batch_size, max_length, hidden_size: spec.hidden_size, decoder_num_attention_heads: spec.num_attention_heads, + decoder_num_key_value_heads: spec.num_key_value_heads, decoder_num_blocks: spec.num_blocks ) end diff --git a/lib/bumblebee/text/phi3.ex b/lib/bumblebee/text/phi3.ex index 9348ad5b..24a67103 100644 --- a/lib/bumblebee/text/phi3.ex +++ b/lib/bumblebee/text/phi3.ex @@ -184,6 +184,7 @@ defmodule Bumblebee.Text.Phi3 do Layers.Decoder.init_cache(batch_size, max_length, hidden_size: spec.hidden_size, decoder_num_attention_heads: spec.num_attention_heads, + decoder_num_key_value_heads: spec.num_key_value_heads, decoder_num_blocks: spec.num_blocks ) end diff --git a/lib/bumblebee/text/qwen3.ex b/lib/bumblebee/text/qwen3.ex index 568dd4e6..eff25ef3 100644 --- a/lib/bumblebee/text/qwen3.ex +++ b/lib/bumblebee/text/qwen3.ex @@ -184,6 +184,7 @@ defmodule Bumblebee.Text.Qwen3 do hidden_size: spec.hidden_size, attention_head_size: spec.attention_head_size, decoder_num_attention_heads: spec.num_attention_heads, + decoder_num_key_value_heads: spec.num_key_value_heads, decoder_num_blocks: spec.num_blocks ) end @@ -373,6 +374,8 @@ defmodule Bumblebee.Text.Qwen3 do ], query_norm: query_norm, key_norm: key_norm, + output_hidden_states: false, + output_attentions: false, name: join(name, "blocks") ) end diff --git a/lib/bumblebee/text/smollm3.ex b/lib/bumblebee/text/smollm3.ex index b0a21ae5..96722765 100644 --- a/lib/bumblebee/text/smollm3.ex +++ b/lib/bumblebee/text/smollm3.ex @@ -212,6 +212,7 @@ defmodule Bumblebee.Text.SmolLM3 do hidden_size: spec.hidden_size, attention_head_size: spec.attention_head_size, decoder_num_attention_heads: spec.num_attention_heads, + decoder_num_key_value_heads: spec.num_key_value_heads, decoder_num_blocks: spec.num_blocks ) end From ccd6f4a7c99d340c3e9f6b5026ecf41f725d10c8 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Mon, 4 May 2026 22:31:22 -0300 Subject: [PATCH 2/2] refactor: pass type options for init_cache --- lib/bumblebee/audio/whisper.ex | 5 +++-- lib/bumblebee/multimodal/blip.ex | 5 +++-- lib/bumblebee/text/bart.ex | 5 +++-- lib/bumblebee/text/bert.ex | 5 +++-- lib/bumblebee/text/blenderbot.ex | 5 +++-- lib/bumblebee/text/blip_text.ex | 5 +++-- lib/bumblebee/text/gemma.ex | 5 +++-- lib/bumblebee/text/gemma3_text.ex | 5 +++-- lib/bumblebee/text/generation.ex | 24 +++++++++++------------- lib/bumblebee/text/gpt2.ex | 5 +++-- lib/bumblebee/text/gpt_big_code.ex | 5 +++-- lib/bumblebee/text/gpt_neo_x.ex | 5 +++-- lib/bumblebee/text/llama.ex | 5 +++-- lib/bumblebee/text/m2m100.ex | 5 +++-- lib/bumblebee/text/mbart.ex | 5 +++-- lib/bumblebee/text/mistral.ex | 5 +++-- lib/bumblebee/text/modernbert_decoder.ex | 5 +++-- lib/bumblebee/text/phi.ex | 5 +++-- lib/bumblebee/text/phi3.ex | 5 +++-- lib/bumblebee/text/qwen3.ex | 5 +++-- lib/bumblebee/text/roberta.ex | 5 +++-- lib/bumblebee/text/smollm3.ex | 5 +++-- lib/bumblebee/text/t5.ex | 5 +++-- mix.exs | 12 ++++++------ mix.lock | 12 ++++++------ 25 files changed, 89 insertions(+), 69 deletions(-) diff --git a/lib/bumblebee/audio/whisper.ex b/lib/bumblebee/audio/whisper.ex index 9b78c29d..6b7a5ba8 100644 --- a/lib/bumblebee/audio/whisper.ex +++ b/lib/bumblebee/audio/whisper.ex @@ -227,7 +227,7 @@ defmodule Bumblebee.Audio.Whisper do end @impl true - def init_cache(spec, batch_size, max_length, inputs) do + def init_cache(spec, batch_size, max_length, inputs, opts \\ []) do encoder_sequence_length = if encoder_hidden_state = inputs["encoder_hidden_state"] do Nx.axis_size(encoder_hidden_state, 1) @@ -238,7 +238,8 @@ defmodule Bumblebee.Audio.Whisper do decoder_num_attention_heads: spec.decoder_num_attention_heads, encoder_num_attention_heads: spec.encoder_num_attention_heads, decoder_num_blocks: spec.decoder_num_blocks, - encoder_sequence_length: encoder_sequence_length + encoder_sequence_length: encoder_sequence_length, + attention_cache_type: opts[:cache_type] ) end diff --git a/lib/bumblebee/multimodal/blip.ex b/lib/bumblebee/multimodal/blip.ex index 26a523c4..9a80b38c 100644 --- a/lib/bumblebee/multimodal/blip.ex +++ b/lib/bumblebee/multimodal/blip.ex @@ -178,7 +178,8 @@ defmodule Bumblebee.Multimodal.Blip do %{vision_spec: vision_spec, text_spec: text_spec}, batch_size, max_length, - inputs + inputs, + opts \\ [] ) do num_patches = div(vision_spec.image_size, vision_spec.patch_size) ** 2 encoder_sequence_length = num_patches + 1 @@ -193,7 +194,7 @@ defmodule Bumblebee.Multimodal.Blip do } |> Map.reject(&match?({_, nil}, &1)) - text_spec.__struct__.init_cache(text_spec, batch_size, max_length, inputs) + text_spec.__struct__.init_cache(text_spec, batch_size, max_length, inputs, opts) end @impl true diff --git a/lib/bumblebee/text/bart.ex b/lib/bumblebee/text/bart.ex index b0e4e272..7ad9b10c 100644 --- a/lib/bumblebee/text/bart.ex +++ b/lib/bumblebee/text/bart.ex @@ -417,7 +417,7 @@ defmodule Bumblebee.Text.Bart do end @impl true - def init_cache(spec, batch_size, max_length, inputs) do + def init_cache(spec, batch_size, max_length, inputs, opts \\ []) do encoder_sequence_length = if encoder_hidden_state = inputs["encoder_hidden_state"] do Nx.axis_size(encoder_hidden_state, 1) @@ -428,7 +428,8 @@ defmodule Bumblebee.Text.Bart do decoder_num_attention_heads: spec.decoder_num_attention_heads, encoder_num_attention_heads: spec.encoder_num_attention_heads, decoder_num_blocks: spec.decoder_num_blocks, - encoder_sequence_length: encoder_sequence_length + encoder_sequence_length: encoder_sequence_length, + attention_cache_type: opts[:cache_type] ) end diff --git a/lib/bumblebee/text/bert.ex b/lib/bumblebee/text/bert.ex index 775fb459..6520cec0 100644 --- a/lib/bumblebee/text/bert.ex +++ b/lib/bumblebee/text/bert.ex @@ -374,7 +374,7 @@ defmodule Bumblebee.Text.Bert do end @impl true - def init_cache(spec, batch_size, max_length, inputs) do + def init_cache(spec, batch_size, max_length, inputs, opts \\ []) do encoder_sequence_length = if encoder_hidden_state = inputs["encoder_hidden_state"] do Nx.axis_size(encoder_hidden_state, 1) @@ -385,7 +385,8 @@ defmodule Bumblebee.Text.Bert do decoder_num_attention_heads: spec.num_attention_heads, encoder_num_attention_heads: spec.num_attention_heads, decoder_num_blocks: spec.num_blocks, - encoder_sequence_length: encoder_sequence_length + encoder_sequence_length: encoder_sequence_length, + attention_cache_type: opts[:cache_type] ) end diff --git a/lib/bumblebee/text/blenderbot.ex b/lib/bumblebee/text/blenderbot.ex index 18898e8c..935fd5f5 100644 --- a/lib/bumblebee/text/blenderbot.ex +++ b/lib/bumblebee/text/blenderbot.ex @@ -269,7 +269,7 @@ defmodule Bumblebee.Text.Blenderbot do end @impl true - def init_cache(spec, batch_size, max_length, inputs) do + def init_cache(spec, batch_size, max_length, inputs, opts \\ []) do encoder_sequence_length = if encoder_hidden_state = inputs["encoder_hidden_state"] do Nx.axis_size(encoder_hidden_state, 1) @@ -280,7 +280,8 @@ defmodule Bumblebee.Text.Blenderbot do decoder_num_attention_heads: spec.decoder_num_attention_heads, encoder_num_attention_heads: spec.encoder_num_attention_heads, decoder_num_blocks: spec.decoder_num_blocks, - encoder_sequence_length: encoder_sequence_length + encoder_sequence_length: encoder_sequence_length, + attention_cache_type: opts[:cache_type] ) end diff --git a/lib/bumblebee/text/blip_text.ex b/lib/bumblebee/text/blip_text.ex index e21fa191..847a98d6 100644 --- a/lib/bumblebee/text/blip_text.ex +++ b/lib/bumblebee/text/blip_text.ex @@ -182,7 +182,7 @@ defmodule Bumblebee.Text.BlipText do end @impl true - def init_cache(spec, batch_size, max_length, inputs) do + def init_cache(spec, batch_size, max_length, inputs, opts \\ []) do encoder_sequence_length = if encoder_hidden_state = inputs["encoder_hidden_state"] do Nx.axis_size(encoder_hidden_state, 1) @@ -193,7 +193,8 @@ defmodule Bumblebee.Text.BlipText do decoder_num_attention_heads: spec.num_attention_heads, encoder_num_attention_heads: spec.num_attention_heads, decoder_num_blocks: spec.num_blocks, - encoder_sequence_length: encoder_sequence_length + encoder_sequence_length: encoder_sequence_length, + attention_cache_type: opts[:cache_type] ) end diff --git a/lib/bumblebee/text/gemma.ex b/lib/bumblebee/text/gemma.ex index 951d575d..778cf088 100644 --- a/lib/bumblebee/text/gemma.ex +++ b/lib/bumblebee/text/gemma.ex @@ -173,13 +173,14 @@ defmodule Bumblebee.Text.Gemma do end @impl true - def init_cache(spec, batch_size, max_length, _inputs) do + def init_cache(spec, batch_size, max_length, _inputs, opts \\ []) do Layers.Decoder.init_cache(batch_size, max_length, hidden_size: spec.hidden_size, attention_head_size: spec.attention_head_size, decoder_num_attention_heads: spec.num_attention_heads, decoder_num_key_value_heads: spec.num_key_value_heads, - decoder_num_blocks: spec.num_blocks + decoder_num_blocks: spec.num_blocks, + attention_cache_type: opts[:cache_type] ) end diff --git a/lib/bumblebee/text/gemma3_text.ex b/lib/bumblebee/text/gemma3_text.ex index 4b751e92..6cf4783f 100644 --- a/lib/bumblebee/text/gemma3_text.ex +++ b/lib/bumblebee/text/gemma3_text.ex @@ -209,13 +209,14 @@ defmodule Bumblebee.Text.Gemma3Text do end @impl true - def init_cache(spec, batch_size, max_length, _inputs) do + def init_cache(spec, batch_size, max_length, _inputs, opts \\ []) do Layers.Decoder.init_cache(batch_size, max_length, hidden_size: spec.hidden_size, attention_head_size: spec.attention_head_size, decoder_num_attention_heads: spec.num_attention_heads, decoder_num_key_value_heads: spec.num_key_value_heads, - decoder_num_blocks: spec.num_blocks + decoder_num_blocks: spec.num_blocks, + attention_cache_type: opts[:cache_type] ) end diff --git a/lib/bumblebee/text/generation.ex b/lib/bumblebee/text/generation.ex index 669c1b7e..1ea4e638 100644 --- a/lib/bumblebee/text/generation.ex +++ b/lib/bumblebee/text/generation.ex @@ -12,7 +12,8 @@ defmodule Bumblebee.Text.Generation do spec :: Bumblebee.ModelSpec.t(), batch_size :: pos_integer(), max_length :: pos_integer(), - inputs :: map() + inputs :: map(), + opts :: keyword() ) :: cache() @doc """ @@ -42,9 +43,10 @@ defmodule Bumblebee.Text.Generation do @doc """ Initializes an opaque cache input for iterative inference. """ - @spec init_cache(Bumblebee.ModelSpec.t(), pos_integer(), pos_integer(), map()) :: cache() - def init_cache(%module{} = spec, batch_size, max_length, inputs) do - module.init_cache(spec, batch_size, max_length, inputs) + @spec init_cache(Bumblebee.ModelSpec.t(), pos_integer(), pos_integer(), map(), keyword()) :: + cache() + def init_cache(%module{} = spec, batch_size, max_length, inputs, opts \\ []) do + module.init_cache(spec, batch_size, max_length, inputs, opts) end @doc """ @@ -313,17 +315,13 @@ defmodule Bumblebee.Text.Generation do |> Map.put(prefix <> "position_ids", position_ids) batch_size = Nx.axis_size(input_ids, 0) - cache = init_cache(spec, batch_size, max_length, inputs) output_policy = model_output_policy(model) - - # Cast all float cache tensors to match the model output. This way - # we make sure the cache we pass as input has the same types as - # the updated cache returned from the model - cache = - Bumblebee.Utils.Nx.map(cache, fn tensor -> - Axon.MixedPrecision.cast(output_policy, tensor, :output) - end) + # Use the compute precision as the cache type. The key/value tensors are + # produced by projection layers running in compute precision, so this + # matches what the model will actually return for the cache. + cache_type = output_policy.compute || {:f, 32} + cache = init_cache(spec, batch_size, max_length, inputs, cache_type: cache_type) Map.put(inputs, "cache", cache) end diff --git a/lib/bumblebee/text/gpt2.ex b/lib/bumblebee/text/gpt2.ex index f23aa032..9a3d79f4 100644 --- a/lib/bumblebee/text/gpt2.ex +++ b/lib/bumblebee/text/gpt2.ex @@ -278,7 +278,7 @@ defmodule Bumblebee.Text.Gpt2 do end @impl true - def init_cache(spec, batch_size, max_length, inputs) do + def init_cache(spec, batch_size, max_length, inputs, opts \\ []) do encoder_sequence_length = if encoder_hidden_state = inputs["encoder_hidden_state"] do Nx.axis_size(encoder_hidden_state, 1) @@ -289,7 +289,8 @@ defmodule Bumblebee.Text.Gpt2 do decoder_num_attention_heads: spec.num_attention_heads, encoder_num_attention_heads: spec.num_attention_heads, decoder_num_blocks: spec.num_blocks, - encoder_sequence_length: encoder_sequence_length + encoder_sequence_length: encoder_sequence_length, + attention_cache_type: opts[:cache_type] ) end diff --git a/lib/bumblebee/text/gpt_big_code.ex b/lib/bumblebee/text/gpt_big_code.ex index db4dd558..81b2555f 100644 --- a/lib/bumblebee/text/gpt_big_code.ex +++ b/lib/bumblebee/text/gpt_big_code.ex @@ -282,7 +282,7 @@ defmodule Bumblebee.Text.GptBigCode do end @impl true - def init_cache(spec, batch_size, max_length, inputs) do + def init_cache(spec, batch_size, max_length, inputs, opts \\ []) do encoder_sequence_length = if encoder_hidden_state = inputs["encoder_hidden_state"] do Nx.axis_size(encoder_hidden_state, 1) @@ -293,7 +293,8 @@ defmodule Bumblebee.Text.GptBigCode do decoder_num_attention_heads: spec.num_attention_heads, encoder_num_attention_heads: spec.num_attention_heads, decoder_num_blocks: spec.num_blocks, - encoder_sequence_length: encoder_sequence_length + encoder_sequence_length: encoder_sequence_length, + attention_cache_type: opts[:cache_type] ) end diff --git a/lib/bumblebee/text/gpt_neo_x.ex b/lib/bumblebee/text/gpt_neo_x.ex index 1f99a09c..d3d956b5 100644 --- a/lib/bumblebee/text/gpt_neo_x.ex +++ b/lib/bumblebee/text/gpt_neo_x.ex @@ -159,11 +159,12 @@ defmodule Bumblebee.Text.GptNeoX do end @impl true - def init_cache(spec, batch_size, max_length, _inputs) do + def init_cache(spec, batch_size, max_length, _inputs, opts \\ []) do Layers.Decoder.init_cache(batch_size, max_length, hidden_size: spec.hidden_size, decoder_num_attention_heads: spec.num_attention_heads, - decoder_num_blocks: spec.num_blocks + decoder_num_blocks: spec.num_blocks, + attention_cache_type: opts[:cache_type] ) end diff --git a/lib/bumblebee/text/llama.ex b/lib/bumblebee/text/llama.ex index 75e5c8a3..68d53ea9 100644 --- a/lib/bumblebee/text/llama.ex +++ b/lib/bumblebee/text/llama.ex @@ -177,13 +177,14 @@ defmodule Bumblebee.Text.Llama do end @impl true - def init_cache(spec, batch_size, max_length, _inputs) do + def init_cache(spec, batch_size, max_length, _inputs, opts \\ []) do Layers.Decoder.init_cache(batch_size, max_length, hidden_size: spec.hidden_size, attention_head_size: spec.attention_head_size, decoder_num_attention_heads: spec.num_attention_heads, decoder_num_key_value_heads: spec.num_key_value_heads, - decoder_num_blocks: spec.num_blocks + decoder_num_blocks: spec.num_blocks, + attention_cache_type: opts[:cache_type] ) end diff --git a/lib/bumblebee/text/m2m100.ex b/lib/bumblebee/text/m2m100.ex index 5f064741..d11e9b3e 100644 --- a/lib/bumblebee/text/m2m100.ex +++ b/lib/bumblebee/text/m2m100.ex @@ -268,7 +268,7 @@ defmodule Bumblebee.Text.M2m100 do end @impl true - def init_cache(spec, batch_size, max_length, inputs) do + def init_cache(spec, batch_size, max_length, inputs, opts \\ []) do encoder_sequence_length = if encoder_hidden_state = inputs["encoder_hidden_state"] do Nx.axis_size(encoder_hidden_state, 1) @@ -279,7 +279,8 @@ defmodule Bumblebee.Text.M2m100 do decoder_num_attention_heads: spec.decoder_num_attention_heads, encoder_num_attention_heads: spec.encoder_num_attention_heads, decoder_num_blocks: spec.decoder_num_blocks, - encoder_sequence_length: encoder_sequence_length + encoder_sequence_length: encoder_sequence_length, + attention_cache_type: opts[:cache_type] ) end diff --git a/lib/bumblebee/text/mbart.ex b/lib/bumblebee/text/mbart.ex index 3b93a0a7..a9214fb7 100644 --- a/lib/bumblebee/text/mbart.ex +++ b/lib/bumblebee/text/mbart.ex @@ -414,7 +414,7 @@ defmodule Bumblebee.Text.Mbart do end @impl true - def init_cache(spec, batch_size, max_length, inputs) do + def init_cache(spec, batch_size, max_length, inputs, opts \\ []) do encoder_sequence_length = if encoder_hidden_state = inputs["encoder_hidden_state"] do Nx.axis_size(encoder_hidden_state, 1) @@ -425,7 +425,8 @@ defmodule Bumblebee.Text.Mbart do decoder_num_attention_heads: spec.decoder_num_attention_heads, encoder_num_attention_heads: spec.encoder_num_attention_heads, decoder_num_blocks: spec.decoder_num_blocks, - encoder_sequence_length: encoder_sequence_length + encoder_sequence_length: encoder_sequence_length, + attention_cache_type: opts[:cache_type] ) end diff --git a/lib/bumblebee/text/mistral.ex b/lib/bumblebee/text/mistral.ex index 19a38e1b..615b37d1 100644 --- a/lib/bumblebee/text/mistral.ex +++ b/lib/bumblebee/text/mistral.ex @@ -161,12 +161,13 @@ defmodule Bumblebee.Text.Mistral do end @impl true - def init_cache(spec, batch_size, max_length, _inputs) do + def init_cache(spec, batch_size, max_length, _inputs, opts \\ []) do Layers.Decoder.init_cache(batch_size, max_length, hidden_size: spec.hidden_size, decoder_num_attention_heads: spec.num_attention_heads, decoder_num_key_value_heads: spec.num_key_value_heads, - decoder_num_blocks: spec.num_blocks + decoder_num_blocks: spec.num_blocks, + attention_cache_type: opts[:cache_type] ) end diff --git a/lib/bumblebee/text/modernbert_decoder.ex b/lib/bumblebee/text/modernbert_decoder.ex index b061bccb..57250f7b 100644 --- a/lib/bumblebee/text/modernbert_decoder.ex +++ b/lib/bumblebee/text/modernbert_decoder.ex @@ -171,12 +171,13 @@ defmodule Bumblebee.Text.ModernBertDecoder do end @impl true - def init_cache(spec, batch_size, max_length, _inputs) do + def init_cache(spec, batch_size, max_length, _inputs, opts \\ []) do Layers.Decoder.init_cache(batch_size, max_length, hidden_size: spec.hidden_size, attention_head_size: div(spec.hidden_size, spec.num_attention_heads), decoder_num_attention_heads: spec.num_attention_heads, - decoder_num_blocks: spec.num_blocks + decoder_num_blocks: spec.num_blocks, + attention_cache_type: opts[:cache_type] ) end diff --git a/lib/bumblebee/text/phi.ex b/lib/bumblebee/text/phi.ex index 385a4c3d..56bc51f1 100644 --- a/lib/bumblebee/text/phi.ex +++ b/lib/bumblebee/text/phi.ex @@ -166,12 +166,13 @@ defmodule Bumblebee.Text.Phi do end @impl true - def init_cache(spec, batch_size, max_length, _inputs) do + def init_cache(spec, batch_size, max_length, _inputs, opts \\ []) do Layers.Decoder.init_cache(batch_size, max_length, hidden_size: spec.hidden_size, decoder_num_attention_heads: spec.num_attention_heads, decoder_num_key_value_heads: spec.num_key_value_heads, - decoder_num_blocks: spec.num_blocks + decoder_num_blocks: spec.num_blocks, + attention_cache_type: opts[:cache_type] ) end diff --git a/lib/bumblebee/text/phi3.ex b/lib/bumblebee/text/phi3.ex index 24a67103..dedef929 100644 --- a/lib/bumblebee/text/phi3.ex +++ b/lib/bumblebee/text/phi3.ex @@ -180,12 +180,13 @@ defmodule Bumblebee.Text.Phi3 do end @impl true - def init_cache(spec, batch_size, max_length, _inputs) do + def init_cache(spec, batch_size, max_length, _inputs, opts \\ []) do Layers.Decoder.init_cache(batch_size, max_length, hidden_size: spec.hidden_size, decoder_num_attention_heads: spec.num_attention_heads, decoder_num_key_value_heads: spec.num_key_value_heads, - decoder_num_blocks: spec.num_blocks + decoder_num_blocks: spec.num_blocks, + attention_cache_type: opts[:cache_type] ) end diff --git a/lib/bumblebee/text/qwen3.ex b/lib/bumblebee/text/qwen3.ex index eff25ef3..028cce91 100644 --- a/lib/bumblebee/text/qwen3.ex +++ b/lib/bumblebee/text/qwen3.ex @@ -179,13 +179,14 @@ defmodule Bumblebee.Text.Qwen3 do end @impl true - def init_cache(spec, batch_size, max_length, _inputs) do + def init_cache(spec, batch_size, max_length, _inputs, opts \\ []) do Layers.Decoder.init_cache(batch_size, max_length, hidden_size: spec.hidden_size, attention_head_size: spec.attention_head_size, decoder_num_attention_heads: spec.num_attention_heads, decoder_num_key_value_heads: spec.num_key_value_heads, - decoder_num_blocks: spec.num_blocks + decoder_num_blocks: spec.num_blocks, + attention_cache_type: opts[:cache_type] ) end diff --git a/lib/bumblebee/text/roberta.ex b/lib/bumblebee/text/roberta.ex index 4db674f3..4f8bbfeb 100644 --- a/lib/bumblebee/text/roberta.ex +++ b/lib/bumblebee/text/roberta.ex @@ -330,7 +330,7 @@ defmodule Bumblebee.Text.Roberta do end @impl true - def init_cache(spec, batch_size, max_length, inputs) do + def init_cache(spec, batch_size, max_length, inputs, opts \\ []) do encoder_sequence_length = if encoder_hidden_state = inputs["encoder_hidden_state"] do Nx.axis_size(encoder_hidden_state, 1) @@ -341,7 +341,8 @@ defmodule Bumblebee.Text.Roberta do decoder_num_attention_heads: spec.num_attention_heads, encoder_num_attention_heads: spec.num_attention_heads, decoder_num_blocks: spec.num_blocks, - encoder_sequence_length: encoder_sequence_length + encoder_sequence_length: encoder_sequence_length, + attention_cache_type: opts[:cache_type] ) end diff --git a/lib/bumblebee/text/smollm3.ex b/lib/bumblebee/text/smollm3.ex index 96722765..559380b5 100644 --- a/lib/bumblebee/text/smollm3.ex +++ b/lib/bumblebee/text/smollm3.ex @@ -207,13 +207,14 @@ defmodule Bumblebee.Text.SmolLM3 do end @impl true - def init_cache(spec, batch_size, max_length, _inputs) do + def init_cache(spec, batch_size, max_length, _inputs, opts \\ []) do Layers.Decoder.init_cache(batch_size, max_length, hidden_size: spec.hidden_size, attention_head_size: spec.attention_head_size, decoder_num_attention_heads: spec.num_attention_heads, decoder_num_key_value_heads: spec.num_key_value_heads, - decoder_num_blocks: spec.num_blocks + decoder_num_blocks: spec.num_blocks, + attention_cache_type: opts[:cache_type] ) end diff --git a/lib/bumblebee/text/t5.ex b/lib/bumblebee/text/t5.ex index ba92cd3d..2cbc8044 100644 --- a/lib/bumblebee/text/t5.ex +++ b/lib/bumblebee/text/t5.ex @@ -198,7 +198,7 @@ defmodule Bumblebee.Text.T5 do end @impl true - def init_cache(spec, batch_size, max_length, inputs) do + def init_cache(spec, batch_size, max_length, inputs, opts \\ []) do encoder_sequence_length = if encoder_hidden_state = inputs["encoder_hidden_state"] do Nx.axis_size(encoder_hidden_state, 1) @@ -210,7 +210,8 @@ defmodule Bumblebee.Text.T5 do decoder_num_attention_heads: spec.decoder_num_attention_heads, encoder_num_attention_heads: spec.encoder_num_attention_heads, decoder_num_blocks: spec.decoder_num_blocks, - encoder_sequence_length: encoder_sequence_length + encoder_sequence_length: encoder_sequence_length, + attention_cache_type: opts[:cache_type] ) end diff --git a/mix.exs b/mix.exs index 1d50ec9d..58c1399b 100644 --- a/mix.exs +++ b/mix.exs @@ -34,12 +34,12 @@ defmodule Bumblebee.MixProject do {:axon, "~> 0.7.0"}, # {:axon, github: "elixir-nx/axon", override: true}, {:tokenizers, "~> 0.4"}, - {:nx, "~> 0.9.0 or ~> 0.10.0"}, - {:exla, ">= 0.0.0", only: [:dev, :test]}, - {:torchx, ">= 0.0.0", only: [:dev, :test]}, - # {:nx, github: "elixir-nx/nx", sparse: "nx", override: true}, - # {:exla, github: "elixir-nx/nx", sparse: "exla", override: true, only: [:dev, :test]}, - # {:torchx, github: "elixir-nx/nx", sparse: "torchx", override: true, only: [:dev, :test]}, + # {:nx, "~> 0.9.0 or ~> 0.10.0"}, + # {:exla, ">= 0.0.0", only: [:dev, :test]}, + # {:torchx, ">= 0.0.0", only: [:dev, :test]}, + {:nx, github: "elixir-nx/nx", sparse: "nx", override: true}, + {:exla, github: "elixir-nx/nx", sparse: "exla", override: true, only: [:dev, :test]}, + {:torchx, github: "elixir-nx/nx", sparse: "torchx", override: true, only: [:dev, :test]}, {:nx_image, "~> 0.1.0"}, {:unpickler, "~> 0.1.0"}, {:safetensors, "~> 0.1.3"}, diff --git a/mix.lock b/mix.lock index a8d7496b..5b767f31 100644 --- a/mix.lock +++ b/mix.lock @@ -11,8 +11,8 @@ "earmark_parser": {:hex, :earmark_parser, "1.4.44", "f20830dd6b5c77afe2b063777ddbbff09f9759396500cdbe7523efd58d7a339c", [:mix], [], "hexpm", "4778ac752b4701a5599215f7030989c989ffdc4f6df457c5f36938cc2d2a2750"}, "elixir_make": {:hex, :elixir_make, "0.9.0", "6484b3cd8c0cee58f09f05ecaf1a140a8c97670671a6a0e7ab4dc326c3109726", [:mix], [], "hexpm", "db23d4fd8b757462ad02f8aa73431a426fe6671c80b200d9710caf3d1dd0ffdb"}, "ex_doc": {:hex, :ex_doc, "0.39.1", "e19d356a1ba1e8f8cfc79ce1c3f83884b6abfcb79329d435d4bbb3e97ccc286e", [:mix], [{:earmark_parser, "~> 1.4.44", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_c, ">= 0.1.0", [hex: :makeup_c, repo: "hexpm", optional: true]}, {:makeup_elixir, "~> 0.14 or ~> 1.0", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1 or ~> 1.0", [hex: :makeup_erlang, repo: "hexpm", optional: false]}, {:makeup_html, ">= 0.1.0", [hex: :makeup_html, repo: "hexpm", optional: true]}], "hexpm", "8abf0ed3e3ca87c0847dfc4168ceab5bedfe881692f1b7c45f4a11b232806865"}, - "exla": {:hex, :exla, "0.10.0", "93e7d75a774fbc06ce05b96de20c4b01bda413b315238cb3c727c09a05d2bc3a", [:make, :mix], [{:elixir_make, "~> 0.6", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:fine, "~> 0.1.0", [hex: :fine, repo: "hexpm", optional: false]}, {:nimble_pool, "~> 1.0", [hex: :nimble_pool, repo: "hexpm", optional: false]}, {:nx, "~> 0.10.0", [hex: :nx, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}, {:xla, "~> 0.9.0", [hex: :xla, repo: "hexpm", optional: false]}], "hexpm", "16fffdb64667d7f0a3bc683fdcd2792b143a9b345e4b1f1d5cd50330c63d8119"}, - "fine": {:hex, :fine, "0.1.4", "b19a89c1476c7c57afb5f9314aed5960b5bc95d5277de4cb5ee8e1d1616ce379", [:mix], [], "hexpm", "be3324cc454a42d80951cf6023b9954e9ff27c6daa255483b3e8d608670303f5"}, + "exla": {:git, "https://github.com/elixir-nx/nx.git", "fb89cf4f998728007d3b02f92ea118e0a13fe60b", [sparse: "exla"]}, + "fine": {:hex, :fine, "0.1.6", "4bf7151493443c454aac9f2fa2f34f5fefd0346a83fb5586a016c4a135c63247", [:mix], [], "hexpm", "5638eb4495488e885ebec167fa57973e5c35e1a50c344eb7666c90ec1c4e3b12"}, "jason": {:hex, :jason, "1.4.4", "b9226785a9aa77b6857ca22832cffa5d5011a667207eb2a0ad56adb5db443b8a", [:mix], [{:decimal, "~> 1.0 or ~> 2.0", [hex: :decimal, repo: "hexpm", optional: true]}], "hexpm", "c5eb0cab91f094599f94d55bc63409236a8ec69a21a67814529e8d5f6cc90b3b"}, "makeup": {:hex, :makeup, "1.2.1", "e90ac1c65589ef354378def3ba19d401e739ee7ee06fb47f94c687016e3713d1", [:mix], [{:nimble_parsec, "~> 1.4", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "d36484867b0bae0fea568d10131197a4c2e47056a6fbe84922bf6ba71c8d17ce"}, "makeup_elixir": {:hex, :makeup_elixir, "1.0.1", "e928a4f984e795e41e3abd27bfc09f51db16ab8ba1aebdba2b3a575437efafc2", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "7284900d412a3e5cfd97fdaed4f5ed389b8f2b4cb49efc0eb3bd10e2febf9507"}, @@ -20,7 +20,7 @@ "mime": {:hex, :mime, "2.0.7", "b8d739037be7cd402aee1ba0306edfdef982687ee7e9859bee6198c1e7e2f128", [:mix], [], "hexpm", "6171188e399ee16023ffc5b76ce445eb6d9672e2e241d2df6050f3c771e80ccd"}, "nimble_parsec": {:hex, :nimble_parsec, "1.4.2", "8efba0122db06df95bfaa78f791344a89352ba04baedd3849593bfce4d0dc1c6", [:mix], [], "hexpm", "4b21398942dda052b403bbe1da991ccd03a053668d147d53fb8c4e0efe09c973"}, "nimble_pool": {:hex, :nimble_pool, "1.1.0", "bf9c29fbdcba3564a8b800d1eeb5a3c58f36e1e11d7b7fb2e084a643f645f06b", [:mix], [], "hexpm", "af2e4e6b34197db81f7aad230c1118eac993acc0dae6bc83bac0126d4ae0813a"}, - "nx": {:hex, :nx, "0.10.0", "128e4a094cb790f663e20e1334b127c1f2a4df54edfb8b13c22757ec33133b4f", [:mix], [{:complex, "~> 0.6", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "3db8892c124aeee091df0e6fbf8e5bf1b81f502eb0d4f5ba63e6378ebcae7da4"}, + "nx": {:git, "https://github.com/elixir-nx/nx.git", "1c3d86eb635e136eab7d5a1a12a57794c200d204", [sparse: "nx"]}, "nx_image": {:hex, :nx_image, "0.1.2", "0c6e3453c1dc30fc80c723a54861204304cebc8a89ed3b806b972c73ee5d119d", [:mix], [{:nx, "~> 0.4", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "9161863c42405ddccb6dbbbeae078ad23e30201509cc804b3b3a7c9e98764b81"}, "nx_signal": {:hex, :nx_signal, "0.2.0", "e1ca0318877b17c81ce8906329f5125f1e2361e4c4235a5baac8a95ee88ea98e", [:mix], [{:nx, "~> 0.6", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "7247e5e18a177a59c4cb5355952900c62fdeadeb2bad02a9a34237b68744e2bb"}, "plug": {:hex, :plug, "1.18.1", "5067f26f7745b7e31bc3368bc1a2b818b9779faa959b49c934c17730efc911cf", [:mix], [{:mime, "~> 1.0 or ~> 2.0", [hex: :mime, repo: "hexpm", optional: false]}, {:plug_crypto, "~> 1.1.1 or ~> 1.2 or ~> 2.0", [hex: :plug_crypto, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.3 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "57a57db70df2b422b564437d2d33cf8d33cd16339c1edb190cd11b1a3a546cc2"}, @@ -32,10 +32,10 @@ "rustler_precompiled": {:hex, :rustler_precompiled, "0.8.3", "4e741024b0b097fe783add06e53ae9a6f23ddc78df1010f215df0c02915ef5a8", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: false]}, {:rustler, "~> 0.23", [hex: :rustler, repo: "hexpm", optional: true]}], "hexpm", "c23f5f33cb6608542de4d04faf0f0291458c352a4648e4d28d17ee1098cddcc4"}, "safetensors": {:hex, :safetensors, "0.1.3", "7ff3c22391e213289c713898481d492c9c28a49ab1d0705b72630fb8360426b2", [:mix], [{:jason, "~> 1.4", [hex: :jason, repo: "hexpm", optional: false]}, {:nx, "~> 0.5", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "fe50b53ea59fde4e723dd1a2e31cfdc6013e69343afac84c6be86d6d7c562c14"}, "stb_image": {:hex, :stb_image, "0.6.10", "76975279e2a130f53dc670bf6f6b1cdc4fbd7ab6293053e88e7fb6a7eae0e836", [:make, :mix], [{:cc_precompiler, "~> 0.1", [hex: :cc_precompiler, repo: "hexpm", optional: false]}, {:elixir_make, "~> 0.8", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:kino, "~> 0.7", [hex: :kino, repo: "hexpm", optional: true]}, {:nx, "~> 0.4", [hex: :nx, repo: "hexpm", optional: true]}], "hexpm", "26125372cfeda209084d3670417fab6819cfccd0e66c657678ecc48314369e8d"}, - "telemetry": {:hex, :telemetry, "1.3.0", "fedebbae410d715cf8e7062c96a1ef32ec22e764197f70cda73d82778d61e7a2", [:rebar3], [], "hexpm", "7015fc8919dbe63764f4b4b87a95b7c0996bd539e0d499be6ec9d7f3875b79e6"}, + "telemetry": {:hex, :telemetry, "1.4.1", "ab6de178e2b29b58e8256b92b382ea3f590a47152ca3651ea857a6cae05ac423", [:rebar3], [], "hexpm", "2172e05a27531d3d31dd9782841065c50dd5c3c7699d95266b2edd54c2dafa1c"}, "tokenizers": {:hex, :tokenizers, "0.5.1", "b0975d92b4ee5b18e8f47b5d65b9d5f1e583d9130189b1a2620401af4e7d4b35", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: false]}, {:rustler, ">= 0.0.0", [hex: :rustler, repo: "hexpm", optional: true]}, {:rustler_precompiled, "~> 0.6", [hex: :rustler_precompiled, repo: "hexpm", optional: false]}], "hexpm", "5f08d97cc7f2ed3d71d370d68120da6d3de010948ccf676c9c0eb591ba4bacc9"}, - "torchx": {:hex, :torchx, "0.10.2", "4b8529bfc4b0e641232497c99ef6d2508e652198840b212373333361352f0bae", [:mix], [{:nx, "~> 0.10.0", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "cad541c64df8ddcbf50d9b0f212961632361a03050c8e01493f0fc8d4fed96d9"}, + "torchx": {:git, "https://github.com/elixir-nx/nx.git", "fb89cf4f998728007d3b02f92ea118e0a13fe60b", [sparse: "torchx"]}, "unpickler": {:hex, :unpickler, "0.1.0", "c2262c0819e6985b761e7107546cef96a485f401816be5304a65fdd200d5bd6a", [:mix], [], "hexpm", "e2b3f61e62406187ac52afead8a63bfb4e49394028993f3c4c42712743cab79e"}, "unzip": {:hex, :unzip, "0.13.0", "bf5ec6ac6063c69e6ec54c8b4a3b8dcd7a2719d28d10d7025776ab107957cde9", [:mix], [], "hexpm", "4bcb9892ecbf2042606b43ab685a1bffe03c14003e6246f5453db2c829237fd9"}, - "xla": {:hex, :xla, "0.9.1", "cca0040ff94902764007a118871bfc667f1a0085d4a5074533a47d6b58bec61e", [:make, :mix], [{:elixir_make, "~> 0.4", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "eb5e443ae5391b1953f253e051f2307bea183b59acee138053a9300779930daf"}, + "xla": {:hex, :xla, "0.10.0", "41121e9f011456242d3a79b9289910ce43419be0b0e7ebe67cc1292c6b3f232f", [:make, :mix], [{:elixir_make, "~> 0.4", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "f57d91aea6e661b52bf12239316c598679e9170628122bbd941235f040122bc6"}, }