Skip to content

Commit

Permalink
test gradient for lstsq
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Apr 18, 2020
1 parent 6c860d4 commit 966021d
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 13 deletions.
25 changes: 13 additions & 12 deletions jax/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,27 +372,27 @@ def solve(a, b):
@_wraps(onp.linalg.lstsq, lax_description=textwrap.dedent("""\
It has two important differences:
1. In `numpy.linalg.lstsq`, the default `rcond` is `-1`, and shows a deprecation
warning. In jax.numpy, the default rcond is `None`.
2. In `np.linalg.lstsq` the residuals return an empty list for low-rank solutions.
Here, the residuals are returned in all cases, to make the function compatible
with jit. The non-jit compatible numpy behavior can be recovered by passing
numpy_resid=True.
1. In `numpy.linalg.lstsq`, the default `rcond` is `-1`, and warns that in the future
the default will be `None`. Here, the default rcond is `None`.
2. In `np.linalg.lstsq` the returned residuals are empty for low-rank or over-determined
solutions. Here, the residuals are returned in all cases, to make the function
compatible with jit. The non-jit compatible numpy behavior can be recovered by
passing numpy_resid=True.
"""))
def lstsq(a, b, rcond=None, *, numpy_resid=False):
# TODO: add lstsq to lax_linalg and implement this function via those wrappers.
a, b = _promote_arg_dtypes(a, b)
if a.shape[0] != b.shape[0]:
raise ValueError("Leading dimensions of input arrays must match")
b_ndim = b.ndim
if b_ndim == 1:
b_orig_ndim = b.ndim
if b_orig_ndim == 1:
b = b[:, None]
if a.ndim != 2:
raise TypeError(
f"{a.ndim}-dimensional array given. Array must be two-dimensional")
if b.ndim != 2:
raise TypeError(
f"{b_ndim}-dimensional array given. Array must be one or two-dimensional")
f"{b_original_ndim}-dimensional array given. Array must be one or two-dimensional")
m, n = a.shape
dtype = a.dtype
if rcond is None:
Expand All @@ -405,13 +405,14 @@ def lstsq(a, b, rcond=None, *, numpy_resid=False):
safe_s = np.where(mask, s, 1)
s_inv = np.where(mask, 1 / safe_s, 0)[:, np.newaxis]
uTb = np.matmul(u.conj().T, b, precision=lax.Precision.HIGHEST)
x = np.matmul(vt.conj().T, uTb * s_inv, precision=lax.Precision.HIGHEST)
# Numpy returns empty residuals in some cases. We return residuals
x = np.matmul(vt.conj().T, s_inv * uTb, precision=lax.Precision.HIGHEST)
# Numpy returns empty residuals in some cases. To allow compilation, we
# default to returning full residuals in all cases.
if numpy_resid and (rank < n or m <= n):
resid = np.asarray([])
else:
b_estimate = np.matmul(a, x, precision=lax.Precision.HIGHEST)
resid = norm(b - b_estimate, axis=0) ** 2
if b_ndim == 1:
if b_orig_ndim == 1:
x = x.ravel()
return x, resid, rank, s
5 changes: 4 additions & 1 deletion tests/linalg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,7 +818,10 @@ def args_maker():

self._CheckAgainstNumpy(onp_fun, jnp_fun_numpy, args_maker, check_dtypes=False, tol=tol)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True, atol=tol, rtol=tol)
# jtu.check_grads(jnp_fun, args_maker(), order=2, atol=tol, rtol=tol)

if np.finfo(dtype).bits == 64:
# Only check grad for first argument:
jtu.check_grads(lambda *args: jnp_fun(*args)[0], args_maker(), order=2, atol=1e-2, rtol=1e-2)

# Regression test for incorrect type for eigenvalues of a complex matrix.
@jtu.skip_on_devices("tpu") # TODO(phawkins): No complex eigh implementation on TPU.
Expand Down

0 comments on commit 966021d

Please sign in to comment.