This repository contains scripts for training and analyzing TCR-ESM models for predicting TCR-peptide binding using ESM protein language model embeddings.
The pipeline consists of several stages:
- Model training and evaluation
- Representation analysis
- Layer-wise classifier probing
- Attention mechanism analysis
- Model interpretability experiments
- Python 3.8+
- PyTorch
- ESM (Fair-ESM)
- See
requirements.txtfor full dependencies
- Create a new conda environment (recommended):
conda create -n esm_probing python=3.10
conda activate esm_probing- Install dependencies:
pip install -r requirements.txtNote: The requirements.txt file contains all necessary dependencies including PyTorch, ESM, and other required packages.
First, train the TCR-ESM model using the training script:
python train_model.py \
--dataset train_ab_90_beta \
--model_type beta \
--epochs 200 \
--batch_size 32 \
--learning_rate 0.001 \
--output_dir ./models/Key parameters:
--dataset: Training dataset name (e.g., train_beta_90, train_ab_90_beta)--model_type: Model type (beta, alpha, or alphabeta)--num_epochs: Number of training epochs--batch_size: Batch size for training--validation_split: Fraction of data for validation (default: 0.2)
Evaluate the trained model on the MIRA dataset:
python evaluate_beta_mira.pyThis script:
- Loads the trained model from
./models/train_ab_90_beta/fold_2_beta_best_mcc.pt - Evaluates on MIRA evaluation set
- Outputs metrics and predictions to
./results_mira_beta_evaluation.csv
Analyze how different ESM layers separate positive and negative examples:
python difference_of_means.py \
--csv ./NetTCR-2.0/data/train_ab_90_beta_balanced1.csv \
--out layer_map.csvThis generates:
- Layer-wise performance metrics using difference of means
- Comparison with logistic regression baseline
- Output CSV with AUC, AP, MCC, and accuracy per layer
Perform comprehensive analysis across all layer combinations:
python classifier_sweeping.py \
--train_csv ./NetTCR-2.0/data/train_ab_90_beta.csv \
--eval_csv ./NetTCR-2.0/data/mira_eval_threshold90.csv \
--grid_all \
--pool mean \
--cache_mode reuse \
--out_dir ./intermediate_expts/grid_33x33_meanThis script:
- Trains classifiers on all CDR3 × peptide layer combinations (33×33 grid)
- Caches embeddings for efficiency
- Generates heatmaps showing performance across layer combinations
- Outputs detailed CSV with all metrics
Train models using specific intermediate layer representations:
python train_intermediate_layers.py \
--train_csv ./NetTCR-2.0/data/train_ab_90_beta.csv \
--eval_csv ./NetTCR-2.0/data/mira_eval_threshold90.csv \
--cdr3_layers 20 21 22 \
--pep_layers 20 21 22 \
--pool mean \
--out_dir ./layer_probe_results/Parameters:
--cdr3_layers: Which ESM layers to use for CDR3 encoding--pep_layers: Which ESM layers to use for peptide encoding--pool: Pooling method (mean or cls)
Analyze attention patterns to understand model decisions:
python attention_attribution.py \
--model_path ./models/best_model.pt \
--eval_csv ./NetTCR-2.0/data/mira_eval_threshold90.csv \
--output_dir ./attention_analysis/This script:
- Extracts attention weights from ESM
- Analyzes attention patterns for positive vs negative examples
- Generates visualization of important attention heads
Evaluate the importance of individual attention heads:
python ablate_heads_eval.py \
--model_path ./models/best_model.pt \
--eval_csv ./NetTCR-2.0/data/mira_eval_threshold90.csv \
--combo_sizes 1 2 3 \
--output_dir ./ablation_results/All CSV files should contain the following columns:
CDR3b: Beta chain CDR3 sequence (required for beta models)CDR3a: Alpha chain CDR3 sequence (required for alpha/alphabeta models)peptide: Peptide sequencebinder: Binary label (1 for binder, 0 for non-binder)partition: Optional, for train/val split
- Models: Saved in PyTorch format (.pt files)
- Evaluations: CSV files with predictions and metrics
- Visualizations: PNG heatmaps and plots
- Layer Analysis: CSV files with layer-wise metrics
- Use
--cache_mode reusein classifier_sweeping.py to avoid recomputing embeddings - Start with smaller datasets to test the pipeline
- Monitor GPU memory usage, especially for batch size selection
- Check intermediate outputs to ensure proper data flow
- If OOM errors occur, reduce batch size
- Ensure all paths are relative when moving the code
- Verify CSV format matches expected columns
- Check that ESM model downloads properly on first run