Skip to content

Commit

Permalink
Refactor test function test utils (#1839)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1839

Refactors the test utils for botorch test functions to make things more modular.

Instead of subclassing the base test cases, we now have class Mixins for synthetic, constrained, and multi-objective problems. This means we can mix-and-match the test case as needed.

This also flattens the `TestMultiObjectiveProblems` and `TestConstrainedMultiObjectiveProblems` test collections (the idea behind the test design is that we have separate test cases for each problem (potentially with different arguments) to better identify and group the tests, so this reestablishes that.

Reviewed By: esantorella

Differential Revision: D45969036

fbshipit-source-id: 888fbbd861d020a3c0e8cc777bb8849ee3cde462
  • Loading branch information
Balandat authored and facebook-github-bot committed May 19, 2023
1 parent 70d0c63 commit 98b4194
Show file tree
Hide file tree
Showing 5 changed files with 289 additions and 103 deletions.
108 changes: 55 additions & 53 deletions botorch/utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import math
import warnings
from abc import abstractproperty
from collections import OrderedDict
from typing import Any, List, Optional, Tuple, Union
from unittest import TestCase
Expand Down Expand Up @@ -93,10 +94,7 @@ def assertAllClose(
)


class BaseTestProblemBaseTestCase:

functions: List[BaseTestProblem]

class BaseTestProblemTestCaseMixIn:
def test_forward(self):
for dtype in (torch.float, torch.double):
for batch_shape in (torch.Size(), torch.Size([2]), torch.Size([2, 3])):
Expand All @@ -113,8 +111,14 @@ def test_forward(self):
)
self.assertEqual(res.shape, batch_shape + tail_shape)

@abstractproperty
def functions(self) -> List[BaseTestProblem]:
# The functions that should be tested. Typically defined as a class
# attribute on the test case subclassing this class.
pass # pragma: no cover


class SyntheticTestFunctionBaseTestCase(BaseTestProblemBaseTestCase):
class SyntheticTestFunctionTestCaseMixin:
def test_optimal_value(self):
for dtype in (torch.float, torch.double):
for f in self.functions:
Expand Down Expand Up @@ -143,6 +147,52 @@ def test_optimizer(self):
self.assertLess(grad.abs().max().item(), 1e-3)


class MultiObjectiveTestProblemTestCaseMixin:
def test_attributes(self):
for f in self.functions:
self.assertTrue(hasattr(f, "dim"))
self.assertTrue(hasattr(f, "num_objectives"))
self.assertEqual(f.bounds.shape, torch.Size([2, f.dim]))

def test_max_hv(self):
for dtype in (torch.float, torch.double):
for f in self.functions:
f.to(device=self.device, dtype=dtype)
if not hasattr(f, "_max_hv"):
with self.assertRaises(NotImplementedError):
f.max_hv
else:
self.assertEqual(f.max_hv, f._max_hv)

def test_ref_point(self):
for dtype in (torch.float, torch.double):
for f in self.functions:
f.to(dtype=dtype, device=self.device)
self.assertTrue(
torch.allclose(
f.ref_point,
torch.tensor(f._ref_point, dtype=dtype, device=self.device),
)
)


class ConstrainedTestProblemTestCaseMixin:
def test_num_constraints(self):
for f in self.functions:
self.assertTrue(hasattr(f, "num_constraints"))

def test_evaluate_slack_true(self):
for dtype in (torch.float, torch.double):
for f in self.functions:
f.to(device=self.device, dtype=dtype)
X = unnormalize(
torch.rand(1, f.dim, device=self.device, dtype=dtype),
bounds=f.bounds,
)
slack = f.evaluate_slack_true(X)
self.assertEqual(slack.shape, torch.Size([1, f.num_constraints]))


class MockPosterior(Posterior):
r"""Mock object that implements dummy methods and feeds through specified outputs"""

Expand Down Expand Up @@ -368,51 +418,3 @@ def _get_test_posterior(
covar = covar + torch.diag_embed(flat_diag)
mtmvn = MultitaskMultivariateNormal(mean, covar, interleaved=interleaved)
return GPyTorchPosterior(mtmvn)


class MultiObjectiveTestProblemBaseTestCase(BaseTestProblemBaseTestCase):
def test_attributes(self):
for f in self.functions:
self.assertTrue(hasattr(f, "dim"))
self.assertTrue(hasattr(f, "num_objectives"))
self.assertEqual(f.bounds.shape, torch.Size([2, f.dim]))

def test_max_hv(self):
for dtype in (torch.float, torch.double):
for f in self.functions:
f.to(device=self.device, dtype=dtype)
if not hasattr(f, "_max_hv"):
with self.assertRaises(NotImplementedError):
f.max_hv
else:
self.assertEqual(f.max_hv, f._max_hv)

def test_ref_point(self):
for dtype in (torch.float, torch.double):
for f in self.functions:
f.to(dtype=dtype, device=self.device)
self.assertTrue(
torch.allclose(
f.ref_point,
torch.tensor(f._ref_point, dtype=dtype, device=self.device),
)
)


class ConstrainedMultiObjectiveTestProblemBaseTestCase(
MultiObjectiveTestProblemBaseTestCase
):
def test_num_constraints(self):
for f in self.functions:
self.assertTrue(hasattr(f, "num_constraints"))

def test_evaluate_slack_true(self):
for dtype in (torch.float, torch.double):
for f in self.functions:
f.to(device=self.device, dtype=dtype)
X = unnormalize(
torch.rand(1, f.dim, device=self.device, dtype=dtype),
bounds=f.bounds,
)
slack = f.evaluate_slack_true(X)
self.assertEqual(slack.shape, torch.Size([1, f.num_constraints]))
18 changes: 14 additions & 4 deletions test/test_functions/test_multi_fidelity.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,16 @@
AugmentedHartmann,
AugmentedRosenbrock,
)
from botorch.utils.testing import BotorchTestCase, SyntheticTestFunctionBaseTestCase
from botorch.utils.testing import (
BaseTestProblemTestCaseMixIn,
BotorchTestCase,
SyntheticTestFunctionTestCaseMixin,
)


class TestAugmentedBranin(SyntheticTestFunctionBaseTestCase, BotorchTestCase):
class TestAugmentedBranin(
BotorchTestCase, BaseTestProblemTestCaseMixIn, SyntheticTestFunctionTestCaseMixin
):

functions = [
AugmentedBranin(),
Expand All @@ -21,7 +27,9 @@ class TestAugmentedBranin(SyntheticTestFunctionBaseTestCase, BotorchTestCase):
]


class TestAugmentedHartmann(SyntheticTestFunctionBaseTestCase, BotorchTestCase):
class TestAugmentedHartmann(
BotorchTestCase, BaseTestProblemTestCaseMixIn, SyntheticTestFunctionTestCaseMixin
):

functions = [
AugmentedHartmann(),
Expand All @@ -30,7 +38,9 @@ class TestAugmentedHartmann(SyntheticTestFunctionBaseTestCase, BotorchTestCase):
]


class TestAugmentedRosenbrock(SyntheticTestFunctionBaseTestCase, BotorchTestCase):
class TestAugmentedRosenbrock(
BotorchTestCase, BaseTestProblemTestCaseMixIn, SyntheticTestFunctionTestCaseMixin
):

functions = [
AugmentedRosenbrock(),
Expand Down
Loading

0 comments on commit 98b4194

Please sign in to comment.