This repository contains the experiment code for the paper mHC-lite: You Don’t Need 20 Sinkhorn-Knopp Iterations. The codebase is adapted from nanoGPT.
Install the required packages:
pip install torch numpy transformers datasets tiktoken wandb tqdm einopsTo prepare the datasets, enter the corresponding dataset folder and run prepare.py:
cd data/shakespeare_char
python prepare.py
cd ../fineweb_edu
python prepare.py
cd ../openwebtext
python prepare.pyData preparation typically takes ~30 minutes (depending on your machine and disk speed).
To train a model, run train.py. Use torchrun to enable distributed training (see the original nanoGPT project for details). You can combine multiple config files to specify the dataset, model scale, and method.
-
Model scales:
- S:
config/small_model.py - M:
config/medium_model.py - L:
config/large_model.py
- S:
-
Methods:
- HC:
config/with_hc.py - mHC:
config/with_mhc.py - mHC-lite:
config/with_mhc_lite.py - Residual: (default)
- HC:
-
Datasets:
- OpenWebText:
config/train_owt.py - FineWeb-Edu:
config/train_fineweb_edu.py
- OpenWebText:
Train a small (S) model with mHC-lite on OpenWebText:
torchrun --standalone --nproc_per_node=8 train.py \
config/train_owt.py config/small_model.py config/with_mhc_lite.py- Set
--nproc_per_nodeto the number of GPUs you have.
Alternatively, run run.sh to reproduce all experiments reported in Table 1 of the paper.
Run train_analysis.py with config/with_mhc_analysis.py to perform analysis using a checkpoint.
Experiment runs automatically create output directories and save checkpoints. For analysis runs, please specify the checkpoint directory via --out_dir so the script can resume from it. Anakysis can only be performed on checkpoints with mHC enabled.
To analyze a checkpoint from a small model trained on OpenWebText, set --out_dir to out-owt-small-mhc:
python train_analysis.py \
config/train_owt.py config/with_mhc_analysis.py config/small_model.py \
--out_dir=out-owt-small-mhc --init_from=resumeAfter the analysis run, results will be saved to log_out/infos.pkl. Then run:
python -m analyze.h_and_nuThis produces the analysis figures in analyze/.
This codebase is adapted from nanoGPT.
Our Hyper-Connection implementation is from the hyper-connections library. Note that hyper-connections also provides an mHC implementation; however, since it does not exactly match the details described in the mHC paper, we implemented our own version.
We would also like to thank this mHC reproduction, which (to the best of our knowledge) is the earliest public reproduction of mHC. While we do not directly use code from it, some of our early explorations were inspired by that project.