Skip to content

Commit

Permalink
make Diagonal+Symmetric return Symmetric
Browse files Browse the repository at this point in the history
  • Loading branch information
dkarrasch committed Apr 1, 2020
1 parent 351d7e9 commit d9d2f25
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 13 deletions.
39 changes: 26 additions & 13 deletions stdlib/LinearAlgebra/src/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,32 @@ end
(-)(A::Diagonal) = Diagonal(-A.diag)
(+)(Da::Diagonal, Db::Diagonal) = Diagonal(Da.diag + Db.diag)
(-)(Da::Diagonal, Db::Diagonal) = Diagonal(Da.diag - Db.diag)
for f in (:+, :-)
@eval function $f(D::Diagonal, S::Symmetric)
SD = copy_oftype(S.data, promote_op($f, eltype(D), eltype(S)))
di = diagind(SD)
SD[di] .= ($f).(D.diag, SD[di])
return Symmetric(SD, sym_uplo(S.uplo))
end
@eval function $f(S::Symmetric, D::Diagonal)
SD = copy_oftype(S.data, promote_op($f, eltype(D), eltype(S)))
di = diagind(SD)
SD[di] .= ($f).(SD[di], D.diag)
return Symmetric(SD, sym_uplo(S.uplo))
end
@eval function $f(D::Diagonal{<:Real}, H::Hermitian)
HD = copy_oftype(H.data, promote_op($f, eltype(D), eltype(H)))
di = diagind(HD)
HD[di] .= ($f).(D.diag, HD[di])
return Hermitian(HD, sym_uplo(H.uplo))
end
@eval function $f(H::Hermitian, D::Diagonal{<:Real})
HD = copy_oftype(H.data, promote_op($f, eltype(D), eltype(H)))
di = diagind(HD)
HD[di] .= ($f).(HD[di], D.diag)
return Hermitian(HD, sym_uplo(H.uplo))
end
end

(*)(x::Number, D::Diagonal) = Diagonal(x * D.diag)
(*)(D::Diagonal, x::Number) = Diagonal(D.diag * x)
Expand Down Expand Up @@ -680,19 +706,6 @@ function getproperty(C::Cholesky{<:Any,<:Diagonal}, d::Symbol)
end

Base._sum(A::Diagonal, ::Colon) = sum(A.diag)
function Base._sum(A::Diagonal, dims::Integer)
res = Base.reducedim_initarray(A, dims, zero(eltype(A)))
if dims <= 2
for i = 1:length(A.diag)
@inbounds res[i] = A.diag[i]
end
else
for i = 1:length(A.diag)
@inbounds res[i,i] = A.diag[i]
end
end
res
end

function logabsdet(A::Diagonal)
mapreduce(x -> (log(abs(x)), sign(x)), ((d1, s1), (d2, s2)) -> (d1 + d2, s1 * s2),
Expand Down
13 changes: 13 additions & 0 deletions stdlib/LinearAlgebra/test/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,19 @@ Random.seed!(1)
A = rand(elty, n, n)
Asym = Symmetric(A + transpose(A), :U)
Aherm = Hermitian(A + adjoint(A), :U)
for op in (+, -)
@test op(Asym, D) isa Symmetric
@test Array(op(Asym, D)) Array(Symmetric(op(Array(Asym), Array(D))))
@test op(D, Asym) isa Symmetric
@test Array(op(D, Asym)) Array(Symmetric(op(Array(D), Array(Asym))))
if !(elty <: Real)
Dr = real(D)
@test op(Aherm, Dr) isa Hermitian
@test Array(op(Aherm, Dr)) Array(Hermitian(op(Array(Aherm), Array(Dr))))
@test op(Dr, Asym) isa Hermitian
@test Array(op(Dr, Asym)) Array(Hermitian(op(Array(Dr), Array(Aherm))))
end
end
@test Array(D*Transpose(Asym)) Array(D) * Array(transpose(Asym))
@test Array(D*Adjoint(Asym)) Array(D) * Array(adjoint(Asym))
@test Array(D*Transpose(Aherm)) Array(D) * Array(transpose(Aherm))
Expand Down

0 comments on commit d9d2f25

Please sign in to comment.