-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Custom derivative for np.linalg.det #2809
Changes from all commits
6cad38b
7752168
4706a91
9f27f39
e73c53f
99d78b5
70f389a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -148,12 +148,112 @@ def _slogdet_jvp(primals, tangents): | |
return (sign, ans), (sign_dot, ans_dot) | ||
|
||
|
||
def _cofactor_solve(a, b): | ||
"""Equivalent to det(a)*solve(a, b) for nonsingular mat. | ||
|
||
Intermediate function used for jvp and vjp of det. | ||
This function borrows heavily from jax.numpy.linalg.solve and | ||
jax.numpy.linalg.slogdet to compute the gradient of the determinant | ||
in a way that is well defined even for low rank matrices. | ||
|
||
This function handles two different cases: | ||
* rank(a) == n or n-1 | ||
* rank(a) < n-1 | ||
|
||
For rank n-1 matrices, the gradient of the determinant is a rank 1 matrix. | ||
Rather than computing det(a)*solve(a, b), which would return NaN, we work | ||
directly with the LU decomposition. If a = p @ l @ u, then | ||
det(a)*solve(a, b) = | ||
prod(diag(u)) * u^-1 @ l^-1 @ p^-1 b = | ||
prod(diag(u)) * triangular_solve(u, solve(p @ l, b)) | ||
If a is rank n-1, then the lower right corner of u will be zero and the | ||
triangular_solve will fail. | ||
Let x = solve(p @ l, b) and y = det(a)*solve(a, b). | ||
Then y_{nn} = | ||
x_{nn} / u_{nn} * prod_{i=1...n}(u_{ii}) = | ||
x_{nn} * prod_{i=1...n-1}(u_{ii}) | ||
So by replacing the lower-right corner of u with prod_{i=1...n-1}(u_{ii})^-1 | ||
we can avoid the triangular_solve failing. | ||
To correctly compute the rest of x_{ii} for i != n, we simply multiply | ||
x_{ii} by det(a) for all i != n, which will be zero if rank(a) = n-1. | ||
|
||
For the second case, a check is done on the matrix to see if `solve` | ||
returns NaN or Inf, and gives a matrix of zeros as a result, as the | ||
gradient of the determinant of a matrix with rank less than n-1 is 0. | ||
This will still return the correct value for rank n-1 matrices, as the check | ||
is applied *after* the lower right corner of u has been updated. | ||
|
||
Args: | ||
a: A square matrix or batch of matrices, possibly singular. | ||
b: A matrix, or batch of matrices of the same dimension as a. | ||
|
||
Returns: | ||
det(a) and cofactor(a)^T*b, aka adjugate(a)*b | ||
""" | ||
a = _promote_arg_dtypes(np.asarray(a)) | ||
b = _promote_arg_dtypes(np.asarray(b)) | ||
a_shape = np.shape(a) | ||
b_shape = np.shape(b) | ||
a_ndims = len(a_shape) | ||
b_ndims = len(b_shape) | ||
if not (a_ndims >= 2 and a_shape[-1] == a_shape[-2] | ||
and b_shape[-2:] == a_shape[-2:]): | ||
msg = ("The arguments to _cofactor_solve must have shapes " | ||
"a=[..., m, m] and b=[..., m, m]; got a={} and b={}") | ||
raise ValueError(msg.format(a_shape, b_shape)) | ||
if a_shape[-1] == 1: | ||
return a[0, 0], b | ||
# lu contains u in the upper triangular matrix and l in the strict lower | ||
# triangular matrix. | ||
# The diagonal of l is set to ones without loss of generality. | ||
lu, pivots = lax_linalg.lu(a) | ||
dtype = lax.dtype(a) | ||
batch_dims = lax.broadcast_shapes(lu.shape[:-2], b.shape[:-2]) | ||
x = np.broadcast_to(b, batch_dims + b.shape[-2:]) | ||
lu = np.broadcast_to(lu, batch_dims + lu.shape[-2:]) | ||
Comment on lines
+211
to
+213
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You probably already have this working, but I'll note that you could probably simplify this considerably if you make use of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is taken pretty much line-by-line from an older version of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually you'd previously mentioned some things about changes to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. From a quick look - it seems pretty complicated! I'll keep this in mind for the future - but as you said, this is working now. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right, we're using I'll have to think a little bit more about the custom_linear_solve thing.... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not familiar with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's roughly my understanding. We can iterate on that in follow-up PRs though! |
||
# Compute (partial) determinant, ignoring last diagonal of LU | ||
diag = np.diagonal(lu, axis1=-2, axis2=-1) | ||
parity = np.count_nonzero(pivots != np.arange(a_shape[-1]), axis=-1) | ||
sign = np.array(-2 * (parity % 2) + 1, dtype=dtype) | ||
# partial_det[:, -1] contains the full determinant and | ||
# partial_det[:, -2] contains det(u) / u_{nn}. | ||
partial_det = np.cumprod(diag, axis=-1) * sign[..., None] | ||
lu = ops.index_update(lu, ops.index[..., -1, -1], 1.0 / partial_det[..., -2]) | ||
permutation = lax_linalg.lu_pivots_to_permutation(pivots, a_shape[-1]) | ||
permutation = np.broadcast_to(permutation, batch_dims + (a_shape[-1],)) | ||
iotas = np.ix_(*(lax.iota(np.int32, b) for b in batch_dims + (1,))) | ||
# filter out any matrices that are not full rank | ||
d = np.ones(x.shape[:-1], x.dtype) | ||
d = lax_linalg.triangular_solve(lu, d, left_side=True, lower=False) | ||
d = np.any(np.logical_or(np.isnan(d), np.isinf(d)), axis=-1) | ||
d = np.tile(d[..., None, None], d.ndim*(1,) + x.shape[-2:]) | ||
x = np.where(d, np.zeros_like(x), x) # first filter | ||
x = x[iotas[:-1] + (permutation, slice(None))] | ||
x = lax_linalg.triangular_solve(lu, x, left_side=True, lower=True, | ||
unit_diagonal=True) | ||
x = np.concatenate((x[..., :-1, :] * partial_det[..., -1, None, None], | ||
x[..., -1:, :]), axis=-2) | ||
x = lax_linalg.triangular_solve(lu, x, left_side=True, lower=False) | ||
x = np.where(d, np.zeros_like(x), x) # second filter | ||
|
||
return partial_det[..., -1], x | ||
|
||
|
||
@custom_jvp | ||
@_wraps(onp.linalg.det) | ||
def det(a): | ||
sign, logdet = slogdet(a) | ||
return sign * np.exp(logdet) | ||
|
||
|
||
@det.defjvp | ||
def _det_jvp(primals, tangents): | ||
x, = primals | ||
g, = tangents | ||
y, z = _cofactor_solve(x, g) | ||
return y, np.trace(z, axis1=-1, axis2=-2) | ||
|
||
|
||
@_wraps(onp.linalg.eig) | ||
def eig(a): | ||
a = _promote_arg_dtypes(np.asarray(a)) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you have a reference on the method you use here for calculating the adjugate?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nope. Derived it myself. I could add a short description to the docstring?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.