Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REVIEW] Fixes memory allocation for experimental backend and improves quantile computations #3586

Merged
merged 40 commits into from
Mar 17, 2021
Merged
Show file tree
Hide file tree
Changes from 35 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
aa09671
Added new quantile computation code
vinaydes Feb 10, 2021
4eec348
Removing commented code
vinaydes Feb 10, 2021
d4a0d21
Adding quantile argument to decision tree fit function and other mino…
vinaydes Feb 12, 2021
1d729da
Changing RF code to handle temporary memory allocation when new backe…
vinaydes Feb 16, 2021
8af51b7
Making necessary changes in decision tree API to pass global quantile…
vinaydes Feb 16, 2021
4633080
Simplifying branching in RF fit functions and fixing deallocation
vinaydes Feb 16, 2021
e7f52ce
Dumping pre-plant data to file instead of stdout
vinaydes Feb 25, 2021
e580ac8
Removing inconsistent printing
vinaydes Mar 1, 2021
3251bb9
Computing qunatiles in DT API
vinaydes Mar 3, 2021
0e3820e
Corrections related to quatile size computations and clean-up
vinaydes Mar 4, 2021
a52a109
Changing the entry point for batched backend
vinaydes Mar 4, 2021
8c4aa53
Chaning the quantile variable data type to address crash issue
vinaydes Mar 4, 2021
3955f27
Merge branch 'branch-0.19' into fix-rf-memory-allocation
vinaydes Mar 5, 2021
a95a3a3
Including .h instead of .cuh as .cuh is included randomforest_impl.cuh
vinaydes Mar 7, 2021
2ea29fa
Removing debug prints from quantile computations
vinaydes Mar 7, 2021
004e71e
Minor changes
vinaydes Mar 7, 2021
a51c471
Undoing more minor changes
vinaydes Mar 7, 2021
80e402b
Undoing even more minor changes
vinaydes Mar 7, 2021
7104450
Blank line added
vinaydes Mar 7, 2021
0e1bc6a
Removing commented code
vinaydes Mar 7, 2021
774eae4
Whitespace and other minor changes
vinaydes Mar 7, 2021
7bf0e23
Matching whitespaces with branch-0.19
vinaydes Mar 7, 2021
8d94e1e
Using correct raft header and fixing call to computeQuantilesSorted
vinaydes Mar 7, 2021
0aa5c97
Removing additional debug prints
vinaydes Mar 7, 2021
8a87dbb
Fixing formatting
vinaydes Mar 7, 2021
f01e0b6
Fixing more formatting issues
vinaydes Mar 7, 2021
619ea7d
Deleting debug code
vinaydes Mar 7, 2021
51ef99e
Brining RF classifier and regressor functions on parity with each other
vinaydes Mar 7, 2021
25118f2
Readding the experimental backend warning
vinaydes Mar 7, 2021
5634ff2
Removed unused argument from computeQuantiles()
vinaydes Mar 8, 2021
05d8420
Cleaning up computeQuantiles() arguments
vinaydes Mar 8, 2021
cf5fc87
Fixing for the case when experimental backend is not used
vinaydes Mar 8, 2021
ab63f1f
Addressing review comments
vinaydes Mar 8, 2021
6ec2503
Updating copyright year
vinaydes Mar 8, 2021
a33703c
Using std::unique_ptr as per review comment
vinaydes Mar 8, 2021
1d307fc
Using std::unique_ptr for cleaner API
vinaydes Mar 8, 2021
7911b01
Fixing new style issues
vinaydes Mar 8, 2021
b4c32b4
Adding unit test for qunatile computation
vinaydes Mar 15, 2021
08ac7b6
Merge branch 'branch-0.19' into fix-rf-memory-allocation
vinaydes Mar 16, 2021
78dcdac
Replacing MLCommon with raft
vinaydes Mar 16, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 56 additions & 4 deletions cpp/src/decisiontree/decisiontree.cu
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,21 @@ void decisionTreeClassifierFit(const raft::handle_t &handle,
uint64_t seed) {
std::shared_ptr<DecisionTreeClassifier<float>> dt_classifier =
std::make_shared<DecisionTreeClassifier<float>>();
std::unique_ptr<MLCommon::device_buffer<float>> global_quantiles_buffer =
nullptr;
float *global_quantiles = nullptr;

if (tree_params.use_experimental_backend) {
auto quantile_size = tree_params.n_bins * ncols;
global_quantiles_buffer = std::make_unique<MLCommon::device_buffer<float>>(
handle.get_device_allocator(), handle.get_stream(), quantile_size);
global_quantiles = global_quantiles_buffer->data();
DecisionTree::computeQuantiles(global_quantiles, tree_params.n_bins, data,
nrows, ncols, handle.get_device_allocator(),
handle.get_stream());
}
dt_classifier->fit(handle, data, ncols, nrows, labels, rowids, n_sampled_rows,
unique_labels, tree, tree_params, seed);
unique_labels, tree, tree_params, seed, global_quantiles);
}

void decisionTreeClassifierFit(const raft::handle_t &handle,
Expand All @@ -172,8 +185,21 @@ void decisionTreeClassifierFit(const raft::handle_t &handle,
uint64_t seed) {
std::shared_ptr<DecisionTreeClassifier<double>> dt_classifier =
std::make_shared<DecisionTreeClassifier<double>>();
std::unique_ptr<MLCommon::device_buffer<double>> global_quantiles_buffer =
nullptr;
double *global_quantiles = nullptr;

if (tree_params.use_experimental_backend) {
auto quantile_size = tree_params.n_bins * ncols;
global_quantiles_buffer = std::make_unique<MLCommon::device_buffer<double>>(
handle.get_device_allocator(), handle.get_stream(), quantile_size);
global_quantiles = global_quantiles_buffer->data();
DecisionTree::computeQuantiles(global_quantiles, tree_params.n_bins, data,
nrows, ncols, handle.get_device_allocator(),
handle.get_stream());
}
dt_classifier->fit(handle, data, ncols, nrows, labels, rowids, n_sampled_rows,
unique_labels, tree, tree_params, seed);
unique_labels, tree, tree_params, seed, global_quantiles);
}

void decisionTreeClassifierPredict(const raft::handle_t &handle,
Expand Down Expand Up @@ -208,8 +234,21 @@ void decisionTreeRegressorFit(const raft::handle_t &handle,
uint64_t seed) {
std::shared_ptr<DecisionTreeRegressor<float>> dt_regressor =
std::make_shared<DecisionTreeRegressor<float>>();
std::unique_ptr<MLCommon::device_buffer<float>> global_quantiles_buffer =
nullptr;
float *global_quantiles = nullptr;

if (tree_params.use_experimental_backend) {
auto quantile_size = tree_params.n_bins * ncols;
global_quantiles_buffer = std::make_unique<MLCommon::device_buffer<float>>(
handle.get_device_allocator(), handle.get_stream(), quantile_size);
global_quantiles = global_quantiles_buffer->data();
DecisionTree::computeQuantiles(global_quantiles, tree_params.n_bins, data,
nrows, ncols, handle.get_device_allocator(),
handle.get_stream());
}
dt_regressor->fit(handle, data, ncols, nrows, labels, rowids, n_sampled_rows,
tree, tree_params, seed);
tree, tree_params, seed, global_quantiles);
}

void decisionTreeRegressorFit(const raft::handle_t &handle,
Expand All @@ -220,8 +259,21 @@ void decisionTreeRegressorFit(const raft::handle_t &handle,
uint64_t seed) {
std::shared_ptr<DecisionTreeRegressor<double>> dt_regressor =
std::make_shared<DecisionTreeRegressor<double>>();
std::unique_ptr<MLCommon::device_buffer<double>> global_quantiles_buffer =
nullptr;
double *global_quantiles = nullptr;

if (tree_params.use_experimental_backend) {
auto quantile_size = tree_params.n_bins * ncols;
global_quantiles_buffer = std::make_unique<MLCommon::device_buffer<double>>(
handle.get_device_allocator(), handle.get_stream(), quantile_size);
global_quantiles = global_quantiles_buffer->data();
DecisionTree::computeQuantiles(global_quantiles, tree_params.n_bins, data,
nrows, ncols, handle.get_device_allocator(),
handle.get_stream());
}
dt_regressor->fit(handle, data, ncols, nrows, labels, rowids, n_sampled_rows,
tree, tree_params, seed);
tree, tree_params, seed, global_quantiles);
}

