Skip to content

Commit

Permalink
WIP set gufunc params based on estimator metadata
Browse files Browse the repository at this point in the history
  • Loading branch information
aazuspan committed Jun 28, 2024
1 parent ef25886 commit 792feb5
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 26 deletions.
44 changes: 23 additions & 21 deletions src/sknnr_spatial/estimator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, cast
from warnings import warn

import numpy as np
Expand All @@ -21,6 +21,12 @@

from .types import ImageType, NoDataType

ESTIMATOR_OUTPUT_DTYPES: dict[str, np.dtype] = {
"classifier": np.int32,
"clusterer": np.int32,
"regressor": np.float64,
}


@dataclass
class FittedMetadata:
Expand Down Expand Up @@ -155,18 +161,23 @@ def predict(
y_image : Numpy or Xarray image with 3 dimensions (y, x, targets)
The predicted values.
"""

output_dim_name = "variable"
image = Image(X_image, nodata_vals=nodata_vals)

# TODO: Re-implement once Image can parse band names
# self._check_feature_names(wrapper.preprocessor.band_names)

# Any estimator with an undefined type should fall back to floating
# point for safety.
estimator_type = getattr(self._wrapped, "_estimator_type", "")
output_dtype = ESTIMATOR_OUTPUT_DTYPES.get(estimator_type, np.float64)

return image.apply_ufunc_across_bands(
suppress_feature_name_warnings(self._wrapped.predict),
# TODO: Set these correctly based on image properties
output_dims=[["variable"]],
output_dtypes=[float],
output_sizes={"variable": 25},
output_dims=[[output_dim_name]],
output_dtypes=[output_dtype],
output_sizes={output_dim_name: self._wrapped_meta.n_targets},
output_coords={output_dim_name: list(self._wrapped_meta.target_names)},
**predict_kwargs,
)

Expand Down Expand Up @@ -244,27 +255,18 @@ def kneighbors(
Indices of the nearest points in the population matrix.
"""
image = Image(X_image, nodata_vals=nodata_vals)
k = n_neighbors or cast(int, getattr(self._wrapped, "n_neighbors", 5))

# TODO: Re-implement
# self._check_feature_names(wrapper.preprocessor.band_names)

# TODO: Get the correct values for these
n_neighbors = n_neighbors if n_neighbors is not None else 5
output_sizes = {"k": n_neighbors}

if return_distance:
output_core_dims = [["k"], ["k"]]
output_dtypes = [float, int]
else:
output_core_dims = [["k"]]
output_dtypes = [int]

return image.apply_ufunc_across_bands(
suppress_feature_name_warnings(self._wrapped.kneighbors),
output_dims=output_core_dims,
output_dtypes=output_dtypes,
output_sizes=output_sizes,
n_neighbors=n_neighbors,
output_dims=[["k"], ["k"]] if return_distance else [["k"]],
output_dtypes=[float, int] if return_distance else [int],
output_sizes={"k": k},
output_coords={"k": list(range(1, k + 1))},
n_neighbors=k,
return_distance=return_distance,
**kneighbors_kwargs,
)
Expand Down
24 changes: 19 additions & 5 deletions src/sknnr_spatial/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,11 @@ def apply_ufunc_across_bands(
output_dims: list[list[str]] | None = None,
output_dtypes: list[np.dtype] | None = None,
output_sizes: dict[str, int] | None = None,
output_coords: dict[str, list[str | int]] | None = None,
nan_fill: float = 0.0,
mask_nodata: bool = True,
**ufunc_kwargs,
) -> ImageType:
) -> ImageType | tuple[ImageType]:
"""
Apply a universal function to all bands of the image.
Expand All @@ -117,7 +118,7 @@ def apply_ufunc_across_bands(
def ufunc(x):
return _ImageChunk(x, nodata_vals=self.nodata_vals).apply(
func,
returns_tuple=len(output_dims) > 1,
returns_tuple=n_outputs > 1,
nan_fill=nan_fill,
mask_nodata=mask_nodata,
**ufunc_kwargs,
Expand All @@ -127,11 +128,9 @@ def ufunc(x):
return ufunc(image)

if isinstance(image, xr.Dataset):
# TODO: Convert back to dataset after predicting
image = image.to_dataarray()

# TODO: Assign target dim names
return xr.apply_ufunc(
result = xr.apply_ufunc(
ufunc,
image,
dask="parallelized",
Expand All @@ -144,3 +143,18 @@ def ufunc(x):
allow_rechunk=True,
),
)

def postprocess(x):
if output_coords is not None:
x = x.assign_coords(output_coords)

# TODO: Convert back to dataset

return x

if n_outputs > 1:
result = tuple(postprocess(x) for x in result)
else:
result = postprocess(result)

return result

0 comments on commit 792feb5

Please sign in to comment.