Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make similar for sparse vectors more consistent and comprehensive #24260

Merged
merged 1 commit into from
Oct 24, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 34 additions & 5 deletions base/sparse/sparsevector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,40 @@ function nonzeroinds(x::SparseColumnView)
return y
end

similar(x::SparseVector, Tv::Type=eltype(x)) = SparseVector(x.n, copy(x.nzind), Vector{Tv}(length(x.nzval)))
function similar(x::SparseVector, ::Type{Tv}, ::Type{Ti}) where {Tv,Ti}
return SparseVector(x.n, copy!(similar(x.nzind, Ti), x.nzind), copy!(similar(x.nzval, Tv), x.nzval))
end
similar(x::SparseVector, ::Type{T}, D::Dims) where {T} = spzeros(T, D...)

## similar
#
# parent method for similar that preserves stored-entry structure (for when new and old dims match)
_sparsesimilar(S::SparseVector, ::Type{TvNew}, ::Type{TiNew}) where {TvNew,TiNew} =
SparseVector(S.n, copy!(similar(S.nzind, TiNew), S.nzind), similar(S.nzval, TvNew))
# parent method for similar that preserves nothing (for when old and new dims differ, and new is 1d)
_sparsesimilar(S::SparseVector, ::Type{TvNew}, ::Type{TiNew}, dims::Dims{1}) where {TvNew,TiNew} =
SparseVector(dims..., similar(S.nzind, TiNew, 0), similar(S.nzval, TvNew, 0))
# parent method for similar that preserves storage space (for old and new dims differ, and new is 2d)
_sparsesimilar(S::SparseVector, ::Type{TvNew}, ::Type{TiNew}, dims::Dims{2}) where {TvNew,TiNew} =
SparseMatrixCSC(dims..., ones(TiNew, last(dims)+1), similar(S.nzind, TiNew), similar(S.nzval, TvNew))
# The following methods hook into the AbstractArray similar hierarchy. The first method
# covers similar(A[, Tv]) calls, which preserve stored-entry structure, and the latter
# methods cover similar(A[, Tv], shape...) calls, which preserve nothing if the dims
# specify a SparseVector result and storage space if the dims specify a SparseMatrixCSC result.
similar(S::SparseVector{<:Any,Ti}, ::Type{TvNew}) where {Ti,TvNew} =
_sparsesimilar(S, TvNew, Ti)
similar(S::SparseVector{<:Any,Ti}, ::Type{TvNew}, dims::Union{Dims{1},Dims{2}}) where {Ti,TvNew} =
_sparsesimilar(S, TvNew, Ti, dims)
# The following methods cover similar(A, Tv, Ti[, shape...]) calls, which specify the
# result's index type in addition to its entry type, and aren't covered by the hooks above.
# The calls without shape again preserve stored-entry structure, whereas those with
# one-dimensional shape preserve nothing, and those with two-dimensional shape
# preserve storage space.
similar(S::SparseVector, ::Type{TvNew}, ::Type{TiNew}) where{TvNew,TiNew} =
_sparsesimilar(S, TvNew, TiNew)
similar(S::SparseVector, ::Type{TvNew}, ::Type{TiNew}, dims::Union{Dims{1},Dims{2}}) where {TvNew,TiNew} =
_sparsesimilar(S, TvNew, TiNew, dims)
similar(S::SparseVector, ::Type{TvNew}, ::Type{TiNew}, m::Integer) where {TvNew,TiNew} =
_sparsesimilar(S, TvNew, TiNew, (m,))
similar(S::SparseVector, ::Type{TvNew}, ::Type{TiNew}, m::Integer, n::Integer) where {TvNew,TiNew} =
_sparsesimilar(S, TvNew, TiNew, (m, n))


### Construct empty sparse vector

Expand Down
71 changes: 64 additions & 7 deletions test/sparse/sparsevector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1144,13 +1144,6 @@ end

mutable struct t20488 end