void decisionTreeRegressorPredict(const raft::handle_t &handle,
Expand Down
79 changes: 44 additions & 35 deletions cpp/src/decisiontree/decisiontree_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
#include "levelalgo/levelfunc_regressor.cuh"
#include "levelalgo/metric.cuh"
#include "memory.cuh"
#include "quantile/quantile.cuh"
#include "quantile/quantile.h"
#include "treelite_util.h"

Expand Down Expand Up @@ -293,20 +292,9 @@ void DecisionTreeBase<T, L>::plant(

total_temp_mem = tempmem->totalmem;
MLCommon::TimerCPU timer;
if (tree_params.use_experimental_backend) {
if (treeid == 0) {
CUML_LOG_WARN("Using experimental backend for growing trees\n");
}
T *quantiles = tempmem->d_quantile->data();
grow_tree(tempmem->device_allocator, tempmem->host_allocator, data, treeid,
seed, ncols, nrows, labels, quantiles, (int *)rowids,
n_sampled_rows, unique_labels, tree_params, tempmem->stream,
sparsetree, this->leaf_counter, this->depth_counter);
} else {
grow_deep_tree(data, labels, rowids, n_sampled_rows, ncols,
tree_params.max_features, dinfo.NLocalrows, sparsetree,
treeid, tempmem);
}
grow_deep_tree(data, labels, rowids, n_sampled_rows, ncols,
tree_params.max_features, dinfo.NLocalrows, sparsetree, treeid,
tempmem);
train_time = timer.getElapsedSeconds();
ML::POP_RANGE();
}
Expand Down Expand Up @@ -379,7 +367,7 @@ void DecisionTreeBase<T, L>::base_fit(
const cudaStream_t stream_in, const T *data, const int ncols, const int nrows,
const L *labels, unsigned int *rowids, const int n_sampled_rows,
int unique_labels, std::vector<SparseTreeNode<T, L>> &sparsetree,
const int treeid, uint64_t seed, bool is_classifier,
const int treeid, uint64_t seed, bool is_classifier, T *d_global_quantiles,
std::shared_ptr<TemporaryMemory<T, L>> in_tempmem) {
prepare_fit_timer.reset();
const char *CRITERION_NAME[] = {"GINI", "ENTROPY", "MSE", "MAE", "END"};
Expand All @@ -406,19 +394,37 @@ void DecisionTreeBase<T, L>::base_fit(
"Unsupported criterion %s\n",
CRITERION_NAME[tree_params.split_criterion]);

if (in_tempmem != nullptr) {
tempmem = in_tempmem;
} else {
tempmem = std::make_shared<TemporaryMemory<T, L>>(
device_allocator_in, host_allocator_in, stream_in, nrows, ncols,
unique_labels, tree_params);
tree_params.quantile_per_tree = true;
if (!tree_params.use_experimental_backend) {
// Only execute for level backend as temporary memory is unused in batched
// backend.
if (in_tempmem != nullptr) {
tempmem = in_tempmem;
} else {
tempmem = std::make_shared<TemporaryMemory<T, L>>(
device_allocator_in, host_allocator_in, stream_in, nrows, ncols,
unique_labels, tree_params);
tree_params.quantile_per_tree = true;
}
}

plant(sparsetree, data, ncols, nrows, labels, rowids, n_sampled_rows,
unique_labels, treeid, seed);
if (in_tempmem == nullptr) {
tempmem.reset();
if (tree_params.use_experimental_backend) {
dinfo.NLocalrows = nrows;
dinfo.NGlobalrows = nrows;
dinfo.Ncols = ncols;
n_unique_labels = unique_labels;
if (treeid == 0) {
CUML_LOG_WARN("Using experimental backend for growing trees\n");
}
grow_tree(device_allocator_in, host_allocator_in, data, treeid, seed, ncols,
nrows, labels, d_global_quantiles, (int *)rowids, n_sampled_rows,
unique_labels, tree_params, stream_in, sparsetree,
this->leaf_counter, this->depth_counter);
} else {
plant(sparsetree, data, ncols, nrows, labels, rowids, n_sampled_rows,
unique_labels, treeid, seed);
if (in_tempmem == nullptr) {
tempmem.reset();
}
}
}

Expand All @@ -427,13 +433,13 @@ void DecisionTreeClassifier<T>::fit(
const raft::handle_t &handle, const T *data, const int ncols, const int nrows,
const int *labels, unsigned int *rowids, const int n_sampled_rows,
const int unique_labels, TreeMetaDataNode<T, int> *&tree,
DecisionTreeParams tree_parameters, uint64_t seed,
DecisionTreeParams tree_parameters, uint64_t seed, T *d_global_quantiles,
std::shared_ptr<TemporaryMemory<T, int>> in_tempmem) {
this->tree_params = tree_parameters;
this->base_fit(handle.get_device_allocator(), handle.get_host_allocator(),
handle.get_stream(), data, ncols, nrows, labels, rowids,
n_sampled_rows, unique_labels, tree->sparsetree, tree->treeid,
seed, true, in_tempmem);
seed, true, d_global_quantiles, in_tempmem);
this->set_metadata(tree);
}

Expand All @@ -445,12 +451,13 @@ void DecisionTreeClassifier<T>::fit(
const cudaStream_t stream_in, const T *data, const int ncols, const int nrows,
const int *labels, unsigned int *rowids, const int n_sampled_rows,
const int unique_labels, TreeMetaDataNode<T, int> *&tree,
DecisionTreeParams tree_parameters, uint64_t seed,
DecisionTreeParams tree_parameters, uint64_t seed, T *d_global_quantiles,
std::shared_ptr<TemporaryMemory<T, int>> in_tempmem) {
this->tree_params = tree_parameters;
this->base_fit(device_allocator_in, host_allocator_in, stream_in, data, ncols,
nrows, labels, rowids, n_sampled_rows, unique_labels,
tree->sparsetree, tree->treeid, seed, true, in_tempmem);
tree->sparsetree, tree->treeid, seed, true, d_global_quantiles,
in_tempmem);
this->set_metadata(tree);
}

Expand All @@ -459,12 +466,13 @@ void DecisionTreeRegressor<T>::fit(
const raft::handle_t &handle, const T *data, const int ncols, const int nrows,
const T *labels, unsigned int *rowids, const int n_sampled_rows,
TreeMetaDataNode<T, T> *&tree, DecisionTreeParams tree_parameters,
uint64_t seed, std::shared_ptr<TemporaryMemory<T, T>> in_tempmem) {
uint64_t seed, T *d_global_quantiles,
std::shared_ptr<TemporaryMemory<T, T>> in_tempmem) {
this->tree_params = tree_parameters;
this->base_fit(handle.get_device_allocator(), handle.get_host_allocator(),
handle.get_stream(), data, ncols, nrows, labels, rowids,
n_sampled_rows, 1, tree->sparsetree, tree->treeid, seed, false,
in_tempmem);
d_global_quantiles, in_tempmem);
this->set_metadata(tree);
}

