diff --git a/src/sparsematrix.jl b/src/sparsematrix.jl index 526e8d02..af69dea3 100644 --- a/src/sparsematrix.jl +++ b/src/sparsematrix.jl @@ -887,9 +887,8 @@ intermediate CSR forms and require `length(csrrowptr) >= m + 1`, `length(csrcolval) >= length(I)`, and `length(csrnzval >= length(I))`. Input array `klasttouch`, workspace for the second stage, requires `length(klasttouch) >= n`. Optional input arrays `csccolptr`, `cscrowval`, and `cscnzval` constitute storage for the -returned CSC form `S`. `csccolptr` requires `length(csccolptr) >= n + 1`. If necessary, -`cscrowval` and `cscnzval` are automatically resized to satisfy -`length(cscrowval) >= nnz(S)` and `length(cscnzval) >= nnz(S)`; hence, if `nnz(S)` is +returned CSC form `S`. If necessary, these are resized automatically to satisfy +`length(csccolptr) = n + 1`, `length(cscrowval) = nnz(S)` and `length(cscnzval) = nnz(S)`; hence, if `nnz(S)` is unknown at the outset, passing in empty vectors of the appropriate type (`Vector{Ti}()` and `Vector{Tv}()` respectively) suffices, or calling the `sparse!` method neglecting `cscrowval` and `cscnzval`. @@ -900,6 +899,7 @@ representation of the result's transpose. You may reuse the input arrays' storage (`I`, `J`, `V`) for the output arrays (`csccolptr`, `cscrowval`, `cscnzval`). For example, you may call `sparse!(I, J, V, csrrowptr, csrcolval, csrnzval, I, J, V)`. +Note that they will be resized to satisfy the conditions above. For the sake of efficiency, this method performs no argument checking beyond `1 <= I[k] <= m` and `1 <= J[k] <= n`. Use with care. Testing with `--check-bounds=yes` @@ -954,6 +954,9 @@ function sparse!(I::AbstractVector{Ti}, J::AbstractVector{Ti}, end # This completes the unsorted-row, has-repeats CSR form's construction + # The output array csccolptr can now be resized safely even if aliased with I + resize!(csccolptr, n + 1) + # Sweep through the CSR form, simultaneously (1) calculating the CSC form's column # counts and storing them shifted forward by one in csccolptr; (2) detecting repeated # entries; and (3) repacking the CSR form with the repeated entries combined. @@ -998,10 +1001,13 @@ function sparse!(I::AbstractVector{Ti}, J::AbstractVector{Ti}, Base.hastypemax(Ti) && (countsum <= typemax(Ti) || throw(ArgumentError("more than typemax(Ti)-1 == $(typemax(Ti)-1) entries"))) end - # Now knowing the CSC form's entry count, resize cscrowval and cscnzval if necessary + # Now knowing the CSC form's entry count, resize cscrowval and cscnzval + # Note: This is done unconditionally to appease the buffer checks in the SparseMatrixCSC + # constructor. If these checks are lifted this resizing is only needed if the + # buffers are too short. csccolptr is resized above. cscnnz = countsum - Tj(1) - length(cscrowval) < cscnnz && resize!(cscrowval, cscnnz) - length(cscnzval) < cscnnz && resize!(cscnzval, cscnnz) + resize!(cscrowval, cscnnz) + resize!(cscnzval, cscnnz) # Finally counting-sort the row and nonzero values from the CSR form into cscrowval and # cscnzval. Tracking write positions in csccolptr corrects the column pointers. diff --git a/test/sparse.jl b/test/sparse.jl index df92c69b..bbc03b4c 100644 --- a/test/sparse.jl +++ b/test/sparse.jl @@ -3398,4 +3398,79 @@ using Base: swaprows!, swapcols! end end +@testset "sparse!" begin + using SparseArrays: sparse!, getcolptr, getrowval, nonzeros + + function same_structure(A, B) + return all(getfield(A, f) == getfield(B, f) for f in (:m, :n, :colptr, :rowval)) + end + + function allocate_arrays(m, n) + N = round(Int, 0.5 * m * n) + Tv, Ti = Float64, Int + I = Ti[rand(1:m) for _ in 1:N]; I = Ti[I; I] + J = Ti[rand(1:n) for _ in 1:N]; J = Ti[J; J] + V = Tv.(I) + csrrowptr = Vector{Ti}(undef, m + 1) + csrcolval = Vector{Ti}(undef, length(I)) + csrnzval = Vector{Tv}(undef, length(I)) + klasttouch = Vector{Ti}(undef, n) + csccolptr = Vector{Ti}(undef, n + 1) + cscrowval = Vector{Ti}() + cscnzval = Vector{Tv}() + return I, J, V, klasttouch, csrrowptr, csrcolval, csrnzval, csccolptr, cscrowval, cscnzval + end + + for (m, n) in ((10, 5), (5, 10), (10, 10)) + # Passing csr vectors + I, J, V, klasttouch, csrrowptr, csrcolval, csrnzval = allocate_arrays(m, n) + S = sparse(I, J, V, m, n) + S! = sparse!(I, J, V, m, n, +, klasttouch, csrrowptr, csrcolval, csrnzval) + @test S == S! + @test same_structure(S, S!) + + # Passing csr vectors + csccolptr + I, J, V, klasttouch, csrrowptr, csrcolval, csrnzval, csccolptr = allocate_arrays(m, n) + S = sparse(I, J, V, m, n) + S! = sparse!(I, J, V, m, n, +, klasttouch, csrrowptr, csrcolval, csrnzval, csccolptr) + @test S == S! + @test same_structure(S, S!) + @test getcolptr(S!) === csccolptr + + # Passing csr vectors, and csc vectors + I, J, V, klasttouch, csrrowptr, csrcolval, csrnzval, csccolptr, cscrowval, cscnzval = + allocate_arrays(m, n) + S = sparse(I, J, V, m, n) + S! = sparse!(I, J, V, m, n, +, klasttouch, csrrowptr, csrcolval, csrnzval, + csccolptr, cscrowval, cscnzval) + @test S == S! + @test same_structure(S, S!) + @test getcolptr(S!) === csccolptr + @test getrowval(S!) === cscrowval + @test nonzeros(S!) === cscnzval + + # Passing csr vectors, and csc vectors of insufficient lengths + I, J, V, klasttouch, csrrowptr, csrcolval, csrnzval, csccolptr, cscrowval, cscnzval = + allocate_arrays(m, n) + S = sparse(I, J, V, m, n) + S! = sparse!(I, J, V, m, n, +, klasttouch, csrrowptr, csrcolval, csrnzval, + resize!(csccolptr, 0), resize!(cscrowval, 0), resize!(cscnzval, 0)) + @test S == S! + @test same_structure(S, S!) + @test getcolptr(S!) === csccolptr + @test getrowval(S!) === cscrowval + @test nonzeros(S!) === cscnzval + + # Passing csr vectors, and csc vectors aliased with I, J, V + I, J, V, klasttouch, csrrowptr, csrcolval, csrnzval = allocate_arrays(m, n) + S = sparse(I, J, V, m, n) + S! = sparse!(I, J, V, m, n, +, klasttouch, csrrowptr, csrcolval, csrnzval, I, J, V) + @test S == S! + @test same_structure(S, S!) + @test getcolptr(S!) === I + @test getrowval(S!) === J + @test nonzeros(S!) === V + end +end + end # module