diff --git a/src/sknnr_spatial/estimator.py b/src/sknnr_spatial/estimator.py index aebfd28..d0942b2 100644 --- a/src/sknnr_spatial/estimator.py +++ b/src/sknnr_spatial/estimator.py @@ -266,7 +266,6 @@ def kneighbors( output_dtypes=[float, int] if return_distance else [int], output_sizes={"k": k}, output_coords={"k": list(range(1, k + 1))}, - output_names=["dist", "nn"] if return_distance else ["nn"], n_neighbors=k, return_distance=return_distance, **kneighbors_kwargs, diff --git a/src/sknnr_spatial/image.py b/src/sknnr_spatial/image.py index a0ffa31..f862d8f 100644 --- a/src/sknnr_spatial/image.py +++ b/src/sknnr_spatial/image.py @@ -226,13 +226,16 @@ def _validate_nodata_vals(self, nodata_vals: NoDataType) -> NDArray | None: def _postprocess( self, result: xr.DataArray, - output_coords: dict[str, list[str | int]] | None = None, + output_coords: dict[str, list[str | int]], ) -> xr.DataArray: """Process the output of an applied ufunc""" if output_coords is not None: result = result.assign_coords(output_coords) + var_dim = list(output_coords.keys())[0] - return result + # apply_gufunc swaps dimension order, so we need to restore it back to + # (band, y, x). + return result.transpose(var_dim, ...) def apply_ufunc_across_bands( self, @@ -254,11 +257,16 @@ def apply_ufunc_across_bands( """ image = self.image - # TODO: Decide on reasonable defaults - output_dims = output_dims or [["output"]] + output_dims = output_dims or [["variable"]] n_outputs = len(output_dims) + # Fall back to float output if unknown output_dtypes = output_dtypes or [np.float32] * n_outputs - output_sizes = output_sizes or {"output": 1} + # If output sizes are not provided, assume a single output coordinate + output_sizes = output_sizes or {"variable": 1} + # Default to sequential coordinates for each output dimension, if not provided + output_coords = output_coords or { + k: list(range(s)) for k, s in output_sizes.items() + } def ufunc(x): return _ImageChunk( @@ -326,13 +334,10 @@ def _validate_nodata_vals(self, nodata_vals: NoDataType) -> NDArray | None: def _postprocess( self, result: xr.DataArray, - output_coords: dict[str, list[str | int]] | None = None, + output_coords: dict[str, list[str | int]], ) -> xr.Dataset: """Process the output of an applied ufunc""" result = super()._postprocess(result, output_coords=output_coords) - # TODO: Once I get the ufunc to respect the dim order, use the band_dim instead - # var_dim = result.dims[self.band_dim] - var_dim = result.dims[-1] - + var_dim = result.dims[self.band_dim] return result.to_dataset(dim=var_dim)