Skip to content

Commit

Permalink
Affine input transforms should error with data of incorrect dimension…
Browse files Browse the repository at this point in the history
…, even in eval mode (#2510)

Summary:
Context: #2509 gives a clear overview

This PR:
* Checks the shape of the `X` provided to an `AffineInputTransform` when it transforms the data, regardless of whether it is updating the coefficients.

Makes some unrelated changes:
* Fixes the example in the docstring for `batched_multi_output_to_single_output`
* fixes an incorrect shape in `test_approximate_gp`
* Makes data and transform batch shapes match in `TestConverters`, since those usages will now (appropriately) error

Pull Request resolved: #2510

Reviewed By: saitcakmak

Differential Revision: D62318530

Pulled By: esantorella

fbshipit-source-id: eaa8b0410c49b17d6abbe1391bbb0750313aea23
  • Loading branch information
esantorella authored and facebook-github-bot committed Sep 9, 2024
1 parent 4dc1271 commit 33e11f4
Show file tree
Hide file tree
Showing 7 changed files with 36 additions and 20 deletions.
4 changes: 2 additions & 2 deletions botorch/models/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,8 +388,8 @@ def batched_multi_output_to_single_output(
Example:
>>> train_X = torch.rand(5, 2)
>>> train_Y = torch.rand(5, 2)
>>> batch_mo_gp = SingleTaskGP(train_X, train_Y)
>>> batch_so_gp = batched_multioutput_to_single_output(batch_gp)
>>> batch_mo_gp = SingleTaskGP(train_X, train_Y, outcome_transform=None)
>>> batch_so_gp = batched_multi_output_to_single_output(batch_mo_gp)
"""
warnings.warn(DEPRECATION_MESSAGE, DeprecationWarning, stacklevel=2)
was_training = batch_mo_model.training
Expand Down
2 changes: 1 addition & 1 deletion botorch/models/transforms/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,8 +412,8 @@ def _transform(self, X: Tensor) -> Tensor:
Returns:
A `batch_shape x n x d`-dim tensor of transformed inputs.
"""
self._check_shape(X)
if self.learn_coefficients and self.training:
self._check_shape(X)
self._update_coefficients(X)
self._to(X)
return (X - self.offset) / self.coefficient
Expand Down
8 changes: 4 additions & 4 deletions test/acquisition/test_proximal.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def test_proximal(self):
proximal_test_X = test_X.clone()
if transformed_weighting:
if input_transform is not None:
last_X = input_transform(train_X[-1])
last_X = input_transform(train_X[-1].unsqueeze(0))
proximal_test_X = input_transform(test_X)

mv_normal = MultivariateNormal(last_X, torch.diag(proximal_weights))
Expand All @@ -105,7 +105,7 @@ def test_proximal(self):
proximal_test_X = test_X.clone()
if transformed_weighting:
if input_transform is not None:
last_X = input_transform(train_X[-1])
last_X = input_transform(train_X[-1].unsqueeze(0))
proximal_test_X = input_transform(test_X)

mv_normal = MultivariateNormal(last_X, torch.diag(proximal_weights))
Expand All @@ -122,7 +122,7 @@ def test_proximal(self):
proximal_test_X = test_X.clone()
if transformed_weighting:
if input_transform is not None:
last_X = input_transform(train_X[-1])
last_X = input_transform(train_X[-1].unsqueeze(0))
proximal_test_X = input_transform(test_X)

ei = EI(test_X)
Expand All @@ -143,7 +143,7 @@ def test_proximal(self):
proximal_test_X = test_X.clone()
if transformed_weighting:
if input_transform is not None:
last_X = input_transform(train_X[-1])
last_X = input_transform(train_X[-1].unsqueeze(0))
proximal_test_X = input_transform(test_X)

qEI_prox = ProximalAcquisitionFunction(
Expand Down
2 changes: 1 addition & 1 deletion test/models/test_approximate_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,5 +327,5 @@ def test_input_transform(self) -> None:
model.likelihood, model.model, num_data=train_X.shape[-2]
)
fit_gpytorch_mll(mll)
post = model.posterior(torch.tensor([train_X.mean()]))
post = model.posterior(torch.tensor([[train_X.mean()]]))
self.assertAllClose(post.mean[0][0], y.mean(), atol=1e-3, rtol=1e-3)
15 changes: 11 additions & 4 deletions test/models/test_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,13 +278,21 @@ def test_model_list_to_batched(self):
batch_shape=torch.Size([3]),
)
gp1_ = SingleTaskGP(
train_X, train_Y1, input_transform=input_tf2, outcome_transform=None
train_X=train_X.unsqueeze(0),
train_Y=train_Y1.unsqueeze(0),
input_transform=input_tf2,
outcome_transform=None,
)
gp2_ = SingleTaskGP(
train_X, train_Y2, input_transform=input_tf2, outcome_transform=None
train_X=train_X.unsqueeze(0),
train_Y=train_Y2.unsqueeze(0),
input_transform=input_tf2,
outcome_transform=None,
)
list_gp = ModelListGP(gp1_, gp2_)
with self.assertRaises(UnsupportedError):
with self.assertRaisesRegex(
UnsupportedError, "Batched input_transforms are not supported."
):
model_list_to_batched(list_gp)

# test outcome transform
Expand Down Expand Up @@ -457,7 +465,6 @@ def test_batched_multi_output_to_single_output(self):
bounds=torch.tensor(
[[-1.0, -1.0], [1.0, 1.0]], device=self.device, dtype=dtype
),
batch_shape=torch.Size([2]),
)
batched_mo_model = SingleTaskGP(
train_X, train_Y, input_transform=input_tf, outcome_transform=None
Expand Down
17 changes: 12 additions & 5 deletions test/models/transforms/test_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,10 @@ def test_normalize(self) -> None:
self.assertTrue(nlz.mins.dtype == other_dtype)
# test incompatible dimensions of specified bounds
bounds = torch.zeros(2, 3, device=self.device, dtype=dtype)
with self.assertRaises(BotorchTensorDimensionError):
with self.assertRaisesRegex(
BotorchTensorDimensionError,
"Dimensions of provided `bounds` are incompatible",
):
Normalize(d=2, bounds=bounds)

# test jitter
Expand Down Expand Up @@ -266,7 +269,12 @@ def test_normalize(self) -> None:
# test errors on wrong shape
nlz = Normalize(d=2, batch_shape=batch_shape)
X = torch.randn(*batch_shape, 2, 1, device=self.device, dtype=dtype)
with self.assertRaises(BotorchTensorDimensionError):
expected_msg = "Wrong input dimension. Received 1, expected 2."
with self.assertRaisesRegex(BotorchTensorDimensionError, expected_msg):
nlz(X)
# Same error in eval mode
nlz.eval()
with self.assertRaisesRegex(BotorchTensorDimensionError, expected_msg):
nlz(X)

# fixed bounds
Expand Down Expand Up @@ -328,9 +336,8 @@ def test_normalize(self) -> None:
[X.min(dim=-2, keepdim=True)[0], X.max(dim=-2, keepdim=True)[0]],
dim=-2,
)[..., indices]
self.assertTrue(
torch.allclose(nlz.bounds, expected_bounds, atol=1e-4, rtol=1e-4)
)
self.assertAllClose(nlz.bounds, expected_bounds, atol=1e-4, rtol=1e-4)

