From b933b8e7351073dcae5b218319e4ea6e87df7b90 Mon Sep 17 00:00:00 2001 From: ebsmothers Date: Thu, 9 Nov 2023 09:35:53 -0800 Subject: [PATCH] Fix contrastive loss with temperature test (#509) Summary: For some reason when we run pytest on a subdirectory of tests/ PyTorch is unable to find GPUs if they are present. But when we run just `pytest tests/` PyTorch does see the GPU. This means that in the latter case the contrastive_loss_with_temperature test fails due to tensors on different devices (previously only model params were moved to the current device but the embeddings were not). But random initialization on GPU is different than on CPU so the test will give different results across different types of hardware. Since the contrastive loss distributed tests are intended to only run on GPU, we leave those as is. But for the remaining test cases we explicitly move everything to CPU so we can ensure consistent results. Pull Request resolved: https://github.com/facebookresearch/multimodal/pull/509 Reviewed By: kartikayk Differential Revision: D50975554 Pulled By: ebsmothers fbshipit-source-id: 71cab826f6fa6643f409643939e4a2fda24efaab --- .../test_contrastive_loss_with_temperature.py | 43 ++++++++++--------- 1 file changed, 22 insertions(+), 21 deletions(-) diff --git a/tests/modules/losses/test_contrastive_loss_with_temperature.py b/tests/modules/losses/test_contrastive_loss_with_temperature.py index bea1e7a7d..55b0c4c2a 100644 --- a/tests/modules/losses/test_contrastive_loss_with_temperature.py +++ b/tests/modules/losses/test_contrastive_loss_with_temperature.py @@ -25,7 +25,6 @@ from torchmultimodal.modules.losses.contrastive_loss_with_temperature import ( ContrastiveLossWithTemperature, ) -from torchmultimodal.utils.common import get_current_device from torchmultimodal.utils.distributed import BackpropType @@ -71,56 +70,58 @@ def text_encoder(self, text_dim, embedding_dim): def image_encoder(self, image_dim, embedding_dim): return nn.Linear(image_dim, embedding_dim) - def test_local_loss(self): + @pytest.fixture() + def device(self): + return torch.device("cpu") + + def test_local_loss(self, device): torch.manual_seed(1234) clip_loss = ContrastiveLossWithTemperature() - clip_loss = clip_loss.to(get_current_device()) - embeddings_a = torch.randn(3, 5) - embeddings_b = torch.randn(3, 5) + clip_loss = clip_loss.to(device) + embeddings_a = torch.randn(3, 5, device=device) + embeddings_b = torch.randn(3, 5, device=device) loss = clip_loss(embeddings_a=embeddings_a, embeddings_b=embeddings_b) - assert_expected(loss.item(), 9.8753, rtol=0, atol=1e-3) - def test_temperature_clamp_max(self): + def test_temperature_clamp_max(self, device): torch.manual_seed(1234) clip_loss_at_max = ContrastiveLossWithTemperature( logit_scale=2, logit_scale_max=2 - ).to(get_current_device()) + ).to(device) clip_loss_above_max = ContrastiveLossWithTemperature( logit_scale=3, logit_scale_max=2 - ).to(get_current_device()) - embeddings_a = torch.randn(3, 5) - embeddings_b = torch.randn(3, 5) + ).to(device) + embeddings_a = torch.randn(3, 5, device=device) + embeddings_b = torch.randn(3, 5, device=device) loss_at_max = clip_loss_at_max(embeddings_a, embeddings_b).item() loss_above_max = clip_loss_above_max(embeddings_a, embeddings_b).item() assert_expected(loss_above_max, loss_at_max, rtol=0, atol=1e-3) - def test_temperature_clamp_min(self): + def test_temperature_clamp_min(self, device): torch.manual_seed(1234) clip_loss_at_min = ContrastiveLossWithTemperature( logit_scale=2, logit_scale_min=2 - ).to(get_current_device()) + ).to(device) clip_loss_below_min = ContrastiveLossWithTemperature( logit_scale=1, logit_scale_min=2 - ).to(get_current_device()) - embeddings_a = torch.randn(3, 5) - embeddings_b = torch.randn(3, 5) + ).to(device) + embeddings_a = torch.randn(3, 5, device=device) + embeddings_b = torch.randn(3, 5, device=device) loss_at_min = clip_loss_at_min(embeddings_a, embeddings_b).item() loss_below_min = clip_loss_below_min(embeddings_a, embeddings_b).item() assert_expected(loss_below_min, loss_at_min, rtol=0, atol=1e-3) - def test_loss_with_ce_kwargs(self): + def test_loss_with_ce_kwargs(self, device): torch.manual_seed(1234) clip_loss = ContrastiveLossWithTemperature() - clip_loss = clip_loss.to(get_current_device()) - embeddings_a = torch.randn(3, 5) - embeddings_b = torch.randn(3, 5) + clip_loss = clip_loss.to(device) + embeddings_a = torch.randn(3, 5, device=device) + embeddings_b = torch.randn(3, 5, device=device) loss = clip_loss( embeddings_a=embeddings_a, embeddings_b=embeddings_b, cross_entropy_kwargs={"label_smoothing": 0.1}, ) - assert_expected(loss.item(), 10.2524, rtol=0, atol=1e-3) def test_temperature_clamp_invalid(self):