Skip to content

Commit

Permalink
Add explicit derivative for jax.numpy.linalg.pinv. (#2794)
Browse files Browse the repository at this point in the history
* Add explicit derivative for jax.numpy.linalg.pinv.

* Fix type confusion problems in the JVP rule for SVD that meant it produced 64-bit tangents for 32-bit primals.
  • Loading branch information
hawkinsp authored Apr 23, 2020
1 parent c3ab1fc commit 8fe3c59
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 15 deletions.
9 changes: 5 additions & 4 deletions jax/lax_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -852,17 +852,18 @@ def svd_jvp_rule(primals, tangents, full_matrices, compute_uv):
s_dim = s[..., None, :]
dS = np.matmul(np.matmul(Ut, dA), V)
ds = np.real(np.diagonal(dS, 0, -2, -1))
F = 1 / (np.square(s_dim) - np.square(_T(s_dim)) + np.eye(k)) - np.eye(k)
F = 1 / (np.square(s_dim) - np.square(_T(s_dim)) + np.eye(k, dtype=A.dtype))
F = F - np.eye(k, dtype=A.dtype)
dSS = s_dim * dS
SdS = _T(s_dim) * dS
dU = np.matmul(U, F * (dSS + _T(dSS)))
dV = np.matmul(V, F * (SdS + _T(SdS)))

m, n = A.shape[-2], A.shape[-1]
m, n = A.shape[-2:]
if m > n:
dU = dU + np.matmul(np.eye(m) - np.matmul(U, Ut), np.matmul(dA, V)) / s_dim
dU = dU + np.matmul(np.eye(m, dtype=A.dtype) - np.matmul(U, Ut), np.matmul(dA, V)) / s_dim
if n > m:
dV = dV + np.matmul(np.eye(n) - np.matmul(V, Vt), np.matmul(_H(dA), U)) / s_dim
dV = dV + np.matmul(np.eye(n, dtype=A.dtype) - np.matmul(V, Vt), np.matmul(_H(dA), U)) / s_dim
return (s, U, Vt), (ds, dU, _T(dV))

def _svd_cpu_gpu_translation_rule(gesvd_impl, c, operand, full_matrices, compute_uv):
Expand Down
36 changes: 26 additions & 10 deletions jax/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from ..third_party.numpy.linalg import cond, multi_dot, tensorinv, tensorsolve

_T = lambda x: np.swapaxes(x, -1, -2)
_H = lambda x: np.conj(np.swapaxes(x, -1, -2))


def _promote_arg_dtypes(*args):
Expand Down Expand Up @@ -188,32 +189,47 @@ def eigvalsh(a, UPLO='L'):
return w


@partial(custom_jvp, nondiff_argnums=(1,))
@_wraps(onp.linalg.pinv, lax_description=textwrap.dedent("""\
It differs only in default value of `rcond`. In `numpy.linalg.pinv`, the
default `rcond` is `1e-15`. Here the default is
`10. * max(num_rows, num_cols) * np.finfo(dtype).eps`.
"""))
def pinv(a, rcond=None):
# ported from https://github.com/numpy/numpy/blob/v1.17.0/numpy/linalg/linalg.py#L1890-L1979
# Uses same algorithm as
# https://github.com/numpy/numpy/blob/v1.17.0/numpy/linalg/linalg.py#L1890-L1979
a = np.conj(a)
# copied from https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/math/linalg.py#L442
if rcond is None:
max_rows_cols = max(a.shape[-2:])
rcond = 10. * max_rows_cols * np.finfo(a.dtype).eps
max_rows_cols = max(a.shape[-2:])
rcond = 10. * max_rows_cols * np.finfo(a.dtype).eps
rcond = np.asarray(rcond)
u, s, v = svd(a, full_matrices=False)
# Singular values less than or equal to ``rcond * largest_singular_value``
# are set to zero.
cutoff = rcond[..., np.newaxis] * np.amax(s, axis=-1, keepdims=True)
large = s > cutoff
s = np.divide(1, s)
s = np.where(large, s, 0)
vT = np.swapaxes(v, -1, -2)
uT = np.swapaxes(u, -1, -2)
res = np.matmul(vT, np.multiply(s[..., np.newaxis], uT))
s = np.where(s > cutoff, s, np.inf)
res = np.matmul(_T(v), np.divide(_T(u), s[..., np.newaxis]))
return lax.convert_element_type(res, a.dtype)


@pinv.defjvp
def _pinv_jvp(rcond, primals, tangents):
# The Differentiation of Pseudo-Inverses and Nonlinear Least Squares Problems
# Whose Variables Separate. Author(s): G. H. Golub and V. Pereyra. SIAM
# Journal on Numerical Analysis, Vol. 10, No. 2 (Apr., 1973), pp. 413-432.
# (via https://en.wikipedia.org/wiki/Moore%E2%80%93Penrose_inverse#Derivative)
a, = primals
a_dot, = tangents
p = pinv(a, rcond=rcond)
m, n = a.shape[-2:]
# TODO(phawkins): on TPU, we would need to opt into high precision here.
# TODO(phawkins): consider if this can be simplified in the Hermitian case.
p_dot = -p @ a_dot @ p
p_dot = p_dot + p @ _H(p) @ _H(a_dot) @ (np.eye(m, dtype=a.dtype) - a @ p)
p_dot = p_dot + (np.eye(n, dtype=a.dtype) - p @ a) @ _H(a_dot) @ _H(p) @ p
return p, p_dot


@_wraps(onp.linalg.inv)
def inv(a):
if np.ndim(a) < 2 or a.shape[-1] != a.shape[-2]:
Expand Down
20 changes: 19 additions & 1 deletion tests/linalg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,7 +703,7 @@ def args_maker():
{"testcase_name":
"_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
"shape": shape, "dtype": dtype, "rng_factory": rng_factory}
for shape in [(1, 1), (4, 4), (2, 70, 7), (2000, 7), (7, 10000), (70, 7, 2)]
for shape in [(1, 1), (4, 4), (2, 70, 7), (2000, 7), (7, 1000), (70, 7, 2)]
for dtype in float_types + complex_types
for rng_factory in [jtu.rand_default]))
@jtu.skip_on_devices("tpu") # SVD is not implemented on the TPU backend
Expand All @@ -716,6 +716,24 @@ def testPinv(self, shape, dtype, rng_factory):
self._CheckAgainstNumpy(onp.linalg.pinv, np.linalg.pinv, args_maker,
check_dtypes=True, tol=1e-2)
self._CompileAndCheck(np.linalg.pinv, args_maker, check_dtypes=True)
# TODO(phawkins): 1e-1 seems like a very loose tolerance.
jtu.check_grads(np.linalg.pinv, args_maker(), 2, rtol=1e-1)


def testPinvGradIssue2792(self):
def f(p):
a = np.array([[0., 0.],[-p, 1.]], np.float32) * 1 / (1 + p**2)
return np.linalg.pinv(a)
j = jax.jacobian(f)(np.float32(2.))
self.assertAllClose(np.array([[0., -1.], [ 0., 0.]], np.float32), j,
check_dtypes=True)

expected = np.array([[[[-1., 0.], [ 0., 0.]], [[0., -1.], [0., 0.]]],
[[[0., 0.], [-1., 0.]], [[0., 0.], [0., -1.]]]],
dtype=np.float32)
self.assertAllClose(
expected, jax.jacobian(np.linalg.pinv)(np.eye(2, dtype=np.float32)),
check_dtypes=True)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_n={}".format(
Expand Down

0 comments on commit 8fe3c59

Please sign in to comment.