-
Notifications
You must be signed in to change notification settings - Fork 312
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add consume all n input constructor (#2750)
Summary: Pull Request resolved: #2750 This diff adds the first of 3 planned input constructors -- this one is the simplest one which returns the full n Follow up diffs: - add the two additional input constructors - update the generation node file to take this as input - update the transition logic to leverage this - storage --> let's do this once we all like the 3 input constructors - update the input constructors to handle the case where n isn't provided as a kwarg Reviewed By: lena-kashtelyan Differential Revision: D62247100 fbshipit-source-id: 52262391484c27931837e0ced79b2d2cfa29344b
- Loading branch information
1 parent
d738685
commit 2a9cf98
Showing
3 changed files
with
128 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# pyre-strict | ||
from enum import Enum, unique | ||
from typing import Any, Dict, Optional | ||
|
||
from ax.modelbridge.generation_node import GenerationNode | ||
|
||
|
||
def consume_all_n( | ||
previous_node: Optional[GenerationNode], | ||
next_node: GenerationNode, | ||
gs_gen_call_kwargs: Dict[str, Any], | ||
) -> int: | ||
"""Generate total requested number of arms from the next node. | ||
Example: Initial exploration with Sobol will generate all arms from a | ||
single sobol node. | ||
Args: | ||
previous_node: The previous node in the ``GenerationStrategy``. This is the node | ||
that is being transition away from, and is provided for easy access to | ||
properties of this node. | ||
next_node: The next node in the ``GenerationStrategy``. This is the node that | ||
will leverage the inputs defined by this input constructor. | ||
gs_gen_call_kwargs: The kwargs passed to the ``GenerationStrategy``'s | ||
gen call. | ||
Returns: | ||
The total number of requested arms from the next node. | ||
""" | ||
# TODO: @mgarrard handle case where n isn't specified | ||
if gs_gen_call_kwargs.get("n") is None: | ||
raise NotImplementedError( | ||
f"Currently `{consume_all_n.__name__}` only supports cases where n is " | ||
"specified" | ||
) | ||
return gs_gen_call_kwargs.get("n") | ||
|
||
|
||
@unique | ||
class NodeInputConstructors(Enum): | ||
"""An enum which maps to a callable method for constructing ``GenerationNode`` | ||
inputs. | ||
""" | ||
|
||
ALL_N = consume_all_n | ||
|
||
def __call__( | ||
self, | ||
previous_node: Optional[GenerationNode], | ||
next_node: GenerationNode, | ||
gs_gen_call_kwargs: Dict[str, Any], | ||
) -> int: | ||
"""Defines a callable method for the Enum as all values are methods""" | ||
return self( | ||
previous_node=previous_node, | ||
next_node=next_node, | ||
gs_gen_call_kwargs=gs_gen_call_kwargs, | ||
) | ||
|
||
|
||
@unique | ||
class InputConstructorPurpose(Enum): | ||
"""A simple enum to indicate the purpose of the input constructor. | ||
Explanation of the different purposes: | ||
N: Defines the logic to determine the number of arms to generate from the | ||
next ``GenerationNode`` given the total number of arms expected in | ||
this trial. | ||
""" | ||
|
||
N = "n" |
47 changes: 47 additions & 0 deletions
47
ax/modelbridge/tests/test_generation_node_input_constructors.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# pyre-strict | ||
|
||
from ax.modelbridge.generation_node import GenerationNode | ||
from ax.modelbridge.generation_node_input_constructors import NodeInputConstructors | ||
from ax.modelbridge.model_spec import ModelSpec | ||
from ax.modelbridge.registry import Models | ||
from ax.utils.common.testutils import TestCase | ||
|
||
|
||
class TestGenerationNodeInputConstructors(TestCase): | ||
def setUp(self) -> None: | ||
super().setUp() | ||
self.sobol_model_spec = ModelSpec( | ||
model_enum=Models.SOBOL, | ||
model_kwargs={"init_position": 3}, | ||
model_gen_kwargs={"some_gen_kwarg": "some_value"}, | ||
) | ||
self.sobol_generation_node = GenerationNode( | ||
node_name="test", model_specs=[self.sobol_model_spec] | ||
) | ||
|
||
def test_consume_all_n_constructor(self) -> None: | ||
"""Test that the consume_all_n_constructor returns full n.""" | ||
num_to_gen = NodeInputConstructors.ALL_N( | ||
previous_node=None, | ||
next_node=self.sobol_generation_node, | ||
gs_gen_call_kwargs={"n": 5}, | ||
) | ||
|
||
self.assertEqual(num_to_gen, 5) | ||
|
||
def test_consume_all_n_constructor_no_n(self) -> None: | ||
"""Test raise error if n is not specified.""" | ||
with self.assertRaisesRegex( | ||
NotImplementedError, | ||
"`consume_all_n` only supports cases where n is specified", | ||
): | ||
_ = NodeInputConstructors.ALL_N( | ||
previous_node=None, | ||
next_node=self.sobol_generation_node, | ||
gs_gen_call_kwargs={}, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters