Skip to content
44 changes: 41 additions & 3 deletions botorch/posteriors/multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,11 @@ def __init__(
distribution: Posterior multivariate normal distribution.
joint_covariance_matrix: Joint test train covariance matrix over the entire
tensor.
train_train_covar: Covariance matrix of train points in the data space.
test_obs_covar: Covariance matrix of test x train points in the data space.
test_train_covar: Covariance matrix of test x train points in the data
space.
train_diff: Difference between train mean and train responses.
test_mean: Test mean response.
train_train_covar: Covariance matrix of train points in the data space.
train_noise: Training noise covariance.
test_noise: Only used if posterior should contain observation noise.
Testing noise covariance.
Expand Down Expand Up @@ -226,7 +228,9 @@ def rsample_from_base_samples(
train_diff.reshape(*train_diff.shape[:-2], -1) - updated_obs_samples
)
train_covar_plus_noise = self.train_train_covar + self.train_noise
obs_solve = train_covar_plus_noise.solve(obs_minus_samples.unsqueeze(-1))
obs_solve = _permute_solve(
train_covar_plus_noise, obs_minus_samples.unsqueeze(-1)
)

# and multiply the test-observed matrix against the result of the solve
updated_samples = self.test_train_covar.matmul(obs_solve).squeeze(-1)
Expand Down Expand Up @@ -286,3 +290,37 @@ def _draw_from_base_covar(
res = covar_root.matmul(base_samples)

return res.squeeze(-1)


def _permute_solve(A: LinearOperator, b: Tensor) -> LinearOperator:
r"""Solve the batched linear system AX = b, where b is a batched column
vector. The solve is carried out after permuting the largest batch
dimension of b to the final position, which results in a more efficient
matrix-matrix solve.

This ideally should be handled upstream (in GPyTorch, linear_operator or
PyTorch), after which any uses of this method can be replaced with
`A.solve(b)`.

Args:
A: LinearOperator of shape (n, n)
b: Tensor of shape (..., n, 1)

Returns:
LinearOperator of shape (..., n, 1)
"""
# permute dimensions to move largest batch dimension to the end (more efficient
# than unsqueezing)
perm = list(range(b.ndim))
if b.ndim > 2:
largest_batch_dim, _ = max(enumerate(b.shape[:-2]), key=lambda t: t[1])
perm[-1], perm[largest_batch_dim] = perm[largest_batch_dim], perm[-1]
b_p = b.permute(*perm)

x_p = A.solve(b_p)

# Undo permutation
inverse_perm = torch.argsort(torch.tensor(perm))
x = x_p.permute(*inverse_perm)

return x
24 changes: 22 additions & 2 deletions test/posteriors/test_multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
import torch
from botorch.exceptions.errors import BotorchTensorDimensionError
from botorch.models.multitask import KroneckerMultiTaskGP
from botorch.posteriors.multitask import MultitaskGPPosterior
from botorch.posteriors.multitask import _permute_solve, MultitaskGPPosterior
from botorch.sampling.normal import IIDNormalSampler
from botorch.utils.testing import BotorchTestCase
from linear_operator.operators import to_linear_operator


def get_posterior_test_cases(
Expand Down Expand Up @@ -41,7 +42,6 @@ def get_posterior_test_cases(


class TestMultitaskGPPosterior(BotorchTestCase):

def _test_MultitaskGPPosterior(self, dtype: torch.dtype) -> None:
post_list = get_posterior_test_cases(device=self.device, dtype=dtype)
sample_shaping = torch.Size([5, 3])
Expand Down Expand Up @@ -189,3 +189,23 @@ def test_draw_from_base_covar(self):
base_samples = torch.randn(4, 10, 1, device=self.device)
with self.assertRaises(RuntimeError):
res = posterior._draw_from_base_covar(sym_mat, base_samples)


class TestPermuteSolve(BotorchTestCase):
def test_permute_solve_tensor(self):
# Random PSD matrix
a = torch.randn(32, 32, device=self.device, dtype=torch.float64)
A = torch.mm(a, a.t())

# Random batched column vector
b = torch.randn(4, 1, 32, 1, device=self.device, dtype=torch.float64)

# Compare results of permuted and standard solve
x_1 = _permute_solve(to_linear_operator(A), b)
x_2 = torch.linalg.solve(A, b)
self.assertAllClose(x_1, x_2)

# Ensure also works if b is not batched
x_1 = _permute_solve(to_linear_operator(A), b[0, 0, :, :])
x_2 = torch.linalg.solve(A, b[0, 0, :, :])
self.assertAllClose(x_1, x_2)
Loading