Skip to content

Commit

Permalink
Merge f84282b into e51d48b
Browse files Browse the repository at this point in the history
  • Loading branch information
mgarrard authored Nov 15, 2023
2 parents e51d48b + f84282b commit 12a20d2
Show file tree
Hide file tree
Showing 6 changed files with 9 additions and 17 deletions.
2 changes: 0 additions & 2 deletions ax/modelbridge/generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,8 +330,6 @@ def _unset_non_persistent_state_fields(self) -> None:
self._model = None
for s in self._steps:
s._model_spec_to_gen_from = None
# TODO: @mgarrard remove once re-enabled criterion storage
s._transition_criteria = []

def __repr__(self) -> str:
"""String representation of this generation strategy."""
Expand Down
1 change: 0 additions & 1 deletion ax/service/tests/test_ax_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,7 +614,6 @@ def test_save_and_load_generation_strategy(self) -> None:
)
second_client = AxClient(db_settings=db_settings)
second_client.load_experiment_from_database("unique_test_experiment")
generation_strategy._unset_non_persistent_state_fields()
self.assertEqual(second_client.generation_strategy, generation_strategy)

@patch(
Expand Down
1 change: 0 additions & 1 deletion ax/service/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,7 +777,6 @@ def test_sqa_storage(self) -> None:
# Check that experiment and GS were saved.
exp, gs = scheduler._load_experiment_and_generation_strategy(experiment.name)
self.assertEqual(exp, experiment)
self.two_sobol_steps_GS._unset_non_persistent_state_fields()
self.assertEqual(gs, self.two_sobol_steps_GS)
scheduler.run_all_trials()
# Check that experiment and GS were saved and test reloading with reduced state.
Expand Down
9 changes: 1 addition & 8 deletions ax/storage/json_store/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,14 +728,7 @@ def generation_step_from_json(
if gen_kwargs
else None,
index=generation_step_json.pop("index", -1),
should_deduplicate=generation_step_json.pop("should_deduplicate")
if "should_deduplicate" in generation_step_json
else False,
)
generation_step._transition_criteria = transition_criteria_from_json(
generation_step_json.pop("transition_criteria")
if "transition_criteria" in generation_step_json.keys()
else None
should_deduplicate=generation_step_json.pop("should_deduplicate", False),
)
return generation_step

Expand Down
12 changes: 8 additions & 4 deletions ax/storage/json_store/tests/test_json_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import numpy as np
import torch
from ax.benchmark.benchmark_method import BenchmarkMethod
from ax.core.metric import Metric
from ax.core.runner import Runner
from ax.exceptions.core import AxStorageWarning
Expand Down Expand Up @@ -328,8 +327,14 @@ def test_EncodeDecode(self) -> None:
converted_object = converted_object.state_dict()
if isinstance(original_object, GenerationStrategy):
original_object._unset_non_persistent_state_fields()
if isinstance(original_object, BenchmarkMethod):
original_object.generation_strategy._unset_non_persistent_state_fields()
# for the test, completion criterion are set post init
# and therefore do not become transition critirion, unset
# for this specific test only
if "with_completion_criteria" in fake_func.keywords:
for step in original_object._steps:
step._transition_criteria = None
for step in converted_object._steps:
step._transition_criteria = None
try:
self.assertEqual(
original_object,
Expand Down Expand Up @@ -402,7 +407,6 @@ def test_DecodeGenerationStrategy(self) -> None:
decoder_registry=CORE_DECODER_REGISTRY,
class_decoder_registry=CORE_CLASS_DECODER_REGISTRY,
)
generation_strategy._unset_non_persistent_state_fields()
self.assertEqual(generation_strategy, new_generation_strategy)
self.assertGreater(len(new_generation_strategy._steps), 0)
self.assertIsInstance(new_generation_strategy._steps[0].model, Models)
Expand Down
1 change: 0 additions & 1 deletion ax/storage/sqa_store/tests/test_sqa_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -1226,7 +1226,6 @@ def test_EncodeDecodeGenerationStrategy(self) -> None:
# pyre-fixme[6]: For 1st param expected `int` but got `Optional[int]`.
gs_id=generation_strategy._db_id
)
generation_strategy._unset_non_persistent_state_fields()
self.assertEqual(generation_strategy, new_generation_strategy)
self.assertIsNone(generation_strategy._experiment)

Expand Down

0 comments on commit 12a20d2

Please sign in to comment.