-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
TST enable non-CPU device testing via array-api-strict #30090
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TST enable non-CPU device testing via array-api-strict #30090
Conversation
Here is the output of
on my machine with the
This same test passes with PyTorch and CUDA or MPS devices, so I suspect that the lack of device propagation in the computation of the |
I think we will need data-apis/array-api-strict#73 and data-apis/array-api-strict#72 for this PR to work |
Another issue that needs resolving scipy/scipy#21736 |
We should update the lock file to try to run the tests in this PR with the new version of EDIT: the lock files have probably already been updated in |
@betatim I started to update this PR: it discovered several device handling issues and maybe dtype related issues. I have not yet fixed them all, feel free to take over at any point :) |
To compute batch sizes and memory sizes we don't need to use the array API, we can do that math with "just" Python types. This change also fixes a slicing error that only appears with array-api-strict. Unrelated to changing to Python types.
The scipy implementation contains a bug with respect to setting the device of all the arrays it creates. This adds xlogy() to our group of functions we implement ourselves.
Using this in functions that support the xp short circuiting, so I think it makes sense to make this function look similar to get_namespace
I pinged Omar and Guillaume for reviews. You don't have to review this, but I thought it might be interesting for you two (and solve the problem that neither Oliver nor I can approve this). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This reduces the amount of `xp.asarray` that we need to convert scalars to arrays for the array API
Co-authored-by: Omar Salman <omar.salman2007@gmail.com>
…e_deviance" This reverts commit 920932f.
@OmarManzoor @betatim after #30090 (comment), the code is simpler, and all tests pass everywhere. +1 for merge on my side. |
👍 Let's wait for the CI to complete and I'll review and merge |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is an early draft PR to attempt to leverage multi device support recently merged in
array-api-strict
: data-apis/array-api-strict#59We need to wait for a release of
array-api-strict
+ a lock file update to actually get this to run on our CI.However, I think we should investigate failures early in scikit-learn because I suspect that some (most?) of them are not necessarily a problem in scikit-learn but might be bugs in
array-api-strict
's device support itself./cc @betatim