From bf2c29f1eee66c632d34581b35b7559b298d14ed Mon Sep 17 00:00:00 2001 From: Pranav Thulasiram Bhat Date: Tue, 16 Aug 2016 22:52:19 +0530 Subject: [PATCH] Define find and findnz for SparseVector (#18049) --- base/sparse/sparsevector.jl | 51 +++++++++++++++++++++++++++++++++- test/sparsedir/sparsevector.jl | 7 +++++ 2 files changed, 57 insertions(+), 1 deletion(-) diff --git a/base/sparse/sparsevector.jl b/base/sparse/sparsevector.jl index cd002a2574bdb..6ecc55e9fe574 100644 --- a/base/sparse/sparsevector.jl +++ b/base/sparse/sparsevector.jl @@ -2,7 +2,7 @@ ### Common definitions -import Base: scalarmax, scalarmin, sort +import Base: scalarmax, scalarmin, sort, find, findnz ### The SparseVector @@ -560,6 +560,55 @@ function getindex{Tv}(A::SparseMatrixCSC{Tv}, I::AbstractVector) SparseVector(n, rowvalB, nzvalB) end +function find{Tv,Ti}(x::SparseVector{Tv,Ti}) + numnz = nnz(x) + I = Array(Ti, numnz) + + nzind = x.nzind + nzval = x.nzval + + count = 1 + @inbounds for i = 1 : numnz + if nzval[i] != 0 + I[count] = nzind[i] + count += 1 + end + end + + count -= 1 + if numnz != count + deleteat!(I, (count+1):numnz) + end + + return I +end + +function findnz{Tv,Ti}(x::SparseVector{Tv,Ti}) + numnz = nnz(x) + + I = Array(Ti, numnz) + V = Array(Tv, numnz) + + nzind = x.nzind + nzval = x.nzval + + count = 1 + @inbounds for i = 1 : numnz + if nzval[i] != 0 + I[count] = nzind[i] + V[count] = nzval[i] + count += 1 + end + end + + count -= 1 + if numnz != count + deleteat!(I, (count+1):numnz) + deleteat!(V, (count+1):numnz) + end + + return (I, V) +end ### Generic functions operating on AbstractSparseVector diff --git a/test/sparsedir/sparsevector.jl b/test/sparsedir/sparsevector.jl index b9079c95807be..fc91b9ba3f11d 100644 --- a/test/sparsedir/sparsevector.jl +++ b/test/sparsedir/sparsevector.jl @@ -225,6 +225,13 @@ let x = SparseVector(10, [2, 7, 9], [2.0, 7.0, 9.0]) @test Base.SparseArrays.dropstored!(x, 5) == SparseVector(10, [7, 9], [7.0, 9.0]) end +# find and findnz tests +@test find(spv_x1) == find(x1_full) +@test findnz(spv_x1) == (find(x1_full), filter(x->x!=0, x1_full)) +let xc = SparseVector(8, [2, 3, 5], [1.25, 0, -0.75]), fc = full(xc) + @test find(xc) == find(fc) + @test findnz(xc) == ([2, 5], [1.25, -0.75]) +end ### Array manipulation