From b5d8fefc77922f2d1859cfe6685f2455d52bad8e Mon Sep 17 00:00:00 2001 From: Dmitry Razdoburdin <> Date: Thu, 27 Feb 2025 02:56:23 -0800 Subject: [PATCH 1/2] fix for sycl iGPU --- plugin/sycl/common/transform.h | 25 +++++++++++++------ src/common/transform.h | 4 +-- src/objective/aft_obj.cu | 4 +-- src/objective/hinge.cu | 2 +- src/objective/multiclass_obj.cu | 25 +++++++++++++------ src/objective/regression_obj.cu | 13 +++++----- src/tree/split_evaluator.h | 3 ++- tests/cpp/common/test_transform_range.cc | 5 ++-- tests/cpp/plugin/test_sycl_transform_range.cc | 3 ++- 9 files changed, 54 insertions(+), 30 deletions(-) diff --git a/plugin/sycl/common/transform.h b/plugin/sycl/common/transform.h index 261d71f2330d..81fed3f22f6a 100644 --- a/plugin/sycl/common/transform.h +++ b/plugin/sycl/common/transform.h @@ -20,13 +20,24 @@ void LaunchSyclKernel(DeviceOrd device, Functor&& _func, xgboost::common::Range auto* qu = device_manager.GetQueue(device); size_t size = *(_range.end()); - qu->submit([&](::sycl::handler& cgh) { - cgh.parallel_for<>(::sycl::range<1>(size), - [=](::sycl::id<1> pid) { - const size_t idx = pid[0]; - const_cast(_func)(idx, _spans...); - }); - }).wait(); + const bool has_fp64_support = qu->get_device().has(::sycl::aspect::fp64); + if (has_fp64_support) { + qu->submit([&](::sycl::handler& cgh) { + cgh.parallel_for<>(::sycl::range<1>(size), + [=](::sycl::id<1> pid) { + const size_t idx = pid[0]; + const_cast(_func)(idx, std::true_type(), _spans...); + }); + }).wait(); + } else { + qu->submit([&](::sycl::handler& cgh) { + cgh.parallel_for<>(::sycl::range<1>(size), + [=](::sycl::id<1> pid) { + const size_t idx = pid[0]; + const_cast(_func)(idx, std::false_type(), _spans...); + }); + }).wait(); + } } } // namespace common diff --git a/src/common/transform.h b/src/common/transform.h index e23ffb5398c7..1699a4889000 100644 --- a/src/common/transform.h +++ b/src/common/transform.h @@ -37,7 +37,7 @@ template __global__ void LaunchCUDAKernel(Functor _func, Range _range, SpanType... _spans) { for (auto i : dh::GridStrideRange(*_range.begin(), *_range.end())) { - _func(i, _spans...); + _func(i, std::true_type(), _spans...); } } #endif // defined(__CUDACC__) @@ -184,7 +184,7 @@ class Transform { void LaunchCPU(Functor func, HDV *...vectors) const { omp_ulong end = static_cast(*(range_.end())); SyncHost(vectors...); - ParallelFor(end, n_threads_, [&](omp_ulong idx) { func(idx, UnpackHDV(vectors)...); }); + ParallelFor(end, n_threads_, [&](omp_ulong idx) { func(idx, std::true_type(), UnpackHDV(vectors)...); }); } private: diff --git a/src/objective/aft_obj.cu b/src/objective/aft_obj.cu index 3ad9ca847db7..a3c4fe596626 100644 --- a/src/objective/aft_obj.cu +++ b/src/objective/aft_obj.cu @@ -45,7 +45,7 @@ class AFTObj : public ObjFunction { linalg::Matrix* out_gpair, size_t ndata, DeviceOrd device, bool is_null_weight, float aft_loss_distribution_scale) { common::Transform<>::Init( - [=] XGBOOST_DEVICE(size_t _idx, + [=] XGBOOST_DEVICE(size_t _idx, auto has_fp64_support, common::Span _out_gpair, common::Span _preds, common::Span _labels_lower_bound, @@ -104,7 +104,7 @@ class AFTObj : public ObjFunction { void PredTransform(HostDeviceVector *io_preds) const override { // Trees give us a prediction in log scale, so exponentiate common::Transform<>::Init( - [] XGBOOST_DEVICE(size_t _idx, common::Span _preds) { + [] XGBOOST_DEVICE(size_t _idx, auto has_fp64_support, common::Span _preds) { _preds[_idx] = exp(_preds[_idx]); }, common::Range{0, static_cast(io_preds->Size())}, this->ctx_->Threads(), diff --git a/src/objective/hinge.cu b/src/objective/hinge.cu index a850df09ea06..f2ef1474cc92 100644 --- a/src/objective/hinge.cu +++ b/src/objective/hinge.cu @@ -85,7 +85,7 @@ class HingeObj : public FitIntercept { void PredTransform(HostDeviceVector *io_preds) const override { common::Transform<>::Init( - [] XGBOOST_DEVICE(std::size_t _idx, common::Span _preds) { + [] XGBOOST_DEVICE(std::size_t _idx, auto has_fp64_support, common::Span _preds) { _preds[_idx] = _preds[_idx] > 0.0 ? 1.0 : 0.0; }, common::Range{0, static_cast(io_preds->Size()), 1}, this->ctx_->Threads(), diff --git a/src/objective/multiclass_obj.cu b/src/objective/multiclass_obj.cu index 1a3df38841bd..0a7e3309c334 100644 --- a/src/objective/multiclass_obj.cu +++ b/src/objective/multiclass_obj.cu @@ -75,7 +75,7 @@ class SoftmaxMultiClassObj : public ObjFunction { } common::Transform<>::Init( - [=] XGBOOST_DEVICE(size_t idx, + [=] XGBOOST_DEVICE(size_t idx, auto has_fp64_support, common::Span gpair, common::Span labels, common::Span preds, @@ -86,8 +86,16 @@ class SoftmaxMultiClassObj : public ObjFunction { // Part of Softmax function bst_float wmax = std::numeric_limits::min(); for (auto const i : point) { wmax = fmaxf(i, wmax); } - double wsum = 0.0f; - for (auto const i : point) { wsum += expf(i - wmax); } + + float wsum = 0.0f; + if constexpr (has_fp64_support) { + double wsum_fp64 = 0; + for (auto const i : point) { wsum_fp64 += expf(i - wmax); } + wsum = static_cast(wsum_fp64); + } else { + for (auto const i : point) { wsum += expf(i - wmax); } + } + auto label = labels[idx]; if (label < 0 || label >= nclass) { _label_correct[0] = 0; @@ -96,11 +104,11 @@ class SoftmaxMultiClassObj : public ObjFunction { bst_float wt = is_null_weight ? 1.0f : weights[idx]; for (int k = 0; k < nclass; ++k) { // Computation duplicated to avoid creating a cache. - bst_float p = expf(point[k] - wmax) / static_cast(wsum); + bst_float p = expf(point[k] - wmax) / wsum; const float eps = 1e-16f; - const bst_float h = fmax(2.0f * p * (1.0f - p) * wt, eps); + const bst_float h = 2.0f * p * (1.0f - p) * wt; p = label == k ? p - 1.0f : p; - gpair[idx * nclass + k] = GradientPair(p * wt, h); + gpair[idx * nclass + k] = GradientPair(p * wt, h < eps ? eps : h); } }, common::Range{0, ndata}, ctx_->Threads(), device) .Eval(out_gpair->Data(), info.labels.Data(), &preds, &info.weights_, &label_correct_); @@ -129,7 +137,7 @@ class SoftmaxMultiClassObj : public ObjFunction { auto device = io_preds->Device(); if (prob) { common::Transform<>::Init( - [=] XGBOOST_DEVICE(size_t _idx, common::Span _preds) { + [=] XGBOOST_DEVICE(size_t _idx, auto has_fp64_support, common::Span _preds) { common::Span point = _preds.subspan(_idx * nclass, nclass); common::Softmax(point.begin(), point.end()); @@ -142,7 +150,8 @@ class SoftmaxMultiClassObj : public ObjFunction { max_preds.SetDevice(device); max_preds.Resize(ndata); common::Transform<>::Init( - [=] XGBOOST_DEVICE(size_t _idx, common::Span _preds, + [=] XGBOOST_DEVICE(size_t _idx, auto has_fp64_support, + common::Span _preds, common::Span _max_preds) { common::Span point = _preds.subspan(_idx * nclass, nclass); diff --git a/src/objective/regression_obj.cu b/src/objective/regression_obj.cu index b5e57199f969..60ac801874cd 100644 --- a/src/objective/regression_obj.cu +++ b/src/objective/regression_obj.cu @@ -142,7 +142,8 @@ class RegLossObj : public FitInterceptGlmLike { common::Transform<>::Init( [block_size, ndata, n_targets] XGBOOST_DEVICE( - size_t data_block_idx, common::Span _additional_input, + size_t data_block_idx, auto has_fp64_support, + common::Span _additional_input, common::Span _out_gpair, common::Span _preds, common::Span _labels, @@ -179,7 +180,7 @@ class RegLossObj : public FitInterceptGlmLike { void PredTransform(HostDeviceVector *io_preds) const override { common::Transform<>::Init( - [] XGBOOST_DEVICE(size_t _idx, common::Span _preds) { + [] XGBOOST_DEVICE(size_t _idx, auto has_fp64_support, common::Span _preds) { _preds[_idx] = Loss::PredTransform(_preds[_idx]); }, common::Range{0, static_cast(io_preds->Size())}, this->ctx_->Threads(), @@ -360,7 +361,7 @@ class PoissonRegression : public FitInterceptGlmLike { } bst_float max_delta_step = param_.max_delta_step; common::Transform<>::Init( - [=] XGBOOST_DEVICE(size_t _idx, + [=] XGBOOST_DEVICE(size_t _idx, auto has_fp64_support, common::Span _label_correct, common::Span _out_gpair, common::Span _preds, @@ -387,7 +388,7 @@ class PoissonRegression : public FitInterceptGlmLike { } void PredTransform(HostDeviceVector *io_preds) const override { common::Transform<>::Init( - [] XGBOOST_DEVICE(size_t _idx, common::Span _preds) { + [] XGBOOST_DEVICE(size_t _idx, auto has_fp64_support, common::Span _preds) { _preds[_idx] = expf(_preds[_idx]); }, common::Range{0, static_cast(io_preds->Size())}, this->ctx_->Threads(), @@ -566,7 +567,7 @@ class TweedieRegression : public FitInterceptGlmLike { const float rho = param_.tweedie_variance_power; common::Transform<>::Init( - [=] XGBOOST_DEVICE(size_t _idx, + [=] XGBOOST_DEVICE(size_t _idx, auto has_fp64_support, common::Span _label_correct, common::Span _out_gpair, common::Span _preds, @@ -597,7 +598,7 @@ class TweedieRegression : public FitInterceptGlmLike { } void PredTransform(HostDeviceVector *io_preds) const override { common::Transform<>::Init( - [] XGBOOST_DEVICE(size_t _idx, common::Span _preds) { + [] XGBOOST_DEVICE(size_t _idx, auto has_fp64_support, common::Span _preds) { _preds[_idx] = expf(_preds[_idx]); }, common::Range{0, static_cast(io_preds->Size())}, this->ctx_->Threads(), diff --git a/src/tree/split_evaluator.h b/src/tree/split_evaluator.h index f417ff8984ae..c3a3d5f4bd10 100644 --- a/src/tree/split_evaluator.h +++ b/src/tree/split_evaluator.h @@ -180,7 +180,8 @@ class TreeEvaluator { } common::Transform<>::Init( - [=] XGBOOST_DEVICE(size_t, common::Span lower, + [=] XGBOOST_DEVICE(size_t, auto has_fp64_support, + common::Span lower, common::Span upper, common::Span monotone) { lower[leftid] = lower[nodeid]; diff --git a/tests/cpp/common/test_transform_range.cc b/tests/cpp/common/test_transform_range.cc index 4fc06f63907e..fcfd24ffb491 100644 --- a/tests/cpp/common/test_transform_range.cc +++ b/tests/cpp/common/test_transform_range.cc @@ -25,7 +25,8 @@ constexpr DeviceOrd TransformDevice() { template struct TestTransformRange { - void XGBOOST_DEVICE operator()(std::size_t _idx, Span _out, Span _in) { + template + void XGBOOST_DEVICE operator()(std::size_t _idx, kBoolConst has_fp64_support, Span _out, Span _in) { _out[_idx] = _in[_idx]; } }; @@ -59,7 +60,7 @@ TEST(TransformDeathTest, Exception) { const HostDeviceVector in_vec{h_in, DeviceOrd::CPU()}; EXPECT_DEATH( { - Transform<>::Init([](size_t idx, common::Span _in) { _in[idx + 1]; }, + Transform<>::Init([](size_t idx, auto has_fp64_support, common::Span _in) { _in[idx + 1]; }, Range(0, static_cast(kSize)), AllThreadsForTest(), DeviceOrd::CPU()) .Eval(&in_vec); diff --git a/tests/cpp/plugin/test_sycl_transform_range.cc b/tests/cpp/plugin/test_sycl_transform_range.cc index bfae073ac7d3..25abe6d4f76a 100644 --- a/tests/cpp/plugin/test_sycl_transform_range.cc +++ b/tests/cpp/plugin/test_sycl_transform_range.cc @@ -19,7 +19,8 @@ namespace xgboost::common { template struct TestTransformRange { - void operator()(std::size_t _idx, Span _out, Span _in) { + template + void operator()(std::size_t _idx, kBoolConst has_fp64_support, Span _out, Span _in) { _out[_idx] = _in[_idx]; } }; From fd5556f5be2fbfb61e24435677ad899cabdfc89c Mon Sep 17 00:00:00 2001 From: Dmitry Razdoburdin <> Date: Thu, 27 Feb 2025 03:04:18 -0800 Subject: [PATCH 2/2] linting --- src/common/transform.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/common/transform.h b/src/common/transform.h index 1699a4889000..ff45c0b8523d 100644 --- a/src/common/transform.h +++ b/src/common/transform.h @@ -184,7 +184,8 @@ class Transform { void LaunchCPU(Functor func, HDV *...vectors) const { omp_ulong end = static_cast(*(range_.end())); SyncHost(vectors...); - ParallelFor(end, n_threads_, [&](omp_ulong idx) { func(idx, std::true_type(), UnpackHDV(vectors)...); }); + ParallelFor(end, n_threads_, [&](omp_ulong idx) { func(idx, std::true_type(), + UnpackHDV(vectors)...); }); } private: