Skip to content

Commit

Permalink
Make sure random seed persists beyond storage
Browse files Browse the repository at this point in the history
Summary:
When the random seed is not specified in `model_kwargs`, it is set to a fixed value in `__init__`. If we do not store this generated `seed` and reload the experiment / GS, we end up continuing the random generation using a different seed. Storing `seed` in `model_state` will ensure it is stored (in last GR) and reused when the GS is reloaded.

Model state is extracted from last GR in `GS._fit_current_model`: https://www.internalfb.com/code/fbsource/[4d9fa225216d]/fbcode/ax/modelbridge/generation_strategy.py?lines=856
This gets passed down to `ModelSpec.fit` as `**model_kwargs`, which takes precedence over `ModelSpec.model_kwargs`: https://www.internalfb.com/code/fbsource/[4d9fa225216d]/fbcode/ax/modelbridge/model_spec.py?lines=131, which will ensure any `"seed": None` kwarg will get overwritten by the generated seed from last GR.

Differential Revision: D61479553
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Aug 19, 2024
1 parent 885288e commit fc90c42
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 3 deletions.
2 changes: 1 addition & 1 deletion ax/models/random/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def _get_state(self) -> dict[str, Any]:
state = super()._get_state()
if not self.deduplicate:
return state
state.update({"generated_points": self.generated_points})
state.update({"seed": self.seed, "generated_points": self.generated_points})
return state

def _gen_unconstrained(
Expand Down
6 changes: 6 additions & 0 deletions ax/models/tests/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ def test_seed(self) -> None:
# With no seed.
self.assertIsInstance(self.random_model.seed, int)

def test_state(self) -> None:
for model in (self.random_model, RandomModel(seed=5)):
state = model._get_state()
self.assertEqual(state["seed"], model.seed)
self.assertEqual(state["generated_points"], model.generated_points)

def test_RandomModelGenSamples(self) -> None:
with self.assertRaises(NotImplementedError):
self.random_model._gen_samples(n=1, tunable_d=1)
Expand Down
17 changes: 15 additions & 2 deletions ax/models/tests/test_sobol.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,15 @@ def test_SobolGeneratorAllTunable(self) -> None:
self.assertTrue(np.all(generated_points >= np_bounds[:, 0]))
self.assertTrue(np.all(generated_points <= np_bounds[:, 1]))
self.assertTrue(np.all(weights == 1.0))
self.assertEqual(generator._get_state().get("init_position"), 3)
state = generator._get_state()
self.assertEqual(state.get("init_position"), 3)
self.assertEqual(state.get("seed"), generator.seed)
self.assertTrue(
np.array_equal(
state.get("generated_points"),
generator.generated_points,
)
)

def test_SobolGeneratorFixedSpace(self) -> None:
generator = SobolGenerator(seed=0, deduplicate=False)
Expand Down Expand Up @@ -308,4 +316,9 @@ def test_SobolGeneratorDedupe(self) -> None:
rounding_func=lambda x: x,
)
self.assertEqual(len(generated_points), 1)
self.assertIsNotNone(generator._get_state().get("generated_points"))
self.assertTrue(
np.array_equal(
generator._get_state().get("generated_points"),
generator.generated_points,
)
)

0 comments on commit fc90c42

Please sign in to comment.