Skip to content

Commit

Permalink
Merge pull request #533 from simonbyrne/sb/sdiagonal
Browse files Browse the repository at this point in the history
Use LinearAlgebra.Diagonal instead of defining SDiagonal type
  • Loading branch information
mschauer authored Oct 29, 2018
2 parents 1683f79 + a0af20c commit 9368cba
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 80 deletions.
99 changes: 22 additions & 77 deletions src/SDiagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,114 +3,59 @@

import Base: ==, -, +, *, /, \, abs, real, imag, conj

@generated function scalem(a::StaticMatrix{M,N}, b::StaticVector{N}) where {M, N}
expr = vec([:(a[$j,$i]*b[$i]) for j=1:M, i=1:N])
:(@_inline_meta; let val1 = ($(expr[1])); similar_type(SMatrix{M,N},typeof(val1))(val1, $(expr[2:end]...)); end)
end
@generated function scalem(a::StaticVector{M}, b::StaticMatrix{M, N}) where {M, N}
expr = vec([:(b[$j,$i]*a[$j]) for j=1:M, i=1:N])
:(@_inline_meta; let val1 = ($(expr[1])); similar_type(SMatrix{M,N},typeof(val1))(val1, $(expr[2:end]...)); end)
end

struct SDiagonal{N,T} <: StaticMatrix{N,N,T}
diag::SVector{N,T}
SDiagonal{N,T}(diag::SVector{N,T}) where {N,T} = new(diag)
end
diagtype(::Type{SDiagonal{N,T}}) where {N, T} = SVector{N,T}
diagtype(::Type{SDiagonal{N}}) where {N} = SVector{N}
diagtype(::Type{SDiagonal}) = SVector
const SDiagonal = Diagonal{T,SVector{N,T}} where {N,T}
SDiagonal(x...) = Diagonal(SVector(x...))

# this is to deal with convert.jl
@inline (::Type{SD})(a::AbstractVector) where {SD <: SDiagonal} = SDiagonal(convert(diagtype(SD), a))
@inline (::Type{SD})(a::Tuple) where {SD <: SDiagonal} = SDiagonal(convert(diagtype(SD), a))
@inline SDiagonal(a::SVector{N,T}) where {N,T} = SDiagonal{N,T}(a)

@generated function SDiagonal(a::StaticMatrix{N,N,T}) where {N,T}
expr = [:(a[$i,$i]) for i=1:N]
:(SDiagonal{N,T}($(expr...)))
end

convert(::Type{SDiagonal{N,T}}, D::SDiagonal{N,T}) where {N,T} = D
convert(::Type{SDiagonal{N,T}}, D::SDiagonal{N}) where {N,T} = SDiagonal{N,T}(convert(SVector{N,T}, D.diag))

function getindex(D::SDiagonal{N,T}, i::Int, j::Int) where {N,T}
@boundscheck checkbounds(D, i, j)
@inbounds return ifelse(i == j, D.diag[i], zero(T))
end

# avoid linear indexing?
@propagate_inbounds function getindex(D::SDiagonal{N,T}, k::Int) where {N,T}
i, j = CartesianIndices(size(D))[k].I
D[i,j]
end
#@inline (::Type{SDiagonal{N,T}})(a::AbstractVector) where {N,T} = Diagonal(SVector{N,T}(a))
@inline (::Type{SDiagonal{N,T}})(a::Tuple) where {N,T} = Diagonal(SVector{N,T}(a))
@inline (::Type{SDiagonal{N}})(a::Tuple) where {N} = Diagonal(SVector{N}(a))

ishermitian(D::SDiagonal{N, T}) where {N,T<:Real} = true
ishermitian(D::SDiagonal) = all(D.diag .== real(D.diag))
issymmetric(D::SDiagonal) = true
isposdef(D::SDiagonal) = all(D.diag .> 0)
SDiagonal(a::SVector) = Diagonal(a)
SDiagonal(a::StaticMatrix{N,N,T}) where {N,T} = Diagonal(diag(a))

factorize(D::SDiagonal) = D
size(::Type{<:SDiagonal{N}}) where {N} = (N,N)
size(::Type{<:SDiagonal{N}}, d::Int) where {N} = d > 2 ? 1 : N

==(Da::SDiagonal, Db::SDiagonal) = Da.diag == Db.diag
-(A::SDiagonal) = SDiagonal(-A.diag)
+(Da::SDiagonal, Db::SDiagonal) = SDiagonal(Da.diag + Db.diag)
-(Da::SDiagonal, Db::SDiagonal) = SDiagonal(Da.diag - Db.diag)
-(A::SDiagonal, B::SMatrix) = typeof(B)(I)*A - B

