Skip to content

Commit

Permalink
Revise setindex!(A::SparseMatrixCSC, x, I::AbstractVector) such that …
Browse files Browse the repository at this point in the history
…it no longer expunges stored entries on zero assignment. Also add tests.
  • Loading branch information
KristofferC authored and Sacha0 committed Jul 15, 2016
1 parent 606a8e4 commit 6d3e198
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 13 deletions.
23 changes: 10 additions & 13 deletions base/sparse/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2942,7 +2942,7 @@ function setindex!{Tv,Ti,T<:Real}(A::SparseMatrixCSC{Tv,Ti}, x, I::AbstractVecto

colptrA = A.colptr; rowvalA = A.rowval; nzvalA = A.nzval; szA = size(A)
colptrB = colptrA; rowvalB = rowvalA; nzvalB = nzvalA
nadd = ndel = 0
nadd = 0
bidx = aidx = 1

S = issorted(I) ? (1:n) : sortperm(I)
Expand All @@ -2968,8 +2968,8 @@ function setindex!{Tv,Ti,T<:Real}(A::SparseMatrixCSC{Tv,Ti}, x, I::AbstractVecto
r2 = Int(colptrA[col+1] - 1)

# copy from last position till current column
if (nadd > 0) || (ndel > 0)
colptrB[(lastcol+1):col] = colptrA[(lastcol+1):col] .+ (nadd - ndel)
if (nadd > 0)
colptrB[(lastcol+1):col] = colptrA[(lastcol+1):col] .+ nadd
copylen = r1 - aidx
if copylen > 0
copy!(rowvalB, bidx, rowvalA, aidx, copylen)
Expand All @@ -2986,7 +2986,7 @@ function setindex!{Tv,Ti,T<:Real}(A::SparseMatrixCSC{Tv,Ti}, x, I::AbstractVecto
if r1 <= r2
copylen = searchsortedfirst(rowvalA, row, r1, r2, Forward) - r1
if (copylen > 0)
if (nadd > 0) || (ndel > 0)
if (nadd > 0)
copy!(rowvalB, bidx, rowvalA, r1, copylen)
copy!(nzvalB, bidx, nzvalA, r1, copylen)
end
Expand All @@ -2996,13 +2996,14 @@ function setindex!{Tv,Ti,T<:Real}(A::SparseMatrixCSC{Tv,Ti}, x, I::AbstractVecto
end
end

# 0: no change, 1: update, 2: delete, 3: add new
mode = ((r1 <= r2) && (rowvalA[r1] == row)) ? ((v == 0) ? 2 : 1) : ((v == 0) ? 0 : 3)
# 0: no change, 1: update, 2: add new
mode = ((r1 <= r2) && (rowvalA[r1] == row)) ? 1 : ((v == 0) ? 0 : 2)

if (mode > 1) && (nadd == 0) && (ndel == 0)
if (mode > 1) && (nadd == 0)
# copy storage to take changes
colptrA = copy(colptrB)
memreq = (x == 0) ? 0 : n
# see comment/TODO for same statement in preceding logical setindex! method
rowvalA = copy(rowvalB)
nzvalA = copy(nzvalB)
resize!(rowvalB, length(rowvalA)+memreq)
Expand All @@ -3015,10 +3016,6 @@ function setindex!{Tv,Ti,T<:Real}(A::SparseMatrixCSC{Tv,Ti}, x, I::AbstractVecto
aidx += 1
r1 += 1
elseif mode == 2
r1 += 1
aidx += 1
ndel += 1
elseif mode == 3
rowvalB[bidx] = row
nzvalB[bidx] = v
bidx += 1
Expand All @@ -3027,8 +3024,8 @@ function setindex!{Tv,Ti,T<:Real}(A::SparseMatrixCSC{Tv,Ti}, x, I::AbstractVecto
end

# copy the rest
@inbounds if (nadd > 0) || (ndel > 0)
colptrB[(lastcol+1):end] = colptrA[(lastcol+1):end] .+ (nadd - ndel)
@inbounds if (nadd > 0)
colptrB[(lastcol+1):end] = colptrA[(lastcol+1):end] .+ nadd
r1 = colptrA[end]-1
copylen = r1 - aidx + 1
if copylen > 0
Expand Down
14 changes: 14 additions & 0 deletions test/sparsedir/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,20 @@ let A = speye(Int, 5), I=1:10, X=reshape([trues(10); falses(15)],5,5)
@test A[I] == A[X] == [1,0,0,0,0,0,1,0,0,0]
A[I] = [1:10;]
@test A[I] == A[X] == collect(1:10)
A[I] = zeros(Int, 10)
@test nnz(A) == 13
@test countnz(A) == 3
@test A[I] == A[X] == zeros(Int, 10)
c = collect(11:20); c[1] = c[3] = 0
A[I] = c
@test nnz(A) == 13
@test countnz(A) == 11
@test A[I] == A[X] == c
A = speye(Int, 5)
A[I] = c
@test nnz(A) == 12
@test countnz(A) == 11
@test A[I] == A[X] == c
end

let S = sprand(50, 30, 0.5, x->round(Int,rand(x)*100)), I = sprand(Bool, 50, 30, 0.2)
Expand Down

0 comments on commit 6d3e198

Please sign in to comment.