diff --git a/base/sparse/sparsevector.jl b/base/sparse/sparsevector.jl index a92e7f6eaf8ab..62f7cc61617e2 100644 --- a/base/sparse/sparsevector.jl +++ b/base/sparse/sparsevector.jl @@ -58,11 +58,40 @@ function nonzeroinds(x::SparseColumnView) return y end -similar(x::SparseVector, Tv::Type=eltype(x)) = SparseVector(x.n, copy(x.nzind), Vector{Tv}(length(x.nzval))) -function similar(x::SparseVector, ::Type{Tv}, ::Type{Ti}) where {Tv,Ti} - return SparseVector(x.n, copy!(similar(x.nzind, Ti), x.nzind), copy!(similar(x.nzval, Tv), x.nzval)) -end -similar(x::SparseVector, ::Type{T}, D::Dims) where {T} = spzeros(T, D...) + +## similar +# +# parent method for similar that preserves stored-entry structure (for when new and old dims match) +_sparsesimilar(S::SparseVector, ::Type{TvNew}, ::Type{TiNew}) where {TvNew,TiNew} = + SparseVector(S.n, copy!(similar(S.nzind, TiNew), S.nzind), similar(S.nzval, TvNew)) +# parent method for similar that preserves nothing (for when old and new dims differ, and new is 1d) +_sparsesimilar(S::SparseVector, ::Type{TvNew}, ::Type{TiNew}, dims::Dims{1}) where {TvNew,TiNew} = + SparseVector(dims..., similar(S.nzind, TiNew, 0), similar(S.nzval, TvNew, 0)) +# parent method for similar that preserves storage space (for old and new dims differ, and new is 2d) +_sparsesimilar(S::SparseVector, ::Type{TvNew}, ::Type{TiNew}, dims::Dims{2}) where {TvNew,TiNew} = + SparseMatrixCSC(dims..., ones(TiNew, last(dims)+1), similar(S.nzind, TiNew), similar(S.nzval, TvNew)) +# The following methods hook into the AbstractArray similar hierarchy. The first method +# covers similar(A[, Tv]) calls, which preserve stored-entry structure, and the latter +# methods cover similar(A[, Tv], shape...) calls, which preserve nothing if the dims +# specify a SparseVector result and storage space if the dims specify a SparseMatrixCSC result. +similar(S::SparseVector{<:Any,Ti}, ::Type{TvNew}) where {Ti,TvNew} = + _sparsesimilar(S, TvNew, Ti) +similar(S::SparseVector{<:Any,Ti}, ::Type{TvNew}, dims::Union{Dims{1},Dims{2}}) where {Ti,TvNew} = + _sparsesimilar(S, TvNew, Ti, dims) +# The following methods cover similar(A, Tv, Ti[, shape...]) calls, which specify the +# result's index type in addition to its entry type, and aren't covered by the hooks above. +# The calls without shape again preserve stored-entry structure, whereas those with +# one-dimensional shape preserve nothing, and those with two-dimensional shape +# preserve storage space. +similar(S::SparseVector, ::Type{TvNew}, ::Type{TiNew}) where{TvNew,TiNew} = + _sparsesimilar(S, TvNew, TiNew) +similar(S::SparseVector, ::Type{TvNew}, ::Type{TiNew}, dims::Union{Dims{1},Dims{2}}) where {TvNew,TiNew} = + _sparsesimilar(S, TvNew, TiNew, dims) +similar(S::SparseVector, ::Type{TvNew}, ::Type{TiNew}, m::Integer) where {TvNew,TiNew} = + _sparsesimilar(S, TvNew, TiNew, (m,)) +similar(S::SparseVector, ::Type{TvNew}, ::Type{TiNew}, m::Integer, n::Integer) where {TvNew,TiNew} = + _sparsesimilar(S, TvNew, TiNew, (m, n)) + ### Construct empty sparse vector diff --git a/test/sparse/sparsevector.jl b/test/sparse/sparsevector.jl index 6935071f86355..52b520476665f 100644 --- a/test/sparse/sparsevector.jl +++ b/test/sparse/sparsevector.jl @@ -1144,13 +1144,6 @@ end mutable struct t20488 end -@testset "similar" begin - x = sparsevec(rand(3) .+ 0.1) - @test length(similar(x, t20488).nzval) == 3 - @test typeof(similar(x, Float32, Int32)) == SparseVector{Float32, Int32} - @test typeof(similar(x, Float32)) == SparseVector{Float32, Int} -end - @testset "show" begin io = IOBuffer() show(io, MIME"text/plain"(), sparsevec(Int64[1], [1.0])) @@ -1174,3 +1167,67 @@ end @test areequal(spzv ./ 0.0, spzv ./ zeroh, sparsevec(zv ./ 0.0)) @test areequal(0.0 .\ spzv, zeroh .\ spzv, sparsevec(0.0 .\ zv)) end + +@testset "similar for SparseVector" begin + A = SparseVector(10, Int[1, 3, 5, 7], Float64[1.0, 3.0, 5.0, 7.0]) + # test similar without specifications (preserves stored-entry structure) + simA = similar(A) + @test typeof(simA) == typeof(A) + @test size(simA) == size(A) + @test simA.nzind == A.nzind + @test length(simA.nzval) == length(A.nzval) + # test similar with entry type specification (preserves stored-entry structure) + simA = similar(A, Float32) + @test typeof(simA) == SparseVector{Float32,eltype(A.nzind)} + @test size(simA) == size(A) + @test simA.nzind == A.nzind + @test length(simA.nzval) == length(A.nzval) + # test similar with entry and index type specification (preserves stored-entry structure) + simA = similar(A, Float32, Int8) + @test typeof(simA) == SparseVector{Float32,Int8} + @test size(simA) == size(A) + @test simA.nzind == A.nzind + @test length(simA.nzval) == length(A.nzval) + # test similar with Dims{1} specification (preserves nothing) + simA = similar(A, (6,)) + @test typeof(simA) == typeof(A) + @test size(simA) == (6,) + @test length(simA.nzind) == 0 + @test length(simA.nzval) == 0 + # test similar with entry type and Dims{1} specification (preserves nothing) + simA = similar(A, Float32, (6,)) + @test typeof(simA) == SparseVector{Float32,eltype(A.nzind)} + @test size(simA) == (6,) + @test length(simA.nzind) == 0 + @test length(simA.nzval) == 0 + # test similar with entry type, index type, and Dims{1} specification (preserves nothing) + simA = similar(A, Float32, Int8, (6,)) + @test typeof(simA) == SparseVector{Float32,Int8} + @test size(simA) == (6,) + @test length(simA.nzind) == 0 + @test length(simA.nzval) == 0 + # test entry points to similar with entry type, index type, and non-Dims shape specification + @test similar(A, Float32, Int8, 6, 6) == similar(A, Float32, Int8, (6, 6)) + @test similar(A, Float32, Int8, 6) == similar(A, Float32, Int8, (6,)) + # test similar with Dims{2} specification (preserves storage space only, not stored-entry structure) + simA = similar(A, (6,6)) + @test typeof(simA) == SparseMatrixCSC{eltype(A.nzval),eltype(A.nzind)} + @test size(simA) == (6,6) + @test simA.colptr == ones(eltype(A.nzind), 6+1) + @test length(simA.rowval) == length(A.nzind) + @test length(simA.nzval) == length(A.nzval) + # test similar with entry type and Dims{2} specification (preserves storage space only) + simA = similar(A, Float32, (6,6)) + @test typeof(simA) == SparseMatrixCSC{Float32,eltype(A.nzind)} + @test size(simA) == (6,6) + @test simA.colptr == ones(eltype(A.nzind), 6+1) + @test length(simA.rowval) == length(A.nzind) + @test length(simA.nzval) == length(A.nzval) + # test similar with entry type, index type, and Dims{2} specification (preserves storage space only) + simA = similar(A, Float32, Int8, (6,6)) + @test typeof(simA) == SparseMatrixCSC{Float32, Int8} + @test size(simA) == (6,6) + @test simA.colptr == ones(eltype(A.nzind), 6+1) + @test length(simA.rowval) == length(A.nzind) + @test length(simA.nzval) == length(A.nzval) +end