diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index 3063e52c3ea9e..918b5f2c01e9c 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -2142,7 +2142,7 @@ def assign(x, output=None): [2.5 2.5]] """ # speed up - if x is output and isinstance(x, Variable): + if x is output and isinstance(x, (Variable, paddle.pir.OpResult)): return x input = x @@ -2150,7 +2150,16 @@ def assign(x, output=None): check_type( input, 'input', - (Variable, np.ndarray, list, tuple, float, int, bool), + ( + Variable, + paddle.pir.OpResult, + np.ndarray, + list, + tuple, + float, + int, + bool, + ), 'assign', ) @@ -2163,12 +2172,17 @@ def assign(x, output=None): # but in_dynamic_mode()==False under @to_static, which means # isinstance(Tensor, Variable) == False. It will cause return None # after this api. - if isinstance(input, (Variable, core.eager.Tensor)): + if isinstance(input, (Variable, core.eager.Tensor, paddle.pir.OpResult)): if in_dynamic_mode(): if output is None: output = _C_ops.assign(input) else: _C_ops.assign_out_(input, output) + elif in_pir_mode(): + if output is None: + output = _C_ops.assign(input) + else: + output = _C_ops.assign_out_(input, output) else: check_dtype( input.dtype, @@ -2196,19 +2210,25 @@ def assign(x, output=None): ) elif isinstance(input, np.ndarray): # We now support the form of [var, VAR...] if the Var.shape=[1,] - if len(input.shape) > 0 and any(isinstance(x, Variable) for x in input): + if len(input.shape) > 0 and any( + isinstance(x, (Variable, paddle.pir.OpResult)) for x in input + ): # We only deal with the case where the list is nested one level, convert all scalars into variables, and then use stack to process. It is necessary to ensure the consistency of types. if not all( x.shape == (1,) for x in input - if isinstance(x, (Variable, core.eager.Tensor)) + if isinstance( + x, (Variable, core.eager.Tensor, paddle.pir.OpResult) + ) ): raise TypeError( "Unsupport paddle.assign([Variable, Variable...]) with non-scalar variable." ) def convert_scalar(x): - if not isinstance(x, (Variable, core.eager.Tensor)): + if not isinstance( + x, (Variable, core.eager.Tensor, paddle.pir.OpResult) + ): return assign(x) return x @@ -2233,16 +2253,33 @@ def convert_scalar(x): "it to float32" ) dtype = core.VarDesc.VarType.FP32 - if dtype == core.VarDesc.VarType.BOOL: + + if dtype == core.DataType.FLOAT64: + # Setting FP64 numpy data is not supported in Paddle, so we + # use FP32 here + warnings.warn( + "paddle.assign doesn't support float64 input now due " + "to current platform protobuf data limitation, we convert " + "it to float32" + ) + dtype = core.DataType.FLOAT32 + + if dtype == core.VarDesc.VarType.BOOL or dtype == core.DataType.BOOL: value_name = "bool_values" values = [int(v) for v in input.flat] - elif dtype == core.VarDesc.VarType.FP32: + elif ( + dtype == core.VarDesc.VarType.FP32 or dtype == core.DataType.FLOAT32 + ): value_name = "fp32_values" values = [float(v) for v in input.flat] - elif dtype == core.VarDesc.VarType.INT32: + elif ( + dtype == core.VarDesc.VarType.INT32 or dtype == core.DataType.INT32 + ): value_name = "int32_values" values = [int(v) for v in input.flat] - elif dtype == core.VarDesc.VarType.INT64: + elif ( + dtype == core.VarDesc.VarType.INT64 or dtype == core.DataType.INT64 + ): value_name = "int64_values" values = [int(v) for v in input.flat] else: @@ -2256,16 +2293,25 @@ def convert_scalar(x): "The size of input is too big. Please consider " "saving it to file and 'load_op' to load it" ) - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): if output is None: output = zeros(list(input.shape), dtype) - _C_ops.assign_value_( - output, - list(input.shape), - dtype, - values, - _current_expected_place(), - ) + if in_dynamic_mode(): + _C_ops.assign_value_( + output, + list(input.shape), + dtype, + values, + _current_expected_place(), + ) + else: + output = _C_ops.assign_value_( + output, + list(input.shape), + dtype, + values, + _current_expected_place(), + ) else: if output is None: output = helper.create_variable_for_type_inference( diff --git a/test/legacy_test/test_assign_op.py b/test/legacy_test/test_assign_op.py index 991c96bdf7849..4a9ff9308f7b8 100644 --- a/test/legacy_test/test_assign_op.py +++ b/test/legacy_test/test_assign_op.py @@ -42,12 +42,12 @@ def init_input_configs(self): def test_forward(self): paddle.enable_static() - self.check_output() + self.check_output(check_new_ir=True) paddle.disable_static() def test_backward(self): paddle.enable_static() - self.check_grad(['X'], 'Out', check_prim=True) + self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True) paddle.disable_static() @@ -71,12 +71,12 @@ def setUp(self): def test_forward(self): paddle.enable_static() - self.check_output() + self.check_output(check_new_ir=True) paddle.disable_static() def test_backward(self): paddle.enable_static() - self.check_grad(['X'], 'Out', check_prim=True) + self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True) paddle.disable_static() @@ -97,12 +97,12 @@ def setUp(self): def test_forward(self): paddle.enable_static() - self.check_output() + self.check_output(check_new_ir=True) paddle.disable_static() def test_backward(self): paddle.enable_static() - self.check_grad(['X'], 'Out', check_prim=True) + self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True) paddle.disable_static() diff --git a/test/legacy_test/test_assign_value_op.py b/test/legacy_test/test_assign_value_op.py index 88bb60edbbc3c..b0963b51b2485 100644 --- a/test/legacy_test/test_assign_value_op.py +++ b/test/legacy_test/test_assign_value_op.py @@ -25,7 +25,12 @@ def assign_value_wrapper( shape=[], dtype=base.core.VarDesc.VarType.FP32, values=0.0 ): - tensor = paddle.Tensor() + if paddle.framework.in_dynamic_mode(): + tensor = paddle.Tensor() + else: + np_type = paddle.base.data_feeder._PADDLE_DTYPE_2_NUMPY_DTYPE[dtype] + tensor = paddle.zeros(list(shape), np_type) + dtype = paddle.pir.core.convert_np_dtype_to_dtype_(np_type) return paddle._C_ops.assign_value_( tensor, shape, dtype, values, framework._current_expected_place() ) @@ -49,7 +54,7 @@ def init_data(self): self.attrs["fp32_values"] = [float(v) for v in self.value.flat] def test_forward(self): - self.check_output(check_cinn=True) + self.check_output(check_cinn=True, check_new_ir=True) class TestAssignValueOp2(TestAssignValueOp):