diff --git a/python/paddle/fluid/tests/unittests/test_cosine_embedding_loss.py b/python/paddle/fluid/tests/unittests/test_cosine_embedding_loss.py new file mode 100644 index 0000000000000..f95089a4cdb51 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_cosine_embedding_loss.py @@ -0,0 +1,328 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +from __future__ import print_function + +import paddle +import paddle.static as static +import numpy as np +import unittest + + +def cosine_embedding_loss(input1, input2, label, margin=0.5, reduction='mean'): + z = (input1 * input2).sum(axis=-1) + mag_square1 = np.square(input1).sum(axis=-1) + 10e-12 + mag_square2 = np.square(input2).sum(axis=-1) + 10e-12 + denom = np.sqrt(mag_square1 * mag_square2) + cos = z / denom + zeros = np.zeros_like(cos) + pos = 1 - cos + neg = np.clip(cos - margin, a_min=0, a_max=np.inf) + out_pos = np.where(label == 1, pos, zeros) + out_neg = np.where(label == -1, neg, zeros) + out = out_pos + out_neg + if reduction == 'none': + return out + if reduction == 'mean': + return np.mean(out) + elif reduction == 'sum': + return np.sum(out) + + +class TestFunctionCosineEmbeddingLoss(unittest.TestCase): + + def setUp(self): + self.input1_np = np.random.random(size=(5, 3)).astype(np.float64) + self.input2_np = np.random.random(size=(5, 3)).astype(np.float64) + a = np.array([-1, -1, -1]).astype(np.int32) + b = np.array([1, 1]).astype(np.int32) + self.label_np = np.concatenate((a, b), axis=0) + np.random.shuffle(self.label_np) + + def run_dynamic(self): + input1 = paddle.to_tensor(self.input1_np) + input2 = paddle.to_tensor(self.input2_np) + label = paddle.to_tensor(self.label_np) + dy_result = paddle.nn.functional.cosine_embedding_loss(input1, + input2, + label, + margin=0.5, + reduction='mean') + expected1 = cosine_embedding_loss(self.input1_np, + self.input2_np, + self.label_np, + margin=0.5, + reduction='mean') + self.assertTrue(np.allclose(dy_result.numpy(), expected1)) + self.assertTrue(dy_result.shape, [1]) + + dy_result = paddle.nn.functional.cosine_embedding_loss(input1, + input2, + label, + margin=0.5, + reduction='sum') + expected2 = cosine_embedding_loss(self.input1_np, + self.input2_np, + self.label_np, + margin=0.5, + reduction='sum') + + self.assertTrue(np.allclose(dy_result.numpy(), expected2)) + self.assertTrue(dy_result.shape, [1]) + + dy_result = paddle.nn.functional.cosine_embedding_loss(input1, + input2, + label, + margin=0.5, + reduction='none') + expected3 = cosine_embedding_loss(self.input1_np, + self.input2_np, + self.label_np, + margin=0.5, + reduction='none') + + self.assertTrue(np.allclose(dy_result.numpy(), expected3)) + self.assertTrue(dy_result.shape, [5]) + + def run_static(self, use_gpu=False): + input1 = static.data(name='input1', shape=[5, 3], dtype='float64') + input2 = static.data(name='input2', shape=[5, 3], dtype='float64') + label = static.data(name='label', shape=[5], dtype='int32') + result0 = paddle.nn.functional.cosine_embedding_loss(input1, + input2, + label, + margin=0.5, + reduction='none') + result1 = paddle.nn.functional.cosine_embedding_loss(input1, + input2, + label, + margin=0.5, + reduction='sum') + result2 = paddle.nn.functional.cosine_embedding_loss(input1, + input2, + label, + margin=0.5, + reduction='mean') + + place = paddle.CUDAPlace(0) if use_gpu else paddle.CPUPlace() + exe = static.Executor(place) + exe.run(static.default_startup_program()) + static_result = exe.run(feed={ + "input1": self.input1_np, + "input2": self.input2_np, + "label": self.label_np + }, + fetch_list=[result0, result1, result2]) + expected = cosine_embedding_loss(self.input1_np, + self.input2_np, + self.label_np, + margin=0.5, + reduction='none') + + self.assertTrue(np.allclose(static_result[0], expected)) + expected = cosine_embedding_loss(self.input1_np, + self.input2_np, + self.label_np, + margin=0.5, + reduction='sum') + + self.assertTrue(np.allclose(static_result[1], expected)) + expected = cosine_embedding_loss(self.input1_np, + self.input2_np, + self.label_np, + margin=0.5, + reduction='mean') + + self.assertTrue(np.allclose(static_result[2], expected)) + + def test_cpu(self): + paddle.disable_static(place=paddle.CPUPlace()) + self.run_dynamic() + paddle.enable_static() + + with static.program_guard(static.Program()): + self.run_static() + + def test_gpu(self): + if not paddle.is_compiled_with_cuda(): + return + + paddle.disable_static(place=paddle.CUDAPlace(0)) + self.run_dynamic() + paddle.enable_static() + + with static.program_guard(static.Program()): + self.run_static(use_gpu=True) + + def test_errors(self): + paddle.disable_static() + input1 = paddle.to_tensor(self.input1_np) + input2 = paddle.to_tensor(self.input2_np) + label = paddle.to_tensor(self.label_np) + + def test_label_shape_error(): + label = paddle.to_tensor( + np.random.randint(low=0, high=2, size=(2, 3))) + paddle.nn.functional.cosine_embedding_loss(input1, + input2, + label, + margin=0.5, + reduction='mean') + + self.assertRaises(ValueError, test_label_shape_error) + + def test_input_different_shape_error(): + input1 = paddle.to_tensor(self.input1_np[0]) + label = paddle.to_tensor(np.ndarray([1])) + paddle.nn.functional.cosine_embedding_loss(input1, + input2, + label, + margin=0.5, + reduction='mean') + + self.assertRaises(ValueError, test_input_different_shape_error) + + def test_input_shape2D_error(): + input1 = paddle.to_tensor( + np.random.random(size=(2, 3, 4)).astype(np.float64)) + input2 = paddle.to_tensor( + np.random.random(size=(2, 3, 4)).astype(np.float64)) + paddle.nn.functional.cosine_embedding_loss(input1, + input2, + label, + margin=0.5, + reduction='mean') + + self.assertRaises(ValueError, test_input_shape2D_error) + + def test_label_value_error(): + label = paddle.to_tensor(np.ndarray([-1, -2])) + paddle.nn.functional.cosine_embedding_loss(input1, + input2, + label, + margin=0.5, + reduction='mean') + + self.assertRaises(ValueError, test_label_value_error) + + def test_input_type_error(): + input1 = paddle.to_tensor(self.input1_np.astype(np.int64)) + paddle.nn.functional.cosine_embedding_loss(input1, + input2, + label, + margin=0.5, + reduction='mean') + + self.assertRaises(ValueError, test_input_type_error) + + def test_label_type_error(): + label = paddle.to_tensor(self.label_np.astype(np.int16)) + paddle.nn.functional.cosine_embedding_loss(input1, + input2, + label, + margin=0.5, + reduction='mean') + + self.assertRaises(ValueError, test_label_type_error) + + +class TestClassCosineEmbeddingLoss(unittest.TestCase): + + def setUp(self): + self.input1_np = np.random.random(size=(10, 3)).astype(np.float32) + self.input2_np = np.random.random(size=(10, 3)).astype(np.float32) + a = np.array([-1, -1, -1, -1, -1]).astype(np.int64) + b = np.array([1, 1, 1, 1, 1]).astype(np.int64) + self.label_np = np.concatenate((a, b), axis=0) + np.random.shuffle(self.label_np) + self.input1_np_1D = np.random.random(size=10).astype(np.float32) + self.input2_np_1D = np.random.random(size=10).astype(np.float32) + self.label_np_1D = np.array([1]).astype(np.int64) + + def run_dynamic(self): + input1 = paddle.to_tensor(self.input1_np) + input2 = paddle.to_tensor(self.input2_np) + label = paddle.to_tensor(self.label_np) + CosineEmbeddingLoss = paddle.nn.CosineEmbeddingLoss(margin=0.5, + reduction='mean') + dy_result = CosineEmbeddingLoss(input1, input2, label) + expected1 = cosine_embedding_loss(self.input1_np, + self.input2_np, + self.label_np, + margin=0.5, + reduction='mean') + self.assertTrue(np.allclose(dy_result.numpy(), expected1)) + self.assertTrue(dy_result.shape, [1]) + + input1_1D = paddle.to_tensor(self.input1_np_1D) + input2_1D = paddle.to_tensor(self.input2_np_1D) + label_1D = paddle.to_tensor(self.label_np_1D) + dy_result = CosineEmbeddingLoss(input1_1D, input2_1D, label_1D) + expected2 = cosine_embedding_loss(self.input1_np_1D, + self.input2_np_1D, + self.label_np_1D, + margin=0.5, + reduction='mean') + self.assertTrue(np.allclose(dy_result.numpy(), expected2)) + + def run_static(self): + input1 = static.data(name='input1', shape=[10, 3], dtype='float32') + input2 = static.data(name='input2', shape=[10, 3], dtype='float32') + label = static.data(name='label', shape=[10], dtype='int64') + CosineEmbeddingLoss = paddle.nn.CosineEmbeddingLoss(margin=0.5, + reduction='mean') + result = CosineEmbeddingLoss(input1, input2, label) + + place = paddle.CPUPlace() + exe = static.Executor(place) + exe.run(static.default_startup_program()) + static_result = exe.run(feed={ + "input1": self.input1_np, + "input2": self.input2_np, + "label": self.label_np + }, + fetch_list=[result]) + expected = cosine_embedding_loss(self.input1_np, + self.input2_np, + self.label_np, + margin=0.5, + reduction='mean') + + self.assertTrue(np.allclose(static_result[0], expected)) + + def test_cpu(self): + paddle.disable_static(place=paddle.CPUPlace()) + self.run_dynamic() + paddle.enable_static() + + with static.program_guard(static.Program()): + self.run_static() + + def test_errors(self): + + def test_margin_error(): + CosineEmbeddingLoss = paddle.nn.CosineEmbeddingLoss( + margin=2, reduction='mean') + + self.assertRaises(ValueError, test_margin_error) + + def test_reduction_error(): + CosineEmbeddingLoss = paddle.nn.CosineEmbeddingLoss( + margin=2, reduction='reduce_mean') + + self.assertRaises(ValueError, test_reduction_error) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index de416ca8093d7..20b176d7c7365 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -107,6 +107,7 @@ from .layer.loss import CTCLoss # noqa: F401 from .layer.loss import SmoothL1Loss # noqa: F401 from .layer.loss import HingeEmbeddingLoss # noqa: F401 +from .layer.loss import CosineEmbeddingLoss # noqa: F401 from .layer.norm import BatchNorm # noqa: F401 from .layer.norm import SyncBatchNorm # noqa: F401 from .layer.norm import GroupNorm # noqa: F401 @@ -311,5 +312,6 @@ def weight_norm(*args): 'MaxUnPool3D', 'HingeEmbeddingLoss', 'Identity', + 'CosineEmbeddingLoss', 'RReLU', ] diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index 5e4d0dd3558f5..5de8c775ad7f4 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -90,6 +90,7 @@ from .loss import square_error_cost # noqa: F401 from .loss import ctc_loss # noqa: F401 from .loss import hinge_embedding_loss # noqa: F401 +from .loss import cosine_embedding_loss # noqa: F401 from .norm import batch_norm # noqa: F401 from .norm import instance_norm # noqa: F401 from .norm import layer_norm # noqa: F401 @@ -229,5 +230,6 @@ 'class_center_sample', 'sparse_attention', 'fold', + 'cosine_embedding_loss', 'rrelu', ] diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index e6a3fdb464caf..58a8bb6538351 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -2763,3 +2763,112 @@ def hinge_embedding_loss(input, label, margin=1.0, reduction='mean', name=None): return paddle.sum(loss, name=name) elif reduction == 'none': return loss + + +def cosine_embedding_loss(input1, + input2, + label, + margin=0, + reduction='mean', + name=None): + r""" + This operator computes the cosine embedding loss of Tensor ``input1``, ``input2`` and ``label`` as follows. + + If label = 1, then the loss value can be calculated as follow: + + .. math:: + Out = 1 - cos(input1, input2) + + If label = -1, then the loss value can be calculated as follow: + + .. math:: + Out = max(0, cos(input1, input2)) - margin + + The operator cos can be described as follow: + .. math:: + cos(x1, x2) = \frac{x1 \cdot{} x2}{\Vert x1 \Vert_2 * \Vert x2 \Vert_2} + + Parameters: + input1 (Tensor): tensor with shape: [N, M] or [M], 'N' means batch size, 'M' means the length of input array. + Available dtypes are float32, float64. + input2 (Tensor): tensor with shape: [N, M] or [M], 'N' means batch size, 'M' means the length of input array. + Available dtypes are float32, float64. + label (Tensor): tensor with shape: [N] or [1]. The target labels values should be -1 or 1. + Available dtypes are int32, int64, float32, float64. + margin (float, optional): Should be a number from :math:`-1` to :math:`1`, + :math:`0` to :math:`0.5` is suggested. If :attr:`margin` is missing, the + default value is :math:`0`. + reduction (string, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of elements in the output + ``'sum'``: the output will be summed. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor, the cosine embedding Loss of Tensor ``input1`` ``input2`` and ``label``. + If `reduction` is ``'none'``, the shape of output loss is [N], the same as ``input`` . + If `reduction` is ``'mean'`` or ``'sum'``, the shape of output loss is [1]. + + Examples: + .. code-block:: python + :name: code-example1 + + import paddle + + input1 = paddle.to_tensor([[1.6, 1.2, -0.5], [3.2, 2.6, -5.8]], 'float32') + input2 = paddle.to_tensor([[0.5, 0.5, -1.8], [2.3, -1.4, 1.1]], 'float32') + label = paddle.to_tensor([1, -1], 'int64') + + output = paddle.nn.functional.cosine_embedding_loss(input1, input2, label, margin=0.5, reduction='mean') + print(output) # [0.21155193] + + output = paddle.nn.functional.cosine_embedding_loss(input1, input2, label, margin=0.5, reduction='sum') + print(output) # [0.42310387] + + output = paddle.nn.functional.cosine_embedding_loss(input1, input2, label, margin=0.5, reduction='none') + print(output) # [0.42310387, 0. ] + + """ + if len(label.shape) != 1: + raise ValueError( + "1D target tensor expected, multi-target not supported") + + if input1.shape != input2.shape: + raise ValueError( + "the shape of input tensor 1 should be equal to input tensor 2, but found inputs with " + "different sizes") + + if len(input1.shape) > 2: + raise ValueError( + "1D target tensor expects 1D or 2D input tensors, but found inputs with different sizes" + ) + + if input1.dtype not in [paddle.float32, paddle.float64]: + raise ValueError( + "The data type of input Variable must be 'float32' or 'float64'") + if label.dtype not in [ + paddle.int32, paddle.int64, paddle.float32, paddle.float64 + ]: + raise ValueError( + "The data type of label Variable must be 'int32', 'int64', 'float32', 'float64'" + ) + + prod_sum = (input1 * input2).sum(axis=-1) + mag_square1 = paddle.square(input1).sum(axis=-1) + 10e-12 + mag_square2 = paddle.square(input2).sum(axis=-1) + 10e-12 + denom = paddle.sqrt(mag_square1 * mag_square2) + cos = prod_sum / denom + zeros = paddle.zeros_like(cos) + pos = 1 - cos + neg = paddle.clip(cos - margin, min=0) + out_pos = paddle.where(label == 1, pos, zeros) + out_neg = paddle.where(label == -1, neg, zeros) + out = out_pos + out_neg + + if reduction == 'none': + return out + if reduction == 'mean': + return paddle.mean(out, name=name) + elif reduction == 'sum': + return paddle.sum(out, name=name) diff --git a/python/paddle/nn/layer/loss.py b/python/paddle/nn/layer/loss.py index c720ec7d1be07..0ec60ef473805 100644 --- a/python/paddle/nn/layer/loss.py +++ b/python/paddle/nn/layer/loss.py @@ -1309,3 +1309,94 @@ def forward(self, input, label): reduction=self.reduction, margin=self.margin, name=self.name) + + +class CosineEmbeddingLoss(Layer): + r""" + This interface is used to construct a callable object of the ``CosineEmbeddingLoss`` class. + The CosineEmbeddingLoss layer measures the cosine_embedding loss between input predictions ``input1``, ``input2`` + and target labels ``label`` with values 1 or 0. This is used for measuring whether two inputs are similar or + dissimilar and is typically used for learning nonlinear embeddings or semi-supervised learning. + The cosine embedding loss can be described as: + + If label = 1, then the loss value can be calculated as follow: + + .. math:: + Out = 1 - cos(input1, input2) + + If label = -1, then the loss value can be calculated as follow: + + .. math:: + Out = max(0, cos(input1, input2)) - margin + + The operator cos can be described as follow: + .. math:: + cos(x1, x2) = \frac{x1 \cdot{} x2}{\Vert x1 \Vert_2 * \Vert x2 \Vert_2} + + Parameters: + margin (float, optional): Should be a number from :math:`-1` to :math:`1`, + :math:`0` to :math:`0.5` is suggested. If :attr:`margin` is missing, the + default value is :math:`0`. + reduction (string, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Shape: + input1 (Tensor): tensor with shape: [N, M] or [M], 'N' means batch size, 'M' means the length of input array. + Available dtypes are float32, float64. + input2 (Tensor): tensor with shape: [N, M] or [M], 'N' means batch size, 'M' means the length of input array. + Available dtypes are float32, float64. + label (Tensor): tensor with shape: [N] or [1]. The target labels values should be -1 or 1. + Available dtypes are int32, int64, float32, float64. + output (Tensor): Tensor, the cosine embedding Loss of Tensor ``input1`` ``input2`` and ``label``. + If `reduction` is ``'none'``, the shape of output loss is [N], the same as ``input`` . + If `reduction` is ``'mean'`` or ``'sum'``, the shape of output loss is [1]. + + Examples: + .. code-block:: python + :name: code-example1 + + import paddle + + input1 = paddle.to_tensor([[1.6, 1.2, -0.5], [3.2, 2.6, -5.8]], 'float32') + input2 = paddle.to_tensor([[0.5, 0.5, -1.8], [2.3, -1.4, 1.1]], 'float32') + label = paddle.to_tensor([1, -1], 'int64') + + cosine_embedding_loss = paddle.nn.CosineEmbeddingLoss(margin=0.5, reduction='mean') + output = cosine_embedding_loss(input1, input2, label) + print(output) # [0.21155193] + + cosine_embedding_loss = paddle.nn.CosineEmbeddingLoss(margin=0.5, reduction='sum') + output = cosine_embedding_loss(input1, input2, label) + print(output) # [0.42310387] + + cosine_embedding_loss = paddle.nn.CosineEmbeddingLoss(margin=0.5, reduction='none') + output = cosine_embedding_loss(input1, input2, label) + print(output) # [0.42310387, 0. ] + + """ + + def __init__(self, margin=0, reduction='mean', name=None): + if margin > 1 or margin < -1: + raise ValueError( + "The value of 'margin' should be in the interval of [-1, 1], but received %f, which is not allowed." + % margin) + if reduction not in ['sum', 'mean', 'none']: + raise ValueError( + "The value of 'reduction' should be 'sum', 'mean' or " + "'none', but received %s, which is not allowed." % reduction) + super(CosineEmbeddingLoss, self).__init__() + self.margin = margin + self.reduction = reduction + self.name = name + + def forward(self, input1, input2, label): + return F.cosine_embedding_loss(input1, + input2, + label, + margin=self.margin, + reduction=self.reduction, + name=self.name)