Skip to content

[Bug] KroneckerMultiTaskGP incompatible with batch_cross_validation or batched models #2553

@slishak-PX

Description

@slishak-PX

🐛 Bug

When trying to use batch_cross_validation to fit a KroneckerMultiTaskGP, a batch shape error occurs.

The code snippet below uses a similar example to https://botorch.org/tutorials/batch_mode_cross_validation.

The following snippet also causes the same error, suggesting that the problem is that KroneckerMultiTaskGP does not correctly support batching.

gp = KroneckerMultiTaskGP(cv_folds.train_X, cv_folds.train_Y)
gp.posterior(cv_folds.test_X)

To reproduce

** Code snippet to reproduce **

import math
import torch
from botorch.cross_validation import batch_cross_validation, gen_loo_cv_folds
from botorch.models import SingleTaskGP, KroneckerMultiTaskGP
from botorch.models.transforms.input import Normalize
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood

device = torch.device("cuda:0")
dtype = torch.float64
torch.manual_seed(3)

sigma = math.sqrt(0.2)
coefs = torch.linspace(1, 4, 4, dtype=dtype, device=device).view(1, -1)
train_X = torch.linspace(0, 1, 20, dtype=dtype, device=device).view(-1, 1)
train_Y_noiseless = torch.sin(train_X * coefs * math.pi)
train_Y = train_Y_noiseless + sigma * torch.randn_like(train_Y_noiseless)

cv_folds = gen_loo_cv_folds(train_X, train_Y)

# If this is False, a SingleTaskGP is used
multi_task = True

cv_results = batch_cross_validation(
    model_cls=KroneckerMultiTaskGP if multi_task else SingleTaskGP,
    mll_cls=ExactMarginalLogLikelihood,
    cv_folds=cv_folds,
    model_init_kwargs={
        "input_transform": Normalize(d=train_X.shape[-1]),
    },
)

** Stack trace/error message **

File /opt/conda/envs/python3/lib/python3.10/site-packages/gpytorch/kernels/multitask_kernel.py:53, in MultitaskKernel.forward(self, x1, x2, diag, last_dim_is_batch, **params)
     51     covar_i = covar_i.repeat(*x1.shape[:-2], 1, 1)
     52 covar_x = to_linear_operator(self.data_covar_module.forward(x1, x2, **params))
---> 53 res = KroneckerProductLinearOperator(covar_x, covar_i)
     54 return res.diagonal(dim1=-1, dim2=-2) if diag else res

File /opt/conda/envs/python3/lib/python3.10/site-packages/gpytorch/lazy/lazy_tensor.py:46, in deprecated_lazy_tensor.<locals>.__init__(self, *args, **kwargs)
     43     else:
     44         new_kwargs[name] = val
---> 46 return __orig_init__(self, *args, **new_kwargs)

File /opt/conda/envs/python3/lib/python3.10/site-packages/linear_operator/operators/kronecker_product_linear_operator.py:81, in KroneckerProductLinearOperator.__init__(self, *linear_ops)
     79     batch_broadcast_shape = torch.broadcast_shapes(*(linear_op.batch_shape for linear_op in linear_ops))
     80 except RuntimeError:
---> 81     raise RuntimeError(
     82         "Batch shapes of LinearOperators "
     83         f"({', '.join([str(tuple(linear_op.shape)) for linear_op in linear_ops])}) "
     84         "are incompatible for a Kronecker product."
     85     )
     87 if len(batch_broadcast_shape):  # Otherwise all linear_ops are non-batch, and we don't need to expand
     88     # NOTE: we must explicitly call requires_grad on each of these arguments
     89     # for the automatic _bilinear_derivative to work in torch.autograd.Functions
     90     linear_ops = tuple(
     91         linear_op._expand_batch(batch_broadcast_shape).requires_grad_(linear_op.requires_grad)
     92         for linear_op in linear_ops
     93     )

RuntimeError: Batch shapes of LinearOperators ((20, 19, 19), (400, 4, 4)) are incompatible for a Kronecker product.

Expected Behavior

System information

Please complete the following information:

  • BoTorch 0.12.0 (also occurs with 0.11.3, although this code requires 0.12 due to the lack of Standardize)
  • GPyTorch 1.13 (also occurs with 1.12)
  • PyTorch 2.0.1+cu117
  • Linux

Additional context

Add any other context about the problem here.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions