Skip to content

Commit

Permalink
add AuxiliaryExperimentCheck transition criterion
Browse files Browse the repository at this point in the history
Summary: add `AuxiliaryExperimentCheck` transition criterion

Differential Revision: D63035595
  • Loading branch information
ItsMrLin authored and facebook-github-bot committed Sep 24, 2024
1 parent 549adf0 commit 9da922e
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 8 deletions.
96 changes: 95 additions & 1 deletion ax/modelbridge/transition_criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
from logging import Logger
from typing import Optional

from ax.core.auxiliary import AuxiliaryExperimentPurpose

from ax.core.base_trial import TrialStatus
from ax.core.experiment import Experiment
from ax.exceptions.core import DataRequiredError
from ax.exceptions.core import DataRequiredError, UserInputError
from ax.exceptions.generation_strategy import MaxParallelismReachedException

from ax.utils.common.base import SortableBase
Expand Down Expand Up @@ -615,6 +617,98 @@ def block_continued_generation_error(
pass


class AuxiliaryExperimentCheck(TransitionCriterion):
"""A class to transition from one GenerationNode to another by checking if certain
type of Auxiliary Expeirment exists.
Args:
transition_to: The name of the GenerationNode the GenerationStrategy should
transition to next.
auxiliary_experiment_purpose_to_include: auxiliary experiment purpose we
expect to have. Condition is met when both inclusion exclusion check pass.
auxiliary_experiment_purpose_to_exclude: auxiliary experiment purpose we
expect to not have. Condition is met when both inclusion exclusion check
pass.
block_transition_if_unmet: A flag to prevent the node from completing and
being able to transition to another node. Ex: This criterion defaults to
setting this to True to ensure we validate a GeneratorRun is generated by
the current GenerationNode.
complete_trial_generation: A flag to indicate that all generation for a given
trial is completed. This is necessary because in ``BatchTrial`` there
are multiple arms per trial, and we enable generation of arms within a
batch from different ``GenerationNodes``. This flag should be set to
True for the last node in a set of ``GenerationNodes`` expected to
create a given ``BatchTrial``.
"""

def __init__(
self,
transition_to: str,
auxiliary_experiment_purpose_to_include: Optional[
AuxiliaryExperimentPurpose
] = None,
auxiliary_experiment_purpose_to_exclude: Optional[
AuxiliaryExperimentPurpose
] = None,
block_transition_if_unmet: Optional[bool] = True,
continue_trial_generation: Optional[bool] = True,
) -> None:
super().__init__(
transition_to=transition_to,
block_transition_if_unmet=block_transition_if_unmet,
continue_trial_generation=continue_trial_generation,
)

if (
auxiliary_experiment_purpose_to_include is None
and auxiliary_experiment_purpose_to_exclude is None
):
raise UserInputError(
f"{self.__class__} cannot have both "
"`auxiliary_experiment_purpose_to_include` and "
"`auxiliary_experiment_purpose_to_exclude` be `None`."
)
self.auxiliary_experiment_purpose_to_include = (
auxiliary_experiment_purpose_to_include
)
self.auxiliary_experiment_purpose_to_exclude = (
auxiliary_experiment_purpose_to_exclude
)

def is_met(
self,
experiment: Experiment,
trials_from_node: Optional[set[int]] = None,
node_that_generated_last_gr: Optional[str] = None,
curr_node_name: Optional[str] = None,
) -> bool:
"""Check if the experiment has auxiliary experiments for certain purpose."""
check_pass = True

aux_exp_purposes = set(experiment.auxiliary_experiments_by_purpose.keys())
if self.auxiliary_experiment_purpose_to_include is not None:
check_pass = check_pass and (
self.auxiliary_experiment_purpose_to_include in aux_exp_purposes
)

if self.auxiliary_experiment_purpose_to_exclude is not None:
check_pass = check_pass and (
self.auxiliary_experiment_purpose_to_exclude not in aux_exp_purposes
)

return check_pass

def block_continued_generation_error(
self,
node_name: Optional[str],
model_name: Optional[str],
experiment: Optional[Experiment],
trials_from_node: Optional[set[int]] = None,
) -> None:
"""Error to be raised if the `block_gen_if_met` flag is set to True."""
pass


# TODO: Deprecate once legacy usecase is updated
class MinimumTrialsInStatus(TransitionCriterion):
"""
Expand Down
21 changes: 14 additions & 7 deletions ax/storage/json_store/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,11 @@
)
from ax.modelbridge.model_spec import ModelSpec
from ax.modelbridge.registry import _decode_callables_from_references
from ax.modelbridge.transition_criterion import TransitionCriterion, TrialBasedCriterion
from ax.modelbridge.transition_criterion import (
AuxiliaryExperimentCheck,
TransitionCriterion,
TrialBasedCriterion,
)
from ax.models.torch.botorch_modular.model import SurrogateSpec
from ax.models.torch.botorch_modular.surrogate import Surrogate
from ax.storage.json_store.decoders import (
Expand Down Expand Up @@ -164,7 +168,6 @@ def object_from_json(
# pyre-fixme[9, 24]: Generic type `type` expects 1 type parameter, use
# `typing.Type[<base type>]` to avoid runtime subscripting errors.
_class: type = decoder_registry[_type]

if isclass(_class) and issubclass(_class, Enum):
# to access enum members by name, use item access
return _class[object_json["name"]]
Expand Down Expand Up @@ -251,10 +254,14 @@ def object_from_json(
object_json["outcome_transform_options"] = (
outcome_transform_options_json
)
elif isclass(_class) and issubclass(_class, TrialBasedCriterion):
# TrialBasedCriterion contain a list of `TrialStatus` for args.
# This list needs to be unpacked by hand to properly retain the types.
return trial_transition_criteria_from_json(
elif isclass(_class) and (
issubclass(_class, TrialBasedCriterion)
or issubclass(_class, AuxiliaryExperimentCheck)
):
# TrialBasedCriterion contains a list of `TrialStatus` for args.
# AuxiliaryExperimentCheck contains AuxiliaryExperimentPurpose objects
# They need to be unpacked by hand to properly retain the types.
return unpack_transition_criteria_from_json(
class_=_class,
transition_criteria_json=object_json,
decoder_registry=decoder_registry,
Expand Down Expand Up @@ -349,7 +356,7 @@ def generator_run_from_json(
return generator_run


def trial_transition_criteria_from_json(
def unpack_transition_criteria_from_json(
# pyre-fixme[24]: Generic type `type` expects 1 type parameter, use `typing.Type` to
# avoid runtime subscripting errors.
class_: type,
Expand Down
3 changes: 3 additions & 0 deletions ax/storage/json_store/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
from ax.modelbridge.transforms.base import Transform
from ax.modelbridge.transition_criterion import (
AutoTransitionAfterGen,
AuxiliaryExperimentCheck,
MaxGenerationParallelism,
MaxTrials,
MinimumPreferenceOccurances,
Expand Down Expand Up @@ -210,6 +211,7 @@
MinTrials: transition_criterion_to_dict,
MinimumTrialsInStatus: transition_criterion_to_dict,
MinimumPreferenceOccurances: transition_criterion_to_dict,
AuxiliaryExperimentCheck: transition_criterion_to_dict,
ModelSpec: model_spec_to_dict,
MultiObjective: multi_objective_to_dict,
MultiObjectiveOptimizationConfig: multi_objective_optimization_config_to_dict,
Expand Down Expand Up @@ -326,6 +328,7 @@
"MinTrials": MinTrials,
"MinimumTrialsInStatus": MinimumTrialsInStatus,
"MinimumPreferenceOccurances": MinimumPreferenceOccurances,
"AuxiliaryExperimentCheck": AuxiliaryExperimentCheck,
"Models": Models,
"ModelRegistryBase": ModelRegistryBase,
"ModelSpec": ModelSpec,
Expand Down

0 comments on commit 9da922e

Please sign in to comment.