Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dask] dask cudf inplace prediction. #5512

Merged
merged 22 commits into from
Apr 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -296,19 +296,25 @@ def TestPythonGPU(args) {
sh """
${dockerRun} ${container_type} ${docker_binary} ${docker_args} tests/ci_build/test_python.sh mgpu
"""
if (args.cuda_version != '9.0') {
echo "Running tests with cuDF..."
sh """
${dockerRun} cudf ${docker_binary} ${docker_args} tests/ci_build/test_python.sh mgpu-cudf
"""
}
} else {
echo "Using a single GPU"
sh """
${dockerRun} ${container_type} ${docker_binary} ${docker_args} tests/ci_build/test_python.sh gpu
"""
if (args.cuda_version != '9.0') {
echo "Running tests with cuDF..."
sh """
${dockerRun} cudf ${docker_binary} ${docker_args} tests/ci_build/test_python.sh cudf
"""
}
}
// For CUDA 10.0 target, run cuDF tests too
if (args.cuda_version == '10.0') {
echo "Running tests with cuDF..."
sh """
${dockerRun} cudf ${docker_binary} ${docker_args} tests/ci_build/test_python.sh cudf
"""
}
deleteDir()
}
}
Expand Down
31 changes: 24 additions & 7 deletions python-package/xgboost/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,15 +209,32 @@ def ctypes2numpy(cptr, length, dtype):

def ctypes2cupy(cptr, length, dtype):
"""Convert a ctypes pointer array to a cupy array."""
import cupy # pylint: disable=import-error
mem = cupy.zeros(length.value, dtype=dtype, order='C')
# pylint: disable=import-error
import cupy
from cupy.cuda.memory import MemoryPointer
from cupy.cuda.memory import UnownedMemory
CUPY_TO_CTYPES_MAPPING = {
cupy.float32: ctypes.c_float,
cupy.uint32: ctypes.c_uint
}
if dtype not in CUPY_TO_CTYPES_MAPPING.keys():
raise RuntimeError('Supported types: {}'.format(
CUPY_TO_CTYPES_MAPPING.keys()
))
addr = ctypes.cast(cptr, ctypes.c_void_p).value
# pylint: disable=c-extension-no-member,no-member
cupy.cuda.runtime.memcpy(
mem.__cuda_array_interface__['data'][0], addr,
length.value * ctypes.sizeof(ctypes.c_float),
cupy.cuda.runtime.memcpyDeviceToDevice)
return mem
device = cupy.cuda.runtime.pointerGetAttributes(addr).device
# The owner field is just used to keep the memory alive with ref count. As
# unowned's life time is scoped within this function we don't need that.
unownd = UnownedMemory(
addr, length.value * ctypes.sizeof(CUPY_TO_CTYPES_MAPPING[dtype]),
owner=None)
memptr = MemoryPointer(unownd, 0)
# pylint: disable=unexpected-keyword-arg
mem = cupy.ndarray((length.value, ), dtype=dtype, memptr=memptr)
assert mem.device.id == device
arr = cupy.array(mem, copy=True)
return arr


def ctypes2buffer(cptr, length):
Expand Down
7 changes: 5 additions & 2 deletions python-package/xgboost/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,11 @@ def concat(value): # pylint: disable=too-many-return-statements
return CUDF_concat(value, axis=0)
if lazy_isinstance(value[0], 'cupy.core.core', 'ndarray'):
import cupy # pylint: disable=import-error
# pylint: disable=c-extension-no-member,no-member
d = cupy.cuda.runtime.getDevice()
for v in value:
d_v = v.device.id
assert d_v == d, 'Concatenating arrays on different devices.'
return cupy.concatenate(value, axis=0)
return dd.multi.concat(list(value), axis=0)

