@@ -1414,43 +1414,42 @@ static PyObject* tensor_method_set_underline_tensor(TensorObject* self,
14141414 if (self->tensor .is_dense_tensor ()) {
14151415 auto * dst_tensor =
14161416 static_cast <phi::DenseTensor*>(self->tensor .impl ().get ());
1417- // if (self->tensor.has_allocation() && self->tensor.initialized() &&
1418- // !dst_tensor->meta().is_contiguous() ||
1419- // !src_tensor->meta().is_contiguous()) {
1420- // VLOG(8) << "set_tensor() method , src or dst tensor is not
1421- // contiguous"; if (!FLAGS_use_stride_kernel) {
1422- // PADDLE_THROW(common::errors::Fatal(
1423- // "FLAGS_use_stride_kernel is closed. Strided kernel "
1424- // "be called, something wrong has happened!"));
1425- // }
1426- // PD_VISIT_ALL_TYPES(
1427- // src_tensor->dtype(), "StridedTensorCopy", ([&] {
1428- // phi::StridedTensorCopy<data_t>(
1429- // *src_tensor,
1430- // common::vectorize<int64_t>(dst_tensor->dims()),
1431- // common::vectorize<int64_t>(dst_tensor->strides()),
1432- // dst_tensor->offset(),
1433- // dst_tensor);
1434- // }));
1435- // } else {
1436- if (dst_tensor->place ().GetType () != phi::AllocationType::UNDEFINED) {
1437- framework::TensorCopy (*src_tensor, dst_tensor->place (), dst_tensor);
1438- } else if (src_tensor->place ().GetType () !=
1439- phi::AllocationType::UNDEFINED) {
1440- framework::TensorCopy (*src_tensor, src_tensor->place (), dst_tensor);
1417+ if (self->tensor .has_allocation () && self->tensor .initialized () &&
1418+ (!dst_tensor->meta ().is_contiguous () ||
1419+ !src_tensor->meta ().is_contiguous ())) {
1420+ VLOG (8 ) << " set_tensor() method , src or dst tensor is not contiguous " ;
1421+ if (!FLAGS_use_stride_kernel) {
1422+ PADDLE_THROW (common::errors::Fatal (
1423+ " FLAGS_use_stride_kernel is closed. Strided kernel "
1424+ " be called, something wrong has happened!" ));
1425+ }
1426+ PD_VISIT_ALL_TYPES (
1427+ src_tensor->dtype (), " StridedTensorCopy" , ([&] {
1428+ phi::StridedTensorCopy<data_t >(
1429+ *src_tensor,
1430+ common::vectorize<int64_t >(dst_tensor->dims ()),
1431+ common::vectorize<int64_t >(dst_tensor->strides ()),
1432+ dst_tensor->offset (),
1433+ dst_tensor);
1434+ }));
14411435 } else {
1442- PADDLE_THROW (common::errors::Unavailable (
1443- " The `set_tensor()` method of (Dist)Tensor get a src value with "
1444- " undefined place" ));
1436+ if (dst_tensor->place ().GetType () != phi::AllocationType::UNDEFINED) {
1437+ framework::TensorCopy (*src_tensor, dst_tensor->place (), dst_tensor);
1438+ } else if (src_tensor->place ().GetType () !=
1439+ phi::AllocationType::UNDEFINED) {
1440+ framework::TensorCopy (*src_tensor, src_tensor->place (), dst_tensor);
1441+ } else {
1442+ PADDLE_THROW (common::errors::Unavailable (
1443+ " The `set_tensor()` method of (Dist)Tensor get a src value with "
1444+ " undefined place" ));
1445+ }
14451446 }
1446- // }
14471447
14481448 } else {
14491449 PADDLE_THROW (common::errors::Unavailable (
14501450 " The `set_tensor()` method of non DenseTensor get a DenseTensor src "
14511451 " value" ));
14521452 }
1453-
14541453 } else if (value.is_dist_tensor ()) {
14551454#ifdef PADDLE_WITH_DISTRIBUTE
14561455 auto * src_tensor =
@@ -1484,7 +1483,6 @@ static PyObject* tensor_method_set_underline_tensor(TensorObject* self,
14841483 " current PaddlePaddle, please recompile and installPaddlePaddle "
14851484 " with the option of `WITH_DISTRIBUTE=ON`." ));
14861485#endif
1487-
14881486 } else {
14891487 PADDLE_THROW (common::errors::Unavailable (
14901488 " The `set_tensor()` method of (Dist)Tensor get a non "
0 commit comments