diff --git a/httomo/method_wrappers/__init__.py b/httomo/method_wrappers/__init__.py index 8342200cb..890fefac5 100644 --- a/httomo/method_wrappers/__init__.py +++ b/httomo/method_wrappers/__init__.py @@ -9,6 +9,7 @@ import httomo.method_wrappers.datareducer import httomo.method_wrappers.dezinging import httomo.method_wrappers.distortion_correction +import httomo.method_wrappers.seam_blender import httomo.method_wrappers.images import httomo.method_wrappers.reconstruction import httomo.method_wrappers.rotation diff --git a/httomo/method_wrappers/seam_blender.py b/httomo/method_wrappers/seam_blender.py new file mode 100644 index 000000000..def8974d9 --- /dev/null +++ b/httomo/method_wrappers/seam_blender.py @@ -0,0 +1,56 @@ +from typing import Dict, Optional + +from mpi4py.MPI import Comm + +from httomo.method_wrappers.generic import GenericMethodWrapper +from httomo.preview import PreviewConfig +from httomo.runner.methods_repository_interface import MethodRepository + + +class SeamBlenderWrapper(GenericMethodWrapper): + """ + Wrapper for seam blender. + """ + + @classmethod + def should_select_this_class(cls, module_path: str, method_name: str) -> bool: + return "seam_blend" in method_name + + @classmethod + def requires_preview(cls) -> bool: + return True + + def __init__( + self, + method_repository: MethodRepository, + module_path: str, + method_name: str, + comm: Comm, + preview_config: PreviewConfig, + save_result: Optional[bool] = None, + output_mapping: Dict[str, str] = {}, + **kwargs, + ): + super().__init__( + method_repository, + module_path, + method_name, + comm, + save_result, + output_mapping, + **kwargs, + ) + self._update_params_from_preview(preview_config) + + def _update_params_from_preview(self, preview_config: PreviewConfig) -> None: + """ + Extract information from preview config to modify seam index parameter required for `seam_blend_stitched_data` method. + """ + SHIFT_PARAM_NAME = "shift_seam_index" + det_x_start = preview_config.detector_x.start + + self.append_config_params( + { + SHIFT_PARAM_NAME: det_x_start, + } + ) diff --git a/httomo/runner/task_runner.py b/httomo/runner/task_runner.py index af90c976c..7af8dbdb6 100644 --- a/httomo/runner/task_runner.py +++ b/httomo/runner/task_runner.py @@ -81,11 +81,18 @@ def execute(self) -> None: def _pipeline_inspector(self) -> None: phase_in_pipeline = False minus_log_in_pipeline = False + seam_blend_in_pipeline = False for i, method in enumerate(self.pipeline._methods): if "phase" in method.module_path: phase_in_pipeline = True if "minus_log" == method.method_name: minus_log_in_pipeline = True + if "seam_blend" in method.method_name: + seam_blend_in_pipeline = True + if "dark_flat" in method.method_name and seam_blend_in_pipeline: + raise RuntimeError( + "Seam blending method should be used only after the dark/flat field correction" + ) if "rotation" in method.module_path: if self.pipeline[i - 1].method_name == "data_reducer": log_once( diff --git a/tests/method_wrappers/test_seam_blender.py b/tests/method_wrappers/test_seam_blender.py new file mode 100644 index 000000000..c4cd53ccb --- /dev/null +++ b/tests/method_wrappers/test_seam_blender.py @@ -0,0 +1,73 @@ +import pytest +from mpi4py import MPI +from pytest_mock import MockerFixture + +from httomo.method_wrappers.seam_blender import SeamBlenderWrapper +from httomo.preview import PreviewConfig, PreviewDimConfig +from tests.testing_utils import make_mock_repo + + +@pytest.mark.parametrize( + "method_name, expected_result", + [("seam_blend", True), ("other_method", False)], + ids=["should-select", "shouldn't-select"], +) +def test_class_only_selected_for_methods_with_seam_blend_in_name( + method_name: str, expected_result: bool +): + assert ( + SeamBlenderWrapper.should_select_this_class("dummy.module.path", method_name) + is expected_result + ) + + +def test_requires_preview_is_true(): + assert SeamBlenderWrapper.requires_preview() is True + + +@pytest.mark.parametrize( + "preview_config", + [ + PreviewConfig( + angles=PreviewDimConfig(start=0, stop=180), + detector_y=PreviewDimConfig(start=0, stop=128), + detector_x=PreviewDimConfig(start=0, stop=160), + ), + PreviewConfig( + angles=PreviewDimConfig(start=0, stop=180), + detector_y=PreviewDimConfig(start=0, stop=128), + detector_x=PreviewDimConfig(start=5, stop=155), + ), + ], + ids=["no_cropping", "crop_det_x_both_ends"], +) +def test_sets_shiftx_params_correctly( + preview_config: PreviewConfig, mocker: MockerFixture +): + MODULE_PATH = "dummy.module.path" + METHOD_NAME = "seam_blend_dummy" + COMM = MPI.COMM_WORLD + + # Patch method function import that occurs when the wrapper object is created, to instead + # import the below dummy method function + class FakeModule: + def seam_blend_dummy(shift_seam_index): # type: ignore + return shift_seam_index + + mocker.patch( + "httomo.method_wrappers.generic.import_module", return_value=FakeModule + ) + + wrapper = SeamBlenderWrapper( + method_repository=make_mock_repo(mocker), + module_path=MODULE_PATH, + method_name=METHOD_NAME, + comm=COMM, + preview_config=preview_config, + ) + + expected_shift_values = preview_config.detector_x.start + + SHIFT_PARAM_NAME = "shift_seam_index" + assert SHIFT_PARAM_NAME in wrapper.config_params + assert wrapper.config_params[SHIFT_PARAM_NAME] == expected_shift_values