Skip to content

Commit

Permalink
[dask] Fix missing value for scikit-learn interface.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Mar 20, 2020
1 parent 3cf665d commit 824b351
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 12 deletions.
14 changes: 10 additions & 4 deletions doc/tutorials/dask.rst
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,14 @@ Basic functionalities including training and generating predictions for regressi
classification are implemented. But there are still some other limitations we haven't
addressed yet.

- Label encoding for Scikit-Learn classifier.
- Ranking
- Label encoding for Scikit-Learn classifier may not be supported. Meaning that user need
to encode their training labels into discrete values first.
- Ranking is not supported right now.
- Empty worker is not well supported by classifier. If the training hangs for classifier
with a warning about empty DMatrix, please consider balancing your data first. But
regressor works fine with empty DMatrix.
- Callback functions are not tested.
- To use cross validation one needs to explicitly train different models instead of using
a functional API like ``xgboost.cv``.
- Only ``GridSearchCV`` from ``scikit-learn`` is supported for dask interface. Meaning
that we can distribute data among workers but have to train one model at a time. If you
want to scale up grid searching with model parallelism by ``dask-ml``, please consider
using normal ``scikit-learn`` interface like `xgboost.XGBRegressor` for now.
23 changes: 15 additions & 8 deletions python-package/xgboost/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,7 @@ def map_function(func):
return predictions


def _evaluation_matrices(client, validation_set, sample_weights):
def _evaluation_matrices(client, validation_set, sample_weights, missing):
'''
Parameters
----------
Expand All @@ -597,7 +597,8 @@ def _evaluation_matrices(client, validation_set, sample_weights):
for i, e in enumerate(validation_set):
w = (sample_weights[i]
if sample_weights is not None else None)
dmat = DaskDMatrix(client=client, data=e[0], label=e[1], weight=w)
dmat = DaskDMatrix(client=client, data=e[0], label=e[1], weight=w,
missing=missing)
evals.append((dmat, 'validation_{}'.format(i)))
else:
evals = None
Expand Down Expand Up @@ -672,10 +673,12 @@ def fit(self,
verbose=True):
_assert_dask_support()
dtrain = DaskDMatrix(client=self.client,
data=X, label=y, weight=sample_weights)
data=X, label=y, weight=sample_weights,
missing=self.missing)
params = self.get_xgb_params()
evals = _evaluation_matrices(self.client,
eval_set, sample_weight_eval_set)
eval_set, sample_weight_eval_set,
self.missing)

results = train(self.client, params, dtrain,
num_boost_round=self.get_num_boosting_rounds(),
Expand All @@ -688,7 +691,8 @@ def fit(self,

def predict(self, data): # pylint: disable=arguments-differ
_assert_dask_support()
test_dmatrix = DaskDMatrix(client=self.client, data=data)
test_dmatrix = DaskDMatrix(client=self.client, data=data,
missing=self.missing)
pred_probs = predict(client=self.client,
model=self.get_booster(), data=test_dmatrix)
return pred_probs
Expand All @@ -711,7 +715,8 @@ def fit(self,
verbose=True):
_assert_dask_support()
dtrain = DaskDMatrix(client=self.client,
data=X, label=y, weight=sample_weights)
data=X, label=y, weight=sample_weights,
missing=self.missing)
params = self.get_xgb_params()

# pylint: disable=attribute-defined-outside-init
Expand All @@ -728,7 +733,8 @@ def fit(self,
params["objective"] = "binary:logistic"

evals = _evaluation_matrices(self.client,
eval_set, sample_weight_eval_set)
eval_set, sample_weight_eval_set,
self.missing)
results = train(self.client, params, dtrain,
num_boost_round=self.get_num_boosting_rounds(),
evals=evals, verbose_eval=verbose)
Expand All @@ -739,7 +745,8 @@ def fit(self,

def predict(self, data): # pylint: disable=arguments-differ
_assert_dask_support()
test_dmatrix = DaskDMatrix(client=self.client, data=data)
test_dmatrix = DaskDMatrix(client=self.client, data=data,
missing=self.missing)
pred_probs = predict(client=self.client,
model=self.get_booster(), data=test_dmatrix)
return pred_probs
52 changes: 52 additions & 0 deletions tests/python/test_with_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,58 @@ def test_from_dask_array():
assert np.all(single_node_predt == from_arr.compute())


def test_dask_missing_value_reg():
with LocalCluster(n_workers=5) as cluster:
with Client(cluster) as client:
X_0 = np.ones((20 // 2, kCols))
X_1 = np.zeros((20 // 2, kCols))
X = np.concatenate([X_0, X_1], axis=0)
np.random.shuffle(X)
X = da.from_array(X)
X = X.rechunk(20, 1)
y = da.random.randint(0, 3, size=20)
y.rechunk(20)
regressor = xgb.dask.DaskXGBRegressor(verbosity=1, n_estimators=2,
missing=0.0)
regressor.client = client
regressor.set_params(tree_method='hist')
regressor.fit(X, y, eval_set=[(X, y)])
dd_predt = regressor.predict(X).compute()

np_X = X.compute()
np_predt = regressor.get_booster().predict(
xgb.DMatrix(np_X, missing=0.0))
np.testing.assert_allclose(np_predt, dd_predt)


def test_dask_missing_value_cls():
# Multi-class doesn't handle empty DMatrix well. So we use lesser workers.
with LocalCluster(n_workers=2) as cluster:
with Client(cluster) as client:
X_0 = np.ones((kRows // 2, kCols))
X_1 = np.zeros((kRows // 2, kCols))
X = np.concatenate([X_0, X_1], axis=0)
np.random.shuffle(X)
X = da.from_array(X)
X = X.rechunk(20, None)
y = da.random.randint(0, 3, size=kRows)
y = y.rechunk(20, 1)
cls = xgb.dask.DaskXGBClassifier(verbosity=1, n_estimators=2,
tree_method='hist',
missing=0.0)
cls.client = client
cls.fit(X, y, eval_set=[(X, y)])
dd_predt = cls.predict(X).compute()

np_X = X.compute()
np_predt = cls.get_booster().predict(
xgb.DMatrix(np_X, missing=0.0))
np.testing.assert_allclose(np_predt, dd_predt)

cls = xgb.dask.DaskXGBClassifier()
assert hasattr(cls, 'missing')


def test_dask_regressor():
with LocalCluster(n_workers=5) as cluster:
with Client(cluster) as client:
Expand Down

0 comments on commit 824b351

Please sign in to comment.