-
Notifications
You must be signed in to change notification settings - Fork 83
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ValueError: Non-hashable static arguments are not supported #7
Comments
Let me take a look at this. My first guess is that you're running on a newer version of Jax than what we tested on, but I will try to verify later today. |
@ppham27 I tried downgrading Jax to 0.2.4 from requirements.txt, but that didn't help. Maybe I need to downgrade jaxlib too? |
Okay, I thought it was the same as google/flax#587, but if it isn't I will dig deeper. |
So, the VM that I was testing on had Jax 0.2.4, and it the command was working fine. When I upgraded 0.2.6, i got the same error you have. Are you sure you cleared your old install correctly? In any case, I'm going to try to get a version that works in 0.2.6 by tomorrow. |
Not sure how to clear the old install correctly... I just did this: pip uninstall jax jaxlib
pip install --upgrade jax==0.2.4 jaxlib==0.1.57+cuda101 -f https://storage.googleapis.com/jax-releases/jax_releases.html and checked that What is your jaxlib and TF version? |
From a clean install of python virtualenv,
Python version is 3.6.9. This is on a Ubuntu GCP machine with 8x V100s. If you can't wait for the fix to be pushed, this should work with Jax 0.2.6. Replace long-range-arena/lra_benchmarks/listops/train.py Lines 55 to 61 in 9407f98
with def create_model(key, flax_module, input_shape, model_kwargs):
"""Creates and initializes the model."""
@functools.partial(jax.jit, backend='cpu')
def _create_model(key):
module = flax_module.partial(**model_kwargs)
with nn.stochastic(key):
_, initial_params = module.init_by_shape(key,
[(input_shape, jnp.float32)])
model = nn.Model(module, initial_params)
return model
return _create_model(key) |
Thanks, downgrading to |
Fixed by 1309003. |
I am trying to train the Transformer on the listops task using the command from the readme, but I get the following error:
I have
jax==0.2.6
andjaxlib==0.1.57+cuda101
.The text was updated successfully, but these errors were encountered: