Skip to content

Commit

Permalink
Document for device ordinal. (#9398)
Browse files Browse the repository at this point in the history
- Rewrite GPU demos. notebook is converted to script to avoid committing additional png plots.
- Add GPU demos into the sphinx gallery.
- Add RMM demos into the sphinx gallery.
- Test for firing threads with different device ordinals.
  • Loading branch information
trivialfis authored Jul 22, 2023
1 parent 22b0a55 commit 275da17
Show file tree
Hide file tree
Showing 32 changed files with 343 additions and 390 deletions.
10 changes: 1 addition & 9 deletions demo/c-api/basic/c-api-demo.c
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,7 @@ int main() {
// configure the training
// available parameters are described here:
// https://xgboost.readthedocs.io/en/latest/parameter.html
safe_xgboost(XGBoosterSetParam(booster, "tree_method", use_gpu ? "gpu_hist" : "hist"));
if (use_gpu) {
// set the GPU to use;
// this is not necessary, but provided here as an illustration
safe_xgboost(XGBoosterSetParam(booster, "gpu_id", "0"));
} else {
// avoid evaluating objective and metric on a GPU
safe_xgboost(XGBoosterSetParam(booster, "gpu_id", "-1"));
}
safe_xgboost(XGBoosterSetParam(booster, "device", use_gpu ? "cuda" : "cpu"));

safe_xgboost(XGBoosterSetParam(booster, "objective", "binary:logistic"));
safe_xgboost(XGBoosterSetParam(booster, "min_child_weight", "1"));
Expand Down
5 changes: 0 additions & 5 deletions demo/gpu_acceleration/README.md

This file was deleted.

8 changes: 8 additions & 0 deletions demo/gpu_acceleration/README.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
:orphan:

GPU Acceleration Demo
=====================

This is a collection of demonstration scripts to showcase the basic usage of GPU. Please
see :doc:`/gpu/index` for more info. There are other demonstrations for distributed GPU
training using dask or spark.
56 changes: 32 additions & 24 deletions demo/gpu_acceleration/cover_type.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,49 @@
"""
Using xgboost on GPU devices
============================
Shows how to train a model on the `forest cover type
<https://archive.ics.uci.edu/ml/datasets/covertype>`_ dataset using GPU
acceleration. The forest cover type dataset has 581,012 rows and 54 features, making it
time consuming to process. We compare the run-time and accuracy of the GPU and CPU
histogram algorithms.
In addition, The demo showcases using GPU with other GPU-related libraries including
cupy and cuml. These libraries are not strictly required.
"""
import time

import cupy as cp
from cuml.model_selection import train_test_split
from sklearn.datasets import fetch_covtype
from sklearn.model_selection import train_test_split

import xgboost as xgb

# Fetch dataset using sklearn
cov = fetch_covtype()
X = cov.data
y = cov.target
X, y = fetch_covtype(return_X_y=True)
X = cp.array(X)
y = cp.array(y)
y -= y.min()

# Create 0.75/0.25 train/test split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, train_size=0.75,
random_state=42)
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.25, train_size=0.75, random_state=42
)

# Specify sufficient boosting iterations to reach a minimum
num_round = 3000

# Leave most parameters as default
param = {'objective': 'multi:softmax', # Specify multiclass classification
'num_class': 8, # Number of possible output classes
'tree_method': 'gpu_hist' # Use GPU accelerated algorithm
}

# Convert input data from numpy to XGBoost format
dtrain = xgb.DMatrix(X_train, label=y_train)
dtest = xgb.DMatrix(X_test, label=y_test)

gpu_res = {} # Store accuracy result
tmp = time.time()
clf = xgb.XGBClassifier(device="cuda", n_estimators=num_round)
# Train model
xgb.train(param, dtrain, num_round, evals=[(dtest, 'test')], evals_result=gpu_res)
print("GPU Training Time: %s seconds" % (str(time.time() - tmp)))
start = time.time()
clf.fit(X_train, y_train, eval_set=[(X_test, y_test)])
gpu_res = clf.evals_result()
print("GPU Training Time: %s seconds" % (str(time.time() - start)))

