Skip to content

Commit

Permalink
Initial implementation of np.linalg.lstsq() via SVD (jax-ml#2744)
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp authored and Jamie Townsend committed May 14, 2020
1 parent 043434d commit ca539ab
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/jax.numpy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ jax.numpy.linalg
eigvals
eigvalsh
inv
lstsq
matrix_power
matrix_rank
multi_dot
Expand Down
53 changes: 53 additions & 0 deletions jax/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,3 +483,56 @@ def solve(a, b):
for func in get_module_functions(np.linalg):
if func.__name__ not in globals():
globals()[func.__name__] = _not_implemented(func)


@_wraps(np.linalg.lstsq, lax_description=textwrap.dedent("""\
It has two important differences:
1. In `numpy.linalg.lstsq`, the default `rcond` is `-1`, and warns that in the future
the default will be `None`. Here, the default rcond is `None`.
2. In `np.linalg.lstsq` the returned residuals are empty for low-rank or over-determined
solutions. Here, the residuals are returned in all cases, to make the function
compatible with jit. The non-jit compatible numpy behavior can be recovered by
passing numpy_resid=True.
The lstsq function does not currently have a custom JVP rule, so the gradient is
poorly behaved for some inputs, particularly for low-rank `a`.
"""))
def lstsq(a, b, rcond=None, *, numpy_resid=False):
# TODO: add lstsq to lax_linalg and implement this function via those wrappers.
# TODO: add custom jvp rule for more robust lstsq differentiation
a, b = _promote_arg_dtypes(a, b)
if a.shape[0] != b.shape[0]:
raise ValueError("Leading dimensions of input arrays must match")
b_orig_ndim = b.ndim
if b_orig_ndim == 1:
b = b[:, None]
if a.ndim != 2:
raise TypeError(
f"{a.ndim}-dimensional array given. Array must be two-dimensional")
if b.ndim != 2:
raise TypeError(
f"{b_original_ndim}-dimensional array given. Array must be one or two-dimensional")
m, n = a.shape
dtype = a.dtype
if rcond is None:
rcond = jnp.finfo(dtype).eps * max(n, m)
elif rcond < 0:
rcond = jnp.finfo(dtype).eps
u, s, vt = svd(a, full_matrices=False)
mask = s >= rcond * s[0]
rank = mask.sum()
safe_s = jnp.where(mask, s, 1)
s_inv = jnp.where(mask, 1 / safe_s, 0)[:, jnp.newaxis]
uTb = jnp.matmul(u.conj().T, b, precision=lax.Precision.HIGHEST)
x = jnp.matmul(vt.conj().T, s_inv * uTb, precision=lax.Precision.HIGHEST)
# Numpy returns empty residuals in some cases. To allow compilation, we
# default to returning full residuals in all cases.
if numpy_resid and (rank < n or m <= n):
resid = jnp.asarray([])
else:
b_estimate = jnp.matmul(a, x, precision=lax.Precision.HIGHEST)
resid = norm(b - b_estimate, axis=0) ** 2
if b_orig_ndim == 1:
x = x.ravel()
return x, resid, rank, s
40 changes: 40 additions & 0 deletions tests/linalg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -842,6 +842,46 @@ def testMultiDot(self, shapes, dtype, rng_factory):
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True,
atol=tol, rtol=tol)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
"_lhs={}_rhs={}_lowrank={}_rcond={}".format(
jtu.format_shape_dtype_string(lhs_shape, dtype),
jtu.format_shape_dtype_string(rhs_shape, dtype),
lowrank, rcond),
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
"lowrank": lowrank, "rcond": rcond, "rng_factory": rng_factory}
for lhs_shape, rhs_shape in [
((1, 1), (1, 1)),
((4, 6), (4,)),
((6, 6), (6, 1)),
((8, 6), (8, 4)),
]
for lowrank in [True, False]
for rcond in [-1, None, 0.5]
for dtype in float_types + complex_types
for rng_factory in [jtu.rand_default]))
@jtu.skip_on_devices("tpu") # SVD not implemented on TPU.
def testLstsq(self, lhs_shape, rhs_shape, dtype, lowrank, rcond, rng_factory):
rng = rng_factory(self.rng())
_skip_if_unsupported_type(dtype)
onp_fun = partial(np.linalg.lstsq, rcond=rcond)
jnp_fun = partial(jnp.linalg.lstsq, rcond=rcond)
jnp_fun_numpy_resid = partial(jnp.linalg.lstsq, rcond=rcond, numpy_resid=True)
tol = {np.float32: 1e-6, np.float64: 1e-12,
np.complex64: 1e-6, np.complex128: 1e-12}
def args_maker():
lhs = rng(lhs_shape, dtype)
if lowrank and lhs_shape[1] > 1:
lhs[:, -1] = lhs[:, :-1].mean(1)
return [lhs, rng(rhs_shape, dtype)]

self._CheckAgainstNumpy(onp_fun, jnp_fun_numpy_resid, args_maker, check_dtypes=False, tol=tol)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True, atol=tol, rtol=tol)

# Disabled because grad is flaky for low-rank inputs.
# TODO:
# jtu.check_grads(lambda *args: jnp_fun(*args)[0], args_maker(), order=2, atol=1e-2, rtol=1e-2)

# Regression test for incorrect type for eigenvalues of a complex matrix.
@jtu.skip_on_devices("tpu") # TODO(phawkins): No complex eigh implementation on TPU.
def testIssue669(self):
Expand Down

0 comments on commit ca539ab

Please sign in to comment.