Skip to content

Commit

Permalink
Typecheck CO2Regularizer
Browse files Browse the repository at this point in the history
  • Loading branch information
philippmwirth committed Dec 30, 2024
1 parent d3529c2 commit b7c424d
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 64 deletions.
19 changes: 11 additions & 8 deletions lightly/loss/regularizer/co2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
from typing import Sequence, Union

import torch
from torch import Tensor
from torch.nn import Module

from lightly.models.modules.memory_bank import MemoryBankModule


class CO2Regularizer(MemoryBankModule):
class CO2Regularizer(Module):
"""Implementation of the CO2 regularizer [0] for self-supervised learning.
- [0] CO2, 2021, https://arxiv.org/abs/2010.02217
Expand Down Expand Up @@ -62,7 +64,8 @@ def __init__(
memory_bank_size:
Size of the memory bank.
"""
super(CO2Regularizer, self).__init__(size=memory_bank_size)
super().__init__()
self.memory_bank = MemoryBankModule(size=memory_bank_size)
# Try-catch the KLDivLoss construction for backwards compatability
self.log_target = True
try:
Expand All @@ -74,7 +77,7 @@ def __init__(
self.t_consistency = t_consistency
self.alpha = alpha

def forward(self, out0: torch.Tensor, out1: torch.Tensor):
def forward(self, out0: Tensor, out1: Tensor) -> Tensor:
"""Computes the CO2 regularization term for two model outputs.
Args:
Expand All @@ -93,7 +96,7 @@ def forward(self, out0: torch.Tensor, out1: torch.Tensor):

# Update the memory bank with out1 and get negatives(if memory bank size > 0)
# If the memory_bank size is 0, negatives will be None
out1, negatives = super(CO2Regularizer, self).forward(out1, update=True)
out1, negatives = self.memory_bank.forward(out1, update=True)

# Get log probabilities
p = self._get_pseudo_labels(out0, out1, negatives)
Expand All @@ -106,11 +109,11 @@ def forward(self, out0: torch.Tensor, out1: torch.Tensor):
# Can't use log_target because of early torch version
div = self.kl_div(p, torch.exp(q)) + self.kl_div(q, torch.exp(p))

return self.alpha * 0.5 * div
return torch.tensor(self.alpha * 0.5 * div)

def _get_pseudo_labels(
self, out0: torch.Tensor, out1: torch.Tensor, negatives: torch.Tensor = None
):
self, out0: Tensor, out1: Tensor, negatives: Union[Tensor, None] = None
) -> Tensor:
"""Computes the soft pseudo labels across negative samples.
Args:
Expand Down Expand Up @@ -140,7 +143,7 @@ def _get_pseudo_labels(
# Remove elements on the diagonal
# l_neg has shape bsz x (bsz - 1)
l_neg = l_neg.masked_select(
~torch.eye(batch_size, dtype=bool, device=l_neg.device)
~torch.eye(batch_size, dtype=torch.bool, device=l_neg.device)
).view(batch_size, batch_size - 1)
else:
# Use memory bank as negative samples
Expand Down
2 changes: 0 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,6 @@ exclude = '''(?x)(
lightly/loss/sym_neg_cos_sim_loss.py |
lightly/loss/vicregl_loss.py |
lightly/loss/dcl_loss.py |
lightly/loss/regularizer/co2.py |
lightly/loss/barlow_twins_loss.py |
lightly/data/dataset.py |
lightly/data/collate.py |
Expand Down Expand Up @@ -245,7 +244,6 @@ exclude = '''(?x)(
tests/UNMOCKED_end2end_tests/scripts_for_reproducing_problems/test_api_latency.py |
tests/loss/test_DINOLoss.py |
tests/loss/test_VICRegLLoss.py |
tests/loss/test_CO2Regularizer.py |
tests/loss/test_DCLLoss.py |
tests/loss/test_barlow_twins_loss.py |
tests/loss/test_SymNegCosineSimilarityLoss.py |
Expand Down
54 changes: 0 additions & 54 deletions tests/loss/test_CO2Regularizer.py

This file was deleted.

49 changes: 49 additions & 0 deletions tests/loss/test_co2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import pytest
import torch

from lightly.loss.regularizer import CO2Regularizer


class TestCO2Regularizer:
@pytest.mark.parametrize("bsz", range(1, 20))
def test_forward_pass_no_memory_bank(self, bsz: int) -> None:
reg = CO2Regularizer(memory_bank_size=0)
batch_1 = torch.randn((bsz, 32))
batch_2 = torch.randn((bsz, 32))

# symmetry
l1 = reg(batch_1, batch_2)
l2 = reg(batch_2, batch_1)
assert l1 == pytest.approx(l2)

@pytest.mark.parametrize("bsz", range(1, 20))
def test_forward_pass_memory_bank(self, bsz: int) -> None:
reg = CO2Regularizer(memory_bank_size=(4096, 32))
batch_1 = torch.randn((bsz, 32))
batch_2 = torch.randn((bsz, 32))

l1 = reg(batch_1, batch_2)
assert l1 > 0

@pytest.mark.skipif(not torch.cuda.is_available(), reason="No cuda")
@pytest.mark.parametrize("bsz", range(1, 20))
def test_forward_pass_cuda_no_memory_bank(self, bsz: int) -> None:
reg = CO2Regularizer(memory_bank_size=0)
batch_1 = torch.randn((bsz, 32)).cuda()
batch_2 = torch.randn((bsz, 32)).cuda()

# symmetry
l1 = reg(batch_1, batch_2)
l2 = reg(batch_2, batch_1)
assert l1 == pytest.approx(l2)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="No cuda")
@pytest.mark.parametrize("bsz", range(1, 20))
def test_forward_pass_cuda_memory_bank(self, bsz: int) -> None:
reg = CO2Regularizer(memory_bank_size=(4096, 32))
batch_1 = torch.randn((bsz, 32)).cuda()
batch_2 = torch.randn((bsz, 32)).cuda()

# symmetry
l1 = reg(batch_1, batch_2)
assert l1 > 0

0 comments on commit b7c424d

Please sign in to comment.