Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug- False "UserWarning: Explicitly requested dtype <class 'jax.numpy.int64'> requested in astype is not available, and will be truncated to dtype int32 #1841

Open
eadadi opened this issue Sep 22, 2024 · 1 comment

Comments

@eadadi
Copy link

eadadi commented Sep 22, 2024

Python 3.9.19 (main, May  6 2024, 19:43:03)
[GCC 11.2.0] :: Anaconda, Inc. on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax
>>> import tensorflow_probability.substrates.jax as tfp
>>> tfd = tfp.distributions
>>> rng = jax.random.PRNGKey(0)
>>> tfd.OneHotCategorical(logits=jax.random.normal(key=rng,shape=(3,4)), dtype=jax.numpy.float32).sample(seed=rng)
/home/user/anaconda3/envs/ml_exp/lib/python3.9/site-packages/jax/_src/numpy/array_methods.py:68: UserWarning: Explicitly requested dtype <class 'jax.numpy.int64'> requested in astype is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return lax_numpy.astype(arr, dtype, copy=copy, device=device)
Array([[1., 0., 0., 0.],
       [0., 1., 0., 0.],
       [0., 0., 0., 1.]], dtype=float32)

I explicitly set jnp.float32 and did not ask for int anywhere, but received the warning

@eadadi
Copy link
Author

eadadi commented Sep 22, 2024

comment: I will post this bug on Jax' repo as well

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant