From 92df7f9d2d47684f78a7aaeb847531b3a47a3dde Mon Sep 17 00:00:00 2001 From: Gabe Joseph Date: Tue, 1 Feb 2022 22:43:24 -0700 Subject: [PATCH] Use tree-reduction in mosaic (#126) This should hopefully result in lower cluster memory usage, since intermediate results can get processed sooner (there's more parallelism, so less vulnerable to root task overproduction) --- conftest.py | 12 +++ stackstac/ops.py | 166 +++++++++++++++++++++++++++++++-- stackstac/tests/test_mosaic.py | 81 ++++++++++++++++ 3 files changed, 251 insertions(+), 8 deletions(-) create mode 100644 conftest.py create mode 100644 stackstac/tests/test_mosaic.py diff --git a/conftest.py b/conftest.py new file mode 100644 index 0000000..55049a3 --- /dev/null +++ b/conftest.py @@ -0,0 +1,12 @@ +import os +from hypothesis import settings +import pytest + +pytest.register_assert_rewrite( + "dask.array.utils", "dask.dataframe.utils", "dask.bag.utils" +) + +settings.register_profile("default", settings.default, print_blob=True) +settings.register_profile("ci", max_examples=1000, derandomize=True) + +settings.load_profile(os.getenv("HYPOTHESIS_PROFILE", "default")) diff --git a/stackstac/ops.py b/stackstac/ops.py index a8570ed..adc0f4d 100644 --- a/stackstac/ops.py +++ b/stackstac/ops.py @@ -1,29 +1,166 @@ -from typing import Hashable, Sequence, Union +from __future__ import annotations + +from functools import partial +from typing import Hashable, Tuple, List, Union, Optional + import numpy as np +import dask.array as da import xarray as xr -def _mosaic(arr, axis, reverse: bool = False, nodata: Union[int, float] = np.nan): - ax_length = arr.shape[axis] +def _mosaic_base( + arr: np.ndarray, + axis: int, + *, + reverse: bool, + nodata: Union[int, float], + keepdims: bool = False, + initial: Optional[np.ndarray] = None, +) -> Tuple[np.ndarray, bool]: + """ + Internal function to mosaic a NumPy array along an axis. Does some extra funky stuff to be useful with dask. + + Parameters + ---------- + arr: array to mosaic, must be at least 1D + axis: axis to reduce + reverse: False = last on top, True = first on top + nodata: the value to treat as invalid + keepdims: whether to preserve the reduced axis as 1-length + initial: + Array to use as the initial layer, instead of the first index into ``arr``. + Used when chaning mosaics together. + + Returns + ------- + result, done + The mosaicked array, and a bool of whether mosaicing is complete (every pixel in ``result`` is valid) + """ + if isinstance(axis, tuple): + # `da.reduction` likes to turn int axes into tuples + assert len(axis) == 1, f"Multiple axes to mosaic not supported: {axis=}" + axis = axis[0] + + ax_length = arr.shape[axis] if np.ndim(arr) else 1 # "normal" means last -> first, "reversed" means first -> last, # so we apply `reversed` in the opposite case you'd expect indices = iter(range(ax_length) if reverse else reversed(range(ax_length))) - out = np.take(arr, next(indices), axis=axis) + if initial is None: + i = next(indices) + out = np.take(arr, [i] if keepdims else i, axis=axis) + else: + out = initial + + # TODO implement this with Numba, would be much faster + done = False for i in indices: - layer = np.take(arr, i, axis=axis) where = np.isnan(out) if np.isnan(nodata) else out == nodata + if not where.any(): + # Short-circuit when mosaic is complete (no invalid pixels left) + done = True + break + layer = np.take(arr, [i] if keepdims else i, axis=axis) out = np.where(where, layer, out) + + # Final `where` may have completed the mosaic, but we wouldn't have checked + if not done: + where = np.isnan(out) if np.isnan(nodata) else out == nodata + done = not where.any() + + return out, done + + +def _mosaic_np( + arr: np.ndarray, + axis: int, + *, + reverse: bool, + nodata: Union[int, float], + keepdims: bool = False, +) -> np.ndarray: + "Mosaic a NumPy array, without returning ``done``" + return _mosaic_base(arr, axis, reverse=reverse, nodata=nodata, keepdims=keepdims)[0] + + +def _mosaic_dask_aggregate( + arrs: Union[List[np.ndarray], np.ndarray], + axis: Union[int, Tuple[int]], + keepdims: bool, + *, + reverse: bool, + nodata: Union[int, float], +) -> np.ndarray | np.ndarray: + """ + Mosaic a list of NumPy arrays (or a single one) without concatenating them all first. + + Avoiding concatenation lets us avoid unnecessary copies and memory spikes. + """ + if not isinstance(arrs, list): + arrs = [arrs] + + if isinstance(axis, tuple): + # `da.reduction` likes to turn int axes into tuples + assert len(axis) == 1, f"Multiple axes to mosaic not supported: {axis=}" + axis = axis[0] + + out: Optional[np.ndarray] = None + for arr in arrs if reverse else arrs[::-1]: + # ^ Remember, `reverse` is backwards of what you'd think + if out is None and (arr.ndim == 0 or arr.shape[axis] <= 1): + # There's nothing to mosaic in this chunk (it's effectively 1-length along `axis` already). + # Skip mosaicing, just drop the axis if we're supposed to. + out = np.take(arr, 0, axis=axis) if not keepdims else arr + else: + # Multiple entries along the axis: mosaic them, using any + # mosaic we've already built up as a starting point. + out, done = _mosaic_base( + arr, + axis, + reverse=reverse, + nodata=nodata, + keepdims=keepdims, + initial=out, + ) + if done: + break + + assert out is not None, "Cannot mosaic zero arrays" return out +def _mosaic_dask( + arr: da.Array, + axis: int, + *, + reverse: bool, + nodata: Union[int, float], + split_every: Union[None, int], +) -> da.Array: + "Tree-reduction-based mosaic for Dask arrays." + return da.reduction( + arr, + partial(_mosaic_np, reverse=reverse, nodata=nodata), + partial(_mosaic_dask_aggregate, reverse=reverse, nodata=nodata), + axis=axis, + keepdims=False, + dtype=arr.dtype, + split_every=split_every, + name="mosaic", + meta=da.utils.meta_from_array(arr, ndim=arr.ndim - 1), + concatenate=False, + ) + + def mosaic( arr: xr.DataArray, - dim: Union[None, Hashable, Sequence[Hashable]] = None, - axis: Union[None, int, Sequence[int]] = 0, + dim: Union[None, Hashable] = None, + axis: Union[None, int] = 0, + *, reverse: bool = False, nodata: Union[int, float] = np.nan, + split_every: Union[None, int] = None, ): """ Flatten a dimension of a `~xarray.DataArray` by picking the first valid pixel. @@ -49,6 +186,13 @@ def mosaic( is used when the array has an integer or boolean dtype. Since NaN cannot exist in those arrays, this indicates a different ``nodata`` value needs to be used. + split_every: + For Dask arrays: how many *chunks* to mosaic together at once. + + The Dask default is 4. A higher number will mean a smaller graph, + but higher peak memory usage. + + Ignored for NumPy arrays. Returns ------- @@ -60,8 +204,14 @@ def mosaic( raise ValueError( f"Cannot use {nodata=} when mosaicing a {arr.dtype} array, since {nodata} cannot exist in the array." ) + + func = ( + partial(_mosaic_dask, split_every=split_every) + if isinstance(arr.data, da.Array) + else _mosaic_np + ) return arr.reduce( - _mosaic, + func, dim=dim, axis=axis, keep_attrs=True, diff --git a/stackstac/tests/test_mosaic.py b/stackstac/tests/test_mosaic.py new file mode 100644 index 0000000..9378f99 --- /dev/null +++ b/stackstac/tests/test_mosaic.py @@ -0,0 +1,81 @@ +from typing import Tuple +from hypothesis import given, assume, strategies as st +from hypothesis.extra import numpy as st_np +import pytest + +import xarray as xr +import dask.array as da +from dask.array.utils import assert_eq +import numpy as np + +from stackstac.ops import mosaic +from stackstac.testing import strategies as st_stc + + +@pytest.mark.parametrize("dask", [False, True]) +def test_mosaic_basic(dask: bool): + arr = np.array( + [ + [np.nan, 1, 2, np.nan], + [np.nan, 10, 20, 30], + [np.nan, 100, 200, np.nan], + ] + ) + if dask: + arr = da.from_array(arr, chunks=1) + + xarr = xr.DataArray(arr) + + fwd = mosaic(xarr, axis=0) + # Remember, forward means "last item on top" (because last item == last date == newest) + assert_eq(fwd.data, np.array([np.nan, 100, 200, 30]), equal_nan=True) + + rev = mosaic(xarr, axis=0, reverse=True) + assert_eq(rev.data, np.array([np.nan, 1, 2, 30]), equal_nan=True) + + +@pytest.mark.parametrize("dtype", [np.dtype("bool"), np.dtype("int"), np.dtype("uint")]) +def test_mosaic_dtype_error(dtype: np.dtype): + arr = xr.DataArray(np.arange(3).astype(dtype)) + with pytest.raises(ValueError, match="Cannot use"): + mosaic(arr) + + +@given( + st.data(), + st_stc.raster_dtypes, + st_np.array_shapes(max_dims=4, max_side=5), + st.booleans(), +) +def test_fuzz_mosaic( + data: st.DataObject, dtype: np.dtype, shape: Tuple[int, ...], reverse: bool +): + """ + See if we can break mosaic. + + Not testing correctness much here, since that's hard to do without rewriting a mosaic implementation. + """ + # `np.where` doesn't seem to preserve endianness. + # Even with our best efforts to add an `astype` at the end, this will fail (on little-endian systems at least?) + # with big-endian dtypes. + assume(dtype.byteorder != ">") + + fill_value = data.draw(st_np.from_dtype(dtype), label="fill_value") + arr = data.draw(st_np.arrays(dtype, shape, fill=st.just(fill_value)), label="arr") + chunkshape = data.draw( + st_np.array_shapes( + min_dims=arr.ndim, max_dims=arr.ndim, max_side=max(arr.shape) + ), label="chunkshape" + ) + axis = data.draw(st.integers(-arr.ndim + 1, arr.ndim - 1), label="axis") + + darr = da.from_array(arr, chunks=chunkshape) + split_every = data.draw(st.integers(1, darr.numblocks[axis]), label="split_every") + xarr = xr.DataArray(darr) + + result = mosaic( + xarr, axis=axis, reverse=reverse, nodata=fill_value, split_every=split_every + ) + assert result.dtype == arr.dtype + result_np = mosaic(xr.DataArray(arr), axis=axis, reverse=reverse, nodata=fill_value) + assert_eq(result.data, result_np.data, equal_nan=True)