From 52a63ed95005d80e73ce39e9d09fa2e0714759e3 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Mon, 8 Jun 2020 21:04:08 +0200 Subject: [PATCH 01/14] Add `logsumexp_onepass` --- src/basicfuns.jl | 51 ++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 45 insertions(+), 6 deletions(-) diff --git a/src/basicfuns.jl b/src/basicfuns.jl index d2518db..46e0e25 100644 --- a/src/basicfuns.jl +++ b/src/basicfuns.jl @@ -207,7 +207,7 @@ end Return `log(exp(x) + exp(y))`, avoiding intermediate overflow/undeflow, and handling non-finite values. """ function logaddexp(x::Real, y::Real) - # ensure Δ = 0 if x = y = Inf + # ensure Δ = 0 if x = y = ± Inf Δ = ifelse(x == y, zero(x - y), abs(x - y)) max(x, y) + log1pexp(-Δ) end @@ -224,14 +224,14 @@ logsubexp(x::Real, y::Real) = max(x, y) + log1mexp(-abs(x - y)) """ logsumexp(X) -Compute `log(sum(exp, X))`, evaluated avoiding intermediate overflow/undeflow. +Compute `log(sum(exp, X))` in a numerically stable way that avoids intermediate over- and +underflow. `X` should be an iterator of real numbers. + +See also: [`logsumexp_onepass`](@ref) """ -function logsumexp(X) - isempty(X) && return log(sum(X)) - reduce(logaddexp, X) -end +logsumexp(X) = logsumexp_onepass(X) function logsumexp(X::AbstractArray{T}; dims=:) where {T<:Real} # Do not use log(zero(T)) directly to avoid issues with ForwardDiff (#82) u = reduce(max, X, dims=dims, init=oftype(log(zero(T)), -Inf)) @@ -246,6 +246,45 @@ function logsumexp(X::AbstractArray{T}; dims=:) where {T<:Real} end end +""" + logsumexp_onepass(X) + +Compute `log(sum(exp, X))` in a numerically stable way that avoids intermediate under- and +overflow. + +In contrast to [`logsumexp`](@ref) the result is computed using a single pass over the data. +`X` should be an iterator of real numbers. + +# References + +[Sebastian Nowozin: Streaming Log-sum-exp Computation.](http://www.nowozin.net/sebastian/blog/streaming-log-sum-exp-computation.html) +""" +function logsumexp_onepass(X) + isempty(X) && return log(sum(X)) + + # initialize maximum value and accumulated sum for the first iterate + x, state = iterate(X) + xmax = x + r = exp(zero(x)) + r_one = r + + # for all other iterates + while (next = iterate(X, state)) !== nothing + x, state = next + + # update maximum value and accumulated sum + if x < xmax + r += exp(x - xmax) + elseif x > xmax + r = r_one + r * exp(xmax - x) + xmax = x + else # ensure finite values if x = xmax = ± Inf + r += r_one + end + end + + return xmax + log(r) +end """ softmax!(r::AbstractArray, x::AbstractArray) From 5983de984b519ec6da91daac401e1b16f94d5e41 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Mon, 8 Jun 2020 21:04:34 +0200 Subject: [PATCH 02/14] Add tests --- test/basicfuns.jl | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/test/basicfuns.jl b/test/basicfuns.jl index ce34ea4..5132895 100644 --- a/test/basicfuns.jl +++ b/test/basicfuns.jl @@ -1,4 +1,5 @@ using StatsFuns, Test +using StatsFuns: logsumexp_onepass @testset "xlogx & xlogy" begin @test iszero(xlogx(0)) @@ -97,8 +98,10 @@ end @test logaddexp(10002, 10003) ≈ 10000 + logaddexp(2.0, 3.0) @test logsumexp([1.0, 2.0, 3.0]) ≈ 3.40760596444438 + @test logsumexp_onepass([1.0, 2.0, 3.0]) ≈ 3.40760596444438 @test logsumexp((1.0, 2.0, 3.0)) ≈ 3.40760596444438 @test logsumexp([1.0, 2.0, 3.0] .+ 1000.) ≈ 1003.40760596444438 + @test logsumexp_onepass([1.0, 2.0, 3.0] .+ 1000.) ≈ 1003.40760596444438 @test logsumexp([[1.0, 2.0, 3.0] [1.0, 2.0, 3.0] .+ 1000.]; dims=1) ≈ [3.40760596444438 1003.40760596444438] @test logsumexp([[1.0 2.0 3.0]; [1.0 2.0 3.0] .+ 1000.]; dims=2) ≈ [3.40760596444438, 1003.40760596444438] @@ -114,6 +117,7 @@ end for (arguments, result) in cases @test logaddexp(arguments...) ≡ result @test logsumexp(arguments) ≡ result + @test logsumexp_onepass(arguments) ≡ result end end @@ -137,6 +141,10 @@ end @test isnan(logsumexp([NaN, 9.0])) @test isnan(logsumexp([NaN, Inf])) @test isnan(logsumexp([NaN, -Inf])) + + @test isnan(logsumexp_onepass([NaN, 9.0])) + @test isnan(logsumexp_onepass([NaN, Inf])) + @test isnan(logsumexp_onepass([NaN, -Inf])) end @testset "softmax" begin From c9e13966a8db0ee07ba7fb125cbd0de16baf1376 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Sat, 5 Sep 2020 00:38:48 +0200 Subject: [PATCH 03/14] Use `mapreduce` & handle generators, CUDA, and NaN --- src/basicfuns.jl | 70 ++++++++++++++++++++++++++++++++---------------- 1 file changed, 47 insertions(+), 23 deletions(-) diff --git a/src/basicfuns.jl b/src/basicfuns.jl index 46e0e25..3492204 100644 --- a/src/basicfuns.jl +++ b/src/basicfuns.jl @@ -232,9 +232,12 @@ underflow. See also: [`logsumexp_onepass`](@ref) """ logsumexp(X) = logsumexp_onepass(X) -function logsumexp(X::AbstractArray{T}; dims=:) where {T<:Real} - # Do not use log(zero(T)) directly to avoid issues with ForwardDiff (#82) - u = reduce(max, X, dims=dims, init=oftype(log(zero(T)), -Inf)) +logsumexp(X::AbstractArray{<:Real}; dims=:) = _logsumexp(X, dims) + +_logsumexp(X::AbstractArray{<:Real}, ::Colon) = logsumexp_onepass(X) +function _logsumexp(X::AbstractArray{<:Real}, dims) + # Do not use log(zero(eltype(X))) directly to avoid issues with ForwardDiff (#82) + u = reduce(max, X, dims=dims, init=oftype(log(zero(eltype(X))), -Inf)) u isa AbstractArray || isfinite(u) || return float(u) let u=u # avoid https://github.com/JuliaLang/julia/issues/15276 # TODO: remove the branch when JuliaLang/julia#31020 is merged. @@ -260,32 +263,53 @@ In contrast to [`logsumexp`](@ref) the result is computed using a single pass ov [Sebastian Nowozin: Streaming Log-sum-exp Computation.](http://www.nowozin.net/sebastian/blog/streaming-log-sum-exp-computation.html) """ function logsumexp_onepass(X) + # fallback for empty collections isempty(X) && return log(sum(X)) - # initialize maximum value and accumulated sum for the first iterate - x, state = iterate(X) - xmax = x - r = exp(zero(x)) - r_one = r - - # for all other iterates - while (next = iterate(X, state)) !== nothing - x, state = next - - # update maximum value and accumulated sum - if x < xmax - r += exp(x - xmax) - elseif x > xmax - r = r_one + r * exp(xmax - x) - xmax = x - else # ensure finite values if x = xmax = ± Inf - r += r_one - end - end + # perform single pass over the data + xmax, r = _logsumexp_onepass(X, Base.IteratorEltype(X)) return xmax + log(r) end +# with initial element: required by CUDA +function _logsumexp_onepass(X, ::Base.HasEltype) + # compute initial element + FT = float(eltype(X)) + init = (FT(-Inf), zero(FT)) + r_one = one(FT) + + # perform single pass over the data + return mapreduce(_logsumexp_onepass_op, X; init=init) do x + return float(x), r_one + end +end + +# without initial element +function _logsumexp_onepass(X, ::Base.EltypeUnknown) + return mapreduce(_logsumexp_onepass_op, X) do x + _x = float(x) + return _x, one(_x) + end +end + +function _logsumexp_onepass_op((xmax1, r1)::T, (xmax2, r2)::T) where {T<:Tuple} + if xmax1 < xmax2 + xmax = xmax2 + a = exp(xmax1 - xmax2) + r = r2 + ifelse(isone(r1), a, r1 * a) + elseif xmax1 > xmax2 + xmax = xmax1 + a = exp(xmax2 - xmax1) + r = r1 + ifelse(isone(r2), a, r2 * a) + else # ensure finite values if x = xmax = ± Inf + xmax = ifelse(isnan(xmax1), xmax1, xmax2) + r = r1 + r2 + end + + return xmax, r +end + """ softmax!(r::AbstractArray, x::AbstractArray) From 70c1b88978b8af29bd5fcc034265b8a5bbebfa09 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Sat, 5 Sep 2020 00:39:37 +0200 Subject: [PATCH 04/14] Add test --- test/basicfuns.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/test/basicfuns.jl b/test/basicfuns.jl index 5132895..b2e4643 100644 --- a/test/basicfuns.jl +++ b/test/basicfuns.jl @@ -145,6 +145,11 @@ end @test isnan(logsumexp_onepass([NaN, 9.0])) @test isnan(logsumexp_onepass([NaN, Inf])) @test isnan(logsumexp_onepass([NaN, -Inf])) + + # issue #63 + a = logsumexp(i for i in range(-500, stop = 10, length = 1000) if true) + b = logsumexp(range(-500, stop = 10, length = 1000)) + @test a == b end @testset "softmax" begin From 31356bd8fb5d92dd458354f50e021fc09fcecee5 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Sat, 5 Sep 2020 16:44:54 +0200 Subject: [PATCH 05/14] Simplify implementation of `logsumexp` reduction over dimensions --- src/basicfuns.jl | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/src/basicfuns.jl b/src/basicfuns.jl index 3492204..338719f 100644 --- a/src/basicfuns.jl +++ b/src/basicfuns.jl @@ -238,15 +238,7 @@ _logsumexp(X::AbstractArray{<:Real}, ::Colon) = logsumexp_onepass(X) function _logsumexp(X::AbstractArray{<:Real}, dims) # Do not use log(zero(eltype(X))) directly to avoid issues with ForwardDiff (#82) u = reduce(max, X, dims=dims, init=oftype(log(zero(eltype(X))), -Inf)) - u isa AbstractArray || isfinite(u) || return float(u) - let u=u # avoid https://github.com/JuliaLang/julia/issues/15276 - # TODO: remove the branch when JuliaLang/julia#31020 is merged. - if u isa AbstractArray - u .+ log.(sum(exp.(X .- u); dims=dims)) - else - u + log(sum(x -> exp(x-u), X)) - end - end + return u .+ log.(sum(exp.(X .- u); dims=dims)) end """ From 44eec2523bd234f9a9015a159570d390826ed04f Mon Sep 17 00:00:00 2001 From: David Widmann Date: Sat, 5 Sep 2020 16:47:38 +0200 Subject: [PATCH 06/14] Remove comment Co-authored-by: Milan Bouchet-Valat --- src/basicfuns.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/basicfuns.jl b/src/basicfuns.jl index 338719f..b2f4c16 100644 --- a/src/basicfuns.jl +++ b/src/basicfuns.jl @@ -258,7 +258,6 @@ function logsumexp_onepass(X) # fallback for empty collections isempty(X) && return log(sum(X)) - # perform single pass over the data xmax, r = _logsumexp_onepass(X, Base.IteratorEltype(X)) return xmax + log(r) From 51ca6db07041f658ed45b6c225a49a458d247f62 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Sat, 5 Sep 2020 16:57:28 +0200 Subject: [PATCH 07/14] Update documentation --- src/basicfuns.jl | 33 ++++++++++++++++++--------------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/src/basicfuns.jl b/src/basicfuns.jl index b2f4c16..e3536a9 100644 --- a/src/basicfuns.jl +++ b/src/basicfuns.jl @@ -227,33 +227,36 @@ logsubexp(x::Real, y::Real) = max(x, y) + log1mexp(-abs(x - y)) Compute `log(sum(exp, X))` in a numerically stable way that avoids intermediate over- and underflow. -`X` should be an iterator of real numbers. +`X` should be an iterator of real numbers. The result is computed using a single pass over +the data. -See also: [`logsumexp_onepass`](@ref) +# References + +[Sebastian Nowozin: Streaming Log-sum-exp Computation.](http://www.nowozin.net/sebastian/blog/streaming-log-sum-exp-computation.html) """ logsumexp(X) = logsumexp_onepass(X) -logsumexp(X::AbstractArray{<:Real}; dims=:) = _logsumexp(X, dims) - -_logsumexp(X::AbstractArray{<:Real}, ::Colon) = logsumexp_onepass(X) -function _logsumexp(X::AbstractArray{<:Real}, dims) - # Do not use log(zero(eltype(X))) directly to avoid issues with ForwardDiff (#82) - u = reduce(max, X, dims=dims, init=oftype(log(zero(eltype(X))), -Inf)) - return u .+ log.(sum(exp.(X .- u); dims=dims)) -end """ - logsumexp_onepass(X) + logsumexp(X::AbstractArray{<:Real}[; dims=:]) -Compute `log(sum(exp, X))` in a numerically stable way that avoids intermediate under- and -overflow. +Compute `log.(sum(exp.(X); dims=dims))` in a numerically stable way that avoids +intermediate over- and underflow. -In contrast to [`logsumexp`](@ref) the result is computed using a single pass over the data. -`X` should be an iterator of real numbers. +If `dims = :`, then the result is computed using a single pass over the data. # References [Sebastian Nowozin: Streaming Log-sum-exp Computation.](http://www.nowozin.net/sebastian/blog/streaming-log-sum-exp-computation.html) """ +logsumexp(X::AbstractArray{<:Real}; dims=:) = _logsumexp(X, dims) + +_logsumexp(X::AbstractArray{<:Real}, ::Colon) = logsumexp_onepass(X) +function _logsumexp(X::AbstractArray{<:Real}, dims) + # Do not use log(zero(eltype(X))) directly to avoid issues with ForwardDiff (#82) + u = reduce(max, X, dims=dims, init=oftype(log(zero(eltype(X))), -Inf)) + return u .+ log.(sum(exp.(X .- u); dims=dims)) +end + function logsumexp_onepass(X) # fallback for empty collections isempty(X) && return log(sum(X)) From 7937e09f86e9b418f7e3204823dccc4b873e8730 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Sat, 5 Sep 2020 16:58:37 +0200 Subject: [PATCH 08/14] Remove `logsumexp_onepass` tests since it is the default now --- test/basicfuns.jl | 7 ------- 1 file changed, 7 deletions(-) diff --git a/test/basicfuns.jl b/test/basicfuns.jl index b2e4643..8c47b37 100644 --- a/test/basicfuns.jl +++ b/test/basicfuns.jl @@ -98,10 +98,8 @@ end @test logaddexp(10002, 10003) ≈ 10000 + logaddexp(2.0, 3.0) @test logsumexp([1.0, 2.0, 3.0]) ≈ 3.40760596444438 - @test logsumexp_onepass([1.0, 2.0, 3.0]) ≈ 3.40760596444438 @test logsumexp((1.0, 2.0, 3.0)) ≈ 3.40760596444438 @test logsumexp([1.0, 2.0, 3.0] .+ 1000.) ≈ 1003.40760596444438 - @test logsumexp_onepass([1.0, 2.0, 3.0] .+ 1000.) ≈ 1003.40760596444438 @test logsumexp([[1.0, 2.0, 3.0] [1.0, 2.0, 3.0] .+ 1000.]; dims=1) ≈ [3.40760596444438 1003.40760596444438] @test logsumexp([[1.0 2.0 3.0]; [1.0 2.0 3.0] .+ 1000.]; dims=2) ≈ [3.40760596444438, 1003.40760596444438] @@ -117,7 +115,6 @@ end for (arguments, result) in cases @test logaddexp(arguments...) ≡ result @test logsumexp(arguments) ≡ result - @test logsumexp_onepass(arguments) ≡ result end end @@ -142,10 +139,6 @@ end @test isnan(logsumexp([NaN, Inf])) @test isnan(logsumexp([NaN, -Inf])) - @test isnan(logsumexp_onepass([NaN, 9.0])) - @test isnan(logsumexp_onepass([NaN, Inf])) - @test isnan(logsumexp_onepass([NaN, -Inf])) - # issue #63 a = logsumexp(i for i in range(-500, stop = 10, length = 1000) if true) b = logsumexp(range(-500, stop = 10, length = 1000)) From b79aceba1b76c90a404867b8b41c127ffcff9d76 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Sun, 6 Sep 2020 12:32:33 +0200 Subject: [PATCH 09/14] Apply suggestions Co-authored-by: Milan Bouchet-Valat --- src/basicfuns.jl | 2 +- test/basicfuns.jl | 8 +++----- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/src/basicfuns.jl b/src/basicfuns.jl index e3536a9..2fa421c 100644 --- a/src/basicfuns.jl +++ b/src/basicfuns.jl @@ -237,7 +237,7 @@ the data. logsumexp(X) = logsumexp_onepass(X) """ - logsumexp(X::AbstractArray{<:Real}[; dims=:]) + logsumexp(X::AbstractArray{<:Real}; dims=:) Compute `log.(sum(exp.(X); dims=dims))` in a numerically stable way that avoids intermediate over- and underflow. diff --git a/test/basicfuns.jl b/test/basicfuns.jl index 8c47b37..f4972e0 100644 --- a/test/basicfuns.jl +++ b/test/basicfuns.jl @@ -1,5 +1,4 @@ using StatsFuns, Test -using StatsFuns: logsumexp_onepass @testset "xlogx & xlogy" begin @test iszero(xlogx(0)) @@ -139,10 +138,9 @@ end @test isnan(logsumexp([NaN, Inf])) @test isnan(logsumexp([NaN, -Inf])) - # issue #63 - a = logsumexp(i for i in range(-500, stop = 10, length = 1000) if true) - b = logsumexp(range(-500, stop = 10, length = 1000)) - @test a == b + # logsumexp with general iterables (issue #63) + xs = range(-500, stop = 10, length = 1000) + @test logsumexp(x for x in xs) == logsumexp(xs) end @testset "softmax" begin From 621c4f659bde482daa8234e8146dd62e5ebbcde8 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Sun, 6 Sep 2020 13:15:04 +0200 Subject: [PATCH 10/14] Add additional check for abstract types --- src/basicfuns.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/basicfuns.jl b/src/basicfuns.jl index 2fa421c..e6280eb 100644 --- a/src/basicfuns.jl +++ b/src/basicfuns.jl @@ -268,6 +268,10 @@ end # with initial element: required by CUDA function _logsumexp_onepass(X, ::Base.HasEltype) + # do not perform type computations if element type is abstract + T = eltype(X) + isconcretetype(T) || return _logsumexp_onepass(X, Base.EltypeUnknown()) + # compute initial element FT = float(eltype(X)) init = (FT(-Inf), zero(FT)) From d0aa84aba9fc6ab6660e9fb193d538ddf5416a5c Mon Sep 17 00:00:00 2001 From: David Widmann Date: Sun, 6 Sep 2020 13:15:44 +0200 Subject: [PATCH 11/14] Update documentation --- src/basicfuns.jl | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/basicfuns.jl b/src/basicfuns.jl index e6280eb..80b04e4 100644 --- a/src/basicfuns.jl +++ b/src/basicfuns.jl @@ -266,7 +266,7 @@ function logsumexp_onepass(X) return xmax + log(r) end -# with initial element: required by CUDA +# initial element is required by CUDA (ideally we would never use this method) function _logsumexp_onepass(X, ::Base.HasEltype) # do not perform type computations if element type is abstract T = eltype(X) @@ -277,13 +277,12 @@ function _logsumexp_onepass(X, ::Base.HasEltype) init = (FT(-Inf), zero(FT)) r_one = one(FT) - # perform single pass over the data return mapreduce(_logsumexp_onepass_op, X; init=init) do x return float(x), r_one end end -# without initial element +# without initial element (ideally we would always use this method) function _logsumexp_onepass(X, ::Base.EltypeUnknown) return mapreduce(_logsumexp_onepass_op, X) do x _x = float(x) @@ -291,17 +290,19 @@ function _logsumexp_onepass(X, ::Base.EltypeUnknown) end end +# all inputs are provided as floating point numbers +# `r1` and `r2` are one if `xmax1` and `xmax2` are new elements (no partial sums) function _logsumexp_onepass_op((xmax1, r1)::T, (xmax2, r2)::T) where {T<:Tuple} if xmax1 < xmax2 xmax = xmax2 a = exp(xmax1 - xmax2) - r = r2 + ifelse(isone(r1), a, r1 * a) + r = r2 + (isone(r1) ? a : r1 * a) # avoid expensive multiplication for new elements elseif xmax1 > xmax2 xmax = xmax1 a = exp(xmax2 - xmax1) - r = r1 + ifelse(isone(r2), a, r2 * a) + r = r1 + (isone(r2) ? a : r2 * a) # avoid expensive multiplication for new elements else # ensure finite values if x = xmax = ± Inf - xmax = ifelse(isnan(xmax1), xmax1, xmax2) + xmax = isnan(xmax1) ? xmax1 : xmax2 r = r1 + r2 end From 0dff349c87057b5a404d70029342449f19b01d1b Mon Sep 17 00:00:00 2001 From: David Widmann Date: Mon, 7 Sep 2020 12:23:44 +0200 Subject: [PATCH 12/14] Use one-pass algorithm also for reductions along subsets of dimensions --- src/basicfuns.jl | 82 ++++++++++++++++++++++++++---------------------- 1 file changed, 44 insertions(+), 38 deletions(-) diff --git a/src/basicfuns.jl b/src/basicfuns.jl index 80b04e4..b3e24ee 100644 --- a/src/basicfuns.jl +++ b/src/basicfuns.jl @@ -234,7 +234,7 @@ the data. [Sebastian Nowozin: Streaming Log-sum-exp Computation.](http://www.nowozin.net/sebastian/blog/streaming-log-sum-exp-computation.html) """ -logsumexp(X) = logsumexp_onepass(X) +logsumexp(X) = _logsumexp_onepass(X) """ logsumexp(X::AbstractArray{<:Real}; dims=:) @@ -242,7 +242,7 @@ logsumexp(X) = logsumexp_onepass(X) Compute `log.(sum(exp.(X); dims=dims))` in a numerically stable way that avoids intermediate over- and underflow. -If `dims = :`, then the result is computed using a single pass over the data. +The result is computed using a single pass over the data. # References @@ -250,62 +250,68 @@ If `dims = :`, then the result is computed using a single pass over the data. """ logsumexp(X::AbstractArray{<:Real}; dims=:) = _logsumexp(X, dims) -_logsumexp(X::AbstractArray{<:Real}, ::Colon) = logsumexp_onepass(X) +_logsumexp(X::AbstractArray{<:Real}, ::Colon) = _logsumexp_onepass(X) function _logsumexp(X::AbstractArray{<:Real}, dims) # Do not use log(zero(eltype(X))) directly to avoid issues with ForwardDiff (#82) - u = reduce(max, X, dims=dims, init=oftype(log(zero(eltype(X))), -Inf)) - return u .+ log.(sum(exp.(X .- u); dims=dims)) + FT = float(eltype(X)) + xmax_r = reduce(_logsumexp_onepass_op, X; dims=dims, init=(FT(-Inf), zero(FT))) + return @. first(xmax_r) + log1p(last(xmax_r)) end -function logsumexp_onepass(X) +function _logsumexp_onepass(X) # fallback for empty collections isempty(X) && return log(sum(X)) - xmax, r = _logsumexp_onepass(X, Base.IteratorEltype(X)) - - return xmax + log(r) + return xmax + log1p(r) end -# initial element is required by CUDA (ideally we would never use this method) +# initial element is required by CUDA (otherwise we could remove this method) function _logsumexp_onepass(X, ::Base.HasEltype) # do not perform type computations if element type is abstract T = eltype(X) isconcretetype(T) || return _logsumexp_onepass(X, Base.EltypeUnknown()) - # compute initial element - FT = float(eltype(X)) - init = (FT(-Inf), zero(FT)) - r_one = one(FT) + FT = float(T) + return reduce(_logsumexp_onepass_op, X; init=(FT(-Inf), zero(FT))) +end - return mapreduce(_logsumexp_onepass_op, X; init=init) do x - return float(x), r_one - end +# without initial element (without CUDA support we could always use this method) +_logsumexp_onepass(X, ::Base.EltypeUnknown)::Tuple = reduce(_logsumexp_onepass_op, X) + +## Reductions for one-pass algorithm: avoid expensive multiplications if numbers are reduced + +# reduce two numbers +function _logsumexp_onepass_op(x1, x2) + a = x1 == x2 ? zero(x1 - x2) : -abs(x1 - x2) + xmax = x1 > x2 ? oftype(a, x1) : oftype(a, x2) + r = exp(a) + return xmax, r end -# without initial element (ideally we would always use this method) -function _logsumexp_onepass(X, ::Base.EltypeUnknown) - return mapreduce(_logsumexp_onepass_op, X) do x - _x = float(x) - return _x, one(_x) +# reduce a number and a partial sum +function _logsumexp_onepass_op(x, (xmax, r)::Tuple) + a = x == xmax ? zero(x - xmax) : -abs(x - xmax) + if x > xmax + _xmax = oftype(a, x) + _r = (r + one(r)) * exp(a) + else + _xmax = oftype(a, xmax) + _r = r + exp(a) end + return _xmax, _r end - -# all inputs are provided as floating point numbers -# `r1` and `r2` are one if `xmax1` and `xmax2` are new elements (no partial sums) -function _logsumexp_onepass_op((xmax1, r1)::T, (xmax2, r2)::T) where {T<:Tuple} - if xmax1 < xmax2 - xmax = xmax2 - a = exp(xmax1 - xmax2) - r = r2 + (isone(r1) ? a : r1 * a) # avoid expensive multiplication for new elements - elseif xmax1 > xmax2 - xmax = xmax1 - a = exp(xmax2 - xmax1) - r = r1 + (isone(r2) ? a : r2 * a) # avoid expensive multiplication for new elements - else # ensure finite values if x = xmax = ± Inf - xmax = isnan(xmax1) ? xmax1 : xmax2 - r = r1 + r2 +_logsumexp_onepass_op(xmax_r::Tuple, x) = _logsumexp_onepass_op(x, xmax_r) + +# reduce two partial sums +function _logsumexp_onepass_op((xmax1, r1)::Tuple, (xmax2, r2)::Tuple) + a = xmax1 == xmax2 ? zero(xmax1 - xmax2) : -abs(xmax1 - xmax2) + if xmax1 > xmax2 + xmax = oftype(a, xmax1) + r = r1 + (r2 + one(r2)) * exp(a) + else + xmax = oftype(a, xmax2) + r = r2 + (r1 + one(r1)) * exp(a) end - return xmax, r end From 70995c504066f368f557dd6959bdb7d852e95350 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Mon, 7 Sep 2020 14:52:35 +0200 Subject: [PATCH 13/14] Add function barrier --- src/basicfuns.jl | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/src/basicfuns.jl b/src/basicfuns.jl index b3e24ee..0834812 100644 --- a/src/basicfuns.jl +++ b/src/basicfuns.jl @@ -261,22 +261,25 @@ end function _logsumexp_onepass(X) # fallback for empty collections isempty(X) && return log(sum(X)) - xmax, r = _logsumexp_onepass(X, Base.IteratorEltype(X)) - return xmax + log1p(r) + return _logsumexp_onepass_result(_logsumexp_onepass_reduce(X, Base.IteratorEltype(X))) end -# initial element is required by CUDA (otherwise we could remove this method) -function _logsumexp_onepass(X, ::Base.HasEltype) +# function barrier for reductions with single element and without initial element +_logsumexp_onepass_result(x) = float(x) +_logsumexp_onepass_result((xmax, r)::Tuple) = xmax + log1p(r) + +# iterables with known element type +function _logsumexp_onepass_reduce(X, ::Base.HasEltype) # do not perform type computations if element type is abstract T = eltype(X) - isconcretetype(T) || return _logsumexp_onepass(X, Base.EltypeUnknown()) + isconcretetype(T) || return _logsumexp_onepass_reduce(X, Base.EltypeUnknown()) FT = float(T) return reduce(_logsumexp_onepass_op, X; init=(FT(-Inf), zero(FT))) end -# without initial element (without CUDA support we could always use this method) -_logsumexp_onepass(X, ::Base.EltypeUnknown)::Tuple = reduce(_logsumexp_onepass_op, X) +# iterables without known element type +_logsumexp_onepass_reduce(X, ::Base.EltypeUnknown) = reduce(_logsumexp_onepass_op, X) ## Reductions for one-pass algorithm: avoid expensive multiplications if numbers are reduced From 09bd2934175e24f05f43ab1c4345403e4ff772af Mon Sep 17 00:00:00 2001 From: David Widmann Date: Mon, 7 Sep 2020 14:52:51 +0200 Subject: [PATCH 14/14] Add more tests --- test/basicfuns.jl | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/test/basicfuns.jl b/test/basicfuns.jl index f4972e0..aff76fa 100644 --- a/test/basicfuns.jl +++ b/test/basicfuns.jl @@ -96,13 +96,18 @@ end @test logaddexp(2.0, 3.0) ≈ log(exp(2.0) + exp(3.0)) @test logaddexp(10002, 10003) ≈ 10000 + logaddexp(2.0, 3.0) - @test logsumexp([1.0, 2.0, 3.0]) ≈ 3.40760596444438 - @test logsumexp((1.0, 2.0, 3.0)) ≈ 3.40760596444438 + @test @inferred(logsumexp([1.0])) == 1.0 + @test @inferred(logsumexp((x for x in [1.0]))) == 1.0 + @test @inferred(logsumexp([1.0, 2.0, 3.0])) ≈ 3.40760596444438 + @test @inferred(logsumexp((1.0, 2.0, 3.0))) ≈ 3.40760596444438 @test logsumexp([1.0, 2.0, 3.0] .+ 1000.) ≈ 1003.40760596444438 - @test logsumexp([[1.0, 2.0, 3.0] [1.0, 2.0, 3.0] .+ 1000.]; dims=1) ≈ [3.40760596444438 1003.40760596444438] - @test logsumexp([[1.0 2.0 3.0]; [1.0 2.0 3.0] .+ 1000.]; dims=2) ≈ [3.40760596444438, 1003.40760596444438] - @test logsumexp([[1.0, 2.0, 3.0] [1.0, 2.0, 3.0] .+ 1000.]; dims=[1,2]) ≈ [1003.4076059644444] + @test @inferred(logsumexp([[1.0, 2.0, 3.0] [1.0, 2.0, 3.0] .+ 1000.]; dims=1)) ≈ [3.40760596444438 1003.40760596444438] + @test @inferred(logsumexp([[1.0 2.0 3.0]; [1.0 2.0 3.0] .+ 1000.]; dims=2)) ≈ [3.40760596444438, 1003.40760596444438] + @test @inferred(logsumexp([[1.0, 2.0, 3.0] [1.0, 2.0, 3.0] .+ 1000.]; dims=[1,2])) ≈ [1003.4076059644444] + + # check underflow + @test logsumexp([1e-20, log(1e-20)]) ≈ 2e-20 let cases = [([-Inf, -Inf], -Inf), # correct handling of all -Inf ([-Inf, -Inf32], -Inf), # promotion @@ -140,7 +145,7 @@ end # logsumexp with general iterables (issue #63) xs = range(-500, stop = 10, length = 1000) - @test logsumexp(x for x in xs) == logsumexp(xs) + @test @inferred(logsumexp(x for x in xs)) == logsumexp(xs) end @testset "softmax" begin