Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check that sparse matrix is valid before constructing a CHOLMOD.Sparse #20464

Merged
merged 1 commit into from
Feb 14, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 38 additions & 16 deletions base/sparse/cholmod.jl
Original file line number Diff line number Diff line change
Expand Up @@ -858,34 +858,45 @@ convert(::Type{Dense}, A::Sparse) = sparse_to_dense(A)

# This constructior assumes zero based colptr and rowval
function (::Type{Sparse}){Tv<:VTypes}(m::Integer, n::Integer,
colptr::Vector{SuiteSparse_long}, rowval::Vector{SuiteSparse_long},
colptr0::Vector{SuiteSparse_long}, rowval0::Vector{SuiteSparse_long},
nzval::Vector{Tv}, stype)
# check if columns are sorted
# checks
## length of input
if length(colptr0) <= n
throw(ArgumentError("length of colptr0 must be at least n + 1 = $(n + 1) but was $(length(colptr0))"))
end
if colptr0[n + 1] > length(rowval0)
throw(ArgumentError("length of rowval0 is $(length(rowval0)) but value of colptr0 requires length to be at least $(colptr0[n + 1])"))
end
if colptr0[n + 1] > length(nzval)
throw(ArgumentError("length of nzval is $(length(nzval)) but value of colptr0 requires length to be at least $(colptr0[n + 1])"))
end
## columns are sorted
iss = true
for i = 2:length(colptr)
if !issorted(view(rowval, colptr[i - 1] + 1:colptr[i]))
for i = 2:length(colptr0)
if !issorted(view(rowval0, colptr0[i - 1] + 1:colptr0[i]))
iss = false
break
end
end

o = allocate_sparse(m, n, length(nzval), iss, true, stype, Tv)
o = allocate_sparse(m, n, colptr0[n + 1], iss, true, stype, Tv)
s = unsafe_load(o.p)

unsafe_copy!(s.p, pointer(colptr), length(colptr))
unsafe_copy!(s.i, pointer(rowval), length(rowval))
unsafe_copy!(s.x, pointer(nzval), length(nzval))
unsafe_copy!(s.p, pointer(colptr0), n + 1)
unsafe_copy!(s.i, pointer(rowval0), colptr0[n + 1])
unsafe_copy!(s.x, pointer(nzval) , colptr0[n + 1])

@isok check_sparse(o)

return o
end

function (::Type{Sparse})(m::Integer, n::Integer,
colptr::Vector{SuiteSparse_long},
rowval::Vector{SuiteSparse_long},
colptr0::Vector{SuiteSparse_long},
rowval0::Vector{SuiteSparse_long},
nzval::Vector{<:VTypes})
o = Sparse(m, n, colptr, rowval, nzval, 0)
o = Sparse(m, n, colptr0, rowval0, nzval, 0)

# sort indices
sort!(o)
Expand All @@ -898,15 +909,26 @@ function (::Type{Sparse})(m::Integer, n::Integer,
end

function (::Type{Sparse}){Tv<:VTypes}(A::SparseMatrixCSC{Tv,SuiteSparse_long}, stype::Integer)
o = allocate_sparse(A.m, A.n, length(A.nzval), true, true, stype, Tv)
## Check length of input. This should never fail but see #20024
if length(A.colptr) <= A.n
throw(ArgumentError("length of colptr must be at least size(A,2) + 1 = $(A.n + 1) but was $(length(A.colptr))"))
end
if nnz(A) > length(A.rowval)
throw(ArgumentError("length of rowval is $(length(A.rowval)) but value of colptr requires length to be at least $(nnz(A))"))
end
if nnz(A) > length(A.nzval)
throw(ArgumentError("length of nzval is $(length(A.nzval)) but value of colptr requires length to be at least $(nnz(A))"))
end

o = allocate_sparse(A.m, A.n, nnz(A), true, true, stype, Tv)
s = unsafe_load(o.p)
for i = 1:length(A.colptr)
for i = 1:(A.n + 1)
unsafe_store!(s.p, A.colptr[i] - 1, i)
end
for i = 1:length(A.rowval)
for i = 1:nnz(A)
unsafe_store!(s.i, A.rowval[i] - 1, i)
end
unsafe_copy!(s.x, pointer(A.nzval), length(A.nzval))
unsafe_copy!(s.x, pointer(A.nzval), nnz(A))

@isok check_sparse(o)

Expand Down Expand Up @@ -1248,7 +1270,7 @@ function getLd!(S::SparseMatrixCSC)
d = Array{eltype(S)}(size(S, 1))
fill!(d, 0)
col = 1
for k = 1:length(S.nzval)
for k = 1:nnz(S)
while k >= S.colptr[col+1]
col += 1
end
Expand Down
2 changes: 1 addition & 1 deletion base/sparse/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ julia> nnz(A)
3
```
"""
nnz(S::SparseMatrixCSC) = Int(S.colptr[end]-1)
nnz(S::SparseMatrixCSC) = Int(S.colptr[S.n + 1]-1)
Copy link
Contributor

@tkelman tkelman Feb 7, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the distinction here could use a more direct test that doesn't go through cholmod-specific methods

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added some tests.

countnz(S::SparseMatrixCSC) = countnz(S.nzval)
count(S::SparseMatrixCSC) = count(S.nzval)

Expand Down
10 changes: 10 additions & 0 deletions test/sparse/cholmod.jl
Original file line number Diff line number Diff line change
Expand Up @@ -692,3 +692,13 @@ let A = sprandn(10, 10, 0.1)
end
end

@testset "Check inputs to Sparse. Related to #20024" for A in (
SparseMatrixCSC(2, 2, [1, 2], CHOLMOD.SuiteSparse_long[], Float64[]),
SparseMatrixCSC(2, 2, [1, 2, 3], CHOLMOD.SuiteSparse_long[1], Float64[]),
SparseMatrixCSC(2, 2, [1, 2, 3], CHOLMOD.SuiteSparse_long[], Float64[1.0]),
SparseMatrixCSC(2, 2, [1, 2, 3], CHOLMOD.SuiteSparse_long[1], Float64[1.0]))

@test_throws ArgumentError CHOLMOD.Sparse(size(A)..., A.colptr - 1, A.rowval - 1, A.nzval)
@test_throws ArgumentError CHOLMOD.Sparse(A)
end

13 changes: 13 additions & 0 deletions test/sparse/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1738,3 +1738,16 @@ end
@test length(c.rowval) == 9
@test length(c.nzval) == 9
end

@testset "check buffers" for n in 1:3
colptr = [1,2,3,4]
rowval = [1,2,3]
nzval1 = ones(0)
nzval2 = ones(3)
A = SparseMatrixCSC(n, n, colptr, rowval, nzval1)
@test nnz(A) == n
@test_throws BoundsError A[n,n]
A = SparseMatrixCSC(n, n, colptr, rowval, nzval2)
@test nnz(A) == n
@test A == eye(n)
end