-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TensorFlow 2.6 error with JAX/FLAX implementation #14265
Comments
Thanks a lot for the issue @stefan-it ! Would it be fine for now for you to stick to TensorFlow version 2.5.0? |
Hi all, I get a similar issue, and I think is related to this issue: Hopefully will be solved by TF 2.6.2:
|
Hello there 👋 I happen to have encountered about the same problem on a CI build job today, and wasn't occurring yesterday. So I investigated, and the culprit seems to be keras 2.7 and not tensorflow: keras-team/keras#15585 On my end, the solution was to constraint the version index of keras to |
@avital @skye @marcvanzee - I think there seems to be a problem with the new keras release and JAX on TPU. Could you guys maybe check? :-) |
Hi @patrickvonplaten, I checked the issue. Perhaps I missed something, but it doesn't look like a Flax/JAX/TPU issue to me. I could indeed reproduce the problem on my machine with a similar stack trace, but from reading the stack trace, it seems like there is a conflict in importing modules from Keras. What makes you think this is related to JAX on TPU? |
I managed to reproduce this issue by installing from keras import optimizers (Suggested in keras-team/keras#15579) |
Gottcha! Sorry, yeah in this case, it does not seem to be related to JAX/FLAX at all, but |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
Hi guys,
this is probably a TPU-related bug and appears when using the JAX/FLAX implementation in combination with TensorFlow in version 2.6.0 and 2.6.1:
I could reproduce it using the
run_mlm_flax.py
example, e.g. with:It does not appear when using TensorFlow in version 2.5.0. I'm using latest master version of both Transformers and Datasets.
The text was updated successfully, but these errors were encountered: