Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 177 additions & 0 deletions CLAUDE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
# CLAUDE.md

This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
microSAM is a state-of-the-art tool for interactive and automatic microscopy segmentation based on segment anything model (SAM).
The GUI is built as a napari plugin.
We are currently updating the tool to support SAM2 and we are refactoring and extending the GUI in this context.
You will implement tasks related to this project.

## Common Commands

### Code Quality
```bash
# Format code with black
black micro_sam/

# Lint with ruff (auto-fixes enabled)
ruff check micro_sam/
```

### Tests
```bash
# Run all tests
# Note: this takes very long, only run the relevant tests for what you develop.
pytest
```

## Code Architecture

### Module Organization

**Core Segmentation (micro_sam/):**
- `util.py` - Model loading, device management, embeddings, preprocessing utilities
- `inference.py` - Batched/tiled inference for large images
- `prompt_based_segmentation.py` - Interactive segmentation with point/box/mask prompts
- `automatic_segmentation.py` - High-level API for automatic segmentation workflows
- `instance_segmentation.py` - Core automatic segmentation implementations (AMG, AIS, APG)
- `multi_dimensional_segmentation.py` - 3D volume and temporal tracking segmentation

**SAM v2 Support (micro_sam/v2/):**
- SAM v2 uses Hiera backbone (hvit_t, hvit_s, hvit_b, hvit_l) with temporal/video capabilities
- `v2/util.py` - SAM2-specific model loading and configuration
- `v2/prompt_based_segmentation.py` - Wrapper for SAM2 2D/3D predictions
- `v2/models/_video_predictor.py` - Video/tracking predictor
- Model type prefixes: SAM v1 = `vit_*`, SAM v2 = `hvit_*`

**Napari UI (micro_sam/sam_annotator/):**
- `_annotator.py` - Base annotator with napari layer/widget/keybinding setup
- `annotator.py`, `annotator_tracking.py` - UIs for interactive and automatic segmentation / tracking
- `_state.py` - Singleton state manager (predictor, embeddings, AMG generators)
- `_widgets.py` - Qt widgets for embedding/segmentation/tracking controls
- `image_series_annotator.py` - Multi-image batch annotation workflow

**Training/Finetuning (micro_sam/training/):**
- `sam_trainer.py` - Base trainer extending torch_em.DefaultTrainer
- `joint_sam_trainer.py`, `simple_sam_trainer.py`, `semantic_sam_trainer.py` - Specialized trainers
- `training.py` - High-level training orchestration and CONFIGURATIONS registry
- `util.py` - Training data conversion and model loading

**Models (micro_sam/models/):**
- `build_sam.py` - Factory for SAM v1 models (vit_b, vit_l, vit_h, vit_t via MobileSAM)
- `peft_sam.py` - Parameter-efficient fine-tuning (LoRA, FacT, SSF, AdaptFormer)
- `sam_3d_wrapper.py` - 3D-compatible SAM wrapper
- `simple_sam_3d_wrapper.py` - Simplified 3D segmentation model

### Key Architectural Patterns

**Three Automatic Segmentation Modes:**

1. **AMG (Automatic Mask Generator)** - Default, grid-based prompting
- Classes: `AutomaticMaskGenerator`, `TiledAutomaticMaskGenerator`
- No decoder required
- Factory: `get_instance_segmentation_generator(mode="amg")`

2. **AIS (Instance Segmentation with Decoder)** - UNETR decoder-based
- Classes: `InstanceSegmentationWithDecoder`, `TiledInstanceSegmentationWithDecoder`
- Requires trained decoder checkpoint
- Factory: `get_instance_segmentation_generator(mode="ais")`

3. **APG (Automatic Prompt Generator)** - Decoder + iterative refinement
- Classes: `AutomaticPromptGenerator`, `TiledAutomaticPromptGenerator`
- Extends AIS with prompt refinement
- Factory: `get_instance_segmentation_generator(mode="apg")`

All modes support tiling for large images via `inference.batched_tiled_inference()`.

**Precomputation and Caching:**
- `util.precompute_image_embeddings()` - Compute and cache embeddings
- `util.set_precomputed()` - Load precomputed embeddings
- `precompute_state.py` - CLI and batch precomputation
- Saves to zarr/h5 format for fast loading
- Embeddings stored as ImageEmbeddings dict with 'features', 'input_size', 'original_size'

**Prompt Generators (for training):**
- `PromptGeneratorBase` - Abstract interface
- `PointAndBoxPromptGenerator` - Samples prompts from ground-truth masks
- `IterativePromptGenerator` - Adapts prompts based on prediction errors
- Used by trainers for curriculum learning

**PEFT Surgery (Parameter-Efficient Fine-Tuning):**
- `PEFT_Sam` wrapper enables freezing most parameters
- Strategies: LoRA, FacT, SSF, AdaptFormer, ClassicalSurgery
- Configured via `models.peft_sam.PEFT_Sam(sam_model, rank=4, peft_module="lora")`

### Data Flow

**Interactive Annotation:**
```
User input (napari) → AnnotatorState → Predictor.predict() →
Update napari layers → Display result
```

**Automatic Segmentation:**
```
Image → util.precompute_image_embeddings() →
AMG/AIS/APG.initialize() → generator.generate() →
Instance masks
```

**Training:**
```
DataLoader → ConvertToSamInputs → SamTrainer →
Iterative prompting → Loss (Dice + IoU MSE) →
Save checkpoint with decoder_state
```

### Important Implementation Notes

**Model Registry and Loading:**
- SAM v1 models downloaded from Facebook via `util.get_sam_model()`
- SAM v2 models via `v2.util.get_sam2_model()`
- Finetuned models available: `vit_b_lm` (light microscopy), `vit_b_em_organelles`, etc.
- Model type determines architecture automatically

**State Management:**
- `AnnotatorState` is a singleton (metaclass-based)
- Shared across all annotator instances
- Contains predictor, embeddings, AMG generator, decoder

**Tiling Strategy:**
- Enabled by `is_tiled=True` in factory functions
- Applies halos for overlap handling
- Merges results avoiding duplicates
- Critical for large images (>2048px)

**SAM v1 vs v2 Routing:**
- Model type prefix determines version: `vit_*` → SAM v1, `hvit_*` → SAM v2
- `_state._get_sam_model()` handles version selection
- SAM v2 adds video/tracking capabilities via video predictor

**Training Checkpoint Format:**
```python
{
'model': model_state_dict,
'decoder_state': decoder_weights, # Optional, for AIS/APG modes
'config': model_config,
'epoch': int,
'optimizer': optimizer_state
}
```

**Decoder Integration:**
- UNETR decoder predicts: center distances, boundaries, foreground probability
- Loaded via `get_decoder()` from checkpoint's `decoder_state`
- Used by AIS and APG modes for prompt generation

### Testing Guidelines

- Use google style docstrings for new code
- Write unit tests for new functionality
- GUI tests should use `make_napari_viewer_proxy` fixture
- Mark slow tests: `@pytest.mark.slow`
- Mark GUI tests: `@pytest.mark.gui`
- Coverage reports generated automatically with pytest-cov

### Environment Variables

- `PYTORCH_ENABLE_MPS_FALLBACK` - Enable Apple Silicon MPS fallback
4 changes: 4 additions & 0 deletions RELEASE_OVERVIEW.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Release Overview

**New in version 1.7.1**

Fixing minor issues in 1.7.0 (related to trackastra, automatic segmentation and training functions) and adding new section in documentation for our new automatic segmentation pipeline, APG.

**New in version 1.7.0**

Updates to the automatic instance segmentation pipeline (introduces APG - automatic prompt generation).
Expand Down
2 changes: 1 addition & 1 deletion doc/deprecated/development.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
This software consists of four different python (sub-)modules:
- The top-level `micro_sam` module implements general purpose functionality for using Segment Anything for multi-dimensional data.
- `micro_sam.evaluation` provides functionality to evaluate Segment Anything models on (microscopy) segmentation tasks.
- `micro_sam.traning` implements the training functionality to finetune Segment Anything for custom segmentation datasets.
- `micro_sam.training` implements the training functionality to finetune Segment Anything for custom segmentation datasets.
- `micro_sam.sam_annotator` implements the interactive annotation tools.

## Annotation Tools
Expand Down
5 changes: 5 additions & 0 deletions doc/faq.md
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,11 @@ You can then use those models with the custom checkpoint option, see answer 15 f
### 18. I would like to evaluate the instance segmentation quantitatively. Can you suggest how to do that?
`micro-sam` supports a `micro_sam.evaluate` CLI, which computes the mean segmentation accuracy (introduced in the Pascal VOC challenge) of the predicted instance segmentation with the corresponding ground-truth annotations. Please see our paper (`Methods` -> `Inference and Evaluation` for more details about it) and `$ micro_sam.evaluate -h` for more details about the evaluation CLI.

### 19. I get `RuntimeError: GET was unable to find an engine to execute this computation` on a V100 GPU (*"or any older GPU"*).
This is a known issue for a combination of older generation GPUs (eg. V100s) and pytorch compiled with the latest CUDA Toolkit (eg. CUDA 12.9 and PyTorch 2.8 has been tested to throw this error on V100s).
Here's what you can do to solve this issue:
- Use a PyTorch/CUDA build that is known to work with V100, for example CUDA 12.1 or 11.8 with a compatible PyTorch version (please check your installed CUDA drivers).
- Run on CPU (slower, but works).

