Skip to content

Commit 06b8883

Browse files
【cherry-pick stride】Set value when dstplace != srcplace and one tenosr is not con… (#75891)
* 【stride】Set value when dstplace != srcplace and one tenosr is not contiguous should add check (#75794) * add log * fix bug * fix * delete * fix conflict * fix test * change test
1 parent 256c9c0 commit 06b8883

File tree

2 files changed

+44
-1
lines changed

2 files changed

+44
-1
lines changed

paddle/fluid/pybind/eager_method.cc

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1434,7 +1434,8 @@ static PyObject* tensor_method_set_underline_tensor(TensorObject* self,
14341434
static_cast<phi::DenseTensor*>(self->tensor.impl().get());
14351435
if (self->tensor.has_allocation() && self->tensor.initialized() &&
14361436
(!dst_tensor->meta().is_contiguous() ||
1437-
!src_tensor->meta().is_contiguous())) {
1437+
!src_tensor->meta().is_contiguous()) &&
1438+
dst_tensor->place().GetType() == src_tensor->place().GetType()) {
14381439
VLOG(8) << "set_tensor() method , src or dst tensor is not contiguous ";
14391440
if (!FLAGS_use_stride_kernel) {
14401441
PADDLE_THROW(common::errors::Fatal(
@@ -1451,6 +1452,17 @@ static PyObject* tensor_method_set_underline_tensor(TensorObject* self,
14511452
dst_tensor);
14521453
}));
14531454
} else {
1455+
if (!dst_tensor->meta().is_contiguous()) {
1456+
PADDLE_THROW(common::errors::Fatal(
1457+
"dst_tensor is not contiguous and src_tesnor has different place "
1458+
"with dst_tensor, so Strided kernel "
1459+
"can't be called, please change src_tensor'place as same as "
1460+
"dst_tensor'place or change dst_tensor to be contiguous"));
1461+
} else if (!src_tensor->meta().is_contiguous()) {
1462+
VLOG(6) << "src_tensor is not contiguous, so dst_tensor will be not "
1463+
"contiguous after set_value ";
1464+
}
1465+
14541466
if (dst_tensor->place().GetType() != phi::AllocationType::UNDEFINED) {
14551467
framework::TensorCopy(*src_tensor, dst_tensor->place(), dst_tensor);
14561468
} else if (src_tensor->place().GetType() !=

test/legacy_test/test_set_value_op.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1788,5 +1788,36 @@ def test_value_input_is_scalar(self):
17881788
np.testing.assert_array_equal(x.grad, expected_x_grad)
17891789

17901790

1791+
@unittest.skipIf(
1792+
not core.is_compiled_with_cuda(),
1793+
"core is not compiled with CUDA",
1794+
)
1795+
class TestSetValueWithStrideError(unittest.TestCase):
1796+
def test_same_place(self):
1797+
x = paddle.ones([5, 10], device=paddle.CUDAPlace(0))
1798+
y = paddle.zeros([10, 5], device=paddle.CUDAPlace(0))
1799+
y.transpose_([1, 0])
1800+
x.set_value(y)
1801+
assert x.is_contiguous()
1802+
1803+
def test_different_place1(self):
1804+
# src place != dst place && src is not contiguous
1805+
x = paddle.ones([5, 10], device=paddle.CUDAPlace(0))
1806+
y = paddle.zeros([10, 5], device=paddle.CPUPlace())
1807+
y.transpose_([1, 0])
1808+
x.set_value(y)
1809+
assert not x.is_contiguous()
1810+
1811+
def test_different_place2(self):
1812+
# src place != dst place && dst is not contiguous
1813+
with self.assertRaises(SystemError):
1814+
x = paddle.ones([5, 4], device=paddle.CUDAPlace(0))
1815+
x.transpose_([1, 0])
1816+
y = paddle.zeros([4, 2], device=paddle.CPUPlace())
1817+
assert not x.is_contiguous()
1818+
1819+
x[:, 1:3].set_value(y)
1820+
1821+
17911822
if __name__ == '__main__':
17921823
unittest.main()

0 commit comments

Comments
 (0)