Skip to content

Chen-ZJ79/Diffusion-Based-Brain-Tumor-Segmentation-and-Classification

Repository files navigation

Multi-Task Brain Tumor Segmentation and grading

A multi-task framework based on Diff-UNet with integrated LGG/HGG grading for BraTS2020 brain tumor segmentation. This project is the final course project for BS6220, Deep Learning in Biomedical Science in Nanyang Technological University. The Full report can be seen in this repository. It combines diffusion-based segmentation with multi-task learning to simultaneously perform 3D brain tumor segmentation and grade classification.


Core Capabilities

  1. Diffusion-Based Segmentation

    • High-quality 3D brain tumor segmentation (TC, WT, ET)
    • DDIM sampling with 50 steps for inference
    • Uncertainty quantification
  2. Multi-Task Learning

    • Shared encoder between segmentation and classification
    • Separate decoder and classification head
    • Multi-stage training strategy
  3. Mask-Guided Classification

    • Leverages segmentation masks as attention weights
    • Focuses on tumor regions for classification
    • Improves LGG classification by 5-8% F1-score
  4. Enhanced Classification Head

    • Two-layer fully connected network (512→128→2)
    • ReLU activation with dropout regularization
    • Handles class imbalance with Focal Loss

Overall Structure

DiffUNet
├── Encoder (BasicUNetEncoder)
│   └── Output: Multi-scale features [f0, f1, f2, f3, f4]
│       └── C5 bottleneck: [B, 512, H/16, W/16, D/16]
│
├── Decoder (BasicUNetDe) - Segmentation Branch
│   └── Diffusion Process
│       ├── Forward Diffusion (Training)
│       └── DDIM Sampling (Inference, 50 steps)
│
└── Classification Head - Classification Branch
    ├── Standard: Global Average Pooling → FC(512→128→2)
    └── Mask-Guided: Weighted Pooling → FC(512→128→2)

BraTS2020 Dataset

  1. Download BraTS2020 dataset from official website

  2. Organize directory structure:

MICCAI_BraTS2020_TrainingData/
├── BraTS20_Training_001/
│   ├── BraTS20_Training_001_t1.nii
│   ├── BraTS20_Training_001_t1ce.nii
│   ├── BraTS20_Training_001_t2.nii
│   ├── BraTS20_Training_001_flair.nii
│   └── BraTS20_Training_001_seg.nii
├── BraTS20_Training_002/
└── ...

3 Stage Training Strategy

We use a three-stage training approach:

  1. Stage 1: Segmentation training (freeze classification head)
  2. Stage 2: Classification training (freeze segmentation network)
  3. Stage 3: End-to-end fine-tuning (all parameters trainable)

Stage 1: Segmentation Training

Goal: Train the diffusion-based segmentation network

python BraTS2020/train.py \
    --stage seg \
    --pretrained_model ./final_model_0.8508.pt \
    --epochs 50 \
    --lr 1e-4 \
    --data_dir ./MICCAI_BraTS2020_TrainingData/

Stage 2: Classification Training

Goal: Train the classification head

Standard Classification:

python BraTS2020/train.py \
    --stage cls \
    --pretrained_model ./logs_brats/stage_seg/model/best_model_*.pt \
    --epochs 30 \
    --lr 1e-4 \
    --data_dir ./MICCAI_BraTS2020_TrainingData/

Mask-Guided Classification (Recommended):

python BraTS2020/train.py \
    --stage cls \
    --use_mask_guided \
    --pretrained_model ./logs_brats/stage_seg/model/best_model_*.pt \
    --epochs 30 \
    --lr 1e-4 \
    --data_dir ./MICCAI_BraTS2020_TrainingData/

Configuration:

  • Freezes: Segmentation network (encoder + decoder)
  • Trains: Classification head
  • Input: Whole volume (128×128×128)
  • Loss: Focal Loss (gamma=2.0)

Stage 3: End-to-End Fine-tuning

Goal: Jointly optimize segmentation and classification

python BraTS2020/train.py \
    --stage end2end \
    --pretrained_model ./logs_brats/stage_cls/model/best_model_*.pt \
    --epochs 20 \
    --lr 5e-5 \
    --data_dir ./MICCAI_BraTS2020_TrainingData/

With Mask-Guided:

python BraTS2020/train.py \
    --stage end2end \
    --use_mask_guided \
    --pretrained_model ./logs_brats/stage_cls/model/best_model_*.pt \
    --epochs 20 \
    --lr 5e-5 \
    --data_dir ./MICCAI_BraTS2020_TrainingData/

Segmentation Visualization

Example predictions from two BraTS patients.

  • Left: original MRI slice
  • Middle: predicted segmentation
  • Right: ground truth mask
  • Difference map highlights errors.

Segmentation Example

The model accurately captures tumor subregions (WT, TC, ET) with clear boundaries.


Acknowledgments

This work is based on the following research:

Reference Paper: Diff-UNet: A Diffusion Embedded Network for Volumetric Segmentation.

Authors: Zhaohu Xing, Liang Wan, Huazhu Fu, Guang Yang, and Lei Zhu.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages