Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stop numerical tests from flaking; use assertRaisesRegex #1991

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 17 additions & 10 deletions test/acquisition/test_objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,29 +399,36 @@ def test_identity_mc_objective(self):


class TestLinearMCObjective(BotorchTestCase):
def test_linear_mc_objective(self):
def test_linear_mc_objective(self) -> None:
# Test passes for each seed
torch.manual_seed(torch.randint(high=1000, size=(1,)))
for dtype in (torch.float, torch.double):
weights = torch.rand(3, device=self.device, dtype=dtype)
obj = LinearMCObjective(weights=weights)
samples = torch.randn(4, 2, 3, device=self.device, dtype=dtype)
self.assertTrue(
torch.allclose(obj(samples), (samples * weights).sum(dim=-1))
)
atol = 1e-8 if dtype == torch.double else 3e-8
rtol = 1e-5 if dtype == torch.double else 4e-5
self.assertAllClose(obj(samples), samples @ weights, atol=atol, rtol=rtol)
samples = torch.randn(5, 4, 2, 3, device=self.device, dtype=dtype)
self.assertTrue(
torch.allclose(obj(samples), (samples * weights).sum(dim=-1))
self.assertAllClose(
obj(samples),
samples @ weights,
atol=atol,
rtol=rtol,
)
# make sure this errors if sample output dimensions are incompatible
with self.assertRaises(RuntimeError):
shape_mismatch_msg = "Output shape of samples not equal to that of weights"
with self.assertRaisesRegex(RuntimeError, shape_mismatch_msg):
obj(samples=torch.randn(2, device=self.device, dtype=dtype))
with self.assertRaises(RuntimeError):
with self.assertRaisesRegex(RuntimeError, shape_mismatch_msg):
obj(samples=torch.randn(1, device=self.device, dtype=dtype))
# make sure we can't construct objectives with multi-dim. weights
with self.assertRaises(ValueError):
weights_1d_msg = "weights must be a one-dimensional tensor."
with self.assertRaisesRegex(ValueError, expected_regex=weights_1d_msg):
LinearMCObjective(
weights=torch.rand(2, 3, device=self.device, dtype=dtype)
)
with self.assertRaises(ValueError):
with self.assertRaisesRegex(ValueError, expected_regex=weights_1d_msg):
LinearMCObjective(
weights=torch.tensor(1.0, device=self.device, dtype=dtype)
)
Expand Down
37 changes: 23 additions & 14 deletions test/utils/probability/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,20 +153,18 @@ def test_swap_along_dim_(self):
with self.assertRaisesRegex(ValueError, "at most 1-dimensional"):
utils.swap_along_dim_(values.view(-1), i=i_lidx, j=j, dim=0)

def test_gaussian_probabilities(self):
def test_gaussian_probabilities(self) -> None:
# test passes for each possible seed
torch.manual_seed(torch.randint(high=1000, size=(1,)))
# testing Gaussian probability functions
for dtype in (torch.float, torch.double):
rtol = 1e-12 if dtype == torch.double else 1e-6
atol = rtol
n = 16
x = 3 * torch.randn(n, device=self.device, dtype=dtype)
# first, test consistency between regular and log versions
self.assertTrue(
torch.allclose(phi(x), log_phi(x).exp(), atol=atol, rtol=rtol)
)
self.assertTrue(
torch.allclose(ndtr(x), log_ndtr(x).exp(), atol=atol, rtol=rtol)
)
self.assertAllClose(phi(x), log_phi(x).exp(), atol=atol, rtol=rtol)
self.assertAllClose(ndtr(x), log_ndtr(x).exp(), atol=atol, rtol=rtol)

# test correctness of log_erfc and log_erfcx
for special_f, custom_log_f in zip(
Expand Down Expand Up @@ -291,10 +289,13 @@ def test_gaussian_probabilities(self):
self.assertTrue((a.grad.diff() < 0).all())

# testing error raising for invalid inputs
with self.assertRaises(ValueError):
a = torch.randn(3, 4, dtype=dtype, device=self.device)
b = torch.randn(3, 4, dtype=dtype, device=self.device)
a[2, 3] = b[2, 3]
a = torch.randn(3, 4, dtype=dtype, device=self.device)
b = torch.randn(3, 4, dtype=dtype, device=self.device)
a[2, 3] = b[2, 3]
with self.assertRaisesRegex(
ValueError,
"Received input tensors a, b for which not all a < b.",
):
log_prob_normal_in(a, b)

# testing gaussian hazard function
Expand All @@ -303,12 +304,20 @@ def test_gaussian_probabilities(self):
x = torch.cat((-x, x))
log_hx = standard_normal_log_hazard(x)
expected_log_hx = log_phi(x) - log_ndtr(-x)
self.assertAllClose(expected_log_hx, log_hx) # correctness
self.assertAllClose(
expected_log_hx,
log_hx,
atol=1e-8 if dtype == torch.double else 1e-7,
) # correctness
# NOTE: Could extend tests here similarly to log_erfc(x) tests above, but
# since the hazard functions are built on log_erfcx, not urgent.

with self.assertRaises(TypeError):
float16_msg = (
"only supports torch.float32 and torch.float64 dtypes, but received "
"x.dtype = torch.float16."
)
with self.assertRaisesRegex(TypeError, expected_regex=float16_msg):
log_erfc(torch.tensor(1.0, dtype=torch.float16, device=self.device))

with self.assertRaises(TypeError):
with self.assertRaisesRegex(TypeError, expected_regex=float16_msg):
log_ndtr(torch.tensor(1.0, dtype=torch.float16, device=self.device))