diff --git a/python/paddle/distribution/multinomial.py b/python/paddle/distribution/multinomial.py index c57511da5c672..f8e8ab3a0ba68 100644 --- a/python/paddle/distribution/multinomial.py +++ b/python/paddle/distribution/multinomial.py @@ -130,7 +130,12 @@ def log_prob(self, value): logits, value = paddle.broadcast_tensors( [paddle.log(self.probs), value] ) - logits[(value == 0) & (paddle.isinf(logits))] = 0 + if paddle.in_dynamic_mode(): + logits[(value == 0) & (paddle.isinf(logits))] = 0 + else: + logits = paddle.static.setitem( + logits, (value == 0) & (paddle.isinf(logits)), 0 + ) return ( paddle.lgamma(value.sum(-1) + 1) diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index ec3f235a18b8e..849805b4965c5 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -2295,7 +2295,14 @@ def __getitem__(self, item): return _getitem_impl_(self, item) def __setitem__(self, item, value): - return _setitem_impl_(self, item, value) + from .dygraph.base import in_declarative_mode + + if in_declarative_mode(): + return _setitem_impl_(self, item, value) + else: + raise RuntimeError( + "In static mode, the __setitem__ (looks like: x[indices] = values) should not be used. Please use x = paddle.static.setitem(x, indices, values)" + ) def get_value(self, scope=None): """ diff --git a/python/paddle/fluid/variable_index.py b/python/paddle/fluid/variable_index.py index 4f7c94a5acb85..f9a7ecc274bea 100644 --- a/python/paddle/fluid/variable_index.py +++ b/python/paddle/fluid/variable_index.py @@ -215,7 +215,9 @@ def set_item(self, tensor_origin, value): out = paddle.scatter(t_reshape, index_1d, value_1d) if tensor_type is not None: out = out.astype(tensor_type) - tensor_origin[:] = out.reshape(tensor_origin.shape) + tensor_origin = _setitem_impl_( + tensor_origin, ..., out.reshape(tensor_origin.shape) + ) return tensor_origin @@ -617,7 +619,7 @@ def _setitem_for_tensor_array(var, item, value): If item is case (1), we perform paddle.tensor.array_write, in other cases, we raise a NotImplementedError. """ - from ..framework import LayerHelper, core + from .framework import Variable assert ( @@ -632,7 +634,7 @@ def _setitem_for_tensor_array(var, item, value): item = paddle.cast(to_static_variable(item), dtype='int64') value = to_static_variable(value) - array_write(x=value, i=item, array=var) + return array_write(x=value, i=item, array=var) else: raise NotImplementedError( "Only support __setitem__ by Int/Variable in tensor_array, but gets {}".format( @@ -807,17 +809,31 @@ def _setitem_impl_(var, item, value): if paddle.in_dynamic_mode(): var._bump_inplace_version() + output = var + else: + helper = paddle.fluid.layer_helper.LayerHelper('set_value', **locals()) + output = helper.create_variable_for_type_inference(dtype=var.dtype) cur_block = default_main_program().current_block() cur_block.append_op( type="set_value", inputs=inputs, - outputs={'Out': var}, + outputs={'Out': output}, attrs=attrs, inplace_map={"Input": "Out"}, ) - return var + if not paddle.in_dynamic_mode(): + # map var to the new output + from paddle.jit.dy2static.program_translator import ( + ProgramTranslator, + ) + + ProgramTranslator.get_instance()._params_map.add( + cur_block.program, var.desc.id(), output + ) + + return output # the item is a tensor of bool @@ -848,11 +864,19 @@ def idx_not_empty(var, item, value): gather_val = gather_nd(var, idx) gather_val_new = value - gather_val out = scatter_nd_add(var, idx, gather_val_new) - var[:] = out + var = _setitem_impl_(var, ..., out) + return var + + def idx_is_empty(var): + return var from paddle.static.nn import cond # If all the bool index is False, just do nothing - cond(item.any(), lambda: idx_not_empty(var, item, value)) + var = cond( + item.any(), + lambda: idx_not_empty(var, item, value), + lambda: idx_is_empty(var), + ) return var diff --git a/python/paddle/incubate/optimizer/functional/lbfgs.py b/python/paddle/incubate/optimizer/functional/lbfgs.py index fe27d123efee9..06d8ba748c018 100644 --- a/python/paddle/incubate/optimizer/functional/lbfgs.py +++ b/python/paddle/incubate/optimizer/functional/lbfgs.py @@ -178,16 +178,23 @@ def body( shape=[], fill_value=(head - 1).mod(history_size), dtype='int64' ) - def cond(i, q): + def cond(i, q, ai_vec): return i != tail - def body(i, q): - ai_vec[i] = rhok_vec[i] * paddle.dot(sk_vec[i], q) + def body(i, q, ai_vec): + if paddle.in_dynamic_mode(): + ai_vec[i] = rhok_vec[i] * paddle.dot(sk_vec[i], q) + else: + ai_vec = paddle.static.setitem( + ai_vec, i, rhok_vec[i] * paddle.dot(sk_vec[i], q) + ) q = q - ai_vec[i] * yk_vec[i] i = (i - 1).mod(history_size) - return i, q + return i, q, ai_vec - paddle.static.nn.while_loop(cond=cond, body=body, loop_vars=[i, q]) + paddle.static.nn.while_loop( + cond=cond, body=body, loop_vars=[i, q, ai_vec] + ) r = paddle.matmul(H0, q) @@ -234,10 +241,14 @@ def body(i, r): lambda: paddle.full(shape=[1], fill_value=1000.0, dtype=dtype), lambda: 1.0 / rhok_inv, ) - - sk_vec[head] = sk - yk_vec[head] = yk - rhok_vec[head] = rhok + if paddle.in_dynamic_mode(): + sk_vec[head] = sk + yk_vec[head] = yk + rhok_vec[head] = rhok + else: + sk_vec = paddle.static.setitem(sk_vec, head, sk) + yk_vec = paddle.static.setitem(yk_vec, head, yk) + rhok_vec = paddle.static.setitem(rhok_vec, head, rhok) head = (head + 1) % history_size def true_fn(tail): diff --git a/python/paddle/jit/dy2static/convert_operators.py b/python/paddle/jit/dy2static/convert_operators.py index 39029708320a3..b1823785cda5c 100644 --- a/python/paddle/jit/dy2static/convert_operators.py +++ b/python/paddle/jit/dy2static/convert_operators.py @@ -20,7 +20,7 @@ _convert_into_variable, in_declarative_mode, ) -from paddle.fluid.framework import Variable, core +from paddle.fluid.framework import Variable, core, default_main_program from paddle.fluid.layers import control_flow from paddle.fluid.layers.control_flow import while_loop @@ -48,6 +48,19 @@ def convert_load(x): TODO:(@xiongkun) may run convert_load in dygraph mode, which should be fixed. """ return _convert_into_variable(x) + + # get the new output of the var + if in_declarative_mode() and isinstance(x, Variable): + cur_block = default_main_program().current_block() + + from paddle.jit.dy2static.program_translator import ProgramTranslator + + new_var = ProgramTranslator.get_instance()._params_map.get( + cur_block.program, x.desc.id() + ) + if new_var is not None: + return new_var + return x diff --git a/python/paddle/jit/dy2static/program_translator.py b/python/paddle/jit/dy2static/program_translator.py index c0ea50e1e263f..db612474c87d2 100644 --- a/python/paddle/jit/dy2static/program_translator.py +++ b/python/paddle/jit/dy2static/program_translator.py @@ -1125,6 +1125,36 @@ def _program_hash(self, program): return id(program) +class ParametersMap: + def __init__(self): + self.params_dict = {} + + @synchronized + def add(self, program, id, param): + """use the default_program as key, append param the parameter list.""" + key = self._program_hash(program) + if key not in self.params_dict: + self.params_dict[key] = {} + + params = self.params_dict[key] + params[id] = param + + def get(self, program, id): + params = self.params_dict.get(self._program_hash(program)) + if params is None: + return None + if id in params.keys(): + return params[id] + return None + + def _program_hash(self, program): + """ + because program is not deleted while calling from_func_spec. + so it's ok to use id(program) + """ + return id(program) + + class FallbackProgramLayer: __slots__ = [ '_instance', @@ -1386,6 +1416,7 @@ def __init__(self): self._initialized = True self._program_cache = ProgramCache() self._params_recorder = ParametersRecorder() + self._params_map = ParametersMap() self.enable_to_static = True def enable(self, enable_to_static): diff --git a/python/paddle/static/__init__.py b/python/paddle/static/__init__.py index 084579a58e591..94b91db295cb9 100644 --- a/python/paddle/static/__init__.py +++ b/python/paddle/static/__init__.py @@ -37,6 +37,7 @@ from ..fluid import Scope # noqa: F401 from .input import data # noqa: F401 from .input import InputSpec # noqa: F401 +from .input import setitem # noqa: F401 from ..tensor.creation import create_parameter # noqa: F401 from ..tensor.creation import create_global_var # noqa: F401 diff --git a/python/paddle/static/input.py b/python/paddle/static/input.py index af281d6394715..ab8f80c8879aa 100644 --- a/python/paddle/static/input.py +++ b/python/paddle/static/input.py @@ -18,6 +18,8 @@ from paddle.fluid.framework import convert_np_dtype_to_dtype_, static_only from paddle.fluid.layer_helper import LayerHelper +from ..fluid.variable_index import _setitem_impl_ + __all__ = [] @@ -342,3 +344,28 @@ def __eq__(self, other): def __ne__(self, other): return not self == other + + +def setitem(x, index, value): + """ + x(Tensor): input Tensor. + index(Scalar|Tuple|List|Tensor): Where should be set value. + value(Scalar|Tensor): The value which is going to be set. + + [How to write index?] + 1. ':' -> slice(), + (1) a[:]=v -> setitem(a, slice(None,None,None), v) + (2) a[1::2] -> setitem(a, slice(1,None,2), v) + + 2. if there are multiple indexes for axes, use TUPLE (Not LIST) to pack them. + (1) a[1, 2]=v -> setitem(a, (1, 2), v) + (2) a[[1,2],[2,3]]=v -> setitem(a, ([1,2],[2,3]), v) + (3) a[1,:, 3] = v -> setitem(a, (1, slice(None,None,None),3), v) + (4) a[1, ..., 2]=v -> setitem(a, (1, ..., 2), v) + + 3. You can always use TUPLE as index input, even there is only one index. + (1) a[Tensor([10,10])]=v -> setitem(a, (Tensor([10,10]),), v) + (2) a[1] = v -> setitem(a, (1,), v) + """ + + return _setitem_impl_(x, index, value) diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 613ba5f84eaf2..42b8d02ead727 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -5788,11 +5788,26 @@ def vander(x, n=None, increasing=False, name=None): res = paddle.empty([x.shape[0], n], dtype=x.dtype) - if n > 0: - res[:, 0] = paddle.to_tensor([1], dtype=x.dtype) - if n > 1: - res[:, 1:] = x[:, None] - res[:, 1:] = paddle.cumprod(res[:, 1:], dim=-1) + if paddle.in_dynamic_mode(): + if n > 0: + res[:, 0] = paddle.to_tensor([1], dtype=x.dtype) + if n > 1: + res[:, 1:] = x[:, None] + res[:, 1:] = paddle.cumprod(res[:, 1:], dim=-1) + else: + if n > 0: + res = paddle.static.setitem( + res, (slice(None), 0), paddle.to_tensor([1], dtype=x.dtype) + ) + if n > 1: + res = paddle.static.setitem( + res, (slice(None), slice(1, None)), x[:, None] + ) + res = paddle.static.setitem( + res, + (slice(None), slice(1, None)), + paddle.cumprod(res[:, 1:], dim=-1), + ) res = res[:, ::-1] if not increasing else res return res diff --git a/python/paddle/vision/transforms/functional_tensor.py b/python/paddle/vision/transforms/functional_tensor.py index e2b7f3cc734d1..e7b57a011baeb 100644 --- a/python/paddle/vision/transforms/functional_tensor.py +++ b/python/paddle/vision/transforms/functional_tensor.py @@ -222,12 +222,12 @@ def _affine_grid(theta, w, h, ow, oh): base_grid = paddle.ones((1, oh, ow, 3), dtype=theta.dtype) x_grid = paddle.linspace(-ow * 0.5 + d, ow * 0.5 + d - 1, ow) - base_grid[..., 0] = x_grid if paddle.in_dynamic_mode(): y_grid = paddle.linspace( -oh * 0.5 + d, oh * 0.5 + d - 1, oh ).unsqueeze_(-1) + base_grid[..., 0] = x_grid base_grid[..., 1] = y_grid tmp = paddle.to_tensor([0.5 * w, 0.5 * h]) else: @@ -236,7 +236,8 @@ def _affine_grid(theta, w, h, ow, oh): y_grid = paddle.linspace(-oh * 0.5 + d, oh * 0.5 + d - 1, oh).unsqueeze( -1 ) - base_grid[..., 1] = y_grid + base_grid = paddle.static.setitem(base_grid, (..., 0), x_grid) + base_grid = paddle.static.setitem(base_grid, (..., 1), y_grid) tmp = paddle.assign(np.array([0.5 * w, 0.5 * h], dtype="float32")) scaled_theta = theta.transpose((0, 2, 1)) / tmp @@ -397,6 +398,17 @@ def rotate( 0.0, ] matrix = paddle.to_tensor(matrix, place=img.place) + + matrix[2] += ( + matrix[0] * (-rotn_center[0] - post_trans[0]) + + matrix[1] * (-rotn_center[1] - post_trans[1]) + + rotn_center[0] + ) + matrix[5] += ( + matrix[3] * (-rotn_center[0] - post_trans[0]) + + matrix[4] * (-rotn_center[1] - post_trans[1]) + + rotn_center[1] + ) else: angle = angle / 180 * math.pi matrix = paddle.concat( @@ -409,16 +421,22 @@ def rotate( paddle.zeros([1]), ] ) - - matrix[2] += matrix[0] * (-rotn_center[0] - post_trans[0]) + matrix[1] * ( - -rotn_center[1] - post_trans[1] - ) - matrix[5] += matrix[3] * (-rotn_center[0] - post_trans[0]) + matrix[4] * ( - -rotn_center[1] - post_trans[1] - ) - - matrix[2] += rotn_center[0] - matrix[5] += rotn_center[1] + matrix = paddle.static.setitem( + matrix, + 2, + matrix[2] + + matrix[0] * (-rotn_center[0] - post_trans[0]) + + matrix[1] * (-rotn_center[1] - post_trans[1]) + + rotn_center[0], + ) + matrix = paddle.static.setitem( + matrix, + 5, + matrix[5] + + matrix[3] * (-rotn_center[0] - post_trans[0]) + + matrix[4] * (-rotn_center[1] - post_trans[1]) + + rotn_center[1], + ) matrix = matrix.reshape((1, 2, 3)) @@ -621,7 +639,12 @@ def erase(img, i, j, h, w, v, inplace=False): if not inplace: img = img.clone() - img[..., i : i + h, j : j + w] = v + if paddle.in_dynamic_mode(): + img[..., i : i + h, j : j + w] = v + else: + img = paddle.static.setitem( + img, (..., slice(i, i + h), slice(j, j + w)), v + ) return img diff --git a/test/dygraph_to_static/test_setitem.py b/test/dygraph_to_static/test_setitem.py new file mode 100644 index 0000000000000..93b8c5d7936b4 --- /dev/null +++ b/test/dygraph_to_static/test_setitem.py @@ -0,0 +1,182 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle + + +class TestSetItemBase(unittest.TestCase): + def setUp(self) -> None: + pass + + def init_data(self): + paddle.seed(2023) + x = paddle.randn([4, 8, 16, 32]) + x.stop_gradient = False + return x + + def init_func(self): + def foo(x): + y = x + 1 + y[:, 2] = x[:, 2] + 99 + return y + + return foo + + def test_case(self): + func = self.init_func() + dy_res = self.run_dygrah(func) + st_res = self.run_to_static(func) + + for dy_out, st_out in zip(dy_res, st_res): + np.testing.assert_allclose(dy_out.numpy(), st_out.numpy()) + + def run_dygrah(self, func): + x = self.init_data() + y = func(x) + x_grad = paddle.grad(y, x)[0] + return y, x_grad + + def run_to_static(self, func): + func = paddle.jit.to_static(func) + return self.run_dygrah(func) + + +class TestCase1(TestSetItemBase): + def init_func(self): + def foo(x): + y = x + 1 + y[2] = x[2] + 99 # (2, ) + return y + + return foo + + +class TestCase2(TestSetItemBase): + def init_func(self): + def foo(x): + y = x + 1 + y[:] = x[:] + 99 # slice(None,None,None) + return y + + return foo + + +class TestCase3(TestSetItemBase): + def init_func(self): + def foo(x): + y = x + 1 + y[1::2] = x[1::2] + 99 # slice(1,None,2) + return y + + return foo + + +class TestCase4(TestSetItemBase): + def init_func(self): + def foo(x): + y = x + 1 + y[1, 2] = x[1, 2] + 99 # (1, 2) + return y + + return foo + + +class TestCase5(TestSetItemBase): + def init_func(self): + def foo(x): + y = x + 1 + y[[1, 2], [2, 3]] = x[[1, 2], [2, 3]] + 99 # ([1,2],[2,3]) + return y + + return foo + + +class TestCase6(TestSetItemBase): + def init_func(self): + def foo(x): + y = x + 1 + y[1, :, 3] = x[1, :, 3] + 99 # slice(None,None,None),3) + return y + + return foo + + +class TestCase7(TestSetItemBase): + def init_func(self): + def foo(x): + y = x + 1 + y[1, ..., 2] = x[1, ..., 2] + 99 # (1, ..., 2) + return y + + return foo + + +class TestCase8(TestSetItemBase): + def init_func(self): + def foo(x): + y = x + 1 + index = paddle.to_tensor([1, 2], dtype="int64") + y[index] = x[index] + 99 # Tensor([1,2]) + return y + + return foo + + +class TestCase9(TestSetItemBase): + def init_func(self): + def foo(x): + y = x + 1 + one = paddle.to_tensor(1, dtype="int64") + two = paddle.to_tensor(2, dtype="int64") + y[one, :, :, 2] = x[1, :, :, two] + 100 # Tensor(1), Tensor(2) + return y + + return foo + + +class TestCase10(TestSetItemBase): + def init_func(self): + def foo(x): + y = x + 1 + y[..., 4:6] = y[..., 4:6] * 10000 + return y + + return foo + + +class TestCase11(TestSetItemBase): + # Test gradient of value tensor + def init_func(self): + def foo(x, value): + y = x + 1 + y[2, 4] = value + return y + + return foo + + def run_dygrah(self, func): + x = self.init_data() + value = paddle.ones((16, 32)) + value.stop_gradient = False + y = func(x, value) + x_grad, value_grad = paddle.grad(y, [x, value]) + return y, x_grad, value_grad + + +if __name__ == '__main__': + unittest.main() diff --git a/test/dygraph_to_static/test_slice.py b/test/dygraph_to_static/test_slice.py index b4fd5f25d7ca8..59af18b706ce0 100644 --- a/test/dygraph_to_static/test_slice.py +++ b/test/dygraph_to_static/test_slice.py @@ -129,8 +129,12 @@ def run_dygraph_mode(self): return self._run(to_static=False) def _run(self, to_static): - paddle.jit.enable_to_static(to_static) - res = self.dygraph_func(self.input) + func = ( + paddle.jit.to_static(self.dygraph_func) + if to_static + else self.dygraph_func + ) + res = func(self.input) return res.numpy() def run_static_mode(self): diff --git a/test/legacy_test/test_program_converter.py b/test/legacy_test/test_program_converter.py index a9926b8c5c6d6..9a9c49df01b68 100644 --- a/test/legacy_test/test_program_converter.py +++ b/test/legacy_test/test_program_converter.py @@ -81,7 +81,8 @@ def test_int32(self): with paddle.static.program_guard(mp, sp): x = paddle.ones([3, 4], dtype=paddle.int32) patch = np.array([41, 42]).astype(np.int32) - x[:1, :2] = patch + index = (slice(None, 1), slice(None, 2)) + x = paddle.static.setitem(x, index, patch) x_input = np.ones([3, 4], dtype=np.int32) x_output = x_input.copy() @@ -110,10 +111,12 @@ def test_int64(self): patch = np.array( [np.iinfo(np.int64).max, np.iinfo(np.int64).min] ).astype(np.int64) - x[:1, :2] = patch + index = (slice(None, 1), slice(None, 2)) + x = paddle.static.setitem(x, index, patch) x_input = np.ones([3, 4], dtype=np.int64) x_output = x_input.copy() + x_output[:1, :2] = patch self.fetch_list = [x.name] @@ -142,7 +145,8 @@ def test_float32(self): patch = np.array( [np.finfo(np.float32).max, np.finfo(np.float32).min] ).astype(np.float32) - x[:1, :2] = patch + index = (slice(None, 1), slice(None, 2)) + x = paddle.static.setitem(x, index, patch) x_input = np.ones([3, 4], dtype=np.float32) x_output = x_input.copy() @@ -171,7 +175,8 @@ def test_float64(self): patch = np.array( [np.finfo(np.float64).max, np.finfo(np.float64).min] ).astype(np.float64) - x[:1, :2] = patch + index = (slice(None, 1), slice(None, 2)) + x = paddle.static.setitem(x, index, patch) x_input = np.ones([3, 4], dtype=np.float64) x_output = x_input.copy() @@ -200,7 +205,8 @@ def test_float16(self): patch = np.array( [np.finfo(np.float16).max, np.finfo(np.float16).min] ).astype(np.float16) - x[:1, :2] = patch + index = (slice(None, 1), slice(None, 2)) + x = paddle.static.setitem(x, index, patch) x_input = np.ones([3, 4], dtype=np.float16) x_output = x_input.copy() @@ -227,7 +233,8 @@ def test_bool(self): with paddle.static.program_guard(mp, sp): x = paddle.ones([3, 4], dtype=paddle.bool) patch = np.array([True, False]) - x[:1, :2] = patch + index = (slice(None, 1), slice(None, 2)) + x = paddle.static.setitem(x, index, patch) x_input = np.ones([3, 4], dtype=bool) x_output = x_input.copy() @@ -257,7 +264,8 @@ def test_complex64(self): paddle.ones([3, 4], dtype=paddle.float32), ) patch = np.array([42.1 + 42.1j, 42.2 + 42.2j]).astype(np.complex64) - x[:1, :2] = patch + index = (slice(None, 1), slice(None, 2)) + x = paddle.static.setitem(x, index, patch) x_input = (np.ones([3, 4]) + 1j * np.ones([3, 4])).astype(np.complex64) x_output = x_input.copy() @@ -282,7 +290,8 @@ def test_complex128(self): np.finfo(np.float64).min + 1j * np.finfo(np.float64).max, ] ).astype(np.complex128) - x[:1, :2] = patch + index = (slice(None, 1), slice(None, 2)) + x = paddle.static.setitem(x, index, patch) x_input = (np.ones([3, 4]) + 1j * np.ones([3, 4])).astype(np.complex128) x_output = x_input.copy() diff --git a/test/legacy_test/test_set_value_op.py b/test/legacy_test/test_set_value_op.py index c87d5fd4729e5..d57ef686dfaf2 100644 --- a/test/legacy_test/test_set_value_op.py +++ b/test/legacy_test/test_set_value_op.py @@ -46,6 +46,10 @@ def set_dtype(self): def _call_setitem(self, x): x[0, 0] = self.value + def _call_setitem_static_api(self, x): + x = paddle.static.setitem(x, (0, 0), self.value) + return x + def _get_answer(self): self.data[0, 0] = self.value @@ -55,7 +59,7 @@ def _run_static(self): paddle.enable_static() with paddle.static.program_guard(self.program): x = paddle.ones(shape=self.shape, dtype=self.dtype) - self._call_setitem(x) + x = self._call_setitem_static_api(x) exe = paddle.static.Executor(paddle.CPUPlace()) out = exe.run(self.program, fetch_list=[x]) @@ -94,6 +98,10 @@ class TestSetValueItemInt(TestSetValueApi): def _call_setitem(self, x): x[0] = self.value + def _call_setitem_static_api(self, x): + x = paddle.static.setitem(x, 0, self.value) + return x + def _get_answer(self): self.data[0] = self.value @@ -104,6 +112,10 @@ class TestSetValueItemSlice(TestSetValueApi): def _call_setitem(self, x): x[0:2] = self.value + def _call_setitem_static_api(self, x): + x = paddle.static.setitem(x, slice(0, 2), self.value) + return x + def _get_answer(self): self.data[0:2] = self.value @@ -112,6 +124,10 @@ class TestSetValueItemSlice2(TestSetValueApi): def _call_setitem(self, x): x[0:-1] = self.value + def _call_setitem_static_api(self, x): + x = paddle.static.setitem(x, slice(0, -1), self.value) + return x + def _get_answer(self): self.data[0:-1] = self.value @@ -120,6 +136,10 @@ class TestSetValueItemSlice3(TestSetValueApi): def _call_setitem(self, x): x[0:-1, 0:2] = self.value + def _call_setitem_static_api(self, x): + x = paddle.static.setitem(x, (slice(0, -1), slice(0, 2)), self.value) + return x + def _get_answer(self): self.data[0:-1, 0:2] = self.value @@ -128,6 +148,12 @@ class TestSetValueItemSlice4(TestSetValueApi): def _call_setitem(self, x): x[0:, 1:2, :] = self.value + def _call_setitem_static_api(self, x): + x = paddle.static.setitem( + x, (slice(0, None), slice(1, 2), slice(None)), self.value + ) + return x + def _get_answer(self): self.data[0:, 1:2, :] = self.value @@ -136,6 +162,12 @@ class TestSetValueItemSlice5(TestSetValueApi): def _call_setitem(self, x): x[0:, 1:1, :] = self.value + def _call_setitem_static_api(self, x): + x = paddle.static.setitem( + x, (slice(0, None), slice(1, 1), slice(None)), self.value + ) + return x + def _get_answer(self): self.data[0:, 1:1, :] = self.value @@ -153,6 +185,19 @@ def body(i, x): i = paddle.zeros(shape=(1,), dtype='int32') i, x = paddle.static.nn.while_loop(cond, body, [i, x]) + def _call_setitem_static_api(self, x): + def cond(i, x): + return i < 1 + + def body(i, x): + x = paddle.static.setitem(x, i, self.value) + i = i + 1 + return i, x + + i = paddle.zeros(shape=(1,), dtype='int32') + i, x = paddle.static.nn.while_loop(cond, body, [i, x]) + return x + def _get_answer(self): self.data[0] = self.value @@ -165,6 +210,10 @@ def set_shape(self): def _call_setitem(self, x): x[0:2:2] = self.value + def _call_setitem_static_api(self, x): + x = paddle.static.setitem(x, slice(0, 2, 2), self.value) + return x + def _get_answer(self): self.data[0:2:2] = self.value @@ -176,6 +225,10 @@ def set_shape(self): def _call_setitem(self, x): x[0:-1:3] = self.value + def _call_setitem_static_api(self, x): + x = paddle.static.setitem(x, slice(0, -1, 3), self.value) + return x + def _get_answer(self): self.data[0:-1:3] = self.value @@ -184,6 +237,12 @@ class TestSetValueItemSliceStep3(TestSetValueApi): def _call_setitem(self, x): x[0:-1, 0:2, ::2] = self.value + def _call_setitem_static_api(self, x): + x = paddle.static.setitem( + x, (slice(0, -1), slice(0, 2), slice(None, None, 2)), self.value + ) + return x + def _get_answer(self): self.data[0:-1, 0:2, ::2] = self.value @@ -192,6 +251,12 @@ class TestSetValueItemSliceStep4(TestSetValueApi): def _call_setitem(self, x): x[0:, 1:2:2, :] = self.value + def _call_setitem_static_api(self, x): + x = paddle.static.setitem( + x, (slice(0, None), slice(1, 2, 2), slice(None)), self.value + ) + return x + def _get_answer(self): self.data[0:, 1:2:2, :] = self.value @@ -207,6 +272,10 @@ def set_value(self): def _call_setitem(self, x): x[5:2:-1] = self.value + def _call_setitem_static_api(self, x): + x = paddle.static.setitem(x, slice(5, 2, -1), self.value) + return x + def _get_answer(self): self.data[5:2:-1] = self.value @@ -221,6 +290,10 @@ def set_value(self): def _call_setitem(self, x): x[1::-1] = self.value + def _call_setitem_static_api(self, x): + x = paddle.static.setitem(x, slice(1, None, -1), self.value) + return x + def _get_answer(self): self.data[1::-1] = self.value @@ -235,6 +308,10 @@ def set_value(self): def _call_setitem(self, x): x[::-1] = self.value + def _call_setitem_static_api(self, x): + x = paddle.static.setitem(x, slice(None, None, -1), self.value) + return x + def _get_answer(self): self.data[::-1] = self.value @@ -246,6 +323,12 @@ def set_shape(self): def _call_setitem(self, x): x[2:0:-1, 0:2, ::-1] = self.value + def _call_setitem_static_api(self, x): + x = paddle.static.setitem( + x, (slice(2, 0, -1), slice(0, 2), slice(None, None, -1)), self.value + ) + return x + def _get_answer(self): self.data[2:0:-1, 0:2, ::-1] = self.value @@ -257,6 +340,12 @@ class TestSetValueItemEllipsis1(TestSetValueApi): def _call_setitem(self, x): x[0:, ..., 1:] = self.value + def _call_setitem_static_api(self, x): + x = paddle.static.setitem( + x, (slice(0, None), ..., slice(1, None)), self.value + ) + return x + def _get_answer(self): self.data[0:, ..., 1:] = self.value @@ -265,6 +354,10 @@ class TestSetValueItemEllipsis2(TestSetValueApi): def _call_setitem(self, x): x[0:, ...] = self.value + def _call_setitem_static_api(self, x): + x = paddle.static.setitem(x, (slice(0, None), ...), self.value) + return x + def _get_answer(self): self.data[0:, ...] = self.value @@ -273,6 +366,10 @@ class TestSetValueItemEllipsis3(TestSetValueApi): def _call_setitem(self, x): x[..., 1:] = self.value + def _call_setitem_static_api(self, x): + x = paddle.static.setitem(x, (..., slice(1, None)), self.value) + return x + def _get_answer(self): self.data[..., 1:] = self.value @@ -281,6 +378,10 @@ class TestSetValueItemEllipsis4(TestSetValueApi): def _call_setitem(self, x): x[...] = self.value + def _call_setitem_static_api(self, x): + x = paddle.static.setitem(x, ..., self.value) + return x + def _get_answer(self): self.data[...] = self.value @@ -288,49 +389,84 @@ def _get_answer(self): # 1.4 item is Paddle Tensor class TestSetValueItemTensor(TestSetValueApi): def _call_setitem(self, x): - zero = paddle.full([1], 0, dtype="int32") + zero = paddle.full([], 0, dtype="int32") x[zero] = self.value + def _call_setitem_static_api(self, x): + zero = paddle.full([], 0, dtype="int32") + x = paddle.static.setitem(x, zero, self.value) + return x + def _get_answer(self): self.data[0] = self.value class TestSetValueItemTensor2(TestSetValueApi): def _call_setitem(self, x): - zero = paddle.full([1], 0, dtype="int32") - two = paddle.full([1], 2, dtype="int64") + zero = paddle.full([], 0, dtype="int32") + two = paddle.full([], 2, dtype="int64") x[zero:two] = self.value + def _call_setitem_static_api(self, x): + zero = paddle.full([], 0, dtype="int32") + two = paddle.full([], 2, dtype="int64") + x = paddle.static.setitem(x, slice(zero, two), self.value) + return x + def _get_answer(self): self.data[0:2] = self.value class TestSetValueItemTensor3(TestSetValueApi): def _call_setitem(self, x): - zero = paddle.full([1], 0, dtype="int32") - two = paddle.full([1], 2, dtype="int64") + zero = paddle.full([], 0, dtype="int32") + two = paddle.full([], 2, dtype="int64") x[zero:-1, 0:two] = self.value + def _call_setitem_static_api(self, x): + zero = paddle.full([], 0, dtype="int32") + two = paddle.full([], 2, dtype="int64") + x = paddle.static.setitem( + x, (slice(zero, -1), slice(0, two)), self.value + ) + return x + def _get_answer(self): self.data[0:-1, 0:2] = self.value class TestSetValueItemTensor4(TestSetValueApi): def _call_setitem(self, x): - zero = paddle.full([1], 0, dtype="int32") - two = paddle.full([1], 2, dtype="int64") + zero = paddle.full([], 0, dtype="int32") + two = paddle.full([], 2, dtype="int64") x[0:-1, zero:2, 0:6:two] = self.value + def _call_setitem_static_api(self, x): + zero = paddle.full([], 0, dtype="int32") + two = paddle.full([], 2, dtype="int64") + x = paddle.static.setitem( + x, (slice(0, -1), slice(zero, 2), slice(0, 6, two)), self.value + ) + return x + def _get_answer(self): self.data[0:-1, 0:2, ::2] = self.value class TestSetValueItemTensor5(TestSetValueApi): def _call_setitem(self, x): - zero = paddle.full([1], 0, dtype="int32") - two = paddle.full([1], 2, dtype="int64") + zero = paddle.full([], 0, dtype="int32") + two = paddle.full([], 2, dtype="int64") x[zero:, 1:2:two, :] = self.value + def _call_setitem_static_api(self, x): + zero = paddle.full([], 0, dtype="int32") + two = paddle.full([], 2, dtype="int64") + x = paddle.static.setitem( + x, (slice(zero, None), slice(1, 2, two)), self.value + ) + return x + def _get_answer(self): self.data[0:, 1:2:2, :] = self.value @@ -340,10 +476,20 @@ def set_shape(self): self.shape = [3, 4, 5] def _call_setitem(self, x): - minus1 = paddle.full([1], -1, dtype="int32") - zero = paddle.full([1], 0, dtype="int32") + minus1 = paddle.full([], -1, dtype="int32") + zero = paddle.full([], 0, dtype="int32") x[2:zero:minus1, 0:2, 10:-6:minus1] = self.value + def _call_setitem_static_api(self, x): + minus1 = paddle.full([], -1, dtype="int32") + zero = paddle.full([], 0, dtype="int64") + x = paddle.static.setitem( + x, + (slice(2, zero, minus1), slice(0, 2), slice(10, -6, minus1)), + self.value, + ) + return x + def _get_answer(self): self.data[2:0:-1, 0:2, ::-1] = self.value @@ -353,6 +499,10 @@ class TestSetValueItemNone1(TestSetValueApi): def _call_setitem(self, x): x[None] = self.value + def _call_setitem_static_api(self, x): + x = paddle.static.setitem(x, None, self.value) + return x + def _get_answer(self): self.data[None] = self.value @@ -361,6 +511,10 @@ class TestSetValueItemNone2(TestSetValueApi): def _call_setitem(self, x): x[0, None, 1] = self.value + def _call_setitem_static_api(self, x): + x = paddle.static.setitem(x, (0, None, 1), self.value) + return x + def _get_answer(self): self.data[0, None, 1] = self.value @@ -369,6 +523,10 @@ class TestSetValueItemNone3(TestSetValueApi): def _call_setitem(self, x): x[:, None, None, 1] = self.value + def _call_setitem_static_api(self, x): + x = paddle.static.setitem(x, (slice(None), None, None, 1), self.value) + return x + def _get_answer(self): self.data[:, None, None, 1] = self.value @@ -377,6 +535,10 @@ class TestSetValueItemNone4(TestSetValueApi): def _call_setitem(self, x): x[0, 0, None, 1] = self.value + def _call_setitem_static_api(self, x): + x = paddle.static.setitem(x, (0, 0, None, 1), self.value) + return x + def _get_answer(self): self.data[0, 0, None, 1] = self.value @@ -385,6 +547,10 @@ class TestSetValueItemNone5(TestSetValueApi): def _call_setitem(self, x): x[0, None, 0, None, 1] = self.value + def _call_setitem_static_api(self, x): + x = paddle.static.setitem(x, (0, None, 0, None, 1), self.value) + return x + def _get_answer(self): self.data[0, None, 0, None, 1] = self.value @@ -393,6 +559,10 @@ class TestSetValueItemNone6(TestSetValueApi): def _call_setitem(self, x): x[None, 0, 0, None, 0] = self.value + def _call_setitem_static_api(self, x): + x = paddle.static.setitem(x, (None, 0, 0, None, 0), self.value) + return x + def _get_answer(self): self.data[None, 0, 0, None, 0] = self.value @@ -401,6 +571,12 @@ class TestSetValueItemNone7(TestSetValueApi): def _call_setitem(self, x): x[:, None, 1] = np.zeros(self.shape)[:, None, 0] + def _call_setitem_static_api(self, x): + x = paddle.static.setitem( + x, (slice(None), None, 1), np.zeros(self.shape)[:, None, 0] + ) + return x + def _get_answer(self): self.data[:, None, 1] = np.zeros(self.shape)[:, None, 0] @@ -409,6 +585,12 @@ class TestSetValueItemNone8(TestSetValueApi): def _call_setitem(self, x): x[:, 1, None] = np.zeros(self.shape)[:, 0, None] + def _call_setitem_static_api(self, x): + x = paddle.static.setitem( + x, (slice(None), 1, None), np.zeros(self.shape)[:, 0, None] + ) + return x + def _get_answer(self): self.data[:, 1, None] = np.zeros(self.shape)[:, 0, None] @@ -417,6 +599,14 @@ class TestSetValueItemNone9(TestSetValueApi): def _call_setitem(self, x): x[None, :, 1, ..., None] = np.zeros(self.shape)[0, 0, :, None] + def _call_setitem_static_api(self, x): + x = paddle.static.setitem( + x, + (None, slice(None), 1, ..., None), + np.zeros(self.shape)[0, 0, :, None], + ) + return x + def _get_answer(self): self.data[None, :, 1, ..., None] = np.zeros(self.shape)[0, 0, :, None] @@ -425,15 +615,29 @@ class TestSetValueItemNone10(TestSetValueApi): def _call_setitem(self, x): x[..., None, :, None] = np.zeros(self.shape)[..., None, :, None] + def _call_setitem_static_api(self, x): + x = paddle.static.setitem( + x, + (..., None, slice(None), None), + np.zeros(self.shape)[..., None, :, None], + ) + return x + def _get_answer(self): self.data[..., None, :, None] = np.zeros(self.shape)[..., None, :, None] -# 1.5 item is list or Tensor of bol +# 1.5 item is list or Tensor of bool +# NOTE(zoooo0820): Currently, 1-D List is same to Tuple. +# The semantic of index will be modified later. class TestSetValueItemBool1(TestSetValueApi): def _call_setitem(self, x): x[[True, False]] = self.value + def _call_setitem_static_api(self, x): + x = paddle.static.setitem(x, [True, False], self.value) + return x + def _get_answer(self): self.data[[True, False]] = self.value @@ -442,6 +646,10 @@ class TestSetValueItemBool2(TestSetValueApi): def _call_setitem(self, x): x[[False, False]] = self.value + def _call_setitem_static_api(self, x): + x = paddle.static.setitem(x, [False, False], self.value) + return x + def _get_answer(self): self.data[[False, False]] = self.value @@ -450,6 +658,10 @@ class TestSetValueItemBool3(TestSetValueApi): def _call_setitem(self, x): x[[False, True]] = np.zeros(self.shape[2]) + def _call_setitem_static_api(self, x): + x = paddle.static.setitem(x, [False, True], np.zeros(self.shape[2])) + return x + def _get_answer(self): self.data[[False, True]] = np.zeros(self.shape[2]) @@ -459,6 +671,11 @@ def _call_setitem(self, x): idx = paddle.assign(np.array([False, True])) x[idx] = np.zeros(self.shape[2]) + def _call_setitem_static_api(self, x): + idx = paddle.assign(np.array([False, True])) + x = paddle.static.setitem(x, idx, np.zeros(self.shape[2])) + return x + def _get_answer(self): self.data[np.array([False, True])] = np.zeros(self.shape[2]) @@ -470,6 +687,13 @@ def _call_setitem(self, x): ) x[idx] = self.value + def _call_setitem_static_api(self, x): + idx = paddle.assign( + np.array([[False, True, False], [True, True, False]]) + ) + x = paddle.static.setitem(x, idx, self.value) + return x + def _get_answer(self): self.data[ np.array([[False, True, False], [True, True, False]]) @@ -481,6 +705,11 @@ def _call_setitem(self, x): x[0, ...] = 0 x[x > 0] = self.value + def _call_setitem_static_api(self, x): + x = paddle.static.setitem(x, (0, ...), 0) + x = paddle.static.setitem(x, x > 0, self.value) + return x + def _get_answer(self): self.data[0, ...] = 0 self.data[self.data > 0] = self.value @@ -803,9 +1032,14 @@ def set_dtype(self): self.dtype = "int32" def _call_setitem(self, x): - value = paddle.full(shape=[1], fill_value=3, dtype=self.dtype) + value = paddle.full(shape=[], fill_value=3, dtype=self.dtype) x[0, 1] = value + def _call_setitem_static_api(self, x): + value = paddle.full(shape=[], fill_value=3, dtype=self.dtype) + x = paddle.static.setitem(x, (0, 1), value) + return x + def _get_answer(self): self.data[0, 1] = 3 @@ -827,9 +1061,14 @@ def set_dtype(self): self.dtype = "int64" def _call_setitem(self, x): - value = paddle.full(shape=[1], fill_value=3, dtype=self.dtype) + value = paddle.full(shape=[], fill_value=3, dtype=self.dtype) x[0, 1] = value + def _call_setitem_static_api(self, x): + value = paddle.full(shape=[], fill_value=3, dtype=self.dtype) + x = paddle.static.setitem(x, (0, 1), value) + return x + def _get_answer(self): self.data[0, 1] = 3 @@ -851,9 +1090,14 @@ def set_dtype(self): self.dtype = "float32" def _call_setitem(self, x): - value = paddle.full(shape=[1], fill_value=3, dtype=self.dtype) + value = paddle.full(shape=[], fill_value=3, dtype=self.dtype) x[0, 1] = value + def _call_setitem_static_api(self, x): + value = paddle.full(shape=[], fill_value=3, dtype=self.dtype) + x = paddle.static.setitem(x, (0, 1), value) + return x + def _get_answer(self): self.data[0, 1] = 3 @@ -875,9 +1119,14 @@ def set_dtype(self): self.dtype = "float64" def _call_setitem(self, x): - value = paddle.full(shape=[1], fill_value=3, dtype=self.dtype) + value = paddle.full(shape=[], fill_value=3, dtype=self.dtype) x[0, 1] = value + def _call_setitem_static_api(self, x): + value = paddle.full(shape=[], fill_value=3, dtype=self.dtype) + x = paddle.static.setitem(x, (0, 1), value) + return x + def _get_answer(self): self.data[0, 1] = 3 @@ -899,9 +1148,14 @@ def set_dtype(self): self.dtype = "bool" def _call_setitem(self, x): - value = paddle.full(shape=[1], fill_value=False, dtype=self.dtype) + value = paddle.full(shape=[], fill_value=False, dtype=self.dtype) x[0, 1] = value + def _call_setitem_static_api(self, x): + value = paddle.full(shape=[], fill_value=False, dtype=self.dtype) + x = paddle.static.setitem(x, (0, 1), value) + return x + def _get_answer(self): self.data[0, 1] = False @@ -925,6 +1179,10 @@ def set_value(self): def _call_setitem(self, x): x[0] = self.value + def _call_setitem_static_api(self, x): + x = paddle.static.setitem(x, 0, self.value) + return x + def _get_answer(self): self.data[0] = self.value @@ -936,6 +1194,10 @@ def set_value(self): def _call_setitem(self, x): x[0:1] = self.value + def _call_setitem_static_api(self, x): + x = paddle.static.setitem(x, slice(0, 1), self.value) + return x + def _get_answer(self): self.data[0:1] = self.value @@ -949,6 +1211,10 @@ def set_value(self): def _call_setitem(self, x): x[0] = self.value + def _call_setitem_static_api(self, x): + x = paddle.static.setitem(x, 0, self.value) + return x + def _get_answer(self): self.data[0] = self.value @@ -964,6 +1230,10 @@ def set_value(self): def _call_setitem(self, x): x[0] = paddle.assign(self.value) # x is Paddle.Tensor + def _call_setitem_static_api(self, x): + x = paddle.static.setitem(x, 0, paddle.assign(self.value)) + return x + def _get_answer(self): self.data[0] = self.value @@ -978,6 +1248,12 @@ def set_shape(self): def _call_setitem(self, x): x[:, 0] = paddle.assign(self.value) # x is Paddle.Tensor + def _call_setitem_static_api(self, x): + x = paddle.static.setitem( + x, (slice(None), 0), paddle.assign(self.value) + ) + return x + def _get_answer(self): self.data[:, 0] = self.value @@ -997,6 +1273,10 @@ def _call_setitem(self, x): def _get_answer(self): self.data[:, 0] = self.value + def _call_setitem_static_api(self, x): + x = paddle.static.setitem(x, (slice(None), 0), self.value) + return x + def test_api(self): places = ['cpu'] if paddle.is_compiled_with_cuda(): @@ -1030,7 +1310,10 @@ def _value_type_error(self): ): x = paddle.ones(shape=self.shape, dtype=self.dtype) value = [1] - x[0] = value + if paddle.in_dynamic_mode(): + x[0] = value + else: + x = paddle.static.setitem(x, 0, value) def _dtype_error(self): with self.assertRaisesRegex( @@ -1043,7 +1326,10 @@ def _dtype_error(self): def _step_error(self): with self.assertRaisesRegex(ValueError, "step can not be 0"): x = paddle.ones(shape=self.shape, dtype=self.dtype) - x[0:1:0] = self.value + if paddle.in_dynamic_mode(): + x[0:1:0] = self.value + else: + x = paddle.static.setitem(x, slice(0, 1, 0), self.value) def _ellipsis_error(self): with self.assertRaisesRegex( @@ -1059,24 +1345,33 @@ def _ellipsis_error(self): def _bool_list_error(self): with self.assertRaises(TypeError): x = paddle.ones(shape=self.shape, dtype=self.dtype) - x[[True, False, 0]] = 0 + if paddle.in_dynamic_mode(): + x[[True, False, 0]] = 0 + else: + x = paddle.static.setitem(x, [True, False, 0], 0) with self.assertRaises(IndexError): x = paddle.ones(shape=self.shape, dtype=self.dtype) - x[[True, False], [True, False]] = 0 + if paddle.in_dynamic_mode(): + x[[True, False], [True, False]] = 0 + else: + x = paddle.static.setitem(x, ([True, False], [True, False]), 0) def _bool_tensor_error(self): with self.assertRaises(IndexError): x = paddle.ones(shape=self.shape, dtype=self.dtype) idx = paddle.assign([True, False, True]) - x[idx] = 0 + if paddle.in_dynamic_mode(): + x[idx] = 0 + else: + x = paddle.static.setitem(x, idx, 0) def _broadcast_mismatch(self): program = paddle.static.Program() with paddle.static.program_guard(program): x = paddle.ones(shape=self.shape, dtype=self.dtype) value = np.array([3, 4, 5, 6, 7]) - x[0] = value + x = paddle.static.setitem(x, 0, value) exe = paddle.static.Executor(paddle.CPUPlace()) with self.assertRaises(ValueError): exe.run(program) @@ -1104,7 +1399,10 @@ def forward(self, x, y): y = self.conv(y) var = y.flatten() - x[0, :, 0, 0] = var + if paddle.in_dynamic_mode(): + x[0, :, 0, 0] = var + else: + x = paddle.static.setitem(x, (0, slice(None), 0, 0), var) loss = paddle.mean(x) return loss, var, x @@ -1131,7 +1429,7 @@ def test_static(self): z = paddle.add(x, y) var = y[0, :] - z[0, :] = var + z = paddle.static.setitem(z, (0, slice(None)), var) prediction = paddle.static.nn.fc(x=z, size=2, activation='softmax') diff --git a/test/legacy_test/test_var_base.py b/test/legacy_test/test_var_base.py index f96f25281d800..b0d53788e57dd 100644 --- a/test/legacy_test/test_var_base.py +++ b/test/legacy_test/test_var_base.py @@ -1365,12 +1365,10 @@ def set_dtype(self): def _test(self, value): paddle.disable_static() - self.assertEqual(self.tensor_x.inplace_version, 0) id_origin = id(self.tensor_x) index_1 = paddle.to_tensor(np.array([True, False, False, False])) self.tensor_x[index_1] = value - self.assertEqual(self.tensor_x.inplace_version, 1) if isinstance(value, (int, float)): result = np.zeros((2, 3)).astype(self.dtype) + value @@ -1383,13 +1381,11 @@ def _test(self, value): index_2 = paddle.to_tensor(np.array([False, True, False, False])) self.tensor_x[index_2] = value - self.assertEqual(self.tensor_x.inplace_version, 2) np.testing.assert_array_equal(self.tensor_x[1].numpy(), result) self.assertEqual(id_origin, id(self.tensor_x)) index_3 = paddle.to_tensor(np.array([True, True, True, True])) self.tensor_x[index_3] = value - self.assertEqual(self.tensor_x.inplace_version, 3) np.testing.assert_array_equal(self.tensor_x[3].numpy(), result) self.assertEqual(id_origin, id(self.tensor_x)) diff --git a/test/legacy_test/test_variable.py b/test/legacy_test/test_variable.py index dcbfedc0f6a23..f7338ce07e425 100644 --- a/test/legacy_test/test_variable.py +++ b/test/legacy_test/test_variable.py @@ -844,8 +844,7 @@ def run_setitem_list_index(self, array, index, value_np): name='value', shape=value_np.shape, dtype='float32' ) - x[index] = value - y = x + y = paddle.static.setitem(x, index, value) place = paddle.fluid.CPUPlace() prog = paddle.static.default_main_program() @@ -1042,9 +1041,8 @@ def test_static_graph_tensor_index_setitem_muti_dim(self): name='index_2', shape=index2.shape, dtype='int32' ) - x1[index_1, index_2] = value - x2[index_1] = value - + x1_out = paddle.static.setitem(x1, (index_1, index_2), value) + x2_out = paddle.static.setitem(x2, index_1, value) place = ( paddle.fluid.CPUPlace() if not paddle.fluid.core.is_compiled_with_cuda() @@ -1055,7 +1053,7 @@ def test_static_graph_tensor_index_setitem_muti_dim(self): exe = paddle.static.Executor(place) exe.run(paddle.static.default_startup_program()) - fetch_list = [x1.name, x2.name] + fetch_list = [x1_out.name, x2_out.name] setitem_pp = exe.run( prog, @@ -1124,10 +1122,10 @@ def test_static_graph_array_index_muti_dim(self): name='x2', shape=array.shape, dtype='float32' ) - x1[index_mod1, index_mod2] = 1 - x2[index_mod1] = 2.5 - y1 = x1[index_mod2, index_mod1] - y2 = x2[index_mod2] + x1_out = paddle.static.setitem(x1, (index_mod1, index_mod2), 1) + x2_out = paddle.static.setitem(x2, index_mod1, 2.5) + y1 = x1_out[index_mod2, index_mod1] + y2 = x2_out[index_mod2] place = ( paddle.fluid.CPUPlace() if not paddle.fluid.core.is_compiled_with_cuda() @@ -1137,7 +1135,7 @@ def test_static_graph_array_index_muti_dim(self): prog = paddle.static.default_main_program() exe = paddle.static.Executor(place) exe.run(paddle.static.default_startup_program()) - fetch_list = [x1.name, x2.name, y1.name, y2.name] + fetch_list = [x1_out.name, x2_out.name, y1.name, y2.name] setitem_pp = exe.run( prog, diff --git a/test/legacy_test/test_zero_dim_tensor.py b/test/legacy_test/test_zero_dim_tensor.py index 6080327a2da95..6f47f2d46b57a 100644 --- a/test/legacy_test/test_zero_dim_tensor.py +++ b/test/legacy_test/test_zero_dim_tensor.py @@ -3141,7 +3141,7 @@ def test_setitem(self): x = paddle.arange(2 * 3 * 4 * 5).reshape((2, 3, 4, 5)) x.stop_gradient = False out = x * 2 - out[1, 2, 3, 4] = 10 + out = paddle.static.setitem(out, (1, 2, 3, 4), 10) paddle.static.append_backward(out.sum()) prog = paddle.static.default_main_program() res = self.exe.run(prog, fetch_list=[out, x.grad_name]) @@ -3162,7 +3162,7 @@ def test_setitem(self): x.stop_gradient = False indice = paddle.full([], 1, dtype='int32') out = x * 1 - out[indice, indice] = 0.5 + out = paddle.static.setitem(out, (indice, indice), 0.5) paddle.static.append_backward(out.sum()) prog = paddle.static.default_main_program() res = self.exe.run(prog, fetch_list=[out, x.grad_name]) @@ -3181,7 +3181,7 @@ def test_setitem(self): v.stop_gradient = False indice = paddle.full([], 1, dtype='int32') out = x * 1 - out[indice] = v + out = paddle.static.setitem(out, indice, v) paddle.static.append_backward(out.sum()) prog = paddle.static.default_main_program() res = self.exe.run(prog, fetch_list=[out, x.grad_name, v.grad_name]) diff --git a/test/xpu/test_set_value_op_xpu.py b/test/xpu/test_set_value_op_xpu.py index 71ca556fb83a3..3bf665f76a32c 100644 --- a/test/xpu/test_set_value_op_xpu.py +++ b/test/xpu/test_set_value_op_xpu.py @@ -61,6 +61,10 @@ def set_dtype(self): def _call_setitem(self, x): x[0, 0] = self.value + def _call_setitem_static_api(self, x): + x = paddle.static.setitem(x, (0, 0), self.value) + return x + def _get_answer(self): self.data[0, 0] = self.value @@ -69,7 +73,7 @@ def _run_static(self): paddle.enable_static() with paddle.static.program_guard(self.program): x = paddle.ones(shape=self.shape, dtype=self.dtype) - self._call_setitem(x) + x = self._call_setitem_static_api(x) exe = paddle.static.Executor(self.place) out = exe.run(self.program, fetch_list=[x]) @@ -107,6 +111,10 @@ class XPUTestSetValueItemInt(XPUTestSetValueApi): def _call_setitem(self, x): x[0] = self.value + def _call_setitem_static_api(self, x): + x = paddle.static.setitem(x, 0, self.value) + return x + def _get_answer(self): self.data[0] = self.value @@ -117,6 +125,10 @@ def set_shape(self): def _call_setitem(self, x): x[0, 3, 4] = self.value + def _call_setitem_static_api(self, x): + x = paddle.static.setitem(x, (0, 3, 4), self.value) + return x + def _get_answer(self): self.data[0, 3, 4] = self.value @@ -127,6 +139,10 @@ def set_shape(self): def _call_setitem(self, x): x[1] = self.value + def _call_setitem_static_api(self, x): + x = paddle.static.setitem(x, (1), self.value) + return x + def _get_answer(self): self.data[1] = self.value @@ -136,6 +152,10 @@ class XPUTestSetValueItemSlice(XPUTestSetValueApi): def _call_setitem(self, x): x[0:2] = self.value + def _call_setitem_static_api(self, x): + x = paddle.static.setitem(x, slice(0, 2), self.value) + return x + def _get_answer(self): self.data[0:2] = self.value @@ -143,6 +163,10 @@ class XPUTestSetValueItemSlice2(XPUTestSetValueApi): def _call_setitem(self, x): x[0:-1] = self.value + def _call_setitem_static_api(self, x): + x = paddle.static.setitem(x, slice(0, -1), self.value) + return x + def _get_answer(self): self.data[0:-1] = self.value @@ -150,6 +174,12 @@ class XPUTestSetValueItemSlice3(XPUTestSetValueApi): def _call_setitem(self, x): x[0:-1, 0:2] = self.value + def _call_setitem_static_api(self, x): + x = paddle.static.setitem( + x, (slice(0, -1), slice(0, 2)), self.value + ) + return x + def _get_answer(self): self.data[0:-1, 0:2] = self.value @@ -157,6 +187,14 @@ class XPUTestSetValueItemSlice4(XPUTestSetValueApi): def _call_setitem(self, x): x[0:, 1:2, :] = self.value + def _call_setitem_static_api(self, x): + x = paddle.static.setitem( + x, + (slice(0, None), slice(1, 2), slice(None, None, None)), + self.value, + ) + return x + def _get_answer(self): self.data[0:, 1:2, :] = self.value @@ -164,6 +202,12 @@ class XPUTestSetValueItemSlice5(XPUTestSetValueApi): def _call_setitem(self, x): x[0:, 1:1, :] = self.value + def _call_setitem_static_api(self, x): + x = paddle.static.setitem( + x, (slice(0), slice(1, 1), slice(None, None, None)), self.value + ) + return x + def _get_answer(self): self.data[0:, 1:1, :] = self.value @@ -186,6 +230,19 @@ def body(i, x): i = paddle.zeros(shape=(1,), dtype='int32') i, x = paddle.static.nn.while_loop(cond, body, [i, x]) + def _call_setitem_static_api(self, x): + def cond(i, x): + return i < 1 + + def body(i, x): + x = paddle.static.setitem(x, i, self.value) + i = i + 1 + return i, x + + i = paddle.zeros(shape=(1,), dtype='int32') + i, x = paddle.static.nn.while_loop(cond, body, [i, x]) + return x + def _get_answer(self): self.data[0] = self.value @@ -197,6 +254,10 @@ def set_shape(self): def _call_setitem(self, x): x[0:2:2] = self.value + def _call_setitem_static_api(self, x): + x = paddle.static.setitem(x, slice(0, 2, 2), self.value) + return x + def _get_answer(self): self.data[0:2:2] = self.value @@ -207,6 +268,10 @@ def set_shape(self): def _call_setitem(self, x): x[0:-1:3] = self.value + def _call_setitem_static_api(self, x): + x = paddle.static.setitem(x, slice(0, -1, 3), self.value) + return x + def _get_answer(self): self.data[0:-1:3] = self.value @@ -214,6 +279,12 @@ class XPUTestSetValueItemSliceStep3(XPUTestSetValueApi): def _call_setitem(self, x): x[0:-1, 0:2, ::2] = self.value + def _call_setitem_static_api(self, x): + x = paddle.static.setitem( + x, (slice(0, -1), slice(0, 2), slice(None, None, 2)), self.value + ) + return x + def _get_answer(self): self.data[0:-1, 0:2, ::2] = self.value @@ -221,6 +292,14 @@ class XPUTestSetValueItemSliceStep4(XPUTestSetValueApi): def _call_setitem(self, x): x[0:, 1:2:2, :] = self.value + def _call_setitem_static_api(self, x): + x = paddle.static.setitem( + x, + (slice(0, None), slice(1, 2, 2), slice(None, None, None)), + self.value, + ) + return x + def _get_answer(self): self.data[0:, 1:2:2, :] = self.value @@ -241,6 +320,10 @@ def set_value(self): def _call_setitem(self, x): x[5:2:-1] = self.value + def _call_setitem_static_api(self, x): + x = paddle.static.setitem(x, slice(5, 2, -1), self.value) + return x + def _get_answer(self): self.data[5:2:-1] = self.value @@ -261,6 +344,10 @@ def set_value(self): def _call_setitem(self, x): x[1::-1] = self.value + def _call_setitem_static_api(self, x): + x = paddle.static.setitem(x, slice(1, None, -1), self.value) + return x + def _get_answer(self): self.data[1::-1] = self.value @@ -280,6 +367,10 @@ def set_value(self): def _call_setitem(self, x): x[::-1] = self.value + def _call_setitem_static_api(self, x): + x = paddle.static.setitem(x, slice(None, None, -1), self.value) + return x + def _get_answer(self): self.data[::-1] = self.value @@ -290,6 +381,14 @@ def set_shape(self): def _call_setitem(self, x): x[2:0:-1, 0:2, ::-1] = self.value + def _call_setitem_static_api(self, x): + x = paddle.static.setitem( + x, + (slice(2, 0, -1), slice(0, 2), slice(None, None, -1)), + self.value, + ) + return x + def _get_answer(self): self.data[2:0:-1, 0:2, ::-1] = self.value @@ -302,11 +401,15 @@ def set_shape(self): def _call_setitem(self, x): x[2:-1:-2] = self.value + def _call_setitem_static_api(self, x): + x = paddle.static.setitem(x, slice(2, -1, -2), self.value) + return x + def _get_answer(self): paddle.enable_static() with paddle.static.program_guard(self.program): x = paddle.ones(shape=self.shape, dtype=self.dtype) - self._call_setitem(x) + x = self._call_setitem_static_api(x) exe = paddle.static.Executor(paddle.CPUPlace()) self.data = exe.run(self.program, fetch_list=[x]) @@ -334,6 +437,12 @@ class XPUTestSetValueItemEllipsis1(XPUTestSetValueApi): def _call_setitem(self, x): x[0:, ..., 1:] = self.value + def _call_setitem_static_api(self, x): + x = paddle.static.setitem( + x, (slice(0, None), ..., slice(1, None)), self.value + ) + return x + def _get_answer(self): self.data[0:, ..., 1:] = self.value @@ -341,6 +450,10 @@ class XPUTestSetValueItemEllipsis2(XPUTestSetValueApi): def _call_setitem(self, x): x[0:, ...] = self.value + def _call_setitem_static_api(self, x): + x = paddle.static.setitem(x, (slice(0, None), ...), self.value) + return x + def _get_answer(self): self.data[0:, ...] = self.value @@ -348,6 +461,10 @@ class XPUTestSetValueItemEllipsis3(XPUTestSetValueApi): def _call_setitem(self, x): x[..., 1:] = self.value + def _call_setitem_static_api(self, x): + x = paddle.static.setitem(x, (..., slice(1, None)), self.value) + return x + def _get_answer(self): self.data[..., 1:] = self.value @@ -355,6 +472,10 @@ class XPUTestSetValueItemEllipsis4(XPUTestSetValueApi): def _call_setitem(self, x): x[...] = self.value + def _call_setitem_static_api(self, x): + x = paddle.static.setitem(x, (...), self.value) + return x + def _get_answer(self): self.data[...] = self.value @@ -370,6 +491,11 @@ def _call_setitem(self, x): zero = paddle.full([1], 0, dtype="int32") x[zero] = self.value + def _call_setitem_static_api(self, x): + zero = paddle.full([1], 0, dtype="int32") + x = paddle.static.setitem(x, zero, self.value) + return x + def _get_answer(self): self.data[0] = self.value @@ -385,6 +511,12 @@ def _call_setitem(self, x): two = paddle.full([1], 2, dtype="int64") x[zero:two] = self.value + def _call_setitem_static_api(self, x): + zero = paddle.full([1], 0, dtype="int32") + two = paddle.full([1], 2, dtype="int64") + x = paddle.static.setitem(x, slice(zero, two), self.value) + return x + def _get_answer(self): self.data[0:2] = self.value @@ -400,6 +532,14 @@ def _call_setitem(self, x): two = paddle.full([1], 2, dtype="int64") x[zero:-1, 0:two] = self.value + def _call_setitem_static_api(self, x): + zero = paddle.full([1], 0, dtype="int32") + two = paddle.full([1], 2, dtype="int64") + x = paddle.static.setitem( + x, (slice(zero, -1), slice(0, two)), self.value + ) + return x + def _get_answer(self): self.data[0:-1, 0:2] = self.value @@ -415,6 +555,14 @@ def _call_setitem(self, x): two = paddle.full([1], 2, dtype="int64") x[0:-1, zero:2, 0:6:two] = self.value + def _call_setitem_static_api(self, x): + zero = paddle.full([1], 0, dtype="int32") + two = paddle.full([1], 2, dtype="int64") + x = paddle.static.setitem( + x, (slice(0, -1), slice(zero, 2), slice(0, 6, two)), self.value + ) + return x + def _get_answer(self): self.data[0:-1, 0:2, ::2] = self.value @@ -430,6 +578,16 @@ def _call_setitem(self, x): two = paddle.full([1], 2, dtype="int64") x[zero:, 1:2:two, :] = self.value + def _call_setitem_static_api(self, x): + zero = paddle.full([1], 0, dtype="int32") + two = paddle.full([1], 2, dtype="int64") + x = paddle.static.setitem( + x, + (slice(zero, None), slice(1, 2, two), slice(None, None, None)), + self.value, + ) + return x + def _get_answer(self): self.data[0:, 1:2:2, :] = self.value @@ -448,6 +606,16 @@ def _call_setitem(self, x): zero = paddle.full([1], 0, dtype="int32") x[2:zero:minus1, 0:2, 10:-6:minus1] = self.value + def _call_setitem_static_api(self, x): + minus1 = paddle.full([1], -1, dtype="int32") + zero = paddle.full([1], 0, dtype="int32") + x = paddle.static.setitem( + x, + (slice(2, zero, minus1), slice(0, 2), slice(10, -6, minus1)), + self.value, + ) + return x + def _get_answer(self): self.data[2:0:-1, 0:2, ::-1] = self.value @@ -462,6 +630,10 @@ def set_dtype(self): def _call_setitem(self, x): x[None] = self.value + def _call_setitem_static_api(self, x): + x = paddle.static.setitem(x, None, self.value) + return x + def _get_answer(self): self.data[None] = self.value @@ -475,6 +647,10 @@ def set_dtype(self): def _call_setitem(self, x): x[0, None, 1] = self.value + def _call_setitem_static_api(self, x): + x = paddle.static.setitem(x, (0, None, 1), self.value) + return x + def _get_answer(self): self.data[0, None, 1] = self.value @@ -488,6 +664,12 @@ def set_dtype(self): def _call_setitem(self, x): x[:, None, None, 1] = self.value + def _call_setitem_static_api(self, x): + x = paddle.static.setitem( + x, (slice(None, None, None), None, None, 1), self.value + ) + return x + def _get_answer(self): self.data[:, None, None, 1] = self.value @@ -501,6 +683,10 @@ def set_dtype(self): def _call_setitem(self, x): x[0, 0, None, 1] = self.value + def _call_setitem_static_api(self, x): + x = paddle.static.setitem(x, (0, 0, None, 1), self.value) + return x + def _get_answer(self): self.data[0, 0, None, 1] = self.value @@ -514,6 +700,10 @@ def set_dtype(self): def _call_setitem(self, x): x[0, None, 0, None, 1] = self.value + def _call_setitem_static_api(self, x): + x = paddle.static.setitem(x, (0, None, 0, None, 1), self.value) + return x + def _get_answer(self): self.data[0, None, 0, None, 1] = self.value @@ -527,6 +717,10 @@ def set_dtype(self): def _call_setitem(self, x): x[None, 0, 0, None, 0] = self.value + def _call_setitem_static_api(self, x): + x = paddle.static.setitem(x, (None, 0, 0, None, 0), self.value) + return x + def _get_answer(self): self.data[None, 0, 0, None, 0] = self.value @@ -540,6 +734,14 @@ def set_dtype(self): def _call_setitem(self, x): x[:, None, 1] = np.zeros(self.shape)[:, None, 0] + def _call_setitem_static_api(self, x): + x = paddle.static.setitem( + x, + (slice(None, None, None), None, 1), + np.zeros(self.shape)[:, None, 0], + ) + return x + def _get_answer(self): self.data[:, None, 1] = np.zeros(self.shape)[:, None, 0] @@ -553,6 +755,14 @@ def set_dtype(self): def _call_setitem(self, x): x[:, 1, None] = np.zeros(self.shape)[:, 0, None] + def _call_setitem_static_api(self, x): + x = paddle.static.setitem( + x, + (slice(None, None, None), 1, None), + np.zeros(self.shape)[:, 0, None], + ) + return x + def _get_answer(self): self.data[:, 1, None] = np.zeros(self.shape)[:, 0, None] @@ -566,6 +776,14 @@ def set_dtype(self): def _call_setitem(self, x): x[None, :, 1, ..., None] = np.zeros(self.shape)[0, 0, :, None] + def _call_setitem_static_api(self, x): + x = paddle.static.setitem( + x, + (None, slice(None, None, None), 1, ..., None), + np.zeros(self.shape)[0, 0, :, None], + ) + return x + def _get_answer(self): self.data[None, :, 1, ..., None] = np.zeros(self.shape)[ 0, 0, :, None @@ -581,6 +799,14 @@ def set_dtype(self): def _call_setitem(self, x): x[..., None, :, None] = np.zeros(self.shape)[..., None, :, None] + def _call_setitem_static_api(self, x): + x = paddle.static.setitem( + x, + (..., None, slice(None, None, None), None), + np.zeros(self.shape)[..., None, :, None], + ) + return x + def _get_answer(self): self.data[..., None, :, None] = np.zeros(self.shape)[ ..., None, :, None @@ -597,6 +823,10 @@ def set_dtype(self): def _call_setitem(self, x): x[[True, False]] = self.value + def _call_setitem_static_api(self, x): + x = paddle.static.setitem(x, [True, False], self.value) + return x + def _get_answer(self): self.data[[True, False]] = self.value @@ -610,6 +840,10 @@ def set_dtype(self): def _call_setitem(self, x): x[[False, False]] = self.value + def _call_setitem_static_api(self, x): + x = paddle.static.setitem(x, [False, False], self.value) + return x + def _get_answer(self): self.data[[False, False]] = self.value @@ -623,6 +857,10 @@ def set_dtype(self): def _call_setitem(self, x): x[[False, True]] = np.zeros(self.shape[2]) + def _call_setitem_static_api(self, x): + x = paddle.static.setitem(x, [False, True], np.zeros(self.shape[2])) + return x + def _get_answer(self): self.data[[False, True]] = np.zeros(self.shape[2]) @@ -637,6 +875,11 @@ def _call_setitem(self, x): idx = paddle.assign(np.array([False, True])) x[idx] = np.zeros(self.shape[2]) + def _call_setitem_static_api(self, x): + idx = paddle.assign(np.array([False, True])) + x = paddle.static.setitem(x, idx, np.zeros(self.shape[2])) + return x + def _get_answer(self): self.data[np.array([False, True])] = np.zeros(self.shape[2]) @@ -653,6 +896,13 @@ def _call_setitem(self, x): ) x[idx] = self.value + def _call_setitem_static_api(self, x): + idx = paddle.assign( + np.array([[False, True, False], [True, True, False]]) + ) + x = paddle.static.setitem(x, idx, self.value) + return x + def _get_answer(self): self.data[ np.array([[False, True, False], [True, True, False]]) @@ -669,6 +919,11 @@ def _call_setitem(self, x): x[0, ...] = 0 x[x > 0] = self.value + def _call_setitem_static_api(self, x): + x = paddle.static.setitem(x, (0, ...), 0) + x = paddle.static.setitem(x, x > 0, self.value) + return x + def _get_answer(self): self.data[0, ...] = 0 self.data[self.data > 0] = self.value @@ -684,6 +939,11 @@ def _call_setitem(self, x): value = paddle.full(shape=[1], fill_value=3, dtype=self.dtype) x[0, 1] = value + def _call_setitem_static_api(self, x): + value = paddle.full(shape=[1], fill_value=3, dtype=self.dtype) + x = paddle.static.setitem(x, (0, 1), value) + return x + def _get_answer(self): self.data[0, 1] = 3 @@ -706,6 +966,11 @@ def _call_setitem(self, x): value = paddle.full(shape=[1], fill_value=3, dtype=self.dtype) x[0, 1] = value + def _call_setitem_static_api(self, x): + value = paddle.full(shape=[1], fill_value=3, dtype=self.dtype) + x = paddle.static.setitem(x, (0, 1), value) + return x + def _get_answer(self): self.data[0, 1] = 3 @@ -728,6 +993,11 @@ def _call_setitem(self, x): value = paddle.full(shape=[1], fill_value=3, dtype=self.dtype) x[0, 1] = value + def _call_setitem_static_api(self, x): + value = paddle.full(shape=[1], fill_value=3, dtype=self.dtype) + x = paddle.static.setitem(x, (0, 1), value) + return x + def _get_answer(self): self.data[0, 1] = 3 @@ -752,6 +1022,13 @@ def _call_setitem(self, x): ) x[0, 1] = value + def _call_setitem_static_api(self, x): + value = paddle.full( + shape=[1], fill_value=False, dtype=self.dtype + ) + x = paddle.static.setitem(x, (0, 1), value) + return x + def _get_answer(self): self.data[0, 1] = False @@ -779,6 +1056,10 @@ def set_value(self): def _call_setitem(self, x): x[0] = self.value + def _call_setitem_static_api(self, x): + x = paddle.static.setitem(x, 0, self.value) + return x + def _get_answer(self): self.data[0] = self.value @@ -795,6 +1076,10 @@ def set_value(self): def _call_setitem(self, x): x[0:1] = self.value + def _call_setitem_static_api(self, x): + x = paddle.static.setitem(x, slice(0, 1), self.value) + return x + def _get_answer(self): self.data[0:1] = self.value @@ -813,6 +1098,10 @@ def set_value(self): def _call_setitem(self, x): x[0] = self.value + def _call_setitem_static_api(self, x): + x = paddle.static.setitem(x, 0, self.value) + return x + def _get_answer(self): self.data[0] = self.value @@ -833,6 +1122,10 @@ def set_value(self): def _call_setitem(self, x): x[0] = paddle.assign(self.value) # x is Paddle.Tensor + def _call_setitem_static_api(self, x): + x = paddle.static.setitem(x, 0, paddle.assign(self.value)) + return x + def _get_answer(self): self.data[0] = self.value @@ -852,6 +1145,12 @@ def set_shape(self): def _call_setitem(self, x): x[:, 0] = paddle.assign(self.value) # x is Paddle.Tensor + def _call_setitem_static_api(self, x): + x = paddle.static.setitem( + x, (slice(None, None, None), 0), paddle.assign(self.value) + ) + return x + def _get_answer(self): self.data[:, 0] = self.value @@ -864,7 +1163,10 @@ def _value_type_error(self): ): x = paddle.ones(shape=self.shape, dtype=self.dtype) value = [1] - x[0] = value + if paddle.in_dynamic_mode(): + x[0] = value + else: + x = paddle.static.setitem(x, 0, value) def _dtype_error(self): with self.assertRaisesRegex( @@ -877,7 +1179,10 @@ def _dtype_error(self): def _step_error(self): with self.assertRaisesRegex(ValueError, "step can not be 0"): x = paddle.ones(shape=self.shape, dtype=self.dtype) - x[0:1:0] = self.value + if paddle.in_dynamic_mode(): + x[0:1:0] = self.value + else: + x = paddle.static.setitem(x, slice(0, 1, 0), self.value) def _ellipsis_error(self): with self.assertRaisesRegex( @@ -893,24 +1198,35 @@ def _ellipsis_error(self): def _bool_list_error(self): with self.assertRaises(TypeError): x = paddle.ones(shape=self.shape, dtype=self.dtype) - x[[True, False, 0]] = 0 + if paddle.in_dynamic_mode(): + x[[True, False, 0]] = 0 + else: + x = paddle.static.setitem(x, [True, False, 0], 0) with self.assertRaises(IndexError): x = paddle.ones(shape=self.shape, dtype=self.dtype) - x[[True, False], [True, False]] = 0 + if paddle.in_dynamic_mode(): + x[[True, False], [True, False]] = 0 + else: + x = paddle.static.setitem( + x, ([True, False], [True, False]), 0 + ) def _bool_tensor_error(self): with self.assertRaises(IndexError): x = paddle.ones(shape=self.shape, dtype=self.dtype) idx = paddle.assign([True, False, True]) - x[idx] = 0 + if paddle.in_dynamic_mode(): + x[idx] = 0 + else: + x = paddle.static.setitem(x, idx, 0) def _broadcast_mismatch(self): program = paddle.static.Program() with paddle.static.program_guard(program): x = paddle.ones(shape=self.shape, dtype=self.dtype) value = np.array([3, 4, 5, 6, 7]) - x[0] = value + x = paddle.static.setitem(x, 0, value) exe = paddle.static.Executor(paddle.XPUPlace(0)) with self.assertRaises(ValueError): exe.run(program) @@ -952,7 +1268,7 @@ def test_static(self): z = paddle.add(x, y) var = y[0, :] - z[0, :] = var + z = paddle.static.setitem(z, (0, slice(None)), var) prediction = paddle.static.nn.fc( x=z, size=2, activation='softmax'