Skip to content

Commit

Permalink
Objective function evaluation on GPU with minimal PCIe transfers (#2935)
Browse files Browse the repository at this point in the history
* Added GPU objective function and no-copy interface.

- xgboost::HostDeviceVector<T> syncs automatically between host and device
- no-copy interfaces have been added
- default implementations just sync the data to host
  and call the implementations with std::vector
- GPU objective function, predictor, histogram updater process data
  directly on GPU
  • Loading branch information
teju85 authored and RAMitchell committed Jan 12, 2018
1 parent a187ed6 commit 84ab74f
Show file tree
Hide file tree
Showing 23 changed files with 1,036 additions and 127 deletions.
1 change: 1 addition & 0 deletions amalgamation/xgboost-all0.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
#include "../src/learner.cc"
#include "../src/logging.cc"
#include "../src/common/common.cc"
#include "../src/common/host_device_vector.cc"
#include "../src/common/hist_util.cc"

// c_api
Expand Down
3 changes: 3 additions & 0 deletions doc/parameter.md
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,9 @@ Specify the learning task and the corresponding learning objective. The objectiv
- "reg:logistic" --logistic regression
- "binary:logistic" --logistic regression for binary classification, output probability
- "binary:logitraw" --logistic regression for binary classification, output score before logistic transformation
- "gpu:reg:linear", "gpu:reg:logistic", "gpu:binary:logistic", gpu:binary:logitraw" --versions
of the corresponding objective functions evaluated on the GPU; note that like the GPU histogram algorithm,
they can only be used when the entire training session uses the same dataset
- "count:poisson" --poisson regression for count data, output mean of poisson distribution
- max_delta_step is set to 0.7 by default in poisson regression (used to safeguard optimization)
- "multi:softmax" --set XGBoost to do multiclass classification using the softmax objective, you also need to set num_class(number of classes)
Expand Down
8 changes: 8 additions & 0 deletions include/xgboost/gbm.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "./data.h"
#include "./objective.h"
#include "./feature_map.h"
#include "../../src/common/host_device_vector.h"

namespace xgboost {
/*!
Expand Down Expand Up @@ -70,6 +71,10 @@ class GradientBooster {
virtual void DoBoost(DMatrix* p_fmat,
std::vector<bst_gpair>* in_gpair,
ObjFunction* obj = nullptr) = 0;
virtual void DoBoost(DMatrix* p_fmat,
HostDeviceVector<bst_gpair>* in_gpair,
ObjFunction* obj = nullptr);

/*!
* \brief generate predictions for given feature matrix
* \param dmat feature matrix
Expand All @@ -80,6 +85,9 @@ class GradientBooster {
virtual void PredictBatch(DMatrix* dmat,
std::vector<bst_float>* out_preds,
unsigned ntree_limit = 0) = 0;
virtual void PredictBatch(DMatrix* dmat,
HostDeviceVector<bst_float>* out_preds,
unsigned ntree_limit = 0);
/*!
* \brief online prediction function, predict score for one instance at a time
* NOTE: use the batch prediction interface if possible, batch prediction is usually
Expand Down
13 changes: 13 additions & 0 deletions include/xgboost/objective.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,11 @@
#include <functional>
#include "./data.h"
#include "./base.h"
#include "../../src/common/host_device_vector.h"


namespace xgboost {

/*! \brief interface of objective function */
class ObjFunction {
public:
Expand Down Expand Up @@ -45,6 +48,11 @@ class ObjFunction {
const MetaInfo& info,
int iteration,
std::vector<bst_gpair>* out_gpair) = 0;
virtual void GetGradient(HostDeviceVector<bst_float>* preds,
const MetaInfo& info,
int iteration,
HostDeviceVector<bst_gpair>* out_gpair);

/*! \return the default evaluation metric for the objective */
virtual const char* DefaultEvalMetric() const = 0;
// the following functions are optional, most of time default implementation is good enough
Expand All @@ -53,6 +61,8 @@ class ObjFunction {
* \param io_preds prediction values, saves to this vector as well
*/
virtual void PredTransform(std::vector<bst_float> *io_preds) {}
virtual void PredTransform(HostDeviceVector<bst_float> *io_preds);

/*!
* \brief transform prediction values, this is only called when Eval is called,
* usually it redirect to PredTransform
Expand All @@ -61,6 +71,9 @@ class ObjFunction {
virtual void EvalTransform(std::vector<bst_float> *io_preds) {
this->PredTransform(io_preds);
}
virtual void EvalTransform(HostDeviceVector<bst_float> *io_preds) {
this->PredTransform(io_preds);
}
/*!
* \brief transform probability value back to margin
* this is used to transform user-set base_score back to margin
Expand Down
21 changes: 17 additions & 4 deletions include/xgboost/predictor.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <utility>
#include <vector>
#include "../../src/gbm/gbtree_model.h"
#include "../../src/common/host_device_vector.h"

// Forward declarations
namespace xgboost {
Expand Down Expand Up @@ -51,10 +52,6 @@ class Predictor {
const std::vector<std::shared_ptr<DMatrix>>& cache);

/**
* \fn virtual void Predictor::PredictBatch( DMatrix* dmat,
* std::vector<bst_float>* out_preds, const gbm::GBTreeModel &model, int
* tree_begin, unsigned ntree_limit = 0) = 0;
*
* \brief Generate batch predictions for a given feature matrix. May use
* cached predictions if available instead of calculating from scratch.
*
Expand All @@ -70,6 +67,22 @@ class Predictor {
const gbm::GBTreeModel& model, int tree_begin,
unsigned ntree_limit = 0) = 0;

/**
* \brief Generate batch predictions for a given feature matrix. May use
* cached predictions if available instead of calculating from scratch.
*
* \param [in,out] dmat Feature matrix.
* \param [in,out] out_preds The output preds.
* \param model The model to predict from.
* \param tree_begin The tree begin index.
* \param ntree_limit (Optional) The ntree limit. 0 means do not
* limit trees.
*/

virtual void PredictBatch(DMatrix* dmat, HostDeviceVector<bst_float>* out_preds,
const gbm::GBTreeModel& model, int tree_begin,
unsigned ntree_limit = 0) = 0;

/**
* \fn virtual void Predictor::UpdatePredictionCache( const gbm::GBTreeModel
* &model, std::vector<std::unique_ptr<TreeUpdater> >* updaters, int
Expand Down
7 changes: 7 additions & 0 deletions include/xgboost/tree_updater.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "./base.h"
#include "./data.h"
#include "./tree_model.h"
#include "../../src/common/host_device_vector.h"

namespace xgboost {
/*!
Expand All @@ -42,6 +43,9 @@ class TreeUpdater {
virtual void Update(const std::vector<bst_gpair>& gpair,
DMatrix* data,
const std::vector<RegTree*>& trees) = 0;
virtual void Update(HostDeviceVector<bst_gpair>* gpair,
DMatrix* data,
const std::vector<RegTree*>& trees);

/*!
* \brief determines whether updater has enough knowledge about a given dataset
Expand All @@ -57,6 +61,9 @@ class TreeUpdater {
std::vector<bst_float>* out_preds) {
return false;
}
virtual bool UpdatePredictionCache(const DMatrix* data,
HostDeviceVector<bst_float>* out_preds);

/*!
* \brief Create a tree updater given name
* \param name Name of the tree updater.
Expand Down
7 changes: 7 additions & 0 deletions src/common/device_helpers.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,13 @@ class bulk_allocator {
}

public:
bulk_allocator() {}
// prevent accidental copying, moving or assignment of this object
bulk_allocator(const bulk_allocator<MemoryT>&) = delete;
bulk_allocator(bulk_allocator<MemoryT>&&) = delete;
void operator=(const bulk_allocator<MemoryT>&) = delete;
void operator=(bulk_allocator<MemoryT>&&) = delete;

~bulk_allocator() {
for (size_t i = 0; i < d_ptr.size(); i++) {
if (!(d_ptr[i] == nullptr)) {
Expand Down
54 changes: 54 additions & 0 deletions src/common/host_device_vector.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*!
* Copyright 2017 XGBoost contributors
*/
#ifndef XGBOOST_USE_CUDA

// dummy implementation of HostDeviceVector in case CUDA is not used

#include <xgboost/base.h>
#include "./host_device_vector.h"

namespace xgboost {

template <typename T>
struct HostDeviceVectorImpl {
explicit HostDeviceVectorImpl(size_t size) : data_h_(size) {}
std::vector<T> data_h_;
};

template <typename T>
HostDeviceVector<T>::HostDeviceVector(size_t size, int device) : impl_(nullptr) {
impl_ = new HostDeviceVectorImpl<T>(size);
}

template <typename T>
HostDeviceVector<T>::~HostDeviceVector() {
HostDeviceVectorImpl<T>* tmp = impl_;
impl_ = nullptr;
delete tmp;
}

template <typename T>
size_t HostDeviceVector<T>::size() const { return impl_->data_h_.size(); }

template <typename T>
int HostDeviceVector<T>::device() const { return -1; }

template <typename T>
T* HostDeviceVector<T>::ptr_d(int device) { return nullptr; }

template <typename T>
std::vector<T>& HostDeviceVector<T>::data_h() { return impl_->data_h_; }

template <typename T>
void HostDeviceVector<T>::resize(size_t new_size, int new_device) {
impl_->data_h_.resize(new_size);
}

// explicit instantiations are required, as HostDeviceVector isn't header-only
template class HostDeviceVector<bst_float>;
template class HostDeviceVector<bst_gpair>;

} // namespace xgboost

#endif
135 changes: 135 additions & 0 deletions src/common/host_device_vector.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
/*!
* Copyright 2017 XGBoost contributors
*/
#include "./host_device_vector.h"
#include "./device_helpers.cuh"

namespace xgboost {

template <typename T>
struct HostDeviceVectorImpl {
HostDeviceVectorImpl(size_t size, int device)
: device_(device), on_d_(device >= 0) {
if (on_d_) {
dh::safe_cuda(cudaSetDevice(device_));
data_d_.resize(size);
} else {
data_h_.resize(size);
}
}
HostDeviceVectorImpl(const HostDeviceVectorImpl<T>&) = delete;
HostDeviceVectorImpl(HostDeviceVectorImpl<T>&&) = delete;
void operator=(const HostDeviceVectorImpl<T>&) = delete;
void operator=(HostDeviceVectorImpl<T>&&) = delete;

size_t size() const { return on_d_ ? data_d_.size() : data_h_.size(); }

int device() const { return device_; }

T* ptr_d(int device) {
lazy_sync_device(device);
return data_d_.data().get();
}
thrust::device_ptr<T> tbegin(int device) {
return thrust::device_ptr<T>(ptr_d(device));
}
thrust::device_ptr<T> tend(int device) {
auto begin = tbegin(device);
return begin + size();
}
std::vector<T>& data_h() {
lazy_sync_host();
return data_h_;
}
void resize(size_t new_size, int new_device) {
if (new_size == this->size() && new_device == device_)
return;
device_ = new_device;
// if !on_d_, but the data size is 0 and the device is set,
// resize the data on device instead
if (!on_d_ && (data_h_.size() > 0 || device_ == -1)) {
data_h_.resize(new_size);
} else {
dh::safe_cuda(cudaSetDevice(device_));
data_d_.resize(new_size);
on_d_ = true;
}
}

void lazy_sync_host() {
if (!on_d_)
return;
if (data_h_.size() != this->size())
data_h_.resize(this->size());
dh::safe_cuda(cudaSetDevice(device_));
thrust::copy(data_d_.begin(), data_d_.end(), data_h_.begin());
on_d_ = false;
}

void lazy_sync_device(int device) {
if (on_d_)
return;
if (device != device_) {
CHECK_EQ(device_, -1);
device_ = device;
}
if (data_d_.size() != this->size()) {
dh::safe_cuda(cudaSetDevice(device_));
data_d_.resize(this->size());
}
dh::safe_cuda(cudaSetDevice(device_));
thrust::copy(data_h_.begin(), data_h_.end(), data_d_.begin());
on_d_ = true;
}

std::vector<T> data_h_;
thrust::device_vector<T> data_d_;
// true if there is an up-to-date copy of data on device, false otherwise
bool on_d_;
int device_;
};

template <typename T>
HostDeviceVector<T>::HostDeviceVector(size_t size, int device) : impl_(nullptr) {
impl_ = new HostDeviceVectorImpl<T>(size, device);
}

template <typename T>
HostDeviceVector<T>::~HostDeviceVector() {
HostDeviceVectorImpl<T>* tmp = impl_;
impl_ = nullptr;
delete tmp;
}

template <typename T>
size_t HostDeviceVector<T>::size() const { return impl_->size(); }

template <typename T>
int HostDeviceVector<T>::device() const { return impl_->device(); }

template <typename T>
T* HostDeviceVector<T>::ptr_d(int device) { return impl_->ptr_d(device); }

template <typename T>
thrust::device_ptr<T> HostDeviceVector<T>::tbegin(int device) {
return impl_->tbegin(device);
}

template <typename T>
thrust::device_ptr<T> HostDeviceVector<T>::tend(int device) {
return impl_->tend(device);
}

template <typename T>
std::vector<T>& HostDeviceVector<T>::data_h() { return impl_->data_h(); }

template <typename T>
void HostDeviceVector<T>::resize(size_t new_size, int new_device) {
impl_->resize(new_size, new_device);
}

// explicit instantiations are required, as HostDeviceVector isn't header-only
template class HostDeviceVector<bst_float>;
template class HostDeviceVector<bst_gpair>;

} // namespace xgboost
Loading

0 comments on commit 84ab74f

Please sign in to comment.