A language model for population genetics that infers pairwise coalescent times (TMRCA) from genotype data by reframing the problem as translation between observed mutation patterns and the latent ancestral recombination graph.
Korfmann K, Pope NS, Meleghy M, Tellier A, Kern AD (2026). Coalescence and Translation: A Language Model for Population Genetics. PNAS. [paper] [docs]
cxt is a decoder-only transformer inspired by GPT-2 that autoregressively predicts pairwise coalescence times conditioned on local mutational context. For each pair of haplotypes it computes a multi-scale site-frequency spectrum (SFS) in sliding windows, feeds it through the transformer, and outputs a discretized log-TMRCA profile across the genome. The generative process is repeated multiple times to sample from an approximate posterior over TMRCA trajectories, providing well-calibrated uncertainty estimates.
Key features:
- No tree-sequence inference -- works directly on genotype matrices, VCF
files, or
tskittree sequences. - Approximate posterior sampling -- multiple stochastic replicates yield well-calibrated uncertainty estimates for each genomic window.
- Bias correction -- optional Bayesian diversity-based correction aligns predicted diversity with the species mutation rate.
- Broad generalization -- trained on nearly the full
stdpopsimcatalog, generalizing across diverse demographic histories and genome architectures. - Multiple model variants -- narrow, broad, residual, broad_w200, w200_wmissing, and adapter-based models for different sample sizes.
- Multi-GPU inference -- shards pairs across GPUs automatically; produces over a million coalescence predictions in minutes.
pip install -e .Requires Python 3.10+ and PyTorch 2.0+. GPU recommended for inference.
import cxt
# Load a pretrained model (downloads checkpoint on first use)
model = cxt.load_model("broad", device="cuda")
# Translate a tree sequence
import msprime
ts = msprime.sim_ancestry(25, population_size=1e4, sequence_length=1e6,
recombination_rate=1e-8, random_seed=42)
ts = msprime.mutate(ts, rate=1e-8, random_seed=42)
tmrca, index_map = cxt.translate(
ts, model,
blocks=[(0, 1_000_000)],
pivot_pairs=[(0, 1), (2, 3)],
devices=["cuda:0"],
n_reps=15,
)
# tmrca shape: (n_items, n_reps, n_windows) -- log-TMRCA valuestmrca, index_map = cxt.translate(
"path/to/file.vcf", model,
blocks=[(0, 1_000_000)],
pivot_pairs=[(0, 1)],
)import numpy as np
gm = np.load("genotypes.npy") # (n_haplotypes, n_sites)
pos = np.load("positions.npy") # (n_sites,) in bp
tmrca, index_map = cxt.translate(
(gm, pos), model,
blocks=[(0, 1_000_000)],
pivot_pairs=[(0, 1)],
)| Name | Preset | Layers | Description |
|---|---|---|---|
narrow |
PRESETS["narrow"] |
6 | Smaller model, faster inference |
broad |
PRESETS["broad"] |
10 | Main model, best accuracy |
residual |
PRESETS["residual"] |
10 | Predicts log-residuals from population mean |
broad_w200 |
PRESETS["broad_w200"] |
10 | 200bp windows (fine-scale) |
w200_wmissing |
PRESETS["w200_wmissing"] |
10 | 200bp windows, handles missingness |
broad+adapter |
adapter on broad |
10 | 10-sample adapter on broad backbone |
w200_wmissing_adapter |
adapter on w200_wmissing |
10 | 10-sample adapter with missingness |
Pretrained checkpoints for all models are included in the repository
under checkpoints/ via Git LFS.
All model variants are trained with a single unified script:
# Broad model from scratch
python -m cxt.train --model broad --dataset-path /path/to/data --gpus 0 1 2 --epochs 2
# Fine-tune broad_w200 from broad checkpoint
python -m cxt.train --model broad_w200 --checkpoint /path/to/broad.ckpt --lr 3e-5
# Adapter training (10-sample) on frozen broad backbone
python -m cxt.train --model broad --adapter --adapter-samples 10 \
--checkpoint /path/to/broad.ckpt --dataset-path /path/to/n10_data
# w200 with missingness support
python -m cxt.train --model w200_wmissing --checkpoint /path/to/broad_w200.ckpt --lr 3e-5
# Direct checkpoints to a specific directory
python -m cxt.train --model broad --dataset-path /path/to/data --log-dir /path/to/outputsimulate → preprocess → train → figures
-
Simulate: Generate tree sequences with
cxt/simulation_ts_only.pypython cxt/simulation_ts_only.py --scenario constant --data_dir /path/to/raw \ --num_samples 10000 --num_processes 80Tree sequences (
.treesfiles) are saved per scenario, enabling downstream preprocessing. -
Preprocess: Convert tree sequences to training pairs with
cxt.preprocesspython -m cxt.preprocess --base_dir /path/to/raw --out_subdir processed \ --window_size 2000 --num_pairs 200 --num_workers 80 --skip_existing -
Train: Run the unified training script above
A single script bootstraps a uv virtualenv and runs the entire pipeline
(simulate → preprocess → train → figures) in an isolated directory:
./scripts/run_fresh.shAll outputs go to /sietch_colab/data_share/cxt_scratch/ by default.
Override with BASE_DIR:
BASE_DIR=/scratch/myuser/cxt_run ./scripts/run_fresh.shRun individual stages: ./scripts/run_fresh.sh simulate, preprocess,
train, or figures. Multiple stages can be combined:
./scripts/run_fresh.sh train figures.
cxt/
├── __init__.py # Public API: load_model, translate, PRESETS
├── config.py # ModelConfig, PRESETS, AdapterConfig, TrainingConfig
├── model.py # TokenFreeDecoder (transformer architecture)
├── modules.py # Attention, MLP, LayerNorm, MutationsToLatentSpace
├── checkpoint.py # Model loading, checkpoint download/cache
├── translate.py # Unified inference: translate(), generate(), multi-GPU
├── sfs.py # SFS computation, source building, filtering
├── correction.py # Diversity-based bias correction (deterministic + stochastic)
├── train.py # Unified training script (Lightning)
├── dataset.py # PairDataset for training
├── simulation_ts_only.py # Tree-sequence simulation CLI (msprime + stdpopsim)
├── preprocess.py # TS → training data conversion
└── utils.py # Grids, helpers, simulation functions
checkpoints/ # Pretrained model weights (Git LFS)
figures/ # Paper figure reproduction scripts
scripts/
├── run_fresh.sh # Full reproduce pipeline (simulate → figures)
├── copy_checkpoints.sh # Copy trained checkpoints into repo for LFS commit
├── retrain_adapter.sh # Retrain adapter models from existing backbone
└── promote_dev_to_main.sh # Merge dev branch into main with checkpoint migration
Model configuration is handled through a single ModelConfig dataclass with
presets for each variant:
from cxt.config import ModelConfig, PRESETS
# Get a preset
config = PRESETS["broad"]
# Customize for inference
config = config.for_inference(batch_size=1, device="cuda")
# Customize for training
config = config.for_training(batch_size=128, device="cuda")Key config flags:
mask_singletons: Whether to zero out singleton SFS entries (True for most models, False forw200_wmissing)use_kv_cache: Allocate KV cache for autoregressive decoding (set automatically byfor_inference())
@article{korfmann2026cxt,
title={Coalescence and Translation: A Language Model for Population Genetics},
author={Korfmann, Kevin and Pope, Nathaniel S. and Meleghy, Melinda and Tellier, Aur{\'e}lien and Kern, Andrew D.},
journal={Proceedings of the National Academy of Sciences},
year={2026},
doi={10.1073/pnas.XXXXXXXXXX}
}See LICENSE for details.
