From 7d79d521a672ae5d6178a33c4784d253c812ebae Mon Sep 17 00:00:00 2001 From: Max Balandat Date: Thu, 18 May 2023 19:20:08 -0700 Subject: [PATCH] Refactor test function test utils (#1839) Summary: Pull Request resolved: https://github.com/pytorch/botorch/pull/1839 Refactors the test utils for botorch test functions to make things more modular. Instead of subclassing the base test cases, we now have class Mixins for synthetic, constrained, and multi-objective problems. This means we can mix-and-match the test case as needed. This also flattens the `TestMultiObjectiveProblems` and `TestConstrainedMultiObjectiveProblems` test collections (the idea behind the test design is that we have separate test cases for each problem (potentially with different arguments) to better identify and group the tests, so this reestablishes that. Reviewed By: esantorella Differential Revision: D45969036 fbshipit-source-id: 7e9ec80ed93bab55e857761dd462545de44ceca9 --- botorch/utils/testing.py | 108 ++++++------ test/test_functions/test_multi_fidelity.py | 18 +- test/test_functions/test_multi_objective.py | 160 +++++++++++++++--- .../test_multi_objective_multi_fidelity.py | 18 +- test/test_functions/test_synthetic.py | 88 +++++++--- 5 files changed, 289 insertions(+), 103 deletions(-) diff --git a/botorch/utils/testing.py b/botorch/utils/testing.py index 0754378f8c..34947d0857 100644 --- a/botorch/utils/testing.py +++ b/botorch/utils/testing.py @@ -8,6 +8,7 @@ import math import warnings +from abc import abstractproperty from collections import OrderedDict from typing import Any, List, Optional, Tuple, Union from unittest import TestCase @@ -93,10 +94,7 @@ def assertAllClose( ) -class BaseTestProblemBaseTestCase: - - functions: List[BaseTestProblem] - +class BaseTestProblemTestCaseMixIn: def test_forward(self): for dtype in (torch.float, torch.double): for batch_shape in (torch.Size(), torch.Size([2]), torch.Size([2, 3])): @@ -113,8 +111,14 @@ def test_forward(self): ) self.assertEqual(res.shape, batch_shape + tail_shape) + @abstractproperty + def functions(self) -> List[BaseTestProblem]: + # The functions that should be tested. Typically defined as a class + # attribute on the test case subclassing this class. + pass # pragma: no cover + -class SyntheticTestFunctionBaseTestCase(BaseTestProblemBaseTestCase): +class SyntheticTestFunctionTestCaseMixin: def test_optimal_value(self): for dtype in (torch.float, torch.double): for f in self.functions: @@ -143,6 +147,52 @@ def test_optimizer(self): self.assertLess(grad.abs().max().item(), 1e-3) +class MultiObjectiveTestProblemTestCaseMixin: + def test_attributes(self): + for f in self.functions: + self.assertTrue(hasattr(f, "dim")) + self.assertTrue(hasattr(f, "num_objectives")) + self.assertEqual(f.bounds.shape, torch.Size([2, f.dim])) + + def test_max_hv(self): + for dtype in (torch.float, torch.double): + for f in self.functions: + f.to(device=self.device, dtype=dtype) + if not hasattr(f, "_max_hv"): + with self.assertRaises(NotImplementedError): + f.max_hv + else: + self.assertEqual(f.max_hv, f._max_hv) + + def test_ref_point(self): + for dtype in (torch.float, torch.double): + for f in self.functions: + f.to(dtype=dtype, device=self.device) + self.assertTrue( + torch.allclose( + f.ref_point, + torch.tensor(f._ref_point, dtype=dtype, device=self.device), + ) + ) + + +class ConstrainedTestProblemTestCaseMixin: + def test_num_constraints(self): + for f in self.functions: + self.assertTrue(hasattr(f, "num_constraints")) + + def test_evaluate_slack_true(self): + for dtype in (torch.float, torch.double): + for f in self.functions: + f.to(device=self.device, dtype=dtype) + X = unnormalize( + torch.rand(1, f.dim, device=self.device, dtype=dtype), + bounds=f.bounds, + ) + slack = f.evaluate_slack_true(X) + self.assertEqual(slack.shape, torch.Size([1, f.num_constraints])) + + class MockPosterior(Posterior): r"""Mock object that implements dummy methods and feeds through specified outputs""" @@ -368,51 +418,3 @@ def _get_test_posterior( covar = covar + torch.diag_embed(flat_diag) mtmvn = MultitaskMultivariateNormal(mean, covar, interleaved=interleaved) return GPyTorchPosterior(mtmvn) - - -class MultiObjectiveTestProblemBaseTestCase(BaseTestProblemBaseTestCase): - def test_attributes(self): - for f in self.functions: - self.assertTrue(hasattr(f, "dim")) - self.assertTrue(hasattr(f, "num_objectives")) - self.assertEqual(f.bounds.shape, torch.Size([2, f.dim])) - - def test_max_hv(self): - for dtype in (torch.float, torch.double): - for f in self.functions: - f.to(device=self.device, dtype=dtype) - if not hasattr(f, "_max_hv"): - with self.assertRaises(NotImplementedError): - f.max_hv - else: - self.assertEqual(f.max_hv, f._max_hv) - - def test_ref_point(self): - for dtype in (torch.float, torch.double): - for f in self.functions: - f.to(dtype=dtype, device=self.device) - self.assertTrue( - torch.allclose( - f.ref_point, - torch.tensor(f._ref_point, dtype=dtype, device=self.device), - ) - ) - - -class ConstrainedMultiObjectiveTestProblemBaseTestCase( - MultiObjectiveTestProblemBaseTestCase -): - def test_num_constraints(self): - for f in self.functions: - self.assertTrue(hasattr(f, "num_constraints")) - - def test_evaluate_slack_true(self): - for dtype in (torch.float, torch.double): - for f in self.functions: - f.to(device=self.device, dtype=dtype) - X = unnormalize( - torch.rand(1, f.dim, device=self.device, dtype=dtype), - bounds=f.bounds, - ) - slack = f.evaluate_slack_true(X) - self.assertEqual(slack.shape, torch.Size([1, f.num_constraints])) diff --git a/test/test_functions/test_multi_fidelity.py b/test/test_functions/test_multi_fidelity.py index 68482e6de7..d1e8257c57 100644 --- a/test/test_functions/test_multi_fidelity.py +++ b/test/test_functions/test_multi_fidelity.py @@ -9,10 +9,16 @@ AugmentedHartmann, AugmentedRosenbrock, ) -from botorch.utils.testing import BotorchTestCase, SyntheticTestFunctionBaseTestCase +from botorch.utils.testing import ( + BaseTestProblemTestCaseMixIn, + BotorchTestCase, + SyntheticTestFunctionTestCaseMixin, +) -class TestAugmentedBranin(SyntheticTestFunctionBaseTestCase, BotorchTestCase): +class TestAugmentedBranin( + BotorchTestCase, BaseTestProblemTestCaseMixIn, SyntheticTestFunctionTestCaseMixin +): functions = [ AugmentedBranin(), @@ -21,7 +27,9 @@ class TestAugmentedBranin(SyntheticTestFunctionBaseTestCase, BotorchTestCase): ] -class TestAugmentedHartmann(SyntheticTestFunctionBaseTestCase, BotorchTestCase): +class TestAugmentedHartmann( + BotorchTestCase, BaseTestProblemTestCaseMixIn, SyntheticTestFunctionTestCaseMixin +): functions = [ AugmentedHartmann(), @@ -30,7 +38,9 @@ class TestAugmentedHartmann(SyntheticTestFunctionBaseTestCase, BotorchTestCase): ] -class TestAugmentedRosenbrock(SyntheticTestFunctionBaseTestCase, BotorchTestCase): +class TestAugmentedRosenbrock( + BotorchTestCase, BaseTestProblemTestCaseMixIn, SyntheticTestFunctionTestCaseMixin +): functions = [ AugmentedRosenbrock(), diff --git a/test/test_functions/test_multi_objective.py b/test/test_functions/test_multi_objective.py index 4b3ac32f50..6a5295b4e5 100644 --- a/test/test_functions/test_multi_objective.py +++ b/test/test_functions/test_multi_objective.py @@ -40,9 +40,10 @@ ZDT3, ) from botorch.utils.testing import ( + BaseTestProblemTestCaseMixIn, BotorchTestCase, - ConstrainedMultiObjectiveTestProblemBaseTestCase, - MultiObjectiveTestProblemBaseTestCase, + ConstrainedTestProblemTestCaseMixin, + MultiObjectiveTestProblemTestCaseMixin, ) @@ -74,7 +75,11 @@ def test_base_mo_problem(self): f.gen_pareto_front(1) -class TestBraninCurrin(MultiObjectiveTestProblemBaseTestCase, BotorchTestCase): +class TestBraninCurrin( + BotorchTestCase, + BaseTestProblemTestCaseMixIn, + MultiObjectiveTestProblemTestCaseMixin, +): functions = [BraninCurrin()] def test_init(self): @@ -83,7 +88,11 @@ def test_init(self): self.assertEqual(f.dim, 2) -class TestDH(MultiObjectiveTestProblemBaseTestCase, BotorchTestCase): +class TestDH( + BotorchTestCase, + BaseTestProblemTestCaseMixIn, + MultiObjectiveTestProblemTestCaseMixin, +): functions = [DH1(dim=2), DH2(dim=3), DH3(dim=4), DH4(dim=5)] dims = [2, 3, 4, 5] bounds = [ @@ -118,7 +127,11 @@ def test_function_values(self): self.assertAllClose(actual, expected) -class TestDTLZ(MultiObjectiveTestProblemBaseTestCase, BotorchTestCase): +class TestDTLZ( + BotorchTestCase, + BaseTestProblemTestCaseMixIn, + MultiObjectiveTestProblemTestCaseMixin, +): functions = [ DTLZ1(dim=5, num_objectives=2), DTLZ2(dim=5, num_objectives=2), @@ -180,7 +193,11 @@ def test_gen_pareto_front(self): ) -class TestGMM(MultiObjectiveTestProblemBaseTestCase, BotorchTestCase): +class TestGMM( + BotorchTestCase, + BaseTestProblemTestCaseMixIn, + MultiObjectiveTestProblemTestCaseMixin, +): functions = [GMM(num_objectives=4)] def test_init(self): @@ -226,7 +243,12 @@ def test_result(self): ) -class TestMW7(ConstrainedMultiObjectiveTestProblemBaseTestCase, BotorchTestCase): +class TestMW7( + BotorchTestCase, + BaseTestProblemTestCaseMixIn, + MultiObjectiveTestProblemTestCaseMixin, + ConstrainedTestProblemTestCaseMixin, +): functions = [MW7(dim=3)] def test_init(self): @@ -237,7 +259,11 @@ def test_init(self): self.assertEqual(f.dim, 3) -class TestZDT(MultiObjectiveTestProblemBaseTestCase, BotorchTestCase): +class TestZDT( + BotorchTestCase, + BaseTestProblemTestCaseMixIn, + MultiObjectiveTestProblemTestCaseMixin, +): functions = [ ZDT1(dim=3, num_objectives=2), ZDT2(dim=3, num_objectives=2), @@ -304,27 +330,119 @@ def test_gen_pareto_front(self): ) -class TestMultiObjectiveProblems( - MultiObjectiveTestProblemBaseTestCase, BotorchTestCase +# ------------------ Unconstrained Multi-objective test problems ------------------ # + + +class TestCarSideImpact( + BotorchTestCase, + BaseTestProblemTestCaseMixIn, + MultiObjectiveTestProblemTestCaseMixin, +): + + functions = [CarSideImpact()] + + +class TestPenicillin( + BotorchTestCase, + BaseTestProblemTestCaseMixIn, + MultiObjectiveTestProblemTestCaseMixin, ): - functions = [CarSideImpact(), Penicillin(), ToyRobust(), VehicleSafety()] + functions = [Penicillin()] -class TestConstrainedMultiObjectiveProblems( - ConstrainedMultiObjectiveTestProblemBaseTestCase, BotorchTestCase +class TestToyRobust( + BotorchTestCase, + BaseTestProblemTestCaseMixIn, + MultiObjectiveTestProblemTestCaseMixin, +): + functions = [ToyRobust()] + + +class TestVehicleSafety( + BotorchTestCase, + BaseTestProblemTestCaseMixIn, + MultiObjectiveTestProblemTestCaseMixin, +): + functions = [VehicleSafety()] + + +# ------------------ Constrained Multi-objective test problems ------------------ # + + +class TestBNH( + BotorchTestCase, + BaseTestProblemTestCaseMixIn, + MultiObjectiveTestProblemTestCaseMixin, + ConstrainedTestProblemTestCaseMixin, +): + functions = [BNH()] + + +class TestSRN( + BotorchTestCase, + BaseTestProblemTestCaseMixIn, + MultiObjectiveTestProblemTestCaseMixin, + ConstrainedTestProblemTestCaseMixin, +): + functions = [SRN()] + + +class TestCONSTR( + BotorchTestCase, + BaseTestProblemTestCaseMixIn, + MultiObjectiveTestProblemTestCaseMixin, + ConstrainedTestProblemTestCaseMixin, +): + functions = [CONSTR()] + + +class TestConstrainedBraninCurrin( + BotorchTestCase, + BaseTestProblemTestCaseMixIn, + MultiObjectiveTestProblemTestCaseMixin, + ConstrainedTestProblemTestCaseMixin, ): functions = [ - BNH(), - SRN(), - CONSTR(), ConstrainedBraninCurrin(), - C2DTLZ2(dim=3, num_objectives=2), - DiscBrake(), - WeldedBeam(), - OSY(), ] - def test_c2dtlz2_batch_exception(self): + +class TestC2DTLZ2( + BotorchTestCase, + BaseTestProblemTestCaseMixIn, + MultiObjectiveTestProblemTestCaseMixin, + ConstrainedTestProblemTestCaseMixin, +): + functions = [C2DTLZ2(dim=3, num_objectives=2)] + + def test_batch_exception(self): f = C2DTLZ2(dim=3, num_objectives=2) with self.assertRaises(NotImplementedError): f.evaluate_slack_true(torch.empty(1, 1, 3)) + + +class TestDiscBrake( + BotorchTestCase, + BaseTestProblemTestCaseMixIn, + MultiObjectiveTestProblemTestCaseMixin, + ConstrainedTestProblemTestCaseMixin, +): + functions = [DiscBrake()] + + +class TestWeldedBeam( + BotorchTestCase, + BaseTestProblemTestCaseMixIn, + MultiObjectiveTestProblemTestCaseMixin, + ConstrainedTestProblemTestCaseMixin, +): + functions = [WeldedBeam()] + + +class TestOSY( + BotorchTestCase, + BaseTestProblemTestCaseMixIn, + MultiObjectiveTestProblemTestCaseMixin, + ConstrainedTestProblemTestCaseMixin, +): + functions = [OSY()] diff --git a/test/test_functions/test_multi_objective_multi_fidelity.py b/test/test_functions/test_multi_objective_multi_fidelity.py index 7563803a99..b4ceae1c6e 100644 --- a/test/test_functions/test_multi_objective_multi_fidelity.py +++ b/test/test_functions/test_multi_objective_multi_fidelity.py @@ -8,10 +8,18 @@ MOMFBraninCurrin, MOMFPark, ) -from botorch.utils.testing import BotorchTestCase, MultiObjectiveTestProblemBaseTestCase +from botorch.utils.testing import ( + BaseTestProblemTestCaseMixIn, + BotorchTestCase, + MultiObjectiveTestProblemTestCaseMixin, +) -class TestMOMFBraninCurrin(MultiObjectiveTestProblemBaseTestCase, BotorchTestCase): +class TestMOMFBraninCurrin( + BotorchTestCase, + BaseTestProblemTestCaseMixIn, + MultiObjectiveTestProblemTestCaseMixin, +): functions = [MOMFBraninCurrin()] bounds = [[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]] @@ -24,7 +32,11 @@ def test_init(self): ) -class TestMOMFPark(MultiObjectiveTestProblemBaseTestCase, BotorchTestCase): +class TestMOMFPark( + BotorchTestCase, + BaseTestProblemTestCaseMixIn, + MultiObjectiveTestProblemTestCaseMixin, +): functions = [MOMFPark()] bounds = [[0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0, 1.0]] diff --git a/test/test_functions/test_synthetic.py b/test/test_functions/test_synthetic.py index 457d0859d7..9c37979970 100644 --- a/test/test_functions/test_synthetic.py +++ b/test/test_functions/test_synthetic.py @@ -29,7 +29,11 @@ SyntheticTestFunction, ThreeHumpCamel, ) -from botorch.utils.testing import BotorchTestCase, SyntheticTestFunctionBaseTestCase +from botorch.utils.testing import ( + BaseTestProblemTestCaseMixIn, + BotorchTestCase, + SyntheticTestFunctionTestCaseMixin, +) from torch import Tensor @@ -46,7 +50,7 @@ class DummySyntheticTestFunctionWithOptimizers(DummySyntheticTestFunction): _optimizers = [(0, 0)] -class TestSyntheticTestFunction(BotorchTestCase): +class TestCustomBounds(BotorchTestCase): functions_with_custom_bounds = [ # Function name and the default dimension. (Ackley, 2), (Beale, 2), @@ -100,37 +104,51 @@ def test_custom_bounds(self): self.assertTrue(torch.allclose(func.bounds, bounds_tensor)) -class TestAckley(SyntheticTestFunctionBaseTestCase, BotorchTestCase): +class TestAckley( + BotorchTestCase, BaseTestProblemTestCaseMixIn, SyntheticTestFunctionTestCaseMixin +): functions = [Ackley(), Ackley(negate=True), Ackley(noise_std=0.1), Ackley(dim=3)] -class TestBeale(SyntheticTestFunctionBaseTestCase, BotorchTestCase): +class TestBeale( + BotorchTestCase, BaseTestProblemTestCaseMixIn, SyntheticTestFunctionTestCaseMixin +): functions = [Beale(), Beale(negate=True), Beale(noise_std=0.1)] -class TestBranin(SyntheticTestFunctionBaseTestCase, BotorchTestCase): +class TestBranin( + BotorchTestCase, BaseTestProblemTestCaseMixIn, SyntheticTestFunctionTestCaseMixin +): functions = [Branin(), Branin(negate=True), Branin(noise_std=0.1)] -class TestBukin(SyntheticTestFunctionBaseTestCase, BotorchTestCase): +class TestBukin( + BotorchTestCase, BaseTestProblemTestCaseMixIn, SyntheticTestFunctionTestCaseMixin +): functions = [Bukin(), Bukin(negate=True), Bukin(noise_std=0.1)] -class TestCosine8(SyntheticTestFunctionBaseTestCase, BotorchTestCase): +class TestCosine8( + BotorchTestCase, BaseTestProblemTestCaseMixIn, SyntheticTestFunctionTestCaseMixin +): functions = [Cosine8(), Cosine8(negate=True), Cosine8(noise_std=0.1)] -class TestDropWave(SyntheticTestFunctionBaseTestCase, BotorchTestCase): +class TestDropWave( + BotorchTestCase, BaseTestProblemTestCaseMixIn, SyntheticTestFunctionTestCaseMixin +): functions = [DropWave(), DropWave(negate=True), DropWave(noise_std=0.1)] -class TestDixonPrice(SyntheticTestFunctionBaseTestCase, BotorchTestCase): +class TestDixonPrice( + BotorchTestCase, BaseTestProblemTestCaseMixIn, SyntheticTestFunctionTestCaseMixin +): functions = [ DixonPrice(), @@ -140,12 +158,16 @@ class TestDixonPrice(SyntheticTestFunctionBaseTestCase, BotorchTestCase): ] -class TestEggHolder(SyntheticTestFunctionBaseTestCase, BotorchTestCase): +class TestEggHolder( + BotorchTestCase, BaseTestProblemTestCaseMixIn, SyntheticTestFunctionTestCaseMixin +): functions = [EggHolder(), EggHolder(negate=True), EggHolder(noise_std=0.1)] -class TestGriewank(SyntheticTestFunctionBaseTestCase, BotorchTestCase): +class TestGriewank( + BotorchTestCase, BaseTestProblemTestCaseMixIn, SyntheticTestFunctionTestCaseMixin +): functions = [ Griewank(), @@ -155,7 +177,9 @@ class TestGriewank(SyntheticTestFunctionBaseTestCase, BotorchTestCase): ] -class TestHartmann(SyntheticTestFunctionBaseTestCase, BotorchTestCase): +class TestHartmann( + BotorchTestCase, BaseTestProblemTestCaseMixIn, SyntheticTestFunctionTestCaseMixin +): functions = [ Hartmann(), @@ -174,12 +198,16 @@ def test_dimension(self): Hartmann(dim=2) -class TestHolderTable(SyntheticTestFunctionBaseTestCase, BotorchTestCase): +class TestHolderTable( + BotorchTestCase, BaseTestProblemTestCaseMixIn, SyntheticTestFunctionTestCaseMixin +): functions = [HolderTable(), HolderTable(negate=True), HolderTable(noise_std=0.1)] -class TestLevy(SyntheticTestFunctionBaseTestCase, BotorchTestCase): +class TestLevy( + BotorchTestCase, BaseTestProblemTestCaseMixIn, SyntheticTestFunctionTestCaseMixin +): functions = [ Levy(), @@ -191,7 +219,9 @@ class TestLevy(SyntheticTestFunctionBaseTestCase, BotorchTestCase): ] -class TestMichalewicz(SyntheticTestFunctionBaseTestCase, BotorchTestCase): +class TestMichalewicz( + BotorchTestCase, BaseTestProblemTestCaseMixIn, SyntheticTestFunctionTestCaseMixin +): functions = [ Michalewicz(), @@ -206,12 +236,16 @@ class TestMichalewicz(SyntheticTestFunctionBaseTestCase, BotorchTestCase): ] -class TestPowell(SyntheticTestFunctionBaseTestCase, BotorchTestCase): +class TestPowell( + BotorchTestCase, BaseTestProblemTestCaseMixIn, SyntheticTestFunctionTestCaseMixin +): functions = [Powell(), Powell(negate=True), Powell(noise_std=0.1)] -class TestRastrigin(SyntheticTestFunctionBaseTestCase, BotorchTestCase): +class TestRastrigin( + BotorchTestCase, BaseTestProblemTestCaseMixIn, SyntheticTestFunctionTestCaseMixin +): functions = [ Rastrigin(), @@ -223,7 +257,9 @@ class TestRastrigin(SyntheticTestFunctionBaseTestCase, BotorchTestCase): ] -class TestRosenbrock(SyntheticTestFunctionBaseTestCase, BotorchTestCase): +class TestRosenbrock( + BotorchTestCase, BaseTestProblemTestCaseMixIn, SyntheticTestFunctionTestCaseMixin +): functions = [ Rosenbrock(), @@ -235,17 +271,23 @@ class TestRosenbrock(SyntheticTestFunctionBaseTestCase, BotorchTestCase): ] -class TestShekel(SyntheticTestFunctionBaseTestCase, BotorchTestCase): +class TestShekel( + BotorchTestCase, BaseTestProblemTestCaseMixIn, SyntheticTestFunctionTestCaseMixin +): functions = [Shekel(), Shekel(negate=True), Shekel(noise_std=0.1)] -class TestSixHumpCamel(SyntheticTestFunctionBaseTestCase, BotorchTestCase): +class TestSixHumpCamel( + BotorchTestCase, BaseTestProblemTestCaseMixIn, SyntheticTestFunctionTestCaseMixin +): functions = [SixHumpCamel(), SixHumpCamel(negate=True), SixHumpCamel(noise_std=0.1)] -class TestStyblinskiTang(SyntheticTestFunctionBaseTestCase, BotorchTestCase): +class TestStyblinskiTang( + BotorchTestCase, BaseTestProblemTestCaseMixIn, SyntheticTestFunctionTestCaseMixin +): functions = [ StyblinskiTang(), @@ -257,7 +299,9 @@ class TestStyblinskiTang(SyntheticTestFunctionBaseTestCase, BotorchTestCase): ] -class TestThreeHumpCamel(SyntheticTestFunctionBaseTestCase, BotorchTestCase): +class TestThreeHumpCamel( + BotorchTestCase, BaseTestProblemTestCaseMixIn, SyntheticTestFunctionTestCaseMixin +): functions = [ ThreeHumpCamel(),