From 8248107dd7eb6fe3b76b6d01c68bac65fdfd0d60 Mon Sep 17 00:00:00 2001 From: GGBond8488 <33050871+GGBond8488@users.noreply.github.com> Date: Fri, 22 Sep 2023 11:14:48 +0800 Subject: [PATCH] add inplace api transpose_, t_, normal_,cauchy_, geometric_ (#57093) * add inplace * fix transpose inpalce error * fix error * fix * fix * add gaussian inpalce kernel * change cauchy_ gepmetric impl * fix typro * add test * remove gaussian test * fix sample code error * fix sample code * fix sample code error --- paddle/phi/api/yaml/backward.yaml | 11 ++ paddle/phi/api/yaml/generator/api_base.py | 1 + paddle/phi/api/yaml/legacy_ops.yaml | 1 + paddle/phi/api/yaml/ops.yaml | 13 ++ .../cpu/gaussian_inplace_grad_kernel.cc | 41 +++++ paddle/phi/kernels/cpu/gaussian_kernel.cc | 31 ++++ .../kernels/gaussian_inplace_grad_kernel.h | 29 ++++ paddle/phi/kernels/gaussian_kernel.h | 8 + .../gpu/gaussian_inplace_grad_kernel.cu | 44 +++++ paddle/phi/kernels/gpu/gaussian_kernel.cu | 32 ++++ paddle/phi/kernels/stride/transpose_kernel.cc | 2 - python/paddle/__init__.py | 10 ++ python/paddle/tensor/__init__.py | 13 ++ python/paddle/tensor/creation.py | 69 ++++++++ python/paddle/tensor/linalg.py | 32 ++++ python/paddle/tensor/random.py | 73 ++++++++ test/legacy_test/test_cauchy_inplace.py | 139 ++++++++++++++++ test/legacy_test/test_geometric_inplace.py | 143 ++++++++++++++++ test/legacy_test/test_inplace.py | 48 ++++++ test/legacy_test/test_normal_inplace.py | 156 ++++++++++++++++++ 20 files changed, 894 insertions(+), 2 deletions(-) create mode 100644 paddle/phi/kernels/cpu/gaussian_inplace_grad_kernel.cc create mode 100644 paddle/phi/kernels/gaussian_inplace_grad_kernel.h create mode 100644 paddle/phi/kernels/gpu/gaussian_inplace_grad_kernel.cu create mode 100644 test/legacy_test/test_cauchy_inplace.py create mode 100644 test/legacy_test/test_geometric_inplace.py create mode 100644 test/legacy_test/test_normal_inplace.py diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index 2f48bb80478e6c..66f5056320950e 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -939,6 +939,17 @@ composite : gather_nd_grad(x, index, out_grad, x_grad) no_need_buffer : x +- backward_op : gaussian_inplace_grad + forward : gaussian_inplace(Tensor x, float mean=0, float std=1.0, int seed=0) -> Tensor(out) + args : (Tensor out_grad, float mean=0, float std=1.0, int seed=0) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [out_grad] + kernel : + func : gaussian_inplace_grad + inplace : (out_grad -> x_grad) + - backward_op : gelu_grad forward : gelu(Tensor x, bool approximate) -> Tensor(out) args : (Tensor x, Tensor out_grad, bool approximate) diff --git a/paddle/phi/api/yaml/generator/api_base.py b/paddle/phi/api/yaml/generator/api_base.py index cbf4ed1dab8371..5e7cff92131712 100644 --- a/paddle/phi/api/yaml/generator/api_base.py +++ b/paddle/phi/api/yaml/generator/api_base.py @@ -1223,6 +1223,7 @@ def gen_kernel_code(self, kernel_name, code_indent, inplace_flag=False): "unsqueeze", "reshape", "flatten", + "transpose", ]: i = 0 for kernel_out in outputs_args: diff --git a/paddle/phi/api/yaml/legacy_ops.yaml b/paddle/phi/api/yaml/legacy_ops.yaml index a647e02b35ef25..14daf99fd7f13a 100755 --- a/paddle/phi/api/yaml/legacy_ops.yaml +++ b/paddle/phi/api/yaml/legacy_ops.yaml @@ -1065,6 +1065,7 @@ func : TransposeInferMeta kernel : func : transpose + inplace : (x -> out) backward : transpose_grad - op : tril diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index c93f94c2b3320a..fdada46699d26a 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -1052,6 +1052,19 @@ func : gather_tree data_type : ids +- op : gaussian_inplace + args: (Tensor x, float mean=0, float std=1.0, int seed=0) + output: Tensor(out) + infer_meta: + func: UnchangedInferMeta + param: [x] + kernel: + func: gaussian_inplace + data_type: x + backend : x + inplace: (x -> out) + backward: gaussian_inplace_grad + - op : gelu args : (Tensor x, bool approximate = false) output : Tensor(out) diff --git a/paddle/phi/kernels/cpu/gaussian_inplace_grad_kernel.cc b/paddle/phi/kernels/cpu/gaussian_inplace_grad_kernel.cc new file mode 100644 index 00000000000000..5913e5b1a4e569 --- /dev/null +++ b/paddle/phi/kernels/cpu/gaussian_inplace_grad_kernel.cc @@ -0,0 +1,41 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/gaussian_inplace_grad_kernel.h" + +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void GaussianInplaceGradKernel(const Context& ctx, + const DenseTensor& out_grad UNUSED, + float mean UNUSED, + float std UNUSED, + int seed UNUSED, + DenseTensor* x_grad) { + if (x_grad) { + auto* data = ctx.template Alloc(x_grad); + std::fill(data, data + x_grad->numel(), T(0)); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(gaussian_inplace_grad, + CPU, + ALL_LAYOUT, + phi::GaussianInplaceGradKernel, + float, + double) {} diff --git a/paddle/phi/kernels/cpu/gaussian_kernel.cc b/paddle/phi/kernels/cpu/gaussian_kernel.cc index 2eb783c695b65f..00ed6aaf357409 100644 --- a/paddle/phi/kernels/cpu/gaussian_kernel.cc +++ b/paddle/phi/kernels/cpu/gaussian_kernel.cc @@ -48,7 +48,38 @@ void GaussianKernel(const Context& dev_ctx, } } +template +void GaussianInplaceKernel(const Context& dev_ctx, + const DenseTensor& x, + float mean, + float std, + int seed, + DenseTensor* out) { + T* data = dev_ctx.template Alloc(out); + std::normal_distribution dist(mean, std); + + int64_t size = out->numel(); + std::shared_ptr engine; + if (seed) { + engine = std::make_shared(); + engine->seed(seed); + } else { + engine = dev_ctx.GetGenerator()->GetCPUEngine(); + } + + for (int64_t i = 0; i < size; ++i) { + data[i] = dist(*engine); + } +} + } // namespace phi PD_REGISTER_KERNEL( gaussian, CPU, ALL_LAYOUT, phi::GaussianKernel, float, double) {} + +PD_REGISTER_KERNEL(gaussian_inplace, + CPU, + ALL_LAYOUT, + phi::GaussianInplaceKernel, + float, + double) {} diff --git a/paddle/phi/kernels/gaussian_inplace_grad_kernel.h b/paddle/phi/kernels/gaussian_inplace_grad_kernel.h new file mode 100644 index 00000000000000..447b7199f695e5 --- /dev/null +++ b/paddle/phi/kernels/gaussian_inplace_grad_kernel.h @@ -0,0 +1,29 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void GaussianInplaceGradKernel(const Context& ctx, + const DenseTensor& out_grad, + float mean, + float std, + int seed, + DenseTensor* x_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/gaussian_kernel.h b/paddle/phi/kernels/gaussian_kernel.h index a04c8802cf3854..5c24d9eb6eb146 100644 --- a/paddle/phi/kernels/gaussian_kernel.h +++ b/paddle/phi/kernels/gaussian_kernel.h @@ -29,4 +29,12 @@ void GaussianKernel(const Context& ctx, DataType dtype, DenseTensor* out); +template +void GaussianInplaceKernel(const Context& ctx, + const DenseTensor& x, + float mean, + float std, + int seed, + DenseTensor* out); + } // namespace phi diff --git a/paddle/phi/kernels/gpu/gaussian_inplace_grad_kernel.cu b/paddle/phi/kernels/gpu/gaussian_inplace_grad_kernel.cu new file mode 100644 index 00000000000000..d2bb9c31fa67da --- /dev/null +++ b/paddle/phi/kernels/gpu/gaussian_inplace_grad_kernel.cu @@ -0,0 +1,44 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/gaussian_inplace_grad_kernel.h" + +#include "paddle/phi/common/amp_type_traits.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/full_kernel.h" + +namespace phi { + +template +void GaussianInplaceGradKernel(const Context& ctx, + const DenseTensor& out_grad, + float mean, + float std, + int seed, + DenseTensor* x_grad) { + auto dims = vectorize(x_grad->dims()); + float value = static_cast(0.0f); + phi::FullKernel(ctx, dims, value, phi::DataType::UNDEFINED, x_grad); +} + +} // namespace phi + +PD_REGISTER_KERNEL(gaussian_inplace_grad, + GPU, + ALL_LAYOUT, + phi::GaussianInplaceGradKernel, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/gaussian_kernel.cu b/paddle/phi/kernels/gpu/gaussian_kernel.cu index d0f839bd677d47..6e5c7ee63ce531 100644 --- a/paddle/phi/kernels/gpu/gaussian_kernel.cu +++ b/paddle/phi/kernels/gpu/gaussian_kernel.cu @@ -76,6 +76,29 @@ void GaussianKernel(const Context& dev_ctx, } } +template +void GaussianInpalceKernel(const Context& dev_ctx, + const DenseTensor& x, + float mean, + float std, + int seed, + DenseTensor* out) { + dev_ctx.template Alloc(out); + if (seed == 0) { + // use global Generator seed + using MT = typename phi::dtype::MPTypeTrait::Type; + funcs::normal_distribution dist; + funcs::normal_transform trans(static_cast(mean), + static_cast(std)); + funcs::distribution_and_transform(dev_ctx, out, dist, trans); + } else { + // use OP seed + auto func = + GaussianGenerator(static_cast(mean), static_cast(std), seed); + IndexKernel>(dev_ctx, out, func); + } +} + } // namespace phi PD_REGISTER_KERNEL(gaussian, @@ -86,3 +109,12 @@ PD_REGISTER_KERNEL(gaussian, phi::dtype::bfloat16, float, double) {} + +PD_REGISTER_KERNEL(gaussian_inpalce, + GPU, + ALL_LAYOUT, + phi::GaussianInpalceKernel, + phi::dtype::float16, + phi::dtype::bfloat16, + float, + double) {} diff --git a/paddle/phi/kernels/stride/transpose_kernel.cc b/paddle/phi/kernels/stride/transpose_kernel.cc index 748beb5194d4ae..1fedb515ef020b 100644 --- a/paddle/phi/kernels/stride/transpose_kernel.cc +++ b/paddle/phi/kernels/stride/transpose_kernel.cc @@ -33,11 +33,9 @@ void TransposeStridedKernel(const Context& ctx, auto meta = out->meta(); auto in_stride = x.strides(); - auto in_dims = x.dims(); meta.strides = in_stride; for (int i = 0; i < static_cast(formated_axis.size()); i++) { meta.strides[i] = in_stride[formated_axis[i]]; - meta.dims[i] = in_dims[formated_axis[i]]; } meta.offset = x.offset(); diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 2fe1eecf21ff88..0c168b44c4a640 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -122,12 +122,16 @@ from .tensor.creation import tril_indices # noqa: F401 from .tensor.creation import triu_indices # noqa: F401 from .tensor.creation import polar # noqa: F401 +from .tensor.creation import geometric_ # noqa: F401 +from .tensor.creation import cauchy_ # noqa: F401 from .tensor.linalg import matmul # noqa: F401 from .tensor.linalg import dot # noqa: F401 from .tensor.linalg import norm # noqa: F401 from .tensor.linalg import transpose # noqa: F401 +from .tensor.linalg import transpose_ # noqa: F401 from .tensor.linalg import dist # noqa: F401 from .tensor.linalg import t # noqa: F401 +from .tensor.linalg import t_ # noqa: F401 from .tensor.linalg import cdist # noqa: F401 from .tensor.linalg import cross # noqa: F401 from .tensor.linalg import cholesky # noqa: F401 @@ -381,6 +385,7 @@ from .tensor.random import multinomial # noqa: F401 from .tensor.random import standard_normal # noqa: F401 from .tensor.random import normal # noqa: F401 +from .tensor.random import normal_ # noqa: F401 from .tensor.random import uniform # noqa: F401 from .tensor.random import randn # noqa: F401 from .tensor.random import rand # noqa: F401 @@ -505,6 +510,7 @@ 'allclose', 'isclose', 't', + 't_', 'add', 'subtract', 'diag', @@ -556,6 +562,7 @@ 'any', 'slice', 'normal', + 'normal_', 'logsumexp', 'full', 'unsqueeze', @@ -736,6 +743,9 @@ 'tanh', 'tanh_', 'transpose', + 'transpose_', + 'cauchy_', + 'geometric_', 'randn', 'strided_slice', 'unique', diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 841925f8b7ff88..b728392b0452de 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -43,6 +43,8 @@ from .creation import empty_like # noqa: F401 from .creation import complex # noqa: F401 from .creation import polar # noqa: F401 +from .creation import cauchy_ # noqa: F401 +from .creation import geometric_ # noqa: F401 from .linalg import matmul # noqa: F401 from .linalg import dot # noqa: F401 from .linalg import cov # noqa: F401 @@ -51,9 +53,11 @@ from .linalg import pca_lowrank # noqa: F401 from .linalg import cond # noqa: F401 from .linalg import transpose # noqa: F401 +from .linalg import transpose_ # noqa: F401 from .linalg import lstsq # noqa: F401 from .linalg import dist # noqa: F401 from .linalg import t # noqa: F401 +from .linalg import t_ # noqa: F401 from .linalg import cross # noqa: F401 from .linalg import cholesky # noqa: F401 from .linalg import bmm # noqa: F401 @@ -327,6 +331,7 @@ from .random import multinomial # noqa: F401 from .random import standard_normal # noqa: F401 from .random import normal # noqa: F401 +from .random import normal_ # noqa: F401 from .random import uniform # noqa: F401 from .random import uniform_ # noqa: F401 from .random import randn # noqa: F401 @@ -381,9 +386,12 @@ 'norm', 'cond', 'transpose', + 'cauchy_', + 'geometric_', 'lstsq', 'dist', 't', + 't_', 'cross', 'cholesky', 'bmm', @@ -558,6 +566,10 @@ 'stack', 'strided_slice', 'transpose', + 'transpose_', + 'cauchy_', + 'geometric_', + 'tan_', 'unique', 'unique_consecutive', 'unsqueeze', @@ -673,6 +685,7 @@ 'i1e', 'polygamma', 'polygamma_', + 'normal_', ] # this list used in math_op_patch.py for magic_method bind diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index f764fbb45996d2..3f543ea29d0037 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -39,6 +39,7 @@ _get_paddle_place, convert_np_dtype_to_dtype_, core, + dygraph_only, in_dynamic_mode, in_dynamic_or_pir_mode, in_pir_mode, @@ -2655,3 +2656,71 @@ def polar(abs, angle, name=None): ) return paddle.complex(abs * paddle.cos(angle), abs * paddle.sin(angle)) + + +@dygraph_only +def cauchy_(x, loc=0, scale=1, name=None): + """Fills the tensor with numbers drawn from the Cauchy distribution. + + Args: + x (Tenosr): the tensor will be filled, The data type is float32 or float64. + loc (scalar, optional): Location of the peak of the distribution. The data type is float32 or float64. + scale (scalar, optional): The half-width at half-maximum (HWHM). The data type is float32 or float64. Must be positive values. + name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. + + Returns: + Tensor: input tensor with numbers drawn from the Cauchy distribution. + + Examples: + .. code-block:: python + + >>> import paddle + >>> x = paddle.randn([3, 4]) + >>> x.cauchy_(1, 2) + >>> # doctest: +SKIP('random check') + >>> print(x) + Tensor(shape=[3, 4], dtype=float32, place=Place(cpu), stop_gradient=True, + [[ 3.80087137, 2.25415039, 2.77960515, 7.64125967], + [ 0.76541221, 2.74023032, 1.99383152, -0.12685823], + [ 1.45228469, 1.76275957, -4.30458832, 34.74880219]]) + + """ + x.normal_() + loc = paddle.to_tensor(loc).astype(x.dtype) + half = paddle.to_tensor(0.5).astype(x.dtype) + x.subtract_(half).scale_(np.pi).tan_().scale_(scale).add_(loc) + return x + + +@dygraph_only +def geometric_(x, probs, name=None): + """Fills the tensor with numbers drawn from the Geometric distribution. + + Args: + x (Tenosr): the tensor will be filled, The data type is float32 or float64. + probs (Real|Tensor): Probability parameter. + The value of probs must be positive. When the parameter is a tensor, probs is probability of success for each trial. + name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. + + Returns: + Tensor: input tensor with numbers drawn from the Geometric distribution. + + Examples: + .. code-block:: python + + >>> import paddle + >>> x = paddle.randn([3, 4]) + >>> x.geometric_(0.3) + >>> # doctest: +SKIP('random check') + >>> print(x) + Tensor(shape=[3, 4], dtype=float32, place=Place(cpu), stop_gradient=True, + [[2.42739224, 4.78268528, 1.23302543, 3.76555204], + [1.38877118, 0.16075331, 0.16401523, 2.47349310], + [1.72872102, 2.76533413, 0.33410925, 1.63351011]]) + + """ + tiny = np.finfo(dtype=convert_dtype(x.dtype)).tiny + probs = paddle.to_tensor(probs).astype(x.dtype) + x.uniform_(min=float(tiny), max=float(1)) + x.log_().divide_(paddle.log1p(-(probs))) + return x diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 286dcd261d8fef..acc59d6385e570 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -17,6 +17,7 @@ import paddle from paddle import _C_ops from paddle.common_ops_import import VarDesc +from paddle.utils.inplace_utils import inplace_apis_in_dygraph_only from ..base.data_feeder import check_dtype, check_type, check_variable_and_dtype from ..common_ops_import import Variable @@ -130,6 +131,16 @@ def transpose(x, perm, name=None): return out +@inplace_apis_in_dygraph_only +def transpose_(x, perm, name=None): + r""" + Inplace version of ``transpose`` API, the output Tensor will be inplaced with input ``x``. + Please refer to :ref:`api_paddle_transpose`. + """ + if in_dynamic_mode(): + return _C_ops.transpose_(x, perm) + + def matmul(x, y, transpose_x=False, transpose_y=False, name=None): """ Applies matrix multiplication to two tensors. `matmul` follows @@ -1394,6 +1405,27 @@ def t(input, name=None): return out +@inplace_apis_in_dygraph_only +def t_(input, name=None): + r""" + Inplace version of ``t`` API, the output Tensor will be inplaced with input ``input``. + Please refer to :ref:`api_paddle_t`. + """ + if len(input.shape) > 2: + raise ValueError( + "Input(input) only support N-D (N<=2) tensor, but received " + "length of Input(input) is %s. Perhaps you can use paddle." + "tensor.transpose() instead." % len(input.shape) + ) + if in_dynamic_mode(): + if len(input.shape) <= 1: + return input + # 2-D tensor + perm = [1, 0] + out = _C_ops.transpose_(input, perm) + return out + + def cross(x, y, axis=9, name=None): """ Computes the cross product between two tensors along an axis. diff --git a/python/paddle/tensor/random.py b/python/paddle/tensor/random.py index f32978ca50706e..46ee4ff6920b91 100644 --- a/python/paddle/tensor/random.py +++ b/python/paddle/tensor/random.py @@ -394,6 +394,40 @@ def gaussian(shape, mean=0.0, std=1.0, seed=0, dtype=None, name=None): return out +@dygraph_only +def gaussian_(x, mean=0.0, std=1.0, seed=0, name=None): + """ + This is the inplace version of OP ``gaussian``, which returns a Tensor filled + with random values sampled from a gaussian distribution. The output Tensor will + be inplaced with input ``x``. Please refer to :ref:`api_tensor_gaussian`. + + Args: + x(Tensor): The input tensor to be filled with random values. + mean (float|int, optional): Mean of the output tensor, default is 0.0. + std (float|int, optional): Standard deviation of the output tensor, default + is 1.0. + seed (int, optional): Random seed of generator. + name(str, optional): The default value is None. Normally there is no + need for user to set this property. For more information, please + refer to :ref:`api_guide_Name`. + Returns: + Tensor: The input tensor x filled with random values sampled from a gaussian + distribution. + Examples: + .. code-block:: python + + >>> import paddle + >>> x = paddle.randn([3, 4]) + >>> paddle.tensor.random.gaussian_(x) + >>> print(x) + Tensor(shape=[3, 4], dtype=float32, place=Place(cpu), stop_gradient=True, + [[ 0.86384124, 0.67328387, 0.21874231, -0.12615913], + [ 0.69844258, 0.42084831, -0.42476156, -0.00072985], + [ 1.72819555, 1.87785017, 0.48915744, 0.09235018]]) + """ + return _C_ops.gaussian_inplace_(x, float(mean), float(std), int(seed)) + + def standard_normal(shape, dtype=None, name=None): """ Returns a Tensor filled with random values sampled from a standard @@ -627,6 +661,45 @@ def normal(mean=0.0, std=1.0, shape=None, name=None): return out +@dygraph_only +def normal_(x, mean=0.0, std=1.0, name=None): + """ + This is the inplace version of api ``normal``, which returns a Tensor filled + with random values sampled from a normal distribution. The output Tensor will + be inplaced with input ``x``. Please refer to :ref:`api_tensor_noraml`. + + Args: + x(Tensor): The input tensor to be filled with random values. + mean (float|Tensor, optional): The mean of the output Tensor's normal distribution. + If ``mean`` is float, all elements of the output Tensor shared the same mean. + If ``mean`` is a Tensor(data type supports float32, float64), it has per-element means. + Default is 0.0 + std (float|Tensor, optional): The standard deviation of the output Tensor's normal distribution. + If ``std`` is float, all elements of the output Tensor shared the same standard deviation. + If ``std`` is a Tensor(data type supports float32, float64), it has per-element standard deviations. + Defaule is 1.0 + name(str, optional): The default value is None. Normally there is no + need for user to set this property. For more information, please + refer to :ref:`api_guide_Name`. + Returns: + A Tensor filled with random values sampled from a normal distribution with ``mean`` and ``std`` . + Examples: + .. code-block:: python + + >>> import paddle + >>> x = paddle.randn([3, 4]) + >>> x.normal_() + >>> # doctest: +SKIP('random check') + >>> print(x) + Tensor(shape=[3, 4], dtype=float32, place=Place(cpu), stop_gradient=True, + [[ 0.06132207, 1.11349595, 0.41906244, -0.24858207], + [-1.85169315, -1.50370061, 1.73954511, 0.13331604], + [ 1.66359663, -0.55764782, -0.59911072, -0.57773495]]) + + """ + return gaussian_(x, mean=mean, std=std) + + def uniform(shape, dtype=None, min=-1.0, max=1.0, seed=0, name=None): """ Returns a Tensor filled with random values sampled from a uniform diff --git a/test/legacy_test/test_cauchy_inplace.py b/test/legacy_test/test_cauchy_inplace.py new file mode 100644 index 00000000000000..7c2b05bc64729d --- /dev/null +++ b/test/legacy_test/test_cauchy_inplace.py @@ -0,0 +1,139 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle +from paddle import base + + +class TestCauchyInplaceDtype(unittest.TestCase): + def setUp(self): + self.shape = (1000, 784) + + def test_cauchytype(self): + def test_fp32(): + tensor_fp32 = paddle.ones(self.shape, dtype=paddle.float32) + tensor_fp32.cauchy_() + self.assertEqual(tensor_fp32.dtype, paddle.float32) + + def test_fp64(): + tensor_fp64 = paddle.ones(self.shape, paddle.float64) + tensor_fp64.cauchy_() + self.assertEqual(tensor_fp64.dtype, paddle.float64) + + places = ['cpu'] + if base.core.is_compiled_with_cuda(): + places.append('gpu') + for place in places: + paddle.set_device(place) + test_fp32() + test_fp64() + + +class TestCauchyIsInplace(unittest.TestCase): + def setUp(self): + self.shape = (1000, 784) + + def test_cauchy_inplace_op_is_inplace(self): + tensor_a = paddle.ones(self.shape) + tensor_b = tensor_a.cauchy_() + self.assertTrue(tensor_a is tensor_b) + + +class TestCauchyInplaceSeedIsZero(unittest.TestCase): + def setUp(self): + self.shape = (1000, 784) + + def test_cauchy_inplace_op_not_equal(self): + tensor = paddle.ones(self.shape) + tensor.cauchy_() + tensor_data_first = tensor.numpy() + tensor.cauchy_() + tensor_data_second = tensor.numpy() + self.assertFalse((tensor_data_first == tensor_data_second).all()) + + +class TestCauchyInplaceOpShape(unittest.TestCase): + def setUp(self): + self.shape = (1000, 784) + + def test_cauchy_inplace_op_shape(self): + tensor = paddle.ones(self.shape) + tensor.cauchy_() + tensor_shape_np = np.array(tensor.shape) + origin_shape = np.array(self.shape) + self.assertTrue((tensor_shape_np == origin_shape).all()) + + +class TestCauchyInplaceDistribution(unittest.TestCase): + def setUp(self): + self.shape = (1000, 784) + self.loc = -3 + self.scale = 5 + + def test_cauchy_inplace_distribution(self): + tensor = paddle.ones(self.shape) + tensor.cauchy_(loc=self.loc, scale=self.scale) + median = tensor.median() + np.testing.assert_allclose(median, self.loc, atol=1e-1) + + +class TestCauchyInplaceEmptyTensor(unittest.TestCase): + def test_cauchy_inplace_op_empty_tensor(self): + places = ['cpu'] + if base.core.is_compiled_with_cuda(): + places.append('gpu') + test_shapes = [(200, 1), (1, 200)] + for place in places: + paddle.set_device(place) + for test_shape in test_shapes: + tensor = paddle.empty(shape=test_shape) + tensor.cauchy_() + tensor_shape_np = np.array(tensor.shape) + origin_shape = np.array(test_shape) + self.assertTrue((tensor_shape_np == origin_shape).all()) + + +class TestCauchyInplaceGrad(unittest.TestCase): + def setUp(self): + self.shape = (1000, 784) + + def run_(self): + def test_grad(): + tensor_a = paddle.ones(self.shape) + tensor_a.stop_gradient = False + tensor_b = tensor_a * 0.5 + tensor_b.retain_grads() + tensor_b.cauchy_() + loss = tensor_b.sum() + loss.backward() + cauchy_grad = tensor_b.grad.numpy() + self.assertTrue((cauchy_grad == 0).all()) + + places = ['cpu'] + if base.core.is_compiled_with_cuda(): + places.append('gpu') + for place in places: + paddle.set_device(place) + test_grad() + + def test_cauchy_inplace_grad(self): + self.run_() + + +if __name__ == '__main__': + unittest.main() diff --git a/test/legacy_test/test_geometric_inplace.py b/test/legacy_test/test_geometric_inplace.py new file mode 100644 index 00000000000000..20f39621f2490a --- /dev/null +++ b/test/legacy_test/test_geometric_inplace.py @@ -0,0 +1,143 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import scipy.stats + +import paddle +from paddle import base + + +class TestGeometricInplaceDtype(unittest.TestCase): + def setUp(self): + self.shape = (1000, 784) + + def test_geometrictype(self): + def test_fp32(): + tensor_fp32 = paddle.ones(self.shape, dtype=paddle.float32) + tensor_fp32.geometric_(probs=0.3) + self.assertEqual(tensor_fp32.dtype, paddle.float32) + + def test_fp64(): + tensor_fp64 = paddle.ones(self.shape, paddle.float64) + tensor_fp64.geometric_(probs=0.3) + self.assertEqual(tensor_fp64.dtype, paddle.float64) + + places = ['cpu'] + if base.core.is_compiled_with_cuda(): + places.append('gpu') + for place in places: + paddle.set_device(place) + test_fp32() + test_fp64() + + +class TestGeometricIsInplace(unittest.TestCase): + def setUp(self): + self.shape = (1000, 784) + + def test_geometric_inplace_op_is_inplace(self): + tensor_a = paddle.ones(self.shape) + tensor_b = tensor_a.geometric_(probs=0.3) + self.assertTrue(tensor_a is tensor_b) + + +class TestGeometricInplaceSeedIsZero(unittest.TestCase): + def setUp(self): + self.shape = (1000, 784) + + def test_geometric_inplace_op_not_equal(self): + tensor = paddle.ones(self.shape) + tensor.geometric_(probs=0.3) + tensor_data_first = tensor.numpy() + tensor.geometric_(probs=0.3) + tensor_data_second = tensor.numpy() + self.assertFalse((tensor_data_first == tensor_data_second).all()) + + +class TestGeometricInplaceOpShape(unittest.TestCase): + def setUp(self): + self.shape = (1000, 784) + + def test_geometric_inplace_op_shape(self): + tensor = paddle.ones(self.shape) + tensor.geometric_(probs=0.3) + tensor_shape_np = np.array(tensor.shape) + origin_shape = np.array(self.shape) + self.assertTrue((tensor_shape_np == origin_shape).all()) + + +class TestGeometricInplaceDistribution(unittest.TestCase): + def setUp(self): + self.shape = (1000, 784) + self.probs = 0.3 + + def test_geometric_inplace_distribution(self): + a = paddle.ones(self.shape) + a.geometric_(self.probs) + np.testing.assert_allclose( + a.mean(axis=0), scipy.stats.geom.mean(self.probs), rtol=0.7, atol=0 + ) + np.testing.assert_allclose( + a.var(axis=0), scipy.stats.geom.var(self.probs), rtol=0.7, atol=0 + ) + + +class TestGeometricInplaceEmptyTensor(unittest.TestCase): + def test_geometric_inplace_op_empty_tensor(self): + places = ['cpu'] + if base.core.is_compiled_with_cuda(): + places.append('gpu') + test_shapes = [(200, 1), (1, 200)] + for place in places: + paddle.set_device(place) + for test_shape in test_shapes: + tensor = paddle.empty(shape=test_shape) + tensor.geometric_(probs=0.3) + tensor_shape_np = np.array(tensor.shape) + origin_shape = np.array(test_shape) + self.assertTrue((tensor_shape_np == origin_shape).all()) + + +class TestGeometricInplaceGrad(unittest.TestCase): + def setUp(self): + self.shape = (1000, 784) + + def run_(self): + def test_grad(): + tensor_a = paddle.ones(self.shape) + tensor_a.stop_gradient = False + tensor_b = tensor_a * 0.5 + tensor_b.retain_grads() + tensor_b.geometric_(probs=0.3) + loss = tensor_b.sum() + loss.backward() + geometric_grad = tensor_b.grad.numpy() + self.assertTrue((geometric_grad == 0).all()) + + places = ['cpu'] + if base.core.is_compiled_with_cuda(): + places.append('gpu') + for place in places: + paddle.set_device(place) + test_grad() + + def test_geometric_inplace_grad(self): + self.run_() + + +if __name__ == '__main__': + unittest.main() diff --git a/test/legacy_test/test_inplace.py b/test/legacy_test/test_inplace.py index 676977ba2ac484..e3f1de1048e113 100644 --- a/test/legacy_test/test_inplace.py +++ b/test/legacy_test/test_inplace.py @@ -1434,5 +1434,53 @@ def non_inplace_api_processing(self, var): return paddle.multiply(var, self.y) +class TestDygrapInplaceT(TestDygraphInplaceWithContinuous): + def init_data(self): + self.input_var_numpy = np.random.uniform(-5, 5, [10, 20]) + self.dtype = "float32" + + def inplace_api_processing(self, var): + return paddle.t_(var) + + def non_inplace_api_processing(self, var): + return paddle.t(var) + + def test_forward_version(self): + with paddle.base.dygraph.guard(): + var = paddle.to_tensor(self.input_var_numpy).astype(self.dtype) + self.assertEqual(var.inplace_version, 0) + + inplace_var = self.inplace_api_processing(var) + self.assertEqual(var.inplace_version, 1) + + inplace_var[0] = 2 + self.assertEqual(var.inplace_version, 1) + + inplace_var = self.inplace_api_processing(inplace_var) + self.assertEqual(var.inplace_version, 2) + + +class TestDygrapInplaceTranspose(TestDygraphInplaceWithContinuous): + def inplace_api_processing(self, var): + return paddle.transpose_(var, [1, 0, 2]) + + def non_inplace_api_processing(self, var): + return paddle.transpose(var, [1, 0, 2]) + + def test_forward_version(self): + with paddle.base.dygraph.guard(): + var = paddle.to_tensor(self.input_var_numpy).astype(self.dtype) + self.assertEqual(var.inplace_version, 0) + + inplace_var = self.inplace_api_processing(var) + self.assertEqual(var.inplace_version, 1) + + inplace_var[0] = 2 + self.assertEqual(var.inplace_version, 1) + + inplace_var = self.inplace_api_processing(inplace_var) + self.assertEqual(var.inplace_version, 2) + + if __name__ == '__main__': unittest.main() diff --git a/test/legacy_test/test_normal_inplace.py b/test/legacy_test/test_normal_inplace.py new file mode 100644 index 00000000000000..dc693a6652561c --- /dev/null +++ b/test/legacy_test/test_normal_inplace.py @@ -0,0 +1,156 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle +from paddle import base + + +def output_hist(out): + hist, _ = np.histogram(out, range=(-1, 1)) + hist = hist.astype("float32") + hist /= float(out.size) + prob = 0.1 * np.ones(10) + return hist, prob + + +class TestNormalRandomInplaceOpDtype(unittest.TestCase): + def setUp(self): + self.shape = (1000, 784) + + def test_normal_inplace_op_dtype(self): + def test_fp32(): + tensor_fp32 = paddle.ones(self.shape, dtype=paddle.float32) + tensor_fp32.normal_() + self.assertEqual(tensor_fp32.dtype, paddle.float32) + + def test_fp64(): + tensor_fp64 = paddle.ones(self.shape, paddle.float64) + tensor_fp64.normal_() + self.assertEqual(tensor_fp64.dtype, paddle.float64) + + places = ['cpu'] + if base.core.is_compiled_with_cuda(): + places.append('gpu') + for place in places: + paddle.set_device(place) + test_fp32() + test_fp64() + + +class TestNormalRandomInplaceOpIsInplace(unittest.TestCase): + def setUp(self): + self.shape = (1000, 784) + + def test_normal_inplace_op_is_inplace(self): + tensor_a = paddle.ones(self.shape) + tensor_b = tensor_a.normal_() + self.assertTrue(tensor_a is tensor_b) + + +class TestNormalRandomInplaceOpSeedIsZero(unittest.TestCase): + def setUp(self): + self.shape = (1000, 784) + + def test_normal_inplace_op_not_equal(self): + tensor = paddle.ones(self.shape) + tensor.normal_() + tensor_data_first = tensor.numpy() + tensor.normal_() + tensor_data_second = tensor.numpy() + self.assertFalse((tensor_data_first == tensor_data_second).all()) + + +class TestNormalRandomInplaceOpShape(unittest.TestCase): + def setUp(self): + self.shape = (1000, 784) + + def test_normal_inplace_op_shape(self): + tensor = paddle.ones(self.shape) + tensor.normal_() + tensor_shape_np = np.array(tensor.shape) + origin_shape = np.array(self.shape) + self.assertTrue((tensor_shape_np == origin_shape).all()) + + +class TestNormalRandomInplaceOpDistribution(unittest.TestCase): + def setUp(self): + self.shape = (1000, 784) + self.mean = -3 + self.std = 5 + + def test_normal_inplace_op_distribution(self): + tensor = paddle.ones(self.shape) + tensor.normal_(self.mean, self.std) + ones = paddle.ones(self.shape) + zeros = paddle.zeros(self.shape) + all_num = self.shape[0] * self.shape[1] + + std_probs = [0.68, 0.95, 0.997] + for index, prob in enumerate(std_probs): + left = self.mean - (index + 1) * self.std + right = self.mean + (index + 1) * self.std + cond = paddle.logical_and(tensor >= left, tensor <= right) + c_sum = paddle.where(cond, ones, zeros).sum() + np.testing.assert_allclose((c_sum / all_num), prob, 1e-2) + + +class TestNormalRandomInplaceOpEmptyTensor(unittest.TestCase): + def test_normal_inplace_op_empty_tensor(self): + places = ['cpu'] + if base.core.is_compiled_with_cuda(): + places.append('gpu') + test_shapes = [(200, 0), (0, 200)] + for place in places: + paddle.set_device(place) + for test_shape in test_shapes: + tensor = paddle.empty(shape=test_shape) + tensor.normal_() + tensor_shape_np = np.array(tensor.shape) + origin_shape = np.array(test_shape) + self.assertTrue((tensor_shape_np == origin_shape).all()) + + +class TestNormalRandomInplaceGrad(unittest.TestCase): + def setUp(self): + self.shape = (1000, 784) + + def run_(self): + def test_grad(): + tensor_a = paddle.ones(self.shape) + tensor_a.stop_gradient = False + tensor_b = tensor_a * 0.5 + tensor_b.retain_grads() + tensor_b.normal_(mean=-2, std=2) + loss = tensor_b.sum() + loss.backward() + normal_grad = tensor_b.grad.numpy() + self.assertTrue((normal_grad == 0).all()) + + places = ['cpu'] + if base.core.is_compiled_with_cuda(): + places.append('gpu') + for place in places: + paddle.set_device(place) + test_grad() + + def test_normal_inplace_grad(self): + self.run_() + + +if __name__ == '__main__': + unittest.main()