Skip to content

AI-sandbox/ARGformer

Repository files navigation

ARGformer logo

Learning on ancestral recombination graphs with transformers

ARGformer is a transformer encoder based on ModernBERT for Ancestral Recombination Graph (ARG) data. It uses the FlexBERT architecture with YAML-based configuration.

The codebase builds upon MosaicBERT and the fork with Flash Attention 2 under Apache 2.0 license.

For ModernBERT details, see the release blog post.

Setup

Install pixi

If you do not have pixi installed or are not familiar with it, pixi is a fast alternative to tools like conda and pip for managing environments and Python packages. You can install it with:

curl -fsSL https://pixi.sh/install.sh | sh

Once pixi is installed, create and enter the environment with:

pixi install
pixi shell

Verify that the solve produced a CUDA-enabled PyTorch build before installing FlashAttention:

python -c "import torch; print(torch.__version__); print(torch.version.cuda); print(torch.cuda._is_compiled())"

The base pixi environment does not include flash-attn. Install it separately inside the pixi shell if you want FlashAttention 2 support:

pip install "flash_attn==2.6.3" --no-build-isolation

For H100 GPUs, optionally install Flash Attention 3:

git clone https://github.com/Dao-AILab/flash-attention.git
cd flash-attention/hopper
python setup.py install

Overview

ARGformer supports:

  • Pretraining: Masked language modeling on ARG sequences
  • Contrastive Learning: Fine-tuning for retrieval and similarity tasks
  • Embeddings: Extracting embeddings for downstream analysis
  • Retrieval: Finding similar sequences in large corpora

Data Format

ARG data structure:

/path/to/arg/data/
├── train/
│   ├── tokenized_train_sequences_and_vocab.pkl
│   └── labels.pkl  # Optional: for contrastive learning
└── val/
    ├── tokenized_val_sequences_and_vocab.pkl
    └── labels.pkl  # Optional: for contrastive learning

The ARGDataset class supports pretokenized sequences with vocabulary mappings for node IDs and special tokens ([PAD], [CLS], [SEP]).

Workflow

1. Prepare Pretraining Data

Extract sequences from tree files using src/data/prepare_data_pretrain.py:

python src/data/prepare_data_pretrain.py

Edit the script to configure input paths, output directory, and train/val split.

2. Pretraining

Configure yamls/mlm.yaml with dataset paths and model parameters, then run:

composer main.py yamls/mlm.yaml

3. Contrastive Fine-tuning

Configure yamls/contrastive.yaml with pretrained checkpoint path and run:

python sequence_contrastive.py yamls/contrastive.yaml

4. Extract Embeddings

python embeddings.py [arguments]

See the script for usage examples.

5. Retrieval

python retrieve.py [arguments]

See the script for usage examples.

Configuration

Training uses composer with YAML configuration files in yamls/:

  • mlm.yaml: Pretraining configuration
  • contrastive.yaml: Contrastive learning configuration

Key configuration sections:

  • model: Model architecture and checkpoint paths
  • train_loader / eval_loader: Dataset paths and data loading settings
  • optimizer / scheduler: Training hyperparameters
  • loggers: WandB logging configuration

Citation

If you use ARGformer in your research, please cite our paper:

@article{bonet2026argformer,
  title={{ARGformer}: learning on ancestral recombination graphs with transformers},
  author={Bonet, David and Shanks, Cole and Cara, Marçal Comajoan and Abante, Jordi and Ioannidis, Alexander G},
  journal={bioRxiv},
  pages={2026--02},
  year={2026},
  doi = {10.64898/2026.02.11.705405},
  publisher={Cold Spring Harbor Laboratory}
}

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Languages