Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] KroneckerMultiTaskGP incompatible with batch_cross_validation or batched models #2553

Open
slishak-PX opened this issue Sep 26, 2024 · 2 comments
Labels
bug Something isn't working

Comments

@slishak-PX
Copy link

🐛 Bug

When trying to use batch_cross_validation to fit a KroneckerMultiTaskGP, a batch shape error occurs.

The code snippet below uses a similar example to https://botorch.org/tutorials/batch_mode_cross_validation.

The following snippet also causes the same error, suggesting that the problem is that KroneckerMultiTaskGP does not correctly support batching.

gp = KroneckerMultiTaskGP(cv_folds.train_X, cv_folds.train_Y)
gp.posterior(cv_folds.test_X)

To reproduce

** Code snippet to reproduce **

import math
import torch
from botorch.cross_validation import batch_cross_validation, gen_loo_cv_folds
from botorch.models import SingleTaskGP, KroneckerMultiTaskGP
from botorch.models.transforms.input import Normalize
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood

device = torch.device("cuda:0")
dtype = torch.float64
torch.manual_seed(3)

sigma = math.sqrt(0.2)
coefs = torch.linspace(1, 4, 4, dtype=dtype, device=device).view(1, -1)
train_X = torch.linspace(0, 1, 20, dtype=dtype, device=device).view(-1, 1)
train_Y_noiseless = torch.sin(train_X * coefs * math.pi)
train_Y = train_Y_noiseless + sigma * torch.randn_like(train_Y_noiseless)

cv_folds = gen_loo_cv_folds(train_X, train_Y)

# If this is False, a SingleTaskGP is used
multi_task = True

cv_results = batch_cross_validation(
    model_cls=KroneckerMultiTaskGP if multi_task else SingleTaskGP,
    mll_cls=ExactMarginalLogLikelihood,
    cv_folds=cv_folds,
    model_init_kwargs={
        "input_transform": Normalize(d=train_X.shape[-1]),
    },
)

** Stack trace/error message **

File /opt/conda/envs/python3/lib/python3.10/site-packages/gpytorch/kernels/multitask_kernel.py:53, in MultitaskKernel.forward(self, x1, x2, diag, last_dim_is_batch, **params)
     51     covar_i = covar_i.repeat(*x1.shape[:-2], 1, 1)
     52 covar_x = to_linear_operator(self.data_covar_module.forward(x1, x2, **params))
---> 53 res = KroneckerProductLinearOperator(covar_x, covar_i)
     54 return res.diagonal(dim1=-1, dim2=-2) if diag else res

File /opt/conda/envs/python3/lib/python3.10/site-packages/gpytorch/lazy/lazy_tensor.py:46, in deprecated_lazy_tensor.<locals>.__init__(self, *args, **kwargs)
     43     else:
     44         new_kwargs[name] = val
---> 46 return __orig_init__(self, *args, **new_kwargs)

File /opt/conda/envs/python3/lib/python3.10/site-packages/linear_operator/operators/kronecker_product_linear_operator.py:81, in KroneckerProductLinearOperator.__init__(self, *linear_ops)
     79     batch_broadcast_shape = torch.broadcast_shapes(*(linear_op.batch_shape for linear_op in linear_ops))
     80 except RuntimeError:
---> 81     raise RuntimeError(
     82         "Batch shapes of LinearOperators "
     83         f"({', '.join([str(tuple(linear_op.shape)) for linear_op in linear_ops])}) "
     84         "are incompatible for a Kronecker product."
     85     )
     87 if len(batch_broadcast_shape):  # Otherwise all linear_ops are non-batch, and we don't need to expand
     88     # NOTE: we must explicitly call requires_grad on each of these arguments
     89     # for the automatic _bilinear_derivative to work in torch.autograd.Functions
     90     linear_ops = tuple(
     91         linear_op._expand_batch(batch_broadcast_shape).requires_grad_(linear_op.requires_grad)
     92         for linear_op in linear_ops
     93     )

RuntimeError: Batch shapes of LinearOperators ((20, 19, 19), (400, 4, 4)) are incompatible for a Kronecker product.

Expected Behavior

System information

Please complete the following information:

  • BoTorch 0.12.0 (also occurs with 0.11.3, although this code requires 0.12 due to the lack of Standardize)
  • GPyTorch 1.13 (also occurs with 1.12)
  • PyTorch 2.0.1+cu117
  • Linux

Additional context

Add any other context about the problem here.

@slishak-PX slishak-PX added the bug Something isn't working label Sep 26, 2024
@sdaulton
Copy link
Contributor

Hi @slishak-PX,

Thanks for flagging this. The docstring for batch_cross_validation says that MultiTaskGPs are not supported, but we should raise an UnsupportedError too. I'll put up a PR to add that.

The following snippet also causes the same error, suggesting that the problem is that KroneckerMultiTaskGP does not correctly support batching.

Interesting. Have you looked into this any deeper? If you have an idea of the fix, I'd be happy to review a PR. Also curious if @jandylin has any idea

@slishak-PX
Copy link
Author

Thanks for flagging this. The docstring for batch_cross_validation says that MultiTaskGPs are not supported, but we should raise an UnsupportedError too. I'll put up a PR to add that.

Sorry, I should have read more carefully - I think my attention was drawn by the "WARNING" at the top and then I forgot to read the rest!

Interesting. Have you looked into this any deeper? If you have an idea of the fix, I'd be happy to review a PR. Also curious if @jandylin has any idea

I haven't investigated more deeply yet, but I've tried evaluating the posterior with inputs of multiple batch dimensions, and the same error appears.

sdaulton added a commit to sdaulton/botorch that referenced this issue Sep 26, 2024
Summary:
see title.

pytorch#2553

Differential Revision: D63464560
sdaulton added a commit to sdaulton/botorch that referenced this issue Sep 26, 2024
Summary:
Pull Request resolved: pytorch#2554

see title.

pytorch#2553

Reviewed By: saitcakmak, Balandat

Differential Revision: D63464560
sdaulton added a commit to sdaulton/botorch that referenced this issue Sep 26, 2024
Summary:
Pull Request resolved: pytorch#2554

see title.

pytorch#2553

Reviewed By: saitcakmak, Balandat

Differential Revision: D63464560
facebook-github-bot pushed a commit that referenced this issue Sep 26, 2024
Summary:
Pull Request resolved: #2554

see title.

#2553

Reviewed By: saitcakmak, Balandat

Differential Revision: D63464560

fbshipit-source-id: b02dce2fbe7b29467cae7991f2fd1acfefbfee41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants