diff --git a/botorch/acquisition/input_constructors.py b/botorch/acquisition/input_constructors.py index 24d8c4bf2c..99f5d6de10 100644 --- a/botorch/acquisition/input_constructors.py +++ b/botorch/acquisition/input_constructors.py @@ -986,7 +986,6 @@ def construct_inputs_qNEHVI( first_only=True, assert_shared=True, ) - # This selects the objectives (a subset of the outcomes) and set each # objective threhsold to have the proper optimization direction. if objective is None: @@ -1014,6 +1013,8 @@ def construct_inputs_qNEHVI( else: ref_point = objective(objective_thresholds) + num_objectives = objective_thresholds[~torch.isnan(objective_thresholds)].shape[0] + return { "model": model, "ref_point": ref_point, @@ -1024,7 +1025,7 @@ def construct_inputs_qNEHVI( "X_pending": kwargs.get("X_pending"), "eta": kwargs.get("eta", 1e-3), "prune_baseline": kwargs.get("prune_baseline", True), - "alpha": kwargs.get("alpha", get_default_partitioning_alpha(model.num_outputs)), + "alpha": kwargs.get("alpha", get_default_partitioning_alpha(num_objectives)), "cache_pending": kwargs.get("cache_pending", True), "max_iep": kwargs.get("max_iep", 0), "incremental_nehvi": kwargs.get("incremental_nehvi", True), diff --git a/test/acquisition/test_input_constructors.py b/test/acquisition/test_input_constructors.py index cde708d1d2..fd563fd955 100644 --- a/test/acquisition/test_input_constructors.py +++ b/test/acquisition/test_input_constructors.py @@ -810,7 +810,7 @@ def test_construct_inputs_qNEHVI(self): X_pending=X_pending, eta=1e-2, prune_baseline=True, - alpha=0.1, + alpha=0.0, cache_pending=False, max_iep=1, incremental_nehvi=False, @@ -831,7 +831,7 @@ def test_construct_inputs_qNEHVI(self): self.assertTrue(torch.equal(kwargs["X_pending"], X_pending)) self.assertEqual(kwargs["eta"], 1e-2) self.assertTrue(kwargs["prune_baseline"]) - self.assertEqual(kwargs["alpha"], 0.1) + self.assertEqual(kwargs["alpha"], 0.0) self.assertFalse(kwargs["cache_pending"]) self.assertEqual(kwargs["max_iep"], 1) self.assertFalse(kwargs["incremental_nehvi"]) @@ -874,7 +874,7 @@ def test_construct_inputs_qNEHVI(self): training_data=self.blockX_blockY, objective_thresholds=objective_thresholds, ) - self.assertEqual(kwargs["alpha"], 1e-3) + self.assertEqual(kwargs["alpha"], 0.0) def test_construct_inputs_kg(self): current_value = torch.tensor(1.23)