Skip to content

Commit

Permalink
Fixed issues with place
Browse files Browse the repository at this point in the history
  • Loading branch information
jim19930609 committed Dec 28, 2021
1 parent 3298585 commit 235503f
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 11 deletions.
10 changes: 1 addition & 9 deletions paddle/pten/core/dense_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -265,22 +265,14 @@ const paddle::platform::Place& DenseTensor::place() const {
storage_,
paddle::platform::errors::PreconditionNotMet(
"Tensor not initialized yet when Tensor::place() is called."));
PADDLE_ENFORCE_NOT_NULL(
storage_->data_shared(),
paddle::platform::errors::PreconditionNotMet(
"Tensor not initialized yet when Tensor::place() is called."));
return storage_->data_shared()->place();
return storage_->place();
}

paddle::framework::proto::VarType::Type DenseTensor::type() const {
PADDLE_ENFORCE_NOT_NULL(
storage_,
paddle::platform::errors::PreconditionNotMet(
"Tensor not initialized yet when Tensor::type() is called."));
PADDLE_ENFORCE_NOT_NULL(
storage_->data_shared(),
paddle::platform::errors::PreconditionNotMet(
"Tensor not initialized yet when Tensor::type() is called."));
return TransToProtoVarType(meta_.dtype);
}

Expand Down
9 changes: 9 additions & 0 deletions paddle/pten/core/dense_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,15 @@ class DenseTensor : public TensorBase,

dnnl::memory::format_tag format_ = dnnl::memory::format_tag::undef;
#endif

/* ------------------------------ */
/* From framework::LoDTensor */
/* ------------------------------ */
/* The following members & interfaces were copied from framework::Tensor,
so as to facilitate the unification of different Tensors
Will be adjusted/removed/moved in the near future
*/
};

} // namespace pten
3 changes: 1 addition & 2 deletions paddle/pten/tests/core/test_dense_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ TEST(dense_tensor, meta) {

TEST(dense_tensor, def_ctor) {
DenseTensor tensor_0;
CHECK(!tensor_0.valid());
CHECK(tensor_0.valid());
}

TEST(dense_tensor, ctor) {
Expand Down Expand Up @@ -97,7 +97,6 @@ TEST(dense_tensor, ctor) {
check_dense_tensor(tensor_0, meta);

DenseTensor tensor_2(make_intrusive<TensorStorage>(alloc), meta);
CHECK(tensor_2.data<int8_t>() == nullptr);
CHECK_NOTNULL(tensor_2.mutable_data<int8_t>());
check_dense_tensor(tensor_2, meta);
}
Expand Down

0 comments on commit 235503f

Please sign in to comment.