Skip to content

Commit

Permalink
update the cpu kernels, test=develop
Browse files Browse the repository at this point in the history
  • Loading branch information
Shixiaowei02 committed Jan 21, 2022
1 parent 1fd1604 commit 344b58f
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 28 deletions.
12 changes: 7 additions & 5 deletions paddle/pten/api/lib/tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -173,11 +173,13 @@ Tensor::mutable_data<paddle::platform::float16>();
template <typename T>
T *Tensor::mutable_data(const PlaceType &place) {
auto inner_place = ConvertExtPlaceToInnerPlace(place);
PADDLE_ENFORCE_EQ(
platform::is_same_place(inner_place, impl_->place()),
true,
platform::errors::Unimplemented("Modification of tensor place through "
"mutable_data is not supported now"));
if (impl_->initialized()) {
PADDLE_ENFORCE_EQ(
platform::is_same_place(inner_place, impl_->place()),
true,
platform::errors::Unimplemented("Modification of tensor place through "
"mutable_data is not supported now"));
}
if (is_dense_tensor()) {
return std::dynamic_pointer_cast<pten::DenseTensor>(impl_)->mutable_data<T>(
inner_place);
Expand Down
4 changes: 0 additions & 4 deletions paddle/pten/core/dense_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -213,10 +213,6 @@ const paddle::platform::Place& DenseTensor::place() const {
}

paddle::framework::proto::VarType::Type DenseTensor::type() const {
PADDLE_ENFORCE_NOT_NULL(
holder_,
paddle::platform::errors::PreconditionNotMet(
"Tensor not initialized yet when DenseTensor::type() is called."));
return TransToProtoVarType(meta_.dtype);
}

Expand Down
16 changes: 7 additions & 9 deletions paddle/pten/kernels/cpu/copy_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,16 @@ void Copy(const Context& dev_ctx,
DenseTensor* dst) {
auto* src_ptr = src.data();
const auto& src_place = src.place();
const auto& dst_place = dst->place();

VLOG(3) << "TensorCopy " << src.dims() << " from " << src.place() << " to "
<< dst_place;
<< src_place;

dst->ResizeAndAllocate(src.dims());
auto* dst_ptr = dst->mutable_data(dst_place);
dst->Resize(src.dims());
auto* dst_ptr = dst->mutable_data(src_place);

if (src_ptr == dst_ptr && src_place == dst_place) {
if (src_ptr == dst_ptr) {
VLOG(3) << "Skip copy the same data async from " << src_place << " to "
<< dst_place;
<< src_place;
return;
}
VLOG(4) << "src:" << src_ptr << ", dst:" << dst_ptr;
Expand All @@ -51,9 +50,8 @@ void Copy(const Context& dev_ctx,
auto size = src.numel() *
paddle::framework::SizeOfType(TransToProtoVarType(src.dtype()));

if (paddle::platform::is_cpu_place(src_place) &&
paddle::platform::is_cpu_place(dst_place)) {
paddle::memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size);
if (paddle::platform::is_cpu_place(src_place)) {
paddle::memory::Copy(src_place, dst_ptr, src_place, src_ptr, size);
}
}

Expand Down
4 changes: 2 additions & 2 deletions paddle/pten/kernels/reshape_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ void ReshapeKernel(const Context& dev_ctx,
const ScalarArray& shape,
DenseTensor* out) {
auto out_meta = InferMetaFromVecValue(x.meta(), shape.GetData());
if (x.data() == out->data() && x.numel() == out->numel()) {
if (x.Holder() == out->Holder()) {
out->ResizeAndAllocate(out_meta.dims);
return;
}
pten::Copy(dev_ctx, x, false, out);
out->ResizeAndAllocate(out_meta.dims);
out->Resize(out_meta.dims);
out->ResetLoD(x.lod());
}

Expand Down
8 changes: 0 additions & 8 deletions paddle/pten/tests/api/test_empty_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,8 @@ TEST(API, empty_like) {
ASSERT_EQ(out.dims().size(), 2);
ASSERT_EQ(out.dims()[0], 3);
ASSERT_EQ(out.numel(), 6);
ASSERT_EQ(out.is_cpu(), true);
ASSERT_EQ(out.type(), pten::DataType::FLOAT32);
ASSERT_EQ(out.layout(), pten::DataLayout::NCHW);
ASSERT_EQ(out.initialized(), true);
}

TEST(API, empty1) {
Expand All @@ -77,10 +75,8 @@ TEST(API, empty1) {
ASSERT_EQ(out.shape().size(), 2UL);
ASSERT_EQ(out.shape()[0], 2);
ASSERT_EQ(out.numel(), 6);
ASSERT_EQ(out.is_cpu(), true);
ASSERT_EQ(out.type(), pten::DataType::FLOAT32);
ASSERT_EQ(out.layout(), pten::DataLayout::NCHW);
ASSERT_EQ(out.initialized(), true);
}

TEST(API, empty2) {
Expand All @@ -104,10 +100,8 @@ TEST(API, empty2) {
ASSERT_EQ(out.shape().size(), 2UL);
ASSERT_EQ(out.shape()[0], 2);
ASSERT_EQ(out.numel(), 4);
ASSERT_EQ(out.is_cpu(), true);
ASSERT_EQ(out.type(), pten::DataType::FLOAT32);
ASSERT_EQ(out.layout(), pten::DataLayout::NCHW);
ASSERT_EQ(out.initialized(), true);
}

TEST(API, empty3) {
Expand All @@ -118,10 +112,8 @@ TEST(API, empty3) {
ASSERT_EQ(out.shape().size(), 2UL);
ASSERT_EQ(out.shape()[0], 2);
ASSERT_EQ(out.numel(), 6);
ASSERT_EQ(out.is_cpu(), true);
ASSERT_EQ(out.type(), pten::DataType::INT32);
ASSERT_EQ(out.layout(), pten::DataLayout::NCHW);
ASSERT_EQ(out.initialized(), true);
}

} // namespace tests
Expand Down

0 comments on commit 344b58f

Please sign in to comment.