Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve accuracy of logistic #94

Merged
merged 1 commit into from
May 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions src/basicfuns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,20 @@ Return `x * log(y)` for `y > 0` with correct limit at `x = 0`.
xlogy(x::T, y::T) where {T<:Real} = x > zero(T) ? x * log(y) : zero(log(x))
xlogy(x::Real, y::Real) = xlogy(promote(x, y)...)

# The following bounds are precomputed versions of the following abstract
# function, but the implicit interface for AbstractFloat doesn't uniformly
# enforce that all floating point types implement nextfloat and prevfloat.
# @inline function _logistic_bounds(x::AbstractFloat)
# (
# logit(nextfloat(zero(float(x)))),
# logit(prevfloat(one(float(x)))),
# )
# end

@inline _logistic_bounds(x::Float16) = (Float16(-16.64), Float16(7.625))
@inline _logistic_bounds(x::Float32) = (-103.27893f0, 16.635532f0)
@inline _logistic_bounds(x::Float64) = (-744.4400719213812, 36.7368005696771)

"""
logistic(x::Real)

Expand All @@ -33,6 +47,20 @@ Its inverse is the [`logit`](@ref) function.
"""
logistic(x::Real) = inv(exp(-x) + one(x))

function logistic(x::Union{Float16, Float32, Float64})
e = exp(x)
lower, upper = _logistic_bounds(x)
ifelse(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like a quite convoluted way of writing it over just if, else etc?

function logistic(x::Real)
    e = exp(x)
    lower, upper = _logistic_bounds(x)
    if x < lower
        return zero(x)
    elseif x > upper
        return one(x)
    else
        return e / (one(x) + e)
    end
end

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably worth checking @code_llvm; it's true that LLVM now does the conversion of if-else => ifelse automatically pretty well these days.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They seem to produce different output at the @code_llvm and @code_native stages consistent with my earlier experience that ifelse generates LLVM select and if generates LLVM br, but benchmarking suggests their performance is very similar in scalar applications and broadcasting: https://gist.github.com/johnmyleswhite/548f4eb18a028a237d52ae06811ca33c

Note that the benchmarks also suggest that changing from the old implementation to the new implementation causes performance to drop because of processing subnormals in a mirror image: the old formulation was fast for z = -710.0 and slow for z = +710.0, whereas the new formulation is slow for z = -710.0 and fast for z = +710.0.

Happy to choose whichever form people prefer. I tend to defer to the side of assuming select is more likely to play well with SIMD, but it doesn't seem like that issue applies in this case for the benchmarks I've run.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to choose whichever form people prefer.

Personally, I don't really care, just felt like the if version was easier to read and I couldn't find a benchmark where it was slower. I guess if you correctly predict the branch version is faster and otherwise the branchless is faster? 🤷‍♂️

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I think in the absence of useful evidence either way I'll maintain my existing superstitious faith in the importance of the LLVM select instruction. Revising in the future would be trivial.

x < lower,
zero(x),
ifelse(
x > upper,
one(x),
e / (one(x) + e)
)
)
end

"""
logit(p::Real)

Expand Down
4 changes: 4 additions & 0 deletions test/basicfuns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ end

@testset "logistic & logit" begin
@test logistic(2) ≈ 1.0 / (1.0 + exp(-2.0))
@test logistic(-750.0) === 0.0
@test logistic(-740.0) > 0.0
@test logistic(+36.0) < 1.0
@test logistic(+750.0) === 1.0
@test iszero(logit(0.5))
@test logit(logistic(2)) ≈ 2.0
end
Expand Down