From b58852e28528b1ed88058f9cbe369869c21556f7 Mon Sep 17 00:00:00 2001 From: Elizabeth Santorella Date: Tue, 10 Sep 2024 07:15:09 -0700 Subject: [PATCH] Allow broadcasting across dimensions in eval mode; always require X to be at least 2d (#2518) Summary: Pull Request resolved: https://github.com/pytorch/botorch/pull/2518 Context: A discussion on allowable shapes for transforms concluded: * We should not allow for broadcasting across the -1 dimension, so the first check in _check_shape should always happen. * The shapes always need to be broadcastable, so the torch.broadcast_shapes check in _check_shape should always happen. * We want to allow for broadcasting across the batch dimension in eval model, so the check that X has dimension of at least len(batch_shape) + 2 should only happen in training mode. * For clarity, we should disallow 1d X, even if broadcastable. BoTorch tends to be strict about requiring explicit dimensions, e.g. GPyTorchModel._validate_tensor_args, and that's a good thing because confusion about tensor dimensions causes a lot of pain. This diff: * Only checks that X has number of dimensions equal to 2 + the number of batch dimensions in training mode. * Disallows <2d X. Reviewed By: Balandat Differential Revision: D62404492 fbshipit-source-id: ea287effee86b9f1eb67863b21a95d6c0a9e49b3 --- botorch/models/transforms/input.py | 6 ++++- test/models/transforms/test_input.py | 37 +++++++++++++++++++++++++--- 2 files changed, 39 insertions(+), 4 deletions(-) diff --git a/botorch/models/transforms/input.py b/botorch/models/transforms/input.py index a8bdb256f6..329e407c50 100644 --- a/botorch/models/transforms/input.py +++ b/botorch/models/transforms/input.py @@ -462,9 +462,13 @@ def _check_shape(self, X: Tensor) -> None: f"Wrong input dimension. Received {X.size(-1)}, " f"expected {self.offset.size(-1)}." ) + if X.ndim < 2: + raise BotorchTensorDimensionError( + f"`X` must have at least 2 dimensions, but has {X.ndim}." + ) n = len(self.batch_shape) + 2 - if X.ndim < n: + if self.training and X.ndim < n: raise ValueError( f"`X` must have at least {n} dimensions, {n - 2} batch and 2 innate" f" , but has {X.ndim}." diff --git a/test/models/transforms/test_input.py b/test/models/transforms/test_input.py index 5537c72cd3..0d09dab09e 100644 --- a/test/models/transforms/test_input.py +++ b/test/models/transforms/test_input.py @@ -240,9 +240,19 @@ def test_normalize(self) -> None: X = torch.cat((torch.randn(4, 1), torch.zeros(4, 1)), dim=-1) X = X.to(self.device) self.assertEqual(torch.isfinite(nlz(X)).sum(), X.numel()) - with self.assertRaisesRegex(ValueError, r"must have at least \d+ dim"): + with self.assertRaisesRegex( + BotorchTensorDimensionError, r"must have at least 2 dimensions" + ): nlz(torch.randn(X.shape[-1], dtype=dtype)) + # using unbatched X to train batched transform + nlz = Normalize(d=2, min_range=1e-4, batch_shape=torch.Size([3])) + X = torch.rand(4, 2) + with self.assertRaisesRegex( + ValueError, "must have at least 3 dimensions, 1 batch and 2 innate" + ): + nlz(X) + # basic usage for batch_shape in (torch.Size(), torch.Size([3])): # learned bounds @@ -341,7 +351,10 @@ def test_normalize(self) -> None: # test errors on wrong shape nlz = Normalize(d=2, batch_shape=batch_shape) X = torch.randn(*batch_shape, 2, 1, device=self.device, dtype=dtype) - with self.assertRaises(BotorchTensorDimensionError): + with self.assertRaisesRegex( + BotorchTensorDimensionError, + "Wrong input dimension. Received 1, expected 2.", + ): nlz(X) # test equals @@ -403,6 +416,22 @@ def test_normalize(self) -> None: expected_X = torch.tensor([[1.5, 0.75]], device=self.device, dtype=dtype) self.assertAllClose(nlzd_X, expected_X) + # Test broadcasting across batch dimensions in eval mode + x = torch.tensor( + [[0.0, 2.0], [3.0, 5.0]], device=self.device, dtype=dtype + ).unsqueeze(-1) + self.assertEqual(x.shape, torch.Size([2, 2, 1])) + nlz = Normalize(d=1, batch_shape=torch.Size([2])) + nlz(x) + nlz.eval() + x2 = torch.tensor([[1.0]], device=self.device, dtype=dtype) + nlzd_x2 = nlz.transform(x2) + self.assertEqual(nlzd_x2.shape, torch.Size([2, 1, 1])) + self.assertAllClose( + nlzd_x2.squeeze(), + torch.tensor([0.5, -1.0], dtype=dtype, device=self.device), + ) + def test_standardize(self) -> None: for dtype in (torch.float, torch.double): # basic init @@ -459,7 +488,9 @@ def test_standardize(self) -> None: X = torch.cat((torch.randn(4, 1), torch.zeros(4, 1)), dim=-1) X = X.to(self.device, dtype=dtype) self.assertEqual(torch.isfinite(stdz(X)).sum(), X.numel()) - with self.assertRaisesRegex(ValueError, r"must have at least \d+ dim"): + with self.assertRaisesRegex( + BotorchTensorDimensionError, r"must have at least \d+ dim" + ): stdz(torch.randn(X.shape[-1], dtype=dtype)) # basic usage