Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ dependencies:
- torch_em >=0.8
- tqdm
- timm
- trackastra
- trackastra >=0.5.3
- xarray
- zarr
- pip:
Expand Down
2 changes: 1 addition & 1 deletion micro_sam/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.7.5"
__version__ = "1.7.6"
16 changes: 14 additions & 2 deletions micro_sam/multi_dimensional_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import os
import multiprocessing as mp
import warnings
from concurrent import futures
from typing import Dict, List, Optional, Union, Tuple

Expand Down Expand Up @@ -569,8 +570,19 @@ def _filter_lineages(lineages, tracking_result):
def _tracking_impl(timeseries, segmentation, mode, min_time_extent, output_folder=None):
device = "cuda" if torch.cuda.is_available() else "cpu"
model = Trackastra.from_pretrained("general_2d", device=device)
lineage_graph, _ = model.track(timeseries, segmentation, mode=mode)
result = model.track(timeseries, segmentation, mode=mode)
try:
lineage_graph, _ = result
except ValueError:
lineage_graph = result

track_data, parent_graph, _ = graph_to_napari_tracks(lineage_graph)
if track_data.size == 0:
warnings.warn("Tracking result is empty.")
tracking_result = np.zeros_like(segmentation)
lineages = []
return tracking_result, lineages

node_to_track, lineages = _extract_tracks_and_lineages(segmentation, track_data, parent_graph)
tracking_result = recolor_segmentation(segmentation, node_to_track)

Expand Down Expand Up @@ -625,7 +637,7 @@ def track_across_frames(
"""
if Trackastra is None:
raise RuntimeError(
"The automatic tracking functionality requires trackastra. You can install it via 'pip install trackastra'."
"Automatic tracking requires trackastra. You can install it via 'pip install trackastra'."
)

_, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init=pbar_init, pbar_update=pbar_update)
Expand Down
37 changes: 34 additions & 3 deletions test/test_automatic_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,18 @@

import numpy as np
import torch
from scipy.ndimage import shift
from skimage.draw import disk
from skimage.measure import label as connected_components

import micro_sam.util as util

try:
from trackastra.model import Trackastra # noqa
WITH_TRACKASTRA = True
except ImportError:
WITH_TRACKASTRA = False

HAVE_CUDA = torch.cuda.is_available()


Expand Down Expand Up @@ -43,9 +50,15 @@ def write_object(center, radius):
def _get_3d_inputs(cls, shape):
mask, image = cls._get_2d_inputs(shape[-2:])

# Create volumes by stacking the input image and respective mask.
volume = np.stack([image] * shape[0])
labels = np.stack([mask] * shape[0])
# Create volumes by stacking the input image and respective mask with small shifts.
labels, volume = [mask], [image]
for _ in range(shape[0] - 1):
shift_vector = (np.random.randint(1, 4), np.random.randint(1, 4))
labels.append(shift(mask, shift=shift_vector, order=0, mode="constant", cval=0))
volume.append(shift(image, shift=shift_vector, order=0, mode="constant", cval=0))

volume = np.stack(volume)
labels = np.stack(labels)
return labels, volume

@classmethod
Expand Down Expand Up @@ -104,6 +117,7 @@ def test_instance_segmentation_with_decoder_2d(self):
predictor=predictor, segmenter=segmenter, input_path=image, ndim=2,
)
self.assertEqual(mask.shape, instances.shape)
self.assertGreater(instances.max(), 0)

def test_tiled_instance_segmentation_with_decoder_2d(self):
from micro_sam.automatic_segmentation import automatic_instance_segmentation, get_predictor_and_segmenter
Expand All @@ -118,6 +132,7 @@ def test_tiled_instance_segmentation_with_decoder_2d(self):
batch_size=2,
)
self.assertEqual(mask.shape, instances.shape)
self.assertGreater(instances.max(), 0)

@unittest.skipUnless(HAVE_CUDA, "Skipping long running tests unless we have a GPU.")
def test_automatic_mask_generator_3d(self):
Expand All @@ -131,6 +146,7 @@ def test_automatic_mask_generator_3d(self):
predictor=predictor, segmenter=segmenter, input_path=volume, ndim=3,
)
self.assertEqual(labels.shape, instances.shape)
self.assertGreater(instances.max(), 0)

@unittest.skipUnless(HAVE_CUDA, "Skipping long running tests unless we have a GPU.")
def test_tiled_automatic_mask_generator_3d(self):
Expand All @@ -145,6 +161,7 @@ def test_tiled_automatic_mask_generator_3d(self):
ndim=3, tile_shape=self.tile_shape, halo=self.halo,
)
self.assertEqual(labels.shape, instances.shape)
self.assertGreater(instances.max(), 0)

def test_instance_segmentation_with_decoder_3d(self):
from micro_sam.automatic_segmentation import automatic_instance_segmentation, get_predictor_and_segmenter
Expand All @@ -157,6 +174,19 @@ def test_instance_segmentation_with_decoder_3d(self):
predictor=predictor, segmenter=segmenter, input_path=volume, ndim=3,
)
self.assertEqual(labels.shape, instances.shape)
self.assertGreater(instances.max(), 0)

@unittest.skipUnless(WITH_TRACKASTRA, "Needs trackastra")
def test_automatic_tracking(self):
from micro_sam.automatic_segmentation import automatic_tracking, get_predictor_and_segmenter

labels, volume = self.labels, self.volume
predictor, segmenter = get_predictor_and_segmenter(
model_type=self.model_type_ais, segmentation_mode="ais", is_tiled=False
)
instances, _ = automatic_tracking(predictor=predictor, segmenter=segmenter, input_path=volume)
self.assertEqual(labels.shape, instances.shape)
self.assertGreater(instances.max(), 0)

def test_tiled_instance_segmentation_with_decoder_3d(self):
from micro_sam.automatic_segmentation import automatic_instance_segmentation, get_predictor_and_segmenter
Expand All @@ -170,6 +200,7 @@ def test_tiled_instance_segmentation_with_decoder_3d(self):
ndim=3, tile_shape=self.tile_shape, halo=self.halo,
)
self.assertEqual(labels.shape, instances.shape)
self.assertGreater(instances.max(), 0)


if __name__ == "__main__":
Expand Down
Loading