Skip to content

Commit

Permalink
support one label training
Browse files Browse the repository at this point in the history
  • Loading branch information
lijinf2 committed Nov 14, 2023
1 parent 5ab3cae commit 2adfc37
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 2 deletions.
2 changes: 1 addition & 1 deletion cpp/src/glm/qn/mg/qn_mg.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ inline void qn_fit_x_mg(const raft::handle_t& handle,

switch (pams.loss) {
case QN_LOSS_LOGISTIC: {
ASSERT(C == 2, "qn_mg.cuh: logistic loss invalid C");
ASSERT(C > 0, "qn_mg.cuh: logistic loss invalid C");
ML::GLM::detail::LogisticLoss<T> loss(handle, D, pams.fit_intercept);
ML::GLM::opg::qn_fit_mg<T, decltype(loss)>(
handle, pams, loss, X, y, Z, w0_data, f, num_iters, n_samples, rank, n_ranks);
Expand Down
9 changes: 8 additions & 1 deletion python/cuml/dask/linear_model/logistic_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,13 @@ def _func_fit(f, data, n_rows, n_cols, partsToSizes, rank):
for p in partsToSizes:
aggregated_partsToSizes[p[0]][1] += p[1]

return f.fit(
ret_status = f.fit(
[(inp_X, inp_y)], n_rows, n_cols, aggregated_partsToSizes, rank
)

if len(f.classes_) == 1:
raise ValueError(
f"This solver needs samples of at least 2 classes in the data, but the data contains only one class: {f.classes_[0]}"
)

return ret_status
4 changes: 4 additions & 0 deletions python/cuml/linear_model/logistic_regression_mg.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,9 @@ class LogisticRegressionMG(MGFitMixin, LogisticRegression):
else:
self._num_classes_dim = self._num_classes

if solves_classification and self._num_classes == 1:
self._num_classes_dim = 1

if self.fit_intercept:
coef_size = (self.n_cols + 1, self._num_classes_dim)
else:
Expand Down Expand Up @@ -207,6 +210,7 @@ class LogisticRegressionMG(MGFitMixin, LogisticRegression):
self._num_classes = len(self.classes_)
self.loss = "sigmoid" if self._num_classes <= 2 else "softmax"
self.prepare_for_fit(self._num_classes)

cdef uintptr_t mat_coef_ptr = self.coef_.ptr

cdef qn_params qnpams = self.solver_model.qnparams.params
Expand Down
24 changes: 24 additions & 0 deletions python/cuml/tests/dask/test_dask_logistic_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,3 +618,27 @@ def test_sparse_nlp20news(dtype, nlp_20news, client):

cuml_score = accuracy_score(y_test, preds.tolist())
assert cuml_score >= acceptable_score


@pytest.mark.parametrize("fit_intercept", [False, True])
def test_exception_one_label(fit_intercept, client):
n_parts = 2
datatype = "float32"

X = np.array([(1, 2), (1, 3), (2, 1), (3, 1)], datatype)
y = np.array([1.0, 1.0, 1.0, 1.0], datatype)
X_df, y_df = _prep_training_data(client, X, y, n_parts)

err_msg = "This solver needs samples of at least 2 classes in the data, but the data contains only one class: 1.0"

from cuml.dask.linear_model import LogisticRegression as cumlLBFGS_dask

mg = cumlLBFGS_dask(fit_intercept=fit_intercept, verbose=6)
with pytest.raises(RuntimeError, match=err_msg):
mg.fit(X_df, y_df)

from sklearn.linear_model import LogisticRegression

lr = LogisticRegression(fit_intercept=fit_intercept)
with pytest.raises(ValueError, match=err_msg):
lr.fit(X, y)

0 comments on commit 2adfc37

Please sign in to comment.