-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add experimental __array_module__ method #4076
Conversation
xref jax-ml#1565 `__array_module__` (see [NEP 37](https://numpy.org/neps/nep-0037-array-module.html)) is an experimental alternative to `__array_function__` and `__array_ufunc__` for "duck array" compatibility with NumPy that promises to be much less invasive. Example usage: ```python import numpy as np def duckarray_stack(arrays): """This "stack" function should work with any array library, including JAX.""" npx = np.get_array_module(*arrays) arrays = [npx.asarray(arr) for arr in arrays] shapes = {arr.shape for arr in arrays} if len(shapes) != 1: raise ValueError('all input arrays must have the same shape') expanded_arrays = [arr[npx.newaxis, ...] for arr in arrays] return npx.concatenate(expanded_arrays, axis=0) ``` Support for this protocol has *not* yet been implemented in NumPy, but it can be tested with https://github.com/seberg/numpy-dispatch. My reasoning for merging it into JAX (on an experimental basis with no guarantees, of course) is that: 1. It's not invasive -- the implementation is small and self-contained. 2. No backwards compatibility issues. Unlike `__array_function__` and `__array_ufunc__`, `__array_module__` will always require an explicit opt-in by libraries that use it by calling `get_array_module()`. 2. Other NumPy developers [want evidence](numpy/numpy#16935 (comment)) that this is actually feasible. 3. Scikit-Learn developers like @thomasjpfan are interested in exploring supporting scikit-learn on top of NumPy-like libraries like JAX, and experimental support for this protocol will make that easier. Note: this PR does add `numpy-dispatch` as a optional testing requirement in order to verify that this works. If desired, we could remove this from CI, but installing numpy-dispatch (and its build requirement Cython) appears to only add a few seconds of build time.
@@ -4574,6 +4575,21 @@ def _operator_round(number, ndigits=None): | |||
setattr(DeviceArray, "nbytes", property(_nbytes)) | |||
|
|||
|
|||
# Experimental support for NumPy's module dispatch with NEP-37. | |||
# Currently requires https://github.com/seberg/numpy-dispatch | |||
_JAX_ARRAY_TYPES = (DeviceArray, core.Tracer) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not quite sure it this list should include UnshapedArray
or one of JAX's other abstract array types.
Their presence in the set of types checked by jax.numpy.ndarray
suggests yes, but on the other hand I don't think I've ever seen them in user supplied functions transformed with JAX?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like limiting the set for now. If it becomes clear later that UnshapedArray
or others are necessary, they are easy enough to add.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! I'm in favor of provisionally adding this to help motivate this kind of feature in the broader Python array ecosystem.
On the JAX side, one detail we may need to think about is this from the Array module contract:
Unimplemented functionality should simply be omitted (e.g., accessing an unimplemented function should raise AttributeError).
This is counter to JAX's current strategy, which is to define unimplemented functions but raise NotImplementedError
when they are called. I'm not sure how important that detail is. I think it's probably OK to push forward as-is and address that if it becomes an issue.
I don't think this detail of NEP-37 (the precise error message) matters too much. Possibly NEP-37 should be updated, since it can indeed be helpful to distinguish between a spelling error vs. an error from a function that just hasn't been implemented. Perhaps the cleanest way to do this would be module level # in jax/numpy/__init__.py
def __getattr__(name):
import numpy
if hasattr(numpy, name):
raise AttributeError(f'numpy.{name} has not been implemented yet in jax.numpy')
else:
raise AttributeError(f"module 'jax.numpy' has no attribute {name}") |
@shoyer Thank you for working on this! I am looking forward to trying this out. |
__array_module__ is an experimental protocol for "duck array" compatibility that for indicating how to find a "numpy compatible" module. The hope is to make it easier to write generic code that works across a range of array libraries. A full example, which should work equally work for TF-NumPy as for JAX, can be found at jax-ml/jax#4076. More examples, and motivation for this protocol can be found at https://numpy.org/neps/nep-0037-array-module.html. This design has not yet been finalized in NumPy, so at present it requires using the experimental numpy_dispatch module: https://github.com/seberg/numpy-dispatch. Unlike NumPy's __array_ufunc__ and __array_function__ protocols, __array_module__ by design has no backwards compatibility consequences. The protocol only controls the behavior of numpy_dispatch.get_array_module() or numpy.get_array_module() -- it does not change any existing NumPy functions. PiperOrigin-RevId: 328181308 Change-Id: I98b648a59709ed3bc295aaf97f471a1ff7ae1f05
xref #1565
__array_module__
(see NEP 37) is an experimental alternative to__array_function__
and__array_ufunc__
for "duck array" compatibility with NumPy that promises to be much less invasive.Example usage, for writing a generic version of
np.stack()
on top ofconcatenate
,asarray
,shape
attributes and array indexing:For now, you need to use
numpy_dispatch.get_array_module()
to try this feature (with https://github.com/seberg/numpy-dispatch) but in the long term the hope is to merge this into NumPy proper asnp.get_array_module()
.My reasoning for merging it into JAX (on an experimental basis with no guarantees, of course) is that:
__array_function__
and__array_ufunc__
,__array_module__
will always require an explicit opt-in by libraries that use it by callingget_array_module()
.Note: this PR does add
numpy-dispatch
as a optional testing requirement in order to verify that this works. If desired, we could remove this from CI, but installing numpy-dispatch appears to only add about 7 extra seconds of build time.