A multi-task framework based on Diff-UNet with integrated LGG/HGG grading for BraTS2020 brain tumor segmentation. This project is the final course project for BS6220, Deep Learning in Biomedical Science in Nanyang Technological University. The Full report can be seen in this repository. It combines diffusion-based segmentation with multi-task learning to simultaneously perform 3D brain tumor segmentation and grade classification.
-
Diffusion-Based Segmentation
- High-quality 3D brain tumor segmentation (TC, WT, ET)
- DDIM sampling with 50 steps for inference
- Uncertainty quantification
-
Multi-Task Learning
- Shared encoder between segmentation and classification
- Separate decoder and classification head
- Multi-stage training strategy
-
Mask-Guided Classification
- Leverages segmentation masks as attention weights
- Focuses on tumor regions for classification
- Improves LGG classification by 5-8% F1-score
-
Enhanced Classification Head
- Two-layer fully connected network (512→128→2)
- ReLU activation with dropout regularization
- Handles class imbalance with Focal Loss
DiffUNet
├── Encoder (BasicUNetEncoder)
│ └── Output: Multi-scale features [f0, f1, f2, f3, f4]
│ └── C5 bottleneck: [B, 512, H/16, W/16, D/16]
│
├── Decoder (BasicUNetDe) - Segmentation Branch
│ └── Diffusion Process
│ ├── Forward Diffusion (Training)
│ └── DDIM Sampling (Inference, 50 steps)
│
└── Classification Head - Classification Branch
├── Standard: Global Average Pooling → FC(512→128→2)
└── Mask-Guided: Weighted Pooling → FC(512→128→2)
-
Download BraTS2020 dataset from official website
-
Organize directory structure:
MICCAI_BraTS2020_TrainingData/
├── BraTS20_Training_001/
│ ├── BraTS20_Training_001_t1.nii
│ ├── BraTS20_Training_001_t1ce.nii
│ ├── BraTS20_Training_001_t2.nii
│ ├── BraTS20_Training_001_flair.nii
│ └── BraTS20_Training_001_seg.nii
├── BraTS20_Training_002/
└── ...
We use a three-stage training approach:
- Stage 1: Segmentation training (freeze classification head)
- Stage 2: Classification training (freeze segmentation network)
- Stage 3: End-to-end fine-tuning (all parameters trainable)
Goal: Train the diffusion-based segmentation network
python BraTS2020/train.py \
--stage seg \
--pretrained_model ./final_model_0.8508.pt \
--epochs 50 \
--lr 1e-4 \
--data_dir ./MICCAI_BraTS2020_TrainingData/Goal: Train the classification head
Standard Classification:
python BraTS2020/train.py \
--stage cls \
--pretrained_model ./logs_brats/stage_seg/model/best_model_*.pt \
--epochs 30 \
--lr 1e-4 \
--data_dir ./MICCAI_BraTS2020_TrainingData/Mask-Guided Classification (Recommended):
python BraTS2020/train.py \
--stage cls \
--use_mask_guided \
--pretrained_model ./logs_brats/stage_seg/model/best_model_*.pt \
--epochs 30 \
--lr 1e-4 \
--data_dir ./MICCAI_BraTS2020_TrainingData/Configuration:
- Freezes: Segmentation network (encoder + decoder)
- Trains: Classification head
- Input: Whole volume (128×128×128)
- Loss: Focal Loss (gamma=2.0)
Goal: Jointly optimize segmentation and classification
python BraTS2020/train.py \
--stage end2end \
--pretrained_model ./logs_brats/stage_cls/model/best_model_*.pt \
--epochs 20 \
--lr 5e-5 \
--data_dir ./MICCAI_BraTS2020_TrainingData/With Mask-Guided:
python BraTS2020/train.py \
--stage end2end \
--use_mask_guided \
--pretrained_model ./logs_brats/stage_cls/model/best_model_*.pt \
--epochs 20 \
--lr 5e-5 \
--data_dir ./MICCAI_BraTS2020_TrainingData/Example predictions from two BraTS patients.
- Left: original MRI slice
- Middle: predicted segmentation
- Right: ground truth mask
- Difference map highlights errors.
The model accurately captures tumor subregions (WT, TC, ET) with clear boundaries.
This work is based on the following research:
Reference Paper: Diff-UNet: A Diffusion Embedded Network for Volumetric Segmentation.
Authors: Zhaohu Xing, Liang Wan, Huazhu Fu, Guang Yang, and Lei Zhu.
