Skip to content

Commit

Permalink
Add consume all n input constructor (#2750)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2750

This diff adds the first of 3 planned input constructors -- this one is the simplest one which returns the full n

Follow up diffs:
- add the two additional input constructors
- update the generation node file to take this as input
- update the transition logic to leverage this
- storage --> let's do this once we all like the 3 input constructors
- update the input constructors to handle the case where n isn't provided as a kwarg

Reviewed By: lena-kashtelyan

Differential Revision: D62247100

fbshipit-source-id: 52262391484c27931837e0ced79b2d2cfa29344b
  • Loading branch information
mgarrard authored and facebook-github-bot committed Sep 24, 2024
1 parent d738685 commit 2a9cf98
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 0 deletions.
75 changes: 75 additions & 0 deletions ax/modelbridge/generation_node_input_constructors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict
from enum import Enum, unique
from typing import Any, Dict, Optional

from ax.modelbridge.generation_node import GenerationNode


def consume_all_n(
previous_node: Optional[GenerationNode],
next_node: GenerationNode,
gs_gen_call_kwargs: Dict[str, Any],
) -> int:
"""Generate total requested number of arms from the next node.
Example: Initial exploration with Sobol will generate all arms from a
single sobol node.
Args:
previous_node: The previous node in the ``GenerationStrategy``. This is the node
that is being transition away from, and is provided for easy access to
properties of this node.
next_node: The next node in the ``GenerationStrategy``. This is the node that
will leverage the inputs defined by this input constructor.
gs_gen_call_kwargs: The kwargs passed to the ``GenerationStrategy``'s
gen call.
Returns:
The total number of requested arms from the next node.
"""
# TODO: @mgarrard handle case where n isn't specified
if gs_gen_call_kwargs.get("n") is None:
raise NotImplementedError(
f"Currently `{consume_all_n.__name__}` only supports cases where n is "
"specified"
)
return gs_gen_call_kwargs.get("n")


@unique
class NodeInputConstructors(Enum):
"""An enum which maps to a callable method for constructing ``GenerationNode``
inputs.
"""

ALL_N = consume_all_n

def __call__(
self,
previous_node: Optional[GenerationNode],
next_node: GenerationNode,
gs_gen_call_kwargs: Dict[str, Any],
) -> int:
"""Defines a callable method for the Enum as all values are methods"""
return self(
previous_node=previous_node,
next_node=next_node,
gs_gen_call_kwargs=gs_gen_call_kwargs,
)


@unique
class InputConstructorPurpose(Enum):
"""A simple enum to indicate the purpose of the input constructor.
Explanation of the different purposes:
N: Defines the logic to determine the number of arms to generate from the
next ``GenerationNode`` given the total number of arms expected in
this trial.
"""

N = "n"
47 changes: 47 additions & 0 deletions ax/modelbridge/tests/test_generation_node_input_constructors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from ax.modelbridge.generation_node import GenerationNode
from ax.modelbridge.generation_node_input_constructors import NodeInputConstructors
from ax.modelbridge.model_spec import ModelSpec
from ax.modelbridge.registry import Models
from ax.utils.common.testutils import TestCase


class TestGenerationNodeInputConstructors(TestCase):
def setUp(self) -> None:
super().setUp()
self.sobol_model_spec = ModelSpec(
model_enum=Models.SOBOL,
model_kwargs={"init_position": 3},
model_gen_kwargs={"some_gen_kwarg": "some_value"},
)
self.sobol_generation_node = GenerationNode(
node_name="test", model_specs=[self.sobol_model_spec]
)

def test_consume_all_n_constructor(self) -> None:
"""Test that the consume_all_n_constructor returns full n."""
num_to_gen = NodeInputConstructors.ALL_N(
previous_node=None,
next_node=self.sobol_generation_node,
gs_gen_call_kwargs={"n": 5},
)

self.assertEqual(num_to_gen, 5)

def test_consume_all_n_constructor_no_n(self) -> None:
"""Test raise error if n is not specified."""
with self.assertRaisesRegex(
NotImplementedError,
"`consume_all_n` only supports cases where n is specified",
):
_ = NodeInputConstructors.ALL_N(
previous_node=None,
next_node=self.sobol_generation_node,
gs_gen_call_kwargs={},
)
6 changes: 6 additions & 0 deletions sphinx/source/modelbridge.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ Transition Criterion
:undoc-members:
:show-inheritance:

Generation Node Input Constructors
.. automodule:: ax.modelbridge.generation_node_input_constructors
:members:
:undoc-members:
:show-inheritance:

Registry
~~~~~~~~
.. automodule:: ax.modelbridge.registry
Expand Down

0 comments on commit 2a9cf98

Please sign in to comment.