-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Derivative of the determinant of a singular matrix is NaN #2380
Comments
I guess my replies on the equivalent Autograd issue are relevant here too. |
This doesn't raise an exception based on values (most JAX API's cannot). Instead I see a matrix of all NaNs:
|
Yes, you're right. Sorry for my fault, I wrongly tested autograd as if I were testing jax. |
Should we just add a |
Sure but we need to work out what formula to use to compute the derivative in the case The derivative of the determinant is equal to the adjugate matrix, which, as someone on SO points out, can be computed using the SVD, even in the case We could have a kwarg on the det function which enables correct derivatives for |
That doesn't work, because the # WARNING: you also need to compute the correct sign of the determinant which I haven't
# bothered to do there.
@_wraps(onp.linalg.det)
def det(a):
lu, _ = lax_linalg.lu(a)
diag = np.diagonal(lu, axis1=-2, axis2=-1)
return np.prod(diag) instead of the current definition in terms of Perhaps we should aim to fix this on the |
Seems like it might make sense to define some custom JVP rules? e.g., using any of these identities from the matrix cookbook: Direct derivative rules can often be much more efficient than differentiating the matrix factorization itself, at least they are for matrix solves. |
The issue here is that the formula for the gradient of the determinant involves a product of the determinant and the matrix inverse. For a singular matrix, that would be basically 0*inf, which is why you get NaNs. I do have an implementation of the gradient of the determinant of a rank n-1 matrix that works directly with the LU decomposition. I don't think it works for generic low-rank matrices though. Also, I'm sure this issue is unrelated to #2510. |
Good point. Just to throw out the first idea that comes to mind: would using a psuedo-inverse instead of an inverse make sense here? |
Tried it, didn't work. Here's my code for the cofactor (transpose of the adjugate) that works for rank n-1 matrices:
|
In the case that the matrix rank is less than n-1, the gradient of the determinant will be identically zero. So we could add a |
Just a note RE "This function borrows heavily from jax.numpy.linalg.solve". I updated the implementation of |
Thanks, I'll take a look. Also, looking at triangular_solve, it seems like there are different primitives for different backends. Does anyone have any idea how to safely check if the backend considers the matrix singular? Should we just do a try/catch that returns zero if the backend considers the matrix to be singular? |
Though maybe you misunderstand my comment in the code. It borrows heavily from the forward computation of |
Ah I guess you said it might change the gradient of my |
Good point, my change probably isn't relevant for your gradient rule.
…On Wed, Mar 25, 2020 at 11:41 AM David Pfau ***@***.***> wrote:
Ah I guess you said it might change the gradient of my solve. Since solve
here computes the gradient of the determinant, that would be the gradient
of the gradient in my case.
—
You are receiving this because you commented.
Reply to this email directly, view it on GitHub
<#2380 (comment)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAJJFVS5VTNQQPSEX2BWJHLRJJF5HANCNFSM4LDZE7OQ>
.
|
I do however need second derivatives of the determinant, so it may be useful for me after all |
I thought the following would be a nice quick work-around, the idea being to let the jvp rule of prod handle the awkwardness of differentiating a product of many terms. But it doesn't seem to support second derivatives yet (see below). import jax.numpy.linalg as la
import jax.numpy as np
def correct_derivatives_det(x):
_, s, _ = la.svd(x)
return np.prod(s) When you try to compute a second derivative you get
That's an issue that's already documented here, maybe this will provide a bit of extra motivation to implement that rule... EDIT: To be clear I'm suggesting the above as a temporary work-around not a permanent fix, I think the correct thing to do is probably to fix the derivative of |
The function I shared above already works. I'm just working on integrating it into JAX. |
I'd be interested to know if that approach is faster than using the derivative of the lu decomposition. In the long term we should aim for an approach that works for all matrix ranks and differentiation orders (since that shouldn't be too difficult once we have 2nd and higher order |
Just thinking also that defining det directly in terms of lu/svd/cholesky will be less numerically stable than using slogdet (just checked and Numpy computes det using slogdet). So that's another reason to maybe prefer keeping the current det implementation and adding a custom jvp. |
I'm actually having a weirdly difficult time reproducing this bug. For anything other than an identically zero matrix, even if the determinant is still zero to within numerical precision, the gradient often still works fine. I think this is due to loss of numerical precision that, weirdly, helps us in this case. Still, I've managed to find a few cases where the existing implementation returns NaN and the new version works. |
I've got a sort-of-working implementation now in my forked repo at github.com/dpfau/jax (see However, it sounds like the JAX team is still sorting out some other bugs involved with the new custom derivatives, so I'll wait until that is sorted out. |
Unlike the existing implementation based on lax.reduce_window, this implementation is O(n log n) instead of O(n^2) and is arbitrarily differentiable. Fixes jax-ml#1212, jax-ml#2418, jax-ml#2542. May help with issue jax-ml#2380.
…an. (#2596) * Reimplement np.cumsum and np.cumprod in terms of a parallel prefix scan. Unlike the existing implementation based on lax.reduce_window, this implementation is O(n log n) instead of O(n^2) and is arbitrarily differentiable. Fixes #1212, #2418, #2542. May help with issue #2380. * Relax gradient test tolerance.
I think #2597 being merged should mean that the code in #2380 (comment) will work for all orders of derivative and all matrix ranks 😊 |
…an. (jax-ml#2596) * Reimplement np.cumsum and np.cumprod in terms of a parallel prefix scan. Unlike the existing implementation based on lax.reduce_window, this implementation is O(n log n) instead of O(n^2) and is arbitrarily differentiable. Fixes jax-ml#1212, jax-ml#2418, jax-ml#2542. May help with issue jax-ml#2380. * Relax gradient test tolerance.
This is now fixed by PR #2809. Please close this issue. |
Woohoooooo thanks @dpfau! |
Consider the following code
the derivative of
np.linalg.det(x)
againstx
is obviously a zero matrix, but NaNs was outputtedThe text was updated successfully, but these errors were encountered: