Skip to content

Commit

Permalink
Add input constructors as optional field on GenerationNode (#2754)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2754

This diff adds the input constructors property to GenerationNode and a simple test on the property.

Follow up diffs:
- add the two additional input constructors
- update the transition logic to leverage this
- storage --> let's do this once we all like the 3 input constructors
- update the input constructors to handle the case where n isn't provided as a kwarg
- add test that all cases of input constructors are handled per Liz's suggestion

Reviewed By: lena-kashtelyan

Differential Revision: D62310414

fbshipit-source-id: fa69a50b60eee70ef396aeaa4a3637ac632a4ec9
  • Loading branch information
mgarrard authored and facebook-github-bot committed Sep 24, 2024
1 parent 2a9cf98 commit 665d9dd
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 2 deletions.
31 changes: 29 additions & 2 deletions ax/modelbridge/generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from collections.abc import Sequence
from dataclasses import dataclass, field
from logging import Logger
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Dict, Optional, Union

# Module-level import to avoid circular dependency b/w this file and
# generation_strategy.py
Expand All @@ -29,6 +29,7 @@
from ax.exceptions.generation_strategy import GenerationStrategyRepeatedPoints
from ax.modelbridge.base import ModelBridge
from ax.modelbridge.best_model_selector import BestModelSelector

from ax.modelbridge.model_spec import FactoryFunctionModelSpec, ModelSpec
from ax.modelbridge.registry import _extract_model_state_after_gen, ModelRegistryBase
from ax.modelbridge.transition_criterion import (
Expand Down Expand Up @@ -100,7 +101,12 @@ class GenerationNode(SerializationMixin, SortableBase):
_model_spec_to_gen_from: Optional[ModelSpec] = None
# TODO: @mgarrard should this be a dict criterion_class name -> criterion mapping?
_transition_criteria: Optional[Sequence[TransitionCriterion]]

_input_constructors: Optional[
Dict[
modelbridge.generation_node_input_constructors.InputConstructorPurpose,
modelbridge.generation_node_input_constructors.NodeInputConstructors,
]
]
# [TODO] Handle experiment passing more eloquently by enforcing experiment
# attribute is set in generation strategies class
_generation_strategy: Optional[
Expand All @@ -114,6 +120,12 @@ def __init__(
best_model_selector: Optional[BestModelSelector] = None,
should_deduplicate: bool = False,
transition_criteria: Optional[Sequence[TransitionCriterion]] = None,
input_constructors: Optional[
Dict[
modelbridge.generation_node_input_constructors.InputConstructorPurpose,
modelbridge.generation_node_input_constructors.NodeInputConstructors,
]
] = None,
) -> None:
self._node_name = node_name
# Check that the model specs have unique model keys.
Expand All @@ -128,6 +140,7 @@ def __init__(
self.best_model_selector = best_model_selector
self.should_deduplicate = should_deduplicate
self._transition_criteria = transition_criteria
self._input_constructors = input_constructors

@property
def node_name(self) -> str:
Expand Down Expand Up @@ -174,6 +187,20 @@ def transition_criteria(self) -> Sequence[TransitionCriterion]:
"""
return [] if self._transition_criteria is None else self._transition_criteria

@property
def input_constructors(
self,
) -> Optional[
Dict[
modelbridge.generation_node_input_constructors.InputConstructorPurpose,
modelbridge.generation_node_input_constructors.NodeInputConstructors,
]
]:
"""Returns the input constructors that will be used to determine any dynamic
inputs to this ``GenerationNode``.
"""
return self._input_constructors

@property
def experiment(self) -> Experiment:
"""Returns the experiment associated with this GenerationStrategy"""
Expand Down
23 changes: 23 additions & 0 deletions ax/modelbridge/tests/test_generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@
GenerationStep,
MISSING_MODEL_SELECTOR_MESSAGE,
)
from ax.modelbridge.generation_node_input_constructors import (
InputConstructorPurpose,
NodeInputConstructors,
)
from ax.modelbridge.model_spec import FactoryFunctionModelSpec, ModelSpec
from ax.modelbridge.registry import Models
from ax.modelbridge.transition_criterion import MaxTrials
Expand Down Expand Up @@ -77,6 +81,25 @@ def test_init(self) -> None:
self.assertEqual(node.model_specs, mbm_specs)
self.assertIs(node.best_model_selector, model_selector)

def test_input_constructor_none(self) -> None:
self.assertIsNone(self.sobol_generation_node.input_constructors)
self.assertIsNone(self.sobol_generation_node._input_constructors)

def test_input_constructor(self) -> None:
node = GenerationNode(
node_name="test",
model_specs=[self.sobol_model_spec],
input_constructors={InputConstructorPurpose.N: NodeInputConstructors.ALL_N},
)
self.assertEqual(
node.input_constructors,
{InputConstructorPurpose.N: NodeInputConstructors.ALL_N},
)
self.assertEqual(
node._input_constructors,
{InputConstructorPurpose.N: NodeInputConstructors.ALL_N},
)

def test_fit(self) -> None:
dat = self.branin_experiment.lookup_data()
with patch.object(
Expand Down

0 comments on commit 665d9dd

Please sign in to comment.