Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ValueError: Non-hashable static arguments are not supported #7

Closed
cifkao opened this issue Dec 1, 2020 · 8 comments
Closed

ValueError: Non-hashable static arguments are not supported #7

cifkao opened this issue Dec 1, 2020 · 8 comments
Assignees

Comments

@cifkao
Copy link
Contributor

cifkao commented Dec 1, 2020

I am trying to train the Transformer on the listops task using the command from the readme, but I get the following error:

I1201 18:31:38.370242 23342722882432 input_pipeline.py:75] Finished processing vocab size=14                                                                                                                
Traceback (most recent call last):                                                                                                                                                                          
  File "lra_benchmarks/listops/train.py", line 288, in <module>                                                                                                                                             
    app.run(main)                                                                                                                                                                                           
  File "/ids-cluster-storage/storage/ocifka/venv/lra/lib/python3.7/site-packages/absl/app.py", line 300, in run                                                                                             
    _run_main(main, args)                                                                                                                                                                                   
  File "/ids-cluster-storage/storage/ocifka/venv/lra/lib/python3.7/site-packages/absl/app.py", line 251, in _run_main                                                                                       
    sys.exit(main(argv))                                                                                                                                                                                    
  File "lra_benchmarks/listops/train.py", line 193, in main                                                                                                                                                 
    model_kwargs)                                                                                                                                                                                           
jax._src.traceback_util.FilteredStackTrace: ValueError: Non-hashable static arguments are not supported. An error occured while trying to hash an object of type <class 'dict'>, {'vocab_size': 16, 'emb_dim
': 512, 'num_heads': 8, 'num_layers': 6, 'qkv_dim': 512, 'mlp_dim': 2048, 'max_len': 2000, 'classifier': True, 'num_classes': 10}. The error was:                                                           
TypeError: unhashable type: 'dict'                                                                                                                                                                          
                                                                                                                                                                                                            
The stack trace above excludes JAX-internal frames.                                                                                                                                                         
The following is the original exception that occurred, unmodified.                                                                                                                                          
                                                                                                                                                                                                            
--------------------                                                                                                                                                                                        
                                                                                                                                                                                                            
The above exception was the direct cause of the following exception:                                                                                                                                        
                                                                                                                                                                                                            
Traceback (most recent call last):                                                                                                                                                                          
  File "lra_benchmarks/listops/train.py", line 288, in <module>                                                                                                                                             
    app.run(main)                                                                                                                                                                                           
  File "/ids-cluster-storage/storage/ocifka/venv/lra/lib/python3.7/site-packages/absl/app.py", line 300, in run                                                                                             
    _run_main(main, args)                                                                                                                                                                                   
  File "/ids-cluster-storage/storage/ocifka/venv/lra/lib/python3.7/site-packages/absl/app.py", line 251, in _run_main                                                                                       
    sys.exit(main(argv))                                                                                                                                                                                    
  File "lra_benchmarks/listops/train.py", line 193, in main                                                                                                                                                 
    model_kwargs)                                                                                                                                                                                           
  File "/ids-cluster-storage/storage/ocifka/venv/lra/lib/python3.7/site-packages/jax/_src/traceback_util.py", line 133, in reraise_with_filtered_traceback                                                  
    return fun(*args, **kwargs)                                                                                                                                                                             
  File "/ids-cluster-storage/storage/ocifka/venv/lra/lib/python3.7/site-packages/jax/api.py", line 371, in f_jitted                                                                                         
    return cpp_jitted_f(*args, **kwargs)                                                                                                                                                                    
ValueError: Non-hashable static arguments are not supported. An error occured while trying to hash an object of type <class 'dict'>, {'vocab_size': 16, 'emb_dim': 512, 'num_heads': 8, 'num_layers': 6, 'qk
v_dim': 512, 'mlp_dim': 2048, 'max_len': 2000, 'classifier': True, 'num_classes': 10}. The error was:
TypeError: unhashable type: 'dict'

I have jax==0.2.6 and jaxlib==0.1.57+cuda101.

@ppham27
Copy link
Collaborator

ppham27 commented Dec 1, 2020

Let me take a look at this. My first guess is that you're running on a newer version of Jax than what we tested on, but I will try to verify later today.

@cifkao
Copy link
Contributor Author

cifkao commented Dec 1, 2020

@ppham27 I tried downgrading Jax to 0.2.4 from requirements.txt, but that didn't help. Maybe I need to downgrade jaxlib too?

@ppham27
Copy link
Collaborator

ppham27 commented Dec 1, 2020

Okay, I thought it was the same as google/flax#587, but if it isn't I will dig deeper.

@ppham27
Copy link
Collaborator

ppham27 commented Dec 1, 2020

So, the VM that I was testing on had Jax 0.2.4, and it the command was working fine. When I upgraded 0.2.6, i got the same error you have. Are you sure you cleared your old install correctly?

In any case, I'm going to try to get a version that works in 0.2.6 by tomorrow.

@cifkao
Copy link
Contributor Author

cifkao commented Dec 2, 2020

Not sure how to clear the old install correctly... I just did this:

pip uninstall jax jaxlib
pip install --upgrade jax==0.2.4 jaxlib==0.1.57+cuda101 -f https://storage.googleapis.com/jax-releases/jax_releases.html

and checked that python -c 'import jax; print(jax.__version__)' is 0.2.4. Still the same result.

What is your jaxlib and TF version?

@ppham27
Copy link
Collaborator

ppham27 commented Dec 2, 2020

From a clean install of python virtualenv,

pip install -r requirements.txt
pip install jaxlib==0.1.56+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html
pip uninstall jax
pip install jax==0.2.4

Python version is 3.6.9. This is on a Ubuntu GCP machine with 8x V100s.

If you can't wait for the fix to be pushed, this should work with Jax 0.2.6. Replace

@functools.partial(jax.jit, static_argnums=(1, 2, 3))
def create_model(key, flax_module, input_shape, model_kwargs):
module = flax_module.partial(**model_kwargs)
with nn.stochastic(key):
_, initial_params = module.init_by_shape(key, [(input_shape, jnp.float32)])
model = nn.Model(module, initial_params)
return model

with

def create_model(key, flax_module, input_shape, model_kwargs):
  """Creates and initializes the model."""

  @functools.partial(jax.jit, backend='cpu')
  def _create_model(key):
    module = flax_module.partial(**model_kwargs)
    with nn.stochastic(key):
      _, initial_params = module.init_by_shape(key,
                                               [(input_shape, jnp.float32)])
      model = nn.Model(module, initial_params)
    return model

  return _create_model(key)

@cifkao
Copy link
Contributor Author

cifkao commented Dec 2, 2020

Thanks, downgrading to jaxlib==0.1.56+cuda101 (from 0.1.57) seems to have fixed it for me.

@ppham27 ppham27 self-assigned this Dec 2, 2020
@ppham27
Copy link
Collaborator

ppham27 commented Dec 2, 2020

Fixed by 1309003.

@ppham27 ppham27 closed this as completed Dec 2, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants