Skip to content

Commit

Permalink
Fixes memory allocation for experimental backend and improves quantil…
Browse files Browse the repository at this point in the history
…e computations (#3586)

Previous to this PR, when new/experimental backend is used for training, the temporary memory needed by old backend is also getting allocated. This PR fixes the issue. The temporary memory is allocated conditionally now. This PR also changes the computation of quantiles for new backend. The old way of computing quantiles may leave last few samples due to incorrect quantile thresholds. Impact on accuracy is still to be evaluated thoroughly.

Authors:
  - Vinay Deshpande (@vinaydes)

Approvers:
  - Thejaswi. N. S (@teju85)
  - Philip Hyunsu Cho (@hcho3)

URL: #3586
  • Loading branch information
vinaydes authored Mar 17, 2021
1 parent c4c4068 commit de42e7f
Show file tree
Hide file tree
Showing 9 changed files with 479 additions and 106 deletions.
60 changes: 56 additions & 4 deletions cpp/src/decisiontree/decisiontree.cu
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,21 @@ void decisionTreeClassifierFit(const raft::handle_t &handle,
uint64_t seed) {
std::shared_ptr<DecisionTreeClassifier<float>> dt_classifier =
std::make_shared<DecisionTreeClassifier<float>>();
std::unique_ptr<MLCommon::device_buffer<float>> global_quantiles_buffer =
nullptr;
float *global_quantiles = nullptr;

if (tree_params.use_experimental_backend) {
auto quantile_size = tree_params.n_bins * ncols;
global_quantiles_buffer = std::make_unique<MLCommon::device_buffer<float>>(
handle.get_device_allocator(), handle.get_stream(), quantile_size);
global_quantiles = global_quantiles_buffer->data();
DecisionTree::computeQuantiles(global_quantiles, tree_params.n_bins, data,
nrows, ncols, handle.get_device_allocator(),
handle.get_stream());
}
dt_classifier->fit(handle, data, ncols, nrows, labels, rowids, n_sampled_rows,
unique_labels, tree, tree_params, seed);
unique_labels, tree, tree_params, seed, global_quantiles);
}

void decisionTreeClassifierFit(const raft::handle_t &handle,
Expand All @@ -172,8 +185,21 @@ void decisionTreeClassifierFit(const raft::handle_t &handle,
uint64_t seed) {
std::shared_ptr<DecisionTreeClassifier<double>> dt_classifier =
std::make_shared<DecisionTreeClassifier<double>>();
std::unique_ptr<MLCommon::device_buffer<double>> global_quantiles_buffer =
nullptr;
double *global_quantiles = nullptr;

if (tree_params.use_experimental_backend) {
auto quantile_size = tree_params.n_bins * ncols;
global_quantiles_buffer = std::make_unique<MLCommon::device_buffer<double>>(
handle.get_device_allocator(), handle.get_stream(), quantile_size);
global_quantiles = global_quantiles_buffer->data();
DecisionTree::computeQuantiles(global_quantiles, tree_params.n_bins, data,
nrows, ncols, handle.get_device_allocator(),
handle.get_stream());
}
dt_classifier->fit(handle, data, ncols, nrows, labels, rowids, n_sampled_rows,
unique_labels, tree, tree_params, seed);
unique_labels, tree, tree_params, seed, global_quantiles);
}

void decisionTreeClassifierPredict(const raft::handle_t &handle,
Expand Down Expand Up @@ -208,8 +234,21 @@ void decisionTreeRegressorFit(const raft::handle_t &handle,
uint64_t seed) {
std::shared_ptr<DecisionTreeRegressor<float>> dt_regressor =
std::make_shared<DecisionTreeRegressor<float>>();
std::unique_ptr<MLCommon::device_buffer<float>> global_quantiles_buffer =
nullptr;
float *global_quantiles = nullptr;

if (tree_params.use_experimental_backend) {
auto quantile_size = tree_params.n_bins * ncols;
global_quantiles_buffer = std::make_unique<MLCommon::device_buffer<float>>(
handle.get_device_allocator(), handle.get_stream(), quantile_size);
global_quantiles = global_quantiles_buffer->data();
DecisionTree::computeQuantiles(global_quantiles, tree_params.n_bins, data,
nrows, ncols, handle.get_device_allocator(),
handle.get_stream());
}
dt_regressor->fit(handle, data, ncols, nrows, labels, rowids, n_sampled_rows,
tree, tree_params, seed);
tree, tree_params, seed, global_quantiles);
}