## Fine-tuning questions

Expand Down
18 changes: 10 additions & 8 deletions examples/annotator_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import imageio.v3 as imageio
from micro_sam.util import get_cache_directory
from micro_sam.sam_annotator import annotator_2d
from micro_sam.sam_annotator import annotator
from micro_sam.sample_data import fetch_hela_2d_example_data, fetch_livecell_example_data, fetch_wholeslide_example_data


Expand All @@ -24,9 +24,9 @@ def livecell_annotator(use_finetuned_model):
model_type = "vit_b_lm"
else:
embedding_path = os.path.join(EMBEDDING_CACHE, "embeddings-livecell.zarr")
model_type = "vit_h"
model_type = "vit_b"

annotator_2d(image, embedding_path, model_type=model_type, precompute_amg_state=True)
annotator(image, embedding_path=embedding_path, model_type=model_type, precompute_amg_state=True)


def hela_2d_annotator(use_finetuned_model):
Expand All @@ -40,9 +40,9 @@ def hela_2d_annotator(use_finetuned_model):
model_type = "vit_b_lm"
else:
embedding_path = os.path.join(EMBEDDING_CACHE, "embeddings-hela2d.zarr")
model_type = "vit_h"
model_type = "vit_b"

annotator_2d(image, embedding_path, model_type=model_type)
annotator(image, embedding_path=embedding_path, model_type=model_type)


def wholeslide_annotator(use_finetuned_model):
Expand All @@ -59,9 +59,11 @@ def wholeslide_annotator(use_finetuned_model):
model_type = "vit_b_lm"
else:
embedding_path = os.path.join(EMBEDDING_CACHE, "whole-slide-embeddings.zarr")
model_type = "vit_h"
model_type = "vit_b"

annotator_2d(image, embedding_path, tile_shape=(1024, 1024), halo=(256, 256), model_type=model_type)
annotator(
image, embedding_path=embedding_path, tile_shape=(1024, 1024), halo=(256, 256), model_type=model_type
)


def main():
Expand All @@ -80,6 +82,6 @@ def main():

# The corresponding CLI call for hela_2d_annotator:
# (replace with cache directory on your machine)
# $ micro_sam.annotator_2d -i /home/pape/.cache/micro_sam/sample_data/hela-2d-image.png -e /home/pape/.cache/micro_sam/embeddings/embeddings-hela2d.zarr # noqa
# $ micro_sam.annotator -i /home/pape/.cache/micro_sam/sample_data/hela-2d-image.png -e /home/pape/.cache/micro_sam/embeddings/embeddings-hela2d.zarr # noqa
if __name__ == "__main__":
main()
6 changes: 3 additions & 3 deletions examples/annotator_3d.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os

from elf.io import open_file
from micro_sam.sam_annotator import annotator_3d
from micro_sam.sam_annotator import annotator
from micro_sam.sample_data import fetch_3d_example_data
from micro_sam.util import get_cache_directory

Expand All @@ -28,7 +28,7 @@ def em_3d_annotator(use_finetuned_model):
precompute_amg_state = False

# start the annotator, cache the embeddings
annotator_3d(raw, embedding_path, model_type=model_type, precompute_amg_state=precompute_amg_state)
annotator(raw, embedding_path=embedding_path, model_type=model_type, precompute_amg_state=precompute_amg_state)


def main():
Expand All @@ -40,6 +40,6 @@ def main():

# The corresponding CLI call for em_3d_annotator:
# (replace with cache directory on your machine)
# $ micro_sam.annotator_3d -i /home/pape/.cache/micro_sam/sample_data/lucchi_pp.zip.unzip/Lucchi++/Test_In -k *.png -e /home/pape/.cache/micro_sam/embeddings/embeddings-lucchi.zarr # noqa
# $ micro_sam.annotator -i /home/pape/.cache/micro_sam/sample_data/lucchi_pp.zip.unzip/Lucchi++/Test_In -k *.png -e /home/pape/.cache/micro_sam/embeddings/embeddings-lucchi.zarr # noqa
if __name__ == "__main__":
main()
4 changes: 2 additions & 2 deletions examples/finetuning/annotator_with_finetuned_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import imageio.v3 as imageio

from micro_sam.sam_annotator import annotator_2d
from micro_sam.sam_annotator import annotator


def run_annotator_with_finetuned_model():
Expand All @@ -20,7 +20,7 @@ def run_annotator_with_finetuned_model():
model_type = "vit_b" # We finetune a vit_b in the example script.

# Run the 2d annotator with the custom model.
annotator_2d(im, model_type=model_type, embedding_path=embedding_path, checkpoint=checkpoint)
annotator(im, model_type=model_type, embedding_path=embedding_path, checkpoint=checkpoint)


if __name__ == "__main__":
Expand Down
6 changes: 2 additions & 4 deletions examples/finetuning/finetune_hela.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import os
import numpy as np

import torch

import torch_em
from torch_em.transform.label import PerObjectDistanceTransform

import micro_sam.training as sam_training
from micro_sam.util import export_custom_sam_model
from micro_sam.util import export_custom_sam_model, get_device
from micro_sam.sample_data import fetch_tracking_example_data, fetch_tracking_segmentation_data


Expand Down Expand Up @@ -82,7 +80,7 @@ def run_training(checkpoint_name, model_type, train_instance_segmentation):
batch_size = 1 # the training batch size
patch_shape = (1, 512, 512) # the size of patches for training
n_objects_per_batch = 25 # the number of objects per batch that will be sampled
device = torch.device("cuda") # the device used for training
device = get_device() # the device used for training

# Get the dataloaders.
train_loader = get_dataloader("train", patch_shape, batch_size, train_instance_segmentation)
Expand Down
5 changes: 3 additions & 2 deletions examples/image_series_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@ def series_annotation(use_finetuned_model):
model_type = "vit_b_lm"
else:
embedding_path = os.path.join(EMBEDDING_CACHE, "series-embeddings")
model_type = "vit_h"
model_type = "vit_b"

example_data = fetch_image_series_example_data(DATA_CACHE)
image_folder_annotator(
example_data, "./series-segmentation-result",
example_data,
output_folder="./series-segmentation-result",
pattern="*.tif",
embedding_path=embedding_path,
model_type=model_type,
Expand Down
4 changes: 2 additions & 2 deletions examples/object_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def _get_livecell_data():

embedding_path = os.path.join(EMBEDDING_CACHE, "embeddings-livecell-vit_b_lm.zarr")

# This is the vit-b-lm segmentation and a test annotaiton.
# This is the vit-b-lm segmentation and a test annotation.
segmentation = imageio.imread("./clf-test-data/livecell-test-seg.tif")
annotations = imageio.imread("./clf-test-data/livecell-test-annotations.tif")

Expand All @@ -97,7 +97,7 @@ def _get_wholeslide_data():

embedding_path = os.path.join(EMBEDDING_CACHE, "whole-slide-embeddings-vit_b_lm.zarr")

# This is the vit-b-lm segmentation and a test annotaiton.
# This is the vit-b-lm segmentation and a test annotation.
segmentation = imageio.imread("./clf-test-data/whole-slide-seg.tif")
annotations = imageio.imread("./clf-test-data/wholeslide-annotations.tif")

Expand Down
2 changes: 1 addition & 1 deletion micro_sam/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.7.0"
__version__ = "1.7.1"
6 changes: 3 additions & 3 deletions micro_sam/automatic_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,11 +273,11 @@ def automatic_instance_segmentation(

# Allow opening the automatic segmentation in the annotator for further annotation, if desired.
if annotate:
from micro_sam.sam_annotator import annotator_2d, annotator_3d
annotator_function = annotator_2d if ndim == 2 else annotator_3d
from micro_sam.sam_annotator import annotator

viewer = annotator_function(
viewer = annotator(
image=image_data,
ndim=ndim,
model_type=predictor.model_name,
embedding_path=image_embeddings, # Providing the precomputed image embeddings.
segmentation_result=instances, # Initializes the automatic segmentation to the annotator.
Expand Down
2 changes: 1 addition & 1 deletion micro_sam/bioimageio/model_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def export_sam_model(
Args:
image: The image for generating test data.
label_image: The segmentation correspoding to `image`.
label_image: The segmentation corresponding to `image`.
It is used to derive prompt inputs for the model.
model_type: The type of the SAM model.
name: The name of the exported model.
Expand Down
2 changes: 1 addition & 1 deletion micro_sam/instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1408,7 +1408,7 @@ def prompt_function(foreground, center_distances, boundary_distances, **kwargs)

# 2.) Apply the predictor to the prompts.
if prompts is None: # No prompts were derived, we can't do much further and return empty masks.
return np.zeros(foreground.shape, dtype="uint32") if output_mode == "instance_egmentation" else []
return np.zeros(foreground.shape, dtype="uint32") if output_mode == "instance_segmentation" else []
else:
predictions = batched_inference(
self._predictor,
Expand Down
Loading