From cdd668d18b2a7e35bed09b7a2b2fca40e5fd2067 Mon Sep 17 00:00:00 2001 From: Sait Cakmak Date: Wed, 31 Aug 2022 14:03:54 -0700 Subject: [PATCH] Move input transforms to GPyTorch Summary: This diff presents a minimal implementation of input transforms in GPyTorch. What this does: * Moves the `transform_inputs` from BoTorch `Model` to GPyTorch `GP` class, with some modifications to explicitly identify whether given inputs are train or test inputs. * Modifies the `InputTransform.forward` call to use `is_training_input` argument instead of `self.training` check to apply the transforms that have `transform_on_train=True`. * Removes `preprocess_transform` method since this is no-longer needed. * For `ExactGP` models, it transforms both train and test inputs in `__call__`. For `train_inputs` it always uses `is_training_input=True`. For generic `inputs`, it uses `is_training_input=self.training` which signals that these are training inputs when the model is in `train` mode, and that these are test inputs when the model is in `eval` mode. * For `ApproximateGP` models, it applies the transform to `inputs` in `__call__` using `is_training_input=self.training`. This again signifies whether the given inputs are train or test inputs based on the mode of the model. Note that this NEVER transforms `inducing_points`, thus fixes the previous bug with `inducing_points` getting transformed in `train` but not getting transformed in `eval`. It is expected that the user will define inducing points in the appropriate space (mostly the normalized space / unit cube). * For BoTorch `SingleTaskVariationalGP`, it moves the `input_transform` attribute down to `_SingleTaskVariationalGP`, which is the actual `ApproximateGP` instance. This makes the transform accessible from GPyTorch. What this doesn't do: * It doesn't do anything about `DeterministicModel`s. Those will still need to deal with their own transforms, which is not implemented here. If we make `Model` inherit from `GP`, we can keep the existing setup with very minimal changes. * It does not clean up the call sites for `self.transform_inputs`. This is just made into a no-op and the clean-up is left for later. * It does not upstream the abstract `InputTransform` classes to GPyTorch. That'll be done if we decide to go forward with this design. * It does not touch `PairwiseGP`. `PairwiseGP` has some non-standard use of input transforms, so it needs an audit to make sure things still work fine. * I didn't look into `ApproximateGP.fantasize`. This may need some changes similar to `ExactGP.get_fantasy_model`. * It does not support `PyroGP` and `DeepGP`. Differential Revision: D39147547 fbshipit-source-id: ed2745b0ff666a13764759e1511a139c228d1d39 --- botorch/models/approximate_gp.py | 6 +- botorch/models/model.py | 79 +++------------------- botorch/models/model_list_gp_regression.py | 10 --- botorch/models/transforms/input.py | 53 ++------------- 4 files changed, 20 insertions(+), 128 deletions(-) diff --git a/botorch/models/approximate_gp.py b/botorch/models/approximate_gp.py index e9ead0c239..2f124073ad 100644 --- a/botorch/models/approximate_gp.py +++ b/botorch/models/approximate_gp.py @@ -168,6 +168,7 @@ def __init__( variational_distribution: Optional[_VariationalDistribution] = None, variational_strategy: Type[_VariationalStrategy] = VariationalStrategy, inducing_points: Optional[Union[Tensor, int]] = None, + input_transform: Optional[InputTransform] = None, ) -> None: r""" Args: @@ -252,6 +253,8 @@ def __init__( super().__init__(variational_strategy=variational_strategy) self.mean_module = mean_module self.covar_module = covar_module + if input_transform is not None: + self.input_transform = input_transform def forward(self, X) -> MultivariateNormal: mean_x = self.mean_module(X) @@ -373,14 +376,13 @@ def __init__( variational_distribution=variational_distribution, variational_strategy=variational_strategy, inducing_points=inducing_points, + input_transform=input_transform, ) super().__init__(model=model, likelihood=likelihood, num_outputs=num_outputs) if outcome_transform is not None: self.outcome_transform = outcome_transform - if input_transform is not None: - self.input_transform = input_transform # for model fitting utilities # TODO: make this a flag? diff --git a/botorch/models/model.py b/botorch/models/model.py index da0645e2f3..43bddafb55 100644 --- a/botorch/models/model.py +++ b/botorch/models/model.py @@ -36,19 +36,8 @@ class Model(Module, ABC): Model cannot be used directly; it only defines an API for other BoTorch models. - - Args: - _has_transformed_inputs: A boolean denoting whether `train_inputs` are currently - stored as transformed or not. - _original_train_inputs: A Tensor storing the original train inputs for use in - `_revert_to_original_inputs`. Note that this is necessary since - transform / untransform cycle introduces numerical errors which lead - to upstream errors during training. """ - _has_transformed_inputs: bool = False - _original_train_inputs: Optional[Tensor] = None - @abstractmethod def posterior( self, @@ -199,57 +188,11 @@ def transform_inputs( Returns: A tensor of transformed inputs """ - if input_transform is not None: - input_transform.to(X) - return input_transform(X) - try: - return self.input_transform(X) - except AttributeError: - return X - - def _set_transformed_inputs(self) -> None: - r"""Update training inputs with transformed inputs.""" - if hasattr(self, "input_transform") and not self._has_transformed_inputs: - if hasattr(self, "train_inputs"): - self._original_train_inputs = self.train_inputs[0] - with torch.no_grad(): - X_tf = self.input_transform.preprocess_transform( - self.train_inputs[0] - ) - self.set_train_data(X_tf, strict=False) - self._has_transformed_inputs = True - else: - warnings.warn( - "Could not update `train_inputs` with transformed inputs " - f"since {self.__class__.__name__} does not have a `train_inputs` " - "attribute. Make sure that the `input_transform` is applied to " - "both the train inputs and test inputs.", - RuntimeWarning, - ) - - def _revert_to_original_inputs(self) -> None: - r"""Revert training inputs back to original.""" - if hasattr(self, "input_transform") and self._has_transformed_inputs: - self.set_train_data(self._original_train_inputs, strict=False) - self._has_transformed_inputs = False - - def eval(self) -> Model: - r"""Puts the model in `eval` mode and sets the transformed inputs.""" - self._set_transformed_inputs() - return super().eval() - - def train(self, mode: bool = True) -> Model: - r"""Puts the model in `train` mode and reverts to the original inputs. - - Args: - mode: A boolean denoting whether to put in `train` or `eval` mode. - If `False`, model is put in `eval` mode. - """ - if mode: - self._revert_to_original_inputs() - else: - self._set_transformed_inputs() - return super().train(mode=mode) + warnings.warn( + "`Model.transform_inputs` is deprecated. Input transforms are applied at GPyTorch model `__call__` instead.", + DeprecationWarning, + ) + return X class ModelList(Model): @@ -413,10 +356,8 @@ def transform_inputs(self, X: Tensor) -> List[Tensor]: Returns: A list of tensors of transformed inputs. """ - transformed_X_list = [] - for model in self.models: - try: - transformed_X_list.append(model.input_transform(X)) - except AttributeError: - transformed_X_list.append(X) - return transformed_X_list + warnings.warn( + "`Model.transform_inputs` is deprecated. Input transforms are applied at GPyTorch model `__call__` instead.", + DeprecationWarning, + ) + return [X for _ in self.models] diff --git a/botorch/models/model_list_gp_regression.py b/botorch/models/model_list_gp_regression.py index 2c62b92e97..1283767008 100644 --- a/botorch/models/model_list_gp_regression.py +++ b/botorch/models/model_list_gp_regression.py @@ -114,13 +114,3 @@ def subset_output(self, idcs: List[int]) -> ModelListGP: The current model, subset to the specified output indices. """ return self.__class__(*[deepcopy(self.models[i]) for i in idcs]) - - def _set_transformed_inputs(self) -> None: - r"""Update training inputs with transformed inputs.""" - for m in self.models: - m._set_transformed_inputs() - - def _revert_to_original_inputs(self) -> None: - r"""Revert training inputs back to original.""" - for m in self.models: - m._revert_to_original_inputs() diff --git a/botorch/models/transforms/input.py b/botorch/models/transforms/input.py index fb084da5f1..5498e5e19e 100644 --- a/botorch/models/transforms/input.py +++ b/botorch/models/transforms/input.py @@ -55,16 +55,20 @@ class InputTransform(ABC): transform_on_train: bool transform_on_fantasize: bool - def forward(self, X: Tensor) -> Tensor: + def forward(self, X: Tensor, is_training_input: bool) -> Tensor: r"""Transform the inputs to a model. Args: X: A `batch_shape x n x d`-dim tensor of inputs. + is_training_input: A boolean denoting whether the input is a training input. + If true, only the transforms with `transform_on_train=True` are applied. + Otherwise, the transform will be applied based on `transform_on_eval` + and `transform_on_fantasize` options. Returns: A `batch_shape x n' x d`-dim tensor of transformed inputs. """ - if self.training: + if is_training_input: if self.transform_on_train: return self.transform(X) elif self.transform_on_eval: @@ -123,33 +127,6 @@ def equals(self, other: InputTransform) -> bool: ) ) - def preprocess_transform(self, X: Tensor) -> Tensor: - r"""Apply transforms for preprocessing inputs. - - The main use cases for this method are 1) to preprocess training data - before calling `set_train_data` and 2) preprocess `X_baseline` for noisy - acquisition functions so that `X_baseline` is "preprocessed" with the - same transformations as the cached training inputs. - - Args: - X: A `batch_shape x n x d`-dim tensor of inputs. - - Returns: - A `batch_shape x n x d`-dim tensor of (transformed) inputs. - """ - if self.transform_on_train: - # We need to disable learning of bounds here. - # See why: https://github.com/pytorch/botorch/issues/1078. - if hasattr(self, "learn_bounds"): - learn_bounds = self.learn_bounds - self.learn_bounds = False - result = self.transform(X) - self.learn_bounds = learn_bounds - return result - else: - return self.transform(X) - return X - class ChainedInputTransform(InputTransform, ModuleDict): r"""An input transform representing the chaining of individual transforms.""" @@ -224,24 +201,6 @@ def equals(self, other: InputTransform) -> bool: t1 == t2 for t1, t2 in zip(self.values(), other.values()) ) - def preprocess_transform(self, X: Tensor) -> Tensor: - r"""Apply transforms for preprocessing inputs. - - The main use cases for this method are 1) to preprocess training data - before calling `set_train_data` and 2) preprocess `X_baseline` for noisy - acquisition functions so that `X_baseline` is "preprocessed" with the - same transformations as the cached training inputs. - - Args: - X: A `batch_shape x n x d`-dim tensor of inputs. - - Returns: - A `batch_shape x n x d`-dim tensor of (transformed) inputs. - """ - for tf in self.values(): - X = tf.preprocess_transform(X) - return X - class ReversibleInputTransform(InputTransform, ABC): r"""An abstract class for a reversible input transform.