Expand Down Expand Up @@ -623,8 +628,6 @@ def mapped_predict(data, is_df):
if is_df:
if lazy_isinstance(data, 'cudf.core.dataframe', 'DataFrame'):
import cudf # pylint: disable=import-error
# There's an error with cudf saying `concat_cudf` got an
# expected argument `ignore_index`. So this is not yet working.
prediction = cudf.DataFrame({'prediction': prediction},
dtype=numpy.float32)
else:
Expand Down
4 changes: 2 additions & 2 deletions tests/ci_build/Dockerfile.cpu
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ RUN \
wget -nv -nc https://cmake.org/files/v3.13/cmake-3.13.0-Linux-x86_64.sh --no-check-certificate && \
bash cmake-3.13.0-Linux-x86_64.sh --skip-license --prefix=/usr && \
# Python
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \
bash Miniconda3-latest-Linux-x86_64.sh -b -p /opt/python
wget -O Miniconda3.sh https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \
bash Miniconda3.sh -b -p /opt/python

ENV PATH=/opt/python/bin:$PATH

Expand Down
16 changes: 6 additions & 10 deletions tests/ci_build/Dockerfile.cudf
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,16 @@ RUN \
apt-get update && \
apt-get install -y wget unzip bzip2 libgomp1 build-essential && \
# Python
wget https://repo.continuum.io/miniconda/Miniconda3-4.5.12-Linux-x86_64.sh && \
bash Miniconda3-4.5.12-Linux-x86_64.sh -b -p /opt/python
wget -O Miniconda3.sh https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \
bash Miniconda3.sh -b -p /opt/python

ENV PATH=/opt/python/bin:$PATH

# Create new Conda environment with cuDF and dask
# Create new Conda environment with cuDF, Dask, and cuPy
RUN \
conda create -n cudf_test -c rapidsai -c nvidia -c numba -c conda-forge -c anaconda \
cudf=0.9 python=3.7 anaconda::cudatoolkit=$CUDA_VERSION dask dask-cuda cupy

# Install other Python packages
RUN \
source activate cudf_test && \
pip install numpy pytest scipy scikit-learn pandas matplotlib wheel kubernetes urllib3 graphviz
conda create -n cudf_test -c rapidsai -c nvidia -c conda-forge -c defaults \
python=3.7 cudf cudatoolkit=$CUDA_VERSION dask dask-cuda dask-cudf cupy \
numpy pytest scipy scikit-learn pandas matplotlib wheel python-kubernetes urllib3 graphviz

ENV GOSU_VERSION 1.10

Expand Down
10 changes: 5 additions & 5 deletions tests/ci_build/Dockerfile.gpu
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,16 @@ RUN \
apt-get update && \
apt-get install -y wget unzip bzip2 libgomp1 build-essential && \
# Python
wget https://repo.continuum.io/miniconda/Miniconda3-4.5.12-Linux-x86_64.sh && \
bash Miniconda3-4.5.12-Linux-x86_64.sh -b -p /opt/python
wget -O Miniconda3.sh https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \
bash Miniconda3.sh -b -p /opt/python

ENV PATH=/opt/python/bin:$PATH

# Install Python packages
RUN \
pip install numpy pytest scipy scikit-learn pandas matplotlib wheel kubernetes urllib3 graphviz && \
pip install "dask[complete]" && \
conda install -c rapidsai -c nvidia -c numba -c conda-forge -c anaconda dask-cuda
conda create -n gpu_test -c rapidsai -c nvidia -c conda-forge -c defaults \
python=3.7 dask dask-cuda numpy pytest scipy scikit-learn pandas \
matplotlib wheel python-kubernetes urllib3 graphviz

ENV GOSU_VERSION 1.10

