Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vmap svd not working #1397

Closed
proteneer opened this issue Sep 25, 2019 · 4 comments
Closed

vmap svd not working #1397

proteneer opened this issue Sep 25, 2019 · 4 comments
Labels
bug Something isn't working

Comments

@proteneer
Copy link
Contributor

proteneer commented Sep 25, 2019

import jax
import jax.numpy as np

dij = np.array(
[
[0., 2., 3., 2.0],
[2., 0., 4., 2.0],
[3., 4., 0., 1.0],
[2., 2., 1., 0.0]
])

def cmdscale_3(D):
    '''
    D: distance matrix (N,N)
    returns: coordinates given the distance matrix (N,3)
    '''
    # Generate Gramian Matrix (B)
    n = len(D)                                                                       
    H = np.eye(n) - np.ones((n, n))/n
    B = -H.dot(D**2).dot(H)/2
    
    # perform SVD
    u, s, v = np.linalg.svd(B)
    # s is guaranteed to be all real since B is hermitian
    x = u.dot(np.diag(np.sqrt(s)))
    return x[:, :3]/10

cmdscale_3(dij)
dijs = np.array([dij, dij])
bfn = jax.vmap(cmdscale_3)
print(dijs.shape)
bfn(dijs)

Gives error

  File "/Users/hessian/venv/lib/python3.6/site-packages/jax/core.py", line 133, in bind
    out_tracer = top_trace.process_primitive(self, tracers, kwargs)
  File "/Users/hessian/venv/lib/python3.6/site-packages/jax/interpreters/batching.py", line 115, in process_primitive
    val_out, dim_out = batched_primitive(vals_in, dims_in, **params)
  File "/Users/hessian/venv/lib/python3.6/site-packages/jax/lax_linalg.py", line 798, in svd_batching_rule
    outs = svd_p.bind(x, full_matrices=full_matrices, compute_uv=compute_uv)
  File "/Users/hessian/venv/lib/python3.6/site-packages/jax/core.py", line 130, in bind
    return self.impl(*args, **kwargs)
  File "/Users/hessian/venv/lib/python3.6/site-packages/jax/lax_linalg.py", line 726, in svd_impl
    compute_uv=compute_uv)
  File "/Users/hessian/venv/lib/python3.6/site-packages/jax/interpreters/xla.py", line 123, in apply_primitive
    compiled_fun = xla_primitive_callable(prim, *abstract_args, **params)
  File "/Users/hessian/venv/lib/python3.6/site-packages/jax/interpreters/xla.py", line 136, in xla_primitive_callable
    built_c = primitive_computation(prim, *xla_shapes, **params)
  File "/Users/hessian/venv/lib/python3.6/site-packages/jax/interpreters/xla.py", line 150, in primitive_computation
    rule(c, *xla_args, **new_params)  # return val set as a side-effect on c
  File "/Users/hessian/venv/lib/python3.6/site-packages/jax/lax_linalg.py", line 784, in _svd_cpu_gpu_translation_rule
    compute_uv=compute_uv)
  File "jaxlib/lapack.pyx", line 1073, in lapack.gesdd
ValueError: too many values to unpack (expected 2)

Is this unsupported or just buggy or maybe I'm just doing something stupid?

jax==0.1.46
jaxlib==0.1.28

@proteneer
Copy link
Contributor Author

(Note that eigh is working fine)

@mattjj mattjj added the bug Something isn't working label Sep 25, 2019
@hawkinsp
Copy link
Collaborator

I think this is actually already fixed in jaxlib at head (by #1314) but we haven't made a new binary release since that time. Certainly it works for me at head. I guess we should make a new release. Until then you could build jaxlib from source?

@proteneer
Copy link
Contributor Author

Good to know. I can implement SVD with eigh trivially right now so this isn't a blocker. Once jaxlibs has been upgraded I'll update.

Thanks again!

@hawkinsp
Copy link
Collaborator

I pushed a new jaxlib (0.1.29) in which this issue should be fixed. Hope that helps!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants