diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 97f2aeb..47fb202 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -407,8 +407,8 @@ jobs: # The following tests are excluded from CI and should be run locally: # # Slow/GPU-intensive tests: -# - tests/test_register_images_ants.py (slow, computationally intensive) -# - tests/test_register_images_icon.py (requires CUDA for ICON) +# - tests/test_register_images_ANTS.py (slow, computationally intensive) +# - tests/test_register_images_ICON.py (requires CUDA for ICON) # - tests/test_transform_tools.py (depends on slow registration tests) # - tests/test_segment_chest_total_segmentator.py (requires CUDA for TotalSegmentator) # @@ -421,7 +421,7 @@ jobs: # pytest tests/ -v --run-slow # Run all slow tests # pytest tests/ -v --run-gpu --run-slow # GPU + slow (typical local dev profile) # pytest tests/ -v --run-simpleware --run-gpu --run-slow # Full Simpleware coverage -# pytest tests/test_register_images_ants.py -v --run-slow +# pytest tests/test_register_images_ANTS.py -v --run-slow # # Self-hosted GPU runner enables ALL buckets via --run-all # (--run-gpu --run-slow --run-simpleware --run-physicsnemo --run-experiments --run-tutorials). diff --git a/.gitignore b/.gitignore index be88ce1..c117323 100644 --- a/.gitignore +++ b/.gitignore @@ -32,6 +32,7 @@ network_weights *.gz *.mat *.mhd +*.raw *.zip *.nii *.tfm diff --git a/README.md b/README.md index 79b08fb..a3bd654 100644 --- a/README.md +++ b/README.md @@ -134,7 +134,7 @@ print(WorkflowConvertImageToUSD.__name__) - **Registration Classes**: Multiple registration methods for different use cases - Image-to-Image Registration: - `RegisterImagesICON`: Deep learning-based registration using Icon algorithm - - `RegisterImagesANTs`: Classical deformable registration using ANTs + - `RegisterImagesANTS`: Classical deformable registration using ANTs - `RegisterTimeSeriesImages`: Specialized time series registration for 4D CT - Model-to-Image/Model Registration: - `RegisterModelsPCA`: PCA-based statistical shape model registration @@ -354,7 +354,7 @@ if "lung" in masks: ### Image Registration ```python -from physiomotion4d import RegisterImagesICON, RegisterImagesANTs, RegisterTimeSeriesImages +from physiomotion4d import RegisterImagesICON, RegisterImagesANTS, RegisterTimeSeriesImages import itk # Option 1: Icon deep learning registration (GPU-accelerated) @@ -364,14 +364,14 @@ registerer.set_fixed_image(itk.imread("reference_frame.mha")) results = registerer.register(itk.imread("target_frame.mha")) # Option 2: ANTs classical registration -registerer = RegisterImagesANTs() +registerer = RegisterImagesANTS() registerer.set_fixed_image(itk.imread("reference_frame.mha")) results = registerer.register(itk.imread("target_frame.mha")) # Option 3: Time series registration for 4D CT time_series_reg = RegisterTimeSeriesImages( reference_index=0, - registration_method='icon' # or 'ants' + registration_method='ICON' # or 'ANTS' ) transforms = time_series_reg.register_time_series( image_filenames=["time00.mha", "time01.mha", "time02.mha"] @@ -735,7 +735,7 @@ will change, and flag any coordinate-system or shape implications. Use `/impl` for end-to-end implementation: read → summarize → plan → diff → lint. ```text -/impl add set_regularization_weight() to RegisterImagesANTs +/impl add set_regularization_weight() to RegisterImagesANTS ``` ```text @@ -757,7 +757,7 @@ Use `/test-feature` to get a test plan and a complete pytest file using syntheti ``` ```text -/test-feature RegisterImagesANTs with a pair of small synthetic ITK images +/test-feature RegisterImagesANTS with a pair of small synthetic ITK images ``` The agent will state image shapes and axis orders in every test docstring, wire @@ -772,7 +772,7 @@ Use `/doc-feature` after modifying a public API to refresh docstrings and regene the API map. ```text -/doc-feature update docstrings for RegisterImagesANTs after adding set_regularization_weight +/doc-feature update docstrings for RegisterImagesANTS after adding set_regularization_weight ``` The agent will update affected docstrings in NumPy style, add shape/axis annotations diff --git a/docs/API_MAP.md b/docs/API_MAP.md index 43f9ba6..cdcc0dc 100644 --- a/docs/API_MAP.md +++ b/docs/API_MAP.md @@ -28,14 +28,9 @@ _Re-run `py utils/generate_api_map.py` whenever public APIs change._ ## experiments/LongitudinalRegistration/0-cardiacGatedCT_segment_and_landmark.py -- `def segment_images(src_data_dirs, src_data_files)` (line 63): Segment each image with SegmentHeartSimpleware and save labelmaps. +- `def segment_images(src_data_dirs, src_data_files)` (line 64): Segment each image with SegmentHeartSimpleware and save labelmaps. -## experiments/LongitudinalRegistration/1-finetune_icon.py - -- `def get_segmented_images(src_data_dirs, src_data_files)` (line 81): Segment each image with SegmentHeartSimpleware and save labelmaps. -- `def get_mask_images(src_data_dirs, src_data_files)` (line 117): Get mask images for each image. - -## experiments/LongitudinalRegistration/2-run_registration_comparison.py +## experiments/LongitudinalRegistration/3-run_registration_method_comparison.py - **class MethodSpec** (line 49): Registration method plus optional ICON checkpoint. - **class ImageArtifacts** (line 58): Input files associated with one image volume. @@ -61,31 +56,7 @@ _Re-run `py utils/generate_api_map.py` whenever public APIs change._ - `def parse_iterations(value)` (line 561): Parse comma-separated multi-resolution iteration counts. - `def main()` (line 566): Run the longitudinal registration comparison experiment. -## experiments/LongitudinalRegistration/recon_4d_icon_eval.py - -- **class MethodSpec** (line 67): Output label and optional ICON checkpoint for one registration run. -- **class TimepointArtifacts** (line 75): File paths for one gated time-point: image, labelmap, landmarks. -- `def nii_stem(path)` (line 84): Return the stem of a ``.nii.gz`` (or single-suffix) file. -- `def timepoint_from_name(path)` (line 91): Extract the gated time-point tag (``g###``) from a filename. -- `def select_reference_index(num_frames, percentile)` (line 99): Return the frame index closest to ``percentile`` of the series. -- `def discover_subject(subject_id, timepoint_base_dir, segmentation_base_dir, exclude_tokens)` (line 108): Discover gated images plus their labelmap and landmark companions. -- `def discover_subject_ids(timepoint_base_dir, segmentation_base_dir)` (line 149): Return subject IDs that have both gated and segmentation directories. -- `def read_landmarks(path)` (line 166): Read physical LPS landmarks from a ``Name,X,Y,Z`` CSV file. -- `def write_landmarks(path, landmarks)` (line 179): Write physical LPS landmarks to a ``Name,X,Y,Z`` CSV file. -- `def transform_landmarks(landmarks, transform)` (line 189): Apply an ITK physical-space transform to landmark coordinates. -- `def landmark_errors(source, target)` (line 202): Return per-landmark Euclidean errors in millimeters. -- `def summarize_errors(errors, prefix)` (line 212): Summarize landmark errors for one comparison mode. -- `def dice_by_label(labelmap_a, labelmap_b)` (line 230): Compute Dice for every non-zero label present in either 3D labelmap. -- `def summarize_dice(scores)` (line 249): Summarize per-label Dice scores into mean and minimum. -- `def write_error_details(path, subject_id, method_name, timepoint, mode, errors)` (line 261): Append per-landmark errors to the long-form detail CSV. -- `def read_error_details(path)` (line 290): Read the long-form per-landmark error CSV. -- `def print_summary_table(detail_file)` (line 299): Print a high-level table comparing methods for each landmark mode. -- `def write_summary(path, rows)` (line 366): Write the wide-form summary CSV. -- `def mask_from_labelmap(labelmap)` (line 378): Return a uint8 binary mask covering the non-zero labels. -- `def run_method_for_subject(subject_id, timepoint_artifacts, reference_index, method_spec, output_dir, icon_iterations, run_resegmentation, error_detail_file)` (line 386): Run one ICON method for one subject and return per-timepoint rows. -- `def main()` (line 544): Run the ICON default-vs-finetuned comparison experiment. - -## experiments/LongitudinalRegistration/recon_4d_run.py +## experiments/LongitudinalRegistration/experiment_recon_4d.py - `def register_time_series(reference_image_file, source_image_dir, source_image_files, registration_method)` (line 75) @@ -439,6 +410,27 @@ _Re-run `py utils/generate_api_map.py` whenever public APIs change._ - `def register_slices(reg_tool, reg_tool_name, fixed_image, images, files_indx, reference_image_num, reference_image_reg_use_identity, portion_of_prior_to_use=0.0)` (line 62) +## results/icon_finetuned/icon_finetuned_model/finetune.py + +- `def loss_to_dict(loss_object)` (line 33) +- `def augment(batch)` (line 50): Apply random affine augmentation to all spatial data in a batch dict. +- `def finetune_multi(config, data_loader, val_data_loaders_dict, data_fields)` (line 304): Unified finetuning loop. +- `def main(argv=None)` (line 432) + +## results/results/icon_finetuned/icon_finetuned_model/finetune.py + +- `def loss_to_dict(loss_object)` (line 33) +- `def augment(batch)` (line 50): Apply random affine augmentation to all spatial data in a batch dict. +- `def finetune_multi(config, data_loader, val_data_loaders_dict, data_fields)` (line 304): Unified finetuning loop. +- `def main(argv=None)` (line 432) + +## results/results/icon_finetuned/icon_finetuned_model-1/finetune.py + +- `def loss_to_dict(loss_object)` (line 33) +- `def augment(batch)` (line 50): Apply random affine augmentation to all spatial data in a batch dict. +- `def finetune_multi(config, data_loader, val_data_loaders_dict, data_fields)` (line 304): Unified finetuning loop. +- `def main(argv=None)` (line 432) + ## src/physiomotion4d/anatomy_taxonomy.py - **class AnatomyGroup** (line 26): One named anatomy group together with the organ labels it contains. @@ -542,6 +534,15 @@ _Re-run `py utils/generate_api_map.py` whenever public APIs change._ - `def convert_array_to_image_of_vectors(self, arr_data, reference_image, ptype=itk.D)` (line 218): Convert a numpy array to an ITK image of vector type. - `def flip_image(self, in_image, in_mask=None, flip_x=False, flip_y=False, flip_z=False, flip_and_make_identity=False)` (line 249): Flip the image and mask. +## src/physiomotion4d/landmark_tools.py + +- **class LandmarkTools** (line 28): Read and write anatomical landmarks in LPS world coordinates. + - `def __init__(self, log_level=logging.INFO)` (line 51): Initialize the LandmarkTools class. + - `def read_landmarks_3dslicer(self, path)` (line 59): Read landmarks from a 3D Slicer Markups JSON (``.mrk.json``) file. + - `def write_landmarks_3dslicer(self, landmarks, path)` (line 112): Write landmarks to a 3D Slicer Markups JSON file in LPS. + - `def read_landmarks_csv(self, path)` (line 148): Read landmarks from a CSV file with header ``Name,x,y,z`` (LPS). + - `def write_landmarks_csv(self, landmarks, path)` (line 192): Write landmarks to a CSV file with header ``Name,x,y,z`` (LPS). + ## src/physiomotion4d/physiomotion4d_base.py - **class ClassNameFilter** (line 38): Filter to show logs only from specific class names. @@ -563,13 +564,13 @@ _Re-run `py utils/generate_api_map.py` whenever public APIs change._ ## src/physiomotion4d/register_images_ants.py -- **class RegisterImagesANTs** (line 24): ANTs-based deformable image registration implementation. +- **class RegisterImagesANTS** (line 24): ANTs-based deformable image registration implementation. - `def __init__(self, log_level=logging.INFO)` (line 70): Initialize the ANTs image registration class. - `def set_number_of_iterations(self, number_of_iterations)` (line 85): Set the number of iterations for ANTs registration. - `def set_transform_type(self, transform_type)` (line 94): Set the type of transform to use for registration. - `def set_metric(self, metric)` (line 106): Set the similarity metric to use for registration. - - `def itk_affine_transform_to_ants_transform(self, itk_tfm)` (line 316): Convert ITK affine/rigid transform to ANTs affine transform. - - `def itk_transform_to_antsfile(self, itk_tfm, reference_image, output_filename)` (line 409): Convert ITK transform to ANTs transform file. + - `def itk_affine_transform_to_ANTS_transform(self, itk_tfm)` (line 316): Convert ITK affine/rigid transform to ANTs affine transform. + - `def itk_transform_to_ANTSfile(self, itk_tfm, reference_image, output_filename)` (line 409): Convert ITK transform to ANTs transform file. - `def registration_method(self, moving_image, moving_mask=None, moving_labelmap=None, moving_image_pre=None, initial_forward_transform=None)` (line 509): Register moving image to fixed image using ANTs registration algorithm. ## src/physiomotion4d/register_images_base.py @@ -597,21 +598,21 @@ _Re-run `py utils/generate_api_map.py` whenever public APIs change._ ## src/physiomotion4d/register_images_icon.py -- **class RegisterImagesICON** (line 32): ICON-based deformable image registration implementation. - - `def __init__(self, log_level=logging.INFO)` (line 68): Initialize the ICON image registration class. - - `def set_weights_path(self, weights_path)` (line 86): Set a custom weights file for the uniGradICON network. - - `def set_number_of_iterations(self, number_of_iterations)` (line 100): Set the number of iterations for ICON registration. - - `def set_multi_modality(self, enable)` (line 108): Enable or disable multi-modality registration. - - `def set_mass_preservation(self, enable)` (line 125): Enable or disable mass preservation constraint. - - `def preprocess(self, image, modality='ct')` (line 142): Preprocess the image for ICON registration. - - `def registration_method(self, moving_image, moving_mask=None, moving_labelmap=None, moving_image_pre=None, initial_forward_transform=None)` (line 162): Register moving image to fixed image using ICON registration algorithm. - - `def finetune(self, image_pairs, output_model_filename, mask_pairs=None, epochs=1, learning_rate=DEFAULT_FINETUNE_LEARNING_RATE)` (line 386): Fine-tune the ICON network on a cohort of image pairs. +- **class RegisterImagesICON** (line 30): ICON-based deformable image registration implementation. + - `def __init__(self, log_level=logging.INFO)` (line 66): Initialize the ICON image registration class. + - `def set_weights_path(self, weights_path)` (line 84): Set a custom weights file for the uniGradICON network. + - `def set_number_of_iterations(self, number_of_iterations)` (line 98): Set the number of iterations for ICON registration. + - `def set_multi_modality(self, enable)` (line 106): Enable or disable multi-modality registration. + - `def set_mass_preservation(self, enable)` (line 123): Enable or disable mass preservation constraint. + - `def preprocess(self, image, modality='ct')` (line 140): Preprocess the image for ICON registration. + - `def registration_method(self, moving_image, moving_mask=None, moving_labelmap=None, moving_image_pre=None, initial_forward_transform=None)` (line 160): Register moving image to fixed image using ICON registration algorithm. + - `def create_mask(labelmap, dilation_mm=5.0)` (line 350): Create a binary registration mask from a labelmap. ## src/physiomotion4d/register_models_distance_maps.py - **class RegisterModelsDistanceMaps** (line 61): Register anatomical models using mask-based deformable registration. - `def __init__(self, moving_model, fixed_model, reference_image, roi_dilation_mm=20, log_level=logging.INFO)` (line 118): Initialize mask-based model registration. - - `def register(self, transform_type='Deformable', use_icon=False, icon_iterations=50)` (line 225): Perform mask-based registration of moving model to fixed model. + - `def register(self, transform_type='Deformable', use_ICON=False, icon_iterations=50)` (line 225): Perform mask-based registration of moving model to fixed model. ## src/physiomotion4d/register_models_icp.py @@ -644,9 +645,9 @@ _Re-run `py utils/generate_api_map.py` whenever public APIs change._ ## src/physiomotion4d/register_time_series_images.py - **class RegisterTimeSeriesImages** (line 31): Register a time series of images to a fixed image. - - `def __init__(self, registration_method='ants', log_level=logging.INFO)` (line 90): Initialize the time series image registration class. - - `def set_number_of_iterations_ants(self, number_of_iterations_ants)` (line 127): Set the number of iterations for ANTs registration. - - `def set_number_of_iterations_icon(self, number_of_iterations_icon)` (line 138): Set the number of iterations for ICON registration. + - `def __init__(self, registration_method='ANTS', log_level=logging.INFO)` (line 90): Initialize the time series image registration class. + - `def set_number_of_iterations_ANTS(self, number_of_iterations_ANTS)` (line 127): Set the number of iterations for ANTs registration. + - `def set_number_of_iterations_ICON(self, number_of_iterations_ICON)` (line 138): Set the number of iterations for ICON registration. - `def set_number_of_iterations_greedy(self, number_of_iterations_greedy)` (line 146): Set the number of iterations for Greedy registration. - `def set_smooth_prior_transform_sigma(self, smooth_prior_transform_sigma)` (line 157): Set the sigma for smoothing the prior transform. - `def set_mask_dilation(self, mask_dilation_mm)` (line 167): Set the dilation of the fixed and moving image masks. @@ -685,8 +686,8 @@ _Re-run `py utils/generate_api_map.py` whenever public APIs change._ - `def set_trim_branches(self, trim_branches)` (line 118): Enable trimming of pulmonary and great-vessel branches. - `def set_simpleware_executable_path(self, path)` (line 133): Set the path to the Simpleware Medical console executable. - `def segmentation_method(self, preprocessed_image)` (line 146): Run Simpleware Medical ASCardio segmentation on the preprocessed image. - - `def get_landmarks(self)` (line 346): Get the landmarks. - - `def trim_branches(self, labelmap_image)` (line 350): Trim pulmonary and great-vessel branches back to the cardiac region. + - `def get_landmarks(self)` (line 392): Get the landmarks. + - `def trim_branches(self, labelmap_image)` (line 396): Trim pulmonary and great-vessel branches back to the cardiac region. ## src/physiomotion4d/test_tools.py @@ -852,6 +853,18 @@ _Re-run `py utils/generate_api_map.py` whenever public APIs change._ - `def set_pca_number_of_components(self, n)` (line 103): Set number of PCA components to retain. - `def run_workflow(self)` (line 313): Run the full pipeline and return a dictionary of results (no file I/O). +## src/physiomotion4d/workflow_fine_tune_icon_registration.py + +- **class WorkflowFineTuneICONRegistration** (line 53): Fine-tune uniGradICON on paired 3D images and apply the fine-tuned weights. + - `def __init__(self, subject_image_files, output_dir, fine_tune_name, subject_ids=None, subject_segmentation_files=None, subject_mask_files=None, subject_landmark_files=None, epochs=2000, batch_size=4, learning_rate=5e-05, input_shape=(175, 175, 175), similarity='lncc', lambda_value=1.5, dice_loss_weight=0.5, lncc_sigma=5, ct_window=(-1000.0, 1000.0), is_ct=True, gpus=None, eval_period=10, save_period=50, mask_dilation_mm=5.0, mask_dir=None, unigradicon_src_path=None, log_level=logging.INFO)` (line 135): Initialize the ICON fine-tuning workflow. + - `def uses_segmentations(self)` (line 313): Whether at least one segmentation file is supplied for training. + - `def uses_masks(self)` (line 321): Whether the dataset will have a ``mask`` field on every kept entry. + - `def prepare_dataset(self)` (line 387): Write the uniGradICON dataset JSON from the configured file lists. + - `def prepare_config(self, dataset_json_path=None)` (line 492): Write the uniGradICON fine-tuning YAML config. + - `def expected_weights_path(self)` (line 557): Return the path uniGradICON writes its final checkpoint to. + - `def run_fine_tuning(self)` (line 572): Build configs and launch ``unigradicon.finetuning.finetune``. + - `def apply_registration(self, reference_image, moving_images, weights_path=None, reference_segmentation=None, reference_landmarks=None, moving_segmentations=None, moving_landmarks=None, number_of_iterations=20, modality='ct')` (line 629): Register each moving image to the reference using fine-tuned ICON weights. + ## src/physiomotion4d/workflow_fit_statistical_model_to_patient.py - **class WorkflowFitStatisticalModelToPatient** (line 56): Register anatomical models using multi-stage ICP, mask-based, and image-based @@ -863,17 +876,17 @@ _Re-run `py utils/generate_api_map.py` whenever public APIs change._ - `def set_use_mask_to_image_registration(self, use_mask_to_image_registration, template_labelmap=None, template_labelmap_organ_mesh_ids=None, template_labelmap_organ_extra_ids=None, template_labelmap_background_ids=None)` (line 427): Set whether to use mask-to-image registration. - `def register_model_to_model_icp(self)` (line 501): Perform ICP alignment of template model to patient model. - `def register_model_to_model_pca(self)` (line 559): Perform PCA-based registration after ICP alignment. - - `def register_mask_to_mask(self, use_icon_refinement=False)` (line 685): Perform mask-based deformable registration of model to patient model. - - `def register_labelmap_to_image(self, use_icon_refinement=False)` (line 753): Perform labelmap-to-image refinement. + - `def register_mask_to_mask(self, use_ICON_refinement=False)` (line 685): Perform mask-based deformable registration of model to patient model. + - `def register_labelmap_to_image(self, use_ICON_refinement=False)` (line 753): Perform labelmap-to-image refinement. - `def transform_model(self, base_model=None)` (line 873): Apply registration transforms to the model. - - `def run_workflow(self, use_icon_registration_refinement=False)` (line 938): Execute the complete multi-stage registration workflow. + - `def run_workflow(self, use_ICON_registration_refinement=False)` (line 938): Execute the complete multi-stage registration workflow. ## src/physiomotion4d/workflow_reconstruct_highres_4d_ct.py - **class WorkflowReconstructHighres4DCT** (line 35): Reconstruct high-resolution 4D CT from time series and reference image. - - `def __init__(self, time_series_images, fixed_image, reference_frame=0, register_reference=False, registration_method='ants_icon', log_level=logging.INFO)` (line 92): Initialize the high-resolution 4D CT reconstruction workflow. - - `def set_number_of_iterations_ants(self, number_of_iterations_ants)` (line 174): Set the number of iterations for ANTs registration. - - `def set_number_of_iterations_icon(self, number_of_iterations_icon)` (line 185): Set the number of iterations for ICON registration. + - `def __init__(self, time_series_images, fixed_image, reference_frame=0, register_reference=False, registration_method='ANTS_ICON', log_level=logging.INFO)` (line 92): Initialize the high-resolution 4D CT reconstruction workflow. + - `def set_number_of_iterations_ANTS(self, number_of_iterations_ANTS)` (line 174): Set the number of iterations for ANTs registration. + - `def set_number_of_iterations_ICON(self, number_of_iterations_ICON)` (line 185): Set the number of iterations for ICON registration. - `def set_prior_weight(self, prior_weight)` (line 193): Set the weight for temporal smoothing with prior transforms. - `def set_modality(self, modality)` (line 209): Set the imaging modality for registration optimization. - `def set_mask_dilation(self, mask_dilation_mm)` (line 217): Set the dilation of the fixed and moving image masks. @@ -894,13 +907,13 @@ _Re-run `py utils/generate_api_map.py` whenever public APIs change._ - `def download_test_data(test_directories)` (line 425): Download Slicer-Heart-CT data. - `def test_images(download_test_data, test_directories)` (line 452): Convert and resample 4D NRRD data; return pre-resampled time points. - `def test_labelmaps(segmenter_total_segmentator, test_images, test_directories)` (line 505): Segment each time point with TotalSegmentator and return result dicts. -- `def test_transforms(registrar_ants, test_images, test_directories)` (line 546): Perform ANTs registration and return results. +- `def test_transforms(registrar_ANTS, test_images, test_directories)` (line 546): Perform ANTs registration and return results. - `def segmenter_total_segmentator()` (line 601): Create a SegmentChestTotalSegmentator instance. - `def segmenter_simpleware()` (line 607): Create a SegmentHeartSimpleware instance. - `def contour_tools()` (line 613): Create a ContourTools instance. -- `def registrar_ants()` (line 619): Create a RegisterImagesANTs instance. +- `def registrar_ANTS()` (line 619): Create a RegisterImagesANTS instance. - `def registrar_greedy()` (line 625): Create a RegisterImagesGreedy instance. -- `def registrar_icon()` (line 631): Create a RegisterImagesICON instance. +- `def registrar_ICON()` (line 631): Create a RegisterImagesICON instance. - `def transform_tools()` (line 637): Create a TransformTools instance. ## tests/test_anatomy_taxonomy.py @@ -1028,23 +1041,23 @@ _Re-run `py utils/generate_api_map.py` whenever public APIs change._ ## tests/test_register_images_ants.py -- **class TestRegisterImagesANTs** (line 24): Test suite for ANTs-based image registration. - - `def test_registrar_initialization(self, registrar_ants)` (line 27): Test that RegisterImagesANTs initializes correctly. - - `def test_set_modality(self, registrar_ants)` (line 35): Test setting imaging modality. - - `def test_set_fixed_image(self, registrar_ants, test_images)` (line 45): Test setting fixed image. - - `def test_register_without_mask(self, registrar_ants, test_images, test_directories)` (line 58): Test basic registration without masks. - - `def test_register_with_mask(self, registrar_ants, test_images, test_directories)` (line 112): Test registration with binary masks. - - `def test_transform_application(self, registrar_ants, test_images, test_directories)` (line 205): Test applying registration transforms to images. - - `def test_preprocess_images(self, registrar_ants, test_images)` (line 259): Test image preprocessing. - - `def test_registration_with_initial_transform(self, registrar_ants, test_images, test_directories)` (line 277): Test registration with initial transform. - - `def test_multiple_registrations(self, registrar_ants, test_images)` (line 312): Test running multiple registrations in sequence. - - `def test_transform_types(self, registrar_ants, test_images)` (line 340): Test that transforms are correct ITK types. - - `def test_image_conversion_cycle_scalar(self, registrar_ants, test_images)` (line 368): Test round-trip conversion: ITK image -> ANTs -> ITK for scalar images. - - `def test_image_conversion_cycle_different_dtypes(self, registrar_ants, test_images)` (line 444): Test round-trip conversion with different data types. - - `def test_image_conversion_preserves_metadata(self, registrar_ants)` (line 476): Test that image conversion preserves all metadata. - - `def test_transform_conversion_cycle_affine(self, registrar_ants, test_images)` (line 523): Test round-trip conversion: ITK affine transform -> ANTs -> ITK. - - `def test_transform_conversion_cycle_displacement_field(self, registrar_ants, test_images)` (line 629): Test round-trip conversion: ITK displacement field -> ANTs -> ITK. - - `def test_transform_conversion_with_composite(self, registrar_ants, test_images)` (line 713): Test conversion of composite transforms. +- **class TestRegisterImagesANTS** (line 24): Test suite for ANTs-based image registration. + - `def test_registrar_initialization(self, registrar_ANTS)` (line 27): Test that RegisterImagesANTS initializes correctly. + - `def test_set_modality(self, registrar_ANTS)` (line 35): Test setting imaging modality. + - `def test_set_fixed_image(self, registrar_ANTS, test_images)` (line 45): Test setting fixed image. + - `def test_register_without_mask(self, registrar_ANTS, test_images, test_directories)` (line 58): Test basic registration without masks. + - `def test_register_with_mask(self, registrar_ANTS, test_images, test_directories)` (line 112): Test registration with binary masks. + - `def test_transform_application(self, registrar_ANTS, test_images, test_directories)` (line 205): Test applying registration transforms to images. + - `def test_preprocess_images(self, registrar_ANTS, test_images)` (line 259): Test image preprocessing. + - `def test_registration_with_initial_transform(self, registrar_ANTS, test_images, test_directories)` (line 277): Test registration with initial transform. + - `def test_multiple_registrations(self, registrar_ANTS, test_images)` (line 312): Test running multiple registrations in sequence. + - `def test_transform_types(self, registrar_ANTS, test_images)` (line 340): Test that transforms are correct ITK types. + - `def test_image_conversion_cycle_scalar(self, registrar_ANTS, test_images)` (line 368): Test round-trip conversion: ITK image -> ANTs -> ITK for scalar images. + - `def test_image_conversion_cycle_different_dtypes(self, registrar_ANTS, test_images)` (line 444): Test round-trip conversion with different data types. + - `def test_image_conversion_preserves_metadata(self, registrar_ANTS)` (line 476): Test that image conversion preserves all metadata. + - `def test_transform_conversion_cycle_affine(self, registrar_ANTS, test_images)` (line 523): Test round-trip conversion: ITK affine transform -> ANTs -> ITK. + - `def test_transform_conversion_cycle_displacement_field(self, registrar_ANTS, test_images)` (line 629): Test round-trip conversion: ITK displacement field -> ANTs -> ITK. + - `def test_transform_conversion_with_composite(self, registrar_ANTS, test_images)` (line 713): Test conversion of composite transforms. ## tests/test_register_images_greedy.py @@ -1060,20 +1073,20 @@ _Re-run `py utils/generate_api_map.py` whenever public APIs change._ ## tests/test_register_images_icon.py - **class TestRegisterImagesICON** (line 23): Test suite for ICON-based image registration. - - `def test_registrar_initialization(self, registrar_icon)` (line 26): Test that RegisterImagesICON initializes correctly. - - `def test_set_modality(self, registrar_icon)` (line 39): Test setting imaging modality. - - `def test_set_number_of_iterations(self, registrar_icon)` (line 49): Test setting number of iterations. - - `def test_set_fixed_image(self, registrar_icon, test_images)` (line 61): Test setting fixed image. - - `def test_set_mass_preservation(self, registrar_icon)` (line 74): Test setting mass preservation flag. - - `def test_set_multi_modality(self, registrar_icon)` (line 86): Test setting multi-modality flag. - - `def test_register_without_mask(self, registrar_icon, test_images, test_directories)` (line 96): Test basic ICON registration without masks. - - `def test_register_with_mask(self, registrar_icon, test_images, test_directories)` (line 151): Test ICON registration with binary masks. - - `def test_transform_application(self, registrar_icon, test_images, test_directories)` (line 245): Test applying ICON registration transforms to images. - - `def test_inverse_consistency(self, registrar_icon, test_images)` (line 299): Test ICON's inverse consistency property. - - `def test_preprocess_images(self, registrar_icon, test_images)` (line 345): Test image preprocessing for ICON. - - `def test_registration_with_initial_transform(self, registrar_icon, test_images, test_directories)` (line 363): Test ICON registration with initial transform. - - `def test_transform_types(self, registrar_icon, test_images)` (line 399): Test that ICON transforms are correct ITK types. - - `def test_different_iteration_counts(self, registrar_icon, test_images)` (line 440): Test ICON with different iteration counts. + - `def test_registrar_initialization(self, registrar_ICON)` (line 26): Test that RegisterImagesICON initializes correctly. + - `def test_set_modality(self, registrar_ICON)` (line 39): Test setting imaging modality. + - `def test_set_number_of_iterations(self, registrar_ICON)` (line 49): Test setting number of iterations. + - `def test_set_fixed_image(self, registrar_ICON, test_images)` (line 61): Test setting fixed image. + - `def test_set_mass_preservation(self, registrar_ICON)` (line 74): Test setting mass preservation flag. + - `def test_set_multi_modality(self, registrar_ICON)` (line 86): Test setting multi-modality flag. + - `def test_register_without_mask(self, registrar_ICON, test_images, test_directories)` (line 96): Test basic ICON registration without masks. + - `def test_register_with_mask(self, registrar_ICON, test_images, test_directories)` (line 151): Test ICON registration with binary masks. + - `def test_transform_application(self, registrar_ICON, test_images, test_directories)` (line 245): Test applying ICON registration transforms to images. + - `def test_inverse_consistency(self, registrar_ICON, test_images)` (line 299): Test ICON's inverse consistency property. + - `def test_preprocess_images(self, registrar_ICON, test_images)` (line 345): Test image preprocessing for ICON. + - `def test_registration_with_initial_transform(self, registrar_ICON, test_images, test_directories)` (line 363): Test ICON registration with initial transform. + - `def test_transform_types(self, registrar_ICON, test_images)` (line 399): Test that ICON transforms are correct ITK types. + - `def test_different_iteration_counts(self, registrar_ICON, test_images)` (line 440): Test ICON with different iteration counts. ## tests/test_register_models_pca.py @@ -1084,8 +1097,8 @@ _Re-run `py utils/generate_api_map.py` whenever public APIs change._ ## tests/test_register_time_series_images.py - **class TestRegisterTimeSeriesImages** (line 24): Test suite for time series image registration. - - `def test_registrar_initialization_ants(self)` (line 29): Test that RegisterTimeSeriesImages initializes correctly with ANTs. - - `def test_registrar_initialization_icon(self)` (line 43): Test that RegisterTimeSeriesImages initializes correctly with ICON. + - `def test_registrar_initialization_ANTS(self)` (line 29): Test that RegisterTimeSeriesImages initializes correctly with ANTs. + - `def test_registrar_initialization_ICON(self)` (line 43): Test that RegisterTimeSeriesImages initializes correctly with ICON. - `def test_registrar_initialization_greedy(self)` (line 57): Test that RegisterTimeSeriesImages initializes correctly with Greedy. - `def test_registrar_initialization_invalid_method(self)` (line 73): Test that invalid registration method raises error. - `def test_set_modality(self)` (line 80): Test setting imaging modality. @@ -1099,7 +1112,7 @@ _Re-run `py utils/generate_api_map.py` whenever public APIs change._ - `def test_register_time_series_error_invalid_starting_index(self, test_images)` (line 337): Test that error is raised for invalid starting index. - `def test_register_time_series_error_invalid_prior_portion(self, test_images)` (line 360): Test that error is raised for invalid prior portion value. - `def test_transform_application_time_series(self, test_images, test_directories)` (line 385): Test applying transforms from time series registration. - - `def test_register_time_series_icon(self, test_images)` (line 437): Test time series registration with ICON method. + - `def test_register_time_series_ICON(self, test_images)` (line 437): Test time series registration with ICON method. - `def test_register_time_series_with_mask(self, test_images, test_directories)` (line 462): Test time series registration with fixed image mask. - `def test_bidirectional_registration(self, test_images)` (line 507): Test that bidirectional registration works correctly. @@ -1221,6 +1234,32 @@ _Re-run `py utils/generate_api_map.py` whenever public APIs change._ - `def test_create_usd_files_passes_times_per_second(monkeypatch, tmp_path)` (line 14): Workflow forwards FPS to VTK-to-USD for shape (X, Y, Z, T) outputs. +## tests/test_workflow_fine_tune_icon_registration.py + +- `def two_subject_dataset(tmp_path)` (line 43): Two patients, two frames each, with matching labelmaps on disk. +- `def test_init_requires_output_dir_and_name(tmp_path)` (line 80): output_dir and fine_tune_name are required positional args. +- `def test_init_rejects_empty_image_files(tmp_path)` (line 88): Empty subject list raises immediately. +- `def test_init_rejects_mismatched_companion_lengths(tmp_path)` (line 98): Mask/seg/landmark lists must match subject_image_files shape exactly. +- `def test_init_rejects_duplicate_subject_ids(tmp_path)` (line 109): Duplicate subject IDs collapse paired groups, so reject them up front. +- `def test_init_rejects_mismatched_subject_ids_length(tmp_path)` (line 120): subject_ids must have one entry per subject. +- `def test_uses_segmentations_and_uses_masks_flags(tmp_path)` (line 131): The two helper flags reflect supplied companions independently. +- `def test_create_mask_thresholds_and_dilates()` (line 160): Single-voxel labelmap becomes a binary mask whose dilation grows it. +- `def test_prepare_dataset_uses_real_subject_ids(two_subject_dataset)` (line 185): Subject IDs round-trip from the caller into every dataset entry. +- `def test_prepare_dataset_skips_frames_with_missing_segmentation(tmp_path)` (line 207): A frame with no seg available is dropped when use_label is required. +- `def test_prepare_dataset_uses_explicit_mask_over_derived(tmp_path)` (line 234): When subject_mask_files supplies a mask, it overrides the derived one. +- `def test_prepare_dataset_mask_only_no_segmentations(tmp_path)` (line 263): Mask-only input: entries have ``mask`` but no ``segmentation`` field. +- `def test_prepare_dataset_derives_mask_next_to_labelmap_by_default(two_subject_dataset)` (line 286): Derived masks land next to each labelmap when ``mask_dir`` is not set. +- `def test_prepare_dataset_derives_mask_under_explicit_mask_dir(two_subject_dataset, tmp_path)` (line 312): Explicit ``mask_dir`` collects every derived mask in that single folder. +- `def test_prepare_dataset_raises_on_missing_image_file(tmp_path)` (line 336): Image existence is a hard requirement; missing image aborts the build. +- `def test_prepare_config_emits_uniGradICON_yaml(two_subject_dataset)` (line 353): YAML config matches uniGradICON's expected structure when seg is present. +- `def test_prepare_config_flags_off_when_no_companions(tmp_path)` (line 390): Without seg or mask, ``use_label`` and ``loss_function_masking`` are False. +- `def test_prepare_config_requires_dataset_json(tmp_path)` (line 411): Calling prepare_config without first preparing the dataset is an error. +- `def test_expected_weights_path_layout(tmp_path)` (line 428): Weights land at ``output_dir//_model/checkpoints/...``. +- `def test_run_fine_tuning_invokes_unigradicon_subprocess(monkeypatch, two_subject_dataset)` (line 447): run_fine_tuning launches the uniGradICON finetune module with the YAML path. +- `def test_run_fine_tuning_without_unigradicon_src(monkeypatch, two_subject_dataset)` (line 489): When unigradicon_src_path is None, PYTHONPATH is not prefixed. +- `def test_apply_registration_rejects_empty_moving(tmp_path)` (line 519): apply_registration validates inputs before touching the registrar. +- `def test_apply_registration_rejects_mismatched_companions(tmp_path)` (line 533): moving_segmentations / moving_landmarks length must match moving_images. + ## tests/test_workflow_fit_statistical_model_to_patient.py - `def test_auto_generate_mask_accumulates_multilabel_models(monkeypatch)` (line 19): Multi-model masks accumulate label IDs instead of overwriting prior labels. diff --git a/docs/api/index.rst b/docs/api/index.rst index 2b6d870..a3909bd 100644 --- a/docs/api/index.rst +++ b/docs/api/index.rst @@ -40,7 +40,7 @@ By Category **Image Registration** * :class:`~physiomotion4d.RegisterImagesBase` - Base registration class - * :class:`~physiomotion4d.RegisterImagesANTs` - ANTs registration + * :class:`~physiomotion4d.RegisterImagesANTS` - ANTs registration * :class:`~physiomotion4d.RegisterImagesICON` - Icon deep learning registration * :class:`~physiomotion4d.RegisterTimeSeriesImages` - 4D time series registration diff --git a/docs/api/registration/ants.rst b/docs/api/registration/ants.rst index c7deda9..721707d 100644 --- a/docs/api/registration/ants.rst +++ b/docs/api/registration/ants.rst @@ -5,13 +5,13 @@ ANTs Registration .. module:: physiomotion4d.register_images_ants .. currentmodule:: physiomotion4d -``RegisterImagesANTs`` provides optimization-based deformable image +``RegisterImagesANTS`` provides optimization-based deformable image registration through ANTs. Class Reference =============== -.. autoclass:: RegisterImagesANTs +.. autoclass:: RegisterImagesANTS :members: :undoc-members: :show-inheritance: @@ -23,12 +23,12 @@ Basic Registration import itk - from physiomotion4d import RegisterImagesANTs + from physiomotion4d import RegisterImagesANTS fixed = itk.imread("reference.mha") moving = itk.imread("moving.mha") - registrar = RegisterImagesANTs() + registrar = RegisterImagesANTS() registrar.set_modality("ct") registrar.set_transform_type("SyN") registrar.set_number_of_iterations([30, 15, 7]) diff --git a/docs/api/registration/index.rst b/docs/api/registration/index.rst index 8a472b0..08e7126 100644 --- a/docs/api/registration/index.rst +++ b/docs/api/registration/index.rst @@ -19,7 +19,7 @@ PhysioMotion4D image registration classes align moving 3D images to a fixed Common Result Shape =================== -``RegisterImagesANTs.register()`` and ``RegisterImagesICON.register()`` return: +``RegisterImagesANTS.register()`` and ``RegisterImagesICON.register()`` return: * ``forward_transform`` * ``inverse_transform`` @@ -35,12 +35,12 @@ Basic Example import itk - from physiomotion4d import RegisterImagesANTs + from physiomotion4d import RegisterImagesANTS fixed = itk.imread("reference.mha") moving = itk.imread("moving.mha") - registrar = RegisterImagesANTs() + registrar = RegisterImagesANTS() registrar.set_modality("ct") registrar.set_fixed_image(fixed) diff --git a/docs/api/registration/time_series.rst b/docs/api/registration/time_series.rst index 5a0a09a..ac6ccaa 100644 --- a/docs/api/registration/time_series.rst +++ b/docs/api/registration/time_series.rst @@ -6,7 +6,7 @@ Time-Series Registration .. currentmodule:: physiomotion4d ``RegisterTimeSeriesImages`` registers ordered 3D image phases to a reference -frame using ANTs, Greedy, ICON, or combined ``ants_icon`` / ``greedy_icon`` +frame using ANTs, Greedy, ICON, or combined ``ANTS_ICON`` / ``greedy_ICON`` methods. Class Reference @@ -29,7 +29,7 @@ Basic Usage images = [itk.imread(f"phase_{idx:02d}.mha") for idx in range(10)] - registrar = RegisterTimeSeriesImages(registration_method="ants") + registrar = RegisterTimeSeriesImages(registration_method="ANTS") registrar.set_fixed_image(images[0]) result = registrar.register_time_series( diff --git a/docs/api/workflows.rst b/docs/api/workflows.rst index 6b20255..08c0cf5 100644 --- a/docs/api/workflows.rst +++ b/docs/api/workflows.rst @@ -157,7 +157,7 @@ High-Resolution 4D CT Reconstruction time_series_images=time_series_images, fixed_image=time_series_images[0], reference_frame=0, - registration_method="ants", + registration_method="ANTS", ) result = workflow.run_workflow(upsample_to_fixed_resolution=True) diff --git a/docs/architecture.rst b/docs/architecture.rst index 58d09aa..5825ae2 100644 --- a/docs/architecture.rst +++ b/docs/architecture.rst @@ -25,7 +25,7 @@ Data Flow v RegisterTimeSeriesImages | | - | +--> RegisterImagesANTs / RegisterImagesICON + | +--> RegisterImagesANTS / RegisterImagesICON v SegmentChestTotalSegmentator / SegmentHeartSimpleware | diff --git a/docs/cli_scripts/4dct_reconstruction.rst b/docs/cli_scripts/4dct_reconstruction.rst index 811437e..800b8c6 100644 --- a/docs/cli_scripts/4dct_reconstruction.rst +++ b/docs/cli_scripts/4dct_reconstruction.rst @@ -47,8 +47,8 @@ Registration Options physiomotion4d-reconstruct-highres-4d-ct \ --time-series-images frame_*.mha \ --fixed-image highres_reference.mha \ - --registration-method ants \ - --ants-iterations 30 15 7 3 \ + --registration-method ANTS \ + --ANTS-iterations 30 15 7 3 \ --prior-weight 0.5 \ --output-dir ./results diff --git a/docs/cli_scripts/fit_statistical_model_to_patient.rst b/docs/cli_scripts/fit_statistical_model_to_patient.rst index 3d065d3..5ed0a07 100644 --- a/docs/cli_scripts/fit_statistical_model_to_patient.rst +++ b/docs/cli_scripts/fit_statistical_model_to_patient.rst @@ -119,7 +119,7 @@ Registration Configuration Enable mask-to-image refinement registration. Requires ``--template-labelmap`` and template label IDs. Disabled by default. -``--use-icon-refinement`` +``--use-ICON-refinement`` Enable ICON deep learning registration refinement (default: disabled) Output Options diff --git a/docs/developer/registration_images.rst b/docs/developer/registration_images.rst index 66e01e5..dab4078 100644 --- a/docs/developer/registration_images.rst +++ b/docs/developer/registration_images.rst @@ -12,12 +12,12 @@ Basic Pattern import itk - from physiomotion4d import RegisterImagesANTs + from physiomotion4d import RegisterImagesANTS fixed = itk.imread("fixed.mha") moving = itk.imread("moving.mha") - registrar = RegisterImagesANTs() + registrar = RegisterImagesANTS() registrar.set_modality("ct") registrar.set_fixed_image(fixed) @@ -38,7 +38,7 @@ Time Series images = [itk.imread(f"phase_{idx:02d}.mha") for idx in range(10)] - registrar = RegisterTimeSeriesImages(registration_method="ants") + registrar = RegisterTimeSeriesImages(registration_method="ANTS") registrar.set_fixed_image(images[0]) result = registrar.register_time_series( moving_images=images, diff --git a/docs/examples.rst b/docs/examples.rst index 283e018..108b07a 100644 --- a/docs/examples.rst +++ b/docs/examples.rst @@ -66,7 +66,7 @@ phase images: time_series_images=time_series_images, fixed_image=fixed_image, reference_frame=0, - registration_method="ants", + registration_method="ANTS", ) result = workflow.run_workflow(upsample_to_fixed_resolution=True) @@ -176,10 +176,10 @@ Advanced registration with multiple stages: .. code-block:: python - from physiomotion4d import RegisterImagesANTs + from physiomotion4d import RegisterImagesANTS import itk - registerer = RegisterImagesANTs() + registerer = RegisterImagesANTS() fixed = itk.imread("reference.mha") moving = itk.imread("target.mha") diff --git a/docs/tutorials.rst b/docs/tutorials.rst index 4c76264..2ba1f92 100644 --- a/docs/tutorials.rst +++ b/docs/tutorials.rst @@ -244,7 +244,7 @@ Script ``tutorials/tutorial_08_dirlab_pca_time_series.py`` Workflow - ``RegisterTimeSeriesImages`` with ``registration_method='ants_icon'`` and + ``RegisterTimeSeriesImages`` with ``registration_method='ANTS_ICON'`` and ``TransformTools`` Dataset diff --git a/experiments/Heart-Create_Statistical_Model/3-registration_based_correspondence.py b/experiments/Heart-Create_Statistical_Model/3-registration_based_correspondence.py index 2baf2ef..59179df 100644 --- a/experiments/Heart-Create_Statistical_Model/3-registration_based_correspondence.py +++ b/experiments/Heart-Create_Statistical_Model/3-registration_based_correspondence.py @@ -114,7 +114,7 @@ # This performs progressive multi-stage registration: rigid → affine → SyN deformable result = registrar.register( transform_type="Deformable", # Uses ANTs SyN (Symmetric Normalization) - use_icon=False, # Set to True for additional ICON deep learning refinement + use_ICON=False, # Set to True for additional ICON deep learning refinement ) forward_transform = result["forward_transform"] @@ -341,6 +341,6 @@ # 1. Rigid alignment # 2. Affine transformation # 3. SyN deformable registration (diffeomorphic) -# - Setting `use_icon=True` in the `register()` call would add ICON deep learning refinement after SyN +# - Setting `use_ICON=True` in the `register()` call would add ICON deep learning refinement after SyN # - The `roi_dilation_mm` parameter controls the dilation of the ROI mask (default 20mm) # - SyN registration provides smooth, invertible deformation fields for anatomical correspondence diff --git a/experiments/Heart-GatedCT_To_USD/1-register_images.py b/experiments/Heart-GatedCT_To_USD/1-register_images.py index 2ac45cf..5ac88e8 100644 --- a/experiments/Heart-GatedCT_To_USD/1-register_images.py +++ b/experiments/Heart-GatedCT_To_USD/1-register_images.py @@ -4,7 +4,7 @@ import itk -from physiomotion4d.register_images_ants import RegisterImagesANTs +from physiomotion4d.register_images_ants import RegisterImagesANTS from physiomotion4d.segment_chest_total_segmentator import SegmentChestTotalSegmentator from physiomotion4d.test_tools import TestTools from physiomotion4d.transform_tools import TransformTools @@ -76,7 +76,7 @@ ) # %% - reg = RegisterImagesANTs() + reg = RegisterImagesANTS() reg.set_mask_dilation(5) reg.set_number_of_iterations([10, 5, 2]) diff --git a/experiments/Heart-GatedCT_To_USD/test_compare_registration_speed.py b/experiments/Heart-GatedCT_To_USD/test_compare_registration_speed.py index 95b9991..addadea 100644 --- a/experiments/Heart-GatedCT_To_USD/test_compare_registration_speed.py +++ b/experiments/Heart-GatedCT_To_USD/test_compare_registration_speed.py @@ -16,7 +16,7 @@ from itk import TubeTK as ttk from physiomotion4d.test_tools import TestTools -from physiomotion4d.register_images_ants import RegisterImagesANTs +from physiomotion4d.register_images_ants import RegisterImagesANTS from physiomotion4d.register_images_greedy import RegisterImagesGreedy from physiomotion4d.register_images_icon import RegisterImagesICON @@ -106,7 +106,7 @@ # --- ANTs (deformable SyN) --- try: - reg_a = RegisterImagesANTs() + reg_a = RegisterImagesANTS() reg_a.set_modality("ct") reg_a.set_transform_type("Deformable") reg_a.set_number_of_iterations([10, 5, 2]) # reduced for speed diff --git a/experiments/Heart-Simpleware_Segmentation/simpleware_heart_segmentation.py b/experiments/Heart-Simpleware_Segmentation/simpleware_heart_segmentation.py index 0def33e..19589a2 100644 --- a/experiments/Heart-Simpleware_Segmentation/simpleware_heart_segmentation.py +++ b/experiments/Heart-Simpleware_Segmentation/simpleware_heart_segmentation.py @@ -24,14 +24,17 @@ # %% import logging import os +import tkinter as tk +from tkinter import filedialog import itk import matplotlib.pyplot as plt import numpy as np import pyvista as pv -from physiomotion4d.test_tools import TestTools +from physiomotion4d.landmark_tools import LandmarkTools from physiomotion4d.segment_heart_simpleware import SegmentHeartSimpleware +from physiomotion4d.test_tools import TestTools _HERE = os.path.dirname(os.path.abspath(__file__)) @@ -57,9 +60,18 @@ # Load a cardiac CT image for segmentation. This should be a 3D volume containing the heart. # %% -input_image_path = os.path.join( - _HERE, "..", "..", "data", "CHOP-Valve4D", "CT", "RVOT28-Dias.nii.gz" -) +if TestTools.running_as_test(): + input_image_path = os.path.join( + _HERE, "..", "..", "data", "CHOP-Valve4D", "CT", "RVOT28-Dias.nii.gz" + ) +else: + root = tk.Tk() + root.withdraw() + input_image_path = filedialog.askopenfilename( + title="Select a cardiac CT image", + filetypes=[("NIfTI", "*.nii.gz"), ("MetaIO", "*.mhd"), ("All files", "*.*")], + ) + root.destroy() # Load the image try: @@ -79,6 +91,7 @@ # Display a few slices of the input image to verify it loaded correctly. # %% +image_array = None if input_image is not None: # Get numpy array from ITK image image_array = itk.array_from_image(input_image) @@ -87,6 +100,8 @@ fig, axes = plt.subplots(1, 3, figsize=(15, 5)) # Axial slice (middle) + dirs = input_image.GetDirection() + print(f"Direction: {dirs}") axial_slice = image_array[image_array.shape[0] // 2, :, :] axes[0].imshow(axial_slice, cmap="gray", vmin=-200, vmax=400) axes[0].set_title("Axial View") @@ -158,8 +173,8 @@ try: # Perform segmentation - # Set contrast_enhanced_study=True if your CT scan used contrast agent - result = segmenter.segment(input_image, contrast_enhanced_study=True) + # For Simpleware, set contrast_enhanced_study=False always! + result = segmenter.segment(input_image, contrast_enhanced_study=False) print("\nSegmentation completed successfully!") @@ -191,6 +206,18 @@ os.path.join(output_dir, "contrast_mask_simpleware.nii.gz"), compression=True, ) + itk.imwrite( + input_image, + os.path.join(output_dir, "input_image_simpleware.nii.gz"), + compression=True, + ) + + # Save landmarks + print("\nSaving landmarks...") + landmarks = segmenter.get_landmarks() + LandmarkTools().write_landmarks_3dslicer( + landmarks=landmarks, path=os.path.join(output_dir, "landmarks.mrk.json") + ) except FileNotFoundError as e: print(f"\nError: {e}") @@ -269,7 +296,6 @@ # %% if result is not None and input_image is not None: # Get arrays - image_array = itk.array_from_image(input_image) labelmap_array = itk.array_from_image(result["labelmap"]) heart_array = itk.array_from_image(result["heart"]) vessels_array = itk.array_from_image(result["major_vessels"]) diff --git a/experiments/Heart-Statistical_Model_To_Patient/heart_model_to_patient.py b/experiments/Heart-Statistical_Model_To_Patient/heart_model_to_patient.py index 7412633..a16b9b3 100644 --- a/experiments/Heart-Statistical_Model_To_Patient/heart_model_to_patient.py +++ b/experiments/Heart-Statistical_Model_To_Patient/heart_model_to_patient.py @@ -190,7 +190,7 @@ # Perform deformable registration print("Starting deformable mask-to-mask registration...") - m2m_results = registrar.register_mask_to_mask(use_icon_refinement=False) + m2m_results = registrar.register_mask_to_mask(use_ICON_refinement=False) m2m_inverse_transform = m2m_results["inverse_transform"] m2m_forward_transform = m2m_results["forward_transform"] m2m_model_surface = m2m_results["registered_template_model_surface"] diff --git a/experiments/LongitudinalRegistration/.gitignore b/experiments/LongitudinalRegistration/.gitignore new file mode 100644 index 0000000..19960a9 --- /dev/null +++ b/experiments/LongitudinalRegistration/.gitignore @@ -0,0 +1 @@ +uniGradICON diff --git a/experiments/LongitudinalRegistration/0-cardiacGatedCT_segment_and_landmark.py b/experiments/LongitudinalRegistration/0-cardiacGatedCT_segment_and_landmark.py new file mode 100644 index 0000000..834df73 --- /dev/null +++ b/experiments/LongitudinalRegistration/0-cardiacGatedCT_segment_and_landmark.py @@ -0,0 +1,116 @@ +# %% [markdown] +# # Segment and Landmark Duke 4D Gated CT Data +# +# Runs Simpleware ASCardio segmentation on each gated time-point image and +# stores the labelmap plus extracted landmarks for later registration +# accuracy experiments. +# +# Each `gated_nii//` directory contains one patient's time-point +# images. The script maps each `ref_images/pm00*.nii.gz` file to the matching +# `gated_nii//` directory by the first six filename characters. +# +# Segmentation labelmaps and landmarks are written to: +# `d:/PhysioMotion4D/duke_data/simple_ascardio//` +# +# Output files follow the input stem: +# - `_labelmap.nii.gz` +# - `_landmark.csv` +# + +# %% +import os +from pathlib import Path + +import itk + +from physiomotion4d import SegmentHeartSimpleware +from physiomotion4d.landmark_tools import LandmarkTools + +# %% +# Discover data (mirrors recon_4d.py) +######################################################## + +ref_data_dir = "d:/PhysioMotion4D/duke_data/ref_images" +src_data_dir_base = "d:/PhysioMotion4D/duke_data/gated_nii" +segmentation_dir_base = "d:/PhysioMotion4D/duke_data/simple_ascardio" + +ref_files = [ + os.path.join(ref_data_dir, f) + for f in sorted(os.listdir(ref_data_dir)) + if f.startswith("pm00") and f.endswith(".nii.gz") +] + +print(f"Found {len(ref_files)} reference images") + +src_data_dirs = [] +src_data_files = [] +for ref_file in ref_files: + src_dir = os.path.join(src_data_dir_base, os.path.basename(ref_file)[:6]) + src_data_dirs.append(src_dir) + + file_list = sorted(os.listdir(src_dir)) + valid_file_list = [f for f in file_list if "nop" not in f and f.endswith(".nii.gz")] + src_data_files.append(valid_file_list) + +print(f"Found {len(src_data_dirs)} source data directories") +for d, fs in zip(src_data_dirs, src_data_files, strict=True): + print(f" {d}: {len(fs)} files") + +# %% +# Function to segment images and save labelmaps and landmarks +######################################################## + + +def segment_images( + src_data_dirs: list[str], + src_data_files: list[list[str]], +) -> dict[str, str]: + """Segment each image with SegmentHeartSimpleware and save labelmaps. + + Skips images whose labelmap file already exists. Returns a mapping from + image path to labelmap path for all images. + + Args: + src_data_dirs: List of per-patient source directories. + src_data_files: List of per-patient filename lists. + + Returns: + Dict mapping absolute image path -> absolute labelmap path. + """ + segmenter = SegmentHeartSimpleware() + image_to_labelmap: dict[str, str] = {} + + for src_dir, files in zip(src_data_dirs, src_data_files, strict=True): + print(f"Segmenting {src_dir}...") + subject_id = os.path.basename(src_dir) + labelmap_dir = Path(segmentation_dir_base) / subject_id + labelmap_dir.mkdir(parents=True, exist_ok=True) + for f in files: + print(f" Segmenting {f}...") + labelmap_path = labelmap_dir / f.replace(".nii.gz", "_labelmap.nii.gz") + landmark_path = labelmap_dir / f.replace(".nii.gz", "_landmark.mrk.json") + + if not os.path.exists(labelmap_path) or not os.path.exists(landmark_path): + image_path = os.path.join(src_dir, f) + input_image = itk.imread(image_path, pixel_type=itk.F) + results = segmenter.segment(input_image, contrast_enhanced_study=False) + labelmap = results["labelmap"] + itk.imwrite(labelmap, str(labelmap_path), compression=True) + + landmarks = segmenter.get_landmarks() + LandmarkTools().write_landmarks_3dslicer(landmarks, landmark_path) + + image_to_labelmap[os.path.join(src_dir, f)] = str(labelmap_path) + + return image_to_labelmap + + +# %% +# Segment each image and save labelmaps +######################################################## + +print("\nSegmenting images...") +image_to_labelmap = segment_images(src_data_dirs, src_data_files) +print(f"Segmentation complete. {len(image_to_labelmap)} labelmaps available.\n") + +# %% diff --git a/experiments/LongitudinalRegistration/1-finetune_icon.py b/experiments/LongitudinalRegistration/1-finetune_icon.py new file mode 100644 index 0000000..968a078 --- /dev/null +++ b/experiments/LongitudinalRegistration/1-finetune_icon.py @@ -0,0 +1,186 @@ +# %% [markdown] +# # Fine-tune uniGradICON on Duke 4D Gated CT Data +# +# Discovers per-patient gated CT images and their precomputed +# SegmentHeartSimpleware labelmaps and applies the project-wide fixed 80/20 +# train/test split (sort patients in ``ref_data_dir`` by filename; the first +# 80% are train, the last 20% are test). The train cohort is handed to +# :class:`WorkflowFineTuneICONRegistration`, which builds the paired dataset +# JSON, YAML config, and derived loss-function masks, then launches +# ``unigradicon.finetuning.finetune`` as a subprocess. +# +# ``2-recon_4d_icon_eval.py`` re-derives the same split from the same sorted +# patient list — no cached split file is needed. +# +# Each patient directory under ``src_data_dir_base`` is one ``subject_id``; +# all of that patient's gated time-point frames form a paired training group. +# Frames whose labelmap is missing on disk are dropped from the dataset. + +# %% +import os +from pathlib import Path + +import itk + +from physiomotion4d import WorkflowFineTuneICONRegistration +from physiomotion4d.register_images_icon import RegisterImagesICON + +# %% [markdown] +# ## 1. Configure data, output locations, and the train/test split + +# %% +ref_data_dir = Path("d:/PhysioMotion4D/duke_data/ref_images") +src_data_dir_base = Path("d:/PhysioMotion4D/duke_data/gated_nii") +segmentation_dir_base = Path("d:/PhysioMotion4D/duke_data/simple_ascardio") + +# Where the workflow writes the dataset JSON, YAML config, derived masks, and +# the uniGradICON ``checkpoints/`` tree. experiment_dir resolves to +# ``output_dir / fine_tune_name``. +output_dir = Path("./results") +fine_tune_name = "icon_finetuned" + +# Fixed train/test split: sort patients in ``ref_data_dir`` by filename; +# first 80% are train, last 20% are test. ``2-recon_4d_icon_eval.py`` applies +# the same rule so the two scripts agree without a cached split record. +train_fraction = 0.8 + +# Local clone of uniGradICON (feat-add-finetuning branch) — prepended to +# PYTHONPATH so the subprocess picks up the local source instead of the +# installed package. Set to ``None`` to use the pip-installed unigradicon. +unigradicon_src_path: Path | None = Path(__file__).parent / "uniGradICON" / "src" + +# %% [markdown] +# ## 2. Enumerate patients and apply the fixed 80/20 split +# +# Sort ``ref_data_dir`` by filename to produce the canonical patient order. +# The first 80% become the train cohort; the last 20% are the held-out test +# cohort that ``2-recon_4d_icon_eval.py`` will evaluate. + +# %% +ref_files = sorted( + p + for p in ref_data_dir.iterdir() + if p.name.startswith("pm00") and p.suffixes[-2:] == [".nii", ".gz"] +) +all_patient_ids = [p.name[:6] for p in ref_files] +print(f"Found {len(all_patient_ids)} patients under {ref_data_dir}") + +if len(all_patient_ids) < 2: + raise FileNotFoundError( + f"Need at least 2 patients to form a train/test split; " + f"discovered {len(all_patient_ids)} under {ref_data_dir}" + ) + +n_train = max( + 1, + min(len(all_patient_ids) - 1, round(train_fraction * len(all_patient_ids))), +) +train_subjects = all_patient_ids[:n_train] +test_subjects = all_patient_ids[n_train:] +print(f" Train (first {n_train}): {train_subjects}") +print(f" Test (last {len(test_subjects)}): {test_subjects}") + +# %% [markdown] +# ## 3. Gather the train cohort's gated frames and labelmaps +# +# For each train-cohort patient, list gated frames in +# ``src_data_dir_base / `` (excluding ``"nop"`` non-gated +# references) and pair each frame with its +# ``_labelmap.nii.gz`` under ``segmentation_dir_base / ``. +# Patients with no source directory or no valid frames are skipped here only +# — they remain part of the canonical train list above, but contribute no +# training data. Missing labelmaps are recorded as ``None`` so the workflow +# skips just that frame. + +# %% +train_image_files: list[list[str]] = [] +train_segmentation_files: list[list[str | None]] = [] +valid_train_subjects: list[str] = [] + +for patient_id in train_subjects: + src_dir = src_data_dir_base / patient_id + seg_dir = segmentation_dir_base / patient_id + + if not src_dir.is_dir(): + print(f" Skipping {patient_id}: source dir {src_dir} not found") + continue + + frame_names = sorted( + f for f in os.listdir(src_dir) if "nop" not in f and f.endswith(".nii.gz") + ) + if not frame_names: + print(f" Skipping {patient_id}: no valid frames in {src_dir}") + continue + + image_paths = [str(src_dir / f) for f in frame_names] + seg_paths: list[str | None] = [] + for f in frame_names: + labelmap = seg_dir / f.replace(".nii.gz", "_labelmap.nii.gz") + seg_paths.append(str(labelmap) if labelmap.exists() else None) + + train_image_files.append(image_paths) + train_segmentation_files.append(seg_paths) + valid_train_subjects.append(patient_id) + + n_seg = sum(1 for s in seg_paths if s is not None) + print(f" {patient_id}: {len(image_paths)} frames, {n_seg} with labelmap") + +# %% [markdown] +# ## 4. Pre-compute loss-function masks next to each labelmap +# +# Use :meth:`RegisterImagesICON.create_mask` (``>0`` threshold + 5 mm +# physical-radius dilation) to derive each frame's binary heart-ROI mask and +# write it as ``_mask.nii.gz`` in the labelmap's own directory. +# Pre-computing here means the workflow does not have to re-derive masks +# during ``run_fine_tuning`` and the same masks are reused by downstream +# evaluation scripts. + +# %% +mask_dilation_mm = 5.0 +train_mask_files: list[list[str | None]] = [] +for image_paths, seg_paths in zip( + train_image_files, train_segmentation_files, strict=True +): + mask_paths: list[str | None] = [] + for seg_path in seg_paths: + if seg_path is None: + mask_paths.append(None) + continue + seg_p = Path(seg_path) + stem = seg_p.name + stem = stem[:-7] if stem.endswith(".nii.gz") else seg_p.stem + mask_p = seg_p.parent / f"{stem}_mask.nii.gz" + if not mask_p.exists(): + mask = RegisterImagesICON.create_mask( + itk.imread(str(seg_p)), dilation_mm=mask_dilation_mm + ) + itk.imwrite(mask, str(mask_p), compression=True) + mask_paths.append(str(mask_p)) + train_mask_files.append(mask_paths) + +# %% [markdown] +# ## 5. Fine-tune uniGradICON on the train cohort +# +# The workflow consumes both the labelmaps (for paired-with-seg training and +# ``use_label``) and the pre-computed masks (for ``loss_function_masking``) +# and launches ``unigradicon.finetuning.finetune`` as a subprocess. The +# final checkpoint lands at +# :meth:`WorkflowFineTuneICONRegistration.expected_weights_path`, which is +# the default ``--finetuned-weights-path`` read by ``2-recon_4d_icon_eval.py``. + +# %% +workflow = WorkflowFineTuneICONRegistration( + subject_image_files=train_image_files, + output_dir=output_dir, + fine_tune_name=fine_tune_name, + subject_ids=valid_train_subjects, + subject_segmentation_files=train_segmentation_files, + subject_mask_files=train_mask_files, + mask_dilation_mm=mask_dilation_mm, + unigradicon_src_path=unigradicon_src_path, + epochs=100, +) + +weights_path = workflow.run_fine_tuning() +print(f"\nFine-tuning complete. Expected weights at: {weights_path}") +print(f"Held-out test cohort (for 2-recon_4d_icon_eval.py): {test_subjects}") diff --git a/experiments/LongitudinalRegistration/2-recon_4d_icon_eval.py b/experiments/LongitudinalRegistration/2-recon_4d_icon_eval.py new file mode 100644 index 0000000..83ff66c --- /dev/null +++ b/experiments/LongitudinalRegistration/2-recon_4d_icon_eval.py @@ -0,0 +1,263 @@ +# %% [markdown] +# # Evaluate ICON default vs finetuned weights on held-out longitudinal CT +# +# Enumerates the Duke patient cohort by sorting ``ref_images/`` and uses the +# *last 20%* of patients as the held-out test set — the same fixed split +# applied by ``1-finetune_icon.py`` (first 80% train, last 20% test). For +# each test subject the 70th-percentile gated frame is selected as the +# reference and every other frame is registered to it twice with +# ``RegisterTimeSeriesImages``: once with the default uniGradICON weights and +# once with the finetuned checkpoint from ``1-finetune_icon.py``. The +# resampler-convention inverse transform (which maps moving-grid points back +# to reference-grid points) is applied to each time-point's precomputed +# landmarks to land them in reference space, and the Euclidean error against +# the reference landmarks is recorded. +# +# Run interactively cell-by-cell; all paths are hard-coded. + +# %% +import csv +import re +from pathlib import Path +from typing import Optional + +import itk +import numpy as np + +from physiomotion4d import RegisterTimeSeriesImages +from physiomotion4d.landmark_tools import LandmarkTools +from physiomotion4d.register_images_icon import RegisterImagesICON + +# %% [markdown] +# ## 1. Hard-coded paths and configuration + +# %% +ref_data_dir = Path("d:/PhysioMotion4D/duke_data/ref_images") +timepoint_base_dir = Path("d:/PhysioMotion4D/duke_data/gated_nii") +segmentation_base_dir = Path("d:/PhysioMotion4D/duke_data/simple_ascardio") +output_dir = Path("./results") +finetuned_weights_path = Path( + "./results/icon_finetuned/checkpoints/Finetune_multi_final.trch" +) + +train_fraction = 0.8 +icon_iterations = 20 +reference_percentile = 0.70 +exclude_tokens = ("nop", "dia", "sys", "_ref") +timepoint_re = re.compile(r"_g(?P[0-9]{3})") + +methods: list[tuple[str, Optional[Path]]] = [ + ("icon_default", None), + ("icon_finetuned", finetuned_weights_path), +] + +output_dir.mkdir(parents=True, exist_ok=True) +detail_file = output_dir / "landmark_errors_by_point.csv" +summary_file = output_dir / "registration_summary.csv" +if detail_file.exists(): + detail_file.unlink() + +# %% [markdown] +# ## 2. Derive the held-out test cohort +# +# The fixed split is: sort ``ref_data_dir`` by filename, take the *first* +# 80% of patients as train, the *last* 20% as test. ``1-finetune_icon.py`` +# applies the same rule so the two scripts agree without any cached record. + +# %% +ref_files = sorted( + p + for p in ref_data_dir.iterdir() + if p.name.startswith("pm00") and p.suffixes[-2:] == [".nii", ".gz"] +) +all_patient_ids = [p.name[:6] for p in ref_files] +n_train = max( + 1, min(len(all_patient_ids) - 1, round(train_fraction * len(all_patient_ids))) +) +test_subjects = all_patient_ids[n_train:] +print( + f"Cohort: {len(all_patient_ids)} patients; " + f"first {n_train} train, last {len(test_subjects)} test." +) +print(f"Held-out test subjects: {test_subjects}") + +# %% [markdown] +# ## 3. Reader instance used in the per-frame inner loop +# +# Landmarks are read with :meth:`LandmarkTools.read_landmarks_3dslicer` — +# they were written as ``_landmark.mrk.json`` (3D Slicer Markups JSON, +# LPS) by ``0-cardiacGatedCT_segment_and_landmark.py``. Binary registration +# masks come from :meth:`RegisterImagesICON.create_mask` (``>0`` threshold +# plus 5 mm dilation by default), matching the loss-function masks used +# during fine-tuning in ``1-finetune_icon.py``. + +# %% +landmark_tools = LandmarkTools() + + +# %% [markdown] +# ## 4. Register and score every test subject under both ICON methods + +# %% +summary_rows: list[dict[str, object]] = [] + +for subject_id in test_subjects: + source_dir = timepoint_base_dir / subject_id + seg_dir = segmentation_base_dir / subject_id + + image_files = [ + p + for p in sorted(source_dir.glob("*.nii.gz")) + if not any(t in p.name for t in exclude_tokens) + ] + stems = [p.name[:-7] for p in image_files] + labelmap_files = [seg_dir / f"{s}_labelmap.nii.gz" for s in stems] + landmark_files = [seg_dir / f"{s}_landmark.mrk.json" for s in stems] + timepoints = [timepoint_re.search(p.name).group("timepoint") for p in image_files] + + reference_index = int(round(reference_percentile * (len(image_files) - 1))) + print( + f"\nSubject {subject_id}: {len(image_files)} time points, " + f"reference index {reference_index} (g{timepoints[reference_index]})" + ) + + fixed_image = itk.imread(str(image_files[reference_index]), pixel_type=itk.F) + fixed_mask = RegisterImagesICON.create_mask( + itk.imread(str(labelmap_files[reference_index])) + ) + reference_landmarks = landmark_tools.read_landmarks_3dslicer( + landmark_files[reference_index] + ) + + moving_images = [itk.imread(str(p), pixel_type=itk.F) for p in image_files] + moving_masks = [ + RegisterImagesICON.create_mask(itk.imread(str(p))) for p in labelmap_files + ] + + for method_name, weights_path in methods: + print(f" Method: {method_name}") + registrar = RegisterTimeSeriesImages(registration_method="ICON") + registrar.set_modality("ct") + registrar.set_fixed_image(fixed_image) + registrar.set_fixed_mask(fixed_mask) + registrar.set_number_of_iterations_ICON(icon_iterations) + if weights_path is not None: + registrar.registrar_ICON.set_weights_path(str(weights_path)) + + result = registrar.register_time_series( + moving_images=moving_images, + moving_masks=moving_masks, + moving_labelmaps=None, + reference_frame=reference_index, + register_reference=False, + prior_weight=0.0, + ) + + method_dir = output_dir / method_name / subject_id + method_dir.mkdir(parents=True, exist_ok=True) + + for index, image_file in enumerate(image_files): + if index == reference_index: + continue + timepoint = timepoints[index] + timepoint_dir = method_dir / timepoint + timepoint_dir.mkdir(parents=True, exist_ok=True) + + inverse_transform = result["inverse_transforms"][index] + itk.transformwrite( + result["forward_transforms"][index], + str(timepoint_dir / "time_to_reference.hdf"), + compression=True, + ) + itk.transformwrite( + inverse_transform, + str(timepoint_dir / "reference_to_time.hdf"), + compression=True, + ) + + # inverse_transform follows the ITK resampler convention — it maps + # moving-grid points back to reference-grid points, which is what + # we need to warp time-point landmarks into reference space. + timepoint_landmarks = landmark_tools.read_landmarks_3dslicer( + landmark_files[index] + ) + shared = sorted(timepoint_landmarks.keys() & reference_landmarks.keys()) + errors: list[tuple[str, float]] = [] + for name in shared: + warped = inverse_transform.TransformPoint(timepoint_landmarks[name]) + err = float( + np.linalg.norm( + np.asarray(warped, dtype=np.float64) + - np.asarray(reference_landmarks[name], dtype=np.float64) + ) + ) + errors.append((name, err)) + + with detail_file.open("a", newline="", encoding="utf-8") as fh: + writer = csv.writer(fh) + if fh.tell() == 0: + writer.writerow( + ["subject_id", "method", "timepoint", "name", "error_mm"] + ) + for name, err in errors: + writer.writerow([subject_id, method_name, timepoint, name, err]) + + values = np.asarray([e for _, e in errors], dtype=np.float64) + summary_rows.append( + { + "subject_id": subject_id, + "method": method_name, + "reference_timepoint": timepoints[reference_index], + "timepoint": timepoint, + "loss": float(result["losses"][index]), + "n_landmarks": int(values.size), + "mean_mm": float(np.mean(values)) if values.size else "", + "median_mm": float(np.median(values)) if values.size else "", + "max_mm": float(np.max(values)) if values.size else "", + } + ) + +# %% [markdown] +# ## 5. Write the wide-form per-timepoint summary CSV + +# %% +with summary_file.open("w", newline="", encoding="utf-8") as fh: + writer = csv.DictWriter(fh, fieldnames=list(summary_rows[0].keys())) + writer.writeheader() + writer.writerows(summary_rows) +print(f"Wrote summary: {summary_file}") +print(f"Wrote landmark details: {detail_file}") + +# %% [markdown] +# ## 6. Per-method aggregate table over all test subjects + +# %% +groups: dict[str, list[float]] = {} +with detail_file.open(newline="", encoding="utf-8") as fh: + for row in csv.DictReader(fh): + groups.setdefault(row["method"], []).append(float(row["error_mm"])) + +header = ( + f"{'Method':<18}{'N':>8}{'Mean (mm)':>12}" + f"{'Median (mm)':>14}{'P95 (mm)':>12}{'Max (mm)':>12}" +) +print() +print("=" * len(header)) +print(f"Landmark error summary ({len(test_subjects)} test subjects)") +print("=" * len(header)) +print(header) +print("-" * len(header)) +for method_name, _ in methods: + arr = np.asarray(groups.get(method_name, []), dtype=np.float64) + if arr.size == 0: + print(f"{method_name:<18}{0:>8}{'':>12}{'':>14}{'':>12}{'':>12}") + continue + print( + f"{method_name:<18}" + f"{arr.size:>8}" + f"{float(np.mean(arr)):>12.3f}" + f"{float(np.median(arr)):>14.3f}" + f"{float(np.percentile(arr, 95)):>12.3f}" + f"{float(np.max(arr)):>12.3f}" + ) +print("=" * len(header)) diff --git a/experiments/LongitudinalRegistration/3-run_registration_method_comparison.py b/experiments/LongitudinalRegistration/3-run_registration_method_comparison.py new file mode 100644 index 0000000..5467cad --- /dev/null +++ b/experiments/LongitudinalRegistration/3-run_registration_method_comparison.py @@ -0,0 +1,700 @@ +"""Compare longitudinal cardiac CT registration methods with landmarks. + +The experiment registers each gated time-point image to the high-resolution +reference image for the same subject. Input images are 3D CT volumes in LPS +world space. Landmarks are CSV rows with physical LPS coordinates +``Name,X,Y,Z`` in millimeters. + +Two accuracy modes are written: +1. Direct landmarks: reference landmarks are transformed into each time-point + image space with the inverse registration transform and compared to the + precomputed time-point landmarks. +2. Re-segmented landmarks: the reference image is warped into each time-point + image space, re-segmented with Simpleware, and the newly extracted landmarks + are compared to the precomputed time-point landmarks. +""" + +from __future__ import annotations + +import argparse +import csv +import re +from dataclasses import dataclass +from pathlib import Path +from typing import Optional + +import itk +import numpy as np + +from physiomotion4d import ( + RegisterTimeSeriesImages, + SegmentHeartSimpleware, + TransformTools, +) + +DEFAULT_REF_DIR = Path("d:/PhysioMotion4D/duke_data/ref_images") +DEFAULT_TIMEPOINT_BASE_DIR = Path("d:/PhysioMotion4D/duke_data/gated_nii") +DEFAULT_SEGMENTATION_BASE_DIR = Path("d:/PhysioMotion4D/duke_data/simple_ascardio") +DEFAULT_OUTPUT_DIR = Path("d:/PhysioMotion4D/duke_data/longitudinal_registration") +DEFAULT_EXCLUDE_TOKENS = ("nop", "dia", "sys", "_ref") +DEFAULT_SEGMENTATION_DIR = "results-labelmaps_and_landmarks" +DEFAULT_METHODS = ("ANTS", "greedy", "icon_default", "ants_icon_default") +TIMEPOINT_RE = re.compile(r"_g(?P[0-9]{3})") + + +Landmarks = dict[str, tuple[float, float, float]] + + +@dataclass(frozen=True) +class MethodSpec: + """Registration method plus optional ICON checkpoint.""" + + output_name: str + registration_method: str + icon_weights_path: Optional[Path] = None + + +@dataclass(frozen=True) +class ImageArtifacts: + """Input files associated with one image volume.""" + + image_file: Path + landmark_file: Optional[Path] + labelmap_file: Optional[Path] + timepoint: str + + +def nii_stem(path: Path) -> str: + """Return a stable stem for ``.nii.gz`` or single-suffix files.""" + if path.name.endswith(".nii.gz"): + return path.name[:-7] + return path.stem + + +def timepoint_from_name(path: Path) -> str: + """Extract the gated time-point tag from a filename.""" + match = TIMEPOINT_RE.search(path.name) + if match: + return match.group("timepoint") + return nii_stem(path) + + +def first_existing(paths: list[Path]) -> Optional[Path]: + """Return the first existing path from a candidate list.""" + for path in paths: + if path.exists(): + return path + return None + + +def landmark_candidates( + image_file: Path, + segmentation_dir: str, + artifact_dir: Optional[Path], +) -> list[Path]: + """Return likely landmark CSV paths for an image.""" + stem = nii_stem(image_file) + parent = image_file.parent + seg_parent = parent / segmentation_dir + candidates = [ + parent / f"{stem}_landmark.csv", + parent / f"{stem}_landmarks.csv", + seg_parent / f"{stem}_landmark.csv", + seg_parent / f"{stem}_landmarks.csv", + ] + if artifact_dir is not None: + candidates = [ + artifact_dir / f"{stem}_landmark.csv", + artifact_dir / f"{stem}_landmarks.csv", + *candidates, + ] + return candidates + + +def labelmap_candidates( + image_file: Path, + segmentation_dir: str, + artifact_dir: Optional[Path], +) -> list[Path]: + """Return likely labelmap paths for an image.""" + stem = nii_stem(image_file) + parent = image_file.parent + seg_parent = parent / segmentation_dir + candidates = [ + parent / f"{stem}_labelmap.nii.gz", + seg_parent / f"{stem}_labelmap.nii.gz", + ] + if artifact_dir is not None: + candidates = [artifact_dir / f"{stem}_labelmap.nii.gz", *candidates] + return candidates + + +def image_artifacts( + image_file: Path, + segmentation_dir: str, + artifact_dir: Optional[Path] = None, +) -> ImageArtifacts: + """Find landmarks and labelmaps associated with one image.""" + return ImageArtifacts( + image_file=image_file, + landmark_file=first_existing( + landmark_candidates(image_file, segmentation_dir, artifact_dir) + ), + labelmap_file=first_existing( + labelmap_candidates(image_file, segmentation_dir, artifact_dir) + ), + timepoint=timepoint_from_name(image_file), + ) + + +def read_landmarks(path: Path) -> Landmarks: + """Read physical LPS landmarks from ``Name,X,Y,Z`` CSV.""" + landmarks: Landmarks = {} + with path.open(newline="", encoding="utf-8-sig") as fh: + for row in csv.DictReader(fh): + landmarks[row["Name"]] = ( + float(row["X"]), + float(row["Y"]), + float(row["Z"]), + ) + return landmarks + + +def write_landmarks(path: Path, landmarks: Landmarks) -> None: + """Write physical LPS landmarks to ``Name,X,Y,Z`` CSV.""" + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("w", newline="", encoding="utf-8") as fh: + writer = csv.writer(fh) + writer.writerow(["Name", "X", "Y", "Z"]) + for name, coords in sorted(landmarks.items()): + writer.writerow([name, coords[0], coords[1], coords[2]]) + + +def transform_landmarks(landmarks: Landmarks, transform: itk.Transform) -> Landmarks: + """Apply an ITK physical-space transform to landmark coordinates.""" + transformed: Landmarks = {} + for name, point in landmarks.items(): + transformed_point = transform.TransformPoint(point) + transformed[name] = ( + float(transformed_point[0]), + float(transformed_point[1]), + float(transformed_point[2]), + ) + return transformed + + +def landmark_errors(source: Landmarks, target: Landmarks) -> dict[str, float]: + """Return per-landmark Euclidean errors in millimeters.""" + errors: dict[str, float] = {} + for name in sorted(source.keys() & target.keys()): + source_point = np.asarray(source[name], dtype=np.float64) + target_point = np.asarray(target[name], dtype=np.float64) + errors[name] = float(np.linalg.norm(source_point - target_point)) + return errors + + +def summarize_errors(errors: dict[str, float], prefix: str) -> dict[str, object]: + """Summarize landmark errors for one comparison mode.""" + if not errors: + return { + f"{prefix}_landmarks": 0, + f"{prefix}_mean_mm": "", + f"{prefix}_median_mm": "", + f"{prefix}_max_mm": "", + } + values = np.asarray(list(errors.values()), dtype=np.float64) + return { + f"{prefix}_landmarks": len(errors), + f"{prefix}_mean_mm": float(np.mean(values)), + f"{prefix}_median_mm": float(np.median(values)), + f"{prefix}_max_mm": float(np.max(values)), + } + + +def write_error_details( + path: Path, + subject_id: str, + method_name: str, + timepoint: str, + mode: str, + errors: dict[str, float], +) -> None: + """Append per-landmark errors to the detail CSV.""" + exists = path.exists() + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("a", newline="", encoding="utf-8") as fh: + fieldnames = ["subject_id", "method", "timepoint", "mode", "name", "error_mm"] + writer = csv.DictWriter(fh, fieldnames=fieldnames) + if not exists: + writer.writeheader() + for name, error in sorted(errors.items()): + writer.writerow( + { + "subject_id": subject_id, + "method": method_name, + "timepoint": timepoint, + "mode": mode, + "name": name, + "error_mm": error, + } + ) + + +def dice_by_label( + labelmap_a: itk.Image, + labelmap_b: itk.Image, +) -> dict[int, float]: + """Compute Dice scores for labels present in either 3D labelmap.""" + arr_a = itk.array_from_image(labelmap_a) + arr_b = itk.array_from_image(labelmap_b) + if arr_a.shape != arr_b.shape: + return {} + labels = sorted(set(np.unique(arr_a)).union(set(np.unique(arr_b))) - {0}) + scores: dict[int, float] = {} + for label in labels: + mask_a = arr_a == label + mask_b = arr_b == label + denom = int(mask_a.sum() + mask_b.sum()) + if denom > 0: + scores[int(label)] = float( + 2.0 * np.logical_and(mask_a, mask_b).sum() / denom + ) + return scores + + +def summarize_dice(scores: dict[int, float]) -> dict[str, object]: + """Summarize per-label Dice scores.""" + if not scores: + return {"dice_labels": 0, "dice_mean": "", "dice_min": ""} + values = np.asarray(list(scores.values()), dtype=np.float64) + return { + "dice_labels": len(scores), + "dice_mean": float(np.mean(values)), + "dice_min": float(np.min(values)), + } + + +def discover_subjects( + reference_dir: Path, + timepoint_base_dir: Path, + reference_pattern: str, + timepoint_pattern: str, + exclude_tokens: tuple[str, ...], + segmentation_dir: str, + segmentation_base_dir: Optional[Path], +) -> list[tuple[str, ImageArtifacts, list[ImageArtifacts]]]: + """Discover reference and time-point files for each subject.""" + if not reference_dir.exists(): + raise FileNotFoundError(f"Reference image directory not found: {reference_dir}") + if not timepoint_base_dir.exists(): + raise FileNotFoundError( + f"Time-point image base directory not found: {timepoint_base_dir}" + ) + + subjects: list[tuple[str, ImageArtifacts, list[ImageArtifacts]]] = [] + for reference_file in sorted(reference_dir.glob(reference_pattern)): + subject_id = reference_file.name[:6] + source_dir = timepoint_base_dir / subject_id + if not source_dir.exists(): + raise FileNotFoundError( + f"No time-point directory for {subject_id}: {source_dir}" + ) + artifact_dir = None + if segmentation_base_dir is not None: + candidate_dir = segmentation_base_dir / subject_id + if candidate_dir.exists(): + artifact_dir = candidate_dir + + reference_in_source = source_dir / reference_file.name + reference_artifacts = image_artifacts( + reference_in_source if reference_in_source.exists() else reference_file, + segmentation_dir, + artifact_dir, + ) + + timepoint_files = [ + path + for path in sorted(source_dir.glob(timepoint_pattern)) + if not any(token in path.name for token in exclude_tokens) + ] + timepoints = [ + image_artifacts(path, segmentation_dir, artifact_dir) + for path in timepoint_files + if path.is_file() + ] + subjects.append((subject_id, reference_artifacts, timepoints)) + return subjects + + +def build_method_specs( + method_names: list[str], + finetuned_weights_path: Optional[Path], +) -> list[MethodSpec]: + """Map output method labels to registrar methods and optional weights.""" + specs: list[MethodSpec] = [] + for method_name in method_names: + if method_name == "ANTS": + specs.append(MethodSpec(method_name, "ANTS")) + elif method_name == "greedy": + specs.append(MethodSpec(method_name, "greedy")) + elif method_name == "icon_default": + specs.append(MethodSpec(method_name, "ICON")) + elif method_name == "ants_icon_default": + specs.append(MethodSpec(method_name, "ANTS_ICON")) + elif method_name == "greedy_icon_default": + specs.append(MethodSpec(method_name, "greedy_ICON")) + elif method_name == "icon_finetuned": + specs.append(MethodSpec(method_name, "ICON", finetuned_weights_path)) + elif method_name == "ants_icon_finetuned": + specs.append(MethodSpec(method_name, "ANTS_ICON", finetuned_weights_path)) + elif method_name == "greedy_icon_finetuned": + specs.append(MethodSpec(method_name, "greedy_ICON", finetuned_weights_path)) + else: + raise ValueError(f"Unknown method: {method_name}") + + for spec in specs: + if "finetuned" in spec.output_name and spec.icon_weights_path is None: + raise ValueError(f"{spec.output_name} requires --finetuned-weights-path") + return specs + + +def configure_registrar( + method_spec: MethodSpec, + fixed_image: itk.Image, + fixed_labelmap: Optional[itk.Image], + ants_iterations: list[int], + greedy_iterations: list[int], + icon_iterations: int, +) -> RegisterTimeSeriesImages: + """Create and configure the time-series registrar.""" + registrar = RegisterTimeSeriesImages( + registration_method=method_spec.registration_method + ) + registrar.set_modality("ct") + registrar.set_fixed_image(fixed_image) + registrar.set_fixed_labelmap(fixed_labelmap) + registrar.set_number_of_iterations_ANTS(ants_iterations) + registrar.set_number_of_iterations_greedy(greedy_iterations) + registrar.set_number_of_iterations_ICON(icon_iterations) + if method_spec.icon_weights_path is not None: + registrar.registrar_ICON.set_weights_path(str(method_spec.icon_weights_path)) + return registrar + + +def write_summary(path: Path, rows: list[dict[str, object]]) -> None: + """Write experiment summary rows.""" + if not rows: + return + path.parent.mkdir(parents=True, exist_ok=True) + fieldnames = list(rows[0].keys()) + with path.open("w", newline="", encoding="utf-8") as fh: + writer = csv.DictWriter(fh, fieldnames=fieldnames) + writer.writeheader() + writer.writerows(rows) + + +def run_method_for_subject( + subject_id: str, + reference_artifacts: ImageArtifacts, + timepoint_artifacts: list[ImageArtifacts], + method_spec: MethodSpec, + output_dir: Path, + run_resegmentation: bool, + ants_iterations: list[int], + greedy_iterations: list[int], + icon_iterations: int, + error_detail_file: Path, +) -> list[dict[str, object]]: + """Run one registration method for one subject and return summary rows.""" + if reference_artifacts.landmark_file is None: + raise FileNotFoundError( + f"Missing reference landmarks for {reference_artifacts.image_file}" + ) + + fixed_image = itk.imread(str(reference_artifacts.image_file), pixel_type=itk.F) + fixed_labelmap = None + if reference_artifacts.labelmap_file is not None: + fixed_labelmap = itk.imread(str(reference_artifacts.labelmap_file)) + + moving_images = [ + itk.imread(str(artifacts.image_file), pixel_type=itk.F) + for artifacts in timepoint_artifacts + ] + moving_labelmaps = None + if all(artifacts.labelmap_file is not None for artifacts in timepoint_artifacts): + moving_labelmaps = [ + itk.imread(str(artifacts.labelmap_file)) + for artifacts in timepoint_artifacts + ] + + registrar = configure_registrar( + method_spec, + fixed_image, + fixed_labelmap, + ants_iterations, + greedy_iterations, + icon_iterations, + ) + + result = registrar.register_time_series( + moving_images=moving_images, + moving_labelmaps=moving_labelmaps, + reference_frame=0, + register_reference=True, + prior_weight=0.0, + ) + + reference_landmarks = read_landmarks(reference_artifacts.landmark_file) + transform_tools = TransformTools() + segmenter = SegmentHeartSimpleware() if run_resegmentation else None + subject_method_dir = output_dir / method_spec.output_name / subject_id + subject_method_dir.mkdir(parents=True, exist_ok=True) + + rows: list[dict[str, object]] = [] + for index, artifacts in enumerate(timepoint_artifacts): + timepoint_dir = subject_method_dir / artifacts.timepoint + timepoint_dir.mkdir(parents=True, exist_ok=True) + + forward_transform = result["forward_transforms"][index] + inverse_transform = result["inverse_transforms"][index] + loss = result["losses"][index] + + forward_file = timepoint_dir / "time_to_reference.hdf" + inverse_file = timepoint_dir / "reference_to_time.hdf" + itk.transformwrite(forward_transform, str(forward_file), compression=True) + itk.transformwrite(inverse_transform, str(inverse_file), compression=True) + + moving_to_reference = transform_tools.transform_image( + moving_images[index], + forward_transform, + fixed_image, + ) + moving_to_reference_file = timepoint_dir / "time_to_reference.mha" + itk.imwrite( + moving_to_reference, str(moving_to_reference_file), compression=True + ) + + reference_to_time = transform_tools.transform_image( + fixed_image, + inverse_transform, + moving_images[index], + ) + reference_to_time_file = timepoint_dir / "reference_to_time.mha" + itk.imwrite(reference_to_time, str(reference_to_time_file), compression=True) + + row: dict[str, object] = { + "subject_id": subject_id, + "method": method_spec.output_name, + "timepoint": artifacts.timepoint, + "moving_image": str(artifacts.image_file), + "forward_transform": str(forward_file), + "inverse_transform": str(inverse_file), + "loss": float(loss), + } + + if artifacts.landmark_file is not None: + timepoint_landmarks = read_landmarks(artifacts.landmark_file) + direct_landmarks = transform_landmarks( + reference_landmarks, + inverse_transform, + ) + direct_errors = landmark_errors(direct_landmarks, timepoint_landmarks) + write_error_details( + error_detail_file, + subject_id, + method_spec.output_name, + artifacts.timepoint, + "direct", + direct_errors, + ) + row.update(summarize_errors(direct_errors, "direct")) + else: + row.update(summarize_errors({}, "direct")) + + if run_resegmentation and segmenter is not None: + segmentation = segmenter.segment( + reference_to_time, + contrast_enhanced_study=False, + ) + warped_labelmap = segmentation["labelmap"] + warped_labelmap_file = timepoint_dir / "reference_to_time_labelmap.nii.gz" + itk.imwrite(warped_labelmap, str(warped_labelmap_file), compression=True) + reseg_landmarks = segmenter.get_landmarks() + reseg_landmark_file = timepoint_dir / "reference_to_time_landmark.csv" + write_landmarks(reseg_landmark_file, reseg_landmarks) + row["resegmented_labelmap"] = str(warped_labelmap_file) + row["resegmented_landmarks"] = str(reseg_landmark_file) + + if artifacts.landmark_file is not None: + timepoint_landmarks = read_landmarks(artifacts.landmark_file) + reseg_errors = landmark_errors(reseg_landmarks, timepoint_landmarks) + write_error_details( + error_detail_file, + subject_id, + method_spec.output_name, + artifacts.timepoint, + "resegmented", + reseg_errors, + ) + row.update(summarize_errors(reseg_errors, "resegmented")) + else: + row.update(summarize_errors({}, "resegmented")) + + if artifacts.labelmap_file is not None: + timepoint_labelmap = itk.imread(str(artifacts.labelmap_file)) + row.update( + summarize_dice(dice_by_label(warped_labelmap, timepoint_labelmap)) + ) + else: + row.update(summarize_dice({})) + else: + row["resegmented_labelmap"] = "" + row["resegmented_landmarks"] = "" + row.update(summarize_errors({}, "resegmented")) + row.update(summarize_dice({})) + + rows.append(row) + + return rows + + +def parse_iterations(value: str) -> list[int]: + """Parse comma-separated multi-resolution iteration counts.""" + return [int(item.strip()) for item in value.split(",") if item.strip()] + + +def main() -> int: + """Run the longitudinal registration comparison experiment.""" + parser = argparse.ArgumentParser( + description="Compare ANTS, Greedy, and ICON longitudinal registration." + ) + parser.add_argument("--reference-dir", type=Path, default=DEFAULT_REF_DIR) + parser.add_argument( + "--timepoint-base-dir", + type=Path, + default=DEFAULT_TIMEPOINT_BASE_DIR, + ) + parser.add_argument("--output-dir", type=Path, default=DEFAULT_OUTPUT_DIR) + parser.add_argument( + "--segmentation-base-dir", + type=Path, + default=DEFAULT_SEGMENTATION_BASE_DIR, + help="Directory with per-subject precomputed *_labelmap and *_landmark files.", + ) + parser.add_argument("--reference-pattern", default="pm00*.nii.gz") + parser.add_argument("--timepoint-pattern", default="*.nii.gz") + parser.add_argument("--segmentation-dir", default=DEFAULT_SEGMENTATION_DIR) + parser.add_argument( + "--exclude-token", + action="append", + default=list(DEFAULT_EXCLUDE_TOKENS), + help="Filename token to exclude from time-point inputs.", + ) + parser.add_argument( + "--methods", + nargs="+", + default=None, + help="Methods to run. Defaults include finetuned methods when weights are set.", + ) + parser.add_argument("--finetuned-weights-path", type=Path, default=None) + parser.add_argument("--max-subjects", type=int, default=None) + parser.add_argument("--max-timepoints", type=int, default=None) + parser.add_argument("--ANTS-iterations", default="30,15,7,3") + parser.add_argument("--greedy-iterations", default="30,15,7,3") + parser.add_argument("--ICON-iterations", type=int, default=20) + parser.add_argument( + "--skip-resegmentation", + action="store_true", + help="Skip Simpleware re-segmentation mode.", + ) + parser.add_argument( + "--dry-run", + action="store_true", + help="Validate discovered files and planned methods without registration.", + ) + args = parser.parse_args() + + method_names = args.methods + if method_names is None: + method_names = list(DEFAULT_METHODS) + method_names.append("greedy_icon_default") + if args.finetuned_weights_path is not None: + method_names.extend( + [ + "icon_finetuned", + "ants_icon_finetuned", + "greedy_icon_finetuned", + ] + ) + + method_specs = build_method_specs(method_names, args.finetuned_weights_path) + subjects = discover_subjects( + args.reference_dir, + args.timepoint_base_dir, + args.reference_pattern, + args.timepoint_pattern, + tuple(args.exclude_token), + args.segmentation_dir, + args.segmentation_base_dir, + ) + if args.max_subjects is not None: + subjects = subjects[: args.max_subjects] + + if args.dry_run: + for subject_id, reference_artifacts, timepoint_artifacts in subjects: + if args.max_timepoints is not None: + timepoint_artifacts = timepoint_artifacts[: args.max_timepoints] + missing_landmarks = sum( + artifacts.landmark_file is None for artifacts in timepoint_artifacts + ) + missing_labelmaps = sum( + artifacts.labelmap_file is None for artifacts in timepoint_artifacts + ) + print( + f"{subject_id}: {len(timepoint_artifacts)} time points, " + f"reference_landmarks={reference_artifacts.landmark_file is not None}, " + f"reference_labelmap={reference_artifacts.labelmap_file is not None}, " + f"missing_time_landmarks={missing_landmarks}, " + f"missing_time_labelmaps={missing_labelmaps}" + ) + print("Methods: " + ", ".join(spec.output_name for spec in method_specs)) + return 0 + + summary_rows: list[dict[str, object]] = [] + detail_file = args.output_dir / "landmark_errors_by_point.csv" + if detail_file.exists(): + detail_file.unlink() + + for subject_id, reference_artifacts, timepoint_artifacts in subjects: + if args.max_timepoints is not None: + timepoint_artifacts = timepoint_artifacts[: args.max_timepoints] + if not timepoint_artifacts: + raise ValueError(f"No time-point images found for {subject_id}") + print( + f"Running {subject_id}: {len(timepoint_artifacts)} time points, " + f"{len(method_specs)} methods" + ) + for method_spec in method_specs: + print(f" Method: {method_spec.output_name}") + rows = run_method_for_subject( + subject_id, + reference_artifacts, + timepoint_artifacts, + method_spec, + args.output_dir, + not args.skip_resegmentation, + parse_iterations(args.ants_iterations), + parse_iterations(args.greedy_iterations), + args.icon_iterations, + detail_file, + ) + summary_rows.extend(rows) + write_summary(args.output_dir / "registration_summary.csv", summary_rows) + + print(f"Wrote summary: {args.output_dir / 'registration_summary.csv'}") + print(f"Wrote landmark details: {detail_file}") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/experiments/LongitudinalRegistration/experiment_recon_4d.py b/experiments/LongitudinalRegistration/experiment_recon_4d.py new file mode 100644 index 0000000..7d9ce26 --- /dev/null +++ b/experiments/LongitudinalRegistration/experiment_recon_4d.py @@ -0,0 +1,191 @@ +# %% [markdown] +# # 4D CT Reconstruction Using RegisterTimeSeriesImages Class +# +# This script demonstrates the use of the `RegisterTimeSeriesImages` class +# to register a time series of CT images to a common reference frame. +# +# This is a refactored version of `reconstruct_4d_ct.ipynb` that uses +# the new class-based approach,including: +# - Registration of time series images using ANTs, Greedy, ICON, or combined +# ANTs/Greedy + ICON methods +# - Reconstruction of time series using the `reconstruct_time_series()` method +# - Optional upsampling to fixed image resolution while preserving spatial positioning +# + +# %% +# Import necessary libraries +######################################################## + +import os + +import itk +import numpy as np +from physiomotion4d import RegisterTimeSeriesImages + +# %% +# Identify reference images +######################################################## + +ref_data_dir = "d:/PhysioMotion4D/duke_data/ref_images" +src_data_dir_base = "d:/PhysioMotion4D/duke_data/gated_nii" +dest_data_dir_base = "d:/PhysioMotion4D/duke_data/recon4d" + +ref_files = [ + os.path.join(ref_data_dir, f) + for f in sorted(os.listdir(ref_data_dir)) + if f.startswith("pm00") and f.endswith(".nii.gz") +] + +print(f"Found {len(ref_files)} reference images") + +# %% +# Identify source data directories and files using reference image names +######################################################## + + +print(os.path.basename(ref_files[0])[:6]) +src_data_dirs = [] +src_data_files = [] +for ref_file in ref_files: + src_dir = os.path.join(src_data_dir_base, os.path.basename(ref_file)[:6]) + src_data_dirs.append(src_dir) + + file_list = sorted(os.listdir(src_dir)) + valid_file_list = [ + f + for f in file_list + if "dia" not in f + and "nop" not in f + and "sys" not in f + and f.endswith(".nii.gz") + ] + src_data_files.append(valid_file_list) + +print(f"Found {len(src_data_dirs)} source data directories") +for d, fs in zip(src_data_dirs, src_data_files): + print(f"{d}: {len(fs)} files") + for f in fs: + print(f" {f}") + +# %% +# Define registration function +######################################################## + + +def register_time_series( + reference_image_file: str, + source_image_dir: str, + source_image_files: list[str], + registration_method: str, +) -> None: + # ANTs registration + if registration_method in ["ANTS", "greedy"]: + number_of_iterations = [30, 15, 7, 3] + elif registration_method == "ICON": + number_of_iterations = 20 + elif registration_method in ["ANTS_ICON", "greedy_ICON"]: + number_of_iterations = [[30, 15, 7, 3], 20] + else: + raise ValueError(f"Invalid registration method: {registration_method}") + + # Create output dir + output_dir = os.path.join( + dest_data_dir_base, registration_method, os.path.basename(source_image_dir) + ) + os.makedirs(output_dir, exist_ok=True) + + # Read the reference image as the fixed image + fixed_image = itk.imread(reference_image_file, pixel_type=itk.F) + + images = [] + for file in source_image_files: + img = itk.imread(os.path.join(source_image_dir, file), pixel_type=itk.F) + images.append(img) + + reference_image_num = 7 + register_start_to_reference = True + if reference_image_file in source_image_files: + reference_image_num = source_image_files.index(reference_image_file) + register_start_to_reference = False + + portion_of_prior_transform_to_init_next_transform = 0.0 + + # Register the time series + registrar = RegisterTimeSeriesImages(registration_method=registration_method) + registrar.set_modality("ct") + registrar.set_fixed_image(fixed_image) + if registration_method == "ANTS": + registrar.set_number_of_iterations_ANTS(number_of_iterations) + elif registration_method == "greedy": + registrar.set_number_of_iterations_greedy(number_of_iterations) + elif registration_method == "ICON": + registrar.set_number_of_iterations_ICON(number_of_iterations) + elif registration_method == "ANTS_ICON": + registrar.set_number_of_iterations_ANTS(number_of_iterations[0]) + registrar.set_number_of_iterations_ICON(number_of_iterations[1]) + elif registration_method == "greedy_ICON": + registrar.set_number_of_iterations_greedy(number_of_iterations[0]) + registrar.set_number_of_iterations_ICON(number_of_iterations[1]) + else: + raise ValueError(f"Invalid registration method: {registration_method}") + + result = registrar.register_time_series( + moving_images=images, + reference_frame=reference_image_num, + register_reference=register_start_to_reference, + prior_weight=portion_of_prior_transform_to_init_next_transform, + ) + + upsampled_images = registrar.reconstruct_time_series( + moving_images=images, + inverse_transforms=result["inverse_transforms"], + upsample_to_fixed_resolution=True, + ) + + losses = result["losses"] + print("Registration complete!") + print(f" Average loss: {np.mean(losses):.6f}") + print(f" Min loss: {np.min(losses):.6f}") + print(f" Max loss: {np.max(losses):.6f}") + print("") + print("Saving results...") + output_file_basename = os.path.basename(reference_image_file)[:6] + for i, fwd_transform in enumerate(result["forward_transforms"]): + time_point_index = source_image_files[i].index("_g") + 2 + time_point = source_image_files[i][time_point_index : time_point_index + 3] + + output_file = f"{output_file_basename}_{time_point}_fwd.hdf" + itk.transformwrite( + fwd_transform, + os.path.join(output_dir, output_file), + compression=True, + ) + + inv_transform = result["inverse_transforms"][i] + output_file = f"{output_file_basename}_{time_point}_inv.hdf" + itk.transformwrite( + inv_transform, + os.path.join(output_dir, output_file), + compression=True, + ) + + output_file = f"{output_file_basename}_{time_point}_hrr.mha" + itk.imwrite( + upsampled_images[i], + os.path.join(output_dir, output_file), + compression=True, + ) + + +# %% +# Register time series +######################################################## + +for ref_file, src_dir, src_files in zip(ref_files, src_data_dirs, src_data_files): + register_time_series(ref_file, src_dir, src_files, "ANTS") + register_time_series(ref_file, src_dir, src_files, "greedy") + register_time_series(ref_file, src_dir, src_files, "ICON") + register_time_series(ref_file, src_dir, src_files, "ANTS_ICON") + register_time_series(ref_file, src_dir, src_files, "greedy_ICON") + +# %% diff --git a/experiments/LongitudinalRegistration/setup.sh b/experiments/LongitudinalRegistration/setup.sh new file mode 100644 index 0000000..5cfcf80 --- /dev/null +++ b/experiments/LongitudinalRegistration/setup.sh @@ -0,0 +1,28 @@ +#!/usr/bin/env bash +set -e + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "$SCRIPT_DIR" + +# Clone uniGradICON (feat-add-finetuning branch required for fine-tuning support) +if [ ! -d "uniGradICON" ]; then + git clone -b feat-add-finetuning https://github.com/uncbiag/uniGradICON.git +else + echo "uniGradICON/ already exists, skipping clone." +fi + +# Create venv if needed +if [ ! -d "venv" ]; then + python -m venv venv +fi + +# Detect venv Python path (Windows vs Linux/Mac) +if [ -f "venv/Scripts/python" ]; then + PYTHON="venv/Scripts/python" +else + PYTHON="venv/bin/python" +fi + +# Install all dependencies (including editable physiomotion4d and uniGradICON) +"$PYTHON" -m pip install uv +"$PYTHON" -m uv pip install -e . diff --git a/experiments/Reconstruct4DCT/reconstruct_4d_ct.py b/experiments/Reconstruct4DCT/reconstruct_4d_ct.py index eb189df..011a6fc 100644 --- a/experiments/Reconstruct4DCT/reconstruct_4d_ct.py +++ b/experiments/Reconstruct4DCT/reconstruct_4d_ct.py @@ -5,7 +5,7 @@ import itk import numpy as np -from physiomotion4d import RegisterImagesANTs, TransformTools +from physiomotion4d import RegisterImagesANTS, TransformTools _HERE = os.path.dirname(os.path.abspath(__file__)) @@ -32,14 +32,14 @@ num_files = len(files) reference_image_num = num_files // 2 # reg_method_data = zip(["ICON"], [RegisterImagesICON()], [2]) - reg_method_data = zip(["ANTs"], [RegisterImagesANTs()], [[20, 10, 2]]) + reg_method_data = zip(["ANTs"], [RegisterImagesANTS()], [[20, 10, 2]]) else: num_files = len(files) files_indx = list(range(num_files)) reference_image_num = 7 - reg_method_data = zip(["ANTs"], [RegisterImagesANTs()], [[30, 15, 5]]) + reg_method_data = zip(["ANTs"], [RegisterImagesANTS()], [[30, 15, 5]]) # reg_method_data = zip(["ICON"], [RegisterImagesICON()], [20]) - # reg_method_data = zip(["ICON","ANTs"], [RegisterImagesICON(), RegisterImagesANTs()], [20, [40, 20, 10]]) + # reg_method_data = zip(["ICON","ANTs"], [RegisterImagesICON(), RegisterImagesANTS()], [20, [40, 20, 10]]) reference_image_file = os.path.join( data_dir, f"slice_{files_indx[reference_image_num]:03d}.mha" @@ -296,7 +296,7 @@ def register_slices( files = [] files_indx = [] for f in sorted(os.listdir(_RESULTS_DIR)): - if f.endswith(".hdf") and f.startswith("slice_ANTs_forward_"): + if f.endswith(".hdf") and f.startswith("slice_ANTS_forward_"): files.append(os.path.join(_RESULTS_DIR, f)) files_indx.append(int(f.split("_")[3].split(".")[0])) @@ -311,7 +311,7 @@ def register_slices( for i in range(num_files): print(files_indx[i]) inverse_transform = itk.transformread( - os.path.join(_RESULTS_DIR, f"slice_ANTs_inverse_{files_indx[i]:03d}.hdf") + os.path.join(_RESULTS_DIR, f"slice_ANTS_inverse_{files_indx[i]:03d}.hdf") )[0] inverse_image = tfm_tool.convert_transform_to_displacement_field( @@ -321,7 +321,7 @@ def register_slices( ) itk.imwrite( inverse_image, - os.path.join(_RESULTS_DIR, f"slice_ANTs_inverse_{files_indx[i]:03d}_hdf.mha"), + os.path.join(_RESULTS_DIR, f"slice_ANTS_inverse_{files_indx[i]:03d}_hdf.mha"), compression=True, ) @@ -333,7 +333,7 @@ def register_slices( itk.imwrite( inverse_grid_image, os.path.join( - _RESULTS_DIR, f"slice_fixed_ANTs_inverse_grid_{files_indx[i]:03d}.mha" + _RESULTS_DIR, f"slice_fixed_ANTS_inverse_grid_{files_indx[i]:03d}.mha" ), compression=True, ) diff --git a/experiments/Reconstruct4DCT/reconstruct_4d_ct_class.py b/experiments/Reconstruct4DCT/reconstruct_4d_ct_class.py index 43fd937..70bd246 100644 --- a/experiments/Reconstruct4DCT/reconstruct_4d_ct_class.py +++ b/experiments/Reconstruct4DCT/reconstruct_4d_ct_class.py @@ -54,7 +54,7 @@ reference_image_num = num_files // 2 # Registration parameters - only ANTs for quick run - registration_methods = ["ants", "icon", "ants_icon"] + registration_methods = ["ANTS", "ICON", "ANTS_ICON"] number_of_iterations_list = [[8, 4, 1], 5, [[8, 4, 1], 5]] # For ANTs and ICON else: print("=== FULL RUN MODE ===") @@ -63,12 +63,12 @@ reference_image_num = 7 # Registration parameters - both ANTs and ICON for full run - registration_methods = ["ants"] # , "icon", "ants_icon"] + registration_methods = ["ANTS"] # , "ICON", "ANTS_ICON"] number_of_iterations_list = [ [30, 15, 7, 3], ] # For ANTs # 20, # For ICON - # [[30, 15, 7, 3], 20], # For ants_icon + # [[30, 15, 7, 3], 20], # For ANTS_ICON # ] # Common parameters @@ -144,13 +144,13 @@ registrar.set_fixed_image(fixed_image) # Set iterations based on registration method - if registration_method == "ants": - registrar.set_number_of_iterations_ants(number_of_iterations) - elif registration_method == "icon": - registrar.set_number_of_iterations_icon(number_of_iterations) - elif registration_method == "ants_icon": - registrar.set_number_of_iterations_ants(number_of_iterations[0]) - registrar.set_number_of_iterations_icon(number_of_iterations[1]) + if registration_method == "ANTS": + registrar.set_number_of_iterations_ANTS(number_of_iterations) + elif registration_method == "ICON": + registrar.set_number_of_iterations_ICON(number_of_iterations) + elif registration_method == "ANTS_ICON": + registrar.set_number_of_iterations_ANTS(number_of_iterations[0]) + registrar.set_number_of_iterations_ICON(number_of_iterations[1]) # Perform registration result = registrar.register_time_series( diff --git a/pyproject.toml b/pyproject.toml index a7715c0..cefcec4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,8 +47,8 @@ keywords = [ "ai", "deep-learning", "totalsegmentator", - "icon", - "ants", + "ICON", + "ANTS", "physiological-motion", "4d-visualization" ] @@ -255,6 +255,7 @@ module = [ "unigradicon.*", "vtk", "vtk.*", + "yaml", ] ignore_missing_imports = true diff --git a/src/physiomotion4d/__init__.py b/src/physiomotion4d/__init__.py index a2b0820..5d00e32 100644 --- a/src/physiomotion4d/__init__.py +++ b/src/physiomotion4d/__init__.py @@ -43,10 +43,11 @@ # Utility classes from .image_tools import ImageTools +from .landmark_tools import LandmarkTools # Base classes from .physiomotion4d_base import PhysioMotion4DBase -from .register_images_ants import RegisterImagesANTs +from .register_images_ants import RegisterImagesANTS from .register_images_greedy import RegisterImagesGreedy # Registration classes @@ -73,6 +74,7 @@ from .workflow_convert_vtk_to_usd import WorkflowConvertVTKToUSD from .workflow_reconstruct_highres_4d_ct import WorkflowReconstructHighres4DCT from .workflow_create_statistical_model import WorkflowCreateStatisticalModel +from .workflow_fine_tune_icon_registration import WorkflowFineTuneICONRegistration from .workflow_fit_statistical_model_to_patient import ( WorkflowFitStatisticalModelToPatient, ) @@ -83,6 +85,7 @@ "WorkflowConvertImageToUSD", "WorkflowConvertVTKToUSD", "WorkflowCreateStatisticalModel", + "WorkflowFineTuneICONRegistration", "WorkflowReconstructHighres4DCT", "WorkflowFitStatisticalModelToPatient", # Segmentation classes @@ -92,7 +95,7 @@ # Registration classes "RegisterImagesBase", "RegisterImagesICON", - "RegisterImagesANTs", + "RegisterImagesANTS", "RegisterImagesGreedy", "RegisterTimeSeriesImages", "RegisterModelsPCA", @@ -103,6 +106,7 @@ "PhysioMotion4DBase", # Utility classes "ImageTools", + "LandmarkTools", "TestTools", "TransformTools", "USDTools", diff --git a/src/physiomotion4d/cli/fit_statistical_model_to_patient.py b/src/physiomotion4d/cli/fit_statistical_model_to_patient.py index 46809c8..51d6b70 100644 --- a/src/physiomotion4d/cli/fit_statistical_model_to_patient.py +++ b/src/physiomotion4d/cli/fit_statistical_model_to_patient.py @@ -53,7 +53,7 @@ def main() -> int: --template-model heart_model.vtu \\ --patient-models lv.vtp rv.vtp \\ --patient-image patient_ct.nii.gz \\ - --use-icon-refinement \\ + --use-ICON-refinement \\ --output-dir ./results """, ) @@ -133,7 +133,7 @@ def main() -> int: help="Enable mask-to-image refinement (requires --template-labelmap and label IDs)", ) parser.add_argument( - "--use-icon-refinement", + "--use-ICON-refinement", action="store_true", default=False, help="Enable ICON registration refinement (default: disabled)", @@ -258,7 +258,7 @@ def main() -> int: print("\nStarting registration pipeline...") print("=" * 70) result = workflow.run_workflow( - use_icon_registration_refinement=args.use_icon_refinement, + use_ICON_registration_refinement=args.use_ICON_refinement, ) # Save results diff --git a/src/physiomotion4d/cli/reconstruct_highres_4d_ct.py b/src/physiomotion4d/cli/reconstruct_highres_4d_ct.py index 84103ee..295de5e 100644 --- a/src/physiomotion4d/cli/reconstruct_highres_4d_ct.py +++ b/src/physiomotion4d/cli/reconstruct_highres_4d_ct.py @@ -47,17 +47,17 @@ def main() -> int: %(prog)s \\ --time-series-images frame_*.mha \\ --fixed-image highres.mha \\ - --registration-method ants_icon \\ - --ants-iterations 30 15 7 3 \\ - --icon-iterations 20 \\ + --registration-method ANTS_ICON \\ + --ANTS-iterations 30 15 7 3 \\ + --ICON-iterations 20 \\ --output-dir ./results # Reconstruction with ICON only %(prog)s \\ --time-series-images frame_*.mha \\ --fixed-image highres.mha \\ - --registration-method icon \\ - --icon-iterations 50 \\ + --registration-method ICON \\ + --ICON-iterations 50 \\ --output-dir ./results """, ) @@ -81,9 +81,9 @@ def main() -> int: # Registration configuration parser.add_argument( "--registration-method", - choices=["ants", "icon", "ants_icon"], - default="ants_icon", - help="Registration method to use (default: ants_icon)", + choices=["ANTS", "ICON", "ANTS_ICON"], + default="ANTS_ICON", + help="Registration method to use (default: ANTS_ICON)", ) parser.add_argument( "--reference-frame", @@ -106,13 +106,13 @@ def main() -> int: # Registration iterations parser.add_argument( - "--ants-iterations", + "--ANTS-iterations", nargs="+", type=int, help="ANTs multi-resolution iterations (e.g., 30 15 7 3). Default: [30, 15, 7, 3]", ) parser.add_argument( - "--icon-iterations", + "--ICON-iterations", type=int, help="ICON fine-tuning iterations. Default: 20", ) @@ -292,14 +292,14 @@ def main() -> int: # Set number of iterations based on registration method and CLI arguments if args.ants_iterations: - workflow.set_number_of_iterations_ants(args.ants_iterations) + workflow.set_number_of_iterations_ANTS(args.ants_iterations) else: - workflow.set_number_of_iterations_ants([30, 15, 7, 3]) + workflow.set_number_of_iterations_ANTS([30, 15, 7, 3]) if args.icon_iterations: - workflow.set_number_of_iterations_icon(args.icon_iterations) + workflow.set_number_of_iterations_ICON(args.icon_iterations) else: - workflow.set_number_of_iterations_icon(20) + workflow.set_number_of_iterations_ICON(20) except (ValueError, RuntimeError, OSError) as e: print(f"Error initializing workflow: {e}") diff --git a/src/physiomotion4d/landmark_tools.py b/src/physiomotion4d/landmark_tools.py new file mode 100644 index 0000000..01c3f01 --- /dev/null +++ b/src/physiomotion4d/landmark_tools.py @@ -0,0 +1,205 @@ +""" +Tools for reading and writing anatomical landmarks. + +This module provides the :class:`LandmarkTools` class with utilities for +reading and writing point landmarks in 3D Slicer's Markups JSON +(``.mrk.json``) format and in a simple CSV format. Landmarks are kept in +memory in LPS world coordinates (ITK's native frame, matching the rest of +the platform). RAS files are converted to LPS on read; outputs are always +written in LPS. +""" + +import csv +import json +import logging +from pathlib import Path + +from physiomotion4d.physiomotion4d_base import PhysioMotion4DBase + +LandmarkDict = dict[str, tuple[float, float, float]] + +# Slicer Markups JSON schema URL emitted in the file envelope on write. +_MRK_JSON_SCHEMA = ( + "https://raw.githubusercontent.com/Slicer/Slicer/main/" + "Modules/Loadable/Markups/Resources/Schema/markups-schema-v1.0.3.json#" +) + + +class LandmarkTools(PhysioMotion4DBase): + """ + Read and write anatomical landmarks in LPS world coordinates. + + Landmarks are represented in memory as a dictionary keyed by label, + with each value a three-tuple ``(x, y, z)`` of LPS millimeter + coordinates:: + + { + 'apex': (x, y, z), + 'base': (x, y, z), + ... + } + + Positions are always in LPS. RAS input files are converted on read; + outputs are always written in LPS. + + Example: + >>> tools = LandmarkTools() + >>> landmarks = tools.read_landmarks_3dslicer('points.mrk.json') + >>> tools.write_landmarks_csv(landmarks, 'points.csv') + """ + + def __init__(self, log_level: int | str = logging.INFO): + """Initialize the LandmarkTools class. + + Args: + log_level: Logging level (default: logging.INFO) + """ + super().__init__(class_name=self.__class__.__name__, log_level=log_level) + + def read_landmarks_3dslicer(self, path: str | Path) -> LandmarkDict: + """Read landmarks from a 3D Slicer Markups JSON (``.mrk.json``) file. + + Reads the first markup node from the file and returns its control + points as a ``{label: (x, y, z)}`` dictionary. Other Slicer fields + (``id``, ``description``, ``orientation``, ...) are discarded. + + Coordinates are returned in LPS. If the file declares RAS (or the + legacy numeric codes ``'0'`` for LPS and ``'1'`` for RAS), each + position is converted by negating its X and Y components. + + Args: + path: Path to the ``.mrk.json`` file. + + Returns: + Dict mapping landmark label to ``(x, y, z)`` tuple in LPS. + + Raises: + ValueError: If the file contains no markups, declares an + unrecognized coordinate system, or has a control point + without a 3D position. + """ + with open(path, encoding="utf-8") as f: + data = json.load(f) + + markups = data.get("markups", []) + if not markups: + raise ValueError(f"No markups found in {path}") + markup = markups[0] + + coord_sys = str(markup.get("coordinateSystem", "LPS")).upper() + if coord_sys in ("RAS", "1"): + flip = True + elif coord_sys in ("LPS", "0"): + flip = False + else: + raise ValueError(f"Unrecognized coordinateSystem {coord_sys!r} in {path}") + + landmarks: LandmarkDict = {} + for cp in markup.get("controlPoints", []): + pos = cp.get("position") + label = cp.get("label", "") + if pos is None or len(pos) < 3: + raise ValueError( + f"Control point {label!r} in {path} has no 3D position" + ) + x, y, z = float(pos[0]), float(pos[1]), float(pos[2]) + if flip: + x, y = -x, -y + landmarks[label] = (x, y, z) + + return landmarks + + def write_landmarks_3dslicer( + self, landmarks: LandmarkDict, path: str | Path + ) -> None: + """Write landmarks to a 3D Slicer Markups JSON file in LPS. + + Wraps the landmarks in the Slicer Markups schema envelope and + writes the result to disk. The output always declares + ``coordinateSystem == 'LPS'``; positions are written verbatim, so + the caller must ensure they are already in LPS. + + Args: + landmarks: Dict mapping label to ``(x, y, z)`` tuple, as + returned by :meth:`read_landmarks_3dslicer` or + :meth:`read_landmarks_csv`. + path: Output ``.mrk.json`` path. + """ + control_points = [ + { + "label": label, + "position": [float(pos[0]), float(pos[1]), float(pos[2])], + } + for label, pos in landmarks.items() + ] + markup = { + "type": "Fiducial", + "coordinateSystem": "LPS", + "coordinateUnits": "mm", + "controlPoints": control_points, + } + data = { + "@schema": _MRK_JSON_SCHEMA, + "markups": [markup], + } + with open(path, "w", encoding="utf-8") as f: + json.dump(data, f, indent=4) + + def read_landmarks_csv(self, path: str | Path) -> LandmarkDict: + """Read landmarks from a CSV file with header ``Name,x,y,z`` (LPS). + + Coordinates are assumed to be in LPS. The returned dictionary + matches the in-memory format used by + :meth:`read_landmarks_3dslicer`, so the readers and writers are + interchangeable. + + Args: + path: Path to the CSV file. The first row must be the header + ``Name,x,y,z`` (case-insensitive, surrounding whitespace + tolerated); subsequent rows are ``label,x,y,z``. + + Returns: + Dict mapping landmark label to ``(x, y, z)`` tuple in LPS. + + Raises: + ValueError: If the file is empty, has the wrong header, or + contains a malformed row. + """ + landmarks: LandmarkDict = {} + with open(path, encoding="utf-8", newline="") as f: + reader = csv.reader(f) + header = next(reader, None) + if header is None: + raise ValueError(f"Empty CSV file: {path}") + normalized = [h.strip().lower() for h in header] + if normalized[:4] != ["name", "x", "y", "z"]: + raise ValueError( + f'Expected header "Name,x,y,z" in {path}, got {header!r}' + ) + for row in reader: + if not row or all(not c.strip() for c in row): + continue + if len(row) < 4: + raise ValueError(f"Malformed landmark row in {path}: {row!r}") + landmarks[row[0].strip()] = ( + float(row[1]), + float(row[2]), + float(row[3]), + ) + + return landmarks + + def write_landmarks_csv(self, landmarks: LandmarkDict, path: str | Path) -> None: + """Write landmarks to a CSV file with header ``Name,x,y,z`` (LPS). + + Positions are written verbatim and assumed to be in LPS. + + Args: + landmarks: Dict mapping label to ``(x, y, z)`` tuple. + path: Output CSV path. + """ + with open(path, "w", encoding="utf-8", newline="") as f: + writer = csv.writer(f) + writer.writerow(["Name", "x", "y", "z"]) + for label, pos in landmarks.items(): + writer.writerow([label, pos[0], pos[1], pos[2]]) diff --git a/src/physiomotion4d/register_images_ants.py b/src/physiomotion4d/register_images_ants.py index 1142917..182afaf 100644 --- a/src/physiomotion4d/register_images_ants.py +++ b/src/physiomotion4d/register_images_ants.py @@ -1,6 +1,6 @@ """ANTs-based image registration implementation. -This module provides the RegisterImagesANTs class, a concrete implementation of +This module provides the RegisterImagesANTS class, a concrete implementation of RegisterImagesBase that uses the Advanced Normalization Tools (ANTs) algorithm for image registration. It supports both affine and deformable (SyN) registration for aligning medical images, particularly useful for 4D cardiac CT registration. @@ -21,7 +21,7 @@ from physiomotion4d.transform_tools import TransformTools -class RegisterImagesANTs(RegisterImagesBase): +class RegisterImagesANTS(RegisterImagesBase): """ANTs-based deformable image registration implementation. This class extends RegisterImagesBase to provide deformable image registration @@ -58,7 +58,7 @@ class RegisterImagesANTs(RegisterImagesBase): metric (str): Similarity metric to use ('CC', 'Mattes', or 'MeanSquares', default: 'CC') Example: - >>> registrar = RegisterImagesANTs() + >>> registrar = RegisterImagesANTS() >>> registrar.set_modality('ct') >>> registrar.set_fixed_image(reference_image) >>> registrar.set_transform_type('Affine') @@ -286,15 +286,15 @@ def _antsfile_to_itk_displacement_field_transform( Returns: itk.DisplacementFieldTransform: ITK displacement field transform """ - disp_field_tfm_ants = ants.read_transform( + disp_field_tfm_ANTS = ants.read_transform( ants_transform_file, precision="double" ) - disp_field_ants = ants.transform_to_displacement_field( - disp_field_tfm_ants, + disp_field_ANTS = ants.transform_to_displacement_field( + disp_field_tfm_ANTS, self._itk_to_ants_image(ref_image, dtype="float"), ) - disp_field_itk_raw = self._ants_to_itk_image(disp_field_ants) + disp_field_itk_raw = self._ants_to_itk_image(disp_field_ANTS) # Convert to the correct Image[Vector[D, 3], 3] type for DisplacementFieldTransform # Use ImageTools helper to convert array to vector image with correct type @@ -313,7 +313,7 @@ def _antsfile_to_itk_displacement_field_transform( return disp_tfm - def itk_affine_transform_to_ants_transform( + def itk_affine_transform_to_ANTS_transform( self, itk_tfm: itk.Transform ) -> ants.ANTsTransform: """Convert ITK affine/rigid transform to ANTs affine transform. @@ -342,9 +342,9 @@ def itk_affine_transform_to_ants_transform( >>> affine_itk = itk.AffineTransform[itk.D, 3].New() >>> affine_itk.SetIdentity() >>> # Convert to ANTs - >>> affine_ants = registrar.itk_affine_transform_to_ants_transform(affine_itk) + >>> affine_ANTS = registrar.itk_affine_transform_to_ANTS_transform(affine_itk) >>> # Use in ANTs operations - >>> result = ants.apply_ants_transform(affine_ants, moving_image) + >>> result = ants.apply_ants_transform(affine_ANTS, moving_image) """ # Get dimension of the transform dimension = itk_tfm.GetInputSpaceDimension() @@ -406,7 +406,7 @@ def itk_affine_transform_to_ants_transform( return ants_tfm - def itk_transform_to_antsfile( + def itk_transform_to_ANTSfile( self, itk_tfm: itk.Transform, reference_image: itk.Image, @@ -440,13 +440,13 @@ def itk_transform_to_antsfile( >>> # Convert ITK affine transform to ANTs file >>> affine_itk = itk.AffineTransform[itk.D, 3].New() >>> affine_itk.SetIdentity() - >>> transform_files = registrar.itk_transform_to_antsfile( + >>> transform_files = registrar.itk_transform_to_ANTSfile( ... affine_itk, reference_image, 'initial_transform.mat' ... ) >>> >>> # Use in registration >>> result = ants.registration( - ... fixed=fixed_ants, moving=moving_ants, initial_transform=transform_files + ... fixed=fixed_ANTS, moving=moving_ANTS, initial_transform=transform_files ... ) """ if isinstance(itk_tfm, itk.DisplacementFieldTransform) or isinstance( @@ -468,7 +468,7 @@ def itk_transform_to_antsfile( self.log_info("Wrote ANTs displacement field to: %s", output_filename) return [output_filename] - ants_tfm = self.itk_affine_transform_to_ants_transform(itk_tfm) + ants_tfm = self.itk_affine_transform_to_ANTS_transform(itk_tfm) if ".mat" not in output_filename: output_filename = os.path.splitext(output_filename)[0] + ".mat" @@ -588,7 +588,7 @@ def registration_method( initial_transform: str | list[str] = "identity" if initial_forward_transform is not None: self.log_info("Converting initial ITK transform to ANTs format...") - initial_transform = self.itk_transform_to_antsfile( + initial_transform = self.itk_transform_to_ANTSfile( itk_tfm=initial_forward_transform, reference_image=self.fixed_image, output_filename="initial_transform_temp.mat", diff --git a/src/physiomotion4d/register_images_icon.py b/src/physiomotion4d/register_images_icon.py index 8f60a2a..28bd03d 100644 --- a/src/physiomotion4d/register_images_icon.py +++ b/src/physiomotion4d/register_images_icon.py @@ -1,6 +1,6 @@ """Icon-based image registration implementation. -This module provides the RegisterImagesIcon class, a concrete implementation of +This module provides the RegisterImagesICON class, a concrete implementation of RegisterImagesBase that uses the Icon (Inverse Consistent Image Registration) algorithm with deep learning models. It supports both masked and unmasked registration for aligning medical images, particularly useful for 4D cardiac CT registration. @@ -10,9 +10,7 @@ """ import logging -from collections.abc import Sequence -from pathlib import Path -from typing import Any, Optional, Union +from typing import Optional, Union import icon_registration as icon import icon_registration.itk_wrapper @@ -348,6 +346,42 @@ def _image_to_resized_tensor( tensor, size=shape[2:], mode="trilinear", align_corners=False ) + @staticmethod + def create_mask(labelmap: itk.Image, dilation_mm: float = 5.0) -> itk.Image: + """Create a binary registration mask from a labelmap. + + Thresholds the labelmap at ``>0`` (so every non-zero label becomes + foreground) and dilates the result by ``dilation_mm`` millimeters of + physical radius. The radius is converted into per-axis voxel counts + from the labelmap's spacing so the dilation is physically isotropic + even on anisotropic grids; each per-axis count is clamped to at least + 1 voxel when ``dilation_mm > 0``. + + Args: + labelmap: Multi-label or binary ``itk.Image``. Any non-zero voxel + is treated as foreground. + dilation_mm: Physical radius of the binary dilation in + millimeters. Pass ``0`` (or negative) to skip dilation and + return the raw ``>0`` mask. Default 5.0 mm. + + Returns: + ``itk.Image[itk.UC, 3]`` binary mask in the same physical space as + ``labelmap`` (origin, spacing, direction copied from the input). + """ + arr = (itk.array_from_image(labelmap) > 0).astype(np.uint8) + mask = itk.image_from_array(arr) + mask.CopyInformation(labelmap) + if dilation_mm <= 0: + return mask + spacing = labelmap.GetSpacing() + radius = itk.Size[3]() + for i in range(3): + radius[i] = max(1, int(round(dilation_mm / float(spacing[i])))) + structuring_element = itk.FlatStructuringElement[3].Ball(radius) + return itk.binary_dilate_image_filter( + mask, kernel=structuring_element, foreground_value=1 + ) + def _mask_to_resized_tensor( self, mask: itk.Image, shape: torch.Size ) -> torch.Tensor: @@ -382,111 +416,3 @@ def _mask_to_resized_tensor( arr = np.array(mask) tensor = torch.Tensor(arr).to(icon.config.device)[None, None] return F.interpolate(tensor, size=shape[2:], mode="nearest") - - def finetune( - self, - image_pairs: Sequence[tuple[itk.Image, itk.Image]], - output_model_filename: str, - mask_pairs: Optional[Sequence[tuple[itk.Image, itk.Image]]] = None, - epochs: int = 1, - learning_rate: float = DEFAULT_FINETUNE_LEARNING_RATE, - ) -> dict[str, Any]: - """Fine-tune the ICON network on a cohort of image pairs. - - Unlike ``register()``, this method *persistently* updates the in-memory - network weights and saves the resulting inner-network state_dict to - disk. The starting weights are whatever ``set_weights_path()`` was - configured to (default UniGradICON pretrained weights if never set). - - This method intentionally does not call ``register()`` / the ICON - ``register_pair`` helpers, because those wrap their optimization in a - ``state_dict`` save/restore that discards any persistent weight - changes. Here the optimizer steps act directly on ``self.net`` and - are not undone. - - Args: - image_pairs: Sequence of (fixed_image, moving_image) itk image - pairs to fine-tune on. Caller can build this from a list of - images with ``itertools.combinations(images, 2)``. - output_model_filename: Path where the fine-tuned inner-network - state_dict will be saved. Suitable for reloading via - ``set_weights_path()`` on a new RegisterImagesICON instance. - mask_pairs: Optional sequence of (fixed_mask, moving_mask) itk - image pairs, same length as ``image_pairs``. When provided, - ICON's masked-loss path is used for every pair. - epochs: Number of full passes over ``image_pairs``. Default 1. - learning_rate: Adam learning rate. Default matches - ``icon_registration``'s per-pair fine-tune rate. - - Returns: - Dict with keys: - - ``output_model_filename``: str path of saved checkpoint - - ``losses``: list[float], one entry per (epoch, pair) - - ``epochs``: int - - ``number_of_pairs``: int - - Raises: - ValueError: If ``image_pairs`` is empty or ``mask_pairs`` length - does not match ``image_pairs``. - """ - if len(image_pairs) == 0: - raise ValueError("At least one image pair is required for finetune.") - if mask_pairs is not None and len(mask_pairs) != len(image_pairs): - raise ValueError("mask_pairs must match image_pairs length.") - if epochs < 1: - raise ValueError("epochs must be >= 1.") - - self._ensure_net() - assert self.net is not None - self.net.to(icon.config.device) - self.net.train() - - optimizer = torch.optim.Adam(self.net.parameters(), lr=learning_rate) - shape = self.net.identity_map.shape - losses: list[float] = [] - - for epoch_index in range(epochs): - for pair_index, (fixed_image, moving_image) in enumerate(image_pairs): - fixed_pre = self.preprocess(fixed_image, modality=self.modality) - moving_pre = self.preprocess(moving_image, modality=self.modality) - - fixed_resized = self._image_to_resized_tensor(fixed_pre, shape) - moving_resized = self._image_to_resized_tensor(moving_pre, shape) - - forward_kwargs: dict[str, torch.Tensor] = {} - if mask_pairs is not None: - fixed_mask, moving_mask = mask_pairs[pair_index] - forward_kwargs["mask_A"] = self._mask_to_resized_tensor( - fixed_mask, shape - ) - forward_kwargs["mask_B"] = self._mask_to_resized_tensor( - moving_mask, shape - ) - - optimizer.zero_grad() - loss_tuple = self.net(fixed_resized, moving_resized, **forward_kwargs) - loss = loss_tuple[0] - loss.backward() - optimizer.step() - - loss_value = float(loss.detach().cpu().item()) - losses.append(loss_value) - self.log_info( - "finetune epoch %d pair %d/%d loss=%.6f", - epoch_index, - pair_index + 1, - len(image_pairs), - loss_value, - ) - - output_path = Path(output_model_filename) - output_path.parent.mkdir(parents=True, exist_ok=True) - torch.save(self.net.regis_net.state_dict(), output_path) - self.weights_path = str(output_path) - - return { - "output_model_filename": str(output_path), - "losses": losses, - "epochs": epochs, - "number_of_pairs": len(image_pairs), - } diff --git a/src/physiomotion4d/register_models_distance_maps.py b/src/physiomotion4d/register_models_distance_maps.py index 7e0a936..4b8090e 100644 --- a/src/physiomotion4d/register_models_distance_maps.py +++ b/src/physiomotion4d/register_models_distance_maps.py @@ -37,7 +37,7 @@ ... reference_image=reference_image, ... roi_dilation_mm=20, ... ) - >>> result = registrar.register(mode='deformable', use_icon=True, icon_iterations=50) + >>> result = registrar.register(mode='deformable', use_ICON=True, icon_iterations=50) >>> >>> # Access results >>> aligned_model = result['registered_model'] @@ -53,7 +53,7 @@ from physiomotion4d.contour_tools import ContourTools from physiomotion4d.physiomotion4d_base import PhysioMotion4DBase -from physiomotion4d.register_images_ants import RegisterImagesANTs +from physiomotion4d.register_images_ants import RegisterImagesANTS from physiomotion4d.register_images_icon import RegisterImagesICON from physiomotion4d.transform_tools import TransformTools @@ -84,8 +84,8 @@ class RegisterModelsDistanceMaps(PhysioMotion4DBase): roi_dilation_mm (float): Dilation amount in mm for ROI mask transform_tools (TransformTools): Transform utility instance contour_tools (ContourTools): Model utility instance - registrar_ants (RegisterImagesANTs): ANTs registration instance - registrar_icon (RegisterImagesICON): ICON registration instance + registrar_ANTS (RegisterImagesANTS): ANTs registration instance + registrar_ICON (RegisterImagesICON): ICON registration instance forward_transform (itk.CompositeTransform): Optimized moving→fixed transform inverse_transform (itk.CompositeTransform): Optimized fixed→moving transform registered_model (pv.PolyData): Aligned moving model @@ -107,7 +107,7 @@ class RegisterModelsDistanceMaps(PhysioMotion4DBase): >>> >>> # Or run deformable with ICON refinement >>> result = registrar.register( - ... mode='deformable', use_ants=False, use_icon=True, icon_iterations=50 + ... mode='deformable', use_ANTS=False, use_ICON=True, icon_iterations=50 ... ) >>> >>> # Get aligned model and transforms @@ -150,10 +150,10 @@ def __init__( self.contour_tools = ContourTools() # Registration instances - self.registrar_ants = RegisterImagesANTs(log_level=log_level) - self.registrar_icon = RegisterImagesICON(log_level=log_level) - self.registrar_icon.set_modality("ct") - self.registrar_icon.set_multi_modality(False) + self.registrar_ANTS = RegisterImagesANTS(log_level=log_level) + self.registrar_ICON = RegisterImagesICON(log_level=log_level) + self.registrar_ICON.set_modality("ct") + self.registrar_ICON.set_multi_modality(False) # Generated masks (will be created during registration) self.fixed_mask_image: Optional[itk.Image] = None @@ -225,7 +225,7 @@ def _create_masks_from_models(self) -> None: def register( self, transform_type: str = "Deformable", - use_icon: bool = False, + use_ICON: bool = False, icon_iterations: int = 50, ) -> dict: """Perform mask-based registration of moving model to fixed model. @@ -249,8 +249,8 @@ def register( Args: transform_type: Registration transform type - 'None', 'Rigid', 'Affine', or 'Deformable'. Default: 'Deformable' - use_icon: Whether to apply ICON registration refinement after ANTs. Default: False - icon_iterations: Number of ICON optimization iterations if use_icon=True. Default: 50 + use_ICON: Whether to apply ICON registration refinement after ANTs. Default: False + icon_iterations: Number of ICON optimization iterations if use_ICON=True. Default: 50 Returns: Dictionary containing: @@ -270,7 +270,7 @@ def register( >>> >>> # Deformable registration with ICON refinement >>> result = registrar.register( - ... transform_type='Deformable', use_icon=True, icon_iterations=100 + ... transform_type='Deformable', use_ICON=True, icon_iterations=100 ... ) """ if transform_type not in ["None", "Rigid", "Affine", "Deformable"]: @@ -288,64 +288,64 @@ def register( transform_type, ) - inverse_transform_ants = None - forward_transform_ants = None + inverse_transform_ANTS = None + forward_transform_ANTS = None if transform_type != "None": - self.registrar_ants.set_fixed_image(self.fixed_mask_image) - self.registrar_ants.set_fixed_mask(self.fixed_mask_roi_image) + self.registrar_ANTS.set_fixed_image(self.fixed_mask_image) + self.registrar_ANTS.set_fixed_mask(self.fixed_mask_roi_image) - self.registrar_ants.set_transform_type(transform_type) - self.registrar_ants.set_metric("MeanSquares") + self.registrar_ANTS.set_transform_type(transform_type) + self.registrar_ANTS.set_metric("MeanSquares") - result_ants = self.registrar_ants.register( + result_ANTS = self.registrar_ANTS.register( moving_image=self.moving_mask_image, moving_mask=self.moving_mask_roi_image, ) - inverse_transform_ants = result_ants["inverse_transform"] - forward_transform_ants = result_ants["forward_transform"] + inverse_transform_ANTS = result_ANTS["inverse_transform"] + forward_transform_ANTS = result_ANTS["forward_transform"] else: identity_transform = itk.AffineTransform[itk.D, 3].New() identity_transform.SetIdentity() - inverse_transform_ants = identity_transform - forward_transform_ants = identity_transform + inverse_transform_ANTS = identity_transform + forward_transform_ANTS = identity_transform # Initialize composite transforms - self.forward_transform = forward_transform_ants - self.inverse_transform = inverse_transform_ants + self.forward_transform = forward_transform_ANTS + self.inverse_transform = inverse_transform_ANTS # Optional ICON refinement - if use_icon: + if use_ICON: self.log_info( "Performing ICON refinement registration (%d iterations)...", icon_iterations, ) # Transform masks with ANTs result for ICON input - moving_mask_ants_transformed = self.transform_tools.transform_image( + moving_mask_ANTS_transformed = self.transform_tools.transform_image( self.moving_mask_image, - forward_transform_ants, + forward_transform_ANTS, self.reference_image, interpolation_method="linear", ) # Configure ICON - self.registrar_icon.set_number_of_iterations(icon_iterations) - self.registrar_icon.set_fixed_image(self.fixed_mask_image) - self.registrar_icon.set_fixed_mask(self.fixed_mask_roi_image) + self.registrar_ICON.set_number_of_iterations(icon_iterations) + self.registrar_ICON.set_fixed_image(self.fixed_mask_image) + self.registrar_ICON.set_fixed_mask(self.fixed_mask_roi_image) # ICON registration - result_icon = self.registrar_icon.register( - moving_image=moving_mask_ants_transformed, + result_ICON = self.registrar_ICON.register( + moving_image=moving_mask_ANTS_transformed, moving_mask=self.moving_mask_roi_image, ) - inverse_transform_icon = result_icon["inverse_transform"] - forward_transform_icon = result_icon["forward_transform"] + inverse_transform_ICON = result_ICON["inverse_transform"] + forward_transform_ICON = result_ICON["forward_transform"] # Compose ANTs and ICON transforms composed_forward = ( self.transform_tools.combine_displacement_field_transforms( - forward_transform_ants, - forward_transform_icon, + forward_transform_ANTS, + forward_transform_ICON, reference_image=self.reference_image, mode="compose", ) @@ -353,8 +353,8 @@ def register( composed_inverse = ( self.transform_tools.combine_displacement_field_transforms( - inverse_transform_icon, - inverse_transform_ants, + inverse_transform_ICON, + inverse_transform_ANTS, reference_image=self.reference_image, mode="compose", ) diff --git a/src/physiomotion4d/register_time_series_images.py b/src/physiomotion4d/register_time_series_images.py index 349e4d6..5e14f5a 100644 --- a/src/physiomotion4d/register_time_series_images.py +++ b/src/physiomotion4d/register_time_series_images.py @@ -13,18 +13,18 @@ import itk -from physiomotion4d.register_images_ants import RegisterImagesANTs +from physiomotion4d.register_images_ants import RegisterImagesANTS from physiomotion4d.register_images_base import RegisterImagesBase from physiomotion4d.register_images_greedy import RegisterImagesGreedy from physiomotion4d.register_images_icon import RegisterImagesICON from physiomotion4d.transform_tools import TransformTools REGISTRATION_METHODS: list[str] = [ - "ants", + "ANTS", "greedy", - "icon", - "ants_icon", - "greedy_icon", + "ICON", + "ANTS_ICON", + "greedy_ICON", ] @@ -52,20 +52,20 @@ class RegisterTimeSeriesImages(RegisterImagesBase): - Returns all transforms and loss values for the entire series Attributes: - registration_method_name (str): Registration method in use ('ants', - 'greedy', 'icon', 'ants_icon', or 'greedy_icon'). - registrar_ants (RegisterImagesANTs): Internal ANTs registrar. + registration_method_name (str): Registration method in use ('ANTS', + 'greedy', 'ICON', 'ANTS_ICON', or 'greedy_ICON'). + registrar_ANTS (RegisterImagesANTS): Internal ANTs registrar. registrar_greedy (RegisterImagesGreedy): Internal Greedy registrar. - registrar_icon (RegisterImagesICON): Internal ICON registrar (also used - as the refinement stage for 'ants_icon' and 'greedy_icon'). + registrar_ICON (RegisterImagesICON): Internal ICON registrar (also used + as the refinement stage for 'ANTS_ICON' and 'greedy_ICON'). transform_tools (TransformTools): Utility for transform operations. Example: >>> # Register a cardiac CT time series - >>> registrar = RegisterTimeSeriesImages(registration_method='ants') + >>> registrar = RegisterTimeSeriesImages(registration_method='ANTS') >>> registrar.set_modality('ct') >>> registrar.set_fixed_image(fixed_image) - >>> registrar.set_number_of_iterations_ants([40, 20, 10]) + >>> registrar.set_number_of_iterations_ANTS([40, 20, 10]) >>> >>> # Register all time points to fixed image >>> result = registrar.register_time_series( @@ -88,14 +88,14 @@ class RegisterTimeSeriesImages(RegisterImagesBase): """ def __init__( - self, registration_method: str = "ants", log_level: int | str = logging.INFO + self, registration_method: str = "ANTS", log_level: int | str = logging.INFO ) -> None: """Initialize the time series image registration class. Args: registration_method (str): Registration method to use. - Options: 'ants', 'greedy', 'icon', 'ants_icon', or - 'greedy_icon'. Default: 'ants' + Options: 'ANTS', 'greedy', 'ICON', 'ANTS_ICON', or + 'greedy_ICON'. Default: 'ANTS' log_level: Logging level (default: logging.INFO) Raises: @@ -103,16 +103,16 @@ def __init__( """ super().__init__(log_level=log_level) - self.registration_method_name: str = registration_method.lower() + self.registration_method_name: str = registration_method - self.registrar_ants = RegisterImagesANTs(log_level=log_level) + self.registrar_ANTS = RegisterImagesANTS(log_level=log_level) self.registrar_greedy = RegisterImagesGreedy(log_level=log_level) - self.registrar_icon = RegisterImagesICON(log_level=log_level) + self.registrar_ICON = RegisterImagesICON(log_level=log_level) # Set default iterations based on registration method - self.number_of_iterations_ants: list[int] = [40, 20, 10] + self.number_of_iterations_ANTS: list[int] = [40, 20, 10] self.number_of_iterations_greedy: list[int] = [40, 20, 10] - self.number_of_iterations_icon: int = 50 + self.number_of_iterations_ICON: int = 50 if self.registration_method_name not in REGISTRATION_METHODS: raise ValueError( @@ -124,24 +124,24 @@ def __init__( self.smooth_prior_transform_sigma: float = 0.5 - def set_number_of_iterations_ants( - self, number_of_iterations_ants: list[int] + def set_number_of_iterations_ANTS( + self, number_of_iterations_ANTS: list[int] ) -> None: """Set the number of iterations for ANTs registration. Args: - number_of_iterations_ants: List of iterations for ANTs multi-resolution + number_of_iterations_ANTS: List of iterations for ANTs multi-resolution (e.g., [40, 20, 10] for three resolution levels) """ - self.number_of_iterations_ants = number_of_iterations_ants + self.number_of_iterations_ANTS = number_of_iterations_ANTS - def set_number_of_iterations_icon(self, number_of_iterations_icon: int) -> None: + def set_number_of_iterations_ICON(self, number_of_iterations_ICON: int) -> None: """Set the number of iterations for ICON registration. Args: - number_of_iterations_icon: Number of fine-tuning steps for ICON + number_of_iterations_ICON: Number of fine-tuning steps for ICON """ - self.number_of_iterations_icon = number_of_iterations_icon + self.number_of_iterations_ICON = number_of_iterations_ICON def set_number_of_iterations_greedy( self, number_of_iterations_greedy: list[int] @@ -268,10 +268,10 @@ def register_time_series( calling this method. Example: - >>> registrar = RegisterTimeSeriesImages(registration_method='ants') + >>> registrar = RegisterTimeSeriesImages(registration_method='ANTS') >>> registrar.set_fixed_image(fixed_image) >>> registrar.set_fixed_mask(fixed_mask) # Optional - >>> registrar.set_number_of_iterations_ants([30, 15, 5]) + >>> registrar.set_number_of_iterations_ANTS([30, 15, 5]) >>> >>> # Use new intuitive parameter names >>> result = registrar.register_time_series( @@ -294,13 +294,13 @@ def register_time_series( if self.fixed_image is None: raise ValueError("Fixed image must be set before registering time series") - if self.registration_method_name == "ants": - self.registrar_ants.set_fixed_image(self.fixed_image) - self.registrar_ants.set_modality(self.modality) - self.registrar_ants.set_mask_dilation(self.mask_dilation_mm) - self.registrar_ants.set_number_of_iterations(self.number_of_iterations_ants) - self.registrar_ants.set_fixed_mask(self.fixed_mask) - self.registrar_ants.set_fixed_labelmap(self.fixed_labelmap) + if self.registration_method_name == "ANTS": + self.registrar_ANTS.set_fixed_image(self.fixed_image) + self.registrar_ANTS.set_modality(self.modality) + self.registrar_ANTS.set_mask_dilation(self.mask_dilation_mm) + self.registrar_ANTS.set_number_of_iterations(self.number_of_iterations_ANTS) + self.registrar_ANTS.set_fixed_mask(self.fixed_mask) + self.registrar_ANTS.set_fixed_labelmap(self.fixed_labelmap) elif self.registration_method_name == "greedy": self.registrar_greedy.set_fixed_image(self.fixed_image) self.registrar_greedy.set_modality(self.modality) @@ -310,23 +310,23 @@ def register_time_series( ) self.registrar_greedy.set_fixed_mask(self.fixed_mask) self.registrar_greedy.set_fixed_labelmap(self.fixed_labelmap) - elif self.registration_method_name == "icon": - self.registrar_icon.set_fixed_image(self.fixed_image) - self.registrar_icon.set_modality(self.modality) - self.registrar_icon.set_mask_dilation(self.mask_dilation_mm) - self.registrar_icon.set_number_of_iterations(self.number_of_iterations_icon) - self.registrar_icon.set_fixed_mask(self.fixed_mask) - self.registrar_icon.set_fixed_labelmap(self.fixed_labelmap) - elif self.registration_method_name in ["ants_icon", "greedy_icon"]: - if self.registration_method_name == "ants_icon": - self.registrar_ants.set_fixed_image(self.fixed_image) - self.registrar_ants.set_modality(self.modality) - self.registrar_ants.set_mask_dilation(self.mask_dilation_mm) - self.registrar_ants.set_number_of_iterations( - self.number_of_iterations_ants + elif self.registration_method_name == "ICON": + self.registrar_ICON.set_fixed_image(self.fixed_image) + self.registrar_ICON.set_modality(self.modality) + self.registrar_ICON.set_mask_dilation(self.mask_dilation_mm) + self.registrar_ICON.set_number_of_iterations(self.number_of_iterations_ICON) + self.registrar_ICON.set_fixed_mask(self.fixed_mask) + self.registrar_ICON.set_fixed_labelmap(self.fixed_labelmap) + elif self.registration_method_name in ["ANTS_ICON", "greedy_ICON"]: + if self.registration_method_name == "ANTS_ICON": + self.registrar_ANTS.set_fixed_image(self.fixed_image) + self.registrar_ANTS.set_modality(self.modality) + self.registrar_ANTS.set_mask_dilation(self.mask_dilation_mm) + self.registrar_ANTS.set_number_of_iterations( + self.number_of_iterations_ANTS ) - self.registrar_ants.set_fixed_mask(self.fixed_mask) - self.registrar_ants.set_fixed_labelmap(self.fixed_labelmap) + self.registrar_ANTS.set_fixed_mask(self.fixed_mask) + self.registrar_ANTS.set_fixed_labelmap(self.fixed_labelmap) else: self.registrar_greedy.set_fixed_image(self.fixed_image) self.registrar_greedy.set_modality(self.modality) @@ -336,12 +336,12 @@ def register_time_series( ) self.registrar_greedy.set_fixed_mask(self.fixed_mask) self.registrar_greedy.set_fixed_labelmap(self.fixed_labelmap) - self.registrar_icon.set_fixed_image(self.fixed_image) - self.registrar_icon.set_modality(self.modality) - self.registrar_icon.set_mask_dilation(self.mask_dilation_mm) - self.registrar_icon.set_number_of_iterations(self.number_of_iterations_icon) - self.registrar_icon.set_fixed_mask(self.fixed_mask) - self.registrar_icon.set_fixed_labelmap(self.fixed_labelmap) + self.registrar_ICON.set_fixed_image(self.fixed_image) + self.registrar_ICON.set_modality(self.modality) + self.registrar_ICON.set_mask_dilation(self.mask_dilation_mm) + self.registrar_ICON.set_number_of_iterations(self.number_of_iterations_ICON) + self.registrar_ICON.set_fixed_mask(self.fixed_mask) + self.registrar_ICON.set_fixed_labelmap(self.fixed_labelmap) num_images = len(moving_images) @@ -388,8 +388,8 @@ def register_time_series( if moving_labelmaps is not None else None ) - if self.registration_method_name == "ants": - result = self.registrar_ants.register( + if self.registration_method_name == "ANTS": + result = self.registrar_ANTS.register( moving_images[reference_frame], moving_mask=reference_mask, moving_labelmap=reference_labelmap, @@ -400,16 +400,16 @@ def register_time_series( moving_mask=reference_mask, moving_labelmap=reference_labelmap, ) - elif self.registration_method_name == "icon": - result = self.registrar_icon.register( + elif self.registration_method_name == "ICON": + result = self.registrar_ICON.register( moving_images[reference_frame], moving_mask=reference_mask, moving_labelmap=reference_labelmap, ) - elif self.registration_method_name in ["ants_icon", "greedy_icon"]: + elif self.registration_method_name in ["ANTS_ICON", "greedy_ICON"]: registrar_initial = ( - self.registrar_ants - if self.registration_method_name == "ants_icon" + self.registrar_ANTS + if self.registration_method_name == "ANTS_ICON" else self.registrar_greedy ) result = registrar_initial.register( @@ -418,7 +418,7 @@ def register_time_series( moving_labelmap=reference_labelmap, ) forward_initial = result["forward_transform"] - result = self.registrar_icon.register( + result = self.registrar_ICON.register( moving_images[reference_frame], moving_mask=reference_mask, moving_labelmap=reference_labelmap, @@ -474,8 +474,8 @@ def register_time_series( ) # Try registration with identity initialization - if self.registration_method_name == "ants": - result_init_identity = self.registrar_ants.register( + if self.registration_method_name == "ANTS": + result_init_identity = self.registrar_ANTS.register( moving_image=moving_image, moving_mask=moving_mask, moving_labelmap=moving_labelmap, @@ -486,16 +486,16 @@ def register_time_series( moving_mask=moving_mask, moving_labelmap=moving_labelmap, ) - elif self.registration_method_name == "icon": - result_init_identity = self.registrar_icon.register( + elif self.registration_method_name == "ICON": + result_init_identity = self.registrar_ICON.register( moving_image=moving_image, moving_mask=moving_mask, moving_labelmap=moving_labelmap, ) - elif self.registration_method_name in ["ants_icon", "greedy_icon"]: + elif self.registration_method_name in ["ANTS_ICON", "greedy_ICON"]: registrar_initial = ( - self.registrar_ants - if self.registration_method_name == "ants_icon" + self.registrar_ANTS + if self.registration_method_name == "ANTS_ICON" else self.registrar_greedy ) result_init_identity = registrar_initial.register( @@ -504,7 +504,7 @@ def register_time_series( moving_labelmap=moving_labelmap, ) forward_initial = result_init_identity["forward_transform"] - result_init_identity = self.registrar_icon.register( + result_init_identity = self.registrar_ICON.register( moving_image=moving_image, moving_mask=moving_mask, moving_labelmap=moving_labelmap, @@ -521,8 +521,8 @@ def register_time_series( # Select best result based on prior usage if prior_weight > 0.0: # Try with prior transform initialization - if self.registration_method_name == "ants": - result_init_prior = self.registrar_ants.register( + if self.registration_method_name == "ANTS": + result_init_prior = self.registrar_ANTS.register( moving_image=moving_image, moving_mask=moving_mask, moving_labelmap=moving_labelmap, @@ -535,17 +535,17 @@ def register_time_series( moving_labelmap=moving_labelmap, initial_forward_transform=prior_forward, ) - elif self.registration_method_name == "icon": - result_init_prior = self.registrar_icon.register( + elif self.registration_method_name == "ICON": + result_init_prior = self.registrar_ICON.register( moving_image=moving_image, moving_mask=moving_mask, moving_labelmap=moving_labelmap, initial_forward_transform=prior_forward, ) - elif self.registration_method_name in ["ants_icon", "greedy_icon"]: + elif self.registration_method_name in ["ANTS_ICON", "greedy_ICON"]: registrar_initial = ( - self.registrar_ants - if self.registration_method_name == "ants_icon" + self.registrar_ANTS + if self.registration_method_name == "ANTS_ICON" else self.registrar_greedy ) result_init_prior = registrar_initial.register( @@ -555,7 +555,7 @@ def register_time_series( initial_forward_transform=prior_forward, ) forward_initial = result_init_prior["forward_transform"] - result_init_prior = self.registrar_icon.register( + result_init_prior = self.registrar_ICON.register( moving_image=moving_image, moving_mask=moving_mask, moving_labelmap=moving_labelmap, @@ -644,7 +644,7 @@ def reconstruct_time_series( ValueError: If lengths of moving_images and inverse_transforms don't match Example: - >>> registrar = RegisterTimeSeriesImages(registration_method='ants') + >>> registrar = RegisterTimeSeriesImages(registration_method='ANTS') >>> registrar.set_fixed_image(fixed_image) >>> >>> result = registrar.register_time_series( @@ -766,8 +766,8 @@ def registration_method( Returns: dict: Registration result with forward_transform, inverse_transform, and loss """ - if self.registration_method_name == "ants": - res = self.registrar_ants.registration_method( + if self.registration_method_name == "ANTS": + res = self.registrar_ANTS.registration_method( moving_image=moving_image, moving_mask=moving_mask, moving_labelmap=moving_labelmap, @@ -792,8 +792,8 @@ def registration_method( "inverse_transform": cast(itk.Transform, res["inverse_transform"]), "loss": float(cast(float, res["loss"])), } - if self.registration_method_name == "icon": - res = self.registrar_icon.registration_method( + if self.registration_method_name == "ICON": + res = self.registrar_ICON.registration_method( moving_image=moving_image, moving_mask=moving_mask, moving_labelmap=moving_labelmap, @@ -805,10 +805,10 @@ def registration_method( "inverse_transform": cast(itk.Transform, res["inverse_transform"]), "loss": float(cast(float, res["loss"])), } - if self.registration_method_name in ["ants_icon", "greedy_icon"]: + if self.registration_method_name in ["ANTS_ICON", "greedy_ICON"]: registrar_initial = ( - self.registrar_ants - if self.registration_method_name == "ants_icon" + self.registrar_ANTS + if self.registration_method_name == "ANTS_ICON" else self.registrar_greedy ) initial_res = registrar_initial.registration_method( @@ -819,7 +819,7 @@ def registration_method( initial_forward_transform=initial_forward_transform, ) forward_initial = initial_res["forward_transform"] - icon_res = self.registrar_icon.registration_method( + icon_res = self.registrar_ICON.registration_method( moving_image=moving_image, moving_mask=moving_mask, moving_labelmap=moving_labelmap, diff --git a/src/physiomotion4d/segment_heart_simpleware.py b/src/physiomotion4d/segment_heart_simpleware.py index d1af243..15a1031 100644 --- a/src/physiomotion4d/segment_heart_simpleware.py +++ b/src/physiomotion4d/segment_heart_simpleware.py @@ -275,17 +275,18 @@ def segmentation_method(self, preprocessed_image: itk.image) -> itk.image: sz = sz[::-1] labelmap_array = np.zeros(sz, dtype=np.uint8) interior_array = np.zeros(sz, dtype=np.uint8) + mask_image = None for mask_id, mask_name in self.taxonomy.all_labels().items(): output_file = os.path.join(tmp_dir, f"mask_{mask_name}.mhd") if os.path.exists(output_file): mask_image = itk.imread(output_file) mask_array = itk.GetArrayFromImage(mask_image).astype(np.uint8) + tmp_array = (mask_array > 128).astype(np.uint8) if mask_id in mask_ids_of_interior_regions: - tmp_array = (mask_array > 128).astype(np.uint8) interior_array = np.where( interior_array == 0, tmp_array, interior_array ) - mask_array = (mask_array > 128) * mask_id + mask_array = tmp_array * mask_id labelmap_array = np.where( labelmap_array == 0, mask_array, labelmap_array ) @@ -314,6 +315,7 @@ def segmentation_method(self, preprocessed_image: itk.image) -> itk.image: landmarks_file, ) + # Dilate the interior regions to simulate 3mm myocardium (heart) interior_image = itk.GetImageFromArray(interior_array.astype(np.uint8)) interior_image.CopyInformation(preprocessed_image) imMath = tube.ImageMath.New(interior_image) @@ -335,6 +337,50 @@ def segmentation_method(self, preprocessed_image: itk.image) -> itk.image: "ensure the ASCardio module ran successfully." ) + if mask_image is not None: + in_direction = np.array(preprocessed_image.GetDirection()) + out_direction = np.array(mask_image.GetDirection()) + flip = [False, False, False] + for i in range(3): + if np.sign(out_direction[i, i]) != np.sign(in_direction[i, i]): + self.log_info(f"Flipping labelmap array along {i}-axis") + labelmap_array = np.flip(labelmap_array, axis=(2 - i)) + flip[i] = True + origin = np.array(mask_image.GetOrigin()) + edge = np.array( + mask_image.TransformIndexToPhysicalPoint( + mask_image.GetLargestPossibleRegion().GetSize() + ) + ) + self.log_debug(f"Origin {origin} Edge {edge}") + point = np.zeros(3) + for landmark_name, landmark_position in self.landmarks.items(): + for i in range(3): + point[i] = landmark_position[i] + + self.log_debug(f"{landmark_name} {point}") + for i in range(3): + if in_direction[i, i] < 0: + self.log_debug( + f" Flipping {i} from {point[i]} " + f"with edge {edge[i]} and origin {origin[i]}" + ) + if i < 2: + point[i] = -origin[i] + (-origin[i] - point[i]) + else: + point[i] = edge[i] - (point[i] - origin[i]) + elif i < 2: + point[i] = -point[i] + self.log_debug(f" New point {point}") + # convert ras to lps as used by this project + point[0] = -point[0] + point[1] = -point[1] + self.landmarks[landmark_name] = ( + float(point[0]), + float(point[1]), + float(point[2]), + ) + labelmap_image = itk.GetImageFromArray(labelmap_array.astype(np.uint8)) labelmap_image.CopyInformation(preprocessed_image) diff --git a/src/physiomotion4d/workflow_convert_image_to_usd.py b/src/physiomotion4d/workflow_convert_image_to_usd.py index d4f828d..640cf68 100644 --- a/src/physiomotion4d/workflow_convert_image_to_usd.py +++ b/src/physiomotion4d/workflow_convert_image_to_usd.py @@ -19,7 +19,7 @@ from physiomotion4d.contour_tools import ContourTools from physiomotion4d.convert_image_4d_to_3d import ConvertImage4DTo3D from physiomotion4d.physiomotion4d_base import PhysioMotion4DBase -from physiomotion4d.register_images_ants import RegisterImagesANTs +from physiomotion4d.register_images_ants import RegisterImagesANTS from physiomotion4d.register_images_base import RegisterImagesBase from physiomotion4d.register_images_icon import RegisterImagesICON from physiomotion4d.segment_anatomy_base import SegmentAnatomyBase @@ -161,7 +161,7 @@ def __init__( self.registrar: RegisterImagesBase if self.registration_method == "ANTS": self.log_info("Initializing ANTs registration...") - ants_registrar = RegisterImagesANTs(log_level=log_level) + ants_registrar = RegisterImagesANTS(log_level=log_level) ants_registrar.set_modality("ct") ants_registrar.set_transform_type("Deformable") if ( diff --git a/src/physiomotion4d/workflow_create_statistical_model.py b/src/physiomotion4d/workflow_create_statistical_model.py index 1958f96..f5a8354 100644 --- a/src/physiomotion4d/workflow_create_statistical_model.py +++ b/src/physiomotion4d/workflow_create_statistical_model.py @@ -191,7 +191,7 @@ def _step3_deformable_correspondence(self) -> None: ) result = registrar.register( transform_type="Deformable", - use_icon=False, + use_ICON=False, ) new_aligned_models.append(result["registered_model"]) diff --git a/src/physiomotion4d/workflow_fine_tune_icon_registration.py b/src/physiomotion4d/workflow_fine_tune_icon_registration.py new file mode 100644 index 0000000..223a5c0 --- /dev/null +++ b/src/physiomotion4d/workflow_fine_tune_icon_registration.py @@ -0,0 +1,865 @@ +"""Fine-tune uniGradICON registration and apply the fine-tuned weights. + +This module provides :class:`WorkflowFineTuneICONRegistration`, which packages +the two halves of the longitudinal-registration ICON fine-tuning experiment +from ``experiments/LongitudinalRegistration``: + +1. **Fine-tuning**: build a paired dataset JSON and YAML config from per-subject + lists of image files (with optional segmentation labelmaps and landmark CSVs) + and launch ``unigradicon.finetuning.finetune`` as a subprocess. Mirrors + ``experiments/LongitudinalRegistration/1-finetune_icon.py``. +2. **Apply**: load a fine-tuned uniGradICON checkpoint and register a list of + moving images to a single reference image using + :class:`RegisterTimeSeriesImages` (ICON backend). Mirrors the per-subject + registration loop in + ``experiments/LongitudinalRegistration/recon_4d_icon_eval.py``. + +Conventions: + - Fine-tuning is file-based: it reads images/labelmaps/landmarks from disk + because ``unigradicon.finetuning.finetune`` is launched as a subprocess + that consumes JSON paths. + - Apply is in-memory: takes ``itk.Image`` inputs in LPS space and + ``dict[name, (x, y, z)]`` landmark dictionaries. Segmentations are + resampled with nearest-neighbor interpolation; images use linear + interpolation. + - The ``inverse_transform`` returned by ICON is a resampler-convention + transform that maps moving-grid points back to reference-grid points; + ``forward_transform`` is the inverse direction (reference grid → + moving grid). Landmarks are warped using ``TransformPoint`` and + images/segmentations are resampled via + :meth:`TransformTools.transform_image`. +""" + +import json +import logging +import os +import subprocess +import sys +from pathlib import Path +from typing import Any, Optional, Union + +import itk +import numpy as np +import yaml + +from physiomotion4d.physiomotion4d_base import PhysioMotion4DBase +from physiomotion4d.register_images_icon import RegisterImagesICON +from physiomotion4d.register_time_series_images import RegisterTimeSeriesImages +from physiomotion4d.transform_tools import TransformTools + +Landmarks = dict[str, tuple[float, float, float]] + + +class WorkflowFineTuneICONRegistration(PhysioMotion4DBase): + """Fine-tune uniGradICON on paired 3D images and apply the fine-tuned weights. + + The workflow has two stages that can be used together or independently: + + **Stage 1: Fine-tuning** (file-based) + Build a paired dataset JSON and YAML config from per-subject lists of + image, segmentation, and landmark files, then launch + ``unigradicon.finetuning.finetune`` as a subprocess. Each subject's + time-point images form one paired group (they share a ``subject_id``). + + **Stage 2: Apply** (in-memory) + Register a list of moving images to a single reference image using the + fine-tuned ICON weights and return both directions of the warp: + + - moving images / segmentations / landmarks warped into reference space + - the reference image / segmentation / landmarks warped into each + moving-image space + + Attributes: + subject_image_files (list[list[str]]): Per-subject lists of image + paths. Images within one inner list share a subject_id during + fine-tuning. + output_dir (Path): Directory where dataset JSON, YAML config, derived + masks, and the uniGradICON ``checkpoints/`` tree are written. + fine_tune_name (str): Sub-directory name for the experiment outputs. + subject_ids (Optional[list[str]]): One ID per subject (e.g. patient + identifiers). Written into the dataset JSON's ``subject_id`` + field; falls back to synthetic ``subject_NNNN`` when ``None``. + subject_segmentation_files (Optional[list[list[Optional[str]]]]): + Per-subject multi-label segmentation/labelmap paths aligned with + ``subject_image_files``. ``None`` (or per-image ``None``) means no + segmentation for that image. If supplied for at least one image, + paired-with-seg training is enabled. + subject_mask_files (Optional[list[list[Optional[str]]]]): + Per-subject binary mask paths aligned with ``subject_image_files``. + When supplied for a frame these masks are used directly for + loss-function masking; otherwise masks are derived from + ``subject_segmentation_files``. + subject_landmark_files (Optional[list[list[Optional[str]]]]): + Per-subject landmark CSV paths (``Name,X,Y,Z`` format) aligned with + ``subject_image_files``. Recorded in the dataset JSON for + traceability; not consumed by uniGradICON fine-tuning itself. + mask_dilation_mm (float): Millimeters of physical-radius binary + dilation applied to the >0 labelmap when deriving the loss-masking + binary mask via :meth:`RegisterImagesICON.create_mask`. + mask_dir (Optional[Path]): Directory where derived binary masks are + written and looked up. ``None`` (default) writes each derived + mask next to its source labelmap on disk. + registrar (RegisterTimeSeriesImages): ICON-backend registrar used in + :meth:`apply_registration`. + transform_tools (TransformTools): Utility for resampling images and + segmentations. + + Example: + >>> # Stage 1: fine-tune + >>> workflow = WorkflowFineTuneICONRegistration( + ... subject_image_files=[ + ... ['pm0001/g000.nii.gz', 'pm0001/g050.nii.gz'], + ... ['pm0002/g000.nii.gz', 'pm0002/g050.nii.gz'], + ... ], + ... output_dir=Path('d:/PhysioMotion4D/icon_finetuned'), + ... fine_tune_name='duke_4d_gated_icon_ft', + ... subject_segmentation_files=[ + ... ['pm0001/g000_labelmap.nii.gz', 'pm0001/g050_labelmap.nii.gz'], + ... ['pm0002/g000_labelmap.nii.gz', 'pm0002/g050_labelmap.nii.gz'], + ... ], + ... ) + >>> weights_path = workflow.run_fine_tuning() + >>> + >>> # Stage 2: apply + >>> result = workflow.apply_registration( + ... reference_image=ref_image, + ... moving_images=moving_images, + ... weights_path=weights_path, + ... reference_segmentation=ref_seg, + ... moving_segmentations=moving_segs, + ... ) + >>> warped_to_ref = result['moving_to_reference_images'] + >>> warped_to_moving = result['reference_to_moving_images'] + """ + + def __init__( + self, + subject_image_files: list[list[str]], + output_dir: Path, + fine_tune_name: str, + subject_ids: Optional[list[str]] = None, + subject_segmentation_files: Optional[list[list[Optional[str]]]] = None, + subject_mask_files: Optional[list[list[Optional[str]]]] = None, + subject_landmark_files: Optional[list[list[Optional[str]]]] = None, + epochs: int = 2000, + batch_size: int = 4, + learning_rate: float = 5e-5, + input_shape: tuple[int, int, int] = (175, 175, 175), + similarity: str = "lncc", + lambda_value: float = 1.5, + dice_loss_weight: float = 0.5, + lncc_sigma: int = 5, + ct_window: tuple[float, float] = (-1000.0, 1000.0), + is_ct: bool = True, + gpus: Optional[list[int]] = None, + eval_period: int = 10, + save_period: int = 50, + mask_dilation_mm: float = 5.0, + mask_dir: Optional[Path] = None, + unigradicon_src_path: Optional[Path] = None, + log_level: Union[int, str] = logging.INFO, + ) -> None: + """Initialize the ICON fine-tuning workflow. + + Args: + subject_image_files: Per-subject lists of image file paths. Each + inner list groups frames belonging to one subject; all of those + frames share a ``subject_id`` for paired training. + output_dir: Directory for dataset JSON, YAML config, derived masks, + and the uniGradICON checkpoint tree. + fine_tune_name: Sub-directory name for the experiment outputs + (used as the uniGradICON ``experiment.name`` stem). + subject_ids: One ID per subject, in the same order as + ``subject_image_files``. Written verbatim into the dataset + JSON's ``subject_id`` field so paired training groups frames + that share an ID. ``None`` falls back to synthetic IDs of the + form ``subject_0000``, ``subject_0001``, ... Must be unique. + subject_segmentation_files: Per-subject multi-label segmentation + (labelmap) paths matching ``subject_image_files``. ``None`` + disables paired-with-seg training (no ``use_label``). + Individual ``None`` entries inside the inner lists skip just + those frames when paired-with-seg training is enabled. + subject_mask_files: Per-subject binary mask paths matching + ``subject_image_files``. When supplied these are used directly + for ICON loss-function masking; otherwise masks are derived + from ``subject_segmentation_files`` via a >0 threshold and + dilation by ``mask_dilation_mm``. Per-image ``None`` + entries fall back to the derived mask for that frame (or skip + it if no segmentation is available either). + subject_landmark_files: Per-subject landmark CSV paths matching + ``subject_image_files``. Stored in the dataset JSON for + traceability; not consumed by uniGradICON fine-tuning. + epochs: uniGradICON ``training.epochs``. + batch_size: uniGradICON ``training.batch_size``. + learning_rate: uniGradICON ``training.learning_rate``. + input_shape: uniGradICON ``training.input_shape`` (voxels, X/Y/Z). + similarity: uniGradICON ``training.similarity`` metric (e.g. ``lncc``). + lambda_value: uniGradICON ``training.lambda`` regularization weight. + dice_loss_weight: uniGradICON ``training.dice_loss_weight``. + lncc_sigma: uniGradICON ``training.lncc_sigma``. + ct_window: uniGradICON dataset ``ct_window`` ``[low, high]`` in HU. + is_ct: Whether the dataset is CT (passes through to dataset config). + gpus: GPU device indices for training. Defaults to ``[0]``. + eval_period: uniGradICON ``training.eval_period``. + save_period: uniGradICON ``training.save_period``. + mask_dilation_mm: Physical radius (millimeters) of binary + dilation applied to the >0 labelmap when deriving the + loss-masking binary mask via + :meth:`RegisterImagesICON.create_mask`. Ignored when no + segmentations are supplied. Default 5.0 mm. + mask_dir: Directory where derived binary masks are written and + looked up. ``None`` (default) writes each derived mask next + to its source labelmap on disk + (``/_mask.nii.gz``). An explicit + path puts all derived masks in that single directory. + unigradicon_src_path: Optional path to a local uniGradICON source + tree to prepend to ``PYTHONPATH`` when running fine-tuning. + Useful for using a checked-out copy instead of the installed + package. + log_level: Logging level (``logging.DEBUG``, ``logging.INFO``, ...). + + Raises: + ValueError: If ``subject_image_files`` is empty. + ValueError: If ``subject_segmentation_files``, + ``subject_mask_files``, or ``subject_landmark_files`` is + provided with a shape that does not match + ``subject_image_files``. + """ + super().__init__( + class_name="WorkflowFineTuneICONRegistration", log_level=log_level + ) + + if not subject_image_files: + raise ValueError("subject_image_files must not be empty") + + if subject_ids is not None: + if len(subject_ids) != len(subject_image_files): + raise ValueError( + f"subject_ids length ({len(subject_ids)}) must match " + f"subject_image_files length ({len(subject_image_files)})" + ) + if len(set(subject_ids)) != len(subject_ids): + raise ValueError(f"subject_ids must be unique, got {subject_ids}") + + self._validate_companion_shape( + subject_image_files, + subject_segmentation_files, + "subject_segmentation_files", + ) + self._validate_companion_shape( + subject_image_files, subject_mask_files, "subject_mask_files" + ) + self._validate_companion_shape( + subject_image_files, subject_landmark_files, "subject_landmark_files" + ) + + self.subject_image_files = subject_image_files + self.subject_ids = subject_ids + self.subject_segmentation_files = subject_segmentation_files + self.subject_mask_files = subject_mask_files + self.subject_landmark_files = subject_landmark_files + + self.output_dir = Path(output_dir).resolve() + self.fine_tune_name = fine_tune_name + self.experiment_dir = self.output_dir / fine_tune_name + self.mask_dir: Optional[Path] = Path(mask_dir) if mask_dir is not None else None + + self.epochs = epochs + self.batch_size = batch_size + self.learning_rate = learning_rate + self.input_shape = tuple(input_shape) + self.similarity = similarity + self.lambda_value = lambda_value + self.dice_loss_weight = dice_loss_weight + self.lncc_sigma = lncc_sigma + self.ct_window = tuple(ct_window) + self.is_ct = is_ct + self.gpus = list(gpus) if gpus is not None else [0] + self.eval_period = eval_period + self.save_period = save_period + self.mask_dilation_mm = float(mask_dilation_mm) + self.unigradicon_src_path = ( + Path(unigradicon_src_path) if unigradicon_src_path is not None else None + ) + + self.transform_tools = TransformTools() + self.registrar: Optional[RegisterTimeSeriesImages] = None + + self._dataset_json_path: Optional[Path] = None + self._config_yaml_path: Optional[Path] = None + + @staticmethod + def _validate_companion_shape( + image_files: list[list[str]], + companion: Optional[list[list[Optional[str]]]], + name: str, + ) -> None: + """Confirm a companion list has the same nested shape as ``image_files``.""" + if companion is None: + return + if len(companion) != len(image_files): + raise ValueError( + f"{name} length ({len(companion)}) must match " + f"subject_image_files length ({len(image_files)})" + ) + for i, (images, items) in enumerate(zip(image_files, companion, strict=True)): + if len(items) != len(images): + raise ValueError( + f"{name}[{i}] length ({len(items)}) must match " + f"subject_image_files[{i}] length ({len(images)})" + ) + + @property + def uses_segmentations(self) -> bool: + """Whether at least one segmentation file is supplied for training. + + Drives the uniGradICON ``training.use_label`` flag. + """ + return self._any_non_none(self.subject_segmentation_files) + + @property + def uses_masks(self) -> bool: + """Whether the dataset will have a ``mask`` field on every kept entry. + + True when explicit masks are supplied OR when segmentations are supplied + (since masks are then derived). Drives the uniGradICON + ``training.loss_function_masking`` flag. + """ + return self._any_non_none(self.subject_mask_files) or self.uses_segmentations + + @staticmethod + def _any_non_none( + companion: Optional[list[list[Optional[str]]]], + ) -> bool: + """Return True when ``companion`` contains at least one non-``None`` entry.""" + if companion is None: + return False + for inner in companion: + for item in inner: + if item is not None: + return True + return False + + @staticmethod + def _posix(path: Union[str, Path]) -> str: + """Return a forward-slashed string path (uniGradICON expects POSIX paths).""" + return str(path).replace("\\", "/") + + def _derive_mask(self, labelmap_path: Union[str, Path]) -> Path: + """Create (or reuse) a dilated binary mask from a multi-label labelmap. + + Threshold the labelmap at ``>0`` and dilate by ``mask_dilation_mm`` mm + of physical radius via :meth:`RegisterImagesICON.create_mask` to widen + the ROI for loss-function masking. + + When :attr:`mask_dir` is ``None`` (the default) the mask is written + next to the source labelmap as + ``/_mask.nii.gz``. Otherwise it goes + under :attr:`mask_dir`. Existing masks on disk are reused unmodified. + + Args: + labelmap_path: Path to a multi-label ``itk.Image`` on disk. + + Returns: + Path to the binary mask file on disk. + """ + labelmap_path = Path(labelmap_path) + stem = labelmap_path.name + if stem.endswith(".nii.gz"): + stem = stem[: -len(".nii.gz")] + else: + stem = labelmap_path.stem + target_dir = ( + self.mask_dir if self.mask_dir is not None else labelmap_path.parent + ) + target_dir.mkdir(parents=True, exist_ok=True) + mask_path = target_dir / f"{stem}_mask.nii.gz" + if mask_path.exists(): + return mask_path + + labelmap = itk.imread(str(labelmap_path)) + mask = RegisterImagesICON.create_mask( + labelmap, dilation_mm=self.mask_dilation_mm + ) + itk.imwrite(mask, str(mask_path), compression=True) + return mask_path + + def prepare_dataset(self) -> Path: + """Write the uniGradICON dataset JSON from the configured file lists. + + Builds one entry per image with ``image``, optional ``segmentation``, + optional ``mask``, optional ``landmarks`` (path only), and a + ``subject_id`` derived from the inner-list index. + + Masks are taken from ``subject_mask_files`` when supplied for a frame; + otherwise they are derived from ``subject_segmentation_files`` via a + >0 threshold and ``mask_dilation_mm`` mm dilation. Frames are + skipped (with a log warning) when a required companion (segmentation + for paired-with-seg training, or mask for loss-function masking) is + missing. + + Returns: + Path to the dataset JSON written under :attr:`experiment_dir`. + + Raises: + FileNotFoundError: If an image listed in ``subject_image_files`` + does not exist on disk. + """ + self.experiment_dir.mkdir(parents=True, exist_ok=True) + use_seg = self.uses_segmentations + use_mask = self.uses_masks + + dataset_entries: list[dict[str, str]] = [] + for subject_index, image_files in enumerate(self.subject_image_files): + subject_id = ( + self.subject_ids[subject_index] + if self.subject_ids is not None + else f"subject_{subject_index:04d}" + ) + seg_list = ( + self.subject_segmentation_files[subject_index] + if self.subject_segmentation_files is not None + else [None] * len(image_files) + ) + mask_list = ( + self.subject_mask_files[subject_index] + if self.subject_mask_files is not None + else [None] * len(image_files) + ) + landmark_list = ( + self.subject_landmark_files[subject_index] + if self.subject_landmark_files is not None + else [None] * len(image_files) + ) + + for image_file, seg_file, mask_file, landmark_file in zip( + image_files, seg_list, mask_list, landmark_list, strict=True + ): + image_path = Path(image_file) + if not image_path.exists(): + raise FileNotFoundError(f"Image not found: {image_path}") + + entry: dict[str, str] = { + "image": self._posix(image_path), + "subject_id": subject_id, + } + + if use_seg: + if seg_file is None or not Path(seg_file).exists(): + self.log_warning( + "Skipping %s: segmentation missing for paired-with-seg " + "training (seg=%s)", + image_path, + seg_file, + ) + continue + entry["segmentation"] = self._posix(seg_file) + + if use_mask: + if mask_file is not None and Path(mask_file).exists(): + resolved_mask: Path = Path(mask_file) + elif seg_file is not None and Path(seg_file).exists(): + resolved_mask = self._derive_mask(seg_file) + else: + self.log_warning( + "Skipping %s: neither explicit mask nor segmentation " + "available to derive a loss-function mask " + "(mask=%s, seg=%s)", + image_path, + mask_file, + seg_file, + ) + continue + entry["mask"] = self._posix(resolved_mask) + + if landmark_file is not None: + entry["landmarks"] = self._posix(landmark_file) + + dataset_entries.append(entry) + + dataset_json_path = self.experiment_dir / f"{self.fine_tune_name}_dataset.json" + with dataset_json_path.open("w") as fh: + json.dump({"data": dataset_entries}, fh, indent=2) + + self.log_info( + "Wrote dataset JSON %s with %d entries", + dataset_json_path, + len(dataset_entries), + ) + self._dataset_json_path = dataset_json_path + return dataset_json_path + + def prepare_config(self, dataset_json_path: Optional[Path] = None) -> Path: + """Write the uniGradICON fine-tuning YAML config. + + Args: + dataset_json_path: Path to the dataset JSON to reference. Defaults + to the JSON last produced by :meth:`prepare_dataset`. + + Returns: + Path to the YAML config written under :attr:`experiment_dir`. + + Raises: + ValueError: If no dataset JSON path is available. + """ + if dataset_json_path is None: + dataset_json_path = self._dataset_json_path + if dataset_json_path is None: + raise ValueError( + "dataset_json_path not provided and prepare_dataset() has not " + "been called yet" + ) + + experiment_name = self.experiment_dir / f"{self.fine_tune_name}_model" + + config: dict[str, Any] = { + "experiment": { + "name": self._posix(experiment_name), + "model_weights": "unigradicon", + }, + "training": { + "batch_size": self.batch_size, + "gpus": self.gpus, + "epochs": self.epochs, + "eval_period": self.eval_period, + "save_period": self.save_period, + "learning_rate": self.learning_rate, + "input_shape": list(self.input_shape), + "similarity": self.similarity, + "lambda": self.lambda_value, + "dice_loss_weight": self.dice_loss_weight, + "lncc_sigma": self.lncc_sigma, + "loss_function_masking": self.uses_masks, + "use_label": self.uses_segmentations, + "roi_masking": False, + }, + "datasets": [ + { + "name": self.fine_tune_name, + "weight": 1.0, + "type": "paired", + "json_file": self._posix(dataset_json_path), + "is_ct": self.is_ct, + "ct_window": list(self.ct_window), + "shuffle": True, + "use_cache": True, + } + ], + } + + config_yaml_path = self.experiment_dir / f"{self.fine_tune_name}_config.yaml" + with config_yaml_path.open("w") as fh: + yaml.dump(config, fh, default_flow_style=False, sort_keys=False) + self.log_info("Wrote config YAML %s", config_yaml_path) + self._config_yaml_path = config_yaml_path + return config_yaml_path + + def expected_weights_path(self) -> Path: + """Return the path uniGradICON writes its final checkpoint to. + + ``unigradicon.finetuning.finetune`` writes + ``/checkpoints/Finetune_multi_final.trch`` at the end of + training. Used both as the return value of :meth:`run_fine_tuning` and + as a default in :meth:`apply_registration`. + """ + return ( + self.experiment_dir + / f"{self.fine_tune_name}_model" + / "checkpoints" + / "Finetune_multi_final.trch" + ) + + def run_fine_tuning(self) -> Path: + """Build configs and launch ``unigradicon.finetuning.finetune``. + + Equivalent to running + ``prepare_dataset()`` → ``prepare_config()`` → subprocess launch. Any + existing dataset JSON or YAML in :attr:`experiment_dir` is overwritten. + + Returns: + Path to the expected final checkpoint + (``Finetune_multi_final.trch``). The file is written by the + subprocess and exists only after a successful run. + + Raises: + subprocess.CalledProcessError: If the fine-tuning subprocess exits + with a non-zero status. + """ + self.log_section("FINE-TUNING UNIGRADICON", width=70) + + dataset_json_path = self.prepare_dataset() + config_yaml_path = self.prepare_config(dataset_json_path) + + env = os.environ.copy() + env["PYTHONUTF8"] = "1" + if self.unigradicon_src_path is not None: + env["PYTHONPATH"] = ( + str(self.unigradicon_src_path) + os.pathsep + env.get("PYTHONPATH", "") + ) + + cmd = [ + sys.executable, + "-m", + "unigradicon.finetuning.finetune", + "--config", + str(config_yaml_path), + ] + self.log_info("Launching fine-tuning subprocess: %s", " ".join(cmd)) + subprocess.run(cmd, check=True, env=env) + + weights_path = self.expected_weights_path() + self.log_info("Fine-tuning complete. Expected weights at %s", weights_path) + return weights_path + + @staticmethod + def _transform_landmarks( + landmarks: Landmarks, transform: itk.Transform + ) -> Landmarks: + """Apply ``transform.TransformPoint`` to every landmark in physical LPS space.""" + transformed: Landmarks = {} + for name, point in landmarks.items(): + new_point = transform.TransformPoint(point) + transformed[name] = ( + float(new_point[0]), + float(new_point[1]), + float(new_point[2]), + ) + return transformed + + def apply_registration( + self, + reference_image: itk.Image, + moving_images: list[itk.Image], + weights_path: Optional[Union[str, Path]] = None, + reference_segmentation: Optional[itk.Image] = None, + reference_landmarks: Optional[Landmarks] = None, + moving_segmentations: Optional[list[Optional[itk.Image]]] = None, + moving_landmarks: Optional[list[Optional[Landmarks]]] = None, + number_of_iterations: int = 20, + modality: str = "ct", + ) -> dict[str, Any]: + """Register each moving image to the reference using fine-tuned ICON weights. + + For every moving image this method: + + - Runs ICON registration ``moving → reference``. When a moving + segmentation is provided, a binary heart-ROI mask is derived from it + and passed as the registration mask so the ICON loss only sees the + ROI; the same is done for the reference segmentation (used as the + fixed mask). + - Warps the moving image, segmentation, and landmarks into reference + space using ``forward_transform``. Segmentations use nearest-neighbor + interpolation. Landmarks use ``inverse_transform.TransformPoint`` + (resampler-convention transform: maps moving-grid points back to + reference-grid points). + - Warps the reference image, segmentation, and landmarks into each + moving-image space using ``inverse_transform`` for image/segmentation + resampling and ``forward_transform.TransformPoint`` for landmarks. + + Args: + reference_image: Fixed (reference) ``itk.Image`` in LPS. + moving_images: List of moving ``itk.Image`` instances to register + to ``reference_image``. + weights_path: Path to a uniGradICON checkpoint (e.g. + ``Finetune_multi_final.trch``). ``None`` uses the default + pretrained uniGradICON weights. + reference_segmentation: Optional multi-label labelmap aligned with + ``reference_image``. Used to derive the fixed-image mask and + returned warped into each moving-image space. + reference_landmarks: Optional ``{name: (x, y, z)}`` landmark dict in + LPS that will be warped into each moving-image space. + moving_segmentations: Optional per-moving multi-label labelmaps + aligned with ``moving_images``. Used to derive per-moving + masks and returned warped into reference space. Per-image + ``None`` entries are allowed. + moving_landmarks: Optional per-moving landmark dicts in LPS. Each + set is warped into reference space. Per-image ``None`` entries + are allowed. + number_of_iterations: ICON fine-tuning iterations per registration. + modality: Imaging modality passed through to the underlying ICON + registrar (``'ct'`` or ``'mri'``). + + Returns: + dict with: + + - ``forward_transforms`` (``list[itk.Transform]``): per-moving + transforms mapping reference grid → moving grid (used to + resample moving → reference). + - ``inverse_transforms`` (``list[itk.Transform]``): per-moving + transforms mapping moving grid → reference grid (used to + resample reference → moving). + - ``losses`` (``list[float]``): per-moving registration loss. + - ``moving_to_reference_images`` (``list[itk.Image]``): each + moving image resampled onto the reference grid. + - ``moving_to_reference_segmentations`` (``list[Optional[itk.Image]]``): + each moving segmentation resampled onto the reference grid + with nearest-neighbor interpolation. ``None`` when the input + was ``None``. + - ``moving_to_reference_landmarks`` (``list[Optional[Landmarks]]``): + each moving landmark set warped into reference space. + ``None`` when the input was ``None``. + - ``reference_to_moving_images`` (``list[itk.Image]``): the + reference image resampled onto each moving grid. + - ``reference_to_moving_segmentations`` (``list[Optional[itk.Image]]``): + the reference segmentation resampled onto each moving grid + with nearest-neighbor interpolation. ``None`` for every + entry when ``reference_segmentation`` was ``None``. + - ``reference_to_moving_landmarks`` (``list[Optional[Landmarks]]``): + reference landmarks warped into each moving space. ``None`` + for every entry when ``reference_landmarks`` was ``None``. + + Raises: + ValueError: If ``moving_images`` is empty. + ValueError: If ``moving_segmentations`` or ``moving_landmarks`` is + supplied with a length that does not match ``moving_images``. + """ + if not moving_images: + raise ValueError("moving_images must not be empty") + num_moving = len(moving_images) + if moving_segmentations is not None and len(moving_segmentations) != num_moving: + raise ValueError( + f"moving_segmentations length ({len(moving_segmentations)}) must " + f"match moving_images length ({num_moving})" + ) + if moving_landmarks is not None and len(moving_landmarks) != num_moving: + raise ValueError( + f"moving_landmarks length ({len(moving_landmarks)}) must match " + f"moving_images length ({num_moving})" + ) + + self.log_section("APPLYING FINE-TUNED ICON REGISTRATION", width=70) + self.log_info("Number of moving images: %d", num_moving) + if weights_path is None: + self.log_info("ICON weights: ") + else: + self.log_info("ICON weights: %s", weights_path) + + fixed_mask = ( + RegisterImagesICON.create_mask( + reference_segmentation, dilation_mm=self.mask_dilation_mm + ) + if reference_segmentation is not None + else None + ) + moving_masks: Optional[list[Optional[itk.Image]]] = None + if moving_segmentations is not None: + moving_masks = [ + ( + RegisterImagesICON.create_mask( + seg, dilation_mm=self.mask_dilation_mm + ) + if seg is not None + else None + ) + for seg in moving_segmentations + ] + + self.registrar = RegisterTimeSeriesImages( + registration_method="ICON", log_level=self.log_level + ) + self.registrar.set_modality(modality) + self.registrar.set_fixed_image(reference_image) + self.registrar.set_fixed_mask(fixed_mask) + self.registrar.set_number_of_iterations_ICON(number_of_iterations) + if weights_path is not None: + self.registrar.registrar_ICON.set_weights_path(str(weights_path)) + + result = self.registrar.register_time_series( + moving_images=moving_images, + moving_masks=moving_masks, + moving_labelmaps=None, + reference_frame=0, + register_reference=True, + prior_weight=0.0, + ) + forward_transforms = result["forward_transforms"] + inverse_transforms = result["inverse_transforms"] + losses = result["losses"] + + moving_to_reference_images: list[itk.Image] = [] + moving_to_reference_segmentations: list[Optional[itk.Image]] = [] + moving_to_reference_landmarks: list[Optional[Landmarks]] = [] + reference_to_moving_images: list[itk.Image] = [] + reference_to_moving_segmentations: list[Optional[itk.Image]] = [] + reference_to_moving_landmarks: list[Optional[Landmarks]] = [] + + for index in range(num_moving): + forward_tfm = forward_transforms[index] + inverse_tfm = inverse_transforms[index] + moving_image = moving_images[index] + + moving_to_reference_images.append( + self.transform_tools.transform_image( + moving_image, forward_tfm, reference_image + ) + ) + reference_to_moving_images.append( + self.transform_tools.transform_image( + reference_image, inverse_tfm, moving_image + ) + ) + + moving_seg = ( + moving_segmentations[index] + if moving_segmentations is not None + else None + ) + if moving_seg is not None: + moving_to_reference_segmentations.append( + self.transform_tools.transform_image( + moving_seg, + forward_tfm, + reference_image, + interpolation_method="nearest", + ) + ) + else: + moving_to_reference_segmentations.append(None) + + if reference_segmentation is not None: + reference_to_moving_segmentations.append( + self.transform_tools.transform_image( + reference_segmentation, + inverse_tfm, + moving_image, + interpolation_method="nearest", + ) + ) + else: + reference_to_moving_segmentations.append(None) + + moving_lms = ( + moving_landmarks[index] if moving_landmarks is not None else None + ) + if moving_lms is not None: + moving_to_reference_landmarks.append( + self._transform_landmarks(moving_lms, inverse_tfm) + ) + else: + moving_to_reference_landmarks.append(None) + + if reference_landmarks is not None: + reference_to_moving_landmarks.append( + self._transform_landmarks(reference_landmarks, forward_tfm) + ) + else: + reference_to_moving_landmarks.append(None) + + self.log_info( + "Average ICON loss: %.6f (min %.6f, max %.6f)", + float(np.mean(losses)), + float(np.min(losses)), + float(np.max(losses)), + ) + + return { + "forward_transforms": forward_transforms, + "inverse_transforms": inverse_transforms, + "losses": losses, + "moving_to_reference_images": moving_to_reference_images, + "moving_to_reference_segmentations": moving_to_reference_segmentations, + "moving_to_reference_landmarks": moving_to_reference_landmarks, + "reference_to_moving_images": reference_to_moving_images, + "reference_to_moving_segmentations": reference_to_moving_segmentations, + "reference_to_moving_landmarks": reference_to_moving_landmarks, + } diff --git a/src/physiomotion4d/workflow_fit_statistical_model_to_patient.py b/src/physiomotion4d/workflow_fit_statistical_model_to_patient.py index 3de73d2..a533f4f 100644 --- a/src/physiomotion4d/workflow_fit_statistical_model_to_patient.py +++ b/src/physiomotion4d/workflow_fit_statistical_model_to_patient.py @@ -31,7 +31,7 @@ from physiomotion4d.contour_tools import ContourTools from physiomotion4d.physiomotion4d_base import PhysioMotion4DBase -from physiomotion4d.register_images_ants import RegisterImagesANTs +from physiomotion4d.register_images_ants import RegisterImagesANTS from physiomotion4d.register_images_icon import RegisterImagesICON from physiomotion4d.register_models_distance_maps import RegisterModelsDistanceMaps from physiomotion4d.register_models_icp import RegisterModelsICP @@ -92,8 +92,8 @@ class WorkflowFitStatisticalModelToPatient(PhysioMotion4DBase): mask_dilation_mm (float): Dilation for mask generation roi_dilation_mm (float): Dilation for ROI mask transform_tools (TransformTools): Transform utilities - registrar_icon (RegisterImagesICON): ICON registration instance - registrar_ants (RegisterImagesANTs): ANTs registration instance + registrar_ICON (RegisterImagesICON): ICON registration instance + registrar_ANTS (RegisterImagesANTS): ANTs registration instance use_pca_registration (bool): Whether PCA registration is enabled (set via set_use_pca_registration) pca_model (dict): PCA model dict when PCA enabled; same structure as WorkflowCreateStatisticalModel output pca_number_of_modes (int): Number of PCA modes when PCA enabled @@ -216,14 +216,14 @@ def __init__( ptype=itk.F, ) - self.registrar_ants = RegisterImagesANTs() - self.registrar_ants.set_number_of_iterations([5, 2, 5]) + self.registrar_ANTS = RegisterImagesANTS() + self.registrar_ANTS.set_number_of_iterations([5, 2, 5]) # Icon registration for final mask-to-image step - self.registrar_icon = RegisterImagesICON() - self.registrar_icon.set_modality("ct") - self.registrar_icon.set_mass_preservation(False) - self.registrar_icon.set_multi_modality(True) - self.registrar_icon.set_number_of_iterations(50) + self.registrar_ICON = RegisterImagesICON() + self.registrar_ICON.set_modality("ct") + self.registrar_ICON.set_mass_preservation(False) + self.registrar_ICON.set_multi_modality(True) + self.registrar_ICON.set_number_of_iterations(50) # Mask configuration (auto-generated) self.template_model_mask = None @@ -270,7 +270,7 @@ def __init__( self.m2i_template_model_surface: Optional[pv.PolyData] = None self.m2i_template_labelmap: Optional[itk.Image] = None - self.use_icon_registration_refinement = False + self.use_ICON_registration_refinement = False # Final result self.registered_template_model: Optional[pv.DataSet] = None @@ -683,7 +683,7 @@ def register_model_to_model_pca(self) -> dict: } def register_mask_to_mask( - self, use_icon_refinement: bool = False + self, use_ICON_refinement: bool = False ) -> Optional[dict]: """Perform mask-based deformable registration of model to patient model. @@ -719,7 +719,7 @@ def register_mask_to_mask( # Run deformable registration mask_result = mask_registrar.register( transform_type="Deformable", - use_icon=use_icon_refinement, + use_ICON=use_ICON_refinement, ) # Store results @@ -751,7 +751,7 @@ def register_mask_to_mask( } def register_labelmap_to_image( - self, use_icon_refinement: bool = False + self, use_ICON_refinement: bool = False ) -> Optional[dict]: """Perform labelmap-to-image refinement. @@ -814,22 +814,22 @@ def register_labelmap_to_image( ) patient_roi = self._auto_generate_roi_mask(patient_mask) - self.registrar_ants.set_fixed_image(self.patient_image) - self.registrar_ants.set_fixed_mask(patient_roi) + self.registrar_ANTS.set_fixed_image(self.patient_image) + self.registrar_ANTS.set_fixed_mask(patient_roi) - result = self.registrar_ants.register( + result = self.registrar_ANTS.register( moving_image=labelmap, moving_mask=labelmap_roi ) self.m2i_inverse_transform = result["inverse_transform"] self.m2i_forward_transform = result["forward_transform"] - if use_icon_refinement: + if use_ICON_refinement: # Configure Icon registration - self.registrar_icon.set_fixed_image(self.patient_image) - self.registrar_icon.set_fixed_mask(patient_roi) + self.registrar_ICON.set_fixed_image(self.patient_image) + self.registrar_ICON.set_fixed_mask(patient_roi) # Perform Icon registration - result = self.registrar_icon.register( + result = self.registrar_ICON.register( initial_forward_transform=self.m2i_forward_transform, moving_image=labelmap, moving_mask=labelmap_roi, @@ -937,7 +937,7 @@ def transform_model( def run_workflow( self, - use_icon_registration_refinement: bool = False, + use_ICON_registration_refinement: bool = False, ) -> dict: """Execute the complete multi-stage registration workflow. @@ -950,7 +950,7 @@ def run_workflow( set via set_use_mask_to_image_registration(True, ...). Args: - use_icon_registration_refinement: Whether to include icon registration + use_ICON_registration_refinement: Whether to include icon registration refinement stage. Default: False Returns: @@ -960,7 +960,7 @@ def run_workflow( "STARTING COMPLETE MODEL-TO-IMAGE-AND-MODEL REGISTRATION WORKFLOW", width=70 ) - self.use_icon_registration_refinement = use_icon_registration_refinement + self.use_ICON_registration_refinement = use_ICON_registration_refinement # Stage 1: ICP alignment self.register_model_to_model_icp() @@ -971,13 +971,13 @@ def run_workflow( # Stage 3: Optional Mask-to-mask deformable registration if self.use_m2m_registration: self.register_mask_to_mask( - use_icon_refinement=use_icon_registration_refinement + use_ICON_refinement=use_ICON_registration_refinement ) # Stage 4: Optional mask-to-image refinement if self.use_m2i_registration: self.register_labelmap_to_image( - use_icon_refinement=use_icon_registration_refinement + use_ICON_refinement=use_ICON_registration_refinement ) _ = self.transform_model() diff --git a/src/physiomotion4d/workflow_reconstruct_highres_4d_ct.py b/src/physiomotion4d/workflow_reconstruct_highres_4d_ct.py index 31e6c42..ebfb732 100644 --- a/src/physiomotion4d/workflow_reconstruct_highres_4d_ct.py +++ b/src/physiomotion4d/workflow_reconstruct_highres_4d_ct.py @@ -41,7 +41,7 @@ class WorkflowReconstructHighres4DCT(PhysioMotion4DBase): **Registration Pipeline:** 1. **Time Series Registration**: Register each time-series image to the - high-resolution reference using RegisterTimeSeriesImages with ants_icon method + high-resolution reference using RegisterTimeSeriesImages with ANTS_ICON method 2. **Reconstruction**: Apply inverse transforms to reconstruct high-resolution time series 3. **Optional Upsampling**: Resample to isotropic high resolution @@ -58,7 +58,7 @@ class WorkflowReconstructHighres4DCT(PhysioMotion4DBase): register_reference (bool): Whether to register reference frame prior_weight (float): Weight for temporal smoothing (0.0-1.0) upsample_to_fixed_resolution (bool): Whether to upsample reconstruction - registration_method (str): Registration method ('ants', 'icon', or 'ants_icon') + registration_method (str): Registration method ('ANTS', 'ICON', or 'ANTS_ICON') number_of_iterations: Iterations for registration registrar (RegisterTimeSeriesImages): Internal registration object forward_transforms (list[itk.Transform]): Forward transforms (moving → fixed) @@ -72,12 +72,12 @@ class WorkflowReconstructHighres4DCT(PhysioMotion4DBase): ... time_series_images=lowres_images, ... fixed_image=highres_reference, ... reference_frame=3, - ... registration_method='ants_icon', + ... registration_method='ANTS_ICON', ... ) >>> >>> # Configure registration parameters - >>> workflow.set_number_of_iterations_ants([30, 15, 7]) - >>> workflow.set_number_of_iterations_icon(20) + >>> workflow.set_number_of_iterations_ANTS([30, 15, 7]) + >>> workflow.set_number_of_iterations_ICON(20) >>> workflow.set_prior_weight(0.5) >>> >>> # Run complete workflow @@ -95,7 +95,7 @@ def __init__( fixed_image: itk.Image, reference_frame: int = 0, register_reference: bool = False, - registration_method: str = "ants_icon", + registration_method: str = "ANTS_ICON", log_level: int | str = logging.INFO, ): """Initialize the high-resolution 4D CT reconstruction workflow. @@ -111,7 +111,7 @@ def __init__( to the fixed image. If False, use identity transform for reference. Default: False registration_method (str, optional): Registration method to use. - Options: 'ants', 'icon', or 'ants_icon'. Default: 'ants_icon' + Options: 'ANTS', 'ICON', or 'ANTS_ICON'. Default: 'ANTS_ICON' log_level: Logging level (logging.DEBUG, logging.INFO, etc.). Default: logging.INFO @@ -135,9 +135,9 @@ def __init__( f"[0, {len(time_series_images) - 1}]" ) - if registration_method not in ["ants", "icon", "ants_icon"]: + if registration_method not in ["ANTS", "ICON", "ANTS_ICON"]: raise ValueError( - f"registration_method must be 'ants', 'icon', or 'ants_icon', " + f"registration_method must be 'ANTS', 'ICON', or 'ANTS_ICON', " f"got '{registration_method}'" ) @@ -157,8 +157,8 @@ def __init__( self.moving_masks: Optional[list[Optional[itk.Image]]] = None # Set default number of iterations based on registration method - self.number_of_iterations_ants: list[int] = [30, 15, 7, 3] - self.number_of_iterations_icon: int = 20 + self.number_of_iterations_ANTS: list[int] = [30, 15, 7, 3] + self.number_of_iterations_ICON: int = 20 # Initialize registrar self.registrar = RegisterTimeSeriesImages( @@ -171,24 +171,24 @@ def __init__( self.losses: Optional[list[float]] = None self.reconstructed_images: Optional[list[itk.Image]] = None - def set_number_of_iterations_ants( - self, number_of_iterations_ants: list[int] + def set_number_of_iterations_ANTS( + self, number_of_iterations_ANTS: list[int] ) -> None: """Set the number of iterations for ANTs registration. Args: - number_of_iterations_ants: List of iterations for ANTs multi-resolution + number_of_iterations_ANTS: List of iterations for ANTs multi-resolution (e.g., [30, 15, 7, 3] for four resolution levels) """ - self.number_of_iterations_ants = number_of_iterations_ants + self.number_of_iterations_ANTS = number_of_iterations_ANTS - def set_number_of_iterations_icon(self, number_of_iterations_icon: int) -> None: + def set_number_of_iterations_ICON(self, number_of_iterations_ICON: int) -> None: """Set the number of iterations for ICON registration. Args: - number_of_iterations_icon: Number of fine-tuning steps for ICON + number_of_iterations_ICON: Number of fine-tuning steps for ICON """ - self.number_of_iterations_icon = number_of_iterations_icon + self.number_of_iterations_ICON = number_of_iterations_ICON def set_prior_weight(self, prior_weight: float) -> None: """Set the weight for temporal smoothing with prior transforms. @@ -277,8 +277,8 @@ def register_time_series(self) -> dict: self.registrar.set_fixed_image(self.fixed_image) self.registrar.set_modality(self.modality) self.registrar.set_mask_dilation(self.mask_dilation_mm) - self.registrar.set_number_of_iterations_ants(self.number_of_iterations_ants) - self.registrar.set_number_of_iterations_icon(self.number_of_iterations_icon) + self.registrar.set_number_of_iterations_ANTS(self.number_of_iterations_ANTS) + self.registrar.set_number_of_iterations_ICON(self.number_of_iterations_ICON) self.registrar.set_fixed_mask(self.fixed_mask) self.log_info(f"Registration method: {self.registration_method}") @@ -286,8 +286,8 @@ def register_time_series(self) -> dict: self.log_info(f"Reference frame: {self.reference_frame}") self.log_info(f"Register reference: {self.register_reference}") self.log_info(f"Prior weight: {self.prior_weight}") - self.log_info(f"Number of iterations (ANTs): {self.number_of_iterations_ants}") - self.log_info(f"Number of iterations (ICON): {self.number_of_iterations_icon}") + self.log_info(f"Number of iterations (ANTs): {self.number_of_iterations_ANTS}") + self.log_info(f"Number of iterations (ICON): {self.number_of_iterations_ICON}") # Perform registration result = self.registrar.register_time_series( diff --git a/tests/README.md b/tests/README.md index 4876470..decd103 100644 --- a/tests/README.md +++ b/tests/README.md @@ -185,8 +185,8 @@ test_download_heart_data ↓ test_convert_image_4d_to_3d ↓ ↓ - ↓ ├─→ test_register_images_ants ──→ test_transform_tools - ↓ ├─→ test_register_images_icon + ↓ ├─→ test_register_images_ANTS ──→ test_transform_tools + ↓ ├─→ test_register_images_ICON ↓ ↓ test_segment_chest_total_segmentator ────→ test_contour_tools ↓ diff --git a/tests/conftest.py b/tests/conftest.py index 290faad..d8415eb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,7 +17,7 @@ from physiomotion4d.contour_tools import ContourTools from physiomotion4d.convert_image_4d_to_3d import ConvertImage4DTo3D from physiomotion4d.data_download_tools import DataDownloadTools -from physiomotion4d.register_images_ants import RegisterImagesANTs +from physiomotion4d.register_images_ants import RegisterImagesANTS from physiomotion4d.register_images_greedy import RegisterImagesGreedy from physiomotion4d.register_images_icon import RegisterImagesICON from physiomotion4d.segment_chest_total_segmentator import SegmentChestTotalSegmentator @@ -544,7 +544,7 @@ def test_labelmaps( @pytest.fixture(scope="session") def test_transforms( - registrar_ants: RegisterImagesANTs, + registrar_ANTS: RegisterImagesANTS, test_images: list[Any], test_directories: dict[str, Path], ) -> dict[str, Any]: @@ -578,8 +578,8 @@ def test_transforms( fixed_image = test_images[7] moving_image = test_images[1] - registrar_ants.set_fixed_image(fixed_image) - result = registrar_ants.register(moving_image=moving_image) + registrar_ANTS.set_fixed_image(fixed_image) + result = registrar_ANTS.register(moving_image=moving_image) inverse_transform = result["inverse_transform"] forward_transform = result["forward_transform"] @@ -616,9 +616,9 @@ def contour_tools() -> ContourTools: @pytest.fixture(scope="session") -def registrar_ants() -> RegisterImagesANTs: - """Create a RegisterImagesANTs instance.""" - return RegisterImagesANTs() +def registrar_ANTS() -> RegisterImagesANTS: + """Create a RegisterImagesANTS instance.""" + return RegisterImagesANTS() @pytest.fixture(scope="session") @@ -628,7 +628,7 @@ def registrar_greedy() -> RegisterImagesGreedy: @pytest.fixture(scope="session") -def registrar_icon() -> RegisterImagesICON: +def registrar_ICON() -> RegisterImagesICON: """Create a RegisterImagesICON instance.""" return RegisterImagesICON() diff --git a/tests/test_register_images_ants.py b/tests/test_register_images_ants.py index 4d425ac..8102c58 100644 --- a/tests/test_register_images_ants.py +++ b/tests/test_register_images_ants.py @@ -16,54 +16,54 @@ import numpy as np import pytest -from physiomotion4d.register_images_ants import RegisterImagesANTs +from physiomotion4d.register_images_ants import RegisterImagesANTS from physiomotion4d.transform_tools import TransformTools @pytest.mark.slow -class TestRegisterImagesANTs: +class TestRegisterImagesANTS: """Test suite for ANTs-based image registration.""" - def test_registrar_initialization(self, registrar_ants: RegisterImagesANTs) -> None: - """Test that RegisterImagesANTs initializes correctly.""" - assert registrar_ants is not None, "Registrar not initialized" - assert hasattr(registrar_ants, "fixed_image"), "Missing fixed_image attribute" - assert hasattr(registrar_ants, "fixed_mask"), "Missing fixed_mask attribute" + def test_registrar_initialization(self, registrar_ANTS: RegisterImagesANTS) -> None: + """Test that RegisterImagesANTS initializes correctly.""" + assert registrar_ANTS is not None, "Registrar not initialized" + assert hasattr(registrar_ANTS, "fixed_image"), "Missing fixed_image attribute" + assert hasattr(registrar_ANTS, "fixed_mask"), "Missing fixed_mask attribute" - print("\nANTs registrar initialized successfully") + print("\nANTS registrar initialized successfully") - def test_set_modality(self, registrar_ants: RegisterImagesANTs) -> None: + def test_set_modality(self, registrar_ANTS: RegisterImagesANTS) -> None: """Test setting imaging modality.""" - registrar_ants.set_modality("ct") - assert registrar_ants.modality == "ct", "Modality not set correctly" + registrar_ANTS.set_modality("ct") + assert registrar_ANTS.modality == "ct", "Modality not set correctly" - registrar_ants.set_modality("mr") - assert registrar_ants.modality == "mr", "Modality change failed" + registrar_ANTS.set_modality("mr") + assert registrar_ANTS.modality == "mr", "Modality change failed" print("\nModality setting works correctly") def test_set_fixed_image( - self, registrar_ants: RegisterImagesANTs, test_images: list[Any] + self, registrar_ANTS: RegisterImagesANTS, test_images: list[Any] ) -> None: """Test setting fixed image.""" fixed_image = test_images[0] - registrar_ants.set_fixed_image(fixed_image) - assert registrar_ants.fixed_image is not None, "Fixed image not set" + registrar_ANTS.set_fixed_image(fixed_image) + assert registrar_ANTS.fixed_image is not None, "Fixed image not set" print("\nFixed image set successfully") - print(f" Image size: {itk.size(registrar_ants.fixed_image)}") - print(f" Image spacing: {itk.spacing(registrar_ants.fixed_image)}") + print(f" Image size: {itk.size(registrar_ANTS.fixed_image)}") + print(f" Image spacing: {itk.spacing(registrar_ANTS.fixed_image)}") def test_register_without_mask( self, - registrar_ants: RegisterImagesANTs, + registrar_ANTS: RegisterImagesANTS, test_images: list[Any], test_directories: dict[str, Path], ) -> None: """Test basic registration without masks.""" output_dir = test_directories["output"] - reg_output_dir = output_dir / "registration_ants" + reg_output_dir = output_dir / "registration_ANTS" reg_output_dir.mkdir(exist_ok=True) # Set up registration @@ -74,11 +74,11 @@ def test_register_without_mask( print(f" Fixed image: {itk.size(fixed_image)}") print(f" Moving image: {itk.size(moving_image)}") - registrar_ants.set_modality("ct") - registrar_ants.set_fixed_image(fixed_image) + registrar_ANTS.set_modality("ct") + registrar_ANTS.set_fixed_image(fixed_image) # Register - result = registrar_ants.register(moving_image=moving_image) + result = registrar_ANTS.register(moving_image=moving_image) # Verify result is a dictionary assert isinstance(result, dict), "Result should be a dictionary" @@ -111,13 +111,13 @@ def test_register_without_mask( def test_register_with_mask( self, - registrar_ants: RegisterImagesANTs, + registrar_ANTS: RegisterImagesANTS, test_images: list[Any], test_directories: dict[str, Path], ) -> None: """Test registration with binary masks.""" output_dir = test_directories["output"] - reg_output_dir = output_dir / "registration_ants" + reg_output_dir = output_dir / "registration_ANTS" reg_output_dir.mkdir(exist_ok=True) fixed_image = test_images[0] @@ -168,12 +168,12 @@ def test_register_with_mask( print(f" Moving mask voxels: {np.sum(moving_mask_arr)}") # Set up registration with masks - registrar_ants.set_modality("ct") - registrar_ants.set_fixed_image(fixed_image) - registrar_ants.set_fixed_mask(fixed_mask) + registrar_ANTS.set_modality("ct") + registrar_ANTS.set_fixed_image(fixed_image) + registrar_ANTS.set_fixed_mask(fixed_mask) # Register - result = registrar_ants.register( + result = registrar_ANTS.register( moving_image=moving_image, moving_mask=moving_mask ) @@ -204,22 +204,22 @@ def test_register_with_mask( def test_transform_application( self, - registrar_ants: RegisterImagesANTs, + registrar_ANTS: RegisterImagesANTS, test_images: list[Any], test_directories: dict[str, Path], ) -> None: """Test applying registration transforms to images.""" output_dir = test_directories["output"] - reg_output_dir = output_dir / "registration_ants" + reg_output_dir = output_dir / "registration_ANTS" reg_output_dir.mkdir(exist_ok=True) fixed_image = test_images[0] moving_image = test_images[1] # Register - registrar_ants.set_modality("ct") - registrar_ants.set_fixed_image(fixed_image) - result = registrar_ants.register(moving_image=moving_image) + registrar_ANTS.set_modality("ct") + registrar_ANTS.set_fixed_image(fixed_image) + result = registrar_ANTS.register(moving_image=moving_image) forward_transform = result["forward_transform"] @@ -257,7 +257,7 @@ def test_transform_application( print(f" Saved to: {reg_output_dir / 'ants_registered_image.mha'}") def test_preprocess_images( - self, registrar_ants: RegisterImagesANTs, test_images: list[Any] + self, registrar_ANTS: RegisterImagesANTS, test_images: list[Any] ) -> None: """Test image preprocessing.""" test_image = test_images[0] @@ -266,7 +266,7 @@ def test_preprocess_images( print(f" Original spacing: {itk.spacing(test_image)}") # Preprocess - preprocessed = registrar_ants.preprocess(test_image, modality="ct") + preprocessed = registrar_ANTS.preprocess(test_image, modality="ct") assert preprocessed is not None, "Preprocessed image is None" @@ -276,13 +276,13 @@ def test_preprocess_images( def test_registration_with_initial_transform( self, - registrar_ants: RegisterImagesANTs, + registrar_ANTS: RegisterImagesANTS, test_images: list[Any], test_directories: dict[str, Path], ) -> None: """Test registration with initial transform.""" output_dir = test_directories["output"] - reg_output_dir = output_dir / "registration_ants" + reg_output_dir = output_dir / "registration_ANTS" reg_output_dir.mkdir(exist_ok=True) fixed_image = test_images[0] @@ -295,10 +295,10 @@ def test_registration_with_initial_transform( print("\nRegistering with initial transform...") print(" Initial offset: [-5.0, -5.0, -5.0]") - registrar_ants.set_modality("ct") - registrar_ants.set_fixed_image(fixed_image) + registrar_ANTS.set_modality("ct") + registrar_ANTS.set_fixed_image(fixed_image) - result = registrar_ants.register( + result = registrar_ANTS.register( moving_image=moving_image, initial_forward_transform=initial_tfm_forward, ) @@ -310,7 +310,7 @@ def test_registration_with_initial_transform( print("Registration with initial transform complete") def test_multiple_registrations( - self, registrar_ants: RegisterImagesANTs, test_images: list[Any] + self, registrar_ANTS: RegisterImagesANTS, test_images: list[Any] ) -> None: """Test running multiple registrations in sequence.""" fixed_image = test_images[0] @@ -318,13 +318,13 @@ def test_multiple_registrations( print("\nRunning multiple registrations...") - registrar_ants.set_modality("ct") - registrar_ants.set_fixed_image(fixed_image) + registrar_ANTS.set_modality("ct") + registrar_ANTS.set_fixed_image(fixed_image) results = [] for i in range(2): print(f" Registration {i + 1}...") - result = registrar_ants.register(moving_image=moving_image) + result = registrar_ANTS.register(moving_image=moving_image) results.append(result) assert isinstance(result, dict), f"Result {i + 1} should be a dictionary" @@ -338,15 +338,15 @@ def test_multiple_registrations( print(f"Multiple registrations complete: {len(results)} runs") def test_transform_types( - self, registrar_ants: RegisterImagesANTs, test_images: list[Any] + self, registrar_ANTS: RegisterImagesANTS, test_images: list[Any] ) -> None: """Test that transforms are correct ITK types.""" fixed_image = test_images[0] moving_image = test_images[1] - registrar_ants.set_modality("ct") - registrar_ants.set_fixed_image(fixed_image) - result = registrar_ants.register(moving_image=moving_image) + registrar_ANTS.set_modality("ct") + registrar_ANTS.set_fixed_image(fixed_image) + result = registrar_ANTS.register(moving_image=moving_image) inverse_transform = result["inverse_transform"] forward_transform = result["forward_transform"] @@ -366,7 +366,7 @@ def test_transform_types( print(f" forward_transform: {type(forward_transform).__name__}") def test_image_conversion_cycle_scalar( - self, registrar_ants: RegisterImagesANTs, test_images: list[Any] + self, registrar_ANTS: RegisterImagesANTS, test_images: list[Any] ) -> None: """Test round-trip conversion: ITK image -> ANTs -> ITK for scalar images.""" original_image = test_images[0] @@ -383,7 +383,7 @@ def test_image_conversion_cycle_scalar( original_direction = itk.array_from_matrix(original_image.GetDirection()) # Convert ITK -> ANTs - ants_image = registrar_ants._itk_to_ants_image(original_image, dtype="float") + ants_image = registrar_ANTS._itk_to_ants_image(original_image, dtype="float") # Verify ANTs image assert ants_image is not None, "ANTs image is None" @@ -393,7 +393,7 @@ def test_image_conversion_cycle_scalar( print(f" ANTs image shape: {ants_image.shape}") # Convert ANTs -> ITK - recovered_image = registrar_ants._ants_to_itk_image(ants_image) + recovered_image = registrar_ANTS._ants_to_itk_image(ants_image) # Verify recovered image assert recovered_image is not None, "Recovered image is None" @@ -442,7 +442,7 @@ def test_image_conversion_cycle_scalar( print("Scalar image conversion cycle successful") def test_image_conversion_cycle_different_dtypes( - self, registrar_ants: RegisterImagesANTs, test_images: list[Any] + self, registrar_ANTS: RegisterImagesANTS, test_images: list[Any] ) -> None: """Test round-trip conversion with different data types.""" original_image = test_images[0] @@ -455,11 +455,11 @@ def test_image_conversion_cycle_different_dtypes( print(f" Testing dtype: {dtype}") # Convert ITK -> ANTs with specified dtype - ants_image = registrar_ants._itk_to_ants_image(original_image, dtype=dtype) + ants_image = registrar_ANTS._itk_to_ants_image(original_image, dtype=dtype) assert ants_image is not None, f"ANTs image is None for dtype {dtype}" # Convert ANTs -> ITK - recovered_image = registrar_ants._ants_to_itk_image(ants_image) + recovered_image = registrar_ANTS._ants_to_itk_image(ants_image) assert recovered_image is not None, ( f"Recovered image is None for dtype {dtype}" ) @@ -474,7 +474,7 @@ def test_image_conversion_cycle_different_dtypes( print("All dtype conversions successful") def test_image_conversion_preserves_metadata( - self, registrar_ants: RegisterImagesANTs + self, registrar_ANTS: RegisterImagesANTS ) -> None: """Test that image conversion preserves all metadata.""" print("\nTesting metadata preservation in image conversion...") @@ -502,8 +502,8 @@ def test_image_conversion_preserves_metadata( print(f" Test image origin: {origin}") # Convert ITK -> ANTs -> ITK - ants_image = registrar_ants._itk_to_ants_image(test_image) - recovered_image = registrar_ants._ants_to_itk_image(ants_image) + ants_image = registrar_ANTS._itk_to_ants_image(test_image) + recovered_image = registrar_ANTS._ants_to_itk_image(ants_image) # Verify all metadata recovered_size = [int(s) for s in itk.size(recovered_image)] @@ -521,7 +521,7 @@ def test_image_conversion_preserves_metadata( print("Metadata preservation verified") def test_transform_conversion_cycle_affine( - self, registrar_ants: RegisterImagesANTs, test_images: list[Any] + self, registrar_ANTS: RegisterImagesANTS, test_images: list[Any] ) -> None: """Test round-trip conversion: ITK affine transform -> ANTs -> ITK.""" reference_image = test_images[0] @@ -560,7 +560,7 @@ def test_transform_conversion_cycle_affine( # Convert ITK -> ANTs file with tempfile.TemporaryDirectory() as tmpdir: temp_tfm_file = os.path.join(tmpdir, "temp_transform.mat") - transform_files = registrar_ants.itk_transform_to_antsfile( + transform_files = registrar_ANTS.itk_transform_to_ANTSfile( affine_tfm, reference_image, temp_tfm_file ) assert len(transform_files) == 1, "Should return one transform file" @@ -577,13 +577,13 @@ def test_transform_conversion_cycle_affine( # Convert back ANTs -> ITK # Affine transforms are stored as affine in ANTs, so read back as affine if ants_tfm.transform_type == "AffineTransform": - recovered_tfm = registrar_ants._antsfile_to_itk_affine_transform( + recovered_tfm = registrar_ANTS._antsfile_to_itk_affine_transform( transform_files[0] ) else: # For displacement field transforms recovered_tfm = ( - registrar_ants._antsfile_to_itk_displacement_field_transform( + registrar_ANTS._antsfile_to_itk_displacement_field_transform( transform_files[0], reference_image ) ) @@ -627,7 +627,7 @@ def test_transform_conversion_cycle_affine( print("Affine transform conversion cycle successful") def test_transform_conversion_cycle_displacement_field( - self, registrar_ants: RegisterImagesANTs, test_images: list[Any] + self, registrar_ANTS: RegisterImagesANTS, test_images: list[Any] ) -> None: """Test round-trip conversion: ITK displacement field -> ANTs -> ITK.""" reference_image = test_images[0] @@ -666,7 +666,7 @@ def test_transform_conversion_cycle_displacement_field( # Convert ITK -> ANTs file with tempfile.TemporaryDirectory() as tmpdir: temp_tfm_file = os.path.join(tmpdir, "temp_disp_transform.mat") - transform_files = registrar_ants.itk_transform_to_antsfile( + transform_files = registrar_ANTS.itk_transform_to_ANTSfile( disp_tfm, reference_image, temp_tfm_file ) assert len(transform_files) == 1, "Should return one transform file" @@ -679,7 +679,7 @@ def test_transform_conversion_cycle_displacement_field( # Convert back ANTs -> ITK recovered_tfm = ( - registrar_ants._antsfile_to_itk_displacement_field_transform( + registrar_ANTS._antsfile_to_itk_displacement_field_transform( transform_files[0], reference_image ) ) @@ -711,7 +711,7 @@ def test_transform_conversion_cycle_displacement_field( print("Displacement field transform conversion cycle successful") def test_transform_conversion_with_composite( - self, registrar_ants: RegisterImagesANTs, test_images: list[Any] + self, registrar_ANTS: RegisterImagesANTS, test_images: list[Any] ) -> None: """Test conversion of composite transforms.""" reference_image = test_images[0] @@ -750,7 +750,7 @@ def test_transform_conversion_with_composite( # Convert to ANTs file with tempfile.TemporaryDirectory() as tmpdir: temp_tfm_file = os.path.join(tmpdir, "temp_composite_transform.mat") - transform_files = registrar_ants.itk_transform_to_antsfile( + transform_files = registrar_ANTS.itk_transform_to_ANTSfile( composite_tfm, reference_image, temp_tfm_file ) assert len(transform_files) == 1, "Should return one transform file" diff --git a/tests/test_register_images_greedy.py b/tests/test_register_images_greedy.py index 932bb2c..dfbe314 100644 --- a/tests/test_register_images_greedy.py +++ b/tests/test_register_images_greedy.py @@ -2,7 +2,7 @@ """ Tests for Greedy-based image registration. -Uses the same fixtures as test_register_images_ants (converted 3D CT images). +Uses the same fixtures as test_register_images_ANTS (converted 3D CT images). Requires the picsl-greedy package and test data. """ diff --git a/tests/test_register_images_icon.py b/tests/test_register_images_icon.py index 797d738..f89c29a 100644 --- a/tests/test_register_images_icon.py +++ b/tests/test_register_images_icon.py @@ -23,85 +23,85 @@ class TestRegisterImagesICON: """Test suite for ICON-based image registration.""" - def test_registrar_initialization(self, registrar_icon: RegisterImagesICON) -> None: + def test_registrar_initialization(self, registrar_ICON: RegisterImagesICON) -> None: """Test that RegisterImagesICON initializes correctly.""" - assert registrar_icon is not None, "Registrar not initialized" - assert hasattr(registrar_icon, "fixed_image"), "Missing fixed_image attribute" - assert hasattr(registrar_icon, "fixed_mask"), "Missing fixed_mask attribute" - assert hasattr(registrar_icon, "number_of_iterations"), ( + assert registrar_ICON is not None, "Registrar not initialized" + assert hasattr(registrar_ICON, "fixed_image"), "Missing fixed_image attribute" + assert hasattr(registrar_ICON, "fixed_mask"), "Missing fixed_mask attribute" + assert hasattr(registrar_ICON, "number_of_iterations"), ( "Missing number_of_iterations attribute" ) - assert hasattr(registrar_icon, "net"), "Missing net attribute (ICON network)" + assert hasattr(registrar_ICON, "net"), "Missing net attribute (ICON network)" print("\nICON registrar initialized successfully") - print(f" Default iterations: {registrar_icon.number_of_iterations}") + print(f" Default iterations: {registrar_ICON.number_of_iterations}") - def test_set_modality(self, registrar_icon: RegisterImagesICON) -> None: + def test_set_modality(self, registrar_ICON: RegisterImagesICON) -> None: """Test setting imaging modality.""" - registrar_icon.set_modality("ct") - assert registrar_icon.modality == "ct", "Modality not set correctly" + registrar_ICON.set_modality("ct") + assert registrar_ICON.modality == "ct", "Modality not set correctly" - registrar_icon.set_modality("mr") - assert registrar_icon.modality == "mr", "Modality change failed" + registrar_ICON.set_modality("mr") + assert registrar_ICON.modality == "mr", "Modality change failed" print("\nModality setting works correctly") - def test_set_number_of_iterations(self, registrar_icon: RegisterImagesICON) -> None: + def test_set_number_of_iterations(self, registrar_ICON: RegisterImagesICON) -> None: """Test setting number of iterations.""" - registrar_icon.set_number_of_iterations(10) - assert registrar_icon.number_of_iterations == 10, "Number of iterations not set" + registrar_ICON.set_number_of_iterations(10) + assert registrar_ICON.number_of_iterations == 10, "Number of iterations not set" - registrar_icon.set_number_of_iterations(5) - assert registrar_icon.number_of_iterations == 5, ( + registrar_ICON.set_number_of_iterations(5) + assert registrar_ICON.number_of_iterations == 5, ( "Number of iterations update failed" ) print("\nNumber of iterations setting works correctly") def test_set_fixed_image( - self, registrar_icon: RegisterImagesICON, test_images: list[Any] + self, registrar_ICON: RegisterImagesICON, test_images: list[Any] ) -> None: """Test setting fixed image.""" fixed_image = test_images[0] - registrar_icon.set_fixed_image(fixed_image) - assert registrar_icon.fixed_image is not None, "Fixed image not set" + registrar_ICON.set_fixed_image(fixed_image) + assert registrar_ICON.fixed_image is not None, "Fixed image not set" print("\nFixed image set successfully") - print(f" Image size: {itk.size(registrar_icon.fixed_image)}") - print(f" Image spacing: {itk.spacing(registrar_icon.fixed_image)}") + print(f" Image size: {itk.size(registrar_ICON.fixed_image)}") + print(f" Image spacing: {itk.spacing(registrar_ICON.fixed_image)}") - def test_set_mass_preservation(self, registrar_icon: RegisterImagesICON) -> None: + def test_set_mass_preservation(self, registrar_ICON: RegisterImagesICON) -> None: """Test setting mass preservation flag.""" - registrar_icon.set_mass_preservation(True) - assert registrar_icon.use_mass_preservation, "Mass preservation not set" + registrar_ICON.set_mass_preservation(True) + assert registrar_ICON.use_mass_preservation, "Mass preservation not set" - registrar_icon.set_mass_preservation(False) - assert not registrar_icon.use_mass_preservation, ( + registrar_ICON.set_mass_preservation(False) + assert not registrar_ICON.use_mass_preservation, ( "Mass preservation update failed" ) print("\nMass preservation setting works correctly") - def test_set_multi_modality(self, registrar_icon: RegisterImagesICON) -> None: + def test_set_multi_modality(self, registrar_ICON: RegisterImagesICON) -> None: """Test setting multi-modality flag.""" - registrar_icon.set_multi_modality(True) - assert registrar_icon.use_multi_modality, "Multi-modality not set" + registrar_ICON.set_multi_modality(True) + assert registrar_ICON.use_multi_modality, "Multi-modality not set" - registrar_icon.set_multi_modality(False) - assert not registrar_icon.use_multi_modality, "Multi-modality update failed" + registrar_ICON.set_multi_modality(False) + assert not registrar_ICON.use_multi_modality, "Multi-modality update failed" print("\nMulti-modality setting works correctly") def test_register_without_mask( self, - registrar_icon: RegisterImagesICON, + registrar_ICON: RegisterImagesICON, test_images: list[Any], test_directories: dict[str, Path], ) -> None: """Test basic ICON registration without masks.""" output_dir = test_directories["output"] - reg_output_dir = output_dir / "registration_icon" + reg_output_dir = output_dir / "registration_ICON" reg_output_dir.mkdir(exist_ok=True) # Set up registration @@ -112,12 +112,12 @@ def test_register_without_mask( print(f" Fixed image: {itk.size(fixed_image)}") print(f" Moving image: {itk.size(moving_image)}") - registrar_icon.set_modality("ct") - registrar_icon.set_fixed_image(fixed_image) - registrar_icon.set_number_of_iterations(2) # Use fewer iterations for testing + registrar_ICON.set_modality("ct") + registrar_ICON.set_fixed_image(fixed_image) + registrar_ICON.set_number_of_iterations(2) # Use fewer iterations for testing # Register - result = registrar_icon.register(moving_image=moving_image) + result = registrar_ICON.register(moving_image=moving_image) # Verify result is a dictionary assert isinstance(result, dict), "Result should be a dictionary" @@ -150,13 +150,13 @@ def test_register_without_mask( def test_register_with_mask( self, - registrar_icon: RegisterImagesICON, + registrar_ICON: RegisterImagesICON, test_images: list[Any], test_directories: dict[str, Path], ) -> None: """Test ICON registration with binary masks.""" output_dir = test_directories["output"] - reg_output_dir = output_dir / "registration_icon" + reg_output_dir = output_dir / "registration_ICON" reg_output_dir.mkdir(exist_ok=True) fixed_image = test_images[0] @@ -207,13 +207,13 @@ def test_register_with_mask( print(f" Moving mask voxels: {np.sum(moving_mask_arr)}") # Set up registration with masks - registrar_icon.set_modality("ct") - registrar_icon.set_fixed_image(fixed_image) - registrar_icon.set_fixed_mask(fixed_mask) - registrar_icon.set_number_of_iterations(2) + registrar_ICON.set_modality("ct") + registrar_ICON.set_fixed_image(fixed_image) + registrar_ICON.set_fixed_mask(fixed_mask) + registrar_ICON.set_number_of_iterations(2) # Register - result = registrar_icon.register( + result = registrar_ICON.register( moving_image=moving_image, moving_mask=moving_mask ) @@ -244,23 +244,23 @@ def test_register_with_mask( def test_transform_application( self, - registrar_icon: RegisterImagesICON, + registrar_ICON: RegisterImagesICON, test_images: list[Any], test_directories: dict[str, Path], ) -> None: """Test applying ICON registration transforms to images.""" output_dir = test_directories["output"] - reg_output_dir = output_dir / "registration_icon" + reg_output_dir = output_dir / "registration_ICON" reg_output_dir.mkdir(exist_ok=True) fixed_image = test_images[0] moving_image = test_images[1] # Register - registrar_icon.set_modality("ct") - registrar_icon.set_fixed_image(fixed_image) - registrar_icon.set_number_of_iterations(2) - result = registrar_icon.register(moving_image=moving_image) + registrar_ICON.set_modality("ct") + registrar_ICON.set_fixed_image(fixed_image) + registrar_ICON.set_number_of_iterations(2) + result = registrar_ICON.register(moving_image=moving_image) forward_transform = result["forward_transform"] @@ -297,7 +297,7 @@ def test_transform_application( print(f" Saved to: {reg_output_dir / 'icon_registered_image.mha'}") def test_inverse_consistency( - self, registrar_icon: RegisterImagesICON, test_images: list[Any] + self, registrar_ICON: RegisterImagesICON, test_images: list[Any] ) -> None: """Test ICON's inverse consistency property.""" fixed_image = test_images[0] @@ -305,10 +305,10 @@ def test_inverse_consistency( print("\nTesting inverse consistency...") - registrar_icon.set_modality("ct") - registrar_icon.set_fixed_image(fixed_image) - registrar_icon.set_number_of_iterations(2) - result = registrar_icon.register(moving_image=moving_image) + registrar_ICON.set_modality("ct") + registrar_ICON.set_fixed_image(fixed_image) + registrar_ICON.set_number_of_iterations(2) + result = registrar_ICON.register(moving_image=moving_image) inverse_transform = result["inverse_transform"] forward_transform = result["forward_transform"] @@ -343,7 +343,7 @@ def test_inverse_consistency( assert error < 5.0, f"Inverse consistency error too large: {error:.2f} mm" def test_preprocess_images( - self, registrar_icon: RegisterImagesICON, test_images: list[Any] + self, registrar_ICON: RegisterImagesICON, test_images: list[Any] ) -> None: """Test image preprocessing for ICON.""" test_image = test_images[0] @@ -352,7 +352,7 @@ def test_preprocess_images( print(f" Original spacing: {itk.spacing(test_image)}") # Preprocess - preprocessed = registrar_icon.preprocess(test_image, modality="ct") + preprocessed = registrar_ICON.preprocess(test_image, modality="ct") assert preprocessed is not None, "Preprocessed image is None" @@ -362,13 +362,13 @@ def test_preprocess_images( def test_registration_with_initial_transform( self, - registrar_icon: RegisterImagesICON, + registrar_ICON: RegisterImagesICON, test_images: list[Any], test_directories: dict[str, Path], ) -> None: """Test ICON registration with initial transform.""" output_dir = test_directories["output"] - reg_output_dir = output_dir / "registration_icon" + reg_output_dir = output_dir / "registration_ICON" reg_output_dir.mkdir(exist_ok=True) fixed_image = test_images[0] @@ -381,11 +381,11 @@ def test_registration_with_initial_transform( print("\nRegistering with initial transform...") print(" Initial offset: [-5.0, -5.0, -5.0]") - registrar_icon.set_modality("ct") - registrar_icon.set_fixed_image(fixed_image) - registrar_icon.set_number_of_iterations(2) + registrar_ICON.set_modality("ct") + registrar_ICON.set_fixed_image(fixed_image) + registrar_ICON.set_number_of_iterations(2) - result = registrar_icon.register( + result = registrar_ICON.register( moving_image=moving_image, initial_forward_transform=initial_tfm_forward, ) @@ -397,16 +397,16 @@ def test_registration_with_initial_transform( print("Registration with initial transform complete") def test_transform_types( - self, registrar_icon: RegisterImagesICON, test_images: list[Any] + self, registrar_ICON: RegisterImagesICON, test_images: list[Any] ) -> None: """Test that ICON transforms are correct ITK types.""" fixed_image = test_images[0] moving_image = test_images[1] - registrar_icon.set_modality("ct") - registrar_icon.set_fixed_image(fixed_image) - registrar_icon.set_number_of_iterations(2) - result = registrar_icon.register(moving_image=moving_image) + registrar_ICON.set_modality("ct") + registrar_ICON.set_fixed_image(fixed_image) + registrar_ICON.set_number_of_iterations(2) + result = registrar_ICON.register(moving_image=moving_image) inverse_transform = result["inverse_transform"] forward_transform = result["forward_transform"] @@ -438,14 +438,14 @@ def test_transform_types( print(f" forward_transform: {type(forward_transform).__name__}") def test_different_iteration_counts( - self, registrar_icon: RegisterImagesICON, test_images: list[Any] + self, registrar_ICON: RegisterImagesICON, test_images: list[Any] ) -> None: """Test ICON with different iteration counts.""" fixed_image = test_images[0] moving_image = test_images[1] - registrar_icon.set_modality("ct") - registrar_icon.set_fixed_image(fixed_image) + registrar_ICON.set_modality("ct") + registrar_ICON.set_fixed_image(fixed_image) iteration_counts = [1, 2, 5] results = [] @@ -454,8 +454,8 @@ def test_different_iteration_counts( for num_iter in iteration_counts: print(f" Running with {num_iter} iterations...") - registrar_icon.set_number_of_iterations(num_iter) - result = registrar_icon.register(moving_image=moving_image) + registrar_ICON.set_number_of_iterations(num_iter) + result = registrar_ICON.register(moving_image=moving_image) results.append(result) assert isinstance(result, dict), "Result should be a dictionary" diff --git a/tests/test_register_time_series_images.py b/tests/test_register_time_series_images.py index d37d053..24f3a1b 100644 --- a/tests/test_register_time_series_images.py +++ b/tests/test_register_time_series_images.py @@ -26,29 +26,29 @@ class TestRegisterTimeSeriesImages: _class_name = "registration_time_series_images" - def test_registrar_initialization_ants(self) -> None: + def test_registrar_initialization_ANTS(self) -> None: """Test that RegisterTimeSeriesImages initializes correctly with ANTs.""" - registrar = RegisterTimeSeriesImages(registration_method="ants") + registrar = RegisterTimeSeriesImages(registration_method="ANTS") assert registrar is not None, "Registrar not initialized" - assert registrar.registration_method_name == "ants", "Method not set correctly" - assert registrar.registrar_ants is not None, ( + assert registrar.registration_method_name == "ANTS", "Method not set correctly" + assert registrar.registrar_ANTS is not None, ( "Internal ANTs registrar not created" ) - assert registrar.registrar_icon is not None, ( + assert registrar.registrar_ICON is not None, ( "Internal ICON registrar not created" ) print("\nTime series registrar initialized with ANTs") - def test_registrar_initialization_icon(self) -> None: + def test_registrar_initialization_ICON(self) -> None: """Test that RegisterTimeSeriesImages initializes correctly with ICON.""" - registrar = RegisterTimeSeriesImages(registration_method="icon") + registrar = RegisterTimeSeriesImages(registration_method="ICON") assert registrar is not None, "Registrar not initialized" - assert registrar.registration_method_name == "icon", "Method not set correctly" - assert registrar.registrar_ants is not None, ( + assert registrar.registration_method_name == "ICON", "Method not set correctly" + assert registrar.registrar_ANTS is not None, ( "Internal ANTs registrar not created" ) - assert registrar.registrar_icon is not None, ( + assert registrar.registrar_ICON is not None, ( "Internal ICON registrar not created" ) @@ -64,7 +64,7 @@ def test_registrar_initialization_greedy(self) -> None: assert registrar.registrar_greedy is not None, ( "Internal Greedy registrar not created" ) - assert registrar.registrar_icon is not None, ( + assert registrar.registrar_ICON is not None, ( "Internal ICON registrar not created" ) @@ -79,7 +79,7 @@ def test_registrar_initialization_invalid_method(self) -> None: def test_set_modality(self) -> None: """Test setting imaging modality.""" - registrar = RegisterTimeSeriesImages(registration_method="ants") + registrar = RegisterTimeSeriesImages(registration_method="ANTS") registrar.set_modality("ct") assert registrar.modality == "ct", "Modality not set correctly" @@ -87,7 +87,7 @@ def test_set_modality(self) -> None: def test_set_fixed_image(self, test_images: list[Any]) -> None: """Test setting fixed image.""" - registrar = RegisterTimeSeriesImages(registration_method="ants") + registrar = RegisterTimeSeriesImages(registration_method="ANTS") fixed_image = test_images[0] registrar.set_fixed_image(fixed_image) @@ -98,11 +98,11 @@ def test_set_fixed_image(self, test_images: list[Any]) -> None: def test_set_number_of_iterations(self) -> None: """Test setting number of iterations.""" - registrar_ants = RegisterTimeSeriesImages(registration_method="ants") - iterations_ants = [30, 15, 5] + registrar_ANTS = RegisterTimeSeriesImages(registration_method="ANTS") + iterations_ANTS = [30, 15, 5] - registrar_ants.set_number_of_iterations_ants(iterations_ants) - assert registrar_ants.number_of_iterations_ants == iterations_ants, ( + registrar_ANTS.set_number_of_iterations_ANTS(iterations_ANTS) + assert registrar_ANTS.number_of_iterations_ANTS == iterations_ANTS, ( "ANTs iterations not set correctly" ) @@ -114,11 +114,11 @@ def test_set_number_of_iterations(self) -> None: "Greedy iterations not set correctly" ) - registrar_icon = RegisterTimeSeriesImages(registration_method="icon") - iterations_icon = 50 + registrar_ICON = RegisterTimeSeriesImages(registration_method="ICON") + iterations_ICON = 50 - registrar_icon.set_number_of_iterations_icon(iterations_icon) - assert registrar_icon.number_of_iterations_icon == iterations_icon, ( + registrar_ICON.set_number_of_iterations_ICON(iterations_ICON) + assert registrar_ICON.number_of_iterations_ICON == iterations_ICON, ( "ICON iterations not set correctly" ) @@ -136,10 +136,10 @@ def test_register_time_series_basic( print(f" Fixed image: {itk.size(fixed_image)}") print(f" Number of moving images: {len(moving_images)}") - registrar = RegisterTimeSeriesImages(registration_method="ants") + registrar = RegisterTimeSeriesImages(registration_method="ANTS") registrar.set_modality("ct") registrar.set_fixed_image(fixed_image) - registrar.set_number_of_iterations_ants([20, 10, 2]) + registrar.set_number_of_iterations_ANTS([20, 10, 2]) result = registrar.register_time_series( moving_images=moving_images, @@ -217,10 +217,10 @@ def test_register_time_series_with_prior( print(f" Number of moving images: {len(moving_images)}") print(" Using prior transform weight: 0.5") - registrar = RegisterTimeSeriesImages(registration_method="ants") + registrar = RegisterTimeSeriesImages(registration_method="ANTS") registrar.set_modality("ct") registrar.set_fixed_image(fixed_image) - registrar.set_number_of_iterations_ants([20, 10, 2]) + registrar.set_number_of_iterations_ANTS([20, 10, 2]) result = registrar.register_time_series( moving_images=moving_images, @@ -274,10 +274,10 @@ def test_register_time_series_identity_start(self, test_images: list[Any]) -> No print("\nRegistering time series (identity start)...") - registrar = RegisterTimeSeriesImages(registration_method="ants") + registrar = RegisterTimeSeriesImages(registration_method="ANTS") registrar.set_modality("ct") registrar.set_fixed_image(fixed_image) - registrar.set_number_of_iterations_ants([20, 10, 2]) + registrar.set_number_of_iterations_ANTS([20, 10, 2]) result = registrar.register_time_series( moving_images=moving_images, @@ -302,10 +302,10 @@ def test_register_time_series_different_starting_indices( print("\nTesting different starting indices...") - registrar = RegisterTimeSeriesImages(registration_method="ants") + registrar = RegisterTimeSeriesImages(registration_method="ANTS") registrar.set_modality("ct") registrar.set_fixed_image(fixed_image) - registrar.set_number_of_iterations_ants([10, 5, 1]) + registrar.set_number_of_iterations_ANTS([10, 5, 1]) # Test starting from beginning, middle, and end for starting_index in [0, 1]: @@ -325,7 +325,7 @@ def test_register_time_series_different_starting_indices( def test_register_time_series_error_no_fixed_image(self) -> None: """Test that error is raised if fixed image not set.""" - registrar = RegisterTimeSeriesImages(registration_method="ants") + registrar = RegisterTimeSeriesImages(registration_method="ANTS") moving_images = [None, None, None] # Dummy list @@ -338,7 +338,7 @@ def test_register_time_series_error_invalid_starting_index( self, test_images: list[Any] ) -> None: """Test that error is raised for invalid starting index.""" - registrar = RegisterTimeSeriesImages(registration_method="ants") + registrar = RegisterTimeSeriesImages(registration_method="ANTS") registrar.set_fixed_image(test_images[0]) moving_images = test_images[1:4] @@ -361,7 +361,7 @@ def test_register_time_series_error_invalid_prior_portion( self, test_images: list[Any] ) -> None: """Test that error is raised for invalid prior portion value.""" - registrar = RegisterTimeSeriesImages(registration_method="ants") + registrar = RegisterTimeSeriesImages(registration_method="ANTS") registrar.set_fixed_image(test_images[0]) moving_images = test_images[1:4] @@ -391,10 +391,10 @@ def test_transform_application_time_series( print("\nTesting transform application...") - registrar = RegisterTimeSeriesImages(registration_method="ants") + registrar = RegisterTimeSeriesImages(registration_method="ANTS") registrar.set_modality("ct") registrar.set_fixed_image(fixed_image) - registrar.set_number_of_iterations_ants([20, 10, 2]) + registrar.set_number_of_iterations_ANTS([20, 10, 2]) result = registrar.register_time_series( moving_images=moving_images, @@ -434,17 +434,17 @@ def test_transform_application_time_series( "transform_application_time_series_0.mha", ) - def test_register_time_series_icon(self, test_images: list[Any]) -> None: + def test_register_time_series_ICON(self, test_images: list[Any]) -> None: """Test time series registration with ICON method.""" fixed_image = test_images[0] moving_images = test_images[1:3] print("\nTesting time series registration with ICON...") - registrar = RegisterTimeSeriesImages(registration_method="icon") + registrar = RegisterTimeSeriesImages(registration_method="ICON") registrar.set_modality("ct") registrar.set_fixed_image(fixed_image) - registrar.set_number_of_iterations_icon(5) # ICON uses single int + registrar.set_number_of_iterations_ICON(5) # ICON uses single int result = registrar.register_time_series( moving_images=moving_images, @@ -487,11 +487,11 @@ def test_register_time_series_with_mask( print("\nTesting time series registration with mask...") print(f" Mask voxels: {np.sum(fixed_mask_arr)}") - registrar = RegisterTimeSeriesImages(registration_method="ants") + registrar = RegisterTimeSeriesImages(registration_method="ANTS") registrar.set_modality("ct") registrar.set_fixed_image(fixed_image) registrar.set_fixed_mask(fixed_mask) - registrar.set_number_of_iterations_ants([20, 10, 2]) + registrar.set_number_of_iterations_ANTS([20, 10, 2]) result = registrar.register_time_series( moving_images=moving_images, @@ -513,10 +513,10 @@ def test_bidirectional_registration(self, test_images: list[Any]) -> None: print(f" Total images: {len(moving_images)}") print(" Starting from middle (index 2)") - registrar = RegisterTimeSeriesImages(registration_method="ants") + registrar = RegisterTimeSeriesImages(registration_method="ANTS") registrar.set_modality("ct") registrar.set_fixed_image(fixed_image) - registrar.set_number_of_iterations_ants([20, 10, 2]) + registrar.set_number_of_iterations_ANTS([20, 10, 2]) result = registrar.register_time_series( moving_images=moving_images, diff --git a/tests/test_transform_tools.py b/tests/test_transform_tools.py index 6efc4a0..3b12635 100644 --- a/tests/test_transform_tools.py +++ b/tests/test_transform_tools.py @@ -2,7 +2,7 @@ """ Test for transform tools functionality. -This test depends on test_register_images_ants and uses registration +This test depends on test_register_images_ANTS and uses registration transforms to test transform manipulation and application. """ diff --git a/tests/test_workflow_fine_tune_icon_registration.py b/tests/test_workflow_fine_tune_icon_registration.py new file mode 100644 index 0000000..11cde75 --- /dev/null +++ b/tests/test_workflow_fine_tune_icon_registration.py @@ -0,0 +1,554 @@ +"""Unit tests for WorkflowFineTuneICONRegistration. + +Exercises constructor validation, ``prepare_dataset`` / ``prepare_config`` +file generation, mask derivation, and the ``run_fine_tuning`` subprocess +launch. GPU-heavy paths (real uniGradICON training, ``apply_registration``) +are not exercised here — only their input-validation guards. +""" + +from __future__ import annotations + +import json +import logging +import subprocess +import sys +from pathlib import Path +from typing import Any, Optional + +import itk +import numpy as np +import pytest +import yaml + +from physiomotion4d.register_images_icon import RegisterImagesICON +from physiomotion4d.workflow_fine_tune_icon_registration import ( + WorkflowFineTuneICONRegistration, +) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +def _make_image(path: Path, value: int = 1) -> None: + """Write a 3x3x3 ``uint8`` ITK image with a single foreground voxel at center.""" + arr = np.zeros((3, 3, 3), dtype=np.uint8) + arr[1, 1, 1] = value + img = itk.image_from_array(arr) + itk.imwrite(img, str(path), compression=True) + + +@pytest.fixture +def two_subject_dataset(tmp_path: Path) -> dict[str, Any]: + """Two patients, two frames each, with matching labelmaps on disk.""" + data_dir = tmp_path / "data" + data_dir.mkdir() + output_dir = tmp_path / "ft_out" + + subject_image_files: list[list[str]] = [] + subject_segmentation_files: list[list[Optional[str]]] = [] + for patient_id in ("pm0001", "pm0002"): + pdir = data_dir / patient_id + pdir.mkdir() + images: list[str] = [] + segs: list[Optional[str]] = [] + for frame in ("g000", "g050"): + image_path = pdir / f"{patient_id}_{frame}.nii.gz" + label_path = pdir / f"{patient_id}_{frame}_labelmap.nii.gz" + _make_image(image_path) + _make_image(label_path) + images.append(str(image_path)) + segs.append(str(label_path)) + subject_image_files.append(images) + subject_segmentation_files.append(segs) + + return { + "output_dir": output_dir, + "fine_tune_name": "test_exp", + "subject_ids": ["pm0001", "pm0002"], + "subject_image_files": subject_image_files, + "subject_segmentation_files": subject_segmentation_files, + } + + +# --------------------------------------------------------------------------- +# Construction / validation +# --------------------------------------------------------------------------- + + +def test_init_requires_output_dir_and_name(tmp_path: Path) -> None: + """output_dir and fine_tune_name are required positional args.""" + with pytest.raises(TypeError): + WorkflowFineTuneICONRegistration( + subject_image_files=[["a.nii.gz"]], + ) + + +def test_init_rejects_empty_image_files(tmp_path: Path) -> None: + """Empty subject list raises immediately.""" + with pytest.raises(ValueError, match="must not be empty"): + WorkflowFineTuneICONRegistration( + subject_image_files=[], + output_dir=tmp_path, + fine_tune_name="x", + ) + + +def test_init_rejects_mismatched_companion_lengths(tmp_path: Path) -> None: + """Mask/seg/landmark lists must match subject_image_files shape exactly.""" + with pytest.raises(ValueError, match="subject_mask_files\\[0\\] length"): + WorkflowFineTuneICONRegistration( + subject_image_files=[["a.nii.gz", "b.nii.gz"]], + output_dir=tmp_path, + fine_tune_name="x", + subject_mask_files=[["m.nii.gz"]], + ) + + +def test_init_rejects_duplicate_subject_ids(tmp_path: Path) -> None: + """Duplicate subject IDs collapse paired groups, so reject them up front.""" + with pytest.raises(ValueError, match="unique"): + WorkflowFineTuneICONRegistration( + subject_image_files=[["a"], ["b"]], + output_dir=tmp_path, + fine_tune_name="x", + subject_ids=["same", "same"], + ) + + +def test_init_rejects_mismatched_subject_ids_length(tmp_path: Path) -> None: + """subject_ids must have one entry per subject.""" + with pytest.raises(ValueError, match="subject_ids length"): + WorkflowFineTuneICONRegistration( + subject_image_files=[["a"]], + output_dir=tmp_path, + fine_tune_name="x", + subject_ids=["a", "b"], + ) + + +def test_uses_segmentations_and_uses_masks_flags(tmp_path: Path) -> None: + """The two helper flags reflect supplied companions independently.""" + base: dict[str, Any] = { + "subject_image_files": [["a"]], + "output_dir": tmp_path, + "fine_tune_name": "x", + } + none_wf = WorkflowFineTuneICONRegistration(**base) + assert not none_wf.uses_segmentations + assert not none_wf.uses_masks + + seg_only = WorkflowFineTuneICONRegistration( + **base, subject_segmentation_files=[["seg.nii.gz"]] + ) + assert seg_only.uses_segmentations + assert seg_only.uses_masks # derived from segs + + mask_only = WorkflowFineTuneICONRegistration( + **base, subject_mask_files=[["mask.nii.gz"]] + ) + assert not mask_only.uses_segmentations + assert mask_only.uses_masks + + +# --------------------------------------------------------------------------- +# RegisterImagesICON.create_mask (in-memory dilation, used by the workflow) +# --------------------------------------------------------------------------- + + +def test_create_mask_thresholds_and_dilates() -> None: + """Single-voxel labelmap becomes a binary mask whose dilation grows it.""" + arr = np.zeros((5, 5, 5), dtype=np.uint8) + arr[2, 2, 2] = 3 # non-zero label id + labelmap = itk.image_from_array(arr) + # Unit isotropic spacing so dilation_mm == voxel radius. + labelmap.SetSpacing([1.0, 1.0, 1.0]) + + no_dilate = RegisterImagesICON.create_mask(labelmap, dilation_mm=0.0) + no_dilate_arr = itk.array_from_image(no_dilate) + assert set(np.unique(no_dilate_arr).tolist()) == {0, 1} + assert int(no_dilate_arr.sum()) == 1 + + dilated = RegisterImagesICON.create_mask(labelmap, dilation_mm=1.0) + dilated_arr = itk.array_from_image(dilated) + assert int(dilated_arr.sum()) > 1 + # Original foreground voxel stays foreground. + assert dilated_arr[2, 2, 2] == 1 + + +# --------------------------------------------------------------------------- +# prepare_dataset +# --------------------------------------------------------------------------- + + +def test_prepare_dataset_uses_real_subject_ids( + two_subject_dataset: dict[str, Any], +) -> None: + """Subject IDs round-trip from the caller into every dataset entry.""" + workflow = WorkflowFineTuneICONRegistration( + log_level=logging.CRITICAL, **two_subject_dataset + ) + dataset_json_path = workflow.prepare_dataset() + + payload = json.loads(dataset_json_path.read_text(encoding="utf-8")) + entries = payload["data"] + assert len(entries) == 4 + ids = {entry["subject_id"] for entry in entries} + assert ids == {"pm0001", "pm0002"} + for entry in entries: + assert set(entry).issuperset({"image", "segmentation", "mask", "subject_id"}) + # Paths are forward-slashed for uniGradICON. + assert "\\" not in entry["image"] + assert "\\" not in entry["segmentation"] + assert "\\" not in entry["mask"] + + +def test_prepare_dataset_skips_frames_with_missing_segmentation( + tmp_path: Path, +) -> None: + """A frame with no seg available is dropped when use_label is required.""" + data_dir = tmp_path / "data" + data_dir.mkdir() + img_a = data_dir / "img_a.nii.gz" + img_b = data_dir / "img_b.nii.gz" + seg_a = data_dir / "seg_a.nii.gz" + _make_image(img_a) + _make_image(img_b) + _make_image(seg_a) + + workflow = WorkflowFineTuneICONRegistration( + subject_image_files=[[str(img_a), str(img_b)]], + output_dir=tmp_path / "out", + fine_tune_name="exp", + subject_segmentation_files=[[str(seg_a), None]], + log_level=logging.CRITICAL, + ) + dataset_json_path = workflow.prepare_dataset() + + entries = json.loads(dataset_json_path.read_text(encoding="utf-8"))["data"] + assert len(entries) == 1 + assert entries[0]["image"].endswith("img_a.nii.gz") + + +def test_prepare_dataset_uses_explicit_mask_over_derived(tmp_path: Path) -> None: + """When subject_mask_files supplies a mask, it overrides the derived one.""" + data_dir = tmp_path / "data" + data_dir.mkdir() + image = data_dir / "image.nii.gz" + seg = data_dir / "seg.nii.gz" + explicit_mask = data_dir / "explicit_mask.nii.gz" + _make_image(image) + _make_image(seg) + _make_image(explicit_mask) + + workflow = WorkflowFineTuneICONRegistration( + subject_image_files=[[str(image)]], + output_dir=tmp_path / "out", + fine_tune_name="exp", + subject_segmentation_files=[[str(seg)]], + subject_mask_files=[[str(explicit_mask)]], + log_level=logging.CRITICAL, + ) + dataset_json_path = workflow.prepare_dataset() + entry = json.loads(dataset_json_path.read_text(encoding="utf-8"))["data"][0] + + assert entry["mask"].endswith("explicit_mask.nii.gz") + assert entry["segmentation"].endswith("seg.nii.gz") + # No derived mask file was created because the explicit one was used. + derived_mask = data_dir / "seg_mask.nii.gz" + assert not derived_mask.exists() + + +def test_prepare_dataset_mask_only_no_segmentations(tmp_path: Path) -> None: + """Mask-only input: entries have ``mask`` but no ``segmentation`` field.""" + data_dir = tmp_path / "data" + data_dir.mkdir() + image = data_dir / "image.nii.gz" + mask = data_dir / "mask.nii.gz" + _make_image(image) + _make_image(mask) + + workflow = WorkflowFineTuneICONRegistration( + subject_image_files=[[str(image)]], + output_dir=tmp_path / "out", + fine_tune_name="exp", + subject_mask_files=[[str(mask)]], + log_level=logging.CRITICAL, + ) + entry = json.loads(workflow.prepare_dataset().read_text(encoding="utf-8"))["data"][ + 0 + ] + assert "mask" in entry + assert "segmentation" not in entry + + +def test_prepare_dataset_derives_mask_next_to_labelmap_by_default( + two_subject_dataset: dict[str, Any], +) -> None: + """Derived masks land next to each labelmap when ``mask_dir`` is not set.""" + workflow = WorkflowFineTuneICONRegistration( + log_level=logging.CRITICAL, **two_subject_dataset + ) + assert workflow.mask_dir is None + workflow.prepare_dataset() + + seg_files = [ + Path(s) + for inner in workflow.subject_segmentation_files or [] + for s in inner + if s is not None + ] + derived = [s.parent / f"{s.name[: -len('.nii.gz')]}_mask.nii.gz" for s in seg_files] + for mask_path in derived: + assert mask_path.exists(), f"missing derived mask: {mask_path}" + assert len(derived) == 4 + # Sanity: derived masks are binary with at least one foreground voxel. + arr = itk.array_from_image(itk.imread(str(derived[0]))) + assert set(np.unique(arr).tolist()).issubset({0, 1}) + assert int(arr.sum()) >= 1 + + +def test_prepare_dataset_derives_mask_under_explicit_mask_dir( + two_subject_dataset: dict[str, Any], tmp_path: Path +) -> None: + """Explicit ``mask_dir`` collects every derived mask in that single folder.""" + explicit_mask_dir = tmp_path / "explicit_masks" + workflow = WorkflowFineTuneICONRegistration( + log_level=logging.CRITICAL, + mask_dir=explicit_mask_dir, + **two_subject_dataset, + ) + workflow.prepare_dataset() + derived = list(explicit_mask_dir.glob("*_mask.nii.gz")) + assert len(derived) == 4 + # None of the labelmap-adjacent locations should have been written to. + seg_files = [ + Path(s) + for inner in workflow.subject_segmentation_files or [] + for s in inner + if s is not None + ] + for s in seg_files: + assert not (s.parent / f"{s.name[: -len('.nii.gz')]}_mask.nii.gz").exists() + + +def test_prepare_dataset_raises_on_missing_image_file(tmp_path: Path) -> None: + """Image existence is a hard requirement; missing image aborts the build.""" + workflow = WorkflowFineTuneICONRegistration( + subject_image_files=[[str(tmp_path / "does_not_exist.nii.gz")]], + output_dir=tmp_path / "out", + fine_tune_name="exp", + log_level=logging.CRITICAL, + ) + with pytest.raises(FileNotFoundError, match="Image not found"): + workflow.prepare_dataset() + + +# --------------------------------------------------------------------------- +# prepare_config +# --------------------------------------------------------------------------- + + +def test_prepare_config_emits_uniGradICON_yaml( + two_subject_dataset: dict[str, Any], +) -> None: + """YAML config matches uniGradICON's expected structure when seg is present.""" + workflow = WorkflowFineTuneICONRegistration( + log_level=logging.CRITICAL, + epochs=10, + batch_size=2, + learning_rate=1e-4, + input_shape=(64, 64, 64), + gpus=[1], + **two_subject_dataset, + ) + dataset_json = workflow.prepare_dataset() + config_yaml = workflow.prepare_config(dataset_json) + + config = yaml.safe_load(config_yaml.read_text(encoding="utf-8")) + assert config["experiment"]["model_weights"] == "unigradicon" + assert config["experiment"]["name"].endswith("test_exp_model") + training = config["training"] + assert training["epochs"] == 10 + assert training["batch_size"] == 2 + assert training["learning_rate"] == 1e-4 + assert training["input_shape"] == [64, 64, 64] + assert training["gpus"] == [1] + # Driven by data availability. + assert training["use_label"] is True + assert training["loss_function_masking"] is True + assert training["roi_masking"] is False + + dataset_cfg = config["datasets"][0] + assert dataset_cfg["type"] == "paired" + assert dataset_cfg["is_ct"] is True + assert dataset_cfg["json_file"].endswith("test_exp_dataset.json") + assert "\\" not in dataset_cfg["json_file"] + + +def test_prepare_config_flags_off_when_no_companions(tmp_path: Path) -> None: + """Without seg or mask, ``use_label`` and ``loss_function_masking`` are False.""" + data_dir = tmp_path / "data" + data_dir.mkdir() + image = data_dir / "image.nii.gz" + _make_image(image) + + workflow = WorkflowFineTuneICONRegistration( + subject_image_files=[[str(image)]], + output_dir=tmp_path / "out", + fine_tune_name="exp", + log_level=logging.CRITICAL, + ) + dataset_json = workflow.prepare_dataset() + config = yaml.safe_load( + workflow.prepare_config(dataset_json).read_text(encoding="utf-8") + ) + assert config["training"]["use_label"] is False + assert config["training"]["loss_function_masking"] is False + + +def test_prepare_config_requires_dataset_json(tmp_path: Path) -> None: + """Calling prepare_config without first preparing the dataset is an error.""" + workflow = WorkflowFineTuneICONRegistration( + subject_image_files=[["a"]], + output_dir=tmp_path, + fine_tune_name="x", + log_level=logging.CRITICAL, + ) + with pytest.raises(ValueError, match="prepare_dataset"): + workflow.prepare_config() + + +# --------------------------------------------------------------------------- +# expected_weights_path +# --------------------------------------------------------------------------- + + +def test_expected_weights_path_layout(tmp_path: Path) -> None: + """Weights land at ``output_dir//_model/checkpoints/...``.""" + workflow = WorkflowFineTuneICONRegistration( + subject_image_files=[["a"]], + output_dir=tmp_path, + fine_tune_name="exp", + log_level=logging.CRITICAL, + ) + expected = workflow.expected_weights_path() + assert expected == ( + tmp_path / "exp" / "exp_model" / "checkpoints" / "Finetune_multi_final.trch" + ) + + +# --------------------------------------------------------------------------- +# run_fine_tuning (subprocess is monkey-patched) +# --------------------------------------------------------------------------- + + +def test_run_fine_tuning_invokes_unigradicon_subprocess( + monkeypatch: pytest.MonkeyPatch, + two_subject_dataset: dict[str, Any], +) -> None: + """run_fine_tuning launches the uniGradICON finetune module with the YAML path.""" + captured: dict[str, Any] = {} + + def fake_run( + cmd: list[str], + *, + check: bool, + env: dict[str, str], + ) -> subprocess.CompletedProcess[bytes]: + captured["cmd"] = cmd + captured["check"] = check + captured["env"] = env + return subprocess.CompletedProcess(args=cmd, returncode=0) + + monkeypatch.setattr(subprocess, "run", fake_run) + + unigradicon_src = two_subject_dataset["output_dir"].parent / "fake_unigradicon_src" + workflow = WorkflowFineTuneICONRegistration( + log_level=logging.CRITICAL, + unigradicon_src_path=unigradicon_src, + **two_subject_dataset, + ) + weights = workflow.run_fine_tuning() + + assert captured["check"] is True + assert captured["cmd"][0] == sys.executable + assert captured["cmd"][1:4] == ["-m", "unigradicon.finetuning.finetune", "--config"] + yaml_arg = Path(captured["cmd"][4]) + assert yaml_arg.exists() + assert yaml_arg.name == "test_exp_config.yaml" + + # Environment overrides. + assert captured["env"]["PYTHONUTF8"] == "1" + assert str(unigradicon_src) in captured["env"]["PYTHONPATH"] + + assert weights == workflow.expected_weights_path() + + +def test_run_fine_tuning_without_unigradicon_src( + monkeypatch: pytest.MonkeyPatch, + two_subject_dataset: dict[str, Any], +) -> None: + """When unigradicon_src_path is None, PYTHONPATH is not prefixed.""" + + def fake_run( + cmd: list[str], + *, + check: bool, + env: dict[str, str], + ) -> subprocess.CompletedProcess[bytes]: + # No leading entry referencing a "fake" src tree. + assert "fake_unigradicon_src" not in env.get("PYTHONPATH", "") + return subprocess.CompletedProcess(args=cmd, returncode=0) + + monkeypatch.setattr(subprocess, "run", fake_run) + + workflow = WorkflowFineTuneICONRegistration( + log_level=logging.CRITICAL, + **two_subject_dataset, + ) + workflow.run_fine_tuning() + + +# --------------------------------------------------------------------------- +# apply_registration — validation guards only (skips real registration) +# --------------------------------------------------------------------------- + + +def test_apply_registration_rejects_empty_moving(tmp_path: Path) -> None: + """apply_registration validates inputs before touching the registrar.""" + workflow = WorkflowFineTuneICONRegistration( + subject_image_files=[["a"]], + output_dir=tmp_path, + fine_tune_name="x", + log_level=logging.CRITICAL, + ) + arr = np.zeros((3, 3, 3), dtype=np.float32) + ref = itk.image_from_array(arr) + with pytest.raises(ValueError, match="moving_images must not be empty"): + workflow.apply_registration(reference_image=ref, moving_images=[]) + + +def test_apply_registration_rejects_mismatched_companions(tmp_path: Path) -> None: + """moving_segmentations / moving_landmarks length must match moving_images.""" + workflow = WorkflowFineTuneICONRegistration( + subject_image_files=[["a"]], + output_dir=tmp_path, + fine_tune_name="x", + log_level=logging.CRITICAL, + ) + ref = itk.image_from_array(np.zeros((3, 3, 3), dtype=np.float32)) + mov = itk.image_from_array(np.zeros((3, 3, 3), dtype=np.float32)) + with pytest.raises(ValueError, match="moving_segmentations length"): + workflow.apply_registration( + reference_image=ref, + moving_images=[mov], + moving_segmentations=[], + ) + with pytest.raises(ValueError, match="moving_landmarks length"): + workflow.apply_registration( + reference_image=ref, + moving_images=[mov], + moving_landmarks=[], + ) diff --git a/tutorials/tutorial_01_heart_gated_ct_to_usd.py b/tutorials/tutorial_01_heart_gated_ct_to_usd.py index 19281cb..acb6c32 100644 --- a/tutorials/tutorial_01_heart_gated_ct_to_usd.py +++ b/tutorials/tutorial_01_heart_gated_ct_to_usd.py @@ -49,7 +49,7 @@ contour extraction -> USD export. - SegmentChestTotalSegmentator (segment_chest_total_segmentator.py): Deep-learning segmentation of 117 anatomical structures (used internally). -- RegisterImagesICON / RegisterImagesANTs (register_images_icon.py / _ants.py): +- RegisterImagesICON / RegisterImagesANTS (register_images_icon.py / _ants.py): Frame-to-frame image registration (used internally). - ContourTools (contour_tools.py): Extracts and transforms surface meshes from segmentation masks (used internally). diff --git a/tutorials/tutorial_06_reconstruct_highres_4d_ct.py b/tutorials/tutorial_06_reconstruct_highres_4d_ct.py index 1fcbe9c..09ddc09 100644 --- a/tutorials/tutorial_06_reconstruct_highres_4d_ct.py +++ b/tutorials/tutorial_06_reconstruct_highres_4d_ct.py @@ -45,7 +45,7 @@ OUTPUT_DIR = TUTORIALS_DIR / "output" / "tutorial_06" BASELINES_DIR = REPO_ROOT / "tests" / "baselines" MAX_FRAMES = 4 - REGISTRATION_METHOD = "ants" + REGISTRATION_METHOD = "ANTS" LOG_LEVEL = logging.INFO # %% @@ -59,10 +59,10 @@ if test_mode: max_frames = min(MAX_FRAMES, 3) - number_of_iterations_ants = [1, 0] + number_of_iterations_ANTS = [1, 0] else: max_frames = MAX_FRAMES - number_of_iterations_ants = [30, 15, 7, 3] + number_of_iterations_ANTS = [30, 15, 7, 3] output_dir.mkdir(parents=True, exist_ok=True) @@ -87,7 +87,7 @@ log_level=log_level, ) workflow.set_modality("ct") - workflow.set_number_of_iterations_ants(number_of_iterations_ants) + workflow.set_number_of_iterations_ANTS(number_of_iterations_ANTS) # %% # Workflow execution diff --git a/tutorials/tutorial_08_dirlab_pca_time_series.py b/tutorials/tutorial_08_dirlab_pca_time_series.py index 3a544f2..232e8ed 100644 --- a/tutorials/tutorial_08_dirlab_pca_time_series.py +++ b/tutorials/tutorial_08_dirlab_pca_time_series.py @@ -120,7 +120,7 @@ def run_tutorial() -> dict[str, Any]: fixed_image = time_series[0] registrar = RegisterTimeSeriesImages( - registration_method="ants_icon", + registration_method="ANTS_ICON", log_level=log_level, ) registrar.set_modality("ct")