Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: Add FastAP Loss #199

Merged
merged 5 commits into from
Mar 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/api/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ Implementations
~softmax_loss.SoftmaxLoss
~triplet_loss.TripletLoss
~circle_loss.CircleLoss
~fastap_loss.FastAPLoss

Extras
++++++
Expand Down
1 change: 1 addition & 0 deletions quaterion/loss/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from quaterion.loss.arcface_loss import ArcFaceLoss
from quaterion.loss.circle_loss import CircleLoss
from quaterion.loss.contrastive_loss import ContrastiveLoss
from quaterion.loss.fast_ap_loss import FastAPLoss
from quaterion.loss.group_loss import GroupLoss
from quaterion.loss.multiple_negatives_ranking_loss import MultipleNegativesRankingLoss
from quaterion.loss.online_contrastive_loss import OnlineContrastiveLoss
Expand Down
2 changes: 1 addition & 1 deletion quaterion/loss/circle_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(
scale_factor: Optional[float] = 256,
distance_metric_name: Optional[Distance] = Distance.COSINE,
):
super(GroupLoss, self).__init__()
super(GroupLoss, self).__init__(distance_metric_name=distance_metric_name)
self.margin = margin
self.scale_factor = scale_factor
self.op = 1 + self.margin
Expand Down
112 changes: 112 additions & 0 deletions quaterion/loss/fast_ap_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
from typing import Any, Dict, Optional

import torch
import torch.nn.functional as F
from torch import Tensor

from quaterion.distances import Distance
from quaterion.loss.group_loss import GroupLoss
from quaterion.utils import get_anchor_negative_mask, get_anchor_positive_mask


class FastAPLoss(GroupLoss):
"""FastAP Loss

Adaptation from https://github.com/kunhe/FastAP-metric-learning.

Further information:
https://cs-people.bu.edu/fcakir/papers/fastap_cvpr2019.pdf.
"Deep Metric Learning to Rank"
Fatih Cakir(*), Kun He(*), Xide Xia, Brian Kulis, and Stan Sclaroff
IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2019

Args:
num_bins:The number of soft histogram bins for calculating average precision. The paper suggests using 10.
"""

def __init__(self, num_bins: Optional[int] = 10):
# Eucledian distance is the only compatible distance metric for FastAP Loss
super(GroupLoss, self).__init__(distance_metric_name=Distance.EUCLIDEAN)
self.num_bins = num_bins

def get_config_dict(self) -> Dict[str, Any]:
"""Config used in saving and loading purposes.

Config object has to be JSON-serializable.

Returns:
Dict[str, Any]: JSON-serializable dict of params
"""
config = super().get_config_dict()
config.update({"num_bins": self.num_bins})

return config

def forward(
self,
embeddings: Tensor,
groups: Tensor,
) -> Tensor:
"""Compute loss value.

Args:
embeddings: shape: (batch_size, vector_length) - Batch of embeddings.
groups: shape: (batch_size,) - Batch of labels associated with `embeddings`.
Returns:
Tensor: Scalar loss value.
"""

_warn = "Batch size of embeddings and groups don't match."

batch_size = groups.size()[0] # batch size
assert embeddings.size()[0] == batch_size, _warn

device = embeddings.device # get the device of the embeddings tensor

# 1. get positive and negative masks
pos_mask = get_anchor_positive_mask(groups).to(
device
) # (batch_size, batch_size)
neg_mask = get_anchor_negative_mask(groups).to(
device
) # (batch_size, batch_size)
n_pos = torch.sum(pos_mask, dim=1) # Sum over all columns (for each row)

# 2. compute distances from embeddings squared Euclidean distance matrix
embeddings = F.normalize(embeddings, p=2, dim=1).to(
device
) # normalize embeddings
dist_matrix = (
self.distance_metric.distance_matrix(embeddings).to(device) ** 2
) # (batch_size, batch_size)

# 3. estimate discrete histograms
histogram_delta = torch.tensor(4.0 / self.num_bins, device=device)
mid_points = torch.linspace(
0.0, 4.0, steps=self.num_bins + 1, device=device
).view(-1, 1, 1)

