Skip to content

Commit

Permalink
'2022_03_27'
Browse files Browse the repository at this point in the history
  • Loading branch information
yangguohao committed Mar 30, 2022
1 parent 352ec13 commit e2117ab
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 0 deletions.
20 changes: 20 additions & 0 deletions python/paddle/fluid/tests/unittests/test_triplet_margin_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,26 @@ def test_TripletMarginLoss_error(self):
reduction="unsupport reduction")
paddle.enable_static()

def test_TripletMarginLoss_dimension(self):
paddle.disable_static()

input = paddle.to_tensor([[0.1, 0.3],[1, 2]], dtype='float32')
positive = paddle.to_tensor([[0.0, 1.0]], dtype='float32')
negative = paddle.to_tensor([[0.2, 0.1]], dtype='float32')
self.assertRaises(
ValueError,
paddle.nn.functional.triplet_margin_loss,
input=input,
positive=positive,
negative=negative,)
TMLoss = paddle.nn.TripletMarginLoss()
self.assertRaises(
ValueError,
TMLoss,
input=input,
positive=positive,
negative=negative,)
paddle.enable_static()

if __name__ == "__main__":
unittest.main()
10 changes: 10 additions & 0 deletions python/paddle/nn/functional/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -2250,6 +2250,16 @@ def triplet_margin_loss(input,positive,negative,
check_variable_and_dtype(negative, 'negative', ['float32', 'float64'],
'triplet_margin_loss')

# reshape to [batch_size, N]
input = input.flatten(start_axis=1,stop_axis=-1)
positive = positive.flatten(start_axis=1,stop_axis=-1)
negative = negative.flatten(start_axis=1,stop_axis=-1)
if not(input.shape==positive.shape==negative.shape):
raise ValueError(
"input's shape must equal to "
"positive's shape and "
"negative's shape")

distance_function = paddle.nn.PairwiseDistance(p, epsilon=eps)
positive_dist = distance_function(input, positive)
negative_dist = distance_function(input, negative)
Expand Down

0 comments on commit e2117ab

Please sign in to comment.