diff --git a/Project.toml b/Project.toml index 18018dd..690eb4a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "DifferentiableFactorizations" uuid = "f7876f94-e99c-4755-b0c6-59dc4ff4934d" authors = ["Mohamed Tarek and contributors"] -version = "0.2.0" +version = "0.2.1" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @@ -12,7 +12,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" [compat] ChainRulesCore = "1" ComponentArrays = "0.11, 0.12, 0.13" -ImplicitDifferentiation = "0.4" +ImplicitDifferentiation = "0.5" julia = "1" [extras] diff --git a/src/DifferentiableFactorizations.jl b/src/DifferentiableFactorizations.jl index 0d1263f..2948533 100644 --- a/src/DifferentiableFactorizations.jl +++ b/src/DifferentiableFactorizations.jl @@ -6,61 +6,62 @@ using LinearAlgebra, ImplicitDifferentiation, ComponentArrays, ChainRulesCore # QR -function qr_conditions(A, x, _) - (; Q, R) = x - return vcat( - vec(UpperTriangular(Q' * Q) + LowerTriangular(R) - I - Diagonal(R)), - vec(Q * R - A), - ) +function qr_conditions(A, x) + (; Q, R) = x + return vcat( + vec(UpperTriangular(Q' * Q) + LowerTriangular(R) - I - Diagonal(R)), + vec(Q * R - A), + ) end function qr_forward(A) - qr_res = qr(A) - Q = copy(qr_res.Q[:, 1:size(A, 2)]) - (; R) = qr_res - return ComponentVector(; Q, R), 0 + qr_res = qr(A) + Q = copy(qr_res.Q[:, 1:size(A, 2)]) + (; R) = qr_res + return ComponentVector(; Q, R) end -const _diff_qr = ImplicitFunction(qr_forward, qr_conditions) +const _diff_qr = ImplicitFunction(qr_forward, qr_conditions, DirectLinearSolver(), nothing) function diff_qr(A) - (; Q, R) = _diff_qr(A)[1] - return (; Q, R) + (; Q, R) = _diff_qr(A) + return (; Q, R) end # Cholesky -function cholesky_conditions(A, U, _) - return vec(UpperTriangular(U' * U) + LowerTriangular(U) - UpperTriangular(A) - Diagonal(U)) +function cholesky_conditions(A, U) + return vec( + UpperTriangular(U' * U) + LowerTriangular(U) - UpperTriangular(A) - Diagonal(U), + ) end function cholesky_forward(A) - ch_res = cholesky(A) - return ch_res.U, 0 + ch_res = cholesky(A) + return ch_res.U end -const _diff_cholesky = ImplicitFunction(cholesky_forward, cholesky_conditions) +const _diff_cholesky = + ImplicitFunction(cholesky_forward, cholesky_conditions, DirectLinearSolver(), nothing) function diff_cholesky(A) - U = _diff_cholesky(A)[1] - return (; L = U', U) + U = _diff_cholesky(A) + return (; L = U', U) end # LU -function lu_conditions(A, LU, _) - (; L, U, p) = LU - pint = convert(Vector{Int}, p) - return vcat( - vec(UpperTriangular(L) + LowerTriangular(U) - Diagonal(U) - I), - vec(L * U - A[pint, :]), - p, - ) +function lu_conditions(A, LU, p) + (; L, U) = LU + return vcat( + vec(UpperTriangular(L) + LowerTriangular(U) - Diagonal(U) - I), + vec(L * U - A[p, :]), + ) end function lu_forward(A) - lu_res = lu(A) - (; L, U, p) = lu_res - return ComponentVector(; L, U, p), 0 + lu_res = lu(A) + (; L, U, p) = lu_res + return ComponentVector(; L, U), p end -const _diff_lu = ImplicitFunction(lu_forward, lu_conditions) +const _diff_lu = ImplicitFunction(lu_forward, lu_conditions, DirectLinearSolver(), nothing) function diff_lu(A) - temp = _diff_lu(A)[1] - (; L, U, p) = temp - return (; L, U, p = convert(Vector{Int}, p)) + temp, p = _diff_lu(A) + (; L, U) = temp + return (; L, U, p) end # Eigen @@ -68,129 +69,130 @@ end comp_vec(A) = ComponentVector((; A)) comp_vec(A, B) = ComponentVector((; A, B)) function ChainRulesCore.rrule(::typeof(comp_vec), A) - out = comp_vec(A) - T = typeof(out) - return out, Δ -> begin - _Δ = convert(T, Δ) - (NoTangent(), _Δ.A) - end + out = comp_vec(A) + T = typeof(out) + return out, Δ -> begin + _Δ = convert(T, Δ) + (NoTangent(), _Δ.A) + end end function ChainRulesCore.rrule(::typeof(comp_vec), A, B) - out = comp_vec(A, B) - T = typeof(out) - return out, Δ -> begin - _Δ = convert(T, Δ) - (NoTangent(), _Δ.A, _Δ.B) - end -end - -function eigen_conditions(AB, sV, _) - (; s, V) = sV - (; A) = AB - if hasproperty(AB, :B) - (; B) = AB - else - B = I - end - return vcat( - vec(A * V - B * V * Diagonal(s)), - diag(V' * B * V) .- 1, - ) + out = comp_vec(A, B) + T = typeof(out) + return out, Δ -> begin + _Δ = convert(T, Δ) + (NoTangent(), _Δ.A, _Δ.B) + end +end + +function eigen_conditions(AB, sV) + (; s, V) = sV + (; A) = AB + if hasproperty(AB, :B) + (; B) = AB + else + B = I + end + return vcat(vec(A * V - B * V * Diagonal(s)), diag(V' * B * V) .- 1) end function eigen_forward(AB) - (; A) = AB - if hasproperty(AB, :B) - (; B) = AB - eig_res = eigen(A, B) - else - eig_res = eigen(A) - end - s = eig_res.values - V = eig_res.vectors - return ComponentVector(; s, V), 0 -end - -const _diff_eigen = ImplicitFunction(eigen_forward, eigen_conditions) + (; A) = AB + if hasproperty(AB, :B) + (; B) = AB + eig_res = eigen(A, B) + else + eig_res = eigen(A) + end + s = eig_res.values + V = eig_res.vectors + return ComponentVector(; s, V) +end + +const _diff_eigen = + ImplicitFunction(eigen_forward, eigen_conditions, DirectLinearSolver(), nothing) function diff_eigen(A) - (; s, V) = _diff_eigen(comp_vec(A))[1] - return (; s , V) + (; s, V) = _diff_eigen(comp_vec(A)) + return (; s, V) end function diff_eigen(A, B) - (; s, V) = _diff_eigen(comp_vec(A, B))[1] - return (; s , V) + (; s, V) = _diff_eigen(comp_vec(A, B)) + return (; s, V) end -function schur_conditions(A, Z_T, _) - (; Z, T) = Z_T - return vcat( - vec(Z' * A * Z - T), - vec(Z' * Z - I + LowerTriangular(T) - Diagonal(T)), - ) +function schur_conditions(A, Z_T) + (; Z, T) = Z_T + return vcat(vec(Z' * A * Z - T), vec(Z' * Z - I + LowerTriangular(T) - Diagonal(T))) end function schur_forward(A) - schur_res = schur(A) - (; Z, T) = schur_res - return ComponentVector(; Z, T), 0 + schur_res = schur(A) + (; Z, T) = schur_res + return ComponentVector(; Z, T) end -const _diff_schur = ImplicitFunction(schur_forward, schur_conditions) +const _diff_schur = + ImplicitFunction(schur_forward, schur_conditions, DirectLinearSolver(), nothing) function bidiag(v1, v2) - return Bidiagonal(v1, v2, :L) + return Bidiagonal(v1, v2, :L) end function ChainRulesCore.rrule(::typeof(bidiag), v1, v2) - bidiag(v1, v2), Δ -> begin - NoTangent(), diag(Δ), diag(Δ, -1) - end -end - -function gen_schur_conditions(AB, left_right_S_T, _) - (; left, right, S, T) = left_right_S_T - (; A, B) = AB - return vcat( - vec(left * S * right' - A), - vec(left * T * right' - B), - vec(UpperTriangular(left' * left) - I + LowerTriangular(S) - bidiag(diag(S), diag(S, -1) .+ (diag(S, -1) .* diag(T, 1)))), - vec(UpperTriangular(right' * right) - I + LowerTriangular(T) - Diagonal(T)), - ) + bidiag(v1, v2), Δ -> begin + NoTangent(), diag(Δ), diag(Δ, -1) + end +end + +function gen_schur_conditions(AB, left_right_S_T) + (; left, right, S, T) = left_right_S_T + (; A, B) = AB + return vcat( + vec(left * S * right' - A), + vec(left * T * right' - B), + vec( + UpperTriangular(left' * left) - I + LowerTriangular(S) - + bidiag(diag(S), diag(S, -1) .+ (diag(S, -1) .* diag(T, 1))), + ), + vec(UpperTriangular(right' * right) - I + LowerTriangular(T) - Diagonal(T)), + ) end function gen_schur_forward(AB) - (; A, B) = AB - schur_res = schur(A, B) - (; left, right, S, T) = schur_res - return ComponentVector(; left, right, S, T), 0 + (; A, B) = AB + schur_res = schur(A, B) + (; left, right, S, T) = schur_res + return ComponentVector(; left, right, S, T) end -const _diff_gen_schur = ImplicitFunction(gen_schur_forward, gen_schur_conditions) +const _diff_gen_schur = + ImplicitFunction(gen_schur_forward, gen_schur_conditions, DirectLinearSolver(), nothing) function diff_schur(A, B) - (; left, right, S, T) = _diff_gen_schur(comp_vec(A, B))[1] - return (; left, right, S, T) + (; left, right, S, T) = _diff_gen_schur(comp_vec(A, B)) + return (; left, right, S, T) end function diff_schur(A) - (; Z, T) = _diff_schur(A)[1] - return (; Z, T) + (; Z, T) = _diff_schur(A) + return (; Z, T) end # SVD -function svd_conditions(A, USV, _) - (; U, S, V) = USV - VtV = V' * V - return vcat( - vec(U * Diagonal(S) * V' - A), - vec(UpperTriangular(VtV) + LowerTriangular(U' * U) - 2I), - diag(VtV) .- 1, - ) +function svd_conditions(A, USV) + (; U, S, V) = USV + VtV = V' * V + return vcat( + vec(U * Diagonal(S) * V' - A), + vec(UpperTriangular(VtV) + LowerTriangular(U' * U) - 2I), + diag(VtV) .- 1, + ) end function svd_forward(A) - svd_res = svd(A) - (; U, S, V) = svd_res - return ComponentVector(; U, S, V), 0 + svd_res = svd(A) + (; U, S, V) = svd_res + return ComponentVector(; U, S, V) end -const _diff_svd = ImplicitFunction(svd_forward, svd_conditions) +const _diff_svd = + ImplicitFunction(svd_forward, svd_conditions, DirectLinearSolver(), nothing) function diff_svd(A) - (; U, S, V) = _diff_svd(A)[1] - return (; U, S , V) + (; U, S, V) = _diff_svd(A) + return (; U, S, V) end end diff --git a/test/runtests.jl b/test/runtests.jl index d69ba92..78b79aa 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,11 +1,12 @@ -using DifferentiableFactorizations, Test, Zygote, FiniteDifferences, LinearAlgebra, ComponentArrays, Random +using DifferentiableFactorizations, + Test, Zygote, FiniteDifferences, LinearAlgebra, ComponentArrays, Random Random.seed!(1) const nreps = 3 const tol = 1e-8 @testset "Cholesky" begin - for _ in 1:nreps + for _ = 1:nreps A = rand(3, 3) f1(A) = diff_cholesky(A' * A + 2I).U @@ -21,7 +22,7 @@ const tol = 1e-8 end @testset "LU" begin - for _ in 1:nreps + for _ = 1:nreps A = rand(3, 3) f1(A) = vec(diff_lu(A).U) @@ -37,9 +38,9 @@ end end @testset "QR" begin - for _ in 1:nreps + for _ = 1:nreps A = rand(3, 2) - + f1(A) = vec(diff_qr(A).Q) zjac1 = Zygote.jacobian(f1, A)[1] fjac1 = FiniteDifferences.jacobian(central_fdm(5, 1), f1, A)[1] @@ -53,7 +54,7 @@ end end @testset "Eigen" begin - for _ in 1:nreps + for _ = 1:nreps A = rand(3, 3) B = rand(3, 3) AB = ComponentVector(; A, B) @@ -91,7 +92,7 @@ end end @testset "SVD" begin - for _ in 1:nreps + for _ = 1:nreps A = rand(3, 3) f1(A) = diff_svd(A).S @@ -112,7 +113,7 @@ end end @testset "Schur" begin - for _ in 1:nreps + for _ = 1:nreps A = randn(3, 3) A = A' + A + I f1(A) = vec(diff_schur(A).Z) @@ -128,7 +129,7 @@ end end @testset "Generalized Schur" begin - for _ in 1:nreps + for _ = 1:nreps A = randn(3, 3) A = A' + A + I B = rand(3, 3)