Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Derivative of the determinant of a singular matrix is NaN #2380

Closed
shyoshyo opened this issue Mar 8, 2020 · 26 comments
Closed

Derivative of the determinant of a singular matrix is NaN #2380

shyoshyo opened this issue Mar 8, 2020 · 26 comments
Labels
bug Something isn't working

Comments

@shyoshyo
Copy link

shyoshyo commented Mar 8, 2020

Consider the following code

from jax import grad
import jax.numpy as np

x = np.zeros(shape=(4,4))
d = grad(np.linalg.det)

print(d(x))

the derivative of np.linalg.det(x) against x is obviously a zero matrix, but NaNs was outputted

@shyoshyo shyoshyo changed the title Exception raised when computing the derivative of singular matrix Exception raised when computing the derivative of the determinant of a singular matrix Mar 8, 2020
@j-towns
Copy link
Contributor

j-towns commented Mar 9, 2020

I guess my replies on the equivalent Autograd issue are relevant here too.

@shoyer
Copy link
Collaborator

shoyer commented Mar 9, 2020

This doesn't raise an exception based on values (most JAX API's cannot). Instead I see a matrix of all NaNs:

[[nan nan nan nan]
 [nan nan nan nan]
 [nan nan nan nan]
 [nan nan nan nan]]

@shoyer shoyer changed the title Exception raised when computing the derivative of the determinant of a singular matrix Derivative of the determinant of a singular matrix is NaN Mar 9, 2020
@shyoshyo
Copy link
Author

This doesn't raise an exception based on values (most JAX API's cannot). Instead I see a matrix of all NaNs:

[[nan nan nan nan]
 [nan nan nan nan]
 [nan nan nan nan]
 [nan nan nan nan]]

Yes, you're right. Sorry for my fault, I wrongly tested autograd as if I were testing jax.

@mattjj
Copy link
Collaborator

mattjj commented Mar 10, 2020

Should we just add a np.where / lax.select where the predicate is whether the determinant is zero?

@mattjj mattjj added the bug Something isn't working label Mar 10, 2020
@j-towns
Copy link
Contributor

j-towns commented Mar 11, 2020

Should we just add a np.where / lax.select where the predicate is whether the determinant is zero?

Sure but we need to work out what formula to use to compute the derivative in the case det(x) == 0, because IIUC the derivative is not necessarily zero in that case.

The derivative of the determinant is equal to the adjugate matrix, which, as someone on SO points out, can be computed using the SVD, even in the case det(x) == 0. Presumably we want to avoid doing that computation unless det(x) == 0, not sure if we're able to do that automatically though (because it sounds like value-dependent control flow).

We could have a kwarg on the det function which enables correct derivatives for det(x) == 0 at the cost of computing the svd of x.

@j-towns
Copy link
Contributor

j-towns commented Mar 11, 2020

Actually there might be an easier fix, just by using the LU directly to compute det instead of doing it via slogdet, will have a go at implementing.

That doesn't work, because the lu jvp also produces NANs for singular input. I thought you might be able to fix simply by defining

# WARNING: you also need to compute the correct sign of the determinant which I haven't
# bothered to do there.
@_wraps(onp.linalg.det)
def det(a):
  lu, _ = lax_linalg.lu(a)
  diag = np.diagonal(lu, axis1=-2, axis2=-1)
  return np.prod(diag)

instead of the current definition in terms of slogdet.

Perhaps we should aim to fix this on the lu level then, but I'm not sure whether the lu derivative is well defined for singular input (whereas I'm pretty confident the determinant derivative is well defined for singular input).

@shoyer
Copy link
Collaborator

shoyer commented Mar 13, 2020

Seems like it might make sense to define some custom JVP rules?

e.g., using any of these identities from the matrix cookbook:
image

Direct derivative rules can often be much more efficient than differentiating the matrix factorization itself, at least they are for matrix solves.

@dpfau
Copy link
Contributor

dpfau commented Mar 25, 2020

The issue here is that the formula for the gradient of the determinant involves a product of the determinant and the matrix inverse. For a singular matrix, that would be basically 0*inf, which is why you get NaNs. I do have an implementation of the gradient of the determinant of a rank n-1 matrix that works directly with the LU decomposition. I don't think it works for generic low-rank matrices though.

Also, I'm sure this issue is unrelated to #2510.

@shoyer
Copy link
Collaborator

shoyer commented Mar 25, 2020

The issue here is that the formula for the gradient of the determinant involves a product of the determinant and the matrix inverse. For a singular matrix, that would be basically 0*inf, which is why you get NaNs.

Good point. Just to throw out the first idea that comes to mind: would using a psuedo-inverse instead of an inverse make sense here?

@dpfau
Copy link
Contributor

dpfau commented Mar 25, 2020

Tried it, didn't work. Here's my code for the cofactor (transpose of the adjugate) that works for rank n-1 matrices:

from jax import lax
from jax import lax_linalg
from jax import ops
import jax.numpy.lax_numpy as np
import jax.numpy.linalg as linalg


def solve(a, b):
  """Compute cof(a)^T*b. Equivalent to det(a)*solve(a, b) for nonsingular mat.

  This function borrows heavily from jax.numpy.linalg.solve and
  jax.numpy.linalg.det to compute the gradient of the determinant in a way that
  is well defined even for rank n-1 matrices.
  * assumes a is at least rank n-1
  * assumes u_{nn} is the element set to 0 in singular cases.

  Args:
    a: A square matrix or batch of matrices, possibly singular.
    b: A vector/matrix, or batch of vectors/matrices of the same dimension as a.

  Returns:
    cofactor(a)^T*b, aka adjugate(a)*b
  """
  a, b = linalg._promote_arg_dtypes(np.asarray(a), np.asarray(b))  # pylint: disable=protected-access
  a_shape = np.shape(a)
  b_shape = np.shape(b)
  a_ndims = len(a_shape)
  b_ndims = len(b_shape)
  if not (a_ndims >= 2 and a_shape[-1] == a_shape[-2] and b_ndims >= 1):
    msg = ("The arguments to cofactor_solve must have shapes "
           "a=[..., m, m] and b=[..., m, k] or b=[..., m]; got a={} and b={}")
    raise ValueError(msg.format(a_shape, b_shape))

  if a_shape[-1] == 1:
    return b

  # lu contains u in the upper triangular matrix and l in the strict lower
  # triangular matrix.
  # The diagonal of l is set to ones without loss of generality.
  lu, pivots = lax_linalg.lu(a)
  dtype = lax.dtype(a)

  m = a_shape[-1]

  # Numpy treats the RHS as a (batched) vector if the number of dimensions
  # differ by 1. Otherwise, broadcasting rules apply.
  x = b[..., None] if a_ndims == b_ndims + 1 else b

  batch_dims = lax.broadcast_shapes(lu.shape[:-2], x.shape[:-2])
  x = np.broadcast_to(x, batch_dims + x.shape[-2:])
  lu = np.broadcast_to(lu, batch_dims + lu.shape[-2:])

  # Compute (partial) determinant, ignoring last diagonal of LU
  diag = np.diagonal(lu, axis1=-2, axis2=-1)
  parity = np.count_nonzero(pivots != np.arange(a_shape[-1]), axis=-1)
  sign = np.array(-2 * (parity % 2) + 1, dtype=dtype)
  # partial_det[:, -1] contains the full determinant and
  # partial_det[:, -2] contains U_{nn} / det{U}.
  partial_det = np.cumprod(diag, axis=-1) * sign[..., None]
  lu = ops.index_update(lu, ops.index[..., -1, -1], 1.0 / partial_det[..., -2])

  permutation = lax_linalg.lu_pivots_to_permutation(pivots, m)
  permutation = np.broadcast_to(permutation, batch_dims + (m,))
  iotas = np.ix_(*(lax.iota(np.int32, b) for b in batch_dims + (1,)))
  x = x[iotas[:-1] + (permutation, slice(None))]

  x = lax_linalg.triangular_solve(lu, x, left_side=True, lower=True,
                                  unit_diagonal=True)
  x = ops.index_update(x, ops.index[..., :-1, :],
                       x[..., :-1, :] * partial_det[..., -1, None, None])
  x = lax_linalg.triangular_solve(lu, x, left_side=True, lower=False)

  return x[..., 0] if a_ndims == b_ndims + 1 else x

@dpfau
Copy link
Contributor

dpfau commented Mar 25, 2020

In the case that the matrix rank is less than n-1, the gradient of the determinant will be identically zero. So we could add a lax.cond to the above function that checks if there is more than one zero on the diagonal of lu. The reason I hadn't done this yet was that I was not sure what tolerance triangular_solve uses to determine if a diagonal element is zero or not (presumably we should match that tolerance).

@shoyer
Copy link
Collaborator

shoyer commented Mar 25, 2020

Just a note RE "This function borrows heavily from jax.numpy.linalg.solve". I updated the implementation of jax.numpy.linalg.solve a month or so ago in #2220, so it looks pretty different now. That probably would speed up gradients of your function, not sure if would change the numerics.

@dpfau
Copy link
Contributor

dpfau commented Mar 25, 2020

Thanks, I'll take a look.

Also, looking at triangular_solve, it seems like there are different primitives for different backends. Does anyone have any idea how to safely check if the backend considers the matrix singular? Should we just do a try/catch that returns zero if the backend considers the matrix to be singular?

@dpfau
Copy link
Contributor

dpfau commented Mar 25, 2020

Though maybe you misunderstand my comment in the code. It borrows heavily from the forward computation of solve, not the gradient of solve. Did that change at all?

@dpfau
Copy link
Contributor

dpfau commented Mar 25, 2020

Ah I guess you said it might change the gradient of my solve. Since solve here computes the gradient of the determinant, that would be the gradient of the gradient in my case.

@shoyer
Copy link
Collaborator

shoyer commented Mar 25, 2020 via email

@dpfau
Copy link
Contributor

dpfau commented Mar 25, 2020

I do however need second derivatives of the determinant, so it may be useful for me after all

@j-towns
Copy link
Contributor

j-towns commented Mar 26, 2020

I thought the following would be a nice quick work-around, the idea being to let the jvp rule of prod handle the awkwardness of differentiating a product of many terms. But it doesn't seem to support second derivatives yet (see below).

import jax.numpy.linalg as la
import jax.numpy as np

def correct_derivatives_det(x):
  _, s, _ = la.svd(x)
  return np.prod(s)

When you try to compute a second derivative you get

NotImplementedError: Forward-mode differentiation rule for 'reduce_window' not implemented

That's an issue that's already documented here, maybe this will provide a bit of extra motivation to implement that rule...

EDIT: To be clear I'm suggesting the above as a temporary work-around not a permanent fix, I think the correct thing to do is probably to fix the derivative of lu and then implement det in terms of lu (without bothering with a custom det derivative), because afaict that would be straightforward and pretty fast. There are also cholesky and qr if lu turns out to be complicated to fix.

@dpfau
Copy link
Contributor

dpfau commented Mar 26, 2020

The function I shared above already works. I'm just working on integrating it into JAX.

@j-towns
Copy link
Contributor

j-towns commented Mar 26, 2020

I'd be interested to know if that approach is faster than using the derivative of the lu decomposition. In the long term we should aim for an approach that works for all matrix ranks and differentiation orders (since that shouldn't be too difficult once we have 2nd and higher order np.prod derivatives).

@j-towns
Copy link
Contributor

j-towns commented Mar 26, 2020

Just thinking also that defining det directly in terms of lu/svd/cholesky will be less numerically stable than using slogdet (just checked and Numpy computes det using slogdet). So that's another reason to maybe prefer keeping the current det implementation and adding a custom jvp.

@dpfau
Copy link
Contributor

dpfau commented Mar 26, 2020

I'm actually having a weirdly difficult time reproducing this bug. For anything other than an identically zero matrix, even if the determinant is still zero to within numerical precision, the gradient often still works fine. I think this is due to loss of numerical precision that, weirdly, helps us in this case. Still, I've managed to find a few cases where the existing implementation returns NaN and the new version works.

@dpfau
Copy link
Contributor

dpfau commented Mar 26, 2020

I've got a sort-of-working implementation now in my forked repo at github.com/dpfau/jax (see linalg.py and linalg_test.py). At the moment the issues are:
*I can't simultaneously define a custom_jvp and custom_vjp rule, and if I try to only use a custom_jvp rule, certain ops can't be transposed (I think scatter is the first to fail)
*Something seems to be failing in the higher-order derivatives as well

However, it sounds like the JAX team is still sorting out some other bugs involved with the new custom derivatives, so I'll wait until that is sorted out.

hawkinsp added a commit to hawkinsp/jax that referenced this issue Apr 3, 2020
Unlike the existing implementation based on lax.reduce_window, this implementation is O(n log n) instead of O(n^2) and is arbitrarily differentiable.

Fixes jax-ml#1212, jax-ml#2418, jax-ml#2542.
May help with issue jax-ml#2380.
hawkinsp added a commit that referenced this issue Apr 3, 2020
…an. (#2596)

* Reimplement np.cumsum and np.cumprod in terms of a parallel prefix scan.

Unlike the existing implementation based on lax.reduce_window, this implementation is O(n log n) instead of O(n^2) and is arbitrarily differentiable.

Fixes #1212, #2418, #2542.
May help with issue #2380.

* Relax gradient test tolerance.
@j-towns
Copy link
Contributor

j-towns commented Apr 6, 2020

I think #2597 being merged should mean that the code in #2380 (comment) will work for all orders of derivative and all matrix ranks 😊

NeilGirdhar pushed a commit to NeilGirdhar/jax that referenced this issue Apr 13, 2020
…an. (jax-ml#2596)

* Reimplement np.cumsum and np.cumprod in terms of a parallel prefix scan.

Unlike the existing implementation based on lax.reduce_window, this implementation is O(n log n) instead of O(n^2) and is arbitrarily differentiable.

Fixes jax-ml#1212, jax-ml#2418, jax-ml#2542.
May help with issue jax-ml#2380.

* Relax gradient test tolerance.
@dpfau
Copy link
Contributor

dpfau commented Apr 25, 2020

This is now fixed by PR #2809. Please close this issue.

@mattjj
Copy link
Collaborator

mattjj commented Apr 25, 2020

Woohoooooo thanks @dpfau!

@mattjj mattjj closed this as completed Apr 25, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

5 participants