Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor gpu_hist split evaluation #5610

Merged
merged 6 commits into from
Apr 29, 2020
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cub
Submodule cub updated 143 files
10 changes: 5 additions & 5 deletions src/tree/constraints.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ struct ValueConstraint {
inline static void Init(tree::TrainParam *param, unsigned num_feature) {
param->monotone_constraints.resize(num_feature, 0);
}
template <typename ParamT>
XGBOOST_DEVICE inline double CalcWeight(const ParamT &param, tree::GradStats stats) const {
template <typename ParamT, typename GpairT>
XGBOOST_DEVICE inline double CalcWeight(const ParamT &param, GpairT stats) const {
trivialfis marked this conversation as resolved.
Show resolved Hide resolved
double w = xgboost::tree::CalcWeight(param, stats);
if (w < lower_bound) {
return lower_bound;
Expand Down Expand Up @@ -63,9 +63,9 @@ struct ValueConstraint {
return wleft >= wright ? gain : negative_infinity;
}
}

inline void SetChild(const tree::TrainParam &param, bst_uint split_index,
tree::GradStats left, tree::GradStats right, ValueConstraint *cleft,
template <typename GpairT>
void SetChild(const tree::TrainParam &param, bst_uint split_index,
GpairT left, GpairT right, ValueConstraint *cleft,
ValueConstraint *cright) {
int c = param.monotone_constraints.at(split_index);
*cleft = *this;
Expand Down
261 changes: 261 additions & 0 deletions src/tree/gpu_hist/evaluate_splits.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,261 @@
/*!
* Copyright 2020 by XGBoost Contributors
*/
#include "evaluate_splits.cuh"
#include <limits>

namespace xgboost {
namespace tree {

// With constraints
template <typename GradientPairT>
XGBOOST_DEVICE float LossChangeMissing(const GradientPairT& scan,
const GradientPairT& missing,
const GradientPairT& parent_sum,
const GPUTrainingParam& param,
int constraint,
const ValueConstraint& value_constraint,
bool& missing_left_out) { // NOLINT
float parent_gain = CalcGain(param, parent_sum);
trivialfis marked this conversation as resolved.
Show resolved Hide resolved
float missing_left_gain = value_constraint.CalcSplitGain(
param, constraint, GradStats(scan + missing),
GradStats(parent_sum - (scan + missing)));
float missing_right_gain = value_constraint.CalcSplitGain(
param, constraint, GradStats(scan), GradStats(parent_sum - scan));

if (missing_left_gain >= missing_right_gain) {
missing_left_out = true;
return missing_left_gain - parent_gain;
} else {
missing_left_out = false;
return missing_right_gain - parent_gain;
}
}

/*!
* \brief
*
* \tparam ReduceT BlockReduce Type.
* \tparam TempStorage Cub Shared memory
*
* \param begin
* \param end
* \param temp_storage Shared memory for intermediate result.
*/
template <int BLOCK_THREADS, typename ReduceT, typename TempStorageT,
typename GradientSumT>
__device__ GradientSumT
ReduceFeature(common::Span<const GradientSumT> feature_histogram,
TempStorageT* temp_storage) {
__shared__ cub::Uninitialized<GradientSumT> uninitialized_sum;
GradientSumT& shared_sum = uninitialized_sum.Alias();

GradientSumT local_sum = GradientSumT();
// For loop sums features into one block size
auto begin = feature_histogram.data();
auto end = begin + feature_histogram.size();
for (auto itr = begin; itr < end; itr += BLOCK_THREADS) {
bool thread_active = itr + threadIdx.x < end;
// Scan histogram
GradientSumT bin = thread_active ? *(itr + threadIdx.x) : GradientSumT();
local_sum += bin;
}
local_sum = ReduceT(temp_storage->sum_reduce).Reduce(local_sum, cub::Sum());
// Reduction result is stored in thread 0.
if (threadIdx.x == 0) {
shared_sum = local_sum;
}
__syncthreads();
return shared_sum;
}

/*! \brief Find the thread with best gain. */
template <int BLOCK_THREADS, typename ReduceT, typename ScanT,
typename MaxReduceT, typename TempStorageT, typename GradientSumT>
__device__ void EvaluateFeature(
int fidx, EvaluateSplitInputs<GradientSumT> inputs,
DeviceSplitCandidate* best_split, // shared memory storing best split
TempStorageT* temp_storage // temp memory for cub operations
) {
// Use pointer from cut to indicate begin and end of bins for each feature.
uint32_t gidx_begin = inputs.feature_segments[fidx]; // begining bin
uint32_t gidx_end =
inputs.feature_segments[fidx + 1]; // end bin for i^th feature

// Sum histogram bins for current feature
GradientSumT const feature_sum =
ReduceFeature<BLOCK_THREADS, ReduceT, TempStorageT, GradientSumT>(
inputs.gradient_histogram.subspan(gidx_begin, gidx_end - gidx_begin),
temp_storage);

GradientSumT const missing = inputs.parent_sum - feature_sum;
float const null_gain = -std::numeric_limits<bst_float>::infinity();

SumCallbackOp<GradientSumT> prefix_op = SumCallbackOp<GradientSumT>();
for (int scan_begin = gidx_begin; scan_begin < gidx_end;
scan_begin += BLOCK_THREADS) {
bool thread_active = (scan_begin + threadIdx.x) < gidx_end;

// Gradient value for current bin.
GradientSumT bin = thread_active
? inputs.gradient_histogram[scan_begin + threadIdx.x]
: GradientSumT();
ScanT(temp_storage->scan).ExclusiveScan(bin, bin, cub::Sum(), prefix_op);

// Whether the gradient of missing values is put to the left side.
bool missing_left = true;
float gain = null_gain;
if (thread_active) {
gain = LossChangeMissing(bin, missing, inputs.parent_sum, inputs.param,
inputs.monotonic_constraints[fidx],
inputs.value_constraint, missing_left);
}

__syncthreads();

// Find thread with best gain
cub::KeyValuePair<int, float> tuple(threadIdx.x, gain);
cub::KeyValuePair<int, float> best =
MaxReduceT(temp_storage->max_reduce).Reduce(tuple, cub::ArgMax());

__shared__ cub::KeyValuePair<int, float> block_max;
if (threadIdx.x == 0) {
block_max = best;
}

__syncthreads();

// Best thread updates split
if (threadIdx.x == block_max.key) {
int split_gidx = (scan_begin + threadIdx.x) - 1;
float fvalue;
if (split_gidx < static_cast<int>(gidx_begin)) {
fvalue = inputs.min_fvalue[fidx];
} else {
fvalue = inputs.feature_values[split_gidx];
}
GradientSumT left = missing_left ? bin + missing : bin;
GradientSumT right = inputs.parent_sum - left;
best_split->Update(gain, missing_left ? kLeftDir : kRightDir, fvalue,
fidx, GradientPair(left), GradientPair(right),
inputs.param);
}
__syncthreads();
}
}

template <int BLOCK_THREADS, typename GradientSumT>
__global__ void EvaluateSplitsKernel(
EvaluateSplitInputs<GradientSumT> left,
EvaluateSplitInputs<GradientSumT> right,
common::Span<DeviceSplitCandidate> out_candidates) {
// KeyValuePair here used as threadIdx.x -> gain_value
using ArgMaxT = cub::KeyValuePair<int, float>;
using BlockScanT =
cub::BlockScan<GradientSumT, BLOCK_THREADS, cub::BLOCK_SCAN_WARP_SCANS>;
using MaxReduceT = cub::BlockReduce<ArgMaxT, BLOCK_THREADS>;

using SumReduceT = cub::BlockReduce<GradientSumT, BLOCK_THREADS>;

union TempStorage {
typename BlockScanT::TempStorage scan;
typename MaxReduceT::TempStorage max_reduce;
typename SumReduceT::TempStorage sum_reduce;
};

// Aligned && shared storage for best_split
__shared__ cub::Uninitialized<DeviceSplitCandidate> uninitialized_split;
DeviceSplitCandidate& best_split = uninitialized_split.Alias();
__shared__ TempStorage temp_storage;

if (threadIdx.x == 0) {
best_split = DeviceSplitCandidate();
}

__syncthreads();

// If this block is working on the left or right node
bool is_left = blockIdx.x < left.feature_set.size();
EvaluateSplitInputs<GradientSumT>& inputs = is_left ? left : right;

// One block for each feature. Features are sampled, so fidx != blockIdx.x
int fidx = inputs.feature_set[is_left ? blockIdx.x
: blockIdx.x - left.feature_set.size()];

EvaluateFeature<BLOCK_THREADS, SumReduceT, BlockScanT, MaxReduceT>(
fidx, inputs, &best_split, &temp_storage);

__syncthreads();

if (threadIdx.x == 0) {
// Record best loss for each feature
out_candidates[blockIdx.x] = best_split;
}
}

__device__ DeviceSplitCandidate operator+(const DeviceSplitCandidate& a,
const DeviceSplitCandidate& b) {
return a.loss_chg > b.loss_chg ? a : b;
}

template <typename GradientSumT>
void EvaluateSplits(common::Span<DeviceSplitCandidate> out_splits,
EvaluateSplitInputs<GradientSumT> left,
EvaluateSplitInputs<GradientSumT> right) {
size_t combined_num_features =
left.feature_set.size() + right.feature_set.size();
trivialfis marked this conversation as resolved.
Show resolved Hide resolved
dh::TemporaryArray<DeviceSplitCandidate> feature_best_splits(
combined_num_features);
// One block for each feature
uint32_t constexpr kBlockThreads = 256;
dh::LaunchKernel {uint32_t(combined_num_features), kBlockThreads, 0}(
EvaluateSplitsKernel<kBlockThreads, GradientSumT>, left, right,
dh::ToSpan(feature_best_splits));

// Reduce to get best candidate for left and right child over all features
auto reduce_offset =
dh::MakeTransformIterator<size_t>(thrust::make_counting_iterator(0llu),
[=] __device__(size_t idx) -> size_t {
if (idx == 0) {
return 0;
}
if (idx == 1) {
return left.feature_set.size();
}
if (idx == 2) {
return combined_num_features;
}
return 0;
});
size_t temp_storage_bytes = 0;
cub::DeviceSegmentedReduce::Sum(nullptr, temp_storage_bytes,
feature_best_splits.data(), out_splits.data(),
2, reduce_offset, reduce_offset + 1);
dh::TemporaryArray<int8_t> temp(temp_storage_bytes);
cub::DeviceSegmentedReduce::Sum(temp.data().get(), temp_storage_bytes,
feature_best_splits.data(), out_splits.data(),
2, reduce_offset, reduce_offset + 1);
}

template <typename GradientSumT>
void EvaluateSingleSplit(common::Span<DeviceSplitCandidate> out_split,
EvaluateSplitInputs<GradientSumT> input) {
EvaluateSplits(out_split, input, {});
}

template void EvaluateSplits<GradientPair>(
common::Span<DeviceSplitCandidate> out_splits,
EvaluateSplitInputs<GradientPair> left,
EvaluateSplitInputs<GradientPair> right);
template void EvaluateSplits<GradientPairPrecise>(
common::Span<DeviceSplitCandidate> out_splits,
EvaluateSplitInputs<GradientPairPrecise> left,
EvaluateSplitInputs<GradientPairPrecise> right);
template void EvaluateSingleSplit<GradientPair>(
common::Span<DeviceSplitCandidate> out_split,
EvaluateSplitInputs<GradientPair> input);
template void EvaluateSingleSplit<GradientPairPrecise>(
common::Span<DeviceSplitCandidate> out_split,
EvaluateSplitInputs<GradientPairPrecise> input);
} // namespace tree
} // namespace xgboost
37 changes: 37 additions & 0 deletions src/tree/gpu_hist/evaluate_splits.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*!
* Copyright 2020 by XGBoost Contributors
*/
#ifndef EVALUATE_SPLITS_CUH_
#define EVALUATE_SPLITS_CUH_
#include <xgboost/span.h>
#include "../../data/ellpack_page.cuh"
#include "../constraints.cuh"
#include "../updater_gpu_common.cuh"

namespace xgboost {
namespace tree {

template <typename GradientSumT>
struct EvaluateSplitInputs {
int nidx;
GradientSumT parent_sum;
GPUTrainingParam param;
trivialfis marked this conversation as resolved.
Show resolved Hide resolved
common::Span<const bst_feature_t> feature_set;
common::Span<const uint32_t> feature_segments;
common::Span<const float> feature_values;
common::Span<const float> min_fvalue;
common::Span<const GradientSumT> gradient_histogram;
ValueConstraint value_constraint;
common::Span<const int> monotonic_constraints;
};
template <typename GradientSumT>
void EvaluateSplits(common::Span<DeviceSplitCandidate> out_splits,
trivialfis marked this conversation as resolved.
Show resolved Hide resolved
EvaluateSplitInputs<GradientSumT> left,
EvaluateSplitInputs<GradientSumT> right);
template <typename GradientSumT>
void EvaluateSingleSplit(common::Span<DeviceSplitCandidate> out_split,
EvaluateSplitInputs<GradientSumT> input);
} // namespace tree
} // namespace xgboost

#endif // EVALUATE_SPLITS_CUH_
Loading