Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

an experiment in handling instances with __jax_array__ #4725

Merged
merged 3 commits into from
Dec 16, 2020

Conversation

mattjj
Copy link
Collaborator

@mattjj mattjj commented Oct 28, 2020

Before raising an error on an unrecognized type, first check if the object defines a __jax_array__ method. If it does, call it!

This provides a way for custom types to be auto-converted to JAX-compatible types.

We didn't add a check for __array__ because that might entail a significant change in behavior. For example, we'd be auto-converting tf.Tensor values. Maybe it's better to remain loud in those cases.

Implementing this method is not sufficient for a type to be duck-typed enough for use with jax.numpy. But it may be necessary. That is, someone trying to add a duck-typed array to be used with JAX identified a need for __jax_array__ or similar.

This feature is experimental, so it may disappear, change arbitrarily, or never be documented.

@google-cla google-cla bot added the cla: yes label Oct 28, 2020
@lukepfister
Copy link
Contributor

This would be very helpful for some work I'm doing.

However, I'd expect jnp.sin(a) to also return an AlexArray. Thoughts?

@mattjj
Copy link
Collaborator Author

mattjj commented Oct 28, 2020

@lukepfister hmm, I think that would be both very hard to implement and also not consistent with NumPy's handling of __array__.

@lukepfister
Copy link
Contributor

@MattJ Ah, fair enough. I was thinking of the array_ufunc and __array_wrap__ functionality. Then the numpy equivalent would indeed return another AlexArray.

Would love to know if something similar is possible with jax.

@mattjj
Copy link
Collaborator Author

mattjj commented Oct 28, 2020

