diff --git a/ax/benchmark/tests/problems/test_mixed_integer_problems.py b/ax/benchmark/tests/problems/test_mixed_integer_problems.py index fa6cb400515..975f775df73 100644 --- a/ax/benchmark/tests/problems/test_mixed_integer_problems.py +++ b/ax/benchmark/tests/problems/test_mixed_integer_problems.py @@ -5,9 +5,10 @@ # pyre-strict -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import torch +from ax.benchmark.benchmark_problem import BenchmarkProblem from ax.benchmark.problems.synthetic.discretized.mixed_integer import ( get_discrete_ackley, @@ -19,7 +20,7 @@ from ax.core.parameter import ParameterType from ax.core.trial import Trial from ax.utils.common.testutils import TestCase -from ax.utils.common.typeutils import checked_cast, not_none +from ax.utils.common.typeutils import checked_cast from botorch.test_functions.synthetic import Ackley, Hartmann, Rosenbrock @@ -61,63 +62,60 @@ def test_problems(self) -> None: self.assertGreaterEqual(problem.optimal_value, problem_cls().optimal_value) # Test that they match correctly to the original problems. - # Hartmann - evaluate at 0 - should correspond to 0. - runner = checked_cast(BotorchTestProblemRunner, get_discrete_hartmann().runner) - mock_call = MagicMock(return_value=torch.tensor(0.0)) - runner.test_problem.evaluate_true = mock_call - trial = Trial(experiment=MagicMock()) - arm = Arm(parameters={f"x{i + 1}": 0.0 for i in range(6)}, name="--") - trial.add_arm(arm) - runner.run(trial) - actual = mock_call.call_args[0][0] - self.assertTrue(torch.allclose(actual, torch.zeros(6, dtype=actual.dtype))) - # Evaluate at 3, 3, 19, 19, 1, 1 - corresponds to 1. - arm = not_none(trial.arm) - arm._parameters = { - "x1": 3, - "x2": 3, - "x3": 19, - "x4": 19, - "x5": 1.0, - "x6": 1.0, - } - runner.run(trial) - actual = mock_call.call_args[0][0] - self.assertTrue(torch.allclose(actual, torch.ones(6, dtype=actual.dtype))) - # Ackley - evaluate at 0 - corresponds to 0. - runner = checked_cast(BotorchTestProblemRunner, get_discrete_ackley().runner) - runner.test_problem.evaluate_true = mock_call - arm._parameters = {f"x{i+1}": 0.0 for i in range(13)} - runner.run(trial) - actual = mock_call.call_args[0][0] - self.assertTrue(torch.allclose(actual, torch.zeros(13, dtype=actual.dtype))) - # Evaluate at 2 x 5, 4 x 5, 1.0 x 3 - corresponds to 1. - arm._parameters = { - **{f"x{i+1}": 2 for i in range(0, 5)}, - **{f"x{i+1}": 4 for i in range(5, 10)}, - **{f"x{i+1}": 1.0 for i in range(10, 13)}, - } - runner.run(trial) - actual = mock_call.call_args[0][0] - self.assertTrue(torch.allclose(actual, torch.ones(13, dtype=actual.dtype))) - # Rosenbrock - evaluate at 0 - corresponds to -5.0. - runner = checked_cast( - BotorchTestProblemRunner, get_discrete_rosenbrock().runner - ) - runner.test_problem.evaluate_true = mock_call - arm._parameters = {f"x{i+1}": 0.0 for i in range(10)} - runner.run(trial) - actual = mock_call.call_args[0][0] - self.assertTrue( - torch.allclose(actual, torch.full((10,), -5.0, dtype=actual.dtype)) - ) - # Evaluate at 3 x 6, 1.0 x 4 - corresponds to 10.0. - arm._parameters = { - **{f"x{i+1}": 3 for i in range(0, 6)}, - **{f"x{i+1}": 1.0 for i in range(6, 10)}, - } - runner.run(trial) - actual = mock_call.call_args[0][0] - self.assertTrue( - torch.allclose(actual, torch.full((10,), 10.0, dtype=actual.dtype)) - ) + cases: list[tuple[BenchmarkProblem, dict[str, float], torch.Tensor]] = [ + ( + get_discrete_hartmann(), + {f"x{i+1}": 0.0 for i in range(6)}, + torch.zeros(6, dtype=torch.double), + ), + ( + get_discrete_hartmann(), + {"x1": 3, "x2": 3, "x3": 19, "x4": 19, "x5": 1.0, "x6": 1.0}, + torch.ones(6, dtype=torch.double), + ), + ( + get_discrete_ackley(), + {f"x{i+1}": 0.0 for i in range(13)}, + torch.zeros(13, dtype=torch.double), + ), + ( + get_discrete_ackley(), + { + **{f"x{i+1}": 2 for i in range(0, 5)}, + **{f"x{i+1}": 4 for i in range(5, 10)}, + **{f"x{i+1}": 1.0 for i in range(10, 13)}, + }, + torch.ones(13, dtype=torch.double), + ), + ( + get_discrete_rosenbrock(), + {f"x{i+1}": 0.0 for i in range(10)}, + torch.full((10,), -5.0, dtype=torch.double), + ), + ( + get_discrete_rosenbrock(), + { + **{f"x{i+1}": 3 for i in range(0, 6)}, + **{f"x{i+1}": 1.0 for i in range(6, 10)}, + }, + torch.full((10,), 10.0, dtype=torch.double), + ), + ] + + for problem, params, expected_arg in cases: + + runner = checked_cast(BotorchTestProblemRunner, problem.runner) + trial = Trial(experiment=MagicMock()) + # pyre-fixme: Incompatible parameter type [6]: In call + # `Arm.__init__`, for argument `parameters`, expected `Dict[str, + # Union[None, bool, float, int, str]]` but got `dict[str, float]`. + arm = Arm(parameters=params, name="--") + trial.add_arm(arm) + with patch.object( + runner.test_problem, + attribute="evaluate_true", + wraps=runner.test_problem.evaluate_true, + ) as mock_call: + runner.run(trial) + actual = mock_call.call_args[0][0] + self.assertTrue(torch.allclose(actual, expected_arg))