void decisionTreeRegressorFit(const raft::handle_t &handle,
Expand All @@ -220,8 +259,21 @@ void decisionTreeRegressorFit(const raft::handle_t &handle,
uint64_t seed) {
std::shared_ptr<DecisionTreeRegressor<double>> dt_regressor =
std::make_shared<DecisionTreeRegressor<double>>();
std::unique_ptr<MLCommon::device_buffer<double>> global_quantiles_buffer =
nullptr;
double *global_quantiles = nullptr;

if (tree_params.use_experimental_backend) {
auto quantile_size = tree_params.n_bins * ncols;
global_quantiles_buffer = std::make_unique<MLCommon::device_buffer<double>>(
handle.get_device_allocator(), handle.get_stream(), quantile_size);
global_quantiles = global_quantiles_buffer->data();
DecisionTree::computeQuantiles(global_quantiles, tree_params.n_bins, data,
nrows, ncols, handle.get_device_allocator(),
handle.get_stream());
}
dt_regressor->fit(handle, data, ncols, nrows, labels, rowids, n_sampled_rows,
tree, tree_params, seed);
tree, tree_params, seed, global_quantiles);
}

void decisionTreeRegressorPredict(const raft::handle_t &handle,
Expand Down
79 changes: 44 additions & 35 deletions cpp/src/decisiontree/decisiontree_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
#include "levelalgo/levelfunc_regressor.cuh"
#include "levelalgo/metric.cuh"
#include "memory.cuh"
#include "quantile/quantile.cuh"
#include "quantile/quantile.h"
#include "treelite_util.h"

Expand Down Expand Up @@ -293,20 +292,9 @@ void DecisionTreeBase<T, L>::plant(

total_temp_mem = tempmem->totalmem;
MLCommon::TimerCPU timer;
if (tree_params.use_experimental_backend) {
if (treeid == 0) {
CUML_LOG_WARN("Using experimental backend for growing trees\n");
}
T *quantiles = tempmem->d_quantile->data();
grow_tree(tempmem->device_allocator, tempmem->host_allocator, data, treeid,
seed, ncols, nrows, labels, quantiles, (int *)rowids,
n_sampled_rows, unique_labels, tree_params, tempmem->stream,
sparsetree, this->leaf_counter, this->depth_counter);
} else {
grow_deep_tree(data, labels, rowids, n_sampled_rows, ncols,
tree_params.max_features, dinfo.NLocalrows, sparsetree,
treeid, tempmem);
}
grow_deep_tree(data, labels, rowids, n_sampled_rows, ncols,
tree_params.max_features, dinfo.NLocalrows, sparsetree, treeid,
tempmem);
train_time = timer.getElapsedSeconds();
ML::POP_RANGE();
}
Expand Down Expand Up @@ -379,7 +367,7 @@ void DecisionTreeBase<T, L>::base_fit(
const cudaStream_t stream_in, const T *data, const int ncols, const int nrows,
const L *labels, unsigned int *rowids, const int n_sampled_rows,
int unique_labels, std::vector<SparseTreeNode<T, L>> &sparsetree,
const int treeid, uint64_t seed, bool is_classifier,
const int treeid, uint64_t seed, bool is_classifier, T *d_global_quantiles,
std::shared_ptr<TemporaryMemory<T, L>> in_tempmem) {
prepare_fit_timer.reset();
const char *CRITERION_NAME[] = {"GINI", "ENTROPY", "MSE", "MAE", "END"};
Expand All @@ -406,19 +394,37 @@ void DecisionTreeBase<T, L>::base_fit(
"Unsupported criterion %s\n",
CRITERION_NAME[tree_params.split_criterion]);

if (in_tempmem != nullptr) {
tempmem = in_tempmem;
} else {
tempmem = std::make_shared<TemporaryMemory<T, L>>(
device_allocator_in, host_allocator_in, stream_in, nrows, ncols,
unique_labels, tree_params);
tree_params.quantile_per_tree = true;
if (!tree_params.use_experimental_backend) {
// Only execute for level backend as temporary memory is unused in batched
// backend.
if (in_tempmem != nullptr) {
tempmem = in_tempmem;
} else {
tempmem = std::make_shared<TemporaryMemory<T, L>>(
device_allocator_in, host_allocator_in, stream_in, nrows, ncols,
unique_labels, tree_params);
tree_params.quantile_per_tree = true;
}
}

plant(sparsetree, data, ncols, nrows, labels, rowids, n_sampled_rows,
unique_labels, treeid, seed);
if (in_tempmem == nullptr) {
tempmem.reset();
if (tree_params.use_experimental_backend) {
dinfo.NLocalrows = nrows;
dinfo.NGlobalrows = nrows;
dinfo.Ncols = ncols;
n_unique_labels = unique_labels;
if (treeid == 0) {
CUML_LOG_WARN("Using experimental backend for growing trees\n");
}
grow_tree(device_allocator_in, host_allocator_in, data, treeid, seed, ncols,
nrows, labels, d_global_quantiles, (int *)rowids, n_sampled_rows,
unique_labels, tree_params, stream_in, sparsetree,
this->leaf_counter, this->depth_counter);
} else {
plant(sparsetree, data, ncols, nrows, labels, rowids, n_sampled_rows,
unique_labels, treeid, seed);
if (in_tempmem == nullptr) {
tempmem.reset();
}
}
}

Expand All @@ -427,13 +433,13 @@ void DecisionTreeClassifier<T>::fit(
const raft::handle_t &handle, const T *data, const int ncols, const int nrows,
const int *labels, unsigned int *rowids, const int n_sampled_rows,
const int unique_labels, TreeMetaDataNode<T, int> *&tree,
DecisionTreeParams tree_parameters, uint64_t seed,
DecisionTreeParams tree_parameters, uint64_t seed, T *d_global_quantiles,
std::shared_ptr<TemporaryMemory<T, int>> in_tempmem) {
this->tree_params = tree_parameters;
this->base_fit(handle.get_device_allocator(), handle.get_host_allocator(),
handle.get_stream(), data, ncols, nrows, labels, rowids,
n_sampled_rows, unique_labels, tree->sparsetree, tree->treeid,
seed, true, in_tempmem);
seed, true, d_global_quantiles, in_tempmem);
this->set_metadata(tree);
}

Expand All @@ -445,12 +451,13 @@ void DecisionTreeClassifier<T>::fit(
const cudaStream_t stream_in, const T *data, const int ncols, const int nrows,
const int *labels, unsigned int *rowids, const int n_sampled_rows,
const int unique_labels, TreeMetaDataNode<T, int> *&tree,
DecisionTreeParams tree_parameters, uint64_t seed,
DecisionTreeParams tree_parameters, uint64_t seed, T *d_global_quantiles,
std::shared_ptr<TemporaryMemory<T, int>> in_tempmem) {
this->tree_params = tree_parameters;
this->base_fit(device_allocator_in, host_allocator_in, stream_in, data, ncols,
nrows, labels, rowids, n_sampled_rows, unique_labels,
tree->sparsetree, tree->treeid, seed, true, in_tempmem);
tree->sparsetree, tree->treeid, seed, true, d_global_quantiles,
in_tempmem);
this->set_metadata(tree);
}

Expand All @@ -459,12 +466,13 @@ void DecisionTreeRegressor<T>::fit(
const raft::handle_t &handle, const T *data, const int ncols, const int nrows,
const T *labels, unsigned int *rowids, const int n_sampled_rows,
TreeMetaDataNode<T, T> *&tree, DecisionTreeParams tree_parameters,
uint64_t seed, std::shared_ptr<TemporaryMemory<T, T>> in_tempmem) {
uint64_t seed, T *d_global_quantiles,
std::shared_ptr<TemporaryMemory<T, T>> in_tempmem) {
this->tree_params = tree_parameters;
this->base_fit(handle.get_device_allocator(), handle.get_host_allocator(),
handle.get_stream(), data, ncols, nrows, labels, rowids,
n_sampled_rows, 1, tree->sparsetree, tree->treeid, seed, false,
in_tempmem);
d_global_quantiles, in_tempmem);
this->set_metadata(tree);
}

