Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax.numpy.tanh(2) works but jax.grad(tanh)(2) fails- int32 not liked by grad #2663

Closed
durian888 opened this issue Apr 9, 2020 · 2 comments · Fixed by #2712
Closed

jax.numpy.tanh(2) works but jax.grad(tanh)(2) fails- int32 not liked by grad #2663

durian888 opened this issue Apr 9, 2020 · 2 comments · Fixed by #2712
Labels
question Questions for the JAX team

Comments

@durian888
Copy link

Hi
this may be a little churlish to bring up but ints are acceptable for some calls in jax but not for others. I assume this is not by design.
example
jax.version.__version__ == '0.1.62'

import jax.numpy as np
from jax import grad
aval =2
print("aval=" +str(aval))
print("np.tanh(aval)=" + str(np.tanh(aval)))

try:
    grad(np.tanh)(aval)
except TypeError as e:
    print(e)
    print ("sadness - grad(np.tanh)(aval) failed ")

Output
aval=2
np.tanh(aval)=0.9640276
Primal inputs to reverse-mode differentiation must be of float or complex type, got type int32
sadness - grad(np.tanh)(aval) failed

@hawkinsp
Copy link
Collaborator

hawkinsp commented Apr 9, 2020

Well, I think this particular example is working as intended. You can think of your original tanh function as equivalent to the following:

import jax.numpy as jnp

def f(x):
  x = jnp.array(x, jnp.float32)
  return jnp.tanh(x)

In other words, the int->float cast happens inside the function you are differentiating. That is, you are differentiating a function that takes as input an integer and returns a float.

In general the cotangent types returned by grad match the input types, so we'd have to return an integer value to you, which probably isn't what you wanted. An error seems preferable.

What do you think?

@hawkinsp hawkinsp added the question Questions for the JAX team label Apr 9, 2020
@durian888
Copy link
Author

durian888 commented Apr 9, 2020 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Questions for the JAX team
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants