diff --git a/ax/modelbridge/generation_node.py b/ax/modelbridge/generation_node.py index 55179c9a4f0..1f3e6e4ff24 100644 --- a/ax/modelbridge/generation_node.py +++ b/ax/modelbridge/generation_node.py @@ -151,7 +151,9 @@ def __init__( self.best_model_selector = best_model_selector self.should_deduplicate = should_deduplicate self._transition_criteria = transition_criteria - self._input_constructors = input_constructors + self._input_constructors = ( + input_constructors if input_constructors is not None else {} + ) self._previous_node_name = previous_node_name @property diff --git a/ax/modelbridge/generation_node_input_constructors.py b/ax/modelbridge/generation_node_input_constructors.py index dbb07c9e1e5..d2787c3b55c 100644 --- a/ax/modelbridge/generation_node_input_constructors.py +++ b/ax/modelbridge/generation_node_input_constructors.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. # pyre-strict +import sys from enum import Enum, unique from math import ceil from typing import Any, Dict, Optional @@ -11,6 +12,54 @@ from ax.modelbridge.generation_node import GenerationNode +@unique +class NodeInputConstructors(Enum): + """An enum which maps to a the name of a callable method for constructing + ``GenerationNode`` inputs. + + NOTE: The methods defined by this enum should all share identical signatures + and reside in this file. + """ + + ALL_N = "consume_all_n" + REPEAT_N = "repeat_arm_n" + REMAINING_N = "remaining_n" + + def __call__( + self, + previous_node: Optional[GenerationNode], + next_node: GenerationNode, + gs_gen_call_kwargs: Dict[str, Any], + ) -> int: + """Defines a callable method for the Enum as all values are methods""" + try: + method = getattr(sys.modules[__name__], self.value) + except AttributeError: + raise ValueError( + f"{self.value} is not defined as a method in " + "``generation_node_input_constructors.py``. Please add the method " + "to the file." + ) + return method( + previous_node=previous_node, + next_node=next_node, + gs_gen_call_kwargs=gs_gen_call_kwargs, + ) + + +@unique +class InputConstructorPurpose(Enum): + """A simple enum to indicate the purpose of the input constructor. + + Explanation of the different purposes: + N: Defines the logic to determine the number of arms to generate from the + next ``GenerationNode`` given the total number of arms expected in + this trial. + """ + + N = "n" + + def consume_all_n( previous_node: Optional[GenerationNode], next_node: GenerationNode, @@ -103,42 +152,3 @@ def remaining_n( total_n = gs_gen_call_kwargs.get("n") # if all arms have been generated, return 0 return max(total_n - sum(len(gr.arms) for gr in grs_this_gen), 0) - - -@unique -class NodeInputConstructors(Enum): - """An enum which maps to a callable method for constructing ``GenerationNode`` - inputs. - - NOTE: The methods defined by this enum should all share identical signatures. - """ - - ALL_N = consume_all_n - REPEAT_N = repeat_arm_n - REMAINING_N = remaining_n - - def __call__( - self, - previous_node: Optional[GenerationNode], - next_node: GenerationNode, - gs_gen_call_kwargs: Dict[str, Any], - ) -> int: - """Defines a callable method for the Enum as all values are methods""" - return self( - previous_node=previous_node, - next_node=next_node, - gs_gen_call_kwargs=gs_gen_call_kwargs, - ) - - -@unique -class InputConstructorPurpose(Enum): - """A simple enum to indicate the purpose of the input constructor. - - Explanation of the different purposes: - N: Defines the logic to determine the number of arms to generate from the - next ``GenerationNode`` given the total number of arms expected in - this trial. - """ - - N = "n" diff --git a/ax/modelbridge/tests/test_generation_node.py b/ax/modelbridge/tests/test_generation_node.py index a76f4668bc9..c744fb12533 100644 --- a/ax/modelbridge/tests/test_generation_node.py +++ b/ax/modelbridge/tests/test_generation_node.py @@ -82,7 +82,7 @@ def test_init(self) -> None: self.assertIs(node.best_model_selector, model_selector) def test_input_constructor_none(self) -> None: - self.assertIsNone(self.sobol_generation_node._input_constructors) + self.assertEqual(self.sobol_generation_node._input_constructors, {}) self.assertEqual(self.sobol_generation_node.input_constructors, {}) def test_input_constructor(self) -> None: diff --git a/ax/modelbridge/tests/test_generation_node_input_constructors.py b/ax/modelbridge/tests/test_generation_node_input_constructors.py index f735f935cb7..20d8c629ad5 100644 --- a/ax/modelbridge/tests/test_generation_node_input_constructors.py +++ b/ax/modelbridge/tests/test_generation_node_input_constructors.py @@ -157,7 +157,7 @@ def test_all_constructors_have_same_signature(self) -> None: method_signature = inspect.signature(all_constructors_tested[0]) for constructor in all_constructors_tested[1:]: with self.subTest(constructor=constructor): - func_parameters = get_type_hints(constructor) + func_parameters = get_type_hints(constructor.__call__) self.assertEqual( Counter(list(func_parameters.keys())), Counter( diff --git a/ax/storage/json_store/decoder.py b/ax/storage/json_store/decoder.py index 79df2f3dd7f..6184477812e 100644 --- a/ax/storage/json_store/decoder.py +++ b/ax/storage/json_store/decoder.py @@ -671,6 +671,15 @@ def generation_node_from_json( if "transition_criteria" in generation_node_json.keys() else None ), + input_constructors=( + object_from_json( + generation_node_json.pop("input_constructors"), + decoder_registry=decoder_registry, + class_decoder_registry=class_decoder_registry, + ) + if "input_constructors" in generation_node_json.keys() + else None + ), previous_node_name=( generation_node_json.pop("previous_node_name") if "previous_node_name" in generation_node_json.keys() diff --git a/ax/storage/json_store/encoders.py b/ax/storage/json_store/encoders.py index 0ec00288c69..9d7516d72be 100644 --- a/ax/storage/json_store/encoders.py +++ b/ax/storage/json_store/encoders.py @@ -443,6 +443,7 @@ def generation_node_to_dict(generation_node: GenerationNode) -> dict[str, Any]: "transition_criteria": generation_node.transition_criteria, "model_spec_to_gen_from": generation_node._model_spec_to_gen_from, "previous_node_name": generation_node._previous_node_name, + "input_constructors": generation_node.input_constructors, } diff --git a/ax/storage/json_store/registry.py b/ax/storage/json_store/registry.py index be2d164a3b1..a99866124bd 100644 --- a/ax/storage/json_store/registry.py +++ b/ax/storage/json_store/registry.py @@ -82,6 +82,10 @@ ) from ax.modelbridge.factory import Models from ax.modelbridge.generation_node import GenerationNode, GenerationStep +from ax.modelbridge.generation_node_input_constructors import ( + InputConstructorPurpose, + NodeInputConstructors, +) from ax.modelbridge.generation_strategy import GenerationStrategy from ax.modelbridge.model_spec import ModelSpec from ax.modelbridge.registry import ModelRegistryBase @@ -312,6 +316,7 @@ "Hartmann6Metric": Hartmann6Metric, "HierarchicalSearchSpace": HierarchicalSearchSpace, "ImprovementGlobalStoppingStrategy": ImprovementGlobalStoppingStrategy, + "InputConstructorPurpose": InputConstructorPurpose, "Interval": Interval, "LifecycleStage": LifecycleStage, "ListSurrogate": Surrogate, # For backwards compatibility @@ -334,6 +339,7 @@ "MultiObjectiveOptimizationConfig": MultiObjectiveOptimizationConfig, "MultiTypeExperiment": MultiTypeExperiment, "NegativeBraninMetric": NegativeBraninMetric, + "NodeInputConstructors": NodeInputConstructors, "NoisyFunctionMetric": NoisyFunctionMetric, "Normalize": Normalize, "Objective": Objective, diff --git a/ax/storage/json_store/tests/test_json_store.py b/ax/storage/json_store/tests/test_json_store.py index 27072919f38..d33f246e9c1 100644 --- a/ax/storage/json_store/tests/test_json_store.py +++ b/ax/storage/json_store/tests/test_json_store.py @@ -188,6 +188,20 @@ "GenerationStrategy", partial(sobol_gpei_generation_node_gs, with_previous_node=True), ), + ( + "GenerationStrategy", + partial(sobol_gpei_generation_node_gs, with_input_constructors_all_n=True), + ), + ( + "GenerationStrategy", + partial( + sobol_gpei_generation_node_gs, with_input_constructors_remaining_n=True + ), + ), + ( + "GenerationStrategy", + partial(sobol_gpei_generation_node_gs, with_input_constructors_repeat_n=True), + ), ("GeneratorRun", get_generator_run), ("Hartmann6Metric", get_hartmann_metric), ("HierarchicalSearchSpace", get_hierarchical_search_space), diff --git a/ax/utils/testing/modeling_stubs.py b/ax/utils/testing/modeling_stubs.py index ecf1e5d091f..f5d1d15b25c 100644 --- a/ax/utils/testing/modeling_stubs.py +++ b/ax/utils/testing/modeling_stubs.py @@ -25,6 +25,11 @@ from ax.modelbridge.dispatch_utils import choose_generation_strategy from ax.modelbridge.factory import get_sobol from ax.modelbridge.generation_node import GenerationNode + +from ax.modelbridge.generation_node_input_constructors import ( + InputConstructorPurpose, + NodeInputConstructors, +) from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy from ax.modelbridge.model_spec import ModelSpec from ax.modelbridge.registry import Models @@ -215,6 +220,9 @@ def sobol_gpei_generation_node_gs( with_model_selection: bool = False, with_auto_transition: bool = False, with_previous_node: bool = False, + with_input_constructors_all_n: bool = False, + with_input_constructors_remaining_n: bool = False, + with_input_constructors_repeat_n: bool = False, ) -> GenerationStrategy: """Returns a basic SOBOL+MBM GS using GenerationNodes for testing. @@ -308,6 +316,22 @@ def sobol_gpei_generation_node_gs( if with_previous_node: mbm_node._previous_node_name = sobol_node.node_name + # test input constructors, this also leaves the mbm node with no input + # constructors which validates encoding/decoding of instances with no + # input constructors + if with_input_constructors_all_n: + sobol_node._input_constructors = { + InputConstructorPurpose.N: NodeInputConstructors.ALL_N, + } + elif with_input_constructors_remaining_n: + sobol_node._input_constructors = { + InputConstructorPurpose.N: NodeInputConstructors.REMAINING_N, + } + elif with_input_constructors_repeat_n: + sobol_node._input_constructors = { + InputConstructorPurpose.N: NodeInputConstructors.REPEAT_N, + } + sobol_mbm_GS_nodes = GenerationStrategy( name="Sobol+MBM_Nodes", nodes=[sobol_node, mbm_node],