Skip to content

Commit

Permalink
initializer for one-shot hvkg (pytorch#1982)
Browse files Browse the repository at this point in the history
Summary:

Adds an initializer for optimizing HVKG based on optimizing HV under the current posterior mean

Reviewed By: Balandat

Differential Revision: D48230663
  • Loading branch information
sdaulton authored and facebook-github-bot committed Nov 9, 2023
1 parent b309e24 commit 8a6db2a
Show file tree
Hide file tree
Showing 5 changed files with 437 additions and 9 deletions.
6 changes: 3 additions & 3 deletions botorch/acquisition/fixed_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def forward(self, X: Tensor):
by adding `values` in the appropriate places (see
`_construct_X_full`).
"""
X_full = self._construct_X_full(X)
X_full = self.construct_X_full(X)
return self.acq_func(X_full)

@property
Expand All @@ -168,11 +168,11 @@ def X_pending(self):
def X_pending(self, X_pending: Optional[Tensor]):
r"""Sets the `X_pending` of the base acquisition function."""
if X_pending is not None:
self.acq_func.X_pending = self._construct_X_full(X_pending)
self.acq_func.X_pending = self.construct_X_full(X_pending)
else:
self.acq_func.X_pending = X_pending

def _construct_X_full(self, X: Tensor) -> Tensor:
def construct_X_full(self, X: Tensor) -> Tensor:
r"""Constructs the full input for the base acquisition function.
Args:
Expand Down
4 changes: 3 additions & 1 deletion botorch/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,7 +656,9 @@ def fantasize(
if observation_noise is not None:
observation_noise_i = observation_noise[..., mask_i, i : i + 1]
else:
sampler_i = sampler
sampler_i = (
sampler.samplers[i] if isinstance(sampler, ListSampler) else sampler
)

fant_model = self.models[i].fantasize(
X=X_i,
Expand Down
236 changes: 231 additions & 5 deletions botorch/optim/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,16 @@
from botorch import settings
from botorch.acquisition import analytic, monte_carlo, multi_objective
from botorch.acquisition.acquisition import AcquisitionFunction
from botorch.acquisition.fixed_feature import FixedFeatureAcquisitionFunction
from botorch.acquisition.knowledge_gradient import (
_get_value_function,
qKnowledgeGradient,
)
from botorch.acquisition.multi_objective.hypervolume_knowledge_gradient import (
_get_hv_value_function,
qHypervolumeKnowledgeGradient,
qMultiFidelityHypervolumeKnowledgeGradient,
)
from botorch.exceptions.errors import BotorchTensorDimensionError, UnsupportedError
from botorch.exceptions.warnings import (
BadInitialCandidatesWarning,
Expand Down Expand Up @@ -245,6 +251,7 @@ def gen_batch_initial_conditions(
inequality_constraints: Optional[List[Tuple[Tensor, Tensor, float]]] = None,
equality_constraints: Optional[List[Tuple[Tensor, Tensor, float]]] = None,
generator: Optional[Callable[[int, int, int], Tensor]] = None,
fixed_X_fantasies: Optional[Tensor] = None,
) -> Tensor:
r"""Generate a batch of initial conditions for random-restart optimziation.
Expand Down Expand Up @@ -278,6 +285,11 @@ def gen_batch_initial_conditions(
generator: Callable for generating samples that are then further
processed. It receives `n`, `q` and `seed` as arguments and
returns a tensor of shape `n x q x d`.
fixed_X_fantasies: A fixed set of fantasy points to concatenate to
the `q` candidates being initialized along the `-2` dimension. The
shape should be `num_pseudo_points x d`. E.g., this should be
`num_fantasies x d` for KG and `num_fantasies*num_pareto x d`
for HVKG.
Returns:
A `num_restarts x q x d` tensor of initial conditions.
Expand Down Expand Up @@ -379,6 +391,22 @@ def gen_batch_initial_conditions(
dim=0,
)
X_rnd = fix_features(X_rnd, fixed_features=fixed_features)
if fixed_X_fantasies is not None:
if (d_f := fixed_X_fantasies.shape[-1]) != (d_r := X_rnd.shape[-1]):
raise BotorchTensorDimensionError(
"`fixed_X_fantasies` and `bounds` must both have the same "
f"trailing dimension `d`, but have {d_f} and {d_r}, "
"respectively."
)
X_rnd = torch.cat(
[
X_rnd,
fixed_X_fantasies.cpu()
.unsqueeze(0)
.expand(X_rnd.shape[0], *fixed_X_fantasies.shape),
],
dim=-2,
)
with torch.no_grad():
if batch_limit is None:
batch_limit = X_rnd.shape[0]
Expand Down Expand Up @@ -425,7 +453,7 @@ def gen_one_shot_kg_initial_conditions(
This function generates initial conditions for optimizing one-shot KG using
the maximizer of the posterior objective. Intutively, the maximizer of the
fantasized posterior will often be close to a maximizer of the current
posterior. This function uses that fact to generate the initital conditions
posterior. This function uses that fact to generate the initial conditions
for the fantasy points. Specifically, a fraction of `1 - frac_random` (see
options) is generated by sampling from the set of maximizers of the
posterior objective (obtained via random restart optimization) according to
Expand All @@ -436,7 +464,7 @@ def gen_one_shot_kg_initial_conditions(
strategy in `gen_batch_initial_conditions`.
Args:
acq_function: The qKnowledgeGradient instance to be optimized.
acq_function: The qHypervolumeKnowledgeGradient instance to be optimized.
bounds: A `2 x d` tensor of lower and upper bounds for each column of
task features.
q: The number of candidates to consider.
Expand Down Expand Up @@ -467,10 +495,10 @@ def gen_one_shot_kg_initial_conditions(
of points (candidate points plus fantasy points).
Example:
>>> qKG = qKnowledgeGradient(model, num_fantasies=64)
>>> qHVKG = qHypervolumeKnowledgeGradient(model, ref_point=num_fantasies=64)
>>> bounds = torch.tensor([[0., 0.], [1., 1.]])
>>> Xinit = gen_one_shot_kg_initial_conditions(
>>> qKG, bounds, q=3, num_restarts=10, raw_samples=512,
>>> Xinit = gen_one_shot_hvkg_initial_conditions(
>>> qHVKG, bounds, q=3, num_restarts=10, raw_samples=512,
>>> options={"frac_random": 0.25},
>>> )
"""
Expand Down Expand Up @@ -528,6 +556,204 @@ def gen_one_shot_kg_initial_conditions(
return ics


def gen_one_shot_hvkg_initial_conditions(
acq_function: qHypervolumeKnowledgeGradient,
bounds: Tensor,
q: int,
num_restarts: int,
raw_samples: int,
fixed_features: Optional[Dict[int, float]] = None,
options: Optional[Dict[str, Union[bool, float, int]]] = None,
inequality_constraints: Optional[List[Tuple[Tensor, Tensor, float]]] = None,
equality_constraints: Optional[List[Tuple[Tensor, Tensor, float]]] = None,
) -> Optional[Tensor]:
r"""Generate a batch of smart initializations for qHypervolumeKnowledgeGradient.
This function generates initial conditions for optimizing one-shot HVKG using
the hypervolume maximizing set (of fixed size) under the posterior mean.
Intutively, the hypervolume maximizing set of the fantasized posterior mean
will often be close to a hypervolume maximizing set under the current posterior
mean. This function uses that fact to generate the initial conditions
for the fantasy points. Specifically, a fraction of `1 - frac_random` (see
options) of the restarts are generated by learning the hypervolume maximizing sets
under the current posterior mean, where each hypervolume maximizing set is
obtained from maximizing the hypervolume from a different starting point. Given
a hypervolume maximizing set, the `q` candidate points are selected using to the
standard initialization strategy in `gen_batch_initial_conditions`, with the fixed
hypervolume maximizing set. The remaining `frac_random` restarts fantasy points
as well as all `q` candidate points are chosen according to the standard
initialization strategy in `gen_batch_initial_conditions`.
Args:
acq_function: The qKnowledgeGradient instance to be optimized.
bounds: A `2 x d` tensor of lower and upper bounds for each column of
task features.
q: The number of candidates to consider.
num_restarts: The number of starting points for multistart acquisition
function optimization.
raw_samples: The number of raw samples to consider in the initialization
heuristic.
fixed_features: A map `{feature_index: value}` for features that
should be fixed to a particular value during generation.
options: Options for initial condition generation. These contain all
settings for the standard heuristic initialization from
`gen_batch_initial_conditions`. In addition, they contain
`frac_random` (the fraction of fully random fantasy points),
`num_inner_restarts` and `raw_inner_samples` (the number of random
restarts and raw samples for solving the posterior objective
maximization problem, respectively) and `eta` (temperature parameter
for sampling heuristic from posterior objective maximizers).
inequality constraints: A list of tuples (indices, coefficients, rhs),
with each tuple encoding an inequality constraint of the form
`\sum_i (X[indices[i]] * coefficients[i]) >= rhs`.
equality constraints: A list of tuples (indices, coefficients, rhs),
with each tuple encoding an inequality constraint of the form
`\sum_i (X[indices[i]] * coefficients[i]) = rhs`.
Returns:
A `num_restarts x q' x d` tensor that can be used as initial conditions
for `optimize_acqf()`. Here `q' = q + num_fantasies` is the total number
of points (candidate points plus fantasy points).
Example:
>>> qHVKG = qHypervolumeKnowledgeGradient(model, ref_point)
>>> bounds = torch.tensor([[0., 0.], [1., 1.]])
>>> Xinit = gen_one_shot_hvkg_initial_conditions(
>>> qHVKG, bounds, q=3, num_restarts=10, raw_samples=512,
>>> options={"frac_random": 0.25},
>>> )
"""
from botorch.optim.optimize import optimize_acqf

options = options or {}
frac_random: float = options.get("frac_random", 0.1)
if not 0 < frac_random < 1:
raise ValueError(
f"frac_random must take on values in (0,1). Value: {frac_random}"
)

value_function = _get_hv_value_function(
model=acq_function.model,
ref_point=acq_function.ref_point,
objective=acq_function.objective,
sampler=acq_function.inner_sampler,
use_posterior_mean=acq_function.use_posterior_mean,
)

is_mf_hvkg = isinstance(acq_function, qMultiFidelityHypervolumeKnowledgeGradient)
if is_mf_hvkg:
dim = bounds.shape[-1]
fidelity_dims, fidelity_targets = zip(*acq_function.target_fidelities.items())
value_function = FixedFeatureAcquisitionFunction(
acq_function=value_function,
d=dim,
columns=fidelity_dims,
values=fidelity_targets,
)

non_fidelity_dims = list(set(range(dim)) - set(fidelity_dims))

num_optim_restarts = int(round(num_restarts * (1 - frac_random)))
fantasy_cands, fantasy_vals = optimize_acqf(
acq_function=value_function,
bounds=bounds[:, non_fidelity_dims] if is_mf_hvkg else bounds,
q=acq_function.num_pareto,
num_restarts=options.get("num_inner_restarts", 20),
raw_samples=options.get("raw_inner_samples", 1024),
fixed_features=fixed_features,
return_best_only=False,
options=options,
inequality_constraints=inequality_constraints,
equality_constraints=equality_constraints,
sequential=False,
)
# sampling from the optimizers
eta = options.get("eta", 2.0)
if num_optim_restarts > 0:
probs = torch.nn.functional.softmax(eta * standardize(fantasy_vals))
idx = torch.multinomial(
probs,
num_optim_restarts * acq_function.num_fantasies,
replacement=True,
)
optim_ics = fantasy_cands[idx]
if is_mf_hvkg:
# add fixed features
optim_ics = value_function.construct_X_full(optim_ics)
optim_ics = optim_ics.reshape(
num_optim_restarts, acq_function.num_pseudo_points, bounds.shape[-1]
)

# get random initial conditions
num_random_restarts = num_restarts - num_optim_restarts
if num_random_restarts > 0:
q_aug = acq_function.get_augmented_q_batch_size(q=q)
base_ics = gen_batch_initial_conditions(
acq_function=acq_function,
bounds=bounds,
q=q_aug,
num_restarts=num_restarts,
raw_samples=raw_samples,
fixed_features=fixed_features,
options=options,
inequality_constraints=inequality_constraints,
equality_constraints=equality_constraints,
)

if num_optim_restarts > 0:
probs = torch.full(
(num_restarts,),
1.0 / num_restarts,
dtype=optim_ics.dtype,
device=optim_ics.device,
)
optim_idxr = probs.multinomial(
num_samples=num_optim_restarts, replacement=False
)
base_ics[optim_idxr, q:] = optim_ics
else:
# optim_ics is num_restarts x num_pseudo_points x d
# add padding so that base_ics is num_restarts x q+num_pseudo_points x d
q_padding = torch.zeros(
optim_ics.shape[0],
q,
optim_ics.shape[-1],
dtype=optim_ics.dtype,
device=optim_ics.device,
)
base_ics = torch.cat([q_padding, optim_ics], dim=-2)

if num_optim_restarts > 0:
all_ics = []
if num_random_restarts > 0:
optim_idcs = optim_idxr.view(-1).tolist()
else:
optim_idcs = list(range(num_restarts))
for i in list(range(num_restarts)):
if i in optim_idcs:
# optimize the q points,
# given fixed, optimized fantasy designs
ics = gen_batch_initial_conditions(
acq_function=acq_function,
bounds=bounds,
q=q,
num_restarts=1,
raw_samples=raw_samples,
fixed_features=fixed_features,
options=options,
inequality_constraints=inequality_constraints,
equality_constraints=equality_constraints,
fixed_X_fantasies=base_ics[i, q:],
)
else:
# ics are all randomly sampled
ics = base_ics[i : i + 1]
all_ics.append(ics)
return torch.cat(all_ics, dim=0)

return base_ics


def gen_value_function_initial_conditions(
acq_function: AcquisitionFunction,
bounds: Tensor,
Expand Down
6 changes: 6 additions & 0 deletions botorch/optim/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,16 @@
OneShotAcquisitionFunction,
)
from botorch.acquisition.knowledge_gradient import qKnowledgeGradient
from botorch.acquisition.multi_objective.hypervolume_knowledge_gradient import (
qHypervolumeKnowledgeGradient,
)
from botorch.exceptions import InputDataError, UnsupportedError
from botorch.exceptions.warnings import OptimizationWarning
from botorch.generation.gen import gen_candidates_scipy, TGenCandidates
from botorch.logging import logger
from botorch.optim.initializers import (
gen_batch_initial_conditions,
gen_one_shot_hvkg_initial_conditions,
gen_one_shot_kg_initial_conditions,
TGenInitialConditions,
)
Expand Down Expand Up @@ -129,6 +133,8 @@ def get_ic_generator(self) -> TGenInitialConditions:
return self.ic_generator
elif isinstance(self.acq_function, qKnowledgeGradient):
return gen_one_shot_kg_initial_conditions
elif isinstance(self.acq_function, qHypervolumeKnowledgeGradient):
return gen_one_shot_hvkg_initial_conditions
return gen_batch_initial_conditions


Expand Down
Loading

0 comments on commit 8a6db2a

Please sign in to comment.