Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] load_state_dict doesn't invalidate cached transformed inputs #1471

Open
mrcslws opened this issue Nov 2, 2022 · 4 comments
Open

[Bug] load_state_dict doesn't invalidate cached transformed inputs #1471

mrcslws opened this issue Nov 2, 2022 · 4 comments
Assignees
Labels
bug Something isn't working WIP Work in Progress

Comments

@mrcslws
Copy link

mrcslws commented Nov 2, 2022

🐛 Bug

botorch models do not properly support model.load_state_dict when using input_transforms with trained parameters. After performing model.load_state_dict, the model continues using cached transformed inputs that were computed with the previous parameters.

gpytorch doesn't have this bug in its caching; it intentionally clears all caches whenever loading the state dict.

Workaround: call model.train() before calling model.load_state_dict().

To reproduce

import copy

import botorch
import gpytorch
import torch


train_X = torch.tensor([[0., 0.],
                        [0.1, 0.1],
                        [0.5, 0.5]])
train_Y = torch.tensor([[0.],
                        [0.5],
                        [1.0]])
test_X = torch.tensor([[0.3, 0.3]])

model = botorch.models.SingleTaskGP(
    train_X, train_Y,
    # This is one example input transform that stores trained parameters in the
    # state dict
    input_transform=botorch.models.transforms.Warp(indices=[0, 1]))

state_dict = copy.deepcopy(model.state_dict())

# Check initial behavior
model.eval()
print(f"Before: {model(test_X).mean.item()}")

# Train model, adjusting the Warp parameters and caching transformed inputs
botorch.fit_gpytorch_mll(
    gpytorch.mlls.ExactMarginalLogLikelihood(model.likelihood, model)
)

# Workaround: uncomment following line
# model.train()

# Revert to original parameters
model.load_state_dict(state_dict)

# Verify that output matches original output
model.eval()
print(f"After: {model(test_X).mean.item()}")

Actual output

Before: 0.2642167806625366
After: 0.21983212232589722

Expected output

Before: 0.2642167806625366
After: 0.2642167806625366

System information

BoTorch 0.7.2
GPyTorch 1.9.0
PyTorch 1.13.0
MacOS 13.0

@mrcslws mrcslws added the bug Something isn't working label Nov 2, 2022
@saitcakmak
Copy link
Contributor

saitcakmak commented Nov 2, 2022

This is just a product of how these transforms are implemented. Their attributes not cached the way GPyTorch caches the train_train_covar etc, but are buffers (or parameters in the case of Warp) that are learnable (or trainable). They will be included in the output of model.state_dict(), and likewise should be included in the model.load_state_dict(state_dict) call if you want them to get updated.

For the particular example of Warp, it is a trainable transform. It has parameters that are trained alongside other model hyper parameters. So, invalidating these on load_state_dict would not lead to any meaningful outcomes when you call the model (unless you train the model afterwards).

It also seems like the differences you see in your example are not due to some caching on the side of Warp or any buffers. When you call model.load_state_dict(state_dict) you reset the parameters of Warp to the originals (can be verified by checking state_dict). The thing that changes on train call is GPyTorch deletes the prediction_strategy. So, the caches that stick around are actually the caches on model.prediction_strategy.

Nevermind, I was printing the prediction_strategy in the wrong spot.

@saitcakmak
Copy link
Contributor

Ok, here's what's happening. After the model training, an mll.eval() call triggers this bit of code that updates model.train_inputs with the transformed inputs. When you reload the state dict, you reset the input transform but not the transformed inputs - thus the odd behavior. The model.train() call reverts the train inputs to the originals, fixing this mismatch. This behavior will be cleaned up as part of #1372, which I will get back to in the next couple weeks.

@saitcakmak
Copy link
Contributor

Another thing to note is currently in eval model, for the input transforms to be applied, you should call the model through model.posterior. Otherwise, the input transforms will only be applied when the model is in train mode, leading to buggy behavior. This is also being cleaned up as part of #1372.

@mrcslws
Copy link
Author

mrcslws commented Nov 2, 2022

Cool, I'm looking forward to watching that PR. Maybe this is obvious at this point, but a quick fix to the current codebase would be to add to botorch.models.Model:

    def _load_from_state_dict(self, *args, **kwargs):
        self._revert_to_original_inputs()
        super()._load_from_state_dict(*args, **kwargs)

(Here I'm just copying the approach used in gpytorch.)

@esantorella esantorella added the WIP Work in Progress label Jan 30, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working WIP Work in Progress
Projects
None yet
Development

No branches or pull requests

3 participants