We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent e739915 commit 6bb3a5dCopy full SHA for 6bb3a5d
src/objective/multiclass_obj.cu
@@ -209,6 +209,16 @@ class SoftmaxMultiClassObj : public ObjFunction {
209
collective::SafeColl(status);
210
CHECK_GE(sum_weight, kRtEps);
211
linalg::VecScaDiv(this->ctx_, intercept, sum_weight);
212
+
213
+ double sum_intercepts = 0.;
214
+ for (std::int64_t ix = 0; ix < n_classes; ix++) {
215
+ intercept(ix) = std::log(intercept(ix));
216
+ sum_intercepts += intercept(ix);
217
+ }
218
+ const double mean_intercepts = sum_intercepts / static_cast<double>(n_classes);
219
220
+ intercept(ix) -= mean_intercepts;
221
222
}
223
224
private:
0 commit comments