Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add onepass algorithm for logsumexp #97

Merged
merged 14 commits into from
Sep 23, 2020
89 changes: 73 additions & 16 deletions src/basicfuns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ end
Return `log(exp(x) + exp(y))`, avoiding intermediate overflow/undeflow, and handling non-finite values.
"""
function logaddexp(x::Real, y::Real)
# ensure Δ = 0 if x = y = Inf
# ensure Δ = 0 if x = y = ± Inf
Δ = ifelse(x == y, zero(x - y), abs(x - y))
max(x, y) + log1pexp(-Δ)
end
Expand All @@ -224,28 +224,85 @@ logsubexp(x::Real, y::Real) = max(x, y) + log1mexp(-abs(x - y))
"""
logsumexp(X)

Compute `log(sum(exp, X))`, evaluated avoiding intermediate overflow/undeflow.
Compute `log(sum(exp, X))` in a numerically stable way that avoids intermediate over- and
underflow.

`X` should be an iterator of real numbers. The result is computed using a single pass over
the data.

# References

[Sebastian Nowozin: Streaming Log-sum-exp Computation.](http://www.nowozin.net/sebastian/blog/streaming-log-sum-exp-computation.html)
"""
logsumexp(X) = logsumexp_onepass(X)

`X` should be an iterator of real numbers.
"""
function logsumexp(X)
logsumexp(X::AbstractArray{<:Real}[; dims=:])
devmotion marked this conversation as resolved.
Show resolved Hide resolved

Compute `log.(sum(exp.(X); dims=dims))` in a numerically stable way that avoids
intermediate over- and underflow.

If `dims = :`, then the result is computed using a single pass over the data.

# References

[Sebastian Nowozin: Streaming Log-sum-exp Computation.](http://www.nowozin.net/sebastian/blog/streaming-log-sum-exp-computation.html)
"""
logsumexp(X::AbstractArray{<:Real}; dims=:) = _logsumexp(X, dims)

_logsumexp(X::AbstractArray{<:Real}, ::Colon) = logsumexp_onepass(X)
function _logsumexp(X::AbstractArray{<:Real}, dims)
# Do not use log(zero(eltype(X))) directly to avoid issues with ForwardDiff (#82)
u = reduce(max, X, dims=dims, init=oftype(log(zero(eltype(X))), -Inf))
return u .+ log.(sum(exp.(X .- u); dims=dims))
end
devmotion marked this conversation as resolved.
Show resolved Hide resolved

function logsumexp_onepass(X)
# fallback for empty collections
isempty(X) && return log(sum(X))
reduce(logaddexp, X)

xmax, r = _logsumexp_onepass(X, Base.IteratorEltype(X))

return xmax + log(r)
end
function logsumexp(X::AbstractArray{T}; dims=:) where {T<:Real}
# Do not use log(zero(T)) directly to avoid issues with ForwardDiff (#82)
u = reduce(max, X, dims=dims, init=oftype(log(zero(T)), -Inf))
u isa AbstractArray || isfinite(u) || return float(u)
let u=u # avoid https://github.com/JuliaLang/julia/issues/15276
# TODO: remove the branch when JuliaLang/julia#31020 is merged.
if u isa AbstractArray
u .+ log.(sum(exp.(X .- u); dims=dims))
else
u + log(sum(x -> exp(x-u), X))
end

# with initial element: required by CUDA
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it's also required by Array when the input is empty, right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, empty inputs are handled by log(sum(X)), as before. Actually we never want to use this method apart for GPU arrays - the type calculations are doomed to fail e.g. for not concretely typed arrays and hence it is much safer to not rely on them and instead just compute the output. I'm not sure why the type calculations in https://github.com/JuliaGPU/GPUArrays.jl/blob/356c5fe3f83f76f8139b4edf5536cbd08d48da7f/src/host/mapreduce.jl#L49-L52 fail, I can check if this can be fixed from our side. Alternatively, maybe another package such as NNlib should provide a GPU-compatible overload of StatsFuns.logsumexp?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. They fail because neutral_element is only defined for a few particular functions. Another approach would be to use the type of the first element, since we already handle the case where the array is empty. That way we wouldn't need two different methods depending on whether the eltype is known.

Given that it seems doable, it would be nice to have something that works for all arrays without another package having to overload the method.

Copy link
Member Author

@devmotion devmotion Sep 5, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO it is just bad that GPUArrays requires us to provide an initial or neutral element. We really never want to do that, regardless if there is a first element or not. Retrieving the first element is bad for any stateful iterator, and is a scalar indexing operation on the GPU (which leads to warnings when calling the function - and we can not remove them without depending on CUDA). Additionally, the docstring of mapreduce says that

It is unspecified whether init is used for non-empty collections.

so there is really no point in providing init here - apart from GPU compatibility. The problem is that we can not restrict our implementation with the init hack to only GPU arrays without taking a dependency on CUDA. Hence I don't think it is completely unreasonable to provide a GPU-compatible implementation somewhere else. An alternative would be to not use the one-pass algorithm as the default algorithm (for arrays).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Retrieving the first element should be OK for any AbstractArray (contrary to some iterables). It's too bad that CUDA prints warnings in that case.

Let's keep what you have now: since the method only accepts arrays with T<:Real element type, float(T) will (almost) certainly work.

function _logsumexp_onepass(X, ::Base.HasEltype)
# compute initial element
FT = float(eltype(X))
init = (FT(-Inf), zero(FT))
r_one = one(FT)

# perform single pass over the data
return mapreduce(_logsumexp_onepass_op, X; init=init) do x
return float(x), r_one
end
end

# without initial element
function _logsumexp_onepass(X, ::Base.EltypeUnknown)
return mapreduce(_logsumexp_onepass_op, X) do x
_x = float(x)
return _x, one(_x)
end
end

function _logsumexp_onepass_op((xmax1, r1)::T, (xmax2, r2)::T) where {T<:Tuple}
if xmax1 < xmax2
xmax = xmax2
a = exp(xmax1 - xmax2)
r = r2 + ifelse(isone(r1), a, r1 * a)
devmotion marked this conversation as resolved.
Show resolved Hide resolved
elseif xmax1 > xmax2
xmax = xmax1
a = exp(xmax2 - xmax1)
r = r1 + ifelse(isone(r2), a, r2 * a)
else # ensure finite values if x = xmax = ± Inf
xmax = ifelse(isnan(xmax1), xmax1, xmax2)
r = r1 + r2
end

return xmax, r
end

"""
softmax!(r::AbstractArray, x::AbstractArray)
Expand Down
6 changes: 6 additions & 0 deletions test/basicfuns.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using StatsFuns, Test
using StatsFuns: logsumexp_onepass
devmotion marked this conversation as resolved.
Show resolved Hide resolved

@testset "xlogx & xlogy" begin
@test iszero(xlogx(0))
Expand Down Expand Up @@ -137,6 +138,11 @@ end
@test isnan(logsumexp([NaN, 9.0]))
@test isnan(logsumexp([NaN, Inf]))
@test isnan(logsumexp([NaN, -Inf]))

# issue #63
a = logsumexp(i for i in range(-500, stop = 10, length = 1000) if true)
b = logsumexp(range(-500, stop = 10, length = 1000))
@test a == b
devmotion marked this conversation as resolved.
Show resolved Hide resolved
end
devmotion marked this conversation as resolved.
Show resolved Hide resolved

@testset "softmax" begin
Expand Down