Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch NDArrayImage dimension order to (band, y, x) #32

Merged
merged 4 commits into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions src/sknnr_spatial/datasets/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,15 @@ def _load_rasters_to_dataset(


def _load_rasters_to_array(file_paths: list[Path]) -> NDArray:
"""Load a list of rasters as a numpy array."""
"""Load single-band rasters as a multi-band numpy array of shape (band, y, x)."""
arr = None
for path in file_paths:
with rasterio.open(path) as src:
band = src.read(1)
arr = band if arr is None else np.dstack((arr, band))
# Add a band dimension to the array to allow concatenation
band = band[np.newaxis, ...]

arr = band if arr is None else np.concatenate((arr, band), axis=0)

return arr

Expand Down Expand Up @@ -103,7 +106,8 @@ def load_swo_ecoplot(
Parameters
----------
as_dataset : bool, default=False
If True, return the image data as an `xarray.Dataset` instead of a Numpy array.
If True, return the image data as an `xarray.Dataset`. Otherwise, return a
Numpy array of shape (bands, y, x).
large_rasters : bool, default=False
If True, load the 2048x4096 version of the image data. Otherwise, load the
128x128 version.
Expand All @@ -115,8 +119,8 @@ def load_swo_ecoplot(
Returns
-------
tuple
Image data as either a numpy array or `xarray.Dataset`, and plot data as X and
y dataframes.
Image data as either a numpy array of shape (bands, y, x) or `xarray.Dataset`,
and plot data as X and y dataframes.

Notes
-----
Expand All @@ -135,7 +139,7 @@ def load_swo_ecoplot(
>>> from sknnr_spatial.datasets import load_swo_ecoplot
>>> X_image, X, y = load_swo_ecoplot()
>>> print(X_image.shape)
(128, 128, 18)
(18, 128, 128)

Load the 2048x4096 image data as an xarray Dataset:

Expand Down
206 changes: 97 additions & 109 deletions src/sknnr_spatial/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,16 @@


class _ImageChunk:
"""A chunk of an NDArray in shape (y, x, band)."""
"""
A chunk of an NDArray in shape (y, x, band).

Note that this dimension order is different from the (band, y, x) order used by
rasterio, rioxarray, and elsewhere in sknnr-spatial. This is because `_ImageChunk`
is called via `xr.apply_ufunc` which automatically swaps the core dimension to the
last axis, resulting in arrays of (y, x, band).
"""

band_dim = -1
aazuspan marked this conversation as resolved.
Show resolved Hide resolved

def __init__(
self, array: NDArray, nodata_vals: list[float] | None = None, nan_fill=0.0
Expand All @@ -24,7 +33,7 @@ def __init__(

def _mask_nodata(self, flat_image: NDArray) -> NDArray:
"""
Set NaNs in the flat image where NoData values are present.
Set NaNs in the flat (pixels, band) image where NoData values are present.
"""
# Skip allocating a mask if the image is not float and NoData wasn't given
if (
Expand All @@ -45,13 +54,13 @@ def _mask_nodata(self, flat_image: NDArray) -> NDArray:
mask |= self.flat_array == self.nodata_vals

# Set the mask where any band contains NoData
flat_image[mask.max(axis=-1)] = np.nan
flat_image[mask.max(axis=self.band_dim)] = np.nan

return flat_image

def _preprocess(self, array: NDArray, nan_fill: float = 0.0) -> NDArray:
"""Preprocess the chunk by flattening to (pixels, bands) and filling NaNs."""
flat = array.reshape(-1, array.shape[-1])
flat = array.reshape(-1, array.shape[self.band_dim])
if nan_fill is not None:
flat[np.isnan(flat)] = nan_fill

Expand Down Expand Up @@ -88,8 +97,8 @@ def apply(
class Image(Generic[ImageType], ABC):
"""A wrapper around a multi-band image"""

band_dim_name: str
band_dim: int
band_dim_name: str | None = None
band_dim: int = 0
band_names: NDArray

def __init__(self, image: ImageType, nodata_vals: NoDataType = None):
Expand Down Expand Up @@ -129,24 +138,79 @@ def _validate_nodata_vals(self, nodata_vals: NoDataType) -> NDArray | None:

return np.asarray(nodata_vals)

@abstractmethod
def apply_ufunc_across_bands(
self,
func: Callable[Concatenate[NDArray, P], NDArray],
*,
output_dims: list[list[str]],
output_dtypes: list[np.dtype],
output_sizes: dict[str, int],
output_dtypes: list[np.dtype] | None = None,
output_sizes: dict[str, int] | None = None,
output_coords: dict[str, list[str | int]] | None = None,
nan_fill: float = 0.0,
mask_nodata: bool = True,
**ufunc_kwargs,
) -> ImageType | tuple[ImageType]:
"""Apply a universal function to all bands of the image."""
n_outputs = len(output_dims)

if output_sizes is not None:
# Default to sequential coordinates for each output dimension
output_coords = output_coords or {
k: list(range(s)) for k, s in output_sizes.items()
}

def ufunc(x):
return _ImageChunk(
x, nodata_vals=self.nodata_vals, nan_fill=nan_fill
).apply(
func,
returns_tuple=n_outputs > 1,
mask_nodata=mask_nodata,
**ufunc_kwargs,
)

result = xr.apply_ufunc(
ufunc,
self._preprocess_ufunc_input(self.image),
dask="parallelized",
input_core_dims=[[self.band_dim_name]],
exclude_dims=set((self.band_dim_name,)),
output_core_dims=output_dims,
output_dtypes=output_dtypes,
keep_attrs=True,
dask_gufunc_kwargs=dict(
output_sizes=output_sizes,
allow_rechunk=True,
),
)

if n_outputs > 1:
result = tuple(
self._postprocess_ufunc_output(x, output_coords=output_coords)
for x in result
)
else:
result = self._postprocess_ufunc_output(result, output_coords=output_coords)

return result

def _preprocess_ufunc_input(self, image: ImageType) -> ImageType:
"""
Preprocess the input of an applied ufunc. No-op unless overridden by subclasses.
"""
return image

@abstractmethod
def _postprocess_ufunc_output(
self,
result: ImageType,
output_coords: dict[str, list[str | int]] | None = None,
) -> ImageType:
"""
Apply a universal function to all bands of the image.
Postprocess the output of an applied ufunc.

If the image is backed by a Dask array, the computation will be parallelized
across spatial chunks.
This method should be overridden by subclasses to handle any necessary
transformations to the output data, e.g. transposing dimensions.
"""

@staticmethod
Expand All @@ -165,41 +229,25 @@ def from_image(image: Any, nodata_vals: NoDataType = None) -> Image:


class NDArrayImage(Image):
band_dim = -1
"""An image stored in a Numpy NDArray of shape (band, y, x)."""

band_names = np.array([])

def __init__(self, image: NDArray, nodata_vals: NoDataType = None):
super().__init__(image, nodata_vals=nodata_vals)

def apply_ufunc_across_bands(
self,
func: Callable[Concatenate[NDArray, P], NDArray],
*,
output_dims: list[list[str]],
output_dtypes: list[np.dtype] | None = None,
output_sizes: dict[str, int] | None = None,
output_coords: dict[str, list[str | int]] | None = None,
nan_fill: float = 0.0,
mask_nodata: bool = True,
**ufunc_kwargs,
) -> NDArray | tuple[NDArray]:
n_outputs = len(output_dims)
def _preprocess_ufunc_input(self, image: NDArray) -> NDArray:
"""Preprocess the image by transposing to (y, x, band) for apply_ufunc."""
# Copy to avoid mutating the original image
return image.copy().transpose(1, 2, 0)

return _ImageChunk(
# Copy to avoid mutating the original image
self.image.copy(),
nodata_vals=self.nodata_vals,
nan_fill=nan_fill,
).apply(
func,
returns_tuple=n_outputs > 1,
mask_nodata=mask_nodata,
**ufunc_kwargs,
)
def _postprocess_ufunc_output(self, result: NDArray, output_coords=None) -> NDArray:
"""Postprocess the ufunc output by transposing back to (band, y, x)."""
return result.transpose(2, 0, 1)


class DataArrayImage(Image):
band_dim = 0
"""An image stored in an xarray DataArray of shape (band, y, x)."""

def __init__(self, image: xr.DataArray, nodata_vals: NoDataType = None):
super().__init__(image, nodata_vals=nodata_vals)
Expand All @@ -225,82 +273,22 @@ def _validate_nodata_vals(self, nodata_vals: NoDataType) -> NDArray | None:

return None

def _postprocess(
def _postprocess_ufunc_output(
self,
result: xr.DataArray,
output_coords: dict[str, list[str | int]],
output_coords: dict[str, list[str | int]] | None = None,
) -> xr.DataArray:
"""Process the output of an applied ufunc"""
"""Process the ufunc output by assigning coordinates and transposing."""
if output_coords is not None:
result = result.assign_coords(output_coords)
var_dim = list(output_coords.keys())[0]

# apply_gufunc swaps dimension order, so we need to restore it back to
# (band, y, x).
return result.transpose(var_dim, ...)

def apply_ufunc_across_bands(
self,
func: Callable[Concatenate[NDArray, P], NDArray],
*,
output_dims: list[list[str]],
output_dtypes: list[np.dtype],
output_sizes: dict[str, int],
output_coords: dict[str, list[str | int]] | None = None,
nan_fill: float = 0.0,
mask_nodata: bool = True,
**ufunc_kwargs,
) -> xr.DataArray | tuple[xr.DataArray]:
"""
Apply a universal function to all bands of the image.

If the image is backed by a Dask array, the computation will be parallelized
across spatial chunks.
"""
image = self.image

n_outputs = len(output_dims)
# Default to sequential coordinates for each output dimension, if not provided
output_coords = output_coords or {
k: list(range(s)) for k, s in output_sizes.items()
}

def ufunc(x):
return _ImageChunk(
x, nodata_vals=self.nodata_vals, nan_fill=nan_fill
).apply(
func,
returns_tuple=n_outputs > 1,
mask_nodata=mask_nodata,
**ufunc_kwargs,
)

result = xr.apply_ufunc(
ufunc,
image,
dask="parallelized",
input_core_dims=[[self.band_dim_name]],
exclude_dims=set((self.band_dim_name,)),
output_core_dims=output_dims,
output_dtypes=output_dtypes,
keep_attrs=True,
dask_gufunc_kwargs=dict(
output_sizes=output_sizes,
allow_rechunk=True,
),
)

if n_outputs > 1:
result = tuple(
self._postprocess(x, output_coords=output_coords) for x in result
)
else:
result = self._postprocess(result, output_coords=output_coords)

return result
# Transpose from (y, x, band) to (band, y, x)
return result.transpose(result.dims[-1], ...)


class DatasetImage(DataArrayImage):
"""An image stored in an xarray Dataset of shape (y, x) with bands as variables."""

def __init__(self, image: xr.Dataset, nodata_vals: NoDataType = None):
# The image itself will be stored as a DataArray, but keep the Dataset for
# metadata like _FillValues.
Expand Down Expand Up @@ -329,13 +317,13 @@ def _validate_nodata_vals(self, nodata_vals: NoDataType) -> NDArray | None:
# Fall back to the DataArray logic for handling NoData
return super()._validate_nodata_vals(nodata_vals)

def _postprocess(
def _postprocess_ufunc_output(
self,
result: xr.DataArray,
output_coords: dict[str, list[str | int]],
output_coords: dict[str, list[str | int]] | None = None,
) -> xr.Dataset:
"""Process the output of an applied ufunc"""
result = super()._postprocess(result, output_coords=output_coords)
"""Process the ufunc output converting from DataArray to Dataset."""
result = super()._postprocess_ufunc_output(result, output_coords=output_coords)

var_dim = result.dims[self.band_dim]
return result.to_dataset(dim=var_dim)
Loading
Loading