-
Notifications
You must be signed in to change notification settings - Fork 33
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
AD fix for CorrBijector #281
Conversation
ext/BijectorsZygoteExt.jl
Outdated
# TODO: Remove these as soon as https://github.com/FluxML/Zygote.jl/pull/1444 is merged. | ||
@adjoint LinearAlgebra.parent(x::LinearAlgebra.UpperTriangular) = parent(x), ȳ -> (LinearAlgebra.UpperTriangular(ȳ),) | ||
@adjoint LinearAlgebra.parent(x::LinearAlgebra.LowerTriangular) = parent(x), ȳ -> (LinearAlgebra.LowerTriangular(ȳ),) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added this here to check if tests are succeeding; will remove as soon as the mentioned PR goes through.
@@ -321,7 +325,7 @@ function _link_chol_lkj(W::UpperTriangular) | |||
return z | |||
end | |||
|
|||
_link_chol_lkj(W::LowerTriangular) = _link_chol_lkj(transpose(W)) | |||
_link_chol_lkj_from_lower(W::AbstractMatrix) = _link_chol_lkj_from_upper(transpose_eager(W)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This one is a bit weird because I just saw that it has it's own chainrule defined while the evaluation itself does not have a specialized implementation..
There is a reproducible Cholesky factorization failure error. I couldn't find any obvious sources causing it. |
* Update chainrules.jl * Update corr.jl * Revert changes to transform.
Looks like in v1.6 In latest version |
Also when I try to run Bijectors.jl/test/transform.jl Lines 216 to 230 in 06b936d
encounter julia> dist = LKJCholesky(3, 1)
LKJCholesky{Float64}(
d: 3
η: 1.0
uplo: L
)
julia> x = rand(dist)
Cholesky{Float64, Matrix{Float64}}
L factor:
3×3 LowerTriangular{Float64, Matrix{Float64}}:
1.0 ⋅ ⋅
0.198156 0.980171 ⋅
-0.797754 0.257329 0.545316
julia> J = ForwardDiff.jacobian(x -> link(dist, x), x.U)
ERROR: PosDefException: matrix is not positive definite; Cholesky factorization failed.
Stacktrace:
[1] checkpositivedefinite
@ /path/to/julia/stdlib/v1.9/LinearAlgebra/src/factorization.jl:18 [inlined]
[2] #cholesky!#152
@ /path/to/julia/stdlib/v1.9/LinearAlgebra/src/cholesky.jl:268 [inlined]
[3] cholesky! (repeats 2 times)
@ /path/to/julia/stdlib/v1.9/LinearAlgebra/src/cholesky.jl:266 [inlined]
[4] #cholesky#162
@ /path/to/julia/stdlib/v1.9/LinearAlgebra/src/cholesky.jl:400 [inlined]
[5] cholesky (repeats 2 times)
@ /path/to/julia/stdlib/v1.9/LinearAlgebra/src/cholesky.jl:400 [inlined]
[6] cholesky_lower(X::UpperTriangular{ForwardDiff.Dual{ForwardDiff.Tag{var"#9#10", Float64}, Float64, 9}, Matrix{ForwardDiff.Dual{ForwardDiff.Tag{var"#9#10", Float64}, Float64, 9}}})
@ Bijectors /path/to/Bijectors/src/utils.jl:31
[7] transform(b::Bijectors.VecCholeskyBijector, X::UpperTriangular{ForwardDiff.Dual{ForwardDiff.Tag{var"#9#10", Float64}, Float64, 9}, Matrix{ForwardDiff.Dual{ForwardDiff.Tag{var"#9#10", Float64}, Float64, 9}}})
@ Bijectors /path/to/Bijectors/src/bijectors/corr.jl:220
[8] Transform
@ /path/to/Bijectors/src/interface.jl:80 [inlined]
[9] link(d::LKJCholesky{Float64}, x::UpperTriangular{ForwardDiff.Dual{ForwardDiff.Tag{var"#9#10", Float64}, Float64, 9}, Matrix{ForwardDiff.Dual{ForwardDiff.Tag{var"#9#10", Float64}, Float64, 9}}})
@ Bijectors /path/to/Bijectors/src/Bijectors.jl:128
[10] (::var"#9#10")(x::UpperTriangular{ForwardDiff.Dual{ForwardDiff.Tag{var"#9#10", Float64}, Float64, 9}, Matrix{ForwardDiff.Dual{ForwardDiff.Tag{var"#9#10", Float64}, Float64, 9}}})
@ Main ./REPL[3]:1
[11] vector_mode_dual_eval!
@ /path/to/ForwardDiff/src/apiutils.jl:24 [inlined]
[12] vector_mode_jacobian(f::var"#9#10", x::UpperTriangular{Float64, Matrix{Float64}}, cfg::ForwardDiff.JacobianConfig{ForwardDiff.Tag{var"#9#10", Float64}, Float64, 9, UpperTriangular{ForwardDiff.Dual{ForwardDiff.Tag{var"#9#10", Float64}, Float64, 9}, Matrix{ForwardDiff.Dual{ForwardDiff.Tag{var"#9#10", Float64}, Float64, 9}}}})
@ ForwardDiff /path/to/ForwardDiff/src/jacobian.jl:125
[13] jacobian(f::Function, x::UpperTriangular{Float64, Matrix{Float64}}, cfg::ForwardDiff.JacobianConfig{ForwardDiff.Tag{var"#9#10", Float64}, Float64, 9, UpperTriangular{ForwardDiff.Dual{ForwardDiff.Tag{var"#9#10", Float64}, Float64, 9}, Matrix{ForwardDiff.Dual{ForwardDiff.Tag{var"#9#10", Float64}, Float64, 9}}}}, ::Val{true})
@ ForwardDiff /path/to/ForwardDiff/src/jacobian.jl:21
[14] jacobian(f::Function, x::UpperTriangular{Float64, Matrix{Float64}}, cfg::ForwardDiff.JacobianConfig{ForwardDiff.Tag{var"#9#10", Float64}, Float64, 9, UpperTriangular{ForwardDiff.Dual{ForwardDiff.Tag{var"#9#10", Float64}, Float64, 9}, Matrix{ForwardDiff.Dual{ForwardDiff.Tag{var"#9#10", Float64}, Float64, 9}}}})
@ ForwardDiff /path/to/ForwardDiff/src/jacobian.jl:19
[15] jacobian(f::Function, x::UpperTriangular{Float64, Matrix{Float64}})
@ ForwardDiff /path/to/ForwardDiff/src/jacobian.jl:19
[16] top-level scope
@ REPL[3]:1 |
But we only call
The issue is two-fold. First, we're giving the transformation in julia> using Bijectors
julia> dist = LKJCholesky(3, 1)
LKJCholesky{Float64}(
d: 3
η: 1.0
uplo: L
)
julia> x = rand(dist)
LinearAlgebra.Cholesky{Float64, Matrix{Float64}}
L factor:
3×3 LinearAlgebra.LowerTriangular{Float64, Matrix{Float64}}:
1.0 ⋅ ⋅
-0.107859 0.994166 ⋅
-0.38706 -0.246285 0.888554
julia> link(dist, x)
3-element Vector{Float64}:
-0.10828016710654258
-0.40833723913774234
-0.2737439074161078
julia> link(dist, x.L)
3-element Vector{Float64}:
0.0
0.0
0.0 This was caused by this line: Line 31 in 5bffcfb
Here we ended up calling |
in LKJCholesky tests
Alrighty; tests are at least passing locally for me now. Let's see if Julia 1.6 also works. |
Looks like we're ready to go:) |
* added cholesky_lower and cholesky_triangular * updated PD to use new cholesky_lower and cholesky_upper * simplified imports in BijectorsReverseDiffExtx * added ChainRules as a dep since we need the chain rules for cholesky, etc. * forgot to update Project.toml in previous commit * added explicit implementation of with_logabsdet_jacobian for PDBijector * Update src/utils.jl * added ProjectTo in rrules for cholesky_lower and cholesky_upper to be proper * added ProjectTo for cholesky_upper too * added transpose_eager as a alias for permutedims to allow definition of AD rules without type piracy * allow usage of ForwardDiff gradient as ground-truth * added AD tests for PDVecBijector * added AD tests for PDVecBijector to runtests and commented out all other tests for the sake of reproducing ReverseDiff bug * forgot to remove type-piracy def of ReverseDiff rule for permutedims * use ReverseDiff.@Grad instead of ReverseDiff.@grad_from_chainrules * only define cholesky_lower and cholesky_upper rules for ReverseDiff, remove rules ChainRules defs * formatting * parameterise gradient test for PD bijector properly instead of using ForwardDiff as per suggestion of @devmotion * reversed chagne to test_ad * reactivate tests * updated doocstrings * improved PDVecBijector AD tests a bit * AD fix for CorrBijector (#281) * removed redundant imports to BijectorsZygoteExt * use cholesky_upper and cholesky_lower instead of cholesky_factor, etc. * added tests for CorrVecBijector * name testset correctly * use cholesky_lower and cholesky_upper instead of cholesky_factor * removed now-redundant cholesky_factor * Fix obsolete function references in tests. (#282) * Update chainrules.jl * Update corr.jl * Revert changes to transform. * removed type-piracy that has been addressed upstream and bumped Zygote version in test * use :L for Hermitian in `cholesky_lower` * fixed ForwardDiff tests for LKJCholesky * fixed tests for matrix dists and added tests for both values of uplo in LKJCholesky tests * another attempt at fixing Julia 1.6 tests --------- Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> --------- Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com>
Thanks @sunxd3 and @torfjelde! |
This is a sibling-PR of #280, making use of the functionality introduced there to ensure that
CorrBijector
and it's siblings are also working as intended.