From 99efc9132e0cecbf10522484581fa67a2a8ccef4 Mon Sep 17 00:00:00 2001 From: Stephan Hilb Date: Thu, 20 Dec 2018 00:40:45 +0100 Subject: [PATCH] stdlib/SparseArrays: fix scalar setindex! for vector eltype (#29331) fix #29034 --- stdlib/SparseArrays/src/sparsematrix.jl | 9 +++++++-- stdlib/SparseArrays/test/sparse.jl | 6 ++++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/stdlib/SparseArrays/src/sparsematrix.jl b/stdlib/SparseArrays/src/sparsematrix.jl index 1832edfea1c22..ae964af2cbc5f 100644 --- a/stdlib/SparseArrays/src/sparsematrix.jl +++ b/stdlib/SparseArrays/src/sparsematrix.jl @@ -2370,7 +2370,11 @@ getindex(A::SparseMatrixCSC, I::AbstractVector{<:Integer}, J::AbstractVector{Boo getindex(A::SparseMatrixCSC, I::AbstractVector{Bool}, J::AbstractVector{<:Integer}) = A[findall(I),J] ## setindex! -function setindex!(A::SparseMatrixCSC{Tv,Ti}, _v, _i::Integer, _j::Integer) where {Tv,Ti<:Integer} + +# dispatch helper for #29034 +setindex!(A::SparseMatrixCSC, _v, _i::Integer, _j::Integer) = _setindex_scalar!(A, _v, _i, _j) + +function _setindex_scalar!(A::SparseMatrixCSC{Tv,Ti}, _v, _i::Integer, _j::Integer) where {Tv,Ti<:Integer} v = convert(Tv, _v) i = convert(Ti, _i) j = convert(Ti, _j) @@ -2565,7 +2569,8 @@ end _to_same_csc(::SparseMatrixCSC{Tv, Ti}, V::AbstractMatrix, I...) where {Tv,Ti} = convert(SparseMatrixCSC{Tv,Ti}, V) _to_same_csc(::SparseMatrixCSC{Tv, Ti}, V::AbstractVector, I...) where {Tv,Ti} = convert(SparseMatrixCSC{Tv,Ti}, reshape(V, map(length, I))) -setindex!(A::SparseMatrixCSC{Tv}, B::AbstractVecOrMat, I::Integer, J::Integer) where {Tv} = setindex!(A, convert(Tv, B), I, J) +setindex!(A::SparseMatrixCSC{Tv}, B::AbstractVecOrMat, I::Integer, J::Integer) where {Tv} = _setindex_scalar!(A, B, I, J) + function setindex!(A::SparseMatrixCSC{Tv,Ti}, V::AbstractVecOrMat, Ix::Union{Integer, AbstractVector{<:Integer}, Colon}, Jx::Union{Integer, AbstractVector{<:Integer}, Colon}) where {Tv,Ti<:Integer} @assert !has_offset_axes(A, V, Ix, Jx) (I, J) = Base.ensure_indexable(to_indices(A, (Ix, Jx))) diff --git a/stdlib/SparseArrays/test/sparse.jl b/stdlib/SparseArrays/test/sparse.jl index 019b979fdccd7..16decca0debd3 100644 --- a/stdlib/SparseArrays/test/sparse.jl +++ b/stdlib/SparseArrays/test/sparse.jl @@ -2027,6 +2027,12 @@ end @test nnz(A) == 1 end +@testset "setindex with vector eltype (#29034)" begin + A = sparse([1], [1], [Vector{Float64}(undef, 3)], 3, 3) + A[1,1] = [1.0, 2.0, 3.0] + @test A[1,1] == [1.0, 2.0, 3.0] +end + @testset "show" begin io = IOBuffer() show(io, MIME"text/plain"(), sparse(Int64[1], Int64[1], [1.0]))