Skip to content

Commit

Permalink
Make sure we add newly gen-d points to pending between calls to candi…
Browse files Browse the repository at this point in the history
…date generation during `_gen_multiple` (#1281)

Summary: Pull Request resolved: #1281

Reviewed By: Balandat

Differential Revision: D41484720

fbshipit-source-id: fb571c3566edd53e41b3c96ce0c6eda5a771ebbe
  • Loading branch information
Lena Kashtelyan authored and facebook-github-bot committed Dec 20, 2022
1 parent 79221c6 commit 32fbe65
Show file tree
Hide file tree
Showing 6 changed files with 247 additions and 72 deletions.
83 changes: 70 additions & 13 deletions ax/modelbridge/generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,16 @@
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional, Sequence, Union

from ax.core.data import Data # Perhaps need to use `AbstractDataFrameData`?
from ax.core.arm import Arm
from ax.core.data import Data
from ax.core.experiment import Experiment
from ax.core.generator_run import GeneratorRun
from ax.core.observation import ObservationFeatures
from ax.core.optimization_config import OptimizationConfig
from ax.core.search_space import SearchSpace
from ax.exceptions.core import UserInputError

from ax.exceptions.generation_strategy import GenerationStrategyRepeatedPoints
from ax.modelbridge.base import ModelBridge
from ax.modelbridge.completion_criterion import CompletionCriterion
from ax.modelbridge.cross_validation import BestModelSelector, CVDiagnostics, CVResult
Expand All @@ -32,6 +35,12 @@
`GenerationNode._pick_fitted_model_to_gen_from` (usually called
by `GenerationNode.gen`.
"""
MAX_GEN_DRAWS = 5
MAX_GEN_DRAWS_EXCEEDED_MESSAGE = (
f"GenerationStrategy exceeded `MAX_GEN_DRAWS` of {MAX_GEN_DRAWS} while trying to "
"generate a unique parameterization. This indicates that the search space has "
"likely been fully explored, or that the sweep has converged."
)


class GenerationNode:
Expand All @@ -40,12 +49,14 @@ class GenerationNode:
"""

model_specs: List[ModelSpec]
should_deduplicate: bool
_model_spec_to_gen_from: Optional[ModelSpec] = None

def __init__(
self,
model_specs: List[ModelSpec],
best_model_selector: Optional[BestModelSelector] = None,
should_deduplicate: bool = False,
) -> None:
# While `GenerationNode` only handles a single `ModelSpec` in the `gen`
# and `_pick_fitted_model_to_gen_from` methods, we validate the
Expand All @@ -54,6 +65,7 @@ def __init__(
# method to bypass that validation.
self.model_specs = model_specs
self.best_model_selector = best_model_selector
self.should_deduplicate = should_deduplicate

@property
def model_spec_to_gen_from(self) -> ModelSpec:
Expand Down Expand Up @@ -142,6 +154,9 @@ def gen(
self,
n: Optional[int] = None,
pending_observations: Optional[Dict[str, List[ObservationFeatures]]] = None,
max_gen_draws_for_deduplication: int = MAX_GEN_DRAWS,
arms_by_signature_for_deduplication: Optional[Dict[str, Arm]] = None,
**model_gen_kwargs: Any,
) -> GeneratorRun:
"""Picks a fitted model, from which to generate candidates (via
``self._pick_fitted_model_to_gen_from``) and generates candidates
Expand All @@ -156,24 +171,47 @@ def gen(
pending_observations: A map from metric name to pending
observations for that metric, used by some models to avoid
resuggesting points that are currently being evaluated.
max_gen_draws_for_deduplication: TODO
model_gen_kwargs: Keyword arguments, passed through to ``ModelSpec.gen``;
these override any pre-specified in ``ModelSpec.model_gen_kwargs``.
NOTE: Models must have been fit prior to calling ``gen``.
NOTE: Some underlying models may ignore the ``n`` argument and produce a
model-determined number of arms. In that case this method will also output
a generator run with number of arms (that can differ from ``n``).
"""
model_spec = self.model_spec_to_gen_from
return model_spec.gen(
# If `n` is not specified, ensure that the `None` value does not
# override the one set in `model_spec.model_gen_kwargs`.
n=model_spec.model_gen_kwargs.get("n")
if n is None and model_spec.model_gen_kwargs
else n,
# For `pending_observations`, prefer the input to this function, as
# `pending_observations` are dynamic throughout the experiment and thus
# unlikely to be specified in `model_spec.model_gen_kwargs`.
pending_observations=pending_observations,
)
should_generate_run = True
generator_run = None
n_gen_draws = 0
# Keep generating until each of `generator_run.arms` is not a duplicate
# of a previous arm, if `should_deduplicate is True`
while should_generate_run:
if n_gen_draws > max_gen_draws_for_deduplication:
raise GenerationStrategyRepeatedPoints(MAX_GEN_DRAWS_EXCEEDED_MESSAGE)
generator_run = model_spec.gen(
# If `n` is not specified, ensure that the `None` value does not
# override the one set in `model_spec.model_gen_kwargs`.
n=model_spec.model_gen_kwargs.get("n")
if n is None and model_spec.model_gen_kwargs
else n,
# For `pending_observations`, prefer the input to this function, as
# `pending_observations` are dynamic throughout the experiment and thus
# unlikely to be specified in `model_spec.model_gen_kwargs`.
pending_observations=pending_observations,
**model_gen_kwargs,
)

should_generate_run = (
self.should_deduplicate
and arms_by_signature_for_deduplication
and any(
arm.signature in arms_by_signature_for_deduplication
for arm in generator_run.arms
)
)
n_gen_draws += 1
return not_none(generator_run)

def _pick_fitted_model_to_gen_from(self) -> ModelSpec:
"""Select one model to generate from among the fitted models on this
Expand Down Expand Up @@ -322,7 +360,9 @@ def __post_init__(self) -> None:
model_kwargs=self.model_kwargs,
model_gen_kwargs=self.model_gen_kwargs,
)
super().__init__(model_specs=[model_spec])
super().__init__(
model_specs=[model_spec], should_deduplicate=self.should_deduplicate
)

@property
def model_spec(self) -> ModelSpec:
Expand All @@ -335,3 +375,20 @@ def model_name(self) -> str:
@property
def _unique_id(self) -> str:
return str(self.index)

def gen(
self,
n: Optional[int] = None,
pending_observations: Optional[Dict[str, List[ObservationFeatures]]] = None,
max_gen_draws_for_deduplication: int = MAX_GEN_DRAWS,
arms_by_signature_for_deduplication: Optional[Dict[str, Arm]] = None,
**model_gen_kwargs: Any,
) -> GeneratorRun:
gr = super().gen(
n=n,
pending_observations=pending_observations,
max_gen_draws_for_deduplication=max_gen_draws_for_deduplication,
arms_by_signature_for_deduplication=arms_by_signature_for_deduplication,
)
gr._generation_step_index = self.index
return gr
72 changes: 21 additions & 51 deletions ax/modelbridge/generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from typing import Any, Dict, List, Optional, Set, Tuple, Type

import pandas as pd
from ax.core.arm import Arm
from ax.core.base_trial import BaseTrial, TrialStatus
from ax.core.data import Data
from ax.core.experiment import Experiment
Expand All @@ -22,11 +21,11 @@
from ax.exceptions.core import DataRequiredError, NoDataError, UserInputError
from ax.exceptions.generation_strategy import (
GenerationStrategyCompleted,
GenerationStrategyRepeatedPoints,
MaxParallelismReachedException,
)
from ax.modelbridge.base import ModelBridge
from ax.modelbridge.generation_node import GenerationStep
from ax.modelbridge.modelbridge_utils import extend_pending_observations
from ax.modelbridge.registry import _extract_model_state_after_gen, ModelRegistryBase
from ax.utils.common.base import Base
from ax.utils.common.logger import _round_floats_for_logging, get_logger
Expand Down Expand Up @@ -420,7 +419,7 @@ def _gen_multiple(
data: Optional[Data] = None,
n: int = 1,
pending_observations: Optional[Dict[str, List[ObservationFeatures]]] = None,
**kwargs: Any,
**model_gen_kwargs: Any,
) -> List[GeneratorRun]:
"""Produce multiple generator runs at once, to be made into multiple
trials on the experiment.
Expand Down Expand Up @@ -452,6 +451,9 @@ def _gen_multiple(
pending_observations: A map from metric name to pending
observations for that metric, used by some models to avoid
resuggesting points that are currently being evaluated.
model_gen_kwargs: Keyword arguments that are passed through to
``GenerationStep.gen``, which will pass them through to
``ModelSpec.gen``, which will pass them to ``ModelBridge.gen``.
"""
self.experiment = experiment
self._maybe_move_to_next_step()
Expand All @@ -471,20 +473,16 @@ def _gen_multiple(
)

generator_runs = []
pending_observations = deepcopy(pending_observations) or {}
for _ in range(num_generator_runs):
try:
generator_run = _gen_from_generation_step(
generation_step=self._curr,
input_max_gen_draws=MAX_GEN_DRAWS,
generator_run = self._curr.gen(
n=n,
pending_observations=pending_observations,
model_gen_kwargs=kwargs,
should_deduplicate=self._curr.should_deduplicate,
arms_by_signature=self.experiment.arms_by_signature,
arms_by_signature_for_deduplication=experiment.arms_by_signature,
**model_gen_kwargs,
)
generator_run._generation_step_index = self._curr.index
self._generator_runs.append(generator_run)
generator_runs.append(generator_run)

except DataRequiredError as err:
# Model needs more data, so we log the error and return
# as many generator runs as we were able to produce, unless
Expand All @@ -494,6 +492,17 @@ def _gen_multiple(
logger.debug(f"Model required more data: {err}.") # pragma: no cover
break # pragma: no cover

self._generator_runs.append(generator_run)
generator_runs.append(generator_run)

# Extend the `pending_observation` with newly generated point(s)
# to avoid repeating them.
extend_pending_observations(
experiment=experiment,
pending_observations=pending_observations,
generator_run=generator_run,
)

return generator_runs

# ------------------------- Model selection logic helpers. -------------------------
Expand Down Expand Up @@ -826,42 +835,3 @@ def _register_trial_data_update(self, trial: BaseTrial) -> None:
"Updating completed trials with new data is not yet supported for "
"generation strategies that leverage `model.update` functionality."
)


def _gen_from_generation_step(
input_max_gen_draws: int,
generation_step: GenerationStep,
n: int,
pending_observations: Optional[Dict[str, List[ObservationFeatures]]],
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
model_gen_kwargs: Any,
should_deduplicate: bool,
arms_by_signature: Dict[str, Arm],
) -> GeneratorRun:
"""Produces a ``GeneratorRun`` with ``n`` arms using the provided ``model``. if
``should_deduplicate is True``, these arms are deduplicated against previous arms
using rejection sampling before returning. If more than ``input_max_gen_draws``
samples are generated during deduplication, this function produces a
``GenerationStrategyRepeatedPoints`` exception.
"""
# TODO[drfreund]: Consider moving dedulication to generation step itself.
# NOTE: Might need to revisit the behavior of deduplication when
# generating multi-arm generator runs (to be made into batch trials).
should_generate_run = True
generator_run = None
n_gen_draws = 0
# Keep generating until each of `generator_run.arms` is not a duplicate
# of a previous arm, if `should_deduplicate is True`
while should_generate_run:
if n_gen_draws > input_max_gen_draws:
raise GenerationStrategyRepeatedPoints(MAX_GEN_DRAWS_EXCEEDED_MESSAGE)
generator_run = generation_step.gen(
n=n,
pending_observations=pending_observations,
**model_gen_kwargs,
)
should_generate_run = should_deduplicate and any(
arm.signature in arms_by_signature for arm in generator_run.arms
)
n_gen_draws += 1
return not_none(generator_run)
17 changes: 17 additions & 0 deletions ax/modelbridge/modelbridge_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from ax.core.batch_trial import BatchTrial
from ax.core.data import Data
from ax.core.experiment import Experiment
from ax.core.generator_run import GeneratorRun
from ax.core.objective import MultiObjective, Objective, ScalarizedObjective
from ax.core.observation import Observation, ObservationData, ObservationFeatures
from ax.core.optimization_config import (
Expand Down Expand Up @@ -766,6 +767,22 @@ def get_pending_observation_features_based_on_trial_status(
return dict(pending_features) if any(x for x in pending_features.values()) else None


def extend_pending_observations(
experiment: Experiment,
pending_observations: Dict[str, List[ObservationFeatures]],
generator_run: GeneratorRun,
) -> None:
"""Extend given pending observations dict (from metric name to observations
that are pending for that metric), with arms in a given generator run.
"""
for m in experiment.metrics:
if m not in pending_observations:
pending_observations[m] = []
pending_observations[m].extend(
ObservationFeatures.from_arm(a) for a in generator_run.arms
)


def get_pareto_frontier_and_configs(
modelbridge: modelbridge_module.torch.TorchModelBridge,
observation_features: List[ObservationFeatures],
Expand Down
Loading

0 comments on commit 32fbe65

Please sign in to comment.