@@ -24,11 +24,6 @@ TEST(DataTypeTransform, CPUTransform) {
2424 paddle::framework::DataLayout::kAnyLayout ,
2525 paddle::framework::LibraryType::kPlain );
2626
27- auto kernel_bf16 = paddle::framework::OpKernelType (
28- paddle::framework::proto::VarType::BF16, place,
29- paddle::framework::DataLayout::kAnyLayout ,
30- paddle::framework::LibraryType::kPlain );
31-
3227 auto kernel_fp32 = paddle::framework::OpKernelType (
3328 paddle::framework::proto::VarType::FP32, place,
3429 paddle::framework::DataLayout::kAnyLayout ,
@@ -194,120 +189,4 @@ TEST(DataTypeTransform, CPUTransform) {
194189 static_cast <paddle::platform::float16>(in_data_bool[i]).x );
195190 }
196191 }
197-
198- // data type transform from/to bfloat16
199- {
200- paddle::framework::Tensor in;
201- paddle::framework::Tensor out;
202-
203- paddle::platform::bfloat16* ptr =
204- in.mutable_data <paddle::platform::bfloat16>(
205- paddle::framework::make_ddim ({2 , 3 }), place);
206- int data_number = 2 * 3 ;
207-
208- for (int i = 0 ; i < data_number; ++i) {
209- ptr[i] = i;
210- }
211-
212- // transform from bfloat16 to other data types
213- paddle::framework::TransDataType (kernel_bf16, kernel_fp32, in, &out);
214- float * out_data_float = out.data <float >();
215- for (int i = 0 ; i < data_number; ++i) {
216- EXPECT_EQ (out_data_float[i], static_cast <float >(ptr[i]));
217- }
218-
219- paddle::framework::TransDataType (kernel_bf16, kernel_fp64, in, &out);
220- double * out_data_double = out.data <double >();
221- for (int i = 0 ; i < data_number; ++i) {
222- EXPECT_EQ (out_data_double[i], static_cast <double >(ptr[i]));
223- }
224-
225- paddle::framework::TransDataType (kernel_bf16, kernel_int32, in, &out);
226- int * out_data_int = out.data <int >();
227- for (int i = 0 ; i < data_number; ++i) {
228- EXPECT_EQ (out_data_int[i], static_cast <int >(ptr[i]));
229- }
230-
231- paddle::framework::TransDataType (kernel_bf16, kernel_int64, in, &out);
232- int64_t * out_data_int64 = out.data <int64_t >();
233- for (int i = 0 ; i < data_number; ++i) {
234- EXPECT_EQ (out_data_int64[i], static_cast <int64_t >(ptr[i]));
235- }
236-
237- paddle::framework::TransDataType (kernel_bf16, kernel_bool, in, &out);
238- bool * out_data_bool = out.data <bool >();
239- for (int i = 0 ; i < data_number; ++i) {
240- EXPECT_EQ (out_data_bool[i], static_cast <bool >(ptr[i]));
241- }
242-
243- // transform float to bfloat16
244- float * in_data_float =
245- in.mutable_data <float >(paddle::framework::make_ddim ({2 , 3 }), place);
246- for (int i = 0 ; i < data_number; ++i) {
247- in_data_float[i] = i;
248- }
249-
250- paddle::framework::TransDataType (kernel_fp32, kernel_bf16, in, &out);
251- ptr = out.data <paddle::platform::bfloat16>();
252- for (int i = 0 ; i < data_number; ++i) {
253- EXPECT_EQ (ptr[i].x ,
254- static_cast <paddle::platform::bfloat16>(in_data_float[i]).x );
255- }
256-
257- // transform double to bfloat16
258- double * in_data_double =
259- in.mutable_data <double >(paddle::framework::make_ddim ({2 , 3 }), place);
260- for (int i = 0 ; i < data_number; ++i) {
261- in_data_double[i] = i;
262- }
263-
264- paddle::framework::TransDataType (kernel_fp64, kernel_bf16, in, &out);
265- ptr = out.data <paddle::platform::bfloat16>();
266- for (int i = 0 ; i < data_number; ++i) {
267- EXPECT_EQ (ptr[i].x ,
268- static_cast <paddle::platform::bfloat16>(in_data_double[i]).x );
269- }
270-
271- // transform int to bfloat16
272- int * in_data_int =
273- in.mutable_data <int >(paddle::framework::make_ddim ({2 , 3 }), place);
274- for (int i = 0 ; i < data_number; ++i) {
275- in_data_int[i] = i;
276- }
277-
278- paddle::framework::TransDataType (kernel_int32, kernel_bf16, in, &out);
279- ptr = out.data <paddle::platform::bfloat16>();
280- for (int i = 0 ; i < data_number; ++i) {
281- EXPECT_EQ (ptr[i].x ,
282- static_cast <paddle::platform::bfloat16>(in_data_int[i]).x );
283- }
284-
285- // transform int64 to bfloat16
286- int64_t * in_data_int64 =
287- in.mutable_data <int64_t >(paddle::framework::make_ddim ({2 , 3 }), place);
288- for (int i = 0 ; i < data_number; ++i) {
289- in_data_int64[i] = i;
290- }
291-
292- paddle::framework::TransDataType (kernel_int64, kernel_bf16, in, &out);
293- ptr = out.data <paddle::platform::bfloat16>();
294- for (int i = 0 ; i < data_number; ++i) {
295- EXPECT_EQ (ptr[i].x ,
296- static_cast <paddle::platform::bfloat16>(in_data_int64[i]).x );
297- }
298-
299- // transform bool to bfloat16
300- bool * in_data_bool =
301- in.mutable_data <bool >(paddle::framework::make_ddim ({2 , 3 }), place);
302- for (int i = 0 ; i < data_number; ++i) {
303- in_data_bool[i] = i;
304- }
305-
306- paddle::framework::TransDataType (kernel_bool, kernel_bf16, in, &out);
307- ptr = out.data <paddle::platform::bfloat16>();
308- for (int i = 0 ; i < data_number; ++i) {
309- EXPECT_EQ (ptr[i].x ,
310- static_cast <paddle::platform::bfloat16>(in_data_bool[i]).x );
311- }
312- }
313192}
0 commit comments