Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 54 additions & 12 deletions micro_sam/sam_annotator/_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import zarr
import numpy as np

from napari.layers import Image
from qtpy.QtWidgets import QWidget

import torch.nn as nn
Expand All @@ -19,7 +20,6 @@
from micro_sam.instance_segmentation import AMGBase, get_decoder
from micro_sam.precompute_state import cache_amg_state, cache_is_state

from napari.layers import Image
from segment_anything import SamPredictor

try:
Expand All @@ -37,6 +37,37 @@ def __call__(cls, *args, **kwargs):
return cls._instances[cls]


# TODO: this should be refactored once we have decided on which models to support.
# (Likely only SAM2 models)
def _get_sam_model(model_type, ndim, device, checkpoint_path, use_cli):
if model_type.startswith("h"): # i.e. SAM2 models.
from micro_sam.v2.util import get_sam2_model

if ndim == 2: # Get the SAM2 model and prepare the image predictor.
model = get_sam2_model(model_type=model_type, input_type="images")
# Prepare the SAM2 predictor.
from sam2.sam2_image_predictor import SAM2ImagePredictor
predictor = SAM2ImagePredictor(model)
elif ndim == 3: # Get SAM2 video predictor
predictor = get_sam2_model(model_type=model_type, input_type="videos")
else:
raise ValueError
state = {}

else:
def progress_bar_factory(model_type):
pbar = tqdm(desc=f"Downloading '{model_type}'. This may take a while")
return pbar

predictor, state = util.get_sam_model(
device=device, model_type=model_type,
checkpoint_path=checkpoint_path, return_state=True,
progress_bar_factory=None if use_cli else progress_bar_factory,
)

return predictor, state


@dataclass
class AnnotatorState(metaclass=Singleton):

Expand Down Expand Up @@ -84,6 +115,10 @@ class AnnotatorState(metaclass=Singleton):
previous_features: Optional[np.ndarray] = None
previous_labels: Optional[np.ndarray] = None

# Interactive segmentation class for 'micro-sam2'.
interactive_segmenter: Optional[Any] = None # TODO: Create a base class and add it here.
is_sam2: Optional[bool] = None # Whether this is a SAM1 or SAM2 model.

def initialize_predictor(
self,
image_data,
Expand All @@ -104,18 +139,11 @@ def initialize_predictor(
use_cli=False,
):
assert ndim in (2, 3)
self.is_sam2 = model_type.startswith("h")

# Initialize the model if necessary.
if predictor is None:
def progress_bar_factory(model_type):
pbar = tqdm(desc=f"Downloading '{model_type}'. This may take a while")
return pbar

self.predictor, state = util.get_sam_model(
device=device, model_type=model_type,
checkpoint_path=checkpoint_path, return_state=True,
progress_bar_factory=None if use_cli else progress_bar_factory,
)
self.predictor, state = _get_sam_model(model_type, ndim, device, checkpoint_path, use_cli)
if prefer_decoder and "decoder_state" in state and model_type != "vit_b_medical_imaging":
self.decoder = get_decoder(
image_encoder=self.predictor.model.image_encoder,
Expand All @@ -132,8 +160,13 @@ def progress_bar_factory(model_type):
self.image_embeddings = save_path
self.embedding_path = None # setting this to 'None' as we do not have embeddings cached.

else: # otherwise, compute the image embeddings.
self.image_embeddings = util.precompute_image_embeddings(
else: # Otherwise, compute the image embeddings.
if self.is_sam2:
from micro_sam.v2.util import precompute_image_embeddings as _comp_embed_fn
else:
_comp_embed_fn = util.precompute_image_embeddings

self.image_embeddings = _comp_embed_fn(
predictor=self.predictor,
input_=image_data,
save_path=save_path,
Expand All @@ -146,6 +179,13 @@ def progress_bar_factory(model_type):
)
self.embedding_path = save_path

# Let's prepare the interactive segmentation class.
if self.is_sam2 and ndim == 3:
from micro_sam.v2.prompt_based_segmentation import PromptableSegmentation3D
self.interactive_segmenter = PromptableSegmentation3D(
predictor=self.predictor, volume=image_data, volume_embeddings=self.image_embeddings,
)

# If we have an embedding path the data signature has already been computed,
# and we can read it from there.
if save_path is not None and isinstance(save_path, str):
Expand Down Expand Up @@ -260,4 +300,6 @@ def reset_state(self):
self.committed_lineages = None
self.z_range = None
self.data_signature = None
self.interactive_segmenter = None
self.is_sam2 = None
# Note: we don't clear the widgets here, because they are fixed for a viewer session.
165 changes: 134 additions & 31 deletions micro_sam/sam_annotator/_widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,11 @@ def _get_model_size_options(self):
# We store the actual model names mapped to UI labels.
self.model_size_mapping = {}
if self.model_family == "Natural Images (SAM)":
self.model_size_options = list(self._model_size_map .values())
self.model_size_options = list(self._model_size_map.values())
self.model_size_mapping = {self._model_size_map[k]: f"vit_{k}" for k in self._model_size_map.keys()}
elif self.model_family == "Natural Images (SAM2)":
self.model_size_options = list(self._model_size_map.values())
self.model_size_mapping = {self._model_size_map[k]: f"hvit_{k}" for k in self._model_size_map.keys()}
else:
model_suffix = self.supported_dropdown_maps[self.model_family]
self.model_size_options = []
Expand Down Expand Up @@ -278,7 +281,10 @@ def _update_model_type(self):
size_key = next(
(k for k, v in self._model_size_map.items() if v == self.model_size), "b"
)
self.model_type = f"vit_{size_key}" + self.supported_dropdown_maps[self.model_family]
if "SAM2" in self.model_family:
self.model_type = f"hvit_{size_key}"
else:
self.model_type = f"vit_{size_key}" + self.supported_dropdown_maps[self.model_family]

self.model_size_dropdown.setCurrentText(self.model_size) # Apply the selected text to the dropdown

Expand All @@ -293,6 +299,7 @@ def _create_model_section(self, default_model: str = util._DEFAULT_MODEL, create
# Create a list of support dropdown values and correspond them to suffixes.
self.supported_dropdown_maps = {
"Natural Images (SAM)": "",
"Natural Images (SAM2)": "_sam2",
"Light Microscopy": "_lm",
"Electron Microscopy": "_em_organelles",
"Medical Imaging": "_medical_imaging",
Expand Down Expand Up @@ -343,7 +350,10 @@ def _create_model_size_section(self):

def _validate_model_type_and_custom_weights(self):
# Let's get all model combination stuff into the desired `model_type` structure.
self.model_type = "vit_" + self.model_size[0] + self.supported_dropdown_maps[self.model_family]
if "SAM2" in self.model_family:
self.model_type = "hvit_" + self.model_size[0]
else:
self.model_type = "vit_" + self.model_size[0] + self.supported_dropdown_maps[self.model_family]

# For 'custom_weights', we remove the displayed text on top of the drop-down menu.
if self.custom_weights:
Expand Down Expand Up @@ -452,12 +462,19 @@ def clear_volume(viewer: "napari.viewer.Viewer", all_slices: bool = True) -> Non
viewer: The napari viewer.
all_slices: Choose whether to clear the annotations for all or only the current slice.
"""
state = AnnotatorState()

if all_slices:
vutil.clear_annotations(viewer)
else:
i = int(viewer.dims.point[0])
vutil.clear_annotations_slice(viewer, i=i)

# If it's a SAM2 promptable segmentation workflow,
# we should reset the prompts after clear annotations has been clicked.
if state.interactive_segmenter is not None:
state.interactive_segmenter.reset_predictor()

# Perform garbage collection.
gc.collect()

Expand Down Expand Up @@ -547,6 +564,10 @@ def _commit_impl(viewer, layer, preserve_mode, preservation_threshold):
viewer.layers["committed_objects"].data[bb][mask] = seg[mask]
viewer.layers["committed_objects"].refresh()

# If it's a SAM2 promptable segmentation workflow, we should reset the prompts after commit has been clicked.
if state.interactive_segmenter is not None:
state.interactive_segmenter.reset_predictor()

return id_offset, seg, mask, bb


Expand Down Expand Up @@ -1010,12 +1031,24 @@ def segment(viewer: "napari.viewer.Viewer", batched: bool = False) -> None:
boxes, masks = vutil.shape_layer_to_prompts(viewer.layers["prompts"], shape)
points, labels = vutil.point_layer_to_prompts(viewer.layers["point_prompts"], with_stop_annotation=False)

predictor = AnnotatorState().predictor
image_embeddings = AnnotatorState().image_embeddings
seg = vutil.prompt_segmentation(
predictor, points, labels, boxes, masks, shape, image_embeddings=image_embeddings,
multiple_box_prompts=True, batched=batched, previous_segmentation=viewer.layers["current_object"].data,
)
state = AnnotatorState()
predictor = state.predictor
image_embeddings = state.image_embeddings

if state.is_sam2:
from micro_sam.v2.prompt_based_segmentation import promptable_segmentation_2d
seg = promptable_segmentation_2d(
predictor=predictor,
points=points,
labels=labels,
boxes=boxes,
masks=masks,
)
else:
seg = vutil.prompt_segmentation(
predictor, points, labels, boxes, masks, shape, image_embeddings=image_embeddings,
multiple_box_prompts=True, batched=batched, previous_segmentation=viewer.layers["current_object"].data,
)

# no prompts were given or prompts were invalid, skip segmentation
if seg is None:
Expand Down Expand Up @@ -1053,10 +1086,22 @@ def segment_slice(viewer: "napari.viewer.Viewer") -> None:
points, labels = point_prompts

state = AnnotatorState()
seg = vutil.prompt_segmentation(
state.predictor, points, labels, boxes, masks, shape, multiple_box_prompts=False,
image_embeddings=state.image_embeddings, i=z,
)

if state.is_sam2:
# Use the segment_slice method for SAM2.
boxes = [box[[1, 0, 3, 2]] for box in boxes]
seg = state.interactive_segmenter.segment_slice(
frame_idx=z,
points=points[:, ::-1].copy(),
labels=labels,
boxes=boxes,
masks=masks
)
else:
seg = vutil.prompt_segmentation(
state.predictor, points, labels, boxes, masks, shape, multiple_box_prompts=False,
image_embeddings=state.image_embeddings, i=z,
)

# no prompts were given or prompts were invalid, skip segmentation
if seg is None:
Expand Down Expand Up @@ -1450,9 +1495,16 @@ def pbar_init(total, description):
prefer_decoder = False

state.initialize_predictor(
image_data, model_type=self.model_type, save_path=save_path, ndim=ndim,
device=self.device, checkpoint_path=self.custom_weights, tile_shape=tile_shape, halo=halo,
prefer_decoder=prefer_decoder, pbar_init=pbar_init,
image_data,
model_type=self.model_type,
save_path=save_path,
ndim=ndim,
device=self.device,
checkpoint_path=self.custom_weights,
tile_shape=tile_shape,
halo=halo,
prefer_decoder=prefer_decoder,
pbar_init=pbar_init,
pbar_update=lambda update: pbar_signals.pbar_update.emit(update),
)
pbar_signals.pbar_stop.emit()
Expand Down Expand Up @@ -1536,6 +1588,14 @@ def _create_settings(self):
)
setting_values.layout().addLayout(layout)

# Create the UI element in form of a checkbox for multi-object segmentation.
self.batched = False
setting_values.layout().addWidget(
self._add_boolean_param(
"batched", self.batched, title="batched", tooltip=get_tooltip("segmentnd", "batched")
)
)

# Create the UI element for the motion smoothing (if we have the tracking widget).
if self.tracking:
self.motion_smoothing = 0.5
Expand Down Expand Up @@ -1611,24 +1671,67 @@ def volumetric_segmentation_impl():
pbar_signals.pbar_total.emit(shape[0])
pbar_signals.pbar_description.emit("Segment object")

# Step 1: Segment all slices with prompts.
seg, slices, stop_lower, stop_upper = vutil.segment_slices_with_prompts(
state.predictor, self._viewer.layers["point_prompts"], self._viewer.layers["prompts"],
state.image_embeddings, shape,
update_progress=lambda update: pbar_signals.pbar_update.emit(update),
)
if state.is_sam2:
# Prepare the prompts
point_prompts = self._viewer.layers["point_prompts"]
box_prompts = self._viewer.layers["prompts"]
z_values_points = np.round(point_prompts.data[:, 0])
z_values_boxes = np.concatenate(
[box[:1, 0] for box in box_prompts.data]
) if box_prompts.data else np.zeros(0, dtype="int")

# Whether the user decide to provide batched prompts for multi-object segmentation.
is_batched = bool(self.batched)

# Let's do points first.
for curr_z_values_point in z_values_points:
# Extract the point prompts from the points layer first.
points, labels = vutil.point_layer_to_prompts(layer=point_prompts, i=curr_z_values_point)

# Add prompts one after the other.
[
state.interactive_segmenter.add_point_prompts(
frame_ids=curr_z_values_point,
points=np.array([curr_point]),
point_labels=np.array([curr_label]),
object_id=i if is_batched else None,
) for i, (curr_point, curr_label) in enumerate(zip(points, labels), start=1)
]

# Next, we add box prompts.
for curr_z_values_box in z_values_boxes:
# Extract the box prompts from the shapes layer first.
boxes, _ = vutil.shape_layer_to_prompts(
layer=box_prompts, shape=state.image_shape, i=curr_z_values_box,
)

# Add prompts one after the other.
state.interactive_segmenter.add_box_prompts(frame_ids=curr_z_values_box, boxes=boxes)

# Propagate the prompts throughout the volume and combine the propagated segmentations.
seg = state.interactive_segmenter.predict()

else:
# Step 1: Segment all slices with prompts.
seg, slices, stop_lower, stop_upper = vutil.segment_slices_with_prompts(
state.predictor, self._viewer.layers["point_prompts"], self._viewer.layers["prompts"],
state.image_embeddings, shape,
update_progress=lambda update: pbar_signals.pbar_update.emit(update),
)

# Step 2: Segment the rest of the volume based on projecting prompts.
seg, (z_min, z_max) = segment_mask_in_volume(
seg, state.predictor, state.image_embeddings, slices,
stop_lower, stop_upper,
iou_threshold=self.iou_threshold, projection=self.projection,
box_extension=self.box_extension,
update_progress=lambda update: pbar_signals.pbar_update.emit(update),
)

state.z_range = (z_min, z_max)

# Step 2: Segment the rest of the volume based on projecting prompts.
seg, (z_min, z_max) = segment_mask_in_volume(
seg, state.predictor, state.image_embeddings, slices,
stop_lower, stop_upper,
iou_threshold=self.iou_threshold, projection=self.projection,
box_extension=self.box_extension,
update_progress=lambda update: pbar_signals.pbar_update.emit(update),
)
pbar_signals.pbar_stop.emit()

state.z_range = (z_min, z_max)
return seg

def update_segmentation(seg):
Expand Down
3 changes: 2 additions & 1 deletion micro_sam/sam_annotator/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,7 +693,8 @@ def _sync_embedding_widget(widget, model_type, save_path, checkpoint_path, devic

# Update the index for model size, eg. 'base', 'tiny', etc.
size_map = {"t": "tiny", "b": "base", "l": "large", "h": "huge"}
model_size = size_map[model_type[4]]
size_idx = 5 if model_type.startswith("h") else 4
model_size = size_map[model_type[size_idx]]

index = widget.model_size_dropdown.findText(model_size)
if index > 0:
Expand Down
Empty file added micro_sam/v2/__init__.py
Empty file.
Empty file added micro_sam/v2/models/__init__.py
Empty file.
Loading
Loading