Skip to content

Commit

Permalink
Merge pull request #5263 from JuliaLang/cjh/solve-diag
Browse files Browse the repository at this point in the history
Specialized solvers for Diagonal matrices
  • Loading branch information
jiahao committed Jan 1, 2014
2 parents 48e599e + 78315a5 commit 7cfcb49
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 46 deletions.
2 changes: 1 addition & 1 deletion base/float16.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ end
for op in (:+,:-,:*,:/,:\)
@eval ($op)(a::Float16, b::Float16) = float16(($op)(float32(a), float32(b)))
end
for op in (:<,:isless)
for op in (:<,:<=,:isless)
@eval ($op)(a::Float16, b::Float16) = ($op)(float32(a), float32(b))
end
for func in (sin,cos,tan,asin,acos,atan,sinh,cosh,tanh,asinh,acosh,atanh,exp,log,exponent,sqrt)
Expand Down
94 changes: 58 additions & 36 deletions base/linalg/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,38 +23,58 @@ isposdef(D::Diagonal) = all(D.diag .> 0)

+(Da::Diagonal, Db::Diagonal) = Diagonal(Da.diag + Db.diag)
-(Da::Diagonal, Db::Diagonal) = Diagonal(Da.diag - Db.diag)
-{T}(D::Diagonal{T}, M::AbstractMatrix{T}) = full(D) - M
-{T}(M::AbstractMatrix{T}, D::Diagonal{T}) = M - full(D)

*{T<:Number}(x::T, D::Diagonal) = Diagonal(x * D.diag)
*{T<:Number}(D::Diagonal, x::T) = Diagonal(D.diag * x)
/{T<:Number}(D::Diagonal, x::T) = Diagonal(D.diag / x)
*(Da::Diagonal, Db::Diagonal) = Diagonal(Da.diag .* Db.diag)
*(D::Diagonal, V::Vector) = D.diag .* V
*(A::Matrix, D::Diagonal) = scale(A,D.diag)
*(D::Diagonal, A::Matrix) = scale(D.diag,A)

\(Da::Diagonal, Db::Diagonal) = Diagonal(Db.diag ./ Da.diag )
/(Da::Diagonal, Db::Diagonal) = Diagonal(Da.diag ./ Db.diag )
function A_ldiv_B!(D::Diagonal, v::Vector)
function A_ldiv_B!{T}(D::Diagonal{T}, v::AbstractVector{T})
length(v)==length(D.diag) || throw(DimensionMismatch(""))
for i=1:length(D.diag)
d = D.diag[i]
d==0 && throw(SingularException())
v[i] /= d
d==zero(T) && throw(SingularException(i))
v[i] *= inv(d)
end
v
end
function /{TA<:Number,TD<:Number}(A::Matrix{TA}, D::Diagonal{TD})
m, n = size(A)
n==length(D.diag) || throw(DimensionMismatch(""))
(m == 0 || n == 0) && return A
C = Array(typeof(one(TD)/one(TA)),size(A))
for j = 1:n
dj = D.diag[j]
dj==0 && throw(SingularException())
for i = 1:m
C[i,j] = A[i,j] / dj
end
function A_ldiv_B!{T}(D::Diagonal{T}, V::AbstractMatrix{T})
size(V,1)==length(D.diag) || throw(DimensionMismatch(""))
for i=1:length(D.diag)
d = D.diag[i]
d==zero(T) && throw(SingularException(i))
V[i,:] *= inv(d)
end
return C
V
end

conj(D::Diagonal) = Diagonal(conj(D.diag))
transpose(D::Diagonal) = D
ctranspose(D::Diagonal) = conj(D)

diag(D::Diagonal) = D.diag
trace(D::Diagonal) = sum(D.diag)
det(D::Diagonal) = prod(D.diag)
logdet{T<:Real}(D::Diagonal{T}) = sum(log(D.diag))
function logdet{T<:Complex}(D::Diagonal{T}) #Make sure branch cut is correct
x = sum(log(D.diag))
-pi<imag(x)<pi ? x : real(x)+(mod2pi(imag(x)+pi)-pi)*im
end
function \{TD<:Number,TA<:Number}(D::Diagonal{TD}, A::Matrix{TA})
m, n = size(A)
# identity matrices via eye(Diagonal{type},n)
eye{T}(::Type{Diagonal{T}}, n::Int) = Diagonal(ones(T,n))

expm(D::Diagonal) = Diagonal(exp(D.diag))
sqrtm(D::Diagonal) = Diagonal(sqrt(D.diag))

#Linear solver
function \{TD<:Number,TA<:Number}(D::Diagonal{TD}, A::AbstractArray{TA,1})
m, n = size(A,2)==1 ? (size(A,1),1) : size(A)
m==length(D.diag) || throw(DimensionMismatch(""))
(m == 0 || n == 0) && return A
C = Array(typeof(one(TD)/one(TA)),size(A))
Expand All @@ -67,29 +87,31 @@ function \{TD<:Number,TA<:Number}(D::Diagonal{TD}, A::Matrix{TA})
end
return C
end
\(Da::Diagonal, Db::Diagonal) = Diagonal(Db.diag ./ Da.diag)

conj(D::Diagonal) = Diagonal(conj(D.diag))
transpose(D::Diagonal) = D
ctranspose(D::Diagonal) = conj(D)

diag(D::Diagonal) = D.diag
det(D::Diagonal) = prod(D.diag)
logdet(D::Diagonal) = sum(log(D.diag))
function inv{T<:BlasFloat}(D::Diagonal{T})
function inv{T}(D::Diagonal{T})
Di = similar(D.diag)
for i = 1:length(D.diag)
D.diag[i] == 0 && throw(SingularException(i))
Di[i] = 1 / D.diag[i]
D.diag[i]==zero(T) && throw(SingularException(i))
Di[i]=inv(D.diag[i])
end
Diagonal(Di)
end
inv(D::Diagonal) = inv(Diagonal(float(D.diag)))

svdvals(D::Diagonal) = sort(D.diag, rev = true)
eigvals(D::Diagonal) = sort(D.diag)

expm(D::Diagonal) = Diagonal(exp(D.diag))
sqrtm(D::Diagonal) = Diagonal(sqrt(D.diag))
#Eigensystem
eigvals{T<:Number}(D::Diagonal{T}) = D.diag
eigvals(D::Diagonal) = [eigvals(x) for x in D.diag] #For block matrices, etc.
eigvecs(D::Diagonal) = eye(D)
eigfact(D::Diagonal) = Eigen(eigvals(D), eigvecs(D))

# identity matrices via eye(Diagonal{type},n)
eye{T}(::Type{Diagonal{T}}, n::Int) = Diagonal(ones(T,n))
#Singular system
svdvals(D::Diagonal) = sort(D.diag, rev = true)
function svdfact(D::Diagonal, thin=true)
S = abs(D.diag)
piv = sortperm(S, rev=true)
U = full(Diagonal(D.diag./S))
Up= hcat([U[:,i] for i=1:length(D.diag)][piv]...)
V = eye(D)
Vp= hcat([V[:,i] for i=1:length(D.diag)][piv]...)
SVD(Up, S[piv], Vp')
end
43 changes: 34 additions & 9 deletions test/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -462,15 +462,6 @@ for elty in (Float32, Float64, Complex64, Complex128, Int)
@test_approx_eq iWv iFv
end

# Diagonal
D = Diagonal(d)
DM = diagm(d)
@test_approx_eq D*v DM*v
@test_approx_eq D*U DM*U
@test_approx_eq D\v DM\v
@test_approx_eq D\U DM\U
@test_approx_eq det(D) det(DM)

# Test det(A::Matrix)
# In the long run, these tests should step through Strang's
# axiomatic definition of determinants.
Expand Down Expand Up @@ -686,6 +677,40 @@ for elty in (Float32, Float64, Complex64, Complex128)
end
end

#Diagonal matrices
n=12
for relty in (Float16, Float32, Float64, BigFloat), elty in (relty, Complex{relty})
d=convert(Vector{elty}, randn(n))
v=convert(Vector{elty}, randn(n))
U=convert(Matrix{elty}, randn(n,n))
if elty <: Complex
d+=im*convert(Vector{elty}, randn(n))
v+=im*convert(Vector{elty}, randn(n))
U+=im*convert(Matrix{elty}, randn(n,n))
end
D = Diagonal(d)
DM = diagm(d)
@test_approx_eq_eps D*v DM*v n*eps(relty)*(elty<:Complex ? 2:1)
@test_approx_eq_eps D*U DM*U n^2*eps(relty)*(elty<:Complex ? 2:1)
if relty != BigFloat
@test_approx_eq_eps D\v DM\v n*eps(relty)*(elty<:Complex ? 2:1)
@test_approx_eq_eps D\U DM\U n^2*eps(relty)*(elty<:Complex ? 2:1)
end
for func in (det, trace)
@test_approx_eq_eps func(D) func(DM) n^2*eps(relty)
end
if relty != BigFloat && relty != Float16
for func in (expm,)
@test_approx_eq_eps func(D) func(DM) n^2*eps(relty)
end
end
if elty <: Complex && relty != BigFloat && relty != Float16
for func in (logdet, sqrtm)
@test_approx_eq_eps func(D) func(DM) n^2*eps(relty)
end
end
end

# Test gglse
for elty in (Float32, Float64, Complex64, Complex128)
A = convert(Array{elty, 2}, [1 1 1 1; 1 3 1 1; 1 -1 3 1; 1 1 1 3; 1 1 1 -1])
Expand Down

0 comments on commit 7cfcb49

Please sign in to comment.