-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix dataframe neighbor indexing for
sknnr
estimators (#26)
* Fix dataframe indexing for sknnr estimators The previous fix for feature name warnings was to prevent fitting wrapped estimators with dataframe data, but this had a side effect of breaking any functionality that depended on fitting with dataframes, like returning dataframe indexes in sknnr estimators. I removed the array conversion to restore dataframe indexing, and added a function that that suppresses missing feature name warnings in wrapped functions. By wrapping the functions applied by apply_gufunc (i.e. predict and kneighbors), we can suppress the warning when it arises at compute time. I also refactored the utils to avoid some circular dependency issues. * Test returning df indices with kneighbors * Check index of all first neighbors * Refactor out duplicated gufunc args
- Loading branch information
Showing
6 changed files
with
150 additions
and
83 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
from functools import wraps | ||
from typing import Callable, Generic | ||
|
||
from typing_extensions import Concatenate, ParamSpec, TypeVar | ||
|
||
from ..types import AnyType | ||
|
||
RT = TypeVar("RT") | ||
P = ParamSpec("P") | ||
|
||
|
||
class AttrWrapper(Generic[AnyType]): | ||
"""A transparent object wrapper that accesses a wrapped object's attributes.""" | ||
|
||
_wrapped: AnyType | ||
|
||
def __init__(self, wrapped: AnyType): | ||
self._wrapped = wrapped | ||
|
||
def __getattr__(self, name: str): | ||
return getattr(self._wrapped, name) | ||
|
||
@property | ||
def __dict__(self): | ||
return self._wrapped.__dict__ | ||
|
||
|
||
GenericWrapper = TypeVar("GenericWrapper", bound=AttrWrapper) | ||
|
||
|
||
def check_wrapper_implements( | ||
func: Callable[Concatenate[GenericWrapper, P], RT], | ||
) -> Callable[Concatenate[GenericWrapper, P], RT]: | ||
"""Decorator that raises if the wrapped instance doesn't implement the method.""" | ||
|
||
@wraps(func) | ||
def wrapper(self: GenericWrapper, *args, **kwargs): | ||
if not hasattr(self._wrapped, func.__name__): | ||
wrapped_class = self._wrapped.__class__.__name__ | ||
msg = f"{wrapped_class} does not implement {func.__name__}." | ||
raise NotImplementedError(msg) | ||
|
||
return func(self, *args, **kwargs) | ||
|
||
return wrapper |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import numpy as np | ||
import xarray as xr | ||
from sknnr import GNNRegressor | ||
|
||
from sknnr_spatial import wrap | ||
|
||
from .image_utils import parametrize_model_data | ||
|
||
|
||
@parametrize_model_data(image_types=(xr.DataArray,)) | ||
def test_kneighbors_returns_df_index(model_data): | ||
"""Test that sknnr estimators return dataframe indices.""" | ||
# Create dummy plot data | ||
X = np.random.rand(10, 3) + 10.0 | ||
y = np.random.rand(10, 3) | ||
|
||
# Create an image of zeros and set the first plot to zeros to ensure that the | ||
# first index is the nearest neighbor to all pixels | ||
X_image = np.zeros((2, 2, 3)) | ||
X[0] = [0, 0, 0] | ||
|
||
# Convert model data to the correct types | ||
X_image, X, y = model_data.set( | ||
X_image=X_image, | ||
X=X, | ||
y=y, | ||
) | ||
|
||
# Offset the dataframe index to differentiate it from the array index | ||
df_index_offset = 999 | ||
X.index += df_index_offset | ||
|
||
est = wrap(GNNRegressor()).fit(X, y) | ||
idx = est.kneighbors(X_image, return_distance=False, return_dataframe_index=False) | ||
df_idx = est.kneighbors(X_image, return_distance=False, return_dataframe_index=True) | ||
|
||
assert idx.shape == df_idx.shape | ||
|
||
# The first neighbor should be the first index for all pixels | ||
assert (idx.sel(variable="k1") == 0).all().compute() | ||
assert (df_idx.sel(variable="k1") == df_index_offset).all().compute() |