Skip to content

Commit

Permalink
Have MBM only pass acqf input constructor kwargs to BoTorch when they…
Browse files Browse the repository at this point in the history
… are not None (#2480)

Summary:

Currently, Ax passes a lot of arguments that are not used and BoTorch ignores them. In the future, this will throw an exception rather than being ignored.

Reviewed By: Balandat

Differential Revision: D57387912
  • Loading branch information
esantorella authored and facebook-github-bot committed May 29, 2024
1 parent ba6eb81 commit 4d35e2c
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 96 deletions.
13 changes: 8 additions & 5 deletions ax/models/torch/botorch_modular/acquisition.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,12 +293,16 @@ def __init__(
"constraints": get_outcome_constraint_transforms(
outcome_constraints=outcome_constraints
),
"target_fidelities": target_fidelities,
"bounds": search_space_digest.bounds,
"objective": objective,
"posterior_transform": posterior_transform,
**acqf_model_kwarg,
**model_deps,
**self.options,
}

if len(target_fidelities) > 0:
input_constructor_kwargs["target_fidelities"] = target_fidelities

input_constructor = get_acqf_input_constructor(botorch_acqf_class)
# Handle multi-dataset surrogates - TODO: Improve this
# If there is only one SupervisedDataset return it alone
Expand All @@ -317,9 +321,8 @@ def __init__(

acqf_inputs = input_constructor(
training_data=training_data,
objective=objective,
posterior_transform=posterior_transform,
**input_constructor_kwargs,
bounds=search_space_digest.bounds,
**{k: v for k, v in input_constructor_kwargs.items() if v is not None},
)
self.acqf = botorch_acqf_class(**acqf_inputs) # pyre-ignore [45]
self.X_pending: Optional[Tensor] = unique_Xs_pending
Expand Down
186 changes: 95 additions & 91 deletions ax/models/torch/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import dataclasses
from collections import OrderedDict
from contextlib import ExitStack
from copy import deepcopy
from typing import Dict, Type
from unittest import mock
Expand Down Expand Up @@ -56,6 +55,7 @@
from botorch.utils.datasets import SupervisedDataset
from gpytorch.likelihoods.gaussian_likelihood import FixedNoiseGaussianLikelihood
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood
from pyre_extensions import none_throws


CURRENT_PATH: str = __name__
Expand Down Expand Up @@ -514,15 +514,31 @@ def _test_gen(
search_space_digest=search_space_digest,
torch_opt_config=self.torch_opt_config,
)
# Assert acquisition initialized with expected arguments
mock_init_acqf.assert_called_once_with(
search_space_digest=search_space_digest,
torch_opt_config=self.torch_opt_config,
acq_options=self.acquisition_options,
)
# Assert acquisition initialized with expected arguments
mock_init_acqf.assert_called_once_with(
search_space_digest=search_space_digest,
torch_opt_config=self.torch_opt_config,
acq_options=self.acquisition_options,
)

mock_input_constructor.assert_called_once()
ckwargs = mock_input_constructor.call_args[1]

# We particularly want to make sure that args that will not be used are
# not passed
expected_kwargs = {
"bounds",
"constraints",
"X_baseline",
"sampler",
"objective",
"training_data",
"model",
}
self.assertSetEqual(set(ckwargs.keys()), expected_kwargs)
for k in expected_kwargs:
self.assertIsNotNone(ckwargs[k], f"{k} is None")

m = ckwargs["model"]
self.assertIsInstance(m, botorch_model_class)
self.assertEqual(m.num_outputs, 1)
Expand All @@ -535,15 +551,8 @@ def _test_gen(
torch.cat([ds.Y for ds in self.block_design_training_data], dim=-1),
)
)
self.assertIsNotNone(ckwargs["constraints"])

self.assertIsNone(
ckwargs["X_pending"],
)
self.assertIsInstance(
ckwargs.get("objective"),
GenericMCObjective,
)
self.assertIsInstance(ckwargs["objective"], GenericMCObjective)
expected_X_baseline = _filter_X_observed(
Xs=[dataset.X for dataset in self.block_design_training_data],
objective_weights=self.objective_weights,
Expand All @@ -553,12 +562,7 @@ def _test_gen(
fixed_features=self.fixed_features,
)
self.assertTrue(
torch.equal(
ckwargs.get("X_baseline"),
# pyre-fixme[6]: For 2nd param expected `Tensor` but got
# `Optional[Tensor]`.
expected_X_baseline,
)
torch.equal(ckwargs["X_baseline"], none_throws(expected_X_baseline))
)

# Assert `construct_acquisition_and_optimizer_options` called with kwargs
Expand Down Expand Up @@ -866,12 +870,27 @@ def test_MOO(self, _) -> None:
search_space_digest=self.mf_search_space_digest,
torch_opt_config=self.moo_torch_opt_config,
)
mock_get_outcome_constraint_transforms.assert_called_once()
ckwargs = mock_get_outcome_constraint_transforms.call_args[1]
oc = ckwargs["outcome_constraints"]
self.assertTrue(torch.equal(oc[0], subset_outcome_constraints[0]))
self.assertTrue(torch.equal(oc[1], subset_outcome_constraints[1]))
mock_get_outcome_constraint_transforms.assert_called_once()
ckwargs = mock_get_outcome_constraint_transforms.call_args[1]
oc = ckwargs["outcome_constraints"]
self.assertTrue(torch.equal(oc[0], subset_outcome_constraints[0]))
self.assertTrue(torch.equal(oc[1], subset_outcome_constraints[1]))

# Check input constructor args
ckwargs = mock_input_constructor.call_args[1]
expected_kwargs = {
"constraints",
"bounds",
"objective",
"training_data",
"X_baseline",
"model",
"objective_thresholds",
}
self.assertSetEqual(set(ckwargs.keys()), expected_kwargs)
for k in expected_kwargs:
self.assertIsNotNone(ckwargs[k], f"{k} is None")

self.assertIs(model.botorch_acqf_class, qLogNoisyExpectedHypervolumeImprovement)
mock_input_constructor.assert_called_once()
m = ckwargs["model"]
Expand Down Expand Up @@ -900,20 +919,14 @@ def test_MOO(self, _) -> None:
)
self.assertIs(ckwargs["constraints"], constraints)

