From 197295c84aab217bffde01a4339f1d0e8e95fb98 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Sun, 28 Jul 2024 17:35:28 +0530 Subject: [PATCH] Restrict binary ops for Diagonal and Symmetric to Number eltypes (#55251) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The `(::Diagonal) + (::Symmetric)` and analogous methods were specialized in https://github.com/JuliaLang/julia/pull/35333 to return a `Symmetric`, but these only work if the `Diagonal` is also symmetric. This typically holds for arrays of numbers, but may not hold for block-diagonal and other types for which symmetry isn't guaranteed. This PR restricts the methods to arrays of `Number`s. Fixes, e.g.: ```julia julia> using StaticArrays, LinearAlgebra julia> D = Diagonal(fill(SMatrix{2,2}(1:4), 2)) 2×2 Diagonal{SMatrix{2, 2, Int64, 4}, Vector{SMatrix{2, 2, Int64, 4}}}: [1 3; 2 4] ⋅ ⋅ [1 3; 2 4] julia> S = Symmetric(D) 2×2 Symmetric{AbstractMatrix, Diagonal{SMatrix{2, 2, Int64, 4}, Vector{SMatrix{2, 2, Int64, 4}}}}: [1 3; 3 4] ⋅ ⋅ [1 3; 3 4] julia> S + D 2×2 Symmetric{AbstractMatrix, Diagonal{SMatrix{2, 2, Int64, 4}, Vector{SMatrix{2, 2, Int64, 4}}}}: [2 6; 6 8] ⋅ ⋅ [2 6; 6 8] julia> S[1,1] + D[1,1] 2×2 SMatrix{2, 2, Int64, 4} with indices SOneTo(2)×SOneTo(2): 2 6 5 8 julia> (S + D)[1,1] == S[1,1] + D[1,1] false ``` After this, ```julia julia> S + D 2×2 Matrix{AbstractMatrix{Int64}}: [2 6; 5 8] [0 0; 0 0] [0 0; 0 0] [2 6; 5 8] ``` Even with `Number`s as elements, there might be an issue with `NaN`s along the diagonal as `!issymmetric(NaN)`, but that may be a different PR. --- stdlib/LinearAlgebra/src/diagonal.jl | 4 ++-- stdlib/LinearAlgebra/test/diagonal.jl | 12 ++++++++++++ 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/stdlib/LinearAlgebra/src/diagonal.jl b/stdlib/LinearAlgebra/src/diagonal.jl index b3826a2aa7f82..ff8350d8ddeb1 100644 --- a/stdlib/LinearAlgebra/src/diagonal.jl +++ b/stdlib/LinearAlgebra/src/diagonal.jl @@ -268,10 +268,10 @@ end (-)(Da::Diagonal, Db::Diagonal) = Diagonal(Da.diag - Db.diag) for f in (:+, :-) - @eval function $f(D::Diagonal, S::Symmetric) + @eval function $f(D::Diagonal{<:Number}, S::Symmetric) return Symmetric($f(D, S.data), sym_uplo(S.uplo)) end - @eval function $f(S::Symmetric, D::Diagonal) + @eval function $f(S::Symmetric, D::Diagonal{<:Number}) return Symmetric($f(S.data, D), sym_uplo(S.uplo)) end @eval function $f(D::Diagonal{<:Real}, H::Hermitian) diff --git a/stdlib/LinearAlgebra/test/diagonal.jl b/stdlib/LinearAlgebra/test/diagonal.jl index 1a3b8d4fd0ea7..e1fc9afa5ad2e 100644 --- a/stdlib/LinearAlgebra/test/diagonal.jl +++ b/stdlib/LinearAlgebra/test/diagonal.jl @@ -1335,4 +1335,16 @@ end end end +@testset "+/- with block Symmetric/Hermitian" begin + for p in ([1 2; 3 4], [1 2+im; 2-im 4+2im]) + m = SizedArrays.SizedArray{(2,2)}(p) + D = Diagonal(fill(m, 2)) + for T in (Symmetric, Hermitian) + S = T(fill(m, 2, 2)) + @test D + S == Array(D) + Array(S) + @test S + D == Array(S) + Array(D) + end + end +end + end # module TestDiagonal