Skip to content

Commit 0afcf35

Browse files
Balandatfacebook-github-bot
authored andcommitted
Check that choices is non-empty in optimize_acqf_discrete (#1228)
Summary: Otherwise this produces confusing tensor dimension errors if no chocies are given. Also applies ufmt style throughout. Pull Request resolved: #1228 Reviewed By: bernardbeckerman Differential Revision: D36480928 Pulled By: Balandat fbshipit-source-id: 39d8b55ef4666265c1657525326528b8f2c1f4b2
1 parent 1b92ee8 commit 0afcf35

File tree

2 files changed

+12
-2
lines changed

2 files changed

+12
-2
lines changed

botorch/optim/optimize.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
OneShotAcquisitionFunction,
1919
)
2020
from botorch.acquisition.knowledge_gradient import qKnowledgeGradient
21-
from botorch.exceptions import UnsupportedError
21+
from botorch.exceptions import InputDataError, UnsupportedError
2222
from botorch.generation.gen import gen_candidates_scipy
2323
from botorch.logging import logger
2424
from botorch.optim.initializers import (
@@ -600,6 +600,8 @@ def optimize_acqf_discrete(
600600
"Discrete optimization is not supported for"
601601
"one-shot acquisition functions."
602602
)
603+
if choices.numel() == 0:
604+
raise InputDataError("`choices` must be non-emtpy.")
603605
choices_batched = choices.unsqueeze(-2)
604606
if q > 1:
605607
candidate_list, acq_value_list = [], []

test/optim/test_optimize.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
AcquisitionFunction,
1414
OneShotAcquisitionFunction,
1515
)
16-
from botorch.exceptions import UnsupportedError
16+
from botorch.exceptions import InputDataError, UnsupportedError
1717
from botorch.optim.optimize import (
1818
_filter_infeasible,
1919
_filter_invalid,
@@ -880,6 +880,14 @@ def test_optimize_acqf_discrete(self):
880880
mock_acq_function = SquaredAcquisitionFunction()
881881
mock_acq_function.set_X_pending(None)
882882

883+
# ensure proper raising of errors if no choices
884+
with self.assertRaisesRegex(InputDataError, "`choices` must be non-emtpy."):
885+
optimize_acqf_discrete(
886+
acq_function=mock_acq_function,
887+
q=q,
888+
choices=torch.empty(0, 2),
889+
)
890+
883891
choices = torch.rand(5, 2, **tkwargs)
884892
exp_acq_vals = mock_acq_function(choices)
885893

0 commit comments

Comments
 (0)