Skip to content

Commit

Permalink
Add input constructor for qHypervolumeKnowledgeGradient (#2501)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2501

Adds new input constructors for `qHypervolumeKnowledgeGradient`.

Differential Revision: D62046832
  • Loading branch information
ltiao authored and facebook-github-bot committed Sep 10, 2024
1 parent 58e2ab0 commit 16a5177
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 0 deletions.
54 changes: 54 additions & 0 deletions botorch/acquisition/input_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@
qExpectedHypervolumeImprovement,
qNoisyExpectedHypervolumeImprovement,
)
from botorch.acquisition.multi_objective.hypervolume_knowledge_gradient import (
_get_hv_value_function,
qHypervolumeKnowledgeGradient,
)
from botorch.acquisition.multi_objective.logei import (
qLogExpectedHypervolumeImprovement,
qLogNoisyExpectedHypervolumeImprovement,
Expand Down Expand Up @@ -1270,6 +1274,56 @@ def construct_inputs_qKG(
return inputs_qkg


@acqf_input_constructor(qHypervolumeKnowledgeGradient)
def construct_inputs_qHVKG(
model: Model,
training_data: MaybeDict[SupervisedDataset],
bounds: list[tuple[float, float]],
objective_thresholds: Tensor,
objective: Optional[MCMultiOutputObjective] = None,
posterior_transform: Optional[PosteriorTransform] = None,
num_fantasies: int = 8,
num_pareto: int = 10,
**optimize_objective_kwargs: TOptimizeObjectiveKwargs,
) -> dict[str, Any]:
r"""Construct kwargs for `qKnowledgeGradient` constructor."""

X = _get_dataset_field(training_data, "X", first_only=True)
_bounds = torch.as_tensor(bounds, dtype=X.dtype, device=X.device)

if objective is None:
ref_point = objective_thresholds
elif isinstance(objective, RiskMeasureMCObjective):
ref_point = objective.preprocessing_function(objective_thresholds)

Check warning on line 1297 in botorch/acquisition/input_constructors.py

View check run for this annotation

Codecov / codecov/patch

botorch/acquisition/input_constructors.py#L1296-L1297

Added lines #L1296 - L1297 were not covered by tests
else:
ref_point = objective(objective_thresholds)

Check warning on line 1299 in botorch/acquisition/input_constructors.py

View check run for this annotation

Codecov / codecov/patch

botorch/acquisition/input_constructors.py#L1299

Added line #L1299 was not covered by tests

acq_function = _get_hv_value_function(
model=model,
ref_point=ref_point,
use_posterior_mean=True,
)

_, current_value = optimize_objective(
model=model,
bounds=_bounds.t(),
q=num_pareto,
acq_function=acq_function,
objective=objective,
posterior_transform=posterior_transform,
**optimize_objective_kwargs,
)

return {
"model": model,
"objective": objective,
"ref_point": ref_point,
"num_fantasies": num_fantasies,
"num_pareto": num_pareto,
"current_value": current_value.detach().cpu().max(),
}


@acqf_input_constructor(qMultiFidelityKnowledgeGradient)
def construct_inputs_qMFKG(
model: Model,
Expand Down
54 changes: 54 additions & 0 deletions test/acquisition/test_input_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@
qExpectedHypervolumeImprovement,
qNoisyExpectedHypervolumeImprovement,
)
from botorch.acquisition.multi_objective.hypervolume_knowledge_gradient import (
qHypervolumeKnowledgeGradient,
)

from botorch.acquisition.multi_objective.logei import (
qLogExpectedHypervolumeImprovement,
qLogNoisyExpectedHypervolumeImprovement,
Expand Down Expand Up @@ -1291,6 +1295,39 @@ def test_construct_inputs_kg(self) -> None:
)
self.assertNotIn("current_value", kwargs)

def test_construct_inputs_hvkg(self) -> None:
current_value = torch.tensor(1.23)
objective_thresholds = torch.rand(2)

with mock.patch(
target="botorch.acquisition.input_constructors.optimize_acqf",
return_value=(None, current_value),
):
get_kwargs = get_acqf_input_constructor(qHypervolumeKnowledgeGradient)
kwargs = get_kwargs(
model=mock.Mock(),
training_data=self.blockX_blockY,
objective_thresholds=objective_thresholds,
bounds=self.bounds,
num_fantasies=33,
num_pareto=11,
)
self.assertLessEqual(
{
"model",
"ref_point",
"num_fantasies",
"num_pareto",
"objective",
"current_value",
},
set(kwargs.keys()),
)
self.assertEqual(kwargs["num_fantasies"], 33)
self.assertEqual(kwargs["num_pareto"], 11)
self.assertEqual(kwargs["current_value"], current_value)
self.assertTrue(torch.equal(kwargs["ref_point"], objective_thresholds))

def test_construct_inputs_mes(self) -> None:
func = get_acqf_input_constructor(qMaxValueEntropy)
n, d, m = 5, 2, 1
Expand Down Expand Up @@ -1603,6 +1640,23 @@ def setUp(self, suppress_input_warnings: bool = True) -> None:
"training_data": self.blockX_blockY,
},
)

X = torch.rand(3, 2)
Y1 = torch.rand(3, 1)
Y2 = torch.rand(3, 1)
m1 = SingleTaskGP(X, Y1)
m2 = SingleTaskGP(X, Y2)
model_list = ModelListGP(m1, m2)

self.cases["HV Look-ahead"] = (
[qHypervolumeKnowledgeGradient],
{
"model": model_list,
"training_data": self.blockX_blockY,
"bounds": bounds,
"objective_thresholds": objective_thresholds,
},
)
pref_model = self.mock_model
pref_model.dim = 2
pref_model.datapoints = torch.tensor([])
Expand Down

0 comments on commit 16a5177

Please sign in to comment.