-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
an experiment in handling instances with __jax_array__ #4725
Conversation
This would be very helpful for some work I'm doing. However, I'd expect |
@lukepfister hmm, I think that would be both very hard to implement and also not consistent with NumPy's handling of |
@MattJ Ah, fair enough. I was thinking of the array_ufunc and Would love to know if something similar is possible with jax. |
@shoyer can probably help educate me about those angles! (He's probably explained them to me already and I forgot...) The aim in this particular PR is just specifying the one-way translation from some instance of a class defining |
We certainly could implement something like this. Both PyTorch and TensorFlow has simplified versions of NumPy's protocols for overriding functions. That said, there are costs to flexibility:
On the whole, I think it might be easier to try using something like NumPy's In theory, we could do make something like |
4aa10ca
to
7342318
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks pretty good, but it's not clear to me whether this is the right approach in the long run. Should we instead follow numpy and dispatch these kinds of things via __array_function__
?
I'm not familiar with |
I think Potentially conflicting behavior would arise only if a custom array object adds |
Ah, you're right - I guess I was actually thinking of the If this is primarily about making it possible for If this is primarily about making it so e.g. |
@jakevdp |
I think |
I'm also not familiar with |
It's really helpful to check through all this, since these subtle differences are pretty confusing! I think I puzzled through this 1.5 months ago when we first started this PR, but didn't write it down for posterity. So it's good to double-check! |
Conceptually, I would put I'm not sure we want that behavior, but certainly making it possible to explicit convert arbitrary objects into JAX arrays seems useful. |
Sounds good, I guess I was mainly just confused by the test case calling jnp.sin rather than jnp.array - it made me think the motivation for this change was different than what it is. |
I wrote the test that way because we don't want to require people to call |
Got it – in that case, we'll have to do a pass on |
No, actually this PR covers what we need. We don't need to call |
Agreed for anything that calls directly into x = jnp.array(1)
y = AlexArray(x)
print(jnp.isscalar(x))
# True
print(jnp.isscalar(y))
# False Similar for |
You're right that there might still be some unexpected gaps, though I expect we can just wait for @AlexeyKurakin to tell us about any. For these particular examples, it's not completely clear to me that we want In other words, it does seem like a nice invariant if (By the way, this is tangential, but I found it surprising based on the function name alone that |
Agreed that those functions are a bit of a gray area in themselves, but throughout lax_nunpy.py we trigger different code-paths based on their outputs, so I anticipate anyone using this feature for anything beyond basic array conversion will be in for some surprises unless we put a concerted effort into handling and testing this case throughout. |
Well, again, it's not clear to me whether we need to do anything there, or if that's more up to the user writing the custom data type to do better duck typing. It seems that import jax.numpy as jnp
from jax import jit, grad
class AlexArray:
def __init__(self, jax_val):
self.jax_val = jax_val
def __jax_array__(self):
return self.jax_val
dtype = property(lambda self: self.jax_val.dtype)
shape = property(lambda self: self.jax_val.shape)
x = jnp.array(1)
a = AlexArray(x)
for f in [jnp.dtype, jnp.shape, jnp.size]:
print(f(x))
print(f(a))
print()
I wonder if the implementation of In any case, let's leave the gap analysis to @AlexeyKurakin to share with us as he experiments with this feature, through issues and PRs. This experimental feature is not meant to entail sweeping changes up-front; if it did, we'd just drop it entirely! My guess is Alex would rather we land and experiment. |
One note: I'm fine merging this for now, because I don't think we'll have many users depending on it. But it will require a fair bit more effort to make this something we can recommend for anything beyond simple array conversions. |
Ah, but class AlexArray:
def __init__(self, jax_val):
self.jax_val = jax_val
def __jax_array__(self):
return self.jax_val
dtype = property(lambda self: self.jax_val.dtype)
shape = property(lambda self: self.jax_val.shape)
size = property(lambda self: self.jax_val.size) So that's up to the user duck-typing, not up to JAX, IMO. We get to make up the rules of our own duck-typing game, after all :)
This might be true, but it might not be; it's not clear to me yet, and the examples so far have all been pretty easy. |
Sure, but what else needs to be handled by the user in order to make JAX work properly with objects of this type? I can't answer that right now. |
As a user, I would find it surprising that I need to implement |
Well, since The new method I wonder if perhaps you're reasoning by an analogy between how the NumPy |
The other reason I'm hesitant here is because, on balance, the added capability does not seem commensurate with the support burden. A user could just as easily create a On the balance, ISTM it would be better to recommend users call |
What's the support burden? To be clear, my estimate is that this won't take any additional work to support properly, as this PR is by definition all the proper support needed (modulo bugs). That is, there are no additional features in scope, i.e. no changes to how
I believe this was discussed in some detail in the original chat thread, with me as the original proponent! But I came to believe I was wrong, and the goal now is to figure out whether we can easily avoid requiring all Objax user code have this kind of explicit conversion. If this change proves to have a support burden, we'll remove it. The whole point here is to experiment, as per the PR title. That's also why there's no documentation. If you have examples of a support burden, please provide them, but otherwise it's just speculation, and speculation on which we disagree. |
OK, sounds good. My understanding had been that the goal here was that a custom object could be used in-place of a JAX array in all It sounds like you're saying that's not the goal or the promise of this feature, in which case this is fine. We should just make sure to note that clearly when documenting it. |
I wanted to get a more quantitative idea of the aspects of this that I was concerned about, so I replaced the def _CheckAgainstNumpy(self, numpy_reference_op, lax_op, args_maker,
check_dtypes=True, tol=None,
canonicalize_dtypes=True):
args = args_maker()
import jax.numpy as jnp
class AlexArray:
def __init__(self, jax_val):
self.jax_val = jax_val
def __jax_array__(self):
return self.jax_val
dtype = property(lambda self: self.jax_val.dtype)
shape = property(lambda self: self.jax_val.shape)
size = property(lambda self: self.jax_val.size)
ndim = property(lambda self: self.jax_val.ndim)
lax_args = [AlexArray(jnp.array(arg)) if isinstance(arg, jnp.ndarray) else arg for arg in args]
lax_ans = lax_op(*lax_args)
numpy_ans = numpy_reference_op(*args)
self.assertAllClose(numpy_ans, lax_ans, check_dtypes=check_dtypes,
atol=tol, rtol=tol,
canonicalize_dtypes=canonicalize_dtypes) When I ran the test suite with
The result is:
From a spot-check of the eight hundred test failures, these appear to be real failures: i.e. places where This is the support burden I'm talking about. |
Wow, thanks for doing that! That could be very helpful as a test pattern for someone wanting to duck-type ndarrays for use with But is it just testing the extent to which that the The goal here is not to implement another duck-typed ndarray, or claim that just implementing
Ah, perhaps this is the crux of the issue! Actually, that's not the goal with this PR. That's Alex's ultimate goal, but in this PR we're just trying to add one missing piece needed so that he can finish his own duck type class. There's still plenty more work needed to implement a complete duck, but that's on his end.
Well, I have no plans to document this. Right now it's just an experiment to see if it helps Objax. Think of it like something brand new in I'll revise the OP to clarify that this is an experiment. (It's also already been rolled back due to test failures. I've got to look into those...) |
I opened an issue to keep track of this change: #5356 |
raise TypeError(f"No device_put handler for type: {type(x)}") from err | ||
handler = device_put_handlers.get(type(x)) | ||
if handler: | ||
x = canonicalize_dtype(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like this PR broke internal tests due to this line. In particular, some type like IntEnum
no longer are accepted by device_put
, e.g., consider this currently valid behavior:
In [1]: import jax
In [2]: import enum
In [3]: class X(enum.IntEnum):
...: Y = 1
...: Z = 2
...:
In [4]: jax.device_put(X.Y)
Out[4]: DeviceArray(1, dtype=int32)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh man, I think you nailed it! Is it just because I dropped the x = canonicalize_dtype(x)
beforehand?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That was the only thing I saw that could have broken the existing code path.
Before raising an error on an unrecognized type, first check if the object defines a __jax_array__ method. If it does, call it! This provides a way for custom types to be auto-converted to JAX-compatible types. Implementing this method is not sufficient for a type to be duck-typed enough for use with jax.numpy. But it may be necessary. That is, someone trying to add a duck-typed array to be used with JAX identified a need for __jax_array__ or similar. The user would still need to add lots of other properties and methods, like dtype and shape attributes. revives #4725 after it was rolled back. fixes #5356.
Before raising an error on an unrecognized type, first check if the object defines a __jax_array__ method. If it does, call it! This provides a way for custom types to be auto-converted to JAX-compatible types. Implementing this method is not sufficient for a type to be duck-typed enough for use with jax.numpy. But it may be necessary. That is, someone trying to add a duck-typed array to be used with JAX identified a need for __jax_array__ or similar. The user would still need to add lots of other properties and methods, like dtype and shape attributes. revives jax-ml#4725 after it was rolled back. fixes jax-ml#5356.
This relates to the long discussion in jax-ml#4725 and jax-ml#10065.
Before raising an error on an unrecognized type, first check if the object defines a
__jax_array__
method. If it does, call it!This provides a way for custom types to be auto-converted to JAX-compatible types.
We didn't add a check for
__array__
because that might entail a significant change in behavior. For example, we'd be auto-convertingtf.Tensor
values. Maybe it's better to remain loud in those cases.Implementing this method is not sufficient for a type to be duck-typed enough for use with
jax.numpy
. But it may be necessary. That is, someone trying to add a duck-typed array to be used with JAX identified a need for__jax_array__
or similar.This feature is experimental, so it may disappear, change arbitrarily, or never be documented.