Skip to content
This repository has been archived by the owner on Jul 16, 2021. It is now read-only.

Remove use of underscored methods #14

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions dask_xgboost/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from toolz import first, assoc
from tornado import gen
from dask import delayed
from distributed.client import _wait, default_client
from distributed.client import wait, default_client
from distributed.utils import sync
import xgboost as xgb

Expand Down Expand Up @@ -105,7 +105,7 @@ def _train(client, params, data, labels, dmatrix_kwargs={}, **kwargs):
# Arrange parts into pairs. This enforces co-locality
parts = list(map(delayed, zip(data_parts, label_parts)))
parts = client.compute(parts) # Start computation in the background
yield _wait(parts)
yield wait(parts)

# Because XGBoost-python doesn't yet allow iterative training, we need to
# find the locations of all chunks and map them to particular Dask workers
Expand All @@ -119,19 +119,20 @@ def _train(client, params, data, labels, dmatrix_kwargs={}, **kwargs):

# Start the XGBoost tracker on the Dask scheduler
host, port = parse_host_port(client.scheduler.address)
env = yield client._run_on_scheduler(start_tracker,
host.strip('/:'),
len(worker_map))
env = yield client.run_on_scheduler(start_tracker,
host.strip('/:'),
len(worker_map))

# Tell each worker to train on the chunks/parts that it has locally
futures = [client.submit(train_part, env,
assoc(params, 'nthreads', ncores[worker]),
list_of_parts, workers=worker,
allow_other_workers=True,
dmatrix_kwargs=dmatrix_kwargs, **kwargs)
for worker, list_of_parts in worker_map.items()]

# Get the results, only one will be non-None
results = yield client._gather(futures)
results = yield client.gather(futures)
result = [v for v in results if v][0]
raise gen.Return(result)

Expand Down