Skip to content

Commit

Permalink
Basic tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Sep 9, 2019
1 parent 1688432 commit 61b4051
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 137 deletions.
4 changes: 3 additions & 1 deletion python-package/xgboost/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,8 @@ def _init_from_npy2d(self, mat, missing, nthread):
and type if memory use is a concern.
"""
if len(mat.shape) != 2:
raise ValueError('Input numpy.ndarray must be 2 dimensional')
raise ValueError('Expecting 2 dimensional numpy.ndarray, got: ',
mat.shape)
# flatten the array by rows and ensure it is float32.
# we try to avoid data copies if possible (reshape returns a view when possible
# and we explicitly tell np.array to try and avoid copying)
Expand Down Expand Up @@ -1017,6 +1018,7 @@ def feature_types(self, feature_types):


def distributed_dispatch(worker_predict):
'''Decides whether distributed predict function should be used.'''
try:
from .distributed import dispatch_predict
return dispatch_predict(worker_predict)
Expand Down
104 changes: 61 additions & 43 deletions python-package/xgboost/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,20 @@
import logging
from threading import Thread
from .compat import distributed_get_worker
import dask
from dask.array import Array as DA
from dask.dataframe import DataFrame as DDF
from dask.dataframe import Series as DS
from dask.distributed import Client, get_client, wait
from dask.distributed import comm
try:
from dask.distributed import Client, get_client, wait
except ImportError:
Client = None
get_client = None
wait = None
try:
from dask.distributed import comm
except ImportError:
comm = None
from . import rabit
from .core import DMatrix
from .tracker import RabitTracker
Expand All @@ -28,7 +37,6 @@

def _start_tracker(host, n_workers):
""" Start Rabit tracker """
print('_start_tracker')
env = {'DMLC_NUM_WORKER': n_workers}
rabit_context = RabitTracker(hostIP=host, nslave=n_workers)
env.update(rabit_context.slave_envs())
Expand All @@ -41,14 +49,15 @@ def _start_tracker(host, n_workers):


class RabitContext:
'''A context controling rabit initialization and finalization.'''
def __init__(self, args):
self.args = args

def __enter__(self):
rabit.init(self.args)
logging.debug('-------------- rabit say hello ------------------')

def __exit__(self, type, value, traceback):
def __exit__(self, *args):
rabit.finalize()
logging.debug('--------------- rabit say bye ------------------')

Expand All @@ -67,10 +76,18 @@ def concat(L):
". Got %s" % type(L[0]))


def map_local_data(X, y, weights=None):
async def map_local_data(X, y, weights=None):
client = get_client()
X_parts = X.to_delayed()
y_parts = y.to_delayed()

if isinstance(X_parts, numpy.ndarray):
assert X_parts.shape[1] == 1
X_parts = X_parts.flatten().tolist()
if isinstance(y_parts, numpy.ndarray):
assert y_parts.ndim == 1 or y_parts.shape[1] == 1
y_parts = y_parts.flatten().tolist()

assert len(X_parts) == len(
y_parts), 'Partitions between X and y are not consistent'

Expand Down Expand Up @@ -157,7 +174,7 @@ def __init__(self,

self.X = X
self.y = y
self.weight = weight
self.weights = weight

assert type(X) in (DDF, DA)
assert type(y) in (DDF, DS, DA)
Expand All @@ -181,8 +198,8 @@ async def from_dataframe(self):
assert len(X_parts) == len(
y_parts), 'Partitions between X and y are not consistent'

if self.weight:
w_parts = self.weight.to_delayed()
if self.weights:
w_parts = self.weights.to_delayed()
assert len(X_parts) == len(
w_parts), 'Partitions between X and weight are not consistent.'
parts = list(map(delayed, zip(X_parts, y_parts, w_parts)))
Expand All @@ -203,10 +220,9 @@ async def from_dataframe(self):
for key, workers in who_has.items():
worker_map[first(workers)].append(key_to_partition[key])
self.worker_map = worker_map
print('from_dataframe.worker_map:', self.worker_map)

async def from_array(self):
self.worker_map = map_local_data(self.X, self.y, self.weights)
self.worker_map = await map_local_data(self.X, self.y, self.weights)

def get_worker_data(self, worker):
'''Get data that local to worker.
Expand All @@ -219,14 +235,14 @@ def get_worker_data(self, worker):
list_of_parts = self.worker_map[worker.address]
list_of_parts = [p.result() for p in list_of_parts]

if self.weight:
if self.weights:
data, labels, weights = zip(*list_of_parts)
else:
data, labels = zip(*list_of_parts)

data = concat(data)
labels = concat(labels)
if self.weight:
if self.weights:
weights = concat(weights)
else:
weights = None
Expand All @@ -240,10 +256,10 @@ def get_worker_data(self, worker):
return dmatrix

def num_row(self):
return self.n_rows.compute()
return self.n_rows

def num_col(self):
return self.n_cols.compute()
return self.n_cols


def train(worker_train,
Expand All @@ -266,7 +282,6 @@ def train(worker_train,

host = comm.get_address_host(client.scheduler.address)
worker_map = dtrain.worker_map
print('train.worker_map:', worker_map)

env = client.run_on_scheduler(_start_tracker,
host.strip('/:'),
Expand Down Expand Up @@ -297,20 +312,22 @@ def dispatched_train(worker_id):

futures = client.map(dispatched_train, range(len(worker_map)),
workers=list(worker_map.keys()))
return futures
bsts = client.gather(futures)
return list(filter(lambda bst: bst is not None, bsts))[0]


def dispatch_training(func):
'''Decides whether to use distributed training.'''
def dispatcher(params,
dtrain,
num_boost_round=10,
evals=(),
*args,
**kwargs):
if isinstance(dtrain, DaskDMatrix):
train(func, parameters=params, dtrain=dtrain,
num_boost_round=num_boost_round, evals=evals, *args,
**kwargs)
return train(func, parameters=params, dtrain=dtrain,
num_boost_round=num_boost_round, evals=evals, *args,
**kwargs)
else:
return func(params, dtrain, num_boost_round, evals, *args,
**kwargs)
Expand All @@ -319,33 +336,34 @@ def dispatcher(params,


def dispatch_predict(func):
def dispatcher(data, output_margin=False, ntree_limit=0,
pred_leaf=False, pred_contribs=False,
approx_contribs=False, pred_interactions=False,
validate_features=True):
'''Decides whether to use distributed predict.'''

def predict(booster, data, *args, **kwargs):
'''Predict using booster.
Parameters
----------
booster: xgboost.Booster
A trained model.
data: DaskDMatrix
Input data used for prediction.
'''
worker = distributed_get_worker()
local_data = data.get_worker_data(worker)
return booster.predict(local_data, *args, **kwargs)

def dispatcher(booster, data, *args, **kwargs):
if isinstance(data, DaskDMatrix):
client: Client = get_client()

worker_map = data.worker_map

def dispatched_predict():
worker = distributed_get_worker()
local_data = data.get_worker_data(worker)
func(local_data, output_margin=output_margin,
ntree_limit=ntree_limit, pred_leaf=pred_leaf,
pred_contribs=pred_contribs,
approx_contribs=approx_contribs,
pred_interactions=pred_interactions,
validate_features=validate_features)

futures = client.map(dispatch_predict, range(len(worker_map)),
workers=list(worker_map.keys()))
return futures
futures = client.map(predict, range(len(worker_map)),
workers=list(worker_map.keys()),
booster=booster, data=data,
*args, **kwargs)
prediction = dask.array.stack(futures, axis=0)
return prediction
else:
return func(data, output_margin=output_margin,
ntree_limit=ntree_limit, pred_leaf=pred_leaf,
pred_contribs=pred_contribs,
approx_contribs=approx_contribs,
pred_interactions=pred_interactions,
validate_features=validate_features)
return func(booster, data, *args, **kwargs)
return dispatcher
1 change: 1 addition & 0 deletions python-package/xgboost/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def _train_internal(params, dtrain,


def distributed_dispatch(worker_train):
'''Decides whether to use distributed training.'''
try:
from .distributed import dispatch_training
return dispatch_training(worker_train)
Expand Down
146 changes: 53 additions & 93 deletions tests/python/test_with_dask.py
Original file line number Diff line number Diff line change
@@ -1,93 +1,53 @@
import testing as tm
import pytest
import xgboost as xgb
import numpy as np
import sys

if sys.platform.startswith("win"):
pytest.skip("Skipping dask tests on Windows", allow_module_level=True)

try:
from distributed.utils_test import client, loop, cluster_fixture
import dask.dataframe as dd
import dask.array as da
except ImportError:
pass

pytestmark = pytest.mark.skipif(**tm.no_dask())


def run_train():
# Contains one label equal to rank
dmat = xgb.DMatrix([[0]], label=[xgb.rabit.get_rank()])
bst = xgb.train({"eta": 1.0, "lambda": 0.0}, dmat, 1)
pred = bst.predict(dmat)
expected_result = np.average(range(xgb.rabit.get_world_size()))
assert all(p == expected_result for p in pred)


def test_train(client):
# Train two workers, the first has label 0, the second has label 1
# If they build the model together the output should be 0.5
xgb.dask.run(client, run_train)
# Run again to check we can have multiple sessions
xgb.dask.run(client, run_train)


def run_create_dmatrix(X, y, weights):
dmat = xgb.dask.create_worker_dmatrix(X, y, weight=weights)
# Expect this worker to get two partitions and concatenate them
assert dmat.num_row() == 50


def test_dask_dataframe(client):
n = 10
m = 100
partition_size = 25
X = dd.from_array(np.random.random((m, n)), partition_size)
y = dd.from_array(np.random.random(m), partition_size)
weights = dd.from_array(np.random.random(m), partition_size)
xgb.dask.run(client, run_create_dmatrix, X, y, weights)


def test_dask_array(client):
n = 10
m = 100
partition_size = 25
X = da.random.random((m, n), partition_size)
y = da.random.random(m, partition_size)
weights = da.random.random(m, partition_size)
xgb.dask.run(client, run_create_dmatrix, X, y, weights)


def run_get_local_data(X, y):
X_local = xgb.dask.get_local_data(X)
y_local = xgb.dask.get_local_data(y)
assert (X_local.shape == (50, 10))
assert (y_local.shape == (50,))


def test_get_local_data(client):
n = 10
m = 100
partition_size = 25
X = da.random.random((m, n), partition_size)
y = da.random.random(m, partition_size)
xgb.dask.run(client, run_get_local_data, X, y)


def run_sklearn():
# Contains one label equal to rank
X = [[0]]
y = [xgb.rabit.get_rank()]
model = xgb.XGBRegressor(learning_rate=1.0)
model.fit(X, y)
pred = model.predict(X)
expected_result = np.average(range(xgb.rabit.get_world_size()))
assert all(p == expected_result for p in pred)
return pred


def test_sklearn(client):
result = xgb.dask.run(client, run_sklearn)
print(result)
import testing as tm
import pytest
import xgboost as xgb
import sys

if sys.platform.startswith("win"):
pytest.skip("Skipping dask tests on Windows", allow_module_level=True)

try:
from distributed.utils_test import client, loop, cluster_fixture
import dask.dataframe as dd
import dask.array as da
from xgboost.distributed import DaskDMatrix
except ImportError:
pass

pytestmark = pytest.mark.skipif(**tm.no_dask())


def test_dask_dataframe(c):
n = 10
m = 1000
partition_size = 20

X = da.random.random((m, n), partition_size)
y = da.random.random(m, partition_size)

X = dd.from_dask_array(X)
y = dd.from_array(y)

dtrain = DaskDMatrix(X, y)
bst: xgb.Booster = xgb.train({}, dtrain, num_boost_round=2)

prediction = bst.predict(dtrain)

assert isinstance(prediction, da.Array)


def test_from_array(c):
n = 10
m = 1000
partition_size = 20
X = da.random.random((m, n), partition_size)
y = da.random.random(m, partition_size)

print('X.shape:', X.shape)

dtrain = DaskDMatrix(X, y)
bst: xgb.Booster = xgb.train({}, dtrain)

prediction = bst.predict(dtrain)

assert isinstance(prediction, da.Array)

0 comments on commit 61b4051

Please sign in to comment.