Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH Add Array API compatibility to MinMaxScaler #26243

Merged
merged 27 commits into from
Aug 18, 2023

Conversation

betatim
Copy link
Member

@betatim betatim commented Apr 21, 2023

Reference Issues/PRs

Towards #26024

What does this implement/fix? Explain your changes.

This enables MinMaxScaler to work with Array API compatible arrays. Most of the changes are replacing np with xp (which represents the namespace the array the user passed in belongs to). Had to implement some helpers like nanmin and nanmax.

Any other comments?

Copy link
Member

@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks fairly straight forward. We'll need tests to make sure PyTorch works on the CI and gives correct results.

For completeness, inverse_transform can also be updated with array_api.

@betatim betatim marked this pull request as ready for review April 26, 2023 12:02
@betatim
Copy link
Member Author

betatim commented Apr 26, 2023

Done and done.

For the tests I more or less took the tests that exist for LinearDiscriminantAnalysis. They are quite similar but not identical. I was thinking I'll make a new PR that tries to make a set of "common tests". Probably via adding a new estimator tag that can be used to discover things that support the array API. WDYT?

Copy link
Member

@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am uneasy about supporting xp.float16. Adding a float16 property to _ArrayAPIWrapper feels like we are "extending the Array API spec".

My preference is to be strict and only allow float16 when the arrays are ndarray. This means we do not support float16 for numpy.array_api Arrays, torch, cupy.array_api, etc.

@betatim
Copy link
Member Author

betatim commented May 2, 2023

My preference is to be strict and only allow float16 when the arrays are ndarray. This means we do not support float16 for numpy.array_api Arrays, torch, cupy.array_api, etc.

Fine for me. Though as far as I can see numpy.array_api and cupy.array_api will be the only namespaces for which we won't have float16. For cupy, torch and numpy it exists.

Copy link
Member

@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I am mostly onboard with supporting float16.

@ogrisel ogrisel mentioned this pull request May 5, 2023
7 tasks
Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Besides the following and @thomasjpfan's comment on isneginf, LGTM.

Not sure how to efficiently cope with the lack of a generic isneginf though.

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Beyond making atol stricter in the tests and the following suggestions to improve comments / docstrings, LGTM!

@@ -357,6 +368,35 @@ def _expit(X):
return 1.0 / (1.0 + xp.exp(-X))


def _isneginf(X):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO for Tim: open issue in Array API about asymmetry between isinf and isneginf

@ogrisel
Copy link
Member

ogrisel commented Jun 14, 2023

I think this PR needs to be updated to leverage the new common test recently merged in main.

@betatim betatim force-pushed the minmax-array-api branch from f28eab5 to 8e8e965 Compare June 29, 2023 12:15
@github-actions
Copy link

github-actions bot commented Jun 29, 2023

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: 54e7f5f. Link to the linter CI: here

@betatim
Copy link
Member Author

betatim commented Jul 4, 2023

I think this is ready to go now. I'd like to add more tests a la:

@parametrize_with_checks(
    [
        MinMaxScaler()
    ],
    check_yielder=_yield_array_api_checks,
)
def test_array_api_compliance(check, estimator, request):
    check(estimator, check_values=True)

but this uses things from #26315. Either we wait to merge this PR until #26315 has landed or we make another PR later to add those tests. I'm fine with either. Maybe waiting for #26315 is cleaner/requires less bookkeeping work.

@ogrisel
Copy link
Member

ogrisel commented Jul 6, 2023

I just reverted the change made to parametrize_with_checks because in the end I no longer needed it in #26315. Instead I used a cross-product of regular pytest.mark.parametrize:

https://github.com/scikit-learn/scikit-learn/pull/26315/files#diff-7fdbfa7b6564912adf698653f0faa9248c7ad600cdccf607451b5c500dbf6b4cR731-R755

Note: the order of the decorators matters for the name of the tests.

@betatim betatim force-pushed the minmax-array-api branch from 547f487 to 2d690a7 Compare July 14, 2023 09:49
Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update. Still LGTM.

@betatim betatim added Waiting for Second Reviewer First reviewer is done, need a second one! and removed Waiting for Reviewer labels Jul 20, 2023
@betatim
Copy link
Member Author

betatim commented Jul 20, 2023

@thomasjpfan what do you think?

Copy link
Member

@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add the estimator to:

Estimators with support for `Array API`-compatible inputs
=========================================================

I ran this on my local machine with CuPy + PyTorch and the new tests pass.

Copy link
Member

@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor comment, otherwise LGTM

@thomasjpfan thomasjpfan changed the title Add Array API compatibility to MinMaxScaler ENH Add Array API compatibility to MinMaxScaler Aug 7, 2023
@thomasjpfan thomasjpfan enabled auto-merge (squash) August 18, 2023 12:47
@thomasjpfan thomasjpfan merged commit 061f877 into scikit-learn:main Aug 18, 2023
@betatim betatim deleted the minmax-array-api branch August 18, 2023 15:35
TamaraAtanasoska pushed a commit to TamaraAtanasoska/scikit-learn that referenced this pull request Aug 21, 2023
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
akaashpatelmns pushed a commit to akaashp2000/scikit-learn that referenced this pull request Aug 25, 2023
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
REDVM pushed a commit to REDVM/scikit-learn that referenced this pull request Nov 16, 2023
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants