From 1d4e0d028638ceebcb0ba407f819ec2f99dde73a Mon Sep 17 00:00:00 2001 From: Kristoffer Carlsson Date: Wed, 13 Jul 2016 13:38:45 -0700 Subject: [PATCH] Revise setindex!(A::SparseMatrixCSC, x, I::AbstractVector) such that it no longer expunges stored entries on zero assignment. Also add tests. --- base/sparse/sparsematrix.jl | 21 +++++++++------------ test/sparsedir/sparse.jl | 14 ++++++++++++++ 2 files changed, 23 insertions(+), 12 deletions(-) diff --git a/base/sparse/sparsematrix.jl b/base/sparse/sparsematrix.jl index 949d54b7e76f8d..087e1c6714b24b 100644 --- a/base/sparse/sparsematrix.jl +++ b/base/sparse/sparsematrix.jl @@ -2937,7 +2937,7 @@ function setindex!{Tv,Ti,T<:Real}(A::SparseMatrixCSC{Tv,Ti}, x, I::AbstractVecto colptrA = A.colptr; rowvalA = A.rowval; nzvalA = A.nzval; szA = size(A) colptrB = colptrA; rowvalB = rowvalA; nzvalB = nzvalA - nadd = ndel = 0 + nadd = 0 bidx = aidx = 1 S = issorted(I) ? (1:n) : sortperm(I) @@ -2963,8 +2963,8 @@ function setindex!{Tv,Ti,T<:Real}(A::SparseMatrixCSC{Tv,Ti}, x, I::AbstractVecto r2 = Int(colptrA[col+1] - 1) # copy from last position till current column - if (nadd > 0) || (ndel > 0) - colptrB[(lastcol+1):col] = colptrA[(lastcol+1):col] .+ (nadd - ndel) + if (nadd > 0) + colptrB[(lastcol+1):col] = colptrA[(lastcol+1):col] .+ nadd copylen = r1 - aidx if copylen > 0 copy!(rowvalB, bidx, rowvalA, aidx, copylen) @@ -2981,7 +2981,7 @@ function setindex!{Tv,Ti,T<:Real}(A::SparseMatrixCSC{Tv,Ti}, x, I::AbstractVecto if r1 <= r2 copylen = searchsortedfirst(rowvalA, row, r1, r2, Forward) - r1 if (copylen > 0) - if (nadd > 0) || (ndel > 0) + if (nadd > 0) copy!(rowvalB, bidx, rowvalA, r1, copylen) copy!(nzvalB, bidx, nzvalA, r1, copylen) end @@ -2992,12 +2992,13 @@ function setindex!{Tv,Ti,T<:Real}(A::SparseMatrixCSC{Tv,Ti}, x, I::AbstractVecto end # 0: no change, 1: update, 2: delete, 3: add new - mode = ((r1 <= r2) && (rowvalA[r1] == row)) ? ((v == 0) ? 2 : 1) : ((v == 0) ? 0 : 3) + mode = ((r1 <= r2) && (rowvalA[r1] == row)) ? 1 : ((v == 0) ? 0 : 3) - if (mode > 1) && (nadd == 0) && (ndel == 0) + if (mode > 1) && (nadd == 0) # copy storage to take changes colptrA = copy(colptrB) memreq = (x == 0) ? 0 : n + # see comment/TODO for same statement in preceding logical setindex! method rowvalA = copy(rowvalB) nzvalA = copy(nzvalB) resize!(rowvalB, length(rowvalA)+memreq) @@ -3009,10 +3010,6 @@ function setindex!{Tv,Ti,T<:Real}(A::SparseMatrixCSC{Tv,Ti}, x, I::AbstractVecto bidx += 1 aidx += 1 r1 += 1 - elseif mode == 2 - r1 += 1 - aidx += 1 - ndel += 1 elseif mode == 3 rowvalB[bidx] = row nzvalB[bidx] = v @@ -3022,8 +3019,8 @@ function setindex!{Tv,Ti,T<:Real}(A::SparseMatrixCSC{Tv,Ti}, x, I::AbstractVecto end # copy the rest - @inbounds if (nadd > 0) || (ndel > 0) - colptrB[(lastcol+1):end] = colptrA[(lastcol+1):end] .+ (nadd - ndel) + @inbounds if (nadd > 0) + colptrB[(lastcol+1):end] = colptrA[(lastcol+1):end] .+ nadd r1 = colptrA[end]-1 copylen = r1 - aidx + 1 if copylen > 0 diff --git a/test/sparsedir/sparse.jl b/test/sparsedir/sparse.jl index 33a00dbb1c10ae..47e94324ba2dac 100644 --- a/test/sparsedir/sparse.jl +++ b/test/sparsedir/sparse.jl @@ -699,6 +699,20 @@ let A = speye(Int, 5), I=1:10, X=reshape([trues(10); falses(15)],5,5) @test A[I] == A[X] == [1,0,0,0,0,0,1,0,0,0] A[I] = [1:10;] @test A[I] == A[X] == collect(1:10) + A[I] = zeros(Int, 10) + @test nnz(A) == 13 + @test countnz(A) == 3 + @test A[I] == A[X] == zeros(Int, 10) + c = collect(11:20); c[1] = c[3] = 0 + A[I] = c + @test nnz(A) == 13 + @test countnz(A) == 11 + @test A[I] == A[X] == c + A = speye(Int, 5) + A[I] = c + @test nnz(A) == 12 + @test countnz(A) == 11 + @test A[I] == A[X] == c end let S = sprand(50, 30, 0.5, x->round(Int,rand(x)*100)), I = sprand(Bool, 50, 30, 0.2)