-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
This PR adds JAX as a new nplike #1399
Conversation
Codecov Report
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this!
Some large-scale refactoring things:
- The
Jax
nplike is not much different from theCupy
one, and theCupy
one is in theak.nplike
module. I think theJax
one could go here (and it needs to be possible to declare the class without importingjax
. Whenever you have something that depends on two attributes, like "is an nplike" and "is for JAX", there's always a question of whether to organize them such that all the nplikes are together or to organize them such that all the JAX-related stuff is together. Let's organize it with all nplikes together, because that will enforce that they behave the same way with respect to when the third-party library gets imported. Some nplikes, like the TypeTracer, are in other places, but that's because the TypeTracer is very different from the nplikes for third-party libraries. - Modules in
ak._v2.operations.*.*
should each provide one high-level function. See, for instance,ak_from_parquet
andak_metadata_from_parquet
. There's a lot of overlap between these two, but they're in separate files to maintain that structure. So please splitak_from_lib
back intoak_from_cupy
andak_from_jax
(same for the "to" direction). - Sharing code in the implementations of
from_X
andto_X
is a good idea, and maybe it would be possible to extend this shared code to NumPy as well as CuPy and JAX. The shared implementation sounds important enough to go intoak._v2._util
as functions named something like "from_arraylib
" and "to_arraylib
". Passing in themodule
object is a good idea, too: it could be done in both, the "from" as well as the "to". If it ever becomes necessary to handle special cases for differences between NumPy, CuPy, and JAX, we can check themodule.__name__
to know which one we have (without accidentally triggering an import of one of the others).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You'll probably need to merge main to get the Windows fix. But when tests pass, you can merge it!
No description provided.