Expand Down
4 changes: 2 additions & 2 deletions tests/ci_build/Dockerfile.gpu_build
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ RUN \
$DEVTOOLSET_URL_ROOT/devtoolset-4-runtime-4.1-3.sc1.el6.x86_64.rpm \
$DEVTOOLSET_URL_ROOT/devtoolset-4-libstdc++-devel-5.3.1-6.1.el6.x86_64.rpm && \
# Python
wget https://repo.continuum.io/miniconda/Miniconda3-4.5.12-Linux-x86_64.sh && \
bash Miniconda3-4.5.12-Linux-x86_64.sh -b -p /opt/python && \
wget -O Miniconda3.sh https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \
bash Miniconda3.sh -b -p /opt/python && \
# CMake
wget -nv -nc https://cmake.org/files/v3.13/cmake-3.13.0-Linux-x86_64.sh --no-check-certificate && \
bash cmake-3.13.0-Linux-x86_64.sh --skip-license --prefix=/usr
Expand Down
4 changes: 2 additions & 2 deletions tests/ci_build/Dockerfile.jvm
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ RUN \
yum -y update && \
yum install -y devtoolset-6-gcc devtoolset-6-binutils devtoolset-6-gcc-c++ && \
# Python
wget https://repo.continuum.io/miniconda/Miniconda3-4.5.12-Linux-x86_64.sh && \
bash Miniconda3-4.5.12-Linux-x86_64.sh -b -p /opt/python && \
wget -O Miniconda3.sh https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \
bash Miniconda3.sh -b -p /opt/python && \
# CMake
wget -nv -nc https://cmake.org/files/v3.13/cmake-3.13.0-Linux-x86_64.sh --no-check-certificate && \
bash cmake-3.13.0-Linux-x86_64.sh --skip-license --prefix=/usr && \
Expand Down
4 changes: 2 additions & 2 deletions tests/ci_build/Dockerfile.jvm_cross
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ RUN \
apt-get update && \
apt-get install -y tar unzip wget openjdk-$JDK_VERSION-jdk libgomp1 && \
# Python
wget https://repo.continuum.io/miniconda/Miniconda3-4.5.12-Linux-x86_64.sh && \
bash Miniconda3-4.5.12-Linux-x86_64.sh -b -p /opt/python && \
wget -O Miniconda3.sh https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \
bash Miniconda3.sh -b -p /opt/python && \
/opt/python/bin/pip install awscli && \
# Maven
wget https://archive.apache.org/dist/maven/maven-3/3.6.1/binaries/apache-maven-3.6.1-bin.tar.gz && \
Expand Down
31 changes: 0 additions & 31 deletions tests/ci_build/Dockerfile.release

This file was deleted.

18 changes: 14 additions & 4 deletions tests/ci_build/test_python.sh
Original file line number Diff line number Diff line change
Expand Up @@ -28,23 +28,33 @@ function install_xgboost {
# Run specified test suite
case "$suite" in
gpu)
source activate gpu_test
install_xgboost
pytest -v -s --fulltrace -m "not mgpu" tests/python-gpu
pytest -v -s -rxXs --fulltrace -m "not mgpu" tests/python-gpu
;;

mgpu)
source activate gpu_test
install_xgboost
pytest -v -s --fulltrace -m "mgpu" tests/python-gpu
pytest -v -s -rxXs --fulltrace -m "mgpu" tests/python-gpu

cd tests/distributed
./runtests-gpu.sh
cd -
pytest -v -s --fulltrace -m "mgpu" tests/python-gpu/test_gpu_with_dask.py
;;

cudf)
source activate cudf_test
install_xgboost
pytest -v -s --fulltrace -m "not mgpu" tests/python-gpu/test_from_cudf.py tests/python-gpu/test_from_cupy.py
pytest -v -s -rxXs --fulltrace -m "not mgpu" \
tests/python-gpu/test_from_cudf.py tests/python-gpu/test_from_cupy.py \
tests/python-gpu/test_gpu_prediction.py
;;

mgpu-cudf)
source activate cudf_test
install_xgboost
pytest -v -s -rxXs --fulltrace -m "mgpu" tests/python-gpu/test_gpu_with_dask.py
;;

cpu)
Expand Down
2 changes: 2 additions & 0 deletions tests/python-gpu/test_gpu_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def non_decreasing(self, L):

# Test case for a bug where multiple batch predictions made on a
# test set produce incorrect results
@pytest.mark.skipif(**tm.no_sklearn())
def test_multi_predict(self):
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
Expand Down Expand Up @@ -89,6 +90,7 @@ def test_multi_predict(self):
assert np.allclose(predict0, predict1)
assert np.allclose(predict0, cpu_predict)

