diff --git a/ax/models/torch/tests/test_sebo.py b/ax/models/torch/tests/test_sebo.py index c0a7b86cb68..5b44a300ae3 100644 --- a/ax/models/torch/tests/test_sebo.py +++ b/ax/models/torch/tests/test_sebo.py @@ -254,8 +254,8 @@ def test_optimize_l0_homotopy( "raw_samples": 16, }, ) - self.assertEqual(candidate, torch.zeros(1, **tkwargs)) - self.assertEqual(acqf_val, 5 * torch.ones(1, **tkwargs)) + self.assertTrue(torch.allclose(candidate, torch.zeros(1, **tkwargs))) + self.assertTrue(torch.allclose(acqf_val, 5 * torch.ones(1, **tkwargs))) self.assertEqual(weights, torch.ones(1, **tkwargs)) @mock.patch(f"{SEBOACQUISITION_PATH}.optimize_acqf_homotopy")