Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add storage for input constructors #2785

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion ax/modelbridge/generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,9 @@ def __init__(
self.best_model_selector = best_model_selector
self.should_deduplicate = should_deduplicate
self._transition_criteria = transition_criteria
self._input_constructors = input_constructors
self._input_constructors = (
input_constructors if input_constructors is not None else {}
)
self._previous_node_name = previous_node_name

@property
Expand Down
88 changes: 49 additions & 39 deletions ax/modelbridge/generation_node_input_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,62 @@
# LICENSE file in the root directory of this source tree.

# pyre-strict
import sys
from enum import Enum, unique
from math import ceil
from typing import Any, Dict, Optional

from ax.modelbridge.generation_node import GenerationNode


@unique
class NodeInputConstructors(Enum):
"""An enum which maps to a the name of a callable method for constructing
``GenerationNode`` inputs.

NOTE: The methods defined by this enum should all share identical signatures
and reside in this file.
"""

ALL_N = "consume_all_n"
REPEAT_N = "repeat_arm_n"
REMAINING_N = "remaining_n"

def __call__(
self,
previous_node: Optional[GenerationNode],
next_node: GenerationNode,
gs_gen_call_kwargs: Dict[str, Any],
) -> int:
"""Defines a callable method for the Enum as all values are methods"""
try:
method = getattr(sys.modules[__name__], self.value)
except AttributeError:
raise ValueError(
f"{self.value} is not defined as a method in "
"``generation_node_input_constructors.py``. Please add the method "
"to the file."
)
return method(
previous_node=previous_node,
next_node=next_node,
gs_gen_call_kwargs=gs_gen_call_kwargs,
)


@unique
class InputConstructorPurpose(Enum):
"""A simple enum to indicate the purpose of the input constructor.

Explanation of the different purposes:
N: Defines the logic to determine the number of arms to generate from the
next ``GenerationNode`` given the total number of arms expected in
this trial.
"""

N = "n"


def consume_all_n(
previous_node: Optional[GenerationNode],
next_node: GenerationNode,
Expand Down Expand Up @@ -103,42 +152,3 @@ def remaining_n(
total_n = gs_gen_call_kwargs.get("n")
# if all arms have been generated, return 0
return max(total_n - sum(len(gr.arms) for gr in grs_this_gen), 0)


@unique
class NodeInputConstructors(Enum):
"""An enum which maps to a callable method for constructing ``GenerationNode``
inputs.

NOTE: The methods defined by this enum should all share identical signatures.
"""

ALL_N = consume_all_n
REPEAT_N = repeat_arm_n
REMAINING_N = remaining_n

def __call__(
self,
previous_node: Optional[GenerationNode],
next_node: GenerationNode,
gs_gen_call_kwargs: Dict[str, Any],
) -> int:
"""Defines a callable method for the Enum as all values are methods"""
return self(
previous_node=previous_node,
next_node=next_node,
gs_gen_call_kwargs=gs_gen_call_kwargs,
)


@unique
class InputConstructorPurpose(Enum):
"""A simple enum to indicate the purpose of the input constructor.

Explanation of the different purposes:
N: Defines the logic to determine the number of arms to generate from the
next ``GenerationNode`` given the total number of arms expected in
this trial.
"""

N = "n"
2 changes: 1 addition & 1 deletion ax/modelbridge/tests/test_generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def test_init(self) -> None:
self.assertIs(node.best_model_selector, model_selector)

def test_input_constructor_none(self) -> None:
self.assertIsNone(self.sobol_generation_node._input_constructors)
self.assertEqual(self.sobol_generation_node._input_constructors, {})
self.assertEqual(self.sobol_generation_node.input_constructors, {})

def test_input_constructor(self) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def test_all_constructors_have_same_signature(self) -> None:
method_signature = inspect.signature(all_constructors_tested[0])
for constructor in all_constructors_tested[1:]:
with self.subTest(constructor=constructor):
func_parameters = get_type_hints(constructor)
func_parameters = get_type_hints(constructor.__call__)
self.assertEqual(
Counter(list(func_parameters.keys())),
Counter(
Expand Down
9 changes: 9 additions & 0 deletions ax/storage/json_store/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,15 @@ def generation_node_from_json(
if "transition_criteria" in generation_node_json.keys()
else None
),
input_constructors=(
object_from_json(
generation_node_json.pop("input_constructors"),
decoder_registry=decoder_registry,
class_decoder_registry=class_decoder_registry,
)
if "input_constructors" in generation_node_json.keys()
else None
),
previous_node_name=(
generation_node_json.pop("previous_node_name")
if "previous_node_name" in generation_node_json.keys()
Expand Down
1 change: 1 addition & 0 deletions ax/storage/json_store/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,7 @@ def generation_node_to_dict(generation_node: GenerationNode) -> dict[str, Any]:
"transition_criteria": generation_node.transition_criteria,
"model_spec_to_gen_from": generation_node._model_spec_to_gen_from,
"previous_node_name": generation_node._previous_node_name,
"input_constructors": generation_node.input_constructors,
}


Expand Down
6 changes: 6 additions & 0 deletions ax/storage/json_store/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@
)
from ax.modelbridge.factory import Models
from ax.modelbridge.generation_node import GenerationNode, GenerationStep
from ax.modelbridge.generation_node_input_constructors import (
InputConstructorPurpose,
NodeInputConstructors,
)
from ax.modelbridge.generation_strategy import GenerationStrategy
from ax.modelbridge.model_spec import ModelSpec
from ax.modelbridge.registry import ModelRegistryBase
Expand Down Expand Up @@ -312,6 +316,7 @@
"Hartmann6Metric": Hartmann6Metric,
"HierarchicalSearchSpace": HierarchicalSearchSpace,
"ImprovementGlobalStoppingStrategy": ImprovementGlobalStoppingStrategy,
"InputConstructorPurpose": InputConstructorPurpose,
"Interval": Interval,
"LifecycleStage": LifecycleStage,
"ListSurrogate": Surrogate, # For backwards compatibility
Expand All @@ -334,6 +339,7 @@
"MultiObjectiveOptimizationConfig": MultiObjectiveOptimizationConfig,
"MultiTypeExperiment": MultiTypeExperiment,
"NegativeBraninMetric": NegativeBraninMetric,
"NodeInputConstructors": NodeInputConstructors,
"NoisyFunctionMetric": NoisyFunctionMetric,
"Normalize": Normalize,
"Objective": Objective,
Expand Down
14 changes: 14 additions & 0 deletions ax/storage/json_store/tests/test_json_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,20 @@
"GenerationStrategy",
partial(sobol_gpei_generation_node_gs, with_previous_node=True),
),
(
"GenerationStrategy",
partial(sobol_gpei_generation_node_gs, with_input_constructors_all_n=True),
),
(
"GenerationStrategy",
partial(
sobol_gpei_generation_node_gs, with_input_constructors_remaining_n=True
),
),
(
"GenerationStrategy",
partial(sobol_gpei_generation_node_gs, with_input_constructors_repeat_n=True),
),
("GeneratorRun", get_generator_run),
("Hartmann6Metric", get_hartmann_metric),
("HierarchicalSearchSpace", get_hierarchical_search_space),
Expand Down
24 changes: 24 additions & 0 deletions ax/utils/testing/modeling_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@
from ax.modelbridge.dispatch_utils import choose_generation_strategy
from ax.modelbridge.factory import get_sobol
from ax.modelbridge.generation_node import GenerationNode

from ax.modelbridge.generation_node_input_constructors import (
InputConstructorPurpose,
NodeInputConstructors,
)
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy
from ax.modelbridge.model_spec import ModelSpec
from ax.modelbridge.registry import Models
Expand Down Expand Up @@ -215,6 +220,9 @@ def sobol_gpei_generation_node_gs(
with_model_selection: bool = False,
with_auto_transition: bool = False,
with_previous_node: bool = False,
with_input_constructors_all_n: bool = False,
with_input_constructors_remaining_n: bool = False,
with_input_constructors_repeat_n: bool = False,
) -> GenerationStrategy:
"""Returns a basic SOBOL+MBM GS using GenerationNodes for testing.

Expand Down Expand Up @@ -308,6 +316,22 @@ def sobol_gpei_generation_node_gs(
if with_previous_node:
mbm_node._previous_node_name = sobol_node.node_name

# test input constructors, this also leaves the mbm node with no input
# constructors which validates encoding/decoding of instances with no
# input constructors
if with_input_constructors_all_n:
sobol_node._input_constructors = {
InputConstructorPurpose.N: NodeInputConstructors.ALL_N,
}
elif with_input_constructors_remaining_n:
sobol_node._input_constructors = {
InputConstructorPurpose.N: NodeInputConstructors.REMAINING_N,
}
elif with_input_constructors_repeat_n:
sobol_node._input_constructors = {
InputConstructorPurpose.N: NodeInputConstructors.REPEAT_N,
}

sobol_mbm_GS_nodes = GenerationStrategy(
name="Sobol+MBM_Nodes",
nodes=[sobol_node, mbm_node],
Expand Down