Skip to content

Commit

Permalink
Remove return in test_lbfgs (#5875)
Browse files Browse the repository at this point in the history
Similar to #5819, pytest functions should not `return` and seems unneeded in this test

Authors:
  - Matthew Roeschke (https://github.com/mroeschke)

Approvers:
  - Bradley Dice (https://github.com/bdice)

URL: #5875
  • Loading branch information
mroeschke authored Apr 30, 2024
1 parent 1609fcd commit 89d5371
Showing 1 changed file with 23 additions and 14 deletions.
37 changes: 23 additions & 14 deletions python/cuml/tests/dask/test_dask_logistic_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,14 +303,7 @@ def assert_params(
)


@pytest.mark.mg
@pytest.mark.parametrize("nrows", [1e5])
@pytest.mark.parametrize("ncols", [20])
@pytest.mark.parametrize("n_parts", [2, 23])
@pytest.mark.parametrize("fit_intercept", [False, True])
@pytest.mark.parametrize("datatype", [np.float32])
@pytest.mark.parametrize("delayed", [True, False])
def test_lbfgs(
def _test_lbfgs(
nrows,
ncols,
n_parts,
Expand Down Expand Up @@ -428,9 +421,25 @@ def array_to_numpy(ary):
return lr


@pytest.mark.mg
@pytest.mark.parametrize("n_parts", [2, 23])
@pytest.mark.parametrize("fit_intercept", [False, True])
@pytest.mark.parametrize("delayed", [True, False])
def test_lbfgs(n_parts, fit_intercept, delayed, client):
_test_lbfgs(
nrows=1e5,
ncols=20,
n_parts=n_parts,
fit_intercept=fit_intercept,
datatype=np.float32,
delayed=delayed,
client=client,
)


@pytest.mark.parametrize("fit_intercept", [False, True])
def test_noreg(fit_intercept, client):
lr = test_lbfgs(
lr = _test_lbfgs(
nrows=1e5,
ncols=20,
n_parts=23,
Expand Down Expand Up @@ -494,7 +503,7 @@ def assert_small(X, y, n_classes):
@pytest.mark.parametrize("n_classes", [8])
def test_n_classes(n_parts, fit_intercept, n_classes, client):
nrows = int(1e5) if n_classes < 5 else int(2e5)
lr = test_lbfgs(
lr = _test_lbfgs(
nrows=nrows,
ncols=20,
n_parts=n_parts,
Expand All @@ -517,7 +526,7 @@ def test_n_classes(n_parts, fit_intercept, n_classes, client):
@pytest.mark.parametrize("C", [1.0, 10.0])
def test_l1(fit_intercept, datatype, delayed, n_classes, C, client):
nrows = int(1e5) if n_classes < 5 else int(2e5)
lr = test_lbfgs(
lr = _test_lbfgs(
nrows=nrows,
ncols=20,
n_parts=2,
Expand Down Expand Up @@ -545,7 +554,7 @@ def test_elasticnet(
fit_intercept, datatype, delayed, n_classes, l1_ratio, client
):
nrows = int(1e5) if n_classes < 5 else int(2e5)
lr = test_lbfgs(
lr = _test_lbfgs(
nrows=nrows,
ncols=20,
n_parts=2,
Expand Down Expand Up @@ -585,7 +594,7 @@ def test_sparse_from_dense(

nrows = int(1e5) if n_classes < 5 else int(2e5)
run_test = partial(
test_lbfgs,
_test_lbfgs,
nrows=nrows,
ncols=20,
n_parts=2,
Expand Down Expand Up @@ -699,7 +708,7 @@ def test_standardization_on_normal_dataset(
nrows = int(1e5) if n_classes < 5 else int(2e5)

# test correctness compared with scikit-learn
test_lbfgs(
_test_lbfgs(
nrows=nrows,
ncols=20,
n_parts=2,
Expand Down

0 comments on commit 89d5371

Please sign in to comment.