Skip to content

Commit

Permalink
Add storage for input constructors (#2785)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2785

This diff adds storage support for the input constructors that we've implemented throughout this stack (diffs 11-19). As part of this update for storage we do a few things:
1. updated NodeInputConstructors enum to store a string and then modified the call method to use that string to link to the correct method. We do this to avoid some strange behavior related to stroing the function as the enum value directly, namely being it's not registring as a enum but instead only the function.
2. Added some additional tests

Reviewed By: lena-kashtelyan

Differential Revision: D62652950
  • Loading branch information
mgarrard authored and facebook-github-bot committed Sep 25, 2024
1 parent 8ba8ce3 commit 35fdaef
Show file tree
Hide file tree
Showing 9 changed files with 108 additions and 42 deletions.
4 changes: 3 additions & 1 deletion ax/modelbridge/generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,9 @@ def __init__(
self.best_model_selector = best_model_selector
self.should_deduplicate = should_deduplicate
self._transition_criteria = transition_criteria
self._input_constructors = input_constructors
self._input_constructors = (
input_constructors if input_constructors is not None else {}
)
self._previous_node_name = previous_node_name

@property
Expand Down
88 changes: 49 additions & 39 deletions ax/modelbridge/generation_node_input_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,62 @@
# LICENSE file in the root directory of this source tree.

# pyre-strict
import sys
from enum import Enum, unique
from math import ceil
from typing import Any, Dict, Optional

from ax.modelbridge.generation_node import GenerationNode


@unique
class NodeInputConstructors(Enum):
"""An enum which maps to a the name of a callable method for constructing
``GenerationNode`` inputs.
NOTE: The methods defined by this enum should all share identical signatures
and reside in this file.
"""

ALL_N = "consume_all_n"
REPEAT_N = "repeat_arm_n"
REMAINING_N = "remaining_n"

def __call__(
self,
previous_node: Optional[GenerationNode],
next_node: GenerationNode,
gs_gen_call_kwargs: Dict[str, Any],
) -> int:
"""Defines a callable method for the Enum as all values are methods"""
try:
method = getattr(sys.modules[__name__], self.value)
except AttributeError:
raise ValueError(
f"{self.value} is not defined as a method in "
"``generation_node_input_constructors.py``. Please add the method "
"to the file."
)
return method(
previous_node=previous_node,
next_node=next_node,
gs_gen_call_kwargs=gs_gen_call_kwargs,
)


@unique
class InputConstructorPurpose(Enum):
"""A simple enum to indicate the purpose of the input constructor.
Explanation of the different purposes:
N: Defines the logic to determine the number of arms to generate from the
next ``GenerationNode`` given the total number of arms expected in
this trial.
"""

N = "n"


def consume_all_n(
previous_node: Optional[GenerationNode],
next_node: GenerationNode,
Expand Down Expand Up @@ -103,42 +152,3 @@ def remaining_n(
total_n = gs_gen_call_kwargs.get("n")
# if all arms have been generated, return 0
return max(total_n - sum(len(gr.arms) for gr in grs_this_gen), 0)


@unique
class NodeInputConstructors(Enum):
"""An enum which maps to a callable method for constructing ``GenerationNode``
inputs.
NOTE: The methods defined by this enum should all share identical signatures.
"""

ALL_N = consume_all_n
REPEAT_N = repeat_arm_n
REMAINING_N = remaining_n

def __call__(
self,
previous_node: Optional[GenerationNode],
next_node: GenerationNode,
gs_gen_call_kwargs: Dict[str, Any],
) -> int:
"""Defines a callable method for the Enum as all values are methods"""
return self(
previous_node=previous_node,
next_node=next_node,
gs_gen_call_kwargs=gs_gen_call_kwargs,
)


@unique
class InputConstructorPurpose(Enum):
"""A simple enum to indicate the purpose of the input constructor.
Explanation of the different purposes:
N: Defines the logic to determine the number of arms to generate from the
next ``GenerationNode`` given the total number of arms expected in
this trial.
"""

N = "n"
2 changes: 1 addition & 1 deletion ax/modelbridge/tests/test_generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def test_init(self) -> None:
self.assertIs(node.best_model_selector, model_selector)

def test_input_constructor_none(self) -> None:
self.assertIsNone(self.sobol_generation_node._input_constructors)
self.assertEqual(self.sobol_generation_node._input_constructors, {})
self.assertEqual(self.sobol_generation_node.input_constructors, {})

def test_input_constructor(self) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def test_all_constructors_have_same_signature(self) -> None:
method_signature = inspect.signature(all_constructors_tested[0])
for constructor in all_constructors_tested[1:]:
with self.subTest(constructor=constructor):
func_parameters = get_type_hints(constructor)
func_parameters = get_type_hints(constructor.__call__)
self.assertEqual(
Counter(list(func_parameters.keys())),
Counter(
Expand Down
9 changes: 9 additions & 0 deletions ax/storage/json_store/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,15 @@ def generation_node_from_json(
if "transition_criteria" in generation_node_json.keys()
else None
),
input_constructors=(
object_from_json(
generation_node_json.pop("input_constructors"),
decoder_registry=decoder_registry,
class_decoder_registry=class_decoder_registry,
)
if "input_constructors" in generation_node_json.keys()
else None
),
previous_node_name=(
generation_node_json.pop("previous_node_name")
if "previous_node_name" in generation_node_json.keys()
Expand Down
1 change: 1 addition & 0 deletions ax/storage/json_store/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,7 @@ def generation_node_to_dict(generation_node: GenerationNode) -> dict[str, Any]:
"transition_criteria": generation_node.transition_criteria,
"model_spec_to_gen_from": generation_node._model_spec_to_gen_from,
"previous_node_name": generation_node._previous_node_name,
"input_constructors": generation_node.input_constructors,
}


Expand Down
6 changes: 6 additions & 0 deletions ax/storage/json_store/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@
)
from ax.modelbridge.factory import Models
from ax.modelbridge.generation_node import GenerationNode, GenerationStep
from ax.modelbridge.generation_node_input_constructors import (
InputConstructorPurpose,
NodeInputConstructors,
)
from ax.modelbridge.generation_strategy import GenerationStrategy
from ax.modelbridge.model_spec import ModelSpec
from ax.modelbridge.registry import ModelRegistryBase
Expand Down Expand Up @@ -312,6 +316,7 @@
"Hartmann6Metric": Hartmann6Metric,
"HierarchicalSearchSpace": HierarchicalSearchSpace,
"ImprovementGlobalStoppingStrategy": ImprovementGlobalStoppingStrategy,
"InputConstructorPurpose": InputConstructorPurpose,
"Interval": Interval,
"LifecycleStage": LifecycleStage,
"ListSurrogate": Surrogate, # For backwards compatibility
Expand All @@ -334,6 +339,7 @@
"MultiObjectiveOptimizationConfig": MultiObjectiveOptimizationConfig,
"MultiTypeExperiment": MultiTypeExperiment,
"NegativeBraninMetric": NegativeBraninMetric,
"NodeInputConstructors": NodeInputConstructors,
"NoisyFunctionMetric": NoisyFunctionMetric,
"Normalize": Normalize,
"Objective": Objective,
Expand Down
14 changes: 14 additions & 0 deletions ax/storage/json_store/tests/test_json_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,20 @@
"GenerationStrategy",
partial(sobol_gpei_generation_node_gs, with_previous_node=True),
),
(
"GenerationStrategy",
partial(sobol_gpei_generation_node_gs, with_input_constructors_all_n=True),
),
(
"GenerationStrategy",
partial(
sobol_gpei_generation_node_gs, with_input_constructors_remaining_n=True
),
),
(
"GenerationStrategy",
partial(sobol_gpei_generation_node_gs, with_input_constructors_repeat_n=True),
),
("GeneratorRun", get_generator_run),
("Hartmann6Metric", get_hartmann_metric),
("HierarchicalSearchSpace", get_hierarchical_search_space),
Expand Down
24 changes: 24 additions & 0 deletions ax/utils/testing/modeling_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@
from ax.modelbridge.dispatch_utils import choose_generation_strategy
from ax.modelbridge.factory import get_sobol
from ax.modelbridge.generation_node import GenerationNode

from ax.modelbridge.generation_node_input_constructors import (
InputConstructorPurpose,
NodeInputConstructors,
)
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy
from ax.modelbridge.model_spec import ModelSpec
from ax.modelbridge.registry import Models
Expand Down Expand Up @@ -215,6 +220,9 @@ def sobol_gpei_generation_node_gs(
with_model_selection: bool = False,
with_auto_transition: bool = False,
with_previous_node: bool = False,
with_input_constructors_all_n: bool = False,
with_input_constructors_remaining_n: bool = False,
with_input_constructors_repeat_n: bool = False,
) -> GenerationStrategy:
"""Returns a basic SOBOL+MBM GS using GenerationNodes for testing.
Expand Down Expand Up @@ -308,6 +316,22 @@ def sobol_gpei_generation_node_gs(
if with_previous_node:
mbm_node._previous_node_name = sobol_node.node_name

# test input constructors, this also leaves the mbm node with no input
# constructors which validates encoding/decoding of instances with no
# input constructors
if with_input_constructors_all_n:
sobol_node._input_constructors = {
InputConstructorPurpose.N: NodeInputConstructors.ALL_N,
}
elif with_input_constructors_remaining_n:
sobol_node._input_constructors = {
InputConstructorPurpose.N: NodeInputConstructors.REMAINING_N,
}
elif with_input_constructors_repeat_n:
sobol_node._input_constructors = {
InputConstructorPurpose.N: NodeInputConstructors.REPEAT_N,
}

sobol_mbm_GS_nodes = GenerationStrategy(
name="Sobol+MBM_Nodes",
nodes=[sobol_node, mbm_node],
Expand Down

0 comments on commit 35fdaef

Please sign in to comment.