Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Store info in Cholesky type #21976

Merged
merged 3 commits into from
May 30, 2017
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 50 additions & 29 deletions base/linalg/cholesky.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@
# through the Hermitian and Symmetric views or exact symmetric or Hermitian elements which
# is checked for and an error is thrown if the check fails.

# The internal structure is as follows
# - _chol! returns the factor and info without checking positive definiteness
# - chol/chol! returns the factor and checks for positive definiteness
# - cholfact/cholfact! returns Cholesky with checking positive definiteness

# FixMe? The dispatch below seems overly complicated. One simplification could be to
# merge the two Cholesky types into one. It would remove the need for Val completely but
# the cost would be extra unnecessary/unused fields for the unpivoted Cholesky and runtime
Expand All @@ -27,9 +32,12 @@
struct Cholesky{T,S<:AbstractMatrix} <: Factorization{T}
factors::S
uplo::Char
info::BlasInt
end
Cholesky{T}(A::AbstractMatrix{T}, uplo::Symbol) = Cholesky{T,typeof(A)}(A, char_uplo(uplo))
Cholesky{T}(A::AbstractMatrix{T}, uplo::Char) = Cholesky{T,typeof(A)}(A, uplo)
Cholesky{T}(A::AbstractMatrix{T}, uplo::Symbol, info::BlasInt) =
Cholesky{T,typeof(A)}(A, char_uplo(uplo), info)
Cholesky{T}(A::AbstractMatrix{T}, uplo::Char, info::BlasInt) =
Cholesky{T,typeof(A)}(A, uplo, info)

struct CholeskyPivoted{T,S<:AbstractMatrix} <: Factorization{T}
factors::S
Expand All @@ -49,11 +57,11 @@ end
## BLAS/LAPACK element types
function _chol!(A::StridedMatrix{<:BlasFloat}, ::Type{UpperTriangular})
C, info = LAPACK.potrf!('U', A)
return @assertposdef UpperTriangular(C) info
return UpperTriangular(C), info
end
function _chol!(A::StridedMatrix{<:BlasFloat}, ::Type{LowerTriangular})
C, info = LAPACK.potrf!('L', A)
return @assertposdef LowerTriangular(C) info
return LowerTriangular(C), info
end

