From ed2ecf1288489d6a507fb0b04777ceeec6b700fc Mon Sep 17 00:00:00 2001 From: Sam Daulton Date: Mon, 3 Jul 2023 15:40:33 -0700 Subject: [PATCH] speed up LCE-A kernel (#1910) Summary: Pull Request resolved: https://github.com/pytorch/botorch/pull/1910 X-link: https://github.com/facebook/Ax/pull/1694 The current implementation is very slow. This is particularly problematic when the number of contexts is large. This diff yields huge speed ups (>3 orders of magnitude) by removing nested for loops in `forward` and instead using batched computation. This also caches the context_covar in eval mode, with an additional requirements that the parameters for each context are contiguous and in the same order. * 14x speed up with 8 contexts and 1,154x speed up with 128 contexts (CPU) * 22x speed up with 8 contexts and 3,370x speed up with 128 contexts (CUDA) * Without this diff, it takes 5 seconds on CPU and 12 seconds with CUDA for a single forward pass with 128 contexts. Current implementation: * 8 contexts: * Forward pass: 20ms (CPU), 45.2ms (CUDA) * Roundtrip: 39.1ms (CPU), 99.7ms (CUDA) * 128 contexts: * Forward pass: 5.08s (CPU), 12s (CUDA) * Roundtrip: 14.2s (CPU), 26.7s (CUDA) New Implementation: * 8 contexts: * Forward pass: 1.44ms (CPU), 2.05ms (CUDA) * Roundtrip: 2.22ms (CPU), 4.65ms (CUDA) * 128 contexts: * Forward pass: 4.4ms (CPU), 3.56ms (CUDA) * Roundtrip: 6.97ms (CPU), 5.34ms (CUDA) Reviewed By: Balandat Differential Revision: D47118335 fbshipit-source-id: efa091bf7ccc30559d921195e10c1375f525cca9 --- botorch/models/contextual.py | 3 +- botorch/models/kernels/contextual_lcea.py | 157 ++++++++++++++++------ test/models/kernels/test_contextual.py | 37 ++++- 3 files changed, 150 insertions(+), 47 deletions(-) diff --git a/botorch/models/contextual.py b/botorch/models/contextual.py index e095893ab7..2ef486d50f 100644 --- a/botorch/models/contextual.py +++ b/botorch/models/contextual.py @@ -66,8 +66,7 @@ def __init__( train_Y: (n x 1) Y training data. train_Yvar: (n x 1) Noise variance of Y. decomposition: Keys are context names. Values are the indexes of - parameters belong to the context. The parameter indexes are in the - same order across contexts. + parameters belong to the context. cat_feature_dict: Keys are context names and values are list of categorical features i.e. {"context_name" : [cat_0, ..., cat_k]}, where k is the number of categorical variables. If None, we use context names in the diff --git a/botorch/models/kernels/contextual_lcea.py b/botorch/models/kernels/contextual_lcea.py index e138a3014f..b8c960ac26 100644 --- a/botorch/models/kernels/contextual_lcea.py +++ b/botorch/models/kernels/contextual_lcea.py @@ -11,11 +11,78 @@ from gpytorch.kernels.kernel import Kernel from gpytorch.kernels.matern_kernel import MaternKernel from gpytorch.priors.torch_priors import GammaPrior -from linear_operator.operators.sum_linear_operator import SumLinearOperator +from linear_operator.operators import DiagLinearOperator +from linear_operator.operators.dense_linear_operator import DenseLinearOperator from torch import Tensor from torch.nn import ModuleList +def get_order(indices: List[int]) -> List[int]: + r"""Get the order indices as integers ranging from 0 to the number of indices. + + Args: + indices: A list of parameter indices. + + Returns: + A list of integers ranging from 0 to the number of indices. + """ + return [i % len(indices) for i in indices] + + +def is_contiguous(indices: List[int]) -> bool: + r"""Check if the list of integers is contiguous. + + Args: + indices: A list of parameter indices. + Returns: + A boolean indicating whether the indices are contiguous. + """ + min_idx = min(indices) + return set(indices) == set(range(min_idx, min_idx + len(indices))) + + +def get_permutation(decomposition: Dict[str, List[int]]) -> Optional[List[int]]: + """Construct permutation to reorder the parameters such that: + + 1) the parameters for each context are contiguous. + 2) The parameters for each context are in the same order + + Args: + decomposition: A dictionary mapping context names to a list of + parameters. + Returns: + A permutation to reorder the parameters for (1) and (2). + Returning `None` means that ordering specified in `decomposition` + satisfies (1) and (2). + """ + permutation = None + if not all( + is_contiguous(indices=active_parameters) + for active_parameters in decomposition.values() + ): + permutation = _create_new_permutation(decomposition=decomposition) + else: + same_order = True + expected_order = get_order(indices=next(iter(decomposition.values()))) + for active_parameters in decomposition.values(): + order = get_order(indices=active_parameters) + if order != expected_order: + same_order = False + break + if not same_order: + permutation = _create_new_permutation(decomposition=decomposition) + return permutation + + +def _create_new_permutation(decomposition: Dict[str, List[int]]) -> List[int]: + # make contiguous and ordered + permutation = [] + for active_parameters in decomposition.values(): + sorted_indices = sorted(active_parameters) + permutation.extend(sorted_indices) + return permutation + + class LCEAKernel(Kernel): r"""The Latent Context Embedding Additive (LCE-A) Kernel. @@ -40,8 +107,7 @@ def __init__( r""" Args: decomposition: Keys index context names. Values are the indexes of - parameters belong to the context. The parameter indexes are in the same - order across contexts. + parameters belong to the context. batch_shape: Batch shape as usual for gpytorch kernels. Model does not support batch training. When batch_shape is non-empty, it is used for loading hyper-parameter values generated from MCMC sampling. @@ -58,26 +124,25 @@ def __init__( context_weight_dict: Known population weights of each context. """ super().__init__(batch_shape=batch_shape) - self.decomposition = decomposition self.batch_shape = batch_shape self.train_embedding = train_embedding self._device = device - num_param = len(next(iter(decomposition.values()))) + self.num_param = len(next(iter(decomposition.values()))) self.context_list = list(decomposition.keys()) self.num_contexts = len(self.context_list) # get parameter space decomposition for active_parameters in decomposition.values(): # check number of parameters are same in each decomp - if len(active_parameters) != num_param: + if len(active_parameters) != self.num_param: raise ValueError( - "num of parameters needs to be same across all contexts" + "The number of parameters needs to be same across all contexts." ) - self._indexers = { - context: torch.tensor(active_params, device=self.device) - for context, active_params in self.decomposition.items() - } + # reorder the parameter list based on decomposition such that + # parameters for each context are contiguous and in the same order for each + # context + self.permutation = get_permutation(decomposition=decomposition) # get context features and set emb dim self.context_cat_feature = None self.context_emb_feature = None @@ -102,7 +167,7 @@ def __init__( # base kernel self.base_kernel = MaternKernel( nu=2.5, - ard_num_dims=num_param, + ard_num_dims=self.num_param, batch_shape=batch_shape, lengthscale_prior=GammaPrior(3.0, 6.0), ) @@ -303,6 +368,11 @@ def _task_embeddings_batch(self) -> Tensor: ) return embeddings + def train(self, mode: bool = True) -> None: + super().train(mode=mode) + if not mode: + self.register_buffer("_context_covar", self._eval_context_covar()) + def forward( self, x1: Tensor, @@ -315,35 +385,42 @@ def forward( covariance matrices together """ # context covar matrix - context_covar = self._eval_context_covar() + if not self.training: + context_covar = self._context_covar + else: + context_covar = self._eval_context_covar() + if self.permutation is not None: + x1 = x1[..., self.permutation] + x2 = x2[..., self.permutation] # check input batch size if b x ns x n x d: expand context_covar to # b x ns x num_context x num_context if x1.dim() > context_covar.dim(): - context_covar = context_covar.expand(*x1.shape[:1] + context_covar.shape) - covars = [] - # TODO: speed computation of covariance matrix - for i in range(self.num_contexts): - for j in range(self.num_contexts): - context1 = self.context_list[i] - context2 = self.context_list[j] - active_params1 = self._indexers[context1] - active_params2 = self._indexers[context2] - covars.append( - ( - context_covar.index_select( # pyre-ignore - -1, torch.tensor([j], device=self.device) - ).index_select( - -2, torch.tensor([i], device=self.device) - ) # b x ns x 1 x 1 - ) - * self.base_kernel( - x1=x1.index_select(-1, active_params1), - x2=x2.index_select(-1, active_params2), - diag=diag, - ) - ) + context_covar = context_covar.expand( + x1.shape[:-1] + torch.Size([x2.shape[-2]]) + context_covar.shape + ) + # turn last two dimensions of n x (k*d) into (n*k) x d. + x1_exp = x1.reshape(*x1.shape[:-2], -1, self.num_param) + x2_exp = x2.reshape(*x2.shape[:-2], -1, self.num_param) + # batch shape x n*k x n*k + base_covar = self.base_kernel(x1_exp, x2_exp) + # batch shape x n x n x k x k + view_shape = x1.shape[:-2] + torch.Size( + [ + x1.shape[-2], + self.num_contexts, + x2.shape[-2], + self.num_contexts, + ] + ) + base_covar_perm = ( + base_covar.to_dense() + .view(view_shape) + .permute(*list(range(x1.ndim - 2)), -4, -2, -3, -1) + ) + # then weight by the context kernel + # compute the base kernel on the d parameters + einsum_str = "...kk, ...nnkk -> ...n" if diag else "...kk, ...kk -> ..." + covar_dense = torch.einsum(einsum_str, context_covar, base_covar_perm) if diag: - res = sum(covars) - else: - res = SumLinearOperator(*covars) - return res + return DiagLinearOperator(covar_dense) + return DenseLinearOperator(covar_dense) diff --git a/test/models/kernels/test_contextual.py b/test/models/kernels/test_contextual.py index f7982891fd..8f8a710df7 100644 --- a/test/models/kernels/test_contextual.py +++ b/test/models/kernels/test_contextual.py @@ -5,7 +5,13 @@ # LICENSE file in the root directory of this source tree. import torch -from botorch.models.kernels.contextual_lcea import LCEAKernel +from botorch.models.kernels.contextual_lcea import ( + get_order, + get_permutation, + is_contiguous, + LCEAKernel, +) + from botorch.models.kernels.contextual_sac import SACKernel from botorch.utils.testing import BotorchTestCase from gpytorch.kernels.matern_kernel import MaternKernel @@ -36,19 +42,20 @@ def test_SACKernel(self): def testLCEAKernel(self): decomposition = {"1": [0, 3], "2": [1, 2]} num_contexts = len(decomposition) - kernel = LCEAKernel(decomposition=decomposition, batch_shape=torch.Size([])) # test init self.assertListEqual(kernel.context_list, ["1", "2"]) - self.assertDictEqual(kernel.decomposition, decomposition) self.assertIsInstance(kernel.base_kernel, MaternKernel) self.assertIsInstance(kernel.task_covar_module, MaternKernel) + self.assertEqual(kernel.permutation, [0, 3, 1, 2]) # test raise of ValueError - with self.assertRaises(ValueError): + with self.assertRaisesRegex( + ValueError, "The number of parameters needs to be same across all contexts." + ): LCEAKernel( - decomposition={"1": [0, 3], "2": [1]}, batch_shape=torch.Size([]) + decomposition={"1": [0, 1], "2": [2]}, batch_shape=torch.Size([]) ) # test set_outputscale_list @@ -181,3 +188,23 @@ def testLCEAKernel(self): self.assertEqual( context_covar6.shape, torch.Size([3, num_contexts, num_contexts]) ) + + def test_get_permutation(self): + decomp = {"a": [0, 1], "b": [2, 3]} + permutation = get_permutation(decomp) + self.assertIsNone(permutation) + # order mismatch + decomp = {"a": [1, 0], "b": [2, 3]} + permutation = get_permutation(decomp) + self.assertEqual(permutation, [0, 1, 2, 3]) + # non-contiguous + decomp = {"a": [0, 2], "b": [1, 3]} + permutation = get_permutation(decomp) + self.assertEqual(permutation, [0, 2, 1, 3]) + + def test_is_contiguous(self): + self.assertFalse(is_contiguous([0, 2])) + self.assertTrue(is_contiguous([0, 1])) + + def test_get_order(self): + self.assertEqual(get_order([1, 10, 3]), [1, 1, 0])