diff --git a/botorch/acquisition/input_constructors.py b/botorch/acquisition/input_constructors.py index b1f0b39acd..dba1d55f29 100644 --- a/botorch/acquisition/input_constructors.py +++ b/botorch/acquisition/input_constructors.py @@ -1532,6 +1532,7 @@ def optimize_objective( model: Model, bounds: Tensor, q: int, + acq_function: Optional[AcquisitionFunction] = None, objective: Optional[MCAcquisitionObjective] = None, posterior_transform: Optional[PosteriorTransform] = None, linear_constraints: Optional[tuple[Tensor, Tensor]] = None, @@ -1574,18 +1575,21 @@ def optimize_objective( if optimizer_options is None: optimizer_options = {} - if objective is not None: - sampler_cls = SobolQMCNormalSampler if qmc else IIDNormalSampler - acq_function = qSimpleRegret( - model=model, - objective=objective, - posterior_transform=posterior_transform, - sampler=sampler_cls(sample_shape=torch.Size([mc_samples]), seed=seed_inner), - ) - else: - acq_function = PosteriorMean( - model=model, posterior_transform=posterior_transform - ) + if acq_function is None: + if objective is None: + acq_function = PosteriorMean( + model=model, posterior_transform=posterior_transform + ) + else: + sampler_cls = SobolQMCNormalSampler if qmc else IIDNormalSampler + acq_function = qSimpleRegret( + model=model, + objective=objective, + posterior_transform=posterior_transform, + sampler=sampler_cls( + sample_shape=torch.Size([mc_samples]), seed=seed_inner + ), + ) if fixed_features: acq_function = FixedFeatureAcquisitionFunction( diff --git a/test/acquisition/test_input_constructors.py b/test/acquisition/test_input_constructors.py index c8a3aa530b..a8e876eaf3 100644 --- a/test/acquisition/test_input_constructors.py +++ b/test/acquisition/test_input_constructors.py @@ -220,6 +220,16 @@ def test_optimize_objective(self, mock_optimize_acqf): mock_model = self.mock_model bounds = torch.rand(2, len(self.bounds)) + with self.subTest("scalarObjective_acqusitionFunction"): + optimize_objective( + model=mock_model, + bounds=bounds, + q=1, + acq_function=UpperConfidenceBound(model=mock_model, beta=0.1), + ) + kwargs = mock_optimize_acqf.call_args[1] + self.assertIsInstance(kwargs["acq_function"], UpperConfidenceBound) + A = torch.rand(1, bounds.shape[-1]) b = torch.zeros([1, 1]) idx = A[0].nonzero(as_tuple=False).squeeze()