Skip to content

Commit

Permalink
Refactor test function test utils
Browse files Browse the repository at this point in the history
Differential Revision: D45969036

fbshipit-source-id: 80034418bc73f899b7844b78bd3ed97407b37eec
  • Loading branch information
Balandat authored and facebook-github-bot committed May 18, 2023
1 parent 8c9d54b commit c796611
Show file tree
Hide file tree
Showing 5 changed files with 282 additions and 100 deletions.
98 changes: 48 additions & 50 deletions botorch/utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def assertAllClose(
)


class BaseTestProblemBaseTestCase:
class BaseTestProblemTestCaseMixIn:

functions: List[BaseTestProblem]

Expand All @@ -114,7 +114,7 @@ def test_forward(self):
self.assertEqual(res.shape, batch_shape + tail_shape)


class SyntheticTestFunctionBaseTestCase(BaseTestProblemBaseTestCase):
class SyntheticTestFunctionTestCaseMixin:
def test_optimal_value(self):
for dtype in (torch.float, torch.double):
for f in self.functions:
Expand Down Expand Up @@ -143,6 +143,52 @@ def test_optimizer(self):
self.assertLess(grad.abs().max().item(), 1e-3)


class MultiObjectiveTestProblemTestCaseMixin:
def test_attributes(self):
for f in self.functions:
self.assertTrue(hasattr(f, "dim"))
self.assertTrue(hasattr(f, "num_objectives"))
self.assertEqual(f.bounds.shape, torch.Size([2, f.dim]))

def test_max_hv(self):
for dtype in (torch.float, torch.double):
for f in self.functions:
f.to(device=self.device, dtype=dtype)
if not hasattr(f, "_max_hv"):
with self.assertRaises(NotImplementedError):
f.max_hv
else:
self.assertEqual(f.max_hv, f._max_hv)

def test_ref_point(self):
for dtype in (torch.float, torch.double):
for f in self.functions:
f.to(dtype=dtype, device=self.device)
self.assertTrue(
torch.allclose(
f.ref_point,
torch.tensor(f._ref_point, dtype=dtype, device=self.device),
)
)


class ConstrainedTestProblemTestCaseMixin:
def test_num_constraints(self):
for f in self.functions:
self.assertTrue(hasattr(f, "num_constraints"))

def test_evaluate_slack_true(self):
for dtype in (torch.float, torch.double):
for f in self.functions:
f.to(device=self.device, dtype=dtype)
X = unnormalize(
torch.rand(1, f.dim, device=self.device, dtype=dtype),
bounds=f.bounds,
)
slack = f.evaluate_slack_true(X)
self.assertEqual(slack.shape, torch.Size([1, f.num_constraints]))


class MockPosterior(Posterior):
r"""Mock object that implements dummy methods and feeds through specified outputs"""

Expand Down Expand Up @@ -368,51 +414,3 @@ def _get_test_posterior(
covar = covar + torch.diag_embed(flat_diag)
mtmvn = MultitaskMultivariateNormal(mean, covar, interleaved=interleaved)
return GPyTorchPosterior(mtmvn)


class MultiObjectiveTestProblemBaseTestCase(BaseTestProblemBaseTestCase):
def test_attributes(self):
for f in self.functions:
self.assertTrue(hasattr(f, "dim"))
self.assertTrue(hasattr(f, "num_objectives"))
self.assertEqual(f.bounds.shape, torch.Size([2, f.dim]))

def test_max_hv(self):
for dtype in (torch.float, torch.double):
for f in self.functions:
f.to(device=self.device, dtype=dtype)
if not hasattr(f, "_max_hv"):
with self.assertRaises(NotImplementedError):
f.max_hv
else:
self.assertEqual(f.max_hv, f._max_hv)

def test_ref_point(self):
for dtype in (torch.float, torch.double):
for f in self.functions:
f.to(dtype=dtype, device=self.device)
self.assertTrue(
torch.allclose(
f.ref_point,
torch.tensor(f._ref_point, dtype=dtype, device=self.device),
)
)


class ConstrainedMultiObjectiveTestProblemBaseTestCase(
MultiObjectiveTestProblemBaseTestCase
):
def test_num_constraints(self):
for f in self.functions:
self.assertTrue(hasattr(f, "num_constraints"))

def test_evaluate_slack_true(self):
for dtype in (torch.float, torch.double):
for f in self.functions:
f.to(device=self.device, dtype=dtype)
X = unnormalize(
torch.rand(1, f.dim, device=self.device, dtype=dtype),
bounds=f.bounds,
)
slack = f.evaluate_slack_true(X)
self.assertEqual(slack.shape, torch.Size([1, f.num_constraints]))
18 changes: 14 additions & 4 deletions test/test_functions/test_multi_fidelity.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,16 @@
AugmentedHartmann,
AugmentedRosenbrock,
)
from botorch.utils.testing import BotorchTestCase, SyntheticTestFunctionBaseTestCase
from botorch.utils.testing import (
BaseTestProblemTestCaseMixIn,
BotorchTestCase,
SyntheticTestFunctionTestCaseMixin,
)


class TestAugmentedBranin(SyntheticTestFunctionBaseTestCase, BotorchTestCase):
class TestAugmentedBranin(
BotorchTestCase, BaseTestProblemTestCaseMixIn, SyntheticTestFunctionTestCaseMixin
):

functions = [
AugmentedBranin(),
Expand All @@ -21,7 +27,9 @@ class TestAugmentedBranin(SyntheticTestFunctionBaseTestCase, BotorchTestCase):
]


class TestAugmentedHartmann(SyntheticTestFunctionBaseTestCase, BotorchTestCase):
class TestAugmentedHartmann(
BotorchTestCase, BaseTestProblemTestCaseMixIn, SyntheticTestFunctionTestCaseMixin
):

functions = [
AugmentedHartmann(),
Expand All @@ -30,7 +38,9 @@ class TestAugmentedHartmann(SyntheticTestFunctionBaseTestCase, BotorchTestCase):
]


class TestAugmentedRosenbrock(SyntheticTestFunctionBaseTestCase, BotorchTestCase):
class TestAugmentedRosenbrock(
BotorchTestCase, BaseTestProblemTestCaseMixIn, SyntheticTestFunctionTestCaseMixin
):

functions = [
AugmentedRosenbrock(),
Expand Down
160 changes: 139 additions & 21 deletions test/test_functions/test_multi_objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,10 @@
ZDT3,
)
from botorch.utils.testing import (
BaseTestProblemTestCaseMixIn,
BotorchTestCase,
ConstrainedMultiObjectiveTestProblemBaseTestCase,
MultiObjectiveTestProblemBaseTestCase,
ConstrainedTestProblemTestCaseMixin,
MultiObjectiveTestProblemTestCaseMixin,
)


Expand Down Expand Up @@ -74,7 +75,11 @@ def test_base_mo_problem(self):
f.gen_pareto_front(1)


class TestBraninCurrin(MultiObjectiveTestProblemBaseTestCase, BotorchTestCase):
class TestBraninCurrin(
BotorchTestCase,
BaseTestProblemTestCaseMixIn,
MultiObjectiveTestProblemTestCaseMixin,
):
functions = [BraninCurrin()]

def test_init(self):
Expand All @@ -83,7 +88,11 @@ def test_init(self):
self.assertEqual(f.dim, 2)


class TestDH(MultiObjectiveTestProblemBaseTestCase, BotorchTestCase):
class TestDH(
BotorchTestCase,
BaseTestProblemTestCaseMixIn,
MultiObjectiveTestProblemTestCaseMixin,
):
functions = [DH1(dim=2), DH2(dim=3), DH3(dim=4), DH4(dim=5)]
dims = [2, 3, 4, 5]
bounds = [
Expand Down Expand Up @@ -118,7 +127,11 @@ def test_function_values(self):
self.assertAllClose(actual, expected)


class TestDTLZ(MultiObjectiveTestProblemBaseTestCase, BotorchTestCase):
class TestDTLZ(
BotorchTestCase,
BaseTestProblemTestCaseMixIn,
MultiObjectiveTestProblemTestCaseMixin,
):
functions = [
DTLZ1(dim=5, num_objectives=2),
DTLZ2(dim=5, num_objectives=2),
Expand Down Expand Up @@ -180,7 +193,11 @@ def test_gen_pareto_front(self):
)


class TestGMM(MultiObjectiveTestProblemBaseTestCase, BotorchTestCase):
class TestGMM(
BotorchTestCase,
BaseTestProblemTestCaseMixIn,
MultiObjectiveTestProblemTestCaseMixin,
):
functions = [GMM(num_objectives=4)]

def test_init(self):
Expand Down Expand Up @@ -226,7 +243,12 @@ def test_result(self):
)


class TestMW7(ConstrainedMultiObjectiveTestProblemBaseTestCase, BotorchTestCase):
class TestMW7(
BotorchTestCase,
BaseTestProblemTestCaseMixIn,
MultiObjectiveTestProblemTestCaseMixin,
ConstrainedTestProblemTestCaseMixin,
):
functions = [MW7(dim=3)]

def test_init(self):
Expand All @@ -237,7 +259,11 @@ def test_init(self):
self.assertEqual(f.dim, 3)


class TestZDT(MultiObjectiveTestProblemBaseTestCase, BotorchTestCase):
class TestZDT(
BotorchTestCase,
BaseTestProblemTestCaseMixIn,
MultiObjectiveTestProblemTestCaseMixin,
):
functions = [
ZDT1(dim=3, num_objectives=2),
ZDT2(dim=3, num_objectives=2),
Expand Down Expand Up @@ -304,27 +330,119 @@ def test_gen_pareto_front(self):
)


class TestMultiObjectiveProblems(
MultiObjectiveTestProblemBaseTestCase, BotorchTestCase
# ------------------ Unconstrained Multi-objective test problems ------------------ #


class TestCarSideImpact(
BotorchTestCase,
BaseTestProblemTestCaseMixIn,
MultiObjectiveTestProblemTestCaseMixin,
):

functions = [CarSideImpact()]


class TestPenicillin(
BotorchTestCase,
BaseTestProblemTestCaseMixIn,
MultiObjectiveTestProblemTestCaseMixin,
):
functions = [CarSideImpact(), Penicillin(), ToyRobust(), VehicleSafety()]
functions = [Penicillin()]


class TestConstrainedMultiObjectiveProblems(
ConstrainedMultiObjectiveTestProblemBaseTestCase, BotorchTestCase
class TestToyRobust(
BotorchTestCase,
BaseTestProblemTestCaseMixIn,
MultiObjectiveTestProblemTestCaseMixin,
):
functions = [ToyRobust()]


class TestVehicleSafety(
BotorchTestCase,
BaseTestProblemTestCaseMixIn,
MultiObjectiveTestProblemTestCaseMixin,
):
functions = [VehicleSafety()]


# ------------------ Constrained Multi-objective test problems ------------------ #


class TestBNH(
BotorchTestCase,
BaseTestProblemTestCaseMixIn,
MultiObjectiveTestProblemTestCaseMixin,
ConstrainedTestProblemTestCaseMixin,
):
functions = [BNH()]


class TestSRN(
BotorchTestCase,
BaseTestProblemTestCaseMixIn,
MultiObjectiveTestProblemTestCaseMixin,
ConstrainedTestProblemTestCaseMixin,
):
functions = [SRN()]


class TestCONSTR(
BotorchTestCase,
BaseTestProblemTestCaseMixIn,
MultiObjectiveTestProblemTestCaseMixin,
ConstrainedTestProblemTestCaseMixin,
):
functions = [CONSTR()]


class TestConstrainedBraninCurrin(
BotorchTestCase,
BaseTestProblemTestCaseMixIn,
MultiObjectiveTestProblemTestCaseMixin,
ConstrainedTestProblemTestCaseMixin,
):
functions = [
BNH(),
SRN(),
CONSTR(),
ConstrainedBraninCurrin(),
C2DTLZ2(dim=3, num_objectives=2),
DiscBrake(),
WeldedBeam(),
OSY(),
]

def test_c2dtlz2_batch_exception(self):

class TestC2DTLZ2(
BotorchTestCase,
BaseTestProblemTestCaseMixIn,
MultiObjectiveTestProblemTestCaseMixin,
ConstrainedTestProblemTestCaseMixin,
):
functions = [C2DTLZ2(dim=3, num_objectives=2)]

def test_batch_exception(self):
f = C2DTLZ2(dim=3, num_objectives=2)
with self.assertRaises(NotImplementedError):
f.evaluate_slack_true(torch.empty(1, 1, 3))


class TestDiscBrake(
BotorchTestCase,
BaseTestProblemTestCaseMixIn,
MultiObjectiveTestProblemTestCaseMixin,
ConstrainedTestProblemTestCaseMixin,
):
functions = [DiscBrake()]


class TestWeldedBeam(
BotorchTestCase,
BaseTestProblemTestCaseMixIn,
MultiObjectiveTestProblemTestCaseMixin,
ConstrainedTestProblemTestCaseMixin,
):
functions = [WeldedBeam()]


class TestOSY(
BotorchTestCase,
BaseTestProblemTestCaseMixIn,
MultiObjectiveTestProblemTestCaseMixin,
ConstrainedTestProblemTestCaseMixin,
):
functions = [OSY()]
Loading

0 comments on commit c796611

Please sign in to comment.