The FastnnUNet Knowledge Distillation Module is an advanced knowledge transfer system developed based on the nnUNetv2 framework. It transfers knowledge from high-performance but computationally complex teacher models (standard nnUNet) to lightweight student models, achieving significant reductions in model size and computational cost while maintaining segmentation accuracy.
- Efficient Knowledge Distillation: Transfers segmentation capabilities from standard nnUNet to lightweight models
- Multi-teacher Ensemble Learning: Supports knowledge extraction from multiple cross-validation model ensembles
- Adaptive Network Architecture: Automatically designs optimized student networks based on teacher models
- Hybrid Distillation Strategy: Joint training method combining soft labels and hard labels
- Feature Reduction Control: Configurable feature channel reduction ratio (default 50% reduction)
- DA5 Strong Data Augmentation: Advanced data augmentation strategy specifically optimized for small datasets
- Compatibility Guarantee: Fully compatible with original nnUNetv2, supporting all configurations and datasets
- Complete Training Cycle: Includes checkpoint recovery, automatic validation, and ONNX export functionality
- ResEnc Architecture Support: Enhanced support for Residual Encoder U-Net (ResEnc) architectures with improved performance for complex segmentation tasks
FastnnUNet knowledge distillation employs advanced knowledge transfer strategies, mainly including:
-
Soft Label Knowledge Transfer:
- Extract class probability distributions rather than hard labels from teacher models
- Control the "softness" of soft labels using temperature parameters
- Capture knowledge of correlations between categories through KL divergence loss
-
Architecture Optimization Design:
- Maintain the U-Net topology structure of the teacher model
- Proportionally reduce the number of feature channels in each layer
- Retain key design elements of the teacher model (depth, convolution types, skip connections, etc.)
-
Multi-teacher Ensemble Learning:
- Simultaneously leverage ensemble knowledge from probabilistic maps of multiple teacher models (different folds), instead of having to reason about each of the 5 folded models once and pool the predictions
- Simplify and get rid of nnUNet's original idea of finding the best model after training a 5-fold model
- Capture complementary information between models, improving generalization ability
-
ResEnc Architecture Support:
- Support for Residual Encoder U-Net architectures with deeper encoder structures
- Enhanced feature extraction capabilities through residual connections in the encoder
- Multiple ResEnc variants: ResEncM, ResEncL, and ResEncXL for different complexity requirements
- Improved performance for complex medical image segmentation tasks with challenging anatomical structures
-
DA5 Strong Data Augmentation:
- Advanced data augmentation strategy optimized for small datasets and challenging segmentation tasks
- Enhanced spatial transformations with reduced interpolation order for better preservation of fine details
- Comprehensive augmentation pipeline including rotation, scaling, elastic deformation, and intensity transformations
- Particularly effective for medical imaging datasets with limited training samples
- Compatible with both standard and ResEnc distillation architectures
Before using the FastnnUNet Knowledge Distillation Module, please ensure the following requirements are met:
-
Basic Environment:
- Python 3.7+
- CUDA 11.0+ (recommended for GPU training)
- PyTorch 1.11+
-
nnUNetv2 Dependencies:
- Ensure nnUNetv2 is correctly installed
- Environment variables are properly configured
-
Install This Module:
# Enter the project directory cd FastnnUNet/distillation # Install in development mode pip install -e .
First, prepare the dataset and train standard nnUNetv2 models as teachers:
# Data preprocessing
nnUNetv2_plan_and_preprocess -d DATASET_ID --verify_dataset_integrity
# Train 5-fold cross-validation teacher models (Standard nnUNet)
nnUNetv2_train DATASET_ID 3d_fullres 0
nnUNetv2_train DATASET_ID 3d_fullres 1
nnUNetv2_train DATASET_ID 3d_fullres 2
nnUNetv2_train DATASET_ID 3d_fullres 3
nnUNetv2_train DATASET_ID 3d_fullres 4
# Alternative: Train ResEnc teacher models for enhanced performance
nnUNetv2_plan_and_preprocess -d DATASET_ID --verify_dataset_integrity -pl nnUNetPlannerResEncM/nnUNetPlannerResEncL/nnUNetPlannerResEncXL
# If you only intend to run 3d_fullres and 2d and have already preprocessed these datasets, to avoid reprocessing, run directly:
nnUNetv2_plan_experiment -d DATASET_ID -pl nnUNetPlannerResEncM/nnUNetPlannerResEncL/nnUNetPlannerResEncXL
nnUNetv2_train DATASET_ID 3d_fullres 0 -p nnUNetResEncUNetMPlans/nnUNetResEncUNetLPlans/nnUNetResEncUNetXLPlans
nnUNetv2_train DATASET_ID 3d_fullres 1 -p nnUNetResEncUNetMPlans/nnUNetResEncUNetLPlans/nnUNetResEncUNetXLPlans
nnUNetv2_train DATASET_ID 3d_fullres 2 -p nnUNetResEncUNetMPlans/nnUNetResEncUNetLPlans/nnUNetResEncUNetXLPlans
nnUNetv2_train DATASET_ID 3d_fullres 3 -p nnUNetResEncUNetMPlans/nnUNetResEncUNetLPlans/nnUNetResEncUNetXLPlans
nnUNetv2_train DATASET_ID 3d_fullres 4 -p nnUNetResEncUNetMPlans/nnUNetResEncUNetLPlans/nnUNetResEncUNetXLPlans
Use the trained teacher models for knowledge distillation:
# Basic usage (using all available folds of teacher models)
nnUNetv2_distillation_train -d DATASET_ID -f 0 -a 0.3 -temp 3.0 -r 2
# Advanced usage - specify specific teacher model folds
nnUNetv2_distillation_train -d DATASET_ID -f 0 -tf 0 1 2 3 4 -a 0.3 -temp 3.0 -r 2
# Continue previous training
nnUNetv2_distillation_train -d DATASET_ID -f 0 -a 0.3 -temp 3.0 -r 2 -c_continue
# With custom teacher model path
nnUNetv2_distillation_train -d DATASET_ID -f 0 -t /path/to/teacher -a 0.3 -temp 3.0 -r 2
# Disable mirroring during validation
nnUNetv2_distillation_train -d DATASET_ID -f 0 -a 0.3 -temp 3.0 -r 2 -disable_mirroring
# Enable fold rotation during training
nnUNetv2_distillation_train -d DATASET_ID -f 0 -a 0.3 -temp 3.0 -r 2 -rotate_folds -rotate_freq 400
# Use DA5 strong data augmentation (recommended for small datasets)
nnUNetv2_distillation_train -d DATASET_ID -f 0 -a 0.3 -temp 3.0 -r 2 --use_da5
# Combine DA5 with other options
nnUNetv2_distillation_train -d DATASET_ID -f 0 -tf 0 1 2 3 4 -a 0.3 -temp 3.0 -r 2 --use_da5 -c_continue# Basic ResEnc distillation (using ResEnc teacher models)
nnUNetv2_resenc_distillation_train -d DATASET_ID -f 0 -a 0.3 -temp 3.0 -r 2
# Advanced ResEnc distillation - specify teacher plans and folds
nnUNetv2_resenc_distillation_train -d DATASET_ID -f 0 -tf 0 1 2 3 4 -a 0.3 -temp 3.0 -r 2 -tpl nnUNetResEncUNetLPlans
# ResEnc distillation with different teacher plan variants
nnUNetv2_resenc_distillation_train -d DATASET_ID -f 0 -a 0.3 -temp 3.0 -r 2 -tpl nnUNetResEncUNetMPlans # Medium
nnUNetv2_resenc_distillation_train -d DATASET_ID -f 0 -a 0.3 -temp 3.0 -r 2 -tpl nnUNetResEncUNetLPlans # Large
nnUNetv2_resenc_distillation_train -d DATASET_ID -f 0 -a 0.3 -temp 3.0 -r 2 -tpl nnUNetResEncUNetXLPlans # Extra Large
# ResEnc distillation with custom teacher model path
nnUNetv2_resenc_distillation_train -d DATASET_ID -f 0 -t /path/to/resenc/teacher -a 0.3 -temp 3.0 -r 2
# ResEnc distillation with different student plans (ResEnc student)
nnUNetv2_resenc_distillation_train -d DATASET_ID -f 0 -a 0.3 -temp 3.0 -r 2 -spl nnUNetResEncUNetMPlans
# ResEnc distillation with block reduction strategy
nnUNetv2_resenc_distillation_train -d DATASET_ID -f 0 -a 0.3 -temp 3.0 -r 2 -bs adaptive
# Continue ResEnc distillation training
nnUNetv2_resenc_distillation_train -d DATASET_ID -f 0 -a 0.3 -temp 3.0 -r 2 -c_continue
# ResEnc distillation with DA5 strong data augmentation
nnUNetv2_resenc_distillation_train -d DATASET_ID -f 0 -a 0.3 -temp 3.0 -r 2 --use_da5
# ResEnc distillation combining DA5 with ResEnc teacher and student
nnUNetv2_resenc_distillation_train -d DATASET_ID -f 0 -a 0.3 -temp 3.0 -r 2 -tpl nnUNetResEncUNetLPlans -spl nnUNetResEncUNetMPlans --use_da5Parameter description:
-d, --dataset_id: Dataset ID-f, --fold: Fold number used to train the student model-tf, --teacher_folds: List of teacher model fold numbers, automatically detected if not specified-a, --alpha: Distillation loss weight, controls the ratio of soft label and hard label losses-temp, --temperature: Distillation temperature, controls the smoothness of soft labels-r, --reduction_factor: Feature reduction factor, higher values produce smaller student models-c_continue, --continue_training: Continue previous training-disable_mirroring: Disable mirror augmentation during validation-e, --epochs: Maximum number of training epochs (default 1000)-t, --teacher_model_folder: Teacher model folder path (auto-constructed if not provided)-tpl, --teacher_plans: Teacher model plans identifier (ResEnc specific, default: nnUNetResEncUNetLPlans)-spl, --student_plans: Student model plans identifier (ResEnc specific, default: nnUNetPlans)-bs, --block_strategy: Block reduction strategy for ResEnc (reduce/keep/increase/adaptive, default: keep)-rotate_folds: Enable rotating training folds periodically-rotate_freq: How often to rotate folds (in epochs, default: 5 for ResEnc, 400 for standard)-d_device: Device to use, e.g., "cuda:0"--use_da5: Use DA5 strong data augmentation (recommended for small datasets)
Export the trained student model to ONNX format for fast inference:
# Basic export
nnUNetv2_distillation_export_onnx -d DATASET_ID -f 0 -r 2
# Verbose information display
nnUNetv2_distillation_export_onnx -d DATASET_ID -f 0 -r 2 -v
# Specify output path
nnUNetv2_distillation_export_onnx -d DATASET_ID -f 0 -r 2 -o /path/to/output.onnx
# Export with daynmic input shape
nnUNetv2_distillation_export_onnx -d DATASET_ID -f 0 -r 2 -da
# Export with custom input shape
nnUNetv2_distillation_export_onnx -d DATASET_ID -f 0 -r 2 -is 1 1 128 128 128
# Export in nnUNet format (single channel, fixed size)
nnUNetv2_distillation_export_onnx -d DATASET_ID -f 0 -r 2 -nn
# Export DA5-trained model
nnUNetv2_distillation_export_onnx -d DATASET_ID -f 0 -r 2 -da5
# Export DA5-trained model with verbose output
nnUNetv2_distillation_export_onnx -d DATASET_ID -f 0 -r 2 -da5 -v
# Export with simplified ONNX
nnUNetv2_distillation_export_onnx -d DATASET_ID -f 0 -r 2 -sim# Basic ResEnc model export
nnUNetv2_resenc_distillation_export_onnx -d DATASET_ID -f 0 -r 2
# ResEnc export with verbose output
nnUNetv2_resenc_distillation_export_onnx -d DATASET_ID -f 0 -r 2 -v
# ResEnc export with custom output directory
nnUNetv2_resenc_distillation_export_onnx -d DATASET_ID -f 0 -r 2 -o /path/to/output/dir
# ResEnc export with specific student plans
nnUNetv2_resenc_distillation_export_onnx -d DATASET_ID -f 0 -r 2 -spl nnUNetResEncUNetMPlans
# ResEnc export with fixed batch size
nnUNetv2_resenc_distillation_export_onnx -d DATASET_ID -f 0 -r 2 -b 1
# ResEnc export with TensorRT compatibility
nnUNetv2_resenc_distillation_export_onnx -d DATASET_ID -f 0 -r 2 -fix
# ResEnc export with custom plans identifier
nnUNetv2_resenc_distillation_export_onnx -d DATASET_ID -f 0 -r 2 -p nnUNetResEncUNetLPlans
# ResEnc export for DA5-trained model
nnUNetv2_resenc_distillation_export_onnx -d DATASET_ID -f 0 -r 2 -da5
# ResEnc export for DA5-trained model with TensorRT compatibility
nnUNetv2_resenc_distillation_export_onnx -d DATASET_ID -f 0 -r 2 -da5 -fix
# ResEnc export with simplified ONNX
nnUNetv2_resenc_distillation_export_onnx -d DATASET_ID -f 0 -r 2 -simParameter description:
-d, --dataset_id: Dataset ID-f, --fold: Model fold number-r, --reduction_factor: Feature reduction factor, must be consistent with training-o, --output: Output ONNX file path (standard) or output directory (ResEnc)-v, --verbose: Display detailed information-da: Use dynamic input shape (default uses static shape)-is: Custom input shape (b c x y z) - standard export only-nn: Export single channel fixed size model - standard export only-spl, --student_plans: Student model plans identifier - ResEnc export only-b, --batch_size: Batch size, 0 means dynamic - ResEnc export only--trt: Apply TensorRT compatibility fixes - ResEnc export only-p, --plans: Plans identifier - ResEnc export only-da5: Model was trained with DA5 data augmentation (both standard and ResEnc export)-sim: Simplify ONNX model
For the best balance between performance and accuracy, the following parameter settings are recommended:
-
📊 Feature Reduction Factor (
-r):- 2: Balanced mode, model size reduced by 75%, minimal accuracy loss
- 4: Lightweight mode, model size reduced by 94%, may have 1-2% accuracy loss
- 8: Ultra-lightweight mode, model size reduced by 98%, suitable for extremely resource-constrained scenarios
-
⚖️ Distillation Loss Weight (
-a):- 0.3: Default value, suitable for most scenarios
- 0.5: More emphasis on soft label knowledge, suitable for complex segmentation tasks
- 0.1: More emphasis on hard labels, suitable for simple segmentation tasks
-
🌡️ Distillation Temperature (
-temp):- 3.0: Default value, provides moderately smooth soft labels
- 1.0: Sharper soft labels, closer to original predictions
- 5.0: Very smooth soft labels, maximizes knowledge transfer
-
🔥 DA5 Data Augmentation (
--use_da5):- Recommended for small datasets (< 100 training cases)
- Particularly effective for challenging segmentation tasks with fine anatomical structures
- Can improve model robustness by 2-5% on small datasets
- Compatible with all reduction factors and architectures
- May increase training time by 10-15% due to more intensive augmentation
If you have teacher models trained elsewhere, you can specify them using the -t parameter:
# Standard distillation with custom teacher
nnUNetv2_distillation_train -d DATASET_ID -f 0 -t /path/to/teacher/model -a 0.3 -temp 3.0 -r 2
# ResEnc distillation with custom teacher
nnUNetv2_resenc_distillation_train -d DATASET_ID -f 0 -t /path/to/resenc/teacher/model -a 0.3 -temp 3.0 -r 2Choose different ResEnc architectures based on your computational requirements:
# ResEnc Medium (balanced performance and efficiency)
nnUNetv2_resenc_distillation_train -d DATASET_ID -f 0 -tpl nnUNetResEncUNetMPlans -a 0.3 -temp 3.0 -r 2
# ResEnc Large (higher performance for complex tasks)
nnUNetv2_resenc_distillation_train -d DATASET_ID -f 0 -tpl nnUNetResEncUNetLPlans -a 0.3 -temp 3.0 -r 2
# ResEnc Extra Large (maximum performance)
nnUNetv2_resenc_distillation_train -d DATASET_ID -f 0 -tpl nnUNetResEncUNetXLPlans -a 0.3 -temp 3.0 -r 2Although student models are usually smaller, multi-GPU training is still supported:
# Standard distillation multi-GPU
CUDA_VISIBLE_DEVICES=0,1 nnUNetv2_distillation_train -d DATASET_ID -f 0 -a 0.3 -temp 3.0 -r 2
# ResEnc distillation multi-GPU
CUDA_VISIBLE_DEVICES=0,1 nnUNetv2_resenc_distillation_train -d DATASET_ID -f 0 -a 0.3 -temp 3.0 -r 2
# Specify specific GPU device
nnUNetv2_distillation_train -d DATASET_ID -f 0 -a 0.3 -temp 3.0 -r 2 -d_device cuda:0
nnUNetv2_resenc_distillation_train -d DATASET_ID -f 0 -a 0.3 -temp 3.0 -r 2 -d_device cuda:1
# Multi-GPU training with DA5 for small datasets
CUDA_VISIBLE_DEVICES=0,1 nnUNetv2_distillation_train -d DATASET_ID -f 0 -a 0.3 -temp 3.0 -r 2 --use_da5
CUDA_VISIBLE_DEVICES=0,1 nnUNetv2_resenc_distillation_train -d DATASET_ID -f 0 -a 0.3 -temp 3.0 -r 2 --use_da5DA5 is an advanced data augmentation strategy specifically designed for small datasets and challenging segmentation tasks. It provides stronger augmentation compared to the default nnUNet augmentation.
- Small datasets: Datasets with fewer than 100 training cases
- Challenging segmentation tasks: Tasks with fine anatomical structures or difficult-to-segment regions
- Limited training data: When you want to maximize the use of available training samples
- Improved robustness: When you need models that generalize better to unseen data variations
- Enhanced spatial transformations: More aggressive rotation, scaling, and elastic deformation
- Reduced interpolation order: Better preservation of fine details during augmentation
- Comprehensive intensity augmentation: Advanced brightness, contrast, and gamma transformations
- Optimized for medical imaging: Specifically tuned for medical image characteristics
# Small dataset scenario (e.g., rare disease with <50 cases)
nnUNetv2_distillation_train -d DATASET_ID -f 0 -a 0.5 -temp 4.0 -r 2 --use_da5
# Challenging anatomy (e.g., small organ segmentation)
nnUNetv2_resenc_distillation_train -d DATASET_ID -f 0 -a 0.4 -temp 3.5 -r 2 -tpl nnUNetResEncUNetLPlans --use_da5
# Combining DA5 with other advanced features
nnUNetv2_resenc_distillation_train -d DATASET_ID -f 0 -a 0.3 -temp 3.0 -r 2 --use_da5 -rotate_folds -bs adaptive-
Out of Memory Errors:
- Reduce batch size: The trainer will automatically select an appropriate batch size
- Use a larger feature reduction factor: Try
-r 4or higher
-
Excessive Accuracy Decline:
- Reduce the feature reduction factor: Try
-r 2or lower - Increase training epochs: Use
-e 1500or more - Adjust temperature parameter: Try
-temp 2.0
- Reduce the feature reduction factor: Try
-
Training Instability:
- Adjust distillation weight: Try
-a 0.4or-a 0.2 - Ensure teacher model quality: Check teacher model performance
- Adjust distillation weight: Try
-
ResEnc Specific Issues:
- Teacher model not found: Ensure ResEnc teacher models are properly trained with correct plans identifier
- Architecture mismatch: Verify teacher and student plans compatibility when using ResEnc variants
- Memory issues with large ResEnc models: Consider using ResEncM instead of ResEncL/ResEncXL for limited GPU memory
-
DA5 Related Issues:
- Slower training with DA5: This is expected due to more intensive augmentation; consider reducing batch size if needed
- DA5 not improving results: DA5 is most effective on small datasets; may not provide benefits on large datasets (>500 cases)
- Export issues with DA5 models: Ensure you use
--use_da5flag when exporting models trained with DA5 - Memory issues with DA5: The stronger augmentation may require more GPU memory; reduce batch size or use gradient accumulation
If you use FastnnUNet Knowledge Distillation in your research, please cite:
Isensee, F., Jaeger, P. F., Kohl, S. A., Petersen, J., & Maier-Hein, K. H. (2021). nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation. Nature methods, 18(2), 203-211.
