Skip to content

irfanfadhullah/3d_tooth_segmentation

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

1 Commit
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

🦷 3D Tooth Segmentation Benchmark

A unified framework for training, evaluating, and benchmarking
point cloud segmentation methods on intraoral 3D dental scans.

Models β€’ Install β€’ Quick Start β€’ Training β€’ Evaluation β€’ Extend β€’ Contributing


πŸ“‹ Overview

This repository provides a modular, extensible benchmark for 3D tooth segmentation from intraoral scans. It includes:

  • 18+ model implementations spanning dental-specific architectures, general point cloud backbones, and self-supervised pre-training methods
  • A model registry with @register_model decorators β€” adding a new architecture requires zero changes to the training loop
  • Unified train.py and test.py scripts that handle all models through a single CLI
  • YAML-based configuration with sensible defaults and per-model overrides
  • Copy-and-customize templates for new models and datasets

Attribution: This codebase is built upon ToothGroupNetwork by Team CGIP. We thank the original authors for their excellent baseline and data processing pipeline.


πŸ— Architecture

boilerplate_segmentation/
β”œβ”€β”€ train.py                    # Unified training entry point (all models)
β”œβ”€β”€ test.py                     # Unified evaluation entry point (all models)
β”œβ”€β”€ smoke_test.py               # Pipeline verification (no GPU/data needed)
β”œβ”€β”€ config_loader.py            # YAML + legacy Python config loader
β”‚
β”œβ”€β”€ configs/                    # YAML configuration files
β”‚   β”œβ”€β”€ default.yaml            #   Base config (inherited by all)
β”‚   β”œβ”€β”€ dgcnn.yaml              #   Per-model overrides
β”‚   β”œβ”€β”€ pointnet.yaml
β”‚   └── ...
β”‚
β”œβ”€β”€ models/                     # Model wrappers + modules
β”‚   β”œβ”€β”€ registry.py             #   @register_model decorator & factory
β”‚   β”œβ”€β”€ base_model.py           #   Abstract base class
β”‚   β”œβ”€β”€ new_model_template.py   #   Template for adding new models
β”‚   β”œβ”€β”€ dgcnn_model.py          #   DGCNN wrapper
β”‚   β”œβ”€β”€ pointnet_model.py       #   PointNet wrapper
β”‚   β”œβ”€β”€ ...                     #   (18 model wrappers total)
β”‚   └── modules/                #   Neural network implementations
β”‚       β”œβ”€β”€ dgcnn_module.py
β”‚       └── ...
β”‚
β”œβ”€β”€ datasets/                   # Dataset classes
β”‚   β”œβ”€β”€ base_dataset.py         #   Abstract base dataset
β”‚   β”œβ”€β”€ dental_dataset.py       #   Dental scan dataset
β”‚   └── new_dataset_template.py #   Template for adding new datasets
β”‚
β”œβ”€β”€ train_configs/              # Legacy Python configs (backward compat)
β”œβ”€β”€ inference_pipelines/        # Model-specific inference pipelines
β”œβ”€β”€ external_libs/              # PointNet2, PointOps CUDA extensions
β”‚
β”œβ”€β”€ generator.py                # Legacy dataset (backward compat)
β”œβ”€β”€ trainer.py                  # Training loop implementation
β”œβ”€β”€ runner.py                   # Training orchestrator
β”œβ”€β”€ loss_meter.py               # Loss aggregation utilities
β”œβ”€β”€ augmentator.py              # Point cloud augmentations
β”œβ”€β”€ gen_utils.py                # General utilities
β”œβ”€β”€ ops_utils.py                # Point cloud operations
β”œβ”€β”€ preprocess_data.py          # Raw mesh β†’ preprocessed .npy
β”œβ”€β”€ eval_visualize_results.py   # Metric computation & visualization
└── predict_utils.py            # Inference utilities

🧠 Supported Models

Category Model Registry Name Reference
Our Method DentalMAE (Pretrain) dental_mae_pretrain β€”
DentalMAE-Seg (Fine-tune) dental_mae_seg β€”
Dental-Specific TGNet-FPS (Challenge Winner) tgnet_fps ToothGroupNetwork
TGNet-BDL tgnet_bdl ToothGroupNetwork
TSegNet tsegnet Paper
TSegFormer tsegformer Paper
MeshSegNet meshsegnet Paper
TeethGNN teethgnn β€”
HiCA hica β€”
SGTNet sgtnet β€”
SGTCNet sgtcnet β€”
TSGCNet tsgcnet β€”
Fast TGCN fast_tgcn β€”
UpToothSeg uptoothseg β€”
Dilated Tooth Seg dilated_tooth_seg β€”
General Backbones PointNet pointnet Paper
PointNet++ pointnetpp Paper
DGCNN dgcnn Paper
Point Transformer pointtransformer Paper

List all models: python train.py --list_models


βš™ Installation

Requirements

  • Python 3.8+
  • PyTorch 1.7.1+ with CUDA 11.0+
  • Ubuntu 18.04+ (tested)

Setup

# 1. Clone the repository
git clone https://github.com/your-username/tooth-segmentation-benchmark.git
cd tooth-segmentation-benchmark

# 2. Create a virtual environment (recommended)
conda create -n tooth_seg python=3.10 -y
conda activate tooth_seg

# 3. Install PyTorch (adjust for your CUDA version)
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118

# 4. Install dependencies
pip install wandb open3d multimethod termcolor trimesh easydict pyyaml scikit-learn

# 5. Install PointOps (required for Point Transformer & DentalMAE)
cd external_libs/pointops && python setup.py install && cd ../..

# 6. Verify installation
python smoke_test.py

πŸš€ Quick Start

1. Preprocess Raw Data

Convert raw .obj meshes to preprocessed .npy point clouds:

python preprocess_data.py \
  --source_obj_data_path data_obj_parent_directory \
  --source_json_data_path data_json_parent_directory \
  --save_data_path path/to/preprocessed_data

2. Train a Model

# Train DGCNN with YAML config
python train.py \
  --model_name dgcnn \
  --config configs/dgcnn.yaml \
  --experiment_name dgcnn_baseline \
  --data_dir path/to/preprocessed_data \
  --epochs 200

# Train with WandB logging
python train.py \
  --model_name pointnet \
  --config configs/pointnet.yaml \
  --experiment_name pointnet_v1 \
  --data_dir path/to/preprocessed_data

3. Evaluate

python test.py \
  --model_name dgcnn \
  --config configs/dgcnn.yaml \
  --checkpoint ckpts/dgcnn_baseline_val.h5 \
  --data_dir path/to/preprocessed_data \
  --test_split base_name_test_fold.txt \
  --save_predictions results/dgcnn/

πŸ‹ Training

Unified Training Script

The train.py script handles all models via the model registry:

python train.py --model_name <MODEL> --config <CONFIG> [OPTIONS]
Argument Description Default
--model_name Model name from registry dgcnn
--config Path to .yaml or .py config configs/dgcnn.yaml
--experiment_name Name for checkpoints & logging experiment
--data_dir Preprocessed data directory β€”
--train_split Train split txt file base_name_train_fold.txt
--val_split Validation split txt file base_name_val_fold.txt
--epochs Number of epochs (overrides config) 200
--batch_size Batch size (overrides config) 1
--lr Learning rate (overrides config) β€”
--resume Checkpoint path to resume from β€”
--wandb_off Disable WandB logging False
--device Force device (cuda/cpu) auto
--val_every Validate every N epochs 1

DentalMAE Two-Phase Training

# Phase 1: Self-supervised pre-training
python train.py \
  --model_name dental_mae_pretrain \
  --config configs/default.yaml \
  --experiment_name mae_pretrain \
  --data_dir path/to/unlabeled_data

# Phase 2: Supervised fine-tuning
python train.py \
  --model_name dental_mae_seg \
  --config configs/default.yaml \
  --experiment_name mae_finetune \
  --data_dir path/to/labeled_data

Legacy Training Script

The original start_train.py is still available for backward compatibility:

python start_train.py \
  --model_name dgcnn \
  --config_path train_configs/dgcnn.py \
  --experiment_name dgcnn_exp \
  ...

πŸ“Š Evaluation

Unified Test Script

python test.py --model_name <MODEL> --config <CONFIG> --checkpoint <CKPT> [OPTIONS]
Argument Description
--checkpoint Path to model checkpoint (.h5)
--test_split Test split txt file
--save_predictions Directory to save prediction JSONs
--num_classes Number of classes (default: 17)

Output metrics:

  • Mean IoU (Intersection over Union)
  • Mean F1 (Dice Score)
  • Accuracy (overall and per-class)
  • Per-class IoU and F1 breakdown

Visualization

python eval_visualize_results.py \
  --mesh_path path/to/obj_file \
  --gt_json_path path/to/gt_json_file \
  --pred_json_path path/to/predicted_json_file

πŸ”§ Adding New Models

Step 1: Create the Model Wrapper

cp models/new_model_template.py models/my_model_model.py

Edit my_model_model.py β€” fill in the TODO markers:

from models.registry import register_model
from models.base_model import BaseModel

@register_model("my_model")          # ← Name used in CLI
class MyModel(BaseModel):
    def __init__(self, config, module=None):
        if module is None:
            module = MyModule            # ← Your nn.Module class
        super().__init__(config, module)

    def get_loss(self, outputs, gt):
        # ← Define your loss computation
        ...

    def step(self, batch_idx, batch_item, phase="train"):
        # ← Forward pass + optimization
        ...

Step 2: Create a Config

cp configs/dgcnn.yaml configs/my_model.yaml
# Edit optimizer, scheduler, model_parameter as needed

Step 3: Train

python train.py --model_name my_model --config configs/my_model.yaml ...

That's it. No changes to train.py, test.py, or any other file. The @register_model decorator handles everything.


πŸ“¦ Adding New Datasets

cp datasets/new_dataset_template.py datasets/my_dataset.py

Your dataset must inherit from BaseSegDataset and return a dict with:

Key Shape Description
feat [C, N] Point features (channels-first)
gt_seg_label [1, N] Segmentation labels
category [2] One-hot jaw category
mesh_path str Source file path

πŸ“ Dataset

Labeled Dataset (Training & Evaluation)

We use the 3DTeethSeg'22 Challenge dataset.

data_obj_parent_directory/
β”œβ”€β”€ 00OMSZGW/
β”‚   β”œβ”€β”€ 00OMSZGW_lower.obj
β”‚   └── 00OMSZGW_upper.obj
└── ...

data_json_parent_directory/
β”œβ”€β”€ 00OMSZGW/
β”‚   β”œβ”€β”€ 00OMSZGW_lower.json
β”‚   └── 00OMSZGW_upper.json
└── ...

Unlabeled Dataset (DentalMAE Pre-training)

For self-supervised pre-training, download the unlabeled scans from OneDrive.


πŸ§ͺ Testing the Pipeline

Run the smoke test to verify the entire pipeline without real data or GPU:

python smoke_test.py

This tests: model registry (18 models), YAML config loading & merging, dataset instantiation with dummy data, DataLoader batching, loss infrastructure, and CLI script importability.


🀝 Contributing

We welcome contributions! Here's how to get started:

Workflow

  1. Fork the repository
  2. Create a branch for your feature: git checkout -b feature/my-new-model
  3. Implement your changes (see Adding New Models)
  4. Test with python smoke_test.py
  5. Submit a Pull Request with a clear description

Contribution Ideas

Area Difficulty Description
🧠 New Model ⭐⭐ Add a new segmentation architecture
πŸ“Š New Dataset ⭐⭐ Add support for another dental dataset
πŸ“ˆ Metrics ⭐ Add boundary F1, Hausdorff distance
πŸ” Visualization ⭐⭐ Add 3D prediction visualization tools
πŸ“ Documentation ⭐ Improve docstrings and examples
⚑ Performance ⭐⭐⭐ Multi-GPU training, mixed precision

Code Style

  • Follow PEP 8 conventions
  • Add docstrings to all public functions and classes
  • Use the model registry β€” never add if/elif chains for new models
  • Write a corresponding YAML config for every new model
  • Ensure python smoke_test.py passes before submitting

πŸ“– Configuration

YAML Config Structure

# configs/my_model.yaml β€” Inherits from configs/default.yaml

tr_set:
  optimizer:
    NAME: "adam"         # adam | sgd | adamw
    lr: 0.001
    weight_decay: 0.0001
  scheduler:
    sched: "cosine"      # cosine | exp | step
    full_steps: 40
    min_lr: 0.00001

model_parameter:
  input_feat: 6          # XYZ + normals
  num_classes: 17         # 16 teeth + gingiva

training:
  epochs: 200
  val_every: 1

wandb:
  wandb_on: false
  project: "tooth_seg"

CLI arguments override YAML values (e.g. --lr 0.0005 overrides optimizer.lr).


πŸ“„ License

This project is released for academic and research purposes. Please refer to the original ToothGroupNetwork repository for license details.


πŸ“š References

  • ToothGroupNetwork: Lim, H., et al. "3D Dental Segmentation via Tooth Group Network." MICCAI 2022. GitHub
  • 3DTeethSeg Challenge: grand-challenge.org
  • PointNet: Qi, C.R., et al. "PointNet: Deep Learning on Point Sets." CVPR 2017.
  • PointNet++: Qi, C.R., et al. "PointNet++: Deep Hierarchical Feature Learning." NeurIPS 2017.
  • DGCNN: Wang, Y., et al. "Dynamic Graph CNN for Learning on Point Clouds." TOG 2019.
  • Point Transformer: Zhao, H., et al. "Point Transformer." ICCV 2021.

πŸ™ Acknowledgements

This codebase is built upon the excellent work of Team CGIP's ToothGroupNetwork. We are grateful for their open-source contribution which made this benchmark possible.


If you find this work useful, please consider giving it a ⭐

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors