Skip to content

Commit

Permalink
Allow passing a function to extrema
Browse files Browse the repository at this point in the history
Currently `minimum` and `maximum` can accept a function argument, but
`extrema` cannot. This makes it consistent.
  • Loading branch information
ararslan committed Dec 8, 2018
1 parent a0bc8fd commit 3e09b6b
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 21 deletions.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ New language features
* An *exception stack* is maintained on each task to make exception handling more robust and enable root cause analysis using `catch_stack` ([#28878]).
* The experimental macro `Base.@locals` returns a dictionary of current local variable names
and values ([#29733]).
* The `extrema` function now accepts a function argument in the same manner as `minimum` and
`maximum` ([#TODO]).

Language changes
----------------
Expand Down
40 changes: 27 additions & 13 deletions base/multidimensional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1500,40 +1500,54 @@ julia> extrema(A, dims = (1,2))
(9, 15)
```
"""
extrema(A::AbstractArray; dims = :) = _extrema_dims(A, dims)
extrema(A::AbstractArray; dims = :) = _extrema_dims(identity, A, dims)

_extrema_dims(A::AbstractArray, ::Colon) = _extrema_itr(A)
"""
extrema(f, A::AbstractArray; dims) -> Array{Tuple}
Compute the minimum and maximum of `f` applied to each element in the given dimensions
of `A`.
!!! compat "Julia 1.1"
This method requires Julia 1.1 or later.
```
"""
extrema(f, A::AbstractArray; dims=:) = _extrema_dims(f, A, dims)

_extrema_dims(f, A::AbstractArray, ::Colon) = _extrema_itr(f, A)

function _extrema_dims(A::AbstractArray, dims)
function _extrema_dims(f, A::AbstractArray, dims)
sz = [size(A)...]
for d in dims
sz[d] = 1
end
B = Array{Tuple{eltype(A),eltype(A)}}(undef, sz...)
return extrema!(B, A)
T = promote_op(f, eltype(A))
B = Array{Tuple{T,T}}(undef, sz...)
return extrema!(f, B, A)
end

@noinline function extrema!(B, A)
@noinline function extrema!(f, B, A)
@assert !has_offset_axes(B, A)
sA = size(A)
sB = size(B)
for I in CartesianIndices(sB)
AI = A[I]
B[I] = (AI, AI)
fAI = f(A[I])
B[I] = (fAI, fAI)
end
Bmax = CartesianIndex(sB)
@inbounds @simd for I in CartesianIndices(sA)
J = min(Bmax,I)
BJ = B[J]
AI = A[I]
if AI < BJ[1]
B[J] = (AI, BJ[2])
elseif AI > BJ[2]
B[J] = (BJ[1], AI)
fAI = f(A[I])
if fAI < BJ[1]
B[J] = (fAI, BJ[2])
elseif fAI > BJ[2]
B[J] = (BJ[1], fAI)
end
end
return B
end
extrema!(B, A) = extrema!(identity, B, A)

# Show for pairs() with Cartesian indices. Needs to be here rather than show.jl for bootstrap order
function Base.showarg(io::IO, r::Iterators.Pairs{<:Integer, <:Any, <:Any, T}, toplevel) where T <: Union{AbstractVector, Tuple}
Expand Down
29 changes: 24 additions & 5 deletions base/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -440,24 +440,43 @@ julia> extrema([9,pi,4.5])
(3.141592653589793, 9.0)
```
"""
extrema(itr) = _extrema_itr(itr)
extrema(itr) = _extrema_itr(identity, itr)

function _extrema_itr(itr)
"""
extrema(f, itr) -> Tuple
Compute both the minimum and maximum of `f` applied to each element in `itr` and return
them as a 2-tuple. Only one pass is made over `itr`.
!!! compat "Julia 1.1"
This method requires Julia 1.1 or later.
# Examples
```jldoctest
julia> extrema(sin, 0:π)
(0.0, 0.9092974268256817)
```
"""
extrema(f, itr) = _extrema_itr(f, itr)

function _extrema_itr(f, itr)
y = iterate(itr)
y === nothing && throw(ArgumentError("collection must be non-empty"))
(v, s) = y
vmin = vmax = v
vmin = vmax = f(v)
while true
y = iterate(itr, s)
y === nothing && break
(x, s) = y
vmax = max(x, vmax)
vmin = min(x, vmin)
fx = f(x)
vmax = max(fx, vmax)
vmin = min(fx, vmin)
end
return (vmin, vmax)
end

extrema(x::Real) = (x, x)
extrema(f, x::Real) = (y = f(x); (y, y))

## definitions providing basic traits of arithmetic operators ##

Expand Down
15 changes: 12 additions & 3 deletions test/reduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -180,10 +180,17 @@ prod2(itr) = invoke(prod, Tuple{Any}, itr)
@test maximum(5) == 5
@test minimum(5) == 5
@test extrema(5) == (5, 5)
@test extrema(abs2, 5) = (25, 25)

@test maximum([4, 3, 5, 2]) == 5
@test minimum([4, 3, 5, 2]) == 2
@test extrema([4, 3, 5, 2]) == (2, 5)
let x = [4,3,5,2]
@test maximum(x) == 5
@test minimum(x) == 2
@test extrema(x) == (2, 5)

@test maximum(abs2, x) == 25
@test minimum(abs2, x) == 4
@test extrema(abs2, x) == (4, 25)
end

@test isnan(maximum([NaN]))
@test isnan(minimum([NaN]))
Expand Down Expand Up @@ -211,6 +218,7 @@ prod2(itr) = invoke(prod, Tuple{Any}, itr)

@test maximum(abs2, 3:7) == 49
@test minimum(abs2, 3:7) == 9
@test extrema(abs2, 3:7) == (9, 49)

@test maximum(Int16[1]) === Int16(1)
@test maximum(Vector(Int16(1):Int16(100))) === Int16(100)
Expand All @@ -227,6 +235,7 @@ A = circshift(reshape(1:24,2,3,4), (0,1,1))
@test size(extrema(A,dims=1)) == size(maximum(A,dims=1))
@test size(extrema(A,dims=(1,2))) == size(maximum(A,dims=(1,2)))
@test size(extrema(A,dims=(1,2,3))) == size(maximum(A,dims=(1,2,3)))
@test extrema(x->div(x, 2), A, dims=(2,3)) == reshape([(0,11),(1,12)],2,1,1)

# any & all

Expand Down

0 comments on commit 3e09b6b

Please sign in to comment.