From 307c314b72fea4748b4b1371b095ee7e10ccd0c5 Mon Sep 17 00:00:00 2001 From: John Myles White Date: Tue, 5 May 2020 21:23:04 -0400 Subject: [PATCH] Improve accuracy of logistic --- src/basicfuns.jl | 28 ++++++++++++++++++++++++++++ test/basicfuns.jl | 4 ++++ 2 files changed, 32 insertions(+) diff --git a/src/basicfuns.jl b/src/basicfuns.jl index 51c7941..e9bacc2 100644 --- a/src/basicfuns.jl +++ b/src/basicfuns.jl @@ -21,6 +21,20 @@ Return `x * log(y)` for `y > 0` with correct limit at `x = 0`. xlogy(x::T, y::T) where {T<:Real} = x > zero(T) ? x * log(y) : zero(log(x)) xlogy(x::Real, y::Real) = xlogy(promote(x, y)...) +# The following bounds are precomputed versions of the following abstract +# function, but the implicit interface for AbstractFloat doesn't uniformly +# enforce that all floating point types implement nextfloat and prevfloat. +# @inline function _logistic_bounds(x::AbstractFloat) +# ( +# logit(nextfloat(zero(float(x)))), +# logit(prevfloat(one(float(x)))), +# ) +# end + +@inline _logistic_bounds(x::Float16) = (Float16(-16.64), Float16(7.625)) +@inline _logistic_bounds(x::Float32) = (-103.27893f0, 16.635532f0) +@inline _logistic_bounds(x::Float64) = (-744.4400719213812, 36.7368005696771) + """ logistic(x::Real) @@ -33,6 +47,20 @@ Its inverse is the [`logit`](@ref) function. """ logistic(x::Real) = inv(exp(-x) + one(x)) +function logistic(x::Union{Float16, Float32, Float64}) + e = exp(x) + lower, upper = _logistic_bounds(x) + ifelse( + x < lower, + zero(x), + ifelse( + x > upper, + one(x), + e / (one(x) + e) + ) + ) +end + """ logit(p::Real) diff --git a/test/basicfuns.jl b/test/basicfuns.jl index 0bec7c6..f52289a 100644 --- a/test/basicfuns.jl +++ b/test/basicfuns.jl @@ -10,6 +10,10 @@ end @testset "logistic & logit" begin @test logistic(2) ≈ 1.0 / (1.0 + exp(-2.0)) + @test logistic(-750.0) === 0.0 + @test logistic(-740.0) > 0.0 + @test logistic(+36.0) < 1.0 + @test logistic(+750.0) === 1.0 @test iszero(logit(0.5)) @test logit(logistic(2)) ≈ 2.0 end