@shoyer can probably help educate me about those angles! (He's probably explained them to me already and I forgot...)

The aim in this particular PR is just specifying the one-way translation from some instance of a class defining __array__ to a JAX value, though it may be a first step in thinking along the lines of more general overloading.

@shoyer
Copy link
Collaborator

shoyer commented Oct 29, 2020

@MattJ Ah, fair enough. I was thinking of the array_ufunc and __array_wrap__ functionality. Then the numpy equivalent would indeed return another AlexArray.

Would love to know if something similar is possible with jax.

We certainly could implement something like this. Both PyTorch and TensorFlow has simplified versions of NumPy's protocols for overriding functions.

That said, there are costs to flexibility:

  • checking for protocols like this adds a (small) amount of overhead to every function call (this mattered for NumPy, but the incremental cost is probably negligible compared to JAX's existing overhead from jit)
  • it adds a layer of indirection into the JAX source code
  • it adds a layer of indirection into user code, which now may or may not be using JAX's original implementations

On the whole, I think it might be easier to try using something like NumPy's __array_module__ protocol for this use-case, which is already implemented on JAX arrays: #4076. AlexArray could implement NumPy's protocol to define a new implementation of the numpy namespace on AlexArray objects.

In theory, we could do make something like __jax_module__ in the same style as __array_module__ to enable generic versions of non-NumPy JAX functions, but I'm not sure it's really worth the trouble (see also #4117).

@mattjj mattjj marked this pull request as ready for review December 15, 2020 01:13
Copy link
Collaborator

@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks pretty good, but it's not clear to me whether this is the right approach in the long run. Should we instead follow numpy and dispatch these kinds of things via __array_function__?

jax/_src/numpy/lax_numpy.py Show resolved Hide resolved
jax/interpreters/xla.py Show resolved Hide resolved
tests/api_test.py Show resolved Hide resolved
@mattjj
Copy link
Collaborator Author

mattjj commented Dec 15, 2020

I'm not familiar with __array_function__, but after reading NEP 18 briefly my thinking is that this change is about something pretty different: this __jax_array__ change is about allowing user data types to be compatible with jax.numpy and jax.lax functions, whereas NEP 18 is about making numpy functions overloadable. Am I misunderstanding?

@shoyer
Copy link
Collaborator

shoyer commented Dec 15, 2020

I think __array_function__/__array_module__ are for a complementary but distinct purpose: making it possible to overload the numpy (or perhaps jax.numpy) namespaces.

Potentially conflicting behavior would arise only if a custom array object adds __jax_array__ now and later adds __jax_array_function__, which would then change the behavior of JAX functions on that object from coercing to switching to a separate implementation.

@jakevdp
Copy link
Collaborator

jakevdp commented Dec 15, 2020

Ah, you're right - I guess I was actually thinking of the __array_ufunc__ mechanism. The way the test case is structured made me think that this would be the more natural approach, though it would require a much deeper change to JAX.

If this is primarily about making it possible for jnp.array(x) to work for user-defined x, then this looks reasonable.

If this is primarily about making it so e.g. jnp.sin(x) works for user-defined x, then I think piggy-backing on the __array_ufunc__ mechanism might make more sense.

@AlexeyKurakin
Copy link

@jakevdp
I was original requestor of this feature and initial motivation was to allow jnp.sin(x) to allow with x which is instance of user-defined class. In other words, to allow user-defined classes to be "automatically casted" to JAX types.

@mattjj
Copy link
Collaborator Author

mattjj commented Dec 15, 2020

If this is primarily about making it so e.g. jnp.sin(x) works for user-defined x, then I think piggy-backing on the array_ufunc mechanism might make more sense.

I think __array_ufunc__ might also be for something slightly different: it'd let the user data type decide how to dispatch a numpy (not jax.numpy) ufunc. But we're not trying to add that capability; we just want user data types to be able to present that they know how to turn themselves into jax arrays.

@AlexeyKurakin
Copy link

I'm also not familiar with __array_ufunc__, but I think __jax_array__ is a reasonable way to achieve what I described.

@mattjj
Copy link
Collaborator Author

mattjj commented Dec 15, 2020

It's really helpful to check through all this, since these subtle differences are pretty confusing! I think I puzzled through this 1.5 months ago when we first started this PR, but didn't write it down for posterity. So it's good to double-check!

@shoyer
Copy link
Collaborator

shoyer commented Dec 15, 2020

Conceptually, I would put __array_ufunc__ in the same category as __array_function__ for overriding the behavior of functions in a namespace.

I'm not sure we want that behavior, but certainly making it possible to explicit convert arbitrary objects into JAX arrays seems useful.

@jakevdp
Copy link
Collaborator

jakevdp commented Dec 15, 2020

Sounds good, I guess I was mainly just confused by the test case calling jnp.sin rather than jnp.array - it made me think the motivation for this change was different than what it is.

@mattjj
Copy link
Collaborator Author

mattjj commented Dec 16, 2020

I wrote the test that way because we don't want to require people to call jnp.array to do the conversion. If we required that, then folks might as well just call their own conversion function. This is just a convenience layer so that e.g. Objax user code can avoid explicit conversion to jax arrays.

@jakevdp
Copy link
Collaborator

jakevdp commented Dec 16, 2020

Got it – in that case, we'll have to do a pass on lax_numpy.py to make certain we're calling array or asarray on relevant inputs: I don't think we currently do that in all cases.

@mattjj
Copy link
Collaborator Author

mattjj commented Dec 16, 2020

No, actually this PR covers what we need. We don't need to call array or asarray because we check for the __jax_array__ method at a lower level, namely on our dispatch path (that's the change in xla.py) and in core.get_aval (that's the core.py change). That's how jnp.sin works now!

@jakevdp
Copy link
Collaborator

jakevdp commented Dec 16, 2020

Agreed for anything that calls directly into lax. But there are routines in jax.numpy that don't. For example:

x = jnp.array(1)
y = AlexArray(x)

print(jnp.isscalar(x))
# True
print(jnp.isscalar(y))
# False

Similar for jnp.dtype, jnp.shape, jnp.size, etc.

@mattjj
Copy link
Collaborator Author

mattjj commented Dec 16, 2020

You're right that there might still be some unexpected gaps, though I expect we can just wait for @AlexeyKurakin to tell us about any.

For these particular examples, it's not completely clear to me that we want jnp.isscalar (or jnp.dtype, jnp.shape, etc) to change their behavior based on the existence of __jax_array__. Maybe those just show that the author of AlexArray must do more work to sufficiently duck-type JAX arrays, i.e. more than just implementing __jax_array__? Perhaps some of those start working once there's a shape attribute available?

In other words, it does seem like a nice invariant if jnp.shape(x) == jnp.shape(jnp.array(x)), but I wonder if that's up to the person doing the duck-typing rather than up to JAX.

(By the way, this is tangential, but I found it surprising based on the function name alone that dtypes.is_python_scalar(jnp.array(1)) is True! I think I get the reasoning, but it is surprising. From #4737 AIUI.)

@jakevdp
Copy link
Collaborator

jakevdp commented Dec 16, 2020

Agreed that those functions are a bit of a gray area in themselves, but throughout lax_nunpy.py we trigger different code-paths based on their outputs, so I anticipate anyone using this feature for anything beyond basic array conversion will be in for some surprises unless we put a concerted effort into handling and testing this case throughout.

@mattjj
Copy link
Collaborator Author

mattjj commented Dec 16, 2020

Well, again, it's not clear to me whether we need to do anything there, or if that's more up to the user writing the custom data type to do better duck typing.

It seems that jnp.isscalar is the odd duck:

import jax.numpy as jnp
from jax import jit, grad

class AlexArray:
  def __init__(self, jax_val):
    self.jax_val = jax_val
  def __jax_array__(self):
    return self.jax_val
  dtype = property(lambda self: self.jax_val.dtype)
  shape = property(lambda self: self.jax_val.shape)

x = jnp.array(1)
a = AlexArray(x)

for f in [jnp.dtype, jnp.shape, jnp.size]:
  print(f(x))
  print(f(a))
  print()
int32
int32

()
()

1
1

I wonder if the implementation of dtypes.is_python_scalar should just be tweaked. I'll add a change to this PR.

In any case, let's leave the gap analysis to @AlexeyKurakin to share with us as he experiments with this feature, through issues and PRs. This experimental feature is not meant to entail sweeping changes up-front; if it did, we'd just drop it entirely! My guess is Alex would rather we land and experiment.

@mattjj
Copy link
Collaborator Author

mattjj commented Dec 16, 2020

Pushed 5eb3685 to change dtypes.is_python_scalar and to test the specific issues @jakevdp brought up so far. WDYT about merging?

@jakevdp
Copy link
Collaborator

jakevdp commented Dec 16, 2020

One note: size() above only works because the underlying array is size 1. If it were anything else, the size would not match and a fair number of jax.numpy routines would work incorrectly (anything relying on ravel(), for instance, which depends on calling size()). And the incompatibilities noted above are by no means comprehensive: they're just a few things that immediately came to mind.

I'm fine merging this for now, because I don't think we'll have many users depending on it. But it will require a fair bit more effort to make this something we can recommend for anything beyond simple array conversions.

@mattjj
Copy link
Collaborator Author

mattjj commented Dec 16, 2020

Ah, but size can be handled just by exposing a size attribute, as in

class AlexArray:
  def __init__(self, jax_val):
    self.jax_val = jax_val
  def __jax_array__(self):
    return self.jax_val
  dtype = property(lambda self: self.jax_val.dtype)
  shape = property(lambda self: self.jax_val.shape)
  size = property(lambda self: self.jax_val.size)

So that's up to the user duck-typing, not up to JAX, IMO. We get to make up the rules of our own duck-typing game, after all :)

But if we ever do, I think it will require a fair bit more effort to make it something we can recommend for anything beyond simple array conversions.

This might be true, but it might not be; it's not clear to me yet, and the examples so far have all been pretty easy.

@mattjj mattjj added the pull ready Ready for copybara import and testing label Dec 16, 2020
@jakevdp
Copy link
Collaborator

jakevdp commented Dec 16, 2020

Sure, but what else needs to be handled by the user in order to make JAX work properly with objects of this type? I can't answer that right now.

@jakevdp
Copy link
Collaborator

jakevdp commented Dec 16, 2020

As a user, I would find it surprising that I need to implement shape, dtype, size, ??, ??, and ?? along with __jax_array__ in order to make things work correctly. Why not just __jax_array__ alone?

@mattjj
Copy link
Collaborator Author

mattjj commented Dec 16, 2020

Well, since jnp.shape, jnp.size, and jnp.dtype are just aliases to np.shape, np.size, and np.dtype, to make the jnp versions work one has to set up on the new type exactly what is needed to make the np versions work. We're not changing anything here.

The new method __jax_array__ is for only the narrow purpose of converting to a jax array when a jax array value is needed, and nothing more. It has no implications for how jnp.shape, jnp.size, jnp.dtype etc work.

I wonder if perhaps you're reasoning by an analogy between how the NumPy __array__ method works and how __jax_array__ could work. But that analogy doesn't apply; we're choosing a narrower definition of what __jax_array__ does. In other words, when you ask "Why not just jax_array alone?", that's just not the spec that we're going for here. And the reasons you bring up are indeed reasons why we shouldn't pursue that spec, namely that it'd be a lot more work.

@jakevdp
Copy link
Collaborator

jakevdp commented Dec 16, 2020

The other reason I'm hesitant here is because, on balance, the added capability does not seem commensurate with the support burden. A user could just as easily create a to_jax() method and call it when they want to use a jax function. But instead of that we're creating a new jax entrypoint with an intended use that currently has no test coverage, and no estimate of the work it will take to support it properly.

On the balance, ISTM it would be better to recommend users call to_jax(). Alternatively, if we want to support __jax_array__(), we should do a good accounting of what it would actually take to support it fully before introducing an incomplete implementation.

@copybara-service copybara-service bot merged commit 0092692 into master Dec 16, 2020
@mattjj
Copy link
Collaborator Author

mattjj commented Dec 16, 2020

the added capability does not seem commensurate with the support burden

no estimate of the work it will take to support it properly

What's the support burden?

To be clear, my estimate is that this won't take any additional work to support properly, as this PR is by definition all the proper support needed (modulo bugs). That is, there are no additional features in scope, i.e. no changes to how jnp.shape, jnp.size, jnp.dtype etc work.

A user could just as easily create a to_jax() method and call it when they want to use a jax function.

I believe this was discussed in some detail in the original chat thread, with me as the original proponent! But I came to believe I was wrong, and the goal now is to figure out whether we can easily avoid requiring all Objax user code have this kind of explicit conversion.

If this change proves to have a support burden, we'll remove it. The whole point here is to experiment, as per the PR title. That's also why there's no documentation. If you have examples of a support burden, please provide them, but otherwise it's just speculation, and speculation on which we disagree.

@mattjj mattjj deleted the handle-dunder-array-classes branch December 16, 2020 06:07
@jakevdp
Copy link
Collaborator

jakevdp commented Dec 16, 2020

OK, sounds good. My understanding had been that the goal here was that a custom object could be used in-place of a JAX array in all jnp and lax functions, and the results would be the same as if you used the underlying array.

It sounds like you're saying that's not the goal or the promise of this feature, in which case this is fine. We should just make sure to note that clearly when documenting it.

@jakevdp
Copy link
Collaborator

jakevdp commented Dec 16, 2020

I wanted to get a more quantitative idea of the aspects of this that I was concerned about, so I replaced the _CheckAgainstNumpy test utility with this:

  def _CheckAgainstNumpy(self, numpy_reference_op, lax_op, args_maker,
                         check_dtypes=True, tol=None,
                         canonicalize_dtypes=True):
    args = args_maker()
    
    import jax.numpy as jnp
    class AlexArray:
      def __init__(self, jax_val):
        self.jax_val = jax_val
      def __jax_array__(self):
        return self.jax_val
      dtype = property(lambda self: self.jax_val.dtype)
      shape = property(lambda self: self.jax_val.shape)
      size = property(lambda self: self.jax_val.size)
      ndim = property(lambda self: self.jax_val.ndim)

    lax_args = [AlexArray(jnp.array(arg)) if isinstance(arg, jnp.ndarray) else arg for arg in args]
    lax_ans = lax_op(*lax_args)
    numpy_ans = numpy_reference_op(*args)
    self.assertAllClose(numpy_ans, lax_ans, check_dtypes=check_dtypes,
                        atol=tol, rtol=tol,
                        canonicalize_dtypes=canonicalize_dtypes)

When I ran the test suite with

$ pytest -n auto test/lax_numpy_test.py

The result is:

801 failed, 2797 passed, 419 skipped

From a spot-check of the eight hundred test failures, these appear to be real failures: i.e. places where __jax_array__ is not properly handled by the jax.numpy API.

This is the support burden I'm talking about.

@mattjj
Copy link
Collaborator Author

mattjj commented Dec 16, 2020

Wow, thanks for doing that! That could be very helpful as a test pattern for someone wanting to duck-type ndarrays for use with jax.numpy.

But is it just testing the extent to which that the AlexArray implementation is incomplete? That is, those test failures might not have anything to do with __jax_array__, and instead just show one needs to do more work to duck-type arrays. Indeed, we'd expect it to need things like __add__!

The goal here is not to implement another duck-typed ndarray, or claim that just implementing __jax_array__ is in any way sufficient for duck typing. As you've shown, it certainly is not!

My understanding had been that the goal here was that a custom object could be used in-place of a JAX array in all jnp and lax functions

Ah, perhaps this is the crux of the issue! Actually, that's not the goal with this PR. That's Alex's ultimate goal, but in this PR we're just trying to add one missing piece needed so that he can finish his own duck type class. There's still plenty more work needed to implement a complete duck, but that's on his end.

We should just make sure to note that clearly when documenting it.

Well, I have no plans to document this. Right now it's just an experiment to see if it helps Objax. Think of it like something brand new in jax.experimental, except it can't literally live in jax.experimental.

I'll revise the OP to clarify that this is an experiment.

(It's also already been rolled back due to test failures. I've got to look into those...)

@mattjj mattjj changed the title an experiment in handling instances with __array__ an experiment in handling instances with __jax_array__ Dec 16, 2020
@AlexeyKurakin
Copy link

I opened an issue to keep track of this change: #5356

@shoyer shoyer restored the handle-dunder-array-classes branch February 5, 2021 03:00
raise TypeError(f"No device_put handler for type: {type(x)}") from err
handler = device_put_handlers.get(type(x))
if handler:
x = canonicalize_dtype(x)
Copy link
Collaborator

@shoyer shoyer Feb 5, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like this PR broke internal tests due to this line. In particular, some type like IntEnum no longer are accepted by device_put, e.g., consider this currently valid behavior:

In [1]: import jax

In [2]: import enum

In [3]: class X(enum.IntEnum):
   ...:     Y = 1
   ...:     Z = 2
   ...:

In [4]: jax.device_put(X.Y)
Out[4]: DeviceArray(1, dtype=int32)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh man, I think you nailed it! Is it just because I dropped the x = canonicalize_dtype(x) beforehand?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That was the only thing I saw that could have broken the existing code path.

@mattjj mattjj deleted the handle-dunder-array-classes branch February 6, 2021 04:13
mattjj added a commit that referenced this pull request Feb 6, 2021
Before raising an error on an unrecognized type, first check if the
object defines a __jax_array__ method. If it does, call it!

This provides a way for custom types to be auto-converted to
JAX-compatible types.

Implementing this method is not sufficient for a type to be duck-typed
enough for use with jax.numpy. But it may be necessary. That is, someone
trying to add a duck-typed array to be used with JAX identified a need
for __jax_array__ or similar. The user would still need to add lots of
other properties and methods, like dtype and shape attributes.

revives #4725 after it was rolled back. fixes #5356.
NeilGirdhar pushed a commit to NeilGirdhar/jax that referenced this pull request Feb 7, 2021
Before raising an error on an unrecognized type, first check if the
object defines a __jax_array__ method. If it does, call it!

This provides a way for custom types to be auto-converted to
JAX-compatible types.

Implementing this method is not sufficient for a type to be duck-typed
enough for use with jax.numpy. But it may be necessary. That is, someone
trying to add a duck-typed array to be used with JAX identified a need
for __jax_array__ or similar. The user would still need to add lots of
other properties and methods, like dtype and shape attributes.

revives jax-ml#4725 after it was rolled back. fixes jax-ml#5356.
gnecula added a commit to gnecula/jax that referenced this pull request Oct 7, 2022
This relates to the long discussion in jax-ml#4725 and jax-ml#10065.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla: yes pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants