From 54359fdf6231ea6425d96ca3870f6559607a0217 Mon Sep 17 00:00:00 2001 From: Daniel Cohen Date: Thu, 30 Nov 2023 13:56:50 -0800 Subject: [PATCH] Accept any GSInterface for database updates Summary: In preparation for D51307866, accept and handle any tpye of GSInterface in save and update methods. Only standard GSs will be saved though Reviewed By: lena-kashtelyan Differential Revision: D51677487 --- ax/service/ax_client.py | 3 +- .../tests/test_with_db_settings_base.py | 23 ++++++- ax/service/utils/with_db_settings_base.py | 61 +++++++++++-------- ax/utils/testing/core_stubs.py | 20 ++++++ 4 files changed, 81 insertions(+), 26 deletions(-) diff --git a/ax/service/ax_client.py b/ax/service/ax_client.py index 8bda0cfd095..3caee6a4625 100644 --- a/ax/service/ax_client.py +++ b/ax/service/ax_client.py @@ -32,6 +32,7 @@ from ax.core.arm import Arm from ax.core.base_trial import BaseTrial, TrialStatus from ax.core.experiment import DataType, Experiment +from ax.core.generation_strategy_interface import GenerationStrategyInterface from ax.core.generator_run import GeneratorRun from ax.core.map_data import MapData from ax.core.map_metric import MapMetric @@ -1730,7 +1731,7 @@ def _set_generation_strategy( def _save_generation_strategy_to_db_if_possible( self, - generation_strategy: Optional[GenerationStrategy] = None, + generation_strategy: Optional[GenerationStrategyInterface] = None, suppress_all_errors: bool = False, ) -> bool: return super()._save_generation_strategy_to_db_if_possible( diff --git a/ax/service/tests/test_with_db_settings_base.py b/ax/service/tests/test_with_db_settings_base.py index c9f0d530733..46d026ff61d 100644 --- a/ax/service/tests/test_with_db_settings_base.py +++ b/ax/service/tests/test_with_db_settings_base.py @@ -26,7 +26,11 @@ ) from ax.storage.sqa_store.structs import DBSettings from ax.utils.common.testutils import TestCase -from ax.utils.testing.core_stubs import get_experiment, get_generator_run +from ax.utils.testing.core_stubs import ( + get_experiment, + get_generator_run, + SpecialGenerationStrategy, +) from ax.utils.testing.modeling_stubs import get_generation_strategy @@ -127,6 +131,13 @@ def test_save_generation_strategy(self) -> None: self.assertIsNotNone(loaded_gs) self.assertEqual(loaded_gs.name, generation_strategy.name) + def test_save_non_standard_generation_strategy(self) -> None: + generation_strategy = SpecialGenerationStrategy() + saved = self.with_db_settings._save_generation_strategy_to_db_if_possible( + generation_strategy + ) + self.assertFalse(saved) + def test_save_load_experiment_and_generation_strategy(self) -> None: experiment, generation_strategy = self.init_experiment_and_generation_strategy( save_generation_strategy=False @@ -169,6 +180,16 @@ def test_update_generation_strategy(self) -> None: self.assertIsNotNone(generator_run.db_id) self.assertIsNotNone(generator_run.arms[0].db_id) + def test_update_non_standard_generation_strategy(self) -> None: + generation_strategy = SpecialGenerationStrategy() + generator_run = get_generator_run() + saved = self.with_db_settings._update_generation_strategy_in_db_if_possible( + generation_strategy, [generator_run] + ) + self.assertFalse(saved) + self.assertIsNone(generator_run.db_id) + self.assertIsNone(generator_run.arms[0].db_id) + @patch(f"{WithDBSettingsBase.__module__}.STORAGE_MINI_BATCH_SIZE", 2) def test_update_generation_strategy_mini_batches(self) -> None: _, generation_strategy = self.init_experiment_and_generation_strategy() diff --git a/ax/service/utils/with_db_settings_base.py b/ax/service/utils/with_db_settings_base.py index c4c55359b47..dd8c0405c37 100644 --- a/ax/service/utils/with_db_settings_base.py +++ b/ax/service/utils/with_db_settings_base.py @@ -12,6 +12,7 @@ from ax.core.base_trial import BaseTrial from ax.core.experiment import Experiment +from ax.core.generation_strategy_interface import GenerationStrategyInterface from ax.core.generator_run import GeneratorRun from ax.exceptions.core import ( IncompatibleDependencyVersion, @@ -155,7 +156,7 @@ def _get_experiment_and_generation_strategy_db_id( return exp_id, gs_id def _maybe_save_experiment_and_generation_strategy( - self, experiment: Experiment, generation_strategy: GenerationStrategy + self, experiment: Experiment, generation_strategy: GenerationStrategyInterface ) -> Tuple[bool, bool]: """If DB settings are set on this `WithDBSettingsBase` instance, checks whether given experiment and generation strategy are already saved and @@ -304,7 +305,7 @@ def _save_or_update_trials_and_generation_strategy_if_possible( self, experiment: Experiment, trials: List[BaseTrial], - generation_strategy: GenerationStrategy, + generation_strategy: GenerationStrategyInterface, new_generator_runs: List[GeneratorRun], reduce_state_generator_runs: bool = False, ) -> None: @@ -386,40 +387,49 @@ def _save_or_update_trials_in_db_if_possible( def _save_generation_strategy_to_db_if_possible( self, - generation_strategy: Optional[GenerationStrategy] = None, + generation_strategy: Optional[GenerationStrategyInterface] = None, suppress_all_errors: bool = False, ) -> bool: """Saves given generation strategy if DB settings are set on this - `WithDBSettingsBase` instance. + `WithDBSettingsBase` instance and the generation strategy is an + instance of `GenerationStrategy`. Args: - generation_strategy: Generation strategy to save in DB. + generation_strategy: GenerationStrategyInterface to update in DB. + For now, only instances of GenerationStrategy will be updated. + Otherwise, this function is a no-op. Returns: bool: Whether the generation strategy was saved. """ if self.db_settings_set and generation_strategy is not None: - _save_generation_strategy_to_db_if_possible( - generation_strategy=generation_strategy, - encoder=self.db_settings.encoder, - decoder=self.db_settings.decoder, - suppress_all_errors=self._suppress_all_errors, - ) - return True + # only local GenerationStrategies should need to be saved to + # the database because only they make changes locally + if isinstance(generation_strategy, GenerationStrategy): + _save_generation_strategy_to_db_if_possible( + generation_strategy=generation_strategy, + encoder=self.db_settings.encoder, + decoder=self.db_settings.decoder, + suppress_all_errors=self._suppress_all_errors, + ) + return True return False def _update_generation_strategy_in_db_if_possible( self, - generation_strategy: GenerationStrategy, + generation_strategy: GenerationStrategyInterface, new_generator_runs: List[GeneratorRun], reduce_state_generator_runs: bool = False, ) -> bool: """Updates the given generation strategy with new generator runs (and with new current generation step if applicable) if DB settings are set - on this `WithDBSettingsBase` instance. + on this `WithDBSettingsBase` instance and the generation strategy is an + instance of `GenerationStrategy`. Args: - generation_strategy: Generation strategy to update in DB. + generation_strategy: GenerationStrategyInterface to update in DB. + For now, only instances of GenerationStrategy will be updated. + Otherwise, this function is a no-op. new_generator_runs: New generator runs of this generation strategy since its last save. @@ -427,15 +437,18 @@ def _update_generation_strategy_in_db_if_possible( bool: Whether the experiment was saved. """ if self.db_settings_set: - _update_generation_strategy_in_db_if_possible( - generation_strategy=generation_strategy, - new_generator_runs=new_generator_runs, - encoder=self.db_settings.encoder, - decoder=self.db_settings.decoder, - suppress_all_errors=self._suppress_all_errors, - reduce_state_generator_runs=reduce_state_generator_runs, - ) - return True + # only local GenerationStrategies should need to be saved to + # the database because only they make changes locally + if isinstance(generation_strategy, GenerationStrategy): + _update_generation_strategy_in_db_if_possible( + generation_strategy=generation_strategy, + new_generator_runs=new_generator_runs, + encoder=self.db_settings.encoder, + decoder=self.db_settings.decoder, + suppress_all_errors=self._suppress_all_errors, + reduce_state_generator_runs=reduce_state_generator_runs, + ) + return True return False def _update_experiment_properties_in_db( diff --git a/ax/utils/testing/core_stubs.py b/ax/utils/testing/core_stubs.py index 537b39912de..0e1e37cace0 100644 --- a/ax/utils/testing/core_stubs.py +++ b/ax/utils/testing/core_stubs.py @@ -31,6 +31,7 @@ from ax.core.batch_trial import AbandonedArm, BatchTrial from ax.core.data import Data from ax.core.experiment import DataType, Experiment +from ax.core.generation_strategy_interface import GenerationStrategyInterface from ax.core.generator_run import GeneratorRun from ax.core.map_data import MapData, MapKeyInfo from ax.core.map_metric import MapMetric @@ -2209,3 +2210,22 @@ class CustomTestMetric(Metric): def __init__(self, name: str, test_attribute: str) -> None: self.test_attribute = test_attribute super().__init__(name=name) + + +class SpecialGenerationStrategy(GenerationStrategyInterface): + """A subclass of `GenerationStrategyInterface` to be used + for testing how methods respond to subtypes other than + `GenerationStrategy`.""" + + def __init__(self) -> None: + self._name = "special" + self._generator_runs: List[GeneratorRun] = [] + + def gen_for_multiple_trials_with_multiple_models( + self, + experiment: Experiment, + num_generator_runs: int, + data: Optional[Data] = None, + n: int = 1, + ) -> List[List[GeneratorRun]]: + return []