Skip to content

Commit

Permalink
address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
lijinf2 committed Nov 28, 2023
1 parent 08ec639 commit 2cc76f5
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 29 deletions.
1 change: 0 additions & 1 deletion python/cuml/linear_model/base_mg.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ class MGFitMixin(object):
check_dtype = self.dtype

if sparse_input:

X_m = SparseCumlArray(input_data[i][0], convert_index=np.int32)
_, self.n_cols = X_m.shape
else:
Expand Down
8 changes: 2 additions & 6 deletions python/cuml/linear_model/logistic_regression_mg.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -170,13 +170,10 @@ class LogisticRegressionMG(MGFitMixin, LogisticRegression):
"with softmax (multinomial).")

if solves_classification and not solves_multiclass:
self._num_classes_dim = self._num_classes - 1
self._num_classes_dim = 1
else:
self._num_classes_dim = self._num_classes

if solves_classification and self._num_classes == 1:
self._num_classes_dim = 1

if self.fit_intercept:
coef_size = (self.n_cols + 1, self._num_classes_dim)
else:
Expand All @@ -188,7 +185,6 @@ class LogisticRegressionMG(MGFitMixin, LogisticRegression):

def fit(self, input_data, n_rows, n_cols, parts_rank_size, rank, convert_dtype=False):

self.rank = rank
assert len(input_data) == 1, f"Currently support only one (X, y) pair in the list. Received {len(input_data)} pairs."
self.is_col_major = False
order = 'F' if self.is_col_major else 'C'
Expand All @@ -215,7 +211,7 @@ class LogisticRegressionMG(MGFitMixin, LogisticRegression):

cdef qn_params qnpams = self.solver_model.qnparams.params

sparse_input = True if isinstance(X, list) else False
sparse_input = isinstance(X, list)

if self.dtype == np.float32:
if sparse_input is False:
Expand Down
62 changes: 40 additions & 22 deletions python/cuml/tests/dask/test_dask_logistic_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,12 +339,12 @@ def imp():
datatype, nrows, ncols, n_info, n_classes=n_classes
)

if convert_to_sparse is False:
# X_dask and y_dask are dask cudf
X_dask, y_dask = _prep_training_data(client, X, y, n_parts)
else:
if convert_to_sparse:
# X_dask and y_dask are dask array
X_dask, y_dask = _prep_training_data_sparse(client, X, y, n_parts)
else:
# X_dask and y_dask are dask cudf
X_dask, y_dask = _prep_training_data(client, X, y, n_parts)

lr = cumlLBFGS_dask(
solver="qn",
Expand Down Expand Up @@ -557,30 +557,48 @@ def test_elasticnet(
("elasticnet", 2.0, 0.2),
],
)
@pytest.mark.parametrize("datatype", [np.float32])
@pytest.mark.parametrize("datatype", [np.float32, np.float64])
@pytest.mark.parametrize("delayed", [True])
@pytest.mark.parametrize("n_classes", [2, 8])
def test_sparse_from_dense(
fit_intercept, regularization, datatype, delayed, n_classes, client
):
penalty = regularization[0]
C = regularization[1]
l1_ratio = regularization[2]
penalty, C, l1_ratio = regularization

test_lbfgs(
nrows=1e5,
ncols=20,
n_parts=2,
fit_intercept=fit_intercept,
datatype=datatype,
delayed=delayed,
client=client,
penalty=penalty,
n_classes=n_classes,
C=C,
l1_ratio=l1_ratio,
convert_to_sparse=True,
)
if datatype == np.float32:
test_lbfgs(
nrows=1e5,
ncols=20,
n_parts=2,
fit_intercept=fit_intercept,
datatype=datatype,
delayed=delayed,
client=client,
penalty=penalty,
n_classes=n_classes,
C=C,
l1_ratio=l1_ratio,
convert_to_sparse=True,
)
else:
with pytest.raises(
RuntimeError,
match="dtypes other than float32 are currently not supported yet. See issue: https://github.com/rapidsai/cuml/issues/5589",
):
test_lbfgs(
nrows=1e5,
ncols=20,
n_parts=2,
fit_intercept=fit_intercept,
datatype=datatype,
delayed=delayed,
client=client,
penalty=penalty,
n_classes=n_classes,
C=C,
l1_ratio=l1_ratio,
convert_to_sparse=True,
)


@pytest.mark.parametrize("dtype", [np.float32])
Expand Down

0 comments on commit 2cc76f5

Please sign in to comment.