Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jax tracer TypeError in multiplication #2556

Closed
alexander-held opened this issue Jul 2, 2023 · 6 comments · Fixed by #3013
Closed

Jax tracer TypeError in multiplication #2556

alexander-held opened this issue Jul 2, 2023 · 6 comments · Fixed by #3013
Assignees
Labels
autodiff Issue related to auto-differentiation bug The problem described is something that must be fixed

Comments

@alexander-held
Copy link
Member

Version of Awkward Array

2.2.4

Description and code to reproduce

While trying out awkward + jax I ran into the following:

import jax
import awkward as ak

ak.jax.register_and_check()

a = ak.Array([[1.0, 2, 3], [5, 6]], backend="jax")

def f(x):
    return ak.sum(ak.sum(x) * x)

f(a)
jax.grad(f)(a)

results in

[...]
TypeError: Jax tracers cannot be used with `ak.from_arraylib`

This error occurred while calling

    numpy.multiply.__call__(
        JVPTracer-instance
        <Array [[...], [...]] type='2 * var * float32'>
    )

This seems to originate from the ak.sum(x) * x piece.

A corresponding numpy version (without any awkward) works fine:

import jax
import numpy as np

a = np.array([1, 2, 3], dtype=float)

def f(x):
    return np.sum(np.sum(x) * x)

f(a)
jax.grad(f)(a)

Relevant library versions:

awkward                       2.2.4
awkward-cpp                   17
jax                           0.4.13
jaxlib                        0.4.13
numpy                         1.25.0
@alexander-held alexander-held added the bug (unverified) The problem described would be a bug, but needs to be triaged label Jul 2, 2023
@agoose77 agoose77 added bug The problem described is something that must be fixed and removed bug (unverified) The problem described would be a bug, but needs to be triaged labels Jul 2, 2023
@jpivarski jpivarski added the autodiff Issue related to auto-differentiation label Oct 2, 2023
@jpivarski
Copy link
Member

@Saransh-cpp, when you have access to the Awkward repo (see https://github.com/scikit-hep/awkward/invitations), please assign yourself to this issue. Possible outcomes for this issue include:

  • testing it and realizing that it's not broken anymore, and closing it without a fix
  • having some good justification for why it's not important anymore, and closing it without a fix
  • fixing it

I'm just organizing responsibilities at this point.

@Saransh-cpp Saransh-cpp self-assigned this Jan 20, 2024
@Saransh-cpp
Copy link
Member

I tried removing the intentional error and everything seems to work -

In [1]: import jax
   ...: import numpy as np
   ...: 
   ...: a = np.array([1, 2, 3], dtype=float)
   ...: 
   ...: def f(x):
   ...:     return np.sum(np.sum(x) * x)
   ...: 
   ...: f(a)
   ...: jax.grad(f)(a)
Out[1]: Array([12., 12., 12.], dtype=float32)

In [2]: import jax
   ...: import awkward as ak
   ...: 
   ...: ak.jax.register_and_check()
   ...: 
   ...: a = ak.Array([[1.0, 2, 3], [5, 6]], backend="jax")
   ...: 
   ...: def f(x):
   ...:     return ak.sum(ak.sum(x) * x)
   ...: 
   ...: f(a)
   ...: jax.grad(f)(a)
Out[2]: <Array [[34.0, 34.0, 34.0], [34.0, 34.0]] type='2 * var * float32'>

Surprisingly, all the tests pass too. @agoose77, given that you tried this above but ended up closing the PR, could you please let me know why this intentional error is required? Thanks!

@agoose77
Copy link
Collaborator

agoose77 commented Feb 8, 2024

@Saransh-cpp at the time, I recall having thoughts that seeing tracers at that point was a symptom of a problem that we needed to fix rather than just removing the error. However, my memory of this is not brilliant.

I would caution that our tests aren't super detailed here, but nevertheless support removing the guard until you find something that tells you otherwise! Maybe I mis-remembered, or it's actually not a problem after all ...

@Saransh-cpp
Copy link
Member

Oh, I see, thanks! I'll create a PR removing the error to see if everything goes right.

@sw15h
Copy link
Collaborator

sw15h commented Mar 26, 2024

I was browsing through JAX issues, and stumbled upon this. It looks like this check was added here #1763. The idea with v2 of Awkward was precisely to work around issues like these, where tracers could be treated as a normal "filled concrete buffers" although in reality for intermediate reverse diff operations, they are storing only the "traced metadata" of JAX buffers. This check should be safe to remove. The only place that would need a distinction between "Tracers" and "ConcreteBuffers" would be at the kernel invocation where we leave the python layer to go into the C layer. (Going offtopic here) When I wrote this, there was no way to write custom kernels that could be feeded into XLA like CuPy custom kernels. However, they do provide that now here (Wow!). Maybe it's worthwhile to investigate if we can reuse the CuPy kernels. Although, we would have to write forward and backward implementations of each of the kernels.

One more interesting point here is that a lot of operations use from_arraylib to convert back to awkward highlevel, the tests were unable to catch this. This is on me as I couldn't get enough bandwidth to write a good test suite for awkward-jax interop. This could be a another good project to undertake.

This is really interesting, and I am very happy to see some use-cases pop up for the Awkward JAX integration. @Saransh-cpp thank you for taking this up!

@Saransh-cpp
Copy link
Member

Interesting, thanks for the explanation!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
autodiff Issue related to auto-differentiation bug The problem described is something that must be fixed
Projects
None yet
5 participants