Skip to content

Commit

Permalink
Perfect unitests (PaddlePaddle#20)
Browse files Browse the repository at this point in the history
* perfect unittest

* update license

* fix bug when run tcmpt_utils_test
  • Loading branch information
YuanRisheng authored Oct 18, 2021
1 parent 37791f7 commit 28a6374
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions paddle/fluid/framework/tcmpt_utils_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,17 @@ TEST(TcmptUtils, VarToPtTensor) {
auto* data =
value->mutable_data<int>(make_ddim({1, 1}), paddle::platform::CPUPlace());
data[0] = 123;
auto tensor_def = pt::TensorArgDef(pt::Backend::kCUDA, pt::DataLayout::kNCHW,
pt::Backend expect_backend = pt::Backend::kCPU;

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
expect_backend = pt::Backend::kCUDA;
#endif
auto tensor_def = pt::TensorArgDef(expect_backend, pt::DataLayout::kNCHW,
pt::DataType::kINT32);
// 2. test API
auto tensor_x = InputVariableToPtTensor(v, tensor_def);
// 3. check result
ASSERT_EQ(tensor_x->backend(), pt::Backend::kCUDA);
ASSERT_EQ(tensor_x->backend(), expect_backend);
ASSERT_EQ(tensor_x->type(), pt::DataType::kINT32);
}

Expand Down

0 comments on commit 28a6374

Please sign in to comment.