diff --git a/.ci/test.sh b/.ci/test.sh
index 435614bb826f..7fd09d8d20d9 100755
--- a/.ci/test.sh
+++ b/.ci/test.sh
@@ -132,6 +132,16 @@ if [[ $TASK == "gpu" ]]; then
exit 0
fi
cmake -DUSE_GPU=ON -DOpenCL_INCLUDE_DIR=$AMDAPPSDK_PATH/include/ ..
+elif [[ $TASK == "cuda" ]]; then
+ sed -i'.bak' 's/std::string device_type = "cpu";/std::string device_type = "cuda";/' $BUILD_DIRECTORY/include/LightGBM/config.h
+ grep -q 'std::string device_type = "cuda"' $BUILD_DIRECTORY/include/LightGBM/config.h || exit -1 # make sure that changes were really done
+ if [[ $METHOD == "pip" ]]; then
+ cd $BUILD_DIRECTORY/python-package && python setup.py sdist || exit -1
+ pip install --user $BUILD_DIRECTORY/python-package/dist/lightgbm-$LGB_VER.tar.gz -v --install-option=--cuda || exit -1
+ pytest $BUILD_DIRECTORY/tests/python_package_test || exit -1
+ exit 0
+ fi
+ cmake -DUSE_CUDA=ON ..
elif [[ $TASK == "mpi" ]]; then
if [[ $METHOD == "pip" ]]; then
cd $BUILD_DIRECTORY/python-package && python setup.py sdist || exit -1
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 78c6c0d18efb..b2e206fe5fd9 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -1,10 +1,16 @@
if(USE_GPU OR APPLE)
cmake_minimum_required(VERSION 3.2)
+elseif(USE_CUDA)
+ cmake_minimum_required(VERSION 3.16)
else()
cmake_minimum_required(VERSION 2.8)
endif()
-PROJECT(lightgbm)
+if(USE_CUDA)
+ PROJECT(lightgbm LANGUAGES C CXX CUDA)
+else()
+ PROJECT(lightgbm LANGUAGES C CXX)
+endif()
OPTION(USE_MPI "Enable MPI-based parallel learning" OFF)
OPTION(USE_OPENMP "Enable OpenMP" ON)
@@ -12,6 +18,7 @@ OPTION(USE_GPU "Enable GPU-accelerated training" OFF)
OPTION(USE_SWIG "Enable SWIG to generate Java API" OFF)
OPTION(USE_HDFS "Enable HDFS support (EXPERIMENTAL)" OFF)
OPTION(USE_TIMETAG "Set to ON to output time costs" OFF)
+OPTION(USE_CUDA "Enable CUDA-accelerated training (EXPERIMENTAL)" OFF)
OPTION(USE_DEBUG "Set to ON for Debug mode" OFF)
OPTION(BUILD_STATIC_LIB "Build static library" OFF)
OPTION(BUILD_FOR_R "Set to ON if building lib_lightgbm for use with the R package" OFF)
@@ -94,6 +101,10 @@ else()
ADD_DEFINITIONS(-DUSE_SOCKET)
endif(USE_MPI)
+if(USE_CUDA)
+ SET(USE_OPENMP ON CACHE BOOL "CUDA requires OpenMP" FORCE)
+endif(USE_CUDA)
+
if(USE_OPENMP)
find_package(OpenMP REQUIRED)
SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
@@ -123,6 +134,67 @@ if(USE_GPU)
ADD_DEFINITIONS(-DUSE_GPU)
endif(USE_GPU)
+if(USE_CUDA)
+ find_package(CUDA REQUIRED)
+ include_directories(${CUDA_INCLUDE_DIRS})
+ LIST(APPEND CMAKE_CUDA_FLAGS -Xcompiler=${OpenMP_CXX_FLAGS} -Xcompiler=-fPIC -Xcompiler=-Wall)
+ CUDA_SELECT_NVCC_ARCH_FLAGS(CUDA_ARCH_FLAGS 6.0 6.1 6.2 7.0 7.5+PTX)
+
+ LIST(APPEND CMAKE_CUDA_FLAGS ${CUDA_ARCH_FLAGS})
+ if(USE_DEBUG)
+ SET(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -g")
+ else()
+ SET(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O3 -lineinfo")
+ endif()
+ string(REPLACE ";" " " CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}")
+ message(STATUS "CMAKE_CUDA_FLAGS: ${CMAKE_CUDA_FLAGS}")
+
+ ADD_DEFINITIONS(-DUSE_CUDA)
+ if (NOT DEFINED CMAKE_CUDA_STANDARD)
+ set(CMAKE_CUDA_STANDARD 11)
+ set(CMAKE_CUDA_STANDARD_REQUIRED ON)
+ endif()
+
+ set(BASE_DEFINES
+ -DPOWER_FEATURE_WORKGROUPS=12
+ -DUSE_CONSTANT_BUF=0
+ )
+ set(ALLFEATS_DEFINES
+ ${BASE_DEFINES}
+ -DENABLE_ALL_FEATURES
+ )
+ set(FULLDATA_DEFINES
+ ${ALLFEATS_DEFINES}
+ -DIGNORE_INDICES
+ )
+
+ message(STATUS "ALLFEATS_DEFINES: ${ALLFEATS_DEFINES}")
+ message(STATUS "FULLDATA_DEFINES: ${FULLDATA_DEFINES}")
+
+ function(add_histogram hsize hname hadd hconst hdir)
+ add_library(histo${hsize}${hname} OBJECT src/treelearner/kernels/histogram${hsize}.cu)
+ set_target_properties(histo${hsize}${hname} PROPERTIES CUDA_SEPARABLE_COMPILATION ON)
+ if(hadd)
+ list(APPEND histograms histo${hsize}${hname})
+ set(histograms ${histograms} PARENT_SCOPE)
+ endif()
+ target_compile_definitions(
+ histo${hsize}${hname} PRIVATE
+ -DCONST_HESSIAN=${hconst}
+ ${hdir}
+ )
+ endfunction()
+
+ foreach (hsize _16_64_256)
+ add_histogram("${hsize}" "_sp_const" "True" "1" "${BASE_DEFINES}")
+ add_histogram("${hsize}" "_sp" "True" "0" "${BASE_DEFINES}")
+ add_histogram("${hsize}" "-allfeats_sp_const" "False" "1" "${ALLFEATS_DEFINES}")
+ add_histogram("${hsize}" "-allfeats_sp" "False" "0" "${ALLFEATS_DEFINES}")
+ add_histogram("${hsize}" "-fulldata_sp_const" "True" "1" "${FULLDATA_DEFINES}")
+ add_histogram("${hsize}" "-fulldata_sp" "True" "0" "${FULLDATA_DEFINES}")
+ endforeach()
+endif(USE_CUDA)
+
if(USE_HDFS)
find_package(JNI REQUIRED)
find_path(HDFS_INCLUDE_DIR hdfs.h REQUIRED)
@@ -228,6 +300,9 @@ file(GLOB SOURCES
src/objective/*.cpp
src/network/*.cpp
src/treelearner/*.cpp
+if(USE_CUDA)
+ src/treelearner/*.cu
+endif(USE_CUDA)
)
add_executable(lightgbm src/main.cpp ${SOURCES})
@@ -303,6 +378,19 @@ if(USE_GPU)
TARGET_LINK_LIBRARIES(_lightgbm ${OpenCL_LIBRARY} ${Boost_LIBRARIES})
endif(USE_GPU)
+if(USE_CUDA)
+ set_target_properties(lightgbm PROPERTIES CUDA_RESOLVE_DEVICE_SYMBOLS ON)
+ TARGET_LINK_LIBRARIES(
+ lightgbm
+ ${histograms}
+ )
+ set_target_properties(_lightgbm PROPERTIES CUDA_RESOLVE_DEVICE_SYMBOLS ON)
+ TARGET_LINK_LIBRARIES(
+ _lightgbm
+ ${histograms}
+ )
+endif(USE_CUDA)
+
if(USE_HDFS)
TARGET_LINK_LIBRARIES(lightgbm ${HDFS_CXX_LIBRARIES})
TARGET_LINK_LIBRARIES(_lightgbm ${HDFS_CXX_LIBRARIES})
diff --git a/docs/Parameters.rst b/docs/Parameters.rst
index 14d7a8098cf8..dcd1353e1525 100644
--- a/docs/Parameters.rst
+++ b/docs/Parameters.rst
@@ -1120,7 +1120,13 @@ GPU Parameters
- ``gpu_use_dp`` :raw-html:`🔗︎`, default = ``false``, type = bool
- - set this to ``true`` to use double precision math on GPU (by default single precision is used)
+ - set this to ``true`` to use double precision math on GPU (by default single precision is used in OpenCL implementation and double precision is used in CUDA implementation)
+
+- ``num_gpu`` :raw-html:`🔗︎`, default = ``1``, type = int, constraints: ``num_gpu > 0``
+
+ - number of GPUs
+
+ - **Note**: can be used only in CUDA implementation
.. end params list
diff --git a/include/LightGBM/bin.h b/include/LightGBM/bin.h
index 4f320698c831..987279e47716 100644
--- a/include/LightGBM/bin.h
+++ b/include/LightGBM/bin.h
@@ -288,6 +288,9 @@ class Bin {
/*! \brief Number of all data */
virtual data_size_t num_data() const = 0;
+ /*! \brief Get data pointer */
+ virtual void* get_data() = 0;
+
virtual void ReSize(data_size_t num_data) = 0;
/*!
diff --git a/include/LightGBM/config.h b/include/LightGBM/config.h
index bfcb09a40049..5e1902613905 100644
--- a/include/LightGBM/config.h
+++ b/include/LightGBM/config.h
@@ -965,9 +965,14 @@ struct Config {
// desc = **Note**: refer to `GPU Targets <./GPU-Targets.rst#query-opencl-devices-in-your-system>`__ for more details
int gpu_device_id = -1;
- // desc = set this to ``true`` to use double precision math on GPU (by default single precision is used)
+ // desc = set this to ``true`` to use double precision math on GPU (by default single precision is used in OpenCL implementation and double precision is used in CUDA implementation)
bool gpu_use_dp = false;
+ // check = >0
+ // desc = number of GPUs
+ // desc = **Note**: can be used only in CUDA implementation
+ int num_gpu = 1;
+
#pragma endregion
#pragma endregion
diff --git a/include/LightGBM/cuda/cuda_utils.h b/include/LightGBM/cuda/cuda_utils.h
new file mode 100644
index 000000000000..1054e09daf18
--- /dev/null
+++ b/include/LightGBM/cuda/cuda_utils.h
@@ -0,0 +1,24 @@
+/*!
+ * Copyright (c) 2020 IBM Corporation. All rights reserved.
+ * Licensed under the MIT License. See LICENSE file in the project root for license information.
+ */
+#ifndef LIGHTGBM_CUDA_CUDA_UTILS_H_
+#define LIGHTGBM_CUDA_CUDA_UTILS_H_
+
+#ifdef USE_CUDA
+
+#include
+#include
+#include
+
+#define CUDASUCCESS_OR_FATAL(ans) { gpuAssert((ans), __FILE__, __LINE__); }
+inline void gpuAssert(cudaError_t code, const char *file, int line, bool abort = true) {
+ if (code != cudaSuccess) {
+ LightGBM::Log::Fatal("[CUDA] %s %s %d\n", cudaGetErrorString(code), file, line);
+ if (abort) exit(code);
+ }
+}
+
+#endif // USE_CUDA
+
+#endif // LIGHTGBM_CUDA_CUDA_UTILS_H_
diff --git a/include/LightGBM/cuda/vector_cudahost.h b/include/LightGBM/cuda/vector_cudahost.h
new file mode 100644
index 000000000000..f81cc4dd905f
--- /dev/null
+++ b/include/LightGBM/cuda/vector_cudahost.h
@@ -0,0 +1,86 @@
+/*!
+ * Copyright (c) 2020 IBM Corporation. All rights reserved.
+ * Licensed under the MIT License. See LICENSE file in the project root for license information.
+ */
+#ifndef LIGHTGBM_CUDA_VECTOR_CUDAHOST_H_
+#define LIGHTGBM_CUDA_VECTOR_CUDAHOST_H_
+
+#include
+
+#ifdef USE_CUDA
+#include
+#include
+#endif
+#include
+
+enum LGBM_Device {
+ lgbm_device_cpu,
+ lgbm_device_gpu,
+ lgbm_device_cuda
+};
+
+enum Use_Learner {
+ use_cpu_learner,
+ use_gpu_learner,
+ use_cuda_learner
+};
+
+namespace LightGBM {
+
+class LGBM_config_ {
+ public:
+ static int current_device; // Default: lgbm_device_cpu
+ static int current_learner; // Default: use_cpu_learner
+};
+
+
+template
+struct CHAllocator {
+ typedef T value_type;
+ CHAllocator() {}
+ template CHAllocator(const CHAllocator& other);
+ T* allocate(std::size_t n) {
+ T* ptr;
+ if (n == 0) return NULL;
+ #ifdef USE_CUDA
+ if (LGBM_config_::current_device == lgbm_device_cuda) {
+ cudaError_t ret = cudaHostAlloc(&ptr, n*sizeof(T), cudaHostAllocPortable);
+ if (ret != cudaSuccess) {
+ Log::Warning("Defaulting to malloc in CHAllocator!!!");
+ ptr = reinterpret_cast(_mm_malloc(n*sizeof(T), 16));
+ }
+ } else {
+ ptr = reinterpret_cast(_mm_malloc(n*sizeof(T), 16));
+ }
+ #else
+ ptr = reinterpret_cast(_mm_malloc(n*sizeof(T), 16));
+ #endif
+ return ptr;
+ }
+
+ void deallocate(T* p, std::size_t n) {
+ (void)n; // UNUSED
+ if (p == NULL) return;
+ #ifdef USE_CUDA
+ if (LGBM_config_::current_device == lgbm_device_cuda) {
+ cudaPointerAttributes attributes;
+ cudaPointerGetAttributes(&attributes, p);
+ if ((attributes.type == cudaMemoryTypeHost) && (attributes.devicePointer != NULL)) {
+ cudaFreeHost(p);
+ }
+ } else {
+ _mm_free(p);
+ }
+ #else
+ _mm_free(p);
+ #endif
+ }
+};
+template
+bool operator==(const CHAllocator&, const CHAllocator&);
+template
+bool operator!=(const CHAllocator&, const CHAllocator&);
+
+} // namespace LightGBM
+
+#endif // LIGHTGBM_CUDA_VECTOR_CUDAHOST_H_
diff --git a/include/LightGBM/dataset.h b/include/LightGBM/dataset.h
index 3cf82c2aa1d3..2c6d74caa1d6 100644
--- a/include/LightGBM/dataset.h
+++ b/include/LightGBM/dataset.h
@@ -589,6 +589,14 @@ class Dataset {
return feature_groups_[i]->is_multi_val_;
}
+ inline size_t FeatureGroupSizesInByte(int group) const {
+ return feature_groups_[group]->FeatureGroupSizesInByte();
+ }
+
+ inline void* FeatureGroupData(int group) const {
+ return feature_groups_[group]->FeatureGroupData();
+ }
+
inline double RealThreshold(int i, uint32_t threshold) const {
const int group = feature2group_[i];
const int sub_feature = feature2subfeature_[i];
diff --git a/include/LightGBM/feature_group.h b/include/LightGBM/feature_group.h
index 2b17e98bb9c1..3ba5c143f85b 100644
--- a/include/LightGBM/feature_group.h
+++ b/include/LightGBM/feature_group.h
@@ -228,6 +228,17 @@ class FeatureGroup {
return bin_data_->GetIterator(min_bin, max_bin, most_freq_bin);
}
+ inline size_t FeatureGroupSizesInByte() {
+ return bin_data_->SizesInByte();
+ }
+
+ inline void* FeatureGroupData() {
+ if (is_multi_val_) {
+ return nullptr;
+ }
+ return bin_data_->get_data();
+ }
+
inline data_size_t Split(int sub_feature, const uint32_t* threshold,
int num_threshold, bool default_left,
const data_size_t* data_indices, data_size_t cnt,
diff --git a/python-package/setup.py b/python-package/setup.py
index 9104bc2694b8..f8782fba47a0 100644
--- a/python-package/setup.py
+++ b/python-package/setup.py
@@ -87,7 +87,7 @@ def silent_call(cmd, raise_error=False, error_msg=''):
return 1
-def compile_cpp(use_mingw=False, use_gpu=False, use_mpi=False,
+def compile_cpp(use_mingw=False, use_gpu=False, use_cuda=False, use_mpi=False,
use_hdfs=False, boost_root=None, boost_dir=None,
boost_include_dir=None, boost_librarydir=None,
opencl_include_dir=None, opencl_library=None,
@@ -115,6 +115,8 @@ def compile_cpp(use_mingw=False, use_gpu=False, use_mpi=False,
cmake_cmd.append("-DOpenCL_INCLUDE_DIR={0}".format(opencl_include_dir))
if opencl_library:
cmake_cmd.append("-DOpenCL_LIBRARY={0}".format(opencl_library))
+ elif use_cuda:
+ cmake_cmd.append("-DUSE_CUDA=ON")
if use_mpi:
cmake_cmd.append("-DUSE_MPI=ON")
if nomp:
@@ -188,6 +190,7 @@ class CustomInstall(install):
user_options = install.user_options + [
('mingw', 'm', 'Compile with MinGW'),
('gpu', 'g', 'Compile GPU version'),
+ ('cuda', None, 'Compile CUDA version'),
('mpi', None, 'Compile MPI version'),
('nomp', None, 'Compile version without OpenMP support'),
('hdfs', 'h', 'Compile HDFS version'),
@@ -205,6 +208,7 @@ def initialize_options(self):
install.initialize_options(self)
self.mingw = 0
self.gpu = 0
+ self.cuda = 0
self.boost_root = None
self.boost_dir = None
self.boost_include_dir = None
@@ -228,7 +232,7 @@ def run(self):
open(LOG_PATH, 'wb').close()
if not self.precompile:
copy_files(use_gpu=self.gpu)
- compile_cpp(use_mingw=self.mingw, use_gpu=self.gpu, use_mpi=self.mpi,
+ compile_cpp(use_mingw=self.mingw, use_gpu=self.gpu, use_cuda=self.cuda, use_mpi=self.mpi,
use_hdfs=self.hdfs, boost_root=self.boost_root, boost_dir=self.boost_dir,
boost_include_dir=self.boost_include_dir, boost_librarydir=self.boost_librarydir,
opencl_include_dir=self.opencl_include_dir, opencl_library=self.opencl_library,
diff --git a/src/application/application.cpp b/src/application/application.cpp
index 21163a5a30ea..d9be76d67c90 100644
--- a/src/application/application.cpp
+++ b/src/application/application.cpp
@@ -11,6 +11,7 @@
#include
#include
#include
+#include
#include
#include
#include
@@ -38,6 +39,10 @@ Application::Application(int argc, char** argv) {
if (config_.data.size() == 0 && config_.task != TaskType::kConvertModel) {
Log::Fatal("No training/prediction data, application quit");
}
+
+ if (config_.device_type == std::string("cuda")) {
+ LGBM_config_::current_device = lgbm_device_cuda;
+ }
}
Application::~Application() {
diff --git a/src/boosting/gbdt.cpp b/src/boosting/gbdt.cpp
index 03f5fe25d554..fcb7185a1512 100644
--- a/src/boosting/gbdt.cpp
+++ b/src/boosting/gbdt.cpp
@@ -17,6 +17,9 @@
namespace LightGBM {
+int LGBM_config_::current_device = lgbm_device_cpu;
+int LGBM_config_::current_learner = use_cpu_learner;
+
GBDT::GBDT()
: iter_(0),
train_data_(nullptr),
@@ -58,6 +61,10 @@ void GBDT::Init(const Config* config, const Dataset* train_data, const Objective
es_first_metric_only_ = config_->first_metric_only;
shrinkage_rate_ = config_->learning_rate;
+ if (config_->device_type == std::string("cuda")) {
+ LGBM_config_::current_learner = use_cuda_learner;
+ }
+
// load forced_splits file
if (!config->forcedsplits_filename.empty()) {
std::ifstream forced_splits_file(config->forcedsplits_filename.c_str());
diff --git a/src/boosting/gbdt.h b/src/boosting/gbdt.h
index a84b321531f1..0d38385d5f06 100644
--- a/src/boosting/gbdt.h
+++ b/src/boosting/gbdt.h
@@ -8,6 +8,7 @@
#include
#include
#include
+#include
#include
#include
@@ -479,10 +480,19 @@ class GBDT : public GBDTBase {
std::vector> models_;
/*! \brief Max feature index of training data*/
int max_feature_idx_;
+
+#ifdef USE_CUDA
+ /*! \brief First order derivative of training data */
+ std::vector> gradients_;
+ /*! \brief Second order derivative of training data */
+ std::vector> hessians_;
+#else
/*! \brief First order derivative of training data */
std::vector> gradients_;
- /*! \brief Secend order derivative of training data */
+ /*! \brief Second order derivative of training data */
std::vector> hessians_;
+#endif
+
/*! \brief Store the indices of in-bag data */
std::vector> bag_data_indices_;
/*! \brief Number of in-bag data */
diff --git a/src/c_api.cpp b/src/c_api.cpp
index 61b3038e660b..a389e8e47b1d 100644
--- a/src/c_api.cpp
+++ b/src/c_api.cpp
@@ -1611,10 +1611,14 @@ int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
const float* hess,
int* is_finished) {
API_BEGIN();
- Booster* ref_booster = reinterpret_cast(handle);
#ifdef SCORE_T_USE_DOUBLE
+ (void) handle; // UNUSED VARIABLE
+ (void) grad; // UNUSED VARIABLE
+ (void) hess; // UNUSED VARIABLE
+ (void) is_finished; // UNUSED VARIABLE
Log::Fatal("Don't support custom loss function when SCORE_T_USE_DOUBLE is enabled");
#else
+ Booster* ref_booster = reinterpret_cast(handle);
if (ref_booster->TrainOneIter(grad, hess)) {
*is_finished = 1;
} else {
diff --git a/src/io/config.cpp b/src/io/config.cpp
index d569a7401e17..6878896deb58 100644
--- a/src/io/config.cpp
+++ b/src/io/config.cpp
@@ -4,6 +4,7 @@
*/
#include
+#include
#include
#include
#include
@@ -126,6 +127,8 @@ void GetDeviceType(const std::unordered_map& params, s
*device_type = "cpu";
} else if (value == std::string("gpu")) {
*device_type = "gpu";
+ } else if (value == std::string("cuda")) {
+ *device_type = "cuda";
} else {
Log::Fatal("Unknown device type %s", value.c_str());
}
@@ -206,6 +209,9 @@ void Config::Set(const std::unordered_map& params) {
GetMetricType(params, &metric);
GetObjectiveType(params, &objective);
GetDeviceType(params, &device_type);
+ if (device_type == std::string("cuda")) {
+ LGBM_config_::current_device = lgbm_device_cuda;
+ }
GetTreeLearnerType(params, &tree_learner);
GetMembersFromString(params);
@@ -319,11 +325,18 @@ void Config::CheckParamConflict() {
num_leaves = static_cast(full_num_leaves);
}
}
- // force col-wise for gpu
- if (device_type == std::string("gpu")) {
+ // force col-wise for gpu & CUDA
+ if (device_type == std::string("gpu") || device_type == std::string("cuda")) {
force_col_wise = true;
force_row_wise = false;
}
+
+ // force gpu_use_dp for CUDA
+ if (device_type == std::string("cuda") && !gpu_use_dp) {
+ Log::Warning("CUDA currently requires double precision calculations.");
+ gpu_use_dp = true;
+ }
+
// min_data_in_leaf must be at least 2 if path smoothing is active. This is because when the split is calculated
// the count is calculated using the proportion of hessian in the leaf which is rounded up to nearest int, so it can
// be 1 when there is actually no data in the leaf. In rare cases this can cause a bug because with path smoothing the
diff --git a/src/io/config_auto.cpp b/src/io/config_auto.cpp
index b14af67fd30e..ad102020322d 100644
--- a/src/io/config_auto.cpp
+++ b/src/io/config_auto.cpp
@@ -296,6 +296,7 @@ const std::unordered_set& Config::parameter_set() {
"gpu_platform_id",
"gpu_device_id",
"gpu_use_dp",
+ "num_gpu",
});
return params;
}
@@ -611,6 +612,9 @@ void Config::GetMembersFromString(const std::unordered_map
#include
+#include
#include
#include
#include
@@ -334,13 +335,24 @@ void Dataset::Construct(std::vector>* bin_mappers,
"constant.");
}
auto features_in_group = NoGroup(used_features);
+
+ auto is_sparse = io_config.is_enable_sparse;
+ if (io_config.device_type == std::string("cuda")) {
+ LGBM_config_::current_device = lgbm_device_cuda;
+ if (is_sparse) {
+ Log::Warning("Using sparse features with CUDA is currently not supported.");
+ }
+ is_sparse = false;
+ }
+
std::vector group_is_multi_val(used_features.size(), 0);
if (io_config.enable_bundle && !used_features.empty()) {
+ bool lgbm_is_gpu_used = io_config.device_type == std::string("gpu") || io_config.device_type == std::string("cuda");
features_in_group = FastFeatureBundling(
*bin_mappers, sample_non_zero_indices, sample_values, num_per_col,
num_sample_col, static_cast(total_sample_cnt),
- used_features, num_data_, io_config.device_type == std::string("gpu"),
- io_config.is_enable_sparse, &group_is_multi_val);
+ used_features, num_data_, lgbm_is_gpu_used,
+ is_sparse, &group_is_multi_val);
}
num_features_ = 0;
diff --git a/src/io/dense_bin.hpp b/src/io/dense_bin.hpp
index e821fe32f08d..4a1cc43fa79a 100644
--- a/src/io/dense_bin.hpp
+++ b/src/io/dense_bin.hpp
@@ -7,6 +7,7 @@
#define LIGHTGBM_IO_DENSE_BIN_HPP_
#include
+#include
#include
#include
@@ -364,6 +365,8 @@ class DenseBin : public Bin {
data_size_t num_data() const override { return num_data_; }
+ void* get_data() override { return data_.data(); }
+
void FinishLoad() override {
if (IS_4BIT) {
if (buf_.empty()) {
@@ -458,7 +461,11 @@ class DenseBin : public Bin {
private:
data_size_t num_data_;
+#ifdef USE_CUDA
+ std::vector> data_;
+#else
std::vector> data_;
+#endif
std::vector buf_;
DenseBin(const DenseBin& other)
diff --git a/src/io/sparse_bin.hpp b/src/io/sparse_bin.hpp
index 07f57c4480a2..1fc076576095 100644
--- a/src/io/sparse_bin.hpp
+++ b/src/io/sparse_bin.hpp
@@ -409,6 +409,8 @@ class SparseBin : public Bin {
data_size_t num_data() const override { return num_data_; }
+ void* get_data() override { return nullptr; }
+
void FinishLoad() override {
// get total non zero size
size_t pair_cnt = 0;
diff --git a/src/treelearner/cuda_kernel_launcher.cu b/src/treelearner/cuda_kernel_launcher.cu
new file mode 100644
index 000000000000..8ceb5b813c9c
--- /dev/null
+++ b/src/treelearner/cuda_kernel_launcher.cu
@@ -0,0 +1,171 @@
+/*!
+ * Copyright (c) 2020 IBM Corporation. All rights reserved.
+ * Licensed under the MIT License. See LICENSE file in the project root for license information.
+ */
+#ifdef USE_CUDA
+
+#include "cuda_kernel_launcher.h"
+
+#include
+
+#include
+
+#include
+
+namespace LightGBM {
+
+void cuda_histogram(
+ int histogram_size,
+ data_size_t leaf_num_data,
+ data_size_t num_data,
+ bool use_all_features,
+ bool is_constant_hessian,
+ int num_workgroups,
+ cudaStream_t stream,
+ uint8_t* arg0,
+ uint8_t* arg1,
+ data_size_t arg2,
+ data_size_t* arg3,
+ data_size_t arg4,
+ score_t* arg5,
+ score_t* arg6,
+ score_t arg6_const,
+ char* arg7,
+ volatile int* arg8,
+ void* arg9,
+ size_t exp_workgroups_per_feature) {
+ if (histogram_size == 16) {
+ if (leaf_num_data == num_data) {
+ if (use_all_features) {
+ if (!is_constant_hessian)
+ histogram16<<<16*num_workgroups, 16, 0, stream>>>(arg0, arg1, arg2,
+ arg3, arg4, arg5,
+ arg6, arg7, arg8, static_cast(arg9), exp_workgroups_per_feature);
+ else
+ histogram16<<<16*num_workgroups, 16, 0, stream>>>(arg0, arg1, arg2,
+ arg3, arg4, arg5,
+ arg6_const, arg7, arg8, static_cast(arg9), exp_workgroups_per_feature);
+ } else {
+ if (!is_constant_hessian)
+ histogram16_fulldata<<<16*num_workgroups, 16, 0, stream>>>(arg0, arg1, arg2,
+ arg3, arg4, arg5,
+ arg6, arg7, arg8, static_cast(arg9), exp_workgroups_per_feature);
+ else
+ histogram16_fulldata<<<16*num_workgroups, 16, 0, stream>>>(arg0, arg1, arg2,
+ arg3, arg4, arg5,
+ arg6_const, arg7, arg8, static_cast(arg9), exp_workgroups_per_feature);
+ }
+ } else {
+ if (use_all_features) {
+ // seems all features is always enabled, so this should be the same as fulldata
+ if (!is_constant_hessian)
+ histogram16<<<16*num_workgroups, 16, 0, stream>>>(arg0, arg1, arg2,
+ arg3, arg4, arg5,
+ arg6, arg7, arg8, static_cast(arg9), exp_workgroups_per_feature);
+ else
+ histogram16<<<16*num_workgroups, 16, 0, stream>>>(arg0, arg1, arg2,
+ arg3, arg4, arg5,
+ arg6_const, arg7, arg8, static_cast(arg9), exp_workgroups_per_feature);
+ } else {
+ if (!is_constant_hessian)
+ histogram16<<<16*num_workgroups, 16, 0, stream>>>(arg0, arg1, arg2,
+ arg3, arg4, arg5,
+ arg6, arg7, arg8, static_cast(arg9), exp_workgroups_per_feature);
+ else
+ histogram16<<<16*num_workgroups, 16, 0, stream>>>(arg0, arg1, arg2,
+ arg3, arg4, arg5,
+ arg6_const, arg7, arg8, static_cast(arg9), exp_workgroups_per_feature);
+ }
+ }
+ } else if (histogram_size == 64) {
+ if (leaf_num_data == num_data) {
+ if (use_all_features) {
+ if (!is_constant_hessian)
+ histogram64<<<4*num_workgroups, 64, 0, stream>>>(arg0, arg1, arg2,
+ arg3, arg4, arg5,
+ arg6, arg7, arg8, static_cast(arg9), exp_workgroups_per_feature);
+ else
+ histogram64<<<4*num_workgroups, 64, 0, stream>>>(arg0, arg1, arg2,
+ arg3, arg4, arg5,
+ arg6_const, arg7, arg8, static_cast(arg9), exp_workgroups_per_feature);
+ } else {
+ if (!is_constant_hessian)
+ histogram64_fulldata<<<4*num_workgroups, 64, 0, stream>>>(arg0, arg1, arg2,
+ arg3, arg4, arg5,
+ arg6, arg7, arg8, static_cast(arg9), exp_workgroups_per_feature);
+ else
+ histogram64_fulldata<<<4*num_workgroups, 64, 0, stream>>>(arg0, arg1, arg2,
+ arg3, arg4, arg5,
+ arg6_const, arg7, arg8, static_cast(arg9), exp_workgroups_per_feature);
+ }
+ } else {
+ if (use_all_features) {
+ // seems all features is always enabled, so this should be the same as fulldata
+ if (!is_constant_hessian)
+ histogram64<<<4*num_workgroups, 64, 0, stream>>>(arg0, arg1, arg2,
+ arg3, arg4, arg5,
+ arg6, arg7, arg8, static_cast(arg9), exp_workgroups_per_feature);
+ else
+ histogram64<<<4*num_workgroups, 64, 0, stream>>>(arg0, arg1, arg2,
+ arg3, arg4, arg5,
+ arg6_const, arg7, arg8, static_cast(arg9), exp_workgroups_per_feature);
+ } else {
+ if (!is_constant_hessian)
+ histogram64<<<4*num_workgroups, 64, 0, stream>>>(arg0, arg1, arg2,
+ arg3, arg4, arg5,
+ arg6, arg7, arg8, static_cast(arg9), exp_workgroups_per_feature);
+ else
+ histogram64<<<4*num_workgroups, 64, 0, stream>>>(arg0, arg1, arg2,
+ arg3, arg4, arg5,
+ arg6_const, arg7, arg8, static_cast(arg9), exp_workgroups_per_feature);
+ }
+ }
+ } else {
+ if (leaf_num_data == num_data) {
+ if (use_all_features) {
+ if (!is_constant_hessian)
+ histogram256<<>>(arg0, arg1, arg2,
+ arg3, arg4, arg5,
+ arg6, arg7, arg8, static_cast(arg9), exp_workgroups_per_feature);
+ else
+ histogram256<<>>(arg0, arg1, arg2,
+ arg3, arg4, arg5,
+ arg6_const, arg7, arg8, static_cast(arg9), exp_workgroups_per_feature);
+ } else {
+ if (!is_constant_hessian)
+ histogram256_fulldata<<>>(arg0, arg1, arg2,
+ arg3, arg4, arg5,
+ arg6, arg7, arg8, static_cast(arg9), exp_workgroups_per_feature);
+ else
+ histogram256_fulldata<<>>(arg0, arg1, arg2,
+ arg3, arg4, arg5,
+ arg6_const, arg7, arg8, static_cast(arg9), exp_workgroups_per_feature);
+ }
+ } else {
+ if (use_all_features) {
+ // seems all features is always enabled, so this should be the same as fulldata
+ if (!is_constant_hessian)
+ histogram256<<>>(arg0, arg1, arg2,
+ arg3, arg4, arg5,
+ arg6, arg7, arg8, static_cast(arg9), exp_workgroups_per_feature);
+ else
+ histogram256<<>>(arg0, arg1, arg2,
+ arg3, arg4, arg5,
+ arg6_const, arg7, arg8, static_cast(arg9), exp_workgroups_per_feature);
+ } else {
+ if (!is_constant_hessian)
+ histogram256<<>>(arg0, arg1, arg2,
+ arg3, arg4, arg5,
+ arg6, arg7, arg8, static_cast(arg9), exp_workgroups_per_feature);
+ else
+ histogram256<<>>(arg0, arg1, arg2,
+ arg3, arg4, arg5,
+ arg6_const, arg7, arg8, static_cast(arg9), exp_workgroups_per_feature);
+ }
+ }
+ }
+}
+
+} // namespace LightGBM
+
+#endif // USE_CUDA
diff --git a/src/treelearner/cuda_kernel_launcher.h b/src/treelearner/cuda_kernel_launcher.h
new file mode 100644
index 000000000000..0714e05b2f2d
--- /dev/null
+++ b/src/treelearner/cuda_kernel_launcher.h
@@ -0,0 +1,70 @@
+/*!
+ * Copyright (c) 2020 IBM Corporation. All rights reserved.
+ * Licensed under the MIT License. See LICENSE file in the project root for license information.
+ */
+#ifndef LIGHTGBM_TREELEARNER_CUDA_KERNEL_LAUNCHER_H_
+#define LIGHTGBM_TREELEARNER_CUDA_KERNEL_LAUNCHER_H_
+
+#ifdef USE_CUDA
+#include
+#include "kernels/histogram_16_64_256.hu" // kernel, acc_type, data_size_t, uchar, score_t
+
+namespace LightGBM {
+
+struct ThreadData {
+ // device id
+ int device_id;
+ // parameters for cuda_histogram
+ int histogram_size;
+ data_size_t leaf_num_data;
+ data_size_t num_data;
+ bool use_all_features;
+ bool is_constant_hessian;
+ int num_workgroups;
+ cudaStream_t stream;
+ uint8_t* device_features;
+ uint8_t* device_feature_masks;
+ data_size_t* device_data_indices;
+ score_t* device_gradients;
+ score_t* device_hessians;
+ score_t hessians_const;
+ char* device_subhistograms;
+ volatile int* sync_counters;
+ void* device_histogram_outputs;
+ size_t exp_workgroups_per_feature;
+ // cuda events
+ cudaEvent_t* kernel_start;
+ cudaEvent_t* kernel_wait_obj;
+ std::chrono::duration* kernel_input_wait_time;
+ // copy histogram
+ size_t output_size;
+ char* host_histogram_output;
+ cudaEvent_t* histograms_wait_obj;
+};
+
+
+void cuda_histogram(
+ int histogram_size,
+ data_size_t leaf_num_data,
+ data_size_t num_data,
+ bool use_all_features,
+ bool is_constant_hessian,
+ int num_workgroups,
+ cudaStream_t stream,
+ uint8_t* arg0,
+ uint8_t* arg1,
+ data_size_t arg2,
+ data_size_t* arg3,
+ data_size_t arg4,
+ score_t* arg5,
+ score_t* arg6,
+ score_t arg6_const,
+ char* arg7,
+ volatile int* arg8,
+ void* arg9,
+ size_t exp_workgroups_per_feature);
+
+} // namespace LightGBM
+
+#endif // USE_CUDA
+#endif // LIGHTGBM_TREELEARNER_CUDA_KERNEL_LAUNCHER_H_
diff --git a/src/treelearner/cuda_tree_learner.cpp b/src/treelearner/cuda_tree_learner.cpp
new file mode 100644
index 000000000000..16569eef257f
--- /dev/null
+++ b/src/treelearner/cuda_tree_learner.cpp
@@ -0,0 +1,974 @@
+/*!
+ * Copyright (c) 2020 IBM Corporation. All rights reserved.
+ * Licensed under the MIT License. See LICENSE file in the project root for license information.
+ */
+#ifdef USE_CUDA
+#include "cuda_tree_learner.h"
+
+#include
+#include
+#include
+#include
+#include
+
+#include
+
+#include
+#include
+#include
+
+#include "../io/dense_bin.hpp"
+
+namespace LightGBM {
+
+#define cudaMemcpy_DEBUG 0 // 1: DEBUG cudaMemcpy
+#define ResetTrainingData_DEBUG 0 // 1: Debug ResetTrainingData
+
+#define CUDA_DEBUG 0
+
+static void *launch_cuda_histogram(void *thread_data) {
+ ThreadData td = *(reinterpret_cast(thread_data));
+ int device_id = td.device_id;
+ CUDASUCCESS_OR_FATAL(cudaSetDevice(device_id));
+
+ // launch cuda kernel
+ cuda_histogram(td.histogram_size,
+ td.leaf_num_data, td.num_data, td.use_all_features,
+ td.is_constant_hessian, td.num_workgroups, td.stream,
+ td.device_features,
+ td.device_feature_masks,
+ td.num_data,
+ td.device_data_indices,
+ td.leaf_num_data,
+ td.device_gradients,
+ td.device_hessians, td.hessians_const,
+ td.device_subhistograms, td.sync_counters,
+ td.device_histogram_outputs,
+ td.exp_workgroups_per_feature);
+
+ CUDASUCCESS_OR_FATAL(cudaGetLastError());
+
+ return NULL;
+}
+
+CUDATreeLearner::CUDATreeLearner(const Config* config)
+ :SerialTreeLearner(config) {
+ use_bagging_ = false;
+ nthreads_ = 0;
+ if (config->gpu_use_dp && USE_DP_FLOAT) {
+ Log::Info("LightGBM using CUDA trainer with DP float!!");
+ } else {
+ Log::Info("LightGBM using CUDA trainer with SP float!!");
+ }
+}
+
+CUDATreeLearner::~CUDATreeLearner() {
+}
+
+
+void CUDATreeLearner::Init(const Dataset* train_data, bool is_constant_hessian) {
+ // initialize SerialTreeLearner
+ SerialTreeLearner::Init(train_data, is_constant_hessian);
+
+ // some additional variables needed for GPU trainer
+ num_feature_groups_ = train_data_->num_feature_groups();
+
+ // Initialize GPU buffers and kernels: get device info
+ InitGPU(config_->num_gpu);
+}
+
+// some functions used for debugging the GPU histogram construction
+#if CUDA_DEBUG > 0
+
+void PrintHistograms(hist_t* h, size_t size) {
+ double total_hess = 0;
+ for (size_t i = 0; i < size; ++i) {
+ printf("%03lu=%9.3g,%9.3g\t", i, GET_GRAD(h, i), GET_HESS(h, i));
+ if ((i & 3) == 3)
+ printf("\n");
+ total_hess += GET_HESS(h, i);
+ }
+ printf("\nSum hessians: %9.3g\n", total_hess);
+}
+
+union Float_t {
+ int64_t i;
+ double f;
+ static int64_t ulp_diff(Float_t a, Float_t b) {
+ return abs(a.i - b.i);
+ }
+};
+
+int CompareHistograms(hist_t* h1, hist_t* h2, size_t size, int feature_id, int dp_flag, int const_flag) {
+ int i;
+ int retval = 0;
+ printf("Comparing Histograms, feature_id = %d, size = %d\n", feature_id, static_cast(size));
+ if (dp_flag) { // double precision
+ double af, bf;
+ int64_t ai, bi;
+ for (i = 0; i < static_cast(size); ++i) {
+ af = GET_GRAD(h1, i);
+ bf = GET_GRAD(h2, i);
+ if ((((std::fabs(af - bf))/af) >= 1e-6) && ((std::fabs(af - bf)) >= 1e-6)) {
+ printf("i = %5d, h1.grad %13.6lf, h2.grad %13.6lf\n", i, af, bf);
+ ++retval;
+ }
+ if (const_flag) {
+ ai = GET_HESS((reinterpret_cast(h1)), i);
+ bi = GET_HESS((reinterpret_cast(h2)), i);
+ if (ai != bi) {
+ printf("i = %5d, h1.hess %" PRId64 ", h2.hess %" PRId64 "\n", i, ai, bi);
+ ++retval;
+ }
+ } else {
+ af = GET_HESS(h1, i);
+ bf = GET_HESS(h2, i);
+ if (((std::fabs(af - bf))/af) >= 1e-6) {
+ printf("i = %5d, h1.hess %13.6lf, h2.hess %13.6lf\n", i, af, bf);
+ ++retval;
+ }
+ }
+ }
+ } else { // single precision
+ float af, bf;
+ int ai, bi;
+ for (i = 0; i < static_cast(size); ++i) {
+ af = GET_GRAD(h1, i);
+ bf = GET_GRAD(h2, i);
+ if ((((std::fabs(af - bf))/af) >= 1e-6) && ((std::fabs(af - bf)) >= 1e-6)) {
+ printf("i = %5d, h1.grad %13.6f, h2.grad %13.6f\n", i, af, bf);
+ ++retval;
+ }
+ if (const_flag) {
+ ai = GET_HESS(h1, i);
+ bi = GET_HESS(h2, i);
+ if (ai != bi) {
+ printf("i = %5d, h1.hess %d, h2.hess %d\n", i, ai, bi);
+ ++retval;
+ }
+ } else {
+ af = GET_HESS(h1, i);
+ bf = GET_HESS(h2, i);
+ if (((std::fabs(af - bf))/af) >= 1e-5) {
+ printf("i = %5d, h1.hess %13.6f, h2.hess %13.6f\n", i, af, bf);
+ ++retval;
+ }
+ }
+ }
+ }
+ printf("DONE Comparing Histograms...\n");
+ return retval;
+}
+#endif
+
+int CUDATreeLearner::GetNumWorkgroupsPerFeature(data_size_t leaf_num_data) {
+ // we roughly want 256 workgroups per device, and we have num_dense_feature4_ feature tuples.
+ // also guarantee that there are at least 2K examples per workgroup
+ double x = 256.0 / num_dense_feature_groups_;
+
+ int exp_workgroups_per_feature = static_cast(ceil(log2(x)));
+ double t = leaf_num_data / 1024.0;
+
+ Log::Debug("We can have at most %d workgroups per feature4 for efficiency reasons\n"
+ "Best workgroup size per feature for full utilization is %d\n", static_cast(ceil(t)), (1 << exp_workgroups_per_feature));
+
+ exp_workgroups_per_feature = std::min(exp_workgroups_per_feature, static_cast(ceil(log(static_cast(t))/log(2.0))));
+ if (exp_workgroups_per_feature < 0)
+ exp_workgroups_per_feature = 0;
+ if (exp_workgroups_per_feature > kMaxLogWorkgroupsPerFeature)
+ exp_workgroups_per_feature = kMaxLogWorkgroupsPerFeature;
+
+ return exp_workgroups_per_feature;
+}
+
+void CUDATreeLearner::GPUHistogram(data_size_t leaf_num_data, bool use_all_features) {
+ // we have already copied ordered gradients, ordered hessians and indices to GPU
+ // decide the best number of workgroups working on one feature4 tuple
+ // set work group size based on feature size
+ // each 2^exp_workgroups_per_feature workgroups work on a feature4 tuple
+ int exp_workgroups_per_feature = GetNumWorkgroupsPerFeature(leaf_num_data);
+ std::vector num_gpu_workgroups;
+ ThreadData *thread_data = reinterpret_cast(_mm_malloc(sizeof(ThreadData) * num_gpu_, 16));
+
+ for (int device_id = 0; device_id < num_gpu_; ++device_id) {
+ int num_gpu_feature_groups = num_gpu_feature_groups_[device_id];
+ int num_workgroups = (1 << exp_workgroups_per_feature) * num_gpu_feature_groups;
+ num_gpu_workgroups.push_back(num_workgroups);
+ if (num_workgroups > preallocd_max_num_wg_[device_id]) {
+ preallocd_max_num_wg_.at(device_id) = num_workgroups;
+ CUDASUCCESS_OR_FATAL(cudaFree(device_subhistograms_[device_id]));
+ CUDASUCCESS_OR_FATAL(cudaMalloc(&(device_subhistograms_[device_id]), static_cast(num_workgroups * dword_features_ * device_bin_size_ * (3 * hist_bin_entry_sz_ / 2))));
+ }
+ // set thread_data
+ SetThreadData(thread_data, device_id, histogram_size_, leaf_num_data, use_all_features,
+ num_workgroups, exp_workgroups_per_feature);
+ }
+
+ for (int device_id = 0; device_id < num_gpu_; ++device_id) {
+ if (pthread_create(cpu_threads_[device_id], NULL, launch_cuda_histogram, reinterpret_cast(&thread_data[device_id]))) {
+ Log::Fatal("Error in creating threads.");
+ }
+ }
+
+ /* Wait for the threads to finish */
+ for (int device_id = 0; device_id < num_gpu_; ++device_id) {
+ if (pthread_join(*(cpu_threads_[device_id]), NULL)) {
+ Log::Fatal("Error in joining threads.");
+ }
+ }
+
+ for (int device_id = 0; device_id < num_gpu_; ++device_id) {
+ // copy the results asynchronously. Size depends on if double precision is used
+
+ size_t output_size = num_gpu_feature_groups_[device_id] * dword_features_ * device_bin_size_ * hist_bin_entry_sz_;
+ size_t host_output_offset = offset_gpu_feature_groups_[device_id] * dword_features_ * device_bin_size_ * hist_bin_entry_sz_;
+
+ CUDASUCCESS_OR_FATAL(cudaMemcpyAsync(reinterpret_cast(host_histogram_outputs_) + host_output_offset, device_histogram_outputs_[device_id], output_size, cudaMemcpyDeviceToHost, stream_[device_id]));
+ CUDASUCCESS_OR_FATAL(cudaEventRecord(histograms_wait_obj_[device_id], stream_[device_id]));
+ }
+}
+
+
+template
+void CUDATreeLearner::WaitAndGetHistograms(FeatureHistogram* leaf_histogram_array) {
+ HistType* hist_outputs = reinterpret_cast(host_histogram_outputs_);
+
+ #pragma omp parallel for schedule(static, num_gpu_)
+ for (int device_id = 0; device_id < num_gpu_; ++device_id) {
+ // when the output is ready, the computation is done
+ CUDASUCCESS_OR_FATAL(cudaEventSynchronize(histograms_wait_obj_[device_id]));
+ }
+
+ HistType* histograms = reinterpret_cast(leaf_histogram_array[0].RawData() - kHistOffset);
+ #pragma omp parallel for schedule(static)
+ for (int i = 0; i < num_dense_feature_groups_; ++i) {
+ if (!feature_masks_[i]) {
+ continue;
+ }
+ int dense_group_index = dense_feature_group_map_[i];
+ auto old_histogram_array = histograms + train_data_->GroupBinBoundary(dense_group_index) * 2;
+ int bin_size = train_data_->FeatureGroupNumBin(dense_group_index);
+
+ for (int j = 0; j < bin_size; ++j) {
+ GET_GRAD(old_histogram_array, j) = GET_GRAD(hist_outputs, i * device_bin_size_+ j);
+ GET_HESS(old_histogram_array, j) = GET_HESS(hist_outputs, i * device_bin_size_+ j);
+ }
+ }
+}
+
+void CUDATreeLearner::CountDenseFeatureGroups() {
+ num_dense_feature_groups_ = 0;
+
+ for (int i = 0; i < num_feature_groups_; ++i) {
+ if (!train_data_->IsMultiGroup(i)) {
+ num_dense_feature_groups_++;
+ }
+ }
+ if (!num_dense_feature_groups_) {
+ Log::Warning("GPU acceleration is disabled because no non-trival dense features can be found");
+ }
+}
+
+void CUDATreeLearner::prevAllocateGPUMemory() {
+ // how many feature-group tuples we have
+ // leave some safe margin for prefetching
+ // 256 work-items per workgroup. Each work-item prefetches one tuple for that feature
+
+ allocated_num_data_ = std::max(num_data_ + 256 * (1 << kMaxLogWorkgroupsPerFeature), allocated_num_data_);
+
+ // clear sparse/dense maps
+ dense_feature_group_map_.clear();
+ sparse_feature_group_map_.clear();
+
+ // do nothing it there is no dense feature
+ if (!num_dense_feature_groups_) {
+ return;
+ }
+
+ // calculate number of feature groups per gpu
+ num_gpu_feature_groups_.resize(num_gpu_);
+ offset_gpu_feature_groups_.resize(num_gpu_);
+ int num_features_per_gpu = num_dense_feature_groups_ / num_gpu_;
+ int remain_features = num_dense_feature_groups_ - num_features_per_gpu * num_gpu_;
+
+ int offset = 0;
+
+ for (int i = 0; i < num_gpu_; ++i) {
+ offset_gpu_feature_groups_.at(i) = offset;
+ num_gpu_feature_groups_.at(i) = (i < remain_features) ? num_features_per_gpu + 1 : num_features_per_gpu;
+ offset += num_gpu_feature_groups_.at(i);
+ }
+
+ feature_masks_.resize(num_dense_feature_groups_);
+ Log::Debug("Resized feature masks");
+
+ ptr_pinned_feature_masks_ = feature_masks_.data();
+ Log::Debug("Memset pinned_feature_masks_");
+ memset(ptr_pinned_feature_masks_, 0, num_dense_feature_groups_);
+
+ // histogram bin entry size depends on the precision (single/double)
+ hist_bin_entry_sz_ = 2 * (config_->gpu_use_dp ? sizeof(hist_t) : sizeof(gpu_hist_t)); // two elements in this "size"
+
+ CUDASUCCESS_OR_FATAL(cudaHostAlloc(reinterpret_cast(&host_histogram_outputs_), static_cast(num_dense_feature_groups_ * device_bin_size_ * hist_bin_entry_sz_), cudaHostAllocPortable));
+
+ nthreads_ = std::min(omp_get_max_threads(), num_dense_feature_groups_ / dword_features_);
+ nthreads_ = std::max(nthreads_, 1);
+}
+
+// allocate GPU memory for each GPU
+void CUDATreeLearner::AllocateGPUMemory() {
+ #pragma omp parallel for schedule(static, num_gpu_)
+
+ for (int device_id = 0; device_id < num_gpu_; ++device_id) {
+ // do nothing it there is no gpu feature
+ int num_gpu_feature_groups = num_gpu_feature_groups_[device_id];
+ if (num_gpu_feature_groups) {
+ CUDASUCCESS_OR_FATAL(cudaSetDevice(device_id));
+
+ // allocate memory for all features
+ if (device_features_[device_id] != NULL) {
+ CUDASUCCESS_OR_FATAL(cudaFree(device_features_[device_id]));
+ }
+
+ CUDASUCCESS_OR_FATAL(cudaMalloc(&(device_features_[device_id]), static_cast(num_gpu_feature_groups * num_data_ * sizeof(uint8_t))));
+ Log::Debug("Allocated device_features_ addr=%p sz=%lu", device_features_[device_id], num_gpu_feature_groups * num_data_);
+
+ // allocate space for gradients and hessians on device
+ // we will copy gradients and hessians in after ordered_gradients_ and ordered_hessians_ are constructed
+ if (device_gradients_[device_id] != NULL) {
+ CUDASUCCESS_OR_FATAL(cudaFree(device_gradients_[device_id]));
+ }
+
+ if (device_hessians_[device_id] != NULL) {
+ CUDASUCCESS_OR_FATAL(cudaFree(device_hessians_[device_id]));
+ }
+
+ if (device_feature_masks_[device_id] != NULL) {
+ CUDASUCCESS_OR_FATAL(cudaFree(device_feature_masks_[device_id]));
+ }
+
+ CUDASUCCESS_OR_FATAL(cudaMalloc(&(device_gradients_[device_id]), static_cast(allocated_num_data_ * sizeof(score_t))));
+ CUDASUCCESS_OR_FATAL(cudaMalloc(&(device_hessians_[device_id]), static_cast(allocated_num_data_ * sizeof(score_t))));
+
+ CUDASUCCESS_OR_FATAL(cudaMalloc(&(device_feature_masks_[device_id]), static_cast(num_gpu_feature_groups)));
+
+ // copy indices to the device
+ if (device_data_indices_[device_id] != NULL) {
+ CUDASUCCESS_OR_FATAL(cudaFree(device_data_indices_[device_id]));
+ }
+
+ CUDASUCCESS_OR_FATAL(cudaMalloc(&(device_data_indices_[device_id]), static_cast(allocated_num_data_ * sizeof(data_size_t))));
+ CUDASUCCESS_OR_FATAL(cudaMemsetAsync(device_data_indices_[device_id], 0, allocated_num_data_ * sizeof(data_size_t), stream_[device_id]));
+
+ Log::Debug("Memset device_data_indices_");
+
+ // create output buffer, each feature has a histogram with device_bin_size_ bins,
+ // each work group generates a sub-histogram of dword_features_ features.
+ if (!device_subhistograms_[device_id]) {
+ // only initialize once here, as this will not need to change when ResetTrainingData() is called
+ CUDASUCCESS_OR_FATAL(cudaMalloc(&(device_subhistograms_[device_id]), static_cast(preallocd_max_num_wg_[device_id] * dword_features_ * device_bin_size_ * (3 * hist_bin_entry_sz_ / 2))));
+
+ Log::Debug("created device_subhistograms_: %p", device_subhistograms_[device_id]);
+ }
+
+ // create atomic counters for inter-group coordination
+ CUDASUCCESS_OR_FATAL(cudaFree(sync_counters_[device_id]));
+ CUDASUCCESS_OR_FATAL(cudaMalloc(&(sync_counters_[device_id]), static_cast(num_gpu_feature_groups * sizeof(int))));
+ CUDASUCCESS_OR_FATAL(cudaMemsetAsync(sync_counters_[device_id], 0, num_gpu_feature_groups * sizeof(int), stream_[device_id]));
+
+ // The output buffer is allocated to host directly, to overlap compute and data transfer
+ CUDASUCCESS_OR_FATAL(cudaFree(device_histogram_outputs_[device_id]));
+ CUDASUCCESS_OR_FATAL(cudaMalloc(&(device_histogram_outputs_[device_id]), static_cast(num_gpu_feature_groups * device_bin_size_ * hist_bin_entry_sz_)));
+ }
+ }
+}
+
+void CUDATreeLearner::ResetGPUMemory() {
+ // clear sparse/dense maps
+ dense_feature_group_map_.clear();
+ sparse_feature_group_map_.clear();
+}
+
+void CUDATreeLearner::copyDenseFeature() {
+ if (num_feature_groups_ == 0) {
+ LGBM_config_::current_learner = use_cpu_learner;
+ return;
+ }
+
+ Log::Debug("Started copying dense features from CPU to GPU");
+ // find the dense feature-groups and group then into Feature4 data structure (several feature-groups packed into 4 bytes)
+ size_t copied_feature = 0;
+ // set device info
+ int device_id = 0;
+ uint8_t* device_features = device_features_[device_id];
+ CUDASUCCESS_OR_FATAL(cudaSetDevice(device_id));
+ Log::Debug("Started copying dense features from CPU to GPU - 1");
+
+ for (int i = 0; i < num_feature_groups_; ++i) {
+ // looking for dword_features_ non-sparse feature-groups
+ if (!train_data_->IsMultiGroup(i)) {
+ dense_feature_group_map_.push_back(i);
+ auto sizes_in_byte = train_data_->FeatureGroupSizesInByte(i);
+ void* tmp_data = train_data_->FeatureGroupData(i);
+ Log::Debug("Started copying dense features from CPU to GPU - 2");
+ CUDASUCCESS_OR_FATAL(cudaMemcpyAsync(&device_features[copied_feature * num_data_], tmp_data, sizes_in_byte, cudaMemcpyHostToDevice, stream_[device_id]));
+ Log::Debug("Started copying dense features from CPU to GPU - 3");
+ copied_feature++;
+ // reset device info
+ if (copied_feature == static_cast(num_gpu_feature_groups_[device_id])) {
+ CUDASUCCESS_OR_FATAL(cudaEventRecord(features_future_[device_id], stream_[device_id]));
+ device_id += 1;
+ copied_feature = 0;
+ if (device_id < num_gpu_) {
+ device_features = device_features_[device_id];
+ CUDASUCCESS_OR_FATAL(cudaSetDevice(device_id));
+ }
+ }
+ } else {
+ sparse_feature_group_map_.push_back(i);
+ }
+ }
+}
+
+
+
+// InitGPU w/ num_gpu
+void CUDATreeLearner::InitGPU(int num_gpu) {
+ // Get the max bin size, used for selecting best GPU kernel
+ max_num_bin_ = 0;
+
+ #if CUDA_DEBUG >= 1
+ printf("bin_size: ");
+ #endif
+ for (int i = 0; i < num_feature_groups_; ++i) {
+ if (train_data_->IsMultiGroup(i)) {
+ continue;
+ }
+ #if CUDA_DEBUG >= 1
+ printf("%d, ", train_data_->FeatureGroupNumBin(i));
+ #endif
+ max_num_bin_ = std::max(max_num_bin_, train_data_->FeatureGroupNumBin(i));
+ }
+ #if CUDA_DEBUG >= 1
+ printf("\n");
+ #endif
+
+ if (max_num_bin_ <= 16) {
+ device_bin_size_ = 16;
+ histogram_size_ = 16;
+ dword_features_ = 1;
+ } else if (max_num_bin_ <= 64) {
+ device_bin_size_ = 64;
+ histogram_size_ = 64;
+ dword_features_ = 1;
+ } else if (max_num_bin_ <= 256) {
+ Log::Debug("device_bin_size_ = 256");
+ device_bin_size_ = 256;
+ histogram_size_ = 256;
+ dword_features_ = 1;
+ } else {
+ Log::Fatal("bin size %d cannot run on GPU", max_num_bin_);
+ }
+ if (max_num_bin_ == 65) {
+ Log::Warning("Setting max_bin to 63 is sugguested for best performance");
+ }
+ if (max_num_bin_ == 17) {
+ Log::Warning("Setting max_bin to 15 is sugguested for best performance");
+ }
+
+ // get num_dense_feature_groups_
+ CountDenseFeatureGroups();
+
+ if (num_gpu > num_dense_feature_groups_) num_gpu = num_dense_feature_groups_;
+
+ // initialize GPU
+ int gpu_count;
+
+ CUDASUCCESS_OR_FATAL(cudaGetDeviceCount(&gpu_count));
+ num_gpu_ = (gpu_count < num_gpu) ? gpu_count : num_gpu;
+
+ // set cpu threads
+ cpu_threads_ = reinterpret_cast(_mm_malloc(sizeof(pthread_t *)*num_gpu_, 16));
+ for (int device_id = 0; device_id < num_gpu_; ++device_id) {
+ cpu_threads_[device_id] = reinterpret_cast(_mm_malloc(sizeof(pthread_t), 16));
+ }
+
+ // resize device memory pointers
+ device_features_.resize(num_gpu_);
+ device_gradients_.resize(num_gpu_);
+ device_hessians_.resize(num_gpu_);
+ device_feature_masks_.resize(num_gpu_);
+ device_data_indices_.resize(num_gpu_);
+ sync_counters_.resize(num_gpu_);
+ device_subhistograms_.resize(num_gpu_);
+ device_histogram_outputs_.resize(num_gpu_);
+
+ // create stream & events to handle multiple GPUs
+ preallocd_max_num_wg_.resize(num_gpu_, 1024);
+ stream_.resize(num_gpu_);
+ hessians_future_.resize(num_gpu_);
+ gradients_future_.resize(num_gpu_);
+ indices_future_.resize(num_gpu_);
+ features_future_.resize(num_gpu_);
+ kernel_start_.resize(num_gpu_);
+ kernel_wait_obj_.resize(num_gpu_);
+ histograms_wait_obj_.resize(num_gpu_);
+
+ for (int i = 0; i < num_gpu_; ++i) {
+ CUDASUCCESS_OR_FATAL(cudaSetDevice(i));
+ CUDASUCCESS_OR_FATAL(cudaStreamCreate(&(stream_[i])));
+ CUDASUCCESS_OR_FATAL(cudaEventCreate(&(hessians_future_[i])));
+ CUDASUCCESS_OR_FATAL(cudaEventCreate(&(gradients_future_[i])));
+ CUDASUCCESS_OR_FATAL(cudaEventCreate(&(indices_future_[i])));
+ CUDASUCCESS_OR_FATAL(cudaEventCreate(&(features_future_[i])));
+ CUDASUCCESS_OR_FATAL(cudaEventCreate(&(kernel_start_[i])));
+ CUDASUCCESS_OR_FATAL(cudaEventCreate(&(kernel_wait_obj_[i])));
+ CUDASUCCESS_OR_FATAL(cudaEventCreate(&(histograms_wait_obj_[i])));
+ }
+
+ allocated_num_data_ = 0;
+ prevAllocateGPUMemory();
+
+ AllocateGPUMemory();
+
+ copyDenseFeature();
+}
+
+Tree* CUDATreeLearner::Train(const score_t* gradients, const score_t *hessians) {
+ Tree *ret = SerialTreeLearner::Train(gradients, hessians);
+ return ret;
+}
+
+void CUDATreeLearner::ResetTrainingDataInner(const Dataset* train_data, bool is_constant_hessian, bool reset_multi_val_bin) {
+ // check data size
+ data_size_t old_allocated_num_data = allocated_num_data_;
+
+ SerialTreeLearner::ResetTrainingDataInner(train_data, is_constant_hessian, reset_multi_val_bin);
+
+ #if ResetTrainingData_DEBUG == 1
+ serial_time = std::chrono::steady_clock::now() - start_serial_time;
+ #endif
+
+ num_feature_groups_ = train_data_->num_feature_groups();
+
+ // GPU memory has to been reallocated because data may have been changed
+ #if ResetTrainingData_DEBUG == 1
+ auto start_alloc_gpu_time = std::chrono::steady_clock::now();
+ #endif
+
+ // AllocateGPUMemory only when the number of data increased
+ int old_num_feature_groups = num_dense_feature_groups_;
+ CountDenseFeatureGroups();
+ if ((old_allocated_num_data < (num_data_ + 256 * (1 << kMaxLogWorkgroupsPerFeature))) || (old_num_feature_groups < num_dense_feature_groups_)) {
+ prevAllocateGPUMemory();
+ AllocateGPUMemory();
+ } else {
+ ResetGPUMemory();
+ }
+
+ copyDenseFeature();
+
+ #if ResetTrainingData_DEBUG == 1
+ alloc_gpu_time = std::chrono::steady_clock::now() - start_alloc_gpu_time;
+ #endif
+
+ // setup GPU kernel arguments after we allocating all the buffers
+ #if ResetTrainingData_DEBUG == 1
+ auto start_set_arg_time = std::chrono::steady_clock::now();
+ #endif
+
+ #if ResetTrainingData_DEBUG == 1
+ set_arg_time = std::chrono::steady_clock::now() - start_set_arg_time;
+ reset_training_data_time = std::chrono::steady_clock::now() - start_reset_training_data_time;
+ Log::Info("reset_training_data_time: %f secs.", reset_training_data_time.count() * 1e-3);
+ Log::Info("serial_time: %f secs.", serial_time.count() * 1e-3);
+ Log::Info("alloc_gpu_time: %f secs.", alloc_gpu_time.count() * 1e-3);
+ Log::Info("set_arg_time: %f secs.", set_arg_time.count() * 1e-3);
+ #endif
+}
+
+void CUDATreeLearner::BeforeTrain() {
+ #if cudaMemcpy_DEBUG == 1
+ std::chrono::duration device_hessians_time = std::chrono::milliseconds(0);
+ std::chrono::duration device_gradients_time = std::chrono::milliseconds(0);
+ #endif
+
+ SerialTreeLearner::BeforeTrain();
+
+ #if CUDA_DEBUG >= 2
+ printf("CUDATreeLearner::BeforeTrain() Copying initial full gradients and hessians to device\n");
+ #endif
+
+ // Copy initial full hessians and gradients to GPU.
+ // We start copying as early as possible, instead of at ConstructHistogram().
+ if ((hessians_ != NULL) && (gradients_ != NULL)) {
+ if (!use_bagging_ && num_dense_feature_groups_) {
+ Log::Debug("CudaTreeLearner::BeforeTrain() No baggings, dense_feature_groups_=%d", num_dense_feature_groups_);
+
+ for (int device_id = 0; device_id < num_gpu_; ++device_id) {
+ if (!(share_state_->is_constant_hessian)) {
+ Log::Debug("CUDATreeLearner::BeforeTrain(): Starting hessians_ -> device_hessians_");
+
+ #if cudaMemcpy_DEBUG == 1
+ auto start_device_hessians_time = std::chrono::steady_clock::now();
+ #endif
+
+ CUDASUCCESS_OR_FATAL(cudaMemcpyAsync(device_hessians_[device_id], hessians_, num_data_*sizeof(score_t), cudaMemcpyHostToDevice, stream_[device_id]));
+
+ CUDASUCCESS_OR_FATAL(cudaEventRecord(hessians_future_[device_id], stream_[device_id]));
+
+ #if cudaMemcpy_DEBUG == 1
+ device_hessians_time = std::chrono::steady_clock::now() - start_device_hessians_time;
+ #endif
+
+ Log::Debug("queued copy of device_hessians_");
+ }
+
+ #if cudaMemcpy_DEBUG == 1
+ auto start_device_gradients_time = std::chrono::steady_clock::now();
+ #endif
+
+ CUDASUCCESS_OR_FATAL(cudaMemcpyAsync(device_gradients_[device_id], gradients_, num_data_ * sizeof(score_t), cudaMemcpyHostToDevice, stream_[device_id]));
+ CUDASUCCESS_OR_FATAL(cudaEventRecord(gradients_future_[device_id], stream_[device_id]));
+
+ #if cudaMemcpy_DEBUG == 1
+ device_gradients_time = std::chrono::steady_clock::now() - start_device_gradients_time;
+ #endif
+
+ Log::Debug("CUDATreeLearner::BeforeTrain: issued gradients_ -> device_gradients_");
+ }
+ }
+ }
+
+ // use bagging
+ if ((hessians_ != NULL) && (gradients_ != NULL)) {
+ if (data_partition_->leaf_count(0) != num_data_ && num_dense_feature_groups_) {
+ // On GPU, we start copying indices, gradients and hessians now, instead at ConstructHistogram()
+ // copy used gradients and hessians to ordered buffer
+ const data_size_t* indices = data_partition_->indices();
+ data_size_t cnt = data_partition_->leaf_count(0);
+
+ // transfer the indices to GPU
+ for (int device_id = 0; device_id < num_gpu_; ++device_id) {
+ CUDASUCCESS_OR_FATAL(cudaMemcpyAsync(device_data_indices_[device_id], indices, cnt * sizeof(*indices), cudaMemcpyHostToDevice, stream_[device_id]));
+ CUDASUCCESS_OR_FATAL(cudaEventRecord(indices_future_[device_id], stream_[device_id]));
+
+ if (!(share_state_->is_constant_hessian)) {
+ CUDASUCCESS_OR_FATAL(cudaMemcpyAsync(device_hessians_[device_id], const_cast(reinterpret_cast(&(hessians_[0]))), num_data_ * sizeof(score_t), cudaMemcpyHostToDevice, stream_[device_id]));
+ CUDASUCCESS_OR_FATAL(cudaEventRecord(hessians_future_[device_id], stream_[device_id]));
+ }
+
+ CUDASUCCESS_OR_FATAL(cudaMemcpyAsync(device_gradients_[device_id], const_cast(reinterpret_cast(&(gradients_[0]))), num_data_ * sizeof(score_t), cudaMemcpyHostToDevice, stream_[device_id]));
+ CUDASUCCESS_OR_FATAL(cudaEventRecord(gradients_future_[device_id], stream_[device_id]));
+ }
+ }
+ }
+}
+
+bool CUDATreeLearner::BeforeFindBestSplit(const Tree* tree, int left_leaf, int right_leaf) {
+ int smaller_leaf;
+
+ data_size_t num_data_in_left_child = GetGlobalDataCountInLeaf(left_leaf);
+ data_size_t num_data_in_right_child = GetGlobalDataCountInLeaf(right_leaf);
+
+ // only have root
+ if (right_leaf < 0) {
+ smaller_leaf = -1;
+ } else if (num_data_in_left_child < num_data_in_right_child) {
+ smaller_leaf = left_leaf;
+ } else {
+ smaller_leaf = right_leaf;
+ }
+
+ // Copy indices, gradients and hessians as early as possible
+ if (smaller_leaf >= 0 && num_dense_feature_groups_) {
+ // only need to initialize for smaller leaf
+ // Get leaf boundary
+ const data_size_t* indices = data_partition_->indices();
+ data_size_t begin = data_partition_->leaf_begin(smaller_leaf);
+ data_size_t end = begin + data_partition_->leaf_count(smaller_leaf);
+
+ for (int device_id = 0; device_id < num_gpu_; ++device_id) {
+ CUDASUCCESS_OR_FATAL(cudaMemcpyAsync(device_data_indices_[device_id], &indices[begin], (end-begin) * sizeof(data_size_t), cudaMemcpyHostToDevice, stream_[device_id]));
+ CUDASUCCESS_OR_FATAL(cudaEventRecord(indices_future_[device_id], stream_[device_id]));
+ }
+ }
+
+ const bool ret = SerialTreeLearner::BeforeFindBestSplit(tree, left_leaf, right_leaf);
+
+ return ret;
+}
+
+bool CUDATreeLearner::ConstructGPUHistogramsAsync(
+ const std::vector& is_feature_used,
+ const data_size_t* data_indices, data_size_t num_data) {
+ if (num_data <= 0) {
+ return false;
+ }
+
+ // do nothing if no features can be processed on GPU
+ if (!num_dense_feature_groups_) {
+ Log::Debug("no dense feature groups, returning");
+ return false;
+ }
+
+ // copy data indices if it is not null
+ if (data_indices != nullptr && num_data != num_data_) {
+ for (int device_id = 0; device_id < num_gpu_; ++device_id) {
+ CUDASUCCESS_OR_FATAL(cudaMemcpyAsync(device_data_indices_[device_id], data_indices, num_data * sizeof(data_size_t), cudaMemcpyHostToDevice, stream_[device_id]));
+ CUDASUCCESS_OR_FATAL(cudaEventRecord(indices_future_[device_id], stream_[device_id]));
+ }
+ }
+
+ // converted indices in is_feature_used to feature-group indices
+ std::vector is_feature_group_used(num_feature_groups_, 0);
+
+ #pragma omp parallel for schedule(static, 1024) if (num_features_ >= 2048)
+ for (int i = 0; i < num_features_; ++i) {
+ if (is_feature_used[i]) {
+ int feature_group = train_data_->Feature2Group(i);
+ is_feature_group_used[feature_group] = (train_data_->FeatureGroupNumBin(feature_group) <= 16) ? 2 : 1;
+ }
+ }
+
+ // construct the feature masks for dense feature-groups
+ int used_dense_feature_groups = 0;
+ #pragma omp parallel for schedule(static, 1024) reduction(+:used_dense_feature_groups) if (num_dense_feature_groups_ >= 2048)
+ for (int i = 0; i < num_dense_feature_groups_; ++i) {
+ if (is_feature_group_used[dense_feature_group_map_[i]]) {
+ feature_masks_[i] = is_feature_group_used[dense_feature_group_map_[i]];
+ ++used_dense_feature_groups;
+ } else {
+ feature_masks_[i] = 0;
+ }
+ }
+ bool use_all_features = ((used_dense_feature_groups == num_dense_feature_groups_) && (data_indices != nullptr));
+ // if no feature group is used, just return and do not use GPU
+ if (used_dense_feature_groups == 0) {
+ return false;
+ }
+
+ // if not all feature groups are used, we need to transfer the feature mask to GPU
+ // otherwise, we will use a specialized GPU kernel with all feature groups enabled
+
+ // We now copy even if all features are used.
+ #pragma omp parallel for schedule(static, num_gpu_)
+ for (int device_id = 0; device_id < num_gpu_; ++device_id) {
+ int offset = offset_gpu_feature_groups_[device_id];
+ CUDASUCCESS_OR_FATAL(cudaMemcpyAsync(device_feature_masks_[device_id], ptr_pinned_feature_masks_ + offset, num_gpu_feature_groups_[device_id] , cudaMemcpyHostToDevice, stream_[device_id]));
+ }
+
+ // All data have been prepared, now run the GPU kernel
+ GPUHistogram(num_data, use_all_features);
+
+ return true;
+}
+
+void CUDATreeLearner::ConstructHistograms(const std::vector& is_feature_used, bool use_subtract) {
+ std::vector is_sparse_feature_used(num_features_, 0);
+ std::vector is_dense_feature_used(num_features_, 0);
+ int num_dense_features = 0, num_sparse_features = 0;
+
+ #pragma omp parallel for schedule(static)
+ for (int feature_index = 0; feature_index < num_features_; ++feature_index) {
+ if (!col_sampler_.is_feature_used_bytree()[feature_index]) continue;
+ if (!is_feature_used[feature_index]) continue;
+ if (train_data_->IsMultiGroup(train_data_->Feature2Group(feature_index))) {
+ is_sparse_feature_used[feature_index] = 1;
+ num_sparse_features++;
+ } else {
+ is_dense_feature_used[feature_index] = 1;
+ num_dense_features++;
+ }
+ }
+
+ // construct smaller leaf
+ hist_t* ptr_smaller_leaf_hist_data = smaller_leaf_histogram_array_[0].RawData() - kHistOffset;
+
+ // Check workgroups per feature4 tuple..
+ int exp_workgroups_per_feature = GetNumWorkgroupsPerFeature(smaller_leaf_splits_->num_data_in_leaf());
+
+ // if the workgroup per feature is 1 (2^0), return as the work is too small for a GPU
+ if (exp_workgroups_per_feature == 0) {
+ return SerialTreeLearner::ConstructHistograms(is_feature_used, use_subtract);
+ }
+
+ // ConstructGPUHistogramsAsync will return true if there are availabe feature groups dispatched to GPU
+ bool is_gpu_used = ConstructGPUHistogramsAsync(is_feature_used,
+ nullptr, smaller_leaf_splits_->num_data_in_leaf());
+
+ // then construct sparse features on CPU
+ // We set data_indices to null to avoid rebuilding ordered gradients/hessians
+ if (num_sparse_features > 0) {
+ train_data_->ConstructHistograms(is_sparse_feature_used,
+ smaller_leaf_splits_->data_indices(), smaller_leaf_splits_->num_data_in_leaf(),
+ gradients_, hessians_,
+ ordered_gradients_.data(), ordered_hessians_.data(),
+ share_state_.get(),
+ ptr_smaller_leaf_hist_data);
+ }
+
+ // wait for GPU to finish, only if GPU is actually used
+ if (is_gpu_used) {
+ if (config_->gpu_use_dp) {
+ // use double precision
+ WaitAndGetHistograms(smaller_leaf_histogram_array_);
+ } else {
+ // use single precision
+ WaitAndGetHistograms(smaller_leaf_histogram_array_);
+ }
+ }
+
+ // Compare GPU histogram with CPU histogram, useful for debuggin GPU code problem
+ // #define CUDA_DEBUG_COMPARE
+#ifdef CUDA_DEBUG_COMPARE
+ printf("Start Comparing_Histogram between GPU and CPU, num_dense_feature_groups_ = %d\n", num_dense_feature_groups_);
+ bool compare = true;
+ for (int i = 0; i < num_dense_feature_groups_; ++i) {
+ if (!feature_masks_[i])
+ continue;
+ int dense_feature_group_index = dense_feature_group_map_[i];
+ size_t size = train_data_->FeatureGroupNumBin(dense_feature_group_index);
+ hist_t* ptr_smaller_leaf_hist_data = smaller_leaf_histogram_array_[0].RawData() - kHistOffset;
+ hist_t* current_histogram = ptr_smaller_leaf_hist_data + train_data_->GroupBinBoundary(dense_feature_group_index) * 2;
+ hist_t* gpu_histogram = new hist_t[size * 2];
+ data_size_t num_data = smaller_leaf_splits_->num_data_in_leaf();
+ printf("Comparing histogram for feature %d, num_data %d, num_data_ = %d, %lu bins\n", dense_feature_group_index, num_data, num_data_, size);
+ std::copy(current_histogram, current_histogram + size * 2, gpu_histogram);
+ std::memset(current_histogram, 0, size * sizeof(hist_t) * 2);
+ if (train_data_->FeatureGroupBin(dense_feature_group_index) == nullptr) {
+ continue;
+ }
+ if (num_data == num_data_) {
+ if (share_state_->is_constant_hessian) {
+ printf("ConstructHistogram(): num_data == num_data_ is_constant_hessian\n");
+ train_data_->FeatureGroupBin(dense_feature_group_index)->ConstructHistogram(
+ 0,
+ num_data,
+ gradients_,
+ current_histogram);
+ } else {
+ printf("ConstructHistogram(): num_data == num_data_\n");
+ train_data_->FeatureGroupBin(dense_feature_group_index)->ConstructHistogram(
+ 0,
+ num_data,
+ gradients_, hessians_,
+ current_histogram);
+ }
+ } else {
+ if (share_state_->is_constant_hessian) {
+ printf("ConstructHistogram(): is_constant_hessian\n");
+ train_data_->FeatureGroupBin(dense_feature_group_index)->ConstructHistogram(
+ smaller_leaf_splits_->data_indices(),
+ 0,
+ num_data,
+ gradients_,
+ current_histogram);
+ } else {
+ printf("ConstructHistogram(): 4, num_data = %d, num_data_ = %d\n", num_data, num_data_);
+ train_data_->FeatureGroupBin(dense_feature_group_index)->ConstructHistogram(
+ smaller_leaf_splits_->data_indices(),
+ 0,
+ num_data,
+ gradients_, hessians_,
+ current_histogram);
+ }
+ }
+ int retval;
+ if ((num_data != num_data_) && compare) {
+ retval = CompareHistograms(gpu_histogram, current_histogram, size, dense_feature_group_index, config_->gpu_use_dp, share_state_->is_constant_hessian);
+ printf("CompareHistograms reports %d errors\n", retval);
+ compare = false;
+ }
+ retval = CompareHistograms(gpu_histogram, current_histogram, size, dense_feature_group_index, config_->gpu_use_dp, share_state_->is_constant_hessian);
+ if (num_data == num_data_) {
+ printf("CompareHistograms reports %d errors\n", retval);
+ } else {
+ printf("CompareHistograms reports %d errors\n", retval);
+ }
+ std::copy(gpu_histogram, gpu_histogram + size * 2, current_histogram);
+ delete [] gpu_histogram;
+ }
+ printf("End Comparing Histogram between GPU and CPU\n");
+ fflush(stderr);
+ fflush(stdout);
+#endif
+
+ if (larger_leaf_histogram_array_ != nullptr && !use_subtract) {
+ // construct larger leaf
+ hist_t* ptr_larger_leaf_hist_data = larger_leaf_histogram_array_[0].RawData() - kHistOffset;
+
+ is_gpu_used = ConstructGPUHistogramsAsync(is_feature_used,
+ larger_leaf_splits_->data_indices(), larger_leaf_splits_->num_data_in_leaf());
+
+ // then construct sparse features on CPU
+ // We set data_indices to null to avoid rebuilding ordered gradients/hessians
+ if (num_sparse_features > 0) {
+ train_data_->ConstructHistograms(is_sparse_feature_used,
+ larger_leaf_splits_->data_indices(), larger_leaf_splits_->num_data_in_leaf(),
+ gradients_, hessians_,
+ ordered_gradients_.data(), ordered_hessians_.data(),
+ share_state_.get(),
+ ptr_larger_leaf_hist_data);
+ }
+
+ // wait for GPU to finish, only if GPU is actually used
+ if (is_gpu_used) {
+ if (config_->gpu_use_dp) {
+ // use double precision
+ WaitAndGetHistograms(larger_leaf_histogram_array_);
+ } else {
+ // use single precision
+ WaitAndGetHistograms(larger_leaf_histogram_array_);
+ }
+ }
+ }
+}
+
+void CUDATreeLearner::FindBestSplits(const Tree* tree) {
+ SerialTreeLearner::FindBestSplits(tree);
+
+#if CUDA_DEBUG >= 3
+ for (int feature_index = 0; feature_index < num_features_; ++feature_index) {
+ if (!col_sampler_.is_feature_used_bytree()[feature_index]) continue;
+ if (parent_leaf_histogram_array_ != nullptr
+ && !parent_leaf_histogram_array_[feature_index].is_splittable()) {
+ smaller_leaf_histogram_array_[feature_index].set_is_splittable(false);
+ continue;
+ }
+ size_t bin_size = train_data_->FeatureNumBin(feature_index) + 1;
+ printf("CUDATreeLearner::FindBestSplits() Feature %d bin_size=%zd smaller leaf:\n", feature_index, bin_size);
+ PrintHistograms(smaller_leaf_histogram_array_[feature_index].RawData() - kHistOffset, bin_size);
+ if (larger_leaf_splits_ == nullptr || larger_leaf_splits_->leaf_index() < 0) { continue; }
+ printf("CUDATreeLearner::FindBestSplits() Feature %d bin_size=%zd larger leaf:\n", feature_index, bin_size);
+
+ PrintHistograms(larger_leaf_histogram_array_[feature_index].RawData() - kHistOffset, bin_size);
+ }
+#endif
+}
+
+void CUDATreeLearner::Split(Tree* tree, int best_Leaf, int* left_leaf, int* right_leaf) {
+ const SplitInfo& best_split_info = best_split_per_leaf_[best_Leaf];
+#if CUDA_DEBUG >= 2
+ printf("Splitting leaf %d with feature %d thresh %d gain %f stat %f %f %f %f\n", best_Leaf, best_split_info.feature, best_split_info.threshold, best_split_info.gain, best_split_info.left_sum_gradient, best_split_info.right_sum_gradient, best_split_info.left_sum_hessian, best_split_info.right_sum_hessian);
+#endif
+ SerialTreeLearner::Split(tree, best_Leaf, left_leaf, right_leaf);
+ if (Network::num_machines() == 1) {
+ // do some sanity check for the GPU algorithm
+ if (best_split_info.left_count < best_split_info.right_count) {
+ if ((best_split_info.left_count != smaller_leaf_splits_->num_data_in_leaf()) ||
+ (best_split_info.right_count!= larger_leaf_splits_->num_data_in_leaf())) {
+ Log::Fatal("Bug in GPU histogram! split %d: %d, smaller_leaf: %d, larger_leaf: %d\n", best_split_info.left_count, best_split_info.right_count, smaller_leaf_splits_->num_data_in_leaf(), larger_leaf_splits_->num_data_in_leaf());
+ }
+ } else {
+ if ((best_split_info.left_count != larger_leaf_splits_->num_data_in_leaf()) ||
+ (best_split_info.right_count!= smaller_leaf_splits_->num_data_in_leaf())) {
+ Log::Fatal("Bug in GPU histogram! split %d: %d, smaller_leaf: %d, larger_leaf: %d\n", best_split_info.left_count, best_split_info.right_count, smaller_leaf_splits_->num_data_in_leaf(), larger_leaf_splits_->num_data_in_leaf());
+ }
+ }
+ }
+}
+
+} // namespace LightGBM
+#undef cudaMemcpy_DEBUG
+#endif // USE_CUDA
diff --git a/src/treelearner/cuda_tree_learner.h b/src/treelearner/cuda_tree_learner.h
new file mode 100644
index 000000000000..442c2f53ea01
--- /dev/null
+++ b/src/treelearner/cuda_tree_learner.h
@@ -0,0 +1,265 @@
+/*!
+ * Copyright (c) 2020 IBM Corporation. All rights reserved.
+ * Licensed under the MIT License. See LICENSE file in the project root for license information.
+ */
+#ifndef LIGHTGBM_TREELEARNER_CUDA_TREE_LEARNER_H_
+#define LIGHTGBM_TREELEARNER_CUDA_TREE_LEARNER_H_
+
+#include
+#include
+#include
+#include
+#include
+
+#include
+#include
+#include
+#include
+#include
+#include
+#ifdef USE_CUDA
+#include
+#endif
+
+#include "feature_histogram.hpp"
+#include "serial_tree_learner.h"
+#include "data_partition.hpp"
+#include "split_info.hpp"
+#include "leaf_splits.hpp"
+
+#ifdef USE_CUDA
+#include
+#include "cuda_kernel_launcher.h"
+
+
+using json11::Json;
+
+namespace LightGBM {
+
+/*!
+* \brief CUDA-based parallel learning algorithm.
+*/
+class CUDATreeLearner: public SerialTreeLearner {
+ public:
+ explicit CUDATreeLearner(const Config* tree_config);
+ ~CUDATreeLearner();
+ void Init(const Dataset* train_data, bool is_constant_hessian) override;
+ void ResetTrainingDataInner(const Dataset* train_data, bool is_constant_hessian, bool reset_multi_val_bin) override;
+ Tree* Train(const score_t* gradients, const score_t *hessians);
+ void SetBaggingData(const Dataset* subset, const data_size_t* used_indices, data_size_t num_data) override {
+ SerialTreeLearner::SetBaggingData(subset, used_indices, num_data);
+ if (subset == nullptr && used_indices != nullptr) {
+ if (num_data != num_data_) {
+ use_bagging_ = true;
+ return;
+ }
+ }
+ use_bagging_ = false;
+ }
+
+ protected:
+ void BeforeTrain() override;
+ bool BeforeFindBestSplit(const Tree* tree, int left_leaf, int right_leaf) override;
+ void FindBestSplits(const Tree* tree) override;
+ void Split(Tree* tree, int best_Leaf, int* left_leaf, int* right_leaf) override;
+ void ConstructHistograms(const std::vector& is_feature_used, bool use_subtract) override;
+
+ private:
+ typedef float gpu_hist_t;
+
+ /*!
+ * \brief Find the best number of workgroups processing one feature for maximizing efficiency
+ * \param leaf_num_data The number of data examples on the current leaf being processed
+ * \return Log2 of the best number for workgroups per feature, in range 0...kMaxLogWorkgroupsPerFeature
+ */
+ int GetNumWorkgroupsPerFeature(data_size_t leaf_num_data);
+
+ /*!
+ * \brief Initialize GPU device
+ * \param num_gpu: number of maximum gpus
+ */
+ void InitGPU(int num_gpu);
+
+ /*!
+ * \brief Allocate memory for GPU computation // alloc only
+ */
+ void CountDenseFeatureGroups(); // compute num_dense_feature_group
+ void prevAllocateGPUMemory(); // compute CPU-side param calculation & Pin HostMemory
+ void AllocateGPUMemory();
+
+ /*!
+ * \ ResetGPUMemory
+ */
+ void ResetGPUMemory();
+
+ /*!
+ * \ copy dense feature from CPU to GPU
+ */
+ void copyDenseFeature();
+
+ /*!
+ * \brief Compute GPU feature histogram for the current leaf.
+ * Indices, gradients and hessians have been copied to the device.
+ * \param leaf_num_data Number of data on current leaf
+ * \param use_all_features Set to true to not use feature masks, with a faster kernel
+ */
+ void GPUHistogram(data_size_t leaf_num_data, bool use_all_features);
+
+ void SetThreadData(ThreadData* thread_data, int device_id, int histogram_size,
+ int leaf_num_data, bool use_all_features,
+ int num_workgroups, int exp_workgroups_per_feature) {
+ ThreadData* td = &thread_data[device_id];
+ td->device_id = device_id;
+ td->histogram_size = histogram_size;
+ td->leaf_num_data = leaf_num_data;
+ td->num_data = num_data_;
+ td->use_all_features = use_all_features;
+ td->is_constant_hessian = share_state_->is_constant_hessian;
+ td->num_workgroups = num_workgroups;
+ td->stream = stream_[device_id];
+ td->device_features = device_features_[device_id];
+ td->device_feature_masks = reinterpret_cast(device_feature_masks_[device_id]);
+ td->device_data_indices = device_data_indices_[device_id];
+ td->device_gradients = device_gradients_[device_id];
+ td->device_hessians = device_hessians_[device_id];
+ td->hessians_const = hessians_[0];
+ td->device_subhistograms = device_subhistograms_[device_id];
+ td->sync_counters = sync_counters_[device_id];
+ td->device_histogram_outputs = device_histogram_outputs_[device_id];
+ td->exp_workgroups_per_feature = exp_workgroups_per_feature;
+
+ td->kernel_start = &(kernel_start_[device_id]);
+ td->kernel_wait_obj = &(kernel_wait_obj_[device_id]);
+ td->kernel_input_wait_time = &(kernel_input_wait_time_[device_id]);
+
+ size_t output_size = num_gpu_feature_groups_[device_id] * dword_features_ * device_bin_size_ * hist_bin_entry_sz_;
+ size_t host_output_offset = offset_gpu_feature_groups_[device_id] * dword_features_ * device_bin_size_ * hist_bin_entry_sz_;
+ td->output_size = output_size;
+ td->host_histogram_output = reinterpret_cast(host_histogram_outputs_) + host_output_offset;
+ td->histograms_wait_obj = &(histograms_wait_obj_[device_id]);
+ }
+
+ /*!
+ * \brief Wait for GPU kernel execution and read histogram
+ * \param histograms Destination of histogram results from GPU.
+ */
+ template
+ void WaitAndGetHistograms(FeatureHistogram* leaf_histogram_array);
+
+ /*!
+ * \brief Construct GPU histogram asynchronously.
+ * Interface is similar to Dataset::ConstructHistograms().
+ * \param is_feature_used A predicate vector for enabling each feature
+ * \param data_indices Array of data example IDs to be included in histogram, will be copied to GPU.
+ * Set to nullptr to skip copy to GPU.
+ * \param num_data Number of data examples to be included in histogram
+ * \return true if GPU kernel is launched, false if GPU is not used
+ */
+ bool ConstructGPUHistogramsAsync(
+ const std::vector& is_feature_used,
+ const data_size_t* data_indices, data_size_t num_data);
+
+ /*! brief Log2 of max number of workgroups per feature*/
+ const int kMaxLogWorkgroupsPerFeature = 10; // 2^10
+ /*! brief Max total number of workgroups with preallocated workspace.
+ * If we use more than this number of workgroups, we have to reallocate subhistograms */
+ std::vector preallocd_max_num_wg_;
+
+ /*! \brief True if bagging is used */
+ bool use_bagging_;
+
+ /*! \brief GPU command queue object */
+ std::vector stream_;
+
+ /*! \brief total number of feature-groups */
+ int num_feature_groups_;
+ /*! \brief total number of dense feature-groups, which will be processed on GPU */
+ int num_dense_feature_groups_;
+ std::vector num_gpu_feature_groups_;
+ std::vector offset_gpu_feature_groups_;
+ /*! \brief On GPU we read one DWORD (4-byte) of features of one example once.
+ * With bin size > 16, there are 4 features per DWORD.
+ * With bin size <=16, there are 8 features per DWORD.
+ */
+ int dword_features_;
+ /*! \brief Max number of bins of training data, used to determine
+ * which GPU kernel to use */
+ int max_num_bin_;
+ /*! \brief Used GPU kernel bin size (64, 256) */
+ int histogram_size_;
+ int device_bin_size_;
+ /*! \brief Size of histogram bin entry, depending if single or double precision is used */
+ size_t hist_bin_entry_sz_;
+ /*! \brief Indices of all dense feature-groups */
+ std::vector dense_feature_group_map_;
+ /*! \brief Indices of all sparse feature-groups */
+ std::vector sparse_feature_group_map_;
+ /*! \brief GPU memory object holding the training data */
+ std::vector device_features_;
+ /*! \brief GPU memory object holding the ordered gradient */
+ std::vector device_gradients_;
+ /*! \brief Pointer to pinned memory of ordered gradient */
+ void * ptr_pinned_gradients_ = nullptr;
+ /*! \brief GPU memory object holding the ordered hessian */
+ std::vector device_hessians_;
+ /*! \brief Pointer to pinned memory of ordered hessian */
+ void * ptr_pinned_hessians_ = nullptr;
+ /*! \brief A vector of feature mask. 1 = feature used, 0 = feature not used */
+ std::vector feature_masks_;
+ /*! \brief GPU memory object holding the feature masks */
+ std::vector device_feature_masks_;
+ /*! \brief Pointer to pinned memory of feature masks */
+ char* ptr_pinned_feature_masks_ = nullptr;
+ /*! \brief GPU memory object holding indices of the leaf being processed */
+ std::vector device_data_indices_;
+ /*! \brief GPU memory object holding counters for workgroup coordination */
+ std::vector sync_counters_;
+ /*! \brief GPU memory object holding temporary sub-histograms per workgroup */
+ std::vector device_subhistograms_;
+ /*! \brief Host memory object for histogram output (GPU will write to Host memory directly) */
+ std::vector device_histogram_outputs_;
+ /*! \brief Host memory pointer for histogram outputs */
+ void *host_histogram_outputs_;
+ /*! CUDA waitlist object for waiting for data transfer before kernel execution */
+ std::vector kernel_wait_obj_;
+ /*! CUDA waitlist object for reading output histograms after kernel execution */
+ std::vector histograms_wait_obj_;
+ /*! CUDA Asynchronous waiting object for copying indices */
+ std::vector indices_future_;
+ /*! Asynchronous waiting object for copying gradients */
+ std::vector gradients_future_;
+ /*! Asynchronous waiting object for copying hessians */
+ std::vector hessians_future_;
+ /*! Asynchronous waiting object for copying dense features */
+ std::vector features_future_;
+
+ // host-side buffer for converting feature data into featre4 data
+ int nthreads_; // number of Feature4* vector on host4_vecs_
+ std::vector kernel_start_;
+ std::vector kernel_time_; // measure histogram kernel time
+ std::vector> kernel_input_wait_time_;
+ int num_gpu_;
+ int allocated_num_data_; // allocated data instances
+ pthread_t **cpu_threads_; // pthread, 1 cpu thread / gpu
+};
+
+} // namespace LightGBM
+#else // USE_CUDA
+
+// When GPU support is not compiled in, quit with an error message
+
+namespace LightGBM {
+
+class CUDATreeLearner: public SerialTreeLearner {
+ public:
+ #pragma warning(disable : 4702)
+ explicit CUDATreeLearner(const Config* tree_config) : SerialTreeLearner(tree_config) {
+ Log::Fatal("CUDA Tree Learner was not enabled in this build.\n"
+ "Please recompile with CMake option -DUSE_CUDA=1");
+ }
+};
+
+} // namespace LightGBM
+
+#endif // USE_CUDA
+#endif // LIGHTGBM_TREELEARNER_CUDA_TREE_LEARNER_H_
diff --git a/src/treelearner/data_parallel_tree_learner.cpp b/src/treelearner/data_parallel_tree_learner.cpp
index 0d6f9df251b6..30d8df84acf6 100644
--- a/src/treelearner/data_parallel_tree_learner.cpp
+++ b/src/treelearner/data_parallel_tree_learner.cpp
@@ -256,6 +256,7 @@ void DataParallelTreeLearner::Split(Tree* tree, int best_Leaf, in
}
// instantiate template classes, otherwise linker cannot find the code
+template class DataParallelTreeLearner;
template class DataParallelTreeLearner;
template class DataParallelTreeLearner;
diff --git a/src/treelearner/feature_parallel_tree_learner.cpp b/src/treelearner/feature_parallel_tree_learner.cpp
index c5202f3d706d..f4edfe03dc16 100644
--- a/src/treelearner/feature_parallel_tree_learner.cpp
+++ b/src/treelearner/feature_parallel_tree_learner.cpp
@@ -77,6 +77,7 @@ void FeatureParallelTreeLearner::FindBestSplitsFromHistograms(
}
// instantiate template classes, otherwise linker cannot find the code
+template class FeatureParallelTreeLearner;
template class FeatureParallelTreeLearner;
template class FeatureParallelTreeLearner;
} // namespace LightGBM
diff --git a/src/treelearner/gpu_tree_learner.cpp b/src/treelearner/gpu_tree_learner.cpp
index 43ccadfd176f..df90aafb945c 100644
--- a/src/treelearner/gpu_tree_learner.cpp
+++ b/src/treelearner/gpu_tree_learner.cpp
@@ -52,7 +52,7 @@ void PrintHistograms(hist_t* h, size_t size) {
double total_hess = 0;
for (size_t i = 0; i < size; ++i) {
printf("%03lu=%9.3g,%9.3g\t", i, GET_GRAD(h, i), GET_HESS(h, i));
- if ((i & 2) == 2)
+ if ((i & 3) == 3)
printf("\n");
total_hess += GET_HESS(h, i);
}
@@ -1068,10 +1068,10 @@ void GPUTreeLearner::FindBestSplits(const Tree* tree) {
}
size_t bin_size = train_data_->FeatureNumBin(feature_index) + 1;
printf("Feature %d smaller leaf:\n", feature_index);
- PrintHistograms(smaller_leaf_histogram_array_[feature_index].RawData() - 1, bin_size);
+ PrintHistograms(smaller_leaf_histogram_array_[feature_index].RawData() - kHistOffset, bin_size);
if (larger_leaf_splits_ == nullptr || larger_leaf_splits_->LeafIndex() < 0) { continue; }
printf("Feature %d larger leaf:\n", feature_index);
- PrintHistograms(larger_leaf_histogram_array_[feature_index].RawData() - 1, bin_size);
+ PrintHistograms(larger_leaf_histogram_array_[feature_index].RawData() - kHistOffset, bin_size);
}
#endif
}
diff --git a/src/treelearner/kernels/histogram_16_64_256.cu b/src/treelearner/kernels/histogram_16_64_256.cu
new file mode 100644
index 000000000000..105ccbb62032
--- /dev/null
+++ b/src/treelearner/kernels/histogram_16_64_256.cu
@@ -0,0 +1,949 @@
+/*!
+ * Copyright (c) 2020 IBM Corporation. All rights reserved.
+ * Licensed under the MIT License. See LICENSE file in the project root for license information.
+ */
+
+#include
+
+#include
+#include
+
+#include "histogram_16_64_256.hu"
+
+namespace LightGBM {
+
+// atomic add for float number in local memory
+inline __device__ void atomic_local_add_f(acc_type *addr, const acc_type val) {
+ atomicAdd(addr, static_cast(val));
+}
+
+// histogram16 stuff
+#ifdef ENABLE_ALL_FEATURES
+#ifdef IGNORE_INDICES
+#define KERNEL_NAME histogram16_fulldata
+#else // IGNORE_INDICES
+#define KERNEL_NAME histogram16
+#endif // IGNORE_INDICES
+#else // ENABLE_ALL_FEATURES
+#error "ENABLE_ALL_FEATURES should always be 1"
+#define KERNEL_NAME histogram16
+#endif // ENABLE_ALL_FEATURES
+#define NUM_BINS 16
+#define LOCAL_MEM_SIZE ((sizeof(unsigned int) + 2 * sizeof(acc_type)) * NUM_BINS)
+
+// this function will be called by histogram16
+// we have one sub-histogram of one feature in local memory, and need to read others
+inline void __device__ within_kernel_reduction16x4(const acc_type* __restrict__ feature_sub_hist,
+ const unsigned int skip_id,
+ const unsigned int old_val_cont_bin0,
+ const uint16_t num_sub_hist,
+ acc_type* __restrict__ output_buf,
+ acc_type* __restrict__ local_hist,
+ const size_t power_feature_workgroups) {
+ const uint16_t ltid = threadIdx.x;
+ acc_type grad_bin = local_hist[ltid * 2];
+ acc_type hess_bin = local_hist[ltid * 2 + 1];
+ unsigned int* __restrict__ local_cnt = reinterpret_cast(local_hist + 2 * NUM_BINS);
+
+ unsigned int cont_bin;
+ if (power_feature_workgroups != 0) {
+ cont_bin = ltid ? local_cnt[ltid] : old_val_cont_bin0;
+ } else {
+ cont_bin = local_cnt[ltid];
+ }
+ uint16_t i;
+
+ if (power_feature_workgroups != 0) {
+ // add all sub-histograms for feature
+ const acc_type* __restrict__ p = feature_sub_hist + ltid;
+ for (i = 0; i < skip_id; ++i) {
+ grad_bin += *p; p += NUM_BINS;
+ hess_bin += *p; p += NUM_BINS;
+ cont_bin += as_acc_int_type(*p); p += NUM_BINS;
+ }
+
+ // skip the counters we already have
+ p += 3 * NUM_BINS;
+
+ for (i = i + 1; i < num_sub_hist; ++i) {
+ grad_bin += *p; p += NUM_BINS;
+ hess_bin += *p; p += NUM_BINS;
+ cont_bin += as_acc_int_type(*p); p += NUM_BINS;
+ }
+ }
+ __syncthreads();
+
+ output_buf[ltid * 2 + 0] = grad_bin;
+ output_buf[ltid * 2 + 1] = hess_bin;
+}
+
+#if USE_CONSTANT_BUF == 1
+__kernel void KERNEL_NAME(__global const uchar* restrict feature_data_base,
+ __constant const uchar* restrict feature_masks __attribute__((max_constant_size(65536))),
+ const data_size_t feature_size,
+ __constant const data_size_t* restrict data_indices __attribute__((max_constant_size(65536))),
+ const data_size_t num_data,
+ __constant const score_t* restrict ordered_gradients __attribute__((max_constant_size(65536))),
+#if CONST_HESSIAN == 0
+ __constant const score_t* restrict ordered_hessians __attribute__((max_constant_size(65536))),
+#else
+ const score_t const_hessian,
+#endif
+ char* __restrict__ output_buf,
+ volatile int * sync_counters,
+ acc_type* __restrict__ hist_buf_base,
+ const size_t power_feature_workgroups) {
+#else
+__global__ void KERNEL_NAME(const uchar* feature_data_base,
+ const uchar* __restrict__ feature_masks,
+ const data_size_t feature_size,
+ const data_size_t* data_indices,
+ const data_size_t num_data,
+ const score_t* ordered_gradients,
+#if CONST_HESSIAN == 0
+ const score_t* ordered_hessians,
+#else
+ const score_t const_hessian,
+#endif
+ char* __restrict__ output_buf,
+ volatile int * sync_counters,
+ acc_type* __restrict__ hist_buf_base,
+ const size_t power_feature_workgroups) {
+#endif
+ // allocate the local memory array aligned with float2, to guarantee correct alignment on NVIDIA platforms
+ // otherwise a "Misaligned Address" exception may occur
+ __shared__ float2 shared_array[LOCAL_MEM_SIZE/sizeof(float2)];
+ const unsigned int gtid = blockIdx.x * blockDim.x + threadIdx.x;
+ const uint16_t ltid = threadIdx.x;
+ const uint16_t lsize = NUM_BINS; // get_local_size(0);
+ const uint16_t group_id = blockIdx.x;
+
+ // local memory per workgroup is 3 KB
+ // clear local memory
+ unsigned int *ptr = reinterpret_cast(shared_array);
+ for (int i = ltid; i < LOCAL_MEM_SIZE/sizeof(unsigned int); i += lsize) {
+ ptr[i] = 0;
+ }
+ __syncthreads();
+ // gradient/hessian histograms
+ // assume this starts at 32 * 4 = 128-byte boundary // What does it mean? boundary??
+ // total size: 2 * 256 * size_of(float) = 2 KB
+ // organization: each feature/grad/hessian is at a different bank,
+ // as indepedent of the feature value as possible
+ acc_type *gh_hist = reinterpret_cast(shared_array);
+
+ // counter histogram
+ // total size: 256 * size_of(unsigned int) = 1 KB
+ unsigned int *cnt_hist = reinterpret_cast(gh_hist + 2 * NUM_BINS);
+
+ // odd threads (1, 3, ...) compute histograms for hessians first
+ // even thread (0, 2, ...) compute histograms for gradients first
+ // etc.
+ uchar is_hessian_first = ltid & 1;
+
+ uint16_t feature_id = group_id >> power_feature_workgroups;
+
+ // each 2^POWER_FEATURE_WORKGROUPS workgroups process on one feature (compile-time constant)
+ // feature_size is the number of examples per feature
+ const uchar *feature_data = feature_data_base + feature_id * feature_size;
+
+ // size of threads that process this feature4
+ const unsigned int subglobal_size = lsize * (1 << power_feature_workgroups);
+
+ // equavalent thread ID in this subgroup for this feature4
+ const unsigned int subglobal_tid = gtid - feature_id * subglobal_size;
+
+
+ data_size_t ind;
+ data_size_t ind_next;
+ #ifdef IGNORE_INDICES
+ ind = subglobal_tid;
+ #else
+ ind = data_indices[subglobal_tid];
+ #endif
+
+ // extract feature mask, when a byte is set to 0, that feature is disabled
+ uchar feature_mask = feature_masks[feature_id];
+ // exit if the feature is masked
+ if (!feature_mask) {
+ return;
+ } else {
+ feature_mask = feature_mask - 1; // feature_mask is used for get feature (1: 4bit feature, 0: 8bit feature)
+ }
+
+ // STAGE 1: read feature data, and gradient and hessian
+ // first half of the threads read feature data from global memory
+ // We will prefetch data into the "next" variable at the beginning of each iteration
+ uchar feature;
+ uchar feature_next;
+ uint16_t bin;
+
+ feature = feature_data[ind >> feature_mask];
+ if (feature_mask) {
+ feature = (feature >> ((ind & 1) << 2)) & 0xf;
+ }
+ bin = feature;
+ acc_type grad_bin = 0.0f, hess_bin = 0.0f;
+ acc_type *addr_bin;
+
+ // store gradient and hessian
+ score_t grad, hess;
+ score_t grad_next, hess_next;
+ grad = ordered_gradients[ind];
+ #if CONST_HESSIAN == 0
+ hess = ordered_hessians[ind];
+ #endif
+
+ // there are 2^POWER_FEATURE_WORKGROUPS workgroups processing each feature4
+ for (unsigned int i = subglobal_tid; i < num_data; i += subglobal_size) {
+ // prefetch the next iteration variables
+ // we don't need bondary check because we have made the buffer large
+ int i_next = i + subglobal_size;
+ #ifdef IGNORE_INDICES
+ // we need to check to bounds here
+ ind_next = i_next < num_data ? i_next : i;
+ #else
+ ind_next = data_indices[i_next];
+ #endif
+
+ grad_next = ordered_gradients[ind_next];
+ #if CONST_HESSIAN == 0
+ hess_next = ordered_hessians[ind_next];
+ #endif
+
+ // STAGE 2: accumulate gradient and hessian
+ if (bin != feature) {
+ addr_bin = gh_hist + bin * 2 + is_hessian_first;
+ #if CONST_HESSIAN == 0
+ acc_type acc_bin = is_hessian_first ? hess_bin : grad_bin;
+ atomic_local_add_f(addr_bin, acc_bin);
+
+ addr_bin = addr_bin + 1 - 2 * is_hessian_first;
+ acc_bin = is_hessian_first ? grad_bin : hess_bin;
+ atomic_local_add_f(addr_bin, acc_bin);
+
+ #elif CONST_HESSIAN == 1
+ atomic_local_add_f(addr_bin, grad_bin);
+ #endif
+
+ bin = feature;
+ grad_bin = grad;
+ hess_bin = hess;
+ } else {
+ grad_bin += grad;
+ hess_bin += hess;
+ }
+
+ // prefetch the next iteration variables
+ feature_next = feature_data[ind_next >> feature_mask];
+
+ // STAGE 3: accumulate counter
+ atomicAdd(cnt_hist + feature, 1);
+
+ // STAGE 4: update next stat
+ grad = grad_next;
+ hess = hess_next;
+ if (!feature_mask) {
+ feature = feature_next;
+ } else {
+ feature = (feature_next >> ((ind_next & 1) << 2)) & 0xf;
+ }
+ }
+
+
+ addr_bin = gh_hist + bin * 2 + is_hessian_first;
+ #if CONST_HESSIAN == 0
+ acc_type acc_bin = is_hessian_first ? hess_bin : grad_bin;
+ atomic_local_add_f(addr_bin, acc_bin);
+
+ addr_bin = addr_bin + 1 - 2 * is_hessian_first;
+ acc_bin = is_hessian_first ? grad_bin : hess_bin;
+ atomic_local_add_f(addr_bin, acc_bin);
+
+ #elif CONST_HESSIAN == 1
+ atomic_local_add_f(addr_bin, grad_bin);
+ #endif
+ __syncthreads();
+
+ #if CONST_HESSIAN == 1
+ // make a final reduction
+ gh_hist[ltid * 2] += gh_hist[ltid * 2 + 1];
+ gh_hist[ltid * 2 + 1] = const_hessian * cnt_hist[ltid]; // counter move to this position
+ __syncthreads();
+ #endif
+
+#if POWER_FEATURE_WORKGROUPS != 0
+ acc_type *__restrict__ output = reinterpret_cast(output_buf) + group_id * 3 * NUM_BINS;
+ // write gradients and hessians
+ acc_type *__restrict__ ptr_f = output;
+ for (uint16_t i = ltid; i < 2 * NUM_BINS; i += lsize) {
+ // even threads read gradients, odd threads read hessians
+ acc_type value = gh_hist[i];
+ ptr_f[(i & 1) * NUM_BINS + (i >> 1)] = value;
+ }
+ // write counts
+ acc_int_type *__restrict__ ptr_i = reinterpret_cast(output + 2 * NUM_BINS);
+ for (uint16_t i = ltid; i < NUM_BINS; i += lsize) {
+ unsigned int value = cnt_hist[i];
+ ptr_i[i] = value;
+ }
+ __syncthreads();
+ __threadfence();
+ unsigned int * counter_val = cnt_hist;
+ // backup the old value
+ unsigned int old_val = *counter_val;
+ if (ltid == 0) {
+ // all workgroups processing the same feature add this counter
+ *counter_val = atomicAdd(const_cast(sync_counters + feature_id), 1);
+ }
+ // make sure everyone in this workgroup is here
+ __syncthreads();
+ // everyone in this workgroup: if we are the last workgroup, then do reduction!
+ if (*counter_val == (1 << power_feature_workgroups) - 1) {
+ if (ltid == 0) {
+ sync_counters[feature_id] = 0;
+ }
+#else
+ }
+ // only 1 work group, no need to increase counter
+ // the reduction will become a simple copy
+ {
+ unsigned int old_val; // dummy
+#endif
+ // locate our feature's block in output memory
+ unsigned int output_offset = (feature_id << power_feature_workgroups);
+ acc_type const * __restrict__ feature_subhists =
+ reinterpret_cast(output_buf) + output_offset * 3 * NUM_BINS;
+ // skip reading the data already in local memory
+ unsigned int skip_id = group_id - output_offset;
+ // locate output histogram location for this feature4
+ acc_type *__restrict__ hist_buf = hist_buf_base + feature_id * 2 * NUM_BINS;
+
+ within_kernel_reduction16x4(feature_subhists, skip_id, old_val, 1 << power_feature_workgroups, hist_buf, reinterpret_cast(shared_array), power_feature_workgroups);
+ }
+}
+
+// end of histogram16 stuff
+
+// histogram64 stuff
+#undef KERNEL_NAME
+#undef NUM_BINS
+#undef LOCAL_MEM_SIZE
+#ifdef ENABLE_ALL_FEATURES
+#ifdef IGNORE_INDICES
+#define KERNEL_NAME histogram64_fulldata
+#else // IGNORE_INDICES
+#define KERNEL_NAME histogram64 // seems like ENABLE_ALL_FEATURES is set to 1 in the header if its disabled
+// #define KERNEL_NAME histogram64_allfeats
+#endif // IGNORE_INDICES
+#else // ENABLE_ALL_FEATURES
+#error "ENABLE_ALL_FEATURES should always be 1"
+#define KERNEL_NAME histogram64
+#endif // ENABLE_ALL_FEATURES
+#define NUM_BINS 64
+#define LOCAL_MEM_SIZE ((sizeof(unsigned int) + 2 * sizeof(acc_type)) * NUM_BINS)
+
+// this function will be called by histogram64
+// we have one sub-histogram of one feature in local memory, and need to read others
+inline void __device__ within_kernel_reduction64x4(const acc_type* __restrict__ feature_sub_hist,
+ const unsigned int skip_id,
+ const unsigned int old_val_cont_bin0,
+ const uint16_t num_sub_hist,
+ acc_type* __restrict__ output_buf,
+ acc_type* __restrict__ local_hist,
+ const size_t power_feature_workgroups) {
+ const uint16_t ltid = threadIdx.x;
+ acc_type grad_bin = local_hist[ltid * 2];
+ acc_type hess_bin = local_hist[ltid * 2 + 1];
+ unsigned int* __restrict__ local_cnt = reinterpret_cast(local_hist + 2 * NUM_BINS);
+
+ unsigned int cont_bin;
+ if (power_feature_workgroups != 0) {
+ cont_bin = ltid ? local_cnt[ltid] : old_val_cont_bin0;
+ } else {
+ cont_bin = local_cnt[ltid];
+ }
+ uint16_t i;
+
+ if (power_feature_workgroups != 0) {
+ // add all sub-histograms for feature
+ const acc_type* __restrict__ p = feature_sub_hist + ltid;
+ for (i = 0; i < skip_id; ++i) {
+ grad_bin += *p; p += NUM_BINS;
+ hess_bin += *p; p += NUM_BINS;
+ cont_bin += as_acc_int_type(*p); p += NUM_BINS;
+ }
+
+ // skip the counters we already have
+ p += 3 * NUM_BINS;
+
+ for (i = i + 1; i < num_sub_hist; ++i) {
+ grad_bin += *p; p += NUM_BINS;
+ hess_bin += *p; p += NUM_BINS;
+ cont_bin += as_acc_int_type(*p); p += NUM_BINS;
+ }
+ }
+ __syncthreads();
+
+ output_buf[ltid * 2 + 0] = grad_bin;
+ output_buf[ltid * 2 + 1] = hess_bin;
+}
+
+#if USE_CONSTANT_BUF == 1
+__kernel void KERNEL_NAME(__global const uchar* restrict feature_data_base,
+ __constant const uchar* restrict feature_masks __attribute__((max_constant_size(65536))),
+ const data_size_t feature_size,
+ __constant const data_size_t* restrict data_indices __attribute__((max_constant_size(65536))),
+ const data_size_t num_data,
+ __constant const score_t* restrict ordered_gradients __attribute__((max_constant_size(65536))),
+#if CONST_HESSIAN == 0
+ __constant const score_t* restrict ordered_hessians __attribute__((max_constant_size(65536))),
+#else
+ const score_t const_hessian,
+#endif
+ char* __restrict__ output_buf,
+ volatile int * sync_counters,
+ acc_type* __restrict__ hist_buf_base,
+ const size_t power_feature_workgroups) {
+#else
+__global__ void KERNEL_NAME(const uchar* feature_data_base,
+ const uchar* __restrict__ feature_masks,
+ const data_size_t feature_size,
+ const data_size_t* data_indices,
+ const data_size_t num_data,
+ const score_t* ordered_gradients,
+#if CONST_HESSIAN == 0
+ const score_t* ordered_hessians,
+#else
+ const score_t const_hessian,
+#endif
+ char* __restrict__ output_buf,
+ volatile int * sync_counters,
+ acc_type* __restrict__ hist_buf_base,
+ const size_t power_feature_workgroups) {
+#endif
+ // allocate the local memory array aligned with float2, to guarantee correct alignment on NVIDIA platforms
+ // otherwise a "Misaligned Address" exception may occur
+ __shared__ float2 shared_array[LOCAL_MEM_SIZE/sizeof(float2)];
+ const unsigned int gtid = blockIdx.x * blockDim.x + threadIdx.x;
+ const uint16_t ltid = threadIdx.x;
+ const uint16_t lsize = NUM_BINS; // get_local_size(0);
+ const uint16_t group_id = blockIdx.x;
+
+ // local memory per workgroup is 3 KB
+ // clear local memory
+ unsigned int *ptr = reinterpret_cast(shared_array);
+ for (int i = ltid; i < LOCAL_MEM_SIZE/sizeof(unsigned int); i += lsize) {
+ ptr[i] = 0;
+ }
+ __syncthreads();
+ // gradient/hessian histograms
+ // assume this starts at 32 * 4 = 128-byte boundary // What does it mean? boundary??
+ // total size: 2 * 256 * size_of(float) = 2 KB
+ // organization: each feature/grad/hessian is at a different bank,
+ // as indepedent of the feature value as possible
+ acc_type *gh_hist = reinterpret_cast(shared_array);
+
+ // counter histogram
+ // total size: 256 * size_of(unsigned int) = 1 KB
+ unsigned int *cnt_hist = reinterpret_cast(gh_hist + 2 * NUM_BINS);
+
+ // odd threads (1, 3, ...) compute histograms for hessians first
+ // even thread (0, 2, ...) compute histograms for gradients first
+ // etc.
+ uchar is_hessian_first = ltid & 1;
+
+ uint16_t feature_id = group_id >> power_feature_workgroups;
+
+ // each 2^POWER_FEATURE_WORKGROUPS workgroups process on one feature (compile-time constant)
+ // feature_size is the number of examples per feature
+ const uchar *feature_data = feature_data_base + feature_id * feature_size;
+
+ // size of threads that process this feature4
+ const unsigned int subglobal_size = lsize * (1 << power_feature_workgroups);
+
+ // equavalent thread ID in this subgroup for this feature4
+ const unsigned int subglobal_tid = gtid - feature_id * subglobal_size;
+
+ data_size_t ind;
+ data_size_t ind_next;
+ #ifdef IGNORE_INDICES
+ ind = subglobal_tid;
+ #else
+ ind = data_indices[subglobal_tid];
+ #endif
+
+ // extract feature mask, when a byte is set to 0, that feature is disabled
+ uchar feature_mask = feature_masks[feature_id];
+ // exit if the feature is masked
+ if (!feature_mask) {
+ return;
+ } else {
+ feature_mask = feature_mask - 1; // feature_mask is used for get feature (1: 4bit feature, 0: 8bit feature)
+ }
+
+ // STAGE 1: read feature data, and gradient and hessian
+ // first half of the threads read feature data from global memory
+ // We will prefetch data into the "next" variable at the beginning of each iteration
+ uchar feature;
+ uchar feature_next;
+ uint16_t bin;
+
+ feature = feature_data[ind >> feature_mask];
+ if (feature_mask) {
+ feature = (feature >> ((ind & 1) << 2)) & 0xf;
+ }
+ bin = feature;
+ acc_type grad_bin = 0.0f, hess_bin = 0.0f;
+ acc_type *addr_bin;
+
+ // store gradient and hessian
+ score_t grad, hess;
+ score_t grad_next, hess_next;
+ grad = ordered_gradients[ind];
+ #if CONST_HESSIAN == 0
+ hess = ordered_hessians[ind];
+ #endif
+
+ // there are 2^POWER_FEATURE_WORKGROUPS workgroups processing each feature4
+ for (unsigned int i = subglobal_tid; i < num_data; i += subglobal_size) {
+ // prefetch the next iteration variables
+ // we don't need bondary check because we have made the buffer large
+ int i_next = i + subglobal_size;
+ #ifdef IGNORE_INDICES
+ // we need to check to bounds here
+ ind_next = i_next < num_data ? i_next : i;
+ #else
+ ind_next = data_indices[i_next];
+ #endif
+
+ grad_next = ordered_gradients[ind_next];
+ #if CONST_HESSIAN == 0
+ hess_next = ordered_hessians[ind_next];
+ #endif
+
+ // STAGE 2: accumulate gradient and hessian
+ if (bin != feature) {
+ addr_bin = gh_hist + bin * 2 + is_hessian_first;
+ #if CONST_HESSIAN == 0
+ acc_type acc_bin = is_hessian_first ? hess_bin : grad_bin;
+ atomic_local_add_f(addr_bin, acc_bin);
+
+ addr_bin = addr_bin + 1 - 2 * is_hessian_first;
+ acc_bin = is_hessian_first ? grad_bin : hess_bin;
+ atomic_local_add_f(addr_bin, acc_bin);
+
+ #elif CONST_HESSIAN == 1
+ atomic_local_add_f(addr_bin, grad_bin);
+ #endif
+
+ bin = feature;
+ grad_bin = grad;
+ hess_bin = hess;
+ } else {
+ grad_bin += grad;
+ hess_bin += hess;
+ }
+
+ // prefetch the next iteration variables
+ feature_next = feature_data[ind_next >> feature_mask];
+
+ // STAGE 3: accumulate counter
+ atomicAdd(cnt_hist + feature, 1);
+
+ // STAGE 4: update next stat
+ grad = grad_next;
+ hess = hess_next;
+ if (!feature_mask) {
+ feature = feature_next;
+ } else {
+ feature = (feature_next >> ((ind_next & 1) << 2)) & 0xf;
+ }
+ }
+
+ addr_bin = gh_hist + bin * 2 + is_hessian_first;
+ #if CONST_HESSIAN == 0
+ acc_type acc_bin = is_hessian_first ? hess_bin : grad_bin;
+ atomic_local_add_f(addr_bin, acc_bin);
+
+ addr_bin = addr_bin + 1 - 2 * is_hessian_first;
+ acc_bin = is_hessian_first ? grad_bin : hess_bin;
+ atomic_local_add_f(addr_bin, acc_bin);
+
+ #elif CONST_HESSIAN == 1
+ atomic_local_add_f(addr_bin, grad_bin);
+ #endif
+ __syncthreads();
+
+ #if CONST_HESSIAN == 1
+ // make a final reduction
+ gh_hist[ltid * 2] += gh_hist[ltid * 2 + 1];
+ gh_hist[ltid * 2 + 1] = const_hessian * cnt_hist[ltid]; // counter move to this position
+ __syncthreads();
+ #endif
+
+#if POWER_FEATURE_WORKGROUPS != 0
+ acc_type *__restrict__ output = reinterpret_cast(output_buf) + group_id * 3 * NUM_BINS;
+ // write gradients and hessians
+ acc_type *__restrict__ ptr_f = output;
+ for (uint16_t i = ltid; i < 2 * NUM_BINS; i += lsize) {
+ // even threads read gradients, odd threads read hessians
+ acc_type value = gh_hist[i];
+ ptr_f[(i & 1) * NUM_BINS + (i >> 1)] = value;
+ }
+ // write counts
+ acc_int_type *__restrict__ ptr_i = reinterpret_cast(output + 2 * NUM_BINS);
+ for (uint16_t i = ltid; i < NUM_BINS; i += lsize) {
+ unsigned int value = cnt_hist[i];
+ ptr_i[i] = value;
+ }
+ __syncthreads();
+ __threadfence();
+ unsigned int * counter_val = cnt_hist;
+ // backup the old value
+ unsigned int old_val = *counter_val;
+ if (ltid == 0) {
+ // all workgroups processing the same feature add this counter
+ *counter_val = atomicAdd(const_cast(sync_counters + feature_id), 1);
+ }
+ // make sure everyone in this workgroup is here
+ __syncthreads();
+ // everyone in this workgroup: if we are the last workgroup, then do reduction!
+ if (*counter_val == (1 << power_feature_workgroups) - 1) {
+ if (ltid == 0) {
+ sync_counters[feature_id] = 0;
+ }
+#else
+ }
+ // only 1 work group, no need to increase counter
+ // the reduction will become a simple copy
+ {
+ unsigned int old_val; // dummy
+#endif
+ // locate our feature's block in output memory
+ unsigned int output_offset = (feature_id << power_feature_workgroups);
+ acc_type const * __restrict__ feature_subhists =
+ reinterpret_cast(output_buf) + output_offset * 3 * NUM_BINS;
+ // skip reading the data already in local memory
+ unsigned int skip_id = group_id - output_offset;
+ // locate output histogram location for this feature4
+ acc_type *__restrict__ hist_buf = hist_buf_base + feature_id * 2 * NUM_BINS;
+
+ within_kernel_reduction64x4(feature_subhists, skip_id, old_val, 1 << power_feature_workgroups, hist_buf, reinterpret_cast(shared_array), power_feature_workgroups);
+ }
+}
+
+// end of histogram64 stuff
+
+// histogram256 stuff
+#undef KERNEL_NAME
+#undef NUM_BINS
+#undef LOCAL_MEM_SIZE
+#ifdef ENABLE_ALL_FEATURES
+#ifdef IGNORE_INDICES
+#define KERNEL_NAME histogram256_fulldata
+#else // IGNORE_INDICES
+#define KERNEL_NAME histogram256 // seems like ENABLE_ALL_FEATURES is set to 1 in the header if its disabled
+// #define KERNEL_NAME histogram256_allfeats
+#endif // IGNORE_INDICES
+#else // ENABLE_ALL_FEATURES
+#error "ENABLE_ALL_FEATURES should always be 1"
+#define KERNEL_NAME histogram256
+#endif // ENABLE_ALL_FEATURES
+#define NUM_BINS 256
+#define LOCAL_MEM_SIZE ((sizeof(unsigned int) + 2 * sizeof(acc_type)) * NUM_BINS)
+
+// this function will be called by histogram256
+// we have one sub-histogram of one feature in local memory, and need to read others
+inline void __device__ within_kernel_reduction256x4(const acc_type* __restrict__ feature_sub_hist,
+ const unsigned int skip_id,
+ const unsigned int old_val_cont_bin0,
+ const uint16_t num_sub_hist,
+ acc_type* __restrict__ output_buf,
+ acc_type* __restrict__ local_hist,
+ const size_t power_feature_workgroups) {
+ const uint16_t ltid = threadIdx.x;
+ acc_type grad_bin = local_hist[ltid * 2];
+ acc_type hess_bin = local_hist[ltid * 2 + 1];
+ unsigned int* __restrict__ local_cnt = reinterpret_cast(local_hist + 2 * NUM_BINS);
+
+ unsigned int cont_bin;
+ if (power_feature_workgroups != 0) {
+ cont_bin = ltid ? local_cnt[ltid] : old_val_cont_bin0;
+ } else {
+ cont_bin = local_cnt[ltid];
+ }
+ uint16_t i;
+
+ if (power_feature_workgroups != 0) {
+ // add all sub-histograms for feature
+ const acc_type* __restrict__ p = feature_sub_hist + ltid;
+ for (i = 0; i < skip_id; ++i) {
+ grad_bin += *p; p += NUM_BINS;
+ hess_bin += *p; p += NUM_BINS;
+ cont_bin += as_acc_int_type(*p); p += NUM_BINS;
+ }
+
+ // skip the counters we already have
+ p += 3 * NUM_BINS;
+
+ for (i = i + 1; i < num_sub_hist; ++i) {
+ grad_bin += *p; p += NUM_BINS;
+ hess_bin += *p; p += NUM_BINS;
+ cont_bin += as_acc_int_type(*p); p += NUM_BINS;
+ }
+ }
+
+ __syncthreads();
+
+ output_buf[ltid * 2 + 0] = grad_bin;
+ output_buf[ltid * 2 + 1] = hess_bin;
+}
+
+#if USE_CONSTANT_BUF == 1
+__kernel void KERNEL_NAME(__global const uchar* restrict feature_data_base,
+ __constant const uchar* restrict feature_masks __attribute__((max_constant_size(65536))),
+ const data_size_t feature_size,
+ __constant const data_size_t* restrict data_indices __attribute__((max_constant_size(65536))),
+ const data_size_t num_data,
+ __constant const score_t* restrict ordered_gradients __attribute__((max_constant_size(65536))),
+#if CONST_HESSIAN == 0
+ __constant const score_t* restrict ordered_hessians __attribute__((max_constant_size(65536))),
+#else
+ const score_t const_hessian,
+#endif
+ char* __restrict__ output_buf,
+ volatile int * sync_counters,
+ acc_type* __restrict__ hist_buf_base,
+ const size_t power_feature_workgroups) {
+#else
+__global__ void KERNEL_NAME(const uchar* feature_data_base,
+ const uchar* __restrict__ feature_masks,
+ const data_size_t feature_size,
+ const data_size_t* data_indices,
+ const data_size_t num_data,
+ const score_t* ordered_gradients,
+#if CONST_HESSIAN == 0
+ const score_t* ordered_hessians,
+#else
+ const score_t const_hessian,
+#endif
+ char* __restrict__ output_buf,
+ volatile int * sync_counters,
+ acc_type* __restrict__ hist_buf_base,
+ const size_t power_feature_workgroups) {
+#endif
+ // allocate the local memory array aligned with float2, to guarantee correct alignment on NVIDIA platforms
+ // otherwise a "Misaligned Address" exception may occur
+ __shared__ float2 shared_array[LOCAL_MEM_SIZE/sizeof(float2)];
+ const unsigned int gtid = blockIdx.x * blockDim.x + threadIdx.x;
+ const uint16_t ltid = threadIdx.x;
+ const uint16_t lsize = NUM_BINS; // get_local_size(0);
+ const uint16_t group_id = blockIdx.x;
+
+ // local memory per workgroup is 3 KB
+ // clear local memory
+ unsigned int *ptr = reinterpret_cast(shared_array);
+ for (int i = ltid; i < LOCAL_MEM_SIZE/sizeof(unsigned int); i += lsize) {
+ ptr[i] = 0;
+ }
+ __syncthreads();
+ // gradient/hessian histograms
+ // assume this starts at 32 * 4 = 128-byte boundary // What does it mean? boundary??
+ // total size: 2 * 256 * size_of(float) = 2 KB
+ // organization: each feature/grad/hessian is at a different bank,
+ // as indepedent of the feature value as possible
+ acc_type *gh_hist = reinterpret_cast(shared_array);
+
+ // counter histogram
+ // total size: 256 * size_of(unsigned int) = 1 KB
+ unsigned int *cnt_hist = reinterpret_cast(gh_hist + 2 * NUM_BINS);
+
+ // odd threads (1, 3, ...) compute histograms for hessians first
+ // even thread (0, 2, ...) compute histograms for gradients first
+ // etc.
+ uchar is_hessian_first = ltid & 1;
+
+ uint16_t feature_id = group_id >> power_feature_workgroups;
+
+ // each 2^POWER_FEATURE_WORKGROUPS workgroups process on one feature (compile-time constant)
+ // feature_size is the number of examples per feature
+ const uchar *feature_data = feature_data_base + feature_id * feature_size;
+
+ // size of threads that process this feature4
+ const unsigned int subglobal_size = lsize * (1 << power_feature_workgroups);
+
+ // equavalent thread ID in this subgroup for this feature4
+ const unsigned int subglobal_tid = gtid - feature_id * subglobal_size;
+
+ data_size_t ind;
+ data_size_t ind_next;
+ #ifdef IGNORE_INDICES
+ ind = subglobal_tid;
+ #else
+ ind = data_indices[subglobal_tid];
+ #endif
+
+ // extract feature mask, when a byte is set to 0, that feature is disabled
+ uchar feature_mask = feature_masks[feature_id];
+ // exit if the feature is masked
+ if (!feature_mask) {
+ return;
+ } else {
+ feature_mask = feature_mask - 1; // feature_mask is used for get feature (1: 4bit feature, 0: 8bit feature)
+ }
+
+ // STAGE 1: read feature data, and gradient and hessian
+ // first half of the threads read feature data from global memory
+ // We will prefetch data into the "next" variable at the beginning of each iteration
+ uchar feature;
+ uchar feature_next;
+ uint16_t bin;
+
+ feature = feature_data[ind >> feature_mask];
+ if (feature_mask) {
+ feature = (feature >> ((ind & 1) << 2)) & 0xf;
+ }
+ bin = feature;
+ acc_type grad_bin = 0.0f, hess_bin = 0.0f;
+ acc_type *addr_bin;
+
+ // store gradient and hessian
+ score_t grad, hess;
+ score_t grad_next, hess_next;
+ grad = ordered_gradients[ind];
+ #if CONST_HESSIAN == 0
+ hess = ordered_hessians[ind];
+ #endif
+
+ // there are 2^POWER_FEATURE_WORKGROUPS workgroups processing each feature4
+ for (unsigned int i = subglobal_tid; i < num_data; i += subglobal_size) {
+ // prefetch the next iteration variables
+ // we don't need bondary check because we have made the buffer large
+ int i_next = i + subglobal_size;
+ #ifdef IGNORE_INDICES
+ // we need to check to bounds here
+ ind_next = i_next < num_data ? i_next : i;
+ #else
+ ind_next = data_indices[i_next];
+ #endif
+
+ grad_next = ordered_gradients[ind_next];
+ #if CONST_HESSIAN == 0
+ hess_next = ordered_hessians[ind_next];
+ #endif
+ // STAGE 2: accumulate gradient and hessian
+ if (bin != feature) {
+ addr_bin = gh_hist + bin * 2 + is_hessian_first;
+ #if CONST_HESSIAN == 0
+ acc_type acc_bin = is_hessian_first ? hess_bin : grad_bin;
+ atomic_local_add_f(addr_bin, acc_bin);
+
+ addr_bin = addr_bin + 1 - 2 * is_hessian_first;
+ acc_bin = is_hessian_first ? grad_bin : hess_bin;
+ atomic_local_add_f(addr_bin, acc_bin);
+
+ #elif CONST_HESSIAN == 1
+ atomic_local_add_f(addr_bin, grad_bin);
+ #endif
+
+ bin = feature;
+ grad_bin = grad;
+ hess_bin = hess;
+ } else {
+ grad_bin += grad;
+ hess_bin += hess;
+ }
+
+ // prefetch the next iteration variables
+ feature_next = feature_data[ind_next >> feature_mask];
+
+ // STAGE 3: accumulate counter
+ atomicAdd(cnt_hist + feature, 1);
+
+ // STAGE 4: update next stat
+ grad = grad_next;
+ hess = hess_next;
+ if (!feature_mask) {
+ feature = feature_next;
+ } else {
+ feature = (feature_next >> ((ind_next & 1) << 2)) & 0xf;
+ }
+ }
+
+ addr_bin = gh_hist + bin * 2 + is_hessian_first;
+ #if CONST_HESSIAN == 0
+ acc_type acc_bin = is_hessian_first ? hess_bin : grad_bin;
+ atomic_local_add_f(addr_bin, acc_bin);
+
+ addr_bin = addr_bin + 1 - 2 * is_hessian_first;
+ acc_bin = is_hessian_first ? grad_bin : hess_bin;
+
+ atomic_local_add_f(addr_bin, acc_bin);
+
+ #elif CONST_HESSIAN == 1
+ atomic_local_add_f(addr_bin, grad_bin);
+ #endif
+ __syncthreads();
+
+ #if CONST_HESSIAN == 1
+ // make a final reduction
+ gh_hist[ltid * 2] += gh_hist[ltid * 2 + 1];
+ gh_hist[ltid * 2 + 1] = const_hessian * cnt_hist[ltid]; // counter move to this position
+ __syncthreads();
+ #endif
+
+#if POWER_FEATURE_WORKGROUPS != 0
+ acc_type *__restrict__ output = reinterpret_cast(output_buf) + group_id * 3 * NUM_BINS;
+ // write gradients and hessians
+ acc_type *__restrict__ ptr_f = output;
+ for (uint16_t i = ltid; i < 2 * NUM_BINS; i += lsize) {
+ // even threads read gradients, odd threads read hessians
+ acc_type value = gh_hist[i];
+ ptr_f[(i & 1) * NUM_BINS + (i >> 1)] = value;
+ }
+ // write counts
+ acc_int_type *__restrict__ ptr_i = reinterpret_cast(output + 2 * NUM_BINS);
+ for (uint16_t i = ltid; i < NUM_BINS; i += lsize) {
+ unsigned int value = cnt_hist[i];
+ ptr_i[i] = value;
+ }
+ __syncthreads();
+ __threadfence();
+ unsigned int * counter_val = cnt_hist;
+ // backup the old value
+ unsigned int old_val = *counter_val;
+ if (ltid == 0) {
+ // all workgroups processing the same feature add this counter
+ *counter_val = atomicAdd(const_cast(sync_counters + feature_id), 1);
+ }
+ // make sure everyone in this workgroup is here
+ __syncthreads();
+ // everyone in this workgroup: if we are the last workgroup, then do reduction!
+ if (*counter_val == (1 << power_feature_workgroups) - 1) {
+ if (ltid == 0) {
+ sync_counters[feature_id] = 0;
+ }
+#else
+ }
+ // only 1 work group, no need to increase counter
+ // the reduction will become a simple copy
+ {
+ unsigned int old_val; // dummy
+#endif
+ // locate our feature's block in output memory
+ unsigned int output_offset = (feature_id << power_feature_workgroups);
+ acc_type const * __restrict__ feature_subhists =
+ reinterpret_cast(output_buf) + output_offset * 3 * NUM_BINS;
+ // skip reading the data already in local memory
+ unsigned int skip_id = group_id - output_offset;
+ // locate output histogram location for this feature4
+ acc_type *__restrict__ hist_buf = hist_buf_base + feature_id * 2 * NUM_BINS;
+
+ within_kernel_reduction256x4(feature_subhists, skip_id, old_val, 1 << power_feature_workgroups, hist_buf, reinterpret_cast(shared_array), power_feature_workgroups);
+ }
+}
+
+// end of histogram256 stuff
+
+} // namespace LightGBM
diff --git a/src/treelearner/kernels/histogram_16_64_256.hu b/src/treelearner/kernels/histogram_16_64_256.hu
new file mode 100644
index 000000000000..8e3d3a5ec782
--- /dev/null
+++ b/src/treelearner/kernels/histogram_16_64_256.hu
@@ -0,0 +1,161 @@
+/*!
+ * Copyright (c) 2020 IBM Corporation. All rights reserved.
+ * Licensed under the MIT License. See LICENSE file in the project root for license information.
+ */
+
+#ifndef LIGHTGBM_TREELEARNER_KERNELS_HISTOGRAM_16_64_256_HU_
+#define LIGHTGBM_TREELEARNER_KERNELS_HISTOGRAM_16_64_256_HU_
+
+#include "LightGBM/meta.h"
+
+namespace LightGBM {
+
+// use double precision or not
+#ifndef USE_DP_FLOAT
+#define USE_DP_FLOAT 1
+#endif
+
+// ignore hessian, and use the local memory for hessian as an additional bank for gradient
+#ifndef CONST_HESSIAN
+#define CONST_HESSIAN 0
+#endif
+
+typedef unsigned char uchar;
+
+template
+__device__ double as_double(const T t) {
+ static_assert(sizeof(T) == sizeof(double), "size mismatch");
+ double d;
+ memcpy(&d, &t, sizeof(T));
+ return d;
+}
+template
+__device__ unsigned long long as_ulong_ulong(const T t) {
+ static_assert(sizeof(T) == sizeof(unsigned long long), "size mismatch");
+ unsigned long long u;
+ memcpy(&u, &t, sizeof(T));
+ return u;
+}
+template
+__device__ float as_float(const T t) {
+ static_assert(sizeof(T) == sizeof(float), "size mismatch");
+ float f;
+ memcpy(&f, &t, sizeof(T));
+ return f;
+}
+template
+__device__ unsigned int as_uint(const T t) {
+ static_assert(sizeof(T) == sizeof(unsigned int), "size_mismatch");
+ unsigned int u;
+ memcpy(&u, &t, sizeof(T));
+ return u;
+}
+template
+__device__ uchar4 as_uchar4(const T t) {
+ static_assert(sizeof(T) == sizeof(uchar4), "size mismatch");
+ uchar4 u;
+ memcpy(&u, &t, sizeof(T));
+ return u;
+}
+
+#if USE_DP_FLOAT == 1
+typedef double acc_type;
+typedef unsigned long long acc_int_type;
+#define as_acc_type as_double
+#define as_acc_int_type as_ulong_ulong
+#else
+typedef float acc_type;
+typedef unsigned int acc_int_type;
+#define as_acc_type as_float
+#define as_acc_int_type as_uint
+#endif
+
+// use all features and do not use feature mask
+#ifndef ENABLE_ALL_FEATURES
+#define ENABLE_ALL_FEATURES 1
+#endif
+
+// define all of the different kernels
+
+#define DECLARE_CONST_BUF(name) \
+__global__ void name(__global const uchar* restrict feature_data_base, \
+ const uchar* restrict feature_masks,\
+ const data_size_t feature_size,\
+ const data_size_t* restrict data_indices, \
+ const data_size_t num_data, \
+ const score_t* restrict ordered_gradients, \
+ const score_t* restrict ordered_hessians,\
+ char* __restrict__ output_buf,\
+ volatile int * sync_counters,\
+ acc_type* __restrict__ hist_buf_base, \
+ const size_t power_feature_workgroups);
+
+
+#define DECLARE_CONST_HES_CONST_BUF(name) \
+__global__ void name(const uchar* __restrict__ feature_data_base, \
+ const uchar* __restrict__ feature_masks,\
+ const data_size_t feature_size,\
+ const data_size_t* __restrict__ data_indices, \
+ const data_size_t num_data, \
+ const score_t* __restrict__ ordered_gradients, \
+ const score_t const_hessian,\
+ char* __restrict__ output_buf,\
+ volatile int * sync_counters,\
+ acc_type* __restrict__ hist_buf_base, \
+ const size_t power_feature_workgroups);
+
+
+
+#define DECLARE_CONST_HES(name) \
+__global__ void name(const uchar* feature_data_base, \
+ const uchar* __restrict__ feature_masks,\
+ const data_size_t feature_size,\
+ const data_size_t* data_indices, \
+ const data_size_t num_data, \
+ const score_t* ordered_gradients, \
+ const score_t const_hessian,\
+ char* __restrict__ output_buf, \
+ volatile int * sync_counters,\
+ acc_type* __restrict__ hist_buf_base, \
+ const size_t power_feature_workgroups);
+
+
+#define DECLARE(name) \
+__global__ void name(const uchar* feature_data_base, \
+ const uchar* __restrict__ feature_masks,\
+ const data_size_t feature_size,\
+ const data_size_t* data_indices, \
+ const data_size_t num_data, \
+ const score_t* ordered_gradients, \
+ const score_t* ordered_hessians,\
+ char* __restrict__ output_buf, \
+ volatile int * sync_counters,\
+ acc_type* __restrict__ hist_buf_base, \
+ const size_t power_feature_workgroups);
+
+
+DECLARE_CONST_HES(histogram16_allfeats);
+DECLARE_CONST_HES(histogram16_fulldata);
+DECLARE_CONST_HES(histogram16);
+DECLARE(histogram16_allfeats);
+DECLARE(histogram16_fulldata);
+DECLARE(histogram16);
+
+DECLARE_CONST_HES(histogram64_allfeats);
+DECLARE_CONST_HES(histogram64_fulldata);
+DECLARE_CONST_HES(histogram64);
+DECLARE(histogram64_allfeats);
+DECLARE(histogram64_fulldata);
+DECLARE(histogram64);
+
+DECLARE_CONST_HES(histogram256_allfeats);
+DECLARE_CONST_HES(histogram256_fulldata);
+DECLARE_CONST_HES(histogram256);
+DECLARE(histogram256_allfeats);
+DECLARE(histogram256_fulldata);
+DECLARE(histogram256);
+
+} // namespace LightGBM
+
+#endif // LIGHTGBM_TREELEARNER_KERNELS_HISTOGRAM_16_64_256_HU_
+
diff --git a/src/treelearner/parallel_tree_learner.h b/src/treelearner/parallel_tree_learner.h
index 137697408e8d..2001f2e0dfeb 100644
--- a/src/treelearner/parallel_tree_learner.h
+++ b/src/treelearner/parallel_tree_learner.h
@@ -12,6 +12,7 @@
#include
#include
+#include "cuda_tree_learner.h"
#include "gpu_tree_learner.h"
#include "serial_tree_learner.h"
diff --git a/src/treelearner/serial_tree_learner.cpp b/src/treelearner/serial_tree_learner.cpp
index 22b353952eea..92f269304419 100644
--- a/src/treelearner/serial_tree_learner.cpp
+++ b/src/treelearner/serial_tree_learner.cpp
@@ -326,7 +326,16 @@ void SerialTreeLearner::FindBestSplits(const Tree* tree) {
is_feature_used[feature_index] = 1;
}
bool use_subtract = parent_leaf_histogram_array_ != nullptr;
+
+#ifdef USE_CUDA
+ if (LGBM_config_::current_learner == use_cpu_learner) {
+ SerialTreeLearner::ConstructHistograms(is_feature_used, use_subtract);
+ } else {
+ ConstructHistograms(is_feature_used, use_subtract);
+ }
+#else
ConstructHistograms(is_feature_used, use_subtract);
+#endif
FindBestSplitsFromHistograms(is_feature_used, use_subtract, tree);
}
diff --git a/src/treelearner/serial_tree_learner.h b/src/treelearner/serial_tree_learner.h
index e6ac8e3ad09c..59ba770fb95e 100644
--- a/src/treelearner/serial_tree_learner.h
+++ b/src/treelearner/serial_tree_learner.h
@@ -8,6 +8,7 @@
#include
#include
#include
+#include
#include
#include
#include
@@ -201,6 +202,11 @@ class SerialTreeLearner: public TreeLearner {
std::vector