Skip to content

Commit ab8115a

Browse files
committed
sys tag.
1 parent cdbf3a8 commit ab8115a

File tree

1 file changed

+22
-9
lines changed

1 file changed

+22
-9
lines changed

src/common/linalg_op.h

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,18 @@ auto end(TensorView<T, kDim>& v) { // NOLINT
125125
return begin(v) + v.Size();
126126
}
127127

128+
// A tag to workaround the one definition rule.
129+
template <bool kWithCuda, bool kWithSycl>
130+
struct SysTagImpl {};
131+
132+
#if defined(__CUDACC__)
133+
using SysTag = SysTagImpl<true, false>;
134+
#elif defined(SYCL_LANGUAGE_VERSION)
135+
using SysTag = SysTagImpl<false, true>;
136+
#else
137+
using SysTag = SysTagImpl<false, false>;
138+
#endif
139+
128140
/**
129141
* @brief Elementwise kernel without a return type.
130142
*
@@ -136,8 +148,8 @@ auto end(TensorView<T, kDim>& v) { // NOLINT
136148
* @param t Input array.
137149
* @param fn Transformation function.
138150
*/
139-
template <typename T, std::int32_t D, typename Fn, bool CompiledWithCuda = WITH_CUDA()>
140-
void ElementWiseKernel(Context const* ctx, TensorView<T, D> t, Fn&& fn) {
151+
template <typename T, std::int32_t D, typename Fn>
152+
void ElementWiseKernel(Context const* ctx, TensorView<T, D> t, Fn&& fn, SysTag = SysTag{}) {
141153
ctx->DispatchDevice([&] { cpu_impl::ElementWiseKernel(t, ctx->Threads(), std::forward<Fn>(fn)); },
142154
[&] {
143155
#if defined(__CUDACC__)
@@ -167,8 +179,8 @@ void ElementWiseKernel(Context const* ctx, TensorView<T, D> t, Fn&& fn) {
167179
* @param t Input array.
168180
* @param fn Transformation function, must return type T.
169181
*/
170-
template <typename T, std::int32_t D, typename Fn, bool CompiledWithCuda = WITH_CUDA()>
171-
void TransformIdxKernel(Context const* ctx, TensorView<T, D> t, Fn&& fn) {
182+
template <typename T, std::int32_t D, typename Fn>
183+
void TransformIdxKernel(Context const* ctx, TensorView<T, D> t, Fn&& fn, SysTag = SysTag{}) {
172184
ctx->DispatchDevice([&] { cpu_impl::TransformIdxKernel(t, ctx->Threads(), fn); },
173185
[&] {
174186
#if defined(__CUDACC__)
@@ -192,8 +204,8 @@ void TransformIdxKernel(Context const* ctx, TensorView<T, D> t, Fn&& fn) {
192204
* @brief Elementwise transform, with the element itself as input. Rest is the same as @ref
193205
* TransformIdxKernel
194206
*/
195-
template <typename T, std::int32_t D, typename Fn, bool CompiledWithCuda = WITH_CUDA()>
196-
void TransformKernel(Context const* ctx, TensorView<T, D> t, Fn&& fn) {
207+
template <typename T, std::int32_t D, typename Fn>
208+
void TransformKernel(Context const* ctx, TensorView<T, D> t, Fn&& fn, SysTag = SysTag{}) {
197209
ctx->DispatchDevice([&] { cpu_impl::TransformKernel(t, ctx->Threads(), fn); },
198210
[&] {
199211
#if defined(__CUDACC__)
@@ -214,17 +226,18 @@ void TransformKernel(Context const* ctx, TensorView<T, D> t, Fn&& fn) {
214226
}
215227

216228
// vector-scalar multiplication
217-
inline void VecScaMul(Context const* ctx, linalg::VectorView<float> x, double mul) {
229+
inline void VecScaMul(Context const* ctx, linalg::VectorView<float> x, double mul, SysTag = SysTag{}) {
218230
CHECK_EQ(x.Device().ordinal, ctx->Device().ordinal);
219231
TransformKernel(ctx, x, [=] XGBOOST_DEVICE(float v) { return v * mul; });
220232
}
221233

222234
// vector-scalar division
223-
inline void VecScaDiv(Context const* ctx, linalg::VectorView<float> x, double div) {
235+
inline void VecScaDiv(Context const* ctx, linalg::VectorView<float> x, double div,
236+
SysTag = SysTag{}) {
224237
return VecScaMul(ctx, x, 1.0 / div);
225238
}
226239

227-
inline void LogE(Context const* ctx, linalg::VectorView<float> x) {
240+
inline void LogE(Context const* ctx, linalg::VectorView<float> x, SysTag = SysTag{}) {
228241
CHECK_EQ(x.Device().ordinal, ctx->Device().ordinal);
229242
TransformKernel(ctx, x, [=] XGBOOST_DEVICE(float v) { return log(v); });
230243
}

0 commit comments

Comments
 (0)