Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rules for det and logdet of Cholesky #613

Merged
merged 9 commits into from
May 18, 2022
Merged

Conversation

devmotion
Copy link
Member

@devmotion devmotion commented May 10, 2022

This PR adds rules for det(::Cholesky) and logdet(::Cholesky) (currently defined in Zygote for real numbers: https://github.com/FluxML/Zygote.jl/blob/885a904ed958c74cdaa2af7a971a6b5a2da908a7/src/lib/array.jl#L744-L748), and restricts the existing rules of det to StridedMatrix{<:Number}, as already done in #245 for logdet. Edit: reverted, see below.

With these changes it will become possible to compute the gradient of det \circ PDMats.PDMat with Zygote (see JuliaStats/PDMats.jl#159). As an example (works even after removing https://github.com/FluxML/Zygote.jl/blob/885a904ed958c74cdaa2af7a971a6b5a2da908a7/src/lib/array.jl#L744-L748):

julia> using PDMats, ChainRulesCore, LinearAlgebra

julia> @opt_out rrule(::typeof(det), ::AbstractPDMat)

julia> using Zygote

julia> x = [1. 0.2; 0.2 1.];

julia> Zygote.gradient(logdet  PDMat, x)
([1.0416666666666667 -0.41666666666666674; 0.0 1.0416666666666667],)

julia> Zygote.gradient(det  PDMat, x)
([1.0 -0.4000000000000001; 0.0 1.0000000000000002],)

julia> Zygote.gradient(logdet  PDMat, Symmetric(x))
([1.0416666666666667 -0.20833333333333337; -0.20833333333333337 1.0416666666666667],)

julia> Zygote.gradient(det  PDMat, Symmetric(x))
([1.0 -0.20000000000000004; -0.20000000000000004 1.0000000000000002],)

cc: @theogf @simsurace

@devmotion devmotion requested a review from oxinabox May 10, 2022 15:08
@mcabbott
Copy link
Member

mcabbott commented May 10, 2022

existing rules of det to StridedMatrix{<:Number}

This produces:

julia> gradient(det, [1 2; 3 4]')
ERROR: Can't differentiate foreigncall expression

julia> gradient(det, SA[1 2; 3 4])
Internal error: encountered unexpected error in runtime:
BoundsError(a=Array{Core.Compiler.VarState, (2,)}[Core.Compiler.VarState(typ=Zygote.Pullback{Tuple{typeof(StaticArrays._det), StaticArrays.Size{(2, 2)}, StaticArrays.SArray{Tuple{2, 2}, Float64, 2, 4}}, Any}, undef=false), Core.Compiler.VarState(typ=Float64, undef=false)], i=(3,))

Isn't this precisely the case for which ChainRulesCore.@opt_out was invented?

Edit: as noted here. Note also that logabsdet restriction pre-dates this.

@devmotion
Copy link
Member Author

devmotion commented May 10, 2022

The following example works as expected with the PR:

julia> Zygote.gradient(logdet, [1 0; 0 1]')
([1.0 0.0; 0.0 1.0],)

julia> Zygote.gradient(det, [1 0; 0 1]')
([1.0 0.0; 0.0 1.0],)

IMO AbstractMatrix is really just too generic here, and generally you don't want to define these rules at all for matrices but rather for factorizations - as this PR does for Cholesky. See also #468 (comment) and #456 (comment).

Implementation-wise, I also don't think the current inconsistency between det and logdet is desirable. Why should det be defined more generally if logdet is restricted to StridedMatrices? That just leads to inconsistencies and surprising issues as in the PDMats issue you linked.

@mcabbott
Copy link
Member

The following example works

If you choose the numbers just right.

Why should det be defined more generally if logdet is restricted to StridedMatrices

Because it was done before @opt_out etc. Ideally it would now be fixed.

@devmotion
Copy link
Member Author

Ideally it would now be fixed.

Yeah, it should be fixed by defining det and logdet for LU factorizations and removing the ones for matrices completely. But this seemed like a task for a different PR.

@oxinabox
Copy link
Member

Yeah, it should be fixed by defining det and logdet for LU factorizations and removing the ones for matrices completely. But this seemed like a task for a different PR.

We should indeed do this.
Is there a reason not to use @opt_out until that has been done?
(I am open to hearing it).

This is kinda a balancing act right now.
Do we break StaticArrays and AdjointMatrix (which are very commonly used), or do we break PDMats which are commonly used in probabilistic and ML applications where gradients are often wanted.
Without further convincing I am leaning towards leaving things as they are, since it is more confusing when somethin that used to work stops working, than when something that was broken continues to be broken.
But I am open to being convinced

Copy link
Member

@oxinabox oxinabox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

splitting the type restrictions into a seperate PR would let me merge the rest of this without delay.

Ω = det(x)
# TODO Performance optimization: probably there is an efficent
# way to compute this trace without during the full compution within
return Ω, Ω * tr(x \ Δx)
end
frule((_, Δx), ::typeof(det), x::Number) = (det(x), Δx)

function rrule(::typeof(det), x::Union{Number, AbstractMatrix})
function rrule(::typeof(det), x::Union{Number, StridedMatrix{<:Number}})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if these changes could be split out of this PR then we could meged this much faster

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reverted these changes.

@devmotion
Copy link
Member Author

Yes, I thought about these issues as well during the last few days and came basically to the same conclusion.

As discussed in the issues and comments linked above, I still think the existing rule for det is particularly bad - not only since it captures too many matrix types that it shouldn't (wrappers such as AbstractPDMat but also e.g. Symmetric) but also since it doesn't even work correctly for eg Matrix{Float64} since it errors for singular matrices.

However, I think break all currently working cases such as StaticArray and Adjoint is definitely undesirable. I wonder if the logdet restriction to StridedMatrix broke these as well or how similar problems were avoided.

In any case, I think the best and fastest way forward is to remove the det changes from this PR and open a separate PR for them.

@github-actions github-actions bot added the needs version bump Version needs to be incremented or set to -DEV in Project.toml label May 12, 2022
@github-actions github-actions bot removed the needs version bump Version needs to be incremented or set to -DEV in Project.toml label May 12, 2022
@devmotion
Copy link
Member Author

@oxinabox can you have another look? I removed the det changes, I think the PR should be easier to review and merge now.

@devmotion
Copy link
Member Author

Bump 🙂

Comment on lines 558 to 560
s = conj!((2 * y) ./ _diag_view(C.factors))
function det_Cholesky_pullback(ȳ)
ΔC = Tangent{typeof(C)}(; factors=Diagonal(ȳ .* s))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would something like this be better? There's one fewer allocation, and we don't need to assume that s is mutable.

Suggested change
s = conj!((2 * y) ./ _diag_view(C.factors))
function det_Cholesky_pullback(ȳ)
ΔC = Tangent{typeof(C)}(; factors=Diagonal(ȳ .* s))
diagF = _diag_view(C.factors)
function det_Cholesky_pullback(ȳ)
ΔC = Tangent{typeof(C)}(; factors=Diagonal(2(* conj(y)) ./ conj.(diagF)))

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the determinant is 0 (can happen if check=false is passed to cholesky), this will inject NaNs, even if the cotangent is 0. Since we try to treat cotangents as strong zeros, it would be nice to handle this case by ensuring that such NaNs end up as zeros.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's one fewer allocation, and we don't need to assume that s is mutable.

Seems like an improvement to me, thanks! 🙂

There's a whitespace missing in the first line of your suggestion it seem:

Suggested change
s = conj!((2 * y) ./ _diag_view(C.factors))
function det_Cholesky_pullback(ȳ)
ΔC = Tangent{typeof(C)}(; factors=Diagonal(ȳ .* s))
diagF = _diag_view(C.factors)
function det_Cholesky_pullback(ȳ)
ΔC = Tangent{typeof(C)}(; factors=Diagonal(2(* conj(y)) ./ conj.(diagF)))

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the determinant is 0

It would happen if y = 0 (the determinant) but also if ȳ = 0. Should we care about the last case as well? Or is it correct to return NaN there?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, never mind, of course, at least one element of diagF is zero iffy = 0. I.e., we only have to care about y = 0.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe even a bit clearer (without having to know about precedence of operators):

Suggested change
s = conj!((2 * y) ./ _diag_view(C.factors))
function det_Cholesky_pullback(ȳ)
ΔC = Tangent{typeof(C)}(; factors=Diagonal(ȳ .* s))
diagF = _diag_view(C.factors)
function det_Cholesky_pullback(ȳ)
ΔC = Tangent{typeof(C)}(; factors=Diagonal( (2 * (* conj(y))) ./ conj.(diagF)))

Copy link
Member Author

@devmotion devmotion May 18, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess something like the following could work?

# compute `x / conj(y)`, handling `x = y = 0`
function _x_divide_conj_y(x, y)
    z = x / conj(y)
    # in our case `iszero(x)` implies `iszero(y)`
    return iszero(x) ? zero(z) : z
end
function rrule(::typeof(det), C::Cholesky)
    y = det(C)
    diagF = _diag_view(C.factors)
    function det_Cholesky_pullback(ȳ)
        ΔC = Tangent{typeof(C)}(; factors=Diagonal(_x_divide_conj_y.(2 ** conj(y), diagF)))

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sethaxen I updated the PR and added tests for singular matrices.

Copy link
Member

@oxinabox oxinabox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM,
merge and tag when you are happy

src/rulesets/LinearAlgebra/factorization.jl Show resolved Hide resolved
src/rulesets/LinearAlgebra/factorization.jl Outdated Show resolved Hide resolved
y = logdet(C)
diagF = _diag_view(C.factors)
function logdet_Cholesky_pullback(ȳ)
ΔC = Tangent{typeof(C)}(; factors=Diagonal((2 * ȳ) ./ conj.(diagF)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect there's something that can be done here as well to make it more NaN-safe, but I think this should not block this PR.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was wondering that as well (there are some - now hidden - comments above) but it felt like usually we don't handle such things in a special way if it can only be triggered by specific cotangents but is not an immediate consequence of the inputs. Or do we?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we test this for all the rules we should, but a principle that we seem to be agreed on is that zero (co)tangents should be strong zeros (see e.g. JuliaDiff/ChainRulesCore.jl#551 (comment)).

So in this case if ȳ==0, then the cotangent of factors should be a zero matrix. Otherwise you end up with cases like zero(logdet(cholesky(A; check=false))), which pulls back a zero cotangent through this rule, injecting NaN's into all downstream cotangents, even though the output is unrelated to the value of A.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, I was not aware of this principle. Would maybe good to add it to the docs and possibly CRTestUtils 🙂

I'll update the PR accordingly.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be fixed in d831cd4

Copy link
Member

@sethaxen sethaxen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Assuming the same tests pass that pass on main, I think this is ready to merge. Thanks, @devmotion!

@devmotion devmotion merged commit 02a8172 into main May 18, 2022
@devmotion devmotion deleted the dw/cholesky_det_logdet branch May 18, 2022 23:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants