Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add logsubexp(x,y) #95

Merged
merged 16 commits into from
May 21, 2020
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
*.jl.cov
*.jl.mem

cossio marked this conversation as resolved.
Show resolved Hide resolved
Manifest.toml
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "StatsFuns"
uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
version = "0.9.4"
version = "0.9.5"

[deps]
Rmath = "79098fc4-a85e-5d69-aa6a-4863f24498fa"
Expand Down
1 change: 1 addition & 0 deletions src/StatsFuns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ export
log1pmx, # log(1 + x) - x
logmxp1, # log(x) - x + 1
logaddexp, # log(exp(x) + exp(y))
logsubexp, # log(abs(e^x - e^y))
logsumexp, # log(sum(exp(x)))
softmax, # exp(x_i) / sum(exp(x)), for i
softmax!, # inplace softmax
Expand Down
40 changes: 27 additions & 13 deletions src/basicfuns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,32 @@
"""
xlogx(x::Real)
cossio marked this conversation as resolved.
Show resolved Hide resolved

Return `x * log(x)` for `x ≥ 0`, handling `x = 0` by taking the downward limit.
Computes `x * log(x)`, with the correct limit at `x = 0`.
cossio marked this conversation as resolved.
Show resolved Hide resolved

```jldoctest
julia> StatsFuns.xlogx(0)
0.0
```
"""
xlogx(x::Real) = x > zero(x) ? x * log(x) : zero(log(x))
function xlogx(x)
cossio marked this conversation as resolved.
Show resolved Hide resolved
result = x * log(x)
ifelse(iszero(x), zero(result), result)
cossio marked this conversation as resolved.
Show resolved Hide resolved
end

"""
xlogy(x::Real, y::Real)

Return `x * log(y)` for `y > 0` with correct limit at `x = 0`.
Computes `x * log(y)`, with the correct limit at `x = 0`.
cossio marked this conversation as resolved.
Show resolved Hide resolved

```jldoctest
julia> StatsFuns.xlogy(0, 0)
0.0
```
"""
xlogy(x::T, y::T) where {T<:Real} = x > zero(T) ? x * log(y) : zero(log(x))
xlogy(x::Real, y::Real) = xlogy(promote(x, y)...)
function xlogy(x, y)
result = x * log(y)
ifelse(iszero(x), zero(result), result)
end

# The following bounds are precomputed versions of the following abstract
# function, but the implicit interface for AbstractFloat doesn't uniformly
Expand Down Expand Up @@ -196,16 +206,20 @@ end

Return `log(exp(x) + exp(y))`, avoiding intermediate overflow/undeflow, and handling non-finite values.
"""
function logaddexp(x::T, y::T) where T<:Real
# x or y is NaN => NaN
# x or y is +Inf => +Inf
# x or y is -Inf => other value
isfinite(x) && isfinite(y) || return max(x,y)
x > y ? x + log1p(exp(y - x)) : y + log1p(exp(x - y))
function logaddexp(x::Real, y::Real)
# ensure Δ = 0 if x = y = Inf
cossio marked this conversation as resolved.
Show resolved Hide resolved
Δ = ifelse(x == y, zero(x - y), abs(x - y))
max(x, y) + log1pexp(-Δ)
end
logaddexp(x::Real, y::Real) = logaddexp(promote(x, y)...)

Base.@deprecate logsumexp(x::Real, y::Real) logaddexp(x,y)
Base.@deprecate logsumexp(x::Real, y::Real) logaddexp(x, y)

"""
logsubexp(x, y)

Return `log(abs(e^x - e^y))`, preserving numerical accuracy.
"""
logsubexp(x::Real, y::Real) = max(x, y) + log1mexp(-abs(x - y))
cossio marked this conversation as resolved.
Show resolved Hide resolved

"""
logsumexp(X)
Expand Down
28 changes: 25 additions & 3 deletions test/basicfuns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,13 @@ using StatsFuns, Test
@testset "xlogx & xlogy" begin
@test iszero(xlogx(0))
@test xlogx(2) ≈ 2.0 * log(2.0)
@test_throws DomainError xlogx(-1)

@test iszero(xlogy(0, 1))
@test xlogy(2, 3) ≈ 2.0 * log(3.0)
@test_throws DomainError xlogy(1, -1)
# we allow negative `x`, https://github.com/JuliaStats/StatsFuns.jl/pull/95#discussion_r427558736
@test xlogy(-2, 3) == -xlogy(2, 3)
end

@testset "logistic & logit" begin
Expand Down Expand Up @@ -88,15 +92,33 @@ end
([-Inf, Inf], Inf),
([-Inf, 9.0], 9.0),
([Inf, 9.0], Inf),
([NaN, 9.0], NaN), # NaN propagation
([NaN, Inf], NaN), # NaN propagation
([NaN, -Inf], NaN), # NaN propagation
([0, 0], log(2.0))] # non-float arguments
for (arguments, result) in cases
@test logaddexp(arguments...) ≡ result
@test logsumexp(arguments) ≡ result
end
end

@test isnan(logsubexp(Inf, Inf))
@test isnan(logsubexp(-Inf, -Inf))
@test logsubexp(Inf, 9.0) ≡ Inf
@test logsubexp(-Inf, 9.0) ≡ 9.0
@test logsubexp(1f2, 1f2) ≡ -Inf32
@test logsubexp(0, 0) ≡ -Inf
@test logsubexp(3, 2) ≈ 2.541324854612918108978

# NaN propagation
@test isnan(logaddexp(NaN, 9.0))
@test isnan(logaddexp(NaN, Inf))
@test isnan(logaddexp(NaN, -Inf))

@test isnan(logsubexp(NaN, 9.0))
@test isnan(logsubexp(NaN, Inf))
@test isnan(logsubexp(NaN, -Inf))

@test isnan(logsumexp([NaN, 9.0]))
@test isnan(logsumexp([NaN, Inf]))
@test isnan(logsumexp([NaN, -Inf]))
end

@testset "softmax" begin
Expand Down