-
Notifications
You must be signed in to change notification settings - Fork 40
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve accuracy of logistic #94
Conversation
Impressive notes! |
Thanks! Ok to merge this? |
Calling float specific functions like |
Ok, I'll make that change. |
Do you have an example of failing with |
I used |
function logistic(x::Real) | ||
e = exp(x) | ||
lower, upper = _logistic_bounds(x) | ||
ifelse( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like a quite convoluted way of writing it over just if
, else
etc?
function logistic(x::Real)
e = exp(x)
lower, upper = _logistic_bounds(x)
if x < lower
return zero(x)
elseif x > upper
return one(x)
else
return e / (one(x) + e)
end
end
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably worth checking @code_llvm
; it's true that LLVM now does the conversion of if-else
=> ifelse
automatically pretty well these days.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They seem to produce different output at the @code_llvm
and @code_native
stages consistent with my earlier experience that ifelse
generates LLVM select
and if
generates LLVM br
, but benchmarking suggests their performance is very similar in scalar applications and broadcasting: https://gist.github.com/johnmyleswhite/548f4eb18a028a237d52ae06811ca33c
Note that the benchmarks also suggest that changing from the old implementation to the new implementation causes performance to drop because of processing subnormals in a mirror image: the old formulation was fast for z = -710.0
and slow for z = +710.0
, whereas the new formulation is slow for z = -710.0
and fast for z = +710.0
.
Happy to choose whichever form people prefer. I tend to defer to the side of assuming select
is more likely to play well with SIMD, but it doesn't seem like that issue applies in this case for the benchmarks I've run.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Happy to choose whichever form people prefer.
Personally, I don't really care, just felt like the if
version was easier to read and I couldn't find a benchmark where it was slower. I guess if you correctly predict the branch version is faster and otherwise the branchless is faster? 🤷♂️
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I think in the absence of useful evidence either way I'll maintain my existing superstitious faith in the importance of the LLVM select
instruction. Revising in the future would be trivial.
Made changes requested by @andreasnoack. Didn't add test to avoid adding more dependencies. |
Any objections to merging? I have the permissions to do it for myself, but given my general disengagement from the repo, it doesn't feel fair for me to use those permissions without approval from those who are more engaged than I am. |
Can this be simplified to this? function logistic(x::Real)
t = exp(-abs(x))
ifelse(x ≥ 0, inv(one(t) + t), t / (one(t) + t))
end This simple version passes all the tests given in this PR. If there is no test that can distinguish then I guess I can make a PR with this simpler version (which automatically supports e.g. BigFloats). What do you think? |
Using the (very nice!) function julia> evaluate_errors(logistic, range(-744.4400719213812, -log(floatmax(Float64)), length=10_000))
Frequency of Exact Results: 0.9964
Average Error: 0.0
Maximum Error: 5.0e-324
Average Number of Incorrect Bits: 0.0079
julia> evaluate_errors(logistic, range(-log(floatmax(Float64)), -log(eps(1.0)), length=10_000))
Frequency of Exact Results: 0.8654
Average Error: 2.727970920400209e-18
Maximum Error: 1.1102230246251565e-16
Average Number of Incorrect Bits: 0.2609
julia> evaluate_errors(logistic, range(-log(eps(1.0)), 36.7368005696771, length=10_000))
Frequency of Exact Results: 0.415
Average Error: 6.494804694057165e-17
Maximum Error: 1.1102230246251565e-16
Average Number of Incorrect Bits: 0.585 In the last two regions the implementation of this PR is better. I'm gonna leave this comment here anyway for future reference. |
Change the logistic function to increase accuracy for subnormal values (e.g.
logistic(-740.0)
). See my notes for extended details.