-
Notifications
You must be signed in to change notification settings - Fork 562
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature Request] Input Transformations #1652
Comments
I'll just add concrete examples of what we tried and where it fails. Apply one-to-many transforms at
Define a wrapper around gpytorch/gpytorch/models/exact_gp.py Line 256 in a0d8cd2
We could maybe get around this by wrapping the call with with debug(False) , but that also breaks tests, so probably not a good idea.
Current proposal at pytorch/botorch#819 is to apply the transforms at |
@Balandat and @saitcakmak is this the interface you have in mind? class MyGP(gpytorch.models.ExactGP):
def __init__(self, train_x, train_y, likelihood):
# ...
self.transform = InputTransform(...)
def forward(self, x):
x = self.transform(x)
mean = self.mean_module(x)
covar = self.covar_module(x)
return gpytorch.distributions.MultivariateNormal(mean, covar) |
So, that is the current (pre pytorch/botorch#819) interface on botorch end, and there are some issues with that. It does not play nicely with transforms that change the shape of I think def __call__(self, *args, **kwargs):
train_inputs = list(self.transform_inputs(self.train_inputs, train=True)) if self.train_inputs is not None else []
inputs = [self.transform_inputs(i.unsqueeze(-1) if i.ndimension() == 1 else i, train=False) for i in args]
... where we can handle the |
After some discussion with @saitcakmak, it looks like placing the input transforms in If it's done in the forwards pass, then the inducing points would be estimated on the raw, untransformed scale because of the call here. I have a version of current botorch transforms + variational GPs in pytorch/botorch#895 but it only works for fixed, un-learnable transforms (e.g. not learnable input warping) because it forcibly sets the "training inputs" and thus the model inputs to be on the transformed scale, at least when using model fitting utilities such as |
🚀 Feature Request
Essentially, upstream BoTorch's
InputTransformation
from https://github.com/pytorch/botorch/blob/master/botorch/models/transforms/input.py to GPyTorch. This allows to add either fixed or learnable transformations that are automatically applied when training models.Motivation
This allows to do things like normalize inputs but also to combine the GP with a learnable transformation. This simplifies model setup. We currently have this in BoTorch and essentially apply the transform in the
forward
methods.Additional context
We recently worked on having input transformations that can change the shape of the input pytorch/botorch#819, which caused some headaches for how to best set this up without a ton of boilerplate code. We were hoping to do this in the
__call__
rather than forward method, but this collides with some of GPyTorch's assumptions. Moving this functionality upstream into gpytorch would allow us to solve these challenges more organically.Describe alternatives you've considered
You could do this as we do it right now, but one would have to add boilerplate transformation code to every implementation of
forward
.The text was updated successfully, but these errors were encountered: