Skip to content

Commit

Permalink
2022-04-05
Browse files Browse the repository at this point in the history
  • Loading branch information
yangguohao committed Apr 5, 2022
1 parent a0919de commit a15eae3
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def call_TripletMarginDistanceLoss_layer(input,
margin=0.3,
swap=False,
reduction='mean',):
triplet_margin_with_distance_loss = paddle.nn.TripletMarginDistanceLoss(distance_function=distance_function,
triplet_margin_with_distance_loss = paddle.nn.TripletMarginWithDistanceLoss(distance_function=distance_function,
margin=margin,swap=swap,reduction=reduction)
res = triplet_margin_with_distance_loss(input=input,positive=positive,negative=negative,)
return res
Expand Down Expand Up @@ -130,7 +130,7 @@ def calc_triplet_margin_distance_loss(input,


class TestTripletMarginLoss(unittest.TestCase):
def test_TripletMarginLoss(self):
def test_TripletMarginDistanceLoss(self):
input = np.random.uniform(0.1, 0.8, size=(20, 30)).astype(np.float64)
positive = np.random.randint(0, 2, size=(20, 30)).astype(np.float64)
negative = np.random.randint(0, 2, size=(20, 30)).astype(np.float64)
Expand Down Expand Up @@ -171,7 +171,7 @@ def test_TripletMarginDistanceLoss_error(self):
paddle.disable_static()
self.assertRaises(
ValueError,
paddle.nn.TripletMarginDistanceLoss,
paddle.nn.TripletMarginWithDistanceLoss,
reduction="unsupport reduction")
input = paddle.to_tensor([[0.1, 0.3]], dtype='float32')
positive = paddle.to_tensor([[0.0, 1.0]], dtype='float32')
Expand Down Expand Up @@ -220,7 +220,7 @@ def distance_function_2(x1,x2):
functional=True)
self.assertTrue(np.allclose(static_functional, dy_functional))

def test_TripletMarginLoss_dimension(self):
def test_TripletMarginDistanceLoss_dimension(self):
paddle.disable_static()

input = paddle.to_tensor([[0.1, 0.3], [1, 2]], dtype='float32')
Expand All @@ -232,13 +232,6 @@ def test_TripletMarginLoss_dimension(self):
input=input,
positive=positive,
negative=negative, )
TMDLoss = paddle.nn.TripletMarginDistanceLoss
self.assertRaises(
ValueError,
TMDLoss,
input=input,
positive=positive,
negative=negative, )
paddle.enable_static()

if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions python/paddle/nn/functional/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from paddle import _C_ops
from paddle import in_dynamic_mode
from paddle.framework import core
from ...fluid.framework import _in_legacy_dygraph, in_dygraph_mode, _non_static_mode
from ...fluid.framework import _in_legacy_dygraph, in_dygraph_mode,_non_static_mode
__all__ = []


Expand Down Expand Up @@ -2272,7 +2272,7 @@ def triplet_margin_with_distance_loss(input,positive,negative,distance_function
raise ValueError(
"margin should not smaller than 0"
)
if not in_dynamic_mode():
if not _non_static_mode():
check_variable_and_dtype(input, 'input', ['float32', 'float64'],
'triplet_margin_loss')
check_variable_and_dtype(positive, 'positive', ['float32', 'float64'],
Expand Down
4 changes: 2 additions & 2 deletions python/paddle/nn/layer/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -1300,7 +1300,7 @@ def forward(self, input, label):
name=self.name)


class TripletMarginDistanceLoss(Layer):
class TripletMarginWithDistanceLoss(Layer):
"""
Creates a criterion that measures the triplet loss given an input
tensors :math:`x1`, :math:`x2`, :math:`x3` and a margin with a value greater than :math:`0`.
Expand Down Expand Up @@ -1361,7 +1361,7 @@ class TripletMarginDistanceLoss(Layer):
"""
def __init__(self, distance_function=None,margin: float = 1.0, swap: bool = False,reduction: str = 'mean'):
super(TripletMarginDistanceLoss, self).__init__()
super(TripletMarginWithDistanceLoss, self).__init__()
if reduction not in ['sum', 'mean', 'none']:
raise ValueError(
"The value of 'reduction' in bce_loss should be 'sum', 'mean' or 'none', but "
Expand Down

0 comments on commit a15eae3

Please sign in to comment.