diff --git a/ax/modelbridge/generation_node.py b/ax/modelbridge/generation_node.py index ecda4a9e6c5..289884908df 100644 --- a/ax/modelbridge/generation_node.py +++ b/ax/modelbridge/generation_node.py @@ -439,11 +439,9 @@ def should_transition_to_next_node( raise NotImplementedError( "Cannot currently select between multiple nodes to transition to." ) - elif len(next_nodes) == 1: - return True, next_nodes[0] else: - # Will transition to the next node in the list. - return True, None + return True, next_nodes[0] + return False, None def generator_run_limit(self, supress_generation_errors: bool = True) -> int: diff --git a/ax/modelbridge/tests/test_generation_strategy.py b/ax/modelbridge/tests/test_generation_strategy.py index 0fe769fc0c7..ca3ce27123b 100644 --- a/ax/modelbridge/tests/test_generation_strategy.py +++ b/ax/modelbridge/tests/test_generation_strategy.py @@ -1284,12 +1284,14 @@ def test_gs_with_nodes_and_blocking_criteria(self) -> None: threshold=3, block_gen_if_met=True, block_transition_if_unmet=True, + transition_to="GPEI_node", ), MinTrials( threshold=2, only_in_statuses=[TrialStatus.COMPLETED], block_gen_if_met=False, block_transition_if_unmet=True, + transition_to="GPEI_node", ), ], ) diff --git a/tutorials/external_generation_node.ipynb b/tutorials/external_generation_node.ipynb index ecf19930490..c1aac0b259d 100644 --- a/tutorials/external_generation_node.ipynb +++ b/tutorials/external_generation_node.ipynb @@ -223,9 +223,11 @@ " model_specs=[ModelSpec(Models.SOBOL)],\n", " transition_criteria=[\n", " MaxTrials(\n", - " # This specifies the maximum number of trials to generate from this node.\n", + " # This specifies the maximum number of trials to generate from this node, \n", + " # and the next node in the strategy.\n", " threshold=5,\n", " block_transition_if_unmet=True,\n", + " transition_to=\"RandomForest\"\n", " )\n", " ],\n", " ),\n",