Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AttributeError: module 'neural_tangents' has no attribute 'utils' #173

Open
LeavesLei opened this issue Feb 22, 2023 · 8 comments
Open

AttributeError: module 'neural_tangents' has no attribute 'utils' #173

LeavesLei opened this issue Feb 22, 2023 · 8 comments
Labels
bug Something isn't working

Comments

@LeavesLei
Copy link

Hi developers, I've met a problem when using neural-tangents as follows:

KERNEL_FN = nt.utils.batch.batch(KERNEL_FN, batch_size=kernel_batch_size)
AttributeError: module 'neural_tangents' has no attribute 'utils'

There are the versions of some library:

  • neural-tangents: 0.6.2
  • scipy: 1.10.1
  • numpy: 1.24.2
@romanngg
Copy link
Contributor

We had a refactoring a while ago, please try nt.batch

See https://github.com/google/neural-tangents/blob/main/neural_tangents/__init__.py for the public API

@LeavesLei
Copy link
Author

Thanks for your fast reply. I changed nt.utilts.batch.batch() to nt.batch(), but another error occured as follows:

Traceback (most recent call last):                                                                                                           
  File "eval_distilled_set.py", line 190, in <module>                                                                                         
    main()                                                                                                                                    
  File "eval_distilled_set.py", line 156, in main                                                                                             
    K_zz = KERNEL_FN(X_sup_reordered, X_sup_reordered)                                                                                        
  File "/usr/local/lib/python3.8/dist-packages/neural_tangents/_src/utils/utils.py", line 188, in h                                           
    return g(*args, **kwargs)                                                                                                                
  File "/usr/local/lib/python3.8/dist-packages/neural_tangents/_src/batching.py", line 471, in serial_fn                                      
    return serial_fn_x1(x1_or_kernel, x2, *args, **kwargs)                                                                                    
  File "/usr/local/lib/python3.8/dist-packages/neural_tangents/_src/batching.py", line 398, in serial_fn_x1                                   
    _, kernel = _scan(row_fn, 0, (x1s, kwargs_np1))                                                                                           
  File "/usr/local/lib/python3.8/dist-packages/neural_tangents/_src/batching.py", line 151, in _scan                                          
    carry, y = f(carry, x)                                                                                                                    
  File "/usr/local/lib/python3.8/dist-packages/neural_tangents/_src/batching.py", line 387, in row_fn                                         
    return _, _scan(col_fn, x1, (x2s, kwargs_np2))[1]                                                                                         
  File "/usr/local/lib/python3.8/dist-packages/neural_tangents/_src/batching.py", line 151, in _scan                                          
    carry, y = f(carry, x)                                                                                                                    
  File "/usr/local/lib/python3.8/dist-packages/neural_tangents/_src/batching.py", line 396, in col_fn                                         
    return (x1, kwargs1), kernel_fn(x1, x2, *args, **kwargs_merge)                                                                            
  File "/usr/local/lib/python3.8/dist-packages/neural_tangents/_src/utils/utils.py", line 188, in h                                           
    return g(*args, **kwargs)                                                                                                                 
  File "/usr/local/lib/python3.8/dist-packages/neural_tangents/_src/batching.py", line 758, in f_pmapped                                      
    return _f(x_or_kernel, *args_np, **kwargs_np)                                                                                             
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: RET_CHECK failure (external/org_tensorflow/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
:627) dnn != nullptr

where KERNEL_FN = functools.partial(kernel_fn, get=('nngp', 'ntk')).

@romanngg
Copy link
Contributor

Haven't seen this error before, does it still happen if you reduce the batch size? I sometimes encounter low-level XLA errors when running out of memory.

@LeavesLei
Copy link
Author

I've redunced the batch size from 25 to 5, but the error still occured. I guess the mismatch between cudnn version and jax caused the problem due to the dnn != nullptr? (jax-ml/jax#14480)

I am using Ubuntu 20.04, CUDA 11.4, cudnn 8.7.0, and GPU is TITAN V (12GB).

@romanngg
Copy link
Contributor

romanngg commented Feb 23, 2023

Good catch, could be, what are your jax and jaxlib [edit: and nvidia driver] versions?

@LeavesLei
Copy link
Author

LeavesLei commented Feb 23, 2023

import jax, jaxlib
jax.__version__: 0.4.4
jaxlib.__version__: 0.4.4

NVIDIA-SMI 470.161.03, Driver Version: 470.161.03

@romanngg
Copy link
Contributor

romanngg commented Mar 6, 2023

Hm, these all seem compatible per https://docs.nvidia.com/deeplearning/cudnn/support-matrix/index.html
Have you tried updating per jax-ml/jax#14480 (comment) ?

@romanngg romanngg added the bug Something isn't working label Mar 6, 2023
@LeavesLei
Copy link
Author

Hi, Roman

Thanks for your reply, and I'll try to update the cuDNN version to solve the problem.

Best,
Shiye

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants