Skip to content
This repository was archived by the owner on Mar 28, 2023. It is now read-only.
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion SYCL/ESIMD/ext_math.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ struct InitDataInRange0_5 {

// --- Math operation identification

enum class MathOp { sin, cos, exp, sqrt, inv, log, rsqrt };
enum class MathOp { sin, cos, exp, sqrt, inv, log, rsqrt, trunc };

// --- Template functions calculating given math operation on host and device

Expand All @@ -71,6 +71,16 @@ template <MathOp Op> float HostMathFunc(float X);
} \
}

// The same as above but adds explicit template parameter for esimd_##Op.
#define DEFINE_ESIMD_RT_OP(Op, HostOp) \
template <> float HostMathFunc<MathOp::Op>(float X) { return HostOp(X); } \
template <int VL> struct DeviceMathFunc<VL, MathOp::Op> { \
simd<float, VL> \
operator()(const simd<float, VL> &X) const SYCL_ESIMD_FUNCTION { \
return esimd_##Op<float, VL>(X); \
} \
}

#define DEFINE_SIMD_OVERLOADED_STD_SYCL_OP(Op, HostOp) \
template <> float HostMathFunc<MathOp::Op>(float X) { return HostOp(X); } \
template <int VL> struct DeviceMathFunc<VL, MathOp::Op> { \
Expand All @@ -87,6 +97,7 @@ DEFINE_SIMD_OVERLOADED_STD_SYCL_OP(log, log);
DEFINE_ESIMD_OP(inv, 1.0f /);
DEFINE_ESIMD_OP(sqrt, sqrt);
DEFINE_ESIMD_OP(rsqrt, 1.0f / sqrt);
DEFINE_ESIMD_RT_OP(trunc, trunc);

// --- Generic kernel calculating an extended math operation on array elements

Expand Down Expand Up @@ -182,6 +193,7 @@ template <int VL> bool test(queue &Q) {
Pass &= test<MathOp::cos, VL>(Q, "cos", InitDataFuncWide{});
Pass &= test<MathOp::exp, VL>(Q, "exp", InitDataInRange0_5{});
Pass &= test<MathOp::log, VL>(Q, "log", InitDataFuncWide{});
Pass &= test<MathOp::trunc, VL>(Q, "trunc", InitDataFuncWide{});
return Pass;
}

Expand Down