-
Notifications
You must be signed in to change notification settings - Fork 404
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Allow broadcasting across dimensions in eval mode; always require X t…
…o be at least 2d (#2518) Summary: Pull Request resolved: #2518 Context: A discussion on allowable shapes for transforms concluded: * We should not allow for broadcasting across the -1 dimension, so the first check in _check_shape should always happen. * The shapes always need to be broadcastable, so the torch.broadcast_shapes check in _check_shape should always happen. * We want to allow for broadcasting across the batch dimension in eval model, so the check that X has dimension of at least len(batch_shape) + 2 should only happen in training mode. * For clarity, we should disallow 1d X, even if broadcastable. BoTorch tends to be strict about requiring explicit dimensions, e.g. GPyTorchModel._validate_tensor_args, and that's a good thing because confusion about tensor dimensions causes a lot of pain. This diff: * Only checks that X has number of dimensions equal to 2 + the number of batch dimensions in training mode. * Disallows <2d X. Reviewed By: Balandat Differential Revision: D62404492 fbshipit-source-id: ea287effee86b9f1eb67863b21a95d6c0a9e49b3
- Loading branch information
1 parent
ebd1727
commit b58852e
Showing
2 changed files
with
39 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters