-
-
Notifications
You must be signed in to change notification settings - Fork 150
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
leaked trace error when using custom_jvp, scan, eqx.Module, and jit #745
Comments
Okay, one way to get around is to use
|
It turned out that using
|
The issue is this bit" You've marked an array (inside Try using |
Thank you for the suggestion! I modified the custom vjp rule below, but I got an error when evaluating the gradient "error "ValueError: Received keyword tangent". How can I pass the function
|
Also the example in the docs doesn't seem to be very robust, @eqx.filter_custom_jvp
def call(x, y, *, fn):
return fn(x, y)
@call.def_jvp
def call_jvp(primals, tangents, *, fn):
x, y = primals
tx, ty = tangents
primal_out = call(x, y, fn=fn)
tangent_out = tx**2 + ty
return primal_out, tangent_out
eqx.filter_grad(call)(jnp.array(2.0), jnp.array(3.0), fn = lambda a, b: a + b) fails. (the function is random, none I tried worked)
|
See #745 (comment). Also improves the documentation into a larger example, to help make clear why some tangent may be `None`.
Thanks both! |
See #745 (comment). Also improves the documentation into a larger example, to help make clear why some tangent may be `None`.
Hi @patrick-kidger, Thank you for checking and fixing the reference. I ran the example but got an error like below. The error is gone if I changed
|
Whoops! You're completely correct. I must have tweaked that after I checked it ran. Thank you for the heads up -- now fixed in #754! |
@patrick-kidger
|
In this case it is the (If you need to differentiate multiple things then wrap them together into a tuple; more generally into any PyTree.) |
@patrick-kidger
Interestingly, if I evaluated the derivative of both arguments, it worked
I guess something is going on but have no clue to it... |
Your JVP has the line In this case Note that all of the above is true regardless of whether you juse |
Understood, thank you for the explanation! |
Hi all,
I encountered an leaked trace error when I used custom_jvp, scan, eqx.Module, and jit at the same time with the latest version of JAX and equinox (JAX=0.4.28, equinox = 0.11.4). I did not have such an error before with older versions, but I could not reproduce it. The below example gives an leaked trace error when evaluating gradient of the function
loss
that has a eqx.Module class as an argument. It seems likeparams
in the class Fun is leaked. It gave no error if I used a function as an argument. I would appreciate any help!Updated:
params
in the class Fun was originally a Python float, but now it is jax.array.The text was updated successfully, but these errors were encountered: