-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add rules for det
and logdet
of Cholesky
#613
Conversation
8a8ad3a
to
54995d6
Compare
This produces: julia> gradient(det, [1 2; 3 4]')
ERROR: Can't differentiate foreigncall expression
julia> gradient(det, SA[1 2; 3 4])
Internal error: encountered unexpected error in runtime:
BoundsError(a=Array{Core.Compiler.VarState, (2,)}[Core.Compiler.VarState(typ=Zygote.Pullback{Tuple{typeof(StaticArrays._det), StaticArrays.Size{(2, 2)}, StaticArrays.SArray{Tuple{2, 2}, Float64, 2, 4}}, Any}, undef=false), Core.Compiler.VarState(typ=Float64, undef=false)], i=(3,)) Isn't this precisely the case for which Edit: as noted here. Note also that |
The following example works as expected with the PR: julia> Zygote.gradient(logdet, [1 0; 0 1]')
([1.0 0.0; 0.0 1.0],)
julia> Zygote.gradient(det, [1 0; 0 1]')
([1.0 0.0; 0.0 1.0],) IMO Implementation-wise, I also don't think the current inconsistency between |
If you choose the numbers just right.
Because it was done before |
Yeah, it should be fixed by defining det and logdet for LU factorizations and removing the ones for matrices completely. But this seemed like a task for a different PR. |
We should indeed do this. This is kinda a balancing act right now. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
splitting the type restrictions into a seperate PR would let me merge the rest of this without delay.
src/rulesets/LinearAlgebra/dense.jl
Outdated
Ω = det(x) | ||
# TODO Performance optimization: probably there is an efficent | ||
# way to compute this trace without during the full compution within | ||
return Ω, Ω * tr(x \ Δx) | ||
end | ||
frule((_, Δx), ::typeof(det), x::Number) = (det(x), Δx) | ||
|
||
function rrule(::typeof(det), x::Union{Number, AbstractMatrix}) | ||
function rrule(::typeof(det), x::Union{Number, StridedMatrix{<:Number}}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if these changes could be split out of this PR then we could meged this much faster
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I reverted these changes.
Yes, I thought about these issues as well during the last few days and came basically to the same conclusion. As discussed in the issues and comments linked above, I still think the existing rule for However, I think break all currently working cases such as In any case, I think the best and fastest way forward is to remove the |
@oxinabox can you have another look? I removed the |
Bump 🙂 |
s = conj!((2 * y) ./ _diag_view(C.factors)) | ||
function det_Cholesky_pullback(ȳ) | ||
ΔC = Tangent{typeof(C)}(; factors=Diagonal(ȳ .* s)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would something like this be better? There's one fewer allocation, and we don't need to assume that s
is mutable.
s = conj!((2 * y) ./ _diag_view(C.factors)) | |
function det_Cholesky_pullback(ȳ) | |
ΔC = Tangent{typeof(C)}(; factors=Diagonal(ȳ .* s)) | |
diagF = _diag_view(C.factors) | |
function det_Cholesky_pullback(ȳ) | |
ΔC = Tangent{typeof(C)}(; factors=Diagonal(2(ȳ * conj(y)) ./ conj.(diagF))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the determinant is 0 (can happen if check=false
is passed to cholesky
), this will inject NaN
s, even if the cotangent is 0. Since we try to treat cotangents as strong zeros, it would be nice to handle this case by ensuring that such NaN
s end up as zeros.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's one fewer allocation, and we don't need to assume that s is mutable.
Seems like an improvement to me, thanks! 🙂
There's a whitespace missing in the first line of your suggestion it seem:
s = conj!((2 * y) ./ _diag_view(C.factors)) | |
function det_Cholesky_pullback(ȳ) | |
ΔC = Tangent{typeof(C)}(; factors=Diagonal(ȳ .* s)) | |
diagF = _diag_view(C.factors) | |
function det_Cholesky_pullback(ȳ) | |
ΔC = Tangent{typeof(C)}(; factors=Diagonal(2(ȳ * conj(y)) ./ conj.(diagF))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the determinant is 0
It would happen if y = 0
(the determinant) but also if ȳ = 0
. Should we care about the last case as well? Or is it correct to return NaN
there?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, never mind, of course, at least one element of diagF
is zero iffy = 0
. I.e., we only have to care about y = 0
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe even a bit clearer (without having to know about precedence of operators):
s = conj!((2 * y) ./ _diag_view(C.factors)) | |
function det_Cholesky_pullback(ȳ) | |
ΔC = Tangent{typeof(C)}(; factors=Diagonal(ȳ .* s)) | |
diagF = _diag_view(C.factors) | |
function det_Cholesky_pullback(ȳ) | |
ΔC = Tangent{typeof(C)}(; factors=Diagonal( (2 * (ȳ * conj(y))) ./ conj.(diagF))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess something like the following could work?
# compute `x / conj(y)`, handling `x = y = 0`
function _x_divide_conj_y(x, y)
z = x / conj(y)
# in our case `iszero(x)` implies `iszero(y)`
return iszero(x) ? zero(z) : z
end
function rrule(::typeof(det), C::Cholesky)
y = det(C)
diagF = _diag_view(C.factors)
function det_Cholesky_pullback(ȳ)
ΔC = Tangent{typeof(C)}(; factors=Diagonal(_x_divide_conj_y.(2 * ȳ * conj(y), diagF)))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sethaxen I updated the PR and added tests for singular matrices.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM,
merge and tag when you are happy
y = logdet(C) | ||
diagF = _diag_view(C.factors) | ||
function logdet_Cholesky_pullback(ȳ) | ||
ΔC = Tangent{typeof(C)}(; factors=Diagonal((2 * ȳ) ./ conj.(diagF))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suspect there's something that can be done here as well to make it more NaN
-safe, but I think this should not block this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was wondering that as well (there are some - now hidden - comments above) but it felt like usually we don't handle such things in a special way if it can only be triggered by specific cotangents but is not an immediate consequence of the inputs. Or do we?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we test this for all the rules we should, but a principle that we seem to be agreed on is that zero (co)tangents should be strong zeros (see e.g. JuliaDiff/ChainRulesCore.jl#551 (comment)).
So in this case if ȳ==0
, then the cotangent of factors
should be a zero matrix. Otherwise you end up with cases like zero(logdet(cholesky(A; check=false)))
, which pulls back a zero cotangent through this rule, injecting NaN
's into all downstream cotangents, even though the output is unrelated to the value of A
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, I was not aware of this principle. Would maybe good to add it to the docs and possibly CRTestUtils 🙂
I'll update the PR accordingly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be fixed in d831cd4
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Assuming the same tests pass that pass on main, I think this is ready to merge. Thanks, @devmotion!
This PR adds rules for
det(::Cholesky)
andlogdet(::Cholesky)
(currently defined in Zygote for real numbers: https://github.com/FluxML/Zygote.jl/blob/885a904ed958c74cdaa2af7a971a6b5a2da908a7/src/lib/array.jl#L744-L748), and restricts the existing rules ofEdit: reverted, see below.det
toStridedMatrix{<:Number}
, as already done in #245 forlogdet
.With these changes it will become possible to compute the gradient of
det \circ PDMats.PDMat
with Zygote (see JuliaStats/PDMats.jl#159). As an example (works even after removing https://github.com/FluxML/Zygote.jl/blob/885a904ed958c74cdaa2af7a971a6b5a2da908a7/src/lib/array.jl#L744-L748):cc: @theogf @simsurace