Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix dataframe indexing for sknnr estimators #26

Merged
merged 4 commits into from
Jun 26, 2024
Merged

Fix dataframe indexing for sknnr estimators #26

merged 4 commits into from
Jun 26, 2024

Conversation

aazuspan
Copy link
Contributor

This closes #25 and HOPEFULLY solves feature name warnings for the last time.

The previous fix for feature name warnings (#23) was to prevent fitting wrapped estimators with dataframe data, but this had an unintended side effect of breaking any functionality that depended on fitting with dataframes, like returning dataframe indexes in sknnr estimators.

I removed the array conversion to restore dataframe indexing, and added a function suppress_feature_name_warnings that suppresses missing feature name warnings in wrapped functions. By wrapping the functions applied by apply_gufunc (e.g. predict and kneighbors), we can suppress the warning when it arises at compute time, rather than during the eager call to apply_gufunc. We'll need to use that function any time that sklearn runs feature name checks on chunked arrays (e.g. transform in #16), so it may be worth abstracting some of that repeated code out eventually.

I also shuffled the utils modules around to avoid some circular dependency issues. I didn't put a ton of thought into this, so there might be a better way to organize things.

The previous fix for feature name warnings was to prevent fitting
wrapped estimators with dataframe data, but this had a side effect
of breaking any functionality that depended on fitting with
dataframes, like returning dataframe indexes in sknnr estimators.

I removed the array conversion to restore dataframe indexing, and
added a function that that suppresses missing feature name warnings
in wrapped functions. By wrapping the functions applied by
apply_gufunc (i.e. predict and kneighbors), we can suppress the
warning when it arises at compute time.

I also refactored the utils to avoid some circular dependency issues.
@aazuspan aazuspan added the bug Something isn't working label Jun 24, 2024
@aazuspan aazuspan requested a review from grovduck June 24, 2024 22:01
@aazuspan aazuspan self-assigned this Jun 24, 2024
Copy link
Member

@grovduck grovduck left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great changes here! It has the effect of simplifying the code as well to not even having to deal with the array with stripped names until the last possible moment. Very minor stuff from me - otherwise looks great!

I also shuffled the utils modules around to avoid some circular dependency issues. I didn't put a ton of thought into this, so there might be a better way to organize things.

Your organization of the utility functions makes sense to me.

src/sknnr_spatial/utils/estimator.py Show resolved Hide resolved
Comment on lines 56 to 59
# If the wrapped estimator was fit with a dataframe, it will warn about
# missing feature names because this passes unnamed arrays. Suppress that
# and let the wrapper handle feature name checks.
suppress_feature_name_warnings(estimator._wrapped.predict),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The placement of the comment here and on lines 92-95 caught me a bit off guard. One, because it's the identical comment and presumably it might be repeated for other functions (e.g. transform). You mention this in the PR description. Two, because the comment itself is within the apply_gufunc function call, it might be read as if it applies to all arguments. Could there be a private static method in DaskBackedWrapper like:

class DaskBackedWrapper(ImageWrapper[DaskBackedType]):

    @staticmethod
    def _suppressed_estimator_function(func):
        # If the wrapped estimator was fit with a dataframe, it will warn about
        # missing feature names because this passes unnamed arrays. Suppress that
        # and let the wrapper handle feature name checks.
        return suppress_feature_name_warnings(func)

    def predict(self, ...):
        ...
        y_pred = da.apply_gufunc(self._suppressed_estimator_function(estimator._wrapped.predict, ...)

You were suggesting abstracting out the repeated code, so perhaps you're already ahead of me on this one.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, it felt weird to put the same comment in both places! I like the static method suggestion, but saw an opportunity to reduce a little more duplication by refactoring out the shared arguments between the apply_gufunc calls into a private _apply_gufunc method, which also allowed for only suppressing warnings in one spot.

Let me know what you think of that choice. Also, maybe there's a more specific name than _apply_gufunc? I thought about _apply_estimator_gufunc or _apply_sklearn_gufunc or something similar, but I guess there's no reason this wouldn't work with other functions.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, it felt weird to put the same comment in both places! I like the static method suggestion, but saw an opportunity to reduce a little more duplication by refactoring out the shared arguments between the apply_gufunc calls into a private _apply_gufunc method, which also allowed for only suppressing warnings in one spot.

As always, yours is the better solution 😉. This looks good to me and, as you say, takes out a bit more duplication.

Also, maybe there's a more specific name than _apply_gufunc? I thought about _apply_estimator_gufunc or _apply_sklearn_gufunc or something similar, but I guess there's no reason this wouldn't work with other functions.

The current name seems good to me. It's clear that it's just an enhancement of da.apply_gufunc and should be able to be used with other functions as well. I like keeping it as a more generic name. Out of curiosity, I assume much of this code goes away if you move to xarray.apply_ufunc (in a future PR), so this seems like reasonable naming for the short term at least.

tests/test_sknnr.py Outdated Show resolved Hide resolved
@grovduck
Copy link
Member

From what I can tell, this all seems good to me. OK by me to merge.

@aazuspan
Copy link
Contributor Author

Thanks @grovduck!

@aazuspan aazuspan merged commit b28a91b into main Jun 26, 2024
5 checks passed
@aazuspan aazuspan deleted the fix-df-index branch June 26, 2024 19:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Feature name fix is incompatible with return_dataframe_index in sknnr
2 participants