Skip to content

Commit

Permalink
add comments and avoid 0 division
Browse files Browse the repository at this point in the history
  • Loading branch information
liopeer committed Jan 3, 2025
1 parent 18abcba commit a954bb5
Showing 1 changed file with 64 additions and 9 deletions.
73 changes: 64 additions & 9 deletions lightly/loss/detcon_loss.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from math import log
import torch
import torch.nn.functional as F
from torch import Tensor
Expand Down Expand Up @@ -100,7 +101,9 @@ class DetConBLoss(Module):
:math:`\\frac{1}{\\sqrt{\\tau}}`, the formula for the contrastive loss is
.. math::
\\mathcal{L} = \\sum_{m}\\sum_{m'} \\mathbb{1}_{m, m'} \\left[ - \\log \\frac{\\exp(v_m \\cdot v_{m'}')}{\\exp(v_m \\cdot v_{m'}') + \\sum_{n}\\exp (v_m \\cdot v_{m'}')} \\right]
\\mathcal{L} = \\sum_{m}\\sum_{m'} \\mathbb{1}_{m, m'} \\left[ - \\log \
\\frac{\\exp(v_m \\cdot v_{m'}')}{\\exp(v_m \\cdot v_{m'}') + \\sum_{n}\\exp \
(v_m \\cdot v_{m'}')} \\right]
where :math:`\\mathbb{1}_{m, m'}` is 1 if the masks are the same and 0 otherwise.
Since :math:`v_m` and :math:`v_{m'}'` stem from different branches, the loss is
Expand Down Expand Up @@ -201,10 +204,16 @@ def forward(
labels_local = F.one_hot(labels_idx, num_classes=enlarged_b)
labels_ext = F.one_hot(labels_idx, num_classes=2 * enlarged_b)

### Expand Labels ###
# labels_local at this point only points towards the diagonal of the batch, i.e.
# indicates to compare between the same samples across views.
labels_local = labels_local[:, None, :, None]
labels_ext = labels_ext[:, None, :, None]
assert labels_local.size() == (b, 1, b, 1)

# calculate similarity matrices
### Calculate Similarity Matrices ###
# tensors of shape (b, m, b, m), indicating similarities between regions across
# views and samples in the batch
logits_aa = (
torch.einsum("abk,uvk->abuv", pred_view0, target_view0_large)
/ self.temperature
Expand All @@ -221,57 +230,103 @@ def forward(
torch.einsum("abk,uvk->abuv", pred_view1, target_view0_large)
/ self.temperature
)
assert logits_aa.size() == (b, m, b, m)
assert logits_bb.size() == (b, m, b, m)
assert logits_ab.size() == (b, m, b, m)
assert logits_ba.size() == (b, m, b, m)

# determine where the masks are the same
### Find Corresponding Regions Across Views ###
same_mask_aa = _same_mask(mask_view0, mask_view0)
same_mask_bb = _same_mask(mask_view1, mask_view1)
same_mask_ab = _same_mask(mask_view0, mask_view1)
same_mask_ba = _same_mask(mask_view1, mask_view0)

# remove similarities between the same masks
labels_aa = labels_local * same_mask_aa
assert same_mask_aa.size() == (b, m, 1, m)
assert same_mask_bb.size() == (b, m, 1, m)
assert same_mask_ab.size() == (b, m, 1, m)
assert same_mask_ba.size() == (b, m, 1, m)

### Remove Similarities Between Corresponding Views But Different Regions ###
# labels_local initially compared all features across views, but we only want to
# compare the same regions across views.
labels_aa = labels_local * same_mask_aa # (b, 1, b, 1) * (b, m, 1, m)
labels_bb = labels_local * same_mask_bb
labels_ab = labels_local * same_mask_ab
labels_ba = labels_local * same_mask_ba
assert labels_aa.size() == (b, m, b, m)
assert labels_bb.size() == (b, m, b, m)
assert labels_ab.size() == (b, m, b, m)
assert labels_ba.size() == (b, m, b, m)

### Remove Logits And Lables Between The Same View ###
logits_aa = logits_aa - infinity_proxy * labels_aa
logits_bb = logits_bb - infinity_proxy * labels_bb
labels_aa = 0.0 * labels_aa
labels_bb = 0.0 * labels_bb

### Arrange Labels ###
labels_abaa = torch.cat([labels_ab, labels_aa], dim=2)
labels_babb = torch.cat([labels_ba, labels_bb], dim=2)

labels_0 = labels_abaa.view(b, m, -1)
labels_1 = labels_babb.view(b, m, -1)

### Count Number of Positives For Every Region (per sample) ###
num_positives_0 = torch.sum(labels_0, dim=-1, keepdim=True)
num_positives_1 = torch.sum(labels_1, dim=-1, keepdim=True)

### Scale The Labels By The Number of Positives To Weight Loss Value ###
labels_0 = labels_0 / torch.maximum(num_positives_0, torch.tensor(1))
labels_1 = labels_1 / torch.maximum(num_positives_1, torch.tensor(1))

### Count How Many Overlapping Regions We Have Across Views ###
obj_area_0 = torch.sum(_same_mask(mask_view0, mask_view0), dim=(2, 3))
obj_area_1 = torch.sum(_same_mask(mask_view1, mask_view1), dim=(2, 3))

# make sure we don't divide by zero
obj_area_0 = torch.maximum(obj_area_0, self.eps)
obj_area_1 = torch.maximum(obj_area_1, self.eps)
assert obj_area_0.size() == (b, m)
assert obj_area_1.size() == (b, m)

### Calculate Weights For The Loss ###
# last dim of num_positives is anyway 1, from the torch.sum above
weights_0 = torch.gt(num_positives_0[..., 0], 1e-3).float()
weights_0 = weights_0 / obj_area_0
weights_1 = torch.gt(num_positives_1[..., 0], 1e-3).float()
weights_1 = weights_1 / obj_area_1

### Arrange Logits ###
logits_abaa = torch.cat([logits_ab, logits_aa], dim=2)
logits_babb = torch.cat([logits_ba, logits_bb], dim=2)

logits_abaa = logits_abaa.view(b, m, -1)
logits_babb = logits_babb.view(b, m, -1)

### Derive Cross Entropy Loss ###
# targets/labels are are a weighted float tensor of same shape as logits,
# which is why we can't use F.cross_entropy (expects integer targets)
loss_a = _torch_manual_cross_entropy(labels_0, logits_abaa, weights_0)
loss_b = _torch_manual_cross_entropy(labels_1, logits_babb, weights_1)
loss = loss_a + loss_b
return loss


def _same_mask(mask0: Tensor, mask1: Tensor) -> Tensor:
"""Find equal masks/regions across views of the same image.
Args:
mask0: Indices corresponding to the sampled masks for the first view,
an integer tensor of shape :math:`(B, M)` with (possibly repeated)
indices in the range :math:`[0, N)`.
mask1: Indices corresponding to the sampled masks for the second view,
an integer tensor of shape (B, M) with (possibly repeated) indices in
the range :math:`[0, N)`.
Returns:
Tensor: A float tensor of shape :math:`(B, M, 1, M)` where the first :math:`M`
dimensions is for the regions/masks of the first view and the last :math:`M`
dimensions is for the regions/masks of the second view. For every sample
:math:`k` in the batch (separately), the tensor is effectively a 2D index
matrix where the entry :math:`(k, i, :, j)` is 1 if the masks :math:`m_i`
and :math:`m_j'` are the same and 0 otherwise.
"""
return (mask0[:, :, None] == mask1[:, None, :]).float()[:, :, None, :]


Expand Down

0 comments on commit a954bb5

Please sign in to comment.