-
Notifications
You must be signed in to change notification settings - Fork 3.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH more stable gradient of CrossEntropy #6327
ENH more stable gradient of CrossEntropy #6327
Conversation
Thanks for this! I'll defer to @shiyu1994 and @guolinke to review. Until then, can you please update this to the latest |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the contribution. I just left a few comments about the correction of hessian computation.
src/objective/xentropy_objective.hpp
Outdated
if (score[i] > -37.0) { | ||
const double exp_tmp = std::exp(-score[i]); | ||
gradients[i] = static_cast<score_t>(((1.0f - label_[i]) - label_[i] * exp_tmp) / (1.0f + exp_tmp)); | ||
hessians[i] = static_cast<score_t>(exp_tmp / (1 + exp_tmp) * (1 + exp_tmp)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The / followed by a * simply returns exp_tmp
, which is not the expected hessian.
hessians[i] = static_cast<score_t>(exp_tmp / (1 + exp_tmp) * (1 + exp_tmp)); | |
hessians[i] = static_cast<score_t>(exp_tmp / ((1 + exp_tmp) * (1 + exp_tmp))); | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible that hessians[i] = static_cast<score_t>(exp_tmp / (1 + exp_tmp) / (1 + exp_tmp));
could be more numerically stable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
First for all, yes I forgot the parenthesis. Thanks for spotting it. It is surprising that still all the tests pass (with this bug).
Then (exp_tmp / (1 + exp_tmp) / (1 + exp_tmp))
is more numerical stable in the sense that it could prevent overflow. But exp_tmp > exp(37) = 1e16
and squaring that is within even single precision (3e38), and note that exp_tmp
is even double precision.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. Could you also fix the hessian calculation in the else branch?
} else { | ||
const double exp_tmp = std::exp(score[i]); | ||
gradients[i] = static_cast<score_t>(exp_tmp - label_[i]); | ||
hessians[i] = static_cast<score_t>(exp_tmp); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hessians[i] = static_cast<score_t>(exp_tmp); | |
hessians[i] = static_cast<score_t>(exp_tmp / ((1 + exp_tmp) * (1 + exp_tmp))); | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not needed as exp_tmp < 1e-16
is tiny and (1 + exp_tmp)
is just 1. Otherwise stated, the implemented formula is the 1st order Taylor series in exp_tmp
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. That makes sense.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But maybe it would still be better to write the original calculation formula explicitly to avoid ambiguity?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you mean with "ambiguity"?
It would not avoid the branch and is a tiny bit more efficient.
Similar to scikit-learn/scikit-learn#28048.
There is a small runtime cost to pay, but gradient computation is not the main bottleneck of histogram based gradient boosting.