Skip to content

Commit

Permalink
Accept any GSInterface for database updates
Browse files Browse the repository at this point in the history
Summary: In preparation for D51307866, accept and handle any tpye of GSInterface in save and update methods.  Only standard GSs will be saved though

Reviewed By: lena-kashtelyan

Differential Revision: D51677487
  • Loading branch information
Daniel Cohen authored and facebook-github-bot committed Nov 30, 2023
1 parent 7c7b32d commit 54359fd
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 26 deletions.
3 changes: 2 additions & 1 deletion ax/service/ax_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from ax.core.arm import Arm
from ax.core.base_trial import BaseTrial, TrialStatus
from ax.core.experiment import DataType, Experiment
from ax.core.generation_strategy_interface import GenerationStrategyInterface
from ax.core.generator_run import GeneratorRun
from ax.core.map_data import MapData
from ax.core.map_metric import MapMetric
Expand Down Expand Up @@ -1730,7 +1731,7 @@ def _set_generation_strategy(

def _save_generation_strategy_to_db_if_possible(
self,
generation_strategy: Optional[GenerationStrategy] = None,
generation_strategy: Optional[GenerationStrategyInterface] = None,
suppress_all_errors: bool = False,
) -> bool:
return super()._save_generation_strategy_to_db_if_possible(
Expand Down
23 changes: 22 additions & 1 deletion ax/service/tests/test_with_db_settings_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@
)
from ax.storage.sqa_store.structs import DBSettings
from ax.utils.common.testutils import TestCase
from ax.utils.testing.core_stubs import get_experiment, get_generator_run
from ax.utils.testing.core_stubs import (
get_experiment,
get_generator_run,
SpecialGenerationStrategy,
)
from ax.utils.testing.modeling_stubs import get_generation_strategy


Expand Down Expand Up @@ -127,6 +131,13 @@ def test_save_generation_strategy(self) -> None:
self.assertIsNotNone(loaded_gs)
self.assertEqual(loaded_gs.name, generation_strategy.name)

def test_save_non_standard_generation_strategy(self) -> None:
generation_strategy = SpecialGenerationStrategy()
saved = self.with_db_settings._save_generation_strategy_to_db_if_possible(
generation_strategy
)
self.assertFalse(saved)

def test_save_load_experiment_and_generation_strategy(self) -> None:
experiment, generation_strategy = self.init_experiment_and_generation_strategy(
save_generation_strategy=False
Expand Down Expand Up @@ -169,6 +180,16 @@ def test_update_generation_strategy(self) -> None:
self.assertIsNotNone(generator_run.db_id)
self.assertIsNotNone(generator_run.arms[0].db_id)

def test_update_non_standard_generation_strategy(self) -> None:
generation_strategy = SpecialGenerationStrategy()
generator_run = get_generator_run()
saved = self.with_db_settings._update_generation_strategy_in_db_if_possible(
generation_strategy, [generator_run]
)
self.assertFalse(saved)
self.assertIsNone(generator_run.db_id)
self.assertIsNone(generator_run.arms[0].db_id)

@patch(f"{WithDBSettingsBase.__module__}.STORAGE_MINI_BATCH_SIZE", 2)
def test_update_generation_strategy_mini_batches(self) -> None:
_, generation_strategy = self.init_experiment_and_generation_strategy()
Expand Down
61 changes: 37 additions & 24 deletions ax/service/utils/with_db_settings_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from ax.core.base_trial import BaseTrial
from ax.core.experiment import Experiment
from ax.core.generation_strategy_interface import GenerationStrategyInterface
from ax.core.generator_run import GeneratorRun
from ax.exceptions.core import (
IncompatibleDependencyVersion,
Expand Down Expand Up @@ -155,7 +156,7 @@ def _get_experiment_and_generation_strategy_db_id(
return exp_id, gs_id

def _maybe_save_experiment_and_generation_strategy(
self, experiment: Experiment, generation_strategy: GenerationStrategy
self, experiment: Experiment, generation_strategy: GenerationStrategyInterface
) -> Tuple[bool, bool]:
"""If DB settings are set on this `WithDBSettingsBase` instance, checks
whether given experiment and generation strategy are already saved and
Expand Down Expand Up @@ -304,7 +305,7 @@ def _save_or_update_trials_and_generation_strategy_if_possible(
self,
experiment: Experiment,
trials: List[BaseTrial],
generation_strategy: GenerationStrategy,
generation_strategy: GenerationStrategyInterface,
new_generator_runs: List[GeneratorRun],
reduce_state_generator_runs: bool = False,
) -> None:
Expand Down Expand Up @@ -386,56 +387,68 @@ def _save_or_update_trials_in_db_if_possible(

def _save_generation_strategy_to_db_if_possible(
self,
generation_strategy: Optional[GenerationStrategy] = None,
generation_strategy: Optional[GenerationStrategyInterface] = None,
suppress_all_errors: bool = False,
) -> bool:
"""Saves given generation strategy if DB settings are set on this
`WithDBSettingsBase` instance.
`WithDBSettingsBase` instance and the generation strategy is an
instance of `GenerationStrategy`.
Args:
generation_strategy: Generation strategy to save in DB.
generation_strategy: GenerationStrategyInterface to update in DB.
For now, only instances of GenerationStrategy will be updated.
Otherwise, this function is a no-op.
Returns:
bool: Whether the generation strategy was saved.
"""
if self.db_settings_set and generation_strategy is not None:
_save_generation_strategy_to_db_if_possible(
generation_strategy=generation_strategy,
encoder=self.db_settings.encoder,
decoder=self.db_settings.decoder,
suppress_all_errors=self._suppress_all_errors,
)
return True
# only local GenerationStrategies should need to be saved to
# the database because only they make changes locally
if isinstance(generation_strategy, GenerationStrategy):
_save_generation_strategy_to_db_if_possible(
generation_strategy=generation_strategy,
encoder=self.db_settings.encoder,
decoder=self.db_settings.decoder,
suppress_all_errors=self._suppress_all_errors,
)
return True
return False

def _update_generation_strategy_in_db_if_possible(
self,
generation_strategy: GenerationStrategy,
generation_strategy: GenerationStrategyInterface,
new_generator_runs: List[GeneratorRun],
reduce_state_generator_runs: bool = False,
) -> bool:
"""Updates the given generation strategy with new generator runs (and with
new current generation step if applicable) if DB settings are set
on this `WithDBSettingsBase` instance.
on this `WithDBSettingsBase` instance and the generation strategy is an
instance of `GenerationStrategy`.
Args:
generation_strategy: Generation strategy to update in DB.
generation_strategy: GenerationStrategyInterface to update in DB.
For now, only instances of GenerationStrategy will be updated.
Otherwise, this function is a no-op.
new_generator_runs: New generator runs of this generation strategy
since its last save.
Returns:
bool: Whether the experiment was saved.
"""
if self.db_settings_set:
_update_generation_strategy_in_db_if_possible(
generation_strategy=generation_strategy,
new_generator_runs=new_generator_runs,
encoder=self.db_settings.encoder,
decoder=self.db_settings.decoder,
suppress_all_errors=self._suppress_all_errors,
reduce_state_generator_runs=reduce_state_generator_runs,
)
return True
# only local GenerationStrategies should need to be saved to
# the database because only they make changes locally
if isinstance(generation_strategy, GenerationStrategy):
_update_generation_strategy_in_db_if_possible(
generation_strategy=generation_strategy,
new_generator_runs=new_generator_runs,
encoder=self.db_settings.encoder,
decoder=self.db_settings.decoder,
suppress_all_errors=self._suppress_all_errors,
reduce_state_generator_runs=reduce_state_generator_runs,
)
return True
return False

def _update_experiment_properties_in_db(
Expand Down
20 changes: 20 additions & 0 deletions ax/utils/testing/core_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from ax.core.batch_trial import AbandonedArm, BatchTrial
from ax.core.data import Data
from ax.core.experiment import DataType, Experiment
from ax.core.generation_strategy_interface import GenerationStrategyInterface
from ax.core.generator_run import GeneratorRun
from ax.core.map_data import MapData, MapKeyInfo
from ax.core.map_metric import MapMetric
Expand Down Expand Up @@ -2209,3 +2210,22 @@ class CustomTestMetric(Metric):
def __init__(self, name: str, test_attribute: str) -> None:
self.test_attribute = test_attribute
super().__init__(name=name)


class SpecialGenerationStrategy(GenerationStrategyInterface):
"""A subclass of `GenerationStrategyInterface` to be used
for testing how methods respond to subtypes other than
`GenerationStrategy`."""

def __init__(self) -> None:
self._name = "special"
self._generator_runs: List[GeneratorRun] = []

def gen_for_multiple_trials_with_multiple_models(
self,
experiment: Experiment,
num_generator_runs: int,
data: Optional[Data] = None,
n: int = 1,
) -> List[List[GeneratorRun]]:
return []

0 comments on commit 54359fd

Please sign in to comment.