@testset "similar" begin
x = sparsevec(rand(3) .+ 0.1)
@test length(similar(x, t20488).nzval) == 3
@test typeof(similar(x, Float32, Int32)) == SparseVector{Float32, Int32}
@test typeof(similar(x, Float32)) == SparseVector{Float32, Int}
end

@testset "show" begin
io = IOBuffer()
show(io, MIME"text/plain"(), sparsevec(Int64[1], [1.0]))
Expand All @@ -1174,3 +1167,67 @@ end
@test areequal(spzv ./ 0.0, spzv ./ zeroh, sparsevec(zv ./ 0.0))
@test areequal(0.0 .\ spzv, zeroh .\ spzv, sparsevec(0.0 .\ zv))
end

@testset "similar for SparseVector" begin
A = SparseVector(10, Int[1, 3, 5, 7], Float64[1.0, 3.0, 5.0, 7.0])
# test similar without specifications (preserves stored-entry structure)
simA = similar(A)
@test typeof(simA) == typeof(A)
@test size(simA) == size(A)
@test simA.nzind == A.nzind
@test length(simA.nzval) == length(A.nzval)
# test similar with entry type specification (preserves stored-entry structure)
simA = similar(A, Float32)
@test typeof(simA) == SparseVector{Float32,eltype(A.nzind)}
@test size(simA) == size(A)
@test simA.nzind == A.nzind
@test length(simA.nzval) == length(A.nzval)
# test similar with entry and index type specification (preserves stored-entry structure)
simA = similar(A, Float32, Int8)
@test typeof(simA) == SparseVector{Float32,Int8}
@test size(simA) == size(A)
@test simA.nzind == A.nzind
@test length(simA.nzval) == length(A.nzval)
# test similar with Dims{1} specification (preserves nothing)
simA = similar(A, (6,))
@test typeof(simA) == typeof(A)
@test size(simA) == (6,)
@test length(simA.nzind) == 0
@test length(simA.nzval) == 0
# test similar with entry type and Dims{1} specification (preserves nothing)
simA = similar(A, Float32, (6,))
@test typeof(simA) == SparseVector{Float32,eltype(A.nzind)}
@test size(simA) == (6,)
@test length(simA.nzind) == 0
@test length(simA.nzval) == 0
# test similar with entry type, index type, and Dims{1} specification (preserves nothing)
simA = similar(A, Float32, Int8, (6,))
@test typeof(simA) == SparseVector{Float32,Int8}
@test size(simA) == (6,)
@test length(simA.nzind) == 0
@test length(simA.nzval) == 0
# test entry points to similar with entry type, index type, and non-Dims shape specification
@test similar(A, Float32, Int8, 6, 6) == similar(A, Float32, Int8, (6, 6))
@test similar(A, Float32, Int8, 6) == similar(A, Float32, Int8, (6,))
# test similar with Dims{2} specification (preserves storage space only, not stored-entry structure)
simA = similar(A, (6,6))
@test typeof(simA) == SparseMatrixCSC{eltype(A.nzval),eltype(A.nzind)}
@test size(simA) == (6,6)
@test simA.colptr == ones(eltype(A.nzind), 6+1)
@test length(simA.rowval) == length(A.nzind)
@test length(simA.nzval) == length(A.nzval)
# test similar with entry type and Dims{2} specification (preserves storage space only)
simA = similar(A, Float32, (6,6))
@test typeof(simA) == SparseMatrixCSC{Float32,eltype(A.nzind)}
@test size(simA) == (6,6)
@test simA.colptr == ones(eltype(A.nzind), 6+1)
@test length(simA.rowval) == length(A.nzind)
@test length(simA.nzval) == length(A.nzval)
# test similar with entry type, index type, and Dims{2} specification (preserves storage space only)
simA = similar(A, Float32, Int8, (6,6))
@test typeof(simA) == SparseMatrixCSC{Float32, Int8}
@test size(simA) == (6,6)
@test simA.colptr == ones(eltype(A.nzind), 6+1)
@test length(simA.rowval) == length(A.nzind)
@test length(simA.nzval) == length(A.nzval)
end