Skip to content

Commit

Permalink
Make unit tests for mixed integer problems less repetitive (#2781)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2781

* Loop over six cases instead of repeating code
* Have a mock wrap the original function so the original function still gets called

Reviewed By: saitcakmak

Differential Revision: D63326855

fbshipit-source-id: f4afbf54618c5b7affcabb2414799cebc0bf042e
  • Loading branch information
esantorella authored and facebook-github-bot committed Sep 24, 2024
1 parent 7a0ffb5 commit 8ba8ce3
Showing 1 changed file with 60 additions and 62 deletions.
122 changes: 60 additions & 62 deletions ax/benchmark/tests/problems/test_mixed_integer_problems.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@

# pyre-strict

from unittest.mock import MagicMock
from unittest.mock import MagicMock, patch

import torch
from ax.benchmark.benchmark_problem import BenchmarkProblem

from ax.benchmark.problems.synthetic.discretized.mixed_integer import (
get_discrete_ackley,
Expand All @@ -19,7 +20,7 @@
from ax.core.parameter import ParameterType
from ax.core.trial import Trial
from ax.utils.common.testutils import TestCase
from ax.utils.common.typeutils import checked_cast, not_none
from ax.utils.common.typeutils import checked_cast
from botorch.test_functions.synthetic import Ackley, Hartmann, Rosenbrock


Expand Down Expand Up @@ -61,63 +62,60 @@ def test_problems(self) -> None:
self.assertGreaterEqual(problem.optimal_value, problem_cls().optimal_value)

# Test that they match correctly to the original problems.
# Hartmann - evaluate at 0 - should correspond to 0.
runner = checked_cast(BotorchTestProblemRunner, get_discrete_hartmann().runner)
mock_call = MagicMock(return_value=torch.tensor(0.0))
runner.test_problem.evaluate_true = mock_call
trial = Trial(experiment=MagicMock())
arm = Arm(parameters={f"x{i + 1}": 0.0 for i in range(6)}, name="--")
trial.add_arm(arm)
runner.run(trial)
actual = mock_call.call_args[0][0]
self.assertTrue(torch.allclose(actual, torch.zeros(6, dtype=actual.dtype)))
# Evaluate at 3, 3, 19, 19, 1, 1 - corresponds to 1.
arm = not_none(trial.arm)
arm._parameters = {
"x1": 3,
"x2": 3,
"x3": 19,
"x4": 19,
"x5": 1.0,
"x6": 1.0,
}
runner.run(trial)
actual = mock_call.call_args[0][0]
self.assertTrue(torch.allclose(actual, torch.ones(6, dtype=actual.dtype)))
# Ackley - evaluate at 0 - corresponds to 0.
runner = checked_cast(BotorchTestProblemRunner, get_discrete_ackley().runner)
runner.test_problem.evaluate_true = mock_call
arm._parameters = {f"x{i+1}": 0.0 for i in range(13)}
runner.run(trial)
actual = mock_call.call_args[0][0]
self.assertTrue(torch.allclose(actual, torch.zeros(13, dtype=actual.dtype)))
# Evaluate at 2 x 5, 4 x 5, 1.0 x 3 - corresponds to 1.
arm._parameters = {
**{f"x{i+1}": 2 for i in range(0, 5)},
**{f"x{i+1}": 4 for i in range(5, 10)},
**{f"x{i+1}": 1.0 for i in range(10, 13)},
}
runner.run(trial)
actual = mock_call.call_args[0][0]
self.assertTrue(torch.allclose(actual, torch.ones(13, dtype=actual.dtype)))
# Rosenbrock - evaluate at 0 - corresponds to -5.0.
runner = checked_cast(
BotorchTestProblemRunner, get_discrete_rosenbrock().runner
)
runner.test_problem.evaluate_true = mock_call
arm._parameters = {f"x{i+1}": 0.0 for i in range(10)}
runner.run(trial)
actual = mock_call.call_args[0][0]
self.assertTrue(
torch.allclose(actual, torch.full((10,), -5.0, dtype=actual.dtype))
)
# Evaluate at 3 x 6, 1.0 x 4 - corresponds to 10.0.
arm._parameters = {
**{f"x{i+1}": 3 for i in range(0, 6)},
**{f"x{i+1}": 1.0 for i in range(6, 10)},
}
runner.run(trial)
actual = mock_call.call_args[0][0]
self.assertTrue(
torch.allclose(actual, torch.full((10,), 10.0, dtype=actual.dtype))
)
cases: list[tuple[BenchmarkProblem, dict[str, float], torch.Tensor]] = [
(
get_discrete_hartmann(),
{f"x{i+1}": 0.0 for i in range(6)},
torch.zeros(6, dtype=torch.double),
),
(
get_discrete_hartmann(),
{"x1": 3, "x2": 3, "x3": 19, "x4": 19, "x5": 1.0, "x6": 1.0},
torch.ones(6, dtype=torch.double),
),
(
get_discrete_ackley(),
{f"x{i+1}": 0.0 for i in range(13)},
torch.zeros(13, dtype=torch.double),
),
(
get_discrete_ackley(),
{
**{f"x{i+1}": 2 for i in range(0, 5)},
**{f"x{i+1}": 4 for i in range(5, 10)},
**{f"x{i+1}": 1.0 for i in range(10, 13)},
},
torch.ones(13, dtype=torch.double),
),
(
get_discrete_rosenbrock(),
{f"x{i+1}": 0.0 for i in range(10)},
torch.full((10,), -5.0, dtype=torch.double),
),
(
get_discrete_rosenbrock(),
{
**{f"x{i+1}": 3 for i in range(0, 6)},
**{f"x{i+1}": 1.0 for i in range(6, 10)},
},
torch.full((10,), 10.0, dtype=torch.double),
),
]

for problem, params, expected_arg in cases:

runner = checked_cast(BotorchTestProblemRunner, problem.runner)
trial = Trial(experiment=MagicMock())
# pyre-fixme: Incompatible parameter type [6]: In call
# `Arm.__init__`, for argument `parameters`, expected `Dict[str,
# Union[None, bool, float, int, str]]` but got `dict[str, float]`.
arm = Arm(parameters=params, name="--")
trial.add_arm(arm)
with patch.object(
runner.test_problem,
attribute="evaluate_true",
wraps=runner.test_problem.evaluate_true,
) as mock_call:
runner.run(trial)
actual = mock_call.call_args[0][0]
self.assertTrue(torch.allclose(actual, expected_arg))

0 comments on commit 8ba8ce3

Please sign in to comment.