Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ThompsonSampling acquisition function #2443

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions botorch/acquisition/thompson_sampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional

import torch
from botorch.acquisition.analytic import AcquisitionFunction
from botorch.acquisition.objective import PosteriorTransform
from botorch.models.model import Model
from botorch.sampling.pathwise.posterior_samplers import get_matheron_path_model
from botorch.utils.transforms import t_batch_mode_transform
from torch import Tensor


BATCH_SIZE_CHANGE_ERROR = """The batch size of PathwiseThompsonSampling should \
not change during a forward pass - was {}, now {}. Please re-initialize the \
acquisition if you want to change the batch size."""


class PathwiseThompsonSampling(AcquisitionFunction):
r"""Single-outcome Thompson Sampling packaged as an (analytic)
acquisition function. Querying the acquisition function gives the summed
values of one or more draws from a pathwise drawn posterior sample, and thus
it maximization yields one (or multiple) Thompson sample(s).

Example:
>>> model = SingleTaskGP(train_X, train_Y)
>>> TS = PathwiseThompsonSampling(model)
"""

def __init__(
self,
model: Model,
posterior_transform: Optional[PosteriorTransform] = None,
) -> None:
r"""Single-outcome TS.

Args:
model: A fitted GP model.
posterior_transform: A PosteriorTransform. If using a multi-output model,
a PosteriorTransform that transforms the multi-output posterior into a
single-output posterior is required.
"""
if model._is_fully_bayesian:
raise NotImplementedError(
"PathwiseThompsonSampling is not supported for fully Bayesian models",
)

super().__init__(model=model)
self.batch_size: Optional[int] = None

def redraw(self) -> None:
self.samples = get_matheron_path_model(
model=self.model, sample_shape=torch.Size([self.batch_size])
)

@t_batch_mode_transform()
def forward(self, X: Tensor) -> Tensor:
r"""Evaluate the pathwise posterior sample draws on the candidate set X.

Args:
X: A `(b1 x ... bk) x 1 x d`-dim batched tensor of `d`-dim design points.

Returns:
A `(b1 x ... bk) x [num_models for fully bayesian]`-dim tensor of
evaluations on the posterior sample draws.
"""
batch_size = X.shape[-2]
q_dim = -2

# batch_shape x q x 1 x d
X = X.unsqueeze(-2)
if self.batch_size is None:
self.batch_size = batch_size
self.redraw()
elif self.batch_size != batch_size:
raise ValueError(
BATCH_SIZE_CHANGE_ERROR.format(self.batch_size, batch_size)
)

# posterior_values.shape post-squeeze:
# batch_shape x q x m
posterior_values = self.samples(X).squeeze(-2)
# sum over batch dim and squeeze num_objectives dim (-1)
return posterior_values.sum(q_dim).squeeze(-1)
5 changes: 5 additions & 0 deletions sphinx/source/acquisition.rst
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,11 @@ Risk Measures
.. automodule:: botorch.acquisition.risk_measures
:members:

Thompson Sampling
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: botorch.acquisition.thompson_sampling
:members:

Multi-Output Risk Measures
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: botorch.acquisition.multi_objective.multi_output_risk_measures
Expand Down
136 changes: 136 additions & 0 deletions test/acquisition/test_thompson_sampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from itertools import product

import torch
from botorch.acquisition.thompson_sampling import PathwiseThompsonSampling
from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP

from botorch.models.gp_regression import SingleTaskGP
from botorch.models.model import Model
from botorch.models.transforms.outcome import Standardize
from botorch.utils.testing import BotorchTestCase


def get_model(train_X, train_Y, standardize_model):
if standardize_model:
outcome_transform = Standardize(m=1)

else:
outcome_transform = None
model = SingleTaskGP(
train_X=train_X,
train_Y=train_Y,
outcome_transform=outcome_transform,
)
return model


