Skip to content

Commit

Permalink
Don't allow unused **kwargs in input_constructors except for a define…
Browse files Browse the repository at this point in the history
…d set of exceptions (facebook#1772)

Summary:
Pull Request resolved: facebook#1772

X-link: pytorch/botorch#1872

[x] Remove unused arguments from input constructors and related functions. The idea is especially not to let unused keyword arguments disappear into `**kwargs` and be silently ignored
[x] add arguments to some input constructors so they don't need any `**kwargs`
[x] Add a decorator that ensures that each input constructor can accept a certain set of keyword arguments, even if those are not used are the constructor, while still erroring on
[ ] Prevent arguments from having different defaults in the input constructors as in acquisition functions

Reviewed By: SebastianAment

Differential Revision: D46519588

fbshipit-source-id: 44ea95d82cd50a3ecebc248759ea5d1db9a3bf51
  • Loading branch information
esantorella authored and facebook-github-bot committed Aug 11, 2023
1 parent 1cc89e9 commit 5ac79cc
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 40 deletions.
2 changes: 1 addition & 1 deletion ax/models/torch/botorch_modular/acquisition.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def __init__(
]

# Subset model only to the outcomes we need for the optimization.
if self.options.get(Keys.SUBSET_MODEL, True):
if self.options.pop(Keys.SUBSET_MODEL, True):
subset_model_results = subset_model(
model=primary_surrogate.model,
objective_weights=torch_opt_config.objective_weights,
Expand Down
6 changes: 4 additions & 2 deletions ax/models/torch/botorch_modular/sebo.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ def __init__(

tkwargs = {"dtype": surrogate.dtype, "device": surrogate.device}
options = options or {}
self.penalty_name: str = options.get("penalty", "L0_norm")
self.target_point: Tensor = options.get("target_point", None)
self.penalty_name: str = options.pop("penalty", "L0_norm")
self.target_point: Tensor = options.pop("target_point", None)
if self.target_point is None:
raise ValueError("please provide target point.")
self.target_point.to(**tkwargs) # pyre-ignore
Expand Down Expand Up @@ -93,6 +93,8 @@ def __init__(
)

# instantiate botorch_acqf_class
if not issubclass(botorch_acqf_class, qExpectedHypervolumeImprovement):
raise ValueError("botorch_acqf_class must be qEHVI to use SEBO")
super().__init__(
surrogates={"sebo": surrogate_f},
search_space_digest=search_space_digest,
Expand Down
84 changes: 49 additions & 35 deletions ax/models/torch/tests/test_acquisition.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@
from ax.exceptions.core import AxWarning, SearchSpaceExhausted
from ax.models.torch.botorch_modular.acquisition import Acquisition
from ax.models.torch.botorch_modular.surrogate import Surrogate
from ax.models.torch.utils import SubsetModelData
from ax.models.torch.utils import (
_get_X_pending_and_observed,
subset_model,
SubsetModelData,
)
from ax.models.torch_base import TorchOptConfig
from ax.utils.common.constants import Keys
from ax.utils.common.testutils import TestCase
Expand Down Expand Up @@ -124,7 +128,7 @@ def setUp(self) -> None:
)
self.linear_constraints = None
self.fixed_features = {1: 2.0}
self.options = {"best_f": 0.0, "cache_root": False, "prune_baseline": False}
self.options = {"cache_root": False, "prune_baseline": False}
self.inequality_constraints = [
(torch.tensor([0, 1], **tkwargs), torch.tensor([-1.0, 1.0], **tkwargs), 1)
]
Expand Down Expand Up @@ -159,41 +163,25 @@ def tearDown(self) -> None:
# Avoid polluting the registry for other tests.
ACQF_INPUT_CONSTRUCTOR_REGISTRY.pop(DummyAcquisitionFunction)

@mock.patch(f"{ACQUISITION_PATH}._get_X_pending_and_observed")
@mock.patch(
f"{ACQUISITION_PATH}.subset_model",
# pyre-fixme[6]: For 1st param expected `Model` but got `None`.
# pyre-fixme[6]: For 5th param expected `Tensor` but got `None`.
return_value=SubsetModelData(None, torch.ones(1), None, None, None),
)
@mock.patch(f"{ACQUISITION_PATH}.get_botorch_objective_and_transform")
@mock.patch(
f"{CURRENT_PATH}.Acquisition.compute_model_dependencies",
return_value={"current_value": 1.2},
)
@mock.patch(
f"{DummyAcquisitionFunction.__module__}.DummyAcquisitionFunction.__init__",
return_value=None,
)
def test_init(
self,
mock_botorch_acqf_class: Mock,
mock_compute_model_deps: Mock,
mock_get_objective_and_transform: Mock,
mock_subset_model: Mock,
mock_get_X: Mock,
) -> None:
def test_init_raises_when_missing_acqf_cls(self) -> None:
with self.assertRaisesRegex(TypeError, ".* missing .* 'botorch_acqf_class'"):
# pyre-fixme[20]: Argument `botorch_acqf_class` expected.
# pyre-ignore[20]: Argument `botorch_acqf_class` expected.
Acquisition(
surrogates={"surrogate": self.surrogate},
search_space_digest=self.search_space_digest,
torch_opt_config=self.torch_opt_config,
)

botorch_objective = LinearMCObjective(weights=torch.tensor([1.0]))
mock_get_objective_and_transform.return_value = (botorch_objective, None)
mock_get_X.return_value = (self.pending_observations[0], self.X[:1])
@mock.patch(
f"{ACQUISITION_PATH}._get_X_pending_and_observed",
wraps=_get_X_pending_and_observed,
)
@mock.patch(f"{ACQUISITION_PATH}.subset_model", wraps=subset_model)
def test_init(
self,
mock_subset_model: Mock,
mock_get_X: Mock,
) -> None:
acquisition = Acquisition(
surrogates={"surrogate": self.surrogate},
search_space_digest=self.search_space_digest,
Expand Down Expand Up @@ -224,10 +212,36 @@ def test_init(
outcome_constraints=self.outcome_constraints,
objective_thresholds=self.objective_thresholds,
)
mock_subset_model.reset_mock()
mock_get_objective_and_transform.reset_mock()
self.mock_input_constructor.reset_mock()
mock_botorch_acqf_class.reset_mock()

@mock.patch(f"{ACQUISITION_PATH}._get_X_pending_and_observed")
@mock.patch(
f"{ACQUISITION_PATH}.subset_model",
# pyre-fixme[6]: For 1st param expected `Model` but got `None`.
# pyre-fixme[6]: For 5th param expected `Tensor` but got `None`.
return_value=SubsetModelData(None, torch.ones(1), None, None, None),
)
@mock.patch(
f"{ACQUISITION_PATH}.get_botorch_objective_and_transform",
)
@mock.patch(
f"{CURRENT_PATH}.Acquisition.compute_model_dependencies",
return_value={"eta": 0.1},
)
@mock.patch(
f"{DummyAcquisitionFunction.__module__}.DummyAcquisitionFunction.__init__",
return_value=None,
)
def test_init_with_subset_model_false(
self,
mock_botorch_acqf_class: Mock,
mock_compute_model_deps: Mock,
mock_get_objective_and_transform: Mock,
mock_subset_model: Mock,
mock_get_X: Mock,
) -> None:
botorch_objective = LinearMCObjective(weights=torch.tensor([1.0]))
mock_get_objective_and_transform.return_value = (botorch_objective, None)
mock_get_X.return_value = (self.pending_observations[0], self.X[:1])
self.options[Keys.SUBSET_MODEL] = False
with mock.patch(
f"{ACQUISITION_PATH}.get_outcome_constraint_transforms",
Expand All @@ -249,7 +263,7 @@ def test_init(
self.assertIs(ckwargs["outcome_constraints"], self.outcome_constraints)
self.assertTrue(torch.equal(ckwargs["X_observed"], self.X[:1]))
# Check final `acqf` creation
model_deps = {Keys.CURRENT_VALUE: 1.2}
model_deps = {"eta": 0.1}
self.mock_input_constructor.assert_called_once()
mock_botorch_acqf_class.assert_called_once()
_, ckwargs = self.mock_input_constructor.call_args
Expand Down
5 changes: 3 additions & 2 deletions ax/utils/testing/benchmark_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ def get_sobol_gpei_benchmark_method() -> BenchmarkMethod:
num_trials=-1,
model_kwargs={
"surrogate": Surrogate(SingleTaskGP),
# TODO: tests should better reflect defaults and not
# re-implement this logic. qLogNEI is a default now, not
# qNEI
"botorch_acqf_class": qNoisyExpectedImprovement,
},
model_gen_kwargs={
Expand All @@ -77,8 +80,6 @@ def get_sobol_gpei_benchmark_method() -> BenchmarkMethod:
},
Keys.ACQF_KWARGS: {
"prune_baseline": True,
"qmc": True,
"mc_samples": 512,
},
}
},
Expand Down

0 comments on commit 5ac79cc

Please sign in to comment.