Skip to content

Commit

Permalink
Move Symmetric/Hermitian rules and tests to own file (#322)
Browse files Browse the repository at this point in the history
* Move symmetric rules to own file

* Move symmetric tests to own file

* Increment version number
  • Loading branch information
sethaxen authored Dec 6, 2020
1 parent fa4b93a commit eb10848
Show file tree
Hide file tree
Showing 7 changed files with 126 additions and 123 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "0.7.35"
version = "0.7.36"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
1 change: 1 addition & 0 deletions src/ChainRules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ include("rulesets/LinearAlgebra/blas.jl")
include("rulesets/LinearAlgebra/dense.jl")
include("rulesets/LinearAlgebra/norm.jl")
include("rulesets/LinearAlgebra/structured.jl")
include("rulesets/LinearAlgebra/symmetric.jl")
include("rulesets/LinearAlgebra/factorization.jl")

include("rulesets/Random/random.jl")
Expand Down
80 changes: 0 additions & 80 deletions src/rulesets/LinearAlgebra/structured.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,86 +86,6 @@ function rrule(::typeof(*), D::Diagonal{<:Real}, V::AbstractVector{<:Real})
return D * V, times_pullback
end

#####
##### `Symmetric`/`Hermitian`
#####

function frule((_, ΔA, _), T::Type{<:LinearAlgebra.HermOrSym}, A::AbstractMatrix, uplo)
return T(A, uplo), T(ΔA, uplo)
end

function rrule(T::Type{<:LinearAlgebra.HermOrSym}, A::AbstractMatrix, uplo)
Ω = T(A, uplo)
function HermOrSym_pullback(ΔΩ)
return (NO_FIELDS, _symherm_back(T, ΔΩ, Ω.uplo), DoesNotExist())
end
return Ω, HermOrSym_pullback
end

function frule((_, ΔA), TM::Type{<:Matrix}, A::LinearAlgebra.HermOrSym)
return TM(A), TM(_symherm_forward(A, ΔA))
end
function frule((_, ΔA), ::Type{Array}, A::LinearAlgebra.HermOrSym)
return Array(A), Array(_symherm_forward(A, ΔA))
end

function rrule(TM::Type{<:Matrix}, A::LinearAlgebra.HermOrSym)
function Matrix_pullback(ΔΩ)
TA = _symhermtype(A)
T∂A = TA{eltype(ΔΩ),typeof(ΔΩ)}
uplo = A.uplo
∂A = T∂A(_symherm_back(A, ΔΩ, uplo), uplo)
return NO_FIELDS, ∂A
end
return TM(A), Matrix_pullback
end
rrule(::Type{Array}, A::LinearAlgebra.HermOrSym) = rrule(Matrix, A)

# Get type (Symmetric or Hermitian) from type or matrix
_symhermtype(::Type{<:Symmetric}) = Symmetric
_symhermtype(::Type{<:Hermitian}) = Hermitian
_symhermtype(A) = _symhermtype(typeof(A))

# for Ω = Matrix(A::HermOrSym), push forward ΔA to get ∂Ω
function _symherm_forward(A, ΔA)
TA = _symhermtype(A)
return if ΔA isa TA
ΔA
else
TA{eltype(ΔA),typeof(ΔA)}(ΔA, A.uplo)
end
end

# for Ω = HermOrSym(A, uplo), pull back ΔΩ to get ∂A
_symherm_back(::Type{<:Symmetric}, ΔΩ, uplo) = _symmetric_back(ΔΩ, uplo)
function _symherm_back(::Type{<:Hermitian}, ΔΩ::AbstractMatrix{<:Real}, uplo)
return _symmetric_back(ΔΩ, uplo)
end
_symherm_back(::Type{<:Hermitian}, ΔΩ, uplo) = _hermitian_back(ΔΩ, uplo)
_symherm_back(Ω, ΔΩ, uplo) = _symherm_back(typeof(Ω), ΔΩ, uplo)

function _symmetric_back(ΔΩ, uplo)
L, U, D = LowerTriangular(ΔΩ), UpperTriangular(ΔΩ), Diagonal(ΔΩ)
return uplo == 'U' ? U .+ transpose(L) - D : L .+ transpose(U) - D
end
_symmetric_back(ΔΩ::Diagonal, uplo) = ΔΩ
_symmetric_back(ΔΩ::UpperTriangular, uplo) = Matrix(uplo == 'U' ? ΔΩ : transpose(ΔΩ))
_symmetric_back(ΔΩ::LowerTriangular, uplo) = Matrix(uplo == 'U' ? transpose(ΔΩ) : ΔΩ)

function _hermitian_back(ΔΩ, uplo)
L, U, rD = LowerTriangular(ΔΩ), UpperTriangular(ΔΩ), real.(Diagonal(ΔΩ))
return uplo == 'U' ? U .+ L' - rD : L .+ U' - rD
end
_hermitian_back(ΔΩ::Diagonal, uplo) = real.(ΔΩ)
function _hermitian_back(ΔΩ::LinearAlgebra.AbstractTriangular, uplo)
∂UL = ΔΩ .- Diagonal(_extract_imag.(diag(ΔΩ)))
return if istriu(ΔΩ)
return Matrix(uplo == 'U' ? ∂UL : ∂UL')
else
return Matrix(uplo == 'U' ? ∂UL' : ∂UL)
end
end

#####
##### `Adjoint`
#####
Expand Down
79 changes: 79 additions & 0 deletions src/rulesets/LinearAlgebra/symmetric.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
#####
##### `Symmetric`/`Hermitian`
#####

function frule((_, ΔA, _), T::Type{<:LinearAlgebra.HermOrSym}, A::AbstractMatrix, uplo)
return T(A, uplo), T(ΔA, uplo)
end

function rrule(T::Type{<:LinearAlgebra.HermOrSym}, A::AbstractMatrix, uplo)
Ω = T(A, uplo)
function HermOrSym_pullback(ΔΩ)
return (NO_FIELDS, _symherm_back(T, ΔΩ, Ω.uplo), DoesNotExist())
end
return Ω, HermOrSym_pullback
end

function frule((_, ΔA), TM::Type{<:Matrix}, A::LinearAlgebra.HermOrSym)
return TM(A), TM(_symherm_forward(A, ΔA))
end
function frule((_, ΔA), ::Type{Array}, A::LinearAlgebra.HermOrSym)
return Array(A), Array(_symherm_forward(A, ΔA))
end

function rrule(TM::Type{<:Matrix}, A::LinearAlgebra.HermOrSym)
function Matrix_pullback(ΔΩ)
TA = _symhermtype(A)
T∂A = TA{eltype(ΔΩ),typeof(ΔΩ)}
uplo = A.uplo
∂A = T∂A(_symherm_back(A, ΔΩ, uplo), uplo)
return NO_FIELDS, ∂A
end
return TM(A), Matrix_pullback
end
rrule(::Type{Array}, A::LinearAlgebra.HermOrSym) = rrule(Matrix, A)

# Get type (Symmetric or Hermitian) from type or matrix
_symhermtype(::Type{<:Symmetric}) = Symmetric
_symhermtype(::Type{<:Hermitian}) = Hermitian
_symhermtype(A) = _symhermtype(typeof(A))

# for Ω = Matrix(A::HermOrSym), push forward ΔA to get ∂Ω
function _symherm_forward(A, ΔA)
TA = _symhermtype(A)
return if ΔA isa TA
ΔA
else
TA{eltype(ΔA),typeof(ΔA)}(ΔA, A.uplo)
end
end

# for Ω = HermOrSym(A, uplo), pull back ΔΩ to get ∂A
_symherm_back(::Type{<:Symmetric}, ΔΩ, uplo) = _symmetric_back(ΔΩ, uplo)
function _symherm_back(::Type{<:Hermitian}, ΔΩ::AbstractMatrix{<:Real}, uplo)
return _symmetric_back(ΔΩ, uplo)
end
_symherm_back(::Type{<:Hermitian}, ΔΩ, uplo) = _hermitian_back(ΔΩ, uplo)
_symherm_back(Ω, ΔΩ, uplo) = _symherm_back(typeof(Ω), ΔΩ, uplo)

function _symmetric_back(ΔΩ, uplo)
L, U, D = LowerTriangular(ΔΩ), UpperTriangular(ΔΩ), Diagonal(ΔΩ)
return uplo == 'U' ? U .+ transpose(L) - D : L .+ transpose(U) - D
end
_symmetric_back(ΔΩ::Diagonal, uplo) = ΔΩ
_symmetric_back(ΔΩ::UpperTriangular, uplo) = Matrix(uplo == 'U' ? ΔΩ : transpose(ΔΩ))
_symmetric_back(ΔΩ::LowerTriangular, uplo) = Matrix(uplo == 'U' ? transpose(ΔΩ) : ΔΩ)

function _hermitian_back(ΔΩ, uplo)
L, U, rD = LowerTriangular(ΔΩ), UpperTriangular(ΔΩ), real.(Diagonal(ΔΩ))
return uplo == 'U' ? U .+ L' - rD : L .+ U' - rD
end
_hermitian_back(ΔΩ::Diagonal, uplo) = real.(ΔΩ)
function _hermitian_back(ΔΩ::LinearAlgebra.AbstractTriangular, uplo)
∂UL = ΔΩ .- Diagonal(_extract_imag.(diag(ΔΩ)))
return if istriu(ΔΩ)
return Matrix(uplo == 'U' ? ∂UL : ∂UL')
else
return Matrix(uplo == 'U' ? ∂UL' : ∂UL)
end
end
42 changes: 0 additions & 42 deletions test/rulesets/LinearAlgebra/structured.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,48 +104,6 @@
end
end
end
@testset "$(SymHerm)(::AbstractMatrix{$T}, :$(uplo))" for
SymHerm in (Symmetric, Hermitian),
T in (Float64, ComplexF64),
uplo in (:U, :L)

N = 3
@testset "frule" begin
x = randn(T, N, N)
Δx = randn(T, N, N)
# can't use frule_test here because it doesn't yet ignore nothing tangents
Ω = SymHerm(x, uplo)
Ω_ad, ∂Ω_ad = frule((Zero(), Δx, Zero()), SymHerm, x, uplo)
@test Ω_ad == Ω
∂Ω_fd = jvp(_fdm, z -> SymHerm(z, uplo), (x, Δx))
@test ∂Ω_ad ∂Ω_fd
end
@testset "rrule" begin
x = randn(T, N, N)
∂x = randn(T, N, N)
ΔΩ = randn(T, N, N)
@testset "back(::$MT)" for MT in (Matrix, LowerTriangular, UpperTriangular)
rrule_test(SymHerm, MT(ΔΩ), (x, ∂x), (uplo, nothing))
end
@testset "back(::Diagonal)" begin
rrule_test(SymHerm, Diagonal(ΔΩ), (x, Diagonal(∂x)), (uplo, nothing))
end
end
end
@testset "$(f)(::$(SymHerm){$T}) with uplo=:$uplo" for f in (Matrix, Array),
SymHerm in (Symmetric, Hermitian),
T in (Float64, ComplexF64),
uplo in (:U, :L)

N = 3
x = SymHerm(randn(T, N, N), uplo)
Δx = randn(T, N, N)
∂x = SymHerm(randn(T, N, N), uplo)
ΔΩ = f(SymHerm(randn(T, N, N), uplo))
frule_test(f, (x, Δx))
frule_test(f, (x, SymHerm(Δx, uplo)))
rrule_test(f, ΔΩ, (x, ∂x))
end
@testset "$f" for f in (Adjoint, adjoint, Transpose, transpose)
n = 5
m = 3
Expand Down
44 changes: 44 additions & 0 deletions test/rulesets/LinearAlgebra/symmetric.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
@testset "Symmetric/Hermitian rules" begin
@testset "$(SymHerm)(::AbstractMatrix{$T}, :$(uplo))" for
SymHerm in (Symmetric, Hermitian),
T in (Float64, ComplexF64),
uplo in (:U, :L)

N = 3
@testset "frule" begin
x = randn(T, N, N)
Δx = randn(T, N, N)
# can't use frule_test here because it doesn't yet ignore nothing tangents
Ω = SymHerm(x, uplo)
Ω_ad, ∂Ω_ad = frule((Zero(), Δx, Zero()), SymHerm, x, uplo)
@test Ω_ad == Ω
∂Ω_fd = jvp(_fdm, z -> SymHerm(z, uplo), (x, Δx))
@test ∂Ω_ad ∂Ω_fd
end
@testset "rrule" begin
x = randn(T, N, N)
∂x = randn(T, N, N)
ΔΩ = randn(T, N, N)
@testset "back(::$MT)" for MT in (Matrix, LowerTriangular, UpperTriangular)
rrule_test(SymHerm, MT(ΔΩ), (x, ∂x), (uplo, nothing))
end
@testset "back(::Diagonal)" begin
rrule_test(SymHerm, Diagonal(ΔΩ), (x, Diagonal(∂x)), (uplo, nothing))
end
end
end
@testset "$(f)(::$(SymHerm){$T}) with uplo=:$uplo" for f in (Matrix, Array),
SymHerm in (Symmetric, Hermitian),
T in (Float64, ComplexF64),
uplo in (:U, :L)

N = 3
x = SymHerm(randn(T, N, N), uplo)
Δx = randn(T, N, N)
∂x = SymHerm(randn(T, N, N), uplo)
ΔΩ = f(SymHerm(randn(T, N, N), uplo))
frule_test(f, (x, Δx))
frule_test(f, (x, SymHerm(Δx, uplo)))
rrule_test(f, ΔΩ, (x, ∂x))
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ println("Testing ChainRules.jl")
include_test("rulesets/LinearAlgebra/dense.jl")
include_test("rulesets/LinearAlgebra/norm.jl")
include_test("rulesets/LinearAlgebra/structured.jl")
include_test("rulesets/LinearAlgebra/symmetric.jl")
include_test("rulesets/LinearAlgebra/factorization.jl")
include_test("rulesets/LinearAlgebra/blas.jl")
end
Expand Down

2 comments on commit eb10848

@sethaxen
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/25905

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.7.36 -m "<description of version>" eb10848119db1efa4234f5ad3476094ca29c744d
git push origin v0.7.36

Also, note the warning: Version 0.7.36 skips over 0.7.35
This can be safely ignored. However, if you want to fix this you can do so. Call register() again after making the fix. This will update the Pull request.

Please sign in to comment.