Skip to content

Commit

Permalink
Correct and adjust tolerances of mnmg logreg pytests (rapidsai#5812)
Browse files Browse the repository at this point in the history
* FIX correct and adjust tolerances of mnmg logreg pytests

* FIX style fixes

* Update test_dask_logistic_regression.py

Co-authored-by: Bradley Dice <bdice@bradleydice.com>

* FIX style fixes

---------

Co-authored-by: Bradley Dice <bdice@bradleydice.com>
  • Loading branch information
dantegd and bdice authored Apr 4, 2024
1 parent 669fad2 commit b9e4a60
Showing 1 changed file with 64 additions and 12 deletions.
76 changes: 64 additions & 12 deletions python/cuml/tests/dask/test_dask_logistic_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def test_lbfgs(
n_classes=2,
convert_to_sparse=False,
):
tolerance = 0.005
tolerance = 0.01 if convert_to_sparse else 0.005

def imp():
import cuml.comm.serialize # NOQA
Expand Down Expand Up @@ -394,9 +394,19 @@ def array_to_numpy(ary):

if sk_solver == "lbfgs" and standardization is False:
assert len(lr_coef) == len(sk_coef)
assert array_equal(lr_coef, sk_coef, tolerance, with_sign=True)
assert array_equal(
lr_intercept, sk_intercept, tolerance, with_sign=True
lr_coef,
sk_coef,
unit_tol=tolerance,
total_tol=tolerance,
with_sign=True,
)
assert array_equal(
lr_intercept,
sk_intercept,
unit_tol=tolerance,
total_tol=tolerance,
with_sign=True,
)

# test predict
Expand Down Expand Up @@ -783,7 +793,12 @@ def to_dask_data(X_train, X_test, y_train, y_test):
)
mgon_accuracy = accuracy_score(y_test, mgon_preds)

assert array_equal(X_train_dask.compute().to_numpy(), X_train)
assert array_equal(
X_train_dask.compute().to_numpy(),
X_train,
unit_tol=tolerance,
total_tol=tolerance,
)

# run CPU with StandardScaler
# if fit_intercept is true, mean center then scale the dataset
Expand Down Expand Up @@ -821,8 +836,18 @@ def to_dask_data(X_train, X_test, y_train, y_test):
mgon_intercept_origin = mgon.intercept_.to_numpy()

if sk_solver == "lbfgs":
assert array_equal(mgon_coef_origin, cpu.coef_, tolerance)
assert array_equal(mgon_intercept_origin, cpu.intercept_, tolerance)
assert array_equal(
mgon_coef_origin,
cpu.coef_,
unit_tol=tolerance,
total_tol=tolerance,
)
assert array_equal(
mgon_intercept_origin,
cpu.intercept_,
unit_tol=tolerance,
total_tol=tolerance,
)

# running MG with standardization=False
mgoff = cumlLBFGS_dask(
Expand All @@ -849,9 +874,17 @@ def to_dask_data(X_train, X_test, y_train, y_test):
np.abs(mgon_accuracy - mgoff_accuracy) < 1e-3
)

assert array_equal(mgon_coef_origin, mgoff.coef_.to_numpy(), tolerance)
assert array_equal(
mgon_intercept_origin, mgoff.intercept_.to_numpy(), tolerance
mgon_coef_origin,
mgoff.coef_.to_numpy(),
unit_tol=tolerance,
total_tol=tolerance,
)
assert array_equal(
mgon_intercept_origin,
mgoff.intercept_.to_numpy(),
unit_tol=tolerance,
total_tol=tolerance,
)


Expand Down Expand Up @@ -888,6 +921,8 @@ def test_standardization_example(fit_intercept, regularization, client):
"max_iter": max_iter,
}

tolerance = 0.005

X, y = make_classification_dataset(
datatype, n_rows, n_cols, n_info, n_classes=n_classes
)
Expand Down Expand Up @@ -917,16 +952,33 @@ def test_standardization_example(fit_intercept, regularization, client):
lr_off = cumlLBFGS_dask(standardization=False, **est_params)
lr_off.fit(X_df_scaled, y_df)

assert array_equal(lron_coef_origin, lr_off.coef_.to_numpy())
assert array_equal(lron_intercept_origin, lr_off.intercept_.to_numpy())
assert array_equal(
lron_coef_origin,
lr_off.coef_.to_numpy(),
unit_tol=tolerance,
total_tol=tolerance,
)
assert array_equal(
lron_intercept_origin,
lr_off.intercept_.to_numpy(),
unit_tol=tolerance,
total_tol=tolerance,
)

from cuml.linear_model import LogisticRegression as SG

sg = SG(**est_params)
sg.fit(X_scaled, y)

assert array_equal(lron_coef_origin, sg.coef_)
assert array_equal(lron_intercept_origin, sg.intercept_)
assert array_equal(
lron_coef_origin, sg.coef_, unit_tol=tolerance, total_tol=tolerance
)
assert array_equal(
lron_intercept_origin,
sg.intercept_,
unit_tol=tolerance,
total_tol=tolerance,
)


@pytest.mark.mg
Expand Down

0 comments on commit b9e4a60

Please sign in to comment.