Skip to content

Commit

Permalink
Fix input constructors for qMaxValueEntropy and `qMultiFidelityKnow…
Browse files Browse the repository at this point in the history
…ledgeGradient` (#1989)

Summary:
X-link: facebook/Ax#1788

Pull Request resolved: #1989

The qMaxValueEntropy input constructor produces the kwarg "objective," but qMaxValueEntropy doesn't accept "objective."

Reviewed By: SebastianAment

Differential Revision: D48480259

fbshipit-source-id: 5bd391f7ea6a0a00413b6977a27f83ff88b56821
  • Loading branch information
esantorella authored and facebook-github-bot committed Aug 21, 2023
1 parent 946f0b4 commit b576f8d
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 32 deletions.
10 changes: 1 addition & 9 deletions botorch/acquisition/input_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,9 +251,6 @@ def acqf_input_constructor(
def decorator(method):
method_kwargs = allow_only_specific_variable_kwargs(method)
for acqf_cls_ in acqf_cls:
_register_acqf_input_constructor(
acqf_cls=acqf_cls_, input_constructor=method_kwargs
)
ACQF_INPUT_CONSTRUCTOR_REGISTRY[acqf_cls_] = method_kwargs
return method

Expand Down Expand Up @@ -1028,7 +1025,6 @@ def construct_inputs_qMES(
model: Model,
training_data: MaybeDict[SupervisedDataset],
bounds: List[Tuple[float, float]],
objective: Optional[MCAcquisitionObjective] = None,
candidate_size: int = 1000,
maximize: bool = True,
# TODO: qMES also supports other inputs, such as num_fantasies
Expand All @@ -1042,7 +1038,6 @@ def construct_inputs_qMES(
return {
"model": model,
"candidate_set": _bounds[0] + (_bounds[1] - _bounds[0]) * _rvs,
"objective": objective,
"maximize": maximize,
}

Expand Down Expand Up @@ -1071,7 +1066,6 @@ def construct_inputs_mf_base(
)

return {
"target_fidelities": target_fidelities,
"cost_aware_utility": cost_aware_utility,
"expand": lambda X: expand_trace_observations(
X=X,
Expand Down Expand Up @@ -1145,7 +1139,6 @@ def construct_inputs_qMFKG(
bounds=bounds,
objective=objective,
posterior_transform=posterior_transform,
target_fidelities=target_fidelities,
num_fantasies=num_fantasies,
)

Expand Down Expand Up @@ -1184,8 +1177,6 @@ def construct_inputs_qMFMES(
model=model,
training_data=training_data,
bounds=bounds,
objective=objective,
posterior_transform=posterior_transform,
candidate_size=candidate_size,
maximize=maximize,
)
Expand All @@ -1205,6 +1196,7 @@ def construct_inputs_qMFMES(
**inputs_mf,
**inputs_qmes,
"current_value": current_value.detach().cpu().max(),
"target_fidelities": target_fidelities,
}


Expand Down
158 changes: 135 additions & 23 deletions test/acquisition/test_input_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import math

from typing import Callable
from typing import Any, Callable, List, Type
from unittest import mock

import torch
Expand Down Expand Up @@ -96,7 +96,7 @@ class DummyAcquisitionFunction(AcquisitionFunction):


class InputConstructorBaseTestCase:
def setUp(self):
def setUp(self) -> None:
self.mock_model = MockModel(
posterior=MockPosterior(mean=None, variance=None, base_shape=(1,))
)
Expand All @@ -113,13 +113,13 @@ def setUp(self):


class TestInputConstructorUtils(InputConstructorBaseTestCase, BotorchTestCase):
def test_field_is_shared(self):
def test_field_is_shared(self) -> None:
self.assertTrue(_field_is_shared(self.blockX_multiY, "X"))
self.assertFalse(_field_is_shared(self.blockX_multiY, "Y"))
with self.assertRaisesRegex(AttributeError, "has no field"):
self.assertFalse(_field_is_shared(self.blockX_multiY, "foo"))

def test_get_best_f_analytic(self):
def test_get_best_f_analytic(self) -> None:
with self.assertRaisesRegex(
NotImplementedError, "Currently only block designs are supported."
):
Expand Down Expand Up @@ -147,7 +147,7 @@ def test_get_best_f_analytic(self):
best_f_expected = post_tf.evaluate(multi_Y).max()
self.assertEqual(best_f_tf, best_f_expected)

def test_get_best_f_mc(self):
def test_get_best_f_mc(self) -> None:
with self.assertRaisesRegex(
NotImplementedError, "Currently only block designs are supported."
):
Expand Down Expand Up @@ -264,7 +264,7 @@ def test_construct_inputs_posterior_mean(self) -> None:
acqf = PosteriorMean(**kwargs)
self.assertIs(acqf.model, mock_model)

def test_construct_inputs_best_f(self):
def test_construct_inputs_best_f(self) -> None:
for acqf_cls in [
ExpectedImprovement,
LogExpectedImprovement,
Expand Down Expand Up @@ -682,7 +682,7 @@ def test_construct_inputs_qUCB(self) -> None:
class TestMultiObjectiveAcquisitionFunctionInputConstructors(
InputConstructorBaseTestCase, BotorchTestCase
):
def test_construct_inputs_EHVI(self):
def test_construct_inputs_EHVI(self) -> None:
c = get_acqf_input_constructor(ExpectedHypervolumeImprovement)
mock_model = mock.Mock()
objective_thresholds = torch.rand(6)
Expand Down Expand Up @@ -794,7 +794,7 @@ def test_construct_inputs_EHVI(self):
self.assertIsInstance(partitioning, FastNondominatedPartitioning)
self.assertTrue(torch.equal(partitioning.ref_point, expected_obj_t))

def test_construct_inputs_qEHVI(self):
def test_construct_inputs_qEHVI(self) -> None:
c = get_acqf_input_constructor(qExpectedHypervolumeImprovement)
objective_thresholds = torch.rand(2)

Expand Down Expand Up @@ -892,7 +892,7 @@ def test_construct_inputs_qEHVI(self):
self.assertEqual(sampler.sample_shape, torch.Size([16]))
self.assertEqual(sampler.seed, 1234)

def test_construct_inputs_qNEHVI(self):
def test_construct_inputs_qNEHVI(self) -> None:
c = get_acqf_input_constructor(qNoisyExpectedHypervolumeImprovement)
objective_thresholds = torch.rand(2)

Expand Down Expand Up @@ -1011,7 +1011,7 @@ def test_construct_inputs_qNEHVI(self):
)
self.assertEqual(kwargs["alpha"], 0.0)

def test_construct_inputs_kg(self):
def test_construct_inputs_kg(self) -> None:
current_value = torch.tensor(1.23)
with mock.patch(
target="botorch.acquisition.input_constructors.optimize_objective",
Expand All @@ -1031,10 +1031,11 @@ def test_construct_inputs_kg(self):
self.assertEqual(kwargs["num_fantasies"], 33)
self.assertEqual(kwargs["current_value"], current_value)

def test_construct_inputs_mes(self):
def test_construct_inputs_mes(self) -> None:
func = get_acqf_input_constructor(qMaxValueEntropy)
model = SingleTaskGP(train_X=torch.ones((3, 2)), train_Y=torch.zeros((3, 1)))
kwargs = func(
model=mock.Mock(),
model=model,
training_data=self.blockX_blockY,
objective=LinearMCObjective(torch.rand(2)),
bounds=self.bounds,
Expand All @@ -1049,7 +1050,10 @@ def test_construct_inputs_mes(self):
[int(s) for s in kwargs["candidate_set"].shape], [17, len(self.bounds)]
)

def test_construct_inputs_mf_base(self):
acqf = qMaxValueEntropy(**kwargs)
self.assertIs(acqf.model, model)

def test_construct_inputs_mf_base(self) -> None:
target_fidelities = {0: 0.123}
fidelity_weights = {0: 0.456}
cost_intercept = 0.789
Expand All @@ -1063,10 +1067,8 @@ def test_construct_inputs_mf_base(self):
num_trace_observations=num_trace_observations,
)

self.assertEqual(kwargs["target_fidelities"], target_fidelities)

X = torch.rand(3, 2)
self.assertTrue(isinstance(kwargs["expand"], Callable))
self.assertIsInstance(kwargs["expand"], Callable)
self.assertTrue(
torch.equal(
kwargs["expand"](X),
Expand All @@ -1078,7 +1080,7 @@ def test_construct_inputs_mf_base(self):
)
)

self.assertTrue(isinstance(kwargs["project"], Callable))
self.assertIsInstance(kwargs["project"], Callable)
self.assertTrue(
torch.equal(
kwargs["project"](X),
Expand All @@ -1103,13 +1105,13 @@ def test_construct_inputs_mf_base(self):
with self.assertRaisesRegex(
RuntimeError, "Must provide the same indices for"
):
_ = construct_inputs_mf_base(
construct_inputs_mf_base(
target_fidelities={0: 1.0},
fidelity_weights={1: 0.5},
cost_intercept=cost_intercept,
)

def test_construct_inputs_mfkg(self):
def test_construct_inputs_mfkg(self) -> None:
constructor_args = {
"model": None,
"training_data": self.blockX_blockY,
Expand All @@ -1136,15 +1138,16 @@ def test_construct_inputs_mfkg(self):
inputs_test = {"foo": 0, "bar": 1}
self.assertEqual(inputs_mfkg, inputs_test)

def test_construct_inputs_mfmes(self):
def test_construct_inputs_mfmes(self) -> None:
target_fidelities = {0: 0.987}
constructor_args = {
"model": None,
"training_data": self.blockX_blockY,
"objective": None,
"bounds": self.bounds,
"num_fantasies": 123,
"candidate_size": 17,
"target_fidelities": {0: 0.987},
"target_fidelities": target_fidelities,
"fidelity_weights": {0: 0.654},
"cost_intercept": 0.321,
}
Expand All @@ -1165,10 +1168,15 @@ def test_construct_inputs_mfmes(self):
qMultiFidelityMaxValueEntropy
)
inputs_mfmes = input_constructor(**constructor_args)
inputs_test = {"foo": 0, "bar": 1, "current_value": current_value}
inputs_test = {
"foo": 0,
"bar": 1,
"current_value": current_value,
"target_fidelities": target_fidelities,
}
self.assertEqual(inputs_mfmes, inputs_test)

def test_construct_inputs_jes(self):
def test_construct_inputs_jes(self) -> None:
func = get_acqf_input_constructor(qJointEntropySearch)
# we need to run optimize_posterior_samples, so we sort of need
# a real model as there is no other (apparent) option
Expand All @@ -1193,3 +1201,107 @@ def test_construct_inputs_jes(self):
# of shape N x D and outputs are N x 1
self.assertEqual(len(kwargs["optimal_inputs"].shape), 2)
self.assertEqual(len(kwargs["optimal_outputs"].shape), 2)
qJointEntropySearch(**kwargs)


class TestInstantiationFromInputConstructor(
InputConstructorBaseTestCase, BotorchTestCase
):
def _test_constructor_base(
self, classes: List[Type[AcquisitionFunction]], **input_constructor_kwargs: Any
) -> None:
for cls_ in classes:
with self.subTest(cls_.__name__, cls_=cls_):
acqf_kwargs = get_acqf_input_constructor(cls_)(
**input_constructor_kwargs
)
# no assertions; we are just testing that this doesn't error
cls_(**acqf_kwargs)

def test_constructors_like_PosteriorMean(self) -> None:
classes = [PosteriorMean, UpperConfidenceBound, qUpperConfidenceBound]
self._test_constructor_base(classes=classes, model=self.mock_model)

def test_constructors_like_ExpectedImprovement(self) -> None:
classes = [
ExpectedImprovement,
LogExpectedImprovement,
ProbabilityOfImprovement,
LogProbabilityOfImprovement,
NoisyExpectedImprovement,
LogNoisyExpectedImprovement,
qExpectedImprovement,
qLogExpectedImprovement,
qNoisyExpectedImprovement,
qLogNoisyExpectedImprovement,
qProbabilityOfImprovement,
]
model = FixedNoiseGP(
train_X=torch.rand((4, 2)),
train_Y=torch.rand((4, 1)),
train_Yvar=torch.ones((4, 1)),
)
self._test_constructor_base(
classes=classes, model=model, training_data=self.blockX_blockY
)

def test_constructors_like_qNEHVI(self) -> None:
objective_thresholds = torch.tensor([0.1, 0.2])
model = SingleTaskGP(train_X=torch.rand((3, 2)), train_Y=torch.rand((3, 2)))
# The EHVI and qEHVI input constructors are not working
classes = [
qNoisyExpectedHypervolumeImprovement,
# ExpectedHypervolumeImprovement,
# qExpectedHypervolumeImprovement,
]
self._test_constructor_base(
classes=classes,
model=model,
training_data=self.blockX_blockY,
objective_thresholds=objective_thresholds,
)

def test_constructors_like_qMaxValueEntropy(self) -> None:
bounds = torch.ones((1, 2))
classes = [qMaxValueEntropy, qKnowledgeGradient]
self._test_constructor_base(
classes=classes,
model=SingleTaskGP(train_X=torch.rand((3, 1)), train_Y=torch.rand((3, 1))),
training_data=self.blockX_blockY,
bounds=bounds,
)

def test_constructors_like_qMultiFidelityKnowledgeGradient(self) -> None:
classes = [
qMultiFidelityKnowledgeGradient,
# currently the input constructor for qMFMVG is not working
# qMultiFidelityMaxValueEntropy
]
self._test_constructor_base(
classes=classes,
model=SingleTaskGP(train_X=torch.rand((3, 1)), train_Y=torch.rand((3, 1))),
training_data=self.blockX_blockY,
bounds=torch.ones((1, 2)),
target_fidelities={0: 0.987},
)

def test_eubo(self) -> None:
model = SingleTaskGP(train_X=torch.rand((3, 2)), train_Y=torch.rand((3, 2)))
pref_model = self.mock_model
pref_model.dim = 2
pref_model.datapoints = torch.tensor([])

classes = [AnalyticExpectedUtilityOfBestOption]
self._test_constructor_base(
classes=classes,
model=model,
pref_model=pref_model,
)

def test_qjes(self) -> None:
model = SingleTaskGP(self.blockX_blockY[0].X(), self.blockX_blockY[0].Y())
self._test_constructor_base(
classes=[qJointEntropySearch],
model=model,
bounds=self.bounds,
)

0 comments on commit b576f8d

Please sign in to comment.