Skip to content

Commit

Permalink
Fix dataframe neighbor indexing for sknnr estimators (#26)
Browse files Browse the repository at this point in the history
* Fix dataframe indexing for sknnr estimators

The previous fix for feature name warnings was to prevent fitting
wrapped estimators with dataframe data, but this had a side effect
of breaking any functionality that depended on fitting with
dataframes, like returning dataframe indexes in sknnr estimators.

I removed the array conversion to restore dataframe indexing, and
added a function that that suppresses missing feature name warnings
in wrapped functions. By wrapping the functions applied by
apply_gufunc (i.e. predict and kneighbors), we can suppress the
warning when it arises at compute time.

I also refactored the utils to avoid some circular dependency issues.

* Test returning df indices with kneighbors

* Check index of all first neighbors

* Refactor out duplicated gufunc args
  • Loading branch information
aazuspan authored Jun 26, 2024
1 parent d757ad7 commit b28a91b
Show file tree
Hide file tree
Showing 6 changed files with 150 additions and 83 deletions.
14 changes: 4 additions & 10 deletions src/sknnr_spatial/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,9 @@
from typing_extensions import Literal, overload

from .types import EstimatorType
from .utils.estimator import (
AttrWrapper,
check_wrapper_implements,
image_or_fallback,
is_fitted,
)
from .utils.image import get_image_wrapper
from .utils.estimator import is_fitted
from .utils.image import get_image_wrapper, image_or_fallback
from .utils.wrapper import AttrWrapper, check_wrapper_implements

if TYPE_CHECKING:
import pandas as pd
Expand Down Expand Up @@ -115,9 +111,7 @@ def fit(self, X, y=None, **kwargs) -> ImageEstimator[EstimatorType]:
# to (n_samples,), which has a consistent output shape.
y = y.squeeze()

# Cast X to array before fitting to prevent the estimator from storing feature
# names. We implement our own feature name checks that are image-compatible.
self._wrapped = self._wrapped.fit(np.asarray(X), y, **kwargs)
self._wrapped = self._wrapped.fit(X, y, **kwargs)
fitted_feature_names = _get_feature_names(X)

self._wrapped_meta = FittedMetadata(
Expand Down
35 changes: 24 additions & 11 deletions src/sknnr_spatial/image/_dask_backed.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from sklearn.utils.validation import check_is_fitted

from ..types import DaskBackedType
from ..utils.estimator import suppress_feature_name_warnings
from ._base import ImageWrapper

if TYPE_CHECKING:
Expand Down Expand Up @@ -51,14 +52,11 @@ def predict(
estimator_type = getattr(estimator, "_estimator_type", "")
output_dtype = ESTIMATOR_OUTPUT_DTYPES.get(estimator_type, np.float64)

y_pred = da.apply_gufunc(
y_pred = self._apply_gufunc(
estimator._wrapped.predict,
signature,
self.preprocessor.flat,
axis=self.preprocessor.flat_band_dim,
output_dtypes=[output_dtype],
signature=signature,
output_sizes=output_sizes,
allow_rechunk=True,
output_dtypes=[output_dtype],
)

# Reshape from (n_samples,) to (n_samples, 1)
Expand All @@ -84,14 +82,11 @@ def kneighbors(
signature = "(x)->(k)" if not return_distance else "(x)->(k),(k)"
output_dtypes: list[type] = [int] if not return_distance else [float, int]

result = da.apply_gufunc(
result = self._apply_gufunc(
estimator._wrapped.kneighbors,
signature,
self.preprocessor.flat,
signature=signature,
output_sizes={"k": k},
output_dtypes=output_dtypes,
axis=self.preprocessor.flat_band_dim,
allow_rechunk=True,
n_neighbors=n_neighbors,
return_distance=return_distance,
**kneighbors_kwargs,
Expand All @@ -106,3 +101,21 @@ def kneighbors(
return dist, nn

return self.preprocessor.unflatten(result, var_names=var_names)

def _apply_gufunc(self, func, *, signature, output_sizes, output_dtypes, **kwargs):
"""Apply a gufunc to the image across bands."""
# sklearn estimator methods like `predict` may warn about missing feature
# names because this passes unnamed arrays. We can suppress those and let
# the wrapper handle feature name checks.
suppressed_func = suppress_feature_name_warnings(func)

return da.apply_gufunc(
suppressed_func,
signature,
self.preprocessor.flat,
output_sizes=output_sizes,
output_dtypes=output_dtypes,
axis=self.preprocessor.flat_band_dim,
allow_rechunk=True,
**kwargs,
)
74 changes: 13 additions & 61 deletions src/sknnr_spatial/utils/estimator.py
Original file line number Diff line number Diff line change
@@ -1,67 +1,7 @@
from functools import wraps
from typing import Callable, Generic
import warnings

from sklearn.base import BaseEstimator
from sklearn.utils.validation import NotFittedError, check_is_fitted
from typing_extensions import Concatenate, ParamSpec, TypeVar

from ..image._base import ImageType
from ..types import AnyType
from .image import is_image_type

RT = TypeVar("RT")
P = ParamSpec("P")


class AttrWrapper(Generic[AnyType]):
"""A transparent object wrapper that accesses a wrapped object's attributes."""

_wrapped: AnyType

def __init__(self, wrapped: AnyType):
self._wrapped = wrapped

def __getattr__(self, name: str):
return getattr(self._wrapped, name)

@property
def __dict__(self):
return self._wrapped.__dict__


GenericWrapper = TypeVar("GenericWrapper", bound=AttrWrapper)


def check_wrapper_implements(
func: Callable[Concatenate[GenericWrapper, P], RT],
) -> Callable[Concatenate[GenericWrapper, P], RT]:
"""Decorator that raises if the wrapped instance doesn't implement the method."""

@wraps(func)
def wrapper(self: GenericWrapper, *args, **kwargs):
if not hasattr(self._wrapped, func.__name__):
wrapped_class = self._wrapped.__class__.__name__
msg = f"{wrapped_class} does not implement {func.__name__}."
raise NotImplementedError(msg)

return func(self, *args, **kwargs)

return wrapper


def image_or_fallback(
func: Callable[Concatenate[GenericWrapper, ImageType, P], RT],
) -> Callable[Concatenate[GenericWrapper, ImageType, P], RT]:
"""Decorator that calls the wrapped method for non-image X arrays."""

@wraps(func)
def wrapper(self: GenericWrapper, X_image: ImageType, *args, **kwargs):
if not is_image_type(X_image):
return getattr(self._wrapped, func.__name__)(X_image, *args, **kwargs)

return func(self, X_image, *args, **kwargs)

return wrapper


def is_fitted(estimator: BaseEstimator) -> bool:
Expand All @@ -71,3 +11,15 @@ def is_fitted(estimator: BaseEstimator) -> bool:
return True
except NotFittedError:
return False


def suppress_feature_name_warnings(func):
"""Suppress warnings related to missing feature names in a wrapped function."""
msg = "X does not have valid feature names"

def wrapper(*args, **kwargs):
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message=msg)
return func(*args, **kwargs)

return wrapper
24 changes: 23 additions & 1 deletion src/sknnr_spatial/utils/image.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
from functools import wraps
from typing import Callable

import numpy as np
import xarray as xr
from typing_extensions import Any
from typing_extensions import Any, Concatenate, ParamSpec, TypeVar

from ..image._base import ImagePreprocessor, ImageType, ImageWrapper
from ..image.dataarray import DataArrayWrapper
from ..image.dataset import DatasetWrapper
from ..image.ndarray import NDArrayWrapper
from .wrapper import GenericWrapper

RT = TypeVar("RT")
P = ParamSpec("P")


def is_image_type(X: Any) -> bool:
Expand All @@ -20,6 +27,21 @@ def is_image_type(X: Any) -> bool:
return False


def image_or_fallback(
func: Callable[Concatenate[GenericWrapper, ImageType, P], RT],
) -> Callable[Concatenate[GenericWrapper, ImageType, P], RT]:
"""Decorator that calls the wrapped method for non-image X arrays."""

@wraps(func)
def wrapper(self: GenericWrapper, X_image: ImageType, *args, **kwargs):
if not is_image_type(X_image):
return getattr(self._wrapped, func.__name__)(X_image, *args, **kwargs)

return func(self, X_image, *args, **kwargs)

return wrapper


def get_image_wrapper(X_image: ImageType) -> type[ImageWrapper]:
"""Get an ImageWrapper subclass for a given image."""
if isinstance(X_image, np.ndarray):
Expand Down
45 changes: 45 additions & 0 deletions src/sknnr_spatial/utils/wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from functools import wraps
from typing import Callable, Generic

from typing_extensions import Concatenate, ParamSpec, TypeVar

from ..types import AnyType

RT = TypeVar("RT")
P = ParamSpec("P")


class AttrWrapper(Generic[AnyType]):
"""A transparent object wrapper that accesses a wrapped object's attributes."""

_wrapped: AnyType

def __init__(self, wrapped: AnyType):
self._wrapped = wrapped

def __getattr__(self, name: str):
return getattr(self._wrapped, name)

@property
def __dict__(self):
return self._wrapped.__dict__


GenericWrapper = TypeVar("GenericWrapper", bound=AttrWrapper)


def check_wrapper_implements(
func: Callable[Concatenate[GenericWrapper, P], RT],
) -> Callable[Concatenate[GenericWrapper, P], RT]:
"""Decorator that raises if the wrapped instance doesn't implement the method."""

@wraps(func)
def wrapper(self: GenericWrapper, *args, **kwargs):
if not hasattr(self._wrapped, func.__name__):
wrapped_class = self._wrapped.__class__.__name__
msg = f"{wrapped_class} does not implement {func.__name__}."
raise NotImplementedError(msg)

return func(self, *args, **kwargs)

return wrapper
41 changes: 41 additions & 0 deletions tests/test_sknnr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import numpy as np
import xarray as xr
from sknnr import GNNRegressor

from sknnr_spatial import wrap

from .image_utils import parametrize_model_data


@parametrize_model_data(image_types=(xr.DataArray,))
def test_kneighbors_returns_df_index(model_data):
"""Test that sknnr estimators return dataframe indices."""
# Create dummy plot data
X = np.random.rand(10, 3) + 10.0
y = np.random.rand(10, 3)

# Create an image of zeros and set the first plot to zeros to ensure that the
# first index is the nearest neighbor to all pixels
X_image = np.zeros((2, 2, 3))
X[0] = [0, 0, 0]

# Convert model data to the correct types
X_image, X, y = model_data.set(
X_image=X_image,
X=X,
y=y,
)

# Offset the dataframe index to differentiate it from the array index
df_index_offset = 999
X.index += df_index_offset

est = wrap(GNNRegressor()).fit(X, y)
idx = est.kneighbors(X_image, return_distance=False, return_dataframe_index=False)
df_idx = est.kneighbors(X_image, return_distance=False, return_dataframe_index=True)

assert idx.shape == df_idx.shape

# The first neighbor should be the first index for all pixels
assert (idx.sel(variable="k1") == 0).all().compute()
assert (df_idx.sel(variable="k1") == df_index_offset).all().compute()

0 comments on commit b28a91b

Please sign in to comment.