Skip to content

kevinlu4588/ESMProbing

Repository files navigation

TCR-ESM Analysis Pipeline

This repository contains scripts for training and analyzing TCR-ESM models for predicting TCR-peptide binding using ESM protein language model embeddings.

Overview

The pipeline consists of several stages:

  1. Model training and evaluation
  2. Representation analysis
  3. Layer-wise classifier probing
  4. Attention mechanism analysis
  5. Model interpretability experiments

Prerequisites

  • Python 3.8+
  • PyTorch
  • ESM (Fair-ESM)
  • See requirements.txt for full dependencies

Setup

Creating the Conda Environment (Added on 8/27/2025 for better reproducibility)

  1. Create a new conda environment (recommended):
conda create -n esm_probing python=3.10
conda activate esm_probing
  1. Install dependencies:
pip install -r requirements.txt

Note: The requirements.txt file contains all necessary dependencies including PyTorch, ESM, and other required packages.

Pipeline Steps

1. Train the Base Model

First, train the TCR-ESM model using the training script:

python train_model.py \
  --dataset train_ab_90_beta \
  --model_type beta \
  --epochs 200 \
  --batch_size 32 \
  --learning_rate 0.001 \
  --output_dir ./models/

Key parameters:

  • --dataset: Training dataset name (e.g., train_beta_90, train_ab_90_beta)
  • --model_type: Model type (beta, alpha, or alphabeta)
  • --num_epochs: Number of training epochs
  • --batch_size: Batch size for training
  • --validation_split: Fraction of data for validation (default: 0.2)

2. Evaluate the Model

Evaluate the trained model on the MIRA dataset:

python evaluate_beta_mira.py

This script:

  • Loads the trained model from ./models/train_ab_90_beta/fold_2_beta_best_mcc.pt
  • Evaluates on MIRA evaluation set
  • Outputs metrics and predictions to ./results_mira_beta_evaluation.csv

3. Analyze Layer Representations (Difference of Means)

Analyze how different ESM layers separate positive and negative examples:

python difference_of_means.py \
  --csv ./NetTCR-2.0/data/train_ab_90_beta_balanced1.csv \
  --out layer_map.csv

This generates:

  • Layer-wise performance metrics using difference of means
  • Comparison with logistic regression baseline
  • Output CSV with AUC, AP, MCC, and accuracy per layer

4. Layer-wise Classifier Sweep

Perform comprehensive analysis across all layer combinations:

python classifier_sweeping.py \
  --train_csv ./NetTCR-2.0/data/train_ab_90_beta.csv \
  --eval_csv ./NetTCR-2.0/data/mira_eval_threshold90.csv \
  --grid_all \
  --pool mean \
  --cache_mode reuse \
  --out_dir ./intermediate_expts/grid_33x33_mean

This script:

  • Trains classifiers on all CDR3 × peptide layer combinations (33×33 grid)
  • Caches embeddings for efficiency
  • Generates heatmaps showing performance across layer combinations
  • Outputs detailed CSV with all metrics

5. Train Intermediate Layer Models

Train models using specific intermediate layer representations:

python train_intermediate_layers.py \
  --train_csv ./NetTCR-2.0/data/train_ab_90_beta.csv \
  --eval_csv ./NetTCR-2.0/data/mira_eval_threshold90.csv \
  --cdr3_layers 20 21 22 \
  --pep_layers 20 21 22 \
  --pool mean \
  --out_dir ./layer_probe_results/

Parameters:

  • --cdr3_layers: Which ESM layers to use for CDR3 encoding
  • --pep_layers: Which ESM layers to use for peptide encoding
  • --pool: Pooling method (mean or cls)

6. Attention Attribution Analysis

Analyze attention patterns to understand model decisions:

python attention_attribution.py \
  --model_path ./models/best_model.pt \
  --eval_csv ./NetTCR-2.0/data/mira_eval_threshold90.csv \
  --output_dir ./attention_analysis/

This script:

  • Extracts attention weights from ESM
  • Analyzes attention patterns for positive vs negative examples
  • Generates visualization of important attention heads

7. Head Ablation Studies

Evaluate the importance of individual attention heads:

python ablate_heads_eval.py \
  --model_path ./models/best_model.pt \
  --eval_csv ./NetTCR-2.0/data/mira_eval_threshold90.csv \
  --combo_sizes 1 2 3 \
  --output_dir ./ablation_results/

Data Format

All CSV files should contain the following columns:

  • CDR3b: Beta chain CDR3 sequence (required for beta models)
  • CDR3a: Alpha chain CDR3 sequence (required for alpha/alphabeta models)
  • peptide: Peptide sequence
  • binder: Binary label (1 for binder, 0 for non-binder)
  • partition: Optional, for train/val split

Output Files

  • Models: Saved in PyTorch format (.pt files)
  • Evaluations: CSV files with predictions and metrics
  • Visualizations: PNG heatmaps and plots
  • Layer Analysis: CSV files with layer-wise metrics

Tips

  1. Use --cache_mode reuse in classifier_sweeping.py to avoid recomputing embeddings
  2. Start with smaller datasets to test the pipeline
  3. Monitor GPU memory usage, especially for batch size selection
  4. Check intermediate outputs to ensure proper data flow

Troubleshooting

  • If OOM errors occur, reduce batch size
  • Ensure all paths are relative when moving the code
  • Verify CSV format matches expected columns
  • Check that ESM model downloads properly on first run

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages