Skip to content

5746 tokens job: RESOURCE_EXHAUSTED (H100 80G, despite unified memory, jax flag and increase bucket size) #341

@GitCeliniHub

Description

@GitCeliniHub

Hi @Augustin-Zidek

It was lovely and very helpful to meet you guys on Thursday :)

As you suggested and in order to succeed with my 5746 token jobs, I threw all the additional flags at it (with the help of Claude, as this all way beyond my skills and knowledge), and I still run into ressource_exhausted.

#!/bin/bash

#SBATCH --job-name=AF3-4GPU                 
#SBATCH --mail-type=END,FAIL                
#SBATCH --mail-user=xxx@crick.ac.uk     
#SBATCH --partition=gh100
#SBATCH --reservation=h100
#SBATCH --nodes=1
#SBATCH --ntasks=1                          
#SBATCH --gres=gpu:4
#SBATCH --cpus-per-task=100                 
#SBATCH --mem=0                         
#SBATCH --time=72:00:00                     
#SBATCH --output=sbatch%j.log   

ml purge
ml Singularity

# Create JAX cache directory
mkdir -p /xxx/jax_cache
chmod 777 /xxx/jax_cache

# Create temporary directory for model_config
TEMP_DIR=$(mktemp -d)
chmod 777 $TEMP_DIR

# Create model_config.py with documented sharding strategy
cat > $TEMP_DIR/model_config.py << 'EOL'
from typing import Sequence
from typing_extensions import TypeAlias

_Shape2DType: TypeAlias = tuple[int | None, int | None]

pair_transition_shard_spec: Sequence[_Shape2DType] = (
    (2048, None),
    (3072, 1024),
    (None, 512),
)
EOL

# Settings from documentation
export XLA_PYTHON_CLIENT_PREALLOCATE=false
export TF_FORCE_UNIFIED_MEMORY=true
export XLA_CLIENT_MEM_FRACTION=3.2
export XLA_FLAGS="--xla_gpu_enable_triton_gemm=false"

singularity exec \
    --nv \
    --bind /xxx/:/root/af_input \
    --bind /xxx/:/root/af_output \
    --bind /flask/reference/Alphafold3_dataset/model_parameters:/root/models \
    --bind /flask/reference/Alphafold3_dataset/datasets:/root/public_databases \
    --bind /xxx/jax_cache:/root/jax_cache \
    --bind $TEMP_DIR/model_config.py:/app/alphafold/alphafold3/model/model_config.py \
    /flask/apps/containers/Alphafold/3.0.0/alphafold3.sif \
    python /app/alphafold/run_alphafold.py \
    --json_path=/root/af_input/fold_input.json \
    --model_dir=/root/models \
    --db_dir=/root/public_databases \
    --output_dir=/root/af_output \
    --buckets 5746 \
    --jax_compilation_cache_dir=/root/jax_cache

# Cleanup temporary directory
rm -rf $TEMP_DIR

It runs for a few hours and returns

Running model inference for seed 1...
Traceback (most recent call last):
  File "/app/alphafold/run_alphafold.py", line 699, in <module>
    app.run(main)
  File "/alphafold3_venv/lib/python3.11/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/alphafold3_venv/lib/python3.11/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
             ^^^^^^^^^^
  File "/app/alphafold/run_alphafold.py", line 684, in main
    process_fold_input(
  File "/app/alphafold/run_alphafold.py", line 556, in process_fold_input
    all_inference_results = predict_structure(
                            ^^^^^^^^^^^^^^^^^^
  File "/app/alphafold/run_alphafold.py", line 373, in predict_structure
    result = model_runner.run_inference(example, rng_key)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/app/alphafold/run_alphafold.py", line 311, in run_inference
    result = self._model(rng_key, featurised_example)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 84554026840 bytes.
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

Any insights?
Cheers!
Celine

Metadata

Metadata

Assignees

No one assigned

    Labels

    questionFurther information is requested

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions