Skip to content

Commit

Permalink
Bugfix for getindex(::AbstractSparseVector,::Abstract<Vector,Array>) …
Browse files Browse the repository at this point in the history
…that makes the methods work for sparsevectors S with eltype(S.nzind) != Int. Added corresponding test (#24548)

Also fixed getindex(A::SparseMatrixCSC{Tv}, I::AbstractArray) and getindex(A::SparseMatrixCSC{Tv}, I::AbstractVector) methods so that integer indexing arrays of output have same type as input. Added corresponding tests.
  • Loading branch information
Pbellive authored and KristofferC committed Nov 14, 2017
1 parent c14adf0 commit 82ae10d
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 12 deletions.
6 changes: 3 additions & 3 deletions base/sparse/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2299,7 +2299,7 @@ function getindex(A::SparseMatrixCSC{Tv,Ti}, I::AbstractVector, J::AbstractVecto
end
end

function getindex(A::SparseMatrixCSC{Tv}, I::AbstractArray) where Tv
function getindex(A::SparseMatrixCSC{Tv,Ti}, I::AbstractArray) where {Tv,Ti}
szA = size(A)
nA = szA[1]*szA[2]
colptrA = A.colptr
Expand All @@ -2310,8 +2310,8 @@ function getindex(A::SparseMatrixCSC{Tv}, I::AbstractArray) where Tv
outm = size(I,1)
outn = size(I,2)
szB = (outm, outn)
colptrB = zeros(Int, outn+1)
rowvalB = Vector{Int}(n)
colptrB = zeros(Ti, outn+1)
rowvalB = Vector{Ti}(n)
nzvalB = Vector{Tv}(n)

colB = 1
Expand Down
12 changes: 6 additions & 6 deletions base/sparse/sparsevector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -640,7 +640,7 @@ function getindex(A::SparseMatrixCSC{Tv}, I::AbstractUnitRange) where Tv
SparseVector(n, rowvalB, nzvalB)
end

function getindex(A::SparseMatrixCSC{Tv}, I::AbstractVector) where Tv
function getindex(A::SparseMatrixCSC{Tv,Ti}, I::AbstractVector) where {Tv,Ti}
szA = size(A)
nA = szA[1]*szA[2]
colptrA = A.colptr
Expand All @@ -649,7 +649,7 @@ function getindex(A::SparseMatrixCSC{Tv}, I::AbstractVector) where Tv

n = length(I)
nnzB = min(n, nnz(A))
rowvalB = Vector{Int}(nnzB)
rowvalB = Vector{Ti}(nnzB)
nzvalB = Vector{Tv}(nnzB)

idxB = 1
Expand Down Expand Up @@ -784,15 +784,15 @@ end

getindex(x::AbstractSparseVector, I::AbstractVector{Bool}) = x[find(I)]
getindex(x::AbstractSparseVector, I::AbstractArray{Bool}) = x[find(I)]
@inline function getindex(x::AbstractSparseVector, I::AbstractVector)
@inline function getindex(x::AbstractSparseVector{Tv,Ti}, I::AbstractVector) where {Tv,Ti}
# SparseMatrixCSC has a nicely optimized routine for this; punt
S = SparseMatrixCSC(x.n, 1, [1,length(x.nzind)+1], x.nzind, x.nzval)
S = SparseMatrixCSC(x.n, 1, Ti[1,length(x.nzind)+1], x.nzind, x.nzval)
S[I, 1]
end

function getindex(x::AbstractSparseVector, I::AbstractArray)
function getindex(x::AbstractSparseVector{Tv,Ti}, I::AbstractArray) where {Tv,Ti}
# punt to SparseMatrixCSC
S = SparseMatrixCSC(x.n, 1, [1,length(x.nzind)+1], x.nzind, x.nzval)
S = SparseMatrixCSC(x.n, 1, Ti[1,length(x.nzind)+1], x.nzind, x.nzval)
S[I]
end

Expand Down
7 changes: 7 additions & 0 deletions test/sparse/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -716,6 +716,13 @@ end
@test S1290[end] == (S1290[1] + S1290[2,2])
@test 6 == sum(diag(S1290))
@test Array(S1290)[[3,1],1] == Array(S1290[[3,1],1])

# check that indexing with an abstract array returns matrix
# with same colptr and rowval eltypes as input. Tests PR 24548
r1 = S1290[[5,9]]
r2 = S1290[[1 2;5 9]]
@test isa(r1, SparseVector{Int64,UInt8})
@test isa(r2, SparseMatrixCSC{Int64,UInt8})
# end
end

Expand Down
24 changes: 21 additions & 3 deletions test/sparse/sparsevector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,25 @@ end
I = rand(1:length(x), 20)
r = x[I]
@test isa(r, SparseVector{Float64,Int})
@test all(nonzeros(r) .!= 0.0)
@test all(!iszero, nonzeros(r))
@test Array(r) == Array(x)[I]
end

# issue 24534
let x = convert(SparseVector{Float64,UInt32},sprandn(100,0.5))
I = rand(1:length(x), 20)
r = x[I]
@test isa(r, SparseVector{Float64,UInt32})
@test all(!iszero, nonzeros(r))
@test Array(r) == Array(x)[I]
end

# issue 24534
let x = convert(SparseVector{Float64,UInt32},sprandn(100,0.5))
I = rand(1:length(x), 20,1)
r = x[I]
@test isa(r, SparseMatrixCSC{Float64,UInt32})
@test all(!iszero, nonzeros(r))
@test Array(r) == Array(x)[I]
end
end
Expand All @@ -187,7 +205,7 @@ end
bI[I] = true
r = x[1,bI]
@test isa(r, SparseVector{Float64,Int})
@test all(nonzeros(r) .!= 0.0)
@test all(!iszero, nonzeros(r))
@test Array(r) == Array(x)[1,bI]
end

Expand All @@ -197,7 +215,7 @@ end
bI[I] = true
r = x[bI]
@test isa(r, SparseVector{Float64,Int})
@test all(nonzeros(r) .!= 0.0)
@test all(!iszero, nonzeros(r))
@test Array(r) == Array(x)[bI]
end
end
Expand Down

0 comments on commit 82ae10d

Please sign in to comment.