Skip to content

Commit

Permalink
Make similar for sparse vectors more consistent and comprehensive. (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
Sacha0 authored and fredrikekre committed Oct 24, 2017
1 parent d3afdfd commit bbcace2
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 12 deletions.
39 changes: 34 additions & 5 deletions base/sparse/sparsevector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,40 @@ function nonzeroinds(x::SparseColumnView)
return y
end

similar(x::SparseVector, Tv::Type=eltype(x)) = SparseVector(x.n, copy(x.nzind), Vector{Tv}(length(x.nzval)))
function similar(x::SparseVector, ::Type{Tv}, ::Type{Ti}) where {Tv,Ti}
return SparseVector(x.n, copy!(similar(x.nzind, Ti), x.nzind), copy!(similar(x.nzval, Tv), x.nzval))
end
similar(x::SparseVector, ::Type{T}, D::Dims) where {T} = spzeros(T, D...)

## similar
#
# parent method for similar that preserves stored-entry structure (for when new and old dims match)
_sparsesimilar(S::SparseVector, ::Type{TvNew}, ::Type{TiNew}) where {TvNew,TiNew} =
SparseVector(S.n, copy!(similar(S.nzind, TiNew), S.nzind), similar(S.nzval, TvNew))
# parent method for similar that preserves nothing (for when old and new dims differ, and new is 1d)
_sparsesimilar(S::SparseVector, ::Type{TvNew}, ::Type{TiNew}, dims::Dims{1}) where {TvNew,TiNew} =
SparseVector(dims..., similar(S.nzind, TiNew, 0), similar(S.nzval, TvNew, 0))
# parent method for similar that preserves storage space (for old and new dims differ, and new is 2d)
_sparsesimilar(S::SparseVector, ::Type{TvNew}, ::Type{TiNew}, dims::Dims{2}) where {TvNew,TiNew} =
SparseMatrixCSC(dims..., ones(TiNew, last(dims)+1), similar(S.nzind, TiNew), similar(S.nzval, TvNew))
# The following methods hook into the AbstractArray similar hierarchy. The first method
# covers similar(A[, Tv]) calls, which preserve stored-entry structure, and the latter
# methods cover similar(A[, Tv], shape...) calls, which preserve nothing if the dims
# specify a SparseVector result and storage space if the dims specify a SparseMatrixCSC result.
similar(S::SparseVector{<:Any,Ti}, ::Type{TvNew}) where {Ti,TvNew} =
_sparsesimilar(S, TvNew, Ti)
similar(S::SparseVector{<:Any,Ti}, ::Type{TvNew}, dims::Union{Dims{1},Dims{2}}) where {Ti,TvNew} =
_sparsesimilar(S, TvNew, Ti, dims)
# The following methods cover similar(A, Tv, Ti[, shape...]) calls, which specify the
# result's index type in addition to its entry type, and aren't covered by the hooks above.
# The calls without shape again preserve stored-entry structure, whereas those with
# one-dimensional shape preserve nothing, and those with two-dimensional shape
# preserve storage space.
similar(S::SparseVector, ::Type{TvNew}, ::Type{TiNew}) where{TvNew,TiNew} =
_sparsesimilar(S, TvNew, TiNew)
similar(S::SparseVector, ::Type{TvNew}, ::Type{TiNew}, dims::Union{Dims{1},Dims{2}}) where {TvNew,TiNew} =
_sparsesimilar(S, TvNew, TiNew, dims)
similar(S::SparseVector, ::Type{TvNew}, ::Type{TiNew}, m::Integer) where {TvNew,TiNew} =
_sparsesimilar(S, TvNew, TiNew, (m,))
similar(S::SparseVector, ::Type{TvNew}, ::Type{TiNew}, m::Integer, n::Integer) where {TvNew,TiNew} =
_sparsesimilar(S, TvNew, TiNew, (m, n))


### Construct empty sparse vector

Expand Down
71 changes: 64 additions & 7 deletions test/sparse/sparsevector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1135,13 +1135,6 @@ end

mutable struct t20488 end

@testset "similar" begin
x = sparsevec(rand(3) .+ 0.1)
@test length(similar(x, t20488).nzval) == 3
@test typeof(similar(x, Float32, Int32)) == SparseVector{Float32, Int32}
@test typeof(similar(x, Float32)) == SparseVector{Float32, Int}
end

@testset "show" begin
io = IOBuffer()
show(io, MIME"text/plain"(), sparsevec(Int64[1], [1.0]))
Expand All @@ -1165,3 +1158,67 @@ end
@test areequal(spzv ./ 0.0, spzv ./ zeroh, sparsevec(zv ./ 0.0))
@test areequal(0.0 .\ spzv, zeroh .\ spzv, sparsevec(0.0 .\ zv))
end

@testset "similar for SparseVector" begin
A = SparseVector(10, Int[1, 3, 5, 7], Float64[1.0, 3.0, 5.0, 7.0])
# test similar without specifications (preserves stored-entry structure)
simA = similar(A)
@test typeof(simA) == typeof(A)
@test size(simA) == size(A)
@test simA.nzind == A.nzind
@test length(simA.nzval) == length(A.nzval)
# test similar with entry type specification (preserves stored-entry structure)
simA = similar(A, Float32)
@test typeof(simA) == SparseVector{Float32,eltype(A.nzind)}
@test size(simA) == size(A)
@test simA.nzind == A.nzind
@test length(simA.nzval) == length(A.nzval)
# test similar with entry and index type specification (preserves stored-entry structure)
simA = similar(A, Float32, Int8)
@test typeof(simA) == SparseVector{Float32,Int8}
@test size(simA) == size(A)
@test simA.nzind == A.nzind
@test length(simA.nzval) == length(A.nzval)
# test similar with Dims{1} specification (preserves nothing)
simA = similar(A, (6,))
@test typeof(simA) == typeof(A)
@test size(simA) == (6,)
@test length(simA.nzind) == 0
@test length(simA.nzval) == 0
# test similar with entry type and Dims{1} specification (preserves nothing)
simA = similar(A, Float32, (6,))
@test typeof(simA) == SparseVector{Float32,eltype(A.nzind)}
@test size(simA) == (6,)
@test length(simA.nzind) == 0
@test length(simA.nzval) == 0
# test similar with entry type, index type, and Dims{1} specification (preserves nothing)
simA = similar(A, Float32, Int8, (6,))
@test typeof(simA) == SparseVector{Float32,Int8}
@test size(simA) == (6,)
@test length(simA.nzind) == 0
@test length(simA.nzval) == 0
# test entry points to similar with entry type, index type, and non-Dims shape specification
@test similar(A, Float32, Int8, 6, 6) == similar(A, Float32, Int8, (6, 6))
@test similar(A, Float32, Int8, 6) == similar(A, Float32, Int8, (6,))
# test similar with Dims{2} specification (preserves storage space only, not stored-entry structure)
simA = similar(A, (6,6))
@test typeof(simA) == SparseMatrixCSC{eltype(A.nzval),eltype(A.nzind)}
@test size(simA) == (6,6)
@test simA.colptr == ones(eltype(A.nzind), 6+1)
@test length(simA.rowval) == length(A.nzind)
@test length(simA.nzval) == length(A.nzval)
# test similar with entry type and Dims{2} specification (preserves storage space only)
simA = similar(A, Float32, (6,6))
@test typeof(simA) == SparseMatrixCSC{Float32,eltype(A.nzind)}
@test size(simA) == (6,6)
@test simA.colptr == ones(eltype(A.nzind), 6+1)
@test length(simA.rowval) == length(A.nzind)
@test length(simA.nzval) == length(A.nzval)
# test similar with entry type, index type, and Dims{2} specification (preserves storage space only)
simA = similar(A, Float32, Int8, (6,6))
@test typeof(simA) == SparseMatrixCSC{Float32, Int8}
@test size(simA) == (6,6)
@test simA.colptr == ones(eltype(A.nzind), 6+1)
@test length(simA.rowval) == length(A.nzind)
@test length(simA.nzval) == length(A.nzval)
end

0 comments on commit bbcace2

Please sign in to comment.