From 9ae71a16c686b102affdebff840dd1c92586783e Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Thu, 15 Dec 2022 22:23:26 +0800 Subject: [PATCH] Define CUDA Context. We will transition to non-default and non-blocking CUDA stream. --- include/xgboost/context.h | 16 +++++++--- src/common/cuda_context.cuh | 28 +++++++++++++++++ src/context.cc | 9 +++++- src/context.cu | 14 +++++++++ src/data/data.cu | 38 +++++++++++++---------- src/learner.cc | 2 +- src/tree/gpu_hist/histogram.cu | 12 +++---- src/tree/gpu_hist/histogram.cuh | 9 +++--- src/tree/updater_gpu_hist.cu | 9 +++--- tests/cpp/tree/gpu_hist/test_histogram.cu | 36 ++++++++++----------- tests/cpp/tree/test_gpu_hist.cu | 9 +++--- 11 files changed, 120 insertions(+), 62 deletions(-) create mode 100644 src/common/cuda_context.cuh create mode 100644 src/context.cu diff --git a/include/xgboost/context.h b/include/xgboost/context.h index 66ad1d4bb7f7..aaa1e3eb88b3 100644 --- a/include/xgboost/context.h +++ b/include/xgboost/context.h @@ -8,15 +8,14 @@ #include #include +#include // std::shared_ptr #include namespace xgboost { -struct Context : public XGBoostParameter { - private: - // cached value for CFS CPU limit. (used in containerized env) - std::int32_t cfs_cpu_count_; // NOLINT +struct CUDAContext; +struct Context : public XGBoostParameter { public: // Constant representing the device ID of CPU. static std::int32_t constexpr kCpuId = -1; @@ -51,6 +50,7 @@ struct Context : public XGBoostParameter { bool IsCPU() const { return gpu_id == kCpuId; } bool IsCUDA() const { return !IsCPU(); } + CUDAContext const* CUDACtx() const; // declare parameters DMLC_DECLARE_PARAMETER(Context) { @@ -73,6 +73,14 @@ struct Context : public XGBoostParameter { .set_default(false) .describe("Enable checking whether parameters are used or not."); } + + private: + // mutable for lazy initialization for cuda context to avoid initializing CUDA at load. + // shared_ptr is used instead of unique_ptr as with unique_ptr it's difficult to define p_impl + // while trying to hide CUDA code from host compiler. + mutable std::shared_ptr cuctx_; + // cached value for CFS CPU limit. (used in containerized env) + std::int32_t cfs_cpu_count_; // NOLINT }; } // namespace xgboost diff --git a/src/common/cuda_context.cuh b/src/common/cuda_context.cuh new file mode 100644 index 000000000000..9056c1b5e032 --- /dev/null +++ b/src/common/cuda_context.cuh @@ -0,0 +1,28 @@ +/** + * Copyright 2022 by XGBoost Contributors + */ +#ifndef XGBOOST_COMMON_CUDA_CONTEXT_CUH_ +#define XGBOOST_COMMON_CUDA_CONTEXT_CUH_ +#include + +#include "device_helpers.cuh" + +namespace xgboost { +struct CUDAContext { + private: + dh::XGBCachingDeviceAllocator caching_alloc_; + dh::XGBDeviceAllocator alloc_; + + public: + /** + * \brief Caching thrust policy. + */ + auto CTP() const { return thrust::cuda::par(caching_alloc_).on(dh::DefaultStream()); } + /** + * \brief Thrust policy without caching allocator. + */ + auto TP() const { return thrust::cuda::par(alloc_).on(dh::DefaultStream()); } + auto Stream() const { return dh::DefaultStream(); } +}; +} // namespace xgboost +#endif // XGBOOST_COMMON_CUDA_CONTEXT_CUH_ diff --git a/src/context.cc b/src/context.cc index 571aa943ea04..437c16f1df7a 100644 --- a/src/context.cc +++ b/src/context.cc @@ -5,7 +5,7 @@ */ #include -#include "common/common.h" +#include "common/common.h" // AssertGPUSupport #include "common/threading_utils.h" namespace xgboost { @@ -59,4 +59,11 @@ std::int32_t Context::Threads() const { } return n_threads; } + +#if !defined(XGBOOST_USE_CUDA) +CUDAContext const* Context::CUDACtx() const { + common::AssertGPUSupport(); + return nullptr; +} +#endif // defined(XGBOOST_USE_CUDA) } // namespace xgboost diff --git a/src/context.cu b/src/context.cu new file mode 100644 index 000000000000..bc2f3714706b --- /dev/null +++ b/src/context.cu @@ -0,0 +1,14 @@ +/** + * Copyright 2022 by XGBoost Contributors + */ +#include "common/cuda_context.cuh" // CUDAContext +#include "xgboost/context.h" + +namespace xgboost { +CUDAContext const* Context::CUDACtx() const { + if (!cuctx_) { + cuctx_.reset(new CUDAContext{}); + } + return cuctx_.get(); +} +} // namespace xgboost diff --git a/src/data/data.cu b/src/data/data.cu index e983f75dc76a..4dedc7d24c4e 100644 --- a/src/data/data.cu +++ b/src/data/data.cu @@ -1,18 +1,19 @@ -/*! - * Copyright 2019-2021 by XGBoost Contributors +/** + * Copyright 2019-2022 by XGBoost Contributors * * \file data.cu * \brief Handles setting metainfo from array interface. */ -#include "xgboost/data.h" -#include "xgboost/logging.h" -#include "xgboost/json.h" -#include "array_interface.h" +#include "../common/cuda_context.cuh" #include "../common/device_helpers.cuh" #include "../common/linalg_op.cuh" +#include "array_interface.h" #include "device_adapter.cuh" #include "simple_dmatrix.h" #include "validation.h" +#include "xgboost/data.h" +#include "xgboost/json.h" +#include "xgboost/logging.h" namespace xgboost { namespace { @@ -25,7 +26,7 @@ auto SetDeviceToPtr(void const* ptr) { } template -void CopyTensorInfoImpl(Json arr_interface, linalg::Tensor* p_out) { +void CopyTensorInfoImpl(CUDAContext const* ctx, Json arr_interface, linalg::Tensor* p_out) { ArrayInterface array(arr_interface); if (array.n == 0) { p_out->SetDevice(0); @@ -43,15 +44,19 @@ void CopyTensorInfoImpl(Json arr_interface, linalg::Tensor* p_out) { // set data data->Resize(array.n); dh::safe_cuda(cudaMemcpyAsync(data->DevicePointer(), array.data, array.n * sizeof(T), - cudaMemcpyDefault)); + cudaMemcpyDefault, ctx->Stream())); }); return; } p_out->Reshape(array.shape); auto t = p_out->View(ptr_device); - linalg::ElementWiseTransformDevice(t, [=] __device__(size_t i, T) { - return linalg::detail::Apply(TypedIndex{array}, linalg::UnravelIndex(i, array.shape)); - }); + linalg::ElementWiseTransformDevice( + t, + [=] __device__(size_t i, T) { + return linalg::detail::Apply(TypedIndex{array}, + linalg::UnravelIndex(i, array.shape)); + }, + ctx->Stream()); } void CopyGroupInfoImpl(ArrayInterface<1> column, std::vector* out) { @@ -115,14 +120,13 @@ void CopyQidImpl(ArrayInterface<1> array_interface, std::vector* p_ } } // namespace -// Context is not used until we have CUDA stream. -void MetaInfo::SetInfoFromCUDA(Context const&, StringView key, Json array) { +void MetaInfo::SetInfoFromCUDA(Context const& ctx, StringView key, Json array) { // multi-dim float info if (key == "base_margin") { - CopyTensorInfoImpl(array, &base_margin_); + CopyTensorInfoImpl(ctx.CUDACtx(), array, &base_margin_); return; } else if (key == "label") { - CopyTensorInfoImpl(array, &labels); + CopyTensorInfoImpl(ctx.CUDACtx(), array, &labels); auto ptr = labels.Data()->ConstDevicePointer(); auto valid = thrust::none_of(thrust::device, ptr, ptr + labels.Size(), data::LabelsCheck{}); CHECK(valid) << "Label contains NaN, infinity or a value too large."; @@ -142,7 +146,7 @@ void MetaInfo::SetInfoFromCUDA(Context const&, StringView key, Json array) { } // float info linalg::Tensor t; - CopyTensorInfoImpl(array, &t); + CopyTensorInfoImpl(ctx.CUDACtx(), array, &t); if (key == "weight") { this->weights_ = std::move(*t.Data()); auto ptr = weights_.ConstDevicePointer(); @@ -156,7 +160,7 @@ void MetaInfo::SetInfoFromCUDA(Context const&, StringView key, Json array) { this->feature_weights = std::move(*t.Data()); auto d_feature_weights = feature_weights.ConstDeviceSpan(); auto valid = - thrust::none_of(thrust::device, d_feature_weights.data(), + thrust::none_of(ctx.CUDACtx()->CTP(), d_feature_weights.data(), d_feature_weights.data() + d_feature_weights.size(), data::WeightsCheck{}); CHECK(valid) << "Feature weight must be greater than 0."; } else { diff --git a/src/learner.cc b/src/learner.cc index 22639904980c..e6f00bffcec5 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -35,7 +35,7 @@ #include "common/version.h" #include "xgboost/base.h" #include "xgboost/c_api.h" -#include "xgboost/context.h" +#include "xgboost/context.h" // Context #include "xgboost/data.h" #include "xgboost/feature_map.h" #include "xgboost/gbm.h" diff --git a/src/tree/gpu_hist/histogram.cu b/src/tree/gpu_hist/histogram.cu index 650ea2b457a1..f02fb909ea75 100644 --- a/src/tree/gpu_hist/histogram.cu +++ b/src/tree/gpu_hist/histogram.cu @@ -267,12 +267,12 @@ __global__ void __launch_bounds__(kBlockThreads) } } -void BuildGradientHistogram(EllpackDeviceAccessor const& matrix, +void BuildGradientHistogram(CUDAContext const* ctx, EllpackDeviceAccessor const& matrix, FeatureGroupsAccessor const& feature_groups, common::Span gpair, common::Span d_ridx, - common::Span histogram, - GradientQuantiser rounding, bool force_global_memory) { + common::Span histogram, GradientQuantiser rounding, + bool force_global_memory) { // decide whether to use shared memory int device = 0; dh::safe_cuda(cudaGetDevice(&device)); @@ -318,9 +318,9 @@ void BuildGradientHistogram(EllpackDeviceAccessor const& matrix, min(grid_size, unsigned(common::DivRoundUp(items_per_group, kMinItemsPerBlock))); - dh::LaunchKernel {dim3(grid_size, num_groups), - static_cast(kBlockThreads), smem_size}( - kernel, matrix, feature_groups, d_ridx, histogram.data(), gpair.data(), rounding); + dh::LaunchKernel{dim3(grid_size, num_groups), static_cast(kBlockThreads), smem_size, + ctx->Stream()} (kernel, matrix, feature_groups, d_ridx, histogram.data(), + gpair.data(), rounding); }; if (shared) { diff --git a/src/tree/gpu_hist/histogram.cuh b/src/tree/gpu_hist/histogram.cuh index d2f11853e777..5c3c955d1108 100644 --- a/src/tree/gpu_hist/histogram.cuh +++ b/src/tree/gpu_hist/histogram.cuh @@ -5,9 +5,9 @@ #define HISTOGRAM_CUH_ #include -#include "feature_groups.cuh" - +#include "../../common/cuda_context.cuh" #include "../../data/ellpack_page.cuh" +#include "feature_groups.cuh" namespace xgboost { namespace tree { @@ -56,12 +56,11 @@ public: } }; -void BuildGradientHistogram(EllpackDeviceAccessor const& matrix, +void BuildGradientHistogram(CUDAContext const* ctx, EllpackDeviceAccessor const& matrix, FeatureGroupsAccessor const& feature_groups, common::Span gpair, common::Span ridx, - common::Span histogram, - GradientQuantiser rounding, + common::Span histogram, GradientQuantiser rounding, bool force_global_memory = false); } // namespace tree } // namespace xgboost diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index caa33d87fd13..85371672639c 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -20,6 +20,7 @@ #include "../common/io.h" #include "../common/timer.h" #include "../data/ellpack_page.cuh" +#include "../common/cuda_context.cuh" // CUDAContext #include "constraints.cuh" #include "driver.h" #include "gpu_hist/evaluate_splits.cuh" @@ -344,9 +345,9 @@ struct GPUHistMakerDevice { void BuildHist(int nidx) { auto d_node_hist = hist.GetNodeHistogram(nidx); auto d_ridx = row_partitioner->GetRows(nidx); - BuildGradientHistogram(page->GetDeviceAccessor(ctx_->gpu_id), - feature_groups->DeviceAccessor(ctx_->gpu_id), gpair, - d_ridx, d_node_hist, *quantiser); + BuildGradientHistogram(ctx_->CUDACtx(), page->GetDeviceAccessor(ctx_->gpu_id), + feature_groups->DeviceAccessor(ctx_->gpu_id), gpair, d_ridx, d_node_hist, + *quantiser); } // Attempt to do subtraction trick @@ -646,7 +647,7 @@ struct GPUHistMakerDevice { return quantiser.ToFixedPoint(gpair); }); GradientPairInt64 root_sum_quantised = - dh::Reduce(thrust::cuda::par(alloc), gpair_it, gpair_it + gpair.size(), + dh::Reduce(ctx_->CUDACtx()->CTP(), gpair_it, gpair_it + gpair.size(), GradientPairInt64{}, thrust::plus{}); using ReduceT = typename decltype(root_sum_quantised)::ValueT; collective::Allreduce( diff --git a/tests/cpp/tree/gpu_hist/test_histogram.cu b/tests/cpp/tree/gpu_hist/test_histogram.cu index 227a6d69bfb8..95fe66138333 100644 --- a/tests/cpp/tree/gpu_hist/test_histogram.cu +++ b/tests/cpp/tree/gpu_hist/test_histogram.cu @@ -11,6 +11,7 @@ namespace xgboost { namespace tree { void TestDeterministicHistogram(bool is_dense, int shm_size) { + Context ctx = CreateEmptyGenericParam(0); size_t constexpr kBins = 256, kCols = 120, kRows = 16384, kRounds = 16; float constexpr kLower = -1e-2, kUpper = 1e2; @@ -34,9 +35,9 @@ void TestDeterministicHistogram(bool is_dense, int shm_size) { sizeof(GradientPairInt64)); auto quantiser = GradientQuantiser(gpair.DeviceSpan()); - BuildGradientHistogram(page->GetDeviceAccessor(0), - feature_groups.DeviceAccessor(0), gpair.DeviceSpan(), - ridx, d_histogram, quantiser); + BuildGradientHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(0), + feature_groups.DeviceAccessor(0), gpair.DeviceSpan(), ridx, d_histogram, + quantiser); std::vector histogram_h(num_bins); dh::safe_cuda(cudaMemcpy(histogram_h.data(), d_histogram.data(), @@ -48,10 +49,9 @@ void TestDeterministicHistogram(bool is_dense, int shm_size) { auto d_new_histogram = dh::ToSpan(new_histogram); auto quantiser = GradientQuantiser(gpair.DeviceSpan()); - BuildGradientHistogram(page->GetDeviceAccessor(0), - feature_groups.DeviceAccessor(0), - gpair.DeviceSpan(), ridx, d_new_histogram, - quantiser); + BuildGradientHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(0), + feature_groups.DeviceAccessor(0), gpair.DeviceSpan(), ridx, + d_new_histogram, quantiser); std::vector new_histogram_h(num_bins); dh::safe_cuda(cudaMemcpy(new_histogram_h.data(), d_new_histogram.data(), @@ -71,10 +71,9 @@ void TestDeterministicHistogram(bool is_dense, int shm_size) { FeatureGroups single_group(page->Cuts()); dh::device_vector baseline(num_bins); - BuildGradientHistogram(page->GetDeviceAccessor(0), - single_group.DeviceAccessor(0), - gpair.DeviceSpan(), ridx, dh::ToSpan(baseline), - quantiser); + BuildGradientHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(0), + single_group.DeviceAccessor(0), gpair.DeviceSpan(), ridx, + dh::ToSpan(baseline), quantiser); std::vector baseline_h(num_bins); dh::safe_cuda(cudaMemcpy(baseline_h.data(), baseline.data().get(), @@ -115,6 +114,7 @@ void ValidateCategoricalHistogram(size_t n_categories, common::SpanGetBatches(batch_param)) { auto* page = batch.Impl(); FeatureGroups single_group(page->Cuts()); - BuildGradientHistogram(page->GetDeviceAccessor(0), - single_group.DeviceAccessor(0), - gpair.DeviceSpan(), ridx, dh::ToSpan(cat_hist), - quantiser); + BuildGradientHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(0), + single_group.DeviceAccessor(0), gpair.DeviceSpan(), ridx, + dh::ToSpan(cat_hist), quantiser); } /** @@ -148,10 +147,9 @@ void TestGPUHistogramCategorical(size_t num_categories) { for (auto const &batch : encode_m->GetBatches(batch_param)) { auto* page = batch.Impl(); FeatureGroups single_group(page->Cuts()); - BuildGradientHistogram(page->GetDeviceAccessor(0), - single_group.DeviceAccessor(0), - gpair.DeviceSpan(), ridx, dh::ToSpan(encode_hist), - quantiser); + BuildGradientHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(0), + single_group.DeviceAccessor(0), gpair.DeviceSpan(), ridx, + dh::ToSpan(encode_hist), quantiser); } std::vector h_cat_hist(cat_hist.size()); diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index 1758c872f286..100a4c393ae2 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -109,11 +109,10 @@ void TestBuildHist(bool use_shared_memory_histograms) { maker.gpair = gpair.DeviceSpan(); maker.quantiser.reset(new GradientQuantiser(maker.gpair)); - BuildGradientHistogram( - page->GetDeviceAccessor(0), maker.feature_groups->DeviceAccessor(0), - gpair.DeviceSpan(), maker.row_partitioner->GetRows(0), - maker.hist.GetNodeHistogram(0), *maker.quantiser, - !use_shared_memory_histograms); + BuildGradientHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(0), + maker.feature_groups->DeviceAccessor(0), gpair.DeviceSpan(), + maker.row_partitioner->GetRows(0), maker.hist.GetNodeHistogram(0), + *maker.quantiser, !use_shared_memory_histograms); DeviceHistogramStorage<>& d_hist = maker.hist;