@pytest.mark.skipif(**tm.no_sklearn())
def test_sklearn(self):
m, n = 15000, 14
tr_size = 2500
Expand Down
19 changes: 12 additions & 7 deletions tests/python-gpu/test_gpu_with_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class TestDistributedGPU(unittest.TestCase):
@pytest.mark.skipif(**tm.no_cudf())
@pytest.mark.skipif(**tm.no_dask_cudf())
@pytest.mark.skipif(**tm.no_dask_cuda())
@pytest.mark.mgpu
def test_dask_dataframe(self):
with LocalCUDACluster() as cluster:
with Client(cluster) as client:
Expand All @@ -51,18 +52,18 @@ def test_dask_dataframe(self):
predictions = dxgb.predict(client, out, dtrain).compute()
assert isinstance(predictions, np.ndarray)

# There's an error with cudf saying `concat_cudf` got an
# expected argument `ignore_index`. So the test here is just
# place holder.

# series_predictions = dxgb.inplace_predict(client, out, X)
# assert isinstance(series_predictions, dd.Series)
series_predictions = dxgb.inplace_predict(client, out, X)
assert isinstance(series_predictions, dd.Series)
series_predictions = series_predictions.compute()

single_node = out['booster'].predict(
xgboost.DMatrix(X.compute()))

cupy.testing.assert_allclose(single_node, predictions)
cupy.testing.assert_allclose(single_node, series_predictions)

@pytest.mark.skipif(**tm.no_cupy())
@pytest.mark.mgpu
def test_dask_array(self):
with LocalCUDACluster() as cluster:
with Client(cluster) as client:
Expand All @@ -82,8 +83,12 @@ def test_dask_array(self):
single_node = out['booster'].predict(
xgboost.DMatrix(X.compute()))
np.testing.assert_allclose(single_node, from_dmatrix)
device = cupy.cuda.runtime.getDevice()
assert device == inplace_predictions.device.id
single_node = cupy.array(single_node)
assert device == single_node.device.id
cupy.testing.assert_allclose(
cupy.array(single_node),
single_node,
inplace_predictions)


Expand Down
10 changes: 6 additions & 4 deletions tests/python-gpu/test_monotonic_constraints.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from __future__ import print_function

import sys
import numpy as np
from sklearn.datasets import make_regression

import unittest
import pytest

import xgboost as xgb
sys.path.append("tests/python")
import testing as tm

rng = np.random.RandomState(1994)

Expand All @@ -20,6 +20,7 @@ def non_increasing(L):


def assert_constraint(constraint, tree_method):
from sklearn.datasets import make_regression
n = 1000
X, y = make_regression(n, random_state=rng, n_features=1, n_informative=1)
dtrain = xgb.DMatrix(X, y)
Expand All @@ -35,12 +36,13 @@ def assert_constraint(constraint, tree_method):
assert non_increasing(pred)


@pytest.mark.gpu
class TestMonotonicConstraints(unittest.TestCase):
@pytest.mark.skipif(**tm.no_sklearn())
def test_exact(self):
assert_constraint(1, 'exact')
assert_constraint(-1, 'exact')

@pytest.mark.skipif(**tm.no_sklearn())
def test_gpu_hist(self):
assert_constraint(1, 'gpu_hist')
assert_constraint(-1, 'gpu_hist')
6 changes: 3 additions & 3 deletions tests/python/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ def run_threaded_predict(X, rows, predict_func):
per_thread = 20
with ThreadPoolExecutor(max_workers=10) as e:
for i in range(0, rows, int(rows / per_thread)):
try:
if hasattr(X, 'iloc'):
predictor = X.iloc[i:i+per_thread, :]
else:
predictor = X[i:i+per_thread, ...]
except TypeError:
predictor = X.iloc[i:i+per_thread, ...]
f = e.submit(predict_func, predictor)
results.append(f)

Expand Down