Skip to content

Commit

Permalink
2022-06-01_pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
yangguohao committed Jun 1, 2022
1 parent 2a4ed81 commit a6c17b0
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 27 deletions.
34 changes: 16 additions & 18 deletions python/paddle/nn/functional/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -2232,13 +2232,13 @@ def hinge_embedding_loss(input, label, margin=1.0, reduction='mean', name=None):


def triplet_margin_with_distance_loss(input,
positive,
negative,
distance_function = None,
margin=1.0,
swap=False,
reduction='mean',
name=None):
positive,
negative,
distance_function=None,
margin=1.0,
swap=False,
reduction='mean',
name=None):
r"""
Measures the triplet loss given an input
tensors :math:`x1`, :math:`x2`, :math:`x3` and a margin with a value greater than :math:`0`.
Expand Down Expand Up @@ -2275,7 +2275,7 @@ def triplet_margin_with_distance_loss(input,
distance_function (callable, optional): Quantifies the distance between two tensors. if not specified, 2 norm functions will be used.
margin (float, optional):Default: :math:`1`.A nonnegative margin representing the minimum difference
margin (float, optional):Default: :math:`1`.A nonnegative margin representing the minimum difference
between the positive and negative distances required for the loss to be 0.
swap (bool, optional):The distance swap changes the negative distance to the swap distance (distance between positive samples
Expand Down Expand Up @@ -2313,10 +2313,9 @@ def triplet_margin_with_distance_loss(input,
"""
if reduction not in ['sum', 'mean', 'none']:
raise ValueError(
"'reduction' in 'triplet_margin_with_distance_loss' "
"should be 'sum', 'mean' or 'none', "
"but received {}.".format(reduction))
raise ValueError("'reduction' in 'triplet_margin_with_distance_loss' "
"should be 'sum', 'mean' or 'none', "
"but received {}.".format(reduction))
if margin < 0:
raise ValueError(
"The margin between positive samples and negative samples should be greater than 0."
Expand All @@ -2329,11 +2328,10 @@ def triplet_margin_with_distance_loss(input,
check_variable_and_dtype(negative, 'negative', ['float32', 'float64'],
'triplet_margin_with_distance_loss')

if not(input.shape==positive.shape==negative.shape):
raise ValueError(
"input's shape must equal to "
"positive's shape and "
"negative's shape")
if not (input.shape == positive.shape == negative.shape):
raise ValueError("input's shape must equal to "
"positive's shape and "
"negative's shape")

distance_function = distance_function if distance_function is not None \
else paddle.nn.PairwiseDistance(2)
Expand All @@ -2350,7 +2348,7 @@ def triplet_margin_with_distance_loss(input,
"The positive distance or negative distance should be greater than 0, "
"The distance functions should be checked.")

loss = paddle.clip(positive_dist-negative_dist+margin, min=0.0)
loss = paddle.clip(positive_dist - negative_dist + margin, min=0.0)

if reduction == 'mean':
return paddle.mean(loss, name=name)
Expand Down
25 changes: 16 additions & 9 deletions python/paddle/nn/layer/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -1361,7 +1361,7 @@ class TripletMarginWithDistanceLoss(Layer):
negative (Tensor):Negative tensor, the data type is float32 or float64.
The shape of label is the same as the shape of input.
output(Tensor): The tensor variable storing the triplet_margin_with_distance_loss of input and positive and negative.
output(Tensor): The tensor variable storing the triplet_margin_with_distance_loss of input and positive and negative.
Return:
A callable object of TripletMarginWithDistanceLoss
Expand All @@ -1386,7 +1386,13 @@ class TripletMarginWithDistanceLoss(Layer):
# Tensor([0.19165580])
"""
def __init__(self, distance_function=None, margin=1.0, swap=False, reduction: str = 'mean', name=None):

def __init__(self,
distance_function=None,
margin=1.0,
swap=False,
reduction: str='mean',
name=None):
super(TripletMarginWithDistanceLoss, self).__init__()
if reduction not in ['sum', 'mean', 'none']:
raise ValueError(
Expand All @@ -1400,10 +1406,11 @@ def __init__(self, distance_function=None, margin=1.0, swap=False, reduction: st
self.name = name

def forward(self, input, positive, negative):
return F.triplet_margin_with_distance_loss(input,
positive,
negative,
margin=self.margin,
swap=self.swap,
reduction=self.reduction,
name=self.name)
return F.triplet_margin_with_distance_loss(
input,
positive,
negative,
margin=self.margin,
swap=self.swap,
reduction=self.reduction,
name=self.name)

0 comments on commit a6c17b0

Please sign in to comment.