Skip to content

Commit

Permalink
qLogNEHVI's input constructor
Browse files Browse the repository at this point in the history
Summary: This commit adds qLogNEHVI's input constructor, enabling integration with Ax in a follow up diff.

Differential Revision: D50345278

fbshipit-source-id: aa255309a47c29419ac60fd18b3f48057cb5fff9
  • Loading branch information
SebastianAment authored and facebook-github-bot committed Oct 17, 2023
1 parent 73611ca commit b2272e0
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 4 deletions.
63 changes: 60 additions & 3 deletions botorch/acquisition/input_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@
qExpectedHypervolumeImprovement,
qNoisyExpectedHypervolumeImprovement,
)
from botorch.acquisition.multi_objective.logei import (
qLogNoisyExpectedHypervolumeImprovement,
)
from botorch.acquisition.multi_objective.objective import (
AnalyticMultiOutputObjective,
IdentityAnalyticMultiOutputObjective,
Expand Down Expand Up @@ -688,8 +691,8 @@ def construct_inputs_qLogNEI(
eta: Temperature parameter(s) governing the smoothness of the sigmoid
approximation to the constraint indicators. For more details, on this
parameter, see the docs of `compute_smoothed_feasibility_indicator`.
fat: Toggles the logarithmic / linear asymptotic behavior of the smooth
approximation to the ReLU.
fat: Toggles the use of the fat-tailed non-linearities to smoothly approximate
the constraints indicator function.
tau_max: Temperature parameter controlling the sharpness of the smooth
approximations to max.
tau_relu: Temperature parameter controlling the sharpness of the smooth
Expand Down Expand Up @@ -961,6 +964,7 @@ def construct_inputs_qNEHVI(
sampler: Optional[MCSampler] = None,
X_pending: Optional[Tensor] = None,
eta: float = 1e-3,
fat: bool = False,
mc_samples: int = 128,
qmc: bool = True,
prune_baseline: bool = True,
Expand All @@ -969,7 +973,7 @@ def construct_inputs_qNEHVI(
incremental_nehvi: bool = True,
cache_root: bool = True,
) -> Dict[str, Any]:
r"""Construct kwargs for `qNoisyExpectedHypervolumeImprovement` constructor."""
r"""Construct kwargs for `qNoisyExpectedHypervolumeImprovement`'s constructor."""
if X_baseline is None:
X_baseline = _get_dataset_field(
training_data,
Expand Down Expand Up @@ -1010,6 +1014,7 @@ def construct_inputs_qNEHVI(
"constraints": constraints,
"X_pending": X_pending,
"eta": eta,
"fat": fat,
"prune_baseline": prune_baseline,
"alpha": alpha,
"cache_pending": cache_pending,
Expand All @@ -1019,6 +1024,58 @@ def construct_inputs_qNEHVI(
}


@acqf_input_constructor(qLogNoisyExpectedHypervolumeImprovement)
def construct_inputs_qLogNEHVI(
model: Model,
training_data: MaybeDict[SupervisedDataset],
objective_thresholds: Tensor,
objective: Optional[MCMultiOutputObjective] = None,
X_baseline: Optional[Tensor] = None,
constraints: Optional[List[Callable[[Tensor], Tensor]]] = None,
alpha: Optional[float] = None,
sampler: Optional[MCSampler] = None,
X_pending: Optional[Tensor] = None,
eta: float = 1e-3,
fat: bool = True,
mc_samples: int = 128,
qmc: bool = True,
prune_baseline: bool = True,
cache_pending: bool = True,
max_iep: int = 0,
incremental_nehvi: bool = True,
cache_root: bool = True,
tau_relu: float = TAU_RELU,
tau_max: float = TAU_MAX,
) -> Dict[str, Any]:
"""
Construct kwargs for `qLogNoisyExpectedHypervolumeImprovement`'s constructor."
"""
return {
**construct_inputs_qNEHVI(
model=model,
training_data=training_data,
objective_thresholds=objective_thresholds,
objective=objective,
X_baseline=X_baseline,
constraints=constraints,
alpha=alpha,
sampler=sampler,
X_pending=X_pending,
eta=eta,
fat=fat,
mc_samples=mc_samples,
qmc=qmc,
prune_baseline=prune_baseline,
cache_pending=cache_pending,
max_iep=max_iep,
incremental_nehvi=incremental_nehvi,
cache_root=cache_root,
),
"tau_relu": tau_relu,
"tau_max": tau_max,
}


@acqf_input_constructor(qMaxValueEntropy)
def construct_inputs_qMES(
model: Model,
Expand Down
20 changes: 19 additions & 1 deletion test/acquisition/test_input_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@
qExpectedHypervolumeImprovement,
qNoisyExpectedHypervolumeImprovement,
)
from botorch.acquisition.multi_objective.logei import (
qLogNoisyExpectedHypervolumeImprovement,
)
from botorch.acquisition.multi_objective.multi_output_risk_measures import (
MultiOutputExpectation,
)
Expand Down Expand Up @@ -928,7 +931,13 @@ def test_construct_inputs_qEHVI(self) -> None:
self.assertEqual(sampler.seed, 1234)

def test_construct_inputs_qNEHVI(self) -> None:
c = get_acqf_input_constructor(qNoisyExpectedHypervolumeImprovement)
self._test_construct_inputs_qNEHVI(qNoisyExpectedHypervolumeImprovement)

def test_construct_inputs_qLogNEHVI(self) -> None:
self._test_construct_inputs_qNEHVI(qLogNoisyExpectedHypervolumeImprovement)

def _test_construct_inputs_qNEHVI(self, acqf_class: Type[AcquisitionFunction]):
c = get_acqf_input_constructor(acqf_class)
objective_thresholds = torch.rand(2)

# Test defaults
Expand All @@ -953,6 +962,15 @@ def test_construct_inputs_qNEHVI(self) -> None:
self.assertTrue(kwargs["incremental_nehvi"])
self.assertTrue(kwargs["cache_root"])

if acqf_class == qLogNoisyExpectedHypervolumeImprovement:
self.assertEqual(kwargs["tau_relu"], TAU_RELU)
self.assertEqual(kwargs["tau_max"], TAU_MAX)
self.assertEqual(kwargs["fat"], True)
else:
self.assertNotIn("tau_relu", kwargs)
self.assertNotIn("tau_max", kwargs)
self.assertEqual(kwargs["fat"], False)

# Test check for block designs
mock_model = mock.Mock()
mock_model.num_outputs = 2
Expand Down

0 comments on commit b2272e0

Please sign in to comment.