Skip to content
This repository has been archived by the owner on Jul 16, 2021. It is now read-only.

Aligns training and testing data #33

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions dask_xgboost/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
sparse = False
ss = False

import dask
from dask import delayed
from dask.distributed import wait, default_client
import dask.dataframe as dd
Expand Down Expand Up @@ -107,6 +108,7 @@ def _train(client, params, data, labels, dmatrix_kwargs={}, **kwargs):
--------
train
"""

# Break apart Dask.array/dataframe into chunks/parts
data_parts = data.to_delayed()
label_parts = labels.to_delayed()
Expand Down Expand Up @@ -158,6 +160,58 @@ def _train(client, params, data, labels, dmatrix_kwargs={}, **kwargs):
raise gen.Return(result)


def compute_array_chunks(arr):
assert isinstance(arr, da.Array)
parts = arr.to_delayed()
if isinstance(parts, np.ndarray):
parts = parts.flatten().tolist()
chunks = tuple([part.shape[0].compute() for part in parts])
return chunks


def align_training_data(client, data, labels):
"""Aligns training data and labels

Parameters
----------
client: dask.distributed.Client
data: dask Array or dask DataFrame
Training features
labels: dask Array or dask DataFrame
Training target

Returns
-------
data : dask Array or dask DataFrame
labels : dask Array or dask DataFrame
"""
with dask.config.set(scheduler=client):
# Compute data chunk/partition sizes
if isinstance(data, dd._Frame):
data_chunks = tuple(data.map_partitions(len).compute())
elif isinstance(data, da.Array):
if any(np.isnan(sum(c)) for c in data.chunks):
data_chunks = compute_array_chunks(data)
else:
data_chunks = data.chunks[0]

# Re-chunk/partition labels to match data
# Only rechunk if there is a size mismatch betwen data and labels
if isinstance(labels, dd._Frame):
labels_arr = labels.to_dask_array(lengths=True)
if labels_arr.chunks != (data_chunks,):
labels_arr = labels_arr.rechunk({0: data_chunks})
labels = labels_arr.to_dask_dataframe()
elif isinstance(labels, da.Array):
if any(np.isnan(sum(c)) for c in labels.chunks):
labels_chunks = compute_array_chunks(labels)
labels._chunks = (labels_chunks,)
if labels.chunks != data_chunks:
labels = labels.rechunk({0: data_chunks})

return data, labels


def train(client, params, data, labels, dmatrix_kwargs={}, **kwargs):
""" Train an XGBoost model on a Dask Cluster

Expand Down Expand Up @@ -187,6 +241,7 @@ def train(client, params, data, labels, dmatrix_kwargs={}, **kwargs):
--------
predict
"""
data, labels = align_training_data(client, data, labels)
return client.sync(_train, client, params, data,
labels, dmatrix_kwargs, **kwargs)

Expand Down
53 changes: 53 additions & 0 deletions dask_xgboost/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from distributed.utils_test import gen_cluster, loop, cluster # noqa

import dask_xgboost as dxgb
from dask_xgboost.core import align_training_data

# Workaround for conflict with distributed 1.23.0
# https://github.com/dask/dask-xgboost/pull/27#issuecomment-417474734
Expand Down Expand Up @@ -158,6 +159,23 @@ def test_basic(c, s, a, b):
assert ((predictions > 0.5) != labels).sum() < 2


@pytest.mark.parametrize('X, y', [ # noqa
(dd.from_pandas(df, chunksize=5),
dd.from_pandas(labels, chunksize=6)),
(dd.from_pandas(df, chunksize=5).values,
dd.from_pandas(labels, chunksize=6)),
(dd.from_pandas(df, chunksize=5),
dd.from_pandas(labels, chunksize=6).values),
(dd.from_pandas(df, chunksize=5).values,
dd.from_pandas(labels, chunksize=6).values),
])
def test_unequal_partition_lengths(loop, X, y): # noqa
with cluster() as (s, [a, b]):
with Client(s['address'], loop=loop):
clf = dxgb.XGBClassifier()
clf.fit(X, y)


@gen_cluster(client=True, timeout=None, check_new_threads=False)
def test_dmatrix_kwargs(c, s, a, b):
xgb.rabit.init() # workaround for "Doing rabit call after Finalize"
Expand Down Expand Up @@ -269,3 +287,38 @@ def f(part):
yield dxgb.train(c, param, df, df.x)

assert 'foo' in str(info.value)


def test_align_training_data_dataframe(loop): # noqa
with cluster() as (s, [a, b]):
with Client(s['address'], loop=loop) as client:
X = dd.from_pandas(df, chunksize=5)
y = dd.from_pandas(labels, chunksize=6)

X_partition_lengths = tuple(X.map_partitions(len).compute())
y_partition_lengths = tuple(y.map_partitions(len).compute())
assert X_partition_lengths != y_partition_lengths

X_align, y_align = align_training_data(client, X, y)
assert isinstance(X_align, dd._Frame)
assert isinstance(y_align, dd._Frame)

X_partition_lengths = tuple(X_align.map_partitions(len).compute())
y_partition_lengths = tuple(y_align.map_partitions(len).compute())
assert X_partition_lengths == y_partition_lengths


@pytest.mark.parametrize('equal_partitions', [True, False]) # noqa
def test_align_training_data_rechunk(loop, equal_partitions): # noqa
with cluster() as (s, [a, b]):
with Client(s['address'], loop=loop) as client:
X = dd.from_pandas(df, chunksize=5)
if equal_partitions:
y = dd.from_pandas(labels, chunksize=5)
else:
y = dd.from_pandas(labels, chunksize=6)

X_align, y_align = align_training_data(client, X, y)
assert X_align is X
if equal_partitions:
assert y_align is y