Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 50 additions & 8 deletions bdsf/rmsimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,50 @@
pass
from .functions import read_image_from_file

from concurrent.futures import ThreadPoolExecutor
import numpy
from scipy import ndimage


def mapcoord_threaded(a, axs, *args, ncores=8, **kwargs):
"""Threaded map_coordinates on cartesian coordinate grid (meshgrid)

:param a: Array to be regridded

:param axs: List of axes onto which to regrid. Result is gridded
to equivalent of meshgrid(*axs)

"""
output = kwargs.get("output", None)
kwargs["output"] = None
# Prefilter only once to awoid repeated work in the workers. See
# _interpolation.py in scipy
order = kwargs.get("order", 3)
if order > 1:
a = ndimage.spline_filter(a, order,
output=numpy.float64,
mode=kwargs.get("mode", "constant"))

def tworker(cl1):
# Construct the subset of meshgrid to which worker has been
# applied.
# NB: The axis reversal is specific to this program. The indexing parameter
# does not do exactly the same thing
cl = numpy.meshgrid( * ([cl1]+axs[1:]))[-1::-1]
return ndimage.map_coordinates(a,
cl,
# NB we pulled the pre-filter out
prefilter=False,
*args, **kwargs)

with ThreadPoolExecutor(max_workers=ncores) as te:
res=te.map(tworker,
axs[0])
res=numpy.hstack(list(res))
if output is not None:
numpy.copyto(output, res)
return res


class Op_rmsimage(Op):
"""Calculate rms & noise maps
Expand Down Expand Up @@ -555,8 +599,7 @@ def map_2d(self, arr, out_mean, out_rms, mask=False,
"""
mask_small = mask
axes, mean_map1, rms_map1 = self.rms_mean_map(arr, mask_small, kappa, box, ncores)
ax = [self.remap_axis(ashp, axv) for (ashp, axv) in zip(arr.shape, axes)]
ax = N.meshgrid(*ax[-1::-1])
ax = [self.remap_axis(ashp, axv) for (ashp, axv) in zip(arr.shape, axes)][-1::-1]
pt_src_scale = box[0]
if do_adapt:
out_rms2 = N.zeros(rms_map1.shape, dtype=N.float32)
Expand All @@ -569,10 +612,9 @@ def map_2d(self, arr, out_mean, out_rms, mask=False,
axes2mod = axes2[:]
axes2mod[0] = axes2[0]/arr.shape[0]*mean_map1.shape[0]
axes2mod[1] = axes2[1]/arr.shape[1]*mean_map1.shape[1]
ax2 = [self.remap_axis(ashp, axv) for (ashp, axv) in zip(out_rms2.shape, axes2mod)]
ax2 = N.meshgrid(*ax2[-1::-1])
nd.map_coordinates(rms_map2, ax2[-1::-1], order=interp, output=out_rms2)
nd.map_coordinates(mean_map2, ax2[-1::-1], order=interp, output=out_mean2)
ax2 = [self.remap_axis(ashp, axv) for (ashp, axv) in zip(out_rms2.shape, axes2mod)][-1::-1]
mapcoord_threaded(rms_map2, ax2, order=interp, output=out_rms2, ncores=ncores)
mapcoord_threaded(mean_map2, ax2, order=interp, output=out_mean2, ncores=ncores)
rms_map = out_rms2
mean_map = out_mean2

Expand Down Expand Up @@ -615,8 +657,8 @@ def map_2d(self, arr, out_mean, out_rms, mask=False,

# Interpolate to image coords
mylog = mylogger.logging.getLogger(logname+"Rmsimage")
nd.map_coordinates(rms_map, ax[-1::-1], order=interp, output=out_rms)
nd.map_coordinates(mean_map, ax[-1::-1], order=interp, output=out_mean)
mapcoord_threaded(rms_map, ax, order=interp, output=out_rms, ncores=ncores)
mapcoord_threaded(mean_map, ax, order=interp, output=out_mean, ncores=ncores)

# Apply mask to mean_map and rms_map by setting masked values to NaN
if isinstance(mask, N.ndarray):
Expand Down