-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
jax.numpy.tanh(2) works but jax.grad(tanh)(2) fails- int32 not liked by grad #2663
Comments
Well, I think this particular example is working as intended. You can think of your original
In other words, the int->float cast happens inside the function you are differentiating. That is, you are differentiating a function that takes as input an integer and returns a float. In general the cotangent types returned by What do you think? |
Hi Peter,
Thank you for the quick response.
I think your comment is quite reasonable and I get that the cotangent code is driven by the function and its inputs so the results are understandable.
I can think of at least two ways forward.
1 Take the expedient approach and just make it clear (may be more clear if it is already in the docs; I apologize if this issue is discussed in the docs and I missed it) that the input value controls the output value from the cotangent and that since ints for gradients are not sensible that the inputs should be floats otherwise the gradient code will throw an error.
2 Take a potentially more hazardous approach and handle ints in a more graceful way by casting them to a default float type before really starting on making the cotangent
Given that there are probably better things to spend one's time on and approach 2 may give rise to some nasty subtle issues, I would think that option 1 would be the more sensible way of moving forward.
Dominic
…________________________________
From: Peter Hawkins <notifications@github.com>
Sent: Thursday, April 9, 2020 3:06 PM
To: google/jax <jax@noreply.github.com>
Cc: Barraclough, Dominic (ext. 414) <djb@qvii.com>; Author <author@noreply.github.com>
Subject: [EXTERNAL] Re: [google/jax] jax.numpy.tanh(2) works but jax.grad(tanh)(2) fails- int32 not liked by grad (#2663)
CAUTION:
This email originated from outside of QVI. Do not click links or open attachments unless you recognize the sender and are expecting to receive this content from them. If you suspect the email is not legitimate, please forward it as an attachment to SPAM@qvii.com and delete it from your Inbox.
Well, I think this particular example is working as intended. You can think of your original tanh function as equivalent to the following:
import jax.numpy as jnp
def f(x):
x = jnp.array(x, jnp.float32)
return jnp.tanh(x)
In other words, the int->float cast happens inside the function you are differentiating. That is, you are differentiating a function that takes as input an integer and returns a float.
In general the cotangent types returned by grad match the input types, so we'd have to return an integer value to you, which probably isn't what you wanted. An error seems preferable.
What do you think?
—
You are receiving this because you authored the thread.
Reply to this email directly, view it on GitHub<#2663 (comment)>, or unsubscribe<https://github.com/notifications/unsubscribe-auth/AAIRK64NOX52JSGRAM4MD53RLYMC7ANCNFSM4ME5ESRQ>.
|
Hi
this may be a little churlish to bring up but ints are acceptable for some calls in jax but not for others. I assume this is not by design.
example
jax.version.__version__ == '0.1.62'
Output
aval=2
np.tanh(aval)=0.9640276
Primal inputs to reverse-mode differentiation must be of float or complex type, got type int32
sadness - grad(np.tanh)(aval) failed
The text was updated successfully, but these errors were encountered: