Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

log1pexp #37

Merged
merged 38 commits into from
Mar 12, 2022
Merged

log1pexp #37

merged 38 commits into from
Mar 12, 2022

Conversation

cossio
Copy link
Contributor

@cossio cossio commented Mar 7, 2022

@cossio
Copy link
Contributor Author

cossio commented Mar 7, 2022

Benchmarks.

using BenchmarkTools, LogExpFunctions
log1pexp_old(x::Float64) = x < 18 ? log1p(exp(x)) : x < 33.3 ? x + exp(-x) : oftype(exp(-x), x) # before this PR
log1pexp_new(x::Float64) = x  -37 ? exp(x) : x  18 ? log1p(exp(x)) : x  33.3 ? x + exp(-x) : float(x) # this PR


julia> @benchmark log1pexp_old(x) setup=(x=(rand() - 0.5) * 100) samples=10^6                                                                                                                                                             
BenchmarkTools.Trial: 240558 samples with 998 evaluations.                  
 Range (min  max):   9.675 ns  136.391 ns  ┊ GC (min  max): 0.00%  0.00%
 Time  (median):     19.892 ns               ┊ GC (median):    0.00%
 Time  (mean ± σ):   19.025 ns ±   7.230 ns  ┊ GC (mean ± σ):  0.00% ± 0.00%

    ▁█                       ▂▄                   ▅             
  ▃███▆▃▁▁▃▁▃▄█▂▂▁▁▁▁▁▁▁▅▃▆▆▅██▁▃▁▁▁▁▁▁▁▂▁▃▁▄▂▁▆▁▁█▁▁▂▁▁▁▁▁▁▁▁ ▂
  9.68 ns         Histogram: frequency by time         35.2 ns <

 Memory estimate: 0 bytes, allocs estimate: 0.            


julia> @benchmark log1pexp_new(x) setup=(x=(rand() - 0.5) * 100) samples=10^6                                                                                                                                                             
BenchmarkTools.Trial: 281295 samples with 1000 evaluations.                 
 Range (min  max):   2.794 ns  343.354 ns  ┊ GC (min  max): 0.00%  0.00%
 Time  (median):     18.832 ns               ┊ GC (median):    0.00%
 Time  (mean ± σ):   16.156 ns ±   8.600 ns  ┊ GC (mean ± σ):  0.00% ± 0.00%

  █                               ▄ ▄                           
  █▂▁▁▁▁▁▁▁▁▁▁▁▆▇▂▇▆▃▂▂▁▁▁▁▁▁▁▁▁▁▁█▂█▆▁▃▁▂▁▁▁▁▁▁▁▁█▁█▁▄▁▂▁▁▂▁▁ ▂
  2.79 ns         Histogram: frequency by time         32.1 ns <

 Memory estimate: 0 bytes, allocs estimate: 0.

src/basicfuns.jl Outdated Show resolved Hide resolved
test/basicfuns.jl Outdated Show resolved Hide resolved
@cossio
Copy link
Contributor Author

cossio commented Mar 7, 2022

Benchmark comparing to version with oftype(...) instead of float(...) suggested by @tpapp .

log1pexp_oftype(x::Float64) = x  -37 ? exp(x) : x  18 ? log1p(exp(x)) : x  33.3 ? x + exp(-x) : oftype(exp(x), x)


julia> @benchmark log1pexp_new(x) setup=(x=(rand() - 0.5) * 100) samples=10^6                                                                                                                                                             
BenchmarkTools.Trial: 272085 samples with 1000 evaluations.
 Range (min  max):   2.460 ns   1.768 μs  ┊ GC (min  max): 0.00%  0.00%
 Time  (median):     19.383 ns              ┊ GC (median):    0.00%
 Time  (mean ± σ):   16.689 ns ± 11.719 ns  ┊ GC (mean ± σ):  0.00% ± 0.00%

  █                ▁                 ▃▂                        
  █▄▂▂▂▂▂▂▂▂▂▂▁▂▇▅▄█▇▅▄▂▂▂▂▂▂▂▂▂▂▂▂▆▃██▂▇▃▆▂▂▂▂▂▂▂▂▅▂█▂▇▂▅▃▂▄ ▃
  2.46 ns         Histogram: frequency by time        31.2 ns <

 Memory estimate: 0 bytes, allocs estimate: 0.


julia> @benchmark log1pexp_oftype(x) setup=(x=(rand() - 0.5) * 100) samples=10^6                                                                                                                                                          
BenchmarkTools.Trial: 233832 samples with 998 evaluations.
 Range (min  max):   9.697 ns  833.891 ns  ┊ GC (min  max): 0.00%  0.00%
 Time  (median):     20.884 ns               ┊ GC (median):    0.00%
 Time  (mean ± σ):   19.594 ns ±   8.968 ns  ┊ GC (mean ± σ):  0.00% ± 0.00%

     ▅█▇▆               ▆ █ ▄          ▁ ▂                      
  ▂▅▇█████▃▁▁▁▁▁▁▁▁▁▁▂▄▆█▂█▅█▂▃▁▁▁▁▁▂▅▁█▁█▁▇▄▁▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▃
  9.7 ns          Histogram: frequency by time         39.7 ns <

 Memory estimate: 0 bytes, allocs estimate: 0.

The oftype(...) version seems slower probably due to the exp() call.

Copy link
Collaborator

@tpapp tpapp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

Copy link
Member

@devmotion devmotion left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you update the ChainRules tests as well? IIRC there we also test every branch.

My main concern is whether there is a licensing issue since this is based on (the vignette of) Rmpfr which uses a GPL. I would prefer if someone who's more familiar with licenses (I'm definitely not) could confirm that there are no problems here and we don't have use the GPL.

src/basicfuns.jl Outdated Show resolved Hide resolved
@cossio
Copy link
Contributor Author

cossio commented Mar 7, 2022

The ChainRules tests fail for Float32. But I noticed that logistic itself has no tests for Float32.

Are we sure logistic(::Float32) is accurate enough?

@cossio
Copy link
Contributor Author

cossio commented Mar 7, 2022

I unified the log1pexp implementations for Float32 and Float64. Here is how I computed the branch bounds (approximately) for the different approximations:

function log1pexp(x::Real)
    t = log1p(exp(-abs(x)))
    return x  0 ? t : t + x
end

xs = 0:-0.01:-100
for T in (Float16, Float32, Float64)
	for x in xs
		correct = T(log1pexp(big(x)))
		if iszero(correct - exp(T(x)))
		   println("Found crossing with `exp(x)` at x = $x for $T, in the interval $(extrema(xs))")
		   break
		end
	end
end

xs = 0:0.01:100
for T in (Float16, Float32, Float64)
	for x in xs
		correct = T(log1pexp(big(x)))
		if iszero(correct - (T(x) + exp(-T(x))))
		   println("Found crossing with `x + exp(-x)` at x = $x for $T, in the interval $(extrema(xs))")
		   break
		end
	end
end

xs = 0:0.01:100
for T in (Float16, Float32, Float64)
	for x in xs
		if iszero(T(log1pexp(big(x))) - T(x))
		   println("Found crossing with `x` at x = $x for $T, in the interval $(extrema(xs))")
		   break
		end
	end
end

I removed the ChainRules tests for Float32, since the chain rule itself is not specialized (it relies on logistic).
The accuracy of the specialized log1pexp should be tested directly on log1pexp.

src/basicfuns.jl Outdated Show resolved Hide resolved
src/basicfuns.jl Outdated Show resolved Hide resolved
src/basicfuns.jl Outdated Show resolved Hide resolved
test/chainrules.jl Outdated Show resolved Hide resolved
@cossio cossio force-pushed the log1pexp branch 2 times, most recently from 844564f to 5328ccb Compare March 7, 2022 16:07
src/basicfuns.jl Outdated Show resolved Hide resolved
test/chainrules.jl Outdated Show resolved Hide resolved
src/basicfuns.jl Outdated Show resolved Hide resolved
@cossio
Copy link
Contributor Author

cossio commented Mar 7, 2022

@devmotion Note that the current log1pexp(x::Real) (in master, not this PR) can give wrong results because x is not restricted to Float64. For example:

julia> log1pexp(Float16(16)) # current master
Inf16
julia> log1pexp(Float16(16)) # after this PR
Float16(16.0)

This is because it is using branch bounds tailored for Float64, which are not right for Float16.

Does this example persuade you?

I agree this is a different issue, but I could try to fix it also here. I'd say that fixing wrong results takes priority over a (minor) performance regression.

src/basicfuns.jl Outdated Show resolved Hide resolved
@cossio
Copy link
Contributor Author

cossio commented Mar 7, 2022

As to the performance issue, let's compare the current implementation on master with the generic log1pexp(x::Real) of this PR:

using BenchmarkTools

# generic fallback in this PR
function log1pexp_new(x::Real)
    t = log1p(exp(-abs(x)))
    return x  0 ? t : t + x
end

# current master
log1pexp_old(x::Real) = x < 18.0 ? log1p(exp(x)) : x < 33.3 ? x + exp(-x) : oftype(exp(-x), x)


julia> @benchmark log1pexp_old(x) setup=(x=(rand() - 0.5) * 100) samples=10^6                                                                                                                                                             
BenchmarkTools.Trial: 278000 samples with 999 evaluations.     
 Range (min  max):   9.042 ns  56.144 ns  ┊ GC (min  max): 0.00%  0.00%
 Time  (median):     17.880 ns              ┊ GC (median):    0.00%
 Time  (mean ± σ):   16.501 ns ±  6.132 ns  ┊ GC (mean ± σ):  0.00% ± 0.00%

  ▆█▇▆▄▃▃ ▂▁▇▆▄▃▃▁▂▃▂▁    ▃▆█▅▇▄▅▁▃ ▂▂▁▃  ▁   ▃ ▅▇ ▆▂ ▅▁ ▃  ▂ ▄
  ████████████████████▇▇▆▇███████████████▇█▆▇▇█▆██▇██▇██▇██▆█ █
  9.04 ns      Histogram: log(frequency) by time      29.1 ns <

 Memory estimate: 0 bytes, allocs estimate: 0.                                                                                                                                                                                            


julia> @benchmark log1pexp_new(x) setup=(x=(rand() - 0.5) * 100) samples=10^6                                                                                                                                                             
BenchmarkTools.Trial: 218739 samples with 997 evaluations.
 Range (min  max):  12.472 ns  1.042 μs  ┊ GC (min  max): 0.00%  0.00%                                                                                               
 Time  (median):     21.477 ns             ┊ GC (median):    0.00%
 Time  (mean ± σ):   21.094 ns ± 8.230 ns  ┊ GC (mean ± σ):  0.00% ± 0.00%

                          ▇ █                                 
  ▁▃▄▄▂▅▇▁▂▁▁▁▁▁▁▁▂▁▇▂▇▄▇▂█▁█▂▂▂▁▂▁▁▂▂▁▁▂▁▁▂▁▂▂▁▁▂▁▁▁▁▁▁▁▁▁▁ ▂                                                                                                                                                                            
  12.5 ns        Histogram: frequency by time          36 ns <

 Memory estimate: 0 bytes, allocs estimate: 0.

@devmotion
Copy link
Member

No, it doesn't convince me. The Float16 issue is already fixed by the PR. IMO there's no reason to change the fallback in this PR, it's a separate issue that has to be done carefully to avoid performance regressions: This PR is about adding optimizations/fixes for Float64, Float32, and Float16. LogExpFunctions is a central package in the Julia ecosystem, and hence we should ensure at least that for common floating point number types other than Float64 in downstream packages there are no sudden surprising regressions. This deserves a separate issue, PR, and possibly downstream changes, and should not be part of this PR.

@cossio
Copy link
Contributor Author

cossio commented Mar 7, 2022

I simplified the implementation once more.

Updated benchmarks:

julia> @benchmark log1pexp(x) setup=(x=(rand() - 0.5) * 100) samples=10^6        # after this PR                                                                                                                                                          
BenchmarkTools.Trial: 261384 samples with 1000 evaluations.
 Range (min … max):   2.439 ns … 57.839 ns  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     20.283 ns              ┊ GC (median):    0.00%
 Time  (mean ± σ):   17.496 ns ±  7.741 ns  ┊ GC (mean ± σ):  0.00% ± 0.00%

  ▇                                ▃ ▅█ ▄                      
  █▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▂▄▄▅▄▃▃▂▂▂▂▂▂▂▂▂▂█▂██▂█▅▆▄▂▂▂▂▂▃▃▂▃▂▂▃▂▃▂▂▂ ▃
  2.44 ns         Histogram: frequency by time        32.3 ns <

 Memory estimate: 0 bytes, allocs estimate: 0.

whereas the current master takes about 9 secs minimum (see plot above).

src/basicfuns.jl Outdated Show resolved Hide resolved
src/basicfuns.jl Outdated Show resolved Hide resolved
src/basicfuns.jl Outdated Show resolved Hide resolved
@cossio
Copy link
Contributor Author

cossio commented Mar 9, 2022

Unfortunately this version:

@inline function _log1pexp_thresholds(x::Real)
    prec = precision(x)
    logtwo = oftype(x, IrrationalConstants.logtwo)
    x0 = -prec * logtwo
    x1 = (prec - 1) * logtwo / 2
    x2 = -x0 - log(-x0) * (1 + 1 / x0) # approximate root of e^-x == x * ϵ/2 via asymptotics of Lambert's W function
    return (x0, x1, x2)
end

is not compiled away in Julia 1.0. Actually @code_typed shows a lot of generated code. Julia 1.6 is much better but still is not completely elided. Amazing how Julia 1.7 kills it 😄

This supports having the hard-coded thresholds for Floats. Not sure if we should worry about regressions on Julia 1.0 / 1.6 code though.

src/basicfuns.jl Outdated Show resolved Hide resolved
src/basicfuns.jl Outdated Show resolved Hide resolved
src/basicfuns.jl Outdated Show resolved Hide resolved
cossio and others added 3 commits March 9, 2022 01:47
Co-authored-by: David Widmann <devmotion@users.noreply.github.com>
Co-authored-by: David Widmann <devmotion@users.noreply.github.com>
test/basicfuns.jl Outdated Show resolved Hide resolved
Copy link
Member

@devmotion devmotion left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this looks good, I have only a question about the Float64 and Float16 values left (and I think it would also be fine to make the hardcoded values consistent with the generic fallback definition, but I don't have a strong opinion there).

In general, now I think the PR should be an improvement for all inputs x where float(x) is a fixed-precision number since then (at least in recent Julia versions) the compiler can optimize away thresholds even for non-standard types. It's only a bit problematic for variable precision numbers - in contrast to the @generated version the results will be correct but since the thresholds have to be recomputed every time it might cause performance regressions. IMO this is a bit annoying but much better than silently returning wrong results, and it can be fixed in the same way as done here for BigFloat.

Can you update the version number as well?

src/basicfuns.jl Outdated Show resolved Hide resolved
src/basicfuns.jl Outdated Show resolved Hide resolved
@cossio
Copy link
Contributor Author

cossio commented Mar 9, 2022

@devmotion I attach here the calculation of the thresholds I am using in this PR.

log1pexp.pdf

@devmotion
Copy link
Member

It seems the approximations in your notes are slightly different from the values in the PR?

julia> log(eps(Float64))
-36.04365338911715

julia> -log(2*eps(Float64)) / 2
17.675253104278607

@cossio
Copy link
Contributor Author

cossio commented Mar 9, 2022

The epsilon in the notes is eps(T) / 2, to ensure correct rounding to the nearest float.

@cossio cossio changed the title log1pexp(x) for x < -37 log1pexp Mar 9, 2022
@cossio
Copy link
Contributor Author

cossio commented Mar 11, 2022

Merge?
Or are we waiting for JuliaDiff/ForwardDiff.jl#580 first?

@tpapp
Copy link
Collaborator

tpapp commented Mar 11, 2022

My understanding is that the current solution is broken without that PR, so for now I would suggest waiting for that.

@cossio, in the meantime, can you please add a @test_broken for ForwardDiff.Duals to verify this? We can change that simply when the related PR is merged.

@devmotion
Copy link
Member

devmotion commented Mar 11, 2022

No, it's not broken, we don't need the PR: #37 (comment) Would still be good to add tests for it, I think.

I was only waiting in case @tpapp has some additional comments. I summarized mine above, and think it's worth it even though it might cause performance regressions for variable precision number types.

@tpapp
Copy link
Collaborator

tpapp commented Mar 11, 2022

@devmotion: thanks for the clarification. In that case, please feel free to merge.

@devmotion devmotion merged commit 4e50c54 into JuliaStats:master Mar 12, 2022
@cossio cossio deleted the log1pexp branch March 12, 2022 00:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

log1pexp(x) when x < -37
3 participants