Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add AuxiliaryExperimentCheck transition criterion #2778

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 95 additions & 1 deletion ax/modelbridge/transition_criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
from logging import Logger
from typing import Optional

from ax.core.auxiliary import AuxiliaryExperimentPurpose

from ax.core.base_trial import TrialStatus
from ax.core.experiment import Experiment
from ax.exceptions.core import DataRequiredError
from ax.exceptions.core import DataRequiredError, UserInputError
from ax.exceptions.generation_strategy import MaxParallelismReachedException

from ax.utils.common.base import SortableBase
Expand Down Expand Up @@ -615,6 +617,98 @@ def block_continued_generation_error(
pass


class AuxiliaryExperimentCheck(TransitionCriterion):
"""A class to transition from one GenerationNode to another by checking if certain
type of Auxiliary Expeirment exists.

Args:
transition_to: The name of the GenerationNode the GenerationStrategy should
transition to next.
auxiliary_experiment_purpose_to_include: auxiliary experiment purpose we
expect to have. Condition is met when both inclusion exclusion check pass.
auxiliary_experiment_purpose_to_exclude: auxiliary experiment purpose we
expect to not have. Condition is met when both inclusion exclusion check
pass.
block_transition_if_unmet: A flag to prevent the node from completing and
being able to transition to another node. Ex: This criterion defaults to
setting this to True to ensure we validate a GeneratorRun is generated by
the current GenerationNode.
complete_trial_generation: A flag to indicate that all generation for a given
trial is completed. This is necessary because in ``BatchTrial`` there
are multiple arms per trial, and we enable generation of arms within a
batch from different ``GenerationNodes``. This flag should be set to
True for the last node in a set of ``GenerationNodes`` expected to
create a given ``BatchTrial``.
"""

def __init__(
self,
transition_to: str,
auxiliary_experiment_purpose_to_include: Optional[
AuxiliaryExperimentPurpose
] = None,
auxiliary_experiment_purpose_to_exclude: Optional[
AuxiliaryExperimentPurpose
] = None,
block_transition_if_unmet: Optional[bool] = True,
continue_trial_generation: Optional[bool] = True,
) -> None:
super().__init__(
transition_to=transition_to,
block_transition_if_unmet=block_transition_if_unmet,
continue_trial_generation=continue_trial_generation,
)

if (
auxiliary_experiment_purpose_to_include is None
and auxiliary_experiment_purpose_to_exclude is None
):
raise UserInputError(
f"{self.__class__} cannot have both "
"`auxiliary_experiment_purpose_to_include` and "
"`auxiliary_experiment_purpose_to_exclude` be `None`."
)
self.auxiliary_experiment_purpose_to_include = (
auxiliary_experiment_purpose_to_include
)
self.auxiliary_experiment_purpose_to_exclude = (
auxiliary_experiment_purpose_to_exclude
)

def is_met(
self,
experiment: Experiment,
trials_from_node: Optional[set[int]] = None,
node_that_generated_last_gr: Optional[str] = None,
curr_node_name: Optional[str] = None,
) -> bool:
"""Check if the experiment has auxiliary experiments for certain purpose."""
check_pass = True

aux_exp_purposes = set(experiment.auxiliary_experiments_by_purpose.keys())
if self.auxiliary_experiment_purpose_to_include is not None:
check_pass = check_pass and (
self.auxiliary_experiment_purpose_to_include in aux_exp_purposes
)

if self.auxiliary_experiment_purpose_to_exclude is not None:
check_pass = check_pass and (
self.auxiliary_experiment_purpose_to_exclude not in aux_exp_purposes
)

return check_pass

def block_continued_generation_error(
self,
node_name: Optional[str],
model_name: Optional[str],
experiment: Optional[Experiment],
trials_from_node: Optional[set[int]] = None,
) -> None:
"""Error to be raised if the `block_gen_if_met` flag is set to True."""
pass


# TODO: Deprecate once legacy usecase is updated
class MinimumTrialsInStatus(TransitionCriterion):
"""
Expand Down
21 changes: 14 additions & 7 deletions ax/storage/json_store/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,11 @@
)
from ax.modelbridge.model_spec import ModelSpec
from ax.modelbridge.registry import _decode_callables_from_references
from ax.modelbridge.transition_criterion import TransitionCriterion, TrialBasedCriterion
from ax.modelbridge.transition_criterion import (
AuxiliaryExperimentCheck,
TransitionCriterion,
TrialBasedCriterion,
)
from ax.models.torch.botorch_modular.model import SurrogateSpec
from ax.models.torch.botorch_modular.surrogate import Surrogate
from ax.storage.json_store.decoders import (
Expand Down Expand Up @@ -164,7 +168,6 @@ def object_from_json(
# pyre-fixme[9, 24]: Generic type `type` expects 1 type parameter, use
# `typing.Type[<base type>]` to avoid runtime subscripting errors.
_class: type = decoder_registry[_type]

if isclass(_class) and issubclass(_class, Enum):
# to access enum members by name, use item access
return _class[object_json["name"]]
Expand Down Expand Up @@ -251,10 +254,14 @@ def object_from_json(
object_json["outcome_transform_options"] = (
outcome_transform_options_json
)
elif isclass(_class) and issubclass(_class, TrialBasedCriterion):
# TrialBasedCriterion contain a list of `TrialStatus` for args.
# This list needs to be unpacked by hand to properly retain the types.
return trial_transition_criteria_from_json(
elif isclass(_class) and (
issubclass(_class, TrialBasedCriterion)
or issubclass(_class, AuxiliaryExperimentCheck)
):
# TrialBasedCriterion contains a list of `TrialStatus` for args.
# AuxiliaryExperimentCheck contains AuxiliaryExperimentPurpose objects
# They need to be unpacked by hand to properly retain the types.
return unpack_transition_criteria_from_json(
class_=_class,
transition_criteria_json=object_json,
decoder_registry=decoder_registry,
Expand Down Expand Up @@ -349,7 +356,7 @@ def generator_run_from_json(
return generator_run


def trial_transition_criteria_from_json(
def unpack_transition_criteria_from_json(
# pyre-fixme[24]: Generic type `type` expects 1 type parameter, use `typing.Type` to
# avoid runtime subscripting errors.
class_: type,
Expand Down
3 changes: 3 additions & 0 deletions ax/storage/json_store/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
from ax.modelbridge.transforms.base import Transform
from ax.modelbridge.transition_criterion import (
AutoTransitionAfterGen,
AuxiliaryExperimentCheck,
MaxGenerationParallelism,
MaxTrials,
MinimumPreferenceOccurances,
Expand Down Expand Up @@ -210,6 +211,7 @@
MinTrials: transition_criterion_to_dict,
MinimumTrialsInStatus: transition_criterion_to_dict,
MinimumPreferenceOccurances: transition_criterion_to_dict,
AuxiliaryExperimentCheck: transition_criterion_to_dict,
ModelSpec: model_spec_to_dict,
MultiObjective: multi_objective_to_dict,
MultiObjectiveOptimizationConfig: multi_objective_optimization_config_to_dict,
Expand Down Expand Up @@ -326,6 +328,7 @@
"MinTrials": MinTrials,
"MinimumTrialsInStatus": MinimumTrialsInStatus,
"MinimumPreferenceOccurances": MinimumPreferenceOccurances,
"AuxiliaryExperimentCheck": AuxiliaryExperimentCheck,
"Models": Models,
"ModelRegistryBase": ModelRegistryBase,
"ModelSpec": ModelSpec,
Expand Down