diff --git a/ax/models/random/base.py b/ax/models/random/base.py index ec0578b6125..8f4941bf439 100644 --- a/ax/models/random/base.py +++ b/ax/models/random/base.py @@ -182,7 +182,7 @@ def _get_state(self) -> dict[str, Any]: state = super()._get_state() if not self.deduplicate: return state - state.update({"generated_points": self.generated_points}) + state.update({"seed": self.seed, "generated_points": self.generated_points}) return state def _gen_unconstrained( diff --git a/ax/models/tests/test_random.py b/ax/models/tests/test_random.py index 7771c21434a..a2e60922b26 100644 --- a/ax/models/tests/test_random.py +++ b/ax/models/tests/test_random.py @@ -25,6 +25,12 @@ def test_seed(self) -> None: # With no seed. self.assertIsInstance(self.random_model.seed, int) + def test_state(self) -> None: + for model in (self.random_model, RandomModel(seed=5)): + state = model._get_state() + self.assertEqual(state["seed"], model.seed) + self.assertEqual(state["generated_points"], model.generated_points) + def test_RandomModelGenSamples(self) -> None: with self.assertRaises(NotImplementedError): self.random_model._gen_samples(n=1, tunable_d=1) diff --git a/ax/models/tests/test_sobol.py b/ax/models/tests/test_sobol.py index 51b670d5f58..57030986631 100644 --- a/ax/models/tests/test_sobol.py +++ b/ax/models/tests/test_sobol.py @@ -37,7 +37,15 @@ def test_SobolGeneratorAllTunable(self) -> None: self.assertTrue(np.all(generated_points >= np_bounds[:, 0])) self.assertTrue(np.all(generated_points <= np_bounds[:, 1])) self.assertTrue(np.all(weights == 1.0)) - self.assertEqual(generator._get_state().get("init_position"), 3) + state = generator._get_state() + self.assertEqual(state.get("init_position"), 3) + self.assertEqual(state.get("seed"), generator.seed) + self.assertTrue( + np.array_equal( + state.get("generated_points"), + generator.generated_points, + ) + ) def test_SobolGeneratorFixedSpace(self) -> None: generator = SobolGenerator(seed=0, deduplicate=False) @@ -308,4 +316,9 @@ def test_SobolGeneratorDedupe(self) -> None: rounding_func=lambda x: x, ) self.assertEqual(len(generated_points), 1) - self.assertIsNotNone(generator._get_state().get("generated_points")) + self.assertTrue( + np.array_equal( + generator._get_state().get("generated_points"), + generator.generated_points, + ) + )