Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make minmax faster for Float32/64 #41709

Merged
merged 7 commits into from
Jul 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 26 additions & 9 deletions base/math.jl
Original file line number Diff line number Diff line change
Expand Up @@ -758,17 +758,34 @@ end
atan(y::Real, x::Real) = atan(promote(float(y),float(x))...)
atan(y::T, x::T) where {T<:AbstractFloat} = Base.no_op_err("atan", T)

max(x::T, y::T) where {T<:AbstractFloat} = ifelse((y > x) | (signbit(y) < signbit(x)),
ifelse(isnan(x), x, y), ifelse(isnan(y), y, x))


min(x::T, y::T) where {T<:AbstractFloat} = ifelse((y < x) | (signbit(y) > signbit(x)),
ifelse(isnan(x), x, y), ifelse(isnan(y), y, x))
_isless(x::T, y::T) where {T<:AbstractFloat} = (x < y) || (signbit(x) > signbit(y))
min(x::T, y::T) where {T<:AbstractFloat} = isnan(x) || ~isnan(y) && _isless(x, y) ? x : y
max(x::T, y::T) where {T<:AbstractFloat} = isnan(x) || ~isnan(y) && _isless(y, x) ? x : y
minmax(x::T, y::T) where {T<:AbstractFloat} = min(x, y), max(x, y)

_isless(x::Float16, y::Float16) = signbit(widen(x) - widen(y))

function min(x::T, y::T) where {T<:Union{Float32,Float64}}
Copy link
Contributor

@mikmoore mikmoore Jun 22, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the new min, max, minmax signatures be expanded to include Float16? Could either add it to the list or use the defined IEEEFloat union.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

using T<:IEEEFloat everywhere that now uses
T<:Union{Float32,Float64} is helpful. Some systems have onchip support for Float16s, so this is a win there. The systems that emulate Float16s support min, max, minmax, so the processing is at worst unchanged there.

diff = x - y
argmin = ifelse(signbit(diff), x, y)
anynan = isnan(x)|isnan(y)
ifelse(anynan, diff, argmin)
end

minmax(x::T, y::T) where {T<:AbstractFloat} =
ifelse(isnan(x) | isnan(y), ifelse(isnan(x), (x,x), (y,y)),
ifelse((y > x) | (signbit(x) > signbit(y)), (x,y), (y,x)))
function max(x::T, y::T) where {T<:Union{Float32,Float64}}
diff = x - y
argmax = ifelse(signbit(diff), y, x)
anynan = isnan(x)|isnan(y)
ifelse(anynan, diff, argmax)
end

function minmax(x::T, y::T) where {T<:Union{Float32,Float64}}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this detailed definition necessary? It seems like a more generic
minmax(x::T, y::T) where {T<:Union{Float32,Float64}} = min(x, y), max(x, y)
would perform identically thanks to inlining.

Copy link
Member Author

@N5N3 N5N3 Jun 1, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My local bench does show some performance difference.

julia> a = randn(1024); b = randn(1024); z = min.(a,b); zz = minmax.(a,b);

julia> using BenchmarkTools
[ Info: Precompiling BenchmarkTools [6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf]

julia> @benchmark $zz .= minmax.($a, $b)
BenchmarkTools.Trial: 10000 samples with 198 evaluations.
 Range (min  max):  440.404 ns   1.189 μs  ┊ GC (min  max): 0.00%  0.00%
 Time  (median):     442.424 ns              ┊ GC (median):    0.00%        
 Time  (mean ± σ):   457.627 ns ± 45.207 ns  ┊ GC (mean ± σ):  0.00% ± 0.00%

  █    
  █▂▃▁▅▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▁
  440 ns          Histogram: frequency by time          702 ns <

 Memory estimate: 0 bytes, allocs estimate: 0.

julia> f(x, y) = min(x, y), max(x, y)
f (generic function with 1 method)

julia> @benchmark $zz .= f.($a, $b)
BenchmarkTools.Trial: 10000 samples with 194 evaluations.
 Range (min  max):  497.423 ns   3.276 μs  ┊ GC (min  max): 0.00%  0.00%
 Time  (median):     498.969 ns              ┊ GC (median):    0.00%
 Time  (mean ± σ):   515.506 ns ± 55.169 ns  ┊ GC (mean ± σ):  0.00% ± 0.00%

  █▄▂ ▆▅▄▂▁                                                    ▁
  ██████████▇▆▆▅▄▃▆▆▆▅▄▅▄▄▅▃▅▅▃▃▃▂▄▃▃▃▄▃▄▄▄▄▄▄▄▄▄▄▄▅▄▅▄▅▄▅▄▄▃▄ █
  497 ns        Histogram: log(frequency) by time       759 ns <

 Memory estimate: 0 bytes, allocs estimate: 0.

Their LLVM IR only have order difference. I'm not sure why this matters though.

diff = x - y
sdiff = signbit(diff)
min, max = ifelse(sdiff, x, y), ifelse(sdiff, y, x)
anynan = isnan(x)|isnan(y)
ifelse(anynan, diff, min), ifelse(anynan, diff, max)
end

"""
ldexp(x, n)
Expand Down
27 changes: 14 additions & 13 deletions base/mpfr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import
cosh, sinh, tanh, sech, csch, coth, acosh, asinh, atanh, lerpi,
cbrt, typemax, typemin, unsafe_trunc, floatmin, floatmax, rounding,
setrounding, maxintfloat, widen, significand, frexp, tryparse, iszero,
isone, big, _string_n, decompose
isone, big, _string_n, decompose, minmax

import ..Rounding: rounding_raw, setrounding_raw

Expand Down Expand Up @@ -697,20 +697,21 @@ function log1p(x::BigFloat)
return z
end

function max(x::BigFloat, y::BigFloat)
isnan(x) && return x
isnan(y) && return y
z = BigFloat()
ccall((:mpfr_max, :libmpfr), Int32, (Ref{BigFloat}, Ref{BigFloat}, Ref{BigFloat}, MPFRRoundingMode), z, x, y, ROUNDING_MODE[])
return z
# For `min`/`max`, general fallback for `AbstractFloat` is good enough.
# Only implement `minmax` and `_extrema_rf` to avoid repeated calls.
function minmax(x::BigFloat, y::BigFloat)
isnan(x) && return x, x
isnan(y) && return y, y
Base.Math._isless(x, y) ? (x, y) : (y, x)
end

function min(x::BigFloat, y::BigFloat)
isnan(x) && return x
isnan(y) && return y
z = BigFloat()
ccall((:mpfr_min, :libmpfr), Int32, (Ref{BigFloat}, Ref{BigFloat}, Ref{BigFloat}, MPFRRoundingMode), z, x, y, ROUNDING_MODE[])
return z
function Base._extrema_rf(x::NTuple{2,BigFloat}, y::NTuple{2,BigFloat})
(x1, x2), (y1, y2) = x, y
isnan(x1) && return x
isnan(y1) && return y
z1 = Base.Math._isless(x1, y1) ? x1 : y1
z2 = Base.Math._isless(x2, y2) ? y2 : x2
z1, z2
end

function modf(x::BigFloat)
Expand Down
9 changes: 8 additions & 1 deletion base/reduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -855,8 +855,15 @@ end
ExtremaMap(::Type{T}) where {T} = ExtremaMap{Type{T}}(T)
@inline (f::ExtremaMap)(x) = (y = f.f(x); (y, y))

# TODO: optimize for inputs <: AbstractFloat
@inline _extrema_rf((min1, max1), (min2, max2)) = (min(min1, min2), max(max1, max2))
# optimization for IEEEFloat
function _extrema_rf(x::NTuple{2,T}, y::NTuple{2,T}) where {T<:IEEEFloat}
(x1, x2), (y1, y2) = x, y
anynan = isnan(x1)|isnan(y1)
z1 = ifelse(anynan, x1-y1, ifelse(signbit(x1-y1), x1, y1))
z2 = ifelse(anynan, x1-y1, ifelse(signbit(x2-y2), y2, x2))
z1, z2
end

## findmax, findmin, argmax & argmin

Expand Down
90 changes: 62 additions & 28 deletions test/numbers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,34 +95,68 @@ end
@test max(1) === 1
@test minmax(1) === (1, 1)
@test minmax(5, 3) == (3, 5)
@test minmax(3., 5.) == (3., 5.)
@test minmax(5., 3.) == (3., 5.)
@test minmax(3., NaN) ≣ (NaN, NaN)
@test minmax(NaN, 3) ≣ (NaN, NaN)
@test minmax(Inf, NaN) ≣ (NaN, NaN)
@test minmax(NaN, Inf) ≣ (NaN, NaN)
@test minmax(-Inf, NaN) ≣ (NaN, NaN)
@test minmax(NaN, -Inf) ≣ (NaN, NaN)
@test minmax(NaN, NaN) ≣ (NaN, NaN)
@test min(-0.0,0.0) === min(0.0,-0.0)
@test max(-0.0,0.0) === max(0.0,-0.0)
@test minmax(-0.0,0.0) === minmax(0.0,-0.0)
@test max(-3.2, 5.1) == max(5.1, -3.2) == 5.1
@test min(-3.2, 5.1) == min(5.1, -3.2) == -3.2
@test max(-3.2, Inf) == max(Inf, -3.2) == Inf
@test max(-3.2, NaN) ≣ max(NaN, -3.2) ≣ NaN
@test min(5.1, Inf) == min(Inf, 5.1) == 5.1
@test min(5.1, -Inf) == min(-Inf, 5.1) == -Inf
@test min(5.1, NaN) ≣ min(NaN, 5.1) ≣ NaN
@test min(5.1, -NaN) ≣ min(-NaN, 5.1) ≣ NaN
@test minmax(-3.2, 5.1) == (min(-3.2, 5.1), max(-3.2, 5.1))
@test minmax(-3.2, Inf) == (min(-3.2, Inf), max(-3.2, Inf))
@test minmax(-3.2, NaN) ≣ (min(-3.2, NaN), max(-3.2, NaN))
@test (max(Inf,NaN), max(-Inf,NaN), max(Inf,-NaN), max(-Inf,-NaN)) ≣ (NaN,NaN,NaN,NaN)
@test (max(NaN,Inf), max(NaN,-Inf), max(-NaN,Inf), max(-NaN,-Inf)) ≣ (NaN,NaN,NaN,NaN)
@test (min(Inf,NaN), min(-Inf,NaN), min(Inf,-NaN), min(-Inf,-NaN)) ≣ (NaN,NaN,NaN,NaN)
@test (min(NaN,Inf), min(NaN,-Inf), min(-NaN,Inf), min(-NaN,-Inf)) ≣ (NaN,NaN,NaN,NaN)
@test minmax(-Inf,NaN) ≣ (min(-Inf,NaN), max(-Inf,NaN))
Top(T, op, x, y) = op(T.(x), T.(y))
Top(T, op) = (x, y) -> Top(T, op, x, y)
_compare(x, y) = x == y
for T in (Float16, Float32, Float64, BigFloat)
minmax = Top(T,Base.minmax)
min = Top(T,Base.min)
max = Top(T,Base.max)
(==) = Top(T,_compare)
(===) = Top(T,Base.isequal) # we only use === to compare -0.0/0.0, `isequal` should be equalvient
@test minmax(3., 5.) == (3., 5.)
@test minmax(5., 3.) == (3., 5.)
@test minmax(3., NaN) ≣ (NaN, NaN)
@test minmax(NaN, 3) ≣ (NaN, NaN)
@test minmax(Inf, NaN) ≣ (NaN, NaN)
@test minmax(NaN, Inf) ≣ (NaN, NaN)
@test minmax(-Inf, NaN) ≣ (NaN, NaN)
@test minmax(NaN, -Inf) ≣ (NaN, NaN)
@test minmax(NaN, NaN) ≣ (NaN, NaN)
@test min(-0.0,0.0) === min(0.0,-0.0)
@test max(-0.0,0.0) === max(0.0,-0.0)
@test minmax(-0.0,0.0) === minmax(0.0,-0.0)
@test max(-3.2, 5.1) == max(5.1, -3.2) == 5.1
@test min(-3.2, 5.1) == min(5.1, -3.2) == -3.2
@test max(-3.2, Inf) == max(Inf, -3.2) == Inf
@test max(-3.2, NaN) ≣ max(NaN, -3.2) ≣ NaN
@test min(5.1, Inf) == min(Inf, 5.1) == 5.1
@test min(5.1, -Inf) == min(-Inf, 5.1) == -Inf
@test min(5.1, NaN) ≣ min(NaN, 5.1) ≣ NaN
@test min(5.1, -NaN) ≣ min(-NaN, 5.1) ≣ NaN
@test minmax(-3.2, 5.1) == (min(-3.2, 5.1), max(-3.2, 5.1))
@test minmax(-3.2, Inf) == (min(-3.2, Inf), max(-3.2, Inf))
@test minmax(-3.2, NaN) ≣ (min(-3.2, NaN), max(-3.2, NaN))
@test (max(Inf,NaN), max(-Inf,NaN), max(Inf,-NaN), max(-Inf,-NaN)) ≣ (NaN,NaN,NaN,NaN)
@test (max(NaN,Inf), max(NaN,-Inf), max(-NaN,Inf), max(-NaN,-Inf)) ≣ (NaN,NaN,NaN,NaN)
@test (min(Inf,NaN), min(-Inf,NaN), min(Inf,-NaN), min(-Inf,-NaN)) ≣ (NaN,NaN,NaN,NaN)
@test (min(NaN,Inf), min(NaN,-Inf), min(-NaN,Inf), min(-NaN,-Inf)) ≣ (NaN,NaN,NaN,NaN)
@test minmax(-Inf,NaN) ≣ (min(-Inf,NaN), max(-Inf,NaN))
end
end
@testset "Base._extrema_rf for float" begin
for T in (Float16, Float32, Float64, BigFloat)
ordered = T[-Inf, -5, -0.0, 0.0, 3, Inf]
unorded = T[NaN, -NaN]
for i1 in 1:6, i2 in 1:6, j1 in 1:6, j2 in 1:6
x = ordered[i1], ordered[i2]
y = ordered[j1], ordered[j2]
z = ordered[min(i1,j1)], ordered[max(i2,j2)]
@test Base._extrema_rf(x, y) === z
end
for i in 1:2, j1 in 1:6, j2 in 1:6 # unordered test (only 1 NaN)
x = unorded[i] , unorded[i]
y = ordered[j1], ordered[j2]
@test Base._extrema_rf(x, y) === x
@test Base._extrema_rf(y, x) === x
end
for i in 1:2, j in 1:2 # unordered test (2 NaNs)
x = unorded[i], unorded[i]
y = unorded[j], unorded[j]
z = Base._extrema_rf(x, y)
@test z === x || z === y
end
end
end
@testset "fma" begin
let x = Int64(7)^7
Expand Down