Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dask quantity fixes #72

Merged
merged 7 commits into from
Nov 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions changelog/72.bugfix.1.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Fixed a bug where the way we dealt with `~astropy.unit.Quantity` objects was inconsistent with
`~dask.array.Array` objects in newer versions of `~numpy`. The `pre_check_hook` option keyword
argument has also been removed from `~sunkit_image.time_lag.time_lag` and `post_check_hook`
has been renamed to `array_check` and now accepts two arguments.
2 changes: 2 additions & 0 deletions changelog/72.bugfix.2.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Fixed a bug where a `~astropy.units.UnitConversionError` was thrown if a non-dimensionless
`~astropy.units.Quantity` object was input for the signal in `~sunkit_image.time_lag.cross_correlation`.
9 changes: 6 additions & 3 deletions examples/calculating_time_lags.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,13 @@ def gaussian_pulse(x, x0, sigma):
# In practice, these data cubes are often very large, sometimes many
# GB, such that doing operations like these on them can be prohibitively
# expensive. All of these operations can be parallelized and distributed
# easily by passing in the intensity cubes as Dask arrays.
# easily by passing in the intensity cubes as Dask arrays. Note that we
# strip the units off of our signal arrays before creating the Dask arrays
# from the as creating a Dask array from an `~astropy.units.Quantity` may
# result in undefined behavior.

s_a = dask.array.from_array(s_a, chunks=s_a.shape[:1] + (5, 5))
s_b = dask.array.from_array(s_b, chunks=s_b.shape[:1] + (5, 5))
s_a = dask.array.from_array(s_a.value, chunks=s_a.shape[:1] + (5, 5))
s_b = dask.array.from_array(s_b.value, chunks=s_b.shape[:1] + (5, 5))
tl_map = time_lag(s_a, s_b, time)
print(tl_map)

Expand Down
20 changes: 20 additions & 0 deletions sunkit_image/tests/test_time_lag.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,26 @@ def test_dask_numpy_consistent(shape_in):
assert u.allclose(max_cc, max_cc_dask.compute(), rtol=0.0, atol=None)


@pytest.mark.parametrize(
"shape_in",
[
((20, 5, 5)),
((100, 10)),
((1000, 1)),
],
)
def test_quantity_numpy_consistent(shape_in):
# Test that Quantities can be used as inputs for the signals and that
# it gives equivalent results to using bare numpy arrays
s_a = np.random.rand(*shape_in) * u.ct / u.s
s_b = np.random.rand(*shape_in) * u.ct / u.s
time = np.linspace(0, 1, shape_in[0]) * u.s
for func in [time_lag, max_cross_correlation]:
result_numpy = func(s_a.value, s_b.value, time)
result_quantity = func(s_a, s_b, time)
assert u.allclose(result_numpy, result_quantity, rtol=0.0, atol=None)


@pytest.mark.parametrize(
"shape_a,shape_b,lags,exception",
[
Expand Down
75 changes: 44 additions & 31 deletions sunkit_image/time_lag.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,14 @@

import astropy.units as u

DASK_INSTALLED = False
try:
import dask.array # do this here so that Dask is not a hard requirement

DASK_INSTALLED = True
except ImportError:
pass

__all__ = [
"cross_correlation",
"get_lags",
Expand Down Expand Up @@ -101,11 +109,16 @@ def cross_correlation(signal_a, signal_b, lags: u.s):
# Reverse the first timeseries
signal_a = signal_a[::-1]
# Normalize by mean and standard deviation
fill_value = signal_a.max()
std_a = signal_a.std(axis=0)
std_a = np.where(std_a == 0, 1, std_a) # avoid dividing by zero
# Avoid dividing by zero by replacing with some non-zero dummy value. Note that
# what this value is does not matter as it will be mulitplied by zero anyway
# since std_dev == 0 any place that signal - signal_mean == 0. We use the max
# of the signal as the fill_value in order to support Quantities.
std_a = np.where(std_a == 0, fill_value, std_a)
v_a = (signal_a - signal_a.mean(axis=0)[np.newaxis]) / std_a[np.newaxis]
std_b = signal_b.std(axis=0)
std_b = np.where(std_b == 0, 1, std_b)
std_b = np.where(std_b == 0, fill_value, std_b)
v_b = (signal_b - signal_b.mean(axis=0)[np.newaxis]) / std_b[np.newaxis]
# Cross-correlation is inverse of product of FFTS (by convolution theorem)
fft_a = np.fft.rfft(v_a, axis=0, n=lags.shape[0])
Expand All @@ -118,28 +131,32 @@ def cross_correlation(signal_a, signal_b, lags: u.s):
def _get_bounds_indices(lags, bounds):
# The start and stop indices are computed in this way
# because Dask does not like "fancy" multidimensional indexing
start = 0
stop = lags.shape[0] + 1
if bounds is not None:
(indices,) = np.where(np.logical_and(lags >= bounds[0], lags <= bounds[1]))
start = indices[0]
stop = indices[-1] + 1
else:
start = 0
stop = lags.shape[0] + 1
return start, stop


def _dask_check(signal, lags):
# This is to avoid having to specify time as a Dask array so that it
# can be specified as a quantity and so that the time lag can be
# returned as a Dask quantity and not computed eagerly
try:
import dask.array # do this here so that Dask is not a hard requirement
except ImportError:
return lags
if isinstance(signal, dask.array.Array):
return dask.array.from_array(lags, chunks=lags.shape)
def _dask_check(lags, indices):
# In order for the time lag to be returned as a Dask array, the lags array,
# which is, in general, a Quantity, must also be a Dask array.
# This function is needed for two reasons:
# 1. astropy.units.Quantity do not play nice with each other and their behavior
# seems to vary from one numpy version to the next. To avoid this ill-defined
# behavior, we will do all of our Dask-ing on a Dask array created from the
# bare numpy array and re-attach the units at the end.
# 2. Dask arrays do not like "fancy" multidimensional indexing. Therefore, we must
# flatten the indices first and then reshape the time lag array in order to
# preserve the laziness of the array evaluation.
if DASK_INSTALLED and isinstance(indices, dask.array.Array):
lags_lazy = dask.array.from_array(lags.value, chunks=lags.shape)
lags_unit = lags.unit
return lags_lazy[indices.flatten()].reshape(indices.shape) * lags_unit
else:
return lags
return lags[indices]


@u.quantity_input
Expand All @@ -158,7 +175,7 @@ def time_lag(signal_a, signal_b, time: u.s, lag_bounds: (u.s, None) = None, **kw
where :math:`\mathcal{C}_{AB}` is the cross-correlation as a function of
lag (computed in :func:`cross_correlation`). Qualitatively, this can be
thought of as how much `signal_a` needs to be shifted in time to best
"match" `signal_b`. Note that the sign of :math:`\\tau_{AB}`` is determined
"match" `signal_b`. Note that the sign of :math:`\tau_{AB}`` is determined
by the ordering of the two signals such that,

.. math::
Expand All @@ -181,15 +198,15 @@ def time_lag(signal_a, signal_b, time: u.s, lag_bounds: (u.s, None) = None, **kw

Other Parameters
----------------
pre_check_hook : function
Function to apply to `lags` array prior to selecting maximum lags. This
is usful when `signal_a` and `signal_b` are of a type besides `~numpy.ndarray`.
This function should accept `signal_a` and `lags` and return an array that
looks like `lags`.
post_check_hook : function
array_check_hook : function
Function to apply to the resulting time lag result. This should take in the
result of the time lag selection and return something that an array that looks
like the time lag selection.
`lags` array and the indices that specify the location of the maximum of the
cross-correlation and return an array that has used those indices to select
the `lags` which maximize the cross-correlation. As an example, if `lags`
and `indices` are both `~numpy.ndarray` objects, this would just return
`lags[indices]`. It is probably only necessary to specify this if you
are working with arrays that are something other than a `~numpy.ndarray`
or `~dask.array.Array` object.

Returns
-------
Expand All @@ -208,16 +225,12 @@ def time_lag(signal_a, signal_b, time: u.s, lag_bounds: (u.s, None) = None, **kw
ApJ, 753, 35, 2012
(https://doi.org/10.1088/0004-637X/753/1/35)
"""
pre_check = kwargs.get("pre_check_hook", _dask_check)
post_check = kwargs.get("post_check_hook", lambda x: x)
array_check = kwargs.get("array_check_hook", _dask_check)
lags = get_lags(time)
cc = cross_correlation(signal_a, signal_b, lags)
start, stop = _get_bounds_indices(lags, lag_bounds)
i_max_cc = cc[start:stop].argmax(axis=0)
# The flatten + reshape is needed here because Dask does not like
# "fancy" multidimensional indexing
lags = pre_check(signal_a, lags)
return post_check(lags[start:stop][i_max_cc.flatten()].reshape(i_max_cc.shape))
return array_check(lags[start:stop], i_max_cc)


@u.quantity_input
Expand Down