Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion packages/essreduce/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ dynamic = ["version"]

dependencies = [
"sciline>=25.11.0",
"scipp>=26.3.0",
"scipp>=26.3.1",
"scippneutron>=25.11.1",
"scippnexus>=25.06.0",
]
Expand Down
24 changes: 10 additions & 14 deletions packages/essreduce/src/ess/reduce/nexus/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,25 +280,21 @@ class NeXusTransformationChain(

@dataclass
class NeXusTransformation(Generic[Component, RunType]):
value: sc.Variable
"""A NeXus transformation computed from a transformation chain.

If the transformation is time-dependent, it is stored as a data array
with a 'time' coordinate.
Otherwise, the transformation is stored as a variable.
"""

value: sc.Variable | sc.DataArray

@staticmethod
def from_chain(
chain: NeXusTransformationChain[Component, RunType],
) -> 'NeXusTransformation[Component, RunType]':
"""
Convert a transformation chain to a single transformation.

As transformation chains may be time-dependent, this method will need to select
a specific time point to convert to a single transformation. This may include
averaging as well as threshold checks. This is not implemented yet and we
therefore currently raise an error if the transformation chain does not compute
to a scalar.
"""
if chain.transformations.sizes != {}:
raise ValueError(f"Expected scalar transformation, got {chain}")
transform = chain.compute()
return NeXusTransformation(value=transform)
"""Convert a transformation chain to a single transformation."""
return NeXusTransformation(value=chain.compute())


class RawChoppers(
Expand Down
83 changes: 68 additions & 15 deletions packages/essreduce/src/ess/reduce/nexus/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from copy import deepcopy
from typing import Any, TypeVar

import numpy as np
import sciline
import sciline.typing
import scipp as sc
Expand Down Expand Up @@ -275,26 +276,51 @@ def load_nexus_data(


def get_transformation_chain(
detector: NeXusComponent[Component, RunType],
component: NeXusComponent[Component, RunType],
) -> NeXusTransformationChain[Component, RunType]:
"""
Extract the transformation chain from a NeXus detector group.
Extract the transformation chain from a NeXus component group.

Parameters
----------
detector:
NeXus detector group.
component:
NeXus component group.
"""
chain = detector['depends_on']
chain = component['depends_on']
return NeXusTransformationChain[Component, RunType](chain)


def _time_filter(transform: sc.DataArray) -> sc.Variable:
def _collapse_runs(transform: sc.DataArray, dim: str) -> sc.DataArray:
"""Collapse runs of equal values into a single value."""
# Find indices where the data changes
different_from_previous = np.hstack(
[True, ~np.isclose(transform.values[:-1], transform.values[1:])]
)
change_indices = np.flatnonzero(different_from_previous)
if change_indices.shape == transform.shape:
return transform # Return early to avoid expensive indexing
# Get unique values
unique_values = transform[change_indices]

# Make bin-edges and extend range to include the whole measurement
last = unique_values.coords[dim][-1]
unique_values.coords[dim] = sc.concat(
[
# bin-edges are left-inclusive, so we can start with coord[0] as first edge
unique_values.coords[dim],
# Surely, no experiment will last more than 10 years...
last + sc.scalar(10, unit='Y').to(unit=last.unit),
],
dim=dim,
)

return unique_values


def _time_filter(transform: sc.DataArray) -> sc.Variable | sc.DataArray:
if transform.ndim == 0 or transform.sizes == {'time': 1}:
return transform.data.squeeze()
raise ValueError(
f"Transform is time-dependent: {transform}, but no filter is provided."
)
return _collapse_runs(transform, dim='time')


def to_transformation(
Expand Down Expand Up @@ -369,6 +395,10 @@ def get_calibrated_detector(
The data array is reshaped to the logical detector shape, by folding the data
array along the detector_number dimension.

The output contains pixel positions computed from ``transform`` and ``offset``.
If ``transform`` is time-dependent, the output contains a 'time' dimension
and coordinate corresponding to the time coordinate of ``transform``.

Parameters
----------
detector:
Expand Down Expand Up @@ -401,9 +431,17 @@ def get_calibrated_detector(
else:
transform_value = transform.value
position = transform_value * offsets
return EmptyDetector[RunType](
da.assign_coords(position=position + offset.to(unit=position.unit))
)

position = position + offset.to(unit=position.unit)
if isinstance(position, sc.DataArray): # time-dependent transform
# Store position and time as separate coords because we can't store data arrays.
return EmptyDetector[RunType](
da.broadcast(
dims=['time', *da.dims], shape=[position.sizes['time'], *da.shape]
).assign_coords(position=position.data, time=position.coords['time'])
)

return EmptyDetector[RunType](da.assign_coords(position=position))


def assemble_detector_data(
Expand All @@ -422,13 +460,29 @@ def assemble_detector_data(
neutron_data:
Neutron data array (events or histogram).
"""
if neutron_data.bins is not None:
detector_coords = dict(detector.coords)
if neutron_data.is_binned:
neutron_data = nexus.group_event_data(
event_data=neutron_data, detector_number=detector.coords['detector_number']
)
if 'time' in detector.dims:
# Give the neutron data a 'time' dimension matching the times in the
# detector data. Preserve the `event_time_zero` event coord.
# This is needed to add time-dependent detector coords and masks below.
neutron_data = neutron_data.bin(
event_time_zero=detector_coords['time'].rename(time='event_time_zero')
).rename_dims(event_time_zero='time')
neutron_data.coords['time'] = neutron_data.coords.pop('event_time_zero')
else:
position = detector_coords.get('position')
if position is not None and 'time' in position.dims:
raise NotImplementedError(
"Time-dependent positions are not yet supported for histogram data."
)

return RawDetector[RunType](
_add_variances(neutron_data)
.assign_coords(detector.coords)
.assign_coords(detector_coords)
.assign_masks(detector.masks)
)

Expand Down Expand Up @@ -659,7 +713,6 @@ def load_source_metadata_from_nexus(
definitions["NXdetector"] = _StrippedDetector
definitions["NXmonitor"] = _StrippedMonitor


_common_providers = (
gravity_vector_neg_y,
file_path_to_file_spec,
Expand Down
33 changes: 23 additions & 10 deletions packages/essreduce/tests/nexus/workflow_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,29 @@ def test_can_compute_position_of_group(depends_on: snx.TransformationChain) -> N
assert_identical(workflow.compute_position(trans), position)


def test_can_compute_position_of_group_time_dependent(
time_dependent_depends_on: snx.TransformationChain,
) -> None:
position = sc.DataArray(
sc.vectors(
dims=['time'],
values=[[1.0, 1.0, 0.0], [1.0, 2.0, 0.0], [1.0, 3.0, 0.0]],
unit='m',
),
coords={'time': sc.array(dims=['time'], values=[0.0, 1.0, 2.0], unit='s')},
)

group = workflow.NeXusComponent[snx.NXsource, SampleRun](
sc.DataGroup(depends_on=time_dependent_depends_on)
)
chain = workflow.get_transformation_chain(group)
trans = workflow.to_transformation(
chain,
interval=TimeInterval(slice(None, None)),
)
assert_identical(workflow.compute_position(trans), position)


def test_to_transform_with_positional_time_interval(
time_dependent_depends_on: snx.TransformationChain,
) -> None:
Expand Down Expand Up @@ -172,16 +195,6 @@ def test_to_transform_with_label_based_time_interval_single_point(
assert sc.identical(transform * origin, sc.vector([1.0, 3.0, 0.0], unit='m'))


def test_to_transform_raises_if_interval_does_not_yield_unique_value(
time_dependent_depends_on: snx.TransformationChain,
) -> None:
with pytest.raises(ValueError, match='Transform is time-dependent'):
workflow.to_transformation(
time_dependent_depends_on,
TimeInterval(slice(sc.scalar(0.1, unit='s'), sc.scalar(1.9, unit='s'))),
)


def test_given_no_sample_load_nexus_sample_returns_group_with_origin_depends_on(
loki_tutorial_sample_run_60250: Path,
) -> None:
Expand Down
Loading