Skip to content

Commit

Permalink
[CUDA] Add L1 regression objective for cuda_exp (#5457)
Browse files Browse the repository at this point in the history
* add (l1) regression objective for cuda_exp

* remove RenewTreeOutputCUDA from CUDARegressionL2loss

* remove mutable and use CUDAVector

* remove white spaces

* remove TODO and document in (#5459)
  • Loading branch information
shiyu1994 authored Sep 1, 2022
1 parent e02ddc4 commit d78b6bc
Show file tree
Hide file tree
Showing 6 changed files with 624 additions and 31 deletions.
103 changes: 103 additions & 0 deletions include/LightGBM/cuda/cuda_algorithms.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,9 @@ __device__ __forceinline__ T ShufflePrefixSumExclusive(T value, T* shared_mem_bu
template <typename T>
void ShufflePrefixSumGlobal(T* values, size_t len, T* block_prefix_sum_buffer);

template <typename VAL_T, typename REDUCE_T, typename INDEX_T>
void GlobalInclusiveArgPrefixSum(const INDEX_T* sorted_indices, const VAL_T* in_values, REDUCE_T* out_values, REDUCE_T* block_buffer, size_t n);

template <typename T>
__device__ __forceinline__ T ShuffleReduceSumWarp(T value, const data_size_t len) {
if (len > 0) {
Expand Down Expand Up @@ -384,12 +387,112 @@ __device__ void BitonicArgSortDevice(const VAL_T* values, INDEX_T* indices, cons
}
}

template <typename VAL_T, typename INDEX_T, bool ASCENDING>
void BitonicArgSortGlobal(const VAL_T* values, INDEX_T* indices, const size_t len);

template <typename VAL_T, typename REDUCE_T>
void ShuffleReduceSumGlobal(const VAL_T* values, size_t n, REDUCE_T* block_buffer);

template <typename VAL_T, typename REDUCE_T>
void ShuffleReduceDotProdGlobal(const VAL_T* values1, const VAL_T* values2, size_t n, REDUCE_T* block_buffer);

template <typename VAL_T, typename REDUCE_VAL_T, typename INDEX_T>
__device__ void ShuffleSortedPrefixSumDevice(const VAL_T* in_values,
const INDEX_T* sorted_indices,
REDUCE_VAL_T* out_values,
const INDEX_T num_data) {
__shared__ REDUCE_VAL_T shared_buffer[32];
const INDEX_T num_data_per_thread = (num_data + static_cast<INDEX_T>(blockDim.x) - 1) / static_cast<INDEX_T>(blockDim.x);
const INDEX_T start = num_data_per_thread * static_cast<INDEX_T>(threadIdx.x);
const INDEX_T end = min(start + num_data_per_thread, num_data);
REDUCE_VAL_T thread_sum = 0;
for (INDEX_T index = start; index < end; ++index) {
thread_sum += static_cast<REDUCE_VAL_T>(in_values[sorted_indices[index]]);
}
__syncthreads();
thread_sum = ShufflePrefixSumExclusive<REDUCE_VAL_T>(thread_sum, shared_buffer);
const REDUCE_VAL_T thread_base = shared_buffer[threadIdx.x];
for (INDEX_T index = start; index < end; ++index) {
out_values[index] = thread_base + static_cast<REDUCE_VAL_T>(in_values[sorted_indices[index]]);
}
__syncthreads();
}

template <typename VAL_T, typename INDEX_T, typename WEIGHT_T, typename WEIGHT_REDUCE_T, bool ASCENDING, bool USE_WEIGHT>
__global__ void PercentileGlobalKernel(const VAL_T* values,
const WEIGHT_T* weights,
const INDEX_T* sorted_indices,
const WEIGHT_REDUCE_T* weights_prefix_sum,
const double alpha,
const INDEX_T len,
VAL_T* out_value) {
if (!USE_WEIGHT) {
const double float_pos = (1.0f - alpha) * len;
const INDEX_T pos = static_cast<INDEX_T>(float_pos);
if (pos < 1) {
*out_value = values[sorted_indices[0]];
} else if (pos >= len) {
*out_value = values[sorted_indices[len - 1]];
} else {
const double bias = float_pos - static_cast<double>(pos);
const VAL_T v1 = values[sorted_indices[pos - 1]];
const VAL_T v2 = values[sorted_indices[pos]];
*out_value = static_cast<VAL_T>(v1 - (v1 - v2) * bias);
}
} else {
const WEIGHT_REDUCE_T threshold = weights_prefix_sum[len - 1] * (1.0f - alpha);
__shared__ INDEX_T pos;
if (threadIdx.x == 0) {
pos = len;
}
__syncthreads();
for (INDEX_T index = static_cast<INDEX_T>(threadIdx.x); index < len; index += static_cast<INDEX_T>(blockDim.x)) {
if (weights_prefix_sum[index] > threshold && (index == 0 || weights_prefix_sum[index - 1] <= threshold)) {
pos = index;
}
}
__syncthreads();
pos = min(pos, len - 1);
if (pos == 0 || pos == len - 1) {
*out_value = values[pos];
}
const VAL_T v1 = values[sorted_indices[pos - 1]];
const VAL_T v2 = values[sorted_indices[pos]];
*out_value = static_cast<VAL_T>(v1 - (v1 - v2) * (threshold - weights_prefix_sum[pos - 1]) / (weights_prefix_sum[pos] - weights_prefix_sum[pos - 1]));
}
}

template <typename VAL_T, typename INDEX_T, typename WEIGHT_T, typename WEIGHT_REDUCE_T, bool ASCENDING, bool USE_WEIGHT>
void PercentileGlobal(const VAL_T* values,
const WEIGHT_T* weights,
INDEX_T* indices,
WEIGHT_REDUCE_T* weights_prefix_sum,
WEIGHT_REDUCE_T* weights_prefix_sum_buffer,
const double alpha,
const INDEX_T len,
VAL_T* cuda_out_value) {
if (len <= 1) {
CopyFromCUDADeviceToCUDADevice<VAL_T>(cuda_out_value, values, 1, __FILE__, __LINE__);
}
BitonicArgSortGlobal<VAL_T, INDEX_T, ASCENDING>(values, indices, len);
SynchronizeCUDADevice(__FILE__, __LINE__);
if (USE_WEIGHT) {
GlobalInclusiveArgPrefixSum<WEIGHT_T, WEIGHT_REDUCE_T, INDEX_T>(indices, weights, weights_prefix_sum, weights_prefix_sum_buffer, static_cast<size_t>(len));
}
SynchronizeCUDADevice(__FILE__, __LINE__);
PercentileGlobalKernel<VAL_T, INDEX_T, WEIGHT_T, WEIGHT_REDUCE_T, ASCENDING, USE_WEIGHT><<<1, GLOBAL_PREFIX_SUM_BLOCK_SIZE>>>(values, weights, indices, weights_prefix_sum, alpha, len, cuda_out_value);
SynchronizeCUDADevice(__FILE__, __LINE__);
}

template <typename VAL_T, typename INDEX_T, typename WEIGHT_T, typename REDUCE_WEIGHT_T, bool ASCENDING, bool USE_WEIGHT>
__device__ VAL_T PercentileDevice(const VAL_T* values,
const WEIGHT_T* weights,
INDEX_T* indices,
REDUCE_WEIGHT_T* weights_prefix_sum,
const double alpha,
const INDEX_T len);


} // namespace LightGBM

#endif // USE_CUDA_EXP
Expand Down
2 changes: 1 addition & 1 deletion include/LightGBM/cuda/cuda_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ class CUDAVector {
return host_vector;
}

T* RawData() {
T* RawData() const {
return data_;
}

Expand Down
Loading

0 comments on commit d78b6bc

Please sign in to comment.