*(x::T, D::SDiagonal) where {T<:Number} = SDiagonal(x * D.diag)
*(D::SDiagonal, x::T) where {T<:Number} = SDiagonal(D.diag * x)
/(D::SDiagonal, x::T) where {T<:Number} = SDiagonal(D.diag / x)
*(Da::SDiagonal, Db::SDiagonal) = SDiagonal(Da.diag .* Db.diag)
*(D::SDiagonal, V::AbstractVector) = D.diag .* V
*(D::SDiagonal, V::StaticVector) = D.diag .* V
*(A::StaticMatrix, D::SDiagonal) = scalem(A,D.diag)
*(D::SDiagonal, A::StaticMatrix) = scalem(D.diag,A)
# define specific methods to avoid allocating mutable arrays
*(A::StaticMatrix, D::SDiagonal) = A .* transpose(D.diag)
*(D::SDiagonal, A::StaticMatrix) = D.diag .* A
\(D::SDiagonal, b::AbstractVector) = D.diag .\ b
\(D::SDiagonal, b::StaticVector) = D.diag .\ b # catch ambiguity

conj(D::SDiagonal) = SDiagonal(conj(D.diag))
transpose(D::SDiagonal) = D
adjoint(D::SDiagonal) = conj(D)
\(D::SDiagonal, B::StaticMatrix) = D.diag .\ B
/(B::StaticMatrix, D::SDiagonal) = B ./ transpose(D.diag)
\(Da::SDiagonal, Db::SDiagonal) = SDiagonal(Db.diag ./ Da.diag)
/(Da::SDiagonal, Db::SDiagonal) = SDiagonal(Da.diag ./ Db.diag )

# override to avoid copying
diag(D::SDiagonal) = D.diag
tr(D::SDiagonal) = sum(D.diag)
det(D::SDiagonal) = prod(D.diag)
logdet(D::SDiagonal{N,T}) where {N,T<:Real} = sum(log.(D.diag))
function logdet(D::SDiagonal{N,T}) where {N,T<:Complex} #Make sure branch cut is correct
x = sum(log.(D.diag))
-pi<imag(x)<pi ? x : real(x)+(mod2pi(imag(x)+pi)-pi)*im
end

# SDiagonal(I::UniformScaling) methods to replace eye
(::Type{SD})(I::UniformScaling) where {N,SD<:SDiagonal{N}} = SD(ntuple(x->I.λ, Val(N)))
(::Type{SDiagonal{N}})(I::UniformScaling) where {N} = SDiagonal{N}(ntuple(x->I.λ, Val(N)))
(::Type{SDiagonal{N,T}})(I::UniformScaling) where {N,T} = SDiagonal{N,T}(ntuple(x->I.λ, Val(N)))

# deprecate eye, keep around for as long as LinearAlgebra.eye exists
@static if VERSION < v"1.0"
@deprecate eye(::Type{SDiagonal{N,T}}) where {N,T} SDiagonal{N,T}(I)
end

one(::Type{SDiagonal{N,T}}) where {N,T} = SDiagonal(ones(SVector{N,T}))
one(::SDiagonal{N,T}) where {N,T} = SDiagonal(ones(SVector{N,T}))

Base.zero(::SDiagonal{N,T}) where {N,T} = SDiagonal(zeros(SVector{N,T}))
exp(D::SDiagonal) = SDiagonal(exp.(D.diag))
log(D::SDiagonal) = SDiagonal(log.(D.diag))
sqrt(D::SDiagonal) = SDiagonal(sqrt.(D.diag))

function LinearAlgebra.cholesky(D::SDiagonal)
any(x -> x < 0, D.diag) && throw(LinearAlgebra.PosDefException(1))
C = sqrt.(D.diag)
return Cholesky(SDiagonal(C), 'U', 0)
end

\(D::SDiagonal, B::StaticMatrix) = scalem(1 ./ D.diag, B)
/(B::StaticMatrix, D::SDiagonal) = scalem(1 ./ D.diag, B)
\(Da::SDiagonal, Db::SDiagonal) = SDiagonal(Db.diag ./ Da.diag)
/(Da::SDiagonal, Db::SDiagonal) = SDiagonal(Da.diag ./ Db.diag )

@generated function check_singular(D::SDiagonal{N}) where {N}
quote
Base.Cartesian.@nexprs $N i->(@inbounds iszero(D.diag[i]) && throw(LinearAlgebra.SingularException(i)))
end
end

function inv(D::SDiagonal)
check_singular(D)
SDiagonal(inv.(D.diag))
Expand Down
3 changes: 0 additions & 3 deletions test/SDiagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,6 @@ using StaticArrays, Test, LinearAlgebra

@testset "Methods" begin

@test StaticArrays.scalem(@SMatrix([1 1 1;1 1 1; 1 1 1]), @SVector [1,2,3]) === @SArray [1 2 3; 1 2 3; 1 2 3]
@test StaticArrays.scalem(@SVector([1,2,3]),@SMatrix [1 1 1;1 1 1; 1 1 1])' === @SArray [1 2 3; 1 2 3; 1 2 3]

m = SDiagonal(@SVector [11, 12, 13, 14])

@test diag(m) === m.diag
Expand Down

0 comments on commit 9368cba

Please sign in to comment.