diff --git a/base/reduce.jl b/base/reduce.jl index f7558d11b21be..f0ae3f055019f 100644 --- a/base/reduce.jl +++ b/base/reduce.jl @@ -428,6 +428,6 @@ function count(pred::Union(Function,Func{1}), a::AbstractArray) end immutable NotEqZero <: Func{1} end -call(::NotEqZero, x) = x != 0 +call(::NotEqZero, x) = x != zero(x) countnz(a) = count(NotEqZero(), a) diff --git a/base/sparse/csparse.jl b/base/sparse/csparse.jl index 4794e3c511b13..7b0b5e7312f21 100644 --- a/base/sparse/csparse.jl +++ b/base/sparse/csparse.jl @@ -29,7 +29,7 @@ function sparse{Tv,Ti<:Integer}(I::AbstractVector{Ti}, J::AbstractVector{Ti}, Rnz[1] = 1 nz = 0 for k=1:N - if V[k] != 0 + if V[k] != zero(Tv) Rnz[I[k]+1] += 1 nz += 1 end @@ -49,7 +49,7 @@ function sparse{Tv,Ti<:Integer}(I::AbstractVector{Ti}, J::AbstractVector{Ti}, ((iind > 0) && (jind > 0)) || throw(BoundsError()) p = Wj[iind] Vk = V[k] - if Vk != 0 + if Vk != zero(Tv) Wj[iind] += 1 Rx[p] = Vk Ri[p] = jind diff --git a/base/sparse/sparsematrix.jl b/base/sparse/sparsematrix.jl index 5279bf04ab0b8..a4a642f6bd4c3 100644 --- a/base/sparse/sparsematrix.jl +++ b/base/sparse/sparsematrix.jl @@ -314,7 +314,7 @@ function findn{Tv,Ti}(S::SparseMatrixCSC{Tv,Ti}) count = 1 @inbounds for col = 1 : S.n, k = S.colptr[col] : (S.colptr[col+1]-1) - if S.nzval[k] != 0 + if S.nzval[k] != zero(Tv) I[count] = S.rowval[k] J[count] = col count += 1 @@ -338,7 +338,7 @@ function findnz{Tv,Ti}(S::SparseMatrixCSC{Tv,Ti}) count = 1 @inbounds for col = 1 : S.n, k = S.colptr[col] : (S.colptr[col+1]-1) - if S.nzval[k] != 0 + if S.nzval[k] != zero(Tv) I[count] = S.rowval[k] J[count] = col V[count] = S.nzval[k] @@ -1224,7 +1224,7 @@ function setindex!{T,Ti}(A::SparseMatrixCSC{T,Ti}, v, i0::Integer, i1::Integer) v = convert(T, v) r1 = int(A.colptr[i1]) r2 = int(A.colptr[i1+1]-1) - if v == 0 #either do nothing or delete entry if it exists + if v == zero(T) #either do nothing or delete entry if it exists if r1 <= r2 r1 = searchsortedfirst(A.rowval, i0, r1, r2, Forward) if (r1 <= r2) && (A.rowval[r1] == i0) diff --git a/test/sparse.jl b/test/sparse.jl index f03efcf4a4944..dc7d0977ea53f 100644 --- a/test/sparse.jl +++ b/test/sparse.jl @@ -472,3 +472,20 @@ end # issue #8976 @test conj(sparse([1im])) == sparse(conj([1im])) @test conj!(sparse([1im])) == sparse(conj!([1im])) + +# Test proper handling of zeros for user types +immutable SpTestVal + value::Float64 +end +(==)(x::SpTestVal,y::SpTestVal) = (x.value == y.value) +Base.zero(x::SpTestVal) = SpTestVal(0) +Base.zero(::Type{SpTestVal}) = SpTestVal(0) +A = sparse([1,2,3],[1,2,3],[SpTestVal(1),SpTestVal(0),SpTestVal(3)]) +@test nnz(A) == 2 # zeros should be stripped by sparse +A[2,2] = SpTestVal(0) +@test nnz(A) == 2 # zeros should be stripped by setindex +A = SparseMatrixCSC(3,3,[1,2,3,4],[1,2,3], + [SpTestVal(1.0),SpTestVal(0.0),SpTestVal(3.0)]) +@test countnz(A) == 2 +r,c,v = findnz(A) +@test length(r) == length(c) == length(v) == 2 \ No newline at end of file