-
Notifications
You must be signed in to change notification settings - Fork 408
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Bug] load_state_dict doesn't invalidate cached transformed inputs #1471
Comments
This is just a product of how these transforms are implemented. Their attributes not cached the way GPyTorch caches the For the particular example of It also seems like the differences you see in your example are not due to some caching on the side of Nevermind, I was printing the |
Ok, here's what's happening. After the model training, an |
Another thing to note is currently in |
Cool, I'm looking forward to watching that PR. Maybe this is obvious at this point, but a quick fix to the current codebase would be to add to def _load_from_state_dict(self, *args, **kwargs):
self._revert_to_original_inputs()
super()._load_from_state_dict(*args, **kwargs) (Here I'm just copying the approach used in gpytorch.) |
🐛 Bug
botorch models do not properly support
model.load_state_dict
when usinginput_transform
s with trained parameters. After performingmodel.load_state_dict
, the model continues using cached transformed inputs that were computed with the previous parameters.gpytorch doesn't have this bug in its caching; it intentionally clears all caches whenever loading the state dict.
Workaround: call
model.train()
before callingmodel.load_state_dict()
.To reproduce
Actual output
Expected output
System information
BoTorch 0.7.2
GPyTorch 1.9.0
PyTorch 1.13.0
MacOS 13.0
The text was updated successfully, but these errors were encountered: