Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions include/xgboost/linalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ enum Order : std::uint8_t {
* some functions expect data types that can be used in everywhere (update prediction
* cache for example).
*/
template <typename T, int32_t kDim>
template <typename T, std::int32_t kDim>
class TensorView {
public:
using ShapeT = std::size_t[kDim];
Expand All @@ -300,7 +300,7 @@ class TensorView {
}
}

template <size_t old_dim, size_t new_dim, int32_t D, typename I>
template <size_t old_dim, size_t new_dim, std::int32_t D, typename I>
LINALG_HD size_t MakeSliceDim(std::size_t new_shape[D], std::size_t new_stride[D],
detail::RangeTag<I> &&range) const {
static_assert(new_dim < D);
Expand Down
20 changes: 0 additions & 20 deletions plugin/sycl/common/linalg_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
#include <sycl/sycl.hpp>

namespace xgboost::sycl::linalg {

void SmallHistogram(Context const* ctx, xgboost::linalg::MatrixView<float const> indices,
xgboost::common::OptionalWeights const& weights,
xgboost::linalg::VectorView<float> bins) {
Expand All @@ -30,23 +29,4 @@ void SmallHistogram(Context const* ctx, xgboost::linalg::MatrixView<float const>
});
}).wait();
}

void VecScaMul(Context const* ctx, xgboost::linalg::VectorView<float> x, double mul) {
sycl::DeviceManager device_manager;
auto* qu = device_manager.GetQueue(ctx->Device());

qu->submit([&](::sycl::handler& cgh) {
cgh.parallel_for<>(::sycl::range<1>(x.Size()),
[=](::sycl::id<1> pid) {
const size_t i = pid[0];
const_cast<float&>(x(i)) *= mul;
});
}).wait();
}
} // namespace xgboost::sycl::linalg

namespace xgboost::linalg::sycl_impl {
void VecScaMul(Context const* ctx, xgboost::linalg::VectorView<float> x, double mul) {
xgboost::sycl::linalg::VecScaMul(ctx, x, mul);
}
} // namespace xgboost::linalg::sycl_impl
14 changes: 0 additions & 14 deletions plugin/sycl/common/linalg_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
#include <vector>
#include <utility>

#include "../../../src/common/linalg_op.h"

#include "../data.h"
#include "../device_manager.h"

Expand Down Expand Up @@ -99,17 +97,5 @@ bool Validate(DeviceOrd device, TensorView<T, D> t, Fn&& fn) {

} // namespace linalg
} // namespace sycl

namespace linalg {
template <typename T, int32_t D, typename Fn>
void ElementWiseKernel(Context const* ctx, TensorView<T, D> t, Fn&& fn) {
if (ctx->IsSycl()) {
sycl::linalg::ElementWiseKernel(t, fn);
} else {
ElementWiseKernelHost(t, ctx->Threads(), fn);
}
}

} // namespace linalg
} // namespace xgboost
#endif // PLUGIN_SYCL_COMMON_LINALG_OP_H_
14 changes: 7 additions & 7 deletions plugin/sycl/tree/hist_updater.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "../../src/tree/common_row_partitioner.h"

#include "../common/hist_util.h"
#include "xgboost/linalg.h"
#include "../../src/collective/allreduce.h"

namespace xgboost {
Expand All @@ -34,8 +35,8 @@ void HistUpdater<GradientSumT>::ReduceHists(const std::vector<int>& sync_ids,
qu_->memcpy(reduce_buffer_.data() + i * nbins, psrc, nbins*sizeof(GradientPairT)).wait();
}

auto buffer_vec = linalg::MakeVec(reinterpret_cast<GradientSumT*>(reduce_buffer_.data()),
2 * nbins * sync_ids.size());
auto buffer_vec = ::xgboost::linalg::MakeVec(
reinterpret_cast<GradientSumT*>(reduce_buffer_.data()), 2 * nbins * sync_ids.size());
auto rc = collective::Allreduce(ctx_, buffer_vec, collective::Op::kSum);
SafeColl(rc);

Expand Down Expand Up @@ -361,10 +362,9 @@ void HistUpdater<GradientSumT>::Update(
builder_monitor_.Stop("Update");
}

template<typename GradientSumT>
template <typename GradientSumT>
bool HistUpdater<GradientSumT>::UpdatePredictionCache(
const DMatrix* data,
linalg::MatrixView<float> out_preds) {
const DMatrix* data, ::xgboost::linalg::MatrixView<float> out_preds) {
CHECK(out_preds.Device().IsSycl());
// p_last_fmat_ is a valid pointer as long as UpdatePredictionCache() is called in
// conjunction with Update().
Expand Down Expand Up @@ -723,8 +723,8 @@ void HistUpdater<GradientSumT>::InitNewNode(int nid,
}).wait_and_throw();
}
auto rc = collective::Allreduce(
ctx_, linalg::MakeVec(reinterpret_cast<GradientSumT*>(&grad_stat), 2),
collective::Op::kSum);
ctx_, ::xgboost::linalg::MakeVec(reinterpret_cast<GradientSumT*>(&grad_stat), 2),
collective::Op::kSum);
SafeColl(rc);
snode_host_[nid].stats = grad_stat;
} else {
Expand Down
6 changes: 3 additions & 3 deletions plugin/sycl/tree/hist_updater.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*!
* Copyright 2017-2024 by Contributors
* Copyright 2017-2025, XGBoost Contributors
* \file hist_updater.h
*/
#ifndef PLUGIN_SYCL_TREE_HIST_UPDATER_H_
Expand All @@ -8,6 +8,7 @@
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wtautological-constant-compare"
#pragma GCC diagnostic ignored "-W#pragma-messages"
#include <xgboost/linalg.h> // for MatrixView
#include <xgboost/tree_updater.h>
#pragma GCC diagnostic pop

Expand Down Expand Up @@ -80,8 +81,7 @@ class HistUpdater {
xgboost::common::Span<HostDeviceVector<bst_node_t>> out_position,
RegTree *p_tree);

bool UpdatePredictionCache(const DMatrix* data,
linalg::MatrixView<float> p_out_preds);
bool UpdatePredictionCache(const DMatrix* data, ::xgboost::linalg::MatrixView<float> p_out_preds);

void SetHistSynchronizer(HistSynchronizer<GradientSumT>* sync);
void SetHistRowsAdder(HistRowsAdder<GradientSumT>* adder);
Expand Down
18 changes: 8 additions & 10 deletions plugin/sycl/tree/updater_quantile_hist.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,12 @@ void QuantileHistMaker::SetPimpl(std::unique_ptr<HistUpdater<GradientSumT>>* pim
}
}

template<typename GradientSumT>
void QuantileHistMaker::CallUpdate(
const std::unique_ptr<HistUpdater<GradientSumT>>& pimpl,
xgboost::tree::TrainParam const *param,
linalg::Matrix<GradientPair> *gpair,
DMatrix *dmat,
xgboost::common::Span<HostDeviceVector<bst_node_t>> out_position,
const std::vector<RegTree *> &trees) {
template <typename GradientSumT>
void QuantileHistMaker::CallUpdate(const std::unique_ptr<HistUpdater<GradientSumT>> &pimpl,
xgboost::tree::TrainParam const *param,
::xgboost::linalg::Matrix<GradientPair> *gpair, DMatrix *dmat,
xgboost::common::Span<HostDeviceVector<bst_node_t>> out_position,
const std::vector<RegTree *> &trees) {
for (auto tree : trees) {
pimpl->Update(param, gmat_, *(gpair->Data()), dmat, out_position, tree);
}
Expand Down Expand Up @@ -107,8 +105,8 @@ void QuantileHistMaker::Update(xgboost::tree::TrainParam const *param, GradientC
p_last_dmat_ = dmat;
}

bool QuantileHistMaker::UpdatePredictionCache(const DMatrix* data,
linalg::MatrixView<float> out_preds) {
bool QuantileHistMaker::UpdatePredictionCache(const DMatrix *data,
::xgboost::linalg::MatrixView<float> out_preds) {
if (param_.subsample < 1.0f) return false;

if (hist_precision_ == HistPrecision::fp32) {
Expand Down
4 changes: 2 additions & 2 deletions plugin/sycl/tree/updater_quantile_hist.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class QuantileHistMaker: public TreeUpdater {
const std::vector<RegTree*>& trees) override;

bool UpdatePredictionCache(const DMatrix* data,
linalg::MatrixView<float> out_preds) override;
::xgboost::linalg::MatrixView<float> out_preds) override;

void LoadConfig(Json const& in) override {
auto const& config = get<Object const>(in);
Expand Down Expand Up @@ -90,7 +90,7 @@ class QuantileHistMaker: public TreeUpdater {
template<typename GradientSumT>
void CallUpdate(const std::unique_ptr<HistUpdater<GradientSumT>>& builder,
xgboost::tree::TrainParam const *param,
linalg::Matrix<GradientPair> *gpair,
::xgboost::linalg::Matrix<GradientPair> *gpair,
DMatrix *dmat,
xgboost::common::Span<HostDeviceVector<bst_node_t>> out_position,
const std::vector<RegTree *> &trees);
Expand Down
9 changes: 1 addition & 8 deletions src/common/linalg_op.cu
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
/**
* Copyright 2025, XGBoost Contributors
*/
#include <thrust/for_each.h> // for for_each_n
#include <thrust/iterator/counting_iterator.h> // for make_counting_iterator
#include <thrust/scan.h> // for inclusive_scan
#include <thrust/scan.h> // for inclusive_scan

#include <cstddef> // for size_t

Expand All @@ -15,11 +13,6 @@
#include "xgboost/linalg.h" // for VectorView

namespace xgboost::linalg::cuda_impl {
void VecScaMul(Context const* ctx, linalg::VectorView<float> x, double mul) {
thrust::for_each_n(ctx->CUDACtx()->CTP(), thrust::make_counting_iterator(0ul), x.Size(),
[=] XGBOOST_DEVICE(std::size_t i) mutable { x(i) = x(i) * mul; });
}

void SmallHistogram(Context const* ctx, linalg::MatrixView<float const> indices,
common::OptionalWeights const& d_weights, linalg::VectorView<float> bins) {
auto n_bins = bins.Size();
Expand Down
84 changes: 52 additions & 32 deletions src/common/linalg_op.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,21 @@
#ifndef XGBOOST_COMMON_LINALG_OP_CUH_
#define XGBOOST_COMMON_LINALG_OP_CUH_

#include <cstdint> // for int32_t
#include <cstdlib> // for size_t
#include <tuple> // for apply
#include <thrust/iterator/counting_iterator.h> // for counting_iterator
#include <thrust/iterator/zip_iterator.h> // for make_zip_iterator
#include <thrust/transform.h> // for transform

#include <cstdint> // for int32_t
#include <cstdlib> // for size_t
#include <cuda/std/iterator> // for iterator_traits
#include <cuda/std/tuple> // for get
#include <tuple> // for apply

#include "cuda_context.cuh"
#include "device_helpers.cuh" // for LaunchN
#include "linalg_op.h"
#include "xgboost/context.h" // for Context
#include "xgboost/linalg.h" // for TensorView
#include "type.h" // for GetValueT
#include "xgboost/context.h" // for Context
#include "xgboost/linalg.h" // for TensorView

namespace xgboost::linalg {
namespace cuda_impl {
Expand Down Expand Up @@ -40,17 +46,22 @@ struct ElementWiseImpl<T, 1> {
template <typename T, std::int32_t D, typename Fn>
void ElementWiseKernel(TensorView<T, D> t, Fn&& fn, cudaStream_t s = nullptr) {
dh::safe_cuda(cudaSetDevice(t.Device().ordinal));
cuda_impl::ElementWiseImpl<T, D>{}(t, fn, s);
ElementWiseImpl<T, D>{}(t, fn, s);
}

void VecScaMul(Context const* ctx, linalg::VectorView<float> x, double mul);
} // namespace cuda_impl

template <typename T, int32_t D, typename Fn>
void ElementWiseTransformDevice(TensorView<T, D> t, Fn&& fn, cudaStream_t s = nullptr) {
template <typename T, std::int32_t D, typename Fn>
void TransformIdxKernel(Context const* ctx, TensorView<T, D> t, Fn&& fn) {
dh::safe_cuda(cudaSetDevice(t.Device().ordinal));
auto s = ctx->CUDACtx()->Stream();
if (t.Contiguous()) {
auto ptr = t.Values().data();
dh::LaunchN(t.Size(), s, [=] __device__(size_t i) { ptr[i] = fn(i, ptr[i]); });
auto it =
thrust::make_zip_iterator(thrust::make_counting_iterator(static_cast<std::size_t>(0)), ptr);
using Tuple = typename cuda::std::iterator_traits<common::GetValueT<decltype(it)>>::value_type;
thrust::transform(ctx->CUDACtx()->CTP(), it, it + t.Size(), ptr,
[=] XGBOOST_DEVICE(Tuple const& tup) {
return fn(cuda::std::get<0>(tup), cuda::std::get<1>(tup));
});
} else {
dh::LaunchN(t.Size(), s, [=] __device__(size_t i) mutable {
T& v = std::apply(t, UnravelIndex(i, t.Shape()));
Expand All @@ -59,44 +70,53 @@ void ElementWiseTransformDevice(TensorView<T, D> t, Fn&& fn, cudaStream_t s = nu
}
}

template <typename T, int32_t D, typename Fn>
void ElementWiseKernel(Context const* ctx, TensorView<T, D> t, Fn&& fn) {
ctx->IsCUDA() ? cuda_impl::ElementWiseKernel(t, fn)
: ElementWiseKernelHost(t, ctx->Threads(), fn);
template <typename T, std::int32_t D, typename Fn>
void TransformKernel(Context const* ctx, TensorView<T, D> t, Fn&& fn) {
dh::safe_cuda(cudaSetDevice(t.Device().ordinal));
auto s = ctx->CUDACtx()->Stream();
if (t.Contiguous()) {
auto ptr = t.Values().data();
thrust::transform(ctx->CUDACtx()->CTP(), ptr, ptr + t.Size(), ptr,
[=] XGBOOST_DEVICE(T const& v) { return fn(v); });
} else {
dh::LaunchN(t.Size(), s, [=] __device__(size_t i) mutable {
T& v = std::apply(t, UnravelIndex(i, t.Shape()));
v = fn(v);
});
}
}
} // namespace cuda_impl

namespace detail {
template <typename T, std::int32_t kDim>
template <typename T, std::int32_t D>
struct IterOp {
TensorView<T, kDim> v;
XGBOOST_DEVICE T& operator()(std::size_t i) {
return std::apply(v, UnravelIndex(i, v.Shape()));
}
TensorView<T, D> v;
XGBOOST_DEVICE T& operator()(std::size_t i) { return std::apply(v, UnravelIndex(i, v.Shape())); }
};
} // namespace detail

// naming: thrust begin
// returns a thrust iterator for a tensor view.
template <typename T, std::int32_t kDim>
auto tcbegin(TensorView<T, kDim> v) { // NOLINT
template <typename T, std::int32_t D>
auto tcbegin(TensorView<T, D> v) { // NOLINT
return thrust::make_transform_iterator(
thrust::make_counting_iterator(0ul),
detail::IterOp<std::add_const_t<std::remove_const_t<T>>, kDim>{v});
detail::IterOp<std::add_const_t<std::remove_const_t<T>>, D>{v});
}

template <typename T, std::int32_t kDim>
auto tcend(TensorView<T, kDim> v) { // NOLINT
template <typename T, std::int32_t D>
auto tcend(TensorView<T, D> v) { // NOLINT
return tcbegin(v) + v.Size();
}

template <typename T, std::int32_t kDim>
auto tbegin(TensorView<T, kDim> v) { // NOLINT
template <typename T, std::int32_t D>
auto tbegin(TensorView<T, D> v) { // NOLINT
return thrust::make_transform_iterator(thrust::make_counting_iterator(0ul),
detail::IterOp<std::remove_const_t<T>, kDim>{v});
detail::IterOp<std::remove_const_t<T>, D>{v});
}

template <typename T, std::int32_t kDim>
auto tend(TensorView<T, kDim> v) { // NOLINT
template <typename T, std::int32_t D>
auto tend(TensorView<T, D> v) { // NOLINT
return tbegin(v) + v.Size();
}
} // namespace xgboost::linalg
Expand Down
Loading
Loading