feat: add init cache type opts#455
Conversation
| output_hidden_states: false, | ||
| output_attentions: false, |
There was a problem hiding this comment.
Do we need these new options? We prune these by default, and in order to actually return it in the model output, the user needs to opt-in by configuring global layer options.
There was a problem hiding this comment.
I think we can drop these. I must've missed them in my self-review. My focus was on the cache typing
| if is_nil(cross_hidden_state) do | ||
| {Layers.Decoder.get_self_attention_cache(block_cache), %Axon.None{}} |
There was a problem hiding this comment.
Why do we need a separate function? If cross attention is not enabled then get_attention_caches already returns none in the second element (or rather a model that compiles to none).
There was a problem hiding this comment.
bumblebee/lib/bumblebee/layers/decoder.ex
Lines 218 to 224 in ccd6f4a
It always returns cross-attention. Do you mean that cross attention is always Axon.None?
Also, being eager means fewer Axon.nx calls, which reduces the overall graph in Nx.Defn.Evaluator
| # produced by projection layers running in compute precision, so this | ||
| # matches what the model will actually return for the cache. | ||
| cache_type = output_policy.compute || {:f, 32} | ||
| cache = init_cache(spec, batch_size, max_length, inputs, cache_type: cache_type) |
There was a problem hiding this comment.
Do we get anything from passing it downstream instead of casting as above?
There was a problem hiding this comment.
It means we can fit in a smaller memory footprint, if we allocate bf16 instead of f32 and then downcast to bf16
| # Use the compute precision as the cache type. The key/value tensors are | ||
| # produced by projection layers running in compute precision, so this | ||
| # matches what the model will actually return for the cache. | ||
| cache_type = output_policy.compute || {:f, 32} |
There was a problem hiding this comment.
IIRC the cache value returned from attention layers is cast using :output precision (since it's the layer output). That's why we cast as output here.
I'm not really sure how to model this with mixed precision policy. It may be that we don't want to cast cache at any point, but then we don't have granularity to specify that, since it's a specific input/output.
There was a problem hiding this comment.
I went with using the policy, but originally what I wanted was to introduce a new explicit parameter for the cache type. When I just used the output, at least in my use-case, things ended up using f32 for the cache instead of bf16 like I wanted
This PR adds init_cache types and num_heads for better flexibility of text generation models