From 49023b50eb22ce1276c381bd0cf88cf2f269ed0e Mon Sep 17 00:00:00 2001 From: Sakse <17059936+dalum@users.noreply.github.com> Date: Thu, 13 Dec 2018 18:01:51 +0800 Subject: [PATCH] Preserve types when adding/subtracting Herm/Sym/UniformScaling (#29500) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Preserve types when adding/subtracting Herm/Sym/UniformScaling * Make `real(::SymOrHerm{<:Real})` consistent with `real(::Array)`. * Fix embarrassing ambiguity * More tests, remove imag(::Hermitian), simplify code * Remove `.λ` --- stdlib/LinearAlgebra/src/symmetric.jl | 14 ++++++++++ stdlib/LinearAlgebra/src/uniformscaling.jl | 25 ++++++++++++++++++ stdlib/LinearAlgebra/test/symmetric.jl | 29 ++++++++++++++++++--- stdlib/LinearAlgebra/test/uniformscaling.jl | 10 +++++++ 4 files changed, 75 insertions(+), 3 deletions(-) diff --git a/stdlib/LinearAlgebra/src/symmetric.jl b/stdlib/LinearAlgebra/src/symmetric.jl index 9932f36c11ce0..1aea051c84259 100644 --- a/stdlib/LinearAlgebra/src/symmetric.jl +++ b/stdlib/LinearAlgebra/src/symmetric.jl @@ -330,6 +330,12 @@ transpose(A::Hermitian{<:Real}) = A adjoint(A::Symmetric) = Adjoint(A) transpose(A::Hermitian) = Transpose(A) +real(A::Symmetric{<:Real}) = A +real(A::Hermitian{<:Real}) = A +real(A::Symmetric) = Symmetric(real(A.data), sym_uplo(A.uplo)) +real(A::Hermitian) = Hermitian(real(A.data), sym_uplo(A.uplo)) +imag(A::Symmetric) = Symmetric(imag(A.data), sym_uplo(A.uplo)) + Base.copy(A::Adjoint{<:Any,<:Hermitian}) = copy(A.parent) Base.copy(A::Transpose{<:Any,<:Symmetric}) = copy(A.parent) Base.copy(A::Adjoint{<:Any,<:Symmetric}) = @@ -394,6 +400,14 @@ end (-)(A::Symmetric{Tv,S}) where {Tv,S} = Symmetric{Tv,S}(-A.data, A.uplo) (-)(A::Hermitian{Tv,S}) where {Tv,S} = Hermitian{Tv,S}(-A.data, A.uplo) +## Addition/subtraction +for f in (:+, :-) + @eval $f(A::Symmetric, B::Symmetric) = Symmetric($f(A.data, B), sym_uplo(A.uplo)) + @eval $f(A::Hermitian, B::Hermitian) = Hermitian($f(A.data, B), sym_uplo(A.uplo)) + @eval $f(A::Hermitian, B::Symmetric{<:Real}) = Hermitian($f(A.data, B), sym_uplo(A.uplo)) + @eval $f(A::Symmetric{<:Real}, B::Hermitian) = Hermitian($f(A.data, B), sym_uplo(A.uplo)) +end + ## Matvec mul!(y::StridedVector{T}, A::Symmetric{T,<:StridedMatrix}, x::StridedVector{T}) where {T<:BlasFloat} = BLAS.symv!(A.uplo, one(T), A.data, x, zero(T), y) diff --git a/stdlib/LinearAlgebra/src/uniformscaling.jl b/stdlib/LinearAlgebra/src/uniformscaling.jl index fd46e6c1a2fce..a97a54ea745c0 100644 --- a/stdlib/LinearAlgebra/src/uniformscaling.jl +++ b/stdlib/LinearAlgebra/src/uniformscaling.jl @@ -112,6 +112,31 @@ for (t1, t2) in ((:UnitUpperTriangular, :UpperTriangular), end end +# Adding a complex UniformScaling to the diagonal of a Hermitian +# matrix breaks the hermiticity, if the UniformScaling is non-real. +# However, to preserve type stability, we do not special-case a +# UniformScaling{<:Complex} that happens to be real. +function (+)(A::Hermitian{T,S}, J::UniformScaling{<:Complex}) where {T,S} + A_ = copytri!(copy(parent(A)), A.uplo) + B = convert(AbstractMatrix{Base._return_type(+, Tuple{eltype(A), typeof(J)})}, A_) + @inbounds for i in diagind(B) + B[i] += J + end + return B +end + +function (-)(J::UniformScaling{<:Complex}, A::Hermitian{T,S}) where {T,S} + A_ = copytri!(copy(parent(A)), A.uplo) + B = convert(AbstractMatrix{Base._return_type(+, Tuple{eltype(A), typeof(J)})}, A_) + @inbounds for i in eachindex(B) + B[i] = -B[i] + end + @inbounds for i in diagind(B) + B[i] += J + end + return B +end + function (+)(A::AbstractMatrix, J::UniformScaling) checksquare(A) B = copy_oftype(A, Base._return_type(+, Tuple{eltype(A), typeof(J)})) diff --git a/stdlib/LinearAlgebra/test/symmetric.jl b/stdlib/LinearAlgebra/test/symmetric.jl index e240172303052..67a7213ed3083 100644 --- a/stdlib/LinearAlgebra/test/symmetric.jl +++ b/stdlib/LinearAlgebra/test/symmetric.jl @@ -90,6 +90,17 @@ end @test (-Hermitian(aherm))::typeof(Hermitian(aherm)) == -aherm end + @testset "Addition and subtraction for Symmetric/Hermitian matrices" begin + for f in (+, -) + @test (f(Symmetric(asym), Symmetric(aposs)))::typeof(Symmetric(asym)) == f(asym, aposs) + @test (f(Hermitian(aherm), Hermitian(apos)))::typeof(Hermitian(aherm)) == f(aherm, apos) + @test (f(Symmetric(real(asym)), Hermitian(aherm)))::typeof(Hermitian(aherm)) == f(real(asym), aherm) + @test (f(Hermitian(aherm), Symmetric(real(asym))))::typeof(Hermitian(aherm)) == f(aherm, real(asym)) + @test (f(Symmetric(asym), Hermitian(aherm))) == f(asym, aherm) + @test (f(Hermitian(aherm), Symmetric(asym))) == f(aherm, asym) + end + end + @testset "getindex and unsafe_getindex" begin @test aherm[1,1] == Hermitian(aherm)[1,1] @test asym[1,1] == Symmetric(asym)[1,1] @@ -153,6 +164,21 @@ end @test transpose(H) == Hermitian(copy(transpose(aherm))) end end + + @testset "real, imag" begin + S = Symmetric(asym) + H = Hermitian(aherm) + @test issymmetric(real(S)) + @test ishermitian(real(H)) + if eltya <: Real + @test real(S) === S == asym + @test real(H) === H == aherm + elseif eltya <: Complex + @test issymmetric(imag(S)) + @test !ishermitian(imag(H)) + end + end + end @testset "linalg unary ops" begin @@ -415,9 +441,6 @@ end @test T([true false; false true]) .+ true == T([2 1; 1 2]) end - - @test_throws ArgumentError Hermitian(X) + 2im*I - @test_throws ArgumentError Hermitian(X) - 2im*I end @testset "Issue #21981" begin diff --git a/stdlib/LinearAlgebra/test/uniformscaling.jl b/stdlib/LinearAlgebra/test/uniformscaling.jl index 887162d03dc20..f22812e0486fc 100644 --- a/stdlib/LinearAlgebra/test/uniformscaling.jl +++ b/stdlib/LinearAlgebra/test/uniformscaling.jl @@ -178,6 +178,16 @@ let @test @inferred(J - T) == J - Array(T) @test @inferred(T\I) == inv(T) + if isa(A, Array) + T = Hermitian(randn(3,3)) + else + T = Hermitian(view(randn(3,3), 1:3, 1:3)) + end + @test @inferred(T + J) == Array(T) + J + @test @inferred(J + T) == J + Array(T) + @test @inferred(T - J) == Array(T) - J + @test @inferred(J - T) == J - Array(T) + @test @inferred(I\A) == A @test @inferred(A\I) == inv(A) @test @inferred(λ\I) === UniformScaling(1/λ)