Skip to content

Commit

Permalink
Fix -, conj, and conj! for sparse matrices with invalid entries…
Browse files Browse the repository at this point in the history
… in `nzval` (#31187)
  • Loading branch information
martinholters authored and KristofferC committed Apr 20, 2019
1 parent 11b8f59 commit c1e1824
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 4 deletions.
18 changes: 14 additions & 4 deletions stdlib/SparseArrays/src/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1558,12 +1558,22 @@ sparse(s::UniformScaling, dims::Dims{2}) = SparseMatrixCSC(s, dims)
sparse(s::UniformScaling, m::Integer, n::Integer) = sparse(s, Dims((m, n)))

# TODO: More appropriate location?
conj!(A::SparseMatrixCSC) = (@inbounds broadcast!(conj, A.nzval, A.nzval); A)
(-)(A::SparseMatrixCSC) = SparseMatrixCSC(A.m, A.n, copy(A.colptr), copy(A.rowval), map(-, A.nzval))
function conj!(A::SparseMatrixCSC)
map!(conj, nzvalview(A), nzvalview(A))
return A
end
function (-)(A::SparseMatrixCSC)
nzval = similar(A.nzval)
map!(-, view(nzval, 1:nnz(A)), nzvalview(A))
return SparseMatrixCSC(A.m, A.n, copy(A.colptr), copy(A.rowval), nzval)
end

# the rest of real, conj, imag are handled correctly via AbstractArray methods
conj(A::SparseMatrixCSC{<:Complex}) =
SparseMatrixCSC(A.m, A.n, copy(A.colptr), copy(A.rowval), conj(A.nzval))
function conj(A::SparseMatrixCSC{<:Complex})
nzval = similar(A.nzval)
map!(conj, view(nzval, 1:nnz(A)), nzvalview(A))
return SparseMatrixCSC(A.m, A.n, copy(A.colptr), copy(A.rowval), nzval)
end
imag(A::SparseMatrixCSC{Tv,Ti}) where {Tv<:Real,Ti} = spzeros(Tv, Ti, A.m, A.n)

## Binary arithmetic and boolean operators
Expand Down
12 changes: 12 additions & 0 deletions stdlib/SparseArrays/test/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2313,4 +2313,16 @@ end
@test m2.module == SparseArrays
end

@testset "unary operations on matrices where length(nzval)>nnz" begin
# this should create a sparse matrix with length(nzval)>nnz
A = SparseMatrixCSC(Complex{BigInt}[1+im 2+2im]')'[1:1, 2:2]
# ...ensure it does! If necessary, the test needs to be updated to use
# another mechanism to create a suitable A.
@assert length(A.nzval) > nnz(A)
@test -A == fill(-2-2im, 1, 1)
@test conj(A) == fill(2-2im, 1, 1)
conj!(A)
@test A == fill(2-2im, 1, 1)
end

end # module

0 comments on commit c1e1824

Please sign in to comment.