# test errors on wrong shape
nlz = Normalize(d=2, batch_shape=batch_shape)
X = torch.randn(*batch_shape, 2, 1, device=self.device, dtype=dtype)
Expand Down
8 changes: 5 additions & 3 deletions test_community/models/test_gp_regression_multisource.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ def _get_model_and_data(
None if train_Yvar else get_gaussian_likelihood_with_gamma_prior()
),
}
model = SingleTaskAugmentedGP(**model_kwargs, **extra_model_kwargs)
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=OptimizationWarning)
model = SingleTaskAugmentedGP(**model_kwargs, **extra_model_kwargs)
return model, model_kwargs

def test_data_init(self):
Expand Down Expand Up @@ -139,8 +141,8 @@ def test_get_reliable_observation(self):
self.assertListEqual(res.tolist(), true_res.tolist())

def test_gp(self):
bounds = torch.tensor([[-1.0], [1.0]])
d = 5
bounds = torch.stack((torch.full((d - 1,), -1), torch.ones(d - 1)))
for batch_shape, dtype, use_octf, use_intf, train_Yvar in itertools.product(
(torch.Size(), torch.Size([2])),
(torch.float, torch.double),
Expand All @@ -151,7 +153,7 @@ def test_gp(self):
tkwargs = {"device": self.device, "dtype": dtype}
octf = Standardize(m=1, batch_shape=torch.Size()) if use_octf else None
intf = (
Normalize(d=1, bounds=bounds.to(**tkwargs), transform_on_train=True)
Normalize(d=d - 1, bounds=bounds.to(**tkwargs), transform_on_train=True)
if use_intf
else None
)
Expand Down

0 comments on commit 33e11f4

Please sign in to comment.