diff --git a/botorch/models/model.py b/botorch/models/model.py index 04ec441a1f..65d408431d 100644 --- a/botorch/models/model.py +++ b/botorch/models/model.py @@ -657,7 +657,9 @@ def fantasize( if observation_noise is not None: observation_noise_i = observation_noise[..., mask_i, i : i + 1] else: - sampler_i = sampler + sampler_i = ( + sampler.samplers[i] if isinstance(sampler, ListSampler) else sampler + ) fant_model = self.models[i].fantasize( X=X_i, diff --git a/botorch/optim/initializers.py b/botorch/optim/initializers.py index 04e203facf..3b9b1eb3dc 100644 --- a/botorch/optim/initializers.py +++ b/botorch/optim/initializers.py @@ -22,10 +22,16 @@ from botorch import settings from botorch.acquisition import analytic, monte_carlo, multi_objective from botorch.acquisition.acquisition import AcquisitionFunction +from botorch.acquisition.fixed_feature import FixedFeatureAcquisitionFunction from botorch.acquisition.knowledge_gradient import ( _get_value_function, qKnowledgeGradient, ) +from botorch.acquisition.multi_objective.hypervolume_knowledge_gradient import ( + _get_hv_value_function, + qHypervolumeKnowledgeGradient, + qMultiFidelityHypervolumeKnowledgeGradient, +) from botorch.exceptions.errors import BotorchTensorDimensionError, UnsupportedError from botorch.exceptions.warnings import ( BadInitialCandidatesWarning, @@ -245,6 +251,7 @@ def gen_batch_initial_conditions( inequality_constraints: Optional[List[Tuple[Tensor, Tensor, float]]] = None, equality_constraints: Optional[List[Tuple[Tensor, Tensor, float]]] = None, generator: Optional[Callable[[int, int, int], Tensor]] = None, + fixed_X_fantasies: Optional[Tensor] = None, ) -> Tensor: r"""Generate a batch of initial conditions for random-restart optimziation. @@ -278,6 +285,8 @@ def gen_batch_initial_conditions( generator: Callable for generating samples that are then further processed. It receives `n`, `q` and `seed` as arguments and returns a tensor of shape `n x q x d`. + fixed_X_fantasies: A fixed set of fantasy points to concatenate to + the `q` candidates being initialized along the `-2` dimension. Returns: A `num_restarts x q x d` tensor of initial conditions. @@ -379,6 +388,22 @@ def gen_batch_initial_conditions( dim=0, ) X_rnd = fix_features(X_rnd, fixed_features=fixed_features) + if fixed_X_fantasies is not None: + if (d_f := fixed_X_fantasies.shape[-1]) != (d_r := X_rnd.shape[-1]): + raise BotorchTensorDimensionError( + "`fixed_X_fantasies` and `bounds` must both have the same " + f"trailing dimension `d`, but have {d_f} and {d_r}, " + "respectively." + ) + X_rnd = torch.cat( + [ + X_rnd, + fixed_X_fantasies.cpu() + .unsqueeze(0) + .expand(X_rnd.shape[0], *fixed_X_fantasies.shape), + ], + dim=-2, + ) with torch.no_grad(): if batch_limit is None: batch_limit = X_rnd.shape[0] @@ -425,7 +450,7 @@ def gen_one_shot_kg_initial_conditions( This function generates initial conditions for optimizing one-shot KG using the maximizer of the posterior objective. Intutively, the maximizer of the fantasized posterior will often be close to a maximizer of the current - posterior. This function uses that fact to generate the initital conditions + posterior. This function uses that fact to generate the initial conditions for the fantasy points. Specifically, a fraction of `1 - frac_random` (see options) is generated by sampling from the set of maximizers of the posterior objective (obtained via random restart optimization) according to @@ -528,6 +553,203 @@ def gen_one_shot_kg_initial_conditions( return ics +def gen_one_shot_hvkg_initial_conditions( + acq_function: qHypervolumeKnowledgeGradient, + bounds: Tensor, + q: int, + num_restarts: int, + raw_samples: int, + fixed_features: Optional[Dict[int, float]] = None, + options: Optional[Dict[str, Union[bool, float, int]]] = None, + inequality_constraints: Optional[List[Tuple[Tensor, Tensor, float]]] = None, + equality_constraints: Optional[List[Tuple[Tensor, Tensor, float]]] = None, +) -> Optional[Tensor]: + r"""Generate a batch of smart initializations for qHypervolumeKnowledgeGradient. + + This function generates initial conditions for optimizing one-shot HVKG using + the hypervolume maximizing set (of fixed size) under the posterior mean. + Intutively, the hypervolume maximizing set of the fantasized posterior mean + will often be close to a hypervolume maximizing set under the current posterior + mean. This function uses that fact to generate the initial conditions + for the fantasy points. Specifically, a fraction of `1 - frac_random` (see + options) of the restarts are generated by learning the hypervolume maximizing sets + under the current posterior mean, where each hypervolume maximizing set is + obtained from maximizing the hypervolume from a different starting point. Given + a hypervolume maximizing set, the `q` candidate points are selected using to the + standard initialization strategy in `gen_batch_initial_conditions`, with the fixed + hypervolume maximizing set. The remaining `frac_random` restarts fantasy points + as well as all `q` candidate points are chosen according to the standard + initialization strategy in `gen_batch_initial_conditions`. + + Args: + acq_function: The qKnowledgeGradient instance to be optimized. + bounds: A `2 x d` tensor of lower and upper bounds for each column of + task features. + q: The number of candidates to consider. + num_restarts: The number of starting points for multistart acquisition + function optimization. + raw_samples: The number of raw samples to consider in the initialization + heuristic. + fixed_features: A map `{feature_index: value}` for features that + should be fixed to a particular value during generation. + options: Options for initial condition generation. These contain all + settings for the standard heuristic initialization from + `gen_batch_initial_conditions`. In addition, they contain + `frac_random` (the fraction of fully random fantasy points), + `num_inner_restarts` and `raw_inner_samples` (the number of random + restarts and raw samples for solving the posterior objective + maximization problem, respectively) and `eta` (temperature parameter + for sampling heuristic from posterior objective maximizers). + inequality constraints: A list of tuples (indices, coefficients, rhs), + with each tuple encoding an inequality constraint of the form + `\sum_i (X[indices[i]] * coefficients[i]) >= rhs`. + equality constraints: A list of tuples (indices, coefficients, rhs), + with each tuple encoding an inequality constraint of the form + `\sum_i (X[indices[i]] * coefficients[i]) = rhs`. + + Returns: + A `num_restarts x q' x d` tensor that can be used as initial conditions + for `optimize_acqf()`. Here `q' = q + num_fantasies` is the total number + of points (candidate points plus fantasy points). + + Example: + >>> qKG = qKnowledgeGradient(model, num_fantasies=64) + >>> bounds = torch.tensor([[0., 0.], [1., 1.]]) + >>> Xinit = gen_one_shot_kg_initial_conditions( + >>> qKG, bounds, q=3, num_restarts=10, raw_samples=512, + >>> options={"frac_random": 0.25}, + >>> ) + """ + from botorch.optim.optimize import optimize_acqf + + options = options or {} + frac_random: float = options.get("frac_random", 0.1) + if not 0 < frac_random < 1: + raise ValueError( + f"frac_random must take on values in (0,1). Value: {frac_random}" + ) + + value_function = _get_hv_value_function( + model=acq_function.model, + ref_point=acq_function.ref_point, + objective=acq_function.objective, + sampler=acq_function.inner_sampler, + use_posterior_mean=acq_function.use_posterior_mean, + ) + + is_mf_hvkg = isinstance(acq_function, qMultiFidelityHypervolumeKnowledgeGradient) + if is_mf_hvkg: + dim = bounds.shape[-1] + fidelity_dims, fidelity_targets = zip(*acq_function.target_fidelities.items()) + value_function = FixedFeatureAcquisitionFunction( + acq_function=value_function, + d=dim, + columns=fidelity_dims, + values=fidelity_targets, + ) + + non_fidelity_dims = list(set(range(dim)) - set(fidelity_dims)) + + num_optim_restarts = int(round(num_restarts * (1 - frac_random))) + fantasy_cands, fantasy_vals = optimize_acqf( + acq_function=value_function, + bounds=bounds[:, non_fidelity_dims] if is_mf_hvkg else bounds, + q=acq_function.num_pareto, + num_restarts=options.get("num_inner_restarts", 20), + raw_samples=options.get("raw_inner_samples", 1024), + fixed_features=fixed_features, + return_best_only=False, + options=options, + inequality_constraints=inequality_constraints, + equality_constraints=equality_constraints, + sequential=False, + ) + # sampling from the optimizers + eta = options.get("eta", 2.0) + if num_optim_restarts > 0: + probs = torch.nn.functional.softmax(eta * standardize(fantasy_vals)) + idx = torch.multinomial( + probs, + num_optim_restarts * acq_function.num_fantasies, + replacement=True, + ) + optim_ics = fantasy_cands[idx] + if is_mf_hvkg: + # add fixed features + optim_ics = value_function.construct_X_full(optim_ics) + optim_ics = optim_ics.reshape( + num_optim_restarts, acq_function.num_pseudo_points, bounds.shape[-1] + ) + + # get random initial conditions + num_random_restarts = num_restarts - num_optim_restarts + if num_random_restarts > 0: + q_aug = acq_function.get_augmented_q_batch_size(q=q) + base_ics = gen_batch_initial_conditions( + acq_function=acq_function, + bounds=bounds, + q=q_aug, + num_restarts=num_restarts, + raw_samples=raw_samples, + fixed_features=fixed_features, + options=options, + inequality_constraints=inequality_constraints, + equality_constraints=equality_constraints, + ) + + if num_optim_restarts > 0: + probs = torch.full( + (num_restarts,), + 1.0 / num_restarts, + dtype=optim_ics.dtype, + device=optim_ics.device, + ) + optim_idxr = probs.multinomial( + num_samples=num_optim_restarts, replacement=False + ) + base_ics[optim_idxr, q:] = optim_ics + else: + # optim_ics is num_restarts x num_pseudo_points x d + # add padding so that base_ics is num_restarts x q+num_pseudo_points x d + q_padding = torch.zeros( + optim_ics.shape[0], + q, + optim_ics.shape[-1], + dtype=optim_ics.dtype, + device=optim_ics.device, + ) + base_ics = torch.cat([q_padding, optim_ics], dim=-2) + + if num_optim_restarts > 0: + all_ics = [] + if num_random_restarts > 0: + optim_idcs = optim_idxr.view(-1).tolist() + else: + optim_idcs = list(range(num_restarts)) + for i in list(range(num_restarts)): + if i in optim_idcs: + # optimize the q points, + # given fixed, optimized fantasy designs + ics = gen_batch_initial_conditions( + acq_function=acq_function, + bounds=bounds, + q=q, + num_restarts=1, + raw_samples=raw_samples, + fixed_features=fixed_features, + options=options, + inequality_constraints=inequality_constraints, + equality_constraints=equality_constraints, + fixed_X_fantasies=base_ics[i, q:], + ) + else: + # ics are all randomly sampled + ics = base_ics[i : i + 1] + all_ics.append(ics) + return torch.cat(all_ics, dim=0) + return base_ics + + def gen_value_function_initial_conditions( acq_function: AcquisitionFunction, bounds: Tensor, diff --git a/botorch/optim/optimize.py b/botorch/optim/optimize.py index f2908d37c6..c17897694e 100644 --- a/botorch/optim/optimize.py +++ b/botorch/optim/optimize.py @@ -22,12 +22,16 @@ OneShotAcquisitionFunction, ) from botorch.acquisition.knowledge_gradient import qKnowledgeGradient +from botorch.acquisition.multi_objective.hypervolume_knowledge_gradient import ( + qHypervolumeKnowledgeGradient, +) from botorch.exceptions import InputDataError, UnsupportedError from botorch.exceptions.warnings import OptimizationWarning from botorch.generation.gen import gen_candidates_scipy, TGenCandidates from botorch.logging import logger from botorch.optim.initializers import ( gen_batch_initial_conditions, + gen_one_shot_hvkg_initial_conditions, gen_one_shot_kg_initial_conditions, TGenInitialConditions, ) @@ -129,6 +133,8 @@ def get_ic_generator(self) -> TGenInitialConditions: return self.ic_generator elif isinstance(self.acq_function, qKnowledgeGradient): return gen_one_shot_kg_initial_conditions + elif isinstance(self.acq_function, qHypervolumeKnowledgeGradient): + return gen_one_shot_hvkg_initial_conditions return gen_batch_initial_conditions diff --git a/test/optim/test_initializers.py b/test/optim/test_initializers.py index 5537ac3260..d7ed46ab7c 100644 --- a/test/optim/test_initializers.py +++ b/test/optim/test_initializers.py @@ -20,6 +20,9 @@ qExpectedImprovement, qNoisyExpectedImprovement, ) +from botorch.acquisition.multi_objective.hypervolume_knowledge_gradient import ( + qHypervolumeKnowledgeGradient, +) from botorch.acquisition.multi_objective.monte_carlo import ( qNoisyExpectedHypervolumeImprovement, ) @@ -27,9 +30,11 @@ from botorch.exceptions.errors import BotorchTensorDimensionError, UnsupportedError from botorch.exceptions.warnings import BotorchWarning from botorch.models import SingleTaskGP +from botorch.models.model_list_gp_regression import ModelListGP from botorch.optim import initialize_q_batch, initialize_q_batch_nonneg from botorch.optim.initializers import ( gen_batch_initial_conditions, + gen_one_shot_hvkg_initial_conditions, gen_one_shot_kg_initial_conditions, gen_value_function_initial_conditions, sample_perturbed_subset_dims, @@ -808,6 +813,88 @@ def test_error_equality_constraints_with_sample_around_best(self): options={"sample_around_best": True}, ) + def test_gen_batch_initial_conditions_fixed_X_fantasies(self): + bounds = torch.stack([torch.zeros(2), torch.ones(2)]) + mock_acqf = MockAcquisitionFunction() + mock_acqf.objective = lambda y: y.squeeze(-1) + for dtype in (torch.float, torch.double): + bounds = bounds.to(device=self.device, dtype=dtype) + mock_acqf.X_baseline = bounds # for testing sample_around_best + mock_acqf.model = MockModel(MockPosterior(mean=bounds[:, :1])) + fixed_X_fantasies = torch.rand(3, 2, dtype=dtype, device=self.device) + for nonnegative, seed, init_batch_limit, ffs, sample_around_best in product( + [True, False], [None, 1234], [None, 1], [None, {0: 0.5}], [True, False] + ): + with mock.patch.object( + MockAcquisitionFunction, + "__call__", + wraps=mock_acqf.__call__, + ) as mock_acqf_call: + batch_initial_conditions = gen_batch_initial_conditions( + acq_function=mock_acqf, + bounds=bounds, + q=1, + num_restarts=2, + raw_samples=10, + fixed_features=ffs, + options={ + "nonnegative": nonnegative, + "eta": 0.01, + "alpha": 0.1, + "seed": seed, + "init_batch_limit": init_batch_limit, + "sample_around_best": sample_around_best, + }, + fixed_X_fantasies=fixed_X_fantasies, + ) + expected_shape = torch.Size([2, 4, 2]) + self.assertEqual(batch_initial_conditions.shape, expected_shape) + self.assertEqual(batch_initial_conditions.device, bounds.device) + self.assertEqual(batch_initial_conditions.dtype, bounds.dtype) + self.assertLess( + _get_max_violation_of_bounds(batch_initial_conditions, bounds), + 1e-6, + ) + batch_shape = ( + torch.Size([]) + if init_batch_limit is None + else torch.Size([init_batch_limit]) + ) + raw_samps = mock_acqf_call.call_args[0][0] + batch_shape = ( + torch.Size([20 if sample_around_best else 10]) + if init_batch_limit is None + else torch.Size([init_batch_limit]) + ) + expected_raw_samps_shape = batch_shape + torch.Size([4, 2]) + self.assertEqual(raw_samps.shape, expected_raw_samps_shape) + + if ffs is not None: + for idx, val in ffs.items(): + self.assertTrue( + torch.all(batch_initial_conditions[..., 0, idx] == val) + ) + self.assertTrue( + torch.equal( + batch_initial_conditions[:, 1:], + fixed_X_fantasies.unsqueeze(0).expand(2, 3, 2), + ) + ) + # test wrong shape + msg = ( + "`fixed_X_fantasies` and `bounds` must both have the same trailing" + " dimension `d`, but have 3 and 2, respectively." + ) + with self.assertRaisesRegex(BotorchTensorDimensionError, msg): + gen_batch_initial_conditions( + acq_function=mock_acqf, + bounds=bounds, + q=1, + num_restarts=2, + raw_samples=10, + fixed_X_fantasies=torch.rand(3, 3, dtype=dtype, device=self.device), + ) + class TestGenOneShotKGInitialConditions(BotorchTestCase): def test_gen_one_shot_kg_initial_conditions(self): @@ -865,6 +952,94 @@ def test_gen_one_shot_kg_initial_conditions(self): self.assertTrue(torch.all(ics[..., -n_value:, :] == 1)) +class TestGenOneShotHVKGInitialConditions(BotorchTestCase): + def test_gen_one_shot_hvkg_initial_conditions(self): + num_fantasies = 8 + num_restarts = 4 + raw_samples = 16 + tkwargs = {"device": self.device} + for dtype in (torch.float, torch.double): + tkwargs["dtype"] = dtype + X = torch.rand(4, 2, **tkwargs) + Y1 = torch.rand(4, 1, **tkwargs) + Y2 = torch.rand(4, 1, **tkwargs) + m1 = SingleTaskGP(X, Y1) + m2 = SingleTaskGP(X, Y2) + model = ModelListGP(m1, m2) + hvkg = qHypervolumeKnowledgeGradient( + model=model, + ref_point=torch.zeros(2, **tkwargs), + num_fantasies=num_fantasies, + ) + bounds = torch.tensor([[0, 0], [1, 1]], device=self.device, dtype=dtype) + # test option error + with self.assertRaises(ValueError): + gen_one_shot_hvkg_initial_conditions( + acq_function=hvkg, + bounds=bounds, + q=1, + num_restarts=num_restarts, + raw_samples=raw_samples, + options={"frac_random": 2.0}, + ) + # test generation logic + q = 2 + mock_fantasy_cands = torch.ones(20, 10, 2) + mock_fantasy_vals = torch.randn(20) + + def mock_gen_ics(*args, **kwargs): + fixed_X_fantasies = kwargs.get("fixed_X_fantasies") + if fixed_X_fantasies is None: + return torch.rand( + kwargs["num_restarts"], q + hvkg.num_pseudo_points, 2 + ) + rand_candidates = torch.rand( + 1, + q, + 2, + dtype=fixed_X_fantasies.dtype, + device=fixed_X_fantasies.device, + ) + return torch.cat( + [ + rand_candidates, + fixed_X_fantasies.unsqueeze(0), + ], + dim=-2, + ) + + for frac_random in (0.1, 0.5): + with ExitStack() as es: + mock_gbics = es.enter_context( + mock.patch( + "botorch.optim.initializers.gen_batch_initial_conditions", + wraps=mock_gen_ics, + ) + ) + mock_optacqf = es.enter_context( + mock.patch( + "botorch.optim.optimize.optimize_acqf", + return_value=(mock_fantasy_cands, mock_fantasy_vals), + ) + ) + ics = gen_one_shot_hvkg_initial_conditions( + acq_function=hvkg, + bounds=bounds, + q=q, + num_restarts=num_restarts, + raw_samples=raw_samples, + options={"frac_random": frac_random}, + ) + expected_call_count = 3 if frac_random == 0.5 else 4 + self.assertEqual(mock_gbics.call_count, expected_call_count) + mock_optacqf.assert_called_once() + n_value = int(round((1 - frac_random) * num_restarts)) + # check that there are the expected number of optimized points + self.assertTrue( + (ics == 1).all(dim=-1).sum() == n_value * hvkg.num_pseudo_points + ) + + class TestGenValueFunctionInitialConditions(BotorchTestCase): def test_gen_value_function_initial_conditions(self): num_fantasies = 2