Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion httomo/data/dataset_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import time
import h5py
from typing import List, Literal, Optional, Tuple, Union
from httomo.data.hdf._utils.reslice import reslice
from httomo.data.hdf._utils.reslice import reslice, reslice_memory_estimator
from httomo.data.padding import extrapolate_after, extrapolate_before
from httomo.runner.auxiliary_data import AuxiliaryData
from httomo.runner.dataset import DataSetBlock
Expand Down Expand Up @@ -286,6 +286,10 @@ def __init__(
start = time.perf_counter()
self._data = self._reslice(source.slicing_dim, slicing_dim, source_data)
end = time.perf_counter()
log_once(
f"reslice_memory_estimator: {reslice_memory_estimator(source_data.shape, source_data.dtype, source.slicing_dim, slicing_dim, self._comm)}",
level=logging.DEBUG,
)
if slicing_dim == 1:
log_once(
f"Slicing axis change (reslice) from projection to sinogram took {(end - start):.9f}s.",
Expand Down
81 changes: 81 additions & 0 deletions httomo/data/hdf/_utils/reslice.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,84 @@ def reslice(

start_idx = 0 if comm.rank == 0 else split_indices[comm.rank - 1]
return new_data, next_slice_dim, start_idx


def reslice_memory_estimator(
data_shape: Tuple[int, int, int],
dtype: numpy.dtype,
current_slice_dim: int,
next_slice_dim: int,
comm: Comm,
) -> Tuple[int, int]:
rank = comm.rank
nprocs = comm.size
itemsize = numpy.dtype(dtype).itemsize

split_sizes = []
length = data_shape[next_slice_dim]
split_indices = [round((length / nprocs) * r) for r in range(1, nprocs)]

prev_idx = 0
for i in range(nprocs):
next_idx = split_indices[i] if i < len(split_indices) else length
split_shape = list(data_shape)
split_shape[next_slice_dim] = next_idx - prev_idx
split_sizes.append(numpy.prod(split_shape) * itemsize)
prev_idx = next_idx

all_split_sizes = comm.allgather(split_sizes)
recv_sizes = [all_split_sizes[p][rank] for p in range(nprocs)]

output_shape = list(data_shape)
output_shape[current_slice_dim] = sum(
recv_sizes[p]
// (
itemsize
* numpy.prod([data_shape[d] for d in range(3) if d != next_slice_dim])
)
for p in range(nprocs)
)
output_size = numpy.prod(output_shape) * itemsize

max_send_buffer = max(split_sizes)
max_recv_buffer = max(recv_sizes)

from httomo.data.mpiutil import _mpi_max_elements

max_elements = _mpi_max_elements - 1
max_transfer_elements = max(
max(split_sizes) // itemsize, max(recv_sizes) // itemsize
)

needs_chunking = max_transfer_elements > max_elements

if needs_chunking:
chunk_overhead_send = max_send_buffer
chunk_overhead_recv = max_recv_buffer
else:
chunk_overhead_send = 0
chunk_overhead_recv = 0

# The final values for the peak allocation sizes before, during, and after the ring
# algorithm have been kept in for the sake of completeness. However, the values that matter
# most are the allocations that the reslice algorithm require, namely:
# - what the ring algorithm allocates
# - what the output size allocated is
#
# peak_before_ring = input_size + output_size
#
# peak_during_ring = (
# peak_before_ring
# + max_send_buffer # Temporary send buffer
# + max_recv_buffer # Temporary recv buffer
# + chunk_overhead_send # Flattened send array (if chunking)
# + chunk_overhead_recv # Flattened recv array (if chunking)
# )
#
# peak_after_ring = input_size + output_size

ring_algorithm_allocations = (
max_send_buffer + max_recv_buffer + chunk_overhead_send + chunk_overhead_recv
)

return (ring_algorithm_allocations, output_size)
Loading