Skip to content

Commit 6bb3a5d

Browse files
david-cortestrivialfis
authored andcommitted
fix intercept initialization for mnlogit
1 parent e739915 commit 6bb3a5d

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

src/objective/multiclass_obj.cu

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,16 @@ class SoftmaxMultiClassObj : public ObjFunction {
209209
collective::SafeColl(status);
210210
CHECK_GE(sum_weight, kRtEps);
211211
linalg::VecScaDiv(this->ctx_, intercept, sum_weight);
212+
213+
double sum_intercepts = 0.;
214+
for (std::int64_t ix = 0; ix < n_classes; ix++) {
215+
intercept(ix) = std::log(intercept(ix));
216+
sum_intercepts += intercept(ix);
217+
}
218+
const double mean_intercepts = sum_intercepts / static_cast<double>(n_classes);
219+
for (std::int64_t ix = 0; ix < n_classes; ix++) {
220+
intercept(ix) -= mean_intercepts;
221+
}
212222
}
213223

214224
private:

0 commit comments

Comments
 (0)