Skip to content

Commit

Permalink
Specialize findmax/findmin on SparseVector, fixes JuliaLang#42823 (Ju…
Browse files Browse the repository at this point in the history
  • Loading branch information
dkarrasch authored and LilithHafner committed Feb 22, 2022
1 parent 9572117 commit edd78f1
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 3 deletions.
6 changes: 4 additions & 2 deletions stdlib/SparseArrays/src/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2172,8 +2172,10 @@ end
_isless_fm(a, b) = b == b && ( a != a || isless(a, b) )
_isgreater_fm(a, b) = b == b && ( a != a || isless(b, a) )

findmin(A::AbstractSparseMatrixCSC{Tv,Ti}, region) where {Tv,Ti} = _findr(_isless_fm, A, region, Tv)
findmax(A::AbstractSparseMatrixCSC{Tv,Ti}, region) where {Tv,Ti} = _findr(_isgreater_fm, A, region, Tv)
findmin(A::AbstractSparseMatrixCSC{Tv}, region::Union{Integer,Tuple{Integer},NTuple{2,Integer}}) where {Tv} =
_findr(_isless_fm, A, region, Tv)
findmax(A::AbstractSparseMatrixCSC{Tv}, region::Union{Integer,Tuple{Integer},NTuple{2,Integer}}) where {Tv} =
_findr(_isgreater_fm, A, region, Tv)
findmin(A::AbstractSparseMatrixCSC) = (r=findmin(A,(1,2)); (r[1][1], r[2][1]))
findmax(A::AbstractSparseMatrixCSC) = (r=findmax(A,(1,2)); (r[1][1], r[2][1]))

Expand Down
21 changes: 21 additions & 0 deletions stdlib/SparseArrays/src/sparsevector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1411,6 +1411,27 @@ end

minimum(x::AbstractSparseVector) = minimum(identity, x)

for (fun, comp, word) in ((:findmin, :(<), "minimum"), (:findmax, :(>), "maximum"))
@eval function $fun(f, x::AbstractSparseVector{T}) where {T}
n = length(x)
n > 0 || throw(ArgumentError($word * " over empty array is not allowed"))
nzvals = nonzeros(x)
m = length(nzvals)
m == 0 && return zero(T), firstindex(x)
val, index = $fun(f, nzvals)
m == n && return val, index
nzinds = nonzeroinds(x)
zeroval = f(zero(T))
$comp(val, zeroval) && return val, nzinds[index]
# we need to find the first zero, which could be stored or implicit
# we try to avoid findfirst(iszero, x)
sindex = findfirst(iszero, nzvals) # first stored zero, if any
zindex = findfirst(i -> i < nzinds[i], eachindex(nzinds)) # first non-stored zero
index = isnothing(sindex) ? zindex : min(sindex, zindex)
return zeroval, index
end
end

norm(x::SparseVectorUnion, p::Real=2) = norm(nonzeros(x), p)

### linalg.jl
Expand Down
38 changes: 37 additions & 1 deletion stdlib/SparseArrays/test/sparsevector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -816,10 +816,14 @@ end
@test norm(x, Inf) == 3.5
end

@testset "maximum, minimum" begin
@testset "maximum, minimum, findmax, findmin" begin
let x = spv_x1
@test maximum(x) == 3.5
@test findmax(x) == findmax(Vector(x)) == (3.5, 6)
@test findmax(x -> -x, x) == findmax(-x) == (0.75, 5)
@test minimum(x) == -0.75
@test findmin(x) == findmin(Vector(x)) == (-0.75, 5)
@test findmin(x -> -x, x) == findmin(-x) == (-3.5, 6)
@test maximum(abs, x) == 3.5
@test minimum(abs, x) == 0.0
@test @inferred(minimum(t -> true, x)) === true
Expand All @@ -832,21 +836,51 @@ end

let x = abs.(spv_x1)
@test maximum(x) == 3.5
@test findmax(x) == findmax(Vector(x)) == (3.5, 6)
@test findmax(abs2, x) == findmax(abs2.(x)) == findmax(Vector(abs2.(x)))
@test minimum(x) == 0.0
@test findmin(x) == findmin(Vector(x)) == (0.0, 1)
@test findmin(abs2, x) == findmin(abs2.(x)) == findmin(Vector(abs2.(x)))
end

let x = -abs.(spv_x1)
@test maximum(x) == 0.0
@test findmax(x) == findmax(Vector(x)) == (0.0, 1)
@test minimum(x) == -3.5
@test findmin(x) == findmin(Vector(x)) == (-3.5, 6)
end

let x = SparseVector(3, [1, 2, 3], [-4.5, 2.5, 3.5])
@test maximum(x) == 3.5
@test findmax(x) == findmax(Vector(x)) == (3.5, 3)
@test minimum(x) == -4.5
@test findmin(x) == findmin(Vector(x)) == (-4.5, 1)
@test maximum(abs, x) == 4.5
@test minimum(abs, x) == 2.5
end

let x = SparseVector(3, [1, 2, 3], [4.5, 0.0, 3.5])
@test minimum(x) == 0.0
@test findmin(x) == findmin(Vector(x)) == (0.0, 2)
end

let x = SparseVector(3, [1, 2, 3], [-4.5, 0.0, -3.5])
@test maximum(x) == 0.0
@test findmax(x) == findmax(Vector(x)) == (0.0, 2)
end

for i in (2, 3)
let x = SparseVector(4, [1, i, 4], [4.5, 0.0, 3.5])
@test minimum(x) == 0.0
@test findmin(x) == findmin(Vector(x)) == (0.0, 2)
end

let x = SparseVector(4, [1, i, 4], [-4.5, 0.0, -3.5])
@test maximum(x) == 0.0
@test findmax(x) == findmax(Vector(x)) == (0.0, 2)
end
end

let x = spzeros(Float64, 8)
@test maximum(x) == 0.0
@test minimum(x) == 0.0
Expand All @@ -861,6 +895,8 @@ end
let x = spzeros(Float64, 0)
@test_throws ArgumentError minimum(t -> true, x)
@test_throws ArgumentError maximum(t -> true, x)
@test_throws ArgumentError findmin(x)
@test_throws ArgumentError findmax(x)
end
end

Expand Down

0 comments on commit edd78f1

Please sign in to comment.