Skip to content

Commit

Permalink
Add soft margin variant to Hard Triplet Loss (#178)
Browse files Browse the repository at this point in the history
* Make args and docstrings consistent

* Add soft margin variant

* Actually assert semi-hard mining test

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
monatis and pre-commit-ci[bot] authored Sep 1, 2022
1 parent 83f9123 commit 760020a
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 19 deletions.
39 changes: 21 additions & 18 deletions quaterion/loss/triplet_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,24 @@
class TripletLoss(GroupLoss):
"""Implements Triplet Loss as defined in https://arxiv.org/abs/1503.03832
It supports batch-all and batch-hard strategies for online triplet mining.
It supports batch-all, batch-hard and batch-semihard strategies for online triplet mining.
Args:
margin: Margin value to push negative examples
apart. Optional, defaults to `0.5`.
apart.
distance_metric_name: Name of the distance function, e.g.,
:class:`~quaterion.distances.Distance`. Optional, defaults to
:attr:`~quaterion.distances.Distance.COSINE`.
mining (str, optional): Triplet mining strategy. One of
`"all"`, `"hard"`, `"semi_hard"`. Defaults to `"hard"`.
:class:`~quaterion.distances.Distance`.
mining: Triplet mining strategy. One of
`"all"`, `"hard"`, `"semi_hard"`.
soft: If `True`, use soft margin variant of Hard Triplet Loss. Ignored in all other cases.
"""

def __init__(
self,
margin: Optional[float] = 1.0,
distance_metric_name: Distance = Distance.COSINE,
margin: Optional[float] = 0.5,
distance_metric_name: Optional[Distance] = Distance.COSINE,
mining: Optional[str] = "hard",
soft: Optional[bool] = False,
):
mining_types = ["all", "hard", "semi_hard"]
if mining not in mining_types:
Expand All @@ -46,14 +47,12 @@ def __init__(

self._margin = margin
self._mining = mining
self._soft = soft

def get_config_dict(self):
config = super().get_config_dict()
config.update(
{
"margin": self._margin,
"mining": self._mining,
}
{"margin": self._margin, "mining": self._mining, "soft": self._soft}
)

return config
Expand Down Expand Up @@ -95,12 +94,16 @@ def _hard_triplet_loss(
hardest_negative_dists = anchor_negative_dists.min(dim=1)[0]

# combine hardest positives and hardest negatives
triplet_loss = F.relu(
# Division by the minimal distance between negative samples scales target distances
# # and prevents vector collapse
(hardest_positive_dists - hardest_negative_dists)
/ hardest_negative_dists.mean()
+ self._margin
triplet_loss = ( # SoftPlus is a smooth approximation to the ReLU function and is always positive
F.softplus(hardest_positive_dists - hardest_negative_dists)
if self._soft
else F.relu(
# Division by the minimal distance between negative samples scales target distances
# # and prevents vector collapse
(hardest_positive_dists - hardest_negative_dists)
/ hardest_negative_dists.mean()
+ self._margin
)
)

# get scalar loss value
Expand Down
2 changes: 1 addition & 1 deletion tests/eval/losses/test_triplet_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,4 @@ def test_semi_hard(self):
groups_b=groups_b,
)

print(loss_res)
assert loss_res.shape == torch.Size([])

0 comments on commit 760020a

Please sign in to comment.