Expand All @@ -475,11 +483,12 @@ void DecisionTreeRegressor<T>::fit(
const cudaStream_t stream_in, const T *data, const int ncols, const int nrows,
const T *labels, unsigned int *rowids, const int n_sampled_rows,
TreeMetaDataNode<T, T> *&tree, DecisionTreeParams tree_parameters,
uint64_t seed, std::shared_ptr<TemporaryMemory<T, T>> in_tempmem) {
uint64_t seed, T *d_global_quantiles,
std::shared_ptr<TemporaryMemory<T, T>> in_tempmem) {
this->tree_params = tree_parameters;
this->base_fit(device_allocator_in, host_allocator_in, stream_in, data, ncols,
nrows, labels, rowids, n_sampled_rows, 1, tree->sparsetree,
tree->treeid, seed, false, in_tempmem);
tree->treeid, seed, false, d_global_quantiles, in_tempmem);
this->set_metadata(tree);
}

Expand Down
11 changes: 6 additions & 5 deletions cpp/src/decisiontree/decisiontree_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ class DecisionTreeBase {
const int nrows, const L *labels, unsigned int *rowids,
const int n_sampled_rows, int unique_labels,
std::vector<SparseTreeNode<T, L>> &sparsetree, const int treeid,
uint64_t seed, bool is_classifier,
uint64_t seed, bool is_classifier, T *d_global_quantiles,
std::shared_ptr<TemporaryMemory<T, L>> in_tempmem);

public:
Expand Down Expand Up @@ -138,7 +138,7 @@ class DecisionTreeClassifier : public DecisionTreeBase<T, int> {
const int nrows, const int *labels, unsigned int *rowids,
const int n_sampled_rows, const int unique_labels,
TreeMetaDataNode<T, int> *&tree, DecisionTreeParams tree_parameters,
uint64_t seed,
uint64_t seed, T *d_quantiles,
std::shared_ptr<TemporaryMemory<T, int>> in_tempmem = nullptr);

//This fit fucntion does not take handle , used by RF
Expand All @@ -148,7 +148,8 @@ class DecisionTreeClassifier : public DecisionTreeBase<T, int> {
const int nrows, const int *labels, unsigned int *rowids,
const int n_sampled_rows, const int unique_labels,
TreeMetaDataNode<T, int> *&tree, DecisionTreeParams tree_parameters,
uint64_t seed, std::shared_ptr<TemporaryMemory<T, int>> in_tempmem);
uint64_t seed, T *d_quantiles,
std::shared_ptr<TemporaryMemory<T, int>> in_tempmem);

private:
void grow_deep_tree(const T *data, const int *labels, unsigned int *rowids,
Expand All @@ -166,7 +167,7 @@ class DecisionTreeRegressor : public DecisionTreeBase<T, T> {
void fit(const raft::handle_t &handle, const T *data, const int ncols,
const int nrows, const T *labels, unsigned int *rowids,
const int n_sampled_rows, TreeMetaDataNode<T, T> *&tree,
DecisionTreeParams tree_parameters, uint64_t seed,
DecisionTreeParams tree_parameters, uint64_t seed, T *d_quantiles,
std::shared_ptr<TemporaryMemory<T, T>> in_tempmem = nullptr);

//This fit function does not take handle. Used by RF
Expand All @@ -175,7 +176,7 @@ class DecisionTreeRegressor : public DecisionTreeBase<T, T> {
const cudaStream_t stream_in, const T *data, const int ncols,
const int nrows, const T *labels, unsigned int *rowids,
const int n_sampled_rows, TreeMetaDataNode<T, T> *&tree,
DecisionTreeParams tree_parameters, uint64_t seed,
DecisionTreeParams tree_parameters, uint64_t seed, T *d_quantiles,
std::shared_ptr<TemporaryMemory<T, T>> in_tempmem);

private:
Expand Down
66 changes: 65 additions & 1 deletion cpp/src/decisiontree/quantile/quantile.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
*/

#pragma once
#include <raft/cudart_utils.h>
#include <cub/cub.cuh>
#include <raft/cuda_utils.cuh>
#include "quantile.h"

#include <common/nvtx.hpp>
Expand Down Expand Up @@ -183,5 +183,69 @@ void preprocess_quantile(const T *data, const unsigned int *rowids,
return;
}

template <typename T>
__global__ void computeQuantilesSorted(T *quantiles, const int n_bins,
const T *sorted_data, const int length) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
double bin_width = static_cast<double>(length) / n_bins;
int index = int(round((tid + 1) * bin_width)) - 1;
// Old way of computing quantiles. Kept here for comparison.
// To be deleted eventually
// int index = (tid + 1) * floor(bin_width) - 1;
if (tid < n_bins) {
quantiles[tid] = sorted_data[index];
}

return;
}

template <typename T>
void computeQuantiles(
T *quantiles, int n_bins, const T *data, int n_rows, int n_cols,
const std::shared_ptr<MLCommon::deviceAllocator> device_allocator,
vinaydes marked this conversation as resolved.
Show resolved Hide resolved
cudaStream_t stream) {
// Determine temporary device storage requirements
MLCommon::device_buffer<char> *d_temp_storage = nullptr;
size_t temp_storage_bytes = 0;

MLCommon::device_buffer<T> *single_column_sorted;
single_column_sorted =
new MLCommon::device_buffer<T>(device_allocator, stream, n_rows);

vinaydes marked this conversation as resolved.
Show resolved Hide resolved
CUDA_CHECK(cub::DeviceRadixSort::SortKeys(d_temp_storage, temp_storage_bytes,
data, single_column_sorted->data(),
n_rows, 0, 8 * sizeof(T), stream));

// Allocate temporary storage for sorting
d_temp_storage = new MLCommon::device_buffer<char>(device_allocator, stream,
vinaydes marked this conversation as resolved.
Show resolved Hide resolved
temp_storage_bytes);

// Compute quantiles column by column
for (int col = 0; col < n_cols; col++) {
int col_offset = col * n_rows;
int quantile_offset = col * n_bins;

CUDA_CHECK(cub::DeviceRadixSort::SortKeys(
(void *)d_temp_storage->data(), temp_storage_bytes, &data[col_offset],
single_column_sorted->data(), n_rows, 0, 8 * sizeof(T), stream));

int blocks = raft::ceildiv(n_bins, 128);

computeQuantilesSorted<<<blocks, 128, 0, stream>>>(
&quantiles[quantile_offset], n_bins, single_column_sorted->data(),
n_rows);

CUDA_CHECK(cudaGetLastError());
}

single_column_sorted->release(stream);
d_temp_storage->release(stream);

delete single_column_sorted;
delete d_temp_storage;

return;
}

} // namespace DecisionTree
} // namespace ML
8 changes: 7 additions & 1 deletion cpp/src/decisiontree/quantile/quantile.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019, NVIDIA CORPORATION.
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -26,5 +26,11 @@ void preprocess_quantile(const T *data, const unsigned int *rowids,
const int rowoffset, const int nbins,
std::shared_ptr<TemporaryMemory<T, L>> tempmem);

template <typename T>
void computeQuantiles(
T *quantiles, int n_bins, const T *data, int n_rows, int n_cols,
const std::shared_ptr<MLCommon::deviceAllocator> device_allocator,
cudaStream_t stream);

} // namespace DecisionTree
} // namespace ML
Loading