Skip to content

Commit

Permalink
fix issues after rebase on master
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed May 6, 2020
1 parent 953e266 commit b199179
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 16 deletions.
18 changes: 9 additions & 9 deletions jax/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,7 @@ def solve(a, b):
globals()[func.__name__] = _not_implemented(func)


@_wraps(onp.linalg.lstsq, lax_description=textwrap.dedent("""\
@_wraps(np.linalg.lstsq, lax_description=textwrap.dedent("""\
It has two important differences:
1. In `numpy.linalg.lstsq`, the default `rcond` is `-1`, and warns that in the future
Expand Down Expand Up @@ -513,22 +513,22 @@ def lstsq(a, b, rcond=None, *, numpy_resid=False):
m, n = a.shape
dtype = a.dtype
if rcond is None:
rcond = np.finfo(dtype).eps * max(n, m)
rcond = jnp.finfo(dtype).eps * max(n, m)
elif rcond < 0:
rcond = np.finfo(dtype).eps
rcond = jnp.finfo(dtype).eps
u, s, vt = svd(a, full_matrices=False)
mask = s >= rcond * s[0]
rank = mask.sum()
safe_s = np.where(mask, s, 1)
s_inv = np.where(mask, 1 / safe_s, 0)[:, np.newaxis]
uTb = np.matmul(u.conj().T, b, precision=lax.Precision.HIGHEST)
x = np.matmul(vt.conj().T, s_inv * uTb, precision=lax.Precision.HIGHEST)
safe_s = jnp.where(mask, s, 1)
s_inv = jnp.where(mask, 1 / safe_s, 0)[:, jnp.newaxis]
uTb = jnp.matmul(u.conj().T, b, precision=lax.Precision.HIGHEST)
x = jnp.matmul(vt.conj().T, s_inv * uTb, precision=lax.Precision.HIGHEST)
# Numpy returns empty residuals in some cases. To allow compilation, we
# default to returning full residuals in all cases.
if numpy_resid and (rank < n or m <= n):
resid = np.asarray([])
resid = jnp.asarray([])
else:
b_estimate = np.matmul(a, x, precision=lax.Precision.HIGHEST)
b_estimate = jnp.matmul(a, x, precision=lax.Precision.HIGHEST)
resid = norm(b - b_estimate, axis=0) ** 2
if b_orig_ndim == 1:
x = x.ravel()
Expand Down
14 changes: 7 additions & 7 deletions tests/linalg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,20 +862,20 @@ def testMultiDot(self, shapes, dtype, rng_factory):
for rng_factory in [jtu.rand_default]))
@jtu.skip_on_devices("tpu") # SVD not implemented on TPU.
def testLstsq(self, lhs_shape, rhs_shape, dtype, lowrank, rcond, rng_factory):
rng = rng_factory()
rng = rng_factory(self.rng())
_skip_if_unsupported_type(dtype)
onp_fun = partial(onp.linalg.lstsq, rcond=rcond)
jnp_fun = partial(np.linalg.lstsq, rcond=rcond)
jnp_fun_numpy = partial(np.linalg.lstsq, rcond=rcond, numpy_resid=True)
tol = {onp.float32: 1e-6, onp.float64: 1e-12,
onp.complex64: 1e-6, onp.complex128: 1e-12}
onp_fun = partial(np.linalg.lstsq, rcond=rcond)
jnp_fun = partial(jnp.linalg.lstsq, rcond=rcond)
jnp_fun_numpy_resid = partial(jnp.linalg.lstsq, rcond=rcond, numpy_resid=True)
tol = {np.float32: 1e-6, np.float64: 1e-12,
np.complex64: 1e-6, np.complex128: 1e-12}
def args_maker():
lhs = rng(lhs_shape, dtype)
if lowrank and lhs_shape[1] > 1:
lhs[:, -1] = lhs[:, :-1].mean(1)
return [lhs, rng(rhs_shape, dtype)]

self._CheckAgainstNumpy(onp_fun, jnp_fun_numpy, args_maker, check_dtypes=False, tol=tol)
self._CheckAgainstNumpy(onp_fun, jnp_fun_numpy_resid, args_maker, check_dtypes=False, tol=tol)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True, atol=tol, rtol=tol)

# Disabled because grad is flaky for low-rank inputs.
Expand Down

0 comments on commit b199179

Please sign in to comment.