Skip to content

[SYCL] Add support for bfloat16 conversion #4213

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Aug 26, 2021
3 changes: 3 additions & 0 deletions sycl/include/CL/__spirv/spirv_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,9 @@ extern SYCL_EXTERNAL void
__spirv_ocl_prefetch(const __attribute__((opencl_global)) char *Ptr,
size_t NumBytes) noexcept;

extern SYCL_EXTERNAL uint16_t __spirv_ConvertFToBF16INTEL(float) noexcept;
extern SYCL_EXTERNAL float __spirv_ConvertBF16ToFINTEL(uint16_t) noexcept;

#else // if !__SYCL_DEVICE_ONLY__

template <typename dataT>
Expand Down
1 change: 1 addition & 0 deletions sycl/include/CL/sycl/feature_test.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ namespace sycl {
#ifndef SYCL_EXT_ONEAPI_MATRIX
#define SYCL_EXT_ONEAPI_MATRIX 2
#endif
#define SYCL_EXT_INTEL_BF16_CONVERSION 1

} // namespace sycl
} // __SYCL_INLINE_NAMESPACE(cl)
148 changes: 148 additions & 0 deletions sycl/include/sycl/ext/intel/experimental/bfloat16.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
//==--------- bfloat16.hpp ------- SYCL bfloat16 conversion ----------------==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#pragma once

#include <CL/__spirv/spirv_ops.hpp>

__SYCL_INLINE_NAMESPACE(cl) {
namespace sycl {
namespace ext {
namespace intel {
namespace experimental {

class [[sycl_detail::uses_aspects(ext_intel_bf16_conversion)]] bfloat16 {
using storage_t = uint16_t;
storage_t value;

public:
bfloat16() = default;
bfloat16(const bfloat16 &) = default;
~bfloat16() = default;

// Explicit conversion functions
static storage_t from_float(const float &a) {
#if defined(__SYCL_DEVICE_ONLY__)
return __spirv_ConvertFToBF16INTEL(a);
#else
throw exception{errc::feature_not_supported,
"Bfloat16 conversion is not supported on host device"};
#endif
}
static float to_float(const storage_t &a) {
#if defined(__SYCL_DEVICE_ONLY__)
return __spirv_ConvertBF16ToFINTEL(a);
#else
throw exception{errc::feature_not_supported,
"Bfloat16 conversion is not supported on host device"};
#endif
}

// Direct initialization
bfloat16(const storage_t &a) : value(a) {}

// Implicit conversion from float to bfloat16
bfloat16(const float &a) { value = from_float(a); }

bfloat16 &operator=(const float &rhs) {
value = from_float(rhs);
return *this;
}

// Implicit conversion from bfloat16 to float
operator float() const { return to_float(value); }

// Get raw bits representation of bfloat16
operator storage_t() const { return value; }

// Logical operators (!,||,&&) are covered if we can cast to bool
explicit operator bool() { return to_float(value) != 0.0f; }

// Unary minus operator overloading
friend bfloat16 operator-(bfloat16 &lhs) {
return bfloat16{-to_float(lhs.value)};
}

// Increment and decrement operators overloading
#define OP(op) \
friend bfloat16 &operator op(bfloat16 &lhs) { \
float f = to_float(lhs.value); \
lhs.value = from_float(op f); \
return lhs; \
} \
friend bfloat16 operator op(bfloat16 &lhs, int) { \
bfloat16 old = lhs; \
operator op(lhs); \
return old; \
}
OP(++)
OP(--)
#undef OP

// Assignment operators overloading
#define OP(op) \
friend bfloat16 &operator op(bfloat16 &lhs, const bfloat16 &rhs) { \
float f = static_cast<float>(lhs); \
f op static_cast<float>(rhs); \
return lhs = f; \
} \
template <typename T> \
friend bfloat16 &operator op(bfloat16 &lhs, const T &rhs) { \
float f = static_cast<float>(lhs); \
f op static_cast<float>(rhs); \
return lhs = f; \
} \
template <typename T> friend T &operator op(T &lhs, const bfloat16 &rhs) { \
float f = static_cast<float>(lhs); \
f op static_cast<float>(rhs); \
return lhs = f; \
}
OP(+=)
OP(-=)
OP(*=)
OP(/=)
#undef OP

// Binary operators overloading
#define OP(type, op) \
friend type operator op(const bfloat16 &lhs, const bfloat16 &rhs) { \
return type{static_cast<float>(lhs) op static_cast<float>(rhs)}; \
} \
template <typename T> \
friend type operator op(const bfloat16 &lhs, const T &rhs) { \
return type{static_cast<float>(lhs) op static_cast<float>(rhs)}; \
} \
template <typename T> \
friend type operator op(const T &lhs, const bfloat16 &rhs) { \
return type{static_cast<float>(lhs) op static_cast<float>(rhs)}; \
}
OP(bfloat16, +)
OP(bfloat16, -)
OP(bfloat16, *)
OP(bfloat16, /)
OP(bool, ==)
OP(bool, !=)
OP(bool, <)
OP(bool, >)
OP(bool, <=)
OP(bool, >=)
#undef OP

// Bitwise(|,&,~,^), modulo(%) and shift(<<,>>) operations are not supported
// for floating-point types.
};

} // namespace experimental
} // namespace intel
} // namespace ext

namespace __SYCL2020_DEPRECATED("use 'ext::intel' instead") INTEL {
using namespace ext::intel;
}
} // namespace sycl
} // __SYCL_INLINE_NAMESPACE(cl)
51 changes: 51 additions & 0 deletions sycl/test/extensions/bfloat16.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// RUN: %clangxx -fsycl-device-only -S -Xclang -emit-llvm %s -o - | FileCheck %s

#include <sycl/sycl.hpp>
#include <sycl/ext/intel/experimental/bfloat16.hpp>

using sycl::ext::intel::experimental::bfloat16;

SYCL_EXTERNAL uint16_t some_bf16_intrinsic(uint16_t x, uint16_t y);

__attribute__((noinline))
float op(float a, float b) {
bfloat16 A {a};
// CHECK: [[A:%.*]] = tail call spir_func zeroext i16 @_Z27__spirv_ConvertFToBF16INTELf(float %a)
// CHECK-NOT: fptoui

bfloat16 B {b};
// CHECK: [[B:%.*]] = tail call spir_func zeroext i16 @_Z27__spirv_ConvertFToBF16INTELf(float %b)
// CHECK-NOT: fptoui

bfloat16 C = A + B;
// CHECK: [[A_float:%.*]] = tail call spir_func float @_Z27__spirv_ConvertBF16ToFINTELt(i16 zeroext [[A]])
// CHECK: [[B_float:%.*]] = tail call spir_func float @_Z27__spirv_ConvertBF16ToFINTELt(i16 zeroext [[B]])
// CHECK: [[Add:%.*]] = fadd float [[A_float]], [[B_float]]
// CHECK: [[C:%.*]] = tail call spir_func zeroext i16 @_Z27__spirv_ConvertFToBF16INTELf(float [[Add]])
// CHECK-NOT: uitofp
// CHECK-NOT: fptoui

bfloat16 D = some_bf16_intrinsic(A, C);
// CHECK: [[D:%.*]] = tail call spir_func zeroext i16 @_Z19some_bf16_intrinsictt(i16 zeroext [[A]], i16 zeroext [[C]])
// CHECK-NOT: uitofp
// CHECK-NOT: fptoui

return D;
// CHECK: [[RetVal:%.*]] = tail call spir_func float @_Z27__spirv_ConvertBF16ToFINTELt(i16 zeroext [[D]])
// CHECK: ret float [[RetVal]]
// CHECK-NOT: uitofp
// CHECK-NOT: fptoui
}

int main(int argc, char *argv[]) {
float data[3] = {7.0, 8.1, 0.0};
cl::sycl::queue deviceQueue;
cl::sycl::buffer<float, 1> buf{data, cl::sycl::range<1>{3}};

deviceQueue.submit([&](cl::sycl::handler &cgh) {
auto numbers = buf.get_access<cl::sycl::access::mode::read_write>(cgh);
cgh.single_task<class simple_kernel>(
[=]() { numbers[2] = op(numbers[0], numbers[1]); });
});
return 0;
}