-
Notifications
You must be signed in to change notification settings - Fork 457
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
🐛 Bug
When trying to use batch_cross_validation to fit a KroneckerMultiTaskGP, a batch shape error occurs.
The code snippet below uses a similar example to https://botorch.org/tutorials/batch_mode_cross_validation.
The following snippet also causes the same error, suggesting that the problem is that KroneckerMultiTaskGP does not correctly support batching.
gp = KroneckerMultiTaskGP(cv_folds.train_X, cv_folds.train_Y)
gp.posterior(cv_folds.test_X)To reproduce
** Code snippet to reproduce **
import math
import torch
from botorch.cross_validation import batch_cross_validation, gen_loo_cv_folds
from botorch.models import SingleTaskGP, KroneckerMultiTaskGP
from botorch.models.transforms.input import Normalize
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood
device = torch.device("cuda:0")
dtype = torch.float64
torch.manual_seed(3)
sigma = math.sqrt(0.2)
coefs = torch.linspace(1, 4, 4, dtype=dtype, device=device).view(1, -1)
train_X = torch.linspace(0, 1, 20, dtype=dtype, device=device).view(-1, 1)
train_Y_noiseless = torch.sin(train_X * coefs * math.pi)
train_Y = train_Y_noiseless + sigma * torch.randn_like(train_Y_noiseless)
cv_folds = gen_loo_cv_folds(train_X, train_Y)
# If this is False, a SingleTaskGP is used
multi_task = True
cv_results = batch_cross_validation(
model_cls=KroneckerMultiTaskGP if multi_task else SingleTaskGP,
mll_cls=ExactMarginalLogLikelihood,
cv_folds=cv_folds,
model_init_kwargs={
"input_transform": Normalize(d=train_X.shape[-1]),
},
)** Stack trace/error message **
File /opt/conda/envs/python3/lib/python3.10/site-packages/gpytorch/kernels/multitask_kernel.py:53, in MultitaskKernel.forward(self, x1, x2, diag, last_dim_is_batch, **params)
51 covar_i = covar_i.repeat(*x1.shape[:-2], 1, 1)
52 covar_x = to_linear_operator(self.data_covar_module.forward(x1, x2, **params))
---> 53 res = KroneckerProductLinearOperator(covar_x, covar_i)
54 return res.diagonal(dim1=-1, dim2=-2) if diag else res
File /opt/conda/envs/python3/lib/python3.10/site-packages/gpytorch/lazy/lazy_tensor.py:46, in deprecated_lazy_tensor.<locals>.__init__(self, *args, **kwargs)
43 else:
44 new_kwargs[name] = val
---> 46 return __orig_init__(self, *args, **new_kwargs)
File /opt/conda/envs/python3/lib/python3.10/site-packages/linear_operator/operators/kronecker_product_linear_operator.py:81, in KroneckerProductLinearOperator.__init__(self, *linear_ops)
79 batch_broadcast_shape = torch.broadcast_shapes(*(linear_op.batch_shape for linear_op in linear_ops))
80 except RuntimeError:
---> 81 raise RuntimeError(
82 "Batch shapes of LinearOperators "
83 f"({', '.join([str(tuple(linear_op.shape)) for linear_op in linear_ops])}) "
84 "are incompatible for a Kronecker product."
85 )
87 if len(batch_broadcast_shape): # Otherwise all linear_ops are non-batch, and we don't need to expand
88 # NOTE: we must explicitly call requires_grad on each of these arguments
89 # for the automatic _bilinear_derivative to work in torch.autograd.Functions
90 linear_ops = tuple(
91 linear_op._expand_batch(batch_broadcast_shape).requires_grad_(linear_op.requires_grad)
92 for linear_op in linear_ops
93 )
RuntimeError: Batch shapes of LinearOperators ((20, 19, 19), (400, 4, 4)) are incompatible for a Kronecker product.
Expected Behavior
System information
Please complete the following information:
- BoTorch 0.12.0 (also occurs with 0.11.3, although this code requires 0.12 due to the lack of
Standardize) - GPyTorch 1.13 (also occurs with 1.12)
- PyTorch 2.0.1+cu117
- Linux
Additional context
Add any other context about the problem here.
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working