Skip to content

Commit

Permalink
Merge pull request #214 from AlistairSymonds/dev_reproject_block
Browse files Browse the repository at this point in the history
Add support for blocked and parallel reprojection in ``reproject_interp``
  • Loading branch information
astrofrog authored Sep 6, 2022
2 parents 89178da + 85c9ffb commit 3f29f13
Show file tree
Hide file tree
Showing 3 changed files with 298 additions and 11 deletions.
47 changes: 37 additions & 10 deletions reproject/interpolation/high_level.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# Licensed under a 3-clause BSD style license - see LICENSE.rst
import os

from astropy.utils import deprecated_renamed_argument

from ..utils import parse_input_data, parse_output_projection
from ..utils import parse_input_data, parse_output_projection, reproject_blocked
from .core import _reproject_full

__all__ = ['reproject_interp']
Expand All @@ -17,16 +18,12 @@
@deprecated_renamed_argument('independent_celestial_slices', None, since='0.6')
def reproject_interp(input_data, output_projection, shape_out=None, hdu_in=0,
order='bilinear', independent_celestial_slices=False,
output_array=None, return_footprint=True,
roundtrip_coords=True):
output_array=None, return_footprint=True, output_footprint=None,
block_size=None, parallel=False, roundtrip_coords=True):
"""
Reproject data to a new projection using interpolation (this is typically
the fastest way to reproject an image).
The output pixel grid is transformed to the input pixel grid, and the
data values in ``input_data`` interpolated on to these coordinates to get
the reprojected data on the output grid.
Parameters
----------
input_data
Expand Down Expand Up @@ -69,6 +66,16 @@ def reproject_interp(input_data, output_projection, shape_out=None, hdu_in=0,
extremely large files.
return_footprint : bool
Whether to return the footprint in addition to the output array.
block_size : None or tuple of (int, int)
If not none, a blocked projection will be performed where the output space is
reprojected to one block at a time, this is useful for memory limited scenarios
such as dealing with very large arrays or high resolution output spaces.
parallel : bool or int
Flag for parallel implementation. If ``True``, a parallel implementation
is chosen, the number of processes selected automatically to be equal to
the number of logical CPUs detected on the machine. If ``False``, a
serial implementation is chosen. If the flag is a positive integer ``n``
greater than one, a parallel implementation using ``n`` processes is chosen.
roundtrip_coords : bool
Whether to verify that coordinate transformations are defined in both
directions.
Expand All @@ -90,6 +97,26 @@ def reproject_interp(input_data, output_projection, shape_out=None, hdu_in=0,
if isinstance(order, str):
order = ORDER[order]

return _reproject_full(array_in, wcs_in, wcs_out, shape_out=shape_out,
order=order, array_out=output_array,
return_footprint=return_footprint, roundtrip_coords=roundtrip_coords)
# if either of these are not default, it means a blocked method must be used
if block_size is not None or parallel is not False:
# if parallel is set but block size isn't, we'll choose
# block size so each thread gets one block each
if parallel is not False and block_size is None:
block_size = shape_out.copy()
# each thread gets an equal sized strip of output area to process
block_size[0] = shape_out[0] // os.cpu_count()

# given we have cases where modern system have many cpu cores some sanity clamping is
# to avoid 0 length block sizes when num_cpu_cores is greater than the side of the image
for dim_idx in range(min(len(shape_out), 2)):
if block_size[dim_idx] == 0:
block_size[dim_idx] = shape_out[dim_idx]

return reproject_blocked(_reproject_full, array_in=array_in, wcs_in=wcs_in, wcs_out=wcs_out,
shape_out=shape_out, output_array=output_array, parallel=parallel,
block_size=block_size, return_footprint=return_footprint,
output_footprint=output_footprint)
else:
return _reproject_full(array_in, wcs_in, wcs_out, shape_out=shape_out,
order=order, array_out=output_array,
return_footprint=return_footprint, roundtrip_coords=roundtrip_coords)
86 changes: 86 additions & 0 deletions reproject/interpolation/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,3 +581,89 @@ def test_identity_with_offset(roundtrip_coords):
expected = np.pad(array_in, 1, 'constant', constant_values=np.nan)

assert_allclose(expected, array_out, atol=1e-10)


@pytest.mark.parametrize('parallel', [True, 2, False])
@pytest.mark.parametrize('block_size', [[10, 10], [500, 500], [500, 100], None])
def test_blocked_against_single(parallel, block_size):

# Ensure when we break a reprojection down into multiple discrete blocks
# it has the same result as if all pixels where reprejcted at once

hdu1 = fits.open(get_pkg_data_filename('galactic_center/gc_2mass_k.fits'))[0]
hdu2 = fits.open(get_pkg_data_filename('galactic_center/gc_msx_e.fits'))[0]
array_test = None
footprint_test = None

# the warning import and ignore is needed to keep pytest happy when running with
# older versions of astropy which don't have this fix:
# https://github.com/astropy/astropy/pull/12844
# All the warning code should be removed when old version no longer being used
# Using context manager ensure only blocked function has them ignored
import warnings
with warnings.catch_warnings():
warnings.simplefilter('ignore', category=FITSFixedWarning)

# this one is needed to avoid the following warning from when the np.as_strided() is
# called in wcs_utils.unbroadcast(), only shows up with py3.8, numpy1.17, astropy 4.0.*:
# DeprecationWarning: Numpy has detected that you (may be) writing to an array with
# overlapping memory from np.broadcast_arrays. If this is intentional
# set the WRITEABLE flag True or make a copy immediately before writing.
# We do call as_strided with writeable=True as it recommends and only shows up with the 10px
# testcase so assuming a numpy bug in the detection code which was fixed in later version.
# The pixel values all still match in the end, only shows up due to pytest clearing
# the standard python warning filters by default and failing as the warnings are now
# treated as the exceptions they're implemented on
if (block_size == [10, 10]):
warnings.simplefilter('ignore', category=DeprecationWarning)

array_test, footprint_test = reproject_interp(hdu2, hdu1.header,
parallel=parallel, block_size=block_size)

array_reference, footprint_reference = reproject_interp(hdu2, hdu1.header,
parallel=False, block_size=None)

np.testing.assert_allclose(array_test, array_reference, equal_nan=True)
np.testing.assert_allclose(footprint_test, footprint_reference, equal_nan=True)


def test_blocked_corner_cases():

"""
When doing blocked there are a few checks designed to sanity clamp/preserve
values. Even though the blocking process only tiles in a 2d manner 3d information
about the image needs to be preserved and transformed correctly. Additonally
when automatically determining block size based on CPU cores zeros can appear on
machines where num_cores > x or y dim of output image. So make sure it correctly
functions when 0 block size goes in
"""

# Read in the input cube
hdu_in = fits.open(
get_pkg_data_filename('data/equatorial_3d.fits', package='reproject.tests'))[0]

# Define the output header - this should be the same for all versions of
# this test to make sure we can use a single reference file.
header_out = hdu_in.header.copy()
header_out['NAXIS1'] = 10
header_out['NAXIS2'] = 9
header_out['CTYPE1'] = 'GLON-SIN'
header_out['CTYPE2'] = 'GLAT-SIN'
header_out['CRVAL1'] = 163.16724
header_out['CRVAL2'] = -15.777405
header_out['CRPIX1'] = 6
header_out['CRPIX2'] = 5

array_reference = reproject_interp(hdu_in, header_out, return_footprint=False)

array_test = None

# same reason as test above for FITSFixedWarning
import warnings
with warnings.catch_warnings():
warnings.simplefilter('ignore', category=FITSFixedWarning)

array_test = reproject_interp(hdu_in, header_out, parallel=True,
block_size=[0, 4], return_footprint=False)

np.testing.assert_allclose(array_test, array_reference, equal_nan=True, verbose=True)
176 changes: 175 additions & 1 deletion reproject/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from concurrent import futures

import astropy.nddata
import numpy as np
from astropy.io import fits
from astropy.io.fits import CompImageHDU, HDUList, Header, ImageHDU, PrimaryHDU
from astropy.wcs import WCS
from astropy.wcs.wcsapi import BaseHighLevelWCS
from astropy.wcs.wcsapi import BaseHighLevelWCS, SlicedLowLevelWCS
from astropy.wcs.wcsapi.high_level_wcs_wrapper import HighLevelWCSWrapper

__all__ = ['parse_input_data', 'parse_input_shape', 'parse_input_weights',
'parse_output_projection']
Expand Down Expand Up @@ -133,3 +136,174 @@ def parse_output_projection(output_projection, shape_out=None, output_array=None
raise ValueError("The shape of the output image should not be an "
"empty tuple")
return wcs_out, shape_out


def _block(reproject_func, array_in, wcs_in, wcs_out_sub, shape_out, i_range, j_range,
return_footprint):
"""
Implementation function that handles reprojecting subsets blocks of pixels
from an input image and holds metadata about where to reinsert when done.
Parameters
----------
reproject_func
One the existing reproject functions implementing a reprojection algorithm
that that will be used be used to perform reprojection
array_in
Data following the same format as expected by underlying reproject_func,
expected to `~numpy.ndarray` when used from reproject_blocked()
wcs_in: `~astropy.wcs.WCS`
WCS object corresponding to array_in
wcs_out_sub:
Output WCS image will be projected to. Normally will correspond to subset of
total output image when used by repoject_blocked()
shape_out:
Passed to reproject_func() alongside WCS out to determine image size
i_range:
Passed through unmodified, used to determine where to reinsert block
j_range:
Passed through unmodified, used to determine where to reinsert block
"""

result = reproject_func(array_in, wcs_in, wcs_out_sub,
shape_out=shape_out, return_footprint=return_footprint)

res_arr = None
res_fp = None

if return_footprint:
res_arr, res_fp = result
else:
res_arr = result

return {'i': i_range, 'j': j_range, 'res_arr': res_arr, 'res_fp': res_fp}


def reproject_blocked(reproject_func, array_in, wcs_in, shape_out, wcs_out, block_size,
output_array=None,
return_footprint=True, output_footprint=None, parallel=True):
"""
Implementation function that handles reprojecting subsets blocks of pixels
from an input image and holds metadata about where to reinsert when done.
Parameters
----------
reproject_func
One the existing reproject functions implementing a reprojection algorithm
that that will be used be used to perform reprojection
array_in
Data following the same format as expected by underlying reproject_func,
expected to `~numpy.ndarray` when used from reproject_blocked()
wcs_in: `~astropy.wcs.WCS`
WCS object corresponding to array_in
shape_out: tuple
Passed to reproject_func() alongside WCS out to determine image size
wcs_out: `~astropy.wcs.WCS`
Output WCS image will be projected to. Normally will correspond to subset of
total output image when used by repoject_blocked()
block_size: tuple
The size of blocks in terms of output array pixels that each block will handle
reprojecting. Extending out from (0,0) coords positively, block sizes
are clamped to output space edges when a block would extend past edge
output_array : None or `~numpy.ndarray`
An array in which to store the reprojected data. This can be any numpy
array including a memory map, which may be helpful when dealing with
extremely large files.
return_footprint : bool
Whether to return the footprint in addition to the output array.
output_footprint : None or `~numpy.ndarray`
An array in which to store the footprint of reprojected data. This can be
any numpy array including a memory map, which may be helpful when dealing with
extremely large files.
parallel : bool or int
Flag for parallel implementation. If ``True``, a parallel implementation
is chosen, the number of processes selected automatically to be equal to
the number of logical CPUs detected on the machine. If ``False``, a
serial implementation is chosen. If the flag is a positive integer ``n``
greater than one, a parallel implementation using ``n`` processes is chosen.
"""

if output_array is None:
output_array = np.zeros(shape_out, dtype=float)
if output_footprint is None and return_footprint:
output_footprint = np.zeros(shape_out, dtype=float)

# setup variables needed for multiprocessing if required
proc_pool = None
blocks_futures = []

if parallel or type(parallel) is int:
if type(parallel) is int:
if parallel <= 0:
raise ValueError("The number of processors to use must be strictly positive")
else:
proc_pool = futures.ProcessPoolExecutor(max_workers=parallel)
else:
proc_pool = futures.ProcessPoolExecutor()

# This will iterate over the output space, generating slices of that
# WCS and either processing and reinserting them immediately,
# or when doing parallel impl submit them to workers then wait and reinsert as
# the workers complete each block
for imin in range(0, output_array.shape[0], block_size[0]):
imax = min(imin + block_size[0], output_array.shape[0])
for jmin in range(0, output_array.shape[1], block_size[1]):
jmax = min(jmin + block_size[1], output_array.shape[1])
shape_out_sub = (imax - imin, jmax - jmin)
# if the output has more than two dims, just append them on the end of the
# shape to it still matches the base WCS
for dim in range(2, len(output_array.shape)):
shape_out_sub = shape_out_sub + (output_array.shape[dim],)

slices = [slice(imin, imax), slice(jmin, jmax)]
wcs_out_sub = HighLevelWCSWrapper(SlicedLowLevelWCS(wcs_out, slices=slices))

if proc_pool is None:
# if sequential input data and reinsert block into main array immediately
completed_block = _block(reproject_func=reproject_func, array_in=array_in,
wcs_in=wcs_in,
wcs_out_sub=wcs_out_sub, shape_out=shape_out_sub,
return_footprint=return_footprint,
j_range=(jmin, jmax), i_range=(imin, imax))

output_array[imin:imax, jmin:jmax] = completed_block['res_arr'][:]
if return_footprint:
output_footprint[imin:imax, jmin:jmax] = completed_block['res_fp'][:]

else:
# if parallel just submit all work items and move on to waiting for them to be done
future = proc_pool.submit(_block, reproject_func=reproject_func, array_in=array_in,
wcs_in=wcs_in, wcs_out_sub=wcs_out_sub,
shape_out=shape_out_sub,
return_footprint=return_footprint, j_range=(jmin, jmax),
i_range=(imin, imax))
blocks_futures.append(future)

# If a parallel implementation is being used that means the
# blocks have not been reassembled yet and must be done now as their
# block call completes in the worker processes
if proc_pool is not None:
completed_future_count = 0
for completed_future in futures.as_completed(blocks_futures):
completed_block = completed_future.result()
i_range = completed_block['i']
j_range = completed_block['j']
output_array[i_range[0]:i_range[1], j_range[0]:j_range[1]] = (
completed_block['res_arr'][:])

if return_footprint:
footprint_block = completed_block['res_fp'][:]
output_footprint[i_range[0]:i_range[1], j_range[0]:j_range[1]] = footprint_block

completed_future_count += 1
idx = blocks_futures.index(completed_future)
# ensure memory used by returned data is freed
completed_future._result = None
del blocks_futures[idx], completed_future
proc_pool.shutdown()
del blocks_futures

if return_footprint:
return output_array, output_footprint
else:
return output_array

0 comments on commit 3f29f13

Please sign in to comment.