# Repeat for CPU algorithm
tmp = time.time()
param['tree_method'] = 'hist'
cpu_res = {}
xgb.train(param, dtrain, num_round, evals=[(dtest, 'test')], evals_result=cpu_res)
print("CPU Training Time: %s seconds" % (str(time.time() - tmp)))
clf = xgb.XGBClassifier(device="cpu", n_estimators=num_round)
start = time.time()
cpu_res = clf.evals_result()
print("CPU Training Time: %s seconds" % (str(time.time() - start)))
211 changes: 0 additions & 211 deletions demo/gpu_acceleration/shap.ipynb

This file was deleted.

55 changes: 55 additions & 0 deletions demo/gpu_acceleration/tree_shap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""
Use GPU to speedup SHAP value computation
=========================================
Demonstrates using GPU acceleration to compute SHAP values for feature importance.
"""
import shap
from sklearn.datasets import fetch_california_housing

import xgboost as xgb

# Fetch dataset using sklearn
data = fetch_california_housing()
print(data.DESCR)
X = data.data
y = data.target

num_round = 500

param = {
"eta": 0.05,
"max_depth": 10,
"tree_method": "hist",
"device": "cuda",
}

# GPU accelerated training
dtrain = xgb.DMatrix(X, label=y, feature_names=data.feature_names)
model = xgb.train(param, dtrain, num_round)

# Compute shap values using GPU with xgboost
model.set_param({"device": "cuda"})
shap_values = model.predict(dtrain, pred_contribs=True)

# Compute shap interaction values using GPU
shap_interaction_values = model.predict(dtrain, pred_interactions=True)


# shap will call the GPU accelerated version as long as the device parameter is set to
# "cuda"
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X)

# visualize the first prediction's explanation
shap.force_plot(
explainer.expected_value,
shap_values[0, :],
X[0, :],
feature_names=data.feature_names,
matplotlib=True,
)

# Show a summary of feature importance
shap.summary_plot(shap_values, X, plot_type="bar", feature_names=data.feature_names)
3 changes: 1 addition & 2 deletions demo/nvflare/horizontal/custom/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,7 @@ def _do_training(self, fl_ctx: FLContext):
param = {'max_depth': 2, 'eta': 1, 'objective': 'binary:logistic'}
if self._use_gpus:
self.log_info(fl_ctx, f'Training with GPU {rank}')
param['tree_method'] = 'gpu_hist'
param['gpu_id'] = rank
param['device'] = f"cuda:{rank}"

# Specify validations set to watch performance
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
Expand Down
47 changes: 0 additions & 47 deletions demo/rmm_plugin/README.md

This file was deleted.

51 changes: 51 additions & 0 deletions demo/rmm_plugin/README.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
Using XGBoost with RAPIDS Memory Manager (RMM) plugin (EXPERIMENTAL)
====================================================================

`RAPIDS Memory Manager (RMM) <https://github.com/rapidsai/rmm>`__ library provides a
collection of efficient memory allocators for NVIDIA GPUs. It is now possible to use
XGBoost with memory allocators provided by RMM, by enabling the RMM integration plugin.

The demos in this directory highlights one RMM allocator in particular: **the pool
sub-allocator**. This allocator addresses the slow speed of ``cudaMalloc()`` by
allocating a large chunk of memory upfront. Subsequent allocations will draw from the pool
of already allocated memory and thus avoid the overhead of calling ``cudaMalloc()``
directly. See `this GTC talk slides
<https://on-demand.gputechconf.com/gtc/2015/presentation/S5530-Stephen-Jones.pdf>`_ for
more details.

Before running the demos, ensure that XGBoost is compiled with the RMM plugin enabled. To do this,
run CMake with option ``-DPLUGIN_RMM=ON`` (``-DUSE_CUDA=ON`` also required):

.. code-block:: sh
cmake .. -DUSE_CUDA=ON -DUSE_NCCL=ON -DPLUGIN_RMM=ON
make -j$(nproc)
CMake will attempt to locate the RMM library in your build environment. You may choose to build
RMM from the source, or install it using the Conda package manager. If CMake cannot find RMM, you
should specify the location of RMM with the CMake prefix:

.. code-block:: sh
# If using Conda:
cmake .. -DUSE_CUDA=ON -DUSE_NCCL=ON -DPLUGIN_RMM=ON -DCMAKE_PREFIX_PATH=$CONDA_PREFIX
# If using RMM installed with a custom location
cmake .. -DUSE_CUDA=ON -DUSE_NCCL=ON -DPLUGIN_RMM=ON -DCMAKE_PREFIX_PATH=/path/to/rmm
********************************
Informing XGBoost about RMM pool
********************************

When XGBoost is compiled with RMM, most of the large size allocation will go through RMM
allocators, but some small allocations in performance critical areas are using a different
caching allocator so that we can have better control over memory allocation behavior.
Users can override this behavior and force the use of rmm for all allocations by setting
the global configuration ``use_rmm``:

.. code-block:: python
with xgb.config_context(use_rmm=True):
clf = xgb.XGBClassifier(tree_method="hist", device="cuda")
Depending on the choice of memory pool size or type of allocator, this may have negative
performance impact.
40 changes: 26 additions & 14 deletions demo/rmm_plugin/rmm_mgpu_with_dask.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
"""
Using rmm with Dask
===================
"""
import dask
from dask.distributed import Client
from dask_cuda import LocalCUDACluster
Expand All @@ -11,25 +15,33 @@ def main(client):
# xgb.set_config(use_rmm=True)

X, y = make_classification(n_samples=10000, n_informative=5, n_classes=3)
# In pratice one should prefer loading the data with dask collections instead of using
# `from_array`.
# In pratice one should prefer loading the data with dask collections instead of
# using `from_array`.
X = dask.array.from_array(X)
y = dask.array.from_array(y)
dtrain = xgb.dask.DaskDMatrix(client, X, label=y)

params = {'max_depth': 8, 'eta': 0.01, 'objective': 'multi:softprob', 'num_class': 3,
'tree_method': 'gpu_hist', 'eval_metric': 'merror'}
output = xgb.dask.train(client, params, dtrain, num_boost_round=100,
evals=[(dtrain, 'train')])
bst = output['booster']
history = output['history']
for i, e in enumerate(history['train']['merror']):
print(f'[{i}] train-merror: {e}')
params = {
"max_depth": 8,
"eta": 0.01,
"objective": "multi:softprob",
"num_class": 3,
"tree_method": "hist",
"eval_metric": "merror",
"device": "cuda",
}
output = xgb.dask.train(
client, params, dtrain, num_boost_round=100, evals=[(dtrain, "train")]
)
bst = output["booster"]
history = output["history"]
for i, e in enumerate(history["train"]["merror"]):
print(f"[{i}] train-merror: {e}")


if __name__ == '__main__':
# To use RMM pool allocator with a GPU Dask cluster, just add rmm_pool_size option to
# LocalCUDACluster constructor.
with LocalCUDACluster(rmm_pool_size='2GB') as cluster:
if __name__ == "__main__":
# To use RMM pool allocator with a GPU Dask cluster, just add rmm_pool_size option
# to LocalCUDACluster constructor.
with LocalCUDACluster(rmm_pool_size="2GB") as cluster:
with Client(cluster) as client:
main(client)
7 changes: 6 additions & 1 deletion demo/rmm_plugin/rmm_singlegpu.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
"""
Using rmm on a single node device
=================================
"""
import rmm
from sklearn.datasets import make_classification

Expand All @@ -16,7 +20,8 @@
"eta": 0.01,
"objective": "multi:softprob",
"num_class": 3,
"tree_method": "gpu_hist",
"tree_method": "hist",
"device": "cuda",
}
# XGBoost will automatically use the RMM pool allocator
bst = xgb.train(params, dtrain, num_boost_round=100, evals=[(dtrain, "train")])
2 changes: 2 additions & 0 deletions doc/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,5 @@ doxygen
parser.py
*.pyc
web-data
# generated by doxygen
tmp
11 changes: 9 additions & 2 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import tarfile
import urllib.request
import warnings
from subprocess import call
from urllib.error import HTTPError

from sh.contrib import git
Expand Down Expand Up @@ -148,12 +147,20 @@ def is_readthedocs_build():

sphinx_gallery_conf = {
# path to your example scripts
"examples_dirs": ["../demo/guide-python", "../demo/dask", "../demo/aft_survival"],
"examples_dirs": [
"../demo/guide-python",
"../demo/dask",
"../demo/aft_survival",
"../demo/gpu_acceleration",
"../demo/rmm_plugin"
],
# path to where to save gallery generated output
"gallery_dirs": [
"python/examples",
"python/dask-examples",
"python/survival-examples",
"python/gpu-examples",
"python/rmm-examples",
],
"matplotlib_animations": True,
}
Expand Down
Loading

0 comments on commit 275da17

Please sign in to comment.