diff --git a/paddle/phi/api/yaml/legacy_ops.yaml b/paddle/phi/api/yaml/legacy_ops.yaml index bc3cee0670e4f..0614007fb995d 100755 --- a/paddle/phi/api/yaml/legacy_ops.yaml +++ b/paddle/phi/api/yaml/legacy_ops.yaml @@ -237,7 +237,6 @@ func : ElementwiseInferMeta kernel : func : elementwise_pow - inplace: (x -> out) backward : elementwise_pow_grad - op : embedding diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 28e52692d3afe..4e4d4803b21bc 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -13,6 +13,7 @@ kernel : func : abs data_type : x + inplace: (x -> out) backward : abs_grad - op : accuracy @@ -26,20 +27,22 @@ - op : acos args : (Tensor x) - output : Tensor + output : Tensor(out) infer_meta : func : UnchangedInferMeta kernel : func : acos + inplace: (x -> out) backward : acos_grad - op : acosh args : (Tensor x) - output : Tensor + output : Tensor(out) infer_meta : func : UnchangedInferMeta kernel : func : acosh + inplace: (x -> out) backward : acosh_grad - op : adagrad_ @@ -90,12 +93,13 @@ - op : addmm args : (Tensor input, Tensor x, Tensor y, float beta=1.0, float alpha=1.0) - output : Tensor + output : Tensor(out) infer_meta : func : AddmmInferMeta kernel : func : addmm data_type : x + inplace: (input -> out) backward : addmm_grad - op : affine_grid @@ -176,34 +180,37 @@ - op : asin args : (Tensor x) - output : Tensor + output : Tensor(out) infer_meta : func : UnchangedInferMeta kernel : func : asin + inplace: (x -> out) backward : asin_grad - op : asinh args : (Tensor x) - output : Tensor + output : Tensor(out) infer_meta : func : UnchangedInferMeta kernel : func : asinh + inplace: (x -> out) backward : asinh_grad - op : atan args : (Tensor x) - output : Tensor + output : Tensor(out) infer_meta : func : UnchangedInferMeta kernel : func : atan + inplace: (x -> out) backward : atan_grad - op : atan2 args : (Tensor x, Tensor y) - output : Tensor + output : Tensor(out) infer_meta : func : Atan2InferMeta kernel : @@ -212,11 +219,12 @@ - op : atanh args : (Tensor x) - output : Tensor + output : Tensor(out) infer_meta : func : UnchangedInferMeta kernel : func : atanh + inplace: (x -> out) backward : atanh_grad - op : auc @@ -524,20 +532,22 @@ - op : cos args : (Tensor x) - output : Tensor + output : Tensor(out) infer_meta : func : UnchangedInferMeta kernel : func : cos + inplace: (x -> out) backward : cos_grad - op : cosh args : (Tensor x) - output : Tensor + output : Tensor(out) infer_meta : func : UnchangedInferMeta kernel : func : cosh + inplace: (x -> out) backward : cosh_grad - op : crop @@ -756,11 +766,12 @@ - op : erf args : (Tensor x) - output : Tensor + output : Tensor(out) infer_meta : func : UnchangedInferMeta kernel : func : erf + inplace : (x -> out) backward : erf_grad - op : erfinv @@ -806,12 +817,13 @@ - op : expm1 args : (Tensor x) - output : Tensor + output : Tensor(out) infer_meta : func : UnchangedInferMeta param : [x] kernel : func : expm1 + inplace: (x -> out) backward : expm1_grad - op : fft_c2c @@ -2250,20 +2262,22 @@ - op : sin args : (Tensor x) - output : Tensor + output : Tensor(out) infer_meta : func : UnchangedInferMeta kernel : func : sin + inplace: (x -> out) backward : sin_grad - op : sinh args : (Tensor x) - output : Tensor + output : Tensor(out) infer_meta : func : UnchangedInferMeta kernel : func : sinh + inplace: (x -> out) backward : sinh_grad - op : slogdet @@ -2409,11 +2423,12 @@ - op : tan args : (Tensor x) - output : Tensor + output : Tensor(out) infer_meta : func : UnchangedInferMeta kernel : func : tan + inplace: (x -> out) backward : tan_grad - op : tanh diff --git a/paddle/phi/kernels/funcs/activation_functor.h b/paddle/phi/kernels/funcs/activation_functor.h index 0b4f051bb4599..ae5af6ce9ae47 100644 --- a/paddle/phi/kernels/funcs/activation_functor.h +++ b/paddle/phi/kernels/funcs/activation_functor.h @@ -116,7 +116,10 @@ template struct SinFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { - out.device(d) = x.unaryExpr(Sine()); + // Note(GGBond8488): Since Eigen3.3, Behavior like {A = (B * A).cwiseAbs()} + // will give wrong result, details see + // http://eigen.tuxfamily.org/dox/group__TopicAliasing.html + out.device(d) = x.unaryExpr(Sine()).eval(); } }; @@ -448,7 +451,7 @@ template struct CosFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { - out.device(d) = x.unaryExpr(Cosine()); + out.device(d) = x.unaryExpr(Cosine()).eval(); } }; @@ -762,7 +765,10 @@ template struct TanFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { - out.device(d) = x.unaryExpr(Tangent()); + // Note(GGBond8488): Since Eigen3.3, Behavior like {A = (B * A).cwiseAbs()} + // will give wrong result, details see + // http://eigen.tuxfamily.org/dox/group__TopicAliasing.html + out.device(d) = x.unaryExpr(Tangent()).eval(); } }; @@ -795,7 +801,7 @@ template struct SinhFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { - out.device(d) = x.unaryExpr(Sinh()); + out.device(d) = x.unaryExpr(Sinh()).eval(); } }; @@ -804,7 +810,7 @@ template struct CoshFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { - out.device(d) = x.unaryExpr(Cosh()); + out.device(d) = x.unaryExpr(Cosh()).eval(); } }; @@ -855,7 +861,7 @@ template struct AcosFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { - out.device(d) = x.unaryExpr(Acos()); + out.device(d) = x.unaryExpr(Acos()).eval(); } }; @@ -892,7 +898,7 @@ template struct AsinFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { - out.device(d) = x.unaryExpr(Asin()); + out.device(d) = x.unaryExpr(Asin()).eval(); } }; @@ -929,7 +935,7 @@ template struct AtanFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { - out.device(d) = x.unaryExpr(Atan()); + out.device(d) = x.unaryExpr(Atan()).eval(); } }; @@ -977,7 +983,7 @@ template struct AcoshFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { - out.device(d) = x.unaryExpr(Acosh()); + out.device(d) = x.unaryExpr(Acosh()).eval(); } }; @@ -1014,7 +1020,7 @@ template struct AsinhFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { - out.device(d) = x.unaryExpr(Asinh()); + out.device(d) = x.unaryExpr(Asinh()).eval(); } }; @@ -1051,7 +1057,7 @@ template struct AtanhFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { - out.device(d) = x.unaryExpr(Atanh()); + out.device(d) = x.unaryExpr(Atanh()).eval(); } }; diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 4963ad8b51160..99c1aa35fd671 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -203,14 +203,21 @@ from .tensor.manipulation import index_put_ # noqa: F401 from .tensor.manipulation import unflatten # noqa: F401 from .tensor.math import abs # noqa: F401 +from .tensor.math import abs_ # noqa: F401 from .tensor.math import acos # noqa: F401 +from .tensor.math import acos_ # noqa: F401 from .tensor.math import asin # noqa: F401 +from .tensor.math import asin_ # noqa: F401 from .tensor.math import atan # noqa: F401 +from .tensor.math import atan_ # noqa: F401 from .tensor.math import atan2 # noqa: F401 from .tensor.math import ceil # noqa: F401 from .tensor.math import cos # noqa: F401 +from .tensor.math import cos_ # noqa: F401 from .tensor.math import tan # noqa: F401 +from .tensor.math import tan_ # noqa: F401 from .tensor.math import cosh # noqa: F401 +from .tensor.math import cosh_ # noqa: F401 from .tensor.math import cumsum # noqa: F401 from .tensor.math import cummax # noqa: F401 from .tensor.math import cummin # noqa: F401 @@ -219,6 +226,7 @@ from .tensor.math import logit # noqa: F401 from .tensor.math import exp # noqa: F401 from .tensor.math import expm1 # noqa: F401 +from .tensor.math import expm1_ # noqa: F401 from .tensor.math import floor # noqa: F401 from .tensor.math import increment # noqa: F401 from .tensor.math import log # noqa: F401 @@ -235,9 +243,12 @@ from .tensor.math import scale # noqa: F401 from .tensor.math import sign # noqa: F401 from .tensor.math import sin # noqa: F401 +from .tensor.math import sin_ # noqa: F401 from .tensor.math import sinh # noqa: F401 +from .tensor.math import sinh_ # noqa: F401 from .tensor.math import sqrt # noqa: F401 from .tensor.math import square # noqa: F401 +from .tensor.math import square_ # noqa: F401 from .tensor.math import stanh # noqa: F401 from .tensor.math import sum # noqa: F401 from .tensor.math import nan_to_num # noqa: F401 @@ -269,7 +280,9 @@ from .tensor.math import inverse # noqa: F401 from .tensor.math import log1p # noqa: F401 from .tensor.math import erf # noqa: F401 +from .tensor.math import erf_ # noqa: F401 from .tensor.math import addmm # noqa: F401 +from .tensor.math import addmm_ # noqa: F401 from .tensor.math import clip # noqa: F401 from .tensor.math import trace # noqa: F401 from .tensor.math import diagonal # noqa: F401 @@ -285,8 +298,11 @@ from .tensor.math import neg # noqa: F401 from .tensor.math import lgamma # noqa: F401 from .tensor.math import acosh # noqa: F401 +from .tensor.math import acosh_ # noqa: F401 from .tensor.math import asinh # noqa: F401 +from .tensor.math import asinh_ # noqa: F401 from .tensor.math import atanh # noqa: F401 +from .tensor.math import atanh_ # noqa: F401 from .tensor.math import lerp # noqa: F401 from .tensor.math import erfinv # noqa: F401 from .tensor.math import rad2deg # noqa: F401 @@ -431,6 +447,7 @@ 'complex64', 'complex128', 'addmm', + 'addmm_', 'allclose', 'isclose', 't', @@ -468,7 +485,9 @@ 'where', 'log1p', 'cos', + 'cos_', 'tan', + 'tan_', 'mean', 'mode', 'mv', @@ -543,6 +562,7 @@ 'less_equal', 'triu', 'sin', + 'sin_', 'dist', 'cdist', 'unbind', @@ -560,6 +580,7 @@ 'is_grad_enabled', 'mod', 'abs', + 'abs_', 'tril', 'pow', 'pow_', @@ -571,12 +592,15 @@ 'matmul', 'seed', 'acos', + 'acos_', 'logical_xor', 'exp', 'expm1', + 'expm1_', 'bernoulli', 'poisson', 'sinh', + 'sinh_', 'round', 'DataParallel', 'argmin', @@ -590,9 +614,11 @@ 'inner', 'outer', 'square', + 'square_', 'divide', 'ceil', 'atan', + 'atan_', 'atan2', 'rad2deg', 'deg2rad', @@ -618,6 +644,7 @@ 'dot', 'increment', 'erf', + 'erf_', 'bmm', 'chunk', 'tolist', diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 95623f145b63d..ccd61d7bb2114 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -141,14 +141,21 @@ from .manipulation import index_put_ # noqa: F401 from .manipulation import unflatten # noqa: F401 from .math import abs # noqa: F401 +from .math import abs_ # noqa: F401 from .math import acos # noqa: F401 +from .math import acos_ # noqa: F401 from .math import asin # noqa: F401 +from .math import asin_ # noqa: F401 from .math import atan # noqa: F401 +from .math import atan_ # noqa: F401 from .math import ceil # noqa: F401 from .math import ceil_ # noqa: F401 from .math import cos # noqa: F401 +from .math import cos_ # noqa: F401 from .math import tan # noqa: F401 +from .math import tan_ # noqa: F401 from .math import cosh # noqa: F401 +from .math import cosh_ # noqa: F401 from .math import cumsum # noqa: F401 from .math import cummax # noqa: F401 from .math import cummin # noqa: F401 @@ -175,7 +182,9 @@ from .math import scale_ # noqa: F401 from .math import sign # noqa: F401 from .math import sin # noqa: F401 +from .math import sin_ # noqa: F401 from .math import sinh # noqa: F401 +from .math import sinh_ # noqa: F401 from .math import sqrt # noqa: F401 from .math import sqrt_ # noqa: F401 from .math import square # noqa: F401 @@ -216,6 +225,7 @@ from .math import log1p # noqa: F401 from .math import erf # noqa: F401 from .math import addmm # noqa: F401 +from .math import addmm_ # noqa: F401 from .math import clip # noqa: F401 from .math import clip_ # noqa: F401 from .math import trace # noqa: F401 @@ -234,8 +244,11 @@ from .math import lgamma # noqa: F401 from .math import diagonal # noqa: F401 from .math import acosh # noqa: F401 +from .math import acosh_ # noqa: F401 from .math import asinh # noqa: F401 +from .math import asinh_ # noqa: F401 from .math import atanh # noqa: F401 +from .math import atanh_ # noqa: F401 from .math import lerp # noqa: F401 from .math import lerp_ # noqa: F401 from .math import erfinv # noqa: F401 @@ -421,6 +434,7 @@ 'log1p', 'erf', 'addmm', + 'addmm_', 'clip', 'clip_', 'trace', diff --git a/python/paddle/tensor/layer_function_generator.py b/python/paddle/tensor/layer_function_generator.py index bdf8ac6e30d87..955e2b13ec548 100644 --- a/python/paddle/tensor/layer_function_generator.py +++ b/python/paddle/tensor/layer_function_generator.py @@ -14,7 +14,6 @@ import re import string -import warnings from io import StringIO from paddle import _C_ops, _legacy_C_ops @@ -352,22 +351,14 @@ def func(x, name=None): else: op = getattr(_legacy_C_ops, inplace_op_type) return op(x) - else: - warnings.warn( - "In static graph mode, {}() is the same as {}() and does not perform inplace operation.".format( - inplace_op_type, origin_op_type - ) - ) - return generate_activation_fn(origin_op_type)(x, name) func.__name__ = inplace_op_type func.__doc__ = """ Inplace version of ``{}`` API, the output Tensor will be inplaced with input ``x``. -Please refer to :ref:`api_fluid_layers_{}`. +Please refer to :ref:`api_paddle_{}`. """.format( origin_op_type, origin_op_type ) - return func diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 974db208cbaec..613ba5f84eaf2 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -43,20 +43,31 @@ from .layer_function_generator import generate_layer_fn, templatedoc from .manipulation import cast from .ops import abs # noqa: F401 +from .ops import abs_ # noqa: F401 from .ops import acos # noqa: F401 +from .ops import acos_ # noqa: F401 from .ops import acosh # noqa: F401 +from .ops import acosh_ # noqa: F401 from .ops import asin # noqa: F401 +from .ops import asin_ # noqa: F401 from .ops import asinh # noqa: F401 +from .ops import asinh_ # noqa: F401 from .ops import atan # noqa: F401 +from .ops import atan_ # noqa: F401 from .ops import atanh # noqa: F401 +from .ops import atanh_ # noqa: F401 from .ops import ceil # noqa: F401 from .ops import ceil_ # noqa: F401 from .ops import cos # noqa: F401 +from .ops import cos_ # noqa: F401 from .ops import cosh # noqa: F401 +from .ops import cosh_ # noqa: F401 from .ops import erf # noqa: F401 +from .ops import erf_ # noqa: F401 from .ops import exp # noqa: F401 from .ops import exp_ # noqa: F401 from .ops import expm1 # noqa: F401 +from .ops import expm1_ # noqa: F401 from .ops import floor # noqa: F401 from .ops import floor_ # noqa: F401 from .ops import reciprocal # noqa: F401 @@ -68,11 +79,15 @@ from .ops import sigmoid # noqa: F401 from .ops import sigmoid_ # noqa: F401 from .ops import sin # noqa: F401 +from .ops import sin_ # noqa: F401 from .ops import sinh # noqa: F401 +from .ops import sinh_ # noqa: F401 from .ops import sqrt # noqa: F401 from .ops import sqrt_ # noqa: F401 from .ops import square # noqa: F401 +from .ops import square_ # noqa: F401 from .ops import tan # noqa: F401 +from .ops import tan_ # noqa: F401 __all__ = [] @@ -482,12 +497,8 @@ def pow_(x, y, name=None): """ if isinstance(y, (int, float)): return _C_ops.pow_(x, y) - elif isinstance(y, (paddle.Tensor, Variable)): - return _C_ops.elementwise_pow_(x, y) else: - raise TypeError( - 'y must be scalar or tensor type, but received: %s ' % (type(y)) - ) + raise TypeError('y must be scalar type, but received: %s ' % (type(y))) OP_NAMEMAPPING = { @@ -2043,6 +2054,66 @@ def addmm(input, x, y, beta=1.0, alpha=1.0, name=None): return out +@inplace_apis_in_dygraph_only +def addmm_(input, x, y, beta=1.0, alpha=1.0, name=None): + """ + Inplace version of ``addmm`` API, the output Tensor will be inplaced with input ``x``. + Please refer to :ref:`api_label_addmm`. + """ + input_shape = input.shape + x_shape = x.shape + y_shape = y.shape + if not len(x_shape) == len(y_shape) == 2: + raise ValueError( + "The dimention of x, y should be 2 but receive x's shape: {}, y's shape: {}".format( + x_shape, y_shape + ) + ) + if x_shape[1] != y_shape[0]: + raise ValueError( + "The input Variable x's width must be equal with Variable y' height. But received x's shape = {}, y's shape = {}.".format( + x_shape, y_shape + ) + ) + if len(input_shape) == 2: + if input_shape[0] != x_shape[0]: + if input_shape[0] != 1: + raise ValueError( + "When x's dimension[0] is not equal with input's dimension[0], input's dimension[0] must be 1 but got {}".format( + input_shape[0] + ) + ) + if input_shape[1] != y_shape[1] and input_shape[1] != 1: + raise ValueError( + "When y's dimension[1] is not equal with input's dimension[1], input's dimension[1] must be 1 but got {}".format( + input_shape[1] + ) + ) + if input_shape[1] != y_shape[1]: + if input_shape[1] != 1: + raise ValueError( + "When y's dimension[1] is not equal with input's dimension[1], input's dimension[1] must be 1 but got {}".format( + input_shape[1] + ) + ) + elif len(input_shape) == 1: + if input_shape[0] not in (y_shape[1], 1): + raise ValueError( + "The input's shape: {} is not broadcastable with [x.shape[0], y.shape[1]]: [{},{}]".format( + input_shape, x_shape[0], y_shape[1] + ) + ) + else: + raise ValueError( + "The dimention of input should be 2 or 1 but receive input's shape: {}".format( + input_shape + ) + ) + + if in_dynamic_mode(): + return _C_ops.addmm_(input, x, y, beta, alpha) + + def renorm(x, p, axis, max_norm): """ **renorm** diff --git a/python/paddle/tensor/ops.py b/python/paddle/tensor/ops.py index 76bc9b6c4a6cc..f19f844f49a5b 100644 --- a/python/paddle/tensor/ops.py +++ b/python/paddle/tensor/ops.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. + +from paddle.utils.inplace_utils import inplace_apis_in_dygraph_only + from .. import _C_ops from ..fluid.data_feeder import check_variable_and_dtype from ..framework import LayerHelper, in_dynamic_mode @@ -47,6 +50,21 @@ 'round_', 'reciprocal_', 'sigmoid_', + 'abs_', + 'sin_', + 'sinh_', + 'asin_', + 'asinh_', + 'cos_', + 'cosh_', + 'acos_', + 'acosh_', + 'tan_', + 'atan_', + 'atanh_', + 'expm1_', + 'erf_', + 'square_', ] __all__ = [] @@ -76,7 +94,9 @@ _new_OP = _OP if _OP in __deprecated_func_name__: _new_OP = __deprecated_func_name__[_OP] - _func = generate_inplace_fn(_OP) + func = generate_inplace_fn(_OP) + func.__module__ = __name__ + _func = inplace_apis_in_dygraph_only(func) globals()[_OP] = _func add_sample_code( diff --git a/test/legacy_test/test_inplace.py b/test/legacy_test/test_inplace.py index 91a569d34c62b..d1dbc00dd55a4 100644 --- a/test/legacy_test/test_inplace.py +++ b/test/legacy_test/test_inplace.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import functools import unittest import numpy as np @@ -123,6 +124,14 @@ def test_inplace_api(self): inplace_var[0] = 2.0 np.testing.assert_array_equal(var.numpy(), inplace_var.numpy()) + def test_forward_result(self): + var = paddle.to_tensor(self.input_var_numpy).astype(self.dtype) + no_inplace_var = self.non_inplace_api_processing(var) + inplace_var = self.inplace_api_processing(var) + np.testing.assert_array_equal( + no_inplace_var.numpy(), inplace_var.numpy() + ) + def test_forward_version(self): with paddle.fluid.dygraph.guard(): var = paddle.to_tensor(self.input_var_numpy).astype(self.dtype) @@ -241,6 +250,52 @@ def test_backward_success_2(self): np.testing.assert_array_equal(grad_var_a_inplace, grad_var_a) +class TestDygraphInplaceWithContinuous(TestDygraphInplace): + def init_data(self): + self.input_var_numpy = np.random.uniform(-5, 5, [10, 20, 1]) + self.dtype = "float32" + + def set_np_compare_func(self): + np_array_equal_with_nan = functools.partial( + np.array_equal, equal_nan=True + ) + self.np_compare = np_array_equal_with_nan + + def non_inplace_api_processing(self, var): + return paddle.sin(var) + + def inplace_api_processing(self, var): + return paddle.sin_(var) + + def test_continuous_inplace_backward(self): + # The api that only relies on input to calculate the gradient will copy input before + # the inpalce calculation, so here supports continuous inpalce backward calculation. + grad_var_a, grad_var_a_inplace = 0, 1 + with paddle.fluid.dygraph.guard(): + var_a = paddle.to_tensor(self.input_var_numpy).astype(self.dtype) + var_a.stop_gradient = False + + var_b = var_a**2 + var_c = self.inplace_api_processing(var_b) + var_d = self.inplace_api_processing(var_c) + loss = var_d.sum() + loss.backward() + grad_var_a_inplace = var_a.grad.numpy() + + with paddle.fluid.dygraph.guard(): + var_a = paddle.to_tensor(self.input_var_numpy).astype(self.dtype) + var_a.stop_gradient = False + + var_b = var_a**2 + var_c = self.non_inplace_api_processing(var_b) + var_d = self.non_inplace_api_processing(var_c) + loss = var_d.sum() + loss.backward() + grad_var_a = var_a.grad.numpy() + + self.assertTrue(self.np_compare(grad_var_a_inplace, grad_var_a)) + + class TestDygraphInplaceUnsqueeze(TestDygraphInplace): def non_inplace_api_processing(self, var): return paddle.unsqueeze(var, -1) @@ -506,5 +561,141 @@ def test_getitem_before_inplace(self): loss.backward() +class TestDygraphInplaceAsin(TestDygraphInplaceWithContinuous): + def non_inplace_api_processing(self, var): + return paddle.asin(var) + + def inplace_api_processing(self, var): + return paddle.asin_(var) + + +class TestDygraphInplaceSinh(TestDygraphInplaceWithContinuous): + def non_inplace_api_processing(self, var): + return paddle.sinh(var) + + def inplace_api_processing(self, var): + return paddle.sinh_(var) + + +class TestDygraphInplaceAsinh(TestDygraphInplaceWithContinuous): + def non_inplace_api_processing(self, var): + return paddle.asinh(var) + + def inplace_api_processing(self, var): + return paddle.asinh_(var) + + +class TestDygraphInplaceAbs(TestDygraphInplaceWithContinuous): + def non_inplace_api_processing(self, var): + return paddle.abs(var) + + def inplace_api_processing(self, var): + return paddle.abs_(var) + + +class TestDygraphInplaceCos(TestDygraphInplaceWithContinuous): + def non_inplace_api_processing(self, var): + return paddle.cos(var) + + def inplace_api_processing(self, var): + return paddle.cos_(var) + + +class TestDygraphInplaceCosh(TestDygraphInplaceWithContinuous): + def non_inplace_api_processing(self, var): + return paddle.cosh(var) + + def inplace_api_processing(self, var): + return paddle.cosh_(var) + + +class TestDygraphInplaceAcos(TestDygraphInplaceWithContinuous): + def non_inplace_api_processing(self, var): + return paddle.acos(var) + + def inplace_api_processing(self, var): + return paddle.acos_(var) + + +class TestDygraphInplaceAcosh(TestDygraphInplaceWithContinuous): + def non_inplace_api_processing(self, var): + return paddle.acosh(var) + + def inplace_api_processing(self, var): + return paddle.acosh_(var) + + +class TestDygraphInplaceTan(TestDygraphInplaceWithContinuous): + def non_inplace_api_processing(self, var): + return paddle.tan(var) + + def inplace_api_processing(self, var): + return paddle.tan_(var) + + +class TestDygraphInplaceATan(TestDygraphInplaceWithContinuous): + def non_inplace_api_processing(self, var): + return paddle.atan(var) + + def inplace_api_processing(self, var): + return paddle.atan_(var) + + +class TestDygraphInplaceATanh(TestDygraphInplaceWithContinuous): + def non_inplace_api_processing(self, var): + return paddle.atanh(var) + + def inplace_api_processing(self, var): + return paddle.atanh_(var) + + +class TestDygraphInplaceAddMM(TestDygraphInplaceWithContinuous): + def init_data(self): + self.input_var_numpy = np.random.uniform(-5, 5, [10, 10]) + self.dtype = "float32" + self.x = paddle.randn([10, 10], dtype="float32") + self.y = paddle.randn([10, 10], dtype="float32") + + def non_inplace_api_processing(self, var): + return paddle.addmm(var, x=self.x, y=self.y) + + def inplace_api_processing(self, var): + return paddle.addmm_(var, x=self.x, y=self.y) + + def test_errors(self): + var = paddle.to_tensor(self.input_var_numpy).astype(self.dtype) + x1 = paddle.randn([10]) + self.assertRaises(ValueError, paddle.addmm_, var, x1, self.y) + + y1 = paddle.randn([12, 10]) + self.assertRaises(ValueError, paddle.addmm_, var, self.x, y1) + x2 = paddle.randn([12, 10]) + self.assertRaises(ValueError, paddle.addmm_, var, x2, self.y) + var1 = paddle.randn([1, 5]) + self.assertRaises(ValueError, paddle.addmm_, var1, x2, self.y) + y2 = paddle.randn([10, 12]) + self.assertRaises(ValueError, paddle.addmm_, var, self.x, y2) + var2 = paddle.randn([6]) + self.assertRaises(ValueError, paddle.addmm_, var2, self.x, self.y) + var3 = paddle.randn([2, 3, 4]) + self.assertRaises(ValueError, paddle.addmm_, var3, self.x, self.y) + + +class TestDygraphInplacePowerScalar(TestDygraphInplaceWithContinuous): + def inplace_api_processing(self, var): + return paddle.pow_(var, 2) + + def non_inplace_api_processing(self, var): + return paddle.pow(var, 2) + + def test_type_error(self): + var = paddle.to_tensor(self.input_var_numpy, dtype=self.dtype) + with self.assertRaisesRegex( + TypeError, + 'y must be scalar type, but received: %s ' % (type([2])), + ): + paddle.pow_(var, [2]) + + if __name__ == '__main__': unittest.main() diff --git a/test/legacy_test/test_pow.py b/test/legacy_test/test_pow.py index e829230492eee..011593b3e874e 100755 --- a/test/legacy_test/test_pow.py +++ b/test/legacy_test/test_pow.py @@ -15,7 +15,6 @@ import unittest import numpy as np -from test_inplace import TestDygraphInplace import paddle from paddle.fluid import core @@ -214,40 +213,5 @@ def test_errors(self): self.assertRaises(TypeError, paddle.pow, x, str(y)) -class TestInplacePowerScalar(TestDygraphInplace): - def set_np_compare_func(self): - self.np_compare = np.allclose - - def inplace_api_processing(self, var): - return paddle.pow_(var, 2) - - def non_inplace_api_processing(self, var): - return paddle.pow(var, 2) - - -class TestInplacePowerTensor(TestDygraphInplace): - def init_data(self): - self.input_var_numpy = np.random.uniform(-5, 5, [10, 20, 1]) - self.dtype = "float32" - self.y = paddle.ones([10, 20, 1], dtype="float32") * 2 - - def set_np_compare_func(self): - self.np_compare = np.allclose - - def inplace_api_processing(self, var): - return paddle.pow_(var, self.y) - - def non_inplace_api_processing(self, var): - return paddle.pow(var, self.y) - - def test_type_error(self): - var = paddle.to_tensor(self.input_var_numpy, dtype=self.dtype) - with self.assertRaisesRegex( - TypeError, - 'y must be scalar or tensor type, but received: %s ' % (type([2])), - ): - paddle.pow_(var, [2]) - - if __name__ == '__main__': unittest.main()