Skip to content

Commit

Permalink
Fix contrastive loss with temperature test (#509)
Browse files Browse the repository at this point in the history
Summary:
For some reason when we run pytest on a subdirectory of tests/ PyTorch is unable to find GPUs if they are present. But when we run just `pytest tests/` PyTorch does see the GPU.

This means that in the latter case the contrastive_loss_with_temperature test fails due to tensors on different devices (previously only model params were moved to the current device but the embeddings were not). But random initialization on GPU is different than on CPU so the test will give different results across different types of hardware.

Since the contrastive loss distributed tests are intended to only run on GPU, we leave those as is. But for the remaining test cases we explicitly move everything to CPU so we can ensure consistent results.

Pull Request resolved: #509

Reviewed By: kartikayk

Differential Revision: D50975554

Pulled By: ebsmothers

fbshipit-source-id: 71cab826f6fa6643f409643939e4a2fda24efaab
  • Loading branch information
ebsmothers authored and facebook-github-bot committed Nov 9, 2023
1 parent a33a8b8 commit b933b8e
Showing 1 changed file with 22 additions and 21 deletions.
43 changes: 22 additions & 21 deletions tests/modules/losses/test_contrastive_loss_with_temperature.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from torchmultimodal.modules.losses.contrastive_loss_with_temperature import (
ContrastiveLossWithTemperature,
)
from torchmultimodal.utils.common import get_current_device
from torchmultimodal.utils.distributed import BackpropType


Expand Down Expand Up @@ -71,56 +70,58 @@ def text_encoder(self, text_dim, embedding_dim):
def image_encoder(self, image_dim, embedding_dim):
return nn.Linear(image_dim, embedding_dim)

def test_local_loss(self):
@pytest.fixture()
def device(self):
return torch.device("cpu")

def test_local_loss(self, device):
torch.manual_seed(1234)
clip_loss = ContrastiveLossWithTemperature()
clip_loss = clip_loss.to(get_current_device())
embeddings_a = torch.randn(3, 5)
embeddings_b = torch.randn(3, 5)
clip_loss = clip_loss.to(device)
embeddings_a = torch.randn(3, 5, device=device)
embeddings_b = torch.randn(3, 5, device=device)
loss = clip_loss(embeddings_a=embeddings_a, embeddings_b=embeddings_b)

assert_expected(loss.item(), 9.8753, rtol=0, atol=1e-3)

def test_temperature_clamp_max(self):
def test_temperature_clamp_max(self, device):
torch.manual_seed(1234)
clip_loss_at_max = ContrastiveLossWithTemperature(
logit_scale=2, logit_scale_max=2
).to(get_current_device())
).to(device)
clip_loss_above_max = ContrastiveLossWithTemperature(
logit_scale=3, logit_scale_max=2
).to(get_current_device())
embeddings_a = torch.randn(3, 5)
embeddings_b = torch.randn(3, 5)
).to(device)
embeddings_a = torch.randn(3, 5, device=device)
embeddings_b = torch.randn(3, 5, device=device)
loss_at_max = clip_loss_at_max(embeddings_a, embeddings_b).item()
loss_above_max = clip_loss_above_max(embeddings_a, embeddings_b).item()
assert_expected(loss_above_max, loss_at_max, rtol=0, atol=1e-3)

def test_temperature_clamp_min(self):
def test_temperature_clamp_min(self, device):
torch.manual_seed(1234)
clip_loss_at_min = ContrastiveLossWithTemperature(
logit_scale=2, logit_scale_min=2
).to(get_current_device())
).to(device)
clip_loss_below_min = ContrastiveLossWithTemperature(
logit_scale=1, logit_scale_min=2
).to(get_current_device())
embeddings_a = torch.randn(3, 5)
embeddings_b = torch.randn(3, 5)
).to(device)
embeddings_a = torch.randn(3, 5, device=device)
embeddings_b = torch.randn(3, 5, device=device)
loss_at_min = clip_loss_at_min(embeddings_a, embeddings_b).item()
loss_below_min = clip_loss_below_min(embeddings_a, embeddings_b).item()
assert_expected(loss_below_min, loss_at_min, rtol=0, atol=1e-3)

def test_loss_with_ce_kwargs(self):
def test_loss_with_ce_kwargs(self, device):
torch.manual_seed(1234)
clip_loss = ContrastiveLossWithTemperature()
clip_loss = clip_loss.to(get_current_device())
embeddings_a = torch.randn(3, 5)
embeddings_b = torch.randn(3, 5)
clip_loss = clip_loss.to(device)
embeddings_a = torch.randn(3, 5, device=device)
embeddings_b = torch.randn(3, 5, device=device)
loss = clip_loss(
embeddings_a=embeddings_a,
embeddings_b=embeddings_b,
cross_entropy_kwargs={"label_smoothing": 0.1},
)

assert_expected(loss.item(), 10.2524, rtol=0, atol=1e-3)

def test_temperature_clamp_invalid(self):
Expand Down

0 comments on commit b933b8e

Please sign in to comment.