Skip to content

Commit

Permalink
add == for vectors
Browse files Browse the repository at this point in the history
Closes #246
  • Loading branch information
SobhanMP committed Sep 6, 2022
1 parent dfcc48a commit 753cf98
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 0 deletions.
39 changes: 39 additions & 0 deletions src/sparsevector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -879,6 +879,45 @@ end

### Generic functions operating on AbstractSparseVector

## Explicit efficient comparisons with vectors

function ==(A::AbstractCompressedVector,
B::AbstractCompressedVector)
# Different sizes are always different
size(A) size(B) && return false
# Compare nonzero elements
i, j = 1, 1
@inbounds while i <= nnz(A) && j <= nnz(B)
if nonzeroinds(A)[i] == nonzeroinds(B)[j]
nonzeros(A)[i] == nonzeros(B)[j] || return false
i += 1
j += 1
elseif nonzeroinds(A)[i] <= nonzeroinds(B)[j]
iszero(nonzeros(A)[i]) || return false
i += 1
else # nonzeroinds(A)[i] >= nonzeroinds(B)[j]
iszero(nonzeros(B)[j]) || return false
j += 1
end
end

@inbounds for k in i:nnz(A)
iszero(nonzeros(A)[k]) || return false
end

@inbounds for k in j:nnz(B)
iszero(nonzeros(B)[k]) || return false
end

return true
end

==(A::Transpose{<:Any,<:AbstractCompressedVector},
B::Transpose{<:Any,<:AbstractCompressedVector}) = transpose(A) == transpose(B)

==(A::Adjoint{<:Any,<:AbstractCompressedVector},
B::Adjoint{<:Any,<:AbstractCompressedVector}) = adjoint(A) == adjoint(B)

### getindex

function _spgetindex(m::Int, nzind::AbstractVector{Ti}, nzval::AbstractVector{Tv}, i::Integer) where {Tv,Ti}
Expand Down
16 changes: 16 additions & 0 deletions test/sparsematrix_ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -530,4 +530,20 @@ Base.transpose(x::Counting) = Counting(transpose(x.elt))
end
end


@testset "Issue #246" begin
for t in [Int, UInt8, Float64]
a = Counting.(sprand(t, 100, 0.5))
b = Counting.(sprand(t, 100, 0.5))
for m in [identity, transpose, adjoint]
ma = m(a)
mb = m(b)

resetcounter()
ma == mb
@test getcounter() <= nnz(a) + nnz(b)
end
end
end

end # module

0 comments on commit 753cf98

Please sign in to comment.