File tree 2 files changed +12
-2
lines changed 2 files changed +12
-2
lines changed Original file line number Diff line number Diff line change 18
18
OneShotAcquisitionFunction ,
19
19
)
20
20
from botorch .acquisition .knowledge_gradient import qKnowledgeGradient
21
- from botorch .exceptions import UnsupportedError
21
+ from botorch .exceptions import InputDataError , UnsupportedError
22
22
from botorch .generation .gen import gen_candidates_scipy
23
23
from botorch .logging import logger
24
24
from botorch .optim .initializers import (
@@ -600,6 +600,8 @@ def optimize_acqf_discrete(
600
600
"Discrete optimization is not supported for"
601
601
"one-shot acquisition functions."
602
602
)
603
+ if choices .numel () == 0 :
604
+ raise InputDataError ("`choices` must be non-emtpy." )
603
605
choices_batched = choices .unsqueeze (- 2 )
604
606
if q > 1 :
605
607
candidate_list , acq_value_list = [], []
Original file line number Diff line number Diff line change 13
13
AcquisitionFunction ,
14
14
OneShotAcquisitionFunction ,
15
15
)
16
- from botorch .exceptions import UnsupportedError
16
+ from botorch .exceptions import InputDataError , UnsupportedError
17
17
from botorch .optim .optimize import (
18
18
_filter_infeasible ,
19
19
_filter_invalid ,
@@ -880,6 +880,14 @@ def test_optimize_acqf_discrete(self):
880
880
mock_acq_function = SquaredAcquisitionFunction ()
881
881
mock_acq_function .set_X_pending (None )
882
882
883
+ # ensure proper raising of errors if no choices
884
+ with self .assertRaisesRegex (InputDataError , "`choices` must be non-emtpy." ):
885
+ optimize_acqf_discrete (
886
+ acq_function = mock_acq_function ,
887
+ q = q ,
888
+ choices = torch .empty (0 , 2 ),
889
+ )
890
+
883
891
choices = torch .rand (5 , 2 , ** tkwargs )
884
892
exp_acq_vals = mock_acq_function (choices )
885
893
You can’t perform that action at this time.
0 commit comments