From 24fba30bc7e19240a0cf1c0fd45412dfaf4ae79b Mon Sep 17 00:00:00 2001 From: Sebastian Ament Date: Tue, 13 Jun 2023 15:35:12 -0700 Subject: [PATCH] LinearEllipticalSliceSampler with fixed feature constraints (#1883) Summary: Pull Request resolved: https://github.com/pytorch/botorch/pull/1883 This adds support for fixed feature equality constraints to `LinearEllipticalSliceSampler`. Differential Revision: D46613288 fbshipit-source-id: 6a5a53d2ebe2ad831406b8df73bdf69d35c4cea2 --- botorch/utils/probability/lin_ess.py | 215 +++++++++++++++++++------ test/utils/probability/test_lin_ess.py | 117 +++++++++++--- 2 files changed, 264 insertions(+), 68 deletions(-) diff --git a/botorch/utils/probability/lin_ess.py b/botorch/utils/probability/lin_ess.py index 010c92e8cd..e432cbaf34 100644 --- a/botorch/utils/probability/lin_ess.py +++ b/botorch/utils/probability/lin_ess.py @@ -22,7 +22,7 @@ from __future__ import annotations import math -from typing import Optional, Tuple +from typing import List, Optional, Tuple, Union import torch from botorch.utils.sampling import PolytopeSampler @@ -34,13 +34,10 @@ class LinearEllipticalSliceSampler(PolytopeSampler): r"""Linear Elliptical Slice Sampler. - TODOs: - - clean up docstrings + Ideas: - Add batch support, broadcasting over parallel chains. - - optimize computations (if possible) - - Maybe TODOs: - - Support degenerate domains (with zero volume)? + - Optimize computations if possible, potentially with torch.compile. + - Extend fixed features constraint to general linear equality constraints. """ def __init__( @@ -54,6 +51,7 @@ def __init__( check_feasibility: bool = False, burnin: int = 0, thinning: int = 0, + fixed_indices: Optional[Union[List[int], Tensor]] = None, ) -> None: r"""Initialize LinearEllipticalSliceSampler. @@ -82,6 +80,10 @@ def __init__( burnin: Number of samples to generate upon initialization to warm up the sampler. thinning: Number of samples to skip before returning a sample in `draw`. + fixed_indices: Integer list or `d`-dim Tensor representing the indices of + dimensions that are constrained to be fixed to the values specified in + the `interior_point`, which is required to be passed in conjunction with + `fixed_indices`. This sampler samples from a multivariante Normal `N(mean, covariance_matrix)` subject to linear domain constraints `A x <= b` (intersected with box bounds, @@ -94,13 +96,28 @@ def __init__( bounds=bounds, ) tkwargs = {"device": self.x0.device, "dtype": self.x0.dtype} - self._mean = mean - if covariance_matrix is not None: - if covariance_root is not None: - raise ValueError( - "Provide either covariance_matrix or covariance_root, not both." - ) + if covariance_matrix is not None and covariance_root is not None: + raise ValueError( + "Provide either covariance_matrix or covariance_root, not both." + ) + + # can't unpack inequality constraints directly if bounds are passed + A, b = self.A, self.b + self._Az, self._bz = A, b + self._is_fixed, self._not_fixed = None, None + mean, covariance_matrix = self._fixed_features_initialization( + A=A, + b=b, + interior_point=interior_point, + mean=mean, + covariance_matrix=covariance_matrix, + covariance_root=covariance_root, + fixed_indices=fixed_indices, + ) + self._mean = mean + # Have to delay factorization until after fixed features initialization. + if covariance_matrix is not None: # implies root is None covariance_root, info = torch.linalg.cholesky_ex(covariance_matrix) not_psd = torch.any(info) if not_psd: @@ -110,19 +127,12 @@ def __init__( ) self._covariance_root = covariance_root - # For non-standard mean and covariance, we're going to rewrite the problem as - # sampling from a standard normal distribution subject to modified constraints. - # Ax - b = A @ (covar_root @ z + mean) - b - # = (A @ covar_root) @ z - (b - A @ mean) - # = _Az @ z - _bz - self._Az = ( - self.A if self._covariance_root is None else self.A @ self._covariance_root - ) - self._bz = self.b if self._mean is None else self.b - self.A @ self._mean + # Rewrite the constraints as a system that constrains a standard Normal. + self._standardization_initialization() # state of the sampler ("current point") self._x = self.x0.clone() - self._z = self._standardize(self._x) + self._z = self._transform(self._x) # We will need the following repeatedly, let's allocate them once self._zero = torch.zeros(1, **tkwargs) @@ -135,6 +145,59 @@ def __init__( self.draw(burnin) self.thinning = thinning + def _fixed_features_initialization( + self, + A: Tensor, + b: Tensor, + interior_point: Optional[Tensor], + mean: Optional[Tensor], + covariance_matrix: Optional[Tensor], + covariance_root: Optional[Tensor], + fixed_indices: Optional[Union[List[int], Tensor]], + ) -> Tuple[Optional[Tensor], Optional[Tensor]]: + if fixed_indices is not None: + if interior_point is None: + raise ValueError( + "If `fixed_indices` are provided, an interior point must also be " + "provided in order to infer feasible values of the fixed features." + ) + if covariance_root is not None: + raise ValueError( + "Provide either covariance_root or fixed_indices, not both." + ) + d = interior_point.shape[0] + is_fixed, not_fixed = get_index_tensors(fixed_indices=fixed_indices, d=d) + self._is_fixed = is_fixed + self._not_fixed = not_fixed + # Transforming constraint system to incorporate fixed features: + # A @ x - b = (A[:, fixed] @ x[fixed] + A[:, not fixed] @ x[not fixed]) - b + # = A[:, not fixed] @ x[not fixed] - (b - A[:, fixed] @ x[fixed]) + # = Az @ z - bz + self._Az = A[:, not_fixed] + self._bz = b - A[:, is_fixed] @ interior_point[is_fixed] + if mean is not None: + mean = mean[not_fixed] + if covariance_matrix is not None: # subselect active dimensions + covariance_matrix = covariance_matrix[ + not_fixed.unsqueeze(-1), not_fixed.unsqueeze(0) + ] + return mean, covariance_matrix + + def _standardization_initialization(self) -> None: + """For non-standard mean and covariance, we're going to rewrite the problem as + sampling from a standard normal distribution subject to modified constraints. + A @ x - b = A @ (covar_root @ z + mean) - b + = (A @ covar_root) @ z - (b - A @ mean) + = _Az @ z - _bz + NOTE: We need to standardize bz before Az in the following, because it relies + on the untransformed Az. We can't simply use A instead because Az might have + been subject to the fixed features transformation. + """ + if self._mean is not None: + self._bz = self._bz - self._Az @ self._mean + if self._covariance_root is not None: + self._Az = self._Az @ self._covariance_root + @property def lifetime_samples(self) -> int: """The total number of samples generated by the sampler during its lifetime.""" @@ -156,22 +219,6 @@ def draw(self, n: int = 1) -> Tuple[Tensor, Tensor]: samples.append(self.step()) return torch.cat(samples, dim=-1).transpose(-1, -2) - def _unstandardize(self, z: Tensor) -> Tensor: - x = z - if self._covariance_root is not None: - x = self._covariance_root @ x - if self._mean is not None: - x = x + self._mean - return x - - def _standardize(self, x: Tensor) -> Tensor: - z = x - if self._mean is not None: - z = z - self._mean - if self._covariance_root is not None: - z = torch.linalg.solve_triangular(self._covariance_root, z, upper=False) - return z - def step(self) -> Tensor: r"""Take a step, return the new sample, update the internal state. @@ -182,7 +229,7 @@ def step(self) -> Tensor: theta = self._draw_angle(nu=nu) z = self._get_cart_coords(nu=nu, theta=theta) self._z[:] = z - x = self._unstandardize(z) + x = self._untransform(z) self._x[:] = x self._lifetime_samples += 1 if self.check_feasibility and (not self._is_feasible(self._x)): @@ -347,25 +394,103 @@ def _active_theta_and_delta(self, nu: Tensor, theta: Tensor) -> Tensor: theta_mid = torch.cat((last_mid, theta_mid, last_mid), dim=0) samples_mid = self._get_cart_coords(nu=nu, theta=theta_mid) delta_feasibility = ( - self._is_feasible(samples_mid, standardized=True).to(dtype=int).diff() + self._is_feasible(samples_mid, transformed=True).to(dtype=int).diff() ) active_indices = delta_feasibility.nonzero() return theta[active_indices], delta_feasibility[active_indices] - def _is_feasible(self, points: Tensor, standardized: bool = False) -> Tensor: + def _is_feasible(self, points: Tensor, transformed: bool = False) -> Tensor: r"""Returns a Boolean tensor indicating whether the `points` are feasible, i.e. they satisfy `A @ points <= b`, where `(A, b)` are the tensors passed as the `inequality_constraints` to the constructor of the sampler. Args: points: A `d x M`-dim tensor of points. - standardized: Wether points are assumed to be standardized by a change of + transformed: Wether points are assumed to be transformed by a change of basis, which means feasibility should be computed based on the - standardized constraint system (_A, _b), instead of (A, b). + transformed constraint system (_Az, _bz), instead of (A, b). Returns: An `M`-dim binary tensor where `True` indicates that the associated point is feasible. """ - A, b = (self._Az, self._bz) if standardized else (self.A, self.b) + A, b = (self._Az, self._bz) if transformed else (self.A, self.b) return (A @ points <= b).all(dim=0) + + def _transform(self, x: Tensor) -> Tensor: + """Transforms the input so that it is equivalent to a standard Normal variable + constrained with the modified system constraints (self._Az, self._bz). + + Args: + x: The input tensor to be transformed, usually `d x 1`-dimensional. + + Returns: + A `d x 1`-dimensional tensor of transformed values subject to the modified + system of constraints. + """ + if self._not_fixed is not None: + x = x[self._not_fixed] + return self._standardize(x) + + def _untransform(self, z: Tensor) -> Tensor: + """The inverse transform of the `_transform`, i.e. maps `z` back to the original + space where it is subject to the original constraint system (self.A, self.b). + + Args: + z: The transformed tensor to be un-transformed, usually `d x 1`-dimensional. + + Returns: + A `d x 1`-dimensional tensor of un-transformed values subject to the + original system of constraints. + """ + if self._is_fixed is None: + return self._unstandardize(z) + else: + x = self._x.clone() # _x already contains the fixed values + x[self._not_fixed] = self._unstandardize(z) + return x + + def _standardize(self, x: Tensor) -> Tensor: + """_transform helper standardizing the input `x`, which is assumed to be a + `d x 1`-dim Tensor, or a `len(self._not_fixed) x 1`-dim if there are fixed + indices. + """ + z = x + if self._mean is not None: + z = z - self._mean + if self._covariance_root is not None: + z = torch.linalg.solve_triangular(self._covariance_root, z, upper=False) + return z + + def _unstandardize(self, z: Tensor) -> Tensor: + """_untransform helper un-standardizing the input `z`, which is assumed to be a + `d x 1`-dim Tensor, or a `len(self._not_fixed) x 1`-dim if there are fixed + indices. + """ + x = z + if self._covariance_root is not None: + x = self._covariance_root @ x + if self._mean is not None: + x = x + self._mean + return x + + +def get_index_tensors( + fixed_indices: Union[List[int], Tensor], d: int +) -> Tuple[Tensor, Tensor]: + """Converts `fixed_indices` to a `d`-dim integral Tensor that is True at indices + that are contained in `fixed_indices` and False otherwise. + + Args: + fixed_indices: A list or Tensoro of integer indices to fix. + d: The dimensionality of the Tensors to be indexed. + + Returns: + A Tuple of integral Tensors partitioning [1, d] into indices that are fixed + (first tensor) and non-fixed (second tensor). + """ + is_fixed = torch.as_tensor(fixed_indices) + dtype, device = is_fixed.dtype, is_fixed.device + dims = torch.arange(d, dtype=dtype, device=device) + not_fixed = torch.tensor([i for i in dims if i not in is_fixed]) + return is_fixed, not_fixed diff --git a/test/utils/probability/test_lin_ess.py b/test/utils/probability/test_lin_ess.py index 105e3e0c82..5565509730 100644 --- a/test/utils/probability/test_lin_ess.py +++ b/test/utils/probability/test_lin_ess.py @@ -6,6 +6,8 @@ from __future__ import annotations +import itertools + import math from unittest.mock import patch @@ -14,6 +16,7 @@ from botorch.exceptions.errors import BotorchError from botorch.utils.probability.lin_ess import LinearEllipticalSliceSampler from botorch.utils.testing import BotorchTestCase +from torch import Tensor class TestLinearEllipticalSliceSampler(BotorchTestCase): @@ -162,8 +165,8 @@ def test_bivariate(self): self.assertFalse(torch.equal(sampler._x, sampler.x0)) def test_multivariate(self): - for dtype in (torch.float, torch.double): - d = 3 + for dtype, atol in zip((torch.float, torch.double), (1e-6, 1e-12)): + d = 5 tkwargs = {"device": self.device, "dtype": dtype} # special case: N(0, I) truncated to greater than lower_bound A = -torch.eye(d, **tkwargs) @@ -207,30 +210,58 @@ def test_multivariate(self): # normalizing to maximal unit variance so that sem math below applies cov_matrix /= cov_matrix.max() interior_point = torch.ones_like(mean) - for mean_i, cov_i in [ + means_and_covs = [ (None, None), (mean, None), (None, cov_matrix), (mean, cov_matrix), - ]: - with self.subTest(mean=mean_i, cov=cov_i): + ] + fixed_indices = [None, [1, 3]] + for (mean_i, cov_i), ff_i in itertools.product( + means_and_covs, + fixed_indices, + ): + with self.subTest(mean=mean_i, cov=cov_i, fixed_indices=ff_i): sampler = LinearEllipticalSliceSampler( inequality_constraints=(A, b), interior_point=interior_point, check_feasibility=True, mean=mean_i, covariance_matrix=cov_i, + fixed_indices=ff_i, ) # checking standardized system of constraints - mean_i = torch.zeros_like(mean) if mean_i is None else mean_i - cov_root_i = ( - torch.eye(d, **tkwargs) - if cov_i is None - else torch.linalg.cholesky_ex(cov_i)[0] - ) - self.assertAllClose(sampler._Az, A @ cov_root_i) - self.assertAllClose(sampler._bz, b - A @ mean_i) + mean_i = torch.zeros(d, 1, **tkwargs) if mean_i is None else mean_i + cov_i = torch.eye(d, **tkwargs) if cov_i is None else cov_i + + # Transform the system to incorporate equality constraints and non- + # standard mean and covariance. + Az_i, bz_i = A, b + if ff_i is None: + is_fixed = [] + not_fixed = range(d) + else: + is_fixed = sampler._is_fixed + not_fixed = sampler._not_fixed + self.assertIsInstance(is_fixed, Tensor) + self.assertIsInstance(not_fixed, Tensor) + self.assertEqual(is_fixed.shape, (len(ff_i),)) + self.assertEqual(not_fixed.shape, (d - len(ff_i),)) + self.assertTrue(all(i in ff_i for i in is_fixed)) + self.assertFalse(any(i in ff_i for i in not_fixed)) + # Modifications to constraint system + Az_i = A[:, not_fixed] + bz_i = b - A[:, is_fixed] @ interior_point[is_fixed] + mean_i = mean_i[not_fixed] + cov_i = cov_i[not_fixed.unsqueeze(-1), not_fixed.unsqueeze(0)] + + cov_root_i = torch.linalg.cholesky_ex(cov_i)[0] + bz_i = bz_i - Az_i @ mean_i + Az_i = Az_i @ cov_root_i + self.assertAllClose(sampler._Az, Az_i, atol=atol) + self.assertAllClose(sampler._bz, bz_i, atol=atol) + # testing standardization of non-fixed elements x = torch.randn_like(mean_i) z = sampler._standardize(x) self.assertAllClose( @@ -238,8 +269,22 @@ def test_multivariate(self): torch.linalg.solve_triangular( cov_root_i, x - mean_i, upper=False ), + atol=atol, ) - self.assertAllClose(sampler._unstandardize(z), x) + self.assertAllClose(sampler._unstandardize(z), x, atol=atol) + + # testing transformation + x = torch.randn(d, 1, **tkwargs) + x[is_fixed] = interior_point[is_fixed] # fixed dimensions + z = sampler._transform(x) + self.assertAllClose( + z, + torch.linalg.solve_triangular( + cov_root_i, x[not_fixed] - mean_i, upper=False + ), + atol=atol, + ) + self.assertAllClose(sampler._untransform(z), x, atol=atol) # checking rejection-free property num_samples = 32 @@ -252,22 +297,27 @@ def test_multivariate(self): # of 5 sigma is 1 in 1.76 million. # sem ~ 0.7 -> can differentiate from zero mean sem = 5 / math.sqrt(num_samples) - sample_mean = samples.mean(dim=0) - self.assertAllClose(sample_mean, mean_i.squeeze(-1), atol=sem) + sample_mean = samples.mean(dim=0).unsqueeze(-1) + self.assertAllClose(sample_mean[not_fixed], mean_i, atol=sem) + # testing the samples have correctly fixed features + self.assertTrue( + torch.equal(sample_mean[is_fixed], interior_point[is_fixed]) + ) - # checking that standardization does not change feasibility values + # checking that transformation does not change feasibility values X_test = 3 * torch.randn(d, num_samples, **tkwargs) + X_test[is_fixed] = interior_point[is_fixed] self.assertAllClose( - sampler._Az @ sampler._standardize(X_test) - sampler._bz, + sampler._Az @ sampler._transform(X_test) - sampler._bz, A @ X_test - b, - atol=1e-5, + atol=atol, ) self.assertAllClose( sampler._is_feasible( - sampler._standardize(X_test), standardized=True + sampler._transform(X_test), transformed=True ), - sampler._is_feasible(X_test, standardized=False), - atol=1e-5, + sampler._is_feasible(X_test, transformed=False), + atol=atol, ) # thining and burn-in tests @@ -311,7 +361,7 @@ def test_multivariate(self): rot_angle, slices = sampler._find_rotated_intersections(nu) self.assertEqual(rot_angle, 0.0) self.assertAllClose( - slices, torch.tensor([[0.0, 2 * torch.pi]], **tkwargs), atol=1e-6 + slices, torch.tensor([[0.0, 2 * torch.pi]], **tkwargs), atol=atol ) # 2) testing tangential intersection of ellipse with constraint @@ -342,6 +392,27 @@ def test_multivariate(self): ): sampler.step() + # testing error for fixed features with no interior point + with self.assertRaisesRegex( + ValueError, + ".*an interior point must also be provided in order to infer feasible ", + ): + LinearEllipticalSliceSampler( + inequality_constraints=(A, b), + fixed_indices=[0], + ) + + with self.assertRaisesRegex( + ValueError, + "Provide either covariance_root or fixed_indices, not both.", + ): + LinearEllipticalSliceSampler( + inequality_constraints=(A, b), + interior_point=interior_point, + fixed_indices=[0], + covariance_root=torch.eye(d, **tkwargs), + ) + # high dimensional test case d = 128 # this encodes order constraints on all d variables: Ax < b