ESM + PEFT LoRA for 3Di per-residue prediction. Train ESM-2 or ESM++ models with LoRA adapters to predict 3Di structural sequences from amino acid sequences.
No installation required! Run ESM3Di directly in Google Colab with GPU support:
The Colab notebook allows you to:
- Predict 3Di sequences from amino acid FASTAs
- Choose between ESM2 and ESM++ models
- Download results as FASTA files
- Create FoldSeek databases
- 🧬 Train ESM-2 and ESM++ models for 3Di structure prediction
- 🎯 Memory-efficient training using LoRA (Low-Rank Adaptation)
- 🔧 Support for masking low-confidence positions
- ⚡ Multi-GPU training with DataParallel
- 🔀 Multi-GPU inference with automatic sharding
- 🌐 Google Colab notebook for online inference
- Create and activate the conda environment:
# Standard environment (CUDA 11.8, most GPUs):
conda env create -f environment.yml
conda activate esm3di
# For Blackwell GPUs (RTX 5090, RTX PRO 4000, etc.):
conda env create -f environment_blackwell.yml
conda activate esm3di_blackwellNote: For exact reproducibility, use environment_frozen.yml which has pinned versions.
- Create a virtual environment:
python -m venv venv
source venv/bin/activate # On Windows: venv\Scripts\activate- Install dependencies:
pip install -r requirements.txt
pip install -e .Download random AlphaFold structures for testing or training:
python -m esm3di.testdataset \
--count 10 \
--output-dir test_structures \
--seed 42Or download specific proteins by UniProt accession:
python -m esm3di.testdataset \
--accessions P04637 P01112 P42574 \
--output-dir structures--count: Number of random structures to download--accessions: Specific UniProt accessions to download--output-dir: Output directory (default: alphafold_structures)--delay: Delay between downloads in seconds (default: 0.5)--seed: Random seed for reproducible sampling--version: AlphaFold model version (default: 4)
Note: Downloaded structures are from the AlphaFold Protein Structure Database
Generate training data from PDB structures with pLDDT-based masking:
python -m esm3di.build_trainingset \
--pdb-dir alphafold_structures/ \
--output-prefix training_data \
--plddt-threshold 70 \
--mask-char XThis will:
- Parse PDB files to extract sequences and pLDDT scores (from B-factor column)
- Use FoldSeek to generate 3Di sequences
- Mask low-confidence positions (pLDDT < threshold) in 3Di sequences
- Output AA and masked 3Di FASTA files ready for training
--pdb-dir: Directory containing PDB files (or use--pdb-filesfor specific files)--output-prefix: Prefix for output files--plddt-threshold: pLDDT threshold below which to mask (default: 70)--mask-char: Character for masking low-confidence positions (default: X)--split-chains: Split multi-chain structures into separate entries--chain: Extract specific chain only
Output files:
{prefix}_aa.fasta: Amino acid sequences{prefix}_3di_masked.fasta: 3Di sequences with masked positions{prefix}_stats.txt: Statistics about masking
Train a model using FASTA files with amino acid sequences and corresponding 3Di labels:
python -m esm3di.esmretrain \
--aa-fasta data/sequences.fasta \
--three-di-fasta data/3di_labels.fasta \
--hf-model facebook/esm2_t33_650M_UR50D \
--mask-label-chars "X" \
--batch-size 4 \
--epochs 10 \
--lr 1e-4 \
--out-dir checkpoints/ESM++ models from Synthyra (via HuggingFace) are supported and offer improved performance:
# Train with ESM++ Small (333M params)
python -m esm3di.esmretrain \
--aa-fasta data/sequences.fasta \
--three-di-fasta data/3di_labels.fasta \
--hf-model Synthyra/ESMplusplus_small \
--mask-label-chars "X" \
--batch-size 4 \
--epochs 10 \
--lr 2e-4 \
--out-dir checkpoints/
# Or use ESM++ Large for better quality
python -m esm3di.esmretrain \
--hf-model Synthyra/ESMplusplus_large \
# ... other argsAvailable ESM++ Models:
Synthyra/ESMplusplus_small: 333M parametersSynthyra/ESMplusplus_large: 575M parameters
ESM++ models provide:
- Better protein representations than ESM-2
- Faster inference
- Improved scaling and performance
- Native HuggingFace integration (no additional dependencies)
--aa-fasta: FASTA file with amino acid sequences--three-di-fasta: FASTA file with matching 3Di sequences (same order and length)--hf-model: Model identifier. ESM-2 options:facebook/esm2_t12_35M_UR50D(35M),facebook/esm2_t30_150M_UR50D(150M),facebook/esm2_t33_650M_UR50D(650M). ESM++ options:Synthyra/ESMplusplus_small(333M),Synthyra/ESMplusplus_large(575M)--mask-label-chars: Characters to treat as masked (e.g., low pLDDT positions)--lora-r: LoRA rank (default: 8)--lora-alpha: LoRA scaling factor (default: 16.0)--batch-size: Training batch size per GPU--epochs: Number of training epochs--lr: Learning rate--out-dir: Directory to save checkpoints--multi-gpu: Enable multi-GPU training (uses all available GPUs)--mixed-precision: Enable FP16 mixed precision training--gradient-accumulation-steps: Accumulate gradients over N batches
For training on multiple GPUs, use the --multi-gpu flag:
python -m esm3di.esmretrain \
--aa-fasta data/sequences.fasta \
--three-di-fasta data/3di_labels.fasta \
--hf-model Synthyra/ESMplusplus_small \
--batch-size 8 \
--multi-gpu \
--mixed-precision \
--epochs 10 \
--out-dir checkpoints/This uses torch.nn.DataParallel to automatically distribute batches across all available GPUs. The effective batch size is batch_size * num_gpus.
Multi-GPU Training Options:
--multi-gpu: Enable DataParallel multi-GPU training--mixed-precision: Enable FP16 for faster training and reduced memory--gradient-accumulation-steps: Simulate larger batches when GPU memory is limited--device: Specify primary GPU (e.g.,cuda:0)
Example with gradient accumulation:
# Effective batch size = 4 * 4 * 2 GPUs = 32
python -m esm3di.esmretrain \
--batch-size 4 \
--gradient-accumulation-steps 4 \
--multi-gpu \
# ... other argsFor reproducible experiments, use JSON config files:
python -m esm3di.esmretrain --config config_esmpp_large.jsonExample config file:
{
"aa_fasta": "data/sequences.fasta",
"three_di_fasta": "data/3di_labels.fasta",
"hf_model": "Synthyra/ESMplusplus_small",
"mask_label_chars": "X",
"use_cnn_head": true,
"batch_size": 8,
"epochs": 3,
"lr": 0.0002,
"multi_gpu": true,
"mixed_precision": true,
"out_dir": "checkpoints_esmpp"
}Use the trained model to predict 3Di sequences:
from esm3di import predict_3di_for_fasta
results = predict_3di_for_fasta(
model_ckpt="checkpoints/epoch_10.pt",
aa_fasta="data/test_sequences.fasta",
device="cuda" # or "cpu"
)
for header, aa_seq, pred_3di in results:
print(f">{header}")
print(f"AA: {aa_seq}")
print(f"3Di: {pred_3di}")Generate a FoldSeek-compatible database from amino acid sequences:
python -m esm3di.fastas2foldseekdb \
--aa-fasta data/proteins.fasta \
--model-ckpt checkpoints/epoch_10.pt \
--output-db my_foldseek_dbThis will:
- Run ESM inference to predict 3Di sequences
- Create intermediate AA and 3Di FASTA files
- Build a FoldSeek database with both sequence and structure information
For large datasets, enable multi-GPU inference to parallelize predictions:
python -m esm3di.fastas2foldseekdb \
--aa-fasta data/large_dataset.fasta \
--model-ckpt checkpoints/epoch_10.pt \
--output-db my_foldseek_db \
--multi-gpu \
--num-gpus 4How Multi-GPU Inference Works:
- Input sequences are sharded across GPUs using round-robin distribution
- Each GPU runs as an isolated subprocess with its own CUDA context
- Predictions are merged back into original sequence order
- FoldSeek database is built from the merged results
Multi-GPU Inference Options:
--multi-gpu: Enable multi-GPU inference--num-gpus: Number of GPUs to use (default: all available)
If you already have 3Di predictions:
python -m esm3di.fastas2foldseekdb \
--aa-fasta data/proteins.fasta \
--three-di-fasta data/proteins_3di.fasta \
--output-db my_foldseek_db \
--skip-inference--keep-fastas: Keep intermediate FASTA files after database creation--output-aa-fasta: Specify path for AA FASTA output--output-3di-fasta: Specify path for 3Di FASTA output--foldseek-bin: Custom path to foldseek binary
Note: FoldSeek must be installed and available in your PATH. Download from https://github.com/steineggerlab/foldseek
Both amino acid and 3Di FASTA files should have:
- Matching number of sequences
- Sequences in the same order
- Equal length for corresponding AA and 3Di sequences
Example sequences.fasta:
>protein1
MKTAYIAKQRQISFVKSHFSRQLEERLGLIEVQAPILSRVGDGTQDNLSGAEK
>protein2
KALTARQQEVFDLIRDHISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIE
Example 3di_labels.fasta:
>protein1
acbdACBDacbdACBDacbdACBDacbdACBDacbdACBDacbdACBDacbdA
>protein2
XbdACBDacbdACBDacbdACBDacbdACBDacbdACBDacbdACBDacbdACB
Note: Characters specified in --mask-label-chars (e.g., 'X') will be ignored during training.
Checkpoints are saved after each epoch and contain:
- Model state dict (including LoRA adapters)
- Label vocabulary
- Masked label characters
- Training arguments
Pre-trained checkpoints are hosted on HuggingFace Hub for easy access:
from huggingface_hub import hf_hub_download
# Download ESM2 35M checkpoint (~131MB)
checkpoint = hf_hub_download(
repo_id="cactuskid13/esm2small_3di",
filename="epoch_3.pt"
)
# Or download ESM++ BFVD checkpoint (~1.4GB)
checkpoint = hf_hub_download(
repo_id="cactuskid13/ESMpp_small_3Di",
filename="epoch_3.pt"
)| HuggingFace Repository | Model | Size | Description |
|---|---|---|---|
| cactuskid13/esm2small_3di | ESM2 35M | ~131MB | Fast, well-tested, recommended |
| cactuskid13/ESMpp_small_3Di | ESM++ small | ~1.4GB | Trained on BFVD, better accuracy |
Recommended for most users: cactuskid13/esm2small_3di (fast, small, reliable)
Use the test script to verify that model outputs have sufficient diversity:
python test_output_diversity.py output_3di.fastaThis will check that:
- Output contains multiple unique 3Di characters
- No single character dominates more than 50% of predictions
- Output is not effectively uniform (>90% one character)
- Python ≥ 3.8
- PyTorch ≥ 2.0.0
- transformers ≥ 4.30.0
- peft ≥ 0.5.0
- CUDA-capable GPU (recommended)
- biopython (for FoldSeek database creation)
For multi-GPU training/inference:
- Multiple CUDA-capable GPUs
- Sufficient GPU memory (ESM++ small: ~4GB per GPU, ESM++ large: ~8GB per GPU)
MIT License
If you use this code, please cite the relevant papers:
- ESM-2: [Lin et al., 2022]
- ESM++: [Synthyra, 2024] - https://huggingface.co/Synthyra
- LoRA: [Hu et al., 2021]
- 3Di: [van Kempen et al., 2023]
Funded by NIH through the Pathogen Data Network.
*This resource is supported by the National Institute Of Allergy And Infectious Diseases of the National Institutes of Health under Award Number U24AI183840. The content is solely the responsibility of the authors and does not necessarily represent the official views of the National Institutes of Health.