Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify blocked reprojection implementation by using dask and improve efficiency of parallel reprojection #314

Merged
merged 22 commits into from
Feb 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion reproject/interpolation/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def _reproject_full(
array_out=None,
return_footprint=True,
roundtrip_coords=True,
output_footprint=None,
):
"""
Reproject n-dimensional data to a new projection using interpolation.
Expand All @@ -100,6 +101,9 @@ def _reproject_full(
if array_out is None:
array_out = np.empty(shape_out)

if output_footprint is None:
output_footprint = np.empty(shape_out)

array_out_loopable = array_out
if len(array.shape) == wcs_in.low_level_wcs.pixel_n_dim:
# We don't need to broadcast the transformation over any extra
Expand Down Expand Up @@ -149,6 +153,7 @@ def _reproject_full(
# also contains this data and has the user's desired output shape.

if return_footprint:
return array_out, (~np.isnan(array_out)).astype(float)
output_footprint[:] = (~np.isnan(array_out)).astype(float)
return array_out, output_footprint
else:
return array_out
18 changes: 3 additions & 15 deletions reproject/interpolation/high_level.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from astropy.utils import deprecated_renamed_argument

from ..utils import parse_input_data, parse_output_projection, reproject_blocked
from ..utils import _reproject_blocked, parse_input_data, parse_output_projection
from .core import _reproject_full

__all__ = ["reproject_interp"]
Expand Down Expand Up @@ -116,20 +116,7 @@ def reproject_interp(

# if either of these are not default, it means a blocked method must be used
if block_size is not None or parallel is not False:
# if parallel is set but block size isn't, we'll choose
# block size so each thread gets one block each
if parallel is not False and block_size is None:
block_size = list(shape_out)
# each thread gets an equal sized strip of output area to process
block_size[-2] = shape_out[-2] // os.cpu_count()

# given we have cases where modern system have many cpu cores some sanity clamping is
# to avoid 0 length block sizes when num_cpu_cores is greater than the side of the image
for dim_idx in range(min(len(shape_out), 2)):
if block_size[dim_idx] == 0:
block_size[dim_idx] = shape_out[dim_idx]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've removed this to instead let dask decide how to chunk the array, though we might want to provide a keyword argument that specifies the typical number of elements in a chunk.


return reproject_blocked(
return _reproject_blocked(
_reproject_full,
array_in=array_in,
wcs_in=wcs_in,
Expand All @@ -151,4 +138,5 @@ def reproject_interp(
array_out=output_array,
return_footprint=return_footprint,
roundtrip_coords=roundtrip_coords,
output_footprint=output_footprint,
)
113 changes: 48 additions & 65 deletions reproject/interpolation/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,9 +718,11 @@ def test_blocked_broadcast_reprojection(input_extra_dims, output_shape, parallel


@pytest.mark.parametrize("parallel", [True, 2, False])
@pytest.mark.parametrize("block_size", [[40, 40], [500, 500], [500, 100], None])
@pytest.mark.parametrize("block_size", [[500, 500], [500, 100], None])
@pytest.mark.parametrize("return_footprint", [False, True])
@pytest.mark.parametrize("existing_outputs", [False, True])
@pytest.mark.remote_data
def test_blocked_against_single(parallel, block_size):
def test_blocked_against_single(parallel, block_size, return_footprint, existing_outputs):
# Ensure when we break a reprojection down into multiple discrete blocks
# it has the same result as if all pixels where reprejcted at once

Expand All @@ -729,6 +731,19 @@ def test_blocked_against_single(parallel, block_size):
array_test = None
footprint_test = None

shape_out = (720, 721)

if existing_outputs:
output_array_test = np.zeros(shape_out)
output_footprint_test = np.zeros(shape_out)
output_array_reference = np.zeros(shape_out)
output_footprint_reference = np.zeros(shape_out)
else:
output_array_test = None
output_footprint_test = None
output_array_reference = None
output_footprint_reference = None

# the warning import and ignore is needed to keep pytest happy when running with
# older versions of astropy which don't have this fix:
# https://github.com/astropy/astropy/pull/12844
Expand All @@ -738,72 +753,40 @@ def test_blocked_against_single(parallel, block_size):

with warnings.catch_warnings():
warnings.simplefilter("ignore", category=FITSFixedWarning)

# this one is needed to avoid the following warning from when the np.as_strided() is
# called in wcs_utils.unbroadcast(), only shows up with py3.8, numpy1.17, astropy 4.0.*:
# DeprecationWarning: Numpy has detected that you (may be) writing to an array with
# overlapping memory from np.broadcast_arrays. If this is intentional
# set the WRITEABLE flag True or make a copy immediately before writing.
# We do call as_strided with writeable=True as it recommends and only shows up with the 10px
# testcase so assuming a numpy bug in the detection code which was fixed in later version.
# The pixel values all still match in the end, only shows up due to pytest clearing
# the standard python warning filters by default and failing as the warnings are now
# treated as the exceptions they're implemented on
if block_size == [10, 10]:
warnings.simplefilter("ignore", category=DeprecationWarning)

array_test, footprint_test = reproject_interp(
hdu2, hdu1.header, parallel=parallel, block_size=block_size
result_test = reproject_interp(
hdu2,
hdu1.header,
parallel=parallel,
block_size=block_size,
return_footprint=return_footprint,
output_array=output_array_test,
output_footprint=output_footprint_test,
)

array_reference, footprint_reference = reproject_interp(
hdu2, hdu1.header, parallel=False, block_size=None
result_reference = reproject_interp(
hdu2,
hdu1.header,
parallel=False,
block_size=None,
return_footprint=return_footprint,
output_array=output_array_reference,
output_footprint=output_footprint_reference,
)

np.testing.assert_allclose(array_test, array_reference, equal_nan=True)
np.testing.assert_allclose(footprint_test, footprint_reference, equal_nan=True)


@pytest.mark.remote_data
def test_blocked_corner_cases():
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is no longer relevant if we don't try and set the chunk size ourselves.

"""
When doing blocked there are a few checks designed to sanity clamp/preserve
values. Even though the blocking process only tiles in a 2d manner 3d information
about the image needs to be preserved and transformed correctly. Additonally
when automatically determining block size based on CPU cores zeros can appear on
machines where num_cores > x or y dim of output image. So make sure it correctly
functions when 0 block size goes in
"""

# Read in the input cube
hdu_in = fits.open(get_pkg_data_filename("data/equatorial_3d.fits", package="reproject.tests"))[
0
]

# Define the output header - this should be the same for all versions of
# this test to make sure we can use a single reference file.
header_out = hdu_in.header.copy()
header_out["NAXIS1"] = 10
header_out["NAXIS2"] = 9
header_out["CTYPE1"] = "GLON-SIN"
header_out["CTYPE2"] = "GLAT-SIN"
header_out["CRVAL1"] = 163.16724
header_out["CRVAL2"] = -15.777405
header_out["CRPIX1"] = 6
header_out["CRPIX2"] = 5

array_reference = reproject_interp(hdu_in, header_out, return_footprint=False)

array_test = None

# same reason as test above for FITSFixedWarning
import warnings

with warnings.catch_warnings():
warnings.simplefilter("ignore", category=FITSFixedWarning)
if return_footprint:
array_test, footprint_test = result_test
array_reference, footprint_reference = result_reference
else:
array_test = result_test
array_reference = result_reference

array_test = reproject_interp(
hdu_in, header_out, parallel=True, block_size=[0, 4], return_footprint=False
)
if existing_outputs:
assert array_test is output_array_test
assert array_reference is output_array_reference
if return_footprint:
assert footprint_test is output_footprint_test
assert footprint_reference is output_footprint_reference

np.testing.assert_allclose(array_test, array_reference, equal_nan=True, verbose=True)
np.testing.assert_allclose(array_test, array_reference, equal_nan=True)
if return_footprint:
np.testing.assert_allclose(footprint_test, footprint_reference, equal_nan=True)
Loading