From 8ee59bc3430de10b86a5c5ce0e2d553be0bc90d0 Mon Sep 17 00:00:00 2001 From: Alfredo Braunstein Date: Tue, 25 Dec 2018 14:06:40 +0100 Subject: [PATCH] faster circshift! for SparseMatrixCSC (#30317) * implement circshift! for SparseMatrixCSC * factor helper function shifter!, implement efficient circshift! for SparseVector * add some @inbounds for improved performance * remove allocations completely, giving a large improvement for small matrices * some renaming to avoid polluting the module namespace * remove useless reallocation and fix bug with different in/out types, better tests * avoid action if iszero(r) and/or iszero(c), move sparse vector shifting helpers to sparsevector.jl * Make shift amounts deterministic in tests, move sparse vector tests into sparsevector.jl * comment fix * for some reason, copy!(a::SparseVector, b::SparseVector) does not work (cherry picked from commit 94993e910b25f0cdd02f1c2e2d0dbd5695f43b6f) --- stdlib/SparseArrays/src/SparseArrays.jl | 2 +- stdlib/SparseArrays/src/sparsematrix.jl | 43 ++++++++++++++++++++++++ stdlib/SparseArrays/src/sparsevector.jl | 39 +++++++++++++++++++++ stdlib/SparseArrays/test/sparse.jl | 27 +++++++++++++++ stdlib/SparseArrays/test/sparsevector.jl | 22 ++++++++++++ 5 files changed, 132 insertions(+), 1 deletion(-) diff --git a/stdlib/SparseArrays/src/SparseArrays.jl b/stdlib/SparseArrays/src/SparseArrays.jl index 724b9865782a6..e122f9fa97790 100644 --- a/stdlib/SparseArrays/src/SparseArrays.jl +++ b/stdlib/SparseArrays/src/SparseArrays.jl @@ -27,7 +27,7 @@ import Base: @get!, acos, acosd, acot, acotd, acsch, asech, asin, asind, asinh, vcat, hcat, hvcat, cat, imag, argmax, kron, length, log, log1p, max, min, maximum, minimum, one, promote_eltype, real, reshape, rot180, rotl90, rotr90, round, setindex!, similar, size, transpose, - vec, permute!, map, map!, Array, diff + vec, permute!, map, map!, Array, diff, circshift!, circshift using Random: GLOBAL_RNG, AbstractRNG, randsubseq, randsubseq! diff --git a/stdlib/SparseArrays/src/sparsematrix.jl b/stdlib/SparseArrays/src/sparsematrix.jl index 6b5c9b2585375..6a7ef13b507c8 100644 --- a/stdlib/SparseArrays/src/sparsematrix.jl +++ b/stdlib/SparseArrays/src/sparsematrix.jl @@ -3503,3 +3503,46 @@ end (+)(A::SparseMatrixCSC, J::UniformScaling) = A + sparse(J, size(A)...) (-)(A::SparseMatrixCSC, J::UniformScaling) = A - sparse(J, size(A)...) (-)(J::UniformScaling, A::SparseMatrixCSC) = sparse(J, size(A)...) - A + +## circular shift + +function circshift!(O::SparseMatrixCSC, X::SparseMatrixCSC, (r,c)::Base.DimsInteger{2}) + nnz = length(X.nzval) + + iszero(nnz) && return copy!(O, X) + + ##### column shift + c = mod(c, X.n) + if iszero(c) + copy!(O, X) + else + ##### readjust output + resize!(O.colptr, X.n + 1) + resize!(O.rowval, nnz) + resize!(O.nzval, nnz) + O.colptr[X.n + 1] = nnz + 1 + + # exchange left and right blocks + nleft = X.colptr[X.n - c + 1] - 1 + nright = nnz - nleft + @inbounds for i=c+1:X.n + O.colptr[i] = X.colptr[i-c] + nright + end + @inbounds for i=1:c + O.colptr[i] = X.colptr[X.n - c + i] - nleft + end + # rotate rowval and nzval by the right number of elements + circshift!(O.rowval, X.rowval, (nright,)) + circshift!(O.nzval, X.nzval, (nright,)) + end + ##### row shift + r = mod(r, X.m) + iszero(r) && return O + @inbounds for i=1:O.n + subvector_shifter!(O.rowval, O.nzval, O.colptr[i], O.colptr[i+1]-1, O.m, r) + end + return O +end + +circshift!(O::SparseMatrixCSC, X::SparseMatrixCSC, (r,)::Base.DimsInteger{1}) = circshift!(O, X, (r,0)) +circshift!(O::SparseMatrixCSC, X::SparseMatrixCSC, r::Real) = circshift!(O, X, (Integer(r),0)) diff --git a/stdlib/SparseArrays/src/sparsevector.jl b/stdlib/SparseArrays/src/sparsevector.jl index e7f0e1a2db26a..12c90991e78f2 100644 --- a/stdlib/SparseArrays/src/sparsevector.jl +++ b/stdlib/SparseArrays/src/sparsevector.jl @@ -1974,3 +1974,42 @@ function fill!(A::Union{SparseVector, SparseMatrixCSC}, x) end return A end + + + +# in-place swaps (dense) blocks start:split and split+1:fin in col +function _swap!(col::AbstractVector, start::Integer, fin::Integer, split::Integer) + split == fin && return + reverse!(col, start, split) + reverse!(col, split + 1, fin) + reverse!(col, start, fin) + return +end + + +# in-place shifts a sparse subvector by r. Used also by sparsematrix.jl +function subvector_shifter!(R::AbstractVector, V::AbstractVector, start::Integer, fin::Integer, m::Integer, r::Integer) + split = fin + @inbounds for j = start:fin + # shift positions ... + R[j] += r + if R[j] <= m + split = j + else + R[j] -= m + end + end + # ...but rowval should be sorted within columns + _swap!(R, start, fin, split) + _swap!(V, start, fin, split) +end + + +function circshift!(O::SparseVector, X::SparseVector, (r,)::Base.DimsInteger{1}) + O .= X + subvector_shifter!(O.nzind, O.nzval, 1, length(O.nzind), O.n, mod(r, X.n)) + return O +end + + +circshift!(O::SparseVector, X::SparseVector, r::Real,) = circshift!(O, X, (Integer(r),)) diff --git a/stdlib/SparseArrays/test/sparse.jl b/stdlib/SparseArrays/test/sparse.jl index 163f3b9271f1b..d6e95f1ddb9a9 100644 --- a/stdlib/SparseArrays/test/sparse.jl +++ b/stdlib/SparseArrays/test/sparse.jl @@ -2353,4 +2353,31 @@ end @test success(pipeline(cmd; stdout=stdout, stderr=stderr)) end +@testset "circshift" begin + m,n = 17,15 + A = sprand(m, n, 0.5) + for rshift in (-1, 0, 1, 10), cshift in (-1, 0, 1, 10) + shifts = (rshift, cshift) + # using dense circshift to compare + B = circshift(Matrix(A), shifts) + # sparse circshift + C = circshift(A, shifts) + @test C == B + # sparse circshift should not add structural zeros + @test nnz(C) == nnz(A) + # test circshift! + D = similar(A) + circshift!(D, A, shifts) + @test D == B + @test nnz(D) == nnz(A) + # test different in/out types + A2 = floor.(100A) + E1 = spzeros(Int64, m, n) + E2 = spzeros(Int64, m, n) + circshift!(E1, A2, shifts) + circshift!(E2, Matrix(A2), shifts) + @test E1 == E2 + end +end + end # module diff --git a/stdlib/SparseArrays/test/sparsevector.jl b/stdlib/SparseArrays/test/sparsevector.jl index 6061b21a6721c..8d0f7b7dc1b7e 100644 --- a/stdlib/SparseArrays/test/sparsevector.jl +++ b/stdlib/SparseArrays/test/sparsevector.jl @@ -1260,4 +1260,26 @@ end end end +@testset "SparseVector circshift" begin + n = 100 + v = sprand(n, 0.5) + for shift in (0,-1,1,5,-7,n+10) + x = circshift(Vector(v), shift) + w = circshift(v, shift) + @test nnz(v) == nnz(w) + @test w == x + # test circshift! + v1 = similar(v) + circshift!(v1, v, shift) + @test v1 == x + # test different in/out types + y1 = spzeros(Int64, n) + y2 = spzeros(Int64, n) + v2 = floor.(100v) + circshift!(y1, v2, shift) + circshift!(y2, Vector(v2), shift) + @test y1 == y2 + end +end + end # module