Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cusolver bug (maybe SVD?) #1355

Closed
clemisch opened this issue Sep 16, 2019 · 1 comment
Closed

cusolver bug (maybe SVD?) #1355

clemisch opened this issue Sep 16, 2019 · 1 comment

Comments

@clemisch
Copy link
Contributor

While rerunning the repro from #1259 (batched SVD) on jax version 0.1.45 I got a new error.

Repro:

import jax
import jax.numpy as np
import numpy as onp

x_host = onp.random.rand(100000, 3, 3).astype(onp.float32)
x_gpu = np.array(x_host)

svd_batch = jax.jit(jax.vmap(np.linalg.svd, 0, 0))

u1, s1, v1 = onp.linalg.svd(x_host)
u2, s2, v2 = np.linalg.svd(x_gpu) # Error
u3, s3, v3 = svd_batch(x_gpu)     # Error

Error:

[...]
~/.local/lib/python3.5/site-packages/jaxlib/cusolver.py in gesvd(c, a, full_matrices, compute_uv)
    195   dtype = a_shape.element_type()
    196   b = 1
--> 197   m, n = a_shape.dimensions()
    198   singular_vals_dtype = _real_type(dtype)
    199 

ValueError: too many values to unpack (expected 2)

This repro worked fine when I posted the issue. I guess this is a new bug?

@clemisch
Copy link
Contributor Author

I compiled and installed the very newest jaxlib and now it works 🤔

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant