Skip to content

Commit

Permalink
implement circshift! for SparseMatrixCS
Browse files Browse the repository at this point in the history
  • Loading branch information
abraunst committed Dec 13, 2018
1 parent 560e829 commit b48e876
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 1 deletion.
2 changes: 1 addition & 1 deletion stdlib/SparseArrays/src/SparseArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import Base: @get!, acos, acosd, acot, acotd, acsch, asech, asin, asind, asinh,
vcat, hcat, hvcat, cat, imag, argmax, kron, length, log, log1p, max, min,
maximum, minimum, one, promote_eltype, real, reshape, rot180,
rotl90, rotr90, round, setindex!, similar, size, transpose,
vec, permute!, map, map!, Array, diff
vec, permute!, map, map!, Array, diff, circshift!, circshift

using Random: GLOBAL_RNG, AbstractRNG, randsubseq, randsubseq!

Expand Down
10 changes: 10 additions & 0 deletions stdlib/SparseArrays/src/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3506,3 +3506,13 @@ end
(+)(A::SparseMatrixCSC, J::UniformScaling) = A + sparse(J, size(A)...)
(-)(A::SparseMatrixCSC, J::UniformScaling) = A - sparse(J, size(A)...)
(-)(J::UniformScaling, A::SparseMatrixCSC) = sparse(J, size(A)...) - A

## circular shift

function circshift!(O::SparseMatrixCSC, X::SparseMatrixCSC, (r,c)::Base.DimsInteger{2})
I, J, V = findnz(X)
O .= sparse(mod1.(I .+ r, X.m), mod1.(J .+ c, X.n), V, X.m, X.n)
end

circshift!(O::SparseMatrixCSC, X::SparseMatrixCSC, (r,)::Base.DimsInteger{1}) = circshift!(O, X, (r,0))
circshift!(O::SparseMatrixCSC, X::SparseMatrixCSC, r::Real) = circshift!(O, X, (Integer(r),0))
12 changes: 12 additions & 0 deletions stdlib/SparseArrays/test/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2359,4 +2359,16 @@ end
@test one(A) isa SparseMatrixCSC{Int}
end

@testset "circshift" begin
A = sprand(40, 30, 0.3)
@test nnz(A) == nnz(circshift(A,(1,2)))
@test circshift(A, (3,4)) == circshift(Matrix(A), (3,4))
@test circshift(A, (7,-4)) == circshift(Matrix(A), (7,-4))
@test circshift(A, 14) == circshift(Matrix(A), 14)
O = similar(A)
circshift!(O, A, (11,13))
@test nnz(O) == nnz(A)
@test O == circshift(A, (11,13))
end

end # module

0 comments on commit b48e876

Please sign in to comment.