Skip to content

Commit

Permalink
Reroute algebraic functions for Symmetric/Hermitian through trian…
Browse files Browse the repository at this point in the history
…gular (#52942)

This ensures that only the triangular indices are accessed for strided
parent matrices. Fix #52895

```julia
julia> M = Matrix{Complex{BigFloat}}(undef, 2, 2);

julia> M[1,1] = M[2,2] = M[1,2] = 2;

julia> H = Hermitian(M)
2×2 Hermitian{Complex{BigFloat}, Matrix{Complex{BigFloat}}}:
 2.0+0.0im  2.0+0.0im
 2.0-0.0im  2.0+0.0im

julia> H + H # works after this
2×2 Hermitian{Complex{BigFloat}, Matrix{Complex{BigFloat}}}:
 4.0+0.0im  4.0+0.0im
 4.0-0.0im  4.0+0.0im
```
This also provides a speed-up in several common cases (allocations
mentioned only when they differ):
```julia
julia> H = Hermitian(rand(ComplexF64,1000,1000));

julia> H2 = Hermitian(rand(ComplexF64,1000,1000),:L);
```
| Operation | master | PR |
| ---- | ---- | ---- |
|`-H` |2.247 ms | 1.384 ms |
| `real(H)` |1.544 ms |1.175 ms |
|`H + H` |2.288 ms |1.978 ms |
|`H + H2` |5.139 ms |3.287 ms |
| `isdiag(H)` |23.042 ns (1 allocation: 16 bytes) |16.778 ns (0
allocations: 0 bytes) |

I'm not entirely certain why `isdiag(H)` allocates on master, as union
splitting should handle this automatically, but manually splitting the
union appears to help.
  • Loading branch information
jishnub authored Feb 9, 2024
1 parent 63e95d4 commit 27b31d1
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 27 deletions.
72 changes: 45 additions & 27 deletions stdlib/LinearAlgebra/src/symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -269,10 +269,34 @@ end
end
end

_conjugation(::Symmetric) = transpose
_conjugation(::Hermitian) = adjoint

diag(A::Symmetric) = symmetric.(diag(parent(A)), sym_uplo(A.uplo))
diag(A::Hermitian) = hermitian.(diag(parent(A)), sym_uplo(A.uplo))

isdiag(A::HermOrSym) = isdiag(A.uplo == 'U' ? UpperTriangular(A.data) : LowerTriangular(A.data))
function applytri(f, A::HermOrSym)
if A.uplo == 'U'
f(UpperTriangular(A.data))
else
f(LowerTriangular(A.data))
end
end

function applytri(f, A::HermOrSym, B::HermOrSym)
if A.uplo == B.uplo == 'U'
f(UpperTriangular(A.data), UpperTriangular(B.data))
elseif A.uplo == B.uplo == 'L'
f(LowerTriangular(A.data), LowerTriangular(B.data))
elseif A.uplo == 'U'
f(UpperTriangular(A.data), UpperTriangular(_conjugation(B)(B.data)))
else # A.uplo == 'L'
f(UpperTriangular(_conjugation(A)(A.data)), UpperTriangular(B.data))
end
end
parentof_applytri(f, args...) = applytri(parent f, args...)

isdiag(A::HermOrSym) = applytri(isdiag, A)

# For A<:Union{Symmetric,Hermitian}, similar(A[, neweltype]) should yield a matrix with the same
# symmetry type, uplo flag, and underlying storage type as A. The following methods cover these cases.
Expand Down Expand Up @@ -314,8 +338,8 @@ Hermitian{T,S}(A::Hermitian) where {T,S<:AbstractMatrix{T}} = Hermitian{T,S}(con
AbstractMatrix{T}(A::Hermitian) where {T} = Hermitian(convert(AbstractMatrix{T}, A.data), sym_uplo(A.uplo))
AbstractMatrix{T}(A::Hermitian{T}) where {T} = copy(A)

copy(A::Symmetric{T,S}) where {T,S} = (B = copy(A.data); Symmetric{T,typeof(B)}(B,A.uplo))
copy(A::Hermitian{T,S}) where {T,S} = (B = copy(A.data); Hermitian{T,typeof(B)}(B,A.uplo))
copy(A::Symmetric) = (Symmetric(parentof_applytri(copy, A), sym_uplo(A.uplo)))
copy(A::Hermitian) = (Hermitian(parentof_applytri(copy, A), sym_uplo(A.uplo)))

function copyto!(dest::Symmetric, src::Symmetric)
if src.uplo == dest.uplo
Expand Down Expand Up @@ -389,9 +413,9 @@ transpose(A::Hermitian) = Transpose(A)

real(A::Symmetric{<:Real}) = A
real(A::Hermitian{<:Real}) = A
real(A::Symmetric) = Symmetric(real(A.data), sym_uplo(A.uplo))
real(A::Hermitian) = Hermitian(real(A.data), sym_uplo(A.uplo))
imag(A::Symmetric) = Symmetric(imag(A.data), sym_uplo(A.uplo))
real(A::Symmetric) = Symmetric(parentof_applytri(real, A), sym_uplo(A.uplo))
real(A::Hermitian) = Hermitian(parentof_applytri(real, A), sym_uplo(A.uplo))
imag(A::Symmetric) = Symmetric(parentof_applytri(imag, A), sym_uplo(A.uplo))

Base.copy(A::Adjoint{<:Any,<:Symmetric}) =
Symmetric(copy(adjoint(A.parent.data)), ifelse(A.parent.uplo == 'U', :L, :U))
Expand All @@ -401,8 +425,9 @@ Base.copy(A::Transpose{<:Any,<:Hermitian}) =
tr(A::Symmetric) = tr(A.data) # to avoid AbstractMatrix fallback (incl. allocations)
tr(A::Hermitian) = real(tr(A.data))

Base.conj(A::HermOrSym) = typeof(A)(conj(A.data), A.uplo)
Base.conj!(A::HermOrSym) = typeof(A)(conj!(A.data), A.uplo)
Base.conj(A::Symmetric) = Symmetric(parentof_applytri(conj, A), sym_uplo(A.uplo))
Base.conj(A::Hermitian) = Hermitian(parentof_applytri(conj, A), sym_uplo(A.uplo))
Base.conj!(A::HermOrSym) = typeof(A)(parentof_applytri(conj!, A), A.uplo)

# tril/triu
function tril(A::Hermitian, k::Integer=0)
Expand Down Expand Up @@ -496,21 +521,14 @@ for (T, trans, real) in [(:Symmetric, :transpose, :identity), (:(Hermitian{<:Uni
end
end

(-)(A::Symmetric) = Symmetric(-A.data, sym_uplo(A.uplo))
(-)(A::Hermitian) = Hermitian(-A.data, sym_uplo(A.uplo))
(-)(A::Symmetric) = Symmetric(parentof_applytri(-, A), sym_uplo(A.uplo))
(-)(A::Hermitian) = Hermitian(parentof_applytri(-, A), sym_uplo(A.uplo))

## Addition/subtraction
for f (:+, :-), (Wrapper, conjugation) ((:Hermitian, :adjoint), (:Symmetric, :transpose))
@eval begin
function $f(A::$Wrapper, B::$Wrapper)
if A.uplo == B.uplo
return $Wrapper($f(parent(A), parent(B)), sym_uplo(A.uplo))
elseif A.uplo == 'U'
return $Wrapper($f(parent(A), $conjugation(parent(B))), :U)
else
return $Wrapper($f($conjugation(parent(A)), parent(B)), :U)
end
end
for f (:+, :-), Wrapper (:Hermitian, :Symmetric)
@eval function $f(A::$Wrapper, B::$Wrapper)
uplo = A.uplo == B.uplo ? sym_uplo(A.uplo) : (:U)
$Wrapper(parentof_applytri($f, A, B), uplo)
end
end

Expand Down Expand Up @@ -555,12 +573,12 @@ function dot(x::AbstractVector, A::RealHermSymComplexHerm, y::AbstractVector)
end

# Scaling with Number
*(A::Symmetric, x::Number) = Symmetric(A.data*x, sym_uplo(A.uplo))
*(x::Number, A::Symmetric) = Symmetric(x*A.data, sym_uplo(A.uplo))
*(A::Hermitian, x::Real) = Hermitian(A.data*x, sym_uplo(A.uplo))
*(x::Real, A::Hermitian) = Hermitian(x*A.data, sym_uplo(A.uplo))
/(A::Symmetric, x::Number) = Symmetric(A.data/x, sym_uplo(A.uplo))
/(A::Hermitian, x::Real) = Hermitian(A.data/x, sym_uplo(A.uplo))
*(A::Symmetric, x::Number) = Symmetric(parentof_applytri(y -> y * x, A), sym_uplo(A.uplo))
*(x::Number, A::Symmetric) = Symmetric(parentof_applytri(y -> x * y, A), sym_uplo(A.uplo))
*(A::Hermitian, x::Real) = Hermitian(parentof_applytri(y -> y * x, A), sym_uplo(A.uplo))
*(x::Real, A::Hermitian) = Hermitian(parentof_applytri(y -> x * y, A), sym_uplo(A.uplo))
/(A::Symmetric, x::Number) = Symmetric(parentof_applytri(y -> y/x, A), sym_uplo(A.uplo))
/(A::Hermitian, x::Real) = Hermitian(parentof_applytri(y -> y/x, A), sym_uplo(A.uplo))

factorize(A::HermOrSym) = _factorize(A)
function _factorize(A::HermOrSym{T}; check::Bool=true) where T
Expand Down
43 changes: 43 additions & 0 deletions stdlib/LinearAlgebra/test/symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,42 @@ end
end
end

@testset "non-isbits algebra" begin
for ST in (Symmetric, Hermitian), uplo in (:L, :U)
M = Matrix{Complex{BigFloat}}(undef,2,2)
M[1,1] = rand()
M[2,2] = rand()
M[1+(uplo==:L), 1+(uplo==:U)] = rand(ComplexF64)
S = ST(M, uplo)
MS = Matrix(S)
@test real(S) == real(MS)
@test imag(S) == imag(MS)
@test conj(S) == conj(MS)
@test conj!(copy(S)) == conj(MS)
@test -S == -MS
@test S + S == MS + MS
@test S - S == MS - MS
@test S*2 == 2*S == 2*MS
@test S/2 == MS/2
end
@testset "mixed uplo" begin
Mu = Matrix{Complex{BigFloat}}(undef,2,2)
Mu[1,1] = Mu[2,2] = 3
Mu[1,2] = 2 + 3im
Ml = Matrix{Complex{BigFloat}}(undef,2,2)
Ml[1,1] = Ml[2,2] = 4
Ml[2,1] = 4 + 5im
for ST in (Symmetric, Hermitian)
Su = ST(Mu, :U)
MSu = Matrix(Su)
Sl = ST(Ml, :L)
MSl = Matrix(Sl)
@test Su + Sl == Sl + Su == MSu + MSl
@test Su - Sl == -(Sl - Su) == MSu - MSl
end
end
end

# bug identified in PR #52318: dot products of quaternionic Hermitian matrices,
# or any number type where conj(a)*conj(b) ≠ conj(a*b):
@testset "dot Hermitian quaternion #52318" begin
Expand Down Expand Up @@ -932,4 +968,11 @@ end
end
end

@testset "conj for immutable" begin
S = Symmetric(reshape((1:16)*im, 4, 4))
@test conj(S) == conj(Array(S))
H = Hermitian(reshape((1:16)*im, 4, 4))
@test conj(H) == conj(Array(H))
end

end # module TestSymmetric

0 comments on commit 27b31d1

Please sign in to comment.