You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
importjaximportjax.numpyasnpdij=np.array(
[
[0., 2., 3., 2.0],
[2., 0., 4., 2.0],
[3., 4., 0., 1.0],
[2., 2., 1., 0.0]
])
defcmdscale_3(D):
''' D: distance matrix (N,N) returns: coordinates given the distance matrix (N,3) '''# Generate Gramian Matrix (B)n=len(D)
H=np.eye(n) -np.ones((n, n))/nB=-H.dot(D**2).dot(H)/2# perform SVDu, s, v=np.linalg.svd(B)
# s is guaranteed to be all real since B is hermitianx=u.dot(np.diag(np.sqrt(s)))
returnx[:, :3]/10cmdscale_3(dij)
dijs=np.array([dij, dij])
bfn=jax.vmap(cmdscale_3)
print(dijs.shape)
bfn(dijs)
Gives error
File "/Users/hessian/venv/lib/python3.6/site-packages/jax/core.py", line 133, in bind
out_tracer = top_trace.process_primitive(self, tracers, kwargs)
File "/Users/hessian/venv/lib/python3.6/site-packages/jax/interpreters/batching.py", line 115, in process_primitive
val_out, dim_out = batched_primitive(vals_in, dims_in, **params)
File "/Users/hessian/venv/lib/python3.6/site-packages/jax/lax_linalg.py", line 798, in svd_batching_rule
outs = svd_p.bind(x, full_matrices=full_matrices, compute_uv=compute_uv)
File "/Users/hessian/venv/lib/python3.6/site-packages/jax/core.py", line 130, in bind
return self.impl(*args, **kwargs)
File "/Users/hessian/venv/lib/python3.6/site-packages/jax/lax_linalg.py", line 726, in svd_impl
compute_uv=compute_uv)
File "/Users/hessian/venv/lib/python3.6/site-packages/jax/interpreters/xla.py", line 123, in apply_primitive
compiled_fun = xla_primitive_callable(prim, *abstract_args, **params)
File "/Users/hessian/venv/lib/python3.6/site-packages/jax/interpreters/xla.py", line 136, in xla_primitive_callable
built_c = primitive_computation(prim, *xla_shapes, **params)
File "/Users/hessian/venv/lib/python3.6/site-packages/jax/interpreters/xla.py", line 150, in primitive_computation
rule(c, *xla_args, **new_params) # return val set as a side-effect on c
File "/Users/hessian/venv/lib/python3.6/site-packages/jax/lax_linalg.py", line 784, in _svd_cpu_gpu_translation_rule
compute_uv=compute_uv)
File "jaxlib/lapack.pyx", line 1073, in lapack.gesdd
ValueError: too many values to unpack (expected 2)
Is this unsupported or just buggy or maybe I'm just doing something stupid?
jax==0.1.46
jaxlib==0.1.28
The text was updated successfully, but these errors were encountered:
I think this is actually already fixed in jaxlib at head (by #1314) but we haven't made a new binary release since that time. Certainly it works for me at head. I guess we should make a new release. Until then you could build jaxlib from source?
Gives error
Is this unsupported or just buggy or maybe I'm just doing something stupid?
jax==0.1.46
jaxlib==0.1.28
The text was updated successfully, but these errors were encountered: