Skip to content

Commit

Permalink
Use tree-reduction in mosaic (#126)
Browse files Browse the repository at this point in the history
This should hopefully result in lower cluster memory usage, since intermediate results can get processed sooner (there's more parallelism, so less vulnerable to root task overproduction)
  • Loading branch information
gjoseph92 committed Feb 2, 2022
1 parent ed2b9f7 commit 92df7f9
Show file tree
Hide file tree
Showing 3 changed files with 251 additions and 8 deletions.
12 changes: 12 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import os
from hypothesis import settings
import pytest

pytest.register_assert_rewrite(
"dask.array.utils", "dask.dataframe.utils", "dask.bag.utils"
)

settings.register_profile("default", settings.default, print_blob=True)
settings.register_profile("ci", max_examples=1000, derandomize=True)

settings.load_profile(os.getenv("HYPOTHESIS_PROFILE", "default"))
166 changes: 158 additions & 8 deletions stackstac/ops.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,166 @@
from typing import Hashable, Sequence, Union
from __future__ import annotations

from functools import partial
from typing import Hashable, Tuple, List, Union, Optional

import numpy as np
import dask.array as da
import xarray as xr


def _mosaic(arr, axis, reverse: bool = False, nodata: Union[int, float] = np.nan):
ax_length = arr.shape[axis]
def _mosaic_base(
arr: np.ndarray,
axis: int,
*,
reverse: bool,
nodata: Union[int, float],
keepdims: bool = False,
initial: Optional[np.ndarray] = None,
) -> Tuple[np.ndarray, bool]:
"""
Internal function to mosaic a NumPy array along an axis. Does some extra funky stuff to be useful with dask.
Parameters
----------
arr: array to mosaic, must be at least 1D
axis: axis to reduce
reverse: False = last on top, True = first on top
nodata: the value to treat as invalid
keepdims: whether to preserve the reduced axis as 1-length
initial:
Array to use as the initial layer, instead of the first index into ``arr``.
Used when chaning mosaics together.
Returns
-------
result, done
The mosaicked array, and a bool of whether mosaicing is complete (every pixel in ``result`` is valid)
"""
if isinstance(axis, tuple):
# `da.reduction` likes to turn int axes into tuples
assert len(axis) == 1, f"Multiple axes to mosaic not supported: {axis=}"
axis = axis[0]

ax_length = arr.shape[axis] if np.ndim(arr) else 1

# "normal" means last -> first, "reversed" means first -> last,
# so we apply `reversed` in the opposite case you'd expect
indices = iter(range(ax_length) if reverse else reversed(range(ax_length)))
out = np.take(arr, next(indices), axis=axis)

if initial is None:
i = next(indices)
out = np.take(arr, [i] if keepdims else i, axis=axis)
else:
out = initial

# TODO implement this with Numba, would be much faster
done = False
for i in indices:
layer = np.take(arr, i, axis=axis)
where = np.isnan(out) if np.isnan(nodata) else out == nodata
if not where.any():
# Short-circuit when mosaic is complete (no invalid pixels left)
done = True
break
layer = np.take(arr, [i] if keepdims else i, axis=axis)
out = np.where(where, layer, out)

# Final `where` may have completed the mosaic, but we wouldn't have checked
if not done:
where = np.isnan(out) if np.isnan(nodata) else out == nodata
done = not where.any()

return out, done


def _mosaic_np(
arr: np.ndarray,
axis: int,
*,
reverse: bool,
nodata: Union[int, float],
keepdims: bool = False,
) -> np.ndarray:
"Mosaic a NumPy array, without returning ``done``"
return _mosaic_base(arr, axis, reverse=reverse, nodata=nodata, keepdims=keepdims)[0]


def _mosaic_dask_aggregate(
arrs: Union[List[np.ndarray], np.ndarray],
axis: Union[int, Tuple[int]],
keepdims: bool,
*,
reverse: bool,
nodata: Union[int, float],
) -> np.ndarray | np.ndarray:
"""
Mosaic a list of NumPy arrays (or a single one) without concatenating them all first.
Avoiding concatenation lets us avoid unnecessary copies and memory spikes.
"""
if not isinstance(arrs, list):
arrs = [arrs]

if isinstance(axis, tuple):
# `da.reduction` likes to turn int axes into tuples
assert len(axis) == 1, f"Multiple axes to mosaic not supported: {axis=}"
axis = axis[0]

out: Optional[np.ndarray] = None
for arr in arrs if reverse else arrs[::-1]:
# ^ Remember, `reverse` is backwards of what you'd think
if out is None and (arr.ndim == 0 or arr.shape[axis] <= 1):
# There's nothing to mosaic in this chunk (it's effectively 1-length along `axis` already).
# Skip mosaicing, just drop the axis if we're supposed to.
out = np.take(arr, 0, axis=axis) if not keepdims else arr
else:
# Multiple entries along the axis: mosaic them, using any
# mosaic we've already built up as a starting point.
out, done = _mosaic_base(
arr,
axis,
reverse=reverse,
nodata=nodata,
keepdims=keepdims,
initial=out,
)
if done:
break

assert out is not None, "Cannot mosaic zero arrays"
return out


def _mosaic_dask(
arr: da.Array,
axis: int,
*,
reverse: bool,
nodata: Union[int, float],
split_every: Union[None, int],
) -> da.Array:
"Tree-reduction-based mosaic for Dask arrays."
return da.reduction(
arr,
partial(_mosaic_np, reverse=reverse, nodata=nodata),
partial(_mosaic_dask_aggregate, reverse=reverse, nodata=nodata),
axis=axis,
keepdims=False,
dtype=arr.dtype,
split_every=split_every,
name="mosaic",
meta=da.utils.meta_from_array(arr, ndim=arr.ndim - 1),
concatenate=False,
)


def mosaic(
arr: xr.DataArray,
dim: Union[None, Hashable, Sequence[Hashable]] = None,
axis: Union[None, int, Sequence[int]] = 0,
dim: Union[None, Hashable] = None,
axis: Union[None, int] = 0,
*,
reverse: bool = False,
nodata: Union[int, float] = np.nan,
split_every: Union[None, int] = None,
):
"""
Flatten a dimension of a `~xarray.DataArray` by picking the first valid pixel.
Expand All @@ -49,6 +186,13 @@ def mosaic(
is used when the array has an integer or boolean dtype. Since NaN
cannot exist in those arrays, this indicates a different ``nodata``
value needs to be used.
split_every:
For Dask arrays: how many *chunks* to mosaic together at once.
The Dask default is 4. A higher number will mean a smaller graph,
but higher peak memory usage.
Ignored for NumPy arrays.
Returns
-------
Expand All @@ -60,8 +204,14 @@ def mosaic(
raise ValueError(
f"Cannot use {nodata=} when mosaicing a {arr.dtype} array, since {nodata} cannot exist in the array."
)

func = (
partial(_mosaic_dask, split_every=split_every)
if isinstance(arr.data, da.Array)
else _mosaic_np
)
return arr.reduce(
_mosaic,
func,
dim=dim,
axis=axis,
keep_attrs=True,
Expand Down
81 changes: 81 additions & 0 deletions stackstac/tests/test_mosaic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from typing import Tuple
from hypothesis import given, assume, strategies as st
from hypothesis.extra import numpy as st_np
import pytest

import xarray as xr
import dask.array as da
from dask.array.utils import assert_eq
import numpy as np

from stackstac.ops import mosaic
from stackstac.testing import strategies as st_stc


@pytest.mark.parametrize("dask", [False, True])
def test_mosaic_basic(dask: bool):
arr = np.array(
[
[np.nan, 1, 2, np.nan],
[np.nan, 10, 20, 30],
[np.nan, 100, 200, np.nan],
]
)
if dask:
arr = da.from_array(arr, chunks=1)

xarr = xr.DataArray(arr)

fwd = mosaic(xarr, axis=0)
# Remember, forward means "last item on top" (because last item == last date == newest)
assert_eq(fwd.data, np.array([np.nan, 100, 200, 30]), equal_nan=True)

rev = mosaic(xarr, axis=0, reverse=True)
assert_eq(rev.data, np.array([np.nan, 1, 2, 30]), equal_nan=True)


@pytest.mark.parametrize("dtype", [np.dtype("bool"), np.dtype("int"), np.dtype("uint")])
def test_mosaic_dtype_error(dtype: np.dtype):
arr = xr.DataArray(np.arange(3).astype(dtype))
with pytest.raises(ValueError, match="Cannot use"):
mosaic(arr)


@given(
st.data(),
st_stc.raster_dtypes,
st_np.array_shapes(max_dims=4, max_side=5),
st.booleans(),
)
def test_fuzz_mosaic(
data: st.DataObject, dtype: np.dtype, shape: Tuple[int, ...], reverse: bool
):
"""
See if we can break mosaic.
Not testing correctness much here, since that's hard to do without rewriting a mosaic implementation.
"""
# `np.where` doesn't seem to preserve endianness.
# Even with our best efforts to add an `astype` at the end, this will fail (on little-endian systems at least?)
# with big-endian dtypes.
assume(dtype.byteorder != ">")

fill_value = data.draw(st_np.from_dtype(dtype), label="fill_value")
arr = data.draw(st_np.arrays(dtype, shape, fill=st.just(fill_value)), label="arr")
chunkshape = data.draw(
st_np.array_shapes(
min_dims=arr.ndim, max_dims=arr.ndim, max_side=max(arr.shape)
), label="chunkshape"
)
axis = data.draw(st.integers(-arr.ndim + 1, arr.ndim - 1), label="axis")

darr = da.from_array(arr, chunks=chunkshape)
split_every = data.draw(st.integers(1, darr.numblocks[axis]), label="split_every")
xarr = xr.DataArray(darr)

result = mosaic(
xarr, axis=axis, reverse=reverse, nodata=fill_value, split_every=split_every
)
assert result.dtype == arr.dtype
result_np = mosaic(xr.DataArray(arr), axis=axis, reverse=reverse, nodata=fill_value)
assert_eq(result.data, result_np.data, equal_nan=True)

0 comments on commit 92df7f9

Please sign in to comment.