Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 29 additions & 8 deletions paddle/fluid/pybind/eager_method.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ typedef SSIZE_T ssize_t;
#include "paddle/phi/core/memory/memcpy.h"
#include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/core/sparse_csr_tensor.h"
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/core/vocab/string_array.h"
#include "pybind11/detail/internals.h"
#include "pybind11/numpy.h"
Expand All @@ -65,6 +66,7 @@ typedef SSIZE_T ssize_t;
#include "paddle/phi/core/memory/allocation/mmap_allocator.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/strided_utils.h"
#include "paddle/utils/pybind.h"

COMMON_DECLARE_bool(set_to_1d);
Expand Down Expand Up @@ -1413,15 +1415,34 @@ static PyObject* tensor_method_set_underline_tensor(TensorObject* self,
if (self->tensor.is_dense_tensor()) {
auto* dst_tensor =
static_cast<phi::DenseTensor*>(self->tensor.impl().get());
if (dst_tensor->place().GetType() != phi::AllocationType::UNDEFINED) {
framework::TensorCopy(*src_tensor, dst_tensor->place(), dst_tensor);
} else if (src_tensor->place().GetType() !=
phi::AllocationType::UNDEFINED) {
framework::TensorCopy(*src_tensor, src_tensor->place(), dst_tensor);
if (!dst_tensor->meta().is_contiguous() ||
!src_tensor->meta().is_contiguous()) {
VLOG(8) << "set_tensor() method , src or dst tensor is not contiguous";
if (!FLAGS_use_stride_kernel) {
PADDLE_THROW(common::errors::Fatal(
"FLAGS_use_stride_kernel is closed. Strided kernel "
"be called, something wrong has happened!"));
}
PD_VISIT_ALL_TYPES(
src_tensor->dtype(), "StridedTensorCopy", ([&] {
phi::StridedTensorCopy<data_t>(
*src_tensor,
common::vectorize<int64_t>(dst_tensor->dims()),
common::vectorize<int64_t>(dst_tensor->strides()),
dst_tensor->offset(),
dst_tensor);
}));
} else {
PADDLE_THROW(common::errors::Unavailable(
"The `set_tensor()` method of (Dist)Tensor get a src value with "
"undefined place"));
if (dst_tensor->place().GetType() != phi::AllocationType::UNDEFINED) {
framework::TensorCopy(*src_tensor, dst_tensor->place(), dst_tensor);
} else if (src_tensor->place().GetType() !=
phi::AllocationType::UNDEFINED) {
framework::TensorCopy(*src_tensor, src_tensor->place(), dst_tensor);
} else {
PADDLE_THROW(common::errors::Unavailable(
"The `set_tensor()` method of (Dist)Tensor get a src value with "
"undefined place"));
}
}

} else {
Expand Down
3 changes: 3 additions & 0 deletions paddle/phi/api/lib/tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,9 @@ DataType Tensor::type() const { return impl_->dtype(); }
phi::DataLayout Tensor::layout() const { return impl_->layout(); }

bool Tensor::is_dense_tensor() const {
if (impl_ == nullptr) {
return false;
}
return phi::DenseTensor::classof(impl_.get());
}
bool Tensor::is_dist_tensor() const {
Expand Down
63 changes: 51 additions & 12 deletions paddle/phi/api/lib/tensor_method.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,27 @@ limitations under the License. */
#include "paddle/phi/api/include/tensor.h"

#include "glog/logging.h"

#include "paddle/phi/common/int_array.h"
#include "paddle/phi/core/compat/convert_utils.h"
#include "paddle/phi/core/tensor_base.h"
#include "paddle/common/flags.h"

#include "paddle/phi/api/include/context_pool.h"
#include "paddle/phi/api/include/sparse_api.h"
#include "paddle/phi/api/lib/api_gen_utils.h"
#include "paddle/phi/api/lib/kernel_dispatch.h"
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/core/compat/convert_utils.h"
#include "paddle/phi/core/tensor_base.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/infermeta/unary.h"
#include "paddle/phi/kernels/funcs/strided_utils.h"
// clang-format off
#ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/phi/infermeta/spmd_rules/rules.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.h"
#include "paddle/phi/api/lib/data_transform.h"
#endif

COMMON_DECLARE_bool(use_stride_kernel);
namespace paddle {
namespace experimental {
// declare cast api
Expand Down Expand Up @@ -194,17 +198,52 @@ void Tensor::copy_(const Tensor &src,
return;
}
#endif
SetKernelOutput(this);
phi::MetaTensor meta_out(impl_.get());
phi::UnchangedInferMeta(
if(is_dense_tensor() && has_allocation() && src.is_dense_tensor()) {
auto dst_tensor = static_cast<phi::DenseTensor*>(impl_.get());
auto src_tensor = std::static_pointer_cast<phi::DenseTensor>(src.impl_);
if(!dst_tensor->meta().is_contiguous() ||
!src_tensor->meta().is_contiguous()) {
VLOG(8) << "Tensor::copy_ , src or dst tesnor is not contiguous";
if (!FLAGS_use_stride_kernel) {
PADDLE_THROW(common::errors::Fatal(
"FLAGS_use_stride_kernel is closed. Strided kernel "
"be called, something wrong has happened!"));
}
PD_VISIT_ALL_TYPES(src_tensor->dtype(), "StridedTensorCopy", ([&] {
phi::StridedTensorCopy<data_t>(
*src_tensor,
common::vectorize<int64_t>(dst_tensor->dims()),
common::vectorize<int64_t>(dst_tensor->strides()),
dst_tensor->offset(),
dst_tensor);
}));
} else {
SetKernelOutput(this);
phi::MetaTensor meta_out(impl_.get());
phi::UnchangedInferMeta(
MakeMetaTensor(
*(std::static_pointer_cast<phi::DenseTensor>(src.impl_))),
&meta_out);
phi::Copy(*dev_ctx,
(*(std::static_pointer_cast<phi::DenseTensor>(src.impl_))),
target_place,
blocking,
static_cast<phi::DenseTensor *>(impl_.get()));
}
} else {
SetKernelOutput(this);
phi::MetaTensor meta_out(impl_.get());
phi::UnchangedInferMeta(
MakeMetaTensor(
*(std::static_pointer_cast<phi::DenseTensor>(src.impl_))),
&meta_out);
phi::Copy(*dev_ctx,
(*(std::static_pointer_cast<phi::DenseTensor>(src.impl_))),
target_place,
blocking,
static_cast<phi::DenseTensor *>(impl_.get()));
phi::Copy(*dev_ctx,
(*(std::static_pointer_cast<phi::DenseTensor>(src.impl_))),
target_place,
blocking,
static_cast<phi::DenseTensor *>(impl_.get()));
}

} else if (kernel_type == KernelType::SELECTED_ROWS_KERNEL) {
SetSelectedRowsKernelOutput(this);
phi::MetaTensor meta_out(impl_.get());
Expand Down
12 changes: 2 additions & 10 deletions python/paddle/base/dygraph/tensor_patch_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,20 +258,12 @@ def set_value(
self.value().process_mesh,
self.value().placements,
)
if (
isinstance(value, paddle.Tensor)
and value.is_contiguous()
and self.value().is_contiguous()
):
if isinstance(value, paddle.Tensor):
self.value().set_tensor(value)
else:
self.value().get_tensor().set(value.get_tensor())
return
if (
isinstance(value, paddle.Tensor)
and value.is_contiguous()
and self.value().is_contiguous()
):
if isinstance(value, paddle.Tensor):
self.value().set_tensor(value)
else:
self.value().get_tensor().set(
Expand Down
Loading