Skip to content

Commit

Permalink
Refactor test function test utils (pytorch#1839)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1839

Refactors the test utils for botorch test functions to make things more modular.

Instead of subclassing the base test cases, we now have class Mixins for synthetic, constrained, and multi-objective problems. This means we can mix-and-match the test case as needed.

This also flattens the `TestMultiObjectiveProblems` and `TestConstrainedMultiObjectiveProblems` test collections (the idea behind the test design is that we have separate test cases for each problem (potentially with different arguments) to better identify and group the tests, so this reestablishes that.

Differential Revision: D45969036

fbshipit-source-id: 2b3c00813185b93bfd55874f34b330ed7280d52b
  • Loading branch information
Balandat authored and facebook-github-bot committed May 18, 2023
1 parent 8c9d54b commit 5f7a444
Show file tree
Hide file tree
Showing 5 changed files with 282 additions and 100 deletions.
98 changes: 48 additions & 50 deletions botorch/utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def assertAllClose(
)


class BaseTestProblemBaseTestCase:
class BaseTestProblemTestCaseMixIn:

functions: List[BaseTestProblem]

Expand All @@ -114,7 +114,7 @@ def test_forward(self):
self.assertEqual(res.shape, batch_shape + tail_shape)


class SyntheticTestFunctionBaseTestCase(BaseTestProblemBaseTestCase):
class SyntheticTestFunctionTestCaseMixin:
def test_optimal_value(self):
for dtype in (torch.float, torch.double):
for f in self.functions:
Expand Down Expand Up @@ -143,6 +143,52 @@ def test_optimizer(self):
self.assertLess(grad.abs().max().item(), 1e-3)


class MultiObjectiveTestProblemTestCaseMixin:
def test_attributes(self):
for f in self.functions:
self.assertTrue(hasattr(f, "dim"))
self.assertTrue(hasattr(f, "num_objectives"))
self.assertEqual(f.bounds.shape, torch.Size([2, f.dim]))

def test_max_hv(self):
for dtype in (torch.float, torch.double):
for f in self.functions:
f.to(device=self.device, dtype=dtype)
if not hasattr(f, "_max_hv"):
with self.assertRaises(NotImplementedError):
f.max_hv
else:
self.assertEqual(f.max_hv, f._max_hv)

def test_ref_point(self):
for dtype in (torch.float, torch.double):
for f in self.functions:
f.to(dtype=dtype, device=self.device)
self.assertTrue(
torch.allclose(
f.ref_point,
torch.tensor(f._ref_point, dtype=dtype, device=self.device),
)
)


class ConstrainedTestProblemTestCaseMixin:
def test_num_constraints(self):
for f in self.functions:
self.assertTrue(hasattr(f, "num_constraints"))

def test_evaluate_slack_true(self):
for dtype in (torch.float, torch.double):
for f in self.functions:
f.to(device=self.device, dtype=dtype)
X = unnormalize(
torch.rand(1, f.dim, device=self.device, dtype=dtype),
bounds=f.bounds,
)
slack = f.evaluate_slack_true(X)
self.assertEqual(slack.shape, torch.Size([1, f.num_constraints]))


class MockPosterior(Posterior):
r"""Mock object that implements dummy methods and feeds through specified outputs"""

Expand Down Expand Up @@ -368,51 +414,3 @@ def _get_test_posterior(
covar = covar + torch.diag_embed(flat_diag)
mtmvn = MultitaskMultivariateNormal(mean, covar, interleaved=interleaved)
return GPyTorchPosterior(mtmvn)


class MultiObjectiveTestProblemBaseTestCase(BaseTestProblemBaseTestCase):
def test_attributes(self):
for f in self.functions:
self.assertTrue(hasattr(f, "dim"))
self.assertTrue(hasattr(f, "num_objectives"))
self.assertEqual(f.bounds.shape, torch.Size([2, f.dim]))

def test_max_hv(self):
for dtype in (torch.float, torch.double):
for f in self.functions:
f.to(device=self.device, dtype=dtype)
if not hasattr(f, "_max_hv"):
with self.assertRaises(NotImplementedError):
f.max_hv
else:
self.assertEqual(f.max_hv, f._max_hv)

def test_ref_point(self):
for dtype in (torch.float, torch.double):
for f in self.functions:
f.to(dtype=dtype, device=self.device)
self.assertTrue(
torch.allclose(
f.ref_point,
torch.tensor(f._ref_point, dtype=dtype, device=self.device),
)
)


class ConstrainedMultiObjectiveTestProblemBaseTestCase(
MultiObjectiveTestProblemBaseTestCase
):
def test_num_constraints(self):
for f in self.functions:
self.assertTrue(hasattr(f, "num_constraints"))

def test_evaluate_slack_true(self):
for dtype in (torch.float, torch.double):
for f in self.functions:
f.to(device=self.device, dtype=dtype)
X = unnormalize(
torch.rand(1, f.dim, device=self.device, dtype=dtype),
bounds=f.bounds,
)
slack = f.evaluate_slack_true(X)
self.assertEqual(slack.shape, torch.Size([1, f.num_constraints]))
18 changes: 14 additions & 4 deletions test/test_functions/test_multi_fidelity.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,16 @@
AugmentedHartmann,
AugmentedRosenbrock,
)
from botorch.utils.testing import BotorchTestCase, SyntheticTestFunctionBaseTestCase
from botorch.utils.testing import (
BaseTestProblemTestCaseMixIn,
BotorchTestCase,
SyntheticTestFunctionTestCaseMixin,
)


class TestAugmentedBranin(SyntheticTestFunctionBaseTestCase, BotorchTestCase):
class TestAugmentedBranin(
BotorchTestCase, BaseTestProblemTestCaseMixIn, SyntheticTestFunctionTestCaseMixin
):

functions = [
AugmentedBranin(),
Expand All @@ -21,7 +27,9 @@ class TestAugmentedBranin(SyntheticTestFunctionBaseTestCase, BotorchTestCase):
]


class TestAugmentedHartmann(SyntheticTestFunctionBaseTestCase, BotorchTestCase):
class TestAugmentedHartmann(
BotorchTestCase, BaseTestProblemTestCaseMixIn, SyntheticTestFunctionTestCaseMixin
):

functions = [
AugmentedHartmann(),
Expand All @@ -30,7 +38,9 @@ class TestAugmentedHartmann(SyntheticTestFunctionBaseTestCase, BotorchTestCase):
]


class TestAugmentedRosenbrock(SyntheticTestFunctionBaseTestCase, BotorchTestCase):
class TestAugmentedRosenbrock(
BotorchTestCase, BaseTestProblemTestCaseMixIn, SyntheticTestFunctionTestCaseMixin
):

functions = [
AugmentedRosenbrock(),
Expand Down
160 changes: 139 additions & 21 deletions test/test_functions/test_multi_objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,10 @@
ZDT3,
)
from botorch.utils.testing import (
BaseTestProblemTestCaseMixIn,
BotorchTestCase,
ConstrainedMultiObjectiveTestProblemBaseTestCase,
MultiObjectiveTestProblemBaseTestCase,
ConstrainedTestProblemTestCaseMixin,
MultiObjectiveTestProblemTestCaseMixin,
)


Expand Down Expand Up @@ -74,7 +75,11 @@ def test_base_mo_problem(self):
f.gen_pareto_front(1)


class TestBraninCurrin(MultiObjectiveTestProblemBaseTestCase, BotorchTestCase):
class TestBraninCurrin(
BotorchTestCase,
BaseTestProblemTestCaseMixIn,
MultiObjectiveTestProblemTestCaseMixin,
):
functions = [BraninCurrin()]

def test_init(self):
Expand All @@ -83,7 +88,11 @@ def test_init(self):
self.assertEqual(f.dim, 2)


class TestDH(MultiObjectiveTestProblemBaseTestCase, BotorchTestCase):
class TestDH(
BotorchTestCase,
BaseTestProblemTestCaseMixIn,
MultiObjectiveTestProblemTestCaseMixin,
):
functions = [DH1(dim=2), DH2(dim=3), DH3(dim=4), DH4(dim=5)]
dims = [2, 3, 4, 5]
bounds = [
Expand Down Expand Up @@ -118,7 +127,11 @@ def test_function_values(self):
self.assertAllClose(actual, expected)


class TestDTLZ(MultiObjectiveTestProblemBaseTestCase, BotorchTestCase):
class TestDTLZ(
BotorchTestCase,
BaseTestProblemTestCaseMixIn,
MultiObjectiveTestProblemTestCaseMixin,
):
functions = [
DTLZ1(dim=5, num_objectives=2),
DTLZ2(dim=5, num_objectives=2),
Expand Down Expand Up @@ -180,7 +193,11 @@ def test_gen_pareto_front(self):
)


class TestGMM(MultiObjectiveTestProblemBaseTestCase, BotorchTestCase):
class TestGMM(
BotorchTestCase,
BaseTestProblemTestCaseMixIn,
MultiObjectiveTestProblemTestCaseMixin,
):
functions = [GMM(num_objectives=4)]

def test_init(self):
Expand Down Expand Up @@ -226,7 +243,12 @@ def test_result(self):
)


class TestMW7(ConstrainedMultiObjectiveTestProblemBaseTestCase, BotorchTestCase):
class TestMW7(
BotorchTestCase,
BaseTestProblemTestCaseMixIn,
MultiObjectiveTestProblemTestCaseMixin,
ConstrainedTestProblemTestCaseMixin,
):
functions = [MW7(dim=3)]

def test_init(self):
Expand All @@ -237,7 +259,11 @@ def test_init(self):
self.assertEqual(f.dim, 3)


class TestZDT(MultiObjectiveTestProblemBaseTestCase, BotorchTestCase):
class TestZDT(
BotorchTestCase,
BaseTestProblemTestCaseMixIn,
MultiObjectiveTestProblemTestCaseMixin,
):
functions = [
ZDT1(dim=3, num_objectives=2),
ZDT2(dim=3, num_objectives=2),
Expand Down Expand Up @@ -304,27 +330,119 @@ def test_gen_pareto_front(self):
)


class TestMultiObjectiveProblems(
MultiObjectiveTestProblemBaseTestCase, BotorchTestCase
# ------------------ Unconstrained Multi-objective test problems ------------------ #


class TestCarSideImpact(
BotorchTestCase,
BaseTestProblemTestCaseMixIn,
MultiObjectiveTestProblemTestCaseMixin,
):

functions = [CarSideImpact()]


class TestPenicillin(
BotorchTestCase,
BaseTestProblemTestCaseMixIn,
MultiObjectiveTestProblemTestCaseMixin,
):
functions = [CarSideImpact(), Penicillin(), ToyRobust(), VehicleSafety()]
functions = [Penicillin()]


class TestConstrainedMultiObjectiveProblems(
ConstrainedMultiObjectiveTestProblemBaseTestCase, BotorchTestCase
class TestToyRobust(
BotorchTestCase,
BaseTestProblemTestCaseMixIn,
MultiObjectiveTestProblemTestCaseMixin,
):
functions = [ToyRobust()]


class TestVehicleSafety(
BotorchTestCase,
BaseTestProblemTestCaseMixIn,
MultiObjectiveTestProblemTestCaseMixin,
):
functions = [VehicleSafety()]


# ------------------ Constrained Multi-objective test problems ------------------ #


class TestBNH(
BotorchTestCase,
BaseTestProblemTestCaseMixIn,
MultiObjectiveTestProblemTestCaseMixin,
ConstrainedTestProblemTestCaseMixin,
):
functions = [BNH()]


class TestSRN(
BotorchTestCase,
BaseTestProblemTestCaseMixIn,
MultiObjectiveTestProblemTestCaseMixin,
ConstrainedTestProblemTestCaseMixin,
):
functions = [SRN()]


class TestCONSTR(
BotorchTestCase,
BaseTestProblemTestCaseMixIn,
MultiObjectiveTestProblemTestCaseMixin,
ConstrainedTestProblemTestCaseMixin,
):
functions = [CONSTR()]


class TestConstrainedBraninCurrin(
BotorchTestCase,
BaseTestProblemTestCaseMixIn,
MultiObjectiveTestProblemTestCaseMixin,
ConstrainedTestProblemTestCaseMixin,
):
functions = [
BNH(),
SRN(),
CONSTR(),
ConstrainedBraninCurrin(),
C2DTLZ2(dim=3, num_objectives=2),
DiscBrake(),
WeldedBeam(),
OSY(),
]

def test_c2dtlz2_batch_exception(self):

class TestC2DTLZ2(
BotorchTestCase,
BaseTestProblemTestCaseMixIn,
MultiObjectiveTestProblemTestCaseMixin,
ConstrainedTestProblemTestCaseMixin,
):
functions = [C2DTLZ2(dim=3, num_objectives=2)]

def test_batch_exception(self):
f = C2DTLZ2(dim=3, num_objectives=2)
with self.assertRaises(NotImplementedError):
f.evaluate_slack_true(torch.empty(1, 1, 3))


class TestDiscBrake(
BotorchTestCase,
BaseTestProblemTestCaseMixIn,
MultiObjectiveTestProblemTestCaseMixin,
ConstrainedTestProblemTestCaseMixin,
):
functions = [DiscBrake()]


class TestWeldedBeam(
BotorchTestCase,
BaseTestProblemTestCaseMixIn,
MultiObjectiveTestProblemTestCaseMixin,
ConstrainedTestProblemTestCaseMixin,
):
functions = [WeldedBeam()]


class TestOSY(
BotorchTestCase,
BaseTestProblemTestCaseMixIn,
MultiObjectiveTestProblemTestCaseMixin,
ConstrainedTestProblemTestCaseMixin,
):
functions = [OSY()]
Loading

0 comments on commit 5f7a444

Please sign in to comment.