pulse = F.relu(
input=1 - torch.abs(dist_matrix - mid_points) / histogram_delta
).to(
device
) # max(0, input)

pos_hist = torch.t(torch.sum(pulse * pos_mask, dim=2)).to(
device
) # positive histograms
neg_hist = torch.t(torch.sum(pulse * neg_mask, dim=2)).to(
device
) # negative histograms

total_pos_hist = torch.cumsum(pos_hist, dim=1).to(device)
total_hist = torch.cumsum(pos_hist + neg_hist, dim=1).to(device)

# 4. compute FastAP
FastAP = pos_hist * total_pos_hist / total_hist
FastAP[torch.isnan(FastAP) | torch.isinf(FastAP)] = 0
FastAP = torch.sum(FastAP, 1) / n_pos
FastAP = FastAP[~torch.isnan(FastAP)]
loss = 1 - torch.mean(FastAP)

return loss
16 changes: 12 additions & 4 deletions quaterion/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Iterable, Sized, Union
from typing import Iterable, Optional, Sized, Union

import torch
import tqdm
Expand Down Expand Up @@ -109,17 +109,21 @@ def get_triplet_mask(labels: torch.Tensor) -> torch.Tensor:


def get_anchor_positive_mask(
labels_a: torch.Tensor, labels_b: torch.Tensor
labels_a: torch.Tensor, labels_b: Optional[torch.Tensor] = None
) -> torch.BoolTensor:
"""Creates a 2D mask of valid anchor-positive pairs.

Args:
labels_a (torch.Tensor): Labels associated with embeddings in the batch A. Shape: (batch_size_a,)
labels_b (torch.Tensor): Labels associated with embeddings in the batch B. Shape: (batch_size_b,)
If `labels_b is None`, it assigns `labels_a` to `labels_b`.

Returns:
torch.Tensor: Anchor-positive mask. Shape: (batch_size_a, batch_size_b)
"""
if labels_b is None:
labels_b = labels_a

