Skip to content

Commit

Permalink
ST_MTGP model for MTGP MOO (#1962)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1962

- Changed the legacy Ax model ST_MTGP_NEHVI into the new modular botorch model ST_MTGP.
- Changed the number of objectives in MOO problems to not depend on the number of constraints but only on the number of objectives.

Reviewed By: saitcakmak

Differential Revision: D47739378

fbshipit-source-id: 9d0bab44d589c1e1d61c47f8be430c122539b884
  • Loading branch information
Jelena Markovic-Voronov authored and facebook-github-bot committed Aug 1, 2023
1 parent ff3a394 commit 3b674eb
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
5 changes: 3 additions & 2 deletions botorch/acquisition/input_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -986,7 +986,6 @@ def construct_inputs_qNEHVI(
first_only=True,
assert_shared=True,
)

# This selects the objectives (a subset of the outcomes) and set each
# objective threhsold to have the proper optimization direction.
if objective is None:
Expand Down Expand Up @@ -1014,6 +1013,8 @@ def construct_inputs_qNEHVI(
else:
ref_point = objective(objective_thresholds)

num_objectives = objective_thresholds[~torch.isnan(objective_thresholds)].shape[0]

return {
"model": model,
"ref_point": ref_point,
Expand All @@ -1024,7 +1025,7 @@ def construct_inputs_qNEHVI(
"X_pending": kwargs.get("X_pending"),
"eta": kwargs.get("eta", 1e-3),
"prune_baseline": kwargs.get("prune_baseline", True),
"alpha": kwargs.get("alpha", get_default_partitioning_alpha(model.num_outputs)),
"alpha": kwargs.get("alpha", get_default_partitioning_alpha(num_objectives)),
"cache_pending": kwargs.get("cache_pending", True),
"max_iep": kwargs.get("max_iep", 0),
"incremental_nehvi": kwargs.get("incremental_nehvi", True),
Expand Down
6 changes: 3 additions & 3 deletions test/acquisition/test_input_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,7 +810,7 @@ def test_construct_inputs_qNEHVI(self):
X_pending=X_pending,
eta=1e-2,
prune_baseline=True,
alpha=0.1,
alpha=0.0,
cache_pending=False,
max_iep=1,
incremental_nehvi=False,
Expand All @@ -831,7 +831,7 @@ def test_construct_inputs_qNEHVI(self):
self.assertTrue(torch.equal(kwargs["X_pending"], X_pending))
self.assertEqual(kwargs["eta"], 1e-2)
self.assertTrue(kwargs["prune_baseline"])
self.assertEqual(kwargs["alpha"], 0.1)
self.assertEqual(kwargs["alpha"], 0.0)
self.assertFalse(kwargs["cache_pending"])
self.assertEqual(kwargs["max_iep"], 1)
self.assertFalse(kwargs["incremental_nehvi"])
Expand Down Expand Up @@ -874,7 +874,7 @@ def test_construct_inputs_qNEHVI(self):
training_data=self.blockX_blockY,
objective_thresholds=objective_thresholds,
)
self.assertEqual(kwargs["alpha"], 1e-3)
self.assertEqual(kwargs["alpha"], 0.0)

def test_construct_inputs_kg(self):
current_value = torch.tensor(1.23)
Expand Down

0 comments on commit 3b674eb

Please sign in to comment.