Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions httomo/method_wrappers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import httomo.method_wrappers.datareducer
import httomo.method_wrappers.dezinging
import httomo.method_wrappers.distortion_correction
import httomo.method_wrappers.seam_blender
import httomo.method_wrappers.images
import httomo.method_wrappers.reconstruction
import httomo.method_wrappers.rotation
Expand Down
56 changes: 56 additions & 0 deletions httomo/method_wrappers/seam_blender.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from typing import Dict, Optional

from mpi4py.MPI import Comm

from httomo.method_wrappers.generic import GenericMethodWrapper
from httomo.preview import PreviewConfig
from httomo.runner.methods_repository_interface import MethodRepository


class SeamBlenderWrapper(GenericMethodWrapper):
"""
Wrapper for seam blender.
"""

@classmethod
def should_select_this_class(cls, module_path: str, method_name: str) -> bool:
return "seam_blend" in method_name

@classmethod
def requires_preview(cls) -> bool:
return True

def __init__(
self,
method_repository: MethodRepository,
module_path: str,
method_name: str,
comm: Comm,
preview_config: PreviewConfig,
save_result: Optional[bool] = None,
output_mapping: Dict[str, str] = {},
**kwargs,
):
super().__init__(
method_repository,
module_path,
method_name,
comm,
save_result,
output_mapping,
**kwargs,
)
self._update_params_from_preview(preview_config)

def _update_params_from_preview(self, preview_config: PreviewConfig) -> None:
"""
Extract information from preview config to modify seam index parameter required for `seam_blend_stitched_data` method.
"""
SHIFT_PARAM_NAME = "shift_seam_index"
det_x_start = preview_config.detector_x.start

self.append_config_params(
{
SHIFT_PARAM_NAME: det_x_start,
}
)
7 changes: 7 additions & 0 deletions httomo/runner/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,18 @@ def execute(self) -> None:
def _pipeline_inspector(self) -> None:
phase_in_pipeline = False
minus_log_in_pipeline = False
seam_blend_in_pipeline = False
for i, method in enumerate(self.pipeline._methods):
if "phase" in method.module_path:
phase_in_pipeline = True
if "minus_log" == method.method_name:
minus_log_in_pipeline = True
if "seam_blend" in method.method_name:
seam_blend_in_pipeline = True
if "dark_flat" in method.method_name and seam_blend_in_pipeline:
raise RuntimeError(
"Seam blending method should be used only after the dark/flat field correction"
)
if "rotation" in method.module_path:
if self.pipeline[i - 1].method_name == "data_reducer":
log_once(
Expand Down
73 changes: 73 additions & 0 deletions tests/method_wrappers/test_seam_blender.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import pytest
from mpi4py import MPI
from pytest_mock import MockerFixture

from httomo.method_wrappers.seam_blender import SeamBlenderWrapper
from httomo.preview import PreviewConfig, PreviewDimConfig
from tests.testing_utils import make_mock_repo


@pytest.mark.parametrize(
"method_name, expected_result",
[("seam_blend", True), ("other_method", False)],
ids=["should-select", "shouldn't-select"],
)
def test_class_only_selected_for_methods_with_seam_blend_in_name(
method_name: str, expected_result: bool
):
assert (
SeamBlenderWrapper.should_select_this_class("dummy.module.path", method_name)
is expected_result
)


def test_requires_preview_is_true():
assert SeamBlenderWrapper.requires_preview() is True


@pytest.mark.parametrize(
"preview_config",
[
PreviewConfig(
angles=PreviewDimConfig(start=0, stop=180),
detector_y=PreviewDimConfig(start=0, stop=128),
detector_x=PreviewDimConfig(start=0, stop=160),
),
PreviewConfig(
angles=PreviewDimConfig(start=0, stop=180),
detector_y=PreviewDimConfig(start=0, stop=128),
detector_x=PreviewDimConfig(start=5, stop=155),
),
],
ids=["no_cropping", "crop_det_x_both_ends"],
)
def test_sets_shiftx_params_correctly(
preview_config: PreviewConfig, mocker: MockerFixture
):
MODULE_PATH = "dummy.module.path"
METHOD_NAME = "seam_blend_dummy"
COMM = MPI.COMM_WORLD

# Patch method function import that occurs when the wrapper object is created, to instead
# import the below dummy method function
class FakeModule:
def seam_blend_dummy(shift_seam_index): # type: ignore
return shift_seam_index

mocker.patch(
"httomo.method_wrappers.generic.import_module", return_value=FakeModule
)

wrapper = SeamBlenderWrapper(
method_repository=make_mock_repo(mocker),
module_path=MODULE_PATH,
method_name=METHOD_NAME,
comm=COMM,
preview_config=preview_config,
)

expected_shift_values = preview_config.detector_x.start

SHIFT_PARAM_NAME = "shift_seam_index"
assert SHIFT_PARAM_NAME in wrapper.config_params
assert wrapper.config_params[SHIFT_PARAM_NAME] == expected_shift_values
Loading