From 8905db99a1604feffc7f946f126c75563dbcf310 Mon Sep 17 00:00:00 2001 From: ydshieh Date: Tue, 23 May 2023 12:00:55 +0200 Subject: [PATCH] fix --- tests/models/sam/test_modeling_sam.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/models/sam/test_modeling_sam.py b/tests/models/sam/test_modeling_sam.py index 599ed5e384bc..2342e8010b92 100644 --- a/tests/models/sam/test_modeling_sam.py +++ b/tests/models/sam/test_modeling_sam.py @@ -476,7 +476,7 @@ def test_inference_mask_generation_no_point(self): scores = outputs.iou_scores.squeeze() masks = outputs.pred_masks[0, 0, 0, 0, :3] self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.4515), atol=2e-4)) - self.assertTrue(torch.allclose(masks, torch.tensor([-4.1807, -3.4949, -3.4483]).to(torch_device), atol=2e-4)) + self.assertTrue(torch.allclose(masks, torch.tensor([-4.1800, -3.4948, -3.4481]).to(torch_device), atol=2e-4)) def test_inference_mask_generation_one_point_one_bb(self): model = SamModel.from_pretrained("facebook/sam-vit-base") @@ -499,7 +499,7 @@ def test_inference_mask_generation_one_point_one_bb(self): masks = outputs.pred_masks[0, 0, 0, 0, :3] self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9566), atol=2e-4)) self.assertTrue( - torch.allclose(masks, torch.tensor([-12.7657, -12.3683, -12.5985]).to(torch_device), atol=2e-4) + torch.allclose(masks, torch.tensor([-12.7729, -12.3665, -12.6061]).to(torch_device), atol=2e-4) ) def test_inference_mask_generation_batched_points_batched_images(self): @@ -540,7 +540,7 @@ def test_inference_mask_generation_batched_points_batched_images(self): ], ] ) - EXPECTED_MASKS = torch.tensor([-2.8552, -2.7990, -2.9612]) + EXPECTED_MASKS = torch.tensor([-2.8550, -2.7988, -2.9625]) self.assertTrue(torch.allclose(scores, EXPECTED_SCORES, atol=1e-3)) self.assertTrue(torch.allclose(masks, EXPECTED_MASKS, atol=1e-3)) @@ -568,7 +568,7 @@ def test_inference_mask_generation_one_point_one_bb_zero(self): outputs = model(**inputs) scores = outputs.iou_scores.squeeze() - self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.7892), atol=1e-4)) + self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.7894), atol=1e-4)) def test_inference_mask_generation_one_point(self): model = SamModel.from_pretrained("facebook/sam-vit-base")