Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

This PR adds JAX as a new nplike #1399

Merged
merged 15 commits into from
Apr 12, 2022
Merged

This PR adds JAX as a new nplike #1399

merged 15 commits into from
Apr 12, 2022

Conversation

sw15h
Copy link
Collaborator

@sw15h sw15h commented Apr 6, 2022

No description provided.

@codecov
Copy link

codecov bot commented Apr 6, 2022

Codecov Report

Merging #1399 (d0d7b9c) into main (edfce38) will decrease coverage by 0.07%.
The diff coverage is 31.29%.

Impacted Files Coverage Δ
src/awkward/_v2/_connect/cuda/__init__.py 0.00% <0.00%> (ø)
src/awkward/_v2/_connect/jax/__init__.py 0.00% <0.00%> (ø)
src/awkward/_v2/operations/convert/ak_from_cupy.py 50.00% <0.00%> (+23.33%) ⬆️
src/awkward/_v2/operations/convert/ak_from_jax.py 50.00% <0.00%> (-25.00%) ⬇️
src/awkward/_v2/operations/convert/ak_to_cupy.py 33.33% <0.00%> (+23.95%) ⬆️
src/awkward/_v2/operations/convert/ak_to_jax.py 33.33% <0.00%> (-41.67%) ⬇️
src/awkward/_v2/operations/describe/ak_backend.py 10.00% <0.00%> (-2.50%) ⬇️
src/awkward/_v2/_util.py 75.11% <37.06%> (-8.22%) ⬇️
src/awkward/_v2/contents/numpyarray.py 90.42% <50.00%> (-0.15%) ⬇️
...rc/awkward/_v2/operations/convert/ak_from_numpy.py 100.00% <100.00%> (+27.77%) ⬆️
... and 1 more

@sw15h sw15h requested a review from jpivarski April 7, 2022 08:45
Copy link
Member

@jpivarski jpivarski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this!

Some large-scale refactoring things:

  • The Jax nplike is not much different from the Cupy one, and the Cupy one is in the ak.nplike module. I think the Jax one could go here (and it needs to be possible to declare the class without importing jax. Whenever you have something that depends on two attributes, like "is an nplike" and "is for JAX", there's always a question of whether to organize them such that all the nplikes are together or to organize them such that all the JAX-related stuff is together. Let's organize it with all nplikes together, because that will enforce that they behave the same way with respect to when the third-party library gets imported. Some nplikes, like the TypeTracer, are in other places, but that's because the TypeTracer is very different from the nplikes for third-party libraries.
  • Modules in ak._v2.operations.*.* should each provide one high-level function. See, for instance, ak_from_parquet and ak_metadata_from_parquet. There's a lot of overlap between these two, but they're in separate files to maintain that structure. So please split ak_from_lib back into ak_from_cupy and ak_from_jax (same for the "to" direction).
  • Sharing code in the implementations of from_X and to_X is a good idea, and maybe it would be possible to extend this shared code to NumPy as well as CuPy and JAX. The shared implementation sounds important enough to go into ak._v2._util as functions named something like "from_arraylib" and "to_arraylib". Passing in the module object is a good idea, too: it could be done in both, the "from" as well as the "to". If it ever becomes necessary to handle special cases for differences between NumPy, CuPy, and JAX, we can check the module.__name__ to know which one we have (without accidentally triggering an import of one of the others).

src/awkward/nplike.py Outdated Show resolved Hide resolved
src/awkward/_v2/_connect/jax/__init__.py Outdated Show resolved Hide resolved
tests-jax/test_1399-from-jax.py Outdated Show resolved Hide resolved
tests-jax/test_1399-from-jax.py Outdated Show resolved Hide resolved
src/awkward/_v2/_connect/jax/nplike.py Outdated Show resolved Hide resolved
src/awkward/_v2/_connect/jax/nplike.py Outdated Show resolved Hide resolved
Copy link
Member

@jpivarski jpivarski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You'll probably need to merge main to get the Windows fix. But when tests pass, you can merge it!

src/awkward/_v2/_util.py Outdated Show resolved Hide resolved
@sw15h sw15h merged commit 3f5a4cc into main Apr 12, 2022
@sw15h sw15h deleted the swishdiff/jax-nplike branch April 12, 2022 09:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants