Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make sure random seed persists beyond storage #2671

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 27 additions & 18 deletions ax/modelbridge/tests/test_generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from unittest import mock
from unittest.mock import MagicMock, patch

import numpy as np
from ax.core.arm import Arm
from ax.core.base_trial import TrialStatus
from ax.core.experiment import Experiment
Expand Down Expand Up @@ -448,6 +449,7 @@ def test_do_not_enforce_min_observations(self) -> None:
def test_sobol_GPEI_strategy(self) -> None:
exp = get_branin_experiment()
self.assertEqual(self.sobol_GPEI_GS.name, "Sobol+GPEI")
expected_seed = None
for i in range(7):
g = self.sobol_GPEI_GS.gen(exp)
exp.new_trial(generator_run=g).run()
Expand All @@ -470,7 +472,7 @@ def test_sobol_GPEI_strategy(self) -> None:
self.assertEqual(
mkw,
{
"seed": None,
"seed": expected_seed,
"deduplicate": True,
"init_position": i,
"scramble": True,
Expand All @@ -491,14 +493,17 @@ def test_sobol_GPEI_strategy(self) -> None:
"fit_on_init": True,
},
)
ms = g._model_state_after_gen
self.assertIsNotNone(ms)
# Generated points are randomized, so just checking that they are there.
self.assertIn("generated_points", ms)
# Remove the randomized generated points to compare the rest.
ms = ms.copy()
del ms["generated_points"]
self.assertEqual(ms, {"init_position": i + 1})
ms = not_none(g._model_state_after_gen).copy()
# Compare the model state to Sobol state.
sobol_model = not_none(self.sobol_GPEI_GS.model).model
self.assertTrue(
np.array_equal(
ms.pop("generated_points"), sobol_model.generated_points
)
)
# Replace expected seed with the one generated in __init__.
expected_seed = sobol_model.seed
self.assertEqual(ms, {"init_position": i + 1, "seed": expected_seed})
# Check completeness error message when GS should be done.
with self.assertRaises(GenerationStrategyCompleted):
g = self.sobol_GPEI_GS.gen(exp)
Expand Down Expand Up @@ -1212,6 +1217,7 @@ def test_gs_with_generation_nodes(self) -> None:
"Simple test of a SOBOL + GPEI GenerationStrategy composed of GenerationNodes"
exp = get_branin_experiment()
self.assertEqual(self.sobol_GPEI_GS_nodes.name, "Sobol+GPEI_Nodes")
expected_seed = None

for i in range(7):
g = self.sobol_GPEI_GS_nodes.gen(exp)
Expand All @@ -1235,7 +1241,7 @@ def test_gs_with_generation_nodes(self) -> None:
self.assertEqual(
mkw,
{
"seed": None,
"seed": expected_seed,
"deduplicate": True,
"init_position": i,
"scramble": True,
Expand All @@ -1256,14 +1262,17 @@ def test_gs_with_generation_nodes(self) -> None:
"fit_on_init": True,
},
)
ms = g._model_state_after_gen
self.assertIsNotNone(ms)
# Generated points are randomized, so just checking that they are there.
self.assertIn("generated_points", ms)
# Remove the randomized generated points to compare the rest.
ms = ms.copy()
del ms["generated_points"]
self.assertEqual(ms, {"init_position": i + 1})
ms = not_none(g._model_state_after_gen).copy()
# Compare the model state to Sobol state.
sobol_model = not_none(self.sobol_GPEI_GS_nodes.model).model
self.assertTrue(
np.array_equal(
ms.pop("generated_points"), sobol_model.generated_points
)
)
# Replace expected seed with the one generated in __init__.
expected_seed = sobol_model.seed
self.assertEqual(ms, {"init_position": i + 1, "seed": expected_seed})

def test_clone_reset_nodes(self) -> None:
"""Test that node-based generation strategy is appropriately reset
Expand Down
4 changes: 1 addition & 3 deletions ax/models/random/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,9 +180,7 @@ def gen(
@copy_doc(Model._get_state)
def _get_state(self) -> dict[str, Any]:
state = super()._get_state()
if not self.deduplicate:
return state
state.update({"generated_points": self.generated_points})
state.update({"seed": self.seed, "generated_points": self.generated_points})
return state

def _gen_unconstrained(
Expand Down
6 changes: 6 additions & 0 deletions ax/models/tests/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ def test_seed(self) -> None:
# With no seed.
self.assertIsInstance(self.random_model.seed, int)

def test_state(self) -> None:
for model in (self.random_model, RandomModel(seed=5)):
state = model._get_state()
self.assertEqual(state["seed"], model.seed)
self.assertEqual(state["generated_points"], model.generated_points)

def test_RandomModelGenSamples(self) -> None:
with self.assertRaises(NotImplementedError):
self.random_model._gen_samples(n=1, tunable_d=1)
Expand Down
17 changes: 15 additions & 2 deletions ax/models/tests/test_sobol.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,15 @@ def test_SobolGeneratorAllTunable(self) -> None:
self.assertTrue(np.all(generated_points >= np_bounds[:, 0]))
self.assertTrue(np.all(generated_points <= np_bounds[:, 1]))
self.assertTrue(np.all(weights == 1.0))
self.assertEqual(generator._get_state().get("init_position"), 3)
state = generator._get_state()
self.assertEqual(state.get("init_position"), 3)
self.assertEqual(state.get("seed"), generator.seed)
self.assertTrue(
np.array_equal(
state.get("generated_points"),
generator.generated_points,
)
)

def test_SobolGeneratorFixedSpace(self) -> None:
generator = SobolGenerator(seed=0, deduplicate=False)
Expand Down Expand Up @@ -308,4 +316,9 @@ def test_SobolGeneratorDedupe(self) -> None:
rounding_func=lambda x: x,
)
self.assertEqual(len(generated_points), 1)
self.assertIsNotNone(generator._get_state().get("generated_points"))
self.assertTrue(
np.array_equal(
generator._get_state().get("generated_points"),
generator.generated_points,
)
)
3 changes: 2 additions & 1 deletion ax/storage/sqa_store/tests/test_sqa_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,9 +456,10 @@ def test_ExperimentSaveAndLoadReducedState(
bkw = gr._bridge_kwargs
self.assertIsNotNone(bkw)
self.assertEqual(len(bkw), 9)
# This has seed, generated points and init position.
ms = gr._model_state_after_gen
self.assertIsNotNone(ms)
self.assertEqual(len(ms), 2)
self.assertEqual(len(ms), 3)
gm = gr._gen_metadata
self.assertIsNotNone(gm)
self.assertEqual(len(gm), 0)
Expand Down