diff --git a/botorch/optim/initializers.py b/botorch/optim/initializers.py index 54103745a8..c9c4ecd1eb 100644 --- a/botorch/optim/initializers.py +++ b/botorch/optim/initializers.py @@ -244,6 +244,7 @@ def gen_batch_initial_conditions( options: Optional[Dict[str, Union[bool, float, int]]] = None, inequality_constraints: Optional[List[Tuple[Tensor, Tensor, float]]] = None, equality_constraints: Optional[List[Tuple[Tensor, Tensor, float]]] = None, + generator: Optional[Callable[[int, int, int], Tensor]] = None, ) -> Tensor: r"""Generate a batch of initial conditions for random-restart optimziation. @@ -274,6 +275,9 @@ def gen_batch_initial_conditions( equality constraints: A list of tuples (indices, coefficients, rhs), with each tuple encoding an inequality constraint of the form `\sum_i (X[indices[i]] * coefficients[i]) = rhs`. + generator: Callable for generating samples that are then further + processed. It receives `n`, `q` and `seed` as arguments and + returns a tensor of shape `n x q x d`. Returns: A `num_restarts x q x d` tensor of initial conditions. @@ -297,6 +301,11 @@ def gen_batch_initial_conditions( "Option 'sample_around_best' is not supported when equality" "constraints are present." ) + if sample_around_best and generator: + raise UnsupportedError( + "Option 'sample_around_best' is not supported when custom " + "generator is be used." + ) seed: Optional[int] = options.get("seed") batch_limit: Optional[int] = options.get( "init_batch_limit", options.get("batch_limit") @@ -327,7 +336,9 @@ def gen_batch_initial_conditions( while factor < max_factor: with warnings.catch_warnings(record=True) as ws: n = raw_samples * factor - if inequality_constraints is None and equality_constraints is None: + if generator is not None: + X_rnd = generator(n, q, seed) + elif inequality_constraints is None and equality_constraints is None: if effective_dim <= SobolEngine.MAXDIM: X_rnd = draw_sobol_samples(bounds=bounds_cpu, n=n, q=q, seed=seed) else: diff --git a/botorch/optim/optimize.py b/botorch/optim/optimize.py index 065f1c0816..b120afba44 100644 --- a/botorch/optim/optimize.py +++ b/botorch/optim/optimize.py @@ -716,9 +716,12 @@ def optimize_acqf_list( options: Optional[Dict[str, Union[bool, float, int, str]]] = None, inequality_constraints: Optional[List[Tuple[Tensor, Tensor, float]]] = None, equality_constraints: Optional[List[Tuple[Tensor, Tensor, float]]] = None, + nonlinear_inequality_constraints: Optional[List[Callable]] = None, fixed_features: Optional[Dict[int, float]] = None, fixed_features_list: Optional[List[Dict[int, float]]] = None, post_processing_func: Optional[Callable[[Tensor], Tensor]] = None, + ic_generator: Optional[TGenInitialConditions] = None, + ic_gen_kwargs: Optional[Dict] = None, ) -> Tuple[Tensor, Tensor]: r"""Generate a list of candidates from a list of acquisition functions. @@ -741,6 +744,14 @@ def optimize_acqf_list( equality constraints: A list of tuples (indices, coefficients, rhs), with each tuple encoding an inequality constraint of the form `\sum_i (X[indices[i]] * coefficients[i]) = rhs` + nonlinear_inequality_constraints: A list of callables with that represent + non-linear inequality constraints of the form `callable(x) >= 0`. Each + callable is expected to take a `(num_restarts) x q x d`-dim tensor as an + input and return a `(num_restarts) x q`-dim tensor with the constraint + values. The constraints will later be passed to SLSQP. You need to pass in + `batch_initial_conditions` in this case. Using non-linear inequality + constraints also requires that `batch_limit` is set to 1, which will be + done automatically if not specified in `options`. fixed_features: A map `{feature_index: value}` for features that should be fixed to a particular value during generation. fixed_features_list: A list of maps `{feature_index: value}`. The i-th @@ -749,6 +760,13 @@ def optimize_acqf_list( post_processing_func: A function that post-processes an optimization result appropriately (i.e., according to `round-trip` transformations). + ic_generator: Function for generating initial conditions. Not needed when + `batch_initial_conditions` are provided. Defaults to + `gen_one_shot_kg_initial_conditions` for `qKnowledgeGradient` acquisition + functions and `gen_batch_initial_conditions` otherwise. Must be specified + for nonlinear inequality constraints. + ic_gen_kwargs: Additional keyword arguments passed to function specified by + `ic_generator` Returns: A two-element tuple containing @@ -784,10 +802,14 @@ def optimize_acqf_list( options=options or {}, inequality_constraints=inequality_constraints, equality_constraints=equality_constraints, + nonlinear_inequality_constraints=nonlinear_inequality_constraints, fixed_features_list=fixed_features_list, post_processing_func=post_processing_func, + ic_generator=ic_generator, + ic_gen_kwargs=ic_gen_kwargs, ) else: + ic_gen_kwargs = ic_gen_kwargs or {} candidate, acq_value = optimize_acqf( acq_function=acq_function, bounds=bounds, @@ -797,10 +819,13 @@ def optimize_acqf_list( options=options or {}, inequality_constraints=inequality_constraints, equality_constraints=equality_constraints, + nonlinear_inequality_constraints=nonlinear_inequality_constraints, fixed_features=fixed_features, post_processing_func=post_processing_func, return_best_only=True, sequential=False, + ic_generator=ic_generator, + **ic_gen_kwargs, ) candidate_list.append(candidate) acq_value_list.append(acq_value) @@ -818,8 +843,11 @@ def optimize_acqf_mixed( options: Optional[Dict[str, Union[bool, float, int, str]]] = None, inequality_constraints: Optional[List[Tuple[Tensor, Tensor, float]]] = None, equality_constraints: Optional[List[Tuple[Tensor, Tensor, float]]] = None, + nonlinear_inequality_constraints: Optional[List[Callable]] = None, post_processing_func: Optional[Callable[[Tensor], Tensor]] = None, batch_initial_conditions: Optional[Tensor] = None, + ic_generator: Optional[TGenInitialConditions] = None, + ic_gen_kwargs: Optional[Dict] = None, **kwargs: Any, ) -> Tuple[Tensor, Tensor]: r"""Optimize over a list of fixed_features and returns the best solution. @@ -847,11 +875,26 @@ def optimize_acqf_mixed( equality constraints: A list of tuples (indices, coefficients, rhs), with each tuple encoding an inequality constraint of the form `\sum_i (X[indices[i]] * coefficients[i]) = rhs` + nonlinear_inequality_constraints: A list of callables with that represent + non-linear inequality constraints of the form `callable(x) >= 0`. Each + callable is expected to take a `(num_restarts) x q x d`-dim tensor as an + input and return a `(num_restarts) x q`-dim tensor with the constraint + values. The constraints will later be passed to SLSQP. You need to pass in + `batch_initial_conditions` in this case. Using non-linear inequality + constraints also requires that `batch_limit` is set to 1, which will be + done automatically if not specified in `options`. post_processing_func: A function that post-processes an optimization result appropriately (i.e., according to `round-trip` transformations). batch_initial_conditions: A tensor to specify the initial conditions. Set this if you do not want to use default initialization strategy. + ic_generator: Function for generating initial conditions. Not needed when + `batch_initial_conditions` are provided. Defaults to + `gen_one_shot_kg_initial_conditions` for `qKnowledgeGradient` acquisition + functions and `gen_batch_initial_conditions` otherwise. Must be specified + for nonlinear inequality constraints. + ic_gen_kwargs: Additional keyword arguments passed to function specified by + `ic_generator` kwargs: kwargs do nothing. This is provided so that the same arguments can be passed to different acquisition functions without raising an error. @@ -873,6 +916,8 @@ def optimize_acqf_mixed( ) _raise_deprecation_warning_if_kwargs("optimize_acqf_mixed", kwargs) + ic_gen_kwargs = ic_gen_kwargs or {} + if q == 1: ff_candidate_list, ff_acq_value_list = [], [] for fixed_features in fixed_features_list: @@ -885,10 +930,13 @@ def optimize_acqf_mixed( options=options or {}, inequality_constraints=inequality_constraints, equality_constraints=equality_constraints, + nonlinear_inequality_constraints=nonlinear_inequality_constraints, fixed_features=fixed_features, post_processing_func=post_processing_func, batch_initial_conditions=batch_initial_conditions, + ic_generator=ic_generator, return_best_only=True, + **ic_gen_kwargs, ) ff_candidate_list.append(candidate) ff_acq_value_list.append(acq_value) @@ -914,8 +962,11 @@ def optimize_acqf_mixed( options=options or {}, inequality_constraints=inequality_constraints, equality_constraints=equality_constraints, + nonlinear_inequality_constraints=nonlinear_inequality_constraints, post_processing_func=post_processing_func, batch_initial_conditions=batch_initial_conditions, + ic_generator=ic_generator, + ic_gen_kwargs=ic_gen_kwargs, ) candidates = torch.cat([candidates, candidate], dim=-2) acq_function.set_X_pending( diff --git a/test/optim/test_initializers.py b/test/optim/test_initializers.py index 3a2b5c4326..59c2194590 100644 --- a/test/optim/test_initializers.py +++ b/test/optim/test_initializers.py @@ -40,7 +40,7 @@ transform_intra_point_constraint, ) from botorch.sampling.normal import IIDNormalSampler -from botorch.utils.sampling import draw_sobol_samples +from botorch.utils.sampling import draw_sobol_samples, manual_seed from botorch.utils.testing import ( BotorchTestCase, MockAcquisitionFunction, @@ -595,12 +595,94 @@ def test_gen_batch_initial_conditions_interpoint_constraints(self): batch_initial_conditions[1, 2, 0], ) + def test_gen_batch_initial_conditions_generator(self): + mock_acqf = MockAcquisitionFunction() + mock_acqf.objective = lambda y: y.squeeze(-1) + for dtype in (torch.float, torch.double): + bounds = torch.tensor( + [[0, 0, 0], [1, 1, 1]], device=self.device, dtype=dtype + ) + for nonnegative, seed, init_batch_limit, ffs in product( + [True, False], [None, 1234], [None, 1], [None, {0: 0.5}] + ): + + def generator(n: int, q: int, seed: int): + with manual_seed(seed): + X_rnd_nlzd = torch.rand( + n, + q, + bounds.shape[-1], + dtype=bounds.dtype, + device=self.device, + ) + X_rnd = bounds[0] + (bounds[1] - bounds[0]) * X_rnd_nlzd + X_rnd[..., -1] = 0.42 + return X_rnd + + mock_acqf = MockAcquisitionFunction() + with mock.patch.object( + MockAcquisitionFunction, + "__call__", + wraps=mock_acqf.__call__, + ): + batch_initial_conditions = gen_batch_initial_conditions( + acq_function=mock_acqf, + bounds=bounds, + q=2, + num_restarts=4, + raw_samples=10, + generator=generator, + fixed_features=ffs, + options={ + "nonnegative": nonnegative, + "eta": 0.01, + "alpha": 0.1, + "seed": seed, + "init_batch_limit": init_batch_limit, + }, + ) + expected_shape = torch.Size([4, 2, 3]) + self.assertEqual(batch_initial_conditions.shape, expected_shape) + self.assertEqual(batch_initial_conditions.device, bounds.device) + self.assertEqual(batch_initial_conditions.dtype, bounds.dtype) + self.assertTrue((batch_initial_conditions[..., -1] == 0.42).all()) + if ffs is not None: + for idx, val in ffs.items(): + self.assertTrue( + torch.all(batch_initial_conditions[..., idx] == val) + ) + + def test_error_generator_with_sample_around_best(self): + tkwargs = {"device": self.device, "dtype": torch.double} + + def generator(n: int, q: int, seed: int): + return torch.rand(n, q, 3).to(**tkwargs) + + with self.assertRaisesRegex( + UnsupportedError, + "Option 'sample_around_best' is not supported when custom " + "generator is be used.", + ): + gen_batch_initial_conditions( + MockAcquisitionFunction(), + bounds=torch.tensor([[0, 0], [1, 1]], **tkwargs), + q=1, + num_restarts=1, + raw_samples=1, + generator=generator, + options={"sample_around_best": True}, + ) + def test_error_equality_constraints_with_sample_around_best(self): tkwargs = {"device": self.device, "dtype": torch.double} # this will give something that does not respect the constraints # TODO: it would be good to have a utils function to check if the # constraints are obeyed - with self.assertRaises(UnsupportedError) as e: + with self.assertRaisesRegex( + UnsupportedError, + "Option 'sample_around_best' is not supported when equality" + "constraints are present.", + ): gen_batch_initial_conditions( MockAcquisitionFunction(), bounds=torch.tensor([[0, 0], [1, 1]], **tkwargs), @@ -616,10 +698,6 @@ def test_error_equality_constraints_with_sample_around_best(self): ], options={"sample_around_best": True}, ) - self.assertTrue( - "Option 'sample_around_best' is not supported when equality" - "constraints are present." in str(e.exception) - ) class TestGenOneShotKGInitialConditions(BotorchTestCase): diff --git a/test/optim/test_optimize.py b/test/optim/test_optimize.py index 4b4cff3c17..19a0fec666 100644 --- a/test/optim/test_optimize.py +++ b/test/optim/test_optimize.py @@ -1360,6 +1360,8 @@ def test_optimize_acqf_mixed_q1(self, mock_optimize_acqf): "batch_initial_conditions": None, "return_best_only": True, "sequential": False, + "ic_generator": None, + "nonlinear_inequality_constraints": None, } for i in range(len(call_args_list)): expected_call_args["fixed_features"] = fixed_features_list[i]