Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reroute algebraic functions for Symmetric/Hermitian through triangular #52942

Merged
merged 8 commits into from
Feb 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 45 additions & 27 deletions stdlib/LinearAlgebra/src/symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -269,10 +269,34 @@ end
end
end

_conjugation(::Symmetric) = transpose
_conjugation(::Hermitian) = adjoint

diag(A::Symmetric) = symmetric.(diag(parent(A)), sym_uplo(A.uplo))
diag(A::Hermitian) = hermitian.(diag(parent(A)), sym_uplo(A.uplo))

isdiag(A::HermOrSym) = isdiag(A.uplo == 'U' ? UpperTriangular(A.data) : LowerTriangular(A.data))
function applytri(f, A::HermOrSym)
if A.uplo == 'U'
f(UpperTriangular(A.data))
else
f(LowerTriangular(A.data))
end
end

function applytri(f, A::HermOrSym, B::HermOrSym)
if A.uplo == B.uplo == 'U'
f(UpperTriangular(A.data), UpperTriangular(B.data))
elseif A.uplo == B.uplo == 'L'
f(LowerTriangular(A.data), LowerTriangular(B.data))
elseif A.uplo == 'U'
f(UpperTriangular(A.data), UpperTriangular(_conjugation(B)(B.data)))
else # A.uplo == 'L'
f(UpperTriangular(_conjugation(A)(A.data)), UpperTriangular(B.data))
end
end
parentof_applytri(f, args...) = applytri(parent ∘ f, args...)

isdiag(A::HermOrSym) = applytri(isdiag, A)

# For A<:Union{Symmetric,Hermitian}, similar(A[, neweltype]) should yield a matrix with the same
# symmetry type, uplo flag, and underlying storage type as A. The following methods cover these cases.
Expand Down Expand Up @@ -314,8 +338,8 @@ Hermitian{T,S}(A::Hermitian) where {T,S<:AbstractMatrix{T}} = Hermitian{T,S}(con
AbstractMatrix{T}(A::Hermitian) where {T} = Hermitian(convert(AbstractMatrix{T}, A.data), sym_uplo(A.uplo))
AbstractMatrix{T}(A::Hermitian{T}) where {T} = copy(A)

copy(A::Symmetric{T,S}) where {T,S} = (B = copy(A.data); Symmetric{T,typeof(B)}(B,A.uplo))
copy(A::Hermitian{T,S}) where {T,S} = (B = copy(A.data); Hermitian{T,typeof(B)}(B,A.uplo))
copy(A::Symmetric) = (Symmetric(parentof_applytri(copy, A), sym_uplo(A.uplo)))
copy(A::Hermitian) = (Hermitian(parentof_applytri(copy, A), sym_uplo(A.uplo)))

function copyto!(dest::Symmetric, src::Symmetric)
if src.uplo == dest.uplo
Expand Down Expand Up @@ -389,9 +413,9 @@ transpose(A::Hermitian) = Transpose(A)

real(A::Symmetric{<:Real}) = A
real(A::Hermitian{<:Real}) = A
real(A::Symmetric) = Symmetric(real(A.data), sym_uplo(A.uplo))
real(A::Hermitian) = Hermitian(real(A.data), sym_uplo(A.uplo))
imag(A::Symmetric) = Symmetric(imag(A.data), sym_uplo(A.uplo))
real(A::Symmetric) = Symmetric(parentof_applytri(real, A), sym_uplo(A.uplo))
real(A::Hermitian) = Hermitian(parentof_applytri(real, A), sym_uplo(A.uplo))
imag(A::Symmetric) = Symmetric(parentof_applytri(imag, A), sym_uplo(A.uplo))

Base.copy(A::Adjoint{<:Any,<:Symmetric}) =
Symmetric(copy(adjoint(A.parent.data)), ifelse(A.parent.uplo == 'U', :L, :U))
Expand All @@ -401,8 +425,9 @@ Base.copy(A::Transpose{<:Any,<:Hermitian}) =
tr(A::Symmetric) = tr(A.data) # to avoid AbstractMatrix fallback (incl. allocations)
tr(A::Hermitian) = real(tr(A.data))

Base.conj(A::HermOrSym) = typeof(A)(conj(A.data), A.uplo)
Base.conj!(A::HermOrSym) = typeof(A)(conj!(A.data), A.uplo)
Base.conj(A::Symmetric) = Symmetric(parentof_applytri(conj, A), sym_uplo(A.uplo))
Base.conj(A::Hermitian) = Hermitian(parentof_applytri(conj, A), sym_uplo(A.uplo))
Base.conj!(A::HermOrSym) = typeof(A)(parentof_applytri(conj!, A), A.uplo)

# tril/triu
function tril(A::Hermitian, k::Integer=0)
Expand Down Expand Up @@ -496,21 +521,14 @@ for (T, trans, real) in [(:Symmetric, :transpose, :identity), (:(Hermitian{<:Uni
end
end

(-)(A::Symmetric) = Symmetric(-A.data, sym_uplo(A.uplo))
(-)(A::Hermitian) = Hermitian(-A.data, sym_uplo(A.uplo))
(-)(A::Symmetric) = Symmetric(parentof_applytri(-, A), sym_uplo(A.uplo))
(-)(A::Hermitian) = Hermitian(parentof_applytri(-, A), sym_uplo(A.uplo))

## Addition/subtraction
for f ∈ (:+, :-), (Wrapper, conjugation) ∈ ((:Hermitian, :adjoint), (:Symmetric, :transpose))
@eval begin
function $f(A::$Wrapper, B::$Wrapper)
if A.uplo == B.uplo
return $Wrapper($f(parent(A), parent(B)), sym_uplo(A.uplo))
elseif A.uplo == 'U'
return $Wrapper($f(parent(A), $conjugation(parent(B))), :U)
else
return $Wrapper($f($conjugation(parent(A)), parent(B)), :U)
end
end
for f ∈ (:+, :-), Wrapper ∈ (:Hermitian, :Symmetric)
@eval function $f(A::$Wrapper, B::$Wrapper)
uplo = A.uplo == B.uplo ? sym_uplo(A.uplo) : (:U)
$Wrapper(parentof_applytri($f, A, B), uplo)
end
end

Expand Down Expand Up @@ -555,12 +573,12 @@ function dot(x::AbstractVector, A::RealHermSymComplexHerm, y::AbstractVector)
end

# Scaling with Number
*(A::Symmetric, x::Number) = Symmetric(A.data*x, sym_uplo(A.uplo))
*(x::Number, A::Symmetric) = Symmetric(x*A.data, sym_uplo(A.uplo))
*(A::Hermitian, x::Real) = Hermitian(A.data*x, sym_uplo(A.uplo))
*(x::Real, A::Hermitian) = Hermitian(x*A.data, sym_uplo(A.uplo))
/(A::Symmetric, x::Number) = Symmetric(A.data/x, sym_uplo(A.uplo))
/(A::Hermitian, x::Real) = Hermitian(A.data/x, sym_uplo(A.uplo))
*(A::Symmetric, x::Number) = Symmetric(parentof_applytri(y -> y * x, A), sym_uplo(A.uplo))
*(x::Number, A::Symmetric) = Symmetric(parentof_applytri(y -> x * y, A), sym_uplo(A.uplo))
*(A::Hermitian, x::Real) = Hermitian(parentof_applytri(y -> y * x, A), sym_uplo(A.uplo))
*(x::Real, A::Hermitian) = Hermitian(parentof_applytri(y -> x * y, A), sym_uplo(A.uplo))
/(A::Symmetric, x::Number) = Symmetric(parentof_applytri(y -> y/x, A), sym_uplo(A.uplo))
/(A::Hermitian, x::Real) = Hermitian(parentof_applytri(y -> y/x, A), sym_uplo(A.uplo))

factorize(A::HermOrSym) = _factorize(A)
function _factorize(A::HermOrSym{T}; check::Bool=true) where T
Expand Down
43 changes: 43 additions & 0 deletions stdlib/LinearAlgebra/test/symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,42 @@ end
end
end

@testset "non-isbits algebra" begin
for ST in (Symmetric, Hermitian), uplo in (:L, :U)
M = Matrix{Complex{BigFloat}}(undef,2,2)
M[1,1] = rand()
M[2,2] = rand()
M[1+(uplo==:L), 1+(uplo==:U)] = rand(ComplexF64)
S = ST(M, uplo)
MS = Matrix(S)
@test real(S) == real(MS)
@test imag(S) == imag(MS)
@test conj(S) == conj(MS)
@test conj!(copy(S)) == conj(MS)
@test -S == -MS
@test S + S == MS + MS
@test S - S == MS - MS
@test S*2 == 2*S == 2*MS
@test S/2 == MS/2
end
@testset "mixed uplo" begin
Mu = Matrix{Complex{BigFloat}}(undef,2,2)
Mu[1,1] = Mu[2,2] = 3
Mu[1,2] = 2 + 3im
Ml = Matrix{Complex{BigFloat}}(undef,2,2)
Ml[1,1] = Ml[2,2] = 4
Ml[2,1] = 4 + 5im
for ST in (Symmetric, Hermitian)
Su = ST(Mu, :U)
MSu = Matrix(Su)
Sl = ST(Ml, :L)
MSl = Matrix(Sl)
@test Su + Sl == Sl + Su == MSu + MSl
@test Su - Sl == -(Sl - Su) == MSu - MSl
end
end
end

# bug identified in PR #52318: dot products of quaternionic Hermitian matrices,
# or any number type where conj(a)*conj(b) ≠ conj(a*b):
@testset "dot Hermitian quaternion #52318" begin
Expand Down Expand Up @@ -932,4 +968,11 @@ end
end
end

@testset "conj for immutable" begin
S = Symmetric(reshape((1:16)*im, 4, 4))
@test conj(S) == conj(Array(S))
H = Hermitian(reshape((1:16)*im, 4, 4))
@test conj(H) == conj(Array(H))
end

end # module TestSymmetric