def _get_mcmc_samples(num_samples: int, dim: int, infer_noise: bool, **tkwargs):

mcmc_samples = {
"lengthscale": torch.rand(num_samples, 1, dim, **tkwargs),
"outputscale": torch.rand(num_samples, **tkwargs),
"mean": torch.randn(num_samples, **tkwargs),
}
if infer_noise:
mcmc_samples["noise"] = torch.rand(num_samples, 1, **tkwargs)
return mcmc_samples


def get_fully_bayesian_model(
train_X,
train_Y,
num_models,
**tkwargs,
):

model = SaasFullyBayesianSingleTaskGP(
train_X=train_X,
train_Y=train_Y,
)
mcmc_samples = _get_mcmc_samples(
num_samples=num_models,
dim=train_X.shape[-1],
infer_noise=True,
**tkwargs,
)
model.load_mcmc_samples(mcmc_samples)
return model


class TestPathwiseThompsonSampling(BotorchTestCase):
def _test_thompson_sampling_base(self, model: Model):
acq = PathwiseThompsonSampling(
model=model,
)
X_observed = model.train_inputs[0]
input_dim = X_observed.shape[-1]
test_X = torch.rand(4, 1, input_dim).to(X_observed)
# re-draw samples and expect other output
acq_pass = acq(test_X)
self.assertTrue(acq_pass.shape == test_X.shape[:-2])

acq_pass1 = acq(test_X)
self.assertAllClose(acq_pass1, acq(test_X))
acq.redraw()
acq_pass2 = acq(test_X)
self.assertFalse(torch.allclose(acq_pass1, acq_pass2))

def _test_thompson_sampling_batch(self, model: Model):
X_observed = model.train_inputs[0]
input_dim = X_observed.shape[-1]
batch_acq = PathwiseThompsonSampling(
model=model,
)
self.assertEqual(batch_acq.batch_size, None)
test_X = torch.rand(4, 5, input_dim).to(X_observed)
batch_acq(test_X)
self.assertEqual(batch_acq.batch_size, 5)
test_X = torch.rand(4, 7, input_dim).to(X_observed)
with self.assertRaisesRegex(
ValueError,
"The batch size of PathwiseThompsonSampling should not "
"change during a forward pass - was 5, now 7. Please re-initialize "
"the acquisition if you want to change the batch size.",
):
batch_acq(test_X)

batch_acq2 = PathwiseThompsonSampling(model)
test_X = torch.rand(4, 7, 1, input_dim).to(X_observed)
self.assertEqual(batch_acq2(test_X).shape, test_X.shape[:-2])

batch_acq3 = PathwiseThompsonSampling(model)
test_X = torch.rand(4, 7, 3, input_dim).to(X_observed)
self.assertEqual(batch_acq3(test_X).shape, test_X.shape[:-2])

def test_thompson_sampling_single_task(self):
input_dim = 2
num_objectives = 1
for dtype, standardize_model in product(
(torch.float32, torch.float64), (True, False)
):
tkwargs = {"device": self.device, "dtype": dtype}
train_X = torch.rand(4, input_dim, **tkwargs)
train_Y = 10 * torch.rand(4, num_objectives, **tkwargs)
model = get_model(train_X, train_Y, standardize_model=standardize_model)
self._test_thompson_sampling_base(model)
self._test_thompson_sampling_batch(model)

def test_thompson_sampling_fully_bayesian(self):
input_dim = 2
num_objectives = 1
tkwargs = {"device": self.device, "dtype": torch.float64}
train_X = torch.rand(4, input_dim, **tkwargs)
train_Y = 10 * torch.rand(4, num_objectives, **tkwargs)

fb_model = get_fully_bayesian_model(train_X, train_Y, num_models=3, **tkwargs)
with self.assertRaisesRegex(
NotImplementedError,
"PathwiseThompsonSampling is not supported for fully Bayesian models",
):
PathwiseThompsonSampling(model=fb_model)
Loading