self.assertIsNone(
ckwargs["X_pending"],
)
obj_t = gen_results.gen_metadata["objective_thresholds"]
self.assertTrue(torch.equal(obj_t[:2], self.moo_objective_thresholds[:2]))
self.assertTrue(np.isnan(obj_t[2].item()))

self.assertIsInstance(
ckwargs.get("objective"),
WeightedMCMultiOutputObjective,
)
self.assertIsInstance(ckwargs["objective"], WeightedMCMultiOutputObjective)
self.assertTrue(
torch.equal(
ckwargs.get("objective").weights,
ckwargs["objective"].weights,
self.moo_objective_weights[:2],
)
)
Expand All @@ -926,71 +939,62 @@ def test_MOO(self, _) -> None:
fixed_features=self.fixed_features,
)
self.assertTrue(
torch.equal(
ckwargs.get("X_baseline"),
# pyre-fixme[6]: For 2nd param expected `Tensor` but got
# `Optional[Tensor]`.
expected_X_baseline,
)
torch.equal(ckwargs["X_baseline"], none_throws(expected_X_baseline))
)
# test inferred objective_thresholds
with ExitStack() as es:
_mock_model_infer_objective_thresholds = es.enter_context(
mock.patch(
"ax.models.torch.botorch_modular.acquisition."
"infer_objective_thresholds",
return_value=torch.tensor([9.9, 3.3, float("nan")]),
)
)
objective_weights = torch.tensor([-1.0, -1.0, 0.0])
outcome_constraints = (
torch.tensor([[1.0, 0.0, 0.0]]),
torch.tensor([[10.0]]),
)
linear_constraints = (
torch.tensor([[1.0, 0.0, 0.0]]),
torch.tensor([[2.0]]),
)

objective_weights = torch.tensor([-1.0, -1.0, 0.0])
outcome_constraints = (
torch.tensor([[1.0, 0.0, 0.0]]),
torch.tensor([[10.0]]),
)
linear_constraints = (
torch.tensor([[1.0, 0.0, 0.0]]),
torch.tensor([[2.0]]),
)
torch_opt_config = dataclasses.replace(
self.moo_torch_opt_config,
objective_weights=objective_weights,
objective_thresholds=None,
outcome_constraints=outcome_constraints,
linear_constraints=linear_constraints,
)

with mock.patch(
"ax.models.torch.botorch_modular.acquisition." "infer_objective_thresholds",
return_value=torch.tensor([9.9, 3.3, float("nan")]),
) as _mock_model_infer_objective_thresholds:
gen_results = model.gen(
n=1,
search_space_digest=self.mf_search_space_digest,
torch_opt_config=dataclasses.replace(
self.moo_torch_opt_config,
objective_weights=objective_weights,
objective_thresholds=None,
outcome_constraints=outcome_constraints,
linear_constraints=linear_constraints,
),
)
expected_X_baseline = _filter_X_observed(
Xs=[dataset.X for dataset in self.moo_training_data],
objective_weights=objective_weights,
outcome_constraints=outcome_constraints,
bounds=self.search_space_digest.bounds,
linear_constraints=linear_constraints,
fixed_features=self.fixed_features,
torch_opt_config=torch_opt_config,
)
ckwargs = _mock_model_infer_objective_thresholds.call_args[1]
self.assertTrue(
torch.equal(
ckwargs["objective_weights"],
objective_weights,
)
expected_X_baseline = _filter_X_observed(
Xs=[dataset.X for dataset in self.moo_training_data],
objective_weights=objective_weights,
outcome_constraints=outcome_constraints,
bounds=self.search_space_digest.bounds,
linear_constraints=linear_constraints,
fixed_features=self.fixed_features,
)
ckwargs = _mock_model_infer_objective_thresholds.call_args[1]
self.assertTrue(
torch.equal(
ckwargs["objective_weights"],
objective_weights,
)
oc = ckwargs["outcome_constraints"]
self.assertTrue(torch.equal(oc[0], outcome_constraints[0]))
self.assertTrue(torch.equal(oc[1], outcome_constraints[1]))
m = ckwargs["model"]
self.assertIsInstance(m, SingleTaskGP)
self.assertIsInstance(m.likelihood, FixedNoiseGaussianLikelihood)
self.assertEqual(m.num_outputs, 2)
self.assertIn("objective_thresholds", gen_results.gen_metadata)
obj_t = gen_results.gen_metadata["objective_thresholds"]
self.assertTrue(torch.equal(obj_t[:2], torch.tensor([9.9, 3.3])))
self.assertTrue(np.isnan(obj_t[2].item()))

# test outcome constraints
)
oc = ckwargs["outcome_constraints"]
self.assertTrue(torch.equal(oc[0], outcome_constraints[0]))
self.assertTrue(torch.equal(oc[1], outcome_constraints[1]))
m = ckwargs["model"]
self.assertIsInstance(m, SingleTaskGP)
self.assertIsInstance(m.likelihood, FixedNoiseGaussianLikelihood)
self.assertEqual(m.num_outputs, 2)
self.assertIn("objective_thresholds", gen_results.gen_metadata)
obj_t = gen_results.gen_metadata["objective_thresholds"]
self.assertTrue(torch.equal(obj_t[:2], torch.tensor([9.9, 3.3])))
self.assertTrue(np.isnan(obj_t[2].item()))

# Avoid polluting the registry for other tests; re-register correct input
# contructor for qLogNEHVI.
Expand Down

0 comments on commit 4d35e2c

Please sign in to comment.