# Shape: (batch_size_a, batch_size_b)
mask = labels_a.expand(labels_b.shape[0], labels_a.shape[0]).t() == labels_b.expand(
labels_a.shape[0], labels_b.shape[0]
Expand All @@ -139,17 +143,21 @@ def get_anchor_positive_mask(


def get_anchor_negative_mask(
labels_a: torch.Tensor, labels_b: torch.Tensor
labels_a: torch.Tensor, labels_b: Optional[torch.Tensor] = None
) -> torch.BoolTensor:
"""Creates a 2D mask of valid anchor-negative pairs.

Args:
labels_a (torch.Tensor): Labels associated with embeddings in the batch A. Shape: (batch_size_a,)
labels_b (torch.Tensor): Labels associated with embeddings in the batch B. Shape: (batch_size_b,)
labels_b (torch.Tensor): Labels associated with embeddings in the batch B. Shape: (batch_size_b,).
If `labels_b is None`, it assigns `labels_a` to `labels_b`.

Returns:
torch.Tensor: Anchor-negative mask. Shape: (batch_size_a, batch_size_b)
"""
if labels_b is None:
labels_b = labels_a

# Shape: (batch_size_a, batch_size_b)
mask = labels_a.expand(labels_b.shape[0], labels_a.shape[0]).t() != labels_b.expand(
labels_a.shape[0], labels_b.shape[0]
Expand Down
127 changes: 127 additions & 0 deletions tests/eval/losses/test_fast_ap_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
from typing import Any, Optional

import torch
import torch.nn.functional as F

from quaterion.loss import FastAPLoss

####################################
# Official Implementation #
####################################
# From https://github.com/kunhe/FastAP-metric-learning/blob/master/pytorch/FastAP_loss.py
# This code is copied from the official implementation to compare our results. It's copied under the MIT license.


def soft_binning(
D: torch.Tensor, mid: torch.Tensor, Delta: torch.Tensor
) -> torch.Tensor:
y = 1 - torch.abs(D - mid) / Delta
return torch.max(torch.tensor([0], dtype=D.dtype).to(D.device), y)


class OfficialFastAP(torch.autograd.Function):
"""
FastAP - autograd function definition

This class implements the FastAP loss from the following paper:
"Deep Metric Learning to Rank",
F. Cakir, K. He, X. Xia, B. Kulis, S. Sclaroff. CVPR 2019
"""

@staticmethod
def forward(
ctx: Any, input: torch.Tensor, target: torch.Tensor, num_bins: int
) -> torch.Tensor:
"""
Args:
input: torch.Tensor(N x embed_dim), embedding matrix
target: torch.Tensor(N x 1), class labels
num_bins: int, number of bins in distance histogram
"""
N = target.size()[0]
assert input.size()[0] == N, "Batch size doesn't match!"

# 1. get affinity matrix
Y = target.unsqueeze(1)
Aff = 2 * (Y == Y.t()).type(input.dtype) - 1
Aff.masked_fill_(
torch.eye(N, N).bool().to(input.device), 0
) # set diagonal to 0

I_pos = (Aff > 0).type(input.dtype).to(input.device)
I_neg = (Aff < 0).type(input.dtype).to(input.device)
N_pos = torch.sum(I_pos, 1)

# 2. compute distances from embeddings
# squared Euclidean distance with range [0,4]
dist2 = 2 - 2 * torch.mm(input, input.t())
# 3. estimate discrete histograms
Delta = torch.tensor(4.0 / num_bins).to(input.device)
Z = torch.linspace(0.0, 4.0, steps=num_bins + 1).to(input.device)
L = Z.size()[0]
h_pos = torch.zeros((N, L), dtype=input.dtype).to(input.device)
h_neg = torch.zeros((N, L), dtype=input.dtype).to(input.device)
for idx in range(L):
pulse = soft_binning(dist2, Z[idx], Delta)
h_pos[:, idx] = torch.sum(pulse * I_pos, 1)
h_neg[:, idx] = torch.sum(pulse * I_neg, 1)

H_pos = torch.cumsum(h_pos, 1)
h = h_pos + h_neg
H = torch.cumsum(h, 1)

# 4. compate FastAP
FastAP = h_pos * H_pos / H
FastAP[torch.isnan(FastAP) | torch.isinf(FastAP)] = 0
FastAP = torch.sum(FastAP, 1) / N_pos
FastAP = FastAP[~torch.isnan(FastAP)]
loss = 1 - torch.mean(FastAP)

return loss


class OfficialFastAPLoss(torch.nn.Module):
"""
FastAP - loss layer definition

This class implements the FastAP loss from the following paper:
"Deep Metric Learning to Rank",
F. Cakir, K. He, X. Xia, B. Kulis, S. Sclaroff. CVPR 2019
"""

def __init__(self, num_bins: Optional[int] = 10):
super(OfficialFastAPLoss, self).__init__()
self.num_bins = num_bins

def forward(self, batch: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
return OfficialFastAP.apply(batch, labels, self.num_bins)


class TestFastAPLoss:
embeddings = torch.Tensor(
[
[0.0, -1.0, 0.5],
[0.1, 2.0, 0.5],
[0.0, 2.3, 0.2],
[1.0, 0.0, 0.9],
[1.2, -1.2, 0.01],
[-0.7, 0.0, 1.5],
]
)

groups = torch.Tensor([1, 2, 3, 3, 2, 1])

def test_batch_all(self):
num_bins = 5
loss = FastAPLoss(num_bins)

actual_loss = loss.forward(embeddings=self.embeddings, groups=self.groups)

assert actual_loss.shape == torch.Size([])

expected_loss = OfficialFastAPLoss(num_bins)(
F.normalize(self.embeddings), labels=self.groups
)

rtol = 1e-2 if torch.dtype == torch.float16 else 1e-5
assert torch.isclose(expected_loss, actual_loss, rtol=rtol)