Expand All @@ -475,11 +483,12 @@ void DecisionTreeRegressor<T>::fit(
const cudaStream_t stream_in, const T *data, const int ncols, const int nrows,
const T *labels, unsigned int *rowids, const int n_sampled_rows,
TreeMetaDataNode<T, T> *&tree, DecisionTreeParams tree_parameters,
uint64_t seed, std::shared_ptr<TemporaryMemory<T, T>> in_tempmem) {
uint64_t seed, T *d_global_quantiles,
std::shared_ptr<TemporaryMemory<T, T>> in_tempmem) {
this->tree_params = tree_parameters;
this->base_fit(device_allocator_in, host_allocator_in, stream_in, data, ncols,
nrows, labels, rowids, n_sampled_rows, 1, tree->sparsetree,
tree->treeid, seed, false, in_tempmem);
tree->treeid, seed, false, d_global_quantiles, in_tempmem);
this->set_metadata(tree);
}

Expand Down
11 changes: 6 additions & 5 deletions cpp/src/decisiontree/decisiontree_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ class DecisionTreeBase {
const int nrows, const L *labels, unsigned int *rowids,
const int n_sampled_rows, int unique_labels,
std::vector<SparseTreeNode<T, L>> &sparsetree, const int treeid,
uint64_t seed, bool is_classifier,
uint64_t seed, bool is_classifier, T *d_global_quantiles,
std::shared_ptr<TemporaryMemory<T, L>> in_tempmem);

public:
Expand Down Expand Up @@ -140,7 +140,7 @@ class DecisionTreeClassifier : public DecisionTreeBase<T, int> {
const int nrows, const int *labels, unsigned int *rowids,
const int n_sampled_rows, const int unique_labels,
TreeMetaDataNode<T, int> *&tree, DecisionTreeParams tree_parameters,
uint64_t seed,
uint64_t seed, T *d_quantiles,
std::shared_ptr<TemporaryMemory<T, int>> in_tempmem = nullptr);

//This fit fucntion does not take handle , used by RF
Expand All @@ -150,7 +150,8 @@ class DecisionTreeClassifier : public DecisionTreeBase<T, int> {
const int nrows, const int *labels, unsigned int *rowids,
const int n_sampled_rows, const int unique_labels,
TreeMetaDataNode<T, int> *&tree, DecisionTreeParams tree_parameters,
uint64_t seed, std::shared_ptr<TemporaryMemory<T, int>> in_tempmem);
uint64_t seed, T *d_quantiles,
std::shared_ptr<TemporaryMemory<T, int>> in_tempmem);

private:
void grow_deep_tree(const T *data, const int *labels, unsigned int *rowids,
Expand All @@ -168,7 +169,7 @@ class DecisionTreeRegressor : public DecisionTreeBase<T, T> {
void fit(const raft::handle_t &handle, const T *data, const int ncols,
const int nrows, const T *labels, unsigned int *rowids,
const int n_sampled_rows, TreeMetaDataNode<T, T> *&tree,
DecisionTreeParams tree_parameters, uint64_t seed,
DecisionTreeParams tree_parameters, uint64_t seed, T *d_quantiles,
std::shared_ptr<TemporaryMemory<T, T>> in_tempmem = nullptr);

//This fit function does not take handle. Used by RF
Expand All @@ -177,7 +178,7 @@ class DecisionTreeRegressor : public DecisionTreeBase<T, T> {
const cudaStream_t stream_in, const T *data, const int ncols,
const int nrows, const T *labels, unsigned int *rowids,
const int n_sampled_rows, TreeMetaDataNode<T, T> *&tree,
DecisionTreeParams tree_parameters, uint64_t seed,
DecisionTreeParams tree_parameters, uint64_t seed, T *d_quantiles,
std::shared_ptr<TemporaryMemory<T, T>> in_tempmem);

private:
Expand Down
66 changes: 65 additions & 1 deletion cpp/src/decisiontree/quantile/quantile.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,21 @@
*/

#pragma once
#include <raft/cudart_utils.h>
#include <cub/cub.cuh>
#include <raft/cuda_utils.cuh>
#include <raft/mr/device/allocator.hpp>
#include <raft/mr/device/buffer.hpp>
#include "quantile.h"

#include <common/nvtx.hpp>

namespace ML {
namespace DecisionTree {

using device_allocator = raft::mr::device::allocator;
template <typename T>
using device_buffer = raft::mr::device::buffer<T>;

template <typename T>
__global__ void allcolsampler_kernel(const T *__restrict__ data,
const unsigned int *__restrict__ rowids,
Expand Down Expand Up @@ -183,5 +189,63 @@ void preprocess_quantile(const T *data, const unsigned int *rowids,
return;
}

template <typename T>
__global__ void computeQuantilesSorted(T *quantiles, const int n_bins,
const T *sorted_data, const int length) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
double bin_width = static_cast<double>(length) / n_bins;
int index = int(round((tid + 1) * bin_width)) - 1;
// Old way of computing quantiles. Kept here for comparison.
// To be deleted eventually
// int index = (tid + 1) * floor(bin_width) - 1;
if (tid < n_bins) {
quantiles[tid] = sorted_data[index];
}

return;
}

template <typename T>
void computeQuantiles(T *quantiles, int n_bins, const T *data, int n_rows,
int n_cols,
const std::shared_ptr<deviceAllocator> device_allocator,
cudaStream_t stream) {
// Determine temporary device storage requirements
std::unique_ptr<device_buffer<char>> d_temp_storage = nullptr;
size_t temp_storage_bytes = 0;

std::unique_ptr<device_buffer<T>> single_column_sorted = nullptr;
single_column_sorted =
std::make_unique<device_buffer<T>>(device_allocator, stream, n_rows);

CUDA_CHECK(cub::DeviceRadixSort::SortKeys(nullptr, temp_storage_bytes, data,
single_column_sorted->data(),
n_rows, 0, 8 * sizeof(T), stream));

// Allocate temporary storage for sorting
d_temp_storage = std::make_unique<device_buffer<char>>(
device_allocator, stream, temp_storage_bytes);

// Compute quantiles column by column
for (int col = 0; col < n_cols; col++) {
int col_offset = col * n_rows;
int quantile_offset = col * n_bins;

CUDA_CHECK(cub::DeviceRadixSort::SortKeys(
(void *)d_temp_storage->data(), temp_storage_bytes, &data[col_offset],
single_column_sorted->data(), n_rows, 0, 8 * sizeof(T), stream));

int blocks = raft::ceildiv(n_bins, 128);

computeQuantilesSorted<<<blocks, 128, 0, stream>>>(
&quantiles[quantile_offset], n_bins, single_column_sorted->data(),
n_rows);

CUDA_CHECK(cudaGetLastError());
}

return;
}

} // namespace DecisionTree
} // namespace ML
Loading

0 comments on commit de42e7f

Please sign in to comment.