diff --git a/base/sparse/csparse.jl b/base/sparse/csparse.jl index f97d8972126f2..e36183bb72946 100644 --- a/base/sparse/csparse.jl +++ b/base/sparse/csparse.jl @@ -300,64 +300,4 @@ function symperm{Tv,Ti}(A::SparseMatrixCSC{Tv,Ti}, pinv::Vector{Ti}) end end (C.').' # double transpose to order the columns -end - -# Based on Direct Methods for Sparse Linear Systems, T. A. Davis, SIAM, Philadelphia, Sept. 2006. -# Section 2.7: Removing entries from a matrix -# http://www.cise.ufl.edu/research/sparse/CSparse/ -function fkeep!{Tv,Ti}(A::SparseMatrixCSC{Tv,Ti}, f, other) - nzorig = nnz(A) - nz = 1 - colptr = A.colptr - rowval = A.rowval - nzval = A.nzval - @inbounds for j = 1:A.n - p = colptr[j] # record current position - colptr[j] = nz # set new position - while p < colptr[j+1] - if f(rowval[p], j, nzval[p], other) - nzval[nz] = nzval[p] - rowval[nz] = rowval[p] - nz += 1 - end - p += 1 - end - end - colptr[A.n + 1] = nz - nz -= 1 - if nz < nzorig - resize!(nzval, nz) - resize!(rowval, nz) - end - A -end - - -immutable DropTolFun <: Func{4} end -(::DropTolFun)(i,j,x,other) = abs(x)>other -immutable DropZerosFun <: Func{4} end -(::DropZerosFun)(i,j,x,other) = x!=0 -immutable TriuFun <: Func{4} end -(::TriuFun)(i,j,x,other) = j>=i + other -immutable TrilFun <: Func{4} end -(::TrilFun)(i,j,x,other) = i>=j - other - -droptol!(A::SparseMatrixCSC, tol) = fkeep!(A, DropTolFun(), tol) -dropzeros!(A::SparseMatrixCSC) = fkeep!(A, DropZerosFun(), nothing) -dropzeros(A::SparseMatrixCSC) = dropzeros!(copy(A)) - -function triu!(A::SparseMatrixCSC, k::Integer=0) - m,n = size(A) - if (k > 0 && k > n) || (k < 0 && -k > m) - throw(BoundsError()) - end - fkeep!(A, TriuFun(), k) -end - -function tril!(A::SparseMatrixCSC, k::Integer=0) - m,n = size(A) - if (k > 0 && k > n) || (k < 0 && -k > m) - throw(BoundsError()) - end - fkeep!(A, TrilFun(), k) -end +end \ No newline at end of file diff --git a/base/sparse/sparsematrix.jl b/base/sparse/sparsematrix.jl index 21fab807e582d..61f62b5b421e0 100644 --- a/base/sparse/sparsematrix.jl +++ b/base/sparse/sparsematrix.jl @@ -479,6 +479,93 @@ transpose{Tv,Ti}(A::SparseMatrixCSC{Tv,Ti}) = qftranspose(A, 1:A.n, Base.IdFun() ctranspose{Tv,Ti}(A::SparseMatrixCSC{Tv,Ti}) = qftranspose(A, 1:A.n, Base.ConjFun()) "See `qftranspose`" ftranspose{Tv,Ti}(A::SparseMatrixCSC{Tv,Ti}, f) = qftranspose(A, 1:A.n, f) + +## fkeep! and children tril!, triu!, droptol!, dropzeros[!] + +""" + fkeep!{Tv,Ti}(A::SparseMatrixCSC{Tv,Ti}, f, other, trim::Bool = true) + +Keep elements of `A` for which test `f` returns `true`. `f`'s signature should be + + f{Tv,Ti}(i::Ti, j::Ti, x::Tv, other::Any) -> Bool + +where `i` and `j` are an element's row and column indices, `x` is the element's value, +and `other` is passed in from the call to `fkeep!`. This method makes a single sweep +through `A`, requiring `O(A.n, nnz(A))`-time and no space beyond that passed in. If `trim` +is `true`, this method trims `A.rowval` and `A.nzval` to length `nnz(A)` after dropping +elements. + +Performance note: As of January 2016, `f` should be a functor for this method to perform +well. This caveat may disappear when the work in `jb/functions` lands. +""" +function fkeep!{Tv,Ti}(A::SparseMatrixCSC{Tv,Ti}, f, other, trim::Bool = true) + An = A.n + Acolptr = A.colptr + Arowval = A.rowval + Anzval = A.nzval + + # Sweep through columns, rewriting kept elements in their new positions + # and updating the column pointers accordingly as we go. + Awritepos = 1 + oldAcolptrAj = 1 + @inbounds for Aj in 1:An + for Ak in oldAcolptrAj:(Acolptr[Aj+1]-1) + Ai = Arowval[Ak] + Ax = Anzval[Ak] + # If this element should be kept, rewrite in new position + if f(Ai, Aj, Ax, other) + if Awritepos != Ak + Arowval[Awritepos] = Ai + Anzval[Awritepos] = Ax + end + Awritepos += 1 + end + end + oldAcolptrAj = Acolptr[Aj+1] + Acolptr[Aj+1] = Awritepos + end + + # Trim A's storage if necessary and desired + if trim + Annz = Acolptr[end] - 1 + if length(Arowval) != Annz + resize!(Arowval, Annz) + end + if length(Anzval) != Annz + resize!(Anzval, Annz) + end + end + + A +end + +immutable TrilFunc <: Base.Func{4} end +immutable TriuFunc <: Base.Func{4} end +(::TrilFunc){Tv,Ti}(i::Ti, j::Ti, x::Tv, k::Integer) = i + k >= j +(::TriuFunc){Tv,Ti}(i::Ti, j::Ti, x::Tv, k::Integer) = j >= i + k +function tril!(A::SparseMatrixCSC, k::Integer = 0, trim::Bool = true) + if k > A.n-1 || k < 1-A.m + throw(ArgumentError("requested diagonal, $k, out of bounds in matrix of size ($(A.m),$(A.n))")) + end + fkeep!(A, TrilFunc(), k, trim) +end +function triu!(A::SparseMatrixCSC, k::Integer = 0, trim::Bool = true) + if k > A.n-1 || k < 1-A.m + throw(ArgumentError("requested diagonal, $k, out of bounds in matrix of size ($(A.m),$(A.n))")) + end + fkeep!(A, TriuFunc(), k, trim) +end + +immutable DroptolFunc <: Base.Func{4} end +(::DroptolFunc){Tv,Ti}(i::Ti, j::Ti, x::Tv, tol::Real) = abs(x) > tol +droptol!(A::SparseMatrixCSC, tol, trim::Bool = true) = fkeep!(A, DroptolFunc(), tol, trim) + +immutable DropzerosFunc <: Base.Func{4} end +(::DropzerosFunc){Tv,Ti}(i::Ti, j::Ti, x::Tv, other) = x != 0 +dropzeros!(A::SparseMatrixCSC, trim::Bool = true) = fkeep!(A, DropzerosFunc(), Void, trim) +dropzeros(A::SparseMatrixCSC, trim::Bool = true) = dropzeros!(copy(A), trim) + + ## Find methods function find(S::SparseMatrixCSC) diff --git a/test/sparsedir/sparse.jl b/test/sparsedir/sparse.jl index c958380a4282d..61daeccb638df 100644 --- a/test/sparsedir/sparse.jl +++ b/test/sparsedir/sparse.jl @@ -978,6 +978,12 @@ perm = randperm(10) # droptol @test Base.droptol!(A,0.01).colptr == [1,1,1,2,2,3,4,6,6,7,9] +@test isequal(Base.droptol!(sparse([1], [1], [1]), 1), SparseMatrixCSC(1,1,Int[1,1],Int[],Int[])) + +# dropzeros +A = sparse([1 2 3; 4 5 6; 7 8 9]) +A.nzval[2] = A.nzval[6] = A.nzval[7] = 0 +@test Base.dropzeros!(A).colptr == [1, 3, 5, 7] #trace @test_throws DimensionMismatch trace(sparse(ones(5,6))) @@ -1000,6 +1006,13 @@ AF = full(A) @test_throws BoundsError tril(A,-6) @test_throws BoundsError triu(A,6) @test_throws BoundsError triu(A,-6) +@test_throws ArgumentError tril!(sparse([1,2,3], [1,2,3], [1,2,3], 3, 4), 4) +@test_throws ArgumentError tril!(sparse([1,2,3], [1,2,3], [1,2,3], 3, 4), -3) +@test_throws ArgumentError triu!(sparse([1,2,3], [1,2,3], [1,2,3], 3, 4), 4) +@test_throws ArgumentError triu!(sparse([1,2,3], [1,2,3], [1,2,3], 3, 4), -3) + +# fkeep trim option +@test isequal(length(tril!(sparse([1,2,3], [1,2,3], [1,2,3], 3, 4), -1).rowval), 0) # test norm