Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

v2.7.0 #728

Merged
merged 17 commits into from
Nov 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions .github/workflows/base_test_workflow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ jobs:
strategy:
matrix:
include:
- python-version: 3.8
pytorch-version: 1.6
torchvision-version: 0.7
- python-version: 3.9
pytorch-version: 2.3
torchvision-version: 0.18
- python-version: "3.8"
pytorch-version: "1.6"
torchvision-version: "0.7"
- python-version: "3.9"
pytorch-version: "2.5"
torchvision-version: "0.20"

steps:
- uses: actions/checkout@v2
Expand All @@ -30,7 +30,7 @@ jobs:
- name: Install dependencies
run: |
pip install .[with-hooks-cpu]
pip install "numpy<2.0" torch==${{ matrix.pytorch-version }} torchvision==${{ matrix.torchvision-version }} --force-reinstall
pip install torch==${{ matrix.pytorch-version }} torchvision==${{ matrix.torchvision-version }} --force-reinstall
pip install --upgrade protobuf==3.20.1
pip install six
pip install packaging
Expand Down
Binary file added docs/imgs/tcm_loss_equation.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
29 changes: 29 additions & 0 deletions docs/losses.md
Original file line number Diff line number Diff line change
Expand Up @@ -1249,6 +1249,35 @@ losses.SupConLoss(temperature=0.1, **kwargs)

* **loss**: The loss per element in the batch. If an element has only negative pairs or no pairs, it's ignored thanks to `AvgNonZeroReducer`. Reduction type is ```"element"```.

## ThresholdConsistentMarginLoss
[Threshold-Consistent Margin Loss for Open-World Deep Metric Learning](https://arxiv.org/pdf/2307.04047){target=_blank}

This loss acts as a form of regularization and is usually combined with another metric loss function.

```python
losses.ThresholdConsistentMarginLoss(
lambda_plus=1.0,
lambda_minus=1.0,
margin_plus=0.9,
margin_minus=0.5,
**kwargs
)
```
**Equation**:
![threshold_consistent_margin_loss](imgs/tcm_loss_equation.png)

**Parameters**:

* **lambda_plus**: The scaling coefficient for the anchor-positive part of the loss. This is $\lambda^+$ in the above equation.
* **lambda_minus**: The scaling coefficient for the anchor-negative part of the loss. This is $\lambda^-$ in the above equation.
* **margin_plus**: The minimum anchor-positive similarity to be included in the loss. This is $m^+$ in the above equation.
* **margin_minus**: The maximum anchor-negative similarity to be included in the loss. This is $m^-$ in the above equation.


**Default distance**:

- [```CosineSimilarity()```](distances.md#cosinesimilarity)
- This is the only compatible distance.

## TripletMarginLoss

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
],
python_requires=">=3.0",
install_requires=[
"numpy < 2.0",
"numpy",
"scikit-learn",
"tqdm",
"torch >= 1.6.0",
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_metric_learning/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "2.6.0"
__version__ = "2.7.0"
1 change: 1 addition & 0 deletions src/pytorch_metric_learning/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from .sphereface_loss import SphereFaceLoss
from .subcenter_arcface_loss import SubCenterArcFaceLoss
from .supcon_loss import SupConLoss
from .tcm_loss import ThresholdConsistentMarginLoss
from .triplet_margin_loss import TripletMarginLoss
from .tuplet_margin_loss import TupletMarginLoss
from .vicreg_loss import VICRegLoss
62 changes: 62 additions & 0 deletions src/pytorch_metric_learning/losses/tcm_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import torch.nn.functional as F

from ..distances import CosineSimilarity
from ..utils import common_functions as c_f
from ..utils.loss_and_miner_utils import convert_to_pairs
from .base_metric_loss_function import BaseMetricLossFunction


class ThresholdConsistentMarginLoss(BaseMetricLossFunction):
"""
Implements the TCM loss from: https://arxiv.org/abs/2307.04047
"""

def __init__(
self,
lambda_plus=1.0,
lambda_minus=1.0,
margin_plus=0.9,
margin_minus=0.5,
**kwargs
):
super().__init__(**kwargs)
c_f.assert_distance_type(self, CosineSimilarity)
self.lambda_plus = lambda_plus
self.lambda_minus = lambda_minus
self.margin_plus = margin_plus
self.margin_minus = margin_minus

def get_default_distance(self):
return CosineSimilarity()

def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels):
ap, p, an, n = convert_to_pairs(indices_tuple, labels, ref_labels)

# calculate the similarities for positive and negative pairs
ap, p = embeddings[ap], embeddings[p]
an, n = embeddings[an], embeddings[n]

pos_sims = F.cosine_similarity(ap, p)
neg_sims = F.cosine_similarity(an, n)

# calculate the positive part
s_lte_m = pos_sims <= self.margin_plus
tcm_pos_num = ((self.margin_plus - pos_sims) * s_lte_m).sum()
tcm_pos_denom = s_lte_m.sum()
pos_tcm = 0 if s_lte_m.sum() == 0 else tcm_pos_num / tcm_pos_denom

# calculate the negative part
s_gte_m = neg_sims >= self.margin_minus
tcm_neg_num = ((neg_sims - self.margin_minus) * s_gte_m).sum()
tcm_neg_denom = s_gte_m.sum()
neg_tcm = 0 if s_gte_m.sum() == 0 else tcm_neg_num / tcm_neg_denom

# add the components for final loss
tcm_loss = self.lambda_plus * pos_tcm + self.lambda_minus * neg_tcm
return {
"loss": {
"losses": tcm_loss,
"indices": None,
"reduction_type": "already_reduced",
}
}
59 changes: 59 additions & 0 deletions tests/losses/test_tcm_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import unittest

import torch
import torch.nn.functional as F

from pytorch_metric_learning.distances import CosineSimilarity
from pytorch_metric_learning.losses import (
ContrastiveLoss,
MultipleLosses,
ThresholdConsistentMarginLoss,
)

from .. import TEST_DEVICE, TEST_DTYPES


class TestThresholdConsistentMarginLoss(unittest.TestCase):
def test_tcm_loss(self):
torch.manual_seed(3459)
for dtype in TEST_DTYPES:
loss_func = MultipleLosses(
losses=[
ContrastiveLoss(
distance=CosineSimilarity(),
pos_margin=0.9,
neg_margin=0.4,
),
ThresholdConsistentMarginLoss(),
]
)
embs = torch.tensor(
[
[0.00, 1.00],
[0.43, 0.90],
[1.00, 0.00],
[0.50, 0.50],
],
device=TEST_DEVICE,
dtype=dtype,
requires_grad=True,
)
labels = torch.tensor([0, 0, 1, 1])

# Contrastive loss = 0.4866
#
# TCM loss part:
# Only pair (2, 3) is taken into account for positive part
# Positive part = 1 * ( 0.9 - 0.7071 ) / ( 1 ) = 0.1929
#
# Only pairs (1, 2) and (1, 3) are taken into account for negative part
# Negative part = 1 * ( 0.7071 - 0.5 + 0.9429 - 0.5 ) / ( 2 ) = 0.325
#
# Sum of these losses -> 0.4866 + 0.518 = 1.0046
correct_loss = torch.tensor(1.0045).to(dtype)

with torch.no_grad():
res = loss_func.forward(embs, labels)
rtol = 1e-2 if dtype == torch.float16 else 1e-5
atol = 1e-4
self.assertTrue(torch.isclose(res, correct_loss, rtol=rtol, atol=atol))
Loading