-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Overrides of NumPy functions on JAX arrays #1565
Comments
adherence to NEP13 and NEP18 would make it useful to integrate jax into projects that rely on them for portability. Specifically we're looking to integrate jax w/ scale-out systems like e.g. dask and particle physics libraries like https://github.com/scikit-hep/awkward-array. @jpivarski can probably comment better on the technical details but we'd very much be a passionate user :) |
I love the imagination of xarray with jax in the back... Would be so awesome! |
An example: N = lambda x: stats.norm.cdf(x)
def test(a, b):
return N((b-a)/np.sqrt(a)) Jake's function (in the mentioned issue above), being meant only for illustrative purposes, Would it be possible to override all the |
In the context of a large software effort for the LHC (http://iris-hep.org) we are discussing this as @lukasheinrich mentioned above. We have jagged arrays and we have been able to override |
as a minimal example this should work
the error message suggests that there are pluggable "abstraction handlers". If there iis a well defined protocol we could maybe implement one for |
xref jax-ml#1565 `__array_module__` (see [NEP 37](https://numpy.org/neps/nep-0037-array-module.html)) is an experimental alternative to `__array_function__` and `__array_ufunc__` for "duck array" compatibility with NumPy that promises to be much less invasive. Example usage: ```python import numpy as np def duckarray_stack(arrays): """This "stack" function should work with any array library, including JAX.""" npx = np.get_array_module(*arrays) arrays = [npx.asarray(arr) for arr in arrays] shapes = {arr.shape for arr in arrays} if len(shapes) != 1: raise ValueError('all input arrays must have the same shape') expanded_arrays = [arr[npx.newaxis, ...] for arr in arrays] return npx.concatenate(expanded_arrays, axis=0) ``` Support for this protocol has *not* yet been implemented in NumPy, but it can be tested with https://github.com/seberg/numpy-dispatch. My reasoning for merging it into JAX (on an experimental basis with no guarantees, of course) is that: 1. It's not invasive -- the implementation is small and self-contained. 2. No backwards compatibility issues. Unlike `__array_function__` and `__array_ufunc__`, `__array_module__` will always require an explicit opt-in by libraries that use it by calling `get_array_module()`. 2. Other NumPy developers [want evidence]numpy/numpy#16935 (comment)) that this is actually feasible. 3. Scikit-Learn developers like @thomasjpfan are interested in exploring supporting scikit-learn on top of NumPy-like libraries like JAX, and experimental support for this protocol will make that easier. Note: this PR does add `numpy-dispatch` as a optional testing requirement in order to verify that this works. If desired, we could remove this from CI, but installing numpy-dispatch (and its build requirement Cython) appears to only add a few seconds of build time.
xref jax-ml#1565 `__array_module__` (see [NEP 37](https://numpy.org/neps/nep-0037-array-module.html)) is an experimental alternative to `__array_function__` and `__array_ufunc__` for "duck array" compatibility with NumPy that promises to be much less invasive. Example usage: ```python import numpy as np def duckarray_stack(arrays): """This "stack" function should work with any array library, including JAX.""" npx = np.get_array_module(*arrays) arrays = [npx.asarray(arr) for arr in arrays] shapes = {arr.shape for arr in arrays} if len(shapes) != 1: raise ValueError('all input arrays must have the same shape') expanded_arrays = [arr[npx.newaxis, ...] for arr in arrays] return npx.concatenate(expanded_arrays, axis=0) ``` Support for this protocol has *not* yet been implemented in NumPy, but it can be tested with https://github.com/seberg/numpy-dispatch. My reasoning for merging it into JAX (on an experimental basis with no guarantees, of course) is that: 1. It's not invasive -- the implementation is small and self-contained. 2. No backwards compatibility issues. Unlike `__array_function__` and `__array_ufunc__`, `__array_module__` will always require an explicit opt-in by libraries that use it by calling `get_array_module()`. 2. Other NumPy developers [want evidence](numpy/numpy#16935 (comment)) that this is actually feasible. 3. Scikit-Learn developers like @thomasjpfan are interested in exploring supporting scikit-learn on top of NumPy-like libraries like JAX, and experimental support for this protocol will make that easier. Note: this PR does add `numpy-dispatch` as a optional testing requirement in order to verify that this works. If desired, we could remove this from CI, but installing numpy-dispatch (and its build requirement Cython) appears to only add a few seconds of build time.
* Add experimental __array_module__ method xref #1565 `__array_module__` (see [NEP 37](https://numpy.org/neps/nep-0037-array-module.html)) is an experimental alternative to `__array_function__` and `__array_ufunc__` for "duck array" compatibility with NumPy that promises to be much less invasive. Example usage: ```python import numpy as np def duckarray_stack(arrays): """This "stack" function should work with any array library, including JAX.""" npx = np.get_array_module(*arrays) arrays = [npx.asarray(arr) for arr in arrays] shapes = {arr.shape for arr in arrays} if len(shapes) != 1: raise ValueError('all input arrays must have the same shape') expanded_arrays = [arr[npx.newaxis, ...] for arr in arrays] return npx.concatenate(expanded_arrays, axis=0) ``` Support for this protocol has *not* yet been implemented in NumPy, but it can be tested with https://github.com/seberg/numpy-dispatch. My reasoning for merging it into JAX (on an experimental basis with no guarantees, of course) is that: 1. It's not invasive -- the implementation is small and self-contained. 2. No backwards compatibility issues. Unlike `__array_function__` and `__array_ufunc__`, `__array_module__` will always require an explicit opt-in by libraries that use it by calling `get_array_module()`. 2. Other NumPy developers [want evidence](numpy/numpy#16935 (comment)) that this is actually feasible. 3. Scikit-Learn developers like @thomasjpfan are interested in exploring supporting scikit-learn on top of NumPy-like libraries like JAX, and experimental support for this protocol will make that easier. Note: this PR does add `numpy-dispatch` as a optional testing requirement in order to verify that this works. If desired, we could remove this from CI, but installing numpy-dispatch (and its build requirement Cython) appears to only add a few seconds of build time. * don't explicitly list cython * remove UnshpaedArray from _JAX_ARRAY_TYPES * Remove incorrect note about metaclasses * remove unnecessary numpy_dispatch.ensure_dispatching()
If this automated or at least simplified postin sckiit to JAX this would be huge! |
Edit: nevermind this comment! I updated JAX to find that |
JAX has |
I'm curious if NEP 47 is supported (or planned) for JAX. It would be nice to transparently use xarray over Jax primitives. |
That's excellent, thank you! Looks like it's shaping up brilliantly. I'm especially happy that the linear algebra primitives are almost all done! |
@raj-magesh I'm excited too! The Jax team are finishing it so fast. |
NumPy has protocols, based on the
__array_ufunc__
and__array_function__
methods, that allow for overriding what NumPy functions likenp.sin()
andnp.concatenate
when called on other array types.In practice, this means users can write
import numpy as np
to get NumPy functions that work on JAX arrays instead of needing to writeimport jax.numpy as np
.It might make sense to implement these methods on JAX's array objects. A working prototype of this can be found in #611.
Reason to do this:
import numpy as np
and it will probably work. This is particularly advantageous for third-party libraries (e.g., for projects like opt-einsum or xarray) that want to support multiple backends in a clean, composable way.Reasons not to do this:
onp.asarray()
. Implement overrides of NumPy's public API on JAX arrays #611 includes a handful of examples of this internally in JAX.Decision by @mattjj and myself: We're not going merge this yet, because it's not clear that anyone would even use it and it imposes a maintenance burden.
If you have compelling use-cases, please speak up. We could relatively easily make this happen, but would need someone who could commit to being a passionate user first.
The text was updated successfully, but these errors were encountered: