Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add logsubexp(x,y) #95

Merged
merged 16 commits into from
May 21, 2020
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
*.jl.cov
*.jl.mem
Manifest.toml
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "StatsFuns"
uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
version = "0.9.4"
version = "0.9.5"

[deps]
Rmath = "79098fc4-a85e-5d69-aa6a-4863f24498fa"
Expand Down
1 change: 1 addition & 0 deletions src/StatsFuns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ export
log1pmx, # log(1 + x) - x
logmxp1, # log(x) - x + 1
logaddexp, # log(exp(x) + exp(y))
logsubexp, # log(abs(e^x - e^y))
logsumexp, # log(sum(exp(x)))
softmax, # exp(x_i) / sum(exp(x)), for i
softmax!, # inplace softmax
Expand Down
40 changes: 27 additions & 13 deletions src/basicfuns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,32 @@
"""
xlogx(x::Real)
cossio marked this conversation as resolved.
Show resolved Hide resolved

Return `x * log(x)` for `x ≥ 0`, handling `x = 0` by taking the downward limit.
Compute `x * log(x)`, returning zero if `x = 0`.

```jldoctest
julia> StatsFuns.xlogx(0)
0.0
```
"""
xlogx(x::Real) = x > zero(x) ? x * log(x) : zero(log(x))
function xlogx(x::Number)
result = x * log(x)
ifelse(iszero(x), zero(result), result)
cossio marked this conversation as resolved.
Show resolved Hide resolved
end

"""
xlogy(x::Real, y::Real)

Return `x * log(y)` for `y > 0` with correct limit at `x = 0`.
Compute `x * log(y)`, returning zero if `x = 0`.

```jldoctest
julia> StatsFuns.xlogy(0, 0)
0.0
```
"""
xlogy(x::T, y::T) where {T<:Real} = x > zero(T) ? x * log(y) : zero(log(x))
xlogy(x::Real, y::Real) = xlogy(promote(x, y)...)
function xlogy(x::Number, y::Number)
result = x * log(y)
ifelse(iszero(x) && !isnan(y), zero(result), result)
end

# The following bounds are precomputed versions of the following abstract
# function, but the implicit interface for AbstractFloat doesn't uniformly
Expand Down Expand Up @@ -196,16 +206,20 @@ end

Return `log(exp(x) + exp(y))`, avoiding intermediate overflow/undeflow, and handling non-finite values.
"""
function logaddexp(x::T, y::T) where T<:Real
# x or y is NaN => NaN
# x or y is +Inf => +Inf
# x or y is -Inf => other value
isfinite(x) && isfinite(y) || return max(x,y)
x > y ? x + log1p(exp(y - x)) : y + log1p(exp(x - y))
function logaddexp(x::Real, y::Real)
# ensure Δ = 0 if x = y = Inf
cossio marked this conversation as resolved.
Show resolved Hide resolved
Δ = ifelse(x == y, zero(x - y), abs(x - y))
max(x, y) + log1pexp(-Δ)
end
logaddexp(x::Real, y::Real) = logaddexp(promote(x, y)...)

Base.@deprecate logsumexp(x::Real, y::Real) logaddexp(x,y)
Base.@deprecate logsumexp(x::Real, y::Real) logaddexp(x, y)

"""
logsubexp(x, y)

Return `log(abs(e^x - e^y))`, preserving numerical accuracy.
"""
logsubexp(x::Real, y::Real) = max(x, y) + log1mexp(-abs(x - y))
cossio marked this conversation as resolved.
Show resolved Hide resolved

"""
logsumexp(X)
Expand Down
48 changes: 44 additions & 4 deletions test/basicfuns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,35 @@ using StatsFuns, Test
@testset "xlogx & xlogy" begin
@test iszero(xlogx(0))
@test xlogx(2) ≈ 2.0 * log(2.0)
@test_throws DomainError xlogx(-1)
@test isnan(xlogx(NaN))

@test iszero(xlogy(0, 1))
@test xlogy(2, 3) ≈ 2.0 * log(3.0)
@test_throws DomainError xlogy(1, -1)
@test isnan(xlogy(NaN, 2))
@test isnan(xlogy(2, NaN))
@test isnan(xlogy(0, NaN))

# Since we allow complex/negative values, test for them. See comments in:
cossio marked this conversation as resolved.
Show resolved Hide resolved
# https://github.com/JuliaStats/StatsFuns.jl/pull/95

@test xlogx(1 + im) == (1 + im) * log(1 + im)
@test isnan(xlogx(NaN + im))
@test isnan(xlogx(1 + NaN * im))

@test xlogy(-2, 3) == -xlogy(2, 3)
@test xlogy(1 + im, 3) == (1 + im) * log(3)
@test xlogy(1 + im, 2 + im) == (1 + im) * log(2 + im)
@test isnan(xlogy(1 + NaN * im, -1 + im))
@test isnan(xlogy(0, -1 + NaN * im))
@test isnan(xlogy(Inf + im * NaN, 1))
@test isnan(xlogy(0 + im * 0, NaN))
@test iszero(xlogy(0 + im * 0, 0 + im * Inf))
end

@testset "logistic & logit" begin
@test logistic(2) ≈ 1.0 / (1.0 + exp(-2.0))
@test logistic(2) ≈ 1.0 / (1.0 + exp(-2.0))
@test logistic(-750.0) === 0.0
@test logistic(-740.0) > 0.0
@test logistic(+36.0) < 1.0
Expand Down Expand Up @@ -88,15 +110,33 @@ end
([-Inf, Inf], Inf),
([-Inf, 9.0], 9.0),
([Inf, 9.0], Inf),
([NaN, 9.0], NaN), # NaN propagation
([NaN, Inf], NaN), # NaN propagation
([NaN, -Inf], NaN), # NaN propagation
([0, 0], log(2.0))] # non-float arguments
for (arguments, result) in cases
@test logaddexp(arguments...) ≡ result
@test logsumexp(arguments) ≡ result
end
end

@test isnan(logsubexp(Inf, Inf))
@test isnan(logsubexp(-Inf, -Inf))
@test logsubexp(Inf, 9.0) ≡ Inf
@test logsubexp(-Inf, 9.0) ≡ 9.0
@test logsubexp(1f2, 1f2) ≡ -Inf32
@test logsubexp(0, 0) ≡ -Inf
@test logsubexp(3, 2) ≈ 2.541324854612918108978

# NaN propagation
@test isnan(logaddexp(NaN, 9.0))
@test isnan(logaddexp(NaN, Inf))
@test isnan(logaddexp(NaN, -Inf))

@test isnan(logsubexp(NaN, 9.0))
@test isnan(logsubexp(NaN, Inf))
@test isnan(logsubexp(NaN, -Inf))

@test isnan(logsumexp([NaN, 9.0]))
@test isnan(logsumexp([NaN, Inf]))
@test isnan(logsumexp([NaN, -Inf]))
end

@testset "softmax" begin
Expand Down