Skip to content

Commit

Permalink
Adds test that all constructors and purposes are tested and that they…
Browse files Browse the repository at this point in the history
… all share the same signature (#2764)

Summary:
Pull Request resolved: #2764

This is a requested follow up by Liz/Sait/Daniel to (1) ensure all constructor methods share the same signature and (2) ensure that they are all tested.

For two, this still relies on best effort to ensure that when you add a new constructor you actually add the relevant tests, but it will ensure that the new constructor has the correct signature.

These input constructors are more simple than those used in the modeling layer so i do believe they should be less error prone generally

Follow up diffs:
- storage for input constructors --> let's do this once we all like the 3 input constructors
- update the input constructors to handle the case where n isn't provided as a kwarg
- decide if it's more legible for the circular import avoider to be on inputconstructors instead of the generationnode file

Reviewed By: lena-kashtelyan

Differential Revision: D62553827

fbshipit-source-id: d1eddb4038e96f39fdc98dfa1a73ef0c29e528a4
  • Loading branch information
mgarrard authored and facebook-github-bot committed Sep 24, 2024
1 parent 64d2e00 commit 549adf0
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 1 deletion.
2 changes: 2 additions & 0 deletions ax/modelbridge/generation_node_input_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ def remaining_n(
class NodeInputConstructors(Enum):
"""An enum which maps to a callable method for constructing ``GenerationNode``
inputs.
NOTE: The methods defined by this enum should all share identical signatures.
"""

ALL_N = consume_all_n
Expand Down
46 changes: 45 additions & 1 deletion ax/modelbridge/tests/test_generation_node_input_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,17 @@

# pyre-strict

import inspect
from collections import Counter
from typing import Any, Dict, get_type_hints, Optional

from ax.core.arm import Arm
from ax.core.generator_run import GeneratorRun
from ax.modelbridge.generation_node import GenerationNode
from ax.modelbridge.generation_node_input_constructors import NodeInputConstructors
from ax.modelbridge.generation_node_input_constructors import (
InputConstructorPurpose,
NodeInputConstructors,
)
from ax.modelbridge.model_spec import ModelSpec
from ax.modelbridge.registry import Models
from ax.utils.common.testutils import TestCase
Expand Down Expand Up @@ -126,3 +133,40 @@ def test_no_n_provided_error_remaining_n(self) -> None:
next_node=self.sobol_generation_node,
gs_gen_call_kwargs={},
)


class TestInstantiationFromNodeInputConstructor(TestCase):
"""Class to test that all node input constructors can be instantiated and are
being tested."""

def setUp(self) -> None:
super().setUp()
self.constructor_cases = {
"ALl_N": NodeInputConstructors.ALL_N,
"REPEAT_N": NodeInputConstructors.REPEAT_N,
"REMAINING_N": NodeInputConstructors.REMAINING_N,
}
self.purpose_cases = {
"N": InputConstructorPurpose.N,
}

def test_all_constructors_have_same_signature(self) -> None:
"""Test that all node input constructors methods have the same signature
and that the parameters are of the expected types"""
all_constructors_tested = list(self.constructor_cases.values())
method_signature = inspect.signature(all_constructors_tested[0])
for constructor in all_constructors_tested[1:]:
with self.subTest(constructor=constructor):
func_parameters = get_type_hints(constructor)
self.assertEqual(
Counter(list(func_parameters.keys())),
Counter(
["previous_node", "next_node", "gs_gen_call_kwargs", "return"]
),
)
self.assertEqual(
func_parameters["previous_node"], Optional[GenerationNode]
)
self.assertEqual(func_parameters["next_node"], GenerationNode)
self.assertEqual(func_parameters["gs_gen_call_kwargs"], Dict[str, Any])
self.assertEqual(method_signature, inspect.signature(constructor))

0 comments on commit 549adf0

Please sign in to comment.