-
Notifications
You must be signed in to change notification settings - Fork 136
feat: add init cache type opts #455
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,7 +12,8 @@ defmodule Bumblebee.Text.Generation do | |
| spec :: Bumblebee.ModelSpec.t(), | ||
| batch_size :: pos_integer(), | ||
| max_length :: pos_integer(), | ||
| inputs :: map() | ||
| inputs :: map(), | ||
| opts :: keyword() | ||
| ) :: cache() | ||
|
|
||
| @doc """ | ||
|
|
@@ -42,9 +43,10 @@ defmodule Bumblebee.Text.Generation do | |
| @doc """ | ||
| Initializes an opaque cache input for iterative inference. | ||
| """ | ||
| @spec init_cache(Bumblebee.ModelSpec.t(), pos_integer(), pos_integer(), map()) :: cache() | ||
| def init_cache(%module{} = spec, batch_size, max_length, inputs) do | ||
| module.init_cache(spec, batch_size, max_length, inputs) | ||
| @spec init_cache(Bumblebee.ModelSpec.t(), pos_integer(), pos_integer(), map(), keyword()) :: | ||
| cache() | ||
| def init_cache(%module{} = spec, batch_size, max_length, inputs, opts \\ []) do | ||
| module.init_cache(spec, batch_size, max_length, inputs, opts) | ||
| end | ||
|
|
||
| @doc """ | ||
|
|
@@ -313,17 +315,13 @@ defmodule Bumblebee.Text.Generation do | |
| |> Map.put(prefix <> "position_ids", position_ids) | ||
|
|
||
| batch_size = Nx.axis_size(input_ids, 0) | ||
| cache = init_cache(spec, batch_size, max_length, inputs) | ||
|
|
||
| output_policy = model_output_policy(model) | ||
|
|
||
| # Cast all float cache tensors to match the model output. This way | ||
| # we make sure the cache we pass as input has the same types as | ||
| # the updated cache returned from the model | ||
| cache = | ||
| Bumblebee.Utils.Nx.map(cache, fn tensor -> | ||
| Axon.MixedPrecision.cast(output_policy, tensor, :output) | ||
| end) | ||
| # Use the compute precision as the cache type. The key/value tensors are | ||
| # produced by projection layers running in compute precision, so this | ||
| # matches what the model will actually return for the cache. | ||
| cache_type = output_policy.compute || {:f, 32} | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IIRC the cache value returned from attention layers is cast using I'm not really sure how to model this with mixed precision policy. It may be that we don't want to cast cache at any point, but then we don't have granularity to specify that, since it's a specific input/output.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I went with using the policy, but originally what I wanted was to introduce a new explicit parameter for the cache type. When I just used the output, at least in my use-case, things ended up using f32 for the cache instead of bf16 like I wanted |
||
| cache = init_cache(spec, batch_size, max_length, inputs, cache_type: cache_type) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we get anything from passing it downstream instead of casting as above?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It means we can fit in a smaller memory footprint, if we allocate bf16 instead of f32 and then downcast to bf16 |
||
|
|
||
| Map.put(inputs, "cache", cache) | ||
| end | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need a separate function? If cross attention is not enabled then
get_attention_cachesalready returns none in the second element (or rather a model that compiles to none).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bumblebee/lib/bumblebee/layers/decoder.ex
Lines 218 to 224 in ccd6f4a
It always returns cross-attention. Do you mean that cross attention is always Axon.None?
Also, being eager means fewer Axon.nx calls, which reduces the overall graph in Nx.Defn.Evaluator