Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
e2c4f4b
AZ segmentation
SarahMuth Oct 23, 2024
a0f713f
updates
SarahMuth Oct 28, 2024
94f9121
Merge branch 'main' of https://github.com/computational-cell-analytic…
SarahMuth Oct 28, 2024
ac1ac00
update 2D DA
SarahMuth Oct 28, 2024
37de75d
Merge branch 'main' of https://github.com/computational-cell-analytic…
SarahMuth Oct 29, 2024
61c57fa
small updates, compartment segmentation
SarahMuth Nov 7, 2024
40e965e
Implement code for first analysis
constantinpape Nov 7, 2024
7be9ee8
2D seg with mask
SarahMuth Nov 11, 2024
b1bef7e
Merge branch 'analysis' of https://github.com/computational-cell-anal…
SarahMuth Nov 11, 2024
f85e445
spatial distribution analysis
SarahMuth Nov 11, 2024
8ef16bc
intersection between compartment boundary and AZ segmentaiton
SarahMuth Nov 12, 2024
e625ef7
Merge branch 'main' of https://github.com/computational-cell-analytic…
SarahMuth Nov 12, 2024
09f6c84
Update compartment postprocessing
constantinpape Nov 12, 2024
d7dbb39
Merge branch 'more-comp-seg-updates' of https://github.com/computatio…
SarahMuth Nov 12, 2024
f893d23
updating data analysis on smaller details
SarahMuth Nov 13, 2024
08c56b9
minor updates data analysis
SarahMuth Nov 13, 2024
49d1b7c
calculation of AZ area
SarahMuth Nov 14, 2024
8a515d1
corrected radius factor
SarahMuth Nov 14, 2024
b1449d2
minor changes
SarahMuth Nov 19, 2024
db89b44
evaluation of AZ seg
SarahMuth Nov 23, 2024
aa5d78e
clean up
SarahMuth Nov 23, 2024
20e429b
clean up
SarahMuth Nov 23, 2024
19f618e
clean up
SarahMuth Nov 23, 2024
84d3ec7
Merge branch 'more-inner-ear-analysis' of https://github.com/computat…
SarahMuth Nov 25, 2024
622da1e
update AZ evaluation
SarahMuth Nov 27, 2024
686b018
erosion dilation filtering of AZ
SarahMuth Nov 28, 2024
6b54e4a
stuff for revision
SarahMuth Mar 31, 2025
7d675ab
everything after 1st revision relating to training, inference, postpr…
SarahMuth May 22, 2025
f052b98
minor things
SarahMuth May 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ models/*/
run_sbatch.sbatch
slurm/
scripts/cooper/evaluation_results/
analysis_results/
scripts/cooper/training/copy_testset.py
scripts/rizzoli/upsample_data.py
scripts/cooper/training/find_rec_testset.py
scripts/cooper/training/find_rec_testset.py
scripts/rizzoli/combine_2D_slices.py
scripts/rizzoli/combine_2D_slices_raw.py
scripts/cooper/remove_h5key.py
scripts/cooper/analysis/calc_AZ_area.py
87 changes: 87 additions & 0 deletions big_to_small_pixel_size.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import os
import numpy as np
import h5py
from glob import glob
from scipy.ndimage import zoom
from scipy.ndimage import label
from skimage.morphology import closing, ball

# Input and output folders
input_folder = "/mnt/lustre-emmy-hdd/usr/u12095/synaptic_reconstruction/AZ_data_after1stRevision/recorrected_length_of_AZ/wichmann_withAZ"
output_folder = "/mnt/lustre-emmy-hdd/usr/u12095/synaptic_reconstruction/AZ_data_after1stRevision/recorrected_length_of_AZ/wichmann_withAZ_rescaled_tomograms"
os.makedirs(output_folder, exist_ok=True)

# Define scaling factors
old_pixel_size = np.array([1.75, 1.75, 1.75])
new_pixel_size = np.array([1.55, 1.55, 1.55])
scaling_factors = old_pixel_size / new_pixel_size

# Utility function to process segmentation
def rescale_and_fix_segmentation(segmentation, scaling_factors):
"""
Rescale the segmentation and ensure labels are preserved.
Args:
segmentation (numpy.ndarray): The input segmentation array with integer labels.
scaling_factors (list or array): Scaling factors for each axis.
Returns:
numpy.ndarray: Rescaled and hole-free segmentation with preserved labels.
"""
# Rescale segmentation using nearest-neighbor interpolation
rescaled_segmentation = zoom(segmentation, scaling_factors, order=0)

# Initialize an array to hold the processed segmentation
processed_segmentation = np.zeros_like(rescaled_segmentation)

# Ensure no holes for each label
unique_labels = np.unique(rescaled_segmentation)
for label_id in unique_labels:
if label_id == 0: # Skip the background
continue

# Extract binary mask for the current label
label_mask = rescaled_segmentation == label_id

# Apply morphological closing to fill holes
closed_mask = closing(label_mask, ball(1))

# Add the processed label back to the output segmentation
processed_segmentation[closed_mask] = label_id

return processed_segmentation.astype(segmentation.dtype)


# Get all .h5 files in the specified input folder
h5_files = glob(os.path.join(input_folder, "*.h5"))
existing_files = {os.path.basename(f) for f in glob(os.path.join(output_folder, "*.h5"))}

for h5_file in h5_files:
print(f"Processing {h5_file}...")

if os.path.basename(h5_file) in existing_files:
print(f"Skipping {h5_file} as it already exists in the output folder.")
continue

# Read data from the .h5 file
with h5py.File(h5_file, "r") as f:
raw = f["raw"][:] # Assuming the dataset is named "raw"
az = f["labels/az"][:]

print(f"Original shape - raw: {raw.shape}; az: {az.shape}")

# Process raw data (tomogram) with linear interpolation
print("Rescaling raw data...")
rescaled_raw = zoom(raw, scaling_factors, order=1)

# Process az segmentation
print("Rescaling and fixing az segmentation...")
rescaled_az = rescale_and_fix_segmentation(az, scaling_factors)

# Save the processed data to a new .h5 file
output_path = os.path.join(output_folder, os.path.basename(h5_file))
with h5py.File(output_path, "w") as f:
f.create_dataset("raw", data=rescaled_raw, compression="gzip")
f.create_dataset("labels/az", data=rescaled_az, compression="gzip")

print(f"Saved rescaled data to {output_path}")

print("Processing complete. Rescaled files are saved in:", output_folder)
189 changes: 189 additions & 0 deletions scripts/cooper/AZ_segmentation_h5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
import argparse
import h5py
import os
import json
from pathlib import Path

from tqdm import tqdm
from elf.io import open_file

from synaptic_reconstruction.inference.AZ import segment_AZ
from synaptic_reconstruction.inference.util import parse_tiling

def _require_output_folders(output_folder):
#seg_output = os.path.join(output_folder, "segmentations")
seg_output = output_folder
os.makedirs(seg_output, exist_ok=True)
return seg_output

def get_volume(input_path):
'''
with h5py.File(input_path) as seg_file:
input_volume = seg_file["raw"][:]
'''
with open_file(input_path, "r") as f:

# Try to automatically derive the key with the raw data.
keys = list(f.keys())
if len(keys) == 1:
key = keys[0]
elif "data" in keys:
key = "data"
elif "raw" in keys:
key = "raw"

input_volume = f[key][:]
return input_volume

def run_AZ_segmentation(input_path, output_path, model_path, mask_path, mask_key,tile_shape, halo, key_label, compartment_seg):
tiling = parse_tiling(tile_shape, halo)
print(f"using tiling {tiling}")
input = get_volume(input_path)

#check if we have a restricting mask for the segmentation
if mask_path is not None:
with open_file(mask_path, "r") as f:
mask = f[mask_key][:]
else:
mask = None

#check if intersection with compartment is necessary
if compartment_seg is None:
foreground = segment_AZ(input_volume=input, model_path=model_path, verbose=False, tiling=tiling, mask = mask)
intersection = None
else:
with open_file(compartment_seg, "r") as f:
compartment = f["/labels/compartment"][:]
foreground, intersection = segment_AZ(input_volume=input, model_path=model_path, verbose=False, tiling=tiling, mask = mask, compartment=compartment)

seg_output = _require_output_folders(output_path)
file_name = Path(input_path).stem
seg_path = os.path.join(seg_output, f"{file_name}.h5")

#check
os.makedirs(Path(seg_path).parent, exist_ok=True)

print(f"Saving results in {seg_path}")
with h5py.File(seg_path, "a") as f:
if "raw" in f:
print("raw image already saved")
else:
f.create_dataset("raw", data=input, compression="gzip")

key=f"AZ/segment_from_{key_label}"
if key in f:
print("Skipping", input_path, "because", key, "exists")
else:
f.create_dataset(key, data=foreground, compression="gzip")

if mask is not None:
if mask_key in f:
print("mask image already saved")
else:
f.create_dataset(mask_key, data = mask, compression = "gzip")

if intersection is not None:
intersection_key = "AZ/compartment_AZ_intersection"
if intersection_key in f:
print("intersection already saved")
else:
f.create_dataset(intersection_key, data = intersection, compression = "gzip")




def segment_folder(args, valid_files):
input_files = [os.path.join(root, name) for root, _, files in os.walk(args.input_path) for name in files if name.endswith(args.data_ext)]
input_files = [f for f in input_files if f in valid_files] if valid_files else input_files
print(input_files)

pbar = tqdm(input_files, desc="Run segmentation")
for input_path in pbar:

filename = os.path.basename(input_path)
try:
mask_path = os.path.join(args.mask_path, filename)
except:
print(f"Mask file not found for {input_path}")
mask_path = None

if args.compartment_seg is not None:
try:
compartment_seg = os.path.join(args.compartment_seg, os.path.splitext(filename)[0] + '.h5')
except:
print(f"compartment file not found for {input_path}")
compartment_seg = None
else:
compartment_seg = None

run_AZ_segmentation(input_path, args.output_path, args.model_path, mask_path, args.mask_key, args.tile_shape, args.halo, args.key_label, compartment_seg)

def get_dataset(json_file, input_path, sets=["test"]):
with open(json_file, 'r') as f:
data = json.load(f)
return {os.path.join(input_path, f) for dataset in sets for f in data.get(dataset, [])}


def main():
parser = argparse.ArgumentParser(description="Segment vesicles in EM tomograms.")
parser.add_argument(
"--input_path", "-i", required=True,
help="The filepath to the mrc file or the directory containing the tomogram data."
)
parser.add_argument(
"--output_path", "-o", required=True,
help="The filepath to directory where the segmentations will be saved."
)
parser.add_argument(
"--model_path", "-m", required=True, help="The filepath to the vesicle model."
)
parser.add_argument(
"--json_path", "-j", help="The filepath to the json file that stores the train, val, and test split."
)
parser.add_argument(
"--mask_path", help="The filepath to a h5 file with a mask that will be used to restrict the segmentation. Needs to be in combination with mask_key."
)
parser.add_argument(
"--mask_key", help="Key name that holds the mask segmentation"
)
parser.add_argument(
"--tile_shape", type=int, nargs=3,
help="The tile shape for prediction. Lower the tile shape if GPU memory is insufficient."
)
parser.add_argument(
"--halo", type=int, nargs=3,
help="The halo for prediction. Increase the halo to minimize boundary artifacts."
)
parser.add_argument(
"--key_label", "-k", default = "combined_vesicles",
help="Give the key name for saving the segmentation in h5."
)
parser.add_argument(
"--data_ext", "-d", default = ".h5",
help="Format extension of data to be segmented, default is .h5."
)
parser.add_argument(
"--compartment_seg", "-c", default = None,
help="Path to compartment segmentation."
"If the compartment segmentation was executed before, this will add a key to output file that stores the intersection between compartment boundary and AZ."
"Maybe need to adjust the compartment key that the segmentation is stored under"
)
args = parser.parse_args()

input_ = args.input_path
valid_files = get_dataset(args.json_path, input_) if args.json_path else None

if valid_files:
if len(valid_files) == 1:
run_AZ_segmentation(next(iter(valid_files)), args.output_path, args.model_path, args.mask_path, args.mask_key, args.tile_shape, args.halo, args.key_label, args.compartment_seg)
else:
segment_folder(args, valid_files)
elif os.path.isdir(args.input_path):
segment_folder(args, valid_files)
else:
run_AZ_segmentation(args.input_path, args.output_path, args.model_path, args.mask_path, args.mask_key, args.tile_shape, args.halo, args.key_label, args.compartment_seg)

print("Finished segmenting!")

if __name__ == "__main__":
main()
Loading