diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td index 377dce99578cc..e126f21b1a5b6 100644 --- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td @@ -855,13 +855,13 @@ def INT_NVVM_FABS_D : F_MATH_1<"abs.f64 \t$dst, $src0;", Float64Regs, // Abs, Neg bf16, bf16x2 // -def INT_NVVM_ABS_BF16 : F_MATH_1<"abs.bf16 \t$dst, $dst;", Int16Regs, +def INT_NVVM_ABS_BF16 : F_MATH_1<"abs.bf16 \t$dst, $src0;", Int16Regs, Int16Regs, int_nvvm_abs_bf16, [hasPTX70, hasSM80]>; -def INT_NVVM_ABS_BF16X2 : F_MATH_1<"abs.bf16x2 \t$dst, $dst;", Int32Regs, +def INT_NVVM_ABS_BF16X2 : F_MATH_1<"abs.bf16x2 \t$dst, $src0;", Int32Regs, Int32Regs, int_nvvm_abs_bf16x2, [hasPTX70, hasSM80]>; -def INT_NVVM_NEG_BF16 : F_MATH_1<"neg.bf16 \t$dst, $dst;", Int16Regs, +def INT_NVVM_NEG_BF16 : F_MATH_1<"neg.bf16 \t$dst, $src0;", Int16Regs, Int16Regs, int_nvvm_neg_bf16, [hasPTX70, hasSM80]>; -def INT_NVVM_NEG_BF16X2 : F_MATH_1<"neg.bf16x2 \t$dst, $dst;", Int32Regs, +def INT_NVVM_NEG_BF16X2 : F_MATH_1<"neg.bf16x2 \t$dst, $src0;", Int32Regs, Int32Regs, int_nvvm_neg_bf16x2, [hasPTX70, hasSM80]>; // diff --git a/sycl/include/CL/__spirv/spirv_ops.hpp b/sycl/include/CL/__spirv/spirv_ops.hpp index e7b660f9d29e7..3c6ce1c639960 100644 --- a/sycl/include/CL/__spirv/spirv_ops.hpp +++ b/sycl/include/CL/__spirv/spirv_ops.hpp @@ -793,6 +793,30 @@ extern SYCL_EXTERNAL __ocl_vec_t<_Float16, 8> extern SYCL_EXTERNAL __ocl_vec_t<_Float16, 16> __clc_native_exp2(__ocl_vec_t<_Float16, 16>); +#define __CLC_BF16(...) \ + extern SYCL_EXTERNAL __SYCL_EXPORT __VA_ARGS__ __clc_fabs( \ + __VA_ARGS__) noexcept; \ + extern SYCL_EXTERNAL __SYCL_EXPORT __VA_ARGS__ __clc_fmin( \ + __VA_ARGS__, __VA_ARGS__) noexcept; \ + extern SYCL_EXTERNAL __SYCL_EXPORT __VA_ARGS__ __clc_fmax( \ + __VA_ARGS__, __VA_ARGS__) noexcept; \ + extern SYCL_EXTERNAL __SYCL_EXPORT __VA_ARGS__ __clc_fma( \ + __VA_ARGS__, __VA_ARGS__, __VA_ARGS__) noexcept; + +#define __CLC_BF16_SCAL_VEC(TYPE) \ + __CLC_BF16(TYPE) \ + __CLC_BF16(__ocl_vec_t) \ + __CLC_BF16(__ocl_vec_t) \ + __CLC_BF16(__ocl_vec_t) \ + __CLC_BF16(__ocl_vec_t) \ + __CLC_BF16(__ocl_vec_t) + +__CLC_BF16_SCAL_VEC(uint16_t) +__CLC_BF16_SCAL_VEC(uint32_t) + +#undef __CLC_BF16_SCAL_VEC +#undef __CLC_BF16 + #else // if !__SYCL_DEVICE_ONLY__ template diff --git a/sycl/include/CL/sycl.hpp b/sycl/include/CL/sycl.hpp index a7278b275aa9d..8ca7d28223cad 100644 --- a/sycl/include/CL/sycl.hpp +++ b/sycl/include/CL/sycl.hpp @@ -60,6 +60,7 @@ #if SYCL_EXT_ONEAPI_BACKEND_LEVEL_ZERO #include #endif +#include #include #include #include diff --git a/sycl/include/sycl/ext/oneapi/bf16_storage_builtins.hpp b/sycl/include/sycl/ext/oneapi/bf16_storage_builtins.hpp new file mode 100644 index 0000000000000..88737c6c668d5 --- /dev/null +++ b/sycl/include/sycl/ext/oneapi/bf16_storage_builtins.hpp @@ -0,0 +1,79 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +__SYCL_INLINE_NAMESPACE(cl) { +namespace sycl { +namespace ext { +namespace oneapi { + +namespace detail { + +template struct is_bf16_storage_type { + static constexpr int value = false; +}; + +template <> struct is_bf16_storage_type { + static constexpr int value = true; +}; + +template <> struct is_bf16_storage_type { + static constexpr int value = true; +}; + +template struct is_bf16_storage_type> { + static constexpr int value = true; +}; + +template struct is_bf16_storage_type> { + static constexpr int value = true; +}; + +} // namespace detail + +template +std::enable_if_t::value, T> fabs(T x) { +#ifdef __SYCL_DEVICE_ONLY__ + return __clc_fabs(x); +#else + throw runtime_error("bf16 is not supported on host device.", + PI_INVALID_DEVICE); +#endif +} +template +std::enable_if_t::value, T> fmin(T x, T y) { +#ifdef __SYCL_DEVICE_ONLY__ + return __clc_fmin(x, y); +#else + throw runtime_error("bf16 is not supported on host device.", + PI_INVALID_DEVICE); +#endif +} +template +std::enable_if_t::value, T> fmax(T x, T y) { +#ifdef __SYCL_DEVICE_ONLY__ + return __clc_fmax(x, y); +#else + throw runtime_error("bf16 is not supported on host device.", + PI_INVALID_DEVICE); +#endif +} +template +std::enable_if_t::value, T> fma(T x, T y, T z) { +#ifdef __SYCL_DEVICE_ONLY__ + return __clc_fma(x, y, z); +#else + throw runtime_error("bf16 is not supported on host device.", + PI_INVALID_DEVICE); +#endif +} + +} // namespace oneapi +} // namespace ext +} // namespace sycl +} // __SYCL_INLINE_NAMESPACE(cl)