-
Notifications
You must be signed in to change notification settings - Fork 440
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Numerically stable log_sigmoid
#1548
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1548 +/- ##
=======================================
Coverage 86.35% 86.36%
=======================================
Files 682 683 +1
Lines 77849 77898 +49
=======================================
+ Hits 67230 67280 +50
+ Misses 10619 10618 -1 ☔ View full report in Codecov by Sentry. |
pub fn log_sigmoid<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> { | ||
/// To avoid overflow, we use the log-sum-exp trick. | ||
/// | ||
/// ```ignore | ||
/// log(sigmoid(x)) = log(1/(1 + exp(-x))) | ||
/// = log(1) - log(1 + exp(-x)) | ||
/// = -log(1 + exp(-x)) | ||
/// = -log(exp(0) + exp(-x)) | ||
/// ``` | ||
/// The `exp(t)` of even a moderate-magnitude positive number can be astronomically huge, so we | ||
/// subtract the `max(t, 0)` of each value (where `t = -x` in this case). This results in the | ||
/// following equivalence: | ||
/// ```ignore | ||
/// log(sigmoid(x)) = -(max(-x, 0) + log(exp(-max(-x, 0)) + exp(-x - max(-x, 0)))) | ||
/// ``` | ||
/// | ||
/// This extends the range of values for which we obtain accurate results. | ||
fn numerically_stable_log_sigmoid<const D: usize, B: Backend>(x: Tensor<B, D>) -> Tensor<B, D> { | ||
// max(-x, 0) | ||
let max_elem = x.clone().neg().max_pair(x.zeros_like()); | ||
|
||
// log(exp(-max(-x, 0)) + exp(-x - max(-x, 0))) | ||
let z = (max_elem.clone().neg().exp() + (x.neg() - max_elem.clone()).exp()).log(); | ||
|
||
z.neg() - max_elem | ||
} | ||
match B::FloatElem::precision() { | ||
Precision::Half => { | ||
let tensor_full = tensor.into_full_precision(); | ||
let tensor_tmp = tensor_full.neg().exp().add_scalar(1.0_f32).log().neg(); | ||
let tensor_tmp = numerically_stable_log_sigmoid(tensor_full); | ||
Tensor::from_full_precision(tensor_tmp) | ||
} | ||
_ => tensor.neg().exp().add_scalar(1.0_f32).log().neg(), | ||
_ => numerically_stable_log_sigmoid(tensor), | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think a descent speedup for backends that don't implement fusion would be to move log_sigmoid
and sigmoid
into burn_tensor::ops::activation
with the default implementation provided. We could then override those activations in backends that don't support fusion such as tch and candle.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed, I just tackled the scope of the current log_sigmoid
implementation but that definitely came to mind.
Btw sigmoid
is already in ActivationOps
just not log_sigmoid
yet.
Should I tackle this in a new PR or expand this one?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
While making progress on a fine-tuning classification example I stumbled upon an issue with our
log_sigmoid
implementation which returned-inf
for large negative values.I first attempted to use this common log-sum-exp trick
Which resulted in this implementation:
That worked on wgpu but gave me NaNs for large values near the min and max on ndarray. That's when I stumbled upon the pytorch implementation that goes a step further, as implemented in this PR.
Checklist
run-checks all
script has been executed.Changes
Changed our
log_sigmoid
implementation to be numerically stable for large values.Testing
Added unit tests for
log_sigmoid
.