Skip to content

Commit 95e1434

Browse files
authored
Add bfloat16 data type (#25402)
1 parent 3ba7b9b commit 95e1434

19 files changed

+832
-63
lines changed

paddle/fluid/framework/data_layout_transform.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,8 @@ void* GetDataFromTensor(const Tensor& tensor, mkldnn::memory::data_type type) {
116116
return platform::to_void_cast(tensor.data<unsigned char>());
117117
case mkldnn::memory::data_type::s32:
118118
return platform::to_void_cast(tensor.data<int32_t>());
119+
case mkldnn::memory::data_type::bf16:
120+
return platform::to_void_cast(tensor.data<paddle::platform::bfloat16>());
119121
default:
120122
PADDLE_THROW(
121123
platform::errors::InvalidArgument("Wrong mkldnn type provided."));

paddle/fluid/framework/data_layout_transform.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,8 @@ inline MKLDNNDataType ToMKLDNNDataType(proto::VarType::Type type) {
6161
{DataTypeTrait<float>::DataType(), MKLDNNDataType::f32},
6262
{DataTypeTrait<int8_t>::DataType(), MKLDNNDataType::s8},
6363
{DataTypeTrait<uint8_t>::DataType(), MKLDNNDataType::u8},
64-
{DataTypeTrait<int32_t>::DataType(), MKLDNNDataType::s32}};
64+
{DataTypeTrait<int32_t>::DataType(), MKLDNNDataType::s32},
65+
{DataTypeTrait<platform::bfloat16>::DataType(), MKLDNNDataType::bf16}};
6566
auto iter = dict.find(static_cast<int>(type));
6667
if (iter != dict.end()) return iter->second;
6768
return MKLDNNDataType::undef;
@@ -74,6 +75,9 @@ void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout,
7475
void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
7576
const OpKernelType& expected_kernel_type,
7677
const Tensor& in, Tensor* out);
78+
79+
void* GetDataFromTensor(const Tensor& tensor, MKLDNNDataType type);
80+
7781
#endif
7882

7983
std::vector<int> GetAxis(const DataLayout& from, const DataLayout& to);

paddle/fluid/framework/data_layout_transform_test.cc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,17 @@ TEST(DataTransform, DataLayoutFunction) {
4343
EXPECT_TRUE(in.layout() == paddle::framework::DataLayout::kNHWC);
4444
EXPECT_TRUE(in.dims() == paddle::framework::make_ddim({2, 3, 1, 2}));
4545
}
46+
47+
#ifdef PADDLE_WITH_MKLDNN
48+
TEST(DataTransform, GetDataFromTensorDNNL) {
49+
auto place = paddle::platform::CPUPlace();
50+
paddle::framework::Tensor in = paddle::framework::Tensor();
51+
in.mutable_data<paddle::platform::bfloat16>(
52+
paddle::framework::make_ddim({2, 3, 1, 2}), place);
53+
54+
void* in_data =
55+
paddle::framework::GetDataFromTensor(in, dnnl::memory::data_type::bf16);
56+
EXPECT_EQ(in_data, paddle::platform::to_void_cast(
57+
in.data<paddle::platform::bfloat16>()));
58+
}
59+
#endif

paddle/fluid/framework/data_type.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include <unordered_map>
1919

2020
using float16 = paddle::platform::float16;
21+
using bfloat16 = paddle::platform::bfloat16;
2122

2223
namespace paddle {
2324
namespace framework {

paddle/fluid/framework/data_type.h

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ limitations under the License. */
1717
#include <typeindex>
1818
#include "paddle/fluid/framework/framework.pb.h"
1919
#include "paddle/fluid/platform/enforce.h"
20+
21+
#include "paddle/fluid/platform/bfloat16.h"
2022
#include "paddle/fluid/platform/float16.h"
2123

2224
namespace paddle {
@@ -36,15 +38,16 @@ struct DataTypeTrait<void> {
3638
#define _ForEachDataTypeHelper_(callback, cpp_type, proto_type) \
3739
callback(cpp_type, ::paddle::framework::proto::VarType::proto_type);
3840

39-
#define _ForEachDataType_(callback) \
40-
_ForEachDataTypeHelper_(callback, float, FP32); \
41-
_ForEachDataTypeHelper_(callback, ::paddle::platform::float16, FP16); \
42-
_ForEachDataTypeHelper_(callback, double, FP64); \
43-
_ForEachDataTypeHelper_(callback, int, INT32); \
44-
_ForEachDataTypeHelper_(callback, int64_t, INT64); \
45-
_ForEachDataTypeHelper_(callback, bool, BOOL); \
46-
_ForEachDataTypeHelper_(callback, uint8_t, UINT8); \
47-
_ForEachDataTypeHelper_(callback, int16_t, INT16); \
41+
#define _ForEachDataType_(callback) \
42+
_ForEachDataTypeHelper_(callback, float, FP32); \
43+
_ForEachDataTypeHelper_(callback, ::paddle::platform::float16, FP16); \
44+
_ForEachDataTypeHelper_(callback, ::paddle::platform::bfloat16, BF16); \
45+
_ForEachDataTypeHelper_(callback, double, FP64); \
46+
_ForEachDataTypeHelper_(callback, int, INT32); \
47+
_ForEachDataTypeHelper_(callback, int64_t, INT64); \
48+
_ForEachDataTypeHelper_(callback, bool, BOOL); \
49+
_ForEachDataTypeHelper_(callback, uint8_t, UINT8); \
50+
_ForEachDataTypeHelper_(callback, int16_t, INT16); \
4851
_ForEachDataTypeHelper_(callback, int8_t, INT8)
4952

5053
#define _ForEachDataTypeSmall_(callback) \

paddle/fluid/framework/data_type_test.cc

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,25 @@ TEST(DataType, float16) {
3838
std::string type = "::paddle::platform::float16";
3939
EXPECT_STREQ(f::DataTypeToString(dtype).c_str(), type.c_str());
4040
}
41+
42+
TEST(DataType, bfloat16) {
43+
using paddle::framework::Tensor;
44+
using paddle::platform::CPUPlace;
45+
using paddle::platform::bfloat16;
46+
namespace f = paddle::framework;
47+
f::proto::VarType::Type dtype = f::proto::VarType::BF16;
48+
49+
Tensor tensor;
50+
CPUPlace cpu;
51+
tensor.mutable_data(cpu, dtype);
52+
53+
// test bf16 tensor
54+
EXPECT_EQ(tensor.type(), f::ToDataType(typeid(bfloat16)));
55+
56+
// test bf16 size
57+
EXPECT_EQ(f::SizeOfType(dtype), 2u);
58+
59+
// test debug info
60+
std::string type = "::paddle::platform::bfloat16";
61+
EXPECT_STREQ(f::DataTypeToString(dtype).c_str(), type.c_str());
62+
}

paddle/fluid/framework/data_type_transform.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,10 @@ void TransDataType(const OpKernelType& kernel_type_for_var,
7777
framework::VisitDataType(dst_type,
7878
CastDataType<platform::float16>(in, out, ctx));
7979
break;
80+
case proto::VarType::BF16:
81+
framework::VisitDataType(dst_type,
82+
CastDataType<platform::bfloat16>(in, out, ctx));
83+
break;
8084
case proto::VarType::FP32:
8185
framework::VisitDataType(dst_type, CastDataType<float>(in, out, ctx));
8286
break;

paddle/fluid/framework/data_type_transform_test.cc

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,11 @@ TEST(DataTypeTransform, CPUTransform) {
2424
paddle::framework::DataLayout::kAnyLayout,
2525
paddle::framework::LibraryType::kPlain);
2626

27+
auto kernel_bf16 = paddle::framework::OpKernelType(
28+
paddle::framework::proto::VarType::BF16, place,
29+
paddle::framework::DataLayout::kAnyLayout,
30+
paddle::framework::LibraryType::kPlain);
31+
2732
auto kernel_fp32 = paddle::framework::OpKernelType(
2833
paddle::framework::proto::VarType::FP32, place,
2934
paddle::framework::DataLayout::kAnyLayout,
@@ -189,4 +194,120 @@ TEST(DataTypeTransform, CPUTransform) {
189194
static_cast<paddle::platform::float16>(in_data_bool[i]).x);
190195
}
191196
}
197+
198+
// data type transform from/to bfloat16
199+
{
200+
paddle::framework::Tensor in;
201+
paddle::framework::Tensor out;
202+
203+
paddle::platform::bfloat16* ptr =
204+
in.mutable_data<paddle::platform::bfloat16>(
205+
paddle::framework::make_ddim({2, 3}), place);
206+
int data_number = 2 * 3;
207+
208+
for (int i = 0; i < data_number; ++i) {
209+
ptr[i] = i;
210+
}
211+
212+
// transform from bfloat16 to other data types
213+
paddle::framework::TransDataType(kernel_bf16, kernel_fp32, in, &out);
214+
float* out_data_float = out.data<float>();
215+
for (int i = 0; i < data_number; ++i) {
216+
EXPECT_EQ(out_data_float[i], static_cast<float>(ptr[i]));
217+
}
218+
219+
paddle::framework::TransDataType(kernel_bf16, kernel_fp64, in, &out);
220+
double* out_data_double = out.data<double>();
221+
for (int i = 0; i < data_number; ++i) {
222+
EXPECT_EQ(out_data_double[i], static_cast<double>(ptr[i]));
223+
}
224+
225+
paddle::framework::TransDataType(kernel_bf16, kernel_int32, in, &out);
226+
int* out_data_int = out.data<int>();
227+
for (int i = 0; i < data_number; ++i) {
228+
EXPECT_EQ(out_data_int[i], static_cast<int>(ptr[i]));
229+
}
230+
231+
paddle::framework::TransDataType(kernel_bf16, kernel_int64, in, &out);
232+
int64_t* out_data_int64 = out.data<int64_t>();
233+
for (int i = 0; i < data_number; ++i) {
234+
EXPECT_EQ(out_data_int64[i], static_cast<int64_t>(ptr[i]));
235+
}
236+
237+
paddle::framework::TransDataType(kernel_bf16, kernel_bool, in, &out);
238+
bool* out_data_bool = out.data<bool>();
239+
for (int i = 0; i < data_number; ++i) {
240+
EXPECT_EQ(out_data_bool[i], static_cast<bool>(ptr[i]));
241+
}
242+
243+
// transform float to bfloat16
244+
float* in_data_float =
245+
in.mutable_data<float>(paddle::framework::make_ddim({2, 3}), place);
246+
for (int i = 0; i < data_number; ++i) {
247+
in_data_float[i] = i;
248+
}
249+
250+
paddle::framework::TransDataType(kernel_fp32, kernel_bf16, in, &out);
251+
ptr = out.data<paddle::platform::bfloat16>();
252+
for (int i = 0; i < data_number; ++i) {
253+
EXPECT_EQ(ptr[i].x,
254+
static_cast<paddle::platform::bfloat16>(in_data_float[i]).x);
255+
}
256+
257+
// transform double to bfloat16
258+
double* in_data_double =
259+
in.mutable_data<double>(paddle::framework::make_ddim({2, 3}), place);
260+
for (int i = 0; i < data_number; ++i) {
261+
in_data_double[i] = i;
262+
}
263+
264+
paddle::framework::TransDataType(kernel_fp64, kernel_bf16, in, &out);
265+
ptr = out.data<paddle::platform::bfloat16>();
266+
for (int i = 0; i < data_number; ++i) {
267+
EXPECT_EQ(ptr[i].x,
268+
static_cast<paddle::platform::bfloat16>(in_data_double[i]).x);
269+
}
270+
271+
// transform int to bfloat16
272+
int* in_data_int =
273+
in.mutable_data<int>(paddle::framework::make_ddim({2, 3}), place);
274+
for (int i = 0; i < data_number; ++i) {
275+
in_data_int[i] = i;
276+
}
277+
278+
paddle::framework::TransDataType(kernel_int32, kernel_bf16, in, &out);
279+
ptr = out.data<paddle::platform::bfloat16>();
280+
for (int i = 0; i < data_number; ++i) {
281+
EXPECT_EQ(ptr[i].x,
282+
static_cast<paddle::platform::bfloat16>(in_data_int[i]).x);
283+
}
284+
285+
// transform int64 to bfloat16
286+
int64_t* in_data_int64 =
287+
in.mutable_data<int64_t>(paddle::framework::make_ddim({2, 3}), place);
288+
for (int i = 0; i < data_number; ++i) {
289+
in_data_int64[i] = i;
290+
}
291+
292+
paddle::framework::TransDataType(kernel_int64, kernel_bf16, in, &out);
293+
ptr = out.data<paddle::platform::bfloat16>();
294+
for (int i = 0; i < data_number; ++i) {
295+
EXPECT_EQ(ptr[i].x,
296+
static_cast<paddle::platform::bfloat16>(in_data_int64[i]).x);
297+
}
298+
299+
// transform bool to bfloat16
300+
bool* in_data_bool =
301+
in.mutable_data<bool>(paddle::framework::make_ddim({2, 3}), place);
302+
for (int i = 0; i < data_number; ++i) {
303+
in_data_bool[i] = i;
304+
}
305+
306+
paddle::framework::TransDataType(kernel_bool, kernel_bf16, in, &out);
307+
ptr = out.data<paddle::platform::bfloat16>();
308+
for (int i = 0; i < data_number; ++i) {
309+
EXPECT_EQ(ptr[i].x,
310+
static_cast<paddle::platform::bfloat16>(in_data_bool[i]).x);
311+
}
312+
}
192313
}

paddle/fluid/framework/details/nan_inf_utils_detail.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,8 @@ static void PrintNanInf(const T* value, const size_t numel, int print_num,
167167
// more detail see: 180 page of
168168
// https://www.openmp.org/wp-content/uploads/OpenMP4.0.0.pdf
169169
#pragma omp declare reduction(+ : paddle::platform::float16 : omp_out += omp_in)
170+
#pragma omp declare reduction(+ : paddle::platform::bfloat16 : omp_out += \
171+
omp_in)
170172
#endif
171173

172174
template <typename T>

paddle/fluid/framework/dlpack_tensor.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ template <typename T>
2323
static ::DLDataType GetDLDataTypeCode() {
2424
::DLDataType dtype;
2525
if (std::is_same<T, platform::float16>::value ||
26+
std::is_same<T, platform::bfloat16>::value ||
2627
std::is_floating_point<T>::value) {
2728
dtype.code = kDLFloat;
2829
} else if (std::is_unsigned<T>::value) {

0 commit comments

Comments
 (0)