From 4837cf9d2e8d3425532cbea3b3693fa5238888da Mon Sep 17 00:00:00 2001 From: Fredrik Ekre Date: Wed, 17 Oct 2018 17:57:27 +0200 Subject: [PATCH] fix #29392: HermOrSym should preserve structure when scaled with Numbers. --- stdlib/LinearAlgebra/src/symmetric.jl | 13 ++++---- stdlib/LinearAlgebra/test/symmetric.jl | 44 ++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 6 deletions(-) diff --git a/stdlib/LinearAlgebra/src/symmetric.jl b/stdlib/LinearAlgebra/src/symmetric.jl index b5f6204361230..0111b3faf4f93 100644 --- a/stdlib/LinearAlgebra/src/symmetric.jl +++ b/stdlib/LinearAlgebra/src/symmetric.jl @@ -440,12 +440,13 @@ mul!(C::StridedMatrix{T}, A::StridedMatrix{T}, B::Hermitian{T,<:StridedMatrix}) *(adjA::Adjoint{<:Any,<:RealHermSymComplexHerm}, B::AbstractTriangular) = adjA.parent * B *(A::AbstractTriangular, adjB::Adjoint{<:Any,<:RealHermSymComplexHerm}) = A * adjB.parent -for T in (:Symmetric, :Hermitian), op in (:*, :/) - # Deal with an ambiguous case - @eval ($op)(A::$T, x::Bool) = ($T)(($op)(A.data, x), sym_uplo(A.uplo)) - S = T == :Hermitian ? :Real : :Number - @eval ($op)(A::$T, x::$S) = ($T)(($op)(A.data, x), sym_uplo(A.uplo)) -end +# Scaling with Number +*(A::Symmetric, x::Number) = Symmetric(A.data*x, sym_uplo(A.uplo)) +*(x::Number, A::Symmetric) = Symmetric(x*A.data, sym_uplo(A.uplo)) +*(A::Hermitian, x::Real) = Hermitian(A.data*x, sym_uplo(A.uplo)) +*(x::Real, A::Hermitian) = Hermitian(x*A.data, sym_uplo(A.uplo)) +/(A::Symmetric, x::Number) = Symmetric(A.data/x, sym_uplo(A.uplo)) +/(A::Hermitian, x::Real) = Hermitian(A.data/x, sym_uplo(A.uplo)) function factorize(A::HermOrSym{T}) where T TT = typeof(sqrt(oneunit(T))) diff --git a/stdlib/LinearAlgebra/test/symmetric.jl b/stdlib/LinearAlgebra/test/symmetric.jl index 9e9814895d7e6..e240172303052 100644 --- a/stdlib/LinearAlgebra/test/symmetric.jl +++ b/stdlib/LinearAlgebra/test/symmetric.jl @@ -517,4 +517,48 @@ end @test Hermitian(A, :U)[1,1] == Hermitian(A, :L)[1,1] == real(A[1,1]) end +@testset "issue #29392: SymOrHerm scaled with Number" begin + R = rand(Float64, 2, 2); C = rand(ComplexF64, 2, 2) + # Symmetric * Real, Real * Symmetric + A = Symmetric(R); x = 2.0 + @test (A * x)::Symmetric == (x * A)::Symmetric + A = Symmetric(C); x = 2.0 + @test (A * x)::Symmetric == (x * A)::Symmetric + # Symmetric * Complex, Complex * Symmetrics + A = Symmetric(R); x = 2.0im + @test (A * x)::Symmetric == (x * A)::Symmetric + A = Symmetric(C); x = 2.0im + @test (A * x)::Symmetric == (x * A)::Symmetric + # Hermitian * Real, Real * Hermitian + A = Hermitian(R); x = 2.0 + @test (A * x)::Hermitian == (x * A)::Hermitian + A = Hermitian(C); x = 2.0 + @test (A * x)::Hermitian == (x * A)::Hermitian + # Hermitian * Complex, Complex * Hermitian + A = Hermitian(R); x = 2.0im + @test (A * x)::Matrix == (x * A)::Matrix + A = Hermitian(C); x = 2.0im + @test (A * x)::Matrix == (x * A)::Matrix + # Symmetric / Real + A = Symmetric(R); x = 2.0 + @test (A / x)::Symmetric == Matrix(A) / x + A = Symmetric(C); x = 2.0 + @test (A / x)::Symmetric == Matrix(A) / x + # Symmetric / Complex + A = Symmetric(R); x = 2.0im + @test (A / x)::Symmetric == Matrix(A) / x + A = Symmetric(C); x = 2.0im + @test (A / x)::Symmetric == Matrix(A) / x + # Hermitian / Real + A = Hermitian(R); x = 2.0 + @test (A / x)::Hermitian == Matrix(A) / x + A = Hermitian(C); x = 2.0 + @test (A / x)::Hermitian == Matrix(A) / x + # Hermitian / Complex + A = Hermitian(R); x = 2.0im + @test (A / x)::Matrix == Matrix(A) / x + A = Hermitian(C); x = 2.0im + @test (A / x)::Matrix == Matrix(A) / x +end + end # module TestSymmetric