Skip to content

Commit

Permalink
Light cleanup of GenerationStrategyInterface (facebook#2256)
Browse files Browse the repository at this point in the history
Summary:

Key changes: 
1. bring back to GS aspects of GSI that are not needed there and are polluting the interface,
2. make `_name` a required attribute of both GSI and GS. 

Re: 2), having it be set during the first call to `name` property was causing weird bugs in equality checks, there two GSs looked like they were equal but they weren't at the time of the initial equality check (one had the `_name` set because its `name` prop was called, and another did not yet). They would become equal during the call to `__repr__` that occurred in reporting their inequality as an error (!), because `__repr__` would call `GS.name`, which would result in `GS._name` getting set. Weird stuff!

Reviewed By: danielcohenlive, mgarrard

Differential Revision: D51441575
  • Loading branch information
Lena Kashtelyan authored and facebook-github-bot committed Mar 7, 2024
1 parent d0df866 commit 1290067
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 45 deletions.
83 changes: 42 additions & 41 deletions ax/core/generation_strategy_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,46 +5,59 @@

# pyre-strict

from __future__ import annotations

from abc import ABC, abstractmethod

from typing import List, Optional

from ax.core.data import Data
from ax.core.experiment import Experiment
from ax.core.generator_run import GeneratorRun
from ax.exceptions.core import UnsupportedError
from ax.utils.common.base import Base
from ax.utils.common.typeutils import not_none


class GenerationStrategyInterface(ABC, Base):
_name: Optional[str]
# All generator runs created through this generation strategy, in chronological
# order.
_generator_runs: List[GeneratorRun]
"""Interface for all generation strategies: standard Ax
``GenerationStrategy``, as well as non-standard (e.g. remote, external)
generation strategies.
NOTE: Currently in Beta; please do not use without discussion with the Ax
developers.
"""

_name: str
# Experiment, for which this generation strategy has generated trials, if
# it exists.
_experiment: Optional[Experiment] = None

def __init__(self, name: str) -> None:
self._name = name

@abstractmethod
def gen_for_multiple_trials_with_multiple_models(
self,
experiment: Experiment,
num_generator_runs: int,
data: Optional[Data] = None,
# TODO[drfreund, danielcohennyc, mgarrard]: Update the format of the arguments
# below as we find the right one.
num_generator_runs: int = 1,
n: int = 1,
) -> List[List[GeneratorRun]]:
"""Produce GeneratorRuns for multiple trials at once with the possibility of
ensembling, or using multiple models per trial, getting multiple
GeneratorRuns per trial.
"""Produce ``GeneratorRun``-s for multiple trials at once with the possibility
of joining ``GeneratorRun``-s from multiple models into one ``BatchTrial``.
Args:
experiment: Experiment, for which the generation strategy is producing
a new generator run in the course of `gen`, and to which that
experiment: ``Experiment``, for which the generation strategy is producing
a new generator run in the course of ``gen``, and to which that
generator run will be added as trial(s). Information stored on the
experiment (e.g., trial statuses) is used to determine which model
will be used to produce the generator run returned from this method.
data: Optional data to be passed to the underlying model's `gen`, which
data: Optional data to be passed to the underlying model's ``gen``, which
is called within this method and actually produces the resulting
generator run. By default, data is all data on the `experiment`.
generator run. By default, data is all data on the ``experiment``.
n: Integer representing how many trials should be in the generator run
produced by this method. NOTE: Some underlying models may ignore
the ``n`` and produce a model-determined number of arms. In that
Expand All @@ -55,27 +68,24 @@ def gen_for_multiple_trials_with_multiple_models(
resuggesting points that are currently being evaluated.
Returns:
A list of lists of lists generator runs. Each outer list represents
a trial being suggested and each inner list represents a generator
run for that trial.
A list of lists of ``GeneratorRun``-s. Each outer list item represents
a ``(Batch)Trial`` being suggested, with a list of ``GeneratorRun``-s for
that trial.
"""
# When implementing your subclass' override for this method, don't forget
# to consider using "pending points", corresponding to arms in trials that
# are currently running / being evaluated/
pass

@abstractmethod
def clone_reset(self) -> GenerationStrategyInterface:
"""Returns a clone of this generation strategy with all state reset."""
pass

@property
def name(self) -> str:
"""Name of this generation strategy. Defaults to a combination of model
names provided in generation steps.
"""
if self._name is not None:
return not_none(self._name)

self._name = f"GenerationStrategy {self.db_id}"
return not_none(self._name)

@name.setter
def name(self, name: str) -> None:
"""Set generation strategy name."""
self._name = name
"""Name of this generation strategy."""
return self._name

@property
def experiment(self) -> Experiment:
Expand All @@ -91,20 +101,11 @@ def experiment(self, experiment: Experiment) -> None:
experiment passed in is the same as the one saved and log an information
statement if its not. Set the new experiment on this generation strategy.
"""
if self._experiment is None or experiment._name == self.experiment._name:
self._experiment = experiment
else:
raise ValueError(
if self._experiment is not None and experiment._name != self.experiment._name:
raise UnsupportedError(
"This generation strategy has been used for experiment "
f"{self.experiment._name} so far; cannot reset experiment"
f" to {experiment._name}. If this is a new optimization, "
f" to {experiment._name}. If this is a new experiment, "
"a new generation strategy should be created instead."
)

@property
def last_generator_run(self) -> Optional[GeneratorRun]:
"""Latest generator run produced by this generation strategy.
Returns None if no generator runs have been produced yet.
"""
# Used to restore current model when decoding a serialized GS.
return self._generator_runs[-1] if self._generator_runs else None
self._experiment = experiment
27 changes: 24 additions & 3 deletions ax/modelbridge/generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,6 @@ class GenerationStrategy(GenerationStrategyInterface):
strategy's name will be names of its nodes' models joined with '+'.
"""

_name: Optional[str]
_nodes: List[GenerationNode]
_curr: GenerationNode # Current node in the strategy.
# Whether all models in this GS are in Models registry enum.
Expand All @@ -123,7 +122,6 @@ def __init__(
name: Optional[str] = None,
nodes: Optional[List[GenerationNode]] = None,
) -> None:
self._name = name
self._uses_registered_models = True
self._generator_runs = []

Expand Down Expand Up @@ -158,7 +156,7 @@ def __init__(
self._seen_trial_indices_by_status = None
# Set name to an explicit value ahead of time to avoid
# adding properties during equality checks
self._name = self.name
super().__init__(name=name or self._make_default_name())

@property
def is_node_based(self) -> bool:
Expand Down Expand Up @@ -280,6 +278,14 @@ def experiment(self, experiment: Experiment) -> None:
"a new generation strategy should be created instead."
)

@property
def last_generator_run(self) -> Optional[GeneratorRun]:
"""Latest generator run produced by this generation strategy.
Returns None if no generator runs have been produced yet.
"""
# Used to restore current model when decoding a serialized GS.
return self._generator_runs[-1] if self._generator_runs else None

@property
def uses_non_registered_models(self) -> bool:
"""Whether this generation strategy involves models that are not
Expand Down Expand Up @@ -613,6 +619,21 @@ def _step_repr(self, step_str_rep: str) -> str:
step_str_rep += "])"
return step_str_rep

def _make_default_name(self) -> str:
"""Make a default name for this generation strategy; used when no name is passed
to the constructor. Makes the name from model keys on generation nodes, set on
this generation strategy, and should only be called once the nodes are set.
"""
if not self._nodes:
raise UnsupportedError(
"Cannot make a default name for a generation strategy with no nodes "
"set yet."
)
factory_names = (node.model_spec_to_gen_from.model_key for node in self._nodes)
# Trim the "get_" beginning of the factory function if it's there.
factory_names = (n[4:] if n[:4] == "get_" else n for n in factory_names)
return "+".join(factory_names)

def __repr__(self) -> str:
"""String representation of this generation strategy."""
gs_str = f"GenerationStrategy(name='{self.name}', "
Expand Down
12 changes: 11 additions & 1 deletion ax/modelbridge/tests/test_generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,17 @@ def test_unique_step_names(self) -> None:
self.assertEqual(gs._steps[1].node_name, "GenerationStep_1")

def test_name(self) -> None:
self.sobol_GS.name = "SomeGSName"
self.assertEqual(self.sobol_GS._name, "Sobol")
self.assertEqual(
GenerationStrategy(
steps=[
GenerationStep(model=Models.SOBOL, num_trials=5),
GenerationStep(model=Models.GPEI, num_trials=-1),
],
).name,
"Sobol+GPEI",
)
self.sobol_GS._name = "SomeGSName"
self.assertEqual(self.sobol_GS.name, "SomeGSName")

def test_validation(self) -> None:
Expand Down
7 changes: 7 additions & 0 deletions ax/utils/testing/core_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
# pyre-strict


from __future__ import annotations

from collections import OrderedDict
from datetime import datetime, timedelta

Expand Down Expand Up @@ -2311,3 +2313,8 @@ def gen_for_multiple_trials_with_multiple_models(
n: int = 1,
) -> List[List[GeneratorRun]]:
return []

def clone_reset(self) -> SpecialGenerationStrategy:
clone = SpecialGenerationStrategy()
clone._name = self._name
return clone

0 comments on commit 1290067

Please sign in to comment.