## Non BLAS/LAPACK element types (generic)
Expand All @@ -64,7 +72,10 @@ function _chol!(A::AbstractMatrix, ::Type{UpperTriangular})
for i = 1:k - 1
A[k,k] -= A[i,k]'A[i,k]
end
Akk = _chol!(A[k,k], UpperTriangular)
Akk, info = _chol!(A[k,k], UpperTriangular)
if info != 0
return UpperTriangular(A), info
end
A[k,k] = Akk
AkkInv = inv(Akk')
for j = k + 1:n
Expand All @@ -75,7 +86,7 @@ function _chol!(A::AbstractMatrix, ::Type{UpperTriangular})
end
end
end
return UpperTriangular(A)
return UpperTriangular(A), convert(BlasInt, 0) # TODO: If we get here, do we know A is pos. def?
end
function _chol!(A::AbstractMatrix, ::Type{LowerTriangular})
n = checksquare(A)
Expand All @@ -84,7 +95,10 @@ function _chol!(A::AbstractMatrix, ::Type{LowerTriangular})
for i = 1:k - 1
A[k,k] -= A[k,i]*A[k,i]'
end
Akk = _chol!(A[k,k], LowerTriangular)
Akk, info = _chol!(A[k,k], LowerTriangular)
if info != 0
return LowerTriangular(A), info
end
A[k,k] = Akk
AkkInv = inv(Akk)
for j = 1:k
Expand All @@ -99,30 +113,33 @@ function _chol!(A::AbstractMatrix, ::Type{LowerTriangular})
end
end
end
return LowerTriangular(A)
return LowerTriangular(A), convert(BlasInt, 0) # TODO: If we get here, do we know A is pos. def?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd think so. What could be wrong?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay thanks. Just wanted to make sure, will remove the comment.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was paranoid so confirmed with experiments:

for N in (10, 100, 1000), i in 1:1000
    A = rand(N, N); B = A'A
    for M in (A, B)
        C, info = invoke(Base.LinAlg._chol!, Tuple{AbstractMatrix, Type{LowerTriangular}}, copy(M), LowerTriangular)
        infoblas = cholfact(Hermitian(M, :L)).info
        info == infoblas == 0 || (info > 0 && infoblas > 0) || error()
    end
    for M in (A, B)
        C, info = invoke(Base.LinAlg._chol!, Tuple{AbstractMatrix, Type{UpperTriangular}}, copy(M), UpperTriangular)
        infoblas = cholfact(Hermitian(M, :U)).info
        info == infoblas == 0 || (info > 0 && infoblas > 0) || error()
    end
end

end

## Numbers
function _chol!(x::Number, uplo)
rx = real(x)
if rx != abs(x)
throw(ArgumentError("x must be positive semidefinite"))
end
rxr = sqrt(rx)
convert(promote_type(typeof(x), typeof(rxr)), rxr)
rxr = sqrt(abs(rx))
rval = convert(promote_type(typeof(x), typeof(rxr)), rxr)
rx == abs(x) ? (rval, convert(BlasInt, 0)) : (rval, convert(BlasInt, 1))
end

chol!(x::Number, uplo) = ((C, info) = _chol!(x, uplo); @assertposdef C info)

non_hermitian_error(f) = throw(ArgumentError("matrix is not symmetric/" *
"Hermitian. This error can be avoided by calling $f(Hermitian(A)) " *
"which will ignore either the upper or lower triangle of the matrix."))

# chol!. Destructive methods for computing Cholesky factor of real symmetric or Hermitian
# matrix
chol!(A::RealHermSymComplexHerm{<:Real,<:StridedMatrix}) =
_chol!(A.uplo == 'U' ? A.data : LinAlg.copytri!(A.data, 'L', true), UpperTriangular)
function chol!(A::RealHermSymComplexHerm{<:Real,<:StridedMatrix})
C, info = _chol!(A.uplo == 'U' ? A.data : LinAlg.copytri!(A.data, 'L', true), UpperTriangular)
@assertposdef C info
end
function chol!(A::StridedMatrix)
ishermitian(A) || non_hermitian_error("chol!")
return _chol!(A, UpperTriangular)
C, info = _chol!(A, UpperTriangular)
@assertposdef C info
end


Expand Down Expand Up @@ -184,7 +201,7 @@ julia> chol(16)
4.0
```
"""
chol(x::Number, args...) = _chol!(x, nothing)
chol(x::Number, args...) = ((C, info) = _chol!(x, nothing); @assertposdef C info)



Expand All @@ -193,9 +210,11 @@ chol(x::Number, args...) = _chol!(x, nothing)
## No pivoting
function cholfact!(A::RealHermSymComplexHerm, ::Type{Val{false}})
if A.uplo == 'U'
Cholesky(_chol!(A.data, UpperTriangular).data, 'U')
CU, info = _chol!(A.data, UpperTriangular)
Cholesky(CU.data, 'U', info)
else
Cholesky(_chol!(A.data, LowerTriangular).data, 'L')
CL, info = _chol!(A.data, LowerTriangular)
Cholesky(CL.data, 'L', info)
end
end

Expand Down Expand Up @@ -354,14 +373,15 @@ end

## Number
function cholfact(x::Number, uplo::Symbol=:U)
xf = fill(chol(x), 1, 1)
Cholesky(xf, uplo)
C, info = _chol!(x, uplo)
xf = fill(C, 1, 1)
Cholesky(xf, uplo, info)
end


function convert(::Type{Cholesky{T}}, C::Cholesky) where T
Cnew = convert(AbstractMatrix{T}, C.factors)
Cholesky{T, typeof(Cnew)}(Cnew, C.uplo)
Cholesky{T, typeof(Cnew)}(Cnew, C.uplo, C.info)
end
convert(::Type{Factorization{T}}, C::Cholesky{T}) where {T} = C
convert(::Type{Factorization{T}}, C::Cholesky) where {T} = convert(Cholesky{T}, C)
Expand All @@ -386,7 +406,7 @@ convert(::Type{Matrix}, F::CholeskyPivoted) = convert(Array, convert(AbstractArr
convert(::Type{Array}, F::CholeskyPivoted) = convert(Matrix, F)
full(F::CholeskyPivoted) = convert(AbstractArray, F)

copy(C::Cholesky) = Cholesky(copy(C.factors), C.uplo)
copy(C::Cholesky) = Cholesky(copy(C.factors), C.uplo, C.info)
copy(C::CholeskyPivoted) = CholeskyPivoted(copy(C.factors), C.uplo, C.piv, C.rank, C.tol, C.info)

size(C::Union{Cholesky, CholeskyPivoted}) = size(C.factors)
Expand Down Expand Up @@ -417,7 +437,7 @@ show(io::IO, C::Cholesky{<:Any,<:AbstractMatrix}) =
(println(io, "$(typeof(C)) with factor:");show(io,C[:UL]))

A_ldiv_B!(C::Cholesky{T,<:AbstractMatrix}, B::StridedVecOrMat{T}) where {T<:BlasFloat} =
LAPACK.potrs!(C.uplo, C.factors, B)
@assertposdef LAPACK.potrs!(C.uplo, C.factors, B) C.info

function A_ldiv_B!(C::Cholesky{<:Any,<:AbstractMatrix}, B::StridedVecOrMat)
if C.uplo == 'L'
Expand Down Expand Up @@ -465,16 +485,18 @@ function A_ldiv_B!(C::CholeskyPivoted, B::StridedMatrix)
end

function det(C::Cholesky)
C.info == 0 || throw(PosDefException(C.info))
dd = one(real(eltype(C)))
for i in 1:size(C.factors,1)
@inbounds for i in 1:size(C.factors,1)
dd *= real(C.factors[i,i])^2
end
dd
end

function logdet(C::Cholesky)
C.info == 0 || throw(PosDefException(C.info))
dd = zero(real(eltype(C)))
for i in 1:size(C.factors,1)
@inbounds for i in 1:size(C.factors,1)
dd += log(real(C.factors[i,i]))
end
dd + dd # instead of 2.0dd which can change the type
Expand Down Expand Up @@ -505,10 +527,9 @@ function logdet(C::CholeskyPivoted)
end

inv!(C::Cholesky{<:BlasFloat,<:StridedMatrix}) =
copytri!(LAPACK.potri!(C.uplo, C.factors), C.uplo, true)
@assertposdef copytri!(LAPACK.potri!(C.uplo, C.factors), C.uplo, true) C.info

inv(C::Cholesky{<:BlasFloat,<:StridedMatrix}) =
inv!(copy(C))
inv(C::Cholesky{<:BlasFloat,<:StridedMatrix}) = inv!(copy(C))

function inv(C::CholeskyPivoted)
chkfullrank(C)
Expand Down
8 changes: 5 additions & 3 deletions base/linalg/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -770,10 +770,12 @@ function factorize(A::StridedMatrix{T}) where T
return UpperTriangular(A)
end
if herm
try
return cholfact(A)
cf = cholfact(A)
if cf.info == 0
return cf
else
return factorize(Hermitian(A))
end
return factorize(Hermitian(A))
end
if sym
return factorize(Symmetric(A))
Expand Down
4 changes: 2 additions & 2 deletions base/linalg/lapack.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2963,10 +2963,10 @@ for (posv, potrf, potri, potrs, pstrf, elty, rtyp) in
(Ptr{UInt8}, Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt}, Ptr{BlasInt}),
&uplo, &size(A,1), A, &lda, info)
chkargsok(info[])
#info[1]>0 means the leading minor of order info[i] is not positive definite
#info[] > 0 means the leading minor of order info[] is not positive definite
#ordinarily, throw Exception here, but return error code here
#this simplifies isposdef! and factorize
return A, info[]
return A, info[] # info stored in Cholesky
end

# SUBROUTINE DPOTRI( UPLO, N, A, LDA, INFO )
Expand Down
24 changes: 17 additions & 7 deletions test/linalg/cholesky.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ debug = false

using Base.Test

using Base.LinAlg: BlasComplex, BlasFloat, BlasReal, QRPivoted
using Base.LinAlg: BlasComplex, BlasFloat, BlasReal, QRPivoted, PosDefException

n = 10

Expand Down Expand Up @@ -60,7 +60,7 @@ for eltya in (Float32, Float64, Complex64, Complex128, BigFloat, Int)

apos = apd[1,1] # test chol(x::Number), needs x>0
@test all(x -> x ≈ √apos, cholfact(apos).factors)
@test_throws ArgumentError chol(-one(eltya))
@test_throws PosDefException chol(-one(eltya))

if eltya <: Real
capds = cholfact(apds)
Expand Down Expand Up @@ -194,10 +194,9 @@ end

begin
# Cholesky factor of Matrix with non-commutative elements, here 2x2-matrices

X = Matrix{Float64}[0.1*rand(2,2) for i in 1:3, j = 1:3]
L = full(Base.LinAlg._chol!(X*X', LowerTriangular))
U = full(Base.LinAlg._chol!(X*X', UpperTriangular))
L = full(Base.LinAlg._chol!(X*X', LowerTriangular)[1])
U = full(Base.LinAlg._chol!(X*X', UpperTriangular)[1])
XX = full(X*X')

@test sum(sum(norm, L*L' - XX)) < eps()
Expand All @@ -212,8 +211,8 @@ for elty in (Float32, Float64, Complex{Float32}, Complex{Float64})
A = randn(5,5)
end
A = convert(Matrix{elty}, A'A)
@test full(cholfact(A)[:L]) ≈ full(invoke(Base.LinAlg._chol!, Tuple{AbstractMatrix, Type{LowerTriangular}}, copy(A), LowerTriangular))
@test full(cholfact(A)[:U]) ≈ full(invoke(Base.LinAlg._chol!, Tuple{AbstractMatrix, Type{UpperTriangular}}, copy(A), UpperTriangular))
@test full(cholfact(A)[:L]) ≈ full(invoke(Base.LinAlg._chol!, Tuple{AbstractMatrix, Type{LowerTriangular}}, copy(A), LowerTriangular)[1])
@test full(cholfact(A)[:U]) ≈ full(invoke(Base.LinAlg._chol!, Tuple{AbstractMatrix, Type{UpperTriangular}}, copy(A), UpperTriangular)[1])
end

# Test up- and downdates
Expand Down Expand Up @@ -272,3 +271,14 @@ end

# Fail for non-BLAS element types
@test_throws ArgumentError cholfact!(Hermitian(rand(Float16, 5,5)), Val{true})

@testset "throw for non positive matrix" begin
for T in (Float32, Float64, Complex64, Complex128)
A = T[1 2; 2 1]; B = T[1, 1]
C = cholfact(A)
@show typeof(A), typeof(B), typeof(C.factors)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

debugging output left in?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, sry

@test_throws PosDefException C\B
@test_throws PosDefException det(C)
@test_throws PosDefException logdet(C)
end
end