-
Notifications
You must be signed in to change notification settings - Fork 3.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH more stable gradient of CrossEntropy #6327
Changes from 4 commits
f1b4da8
d92ed54
b8b27ec
d9615f9
f2d955a
85eca6c
79489fa
96a06be
64a7474
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -75,21 +75,54 @@ class CrossEntropy: public ObjectiveFunction { | |||||||
} | ||||||||
|
||||||||
void GetGradients(const double* score, score_t* gradients, score_t* hessians) const override { | ||||||||
// z = expit(score) = 1 / (1 + exp(-score)) | ||||||||
// gradient = z - label = expit(score) - label | ||||||||
// Numerically more stable, see http://fa.bianp.net/blog/2019/evaluate_logistic/ | ||||||||
// if score < 0: | ||||||||
// exp_tmp = exp(score) | ||||||||
// return ((1 - label) * exp_tmp - label) / (1 + exp_tmp) | ||||||||
// else: | ||||||||
// exp_tmp = exp(-score) | ||||||||
// return ((1 - label) - label * exp_tmp) / (1 + exp_tmp) | ||||||||
// Note that optimal speed would be achieved, at the cost of precision, by | ||||||||
// return expit(score) - y_true | ||||||||
// i.e. no "if else" and an own inline implementation of expit. | ||||||||
// The case distinction score < 0 in the stable implementation does not | ||||||||
// provide significant better precision apart from protecting overflow of exp(..). | ||||||||
// The branch (if else), however, can incur runtime costs of up to 30%. | ||||||||
// Instead, we help branch prediction by almost always ending in the first if clause | ||||||||
// and making the second branch (else) a bit simpler. This has the exact same | ||||||||
// precision but is faster than the stable implementation. | ||||||||
// As branching criteria, we use the same cutoff as in log1pexp, see link above. | ||||||||
// Note that the maximal value to get gradient = -1 with label = 1 is -37.439198610162731 | ||||||||
// (based on mpmath), and scipy.special.logit(np.finfo(float).eps) ~ -36.04365. | ||||||||
if (weights_ == nullptr) { | ||||||||
// compute pointwise gradients and Hessians with implied unit weights | ||||||||
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static) | ||||||||
for (data_size_t i = 0; i < num_data_; ++i) { | ||||||||
const double z = 1.0f / (1.0f + std::exp(-score[i])); | ||||||||
gradients[i] = static_cast<score_t>(z - label_[i]); | ||||||||
hessians[i] = static_cast<score_t>(z * (1.0f - z)); | ||||||||
if (score[i] > -37.0) { | ||||||||
const double exp_tmp = std::exp(-score[i]); | ||||||||
gradients[i] = static_cast<score_t>(((1.0f - label_[i]) - label_[i] * exp_tmp) / (1.0f + exp_tmp)); | ||||||||
hessians[i] = static_cast<score_t>(exp_tmp / (1 + exp_tmp) * (1 + exp_tmp)); | ||||||||
} else { | ||||||||
const double exp_tmp = std::exp(score[i]); | ||||||||
gradients[i] = static_cast<score_t>(exp_tmp - label_[i]); | ||||||||
hessians[i] = static_cast<score_t>(exp_tmp); | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is not needed as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see. That makes sense. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But maybe it would still be better to write the original calculation formula explicitly to avoid ambiguity? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What do you mean with "ambiguity"? |
||||||||
} | ||||||||
} | ||||||||
} else { | ||||||||
// compute pointwise gradients and Hessians with given weights | ||||||||
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static) | ||||||||
for (data_size_t i = 0; i < num_data_; ++i) { | ||||||||
const double z = 1.0f / (1.0f + std::exp(-score[i])); | ||||||||
gradients[i] = static_cast<score_t>((z - label_[i]) * weights_[i]); | ||||||||
hessians[i] = static_cast<score_t>(z * (1.0f - z) * weights_[i]); | ||||||||
if (score[i] > -37.0) { | ||||||||
const double exp_tmp = std::exp(-score[i]); | ||||||||
gradients[i] = static_cast<score_t>(((1.0f - label_[i]) - label_[i] * exp_tmp) / (1.0f + exp_tmp) * weights_[i]); | ||||||||
hessians[i] = static_cast<score_t>(exp_tmp / (1 + exp_tmp) * (1 + exp_tmp) * weights_[i]); | ||||||||
lorentzenchr marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||
} else { | ||||||||
const double exp_tmp = std::exp(score[i]); | ||||||||
gradients[i] = static_cast<score_t>((exp_tmp - label_[i]) * weights_[i]); | ||||||||
hessians[i] = static_cast<score_t>(exp_tmp * weights_[i]); | ||||||||
lorentzenchr marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||
} | ||||||||
} | ||||||||
} | ||||||||
} | ||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The / followed by a * simply returns
exp_tmp
, which is not the expected hessian.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible that
hessians[i] = static_cast<score_t>(exp_tmp / (1 + exp_tmp) / (1 + exp_tmp));
could be more numerically stable?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
First for all, yes I forgot the parenthesis. Thanks for spotting it. It is surprising that still all the tests pass (with this bug).
Then
(exp_tmp / (1 + exp_tmp) / (1 + exp_tmp))
is more numerical stable in the sense that it could prevent overflow. Butexp_tmp > exp(37) = 1e16
and squaring that is within even single precision (3e38), and note thatexp_tmp
is even double precision.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. Could you also fix the hessian calculation in the else branch?