-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Jax tracer TypeError in multiplication #2556
Comments
@Saransh-cpp, when you have access to the Awkward repo (see https://github.com/scikit-hep/awkward/invitations), please assign yourself to this issue. Possible outcomes for this issue include:
I'm just organizing responsibilities at this point. |
I tried removing the intentional error and everything seems to work - In [1]: import jax
...: import numpy as np
...:
...: a = np.array([1, 2, 3], dtype=float)
...:
...: def f(x):
...: return np.sum(np.sum(x) * x)
...:
...: f(a)
...: jax.grad(f)(a)
Out[1]: Array([12., 12., 12.], dtype=float32)
In [2]: import jax
...: import awkward as ak
...:
...: ak.jax.register_and_check()
...:
...: a = ak.Array([[1.0, 2, 3], [5, 6]], backend="jax")
...:
...: def f(x):
...: return ak.sum(ak.sum(x) * x)
...:
...: f(a)
...: jax.grad(f)(a)
Out[2]: <Array [[34.0, 34.0, 34.0], [34.0, 34.0]] type='2 * var * float32'> Surprisingly, all the tests pass too. @agoose77, given that you tried this above but ended up closing the PR, could you please let me know why this intentional error is required? Thanks! |
@Saransh-cpp at the time, I recall having thoughts that seeing tracers at that point was a symptom of a problem that we needed to fix rather than just removing the error. However, my memory of this is not brilliant. I would caution that our tests aren't super detailed here, but nevertheless support removing the guard until you find something that tells you otherwise! Maybe I mis-remembered, or it's actually not a problem after all ... |
Oh, I see, thanks! I'll create a PR removing the error to see if everything goes right. |
I was browsing through JAX issues, and stumbled upon this. It looks like this check was added here #1763. The idea with v2 of Awkward was precisely to work around issues like these, where tracers could be treated as a normal "filled concrete buffers" although in reality for intermediate reverse diff operations, they are storing only the "traced metadata" of JAX buffers. This check should be safe to remove. The only place that would need a distinction between "Tracers" and "ConcreteBuffers" would be at the kernel invocation where we leave the python layer to go into the C layer. (Going offtopic here) When I wrote this, there was no way to write custom kernels that could be feeded into XLA like CuPy custom kernels. However, they do provide that now here (Wow!). Maybe it's worthwhile to investigate if we can reuse the One more interesting point here is that a lot of operations use This is really interesting, and I am very happy to see some use-cases pop up for the Awkward JAX integration. @Saransh-cpp thank you for taking this up! |
Interesting, thanks for the explanation! |
Version of Awkward Array
2.2.4
Description and code to reproduce
While trying out awkward + jax I ran into the following:
results in
This seems to originate from the
ak.sum(x) * x
piece.A corresponding
numpy
version (without anyawkward
) works fine:Relevant library versions:
The text was updated successfully, but these errors were encountered: