Skip to content

Commit 94cb963

Browse files
fix
1 parent ce95f84 commit 94cb963

File tree

1 file changed

+28
-30
lines changed

1 file changed

+28
-30
lines changed

paddle/fluid/pybind/eager_method.cc

Lines changed: 28 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1414,43 +1414,42 @@ static PyObject* tensor_method_set_underline_tensor(TensorObject* self,
14141414
if (self->tensor.is_dense_tensor()) {
14151415
auto* dst_tensor =
14161416
static_cast<phi::DenseTensor*>(self->tensor.impl().get());
1417-
// if (self->tensor.has_allocation() && self->tensor.initialized() &&
1418-
// !dst_tensor->meta().is_contiguous() ||
1419-
// !src_tensor->meta().is_contiguous()) {
1420-
// VLOG(8) << "set_tensor() method , src or dst tensor is not
1421-
// contiguous"; if (!FLAGS_use_stride_kernel) {
1422-
// PADDLE_THROW(common::errors::Fatal(
1423-
// "FLAGS_use_stride_kernel is closed. Strided kernel "
1424-
// "be called, something wrong has happened!"));
1425-
// }
1426-
// PD_VISIT_ALL_TYPES(
1427-
// src_tensor->dtype(), "StridedTensorCopy", ([&] {
1428-
// phi::StridedTensorCopy<data_t>(
1429-
// *src_tensor,
1430-
// common::vectorize<int64_t>(dst_tensor->dims()),
1431-
// common::vectorize<int64_t>(dst_tensor->strides()),
1432-
// dst_tensor->offset(),
1433-
// dst_tensor);
1434-
// }));
1435-
// } else {
1436-
if (dst_tensor->place().GetType() != phi::AllocationType::UNDEFINED) {
1437-
framework::TensorCopy(*src_tensor, dst_tensor->place(), dst_tensor);
1438-
} else if (src_tensor->place().GetType() !=
1439-
phi::AllocationType::UNDEFINED) {
1440-
framework::TensorCopy(*src_tensor, src_tensor->place(), dst_tensor);
1417+
if (self->tensor.has_allocation() && self->tensor.initialized() &&
1418+
(!dst_tensor->meta().is_contiguous() ||
1419+
!src_tensor->meta().is_contiguous())) {
1420+
VLOG(8) << "set_tensor() method , src or dst tensor is not contiguous ";
1421+
if (!FLAGS_use_stride_kernel) {
1422+
PADDLE_THROW(common::errors::Fatal(
1423+
"FLAGS_use_stride_kernel is closed. Strided kernel "
1424+
"be called, something wrong has happened!"));
1425+
}
1426+
PD_VISIT_ALL_TYPES(
1427+
src_tensor->dtype(), "StridedTensorCopy", ([&] {
1428+
phi::StridedTensorCopy<data_t>(
1429+
*src_tensor,
1430+
common::vectorize<int64_t>(dst_tensor->dims()),
1431+
common::vectorize<int64_t>(dst_tensor->strides()),
1432+
dst_tensor->offset(),
1433+
dst_tensor);
1434+
}));
14411435
} else {
1442-
PADDLE_THROW(common::errors::Unavailable(
1443-
"The `set_tensor()` method of (Dist)Tensor get a src value with "
1444-
"undefined place"));
1436+
if (dst_tensor->place().GetType() != phi::AllocationType::UNDEFINED) {
1437+
framework::TensorCopy(*src_tensor, dst_tensor->place(), dst_tensor);
1438+
} else if (src_tensor->place().GetType() !=
1439+
phi::AllocationType::UNDEFINED) {
1440+
framework::TensorCopy(*src_tensor, src_tensor->place(), dst_tensor);
1441+
} else {
1442+
PADDLE_THROW(common::errors::Unavailable(
1443+
"The `set_tensor()` method of (Dist)Tensor get a src value with "
1444+
"undefined place"));
1445+
}
14451446
}
1446-
//}
14471447

14481448
} else {
14491449
PADDLE_THROW(common::errors::Unavailable(
14501450
"The `set_tensor()` method of non DenseTensor get a DenseTensor src "
14511451
"value"));
14521452
}
1453-
14541453
} else if (value.is_dist_tensor()) {
14551454
#ifdef PADDLE_WITH_DISTRIBUTE
14561455
auto* src_tensor =
@@ -1484,7 +1483,6 @@ static PyObject* tensor_method_set_underline_tensor(TensorObject* self,
14841483
"current PaddlePaddle, please recompile and installPaddlePaddle "
14851484
"with the option of `WITH_DISTRIBUTE=ON`."));
14861485
#endif
1487-
14881486
} else {
14891487
PADDLE_THROW(common::errors::Unavailable(
14901488
"The `set_tensor()` method of (Dist)Tensor get a non "

0 commit comments

Comments
 (0)