Skip to content

Commit 229b836

Browse files
david-cortestrivialfis
authored andcommitted
fix intercept initialization for mnlogit
1 parent c5ba21d commit 229b836

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

src/objective/multiclass_obj.cu

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,16 @@ class SoftmaxMultiClassObj : public ObjFunction {
214214
collective::SafeColl(status);
215215
CHECK_GE(sum_weight, kRtEps);
216216
linalg::VecScaDiv(this->ctx_, intercept, sum_weight);
217+
218+
double sum_intercepts = 0.;
219+
for (std::int64_t ix = 0; ix < n_classes; ix++) {
220+
intercept(ix) = std::log(intercept(ix));
221+
sum_intercepts += intercept(ix);
222+
}
223+
const double mean_intercepts = sum_intercepts / static_cast<double>(n_classes);
224+
for (std::int64_t ix = 0; ix < n_classes; ix++) {
225+
intercept(ix) -= mean_intercepts;
226+
}
217227
}
218228

219229
private:

0 commit comments

Comments
 (0)