From 287847936f9c952c6cdbf87fa5ecbbbbd4d1f605 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Mon, 2 Sep 2024 20:12:39 +0530 Subject: [PATCH] Reroute Symmetric/Hermitian + Diagonal through triangular (#55605) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This should fix the `Diagonal`-related issue from https://github.com/JuliaLang/julia/issues/55590, although the `SymTridiagonal` one still remains. ```julia julia> using LinearAlgebra julia> a = Matrix{BigFloat}(undef, 2,2) 2×2 Matrix{BigFloat}: #undef #undef #undef #undef julia> a[1] = 1; a[3] = 1; a[4] = 1 1 julia> a = Hermitian(a) 2×2 Hermitian{BigFloat, Matrix{BigFloat}}: 1.0 1.0 1.0 1.0 julia> b = Symmetric(a) 2×2 Symmetric{BigFloat, Matrix{BigFloat}}: 1.0 1.0 1.0 1.0 julia> c = Diagonal([1,1]) 2×2 Diagonal{Int64, Vector{Int64}}: 1 ⋅ ⋅ 1 julia> a+c 2×2 Hermitian{BigFloat, Matrix{BigFloat}}: 2.0 1.0 1.0 2.0 julia> b+c 2×2 Symmetric{BigFloat, Matrix{BigFloat}}: 2.0 1.0 1.0 2.0 ``` (cherry picked from commit 39f2ad1e941fd092b906d60e8d23c004e8ee5b7e) --- stdlib/LinearAlgebra/src/diagonal.jl | 15 --------------- stdlib/LinearAlgebra/src/special.jl | 19 +++++++++++++++++++ stdlib/LinearAlgebra/test/special.jl | 13 +++++++++++++ 3 files changed, 32 insertions(+), 15 deletions(-) diff --git a/stdlib/LinearAlgebra/src/diagonal.jl b/stdlib/LinearAlgebra/src/diagonal.jl index d27da450205f7..bea8183546d90 100644 --- a/stdlib/LinearAlgebra/src/diagonal.jl +++ b/stdlib/LinearAlgebra/src/diagonal.jl @@ -250,21 +250,6 @@ end (+)(Da::Diagonal, Db::Diagonal) = Diagonal(Da.diag + Db.diag) (-)(Da::Diagonal, Db::Diagonal) = Diagonal(Da.diag - Db.diag) -for f in (:+, :-) - @eval function $f(D::Diagonal{<:Number}, S::Symmetric) - return Symmetric($f(D, S.data), sym_uplo(S.uplo)) - end - @eval function $f(S::Symmetric, D::Diagonal{<:Number}) - return Symmetric($f(S.data, D), sym_uplo(S.uplo)) - end - @eval function $f(D::Diagonal{<:Real}, H::Hermitian) - return Hermitian($f(D, H.data), sym_uplo(H.uplo)) - end - @eval function $f(H::Hermitian, D::Diagonal{<:Real}) - return Hermitian($f(H.data, D), sym_uplo(H.uplo)) - end -end - (*)(x::Number, D::Diagonal) = Diagonal(x * D.diag) (*)(D::Diagonal, x::Number) = Diagonal(D.diag * x) (/)(D::Diagonal, x::Number) = Diagonal(D.diag / x) diff --git a/stdlib/LinearAlgebra/src/special.jl b/stdlib/LinearAlgebra/src/special.jl index 1363708fb515f..2a3b6bf2a8e00 100644 --- a/stdlib/LinearAlgebra/src/special.jl +++ b/stdlib/LinearAlgebra/src/special.jl @@ -264,6 +264,25 @@ function (-)(A::UniformScaling, B::Diagonal) Diagonal(Ref(A) .- B.diag) end +for f in (:+, :-) + @eval function $f(D::Diagonal{<:Number}, S::Symmetric) + uplo = sym_uplo(S.uplo) + return Symmetric(parentof_applytri($f, Symmetric(D, uplo), S), uplo) + end + @eval function $f(S::Symmetric, D::Diagonal{<:Number}) + uplo = sym_uplo(S.uplo) + return Symmetric(parentof_applytri($f, S, Symmetric(D, uplo)), uplo) + end + @eval function $f(D::Diagonal{<:Real}, H::Hermitian) + uplo = sym_uplo(H.uplo) + return Hermitian(parentof_applytri($f, Hermitian(D, uplo), H), uplo) + end + @eval function $f(H::Hermitian, D::Diagonal{<:Real}) + uplo = sym_uplo(H.uplo) + return Hermitian(parentof_applytri($f, H, Hermitian(D, uplo)), uplo) + end +end + ## Diagonal construction from UniformScaling Diagonal{T}(s::UniformScaling, m::Integer) where {T} = Diagonal{T}(fill(T(s.λ), m)) Diagonal(s::UniformScaling, m::Integer) = Diagonal{eltype(s)}(s, m) diff --git a/stdlib/LinearAlgebra/test/special.jl b/stdlib/LinearAlgebra/test/special.jl index 7e96af369e310..148ccd85efeb1 100644 --- a/stdlib/LinearAlgebra/test/special.jl +++ b/stdlib/LinearAlgebra/test/special.jl @@ -536,4 +536,17 @@ end @test v * S isa Matrix end +@testset "Partly filled Hermitian and Diagonal algebra" begin + D = Diagonal([1,2]) + for S in (Symmetric, Hermitian), uplo in (:U, :L) + M = Matrix{BigInt}(undef, 2, 2) + M[1,1] = M[2,2] = M[1+(uplo == :L), 1 + (uplo == :U)] = 3 + H = S(M, uplo) + HM = Matrix(H) + @test H + D == D + H == HM + D + @test H - D == HM - D + @test D - H == D - HM + end +end + end # module TestSpecial