Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unsolvable jax dependency #321

Open
adiaconu11 opened this issue May 23, 2024 · 2 comments
Open

Unsolvable jax dependency #321

adiaconu11 opened this issue May 23, 2024 · 2 comments

Comments

@adiaconu11
Copy link

Following the most recent commit, the jax.tree module is now being used instead of jax.tree_*. This is change requires jax >= 0.4.25. However, the are still many parts of the repository that are still old/deprecated. For instance, if you install 0.4.25 you might get something like:

AttributeError: module 'jax.random' has no attribute 'KeyArray'

This is because this module has been removed in jax 0.4.24, meaning that in order to not run into this problem you need jax <=0.4.23. Obviously this goes against the requirement above.

Lastly, there is still the issue with DeviceArray and ShardedDeviceArray. They have all been changed to somply jax.Array back in jax=0.4.0! At the current state of the repo you basically need to add lines like:

jax.interpreters.xla.DeviceArray = jax.Array in order to be able to even import acme...

@adiaconu11
Copy link
Author

the jax.Array issue can be solved by installing chex==0.1.7 instead of the default 0.1.6. This should be updated in the requirements of acme.

@lkoelman
Copy link

Bumping into similar problems. It is very hard to install a functional stack.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants