From f2bb45d232e4eb1d5ecbcd0b119d91041dd3e5ad Mon Sep 17 00:00:00 2001 From: ST John Date: Thu, 4 Nov 2021 16:15:53 +0200 Subject: [PATCH 1/8] remove `@adjoint function cholesky` --- src/lib/array.jl | 29 ----------------------------- 1 file changed, 29 deletions(-) diff --git a/src/lib/array.jl b/src/lib/array.jl index 15b994564..b35883e59 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -540,35 +540,6 @@ end @adjoint Matrix(A::LinearAlgebra.HermOrSym{T,S}) where {T,S} = Matrix(A), Δ -> (convert(S, Δ),) -@adjoint function cholesky(Σ::Real) - C = cholesky(Σ) - return C, Δ::NamedTuple->(Δ.factors[1, 1] / (2 * C.U[1, 1]),) -end - -@adjoint function cholesky(Σ::Diagonal; check = true) - C = cholesky(Σ, check = check) - return C, Δ::NamedTuple -> begin - issuccess(C) || throw(PosDefException(C.info)) - return Diagonal(diag(Δ.factors) .* inv.(2 .* C.factors.diag)), nothing - end -end - -# Implementation due to Seeger, Matthias, et al. "Auto-differentiating linear algebra." -@adjoint function cholesky(Σ::Union{StridedMatrix, Symmetric{<:Real, <:StridedMatrix}}; check = true) - C = cholesky(Σ, check = check) - return C, function(Δ::NamedTuple) - issuccess(C) || throw(PosDefException(C.info)) - U, Ū = C.U, Δ.factors - Σ̄ = similar(U.data) - Σ̄ = mul!(Σ̄, Ū, U') - Σ̄ = copytri!(Σ̄, 'U') - Σ̄ = ldiv!(U, Σ̄) - Σ̄ = BLAS.trsm!('R', 'U', 'T', 'N', one(eltype(Σ)), U.data, Σ̄) - Σ̄[diagind(Σ̄)] ./= 2 - return (UpperTriangular(Σ̄),) - end -end - @adjoint function lyap(A::AbstractMatrix, C::AbstractMatrix) X = lyap(A, C) return X, function (X̄) From 131c5c82a9c653a836f1545cbac9c687ab7507f8 Mon Sep 17 00:00:00 2001 From: ST John Date: Fri, 17 Jun 2022 19:28:36 +0300 Subject: [PATCH 2/8] increase ChainRules lower bound to 1.35.3 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index ae5c974d9..8de31f1a4 100644 --- a/Project.toml +++ b/Project.toml @@ -25,7 +25,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] AbstractFFTs = "0.5, 1.0" -ChainRules = "1.5" +ChainRules = "1.35.3" ChainRulesCore = "1.9" ChainRulesTestUtils = "1" DiffRules = "1.4" From 644a5dd874f769bca7d6e4aa314e08f36bf4bd27 Mon Sep 17 00:00:00 2001 From: ST John Date: Sat, 18 Jun 2022 10:47:23 +0300 Subject: [PATCH 3/8] bump julia compat to 1.6 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 8de31f1a4..a35854934 100644 --- a/Project.toml +++ b/Project.toml @@ -38,7 +38,7 @@ NaNMath = "0.3, 1" Requires = "1.1" SpecialFunctions = "1.6, 2" ZygoteRules = "0.2.1" -julia = "1.3" +julia = "1.6" [extras] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" From 206a6000397cc47ebf3c2b5c3c79c201e6eca310 Mon Sep 17 00:00:00 2001 From: ST John Date: Sat, 18 Jun 2022 10:58:30 +0300 Subject: [PATCH 4/8] remove failing test --- test/gradcheck.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/test/gradcheck.jl b/test/gradcheck.jl index ac0dd28bf..0b024a51f 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -654,7 +654,6 @@ end g(X) = cholesky(X * X' + I) @test Zygote.pullback(g, X)[2]((factors=LowerTriangular(X),))[1] ≈ Zygote.pullback(g, X)[2]((factors=Matrix(LowerTriangular(X)),))[1] - @test_throws PosDefException Zygote.pullback(X -> cholesky(X, check = false), X)[2]((factors=X,)) # https://github.com/FluxML/Zygote.jl/issues/932 @test gradcheck(rand(5, 5), rand(5)) do A, x From 984a25c01adca0e9b197c235a986dcb9ed5c1396 Mon Sep 17 00:00:00 2001 From: ST John Date: Sat, 18 Jun 2022 11:08:55 +0300 Subject: [PATCH 5/8] add Hermitian cholesky test --- test/gradcheck.jl | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/test/gradcheck.jl b/test/gradcheck.jl index 0b024a51f..0ce7f6ae4 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -819,6 +819,18 @@ end @test back′(C̄)[1] isa Diagonal @test diag(back′(C̄)[1]) ≈ diag(back(C̄)[1]) end + @testset "cholesky - Hermitian" begin + rng, N = MersenneTwister(123456), 3 + A = randn(rng, N, N) + im * randn(rng, N, N) + H = Hermitian(A * A' + I) + Hmat = Matrix(H) + y, back = Zygote.pullback(cholesky, Hmat) + y′, back′ = Zygote.pullback(cholesky, H) + C̄ = (factors=randn(rng, N, N),) + @test back′(C̄)[1] isa Hermitian + @test gradtest(B->cholesky(Hermitian(B)).U, A * A' + I) + @test gradtest(B->logdet(cholesky(Hermitian(B))), A * A' + I) + end end @testset "lyap" begin From d13be2e84358e449e45ea7e3b86e8db793fd70ea Mon Sep 17 00:00:00 2001 From: st-- Date: Sat, 18 Jun 2022 21:48:44 +0300 Subject: [PATCH 6/8] Update test/gradcheck.jl Co-authored-by: David Widmann --- test/gradcheck.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/gradcheck.jl b/test/gradcheck.jl index 0ce7f6ae4..46e40b87e 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -821,7 +821,7 @@ end end @testset "cholesky - Hermitian" begin rng, N = MersenneTwister(123456), 3 - A = randn(rng, N, N) + im * randn(rng, N, N) + A = randn(rng, Complex{Float64}, N, N) H = Hermitian(A * A' + I) Hmat = Matrix(H) y, back = Zygote.pullback(cholesky, Hmat) From c8df3f07d326437d35a31f0da60190388e9dbc14 Mon Sep 17 00:00:00 2001 From: ST John Date: Sat, 18 Jun 2022 21:49:26 +0300 Subject: [PATCH 7/8] bump julia minimum version in github action ci.yml --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index bab7876a5..887c985c8 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -17,7 +17,7 @@ jobs: fail-fast: false matrix: version: - - '1.3' # Replace this with the minimum Julia version that your package supports. + - '1.6' # Replace this with the minimum Julia version that your package supports. - '1' # automatically expands to the latest stable 1.x release of Julia - 'nightly' os: From 815e8dd76f3618640ad9305f54ce85fc3a634d8f Mon Sep 17 00:00:00 2001 From: ST John Date: Mon, 20 Jun 2022 13:01:04 +0300 Subject: [PATCH 8/8] fix test --- test/gradcheck.jl | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/test/gradcheck.jl b/test/gradcheck.jl index 46e40b87e..182f2b666 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -819,7 +819,7 @@ end @test back′(C̄)[1] isa Diagonal @test diag(back′(C̄)[1]) ≈ diag(back(C̄)[1]) end - @testset "cholesky - Hermitian" begin + @testset "cholesky - Hermitian{Complex}" begin rng, N = MersenneTwister(123456), 3 A = randn(rng, Complex{Float64}, N, N) H = Hermitian(A * A' + I) @@ -827,9 +827,23 @@ end y, back = Zygote.pullback(cholesky, Hmat) y′, back′ = Zygote.pullback(cholesky, H) C̄ = (factors=randn(rng, N, N),) + @test only(back′(C̄)) isa Hermitian + # gradtest does not support complex gradients, even though the pullback exists + d = only(back(C̄)) + d′ = only(back′(C̄)) + @test (d + d')/2 ≈ d′ + end + @testset "cholesky - Hermitian{Real}" begin + rng, N = MersenneTwister(123456), 3 + A = randn(rng, N, N) + H = Hermitian(A * A' + I) + Hmat = Matrix(H) + y, back = Zygote.pullback(cholesky, Hmat) + y′, back′ = Zygote.pullback(cholesky, H) + C̄ = (factors=randn(rng, N, N),) @test back′(C̄)[1] isa Hermitian - @test gradtest(B->cholesky(Hermitian(B)).U, A * A' + I) - @test gradtest(B->logdet(cholesky(Hermitian(B))), A * A' + I) + @test gradtest(B->cholesky(Hermitian(B)).U, Hmat) + @test gradtest(B->logdet(cholesky(Hermitian(B))), Hmat) end end