@@ -24,6 +24,11 @@ TEST(DataTypeTransform, CPUTransform) {
2424 paddle::framework::DataLayout::kAnyLayout ,
2525 paddle::framework::LibraryType::kPlain );
2626
27+ auto kernel_bf16 = paddle::framework::OpKernelType (
28+ paddle::framework::proto::VarType::BF16, place,
29+ paddle::framework::DataLayout::kAnyLayout ,
30+ paddle::framework::LibraryType::kPlain );
31+
2732 auto kernel_fp32 = paddle::framework::OpKernelType (
2833 paddle::framework::proto::VarType::FP32, place,
2934 paddle::framework::DataLayout::kAnyLayout ,
@@ -189,4 +194,120 @@ TEST(DataTypeTransform, CPUTransform) {
189194 static_cast <paddle::platform::float16>(in_data_bool[i]).x );
190195 }
191196 }
197+
198+ // data type transform from/to bfloat16
199+ {
200+ paddle::framework::Tensor in;
201+ paddle::framework::Tensor out;
202+
203+ paddle::platform::bfloat16* ptr =
204+ in.mutable_data <paddle::platform::bfloat16>(
205+ paddle::framework::make_ddim ({2 , 3 }), place);
206+ int data_number = 2 * 3 ;
207+
208+ for (int i = 0 ; i < data_number; ++i) {
209+ ptr[i] = i;
210+ }
211+
212+ // transform from bfloat16 to other data types
213+ paddle::framework::TransDataType (kernel_bf16, kernel_fp32, in, &out);
214+ float * out_data_float = out.data <float >();
215+ for (int i = 0 ; i < data_number; ++i) {
216+ EXPECT_EQ (out_data_float[i], static_cast <float >(ptr[i]));
217+ }
218+
219+ paddle::framework::TransDataType (kernel_bf16, kernel_fp64, in, &out);
220+ double * out_data_double = out.data <double >();
221+ for (int i = 0 ; i < data_number; ++i) {
222+ EXPECT_EQ (out_data_double[i], static_cast <double >(ptr[i]));
223+ }
224+
225+ paddle::framework::TransDataType (kernel_bf16, kernel_int32, in, &out);
226+ int * out_data_int = out.data <int >();
227+ for (int i = 0 ; i < data_number; ++i) {
228+ EXPECT_EQ (out_data_int[i], static_cast <int >(ptr[i]));
229+ }
230+
231+ paddle::framework::TransDataType (kernel_bf16, kernel_int64, in, &out);
232+ int64_t * out_data_int64 = out.data <int64_t >();
233+ for (int i = 0 ; i < data_number; ++i) {
234+ EXPECT_EQ (out_data_int64[i], static_cast <int64_t >(ptr[i]));
235+ }
236+
237+ paddle::framework::TransDataType (kernel_bf16, kernel_bool, in, &out);
238+ bool * out_data_bool = out.data <bool >();
239+ for (int i = 0 ; i < data_number; ++i) {
240+ EXPECT_EQ (out_data_bool[i], static_cast <bool >(ptr[i]));
241+ }
242+
243+ // transform float to bfloat16
244+ float * in_data_float =
245+ in.mutable_data <float >(paddle::framework::make_ddim ({2 , 3 }), place);
246+ for (int i = 0 ; i < data_number; ++i) {
247+ in_data_float[i] = i;
248+ }
249+
250+ paddle::framework::TransDataType (kernel_fp32, kernel_bf16, in, &out);
251+ ptr = out.data <paddle::platform::bfloat16>();
252+ for (int i = 0 ; i < data_number; ++i) {
253+ EXPECT_EQ (ptr[i].x ,
254+ static_cast <paddle::platform::bfloat16>(in_data_float[i]).x );
255+ }
256+
257+ // transform double to bfloat16
258+ double * in_data_double =
259+ in.mutable_data <double >(paddle::framework::make_ddim ({2 , 3 }), place);
260+ for (int i = 0 ; i < data_number; ++i) {
261+ in_data_double[i] = i;
262+ }
263+
264+ paddle::framework::TransDataType (kernel_fp64, kernel_bf16, in, &out);
265+ ptr = out.data <paddle::platform::bfloat16>();
266+ for (int i = 0 ; i < data_number; ++i) {
267+ EXPECT_EQ (ptr[i].x ,
268+ static_cast <paddle::platform::bfloat16>(in_data_double[i]).x );
269+ }
270+
271+ // transform int to bfloat16
272+ int * in_data_int =
273+ in.mutable_data <int >(paddle::framework::make_ddim ({2 , 3 }), place);
274+ for (int i = 0 ; i < data_number; ++i) {
275+ in_data_int[i] = i;
276+ }
277+
278+ paddle::framework::TransDataType (kernel_int32, kernel_bf16, in, &out);
279+ ptr = out.data <paddle::platform::bfloat16>();
280+ for (int i = 0 ; i < data_number; ++i) {
281+ EXPECT_EQ (ptr[i].x ,
282+ static_cast <paddle::platform::bfloat16>(in_data_int[i]).x );
283+ }
284+
285+ // transform int64 to bfloat16
286+ int64_t * in_data_int64 =
287+ in.mutable_data <int64_t >(paddle::framework::make_ddim ({2 , 3 }), place);
288+ for (int i = 0 ; i < data_number; ++i) {
289+ in_data_int64[i] = i;
290+ }
291+
292+ paddle::framework::TransDataType (kernel_int64, kernel_bf16, in, &out);
293+ ptr = out.data <paddle::platform::bfloat16>();
294+ for (int i = 0 ; i < data_number; ++i) {
295+ EXPECT_EQ (ptr[i].x ,
296+ static_cast <paddle::platform::bfloat16>(in_data_int64[i]).x );
297+ }
298+
299+ // transform bool to bfloat16
300+ bool * in_data_bool =
301+ in.mutable_data <bool >(paddle::framework::make_ddim ({2 , 3 }), place);
302+ for (int i = 0 ; i < data_number; ++i) {
303+ in_data_bool[i] = i;
304+ }
305+
306+ paddle::framework::TransDataType (kernel_bool, kernel_bf16, in, &out);
307+ ptr = out.data <paddle::platform::bfloat16>();
308+ for (int i = 0 ; i < data_number; ++i) {
309+ EXPECT_EQ (ptr[i].x ,
310+ static_cast <paddle::platform::bfloat16>(in_data_bool[i]).x );
311+ }
312+ }
192313}
0 commit comments