Skip to content

Commit

Permalink
[dask][tests] reduce code duplication in Dask tests (#3828)
Browse files Browse the repository at this point in the history
  • Loading branch information
StrikerRUS authored Jan 24, 2021
1 parent 5a4fec6 commit ac706e1
Showing 1 changed file with 51 additions and 47 deletions.
98 changes: 51 additions & 47 deletions tests/python_package_test/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,19 +133,22 @@ def test_classifier(output, centers, client, listen_port):
centers=centers
)

params = {
"n_estimators": 10,
"num_leaves": 10
}
dask_classifier = dlgbm.DaskLGBMClassifier(
time_out=5,
local_listen_port=listen_port,
n_estimators=10,
num_leaves=10
**params
)
dask_classifier = dask_classifier.fit(dX, dy, sample_weight=dw, client=client)
p1 = dask_classifier.predict(dX)
p1_proba = dask_classifier.predict_proba(dX).compute()
s1 = accuracy_score(dy, p1)
p1 = p1.compute()

local_classifier = lightgbm.LGBMClassifier(n_estimators=10, num_leaves=10)
local_classifier = lightgbm.LGBMClassifier(**params)
local_classifier.fit(X, y, sample_weight=w)
p2 = local_classifier.predict(X)
p2_proba = local_classifier.predict_proba(X)
Expand All @@ -169,20 +172,20 @@ def test_classifier_pred_contrib(output, centers, client, listen_port):
centers=centers
)

params = {
"n_estimators": 10,
"num_leaves": 10
}
dask_classifier = dlgbm.DaskLGBMClassifier(
time_out=5,
local_listen_port=listen_port,
tree_learner='data',
n_estimators=10,
num_leaves=10
**params
)
dask_classifier = dask_classifier.fit(dX, dy, sample_weight=dw, client=client)
preds_with_contrib = dask_classifier.predict(dX, pred_contrib=True).compute()

local_classifier = lightgbm.LGBMClassifier(
n_estimators=10,
num_leaves=10
)
local_classifier = lightgbm.LGBMClassifier(**params)
local_classifier.fit(X, y, sample_weight=w)
local_preds_with_contrib = local_classifier.predict(X, pred_contrib=True)

Expand Down Expand Up @@ -243,16 +246,19 @@ def test_classifier_local_predict(client, listen_port):
output='array'
)

params = {
"n_estimators": 10,
"num_leaves": 10
}
dask_classifier = dlgbm.DaskLGBMClassifier(
time_out=5,
local_port=listen_port,
n_estimators=10,
num_leaves=10
**params
)
dask_classifier = dask_classifier.fit(dX, dy, sample_weight=dw, client=client)
p1 = dask_classifier.to_local().predict(dX)

local_classifier = lightgbm.LGBMClassifier(n_estimators=10, num_leaves=10)
local_classifier = lightgbm.LGBMClassifier(**params)
local_classifier.fit(X, y, sample_weight=w)
p2 = local_classifier.predict(X)

Expand All @@ -270,20 +276,23 @@ def test_regressor(output, client, listen_port):
output=output
)

params = {
"random_state": 42,
"num_leaves": 10
}
dask_regressor = dlgbm.DaskLGBMRegressor(
time_out=5,
local_listen_port=listen_port,
seed=42,
num_leaves=10,
tree='data'
tree='data',
**params
)
dask_regressor = dask_regressor.fit(dX, dy, client=client, sample_weight=dw)
p1 = dask_regressor.predict(dX)
if output != 'dataframe':
s1 = r2_score(dy, p1)
p1 = p1.compute()

local_regressor = lightgbm.LGBMRegressor(seed=42, num_leaves=10)
local_regressor = lightgbm.LGBMRegressor(**params)
local_regressor.fit(X, y, sample_weight=w)
s2 = local_regressor.score(X, y)
p2 = local_regressor.predict(X)
Expand All @@ -306,20 +315,20 @@ def test_regressor_pred_contrib(output, client, listen_port):
output=output
)

