Skip to content

Commit

Permalink
LinearEllipticalSliceSampler with fixed feature constraints (#1883)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1883

This adds support for fixed feature equality constraints to `LinearEllipticalSliceSampler`.

Differential Revision: D46613288

fbshipit-source-id: 6a5a53d2ebe2ad831406b8df73bdf69d35c4cea2
  • Loading branch information
SebastianAment authored and facebook-github-bot committed Jun 13, 2023
1 parent e7b1994 commit 24fba30
Show file tree
Hide file tree
Showing 2 changed files with 264 additions and 68 deletions.
215 changes: 170 additions & 45 deletions botorch/utils/probability/lin_ess.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from __future__ import annotations

import math
from typing import Optional, Tuple
from typing import List, Optional, Tuple, Union

import torch
from botorch.utils.sampling import PolytopeSampler
Expand All @@ -34,13 +34,10 @@
class LinearEllipticalSliceSampler(PolytopeSampler):
r"""Linear Elliptical Slice Sampler.
TODOs:
- clean up docstrings
Ideas:
- Add batch support, broadcasting over parallel chains.
- optimize computations (if possible)
Maybe TODOs:
- Support degenerate domains (with zero volume)?
- Optimize computations if possible, potentially with torch.compile.
- Extend fixed features constraint to general linear equality constraints.
"""

def __init__(
Expand All @@ -54,6 +51,7 @@ def __init__(
check_feasibility: bool = False,
burnin: int = 0,
thinning: int = 0,
fixed_indices: Optional[Union[List[int], Tensor]] = None,
) -> None:
r"""Initialize LinearEllipticalSliceSampler.
Expand Down Expand Up @@ -82,6 +80,10 @@ def __init__(
burnin: Number of samples to generate upon initialization to warm up the
sampler.
thinning: Number of samples to skip before returning a sample in `draw`.
fixed_indices: Integer list or `d`-dim Tensor representing the indices of
dimensions that are constrained to be fixed to the values specified in
the `interior_point`, which is required to be passed in conjunction with
`fixed_indices`.
This sampler samples from a multivariante Normal `N(mean, covariance_matrix)`
subject to linear domain constraints `A x <= b` (intersected with box bounds,
Expand All @@ -94,13 +96,28 @@ def __init__(
bounds=bounds,
)
tkwargs = {"device": self.x0.device, "dtype": self.x0.dtype}
self._mean = mean
if covariance_matrix is not None:
if covariance_root is not None:
raise ValueError(
"Provide either covariance_matrix or covariance_root, not both."
)
if covariance_matrix is not None and covariance_root is not None:
raise ValueError(
"Provide either covariance_matrix or covariance_root, not both."
)

# can't unpack inequality constraints directly if bounds are passed
A, b = self.A, self.b
self._Az, self._bz = A, b
self._is_fixed, self._not_fixed = None, None
mean, covariance_matrix = self._fixed_features_initialization(
A=A,
b=b,
interior_point=interior_point,
mean=mean,
covariance_matrix=covariance_matrix,
covariance_root=covariance_root,
fixed_indices=fixed_indices,
)

self._mean = mean
# Have to delay factorization until after fixed features initialization.
if covariance_matrix is not None: # implies root is None
covariance_root, info = torch.linalg.cholesky_ex(covariance_matrix)
not_psd = torch.any(info)
if not_psd:
Expand All @@ -110,19 +127,12 @@ def __init__(
)
self._covariance_root = covariance_root

# For non-standard mean and covariance, we're going to rewrite the problem as
# sampling from a standard normal distribution subject to modified constraints.
# Ax - b = A @ (covar_root @ z + mean) - b
# = (A @ covar_root) @ z - (b - A @ mean)
# = _Az @ z - _bz
self._Az = (
self.A if self._covariance_root is None else self.A @ self._covariance_root
)
self._bz = self.b if self._mean is None else self.b - self.A @ self._mean
# Rewrite the constraints as a system that constrains a standard Normal.
self._standardization_initialization()

# state of the sampler ("current point")
self._x = self.x0.clone()
self._z = self._standardize(self._x)
self._z = self._transform(self._x)

# We will need the following repeatedly, let's allocate them once
self._zero = torch.zeros(1, **tkwargs)
Expand All @@ -135,6 +145,59 @@ def __init__(
self.draw(burnin)
self.thinning = thinning

def _fixed_features_initialization(
self,
A: Tensor,
b: Tensor,
interior_point: Optional[Tensor],
mean: Optional[Tensor],
covariance_matrix: Optional[Tensor],
covariance_root: Optional[Tensor],
fixed_indices: Optional[Union[List[int], Tensor]],
) -> Tuple[Optional[Tensor], Optional[Tensor]]:
if fixed_indices is not None:
if interior_point is None:
raise ValueError(
"If `fixed_indices` are provided, an interior point must also be "
"provided in order to infer feasible values of the fixed features."
)
if covariance_root is not None:
raise ValueError(
"Provide either covariance_root or fixed_indices, not both."
)
d = interior_point.shape[0]
is_fixed, not_fixed = get_index_tensors(fixed_indices=fixed_indices, d=d)
self._is_fixed = is_fixed
self._not_fixed = not_fixed
# Transforming constraint system to incorporate fixed features:
# A @ x - b = (A[:, fixed] @ x[fixed] + A[:, not fixed] @ x[not fixed]) - b
# = A[:, not fixed] @ x[not fixed] - (b - A[:, fixed] @ x[fixed])
# = Az @ z - bz
self._Az = A[:, not_fixed]
self._bz = b - A[:, is_fixed] @ interior_point[is_fixed]
if mean is not None:
mean = mean[not_fixed]
if covariance_matrix is not None: # subselect active dimensions
covariance_matrix = covariance_matrix[
not_fixed.unsqueeze(-1), not_fixed.unsqueeze(0)
]
return mean, covariance_matrix

def _standardization_initialization(self) -> None:
"""For non-standard mean and covariance, we're going to rewrite the problem as
sampling from a standard normal distribution subject to modified constraints.
A @ x - b = A @ (covar_root @ z + mean) - b
= (A @ covar_root) @ z - (b - A @ mean)
= _Az @ z - _bz
NOTE: We need to standardize bz before Az in the following, because it relies
on the untransformed Az. We can't simply use A instead because Az might have
been subject to the fixed features transformation.
"""
if self._mean is not None:
self._bz = self._bz - self._Az @ self._mean
if self._covariance_root is not None:
self._Az = self._Az @ self._covariance_root

@property
def lifetime_samples(self) -> int:
"""The total number of samples generated by the sampler during its lifetime."""
Expand All @@ -156,22 +219,6 @@ def draw(self, n: int = 1) -> Tuple[Tensor, Tensor]:
samples.append(self.step())
return torch.cat(samples, dim=-1).transpose(-1, -2)

def _unstandardize(self, z: Tensor) -> Tensor:
x = z
if self._covariance_root is not None:
x = self._covariance_root @ x
if self._mean is not None:
x = x + self._mean
return x

def _standardize(self, x: Tensor) -> Tensor:
z = x
if self._mean is not None:
z = z - self._mean
if self._covariance_root is not None:
z = torch.linalg.solve_triangular(self._covariance_root, z, upper=False)
return z

def step(self) -> Tensor:
r"""Take a step, return the new sample, update the internal state.
Expand All @@ -182,7 +229,7 @@ def step(self) -> Tensor:
theta = self._draw_angle(nu=nu)
z = self._get_cart_coords(nu=nu, theta=theta)
self._z[:] = z
x = self._unstandardize(z)
x = self._untransform(z)
self._x[:] = x
self._lifetime_samples += 1
if self.check_feasibility and (not self._is_feasible(self._x)):
Expand Down Expand Up @@ -347,25 +394,103 @@ def _active_theta_and_delta(self, nu: Tensor, theta: Tensor) -> Tensor:
theta_mid = torch.cat((last_mid, theta_mid, last_mid), dim=0)
samples_mid = self._get_cart_coords(nu=nu, theta=theta_mid)
delta_feasibility = (
self._is_feasible(samples_mid, standardized=True).to(dtype=int).diff()
self._is_feasible(samples_mid, transformed=True).to(dtype=int).diff()
)
active_indices = delta_feasibility.nonzero()
return theta[active_indices], delta_feasibility[active_indices]

def _is_feasible(self, points: Tensor, standardized: bool = False) -> Tensor:
def _is_feasible(self, points: Tensor, transformed: bool = False) -> Tensor:
r"""Returns a Boolean tensor indicating whether the `points` are feasible,
i.e. they satisfy `A @ points <= b`, where `(A, b)` are the tensors passed
as the `inequality_constraints` to the constructor of the sampler.
Args:
points: A `d x M`-dim tensor of points.
standardized: Wether points are assumed to be standardized by a change of
transformed: Wether points are assumed to be transformed by a change of
basis, which means feasibility should be computed based on the
standardized constraint system (_A, _b), instead of (A, b).
transformed constraint system (_Az, _bz), instead of (A, b).
Returns:
An `M`-dim binary tensor where `True` indicates that the associated
point is feasible.
"""
A, b = (self._Az, self._bz) if standardized else (self.A, self.b)
A, b = (self._Az, self._bz) if transformed else (self.A, self.b)
return (A @ points <= b).all(dim=0)

def _transform(self, x: Tensor) -> Tensor:
"""Transforms the input so that it is equivalent to a standard Normal variable
constrained with the modified system constraints (self._Az, self._bz).
Args:
x: The input tensor to be transformed, usually `d x 1`-dimensional.
Returns:
A `d x 1`-dimensional tensor of transformed values subject to the modified
system of constraints.
"""
if self._not_fixed is not None:
x = x[self._not_fixed]
return self._standardize(x)

def _untransform(self, z: Tensor) -> Tensor:
"""The inverse transform of the `_transform`, i.e. maps `z` back to the original
space where it is subject to the original constraint system (self.A, self.b).
Args:
z: The transformed tensor to be un-transformed, usually `d x 1`-dimensional.
Returns:
A `d x 1`-dimensional tensor of un-transformed values subject to the
original system of constraints.
"""
if self._is_fixed is None:
return self._unstandardize(z)
else:
x = self._x.clone() # _x already contains the fixed values
x[self._not_fixed] = self._unstandardize(z)
return x

def _standardize(self, x: Tensor) -> Tensor:
"""_transform helper standardizing the input `x`, which is assumed to be a
`d x 1`-dim Tensor, or a `len(self._not_fixed) x 1`-dim if there are fixed
indices.
"""
z = x
if self._mean is not None:
z = z - self._mean
if self._covariance_root is not None:
z = torch.linalg.solve_triangular(self._covariance_root, z, upper=False)
return z

def _unstandardize(self, z: Tensor) -> Tensor:
"""_untransform helper un-standardizing the input `z`, which is assumed to be a
`d x 1`-dim Tensor, or a `len(self._not_fixed) x 1`-dim if there are fixed
indices.
"""
x = z
if self._covariance_root is not None:
x = self._covariance_root @ x
if self._mean is not None:
x = x + self._mean
return x


def get_index_tensors(
fixed_indices: Union[List[int], Tensor], d: int
) -> Tuple[Tensor, Tensor]:
"""Converts `fixed_indices` to a `d`-dim integral Tensor that is True at indices
that are contained in `fixed_indices` and False otherwise.
Args:
fixed_indices: A list or Tensoro of integer indices to fix.
d: The dimensionality of the Tensors to be indexed.
Returns:
A Tuple of integral Tensors partitioning [1, d] into indices that are fixed
(first tensor) and non-fixed (second tensor).
"""
is_fixed = torch.as_tensor(fixed_indices)
dtype, device = is_fixed.dtype, is_fixed.device
dims = torch.arange(d, dtype=dtype, device=device)
not_fixed = torch.tensor([i for i in dims if i not in is_fixed])
return is_fixed, not_fixed
Loading

0 comments on commit 24fba30

Please sign in to comment.