Skip to content

Commit

Permalink
WIP restore dimension order after applying ufunc
Browse files Browse the repository at this point in the history
  • Loading branch information
aazuspan committed Jun 29, 2024
1 parent dfab8cc commit 123f2ab
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 11 deletions.
1 change: 0 additions & 1 deletion src/sknnr_spatial/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,6 @@ def kneighbors(
output_dtypes=[float, int] if return_distance else [int],
output_sizes={"k": k},
output_coords={"k": list(range(1, k + 1))},
output_names=["dist", "nn"] if return_distance else ["nn"],
n_neighbors=k,
return_distance=return_distance,
**kneighbors_kwargs,
Expand Down
25 changes: 15 additions & 10 deletions src/sknnr_spatial/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,13 +226,16 @@ def _validate_nodata_vals(self, nodata_vals: NoDataType) -> NDArray | None:
def _postprocess(
self,
result: xr.DataArray,
output_coords: dict[str, list[str | int]] | None = None,
output_coords: dict[str, list[str | int]],
) -> xr.DataArray:
"""Process the output of an applied ufunc"""
if output_coords is not None:
result = result.assign_coords(output_coords)
var_dim = list(output_coords.keys())[0]

return result
# apply_gufunc swaps dimension order, so we need to restore it back to
# (band, y, x).
return result.transpose(var_dim, ...)

def apply_ufunc_across_bands(
self,
Expand All @@ -254,11 +257,16 @@ def apply_ufunc_across_bands(
"""
image = self.image

# TODO: Decide on reasonable defaults
output_dims = output_dims or [["output"]]
output_dims = output_dims or [["variable"]]
n_outputs = len(output_dims)
# Fall back to float output if unknown
output_dtypes = output_dtypes or [np.float32] * n_outputs
output_sizes = output_sizes or {"output": 1}
# If output sizes are not provided, assume a single output coordinate
output_sizes = output_sizes or {"variable": 1}
# Default to sequential coordinates for each output dimension, if not provided
output_coords = output_coords or {
k: list(range(s)) for k, s in output_sizes.items()
}

def ufunc(x):
return _ImageChunk(
Expand Down Expand Up @@ -326,13 +334,10 @@ def _validate_nodata_vals(self, nodata_vals: NoDataType) -> NDArray | None:
def _postprocess(
self,
result: xr.DataArray,
output_coords: dict[str, list[str | int]] | None = None,
output_coords: dict[str, list[str | int]],
) -> xr.Dataset:
"""Process the output of an applied ufunc"""
result = super()._postprocess(result, output_coords=output_coords)

# TODO: Once I get the ufunc to respect the dim order, use the band_dim instead
# var_dim = result.dims[self.band_dim]
var_dim = result.dims[-1]

var_dim = result.dims[self.band_dim]
return result.to_dataset(dim=var_dim)

0 comments on commit 123f2ab

Please sign in to comment.