Skip to content

LayerStack.from_sublayer_builder temporarily uses about twice the memory #122

@JEM-Mosig

Description

@JEM-Mosig

When I use the config

config = LlamalikeTransformerConfig(
    num_kv_heads=32,
    query_head_multiplier=1,
    embedding_dim=4096,
    projection_dim=128,
    mlp_hidden_dim=11008,
    num_decoder_blocks=32,
    vocab_size=32007,
    mlp_variant='swiglu',
    tie_embedder_and_logits=False,
    rope_wavelength=10000.0,
    rms_norm_eps=1e-05,
    attention_type=AttentionTypeGlobalCausal(),
    use_post_attn_norm=False,
    use_post_ffw_norm=False,
    final_logit_softcap=None,
    attn_logits_soft_cap=None,
    query_scaling_factor='default',
    parameter_dtype=jax.numpy.bfloat16,
    activation_dtype=jax.numpy.float32,
    use_layer_stack=True,
)

to run

llamalike_common.build_llamalike_transformer(config, jax.random.key(123))

I get XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 2885681152 bytes, even though the loaded model should only take up about 13 GB. When I run this on CPU, I also see the spike in memory consumption before it settles down to the appropriate value.

This does not happen when use_layer_stack=False.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions