We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent c5ba21d commit 229b836Copy full SHA for 229b836
src/objective/multiclass_obj.cu
@@ -214,6 +214,16 @@ class SoftmaxMultiClassObj : public ObjFunction {
214
collective::SafeColl(status);
215
CHECK_GE(sum_weight, kRtEps);
216
linalg::VecScaDiv(this->ctx_, intercept, sum_weight);
217
+
218
+ double sum_intercepts = 0.;
219
+ for (std::int64_t ix = 0; ix < n_classes; ix++) {
220
+ intercept(ix) = std::log(intercept(ix));
221
+ sum_intercepts += intercept(ix);
222
+ }
223
+ const double mean_intercepts = sum_intercepts / static_cast<double>(n_classes);
224
225
+ intercept(ix) -= mean_intercepts;
226
227
}
228
229
private:
0 commit comments