Skip to content

Commit

Permalink
speed up LCE-A kernel (#1910)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1910

X-link: facebook/Ax#1694

The current implementation is very slow. This is particularly problematic when the number of contexts is large.

This diff yields huge speed ups (>3 orders of magnitude) by removing nested for loops in `forward` and instead using batched computation. This also caches the context_covar in eval mode, with an additional requirements that the parameters for each context are contiguous and in the same order.

* 14x speed up with 8 contexts and 1,154x speed up with 128 contexts (CPU)

* 22x speed up with 8 contexts and 3,370x speed up with 128 contexts (CUDA)
* Without this diff, it takes 5 seconds on CPU and 12 seconds with CUDA for a single forward pass with 128 contexts.

Current implementation:

* 8 contexts:
  * Forward pass: 20ms (CPU), 45.2ms (CUDA)
  * Roundtrip: 39.1ms (CPU), 99.7ms (CUDA)
* 128 contexts:
  * Forward pass: 5.08s (CPU), 12s (CUDA)
  * Roundtrip: 14.2s (CPU), 26.7s (CUDA)

New Implementation:
* 8 contexts:
  * Forward pass: 1.44ms (CPU), 2.05ms (CUDA)
  * Roundtrip: 2.22ms (CPU), 4.65ms (CUDA)
* 128 contexts:
  * Forward pass: 4.4ms (CPU), 3.56ms (CUDA)
  * Roundtrip: 6.97ms (CPU), 5.34ms (CUDA)

Reviewed By: Balandat

Differential Revision: D47118335

fbshipit-source-id: efa091bf7ccc30559d921195e10c1375f525cca9
  • Loading branch information
sdaulton authored and facebook-github-bot committed Jul 3, 2023
1 parent 28b1b2b commit ed2ecf1
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 47 deletions.
3 changes: 1 addition & 2 deletions botorch/models/contextual.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,7 @@ def __init__(
train_Y: (n x 1) Y training data.
train_Yvar: (n x 1) Noise variance of Y.
decomposition: Keys are context names. Values are the indexes of
parameters belong to the context. The parameter indexes are in the
same order across contexts.
parameters belong to the context.
cat_feature_dict: Keys are context names and values are list of categorical
features i.e. {"context_name" : [cat_0, ..., cat_k]}, where k is the
number of categorical variables. If None, we use context names in the
Expand Down
157 changes: 117 additions & 40 deletions botorch/models/kernels/contextual_lcea.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,78 @@
from gpytorch.kernels.kernel import Kernel
from gpytorch.kernels.matern_kernel import MaternKernel
from gpytorch.priors.torch_priors import GammaPrior
from linear_operator.operators.sum_linear_operator import SumLinearOperator
from linear_operator.operators import DiagLinearOperator
from linear_operator.operators.dense_linear_operator import DenseLinearOperator
from torch import Tensor
from torch.nn import ModuleList


def get_order(indices: List[int]) -> List[int]:
r"""Get the order indices as integers ranging from 0 to the number of indices.
Args:
indices: A list of parameter indices.
Returns:
A list of integers ranging from 0 to the number of indices.
"""
return [i % len(indices) for i in indices]


def is_contiguous(indices: List[int]) -> bool:
r"""Check if the list of integers is contiguous.
Args:
indices: A list of parameter indices.
Returns:
A boolean indicating whether the indices are contiguous.
"""
min_idx = min(indices)
return set(indices) == set(range(min_idx, min_idx + len(indices)))


def get_permutation(decomposition: Dict[str, List[int]]) -> Optional[List[int]]:
"""Construct permutation to reorder the parameters such that:
1) the parameters for each context are contiguous.
2) The parameters for each context are in the same order
Args:
decomposition: A dictionary mapping context names to a list of
parameters.
Returns:
A permutation to reorder the parameters for (1) and (2).
Returning `None` means that ordering specified in `decomposition`
satisfies (1) and (2).
"""
permutation = None
if not all(
is_contiguous(indices=active_parameters)
for active_parameters in decomposition.values()
):
permutation = _create_new_permutation(decomposition=decomposition)
else:
same_order = True
expected_order = get_order(indices=next(iter(decomposition.values())))
for active_parameters in decomposition.values():
order = get_order(indices=active_parameters)
if order != expected_order:
same_order = False
break
if not same_order:
permutation = _create_new_permutation(decomposition=decomposition)
return permutation


def _create_new_permutation(decomposition: Dict[str, List[int]]) -> List[int]:
# make contiguous and ordered
permutation = []
for active_parameters in decomposition.values():
sorted_indices = sorted(active_parameters)
permutation.extend(sorted_indices)
return permutation


class LCEAKernel(Kernel):
r"""The Latent Context Embedding Additive (LCE-A) Kernel.
Expand All @@ -40,8 +107,7 @@ def __init__(
r"""
Args:
decomposition: Keys index context names. Values are the indexes of
parameters belong to the context. The parameter indexes are in the same
order across contexts.
parameters belong to the context.
batch_shape: Batch shape as usual for gpytorch kernels. Model does not
support batch training. When batch_shape is non-empty, it is used for
loading hyper-parameter values generated from MCMC sampling.
Expand All @@ -58,26 +124,25 @@ def __init__(
context_weight_dict: Known population weights of each context.
"""
super().__init__(batch_shape=batch_shape)
self.decomposition = decomposition
self.batch_shape = batch_shape
self.train_embedding = train_embedding
self._device = device

num_param = len(next(iter(decomposition.values())))
self.num_param = len(next(iter(decomposition.values())))
self.context_list = list(decomposition.keys())
self.num_contexts = len(self.context_list)

# get parameter space decomposition
for active_parameters in decomposition.values():
# check number of parameters are same in each decomp
if len(active_parameters) != num_param:
if len(active_parameters) != self.num_param:
raise ValueError(
"num of parameters needs to be same across all contexts"
"The number of parameters needs to be same across all contexts."
)
self._indexers = {
context: torch.tensor(active_params, device=self.device)
for context, active_params in self.decomposition.items()
}
# reorder the parameter list based on decomposition such that
# parameters for each context are contiguous and in the same order for each
# context
self.permutation = get_permutation(decomposition=decomposition)
# get context features and set emb dim
self.context_cat_feature = None
self.context_emb_feature = None
Expand All @@ -102,7 +167,7 @@ def __init__(
# base kernel
self.base_kernel = MaternKernel(
nu=2.5,
ard_num_dims=num_param,
ard_num_dims=self.num_param,
batch_shape=batch_shape,
lengthscale_prior=GammaPrior(3.0, 6.0),
)
Expand Down Expand Up @@ -303,6 +368,11 @@ def _task_embeddings_batch(self) -> Tensor:
)
return embeddings

def train(self, mode: bool = True) -> None:
super().train(mode=mode)
if not mode:
self.register_buffer("_context_covar", self._eval_context_covar())

def forward(
self,
x1: Tensor,
Expand All @@ -315,35 +385,42 @@ def forward(
covariance matrices together
"""
# context covar matrix
context_covar = self._eval_context_covar()
if not self.training:
context_covar = self._context_covar
else:
context_covar = self._eval_context_covar()
if self.permutation is not None:
x1 = x1[..., self.permutation]
x2 = x2[..., self.permutation]
# check input batch size if b x ns x n x d: expand context_covar to
# b x ns x num_context x num_context
if x1.dim() > context_covar.dim():
context_covar = context_covar.expand(*x1.shape[:1] + context_covar.shape)
covars = []
# TODO: speed computation of covariance matrix
for i in range(self.num_contexts):
for j in range(self.num_contexts):
context1 = self.context_list[i]
context2 = self.context_list[j]
active_params1 = self._indexers[context1]
active_params2 = self._indexers[context2]
covars.append(
(
context_covar.index_select( # pyre-ignore
-1, torch.tensor([j], device=self.device)
).index_select(
-2, torch.tensor([i], device=self.device)
) # b x ns x 1 x 1
)
* self.base_kernel(
x1=x1.index_select(-1, active_params1),
x2=x2.index_select(-1, active_params2),
diag=diag,
)
)
context_covar = context_covar.expand(
x1.shape[:-1] + torch.Size([x2.shape[-2]]) + context_covar.shape
)
# turn last two dimensions of n x (k*d) into (n*k) x d.
x1_exp = x1.reshape(*x1.shape[:-2], -1, self.num_param)
x2_exp = x2.reshape(*x2.shape[:-2], -1, self.num_param)
# batch shape x n*k x n*k
base_covar = self.base_kernel(x1_exp, x2_exp)
# batch shape x n x n x k x k
view_shape = x1.shape[:-2] + torch.Size(
[
x1.shape[-2],
self.num_contexts,
x2.shape[-2],
self.num_contexts,
]
)
base_covar_perm = (
base_covar.to_dense()
.view(view_shape)
.permute(*list(range(x1.ndim - 2)), -4, -2, -3, -1)
)
# then weight by the context kernel
# compute the base kernel on the d parameters
einsum_str = "...kk, ...nnkk -> ...n" if diag else "...kk, ...kk -> ..."
covar_dense = torch.einsum(einsum_str, context_covar, base_covar_perm)
if diag:
res = sum(covars)
else:
res = SumLinearOperator(*covars)
return res
return DiagLinearOperator(covar_dense)
return DenseLinearOperator(covar_dense)
37 changes: 32 additions & 5 deletions test/models/kernels/test_contextual.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@
# LICENSE file in the root directory of this source tree.

import torch
from botorch.models.kernels.contextual_lcea import LCEAKernel
from botorch.models.kernels.contextual_lcea import (
get_order,
get_permutation,
is_contiguous,
LCEAKernel,
)

from botorch.models.kernels.contextual_sac import SACKernel
from botorch.utils.testing import BotorchTestCase
from gpytorch.kernels.matern_kernel import MaternKernel
Expand Down Expand Up @@ -36,19 +42,20 @@ def test_SACKernel(self):
def testLCEAKernel(self):
decomposition = {"1": [0, 3], "2": [1, 2]}
num_contexts = len(decomposition)

kernel = LCEAKernel(decomposition=decomposition, batch_shape=torch.Size([]))
# test init
self.assertListEqual(kernel.context_list, ["1", "2"])
self.assertDictEqual(kernel.decomposition, decomposition)

self.assertIsInstance(kernel.base_kernel, MaternKernel)
self.assertIsInstance(kernel.task_covar_module, MaternKernel)
self.assertEqual(kernel.permutation, [0, 3, 1, 2])

# test raise of ValueError
with self.assertRaises(ValueError):
with self.assertRaisesRegex(
ValueError, "The number of parameters needs to be same across all contexts."
):
LCEAKernel(
decomposition={"1": [0, 3], "2": [1]}, batch_shape=torch.Size([])
decomposition={"1": [0, 1], "2": [2]}, batch_shape=torch.Size([])
)

# test set_outputscale_list
Expand Down Expand Up @@ -181,3 +188,23 @@ def testLCEAKernel(self):
self.assertEqual(
context_covar6.shape, torch.Size([3, num_contexts, num_contexts])
)

def test_get_permutation(self):
decomp = {"a": [0, 1], "b": [2, 3]}
permutation = get_permutation(decomp)
self.assertIsNone(permutation)
# order mismatch
decomp = {"a": [1, 0], "b": [2, 3]}
permutation = get_permutation(decomp)
self.assertEqual(permutation, [0, 1, 2, 3])
# non-contiguous
decomp = {"a": [0, 2], "b": [1, 3]}
permutation = get_permutation(decomp)
self.assertEqual(permutation, [0, 2, 1, 3])

def test_is_contiguous(self):
self.assertFalse(is_contiguous([0, 2]))
self.assertTrue(is_contiguous([0, 1]))

def test_get_order(self):
self.assertEqual(get_order([1, 10, 3]), [1, 1, 0])

0 comments on commit ed2ecf1

Please sign in to comment.