Skip to content

Commit

Permalink
Fix storage for instances where a generation node has no TC (#2784)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2784

During dev of storage for input constructors I realized this may be a problem for TC as well. This diff ensures that nodes with no TC are properly stored and able to be reloaded.

Reviewed By: saitcakmak

Differential Revision: D63348895
  • Loading branch information
mgarrard authored and facebook-github-bot committed Sep 25, 2024
1 parent 6762e36 commit 2e15052
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 7 deletions.
14 changes: 7 additions & 7 deletions ax/modelbridge/generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,10 @@ class GenerationNode(SerializationMixin, SortableBase):
# Optional specifications
_model_spec_to_gen_from: Optional[ModelSpec] = None
# TODO: @mgarrard should this be a dict criterion_class name -> criterion mapping?
_transition_criteria: Optional[Sequence[TransitionCriterion]]
_input_constructors: Optional[
Dict[
modelbridge.generation_node_input_constructors.InputConstructorPurpose,
modelbridge.generation_node_input_constructors.NodeInputConstructors,
]
_transition_criteria: Sequence[TransitionCriterion]
_input_constructors: Dict[
modelbridge.generation_node_input_constructors.InputConstructorPurpose,
modelbridge.generation_node_input_constructors.NodeInputConstructors,
]
_previous_node_name: Optional[str] = None

Expand Down Expand Up @@ -150,7 +148,9 @@ def __init__(
self.model_specs = model_specs
self.best_model_selector = best_model_selector
self.should_deduplicate = should_deduplicate
self._transition_criteria = transition_criteria
self._transition_criteria = (
transition_criteria if transition_criteria is not None else []
)
self._input_constructors = (
input_constructors if input_constructors is not None else {}
)
Expand Down
4 changes: 4 additions & 0 deletions ax/storage/json_store/tests/test_json_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,10 @@
"GenerationStrategy",
partial(sobol_gpei_generation_node_gs, with_input_constructors_repeat_n=True),
),
(
"GenerationStrategy",
partial(sobol_gpei_generation_node_gs, with_unlimited_gen_mbm=True),
),
("GeneratorRun", get_generator_run),
("Hartmann6Metric", get_hartmann_metric),
("HierarchicalSearchSpace", get_hierarchical_search_space),
Expand Down
8 changes: 8 additions & 0 deletions ax/utils/testing/modeling_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ def sobol_gpei_generation_node_gs(
with_input_constructors_all_n: bool = False,
with_input_constructors_remaining_n: bool = False,
with_input_constructors_repeat_n: bool = False,
with_unlimited_gen_mbm: bool = False,
) -> GenerationStrategy:
"""Returns a basic SOBOL+MBM GS using GenerationNodes for testing.
Expand Down Expand Up @@ -303,6 +304,13 @@ def sobol_gpei_generation_node_gs(
model_specs=mbm_model_specs,
best_model_selector=best_model_selector,
)
elif with_unlimited_gen_mbm:
# no TC defined is equivalent to unlimited gen
mbm_node = GenerationNode(
node_name="MBM_node",
model_specs=mbm_model_specs,
best_model_selector=best_model_selector,
)
else:
mbm_node = GenerationNode(
node_name="MBM_node",
Expand Down

0 comments on commit 2e15052

Please sign in to comment.