A unified framework for training, evaluating, and benchmarking
point cloud segmentation methods on intraoral 3D dental scans.
Models β’ Install β’ Quick Start β’ Training β’ Evaluation β’ Extend β’ Contributing
This repository provides a modular, extensible benchmark for 3D tooth segmentation from intraoral scans. It includes:
- 18+ model implementations spanning dental-specific architectures, general point cloud backbones, and self-supervised pre-training methods
- A model registry with
@register_modeldecorators β adding a new architecture requires zero changes to the training loop - Unified
train.pyandtest.pyscripts that handle all models through a single CLI - YAML-based configuration with sensible defaults and per-model overrides
- Copy-and-customize templates for new models and datasets
Attribution: This codebase is built upon ToothGroupNetwork by Team CGIP. We thank the original authors for their excellent baseline and data processing pipeline.
boilerplate_segmentation/
βββ train.py # Unified training entry point (all models)
βββ test.py # Unified evaluation entry point (all models)
βββ smoke_test.py # Pipeline verification (no GPU/data needed)
βββ config_loader.py # YAML + legacy Python config loader
β
βββ configs/ # YAML configuration files
β βββ default.yaml # Base config (inherited by all)
β βββ dgcnn.yaml # Per-model overrides
β βββ pointnet.yaml
β βββ ...
β
βββ models/ # Model wrappers + modules
β βββ registry.py # @register_model decorator & factory
β βββ base_model.py # Abstract base class
β βββ new_model_template.py # Template for adding new models
β βββ dgcnn_model.py # DGCNN wrapper
β βββ pointnet_model.py # PointNet wrapper
β βββ ... # (18 model wrappers total)
β βββ modules/ # Neural network implementations
β βββ dgcnn_module.py
β βββ ...
β
βββ datasets/ # Dataset classes
β βββ base_dataset.py # Abstract base dataset
β βββ dental_dataset.py # Dental scan dataset
β βββ new_dataset_template.py # Template for adding new datasets
β
βββ train_configs/ # Legacy Python configs (backward compat)
βββ inference_pipelines/ # Model-specific inference pipelines
βββ external_libs/ # PointNet2, PointOps CUDA extensions
β
βββ generator.py # Legacy dataset (backward compat)
βββ trainer.py # Training loop implementation
βββ runner.py # Training orchestrator
βββ loss_meter.py # Loss aggregation utilities
βββ augmentator.py # Point cloud augmentations
βββ gen_utils.py # General utilities
βββ ops_utils.py # Point cloud operations
βββ preprocess_data.py # Raw mesh β preprocessed .npy
βββ eval_visualize_results.py # Metric computation & visualization
βββ predict_utils.py # Inference utilities
| Category | Model | Registry Name | Reference |
|---|---|---|---|
| Our Method | DentalMAE (Pretrain) | dental_mae_pretrain |
β |
| DentalMAE-Seg (Fine-tune) | dental_mae_seg |
β | |
| Dental-Specific | TGNet-FPS (Challenge Winner) | tgnet_fps |
ToothGroupNetwork |
| TGNet-BDL | tgnet_bdl |
ToothGroupNetwork | |
| TSegNet | tsegnet |
Paper | |
| TSegFormer | tsegformer |
Paper | |
| MeshSegNet | meshsegnet |
Paper | |
| TeethGNN | teethgnn |
β | |
| HiCA | hica |
β | |
| SGTNet | sgtnet |
β | |
| SGTCNet | sgtcnet |
β | |
| TSGCNet | tsgcnet |
β | |
| Fast TGCN | fast_tgcn |
β | |
| UpToothSeg | uptoothseg |
β | |
| Dilated Tooth Seg | dilated_tooth_seg |
β | |
| General Backbones | PointNet | pointnet |
Paper |
| PointNet++ | pointnetpp |
Paper | |
| DGCNN | dgcnn |
Paper | |
| Point Transformer | pointtransformer |
Paper |
List all models:
python train.py --list_models
- Python 3.8+
- PyTorch 1.7.1+ with CUDA 11.0+
- Ubuntu 18.04+ (tested)
# 1. Clone the repository
git clone https://github.com/your-username/tooth-segmentation-benchmark.git
cd tooth-segmentation-benchmark
# 2. Create a virtual environment (recommended)
conda create -n tooth_seg python=3.10 -y
conda activate tooth_seg
# 3. Install PyTorch (adjust for your CUDA version)
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118
# 4. Install dependencies
pip install wandb open3d multimethod termcolor trimesh easydict pyyaml scikit-learn
# 5. Install PointOps (required for Point Transformer & DentalMAE)
cd external_libs/pointops && python setup.py install && cd ../..
# 6. Verify installation
python smoke_test.pyConvert raw .obj meshes to preprocessed .npy point clouds:
python preprocess_data.py \
--source_obj_data_path data_obj_parent_directory \
--source_json_data_path data_json_parent_directory \
--save_data_path path/to/preprocessed_data# Train DGCNN with YAML config
python train.py \
--model_name dgcnn \
--config configs/dgcnn.yaml \
--experiment_name dgcnn_baseline \
--data_dir path/to/preprocessed_data \
--epochs 200
# Train with WandB logging
python train.py \
--model_name pointnet \
--config configs/pointnet.yaml \
--experiment_name pointnet_v1 \
--data_dir path/to/preprocessed_datapython test.py \
--model_name dgcnn \
--config configs/dgcnn.yaml \
--checkpoint ckpts/dgcnn_baseline_val.h5 \
--data_dir path/to/preprocessed_data \
--test_split base_name_test_fold.txt \
--save_predictions results/dgcnn/The train.py script handles all models via the model registry:
python train.py --model_name <MODEL> --config <CONFIG> [OPTIONS]| Argument | Description | Default |
|---|---|---|
--model_name |
Model name from registry | dgcnn |
--config |
Path to .yaml or .py config |
configs/dgcnn.yaml |
--experiment_name |
Name for checkpoints & logging | experiment |
--data_dir |
Preprocessed data directory | β |
--train_split |
Train split txt file | base_name_train_fold.txt |
--val_split |
Validation split txt file | base_name_val_fold.txt |
--epochs |
Number of epochs (overrides config) | 200 |
--batch_size |
Batch size (overrides config) | 1 |
--lr |
Learning rate (overrides config) | β |
--resume |
Checkpoint path to resume from | β |
--wandb_off |
Disable WandB logging | False |
--device |
Force device (cuda/cpu) |
auto |
--val_every |
Validate every N epochs | 1 |
# Phase 1: Self-supervised pre-training
python train.py \
--model_name dental_mae_pretrain \
--config configs/default.yaml \
--experiment_name mae_pretrain \
--data_dir path/to/unlabeled_data
# Phase 2: Supervised fine-tuning
python train.py \
--model_name dental_mae_seg \
--config configs/default.yaml \
--experiment_name mae_finetune \
--data_dir path/to/labeled_dataThe original start_train.py is still available for backward compatibility:
python start_train.py \
--model_name dgcnn \
--config_path train_configs/dgcnn.py \
--experiment_name dgcnn_exp \
...python test.py --model_name <MODEL> --config <CONFIG> --checkpoint <CKPT> [OPTIONS]| Argument | Description |
|---|---|
--checkpoint |
Path to model checkpoint (.h5) |
--test_split |
Test split txt file |
--save_predictions |
Directory to save prediction JSONs |
--num_classes |
Number of classes (default: 17) |
Output metrics:
- Mean IoU (Intersection over Union)
- Mean F1 (Dice Score)
- Accuracy (overall and per-class)
- Per-class IoU and F1 breakdown
python eval_visualize_results.py \
--mesh_path path/to/obj_file \
--gt_json_path path/to/gt_json_file \
--pred_json_path path/to/predicted_json_filecp models/new_model_template.py models/my_model_model.pyEdit my_model_model.py β fill in the TODO markers:
from models.registry import register_model
from models.base_model import BaseModel
@register_model("my_model") # β Name used in CLI
class MyModel(BaseModel):
def __init__(self, config, module=None):
if module is None:
module = MyModule # β Your nn.Module class
super().__init__(config, module)
def get_loss(self, outputs, gt):
# β Define your loss computation
...
def step(self, batch_idx, batch_item, phase="train"):
# β Forward pass + optimization
...cp configs/dgcnn.yaml configs/my_model.yaml
# Edit optimizer, scheduler, model_parameter as neededpython train.py --model_name my_model --config configs/my_model.yaml ...That's it. No changes to
train.py,test.py, or any other file. The@register_modeldecorator handles everything.
cp datasets/new_dataset_template.py datasets/my_dataset.pyYour dataset must inherit from BaseSegDataset and return a dict with:
| Key | Shape | Description |
|---|---|---|
feat |
[C, N] |
Point features (channels-first) |
gt_seg_label |
[1, N] |
Segmentation labels |
category |
[2] |
One-hot jaw category |
mesh_path |
str |
Source file path |
We use the 3DTeethSeg'22 Challenge dataset.
data_obj_parent_directory/
βββ 00OMSZGW/
β βββ 00OMSZGW_lower.obj
β βββ 00OMSZGW_upper.obj
βββ ...
data_json_parent_directory/
βββ 00OMSZGW/
β βββ 00OMSZGW_lower.json
β βββ 00OMSZGW_upper.json
βββ ...
For self-supervised pre-training, download the unlabeled scans from OneDrive.
Run the smoke test to verify the entire pipeline without real data or GPU:
python smoke_test.pyThis tests: model registry (18 models), YAML config loading & merging, dataset instantiation with dummy data, DataLoader batching, loss infrastructure, and CLI script importability.
We welcome contributions! Here's how to get started:
- Fork the repository
- Create a branch for your feature:
git checkout -b feature/my-new-model - Implement your changes (see Adding New Models)
- Test with
python smoke_test.py - Submit a Pull Request with a clear description
| Area | Difficulty | Description |
|---|---|---|
| π§ New Model | ββ | Add a new segmentation architecture |
| π New Dataset | ββ | Add support for another dental dataset |
| π Metrics | β | Add boundary F1, Hausdorff distance |
| π Visualization | ββ | Add 3D prediction visualization tools |
| π Documentation | β | Improve docstrings and examples |
| β‘ Performance | βββ | Multi-GPU training, mixed precision |
- Follow PEP 8 conventions
- Add docstrings to all public functions and classes
- Use the model registry β never add if/elif chains for new models
- Write a corresponding YAML config for every new model
- Ensure
python smoke_test.pypasses before submitting
# configs/my_model.yaml β Inherits from configs/default.yaml
tr_set:
optimizer:
NAME: "adam" # adam | sgd | adamw
lr: 0.001
weight_decay: 0.0001
scheduler:
sched: "cosine" # cosine | exp | step
full_steps: 40
min_lr: 0.00001
model_parameter:
input_feat: 6 # XYZ + normals
num_classes: 17 # 16 teeth + gingiva
training:
epochs: 200
val_every: 1
wandb:
wandb_on: false
project: "tooth_seg"CLI arguments override YAML values (e.g. --lr 0.0005 overrides optimizer.lr).
This project is released for academic and research purposes. Please refer to the original ToothGroupNetwork repository for license details.
- ToothGroupNetwork: Lim, H., et al. "3D Dental Segmentation via Tooth Group Network." MICCAI 2022. GitHub
- 3DTeethSeg Challenge: grand-challenge.org
- PointNet: Qi, C.R., et al. "PointNet: Deep Learning on Point Sets." CVPR 2017.
- PointNet++: Qi, C.R., et al. "PointNet++: Deep Hierarchical Feature Learning." NeurIPS 2017.
- DGCNN: Wang, Y., et al. "Dynamic Graph CNN for Learning on Point Clouds." TOG 2019.
- Point Transformer: Zhao, H., et al. "Point Transformer." ICCV 2021.
This codebase is built upon the excellent work of Team CGIP's ToothGroupNetwork. We are grateful for their open-source contribution which made this benchmark possible.
If you find this work useful, please consider giving it a β