When I use the config
config = LlamalikeTransformerConfig(
num_kv_heads=32,
query_head_multiplier=1,
embedding_dim=4096,
projection_dim=128,
mlp_hidden_dim=11008,
num_decoder_blocks=32,
vocab_size=32007,
mlp_variant='swiglu',
tie_embedder_and_logits=False,
rope_wavelength=10000.0,
rms_norm_eps=1e-05,
attention_type=AttentionTypeGlobalCausal(),
use_post_attn_norm=False,
use_post_ffw_norm=False,
final_logit_softcap=None,
attn_logits_soft_cap=None,
query_scaling_factor='default',
parameter_dtype=jax.numpy.bfloat16,
activation_dtype=jax.numpy.float32,
use_layer_stack=True,
)
to run
llamalike_common.build_llamalike_transformer(config, jax.random.key(123))
I get XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 2885681152 bytes, even though the loaded model should only take up about 13 GB. When I run this on CPU, I also see the spike in memory consumption before it settles down to the appropriate value.
This does not happen when use_layer_stack=False.
When I use the config
to run
I get
XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 2885681152 bytes, even though the loaded model should only take up about 13 GB. When I run this on CPU, I also see the spike in memory consumption before it settles down to the appropriate value.This does not happen when
use_layer_stack=False.