params = {
"n_estimators": 10,
"num_leaves": 10
}
dask_regressor = dlgbm.DaskLGBMRegressor(
time_out=5,
local_listen_port=listen_port,
tree_learner='data',
n_estimators=10,
num_leaves=10
**params
)
dask_regressor = dask_regressor.fit(dX, dy, sample_weight=dw, client=client)
preds_with_contrib = dask_regressor.predict(dX, pred_contrib=True).compute()

local_regressor = lightgbm.LGBMRegressor(
n_estimators=10,
num_leaves=10
)
local_regressor = lightgbm.LGBMRegressor(**params)
local_regressor.fit(X, y, sample_weight=w)
local_preds_with_contrib = local_regressor.predict(X, pred_contrib=True)

Expand All @@ -341,26 +350,23 @@ def test_regressor_quantile(output, client, listen_port, alpha):
output=output
)

params = {
"objective": "quantile",
"alpha": alpha,
"random_state": 42,
"n_estimators": 10,
"num_leaves": 10
}
dask_regressor = dlgbm.DaskLGBMRegressor(
local_listen_port=listen_port,
seed=42,
objective='quantile',
alpha=alpha,
n_estimators=10,
num_leaves=10,
tree_learner_type='data_parallel'
tree_learner_type='data_parallel',
**params
)
dask_regressor = dask_regressor.fit(dX, dy, client=client, sample_weight=dw)
p1 = dask_regressor.predict(dX).compute()
q1 = np.count_nonzero(y < p1) / y.shape[0]

local_regressor = lightgbm.LGBMRegressor(
seed=42,
objective='quantile',
alpha=alpha,
n_estimatores=10,
num_leaves=10
)
local_regressor = lightgbm.LGBMRegressor(**params)
local_regressor.fit(X, y, sample_weight=w)
p2 = local_regressor.predict(X)
q2 = np.count_nonzero(y < p2) / y.shape[0]
Expand All @@ -377,7 +383,7 @@ def test_regressor_local_predict(client, listen_port):

dask_regressor = dlgbm.DaskLGBMRegressor(
local_listen_port=listen_port,
seed=42,
random_state=42,
n_estimators=10,
num_leaves=10,
tree_type='data'
Expand Down Expand Up @@ -407,25 +413,23 @@ def test_ranker(output, client, listen_port, group):

# use many trees + leaves to overfit, help ensure that dask data-parallel strategy matches that of
# serial learner. See https://github.com/microsoft/LightGBM/issues/3292#issuecomment-671288210.
params = {
"random_state": 42,
"n_estimators": 50,
"num_leaves": 20,
"min_child_samples": 1
}
dask_ranker = dlgbm.DaskLGBMRanker(
time_out=5,
local_listen_port=listen_port,
tree_learner_type='data_parallel',
n_estimators=50,
num_leaves=20,
seed=42,
min_child_samples=1
**params
)
dask_ranker = dask_ranker.fit(dX, dy, sample_weight=dw, group=dg, client=client)
rnkvec_dask = dask_ranker.predict(dX)
rnkvec_dask = rnkvec_dask.compute()

local_ranker = lightgbm.LGBMRanker(
n_estimators=50,
num_leaves=20,
seed=42,
min_child_samples=1
)
local_ranker = lightgbm.LGBMRanker(**params)
local_ranker.fit(X, y, sample_weight=w, group=g)
rnkvec_local = local_ranker.predict(X)

Expand Down Expand Up @@ -453,7 +457,7 @@ def test_ranker_local_predict(output, client, listen_port, group):
tree_learner='data',
n_estimators=10,
num_leaves=10,
seed=42,
random_state=42,
min_child_samples=1
)
dask_ranker = dask_ranker.fit(dX, dy, group=dg, client=client)
Expand Down

0 comments on commit ac706e1

Please sign in to comment.