From 9da922e716665957b157e044bb19c8787b4a948a Mon Sep 17 00:00:00 2001 From: Jerry Lin Date: Mon, 23 Sep 2024 23:04:32 -0700 Subject: [PATCH] add `AuxiliaryExperimentCheck` transition criterion Summary: add `AuxiliaryExperimentCheck` transition criterion Differential Revision: D63035595 --- ax/modelbridge/transition_criterion.py | 96 +++++++++++++++++++++++++- ax/storage/json_store/decoder.py | 21 ++++-- ax/storage/json_store/registry.py | 3 + 3 files changed, 112 insertions(+), 8 deletions(-) diff --git a/ax/modelbridge/transition_criterion.py b/ax/modelbridge/transition_criterion.py index d8f5248dc97..d3a4f9f8264 100644 --- a/ax/modelbridge/transition_criterion.py +++ b/ax/modelbridge/transition_criterion.py @@ -9,9 +9,11 @@ from logging import Logger from typing import Optional +from ax.core.auxiliary import AuxiliaryExperimentPurpose + from ax.core.base_trial import TrialStatus from ax.core.experiment import Experiment -from ax.exceptions.core import DataRequiredError +from ax.exceptions.core import DataRequiredError, UserInputError from ax.exceptions.generation_strategy import MaxParallelismReachedException from ax.utils.common.base import SortableBase @@ -615,6 +617,98 @@ def block_continued_generation_error( pass +class AuxiliaryExperimentCheck(TransitionCriterion): + """A class to transition from one GenerationNode to another by checking if certain + type of Auxiliary Expeirment exists. + + Args: + transition_to: The name of the GenerationNode the GenerationStrategy should + transition to next. + auxiliary_experiment_purpose_to_include: auxiliary experiment purpose we + expect to have. Condition is met when both inclusion exclusion check pass. + auxiliary_experiment_purpose_to_exclude: auxiliary experiment purpose we + expect to not have. Condition is met when both inclusion exclusion check + pass. + block_transition_if_unmet: A flag to prevent the node from completing and + being able to transition to another node. Ex: This criterion defaults to + setting this to True to ensure we validate a GeneratorRun is generated by + the current GenerationNode. + complete_trial_generation: A flag to indicate that all generation for a given + trial is completed. This is necessary because in ``BatchTrial`` there + are multiple arms per trial, and we enable generation of arms within a + batch from different ``GenerationNodes``. This flag should be set to + True for the last node in a set of ``GenerationNodes`` expected to + create a given ``BatchTrial``. + """ + + def __init__( + self, + transition_to: str, + auxiliary_experiment_purpose_to_include: Optional[ + AuxiliaryExperimentPurpose + ] = None, + auxiliary_experiment_purpose_to_exclude: Optional[ + AuxiliaryExperimentPurpose + ] = None, + block_transition_if_unmet: Optional[bool] = True, + continue_trial_generation: Optional[bool] = True, + ) -> None: + super().__init__( + transition_to=transition_to, + block_transition_if_unmet=block_transition_if_unmet, + continue_trial_generation=continue_trial_generation, + ) + + if ( + auxiliary_experiment_purpose_to_include is None + and auxiliary_experiment_purpose_to_exclude is None + ): + raise UserInputError( + f"{self.__class__} cannot have both " + "`auxiliary_experiment_purpose_to_include` and " + "`auxiliary_experiment_purpose_to_exclude` be `None`." + ) + self.auxiliary_experiment_purpose_to_include = ( + auxiliary_experiment_purpose_to_include + ) + self.auxiliary_experiment_purpose_to_exclude = ( + auxiliary_experiment_purpose_to_exclude + ) + + def is_met( + self, + experiment: Experiment, + trials_from_node: Optional[set[int]] = None, + node_that_generated_last_gr: Optional[str] = None, + curr_node_name: Optional[str] = None, + ) -> bool: + """Check if the experiment has auxiliary experiments for certain purpose.""" + check_pass = True + + aux_exp_purposes = set(experiment.auxiliary_experiments_by_purpose.keys()) + if self.auxiliary_experiment_purpose_to_include is not None: + check_pass = check_pass and ( + self.auxiliary_experiment_purpose_to_include in aux_exp_purposes + ) + + if self.auxiliary_experiment_purpose_to_exclude is not None: + check_pass = check_pass and ( + self.auxiliary_experiment_purpose_to_exclude not in aux_exp_purposes + ) + + return check_pass + + def block_continued_generation_error( + self, + node_name: Optional[str], + model_name: Optional[str], + experiment: Optional[Experiment], + trials_from_node: Optional[set[int]] = None, + ) -> None: + """Error to be raised if the `block_gen_if_met` flag is set to True.""" + pass + + # TODO: Deprecate once legacy usecase is updated class MinimumTrialsInStatus(TransitionCriterion): """ diff --git a/ax/storage/json_store/decoder.py b/ax/storage/json_store/decoder.py index 79df2f3dd7f..afc78d26da0 100644 --- a/ax/storage/json_store/decoder.py +++ b/ax/storage/json_store/decoder.py @@ -38,7 +38,11 @@ ) from ax.modelbridge.model_spec import ModelSpec from ax.modelbridge.registry import _decode_callables_from_references -from ax.modelbridge.transition_criterion import TransitionCriterion, TrialBasedCriterion +from ax.modelbridge.transition_criterion import ( + AuxiliaryExperimentCheck, + TransitionCriterion, + TrialBasedCriterion, +) from ax.models.torch.botorch_modular.model import SurrogateSpec from ax.models.torch.botorch_modular.surrogate import Surrogate from ax.storage.json_store.decoders import ( @@ -164,7 +168,6 @@ def object_from_json( # pyre-fixme[9, 24]: Generic type `type` expects 1 type parameter, use # `typing.Type[]` to avoid runtime subscripting errors. _class: type = decoder_registry[_type] - if isclass(_class) and issubclass(_class, Enum): # to access enum members by name, use item access return _class[object_json["name"]] @@ -251,10 +254,14 @@ def object_from_json( object_json["outcome_transform_options"] = ( outcome_transform_options_json ) - elif isclass(_class) and issubclass(_class, TrialBasedCriterion): - # TrialBasedCriterion contain a list of `TrialStatus` for args. - # This list needs to be unpacked by hand to properly retain the types. - return trial_transition_criteria_from_json( + elif isclass(_class) and ( + issubclass(_class, TrialBasedCriterion) + or issubclass(_class, AuxiliaryExperimentCheck) + ): + # TrialBasedCriterion contains a list of `TrialStatus` for args. + # AuxiliaryExperimentCheck contains AuxiliaryExperimentPurpose objects + # They need to be unpacked by hand to properly retain the types. + return unpack_transition_criteria_from_json( class_=_class, transition_criteria_json=object_json, decoder_registry=decoder_registry, @@ -349,7 +356,7 @@ def generator_run_from_json( return generator_run -def trial_transition_criteria_from_json( +def unpack_transition_criteria_from_json( # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use `typing.Type` to # avoid runtime subscripting errors. class_: type, diff --git a/ax/storage/json_store/registry.py b/ax/storage/json_store/registry.py index be2d164a3b1..4665810a669 100644 --- a/ax/storage/json_store/registry.py +++ b/ax/storage/json_store/registry.py @@ -88,6 +88,7 @@ from ax.modelbridge.transforms.base import Transform from ax.modelbridge.transition_criterion import ( AutoTransitionAfterGen, + AuxiliaryExperimentCheck, MaxGenerationParallelism, MaxTrials, MinimumPreferenceOccurances, @@ -210,6 +211,7 @@ MinTrials: transition_criterion_to_dict, MinimumTrialsInStatus: transition_criterion_to_dict, MinimumPreferenceOccurances: transition_criterion_to_dict, + AuxiliaryExperimentCheck: transition_criterion_to_dict, ModelSpec: model_spec_to_dict, MultiObjective: multi_objective_to_dict, MultiObjectiveOptimizationConfig: multi_objective_optimization_config_to_dict, @@ -326,6 +328,7 @@ "MinTrials": MinTrials, "MinimumTrialsInStatus": MinimumTrialsInStatus, "MinimumPreferenceOccurances": MinimumPreferenceOccurances, + "AuxiliaryExperimentCheck": AuxiliaryExperimentCheck, "Models": Models, "ModelRegistryBase": ModelRegistryBase, "ModelSpec": ModelSpec,