diff --git a/src/Zygote.jl b/src/Zygote.jl index 07316ba9c..72fe6faca 100644 --- a/src/Zygote.jl +++ b/src/Zygote.jl @@ -1,7 +1,7 @@ module Zygote using LinearAlgebra, Statistics -using LinearAlgebra: copytri! +using LinearAlgebra: copytri!, AbstractTriangular # This flag enables Zygote to grab extra type inference information during # compiles. When control flow is present, this can give gradient code a diff --git a/src/lib/array.jl b/src/lib/array.jl index 947ec4f44..0d8e7e901 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -171,8 +171,18 @@ _backmean(xs, Δ, dims) = zero(xs) .+ Δ ./ mapreduce(i -> size(xs,i),*,dims) end end -@adjoint transpose(x) = transpose(x), Δ -> (transpose(Δ),) -@adjoint Base.adjoint(x) = x', Δ -> (Δ',) +@adjoint function transpose(x) + back(Δ) = (transpose(Δ),) + back(Δ::NamedTuple{(:parent,)}) = (Δ.parent,) + return transpose(x), back +end + +@adjoint function Base.adjoint(x) + back(Δ) = (Δ',) + back(Δ::NamedTuple{(:parent,)}) = (Δ.parent,) + return x', back +end + @adjoint parent(x::LinearAlgebra.Adjoint) = parent(x), ȳ -> (LinearAlgebra.Adjoint(ȳ),) @adjoint dot(x::AbstractArray, y::AbstractArray) = dot(x, y), Δ->(Δ .* y, Δ .* x) @@ -204,18 +214,55 @@ end @adjoint logabsdet(xs) = logabsdet(xs), Δ -> (Δ[1] * transpose(inv(xs)),) @adjoint function inv(A) - return inv(A), function (Δ) - Ainv = inv(A) - ∇A = - Ainv' * Δ * Ainv' - return (∇A, ) - end + return inv(A), function (Δ) + Ainv = inv(A) + ∇A = - Ainv' * Δ * Ainv' + return (∇A, ) + end end -@adjoint function \(A::AbstractMatrix, B::AbstractVecOrMat) +# Defaults for atol and rtol copied directly from LinearAlgebra. See the following for +# derivation: +# Golub, Gene H., and Victor Pereyra. "The differentiation of pseudo-inverses and nonlinear +# least squares problems whose variables separate." SIAM Journal on numerical analysis 10.2 +# (1973): 413-432. +@adjoint function pinv( + A::AbstractMatrix{T}; + atol::Real = 0.0, + rtol::Real = (eps(real(float(one(T))))*min(size(A)...))*iszero(atol), +) where {T} + Y = pinv(A) + return Y, Δ->(-Y' * Δ * Y' + (I - A * Y) * Δ' * Y * Y' + Y' * Y * Δ' * (I - Y * A),) +end + +@adjoint function \(A::Union{Diagonal, AbstractTriangular}, B::AbstractVecOrMat) Y = A \ B return Y, function(Ȳ) - B̄ = A' \ Ȳ - return (-B̄ * Y', B̄) + B̄ = A' \ Ȳ + return (-B̄ * Y', B̄) + end +end + +@adjoint function /(A::AbstractMatrix, B::Union{Diagonal, AbstractTriangular}) + Y = A / B + return Y, function(Ȳ) + Ā = Ȳ / B' + return (Ā, -Y' * Ā) + end +end + +@adjoint function \(A::AbstractMatrix, B::AbstractVecOrMat) + Z = A \ B + return Z, function(Z̄) + B̄ = A' \ Z̄ + if size(A, 1) == size(A, 2) + return (-B̄ * Z', B̄) + else + a = -B̄ * Z' + b = (B - A * Z) * B̄' / A' + c = A' \ Z * (Z̄' - B̄' * A) + return (a + b + c, B̄) + end end end @@ -227,6 +274,9 @@ end # LinAlg Matrix Types # =================== +@adjoint LinearAlgebra.LowerTriangular(A) = LowerTriangular(A), Δ->(LowerTriangular(Δ),) +@adjoint LinearAlgebra.UpperTriangular(A) = UpperTriangular(A), Δ->(UpperTriangular(Δ),) + # This is basically a hack while we don't have a working `ldiv!`. @adjoint function \(A::Cholesky, B::AbstractVecOrMat) Y, back = Zygote.forward((U, B)->U \ (U' \ B), A.U, B) @@ -236,14 +286,6 @@ end end end -@adjoint function /(A::AbstractMatrix, B::AbstractMatrix) - Y = A / B - return Y, function(Ȳ) - Ā = Ȳ / B' - return (Ā, -Y' * Ā) - end -end - _symmetric_back(Δ) = UpperTriangular(Δ) + LowerTriangular(Δ)' - Diagonal(Δ) _symmetric_back(Δ::Union{Diagonal, UpperTriangular}) = Δ @adjoint function Symmetric(A::AbstractMatrix) diff --git a/test/gradcheck.jl b/test/gradcheck.jl index 5060cb7a2..cc4fd9a61 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -139,21 +139,79 @@ end @test gradtest(x -> minimum(x, dims=[1, 2]), rand(2, 3, 4)) end +@testset "(p)inv" begin + rng, P, Q = MersenneTwister(123456), 13, 11 + A, B, C = randn(rng, P, Q), randn(rng, P, P), randn(Q, P) + @test gradtest(pinv, A) + @test gradtest(inv, B) + @test gradtest(pinv, C) +end + @testset "backsolve" begin - rng, P, Q = MersenneTwister(123456), 10, 9 + rng, M, P, Q = MersenneTwister(123456), 13, 10, 9 X, Y, y = randn(rng, P, P), randn(rng, P, Q), randn(rng, P) - - # \ - @test gradtest(X -> X \ Y, X) - @test gradtest(Y -> X \ Y, Y) - @test gradtest(X -> X \ y, X) - @test gradtest(y -> X \ y, y) + A, B = randn(rng, P, M), randn(P, Q) + D = collect(Diagonal(randn(rng, P))) + L = collect(LowerTriangular(randn(rng, P, P))) + L[diagind(L)] .= 1 .+ 0.01 .* randn(rng, P) + U = collect(UpperTriangular(randn(rng, P, P))) + U[diagind(U)] .= 1 .+ 0.01 .* randn(rng, P) + + # \ (Dense square) + @test gradtest(\, X, Y) + @test gradtest(\, X, y) + + # \ (Dense rectangular) + @test gradtest(\, A, Y) + @test gradtest(\, A, y) + @test gradtest(\, B, Y) + @test gradtest(\, B, y) + + # \ (Diagonal) + @test gradtest(\, D, Y) + @test gradtest(\, D, y) + @test gradtest((D, Y)-> Diagonal(D) \ Y, D, Y) + @test gradtest((D, Y)-> Diagonal(D) \ Y, D, y) + + # \ (LowerTriangular) + @test gradtest(\, L, Y) + @test gradtest(\, L, y) + @test gradtest((L, Y) -> LowerTriangular(L) \ Y, L, Y) + @test gradtest((L, Y) -> LowerTriangular(L) \ Y, L, y) + + # \ (UpperTriangular) + @test gradtest(\, U, Y) + @test gradtest(\, U, y) + @test gradtest((U, Y) -> UpperTriangular(U) \ Y, U, Y) + @test gradtest((U, Y) -> UpperTriangular(U) \ Y, U, y) # / - @test gradtest(X -> Y' / X, X) - @test gradtest(Y -> Y' / X, Y) - @test gradtest(X -> y' / X, X) - @test gradtest(y -> y' / X, y) + @test gradtest(/, Y', X) + @test gradtest((y, X)->y' / X, y, X) + + # / (rectangular) + @test gradtest(/, Y', A') + @test gradtest((y, A)->y' / A', y, A) + @test gradtest(/, Y', B') + @test gradtest((y, A)->y' / A', y, B) + + # / (Diagonal) + @test gradtest((D, Y) -> Y' / D, D, Y) + @test gradtest((D, Y) -> Y' / D, D, y) + @test gradtest((D, Y)-> Y' / Diagonal(D), D, Y) + @test gradtest((D, Y)-> Y' / Diagonal(D), D, y) + + # / (LowerTriangular) + @test gradtest((L, Y) -> Y' / L, L, Y) + @test gradtest((L, Y) -> Y' / L, L, y) + @test gradtest((L, Y) -> Y' / LowerTriangular(L), L, Y) + @test gradtest((L, Y) -> Y' / LowerTriangular(L), L, y) + + # / (UpperTriangular) + @test gradtest((U, Y) -> Y' / U, U, Y) + @test gradtest((U, Y) -> Y' / U, U, y) + @test gradtest((U, Y) -> Y' / UpperTriangular(U), U, Y) + @test gradtest((U, Y) -> Y' / UpperTriangular(U), U, y) @testset "Cholesky" begin